From acc5c1281b50c12e4d04c81b899410f6ca2cacac Mon Sep 17 00:00:00 2001
From: cloudroam <cloudroam>
Date: 星期二, 15 四月 2025 15:13:30 +0800
Subject: [PATCH] add: 增加航班和火车票
---
train_flight_ner.py | 297 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
1 files changed, 297 insertions(+), 0 deletions(-)
diff --git a/train_flight_ner.py b/train_flight_ner.py
index e69de29..0641e41 100644
--- a/train_flight_ner.py
+++ b/train_flight_ner.py
@@ -0,0 +1,297 @@
+# train_flight_ner.py
+from transformers import AutoTokenizer, AutoModelForTokenClassification, TrainingArguments, Trainer
+from transformers.trainer_callback import EarlyStoppingCallback
+import torch
+from torch.utils.data import Dataset
+import numpy as np
+from sklearn.model_selection import train_test_split
+from seqeval.metrics import f1_score, precision_score, recall_score
+import random
+import re
+from ner_config import FlightNERConfig
+
+# 设置随机种子
+def set_seed(seed):
+ random.seed(seed)
+ np.random.seed(seed)
+ torch.manual_seed(seed)
+ if torch.cuda.is_available():
+ torch.cuda.manual_seed_all(seed)
+
+set_seed(FlightNERConfig.SEED)
+
+class NERDataset(Dataset):
+ def __init__(self, texts, labels, tokenizer, label_list):
+ self.texts = texts
+ self.labels = labels
+ self.tokenizer = tokenizer
+ # 创建标签到ID的映射
+ self.label2id = {label: i for i, label in enumerate(label_list)}
+ self.id2label = {i: label for i, label in enumerate(label_list)}
+
+ # 打印标签映射信息
+ print("标签映射:")
+ for label, idx in self.label2id.items():
+ print(f"{label}: {idx}")
+
+ # 对文本进行编码
+ self.encodings = self.tokenize_and_align_labels()
+
+ def tokenize_and_align_labels(self):
+ tokenized_inputs = self.tokenizer(
+ [''.join(text) for text in self.texts],
+ truncation=True,
+ padding=True,
+ max_length=FlightNERConfig.MAX_LENGTH,
+ return_offsets_mapping=True,
+ return_tensors=None
+ )
+
+ labels = []
+ for i, label in enumerate(self.labels):
+ word_ids = tokenized_inputs.word_ids(i)
+ previous_word_idx = None
+ label_ids = []
+ current_entity = None
+
+ for word_idx in word_ids:
+ if word_idx is None:
+ label_ids.append(-100)
+ elif word_idx != previous_word_idx:
+ # 新词开始
+ label_ids.append(self.label2id[label[word_idx]])
+ if label[word_idx].startswith("B-"):
+ current_entity = label[word_idx][2:]
+ elif label[word_idx] == "O":
+ current_entity = None
+ else:
+ # 同一个词的后续token
+ if current_entity:
+ label_ids.append(self.label2id[f"I-{current_entity}"])
+ else:
+ label_ids.append(self.label2id["O"])
+
+ previous_word_idx = word_idx
+
+ labels.append(label_ids)
+
+ tokenized_inputs["labels"] = labels
+ return tokenized_inputs
+
+ def __getitem__(self, idx):
+ return {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
+
+ def __len__(self):
+ return len(self.texts)
+
+def load_data(file_path):
+ texts, labels = [], []
+ current_words, current_labels = [], []
+
+ def clean_flight_labels(words, labels):
+ """清理航班号标注,确保格式正确"""
+ i = 0
+ while i < len(words):
+ if labels[i].startswith("B-FLIGHT"): # 已修改标签名称
+ # 找到航班号的结束位置
+ j = i + 1
+ while j < len(words) and labels[j].startswith("I-FLIGHT"): # 已修改标签名称
+ j += 1
+
+ # 检查并修正航班号序列
+ flight_words = words[i:j]
+ flight_str = ''.join(flight_words)
+
+ # 检查格式是否符合航班号规范
+ valid_pattern = re.compile(r'^[A-Z]{2}\d{3,4}$')
+ if not valid_pattern.match(flight_str):
+ # 将格式不正确的标签改为O
+ for k in range(i, j):
+ labels[k] = "O"
+
+ i = j
+ else:
+ i += 1
+
+ return words, labels
+
+ with open(file_path, 'r', encoding='utf-8') as f:
+ for line in f:
+ line = line.strip()
+ if line:
+ try:
+ word, label = line.split(maxsplit=1)
+ current_words.append(word)
+ current_labels.append(label)
+ except Exception as e:
+ print(f"错误:处理行时出错: '{line}'")
+ continue
+ elif current_words: # 遇到空行且当前有数据
+ # 清理航班号标注
+ current_words, current_labels = clean_flight_labels(current_words, current_labels)
+ texts.append(current_words)
+ labels.append(current_labels)
+ current_words, current_labels = [], []
+
+ if current_words: # 处理最后一个样本
+ current_words, current_labels = clean_flight_labels(current_words, current_labels)
+ texts.append(current_words)
+ labels.append(current_labels)
+
+ return texts, labels
+
+def compute_metrics(p):
+ """计算评估指标"""
+ predictions, labels = p
+ predictions = np.argmax(predictions, axis=2)
+
+ # 移除特殊token的预测和标签
+ true_predictions = [
+ [FlightNERConfig.LABELS[p] for (p, l) in zip(prediction, label) if l != -100]
+ for prediction, label in zip(predictions, labels)
+ ]
+ true_labels = [
+ [FlightNERConfig.LABELS[l] for (p, l) in zip(prediction, label) if l != -100]
+ for prediction, label in zip(predictions, labels)
+ ]
+
+ # 计算总体评估指标
+ results = {
+ "overall_f1": f1_score(true_labels, true_predictions),
+ "overall_precision": precision_score(true_labels, true_predictions),
+ "overall_recall": recall_score(true_labels, true_predictions)
+ }
+
+ # 计算每个实体类型的指标
+ for entity_type in ["FLIGHT", "COMPANY", "START", "END", "DATE", "TIME", "DEPARTURE_TIME", "ARRIVAL_TIME","TICKET_NUM","SEAT"]:
+ # 将标签转换为二进制形式
+ binary_preds = []
+ binary_labels = []
+
+ for pred_seq, label_seq in zip(true_predictions, true_labels):
+ pred_binary = []
+ label_binary = []
+
+ for pred, label in zip(pred_seq, label_seq):
+ # 检查标签是否属于当前实体类型
+ pred_is_entity = pred.endswith(entity_type)
+ label_is_entity = label.endswith(entity_type)
+
+ pred_binary.append(1 if pred_is_entity else 0)
+ label_binary.append(1 if label_is_entity else 0)
+
+ binary_preds.append(pred_binary)
+ binary_labels.append(label_binary)
+
+ # 计算当前实体类型的F1分数
+ try:
+ entity_f1 = f1_score(
+ sum(binary_labels, []), # 展平列表
+ sum(binary_preds, []), # 展平列表
+ average='binary' # 使用二进制评估
+ )
+ results[f"{entity_type}_f1"] = entity_f1
+ except Exception as e:
+ print(f"计算{entity_type}的F1分数时出错: {str(e)}")
+ results[f"{entity_type}_f1"] = 0.0
+
+ return results
+
+def augment_data(texts, labels):
+ """数据增强"""
+ augmented_texts = []
+ augmented_labels = []
+ for text, label in zip(texts, labels):
+ # 原始数据
+ augmented_texts.append(text)
+ augmented_labels.append(label)
+
+ # 删除一些无关字符
+ new_text = []
+ new_label = []
+ for t, l in zip(text, label):
+ if l == "O" and random.random() < 0.3:
+ continue
+ new_text.append(t)
+ new_label.append(l)
+ augmented_texts.append(new_text)
+ augmented_labels.append(new_label)
+
+ return augmented_texts, augmented_labels
+
+def main():
+ # 加载数据
+ texts, labels = load_data(FlightNERConfig.DATA_PATH)
+ print(f"加载的数据集大小:{len(texts)}个样本")
+
+ # 划分数据集
+ train_texts, val_texts, train_labels, val_labels = train_test_split(
+ texts, labels, test_size=FlightNERConfig.TEST_SIZE, random_state=FlightNERConfig.SEED
+ )
+
+ # 数据增强
+ train_texts, train_labels = augment_data(train_texts, train_labels)
+
+ # 加载分词器和模型
+ tokenizer = AutoTokenizer.from_pretrained(FlightNERConfig.MODEL_NAME)
+ model = AutoModelForTokenClassification.from_pretrained(
+ FlightNERConfig.MODEL_NAME,
+ num_labels=len(FlightNERConfig.LABELS),
+ id2label={i: label for i, label in enumerate(FlightNERConfig.LABELS)},
+ label2id={label: i for i, label in enumerate(FlightNERConfig.LABELS)}
+ )
+
+ # 创建数据集
+ train_dataset = NERDataset(train_texts, train_labels, tokenizer, FlightNERConfig.LABELS)
+ val_dataset = NERDataset(val_texts, val_labels, tokenizer, FlightNERConfig.LABELS)
+
+ # 训练参数
+ training_args = TrainingArguments(
+ output_dir=FlightNERConfig.MODEL_PATH,
+ num_train_epochs=FlightNERConfig.EPOCHS,
+ per_device_train_batch_size=FlightNERConfig.BATCH_SIZE,
+ per_device_eval_batch_size=FlightNERConfig.BATCH_SIZE,
+ learning_rate=FlightNERConfig.LEARNING_RATE,
+ warmup_ratio=FlightNERConfig.WARMUP_RATIO,
+ weight_decay=FlightNERConfig.WEIGHT_DECAY,
+ gradient_accumulation_steps=FlightNERConfig.GRADIENT_ACCUMULATION_STEPS,
+ logging_steps=FlightNERConfig.LOGGING_STEPS,
+ save_total_limit=2,
+ no_cuda=True,
+ evaluation_strategy="steps",
+ eval_steps=FlightNERConfig.EVAL_STEPS,
+ save_strategy="steps",
+ save_steps=FlightNERConfig.SAVE_STEPS,
+ load_best_model_at_end=True,
+ metric_for_best_model="overall_f1",
+ greater_is_better=True,
+ logging_dir=FlightNERConfig.LOG_PATH,
+ logging_first_step=True,
+ report_to=["tensorboard"],
+ )
+
+ # 训练器
+ trainer = Trainer(
+ model=model,
+ args=training_args,
+ train_dataset=train_dataset,
+ eval_dataset=val_dataset,
+ compute_metrics=compute_metrics,
+ callbacks=[EarlyStoppingCallback(early_stopping_patience=FlightNERConfig.EARLY_STOPPING_PATIENCE)]
+ )
+
+ # 训练模型
+ trainer.train()
+
+ # 评估结果
+ eval_results = trainer.evaluate()
+ print("\n评估结果:")
+ for key, value in eval_results.items():
+ print(f"{key}: {value:.4f}")
+
+ # 保存最终模型
+ model.save_pretrained(f"{FlightNERConfig.MODEL_PATH}/best_model")
+ tokenizer.save_pretrained(f"{FlightNERConfig.MODEL_PATH}/best_model")
+
+if __name__ == "__main__":
+ main()
\ No newline at end of file
--
Gitblit v1.9.3