自然語言處理之PyTorch實現詞袋CBOW模型

在自然語言處理(NLP)領域,詞向量(Word Embedding)是將文本轉換為數值向量的核心技術。它能讓計算機“理解”詞語的語義關聯,例如“國王”和“女王”的向量差可能與“男人”和“女人”的向量差相似。而Word2Vec作為經典的詞向量訓練模型,其核心思想是通過上下文預測目標詞(或反之)。本文將以 --CBOW(連續詞袋模型)為例,帶你從代碼到原理,一步步實現一個簡單的詞向量訓練過程。

一、CBOW模型簡介

CBOW(Continuous Bag-of-Words)是Word2Vec的兩種核心模型之一。其核心思想是:給定目標詞的上下文窗口內的所有詞,預測目標詞本身。例如,對于句子“We are about to study”,若上下文窗口大小為2(即目標詞左右各取2個詞),則當目標詞是“about”時,上下文是“We, are, to, study”,模型需要根據這4個詞預測出“about”。

CBOW的優勢在于通過平均上下文詞向量來預測目標詞,計算效率高;缺點是對低頻詞不友好。本文將實現的CBOW模型包含詞嵌入層、投影層和輸出層,最終輸出目標詞的概率分布。

二、環境準備與數據預處理

2.1 進度條庫安裝

pip install torch numpy tqdm

2.2 語料庫與基礎設置

我們使用一段英文文本作為語料庫,并定義上下文窗口大小(CONTEXT_SIZE=2,即目標詞左右各取2個詞):

CONTEXT_SIZE = 2  # 上下文窗口大小(左右各2個詞)
raw_text = """We are about to study the idea of a computational process.
Computational processes are abstract beings that inhabit computers.
As they evolve, processes manipulate other abstract things called data.
The evolution of a process is directed by a pattern of rules
called a program. People create programs to direct processes. In effect,
we conjure the spirits of the computer with our spells.""".split()  # 按空格分割成單詞列表

2.3 構建詞匯表與映射

為了將文本轉換為模型可處理的數值,需要先構建詞匯表(所有唯一詞),并為每個詞分配唯一索引:

vocab = set(raw_text)  # 去重后的詞匯表(集合)
vocab_size = len(vocab)  # 詞匯表大小(本文示例中為49)# 詞到索引的映射(如:"We"→0,"are"→1)
word_to_idx = {word: i for i, word in enumerate(vocab)}
# 索引到詞的反向映射(如:0→"We",1→"are")
idx_to_word = {i: word for i, word in enumerate(vocab)}

三、生成訓練數據:上下文-目標詞對

CBOW的訓練數據是“上下文詞列表”與“目標詞”的配對。例如,若目標詞是raw_text[i],則上下文是[raw_text[i-2], raw_text[i-1], raw_text[i+1], raw_text[i+2]](假設窗口大小為2)。

3.1 數據生成邏輯

通過遍歷語料庫,跳過前CONTEXT_SIZE和后CONTEXT_SIZE個詞(避免越界),生成上下文-目標詞對:

data = []
for i in range(CONTEXT_SIZE, len(raw_text) - CONTEXT_SIZE):# 左上下文:取i-2, i-1(j從0到1,2-j對應2,1)left_context = [raw_text[i - (2 - j)] for j in range(CONTEXT_SIZE)]# 右上下文:取i+1, i+2(j從0到1,i+j+1對應i+1, i+2)right_context = [raw_text[i + j + 1] for j in range(CONTEXT_SIZE)]context = left_context + right_context  # 合并上下文(共4個詞)target = raw_text[i]  # 目標詞(當前中心詞)data.append((context, target))

3.2 示例驗證

i=2為例:

  • 左上下文:i-2=0(“We”),i-1=1(“are”)→ ["We", "are"]
  • 右上下文:i+1=3(“to”),i+2=4(“study”)→ ["to", "study"]
  • 上下文合并:["We", "are", "to", "study"]
  • 目標詞:raw_text[2](“about”)

四、CBOW模型實現(PyTorch)

4.1 模型結構設計

CBOW模型的核心是通過上下文詞的詞向量預測目標詞。模型結構包含三層:

  1. 詞嵌入層(Embedding):將詞的索引映射為低維稠密向量(如10維)。
  2. 投影層(Linear):將拼接后的詞向量投影到更高維度(如128維),增加非線性表達能力。
  3. 輸出層(Linear):將投影后的向量映射回詞匯表大小,通過Softmax輸出目標詞的概率分布。
import torch
import torch.nn as nn
import torch.nn.functional as Fclass CBOW(nn.Module):  # 神經網絡def __init__(self, vocab_size, embedding_dim):super(CBOW, self).__init__()  # 父類的初始化self.embeddings = nn.Embedding(vocab_size, embedding_dim)self.proj = nn.Linear(embedding_dim, 128)self.output = nn.Linear(128, vocab_size)def forward(self, inputs):embeds = sum(self.embeddings(inputs)).view(1, -1)out = F.relu(self.proj(embeds))  # nn.relu() 激活層out = self.output(out)nll_prob = F.log_softmax(out, dim=-1)  # softmax交叉熵return nll_prob

五、模型訓練與優化

5.1 初始化模型與超參數

設置詞向量維度(embedding_dim=10)、學習率(lr=0.001)、訓練輪數(epochs=200),并初始化模型、優化器和損失函數:

vocab_size = 49
model = CBOW(vocab_size, 10).to(device)
optimizer=torch.optim.Adam(model.parameters(),lr=0.001)
losses = []# 存儲損失的集合  losses: []
loss_function = nn.NLLLoss()        #NLLLoss損失函數(當分類列表非常多的情況),將多個類

5.2 訓練循環邏輯

遍歷每個訓練輪次(epoch),對每個上下文-目標詞對進行前向傳播、損失計算、反向傳播和參數更新:

model.train()
for epoch in tqdm(range(200)):#開始訓練total_loss = 0for context, target in data:context_vector = make_context_vector(context, word_to_idx).to(device)target = torch.tensor([word_to_idx[target]]).to(device)# 開始前向傳播train_predict = model(context_vector)  # 可以不寫forward,torch的內置功能,loss = loss_function(train_predict, target)  # 計算損失# 反向傳播optimizer.zero_grad()  # 梯度值清零loss.backward()  # 反向傳播計算得到每個參數的梯度optimizer.step()  # 根據梯度更新網絡參數total_loss += loss.item()

六、詞向量提取與應用

訓練完成后,模型的詞嵌入層(model.embeddings.weight)中存儲了每個詞的向量表示。我們可以將其提取并保存,用于后續任務(如文本分類、相似度計算)。

6.1 提取詞向量

# 將詞向量從GPU移至CPU,并轉換為NumPy數組
W = model.embeddings.weight.cpu().detach().numpy()
print("詞向量矩陣形狀:", W.shape)  # (vocab_size, embedding_dim) → (49, 10)

6.2 生成詞-向量映射字典

word_2_vec = {}
for word, idx in word_to_idx.items():word_2_vec[word] = W[idx]  # 每個詞對應詞向量矩陣中的一行
print("示例詞向量('process'):", word_2_vec["process"])

6.3 保存詞向量

使用np.savez保存詞向量矩陣,方便后續加載使用:

import numpy as np
np.savez('word2vec實現.npz', word_vectors=W)  # 保存為npz文件# 加載驗證
data = np.load('word2vec實現.npz')
loaded_vectors = data['word_vectors']
print("加載的詞向量形狀:", loaded_vectors.shape)  # 應與原始矩陣一致

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/pingmian/95787.shtml
繁體地址,請注明出處:http://hk.pswp.cn/pingmian/95787.shtml
英文地址,請注明出處:http://en.pswp.cn/pingmian/95787.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

TCP, 三次握手, 四次揮手, 滑動窗口, 快速重傳, 擁塞控制, 半連接隊列, RST, SYN, ACK

目錄 TCP 是什么:面向連接 可靠 字節流三次握手:為什么不是兩次四次揮手與 TIME_WAIT:誰等誰序列號/確認號與去重、排序、確認重傳機制:超時重傳與快速重傳滑動窗口與流量控制擁塞控制:慢啟動/擁塞避免/快重傳/快恢…

CentOS 7.2 虛機 ssh 登錄報錯在重啟后無法進入系統

文章目錄前言1. 故障描述2. 故障診斷3. 故障原因4. 解決方案總結前言 上周幫用戶處理了一個 linux 虛擬機在重啟后無法正常進入操作系統的故障,覺得比較有意思,在這里分享給大家。 1. 故障描述 事情的起因是一臺系統版本為 CentOS 7.2 的 VMware 虛擬機…

《從使用到源碼:OkHttp3責任鏈模式剖析》

一 從使用開始0.依賴引入implementation ("com.squareup.okhttp3:okhttp:3.14.7")1.創建OkHttpClient實例方式一:直接使用默認配置的Builder//從源碼可以看出,當我們直接new創建OkHttpClient實例時,會默認給我們配置好一個Builder …

安裝3DS MAX 2026后,無法運行,提示缺少.net core的解決方案

今天安裝了3DS MAX 2026(俗稱3DMAX),安裝完畢后死活運行不了。提示如下: 大意是找不到所需的.NET Core 8庫文件。后來搜索了下,各種文章說.NET CORE和.NET FRAMEWORK不是一個東西。需要單獨下載安裝。然后根據提示&…

FastAPI + LangChain 和 Spring AI + LangChain4j

FastAPI+LangChain和Spring AI+LangChain4j這兩個技術組合進行詳細對比。 核心區別: 特性維度 FastAPI + LangChain (Python棧) Spring AI + LangChain4j (Java棧) 技術棧 Python生態 (FastAPI, LangChain) Java生態 (Spring Boot, Spring AI, LangChain4j) 核心設計哲學 靈活…

Apache 2.0 開源協議詳解:自由、責任與商業化的完美平衡-優雅草卓伊凡

Apache 2.0 開源協議詳解:自由、責任與商業化的完美平衡-優雅草卓伊凡引言由于我們優雅草要推出收銀系統,因此要采用開源代碼,卓伊凡目前看好了一個產品是apache 2.0協議,因此我們有必要深刻理解apache 2.0協議避免觸犯版權問題。…

自學嵌入式第37天:MQTT協議

一、MQTT(消息隊列遙測傳輸協議Message Queuing Telemetry Transport)1.MQTT是應用層的協議,是一種基于發布/訂閱模式的“輕量級”通訊協議,建構于TCP/IP協議上,可以以極少的代碼和有限的帶寬為連接遠程設備提供實時可…

RabbitMQ--延時隊列總結

一、延遲隊列概念 延遲隊列(Delay Queue)是一種特殊類型的隊列,隊列中的元素需要在指定的時間點被取出和處理。簡單來說,延時隊列就是存放需要在某個特定時間被處理的消息。它的核心特性在于“延遲”——消息在隊列中停留一段時間…

Java 提取 PDF 文件內容:告別手動復制粘貼,擁抱自動化解析!

在日常工作中,我們經常需要處理大量的 PDF 文檔,無論是提取報告中的關鍵數據,還是解析合同中的重要條款,手動復制粘貼不僅效率低下,還極易出錯。當面對海量的 PDF 文件時,這種傳統方式更是讓人望而卻步。那…

關鍵字 const

Flutter 是一個使用 Dart 語言構建的 UI 工具包,因此它完全遵循 Dart 的語法和規則。Dart 中的 const 是語言層面的特性,而 Flutter 因其聲明式 UI 和頻繁重建的特性,將 const 的效能發揮到了極致。Dart 中的 const(語言層面&…

Ubuntu22.04中使用cmake安裝abseil-cpp庫

Ubuntu22.04中使用cmake安裝abseil-cpp庫 關于Abseil庫 Abseil 由 Google 的基礎 C 和 Python 代碼庫組成,包括一些正支撐著如 gRPC、Protobuf 和 TensorFlow 等開源項目并一起 “成長” 的庫。目前已開源 C 部分,Python 部分將在后續開放。 Abseil …

FreeRTOS項目(序)目錄

這章是整個專欄的目錄,負責記錄這個小項目的開發日志和目錄。附帶總流程圖。 目錄 項目簡介 專欄目錄 開發日志 總流程圖 項目簡介 本項目基于STM32C8T6核心板和FreeRTOS,實現一些簡單的功能。以下為目前已實現的功能。 (1&#xff09…

Python 多任務編程:進程、線程與協程全面解析

目錄 一、多任務基礎:并發與并行 1. 什么是多任務 2. 兩種表現形式 二、進程:操作系統資源分配的最小單位 1. 進程的概念 2. 多進程實現多任務 2.1 基礎示例:邊聽音樂邊敲代碼 2.2 帶參數的進程任務 2.3 進程編號與應用注意點 2.3.…

ADSL技術

<摘要> ADSL&#xff08;非對稱數字用戶線路&#xff09;是一種利用傳統電話線實現寬帶上網的技術。其核心原理是頻率分割&#xff1a;將一根電話線的頻帶劃分為語音、上行數據&#xff08;慢&#xff09;和下行數據&#xff08;快&#xff09;三個獨立頻道&#xff0c;從…

信號衰減中的分貝到底是怎么回事

問題&#xff1a;在一個低通濾波中&#xff0c;經常會看到一個值-3dB&#xff08;-3分貝&#xff09;&#xff0c;到底是個什么含義&#xff1f; 今天我就來粗淺的講解這個問題。 在低通濾波器中&#xff0c;我們說的 “截止頻率”&#xff08;或叫 - 3dB 點&#xff09;&…

工具分享--IP與域名提取工具2.0

基于原版的基礎上新增了一個功能點:IP-A段過濾&#xff0c;可以快速把內網192、170、10或者其它你想要過濾掉的IP-A段輕松去掉&#xff0c;提高你的干活效率&#xff01;&#xff01;&#xff01; 界面樣式預覽&#xff1a;<!DOCTYPE html> <html lang"zh-CN&quo…

如何通過日志先行原則保障數據持久化:Redis AOF 和 MySQL redo log 的對比

在分布式系統或數據庫管理系統中&#xff0c;日志先行原則&#xff08;Write-Ahead Logging&#xff0c;WAL&#xff09; 是確保數據一致性、持久性和恢復能力的重要機制。通過 WAL&#xff0c;系統能夠在發生故障時恢復數據&#xff0c;保證數據的可靠性。在這篇博客中&#x…

臨床研究三千問——臨床研究體系的3個維度(8)

在上周的文章中&#xff0c;我們共同探討了1345-10戰策的“臨床研究的起點——如何提出一個犀利的臨床與科學問題”。問題固然是靈魂&#xff0c;但若沒有堅實的骨架與血肉&#xff0c;靈魂便無所依歸。今天&#xff0c;我們將深入“1345-10戰策”中的“3”&#xff0c;即支撐起…

AI+預測3D新模型百十個定位預測+膽碼預測+去和尾2025年9月7日第172彈

從今天開始&#xff0c;咱們還是暫時基于舊的模型進行預測&#xff0c;好了&#xff0c;廢話不多說&#xff0c;按照老辦法&#xff0c;重點8-9碼定位&#xff0c;配合三膽下1或下2&#xff0c;殺1-2個和尾&#xff0c;再殺4-5個和值&#xff0c;可以做到100-300注左右。(1)定位…

萬字詳解網絡編程之socket

一&#xff0c;socket簡介1.什么是socketsocket通常也稱作"套接字"&#xff0c;?于描述IP地址和端?&#xff0c;是?個通信鏈的句柄&#xff0c;應用程序通常通過"套接字"向?絡發出請求或者應答?絡請求。?絡通信就是兩個進程間的通信&#xff0c;這兩…