NLP項目實戰01--電影評論分類

介紹:

歡迎來到本篇文章!在這里,我們將探討一個常見而重要的自然語言處理任務——文本分類。具體而言,我們將關注情感分析任務,即通過分析電影評論的情感來判斷評論是正面的、負面的。

展示:
訓練展示如下:

在這里插入圖片描述
在這里插入圖片描述

實際使用如下:

請添加圖片描述

實現方式:

選擇PyTorch作為深度學習框架,使用電影評論IMDB數據集,并結合torchtext對數據進行預處理。

環境:

Windows+Anaconda
重要庫版本信息
torch==1.8.2+cu102
torchaudio==0.8.2
torchdata==0.7.1
torchtext==0.9.2
torchvision==0.9.2+cu102

實現思路:

1、數據集
本次使用的是IMDB數據集,IMDB是一個含有50000條關于電影評論的數據集
數據如下:
請添加圖片描述
請添加圖片描述

2、數據加載與預處理
使用torchtext加載IMDB數據集,并對數據集進行劃分
具體劃分如下:

TEXT = data.Field(tokenize='spacy', tokenizer_language='en_core_web_sm')
LABEL = data.LabelField(dtype=torch.float)
# Load the IMDB dataset
train_data, test_data = datasets.IMDB.splits(TEXT, LABEL)

創建一個 Field 對象,用于處理文本數據。同時使用spacy分詞器對文本進行分詞,由于IMDB是英文的,所以使用en_core_web_sm語言模型。
創建一個 LabelField 對象,用于處理標簽數據。設置dtype 參數為 torch.float,表示標簽的數據類型為浮點型。

使用 datasets.IMDB.splits 方法加載 IMDB 數據集,并將文本字段 TEXT 和標簽字段 LABEL 傳遞給該方法。返回的 train_data 和 test_data 包含了 IMDB 數據集的訓練和測試部分。
下面是train_data的輸出
請添加圖片描述

3、構建詞匯表與加載預訓練詞向量

TEXT.build_vocab(train_data,max_size=25000,vectors="glove.6B.100d",unk_init=torch.Tensor.normal_)
LABEL.build_vocab(train_data)

train_data:表示使用train_data中數據構建詞匯表
max_size:限制詞匯表的大小為 25000
vectors=“glove.6B.100d”:表示使用預訓練的 GloVe 詞向量,其中 “glove.6B.100d” 指的是包含 100 維向量的 6B 版 GloVe。
unk_init=torch.Tensor.normal_ :表示指定未知單詞(UNK)的初始化方式,這里使用正態分布進行初始化。
LABEL.build_vocab(train_data):表示對標簽進行類似的操作,構建標簽的詞匯表

train_iterator, valid_iterator, test_iterator = data.BucketIterator.splits( (train_data, valid_data, test_data), batch_size=BATCH_SIZE, device=device)

使用data.BucketIterator.splits 來創建數據加載器,包括訓練、驗證和測試集的迭代器。這將確保你能夠方便地以批量的形式獲取數據進行訓練和評估。

4、定義神經網絡
這里的網絡定義比較簡單,主要采用在詞嵌入層(embedding)后接一個全連接層的方式完成對文本數據的分類。
具體如下:

class NetWork(nn.Module):def __init__(self,vocab_size,embedding_dim,output_dim,pad_idx):super(NetWork,self).__init__()self.embedding = nn.Embedding(vocab_size,embedding_dim,padding_idx=pad_idx)self.fc = nn.Linear(embedding_dim,output_dim)self.dropout = nn.Dropout(0.5)self.relu = nn.ReLU()def forward(self,x):embedded = self.embedding(x)embedded = embedded.permute(1,0,2) pooled = F.avg_pool2d(embedded, (embedded.shape[1], 1)).squeeze(1)pooled = self.relu(pooled)pooled = self.dropout(pooled)output = self.fc(pooled)return output

5、模型初始化

vocab_size = len(TEXT.vocab)
embedding_dim  = 100
output = 1
pad_idx = TEXT.vocab.stoi[TEXT.pad_token]
model = NetWork(vocab_size,embedding_dim,output,pad_idx)
pretrained_embeddings = TEXT.vocab.vectors
model.embedding.weight.data.copy_(pretrained_embeddings)

定義模型的超參數,包括詞匯表大小(vocab_size)、詞向量維度(embedding_dim)、輸出維度(output,在這個任務中是1,因為是二元分類,所以使用1),以及 PAD 標記的索引(pad_idx)

之后需要將預訓練的詞向量加載到嵌入層的權重中。TEXT.vocab.vectors 包含了詞匯表中每個單詞的預訓練詞向量,然后通過 copy_ 方法將這些詞向量復制到模型的嵌入層權重中對網絡進行初始化。這樣做確保了模型的初始化狀態良好。

6、訓練模型

 total_loss = 0train_acc = 0 
model.train()
for batch in train_iterator:optimizer.zero_grad()preds = model(batch.text).squeeze(1)loss = criterion(preds,batch.label)total_loss += loss.item()batch_acc = (torch.round(torch.sigmoid(preds)) == batch.label).sum().item()train_acc += batch_accloss.backward()optimizer.step()average_loss = total_loss / len(train_iterator)train_acc /= len(train_iterator.dataset)

optimizer.zero_grad():表示將模型參數的梯度清零,以準備接收新的梯度。
preds = model(batch.text).squeeze(1):表示一次前向傳播的過程,由于model輸出的是torch.tensor(batch_size,1)所以使用squeeze(1)給其中的1維度數據去除,以匹配標簽張量的形狀
criterion(preds,batch.label):定義的損失函數 criterion 計算預測值 preds 與真實標簽 batch.label 之間的損失

(torch.round(torch.sigmoid(preds)) == batch.label).sum().item():
通過比較模型的預測值與真實標簽,計算當前批次的準確率,并將其累加到 train_acc 中
后面的就是進行反向傳播更新參數,還有就是計算loss和train_acc的值了
7、模型評估:

model.eval()valid_loss = 0valid_acc = 0best_valid_acc = 0with torch.no_grad():for batch in valid_iterator:preds = model(batch.text).squeeze(1)loss = criterion(preds,batch.label)valid_loss += loss.item()batch_acc = ((torch.round(torch.sigmoid(preds)) == batch.label).sum().item())valid_acc += batch_acc

和訓練模型的類似,這里就不解釋了

8、保存模型
這里一共使用了兩種保存模型的方式:

torch.save(model, "model.pth")
torch.save(model.state_dict(),"model.pth")

第一種方式叫做模型的全量保存
第二種方式叫做模型的參數保存

全量保存是保存了整個模型,包括模型的結構、參數、優化器狀態等信息
參數量保存是保存了模型的參數(state_dict),不包括模型的結構
9、測試模型
測試模型的基本思路:
加載訓練保存的模型、對待推理的文本進行預處理、將文本數據加載給模型進行推理

加載模型:

saved_model_path = "model.pth"
saved_model = torch.load(saved_model_path)

輸入文本:
input_text = “Great service! The staff was very friendly and helpful.”

文本進行處理:

tokenizer = get_tokenizer("spacy", language="en_core_web_sm")
tokenized_text = tokenizer(input_text)
indexed_text = [TEXT.vocab.stoi[token] for token in tokenized_text]
tensor_text = torch.LongTensor(indexed_text).unsqueeze(1).to(device)

模型推理:

saved_model.eval()
with torch.no_grad():output = saved_model(tensor_text).squeeze(1)prediction = torch.round(torch.sigmoid(output)).item()probability = torch.sigmoid(output).item()

由于筆者能力有限,所以在描述的過程中難免會有不準確的地方,還請多多包含!

更多NLP和CV文章以及完整代碼請到"陶陶name"獲取。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/209595.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/209595.shtml
英文地址,請注明出處:http://en.pswp.cn/news/209595.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

比較不同聚類方法的評估指標

歸一化互信息(NMI) 要求:需要每個序列的真實標簽(分類信息)

你在地鐵上修過bug嗎?

作為技術人員,有沒有遇到下班路上收到老板電話,系統故障,然后地鐵上掏出電腦,修bug的場景。自己負責的業務線上出現問題,負責人心里是很慌的,在這種心理狀態下做事很容易二次犯錯,造成更大的問題…

SAP UI5 walkthrough step10 Descriptor for Applications

在這一步,我們將會把所有的應用相關的描述性的文件獨立放到manifest.json 新建一個manifest.json文件 webapp/manifest.json (New) {"_version": "1.58.0","sap.app": {"id": "ui5.walkthrough","i18n&q…

【已解決】No module named ‘sklearn‘

問題描述 No module named ‘sklearn‘ 解決辦法 pip install scikit-learn 完結撒花 契約、包容、感恩、原則……這些成年人該有的基本精神,為什么我在他們身上找不到呢?

圖像疊加中文字體

目錄 1) 前言2) freetype下載3) Demo3.1) 下載3.2) 編譯3.3) 運行3.4) 結果3.5) 更詳細的使用見目錄中說明 4) 積少成多 1) 前言 最近在做圖片、視頻疊加文字,要求支持中文,基本原理是將圖片或視頻解碼后疊加文字,之后做圖片或視頻編碼即可。…

ASP.NET Core概述-微軟已經收購了mono,為什么還搞.NET Core呢

一、.NET Core概述 1、相關歷程 .NET在設計之初也是考慮像Java一樣跨平臺,.NET Framework是在Windows下運行的,大部分類是可以兼容移植到Linux下,但是沒有人做這個工作。 2001年米格爾為Gnome尋找桌面開發技術,在研究了微軟的.…

數據庫版本管理框架-Flyway(從入門到精通)

一、flyway簡介 Flyway是一個簡單開源數據庫版本控制器(約定大于配置),主要提供migrate、clean、info、validate、baseline、repair等命令。它支持SQL(PL/SQL、T-SQL)方式和Java方式,支持命令行客戶端等&am…

TCP對數據的拆分

應用程序的數據一般都比較大,因此TCP會按照網絡包的大小對數據進行拆分。 當發送緩沖區中的數據超過MSS的長度,數據會被以MSS長度為單位進行拆分,拆分出來的數據塊被放進單獨的網路包中。 根據發送緩沖區中的數據拆分情況,當判斷…

ffmpeg編譯問題

利用ffmpeg實現一個播放器,ffmpeg提供動態庫,但是編譯鏈接的時候遇到下面的問題: ../ffmpegWidgetPlayer/videoplayerwidget.cpp:23: error: undefined reference to sws_freeContext(SwsContext*) ../ffmpegWidgetPlayer/videoplayerwidget.…

JWT介紹及演示

JWT 介紹 cookie(放在瀏覽器) cookie 是一個非常具體的東西,指的就是瀏覽器里面能永久存儲的一種數據,僅僅是瀏覽器實現的一種數據存儲功能。 cookie由服務器生成,發送給瀏覽器,瀏覽器把cookie以kv形式保存到某個目錄下的文本…

JavaScript 金額元轉化為萬

function dealNum(price){if (price 0) {return 0元}const BASE 10000const decimal 0const SIZES ["", "萬", "億", "萬億"];let i undefined;let str "";if (price) {if ((price > 0 && price < BASE…

p標簽的水平居中和垂直居中

1行內塊元素水平居中垂直居中 行內元素和行內塊元素水平居中&#xff0c;給其父元素添加text-align:center&#xff1b;所以案例里面給one加了 text-align: center之后span就會水平居中了。在設置span行高和高都是一樣的 20px;這樣就實現上下居中了。 2塊級元素P元素水平居中…

通過命令行輸入參數控制激勵

1)在命令行的仿真參數&#xff08;SIM_OPT&#xff09;加上&#xff1a;“var_a100 var_b99” 2)在環境中調用&#xff1a; $test$plusargs("var_a")&#xff1b;如果命令行存在這個字符&#xff0c;返回1&#xff0c;否則返回0&#xff1b; $value$plusargs(&qu…

vue2 el-input里實現打字機 效果

vue2 el-input里實現打字機 效果 <el-col :span"24" v-if"ifshowOtherDesc""><el-form-item label"分析" prop"otherDesc"><el-input type"textarea" :disabled"disabled" autofocus"t…

藍牙物聯網對接技術難點有哪些?

#物聯網# 藍牙物聯網對接技術難點主要包括以下幾個方面&#xff1a; 1、設備兼容性&#xff1a;藍牙技術有多種版本和規格&#xff0c;如藍牙4.0、藍牙5.0等&#xff0c;不同版本之間的兼容性可能存在問題。同時&#xff0c;不同廠商生產的藍牙設備也可能存在兼容性問題。 2、…

0-1背包問題

二維版: import java.io.BufferedReader; import java.io.IOException; import java.io.InputStreamReader;public class Main {static int N 1010;static int[][] dp new int[N][N]; //dp[i][j] 只選前i件物品,體積 < j的最優解static int[] w new int[N]; //存儲價…

Day03 嵌入式---中斷

目錄 一、簡單介紹 二、總體框架 三、NVIC 3.2 NVIC的寄存器 3.3 中斷向量表 3.4 中斷優先級 3.5 NVIC優先級分組 3.6 NVIC配置 3.6.1、設置中斷分組 3.6.2、初始化 四、EXTI 外部中斷 4.1.EXTI的基本概念 4.2.EXTI的?作原理 4.3 EXTI配置 五、SYSCFG 5.1 SYS…

字符串函數`strlen`、`strcpy`、`strcmp`、`strstr`、`strcat`的使用以及模擬實現

文章目錄 &#x1f680;前言&#x1f680;庫函數strlen??strlen的模擬實現 &#x1f680;庫函數strcpy??strcpy的模擬實現 &#x1f680;strcmp??strcmp的模擬實現 &#x1f680;strstr??strstr的模擬實現 &#x1f680;strcat??strcat的模擬實現 &#x1f680;前言 …

ReactJS和VueJS的簡介以及它們之間的區別

本文主要介紹ReactJS和VueJS的簡介以及它們之間的區別。 目錄 ReactJS簡介ReactJS的優缺點ReactJS的應用場景VueJS簡介VueJS的優缺點VueJS的應用場景ReactJS和VueJS的區別 ReactJS簡介 ReactJS是一個由Facebook開發的基于JavaScript的前端框架。它是一個用于構建用戶界面的庫&…

【C語言】——函數遞歸,用遞歸簡化并實現復雜問題

文章目錄 前言一、什么是遞歸二、遞歸的限制條件三、遞歸舉例1.求n的階乘2. 舉例2&#xff1a;順序打印一個整數的每一位 四、遞歸的優劣總結 前言 不多廢話了&#xff0c;直接開始。 一、什么是遞歸 遞歸是學習C語言函數繞不開的?個話題&#xff0c;那什么是遞歸呢&#xf…