前情回顧
在之前的章節我們已經構建好了視覺編碼器,預處理模塊,以及gemma模型的頂層。gemma模型的頂層,主要是構建圖中圈出的輸入,它把視覺編碼器里每個圖像patch的編碼維度對齊到自然語言token的嵌入維度,并組裝成了一個大的輸入向量。同時在模型的頂層,我們準備好了位置id 以及attention mask,用來在后面的模型層計算旋轉位置編碼和注意力得分矩陣。接下來,我們要開始構建gemma模型的架構了。
頂層模型 GemmaForCausalLM
還記得嗎,在之前的paligemma模型的頂層,我們有一個GemmaForCausalLM,然后我們通過下面的代碼把輸入傳入了語言模型:
self.language_model = GemmaForCausalLM(config.text_config)
outputs = self.language_model(
inputs_embeds = input_embeds,
position_ids = position_ids,
attention_mask = attention_mask,
kv_cache = kv_cache,
**kwargs
)
現在我們首先要實現這個GemmaForCausalLM。
一般模型的上層是對整個模型邏輯的簡單封裝,故這里GemmaForCausalLM的作用很簡單,它僅僅把上下文編碼后的注意力嵌入通過一個MLP轉換為不同token的輸出概率,也就是logits,然后返回給上層,從而讓上層根據概率分布來采樣下一個要輸出的token是什么。
先給出代碼:
class GemmaForCausalLM(nn.Module): ## 匹配
def __init__(self,config:GemmaConfig): ##CasualLM實際上是Transformer模型加一個投影層,即將嵌入轉換為對數概率
super().__init__()
self.config = config
self.model = GemmaModel(config)
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size,config.vocab_size,bias=False) def get_input_embeddings(self): ##這里返回的是模型對象本身
return self.model.embed_tokens def tie_weights(self):
self.lm_head.weight = self.model.embed_tokens.weight def forward(
self,
attention_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
kv_cache: Optional[KVCache] = None,
position_ids: Optional[torch.Tensor] = None
):
'''
input: [Batch_size, Seq_len, Hidden_size]
output: [Batch_size, Seq_len, Vocab_size]
'''
## [Batch_size, Seq_len, Hidden_size]
outputs = self.model(
attention_mask = attention_mask,
inputs_embeds = inputs_embeds,
kv_cache = kv_cache,
position_ids = position_ids
) hidden_states = outputs
logits = self.lm_head(hidden_states) #lm_head負責將hidden_states映射到vocab_size維度的向量,即logits
logits = logits.float() return_data = {
"logits": logits
}
if kv_cache is not None:
return_data["kv_cache"] = kv_cache ##這里kv cache是要傳遞下去的,因為自回歸的邏輯下,后面生成的token的注意力計算要能夠通過kv cache來看到之前的token的kv return return_data
以上便是頂層模型的前向傳遞過程:
- 就是通過 GemmaModel 生成的注意力嵌入來計算logits
- 注意:由于我們在推理過程中,后續的token計算要用到之前的kv,所以kv cache必須在推理的過程中依次傳遞下去,同時也要返回給上層,從而在下一次推理運算的時候有kv cache可以傳入。
- 我們之前用到了參數捆綁的策略,即token嵌入的模型參數等于嵌入反解碼成logits的模型參數,所以我們提供這兩個函數供上層調用:
def get_input_embeddings(self): ##這里返回的是模型對象本身return self.model.embed_tokensdef tie_weights(self):self.lm_head.weight = self.model.embed_tokens.weight
GemmaModel
GemmaModel里面實際上就是一個注意力塊的序列,就像一個注意力塊數組一樣,而該層需要做的僅僅是將輸入在不同的注意力塊里依次傳遞,并把最后一個注意力塊的輸出返回給上層即可。
class GemmaModel(nn.Module): ## 匹配def __init__(self,config:GemmaConfig):super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.embed_tokens = nn.Embedding(config.vocab_size,config.hidden_size,padding_idx=config.pad_token_id)
self.layers = nn.ModuleList([GemmaLayer(config, _) for _ in range(config.num_hidden_layers)])
self.norm = GemmaRMSNorm(config.hidden_size,eps=config.rms_norm_eps) ##Root Mean Square Normalization均方根標準化,該論文表明并不一定要標準化到標準正態分布,而是只要方差為1就可以def forward(
self,
attention_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
kv_cache: Optional[KVCache] = None,
position_ids: Optional[torch.Tensor] = None):#[Batch_size, Seq_len, Hidden_size]
hidden_states = inputs_embeds
normalizer = torch.tensor(self.hidden_size ** 0.5,dtype= inputs_embeds.dtype)
hidden_states = hidden_states * normalizerfor layer in self.layers:
hidden_states = layer(
hidden_states = hidden_states,
attention_mask = attention_mask,
kv_cache = kv_cache,
position_ids = position_ids)## 均方根歸一化,不改變shape
hidden_states = self.norm(hidden_states)return hidden_states
這里我們用一個nn.ModuleList來存儲所有的GemmaLayer,一個GemmaLayer實際上就是一個attention 塊。值得注意的是,在每個attention塊內部我們將會做兩次歸一化,但是每個attention layer的輸出不會做歸一化,為了使得上層的計算能拿到歸一化后的結果,我們在整個list前向傳遞完了之后再補一個normalization的過程:
hidden_states = self.norm(hidden_states)
- 注意:我們此處用的是RMSNorm,即均方根歸一化,關于這個歸一化與之前的其他歸一化的不同我們會在文末補充一些資料。
有人可能想問,為什么嵌入模型會放到這里:self.embed_tokens
這是因為paligemma的作者是這么實現的,而我們將從huggingface來導入整個模型的參數,所以我們的架構也必須和作者一樣才能正確導入參數,所以我們不得不放在這里。
GemmaLayer
在一個attention塊里面我們有一個多頭注意力層和一個前向傳播網絡,以及兩個歸一化,但我們實際的實現中會把歸一化提前,即add&norm -> attention -> add&norm -> ff。
這也就是為什么上面提到在layer的輸出沒有做歸一化。
代碼如下:
class GemmaLayer(nn.Module): ##匹配def __init__(self,config:GemmaConfig,layer_idx:int): ##layer_idx是當前layer的索引,輔助attention存儲kv_cachesuper().__init__()self.config = configself.layer_idx = layer_idxself.hidden_size = config.hidden_sizeself.intermediate_size = config.intermediate_sizeself.input_layernorm = GemmaRMSNorm(config.hidden_size,eps=config.rms_norm_eps)self.post_attention_layernorm = GemmaRMSNorm(config.hidden_size,eps=config.rms_norm_eps)self.mlp = GemmaMLP(config)self.self_attn = GemmaAttention(config,layer_idx)def forward(self,hidden_states: torch.Tensor,attention_mask: Optional[torch.Tensor] = None,kv_cache: Optional[KVCache] = None,position_ids: Optional[torch.Tensor] = None)-> Tuple[torch.Tensor,Optional[Tuple[torch.FloatTensor,torch.FloatTensor]]]:
'''input: [Batch_size, Seq_len, Hidden_size]output: [Batch_size, Seq_len, Hidden_size] '''residual = hidden_stateshidden_states = self.input_layernorm(hidden_states)hidden_states,_ = self.self_attn(hidden_states = hidden_states,attention_mask = attention_mask,kv_cache = kv_cache,position_ids = position_ids)hidden_states = residual + hidden_statesresidual = hidden_stateshidden_states = self.post_attention_layernorm(hidden_states)hidden_states = self.mlp(hidden_states)hidden_states = residual + hidden_statesreturn hidden_states
- 在這里的兩個歸一化我們也用RMSNorm來進行歸一化,注意除了歸一化,我們還要處理好殘差。
- 殘差的作用是防止梯度為0導致訓練緩慢。
RMSNorm
在前面的第四章節:手搓多模態-04 歸一化介紹 里面我們介紹了BatchNormalization和LayerNormalization,我們了解到以下信息:
- 歸一化是為了防止不同模型層的輸入輸出不穩定,分布不均勻導致的訓練速度過慢
- BN 依賴于batch 的規模,而batch的規模過大會導致訓練速度變相過慢
- LN 通過對單個樣本的所有特征進行標準化規避了BN的問題,主要做法是對單個樣本的所有特征計算均值和方差,從而將其分布轉換為0-1分布。
RMSNormalization,又稱均方差歸一化,是由論文《Root Mean Square Layer Normalization》提出的,該文章發現,其實分布不穩定的問題和均值沒有關系,主要是方差的問題,所以只需要特征的方差穩定即可,不需要計算均值,這樣可以減少計算的時間,從而加速訓練。
論文提出用均方根來對每個值進行縮放,從而使得方差更小,如圖所示。
其中,a_i 表示縮放前的特征值,RMS(a)表示所有特征值計算出來的均方根,g是一個可學習的參數向量,b是偏置。
在paligemma的實現中,RMSNorm的代碼如下:
class GemmaRMSNorm(nn.Module): ##匹配
def __init__(self,dim,eps=1e-6): ##dim是hidden_size
super().__init__()
self.dim = dim
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim)) def _norm(self,x): return x * torch.rsqrt(x.pow(2).mean(dim = -1,keepdim=True) + self.eps) ##rsqrt表示平方的倒數,self.eps是防止分母為0 def forward(self,x):
x = self._norm(x)
output = x * (1.0 + self.weight.float()) ##論文中的可學習參數g
return output.type_as(x)
其中特征的維度為嵌入的維度大小。