# -*- coding: utf-8 -*- from transformers import BertTokenizer, BertForSequenceClassification, Trainer, TrainingArguments, EarlyStoppingCallback import pandas as pd from sklearn.model_selection import train_test_split import torch import re from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score, confusion_matrix import matplotlib.pyplot as plt import seaborn as sns import warnings warnings.filterwarnings('ignore') # 检查设备 print("Available devices:", "GPU" if torch.cuda.is_available() else "Only CPU") # 加载数据 df = pd.read_csv("data/sms_classify.csv") texts = df["text"].tolist() labels = df["label"].tolist() # 在数据加载后添加文本清洗函数 --0320 def clean_text(text): # 移除多余的空格 text = ' '.join(text.split()) # 移除特殊字符 text = re.sub(r'[【】]', '', text) return text texts = [clean_text(text) for text in texts] # 标签映射 label2id = {label: idx for idx, label in enumerate(set(labels))} id2label = {idx: label for label, idx in label2id.items()} labels = [label2id[label] for label in labels] # 划分数据集 train_texts, val_texts, train_labels, val_labels = train_test_split(texts, labels, test_size=0.2) # 加载分词器和模型(使用BERT中文模型) tokenizer = BertTokenizer.from_pretrained("bert-base-chinese") model = BertForSequenceClassification.from_pretrained( "bert-base-chinese", num_labels=len(label2id), id2label=id2label, label2id=label2id ) # 数据编码 train_encodings = tokenizer(train_texts, truncation=True, padding=True, max_length=64) val_encodings = tokenizer(val_texts, truncation=True, padding=True, max_length=64) # 转换为PyTorch Dataset class SMSDataset(torch.utils.data.Dataset): def __init__(self, encodings, labels): self.encodings = encodings self.labels = labels def __getitem__(self, idx): item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()} item["labels"] = torch.tensor(self.labels[idx]) return item def __len__(self): return len(self.labels) train_dataset = SMSDataset(train_encodings, train_labels) val_dataset = SMSDataset(val_encodings, val_labels) # 训练参数(CPU优化)定义训练参数(批次大小、学习率、保存策略等) training_args = TrainingArguments( output_dir="./models/classifier", num_train_epochs=3, per_device_train_batch_size=8, per_device_eval_batch_size=8, logging_dir="./logs", evaluation_strategy="epoch", save_strategy="epoch", load_best_model_at_end=True, no_cuda=True, # 强制使用CPU # 新增参数 metric_for_best_model="accuracy", greater_is_better=True, learning_rate=2e-5, warmup_ratio=0.1, # 使用warmup_ratio替代warmup_steps weight_decay=0.01, ) # 添加评估指标计算 --0320 def compute_metrics(pred): labels = pred.label_ids preds = pred.predictions.argmax(-1) precision = precision_score(labels, preds, average='weighted') recall = recall_score(labels, preds, average='weighted') f1 = f1_score(labels, preds, average='weighted') acc = accuracy_score(labels, preds) return { 'accuracy': acc, 'f1': f1, 'precision': precision, 'recall': recall } # 训练器 # 更新Trainer trainer = Trainer( model=model, args=training_args, train_dataset=train_dataset, eval_dataset=val_dataset, compute_metrics=compute_metrics, callbacks=[EarlyStoppingCallback(early_stopping_patience=2)] # 添加早停回调 ) # 开始训练 trainer.train() # 在训练后添加模型评估 eval_results = trainer.evaluate() print("评估结果:", eval_results) # 保存混淆矩阵 predictions = trainer.predict(val_dataset) preds = predictions.predictions.argmax(-1) cm = confusion_matrix(val_labels, preds) plt.figure(figsize=(10,8)) sns.heatmap(cm, annot=True, fmt='d', xticklabels=list(id2label.values()), yticklabels=list(id2label.values())) plt.title('混淆矩阵') plt.savefig('./models/classifier/confusion_matrix.png') # 保存模型和分词器 model_path = "./models/classifier" model.save_pretrained(model_path) tokenizer.save_pretrained(model_path)