ROCm上情感分析:使用循環神經網絡

15.2. 情感分析:使用循環神經網絡 — 動手學深度學習 2.0.0 documentation (d2l.ai)

代碼

import torch
from torch import nn
from d2l import torch as d2lbatch_size = 64
train_iter, test_iter, vocab = d2l.load_data_imdb(batch_size)class BiRNN(nn.Module):def __init__(self, vocab_size, embed_size, num_hiddens,num_layers, **kwargs):super(BiRNN, self).__init__(**kwargs)self.embedding = nn.Embedding(vocab_size, embed_size)# 將bidirectional設置為True以獲取雙向循環神經網絡self.encoder = nn.LSTM(embed_size, num_hiddens, num_layers=num_layers,bidirectional=True)self.decoder = nn.Linear(4 * num_hiddens, 2)def forward(self, inputs):# inputs的形狀是(批量大小,時間步數)# 因為長短期記憶網絡要求其輸入的第一個維度是時間維,# 所以在獲得詞元表示之前,輸入會被轉置。# 輸出形狀為(時間步數,批量大小,詞向量維度)embeddings = self.embedding(inputs.T)self.encoder.flatten_parameters()# 返回上一個隱藏層在不同時間步的隱狀態,# outputs的形狀是(時間步數,批量大小,2*隱藏單元數)outputs, _ = self.encoder(embeddings)# 連結初始和最終時間步的隱狀態,作為全連接層的輸入,# 其形狀為(批量大小,4*隱藏單元數)encoding = torch.cat((outputs[0], outputs[-1]), dim=1)outs = self.decoder(encoding)return outsembed_size, num_hiddens, num_layers = 100, 100, 2
devices = d2l.try_all_gpus()
net = BiRNN(len(vocab), embed_size, num_hiddens, num_layers)def init_weights(m):if type(m) == nn.Linear:nn.init.xavier_uniform_(m.weight)if type(m) == nn.LSTM:for param in m._flat_weights_names:if "weight" in param:nn.init.xavier_uniform_(m._parameters[param])
net.apply(init_weights);glove_embedding = d2l.TokenEmbedding('glove.6b.100d')embeds = glove_embedding[vocab.idx_to_token]
embeds.shapenet.embedding.weight.data.copy_(embeds)
net.embedding.weight.requires_grad = Falselr, num_epochs = 0.01, 5
trainer = torch.optim.Adam(net.parameters(), lr=lr)
loss = nn.CrossEntropyLoss(reduction="none")
d2l.train_ch13(net, train_iter, test_iter, loss, trainer, num_epochs,devices)#@save
def predict_sentiment(net, vocab, sequence):"""預測文本序列的情感"""sequence = torch.tensor(vocab[sequence.split()], device=d2l.try_gpu())label = torch.argmax(net(sequence.reshape(1, -1)), dim=1)return 'positive' if label == 1 else 'negative'predict_sentiment(net, vocab, 'this movie is so great')predict_sentiment(net, vocab, 'this movie is so bad')

代碼解析

這段代碼實現了一個用于情感分析的雙向循環神經網絡(BiRNN)。下面我將逐部分用中文解析它:
1. 導入所需的庫和模塊:

import torch
from torch import nn
from d2l import torch as d2l

這里導入了PyTorch庫、神經網絡模塊`nn`和基于PyTorch的深度學習庫`d2l`(深度學習的一本書)。
2. 加載數據集:

batch_size = 64
train_iter, test_iter, vocab = d2l.load_data_imdb(batch_size)

加載IMDB電影評論數據集,并用迭代器`train_iter`和`test_iter`進行訓練和測試。`vocab`是數據集的詞匯表。
3. 定義雙向循環神經網絡(BiRNN)模型:

class BiRNN(nn.Module):...

創建了一個名為`BiRNN`的類,用于定義雙向LSTM模型。模型有一個嵌入層(`embedding`),將詞匯映射到向量空間。LSTM層(`encoder`)設定為雙向,輸出經過全連接層(`decoder`)得到最終的分類結果。
4. 初始化模型參數:

def init_weights(m):...
net.apply(init_weights);

init_weights函數用于模型參數的初始化。`net.apply(init_weights);`使用這個函數來應用參數初始化。
5. 加載預訓練的詞向量:

glove_embedding = d2l.TokenEmbedding('glove.6b.100d')
embeds = glove_embedding[vocab.idx_to_token]
net.embedding.weight.data.copy_(embeds)
net.embedding.weight.requires_grad = False

使用GloVe預訓練的100維詞向量,并將它們復制到嵌入層`net.embedding`。同時設置`requires_grad = False`使得這些詞向量在訓練中不被更新。
6. 訓練模型:

lr, num_epochs = 0.01, 5
trainer = torch.optim.Adam(net.parameters(), lr=lr)
loss = nn.CrossEntropyLoss(reduction="none")
d2l.train_ch13(net, train_iter, test_iter, loss, trainer, num_epochs, devices)

設置學習率和迭代次數,使用Adam優化器和交叉熵損失函數。用`d2l.train_ch13`函數來訓練和評估模型。
7. 定義預測函數:

def predict_sentiment(net, vocab, sequence):...

這個函數用于預測給定文本序列的情感標簽(積極或消極)。
8. 使用模型進行預測:

predict_sentiment(net, vocab, 'this movie is so great')
predict_sentiment(net, vocab, 'this movie is so bad')

調用`predict_sentiment`函數分別對兩個句子進行情感預測。
整體來看,這段代碼主要是利用循環神經網絡對電影評論的情感進行分類,它通過加載預訓練好的詞向量,構建一個雙向LSTM網絡,并在IMDB評論數據集上進行訓練和測試。最后定義了一個實用函數,用于預測輸入句子的情感傾向。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/bicheng/14860.shtml
繁體地址,請注明出處:http://hk.pswp.cn/bicheng/14860.shtml
英文地址,請注明出處:http://en.pswp.cn/bicheng/14860.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

java抽象類,接口,枚舉練習題

第一題: 答案: class Animal{//成員變量protected String name;protected int weight;//構造方法public Animal(){this.name"refer";this.weight50;}public Animal(String name,int weight){this.namename;this.weightweight;}//成員方法publ…

Bugku Crypto 部分題目簡單題解(四)

目錄 python_jail 簡單的rsa 托馬斯.杰斐遜 這不是md5 進制轉換 affine Crack it rsa python_jail 啟動場景 使用虛擬機nc進行連接 輸入print(flag) 發現報錯,經過測試只能傳入10個字符多了就會報錯 利用python中help()函數,借報錯信息帶出flag變…

【力扣刷題筆記第三期】Python 數據結構與算法

先從簡單的題型開始刷起,一起加油啊!! 點個關注和收藏唄,一起刷題鴨!! 第一批題目 1.設備編號 給定一個設備編號區間[start, end],包含4或18的編號都不能使用,如:418、…

對于map的新應用

題源codeforces1974 problemC 題目大意 定義當兩個三元組A和B中,滿足三元組中有且僅有兩個元素相等,比如 a 1 b 1 , a 2 b 2 , a 3 ! b 3 a_1b_1,a_2b_2,a_3!b_3 a1?b1?,a2?b2?,a3?!b3? 這只是一種情況,三種情況之一 解題思路 …

java抽象類和接口知識總結

一.抽象類 1.啥是抽象類 用專業語言描述就是:如果一個類中沒有包含足夠的信息來描繪一個具體的對象,這樣的類就是抽象類 當然這話說的也很抽象,所以我們來用人話來解釋一下抽象類 拋開編程語言這些,就以現實舉例,我…

每日練習之排序——鏈表的合并;完全背包—— 兌換零錢

鏈表的合并 題目描述 運行代碼 #include<iostream> #include<algorithm> using namespace std; int main() { int a[31];for(int i 1;i < 30;i)cin>>a[i];sort(a 1,a 1 30);for(int i 1;i < 30;i)cout<<a[i]<<" ";cout&…

Mysql之Innodb存儲引擎

1.Innodb數據存儲 innodb如今能夠做到mysql的默認數據存儲引擎&#xff0c;肯定有著其好處的&#xff0c;那么innodb有什么好處呢? 1. 當意外斷電或者重啟&#xff0c; InnoDB 能夠做到奔潰恢復&#xff0c;撤銷沒有提交的數據 2.InnoDB 存儲引擎維護自己的緩沖池&#xff0c…

UDS(ISO 14229)學習筆記

文章目錄 名詞縮寫Vector視頻筆記$10$27Fault Memory物理尋址和功能尋址UDS服務分類0x19服務0x14DTC汽車控制器(ECU)中DTC的狀態位物理尋址和功能尋址單幀 多幀 首幀 連續幀名詞縮寫 DTC Diagnostic Trouble Code FTB Fault Type Byte SID Service Identifier SF Subfunctio…

DML(Data Manipulation Language)數據操作語言

一、增加 insert into -- 寫全所有列名 insert into 表名(列名1,列名2,...列名n) values(值1,值2,...值n);-- 不寫列名&#xff08;所有列全部添加&#xff09; insert into 表名 values(值1,值2,...值n);-- 插入部分數據 insert into 表名(列名1,列名2) values(值1,值2); 舉…

醫院掛號就診系統的設計與實現

前端使用Vue.js 后端使用SpiringBoot MyBatis 數據使用MySQL 需要項目和論文加企鵝&#xff1a;2583550535 醫院掛號就診系統的設計與實現_嗶哩嗶哩_bilibili 隨著社會的發展&#xff0c;醫療資源分布不均&#xff0c;患者就診難、排隊時間長等問題日益突出&#xff0c;傳統的…

軟考備考三

操作系統 操作系統概述 功能&#xff1a;組織和管理軟件&#xff0c;硬件資源以及計算機系統中的工作流程&#xff0c;控制程序的執行&#xff0c;向用戶提供接口。 分類&#xff1a; 1.批處理操作系統 單道批 多道批&#xff08;宏觀上并行&#xff0c;微觀上串行&#xff09…

Hadoop3:HDFS的Fsimage和Edits文件介紹

一、概念 Fsimage文件&#xff1a;HDFS文件系統元數據的一個永久性的檢查點&#xff0c;其中包含HDFS文件系統的所有目 錄和文件inode的序列化信息。 Edits文件&#xff1a;存放HDFS文件系統的所有更新操作的路徑&#xff0c;文件系統客戶端執行的所有寫操作首先 會被記錄到Ed…

K8s 身份認證和權限

文章目錄 K8s 身份認證和權限認證Service AccountsService Account Admission ControllerToken ControllerService Account Controller 授權(RBAC)RoleClusterRoleRoleBindingClusterRoleBinding K8s 身份認證和權限 Kubernetes 中提供了良好的多租戶認證管理機制&#xff0c;…

二叉樹的鏈式結構

1.二叉樹的遍歷 2.二叉樹鏈式結構的實現 3.解決單值二叉樹題 1.二叉樹的遍歷 1.1前序&#xff0c;中序以及后序遍歷 二叉樹的遍歷是按照某種特定的規則&#xff0c;依次對二叉樹的結點進行相應的操作&#xff0c;并且每個結點只操作一次。 二叉樹的遍歷有這些規則&#xff…

主流電商平臺商品實時數據采集API接口||抖音電商數據分析實例|可視化

— 1 — 抖音電商數據【抖音電商API數據采集】分析場景 1. 這里&#xff0c;我們選擇“伊利”這個品牌作為案例進行分析&#xff0c;在短短的4個月里&#xff0c;從最初每月營收17.07萬&#xff0c;到6月份達到了2485.54 萬&#xff0c;伊利的牛奶&#xff0c;有點牛&#xff…

Spring 對 Junit4,Junit5 的支持上的運用

1. Spring 對 Junit4,Junit5 的支持上的運用 文章目錄 1. Spring 對 Junit4,Junit5 的支持上的運用每博一文案2. Spring對Junit4 的支持3. Spring對Junit5的支持4. 總結&#xff1a;5. 最后&#xff1a; 每博一文案 關于理想主義&#xff0c;在知乎上看到一句話&#xff1a;“…

在Windows下訪問WSL(Windows Subsystem for Linux)文件夾

在Windows下訪問WSL&#xff08;Windows Subsystem for Linux&#xff09;文件夾&#xff0c;可以按照以下步驟操作&#xff1a; 通過Windows文件資源管理器訪問&#xff1a; 打開文件資源管理器。在地址欄中輸入\\wsl$&#xff0c;然后按回車鍵。這將打開一個顯示WSL可用發行版…

kafka配置消費者重要參數

文章目錄 fetch.min.bytesfetch.max.wait.msfetch.max.bytesmax.poll.recordsmax.partition.fetch.bytessession.timeout.ms和heartbeat.interval.msmax.poll.interval.msrequest.timeout.msauto.offset.resetenable.auto.commitpartition.assignment.strategy區間(range)輪詢(…

Xline社區會議Call Up|在 CURP 算法中實現聯合共識的安全性

為了更全面地向大家介紹Xline的進展&#xff0c;同時促進Xline社區的發展&#xff0c;我們將于2024年5月31日北京時間11:00 p.m.召開Xline社區會議。 歡迎您屆時登陸zoom觀看直播&#xff0c;或點擊“閱讀原文”鏈接加入會議&#xff1a; 會議號: 832 1086 6737 密碼: 41125…

通過cmd命令行使用用3dmax自帶的vray渲染

有時調試需要使用vray渲染vrscene文件看效果&#xff0c;只裝有3dmax下可以使用自帶vray渲染&#xff0c;在3dmax的渲染日志里面看自帶引擎路徑 使用命令行進入到此目錄 執行命令指定vr文件即可看到效果&#xff0c;如&#xff1a;vray.exe -sceneFile“C:\test15\202405241…