fix
cloudroam
6 天以前 e6fed94443177826cf7497a85e9cdcfc7c43ee21
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
# -*- coding: utf-8 -*-
from transformers import BertTokenizer, BertForSequenceClassification, Trainer, TrainingArguments, EarlyStoppingCallback
import pandas as pd
from sklearn.model_selection import train_test_split
import torch
import re
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score, confusion_matrix
import matplotlib.pyplot as plt
 
import seaborn as sns
import warnings
warnings.filterwarnings('ignore')
# 检查设备
print("Available devices:", "GPU" if torch.cuda.is_available() else "Only CPU")
 
# 加载数据
df = pd.read_csv("data/sms_classify.csv")
texts = df["text"].tolist()
labels = df["label"].tolist()
 
# 在数据加载后添加文本清洗函数 --0320
def clean_text(text):
    # 移除多余的空格
    text = ' '.join(text.split())
    # 移除特殊字符
    text = re.sub(r'[【】]', '', text)
    return text
 
texts = [clean_text(text) for text in texts]
 
 
# 标签映射
label2id = {label: idx for idx, label in enumerate(set(labels))}
id2label = {idx: label for label, idx in label2id.items()}
labels = [label2id[label] for label in labels]
 
# 划分数据集
train_texts, val_texts, train_labels, val_labels = train_test_split(texts, labels, test_size=0.2)
 
# 加载分词器和模型(使用BERT中文模型)
tokenizer = BertTokenizer.from_pretrained("bert-base-chinese")
model = BertForSequenceClassification.from_pretrained(
    "bert-base-chinese",
    num_labels=len(label2id),
    id2label=id2label,
    label2id=label2id
)
 
 
# 数据编码
train_encodings = tokenizer(train_texts, truncation=True, padding=True, max_length=64)
val_encodings = tokenizer(val_texts, truncation=True, padding=True, max_length=64)
 
# 转换为PyTorch Dataset
class SMSDataset(torch.utils.data.Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels
 
    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        item["labels"] = torch.tensor(self.labels[idx])
        return item
 
    def __len__(self):
        return len(self.labels)
 
train_dataset = SMSDataset(train_encodings, train_labels)
val_dataset = SMSDataset(val_encodings, val_labels)
 
# 训练参数(CPU优化)定义训练参数(批次大小、学习率、保存策略等)
training_args = TrainingArguments(
    output_dir="./models/classifier",
    num_train_epochs=3,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    logging_dir="./logs",
    evaluation_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    no_cuda=True,  # 强制使用CPU
    # 新增参数
    metric_for_best_model="accuracy",
    greater_is_better=True,
    learning_rate=2e-5,
    warmup_ratio=0.1,  # 使用warmup_ratio替代warmup_steps
    weight_decay=0.01,
)
 
 
# 添加评估指标计算  --0320
def compute_metrics(pred):
    labels = pred.label_ids
    preds = pred.predictions.argmax(-1)
    precision = precision_score(labels, preds, average='weighted')
    recall = recall_score(labels, preds, average='weighted')
    f1 = f1_score(labels, preds, average='weighted')
    acc = accuracy_score(labels, preds)
    return {
        'accuracy': acc,
        'f1': f1,
        'precision': precision,
        'recall': recall
    }
 
# 训练器
# 更新Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    compute_metrics=compute_metrics,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=2)]  # 添加早停回调
)
 
 
# 开始训练
trainer.train()
 
# 在训练后添加模型评估
eval_results = trainer.evaluate()
print("评估结果:", eval_results)
 
# 保存混淆矩阵
predictions = trainer.predict(val_dataset)
preds = predictions.predictions.argmax(-1)
cm = confusion_matrix(val_labels, preds)
plt.figure(figsize=(10,8))
sns.heatmap(cm, annot=True, fmt='d', xticklabels=list(id2label.values()), yticklabels=list(id2label.values()))
plt.title('混淆矩阵')
plt.savefig('./models/classifier/confusion_matrix.png')
 
# 保存模型和分词器
model_path = "./models/classifier"
model.save_pretrained(model_path)
tokenizer.save_pretrained(model_path)