fix
cloudroam
2025-04-16 e46063e99f4cc6b22de4e32227289da3122646dd
app.py
@@ -27,8 +27,8 @@
        self.ner_path = "./models/ner_model/best_model"
        self.repayment_path = "./models/repayment_model/best_model"
        self.income_path = "./models/income_model/best_model"
        self.flight_path = "./models/flight_model/best_model"
        self.train_path = "./models/train_model/best_model"  # 添加火车票模型路径
        # self.flight_path = "./models/flight_model/best_model"
        # self.train_path = "./models/train_model/best_model"  # 添加火车票模型路径
        # 检查模型文件
        self._check_model_files()
@@ -38,16 +38,16 @@
        self.ner_tokenizer, self.ner_model = self._load_ner()
        self.repayment_tokenizer, self.repayment_model = self._load_repayment()
        self.income_tokenizer, self.income_model = self._load_income()
        self.flight_tokenizer, self.flight_model = self._load_flight()
        self.train_tokenizer, self.train_model = self._load_train()  # 加载火车票模型
        # self.flight_tokenizer, self.flight_model = self._load_flight()
        # self.train_tokenizer, self.train_model = self._load_train()  # 加载火车票模型
        
        # 将模型设置为评估模式
        self.classifier_model.eval()
        self.ner_model.eval()
        self.repayment_model.eval()
        self.income_model.eval()
        self.flight_model.eval()
        self.train_model.eval()  # 设置火车票模型为评估模式
        # self.flight_model.eval()
        # self.train_model.eval()  # 设置火车票模型为评估模式
        
    def _check_model_files(self):
        """检查模型文件是否存在"""
@@ -59,10 +59,10 @@
            raise RuntimeError("还款模型文件不存在,请先运行训练脚本")
        if not os.path.exists(self.income_path):
            raise RuntimeError("收入模型文件不存在,请先运行训练脚本")
        if not os.path.exists(self.flight_path):
            raise RuntimeError("航班模型文件不存在,请先运行训练脚本")
        if not os.path.exists(self.train_path):
            raise RuntimeError("火车票模型文件不存在,请先运行训练脚本")
        # if not os.path.exists(self.flight_path):
        #     raise RuntimeError("航班模型文件不存在,请先运行训练脚本")
        # if not os.path.exists(self.train_path):
        #     raise RuntimeError("火车票模型文件不存在,请先运行训练脚本")
            
    def _load_classifier(self) -> Tuple[BertTokenizer, BertForSequenceClassification]:
        """加载分类模型"""
@@ -382,67 +382,60 @@
                if not amount_text:
                    return None
                
                # 尝试直接在上下文中使用正则表达式查找更完整的金额
                # 如果在同一句话里有类似"应还金额5,800元"这样的模式
                amount_match = re.search(r'(?:应还|还款)?金额([\d,]+\.?\d*)(?:元|块钱|块|万元|万)?', context)
                if amount_match:
                    return amount_match.group(1)  # 直接返回匹配到的金额,保留原始格式
                # 尝试查找最低还款金额
                min_amount_match = re.search(r'最低还款([\d,]+\.?\d*)(?:元|块钱|块|万元|万)?', context)
                if min_amount_match and "MIN_CODE" in current_entity["type"]:
                    return min_amount_match.group(1)  # 直接返回匹配到的最低还款金额,保留原始格式
                # 在上下文中查找完整金额
                amount_index = context.find(amount_text)
                if amount_index != -1:
                    # 扩大搜索范围,查找完整金额
                    search_start = max(0, amount_index - 10)  # 增加向前搜索范围
                    search_start = max(0, amount_index - 10)
                    search_end = min(len(context), amount_index + len(amount_text) + 10)
                    search_text = context[search_start:search_end]
                    
                    # 使用更精确的正则表达式查找金额模式
                    amount_pattern = re.compile(r'(\d{1,10}(?:\.\d{1,2})?)')
                    # 使用正则表达式查找金额
                    amount_pattern = re.compile(r'([\d,]+\.?\d*)(?:元|块钱|块|万元|万)?')
                    matches = list(amount_pattern.finditer(search_text))
                    
                    # 找到最接近且最长的完整金额
                    best_match = None
                    min_distance = float('inf')
                    max_length = 0
                    target_pos = amount_index - search_start
                    for match in matches:
                        match_pos = match.start()
                        distance = abs(match_pos - target_pos)
                        match_text = match.group(1)
                    if matches:
                        # 选择最接近的匹配结果
                        best_match = None
                        min_distance = float('inf')
                        
                        # 优先选择更长的匹配,除非距离差异太大
                        if len(match_text) > max_length or (len(match_text) == max_length and distance < min_distance):
                            try:
                                # 验证金额是否合理
                                value = float(match_text)
                                if value > 0 and value <= 9999999.99:  # 设置合理的金额范围
                                    best_match = match_text
                                    min_distance = distance
                                    max_length = len(match_text)
                            except ValueError:
                                continue
                        for match in matches:
                            distance = abs(match.start() - (amount_index - search_start))
                            if distance < min_distance:
                                min_distance = distance
                                best_match = match.group(1)  # 只取数字部分,保留逗号
                        if best_match:
                            return best_match
                
                    if best_match:
                        amount_text = best_match
                # 如果上述方法都没找到,则保留原始提取结果但验证其有效性
                # 移除货币符号和无效词
                for symbol in RepaymentNERConfig.AMOUNT_CONFIG['currency_symbols']:
                    amount_text = amount_text.replace(symbol, '')
                for word in RepaymentNERConfig.AMOUNT_CONFIG['invalid_words']:
                    amount_text = amount_text.replace(word, '')
                
                # 处理金额中的逗号
                amount_text = amount_text.replace(',', '')
                # 验证金额有效性
                clean_amount = amount_text.replace(',', '')
                try:
                    # 转换为浮点数
                    value = float(amount_text)
                    # 验证整数位数
                    integer_part = str(int(value))
                    if len(integer_part) <= RepaymentNERConfig.AMOUNT_CONFIG['max_integer_digits']:
                        # 保持原始小数位数
                        if '.' in amount_text:
                            decimal_places = len(amount_text.split('.')[1])
                            return f"{value:.{decimal_places}f}"
                        return str(int(value))
                    value = float(clean_amount)
                    if value > 0:
                        # 返回原始格式
                        return amount_text
                except ValueError:
                    pass
                return None
            # 实体提取
@@ -473,12 +466,11 @@
            # 处理银行名称
            if entities["BANK"]:
                # 修改银行名称处理逻辑
                bank_parts = []
                seen = set()  # 用于去重
                seen = set()
                for bank in entities["BANK"]:
                    bank = bank.strip()
                    if bank and bank not in seen:  # 避免重复
                    if bank and bank not in seen:
                        bank_parts.append(bank)
                        seen.add(bank)
                bank = "".join(bank_parts)
@@ -487,7 +479,14 @@
            # 处理还款类型
            if entities["TYPE"]:
                type_ = "".join(entities["TYPE"]).strip()
                type_parts = []
                seen = set()
                for type_ in entities["TYPE"]:
                    type_ = type_.strip()
                    if type_ and type_ not in seen:
                        type_parts.append(type_)
                        seen.add(type_)
                type_ = "".join(type_parts)
                if len(type_) <= RepaymentNERConfig.MAX_ENTITY_LENGTH["TYPE"]:
                    result["type"] = type_
@@ -500,32 +499,74 @@
            # 处理日期
            if entities["DATE"]:
                date = "".join(entities["DATE"])
                date = ''.join(c for c in date if c.isdigit() or c in ['年', '月', '日'])
                date = ''.join(c for c in date if c.isdigit() or c in ['年', '月', '日', '-'])
                if date:
                    result["date"] = date
            # 处理金额
            amount_candidates = []
            for amount in entities["PICKUP_CODE"]:
                cleaned_amount = clean_amount(amount, text)
                if cleaned_amount:
                    try:
                        value = float(cleaned_amount)
                        amount_candidates.append((cleaned_amount, value))
                    except ValueError:
                        continue
            # 先尝试使用正则表达式直接匹配金额
            amount_match = re.search(r'(?:应还|还款)?金额([\d,]+\.?\d*)(?:元|块钱|块|万元|万)?', text)
            if amount_match:
                amount = amount_match.group(1)  # 保留原始格式(带逗号)
                # 验证金额有效性
                try:
                    value = float(amount.replace(',', ''))
                    if value > 0:
                        result["amount"] = amount
                except ValueError:
                    pass
            
            # 选择最大的有效金额
            if amount_candidates:
                # 按金额大小排序,选择最大的
                result["amount"] = max(amount_candidates, key=lambda x: x[1])[0]
            # 如果正则没有匹配到,使用NER结果
            if not result["amount"]:
                amount_candidates = []
                # 从识别的实体中获取
                for amount in entities["PICKUP_CODE"]:
                    cleaned_amount = clean_amount(amount, text)
                    if cleaned_amount:
                        try:
                            value = float(cleaned_amount.replace(',', ''))
                            amount_candidates.append((cleaned_amount, value))
                        except ValueError:
                            continue
                # 如果还是没有找到,尝试从文本中提取
                if not amount_candidates:
                    # 使用更宽松的正则表达式匹配金额
                    amount_pattern = re.compile(r'([\d,]+\.?\d*)(?:元|块钱|块|万元|万)')
                    matches = list(amount_pattern.finditer(text))
                    for match in matches:
                        amount_text = match.group(1)  # 获取数字部分,保留逗号
                        try:
                            value = float(amount_text.replace(',', ''))
                            amount_candidates.append((amount_text, value))
                        except ValueError:
                            continue
                # 选择最大的有效金额
                if amount_candidates:
                    result["amount"] = max(amount_candidates, key=lambda x: x[1])[0]
            # 处理最低还款金额
            for amount in entities["MIN_CODE"]:
                cleaned_amount = clean_amount(amount, text)  # 传入原始文本作为上下文
                if cleaned_amount:
                    result["min_amount"] = cleaned_amount
                    break
            # 先尝试使用正则表达式直接匹配最低还款金额
            min_amount_match = re.search(r'最低还款([\d,]+\.?\d*)(?:元|块钱|块|万元|万)?', text)
            if min_amount_match:
                min_amount = min_amount_match.group(1)  # 保留原始格式(带逗号)
                # 验证金额有效性
                try:
                    value = float(min_amount.replace(',', ''))
                    if value > 0:
                        result["min_amount"] = min_amount
                except ValueError:
                    pass
            # 如果正则没有匹配到,使用NER结果
            if not result["min_amount"] and entities["MIN_CODE"]:
                for amount in entities["MIN_CODE"]:
                    cleaned_amount = clean_amount(amount, text)
                    if cleaned_amount:
                        result["min_amount"] = cleaned_amount
                        break
            return result
        
@@ -571,9 +612,8 @@
                    search_end = min(len(context), amount_index + len(amount_text) + 10)
                    search_text = context[search_start:search_end]
                    
                    # 使用正则表达式查找金额模式
                    import re
                    amount_pattern = re.compile(r'(\d{1,10}(?:\.\d{1,2})?)')
                    # 使用更精确的正则表达式查找金额模式,支持带逗号的金额
                    amount_pattern = re.compile(r'(\d{1,3}(?:,\d{3})*(?:\.\d{1,2})?|\d+(?:\.\d{1,2})?)')
                    matches = list(amount_pattern.finditer(search_text))
                    
                    # 找到最接近且最长的完整金额
@@ -671,12 +711,34 @@
                    result["datetime"] = datetime
            # 处理收入金额
            if entities["PICKUP_CODE"]:
                for amount in entities["PICKUP_CODE"]:
                    cleaned_amount = clean_amount(amount, text)
                    if cleaned_amount:
                        result["amount"] = cleaned_amount
                        break
            amount_candidates = []
            # 首先从识别的实体中获取
            for amount in entities["PICKUP_CODE"]:
                cleaned_amount = clean_amount(amount, text)
                if cleaned_amount:
                    try:
                        value = float(cleaned_amount)
                        amount_candidates.append((cleaned_amount, value))
                    except ValueError:
                        continue
            # 如果没有找到有效金额,直接从文本中尝试提取
            if not amount_candidates:
                # 直接在整个文本中寻找金额模式
                amount_pattern = re.compile(r'(\d{1,3}(?:,\d{3})*(?:\.\d{1,2})?|\d+(?:\.\d{1,2})?)')
                matches = list(amount_pattern.finditer(text))
                for match in matches:
                    amount_text = match.group(1)
                    try:
                        value = float(amount_text.replace(',', ''))
                        amount_candidates.append((amount_text, value))
                    except ValueError:
                        continue
            # 选择最合适的有效金额
            if amount_candidates:
                result["amount"] = max(amount_candidates, key=lambda x: x[1])[0]
            # 处理余额
            if entities["BALANCE"]: