| | |
| | | # train_train_ner.py |
| | | from transformers import AutoTokenizer, AutoModelForTokenClassification, TrainingArguments, Trainer |
| | | from transformers.trainer_callback import EarlyStoppingCallback |
| | | import torch |
| | | from torch.utils.data import Dataset |
| | | import numpy as np |
| | | from sklearn.model_selection import train_test_split |
| | | from seqeval.metrics import f1_score, precision_score, recall_score |
| | | import random |
| | | import os |
| | | import re |
| | | from ner_config import TrainNERConfig |
| | | |
| | | # 设置随机种子 |
| | | def set_seed(seed): |
| | | random.seed(seed) |
| | | np.random.seed(seed) |
| | | torch.manual_seed(seed) |
| | | if torch.cuda.is_available(): |
| | | torch.cuda.manual_seed_all(seed) |
| | | |
| | | set_seed(TrainNERConfig.SEED) |
| | | |
| | | class NERDataset(Dataset): |
| | | def __init__(self, texts, labels, tokenizer, label_list): |
| | | self.texts = texts |
| | | self.labels = labels |
| | | self.tokenizer = tokenizer |
| | | # 创建标签到ID的映射 |
| | | self.label2id = {label: i for i, label in enumerate(label_list)} |
| | | self.id2label = {i: label for i, label in enumerate(label_list)} |
| | | |
| | | # 打印标签映射信息 |
| | | print("标签映射:") |
| | | for label, idx in self.label2id.items(): |
| | | print(f"{label}: {idx}") |
| | | |
| | | # 对文本进行编码 |
| | | self.encodings = self.tokenize_and_align_labels() |
| | | |
| | | def tokenize_and_align_labels(self): |
| | | tokenized_inputs = self.tokenizer( |
| | | [''.join(text) for text in self.texts], |
| | | truncation=True, |
| | | padding=True, |
| | | max_length=TrainNERConfig.MAX_LENGTH, |
| | | return_offsets_mapping=True, |
| | | return_tensors=None |
| | | ) |
| | | |
| | | labels = [] |
| | | for i, label in enumerate(self.labels): |
| | | word_ids = tokenized_inputs.word_ids(i) |
| | | previous_word_idx = None |
| | | label_ids = [] |
| | | current_entity = None |
| | | |
| | | for word_idx in word_ids: |
| | | if word_idx is None: |
| | | label_ids.append(-100) |
| | | elif word_idx != previous_word_idx: |
| | | # 新词开始 |
| | | label_ids.append(self.label2id[label[word_idx]]) |
| | | if label[word_idx].startswith("B-"): |
| | | current_entity = label[word_idx][2:] |
| | | elif label[word_idx] == "O": |
| | | current_entity = None |
| | | else: |
| | | # 同一个词的后续token |
| | | if current_entity: |
| | | label_ids.append(self.label2id[f"I-{current_entity}"]) |
| | | else: |
| | | label_ids.append(self.label2id["O"]) |
| | | |
| | | previous_word_idx = word_idx |
| | | |
| | | labels.append(label_ids) |
| | | |
| | | tokenized_inputs["labels"] = labels |
| | | return tokenized_inputs |
| | | |
| | | def __getitem__(self, idx): |
| | | return {key: torch.tensor(val[idx]) for key, val in self.encodings.items()} |
| | | |
| | | def __len__(self): |
| | | return len(self.texts) |
| | | |
| | | def load_data(file_path): |
| | | texts, labels = [], [] |
| | | current_words, current_labels = [], [] |
| | | |
| | | def clean_trips_labels(words, labels): |
| | | """清理车次标注,确保格式正确""" |
| | | i = 0 |
| | | while i < len(words): |
| | | if labels[i].startswith("B-TRIPS"): # 修改标签名 |
| | | # 找到车次的结束位置 |
| | | j = i + 1 |
| | | while j < len(words) and labels[j].startswith("I-TRIPS"): # 修改标签名 |
| | | j += 1 |
| | | |
| | | # 检查并修正车次序列 |
| | | train_words = words[i:j] |
| | | train_str = ''.join(train_words) |
| | | |
| | | # 检查格式是否符合车次规范 |
| | | valid_patterns = [ |
| | | re.compile(r'^[GDCZTKY]\d{1,2}$'), |
| | | re.compile(r'^[GDCZTKY]\d{1,2}/\d{1,2}$'), |
| | | re.compile(r'^[GDCZTKY]\d{1,2}-\d{1,2}$'), |
| | | re.compile(r'^\d{1,4}$'), |
| | | re.compile(r'^[A-Z]\d{1,4}$') |
| | | ] |
| | | |
| | | is_valid = any(pattern.match(train_str) for pattern in valid_patterns) |
| | | if not is_valid: |
| | | # 将格式不正确的标签改为O |
| | | for k in range(i, j): |
| | | labels[k] = "O" |
| | | |
| | | i = j |
| | | else: |
| | | i += 1 |
| | | |
| | | return words, labels |
| | | |
| | | with open(file_path, 'r', encoding='utf-8') as f: |
| | | for line in f: |
| | | line = line.strip() |
| | | if line: |
| | | try: |
| | | word, label = line.split(maxsplit=1) |
| | | current_words.append(word) |
| | | current_labels.append(label) |
| | | except Exception as e: |
| | | print(f"错误:处理行时出错: '{line}'") |
| | | continue |
| | | elif current_words: # 遇到空行且当前有数据 |
| | | # 清理车次标注 |
| | | current_words, current_labels = clean_trips_labels(current_words, current_labels) |
| | | texts.append(current_words) |
| | | labels.append(current_labels) |
| | | current_words, current_labels = [], [] |
| | | |
| | | if current_words: # 处理最后一个样本 |
| | | current_words, current_labels = clean_trips_labels(current_words, current_labels) |
| | | texts.append(current_words) |
| | | labels.append(current_labels) |
| | | |
| | | return texts, labels |
| | | |
| | | def compute_metrics(p): |
| | | """计算评估指标""" |
| | | predictions, labels = p |
| | | predictions = np.argmax(predictions, axis=2) |
| | | |
| | | # 移除特殊token的预测和标签 |
| | | true_predictions = [ |
| | | [TrainNERConfig.LABELS[p] for (p, l) in zip(prediction, label) if l != -100] |
| | | for prediction, label in zip(predictions, labels) |
| | | ] |
| | | true_labels = [ |
| | | [TrainNERConfig.LABELS[l] for (p, l) in zip(prediction, label) if l != -100] |
| | | for prediction, label in zip(predictions, labels) |
| | | ] |
| | | |
| | | # 计算总体评估指标 |
| | | results = { |
| | | "overall_f1": f1_score(true_labels, true_predictions), |
| | | "overall_precision": precision_score(true_labels, true_predictions), |
| | | "overall_recall": recall_score(true_labels, true_predictions) |
| | | } |
| | | |
| | | # 计算每个实体类型的指标 |
| | | for entity_type in ["COMPANY","TRIPS", "START", "END", "DATE", "TIME", "SEAT", "NAME"]: |
| | | # 将标签转换为二进制形式 |
| | | binary_preds = [] |
| | | binary_labels = [] |
| | | |
| | | for pred_seq, label_seq in zip(true_predictions, true_labels): |
| | | pred_binary = [] |
| | | label_binary = [] |
| | | |
| | | for pred, label in zip(pred_seq, label_seq): |
| | | # 检查标签是否属于当前实体类型 |
| | | pred_is_entity = pred.endswith(entity_type) |
| | | label_is_entity = label.endswith(entity_type) |
| | | |
| | | pred_binary.append(1 if pred_is_entity else 0) |
| | | label_binary.append(1 if label_is_entity else 0) |
| | | |
| | | binary_preds.append(pred_binary) |
| | | binary_labels.append(label_binary) |
| | | |
| | | # 计算当前实体类型的F1分数 |
| | | try: |
| | | entity_f1 = f1_score( |
| | | sum(binary_labels, []), # 展平列表 |
| | | sum(binary_preds, []), # 展平列表 |
| | | average='binary' # 使用二进制评估 |
| | | ) |
| | | results[f"{entity_type}_f1"] = entity_f1 |
| | | except Exception as e: |
| | | print(f"计算{entity_type}的F1分数时出错: {str(e)}") |
| | | results[f"{entity_type}_f1"] = 0.0 |
| | | |
| | | return results |
| | | |
| | | def augment_data(texts, labels): |
| | | """数据增强""" |
| | | augmented_texts = [] |
| | | augmented_labels = [] |
| | | for text, label in zip(texts, labels): |
| | | # 原始数据 |
| | | augmented_texts.append(text) |
| | | augmented_labels.append(label) |
| | | |
| | | # 删除一些无关字符 |
| | | new_text = [] |
| | | new_label = [] |
| | | for t, l in zip(text, label): |
| | | if l == "O" and random.random() < 0.3: |
| | | continue |
| | | new_text.append(t) |
| | | new_label.append(l) |
| | | augmented_texts.append(new_text) |
| | | augmented_labels.append(new_label) |
| | | |
| | | return augmented_texts, augmented_labels |
| | | |
| | | def main(): |
| | | # 加载数据 |
| | | texts, labels = load_data(TrainNERConfig.DATA_PATH) |
| | | print(f"加载的数据集大小:{len(texts)}个样本") |
| | | |
| | | # 划分数据集 |
| | | train_texts, val_texts, train_labels, val_labels = train_test_split( |
| | | texts, labels, test_size=TrainNERConfig.TEST_SIZE, random_state=TrainNERConfig.SEED |
| | | ) |
| | | |
| | | # 数据增强 |
| | | train_texts, train_labels = augment_data(train_texts, train_labels) |
| | | |
| | | # 加载分词器和模型 |
| | | tokenizer = AutoTokenizer.from_pretrained(TrainNERConfig.MODEL_NAME) |
| | | model = AutoModelForTokenClassification.from_pretrained( |
| | | TrainNERConfig.MODEL_NAME, |
| | | num_labels=len(TrainNERConfig.LABELS), |
| | | id2label={i: label for i, label in enumerate(TrainNERConfig.LABELS)}, |
| | | label2id={label: i for i, label in enumerate(TrainNERConfig.LABELS)} |
| | | ) |
| | | |
| | | # 创建数据集 |
| | | train_dataset = NERDataset(train_texts, train_labels, tokenizer, TrainNERConfig.LABELS) |
| | | val_dataset = NERDataset(val_texts, val_labels, tokenizer, TrainNERConfig.LABELS) |
| | | |
| | | # 训练参数 |
| | | training_args = TrainingArguments( |
| | | output_dir=TrainNERConfig.MODEL_PATH, |
| | | num_train_epochs=TrainNERConfig.EPOCHS, |
| | | per_device_train_batch_size=TrainNERConfig.BATCH_SIZE, |
| | | per_device_eval_batch_size=TrainNERConfig.BATCH_SIZE, |
| | | learning_rate=TrainNERConfig.LEARNING_RATE, |
| | | warmup_ratio=TrainNERConfig.WARMUP_RATIO, |
| | | weight_decay=TrainNERConfig.WEIGHT_DECAY, |
| | | gradient_accumulation_steps=TrainNERConfig.GRADIENT_ACCUMULATION_STEPS |
| | | ) |
| | | |
| | | trainer = Trainer( |
| | | model=model, |
| | | args=training_args, |
| | | train_dataset=train_dataset, |
| | | eval_dataset=val_dataset, |
| | | compute_metrics=compute_metrics |
| | | ) |
| | | |
| | | trainer.train() |
| | | # 评估结果 |
| | | eval_results = trainer.evaluate() |
| | | print("\n评估结果:") |
| | | for key, value in eval_results.items(): |
| | | print(f"{key}: {value:.4f}") |
| | | |
| | | # 保存最终模型 |
| | | model.save_pretrained(f"{TrainNERConfig.MODEL_PATH}/best_model") |
| | | tokenizer.save_pretrained(f"{TrainNERConfig.MODEL_PATH}/best_model") |
| | | |
| | | if __name__ == "__main__": |
| | | main() |