# -*- coding: utf-8 -*-
|
from transformers import BertTokenizer, BertForSequenceClassification, Trainer, TrainingArguments, EarlyStoppingCallback
|
import pandas as pd
|
from sklearn.model_selection import train_test_split
|
import torch
|
import re
|
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score, confusion_matrix
|
import matplotlib.pyplot as plt
|
|
import seaborn as sns
|
import warnings
|
warnings.filterwarnings('ignore')
|
# 检查设备
|
print("Available devices:", "GPU" if torch.cuda.is_available() else "Only CPU")
|
|
# 加载数据
|
df = pd.read_csv("data/sms_classify.csv")
|
texts = df["text"].tolist()
|
labels = df["label"].tolist()
|
|
# 在数据加载后添加文本清洗函数 --0320
|
def clean_text(text):
|
# 移除多余的空格
|
text = ' '.join(text.split())
|
# 移除特殊字符
|
text = re.sub(r'[【】]', '', text)
|
return text
|
|
texts = [clean_text(text) for text in texts]
|
|
|
# 标签映射
|
label2id = {label: idx for idx, label in enumerate(set(labels))}
|
id2label = {idx: label for label, idx in label2id.items()}
|
labels = [label2id[label] for label in labels]
|
|
# 划分数据集
|
train_texts, val_texts, train_labels, val_labels = train_test_split(texts, labels, test_size=0.2)
|
|
# 加载分词器和模型(使用BERT中文模型)
|
tokenizer = BertTokenizer.from_pretrained("bert-base-chinese")
|
model = BertForSequenceClassification.from_pretrained(
|
"bert-base-chinese",
|
num_labels=len(label2id),
|
id2label=id2label,
|
label2id=label2id
|
)
|
|
|
# 数据编码
|
train_encodings = tokenizer(train_texts, truncation=True, padding=True, max_length=64)
|
val_encodings = tokenizer(val_texts, truncation=True, padding=True, max_length=64)
|
|
# 转换为PyTorch Dataset
|
class SMSDataset(torch.utils.data.Dataset):
|
def __init__(self, encodings, labels):
|
self.encodings = encodings
|
self.labels = labels
|
|
def __getitem__(self, idx):
|
item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
|
item["labels"] = torch.tensor(self.labels[idx])
|
return item
|
|
def __len__(self):
|
return len(self.labels)
|
|
train_dataset = SMSDataset(train_encodings, train_labels)
|
val_dataset = SMSDataset(val_encodings, val_labels)
|
|
# 训练参数(CPU优化)定义训练参数(批次大小、学习率、保存策略等)
|
training_args = TrainingArguments(
|
output_dir="./models/classifier",
|
num_train_epochs=3,
|
per_device_train_batch_size=8,
|
per_device_eval_batch_size=8,
|
logging_dir="./logs",
|
evaluation_strategy="epoch",
|
save_strategy="epoch",
|
load_best_model_at_end=True,
|
no_cuda=True, # 强制使用CPU
|
# 新增参数
|
metric_for_best_model="accuracy",
|
greater_is_better=True,
|
learning_rate=2e-5,
|
warmup_ratio=0.1, # 使用warmup_ratio替代warmup_steps
|
weight_decay=0.01,
|
)
|
|
|
# 添加评估指标计算 --0320
|
def compute_metrics(pred):
|
labels = pred.label_ids
|
preds = pred.predictions.argmax(-1)
|
precision = precision_score(labels, preds, average='weighted')
|
recall = recall_score(labels, preds, average='weighted')
|
f1 = f1_score(labels, preds, average='weighted')
|
acc = accuracy_score(labels, preds)
|
return {
|
'accuracy': acc,
|
'f1': f1,
|
'precision': precision,
|
'recall': recall
|
}
|
|
# 训练器
|
# 更新Trainer
|
trainer = Trainer(
|
model=model,
|
args=training_args,
|
train_dataset=train_dataset,
|
eval_dataset=val_dataset,
|
compute_metrics=compute_metrics,
|
callbacks=[EarlyStoppingCallback(early_stopping_patience=2)] # 添加早停回调
|
)
|
|
|
# 开始训练
|
trainer.train()
|
|
# 在训练后添加模型评估
|
eval_results = trainer.evaluate()
|
print("评估结果:", eval_results)
|
|
# 保存混淆矩阵
|
predictions = trainer.predict(val_dataset)
|
preds = predictions.predictions.argmax(-1)
|
cm = confusion_matrix(val_labels, preds)
|
plt.figure(figsize=(10,8))
|
sns.heatmap(cm, annot=True, fmt='d', xticklabels=list(id2label.values()), yticklabels=list(id2label.values()))
|
plt.title('混淆矩阵')
|
plt.savefig('./models/classifier/confusion_matrix.png')
|
|
# 保存模型和分词器
|
model_path = "./models/classifier"
|
model.save_pretrained(model_path)
|
tokenizer.save_pretrained(model_path)
|