From e46063e99f4cc6b22de4e32227289da3122646dd Mon Sep 17 00:00:00 2001 From: cloudroam <cloudroam> Date: 星期三, 16 四月 2025 16:55:20 +0800 Subject: [PATCH] fix --- app.py | 462 ++++++++++++++++++++++++++++++++++++++++++++++++--------- 1 files changed, 387 insertions(+), 75 deletions(-) diff --git a/app.py b/app.py index f697651..ba71f7e 100644 --- a/app.py +++ b/app.py @@ -7,7 +7,7 @@ from transformers import BertTokenizer, BertForSequenceClassification, AutoTokenizer, AutoModelForTokenClassification import torch from werkzeug.exceptions import BadRequest -from ner_config import NERConfig, RepaymentNERConfig, IncomeNERConfig +from ner_config import NERConfig, RepaymentNERConfig, IncomeNERConfig, FlightNERConfig, TrainNERConfig import re # 配置日志 @@ -27,6 +27,8 @@ self.ner_path = "./models/ner_model/best_model" self.repayment_path = "./models/repayment_model/best_model" self.income_path = "./models/income_model/best_model" + # self.flight_path = "./models/flight_model/best_model" + # self.train_path = "./models/train_model/best_model" # 添加火车票模型路径 # 检查模型文件 self._check_model_files() @@ -36,12 +38,16 @@ self.ner_tokenizer, self.ner_model = self._load_ner() self.repayment_tokenizer, self.repayment_model = self._load_repayment() self.income_tokenizer, self.income_model = self._load_income() + # self.flight_tokenizer, self.flight_model = self._load_flight() + # self.train_tokenizer, self.train_model = self._load_train() # 加载火车票模型 # 将模型设置为评估模式 self.classifier_model.eval() self.ner_model.eval() self.repayment_model.eval() self.income_model.eval() + # self.flight_model.eval() + # self.train_model.eval() # 设置火车票模型为评估模式 def _check_model_files(self): """检查模型文件是否存在""" @@ -53,6 +59,10 @@ raise RuntimeError("还款模型文件不存在,请先运行训练脚本") if not os.path.exists(self.income_path): raise RuntimeError("收入模型文件不存在,请先运行训练脚本") + # if not os.path.exists(self.flight_path): + # raise RuntimeError("航班模型文件不存在,请先运行训练脚本") + # if not os.path.exists(self.train_path): + # raise RuntimeError("火车票模型文件不存在,请先运行训练脚本") def _load_classifier(self) -> Tuple[BertTokenizer, BertForSequenceClassification]: """加载分类模型""" @@ -94,6 +104,26 @@ logger.error(f"加载收入模型失败: {str(e)}") raise + def _load_flight(self): + """加载航班模型""" + try: + tokenizer = AutoTokenizer.from_pretrained(self.flight_path) + model = AutoModelForTokenClassification.from_pretrained(self.flight_path) + return tokenizer, model + except Exception as e: + logger.error(f"加载航班模型失败: {str(e)}") + raise + + def _load_train(self): + """加载火车票模型""" + try: + tokenizer = AutoTokenizer.from_pretrained(self.train_path) + model = AutoModelForTokenClassification.from_pretrained(self.train_path) + return tokenizer, model + except Exception as e: + logger.error(f"加载火车票模型失败: {str(e)}") + raise + def classify_sms(self, text: str) -> str: """对短信进行分类""" try: @@ -120,7 +150,7 @@ "company": None, # 寄件公司 "address": None, # 地址 "pickup_code": None, # 取件码 - "time": None # 时间 + "time": None # 添加时间字段 } # 第一阶段:直接从文本中提取取件码 @@ -352,67 +382,60 @@ if not amount_text: return None + # 尝试直接在上下文中使用正则表达式查找更完整的金额 + # 如果在同一句话里有类似"应还金额5,800元"这样的模式 + amount_match = re.search(r'(?:应还|还款)?金额([\d,]+\.?\d*)(?:元|块钱|块|万元|万)?', context) + if amount_match: + return amount_match.group(1) # 直接返回匹配到的金额,保留原始格式 + + # 尝试查找最低还款金额 + min_amount_match = re.search(r'最低还款([\d,]+\.?\d*)(?:元|块钱|块|万元|万)?', context) + if min_amount_match and "MIN_CODE" in current_entity["type"]: + return min_amount_match.group(1) # 直接返回匹配到的最低还款金额,保留原始格式 + # 在上下文中查找完整金额 amount_index = context.find(amount_text) if amount_index != -1: # 扩大搜索范围,查找完整金额 - search_start = max(0, amount_index - 10) # 增加向前搜索范围 + search_start = max(0, amount_index - 10) search_end = min(len(context), amount_index + len(amount_text) + 10) search_text = context[search_start:search_end] - # 使用更精确的正则表达式查找金额模式 - amount_pattern = re.compile(r'(\d{1,10}(?:\.\d{1,2})?)') + # 使用正则表达式查找金额 + amount_pattern = re.compile(r'([\d,]+\.?\d*)(?:元|块钱|块|万元|万)?') matches = list(amount_pattern.finditer(search_text)) - # 找到最接近且最长的完整金额 - best_match = None - min_distance = float('inf') - max_length = 0 - target_pos = amount_index - search_start - - for match in matches: - match_pos = match.start() - distance = abs(match_pos - target_pos) - match_text = match.group(1) + if matches: + # 选择最接近的匹配结果 + best_match = None + min_distance = float('inf') - # 优先选择更长的匹配,除非距离差异太大 - if len(match_text) > max_length or (len(match_text) == max_length and distance < min_distance): - try: - # 验证金额是否合理 - value = float(match_text) - if value > 0 and value <= 9999999.99: # 设置合理的金额范围 - best_match = match_text - min_distance = distance - max_length = len(match_text) - except ValueError: - continue + for match in matches: + distance = abs(match.start() - (amount_index - search_start)) + if distance < min_distance: + min_distance = distance + best_match = match.group(1) # 只取数字部分,保留逗号 + + if best_match: + return best_match - if best_match: - amount_text = best_match - + # 如果上述方法都没找到,则保留原始提取结果但验证其有效性 # 移除货币符号和无效词 for symbol in RepaymentNERConfig.AMOUNT_CONFIG['currency_symbols']: amount_text = amount_text.replace(symbol, '') for word in RepaymentNERConfig.AMOUNT_CONFIG['invalid_words']: amount_text = amount_text.replace(word, '') - # 处理金额中的逗号 - amount_text = amount_text.replace(',', '') - + # 验证金额有效性 + clean_amount = amount_text.replace(',', '') try: - # 转换为浮点数 - value = float(amount_text) - - # 验证整数位数 - integer_part = str(int(value)) - if len(integer_part) <= RepaymentNERConfig.AMOUNT_CONFIG['max_integer_digits']: - # 保持原始小数位数 - if '.' in amount_text: - decimal_places = len(amount_text.split('.')[1]) - return f"{value:.{decimal_places}f}" - return str(int(value)) + value = float(clean_amount) + if value > 0: + # 返回原始格式 + return amount_text except ValueError: pass + return None # 实体提取 @@ -443,12 +466,11 @@ # 处理银行名称 if entities["BANK"]: - # 修改银行名称处理逻辑 bank_parts = [] - seen = set() # 用于去重 + seen = set() for bank in entities["BANK"]: bank = bank.strip() - if bank and bank not in seen: # 避免重复 + if bank and bank not in seen: bank_parts.append(bank) seen.add(bank) bank = "".join(bank_parts) @@ -457,7 +479,14 @@ # 处理还款类型 if entities["TYPE"]: - type_ = "".join(entities["TYPE"]).strip() + type_parts = [] + seen = set() + for type_ in entities["TYPE"]: + type_ = type_.strip() + if type_ and type_ not in seen: + type_parts.append(type_) + seen.add(type_) + type_ = "".join(type_parts) if len(type_) <= RepaymentNERConfig.MAX_ENTITY_LENGTH["TYPE"]: result["type"] = type_ @@ -470,32 +499,74 @@ # 处理日期 if entities["DATE"]: date = "".join(entities["DATE"]) - date = ''.join(c for c in date if c.isdigit() or c in ['年', '月', '日']) + date = ''.join(c for c in date if c.isdigit() or c in ['年', '月', '日', '-']) if date: result["date"] = date # 处理金额 - amount_candidates = [] - for amount in entities["PICKUP_CODE"]: - cleaned_amount = clean_amount(amount, text) - if cleaned_amount: - try: - value = float(cleaned_amount) - amount_candidates.append((cleaned_amount, value)) - except ValueError: - continue + # 先尝试使用正则表达式直接匹配金额 + amount_match = re.search(r'(?:应还|还款)?金额([\d,]+\.?\d*)(?:元|块钱|块|万元|万)?', text) + if amount_match: + amount = amount_match.group(1) # 保留原始格式(带逗号) + # 验证金额有效性 + try: + value = float(amount.replace(',', '')) + if value > 0: + result["amount"] = amount + except ValueError: + pass - # 选择最大的有效金额 - if amount_candidates: - # 按金额大小排序,选择最大的 - result["amount"] = max(amount_candidates, key=lambda x: x[1])[0] + # 如果正则没有匹配到,使用NER结果 + if not result["amount"]: + amount_candidates = [] + # 从识别的实体中获取 + for amount in entities["PICKUP_CODE"]: + cleaned_amount = clean_amount(amount, text) + if cleaned_amount: + try: + value = float(cleaned_amount.replace(',', '')) + amount_candidates.append((cleaned_amount, value)) + except ValueError: + continue + + # 如果还是没有找到,尝试从文本中提取 + if not amount_candidates: + # 使用更宽松的正则表达式匹配金额 + amount_pattern = re.compile(r'([\d,]+\.?\d*)(?:元|块钱|块|万元|万)') + matches = list(amount_pattern.finditer(text)) + + for match in matches: + amount_text = match.group(1) # 获取数字部分,保留逗号 + try: + value = float(amount_text.replace(',', '')) + amount_candidates.append((amount_text, value)) + except ValueError: + continue + + # 选择最大的有效金额 + if amount_candidates: + result["amount"] = max(amount_candidates, key=lambda x: x[1])[0] # 处理最低还款金额 - for amount in entities["MIN_CODE"]: - cleaned_amount = clean_amount(amount, text) # 传入原始文本作为上下文 - if cleaned_amount: - result["min_amount"] = cleaned_amount - break + # 先尝试使用正则表达式直接匹配最低还款金额 + min_amount_match = re.search(r'最低还款([\d,]+\.?\d*)(?:元|块钱|块|万元|万)?', text) + if min_amount_match: + min_amount = min_amount_match.group(1) # 保留原始格式(带逗号) + # 验证金额有效性 + try: + value = float(min_amount.replace(',', '')) + if value > 0: + result["min_amount"] = min_amount + except ValueError: + pass + + # 如果正则没有匹配到,使用NER结果 + if not result["min_amount"] and entities["MIN_CODE"]: + for amount in entities["MIN_CODE"]: + cleaned_amount = clean_amount(amount, text) + if cleaned_amount: + result["min_amount"] = cleaned_amount + break return result @@ -541,9 +612,8 @@ search_end = min(len(context), amount_index + len(amount_text) + 10) search_text = context[search_start:search_end] - # 使用正则表达式查找金额模式 - import re - amount_pattern = re.compile(r'(\d{1,10}(?:\.\d{1,2})?)') + # 使用更精确的正则表达式查找金额模式,支持带逗号的金额 + amount_pattern = re.compile(r'(\d{1,3}(?:,\d{3})*(?:\.\d{1,2})?|\d+(?:\.\d{1,2})?)') matches = list(amount_pattern.finditer(search_text)) # 找到最接近且最长的完整金额 @@ -641,12 +711,34 @@ result["datetime"] = datetime # 处理收入金额 - if entities["PICKUP_CODE"]: - for amount in entities["PICKUP_CODE"]: - cleaned_amount = clean_amount(amount, text) - if cleaned_amount: - result["amount"] = cleaned_amount - break + amount_candidates = [] + # 首先从识别的实体中获取 + for amount in entities["PICKUP_CODE"]: + cleaned_amount = clean_amount(amount, text) + if cleaned_amount: + try: + value = float(cleaned_amount) + amount_candidates.append((cleaned_amount, value)) + except ValueError: + continue + + # 如果没有找到有效金额,直接从文本中尝试提取 + if not amount_candidates: + # 直接在整个文本中寻找金额模式 + amount_pattern = re.compile(r'(\d{1,3}(?:,\d{3})*(?:\.\d{1,2})?|\d+(?:\.\d{1,2})?)') + matches = list(amount_pattern.finditer(text)) + + for match in matches: + amount_text = match.group(1) + try: + value = float(amount_text.replace(',', '')) + amount_candidates.append((amount_text, value)) + except ValueError: + continue + + # 选择最合适的有效金额 + if amount_candidates: + result["amount"] = max(amount_candidates, key=lambda x: x[1])[0] # 处理余额 if entities["BALANCE"]: @@ -660,6 +752,222 @@ except Exception as e: logger.error(f"收入实体提取失败: {str(e)}") + raise + + def extract_flight_entities(self, text: str) -> Dict[str, Optional[str]]: + """提取航班相关实体""" + try: + # 初始化结果字典 + result = { + "flight": None, # 航班号 + "company": None, # 航空公司 + "start": None, # 出发地 + "end": None, # 目的地 + "date": None, # 日期 + "time": None, # 时间 + "departure_time": None, # 起飞时间 + "arrival_time": None, # 到达时间 + "ticket_num": None, # 机票号码 + "seat": None # 座位等信息 + } + + # 使用NER模型提取实体 + inputs = self.flight_tokenizer( + text, + return_tensors="pt", + truncation=True, + max_length=FlightNERConfig.MAX_LENGTH + ) + + with torch.no_grad(): + outputs = self.flight_model(**inputs) + + predictions = torch.argmax(outputs.logits, dim=2) + tokens = self.flight_tokenizer.convert_ids_to_tokens(inputs["input_ids"][0]) + tags = [self.flight_model.config.id2label[p] for p in predictions[0].numpy()] + + # 解析实体 + current_entity = None + + for token, tag in zip(tokens, tags): + if tag.startswith("B-"): + if current_entity: + entity_type = current_entity["type"].lower() + result[entity_type] = current_entity["text"].replace("[UNK]", "").replace("##", "").strip() + current_entity = {"type": tag[2:], "text": token} + elif tag.startswith("I-") and current_entity and tag[2:] == current_entity["type"]: + current_entity["text"] += token + else: + if current_entity: + entity_type = current_entity["type"].lower() + result[entity_type] = current_entity["text"].replace("[UNK]", "").replace("##", "").strip() + current_entity = None + + # 处理最后一个实体 + if current_entity: + entity_type = current_entity["type"].lower() + result[entity_type] = current_entity["text"].replace("[UNK]", "").replace("##", "").strip() + + # 处理航班号格式 + if result["flight"]: + flight_no = result["flight"].upper() + # 清理航班号,只保留字母和数字 + flight_no = ''.join(c for c in flight_no if c.isalnum()) + # 验证航班号格式 + valid_pattern = re.compile(FlightNERConfig.FLIGHT_CONFIG['pattern']) + if valid_pattern.match(flight_no): + result["flight"] = flight_no + else: + # 尝试修复常见错误 + if len(flight_no) >= FlightNERConfig.FLIGHT_CONFIG['min_length'] and flight_no[:2].isalpha() and flight_no[2:].isdigit(): + result["flight"] = flight_no + else: + result["flight"] = None + + # 清理日期格式 + if result["date"]: + date_str = result["date"] + # 保留数字和常见日期分隔符 + date_str = ''.join(c for c in date_str if c.isdigit() or c in ['年', '月', '日', '-', '/', '.']) + result["date"] = date_str + + # 清理时间格式 + for time_field in ["time", "departure_time", "arrival_time"]: + if result[time_field]: + time_str = result[time_field] + # 保留数字和常见时间分隔符 + time_str = ''.join(c for c in time_str if c.isdigit() or c in [':', '时', '分', '点']) + result[time_field] = time_str + + # 处理机票号码 + if result["ticket_num"]: + ticket_num = result["ticket_num"] + # 清理机票号码,只保留字母和数字 + ticket_num = ''.join(c for c in ticket_num if c.isalnum()) + result["ticket_num"] = ticket_num + + # 处理座位信息 + if result["seat"]: + seat_str = result["seat"] + # 移除可能的额外空格和特殊字符 + seat_str = seat_str.replace(" ", "").strip() + result["seat"] = seat_str + + return result + except Exception as e: + logger.error(f"航班实体提取失败: {str(e)}") + raise + + def extract_train_entities(self, text: str) -> Dict[str, Optional[str]]: + """提取火车票相关实体""" + try: + # 初始化结果字典 + result = { + "company": None, # 12306 + "trips": None, # 车次 + "start": None, # 出发站 + "end": None, # 到达站 + "date": None, # 日期 + "time": None, # 时间 + "seat": None, # 座位等信息 + "name": None # 用户姓名 + } + + # 使用NER模型提取实体 + inputs = self.train_tokenizer( + text, + return_tensors="pt", + truncation=True, + max_length=TrainNERConfig.MAX_LENGTH + ) + + with torch.no_grad(): + outputs = self.train_model(**inputs) + + predictions = torch.argmax(outputs.logits, dim=2) + tokens = self.train_tokenizer.convert_ids_to_tokens(inputs["input_ids"][0]) + tags = [self.train_model.config.id2label[p] for p in predictions[0].numpy()] + + # 解析实体 + current_entity = None + + for token, tag in zip(tokens, tags): + if tag.startswith("B-"): + if current_entity: + entity_type = current_entity["type"].lower() + result[entity_type] = current_entity["text"].replace("[UNK]", "").replace("##", "").strip() + current_entity = {"type": tag[2:], "text": token} + elif tag.startswith("I-") and current_entity and tag[2:] == current_entity["type"]: + current_entity["text"] += token + else: + if current_entity: + entity_type = current_entity["type"].lower() + result[entity_type] = current_entity["text"].replace("[UNK]", "").replace("##", "").strip() + current_entity = None + + # 处理最后一个实体 + if current_entity: + entity_type = current_entity["type"].lower() + result[entity_type] = current_entity["text"].replace("[UNK]", "").replace("##", "").strip() + + # 处理公司名称,通常为12306 + if result["company"]: + company = result["company"].strip() + # 如果文本中检测不到公司名称,但包含12306,则默认为12306 + result["company"] = company + elif "12306" in text: + result["company"] = "12306" + + # 处理车次格式 + if result["trips"]: + trips_no = result["trips"].upper() + # 清理车次号,只保留字母和数字 + trips_no = ''.join(c for c in trips_no if c.isalnum() or c in ['/', '-']) + + # 验证车次格式 + valid_patterns = [re.compile(pattern) for pattern in TrainNERConfig.TRIPS_CONFIG['patterns']] + if any(pattern.match(trips_no) for pattern in valid_patterns): + result["trips"] = trips_no + else: + # 尝试修复常见错误 + if len(trips_no) >= TrainNERConfig.TRIPS_CONFIG['min_length'] and any(trips_no.startswith(t) for t in TrainNERConfig.TRIPS_CONFIG['train_types']): + result["trips"] = trips_no + elif trips_no.isdigit() and 1 <= len(trips_no) <= TrainNERConfig.TRIPS_CONFIG['max_length']: + result["trips"] = trips_no + else: + result["trips"] = None + + # 清理日期格式 + if result["date"]: + date_str = result["date"] + # 保留数字和常见日期分隔符 + date_str = ''.join(c for c in date_str if c.isdigit() or c in ['年', '月', '日', '-', '/', '.']) + result["date"] = date_str + + # 清理时间格式 + if result["time"]: + time_str = result["time"] + # 保留数字和常见时间分隔符 + time_str = ''.join(c for c in time_str if c.isdigit() or c in [':', '时', '分', '点']) + result["time"] = time_str + + # 处理座位信息 + if result["seat"]: + seat_str = result["seat"] + # 移除可能的额外空格和特殊字符 + seat_str = seat_str.replace(" ", "").strip() + result["seat"] = seat_str + + # 处理乘客姓名 + if result["name"]: + name = result["name"].strip() + # 移除可能的标点符号 + name = ''.join(c for c in name if c.isalnum() or c in ['*', '·']) + result["name"] = name + + return result + except Exception as e: + logger.error(f"火车票实体提取失败: {str(e)}") raise # 创建Flask应用 @@ -695,6 +1003,10 @@ details = model_manager.extract_repayment_entities(text) elif category == "收入": details = model_manager.extract_income_entities(text) + elif category == "航班": + details = model_manager.extract_flight_entities(text) + elif category == "火车票": # 添加火车票类别处理 + details = model_manager.extract_train_entities(text) else: details = {} -- Gitblit v1.9.3