cloudroam
2025-04-15 acc5c1281b50c12e4d04c81b899410f6ca2cacac
train_flight_ner.py
@@ -0,0 +1,297 @@
# train_flight_ner.py
from transformers import AutoTokenizer, AutoModelForTokenClassification, TrainingArguments, Trainer
from transformers.trainer_callback import EarlyStoppingCallback
import torch
from torch.utils.data import Dataset
import numpy as np
from sklearn.model_selection import train_test_split
from seqeval.metrics import f1_score, precision_score, recall_score
import random
import re
from ner_config import FlightNERConfig
# 设置随机种子
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
set_seed(FlightNERConfig.SEED)
class NERDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, label_list):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        # 创建标签到ID的映射
        self.label2id = {label: i for i, label in enumerate(label_list)}
        self.id2label = {i: label for i, label in enumerate(label_list)}
        # 打印标签映射信息
        print("标签映射:")
        for label, idx in self.label2id.items():
            print(f"{label}: {idx}")
        # 对文本进行编码
        self.encodings = self.tokenize_and_align_labels()
    def tokenize_and_align_labels(self):
        tokenized_inputs = self.tokenizer(
            [''.join(text) for text in self.texts],
            truncation=True,
            padding=True,
            max_length=FlightNERConfig.MAX_LENGTH,
            return_offsets_mapping=True,
            return_tensors=None
        )
        labels = []
        for i, label in enumerate(self.labels):
            word_ids = tokenized_inputs.word_ids(i)
            previous_word_idx = None
            label_ids = []
            current_entity = None
            for word_idx in word_ids:
                if word_idx is None:
                    label_ids.append(-100)
                elif word_idx != previous_word_idx:
                    # 新词开始
                    label_ids.append(self.label2id[label[word_idx]])
                    if label[word_idx].startswith("B-"):
                        current_entity = label[word_idx][2:]
                    elif label[word_idx] == "O":
                        current_entity = None
                else:
                    # 同一个词的后续token
                    if current_entity:
                        label_ids.append(self.label2id[f"I-{current_entity}"])
                    else:
                        label_ids.append(self.label2id["O"])
                previous_word_idx = word_idx
            labels.append(label_ids)
        tokenized_inputs["labels"] = labels
        return tokenized_inputs
    def __getitem__(self, idx):
        return {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
    def __len__(self):
        return len(self.texts)
def load_data(file_path):
    texts, labels = [], []
    current_words, current_labels = [], []
    def clean_flight_labels(words, labels):
        """清理航班号标注,确保格式正确"""
        i = 0
        while i < len(words):
            if labels[i].startswith("B-FLIGHT"):  # 已修改标签名称
                # 找到航班号的结束位置
                j = i + 1
                while j < len(words) and labels[j].startswith("I-FLIGHT"):  # 已修改标签名称
                    j += 1
                # 检查并修正航班号序列
                flight_words = words[i:j]
                flight_str = ''.join(flight_words)
                # 检查格式是否符合航班号规范
                valid_pattern = re.compile(r'^[A-Z]{2}\d{3,4}$')
                if not valid_pattern.match(flight_str):
                    # 将格式不正确的标签改为O
                    for k in range(i, j):
                        labels[k] = "O"
                i = j
            else:
                i += 1
        return words, labels
    with open(file_path, 'r', encoding='utf-8') as f:
        for line in f:
            line = line.strip()
            if line:
                try:
                    word, label = line.split(maxsplit=1)
                    current_words.append(word)
                    current_labels.append(label)
                except Exception as e:
                    print(f"错误:处理行时出错: '{line}'")
                    continue
            elif current_words:  # 遇到空行且当前有数据
                # 清理航班号标注
                current_words, current_labels = clean_flight_labels(current_words, current_labels)
                texts.append(current_words)
                labels.append(current_labels)
                current_words, current_labels = [], []
    if current_words:  # 处理最后一个样本
        current_words, current_labels = clean_flight_labels(current_words, current_labels)
        texts.append(current_words)
        labels.append(current_labels)
    return texts, labels
def compute_metrics(p):
    """计算评估指标"""
    predictions, labels = p
    predictions = np.argmax(predictions, axis=2)
    # 移除特殊token的预测和标签
    true_predictions = [
        [FlightNERConfig.LABELS[p] for (p, l) in zip(prediction, label) if l != -100]
        for prediction, label in zip(predictions, labels)
    ]
    true_labels = [
        [FlightNERConfig.LABELS[l] for (p, l) in zip(prediction, label) if l != -100]
        for prediction, label in zip(predictions, labels)
    ]
    # 计算总体评估指标
    results = {
        "overall_f1": f1_score(true_labels, true_predictions),
        "overall_precision": precision_score(true_labels, true_predictions),
        "overall_recall": recall_score(true_labels, true_predictions)
    }
    # 计算每个实体类型的指标
    for entity_type in ["FLIGHT", "COMPANY", "START", "END", "DATE", "TIME", "DEPARTURE_TIME", "ARRIVAL_TIME","TICKET_NUM","SEAT"]:
        # 将标签转换为二进制形式
        binary_preds = []
        binary_labels = []
        for pred_seq, label_seq in zip(true_predictions, true_labels):
            pred_binary = []
            label_binary = []
            for pred, label in zip(pred_seq, label_seq):
                # 检查标签是否属于当前实体类型
                pred_is_entity = pred.endswith(entity_type)
                label_is_entity = label.endswith(entity_type)
                pred_binary.append(1 if pred_is_entity else 0)
                label_binary.append(1 if label_is_entity else 0)
            binary_preds.append(pred_binary)
            binary_labels.append(label_binary)
        # 计算当前实体类型的F1分数
        try:
            entity_f1 = f1_score(
                sum(binary_labels, []),  # 展平列表
                sum(binary_preds, []),   # 展平列表
                average='binary'         # 使用二进制评估
            )
            results[f"{entity_type}_f1"] = entity_f1
        except Exception as e:
            print(f"计算{entity_type}的F1分数时出错: {str(e)}")
            results[f"{entity_type}_f1"] = 0.0
    return results
def augment_data(texts, labels):
    """数据增强"""
    augmented_texts = []
    augmented_labels = []
    for text, label in zip(texts, labels):
        # 原始数据
        augmented_texts.append(text)
        augmented_labels.append(label)
        # 删除一些无关字符
        new_text = []
        new_label = []
        for t, l in zip(text, label):
            if l == "O" and random.random() < 0.3:
                continue
            new_text.append(t)
            new_label.append(l)
        augmented_texts.append(new_text)
        augmented_labels.append(new_label)
    return augmented_texts, augmented_labels
def main():
    # 加载数据
    texts, labels = load_data(FlightNERConfig.DATA_PATH)
    print(f"加载的数据集大小:{len(texts)}个样本")
    # 划分数据集
    train_texts, val_texts, train_labels, val_labels = train_test_split(
        texts, labels, test_size=FlightNERConfig.TEST_SIZE, random_state=FlightNERConfig.SEED
    )
    # 数据增强
    train_texts, train_labels = augment_data(train_texts, train_labels)
    # 加载分词器和模型
    tokenizer = AutoTokenizer.from_pretrained(FlightNERConfig.MODEL_NAME)
    model = AutoModelForTokenClassification.from_pretrained(
        FlightNERConfig.MODEL_NAME,
        num_labels=len(FlightNERConfig.LABELS),
        id2label={i: label for i, label in enumerate(FlightNERConfig.LABELS)},
        label2id={label: i for i, label in enumerate(FlightNERConfig.LABELS)}
    )
    # 创建数据集
    train_dataset = NERDataset(train_texts, train_labels, tokenizer, FlightNERConfig.LABELS)
    val_dataset = NERDataset(val_texts, val_labels, tokenizer, FlightNERConfig.LABELS)
    # 训练参数
    training_args = TrainingArguments(
        output_dir=FlightNERConfig.MODEL_PATH,
        num_train_epochs=FlightNERConfig.EPOCHS,
        per_device_train_batch_size=FlightNERConfig.BATCH_SIZE,
        per_device_eval_batch_size=FlightNERConfig.BATCH_SIZE,
        learning_rate=FlightNERConfig.LEARNING_RATE,
        warmup_ratio=FlightNERConfig.WARMUP_RATIO,
        weight_decay=FlightNERConfig.WEIGHT_DECAY,
        gradient_accumulation_steps=FlightNERConfig.GRADIENT_ACCUMULATION_STEPS,
        logging_steps=FlightNERConfig.LOGGING_STEPS,
        save_total_limit=2,
        no_cuda=True,
        evaluation_strategy="steps",
        eval_steps=FlightNERConfig.EVAL_STEPS,
        save_strategy="steps",
        save_steps=FlightNERConfig.SAVE_STEPS,
        load_best_model_at_end=True,
        metric_for_best_model="overall_f1",
        greater_is_better=True,
        logging_dir=FlightNERConfig.LOG_PATH,
        logging_first_step=True,
        report_to=["tensorboard"],
    )
    # 训练器
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=val_dataset,
        compute_metrics=compute_metrics,
        callbacks=[EarlyStoppingCallback(early_stopping_patience=FlightNERConfig.EARLY_STOPPING_PATIENCE)]
    )
    # 训练模型
    trainer.train()
    # 评估结果
    eval_results = trainer.evaluate()
    print("\n评估结果:")
    for key, value in eval_results.items():
        print(f"{key}: {value:.4f}")
    # 保存最终模型
    model.save_pretrained(f"{FlightNERConfig.MODEL_PATH}/best_model")
    tokenizer.save_pretrained(f"{FlightNERConfig.MODEL_PATH}/best_model")
if __name__ == "__main__":
    main()