機器學習監督學習實戰七:文本卷積神經網絡TextCNN對中文短文本分類(15類)

??本文介紹了一個基于TextCNN模型的文本分類項目,使用今日頭條新聞數據集進行訓練和評估。項目包括數據獲取、預處理、模型訓練、評估測試等環節。數據預處理涉及清洗文本、中文分詞、去除停用詞、構建詞匯表和向量化等步驟。TextCNN模型通過卷積層和池化層提取文本特征,并在訓練過程中記錄準確率和損失。最終,模型在測試集上達到了較高的準確率(84.06%),并生成了混淆矩陣可視化。項目還詳細介紹了TextCNN模型的結構和創新點,以及數據預處理和模型訓練的具體實現代碼。完整代碼已開源在個人GitHub:https://github.com/KLWU07/Chinese-text-classification-TextCNN

一、項目描述

1.數據獲取

(1)今日頭條新聞分類數據爬取(Python腳本)

  • 使用今日頭條文本分類數據集,包含民生故事、文化、娛樂、體育、財經、房
    產、汽車、教育、科技、軍事、旅游、國際、證券股票、農業、游戲15個類別,共30多萬條數據。

(2)每條數據形式

6554371968739574280_!_102_!_news_entertainment_!_今年你看《復聯3》哭得有多慘,明年你看《復聯4》就會叫得有多爽_!_卡魔拉,滅霸,小丑,復仇者聯盟3,理想主義者,復聯3,復聯4,雷神3,漫威,黑暗騎士

2.數據預處理

(1)清洗文本(保留中文字符)
(2)中文分詞(使用jieba)
(3)去除停用詞
(4)構建詞匯表并轉換為數字向量
(5)數據劃分為訓練集/驗證集/測試集

3.模型訓練

(1)使用TextCNN模型進行訓練
(2)記錄訓練過程中的準確率和損失
(3)保存最佳模型

4.評估測試

(1)輸出測試集的準確率、精確率、召回率和F1值
(2)生成混淆矩陣可視化

二、訓練過程和結果

1.訓練過程中的準確率和損失、混淆矩陣可視化

在這里插入圖片描述
在這里插入圖片描述

在這里插入圖片描述

textCNN((embed): Embedding(165444, 64, padding_idx=1)(conv11): Conv2d(1, 16, kernel_size=(3, 64), stride=(1, 1))(conv12): Conv2d(1, 16, kernel_size=(4, 64), stride=(1, 1))(conv13): Conv2d(1, 16, kernel_size=(5, 64), stride=(1, 1))(dropout): Dropout(p=0.5, inplace=False)(fc1): Linear(in_features=48, out_features=15, bias=True)
)
Epoch: 1   [===========]  cost: 775.52s;  loss: 10520.5979;  train acc: 0.1589;  val acc:0.2256;
Epoch: 2   [===========]  cost: 941.70s;  loss: 9934.9760;  train acc: 0.2239;  val acc:0.2975;
Epoch: 3   [===========]  cost: 1309.58s;  loss: 9329.3478;  train acc: 0.2844;  val acc:0.3842;
Epoch: 4   [===========]  cost: 1619.02s;  loss: 8444.4464;  train acc: 0.3663;  val acc:0.4751;
Epoch: 5   [===========]  cost: 1837.63s;  loss: 7384.8970;  train acc: 0.4545;  val acc:0.5569;
Epoch: 6   [===========]  cost: 1989.67s;  loss: 6397.7872;  train acc: 0.5371;  val acc:0.6317;
Epoch: 7   [===========]  cost: 2095.66s;  loss: 5669.9501;  train acc: 0.6029;  val acc:0.6932;
Epoch: 8   [===========]  cost: 2142.67s;  loss: 5095.1675;  train acc: 0.6578;  val acc:0.7403;
Epoch: 9   [===========]  cost: 2126.22s;  loss: 4633.6697;  train acc: 0.6989;  val acc:0.7702;
Epoch: 10   [===========]  cost: 2101.42s;  loss: 4296.7775;  train acc: 0.7272;  val acc:0.7899;
Epoch: 11   [===========]  cost: 2099.69s;  loss: 4047.9119;  train acc: 0.7460;  val acc:0.8012;
Epoch: 12   [===========]  cost: 2085.49s;  loss: 3871.4451;  train acc: 0.7594;  val acc:0.8096;
Epoch: 13   [===========]  cost: 2152.68s;  loss: 3725.9579;  train acc: 0.7689;  val acc:0.8157;
Epoch: 14   [===========]  cost: 2178.29s;  loss: 3613.9486;  train acc: 0.7761;  val acc:0.8209;
Epoch: 15   [===========]  cost: 2303.36s;  loss: 3531.0598;  train acc: 0.7828;  val acc:0.8230;
Epoch: 16   [===========]  cost: 2183.05s;  loss: 3447.1700;  train acc: 0.7878;  val acc:0.8280;
Epoch: 17   [===========]  cost: 2213.03s;  loss: 3388.7178;  train acc: 0.7918;  val acc:0.8286;
Epoch: 18   [===========]  cost: 2349.19s;  loss: 3332.7325;  train acc: 0.7953;  val acc:0.8313;
Epoch: 19   [===========]  cost: 2508.73s;  loss: 3278.0177;  train acc: 0.7982;  val acc:0.8326;
Epoch: 20   [===========]  cost: 2579.64s;  loss: 3241.5735;  train acc: 0.8021;  val acc:0.8341;
Epoch: 21   [===========]  cost: 2616.97s;  loss: 3205.6926;  train acc: 0.8039;  val acc:0.8350;
Epoch: 22   [===========]  cost: 2644.12s;  loss: 3158.8484;  train acc: 0.8072;  val acc:0.8348;
Epoch: 23   [===========]  cost: 2949.09s;  loss: 3135.9280;  train acc: 0.8093;  val acc:0.8376;
Epoch: 24   [===========]  cost: 2474.57s;  loss: 3118.4267;  train acc: 0.8100;  val acc:0.8385;
Epoch: 25   [===========]  cost: 2544.15s;  loss: 3092.5366;  train acc: 0.8117;  val acc:0.8395;
Epoch: 26   [===========]  cost: 2468.19s;  loss: 3072.7564;  train acc: 0.8132;  val acc:0.8389;
Epoch: 27   [===========]  cost: 2449.40s;  loss: 3041.6649;  train acc: 0.8155;  val acc:0.8403;
Epoch: 28   [===========]  cost: 2472.81s;  loss: 3025.2191;  train acc: 0.8161;  val acc:0.8410;
Epoch: 29   [===========]  cost: 2481.49s;  loss: 3007.8602;  train acc: 0.8182;  val acc:0.8408;
Epoch: 30   [===========]  cost: 2424.77s;  loss: 2986.9278;  train acc: 0.8195;  val acc:0.8421;
test ...
test acc: 0.8406   precision: 0.8406   recall: 0.8406    f1: 0.8406 

三、文本卷積神經網路TextCNN模型解釋

1.《Convolutional Neural Networks for Sentence Classification》論文創新點

(1)模型應用創新:將卷積神經網絡(CNN)應用于自然語言處理(NLP)中的句子分類任務,打破了傳統 NLP 任務主要依賴循環神經網絡(RNN)及其變體的局面。
(2)輸入表示創新:采用預訓練的詞向量(如 word2vec)作為輸入,替代傳統的 one - hot 編碼。預訓練詞向量能更好地捕捉詞的語義信息,提高了模型的泛化能力和性能。
(3)特征提取創新:使用多個不同尺寸的卷積核來提取句子中的關鍵信息,類似于多窗口大小的 ngram。不同大小的卷積核可以捕捉不同長度的局部相關性,從而更全面地提取句子的特征,提高模型的特征提取能力。
(4)模型結構創新:TextCNN 模型結構相對簡單,計算快速、實現方便,且在準確性方面表現較高。它包含嵌入層、卷積層、池化層和全連接層等,通過簡單的結構實現了高效的句子分類。
(5)訓練策略創新:提出了多種訓練策略,如 CNN - rand(基礎模型中所有單詞都是隨機初始化,然后在訓練期間進行修改)、CNN - static(帶有來自 word2vec 的預訓練向量,所有單詞保持靜態,只學習模型的其他參數)、CNN - non - static(預訓練的向量針對每個任務進行微調)和 CNN - multichannel(有兩組詞向量模型)。

2.數據預處理模塊

# 數據處理相關函數
def is_chinese(uchar):# 判斷字符是否為中文
def reserve_chinese(content):# 保留文本中的中文字符
def getStopWords():# 加載停用詞表
def dataParse(text, stop_words):# 解析文本數據,映射標簽,清洗文本并分詞
def getFormatData():# 處理原始數據,構建詞表,生成詞向量表示并保存

(1)中文分詞:jieba分詞

  • jieba是一個專為中文文本設計的分詞庫,其主要功能是將連續的中文文本拆分成有意義的詞語序列。與英文不同,中文文本中詞與詞之間沒有空格作為自然分隔符。例如:中文句子:“我喜歡自然語言處理”,分詞結果:“我 / 喜歡 / 自然語言處理”。分詞的準確性直接影響后續文本分析的質量,如關鍵詞提取、情感分析、機器翻譯等。
  • jieba分詞基于以下技術:基于前綴詞典的匹配算法、隱馬爾可夫模型 (HMM)、動態規劃優化
  • jieba 分詞的主要模式:精確模式(默認),將文本精確地切分成詞語,適合文本分析。全模式,將文本中所有可能的詞語都掃描出來,速度快但可能產生冗余。搜索引擎模式,在精確模式基礎上,對長詞再次切分,適合搜索引擎分詞。

(2)去除停用詞

  • 停用詞(Stop Words)是自然語言處理(NLP)中一類被認為對文本分析價值較低的詞匯。如中文中的 “的”“了”“在”“是”,英文中的 “the”“a”“is” 等。包括標點符號、連接詞、代詞等,本身不攜帶具體語義信息。去除停用詞是文本預處理的核心步驟,通過 stopwords.txt 文件可高效管理停用詞列表,配合 jieba 等分詞工具,能顯著提升文本分析的質量和效率。
  • 去除停用詞的主要目的:減少數據量,提升處理效率;過濾無意義詞匯,突出關鍵信息(如名詞、動詞等實詞);避免停用詞對文本分析(如關鍵詞提取、情感分析)產生干擾。
import jiebadef filter_stopwords(text, stopwords):"""對文本分詞并去除停用詞"""# 分詞words = jieba.lcut(text)# 過濾停用詞filtered_words = [word for word in words if word not in stopwords and len(word) > 1]return filtered_words# 示例文本
text = "今天天氣很好,我打算去公園散步。"
filtered_result = filter_stopwords(text, stopwords)
print(filtered_result)  # 輸出:['今天', '天氣', '很好', '打算', '公園', '散步']

(3)構建詞表:按詞頻排序,高頻詞獲得更小的索引

  • 在自然語言處理中,構建詞表(Vocabulary)是將文本轉換為計算機可處理格式的關鍵步驟。按詞頻排序并讓高頻詞獲得更小索引的策略,是一種常見且有效的詞表構建方法。
  • 文本數字化的需求:神經網絡等機器學習模型無法直接處理文本,需要將文本轉換為數字表示。詞表是文本與數字之間的映射橋梁,每個詞被映射為唯一的整數索引。
  • 按詞頻排序:在大量文本中頻繁出現的詞,通常攜帶更關鍵的語義信息。更小的索引值占用更少的內存和計算資源,高頻詞優先使用小索引可提升效率。

(4)文本向量化:通過詞匯表將詞語映射為索引

  • 在自然語言處理中,將文本轉換為詞向量序列是連接原始文本與深度學習模型的關鍵步驟。這一過程將離散的文本符號轉換為連續的數值表示,使模型能夠捕捉文本中的語義信息。
1.分詞:使用分詞工具(如代碼中的 jieba.lcut)將句子拆分為詞語列表。
例如:"深度學習很強大"["深度", "學習", "很", "強大"]2.統計詞頻:遍歷所有文本,統計每個詞的出現頻率(使用 Counter)。
例如:{"深度": 10, "學習": 8, "強大": 5, ...}3.排序和過濾:按詞頻從高到低排序(代碼中 reverse=True),保留高頻詞以降低維度。低頻詞可視為噪聲或替換為 <UNK>(未知詞)。4.建立映射:為每個詞分配唯一索引(從 1 開始,0 通常留給填充符 <PAD>)。
例如:"<PAD>": 0, "深度": 1, "學習": 2, "強大": 3, ..., "<UNK>": 100015.詞到索引的轉換:根據詞匯表將每個詞替換為對應的索引。處理未知詞:若詞不在詞匯表中,使用 <UNK> 的索引(如 10001)。
例如:["深度", "學習", "強大"][1, 2, 3]6.序列對齊(Padding):使用 pad_sequence 將所有序列填充到相同長度(按最長序列或預設值)。
例如(填充到長度5):
[1, 2, 3][1, 2, 3, 0, 0]  # 0是<pad>的索引

(5) MAT 文件 (data.mat)

data = {'X': data,          # 文本的數字向量表示(填充后的索引序列)'label': labels,     # 文本對應的類別標簽(數字編碼)'num_words': len(my_vocab)  # 詞匯表大小(總詞數)
}
io.savemat('./dataset/data/data.mat', data)
  • data.mat 是通過 scipy.io.savemat() 生成的一個 MATLAB 格式的數據文件,它保存了預處理后的文本數據及其標簽,用于后續的模型訓練和評估。

  • 為什么使用 .mat 格式?

    • 兼容性: MATLAB和Python(通過 scipy.io)均可讀寫,便于跨平臺使用。
    • 結構化存儲: 適合存儲多維數組和元數據(如詞匯表大小)。
    • 效率: 二進制格式加載速度快,占用空間比純文本小。
  • X(特征數據)
    形狀: [num_samples, max_seq_len]
    num_samples: 樣本數量(即多少條文本)。
    max_seq_len: 填充后的序列最大長度。每條文本被轉換為固定長度的數字序列
    例如:

原始文本: ["深度", "學習", "強大"] → 分詞后
詞匯表: {"深度":1, "學習":2, "強大":3, "<PAD>":0}
填充后: [1, 2, 3, 0, 0]  # 填充到長度5
  • label(標簽數據)
    形狀: [num_samples,]
    每個文本的類別標簽(通過 LabelEncoder 從字符串標簽轉換為數字)

  • num_words(詞匯表大小)
    作用: 記錄詞匯表的總詞數,用于初始化模型的嵌入層(nn.Embedding)。

數據字段維度說明
X382589×53共 382,589 條文本樣本,每條文本被填充/截斷為 53 個詞的固定長度。
num_words165444詞匯表總大小(包含所有唯一詞 + 特殊符號如 和 )
label1×382589每條文本對應的類別標簽(共 382,589 個標簽,可能是 15 分類任務)
  • 53 是數據集中所有文本分詞后的 最大詞數(即至少有一條文本被分詞為 53 個詞)
text:"他是最帥反派專業戶,演《古惑仔》大火,今病魔纏身可憐無人識!"
經過 jieba.lcut() 分詞后可能變成:['他', '是', '最', '帥', '反派', '專業戶', ',', '演', '《', '古惑仔', '》', '大火', ',', '今', '病魔', '纏身', '可憐', '無人', '識', '!']
分詞數量:20 個(如果過濾掉標點符號和停用詞,可能更少)。

3.TextCNN 的核心結構

(1)Embedding Layer(詞嵌入層):將單詞映射為稠密向量。
輸入:詞的整數索引(如通過詞表映射得到的[1, 5, 10])。
輸出:固定維度的連續向量(如 100 維的[0.2, -0.5, 0.1, …])。

(2)Convolutional Layer(卷積層):使用 不同尺寸的卷積核(如 [3,4,5])在詞向量序列上滑動,提取局部特征。

  • 為什么使用卷積?

    • 捕獲局部特征:文本中的語義單元(如短語、情感表達)通常是連續詞的組合。
    • 參數共享:減少模型參數量,提高泛化能力。
    • 多尺度特征:不同大小的卷積核能捕獲不同長度的語義模式。
  • 文本數據的二維表示
    雖然文本本質是一維序列,但在卷積操作中通常表示為二維張量,形狀:[句子長度 × 詞向量維度]。如一個包含 10 個詞的句子,每個詞用 300 維向量表示 → [10 × 300] 的二維矩陣

  • kernel_size=2:每次滑動查看 2 個連續詞 的組合(如 [“深度”, “學習”])。
    kernel_size=3:每次滑動查看 3 個連續詞 的組合(如 [“深度”, “學習”, “模型”])。
    kernel_size=4:每次滑動查看 4 個連續詞 的組合(如 [“深度”, “學習”, “模型”, “強大”])。

  • 經過 ReLU 激活函數增強非線性。

Max-Pooling Layer(最大池化層)
對每個卷積核的輸出進行 全局最大池化(Global Max Pooling),提取最重要的特征。
輸出:[batch_size, num_filters](每個卷積核保留一個最大值)。

(3)Concatenation(特征拼接):將所有卷積核的池化結果拼接,形成固定長度的特征向量。

  • 池化層的核心功能:特征篩選與降維
    • 最大池化:在每個滑動窗口中選取最大值作為輸出,本質是保留最突出的特征。
    • 降維效果:減少特征維度,降低計算量,同時避免過擬合。
  • 文本處理中的直觀意義
    • 文本中的關鍵語義(如情感詞、主題詞)通常會產生較大的卷積響應值。
    • 最大池化相當于 “篩選” 出每個局部窗口中最關鍵的特征,忽略次要信息。
import torch
import torch.nn.functional as F# 假設卷積后的特征圖(單通道)
feature_map = torch.tensor([[0.2, -0.5, 0.8, 0.1, -0.3],  # 長度為5的特征序列
], dtype=torch.float32).unsqueeze(0)  # [1, 1, 5]# 最大池化(窗口大小為5,覆蓋整個序列)
pooled = F.max_pool1d(feature_map, kernel_size=5)
print(pooled)  # 輸出: tensor([[[0.8]]])

該操作從長度為 5 的序列中提取最大值 0.8,作為該通道的最終特征。
(4)Fully Connected Layer(全連接層):將拼接后的特征映射到類別空間(Softmax 輸出概率)。在文本分類模型中,全連接層 (Linear) 和 Dropout 層是構建分類器的關鍵組件。它們共同作用,將提取的文本特征映射到分類空間并防止過擬合。

四、完整代碼

# -*- coding: utf-8 -*-
from jieba import lcut
from torchtext.vocab import vocab
from collections import OrderedDict, Counter
from torchtext.transforms import VocabTransform
from torch.nn.utils.rnn import pack_padded_sequence, pad_sequence
from sklearn.preprocessing import LabelEncoder
import scipy.io as io
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
import torch.nn as nn
from torch.optim import Adam
import numpy as np
from utils import metrics, safeCreateDir
import time
from sklearn.metrics import ConfusionMatrixDisplay
from matplotlib import pyplot as plt
import seaborn as sns
import torch
from torch.nn import functional as F
import math
from sklearn.metrics import confusion_matrix
import os# 數據處理
def is_chinese(uchar):if (uchar >= '\u4e00' and uchar <= '\u9fa5'):return Trueelse:return Falsedef reserve_chinese(content):content_str = ''for i in content:if is_chinese(i):content_str += ireturn content_strdef getStopWords():file = open('./dataset/stopwords.txt', 'r', encoding='utf8')words = [i.strip() for i in file.readlines()]file.close()return wordsdef dataParse(text, stop_words):label_map = {'news_story': 0, 'news_culture': 1, 'news_entertainment': 2,'news_sports': 3, 'news_finance': 4, 'news_house': 5, 'news_car': 6,'news_edu': 7, 'news_tech': 8, 'news_military': 9, 'news_travel': 10,'news_world': 11, 'stock': 12, 'news_agriculture': 13, 'news_game': 14}_, _, label, content, _ = text.split('_!_')label = label_map[label]content = reserve_chinese(content)words = lcut(content)words = [i for i in words if not i in stop_words]return words, int(label)def getFormatData():file = open('./dataset/data/toutiao_cat_data.txt', 'r', encoding='utf8')texts = file.readlines()file.close()stop_words = getStopWords()all_words = []all_labels = []for text in texts:content, label = dataParse(text, stop_words)if len(content) <= 0:continueall_words.append(content)all_labels.append(label)ws = sum(all_words, [])set_ws = Counter(ws)keys = sorted(set_ws, key=lambda x: set_ws[x], reverse=True)dict_words = dict(zip(keys, list(range(1, len(set_ws) + 1))))ordered_dict = OrderedDict(dict_words)my_vocab = vocab(ordered_dict, specials=['<UNK>', '<SEP>'])vocab_transform = VocabTransform(my_vocab)vector = vocab_transform(all_words)vector = [torch.tensor(i) for i in vector]pad_seq = pad_sequence(vector, batch_first=True)labelencoder = LabelEncoder()labels = labelencoder.fit_transform(all_labels)data = pad_seq.numpy()data = {'X': data,'label': labels,'num_words': len(my_vocab)}io.savemat('./dataset/data/data.mat', data)# 數據集加載
class Data(Dataset):def __init__(self, mode='train'):data = io.loadmat('./dataset/data/data.mat')self.X = data['X']self.y = data['label']self.num_words = data['num_words'].item()train_X, val_X, train_y, val_y = train_test_split(self.X, self.y.squeeze(), test_size=0.3, random_state=1)val_X, test_X, val_y, test_y = train_test_split(val_X, val_y, test_size=0.5, random_state=1)if mode == 'train':self.X = train_Xself.y = train_yelif mode == 'val':self.X = val_Xself.y = val_yelif mode == 'test':self.X = test_Xself.y = test_ydef __getitem__(self, item):return self.X[item], self.y[item]def __len__(self):return self.X.shape[0]class getDataLoader():def __init__(self, batch_size):train_data = Data('train')val_data = Data('val')test_data = Data('test')self.traindl = DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=4)self.valdl = DataLoader(val_data, batch_size=batch_size, shuffle=True, num_workers=4)self.testdl = DataLoader(test_data, batch_size=batch_size, shuffle=False, num_workers=4)self.num_words = train_data.num_words# 定義網絡結構
class textCNN(nn.Module):def __init__(self, param):super(textCNN, self).__init__()ci = 1kernel_num = param['kernel_num']kernel_size = param['kernel_size']vocab_size = param['vocab_size']embed_dim = param['embed_dim']dropout = param['dropout']class_num = param['class_num']self.param = paramself.embed = nn.Embedding(vocab_size, embed_dim, padding_idx=1)self.conv11 = nn.Conv2d(ci, kernel_num, (kernel_size[0], embed_dim))self.conv12 = nn.Conv2d(ci, kernel_num, (kernel_size[1], embed_dim))self.conv13 = nn.Conv2d(ci, kernel_num, (kernel_size[2], embed_dim))self.dropout = nn.Dropout(dropout)self.fc1 = nn.Linear(len(kernel_size) * kernel_num, class_num)def init_embed(self, embed_matrix):self.embed.weight = nn.Parameter(torch.Tensor(embed_matrix))@staticmethoddef conv_and_pool(x, conv):x = conv(x)x = F.relu(x.squeeze(3))x = F.max_pool1d(x, x.size(2)).squeeze(2)return xdef forward(self, x):x = self.embed(x)x = x.unsqueeze(1)x1 = self.conv_and_pool(x, self.conv11)x2 = self.conv_and_pool(x, self.conv12)x3 = self.conv_and_pool(x, self.conv13)x = torch.cat((x1, x2, x3), 1)x = self.dropout(x)logit = F.log_softmax(self.fc1(x), dim=1)return logitdef init_weight(self):for m in self.modules():if isinstance(m, nn.Conv2d):n = m.kernel_size[0] * m.kernel_size[1] * m.out_channelsm.weight.data.normal_(0, math.sqrt(2. / n))if m.bias is not None:m.bias.data.zero_()elif isinstance(m, nn.Linear):m.weight.data.normal_(0, 0.01)m.bias.data.zero_()def plot_acc(train_acc):sns.set(style='darkgrid')plt.figure(figsize=(10, 7))x = list(range(len(train_acc)))plt.plot(x, train_acc, alpha=0.9, linewidth=2, label='train acc')plt.xlabel('Epoch')plt.ylabel('Acc')plt.legend(loc='best')plt.savefig('results/acc.png', dpi=400)def plot_loss(train_loss):sns.set(style='darkgrid')plt.figure(figsize=(10, 7))x = list(range(len(train_loss)))plt.plot(x, train_loss, alpha=0.9, linewidth=2, label='train loss')plt.xlabel('Epoch')plt.ylabel('loss')plt.legend(loc='best')plt.savefig('results/loss.png', dpi=400)# 定義訓練過程
class Trainer():def __init__(self):safeCreateDir('results/')self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')self._init_data()self._init_model()def _init_data(self):data = getDataLoader(batch_size=64)self.traindl = data.traindlself.valdl = data.valdlself.testdl = data.testdlself.num_words = data.num_wordsdef _init_model(self):self.textCNN_param = {'vocab_size': self.num_words,'embed_dim': 64,'class_num': 15,"kernel_num": 16,"kernel_size": [3, 4, 5],"dropout": 0.5,}self.net = textCNN(self.textCNN_param)self.opt = Adam(self.net.parameters(), lr=1e-4, weight_decay=5e-4)self.cri = nn.CrossEntropyLoss()def save_model(self):save_dir = 'saved_dict'if not os.path.exists(save_dir):os.makedirs(save_dir)torch.save(self.net.state_dict(), os.path.join(save_dir, 'cnn.pt'))def load_model(self):save_dir = 'saved_dict'model_path = os.path.join(save_dir, 'cnn.pt')if not os.path.exists(model_path):raise FileNotFoundError(f"Model file not found at {model_path}")self.net.load_state_dict(torch.load(model_path))def train(self, epochs):print('init net...')self.net.init_weight()print(self.net)patten = 'Epoch: %d   [===========]  cost: %.2fs;  loss: %.4f;  train acc: %.4f;  val acc:%.4f;'train_accs = []c_loss = []for epoch in range(epochs):cur_preds = np.empty(0)cur_labels = np.empty(0)cur_loss = 0start = time.time()for batch, (inputs, targets) in enumerate(self.traindl):inputs = inputs.to(self.device)targets = targets.to(self.device)self.net.to(self.device)pred = self.net(inputs)loss = self.cri(pred, targets)self.opt.zero_grad()loss.backward()self.opt.step()cur_preds = np.concatenate([cur_preds, pred.cpu().detach().numpy().argmax(axis=1)])cur_labels = np.concatenate([cur_labels, targets.cpu().numpy()])cur_loss += loss.item()acc, precision, f1, recall = metrics(cur_preds, cur_labels)val_acc, val_precision, val_f1, val_recall = self.val()train_accs.append(acc)c_loss.append(cur_loss)end = time.time()print(patten % (epoch + 1, end - start, cur_loss, acc, val_acc))self.save_model()plot_acc(train_accs)plot_loss(c_loss)@torch.no_grad()def val(self):self.net.eval()cur_preds = np.empty(0)cur_labels = np.empty(0)for batch, (inputs, targets) in enumerate(self.valdl):inputs = inputs.to(self.device)targets = targets.to(self.device)self.net.to(self.device)pred = self.net(inputs)cur_preds = np.concatenate([cur_preds, pred.cpu().detach().numpy().argmax(axis=1)])cur_labels = np.concatenate([cur_labels, targets.cpu().numpy()])acc, precision, f1, recall = metrics(cur_preds, cur_labels)self.net.train()return acc, precision, f1, recall@torch.no_grad()def test(self):print("test ...")self.load_model()patten = 'test acc: %.4f   precision: %.4f   recall: %.4f    f1: %.4f    'self.net.eval()cur_preds = np.empty(0)cur_labels = np.empty(0)for batch, (inputs, targets) in enumerate(self.testdl):inputs = inputs.to(self.device)targets = targets.to(self.device)self.net.to(self.device)pred = self.net(inputs)cur_preds = np.concatenate([cur_preds, pred.cpu().detach().numpy().argmax(axis=1)])cur_labels = np.concatenate([cur_labels, targets.cpu().numpy()])acc, precision, f1, recall = metrics(cur_preds, cur_labels)cv_conf = confusion_matrix(cur_preds, cur_labels)labels11 = ['story', 'culture', 'entertainment', 'sports', 'finance','house', 'car', 'edu', 'tech', 'military','travel', 'world', 'stock', 'agriculture', 'game']fig, ax = plt.subplots(figsize=(15, 15))disp = ConfusionMatrixDisplay(confusion_matrix=cv_conf, display_labels=labels11)disp.plot(cmap="Blues", values_format='', ax=ax)plt.savefig("results/ConfusionMatrix.png", dpi=400)self.net.train()print(patten % (acc, precision, recall, f1))if __name__ == "__main__":getFormatData()  # 數據預處理:數據清洗和詞向量trainer = Trainer()trainer.train(epochs=30)  # 數據訓練trainer.test()  # 測試

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/diannao/87600.shtml
繁體地址,請注明出處:http://hk.pswp.cn/diannao/87600.shtml
英文地址,請注明出處:http://en.pswp.cn/diannao/87600.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

iot-dc3 項目Bug修復保姆喂奶級教程

一.Uncaught (in promise) ReferenceError: TinyArea is not defined 1.觸發場景 前端設備模塊,點擊關聯模板、關聯位號、設備數據,無反應,一直切不過去,沒有報錯通知,F12查看控制臺報錯如下: 2.引起原因 前端導入的庫為"@antv/g2": "^5.3.0",在 P…

Spring Boot + MyBatis Plus + SpringAI + Vue 畢設項目開發全解析(源碼)

前言 前些天發現了一個巨牛的人工智能免費學習網站&#xff0c;通俗易懂&#xff0c;風趣幽默&#xff0c;忍不住分享一下給大家。點擊跳轉到網站 Spring Boot MyBatis Plus SpringAI Vue 畢設項目開發全解析 目錄 一、項目概述與技術選型 項目背景與需求分析技術棧選擇…

Vitess數據庫部署與運維深度指南:構建可伸縮、高可用與安全的云原生數據庫

摘要 Vitess是一個為MySQL和MariaDB設計的云原生、水平可伸縮的分布式數據庫系統&#xff0c;它通過分片&#xff08;sharding&#xff09;實現無限擴展&#xff0c;同時保持對應用程序的透明性&#xff0c;使其無需感知底層數據分布。該項目于2019年從云原生計算基金會&#…

SpringAI+DeepSeek大模型應用開發——6基于MongDB持久化對話

持久化對話 默認情況下&#xff0c;聊天記憶存儲在內存中ChatMemory chatMemory new InMemoryChatMemory()。 如果需要持久化存儲&#xff0c;可以實現一個自定義的聊天記憶存儲類&#xff0c;以便將聊天消息存儲在你選擇的任何持久化存儲介質中。 MongoDB 文檔型數據庫&…

Mac電腦-音視頻剪輯編輯-Final Cut Pro X(fcpx)

Final Cut Pro Mac是一款專業的視頻剪輯工具&#xff0c;專為蘋果用戶設計。 它具備強大的視頻剪輯、音軌、圖形特效和調色功能&#xff0c;支持整片輸出&#xff0c;提升創作效率。 經過Apple芯片優化&#xff0c;利用Metal引擎動力&#xff0c;可處理更復雜的項目&#xff…

不同程度多徑效應影響下的無線通信網絡電磁信號仿真數據生成程序

生成.mat數據&#xff1a; %創建時間&#xff1a;2025年6月19日 %zhouzhichao %遍歷生成不同程度多徑效應影響的無線通信網絡拓撲推理數據用于測試close all clearsnr 40; n 30;dataset_n 100;for bias 0.1:0.1:0.9nodes_P ones(n,1);Sampling_M 3000;%獲取一幀信號及對…

Eureka 和 Feign(二)

Eureka 和 Feign 是 Spring Cloud 微服務架構中協同工作的兩個核心組件&#xff0c;它們的關系可以通過以下比喻和詳解來說明&#xff1a; 關系核心&#xff1a;服務發現 → 動態調用 組件角色核心功能Eureka服務注冊中心服務實例的"電話簿"Feign聲明式HTTP客戶端根…

Springboot仿抖音app開發之RabbitMQ 異步解耦(進階)

Springboot仿抖音app開發之評論業務模塊后端復盤及相關業務知識總結 Springboot仿抖音app開發之粉絲業務模塊后端復盤及相關業務知識總結 Springboot仿抖音app開發之用短視頻務模塊后端復盤及相關業務知識總結 Springboot仿抖音app開發之用戶業務模塊后端復盤及相關業務知識…

1.部署KVM虛擬化平臺

一.KVM原理簡介 廣義的KVM實際上包含兩部分&#xff0c;一部分是基于Linux內核支持的KVM內核模塊&#xff0c;另一部分就是經過簡化和修改的Qemuo KVM內核模塊是模擬處理器和內存以支持虛擬機的運行&#xff0c;Qemu主要處理丨℃以及為用戶提供一個用戶空間工具來進行虛擬機的…

優化與管理數據庫連接池

優化與管理數據庫連接池 在現代高并發系統中,數據庫連接池是保障數據庫訪問性能的核心組件之一。合理配置、優化和管理連接池,可以有效緩解連接創建成本高、連接頻繁斷開重連等問題,從而提升系統整體的響應速度與穩定性。 數據庫連接池的作用與價值 數據庫連接池的核心思…

實現回顯服務器(基于UDP)

目錄 一.回顯服務器的基本概念 二.回顯服務器的簡單示意圖 三.實現回顯服務器&#xff08;基于UDP&#xff09;必須要知道的API 1.DatagramSocket 2.DatagramPacket 3.InetSocketAddress 4.二者區別 1. 功能職責 2. 核心作用 3. 使用場景流程 四.實現服務器端的主…

LabVIEW電液伺服閥自動測試

針對航空航天及工業液壓領域電液伺服閥測試需求&#xff0c;采用 LabVIEW 圖形化編程平臺&#xff0c;集成 NI、GE Druck 等品牌硬件&#xff0c;構建集靜態特性&#xff08;流量/ 壓力 / 泄漏&#xff09;與動態特性&#xff08;頻率響應&#xff09;測試于一體的自動化系統&a…

性能優化 - 高級進階: Spring Boot服務性能優化

文章目錄 Pre引言&#xff1a;為何提前暴露指標與分析的重要性指標暴露與監控接入Prometheus 集成 性能剖析工具&#xff1a;火焰圖與 async-profilerasync-profiler 下載與使用結合 Flame 圖優化示例 HTTP 及 Web 層優化CDN 與靜態資源加速Cache-Control/Expires 在 Nginx 中配…

力扣網C語言編程題:除自身以外數組的乘積

一. 簡介 本文記錄力扣網上涉及數組方面的編程題&#xff0c;主要以 C語言實現。 二. 力扣上C語言編程題&#xff1a;涉及數組 題目&#xff1a;除自身以外數組的乘積 給你一個整數數組 nums&#xff0c;返回 數組 answer &#xff0c;其中 answer[i] 等于 nums 中除 nums[i…

SpringBoot擴展——發送郵件!

發送郵件 在日常工作和生活中經常會用到電子郵件。例如&#xff0c;當注冊一個新賬戶時&#xff0c;系統會自動給注冊郵箱發送一封激活郵件&#xff0c;通過郵件找回密碼&#xff0c;自動批量發送活動信息等。郵箱的使用基本包括這幾步&#xff1a;先打開瀏覽器并登錄郵箱&…

【html】iOS26 液態玻璃實現效果

<!DOCTYPE html> <html lang"zh"><head><meta charset"UTF-8" /><meta name"viewport" content"widthdevice-width, initial-scale1.0" /><title>液體玻璃效果演示</title><style>bo…

探索算法秘境:量子隨機游走算法及其在圖論問題中的創新應用

目錄 ?編輯 一、量子隨機游走算法的起源與原理 二、量子隨機游走算法在圖論問題中的創新應用 三、量子隨機游走算法的優勢與挑戰 四、結語 在算法研究的浩瀚星空中&#xff0c;總有一些領域如同遙遠星系&#xff0c;閃爍著神秘而誘人的光芒。今天&#xff0c;我們將一同深…

C# 一維數組和矩形數組全解析

在編程的世界里&#xff0c;數組是一種非常重要的數據結構。今天&#xff0c;我們就來詳細了解一下一維數組和矩形數組。 數組基礎認知 數組實例是從 System.Array 繼承類型的對象。由于它從 BCL 基類派生而來&#xff0c;所以繼承了許多有用的成員&#xff1a; Rank 屬性&a…

WebStorm編輯器側邊欄

目錄 編輯器側邊欄行號配置行號隱藏行號 代碼折疊側邊欄圖標書簽添加匿名書簽添加助記符書簽 運行和調試管理斷點配置斷點圖標 版本控制配置Git Blame注釋 編輯器側邊欄 編輯器左側的垂直區域。當編寫代碼時&#xff0c;提供重要信息和操作圖標。外觀和行為可以根據你的喜好進…

騰訊云TCCA認證考試報名 - TDSQL數據庫交付運維工程師(PostgreSQL版)

數據庫交付運維工程師-騰訊云TDSQL(PostgreSQL版)認證 適合人群&#xff1a; 適合從事TDSQL(PostgreSQL版)交付、運維、售前咨詢以及TDSQL(PostgreSQL版)相關項目的管理人員。 認證考試 單選*40道多選*20道 成績查詢 70分及以上通過認證&#xff0c;官網個人中心->認證考…