From e46063e99f4cc6b22de4e32227289da3122646dd Mon Sep 17 00:00:00 2001 From: cloudroam <cloudroam> Date: 星期三, 16 四月 2025 16:55:20 +0800 Subject: [PATCH] fix --- app.py | 228 ++++++++++++++++++++++++++++++++++++-------------------- 1 files changed, 145 insertions(+), 83 deletions(-) diff --git a/app.py b/app.py index 47aa332..ba71f7e 100644 --- a/app.py +++ b/app.py @@ -27,8 +27,8 @@ self.ner_path = "./models/ner_model/best_model" self.repayment_path = "./models/repayment_model/best_model" self.income_path = "./models/income_model/best_model" - self.flight_path = "./models/flight_model/best_model" - self.train_path = "./models/train_model/best_model" # 添加火车票模型路径 + # self.flight_path = "./models/flight_model/best_model" + # self.train_path = "./models/train_model/best_model" # 添加火车票模型路径 # 检查模型文件 self._check_model_files() @@ -38,16 +38,16 @@ self.ner_tokenizer, self.ner_model = self._load_ner() self.repayment_tokenizer, self.repayment_model = self._load_repayment() self.income_tokenizer, self.income_model = self._load_income() - self.flight_tokenizer, self.flight_model = self._load_flight() - self.train_tokenizer, self.train_model = self._load_train() # 加载火车票模型 + # self.flight_tokenizer, self.flight_model = self._load_flight() + # self.train_tokenizer, self.train_model = self._load_train() # 加载火车票模型 # 将模型设置为评估模式 self.classifier_model.eval() self.ner_model.eval() self.repayment_model.eval() self.income_model.eval() - self.flight_model.eval() - self.train_model.eval() # 设置火车票模型为评估模式 + # self.flight_model.eval() + # self.train_model.eval() # 设置火车票模型为评估模式 def _check_model_files(self): """检查模型文件是否存在""" @@ -59,10 +59,10 @@ raise RuntimeError("还款模型文件不存在,请先运行训练脚本") if not os.path.exists(self.income_path): raise RuntimeError("收入模型文件不存在,请先运行训练脚本") - if not os.path.exists(self.flight_path): - raise RuntimeError("航班模型文件不存在,请先运行训练脚本") - if not os.path.exists(self.train_path): - raise RuntimeError("火车票模型文件不存在,请先运行训练脚本") + # if not os.path.exists(self.flight_path): + # raise RuntimeError("航班模型文件不存在,请先运行训练脚本") + # if not os.path.exists(self.train_path): + # raise RuntimeError("火车票模型文件不存在,请先运行训练脚本") def _load_classifier(self) -> Tuple[BertTokenizer, BertForSequenceClassification]: """加载分类模型""" @@ -382,67 +382,60 @@ if not amount_text: return None + # 尝试直接在上下文中使用正则表达式查找更完整的金额 + # 如果在同一句话里有类似"应还金额5,800元"这样的模式 + amount_match = re.search(r'(?:应还|还款)?金额([\d,]+\.?\d*)(?:元|块钱|块|万元|万)?', context) + if amount_match: + return amount_match.group(1) # 直接返回匹配到的金额,保留原始格式 + + # 尝试查找最低还款金额 + min_amount_match = re.search(r'最低还款([\d,]+\.?\d*)(?:元|块钱|块|万元|万)?', context) + if min_amount_match and "MIN_CODE" in current_entity["type"]: + return min_amount_match.group(1) # 直接返回匹配到的最低还款金额,保留原始格式 + # 在上下文中查找完整金额 amount_index = context.find(amount_text) if amount_index != -1: # 扩大搜索范围,查找完整金额 - search_start = max(0, amount_index - 10) # 增加向前搜索范围 + search_start = max(0, amount_index - 10) search_end = min(len(context), amount_index + len(amount_text) + 10) search_text = context[search_start:search_end] - # 使用更精确的正则表达式查找金额模式 - amount_pattern = re.compile(r'(\d{1,10}(?:\.\d{1,2})?)') + # 使用正则表达式查找金额 + amount_pattern = re.compile(r'([\d,]+\.?\d*)(?:元|块钱|块|万元|万)?') matches = list(amount_pattern.finditer(search_text)) - # 找到最接近且最长的完整金额 - best_match = None - min_distance = float('inf') - max_length = 0 - target_pos = amount_index - search_start - - for match in matches: - match_pos = match.start() - distance = abs(match_pos - target_pos) - match_text = match.group(1) + if matches: + # 选择最接近的匹配结果 + best_match = None + min_distance = float('inf') - # 优先选择更长的匹配,除非距离差异太大 - if len(match_text) > max_length or (len(match_text) == max_length and distance < min_distance): - try: - # 验证金额是否合理 - value = float(match_text) - if value > 0 and value <= 9999999.99: # 设置合理的金额范围 - best_match = match_text - min_distance = distance - max_length = len(match_text) - except ValueError: - continue + for match in matches: + distance = abs(match.start() - (amount_index - search_start)) + if distance < min_distance: + min_distance = distance + best_match = match.group(1) # 只取数字部分,保留逗号 + + if best_match: + return best_match - if best_match: - amount_text = best_match - + # 如果上述方法都没找到,则保留原始提取结果但验证其有效性 # 移除货币符号和无效词 for symbol in RepaymentNERConfig.AMOUNT_CONFIG['currency_symbols']: amount_text = amount_text.replace(symbol, '') for word in RepaymentNERConfig.AMOUNT_CONFIG['invalid_words']: amount_text = amount_text.replace(word, '') - # 处理金额中的逗号 - amount_text = amount_text.replace(',', '') - + # 验证金额有效性 + clean_amount = amount_text.replace(',', '') try: - # 转换为浮点数 - value = float(amount_text) - - # 验证整数位数 - integer_part = str(int(value)) - if len(integer_part) <= RepaymentNERConfig.AMOUNT_CONFIG['max_integer_digits']: - # 保持原始小数位数 - if '.' in amount_text: - decimal_places = len(amount_text.split('.')[1]) - return f"{value:.{decimal_places}f}" - return str(int(value)) + value = float(clean_amount) + if value > 0: + # 返回原始格式 + return amount_text except ValueError: pass + return None # 实体提取 @@ -473,12 +466,11 @@ # 处理银行名称 if entities["BANK"]: - # 修改银行名称处理逻辑 bank_parts = [] - seen = set() # 用于去重 + seen = set() for bank in entities["BANK"]: bank = bank.strip() - if bank and bank not in seen: # 避免重复 + if bank and bank not in seen: bank_parts.append(bank) seen.add(bank) bank = "".join(bank_parts) @@ -487,7 +479,14 @@ # 处理还款类型 if entities["TYPE"]: - type_ = "".join(entities["TYPE"]).strip() + type_parts = [] + seen = set() + for type_ in entities["TYPE"]: + type_ = type_.strip() + if type_ and type_ not in seen: + type_parts.append(type_) + seen.add(type_) + type_ = "".join(type_parts) if len(type_) <= RepaymentNERConfig.MAX_ENTITY_LENGTH["TYPE"]: result["type"] = type_ @@ -500,32 +499,74 @@ # 处理日期 if entities["DATE"]: date = "".join(entities["DATE"]) - date = ''.join(c for c in date if c.isdigit() or c in ['年', '月', '日']) + date = ''.join(c for c in date if c.isdigit() or c in ['年', '月', '日', '-']) if date: result["date"] = date # 处理金额 - amount_candidates = [] - for amount in entities["PICKUP_CODE"]: - cleaned_amount = clean_amount(amount, text) - if cleaned_amount: - try: - value = float(cleaned_amount) - amount_candidates.append((cleaned_amount, value)) - except ValueError: - continue + # 先尝试使用正则表达式直接匹配金额 + amount_match = re.search(r'(?:应还|还款)?金额([\d,]+\.?\d*)(?:元|块钱|块|万元|万)?', text) + if amount_match: + amount = amount_match.group(1) # 保留原始格式(带逗号) + # 验证金额有效性 + try: + value = float(amount.replace(',', '')) + if value > 0: + result["amount"] = amount + except ValueError: + pass - # 选择最大的有效金额 - if amount_candidates: - # 按金额大小排序,选择最大的 - result["amount"] = max(amount_candidates, key=lambda x: x[1])[0] + # 如果正则没有匹配到,使用NER结果 + if not result["amount"]: + amount_candidates = [] + # 从识别的实体中获取 + for amount in entities["PICKUP_CODE"]: + cleaned_amount = clean_amount(amount, text) + if cleaned_amount: + try: + value = float(cleaned_amount.replace(',', '')) + amount_candidates.append((cleaned_amount, value)) + except ValueError: + continue + + # 如果还是没有找到,尝试从文本中提取 + if not amount_candidates: + # 使用更宽松的正则表达式匹配金额 + amount_pattern = re.compile(r'([\d,]+\.?\d*)(?:元|块钱|块|万元|万)') + matches = list(amount_pattern.finditer(text)) + + for match in matches: + amount_text = match.group(1) # 获取数字部分,保留逗号 + try: + value = float(amount_text.replace(',', '')) + amount_candidates.append((amount_text, value)) + except ValueError: + continue + + # 选择最大的有效金额 + if amount_candidates: + result["amount"] = max(amount_candidates, key=lambda x: x[1])[0] # 处理最低还款金额 - for amount in entities["MIN_CODE"]: - cleaned_amount = clean_amount(amount, text) # 传入原始文本作为上下文 - if cleaned_amount: - result["min_amount"] = cleaned_amount - break + # 先尝试使用正则表达式直接匹配最低还款金额 + min_amount_match = re.search(r'最低还款([\d,]+\.?\d*)(?:元|块钱|块|万元|万)?', text) + if min_amount_match: + min_amount = min_amount_match.group(1) # 保留原始格式(带逗号) + # 验证金额有效性 + try: + value = float(min_amount.replace(',', '')) + if value > 0: + result["min_amount"] = min_amount + except ValueError: + pass + + # 如果正则没有匹配到,使用NER结果 + if not result["min_amount"] and entities["MIN_CODE"]: + for amount in entities["MIN_CODE"]: + cleaned_amount = clean_amount(amount, text) + if cleaned_amount: + result["min_amount"] = cleaned_amount + break return result @@ -571,9 +612,8 @@ search_end = min(len(context), amount_index + len(amount_text) + 10) search_text = context[search_start:search_end] - # 使用正则表达式查找金额模式 - import re - amount_pattern = re.compile(r'(\d{1,10}(?:\.\d{1,2})?)') + # 使用更精确的正则表达式查找金额模式,支持带逗号的金额 + amount_pattern = re.compile(r'(\d{1,3}(?:,\d{3})*(?:\.\d{1,2})?|\d+(?:\.\d{1,2})?)') matches = list(amount_pattern.finditer(search_text)) # 找到最接近且最长的完整金额 @@ -671,12 +711,34 @@ result["datetime"] = datetime # 处理收入金额 - if entities["PICKUP_CODE"]: - for amount in entities["PICKUP_CODE"]: - cleaned_amount = clean_amount(amount, text) - if cleaned_amount: - result["amount"] = cleaned_amount - break + amount_candidates = [] + # 首先从识别的实体中获取 + for amount in entities["PICKUP_CODE"]: + cleaned_amount = clean_amount(amount, text) + if cleaned_amount: + try: + value = float(cleaned_amount) + amount_candidates.append((cleaned_amount, value)) + except ValueError: + continue + + # 如果没有找到有效金额,直接从文本中尝试提取 + if not amount_candidates: + # 直接在整个文本中寻找金额模式 + amount_pattern = re.compile(r'(\d{1,3}(?:,\d{3})*(?:\.\d{1,2})?|\d+(?:\.\d{1,2})?)') + matches = list(amount_pattern.finditer(text)) + + for match in matches: + amount_text = match.group(1) + try: + value = float(amount_text.replace(',', '')) + amount_candidates.append((amount_text, value)) + except ValueError: + continue + + # 选择最合适的有效金额 + if amount_candidates: + result["amount"] = max(amount_candidates, key=lambda x: x[1])[0] # 处理余额 if entities["BALANCE"]: -- Gitblit v1.9.3