用例子和代碼了解詞嵌入和位置編碼

1.嵌入(Input Embedding)

讓我用一個更具體的例子來解釋輸入嵌入(Input Embedding)。

背景

假設我們有一個非常小的詞匯表,其中包含以下 5 個詞:

  • "I"
  • "love"
  • "machine"
  • "learning"
  • "!"

假設我們想把這句話 "I love machine learning !" 作為輸入。

步驟 1:創建詞匯表(Vocabulary)

我們給每個詞分配一個唯一的索引號:

  • "I" -> 0
  • "love" -> 1
  • "machine" -> 2
  • "learning" -> 3
  • "!" -> 4
步驟 2:創建嵌入矩陣(Embedding Matrix)

假設我們選擇每個詞的向量維度為 3(實際應用中維度會更高)。我們初始化一個大小為 5x3 的嵌入矩陣,如下所示:

嵌入矩陣(Embedding Matrix):
[[0.1, 0.2, 0.3],  // "I" 的向量表示[0.4, 0.5, 0.6],  // "love" 的向量表示[0.7, 0.8, 0.9],  // "machine" 的向量表示[1.0, 1.1, 1.2],  // "learning" 的向量表示[1.3, 1.4, 1.5]   // "!" 的向量表示
]
步驟 3:查找表操作(Lookup Table Operation)

當我們輸入句子 "I love machine learning !" 時,我們首先將每個詞轉換為其對應的索引:

  • "I" -> 0
  • "love" -> 1
  • "machine" -> 2
  • "learning" -> 3
  • "!" -> 4

然后,我們使用這些索引在嵌入矩陣中查找相應的向量表示:

輸入句子嵌入表示:
[[0.1, 0.2, 0.3],  // "I" 的向量表示[0.4, 0.5, 0.6],  // "love" 的向量表示[0.7, 0.8, 0.9],  // "machine" 的向量表示[1.0, 1.1, 1.2],  // "learning" 的向量表示[1.3, 1.4, 1.5]   // "!" 的向量表示
]
步驟 4:輸入嵌入過程

????????通過查找表操作,我們把原本的句子 "I love machine learning !" 轉換成了一個二維數組,每一行是一個詞的嵌入向量。

碼示例代

讓我們用 Python 和 PyTorch 來實現這個過程:

import torch
import torch.nn as nn# 假設詞匯表大小為 5,嵌入維度為 3
vocab_size = 5
embedding_dim = 3# 創建一個嵌入層
embedding_layer = nn.Embedding(num_embeddings=vocab_size, embedding_dim=embedding_dim)# 初始化嵌入矩陣(為了便于理解,這里手動設置嵌入矩陣的值)
embedding_layer.weight = nn.Parameter(torch.tensor([[0.1, 0.2, 0.3],  # "I"[0.4, 0.5, 0.6],  # "love"[0.7, 0.8, 0.9],  # "machine"[1.0, 1.1, 1.2],  # "learning"[1.3, 1.4, 1.5]   # "!"
]))# 輸入句子對應的索引
input_indices = torch.tensor([0, 1, 2, 3, 4])# 獲取輸入詞的嵌入表示
embedded = embedding_layer(input_indices)print(embedded)

輸出:?

tensor([[0.1000, 0.2000, 0.3000],[0.4000, 0.5000, 0.6000],[0.7000, 0.8000, 0.9000],[1.0000, 1.1000, 1.2000],[1.3000, 1.4000, 1.5000]], grad_fn=<EmbeddingBackward>)

????????這樣我們就完成了輸入嵌入的過程,把離散的詞轉換為了連續的向量表示。

????????當你完成了詞嵌入,將離散的詞轉換為連續的向量表示后,位置編碼步驟如下:

2. 理解位置編碼

????????位置編碼(Positional Encoding)通過生成一組特殊的向量,表示詞在序列中的位置,并將這些向量添加到詞嵌入上,使模型能夠識別詞序。

2.1 位置編碼公式

????????位置編碼使用正弦和余弦函數生成。具體公式如下:

其中:

  • ?是詞在序列中的位置。
  • ?i是詞嵌入向量的維度索引。
  • ?d是詞嵌入向量的總維度。

2.2 生成位置編碼向量

????????以下是 Python 代碼示例,展示如何生成位置編碼向量,并將其添加到詞嵌入上:

????????生成位置編碼向量

import numpy as np
import torchdef get_positional_encoding(max_len, d_model):"""生成位置編碼向量:param max_len: 序列的最大長度:param d_model: 詞嵌入向量的維度:return: 形狀為 (max_len, d_model) 的位置編碼矩陣"""pos = np.arange(max_len)[:, np.newaxis]i = np.arange(d_model)[np.newaxis, :]angle_rates = 1 / np.power(10000, (2 * (i // 2)) / np.float32(d_model))angle_rads = pos * angle_rates# 采用正弦函數應用于偶數索引 (2i)angle_rads[:, 0::2] = np.sin(angle_rads[:, 0::2])# 采用余弦函數應用于奇數索引 (2i+1)angle_rads[:, 1::2] = np.cos(angle_rads[:, 1::2])return torch.tensor(angle_rads, dtype=torch.float32)# 示例參數
max_len = 100  # 假設最大序列長度為 100
d_model = 512  # 假設詞嵌入維度為 512# 生成位置編碼矩陣
positional_encoding = get_positional_encoding(max_len, d_model)
print(positional_encoding.shape)  # 輸出: torch.Size([100, 512])

2.3 添加位置編碼到詞嵌入

????????假設你已經有一個詞嵌入張量?embedded,它的形狀為 (batch_size, seq_len, d_model),可以將位置編碼添加到詞嵌入中:

class TransformerEmbedding(nn.Module):def __init__(self, vocab_size, d_model, max_len):super(TransformerEmbedding, self).__init__()self.token_embedding = nn.Embedding(vocab_size, d_model)self.positional_encoding = get_positional_encoding(max_len, d_model)self.dropout = nn.Dropout(p=0.1)def forward(self, x):# 獲取詞嵌入token_embeddings = self.token_embedding(x)# 添加位置編碼seq_len = x.size(1)position_embeddings = self.positional_encoding[:seq_len, :]# 詞嵌入和位置編碼相加embeddings = token_embeddings + position_embeddings.unsqueeze(0)return self.dropout(embeddings)# 示例參數
vocab_size = 10000  # 假設詞匯表大小為 10000
d_model = 512       # 詞嵌入維度
max_len = 100       # 最大序列長度# 實例化嵌入層
embedding_layer = TransformerEmbedding(vocab_size, d_model, max_len)# 假設輸入序列為一批大小為 2,序列長度為 10 的張量
input_tensor = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10],[10, 9, 8, 7, 6, 5, 4, 3, 2, 1]], dtype=torch.long)# 獲取嵌入表示
output_embeddings = embedding_layer(input_tensor)
print(output_embeddings.shape)  # 輸出: torch.Size([2, 10, 512])

2.4. 繼續進行 Transformer 模型的前向傳播

????????有了詞嵌入和位置編碼之后,接下來的步驟就是將這些嵌入輸入到 Transformer 模型的編碼器和解碼器中,進行進一步處理。Transformer 模型的編碼器和解碼器由多層注意力機制和前饋神經網絡組成。

????????位置編碼步驟通過生成一組正弦和余弦函數的向量,并將這些向量添加到詞嵌入上,使 Transformer 模型能夠捕捉序列中的位置信息。

import torch
import torch.nn as nnclass MultiHeadSelfAttention(nn.Module):def __init__(self, d_model, nhead):super(MultiHeadSelfAttention, self).__init__()assert d_model % nhead == 0, "d_model 必須能被 nhead 整除"self.d_model = d_modelself.d_k = d_model // nheadself.nhead = nheadself.W_q = nn.Linear(d_model, d_model)self.W_k = nn.Linear(d_model, d_model)self.W_v = nn.Linear(d_model, d_model)self.fc = nn.Linear(d_model, d_model)self.dropout = nn.Dropout(0.1)self.scale = torch.sqrt(torch.FloatTensor([self.d_k]))def forward(self, x):batch_size = x.size(0)seq_len = x.size(1)# 線性變換得到 Q, K, VQ = self.W_q(x)K = self.W_k(x)V = self.W_v(x)# 分成多頭Q = Q.view(batch_size, seq_len, self.nhead, self.d_k).transpose(1, 2)K = K.view(batch_size, seq_len, self.nhead, self.d_k).transpose(1, 2)V = V.view(batch_size, seq_len, self.nhead, self.d_k).transpose(1, 2)# 計算注意力權重attn_weights = torch.matmul(Q, K.transpose(-2, -1)) / self.scaleattn_weights = torch.nn.functional.softmax(attn_weights, dim=-1)attn_weights = self.dropout(attn_weights)# 加權求和attn_output = torch.matmul(attn_weights, V)# 拼接多頭輸出attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)# 最后的線性變換output = self.fc(attn_output)return output# 示例參數
d_model = 8
nhead = 2# 輸入張量
x = torch.rand(2, 5, d_model)# 實例化多頭自注意力層
multi_head_attn = MultiHeadSelfAttention(d_model, nhead)# 前向傳播
output = multi_head_attn(x)
print("多頭自注意力輸出:\n", output)

解釋

  • 線性變換:使用 nn.Linear 實現線性變換,將輸入張量??通過三個不同的線性層得到查詢、鍵和值向量。
  • 分成多頭:使用 view 和 transpose 方法將查詢、鍵和值向量分成多頭,形狀變為?。
  • 計算注意力權重:通過點積計算查詢和鍵的相似度,并通過 softmax 歸一化得到注意力權重。
  • 加權求和:使用注意力權重對值向量進行加權求和,得到每個頭的輸出。
  • 拼接多頭輸出:將多頭的輸出拼接起來,并通過一個線性層進行變換,得到最終的輸出。

????????查詢、鍵和值向量的生成是多頭自注意力機制的關鍵步驟,通過線性變換將輸入向量轉換為查詢、鍵和值向量,然后使用這些向量計算注意力權重,捕捉輸入序列中不同位置的相關性。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/web/39107.shtml
繁體地址,請注明出處:http://hk.pswp.cn/web/39107.shtml
英文地址,請注明出處:http://en.pswp.cn/web/39107.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

10 Posix API與網絡協議棧

POSIX概念 POSIX是由IEEE指定的一系列標準,用于澄清和統一Unix-y操作系統提供的應用程序編程接口(以及輔助問題,如命令行shell實用程序),當您編寫程序以依賴POSIX標準時,您可以非常肯定能夠輕松地將它們移植到大量的Unix衍生產品系列中(包括Linux,但不限于此!)。 如…

DeepFaceLive----AI換臉簡單使用

非常強大的軟件,官方github https://github.com/iperov/DeepFaceLive 百度云鏈接: 鏈接&#xff1a;https://pan.baidu.com/s/1VHY-wxqJXSh5lCn1c4whZg 提取碼&#xff1a;nhev 1下載解壓軟件 下載完成后雙擊.exe文件進行解壓.完成后雙擊.bat文件打開軟件 2 視頻使用圖片換…

k8s部署單機版mysql8

一、創建命名空間 # cat mysql8-namespace.yaml apiVersion: v1 kind: Namespace metadata:name: mysql8labels:name: mysql8# kubectl apply -f mysql8-namespace.yaml namespace/mysql8 created# kubectl get ns|grep mysql8 mysql8 Active 8s二、創建mysql配…

Ubuntu環境下Graphics drawString 中文亂碼解決方法

問題描述 以下代碼在,在本地測試時 ,可以正常輸出中文字符的圖片,但部署到線上時中文亂碼 // 獲取Graphics2D對象以支持更多繪圖功能 Graphics2D g2d combined.createGraphics(); // 示例字體、樣式和大小 Font font new Font("微軟雅黑", Font.PLAI…

Swagger:swagger和knife4j

Swagger 一個規范完整的框架 用以生成,描述,調用和可視化 主要作用為 自動生成接口文檔 方便后端開發進行接口調試 Knife4j 為Java MVC框架集成 依賴引入: <!-- knife4j版接口文檔 訪問/doc.html--> <dependency><groupId>com.github.xiaoymin<…

SSM學習4:spring整合mybatis、spring整合Junit

spring整合mybatis 之前的內容是有service層&#xff08;業務實現層&#xff09;、dao層&#xff08;操作數據庫&#xff09;&#xff0c;現在新添加一個domain&#xff08;與業務相關的實體類&#xff09; 依賴配置 pom.xml <?xml version"1.0" encoding&quo…

解決ScaleBox來實現大屏自適應時,頁面的餅圖會變形的問題

封裝一個公用組件pieChartAdaptation.vue 代碼如下&#xff1a; <template><div :style"styleObject" class"pie-chart-adaptation"><slot></slot></div> </template><script setup lang"ts"> impo…

2.2.3 C#中顯示控件BDPictureBox 的實現----控件實現

2.2.3 C#中顯示控件BDPictureBox 的實現----控件實現 1 界面控件布局 2圖片內存Mat類說明 原始圖片&#xff1a;m_raw_mat ,Display_Mat()調用時更新或者InitDisplay_Mat時更新局部放大顯示圖片&#xff1a;m_extract_zoom_mat&#xff0c;更新scale和scroll信息后更新overla…

2024年精選100道軟件測試面試題(內含文檔)

測試技術面試題 1、我現在有個程序&#xff0c;發現在 Windows 上運行得很慢&#xff0c;怎么判別是程序存在問題還是軟硬件系統存在問題&#xff1f; 2、什么是兼容性測試&#xff1f;兼容性測試側重哪些方面&#xff1f; 3、測試的策略有哪些&#xff1f; 4、正交表測試用…

Eureka與Spring Cloud Bus的協同:打造智能服務發現新篇章

Eureka與Spring Cloud Bus的協同&#xff1a;打造智能服務發現新篇章 在微服務架構中&#xff0c;服務發現是實現服務間通信的關鍵機制。Eureka作為Netflix開源的服務發現框架&#xff0c;與Spring Cloud Bus的集成&#xff0c;提供了一種動態、響應式的服務治理解決方案。本文…

市場規模5萬億,護理員缺口550萬,商業護理企業如何解決服務供給難題?

干貨搶先看 1. 據統計&#xff0c;我國失能、半失能老人數量約4400萬&#xff0c;商業護理服務市場規模達5萬億。然而&#xff0c;當前養老護理員缺口巨大&#xff0c;人員的供需不匹配是很多養老服務企業需要克服的難題。 2. 當前居家護理服務的主要市場參與者分為兩類&…

利用GPT 將 matlab 內置 bwlookup 函數轉C

最近業務需要將 matlab中bwlookup 的轉C 這個函數沒有現成的m文件參考&#xff0c;內置已經打成庫了&#xff0c;所以沒有參考源代碼 但是它的解釋還是很清楚的&#xff0c;可以根據這個來寫 Nonlinear filtering using lookup tables - MATLAB bwlookup - MathWorks 中國 A…

python請求報錯::requests.exceptions.ProxyError: HTTPSConnectionPool

在發送網頁請求時&#xff0c;發現很久未響應&#xff0c;最后報錯&#xff1a; requests.exceptions.ProxyError: HTTPSConnectionPool(hostsvr-6-9009.share.51env.net, port443): Max retries exceeded with url: /prod-api/getInfo (Caused by ProxyError(Unable to conne…

秒懂設計模式--學習筆記(5)【創建篇-抽象工廠】

目錄 4、抽象工廠4.1 介紹4.2 品牌與系列&#xff08;針對工廠泛濫&#xff09;(**分類**)4.3 產品規劃&#xff08;**數據模型**&#xff09;4.4 生產線規劃&#xff08;**工廠類**&#xff09;4.5 分而治之4.6 抽象工廠模式的各角色定義如下4.7 基于此抽象工廠模式以品牌與系…

vue啟動時的錯誤

解決辦法一&#xff1a;在vue.config.js中直接添加一行代碼 lintOnSave:false 關閉該項目重新運行就可啟動 解決辦法二&#xff1a; 修改組件名稱

Python容器 之 通用功能

1.切片 1.格式&#xff1a; 數據[起始索引:結束索引:步長 2.適用類型&#xff1a; 字符串(str)、列表(list)、元組(tuple) 3.說明&#xff1a; 通過切片操作, 可以獲取數據中指定部分的內容 4.注意 : 結束索引對應的數據不會被截取到 支持正向索引和逆向索引 步長用于設置截取…

配音軟件有哪些?分享五款超級好用的配音軟件

隨著嫦娥六號的壯麗回歸&#xff0c;舉國上下都沉浸在這份自豪與激動之中。 在這樣一個歷史性的時刻&#xff0c;我們何不用聲音記錄下這份情感&#xff0c;讓這份記憶以聲音的形式流傳&#xff1f; 無論是制作視頻分享這份喜悅&#xff0c;還是創作音頻講述探月故事&#xff…

Oracle數據庫中RETURNING子句

RETURNING子句允許您檢索插入、刪除或更新所修改的列&#xff08;以及基于列的表達式&#xff09;的值。如果不使用RETURNING&#xff0c;則必須在DML語句完成后運行SELECT語句&#xff0c;才能獲得更改列的值。因此&#xff0c;RETURNING有助于避免再次往返數據庫&#xff0c;…

Plotly:原理、使用與數據可視化的未來

文章目錄 引言Plotly的原理Plotly的基本使用安裝Plotly創建基本圖表定制圖表樣式 Plotly的高級特性交互式圖表圖表動畫圖表集成 結論 引言 在當今的數據驅動世界中&#xff0c;數據可視化已經成為了一個至關重要的工具。它允許我們直觀地理解數據&#xff0c;發現數據中的模式…

CXL-GPU: 全球首款實現百ns以內的低延遲CXL解決方案

數據中心在追求更高性能和更低總擁有成本&#xff08;TCO&#xff09;的過程中面臨三大主要內存挑戰。首先&#xff0c;當前服務器內存層次結構存在局限性。直接連接的DRAM與固態硬盤&#xff08;SSD&#xff09;存儲之間存在三個數量級的延遲差異。當處理器直接連接的內存容量…