From e46063e99f4cc6b22de4e32227289da3122646dd Mon Sep 17 00:00:00 2001
From: cloudroam <cloudroam>
Date: 星期三, 16 四月 2025 16:55:20 +0800
Subject: [PATCH] fix
---
app.py | 462 ++++++++++++++++++++++++++++++++++++++++++++++++---------
1 files changed, 387 insertions(+), 75 deletions(-)
diff --git a/app.py b/app.py
index f697651..ba71f7e 100644
--- a/app.py
+++ b/app.py
@@ -7,7 +7,7 @@
from transformers import BertTokenizer, BertForSequenceClassification, AutoTokenizer, AutoModelForTokenClassification
import torch
from werkzeug.exceptions import BadRequest
-from ner_config import NERConfig, RepaymentNERConfig, IncomeNERConfig
+from ner_config import NERConfig, RepaymentNERConfig, IncomeNERConfig, FlightNERConfig, TrainNERConfig
import re
# 配置日志
@@ -27,6 +27,8 @@
self.ner_path = "./models/ner_model/best_model"
self.repayment_path = "./models/repayment_model/best_model"
self.income_path = "./models/income_model/best_model"
+ # self.flight_path = "./models/flight_model/best_model"
+ # self.train_path = "./models/train_model/best_model" # 添加火车票模型路径
# 检查模型文件
self._check_model_files()
@@ -36,12 +38,16 @@
self.ner_tokenizer, self.ner_model = self._load_ner()
self.repayment_tokenizer, self.repayment_model = self._load_repayment()
self.income_tokenizer, self.income_model = self._load_income()
+ # self.flight_tokenizer, self.flight_model = self._load_flight()
+ # self.train_tokenizer, self.train_model = self._load_train() # 加载火车票模型
# 将模型设置为评估模式
self.classifier_model.eval()
self.ner_model.eval()
self.repayment_model.eval()
self.income_model.eval()
+ # self.flight_model.eval()
+ # self.train_model.eval() # 设置火车票模型为评估模式
def _check_model_files(self):
"""检查模型文件是否存在"""
@@ -53,6 +59,10 @@
raise RuntimeError("还款模型文件不存在,请先运行训练脚本")
if not os.path.exists(self.income_path):
raise RuntimeError("收入模型文件不存在,请先运行训练脚本")
+ # if not os.path.exists(self.flight_path):
+ # raise RuntimeError("航班模型文件不存在,请先运行训练脚本")
+ # if not os.path.exists(self.train_path):
+ # raise RuntimeError("火车票模型文件不存在,请先运行训练脚本")
def _load_classifier(self) -> Tuple[BertTokenizer, BertForSequenceClassification]:
"""加载分类模型"""
@@ -94,6 +104,26 @@
logger.error(f"加载收入模型失败: {str(e)}")
raise
+ def _load_flight(self):
+ """加载航班模型"""
+ try:
+ tokenizer = AutoTokenizer.from_pretrained(self.flight_path)
+ model = AutoModelForTokenClassification.from_pretrained(self.flight_path)
+ return tokenizer, model
+ except Exception as e:
+ logger.error(f"加载航班模型失败: {str(e)}")
+ raise
+
+ def _load_train(self):
+ """加载火车票模型"""
+ try:
+ tokenizer = AutoTokenizer.from_pretrained(self.train_path)
+ model = AutoModelForTokenClassification.from_pretrained(self.train_path)
+ return tokenizer, model
+ except Exception as e:
+ logger.error(f"加载火车票模型失败: {str(e)}")
+ raise
+
def classify_sms(self, text: str) -> str:
"""对短信进行分类"""
try:
@@ -120,7 +150,7 @@
"company": None, # 寄件公司
"address": None, # 地址
"pickup_code": None, # 取件码
- "time": None # 时间
+ "time": None # 添加时间字段
}
# 第一阶段:直接从文本中提取取件码
@@ -352,67 +382,60 @@
if not amount_text:
return None
+ # 尝试直接在上下文中使用正则表达式查找更完整的金额
+ # 如果在同一句话里有类似"应还金额5,800元"这样的模式
+ amount_match = re.search(r'(?:应还|还款)?金额([\d,]+\.?\d*)(?:元|块钱|块|万元|万)?', context)
+ if amount_match:
+ return amount_match.group(1) # 直接返回匹配到的金额,保留原始格式
+
+ # 尝试查找最低还款金额
+ min_amount_match = re.search(r'最低还款([\d,]+\.?\d*)(?:元|块钱|块|万元|万)?', context)
+ if min_amount_match and "MIN_CODE" in current_entity["type"]:
+ return min_amount_match.group(1) # 直接返回匹配到的最低还款金额,保留原始格式
+
# 在上下文中查找完整金额
amount_index = context.find(amount_text)
if amount_index != -1:
# 扩大搜索范围,查找完整金额
- search_start = max(0, amount_index - 10) # 增加向前搜索范围
+ search_start = max(0, amount_index - 10)
search_end = min(len(context), amount_index + len(amount_text) + 10)
search_text = context[search_start:search_end]
- # 使用更精确的正则表达式查找金额模式
- amount_pattern = re.compile(r'(\d{1,10}(?:\.\d{1,2})?)')
+ # 使用正则表达式查找金额
+ amount_pattern = re.compile(r'([\d,]+\.?\d*)(?:元|块钱|块|万元|万)?')
matches = list(amount_pattern.finditer(search_text))
- # 找到最接近且最长的完整金额
- best_match = None
- min_distance = float('inf')
- max_length = 0
- target_pos = amount_index - search_start
-
- for match in matches:
- match_pos = match.start()
- distance = abs(match_pos - target_pos)
- match_text = match.group(1)
+ if matches:
+ # 选择最接近的匹配结果
+ best_match = None
+ min_distance = float('inf')
- # 优先选择更长的匹配,除非距离差异太大
- if len(match_text) > max_length or (len(match_text) == max_length and distance < min_distance):
- try:
- # 验证金额是否合理
- value = float(match_text)
- if value > 0 and value <= 9999999.99: # 设置合理的金额范围
- best_match = match_text
- min_distance = distance
- max_length = len(match_text)
- except ValueError:
- continue
+ for match in matches:
+ distance = abs(match.start() - (amount_index - search_start))
+ if distance < min_distance:
+ min_distance = distance
+ best_match = match.group(1) # 只取数字部分,保留逗号
+
+ if best_match:
+ return best_match
- if best_match:
- amount_text = best_match
-
+ # 如果上述方法都没找到,则保留原始提取结果但验证其有效性
# 移除货币符号和无效词
for symbol in RepaymentNERConfig.AMOUNT_CONFIG['currency_symbols']:
amount_text = amount_text.replace(symbol, '')
for word in RepaymentNERConfig.AMOUNT_CONFIG['invalid_words']:
amount_text = amount_text.replace(word, '')
- # 处理金额中的逗号
- amount_text = amount_text.replace(',', '')
-
+ # 验证金额有效性
+ clean_amount = amount_text.replace(',', '')
try:
- # 转换为浮点数
- value = float(amount_text)
-
- # 验证整数位数
- integer_part = str(int(value))
- if len(integer_part) <= RepaymentNERConfig.AMOUNT_CONFIG['max_integer_digits']:
- # 保持原始小数位数
- if '.' in amount_text:
- decimal_places = len(amount_text.split('.')[1])
- return f"{value:.{decimal_places}f}"
- return str(int(value))
+ value = float(clean_amount)
+ if value > 0:
+ # 返回原始格式
+ return amount_text
except ValueError:
pass
+
return None
# 实体提取
@@ -443,12 +466,11 @@
# 处理银行名称
if entities["BANK"]:
- # 修改银行名称处理逻辑
bank_parts = []
- seen = set() # 用于去重
+ seen = set()
for bank in entities["BANK"]:
bank = bank.strip()
- if bank and bank not in seen: # 避免重复
+ if bank and bank not in seen:
bank_parts.append(bank)
seen.add(bank)
bank = "".join(bank_parts)
@@ -457,7 +479,14 @@
# 处理还款类型
if entities["TYPE"]:
- type_ = "".join(entities["TYPE"]).strip()
+ type_parts = []
+ seen = set()
+ for type_ in entities["TYPE"]:
+ type_ = type_.strip()
+ if type_ and type_ not in seen:
+ type_parts.append(type_)
+ seen.add(type_)
+ type_ = "".join(type_parts)
if len(type_) <= RepaymentNERConfig.MAX_ENTITY_LENGTH["TYPE"]:
result["type"] = type_
@@ -470,32 +499,74 @@
# 处理日期
if entities["DATE"]:
date = "".join(entities["DATE"])
- date = ''.join(c for c in date if c.isdigit() or c in ['年', '月', '日'])
+ date = ''.join(c for c in date if c.isdigit() or c in ['年', '月', '日', '-'])
if date:
result["date"] = date
# 处理金额
- amount_candidates = []
- for amount in entities["PICKUP_CODE"]:
- cleaned_amount = clean_amount(amount, text)
- if cleaned_amount:
- try:
- value = float(cleaned_amount)
- amount_candidates.append((cleaned_amount, value))
- except ValueError:
- continue
+ # 先尝试使用正则表达式直接匹配金额
+ amount_match = re.search(r'(?:应还|还款)?金额([\d,]+\.?\d*)(?:元|块钱|块|万元|万)?', text)
+ if amount_match:
+ amount = amount_match.group(1) # 保留原始格式(带逗号)
+ # 验证金额有效性
+ try:
+ value = float(amount.replace(',', ''))
+ if value > 0:
+ result["amount"] = amount
+ except ValueError:
+ pass
- # 选择最大的有效金额
- if amount_candidates:
- # 按金额大小排序,选择最大的
- result["amount"] = max(amount_candidates, key=lambda x: x[1])[0]
+ # 如果正则没有匹配到,使用NER结果
+ if not result["amount"]:
+ amount_candidates = []
+ # 从识别的实体中获取
+ for amount in entities["PICKUP_CODE"]:
+ cleaned_amount = clean_amount(amount, text)
+ if cleaned_amount:
+ try:
+ value = float(cleaned_amount.replace(',', ''))
+ amount_candidates.append((cleaned_amount, value))
+ except ValueError:
+ continue
+
+ # 如果还是没有找到,尝试从文本中提取
+ if not amount_candidates:
+ # 使用更宽松的正则表达式匹配金额
+ amount_pattern = re.compile(r'([\d,]+\.?\d*)(?:元|块钱|块|万元|万)')
+ matches = list(amount_pattern.finditer(text))
+
+ for match in matches:
+ amount_text = match.group(1) # 获取数字部分,保留逗号
+ try:
+ value = float(amount_text.replace(',', ''))
+ amount_candidates.append((amount_text, value))
+ except ValueError:
+ continue
+
+ # 选择最大的有效金额
+ if amount_candidates:
+ result["amount"] = max(amount_candidates, key=lambda x: x[1])[0]
# 处理最低还款金额
- for amount in entities["MIN_CODE"]:
- cleaned_amount = clean_amount(amount, text) # 传入原始文本作为上下文
- if cleaned_amount:
- result["min_amount"] = cleaned_amount
- break
+ # 先尝试使用正则表达式直接匹配最低还款金额
+ min_amount_match = re.search(r'最低还款([\d,]+\.?\d*)(?:元|块钱|块|万元|万)?', text)
+ if min_amount_match:
+ min_amount = min_amount_match.group(1) # 保留原始格式(带逗号)
+ # 验证金额有效性
+ try:
+ value = float(min_amount.replace(',', ''))
+ if value > 0:
+ result["min_amount"] = min_amount
+ except ValueError:
+ pass
+
+ # 如果正则没有匹配到,使用NER结果
+ if not result["min_amount"] and entities["MIN_CODE"]:
+ for amount in entities["MIN_CODE"]:
+ cleaned_amount = clean_amount(amount, text)
+ if cleaned_amount:
+ result["min_amount"] = cleaned_amount
+ break
return result
@@ -541,9 +612,8 @@
search_end = min(len(context), amount_index + len(amount_text) + 10)
search_text = context[search_start:search_end]
- # 使用正则表达式查找金额模式
- import re
- amount_pattern = re.compile(r'(\d{1,10}(?:\.\d{1,2})?)')
+ # 使用更精确的正则表达式查找金额模式,支持带逗号的金额
+ amount_pattern = re.compile(r'(\d{1,3}(?:,\d{3})*(?:\.\d{1,2})?|\d+(?:\.\d{1,2})?)')
matches = list(amount_pattern.finditer(search_text))
# 找到最接近且最长的完整金额
@@ -641,12 +711,34 @@
result["datetime"] = datetime
# 处理收入金额
- if entities["PICKUP_CODE"]:
- for amount in entities["PICKUP_CODE"]:
- cleaned_amount = clean_amount(amount, text)
- if cleaned_amount:
- result["amount"] = cleaned_amount
- break
+ amount_candidates = []
+ # 首先从识别的实体中获取
+ for amount in entities["PICKUP_CODE"]:
+ cleaned_amount = clean_amount(amount, text)
+ if cleaned_amount:
+ try:
+ value = float(cleaned_amount)
+ amount_candidates.append((cleaned_amount, value))
+ except ValueError:
+ continue
+
+ # 如果没有找到有效金额,直接从文本中尝试提取
+ if not amount_candidates:
+ # 直接在整个文本中寻找金额模式
+ amount_pattern = re.compile(r'(\d{1,3}(?:,\d{3})*(?:\.\d{1,2})?|\d+(?:\.\d{1,2})?)')
+ matches = list(amount_pattern.finditer(text))
+
+ for match in matches:
+ amount_text = match.group(1)
+ try:
+ value = float(amount_text.replace(',', ''))
+ amount_candidates.append((amount_text, value))
+ except ValueError:
+ continue
+
+ # 选择最合适的有效金额
+ if amount_candidates:
+ result["amount"] = max(amount_candidates, key=lambda x: x[1])[0]
# 处理余额
if entities["BALANCE"]:
@@ -660,6 +752,222 @@
except Exception as e:
logger.error(f"收入实体提取失败: {str(e)}")
+ raise
+
+ def extract_flight_entities(self, text: str) -> Dict[str, Optional[str]]:
+ """提取航班相关实体"""
+ try:
+ # 初始化结果字典
+ result = {
+ "flight": None, # 航班号
+ "company": None, # 航空公司
+ "start": None, # 出发地
+ "end": None, # 目的地
+ "date": None, # 日期
+ "time": None, # 时间
+ "departure_time": None, # 起飞时间
+ "arrival_time": None, # 到达时间
+ "ticket_num": None, # 机票号码
+ "seat": None # 座位等信息
+ }
+
+ # 使用NER模型提取实体
+ inputs = self.flight_tokenizer(
+ text,
+ return_tensors="pt",
+ truncation=True,
+ max_length=FlightNERConfig.MAX_LENGTH
+ )
+
+ with torch.no_grad():
+ outputs = self.flight_model(**inputs)
+
+ predictions = torch.argmax(outputs.logits, dim=2)
+ tokens = self.flight_tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
+ tags = [self.flight_model.config.id2label[p] for p in predictions[0].numpy()]
+
+ # 解析实体
+ current_entity = None
+
+ for token, tag in zip(tokens, tags):
+ if tag.startswith("B-"):
+ if current_entity:
+ entity_type = current_entity["type"].lower()
+ result[entity_type] = current_entity["text"].replace("[UNK]", "").replace("##", "").strip()
+ current_entity = {"type": tag[2:], "text": token}
+ elif tag.startswith("I-") and current_entity and tag[2:] == current_entity["type"]:
+ current_entity["text"] += token
+ else:
+ if current_entity:
+ entity_type = current_entity["type"].lower()
+ result[entity_type] = current_entity["text"].replace("[UNK]", "").replace("##", "").strip()
+ current_entity = None
+
+ # 处理最后一个实体
+ if current_entity:
+ entity_type = current_entity["type"].lower()
+ result[entity_type] = current_entity["text"].replace("[UNK]", "").replace("##", "").strip()
+
+ # 处理航班号格式
+ if result["flight"]:
+ flight_no = result["flight"].upper()
+ # 清理航班号,只保留字母和数字
+ flight_no = ''.join(c for c in flight_no if c.isalnum())
+ # 验证航班号格式
+ valid_pattern = re.compile(FlightNERConfig.FLIGHT_CONFIG['pattern'])
+ if valid_pattern.match(flight_no):
+ result["flight"] = flight_no
+ else:
+ # 尝试修复常见错误
+ if len(flight_no) >= FlightNERConfig.FLIGHT_CONFIG['min_length'] and flight_no[:2].isalpha() and flight_no[2:].isdigit():
+ result["flight"] = flight_no
+ else:
+ result["flight"] = None
+
+ # 清理日期格式
+ if result["date"]:
+ date_str = result["date"]
+ # 保留数字和常见日期分隔符
+ date_str = ''.join(c for c in date_str if c.isdigit() or c in ['年', '月', '日', '-', '/', '.'])
+ result["date"] = date_str
+
+ # 清理时间格式
+ for time_field in ["time", "departure_time", "arrival_time"]:
+ if result[time_field]:
+ time_str = result[time_field]
+ # 保留数字和常见时间分隔符
+ time_str = ''.join(c for c in time_str if c.isdigit() or c in [':', '时', '分', '点'])
+ result[time_field] = time_str
+
+ # 处理机票号码
+ if result["ticket_num"]:
+ ticket_num = result["ticket_num"]
+ # 清理机票号码,只保留字母和数字
+ ticket_num = ''.join(c for c in ticket_num if c.isalnum())
+ result["ticket_num"] = ticket_num
+
+ # 处理座位信息
+ if result["seat"]:
+ seat_str = result["seat"]
+ # 移除可能的额外空格和特殊字符
+ seat_str = seat_str.replace(" ", "").strip()
+ result["seat"] = seat_str
+
+ return result
+ except Exception as e:
+ logger.error(f"航班实体提取失败: {str(e)}")
+ raise
+
+ def extract_train_entities(self, text: str) -> Dict[str, Optional[str]]:
+ """提取火车票相关实体"""
+ try:
+ # 初始化结果字典
+ result = {
+ "company": None, # 12306
+ "trips": None, # 车次
+ "start": None, # 出发站
+ "end": None, # 到达站
+ "date": None, # 日期
+ "time": None, # 时间
+ "seat": None, # 座位等信息
+ "name": None # 用户姓名
+ }
+
+ # 使用NER模型提取实体
+ inputs = self.train_tokenizer(
+ text,
+ return_tensors="pt",
+ truncation=True,
+ max_length=TrainNERConfig.MAX_LENGTH
+ )
+
+ with torch.no_grad():
+ outputs = self.train_model(**inputs)
+
+ predictions = torch.argmax(outputs.logits, dim=2)
+ tokens = self.train_tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
+ tags = [self.train_model.config.id2label[p] for p in predictions[0].numpy()]
+
+ # 解析实体
+ current_entity = None
+
+ for token, tag in zip(tokens, tags):
+ if tag.startswith("B-"):
+ if current_entity:
+ entity_type = current_entity["type"].lower()
+ result[entity_type] = current_entity["text"].replace("[UNK]", "").replace("##", "").strip()
+ current_entity = {"type": tag[2:], "text": token}
+ elif tag.startswith("I-") and current_entity and tag[2:] == current_entity["type"]:
+ current_entity["text"] += token
+ else:
+ if current_entity:
+ entity_type = current_entity["type"].lower()
+ result[entity_type] = current_entity["text"].replace("[UNK]", "").replace("##", "").strip()
+ current_entity = None
+
+ # 处理最后一个实体
+ if current_entity:
+ entity_type = current_entity["type"].lower()
+ result[entity_type] = current_entity["text"].replace("[UNK]", "").replace("##", "").strip()
+
+ # 处理公司名称,通常为12306
+ if result["company"]:
+ company = result["company"].strip()
+ # 如果文本中检测不到公司名称,但包含12306,则默认为12306
+ result["company"] = company
+ elif "12306" in text:
+ result["company"] = "12306"
+
+ # 处理车次格式
+ if result["trips"]:
+ trips_no = result["trips"].upper()
+ # 清理车次号,只保留字母和数字
+ trips_no = ''.join(c for c in trips_no if c.isalnum() or c in ['/', '-'])
+
+ # 验证车次格式
+ valid_patterns = [re.compile(pattern) for pattern in TrainNERConfig.TRIPS_CONFIG['patterns']]
+ if any(pattern.match(trips_no) for pattern in valid_patterns):
+ result["trips"] = trips_no
+ else:
+ # 尝试修复常见错误
+ if len(trips_no) >= TrainNERConfig.TRIPS_CONFIG['min_length'] and any(trips_no.startswith(t) for t in TrainNERConfig.TRIPS_CONFIG['train_types']):
+ result["trips"] = trips_no
+ elif trips_no.isdigit() and 1 <= len(trips_no) <= TrainNERConfig.TRIPS_CONFIG['max_length']:
+ result["trips"] = trips_no
+ else:
+ result["trips"] = None
+
+ # 清理日期格式
+ if result["date"]:
+ date_str = result["date"]
+ # 保留数字和常见日期分隔符
+ date_str = ''.join(c for c in date_str if c.isdigit() or c in ['年', '月', '日', '-', '/', '.'])
+ result["date"] = date_str
+
+ # 清理时间格式
+ if result["time"]:
+ time_str = result["time"]
+ # 保留数字和常见时间分隔符
+ time_str = ''.join(c for c in time_str if c.isdigit() or c in [':', '时', '分', '点'])
+ result["time"] = time_str
+
+ # 处理座位信息
+ if result["seat"]:
+ seat_str = result["seat"]
+ # 移除可能的额外空格和特殊字符
+ seat_str = seat_str.replace(" ", "").strip()
+ result["seat"] = seat_str
+
+ # 处理乘客姓名
+ if result["name"]:
+ name = result["name"].strip()
+ # 移除可能的标点符号
+ name = ''.join(c for c in name if c.isalnum() or c in ['*', '·'])
+ result["name"] = name
+
+ return result
+ except Exception as e:
+ logger.error(f"火车票实体提取失败: {str(e)}")
raise
# 创建Flask应用
@@ -695,6 +1003,10 @@
details = model_manager.extract_repayment_entities(text)
elif category == "收入":
details = model_manager.extract_income_entities(text)
+ elif category == "航班":
+ details = model_manager.extract_flight_entities(text)
+ elif category == "火车票": # 添加火车票类别处理
+ details = model_manager.extract_train_entities(text)
else:
details = {}
--
Gitblit v1.9.3