| | |
| | | from transformers import BertTokenizer, BertForSequenceClassification, AutoTokenizer, AutoModelForTokenClassification |
| | | import torch |
| | | from werkzeug.exceptions import BadRequest |
| | | from ner_config import NERConfig, RepaymentNERConfig, IncomeNERConfig |
| | | from ner_config import NERConfig, RepaymentNERConfig, IncomeNERConfig, FlightNERConfig, TrainNERConfig |
| | | import re |
| | | |
| | | # 配置日志 |
| | |
| | | self.ner_path = "./models/ner_model/best_model" |
| | | self.repayment_path = "./models/repayment_model/best_model" |
| | | self.income_path = "./models/income_model/best_model" |
| | | # self.flight_path = "./models/flight_model/best_model" |
| | | # self.train_path = "./models/train_model/best_model" # 添加火车票模型路径 |
| | | |
| | | # 检查模型文件 |
| | | self._check_model_files() |
| | |
| | | self.ner_tokenizer, self.ner_model = self._load_ner() |
| | | self.repayment_tokenizer, self.repayment_model = self._load_repayment() |
| | | self.income_tokenizer, self.income_model = self._load_income() |
| | | # self.flight_tokenizer, self.flight_model = self._load_flight() |
| | | # self.train_tokenizer, self.train_model = self._load_train() # 加载火车票模型 |
| | | |
| | | # 将模型设置为评估模式 |
| | | self.classifier_model.eval() |
| | | self.ner_model.eval() |
| | | self.repayment_model.eval() |
| | | self.income_model.eval() |
| | | # self.flight_model.eval() |
| | | # self.train_model.eval() # 设置火车票模型为评估模式 |
| | | |
| | | def _check_model_files(self): |
| | | """检查模型文件是否存在""" |
| | |
| | | raise RuntimeError("还款模型文件不存在,请先运行训练脚本") |
| | | if not os.path.exists(self.income_path): |
| | | raise RuntimeError("收入模型文件不存在,请先运行训练脚本") |
| | | # if not os.path.exists(self.flight_path): |
| | | # raise RuntimeError("航班模型文件不存在,请先运行训练脚本") |
| | | # if not os.path.exists(self.train_path): |
| | | # raise RuntimeError("火车票模型文件不存在,请先运行训练脚本") |
| | | |
| | | def _load_classifier(self) -> Tuple[BertTokenizer, BertForSequenceClassification]: |
| | | """加载分类模型""" |
| | |
| | | logger.error(f"加载收入模型失败: {str(e)}") |
| | | raise |
| | | |
| | | def _load_flight(self): |
| | | """加载航班模型""" |
| | | try: |
| | | tokenizer = AutoTokenizer.from_pretrained(self.flight_path) |
| | | model = AutoModelForTokenClassification.from_pretrained(self.flight_path) |
| | | return tokenizer, model |
| | | except Exception as e: |
| | | logger.error(f"加载航班模型失败: {str(e)}") |
| | | raise |
| | | |
| | | def _load_train(self): |
| | | """加载火车票模型""" |
| | | try: |
| | | tokenizer = AutoTokenizer.from_pretrained(self.train_path) |
| | | model = AutoModelForTokenClassification.from_pretrained(self.train_path) |
| | | return tokenizer, model |
| | | except Exception as e: |
| | | logger.error(f"加载火车票模型失败: {str(e)}") |
| | | raise |
| | | |
| | | def classify_sms(self, text: str) -> str: |
| | | """对短信进行分类""" |
| | | try: |
| | |
| | | "company": None, # 寄件公司 |
| | | "address": None, # 地址 |
| | | "pickup_code": None, # 取件码 |
| | | "time": None # 时间 |
| | | "time": None # 添加时间字段 |
| | | } |
| | | |
| | | # 第一阶段:直接从文本中提取取件码 |
| | |
| | | if not amount_text: |
| | | return None |
| | | |
| | | # 尝试直接在上下文中使用正则表达式查找更完整的金额 |
| | | # 如果在同一句话里有类似"应还金额5,800元"这样的模式 |
| | | amount_match = re.search(r'(?:应还|还款)?金额([\d,]+\.?\d*)(?:元|块钱|块|万元|万)?', context) |
| | | if amount_match: |
| | | return amount_match.group(1) # 直接返回匹配到的金额,保留原始格式 |
| | | |
| | | # 尝试查找最低还款金额 |
| | | min_amount_match = re.search(r'最低还款([\d,]+\.?\d*)(?:元|块钱|块|万元|万)?', context) |
| | | if min_amount_match and "MIN_CODE" in current_entity["type"]: |
| | | return min_amount_match.group(1) # 直接返回匹配到的最低还款金额,保留原始格式 |
| | | |
| | | # 在上下文中查找完整金额 |
| | | amount_index = context.find(amount_text) |
| | | if amount_index != -1: |
| | | # 扩大搜索范围,查找完整金额 |
| | | search_start = max(0, amount_index - 10) # 增加向前搜索范围 |
| | | search_start = max(0, amount_index - 10) |
| | | search_end = min(len(context), amount_index + len(amount_text) + 10) |
| | | search_text = context[search_start:search_end] |
| | | |
| | | # 使用更精确的正则表达式查找金额模式 |
| | | amount_pattern = re.compile(r'(\d{1,10}(?:\.\d{1,2})?)') |
| | | # 使用正则表达式查找金额 |
| | | amount_pattern = re.compile(r'([\d,]+\.?\d*)(?:元|块钱|块|万元|万)?') |
| | | matches = list(amount_pattern.finditer(search_text)) |
| | | |
| | | # 找到最接近且最长的完整金额 |
| | | best_match = None |
| | | min_distance = float('inf') |
| | | max_length = 0 |
| | | target_pos = amount_index - search_start |
| | | |
| | | for match in matches: |
| | | match_pos = match.start() |
| | | distance = abs(match_pos - target_pos) |
| | | match_text = match.group(1) |
| | | if matches: |
| | | # 选择最接近的匹配结果 |
| | | best_match = None |
| | | min_distance = float('inf') |
| | | |
| | | # 优先选择更长的匹配,除非距离差异太大 |
| | | if len(match_text) > max_length or (len(match_text) == max_length and distance < min_distance): |
| | | try: |
| | | # 验证金额是否合理 |
| | | value = float(match_text) |
| | | if value > 0 and value <= 9999999.99: # 设置合理的金额范围 |
| | | best_match = match_text |
| | | min_distance = distance |
| | | max_length = len(match_text) |
| | | except ValueError: |
| | | continue |
| | | for match in matches: |
| | | distance = abs(match.start() - (amount_index - search_start)) |
| | | if distance < min_distance: |
| | | min_distance = distance |
| | | best_match = match.group(1) # 只取数字部分,保留逗号 |
| | | |
| | | if best_match: |
| | | return best_match |
| | | |
| | | if best_match: |
| | | amount_text = best_match |
| | | |
| | | # 如果上述方法都没找到,则保留原始提取结果但验证其有效性 |
| | | # 移除货币符号和无效词 |
| | | for symbol in RepaymentNERConfig.AMOUNT_CONFIG['currency_symbols']: |
| | | amount_text = amount_text.replace(symbol, '') |
| | | for word in RepaymentNERConfig.AMOUNT_CONFIG['invalid_words']: |
| | | amount_text = amount_text.replace(word, '') |
| | | |
| | | # 处理金额中的逗号 |
| | | amount_text = amount_text.replace(',', '') |
| | | |
| | | # 验证金额有效性 |
| | | clean_amount = amount_text.replace(',', '') |
| | | try: |
| | | # 转换为浮点数 |
| | | value = float(amount_text) |
| | | |
| | | # 验证整数位数 |
| | | integer_part = str(int(value)) |
| | | if len(integer_part) <= RepaymentNERConfig.AMOUNT_CONFIG['max_integer_digits']: |
| | | # 保持原始小数位数 |
| | | if '.' in amount_text: |
| | | decimal_places = len(amount_text.split('.')[1]) |
| | | return f"{value:.{decimal_places}f}" |
| | | return str(int(value)) |
| | | value = float(clean_amount) |
| | | if value > 0: |
| | | # 返回原始格式 |
| | | return amount_text |
| | | except ValueError: |
| | | pass |
| | | |
| | | return None |
| | | |
| | | # 实体提取 |
| | |
| | | |
| | | # 处理银行名称 |
| | | if entities["BANK"]: |
| | | # 修改银行名称处理逻辑 |
| | | bank_parts = [] |
| | | seen = set() # 用于去重 |
| | | seen = set() |
| | | for bank in entities["BANK"]: |
| | | bank = bank.strip() |
| | | if bank and bank not in seen: # 避免重复 |
| | | if bank and bank not in seen: |
| | | bank_parts.append(bank) |
| | | seen.add(bank) |
| | | bank = "".join(bank_parts) |
| | |
| | | |
| | | # 处理还款类型 |
| | | if entities["TYPE"]: |
| | | type_ = "".join(entities["TYPE"]).strip() |
| | | type_parts = [] |
| | | seen = set() |
| | | for type_ in entities["TYPE"]: |
| | | type_ = type_.strip() |
| | | if type_ and type_ not in seen: |
| | | type_parts.append(type_) |
| | | seen.add(type_) |
| | | type_ = "".join(type_parts) |
| | | if len(type_) <= RepaymentNERConfig.MAX_ENTITY_LENGTH["TYPE"]: |
| | | result["type"] = type_ |
| | | |
| | |
| | | # 处理日期 |
| | | if entities["DATE"]: |
| | | date = "".join(entities["DATE"]) |
| | | date = ''.join(c for c in date if c.isdigit() or c in ['年', '月', '日']) |
| | | date = ''.join(c for c in date if c.isdigit() or c in ['年', '月', '日', '-']) |
| | | if date: |
| | | result["date"] = date |
| | | |
| | | # 处理金额 |
| | | amount_candidates = [] |
| | | for amount in entities["PICKUP_CODE"]: |
| | | cleaned_amount = clean_amount(amount, text) |
| | | if cleaned_amount: |
| | | try: |
| | | value = float(cleaned_amount) |
| | | amount_candidates.append((cleaned_amount, value)) |
| | | except ValueError: |
| | | continue |
| | | # 先尝试使用正则表达式直接匹配金额 |
| | | amount_match = re.search(r'(?:应还|还款)?金额([\d,]+\.?\d*)(?:元|块钱|块|万元|万)?', text) |
| | | if amount_match: |
| | | amount = amount_match.group(1) # 保留原始格式(带逗号) |
| | | # 验证金额有效性 |
| | | try: |
| | | value = float(amount.replace(',', '')) |
| | | if value > 0: |
| | | result["amount"] = amount |
| | | except ValueError: |
| | | pass |
| | | |
| | | # 选择最大的有效金额 |
| | | if amount_candidates: |
| | | # 按金额大小排序,选择最大的 |
| | | result["amount"] = max(amount_candidates, key=lambda x: x[1])[0] |
| | | # 如果正则没有匹配到,使用NER结果 |
| | | if not result["amount"]: |
| | | amount_candidates = [] |
| | | # 从识别的实体中获取 |
| | | for amount in entities["PICKUP_CODE"]: |
| | | cleaned_amount = clean_amount(amount, text) |
| | | if cleaned_amount: |
| | | try: |
| | | value = float(cleaned_amount.replace(',', '')) |
| | | amount_candidates.append((cleaned_amount, value)) |
| | | except ValueError: |
| | | continue |
| | | |
| | | # 如果还是没有找到,尝试从文本中提取 |
| | | if not amount_candidates: |
| | | # 使用更宽松的正则表达式匹配金额 |
| | | amount_pattern = re.compile(r'([\d,]+\.?\d*)(?:元|块钱|块|万元|万)') |
| | | matches = list(amount_pattern.finditer(text)) |
| | | |
| | | for match in matches: |
| | | amount_text = match.group(1) # 获取数字部分,保留逗号 |
| | | try: |
| | | value = float(amount_text.replace(',', '')) |
| | | amount_candidates.append((amount_text, value)) |
| | | except ValueError: |
| | | continue |
| | | |
| | | # 选择最大的有效金额 |
| | | if amount_candidates: |
| | | result["amount"] = max(amount_candidates, key=lambda x: x[1])[0] |
| | | |
| | | # 处理最低还款金额 |
| | | for amount in entities["MIN_CODE"]: |
| | | cleaned_amount = clean_amount(amount, text) # 传入原始文本作为上下文 |
| | | if cleaned_amount: |
| | | result["min_amount"] = cleaned_amount |
| | | break |
| | | # 先尝试使用正则表达式直接匹配最低还款金额 |
| | | min_amount_match = re.search(r'最低还款([\d,]+\.?\d*)(?:元|块钱|块|万元|万)?', text) |
| | | if min_amount_match: |
| | | min_amount = min_amount_match.group(1) # 保留原始格式(带逗号) |
| | | # 验证金额有效性 |
| | | try: |
| | | value = float(min_amount.replace(',', '')) |
| | | if value > 0: |
| | | result["min_amount"] = min_amount |
| | | except ValueError: |
| | | pass |
| | | |
| | | # 如果正则没有匹配到,使用NER结果 |
| | | if not result["min_amount"] and entities["MIN_CODE"]: |
| | | for amount in entities["MIN_CODE"]: |
| | | cleaned_amount = clean_amount(amount, text) |
| | | if cleaned_amount: |
| | | result["min_amount"] = cleaned_amount |
| | | break |
| | | |
| | | return result |
| | | |
| | |
| | | search_end = min(len(context), amount_index + len(amount_text) + 10) |
| | | search_text = context[search_start:search_end] |
| | | |
| | | # 使用正则表达式查找金额模式 |
| | | import re |
| | | amount_pattern = re.compile(r'(\d{1,10}(?:\.\d{1,2})?)') |
| | | # 使用更精确的正则表达式查找金额模式,支持带逗号的金额 |
| | | amount_pattern = re.compile(r'(\d{1,3}(?:,\d{3})*(?:\.\d{1,2})?|\d+(?:\.\d{1,2})?)') |
| | | matches = list(amount_pattern.finditer(search_text)) |
| | | |
| | | # 找到最接近且最长的完整金额 |
| | |
| | | result["datetime"] = datetime |
| | | |
| | | # 处理收入金额 |
| | | if entities["PICKUP_CODE"]: |
| | | for amount in entities["PICKUP_CODE"]: |
| | | cleaned_amount = clean_amount(amount, text) |
| | | if cleaned_amount: |
| | | result["amount"] = cleaned_amount |
| | | break |
| | | amount_candidates = [] |
| | | # 首先从识别的实体中获取 |
| | | for amount in entities["PICKUP_CODE"]: |
| | | cleaned_amount = clean_amount(amount, text) |
| | | if cleaned_amount: |
| | | try: |
| | | value = float(cleaned_amount) |
| | | amount_candidates.append((cleaned_amount, value)) |
| | | except ValueError: |
| | | continue |
| | | |
| | | # 如果没有找到有效金额,直接从文本中尝试提取 |
| | | if not amount_candidates: |
| | | # 直接在整个文本中寻找金额模式 |
| | | amount_pattern = re.compile(r'(\d{1,3}(?:,\d{3})*(?:\.\d{1,2})?|\d+(?:\.\d{1,2})?)') |
| | | matches = list(amount_pattern.finditer(text)) |
| | | |
| | | for match in matches: |
| | | amount_text = match.group(1) |
| | | try: |
| | | value = float(amount_text.replace(',', '')) |
| | | amount_candidates.append((amount_text, value)) |
| | | except ValueError: |
| | | continue |
| | | |
| | | # 选择最合适的有效金额 |
| | | if amount_candidates: |
| | | result["amount"] = max(amount_candidates, key=lambda x: x[1])[0] |
| | | |
| | | # 处理余额 |
| | | if entities["BALANCE"]: |
| | |
| | | |
| | | except Exception as e: |
| | | logger.error(f"收入实体提取失败: {str(e)}") |
| | | raise |
| | | |
| | | def extract_flight_entities(self, text: str) -> Dict[str, Optional[str]]: |
| | | """提取航班相关实体""" |
| | | try: |
| | | # 初始化结果字典 |
| | | result = { |
| | | "flight": None, # 航班号 |
| | | "company": None, # 航空公司 |
| | | "start": None, # 出发地 |
| | | "end": None, # 目的地 |
| | | "date": None, # 日期 |
| | | "time": None, # 时间 |
| | | "departure_time": None, # 起飞时间 |
| | | "arrival_time": None, # 到达时间 |
| | | "ticket_num": None, # 机票号码 |
| | | "seat": None # 座位等信息 |
| | | } |
| | | |
| | | # 使用NER模型提取实体 |
| | | inputs = self.flight_tokenizer( |
| | | text, |
| | | return_tensors="pt", |
| | | truncation=True, |
| | | max_length=FlightNERConfig.MAX_LENGTH |
| | | ) |
| | | |
| | | with torch.no_grad(): |
| | | outputs = self.flight_model(**inputs) |
| | | |
| | | predictions = torch.argmax(outputs.logits, dim=2) |
| | | tokens = self.flight_tokenizer.convert_ids_to_tokens(inputs["input_ids"][0]) |
| | | tags = [self.flight_model.config.id2label[p] for p in predictions[0].numpy()] |
| | | |
| | | # 解析实体 |
| | | current_entity = None |
| | | |
| | | for token, tag in zip(tokens, tags): |
| | | if tag.startswith("B-"): |
| | | if current_entity: |
| | | entity_type = current_entity["type"].lower() |
| | | result[entity_type] = current_entity["text"].replace("[UNK]", "").replace("##", "").strip() |
| | | current_entity = {"type": tag[2:], "text": token} |
| | | elif tag.startswith("I-") and current_entity and tag[2:] == current_entity["type"]: |
| | | current_entity["text"] += token |
| | | else: |
| | | if current_entity: |
| | | entity_type = current_entity["type"].lower() |
| | | result[entity_type] = current_entity["text"].replace("[UNK]", "").replace("##", "").strip() |
| | | current_entity = None |
| | | |
| | | # 处理最后一个实体 |
| | | if current_entity: |
| | | entity_type = current_entity["type"].lower() |
| | | result[entity_type] = current_entity["text"].replace("[UNK]", "").replace("##", "").strip() |
| | | |
| | | # 处理航班号格式 |
| | | if result["flight"]: |
| | | flight_no = result["flight"].upper() |
| | | # 清理航班号,只保留字母和数字 |
| | | flight_no = ''.join(c for c in flight_no if c.isalnum()) |
| | | # 验证航班号格式 |
| | | valid_pattern = re.compile(FlightNERConfig.FLIGHT_CONFIG['pattern']) |
| | | if valid_pattern.match(flight_no): |
| | | result["flight"] = flight_no |
| | | else: |
| | | # 尝试修复常见错误 |
| | | if len(flight_no) >= FlightNERConfig.FLIGHT_CONFIG['min_length'] and flight_no[:2].isalpha() and flight_no[2:].isdigit(): |
| | | result["flight"] = flight_no |
| | | else: |
| | | result["flight"] = None |
| | | |
| | | # 清理日期格式 |
| | | if result["date"]: |
| | | date_str = result["date"] |
| | | # 保留数字和常见日期分隔符 |
| | | date_str = ''.join(c for c in date_str if c.isdigit() or c in ['年', '月', '日', '-', '/', '.']) |
| | | result["date"] = date_str |
| | | |
| | | # 清理时间格式 |
| | | for time_field in ["time", "departure_time", "arrival_time"]: |
| | | if result[time_field]: |
| | | time_str = result[time_field] |
| | | # 保留数字和常见时间分隔符 |
| | | time_str = ''.join(c for c in time_str if c.isdigit() or c in [':', '时', '分', '点']) |
| | | result[time_field] = time_str |
| | | |
| | | # 处理机票号码 |
| | | if result["ticket_num"]: |
| | | ticket_num = result["ticket_num"] |
| | | # 清理机票号码,只保留字母和数字 |
| | | ticket_num = ''.join(c for c in ticket_num if c.isalnum()) |
| | | result["ticket_num"] = ticket_num |
| | | |
| | | # 处理座位信息 |
| | | if result["seat"]: |
| | | seat_str = result["seat"] |
| | | # 移除可能的额外空格和特殊字符 |
| | | seat_str = seat_str.replace(" ", "").strip() |
| | | result["seat"] = seat_str |
| | | |
| | | return result |
| | | except Exception as e: |
| | | logger.error(f"航班实体提取失败: {str(e)}") |
| | | raise |
| | | |
| | | def extract_train_entities(self, text: str) -> Dict[str, Optional[str]]: |
| | | """提取火车票相关实体""" |
| | | try: |
| | | # 初始化结果字典 |
| | | result = { |
| | | "company": None, # 12306 |
| | | "trips": None, # 车次 |
| | | "start": None, # 出发站 |
| | | "end": None, # 到达站 |
| | | "date": None, # 日期 |
| | | "time": None, # 时间 |
| | | "seat": None, # 座位等信息 |
| | | "name": None # 用户姓名 |
| | | } |
| | | |
| | | # 使用NER模型提取实体 |
| | | inputs = self.train_tokenizer( |
| | | text, |
| | | return_tensors="pt", |
| | | truncation=True, |
| | | max_length=TrainNERConfig.MAX_LENGTH |
| | | ) |
| | | |
| | | with torch.no_grad(): |
| | | outputs = self.train_model(**inputs) |
| | | |
| | | predictions = torch.argmax(outputs.logits, dim=2) |
| | | tokens = self.train_tokenizer.convert_ids_to_tokens(inputs["input_ids"][0]) |
| | | tags = [self.train_model.config.id2label[p] for p in predictions[0].numpy()] |
| | | |
| | | # 解析实体 |
| | | current_entity = None |
| | | |
| | | for token, tag in zip(tokens, tags): |
| | | if tag.startswith("B-"): |
| | | if current_entity: |
| | | entity_type = current_entity["type"].lower() |
| | | result[entity_type] = current_entity["text"].replace("[UNK]", "").replace("##", "").strip() |
| | | current_entity = {"type": tag[2:], "text": token} |
| | | elif tag.startswith("I-") and current_entity and tag[2:] == current_entity["type"]: |
| | | current_entity["text"] += token |
| | | else: |
| | | if current_entity: |
| | | entity_type = current_entity["type"].lower() |
| | | result[entity_type] = current_entity["text"].replace("[UNK]", "").replace("##", "").strip() |
| | | current_entity = None |
| | | |
| | | # 处理最后一个实体 |
| | | if current_entity: |
| | | entity_type = current_entity["type"].lower() |
| | | result[entity_type] = current_entity["text"].replace("[UNK]", "").replace("##", "").strip() |
| | | |
| | | # 处理公司名称,通常为12306 |
| | | if result["company"]: |
| | | company = result["company"].strip() |
| | | # 如果文本中检测不到公司名称,但包含12306,则默认为12306 |
| | | result["company"] = company |
| | | elif "12306" in text: |
| | | result["company"] = "12306" |
| | | |
| | | # 处理车次格式 |
| | | if result["trips"]: |
| | | trips_no = result["trips"].upper() |
| | | # 清理车次号,只保留字母和数字 |
| | | trips_no = ''.join(c for c in trips_no if c.isalnum() or c in ['/', '-']) |
| | | |
| | | # 验证车次格式 |
| | | valid_patterns = [re.compile(pattern) for pattern in TrainNERConfig.TRIPS_CONFIG['patterns']] |
| | | if any(pattern.match(trips_no) for pattern in valid_patterns): |
| | | result["trips"] = trips_no |
| | | else: |
| | | # 尝试修复常见错误 |
| | | if len(trips_no) >= TrainNERConfig.TRIPS_CONFIG['min_length'] and any(trips_no.startswith(t) for t in TrainNERConfig.TRIPS_CONFIG['train_types']): |
| | | result["trips"] = trips_no |
| | | elif trips_no.isdigit() and 1 <= len(trips_no) <= TrainNERConfig.TRIPS_CONFIG['max_length']: |
| | | result["trips"] = trips_no |
| | | else: |
| | | result["trips"] = None |
| | | |
| | | # 清理日期格式 |
| | | if result["date"]: |
| | | date_str = result["date"] |
| | | # 保留数字和常见日期分隔符 |
| | | date_str = ''.join(c for c in date_str if c.isdigit() or c in ['年', '月', '日', '-', '/', '.']) |
| | | result["date"] = date_str |
| | | |
| | | # 清理时间格式 |
| | | if result["time"]: |
| | | time_str = result["time"] |
| | | # 保留数字和常见时间分隔符 |
| | | time_str = ''.join(c for c in time_str if c.isdigit() or c in [':', '时', '分', '点']) |
| | | result["time"] = time_str |
| | | |
| | | # 处理座位信息 |
| | | if result["seat"]: |
| | | seat_str = result["seat"] |
| | | # 移除可能的额外空格和特殊字符 |
| | | seat_str = seat_str.replace(" ", "").strip() |
| | | result["seat"] = seat_str |
| | | |
| | | # 处理乘客姓名 |
| | | if result["name"]: |
| | | name = result["name"].strip() |
| | | # 移除可能的标点符号 |
| | | name = ''.join(c for c in name if c.isalnum() or c in ['*', '·']) |
| | | result["name"] = name |
| | | |
| | | return result |
| | | except Exception as e: |
| | | logger.error(f"火车票实体提取失败: {str(e)}") |
| | | raise |
| | | |
| | | # 创建Flask应用 |
| | |
| | | details = model_manager.extract_repayment_entities(text) |
| | | elif category == "收入": |
| | | details = model_manager.extract_income_entities(text) |
| | | elif category == "航班": |
| | | details = model_manager.extract_flight_entities(text) |
| | | elif category == "火车票": # 添加火车票类别处理 |
| | | details = model_manager.extract_train_entities(text) |
| | | else: |
| | | details = {} |
| | | |