From acc5c1281b50c12e4d04c81b899410f6ca2cacac Mon Sep 17 00:00:00 2001
From: cloudroam <cloudroam>
Date: 星期二, 15 四月 2025 15:13:30 +0800
Subject: [PATCH] add: 增加航班和火车票

---
 app.py |  254 ++++++++++++++++++++++++++++++++++++++++++++++++++
 1 files changed, 252 insertions(+), 2 deletions(-)

diff --git a/app.py b/app.py
index f697651..47aa332 100644
--- a/app.py
+++ b/app.py
@@ -7,7 +7,7 @@
 from transformers import BertTokenizer, BertForSequenceClassification, AutoTokenizer, AutoModelForTokenClassification
 import torch
 from werkzeug.exceptions import BadRequest
-from ner_config import NERConfig, RepaymentNERConfig, IncomeNERConfig
+from ner_config import NERConfig, RepaymentNERConfig, IncomeNERConfig, FlightNERConfig, TrainNERConfig
 import re
 
 # 配置日志
@@ -27,6 +27,8 @@
         self.ner_path = "./models/ner_model/best_model"
         self.repayment_path = "./models/repayment_model/best_model"
         self.income_path = "./models/income_model/best_model"
+        self.flight_path = "./models/flight_model/best_model"  
+        self.train_path = "./models/train_model/best_model"  # 添加火车票模型路径
 
         # 检查模型文件
         self._check_model_files()
@@ -36,12 +38,16 @@
         self.ner_tokenizer, self.ner_model = self._load_ner()
         self.repayment_tokenizer, self.repayment_model = self._load_repayment()
         self.income_tokenizer, self.income_model = self._load_income()
+        self.flight_tokenizer, self.flight_model = self._load_flight()
+        self.train_tokenizer, self.train_model = self._load_train()  # 加载火车票模型
         
         # 将模型设置为评估模式
         self.classifier_model.eval()
         self.ner_model.eval()
         self.repayment_model.eval()
         self.income_model.eval()
+        self.flight_model.eval()
+        self.train_model.eval()  # 设置火车票模型为评估模式
         
     def _check_model_files(self):
         """检查模型文件是否存在"""
@@ -53,6 +59,10 @@
             raise RuntimeError("还款模型文件不存在,请先运行训练脚本")
         if not os.path.exists(self.income_path):
             raise RuntimeError("收入模型文件不存在,请先运行训练脚本")
+        if not os.path.exists(self.flight_path):
+            raise RuntimeError("航班模型文件不存在,请先运行训练脚本")
+        if not os.path.exists(self.train_path):
+            raise RuntimeError("火车票模型文件不存在,请先运行训练脚本")
             
     def _load_classifier(self) -> Tuple[BertTokenizer, BertForSequenceClassification]:
         """加载分类模型"""
@@ -94,6 +104,26 @@
             logger.error(f"加载收入模型失败: {str(e)}")
             raise
             
+    def _load_flight(self):
+        """加载航班模型"""
+        try:
+            tokenizer = AutoTokenizer.from_pretrained(self.flight_path)
+            model = AutoModelForTokenClassification.from_pretrained(self.flight_path)
+            return tokenizer, model
+        except Exception as e:
+            logger.error(f"加载航班模型失败: {str(e)}")
+            raise
+
+    def _load_train(self):
+        """加载火车票模型"""
+        try:
+            tokenizer = AutoTokenizer.from_pretrained(self.train_path)
+            model = AutoModelForTokenClassification.from_pretrained(self.train_path)
+            return tokenizer, model
+        except Exception as e:
+            logger.error(f"加载火车票模型失败: {str(e)}")
+            raise
+
     def classify_sms(self, text: str) -> str:
         """对短信进行分类"""
         try:
@@ -120,7 +150,7 @@
                 "company": None,       # 寄件公司
                 "address": None,       # 地址
                 "pickup_code": None,   # 取件码
-                "time": None          # 时间
+                "time": None           # 添加时间字段
             }
             
             # 第一阶段:直接从文本中提取取件码
@@ -662,6 +692,222 @@
             logger.error(f"收入实体提取失败: {str(e)}")
             raise
 
+    def extract_flight_entities(self, text: str) -> Dict[str, Optional[str]]:
+        """提取航班相关实体"""
+        try:
+            # 初始化结果字典
+            result = {
+                "flight": None,           # 航班号
+                "company": None,          # 航空公司
+                "start": None,            # 出发地
+                "end": None,              # 目的地
+                "date": None,             # 日期
+                "time": None,             # 时间
+                "departure_time": None,   # 起飞时间
+                "arrival_time": None,     # 到达时间
+                "ticket_num": None,       # 机票号码
+                "seat": None              # 座位等信息
+            }
+            
+            # 使用NER模型提取实体
+            inputs = self.flight_tokenizer(
+                text, 
+                return_tensors="pt", 
+                truncation=True, 
+                max_length=FlightNERConfig.MAX_LENGTH
+            )
+            
+            with torch.no_grad():
+                outputs = self.flight_model(**inputs)
+            
+            predictions = torch.argmax(outputs.logits, dim=2)
+            tokens = self.flight_tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
+            tags = [self.flight_model.config.id2label[p] for p in predictions[0].numpy()]
+
+            # 解析实体
+            current_entity = None
+            
+            for token, tag in zip(tokens, tags):
+                if tag.startswith("B-"):
+                    if current_entity:
+                        entity_type = current_entity["type"].lower()
+                        result[entity_type] = current_entity["text"].replace("[UNK]", "").replace("##", "").strip()
+                    current_entity = {"type": tag[2:], "text": token}
+                elif tag.startswith("I-") and current_entity and tag[2:] == current_entity["type"]:
+                    current_entity["text"] += token
+                else:
+                    if current_entity:
+                        entity_type = current_entity["type"].lower()
+                        result[entity_type] = current_entity["text"].replace("[UNK]", "").replace("##", "").strip()
+                        current_entity = None
+
+            # 处理最后一个实体
+            if current_entity:
+                entity_type = current_entity["type"].lower()
+                result[entity_type] = current_entity["text"].replace("[UNK]", "").replace("##", "").strip()
+
+            # 处理航班号格式
+            if result["flight"]:
+                flight_no = result["flight"].upper()
+                # 清理航班号,只保留字母和数字
+                flight_no = ''.join(c for c in flight_no if c.isalnum())
+                # 验证航班号格式
+                valid_pattern = re.compile(FlightNERConfig.FLIGHT_CONFIG['pattern'])
+                if valid_pattern.match(flight_no):
+                    result["flight"] = flight_no
+                else:
+                    # 尝试修复常见错误
+                    if len(flight_no) >= FlightNERConfig.FLIGHT_CONFIG['min_length'] and flight_no[:2].isalpha() and flight_no[2:].isdigit():
+                        result["flight"] = flight_no
+                    else:
+                        result["flight"] = None
+
+            # 清理日期格式
+            if result["date"]:
+                date_str = result["date"]
+                # 保留数字和常见日期分隔符
+                date_str = ''.join(c for c in date_str if c.isdigit() or c in ['年', '月', '日', '-', '/', '.'])
+                result["date"] = date_str
+
+            # 清理时间格式
+            for time_field in ["time", "departure_time", "arrival_time"]:
+                if result[time_field]:
+                    time_str = result[time_field]
+                    # 保留数字和常见时间分隔符
+                    time_str = ''.join(c for c in time_str if c.isdigit() or c in [':', '时', '分', '点'])
+                    result[time_field] = time_str
+                    
+            # 处理机票号码
+            if result["ticket_num"]:
+                ticket_num = result["ticket_num"]
+                # 清理机票号码,只保留字母和数字
+                ticket_num = ''.join(c for c in ticket_num if c.isalnum())
+                result["ticket_num"] = ticket_num
+                
+            # 处理座位信息
+            if result["seat"]:
+                seat_str = result["seat"]
+                # 移除可能的额外空格和特殊字符
+                seat_str = seat_str.replace(" ", "").strip()
+                result["seat"] = seat_str
+
+            return result
+        except Exception as e:
+            logger.error(f"航班实体提取失败: {str(e)}")
+            raise
+
+    def extract_train_entities(self, text: str) -> Dict[str, Optional[str]]:
+        """提取火车票相关实体"""
+        try:
+            # 初始化结果字典
+            result = {
+                "company": None,         # 12306
+                "trips": None,           # 车次
+                "start": None,           # 出发站
+                "end": None,             # 到达站
+                "date": None,            # 日期
+                "time": None,            # 时间
+                "seat": None,            # 座位等信息
+                "name": None             # 用户姓名
+            }
+            
+            # 使用NER模型提取实体
+            inputs = self.train_tokenizer(
+                text, 
+                return_tensors="pt", 
+                truncation=True, 
+                max_length=TrainNERConfig.MAX_LENGTH
+            )
+            
+            with torch.no_grad():
+                outputs = self.train_model(**inputs)
+            
+            predictions = torch.argmax(outputs.logits, dim=2)
+            tokens = self.train_tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
+            tags = [self.train_model.config.id2label[p] for p in predictions[0].numpy()]
+
+            # 解析实体
+            current_entity = None
+            
+            for token, tag in zip(tokens, tags):
+                if tag.startswith("B-"):
+                    if current_entity:
+                        entity_type = current_entity["type"].lower()
+                        result[entity_type] = current_entity["text"].replace("[UNK]", "").replace("##", "").strip()
+                    current_entity = {"type": tag[2:], "text": token}
+                elif tag.startswith("I-") and current_entity and tag[2:] == current_entity["type"]:
+                    current_entity["text"] += token
+                else:
+                    if current_entity:
+                        entity_type = current_entity["type"].lower()
+                        result[entity_type] = current_entity["text"].replace("[UNK]", "").replace("##", "").strip()
+                        current_entity = None
+
+            # 处理最后一个实体
+            if current_entity:
+                entity_type = current_entity["type"].lower()
+                result[entity_type] = current_entity["text"].replace("[UNK]", "").replace("##", "").strip()
+
+            # 处理公司名称,通常为12306
+            if result["company"]:
+                company = result["company"].strip()
+                # 如果文本中检测不到公司名称,但包含12306,则默认为12306
+                result["company"] = company
+            elif "12306" in text:
+                result["company"] = "12306"
+
+            # 处理车次格式
+            if result["trips"]:
+                trips_no = result["trips"].upper()
+                # 清理车次号,只保留字母和数字
+                trips_no = ''.join(c for c in trips_no if c.isalnum() or c in ['/', '-'])
+                
+                # 验证车次格式
+                valid_patterns = [re.compile(pattern) for pattern in TrainNERConfig.TRIPS_CONFIG['patterns']]
+                if any(pattern.match(trips_no) for pattern in valid_patterns):
+                    result["trips"] = trips_no
+                else:
+                    # 尝试修复常见错误
+                    if len(trips_no) >= TrainNERConfig.TRIPS_CONFIG['min_length'] and any(trips_no.startswith(t) for t in TrainNERConfig.TRIPS_CONFIG['train_types']):
+                        result["trips"] = trips_no
+                    elif trips_no.isdigit() and 1 <= len(trips_no) <= TrainNERConfig.TRIPS_CONFIG['max_length']:
+                        result["trips"] = trips_no
+                    else:
+                        result["trips"] = None
+
+            # 清理日期格式
+            if result["date"]:
+                date_str = result["date"]
+                # 保留数字和常见日期分隔符
+                date_str = ''.join(c for c in date_str if c.isdigit() or c in ['年', '月', '日', '-', '/', '.'])
+                result["date"] = date_str
+
+            # 清理时间格式
+            if result["time"]:
+                time_str = result["time"]
+                # 保留数字和常见时间分隔符
+                time_str = ''.join(c for c in time_str if c.isdigit() or c in [':', '时', '分', '点'])
+                result["time"] = time_str
+
+            # 处理座位信息
+            if result["seat"]:
+                seat_str = result["seat"]
+                # 移除可能的额外空格和特殊字符
+                seat_str = seat_str.replace(" ", "").strip()
+                result["seat"] = seat_str
+
+            # 处理乘客姓名
+            if result["name"]:
+                name = result["name"].strip()
+                # 移除可能的标点符号
+                name = ''.join(c for c in name if c.isalnum() or c in ['*', '·'])
+                result["name"] = name
+
+            return result
+        except Exception as e:
+            logger.error(f"火车票实体提取失败: {str(e)}")
+            raise
+
 # 创建Flask应用
 app = Flask(__name__)
 model_manager = ModelManager()
@@ -695,6 +941,10 @@
             details = model_manager.extract_repayment_entities(text)
         elif category == "收入":
             details = model_manager.extract_income_entities(text)
+        elif category == "航班":
+            details = model_manager.extract_flight_entities(text)
+        elif category == "火车票":  # 添加火车票类别处理
+            details = model_manager.extract_train_entities(text)
         else:
             details = {}
         

--
Gitblit v1.9.3