from transformers import AutoTokenizer, AutoModelForTokenClassification, TrainingArguments, Trainer from transformers.trainer_callback import EarlyStoppingCallback, ProgressCallback import torch from torch.utils.data import Dataset import numpy as np from sklearn.model_selection import train_test_split from seqeval.metrics import f1_score, precision_score, recall_score import random import os from ner_config import IncomeNERConfig # 设置随机种子 def set_seed(seed): random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) set_seed(IncomeNERConfig.SEED) class IncomeDataset(Dataset): def __init__(self, texts, labels, tokenizer, label_list): self.texts = texts self.labels = labels self.tokenizer = tokenizer self.label2id = {label: i for i, label in enumerate(label_list)} self.id2label = {i: label for i, label in enumerate(label_list)} self.encodings = self.tokenize_and_align_labels() def tokenize_and_align_labels(self): """分词并对齐标签""" tokenized_inputs = self.tokenizer( self.texts, is_split_into_words=True, truncation=True, padding=True, max_length=IncomeNERConfig.MAX_LENGTH, return_offsets_mapping=True, return_tensors=None ) labels = [] for i, label_seq in enumerate(self.labels): word_ids = tokenized_inputs.word_ids(i) previous_word_idx = None label_ids = [] for word_idx in word_ids: if word_idx is None: label_ids.append(-100) elif word_idx != previous_word_idx: try: label_ids.append(self.label2id[label_seq[word_idx]]) except IndexError: print(f"错误:样本 {i} 的标签序列长度与文本不匹配") print(f"文本长度: {len(self.texts[i])}") print(f"标签长度: {len(label_seq)}") print(f"word_idx: {word_idx}") raise else: if label_seq[word_idx-1].startswith("B-"): current_type = label_seq[word_idx-1][2:] label_ids.append(self.label2id[f"I-{current_type}"]) elif label_seq[word_idx-1].startswith("I-"): label_ids.append(self.label2id[label_seq[word_idx-1]]) else: label_ids.append(self.label2id["O"]) previous_word_idx = word_idx labels.append(label_ids) tokenized_inputs["labels"] = labels return tokenized_inputs def __getitem__(self, idx): return {key: torch.tensor(val[idx]) for key, val in self.encodings.items()} def __len__(self): return len(self.texts) def compute_metrics(eval_pred): predictions, labels = eval_pred predictions = np.argmax(predictions, axis=2) true_predictions = [ [IncomeNERConfig.LABELS[p] for (p, l) in zip(prediction, label) if l != -100] for prediction, label in zip(predictions, labels) ] true_labels = [ [IncomeNERConfig.LABELS[l] for (p, l) in zip(prediction, label) if l != -100] for prediction, label in zip(predictions, labels) ] return { "overall_precision": precision_score(true_labels, true_predictions), "overall_recall": recall_score(true_labels, true_predictions), "overall_f1": f1_score(true_labels, true_predictions), } def train_income_model(texts, labels): # 加载预训练模型和分词器 tokenizer = AutoTokenizer.from_pretrained(IncomeNERConfig.MODEL_NAME) model = AutoModelForTokenClassification.from_pretrained( IncomeNERConfig.MODEL_NAME, num_labels=len(IncomeNERConfig.LABELS), id2label={i: label for i, label in enumerate(IncomeNERConfig.LABELS)}, label2id={label: i for i, label in enumerate(IncomeNERConfig.LABELS)} ) # 划分训练集和验证集 train_texts, val_texts, train_labels, val_labels = train_test_split( texts, labels, test_size=IncomeNERConfig.TEST_SIZE, random_state=IncomeNERConfig.SEED ) # 创建数据集 train_dataset = IncomeDataset(train_texts, train_labels, tokenizer, IncomeNERConfig.LABELS) val_dataset = IncomeDataset(val_texts, val_labels, tokenizer, IncomeNERConfig.LABELS) # 训练参数 training_args = TrainingArguments( output_dir=IncomeNERConfig.MODEL_PATH, num_train_epochs=IncomeNERConfig.EPOCHS, per_device_train_batch_size=IncomeNERConfig.BATCH_SIZE, per_device_eval_batch_size=IncomeNERConfig.BATCH_SIZE * 2, warmup_ratio=IncomeNERConfig.WARMUP_RATIO, weight_decay=IncomeNERConfig.WEIGHT_DECAY, logging_dir=IncomeNERConfig.LOG_PATH, logging_steps=IncomeNERConfig.LOGGING_STEPS, evaluation_strategy="steps", eval_steps=IncomeNERConfig.EVAL_STEPS, save_strategy="steps", save_steps=IncomeNERConfig.SAVE_STEPS, save_total_limit=IncomeNERConfig.SAVE_TOTAL_LIMIT, load_best_model_at_end=True, metric_for_best_model="overall_f1", greater_is_better=True, max_grad_norm=IncomeNERConfig.MAX_GRAD_NORM, gradient_accumulation_steps=IncomeNERConfig.GRADIENT_ACCUMULATION_STEPS, fp16=IncomeNERConfig.FP16, dataloader_num_workers=IncomeNERConfig.DATALOADER_NUM_WORKERS, dataloader_pin_memory=IncomeNERConfig.DATALOADER_PIN_MEMORY, save_safetensors=True, optim="adamw_torch", disable_tqdm=False, report_to=["tensorboard"], group_by_length=True, length_column_name="length", remove_unused_columns=True, label_smoothing_factor=0.1, ) # 创建训练器 trainer = Trainer( model=model, args=training_args, train_dataset=train_dataset, eval_dataset=val_dataset, compute_metrics=compute_metrics, callbacks=[ EarlyStoppingCallback( early_stopping_patience=IncomeNERConfig.EARLY_STOPPING_PATIENCE, early_stopping_threshold=0.001 ), ProgressCallback() ] ) try: # 训练模型 print("\n开始训练模型...") train_result = trainer.train() # 打印训练结果 print("\n训练完成!") print(f"训练时长: {train_result.metrics['train_runtime']:.2f}秒") # 打印指标 metrics = train_result.metrics print("\n训练指标:") for key, value in metrics.items(): if isinstance(value, (int, float)): print(f"- {key}: {value:.4f}") # 最终评估 final_eval = trainer.evaluate() print("\n最终评估结果:") print(f"F1分数: {final_eval['eval_overall_f1']:.4f}") print(f"准确率: {final_eval['eval_overall_precision']:.4f}") print(f"召回率: {final_eval['eval_overall_recall']:.4f}") # 保存最佳模型 print("\n保存模型...") save_path = f"{IncomeNERConfig.MODEL_PATH}/best_model" trainer.save_model(save_path) tokenizer.save_pretrained(save_path) print(f"模型已保存到: {save_path}") return model, tokenizer except Exception as e: print(f"\n训练过程中断: {str(e)}") try: save_path = f"{IncomeNERConfig.MODEL_PATH}/interrupted_model" trainer.save_model(save_path) tokenizer.save_pretrained(save_path) print(f"已保存中断时的模型到: {save_path}") except Exception as save_error: print(f"保存中断模型失败: {str(save_error)}") raise def load_data(file_path): """加载并预处理数据""" texts = [] labels = [] current_words = [] current_labels = [] skip_url = False url_indicators = {'u', 'ur', 'url', 'http', 'https', 'www', 'com', 'cn'} def is_url_part(word): return (word.lower() in url_indicators or '.' in word or '/' in word or word.startswith('?')) with open(file_path, 'r', encoding='utf-8') as f: for line_num, line in enumerate(f, 1): line = line.strip() if not line: # 样本分隔符 if current_words: texts.append(current_words) labels.append(current_labels) current_words = [] current_labels = [] skip_url = False continue try: word, label = line.split(maxsplit=1) # URL处理逻辑 if is_url_part(word): skip_url = True elif word in ['】', ',', '。', ':']: skip_url = False # 标签验证 if not skip_url: if label not in IncomeNERConfig.LABELS: print(f"警告:第{line_num}行发现非法标签 '{label}',已跳过") continue current_words.append(word) current_labels.append(label) except Exception as e: print(f"错误:第{line_num}行处理失败 '{line}': {str(e)}") continue # 处理最后一个样本 if current_words: texts.append(current_words) labels.append(current_labels) return texts, labels def clean_text(word): """清理文本""" # 移除特殊字符 word = word.strip('*') # 统一货币符号 if word in ['¥', '¥', 'RMB']: return '元' return word def preprocess_data(texts, labels): """预处理数据""" processed_texts = [] processed_labels = [] for i, (text, label_seq) in enumerate(zip(texts, labels)): if len(text) != len(label_seq): print(f"警告:样本 {i} 的文本和标签长度不匹配,已跳过") continue # 清理文本 cleaned_text = [clean_text(word) for word in text] # 处理金额标注 is_balance = False new_labels = [] for j, (word, label) in enumerate(zip(cleaned_text, label_seq)): if label.startswith("B-PICKUP_CODE"): # 检查是否是余额 context = ''.join(cleaned_text[max(0, j-5):j]) if any(kw in context for kw in ['余额', '余', '结余']): is_balance = True new_labels.append("B-BALANCE") else: is_balance = False new_labels.append(label) elif label.startswith("I-PICKUP_CODE"): if is_balance: new_labels.append("I-BALANCE") else: new_labels.append(label) else: new_labels.append(label) processed_texts.append(cleaned_text) processed_labels.append(new_labels) return processed_texts, processed_labels def resume_training(checkpoint_path): """从检查点恢复训练""" print(f"从检查点恢复训练: {checkpoint_path}") # 加载模型和分词器 tokenizer = AutoTokenizer.from_pretrained(checkpoint_path) model = AutoModelForTokenClassification.from_pretrained(checkpoint_path) # 重新加载数据 texts, labels = load_data(IncomeNERConfig.DATA_PATH) texts, labels = preprocess_data(texts, labels) # 重新创建数据集 train_texts, val_texts, train_labels, val_labels = train_test_split( texts, labels, test_size=IncomeNERConfig.TEST_SIZE, random_state=IncomeNERConfig.SEED ) train_dataset = IncomeDataset(train_texts, train_labels, tokenizer, IncomeNERConfig.LABELS) val_dataset = IncomeDataset(val_texts, val_labels, tokenizer, IncomeNERConfig.LABELS) # 创建训练器并继续训练 training_args = TrainingArguments( output_dir=IncomeNERConfig.MODEL_PATH, num_train_epochs=IncomeNERConfig.EPOCHS, per_device_train_batch_size=IncomeNERConfig.BATCH_SIZE, per_device_eval_batch_size=IncomeNERConfig.BATCH_SIZE * 2, warmup_ratio=IncomeNERConfig.WARMUP_RATIO, weight_decay=IncomeNERConfig.WEIGHT_DECAY, logging_dir=IncomeNERConfig.LOG_PATH, logging_steps=IncomeNERConfig.LOGGING_STEPS, evaluation_strategy="steps", eval_steps=IncomeNERConfig.EVAL_STEPS, save_strategy="steps", save_steps=IncomeNERConfig.SAVE_STEPS, save_total_limit=IncomeNERConfig.SAVE_TOTAL_LIMIT, load_best_model_at_end=True, metric_for_best_model="overall_f1", greater_is_better=True, max_grad_norm=IncomeNERConfig.MAX_GRAD_NORM, gradient_accumulation_steps=IncomeNERConfig.GRADIENT_ACCUMULATION_STEPS, fp16=IncomeNERConfig.FP16, dataloader_num_workers=IncomeNERConfig.DATALOADER_NUM_WORKERS, dataloader_pin_memory=IncomeNERConfig.DATALOADER_PIN_MEMORY ) trainer = Trainer( model=model, args=training_args, train_dataset=train_dataset, eval_dataset=val_dataset, compute_metrics=compute_metrics, callbacks=[EarlyStoppingCallback(early_stopping_patience=IncomeNERConfig.EARLY_STOPPING_PATIENCE)] ) # 继续训练 trainer.train(resume_from_checkpoint=checkpoint_path) return model, tokenizer def main(): # 加载数据 print("正在加载数据...") texts, labels = load_data(IncomeNERConfig.DATA_PATH) # 数据预处理 print("正在预处理数据...") texts, labels = preprocess_data(texts, labels) # 验证数据 print("验证数据集...") for i, (text, label_seq) in enumerate(zip(texts, labels)): if len(text) != len(label_seq): print(f"错误:样本 {i} 的文本和标签长度不匹配") print(f"文本({len(text)}): {text}") print(f"标签({len(label_seq)}): {label_seq}") return print(f"数据验证通过,共 {len(texts)} 个有效样本") # 检查是否存在中断的模型 interrupted_model_path = f"{IncomeNERConfig.MODEL_PATH}/interrupted_model" if os.path.exists(interrupted_model_path): print("\n发现中断的训练模型") if input("是否从中断处继续训练? (y/n) ").lower() == 'y': model, tokenizer = resume_training(interrupted_model_path) return # 正常训练流程 print("\n开始新的训练...") model, tokenizer = train_income_model(texts, labels) if __name__ == "__main__": main()