Transformer
2017 年,谷歌團隊提出 Transformer 結構,Transformer 首先應用在自然語言處理領域中的機器翻譯任務上,Transformer 結構完全構建于注意力機制,完全丟棄遞歸和卷積的結構,這使得 Transformer 結構效率更高。迄今為止,Transformer 廣泛應用于深度學習的各個領域。
模型架構
Transformer 結構如下圖所示,Transformer 遵循編碼器-解碼器(Encoder-Decoder)的結構,每個 Transformer Block 的結構基本上相同,其編碼器和解碼器可以視為兩個獨立的模型,例如:ViT 僅使用了 Transformer 編碼器,而 GPT 僅使用了 Transformer 解碼器。
編碼器
編碼器包含 N = 6 N=6 N=6 個相同的層,每個層包含兩個子層,分別是多頭自注意力層(Multi Head Self-Attention)和前饋神經網絡層(Feed Forward Network),每個子層都包含殘差連接(Residual Connection)和層歸一化(Layer Normalization),使模型更容易學習。FFN 層是一個兩層的多層感知機(Multi Layer Perceptron)。
解碼器
解碼器也包含 N = 6 N=6 N=6 個相同的層,包含三個子層,分別是掩碼多頭自注意力層(Masked Multi-Head Attention)、編碼器-解碼器多頭注意力層(Cross Attention)和前饋神經網絡層。
其中,掩碼多頭自注意力層用于將輸出的 token 進行編碼,在應用注意力機制時存在一個注意力掩碼,以保持自回歸(Auto Regressive)特性,即先生成的 token 不能注意到后生成的 token,編碼后作為 Cross Attention 層的 Query,而 Cross Attention 層的 Key 和 Value 來自于編碼器的輸出,最后通過 FFN 層產生解碼器塊的輸出。
位置編碼
與遞歸神經網絡(Recurrent Neural Networks)以串行的方式處理序列信息不同,注意力機制本身不包含位置關系,因此 Transformer 需要為序列中的每個 token 添加位置信息,因此需要位置編碼。Transformer 中使用了正弦位置編碼(Sinusoidal Position Embedding),位置編碼由以下數學表達式給出:
P E p o s , 2 i = sin ? ( p o s 1000 0 2 i / d model ) P E p o s , 2 i + 1 = cos ? ( p o s 1000 0 2 i / d model ) \begin{aligned} &PE_{pos,2i} = \sin \left( \frac{pos}{10000^{2i/d_{\text{model}}}} \right)\\ &PE_{pos,2i+1} = \cos \left( \frac{pos}{10000^{2i/d_{\text{model}}}} \right) \end{aligned} ?PEpos,2i?=sin(100002i/dmodel?pos?)PEpos,2i+1?=cos(100002i/dmodel?pos?)?
其中,pos 為 token 所在的序列位置,i 則是對應的特征維度。作者采用正弦位置編碼是基于正弦位置編碼可以使模型更容易學習到相對位置關系的假設。
下面是正弦位置編碼的 PyTorch 實現代碼,僅供參考。
class PositionEmbedding(nn.Module):"""Sinusoidal Positional Encoding."""def __init__(self, d_model: int, max_len: int) -> None:super(PositionEmbedding, self).__init__()self.pe = torch.zeros(max_len, d_model, requires_grad=False)factor = 10000 ** (torch.arange(0, d_model, step=2) / d_model)pos = torch.arange(0, max_len).float().unsqueeze(1)self.pe[:, 0::2] = torch.sin(pos / factor)self.pe[:, 1::2] = torch.cos(pos / factor)def forward(self, x: Tensor) -> Tensor:seq_len = x.size()[1]pos_emb = self.pos_encoding[:seq_len, :].unsqueeze(0).to(x.device)return pos_emb
注意力機制
注意力機制出現在 Transformer 之前,包括兩種類型:加性注意力和乘性注意力。Transformer 使用的是乘性注意力,這也是最常見的注意力機制,首先計算一個點積相似度,然后通過 Softmax 后得到注意力權重,根據注意力權重對 Values 進行加權求和,具體的過程可以表示為以下數學公式:
Attention ( Q , K , V ) = softmax ( Q K T d k ) V \text{Attention}(Q,K,V)=\text{softmax}(\frac{QK^T}{\sqrt{d_k}})V Attention(Q,K,V)=softmax(dk??QKT?)V
其中,注意力計算中包含了一個溫度參數 d k \sqrt{d_k} dk??,一個直觀的解釋是避免點積的結果過大或過小,導致 softmax 后的結果梯度幾乎為 0 的區域,降低模型的收斂速度。對于自回歸生成任務而言,我們不希望前面生成的 token 關注后面生成 token,因此可能會采用一個下三角的 Attention Mask,掩蓋掉 attention 矩陣的上三角部分,注意力機制可以重寫為:
Attention ( Q , K , V ) = softmax ( Q K T d k + M ) V \text{Attention}(Q,K,V)=\text{softmax}(\frac{QK^T}{\sqrt{d_k}}+M)V Attention(Q,K,V)=softmax(dk??QKT?+M)V
具體實現中,需要 mask 掉的部分設置為負無窮即可,這會使得在 softmax 操作后得到的注意力權重為 0,避免注意到特定的 token。
有趣的是,注意力機制本身不包含可學習參數,因此,在 Transformer 中引入了多頭注意力機制,同時希望多頭注意力能夠捕捉多種模式,類似于卷積。多頭注意力機制可以表示為:
MultiHead ( Q , K , V ) = Concat ( head 1 , head 2 , … , head h ) W O where? head i = Attention ( Q W i Q , K W i K , V W i V ) \begin{aligned} \text{MultiHead}(Q,K,V)=\text{Concat}(\text{head}_1,\text{head}_2,\dots,\text{head}_h)W^O\\ \text{where }\text{head}_i=\text{Attention}(QW_i^Q,KW_i^K,VW_i^V) \end{aligned} MultiHead(Q,K,V)=Concat(head1?,head2?,…,headh?)WOwhere?headi?=Attention(QWiQ?,KWiK?,VWiV?)?
以下為多頭注意力機制的 PyTorch 實現代碼,僅供參考。
from torch import nn, Tensor
from functools import partialclass MultiHeadAttention(nn.Module):"""Multi-Head Attention."""def __init__(self, d_model: int, n_heads: int) -> None:super(MultiHeadAttention, self).__init__()self.n_heads = n_headsself.proj_q = nn.Linear(d_model, d_model)self.proj_k = nn.Linear(d_model, d_model)self.proj_v = nn.Linear(d_model, d_model)self.proj_o = nn.Linear(d_model, d_model)self.attention = ScaledDotProductAttention()def forward(self, q: Tensor, k: Tensor, v: Tensor, mask: Tensor | None = None) -> Tensor:# input tensor of shape (batch_size, seq_len, d_model)# 1. linear transformationq, k, v = self.proj_q(q), self.proj_k(k), self.proj_v(v)# 2. split tensor by the number of headsq, k, v = map(partial(_split, n_heads=self.n_heads), (q, k, v))# 3. scaled dot-product attentionout = self.attention(q, k, v, mask)# 4. concatenate headsout = _concat(out)# 5. linear transformationreturn self.proj_o(out)class ScaledDotProductAttention(nn.Module):"""Scaled Dot-Product Attention."""def __init__(self) -> None:super(ScaledDotProductAttention, self).__init__()self.softmax = nn.Softmax(dim=-1)def forward(self, q: Tensor, k: Tensor, v: Tensor, mask: Tensor | None = None) -> Tensor:# input tensor of shape (batch_size, n_heads, seq_len, d_head)d_k = k.size()[3]k_t = k.transpose(2, 3)# 1. compute attention scorescore: Tensor = (q @ k_t) * d_k**-0.5# 2. apply mask(optional)if mask is not None:score = score.masked_fill(mask == 0, float("-inf"))# 3. compute attention weightsattn = self.softmax(score)# 4. compute attention outputout = attn @ vreturn outdef _split(tensor: Tensor, n_heads: int) -> Tensor:"""Split tensor by the number of heads."""batch_size, seq_len = tensor.size()[:2]d_model = tensor.size()[2]d_head = d_model // n_headsreturn tensor.view(batch_size, seq_len, n_heads, d_head).transpose(1, 2)def _concat(tensor: Tensor) -> Tensor:"""Concatenate tensor after splitting."""batch_size, n_heads, seq_len, d_head = tensor.size()d_model = n_heads * d_head
參考
[1] A. Vaswani et al., “Attention is All you Need,” in Advances in Neural Information Processing Systems, Curran Associates, Inc., 2017.
[2] A. Dosovitskiy et al., “An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale,” Jun. 03, 2021, arXiv: arXiv:2010.11929.
[3] K. He, X. Zhang, S. Ren, and J. Sun, “Deep Residual Learning for Image Recognition,” Dec. 10, 2015, arXiv: arXiv:1512.03385.
[4] A. Radford, K. Narasimhan, T. Salimans, and I. Sutskever, “Improving Language Understanding by Generative Pre-Training”.
[5] hyunwoongko. "Transformer: PyTorch Implementation of ‘Attention Is All You Need’ " Github 2019. [Online] Available: https://github.com/hyunwoongko/transformer
[6] 李沐. “Transformer論文逐段精讀【論文精讀】” Bilibili 2021. [Online] Available: https://www.bilibili.com/video/BV1pu411o7BE/?spm_id_from=333.337.search-card.all.click&vd_source=c8a32a5a667964d5f1068d38d6182813