基于CBOW模型的詞向量訓練實戰:從原理到PyTorch實現

基于CBOW模型的詞向量訓練實戰:從原理到PyTorch實現

在自然語言處理(NLP)領域,詞向量是將單詞映射為計算機可處理的數值向量的重要方式。通過詞向量,單詞之間的語義關系能夠以數學形式表達,為后續的文本分析、機器翻譯、情感分析等任務奠定基礎。本文將結合連續詞袋模型(CBOW),詳細介紹如何使用PyTorch訓練詞向量,并通過具體代碼實現和分析訓練過程。

一、CBOW模型原理簡介

CBOW(Continuous Bag-of-Words)模型是一種用于生成詞向量的神經網絡模型,它基于上下文預測目標詞。其核心思想是:給定一個目標詞的上下文單詞,通過模型預測該目標詞。在訓練過程中,模型會不斷調整參數,使得預測結果盡可能接近真實的目標詞,最終訓練得到的詞向量能夠捕捉單詞之間的語義關系。

例如,在句子 “People create programs to direct processes” 中,如果目標詞是 “programs”,CBOW模型會利用其上下文單詞 “People”、“create”、“to”、“direct” 來預測 “programs”。通過大量類似樣本的訓練,模型能夠學習到單詞之間的語義關聯,從而生成有效的詞向量。

二、代碼實現與詳細解析

下面我會逐行解釋你提供的代碼,此代碼借助 PyTorch 實現了一個連續詞袋模型(CBOW)來學習詞向量。

1. 導入必要的庫

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from tqdm import tqdm, trange  # 顯示進度條
import numpy as np
  • torch:PyTorch 深度學習框架的核心庫。
  • torch.nn:用于構建神經網絡的模塊。
  • torch.nn.functional:提供了許多常用的函數,像激活函數等。
  • torch.optim:包含各種優化算法。
  • tqdmtrange:用于在訓練過程中顯示進度條。
  • numpy:用于處理數值計算和數組操作。

2. 定義上下文窗口大小和原始文本

CONTEXT_SIZE = 2
raw_text = """We are about to study the idea of a computational process.
Computational processes are abstract beings that inhabit computers.
As they evolve, processes manipulate other abstract things called data.
The evolution of a process is directed by a pattern of rules called a program. 
People create programs to direct processes. 
In effect,we conjure the spirits of the computer with our spells.""".split()
  • CONTEXT_SIZE:上下文窗口的大小,意味著在預測目標詞時,會考慮其前后各 CONTEXT_SIZE 個單詞。
  • raw_text:原始文本,將其按空格分割成單詞列表。

3. 構建詞匯表和索引映射

vocab = set(raw_text)  # 集合,詞庫,里面的內容獨一無二(將文本中所有單詞去重后得到的詞匯表)
vocab_size = len(vocab)  # 詞匯表的大小word_to_idx = {word: i for i, word in enumerate(vocab)}  # 單詞到索引的映射字典
idx_to_word = {i: word for i, word in enumerate(vocab)}  # 索引到單詞的映射字典
  • vocab:把原始文本中的所有單詞去重后得到的詞匯表。
  • vocab_size:詞匯表的大小。
  • word_to_idx:將單詞映射為對應的索引。
  • idx_to_word:將索引映射為對應的單詞。

4. 構建訓練數據集

data = []  # 獲取上下文詞,將上下文詞作為輸入,目標詞作為輸出,構建訓練數據集(用于存儲訓練數據,每個元素是一個元組,包含上下文詞列表和目標詞)
for i in range(CONTEXT_SIZE, len(raw_text) - CONTEXT_SIZE):context = ([raw_text[i - (2 - j)] for j in range(CONTEXT_SIZE)]+ [raw_text[i + j + 1] for j in range(CONTEXT_SIZE)])  # 獲取上下文詞target = raw_text[i]  # 獲取目標詞data.append((context, target))  # 將上下文詞和目標詞保存到 data 中
  • data:用于存儲訓練數據,每個元素是一個元組,包含上下文詞列表和目標詞。
  • 通過循環遍歷原始文本,提取每個目標詞及其上下文詞,然后將它們添加到 data 中。

5. 定義將上下文詞轉換為張量的函數

def make_context_vector(context, word_to_ix):  # 將上下詞轉換為 one - hotidxs = [word_to_ix[w] for w in context]return torch.tensor(idxs, dtype=torch.long)
  • make_context_vector:把上下文詞列表轉換為對應的索引張量。

6. 打印第一個上下文詞的索引張量

print(make_context_vector(data[0][0], word_to_idx))
  • 打印第一個訓練樣本的上下文詞對應的索引張量。

7. 定義 CBOW 模型

class CBOW(nn.Module):  # 神經網絡def __init__(self, vocab_size, embedding_dim):super(CBOW, self).__init__()self.embeddings = nn.Embedding(vocab_size, embedding_dim)self.proj = nn.Linear(embedding_dim, 128)self.output = nn.Linear(128, vocab_size)def forward(self, inputs):embeds = sum(self.embeddings(inputs)).view(1, -1)out = F.relu(self.proj(embeds))  # nn.relu() 激活層out = self.output(out)nll_prob = F.log_softmax(out, dim=1)return nll_prob
  • CBOW:繼承自 nn.Module,定義了 CBOW 模型的結構。
    • __init__:初始化模型的層,包含一個嵌入層、一個線性層和另一個線性層。
    • forward:定義了前向傳播過程,將輸入的上下文詞索引轉換為嵌入向量,求和后經過線性層和激活函數,最后輸出對數概率。

8. 選擇設備并創建模型實例

device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
print(f"Using {device} device")  # 字符串的格式化
model = CBOW(vocab_size, 10).to(device)
  • device:檢查當前設備是否支持 GPU(CUDA 或 MPS),若支持則使用 GPU,否則使用 CPU。
  • model:創建 CBOW 模型的實例,并將其移動到指定設備上。

9. 定義優化器、損失函數和損失列表

optimizer = optim.Adam(model.parameters(), lr=0.001)  # 創建一個優化器,
losses = []  # 存儲損失的集合
loss_function = nn.NLLLoss()
  • optimizer:使用 Adam 優化器來更新模型的參數。
  • losses:用于存儲每個 epoch 的損失值。
  • loss_function:使用負對數似然損失函數。

10. 訓練模型

model.train()for epoch in tqdm(range(200)):total_loss = 0for context, target in data:context_vector = make_context_vector(context, word_to_idx).to(device)target = torch.tensor([word_to_idx[target]]).to(device)# 開始向前傳播train_predict = model(context_vector)loss = loss_function(train_predict, target)# 反向傳播optimizer.zero_grad()loss.backward()optimizer.step()total_loss += loss.item()losses.append(total_loss)print(losses)
  • model.train():將模型設置為訓練模式。
  • 通過循環進行 200 個 epoch 的訓練,每個 epoch 遍歷所有訓練數據。
    • 將上下文詞和目標詞轉換為張量并移動到指定設備上。
    • 進行前向傳播得到預測結果。
    • 計算損失。
    • 進行反向傳播并更新模型參數。
    • 累加每個 epoch 的損失值。

11. 進行預測

context = ['People', 'create', 'to', 'direct']
context_vector = make_context_vector(context, word_to_idx).to(device)model.eval()  # 進入到測試模式
predict = model(context_vector)
max_idx = predict.argmax(1)  # dim = 1 表示每一行中的最大值對應的索引號, dim = 0 表示每一列中的最大值對應的索引號print("CBOW embedding weight =", model.embeddings.weight)  # GPU
W = model.embeddings.weight.cpu().detach().numpy()
print(W)
  • 選擇一個上下文詞列表進行預測。
  • model.eval():將模型設置為評估模式。
  • 進行預測并獲取預測結果中概率最大的索引。
  • 打印嵌入層的權重,并將其轉換為 NumPy 數組。

12. 構建詞向量字典

word_2_vec = {}
for word in word_to_idx.keys():word_2_vec[word] = W[word_to_idx[word], :]
print('jiesu')
  • word_2_vec:將每個單詞映射到其對應的詞向量。

13. 保存和加載詞向量

np.savez('word2vec實現.npz', file_1 = W)
data = np.load('word2vec實現.npz')
print(data.files)
  • np.savez:將詞向量保存為 .npz 文件。
  • np.load:加載保存的 .npz 文件,并打印文件中的數組名稱。

綜上所述,這段代碼實現了一個簡單的 CBOW 模型來學習詞向量,并將學習到的詞向量保存到文件中。 。運行結果
在這里插入圖片描述

三、總結

通過上述代碼的實現和分析,我們成功地使用CBOW模型在PyTorch框架下完成了詞向量的訓練。從數據準備、模型定義,到訓練和測試,再到詞向量的保存,每一個步驟都緊密相連,共同構建了一個完整的詞向量訓練流程。

CBOW模型通過上下文預測目標詞的方式,能夠有效地學習到單詞之間的語義關系,生成的詞向量可以應用于各種自然語言處理任務。在實際應用中,我們還可以通過調整模型的超參數(如詞向量維度、上下文窗口大小、訓練輪數等),以及使用更大規模的數據集,進一步優化詞向量的質量和模型的性能。希望本文的內容能夠幫助讀者更好地理解CBOW模型和詞向量訓練的原理與實踐。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/web/78723.shtml
繁體地址,請注明出處:http://hk.pswp.cn/web/78723.shtml
英文地址,請注明出處:http://en.pswp.cn/web/78723.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

Linux——進程終止/等待/替換

前言 本章主要對進程終止,進程等待,進程替換的詳細認識,根據實驗去理解其中的原理,干貨滿滿! 1.進程終止 概念:進程終止就是釋放進程申請的內核數據結構和對應的代碼和數據 進程退出的三種狀態 代碼運行…

iOS開發架構——MVC、MVP和MVVM對比

文章目錄 前言MVC(Model - View - Controller)MVP(Model - View - Presenter)MVVM(Model - View - ViewModel) 前言 在 iOS 開發中,MVC、MVVM、和 MVP 是常見的三種架構模式,它們主…

0506--01-DA

36. 單選題 在娛樂方式多元化的今天,“ ”是不少人(特別是中青年群體)對待戲曲的態度。這里面固然存在 的偏見、難以靜下心來欣賞戲曲之美等因素,卻也有另一個無法回避的原因:一些戲曲雖然與觀眾…

關于Java多態簡單講解

面向對象程序設計有三大特征,分別是封裝,繼承和多態。 這三大特性相輔相成,可以使程序員更容易用編程語言描述現實對象。 其中多態 多態是方法的多態,是通過子類通過對父類的重寫,實現不同子類對同一方法有不同的實現…

【Trea】Trea國際版|海外版下載

Trea目前有兩個版本,海外版和國內版。? Trae 版本差異 ?大模型選擇?: ?國內版?:提供了字節自己的Doubao-1.5-pro以及DeepSeek的V3版本和R1版本。海外版:提供了ChartGPT以及Claude-3.5-Sonnet和3.7-Sonnt. ?功能和界面?&a…

Missashe考研日記-day33

Missashe考研日記-day33 1 專業課408 學習時間:2h30min學習內容: 今天開始學習OS最后一章I/O管理的內容,聽了第一小節的內容,然后把課后習題也做了。知識點回顧: 1.I/O設備分類:按信息交換單位、按設備傳…

鏈表的面試題3找出中間節點

來來來,接著繼續我們的第三道題 。 解法 暴力求解 快慢指針 https://leetcode.cn/problems/middle-of-the-linked-list/submissions/ 這道題的話,思路是非常明確的,就是讓你找出我們這個所謂的中間節點并且輸出。 那這道題我們就需要注意…

linux磁盤介紹與LVM管理

一、磁盤基本概述 GPT是全局唯一標識分區表的縮寫,是全局唯一標示磁盤分區表格式。而MBR則是另一種磁盤分區形式,它是主引導記錄的縮寫。相比之下,MBR比GPT出現得要更早一些。 MBR 與 GPT MBR 支持的磁盤最大容量為 2 TB,GPT 最大支持的磁盤容量為 18 EB,當前數據盤支持…

突破測試環境文件上傳帶寬瓶頸!React Native 阿里云 OSS 直傳文件格式問題攻克二

上一篇我們對服務端和阿里云oss的配置及前端調用做了簡單的介紹,但是一直報錯。最終判斷是文件格式問題,通常我們在reactnative中用formData上傳, formData.append(file, {uri: file, name: nameType(type), type: multipart/form-data});這…

Spring Boot 中 @Bean 注解詳解:從入門到實踐

在 Spring Boot 開發中,Bean注解是一個非常重要且常用的注解,它能夠幫助開發者輕松地將 Java 對象納入 Spring 容器的管理之下,實現對象的依賴注入和生命周期管理。對于新手來說,理解并掌握Bean注解,是深入學習 Spring…

TCP 協議設計入門:自定義消息格式與粘包解決方案

目錄 一、為什么需要自定義 TCP 協議? TCP粘包問題的本質 1.1 粘包與拆包的定義 1.2 粘包的根本原因 1.3 粘包的典型場景 二、自定義消息格式設計 2.1 協議結構設計 方案1:固定長度協議 方案2:分隔符標記法 方案3:長度前…

了解一下OceanBase中的表分區

OceanBase 是一個高性能的分布式關系型數據庫,它支持 SQL 標準的大部分功能,包括分區表。分區表可以幫助管理大量數據,提高查詢效率,通過將數據分散到不同的物理段中,可以減少查詢時的數據掃描量。 在 OceanBase 中操…

多線程網絡編程:粘包問題、多線程/多進程服務器實戰與常見問題解析

多線程網絡編程:粘包問題、多線程/多進程服務器實戰與常見問題解析 一、TCP粘包問題:成因、影響與解決方案 1. 粘包問題本質 TCP是面向流的協議,數據傳輸時沒有明確的消息邊界,導致多個消息可能被合并(粘包&#xf…

大模型主干

1.什么是語言模型骨架LLM-Backbone,在多模態模型中的作用? 語言模型骨架(LLM Backbone)是多模態模型中的核心組件之一。它利用預訓練的語言模型(如Flan-T5、ChatGLM、UL2等)來處理各種模態的特征,進行語義…

[創業之路-350]:光刻機、激光器、自動駕駛、具身智能:跨學科技術體系全景解析(光-機-電-材-熱-信-控-軟-網-算-智)

光刻機、激光器、自動駕駛、具身智能四大領域的技術突破均依賴光、機、電、材、熱、信、控、軟、網、算、智十一大學科體系的深度耦合。以下從技術原理、跨學科融合、關鍵挑戰三個維度展開系統性分析: 一、光刻機:精密制造的極限挑戰 1. 核心技術與學科…

SVTAV1 編碼函數 svt_aom_is_pic_skipped

一 函數解釋 1.1 svt_aom_is_pic_skipped函數的作用是判斷當前圖片是否可以跳過編碼處理。 具體分析如下 函數邏輯 參數說明:函數接收一個指向圖片父控制集的指針PictureParentControlSet *pcs, 通過這個指針可以獲取與圖片相關的各種信息,用于判斷是否跳…

【Redis新手入門指南】從小白入門到日常使用(全)

文章目錄 前言redis是什么?定義原理與特點與MySQL對比 Redis安裝方式一、Homebrew 快速安裝 Redis(推薦)方式二、源碼編譯安裝redisHomebrew vs 源碼安裝對比 redis配置說明修改redis配置的方法常見redis配置項說明 redis常用命令redis服務啟…

Linux grep 命令詳解及示例大全

文章目錄 一、基本語法二、常用選項及示例1. 基本匹配:查找包含某字符串的行2. 忽略大小寫匹配 -i3. 顯示行號 -n4. 遞歸查找目錄下的文件 -r 或 -R5. 僅顯示匹配的字符串 -o6. 使用正則表達式 -E(擴展)或 egrep7. 顯示匹配前后行 -A, -B, -C…

【排序算法】快速排序(全坤式超詳解)———有這一篇就夠啦

【排序算法】——快速排序 目錄 一:快速排序——思想 二:快速排序——分析 三:快速排序——動態演示圖 四:快速排序——單趟排序 4.1:霍爾法 4.2:挖坑法 4.3:前后指針法 五:…

【platform push 提示 Invalid source ref: HEAD】

platform push 提示 Invalid source ref: HEAD 場景:環境:排查過程:解決: 場景: 使用platform push 命令行輸入git -v 可以輸出git 版本號,但就是提示Invalid source ref: HEAD,platform creat…