出發點
上一篇解析了Chatglm2-6b的模型架構,并和Chatglm-6b進行對比,但是留下了幾個問題(哭)這一篇的目的是講明白attention和rotaryEmbedding,解決問題,并實現整體目標,完全替代modeling_chatglm.py,并將代碼縮減到一半兒。
selfattention
class SelfAttention(torch.nn.Module):"""Parallel self-attention layer abstract class.Self-attention layer takes input with size [s, b, h]and returns output of the same size."""def __init__(self, config: ChatGLMConfig, layer_number, device=None):super(SelfAttention, self).__init__()self.layer_number = max(1, layer_number)self.projection_size = config.kv_channels * config.num_attention_heads# 128*32=4096 hidden_size# Per attention head and per partition values.self.hidden_size_per_attention_head = self.projection_size // config.num_attention_heads# 128 每個attention頭的hidden_sizeself.num_attention_heads_per_partition = config.num_attention_heads# 32 attention頭數self.num_multi_query_groups_per_partition = config.multi_query_group_num# 2 分了多少組self.qkv_hidden_size = (self.projection_size + 2 * self.hidden_size_per_attention_head * config.multi_query_group_num)# 4096+2*128*2=4608 qkv對應的hidden_size# 稍微解釋一下為什么不是4096*3,因為這里使用了GQA的思想,下文會簡單介紹一下self.query_key_value = nn.Linear(config.hidden_size, self.qkv_hidden_size,bias=config.add_bias_linear or config.add_qkv_bias,device=device, **_config_to_kwargs(config))self.core_attention = CoreAttention(config, self.layer_number)# Output.self.dense = nn.Linear(self.projection_size, config.hidden_size, bias=config.add_bias_linear,device=device, **_config_to_kwargs(config))def forward(self, hidden_states, rotary_pos_emb, kv_cache=None, use_cache=True):# hidden_states: [sq, b, h]# =================================================# Pre-allocate memory for key-values for inference.# =================================================# =====================# Query, Key, and Value# =====================# Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)]mixed_x_layer = self.query_key_value(hidden_states)(query_layer, key_layer, value_layer) = mixed_x_layer.split([self.num_attention_heads_per_partition * self.hidden_size_per_attention_head,self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head,self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head,],dim=-1,)query_layer = query_layer.view(query_layer.size()[:-1] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head))key_layer = key_layer.view(key_layer.size()[:-1] + (self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head))value_layer = value_layer.view(value_layer.size()[:-1]+ (self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head))# apply relative positional encoding (rotary embedding)if rotary_pos_emb is not None:query_layer = apply_rotary_pos_emb(query_layer, rotary_pos_emb)key_layer = apply_rotary_pos_emb(key_layer, rotary_pos_emb)# adjust key and value for inferenceif kv_cache is not None:cache_k, cache_v = kv_cachekey_layer = torch.cat((cache_k, key_layer), dim=0)value_layer = torch.cat((cache_v, value_layer), dim=0)if use_cache:kv_cache = (key_layer, value_layer)else:kv_cache = Nonekey_layer = key_layer.unsqueeze(-2)key_layer = key_layer.expand(-1, -1, -1, self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, -1)key_layer = key_layer.contiguous().view(key_layer.size()[:2] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head))# GQA的操作:重復多次到原始尺寸,即32,128value_layer = value_layer.unsqueeze(-2)value_layer = value_layer.expand(-1, -1, -1, self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, -1)value_layer = value_layer.contiguous().view(value_layer.size()[:2] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head))# GQA的操作:重復多次到原始尺寸,即32,128# ==================================# core attention computation# ==================================context_layer = self.core_attention(query_layer, key_layer, value_layer)# 核心操作attention,和Chatglm-6b中attention_fn是一樣的# =================# Output. [sq, b, h]# =================output = self.dense(context_layer)return output, kv_cache
GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints
可以看出來思想也比較樸素,MHA中query、key、value都是一對一的,這樣雖然效果好,但是caches太多了。MQA中只有一組key和value,和多個query相對應,caches減少了,但是效果會不好。那GQA則取個平均,有g組key和value,每一組key和value都重復幾次和query相對應。
GQA提供了MHA到MQA的自然過渡,當g=h時就是MHA,g=1時就是MQA,當1<g<h時,它只將KV Cache壓縮到g/h,壓縮率不如MQA,但同時也提供了更大的自由度,效果上更有保證。
這里也貼一下Fast Transformer Decoding: One Write-Head is All You Need
那這里就解決了兩個問題:
- multi_query_group_num是GQA中要分組的數量
- kv_channels對應的是query、key、value每個頭的hidden_size
coreattention
class CoreAttention(torch.nn.Module):def __init__(self, config: ChatGLMConfig, layer_number):super(CoreAttention, self).__init__()self.apply_query_key_layer_scaling = config.apply_query_key_layer_scaling# 對query、key層是否要進行縮放,實際是要縮放的self.attention_softmax_in_fp32 = config.attention_softmax_in_fp32# softmax的精度要使用fp32self.layer_number = max(1, layer_number)# Per attention head and per partition values.self.hidden_size_per_partition = config.kv_channels * config.num_attention_heads# 128*32self.hidden_size_per_attention_head = config.kv_channels# 128self.num_attention_heads_per_partition = config.num_attention_heads# 32self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)# sqrt(d)的操作self.attention_dropout = torch.nn.Dropout(config.attention_dropout)def forward(self, query_layer, key_layer, value_layer):pytorch_major_version = int(torch.__version__.split('.')[0])if pytorch_major_version >= 2:query_layer, key_layer, value_layer = [k.permute(1, 2, 0, 3) for k in [query_layer, key_layer, value_layer]]if query_layer.shape[2] == key_layer.shape[2]:# 只會在生成第一個token的時候,走這條路context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer,is_causal=True)# 從這里可以看出來Chatglm2-6b完全就是一個decoder only的模型else:# 這時候query的長度是1,key的長度是總token的長度context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer,None)context_layer = context_layer.permute(2, 0, 1, 3)new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)context_layer = context_layer.reshape(*new_context_layer_shape)else:# Raw attention scores# [b, np, sq, sk]output_size = (query_layer.size(1), query_layer.size(2), query_layer.size(0), key_layer.size(0))# [sq, b, np, hn] -> [sq, b * np, hn]query_layer = query_layer.view(output_size[2], output_size[0] * output_size[1], -1)# [sk, b, np, hn] -> [sk, b * np, hn]key_layer = key_layer.view(output_size[3], output_size[0] * output_size[1], -1)# preallocting input tensor: [b * np, sq, sk]matmul_input_buffer = torch.empty(output_size[0] * output_size[1], output_size[2], output_size[3], dtype=query_layer.dtype,device=query_layer.device)# Raw attention scores. [b * np, sq, sk]matmul_result = torch.baddbmm(matmul_input_buffer,query_layer.transpose(0, 1), # [b * np, sq, hn]key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk]beta=0.0,alpha=(1.0 / self.norm_factor),)# Chatglm-6b中將alpha放在了前面,讓query單獨除了一下,沒啥結果上的差別# 關于torch.baddbmm多說一句,因為beta=0,所以input選擇empty沒啥問題,反正要被跳過# change view to [b, np, sq, sk]attention_scores = matmul_result.view(*output_size)# ===========================# Attention probs and dropout# ===========================# attention scores and attention mask [b, np, sq, sk]if self.attention_softmax_in_fp32:attention_scores = attention_scores.float()if attention_scores.shape[2] == attention_scores.shape[3]:attention_mask = torch.ones(output_size[0], 1, output_size[2], output_size[3],device=attention_scores.device, dtype=torch.bool)attention_mask.tril_()attention_mask = ~attention_maskelse:attention_mask = None"""重點看一下這一小段代碼,當sq=sk時(即query長度和key長度一致時,給了一個attention_mask)此時的attention_mask其實就是一個上三角為True、下三角為False的矩陣結合后面的 attention_scores = attention_scores.masked_fill(attention_mask, float("-inf")) 這一句的操作就是將上三角的scores值置為負無窮,這妥妥的就是decoder-only嘛當sq!=sk時,attention_mask即為空,即預測第二個token時,此時query長度為1,而key長度帶著之前的cache,所以長度>1,此時不相等,attention_mask為空,后續也就沒有啥操作了"""if attention_mask is not None:attention_scores = attention_scores.masked_fill(attention_mask, float("-inf"))attention_probs = F.softmax(attention_scores, dim=-1)attention_probs = attention_probs.type_as(value_layer)# This is actually dropping out entire tokens to attend to, which might# seem a bit unusual, but is taken from the original Transformer paper.attention_probs = self.attention_dropout(attention_probs)# =========================# Context layer. [sq, b, hp]# =========================# value_layer -> context layer.# [sk, b, np, hn] --> [b, np, sq, hn]# context layer shape: [b, np, sq, hn]output_size = (value_layer.size(1), value_layer.size(2), query_layer.size(0), value_layer.size(3))# change view [sk, b * np, hn]value_layer = value_layer.view(value_layer.size(0), output_size[0] * output_size[1], -1)# change view [b * np, sq, sk]attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1)# matmul: [b * np, sq, hn]context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1))# change view [b, np, sq, hn]context_layer = context_layer.view(*output_size)# [b, np, sq, hn] --> [sq, b, np, hn]context_layer = context_layer.permute(2, 0, 1, 3).contiguous()# [sq, b, np, hn] --> [sq, b, hp]new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)context_layer = context_layer.view(*new_context_layer_shape)return context_layer
這里多寫一句,代碼中有關于self.coeff的操作,即layer_number
在代碼中self.norm_factor=self.coeff *math.sqrt(self.hidden_size_per_attention_head)
在計算attention_scores中除以了self.coeff *math.sqrt(self.hidden_size_per_attention_head)
然后在計算softmax之前又將attention_scores乘以了self.coeff
那不就相當于只是除以了math.sqrt(self.hidden_size_per_attention_head)嘛????
不知道為什么要有這個操作,感覺怪怪的,最主要的是不知道目的,有了解的可以解釋一下,謝謝
之前Chatglm-6b的代碼中就有這樣的操作,當時沒注意到(汗),這里的代碼是直接刪去了這個操作,完全沒影響的。
當然了因為在pytorch_major_version >= 2中其實是沒有和layer_number相關的操作,這個時候應該就能明白這個操作是無用的了。
RotaryEmbedding
class RotaryEmbedding(nn.Module):def __init__(self, dim, original_impl=False, device=None, dtype=None):super().__init__()inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, device=device).to(dtype=dtype) / dim))self.register_buffer("inv_freq", inv_freq)self.dim = dimself.original_impl = original_impldef forward_impl(self, seq_len: int, n_elem: int, dtype: torch.dtype, device: torch.device, base: int = 10000):"""Enhanced Transformer with Rotary Position Embedding.Derived from: https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/transformers/rope/__init__.py. MIT License:https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/license."""# $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, dtype=dtype, device=device) / n_elem))# Create position indexes `[0, 1, ..., seq_len - 1]`seq_idx = torch.arange(seq_len, dtype=dtype, device=device)# Calculate the product of position index and $\theta_i$idx_theta = torch.outer(seq_idx, theta).float()cache = torch.stack([torch.cos(idx_theta), torch.sin(idx_theta)], dim=-1)# this is to mimic the behaviour of complex32, else we will get different resultsif dtype in (torch.float16, torch.bfloat16, torch.int8):cache = cache.bfloat16() if dtype == torch.bfloat16 else cache.half()return cachedef forward(self, max_seq_len, offset=0):return self.forward_impl(max_seq_len, self.dim, dtype=self.inv_freq.dtype, device=self.inv_freq.device)@torch.jit.script
def apply_rotary_pos_emb(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Tensor:# x: [sq, b, np, hn]sq, b, np, hn = x.size(0), x.size(1), x.size(2), x.size(3)rot_dim = rope_cache.shape[-2] * 2# 32*2x, x_pass = x[..., :rot_dim], x[..., rot_dim:]# [:64],[64:] 將輸入根據隱藏層維度,拆分得到兩部分,只針對前部分x計算旋轉位置信息# truncate to support variable sizesrope_cache = rope_cache[:sq]xshaped = x.reshape(sq, -1, np, rot_dim // 2, 2)# [q_0,q_1][q_2,q_3]rope_cache = rope_cache.view(sq, -1, 1, xshaped.size(3), 2)# [cos0,sin0][cos1,sin1]x_out2 = torch.stack([xshaped[..., 0] * rope_cache[..., 0] - xshaped[..., 1] * rope_cache[..., 1],# 對應復數的實部q_0*cos(m\theta)-q_1*sin(m\theta)# [q0, q2, ] *[cos0, cos1] - [q1, q3, ] *[sin0, sin1]xshaped[..., 1] * rope_cache[..., 0] + xshaped[..., 0] * rope_cache[..., 1],# 對應復數的虛部q_1*cos(m\theta)+q_0*sin(m\theta)# [q1, q3, ] *[cos0, cos1] + [q0, q2, ] *[sin0, sin1]],-1,)# q0cos0-q1sin0# q1cos0+q0sin0# q2cos1-q3sin1# q3cos1+q2sin1x_out2 = x_out2.flatten(3)return torch.cat((x_out2, x_pass), dim=-1)
這里就可以解釋位置Embedding中傳入的dim為什么是rotary_dim // 2了,因為它只對一半的hidden_size進行了位置編碼,這也是很迷的一項操作,我沒看到什么很好的解釋,有了解原因的,歡迎指導,謝謝
最后一點代碼量
到此基本就寫完了代碼,最后補充上兩個函數和一點import
""" PyTorch ChatGLM model. """import math
import copy
import reimport torch
import torch.nn.functional as F
from torch import nn
from torch.nn import LayerNorm
from torch.nn.utils import skip_init
from typing import Optional, Tuple, Union, List, Callable, Dict, Any
from transformers.modeling_utils import PreTrainedModel
from configuration_chatglm import ChatGLMConfigdef _config_to_kwargs(args):common_kwargs = {"dtype": args.torch_dtype,}return common_kwargsclass ChatGLMPreTrainedModel(PreTrainedModel):"""An abstract class to handle weights initialization anda simple interface for downloading and loading pretrained models."""is_parallelizable = Falseconfig_class = ChatGLMConfigbase_model_prefix = "transformer"_no_split_modules = ["GLMBlock"]
把這些代碼保存成chatglm.py,放在chatglm2-6b的代碼中,就可以正常使用了,使用方法和chatglm-6b是一樣的
from chatglm import *
from transformers import AutoTokenizer
model_path = "/usr/downloads/chatglm2-6b"
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
model = ChatGLMForConditionalGeneration.from_pretrained(model_path, trust_remote_code=True).half().cuda()prompt = '你好'
response = model.chat(tokenizer, prompt)
代碼量在650行,原始代碼量是1280,減少一半的代碼的小目標基本實現(成功)
參數量
簡單分析一下參數量,其實從模型結構里就能很明白的看出來了,我這里就是記錄一下
# word embedding
65024*4096*2=532676608
# 最后一層后面的LN
4096
# 下面幾個是每層都有的
# query_key_value
4608*4096=18874368
# query_key_value.bias
4608
# dense
4096*4096=16777216
# LN
2*4096
# dense_h_to_4h
4096*27392=112197632
# dense_4h_to_h
13696*4096=56098816# 28層
(18874368+4608+16777216+2*4096+112197632+56098816)*28=5710903296
5710903296+532676608+4096=6243584000
# 可以看出來主要的參數還是在word Embedding和dense_h_to_4h
結束語
這次解析了chatglm2-6b的代碼,將代碼縮減到650行,并分析了與chatglm-6b的區別,其實從結構里就可以看出來,它已經不是GLM的架構了,完全是一個decoder only的結構。改為了使用了RMSNorm、使用了GQA縮減caches、激活函數使用swiglu,基本就是這些了。
補充一點:經過查看代碼,發現chatglm3-6b和chatglm2-6b的代碼基本一模一樣,只有在tokenizer處理輸入的時候和返回response的時候有一點不一樣,所以就不對chatglm3-6b做單獨的介紹了。