From acc5c1281b50c12e4d04c81b899410f6ca2cacac Mon Sep 17 00:00:00 2001 From: cloudroam <cloudroam> Date: 星期二, 15 四月 2025 15:13:30 +0800 Subject: [PATCH] add: 增加航班和火车票 --- train_flight_ner.py | 297 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 files changed, 297 insertions(+), 0 deletions(-) diff --git a/train_flight_ner.py b/train_flight_ner.py index e69de29..0641e41 100644 --- a/train_flight_ner.py +++ b/train_flight_ner.py @@ -0,0 +1,297 @@ +# train_flight_ner.py +from transformers import AutoTokenizer, AutoModelForTokenClassification, TrainingArguments, Trainer +from transformers.trainer_callback import EarlyStoppingCallback +import torch +from torch.utils.data import Dataset +import numpy as np +from sklearn.model_selection import train_test_split +from seqeval.metrics import f1_score, precision_score, recall_score +import random +import re +from ner_config import FlightNERConfig + +# 设置随机种子 +def set_seed(seed): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(seed) + +set_seed(FlightNERConfig.SEED) + +class NERDataset(Dataset): + def __init__(self, texts, labels, tokenizer, label_list): + self.texts = texts + self.labels = labels + self.tokenizer = tokenizer + # 创建标签到ID的映射 + self.label2id = {label: i for i, label in enumerate(label_list)} + self.id2label = {i: label for i, label in enumerate(label_list)} + + # 打印标签映射信息 + print("标签映射:") + for label, idx in self.label2id.items(): + print(f"{label}: {idx}") + + # 对文本进行编码 + self.encodings = self.tokenize_and_align_labels() + + def tokenize_and_align_labels(self): + tokenized_inputs = self.tokenizer( + [''.join(text) for text in self.texts], + truncation=True, + padding=True, + max_length=FlightNERConfig.MAX_LENGTH, + return_offsets_mapping=True, + return_tensors=None + ) + + labels = [] + for i, label in enumerate(self.labels): + word_ids = tokenized_inputs.word_ids(i) + previous_word_idx = None + label_ids = [] + current_entity = None + + for word_idx in word_ids: + if word_idx is None: + label_ids.append(-100) + elif word_idx != previous_word_idx: + # 新词开始 + label_ids.append(self.label2id[label[word_idx]]) + if label[word_idx].startswith("B-"): + current_entity = label[word_idx][2:] + elif label[word_idx] == "O": + current_entity = None + else: + # 同一个词的后续token + if current_entity: + label_ids.append(self.label2id[f"I-{current_entity}"]) + else: + label_ids.append(self.label2id["O"]) + + previous_word_idx = word_idx + + labels.append(label_ids) + + tokenized_inputs["labels"] = labels + return tokenized_inputs + + def __getitem__(self, idx): + return {key: torch.tensor(val[idx]) for key, val in self.encodings.items()} + + def __len__(self): + return len(self.texts) + +def load_data(file_path): + texts, labels = [], [] + current_words, current_labels = [], [] + + def clean_flight_labels(words, labels): + """清理航班号标注,确保格式正确""" + i = 0 + while i < len(words): + if labels[i].startswith("B-FLIGHT"): # 已修改标签名称 + # 找到航班号的结束位置 + j = i + 1 + while j < len(words) and labels[j].startswith("I-FLIGHT"): # 已修改标签名称 + j += 1 + + # 检查并修正航班号序列 + flight_words = words[i:j] + flight_str = ''.join(flight_words) + + # 检查格式是否符合航班号规范 + valid_pattern = re.compile(r'^[A-Z]{2}\d{3,4}$') + if not valid_pattern.match(flight_str): + # 将格式不正确的标签改为O + for k in range(i, j): + labels[k] = "O" + + i = j + else: + i += 1 + + return words, labels + + with open(file_path, 'r', encoding='utf-8') as f: + for line in f: + line = line.strip() + if line: + try: + word, label = line.split(maxsplit=1) + current_words.append(word) + current_labels.append(label) + except Exception as e: + print(f"错误:处理行时出错: '{line}'") + continue + elif current_words: # 遇到空行且当前有数据 + # 清理航班号标注 + current_words, current_labels = clean_flight_labels(current_words, current_labels) + texts.append(current_words) + labels.append(current_labels) + current_words, current_labels = [], [] + + if current_words: # 处理最后一个样本 + current_words, current_labels = clean_flight_labels(current_words, current_labels) + texts.append(current_words) + labels.append(current_labels) + + return texts, labels + +def compute_metrics(p): + """计算评估指标""" + predictions, labels = p + predictions = np.argmax(predictions, axis=2) + + # 移除特殊token的预测和标签 + true_predictions = [ + [FlightNERConfig.LABELS[p] for (p, l) in zip(prediction, label) if l != -100] + for prediction, label in zip(predictions, labels) + ] + true_labels = [ + [FlightNERConfig.LABELS[l] for (p, l) in zip(prediction, label) if l != -100] + for prediction, label in zip(predictions, labels) + ] + + # 计算总体评估指标 + results = { + "overall_f1": f1_score(true_labels, true_predictions), + "overall_precision": precision_score(true_labels, true_predictions), + "overall_recall": recall_score(true_labels, true_predictions) + } + + # 计算每个实体类型的指标 + for entity_type in ["FLIGHT", "COMPANY", "START", "END", "DATE", "TIME", "DEPARTURE_TIME", "ARRIVAL_TIME","TICKET_NUM","SEAT"]: + # 将标签转换为二进制形式 + binary_preds = [] + binary_labels = [] + + for pred_seq, label_seq in zip(true_predictions, true_labels): + pred_binary = [] + label_binary = [] + + for pred, label in zip(pred_seq, label_seq): + # 检查标签是否属于当前实体类型 + pred_is_entity = pred.endswith(entity_type) + label_is_entity = label.endswith(entity_type) + + pred_binary.append(1 if pred_is_entity else 0) + label_binary.append(1 if label_is_entity else 0) + + binary_preds.append(pred_binary) + binary_labels.append(label_binary) + + # 计算当前实体类型的F1分数 + try: + entity_f1 = f1_score( + sum(binary_labels, []), # 展平列表 + sum(binary_preds, []), # 展平列表 + average='binary' # 使用二进制评估 + ) + results[f"{entity_type}_f1"] = entity_f1 + except Exception as e: + print(f"计算{entity_type}的F1分数时出错: {str(e)}") + results[f"{entity_type}_f1"] = 0.0 + + return results + +def augment_data(texts, labels): + """数据增强""" + augmented_texts = [] + augmented_labels = [] + for text, label in zip(texts, labels): + # 原始数据 + augmented_texts.append(text) + augmented_labels.append(label) + + # 删除一些无关字符 + new_text = [] + new_label = [] + for t, l in zip(text, label): + if l == "O" and random.random() < 0.3: + continue + new_text.append(t) + new_label.append(l) + augmented_texts.append(new_text) + augmented_labels.append(new_label) + + return augmented_texts, augmented_labels + +def main(): + # 加载数据 + texts, labels = load_data(FlightNERConfig.DATA_PATH) + print(f"加载的数据集大小:{len(texts)}个样本") + + # 划分数据集 + train_texts, val_texts, train_labels, val_labels = train_test_split( + texts, labels, test_size=FlightNERConfig.TEST_SIZE, random_state=FlightNERConfig.SEED + ) + + # 数据增强 + train_texts, train_labels = augment_data(train_texts, train_labels) + + # 加载分词器和模型 + tokenizer = AutoTokenizer.from_pretrained(FlightNERConfig.MODEL_NAME) + model = AutoModelForTokenClassification.from_pretrained( + FlightNERConfig.MODEL_NAME, + num_labels=len(FlightNERConfig.LABELS), + id2label={i: label for i, label in enumerate(FlightNERConfig.LABELS)}, + label2id={label: i for i, label in enumerate(FlightNERConfig.LABELS)} + ) + + # 创建数据集 + train_dataset = NERDataset(train_texts, train_labels, tokenizer, FlightNERConfig.LABELS) + val_dataset = NERDataset(val_texts, val_labels, tokenizer, FlightNERConfig.LABELS) + + # 训练参数 + training_args = TrainingArguments( + output_dir=FlightNERConfig.MODEL_PATH, + num_train_epochs=FlightNERConfig.EPOCHS, + per_device_train_batch_size=FlightNERConfig.BATCH_SIZE, + per_device_eval_batch_size=FlightNERConfig.BATCH_SIZE, + learning_rate=FlightNERConfig.LEARNING_RATE, + warmup_ratio=FlightNERConfig.WARMUP_RATIO, + weight_decay=FlightNERConfig.WEIGHT_DECAY, + gradient_accumulation_steps=FlightNERConfig.GRADIENT_ACCUMULATION_STEPS, + logging_steps=FlightNERConfig.LOGGING_STEPS, + save_total_limit=2, + no_cuda=True, + evaluation_strategy="steps", + eval_steps=FlightNERConfig.EVAL_STEPS, + save_strategy="steps", + save_steps=FlightNERConfig.SAVE_STEPS, + load_best_model_at_end=True, + metric_for_best_model="overall_f1", + greater_is_better=True, + logging_dir=FlightNERConfig.LOG_PATH, + logging_first_step=True, + report_to=["tensorboard"], + ) + + # 训练器 + trainer = Trainer( + model=model, + args=training_args, + train_dataset=train_dataset, + eval_dataset=val_dataset, + compute_metrics=compute_metrics, + callbacks=[EarlyStoppingCallback(early_stopping_patience=FlightNERConfig.EARLY_STOPPING_PATIENCE)] + ) + + # 训练模型 + trainer.train() + + # 评估结果 + eval_results = trainer.evaluate() + print("\n评估结果:") + for key, value in eval_results.items(): + print(f"{key}: {value:.4f}") + + # 保存最终模型 + model.save_pretrained(f"{FlightNERConfig.MODEL_PATH}/best_model") + tokenizer.save_pretrained(f"{FlightNERConfig.MODEL_PATH}/best_model") + +if __name__ == "__main__": + main() \ No newline at end of file -- Gitblit v1.9.3