# -*- coding: utf-8 -*-
|
import os
|
import logging
|
import datetime
|
from typing import Dict, Optional, Tuple
|
|
from flask import Flask, request, jsonify
|
from transformers import BertTokenizer, BertForSequenceClassification, AutoTokenizer, AutoModelForTokenClassification
|
import torch
|
from werkzeug.exceptions import BadRequest
|
from ner_config import NERConfig, RepaymentNERConfig, IncomeNERConfig, FlightNERConfig, TrainNERConfig
|
import re
|
|
# 配置日志
|
logging.basicConfig(
|
level=logging.INFO,
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
handlers=[
|
logging.FileHandler('app.log'),
|
logging.StreamHandler()
|
]
|
)
|
logger = logging.getLogger(__name__)
|
|
class ModelManager:
|
def __init__(self):
|
self.classifier_path = "./models/classifier"
|
self.ner_path = "./models/ner_model/best_model"
|
self.repayment_path = "./models/repayment_model/best_model"
|
self.income_path = "./models/income_model/best_model"
|
# self.flight_path = "./models/flight_model/best_model"
|
# self.train_path = "./models/train_model/best_model" # 添加火车票模型路径
|
|
# 检查模型文件
|
self._check_model_files()
|
|
# 加载模型
|
self.classifier_tokenizer, self.classifier_model = self._load_classifier()
|
self.ner_tokenizer, self.ner_model = self._load_ner()
|
self.repayment_tokenizer, self.repayment_model = self._load_repayment()
|
self.income_tokenizer, self.income_model = self._load_income()
|
# self.flight_tokenizer, self.flight_model = self._load_flight()
|
# self.train_tokenizer, self.train_model = self._load_train() # 加载火车票模型
|
|
# 将模型设置为评估模式
|
self.classifier_model.eval()
|
self.ner_model.eval()
|
self.repayment_model.eval()
|
self.income_model.eval()
|
# self.flight_model.eval()
|
# self.train_model.eval() # 设置火车票模型为评估模式
|
|
def _check_model_files(self):
|
"""检查模型文件是否存在"""
|
if not os.path.exists(self.classifier_path):
|
raise RuntimeError("分类模型文件不存在,请先运行训练脚本")
|
if not os.path.exists(self.ner_path):
|
raise RuntimeError("NER模型文件不存在,请先运行训练脚本")
|
if not os.path.exists(self.repayment_path):
|
raise RuntimeError("还款模型文件不存在,请先运行训练脚本")
|
if not os.path.exists(self.income_path):
|
raise RuntimeError("收入模型文件不存在,请先运行训练脚本")
|
# if not os.path.exists(self.flight_path):
|
# raise RuntimeError("航班模型文件不存在,请先运行训练脚本")
|
# if not os.path.exists(self.train_path):
|
# raise RuntimeError("火车票模型文件不存在,请先运行训练脚本")
|
|
def _load_classifier(self) -> Tuple[BertTokenizer, BertForSequenceClassification]:
|
"""加载分类模型"""
|
try:
|
tokenizer = BertTokenizer.from_pretrained(self.classifier_path)
|
model = BertForSequenceClassification.from_pretrained(self.classifier_path)
|
return tokenizer, model
|
except Exception as e:
|
logger.error(f"加载分类模型失败: {str(e)}")
|
raise
|
|
def _load_ner(self) -> Tuple[AutoTokenizer, AutoModelForTokenClassification]:
|
"""加载NER模型"""
|
try:
|
tokenizer = AutoTokenizer.from_pretrained(self.ner_path)
|
model = AutoModelForTokenClassification.from_pretrained(self.ner_path)
|
return tokenizer, model
|
except Exception as e:
|
logger.error(f"加载NER模型失败: {str(e)}")
|
raise
|
|
def _load_repayment(self):
|
"""加载还款模型"""
|
try:
|
tokenizer = AutoTokenizer.from_pretrained(self.repayment_path)
|
model = AutoModelForTokenClassification.from_pretrained(self.repayment_path)
|
return tokenizer, model
|
except Exception as e:
|
logger.error(f"加载还款模型失败: {str(e)}")
|
raise
|
|
def _load_income(self):
|
"""加载收入模型"""
|
try:
|
tokenizer = AutoTokenizer.from_pretrained(self.income_path)
|
model = AutoModelForTokenClassification.from_pretrained(self.income_path)
|
return tokenizer, model
|
except Exception as e:
|
logger.error(f"加载收入模型失败: {str(e)}")
|
raise
|
|
def _load_flight(self):
|
"""加载航班模型"""
|
try:
|
tokenizer = AutoTokenizer.from_pretrained(self.flight_path)
|
model = AutoModelForTokenClassification.from_pretrained(self.flight_path)
|
return tokenizer, model
|
except Exception as e:
|
logger.error(f"加载航班模型失败: {str(e)}")
|
raise
|
|
def _load_train(self):
|
"""加载火车票模型"""
|
try:
|
tokenizer = AutoTokenizer.from_pretrained(self.train_path)
|
model = AutoModelForTokenClassification.from_pretrained(self.train_path)
|
return tokenizer, model
|
except Exception as e:
|
logger.error(f"加载火车票模型失败: {str(e)}")
|
raise
|
|
def classify_sms(self, text: str) -> Tuple[str, float]:
|
"""对短信进行分类,并返回置信度"""
|
try:
|
inputs = self.classifier_tokenizer(
|
text,
|
return_tensors="pt",
|
truncation=True,
|
max_length=64
|
)
|
with torch.no_grad():
|
outputs = self.classifier_model(**inputs)
|
|
# 获取预测标签及其对应的概率
|
logits = outputs.logits
|
probabilities = torch.softmax(logits, dim=1)
|
pred_id = logits.argmax().item()
|
confidence = probabilities[0, pred_id].item() # 获取预测标签的置信度
|
|
return self.classifier_model.config.id2label[pred_id], confidence
|
except Exception as e:
|
logger.error(f"短信分类失败: {str(e)}")
|
raise
|
|
def is_marketing_sms(self, text: str) -> bool:
|
"""判断是否为营销/广告类短信,采用评分系统"""
|
# 特定字符串模式检查:直接匹配明显的营销/通知短信
|
marketing_patterns = [
|
# 百度类通知
|
r"百度智能云.*?尊敬的用户",
|
r"百度.*?账户.*?tokens",
|
r"AppBuilder.*?账户",
|
r"账户有.*?免费额度",
|
r".*?免费额度.*?过期",
|
r"dwz\.cn\/[A-Za-z0-9]+"
|
]
|
|
# 对特定模式直接判断
|
for pattern in marketing_patterns:
|
if re.search(pattern, text):
|
return True # 直接认为是营销短信
|
|
# 评分系统:根据短信内容特征进行评分,超过阈值判定为营销短信
|
score = 0
|
|
# 强营销特征关键词(高权重)
|
strong_marketing_keywords = [
|
"有奖", "免费赠送", "抽奖", "中奖", "优惠券", "折扣券", "特价", "秒杀",
|
"限时抢购", "促销", "推广", "广告", "代金券", "0元购", "tokens调用量"
|
]
|
|
# 一般营销特征关键词(中等权重)
|
general_marketing_keywords = [
|
"活动", "优惠", "折扣", "限时", "抢购", "特价", "promotion", "推广",
|
"开业", "集点", "集赞", "关注", "公众号", "小程序", "注册有礼", "免费额度"
|
]
|
|
# 弱营销特征关键词(低权重,可能出现在正常短信中)
|
weak_marketing_keywords = [
|
"尊敬的用户", "尊敬的客户", "您好", "注册", "登录", "账户", "账号",
|
"会员", "积分", "权益", "提醒", "即将", "有效期", "过期", "升级",
|
"更新", "下载", "APP", "应用", "平台", "网址", "点击", "工单"
|
]
|
|
# 短网址和链接(独立评估,结合其他特征判断)
|
url_patterns = [
|
"dwz.cn", "t.cn", "短网址", "http://", "https://", "cmbt.cn"
|
]
|
|
# 业务短信特征(用于反向识别,降低误判率)
|
# 快递短信特征
|
express_keywords = [
|
"快递", "包裹", "取件码", "取件", "签收", "派送", "配送", "物流",
|
"驿站", "在途", "揽收", "暂存", "已到达", "丰巢", "柜取件", "柜机"
|
]
|
|
# 还款短信特征
|
repayment_keywords = [
|
"还款", "账单", "信用卡", "借款", "贷款", "逾期", "欠款", "最低还款",
|
"应还金额", "到期还款", "还清", "应还", "还款日", "账单¥", "账单¥", "查账还款"
|
]
|
|
# 收入短信特征
|
income_keywords = [
|
"收入", "转账", "入账", "到账", "支付", "工资", "报销", "余额",
|
"成功收款", "收到", "款项"
|
]
|
|
# 航班/火车票特征
|
travel_keywords = [
|
"航班", "航空", "飞机", "机票", "火车", "铁路", "列车", "车票",
|
"出发", "抵达", "起飞", "登机", "候车", "检票"
|
]
|
|
# 额外增加:通知类短信特征(通常不需要处理的短信)
|
notification_keywords = [
|
"余额不足", "话费不足", "话费余额", "通讯费", "流量用尽", "流量不足",
|
"停机", "恢复通话", "自动充值", "交费", "缴费",
|
"消费提醒", "交易提醒", "动账", "短信通知", "验证码", "校验码", "安全码"
|
]
|
|
# 运营商标识
|
telecom_keywords = [
|
"中国电信", "中国移动", "中国联通", "电信", "移动", "联通",
|
"携号转网", "号码服务", "通讯服务", "189.cn", "10086", "10010"
|
]
|
|
# 银行和金融机构标识
|
bank_keywords = [
|
"信用卡", "储蓄卡", "借记卡", "储蓄", "银联",
|
"建设银行", "工商银行", "农业银行", "中国银行", "交通银行",
|
"招商银行", "浦发银行", "民生银行", "兴业银行", "广发银行",
|
"平安银行", "中信银行", "光大银行", "华夏银行", "邮储银行",
|
"农商银行", "支付宝", "微信支付", "京东金融", "度小满", "陆金所"
|
]
|
|
# 特殊情况检查:招商银行账单短信,不应被过滤
|
if ("招商银行" in text and ("账单" in text or "还款日" in text)) or "cmbt.cn" in text:
|
if "还款" in text or "账单" in text or "消费卡" in text:
|
return False # 是还款短信,不过滤
|
|
# 计算评分
|
# 首先检查业务短信特征,如果明确是业务短信,直接返回False
|
has_express_feature = any(keyword in text for keyword in express_keywords)
|
has_repayment_feature = any(keyword in text for keyword in repayment_keywords)
|
has_income_feature = any(keyword in text for keyword in income_keywords)
|
has_travel_feature = any(keyword in text for keyword in travel_keywords)
|
|
# 检查是否为百度通知
|
is_baidu_notification = "百度" in text and "尊敬的用户" in text
|
if is_baidu_notification:
|
return True # 百度通知应被过滤
|
|
# 如果短信中包含多个业务关键词(≥2个),很可能是重要的业务短信
|
business_score = (has_express_feature + has_repayment_feature +
|
has_income_feature + has_travel_feature)
|
if business_score >= 2 and not is_baidu_notification:
|
return False # 多个业务特征同时存在,不太可能是营销短信
|
|
# 检查强营销特征
|
for keyword in strong_marketing_keywords:
|
if keyword in text:
|
score += 3
|
|
# 检查一般营销特征
|
for keyword in general_marketing_keywords:
|
if keyword in text:
|
score += 2
|
|
# 检查弱营销特征
|
for keyword in weak_marketing_keywords:
|
if keyword in text:
|
score += 1
|
|
# 检查URL特征(结合是否存在业务特征)
|
has_url = any(pattern in text for pattern in url_patterns)
|
|
# 降低业务特征短信的营销判定分数
|
if has_express_feature and not is_baidu_notification:
|
score -= 3 # 快递特征明显减分
|
|
if has_repayment_feature:
|
score -= 3 # 还款特征明显减分
|
|
if has_income_feature:
|
score -= 2 # 收入特征减分
|
|
if has_travel_feature:
|
score -= 2 # 旅行特征减分
|
|
# 检查通知类短信特征(但不包括重要的业务短信)
|
if not has_express_feature and not has_repayment_feature: # 确保不是快递和还款短信
|
notification_count = sum(1 for keyword in notification_keywords if keyword in text)
|
if notification_count >= 2: # 需要至少2个通知关键词才判定
|
score += notification_count # 增加判定为营销/通知短信的可能性
|
|
# 检查运营商和银行标识(结合其他特征判断)
|
has_telecom_feature = any(keyword in text for keyword in telecom_keywords)
|
has_bank_feature = any(keyword in text for keyword in bank_keywords)
|
|
# URL的评分处理
|
if has_url:
|
if (has_express_feature or has_repayment_feature or has_income_feature or has_travel_feature) and not is_baidu_notification:
|
# URL在业务短信中可能是正常的追踪链接,不增加评分
|
pass
|
else:
|
# 纯URL且无业务特征,可能是营销短信
|
score += 2
|
|
# 特殊情况:运营商余额通知
|
if has_telecom_feature and "余额" in text and not has_income_feature:
|
score += 2
|
|
# 设置判定阈值
|
threshold = 4 # 需要至少4分才判定为营销短信
|
|
return score >= threshold
|
|
def is_notification_sms(self, text: str) -> bool:
|
"""判断是否为通知类短信(如银行交易通知、运营商提醒等)"""
|
# 银行交易通知特征(不包括还款提醒)
|
bank_transaction_patterns = [
|
r"您尾号\d+的.+消费",
|
r"您.+账户消费[\d,.]+元",
|
r"交易[\d,.]+元",
|
r"支付宝.+消费",
|
r"微信支付.+消费",
|
r"\d{1,2}月\d{1,2}日\d{1,2}[::]\d{1,2}消费",
|
r"银行卡([支付|消费|扣款])"
|
]
|
|
# 排除规则:包含以下关键词的短信不应被判定为通知短信
|
business_keywords = [
|
# 还款关键词
|
"还款", "账单", "应还", "到期还款", "还款日", "最低还款", "账单¥", "账单¥", "查账还款",
|
# 快递关键词
|
"快递", "包裹", "取件码", "取件", "签收", "派送", "配送",
|
# 收入关键词
|
"收入", "转账", "入账", "到账", "支付成功", "工资"
|
]
|
|
# 运营商余额通知特征
|
telecom_balance_patterns = [
|
r"余额[不足|低于][\d,.]+元",
|
r"话费[不足|仅剩][\d,.]+元",
|
r"流量[不足|即将用尽]",
|
r"[电信|移动|联通].+余额",
|
r"[停机|停号]提醒",
|
r"为了保障您的正常通讯",
|
]
|
|
# 首先检查是否包含业务关键词,有则不应判定为通知短信
|
for keyword in business_keywords:
|
if keyword in text:
|
return False # 包含业务关键词,不是需要过滤的通知短信
|
|
# 检查银行交易通知模式
|
for pattern in bank_transaction_patterns:
|
if re.search(pattern, text):
|
logger.debug(f"识别到银行交易通知短信:{text[:30]}...")
|
return True
|
|
# 检查运营商余额通知模式
|
for pattern in telecom_balance_patterns:
|
if re.search(pattern, text):
|
logger.debug(f"识别到运营商余额通知短信:{text[:30]}...")
|
return True
|
|
return False
|
|
def extract_entities(self, text: str) -> Dict[str, Optional[str]]:
|
"""提取文本中的实体"""
|
try:
|
# 初始化结果字典
|
result = {
|
"post": None, # 快递公司
|
"company": None, # 寄件公司
|
"address": None, # 地址
|
"pickup_code": None, # 取件码
|
"time": None # 添加时间字段
|
}
|
|
# 第一阶段:直接从文本中提取取件码
|
pickup_code = self.extract_pickup_code_from_text(text)
|
if pickup_code:
|
result["pickup_code"] = pickup_code
|
|
# 第二阶段:使用NER模型提取其他实体
|
inputs = self.ner_tokenizer(
|
text,
|
return_tensors="pt",
|
truncation=True,
|
max_length=64
|
)
|
with torch.no_grad():
|
outputs = self.ner_model(**inputs)
|
|
predictions = torch.argmax(outputs.logits, dim=2)
|
tokens = self.ner_tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
|
tags = [self.ner_model.config.id2label[p] for p in predictions[0].numpy()]
|
|
# 解析实体
|
current_entity = None
|
entity_start = None
|
for i, (token, tag) in enumerate(zip(tokens, tags)):
|
# 跳过取件码实体,因为我们已经单独处理了
|
if tag.startswith("B-") and tag[2:] != "PICKUP_CODE":
|
# 保存之前的实体
|
if current_entity and current_entity["type"] != "PICKUP_CODE":
|
key = current_entity["type"].lower()
|
result[key] = current_entity["text"].replace("[UNK]", "").replace("##", "").strip()
|
# 开始新实体
|
current_entity = {"type": tag[2:], "text": token}
|
entity_start = i
|
elif tag.startswith("I-") and current_entity and tag[2:] == current_entity["type"] and current_entity["type"] != "PICKUP_CODE":
|
current_entity["text"] += token
|
elif tag == "O" or tag.startswith("B-"):
|
# 结束当前实体
|
if current_entity and current_entity["type"] != "PICKUP_CODE":
|
key = current_entity["type"].lower()
|
result[key] = current_entity["text"].replace("[UNK]", "").replace("##", "").strip()
|
current_entity = None
|
|
# 处理最后一个实体
|
if current_entity and current_entity["type"] != "PICKUP_CODE":
|
key = current_entity["type"].lower()
|
result[key] = current_entity["text"].replace("[UNK]", "").replace("##", "").strip()
|
|
# 如果第一阶段没有提取到取件码,使用NER模型的结果
|
if not result["pickup_code"]:
|
# 重新解析一次,只提取取件码
|
current_entity = None
|
for i, (token, tag) in enumerate(zip(tokens, tags)):
|
if tag.startswith("B-") and tag[2:] == "PICKUP_CODE":
|
current_entity = {"type": tag[2:], "text": token}
|
elif tag.startswith("I-") and current_entity and tag[2:] == "PICKUP_CODE":
|
current_entity["text"] += token
|
elif (tag == "O" or tag.startswith("B-")) and current_entity:
|
result["pickup_code"] = current_entity["text"].replace("[UNK]", "").replace("##", "").strip()
|
current_entity = None
|
|
# 处理最后一个实体
|
if current_entity and current_entity["type"] == "PICKUP_CODE":
|
result["pickup_code"] = current_entity["text"].replace("[UNK]", "").replace("##", "").strip()
|
|
# 后处理:清理和验证取件码
|
if result["pickup_code"]:
|
# 清理取件码
|
code = result["pickup_code"]
|
# 移除取件码后的额外文字
|
for word in ["的", "来", "到", "取", "件", "码"]:
|
if word in code:
|
code = code[:code.index(word)]
|
|
# 只保留字母、数字和连字符
|
code = ''.join(c for c in code if c.isalnum() or c == "-")
|
|
# 确保格式正确
|
parts = code.split("-")
|
valid_parts = []
|
for part in parts:
|
if part and any(c.isalnum() for c in part):
|
valid_parts.append(part)
|
|
if valid_parts:
|
result["pickup_code"] = "-".join(valid_parts)
|
else:
|
result["pickup_code"] = None
|
|
# 清理公司名称
|
if result["company"]:
|
company = result["company"]
|
invalid_words = ["包裹", "快递", "到"]
|
for word in invalid_words:
|
if company.endswith(word):
|
company = company[:-len(word)]
|
result["company"] = company.strip()
|
|
# 清理地址
|
if result["address"]:
|
address = result["address"]
|
invalid_suffixes = [",请尽快取件", ",询", "请尽快取件"]
|
for suffix in invalid_suffixes:
|
if address.endswith(suffix):
|
address = address[:-len(suffix)]
|
result["address"] = address.strip()
|
|
return result
|
except Exception as e:
|
logger.error(f"实体提取失败: {str(e)}")
|
raise
|
|
def extract_pickup_code_from_text(self, text: str) -> Optional[str]:
|
"""直接从文本提取取件码"""
|
# 先尝试直接匹配完整的取件码格式
|
pickup_patterns = [
|
# 带有上下文的取件码模式(包括字母数字组合)
|
r'凭\s*([A-Za-z0-9]{1,3}-\d{1,2}-\d{1,6})\s*(来|到)?', # 匹配"凭xx-xx-xxxx"格式
|
r'码\s*([A-Za-z0-9]{1,3}-\d{1,2}-\d{1,6})', # 匹配"码xx-xx-xxxx"格式
|
r'提货码\s*([A-Za-z0-9]{1,3}-\d{1,2}-\d{1,6})', # 匹配"提货码xx-xx-xxxx"格式
|
r'取件码\s*([A-Za-z0-9]{1,3}-\d{1,2}-\d{1,6})', # 匹配"取件码xx-xx-xxxx"格式
|
|
# 分组提取模式,对有上下文的取件码
|
r'凭\s*([A-Za-z0-9]{1,3})-(\d{1,2})-(\d{1,6})\s*(来|到)?', # 匹配"凭xx-xx-xxxx"
|
r'码\s*([A-Za-z0-9]{1,3})-(\d{1,2})-(\d{1,6})', # 匹配"码xx-xx-xxxx"
|
r'提货码\s*([A-Za-z0-9]{1,3})-(\d{1,2})-(\d{1,6})', # 匹配"提货码xx-xx-xxxx"
|
r'取件码\s*([A-Za-z0-9]{1,3})-(\d{1,2})-(\d{1,6})', # 匹配"取件码xx-xx-xxxx"
|
|
# 独立的取件码模式
|
r'([A-Za-z0-9]{1,3}-\d{1,2}-\d{1,6})', # 匹配独立的xx-xx-xxxx格式,支持字母
|
r'(\d{1,2}-\d{1,2}-\d{4,6})', # 匹配独立的xx-xx-xxxx格式,纯数字
|
]
|
|
# 首先尝试整体匹配
|
for pattern in pickup_patterns[:4]: # 前4个是整体匹配模式
|
matches = re.findall(pattern, text)
|
if matches:
|
return matches[0] if isinstance(matches[0], str) else matches[0][0]
|
|
# 然后尝试分组匹配
|
for pattern in pickup_patterns[4:8]: # 中间4个是分组匹配模式
|
matches = re.findall(pattern, text)
|
if matches:
|
# 分组匹配,第一部分可以是字母数字组合
|
match = matches[0]
|
if len(match) >= 3:
|
# 确保第2、3部分是数字
|
if match[1].isdigit() and match[2].isdigit():
|
return f"{match[0]}-{match[1]}-{match[2]}"
|
|
# 最后尝试独立模式
|
for pattern in pickup_patterns[8:10]: # 最后2个是独立匹配模式
|
matches = re.findall(pattern, text)
|
if matches:
|
return matches[0]
|
|
# 特殊处理包含"取件码"的文本
|
code_indicators = ["取件码", "提货码"]
|
for indicator in code_indicators:
|
if indicator in text:
|
idx = text.find(indicator)
|
# 检查紧跟着的文本
|
search_text = text[idx:idx+35] # 扩大搜索范围
|
# 尝试匹配"取件码A8-1-15"这样的格式
|
special_match = re.search(r'[码]\s*([A-Za-z0-9]+[-\s]+\d+[-\s]+\d+)', search_text)
|
if special_match:
|
# 规范化格式
|
code = special_match.group(1)
|
# 将空格替换为连字符,确保格式一致
|
code = re.sub(r'\s+', '-', code)
|
# 确保连字符格式正确
|
code = re.sub(r'-+', '-', code)
|
return code
|
|
# 特殊处理"凭"后面的数字序列
|
if "凭" in text:
|
pickup_index = text.find("凭")
|
if pickup_index != -1:
|
# 在"凭"之后的25个字符内寻找取件码
|
search_text = text[pickup_index:pickup_index+35]
|
|
# 尝试提取形如"凭A8-1-15"或"凭22-4-1111"的格式
|
alpha_num_match = re.search(r'凭\s*([A-Za-z0-9]+)[-\s]+(\d+)[-\s]+(\d+)', search_text)
|
if alpha_num_match:
|
return f"{alpha_num_match.group(1)}-{alpha_num_match.group(2)}-{alpha_num_match.group(3)}"
|
|
# 提取所有字母数字序列
|
parts = re.findall(r'[A-Za-z0-9]+', search_text)
|
if len(parts) >= 3:
|
# 组合前三个部分
|
return f"{parts[0]}-{parts[1]}-{parts[2]}"
|
|
# 查找形如"A8-1-15"的取件码
|
alpha_num_codes = re.findall(r'([A-Za-z]\d+)-(\d+)-(\d+)', text)
|
if alpha_num_codes:
|
match = alpha_num_codes[0]
|
return f"{match[0]}-{match[1]}-{match[2]}"
|
|
return None
|
|
def extract_repayment_entities(self, text: str) -> Dict[str, Optional[str]]:
|
"""提取还款相关实体"""
|
try:
|
result = {
|
"bank": None, # 还款主体
|
"type": None, # 还款类型
|
"amount": None, # 还款金额
|
"date": None, # 还款日期
|
"number": None, # 账号尾号
|
"min_amount": None # 最低还款金额
|
}
|
|
inputs = self.repayment_tokenizer(
|
text,
|
return_tensors="pt",
|
truncation=True,
|
max_length=RepaymentNERConfig.MAX_LENGTH
|
)
|
|
with torch.no_grad():
|
outputs = self.repayment_model(**inputs)
|
|
predictions = torch.argmax(outputs.logits, dim=2)
|
tokens = self.repayment_tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
|
tags = [self.repayment_model.config.id2label[p] for p in predictions[0].numpy()]
|
|
def clean_amount(amount_text: str, context: str) -> Optional[str]:
|
"""清理和标准化金额"""
|
if not amount_text:
|
return None
|
|
# 尝试直接在上下文中使用正则表达式查找更完整的金额
|
# 如果在同一句话里有类似"应还金额5,800元"这样的模式
|
amount_match = re.search(r'(?:应还|还款)?金额([\d,]+\.?\d*)(?:元|块钱|块|万元|万)?', context)
|
if amount_match:
|
return amount_match.group(1) # 直接返回匹配到的金额,保留原始格式
|
|
# 尝试查找最低还款金额
|
min_amount_match = re.search(r'最低还款([\d,]+\.?\d*)(?:元|块钱|块|万元|万)?', context)
|
if min_amount_match and "MIN_CODE" in current_entity["type"]:
|
return min_amount_match.group(1) # 直接返回匹配到的最低还款金额,保留原始格式
|
|
# 在上下文中查找完整金额
|
amount_index = context.find(amount_text)
|
if amount_index != -1:
|
# 扩大搜索范围,查找完整金额
|
search_start = max(0, amount_index - 10)
|
search_end = min(len(context), amount_index + len(amount_text) + 10)
|
search_text = context[search_start:search_end]
|
|
# 使用正则表达式查找金额
|
amount_pattern = re.compile(r'([\d,]+\.?\d*)(?:元|块钱|块|万元|万)?')
|
matches = list(amount_pattern.finditer(search_text))
|
|
if matches:
|
# 选择最接近的匹配结果
|
best_match = None
|
min_distance = float('inf')
|
|
for match in matches:
|
distance = abs(match.start() - (amount_index - search_start))
|
if distance < min_distance:
|
min_distance = distance
|
best_match = match.group(1) # 只取数字部分,保留逗号
|
|
if best_match:
|
return best_match
|
|
# 如果上述方法都没找到,则保留原始提取结果但验证其有效性
|
# 移除货币符号和无效词
|
for symbol in RepaymentNERConfig.AMOUNT_CONFIG['currency_symbols']:
|
amount_text = amount_text.replace(symbol, '')
|
for word in RepaymentNERConfig.AMOUNT_CONFIG['invalid_words']:
|
amount_text = amount_text.replace(word, '')
|
|
# 验证金额有效性
|
clean_amount = amount_text.replace(',', '')
|
try:
|
value = float(clean_amount)
|
if value > 0:
|
# 返回原始格式
|
return amount_text
|
except ValueError:
|
pass
|
|
return None
|
|
# 实体提取
|
current_entity = None
|
entities = {
|
"BANK": [],
|
"TYPE": [],
|
"PICKUP_CODE": [],
|
"DATE": [],
|
"NUMBER": [],
|
"MIN_CODE": []
|
}
|
|
for token, tag in zip(tokens, tags):
|
if tag.startswith("B-"):
|
if current_entity:
|
entities[current_entity["type"]].append(current_entity["text"])
|
current_entity = {"type": tag[2:], "text": token.replace("##", "")}
|
elif tag.startswith("I-") and current_entity and tag[2:] == current_entity["type"]:
|
current_entity["text"] += token.replace("##", "")
|
else:
|
if current_entity:
|
entities[current_entity["type"]].append(current_entity["text"])
|
current_entity = None
|
|
if current_entity:
|
entities[current_entity["type"]].append(current_entity["text"])
|
|
# 处理银行名称
|
if entities["BANK"]:
|
bank_parts = []
|
seen = set()
|
for bank in entities["BANK"]:
|
bank = bank.strip()
|
if bank and bank not in seen:
|
bank_parts.append(bank)
|
seen.add(bank)
|
bank = "".join(bank_parts)
|
if len(bank) <= RepaymentNERConfig.MAX_ENTITY_LENGTH["BANK"]:
|
result["bank"] = bank
|
|
# 处理还款类型
|
if entities["TYPE"]:
|
type_parts = []
|
seen = set()
|
for type_ in entities["TYPE"]:
|
type_ = type_.strip()
|
if type_ and type_ not in seen:
|
type_parts.append(type_)
|
seen.add(type_)
|
type_ = "".join(type_parts)
|
if len(type_) <= RepaymentNERConfig.MAX_ENTITY_LENGTH["TYPE"]:
|
result["type"] = type_
|
|
# 处理账号尾号
|
if entities["NUMBER"]:
|
number = "".join(c for c in "".join(entities["NUMBER"]) if c.isdigit())
|
if number and len(number) <= RepaymentNERConfig.MAX_ENTITY_LENGTH["NUMBER"]:
|
result["number"] = number
|
|
# 处理日期
|
if entities["DATE"]:
|
date = "".join(entities["DATE"])
|
date = ''.join(c for c in date if c.isdigit() or c in ['年', '月', '日', '-'])
|
if date:
|
result["date"] = date
|
|
# 处理金额
|
# 尝试匹配带¥符号的账单金额模式
|
amount_match = re.search(r'账单¥([\d,]+\.?\d*)', text)
|
if not amount_match:
|
# 尝试匹配带¥符号的账单金额模式
|
amount_match = re.search(r'账单¥([\d,]+\.?\d*)', text)
|
if not amount_match:
|
# 尝试匹配一般金额模式
|
amount_match = re.search(r'(?:应还|还款)?金额([\d,]+\.?\d*)(?:元|块钱|块|万元|万)?', text)
|
|
if amount_match:
|
amount = amount_match.group(1) # 保留原始格式(带逗号)
|
# 验证金额有效性
|
try:
|
value = float(amount.replace(',', ''))
|
if value > 0:
|
result["amount"] = amount
|
except ValueError:
|
pass
|
|
# 如果正则没有匹配到,使用NER结果
|
if not result["amount"]:
|
amount_candidates = []
|
# 从识别的实体中获取
|
for amount in entities["PICKUP_CODE"]:
|
cleaned_amount = clean_amount(amount, text)
|
if cleaned_amount:
|
try:
|
value = float(cleaned_amount.replace(',', ''))
|
amount_candidates.append((cleaned_amount, value))
|
except ValueError:
|
continue
|
|
# 如果还是没有找到,尝试从文本中提取
|
if not amount_candidates:
|
# 使用多个正则表达式匹配不同格式的金额
|
# 1. 匹配带¥符号格式
|
matches = list(re.finditer(r'¥([\d,]+\.?\d*)', text))
|
# 2. 匹配带¥符号格式
|
matches.extend(list(re.finditer(r'¥([\d,]+\.?\d*)', text)))
|
# 3. 匹配一般金额格式
|
matches.extend(list(re.finditer(r'([\d,]+\.?\d*)(?:元|块钱|块|万元|万)', text)))
|
|
for match in matches:
|
amount_text = match.group(1) # 获取数字部分,保留逗号
|
try:
|
value = float(amount_text.replace(',', ''))
|
amount_candidates.append((amount_text, value))
|
except ValueError:
|
continue
|
|
# 选择最大的有效金额
|
if amount_candidates:
|
result["amount"] = max(amount_candidates, key=lambda x: x[1])[0]
|
|
# 处理最低还款金额
|
# 先尝试使用正则表达式直接匹配最低还款金额
|
min_amount_match = re.search(r'最低还款([\d,]+\.?\d*)(?:元|块钱|块|万元|万)?', text)
|
if min_amount_match:
|
min_amount = min_amount_match.group(1) # 保留原始格式(带逗号)
|
# 验证金额有效性
|
try:
|
value = float(min_amount.replace(',', ''))
|
if value > 0:
|
result["min_amount"] = min_amount
|
except ValueError:
|
pass
|
|
# 如果正则没有匹配到,使用NER结果
|
if not result["min_amount"] and entities["MIN_CODE"]:
|
for amount in entities["MIN_CODE"]:
|
cleaned_amount = clean_amount(amount, text)
|
if cleaned_amount:
|
result["min_amount"] = cleaned_amount
|
break
|
|
return result
|
|
except Exception as e:
|
logger.error(f"还款实体提取失败: {str(e)}")
|
raise
|
|
def extract_income_entities(self, text: str) -> Dict[str, Optional[str]]:
|
"""提取收入相关实体"""
|
try:
|
result = {
|
"bank": None, # 银行名称
|
"number": None, # 账号尾号
|
"datetime": None, # 交易时间
|
"amount": None, # 收入金额
|
"balance": None # 余额
|
}
|
|
inputs = self.income_tokenizer(
|
text,
|
return_tensors="pt",
|
truncation=True,
|
max_length=IncomeNERConfig.MAX_LENGTH
|
)
|
|
with torch.no_grad():
|
outputs = self.income_model(**inputs)
|
|
predictions = torch.argmax(outputs.logits, dim=2)
|
tokens = self.income_tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
|
tags = [self.income_model.config.id2label[p] for p in predictions[0].numpy()]
|
|
def clean_amount(amount_text: str, context: str) -> Optional[str]:
|
"""清理和标准化金额"""
|
if not amount_text:
|
return None
|
|
# 在上下文中查找完整金额
|
amount_index = context.find(amount_text)
|
if amount_index != -1:
|
# 扩大搜索范围,查找完整金额
|
search_start = max(0, amount_index - 10)
|
search_end = min(len(context), amount_index + len(amount_text) + 10)
|
search_text = context[search_start:search_end]
|
|
# 使用更精确的正则表达式查找金额模式,支持带逗号的金额
|
amount_pattern = re.compile(r'(\d{1,3}(?:,\d{3})*(?:\.\d{1,2})?|\d+(?:\.\d{1,2})?)')
|
matches = list(amount_pattern.finditer(search_text))
|
|
# 找到最接近且最长的完整金额
|
best_match = None
|
min_distance = float('inf')
|
max_length = 0
|
target_pos = amount_index - search_start
|
|
for match in matches:
|
match_pos = match.start()
|
distance = abs(match_pos - target_pos)
|
match_text = match.group(1)
|
|
if len(match_text) > max_length or (len(match_text) == max_length and distance < min_distance):
|
try:
|
value = float(match_text)
|
if value > 0:
|
best_match = match_text
|
min_distance = distance
|
max_length = len(match_text)
|
except ValueError:
|
continue
|
|
if best_match:
|
amount_text = best_match
|
|
# 移除货币符号和无效词
|
for symbol in IncomeNERConfig.AMOUNT_CONFIG['currency_symbols']:
|
amount_text = amount_text.replace(symbol, '')
|
for word in IncomeNERConfig.AMOUNT_CONFIG['invalid_words']:
|
amount_text = amount_text.replace(word, '')
|
|
# 处理金额中的逗号
|
amount_text = amount_text.replace(',', '')
|
|
try:
|
value = float(amount_text)
|
if '.' in amount_text:
|
decimal_places = len(amount_text.split('.')[1])
|
return f"{value:.{decimal_places}f}"
|
return str(int(value))
|
except ValueError:
|
pass
|
return None
|
|
# 实体提取
|
current_entity = None
|
entities = {
|
"BANK": [],
|
"NUMBER": [],
|
"DATATIME": [],
|
"PICKUP_CODE": [],
|
"BALANCE": []
|
}
|
|
for token, tag in zip(tokens, tags):
|
if tag.startswith("B-"):
|
if current_entity:
|
entities[current_entity["type"]].append(current_entity["text"])
|
current_entity = {"type": tag[2:], "text": token.replace("##", "")}
|
elif tag.startswith("I-") and current_entity and tag[2:] == current_entity["type"]:
|
current_entity["text"] += token.replace("##", "")
|
else:
|
if current_entity:
|
entities[current_entity["type"]].append(current_entity["text"])
|
current_entity = None
|
|
if current_entity:
|
entities[current_entity["type"]].append(current_entity["text"])
|
|
# 处理银行名称
|
if entities["BANK"]:
|
bank_parts = []
|
seen = set()
|
for bank in entities["BANK"]:
|
bank = bank.strip()
|
if bank and bank not in seen:
|
bank_parts.append(bank)
|
seen.add(bank)
|
bank = "".join(bank_parts)
|
if len(bank) <= IncomeNERConfig.MAX_ENTITY_LENGTH["BANK"]:
|
result["bank"] = bank
|
|
# 处理账号尾号
|
if entities["NUMBER"]:
|
number = "".join(c for c in "".join(entities["NUMBER"]) if c.isdigit())
|
if number and len(number) <= IncomeNERConfig.MAX_ENTITY_LENGTH["NUMBER"]:
|
result["number"] = number
|
|
# 处理交易时间
|
if entities["DATATIME"]:
|
datetime = "".join(entities["DATATIME"])
|
datetime = ''.join(c for c in datetime if c.isdigit() or c in ['年', '月', '日', '时', '分', ':', '-'])
|
if datetime:
|
result["datetime"] = datetime
|
|
# 处理收入金额
|
# 先尝试使用正则表达式直接匹配收入金额,包括"收入金额"格式
|
amount_match = re.search(r'收入金额([\d,]+\.?\d*)元', text)
|
if not amount_match:
|
# 尝试匹配一般收入格式
|
amount_match = re.search(r'收入([\d,]+\.?\d*)元', text)
|
|
if amount_match:
|
amount = amount_match.group(1) # 保留原始格式(带逗号)
|
# 验证金额有效性
|
try:
|
value = float(amount.replace(',', ''))
|
if value > 0:
|
result["amount"] = amount
|
except ValueError:
|
pass
|
|
# 如果正则没有匹配到,继续尝试NER结果
|
if not result["amount"]:
|
amount_candidates = []
|
# 首先从识别的实体中获取
|
for amount in entities["PICKUP_CODE"]:
|
cleaned_amount = clean_amount(amount, text)
|
if cleaned_amount:
|
try:
|
value = float(cleaned_amount)
|
amount_candidates.append((cleaned_amount, value))
|
except ValueError:
|
continue
|
|
# 如果没有找到有效金额,直接从文本中尝试提取
|
if not amount_candidates:
|
# 尝试多种模式匹配金额
|
# 1. 匹配"收入金额xxx元"模式
|
matches = list(re.finditer(r'收入金额([\d,]+\.?\d*)元', text))
|
# 2. 匹配"收入xxx元"模式
|
matches.extend(list(re.finditer(r'收入([\d,]+\.?\d*)元', text)))
|
# 3. 匹配带元结尾的金额
|
matches.extend(list(re.finditer(r'([0-9,]+\.[0-9]+)元', text)))
|
# 4. 匹配普通数字(可能是余额),但排除已识别为余额的金额
|
if "余额" in text:
|
balance_match = re.search(r'余额([\d,]+\.?\d*)元', text)
|
if balance_match:
|
balance_value = balance_match.group(1)
|
# 只匹配不等于余额的金额
|
all_numbers = re.finditer(r'(\d{1,3}(?:,\d{3})*(?:\.\d{1,2})?|\d+(?:\.\d{1,2})?)', text)
|
for match in all_numbers:
|
if match.group(1) != balance_value:
|
matches.append(match)
|
else:
|
matches.extend(list(re.finditer(r'(\d{1,3}(?:,\d{3})*(?:\.\d{1,2})?|\d+(?:\.\d{1,2})?)', text)))
|
|
for match in matches:
|
amount_text = match.group(1)
|
try:
|
value = float(amount_text.replace(',', ''))
|
amount_candidates.append((amount_text, value))
|
except ValueError:
|
continue
|
|
# 从金额候选中排除已识别的余额值
|
if result["balance"]:
|
try:
|
balance_value = float(result["balance"].replace(',', ''))
|
amount_candidates = [(text, value) for text, value in amount_candidates if abs(value - balance_value) > 0.01]
|
except ValueError:
|
pass
|
|
# 选择适当的金额作为收入
|
if amount_candidates:
|
has_income_amount_keyword = "收入金额" in text
|
|
if has_income_amount_keyword:
|
# 查找"收入金额"附近的数字
|
idx = text.find("收入金额")
|
if idx != -1:
|
closest_amount = None
|
min_distance = float('inf')
|
for amount_text, value in amount_candidates:
|
# 找到这个数字在原文中的位置
|
amount_idx = text.find(amount_text)
|
if amount_idx != -1:
|
distance = abs(amount_idx - idx)
|
if distance < min_distance:
|
min_distance = distance
|
closest_amount = amount_text
|
|
if closest_amount:
|
result["amount"] = closest_amount
|
else:
|
# 如果无法找到最近的金额,使用最大金额策略
|
result["amount"] = max(amount_candidates, key=lambda x: x[1])[0]
|
else:
|
# 如果没有"收入金额"关键词,则使用最大金额策略
|
result["amount"] = max(amount_candidates, key=lambda x: x[1])[0]
|
|
# 处理余额
|
# 先尝试使用正则表达式直接匹配余额
|
balance_match = re.search(r'余额([\d,]+\.?\d*)元', text)
|
if balance_match:
|
balance = balance_match.group(1) # 保留原始格式(带逗号)
|
# 验证金额有效性
|
try:
|
value = float(balance.replace(',', ''))
|
if value > 0:
|
result["balance"] = balance
|
except ValueError:
|
pass
|
|
# 如果正则没有匹配到,使用NER结果
|
if not result["balance"] and entities["BALANCE"]:
|
for amount in entities["BALANCE"]:
|
cleaned_amount = clean_amount(amount, text)
|
if cleaned_amount:
|
result["balance"] = cleaned_amount
|
break
|
|
return result
|
|
except Exception as e:
|
logger.error(f"收入实体提取失败: {str(e)}")
|
raise
|
|
def extract_flight_entities(self, text: str) -> Dict[str, Optional[str]]:
|
"""提取航班相关实体"""
|
try:
|
# 初始化结果字典
|
result = {
|
"flight": None, # 航班号
|
"company": None, # 航空公司
|
"start": None, # 出发地
|
"end": None, # 目的地
|
"date": None, # 日期
|
"time": None, # 时间
|
"departure_time": None, # 起飞时间
|
"arrival_time": None, # 到达时间
|
"ticket_num": None, # 机票号码
|
"seat": None # 座位等信息
|
}
|
|
# 使用NER模型提取实体
|
inputs = self.flight_tokenizer(
|
text,
|
return_tensors="pt",
|
truncation=True,
|
max_length=FlightNERConfig.MAX_LENGTH
|
)
|
|
with torch.no_grad():
|
outputs = self.flight_model(**inputs)
|
|
predictions = torch.argmax(outputs.logits, dim=2)
|
tokens = self.flight_tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
|
tags = [self.flight_model.config.id2label[p] for p in predictions[0].numpy()]
|
|
# 解析实体
|
current_entity = None
|
|
for token, tag in zip(tokens, tags):
|
if tag.startswith("B-"):
|
if current_entity:
|
entity_type = current_entity["type"].lower()
|
result[entity_type] = current_entity["text"].replace("[UNK]", "").replace("##", "").strip()
|
current_entity = {"type": tag[2:], "text": token}
|
elif tag.startswith("I-") and current_entity and tag[2:] == current_entity["type"]:
|
current_entity["text"] += token
|
else:
|
if current_entity:
|
entity_type = current_entity["type"].lower()
|
result[entity_type] = current_entity["text"].replace("[UNK]", "").replace("##", "").strip()
|
current_entity = None
|
|
# 处理最后一个实体
|
if current_entity:
|
entity_type = current_entity["type"].lower()
|
result[entity_type] = current_entity["text"].replace("[UNK]", "").replace("##", "").strip()
|
|
# 处理航班号格式
|
if result["flight"]:
|
flight_no = result["flight"].upper()
|
# 清理航班号,只保留字母和数字
|
flight_no = ''.join(c for c in flight_no if c.isalnum())
|
# 验证航班号格式
|
valid_pattern = re.compile(FlightNERConfig.FLIGHT_CONFIG['pattern'])
|
if valid_pattern.match(flight_no):
|
result["flight"] = flight_no
|
else:
|
# 尝试修复常见错误
|
if len(flight_no) >= FlightNERConfig.FLIGHT_CONFIG['min_length'] and flight_no[:2].isalpha() and flight_no[2:].isdigit():
|
result["flight"] = flight_no
|
else:
|
result["flight"] = None
|
|
# 清理日期格式
|
if result["date"]:
|
date_str = result["date"]
|
# 保留数字和常见日期分隔符
|
date_str = ''.join(c for c in date_str if c.isdigit() or c in ['年', '月', '日', '-', '/', '.'])
|
result["date"] = date_str
|
|
# 清理时间格式
|
for time_field in ["time", "departure_time", "arrival_time"]:
|
if result[time_field]:
|
time_str = result[time_field]
|
# 保留数字和常见时间分隔符
|
time_str = ''.join(c for c in time_str if c.isdigit() or c in [':', '时', '分', '点'])
|
result[time_field] = time_str
|
|
# 处理机票号码
|
if result["ticket_num"]:
|
ticket_num = result["ticket_num"]
|
# 清理机票号码,只保留字母和数字
|
ticket_num = ''.join(c for c in ticket_num if c.isalnum())
|
result["ticket_num"] = ticket_num
|
|
# 处理座位信息
|
if result["seat"]:
|
seat_str = result["seat"]
|
# 移除可能的额外空格和特殊字符
|
seat_str = seat_str.replace(" ", "").strip()
|
result["seat"] = seat_str
|
|
return result
|
except Exception as e:
|
logger.error(f"航班实体提取失败: {str(e)}")
|
raise
|
|
def extract_train_entities(self, text: str) -> Dict[str, Optional[str]]:
|
"""提取火车票相关实体"""
|
try:
|
# 初始化结果字典
|
result = {
|
"company": None, # 12306
|
"trips": None, # 车次
|
"start": None, # 出发站
|
"end": None, # 到达站
|
"date": None, # 日期
|
"time": None, # 时间
|
"seat": None, # 座位等信息
|
"name": None # 用户姓名
|
}
|
|
# 使用NER模型提取实体
|
inputs = self.train_tokenizer(
|
text,
|
return_tensors="pt",
|
truncation=True,
|
max_length=TrainNERConfig.MAX_LENGTH
|
)
|
|
with torch.no_grad():
|
outputs = self.train_model(**inputs)
|
|
predictions = torch.argmax(outputs.logits, dim=2)
|
tokens = self.train_tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
|
tags = [self.train_model.config.id2label[p] for p in predictions[0].numpy()]
|
|
# 解析实体
|
current_entity = None
|
|
for token, tag in zip(tokens, tags):
|
if tag.startswith("B-"):
|
if current_entity:
|
entity_type = current_entity["type"].lower()
|
result[entity_type] = current_entity["text"].replace("[UNK]", "").replace("##", "").strip()
|
current_entity = {"type": tag[2:], "text": token}
|
elif tag.startswith("I-") and current_entity and tag[2:] == current_entity["type"]:
|
current_entity["text"] += token
|
else:
|
if current_entity:
|
entity_type = current_entity["type"].lower()
|
result[entity_type] = current_entity["text"].replace("[UNK]", "").replace("##", "").strip()
|
current_entity = None
|
|
# 处理最后一个实体
|
if current_entity:
|
entity_type = current_entity["type"].lower()
|
result[entity_type] = current_entity["text"].replace("[UNK]", "").replace("##", "").strip()
|
|
# 处理公司名称,通常为12306
|
if result["company"]:
|
company = result["company"].strip()
|
# 如果文本中检测不到公司名称,但包含12306,则默认为12306
|
result["company"] = company
|
elif "12306" in text:
|
result["company"] = "12306"
|
|
# 处理车次格式
|
if result["trips"]:
|
trips_no = result["trips"].upper()
|
# 清理车次号,只保留字母和数字
|
trips_no = ''.join(c for c in trips_no if c.isalnum() or c in ['/', '-'])
|
|
# 验证车次格式
|
valid_patterns = [re.compile(pattern) for pattern in TrainNERConfig.TRIPS_CONFIG['patterns']]
|
if any(pattern.match(trips_no) for pattern in valid_patterns):
|
result["trips"] = trips_no
|
else:
|
# 尝试修复常见错误
|
if len(trips_no) >= TrainNERConfig.TRIPS_CONFIG['min_length'] and any(trips_no.startswith(t) for t in TrainNERConfig.TRIPS_CONFIG['train_types']):
|
result["trips"] = trips_no
|
elif trips_no.isdigit() and 1 <= len(trips_no) <= TrainNERConfig.TRIPS_CONFIG['max_length']:
|
result["trips"] = trips_no
|
else:
|
result["trips"] = None
|
|
# 清理日期格式
|
if result["date"]:
|
date_str = result["date"]
|
# 保留数字和常见日期分隔符
|
date_str = ''.join(c for c in date_str if c.isdigit() or c in ['年', '月', '日', '-', '/', '.'])
|
result["date"] = date_str
|
|
# 清理时间格式
|
if result["time"]:
|
time_str = result["time"]
|
# 保留数字和常见时间分隔符
|
time_str = ''.join(c for c in time_str if c.isdigit() or c in [':', '时', '分', '点'])
|
result["time"] = time_str
|
|
# 处理座位信息
|
if result["seat"]:
|
seat_str = result["seat"]
|
# 移除可能的额外空格和特殊字符
|
seat_str = seat_str.replace(" ", "").strip()
|
result["seat"] = seat_str
|
|
# 处理乘客姓名
|
if result["name"]:
|
name = result["name"].strip()
|
# 移除可能的标点符号
|
name = ''.join(c for c in name if c.isalnum() or c in ['*', '·'])
|
result["name"] = name
|
|
return result
|
except Exception as e:
|
logger.error(f"火车票实体提取失败: {str(e)}")
|
raise
|
|
# 创建Flask应用
|
app = Flask(__name__)
|
model_manager = ModelManager()
|
|
# 添加保存短信到文件的函数
|
def save_sms_to_file(text: str, category: str = None, confidence: float = None) -> bool:
|
"""
|
将短信内容保存到本地文件
|
|
Args:
|
text: 短信内容
|
category: 分类结果
|
confidence: 分类置信度
|
|
Returns:
|
bool: 保存成功返回True,否则返回False
|
"""
|
try:
|
# 确保日志目录存在
|
log_dir = "./sms_logs"
|
if not os.path.exists(log_dir):
|
os.makedirs(log_dir)
|
|
# 创建基于日期的文件名
|
today = datetime.datetime.now().strftime("%Y-%m-%d")
|
file_path = os.path.join(log_dir, f"sms_log_{today}.txt")
|
|
# 获取当前时间
|
current_time = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
|
# 准备要写入的内容
|
category_info = f"分类: {category}, 置信度: {confidence:.4f}" if category and confidence else "未分类"
|
log_content = f"[{current_time}] {category_info}\n{text}\n{'='*50}\n"
|
|
# 以追加模式写入文件
|
with open(file_path, 'a', encoding='utf-8') as f:
|
f.write(log_content)
|
|
return True
|
except Exception as e:
|
logger.error(f"保存短信到文件失败: {str(e)}")
|
return False
|
|
@app.route("/health", methods=["GET"])
|
def health_check():
|
"""健康检查接口"""
|
return jsonify({"status": "healthy"})
|
|
@app.route("/process-sms", methods=["POST"])
|
def process_sms():
|
"""处理短信的接口"""
|
try:
|
# 验证请求数据
|
if not request.is_json:
|
raise BadRequest("请求必须是JSON格式")
|
|
data = request.get_json()
|
if "content" not in data:
|
raise BadRequest("请求中必须包含'content'字段")
|
|
text = data["content"]
|
if not isinstance(text, str) or not text.strip():
|
raise BadRequest("短信内容不能为空")
|
|
# 保存原始短信内容到文件
|
save_sms_to_file(text)
|
|
# 特定短信识别逻辑 - 针对百度通知和招商银行账单
|
# 识别百度通知
|
if "百度智能云" in text and "尊敬的用户" in text and "免费额度" in text:
|
logger.info(f"直接识别为百度通知短信: {text[:30]}...")
|
category = "其他"
|
save_sms_to_file(text, category, 1.0) # 记录分类结果
|
return jsonify({
|
"status": "success",
|
"data": {
|
"category": category,
|
"details": {}
|
}
|
})
|
|
# 识别招商银行账单
|
if "招商银行" in text and ("账单¥" in text or "账单¥" in text or "还款日" in text):
|
logger.info(f"直接识别为招商银行还款短信: {text[:30]}...")
|
category = "还款"
|
details = model_manager.extract_repayment_entities(text)
|
save_sms_to_file(text, category, 1.0) # 记录分类结果
|
return jsonify({
|
"status": "success",
|
"data": {
|
"category": category,
|
"details": details
|
}
|
})
|
|
# 处理短信
|
category, confidence = model_manager.classify_sms(text)
|
|
# 保存短信内容和分类结果
|
save_sms_to_file(text, category, confidence)
|
|
# 如果是明确的业务短信类别,直接进入处理流程
|
if category in ["快递", "还款", "收入", "航班", "火车票"] and confidence > 0.5:
|
# 对百度通知的特殊处理
|
if category == "快递" and "百度" in text and "尊敬的用户" in text:
|
logger.info(f"纠正百度通知短信的分类: {text[:30]}...")
|
category = "其他"
|
save_sms_to_file(text, category, confidence) # 更新分类结果
|
return jsonify({
|
"status": "success",
|
"data": {
|
"category": category,
|
"details": {}
|
}
|
})
|
|
# 对于高置信度的业务分类,直接进入实体提取
|
if category == "快递":
|
details = model_manager.extract_entities(text)
|
elif category == "还款":
|
details = model_manager.extract_repayment_entities(text)
|
elif category == "收入":
|
details = model_manager.extract_income_entities(text)
|
elif category == "航班":
|
details = model_manager.extract_flight_entities(text)
|
elif category == "火车票":
|
details = model_manager.extract_train_entities(text)
|
|
logger.info(f"高置信度业务短信: {text[:30]}..., category: {category}, confidence: {confidence:.4f}")
|
return jsonify({
|
"status": "success",
|
"data": {
|
"category": category,
|
"details": details
|
}
|
})
|
|
# 检查是否为营销/广告短信
|
if model_manager.is_marketing_sms(text):
|
# 如果是营销/广告短信,直接归类为"其他"
|
logger.info(f"检测到营销/广告短信: {text[:30]}...")
|
category = "其他"
|
save_sms_to_file(text, category, confidence) # 更新分类结果
|
return jsonify({
|
"status": "success",
|
"data": {
|
"category": category,
|
"details": {}
|
}
|
})
|
|
# 检查是否为通知类短信
|
if model_manager.is_notification_sms(text):
|
# 如果是通知类短信,直接归类为"其他"
|
logger.info(f"检测到通知类短信: {text[:30]}...")
|
category = "其他"
|
save_sms_to_file(text, category, confidence) # 更新分类结果
|
return jsonify({
|
"status": "success",
|
"data": {
|
"category": category,
|
"details": {}
|
}
|
})
|
|
# 置信度阈值,低于此阈值的分类结果被视为"其他"
|
confidence_threshold = 0.7
|
if confidence < confidence_threshold:
|
logger.info(f"短信分类置信度低({confidence:.4f}),归类为'其他': {text[:30]}...")
|
category = "其他"
|
save_sms_to_file(text, category, confidence) # 更新分类结果
|
return jsonify({
|
"status": "success",
|
"data": {
|
"category": category,
|
"details": {}
|
}
|
})
|
|
# 根据分类结果调用对应的实体提取函数
|
if category == "快递":
|
details = model_manager.extract_entities(text)
|
elif category == "还款":
|
details = model_manager.extract_repayment_entities(text)
|
elif category == "收入":
|
details = model_manager.extract_income_entities(text)
|
elif category == "航班":
|
details = model_manager.extract_flight_entities(text)
|
elif category == "火车票":
|
details = model_manager.extract_train_entities(text)
|
else:
|
details = {}
|
|
# 记录处理结果
|
logger.info(f"Successfully processed SMS: {text[:30]}..., category: {category}, confidence: {confidence:.4f}")
|
|
return jsonify({
|
"status": "success",
|
"data": {
|
"category": category,
|
"details": details
|
}
|
})
|
save_sms_to_file
|
except BadRequest as e:
|
logger.warning(f"Invalid request: {str(e)}")
|
return jsonify({
|
"status": "error",
|
"message": str(e)
|
}), 400
|
|
except Exception as e:
|
logger.error(f"Error processing SMS: {str(e)}")
|
return jsonify({
|
"status": "error",
|
"message": "服务器内部错误"
|
}), 500
|
|
if __name__ == "__main__":
|
app.run(host="0.0.0.0", port=5000)
|