從代碼學習深度學習 - 預訓練word2vec PyTorch版

文章目錄

  • 前言
  • 輔助工具
    • 1. 繪圖工具 (`utils_for_huitu.py`)
    • 2. 數據處理工具 (`utils_for_data.py`)
    • 3. 訓練輔助工具 (`utils_for_train.py`)
  • 預訓練 Word2Vec - 主流程
    • 1. 環境設置與數據加載
    • 2. 跳元模型 (Skip-gram Model)
      • 2.1. 嵌入層 (Embedding Layer)
      • 2.2. 定義前向傳播
    • 3. 訓練
      • 3.1. 二元交叉熵損失
      • 3.2. 初始化模型參數
      • 3.3. 定義訓練階段代碼
      • 3.4. 開始訓練
    • 4. 應用詞嵌入
  • 總結


前言

詞嵌入(Word Embeddings)是自然語言處理(NLP)領域中的基石技術之一。它們將詞語從稀疏的、高維的獨熱編碼(one-hot encoding)表示轉換為稠密的、低維的向量表示。這些向量能夠捕捉詞語之間的語義和句法關系,使得相似的詞在向量空間中距離更近。Word2Vec是其中一種非常流行且有效的詞嵌入算法,由Google的Tomas Mikolov等人在2013年提出。它主要包含兩種模型架構:CBOW(Continuous Bag-of-Words,連續詞袋模型)和Skip-gram(跳字模型)。

本篇博客將聚焦于Skip-gram模型,并結合**負采樣(Negative Sampling)**這一重要的優化技巧,通過PyTorch框架從零開始實現一個Word2Vec模型。我們將詳細探討數據預處理的每一個步驟,如何構建模型,如何進行訓練,以及訓練完成后如何應用得到的詞向量來尋找相似詞。通過深入代碼細節,我們希望能幫助讀者更好地理解Word2Vec的內部工作原理及其在PyTorch中的實現。

我們將依賴一系列輔助腳本來處理數據、可視化訓練過程以及進行模型訓練。讓我們一步步揭開Word2Vec的神秘面紗。

完整代碼:下載鏈接

輔助工具

在構建和訓練Word2Vec模型之前,我們首先介紹一下項目中用到的一些輔助Python腳本。這些腳本提供了數據加載、預處理、可視化以及訓練監控等常用功能。

1. 繪圖工具 (utils_for_huitu.py)

這個腳本主要封裝了使用matplotlib進行繪圖的常用函數,特別是在Jupyter Notebook環境中,它包含了一個Animator類,可以動態地展示訓練過程中的損失變化。

# 導入必要的包
import matplotlib.pyplot as plt  # 用于創建和操作 Matplotlib 圖表
from matplotlib_inline import backend_inline  # 用于在Jupyter中設置Matplotlib輸出格式
from IPython import display  # 用于后續動態顯示(如 Animator)
import torch  # 導入PyTorch庫,用于處理張量類型的圖像
import numpy as np  # 導入NumPy,可能用于數據處理
import matplotlib as mpl  # 導入Matplotlib主模塊,用于設置圖像屬性def set_figsize(figsize=(3.5, 2.5)):"""設置matplotlib圖形的大小參數:figsize: tuple[float, float] - 圖形大小,形狀為 (寬度, 高度),單位為英寸輸出:無返回值"""plt.rcParams['figure.figsize'] = figsize  # 設置圖形默認大小def use_svg_display():"""使用 SVG 格式在 Jupyter 中顯示繪圖輸入:無輸出:無返回值"""backend_inline.set_matplotlib_formats('svg')  # 設置 Matplotlib 使用 SVG 格式def set_axes(axes, xlabel, ylabel, xlim, ylim, xscale, yscale, legend):"""設置 Matplotlib 的軸  輸入:axes: Matplotlib 的軸對象  # 輸入參數:軸對象xlabel: x 軸標簽  # 輸入參數:x 軸標簽ylabel: y 軸標簽  # 輸入參數:y 軸標簽xlim: x 軸范圍  # 輸入參數:x 軸范圍ylim: y 軸范圍  # 輸入參數:y 軸范圍xscale: x 軸刻度類型  # 輸入參數:x 軸刻度類型yscale: y 軸刻度類型  # 輸入參數:y 軸刻度類型legend: 圖例標簽列表  # 輸入參數:圖例標簽輸出:無返回值  # 函數無顯式返回值"""axes.set_xlabel(xlabel)  # 設置 x 軸標簽axes.set_ylabel(ylabel)  # 設置 y 軸標簽axes.set_xscale(xscale)  # 設置 x 軸刻度類型axes.set_yscale(yscale)  # 設置 y 軸刻度類型axes.set_xlim(xlim)  # 設置 x 軸范圍axes.set_ylim(ylim)  # 設置 y 軸范圍if legend:  # 檢查是否提供了圖例標簽axes.legend(legend)  # 如果有圖例,則設置圖例axes.grid()  # 為軸添加網格線class Animator:"""在動畫中繪制數據,僅針對一張圖的情況"""def __init__(self, xlabel=None, ylabel=None, legend=None, xlim=None,ylim=None, xscale='linear', yscale='linear',fmts=('-', 'm--', 'g-.', 'r:'), nrows=1, ncols=1,figsize=(3.5, 2.5)):"""初始化 Animator 類 輸入:xlabel: x 軸標簽,默認為 None  # 輸入參數:x 軸標簽ylabel: y 軸標簽,默認為 None  # 輸入參數:y 軸標簽legend: 圖例標簽列表,默認為 None  # 輸入參數:圖例標簽xlim: x 軸范圍,默認為 None  # 輸入參數:x 軸范圍ylim: y 軸范圍,默認為 None  # 輸入參數:y 軸范圍xscale: x 軸刻度類型,默認為 'linear'  # 輸入參數:x 軸刻度類型yscale: y 軸刻度類型,默認為 'linear'  # 輸入參數:y 軸刻度類型fmts: 繪圖格式元組,默認為 ('-', 'm--', 'g-.', 'r:')  # 輸入參數:線條格式nrows: 子圖行數,默認為 1  # 輸入參數:子圖行數ncols: 子圖列數,默認為 1  # 輸入參數:子圖列數figsize: 圖像大小元組,默認為 (3.5, 2.5)  # 輸入參數:圖像大小輸出:無返回值  # 方法無顯式返回值定義位置::numref:`sec_softmax_scratch`  # 指明定義的參考位置"""if legend is None:  # 檢查 legend 是否為 Nonelegend = []  # 如果為 None,則初始化為空列表use_svg_display()  # 設置繪圖顯示為 SVG 格式self.fig, self.axes = plt.subplots(nrows, ncols, figsize=figsize)  # 創建繪圖對象和子圖if nrows * ncols == 1:  # 判斷是否只有一個子圖self.axes = [self.axes, ]  # 如果是單個子圖,將 axes 轉為列表self.config_axes = lambda: set_axes(  # 定義 lambda 函數配置坐標軸self.axes[0], xlabel, ylabel, xlim, ylim, xscale, yscale, legend)  # 調用 set_axes 設置參數self.X, self.Y, self.fmts = None, None, fmts  # 初始化數據和格式屬性def add(self, x, y):"""向圖表中添加多個數據點  輸入:x: x 軸數據點  # 輸入參數:x 軸數據y: y 軸數據點  # 輸入參數:y 軸數據輸出:無返回值  # 方法無顯式返回值"""if not hasattr(y, "__len__"):  # 檢查 y 是否具有長度屬性(是否可迭代)y = [y]  # 如果不可迭代,將 y 轉為單元素列表n = len(y)  # 獲取 y 的長度if not hasattr(x, "__len__"):  # 檢查 x 是否具有長度屬性x = [x] * n  # 如果不可迭代,將 x 擴展為與 y 同長度的列表if not self.X:  # 檢查 self.X 是否已初始化self.X = [[] for _ in range(n)]  # 如果未初始化,為每條線創建空列表if not self.Y:  # 檢查 self.Y 是否已初始化self.Y = [[] for _ in range(n)]  # 如果未初始化,為每條線創建空列表for i, (a, b) in enumerate(zip(x, y)):  # 遍歷 x 和 y 的數據對if a is not None and b is not None:  # 檢查數據點是否有效self.X[i].append(a)  # 將 x 數據點添加到對應列表self.Y[i].append(b)  # 將 y 數據點添加到對應列表self.axes[0].cla()  # 清除當前軸的內容for x, y, fmt in zip(self.X, self.Y, self.fmts):  # 遍歷所有數據和格式self.axes[0].plot(x, y, fmt)  # 繪制每條線self.config_axes()  # 調用 lambda 函數配置坐標軸display.display(self.fig)  # 顯示當前圖形display.clear_output(wait=True)  # 標記當前輸出為待清除,但由于 wait=True,它不會立即清除,而是等待下一次 display.display()。def show_list_len_pair_hist(legend, xlabel, ylabel, xlist, ylist):"""繪制列表長度對的直方圖,用于比較兩組列表中元素長度的分布參數:legend: list[str] - 圖例標簽,形狀為 (2,),分別對應xlist和ylist的標簽xlabel: str - x軸標簽ylabel: str - y軸標簽xlist: list[list] - 第一組列表,形狀為 (樣本數量, 每個樣本的元素數)ylist: list[list] - 第二組列表,形狀為 (樣本數量, 每個樣本的元素數)輸出:無返回值,但會顯示生成的直方圖"""set_figsize()  # 設置圖形大小# plt.hist返回的三個值:# n: list[array] - 每個bin中的樣本數量,形狀為 (2, bin數量)# bins: array - bin的邊界值,形狀為 (bin數量+1,)# patches: list[list[Rectangle]] - 直方圖的矩形對象,形狀為 (2, bin數量)_, _, patches = plt.hist([[len(l) for l in xlist], [len(l) for l in ylist]])  # 繪制兩組數據長度的直方圖plt.xlabel(xlabel)  # 設置x軸標簽plt.ylabel(ylabel)  # 設置y軸標簽# 為第二組數據(ylist)的直方圖添加斜線圖案,以區分兩組數據for patch in patches[1].patches:  # patches[1]是ylist對應的矩形對象列表patch.set_hatch('/')  # 設置填充圖案為斜線plt.legend(legend)  # 添加圖例

解讀

  • set_figsizeuse_svg_display 用于基礎的Matplotlib繪圖設置。
  • set_axes 是一個通用的函數,用于配置圖表的坐標軸標簽、范圍、刻度類型和圖例。
  • Animator 類是實現動態繪圖的關鍵。在訓練循環中,我們可以周期性地調用其add方法,傳入當前的訓練輪次(或迭代次數)和對應的損失值(或其他指標)。Animator會清除舊的圖像并重新繪制,從而在Jupyter Notebook中形成動畫效果,直觀地展示訓練趨勢。
  • show_list_len_pair_hist 函數用于繪制兩個列表集合中,各子列表長度分布的直方圖,方便進行數據分析和比較。

2. 數據處理工具 (utils_for_data.py)

這個腳本是Word2Vec數據預處理的核心,包含了從讀取原始文本、構建詞匯表、下采樣、生成中心詞-上下文詞對、負采樣到最終打包成PyTorch DataLoader的完整流程。

from collections import Counter  # 導入 Counter 類
from collections import Counter  # 用于詞頻統計
import torch  # PyTorch 核心庫
from torch.utils import data  # PyTorch 數據加載工具
import numpy as np  # NumPy 用于數組操作
import random  # 導入隨機模塊,用于下采樣和負采樣
import math  # 導入數學函數模塊,用于概率計算
import osdef count_corpus(tokens):"""統計詞元的頻率參數:tokens: 詞元列表,可以是:- 一維列表,例如 ['a', 'b']- 二維列表,例如 [['a', 'b'], ['c']]返回值:Counter: Counter 對象,統計每個詞元的出現次數"""# 如果輸入為空列表,直接返回空計數器if not tokens:  # 等價于 len(tokens) == 0return Counter()# 檢查輸入是否為二維列表if isinstance(tokens[0], list):# 將二維列表展平為一維列表flattened_tokens = [token for sublist in tokens for token in sublist]else:# 如果是一維列表,直接使用原列表flattened_tokens = tokens# 使用 Counter 統計詞頻并返回return Counter(flattened_tokens)class Vocab:"""文本詞表類,用于管理詞元及其索引的映射關系"""def __init__(self, tokens=None, min_freq=0, reserved_tokens=None):"""初始化詞表Args:tokens: 輸入的詞元列表,可以是1D或2D列表,默認為空列表min_freq: 詞元最小出現頻率,小于此頻率的詞元將被忽略,默認為0reserved_tokens: 預留的特殊詞元列表(如'<pad>'),默認為空列表"""# 處理默認參數self.tokens = tokens if tokens is not None else []self.reserved_tokens = reserved_tokens if reserved_tokens is not None else []# 統計詞元頻率并按頻率降序排序counter = self._count_corpus(self.tokens)self._token_freqs = sorted(counter.items(), key=lambda x: x[1], reverse=True)# 初始化詞表,'<unk>'為未知詞元,索引為0self.idx_to_token = ['<unk>'] + self.reserved_tokensself.token_to_idx = {token: idx for idx, token in enumerate(self.idx_to_token)}# 添加滿足最小頻率要求的詞元到詞表for token, freq in self._token_freqs:if freq < min_freq:breakif token not in self.token_to_idx:self.idx_to_token.append(token)self.token_to_idx[token] = 

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/906810.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/906810.shtml
英文地址,請注明出處:http://en.pswp.cn/news/906810.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

Python實現對大批量Word文檔進行自動添加頁碼(16)

前言 本文是該專欄的第16篇,后面會持續分享Python辦公自動化干貨知識,記得關注。 在處理word文檔的時候,相信或多或少都遇到過這樣的需求——需要對“目標word文檔,自動添加頁碼”。 換言之,如果有大批量的word文檔文件需要你添加頁碼,這個時候最聰明的辦法就是使用“程…

云原生安全:Linux命令行操作全解析

&#x1f525;「炎碼工坊」技術彈藥已裝填&#xff01; 點擊關注 → 解鎖工業級干貨【工具實測|項目避坑|源碼燃燒指南】 ——從基礎概念到安全實踐的完整指南 一、基礎概念 1. Shell與終端交互 Shell是Linux命令行的解釋器&#xff08;如Bash、Zsh&#xff09;&#xff0c;負…

Day 34

GPU訓練 要讓模型在 GPU 上訓練&#xff0c;主要是將模型和數據遷移到 GPU 設備上。 在 PyTorch 里&#xff0c;.to(device) 方法的作用是把張量或者模型轉移到指定的計算設備&#xff08;像 CPU 或者 GPU&#xff09;上。 對于張量&#xff08;Tensor&#xff09;&#xff1…

C++筆試題(金山科技新未來訓練營):

題目分布&#xff1a; 17道單選&#xff08;每題3分&#xff09;3道多選題&#xff08;全對3分&#xff0c;部分對1分&#xff09;2道編程題&#xff08;每一道20分&#xff09;。 不過題目太多&#xff0c;就記得一部分了&#xff1a; 單選題&#xff1a; static變量的初始…

Spark(29)基礎自定義分區器

&#xff08;一&#xff09;什么是分區 【復習提問&#xff1a;RDD的定義是什么&#xff1f;】 在 Spark 里&#xff0c;彈性分布式數據集&#xff08;RDD&#xff09;是核心的數據抽象&#xff0c;它是不可變的、可分區的、里面的元素并行計算的集合。 在 Spark 中&#xf…

python打卡訓練營打卡記錄day35

知識點回顧&#xff1a; 三種不同的模型可視化方法&#xff1a;推薦torchinfo打印summary權重分布可視化進度條功能&#xff1a;手動和自動寫法&#xff0c;讓打印結果更加美觀推理的寫法&#xff1a;評估模式 作業&#xff1a;調整模型定義時的超參數&#xff0c;對比下效果 1…

【MySQL】07.表內容的操作

1. insert 我們先創建一個表結構&#xff0c;這部分操作我們使用這張表完成我們的操作&#xff1a; mysql> create table student(-> id int primary key auto_increment,-> name varchar(20) not null,-> qq varchar(20) unique-> ); Query OK, 0 rows affec…

使用SQLite Expert個人版VACUUM功能修復數據庫

使用SQLite Expert個人版VACUUM功能修復數據庫 一、SQLite Expert工具簡介 SQLite Expert 是一款功能強大的SQLite數據庫管理工具&#xff0c;分為免費的個人版&#xff08;Personal Edition&#xff09;和收費的專業版&#xff08;Professional Edition&#xff09;。其核心功…

LM-BFF——語言模型微調新范式

gpt3&#xff08;GPT3——少樣本示例推動下的通用語言模型雛形)結合提示詞和少樣本示例后&#xff0c;展示出了強大性能。但大語言模型的訓練門檻太高&#xff0c;普通研究人員無力&#xff0c;LM-BFF(Making Pre-trained Language Models Better Few-shot Learners)的作者受gp…

遙感解譯項目Land-Cover-Semantic-Segmentation-PyTorch之二訓練模型

遙感解譯項目Land-Cover-Semantic-Segmentation-PyTorch之一推理模型 背景 上一篇文章了解了這個項目的環境安裝和模型推理,這篇文章介紹下如何訓練這個模型,添加類別 下載數據集 在之前的一篇文章中,也有用到這個數據集 QGIS之三十六Deepness插件實現AI遙感訓練模型 數…

【NLP 71、常見大模型的模型結構對比】

三到五年的深耕&#xff0c;足夠讓你成為一個你想成為的人 —— 25.5.8 模型名稱位置編碼Transformer結構多頭機制Feed Forward層設計歸一化層設計線性層偏置項激活函數訓練數據規模及來源參數量應用場景側重GPT-5 (OpenAI)RoPE動態相對編碼混合專家架構&#xff08;MoE&#…

[250521] DBeaver 25.0.5 發布:SQL 編輯器、導航器全面升級,新增 Kingbase 支持!

目錄 DBeaver 25.0.5 發布&#xff1a;SQL 編輯器、導航器全面升級&#xff0c;新增 Kingbase 支持&#xff01; DBeaver 25.0.5 發布&#xff1a;SQL 編輯器、導航器全面升級&#xff0c;新增 Kingbase 支持&#xff01; 近日&#xff0c;DBeaver 發布了 25.0.5 版本&#xf…

服務器硬盤虛擬卷的處理

目前的情況是需要刪除邏輯卷&#xff0c;然后再重新來弄一遍。 數據已經備份好了&#xff0c;所以不用擔心數據會丟失。 查看服務器的具體情況 使用 vgdisplay 操作查看服務器的卷組情況&#xff1a; --- Volume group ---VG Name vg01System IDFormat …

Flutter 中 build 方法為何寫在 StatefulWidget 的 State 類中

Flutter 中 build 方法為何寫在 StatefulWidget 的 State 類中 在 Flutter 中&#xff0c;build 方法被設計在 StatefulWidget 的 State 類中而非 StatefulWidget 類本身&#xff0c;這種設計基于幾個重要的架構原則和實際考量&#xff1a; 1. 核心設計原因 1.1 生命周期管理…

傳統醫療系統文檔集中標準化存儲和AI智能化更新路徑分析

引言 隨著醫療數智化建設的深入推進&#xff0c;傳統醫療系統如醫院信息系統(HIS)、臨床信息系統(CIS)、護理信息系統(NIS)、影像歸檔與通信系統(PACS)和實驗室信息系統(LIS)已經成為了現代醫療機構不可或缺的技術基礎設施。這些系統各自承擔著不同的功能&#xff0c;共同支撐…

探索常識性概念圖譜:構建智能生活的知識橋梁

目錄 一、知識圖譜背景介紹 &#xff08;一&#xff09;基本背景 &#xff08;二&#xff09;與NLP的關系 &#xff08;三&#xff09;常識性概念圖譜的引入對比 二、常識性概念圖譜介紹 &#xff08;一&#xff09;常識性概念圖譜關系圖示例 &#xff08;二&#xff09…

Linux/aarch64架構下安裝Python的Orekit開發環境

1.背景 國產化趨勢越來越強&#xff0c;從軟件到硬件&#xff0c;從操作系統到CPU&#xff0c;甚至顯卡&#xff0c;就產生了在國產ARM CPU和Kylin系統下部署Orekit的需求&#xff0c;且之前的開發是基于Python的&#xff0c;需要做適配。 2.X86架構下安裝Python/Orekit開發環…

Ctrl+鼠標滾動阻止頁面放大/縮小

項目場景&#xff1a; 提示&#xff1a;這里簡述項目相關背景&#xff1a; 一般在我們做大屏的時候&#xff0c;不希望Ctrl鼠標上下滾動的時候頁面會放大/縮小&#xff0c;那么在有時候&#xff0c;又不希望影響到別的頁面&#xff0c;比如說這個大屏是在另一個管理后臺中&am…

MySQL——復合查詢表的內外連

目錄 復合查詢 回顧基本查詢 多表查詢 自連接 子查詢 where 字句中使用子查詢 單行子查詢 多行子查詢 多列子查詢 from 字句中使用子查詢 合并查詢 實戰OJ 查找所有員工入職時候的薪水情況 獲取所有非manager的員工emp_no 獲取所有員工當前的manager 表的內外…

聊一下CSS中的標準流,浮動流,文本流,文檔流

在網絡上關于CSS的文章中&#xff0c;有時候能聽到“標準流”&#xff0c;“浮動流”&#xff0c;“定位流”等等詞語&#xff0c;還有像“文檔流”&#xff0c;“文本流”等詞&#xff0c;這些流是什么意思&#xff1f;它們是CSS中的一些布局方案和特性。今天我們就來聊一下CS…