对比新文件 |
| | |
| | | /models/ner_model/ |
| | | /.idea/.gitignore |
| | | /models/classifier/checkpoint-11/config.json |
| | | /models/classifier/checkpoint-22/config.json |
| | | /models/classifier/checkpoint-33/config.json |
| | | /models/classifier/config.json |
| | | /models/income_model/best_model/config.json |
| | | /models/income_model/checkpoint-25/config.json |
| | | /models/income_model/checkpoint-75/config.json |
| | | /models/repayment_model/best_model/config.json |
| | | /models/repayment_model/checkpoint-75/config.json |
| | | /models/repayment_model/checkpoint-125/config.json |
| | | /models/classifier/confusion_matrix.png |
| | | /logs_ner/seed_0/events.out.tfevents.1742464122.DESKTOP-U3O8B5H.5008.0 |
| | | /logs_ner/seed_0/events.out.tfevents.1742469089.DESKTOP-U3O8B5H.8812.0 |
| | | /logs_ner/seed_0/events.out.tfevents.1742470664.DESKTOP-U3O8B5H.8812.1 |
| | | /logs_ner/seed_1/events.out.tfevents.1742470670.DESKTOP-U3O8B5H.8812.2 |
| | | /logs_ner/seed_1/events.out.tfevents.1742472137.DESKTOP-U3O8B5H.8812.3 |
| | | /logs_ner/seed_0/events.out.tfevents.1742522311.DESKTOP-U3O8B5H.9272.0 |
| | | /logs_ner/seed_0/events.out.tfevents.1742523686.DESKTOP-U3O8B5H.9272.1 |
| | | /logs_ner/seed_0/events.out.tfevents.1742529705.DESKTOP-U3O8B5H.11904.0 |
| | | /logs_ner/seed_0/events.out.tfevents.1742531427.DESKTOP-U3O8B5H.11904.1 |
| | | /logs_ner/seed_0/events.out.tfevents.1742534893.DESKTOP-U3O8B5H.14104.0 |
| | | /logs_ner/seed_0/events.out.tfevents.1742535919.DESKTOP-U3O8B5H.14104.1 |
| | | /logs_ner/seed_0/events.out.tfevents.1742536567.DESKTOP-U3O8B5H.11136.0 |
| | | /logs_ner/seed_0/events.out.tfevents.1742537539.DESKTOP-U3O8B5H.11136.1 |
| | | /logs_ner/seed_0/events.out.tfevents.1742538119.DESKTOP-U3O8B5H.8680.0 |
| | | /logs_ner/seed_0/events.out.tfevents.1742540436.DESKTOP-U3O8B5H.8680.1 |
| | | /logs_repayment/events.out.tfevents.1742888152.DESKTOP-U3O8B5H.3584.0 |
| | | /logs_repayment/events.out.tfevents.1742890775.DESKTOP-U3O8B5H.6340.0 |
| | | /logs_repayment/events.out.tfevents.1742891151.DESKTOP-U3O8B5H.9964.0 |
| | | /logs_repayment/events.out.tfevents.1742893885.DESKTOP-U3O8B5H.12032.0 |
| | | /logs_repayment/events.out.tfevents.1742896256.DESKTOP-U3O8B5H.12032.1 |
| | | /logs_repayment/events.out.tfevents.1742953278.DESKTOP-U3O8B5H.9672.0 |
| | | /logs_repayment/events.out.tfevents.1742955624.DESKTOP-U3O8B5H.9672.1 |
| | | /logs_repayment/events.out.tfevents.1742957041.DESKTOP-U3O8B5H.6900.0 |
| | | /logs_repayment/events.out.tfevents.1742959423.DESKTOP-U3O8B5H.6900.1 |
| | | /logs_repayment/events.out.tfevents.1742966658.DESKTOP-U3O8B5H.7856.0 |
| | | /logs_repayment/events.out.tfevents.1742969310.DESKTOP-U3O8B5H.7856.1 |
| | | /logs_income/events.out.tfevents.1743057915.DESKTOP-U3O8B5H.196.0 |
| | | /logs_income/events.out.tfevents.1743059103.DESKTOP-U3O8B5H.196.1 |
| | | /logs_ner/seed_0/events.out.tfevents.1744093403.DESKTOP-U3O8B5H.10748.0 |
| | | /.idea/git_toolbox_prj.xml |
| | | /models/income_model.zip |
| | | /.idea/misc.xml |
| | | /models/income_model/best_model/model.safetensors |
| | | /models/income_model/checkpoint-25/model.safetensors |
| | | /models/income_model/checkpoint-75/model.safetensors |
| | | /models/repayment_model/best_model/model.safetensors |
| | | /models/repayment_model/checkpoint-75/model.safetensors |
| | | /models/repayment_model/checkpoint-125/model.safetensors |
| | | /.idea/modules.xml |
| | | /models/ner_model.zip |
| | | /models/classifier/checkpoint-11/optimizer.pt |
| | | /models/classifier/checkpoint-22/optimizer.pt |
| | | /models/classifier/checkpoint-33/optimizer.pt |
| | | /models/income_model/checkpoint-25/optimizer.pt |
| | | /models/income_model/checkpoint-75/optimizer.pt |
| | | /models/repayment_model/checkpoint-75/optimizer.pt |
| | | /models/repayment_model/checkpoint-125/optimizer.pt |
| | | /.idea/other.xml |
| | | /.idea/inspectionProfiles/profiles_settings.xml |
| | | /.idea/inspectionProfiles/Project_Default.xml |
| | | /.idea/pythonProject.iml |
| | | /models/classifier/checkpoint-11/pytorch_model.bin |
| | | /models/classifier/checkpoint-22/pytorch_model.bin |
| | | /models/classifier/checkpoint-33/pytorch_model.bin |
| | | /models/classifier/pytorch_model.bin |
| | | /models/repayment_model.zip |
| | | /models/classifier/checkpoint-11/rng_state.pth |
| | | /models/classifier/checkpoint-22/rng_state.pth |
| | | /models/classifier/checkpoint-33/rng_state.pth |
| | | /models/income_model/checkpoint-25/rng_state.pth |
| | | /models/income_model/checkpoint-75/rng_state.pth |
| | | /models/repayment_model/checkpoint-75/rng_state.pth |
| | | /models/repayment_model/checkpoint-125/rng_state.pth |
| | | /models/classifier/checkpoint-11/scheduler.pt |
| | | /models/classifier/checkpoint-22/scheduler.pt |
| | | /models/classifier/checkpoint-33/scheduler.pt |
| | | /models/income_model/checkpoint-25/scheduler.pt |
| | | /models/income_model/checkpoint-75/scheduler.pt |
| | | /models/repayment_model/checkpoint-75/scheduler.pt |
| | | /models/repayment_model/checkpoint-125/scheduler.pt |
| | | /models/classifier/special_tokens_map.json |
| | | /models/income_model/best_model/special_tokens_map.json |
| | | /models/repayment_model/best_model/special_tokens_map.json |
| | | /models/income_model/best_model/tokenizer.json |
| | | /models/repayment_model/best_model/tokenizer.json |
| | | /models/classifier/tokenizer_config.json |
| | | /models/income_model/best_model/tokenizer_config.json |
| | | /models/repayment_model/best_model/tokenizer_config.json |
| | | /models/classifier/checkpoint-11/trainer_state.json |
| | | /models/classifier/checkpoint-22/trainer_state.json |
| | | /models/classifier/checkpoint-33/trainer_state.json |
| | | /models/income_model/checkpoint-25/trainer_state.json |
| | | /models/income_model/checkpoint-75/trainer_state.json |
| | | /models/repayment_model/checkpoint-75/trainer_state.json |
| | | /models/repayment_model/checkpoint-125/trainer_state.json |
| | | /models/classifier/checkpoint-11/training_args.bin |
| | | /models/classifier/checkpoint-22/training_args.bin |
| | | /models/classifier/checkpoint-33/training_args.bin |
| | | /models/income_model/best_model/training_args.bin |
| | | /models/income_model/checkpoint-25/training_args.bin |
| | | /models/income_model/checkpoint-75/training_args.bin |
| | | /models/repayment_model/best_model/training_args.bin |
| | | /models/repayment_model/checkpoint-75/training_args.bin |
| | | /models/repayment_model/checkpoint-125/training_args.bin |
| | | /.idea/vcs.xml |
| | | /models/classifier/vocab.txt |
| | | /models/income_model/best_model/vocab.txt |
| | | /models/repayment_model/best_model/vocab.txt |
| | |
| | | from transformers import BertTokenizer, BertForSequenceClassification, AutoTokenizer, AutoModelForTokenClassification |
| | | import torch |
| | | from werkzeug.exceptions import BadRequest |
| | | from ner_config import NERConfig, RepaymentNERConfig, IncomeNERConfig |
| | | from ner_config import NERConfig, RepaymentNERConfig, IncomeNERConfig, FlightNERConfig, TrainNERConfig |
| | | import re |
| | | |
| | | # 配置日志 |
| | |
| | | self.ner_path = "./models/ner_model/best_model" |
| | | self.repayment_path = "./models/repayment_model/best_model" |
| | | self.income_path = "./models/income_model/best_model" |
| | | self.flight_path = "./models/flight_model/best_model" |
| | | self.train_path = "./models/train_model/best_model" # 添加火车票模型路径 |
| | | |
| | | # 检查模型文件 |
| | | self._check_model_files() |
| | |
| | | self.ner_tokenizer, self.ner_model = self._load_ner() |
| | | self.repayment_tokenizer, self.repayment_model = self._load_repayment() |
| | | self.income_tokenizer, self.income_model = self._load_income() |
| | | self.flight_tokenizer, self.flight_model = self._load_flight() |
| | | self.train_tokenizer, self.train_model = self._load_train() # 加载火车票模型 |
| | | |
| | | # 将模型设置为评估模式 |
| | | self.classifier_model.eval() |
| | | self.ner_model.eval() |
| | | self.repayment_model.eval() |
| | | self.income_model.eval() |
| | | self.flight_model.eval() |
| | | self.train_model.eval() # 设置火车票模型为评估模式 |
| | | |
| | | def _check_model_files(self): |
| | | """检查模型文件是否存在""" |
| | |
| | | raise RuntimeError("还款模型文件不存在,请先运行训练脚本") |
| | | if not os.path.exists(self.income_path): |
| | | raise RuntimeError("收入模型文件不存在,请先运行训练脚本") |
| | | if not os.path.exists(self.flight_path): |
| | | raise RuntimeError("航班模型文件不存在,请先运行训练脚本") |
| | | if not os.path.exists(self.train_path): |
| | | raise RuntimeError("火车票模型文件不存在,请先运行训练脚本") |
| | | |
| | | def _load_classifier(self) -> Tuple[BertTokenizer, BertForSequenceClassification]: |
| | | """加载分类模型""" |
| | |
| | | logger.error(f"加载收入模型失败: {str(e)}") |
| | | raise |
| | | |
| | | def _load_flight(self): |
| | | """加载航班模型""" |
| | | try: |
| | | tokenizer = AutoTokenizer.from_pretrained(self.flight_path) |
| | | model = AutoModelForTokenClassification.from_pretrained(self.flight_path) |
| | | return tokenizer, model |
| | | except Exception as e: |
| | | logger.error(f"加载航班模型失败: {str(e)}") |
| | | raise |
| | | |
| | | def _load_train(self): |
| | | """加载火车票模型""" |
| | | try: |
| | | tokenizer = AutoTokenizer.from_pretrained(self.train_path) |
| | | model = AutoModelForTokenClassification.from_pretrained(self.train_path) |
| | | return tokenizer, model |
| | | except Exception as e: |
| | | logger.error(f"加载火车票模型失败: {str(e)}") |
| | | raise |
| | | |
| | | def classify_sms(self, text: str) -> str: |
| | | """对短信进行分类""" |
| | | try: |
| | |
| | | "company": None, # 寄件公司 |
| | | "address": None, # 地址 |
| | | "pickup_code": None, # 取件码 |
| | | "time": None # 时间 |
| | | "time": None # 添加时间字段 |
| | | } |
| | | |
| | | # 第一阶段:直接从文本中提取取件码 |
| | |
| | | logger.error(f"收入实体提取失败: {str(e)}") |
| | | raise |
| | | |
| | | def extract_flight_entities(self, text: str) -> Dict[str, Optional[str]]: |
| | | """提取航班相关实体""" |
| | | try: |
| | | # 初始化结果字典 |
| | | result = { |
| | | "flight": None, # 航班号 |
| | | "company": None, # 航空公司 |
| | | "start": None, # 出发地 |
| | | "end": None, # 目的地 |
| | | "date": None, # 日期 |
| | | "time": None, # 时间 |
| | | "departure_time": None, # 起飞时间 |
| | | "arrival_time": None, # 到达时间 |
| | | "ticket_num": None, # 机票号码 |
| | | "seat": None # 座位等信息 |
| | | } |
| | | |
| | | # 使用NER模型提取实体 |
| | | inputs = self.flight_tokenizer( |
| | | text, |
| | | return_tensors="pt", |
| | | truncation=True, |
| | | max_length=FlightNERConfig.MAX_LENGTH |
| | | ) |
| | | |
| | | with torch.no_grad(): |
| | | outputs = self.flight_model(**inputs) |
| | | |
| | | predictions = torch.argmax(outputs.logits, dim=2) |
| | | tokens = self.flight_tokenizer.convert_ids_to_tokens(inputs["input_ids"][0]) |
| | | tags = [self.flight_model.config.id2label[p] for p in predictions[0].numpy()] |
| | | |
| | | # 解析实体 |
| | | current_entity = None |
| | | |
| | | for token, tag in zip(tokens, tags): |
| | | if tag.startswith("B-"): |
| | | if current_entity: |
| | | entity_type = current_entity["type"].lower() |
| | | result[entity_type] = current_entity["text"].replace("[UNK]", "").replace("##", "").strip() |
| | | current_entity = {"type": tag[2:], "text": token} |
| | | elif tag.startswith("I-") and current_entity and tag[2:] == current_entity["type"]: |
| | | current_entity["text"] += token |
| | | else: |
| | | if current_entity: |
| | | entity_type = current_entity["type"].lower() |
| | | result[entity_type] = current_entity["text"].replace("[UNK]", "").replace("##", "").strip() |
| | | current_entity = None |
| | | |
| | | # 处理最后一个实体 |
| | | if current_entity: |
| | | entity_type = current_entity["type"].lower() |
| | | result[entity_type] = current_entity["text"].replace("[UNK]", "").replace("##", "").strip() |
| | | |
| | | # 处理航班号格式 |
| | | if result["flight"]: |
| | | flight_no = result["flight"].upper() |
| | | # 清理航班号,只保留字母和数字 |
| | | flight_no = ''.join(c for c in flight_no if c.isalnum()) |
| | | # 验证航班号格式 |
| | | valid_pattern = re.compile(FlightNERConfig.FLIGHT_CONFIG['pattern']) |
| | | if valid_pattern.match(flight_no): |
| | | result["flight"] = flight_no |
| | | else: |
| | | # 尝试修复常见错误 |
| | | if len(flight_no) >= FlightNERConfig.FLIGHT_CONFIG['min_length'] and flight_no[:2].isalpha() and flight_no[2:].isdigit(): |
| | | result["flight"] = flight_no |
| | | else: |
| | | result["flight"] = None |
| | | |
| | | # 清理日期格式 |
| | | if result["date"]: |
| | | date_str = result["date"] |
| | | # 保留数字和常见日期分隔符 |
| | | date_str = ''.join(c for c in date_str if c.isdigit() or c in ['年', '月', '日', '-', '/', '.']) |
| | | result["date"] = date_str |
| | | |
| | | # 清理时间格式 |
| | | for time_field in ["time", "departure_time", "arrival_time"]: |
| | | if result[time_field]: |
| | | time_str = result[time_field] |
| | | # 保留数字和常见时间分隔符 |
| | | time_str = ''.join(c for c in time_str if c.isdigit() or c in [':', '时', '分', '点']) |
| | | result[time_field] = time_str |
| | | |
| | | # 处理机票号码 |
| | | if result["ticket_num"]: |
| | | ticket_num = result["ticket_num"] |
| | | # 清理机票号码,只保留字母和数字 |
| | | ticket_num = ''.join(c for c in ticket_num if c.isalnum()) |
| | | result["ticket_num"] = ticket_num |
| | | |
| | | # 处理座位信息 |
| | | if result["seat"]: |
| | | seat_str = result["seat"] |
| | | # 移除可能的额外空格和特殊字符 |
| | | seat_str = seat_str.replace(" ", "").strip() |
| | | result["seat"] = seat_str |
| | | |
| | | return result |
| | | except Exception as e: |
| | | logger.error(f"航班实体提取失败: {str(e)}") |
| | | raise |
| | | |
| | | def extract_train_entities(self, text: str) -> Dict[str, Optional[str]]: |
| | | """提取火车票相关实体""" |
| | | try: |
| | | # 初始化结果字典 |
| | | result = { |
| | | "company": None, # 12306 |
| | | "trips": None, # 车次 |
| | | "start": None, # 出发站 |
| | | "end": None, # 到达站 |
| | | "date": None, # 日期 |
| | | "time": None, # 时间 |
| | | "seat": None, # 座位等信息 |
| | | "name": None # 用户姓名 |
| | | } |
| | | |
| | | # 使用NER模型提取实体 |
| | | inputs = self.train_tokenizer( |
| | | text, |
| | | return_tensors="pt", |
| | | truncation=True, |
| | | max_length=TrainNERConfig.MAX_LENGTH |
| | | ) |
| | | |
| | | with torch.no_grad(): |
| | | outputs = self.train_model(**inputs) |
| | | |
| | | predictions = torch.argmax(outputs.logits, dim=2) |
| | | tokens = self.train_tokenizer.convert_ids_to_tokens(inputs["input_ids"][0]) |
| | | tags = [self.train_model.config.id2label[p] for p in predictions[0].numpy()] |
| | | |
| | | # 解析实体 |
| | | current_entity = None |
| | | |
| | | for token, tag in zip(tokens, tags): |
| | | if tag.startswith("B-"): |
| | | if current_entity: |
| | | entity_type = current_entity["type"].lower() |
| | | result[entity_type] = current_entity["text"].replace("[UNK]", "").replace("##", "").strip() |
| | | current_entity = {"type": tag[2:], "text": token} |
| | | elif tag.startswith("I-") and current_entity and tag[2:] == current_entity["type"]: |
| | | current_entity["text"] += token |
| | | else: |
| | | if current_entity: |
| | | entity_type = current_entity["type"].lower() |
| | | result[entity_type] = current_entity["text"].replace("[UNK]", "").replace("##", "").strip() |
| | | current_entity = None |
| | | |
| | | # 处理最后一个实体 |
| | | if current_entity: |
| | | entity_type = current_entity["type"].lower() |
| | | result[entity_type] = current_entity["text"].replace("[UNK]", "").replace("##", "").strip() |
| | | |
| | | # 处理公司名称,通常为12306 |
| | | if result["company"]: |
| | | company = result["company"].strip() |
| | | # 如果文本中检测不到公司名称,但包含12306,则默认为12306 |
| | | result["company"] = company |
| | | elif "12306" in text: |
| | | result["company"] = "12306" |
| | | |
| | | # 处理车次格式 |
| | | if result["trips"]: |
| | | trips_no = result["trips"].upper() |
| | | # 清理车次号,只保留字母和数字 |
| | | trips_no = ''.join(c for c in trips_no if c.isalnum() or c in ['/', '-']) |
| | | |
| | | # 验证车次格式 |
| | | valid_patterns = [re.compile(pattern) for pattern in TrainNERConfig.TRIPS_CONFIG['patterns']] |
| | | if any(pattern.match(trips_no) for pattern in valid_patterns): |
| | | result["trips"] = trips_no |
| | | else: |
| | | # 尝试修复常见错误 |
| | | if len(trips_no) >= TrainNERConfig.TRIPS_CONFIG['min_length'] and any(trips_no.startswith(t) for t in TrainNERConfig.TRIPS_CONFIG['train_types']): |
| | | result["trips"] = trips_no |
| | | elif trips_no.isdigit() and 1 <= len(trips_no) <= TrainNERConfig.TRIPS_CONFIG['max_length']: |
| | | result["trips"] = trips_no |
| | | else: |
| | | result["trips"] = None |
| | | |
| | | # 清理日期格式 |
| | | if result["date"]: |
| | | date_str = result["date"] |
| | | # 保留数字和常见日期分隔符 |
| | | date_str = ''.join(c for c in date_str if c.isdigit() or c in ['年', '月', '日', '-', '/', '.']) |
| | | result["date"] = date_str |
| | | |
| | | # 清理时间格式 |
| | | if result["time"]: |
| | | time_str = result["time"] |
| | | # 保留数字和常见时间分隔符 |
| | | time_str = ''.join(c for c in time_str if c.isdigit() or c in [':', '时', '分', '点']) |
| | | result["time"] = time_str |
| | | |
| | | # 处理座位信息 |
| | | if result["seat"]: |
| | | seat_str = result["seat"] |
| | | # 移除可能的额外空格和特殊字符 |
| | | seat_str = seat_str.replace(" ", "").strip() |
| | | result["seat"] = seat_str |
| | | |
| | | # 处理乘客姓名 |
| | | if result["name"]: |
| | | name = result["name"].strip() |
| | | # 移除可能的标点符号 |
| | | name = ''.join(c for c in name if c.isalnum() or c in ['*', '·']) |
| | | result["name"] = name |
| | | |
| | | return result |
| | | except Exception as e: |
| | | logger.error(f"火车票实体提取失败: {str(e)}") |
| | | raise |
| | | |
| | | # 创建Flask应用 |
| | | app = Flask(__name__) |
| | | model_manager = ModelManager() |
| | |
| | | details = model_manager.extract_repayment_entities(text) |
| | | elif category == "收入": |
| | | details = model_manager.extract_income_entities(text) |
| | | elif category == "航班": |
| | | details = model_manager.extract_flight_entities(text) |
| | | elif category == "火车票": # 添加火车票类别处理 |
| | | details = model_manager.extract_train_entities(text) |
| | | else: |
| | | details = {} |
| | | |
| | |
| | | from ner_config import RepaymentNERConfig |
| | | from ner_config import RepaymentNERConfig, FlightNERConfig, TrainNERConfig |
| | | |
| | | |
| | | # 脚本:校验非法格式 |
| | | |
| | |
| | | label_set = set() |
| | | line_num = 0 |
| | | |
| | | with open(RepaymentNERConfig.DATA_PATH, 'r', encoding='utf-8') as f: |
| | | with open(FlightNERConfig.DATA_PATH, 'r', encoding='utf-8') as f: |
| | | for line in f: |
| | | line_num += 1 |
| | | line = line.strip() |
| | | if line: |
| | | try: |
| | | _, label = line.split(maxsplit=1) |
| | | if label not in RepaymentNERConfig.LABELS: |
| | | if label not in FlightNERConfig.LABELS: |
| | | print(f"行 {line_num}: 发现非法标签 '{label}'") |
| | | label_set.add(label) |
| | | except Exception as e: |
| | |
| | | # 交叉验证配置 |
| | | N_SPLITS = 3 # CPU环境下减少折数 |
| | | N_SEEDS = 1 # CPU环境下减少种子数量 |
| | | |
| | | |
| | | # 确保标签列表完整 |
| | | LABELS = [ |
| | | "O", |
| | |
| | | LEARNING_RATE = 3e-5 |
| | | WARMUP_RATIO = 0.1 |
| | | WEIGHT_DECAY = 0.01 |
| | | |
| | | |
| | | # 数据增强配置 |
| | | USE_DATA_AUGMENTATION = False |
| | | AUGMENTATION_RATIO = 0.3 |
| | | |
| | | |
| | | # 训练策略 |
| | | GRADIENT_ACCUMULATION_STEPS = 4 |
| | | EVAL_STEPS = 25 |
| | | LOGGING_STEPS = 10 |
| | | SAVE_STEPS = 25 # 添加保存步数 |
| | | SAVE_TOTAL_LIMIT = 2 # 添加保存检查点数量限制 |
| | | |
| | | |
| | | # 路径配置 |
| | | DATA_PATH = "data/repayment.txt" |
| | | MODEL_PATH = "./models/repayment_model" |
| | | LOG_PATH = "./logs_repayment" |
| | | |
| | | |
| | | # 训练配置优化 |
| | | SEED = 42 |
| | | TEST_SIZE = 0.1 |
| | | EARLY_STOPPING_PATIENCE = 2 |
| | | |
| | | |
| | | # CPU环境配置 |
| | | MAX_GRAD_NORM = 1.0 |
| | | FP16 = False # CPU环境下关闭FP16 |
| | | |
| | | |
| | | # CPU环境下的数据加载优化 |
| | | DATALOADER_NUM_WORKERS = 0 # CPU环境下设为0 |
| | | DATALOADER_PIN_MEMORY = False # CPU环境下关闭 |
| | | |
| | | |
| | | # 交叉验证配置 |
| | | N_SPLITS = 3 |
| | | N_SEEDS = 1 |
| | | |
| | | |
| | | # 标签列表 |
| | | LABELS = [ |
| | | "O", |
| | |
| | | LEARNING_RATE = 3e-5 |
| | | WARMUP_RATIO = 0.1 |
| | | WEIGHT_DECAY = 0.01 |
| | | |
| | | |
| | | # 数据增强配置 |
| | | USE_DATA_AUGMENTATION = False |
| | | AUGMENTATION_RATIO = 0.3 |
| | | |
| | | |
| | | # 训练策略 |
| | | GRADIENT_ACCUMULATION_STEPS = 4 |
| | | EVAL_STEPS = 25 |
| | | LOGGING_STEPS = 10 |
| | | SAVE_STEPS = 25 |
| | | SAVE_TOTAL_LIMIT = 2 |
| | | |
| | | |
| | | # 路径配置 |
| | | DATA_PATH = "data/income.txt" |
| | | MODEL_PATH = "./models/income_model" |
| | | LOG_PATH = "./logs_income" |
| | | |
| | | |
| | | # 训练配置优化 |
| | | SEED = 42 |
| | | TEST_SIZE = 0.1 |
| | | EARLY_STOPPING_PATIENCE = 2 |
| | | |
| | | |
| | | # CPU环境配置 |
| | | MAX_GRAD_NORM = 1.0 |
| | | FP16 = False |
| | | |
| | | |
| | | # CPU环境下的数据加载优化 |
| | | DATALOADER_NUM_WORKERS = 0 |
| | | DATALOADER_PIN_MEMORY = False |
| | | |
| | | |
| | | # 交叉验证配置 |
| | | N_SPLITS = 3 |
| | | N_SEEDS = 1 |
| | | |
| | | |
| | | # 标签列表 |
| | | LABELS = [ |
| | | "O", |
| | |
| | | 'max_integer_digits': 12, # 整数部分最大位数 |
| | | 'currency_symbols': ['¥', '¥', 'RMB', '元'], # 货币符号 |
| | | 'decimal_context_range': 3 # 查找小数点的上下文范围 |
| | | } |
| | | |
| | | class FlightNERConfig: |
| | | # 优化模型参数 (与 RepaymentNERConfig 保持一致) |
| | | MODEL_NAME = "bert-base-chinese" |
| | | MAX_LENGTH = 128 |
| | | BATCH_SIZE = 4 |
| | | EPOCHS = 10 |
| | | LEARNING_RATE = 3e-5 |
| | | WARMUP_RATIO = 0.1 |
| | | WEIGHT_DECAY = 0.01 |
| | | |
| | | # 训练策略 |
| | | GRADIENT_ACCUMULATION_STEPS = 4 |
| | | EVAL_STEPS = 25 |
| | | LOGGING_STEPS = 10 |
| | | SAVE_STEPS = 25 |
| | | SAVE_TOTAL_LIMIT = 2 |
| | | |
| | | # 路径配置 |
| | | DATA_PATH = "data/flight.txt" |
| | | MODEL_PATH = "./models/flight_model" |
| | | LOG_PATH = "./logs_flight" |
| | | |
| | | # 训练配置 |
| | | SEED = 42 |
| | | TEST_SIZE = 0.1 |
| | | EARLY_STOPPING_PATIENCE = 2 |
| | | |
| | | # CPU环境配置 |
| | | MAX_GRAD_NORM = 1.0 |
| | | FP16 = False |
| | | DATALOADER_NUM_WORKERS = 0 |
| | | DATALOADER_PIN_MEMORY = False |
| | | |
| | | # 交叉验证配置 |
| | | N_SPLITS = 3 |
| | | N_SEEDS = 3 # 增加种子数量以提高模型稳定性 |
| | | |
| | | # 标签列表 - 保持与需求一致 |
| | | LABELS = [ |
| | | "O", |
| | | "B-FLIGHT", "I-FLIGHT", # 航班号 |
| | | "B-COMPANY", "I-COMPANY", # 航空公司 |
| | | "B-START", "I-START", # 出发地 |
| | | "B-END", "I-END", # 目的地 |
| | | "B-DATE", "I-DATE", # 日期 |
| | | "B-TIME", "I-TIME", # 时间 |
| | | "B-DEPARTURE_TIME", "I-DEPARTURE_TIME", # 起飞时间 |
| | | "B-ARRIVAL_TIME", "I-ARRIVAL_TIME", # 到达时间 |
| | | "B-TICKET_NUM", "I-TICKET_NUM", # 机票号码 |
| | | "B-SEAT", "I-SEAT" # 座位等信息 |
| | | ] |
| | | |
| | | # 实体长度限制 - 更新键名与LABELS一致 |
| | | MAX_ENTITY_LENGTH = { |
| | | "FLIGHT": 10, # 航班号 |
| | | "COMPANY": 15, # 航空公司 |
| | | "START": 10, # 出发地 |
| | | "END": 10, # 目的地 |
| | | "DATE": 15, # 日期 |
| | | "TIME": 10, # 时间 |
| | | "DEPARTURE_TIME": 10, # 起飞时间 |
| | | "ARRIVAL_TIME": 10, # 到达时间 |
| | | "TICKET_NUM": 10, # 用户姓名 |
| | | "SEAT": 10 # 座位等信息 |
| | | } |
| | | |
| | | # 航班号配置 |
| | | FLIGHT_CONFIG = { |
| | | 'pattern': r'[A-Z]{2}\d{3,4}', |
| | | 'min_length': 4, |
| | | 'max_length': 7, |
| | | 'carrier_codes': ['CA', 'MU', 'CZ', 'HU', '3U', 'ZH', 'FM', 'MF', 'SC', '9C'] # 常见航司代码 |
| | | } |
| | | |
| | | class TrainNERConfig: |
| | | # 模型参数 |
| | | MODEL_NAME = "bert-base-chinese" |
| | | MAX_LENGTH = 128 |
| | | BATCH_SIZE = 4 |
| | | EPOCHS = 10 |
| | | LEARNING_RATE = 3e-5 |
| | | WARMUP_RATIO = 0.1 |
| | | WEIGHT_DECAY = 0.01 |
| | | |
| | | # 训练策略 |
| | | GRADIENT_ACCUMULATION_STEPS = 4 |
| | | EVAL_STEPS = 25 |
| | | LOGGING_STEPS = 10 |
| | | SAVE_STEPS = 25 |
| | | SAVE_TOTAL_LIMIT = 2 |
| | | |
| | | # 路径配置 |
| | | DATA_PATH = "data/train.txt" |
| | | MODEL_PATH = "./models/train_model" |
| | | LOG_PATH = "./logs_train" |
| | | |
| | | # 训练配置 |
| | | SEED = 42 |
| | | TEST_SIZE = 0.1 |
| | | EARLY_STOPPING_PATIENCE = 2 |
| | | |
| | | # CPU环境配置 |
| | | MAX_GRAD_NORM = 1.0 |
| | | FP16 = False |
| | | DATALOADER_NUM_WORKERS = 0 |
| | | DATALOADER_PIN_MEMORY = False |
| | | |
| | | # 交叉验证配置 |
| | | N_SPLITS = 3 |
| | | N_SEEDS = 3 # 增加种子数量以提高模型稳定性 |
| | | |
| | | # 标签列表 |
| | | LABELS = [ |
| | | "O", |
| | | "B-COMPANY", "I-COMPANY", # 车次 |
| | | "B-TRIPS", "I-TRIPS", # 车次 |
| | | "B-START", "I-START", # 出发站 |
| | | "B-END", "I-END", # 到达站 |
| | | "B-DATE", "I-DATE", # 日期 |
| | | "B-TIME", "I-TIME", # 时间 |
| | | "B-SEAT", "I-SEAT", # 座位等信息 |
| | | "B-NAME", "I-NAME" # 用户姓名 |
| | | ] |
| | | |
| | | # 实体长度限制 - 更新键名与LABELS一致 |
| | | MAX_ENTITY_LENGTH = { |
| | | "COMPANY": 8, # 12306 |
| | | "TRIPS": 8, # 车次 |
| | | "START": 10, # 出发站 |
| | | "END": 10, # 到达站 |
| | | "DATE": 15, # 日期 |
| | | "TIME": 10, # 时间 |
| | | "SEAT": 10, # 座位等信息 |
| | | "NAME": 10 # 用户姓名 |
| | | } |
| | | |
| | | # 车次配置 |
| | | TRIPS_CONFIG = { |
| | | 'patterns': [ |
| | | r'[GDCZTKY]\d{1,2}', # G1, D1, C1等 |
| | | r'[GDCZTKY]\d{1,2}/\d{1,2}', # G1/2等联运车次 |
| | | r'[GDCZTKY]\d{1,2}-\d{1,2}', # G1-2等联运车次 |
| | | r'\d{1,4}', # 普通车次如1234次 |
| | | r'[A-Z]\d{1,4}' # Z1234等特殊车次 |
| | | ], |
| | | 'min_length': 1, |
| | | 'max_length': 8, |
| | | 'train_types': ['G', 'D', 'C', 'Z', 'T', 'K', 'Y'] # 车次类型前缀 |
| | | } |
| | |
| | | # train_flight_ner.py |
| | | from transformers import AutoTokenizer, AutoModelForTokenClassification, TrainingArguments, Trainer |
| | | from transformers.trainer_callback import EarlyStoppingCallback |
| | | import torch |
| | | from torch.utils.data import Dataset |
| | | import numpy as np |
| | | from sklearn.model_selection import train_test_split |
| | | from seqeval.metrics import f1_score, precision_score, recall_score |
| | | import random |
| | | import re |
| | | from ner_config import FlightNERConfig |
| | | |
| | | # 设置随机种子 |
| | | def set_seed(seed): |
| | | random.seed(seed) |
| | | np.random.seed(seed) |
| | | torch.manual_seed(seed) |
| | | if torch.cuda.is_available(): |
| | | torch.cuda.manual_seed_all(seed) |
| | | |
| | | set_seed(FlightNERConfig.SEED) |
| | | |
| | | class NERDataset(Dataset): |
| | | def __init__(self, texts, labels, tokenizer, label_list): |
| | | self.texts = texts |
| | | self.labels = labels |
| | | self.tokenizer = tokenizer |
| | | # 创建标签到ID的映射 |
| | | self.label2id = {label: i for i, label in enumerate(label_list)} |
| | | self.id2label = {i: label for i, label in enumerate(label_list)} |
| | | |
| | | # 打印标签映射信息 |
| | | print("标签映射:") |
| | | for label, idx in self.label2id.items(): |
| | | print(f"{label}: {idx}") |
| | | |
| | | # 对文本进行编码 |
| | | self.encodings = self.tokenize_and_align_labels() |
| | | |
| | | def tokenize_and_align_labels(self): |
| | | tokenized_inputs = self.tokenizer( |
| | | [''.join(text) for text in self.texts], |
| | | truncation=True, |
| | | padding=True, |
| | | max_length=FlightNERConfig.MAX_LENGTH, |
| | | return_offsets_mapping=True, |
| | | return_tensors=None |
| | | ) |
| | | |
| | | labels = [] |
| | | for i, label in enumerate(self.labels): |
| | | word_ids = tokenized_inputs.word_ids(i) |
| | | previous_word_idx = None |
| | | label_ids = [] |
| | | current_entity = None |
| | | |
| | | for word_idx in word_ids: |
| | | if word_idx is None: |
| | | label_ids.append(-100) |
| | | elif word_idx != previous_word_idx: |
| | | # 新词开始 |
| | | label_ids.append(self.label2id[label[word_idx]]) |
| | | if label[word_idx].startswith("B-"): |
| | | current_entity = label[word_idx][2:] |
| | | elif label[word_idx] == "O": |
| | | current_entity = None |
| | | else: |
| | | # 同一个词的后续token |
| | | if current_entity: |
| | | label_ids.append(self.label2id[f"I-{current_entity}"]) |
| | | else: |
| | | label_ids.append(self.label2id["O"]) |
| | | |
| | | previous_word_idx = word_idx |
| | | |
| | | labels.append(label_ids) |
| | | |
| | | tokenized_inputs["labels"] = labels |
| | | return tokenized_inputs |
| | | |
| | | def __getitem__(self, idx): |
| | | return {key: torch.tensor(val[idx]) for key, val in self.encodings.items()} |
| | | |
| | | def __len__(self): |
| | | return len(self.texts) |
| | | |
| | | def load_data(file_path): |
| | | texts, labels = [], [] |
| | | current_words, current_labels = [], [] |
| | | |
| | | def clean_flight_labels(words, labels): |
| | | """清理航班号标注,确保格式正确""" |
| | | i = 0 |
| | | while i < len(words): |
| | | if labels[i].startswith("B-FLIGHT"): # 已修改标签名称 |
| | | # 找到航班号的结束位置 |
| | | j = i + 1 |
| | | while j < len(words) and labels[j].startswith("I-FLIGHT"): # 已修改标签名称 |
| | | j += 1 |
| | | |
| | | # 检查并修正航班号序列 |
| | | flight_words = words[i:j] |
| | | flight_str = ''.join(flight_words) |
| | | |
| | | # 检查格式是否符合航班号规范 |
| | | valid_pattern = re.compile(r'^[A-Z]{2}\d{3,4}$') |
| | | if not valid_pattern.match(flight_str): |
| | | # 将格式不正确的标签改为O |
| | | for k in range(i, j): |
| | | labels[k] = "O" |
| | | |
| | | i = j |
| | | else: |
| | | i += 1 |
| | | |
| | | return words, labels |
| | | |
| | | with open(file_path, 'r', encoding='utf-8') as f: |
| | | for line in f: |
| | | line = line.strip() |
| | | if line: |
| | | try: |
| | | word, label = line.split(maxsplit=1) |
| | | current_words.append(word) |
| | | current_labels.append(label) |
| | | except Exception as e: |
| | | print(f"错误:处理行时出错: '{line}'") |
| | | continue |
| | | elif current_words: # 遇到空行且当前有数据 |
| | | # 清理航班号标注 |
| | | current_words, current_labels = clean_flight_labels(current_words, current_labels) |
| | | texts.append(current_words) |
| | | labels.append(current_labels) |
| | | current_words, current_labels = [], [] |
| | | |
| | | if current_words: # 处理最后一个样本 |
| | | current_words, current_labels = clean_flight_labels(current_words, current_labels) |
| | | texts.append(current_words) |
| | | labels.append(current_labels) |
| | | |
| | | return texts, labels |
| | | |
| | | def compute_metrics(p): |
| | | """计算评估指标""" |
| | | predictions, labels = p |
| | | predictions = np.argmax(predictions, axis=2) |
| | | |
| | | # 移除特殊token的预测和标签 |
| | | true_predictions = [ |
| | | [FlightNERConfig.LABELS[p] for (p, l) in zip(prediction, label) if l != -100] |
| | | for prediction, label in zip(predictions, labels) |
| | | ] |
| | | true_labels = [ |
| | | [FlightNERConfig.LABELS[l] for (p, l) in zip(prediction, label) if l != -100] |
| | | for prediction, label in zip(predictions, labels) |
| | | ] |
| | | |
| | | # 计算总体评估指标 |
| | | results = { |
| | | "overall_f1": f1_score(true_labels, true_predictions), |
| | | "overall_precision": precision_score(true_labels, true_predictions), |
| | | "overall_recall": recall_score(true_labels, true_predictions) |
| | | } |
| | | |
| | | # 计算每个实体类型的指标 |
| | | for entity_type in ["FLIGHT", "COMPANY", "START", "END", "DATE", "TIME", "DEPARTURE_TIME", "ARRIVAL_TIME","TICKET_NUM","SEAT"]: |
| | | # 将标签转换为二进制形式 |
| | | binary_preds = [] |
| | | binary_labels = [] |
| | | |
| | | for pred_seq, label_seq in zip(true_predictions, true_labels): |
| | | pred_binary = [] |
| | | label_binary = [] |
| | | |
| | | for pred, label in zip(pred_seq, label_seq): |
| | | # 检查标签是否属于当前实体类型 |
| | | pred_is_entity = pred.endswith(entity_type) |
| | | label_is_entity = label.endswith(entity_type) |
| | | |
| | | pred_binary.append(1 if pred_is_entity else 0) |
| | | label_binary.append(1 if label_is_entity else 0) |
| | | |
| | | binary_preds.append(pred_binary) |
| | | binary_labels.append(label_binary) |
| | | |
| | | # 计算当前实体类型的F1分数 |
| | | try: |
| | | entity_f1 = f1_score( |
| | | sum(binary_labels, []), # 展平列表 |
| | | sum(binary_preds, []), # 展平列表 |
| | | average='binary' # 使用二进制评估 |
| | | ) |
| | | results[f"{entity_type}_f1"] = entity_f1 |
| | | except Exception as e: |
| | | print(f"计算{entity_type}的F1分数时出错: {str(e)}") |
| | | results[f"{entity_type}_f1"] = 0.0 |
| | | |
| | | return results |
| | | |
| | | def augment_data(texts, labels): |
| | | """数据增强""" |
| | | augmented_texts = [] |
| | | augmented_labels = [] |
| | | for text, label in zip(texts, labels): |
| | | # 原始数据 |
| | | augmented_texts.append(text) |
| | | augmented_labels.append(label) |
| | | |
| | | # 删除一些无关字符 |
| | | new_text = [] |
| | | new_label = [] |
| | | for t, l in zip(text, label): |
| | | if l == "O" and random.random() < 0.3: |
| | | continue |
| | | new_text.append(t) |
| | | new_label.append(l) |
| | | augmented_texts.append(new_text) |
| | | augmented_labels.append(new_label) |
| | | |
| | | return augmented_texts, augmented_labels |
| | | |
| | | def main(): |
| | | # 加载数据 |
| | | texts, labels = load_data(FlightNERConfig.DATA_PATH) |
| | | print(f"加载的数据集大小:{len(texts)}个样本") |
| | | |
| | | # 划分数据集 |
| | | train_texts, val_texts, train_labels, val_labels = train_test_split( |
| | | texts, labels, test_size=FlightNERConfig.TEST_SIZE, random_state=FlightNERConfig.SEED |
| | | ) |
| | | |
| | | # 数据增强 |
| | | train_texts, train_labels = augment_data(train_texts, train_labels) |
| | | |
| | | # 加载分词器和模型 |
| | | tokenizer = AutoTokenizer.from_pretrained(FlightNERConfig.MODEL_NAME) |
| | | model = AutoModelForTokenClassification.from_pretrained( |
| | | FlightNERConfig.MODEL_NAME, |
| | | num_labels=len(FlightNERConfig.LABELS), |
| | | id2label={i: label for i, label in enumerate(FlightNERConfig.LABELS)}, |
| | | label2id={label: i for i, label in enumerate(FlightNERConfig.LABELS)} |
| | | ) |
| | | |
| | | # 创建数据集 |
| | | train_dataset = NERDataset(train_texts, train_labels, tokenizer, FlightNERConfig.LABELS) |
| | | val_dataset = NERDataset(val_texts, val_labels, tokenizer, FlightNERConfig.LABELS) |
| | | |
| | | # 训练参数 |
| | | training_args = TrainingArguments( |
| | | output_dir=FlightNERConfig.MODEL_PATH, |
| | | num_train_epochs=FlightNERConfig.EPOCHS, |
| | | per_device_train_batch_size=FlightNERConfig.BATCH_SIZE, |
| | | per_device_eval_batch_size=FlightNERConfig.BATCH_SIZE, |
| | | learning_rate=FlightNERConfig.LEARNING_RATE, |
| | | warmup_ratio=FlightNERConfig.WARMUP_RATIO, |
| | | weight_decay=FlightNERConfig.WEIGHT_DECAY, |
| | | gradient_accumulation_steps=FlightNERConfig.GRADIENT_ACCUMULATION_STEPS, |
| | | logging_steps=FlightNERConfig.LOGGING_STEPS, |
| | | save_total_limit=2, |
| | | no_cuda=True, |
| | | evaluation_strategy="steps", |
| | | eval_steps=FlightNERConfig.EVAL_STEPS, |
| | | save_strategy="steps", |
| | | save_steps=FlightNERConfig.SAVE_STEPS, |
| | | load_best_model_at_end=True, |
| | | metric_for_best_model="overall_f1", |
| | | greater_is_better=True, |
| | | logging_dir=FlightNERConfig.LOG_PATH, |
| | | logging_first_step=True, |
| | | report_to=["tensorboard"], |
| | | ) |
| | | |
| | | # 训练器 |
| | | trainer = Trainer( |
| | | model=model, |
| | | args=training_args, |
| | | train_dataset=train_dataset, |
| | | eval_dataset=val_dataset, |
| | | compute_metrics=compute_metrics, |
| | | callbacks=[EarlyStoppingCallback(early_stopping_patience=FlightNERConfig.EARLY_STOPPING_PATIENCE)] |
| | | ) |
| | | |
| | | # 训练模型 |
| | | trainer.train() |
| | | |
| | | # 评估结果 |
| | | eval_results = trainer.evaluate() |
| | | print("\n评估结果:") |
| | | for key, value in eval_results.items(): |
| | | print(f"{key}: {value:.4f}") |
| | | |
| | | # 保存最终模型 |
| | | model.save_pretrained(f"{FlightNERConfig.MODEL_PATH}/best_model") |
| | | tokenizer.save_pretrained(f"{FlightNERConfig.MODEL_PATH}/best_model") |
| | | |
| | | if __name__ == "__main__": |
| | | main() |
| | |
| | | # train_train_ner.py |
| | | from transformers import AutoTokenizer, AutoModelForTokenClassification, TrainingArguments, Trainer |
| | | from transformers.trainer_callback import EarlyStoppingCallback |
| | | import torch |
| | | from torch.utils.data import Dataset |
| | | import numpy as np |
| | | from sklearn.model_selection import train_test_split |
| | | from seqeval.metrics import f1_score, precision_score, recall_score |
| | | import random |
| | | import os |
| | | import re |
| | | from ner_config import TrainNERConfig |
| | | |
| | | # 设置随机种子 |
| | | def set_seed(seed): |
| | | random.seed(seed) |
| | | np.random.seed(seed) |
| | | torch.manual_seed(seed) |
| | | if torch.cuda.is_available(): |
| | | torch.cuda.manual_seed_all(seed) |
| | | |
| | | set_seed(TrainNERConfig.SEED) |
| | | |
| | | class NERDataset(Dataset): |
| | | def __init__(self, texts, labels, tokenizer, label_list): |
| | | self.texts = texts |
| | | self.labels = labels |
| | | self.tokenizer = tokenizer |
| | | # 创建标签到ID的映射 |
| | | self.label2id = {label: i for i, label in enumerate(label_list)} |
| | | self.id2label = {i: label for i, label in enumerate(label_list)} |
| | | |
| | | # 打印标签映射信息 |
| | | print("标签映射:") |
| | | for label, idx in self.label2id.items(): |
| | | print(f"{label}: {idx}") |
| | | |
| | | # 对文本进行编码 |
| | | self.encodings = self.tokenize_and_align_labels() |
| | | |
| | | def tokenize_and_align_labels(self): |
| | | tokenized_inputs = self.tokenizer( |
| | | [''.join(text) for text in self.texts], |
| | | truncation=True, |
| | | padding=True, |
| | | max_length=TrainNERConfig.MAX_LENGTH, |
| | | return_offsets_mapping=True, |
| | | return_tensors=None |
| | | ) |
| | | |
| | | labels = [] |
| | | for i, label in enumerate(self.labels): |
| | | word_ids = tokenized_inputs.word_ids(i) |
| | | previous_word_idx = None |
| | | label_ids = [] |
| | | current_entity = None |
| | | |
| | | for word_idx in word_ids: |
| | | if word_idx is None: |
| | | label_ids.append(-100) |
| | | elif word_idx != previous_word_idx: |
| | | # 新词开始 |
| | | label_ids.append(self.label2id[label[word_idx]]) |
| | | if label[word_idx].startswith("B-"): |
| | | current_entity = label[word_idx][2:] |
| | | elif label[word_idx] == "O": |
| | | current_entity = None |
| | | else: |
| | | # 同一个词的后续token |
| | | if current_entity: |
| | | label_ids.append(self.label2id[f"I-{current_entity}"]) |
| | | else: |
| | | label_ids.append(self.label2id["O"]) |
| | | |
| | | previous_word_idx = word_idx |
| | | |
| | | labels.append(label_ids) |
| | | |
| | | tokenized_inputs["labels"] = labels |
| | | return tokenized_inputs |
| | | |
| | | def __getitem__(self, idx): |
| | | return {key: torch.tensor(val[idx]) for key, val in self.encodings.items()} |
| | | |
| | | def __len__(self): |
| | | return len(self.texts) |
| | | |
| | | def load_data(file_path): |
| | | texts, labels = [], [] |
| | | current_words, current_labels = [], [] |
| | | |
| | | def clean_trips_labels(words, labels): |
| | | """清理车次标注,确保格式正确""" |
| | | i = 0 |
| | | while i < len(words): |
| | | if labels[i].startswith("B-TRIPS"): # 修改标签名 |
| | | # 找到车次的结束位置 |
| | | j = i + 1 |
| | | while j < len(words) and labels[j].startswith("I-TRIPS"): # 修改标签名 |
| | | j += 1 |
| | | |
| | | # 检查并修正车次序列 |
| | | train_words = words[i:j] |
| | | train_str = ''.join(train_words) |
| | | |
| | | # 检查格式是否符合车次规范 |
| | | valid_patterns = [ |
| | | re.compile(r'^[GDCZTKY]\d{1,2}$'), |
| | | re.compile(r'^[GDCZTKY]\d{1,2}/\d{1,2}$'), |
| | | re.compile(r'^[GDCZTKY]\d{1,2}-\d{1,2}$'), |
| | | re.compile(r'^\d{1,4}$'), |
| | | re.compile(r'^[A-Z]\d{1,4}$') |
| | | ] |
| | | |
| | | is_valid = any(pattern.match(train_str) for pattern in valid_patterns) |
| | | if not is_valid: |
| | | # 将格式不正确的标签改为O |
| | | for k in range(i, j): |
| | | labels[k] = "O" |
| | | |
| | | i = j |
| | | else: |
| | | i += 1 |
| | | |
| | | return words, labels |
| | | |
| | | with open(file_path, 'r', encoding='utf-8') as f: |
| | | for line in f: |
| | | line = line.strip() |
| | | if line: |
| | | try: |
| | | word, label = line.split(maxsplit=1) |
| | | current_words.append(word) |
| | | current_labels.append(label) |
| | | except Exception as e: |
| | | print(f"错误:处理行时出错: '{line}'") |
| | | continue |
| | | elif current_words: # 遇到空行且当前有数据 |
| | | # 清理车次标注 |
| | | current_words, current_labels = clean_trips_labels(current_words, current_labels) |
| | | texts.append(current_words) |
| | | labels.append(current_labels) |
| | | current_words, current_labels = [], [] |
| | | |
| | | if current_words: # 处理最后一个样本 |
| | | current_words, current_labels = clean_trips_labels(current_words, current_labels) |
| | | texts.append(current_words) |
| | | labels.append(current_labels) |
| | | |
| | | return texts, labels |
| | | |
| | | def compute_metrics(p): |
| | | """计算评估指标""" |
| | | predictions, labels = p |
| | | predictions = np.argmax(predictions, axis=2) |
| | | |
| | | # 移除特殊token的预测和标签 |
| | | true_predictions = [ |
| | | [TrainNERConfig.LABELS[p] for (p, l) in zip(prediction, label) if l != -100] |
| | | for prediction, label in zip(predictions, labels) |
| | | ] |
| | | true_labels = [ |
| | | [TrainNERConfig.LABELS[l] for (p, l) in zip(prediction, label) if l != -100] |
| | | for prediction, label in zip(predictions, labels) |
| | | ] |
| | | |
| | | # 计算总体评估指标 |
| | | results = { |
| | | "overall_f1": f1_score(true_labels, true_predictions), |
| | | "overall_precision": precision_score(true_labels, true_predictions), |
| | | "overall_recall": recall_score(true_labels, true_predictions) |
| | | } |
| | | |
| | | # 计算每个实体类型的指标 |
| | | for entity_type in ["COMPANY","TRIPS", "START", "END", "DATE", "TIME", "SEAT", "NAME"]: |
| | | # 将标签转换为二进制形式 |
| | | binary_preds = [] |
| | | binary_labels = [] |
| | | |
| | | for pred_seq, label_seq in zip(true_predictions, true_labels): |
| | | pred_binary = [] |
| | | label_binary = [] |
| | | |
| | | for pred, label in zip(pred_seq, label_seq): |
| | | # 检查标签是否属于当前实体类型 |
| | | pred_is_entity = pred.endswith(entity_type) |
| | | label_is_entity = label.endswith(entity_type) |
| | | |
| | | pred_binary.append(1 if pred_is_entity else 0) |
| | | label_binary.append(1 if label_is_entity else 0) |
| | | |
| | | binary_preds.append(pred_binary) |
| | | binary_labels.append(label_binary) |
| | | |
| | | # 计算当前实体类型的F1分数 |
| | | try: |
| | | entity_f1 = f1_score( |
| | | sum(binary_labels, []), # 展平列表 |
| | | sum(binary_preds, []), # 展平列表 |
| | | average='binary' # 使用二进制评估 |
| | | ) |
| | | results[f"{entity_type}_f1"] = entity_f1 |
| | | except Exception as e: |
| | | print(f"计算{entity_type}的F1分数时出错: {str(e)}") |
| | | results[f"{entity_type}_f1"] = 0.0 |
| | | |
| | | return results |
| | | |
| | | def augment_data(texts, labels): |
| | | """数据增强""" |
| | | augmented_texts = [] |
| | | augmented_labels = [] |
| | | for text, label in zip(texts, labels): |
| | | # 原始数据 |
| | | augmented_texts.append(text) |
| | | augmented_labels.append(label) |
| | | |
| | | # 删除一些无关字符 |
| | | new_text = [] |
| | | new_label = [] |
| | | for t, l in zip(text, label): |
| | | if l == "O" and random.random() < 0.3: |
| | | continue |
| | | new_text.append(t) |
| | | new_label.append(l) |
| | | augmented_texts.append(new_text) |
| | | augmented_labels.append(new_label) |
| | | |
| | | return augmented_texts, augmented_labels |
| | | |
| | | def main(): |
| | | # 加载数据 |
| | | texts, labels = load_data(TrainNERConfig.DATA_PATH) |
| | | print(f"加载的数据集大小:{len(texts)}个样本") |
| | | |
| | | # 划分数据集 |
| | | train_texts, val_texts, train_labels, val_labels = train_test_split( |
| | | texts, labels, test_size=TrainNERConfig.TEST_SIZE, random_state=TrainNERConfig.SEED |
| | | ) |
| | | |
| | | # 数据增强 |
| | | train_texts, train_labels = augment_data(train_texts, train_labels) |
| | | |
| | | # 加载分词器和模型 |
| | | tokenizer = AutoTokenizer.from_pretrained(TrainNERConfig.MODEL_NAME) |
| | | model = AutoModelForTokenClassification.from_pretrained( |
| | | TrainNERConfig.MODEL_NAME, |
| | | num_labels=len(TrainNERConfig.LABELS), |
| | | id2label={i: label for i, label in enumerate(TrainNERConfig.LABELS)}, |
| | | label2id={label: i for i, label in enumerate(TrainNERConfig.LABELS)} |
| | | ) |
| | | |
| | | # 创建数据集 |
| | | train_dataset = NERDataset(train_texts, train_labels, tokenizer, TrainNERConfig.LABELS) |
| | | val_dataset = NERDataset(val_texts, val_labels, tokenizer, TrainNERConfig.LABELS) |
| | | |
| | | # 训练参数 |
| | | training_args = TrainingArguments( |
| | | output_dir=TrainNERConfig.MODEL_PATH, |
| | | num_train_epochs=TrainNERConfig.EPOCHS, |
| | | per_device_train_batch_size=TrainNERConfig.BATCH_SIZE, |
| | | per_device_eval_batch_size=TrainNERConfig.BATCH_SIZE, |
| | | learning_rate=TrainNERConfig.LEARNING_RATE, |
| | | warmup_ratio=TrainNERConfig.WARMUP_RATIO, |
| | | weight_decay=TrainNERConfig.WEIGHT_DECAY, |
| | | gradient_accumulation_steps=TrainNERConfig.GRADIENT_ACCUMULATION_STEPS |
| | | ) |
| | | |
| | | trainer = Trainer( |
| | | model=model, |
| | | args=training_args, |
| | | train_dataset=train_dataset, |
| | | eval_dataset=val_dataset, |
| | | compute_metrics=compute_metrics |
| | | ) |
| | | |
| | | trainer.train() |
| | | # 评估结果 |
| | | eval_results = trainer.evaluate() |
| | | print("\n评估结果:") |
| | | for key, value in eval_results.items(): |
| | | print(f"{key}: {value:.4f}") |
| | | |
| | | # 保存最终模型 |
| | | model.save_pretrained(f"{TrainNERConfig.MODEL_PATH}/best_model") |
| | | tokenizer.save_pretrained(f"{TrainNERConfig.MODEL_PATH}/best_model") |
| | | |
| | | if __name__ == "__main__": |
| | | main() |