從入門到放棄
花了幾天時間,看懂了DeepSeek V3 源碼的邏輯。源碼的邏輯是不難的,但為什么模型結構需要這樣設計,為什么參數需要這樣設置呢?知其然,但不知其所以然。除了模型結構以外,模型的訓練數據、訓練腳本和訓練經驗,也是DeepSeek V3能夠訓練出來的關鍵,但這些是DeepSeek母公司的核心機密,我們無從得知。
因此,看懂了源碼,算是入門了DeepSeek V3,因為沒有條件知道更多重要細節,因此不得不放棄重現整個模型的訓練。
Paper 和源碼
Paper URL: https://github.com/deepseek-ai/DeepSeek-V3/blob/main/DeepSeek_V3.pdf
Code URL: https://github.com/deepseek-ai/DeepSeek-V3
模型邏輯
下面這張圖,代表了DeepSeek的核心邏輯。左邊是Transformer的邏輯結構,可以認為有N個左邊這樣的Block結構不斷重復,組成Transformer模型。每個Block中,分成兩個部分,Attention 和 Feed-Forward Network。對這兩個部分使用不同的網絡結構,我們就得到了不同的模型。
DeepSeek V3 的 Attention 用的是 Multi-Head Latent Attention(MLA) ,Feed-Forward Network 用的是DeepSeekMoE。
MLA
Multi-Head Latent Attention(MLA)即多頭潛在注意力,是DeepSeek模型中引入的一種創新注意力機制,旨在優化傳統多頭注意力(Multi-Head Attention,MHA)的計算效率和內存占用。具體介紹如下:
核心創新點
- 低秩鍵值壓縮
- KV的低秩壓縮:不直接存儲原始的Key和Value,而是先將隱藏狀態投影到一個更小的壓縮潛在向量。在推理時,只需緩存該壓縮潛在向量,而不是完整的Key和Value,從而大大降低了KV緩存的存儲需求。
- Query的低秩壓縮:對Query也進行低秩壓縮,雖然不會減少KV緩存的大小,但可以減少訓練時的激活存儲需求,進而降低計算成本。
- 解耦旋轉位置嵌入(RoPE)
- 額外引入“解耦查詢”:將查詢拆分為兩個部分,一部分不經過RoPE變換,代表非位置敏感的特征信息;另一部分專門用于嵌入RoPE位置編碼信息。
- 共享RoPE變換的Key:所有注意力頭共用一個旋轉變換后的Key,減少了計算開銷,也減小了KV緩存大小,降低了GPU內存占用,提高了推理速度,特別適用于長序列任務和大規模Transformer。
推理過程中的優化
將上投影矩陣吸收到里面,簡化查詢計算,并優化注意力分數的計算,減少了計算步驟,提升了計算效率。避免了先計算Value向量,減少了矩陣運算的開銷,使推理更快。
整體優勢
- 降低內存占用:通過對鍵值進行低秩聯合壓縮以及解耦RoPE等策略,顯著減少了KV緩存的存儲需求,降低了GPU內存占用。
- 提高計算效率:減少了訓練和推理過程中的計算量,加快了模型的推理速度,在保持甚至提高模型性能的同時,提升了模型的運行效率。
- 增強模型適應性:特別適用于長序列任務和大規模Transformer模型,能夠更好地處理長序列輸入,提高模型在各種自然語言處理任務中的表現。
MLA 有物理意義嗎?
Multi-Head Latent Attention(MLA)能夠起作用主要源于其獨特的技術設計,在數學和信息處理層面有清晰的邏輯,不過它是一種抽象的算法概念,并不直接對應具體的物理意義,以下是對其作用原理的分析:
起作用的原因
- 低秩壓縮的有效性
- 信息濃縮與降噪:通過低秩鍵值壓縮,MLA將高維的Key和Value信息投影到低維的潛在向量空間,這一過程類似于對原始信息進行濃縮,提取出最關鍵、最具代表性的特征,去除了一些可能的噪聲和冗余信息,使得模型能夠更聚焦于重要信息,從而提高信息處理的效率和準確性。
- 減少計算量和存儲需求:低秩壓縮大大降低了數據的維度,減少了模型訓練和推理過程中的計算量和存儲需求,使得模型能夠更高效地運行,尤其是在處理大規模數據和長序列數據時,這種優勢更為明顯。
- 解耦旋轉位置嵌入的優勢
- 位置信息與內容信息的分離:傳統的位置編碼方式將位置信息和內容信息混合在一起進行處理,而MLA的解耦旋轉位置嵌入將查詢拆分為位置敏感和非位置敏感兩部分,使模型能夠更清晰地分離和處理位置信息與內容信息,更好地捕捉文本中的長距離依賴關系。
- 共享RoPE變換的Key:所有注意力頭共用一個旋轉變換后的Key,不僅減少了計算開銷,還使得模型能夠從更宏觀的角度利用位置信息,增強了模型對序列數據整體結構的理解和把握能力。
- 多頭機制的協同作用
- 捕捉多維度信息:MLA中的多頭機制允許模型同時從多個不同的角度和維度去捕捉輸入數據中的信息,每個頭可以關注到輸入序列的不同方面,通過多個頭的并行計算和協同工作,模型能夠更全面、更深入地理解輸入數據,提高模型的表示能力和泛化能力。
難以直接賦予物理意義的原因及近似理解
- 抽象的算法概念:MLA是一種基于數學和計算機科學的算法概念,主要用于處理和分析數據中的模式和關系,它不像物理概念那樣具有直接可觀測的物理實體或現象與之對應,更多地是在數據空間和計算邏輯中發揮作用。
- 類比物理現象理解:可以進行一些類比來幫助理解。比如低秩壓縮類似于物理中的能量聚集,將分散的能量(信息)聚集到關鍵的“點”上;解耦旋轉位置嵌入有點像物理中對不同性質力的分解,將位置信息和內容信息這兩種“力”分開處理;多頭機制如同多個物理傳感器從不同方向和角度對環境進行感知,然后綜合這些感知信息來對整個系統進行理解和判斷。
DeepSeekMoE
DeepSeekMoE是由深度求索(DeepSeek)研發的基于混合專家系統(Mixture of Experts,MoE)的技術架構,以下是具體介紹:
架構原理
- 混合專家系統核心:采用MoE架構,核心在于通過動態路由機制,把輸入數據分配給最相關的專家處理。比如在自然語言處理中,有的專家專門處理情感分析,有的處理主題建模。
- 結合多頭潛在注意力機制:與MLA相結合,MLA通過引入潛在向量,減少鍵值緩存(KV cache)需求,提升推理效率。
- Transformer架構基礎:以Transformer架構為基礎,每個Transformer塊由一個注意力模塊和一個前饋網絡(FFN)組成,在注意力機制和FFN方面采用創新架構。
技術優勢
- 降低算力需求:MoE的動態分配機制和MLA減少KV緩存需求等特點,使模型在訓練和推理時對算力的要求降低。
- 保持高性能:在參數量減少的情況下仍能保持高性能,例如DeepSeek-V2以236B總參數、21B激活,大致可以達到70B-110B Dense的模型能力。
- 減少計算量:自研Sparse結構DeepSeekMoE進一步降低了計算量。
- 長上下文理解能力強:支持超100萬token的上下文窗口,顯著優于行業平均水平,適用于長文檔分析、代碼開發等復雜場景的連貫交互。
DeepSeekMoE的物理意義是什么?
DeepSeekMoE作為一種人工智能技術架構,沒有嚴格意義上的物理意義,但可以從一些角度進行類比和理解:
從系統資源分配角度
- 資源按需分配類比:可以將DeepSeekMoE的專家網絡和動態路由機制類比為一個智能電力分配系統。在這個系統中,不同的電器設備(任務)需要不同的電量(計算資源)來運行。專家網絡就像不同功率的發電機,而動態路由機制則像是智能電表和分配器,它會根據每個電器設備的實際需求,將電力(計算資源)精準地分配給需要的設備,避免了資源的浪費,提高了整個系統的能源利用效率。
- 負載均衡類比:類似于在一個大型物流中心,不同的倉庫區域(專家)負責存儲和處理不同類型的貨物(數據)。當有貨物運輸任務時,調度系統(動態路由)會根據貨物的特點和倉庫的負載情況,合理地安排貨物存儲到哪個倉庫,確保每個倉庫都能在其承載能力范圍內高效運作,不會出現某個倉庫過度擁擠而其他倉庫閑置的情況,實現了負載均衡,提高了物流中心的整體運營效率。
從信息處理角度
- 多維度信息處理類比:可以把DeepSeekMoE處理信息的過程想象成一個由多個不同專業的偵探(專家)組成的偵探團隊在調查一個復雜案件。每個偵探都有自己獨特的專業技能和視角,比如有的擅長調查線索,有的擅長分析人物關系,有的擅長破解密碼等。當面對案件(輸入數據)時,隊長(路由器)會根據案件的具體情況,分配合適的偵探去處理相應的部分,最后將各個偵探的調查結果綜合起來,形成對整個案件的全面了解和判斷,從而更高效地解決復雜問題。
- 特征提取與融合類比:如同在一個化學實驗中,不同的化學試劑(專家)可以與不同的物質發生反應,提取出特定的化學特征。DeepSeekMoE中的專家網絡就像這些化學試劑,它們各自對輸入數據進行處理,提取出不同的特征。然后通過融合機制,將這些特征像混合化學物質一樣進行整合,得到更全面、更有價值的信息,用于后續的分析和決策。
從模型架構角度
- 積木搭建類比:把DeepSeekMoE的架構比作搭建積木。每個專家網絡就像不同形狀和功能的積木塊,有的積木塊負責搭建基礎結構,有的負責構建上層建筑,有的負責添加裝飾等。路由器則像是搭建者的手,根據要搭建的目標模型的需求,選擇合適的積木塊進行組合,最終搭建出一個復雜而功能強大的模型結構,實現對各種自然語言處理任務的高效處理。
- 人體神經系統類比:可以將DeepSeekMoE類比為人體的神經系統。專家網絡類似于人體的不同神經細胞或神經中樞,它們各自負責處理特定類型的信息,如視覺神經細胞負責處理視覺信息,聽覺神經細胞負責處理聽覺信息等。路由器就像神經系統中的神經遞質或信號傳導機制,它負責將外界的刺激信號(輸入數據)準確地傳遞給相應的神經細胞,并將各個神經細胞處理后的信號進行整合和傳遞,使人體能夠做出協調的反應和決策,實現對外部世界的感知和交互。
代碼邏輯
整體 - Transformer
下面這段代碼是典型的 Transformer 實現,核心可以看 forward 函數邏輯:
- 進行 Embeding;
- 經過各個 Block;
- 歸一化并輸出。
對應的代碼:
# 通過嵌入層將輸入標記轉換為向量表示
h = self.embed(tokens)
# 依次通過每個Transformer塊進行處理
for layer in self.layers:h = layer(h, start_pos, freqs_cis, mask)
# 對輸出進行層歸一化,并取最后一個時間步的輸出
h = self.norm(h)[:, -1]
# 通過輸出投影層得到對數概率
logits = self.head(h)
完整代碼:
# 定義Transformer類,繼承自PyTorch的nn.Module類
class Transformer(Module):"""Transformer模型,包含位置嵌入、多個層以及輸出投影。屬性:max_seq_len (int): Transformer允許的最大序列長度。embed (nn.Module): 用于輸入標記的嵌入層,將輸入的標記轉換為向量表示。layers (torch.nn.ModuleList): 存儲多個Transformer塊的列表,每個塊包含多頭注意力和前饋網絡。norm (nn.Module): 層歸一化層,在所有Transformer塊之后應用,用于穩定訓練。head (nn.Module): 輸出投影層,將模型的輸出映射到詞匯表大小,用于預測下一個標記。freqs_cis (torch.Tensor): 預計算的復指數值,用于旋轉位置嵌入,幫助模型捕捉序列中的位置信息。"""def __init__(self, args):"""初始化Transformer模型。參數:args: 模型參數對象,包含Transformer的各種參數,如詞匯表大小、維度、層數等。"""# 獲取全局變量world_size和rank,分別表示分布式訓練中的進程總數和當前進程的編號global world_size, rank# 如果分布式訓練已初始化,則獲取進程總數,否則默認為1world_size = dist.get_world_size() if dist.is_initialized() else 1# 如果分布式訓練已初始化,則獲取當前進程編號,否則默認為0rank = dist.get_rank() if dist.is_initialized() else 0# 根據參數設置線性層的數據類型Linear.dtype = torch.float8_e4m3fn if args.dtype == "fp8" else torch.bfloat16# 調用父類的初始化方法super().__init__()# 保存最大序列長度self.max_seq_len = args.max_seq_len# 初始化嵌入層,將輸入標記轉換為向量表示self.embed = ParallelEmbedding(args.vocab_size, args.dim)# 初始化一個空的ModuleList,用于存儲Transformer塊self.layers = torch.nn.ModuleList()# 循環創建指定數量的Transformer塊,并添加到layers列表中for layer_id in range(args.n_layers):self.layers.append(Block(layer_id, args))# 初始化層歸一化層self.norm = RMSNorm(args.dim)# 初始化輸出投影層,將模型的輸出映射到詞匯表大小self.head = ColumnParallelLinear(args.dim, args.vocab_size, dtype=torch.get_default_dtype())# 預計算旋轉位置嵌入所需的復指數值,并將其注冊為緩沖區,不參與模型參數的更新self.register_buffer("freqs_cis", precompute_freqs_cis(args), persistent=False)@torch.inference_mode()def forward(self, tokens, start_pos=0):"""Transformer模型的前向傳播過程。參數:tokens (torch.Tensor): 輸入的標記ID張量,形狀為 (batch_size, seq_len)。start_pos (int, 可選): 旋轉位置嵌入的起始位置,默認為0。返回:torch.Tensor: 對數概率張量,形狀為 (batch_size, vocab_size),表示每個標記的預測概率。"""# 獲取輸入序列的長度seqlen = tokens.size(1)# 通過嵌入層將輸入標記轉換為向量表示h = self.embed(tokens)# 從預計算的復指數值中截取當前序列所需的部分freqs_cis = self.freqs_cis[start_pos:start_pos + seqlen]# 初始化掩碼為Nonemask = None# 如果序列長度大于1,則創建一個上三角掩碼,用于屏蔽未來的標記if seqlen > 1:mask = torch.full((seqlen, seqlen), float("-inf"), device=tokens.device).triu_(1)# 依次通過每個Transformer塊進行處理for layer in self.layers:h = layer(h, start_pos, freqs_cis, mask)# 對輸出進行層歸一化,并取最后一個時間步的輸出h = self.norm(h)[:, -1]# 通過輸出投影層得到對數概率logits = self.head(h)# 如果使用分布式訓練,則收集所有進程的對數概率if world_size > 1:# 創建一個列表,用于存儲所有進程的對數概率all_logits = [torch.empty_like(logits) for _ in range(world_size)]# 收集所有進程的對數概率dist.all_gather(all_logits, logits)# 將所有進程的對數概率拼接在一起logits = torch.cat(all_logits, dim=-1)return logits
單個 - Block
核心代碼非常簡單MLA(attention) + MOE(Feed-Forward Network):
# 首先對輸入進行層歸一化,然后通過注意力層進行計算,最后將結果與輸入進行殘差連接
x = x + self.attn(self.attn_norm(x), start_pos, freqs_cis, mask)
# 接著對上述結果進行層歸一化,再通過前饋網絡層進行計算,最后將結果與之前的結果進行殘差連接
x = x + self.ffn(self.ffn_norm(x))
全部代碼:
# 定義一個Transformer塊類,繼承自PyTorch的nn.Module類
class Block(Module):"""Transformer塊,結合了注意力層和前饋網絡層。屬性:attn (nn.Module): 注意力層(采用多頭潛在注意力機制,即MLA),用于捕捉輸入序列中不同位置之間的依賴關系。ffn (nn.Module): 前饋網絡層(可以是多層感知機MLP或者混合專家模型MoE),對注意力層的輸出進行非線性變換。attn_norm (nn.Module): 用于注意力層的層歸一化層,對輸入到注意力層的數據進行歸一化處理,穩定訓練過程。ffn_norm (nn.Module): 用于前饋網絡層的層歸一化層,對輸入到前饋網絡層的數據進行歸一化處理。"""def __init__(self, layer_id, args):"""初始化Transformer塊。參數:layer_id (int): 當前塊在Transformer模型中的層索引,用于確定使用哪種前饋網絡結構。args: 模型參數對象,包含了塊的各種參數,如維度、層數等。"""# 調用父類的初始化方法super().__init__()# 初始化注意力層,使用多頭潛在注意力機制(MLA)self.attn = MLA(args)# 根據當前層的索引來決定使用MLP還是MoE作為前饋網絡# 如果當前層索引小于密集層的數量,則使用MLP# 否則使用混合專家模型(MoE)self.ffn = MLP(args.dim, args.inter_dim) if layer_id < args.n_dense_layers else MoE(args)# 初始化用于注意力層的層歸一化層self.attn_norm = RMSNorm(args.dim)# 初始化用于前饋網絡層的層歸一化層self.ffn_norm = RMSNorm(args.dim)def forward(self, x, start_pos, freqs_cis, mask=None):"""Transformer塊的前向傳播過程。參數:x (torch.Tensor): 輸入張量,包含了序列的特征信息。start_pos (int): 序列中的起始位置,用于旋轉位置嵌入。freqs_cis (torch.Tensor): 預計算的復指數值,用于旋轉位置嵌入,幫助模型捕捉序列中的位置信息。mask (Optional[torch.Tensor]): 掩碼張量,用于在注意力計算中排除某些位置,避免模型關注到不應該關注的信息。返回:torch.Tensor: 經過當前Transformer塊計算后的輸出張量。"""# 首先對輸入進行層歸一化,然后通過注意力層進行計算,最后將結果與輸入進行殘差連接x = x + self.attn(self.attn_norm(x), start_pos, freqs_cis, mask)# 接著對上述結果進行層歸一化,再通過前饋網絡層進行計算,最后將結果與之前的結果進行殘差連接x = x + self.ffn(self.ffn_norm(x))return x
Attention 模塊
經典的QKV計算公式。解釋可以自行搜索,或者參考:Transformer結構和注意力機制
和傳統的QKV相比,可以認為是做了壓縮,主要是為了減小 KV Cache。
代碼,就是做了一堆下采樣上采樣和矩陣的組合變換。最終目的是減少計算量和顯存使用量。
# 定義多頭注意力層類,繼承自PyTorch的nn.Module類
class MLA(Module):"""多頭注意力層(MLA)。屬性:dim (int): 輸入特征的維度。n_heads (int): 注意力頭的數量。n_local_heads (int): 分布式系統中本地注意力頭的數量。q_lora_rank (int): 查詢(query)的低秩投影的秩。kv_lora_rank (int): 鍵(key)和值(value)的低秩投影的秩。qk_nope_head_dim (int): 非位置相關的查詢/鍵投影的維度。qk_rope_head_dim (int): 旋轉位置編碼的查詢/鍵投影的維度。qk_head_dim (int): 查詢/鍵投影的總維度。v_head_dim (int): 值投影的維度。softmax_scale (float): 注意力計算中softmax函數的縮放因子。"""def __init__(self, args):# 調用父類的初始化方法super().__init__()# 保存輸入特征的維度self.dim = args.dim# 保存注意力頭的數量self.n_heads = args.n_heads# 計算分布式系統中本地注意力頭的數量self.n_local_heads = args.n_heads // world_size# 保存查詢的低秩投影的秩self.q_lora_rank = args.q_lora_rank# 保存鍵和值的低秩投影的秩self.kv_lora_rank = args.kv_lora_rank# 保存非位置相關的查詢/鍵投影的維度self.qk_nope_head_dim = args.qk_nope_head_dim# 保存旋轉位置編碼的查詢/鍵投影的維度self.qk_rope_head_dim = args.qk_rope_head_dim# 計算查詢/鍵投影的總維度self.qk_head_dim = args.qk_nope_head_dim + args.qk_rope_head_dim# 保存值投影的維度self.v_head_dim = args.v_head_dim# 如果查詢的低秩投影的秩為0,直接使用列并行線性層進行查詢投影if self.q_lora_rank == 0:self.wq = ColumnParallelLinear(self.dim, self.n_heads * self.qk_head_dim)# 否則,使用低秩分解的方式進行查詢投影else:self.wq_a = Linear(self.dim, self.q_lora_rank)self.q_norm = RMSNorm(self.q_lora_rank)self.wq_b = ColumnParallelLinear(self.q_lora_rank, self.n_heads * self.qk_head_dim)# 對輸入進行線性變換得到鍵和值的低秩表示self.wkv_a = Linear(self.dim, self.kv_lora_rank + self.qk_rope_head_dim)# 對鍵和值的低秩表示進行歸一化self.kv_norm = RMSNorm(self.kv_lora_rank)# 對歸一化后的鍵和值進行線性變換self.wkv_b = ColumnParallelLinear(self.kv_lora_rank, self.n_heads * (self.qk_nope_head_dim + self.v_head_dim))# 對多頭注意力的輸出進行行并行線性變換self.wo = RowParallelLinear(self.n_heads * self.v_head_dim, self.dim)# 計算softmax函數的縮放因子self.softmax_scale = self.qk_head_dim ** -0.5# 如果最大序列長度大于原始序列長度,對縮放因子進行調整if args.max_seq_len > args.original_seq_len:mscale = 0.1 * args.mscale * math.log(args.rope_factor) + 1.0self.softmax_scale = self.softmax_scale * mscale * mscale# 如果注意力實現方式為樸素方式if attn_impl == "naive":# 注冊鍵緩存self.register_buffer("k_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.n_local_heads, self.qk_head_dim), persistent=False)# 注冊值緩存self.register_buffer("v_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.n_local_heads, self.v_head_dim), persistent=False)# 否則else:# 注冊鍵值緩存self.register_buffer("kv_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.kv_lora_rank), persistent=False)# 注冊位置編碼緩存self.register_buffer("pe_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.qk_rope_head_dim), persistent=False)def forward(self, x, start_pos, freqs_cis, mask=None):"""多頭注意力層(MLA)的前向傳播過程。參數:x (torch.Tensor): 輸入張量,形狀為 (batch_size, seq_len, dim)。start_pos (int): 序列中用于緩存的起始位置。freqs_cis (torch.Tensor): 預計算的復指數值,用于旋轉位置編碼。mask (Optional[torch.Tensor]): 掩碼張量,用于在注意力計算中排除某些位置。返回:torch.Tensor: 輸出張量,形狀與輸入相同。"""# 獲取輸入張量的批次大小、序列長度bsz, seqlen, _ = x.size()# 計算序列的結束位置end_pos = start_pos + seqlen# 如果查詢的低秩投影的秩為0,直接通過線性層得到查詢if self.q_lora_rank == 0:q = self.wq(x)# 否則,通過低秩分解的方式得到查詢else:q = self.wq_b(self.q_norm(self.wq_a(x)))# 調整查詢的形狀,將其劃分為多個頭q = q.view(bsz, seqlen, self.n_local_heads, self.qk_head_dim)# 將查詢劃分為非位置相關部分和位置相關部分q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)# 對位置相關部分應用旋轉位置編碼q_pe = apply_rotary_emb(q_pe, freqs_cis)# 通過線性層得到鍵和值的低秩表示kv = self.wkv_a(x)# 將鍵和值的低秩表示劃分為低秩部分和位置編碼部分kv, k_pe = torch.split(kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)# 對位置編碼部分應用旋轉位置編碼k_pe = apply_rotary_emb(k_pe.unsqueeze(2), freqs_cis)# 如果注意力實現方式為樸素方式if attn_impl == "naive":# 將非位置相關部分和位置相關部分拼接得到完整的查詢q = torch.cat([q_nope, q_pe], dim=-1)# 對鍵和值的低秩表示進行歸一化和線性變換kv = self.wkv_b(self.kv_norm(kv))# 調整鍵和值的形狀,將其劃分為多個頭kv = kv.view(bsz, seqlen, self.n_local_heads, self.qk_nope_head_dim + self.v_head_dim)# 將鍵和值劃分為非位置相關部分和值部分k_nope, v = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)# 將非位置相關部分和位置編碼部分拼接得到完整的鍵k = torch.cat([k_nope, k_pe.expand(-1, -1, self.n_local_heads, -1)], dim=-1)# 將鍵存入緩存self.k_cache[:bsz, start_pos:end_pos] = k# 將值存入緩存self.v_cache[:bsz, start_pos:end_pos] = v# 計算注意力分數scores = torch.einsum("bshd,bthd->bsht", q, self.k_cache[:bsz, :end_pos]) * self.softmax_scale# 否則else:# 獲取鍵和值的線性變換層的權重wkv_b = self.wkv_b.weight if self.wkv_b.scale is None else weight_dequant(self.wkv_b.weight, self.wkv_b.scale, block_size) # 調整權重的形狀wkv_b = wkv_b.view(self.n_local_heads, -1, self.kv_lora_rank)# 計算非位置相關部分的注意力分數q_nope = torch.einsum("bshd,hdc->bshc", q_nope, wkv_b[:, :self.qk_nope_head_dim])# 將鍵和值的低秩表示歸一化后存入緩存self.kv_cache[:bsz, start_pos:end_pos] = self.kv_norm(kv)# 將位置編碼部分存入緩存self.pe_cache[:bsz, start_pos:end_pos] = k_pe.squeeze(2)# 計算注意力分數scores = (torch.einsum("bshc,btc->bsht", q_nope, self.kv_cache[:bsz, :end_pos]) +torch.einsum("bshr,btr->bsht", q_pe, self.pe_cache[:bsz, :end_pos])) * self.softmax_scale# 如果存在掩碼,將掩碼加到注意力分數上if mask is not None:scores += mask.unsqueeze(1)# 對注意力分數應用softmax函數scores = scores.softmax(dim=-1, dtype=torch.float32).type_as(x)# 如果注意力實現方式為樸素方式if attn_impl == "naive":# 通過注意力分數和值緩存計算輸出x = torch.einsum("bsht,bthd->bshd", scores, self.v_cache[:bsz, :end_pos])# 否則else:# 通過注意力分數和鍵值緩存計算中間結果x = torch.einsum("bsht,btc->bshc", scores, self.kv_cache[:bsz, :end_pos])# 通過中間結果和權重計算輸出x = torch.einsum("bshc,hdc->bshd", x, wkv_b[:, -self.v_head_dim:])# 對輸出進行線性變換x = self.wo(x.flatten(2))return x
代碼解釋總結
這段代碼定義了一個多頭注意力層(MLA)類。在初始化時,根據傳入的參數設置各種維度、低秩投影的秩等,并初始化相應的線性層和歸一化層,同時根據注意力實現方式注冊不同的緩存。在前向傳播過程中,對輸入進行處理得到查詢、鍵和值,應用旋轉位置編碼,根據不同的注意力實現方式計算注意力分數,最后通過注意力分數和緩存得到輸出并進行線性變換。
Feed-Forward Network
這個MoE分成兩個部分,左邊是一些可以分享的專家,就是每次都需要去計算的,右邊的是根據分數來選擇的。如何選擇,是通過一個門控機制來選擇。這個門控是如何設計的,代碼里有實現,但論文和代碼都沒有對它物理意義的解釋。門控機制,簡單來說,就是設計了一個網絡,來選出K個候選。
核心邏輯:
- 通過門控確定本輪要用到的本地專家:
weights, indices = self.gate(x)
- 用選擇的每個本地專家進行計算:
y[idx] += expert(x[idx]) * weights[idx, top, None]
- 用共享專家進行計算:
z = self.shared_experts(x)
- 將本地專家的輸出和共享專家的輸出相加,并恢復到原始形狀:
return (y + z).view(shape)
全部代碼:
# 定義混合專家(Mixture-of-Experts, MoE)模塊類,繼承自PyTorch的nn.Module類
class MoE(nn.Module):"""混合專家(Mixture-of-Experts, MoE)模塊。屬性:dim (int): 輸入特征的維度。n_routed_experts (int): 模型中專家的總數。n_local_experts (int): 在分布式系統中本地處理的專家數量。n_activated_experts (int): 每個輸入激活的專家數量。gate (nn.Module): 門控機制,用于將輸入路由到不同的專家。experts (nn.ModuleList): 專家模塊列表,包含多個專家網絡。shared_experts (nn.Module): 共享專家模塊,應用于所有輸入。"""def __init__(self, args):"""初始化MoE模塊。參數:args: 模型參數對象,包含MoE模塊的相關參數。"""# 調用父類的初始化方法super().__init__()# 保存輸入特征的維度self.dim = args.dim# 確保專家總數能被分布式系統中的進程數整除assert args.n_routed_experts % world_size == 0, f"專家數量必須能被進程數整除 (進程數={world_size})"# 保存模型中專家的總數self.n_routed_experts = args.n_routed_experts# 計算本地處理的專家數量self.n_local_experts = args.n_routed_experts // world_size# 保存每個輸入激活的專家數量self.n_activated_experts = args.n_activated_experts# 計算本地專家在所有專家中的起始索引self.experts_start_idx = rank * self.n_local_experts# 計算本地專家在所有專家中的結束索引self.experts_end_idx = self.experts_start_idx + self.n_local_experts# 初始化門控機制self.gate = Gate(args)# 初始化專家模塊列表,本地負責的專家使用Expert模塊,其他位置置為Noneself.experts = nn.ModuleList([Expert(args.dim, args.moe_inter_dim) if self.experts_start_idx <= i < self.experts_end_idx else Nonefor i in range(self.n_routed_experts)])# 初始化共享專家模塊self.shared_experts = MLP(args.dim, args.n_shared_experts * args.moe_inter_dim)def forward(self, x):"""MoE模塊的前向傳播過程。參數:x (torch.Tensor): 輸入張量。返回:torch.Tensor: 經過專家路由和計算后的輸出張量。"""# 保存輸入張量的原始形狀shape = x.size()# 將輸入張量展平為二維張量,方便后續處理x = x.view(-1, self.dim)# 通過門控機制得到每個輸入分配到各個專家的權重和對應的專家索引weights, indices = self.gate(x)# 初始化輸出張量,形狀與輸入相同,初始值全為0y = torch.zeros_like(x)# 統計每個專家被分配到的輸入數量counts = torch.bincount(indices.flatten(), minlength=self.n_routed_experts).tolist()# 遍歷本地負責的專家for i in range(self.experts_start_idx, self.experts_end_idx):# 如果該專家沒有被分配到輸入,則跳過if counts[i] == 0:continue# 獲取當前專家模塊expert = self.experts[i]# 找出分配到當前專家的輸入的索引idx, top = torch.where(indices == i)# 將這些輸入通過當前專家模塊進行計算,并乘以對應的權重,累加到輸出張量中y[idx] += expert(x[idx]) * weights[idx, top, None]# 將輸入通過共享專家模塊進行計算z = self.shared_experts(x)# 如果使用分布式訓練,對本地專家的輸出進行全局歸約操作if world_size > 1:dist.all_reduce(y)# 將本地專家的輸出和共享專家的輸出相加,并恢復到原始形狀return (y + z).view(shape)
代碼解釋總結
這段代碼定義了一個混合專家(MoE)模塊。在初始化時,根據傳入的參數設置專家的數量、門控機制、專家模塊列表和共享專家模塊。在前向傳播過程中,首先通過門控機制將輸入路由到不同的專家,然后對本地負責的專家進行計算并累加結果,同時將輸入通過共享專家模塊進行計算,最后將兩部分結果相加并恢復原始形狀。如果使用分布式訓練,還會對本地專家的輸出進行全局歸約操作。
Gate
不考慮分組路由來看看它的核心邏輯,實際上就是線下變換,然后激活,選擇K個極值(如果用了分組,就是選擇K的方式發生了一些變化):
# 通過線性變換計算每個輸入對應各個專家的分數
scores = linear(x, self.weight)
# 根據評分函數類型對分數進行處理
if self.score_func == "softmax":scores = scores.softmax(dim=-1, dtype=torch.float32)
else:scores = scores.sigmoid()
# 選擇分數最高的若干專家
indices = torch.topk(scores, self.topk, dim=-1)[1]
# 根據選擇的專家索引,從原始分數中獲取對應的權重
weights = scores.gather(1, indices)
# 如果評分函數是sigmoid,對權重進行歸一化
if self.score_func == "sigmoid":weights /= weights.sum(dim=-1, keepdim=True)
# 對權重進行縮放
weights *= self.route_scale
return weights.type_as(x), indices
# 定義門控機制類,用于在混合專家(MoE)模型中對輸入進行路由
class Gate(nn.Module):"""混合專家(MoE)模型中用于輸入路由的門控機制。屬性:dim (int): 輸入特征的維度。topk (int): 每個輸入激活的頂級專家數量。n_groups (int): 用于路由的分組數量。topk_groups (int): 輸入將被路由到的分組數量。score_func (str): 評分函數,取值為 'softmax' 或 'sigmoid'。route_scale (float): 路由權重的縮放因子。weight (torch.nn.Parameter): 門控機制的可學習權重。bias (Optional[torch.nn.Parameter]): 門控機制的可選偏置項。"""def __init__(self, args):"""初始化門控機制模塊。參數:args: 模型參數對象,包含門控機制的相關參數。"""# 調用父類的初始化方法super().__init__()# 保存輸入特征的維度self.dim = args.dim# 保存每個輸入激活的頂級專家數量self.topk = args.n_activated_experts# 保存用于路由的分組數量self.n_groups = args.n_expert_groups# 保存輸入將被路由到的分組數量self.topk_groups = args.n_limited_groups# 保存評分函數類型self.score_func = args.score_func# 保存路由權重的縮放因子self.route_scale = args.route_scale# 初始化可學習權重self.weight = nn.Parameter(torch.empty(args.n_routed_experts, args.dim))# 根據輸入特征維度決定是否初始化偏置項self.bias = nn.Parameter(torch.empty(args.n_routed_experts)) if self.dim == 7168 else Nonedef forward(self, x):"""門控機制的前向傳播過程。參數:x (torch.Tensor): 輸入張量。返回:Tuple[torch.Tensor, torch.Tensor]: 路由權重和選擇的專家索引。"""# 通過線性變換計算每個輸入對應各個專家的分數scores = linear(x, self.weight)# 根據評分函數類型對分數進行處理if self.score_func == "softmax":scores = scores.softmax(dim=-1, dtype=torch.float32)else:scores = scores.sigmoid()# 保存原始分數,后續計算權重時使用original_scores = scores# 如果存在偏置項,將其加到分數上if self.bias is not None:scores = scores + self.bias# 如果分組數量大于1,進行分組路由操作if self.n_groups > 1:# 調整分數的形狀,以便按組處理scores = scores.view(x.size(0), self.n_groups, -1)# 根據是否有偏置項,計算每個組的分數表示if self.bias is None:group_scores = scores.amax(dim=-1)else:group_scores = scores.topk(2, dim=-1)[0].sum(dim=-1)# 選擇分數最高的若干組indices = group_scores.topk(self.topk_groups, dim=-1)[1]# 創建掩碼,用于屏蔽未選擇的組mask = scores.new_ones(x.size(0), self.n_groups, dtype=bool).scatter_(1, indices, False)# 將屏蔽組的分數設為負無窮scores = scores.masked_fill_(mask.unsqueeze(-1), float("-inf")).flatten(1)# 選擇分數最高的若干專家indices = torch.topk(scores, self.topk, dim=-1)[1]# 根據選擇的專家索引,從原始分數中獲取對應的權重weights = original_scores.gather(1, indices)# 如果評分函數是sigmoid,對權重進行歸一化if self.score_func == "sigmoid":weights /= weights.sum(dim=-1, keepdim=True)# 對權重進行縮放weights *= self.route_scalereturn weights.type_as(x), indices
代碼解釋總結
這段代碼定義了一個門控機制(Gate)類,用于在混合專家(MoE)模型中對輸入進行路由。在初始化時,根據傳入的參數設置各種屬性,如輸入維度、激活專家數量、分組數量等,并初始化可學習的權重和偏置項。在前向傳播過程中,首先計算每個輸入對應各個專家的分數,然后根據評分函數類型進行處理,接著根據分組情況進行分組路由操作,選擇激活的專家并計算對應的權重,最后返回路由權重和選擇的專家索引。
B站大牛詳解視頻鏈接
B站大牛有詳細視頻講解:
https://www.bilibili.com/video/BV1RtNLeqEeu/