??本文介紹了一個基于TextCNN模型的文本分類項目,使用今日頭條新聞數據集進行訓練和評估。項目包括數據獲取、預處理、模型訓練、評估測試等環節。數據預處理涉及清洗文本、中文分詞、去除停用詞、構建詞匯表和向量化等步驟。TextCNN模型通過卷積層和池化層提取文本特征,并在訓練過程中記錄準確率和損失。最終,模型在測試集上達到了較高的準確率(84.06%),并生成了混淆矩陣可視化。項目還詳細介紹了TextCNN模型的結構和創新點,以及數據預處理和模型訓練的具體實現代碼。完整代碼已開源在個人GitHub:https://github.com/KLWU07/Chinese-text-classification-TextCNN
一、項目描述
1.數據獲取
(1)今日頭條新聞分類數據爬取(Python腳本)
- 使用今日頭條文本分類數據集,包含民生故事、文化、娛樂、體育、財經、房
產、汽車、教育、科技、軍事、旅游、國際、證券股票、農業、游戲15個類別,共30多萬條數據。
(2)每條數據形式
6554371968739574280_!_102_!_news_entertainment_!_今年你看《復聯3》哭得有多慘,明年你看《復聯4》就會叫得有多爽_!_卡魔拉,滅霸,小丑,復仇者聯盟3,理想主義者,復聯3,復聯4,雷神3,漫威,黑暗騎士
2.數據預處理
(1)清洗文本(保留中文字符)
(2)中文分詞(使用jieba)
(3)去除停用詞
(4)構建詞匯表并轉換為數字向量
(5)數據劃分為訓練集/驗證集/測試集
3.模型訓練
(1)使用TextCNN模型進行訓練
(2)記錄訓練過程中的準確率和損失
(3)保存最佳模型
4.評估測試
(1)輸出測試集的準確率、精確率、召回率和F1值
(2)生成混淆矩陣可視化
二、訓練過程和結果
1.訓練過程中的準確率和損失、混淆矩陣可視化
textCNN((embed): Embedding(165444, 64, padding_idx=1)(conv11): Conv2d(1, 16, kernel_size=(3, 64), stride=(1, 1))(conv12): Conv2d(1, 16, kernel_size=(4, 64), stride=(1, 1))(conv13): Conv2d(1, 16, kernel_size=(5, 64), stride=(1, 1))(dropout): Dropout(p=0.5, inplace=False)(fc1): Linear(in_features=48, out_features=15, bias=True)
)
Epoch: 1 [===========] cost: 775.52s; loss: 10520.5979; train acc: 0.1589; val acc:0.2256;
Epoch: 2 [===========] cost: 941.70s; loss: 9934.9760; train acc: 0.2239; val acc:0.2975;
Epoch: 3 [===========] cost: 1309.58s; loss: 9329.3478; train acc: 0.2844; val acc:0.3842;
Epoch: 4 [===========] cost: 1619.02s; loss: 8444.4464; train acc: 0.3663; val acc:0.4751;
Epoch: 5 [===========] cost: 1837.63s; loss: 7384.8970; train acc: 0.4545; val acc:0.5569;
Epoch: 6 [===========] cost: 1989.67s; loss: 6397.7872; train acc: 0.5371; val acc:0.6317;
Epoch: 7 [===========] cost: 2095.66s; loss: 5669.9501; train acc: 0.6029; val acc:0.6932;
Epoch: 8 [===========] cost: 2142.67s; loss: 5095.1675; train acc: 0.6578; val acc:0.7403;
Epoch: 9 [===========] cost: 2126.22s; loss: 4633.6697; train acc: 0.6989; val acc:0.7702;
Epoch: 10 [===========] cost: 2101.42s; loss: 4296.7775; train acc: 0.7272; val acc:0.7899;
Epoch: 11 [===========] cost: 2099.69s; loss: 4047.9119; train acc: 0.7460; val acc:0.8012;
Epoch: 12 [===========] cost: 2085.49s; loss: 3871.4451; train acc: 0.7594; val acc:0.8096;
Epoch: 13 [===========] cost: 2152.68s; loss: 3725.9579; train acc: 0.7689; val acc:0.8157;
Epoch: 14 [===========] cost: 2178.29s; loss: 3613.9486; train acc: 0.7761; val acc:0.8209;
Epoch: 15 [===========] cost: 2303.36s; loss: 3531.0598; train acc: 0.7828; val acc:0.8230;
Epoch: 16 [===========] cost: 2183.05s; loss: 3447.1700; train acc: 0.7878; val acc:0.8280;
Epoch: 17 [===========] cost: 2213.03s; loss: 3388.7178; train acc: 0.7918; val acc:0.8286;
Epoch: 18 [===========] cost: 2349.19s; loss: 3332.7325; train acc: 0.7953; val acc:0.8313;
Epoch: 19 [===========] cost: 2508.73s; loss: 3278.0177; train acc: 0.7982; val acc:0.8326;
Epoch: 20 [===========] cost: 2579.64s; loss: 3241.5735; train acc: 0.8021; val acc:0.8341;
Epoch: 21 [===========] cost: 2616.97s; loss: 3205.6926; train acc: 0.8039; val acc:0.8350;
Epoch: 22 [===========] cost: 2644.12s; loss: 3158.8484; train acc: 0.8072; val acc:0.8348;
Epoch: 23 [===========] cost: 2949.09s; loss: 3135.9280; train acc: 0.8093; val acc:0.8376;
Epoch: 24 [===========] cost: 2474.57s; loss: 3118.4267; train acc: 0.8100; val acc:0.8385;
Epoch: 25 [===========] cost: 2544.15s; loss: 3092.5366; train acc: 0.8117; val acc:0.8395;
Epoch: 26 [===========] cost: 2468.19s; loss: 3072.7564; train acc: 0.8132; val acc:0.8389;
Epoch: 27 [===========] cost: 2449.40s; loss: 3041.6649; train acc: 0.8155; val acc:0.8403;
Epoch: 28 [===========] cost: 2472.81s; loss: 3025.2191; train acc: 0.8161; val acc:0.8410;
Epoch: 29 [===========] cost: 2481.49s; loss: 3007.8602; train acc: 0.8182; val acc:0.8408;
Epoch: 30 [===========] cost: 2424.77s; loss: 2986.9278; train acc: 0.8195; val acc:0.8421;
test ...
test acc: 0.8406 precision: 0.8406 recall: 0.8406 f1: 0.8406
三、文本卷積神經網路TextCNN模型解釋
1.《Convolutional Neural Networks for Sentence Classification》論文創新點
(1)模型應用創新:將卷積神經網絡(CNN)應用于自然語言處理(NLP)中的句子分類任務,打破了傳統 NLP 任務主要依賴循環神經網絡(RNN)及其變體的局面。
(2)輸入表示創新:采用預訓練的詞向量(如 word2vec)作為輸入,替代傳統的 one - hot 編碼。預訓練詞向量能更好地捕捉詞的語義信息,提高了模型的泛化能力和性能。
(3)特征提取創新:使用多個不同尺寸的卷積核來提取句子中的關鍵信息,類似于多窗口大小的 ngram。不同大小的卷積核可以捕捉不同長度的局部相關性,從而更全面地提取句子的特征,提高模型的特征提取能力。
(4)模型結構創新:TextCNN 模型結構相對簡單,計算快速、實現方便,且在準確性方面表現較高。它包含嵌入層、卷積層、池化層和全連接層等,通過簡單的結構實現了高效的句子分類。
(5)訓練策略創新:提出了多種訓練策略,如 CNN - rand(基礎模型中所有單詞都是隨機初始化,然后在訓練期間進行修改)、CNN - static(帶有來自 word2vec 的預訓練向量,所有單詞保持靜態,只學習模型的其他參數)、CNN - non - static(預訓練的向量針對每個任務進行微調)和 CNN - multichannel(有兩組詞向量模型)。
2.數據預處理模塊
# 數據處理相關函數
def is_chinese(uchar):# 判斷字符是否為中文
def reserve_chinese(content):# 保留文本中的中文字符
def getStopWords():# 加載停用詞表
def dataParse(text, stop_words):# 解析文本數據,映射標簽,清洗文本并分詞
def getFormatData():# 處理原始數據,構建詞表,生成詞向量表示并保存
(1)中文分詞:jieba分詞
- jieba是一個專為中文文本設計的分詞庫,其主要功能是將連續的中文文本拆分成有意義的詞語序列。與英文不同,中文文本中詞與詞之間沒有空格作為自然分隔符。例如:中文句子:“我喜歡自然語言處理”,分詞結果:“我 / 喜歡 / 自然語言處理”。分詞的準確性直接影響后續文本分析的質量,如關鍵詞提取、情感分析、機器翻譯等。
- jieba分詞基于以下技術:基于前綴詞典的匹配算法、隱馬爾可夫模型 (HMM)、動態規劃優化
- jieba 分詞的主要模式:精確模式(默認),將文本精確地切分成詞語,適合文本分析。全模式,將文本中所有可能的詞語都掃描出來,速度快但可能產生冗余。搜索引擎模式,在精確模式基礎上,對長詞再次切分,適合搜索引擎分詞。
(2)去除停用詞
- 停用詞(Stop Words)是自然語言處理(NLP)中一類被認為對文本分析價值較低的詞匯。如中文中的 “的”“了”“在”“是”,英文中的 “the”“a”“is” 等。包括標點符號、連接詞、代詞等,本身不攜帶具體語義信息。去除停用詞是文本預處理的核心步驟,通過 stopwords.txt 文件可高效管理停用詞列表,配合 jieba 等分詞工具,能顯著提升文本分析的質量和效率。
- 去除停用詞的主要目的:減少數據量,提升處理效率;過濾無意義詞匯,突出關鍵信息(如名詞、動詞等實詞);避免停用詞對文本分析(如關鍵詞提取、情感分析)產生干擾。
import jiebadef filter_stopwords(text, stopwords):"""對文本分詞并去除停用詞"""# 分詞words = jieba.lcut(text)# 過濾停用詞filtered_words = [word for word in words if word not in stopwords and len(word) > 1]return filtered_words# 示例文本
text = "今天天氣很好,我打算去公園散步。"
filtered_result = filter_stopwords(text, stopwords)
print(filtered_result) # 輸出:['今天', '天氣', '很好', '打算', '公園', '散步']
(3)構建詞表:按詞頻排序,高頻詞獲得更小的索引
- 在自然語言處理中,構建詞表(Vocabulary)是將文本轉換為計算機可處理格式的關鍵步驟。按詞頻排序并讓高頻詞獲得更小索引的策略,是一種常見且有效的詞表構建方法。
- 文本數字化的需求:神經網絡等機器學習模型無法直接處理文本,需要將文本轉換為數字表示。詞表是文本與數字之間的映射橋梁,每個詞被映射為唯一的整數索引。
- 按詞頻排序:在大量文本中頻繁出現的詞,通常攜帶更關鍵的語義信息。更小的索引值占用更少的內存和計算資源,高頻詞優先使用小索引可提升效率。
(4)文本向量化:通過詞匯表將詞語映射為索引
- 在自然語言處理中,將文本轉換為詞向量序列是連接原始文本與深度學習模型的關鍵步驟。這一過程將離散的文本符號轉換為連續的數值表示,使模型能夠捕捉文本中的語義信息。
1.分詞:使用分詞工具(如代碼中的 jieba.lcut)將句子拆分為詞語列表。
例如:"深度學習很強大" → ["深度", "學習", "很", "強大"]。2.統計詞頻:遍歷所有文本,統計每個詞的出現頻率(使用 Counter)。
例如:{"深度": 10, "學習": 8, "強大": 5, ...}。3.排序和過濾:按詞頻從高到低排序(代碼中 reverse=True),保留高頻詞以降低維度。低頻詞可視為噪聲或替換為 <UNK>(未知詞)。4.建立映射:為每個詞分配唯一索引(從 1 開始,0 通常留給填充符 <PAD>)。
例如:"<PAD>": 0, "深度": 1, "學習": 2, "強大": 3, ..., "<UNK>": 100015.詞到索引的轉換:根據詞匯表將每個詞替換為對應的索引。處理未知詞:若詞不在詞匯表中,使用 <UNK> 的索引(如 10001)。
例如:["深度", "學習", "強大"] → [1, 2, 3]6.序列對齊(Padding):使用 pad_sequence 將所有序列填充到相同長度(按最長序列或預設值)。
例如(填充到長度5):
[1, 2, 3] → [1, 2, 3, 0, 0] # 0是<pad>的索引
(5) MAT 文件 (data.mat)
data = {'X': data, # 文本的數字向量表示(填充后的索引序列)'label': labels, # 文本對應的類別標簽(數字編碼)'num_words': len(my_vocab) # 詞匯表大小(總詞數)
}
io.savemat('./dataset/data/data.mat', data)
-
data.mat 是通過 scipy.io.savemat() 生成的一個 MATLAB 格式的數據文件,它保存了預處理后的文本數據及其標簽,用于后續的模型訓練和評估。
-
為什么使用 .mat 格式?
- 兼容性: MATLAB和Python(通過 scipy.io)均可讀寫,便于跨平臺使用。
- 結構化存儲: 適合存儲多維數組和元數據(如詞匯表大小)。
- 效率: 二進制格式加載速度快,占用空間比純文本小。
-
X(特征數據)
形狀: [num_samples, max_seq_len]
num_samples: 樣本數量(即多少條文本)。
max_seq_len: 填充后的序列最大長度。每條文本被轉換為固定長度的數字序列
例如:
原始文本: ["深度", "學習", "強大"] → 分詞后
詞匯表: {"深度":1, "學習":2, "強大":3, "<PAD>":0}
填充后: [1, 2, 3, 0, 0] # 填充到長度5
-
label(標簽數據)
形狀: [num_samples,]
每個文本的類別標簽(通過 LabelEncoder 從字符串標簽轉換為數字) -
num_words(詞匯表大小)
作用: 記錄詞匯表的總詞數,用于初始化模型的嵌入層(nn.Embedding)。
數據字段 | 維度 | 說明 |
---|---|---|
X | 382589×53 | 共 382,589 條文本樣本,每條文本被填充/截斷為 53 個詞的固定長度。 |
num_words | 165444 | 詞匯表總大小(包含所有唯一詞 + 特殊符號如 和 ) |
label | 1×382589 | 每條文本對應的類別標簽(共 382,589 個標簽,可能是 15 分類任務) |
- 53 是數據集中所有文本分詞后的 最大詞數(即至少有一條文本被分詞為 53 個詞)
text:"他是最帥反派專業戶,演《古惑仔》大火,今病魔纏身可憐無人識!"
經過 jieba.lcut() 分詞后可能變成:['他', '是', '最', '帥', '反派', '專業戶', ',', '演', '《', '古惑仔', '》', '大火', ',', '今', '病魔', '纏身', '可憐', '無人', '識', '!']
分詞數量:20 個(如果過濾掉標點符號和停用詞,可能更少)。
3.TextCNN 的核心結構
(1)Embedding Layer(詞嵌入層):將單詞映射為稠密向量。
輸入:詞的整數索引(如通過詞表映射得到的[1, 5, 10])。
輸出:固定維度的連續向量(如 100 維的[0.2, -0.5, 0.1, …])。
(2)Convolutional Layer(卷積層):使用 不同尺寸的卷積核(如 [3,4,5])在詞向量序列上滑動,提取局部特征。
-
為什么使用卷積?
- 捕獲局部特征:文本中的語義單元(如短語、情感表達)通常是連續詞的組合。
- 參數共享:減少模型參數量,提高泛化能力。
- 多尺度特征:不同大小的卷積核能捕獲不同長度的語義模式。
-
文本數據的二維表示
雖然文本本質是一維序列,但在卷積操作中通常表示為二維張量,形狀:[句子長度 × 詞向量維度]。如一個包含 10 個詞的句子,每個詞用 300 維向量表示 → [10 × 300] 的二維矩陣 -
kernel_size=2:每次滑動查看 2 個連續詞 的組合(如 [“深度”, “學習”])。
kernel_size=3:每次滑動查看 3 個連續詞 的組合(如 [“深度”, “學習”, “模型”])。
kernel_size=4:每次滑動查看 4 個連續詞 的組合(如 [“深度”, “學習”, “模型”, “強大”])。 -
經過 ReLU 激活函數增強非線性。
Max-Pooling Layer(最大池化層)
對每個卷積核的輸出進行 全局最大池化(Global Max Pooling),提取最重要的特征。
輸出:[batch_size, num_filters](每個卷積核保留一個最大值)。
(3)Concatenation(特征拼接):將所有卷積核的池化結果拼接,形成固定長度的特征向量。
- 池化層的核心功能:特征篩選與降維
- 最大池化:在每個滑動窗口中選取最大值作為輸出,本質是保留最突出的特征。
- 降維效果:減少特征維度,降低計算量,同時避免過擬合。
- 文本處理中的直觀意義
- 文本中的關鍵語義(如情感詞、主題詞)通常會產生較大的卷積響應值。
- 最大池化相當于 “篩選” 出每個局部窗口中最關鍵的特征,忽略次要信息。
import torch
import torch.nn.functional as F# 假設卷積后的特征圖(單通道)
feature_map = torch.tensor([[0.2, -0.5, 0.8, 0.1, -0.3], # 長度為5的特征序列
], dtype=torch.float32).unsqueeze(0) # [1, 1, 5]# 最大池化(窗口大小為5,覆蓋整個序列)
pooled = F.max_pool1d(feature_map, kernel_size=5)
print(pooled) # 輸出: tensor([[[0.8]]])
該操作從長度為 5 的序列中提取最大值 0.8,作為該通道的最終特征。
(4)Fully Connected Layer(全連接層):將拼接后的特征映射到類別空間(Softmax 輸出概率)。在文本分類模型中,全連接層 (Linear) 和 Dropout 層是構建分類器的關鍵組件。它們共同作用,將提取的文本特征映射到分類空間并防止過擬合。
四、完整代碼
# -*- coding: utf-8 -*-
from jieba import lcut
from torchtext.vocab import vocab
from collections import OrderedDict, Counter
from torchtext.transforms import VocabTransform
from torch.nn.utils.rnn import pack_padded_sequence, pad_sequence
from sklearn.preprocessing import LabelEncoder
import scipy.io as io
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
import torch.nn as nn
from torch.optim import Adam
import numpy as np
from utils import metrics, safeCreateDir
import time
from sklearn.metrics import ConfusionMatrixDisplay
from matplotlib import pyplot as plt
import seaborn as sns
import torch
from torch.nn import functional as F
import math
from sklearn.metrics import confusion_matrix
import os# 數據處理
def is_chinese(uchar):if (uchar >= '\u4e00' and uchar <= '\u9fa5'):return Trueelse:return Falsedef reserve_chinese(content):content_str = ''for i in content:if is_chinese(i):content_str += ireturn content_strdef getStopWords():file = open('./dataset/stopwords.txt', 'r', encoding='utf8')words = [i.strip() for i in file.readlines()]file.close()return wordsdef dataParse(text, stop_words):label_map = {'news_story': 0, 'news_culture': 1, 'news_entertainment': 2,'news_sports': 3, 'news_finance': 4, 'news_house': 5, 'news_car': 6,'news_edu': 7, 'news_tech': 8, 'news_military': 9, 'news_travel': 10,'news_world': 11, 'stock': 12, 'news_agriculture': 13, 'news_game': 14}_, _, label, content, _ = text.split('_!_')label = label_map[label]content = reserve_chinese(content)words = lcut(content)words = [i for i in words if not i in stop_words]return words, int(label)def getFormatData():file = open('./dataset/data/toutiao_cat_data.txt', 'r', encoding='utf8')texts = file.readlines()file.close()stop_words = getStopWords()all_words = []all_labels = []for text in texts:content, label = dataParse(text, stop_words)if len(content) <= 0:continueall_words.append(content)all_labels.append(label)ws = sum(all_words, [])set_ws = Counter(ws)keys = sorted(set_ws, key=lambda x: set_ws[x], reverse=True)dict_words = dict(zip(keys, list(range(1, len(set_ws) + 1))))ordered_dict = OrderedDict(dict_words)my_vocab = vocab(ordered_dict, specials=['<UNK>', '<SEP>'])vocab_transform = VocabTransform(my_vocab)vector = vocab_transform(all_words)vector = [torch.tensor(i) for i in vector]pad_seq = pad_sequence(vector, batch_first=True)labelencoder = LabelEncoder()labels = labelencoder.fit_transform(all_labels)data = pad_seq.numpy()data = {'X': data,'label': labels,'num_words': len(my_vocab)}io.savemat('./dataset/data/data.mat', data)# 數據集加載
class Data(Dataset):def __init__(self, mode='train'):data = io.loadmat('./dataset/data/data.mat')self.X = data['X']self.y = data['label']self.num_words = data['num_words'].item()train_X, val_X, train_y, val_y = train_test_split(self.X, self.y.squeeze(), test_size=0.3, random_state=1)val_X, test_X, val_y, test_y = train_test_split(val_X, val_y, test_size=0.5, random_state=1)if mode == 'train':self.X = train_Xself.y = train_yelif mode == 'val':self.X = val_Xself.y = val_yelif mode == 'test':self.X = test_Xself.y = test_ydef __getitem__(self, item):return self.X[item], self.y[item]def __len__(self):return self.X.shape[0]class getDataLoader():def __init__(self, batch_size):train_data = Data('train')val_data = Data('val')test_data = Data('test')self.traindl = DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=4)self.valdl = DataLoader(val_data, batch_size=batch_size, shuffle=True, num_workers=4)self.testdl = DataLoader(test_data, batch_size=batch_size, shuffle=False, num_workers=4)self.num_words = train_data.num_words# 定義網絡結構
class textCNN(nn.Module):def __init__(self, param):super(textCNN, self).__init__()ci = 1kernel_num = param['kernel_num']kernel_size = param['kernel_size']vocab_size = param['vocab_size']embed_dim = param['embed_dim']dropout = param['dropout']class_num = param['class_num']self.param = paramself.embed = nn.Embedding(vocab_size, embed_dim, padding_idx=1)self.conv11 = nn.Conv2d(ci, kernel_num, (kernel_size[0], embed_dim))self.conv12 = nn.Conv2d(ci, kernel_num, (kernel_size[1], embed_dim))self.conv13 = nn.Conv2d(ci, kernel_num, (kernel_size[2], embed_dim))self.dropout = nn.Dropout(dropout)self.fc1 = nn.Linear(len(kernel_size) * kernel_num, class_num)def init_embed(self, embed_matrix):self.embed.weight = nn.Parameter(torch.Tensor(embed_matrix))@staticmethoddef conv_and_pool(x, conv):x = conv(x)x = F.relu(x.squeeze(3))x = F.max_pool1d(x, x.size(2)).squeeze(2)return xdef forward(self, x):x = self.embed(x)x = x.unsqueeze(1)x1 = self.conv_and_pool(x, self.conv11)x2 = self.conv_and_pool(x, self.conv12)x3 = self.conv_and_pool(x, self.conv13)x = torch.cat((x1, x2, x3), 1)x = self.dropout(x)logit = F.log_softmax(self.fc1(x), dim=1)return logitdef init_weight(self):for m in self.modules():if isinstance(m, nn.Conv2d):n = m.kernel_size[0] * m.kernel_size[1] * m.out_channelsm.weight.data.normal_(0, math.sqrt(2. / n))if m.bias is not None:m.bias.data.zero_()elif isinstance(m, nn.Linear):m.weight.data.normal_(0, 0.01)m.bias.data.zero_()def plot_acc(train_acc):sns.set(style='darkgrid')plt.figure(figsize=(10, 7))x = list(range(len(train_acc)))plt.plot(x, train_acc, alpha=0.9, linewidth=2, label='train acc')plt.xlabel('Epoch')plt.ylabel('Acc')plt.legend(loc='best')plt.savefig('results/acc.png', dpi=400)def plot_loss(train_loss):sns.set(style='darkgrid')plt.figure(figsize=(10, 7))x = list(range(len(train_loss)))plt.plot(x, train_loss, alpha=0.9, linewidth=2, label='train loss')plt.xlabel('Epoch')plt.ylabel('loss')plt.legend(loc='best')plt.savefig('results/loss.png', dpi=400)# 定義訓練過程
class Trainer():def __init__(self):safeCreateDir('results/')self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')self._init_data()self._init_model()def _init_data(self):data = getDataLoader(batch_size=64)self.traindl = data.traindlself.valdl = data.valdlself.testdl = data.testdlself.num_words = data.num_wordsdef _init_model(self):self.textCNN_param = {'vocab_size': self.num_words,'embed_dim': 64,'class_num': 15,"kernel_num": 16,"kernel_size": [3, 4, 5],"dropout": 0.5,}self.net = textCNN(self.textCNN_param)self.opt = Adam(self.net.parameters(), lr=1e-4, weight_decay=5e-4)self.cri = nn.CrossEntropyLoss()def save_model(self):save_dir = 'saved_dict'if not os.path.exists(save_dir):os.makedirs(save_dir)torch.save(self.net.state_dict(), os.path.join(save_dir, 'cnn.pt'))def load_model(self):save_dir = 'saved_dict'model_path = os.path.join(save_dir, 'cnn.pt')if not os.path.exists(model_path):raise FileNotFoundError(f"Model file not found at {model_path}")self.net.load_state_dict(torch.load(model_path))def train(self, epochs):print('init net...')self.net.init_weight()print(self.net)patten = 'Epoch: %d [===========] cost: %.2fs; loss: %.4f; train acc: %.4f; val acc:%.4f;'train_accs = []c_loss = []for epoch in range(epochs):cur_preds = np.empty(0)cur_labels = np.empty(0)cur_loss = 0start = time.time()for batch, (inputs, targets) in enumerate(self.traindl):inputs = inputs.to(self.device)targets = targets.to(self.device)self.net.to(self.device)pred = self.net(inputs)loss = self.cri(pred, targets)self.opt.zero_grad()loss.backward()self.opt.step()cur_preds = np.concatenate([cur_preds, pred.cpu().detach().numpy().argmax(axis=1)])cur_labels = np.concatenate([cur_labels, targets.cpu().numpy()])cur_loss += loss.item()acc, precision, f1, recall = metrics(cur_preds, cur_labels)val_acc, val_precision, val_f1, val_recall = self.val()train_accs.append(acc)c_loss.append(cur_loss)end = time.time()print(patten % (epoch + 1, end - start, cur_loss, acc, val_acc))self.save_model()plot_acc(train_accs)plot_loss(c_loss)@torch.no_grad()def val(self):self.net.eval()cur_preds = np.empty(0)cur_labels = np.empty(0)for batch, (inputs, targets) in enumerate(self.valdl):inputs = inputs.to(self.device)targets = targets.to(self.device)self.net.to(self.device)pred = self.net(inputs)cur_preds = np.concatenate([cur_preds, pred.cpu().detach().numpy().argmax(axis=1)])cur_labels = np.concatenate([cur_labels, targets.cpu().numpy()])acc, precision, f1, recall = metrics(cur_preds, cur_labels)self.net.train()return acc, precision, f1, recall@torch.no_grad()def test(self):print("test ...")self.load_model()patten = 'test acc: %.4f precision: %.4f recall: %.4f f1: %.4f 'self.net.eval()cur_preds = np.empty(0)cur_labels = np.empty(0)for batch, (inputs, targets) in enumerate(self.testdl):inputs = inputs.to(self.device)targets = targets.to(self.device)self.net.to(self.device)pred = self.net(inputs)cur_preds = np.concatenate([cur_preds, pred.cpu().detach().numpy().argmax(axis=1)])cur_labels = np.concatenate([cur_labels, targets.cpu().numpy()])acc, precision, f1, recall = metrics(cur_preds, cur_labels)cv_conf = confusion_matrix(cur_preds, cur_labels)labels11 = ['story', 'culture', 'entertainment', 'sports', 'finance','house', 'car', 'edu', 'tech', 'military','travel', 'world', 'stock', 'agriculture', 'game']fig, ax = plt.subplots(figsize=(15, 15))disp = ConfusionMatrixDisplay(confusion_matrix=cv_conf, display_labels=labels11)disp.plot(cmap="Blues", values_format='', ax=ax)plt.savefig("results/ConfusionMatrix.png", dpi=400)self.net.train()print(patten % (acc, precision, recall, f1))if __name__ == "__main__":getFormatData() # 數據預處理:數據清洗和詞向量trainer = Trainer()trainer.train(epochs=30) # 數據訓練trainer.test() # 測試