NLP項目實戰01之電影評論分類

介紹:

歡迎來到本篇文章!在這里,我們將探討一個常見而重要的自然語言處理任務——文本分類。具體而言,我們將關注情感分析任務,即通過分析電影評論的情感來判斷評論是正面的、負面的。

展示:
訓練展示如下:

在這里插入圖片描述
在這里插入圖片描述

實際使用如下:

請添加圖片描述

實現方式:

選擇PyTorch作為深度學習框架,使用電影評論IMDB數據集,并結合torchtext對數據進行預處理。

環境:

Windows+Anaconda
重要庫版本信息
torch==1.8.2+cu102
torchaudio==0.8.2
torchdata==0.7.1
torchtext==0.9.2
torchvision==0.9.2+cu102

實現思路:

1、數據集
本次使用的是IMDB數據集,IMDB是一個含有50000條關于電影評論的數據集
數據如下:
請添加圖片描述
請添加圖片描述

2、數據加載與預處理
使用torchtext加載IMDB數據集,并對數據集進行劃分
具體劃分如下:

TEXT = data.Field(tokenize='spacy', tokenizer_language='en_core_web_sm')
LABEL = data.LabelField(dtype=torch.float)
# Load the IMDB dataset
train_data, test_data = datasets.IMDB.splits(TEXT, LABEL)

創建一個 Field 對象,用于處理文本數據。同時使用spacy分詞器對文本進行分詞,由于IMDB是英文的,所以使用en_core_web_sm語言模型。
創建一個 LabelField 對象,用于處理標簽數據。設置dtype 參數為 torch.float,表示標簽的數據類型為浮點型。

使用 datasets.IMDB.splits 方法加載 IMDB 數據集,并將文本字段 TEXT 和標簽字段 LABEL 傳遞給該方法。返回的 train_data 和 test_data 包含了 IMDB 數據集的訓練和測試部分。
下面是train_data的輸出
請添加圖片描述

3、構建詞匯表與加載預訓練詞向量

TEXT.build_vocab(train_data,max_size=25000,vectors="glove.6B.100d",unk_init=torch.Tensor.normal_)
LABEL.build_vocab(train_data)

train_data:表示使用train_data中數據構建詞匯表
max_size:限制詞匯表的大小為 25000
vectors=“glove.6B.100d”:表示使用預訓練的 GloVe 詞向量,其中 “glove.6B.100d” 指的是包含 100 維向量的 6B 版 GloVe。
unk_init=torch.Tensor.normal_ :表示指定未知單詞(UNK)的初始化方式,這里使用正態分布進行初始化。
LABEL.build_vocab(train_data):表示對標簽進行類似的操作,構建標簽的詞匯表

train_iterator, valid_iterator, test_iterator = data.BucketIterator.splits( (train_data, valid_data, test_data), batch_size=BATCH_SIZE, device=device)

使用data.BucketIterator.splits 來創建數據加載器,包括訓練、驗證和測試集的迭代器。這將確保你能夠方便地以批量的形式獲取數據進行訓練和評估。

4、定義神經網絡
這里的網絡定義比較簡單,主要采用在詞嵌入層(embedding)后接一個全連接層的方式完成對文本數據的分類。
具體如下:

class NetWork(nn.Module):def __init__(self,vocab_size,embedding_dim,output_dim,pad_idx):super(NetWork,self).__init__()self.embedding = nn.Embedding(vocab_size,embedding_dim,padding_idx=pad_idx)self.fc = nn.Linear(embedding_dim,output_dim)self.dropout = nn.Dropout(0.5)self.relu = nn.ReLU()def forward(self,x):embedded = self.embedding(x)embedded = embedded.permute(1,0,2) pooled = F.avg_pool2d(embedded, (embedded.shape[1], 1)).squeeze(1)pooled = self.relu(pooled)pooled = self.dropout(pooled)output = self.fc(pooled)return output

5、模型初始化

vocab_size = len(TEXT.vocab)
embedding_dim  = 100
output = 1
pad_idx = TEXT.vocab.stoi[TEXT.pad_token]
model = NetWork(vocab_size,embedding_dim,output,pad_idx)
pretrained_embeddings = TEXT.vocab.vectors
model.embedding.weight.data.copy_(pretrained_embeddings)

定義模型的超參數,包括詞匯表大小(vocab_size)、詞向量維度(embedding_dim)、輸出維度(output,在這個任務中是1,因為是二元分類,所以使用1),以及 PAD 標記的索引(pad_idx)

之后需要將預訓練的詞向量加載到嵌入層的權重中。TEXT.vocab.vectors 包含了詞匯表中每個單詞的預訓練詞向量,然后通過 copy_ 方法將這些詞向量復制到模型的嵌入層權重中對網絡進行初始化。這樣做確保了模型的初始化狀態良好。

6、訓練模型

 total_loss = 0train_acc = 0 
model.train()
for batch in train_iterator:optimizer.zero_grad()preds = model(batch.text).squeeze(1)loss = criterion(preds,batch.label)total_loss += loss.item()batch_acc = (torch.round(torch.sigmoid(preds)) == batch.label).sum().item()train_acc += batch_accloss.backward()optimizer.step()average_loss = total_loss / len(train_iterator)train_acc /= len(train_iterator.dataset)

optimizer.zero_grad():表示將模型參數的梯度清零,以準備接收新的梯度。
preds = model(batch.text).squeeze(1):表示一次前向傳播的過程,由于model輸出的是torch.tensor(batch_size,1)所以使用squeeze(1)給其中的1維度數據去除,以匹配標簽張量的形狀
criterion(preds,batch.label):定義的損失函數 criterion 計算預測值 preds 與真實標簽 batch.label 之間的損失

(torch.round(torch.sigmoid(preds)) == batch.label).sum().item():
通過比較模型的預測值與真實標簽,計算當前批次的準確率,并將其累加到 train_acc 中
后面的就是進行反向傳播更新參數,還有就是計算loss和train_acc的值了
7、模型評估:

model.eval()valid_loss = 0valid_acc = 0best_valid_acc = 0with torch.no_grad():for batch in valid_iterator:preds = model(batch.text).squeeze(1)loss = criterion(preds,batch.label)valid_loss += loss.item()batch_acc = ((torch.round(torch.sigmoid(preds)) == batch.label).sum().item())valid_acc += batch_acc

和訓練模型的類似,這里就不解釋了

8、保存模型
這里一共使用了兩種保存模型的方式:

torch.save(model, "model.pth")
torch.save(model.state_dict(),"model.pth")

第一種方式叫做模型的全量保存
第二種方式叫做模型的參數保存

全量保存是保存了整個模型,包括模型的結構、參數、優化器狀態等信息
參數量保存是保存了模型的參數(state_dict),不包括模型的結構
9、測試模型
測試模型的基本思路:
加載訓練保存的模型、對待推理的文本進行預處理、將文本數據加載給模型進行推理

加載模型:

saved_model_path = "model.pth"
saved_model = torch.load(saved_model_path)

輸入文本:
input_text = “Great service! The staff was very friendly and helpful.”

文本進行處理:

tokenizer = get_tokenizer("spacy", language="en_core_web_sm")
tokenized_text = tokenizer(input_text)
indexed_text = [TEXT.vocab.stoi[token] for token in tokenized_text]
tensor_text = torch.LongTensor(indexed_text).unsqueeze(1).to(device)

模型推理:

saved_model.eval()
with torch.no_grad():output = saved_model(tensor_text).squeeze(1)prediction = torch.round(torch.sigmoid(output)).item()probability = torch.sigmoid(output).item()

由于筆者能力有限,所以在描述的過程中難免會有不準確的地方,還請多多包含!

更多NLP和CV文章以及完整代碼請到"陶陶name"獲取。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/209046.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/209046.shtml
英文地址,請注明出處:http://en.pswp.cn/news/209046.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

【基于LicheePi-4A的 人臉識別系統軟件設計】

參考:https://www.xrvm.cn/community/post/detail?spm=a2cl5.27438731.0.0.31d40dck0dckmg&id=4253195599836418048 1.前言 原先計劃做基于深度學習的炸藥抓取和智能填裝方法研究,但是后來發現板卡不支持pyrealsense2等多個依賴包。因此改變策略,做一款基于LicheePi…

Android Studio的筆記--三元表達式、布爾運算符、與() 或(||) 非(!)

[TOC](三元表達式、布爾運算符、與(&&) 或(||) 非(!)) 表達式 int x 1; int y 2;x < y 結果 true x > y 結果 false x < y 結果 false x > y 結果 true x y 結果 false x ! y 結果 true 布爾運算符 boolean boolean a true; boolean b false; 與…

【Python】列表乘積的計算時間

概述 使用以下三種模式測量了計算列表乘積所需的時間。 使用 for 語句傳遞list使用math模塊使用numpy 下面是實際運行的代碼。 import timestart time.time() A [1] * 100000000 ans 1 for a in A:ans * a print("list loop:", time.time() - start)import m…

前端面試提問(4)

1、手撕防抖與節流、樹與對象的轉換、遞歸調用&#xff0c;鏈表頭插法 1.1、防抖 防抖函數用于延遲執行某個函數&#xff0c;直到過了一定的間隔時間&#xff08;例如等待用戶停止輸入&#xff09;后再執行。 即后一次點擊事件發生時間距離一次點擊事件至少間隔一定時間。 …

笙默考試管理系統-MyExamTest----codemirror(49)

笙默考試管理系統-MyExamTest----codemirror&#xff08;49&#xff09; 目錄 笙默考試管理系統-MyExamTest----codemirror&#xff08;49&#xff09; 一、 笙默考試管理系統-MyExamTest----codemirror 二、 笙默考試管理系統-MyExamTest----codemirror 三、 笙默考試…

有哪些已經上線的vue商城項目?

前言 下面是一些商城的項目&#xff0c;需要練手的同學可以挑選一些來練&#xff0c;廢話少說&#xff0c;讓我們直接開始正題~~ 1、newbee-mall-vue3-app 是一個基于 Vue 3 和 TypeScript 的電商前端項目&#xff0c;它是 newbee-mall 項目的升級版。該項目包含了商品列表、…

內網環境下 - 安裝linux命令、搭建docker以及安裝鏡像

一 內網環境安裝docker 先在外網環境下載好docker二進制文件docker二進制文件下載&#xff0c;要下載對應硬件平臺的文件&#xff0c;否則不兼容 如下載linux平臺下的文件&#xff0c;直接訪問這里即可linux版本docker二進制文件 這里下載docker-24.0.5.tgz 將下載好的文件…

計算機存儲單位 + 程序編譯過程

C語言的編譯過程 計算機存儲單位 頭文件包含的兩種方式 使用 C/C 程序常用的IDE 常用的C語言編譯器&#xff1a; 在選擇編譯器時&#xff0c;需考慮平臺兼容性、性能優化、調試工具和開發人員的個人偏好等因素。 詳細教程可轉 愛編程的大丙

Java編程中通用的正則表達式(一)

正則表達式&#xff08;Regular Expression&#xff0c;簡稱RegEx&#xff09;&#xff0c;又稱常規表示法、正則表示、正規表示式、規則表達式、常式、表達式等&#xff0c;是計算機科學中的一個概念。正則表達式是用于描述某種特定模式的字符序列&#xff0c;特別是用來匹配、…

持續集成和持續交付

引言 CI/CD 是一種通過在應用開發階段引入自動化來頻繁向客戶交付應用的方法。CI/CD 的核心概念是持續集成、持續交付和持續部署。作為一種面向開發和運維團隊的解決方案&#xff0c;CI/CD 主要針對在集成新代碼時所引發的問題&#xff08;亦稱&#xff1a;“集成地獄”&#…

力扣刷題筆記——反轉鏈表

力扣&#xff08;LeetCode&#xff09;官網 - 全球極客摯愛的技術成長平臺 經典問題反轉鏈表 這里給出四種解法 1.雙指針 這種方法是用一個next指針記錄當前節點的下一個節點&#xff0c;一個pre指針記錄當前節點的前一個節點。 只需要遍歷一遍鏈表就可以完成鏈表的反轉 c…

idea__SpringBoot微服務05——JSR303校驗(新注解)(新的依賴),配置文件優先級,多環境切換

JSR303校驗&#xff0c;配置文件優先級&#xff0c;多環境切換 一、JSR303數據校驗二、配置文件優先級三、多環境切換一、properties多環境切換二、yaml多環境切換————————創作不易&#xff0c;如覺不錯&#xff0c;隨手點贊&#xff0c;關注&#xff0c;收藏(*&#x…

電腦待機怎么設置?讓你的電腦更加節能

在日常使用電腦的過程中&#xff0c;合理設置待機模式是一項省電且環保的好習慣。然而&#xff0c;許多用戶對于如何設置電腦待機感到困擾。那么電腦待機怎么設置呢&#xff1f;本文將深入探討三種常用的電腦待機設置方法&#xff0c;通過詳細的步驟&#xff0c;幫助用戶更好地…

【C語言期末】題目+筆記

文章目錄 題目1.下面哪個不是C語言的基本數據類型&#xff1f;&#xff08; B &#xff09;2.C語言的標識符應以字母或&#xff08; A &#xff09;開頭。3.如果需要在C程序里調用標準函數庫中的printf函數&#xff0c;則應該在程序的開頭包含哪個頭文件&#xff1f;&#xff0…

【數據結構】順序表的定義和運算

目錄 1.初始化 2.插入 3.刪除 4.查找 5.修改 6.長度 7.遍歷 8.完整代碼 &#x1f308;嗨&#xff01;我是Filotimo__&#x1f308;。很高興與大家相識&#xff0c;希望我的博客能對你有所幫助。 &#x1f4a1;本文由Filotimo__??原創&#xff0c;首發于CSDN&#x1f4da;。 &…

web前端開發html/css練習

目標圖&#xff1a; 素材&#xff1a; 代碼&#xff1a; <!DOCTYPE html PUBLIC "-//W3C//DTD XHTML 1.0 Transitional//EN" "http://www.w3.org/TR/xhtml1/DTD/xhtml1-transitional.dtd"> <html xmlns"http://www.w3.org/1999/xhtml"…

使用RSA工具進行對信息加解密

我們在開發中需要對用戶敏感數據進行加解密&#xff0c;比如密碼 這邊科普一下RSA算法 RSA是非對稱加密算法&#xff0c;與對稱加密算法不同;在對稱加密中&#xff0c;相同的密鑰用于加密和解密數據,因此密鑰的安全性至關重要;而在RSA非對稱加密中&#xff0c;有兩個密鑰&…

【USRP】5G / 6G OAI 系統 5g / 6G OAI system

面向5G/6G科研應用 USRP專門用于5G/6G產品的原型開發與驗證。該系統可以在實驗室搭建一個真實的5G 網絡&#xff0c;基于開源的代碼&#xff0c;專為科研用戶設計。 軟件無線電架構&#xff0c;構建真實5G移動通信系統 X410 采用了目前流行的異構式系統&#xff0c;融合了FP…

SQLite基本使用

目錄 1. 概述2. 引入SQLite3. 連接數據庫創建游標4. 創建數據庫文件5. 新增單條數據6. 批量新增數據7. 查詢單條數據8.查詢全部數據9. 查詢指定條數的數據10. 修改數據11. 刪除數據12. 事務回滾13. 關閉數據庫關閉游標1. 概述 SQLite是一個進程內的庫,實現了自給自足的、無服務…

【嵌入式開發 Linux 常用命令系列 4.2 -- .repo 各個目錄介紹】

文章目錄 概述.repo 目錄結構manifests/default.xmlManifest 文件的作用default.xml 文件內容示例linkfile 介紹 .repo/projects 子目錄配置和管理configHEADhooksinfo/excludeobjectsrr-cache 工作區中的對應目錄 概述 repo 是一個由 Google 開發的版本控制工具&#xff0c;它…