文章目錄
- 前言
- 輔助工具
- 1. 繪圖工具 (`utils_for_huitu.py`)
- 2. 數據處理工具 (`utils_for_data.py`)
- 3. 訓練輔助工具 (`utils_for_train.py`)
- 預訓練 Word2Vec - 主流程
- 1. 環境設置與數據加載
- 2. 跳元模型 (Skip-gram Model)
- 2.1. 嵌入層 (Embedding Layer)
- 2.2. 定義前向傳播
- 3. 訓練
- 3.1. 二元交叉熵損失
- 3.2. 初始化模型參數
- 3.3. 定義訓練階段代碼
- 3.4. 開始訓練
- 4. 應用詞嵌入
- 總結
前言
詞嵌入(Word Embeddings)是自然語言處理(NLP)領域中的基石技術之一。它們將詞語從稀疏的、高維的獨熱編碼(one-hot encoding)表示轉換為稠密的、低維的向量表示。這些向量能夠捕捉詞語之間的語義和句法關系,使得相似的詞在向量空間中距離更近。Word2Vec是其中一種非常流行且有效的詞嵌入算法,由Google的Tomas Mikolov等人在2013年提出。它主要包含兩種模型架構:CBOW(Continuous Bag-of-Words,連續詞袋模型)和Skip-gram(跳字模型)。
本篇博客將聚焦于Skip-gram模型,并結合**負采樣(Negative Sampling)**這一重要的優化技巧,通過PyTorch框架從零開始實現一個Word2Vec模型。我們將詳細探討數據預處理的每一個步驟,如何構建模型,如何進行訓練,以及訓練完成后如何應用得到的詞向量來尋找相似詞。通過深入代碼細節,我們希望能幫助讀者更好地理解Word2Vec的內部工作原理及其在PyTorch中的實現。
我們將依賴一系列輔助腳本來處理數據、可視化訓練過程以及進行模型訓練。讓我們一步步揭開Word2Vec的神秘面紗。
完整代碼:下載鏈接
輔助工具
在構建和訓練Word2Vec模型之前,我們首先介紹一下項目中用到的一些輔助Python腳本。這些腳本提供了數據加載、預處理、可視化以及訓練監控等常用功能。
1. 繪圖工具 (utils_for_huitu.py
)
這個腳本主要封裝了使用matplotlib
進行繪圖的常用函數,特別是在Jupyter Notebook環境中,它包含了一個Animator
類,可以動態地展示訓練過程中的損失變化。
# 導入必要的包
import matplotlib.pyplot as plt # 用于創建和操作 Matplotlib 圖表
from matplotlib_inline import backend_inline # 用于在Jupyter中設置Matplotlib輸出格式
from IPython import display # 用于后續動態顯示(如 Animator)
import torch # 導入PyTorch庫,用于處理張量類型的圖像
import numpy as np # 導入NumPy,可能用于數據處理
import matplotlib as mpl # 導入Matplotlib主模塊,用于設置圖像屬性def set_figsize(figsize=(3.5, 2.5)):"""設置matplotlib圖形的大小參數:figsize: tuple[float, float] - 圖形大小,形狀為 (寬度, 高度),單位為英寸輸出:無返回值"""plt.rcParams['figure.figsize'] = figsize # 設置圖形默認大小def use_svg_display():"""使用 SVG 格式在 Jupyter 中顯示繪圖輸入:無輸出:無返回值"""backend_inline.set_matplotlib_formats('svg') # 設置 Matplotlib 使用 SVG 格式def set_axes(axes, xlabel, ylabel, xlim, ylim, xscale, yscale, legend):"""設置 Matplotlib 的軸 輸入:axes: Matplotlib 的軸對象 # 輸入參數:軸對象xlabel: x 軸標簽 # 輸入參數:x 軸標簽ylabel: y 軸標簽 # 輸入參數:y 軸標簽xlim: x 軸范圍 # 輸入參數:x 軸范圍ylim: y 軸范圍 # 輸入參數:y 軸范圍xscale: x 軸刻度類型 # 輸入參數:x 軸刻度類型yscale: y 軸刻度類型 # 輸入參數:y 軸刻度類型legend: 圖例標簽列表 # 輸入參數:圖例標簽輸出:無返回值 # 函數無顯式返回值"""axes.set_xlabel(xlabel) # 設置 x 軸標簽axes.set_ylabel(ylabel) # 設置 y 軸標簽axes.set_xscale(xscale) # 設置 x 軸刻度類型axes.set_yscale(yscale) # 設置 y 軸刻度類型axes.set_xlim(xlim) # 設置 x 軸范圍axes.set_ylim(ylim) # 設置 y 軸范圍if legend: # 檢查是否提供了圖例標簽axes.legend(legend) # 如果有圖例,則設置圖例axes.grid() # 為軸添加網格線class Animator:"""在動畫中繪制數據,僅針對一張圖的情況"""def __init__(self, xlabel=None, ylabel=None, legend=None, xlim=None,ylim=None, xscale='linear', yscale='linear',fmts=('-', 'm--', 'g-.', 'r:'), nrows=1, ncols=1,figsize=(3.5, 2.5)):"""初始化 Animator 類 輸入:xlabel: x 軸標簽,默認為 None # 輸入參數:x 軸標簽ylabel: y 軸標簽,默認為 None # 輸入參數:y 軸標簽legend: 圖例標簽列表,默認為 None # 輸入參數:圖例標簽xlim: x 軸范圍,默認為 None # 輸入參數:x 軸范圍ylim: y 軸范圍,默認為 None # 輸入參數:y 軸范圍xscale: x 軸刻度類型,默認為 'linear' # 輸入參數:x 軸刻度類型yscale: y 軸刻度類型,默認為 'linear' # 輸入參數:y 軸刻度類型fmts: 繪圖格式元組,默認為 ('-', 'm--', 'g-.', 'r:') # 輸入參數:線條格式nrows: 子圖行數,默認為 1 # 輸入參數:子圖行數ncols: 子圖列數,默認為 1 # 輸入參數:子圖列數figsize: 圖像大小元組,默認為 (3.5, 2.5) # 輸入參數:圖像大小輸出:無返回值 # 方法無顯式返回值定義位置::numref:`sec_softmax_scratch` # 指明定義的參考位置"""if legend is None: # 檢查 legend 是否為 Nonelegend = [] # 如果為 None,則初始化為空列表use_svg_display() # 設置繪圖顯示為 SVG 格式self.fig, self.axes = plt.subplots(nrows, ncols, figsize=figsize) # 創建繪圖對象和子圖if nrows * ncols == 1: # 判斷是否只有一個子圖self.axes = [self.axes, ] # 如果是單個子圖,將 axes 轉為列表self.config_axes = lambda: set_axes( # 定義 lambda 函數配置坐標軸self.axes[0], xlabel, ylabel, xlim, ylim, xscale, yscale, legend) # 調用 set_axes 設置參數self.X, self.Y, self.fmts = None, None, fmts # 初始化數據和格式屬性def add(self, x, y):"""向圖表中添加多個數據點 輸入:x: x 軸數據點 # 輸入參數:x 軸數據y: y 軸數據點 # 輸入參數:y 軸數據輸出:無返回值 # 方法無顯式返回值"""if not hasattr(y, "__len__"): # 檢查 y 是否具有長度屬性(是否可迭代)y = [y] # 如果不可迭代,將 y 轉為單元素列表n = len(y) # 獲取 y 的長度if not hasattr(x, "__len__"): # 檢查 x 是否具有長度屬性x = [x] * n # 如果不可迭代,將 x 擴展為與 y 同長度的列表if not self.X: # 檢查 self.X 是否已初始化self.X = [[] for _ in range(n)] # 如果未初始化,為每條線創建空列表if not self.Y: # 檢查 self.Y 是否已初始化self.Y = [[] for _ in range(n)] # 如果未初始化,為每條線創建空列表for i, (a, b) in enumerate(zip(x, y)): # 遍歷 x 和 y 的數據對if a is not None and b is not None: # 檢查數據點是否有效self.X[i].append(a) # 將 x 數據點添加到對應列表self.Y[i].append(b) # 將 y 數據點添加到對應列表self.axes[0].cla() # 清除當前軸的內容for x, y, fmt in zip(self.X, self.Y, self.fmts): # 遍歷所有數據和格式self.axes[0].plot(x, y, fmt) # 繪制每條線self.config_axes() # 調用 lambda 函數配置坐標軸display.display(self.fig) # 顯示當前圖形display.clear_output(wait=True) # 標記當前輸出為待清除,但由于 wait=True,它不會立即清除,而是等待下一次 display.display()。def show_list_len_pair_hist(legend, xlabel, ylabel, xlist, ylist):"""繪制列表長度對的直方圖,用于比較兩組列表中元素長度的分布參數:legend: list[str] - 圖例標簽,形狀為 (2,),分別對應xlist和ylist的標簽xlabel: str - x軸標簽ylabel: str - y軸標簽xlist: list[list] - 第一組列表,形狀為 (樣本數量, 每個樣本的元素數)ylist: list[list] - 第二組列表,形狀為 (樣本數量, 每個樣本的元素數)輸出:無返回值,但會顯示生成的直方圖"""set_figsize() # 設置圖形大小# plt.hist返回的三個值:# n: list[array] - 每個bin中的樣本數量,形狀為 (2, bin數量)# bins: array - bin的邊界值,形狀為 (bin數量+1,)# patches: list[list[Rectangle]] - 直方圖的矩形對象,形狀為 (2, bin數量)_, _, patches = plt.hist([[len(l) for l in xlist], [len(l) for l in ylist]]) # 繪制兩組數據長度的直方圖plt.xlabel(xlabel) # 設置x軸標簽plt.ylabel(ylabel) # 設置y軸標簽# 為第二組數據(ylist)的直方圖添加斜線圖案,以區分兩組數據for patch in patches[1].patches: # patches[1]是ylist對應的矩形對象列表patch.set_hatch('/') # 設置填充圖案為斜線plt.legend(legend) # 添加圖例
解讀:
set_figsize
和use_svg_display
用于基礎的Matplotlib繪圖設置。set_axes
是一個通用的函數,用于配置圖表的坐標軸標簽、范圍、刻度類型和圖例。Animator
類是實現動態繪圖的關鍵。在訓練循環中,我們可以周期性地調用其add
方法,傳入當前的訓練輪次(或迭代次數)和對應的損失值(或其他指標)。Animator
會清除舊的圖像并重新繪制,從而在Jupyter Notebook中形成動畫效果,直觀地展示訓練趨勢。show_list_len_pair_hist
函數用于繪制兩個列表集合中,各子列表長度分布的直方圖,方便進行數據分析和比較。
2. 數據處理工具 (utils_for_data.py
)
這個腳本是Word2Vec數據預處理的核心,包含了從讀取原始文本、構建詞匯表、下采樣、生成中心詞-上下文詞對、負采樣到最終打包成PyTorch DataLoader
的完整流程。
from collections import Counter # 導入 Counter 類
from collections import Counter # 用于詞頻統計
import torch # PyTorch 核心庫
from torch.utils import data # PyTorch 數據加載工具
import numpy as np # NumPy 用于數組操作
import random # 導入隨機模塊,用于下采樣和負采樣
import math # 導入數學函數模塊,用于概率計算
import osdef count_corpus(tokens):"""統計詞元的頻率參數:tokens: 詞元列表,可以是:- 一維列表,例如 ['a', 'b']- 二維列表,例如 [['a', 'b'], ['c']]返回值:Counter: Counter 對象,統計每個詞元的出現次數"""# 如果輸入為空列表,直接返回空計數器if not tokens: # 等價于 len(tokens) == 0return Counter()# 檢查輸入是否為二維列表if isinstance(tokens[0], list):# 將二維列表展平為一維列表flattened_tokens = [token for sublist in tokens for token in sublist]else:# 如果是一維列表,直接使用原列表flattened_tokens = tokens# 使用 Counter 統計詞頻并返回return Counter(flattened_tokens)class Vocab:"""文本詞表類,用于管理詞元及其索引的映射關系"""def __init__(self, tokens=None, min_freq=0, reserved_tokens=None):"""初始化詞表Args:tokens: 輸入的詞元列表,可以是1D或2D列表,默認為空列表min_freq: 詞元最小出現頻率,小于此頻率的詞元將被忽略,默認為0reserved_tokens: 預留的特殊詞元列表(如'<pad>'),默認為空列表"""# 處理默認參數self.tokens = tokens if tokens is not None else []self.reserved_tokens = reserved_tokens if reserved_tokens is not None else []# 統計詞元頻率并按頻率降序排序counter = self._count_corpus(self.tokens)self._token_freqs = sorted(counter.items(), key=lambda x: x[1], reverse=True)# 初始化詞表,'<unk>'為未知詞元,索引為0self.idx_to_token = ['<unk>'] + self.reserved_tokensself.token_to_idx = {token: idx for idx, token in enumerate(self.idx_to_token)}# 添加滿足最小頻率要求的詞元到詞表for token, freq in self._token_freqs:if freq < min_freq:breakif token not in self.token_to_idx:self.idx_to_token.append(token)self.token_to_idx[token] =