從代碼學習深度學習 - 情感分析:使用卷積神經網絡 PyTorch版

文章目錄

  • 前言
  • 加載數據集
  • 一維卷積
  • 最大時間匯聚層
  • textCNN模型
    • 定義模型
    • 加載預訓練詞向量
    • 訓練和評估模型
  • 總結


前言

在之前的章節中,我們探討了如何使用循環神經網絡(RNN)來處理序列數據。今天,我們將探索另一種強大的模型——卷積神經網絡(CNN)——并將其應用于自然語言處理中的經典任務:情感分析。

你可能會覺得奇怪,CNN不是主要用于圖像處理的嗎?確實,CNN在計算機視覺領域取得了巨大的成功,它通過二維卷積核捕捉圖像的局部特征(如邊緣、紋理)。但如果我們換個角度思考,文本序列可以被看作是一維的“圖像”,其中每個詞元(token)就是一個“像素”。這樣,我們就可以使用一維卷積來捕捉文本中的局部模式,比如由相鄰單詞組成的n-gram

本篇博客將詳細介紹如何使用 textCNN 模型,這是一種專為文本分類設計的CNN架構。我們將基于IMDb電影評論數據集,訓練一個能夠判斷評論是正面還是負面的模型。整個流程如下圖所示,我們將使用預訓練的GloVe詞向量作為輸入,將其送入textCNN模型,最終得到情感分類結果。

在這里插入圖片描述

讓我們開始吧!首先,我們需要加載所需的數據集。

完整代碼:下載鏈接

加載數據集

我們仍然使用IMDb電影評論數據集。通過我們預先準備好的 utils_for_data.load_data_imdb 輔助函數,我們可以方便地加載訓練和測試數據迭代器,以及一個根據訓練數據構建好的詞匯表(vocab)。

import torch
import utils_for_data
from torch import nnbatch_size = 64
train_iter, test_iter, vocab = utils_for_data.load_data_imdb(batch_size)

一維卷積

在深入textCNN模型之前,我們先來回顧一下一維卷積是如何工作的。它本質上是二維卷積在只有一個維度(時間或序列步長)上的特例。

如下圖所示,卷積窗口(或稱為卷積核)在一個一維輸入張量上從左到右滑動。在每個位置,輸入子張量與核張量進行逐元素相乘,然后求和,得到輸出張量中對應位置的一個標量值。例如,圖中第一個輸出值 2 是通過 0*1 + 1*2 = 2 計算得出的。

在這里插入圖片描述

我們可以通過代碼來實現這個一維互相關(corr1d)運算,代碼中詳盡的注釋解釋了每一步的維度變化和操作目的。

import torchdef corr1d(X, K):"""實現一維互相關(卷積)運算參數:X: 輸入張量,維度為 (n,) 其中 n 是輸入序列的長度K: 卷積核張量,維度為 (w,) 其中 w 是卷積核的長度返回:Y: 輸出張量,維度為 (n - w + 1,) 其中 n-w+1 是輸出序列的長度"""# 獲取卷積核的長度,維度: 標量w = K.shape[0]# 創建輸出張量,長度為輸入長度減去卷積核長度加1# Y的維度: (X.shape[0] - w + 1,)Y = torch.zeros((X.shape[0] - w + 1))# 遍歷輸出張量的每個位置for i in range(Y.shape[0]):# 在第i個位置進行卷積運算# X[i: i + w] 的維度: (w,) - 提取輸入序列的一個窗口# K 的維度: (w,) - 卷積核# 兩者逐元素相乘后求和得到標量結果Y[i] = (X[i: i + w] * K).sum()return Y# 測試代碼
# X: 輸入張量,維度 (7,) - 包含7個元素的一維張量
X = torch.tensor([0, 1, 2, 3, 4, 5, 6])# K: 卷積核張量,維度 (2,) - 包含2個元素的一維張量  
K = torch.tensor([1, 2])# 調用函數進行一維卷積運算
# 輸出結果的維度: (7 - 2 + 1,) = (6,)
result = corr1d(X, K)
print(result)

輸出結果與預期一致:

tensor([ 2.,  5.,  8., 11., 14., 17.])

在NLP中,詞嵌入通常是多維的,這意味著我們的輸入有多個通道。一維卷積同樣可以處理多通道輸入。此時,卷積核也需要有相同數量的輸入通道。運算時,對每個通道分別執行一維互相關,然后將所有通道的結果相加,得到一個單通道的輸出。

在這里插入圖片描述

下面是多輸入通道一維互相關的實現。

import torchdef corr1d_multi_in(X, K):"""實現多輸入通道的一維互相關(卷積)運算參數:X: 多通道輸入張量,維度為 (c, n) 其中 c 是輸入通道數,n 是每個通道的序列長度K: 多通道卷積核張量,維度為 (c, w) 其中 c 是輸入通道數,w 是卷積核的長度返回:result: 輸出張量,維度為 (n - w + 1,) 其中 n-w+1 是輸出序列的長度"""# 遍歷X和K的第0維(通道維),對每個通道分別進行一維卷積,然后求和# X的維度: (c, n) - c個通道,每個通道長度為n# K的維度: (c, w) - c個通道,每個通道的卷積核長度為w# zip(X, K) 將對應通道的輸入和卷積核配對# 每次corr1d(x, k)的結果維度: (n - w + 1,)# sum()將所有通道的結果相加,最終輸出維度: (n - w + 1,)return sum(corr1d(x, k) for x, k in zip(X, K))# 測試代碼
# X: 多通道輸入張量,維度 (3, 7) - 3個輸入通道,每個通道包含7個元素
X = torch.tensor([[0, 1, 2, 3, 4, 5, 6],[1, 2, 3, 4, 5, 6, 7],[2, 3, 4, 5, 6, 7, 8]])# K: 多通道卷積核張量,維度 (3, 2) - 3個通道,每個通道的卷積核長度為2
K = torch.tensor([[1, 2], [3, 4], [-1, -3]])# 調用函數進行多通道一維卷積運算
# 輸出結果的維度: (7 - 2 + 1,) = (6,)
result = corr1d_multi_in(X, K)
print(result)

輸出結果:

tensor([ 2.,  8., 14., 20., 26., 32.])

有趣的是,多輸入通道的一維互相關等價于單輸入通道的二維互相關,只要將二維卷積核的高度設置為與輸入張量的高度相同即可,如下圖所示。

在這里插入圖片描述

最大時間匯聚層

在卷積層之后,textCNN使用了一個稱為最大時間匯聚層(Max-over-time Pooling)的關鍵組件。卷積操作的輸出長度依賴于輸入序列和卷積核的寬度,導致不同卷積核產生的輸出序列長度不同。最大時間匯聚層的作用是在時間步(序列長度)維度上取最大值。這相當于從每個卷積核提取的特征圖中,只保留最強烈的信號。無論輸入序列多長,經過這個操作后,每個通道都只會輸出一個標量值,從而解決了不同卷積核輸出維度不一的問題,并生成了用于分類的固定長度的特征向量。

textCNN模型

理解了一維卷積和最大時間匯聚后,我們就可以構建textCNN模型了。整個模型的架構如下圖所示:

在這里插入圖片描述

輸入是一個句子,每個詞元由一個多維向量表示。我們定義了多種不同寬度的卷積核(

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/bicheng/86640.shtml
繁體地址,請注明出處:http://hk.pswp.cn/bicheng/86640.shtml
英文地址,請注明出處:http://en.pswp.cn/bicheng/86640.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

深入解析分布式訓練基石:ps-lite源碼實現原理

分布式機器學習框架是現代推薦、廣告和搜索系統的核心支撐。面對海量訓練數據和高維稀疏特征,參數服務器(Parameter Server, PS) 架構應運而生。作為早期經典實現的ps-lite因其簡潔性和完整性,成為理解PS原理的絕佳切入點。本文將…

IDEA 插件開發:Internal Actions 與 UI Inspector 快速定位 PSI

在開發 IntelliJ 平臺插件的過程中,你常常需要搞清楚 某個 IDE 彈框背后是如何操作 PSI(Program Structure Interface) 的。下面這篇筆記將介紹如何通過 Internal Actions、UI Inspector 以及調試技巧快速定位 PSI 調用鏈。 1. 啟用 Internal…

26考研|數學分析:多元函數微分學

前言 本章我們將進行多元函數微分學的學習,多元函數微分學與一元函數微分學相對應,涉及到可微性、中值定理、泰勒公式等諸多問題的探討與研究,本章難度較大,在學習過程中需要進行深度思考與分析,才能真正掌握這一章的…

數星星--二分

https://www.matiji.net/exam/brushquestion/17/4498/F16DA07A4D99E21DFFEF46BD18FF68AD 二分思路不難&#xff0c;關鍵的區間內個數的確定 #include<bits/stdc.h> using namespace std; #define N 100011 #define inf 0x3f3f3f3f typedef long long ll; typedef pair&…

Oracle/PostgreSQL/MSSQL/MySQL函數實現對照表

函數列表清單 函數作用OraclePOSTGRESQLMSSQLMYSQL求字符串長度LENGTH(str)LENGTH(str)LEN(str)LENGTH(str)字符切割SUBSTR(str,index,length)SUBSTR(str,index,length)SUBSTRING(str,index,length)SUBSTRING(str,index,length)字符串連接str1||str2||str3...strNstr1||str2||…

pycharm客戶端安裝教程

二、 pycharm客戶端安裝 打開pycharm官網&#xff1a;https://www.jetbrains.com/pycharm/download/?sectionwindows 選擇其他版本 選擇2018社區版本&#xff0c;點擊下載 雙擊下載的安裝程序(第一個彈框允許)&#xff0c;選擇下一步 更改安裝路徑&#xff0c;在pycah…

博圖SCL語言中用戶自定義數據類型(UDT)使用詳解

博圖SCL語言中用戶自定義數據類型&#xff08;UDT&#xff09;使用詳解 一、UDT概述 用戶自定義數據類型&#xff08;UDT&#xff09;是TIA Portal中強大的結構化工具&#xff0c;允許將多個相關變量組合成單一數據結構。UDT本質是可重用的數據模板&#xff0c;具有以下核心優…

Vscode自定義代碼快捷方式

首選項>配置代碼片段 >新建全局代碼片段 (也可以選擇你的語言 為了避免有的時候不生效 選擇全局代碼) {"console.log": { //名字"prefix": "log",//prefix 快捷鍵 &#xff1a; log"body": ["console.log($1);", //b…

ESP32 008 MicroPython Web框架庫 Microdot 實現的網絡文件服務器

以下是整合了所有功能的完整 main.py(在ESP32 007 MicroPython 適用于 Python 和 MicroPython 的小型 Web 框架庫 Microdot基礎上)&#xff0c;實現了&#xff1a; Wi?Fi 自動連接&#xff08;支持靜態 IP&#xff09;&#xff1b;SD 卡掛載&#xff1b;從 /sd/www/ 讀取 HTML…

Mcp-git-ingest Quickstart

目錄 配置例子 文檔github鏈接&#xff1a;git_ingest.md 配置 {"mcpServers": {"mcp-git-ingest": {"command": "uvx","args": ["--from", "githttps://github.com/adhikasp/mcp-git-ingest", "…

(LeetCode 面試經典 150 題) 27.移除元素

目錄 題目&#xff1a; 題目描述&#xff1a; 題目鏈接&#xff1a; 思路&#xff1a; 核心思路&#xff1a; 思路詳解&#xff1a; 樣例模擬&#xff1a; 代碼&#xff1a; C代碼&#xff1a; Java代碼&#xff1a; 題目&#xff1a; 題目描述&#xff1a; 題目鏈接…

MySQL之事務原理深度解析

MySQL之事務原理深度解析 一、事務基礎&#xff1a;ACID特性的本質1.1 事務的定義與核心作用1.2 ACID特性的內在聯系 二、原子性與持久性的基石&#xff1a;日志系統2.1 Undo Log&#xff1a;原子性的實現核心2.2 Redo Log&#xff1a;持久性的保障2.3 雙寫緩沖&#xff08;Dou…

JUC:5.start()與run()

這兩個方法都可以使線程進行運行&#xff0c;但是start只能用于第一次運行線程&#xff0c;后續要繼續運行該線程需要使用run()方法。如果多次運行start()方法&#xff0c;會出現報錯。 初次調用線程使用run()方法&#xff0c;無法使線程運行。 如果你對一個 Thread 實例直接調…

微服務中解決高并發問題的不同方法!

如果由于流量大而在短時間內幾乎同時發出請求&#xff0c;或者由于服務器不穩定而需要很長時間來處理請求&#xff0c;并發問題可能會導致數據完整性問題。 示例問題情況 讓我們假設有一個邏輯可以檢索產品的庫存并將庫存減少一個&#xff0c;如上所述。此時&#xff0c;兩個請…

【2025CCF中國開源大會】OpenChain標準實踐:AI時代開源軟件供應鏈安全合規分論壇重磅來襲!

點擊藍字 關注我們 CCF Opensource Development Committee 在AI時代&#xff0c;軟件供應鏈愈發復雜&#xff0c;從操作系統到開發框架&#xff0c;從數據庫到人工智能工具&#xff0c;開源無處不在。AI 與開源生態深度融合&#xff0c;在為軟件行業帶來前所未有的創新效率的同…

[Java實戰]springboot3使用JDK21虛擬線程(四十)

[Java實戰]springboot3使用JDK21虛擬線程(四十) 告別線程池爆滿、內存溢出的噩夢!JDK21 虛擬線程讓高并發連接變得觸手可及。本文將帶你深入實戰,見證虛擬線程如何以極低資源消耗輕松應對高并發壓測。 一、虛擬線程 傳統 Java 線程(平臺線程)與 OS 線程 1:1 綁定,創建和…

SpringBoot 中使用 @Async 實現異步調用?

? ? SpringBoot 中使用 Async 實現異步調用 一、Async 注解的使用場合?二、Async 注解的創建與調試?三、Async 注解的注意事項?四、總結? 在高并發、高性能要求的應用場景下&#xff0c;異步處理能夠顯著提升系統的響應速度和吞吐量。Spring Boot 提供的 Async 注解為開…

CMOS SENSOR HDR場景下MIPI 虛擬端口的使用案例

CMOS SENSOR HDR場景下MIPI 虛擬端口的使用案例 文章目錄 CMOS SENSOR HDR場景下MIPI 虛擬端口的使用案例?? **一、HDR模式下的虛擬通道核心作用**?? **二、典型應用案例****1. 車載多目HDR系統****2. 工業檢測多模態HDR****3. 手機多攝HDR合成**?? **三、實現關鍵技術點…

RJ45 以太網與 5G 的原理解析及區別

一、RJ45 以太網的原理 1. RJ45 接口與以太網的關系 RJ45 是一種標準化的網絡接口&#xff0c;主要用于連接以太網設備&#xff08;如電腦、路由器&#xff09;&#xff0c;其物理形態為 8 針模塊化接口&#xff0c;適配雙絞線&#xff08;如 CAT5、CAT6 網線&#xff09;。以…

valkey之sdscatrepr 函數優化解析

一、函數功能概述 sds sdscatrepr(sds s, const char *p, size_t len)函數的核心功能是將字符串p追加到字符串s中。在追加過程中&#xff0c;它會對字符串p中的字符進行判斷&#xff0c;使用isprint()函數識別不可打印字符&#xff0c;并對這些字符進行轉義處理&#xff0c;確…