| | |
| | | from transformers import BertTokenizer, BertForSequenceClassification, AutoTokenizer, AutoModelForTokenClassification |
| | | import torch |
| | | from werkzeug.exceptions import BadRequest |
| | | from ner_config import NERConfig, RepaymentNERConfig, IncomeNERConfig |
| | | from ner_config import NERConfig, RepaymentNERConfig, IncomeNERConfig, FlightNERConfig, TrainNERConfig |
| | | import re |
| | | |
| | | # 配置日志 |
| | |
| | | self.ner_path = "./models/ner_model/best_model" |
| | | self.repayment_path = "./models/repayment_model/best_model" |
| | | self.income_path = "./models/income_model/best_model" |
| | | self.flight_path = "./models/flight_model/best_model" |
| | | self.train_path = "./models/train_model/best_model" # 添加火车票模型路径 |
| | | |
| | | # 检查模型文件 |
| | | self._check_model_files() |
| | |
| | | self.ner_tokenizer, self.ner_model = self._load_ner() |
| | | self.repayment_tokenizer, self.repayment_model = self._load_repayment() |
| | | self.income_tokenizer, self.income_model = self._load_income() |
| | | self.flight_tokenizer, self.flight_model = self._load_flight() |
| | | self.train_tokenizer, self.train_model = self._load_train() # 加载火车票模型 |
| | | |
| | | # 将模型设置为评估模式 |
| | | self.classifier_model.eval() |
| | | self.ner_model.eval() |
| | | self.repayment_model.eval() |
| | | self.income_model.eval() |
| | | self.flight_model.eval() |
| | | self.train_model.eval() # 设置火车票模型为评估模式 |
| | | |
| | | def _check_model_files(self): |
| | | """检查模型文件是否存在""" |
| | |
| | | raise RuntimeError("还款模型文件不存在,请先运行训练脚本") |
| | | if not os.path.exists(self.income_path): |
| | | raise RuntimeError("收入模型文件不存在,请先运行训练脚本") |
| | | if not os.path.exists(self.flight_path): |
| | | raise RuntimeError("航班模型文件不存在,请先运行训练脚本") |
| | | if not os.path.exists(self.train_path): |
| | | raise RuntimeError("火车票模型文件不存在,请先运行训练脚本") |
| | | |
| | | def _load_classifier(self) -> Tuple[BertTokenizer, BertForSequenceClassification]: |
| | | """加载分类模型""" |
| | |
| | | logger.error(f"加载收入模型失败: {str(e)}") |
| | | raise |
| | | |
| | | def _load_flight(self): |
| | | """加载航班模型""" |
| | | try: |
| | | tokenizer = AutoTokenizer.from_pretrained(self.flight_path) |
| | | model = AutoModelForTokenClassification.from_pretrained(self.flight_path) |
| | | return tokenizer, model |
| | | except Exception as e: |
| | | logger.error(f"加载航班模型失败: {str(e)}") |
| | | raise |
| | | |
| | | def _load_train(self): |
| | | """加载火车票模型""" |
| | | try: |
| | | tokenizer = AutoTokenizer.from_pretrained(self.train_path) |
| | | model = AutoModelForTokenClassification.from_pretrained(self.train_path) |
| | | return tokenizer, model |
| | | except Exception as e: |
| | | logger.error(f"加载火车票模型失败: {str(e)}") |
| | | raise |
| | | |
| | | def classify_sms(self, text: str) -> str: |
| | | """对短信进行分类""" |
| | | try: |
| | |
| | | "company": None, # 寄件公司 |
| | | "address": None, # 地址 |
| | | "pickup_code": None, # 取件码 |
| | | "time": None # 时间 |
| | | "time": None # 添加时间字段 |
| | | } |
| | | |
| | | # 第一阶段:直接从文本中提取取件码 |
| | |
| | | logger.error(f"收入实体提取失败: {str(e)}") |
| | | raise |
| | | |
| | | def extract_flight_entities(self, text: str) -> Dict[str, Optional[str]]: |
| | | """提取航班相关实体""" |
| | | try: |
| | | # 初始化结果字典 |
| | | result = { |
| | | "flight": None, # 航班号 |
| | | "company": None, # 航空公司 |
| | | "start": None, # 出发地 |
| | | "end": None, # 目的地 |
| | | "date": None, # 日期 |
| | | "time": None, # 时间 |
| | | "departure_time": None, # 起飞时间 |
| | | "arrival_time": None, # 到达时间 |
| | | "ticket_num": None, # 机票号码 |
| | | "seat": None # 座位等信息 |
| | | } |
| | | |
| | | # 使用NER模型提取实体 |
| | | inputs = self.flight_tokenizer( |
| | | text, |
| | | return_tensors="pt", |
| | | truncation=True, |
| | | max_length=FlightNERConfig.MAX_LENGTH |
| | | ) |
| | | |
| | | with torch.no_grad(): |
| | | outputs = self.flight_model(**inputs) |
| | | |
| | | predictions = torch.argmax(outputs.logits, dim=2) |
| | | tokens = self.flight_tokenizer.convert_ids_to_tokens(inputs["input_ids"][0]) |
| | | tags = [self.flight_model.config.id2label[p] for p in predictions[0].numpy()] |
| | | |
| | | # 解析实体 |
| | | current_entity = None |
| | | |
| | | for token, tag in zip(tokens, tags): |
| | | if tag.startswith("B-"): |
| | | if current_entity: |
| | | entity_type = current_entity["type"].lower() |
| | | result[entity_type] = current_entity["text"].replace("[UNK]", "").replace("##", "").strip() |
| | | current_entity = {"type": tag[2:], "text": token} |
| | | elif tag.startswith("I-") and current_entity and tag[2:] == current_entity["type"]: |
| | | current_entity["text"] += token |
| | | else: |
| | | if current_entity: |
| | | entity_type = current_entity["type"].lower() |
| | | result[entity_type] = current_entity["text"].replace("[UNK]", "").replace("##", "").strip() |
| | | current_entity = None |
| | | |
| | | # 处理最后一个实体 |
| | | if current_entity: |
| | | entity_type = current_entity["type"].lower() |
| | | result[entity_type] = current_entity["text"].replace("[UNK]", "").replace("##", "").strip() |
| | | |
| | | # 处理航班号格式 |
| | | if result["flight"]: |
| | | flight_no = result["flight"].upper() |
| | | # 清理航班号,只保留字母和数字 |
| | | flight_no = ''.join(c for c in flight_no if c.isalnum()) |
| | | # 验证航班号格式 |
| | | valid_pattern = re.compile(FlightNERConfig.FLIGHT_CONFIG['pattern']) |
| | | if valid_pattern.match(flight_no): |
| | | result["flight"] = flight_no |
| | | else: |
| | | # 尝试修复常见错误 |
| | | if len(flight_no) >= FlightNERConfig.FLIGHT_CONFIG['min_length'] and flight_no[:2].isalpha() and flight_no[2:].isdigit(): |
| | | result["flight"] = flight_no |
| | | else: |
| | | result["flight"] = None |
| | | |
| | | # 清理日期格式 |
| | | if result["date"]: |
| | | date_str = result["date"] |
| | | # 保留数字和常见日期分隔符 |
| | | date_str = ''.join(c for c in date_str if c.isdigit() or c in ['年', '月', '日', '-', '/', '.']) |
| | | result["date"] = date_str |
| | | |
| | | # 清理时间格式 |
| | | for time_field in ["time", "departure_time", "arrival_time"]: |
| | | if result[time_field]: |
| | | time_str = result[time_field] |
| | | # 保留数字和常见时间分隔符 |
| | | time_str = ''.join(c for c in time_str if c.isdigit() or c in [':', '时', '分', '点']) |
| | | result[time_field] = time_str |
| | | |
| | | # 处理机票号码 |
| | | if result["ticket_num"]: |
| | | ticket_num = result["ticket_num"] |
| | | # 清理机票号码,只保留字母和数字 |
| | | ticket_num = ''.join(c for c in ticket_num if c.isalnum()) |
| | | result["ticket_num"] = ticket_num |
| | | |
| | | # 处理座位信息 |
| | | if result["seat"]: |
| | | seat_str = result["seat"] |
| | | # 移除可能的额外空格和特殊字符 |
| | | seat_str = seat_str.replace(" ", "").strip() |
| | | result["seat"] = seat_str |
| | | |
| | | return result |
| | | except Exception as e: |
| | | logger.error(f"航班实体提取失败: {str(e)}") |
| | | raise |
| | | |
| | | def extract_train_entities(self, text: str) -> Dict[str, Optional[str]]: |
| | | """提取火车票相关实体""" |
| | | try: |
| | | # 初始化结果字典 |
| | | result = { |
| | | "company": None, # 12306 |
| | | "trips": None, # 车次 |
| | | "start": None, # 出发站 |
| | | "end": None, # 到达站 |
| | | "date": None, # 日期 |
| | | "time": None, # 时间 |
| | | "seat": None, # 座位等信息 |
| | | "name": None # 用户姓名 |
| | | } |
| | | |
| | | # 使用NER模型提取实体 |
| | | inputs = self.train_tokenizer( |
| | | text, |
| | | return_tensors="pt", |
| | | truncation=True, |
| | | max_length=TrainNERConfig.MAX_LENGTH |
| | | ) |
| | | |
| | | with torch.no_grad(): |
| | | outputs = self.train_model(**inputs) |
| | | |
| | | predictions = torch.argmax(outputs.logits, dim=2) |
| | | tokens = self.train_tokenizer.convert_ids_to_tokens(inputs["input_ids"][0]) |
| | | tags = [self.train_model.config.id2label[p] for p in predictions[0].numpy()] |
| | | |
| | | # 解析实体 |
| | | current_entity = None |
| | | |
| | | for token, tag in zip(tokens, tags): |
| | | if tag.startswith("B-"): |
| | | if current_entity: |
| | | entity_type = current_entity["type"].lower() |
| | | result[entity_type] = current_entity["text"].replace("[UNK]", "").replace("##", "").strip() |
| | | current_entity = {"type": tag[2:], "text": token} |
| | | elif tag.startswith("I-") and current_entity and tag[2:] == current_entity["type"]: |
| | | current_entity["text"] += token |
| | | else: |
| | | if current_entity: |
| | | entity_type = current_entity["type"].lower() |
| | | result[entity_type] = current_entity["text"].replace("[UNK]", "").replace("##", "").strip() |
| | | current_entity = None |
| | | |
| | | # 处理最后一个实体 |
| | | if current_entity: |
| | | entity_type = current_entity["type"].lower() |
| | | result[entity_type] = current_entity["text"].replace("[UNK]", "").replace("##", "").strip() |
| | | |
| | | # 处理公司名称,通常为12306 |
| | | if result["company"]: |
| | | company = result["company"].strip() |
| | | # 如果文本中检测不到公司名称,但包含12306,则默认为12306 |
| | | result["company"] = company |
| | | elif "12306" in text: |
| | | result["company"] = "12306" |
| | | |
| | | # 处理车次格式 |
| | | if result["trips"]: |
| | | trips_no = result["trips"].upper() |
| | | # 清理车次号,只保留字母和数字 |
| | | trips_no = ''.join(c for c in trips_no if c.isalnum() or c in ['/', '-']) |
| | | |
| | | # 验证车次格式 |
| | | valid_patterns = [re.compile(pattern) for pattern in TrainNERConfig.TRIPS_CONFIG['patterns']] |
| | | if any(pattern.match(trips_no) for pattern in valid_patterns): |
| | | result["trips"] = trips_no |
| | | else: |
| | | # 尝试修复常见错误 |
| | | if len(trips_no) >= TrainNERConfig.TRIPS_CONFIG['min_length'] and any(trips_no.startswith(t) for t in TrainNERConfig.TRIPS_CONFIG['train_types']): |
| | | result["trips"] = trips_no |
| | | elif trips_no.isdigit() and 1 <= len(trips_no) <= TrainNERConfig.TRIPS_CONFIG['max_length']: |
| | | result["trips"] = trips_no |
| | | else: |
| | | result["trips"] = None |
| | | |
| | | # 清理日期格式 |
| | | if result["date"]: |
| | | date_str = result["date"] |
| | | # 保留数字和常见日期分隔符 |
| | | date_str = ''.join(c for c in date_str if c.isdigit() or c in ['年', '月', '日', '-', '/', '.']) |
| | | result["date"] = date_str |
| | | |
| | | # 清理时间格式 |
| | | if result["time"]: |
| | | time_str = result["time"] |
| | | # 保留数字和常见时间分隔符 |
| | | time_str = ''.join(c for c in time_str if c.isdigit() or c in [':', '时', '分', '点']) |
| | | result["time"] = time_str |
| | | |
| | | # 处理座位信息 |
| | | if result["seat"]: |
| | | seat_str = result["seat"] |
| | | # 移除可能的额外空格和特殊字符 |
| | | seat_str = seat_str.replace(" ", "").strip() |
| | | result["seat"] = seat_str |
| | | |
| | | # 处理乘客姓名 |
| | | if result["name"]: |
| | | name = result["name"].strip() |
| | | # 移除可能的标点符号 |
| | | name = ''.join(c for c in name if c.isalnum() or c in ['*', '·']) |
| | | result["name"] = name |
| | | |
| | | return result |
| | | except Exception as e: |
| | | logger.error(f"火车票实体提取失败: {str(e)}") |
| | | raise |
| | | |
| | | # 创建Flask应用 |
| | | app = Flask(__name__) |
| | | model_manager = ModelManager() |
| | |
| | | details = model_manager.extract_repayment_entities(text) |
| | | elif category == "收入": |
| | | details = model_manager.extract_income_entities(text) |
| | | elif category == "航班": |
| | | details = model_manager.extract_flight_entities(text) |
| | | elif category == "火车票": # 添加火车票类别处理 |
| | | details = model_manager.extract_train_entities(text) |
| | | else: |
| | | details = {} |
| | | |