From acc5c1281b50c12e4d04c81b899410f6ca2cacac Mon Sep 17 00:00:00 2001
From: cloudroam <cloudroam>
Date: 星期二, 15 四月 2025 15:13:30 +0800
Subject: [PATCH] add: 增加航班和火车票

---
 train_train_ner.py |  289 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 1 files changed, 289 insertions(+), 0 deletions(-)

diff --git a/train_train_ner.py b/train_train_ner.py
index e69de29..3ff7fb7 100644
--- a/train_train_ner.py
+++ b/train_train_ner.py
@@ -0,0 +1,289 @@
+# train_train_ner.py 
+from transformers import AutoTokenizer, AutoModelForTokenClassification, TrainingArguments, Trainer
+from transformers.trainer_callback import EarlyStoppingCallback
+import torch
+from torch.utils.data import Dataset
+import numpy as np
+from sklearn.model_selection import train_test_split
+from seqeval.metrics import f1_score, precision_score, recall_score
+import random
+import os
+import re
+from ner_config import TrainNERConfig
+
+# 设置随机种子
+def set_seed(seed):
+    random.seed(seed)
+    np.random.seed(seed)
+    torch.manual_seed(seed)
+    if torch.cuda.is_available():
+        torch.cuda.manual_seed_all(seed)
+
+set_seed(TrainNERConfig.SEED)
+
+class NERDataset(Dataset):
+    def __init__(self, texts, labels, tokenizer, label_list):
+        self.texts = texts
+        self.labels = labels
+        self.tokenizer = tokenizer
+        # 创建标签到ID的映射
+        self.label2id = {label: i for i, label in enumerate(label_list)}
+        self.id2label = {i: label for i, label in enumerate(label_list)}
+        
+        # 打印标签映射信息
+        print("标签映射:")
+        for label, idx in self.label2id.items():
+            print(f"{label}: {idx}")
+            
+        # 对文本进行编码
+        self.encodings = self.tokenize_and_align_labels()
+
+    def tokenize_and_align_labels(self):
+        tokenized_inputs = self.tokenizer(
+            [''.join(text) for text in self.texts],
+            truncation=True,
+            padding=True,
+            max_length=TrainNERConfig.MAX_LENGTH,
+            return_offsets_mapping=True,
+            return_tensors=None
+        )
+
+        labels = []
+        for i, label in enumerate(self.labels):
+            word_ids = tokenized_inputs.word_ids(i)
+            previous_word_idx = None
+            label_ids = []
+            current_entity = None
+            
+            for word_idx in word_ids:
+                if word_idx is None:
+                    label_ids.append(-100)
+                elif word_idx != previous_word_idx:
+                    # 新词开始
+                    label_ids.append(self.label2id[label[word_idx]])
+                    if label[word_idx].startswith("B-"):
+                        current_entity = label[word_idx][2:]
+                    elif label[word_idx] == "O":
+                        current_entity = None
+                else:
+                    # 同一个词的后续token
+                    if current_entity:
+                        label_ids.append(self.label2id[f"I-{current_entity}"])
+                    else:
+                        label_ids.append(self.label2id["O"])
+                
+                previous_word_idx = word_idx
+            
+            labels.append(label_ids)
+
+        tokenized_inputs["labels"] = labels
+        return tokenized_inputs
+
+    def __getitem__(self, idx):
+        return {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
+
+    def __len__(self):
+        return len(self.texts)
+
+def load_data(file_path):
+    texts, labels = [], []
+    current_words, current_labels = [], []
+    
+    def clean_trips_labels(words, labels):
+        """清理车次标注,确保格式正确"""
+        i = 0
+        while i < len(words):
+            if labels[i].startswith("B-TRIPS"):  # 修改标签名
+                # 找到车次的结束位置
+                j = i + 1
+                while j < len(words) and labels[j].startswith("I-TRIPS"):  # 修改标签名
+                    j += 1
+                
+                # 检查并修正车次序列
+                train_words = words[i:j]
+                train_str = ''.join(train_words)
+                
+                # 检查格式是否符合车次规范
+                valid_patterns = [
+                    re.compile(r'^[GDCZTKY]\d{1,2}$'),
+                    re.compile(r'^[GDCZTKY]\d{1,2}/\d{1,2}$'),
+                    re.compile(r'^[GDCZTKY]\d{1,2}-\d{1,2}$'),
+                    re.compile(r'^\d{1,4}$'),
+                    re.compile(r'^[A-Z]\d{1,4}$')
+                ]
+                
+                is_valid = any(pattern.match(train_str) for pattern in valid_patterns)
+                if not is_valid:
+                    # 将格式不正确的标签改为O
+                    for k in range(i, j):
+                        labels[k] = "O"
+                
+                i = j
+            else:
+                i += 1
+        
+        return words, labels
+    
+    with open(file_path, 'r', encoding='utf-8') as f:
+        for line in f:
+            line = line.strip()
+            if line:
+                try:
+                    word, label = line.split(maxsplit=1)
+                    current_words.append(word)
+                    current_labels.append(label)
+                except Exception as e:
+                    print(f"错误:处理行时出错: '{line}'")
+                    continue
+            elif current_words:  # 遇到空行且当前有数据
+                # 清理车次标注
+                current_words, current_labels = clean_trips_labels(current_words, current_labels)
+                texts.append(current_words)
+                labels.append(current_labels)
+                current_words, current_labels = [], []
+    
+    if current_words:  # 处理最后一个样本
+        current_words, current_labels = clean_trips_labels(current_words, current_labels)
+        texts.append(current_words)
+        labels.append(current_labels)
+    
+    return texts, labels
+
+def compute_metrics(p):
+    """计算评估指标"""
+    predictions, labels = p
+    predictions = np.argmax(predictions, axis=2)
+
+    # 移除特殊token的预测和标签
+    true_predictions = [
+        [TrainNERConfig.LABELS[p] for (p, l) in zip(prediction, label) if l != -100]
+        for prediction, label in zip(predictions, labels)
+    ]
+    true_labels = [
+        [TrainNERConfig.LABELS[l] for (p, l) in zip(prediction, label) if l != -100]
+        for prediction, label in zip(predictions, labels)
+    ]
+
+    # 计算总体评估指标
+    results = {
+        "overall_f1": f1_score(true_labels, true_predictions),
+        "overall_precision": precision_score(true_labels, true_predictions),
+        "overall_recall": recall_score(true_labels, true_predictions)
+    }
+    
+    # 计算每个实体类型的指标
+    for entity_type in ["COMPANY","TRIPS", "START", "END", "DATE", "TIME", "SEAT", "NAME"]:
+        # 将标签转换为二进制形式
+        binary_preds = []
+        binary_labels = []
+        
+        for pred_seq, label_seq in zip(true_predictions, true_labels):
+            pred_binary = []
+            label_binary = []
+            
+            for pred, label in zip(pred_seq, label_seq):
+                # 检查标签是否属于当前实体类型
+                pred_is_entity = pred.endswith(entity_type)
+                label_is_entity = label.endswith(entity_type)
+                
+                pred_binary.append(1 if pred_is_entity else 0)
+                label_binary.append(1 if label_is_entity else 0)
+            
+            binary_preds.append(pred_binary)
+            binary_labels.append(label_binary)
+        
+        # 计算当前实体类型的F1分数
+        try:
+            entity_f1 = f1_score(
+                sum(binary_labels, []),  # 展平列表
+                sum(binary_preds, []),   # 展平列表
+                average='binary'         # 使用二进制评估
+            )
+            results[f"{entity_type}_f1"] = entity_f1
+        except Exception as e:
+            print(f"计算{entity_type}的F1分数时出错: {str(e)}")
+            results[f"{entity_type}_f1"] = 0.0
+    
+    return results
+
+def augment_data(texts, labels):
+    """数据增强"""
+    augmented_texts = []
+    augmented_labels = []
+    for text, label in zip(texts, labels):
+        # 原始数据
+        augmented_texts.append(text)
+        augmented_labels.append(label)
+        
+        # 删除一些无关字符
+        new_text = []
+        new_label = []
+        for t, l in zip(text, label):
+            if l == "O" and random.random() < 0.3:
+                continue
+            new_text.append(t)
+            new_label.append(l)
+        augmented_texts.append(new_text)
+        augmented_labels.append(new_label)
+    
+    return augmented_texts, augmented_labels
+
+def main():
+    # 加载数据
+    texts, labels = load_data(TrainNERConfig.DATA_PATH)
+    print(f"加载的数据集大小:{len(texts)}个样本")
+    
+    # 划分数据集
+    train_texts, val_texts, train_labels, val_labels = train_test_split(
+        texts, labels, test_size=TrainNERConfig.TEST_SIZE, random_state=TrainNERConfig.SEED
+    )
+    
+    # 数据增强
+    train_texts, train_labels = augment_data(train_texts, train_labels)
+    
+    # 加载分词器和模型
+    tokenizer = AutoTokenizer.from_pretrained(TrainNERConfig.MODEL_NAME)
+    model = AutoModelForTokenClassification.from_pretrained(
+        TrainNERConfig.MODEL_NAME,
+        num_labels=len(TrainNERConfig.LABELS),
+        id2label={i: label for i, label in enumerate(TrainNERConfig.LABELS)},
+        label2id={label: i for i, label in enumerate(TrainNERConfig.LABELS)}
+    )
+    
+    # 创建数据集
+    train_dataset = NERDataset(train_texts, train_labels, tokenizer, TrainNERConfig.LABELS)
+    val_dataset = NERDataset(val_texts, val_labels, tokenizer, TrainNERConfig.LABELS)
+    
+    # 训练参数
+    training_args = TrainingArguments(
+        output_dir=TrainNERConfig.MODEL_PATH,
+        num_train_epochs=TrainNERConfig.EPOCHS,
+        per_device_train_batch_size=TrainNERConfig.BATCH_SIZE,
+        per_device_eval_batch_size=TrainNERConfig.BATCH_SIZE,
+        learning_rate=TrainNERConfig.LEARNING_RATE,
+        warmup_ratio=TrainNERConfig.WARMUP_RATIO,
+        weight_decay=TrainNERConfig.WEIGHT_DECAY,
+        gradient_accumulation_steps=TrainNERConfig.GRADIENT_ACCUMULATION_STEPS
+    )
+
+    trainer = Trainer(
+        model=model,
+        args=training_args,
+        train_dataset=train_dataset,
+        eval_dataset=val_dataset,
+        compute_metrics=compute_metrics
+    )
+
+    trainer.train()
+    # 评估结果
+    eval_results = trainer.evaluate()
+    print("\n评估结果:")
+    for key, value in eval_results.items():
+        print(f"{key}: {value:.4f}")
+
+    # 保存最终模型
+    model.save_pretrained(f"{TrainNERConfig.MODEL_PATH}/best_model")
+    tokenizer.save_pretrained(f"{TrainNERConfig.MODEL_PATH}/best_model")
+
+if __name__ == "__main__":
+    main()
\ No newline at end of file

--
Gitblit v1.9.3