from transformers import AutoTokenizer, AutoModelForTokenClassification, TrainingArguments, Trainer
|
from transformers.trainer_callback import EarlyStoppingCallback, ProgressCallback
|
import torch
|
from torch.utils.data import Dataset
|
import numpy as np
|
from sklearn.model_selection import train_test_split
|
from seqeval.metrics import f1_score, precision_score, recall_score
|
import random
|
import os
|
from ner_config import RepaymentNERConfig
|
|
# 设置随机种子
|
def set_seed(seed):
|
random.seed(seed)
|
np.random.seed(seed)
|
torch.manual_seed(seed)
|
if torch.cuda.is_available():
|
torch.cuda.manual_seed_all(seed)
|
|
set_seed(RepaymentNERConfig.SEED)
|
|
class RepaymentDataset(Dataset):
|
def __init__(self, texts, labels, tokenizer, label_list):
|
self.texts = texts
|
self.labels = labels
|
self.tokenizer = tokenizer
|
self.label2id = {label: i for i, label in enumerate(label_list)}
|
self.id2label = {i: label for i, label in enumerate(label_list)}
|
self.encodings = self.tokenize_and_align_labels()
|
|
def tokenize_and_align_labels(self):
|
"""分词并对齐标签"""
|
tokenized_inputs = self.tokenizer(
|
self.texts, # 直接传入文本列表
|
is_split_into_words=True, # 指示输入已经分词
|
truncation=True,
|
padding=True,
|
max_length=RepaymentNERConfig.MAX_LENGTH,
|
return_offsets_mapping=True,
|
return_tensors=None
|
)
|
|
labels = []
|
for i, label_seq in enumerate(self.labels):
|
word_ids = tokenized_inputs.word_ids(i)
|
previous_word_idx = None
|
label_ids = []
|
|
for word_idx in word_ids:
|
if word_idx is None:
|
# 特殊token,如[CLS], [SEP], [PAD]
|
label_ids.append(-100)
|
elif word_idx != previous_word_idx:
|
# 新词的第一个token
|
try:
|
label_ids.append(self.label2id[label_seq[word_idx]])
|
except IndexError:
|
print(f"错误:样本 {i} 的标签序列长度与文本不匹配")
|
print(f"文本长度: {len(self.texts[i])}")
|
print(f"标签长度: {len(label_seq)}")
|
print(f"word_idx: {word_idx}")
|
raise
|
else:
|
# 同一个词的后续token
|
# 如果前一个token是实体的一部分,则使用相同的标签
|
if label_seq[word_idx-1].startswith("B-"):
|
current_type = label_seq[word_idx-1][2:]
|
label_ids.append(self.label2id[f"I-{current_type}"])
|
elif label_seq[word_idx-1].startswith("I-"):
|
label_ids.append(self.label2id[label_seq[word_idx-1]])
|
else:
|
label_ids.append(self.label2id["O"])
|
previous_word_idx = word_idx
|
|
labels.append(label_ids)
|
|
tokenized_inputs["labels"] = labels
|
return tokenized_inputs
|
|
def __getitem__(self, idx):
|
return {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
|
|
def __len__(self):
|
return len(self.texts)
|
|
def compute_metrics(p):
|
predictions, labels = p
|
predictions = np.argmax(predictions, axis=2)
|
|
true_predictions = [
|
[RepaymentNERConfig.LABELS[p] for (p, l) in zip(prediction, label) if l != -100]
|
for prediction, label in zip(predictions, labels)
|
]
|
true_labels = [
|
[RepaymentNERConfig.LABELS[l] for (p, l) in zip(prediction, label) if l != -100]
|
for prediction, label in zip(predictions, labels)
|
]
|
|
results = {
|
"overall_f1": f1_score(true_labels, true_predictions),
|
"overall_precision": precision_score(true_labels, true_predictions),
|
"overall_recall": recall_score(true_labels, true_predictions)
|
}
|
|
return results
|
|
def train_repayment_model(texts, labels):
|
# 加载预训练模型和分词器
|
tokenizer = AutoTokenizer.from_pretrained(RepaymentNERConfig.MODEL_NAME)
|
model = AutoModelForTokenClassification.from_pretrained(
|
RepaymentNERConfig.MODEL_NAME,
|
num_labels=len(RepaymentNERConfig.LABELS),
|
id2label={i: label for i, label in enumerate(RepaymentNERConfig.LABELS)},
|
label2id={label: i for i, label in enumerate(RepaymentNERConfig.LABELS)}
|
)
|
|
# 划分训练集和验证集
|
train_texts, val_texts, train_labels, val_labels = train_test_split(
|
texts, labels,
|
test_size=RepaymentNERConfig.TEST_SIZE,
|
random_state=RepaymentNERConfig.SEED
|
)
|
|
# 创建数据集
|
train_dataset = RepaymentDataset(train_texts, train_labels, tokenizer, RepaymentNERConfig.LABELS)
|
val_dataset = RepaymentDataset(val_texts, val_labels, tokenizer, RepaymentNERConfig.LABELS)
|
|
# 训练参数
|
training_args = TrainingArguments(
|
output_dir=RepaymentNERConfig.MODEL_PATH,
|
num_train_epochs=RepaymentNERConfig.EPOCHS,
|
per_device_train_batch_size=RepaymentNERConfig.BATCH_SIZE,
|
per_device_eval_batch_size=RepaymentNERConfig.BATCH_SIZE * 2, # 评估时可以用更大的批次
|
warmup_ratio=RepaymentNERConfig.WARMUP_RATIO,
|
weight_decay=RepaymentNERConfig.WEIGHT_DECAY,
|
logging_dir=RepaymentNERConfig.LOG_PATH,
|
logging_steps=RepaymentNERConfig.LOGGING_STEPS,
|
evaluation_strategy="steps",
|
eval_steps=RepaymentNERConfig.EVAL_STEPS,
|
save_strategy="steps",
|
save_steps=RepaymentNERConfig.EVAL_STEPS,
|
save_total_limit=RepaymentNERConfig.SAVE_TOTAL_LIMIT,
|
load_best_model_at_end=True,
|
metric_for_best_model="overall_f1",
|
greater_is_better=True,
|
max_grad_norm=RepaymentNERConfig.MAX_GRAD_NORM,
|
gradient_accumulation_steps=RepaymentNERConfig.GRADIENT_ACCUMULATION_STEPS,
|
fp16=RepaymentNERConfig.FP16,
|
dataloader_num_workers=RepaymentNERConfig.DATALOADER_NUM_WORKERS,
|
dataloader_pin_memory=RepaymentNERConfig.DATALOADER_PIN_MEMORY,
|
save_safetensors=True,
|
optim="adamw_torch",
|
disable_tqdm=False,
|
report_to=["tensorboard"],
|
group_by_length=True, # 相似长度的样本放在一起,减少padding
|
length_column_name="length",
|
remove_unused_columns=True,
|
label_smoothing_factor=0.1, # 添加标签平滑
|
)
|
|
# 创建训练器
|
trainer = Trainer(
|
model=model,
|
args=training_args,
|
train_dataset=train_dataset,
|
eval_dataset=val_dataset,
|
compute_metrics=compute_metrics,
|
callbacks=[
|
EarlyStoppingCallback(
|
early_stopping_patience=RepaymentNERConfig.EARLY_STOPPING_PATIENCE,
|
early_stopping_threshold=0.001
|
),
|
# 添加进度条回调
|
ProgressCallback()
|
]
|
)
|
|
try:
|
# 训练模型
|
print("\n开始训练模型...")
|
train_result = trainer.train()
|
|
# 打印训练结果
|
print("\n训练完成!")
|
print(f"训练时长: {train_result.metrics['train_runtime']:.2f}秒")
|
|
# 安全地获取和打印指标
|
metrics = train_result.metrics
|
print("\n训练指标:")
|
for key, value in metrics.items():
|
if isinstance(value, (int, float)):
|
print(f"- {key}: {value:.4f}")
|
|
# 最终评估
|
final_eval = trainer.evaluate()
|
print("\n最终评估结果:")
|
print(f"F1分数: {final_eval['eval_overall_f1']:.4f}")
|
print(f"准确率: {final_eval['eval_overall_precision']:.4f}")
|
print(f"召回率: {final_eval['eval_overall_recall']:.4f}")
|
|
# 保存最佳模型
|
print("\n保存模型...")
|
save_path = f"{RepaymentNERConfig.MODEL_PATH}/best_model"
|
trainer.save_model(save_path)
|
tokenizer.save_pretrained(save_path)
|
print(f"模型已保存到: {save_path}")
|
|
return model, tokenizer
|
|
except Exception as e:
|
print(f"\n训练过程中断: {str(e)}")
|
# 尝试保存当前模型
|
try:
|
save_path = f"{RepaymentNERConfig.MODEL_PATH}/interrupted_model"
|
trainer.save_model(save_path)
|
tokenizer.save_pretrained(save_path)
|
print(f"已保存中断时的模型到: {save_path}")
|
except Exception as save_error:
|
print(f"保存中断模型失败: {str(save_error)}")
|
raise
|
|
def validate_labels(labels, valid_labels):
|
"""验证标签是否合法"""
|
label_set = set()
|
for seq in labels:
|
label_set.update(seq)
|
|
invalid_labels = label_set - set(valid_labels)
|
if invalid_labels:
|
raise ValueError(f"发现非法标签: {invalid_labels}")
|
|
def clean_text(text: str) -> str:
|
"""清理文本中的特殊字符"""
|
# 替换全角字符为半角
|
text = text.replace('¥', '¥')
|
text = text.replace(',', ',')
|
text = text.replace('。', '.')
|
text = text.replace(':', ':')
|
text = text.replace('(', '(')
|
text = text.replace(')', ')')
|
return text
|
|
def preprocess_data(texts, labels):
|
"""预处理数据"""
|
processed_texts = []
|
processed_labels = []
|
|
for i, (text, label_seq) in enumerate(zip(texts, labels)):
|
if len(text) != len(label_seq):
|
print(f"警告:样本 {i} 的文本和标签长度不匹配,已跳过")
|
continue
|
|
# 清理文本
|
cleaned_text = [clean_text(word) for word in text]
|
|
# 处理金额标注
|
is_min_amount = False
|
new_labels = []
|
for j, (word, label) in enumerate(zip(cleaned_text, label_seq)):
|
if label.startswith("B-PICKUP_CODE"):
|
# 检查是否是最低还款金额
|
context = ''.join(cleaned_text[max(0, j-5):j])
|
if any(kw in context for kw in RepaymentNERConfig.AMOUNT_CONFIG['min_amount_keywords']):
|
is_min_amount = True
|
new_labels.append("B-MIN_CODE")
|
else:
|
is_min_amount = False
|
new_labels.append(label)
|
elif label.startswith("I-PICKUP_CODE"):
|
if is_min_amount:
|
new_labels.append("I-MIN_CODE")
|
else:
|
new_labels.append(label)
|
else:
|
new_labels.append(label)
|
|
processed_texts.append(cleaned_text)
|
processed_labels.append(new_labels)
|
|
return processed_texts, processed_labels
|
|
def load_data(file_path):
|
"""加载并预处理数据"""
|
texts = []
|
labels = []
|
current_words = []
|
current_labels = []
|
skip_url = False
|
url_indicators = {'u', 'ur', 'url', 'http', 'https', 'www', 'com', 'cn'}
|
|
def is_url_part(word):
|
return (word.lower() in url_indicators or
|
'.' in word or
|
'/' in word or
|
word.startswith('?'))
|
|
with open(file_path, 'r', encoding='utf-8') as f:
|
for line_num, line in enumerate(f, 1):
|
line = line.strip()
|
|
if not line: # 样本分隔符
|
if current_words:
|
texts.append(current_words)
|
labels.append(current_labels)
|
current_words = []
|
current_labels = []
|
skip_url = False
|
continue
|
|
try:
|
word, label = line.split(maxsplit=1)
|
|
# URL处理逻辑
|
if is_url_part(word):
|
skip_url = True
|
elif word in ['】', ',', '。', ':']:
|
skip_url = False
|
|
# 标签验证
|
if not skip_url:
|
if label not in RepaymentNERConfig.LABELS:
|
print(f"警告:第{line_num}行发现非法标签 '{label}',已跳过")
|
continue
|
current_words.append(word)
|
current_labels.append(label)
|
|
except Exception as e:
|
print(f"错误:第{line_num}行处理失败 '{line}': {str(e)}")
|
continue
|
|
# 处理最后一个样本
|
if current_words:
|
texts.append(current_words)
|
labels.append(current_labels)
|
|
return texts, labels
|
|
def validate_dataset(texts, labels):
|
"""验证数据集的完整性和正确性"""
|
stats = {
|
"total_samples": len(texts),
|
"total_tokens": sum(len(t) for t in texts),
|
"entity_counts": {},
|
"avg_length": 0,
|
"errors": []
|
}
|
|
for i, (text, label_seq) in enumerate(zip(texts, labels)):
|
# 长度检查
|
if len(text) != len(label_seq):
|
stats["errors"].append(f"样本 {i}: 文本和标签长度不匹配")
|
continue
|
|
# 统计实体
|
current_entity = None
|
for j, (word, label) in enumerate(zip(text, label_seq)):
|
if label.startswith("B-"):
|
entity_type = label[2:]
|
stats["entity_counts"][entity_type] = stats["entity_counts"].get(entity_type, 0) + 1
|
current_entity = entity_type
|
elif label.startswith("I-"):
|
if not current_entity:
|
stats["errors"].append(f"样本 {i}: 位置 {j} 的I-标签前没有对应的B-标签")
|
elif label[2:] != current_entity:
|
stats["errors"].append(f"样本 {i}: 位置 {j} 的I-标签类型与B-标签不匹配")
|
else:
|
current_entity = None
|
|
stats["avg_length"] = stats["total_tokens"] / stats["total_samples"] if stats["total_samples"] > 0 else 0
|
|
return stats
|
|
def resume_training(checkpoint_path):
|
"""从检查点恢复训练"""
|
print(f"从检查点恢复训练: {checkpoint_path}")
|
|
# 加载模型和分词器
|
tokenizer = AutoTokenizer.from_pretrained(checkpoint_path)
|
model = AutoModelForTokenClassification.from_pretrained(checkpoint_path)
|
|
# 重新加载数据
|
texts, labels = load_data(RepaymentNERConfig.DATA_PATH)
|
texts, labels = preprocess_data(texts, labels)
|
|
# 重新创建数据集
|
train_texts, val_texts, train_labels, val_labels = train_test_split(
|
texts, labels,
|
test_size=RepaymentNERConfig.TEST_SIZE,
|
random_state=RepaymentNERConfig.SEED
|
)
|
|
train_dataset = RepaymentDataset(train_texts, train_labels, tokenizer, RepaymentNERConfig.LABELS)
|
val_dataset = RepaymentDataset(val_texts, val_labels, tokenizer, RepaymentNERConfig.LABELS)
|
|
# 创建训练器并继续训练
|
training_args = TrainingArguments(
|
output_dir=RepaymentNERConfig.MODEL_PATH,
|
num_train_epochs=RepaymentNERConfig.EPOCHS,
|
# ... 其他参数与train_repayment_model中相同 ...
|
)
|
|
trainer = Trainer(
|
model=model,
|
args=training_args,
|
train_dataset=train_dataset,
|
eval_dataset=val_dataset,
|
compute_metrics=compute_metrics,
|
callbacks=[EarlyStoppingCallback(early_stopping_patience=RepaymentNERConfig.EARLY_STOPPING_PATIENCE)]
|
)
|
|
# 继续训练
|
trainer.train(resume_from_checkpoint=checkpoint_path)
|
|
return model, tokenizer
|
|
def main():
|
# 加载数据
|
print("正在加载数据...")
|
texts, labels = load_data(RepaymentNERConfig.DATA_PATH)
|
|
# 数据预处理
|
print("正在预处理数据...")
|
texts, labels = preprocess_data(texts, labels)
|
|
# 验证数据
|
print("验证数据集...")
|
for i, (text, label_seq) in enumerate(zip(texts, labels)):
|
if len(text) != len(label_seq):
|
print(f"错误:样本 {i} 的文本和标签长度不匹配")
|
print(f"文本({len(text)}): {text}")
|
print(f"标签({len(label_seq)}): {label_seq}")
|
return
|
|
print(f"数据验证通过,共 {len(texts)} 个有效样本")
|
|
# 验证数据集
|
print("正在验证数据集...")
|
stats = validate_dataset(texts, labels)
|
|
print("\n=== 数据集统计 ===")
|
print(f"总样本数: {stats['total_samples']}")
|
print(f"平均长度: {stats['avg_length']:.2f}")
|
print("\n实体统计:")
|
for entity_type, count in stats['entity_counts'].items():
|
print(f"- {entity_type}: {count}")
|
|
if stats['errors']:
|
print("\n发现以下问题:")
|
for error in stats['errors']:
|
print(f"- {error}")
|
if input("是否继续训练? (y/n) ").lower() != 'y':
|
return
|
|
# 检查是否存在中断的模型
|
interrupted_model_path = f"{RepaymentNERConfig.MODEL_PATH}/interrupted_model"
|
if os.path.exists(interrupted_model_path):
|
print("\n发现中断的训练模型")
|
if input("是否从中断处继续训练? (y/n) ").lower() == 'y':
|
model, tokenizer = resume_training(interrupted_model_path)
|
return
|
|
# 正常训练流程
|
print("\n开始新的训练...")
|
model, tokenizer = train_repayment_model(texts, labels)
|
|
if __name__ == "__main__":
|
main()
|