PyTorch中的nn.Embedding應用詳解

PyTorch


文章目錄

  • PyTorch
  • 前言
  • 一、nn.Embedding的基本原理
  • 二、nn.Embedding的實際應用
    • 簡單的例子
    • 自然語言處理任務


前言

在深度學習中,詞嵌入(Word Embedding)是一種常見的技術,用于將離散的詞匯或符號映射到連續的向量空間。這種映射使得相似的詞匯在向量空間中具有相似的向量表示,從而可以捕捉詞匯之間的語義關系。在PyTorch中,nn.Embedding模塊提供了一種簡單而高效的方式來實現詞嵌入。

一、nn.Embedding的基本原理

nn.Embedding是一個存儲固定大小的詞典的嵌入向量的查找表。給定一個編號,嵌入層能夠返回該編號對應的嵌入向量。這些嵌入向量反映了各個編號代表的符號之間的語義關系。在輸入一個編號列表時,nn.Embedding會輸出對應的符號嵌入向量列表。

在內部,nn.Embedding實際上是一個參數化的查找表,其中每一行都對應一個符號的嵌入向量。這些嵌入向量在訓練過程中通過反向傳播算法進行更新,以優化模型的性能。因此,nn.Embedding不僅可以用于降低數據的維度,減少計算和存儲開銷,還可以通過訓練學習輸入數據中的語義或結構信息。

二、nn.Embedding的實際應用

簡單的例子

import torch
from torch.nn import Embeddingclass Model(torch.nn.Module):def __init__(self):super(Model, self).__init__()self.emb = Embedding(5, 3)def forward(self,vec):input = torch.tensor([0, 1, 2, 3, 4])emb_vec1 = self.emb(input)# print(emb_vec1)  ### 輸出對同一組詞匯的編碼output = torch.einsum('ik, kj -> ij', emb_vec1, vec)return output
def simple_train():model = Model()vec = torch.randn((3, 1))label = torch.Tensor(5, 1).fill_(3)loss_fun = torch.nn.MSELoss()opt = torch.optim.SGD(model.parameters(), lr=0.015)print('初始化emebding參數權重:\n',model.emb.weight)for iter_num in range(100):output = model(vec)loss = loss_fun(output, label)opt.zero_grad()loss.backward(retain_graph=True)opt.step()# print('第{}次迭代emebding參數權重{}:\n'.format(iter_num, model.emb.weight))print('訓練后emebding參數權重:\n',model.emb.weight)torch.save(model.state_dict(),'./embeding.pth')return modeldef simple_test():model = Model()ckpt = torch.load('./embeding.pth')model.load_state_dict(ckpt)model=model.eval()vec = torch.randn((3, 1))print('加載emebding參數權重:\n', model.emb.weight)for iter_num in range(100):output = model(vec)print('n次預測后emebding參數權重:\n', model.emb.weight)if __name__ == '__main__':simple_train()  # 訓練與保存權重simple_test()

訓練代碼

import torch
from torch.nn import Embeddingclass Model(torch.nn.Module):def __init__(self):super(Model, self).__init__()self.emb = Embedding(5, 10)def forward(self,vec):input = torch.tensor([0, 1, 2, 3, 4])emb_vec1 = self.emb(input)print(emb_vec1)  ### 輸出對同一組詞匯的編碼output = torch.einsum('ik, kj -> ij', emb_vec1, vec)print(output)return outputdef simple_train():model = Model()vec = torch.randn((10, 1))label = torch.Tensor(5, 1).fill_(3)print(label)loss_fun = torch.nn.MSELoss()opt = torch.optim.SGD(model.parameters(), lr=0.015)for iter_num in range(1):output = model(vec)loss = loss_fun(output, label)print('iter:%d loss:%.2f' % (iter_num, loss))opt.zero_grad()loss.backward(retain_graph=True)opt.step()if __name__ == '__main__':simple_train()

自然語言處理任務

在自然語言處理任務中,詞嵌入是一種非常有用的技術。通過將每個單詞表示為一個實數向量,我們可以將高維的詞匯空間映射到一個低維的連續向量空間。這有助于提高模型的泛化能力和計算效率。例如,在文本分類任務中,我們可以使用nn.Embedding將文本中的每個單詞轉換為嵌入向量,然后將這些向量輸入到神經網絡中進行分類。

以下是一個簡單的示例代碼,演示了如何在PyTorch中使用nn.Embedding進行文本分類:

import torch
import torch.nn as nn
# 定義詞嵌入層,詞典大小為10000,嵌入向量維度為128
embedding = nn.Embedding(num_embeddings=10000, embedding_dim=128)
# 假設我們有一個包含5個單詞的文本,每個單詞的編號分別為1, 2, 3, 4, 5
input_ids = torch.tensor([1, 2, 3, 4, 5], dtype=torch.long)
# 通過詞嵌入層將單詞編號轉換為嵌入向量
embedded = embedding(input_ids)
# 輸出嵌入向量的形狀:(5, 128)
print(embedded.shape)
# 定義神經網絡模型
class TextClassifier(nn.Module):def __init__(self, embedding_dim, hidden_dim, num_classes):super(TextClassifier, self).__init__()self.embedding = embeddingself.fc1 = nn.Linear(embedding_dim, hidden_dim)self.fc2 = nn.Linear(hidden_dim, num_classes)def forward(self, input_ids):embedded = self.embedding(input_ids)# 對嵌入向量進行平均池化,得到一個固定長度的向量表示整個文本pooled = embedded.mean(dim=0)# 通過全連接層進行分類logits = self.fc2(self.fc1(pooled))return logits
# 實例化模型并進行訓練...

上述代碼中,我們首先定義了一個詞嵌入層embedding,詞典大小為10000,嵌入向量維度為128。然后,我們創建了一個包含5個單詞的文本,每個單詞的編號分別為1到5。通過調用embedding(input_ids),我們將單詞編號轉換為嵌入向量。最后,我們定義了一個文本分類器模型TextClassifier,其中包含了詞嵌入層、全連接層等組件。在模型的前向傳播過程中,我們首先對嵌入向量進行平均池化,得到一個固定長度的向量表示整個文本,然后通過全連接層進行分類。

除了自然語言處理任務外,nn.Embedding還可以用于圖像處理任務。例如,在卷積神經網絡(CNN)中,嵌入層可以將圖像的像素值映射到一個高維的空間,從而更好地捕捉圖像中的復雜特征和結構。這有助于提高模型的性能和泛化能力。

需要注意的是,在圖像處理任務中,我們通常使用卷積層(nn.Conv2d)或像素嵌入層(nn.PixelEmbed)等模塊來處理圖像數據,而不是直接使用nn.Embedding。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/905375.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/905375.shtml
英文地址,請注明出處:http://en.pswp.cn/news/905375.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

AI 檢測原創論文:技術迷思與教育本質的悖論思考

當高校將 AI 寫作檢測工具作為學術誠信的 "電子判官",一場由技術理性引發的教育異化正在悄然上演。GPT-4 檢測工具將人類創作的論文誤判為 AI 生成的概率高達 23%(斯坦福大學 2024 年研究數據),這種 "以 AI 制 AI&…

langchain4j集成QWen、Redis聊天記憶持久化

langchain4j實現聊天記憶默認是基于進程內存的方式,InMemoryChatMemoryStore是具體的實現了,是將聊天記錄到一個map中,如果用戶大的話,會造成內存溢出以及數據安全問題。位了解決這個問題 langchain4提供了ChatMemoryStore接口&am…

Tomcat 日志體系深度解析:從訪問日志配置到錯誤日志分析的全鏈路指南

一、Tomcat 核心日志文件架構與核心功能 1. 三大基礎日志文件對比(權威定義) 日志文件數據來源核心功能典型場景catalina.out標準輸出 / 錯誤重定向包含 Tomcat 引擎日志與應用控制臺輸出(System.out/System.err)排查 Tomcat 啟…

萬物互聯時代:ONVIF協議如何重構安防監控系統架構

前言 一、ONVIF協議是什么 ONVIF(Open Network Video Interface Forum,開放式網絡視頻接口論壇)是一種全球性的開放行業標準,由安訊士(AXIS)、博世(BOSCH)和索尼(SONY&…

leetcode - 雙指針問題

文章目錄 前言 題1 移動零: 思路: 參考代碼: 題2 復寫零: 思考: 參考代碼: 題3 快樂數: 思考: 參考代碼: 題4 盛最多水的容器: 思考:…

從概念表達到安全驗證:智能駕駛功能迎來系統性規范

隨著輔助駕駛事故頻發,監管機制正在迅速補位。面對能力表達、使用責任、功能部署等方面的新要求,行業開始重估技術邊界與驗證能力,數字樣機正成為企業合規落地的重要抓手。 2025年以來,圍繞智能駕駛功能的爭議不斷升級。多起因輔…

java數組題(5)

(1): 思路: 1.首先要對數組nums排序,這樣兩數之間的差距最小。 2.題目要求我們通過最多 k 次遞增操作,使數組中某個元素的頻數(出現次數)最大化。經過上面的排序,最大數…

Python(1) 做一個隨機數的游戲

有關變量的,其實就是 可以直接打印對應變量。 并且最后倒數第二行就是可以讓兩個數進行交換。 Py快捷鍵“ALTP 就是顯示上一句的代碼。 —————————————————————————————— 字符串 用 雙引號或者單引號 。 然后 保證成雙出現即可 要是…

【認知思維】驗證性偏差:認知陷阱的識別與克服

什么是驗證性偏差 驗證性偏差(Confirmation Bias)是人類認知中最普遍、最根深蒂固的心理現象之一,指的是人們傾向于尋找、解釋、偏愛和回憶那些能夠確認自己已有信念或假設的信息,同時忽視或貶低與之相矛盾的證據。這種認知偏差影…

Wpf學習片段

IRegionManager 和IContainerExtension IRegionManager 是 Prism 框架中用于管理 UI 區域(Regions)的核心接口,它實現了模塊化應用中視圖(Views)的動態加載、導航和生命周期管理。 IContainerExtension 是依賴注入&…

消息~組件(群聊類型)ConcurrentHashMap發送

為什么選擇ConcurrentHashMap? 在開發聊天應用時,我們需要存儲和管理大量的聊天消息數據,這些數據會被多個線程頻繁訪問和修改。比如,當多個用戶同時發送消息時,服務端需要同時處理這些消息的存儲和查詢。如果用普通的…

Stapi知識框架

一、Stapi 基礎認知 1. 框架定位 自動化API開發框架:專注于快速生成RESTful API 約定優于配置:通過標準化約定減少樣板代碼 企業級應用支持:適合構建中大型API服務 代碼生成導向:顯著提升開發效率 2. 核心特性 自動CRUD端點…

基于深度學習的水果識別系統設計

一、選擇YOLOv5s模型 YOLOv5:YOLOv5 是一個輕量級的目標檢測模型,它在 YOLOv4 的基礎上進行了進一步優化,使其在保持較高檢測精度的同時,具有更快的推理速度。YOLOv5 的網絡結構更加靈活,可以根據不同的需求選擇不同大…

Spring Security與SaToken的對比

Spring Security與SaToken的詳細對照與優缺點分析 1. 核心功能與設計理念 對比維度Spring SecuritySaToken核心定位企業級安全框架,深度集成Spring生態,提供全面的安全解決方案(認證、授權、攻擊防護等)輕量級權限認證框架&#…

【docker】--鏡像管理

文章目錄 拉取鏡像啟動鏡像為容器連接容器法一法二 保存鏡像加載鏡像鏡像打標簽移除鏡像 拉取鏡像 docker pull mysql:8.0.42啟動鏡像為容器 docker run -dp 8080:8080 --name container_mysql8.0.42 -e MYSQL_ROOT_PASSWORD123123123 mysql:8.0.42 連接容器 法一 docker e…

力扣HOT100之二叉樹:543. 二叉樹的直徑

這道題本來想到可以用遞歸做,但是還是沒想明白,最后還是去看靈神題解了,感覺這道題最大的收獲就是鞏固了我對lambda表達式的掌握。 按照靈神的思路,直徑可以理解為從一個葉子出發向上,在某個節點處拐彎,然后…

web 自動化之 yaml 數據/日志/截圖

文章目錄 一、yaml 數據獲取二、日志獲取三、截圖 一、yaml 數據獲取 需要安裝 PyYAML 庫 import yaml import os from TestPOM.common import dir_config as Dirdef read_yaml(key,file_name"test_datas.yaml"):file_path os.path.join(Dir.testcases_dir, file_…

rtty操作記錄說明

rtty操作記錄說明 前言 整理資料發現了幾年前做的操作記錄,分享出來,希望對大家有用。 rtty-master:rtty客戶端程序,其中buffer\log\ssl為源碼的子目錄,從git上下載https://github.com/zhaojh329, rtty…

mybatis中${}和#{}的區別

先測試&#xff0c;再說結論 userService.selectStudentByClssIds(10000, "wzh or 11");List<StudentEntity> selectStudentByClssIds(Param("stuId") int stuId, Param("field") String field);<select id"selectStudentByClssI…

【運維】MacOS藍牙故障排查與修復指南

在日常使用macOS系統過程中&#xff0c;藍牙連接問題時有發生。無論是無法連接設備、連接不穩定還是藍牙功能完全失效&#xff0c;這些問題都會嚴重影響我們的工作效率。本文將分享一些實用的排查方法和修復技巧&#xff0c;幫助你解決macOS系統上的藍牙故障。 問題癥狀 常見…