From acc5c1281b50c12e4d04c81b899410f6ca2cacac Mon Sep 17 00:00:00 2001 From: cloudroam <cloudroam> Date: 星期二, 15 四月 2025 15:13:30 +0800 Subject: [PATCH] add: 增加航班和火车票 --- app.py | 254 ++++++++++++++++++++++++++++++++++++++++++++++++++ 1 files changed, 252 insertions(+), 2 deletions(-) diff --git a/app.py b/app.py index f697651..47aa332 100644 --- a/app.py +++ b/app.py @@ -7,7 +7,7 @@ from transformers import BertTokenizer, BertForSequenceClassification, AutoTokenizer, AutoModelForTokenClassification import torch from werkzeug.exceptions import BadRequest -from ner_config import NERConfig, RepaymentNERConfig, IncomeNERConfig +from ner_config import NERConfig, RepaymentNERConfig, IncomeNERConfig, FlightNERConfig, TrainNERConfig import re # 配置日志 @@ -27,6 +27,8 @@ self.ner_path = "./models/ner_model/best_model" self.repayment_path = "./models/repayment_model/best_model" self.income_path = "./models/income_model/best_model" + self.flight_path = "./models/flight_model/best_model" + self.train_path = "./models/train_model/best_model" # 添加火车票模型路径 # 检查模型文件 self._check_model_files() @@ -36,12 +38,16 @@ self.ner_tokenizer, self.ner_model = self._load_ner() self.repayment_tokenizer, self.repayment_model = self._load_repayment() self.income_tokenizer, self.income_model = self._load_income() + self.flight_tokenizer, self.flight_model = self._load_flight() + self.train_tokenizer, self.train_model = self._load_train() # 加载火车票模型 # 将模型设置为评估模式 self.classifier_model.eval() self.ner_model.eval() self.repayment_model.eval() self.income_model.eval() + self.flight_model.eval() + self.train_model.eval() # 设置火车票模型为评估模式 def _check_model_files(self): """检查模型文件是否存在""" @@ -53,6 +59,10 @@ raise RuntimeError("还款模型文件不存在,请先运行训练脚本") if not os.path.exists(self.income_path): raise RuntimeError("收入模型文件不存在,请先运行训练脚本") + if not os.path.exists(self.flight_path): + raise RuntimeError("航班模型文件不存在,请先运行训练脚本") + if not os.path.exists(self.train_path): + raise RuntimeError("火车票模型文件不存在,请先运行训练脚本") def _load_classifier(self) -> Tuple[BertTokenizer, BertForSequenceClassification]: """加载分类模型""" @@ -94,6 +104,26 @@ logger.error(f"加载收入模型失败: {str(e)}") raise + def _load_flight(self): + """加载航班模型""" + try: + tokenizer = AutoTokenizer.from_pretrained(self.flight_path) + model = AutoModelForTokenClassification.from_pretrained(self.flight_path) + return tokenizer, model + except Exception as e: + logger.error(f"加载航班模型失败: {str(e)}") + raise + + def _load_train(self): + """加载火车票模型""" + try: + tokenizer = AutoTokenizer.from_pretrained(self.train_path) + model = AutoModelForTokenClassification.from_pretrained(self.train_path) + return tokenizer, model + except Exception as e: + logger.error(f"加载火车票模型失败: {str(e)}") + raise + def classify_sms(self, text: str) -> str: """对短信进行分类""" try: @@ -120,7 +150,7 @@ "company": None, # 寄件公司 "address": None, # 地址 "pickup_code": None, # 取件码 - "time": None # 时间 + "time": None # 添加时间字段 } # 第一阶段:直接从文本中提取取件码 @@ -662,6 +692,222 @@ logger.error(f"收入实体提取失败: {str(e)}") raise + def extract_flight_entities(self, text: str) -> Dict[str, Optional[str]]: + """提取航班相关实体""" + try: + # 初始化结果字典 + result = { + "flight": None, # 航班号 + "company": None, # 航空公司 + "start": None, # 出发地 + "end": None, # 目的地 + "date": None, # 日期 + "time": None, # 时间 + "departure_time": None, # 起飞时间 + "arrival_time": None, # 到达时间 + "ticket_num": None, # 机票号码 + "seat": None # 座位等信息 + } + + # 使用NER模型提取实体 + inputs = self.flight_tokenizer( + text, + return_tensors="pt", + truncation=True, + max_length=FlightNERConfig.MAX_LENGTH + ) + + with torch.no_grad(): + outputs = self.flight_model(**inputs) + + predictions = torch.argmax(outputs.logits, dim=2) + tokens = self.flight_tokenizer.convert_ids_to_tokens(inputs["input_ids"][0]) + tags = [self.flight_model.config.id2label[p] for p in predictions[0].numpy()] + + # 解析实体 + current_entity = None + + for token, tag in zip(tokens, tags): + if tag.startswith("B-"): + if current_entity: + entity_type = current_entity["type"].lower() + result[entity_type] = current_entity["text"].replace("[UNK]", "").replace("##", "").strip() + current_entity = {"type": tag[2:], "text": token} + elif tag.startswith("I-") and current_entity and tag[2:] == current_entity["type"]: + current_entity["text"] += token + else: + if current_entity: + entity_type = current_entity["type"].lower() + result[entity_type] = current_entity["text"].replace("[UNK]", "").replace("##", "").strip() + current_entity = None + + # 处理最后一个实体 + if current_entity: + entity_type = current_entity["type"].lower() + result[entity_type] = current_entity["text"].replace("[UNK]", "").replace("##", "").strip() + + # 处理航班号格式 + if result["flight"]: + flight_no = result["flight"].upper() + # 清理航班号,只保留字母和数字 + flight_no = ''.join(c for c in flight_no if c.isalnum()) + # 验证航班号格式 + valid_pattern = re.compile(FlightNERConfig.FLIGHT_CONFIG['pattern']) + if valid_pattern.match(flight_no): + result["flight"] = flight_no + else: + # 尝试修复常见错误 + if len(flight_no) >= FlightNERConfig.FLIGHT_CONFIG['min_length'] and flight_no[:2].isalpha() and flight_no[2:].isdigit(): + result["flight"] = flight_no + else: + result["flight"] = None + + # 清理日期格式 + if result["date"]: + date_str = result["date"] + # 保留数字和常见日期分隔符 + date_str = ''.join(c for c in date_str if c.isdigit() or c in ['年', '月', '日', '-', '/', '.']) + result["date"] = date_str + + # 清理时间格式 + for time_field in ["time", "departure_time", "arrival_time"]: + if result[time_field]: + time_str = result[time_field] + # 保留数字和常见时间分隔符 + time_str = ''.join(c for c in time_str if c.isdigit() or c in [':', '时', '分', '点']) + result[time_field] = time_str + + # 处理机票号码 + if result["ticket_num"]: + ticket_num = result["ticket_num"] + # 清理机票号码,只保留字母和数字 + ticket_num = ''.join(c for c in ticket_num if c.isalnum()) + result["ticket_num"] = ticket_num + + # 处理座位信息 + if result["seat"]: + seat_str = result["seat"] + # 移除可能的额外空格和特殊字符 + seat_str = seat_str.replace(" ", "").strip() + result["seat"] = seat_str + + return result + except Exception as e: + logger.error(f"航班实体提取失败: {str(e)}") + raise + + def extract_train_entities(self, text: str) -> Dict[str, Optional[str]]: + """提取火车票相关实体""" + try: + # 初始化结果字典 + result = { + "company": None, # 12306 + "trips": None, # 车次 + "start": None, # 出发站 + "end": None, # 到达站 + "date": None, # 日期 + "time": None, # 时间 + "seat": None, # 座位等信息 + "name": None # 用户姓名 + } + + # 使用NER模型提取实体 + inputs = self.train_tokenizer( + text, + return_tensors="pt", + truncation=True, + max_length=TrainNERConfig.MAX_LENGTH + ) + + with torch.no_grad(): + outputs = self.train_model(**inputs) + + predictions = torch.argmax(outputs.logits, dim=2) + tokens = self.train_tokenizer.convert_ids_to_tokens(inputs["input_ids"][0]) + tags = [self.train_model.config.id2label[p] for p in predictions[0].numpy()] + + # 解析实体 + current_entity = None + + for token, tag in zip(tokens, tags): + if tag.startswith("B-"): + if current_entity: + entity_type = current_entity["type"].lower() + result[entity_type] = current_entity["text"].replace("[UNK]", "").replace("##", "").strip() + current_entity = {"type": tag[2:], "text": token} + elif tag.startswith("I-") and current_entity and tag[2:] == current_entity["type"]: + current_entity["text"] += token + else: + if current_entity: + entity_type = current_entity["type"].lower() + result[entity_type] = current_entity["text"].replace("[UNK]", "").replace("##", "").strip() + current_entity = None + + # 处理最后一个实体 + if current_entity: + entity_type = current_entity["type"].lower() + result[entity_type] = current_entity["text"].replace("[UNK]", "").replace("##", "").strip() + + # 处理公司名称,通常为12306 + if result["company"]: + company = result["company"].strip() + # 如果文本中检测不到公司名称,但包含12306,则默认为12306 + result["company"] = company + elif "12306" in text: + result["company"] = "12306" + + # 处理车次格式 + if result["trips"]: + trips_no = result["trips"].upper() + # 清理车次号,只保留字母和数字 + trips_no = ''.join(c for c in trips_no if c.isalnum() or c in ['/', '-']) + + # 验证车次格式 + valid_patterns = [re.compile(pattern) for pattern in TrainNERConfig.TRIPS_CONFIG['patterns']] + if any(pattern.match(trips_no) for pattern in valid_patterns): + result["trips"] = trips_no + else: + # 尝试修复常见错误 + if len(trips_no) >= TrainNERConfig.TRIPS_CONFIG['min_length'] and any(trips_no.startswith(t) for t in TrainNERConfig.TRIPS_CONFIG['train_types']): + result["trips"] = trips_no + elif trips_no.isdigit() and 1 <= len(trips_no) <= TrainNERConfig.TRIPS_CONFIG['max_length']: + result["trips"] = trips_no + else: + result["trips"] = None + + # 清理日期格式 + if result["date"]: + date_str = result["date"] + # 保留数字和常见日期分隔符 + date_str = ''.join(c for c in date_str if c.isdigit() or c in ['年', '月', '日', '-', '/', '.']) + result["date"] = date_str + + # 清理时间格式 + if result["time"]: + time_str = result["time"] + # 保留数字和常见时间分隔符 + time_str = ''.join(c for c in time_str if c.isdigit() or c in [':', '时', '分', '点']) + result["time"] = time_str + + # 处理座位信息 + if result["seat"]: + seat_str = result["seat"] + # 移除可能的额外空格和特殊字符 + seat_str = seat_str.replace(" ", "").strip() + result["seat"] = seat_str + + # 处理乘客姓名 + if result["name"]: + name = result["name"].strip() + # 移除可能的标点符号 + name = ''.join(c for c in name if c.isalnum() or c in ['*', '·']) + result["name"] = name + + return result + except Exception as e: + logger.error(f"火车票实体提取失败: {str(e)}") + raise + # 创建Flask应用 app = Flask(__name__) model_manager = ModelManager() @@ -695,6 +941,10 @@ details = model_manager.extract_repayment_entities(text) elif category == "收入": details = model_manager.extract_income_entities(text) + elif category == "航班": + details = model_manager.extract_flight_entities(text) + elif category == "火车票": # 添加火车票类别处理 + details = model_manager.extract_train_entities(text) else: details = {} -- Gitblit v1.9.3