【Torch】nn.Embedding算法詳解

1. 定義

nn.Embedding 是 PyTorch 中的 查表式嵌入層(lookup‐table),用于將離散的整數索引(如詞 ID、實體 ID、離散特征類別等)映射到一個連續的、可訓練的低維向量空間。它通過維護一個形狀為 (num_embeddings, embedding_dim) 的權重矩陣,實現高效的“索引 → 向量”轉換。

2. 輸入與輸出

  • 輸入

    • 類型:整型張量(torch.longtorch.int64),必須是 LongTensor,其他類型會報錯。
    • 形狀:任意形狀 (*, L),其中最內層長度 L 常為序列長度,前面的 * 可以是 batch 及其他維度。
    • 取值范圍0 ≤ index < num_embeddings;超出范圍會拋出 IndexError
  • 輸出

    • 類型:浮點型張量(與權重相同的 dtype,默認為 torch.float32)。
    • 形狀(*, L, embedding_dim);就是在輸入張量后追加一個維度 embedding_dim
    • 語義:若輸入某位置的值為 j,則該位置對應輸出就是權重矩陣的第 j 行。

3. 底層原理

  1. 查表操作 vs. One-hot 乘法

    • 直觀上,Embedding 相當于:
      output = one_hot ( i n p u t ) × W \text{output} = \text{one\_hot}(input) \;\times\; W output=one_hot(input)×W
      其中 W(num_embeddings×embedding_dim) 的權重矩陣。
    • 為避免顯式構造稀疏的 one-hot 張量,PyTorch 直接根據索引做“取行”操作,效率更高、內存更省。
  2. 梯度更新

    • 稠密模式(默認):整個 W 都有梯度緩沖,優化器根據梯度更新所有行。
    • 稀疏模式sparse=True):僅對被索引過的行計算和存儲梯度,可配合 optim.SparseAdam 高效更新,適合超大字典(百萬級以上)但每次只訪問少量索引的場景。
  3. 范數裁剪

    • 若指定 max_norm,每次前向都會對輸出向量(即對應的行)做范數裁剪,保證其 L-norm_type 范數不超過 max_norm,有助于防止某些頻繁訪問的詞向量過大。
  4. 權重初始化

    • 默認初始化使用均勻分布:
      W i , j ~ U ( ? 1 num_embeddings , 1 num_embeddings ) W_{i,j} \sim \mathcal{U}\Bigl(-\sqrt{\tfrac{1}{\text{num\_embeddings}}},\;\sqrt{\tfrac{1}{\text{num\_embeddings}}}\Bigr) Wi,j?U(?num_embeddings1? ?,num_embeddings1? ?)
    • 可以通過 _weight 參數傳入外部預訓練權重(如 Word2Vec、GloVe 等)。

4. 構造函數參數詳解

參數類型及默認說明
num_embeddingsint必填。嵌入表行數,等于類別總數(最大索引 + 1)。
embedding_dimint必填。每個向量的維度。
padding_idxintNone默認 None。指定該索引對應行始終輸出全零,并且該行的梯度永遠為 0,適合做序列填充。
max_normfloatNone默認 None。若設為數值,每次前向時對取出的向量做范數裁剪(L-norm_typemax_norm)。
norm_typefloat,默認 2max_norm 配合使用時定義范數類型,如 1-范數、2-范數等。
scale_grad_by_freqbool,默認 False若為 True,在反向傳播階段按照索引在 batch 中出現的頻次對梯度做縮放(出現越多,梯度越小),有助于高頻詞的梯度平滑。
sparsebool,默認 False若為 True,開啟稀疏更新,僅對被訪問行生成梯度;必須配合 optim.SparseAdam 使用,不支持常規稠密優化器。
_weightTensorNone若提供,則用此張量(形狀應為 (num_embeddings, embedding_dim))作為權重初始化,否則隨機初始化。

5. 使用示例

import torch
import torch.nn as nn# 1. 參數設定
vocab_size = 10000   # 詞表大小
embed_dim  = 300     # 嵌入維度# 2. 創建 Embedding 層
embedding = nn.Embedding(num_embeddings=vocab_size,embedding_dim=embed_dim,padding_idx=0,         # 將 0 作為填充索引,輸出全 0max_norm=5.0,          # 向量范數不超過 5norm_type=2.0,scale_grad_by_freq=True,sparse=False
)# 3. 構造輸入
# batch_size=2, seq_len=6
input_ids = torch.tensor([[  1, 234,  56, 789,   0,  23],[123,   4, 567,   8,   9,   0],
], dtype=torch.long)# 4. 前向計算
# 輸出 shape = [2, 6, 300]
output = embedding(input_ids)
print(output.shape)  # -> torch.Size([2, 6, 300])

加載并凍結預訓練權重

import numpy as np# 假設有預訓練權重 pre_trained.npy,shape=(10000,300)
weights = torch.from_numpy(np.load("pre_trained.npy"))
embed_pre = nn.Embedding(num_embeddings=vocab_size,embedding_dim=embed_dim,_weight=weights
)
# 凍結所有權重
embed_pre.weight.requires_grad = False

6. 注意事項

  1. 類型與范圍
    • 輸入必須為 LongTensor,且所有索引滿足 0 ≤ index < num_embeddings
  2. Padding 與 Mask
    • 僅指定 padding_idx 會返回零向量,但上游網絡(如 RNN、Transformer)還需顯式 mask,避免無效位置影響注意力或累積狀態。
  3. 性能考量
    • max_norm 每次前向都做范數計算和裁剪,若不需要可關閉以提升速度。
  4. 稀疏更新限制
    • sparse=True 可節省內存,但只支持 SparseAdam,且在 GPU 上效率有時不如稠密模式。
  5. EmbeddingBag
    • 對于可變長度序列的 sum/mean/power-mean 匯聚,可使用 nn.EmbeddingBag,避免中間張量開銷。
  6. 分布式與大詞表
    • 在分布式訓練時,可將嵌入表切分到多個進程上(torch.nn.parallel.DistributedDataParallel + torch.nn.Embedding 支持參數分布式)。
    • 超大詞表(千萬級)時,可考慮動態加載、分布式哈希表或專用庫(如 DeepSpeed 的嵌入稀疏優化)。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/912175.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/912175.shtml
英文地址,請注明出處:http://en.pswp.cn/news/912175.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

cdq 三維偏序應用 / P4169 [Violet] 天使玩偶/SJY擺棋子

最近學了 cdq 分治想來做做這道題&#xff0c;結果被有些毒瘤的代碼惡心到了。 /ll 題目大意&#xff1a;一開始給定一些平面中的點。然后給定一些修改和詢問&#xff1a; 修改&#xff1a;增加一個點。詢問&#xff1a;給定一個點&#xff0c;求離這個點最近&#xff08;定義…

System.Threading.Tasks 庫簡介

System.Threading.Tasks 是 .NET 中任務并行庫(Task Parallel Library, TPL)的核心組件&#xff0c;它提供了基于任務的異步編程模型&#xff0c;是現代 .NET 并發編程的基礎。 設計原理 1. 核心目標 抽象并發工作&#xff1a;將并發操作抽象為"任務"概念 資源高效…

Python爬蟲實戰:研究jieba相關技術

1. 引言 1.1 研究背景與意義 隨著互聯網技術的飛速發展,網絡新聞已成為人們獲取信息的主要渠道之一。每天產生的新聞文本數據量呈爆炸式增長,如何從海量文本中高效提取有價值的信息,成為信息科學領域的重要研究課題。文本分析技術通過對文本內容的結構化處理和語義挖掘,能…

github 淘金技巧

1. 效率&#xff0c;搜索&#xff0c;先不管。后面再說。 2. 分享的話&#xff0c; 其實使用默認的分享功能也行。也是后面再說。此 app &#xff0c; 今天先做到這里。 下面我們再聊點其他東西。其實我還想問&#xff0c;這個事情&#xff0c;其他人是否也做了&#xff0c; ht…

RAG技術發展綜述

摘要 檢索增強生成&#xff08;Retrieval-Augmented Generation, RAG&#xff09;技術已成為大語言模型應用的核心技術棧。RAG有效解決了LLM的幻覺問題、知識截止和實時更新挑戰&#xff0c;目前正處于全面產業化階段。本文系統性地分析RAG的全棧技術架構&#xff0c;包括檢索…

集群聊天服務器---muduo庫(3)

使用muduo網絡庫進行編譯和鏈接的示例 項目的目錄結構 bin: 存放可執行文件。 lib: 存放庫文件。 include: 存放頭文件。 src: 存放源代碼文件。 build: 存放編譯生成的中間文件。 example: 存放示例代碼。 thirdparty: 存放第三方庫。 CMakeLists.txt: CMake構建系統…

雙核SOC/5340 應用和網絡核間通訊

1&#xff1a; 可以在 nRF Connect SDK 文件夾結構的 samples/ipc/ipc_service 下找到示例&#xff0c;應用和網絡核心在由 CONFIG_APP_IPC_SERVICE_SEND_INTERVAL 選項指定的時隙內相互發送數據。可以更改該值并觀察每個核心的吞吐量如何變化 nRF5340 DK 可以使用 RPMsg 或 IC…

Spring Cloud Ribbon核心負載均衡算法詳解

Ribbon 作為 Spring Cloud 生態中的客戶端負載均衡工具&#xff0c;提供多種動態負載均衡算法&#xff0c;根據后端服務狀態智能分配請求。其核心算法及適用場景如下&#xff1a; &#x1f9e0; 一、Ribbon 負載均衡算法 算法名稱工作原理引用來源輪詢 (RoundRobinRule)按服務…

網站圖片過于太大影響整體加載響應速度怎么辦? Typecho高級圖像處理插件

文章目錄 LeleImges - Typecho高級圖像處理插件 ???插件介紹 ??插件架構 ???主要功能 ?性能優勢 ??系統要求 ??安裝方法 ??詳細配置說明 ??圖片質量設置 ???最大寬度/高度限制 ??壓縮格式選擇 ???壓縮方法選擇 ??GIF處理方式 ???備份源文件 ??…

VUE3入門很簡單(1)--- 響應式對象

前言 重要提示&#xff1a;文章只適合初學者&#xff0c;不適合專家&#xff01;&#xff01;&#xff01; 什么是響應式對象&#xff1f; 在Vue3中&#xff0c;響應式對象就是這種智能溫控器。當你修改JavaScript對象的數據時&#xff0c;Vue會自動更新網頁上顯示的內容&am…

廣州華銳互動攜手中石油:AR 巡檢系統實現重大突破?

廣州華銳互動在 AR 技術領域的卓越成就&#xff0c;通過一系列與知名企業、機構的成功合作案例得以充分彰顯。其中&#xff0c;與中石油的合作項目堪稱經典&#xff0c;展現了廣州華銳互動運用 AR 技術解決實際難題、達成目標的強大實力。? 中石油作為能源行業的巨擘&#xff…

權威認證!華宇TAS應用中間件榮獲CCRC“中間件產品安全認證”

近日&#xff0c;華宇TAS應用中間件順利通過了中國網絡安全審查認證和市場監管大數據中心(CCRC)的信息安全認證&#xff0c;獲得了IT產品信息安全認證證書。此次獲證&#xff0c;標志著華宇TAS應用中間件在安全性、可靠性及合規性等方面達到行業領先水平&#xff0c;可以為政企…

BI財務分析 – 反映盈利水平利潤占比的指標如何分析(下)

之前的文章重點把構成銷售凈利率、主營業務利潤率、成本費用利潤率、營業利潤率、銷售毛利率的分母像銷售收入、營業收入、主營業務收入凈額、成本費用總額做了比較細致的說明&#xff0c;把這幾個基本的概念搞明白后&#xff0c;再來看這幾個指標就比較容易理解了。 銷售凈利…

竹云受邀出席華為開發者大會,與華為聯合發布海外政務數字化解決方案

6月20日-22日&#xff0c;華為開發者大會&#xff08;HDC 2025&#xff09;在東莞松山湖盛大召開。作為華為一年一度面向全球開發者的頂級科技盛會&#xff0c;今年的HDC不僅帶來了HarmonyOS 6.0 Beta版本、盤古大模型5.5等多項重磅技術和產品更新&#xff0c;更聚集了全球極客…

AI助力游戲設計——從靈感到行動-靠岸篇

OK&#xff0c;朋友&#xff0c;如果你到了這里&#xff0c;那就證明這趟旅程&#xff0c;快要到岸了。 首先&#xff0c;恭喜你&#xff0c;到了需要這一步的時候。其實&#xff0c;如果你有一天真的用到了&#xff0c;希望你可以回來打個卡。行了&#xff0c;不廢話&#xf…

vue將頁面導出pdf,vue導出pdf ,使用html2canvas和jspdf組件

vue導出pdf 需求&#xff1a;需要前端下載把當前html下載成pdf文件–有十八頁超長&#xff0c;之前使用vue-html2pdf組件&#xff0c;但是這個組件有長度限制和比較新瀏覽器版本限制&#xff0c;所以改成使用html2canvas和jspdf組件 方法&#xff1a; 1、第一步&#xff1a;我…

024 企業客戶管理系統技術解析:基于 Spring Boot 的全流程管理平臺

企業客戶管理系統技術解析&#xff1a;基于Spring Boot的全流程管理平臺 在企業數字化轉型的浪潮中&#xff0c;高效的客戶管理系統成為提升企業競爭力的關鍵工具。本文將深入解析基于Java和Spring Boot框架構建的企業客戶管理系統&#xff0c;該系統涵蓋員工管理、客戶信息管…

JavaScript性能優化代碼示例

JavaScript性能優化實戰大綱 性能優化的核心目標 減少加載時間、提升渲染效率、降低內存占用、優化交互響應 代碼層面的優化實踐 避免全局變量污染&#xff0c;使用局部變量和模塊化開發 減少DOM操作頻率&#xff0c;批量處理DOM更新 使用事件委托替代大量事件監聽器 優化循…

樹的重心(雙dfs,換根)

思路&#xff1a; 基于樹形 DP 的兩次遍歷&#xff08;第一次dfs計算以某個初始根&#xff08;這里選了 1&#xff09;為根時各子樹的深度和與節點數&#xff0c;第二次zy進行換根操作&#xff0c;更新每個節點作為根時的深度和&#xff09; 換根原理&#xff1a; 更換主根&…

官方App Store,直鏈下載macOS ,無需Apple ID,macOS10.10以上.

前言 想必很多人都有過維修老舊Mac的體驗,也有過想要重裝macos的體驗. 尤其是前者,想要重裝或者升級系統,由于官方已經無法更新,必須下載iSo鏡像 這時就會遇到死循環:想要更新macOS ,必須先使用更高版本的App Store,但要使用更高版本的App Store,必須先更新macOS !!! 如果想…