和多頭注意力機制的唯一區別:K、V在不同的head之間實現了復用,而對于不同的頭,Q依然不同。
因此這里的代碼和標準多頭注意力的實現也是幾乎完全一樣:
import torch
import torch.nn as nn
import torch.nn.functional as Fclass MultiQueryAttention(nn.Module):def __init__(self, embed_dim, num_heads):super().__init__()self.num_heads = num_headsself.head_dim = embed_dim // num_headsself.scale = self.head_dim ** -0.5# 查詢、鍵、值投影self.q_proj = nn.Linear(embed_dim, embed_dim) # 多頭查詢self.k_proj = nn.Linear(embed_dim, self.head_dim) # 單頭鍵self.v_proj = nn.Linear(embed_dim, self.head_dim) # 單頭值self.out_proj = nn.Linear(embed_dim, embed_dim)def forward(self, x):batch_size, seq_len, embed_dim = x.shape# 投影q = self.q_proj(x) # (batch, seq_len, embed_dim)k = self.k_proj(x) # (batch, seq_len, head_dim)v = self.v_proj(x) # (batch, seq_len, head_dim)# 重塑查詢為多頭q = q.reshape(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)# (batch, num_heads, seq_len, head_dim)# 鍵和值保持單頭,擴展到多頭維度k = k.unsqueeze(1) # (batch, 1, seq_len, head_dim)v = v.unsqueeze(1) # (batch, 1, seq_len, head_dim)# 注意力計算scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale# (batch, num_heads, seq_len, seq_len)attn = F.softmax(scores, dim=-1)out = torch.matmul(attn, v) # (batch, num_heads, seq_len, head_dim)# 合并多頭out = out.transpose(1, 2).reshape(batch_size, seq_len, embed_dim)out = self.out_proj(out) # (batch, seq_len, embed_dim)return out# 示例用法
embed_dim = 64
num_heads = 8
model = MultiQueryAttention(embed_dim, num_heads)
x = torch.randn(2, 10, embed_dim) # (batch, seq_len, embed_dim)
output = model(x)
print(output.shape) # torch.Size([2, 10, 64])