cloudroam
2025-04-15 acc5c1281b50c12e4d04c81b899410f6ca2cacac
add: 增加航班和火车票
已修改6个文件
已重命名1个文件
已添加1个文件
4224 ■■■■■ 文件已修改
.gitignore 111 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
app.py 254 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
check_labels.py 7 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
data/flight.txt 补丁 | 查看 | 原始文档 | blame | 历史
data/train.txt 3082 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
ner_config.py 184 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
train_flight_ner.py 297 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
train_train_ner.py 289 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
.gitignore
对比新文件
@@ -0,0 +1,111 @@
/models/ner_model/
/.idea/.gitignore
/models/classifier/checkpoint-11/config.json
/models/classifier/checkpoint-22/config.json
/models/classifier/checkpoint-33/config.json
/models/classifier/config.json
/models/income_model/best_model/config.json
/models/income_model/checkpoint-25/config.json
/models/income_model/checkpoint-75/config.json
/models/repayment_model/best_model/config.json
/models/repayment_model/checkpoint-75/config.json
/models/repayment_model/checkpoint-125/config.json
/models/classifier/confusion_matrix.png
/logs_ner/seed_0/events.out.tfevents.1742464122.DESKTOP-U3O8B5H.5008.0
/logs_ner/seed_0/events.out.tfevents.1742469089.DESKTOP-U3O8B5H.8812.0
/logs_ner/seed_0/events.out.tfevents.1742470664.DESKTOP-U3O8B5H.8812.1
/logs_ner/seed_1/events.out.tfevents.1742470670.DESKTOP-U3O8B5H.8812.2
/logs_ner/seed_1/events.out.tfevents.1742472137.DESKTOP-U3O8B5H.8812.3
/logs_ner/seed_0/events.out.tfevents.1742522311.DESKTOP-U3O8B5H.9272.0
/logs_ner/seed_0/events.out.tfevents.1742523686.DESKTOP-U3O8B5H.9272.1
/logs_ner/seed_0/events.out.tfevents.1742529705.DESKTOP-U3O8B5H.11904.0
/logs_ner/seed_0/events.out.tfevents.1742531427.DESKTOP-U3O8B5H.11904.1
/logs_ner/seed_0/events.out.tfevents.1742534893.DESKTOP-U3O8B5H.14104.0
/logs_ner/seed_0/events.out.tfevents.1742535919.DESKTOP-U3O8B5H.14104.1
/logs_ner/seed_0/events.out.tfevents.1742536567.DESKTOP-U3O8B5H.11136.0
/logs_ner/seed_0/events.out.tfevents.1742537539.DESKTOP-U3O8B5H.11136.1
/logs_ner/seed_0/events.out.tfevents.1742538119.DESKTOP-U3O8B5H.8680.0
/logs_ner/seed_0/events.out.tfevents.1742540436.DESKTOP-U3O8B5H.8680.1
/logs_repayment/events.out.tfevents.1742888152.DESKTOP-U3O8B5H.3584.0
/logs_repayment/events.out.tfevents.1742890775.DESKTOP-U3O8B5H.6340.0
/logs_repayment/events.out.tfevents.1742891151.DESKTOP-U3O8B5H.9964.0
/logs_repayment/events.out.tfevents.1742893885.DESKTOP-U3O8B5H.12032.0
/logs_repayment/events.out.tfevents.1742896256.DESKTOP-U3O8B5H.12032.1
/logs_repayment/events.out.tfevents.1742953278.DESKTOP-U3O8B5H.9672.0
/logs_repayment/events.out.tfevents.1742955624.DESKTOP-U3O8B5H.9672.1
/logs_repayment/events.out.tfevents.1742957041.DESKTOP-U3O8B5H.6900.0
/logs_repayment/events.out.tfevents.1742959423.DESKTOP-U3O8B5H.6900.1
/logs_repayment/events.out.tfevents.1742966658.DESKTOP-U3O8B5H.7856.0
/logs_repayment/events.out.tfevents.1742969310.DESKTOP-U3O8B5H.7856.1
/logs_income/events.out.tfevents.1743057915.DESKTOP-U3O8B5H.196.0
/logs_income/events.out.tfevents.1743059103.DESKTOP-U3O8B5H.196.1
/logs_ner/seed_0/events.out.tfevents.1744093403.DESKTOP-U3O8B5H.10748.0
/.idea/git_toolbox_prj.xml
/models/income_model.zip
/.idea/misc.xml
/models/income_model/best_model/model.safetensors
/models/income_model/checkpoint-25/model.safetensors
/models/income_model/checkpoint-75/model.safetensors
/models/repayment_model/best_model/model.safetensors
/models/repayment_model/checkpoint-75/model.safetensors
/models/repayment_model/checkpoint-125/model.safetensors
/.idea/modules.xml
/models/ner_model.zip
/models/classifier/checkpoint-11/optimizer.pt
/models/classifier/checkpoint-22/optimizer.pt
/models/classifier/checkpoint-33/optimizer.pt
/models/income_model/checkpoint-25/optimizer.pt
/models/income_model/checkpoint-75/optimizer.pt
/models/repayment_model/checkpoint-75/optimizer.pt
/models/repayment_model/checkpoint-125/optimizer.pt
/.idea/other.xml
/.idea/inspectionProfiles/profiles_settings.xml
/.idea/inspectionProfiles/Project_Default.xml
/.idea/pythonProject.iml
/models/classifier/checkpoint-11/pytorch_model.bin
/models/classifier/checkpoint-22/pytorch_model.bin
/models/classifier/checkpoint-33/pytorch_model.bin
/models/classifier/pytorch_model.bin
/models/repayment_model.zip
/models/classifier/checkpoint-11/rng_state.pth
/models/classifier/checkpoint-22/rng_state.pth
/models/classifier/checkpoint-33/rng_state.pth
/models/income_model/checkpoint-25/rng_state.pth
/models/income_model/checkpoint-75/rng_state.pth
/models/repayment_model/checkpoint-75/rng_state.pth
/models/repayment_model/checkpoint-125/rng_state.pth
/models/classifier/checkpoint-11/scheduler.pt
/models/classifier/checkpoint-22/scheduler.pt
/models/classifier/checkpoint-33/scheduler.pt
/models/income_model/checkpoint-25/scheduler.pt
/models/income_model/checkpoint-75/scheduler.pt
/models/repayment_model/checkpoint-75/scheduler.pt
/models/repayment_model/checkpoint-125/scheduler.pt
/models/classifier/special_tokens_map.json
/models/income_model/best_model/special_tokens_map.json
/models/repayment_model/best_model/special_tokens_map.json
/models/income_model/best_model/tokenizer.json
/models/repayment_model/best_model/tokenizer.json
/models/classifier/tokenizer_config.json
/models/income_model/best_model/tokenizer_config.json
/models/repayment_model/best_model/tokenizer_config.json
/models/classifier/checkpoint-11/trainer_state.json
/models/classifier/checkpoint-22/trainer_state.json
/models/classifier/checkpoint-33/trainer_state.json
/models/income_model/checkpoint-25/trainer_state.json
/models/income_model/checkpoint-75/trainer_state.json
/models/repayment_model/checkpoint-75/trainer_state.json
/models/repayment_model/checkpoint-125/trainer_state.json
/models/classifier/checkpoint-11/training_args.bin
/models/classifier/checkpoint-22/training_args.bin
/models/classifier/checkpoint-33/training_args.bin
/models/income_model/best_model/training_args.bin
/models/income_model/checkpoint-25/training_args.bin
/models/income_model/checkpoint-75/training_args.bin
/models/repayment_model/best_model/training_args.bin
/models/repayment_model/checkpoint-75/training_args.bin
/models/repayment_model/checkpoint-125/training_args.bin
/.idea/vcs.xml
/models/classifier/vocab.txt
/models/income_model/best_model/vocab.txt
/models/repayment_model/best_model/vocab.txt
app.py
@@ -7,7 +7,7 @@
from transformers import BertTokenizer, BertForSequenceClassification, AutoTokenizer, AutoModelForTokenClassification
import torch
from werkzeug.exceptions import BadRequest
from ner_config import NERConfig, RepaymentNERConfig, IncomeNERConfig
from ner_config import NERConfig, RepaymentNERConfig, IncomeNERConfig, FlightNERConfig, TrainNERConfig
import re
# 配置日志
@@ -27,6 +27,8 @@
        self.ner_path = "./models/ner_model/best_model"
        self.repayment_path = "./models/repayment_model/best_model"
        self.income_path = "./models/income_model/best_model"
        self.flight_path = "./models/flight_model/best_model"
        self.train_path = "./models/train_model/best_model"  # 添加火车票模型路径
        # 检查模型文件
        self._check_model_files()
@@ -36,12 +38,16 @@
        self.ner_tokenizer, self.ner_model = self._load_ner()
        self.repayment_tokenizer, self.repayment_model = self._load_repayment()
        self.income_tokenizer, self.income_model = self._load_income()
        self.flight_tokenizer, self.flight_model = self._load_flight()
        self.train_tokenizer, self.train_model = self._load_train()  # 加载火车票模型
        
        # 将模型设置为评估模式
        self.classifier_model.eval()
        self.ner_model.eval()
        self.repayment_model.eval()
        self.income_model.eval()
        self.flight_model.eval()
        self.train_model.eval()  # 设置火车票模型为评估模式
        
    def _check_model_files(self):
        """检查模型文件是否存在"""
@@ -53,6 +59,10 @@
            raise RuntimeError("还款模型文件不存在,请先运行训练脚本")
        if not os.path.exists(self.income_path):
            raise RuntimeError("收入模型文件不存在,请先运行训练脚本")
        if not os.path.exists(self.flight_path):
            raise RuntimeError("航班模型文件不存在,请先运行训练脚本")
        if not os.path.exists(self.train_path):
            raise RuntimeError("火车票模型文件不存在,请先运行训练脚本")
            
    def _load_classifier(self) -> Tuple[BertTokenizer, BertForSequenceClassification]:
        """加载分类模型"""
@@ -94,6 +104,26 @@
            logger.error(f"加载收入模型失败: {str(e)}")
            raise
            
    def _load_flight(self):
        """加载航班模型"""
        try:
            tokenizer = AutoTokenizer.from_pretrained(self.flight_path)
            model = AutoModelForTokenClassification.from_pretrained(self.flight_path)
            return tokenizer, model
        except Exception as e:
            logger.error(f"加载航班模型失败: {str(e)}")
            raise
    def _load_train(self):
        """加载火车票模型"""
        try:
            tokenizer = AutoTokenizer.from_pretrained(self.train_path)
            model = AutoModelForTokenClassification.from_pretrained(self.train_path)
            return tokenizer, model
        except Exception as e:
            logger.error(f"加载火车票模型失败: {str(e)}")
            raise
    def classify_sms(self, text: str) -> str:
        """对短信进行分类"""
        try:
@@ -120,7 +150,7 @@
                "company": None,       # 寄件公司
                "address": None,       # 地址
                "pickup_code": None,   # 取件码
                "time": None          # 时间
                "time": None           # 添加时间字段
            }
            
            # 第一阶段:直接从文本中提取取件码
@@ -662,6 +692,222 @@
            logger.error(f"收入实体提取失败: {str(e)}")
            raise
    def extract_flight_entities(self, text: str) -> Dict[str, Optional[str]]:
        """提取航班相关实体"""
        try:
            # 初始化结果字典
            result = {
                "flight": None,           # 航班号
                "company": None,          # 航空公司
                "start": None,            # 出发地
                "end": None,              # 目的地
                "date": None,             # 日期
                "time": None,             # 时间
                "departure_time": None,   # 起飞时间
                "arrival_time": None,     # 到达时间
                "ticket_num": None,       # 机票号码
                "seat": None              # 座位等信息
            }
            # 使用NER模型提取实体
            inputs = self.flight_tokenizer(
                text,
                return_tensors="pt",
                truncation=True,
                max_length=FlightNERConfig.MAX_LENGTH
            )
            with torch.no_grad():
                outputs = self.flight_model(**inputs)
            predictions = torch.argmax(outputs.logits, dim=2)
            tokens = self.flight_tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
            tags = [self.flight_model.config.id2label[p] for p in predictions[0].numpy()]
            # 解析实体
            current_entity = None
            for token, tag in zip(tokens, tags):
                if tag.startswith("B-"):
                    if current_entity:
                        entity_type = current_entity["type"].lower()
                        result[entity_type] = current_entity["text"].replace("[UNK]", "").replace("##", "").strip()
                    current_entity = {"type": tag[2:], "text": token}
                elif tag.startswith("I-") and current_entity and tag[2:] == current_entity["type"]:
                    current_entity["text"] += token
                else:
                    if current_entity:
                        entity_type = current_entity["type"].lower()
                        result[entity_type] = current_entity["text"].replace("[UNK]", "").replace("##", "").strip()
                        current_entity = None
            # 处理最后一个实体
            if current_entity:
                entity_type = current_entity["type"].lower()
                result[entity_type] = current_entity["text"].replace("[UNK]", "").replace("##", "").strip()
            # 处理航班号格式
            if result["flight"]:
                flight_no = result["flight"].upper()
                # 清理航班号,只保留字母和数字
                flight_no = ''.join(c for c in flight_no if c.isalnum())
                # 验证航班号格式
                valid_pattern = re.compile(FlightNERConfig.FLIGHT_CONFIG['pattern'])
                if valid_pattern.match(flight_no):
                    result["flight"] = flight_no
                else:
                    # 尝试修复常见错误
                    if len(flight_no) >= FlightNERConfig.FLIGHT_CONFIG['min_length'] and flight_no[:2].isalpha() and flight_no[2:].isdigit():
                        result["flight"] = flight_no
                    else:
                        result["flight"] = None
            # 清理日期格式
            if result["date"]:
                date_str = result["date"]
                # 保留数字和常见日期分隔符
                date_str = ''.join(c for c in date_str if c.isdigit() or c in ['年', '月', '日', '-', '/', '.'])
                result["date"] = date_str
            # 清理时间格式
            for time_field in ["time", "departure_time", "arrival_time"]:
                if result[time_field]:
                    time_str = result[time_field]
                    # 保留数字和常见时间分隔符
                    time_str = ''.join(c for c in time_str if c.isdigit() or c in [':', '时', '分', '点'])
                    result[time_field] = time_str
            # 处理机票号码
            if result["ticket_num"]:
                ticket_num = result["ticket_num"]
                # 清理机票号码,只保留字母和数字
                ticket_num = ''.join(c for c in ticket_num if c.isalnum())
                result["ticket_num"] = ticket_num
            # 处理座位信息
            if result["seat"]:
                seat_str = result["seat"]
                # 移除可能的额外空格和特殊字符
                seat_str = seat_str.replace(" ", "").strip()
                result["seat"] = seat_str
            return result
        except Exception as e:
            logger.error(f"航班实体提取失败: {str(e)}")
            raise
    def extract_train_entities(self, text: str) -> Dict[str, Optional[str]]:
        """提取火车票相关实体"""
        try:
            # 初始化结果字典
            result = {
                "company": None,         # 12306
                "trips": None,           # 车次
                "start": None,           # 出发站
                "end": None,             # 到达站
                "date": None,            # 日期
                "time": None,            # 时间
                "seat": None,            # 座位等信息
                "name": None             # 用户姓名
            }
            # 使用NER模型提取实体
            inputs = self.train_tokenizer(
                text,
                return_tensors="pt",
                truncation=True,
                max_length=TrainNERConfig.MAX_LENGTH
            )
            with torch.no_grad():
                outputs = self.train_model(**inputs)
            predictions = torch.argmax(outputs.logits, dim=2)
            tokens = self.train_tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
            tags = [self.train_model.config.id2label[p] for p in predictions[0].numpy()]
            # 解析实体
            current_entity = None
            for token, tag in zip(tokens, tags):
                if tag.startswith("B-"):
                    if current_entity:
                        entity_type = current_entity["type"].lower()
                        result[entity_type] = current_entity["text"].replace("[UNK]", "").replace("##", "").strip()
                    current_entity = {"type": tag[2:], "text": token}
                elif tag.startswith("I-") and current_entity and tag[2:] == current_entity["type"]:
                    current_entity["text"] += token
                else:
                    if current_entity:
                        entity_type = current_entity["type"].lower()
                        result[entity_type] = current_entity["text"].replace("[UNK]", "").replace("##", "").strip()
                        current_entity = None
            # 处理最后一个实体
            if current_entity:
                entity_type = current_entity["type"].lower()
                result[entity_type] = current_entity["text"].replace("[UNK]", "").replace("##", "").strip()
            # 处理公司名称,通常为12306
            if result["company"]:
                company = result["company"].strip()
                # 如果文本中检测不到公司名称,但包含12306,则默认为12306
                result["company"] = company
            elif "12306" in text:
                result["company"] = "12306"
            # 处理车次格式
            if result["trips"]:
                trips_no = result["trips"].upper()
                # 清理车次号,只保留字母和数字
                trips_no = ''.join(c for c in trips_no if c.isalnum() or c in ['/', '-'])
                # 验证车次格式
                valid_patterns = [re.compile(pattern) for pattern in TrainNERConfig.TRIPS_CONFIG['patterns']]
                if any(pattern.match(trips_no) for pattern in valid_patterns):
                    result["trips"] = trips_no
                else:
                    # 尝试修复常见错误
                    if len(trips_no) >= TrainNERConfig.TRIPS_CONFIG['min_length'] and any(trips_no.startswith(t) for t in TrainNERConfig.TRIPS_CONFIG['train_types']):
                        result["trips"] = trips_no
                    elif trips_no.isdigit() and 1 <= len(trips_no) <= TrainNERConfig.TRIPS_CONFIG['max_length']:
                        result["trips"] = trips_no
                    else:
                        result["trips"] = None
            # 清理日期格式
            if result["date"]:
                date_str = result["date"]
                # 保留数字和常见日期分隔符
                date_str = ''.join(c for c in date_str if c.isdigit() or c in ['年', '月', '日', '-', '/', '.'])
                result["date"] = date_str
            # 清理时间格式
            if result["time"]:
                time_str = result["time"]
                # 保留数字和常见时间分隔符
                time_str = ''.join(c for c in time_str if c.isdigit() or c in [':', '时', '分', '点'])
                result["time"] = time_str
            # 处理座位信息
            if result["seat"]:
                seat_str = result["seat"]
                # 移除可能的额外空格和特殊字符
                seat_str = seat_str.replace(" ", "").strip()
                result["seat"] = seat_str
            # 处理乘客姓名
            if result["name"]:
                name = result["name"].strip()
                # 移除可能的标点符号
                name = ''.join(c for c in name if c.isalnum() or c in ['*', '·'])
                result["name"] = name
            return result
        except Exception as e:
            logger.error(f"火车票实体提取失败: {str(e)}")
            raise
# 创建Flask应用
app = Flask(__name__)
model_manager = ModelManager()
@@ -695,6 +941,10 @@
            details = model_manager.extract_repayment_entities(text)
        elif category == "收入":
            details = model_manager.extract_income_entities(text)
        elif category == "航班":
            details = model_manager.extract_flight_entities(text)
        elif category == "火车票":  # 添加火车票类别处理
            details = model_manager.extract_train_entities(text)
        else:
            details = {}
        
check_labels.py
@@ -1,4 +1,5 @@
from ner_config import RepaymentNERConfig
from ner_config import RepaymentNERConfig, FlightNERConfig, TrainNERConfig
# 脚本:校验非法格式
@@ -6,14 +7,14 @@
    label_set = set()
    line_num = 0
    
    with open(RepaymentNERConfig.DATA_PATH, 'r', encoding='utf-8') as f:
    with open(FlightNERConfig.DATA_PATH, 'r', encoding='utf-8') as f:
        for line in f:
            line_num += 1
            line = line.strip()
            if line:
                try:
                    _, label = line.split(maxsplit=1)
                    if label not in RepaymentNERConfig.LABELS:
                    if label not in FlightNERConfig.LABELS:
                        print(f"行 {line_num}: 发现非法标签 '{label}'")
                        label_set.add(label)
                except Exception as e:
data/flight.txt
data/train.txt
文件太大
ner_config.py
@@ -30,7 +30,7 @@
    # 交叉验证配置
    N_SPLITS = 3      # CPU环境下减少折数
    N_SEEDS = 1       # CPU环境下减少种子数量
    # 确保标签列表完整
    LABELS =  [
    "O",
@@ -72,40 +72,40 @@
    LEARNING_RATE = 3e-5
    WARMUP_RATIO = 0.1
    WEIGHT_DECAY = 0.01
    # 数据增强配置
    USE_DATA_AUGMENTATION = False
    AUGMENTATION_RATIO = 0.3
    # 训练策略
    GRADIENT_ACCUMULATION_STEPS = 4
    EVAL_STEPS = 25
    LOGGING_STEPS = 10
    SAVE_STEPS = 25      # 添加保存步数
    SAVE_TOTAL_LIMIT = 2 # 添加保存检查点数量限制
    # 路径配置
    DATA_PATH = "data/repayment.txt"
    MODEL_PATH = "./models/repayment_model"
    LOG_PATH = "./logs_repayment"
    # 训练配置优化
    SEED = 42
    TEST_SIZE = 0.1
    EARLY_STOPPING_PATIENCE = 2
    # CPU环境配置
    MAX_GRAD_NORM = 1.0
    FP16 = False        # CPU环境下关闭FP16
    # CPU环境下的数据加载优化
    DATALOADER_NUM_WORKERS = 0  # CPU环境下设为0
    DATALOADER_PIN_MEMORY = False  # CPU环境下关闭
    # 交叉验证配置
    N_SPLITS = 3
    N_SEEDS = 1
    # 标签列表
    LABELS = [
        "O",
@@ -149,40 +149,40 @@
    LEARNING_RATE = 3e-5
    WARMUP_RATIO = 0.1
    WEIGHT_DECAY = 0.01
    # 数据增强配置
    USE_DATA_AUGMENTATION = False
    AUGMENTATION_RATIO = 0.3
    # 训练策略
    GRADIENT_ACCUMULATION_STEPS = 4
    EVAL_STEPS = 25
    LOGGING_STEPS = 10
    SAVE_STEPS = 25
    SAVE_TOTAL_LIMIT = 2
    # 路径配置
    DATA_PATH = "data/income.txt"
    MODEL_PATH = "./models/income_model"
    LOG_PATH = "./logs_income"
    # 训练配置优化
    SEED = 42
    TEST_SIZE = 0.1
    EARLY_STOPPING_PATIENCE = 2
    # CPU环境配置
    MAX_GRAD_NORM = 1.0
    FP16 = False
    # CPU环境下的数据加载优化
    DATALOADER_NUM_WORKERS = 0
    DATALOADER_PIN_MEMORY = False
    # 交叉验证配置
    N_SPLITS = 3
    N_SEEDS = 1
    # 标签列表
    LABELS = [
        "O",
@@ -211,4 +211,154 @@
        'max_integer_digits': 12,  # 整数部分最大位数
        'currency_symbols': ['¥', '¥', 'RMB', '元'],  # 货币符号
        'decimal_context_range': 3  # 查找小数点的上下文范围
    }
class FlightNERConfig:
    # 优化模型参数 (与 RepaymentNERConfig 保持一致)
    MODEL_NAME = "bert-base-chinese"
    MAX_LENGTH = 128
    BATCH_SIZE = 4
    EPOCHS = 10
    LEARNING_RATE = 3e-5
    WARMUP_RATIO = 0.1
    WEIGHT_DECAY = 0.01
    # 训练策略
    GRADIENT_ACCUMULATION_STEPS = 4
    EVAL_STEPS = 25
    LOGGING_STEPS = 10
    SAVE_STEPS = 25
    SAVE_TOTAL_LIMIT = 2
    # 路径配置
    DATA_PATH = "data/flight.txt"
    MODEL_PATH = "./models/flight_model"
    LOG_PATH = "./logs_flight"
    # 训练配置
    SEED = 42
    TEST_SIZE = 0.1
    EARLY_STOPPING_PATIENCE = 2
    # CPU环境配置
    MAX_GRAD_NORM = 1.0
    FP16 = False
    DATALOADER_NUM_WORKERS = 0
    DATALOADER_PIN_MEMORY = False
    # 交叉验证配置
    N_SPLITS = 3
    N_SEEDS = 3  # 增加种子数量以提高模型稳定性
    # 标签列表 - 保持与需求一致
    LABELS = [
        "O",
        "B-FLIGHT", "I-FLIGHT",     # 航班号
        "B-COMPANY", "I-COMPANY",   # 航空公司
        "B-START", "I-START",       # 出发地
        "B-END", "I-END",           # 目的地
        "B-DATE", "I-DATE",         # 日期
        "B-TIME", "I-TIME",         # 时间
        "B-DEPARTURE_TIME", "I-DEPARTURE_TIME",  # 起飞时间
        "B-ARRIVAL_TIME", "I-ARRIVAL_TIME",      # 到达时间
        "B-TICKET_NUM", "I-TICKET_NUM",  # 机票号码
        "B-SEAT", "I-SEAT"  # 座位等信息
    ]
    # 实体长度限制 - 更新键名与LABELS一致
    MAX_ENTITY_LENGTH = {
        "FLIGHT": 10,       # 航班号
        "COMPANY": 15,      # 航空公司
        "START": 10,        # 出发地
        "END": 10,          # 目的地
        "DATE": 15,         # 日期
        "TIME": 10,         # 时间
        "DEPARTURE_TIME": 10,  # 起飞时间
        "ARRIVAL_TIME": 10,    # 到达时间
        "TICKET_NUM": 10,      # 用户姓名
        "SEAT": 10             # 座位等信息
    }
    # 航班号配置
    FLIGHT_CONFIG = {
        'pattern': r'[A-Z]{2}\d{3,4}',
        'min_length': 4,
        'max_length': 7,
        'carrier_codes': ['CA', 'MU', 'CZ', 'HU', '3U', 'ZH', 'FM', 'MF', 'SC', '9C']  # 常见航司代码
    }
class TrainNERConfig:
    # 模型参数
    MODEL_NAME = "bert-base-chinese"
    MAX_LENGTH = 128
    BATCH_SIZE = 4
    EPOCHS = 10
    LEARNING_RATE = 3e-5
    WARMUP_RATIO = 0.1
    WEIGHT_DECAY = 0.01
    # 训练策略
    GRADIENT_ACCUMULATION_STEPS = 4
    EVAL_STEPS = 25
    LOGGING_STEPS = 10
    SAVE_STEPS = 25
    SAVE_TOTAL_LIMIT = 2
    # 路径配置
    DATA_PATH = "data/train.txt"
    MODEL_PATH = "./models/train_model"
    LOG_PATH = "./logs_train"
    # 训练配置
    SEED = 42
    TEST_SIZE = 0.1
    EARLY_STOPPING_PATIENCE = 2
    # CPU环境配置
    MAX_GRAD_NORM = 1.0
    FP16 = False
    DATALOADER_NUM_WORKERS = 0
    DATALOADER_PIN_MEMORY = False
    # 交叉验证配置
    N_SPLITS = 3
    N_SEEDS = 3  # 增加种子数量以提高模型稳定性
    # 标签列表
    LABELS = [
        "O",
        "B-COMPANY", "I-COMPANY",  # 车次
        "B-TRIPS", "I-TRIPS",       # 车次
        "B-START", "I-START",     # 出发站
        "B-END", "I-END",         # 到达站
        "B-DATE", "I-DATE",      # 日期
        "B-TIME", "I-TIME",      # 时间
        "B-SEAT", "I-SEAT",      # 座位等信息
        "B-NAME", "I-NAME"       # 用户姓名
    ]
    # 实体长度限制 - 更新键名与LABELS一致
    MAX_ENTITY_LENGTH = {
        "COMPANY": 8,          # 12306
        "TRIPS": 8,            # 车次
        "START": 10,           # 出发站
        "END": 10,             # 到达站
        "DATE": 15,            # 日期
        "TIME": 10,            # 时间
        "SEAT": 10,            # 座位等信息
        "NAME": 10             # 用户姓名
    }
    # 车次配置
    TRIPS_CONFIG = {
        'patterns': [
            r'[GDCZTKY]\d{1,2}',            # G1, D1, C1等
            r'[GDCZTKY]\d{1,2}/\d{1,2}',    # G1/2等联运车次
            r'[GDCZTKY]\d{1,2}-\d{1,2}',    # G1-2等联运车次
            r'\d{1,4}',                     # 普通车次如1234次
            r'[A-Z]\d{1,4}'                 # Z1234等特殊车次
        ],
        'min_length': 1,
        'max_length': 8,
        'train_types': ['G', 'D', 'C', 'Z', 'T', 'K', 'Y']  # 车次类型前缀
    }
train_flight_ner.py
@@ -0,0 +1,297 @@
# train_flight_ner.py
from transformers import AutoTokenizer, AutoModelForTokenClassification, TrainingArguments, Trainer
from transformers.trainer_callback import EarlyStoppingCallback
import torch
from torch.utils.data import Dataset
import numpy as np
from sklearn.model_selection import train_test_split
from seqeval.metrics import f1_score, precision_score, recall_score
import random
import re
from ner_config import FlightNERConfig
# 设置随机种子
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
set_seed(FlightNERConfig.SEED)
class NERDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, label_list):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        # 创建标签到ID的映射
        self.label2id = {label: i for i, label in enumerate(label_list)}
        self.id2label = {i: label for i, label in enumerate(label_list)}
        # 打印标签映射信息
        print("标签映射:")
        for label, idx in self.label2id.items():
            print(f"{label}: {idx}")
        # 对文本进行编码
        self.encodings = self.tokenize_and_align_labels()
    def tokenize_and_align_labels(self):
        tokenized_inputs = self.tokenizer(
            [''.join(text) for text in self.texts],
            truncation=True,
            padding=True,
            max_length=FlightNERConfig.MAX_LENGTH,
            return_offsets_mapping=True,
            return_tensors=None
        )
        labels = []
        for i, label in enumerate(self.labels):
            word_ids = tokenized_inputs.word_ids(i)
            previous_word_idx = None
            label_ids = []
            current_entity = None
            for word_idx in word_ids:
                if word_idx is None:
                    label_ids.append(-100)
                elif word_idx != previous_word_idx:
                    # 新词开始
                    label_ids.append(self.label2id[label[word_idx]])
                    if label[word_idx].startswith("B-"):
                        current_entity = label[word_idx][2:]
                    elif label[word_idx] == "O":
                        current_entity = None
                else:
                    # 同一个词的后续token
                    if current_entity:
                        label_ids.append(self.label2id[f"I-{current_entity}"])
                    else:
                        label_ids.append(self.label2id["O"])
                previous_word_idx = word_idx
            labels.append(label_ids)
        tokenized_inputs["labels"] = labels
        return tokenized_inputs
    def __getitem__(self, idx):
        return {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
    def __len__(self):
        return len(self.texts)
def load_data(file_path):
    texts, labels = [], []
    current_words, current_labels = [], []
    def clean_flight_labels(words, labels):
        """清理航班号标注,确保格式正确"""
        i = 0
        while i < len(words):
            if labels[i].startswith("B-FLIGHT"):  # 已修改标签名称
                # 找到航班号的结束位置
                j = i + 1
                while j < len(words) and labels[j].startswith("I-FLIGHT"):  # 已修改标签名称
                    j += 1
                # 检查并修正航班号序列
                flight_words = words[i:j]
                flight_str = ''.join(flight_words)
                # 检查格式是否符合航班号规范
                valid_pattern = re.compile(r'^[A-Z]{2}\d{3,4}$')
                if not valid_pattern.match(flight_str):
                    # 将格式不正确的标签改为O
                    for k in range(i, j):
                        labels[k] = "O"
                i = j
            else:
                i += 1
        return words, labels
    with open(file_path, 'r', encoding='utf-8') as f:
        for line in f:
            line = line.strip()
            if line:
                try:
                    word, label = line.split(maxsplit=1)
                    current_words.append(word)
                    current_labels.append(label)
                except Exception as e:
                    print(f"错误:处理行时出错: '{line}'")
                    continue
            elif current_words:  # 遇到空行且当前有数据
                # 清理航班号标注
                current_words, current_labels = clean_flight_labels(current_words, current_labels)
                texts.append(current_words)
                labels.append(current_labels)
                current_words, current_labels = [], []
    if current_words:  # 处理最后一个样本
        current_words, current_labels = clean_flight_labels(current_words, current_labels)
        texts.append(current_words)
        labels.append(current_labels)
    return texts, labels
def compute_metrics(p):
    """计算评估指标"""
    predictions, labels = p
    predictions = np.argmax(predictions, axis=2)
    # 移除特殊token的预测和标签
    true_predictions = [
        [FlightNERConfig.LABELS[p] for (p, l) in zip(prediction, label) if l != -100]
        for prediction, label in zip(predictions, labels)
    ]
    true_labels = [
        [FlightNERConfig.LABELS[l] for (p, l) in zip(prediction, label) if l != -100]
        for prediction, label in zip(predictions, labels)
    ]
    # 计算总体评估指标
    results = {
        "overall_f1": f1_score(true_labels, true_predictions),
        "overall_precision": precision_score(true_labels, true_predictions),
        "overall_recall": recall_score(true_labels, true_predictions)
    }
    # 计算每个实体类型的指标
    for entity_type in ["FLIGHT", "COMPANY", "START", "END", "DATE", "TIME", "DEPARTURE_TIME", "ARRIVAL_TIME","TICKET_NUM","SEAT"]:
        # 将标签转换为二进制形式
        binary_preds = []
        binary_labels = []
        for pred_seq, label_seq in zip(true_predictions, true_labels):
            pred_binary = []
            label_binary = []
            for pred, label in zip(pred_seq, label_seq):
                # 检查标签是否属于当前实体类型
                pred_is_entity = pred.endswith(entity_type)
                label_is_entity = label.endswith(entity_type)
                pred_binary.append(1 if pred_is_entity else 0)
                label_binary.append(1 if label_is_entity else 0)
            binary_preds.append(pred_binary)
            binary_labels.append(label_binary)
        # 计算当前实体类型的F1分数
        try:
            entity_f1 = f1_score(
                sum(binary_labels, []),  # 展平列表
                sum(binary_preds, []),   # 展平列表
                average='binary'         # 使用二进制评估
            )
            results[f"{entity_type}_f1"] = entity_f1
        except Exception as e:
            print(f"计算{entity_type}的F1分数时出错: {str(e)}")
            results[f"{entity_type}_f1"] = 0.0
    return results
def augment_data(texts, labels):
    """数据增强"""
    augmented_texts = []
    augmented_labels = []
    for text, label in zip(texts, labels):
        # 原始数据
        augmented_texts.append(text)
        augmented_labels.append(label)
        # 删除一些无关字符
        new_text = []
        new_label = []
        for t, l in zip(text, label):
            if l == "O" and random.random() < 0.3:
                continue
            new_text.append(t)
            new_label.append(l)
        augmented_texts.append(new_text)
        augmented_labels.append(new_label)
    return augmented_texts, augmented_labels
def main():
    # 加载数据
    texts, labels = load_data(FlightNERConfig.DATA_PATH)
    print(f"加载的数据集大小:{len(texts)}个样本")
    # 划分数据集
    train_texts, val_texts, train_labels, val_labels = train_test_split(
        texts, labels, test_size=FlightNERConfig.TEST_SIZE, random_state=FlightNERConfig.SEED
    )
    # 数据增强
    train_texts, train_labels = augment_data(train_texts, train_labels)
    # 加载分词器和模型
    tokenizer = AutoTokenizer.from_pretrained(FlightNERConfig.MODEL_NAME)
    model = AutoModelForTokenClassification.from_pretrained(
        FlightNERConfig.MODEL_NAME,
        num_labels=len(FlightNERConfig.LABELS),
        id2label={i: label for i, label in enumerate(FlightNERConfig.LABELS)},
        label2id={label: i for i, label in enumerate(FlightNERConfig.LABELS)}
    )
    # 创建数据集
    train_dataset = NERDataset(train_texts, train_labels, tokenizer, FlightNERConfig.LABELS)
    val_dataset = NERDataset(val_texts, val_labels, tokenizer, FlightNERConfig.LABELS)
    # 训练参数
    training_args = TrainingArguments(
        output_dir=FlightNERConfig.MODEL_PATH,
        num_train_epochs=FlightNERConfig.EPOCHS,
        per_device_train_batch_size=FlightNERConfig.BATCH_SIZE,
        per_device_eval_batch_size=FlightNERConfig.BATCH_SIZE,
        learning_rate=FlightNERConfig.LEARNING_RATE,
        warmup_ratio=FlightNERConfig.WARMUP_RATIO,
        weight_decay=FlightNERConfig.WEIGHT_DECAY,
        gradient_accumulation_steps=FlightNERConfig.GRADIENT_ACCUMULATION_STEPS,
        logging_steps=FlightNERConfig.LOGGING_STEPS,
        save_total_limit=2,
        no_cuda=True,
        evaluation_strategy="steps",
        eval_steps=FlightNERConfig.EVAL_STEPS,
        save_strategy="steps",
        save_steps=FlightNERConfig.SAVE_STEPS,
        load_best_model_at_end=True,
        metric_for_best_model="overall_f1",
        greater_is_better=True,
        logging_dir=FlightNERConfig.LOG_PATH,
        logging_first_step=True,
        report_to=["tensorboard"],
    )
    # 训练器
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=val_dataset,
        compute_metrics=compute_metrics,
        callbacks=[EarlyStoppingCallback(early_stopping_patience=FlightNERConfig.EARLY_STOPPING_PATIENCE)]
    )
    # 训练模型
    trainer.train()
    # 评估结果
    eval_results = trainer.evaluate()
    print("\n评估结果:")
    for key, value in eval_results.items():
        print(f"{key}: {value:.4f}")
    # 保存最终模型
    model.save_pretrained(f"{FlightNERConfig.MODEL_PATH}/best_model")
    tokenizer.save_pretrained(f"{FlightNERConfig.MODEL_PATH}/best_model")
if __name__ == "__main__":
    main()
train_train_ner.py
@@ -0,0 +1,289 @@
# train_train_ner.py
from transformers import AutoTokenizer, AutoModelForTokenClassification, TrainingArguments, Trainer
from transformers.trainer_callback import EarlyStoppingCallback
import torch
from torch.utils.data import Dataset
import numpy as np
from sklearn.model_selection import train_test_split
from seqeval.metrics import f1_score, precision_score, recall_score
import random
import os
import re
from ner_config import TrainNERConfig
# 设置随机种子
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
set_seed(TrainNERConfig.SEED)
class NERDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, label_list):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        # 创建标签到ID的映射
        self.label2id = {label: i for i, label in enumerate(label_list)}
        self.id2label = {i: label for i, label in enumerate(label_list)}
        # 打印标签映射信息
        print("标签映射:")
        for label, idx in self.label2id.items():
            print(f"{label}: {idx}")
        # 对文本进行编码
        self.encodings = self.tokenize_and_align_labels()
    def tokenize_and_align_labels(self):
        tokenized_inputs = self.tokenizer(
            [''.join(text) for text in self.texts],
            truncation=True,
            padding=True,
            max_length=TrainNERConfig.MAX_LENGTH,
            return_offsets_mapping=True,
            return_tensors=None
        )
        labels = []
        for i, label in enumerate(self.labels):
            word_ids = tokenized_inputs.word_ids(i)
            previous_word_idx = None
            label_ids = []
            current_entity = None
            for word_idx in word_ids:
                if word_idx is None:
                    label_ids.append(-100)
                elif word_idx != previous_word_idx:
                    # 新词开始
                    label_ids.append(self.label2id[label[word_idx]])
                    if label[word_idx].startswith("B-"):
                        current_entity = label[word_idx][2:]
                    elif label[word_idx] == "O":
                        current_entity = None
                else:
                    # 同一个词的后续token
                    if current_entity:
                        label_ids.append(self.label2id[f"I-{current_entity}"])
                    else:
                        label_ids.append(self.label2id["O"])
                previous_word_idx = word_idx
            labels.append(label_ids)
        tokenized_inputs["labels"] = labels
        return tokenized_inputs
    def __getitem__(self, idx):
        return {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
    def __len__(self):
        return len(self.texts)
def load_data(file_path):
    texts, labels = [], []
    current_words, current_labels = [], []
    def clean_trips_labels(words, labels):
        """清理车次标注,确保格式正确"""
        i = 0
        while i < len(words):
            if labels[i].startswith("B-TRIPS"):  # 修改标签名
                # 找到车次的结束位置
                j = i + 1
                while j < len(words) and labels[j].startswith("I-TRIPS"):  # 修改标签名
                    j += 1
                # 检查并修正车次序列
                train_words = words[i:j]
                train_str = ''.join(train_words)
                # 检查格式是否符合车次规范
                valid_patterns = [
                    re.compile(r'^[GDCZTKY]\d{1,2}$'),
                    re.compile(r'^[GDCZTKY]\d{1,2}/\d{1,2}$'),
                    re.compile(r'^[GDCZTKY]\d{1,2}-\d{1,2}$'),
                    re.compile(r'^\d{1,4}$'),
                    re.compile(r'^[A-Z]\d{1,4}$')
                ]
                is_valid = any(pattern.match(train_str) for pattern in valid_patterns)
                if not is_valid:
                    # 将格式不正确的标签改为O
                    for k in range(i, j):
                        labels[k] = "O"
                i = j
            else:
                i += 1
        return words, labels
    with open(file_path, 'r', encoding='utf-8') as f:
        for line in f:
            line = line.strip()
            if line:
                try:
                    word, label = line.split(maxsplit=1)
                    current_words.append(word)
                    current_labels.append(label)
                except Exception as e:
                    print(f"错误:处理行时出错: '{line}'")
                    continue
            elif current_words:  # 遇到空行且当前有数据
                # 清理车次标注
                current_words, current_labels = clean_trips_labels(current_words, current_labels)
                texts.append(current_words)
                labels.append(current_labels)
                current_words, current_labels = [], []
    if current_words:  # 处理最后一个样本
        current_words, current_labels = clean_trips_labels(current_words, current_labels)
        texts.append(current_words)
        labels.append(current_labels)
    return texts, labels
def compute_metrics(p):
    """计算评估指标"""
    predictions, labels = p
    predictions = np.argmax(predictions, axis=2)
    # 移除特殊token的预测和标签
    true_predictions = [
        [TrainNERConfig.LABELS[p] for (p, l) in zip(prediction, label) if l != -100]
        for prediction, label in zip(predictions, labels)
    ]
    true_labels = [
        [TrainNERConfig.LABELS[l] for (p, l) in zip(prediction, label) if l != -100]
        for prediction, label in zip(predictions, labels)
    ]
    # 计算总体评估指标
    results = {
        "overall_f1": f1_score(true_labels, true_predictions),
        "overall_precision": precision_score(true_labels, true_predictions),
        "overall_recall": recall_score(true_labels, true_predictions)
    }
    # 计算每个实体类型的指标
    for entity_type in ["COMPANY","TRIPS", "START", "END", "DATE", "TIME", "SEAT", "NAME"]:
        # 将标签转换为二进制形式
        binary_preds = []
        binary_labels = []
        for pred_seq, label_seq in zip(true_predictions, true_labels):
            pred_binary = []
            label_binary = []
            for pred, label in zip(pred_seq, label_seq):
                # 检查标签是否属于当前实体类型
                pred_is_entity = pred.endswith(entity_type)
                label_is_entity = label.endswith(entity_type)
                pred_binary.append(1 if pred_is_entity else 0)
                label_binary.append(1 if label_is_entity else 0)
            binary_preds.append(pred_binary)
            binary_labels.append(label_binary)
        # 计算当前实体类型的F1分数
        try:
            entity_f1 = f1_score(
                sum(binary_labels, []),  # 展平列表
                sum(binary_preds, []),   # 展平列表
                average='binary'         # 使用二进制评估
            )
            results[f"{entity_type}_f1"] = entity_f1
        except Exception as e:
            print(f"计算{entity_type}的F1分数时出错: {str(e)}")
            results[f"{entity_type}_f1"] = 0.0
    return results
def augment_data(texts, labels):
    """数据增强"""
    augmented_texts = []
    augmented_labels = []
    for text, label in zip(texts, labels):
        # 原始数据
        augmented_texts.append(text)
        augmented_labels.append(label)
        # 删除一些无关字符
        new_text = []
        new_label = []
        for t, l in zip(text, label):
            if l == "O" and random.random() < 0.3:
                continue
            new_text.append(t)
            new_label.append(l)
        augmented_texts.append(new_text)
        augmented_labels.append(new_label)
    return augmented_texts, augmented_labels
def main():
    # 加载数据
    texts, labels = load_data(TrainNERConfig.DATA_PATH)
    print(f"加载的数据集大小:{len(texts)}个样本")
    # 划分数据集
    train_texts, val_texts, train_labels, val_labels = train_test_split(
        texts, labels, test_size=TrainNERConfig.TEST_SIZE, random_state=TrainNERConfig.SEED
    )
    # 数据增强
    train_texts, train_labels = augment_data(train_texts, train_labels)
    # 加载分词器和模型
    tokenizer = AutoTokenizer.from_pretrained(TrainNERConfig.MODEL_NAME)
    model = AutoModelForTokenClassification.from_pretrained(
        TrainNERConfig.MODEL_NAME,
        num_labels=len(TrainNERConfig.LABELS),
        id2label={i: label for i, label in enumerate(TrainNERConfig.LABELS)},
        label2id={label: i for i, label in enumerate(TrainNERConfig.LABELS)}
    )
    # 创建数据集
    train_dataset = NERDataset(train_texts, train_labels, tokenizer, TrainNERConfig.LABELS)
    val_dataset = NERDataset(val_texts, val_labels, tokenizer, TrainNERConfig.LABELS)
    # 训练参数
    training_args = TrainingArguments(
        output_dir=TrainNERConfig.MODEL_PATH,
        num_train_epochs=TrainNERConfig.EPOCHS,
        per_device_train_batch_size=TrainNERConfig.BATCH_SIZE,
        per_device_eval_batch_size=TrainNERConfig.BATCH_SIZE,
        learning_rate=TrainNERConfig.LEARNING_RATE,
        warmup_ratio=TrainNERConfig.WARMUP_RATIO,
        weight_decay=TrainNERConfig.WEIGHT_DECAY,
        gradient_accumulation_steps=TrainNERConfig.GRADIENT_ACCUMULATION_STEPS
    )
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=val_dataset,
        compute_metrics=compute_metrics
    )
    trainer.train()
    # 评估结果
    eval_results = trainer.evaluate()
    print("\n评估结果:")
    for key, value in eval_results.items():
        print(f"{key}: {value:.4f}")
    # 保存最终模型
    model.save_pretrained(f"{TrainNERConfig.MODEL_PATH}/best_model")
    tokenizer.save_pretrained(f"{TrainNERConfig.MODEL_PATH}/best_model")
if __name__ == "__main__":
    main()