cloudroam
2025-04-15 acc5c1281b50c12e4d04c81b899410f6ca2cacac
ner_config.py
@@ -30,7 +30,7 @@
    # 交叉验证配置
    N_SPLITS = 3      # CPU环境下减少折数
    N_SEEDS = 1       # CPU环境下减少种子数量
    # 确保标签列表完整
    LABELS =  [
    "O",
@@ -72,40 +72,40 @@
    LEARNING_RATE = 3e-5
    WARMUP_RATIO = 0.1
    WEIGHT_DECAY = 0.01
    # 数据增强配置
    USE_DATA_AUGMENTATION = False
    AUGMENTATION_RATIO = 0.3
    # 训练策略
    GRADIENT_ACCUMULATION_STEPS = 4
    EVAL_STEPS = 25
    LOGGING_STEPS = 10
    SAVE_STEPS = 25      # 添加保存步数
    SAVE_TOTAL_LIMIT = 2 # 添加保存检查点数量限制
    # 路径配置
    DATA_PATH = "data/repayment.txt"
    MODEL_PATH = "./models/repayment_model"
    LOG_PATH = "./logs_repayment"
    # 训练配置优化
    SEED = 42
    TEST_SIZE = 0.1
    EARLY_STOPPING_PATIENCE = 2
    # CPU环境配置
    MAX_GRAD_NORM = 1.0
    FP16 = False        # CPU环境下关闭FP16
    # CPU环境下的数据加载优化
    DATALOADER_NUM_WORKERS = 0  # CPU环境下设为0
    DATALOADER_PIN_MEMORY = False  # CPU环境下关闭
    # 交叉验证配置
    N_SPLITS = 3
    N_SEEDS = 1
    # 标签列表
    LABELS = [
        "O",
@@ -149,40 +149,40 @@
    LEARNING_RATE = 3e-5
    WARMUP_RATIO = 0.1
    WEIGHT_DECAY = 0.01
    # 数据增强配置
    USE_DATA_AUGMENTATION = False
    AUGMENTATION_RATIO = 0.3
    # 训练策略
    GRADIENT_ACCUMULATION_STEPS = 4
    EVAL_STEPS = 25
    LOGGING_STEPS = 10
    SAVE_STEPS = 25
    SAVE_TOTAL_LIMIT = 2
    # 路径配置
    DATA_PATH = "data/income.txt"
    MODEL_PATH = "./models/income_model"
    LOG_PATH = "./logs_income"
    # 训练配置优化
    SEED = 42
    TEST_SIZE = 0.1
    EARLY_STOPPING_PATIENCE = 2
    # CPU环境配置
    MAX_GRAD_NORM = 1.0
    FP16 = False
    # CPU环境下的数据加载优化
    DATALOADER_NUM_WORKERS = 0
    DATALOADER_PIN_MEMORY = False
    # 交叉验证配置
    N_SPLITS = 3
    N_SEEDS = 1
    # 标签列表
    LABELS = [
        "O",
@@ -211,4 +211,154 @@
        'max_integer_digits': 12,  # 整数部分最大位数
        'currency_symbols': ['¥', '¥', 'RMB', '元'],  # 货币符号
        'decimal_context_range': 3  # 查找小数点的上下文范围
    }
class FlightNERConfig:
    # 优化模型参数 (与 RepaymentNERConfig 保持一致)
    MODEL_NAME = "bert-base-chinese"
    MAX_LENGTH = 128
    BATCH_SIZE = 4
    EPOCHS = 10
    LEARNING_RATE = 3e-5
    WARMUP_RATIO = 0.1
    WEIGHT_DECAY = 0.01
    # 训练策略
    GRADIENT_ACCUMULATION_STEPS = 4
    EVAL_STEPS = 25
    LOGGING_STEPS = 10
    SAVE_STEPS = 25
    SAVE_TOTAL_LIMIT = 2
    # 路径配置
    DATA_PATH = "data/flight.txt"
    MODEL_PATH = "./models/flight_model"
    LOG_PATH = "./logs_flight"
    # 训练配置
    SEED = 42
    TEST_SIZE = 0.1
    EARLY_STOPPING_PATIENCE = 2
    # CPU环境配置
    MAX_GRAD_NORM = 1.0
    FP16 = False
    DATALOADER_NUM_WORKERS = 0
    DATALOADER_PIN_MEMORY = False
    # 交叉验证配置
    N_SPLITS = 3
    N_SEEDS = 3  # 增加种子数量以提高模型稳定性
    # 标签列表 - 保持与需求一致
    LABELS = [
        "O",
        "B-FLIGHT", "I-FLIGHT",     # 航班号
        "B-COMPANY", "I-COMPANY",   # 航空公司
        "B-START", "I-START",       # 出发地
        "B-END", "I-END",           # 目的地
        "B-DATE", "I-DATE",         # 日期
        "B-TIME", "I-TIME",         # 时间
        "B-DEPARTURE_TIME", "I-DEPARTURE_TIME",  # 起飞时间
        "B-ARRIVAL_TIME", "I-ARRIVAL_TIME",      # 到达时间
        "B-TICKET_NUM", "I-TICKET_NUM",  # 机票号码
        "B-SEAT", "I-SEAT"  # 座位等信息
    ]
    # 实体长度限制 - 更新键名与LABELS一致
    MAX_ENTITY_LENGTH = {
        "FLIGHT": 10,       # 航班号
        "COMPANY": 15,      # 航空公司
        "START": 10,        # 出发地
        "END": 10,          # 目的地
        "DATE": 15,         # 日期
        "TIME": 10,         # 时间
        "DEPARTURE_TIME": 10,  # 起飞时间
        "ARRIVAL_TIME": 10,    # 到达时间
        "TICKET_NUM": 10,      # 用户姓名
        "SEAT": 10             # 座位等信息
    }
    # 航班号配置
    FLIGHT_CONFIG = {
        'pattern': r'[A-Z]{2}\d{3,4}',
        'min_length': 4,
        'max_length': 7,
        'carrier_codes': ['CA', 'MU', 'CZ', 'HU', '3U', 'ZH', 'FM', 'MF', 'SC', '9C']  # 常见航司代码
    }
class TrainNERConfig:
    # 模型参数
    MODEL_NAME = "bert-base-chinese"
    MAX_LENGTH = 128
    BATCH_SIZE = 4
    EPOCHS = 10
    LEARNING_RATE = 3e-5
    WARMUP_RATIO = 0.1
    WEIGHT_DECAY = 0.01
    # 训练策略
    GRADIENT_ACCUMULATION_STEPS = 4
    EVAL_STEPS = 25
    LOGGING_STEPS = 10
    SAVE_STEPS = 25
    SAVE_TOTAL_LIMIT = 2
    # 路径配置
    DATA_PATH = "data/train.txt"
    MODEL_PATH = "./models/train_model"
    LOG_PATH = "./logs_train"
    # 训练配置
    SEED = 42
    TEST_SIZE = 0.1
    EARLY_STOPPING_PATIENCE = 2
    # CPU环境配置
    MAX_GRAD_NORM = 1.0
    FP16 = False
    DATALOADER_NUM_WORKERS = 0
    DATALOADER_PIN_MEMORY = False
    # 交叉验证配置
    N_SPLITS = 3
    N_SEEDS = 3  # 增加种子数量以提高模型稳定性
    # 标签列表
    LABELS = [
        "O",
        "B-COMPANY", "I-COMPANY",  # 车次
        "B-TRIPS", "I-TRIPS",       # 车次
        "B-START", "I-START",     # 出发站
        "B-END", "I-END",         # 到达站
        "B-DATE", "I-DATE",      # 日期
        "B-TIME", "I-TIME",      # 时间
        "B-SEAT", "I-SEAT",      # 座位等信息
        "B-NAME", "I-NAME"       # 用户姓名
    ]
    # 实体长度限制 - 更新键名与LABELS一致
    MAX_ENTITY_LENGTH = {
        "COMPANY": 8,          # 12306
        "TRIPS": 8,            # 车次
        "START": 10,           # 出发站
        "END": 10,             # 到达站
        "DATE": 15,            # 日期
        "TIME": 10,            # 时间
        "SEAT": 10,            # 座位等信息
        "NAME": 10             # 用户姓名
    }
    # 车次配置
    TRIPS_CONFIG = {
        'patterns': [
            r'[GDCZTKY]\d{1,2}',            # G1, D1, C1等
            r'[GDCZTKY]\d{1,2}/\d{1,2}',    # G1/2等联运车次
            r'[GDCZTKY]\d{1,2}-\d{1,2}',    # G1-2等联运车次
            r'\d{1,4}',                     # 普通车次如1234次
            r'[A-Z]\d{1,4}'                 # Z1234等特殊车次
        ],
        'min_length': 1,
        'max_length': 8,
        'train_types': ['G', 'D', 'C', 'Z', 'T', 'K', 'Y']  # 车次类型前缀
    }