From acc5c1281b50c12e4d04c81b899410f6ca2cacac Mon Sep 17 00:00:00 2001 From: cloudroam <cloudroam> Date: 星期二, 15 四月 2025 15:13:30 +0800 Subject: [PATCH] add: 增加航班和火车票 --- ner_config.py | 184 +++++++++++++++++++++++++++++++++++++++++---- 1 files changed, 167 insertions(+), 17 deletions(-) diff --git a/ner_config.py b/ner_config.py index 71581d7..999015d 100644 --- a/ner_config.py +++ b/ner_config.py @@ -30,7 +30,7 @@ # 交叉验证配置 N_SPLITS = 3 # CPU环境下减少折数 N_SEEDS = 1 # CPU环境下减少种子数量 - + # 确保标签列表完整 LABELS = [ "O", @@ -72,40 +72,40 @@ LEARNING_RATE = 3e-5 WARMUP_RATIO = 0.1 WEIGHT_DECAY = 0.01 - + # 数据增强配置 USE_DATA_AUGMENTATION = False AUGMENTATION_RATIO = 0.3 - + # 训练策略 GRADIENT_ACCUMULATION_STEPS = 4 EVAL_STEPS = 25 LOGGING_STEPS = 10 SAVE_STEPS = 25 # 添加保存步数 SAVE_TOTAL_LIMIT = 2 # 添加保存检查点数量限制 - + # 路径配置 DATA_PATH = "data/repayment.txt" MODEL_PATH = "./models/repayment_model" LOG_PATH = "./logs_repayment" - + # 训练配置优化 SEED = 42 TEST_SIZE = 0.1 EARLY_STOPPING_PATIENCE = 2 - + # CPU环境配置 MAX_GRAD_NORM = 1.0 FP16 = False # CPU环境下关闭FP16 - + # CPU环境下的数据加载优化 DATALOADER_NUM_WORKERS = 0 # CPU环境下设为0 DATALOADER_PIN_MEMORY = False # CPU环境下关闭 - + # 交叉验证配置 N_SPLITS = 3 N_SEEDS = 1 - + # 标签列表 LABELS = [ "O", @@ -149,40 +149,40 @@ LEARNING_RATE = 3e-5 WARMUP_RATIO = 0.1 WEIGHT_DECAY = 0.01 - + # 数据增强配置 USE_DATA_AUGMENTATION = False AUGMENTATION_RATIO = 0.3 - + # 训练策略 GRADIENT_ACCUMULATION_STEPS = 4 EVAL_STEPS = 25 LOGGING_STEPS = 10 SAVE_STEPS = 25 SAVE_TOTAL_LIMIT = 2 - + # 路径配置 DATA_PATH = "data/income.txt" MODEL_PATH = "./models/income_model" LOG_PATH = "./logs_income" - + # 训练配置优化 SEED = 42 TEST_SIZE = 0.1 EARLY_STOPPING_PATIENCE = 2 - + # CPU环境配置 MAX_GRAD_NORM = 1.0 FP16 = False - + # CPU环境下的数据加载优化 DATALOADER_NUM_WORKERS = 0 DATALOADER_PIN_MEMORY = False - + # 交叉验证配置 N_SPLITS = 3 N_SEEDS = 1 - + # 标签列表 LABELS = [ "O", @@ -211,4 +211,154 @@ 'max_integer_digits': 12, # 整数部分最大位数 'currency_symbols': ['¥', '¥', 'RMB', '元'], # 货币符号 'decimal_context_range': 3 # 查找小数点的上下文范围 + } + +class FlightNERConfig: + # 优化模型参数 (与 RepaymentNERConfig 保持一致) + MODEL_NAME = "bert-base-chinese" + MAX_LENGTH = 128 + BATCH_SIZE = 4 + EPOCHS = 10 + LEARNING_RATE = 3e-5 + WARMUP_RATIO = 0.1 + WEIGHT_DECAY = 0.01 + + # 训练策略 + GRADIENT_ACCUMULATION_STEPS = 4 + EVAL_STEPS = 25 + LOGGING_STEPS = 10 + SAVE_STEPS = 25 + SAVE_TOTAL_LIMIT = 2 + + # 路径配置 + DATA_PATH = "data/flight.txt" + MODEL_PATH = "./models/flight_model" + LOG_PATH = "./logs_flight" + + # 训练配置 + SEED = 42 + TEST_SIZE = 0.1 + EARLY_STOPPING_PATIENCE = 2 + + # CPU环境配置 + MAX_GRAD_NORM = 1.0 + FP16 = False + DATALOADER_NUM_WORKERS = 0 + DATALOADER_PIN_MEMORY = False + + # 交叉验证配置 + N_SPLITS = 3 + N_SEEDS = 3 # 增加种子数量以提高模型稳定性 + + # 标签列表 - 保持与需求一致 + LABELS = [ + "O", + "B-FLIGHT", "I-FLIGHT", # 航班号 + "B-COMPANY", "I-COMPANY", # 航空公司 + "B-START", "I-START", # 出发地 + "B-END", "I-END", # 目的地 + "B-DATE", "I-DATE", # 日期 + "B-TIME", "I-TIME", # 时间 + "B-DEPARTURE_TIME", "I-DEPARTURE_TIME", # 起飞时间 + "B-ARRIVAL_TIME", "I-ARRIVAL_TIME", # 到达时间 + "B-TICKET_NUM", "I-TICKET_NUM", # 机票号码 + "B-SEAT", "I-SEAT" # 座位等信息 + ] + + # 实体长度限制 - 更新键名与LABELS一致 + MAX_ENTITY_LENGTH = { + "FLIGHT": 10, # 航班号 + "COMPANY": 15, # 航空公司 + "START": 10, # 出发地 + "END": 10, # 目的地 + "DATE": 15, # 日期 + "TIME": 10, # 时间 + "DEPARTURE_TIME": 10, # 起飞时间 + "ARRIVAL_TIME": 10, # 到达时间 + "TICKET_NUM": 10, # 用户姓名 + "SEAT": 10 # 座位等信息 + } + + # 航班号配置 + FLIGHT_CONFIG = { + 'pattern': r'[A-Z]{2}\d{3,4}', + 'min_length': 4, + 'max_length': 7, + 'carrier_codes': ['CA', 'MU', 'CZ', 'HU', '3U', 'ZH', 'FM', 'MF', 'SC', '9C'] # 常见航司代码 + } + +class TrainNERConfig: + # 模型参数 + MODEL_NAME = "bert-base-chinese" + MAX_LENGTH = 128 + BATCH_SIZE = 4 + EPOCHS = 10 + LEARNING_RATE = 3e-5 + WARMUP_RATIO = 0.1 + WEIGHT_DECAY = 0.01 + + # 训练策略 + GRADIENT_ACCUMULATION_STEPS = 4 + EVAL_STEPS = 25 + LOGGING_STEPS = 10 + SAVE_STEPS = 25 + SAVE_TOTAL_LIMIT = 2 + + # 路径配置 + DATA_PATH = "data/train.txt" + MODEL_PATH = "./models/train_model" + LOG_PATH = "./logs_train" + + # 训练配置 + SEED = 42 + TEST_SIZE = 0.1 + EARLY_STOPPING_PATIENCE = 2 + + # CPU环境配置 + MAX_GRAD_NORM = 1.0 + FP16 = False + DATALOADER_NUM_WORKERS = 0 + DATALOADER_PIN_MEMORY = False + + # 交叉验证配置 + N_SPLITS = 3 + N_SEEDS = 3 # 增加种子数量以提高模型稳定性 + + # 标签列表 + LABELS = [ + "O", + "B-COMPANY", "I-COMPANY", # 车次 + "B-TRIPS", "I-TRIPS", # 车次 + "B-START", "I-START", # 出发站 + "B-END", "I-END", # 到达站 + "B-DATE", "I-DATE", # 日期 + "B-TIME", "I-TIME", # 时间 + "B-SEAT", "I-SEAT", # 座位等信息 + "B-NAME", "I-NAME" # 用户姓名 + ] + + # 实体长度限制 - 更新键名与LABELS一致 + MAX_ENTITY_LENGTH = { + "COMPANY": 8, # 12306 + "TRIPS": 8, # 车次 + "START": 10, # 出发站 + "END": 10, # 到达站 + "DATE": 15, # 日期 + "TIME": 10, # 时间 + "SEAT": 10, # 座位等信息 + "NAME": 10 # 用户姓名 + } + + # 车次配置 + TRIPS_CONFIG = { + 'patterns': [ + r'[GDCZTKY]\d{1,2}', # G1, D1, C1等 + r'[GDCZTKY]\d{1,2}/\d{1,2}', # G1/2等联运车次 + r'[GDCZTKY]\d{1,2}-\d{1,2}', # G1-2等联运车次 + r'\d{1,4}', # 普通车次如1234次 + r'[A-Z]\d{1,4}' # Z1234等特殊车次 + ], + 'min_length': 1, + 'max_length': 8, + 'train_types': ['G', 'D', 'C', 'Z', 'T', 'K', 'Y'] # 车次类型前缀 } \ No newline at end of file -- Gitblit v1.9.3