【大模型訓練】中短序列attention 和MOE層并行方式(二)

在這里插入圖片描述
在這里插入圖片描述在這里插入圖片描述

我們考慮一個典型的Transformer模型結構,在多層堆疊中,其中包含Attention層和MoE層(FeedForward層被替換為MoE層)。在模型最后是LM Head(語言模型頭),通常是一個全連接層,將隱層向量映射到詞表大小的輸出。

在您描述的配置中,MoE層采用allgatherEP模式,其輸出通過Reduce-Scatter操作得到分片輸出(每個TP組內,每張卡持有部分隱層維度)。而下一步的LM Head需要完整的隱層向量作為輸入。因此,在LM Head之前需要將分片的數據通過All-Gather操作聚合為完整向量。

具體流程如下:

MoE層輸出(Reduce-Scatter之后):每張卡持有分片數據,形狀為 [batch_size, seq_len, hidden_size/TP_size]
例如:TP=4,hidden_size=6144,則每張卡持有 [batch_size, 16000, 1536]
為了進行LM Head計算(需要完整的隱層向量),在TP組內執行All-Gather操作:
將其他3張卡上的分片數據收集過來,每張卡都得到完整的隱層向量 [batch_size, 16000, 6144]
LM Head計算:
此時每張卡都擁有完整的輸出向量,可以獨立計算LM Head(一個線性層,權重形狀為 [vocab_size, hidden_size])。
但是,由于LM Head的計算量可能很大(特別是詞表很大時),通常也會采用張量并行(TP)來切分這個線性層。
LM Head的TP切分:
將LM Head的權重矩陣按列切分(即hidden維度切分),每張卡持有一部分權重:[vocab_size, hidden_size/TP_size]。
這樣,每張卡只需要計算自己的部分:logits_local = input_full @ weight_local.T,輸出形狀為 [batch_size, seq_len, vocab_size/TP_size]
然后,通過All-Gather(在vocab維度)將logits聚合為完整的詞表大小(也可以使用Reduce-Scatter,但一般為了后續的損失計算,需要每張卡都有完整的logits,所以用All-Gather):
每張卡收集其他卡計算的logits分片,得到完整的logits:[batch_size, seq_len, vocab_size]
但是,這里有一個效率問題:在LM Head之前做了All-Gather得到完整隱層,而后面又進行了兩次通信(第一次是矩陣乘后的All-Gather,第二次是為了得到完整logits)。實際上,我們可以通過Gather操作合并來優化。

另一種常見的做法(如Megatron-LM的做法)是:避免在LM Head之前做All-Gather,而是直接利用分片的隱層和分片的權重計算,然后再進行All-Gather(在詞表維度聚合)。這樣就只需要一次All-Gather通信。

步驟:

MoE層輸出:分片隱層 [batch_size, seq_len, hidden_size/TP_size]
LM Head權重:每張卡持有 [vocab_size, hidden_size] 的列切分(每個分片大小為 [vocab_size, hidden_size/TP_size])
每張卡計算:logits_local = input_local @ weight_local.T 得到 [batch_size, seq_len, vocab_size] 的一部分(實際上是 [batch_size, seq_len, vocab_size] 在vocab維度上的一個分片?注意:矩陣乘法說明:input_local的隱層維度是hidden_size/TP,而weight_local.T的維度是 [hidden_size/TP_size, vocab_size],但這樣乘出來的結果是 [batch_size, seq_len, vocab_size],但是注意,這個結果是不完整的,因為輸入只是部分隱層。因此需要將各個卡的結果相加(Reduce-Sum)才能得到完整的logits。
因此,正確的做法是:
在計算完本地矩陣乘之后,需要進行一次All-Reduce(Sum)操作,將各個卡上的部分logits相加,得到完整的logits。但這里因為輸入本身就是分片的,所以需要All-Reduce。但是,這種做法在Megatron中只適用于無并行的線性層。在LM Head中,如果權重按列切分(即每張卡持有部分權重),那么每個分片權重與輸入分片相乘只能得到部分結果,而完整的結果需要將各個部分加起來(即沿著hidden維度拆分,然后結果相加)。

因此,LM Head的計算應該是:

輸入:X (每張卡上 [batch_size, seq_len, hidden_size/TP_size])
權重:W (每張卡上 [vocab_size, hidden_size/TP_size])

計算:local_logits = X @ W.T 得到 [batch_size, seq_len, vocab_size],但這個結果只是部分結果(因為是隱層分片和權重分片相乘)

然后,需要All-Reduce(Sum)操作:將各個卡上的local_logits相加,得到完整的logits。

然而,在MoE之后,如果我們不進行All-Gather(還原完整隱層),那么LM Head就要在分片隱層的基礎上計算,然后通過All-Reduce來聚合。這樣通信量是多少呢?

通信量 = batch_size * seq_len * vocab_size * sizeof(float)

這通常很大,因為vocab_size很大(幾萬到幾十萬)。例如,batch_size=1, seq_len=16000, vocab_size=50000, 則通信量=1 * 16000 * 50000 * 4字節=3.2GB,這比之前All-Gather的通信量(1 * 16000 * 6144 * 2字節≈196MB)大得多。

因此,更高效的做法是:

在MoE輸出后,使用All-Gather恢復完整隱層(每張卡得到 [batch, seq, hidden],通信量是 (TP_size-1)/TP_size * hidden_size * …,也就是上面提到的377MB左右,對于TP=4,每個分片1536,All-Gather需要發送3個1536的分片,所以輸入數據量是 3 * (batch_size * seq_len * 1536),以batch=2, seq=16000為例:3 * 2 * 16000 * 1536 * 2字節(fp16)≈ 2.25GB(注意這是總發送數據量,在NCCL中,All-Gather會將這些數據分發到其他卡,所以每卡接收3塊數據,共約2.25GB,但現代GPU卡間帶寬高,如NVLink 600GB/s,實際時間很短)
然后,使用TP切分權重計算LM Head(不需要立即做All-Gather)。但注意,由于我們已經All-Gather得到了完整隱層,那么我們可以將LM Head的權重按行切分(即切分vocab維度)。這樣每張卡計算一部分詞表的logits。
然后,再使用All-Gather將各個卡上分詞的logits收集起來,得到完整的logits。
第二步的通信量:All-Gather聚合logits,通信量為:batch_size * seq_len * vocab_size * (TP_size-1)/TP_size * sizeof(float)

例如:batch=2, seq=16000, vocab=50000, TP=4,則通信量 = 2 * 16000 * 50000 * 3/4 * 4字節≈4.8GB(按float32計算,如果用float16則減半)。這依然很大。

所以,為了減少通信量,另一種策略是將LM Head的權重切分到不同的數據并行組(DP)中,只在一個設備上計算整個LM Head(不并行),但這樣會引入計算瓶頸。

或者,我們可以不進行第二次All-Gather,而是在后續的損失計算中進行優化。但是,損失計算(交叉熵)通常需要完整的logits。

實際上,我們觀察到,在Megatron-LM中,對于LM Head,他們采用了以下兩種方式之一:

A. 不切分LM Head(權重復制),在計算之前通過All-Gather得到完整隱層(這樣每張卡有完整隱層),然后每張卡獨立計算整個LM Head(權重復制,所以輸出也是完整的logits)。這樣避免了第二步的All-Gather,但每張卡都存儲了整個詞表(可能很大)且計算了完整的矩陣乘法(計算量大,但并行度高)。

B. 將LM Head用張量并行切分(按hidden維度切分權重),然后通過All-Reduce得到完整的logits(如上所述)。但這種方法在詞表大時通信量很大。

由于詞表很大(如50k),All-Reduce的通信量甚至比All-Gather隱層還要大,因此實際中更常見的是使用All-Gather隱層的方式(即先將隱層收集完整),然后每張卡獨立計算整個LM Head(需要每張卡都存儲完整的LM Head權重)。

但是,這需要每張卡都有足夠的顯存放得下整個LM Head(詞表大小×隱層大小×2字節(fp16))。例如50k×6144×2≈600MB,這可以接受。

因此,推薦做法:

MoE層輸出后,在TP組內通過All-Gather得到完整隱層(每張卡都有 [batch, seq, hidden])。
然后,LM Head(一個線性層)的權重在每張卡上都是完整的(通過DP組廣播,每個DP組內的卡都有相同的權重副本)。
每張卡獨立計算:logits = hidden_full @ lm_head_weight.T,得到 [batch, seq, vocab]。
這樣避免了LM Head的并行通信,但增加了顯存占用(存儲完整LM Head權重)和計算量(每張卡都計算整個矩陣乘法)。然而,計算方面實際上通過數據并行分攤了(因為batch切分在數據并行組內,但實際上在TP組內也是完整batch?注意,TP組內的batch是完整的,因為前面All-Gather已經得到完整隱層,而每個TP組計算的是整個batch?)

澄清:我們的并行方式包括DP(數據并行)和TP(張量并行)。在TP組內,每個TP組處理一個數據并行分片(batch的一部分)。所以,在TP組內,All-Gather之后的隱層是該TP組負責的那部分batch(batch_size/DP_size)的完整隱層。然后,在TP組內每張卡計算完整的LM Head(對這部分batch)。

因此,總計算量:每個TP組獨立計算batch_size/DP_size個樣本的整個LM Head。

這樣,LM Head的計算在多個TP組之間是數據并行的(每個組算一部分batch)。

所以,這個方案是可行的,并且沒有額外的通信。

總結流程:

在注意力層:使用TP=4(和DP=4)計算。
MoE層:使用TP=4和EP=4(allgatherEP模式),輸出通過Reduce-Scatter得到分片隱層(每個卡持有部分隱層)。
為了進入后續層(比如下一層可能是Attention層)需要完整隱層?但這里是進入LM Head。
但是,在模型結構上,通常MoE層后面可能還有層(如另一個Attention層),但最后才是LM Head。在您的配置中,可能MoE層是最后一層專家層,然后輸出到LM Head。

因此,如果后面是LM Head,那么:

Step1: MoE層輸出(Reduce-Scatter后) -> 每張卡有 [batch, seq, hidden/TP]
Step2: 執行All-Gather(在TP組內)得到完整的隱層 [batch, seq, hidden](注意,這個batch是當前DP分片,即batch/DP_size)
Step3: 將完整隱層輸入到LM Head(每張卡都有完整的LM Head權重,獨立計算)
Step4: 得到logits [batch, seq, vocab],然后計算損失(在數據并行組內需要聚合損失?或者直接在每個設備上計算自己部分batch的損失,然后梯度聚合時自然合并)

注意:損失計算不需要跨設備通信,因為每個設備只計算自己部分batch的損失(標簽也是分片的),最后梯度聚合在數據并行組內進行(只對模型參數梯度,損失值本身不需要聚合)。

所以,在MoE層后到LM Head之間,需要一次All-Gather(在TP組內)來恢復完整隱層,然后直接計算LM Head。

這個All-Gather操作是必要的。

如果您不想做這次All-Gather,另一種方法是:將LM Head的TP計算與前面的MoE層的TP輸出分片對齊(按隱層維度切分),然后通過All-Reduce(Sum)來聚合logits。但是,如上所述,All-Reduce的通信量很大(因為詞表大),所以通常不采用。

結論:在您的配置下,MoE層后進入LM Head之前,需要一次All-Gather(在TP組內)來恢復完整隱層。

這也就是為什么您的問題描述中說“然后進入下一步采用all gather”的原因。
在這里插入圖片描述
在這里插入圖片描述

在這里插入圖片描述
在這里插入圖片描述

在這里插入圖片描述
在這里插入圖片描述

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/web/83423.shtml
繁體地址,請注明出處:http://hk.pswp.cn/web/83423.shtml
英文地址,請注明出處:http://en.pswp.cn/web/83423.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

2025-06-09(批量智能裁剪視頻尺寸并延長視頻時長)

import os import subprocess import random import json # 配置參數 TARGET_WIDTH 500 TARGET_HEIGHT 600 TARGET_DURATION 180 # 目標時長(秒) OUTPUT_DIR "processed_videos" MIRROR_MODES ["none", "horizontal&quo…

CKA考試知識點分享(9)---gateway api

CKA 版本:1.32 第九套題是涉及gateway api相關。 注意:本文不是題目,只是為了學習相關知識點做的實驗。僅供參考 實驗目的 創建一個gateway api,來實現后端鏡像的外部訪問。 gateway api 通過nginx實現 實驗開始 安裝nginx ga…

Kafka 消息模式實戰:從簡單隊列到流處理(一)

一、Kafka 簡介 ** Kafka 是一種分布式的、基于發布 / 訂閱的消息系統,由 LinkedIn 公司開發,并于 2011 年開源,后來成為 Apache 基金會的頂級項目。它最初的設計目標是處理 LinkedIn 公司的海量數據,如用戶活動跟蹤、消息傳遞和…

Linux中使用yum安裝MYSQL

1、關系型數據庫 MySQL 使用 yum 安裝mysql 1、檢查是否已經安裝 Mysql rpm -qa | grep mysql如果安裝了 就進行卸載 rpm -e mysql-community-libs-5.7.44-1.el7.x86_64 rpm -e mysql57-community-release-el7-11.noarch rpm -e mysql-community-common-5.7.44-1.el7.x86_64…

Linux 文件系統與 I/O 編程核心原理及實踐筆記

文章目錄 一、理解文件1.1 狹義理解1.2 廣義理解1.3 文件操作的歸類認識1.4 系統角度:進程與文件的交互1.5 實踐示例 二、回顧 C 文件接口2.1 hello.c 打開文件2.2 hello.c 寫文件2.3 hello.c 讀文件2.4 輸出信息到顯示器的幾種方法2.5 stdin & stdout & st…

1.9 Express

Express 是一個基于 Node.js 平臺的輕量級、靈活的 Web 應用框架,它為構建 Web 應用和 API 提供了一系列強大的功能。 核心特性 中間件支持:Express 使用中間件(middleware)函數來處理 HTTP 請求和響應。中間件可以訪問請求對象&…

面壁智能MiniCPM4.0技術架構與應用場景

📋 目錄 1. 引言:端側智能新時代2. MiniCPM4.0概述3. 核心技術架構 3.1 高效雙頻換擋機制3.2 稀疏注意力機制3.3 系統級優化創新 4. 技術突破與性能表現5. 應用場景深度解析 5.1 智能手機應用5.2 智能家居場景5.3 汽車智能化5.4 其他端側應用 6. 行業影…

RabbitMQ路由核心解密:從Exchange到RoutingKey的深度實踐與避坑指南

🔍 RabbitMQ路由核心解密:從Exchange到RoutingKey的深度實踐與避坑指南 “消息去哪了?”——這是每位RabbitMQ使用者在調試時最常發出的靈魂拷問。 理解Exchange與RoutingKey的協作機制,正是解開路由謎題的關鍵鑰匙。 一、Exchang…

Spring MVC完全指南 - 從入門到精通

目錄 1. Spring MVC簡介 2. MVC架構模式 3. Spring MVC核心組件 4. 請求處理流程 5. 控制器詳解 6. 請求映射 7. 參數綁定 8. 數據驗證 9. 視圖解析器 10. 模型數據處理 11. 異常處理 12. 攔截器 13. 文件上傳下載 14. RESTful API 15. 配置詳解 總結 1. Sprin…

實戰使用docker compose 搭建 Redis 主從復制集群

文章目錄 前言技術積累1、Redis 主從復制機制2、Docker Compose 編排3、 Redis 配置文件定制4、 驗證主從狀態5、 自動化部署與維護 環境準備實戰演示創建redis目錄及配置1、創建redis目錄2、創建redis配置文件 啟動redis集群服務1、創建docker-compose編排文件2、編排docker-c…

【學習筆記】RTSP-Ovnif-GB28181

【學習筆記】RTSP-Ovnif-GB28181 一、RTSP_RTP_RTCP RTSP(Real Time Streaming Protocol),RFC2326,實時流傳輸協議,是TCP/IP協議體系中的一個應用層協議。 RTP協議詳細說明了在互聯網上傳遞音頻和視頻的標準數據包格…

stm32-c8t6實現語音識別(LD3320)

目錄 LD3320介紹: 功能引腳 主要特色功能 通信協議 端口信息 開發流程 stm32c8t6代碼 LD3320驅動代碼: LD3320介紹: 內置單聲道mono 16-bit A/D 模數轉換內置雙聲道stereo 16-bit D/A 數模轉換內置 20mW 雙聲道耳機放大器輸出內置 5…

RAG技術全解析:從概念到實踐,構建高效語義檢索系統——嵌入模型與向量數據庫搭建指南

一、RAG技術概述:為什么需要RAG? 1.1 什么是RAG? RAG(Retrieval-Augmented Generation)是一種結合檢索與生成能力的AI架構。其核心思想是通過外部知識庫動態增強大語言模型(LLM)的生成能力&…

【資源分享】手機玩轉經典游戲!小雞模擬器1.9.0:PSP/NDS/GBA完美運行!

阿燦今天給大家推薦一款小雞模擬器,這是一個老款PC和掌上游戲機模擬器。完美模擬街機(fbamamemameplus).PS、PSP、FC(NES)SFC(SNES)、GBA、GBC、MD、NDS、DC、NGP、WS (WSC) PCE、ONS 等18款經典掌機游戲機。小雞模擬器同時也提供海量熱門的漢化版游戲免…

matlab脈沖信號并繪制波形2025.6.11

以下是一個使用MATLAB生成5V、10MHz脈沖信號并繪制波形的示例代碼: % 5V 10MHz脈沖信號仿真 clc; clear; close all; % 參數設置 voltage = 5; % 信號幅度(V) frequency = 10e6; % 脈沖頻率(10MHz) duty_cycle =

ElasticJob初探

依賴版本 JDK版本是&#xff1a;jdk17 springboot版本 <parent><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-parent</artifactId><version>3.2.4</version></parent>zookeeper elasticjo…

【Vue3】(三)vue3中的pinia狀態管理、組件通信方式及總結、插槽

目錄 一、vue3的pinia 1、什么是pinia&#xff1f; 2、為什么Vue3選擇pinia&#xff1f; 3、使用pinia的好處 4、安裝pinia 2、項目配置 3、存儲/讀取pinia中的數據 4、修改pinia中的數據 5、storeToRefs&#xff08;保持store中數據的響應式&#xff09; 6、getters 7、…

WEB3全棧開發——面試專業技能點P1Node.js / Web3.js / Ethers.js

一、Node.js 事件循環 Node.js 的事件循環&#xff08;Event Loop&#xff09;是其異步編程的核心機制&#xff0c;它使得 Node.js 可以在單線程中實現非阻塞 I/O 操作。 &#x1f501; 簡要原理 Node.js 是基于 libuv 實現的&#xff0c;它使用事件循環來處理非阻塞操作。事件…

大數據學習棧記——Neo4j的安裝與使用

本文介紹圖數據庫Neofj的安裝與使用&#xff0c;操作系統&#xff1a;Ubuntu24.04&#xff0c;Neofj版本&#xff1a;2025.04.0。 Apt安裝 Neofj可以進行官網安裝&#xff1a;Neo4j Deployment Center - Graph Database & Analytics 我這里安裝是添加軟件源的方法 最新版…

web架構4------(nginx常用變量,nginx中英文自動匹配,lnmp網站架構,正向代理,反向代理,負載均衡)

一.前言 本期來介紹nginx最后幾個知識點&#xff0c;看著要說的內容很多&#xff0c;其實一點也不多&#xff0c;都是所見即所得的東西。 二.nginx常用變量 2.1 常用變量 $args 請求中的參數&#xff0c;也叫查詢參數&#xff0c;如www.123.com/1.php?a1&b2的$args就是…