From e6fed94443177826cf7497a85e9cdcfc7c43ee21 Mon Sep 17 00:00:00 2001
From: cloudroam <cloudroam>
Date: 星期一, 21 四月 2025 16:49:49 +0800
Subject: [PATCH] fix

---
 ner_config.py |  184 +++++++++++++++++++++++++++++++++++++++++----
 1 files changed, 167 insertions(+), 17 deletions(-)

diff --git a/ner_config.py b/ner_config.py
index 71581d7..999015d 100644
--- a/ner_config.py
+++ b/ner_config.py
@@ -30,7 +30,7 @@
     # 交叉验证配置
     N_SPLITS = 3      # CPU环境下减少折数
     N_SEEDS = 1       # CPU环境下减少种子数量
-    
+
     # 确保标签列表完整
     LABELS =  [
     "O",
@@ -72,40 +72,40 @@
     LEARNING_RATE = 3e-5
     WARMUP_RATIO = 0.1
     WEIGHT_DECAY = 0.01
-    
+
     # 数据增强配置
     USE_DATA_AUGMENTATION = False
     AUGMENTATION_RATIO = 0.3
-    
+
     # 训练策略
     GRADIENT_ACCUMULATION_STEPS = 4
     EVAL_STEPS = 25
     LOGGING_STEPS = 10
     SAVE_STEPS = 25      # 添加保存步数
     SAVE_TOTAL_LIMIT = 2 # 添加保存检查点数量限制
-    
+
     # 路径配置
     DATA_PATH = "data/repayment.txt"
     MODEL_PATH = "./models/repayment_model"
     LOG_PATH = "./logs_repayment"
-    
+
     # 训练配置优化
     SEED = 42
     TEST_SIZE = 0.1
     EARLY_STOPPING_PATIENCE = 2
-    
+
     # CPU环境配置
     MAX_GRAD_NORM = 1.0
     FP16 = False        # CPU环境下关闭FP16
-    
+
     # CPU环境下的数据加载优化
     DATALOADER_NUM_WORKERS = 0  # CPU环境下设为0
     DATALOADER_PIN_MEMORY = False  # CPU环境下关闭
-    
+
     # 交叉验证配置
     N_SPLITS = 3
     N_SEEDS = 1
-    
+
     # 标签列表
     LABELS = [
         "O",
@@ -149,40 +149,40 @@
     LEARNING_RATE = 3e-5
     WARMUP_RATIO = 0.1
     WEIGHT_DECAY = 0.01
-    
+
     # 数据增强配置
     USE_DATA_AUGMENTATION = False
     AUGMENTATION_RATIO = 0.3
-    
+
     # 训练策略
     GRADIENT_ACCUMULATION_STEPS = 4
     EVAL_STEPS = 25
     LOGGING_STEPS = 10
     SAVE_STEPS = 25
     SAVE_TOTAL_LIMIT = 2
-    
+
     # 路径配置
     DATA_PATH = "data/income.txt"
     MODEL_PATH = "./models/income_model"
     LOG_PATH = "./logs_income"
-    
+
     # 训练配置优化
     SEED = 42
     TEST_SIZE = 0.1
     EARLY_STOPPING_PATIENCE = 2
-    
+
     # CPU环境配置
     MAX_GRAD_NORM = 1.0
     FP16 = False
-    
+
     # CPU环境下的数据加载优化
     DATALOADER_NUM_WORKERS = 0
     DATALOADER_PIN_MEMORY = False
-    
+
     # 交叉验证配置
     N_SPLITS = 3
     N_SEEDS = 1
-    
+
     # 标签列表
     LABELS = [
         "O",
@@ -211,4 +211,154 @@
         'max_integer_digits': 12,  # 整数部分最大位数
         'currency_symbols': ['¥', '¥', 'RMB', '元'],  # 货币符号
         'decimal_context_range': 3  # 查找小数点的上下文范围
+    }
+
+class FlightNERConfig:
+    # 优化模型参数 (与 RepaymentNERConfig 保持一致)
+    MODEL_NAME = "bert-base-chinese"
+    MAX_LENGTH = 128
+    BATCH_SIZE = 4
+    EPOCHS = 10
+    LEARNING_RATE = 3e-5
+    WARMUP_RATIO = 0.1
+    WEIGHT_DECAY = 0.01
+
+    # 训练策略
+    GRADIENT_ACCUMULATION_STEPS = 4
+    EVAL_STEPS = 25
+    LOGGING_STEPS = 10
+    SAVE_STEPS = 25
+    SAVE_TOTAL_LIMIT = 2
+
+    # 路径配置
+    DATA_PATH = "data/flight.txt"
+    MODEL_PATH = "./models/flight_model"
+    LOG_PATH = "./logs_flight"
+
+    # 训练配置
+    SEED = 42
+    TEST_SIZE = 0.1
+    EARLY_STOPPING_PATIENCE = 2
+
+    # CPU环境配置
+    MAX_GRAD_NORM = 1.0
+    FP16 = False
+    DATALOADER_NUM_WORKERS = 0
+    DATALOADER_PIN_MEMORY = False
+
+    # 交叉验证配置
+    N_SPLITS = 3
+    N_SEEDS = 3  # 增加种子数量以提高模型稳定性
+
+    # 标签列表 - 保持与需求一致
+    LABELS = [
+        "O",
+        "B-FLIGHT", "I-FLIGHT",     # 航班号
+        "B-COMPANY", "I-COMPANY",   # 航空公司
+        "B-START", "I-START",       # 出发地
+        "B-END", "I-END",           # 目的地
+        "B-DATE", "I-DATE",         # 日期
+        "B-TIME", "I-TIME",         # 时间
+        "B-DEPARTURE_TIME", "I-DEPARTURE_TIME",  # 起飞时间
+        "B-ARRIVAL_TIME", "I-ARRIVAL_TIME",      # 到达时间
+        "B-TICKET_NUM", "I-TICKET_NUM",  # 机票号码
+        "B-SEAT", "I-SEAT"  # 座位等信息
+    ]
+
+    # 实体长度限制 - 更新键名与LABELS一致
+    MAX_ENTITY_LENGTH = {
+        "FLIGHT": 10,       # 航班号
+        "COMPANY": 15,      # 航空公司
+        "START": 10,        # 出发地
+        "END": 10,          # 目的地
+        "DATE": 15,         # 日期
+        "TIME": 10,         # 时间
+        "DEPARTURE_TIME": 10,  # 起飞时间
+        "ARRIVAL_TIME": 10,    # 到达时间
+        "TICKET_NUM": 10,      # 用户姓名
+        "SEAT": 10             # 座位等信息
+    }
+
+    # 航班号配置
+    FLIGHT_CONFIG = {
+        'pattern': r'[A-Z]{2}\d{3,4}',
+        'min_length': 4,
+        'max_length': 7,
+        'carrier_codes': ['CA', 'MU', 'CZ', 'HU', '3U', 'ZH', 'FM', 'MF', 'SC', '9C']  # 常见航司代码
+    }
+
+class TrainNERConfig:
+    # 模型参数
+    MODEL_NAME = "bert-base-chinese"
+    MAX_LENGTH = 128
+    BATCH_SIZE = 4
+    EPOCHS = 10
+    LEARNING_RATE = 3e-5
+    WARMUP_RATIO = 0.1
+    WEIGHT_DECAY = 0.01
+
+    # 训练策略
+    GRADIENT_ACCUMULATION_STEPS = 4
+    EVAL_STEPS = 25
+    LOGGING_STEPS = 10
+    SAVE_STEPS = 25
+    SAVE_TOTAL_LIMIT = 2
+
+    # 路径配置
+    DATA_PATH = "data/train.txt"
+    MODEL_PATH = "./models/train_model"
+    LOG_PATH = "./logs_train"
+
+    # 训练配置
+    SEED = 42
+    TEST_SIZE = 0.1
+    EARLY_STOPPING_PATIENCE = 2
+
+    # CPU环境配置
+    MAX_GRAD_NORM = 1.0
+    FP16 = False
+    DATALOADER_NUM_WORKERS = 0
+    DATALOADER_PIN_MEMORY = False
+
+    # 交叉验证配置
+    N_SPLITS = 3
+    N_SEEDS = 3  # 增加种子数量以提高模型稳定性
+
+    # 标签列表
+    LABELS = [
+        "O",
+        "B-COMPANY", "I-COMPANY",  # 车次
+        "B-TRIPS", "I-TRIPS",       # 车次
+        "B-START", "I-START",     # 出发站
+        "B-END", "I-END",         # 到达站
+        "B-DATE", "I-DATE",      # 日期
+        "B-TIME", "I-TIME",      # 时间
+        "B-SEAT", "I-SEAT",      # 座位等信息
+        "B-NAME", "I-NAME"       # 用户姓名
+    ]
+
+    # 实体长度限制 - 更新键名与LABELS一致
+    MAX_ENTITY_LENGTH = {
+        "COMPANY": 8,          # 12306
+        "TRIPS": 8,            # 车次
+        "START": 10,           # 出发站
+        "END": 10,             # 到达站
+        "DATE": 15,            # 日期
+        "TIME": 10,            # 时间
+        "SEAT": 10,            # 座位等信息
+        "NAME": 10             # 用户姓名
+    }
+    
+    # 车次配置
+    TRIPS_CONFIG = {
+        'patterns': [
+            r'[GDCZTKY]\d{1,2}',            # G1, D1, C1等
+            r'[GDCZTKY]\d{1,2}/\d{1,2}',    # G1/2等联运车次
+            r'[GDCZTKY]\d{1,2}-\d{1,2}',    # G1-2等联运车次
+            r'\d{1,4}',                     # 普通车次如1234次
+            r'[A-Z]\d{1,4}'                 # Z1234等特殊车次
+        ],
+        'min_length': 1,
+        'max_length': 8,
+        'train_types': ['G', 'D', 'C', 'Z', 'T', 'K', 'Y']  # 车次类型前缀
     }
\ No newline at end of file

--
Gitblit v1.9.3