在 PyTorch 中借助 GloVe 詞嵌入完成情感分析

一. Glove 詞嵌入原理

GloVe是一種學習詞嵌入的方法,它希望擬合給定上下文單詞i時單詞j出現的次數x_{ij}。使用的誤差函數為:

\sum_{i=1}^{N}\sum_{j=1}^{N}f(x_{ij})(\theta _{j}^{T}e_{i}+b_{i}+b_{j}^{'}-logx_{ij})

其中N是詞匯表大小,\theta ,b是線性層參數,e_{i}?是詞嵌入。f(x)是權重項,用于平衡不同頻率的單詞對誤差的影響,并消除log0時式子不成立情況。

GloVe作者提供了官方的預訓練詞嵌入(https://nlp.stanford.edu/projects/glove/?)。預訓練的GloVe有好幾個版本,按數據來源,可以分成:

  • 維基百科+gigaword(6B)
  • 爬蟲(42B)
  • 爬蟲(840B)
  • 推特(27B)

按照詞嵌入向量的大小分,又可以分成50維,100維,200維等不同維度。

預訓練GloVe的文件格式非常簡明,一行代表一個單詞向量,每行先是一個單詞,再是若干個浮點數,表示該單詞向量的每一個元素。

在Pytorch里,我們不必自己去下載解析GloVe,而是可以直接調用Pytorch庫自動下載解析GloVe。首先我們要安裝Pytorch的NLP庫-- torchtext。

如上所述,GloVe版本可以由其數據來源和向量維數確定,在構建GloVe類時,要提供這兩個參數,我們選擇的是6B token,維度100的GloVe

調用glove.get_vecs_by_tokens,我們能夠把token轉換成GloVe里的向量。

import torch

from torchtext.vocab import GloVe

glove = GloVe(name='6B', dim=100)

# Get vectors

tensor = glove.get_vecs_by_tokens(['', '1998', '199999998', ',', 'cat'], True)

print(tensor)

PyTorch提供的這個函數非常方便。如果token不在GloVe里的話,該函數會返回一個全0向量。如果你運行上面的代碼,可以觀察到一些有趣的事:空字符串和199999998這樣的不常見數字不在詞匯表里,而1998這種常見的數字以及標點符號都在詞匯表里。

GloVe類內部維護了一個矩陣,即每個單詞向量的數組。因此,GloVe需要一個映射表來把單詞映射成向量數組的下標。glove.itosglove.stoi完成了下標與單詞字符串的相互映射。比如用下面的代碼,我們可以知道詞匯表的大小,并訪問詞匯表的前幾個單詞:

myvocab = glove.itos
print(len(myvocab))
print(myvocab[0], myvocab[1], myvocab[2], myvocab[3])

最后,我們來通過一個實際的例子認識一下詞嵌入的意義。詞嵌入就是向量,向量的關系常常與語義關系對應。利用詞嵌入的相對關系,我們能夠回答“x1之于y1,相當于x2之于誰?”這種問題。比如,男人之于女人,相當于國王之于王后。設我們要找的向量為y2,我們想讓x1-y1=x2-y2,即找出一個和x2-(x1-y1)最相近的向量y2出來。這一過程可以用如下的代碼描述:

def get_counterpart(x1, y1, x2):x1_id = glove.stoi[x1]y1_id = glove.stoi[y1]x2_id = glove.stoi[x2]#print("x1:",x1,"y1:",y1,"x2:",x2)x1, y1, x2 = glove.get_vecs_by_tokens([x1, y1, x2],True)#print("x1:",x1,"y1:",y1,"x2:",x2)target = x2 - x1 + y1max_sim =0 max_id = -1for i in range(len(myvocab)):vector = glove.get_vecs_by_tokens([myvocab[i]],True)[0]cossim = torch.dot(target, vector)if cossim > max_sim and i not in {x1_id, y1_id, x2_id}:max_sim = cossimmax_id = ireturn myvocab[max_id]
print(get_counterpart('man', 'woman', 'king'))
print(get_counterpart('more', 'less', 'long'))
print(get_counterpart('apple', 'red', 'banana'))

運行結果:?

queen

short

yellow

二.基于GloVe的情感分析

情感分析任務與數據集

和貓狗分類類似,情感分析任務是一種比較簡單的二分類NLP任務:給定一段話,輸出這段話的情感是積極的還是消極的。

比如下面這段話:

I went and saw this movie last night after being coaxed to by a few friends of mine. I'll admit that I was reluctant to see it because from what I knew of Ashton Kutcher he was only able to do comedy. I was wrong. Kutcher played the character of Jake Fischer very well, and Kevin Costner played Ben Randall with such professionalism. ......

這是一段影評,大意說,這個觀眾本來不太想去看電影,因為他認為演員Kutcher只能演好喜劇。但是,看完后,他發現他錯了,所有演員都演得非常好。這是一段積極的評論。

1. 讀取數據集:

import os 
from torchtext.data import get_tokenizerdef read_imdb(dir='aclImdb', split = 'pos', is_train=True):subdir = 'train' if is_train else 'test'dir = os.path.join(dir, subdir, split)lines = []for file in os.listdir(dir):with open(os.path.join(dir, file), 'rb') as f:line = f.read().decode('utf-8')lines.append(line)return lineslines = read_imdb()
print('Length of the file:', len(lines))
print('lines[0]:', lines[0])
tokenizer = get_tokenizer('basic_english')
tokens = tokenizer(lines[0])
print('lines[0] tokens:', tokens)

output:?

2.獲取經GloVe預處理的數據

在這個作業里,模型其實很簡單,輸入序列經過詞嵌入,送入單層RNN,之后輸出結果。作業最難的是如何把token轉換成GloVe詞嵌入。

torchtext其實還提供了一些更方便的NLP工具類(Field,Vectors),用于管理向量。但是,這些工具需要一定的學習成本,后續學習pytorch時再學習。

Pytorch通常用nn.Embedding來表示詞嵌入層。nn.Embedding其實就是一個矩陣,每一行都是一個詞嵌入,每一個token都是整型索引,表示該token再詞匯表里的序號。有了索引,有了矩陣就可以得到token的詞嵌入了。但是有些token在詞匯表中并不存在,我們得對輸入做處理,把詞匯表里沒有的token轉換成<unk>這個表示未知字符的特殊token。同時為了對齊序列的長度,我們還得添加<pad>這個特殊字符。而用glove直接生成的nn.Embedding里沒有<unk>和<pad>字符。如果使用nn.Embedding的話,我們要編寫非常復雜的預處理邏輯。

為此,我們可以用GloVe類的get_vecs_by_tokens直接獲取token的詞嵌入,以代替nn.Embedding。回憶一下前文提到的get_vecs_by_tokens的使用結果,所有沒有出現的token都會被轉換成零向量。這樣,我們就不必操心數據預處理的事了。get_vecs_by_tokens應該發生在數據讀取之后,可以直接被寫在Dataset的讀取邏輯里

from torch.utils.data import DataLoader, Dataset
from torchtext.data import get_tokenizer
from torchtext.vocab import GloVeclass IMDBDataset(Dataset):def __init__(self, is_train=True, dir = 'aclImdb'):super().__init__()self.tokenizer = get_tokenizer('basic_english')pos_lines = read_imdb(dir, 'pos', is_train)neg_lines = read_imdb(dir, 'neg', is_train)self.pos_length = len(pos_lines)self.neg_length = len(neg_lines)self.lines = pos_lines+neg_linesdef __len__(self):return self.pos_length + self.neg_lengthdef __getitem__(self, index):sentence = self.tokenizer(self.lines[index])x = glove.get_vecs_by_tokens(sentence)label = 1 if index < self.pos_length else 0return x, label

數據預處理的邏輯都在__getitem__里。每一段字符串會先被token化,之后由GLOVE.get_vecs_by_tokens得到詞嵌入數組。?

3.對齊輸入

使用一個batch的序列數據時常常會碰到序列不等長的問題。實際上利用Pytorch Dataloader的collate_fn機制有更簡潔的實現方法。

from torch.nn.utils.rnn import pad_sequencedef get_dataloader(dir='aclImdb'):def collate_fn(batch):x, y = zip(*batch)x_pad = pad_sequence(x, batch_first=True)y = torch.Tensor(y)return x_pad, ytrain_dataloader = DataLoader(IMDBDataset(True, dir),batch_size=32,shuffle=True,collate_fn=collate_fn)test_dataloader = DataLoader(IMDBDataset(False, dir),batch_size=32,shuffle=True,collate_fn=collate_fn)return train_dataloader, test_dataloader

PyTorch DataLoader在獲取Dataset的一個batch的數據時,實際上會先吊用Dataset.__getitem__獲取若干個樣本,再把所有樣本拼接成一個batch,比如用__getitem__獲取四個[4,3,10,10]這一個batch,可是序列數據通常長度不等,__getitem__可能會獲得[10, 100],?[15, 100]這樣不等長的詞嵌入數組。

為了解決這個問題,我們要手動編寫把所有張量拼成一個batch的函數。這個函數就是DataLoadercollate_fn函數。我們的collate_fn應該這樣編寫:

def collate_fn(batch):x, y = zip(*batch)x_pad = pad_sequence(x, batch_first=True)y = torch.Tensor(y)return x_pad, y

collate_fn的輸入batch是每次__getitem__的結果的數組。比如在我們這個項目中,第一次獲取了一個長度為10的積極的句子,__getitem__返回(Tensor[10, 100], 1);第二次獲取了一個長度為15的消極的句子,__getitem__返回(Tensor[15, 100], 0)。那么,輸入batch的內容就是:

[(Tensor[10, 100], 1), (Tensor[15, 100], 0)]

我們可以用x, y = zip(*batch)把它巧妙地轉換成兩個元組:

x = (Tensor[10, 100], Tensor[15, 100])
y = (1, 0)

之后,PyTorch的pad_sequence可以把不等長序列的數組按最大長度填充成一整個batch張量。也就是說,經過這個函數后,x_pad變成了:

x_pad = Tensor[2, 15, 100]

pad_sequencebatch_first決定了batch是否在第一維。如果它為False,則結果張量的形狀是[15, 2, 100]

pad_sequence還可以決定填充內容,默認填充0。在我們這個項目中,被填充的序列已經是詞嵌入了,直接用全零向量表示<pad>沒問題。

有了collate_fn,構建DataLoader就很輕松了:

DataLoader(IMDBDataset(True, dir),batch_size=32,shuffle=True,collate_fn=collate_fn)

注意,使用shuffle=True可以令DataLoader隨機取數據構成batch。由于我們的Dataset十分工整,前一半的標簽是1,后一半是0,必須得用隨機的方式去取數據以提高訓練效率。?

4.模型

import torch.nn as nn
GLOVE_DIM = 100
GLOVE = GloVe(name = '6B', dim=GLOVE_DIM)
class RNN(torch.nn.Module):def __init__(self, hidden_units=64, dropout_rate = 0.5):super().__init__()self.drop = nn.Dropout(dropout_rate)self.rnn = nn.GRU(GLOVE_DIM, hidden_units, 1, batch_first=True)self.linear = nn.Linear(hidden_units,1)self.sigmoid = nn.Sigmoid()def forward(self, x:torch.Tensor):# x: [batch, max_word_length, embedding_length]emb = self.drop(x)output,_ = self.rnn(emb)output = output[:, -1]output = self.linear(output)output = self.sigmoid(output)return output

這里要注意一下,PyTorch的RNN會返回整個序列的輸出。而在預測分類概率時,我們只需要用到最后一輪RNN計算的輸出。因此,要用output[:, -1]取最后一次的輸出。

5. 訓練、測試、推理?

train_dataloader, test_dataloader = get_dataloader()
model = RNN()optimizer = torch.optim.Adam(model.parameters(),lr=0.001)
citerion = torch.nn.BCELoss()for epoch in range(100):loss_sum = 0dataset_len = len(train_dataloader.dataset)for x, y in train_dataloader:batchsize = y.shape[0]hat_y = model(x)hat_y = hat_y.squeeze(-1)loss = citerion(hat_y, y)optimizer.zero_grad()loss.backward()torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)optimizer.step()loss_sum += loss * batchsizeprint(f'Epoch{epoch}. loss :{loss_sum/dataset_len}')torch.save(model.state_dict(),'rnn.pth')

output:?

model.load_state_dict(torch.load('rnn.pth'))
accuracy = 0
dataset_len = len(test_dataloader.dataset)
model.eval()
for x, y in test_dataloader:with torch.no_grad():hat_y = model(x)hat_y.squeeze_(1)predictions = torch.where(hat_y>0.5,1,0)score = torch.sum(torch.where(predictions==y,1,0))accuracy += score.item()
accuracy /= dataset_lenprint(f'Accuracy:{accuracy}')   

Accuracy:0.90516

tokenizer = get_tokenizer('basic_english')
article = "U.S. stock indexes fell Tuesday, driven by expectations for tighter Federal Reserve policy and an energy crisis in Europe. Stocks around the globe have come under pressure in recent weeks as worries about tighter monetary policy in the U.S. and a darkening economic outlook in Europe have led investors to sell riskier assets."x = GLOVE.get_vecs_by_tokens(tokenizer(article)).unsqueeze(0)
with torch.no_grad():hat_y = model(x)
hat_y = hat_y.squeeze_().item()
result = 'positive' if hat_y > 0.5 else 'negative'
print(result)

negative

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/pingmian/79706.shtml
繁體地址,請注明出處:http://hk.pswp.cn/pingmian/79706.shtml
英文地址,請注明出處:http://en.pswp.cn/pingmian/79706.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

kotlin中 熱流 vs 冷流 的本質區別

&#x1f525; 冷流&#xff08;Cold Flow&#xff09; vs 熱流&#xff08;Hot Flow&#xff09;區別 特性冷流&#xff08;Cold Flow&#xff09;熱流&#xff08;Hot Flow&#xff09;數據生產時機每次 collect 才開始執行啟動時就開始生產、始終運行生命周期與 collect 者…

精益數據分析(44/126):深度解析媒體網站商業模式的關鍵要點

精益數據分析&#xff08;44/126&#xff09;&#xff1a;深度解析媒體網站商業模式的關鍵要點 在創業與數據分析的探索道路上&#xff0c;我們不斷挖掘不同商業模式的核心要素&#xff0c;今天將深入剖析媒體網站商業模式。希望通過對《精益數據分析》相關內容的解讀&#xf…

Android學習總結之Java和kotlin區別

一、空安全機制 真題 1&#xff1a;Kotlin 如何解決 Java 的 NullPointerException&#xff1f;對比兩者在空安全上的設計差異 解析&#xff1a; 核心考點&#xff1a;Kotlin 可空類型系統&#xff08;?&#xff09;、安全操作符&#xff08;?./?:&#xff09;、非空斷言&…

[Survey]Remote Sensing Temporal Vision-Language Models: A Comprehensive Survey

BaseInfo TitleRemote Sensing Temporal Vision-Language Models: A Comprehensive SurveyAdresshttps://arxiv.org/abs/2412.02573Journal/Time2024 arxivAuthor北航 上海AI LabCodehttps://github.com/Chen-Yang-Liu/Awesome-RS-Temporal-VLM 1. Introduction 傳統遙感局限…

jmeter讀取CSV文件中文亂碼的解決方案

原因分析? CSV文件出現中文亂碼通常是因為文件編碼與JMeter讀取編碼不一致。常見場景&#xff1a; 文件保存為GBK/GB2312編碼&#xff0c;但JMeter以UTF-8讀取。文件包含BOM頭&#xff08;如Windows記事本保存的UTF-8&#xff09;&#xff0c;但JMeter未正確處理。腳本讀取文…

Webview通信系統學習指南

Webview通信系統學習指南 一、定義與核心概念 1. 什么是Webview&#xff1f; 定義&#xff1a;Webview是移動端&#xff08;Android/iOS&#xff09;內置的輕量級瀏覽器組件&#xff0c;用于在原生應用中嵌入網頁內容。作用&#xff1a;實現H5頁面與原生應用的深度交互&…

【C++】C++中的命名/名字/名稱空間 namespace

C中的命名/名字/名稱空間 namespace 1、問題引入2、概念3、作用4、格式5、使用命名空間中的成員5.1 using編譯指令&#xff08; 引進整個命名空間&#xff09; ---將這個盒子全部打開5.2 using聲明使特定的標識符可用(引進命名空間的某個成員) ---將這個盒子中某個成員的位置打…

Arduino IDE中離線更新esp32 3.2.0版本的辦法

在Arduino IDE中更新esp32-3.2.0版本是個不可能的任務&#xff0c;下載文件速度極慢。網上提供了離線的辦法&#xff0c;提供了安裝文件&#xff0c;但是沒有3.2.0的版本。 下面提供了一種離線安裝方法 一、騰訊元寶查詢解決辦法 通過打開開發板管理地址&#xff1a;通過在騰…

【工具使用-數據可視化工具】Apache Superset

1. 工具介紹 1.1. 簡介 一個輕量級、高性能的數據可視化工具 官網&#xff1a;https://superset.apache.org/GitHub鏈接&#xff1a;https://github.com/apache/superset官方文檔&#xff1a;https://superset.apache.ac.cn/docs/intro/ 1.2. 核心功能 豐富的可視化庫&…

算法每日一題 | 入門-順序結構-三角形面積

三角形面積 題目描述 一個三角形的三邊長分別是 a、b、c&#xff0c;那么它的面積為 p ( p ? a ) ( p ? b ) ( p ? c ) \sqrt{p(p-a)(p-b)(p-c)} p(p?a)(p?b)(p?c) ?&#xff0c;其中 p 1 2 ( a b c ) p\frac{1}{2}(abc) p21?(abc) 。輸入這三個數字&#xff0c;…

MongoDB入門詳解

文章目錄 MongoDB下載和安裝1.MongoDBCompass字段簡介1.1 Aggregations&#xff08;聚合&#xff09;1.2 Schema&#xff08;模式分析&#xff09;1.3 Indexes&#xff08;索引&#xff09;1.4 Validation&#xff08;數據驗證&#xff09; 2.增刪改查操作2.1創建、刪除數據庫&…

從Oculus到Meta:Facebook實現元宇宙的硬件策略

Oculus的起步 Facebook在2014年收購了Oculus&#xff0c;這標志著其在虛擬現實&#xff08;VR&#xff09;領域的首次重大投資。Oculus Rift作為公司的旗艦產品&#xff0c;是一款高端的VR頭戴設備&#xff0c;它為用戶帶來了沉浸式的體驗。Facebook通過Oculus Rift&#xff0…

安裝與配置Go語言開發環境 -《Go語言實戰指南》

為了開始使用Go語言進行開發&#xff0c;我們首先需要正確安裝并配置Go語言環境。Go的安裝相對簡單&#xff0c;支持多平臺&#xff0c;包括Windows、macOS和Linux。本節將逐一介紹各平臺的安裝流程及環境變量配置方式。 一、Windows系統 1. 下載Go安裝包 前往Go語言官網&…

網絡的搭建

1、rpm rpm -ivh 2、yum倉庫&#xff08;rpm包&#xff09;&#xff1a;網絡源 ----》網站 本地源 ----》/dev/sr0 光盤映像文件 3、源碼安裝 源碼安裝&#xff08;編譯&#xff09; 1、獲取源碼 2、檢測環境生成Ma…

多元隨機變量協方差矩陣

主要記錄多元隨機變量數字特征相關內容。 關鍵詞&#xff1a;多元統計分析 一元隨機變量 總體 隨機變量Y 總體均值 μ E ( Y ) ∫ y f ( y ) d y \mu E(Y) \int y f(y) \, dy μE(Y)∫yf(y)dy 總體方差 σ 2 V a r ( Y ) E ( Y ? μ ) 2 \sigma^2 Var(Y) E(Y - \…

Ros工作空間

工作空間其實放到嵌入式里就是相關的編程包 ------------------------------------- d第一個Init 就是類型的初始化 然后正常一個catkin_make 后 就會產生如devil之類的文件&#xff0c; 你需要再自己 終端 一個catkin_make install 一下 。這樣對應install也會產生&#xf…

qt國際化翻譯功能用法

文章目錄 [toc]1 概述2 設置待翻譯文本3 生成ts翻譯源文件4 編輯ts翻譯源文件5 生成qm翻譯二進制文件6 加載qm翻譯文件進行翻譯 更多精彩內容&#x1f449;內容導航 &#x1f448;&#x1f449;Qt開發經驗 &#x1f448; 1 概述 在 Qt 中&#xff0c;ts 文件和 qm 文件是用于國…

PyTorch 與 TensorFlow 中基于自定義層的 DNN 實現對比

深度學習雙雄對決&#xff1a;PyTorch vs TensorFlow 自定義層大比拼 目錄 深度學習雙雄對決&#xff1a;PyTorch vs TensorFlow 自定義層大比拼一、TensorFlow 實現 DNN1. 核心邏輯 二、PyTorch 實現自定義層1. 核心邏輯 三、關鍵差異對比四、總結 一、TensorFlow 實現 DNN 1…

1ms城市算網穩步啟航,引領數字領域的“1小時經濟圈”效應

文 | 智能相對論 作者 | 陳選濱 為什么近年來國產動畫、國產3A大作迎來了井噴式爆發&#xff1f;拋開制作水平以及市場需求的升級不談&#xff0c;還有一個重要原因往往被大多數人所忽視&#xff0c;那就是新型信息的完善與成熟。 譬如&#xff0c;現階段驚艷用戶的云游戲以及…

【計算機視覺】語義分割:Segment Anything (SAM):通用圖像分割的范式革命

Segment Anything&#xff1a;通用圖像分割的范式革命 技術突破與架構創新核心設計理念關鍵技術組件 環境配置與快速開始硬件要求安裝步驟基礎使用示例 深度功能解析1. 多模態提示融合2. 全圖分割生成3. 高分辨率處理 模型微調與定制1. 自定義數據集準備2. 微調訓練配置 常見問…