Megatron-LM: Training Multi-Billion Parameter Language Models Using
Model Parallelism
1. 技術設計原則
Megatron-LM 提出輕量級層內模型并行,無需定制編譯器或修改框架,僅通過在 PyTorch 原生代碼中插入少量通信操作(如all-reduce)實現,且與流水線模型并行正交互補,可靈活組合。
2.背景:矩陣分塊計算
參考:https://www.bilibili.com/video/BV1HdXtY9EuF/?share_source=copy_web&vd_source=0f3d85b09673431159069a2a9a3da50c
矩陣XXX和YYY相乘,即計算XYXYXY,有兩種分塊運算方式:
- 把YYY按列拆分為[Y1,Y2][Y_1,Y_2][Y1?,Y2?],XXX不變
● XmnYnk=Xmn[Y1nk2,Y2nk2]=[XY1,XY2]X^{mn}Y^{nk}=X^{mn}[Y_1^{n\frac{k}{2}},Y_2^{n\frac{k}{2}}]=[XY_1,XY_2]XmnYnk=Xmn[Y1n2k??,Y2n2k??]=[XY1?,XY2?] - 把YYY按行拆分為 [Y1Y2]\begin{bmatrix} Y_1 \\ Y_2 \end{bmatrix}[Y1?Y2??],把XXX按列拆分為[X1,X2][X_1,X_2][X1?,X2?]
● XmnYnk=[X1mn2,X2mn2]×[Y1n2kY2n2k]=(X1Y1)mk+(X2Y2)mkX^{mn}Y^{nk}= \begin{bmatrix}X_1^{m \frac{n}{2}},X_2^{m \frac{n}{2}} \end{bmatrix} \times \begin{bmatrix} Y_1^{\frac{n}{2} k} \\ Y_2^{\frac{n}{2} k} \end{bmatrix} = (X_1Y_1)^{mk} + (X_2Y_2)^{m k}XmnYnk=[X1m2n??,X2m2n???]×[Y12n?k?Y22n?k??]=(X1?Y1?)mk+(X2?Y2?)mk
3. 關鍵模塊并行化實現
這一部分的圖解為作者自己根據理解畫的,如果有錯誤請指正
(1)前饋網絡層(MLP)
公式:FFN(X)=σ(XA)BFFN(X)=\sigma(XA)BFFN(X)=σ(XA)B
設序列長度為lll,隱藏層維度為ddd,前饋網絡的隱藏層維度為dFFNd_{FFN}dFFN?,A∈Rd×dFFN,B∈RdFFN×dA \in \mathbb{R}^{d \times d_{FFN}}, B \in \mathbb{R}^{d_{FFN} \times d}A∈Rd×dFFN?,B∈RdFFN?×d。
權重矩陣拆分策略:
第一層線性層的權重矩陣按列拆分(A=[A1,A2]A=[A_1,A_2]A=[A1?,A2?])
- 使 GeLU 非線性激活可在各 GPU 上獨立計算,避免中間同步;
- 即使得GeLU(XA)=[GeLU(XA1),?,GeLU(XAn)]\text{GeLU}(XA)=[ \text{GeLU}(XA_1),\cdots ,\text{GeLU}(XA_n)]GeLU(XA)=[GeLU(XA1?),?,GeLU(XAn?)]
第二層線性層的權重矩陣按行拆分,直接接收 GeLU 輸出
- 在前向傳播時僅需對第二層的輸出做一次 All-Reduce 聚合輸出。
- 在反向傳播時僅需在返回到輸入時做一次 All-Reduce 聚合梯度。
通信優化:整個 MLP 模塊僅需 2 次 all-reduce 操作(前向1次、反向1次),無額外同步點。
(2)多頭注意力(Multi-Head Attention)模塊
注意力頭拆分:
- 將 Q、K、V 對應的權重矩陣按列拆分,每個 GPU 負責部分注意力頭的計算,無需中間通信;
注意力輸出層權重按行拆分,直接接收并行計算結果,僅需在反向傳播聚合梯度。
優勢:充分利用注意力頭的天然并行性,每個 GPU 僅處理部分頭的計算,降低單設備內存壓力。
(3)輸入層與輸出層優化
輸入嵌入:
- 按詞匯表維度列拆分嵌入矩陣(E=[E1,E2]E=[E_1,E_2]E=[E1?,E2?])
- 通過 g 算子(前向 all-reduce)聚合結果,避免單 GPU 存儲完整詞匯表。
輸出層與損失計算:
- 融合最終線性層的輸出與交叉熵損失計算,直接在各 GPU 上計算局部損失后聚合
- 無需傳輸大規模 logits,減少通信量從b×s×vb×s×vb×s×v至b×sb×sb×s
- bbb 為批次大小、sss 為序列長度、vvv 為詞匯表大小
4.混合并行策略(模型+數據并行)
GPU分組:
- 將 GPU 劃分為模型并行組(如8個GPU一組,共同承載一個模型)和數據并行組(不同模型并行組中同位置 GPU 組成,負責梯度同步)
- 總 GPU 數 = 模型并行度 × 數據并行度