from transformers import AutoTokenizer, AutoModelForTokenClassification, TrainingArguments, Trainer 
 | 
from transformers.trainer_callback import EarlyStoppingCallback, ProgressCallback 
 | 
import torch 
 | 
from torch.utils.data import Dataset 
 | 
import numpy as np 
 | 
from sklearn.model_selection import train_test_split 
 | 
from seqeval.metrics import f1_score, precision_score, recall_score 
 | 
import random 
 | 
import os 
 | 
from ner_config import RepaymentNERConfig 
 | 
  
 | 
# 设置随机种子 
 | 
def set_seed(seed): 
 | 
    random.seed(seed) 
 | 
    np.random.seed(seed) 
 | 
    torch.manual_seed(seed) 
 | 
    if torch.cuda.is_available(): 
 | 
        torch.cuda.manual_seed_all(seed) 
 | 
  
 | 
set_seed(RepaymentNERConfig.SEED) 
 | 
  
 | 
class RepaymentDataset(Dataset): 
 | 
    def __init__(self, texts, labels, tokenizer, label_list): 
 | 
        self.texts = texts 
 | 
        self.labels = labels 
 | 
        self.tokenizer = tokenizer 
 | 
        self.label2id = {label: i for i, label in enumerate(label_list)} 
 | 
        self.id2label = {i: label for i, label in enumerate(label_list)} 
 | 
        self.encodings = self.tokenize_and_align_labels() 
 | 
  
 | 
    def tokenize_and_align_labels(self): 
 | 
        """分词并对齐标签""" 
 | 
        tokenized_inputs = self.tokenizer( 
 | 
            self.texts,  # 直接传入文本列表 
 | 
            is_split_into_words=True,  # 指示输入已经分词 
 | 
            truncation=True, 
 | 
            padding=True, 
 | 
            max_length=RepaymentNERConfig.MAX_LENGTH, 
 | 
            return_offsets_mapping=True, 
 | 
            return_tensors=None 
 | 
        ) 
 | 
  
 | 
        labels = [] 
 | 
        for i, label_seq in enumerate(self.labels): 
 | 
            word_ids = tokenized_inputs.word_ids(i) 
 | 
            previous_word_idx = None 
 | 
            label_ids = [] 
 | 
  
 | 
            for word_idx in word_ids: 
 | 
                if word_idx is None: 
 | 
                    # 特殊token,如[CLS], [SEP], [PAD] 
 | 
                    label_ids.append(-100) 
 | 
                elif word_idx != previous_word_idx: 
 | 
                    # 新词的第一个token 
 | 
                    try: 
 | 
                        label_ids.append(self.label2id[label_seq[word_idx]]) 
 | 
                    except IndexError: 
 | 
                        print(f"错误:样本 {i} 的标签序列长度与文本不匹配") 
 | 
                        print(f"文本长度: {len(self.texts[i])}") 
 | 
                        print(f"标签长度: {len(label_seq)}") 
 | 
                        print(f"word_idx: {word_idx}") 
 | 
                        raise 
 | 
                else: 
 | 
                    # 同一个词的后续token 
 | 
                    # 如果前一个token是实体的一部分,则使用相同的标签 
 | 
                    if label_seq[word_idx-1].startswith("B-"): 
 | 
                        current_type = label_seq[word_idx-1][2:] 
 | 
                        label_ids.append(self.label2id[f"I-{current_type}"]) 
 | 
                    elif label_seq[word_idx-1].startswith("I-"): 
 | 
                        label_ids.append(self.label2id[label_seq[word_idx-1]]) 
 | 
                    else: 
 | 
                        label_ids.append(self.label2id["O"]) 
 | 
                previous_word_idx = word_idx 
 | 
  
 | 
            labels.append(label_ids) 
 | 
  
 | 
        tokenized_inputs["labels"] = labels 
 | 
        return tokenized_inputs 
 | 
  
 | 
    def __getitem__(self, idx): 
 | 
        return {key: torch.tensor(val[idx]) for key, val in self.encodings.items()} 
 | 
  
 | 
    def __len__(self): 
 | 
        return len(self.texts) 
 | 
  
 | 
def compute_metrics(p): 
 | 
    predictions, labels = p 
 | 
    predictions = np.argmax(predictions, axis=2) 
 | 
  
 | 
    true_predictions = [ 
 | 
        [RepaymentNERConfig.LABELS[p] for (p, l) in zip(prediction, label) if l != -100] 
 | 
        for prediction, label in zip(predictions, labels) 
 | 
    ] 
 | 
    true_labels = [ 
 | 
        [RepaymentNERConfig.LABELS[l] for (p, l) in zip(prediction, label) if l != -100] 
 | 
        for prediction, label in zip(predictions, labels) 
 | 
    ] 
 | 
  
 | 
    results = { 
 | 
        "overall_f1": f1_score(true_labels, true_predictions), 
 | 
        "overall_precision": precision_score(true_labels, true_predictions), 
 | 
        "overall_recall": recall_score(true_labels, true_predictions) 
 | 
    } 
 | 
     
 | 
    return results 
 | 
  
 | 
def train_repayment_model(texts, labels): 
 | 
    # 加载预训练模型和分词器 
 | 
    tokenizer = AutoTokenizer.from_pretrained(RepaymentNERConfig.MODEL_NAME) 
 | 
    model = AutoModelForTokenClassification.from_pretrained( 
 | 
        RepaymentNERConfig.MODEL_NAME, 
 | 
        num_labels=len(RepaymentNERConfig.LABELS), 
 | 
        id2label={i: label for i, label in enumerate(RepaymentNERConfig.LABELS)}, 
 | 
        label2id={label: i for i, label in enumerate(RepaymentNERConfig.LABELS)} 
 | 
    ) 
 | 
  
 | 
    # 划分训练集和验证集 
 | 
    train_texts, val_texts, train_labels, val_labels = train_test_split( 
 | 
        texts, labels, 
 | 
        test_size=RepaymentNERConfig.TEST_SIZE, 
 | 
        random_state=RepaymentNERConfig.SEED 
 | 
    ) 
 | 
  
 | 
    # 创建数据集 
 | 
    train_dataset = RepaymentDataset(train_texts, train_labels, tokenizer, RepaymentNERConfig.LABELS) 
 | 
    val_dataset = RepaymentDataset(val_texts, val_labels, tokenizer, RepaymentNERConfig.LABELS) 
 | 
  
 | 
    # 训练参数 
 | 
    training_args = TrainingArguments( 
 | 
        output_dir=RepaymentNERConfig.MODEL_PATH, 
 | 
        num_train_epochs=RepaymentNERConfig.EPOCHS, 
 | 
        per_device_train_batch_size=RepaymentNERConfig.BATCH_SIZE, 
 | 
        per_device_eval_batch_size=RepaymentNERConfig.BATCH_SIZE * 2,  # 评估时可以用更大的批次 
 | 
        warmup_ratio=RepaymentNERConfig.WARMUP_RATIO, 
 | 
        weight_decay=RepaymentNERConfig.WEIGHT_DECAY, 
 | 
        logging_dir=RepaymentNERConfig.LOG_PATH, 
 | 
        logging_steps=RepaymentNERConfig.LOGGING_STEPS, 
 | 
        evaluation_strategy="steps", 
 | 
        eval_steps=RepaymentNERConfig.EVAL_STEPS, 
 | 
        save_strategy="steps", 
 | 
        save_steps=RepaymentNERConfig.EVAL_STEPS, 
 | 
        save_total_limit=RepaymentNERConfig.SAVE_TOTAL_LIMIT, 
 | 
        load_best_model_at_end=True, 
 | 
        metric_for_best_model="overall_f1", 
 | 
        greater_is_better=True, 
 | 
        max_grad_norm=RepaymentNERConfig.MAX_GRAD_NORM, 
 | 
        gradient_accumulation_steps=RepaymentNERConfig.GRADIENT_ACCUMULATION_STEPS, 
 | 
        fp16=RepaymentNERConfig.FP16, 
 | 
        dataloader_num_workers=RepaymentNERConfig.DATALOADER_NUM_WORKERS, 
 | 
        dataloader_pin_memory=RepaymentNERConfig.DATALOADER_PIN_MEMORY, 
 | 
        save_safetensors=True, 
 | 
        optim="adamw_torch", 
 | 
        disable_tqdm=False, 
 | 
        report_to=["tensorboard"], 
 | 
        group_by_length=True,  # 相似长度的样本放在一起,减少padding 
 | 
        length_column_name="length", 
 | 
        remove_unused_columns=True, 
 | 
        label_smoothing_factor=0.1,  # 添加标签平滑 
 | 
    ) 
 | 
  
 | 
    # 创建训练器 
 | 
    trainer = Trainer( 
 | 
        model=model, 
 | 
        args=training_args, 
 | 
        train_dataset=train_dataset, 
 | 
        eval_dataset=val_dataset, 
 | 
        compute_metrics=compute_metrics, 
 | 
        callbacks=[ 
 | 
            EarlyStoppingCallback( 
 | 
                early_stopping_patience=RepaymentNERConfig.EARLY_STOPPING_PATIENCE, 
 | 
                early_stopping_threshold=0.001 
 | 
            ), 
 | 
            # 添加进度条回调 
 | 
            ProgressCallback() 
 | 
        ] 
 | 
    ) 
 | 
  
 | 
    try: 
 | 
        # 训练模型 
 | 
        print("\n开始训练模型...") 
 | 
        train_result = trainer.train() 
 | 
         
 | 
        # 打印训练结果 
 | 
        print("\n训练完成!") 
 | 
        print(f"训练时长: {train_result.metrics['train_runtime']:.2f}秒") 
 | 
         
 | 
        # 安全地获取和打印指标 
 | 
        metrics = train_result.metrics 
 | 
        print("\n训练指标:") 
 | 
        for key, value in metrics.items(): 
 | 
            if isinstance(value, (int, float)): 
 | 
                print(f"- {key}: {value:.4f}") 
 | 
         
 | 
        # 最终评估 
 | 
        final_eval = trainer.evaluate() 
 | 
        print("\n最终评估结果:") 
 | 
        print(f"F1分数: {final_eval['eval_overall_f1']:.4f}") 
 | 
        print(f"准确率: {final_eval['eval_overall_precision']:.4f}") 
 | 
        print(f"召回率: {final_eval['eval_overall_recall']:.4f}") 
 | 
         
 | 
        # 保存最佳模型 
 | 
        print("\n保存模型...") 
 | 
        save_path = f"{RepaymentNERConfig.MODEL_PATH}/best_model" 
 | 
        trainer.save_model(save_path) 
 | 
        tokenizer.save_pretrained(save_path) 
 | 
        print(f"模型已保存到: {save_path}") 
 | 
         
 | 
        return model, tokenizer 
 | 
         
 | 
    except Exception as e: 
 | 
        print(f"\n训练过程中断: {str(e)}") 
 | 
        # 尝试保存当前模型 
 | 
        try: 
 | 
            save_path = f"{RepaymentNERConfig.MODEL_PATH}/interrupted_model" 
 | 
            trainer.save_model(save_path) 
 | 
            tokenizer.save_pretrained(save_path) 
 | 
            print(f"已保存中断时的模型到: {save_path}") 
 | 
        except Exception as save_error: 
 | 
            print(f"保存中断模型失败: {str(save_error)}") 
 | 
        raise 
 | 
  
 | 
def validate_labels(labels, valid_labels): 
 | 
    """验证标签是否合法""" 
 | 
    label_set = set() 
 | 
    for seq in labels: 
 | 
        label_set.update(seq) 
 | 
     
 | 
    invalid_labels = label_set - set(valid_labels) 
 | 
    if invalid_labels: 
 | 
        raise ValueError(f"发现非法标签: {invalid_labels}") 
 | 
  
 | 
def clean_text(text: str) -> str: 
 | 
    """清理文本中的特殊字符""" 
 | 
    # 替换全角字符为半角 
 | 
    text = text.replace('¥', '¥') 
 | 
    text = text.replace(',', ',') 
 | 
    text = text.replace('。', '.') 
 | 
    text = text.replace(':', ':') 
 | 
    text = text.replace('(', '(') 
 | 
    text = text.replace(')', ')') 
 | 
    return text 
 | 
  
 | 
def preprocess_data(texts, labels): 
 | 
    """预处理数据""" 
 | 
    processed_texts = [] 
 | 
    processed_labels = [] 
 | 
     
 | 
    for i, (text, label_seq) in enumerate(zip(texts, labels)): 
 | 
        if len(text) != len(label_seq): 
 | 
            print(f"警告:样本 {i} 的文本和标签长度不匹配,已跳过") 
 | 
            continue 
 | 
             
 | 
        # 清理文本 
 | 
        cleaned_text = [clean_text(word) for word in text] 
 | 
         
 | 
        # 处理金额标注 
 | 
        is_min_amount = False 
 | 
        new_labels = [] 
 | 
        for j, (word, label) in enumerate(zip(cleaned_text, label_seq)): 
 | 
            if label.startswith("B-PICKUP_CODE"): 
 | 
                # 检查是否是最低还款金额 
 | 
                context = ''.join(cleaned_text[max(0, j-5):j]) 
 | 
                if any(kw in context for kw in RepaymentNERConfig.AMOUNT_CONFIG['min_amount_keywords']): 
 | 
                    is_min_amount = True 
 | 
                    new_labels.append("B-MIN_CODE") 
 | 
                else: 
 | 
                    is_min_amount = False 
 | 
                    new_labels.append(label) 
 | 
            elif label.startswith("I-PICKUP_CODE"): 
 | 
                if is_min_amount: 
 | 
                    new_labels.append("I-MIN_CODE") 
 | 
                else: 
 | 
                    new_labels.append(label) 
 | 
            else: 
 | 
                new_labels.append(label) 
 | 
         
 | 
        processed_texts.append(cleaned_text) 
 | 
        processed_labels.append(new_labels) 
 | 
     
 | 
    return processed_texts, processed_labels 
 | 
  
 | 
def load_data(file_path): 
 | 
    """加载并预处理数据""" 
 | 
    texts = [] 
 | 
    labels = [] 
 | 
    current_words = [] 
 | 
    current_labels = [] 
 | 
    skip_url = False 
 | 
    url_indicators = {'u', 'ur', 'url', 'http', 'https', 'www', 'com', 'cn'} 
 | 
     
 | 
    def is_url_part(word): 
 | 
        return (word.lower() in url_indicators or  
 | 
                '.' in word or  
 | 
                '/' in word or  
 | 
                word.startswith('?')) 
 | 
     
 | 
    with open(file_path, 'r', encoding='utf-8') as f: 
 | 
        for line_num, line in enumerate(f, 1): 
 | 
            line = line.strip() 
 | 
             
 | 
            if not line:  # 样本分隔符 
 | 
                if current_words: 
 | 
                    texts.append(current_words) 
 | 
                    labels.append(current_labels) 
 | 
                    current_words = [] 
 | 
                    current_labels = [] 
 | 
                skip_url = False 
 | 
                continue 
 | 
             
 | 
            try: 
 | 
                word, label = line.split(maxsplit=1) 
 | 
                 
 | 
                # URL处理逻辑 
 | 
                if is_url_part(word): 
 | 
                    skip_url = True 
 | 
                elif word in ['】', ',', '。', ':']: 
 | 
                    skip_url = False 
 | 
                 
 | 
                # 标签验证 
 | 
                if not skip_url: 
 | 
                    if label not in RepaymentNERConfig.LABELS: 
 | 
                        print(f"警告:第{line_num}行发现非法标签 '{label}',已跳过") 
 | 
                        continue 
 | 
                    current_words.append(word) 
 | 
                    current_labels.append(label) 
 | 
                 
 | 
            except Exception as e: 
 | 
                print(f"错误:第{line_num}行处理失败 '{line}': {str(e)}") 
 | 
                continue 
 | 
     
 | 
    # 处理最后一个样本 
 | 
    if current_words: 
 | 
        texts.append(current_words) 
 | 
        labels.append(current_labels) 
 | 
     
 | 
    return texts, labels 
 | 
  
 | 
def validate_dataset(texts, labels): 
 | 
    """验证数据集的完整性和正确性""" 
 | 
    stats = { 
 | 
        "total_samples": len(texts), 
 | 
        "total_tokens": sum(len(t) for t in texts), 
 | 
        "entity_counts": {}, 
 | 
        "avg_length": 0, 
 | 
        "errors": [] 
 | 
    } 
 | 
     
 | 
    for i, (text, label_seq) in enumerate(zip(texts, labels)): 
 | 
        # 长度检查 
 | 
        if len(text) != len(label_seq): 
 | 
            stats["errors"].append(f"样本 {i}: 文本和标签长度不匹配") 
 | 
            continue 
 | 
             
 | 
        # 统计实体 
 | 
        current_entity = None 
 | 
        for j, (word, label) in enumerate(zip(text, label_seq)): 
 | 
            if label.startswith("B-"): 
 | 
                entity_type = label[2:] 
 | 
                stats["entity_counts"][entity_type] = stats["entity_counts"].get(entity_type, 0) + 1 
 | 
                current_entity = entity_type 
 | 
            elif label.startswith("I-"): 
 | 
                if not current_entity: 
 | 
                    stats["errors"].append(f"样本 {i}: 位置 {j} 的I-标签前没有对应的B-标签") 
 | 
                elif label[2:] != current_entity: 
 | 
                    stats["errors"].append(f"样本 {i}: 位置 {j} 的I-标签类型与B-标签不匹配") 
 | 
            else: 
 | 
                current_entity = None 
 | 
     
 | 
    stats["avg_length"] = stats["total_tokens"] / stats["total_samples"] if stats["total_samples"] > 0 else 0 
 | 
     
 | 
    return stats 
 | 
  
 | 
def resume_training(checkpoint_path): 
 | 
    """从检查点恢复训练""" 
 | 
    print(f"从检查点恢复训练: {checkpoint_path}") 
 | 
     
 | 
    # 加载模型和分词器 
 | 
    tokenizer = AutoTokenizer.from_pretrained(checkpoint_path) 
 | 
    model = AutoModelForTokenClassification.from_pretrained(checkpoint_path) 
 | 
     
 | 
    # 重新加载数据 
 | 
    texts, labels = load_data(RepaymentNERConfig.DATA_PATH) 
 | 
    texts, labels = preprocess_data(texts, labels) 
 | 
     
 | 
    # 重新创建数据集 
 | 
    train_texts, val_texts, train_labels, val_labels = train_test_split( 
 | 
        texts, labels, 
 | 
        test_size=RepaymentNERConfig.TEST_SIZE, 
 | 
        random_state=RepaymentNERConfig.SEED 
 | 
    ) 
 | 
     
 | 
    train_dataset = RepaymentDataset(train_texts, train_labels, tokenizer, RepaymentNERConfig.LABELS) 
 | 
    val_dataset = RepaymentDataset(val_texts, val_labels, tokenizer, RepaymentNERConfig.LABELS) 
 | 
     
 | 
    # 创建训练器并继续训练 
 | 
    training_args = TrainingArguments( 
 | 
        output_dir=RepaymentNERConfig.MODEL_PATH, 
 | 
        num_train_epochs=RepaymentNERConfig.EPOCHS, 
 | 
        # ... 其他参数与train_repayment_model中相同 ... 
 | 
    ) 
 | 
     
 | 
    trainer = Trainer( 
 | 
        model=model, 
 | 
        args=training_args, 
 | 
        train_dataset=train_dataset, 
 | 
        eval_dataset=val_dataset, 
 | 
        compute_metrics=compute_metrics, 
 | 
        callbacks=[EarlyStoppingCallback(early_stopping_patience=RepaymentNERConfig.EARLY_STOPPING_PATIENCE)] 
 | 
    ) 
 | 
     
 | 
    # 继续训练 
 | 
    trainer.train(resume_from_checkpoint=checkpoint_path) 
 | 
     
 | 
    return model, tokenizer 
 | 
  
 | 
def main(): 
 | 
    # 加载数据 
 | 
    print("正在加载数据...") 
 | 
    texts, labels = load_data(RepaymentNERConfig.DATA_PATH) 
 | 
     
 | 
    # 数据预处理 
 | 
    print("正在预处理数据...") 
 | 
    texts, labels = preprocess_data(texts, labels) 
 | 
     
 | 
    # 验证数据 
 | 
    print("验证数据集...") 
 | 
    for i, (text, label_seq) in enumerate(zip(texts, labels)): 
 | 
        if len(text) != len(label_seq): 
 | 
            print(f"错误:样本 {i} 的文本和标签长度不匹配") 
 | 
            print(f"文本({len(text)}): {text}") 
 | 
            print(f"标签({len(label_seq)}): {label_seq}") 
 | 
            return 
 | 
     
 | 
    print(f"数据验证通过,共 {len(texts)} 个有效样本") 
 | 
     
 | 
    # 验证数据集 
 | 
    print("正在验证数据集...") 
 | 
    stats = validate_dataset(texts, labels) 
 | 
     
 | 
    print("\n=== 数据集统计 ===") 
 | 
    print(f"总样本数: {stats['total_samples']}") 
 | 
    print(f"平均长度: {stats['avg_length']:.2f}") 
 | 
    print("\n实体统计:") 
 | 
    for entity_type, count in stats['entity_counts'].items(): 
 | 
        print(f"- {entity_type}: {count}") 
 | 
     
 | 
    if stats['errors']: 
 | 
        print("\n发现以下问题:") 
 | 
        for error in stats['errors']: 
 | 
            print(f"- {error}") 
 | 
        if input("是否继续训练? (y/n) ").lower() != 'y': 
 | 
            return 
 | 
     
 | 
    # 检查是否存在中断的模型 
 | 
    interrupted_model_path = f"{RepaymentNERConfig.MODEL_PATH}/interrupted_model" 
 | 
    if os.path.exists(interrupted_model_path): 
 | 
        print("\n发现中断的训练模型") 
 | 
        if input("是否从中断处继续训练? (y/n) ").lower() == 'y': 
 | 
            model, tokenizer = resume_training(interrupted_model_path) 
 | 
            return 
 | 
     
 | 
    # 正常训练流程 
 | 
    print("\n开始新的训练...") 
 | 
    model, tokenizer = train_repayment_model(texts, labels) 
 | 
  
 | 
if __name__ == "__main__": 
 | 
    main()  
 |