【論文筆記】Transformer

Transformer

2017 年,谷歌團隊提出 Transformer 結構,Transformer 首先應用在自然語言處理領域中的機器翻譯任務上,Transformer 結構完全構建于注意力機制,完全丟棄遞歸和卷積的結構,這使得 Transformer 結構效率更高。迄今為止,Transformer 廣泛應用于深度學習的各個領域。

模型架構

Transformer 結構如下圖所示,Transformer 遵循編碼器-解碼器(Encoder-Decoder)的結構,每個 Transformer Block 的結構基本上相同,其編碼器和解碼器可以視為兩個獨立的模型,例如:ViT 僅使用了 Transformer 編碼器,而 GPT 僅使用了 Transformer 解碼器。

Transformer Architecture

編碼器

編碼器包含 N = 6 N=6 N=6 個相同的層,每個層包含兩個子層,分別是多頭自注意力層(Multi Head Self-Attention)和前饋神經網絡層(Feed Forward Network),每個子層都包含殘差連接(Residual Connection)和層歸一化(Layer Normalization),使模型更容易學習。FFN 層是一個兩層的多層感知機(Multi Layer Perceptron)。

解碼器

解碼器也包含 N = 6 N=6 N=6 個相同的層,包含三個子層,分別是掩碼多頭自注意力層(Masked Multi-Head Attention)、編碼器-解碼器多頭注意力層(Cross Attention)和前饋神經網絡層。

其中,掩碼多頭自注意力層用于將輸出的 token 進行編碼,在應用注意力機制時存在一個注意力掩碼,以保持自回歸(Auto Regressive)特性,即先生成的 token 不能注意到后生成的 token,編碼后作為 Cross Attention 層的 Query,而 Cross Attention 層的 Key 和 Value 來自于編碼器的輸出,最后通過 FFN 層產生解碼器塊的輸出。

位置編碼

遞歸神經網絡(Recurrent Neural Networks)以串行的方式處理序列信息不同,注意力機制本身不包含位置關系,因此 Transformer 需要為序列中的每個 token 添加位置信息,因此需要位置編碼。Transformer 中使用了正弦位置編碼(Sinusoidal Position Embedding),位置編碼由以下數學表達式給出:

P E p o s , 2 i = sin ? ( p o s 1000 0 2 i / d model ) P E p o s , 2 i + 1 = cos ? ( p o s 1000 0 2 i / d model ) \begin{aligned} &PE_{pos,2i} = \sin \left( \frac{pos}{10000^{2i/d_{\text{model}}}} \right)\\ &PE_{pos,2i+1} = \cos \left( \frac{pos}{10000^{2i/d_{\text{model}}}} \right) \end{aligned} ?PEpos,2i?=sin(100002i/dmodel?pos?)PEpos,2i+1?=cos(100002i/dmodel?pos?)?

其中,pos 為 token 所在的序列位置,i 則是對應的特征維度。作者采用正弦位置編碼是基于正弦位置編碼可以使模型更容易學習到相對位置關系的假設。

下面是正弦位置編碼的 PyTorch 實現代碼,僅供參考。

class PositionEmbedding(nn.Module):"""Sinusoidal Positional Encoding."""def __init__(self, d_model: int, max_len: int) -> None:super(PositionEmbedding, self).__init__()self.pe = torch.zeros(max_len, d_model, requires_grad=False)factor = 10000 ** (torch.arange(0, d_model, step=2) / d_model)pos = torch.arange(0, max_len).float().unsqueeze(1)self.pe[:, 0::2] = torch.sin(pos / factor)self.pe[:, 1::2] = torch.cos(pos / factor)def forward(self, x: Tensor) -> Tensor:seq_len = x.size()[1]pos_emb = self.pos_encoding[:seq_len, :].unsqueeze(0).to(x.device)return pos_emb

注意力機制

注意力機制出現在 Transformer 之前,包括兩種類型:加性注意力和乘性注意力。Transformer 使用的是乘性注意力,這也是最常見的注意力機制,首先計算一個點積相似度,然后通過 Softmax 后得到注意力權重,根據注意力權重對 Values 進行加權求和,具體的過程可以表示為以下數學公式:

Attention ( Q , K , V ) = softmax ( Q K T d k ) V \text{Attention}(Q,K,V)=\text{softmax}(\frac{QK^T}{\sqrt{d_k}})V Attention(Q,K,V)=softmax(dk? ?QKT?)V

Attention

其中,注意力計算中包含了一個溫度參數 d k \sqrt{d_k} dk? ?,一個直觀的解釋是避免點積的結果過大或過小,導致 softmax 后的結果梯度幾乎為 0 的區域,降低模型的收斂速度。對于自回歸生成任務而言,我們不希望前面生成的 token 關注后面生成 token,因此可能會采用一個下三角的 Attention Mask,掩蓋掉 attention 矩陣的上三角部分,注意力機制可以重寫為:

Attention ( Q , K , V ) = softmax ( Q K T d k + M ) V \text{Attention}(Q,K,V)=\text{softmax}(\frac{QK^T}{\sqrt{d_k}}+M)V Attention(Q,K,V)=softmax(dk? ?QKT?+M)V

具體實現中,需要 mask 掉的部分設置為負無窮即可,這會使得在 softmax 操作后得到的注意力權重為 0,避免注意到特定的 token。

有趣的是,注意力機制本身不包含可學習參數,因此,在 Transformer 中引入了多頭注意力機制,同時希望多頭注意力能夠捕捉多種模式,類似于卷積。多頭注意力機制可以表示為:

MultiHead ( Q , K , V ) = Concat ( head 1 , head 2 , … , head h ) W O where? head i = Attention ( Q W i Q , K W i K , V W i V ) \begin{aligned} \text{MultiHead}(Q,K,V)=\text{Concat}(\text{head}_1,\text{head}_2,\dots,\text{head}_h)W^O\\ \text{where }\text{head}_i=\text{Attention}(QW_i^Q,KW_i^K,VW_i^V) \end{aligned} MultiHead(Q,K,V)=Concat(head1?,head2?,,headh?)WOwhere?headi?=Attention(QWiQ?,KWiK?,VWiV?)?

以下為多頭注意力機制的 PyTorch 實現代碼,僅供參考。

from torch import nn, Tensor
from functools import partialclass MultiHeadAttention(nn.Module):"""Multi-Head Attention."""def __init__(self, d_model: int, n_heads: int) -> None:super(MultiHeadAttention, self).__init__()self.n_heads = n_headsself.proj_q = nn.Linear(d_model, d_model)self.proj_k = nn.Linear(d_model, d_model)self.proj_v = nn.Linear(d_model, d_model)self.proj_o = nn.Linear(d_model, d_model)self.attention = ScaledDotProductAttention()def forward(self, q: Tensor, k: Tensor, v: Tensor, mask: Tensor | None = None) -> Tensor:# input tensor of shape (batch_size, seq_len, d_model)# 1. linear transformationq, k, v = self.proj_q(q), self.proj_k(k), self.proj_v(v)# 2. split tensor by the number of headsq, k, v = map(partial(_split, n_heads=self.n_heads), (q, k, v))# 3. scaled dot-product attentionout = self.attention(q, k, v, mask)# 4. concatenate headsout = _concat(out)# 5. linear transformationreturn self.proj_o(out)class ScaledDotProductAttention(nn.Module):"""Scaled Dot-Product Attention."""def __init__(self) -> None:super(ScaledDotProductAttention, self).__init__()self.softmax = nn.Softmax(dim=-1)def forward(self, q: Tensor, k: Tensor, v: Tensor, mask: Tensor | None = None) -> Tensor:# input tensor of shape (batch_size, n_heads, seq_len, d_head)d_k = k.size()[3]k_t = k.transpose(2, 3)# 1. compute attention scorescore: Tensor = (q @ k_t) * d_k**-0.5# 2. apply mask(optional)if mask is not None:score = score.masked_fill(mask == 0, float("-inf"))# 3. compute attention weightsattn = self.softmax(score)# 4. compute attention outputout = attn @ vreturn outdef _split(tensor: Tensor, n_heads: int) -> Tensor:"""Split tensor by the number of heads."""batch_size, seq_len = tensor.size()[:2]d_model = tensor.size()[2]d_head = d_model // n_headsreturn tensor.view(batch_size, seq_len, n_heads, d_head).transpose(1, 2)def _concat(tensor: Tensor) -> Tensor:"""Concatenate tensor after splitting."""batch_size, n_heads, seq_len, d_head = tensor.size()d_model = n_heads * d_head

參考

[1] A. Vaswani et al., “Attention is All you Need,” in Advances in Neural Information Processing Systems, Curran Associates, Inc., 2017.

[2] A. Dosovitskiy et al., “An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale,” Jun. 03, 2021, arXiv: arXiv:2010.11929.

[3] K. He, X. Zhang, S. Ren, and J. Sun, “Deep Residual Learning for Image Recognition,” Dec. 10, 2015, arXiv: arXiv:1512.03385.

[4] A. Radford, K. Narasimhan, T. Salimans, and I. Sutskever, “Improving Language Understanding by Generative Pre-Training”.

[5] hyunwoongko. "Transformer: PyTorch Implementation of ‘Attention Is All You Need’ " Github 2019. [Online] Available: https://github.com/hyunwoongko/transformer

[6] 李沐. “Transformer論文逐段精讀【論文精讀】” Bilibili 2021. [Online] Available: https://www.bilibili.com/video/BV1pu411o7BE/?spm_id_from=333.337.search-card.all.click&vd_source=c8a32a5a667964d5f1068d38d6182813

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/bicheng/74217.shtml
繁體地址,請注明出處:http://hk.pswp.cn/bicheng/74217.shtml
英文地址,請注明出處:http://en.pswp.cn/bicheng/74217.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

CI/CD(三) 安裝nfs并指定k8s默認storageClass

一、NFS 服務端安裝(主節點 10.60.0.20) 1. 安裝 NFS 服務端 sudo apt update sudo apt install -y nfs-kernel-server 2. 創建共享目錄并配置權限 sudo mkdir -p /data/k8s sudo chown nobody:nogroup /data/k8s # 允許匿名訪問 sudo chmod 777 /dat…

【QA】單件模式在Qt中有哪些應用?

單例設計模式確保一個類僅有一個實例,并提供一個全局訪問點來獲取該實例。在 Qt 框架中,有不少類的設計采用了單例模式,以下為你詳細介紹并給出相應代碼示例。 1. QApplication QApplication 是 Qt GUI 應用程序的核心類,每個 Q…

存儲過程觸發器習題整理1

46、{blank}設有商品表(商品號,商品名,單價)和銷售表(銷售單據號,商品號,銷售時間,銷售數量,銷售單價)。其中,商品號代表一類商品,商品號、單價、銷售數量和銷售單價均為整型。請編寫…

基于ChatGPT、GIS與Python機器學習的地質災害風險評估、易發性分析、信息化建庫及災后重建高級實踐

第一章、ChatGPT、DeepSeek大語言模型提示詞與地質災害基礎及平臺介紹【基礎實踐篇】 1、什么是大模型? 大模型(Large Language Model, LLM)是一種基于深度學習技術的大規模自然語言處理模型。 代表性大模型:GPT-4、BERT、T5、Ch…

單表達式倒計時工具:datetime的極度優雅(智普清言)

一個簡單表達式,也可以優雅自成工具。 筆記模板由python腳本于2025-03-22 20:25:49創建,本篇筆記適合任意喜歡學習的coder翻閱。 【學習的細節是歡悅的歷程】 博客的核心價值:在于輸出思考與經驗,而不僅僅是知識的簡單復述。 Pyth…

最優編碼樹的雙子性

現在看一些書,不太愿意在書上面做一些標記,也沒啥特殊的原因。。哈哈。 樹的定義 無環連通圖,極小連通圖,極大無環圖。 度 某個節點,描述它的度,一般默認是出度,分叉的邊的條數。或者說孩子…

MiB和MB

本文來自騰訊元寶 MiB 和 ?MB 有區別,盡管它們都用于表示數據存儲的單位,但它們的計算方式不同,分別基于不同的進制系統。 1. ?MiB(Mebibyte)? ?MiB 是基于二進制的單位,使用1024作為基數。1 MiB 102…

Labview和C#調用KNX API 相關東西

敘述:完全沒有聽說過KNX這個協議...................我這次項目中也是簡單的用了一下沒有過多的去研究 C#調用示例工程鏈接(labview調用示例在 DEBUG文件夾里面) 通過網盤分享的文件:KNX調用示例.zip 鏈接: https://pan.baidu.com/s/1NQUEYM11HID0M4ksetrTyg?pwd…

損失函數理解(二)——交叉熵損失

損失函數的目的是為了定量描述不同模型(例如神經網絡模型和人腦模型)的差異。 交叉熵,顧名思義,與熵有關,先把模型換成熵這么一個數值,然后用這個數值比較不同模型之間的差異。 為什么要做這一步轉換&…

Kubernetes的Replica Set和ReplicaController有什么區別

ReplicaSet 和 ReplicationController 是 Kubernetes 中用于管理應用程序副本的兩種資源,它們有類似的功能,但 ReplicaSet 是 ReplicationController 的增強版本。 以下是它們的主要區別: 1. 功能的演進 ReplicationController 是 Kubernete…

信息系統運行管理員教程3--信息系統設施運維

第3章 信息系統設施運維 信息系統設施是支撐信息系統業務活動的信息系統軟硬件資產及環境。 第1節 信息系統設施運維的管理體系 信息系統設施運維的范圍包含信息系統涉及的所有設備及環境,主要包括基礎環境、硬件設備、網絡設備、基礎軟件等。 信息系統設施運維…

如何通過Python實現自動化任務:從入門到實踐

在當今快節奏的數字化時代,自動化技術正逐漸成為提高工作效率的利器。無論是處理重復性任務,還是管理復雜的工作流程,自動化都能為我們節省大量時間和精力。本文將以Python為例,帶你從零開始學習如何實現自動化任務,并通過一個實際案例展示其強大功能。 一、為什么選擇Pyt…

Spring Boot 與 MyBatis Plus 整合 KWDB 實現 JDBC 數據訪問

? 引言 本文主要介紹如何在 IDEA 中搭建一個使用 Maven 管理的 Spring Boot 應用項目工程,并結合在本地搭建的 KWDB 數據庫(版本為:2.0.3)來演示 Spring Boot 與 MyBatis Plus 的集成,以及對 KWDB 數據庫的數據操作…

Java鎖等待喚醒機制

在 Java 并發編程中,鎖的等待和喚醒機制至關重要,通常使用 wait()、notify() 和 notifyAll() 來實現線程間的協調。本文將詳細介紹這些方法的用法,并通過示例代碼加以說明。 1. wait()、notify() 與 notifyAll() 在 Java 中,Obj…

? UNIX網絡編程筆記:TCP客戶/服務器程序示例

服務器實例 有個著名的項目&#xff0c;tiny web&#xff0c;本項目將其改到windows下&#xff0c;并使用RAII重構&#xff0c;編寫過程中對于內存泄漏確實很頭疼&#xff0c;還沒寫完&#xff0c;后面會繼續更&#xff1a; #include <iostream> #include <vector&g…

AI Agent開發大全第四課-提示語工程:從簡單命令到AI對話的“魔法”公式

什么是提示語工程&#xff1f;一個讓AI“聽話”的秘密 如果你曾經嘗試過用ChatGPT或者其他大語言模型完成任務&#xff0c;那么你一定遇到過這樣的情況&#xff1a;明明你的問題是清晰的&#xff0c;但答案卻離題萬里&#xff1b;或者你認為自己提供的信息足夠詳盡&#xff0c…

系統架構設計知識體系總結

1.技術選型 1.什么是技術選型&#xff1f; 技術選型是指評估和選擇在項目或系統開發中使用的最合適的技術和工具的過程。這涉及考慮基于其能力、特性、與項目需求的兼容性、可擴展性、性能、維護和其他因素的各種可用選項。技術選型的目標是確定與項目目標相符合、能夠有效解…

基于3DMax與Vray引擎的輕量級室內場景渲染實踐

歡迎踏入3DMAX室內渲染的沉浸式學習之旅!在這個精心設計的實戰教程中,我們將攜手揭開3DMAX與Vray這對黃金搭檔在打造現實室內場景時的核心奧秘。無論您是渴望入門的3D新手,還是追求極致效果的專業設計師,這里都將為您呈現從場景藍圖構建到光影魔法施加的完整技術圖譜。我們…

邏輯卷,vdo,(阿里加速器)

一、邏輯卷 10 20 30 1.邏輯卷的2個特點 &#xff08;1&#xff09;邏輯卷可以將多個分區或者磁盤整合成一個更大的邏輯磁盤&#xff0c;然后可以從邏輯磁盤上劃分出分區&#xff08;邏輯磁盤的大小等于整合的物理磁盤大小之和。&#xff09; &#xff08;2&#xff09;能…

檢索增強生成(2)本地PDF 本地嵌入模型

from langchain_community.document_loaders import PyPDFLoader from pathlib import Pathdef load_local_pdf(file_path):if not Path(file_path).exists():raise FileNotFoundError(f"文件 {file_path} 不存在&#xff01;")loader PyPDFLoader(file_path)try:do…