[分布式并行策略] 數據并行 DP/DDP/FSDP/ZeRO

上篇文章【[論文品鑒] DeepSeek V3 最新論文 之 DeepEP】 介紹了分布式并行策略中的EP,簡單的提到了其他幾種并行策略,但礙于精力和篇幅限制決定將內容分幾期,本期首先介紹DP,但并不是因為DP簡單,相反DP的水也很深,例如:“DP到底同步的是什么數據?怎么同步的?“,“AllReduce/Ring-AllReduce是什么?”,“ZeRO1、2、3又都是什么?” 等各種問題,會結合PyTorch代碼,盡量做到詳細由淺入深。

單機單卡

在深入分布式并行策略前,先回顧一下單機單卡的訓練模式:

在這里插入圖片描述

  1. CPU 加載數據,并將數據分成 batch 批次
  2. CPUbatch 批次數據傳給 GPU
  3. GPU 進行 前向傳播 計算得到 loss
  4. GPU 再通過 反向傳播 通過 loss 得到 梯度
  5. GPU 再通過 梯度 更新 參數

偽代碼:

model = Model(xx) # 1. 模型初始化
optmizer = Optimizer(xx) # 2. 優化器初始化
output = model(input) # 3. 模型計算
loss = loss_function(output, target) # 4. loss計算
loss.backward() # 5. 反向傳播計算梯度
optimizer.step() # 6. 優化器更新參數 

DP

就像 多線程編程 一樣,可以通過引入 多個GPU 來提高訓練效率,這就引出了最基礎的 單機多卡 DP,即 Data Parallel 數據并行。

在這里插入圖片描述

  1. CPU 加載數據,并將數據拆分,分給不同的 GPU
  2. GPU0模型 復制到 其他所有 GPU
  3. 每塊 GPU 獨立的進行 前向傳播反向傳播 得到 梯度
  4. 其余所有 GPU梯度 傳給 GPU0
  5. GPU0 匯總全部 梯度 進行 全局平均 計算
  6. GPU0 通過 全局平均梯度 更新自己的 模型
  7. GPU0 再把最新的 模型 同步到其他 GPU

DPPyTorch偽代碼,相比于單機單卡,大部分都沒有變化,只是把模型換了DataParallel模型,在PyTorch中通過nn.DataParallel(module, device_ids) 實現:

model = Model(xx) # 1. 模型初始化(沒變化)
model_new = torch.nn.DataParallel(model, device_ids=[0,1,2]) # 1.1 啟用DP (新增)
optmizer = Optimizer(xx) # 2. 優化器初始化(不變)
output = model_new(input) # 3. 模型計算,替換使用DP(變更)
loss = loss_function(output, target) # 4. loss計算(不變)
loss.backward() # 5. 反向傳播計算梯度(不變)
optimizer.step() # 6. 優化器更新參數 (不變)

可見DP使用上非常簡單,通過nn.DataParallel套上之前的模型即可。

但是DP存在 2個 比較嚴重的問題:

  1. 數據傳輸量較大:不考慮CPU將input數據拆分傳輸給每塊GPU,單獨看GPU間的數據傳遞;對于GPU0它需要把整個模型的參數廣播到其他所有GPU,假設有 N N N塊GPU,那么就需要傳輸 ( N ? 1 ) ? w (N-1)*w (N?1)?w參數,同時GPU0也需要從其他所有GPU上Reduce所有梯度,那么就要傳輸 ( N ? 1 ) ? g (N-1)*g (N?1)?g,所以對于GPU0來說要傳輸 ( N ? 1 ) ? ( w + g ) (N-1)*(w +g) (N?1)?(w+g)的數據,同理對于其他GPU來說,要傳輸與來自GPU0的參數,與傳出自己那份梯度。所以整體上個說,GPU數量多 N N N越大,傳輸的數據量就越多。
  2. GPU0的壓力太大:它要收集梯度、更新參數、同步參數,計算和通信壓力都很大

接下來看一下更高級用法 DDP

DDP

DDPDistributed Data Parallel,多機多卡的分布式數據并行。

在這里插入圖片描述
DP 最主要的區別就是,解決了 DP主節點瓶頸,實現了真正的 分布式通信。

而精髓就是 Ring-AllReduce,下面介紹它是如何實現 梯度累計 的:

  1. 假設梯度目前都是單獨存在于不同GPU上,而目標是將三個GPU的梯度進行累計,也就是得到下圖中三個梯度的和,a0+a1+a2b0+b1+b2c0+c1+c2
    在這里插入圖片描述
  2. 首先第一階段:GPU0a0發送給GPU1去求和a0+a1GPU1b1發送給GPU2去求和b1+b2GPU2c2發送給GPU0去求和c0+c2
    在這里插入圖片描述
  3. 然后,繼續累加,將GPU0上的c0+c2發送給GPU1去求c0+c1+c2GPU1a0+a1發送給GPU2去求a0+a1+a2,將GPU2b1+b2發送給GPU0去求b0+b1+b2
    在這里插入圖片描述
  4. 此時第一階段完成,通過Scatter-Reduce將參數分發后集合,分別得到了各個參數梯度累計結果
    在這里插入圖片描述
  5. 之后的第二節階段,通過All-Gather將各個參數的梯度進行傳播,使得每個GPU上都得到了完整的梯度結果
  6. 首先,GPU0將完整的b0+b1+b2傳遞給GPU1,同理GPU1GPU2也傳遞完整的梯度
    在這里插入圖片描述
  7. 最后,再將剩余的梯度進行傳遞
    在這里插入圖片描述
  8. 最終每個設備得到了所有參數的完整梯度累計
    在這里插入圖片描述

DDPRing-AllReduce 中還有一個細節:如果每個參數都這么Ring著進行信息梯度累計,那么通信壓力太大了;

所以設計了,通過將參數分桶聚合,也就是一個桶中維護了多個參數,當整個桶中的所有梯度都計算完畢后,再以桶維度進行Ring梯度累計,這樣降低了通信壓力,提高了訓練效率。

DDP的落地,相較于DP會復雜很多,首先簡單理解幾個概念:

  • world:代表著DDP集群中的那些卡的
  • rankworld中,每張卡的唯一標識
  • ncclgloo:都是通信庫,也就是那些分布式原語的實現,現在普遍都用老黃家的NCCL,搭配RMDA食用效率更高

接下來看一下DDPPyTorch偽代碼:

# 首先需要在每張卡,也就是進程單位設置一下,可以理解為在“組網” (新增)
import torch
import torch.distributed as dist
dist.init_process_group(backend = "nccl", # 使用NCCL通信rank = xx, # 這張卡的標識world_size = xx # 所有卡的數量
)
torch.cuda.set_device(rank) # 綁定這個進程的GPU# 然后是模型定義(變化)
model = Model(xx).cuda(rank)
model_ddp = nn.parallel.DistributedDataParallel(mode, device_ids=[rank]) # 相較于DP,這里用DDP來包裝模型# 優化器(沒變)
optimizer = Optimizer(xx)# 分布式數據加載(新增)
train_sampler = torch.utils.data.distributed.DsitributedSampler(dataset,num_replicas = world_size,rank = rank
)
dataloader = DataLoader(dataset,batch_size = per_gpu_batch_size,sampler = train_sampler
)# 訓練(不變)
output = model_ddp(input)
loss = loss_function(output, target)
loss.backward()
optimizer.step()# 訓練后結束"組網"
dist.destroy_process_group()# 使用torchrun啟動DDP
torchrun train.py # torchrun是pytorch官方DDP的最佳實踐,就別用其他的了

FSDP

不論是DP還是DDP數據并行,都有一個核心問題:模型在每個GPU上都存儲一份,如果模型特別大,單卡顯存不足的話就無法訓練。

這就引入了 FSDP(fully sharded data parallel)核心思想是:把模型的參數、梯度、優化器狀態 分片存儲,顯著降低顯存占用。

分片機制:

  • 參數分片:把模型的參數切分到所有GPU上,每個GPU僅存儲部分參數
  • 前向傳播:通過 AllGather 收集完整參數 -> 計算 -> 丟棄 非本地分片(不在顯存中存儲,僅僅是計算用)
  • 反向傳播:通過 AllGather 收集參數 -> 計算梯度 -> 再通過 reduce-scatter同步梯度分片
  • 優化器狀態:每個GPU僅維護與其參數分片對應的優化器狀態

但這時候就會有疑問了:把模型分片存儲,這還算DP嗎,這不成了MP么?

確實,FSDP融合DPMP兩種思想,但核心仍然是DP,因為它仍然是在 數據維度 進行并行(不同GPU處理不同數據),并且每個GPU都獨立的完整前向+反向傳播;這是用DP的思想,去解決DP單卡顯存瓶頸的問題。

“FSDP is DP with model sharding, not MP. It extends DP beyond single-device memory limits.”
—— PyTorch Distributed Team, Meta AI

下面展示FSDPFULLY_SHARD策略,也就是對標ZeRO-3的訓練流程:

  1. 通過FULLY_SHARD策略,將參數、梯度、優化器狀態進行了分片
    在這里插入圖片描述
  2. 前向傳播中,由于每個GPU都只有部分參數,所以當走到缺失那部分參數的時候,依賴其他GPU將參數傳進來,執行完畢后就丟棄;通過這種方式,使得即使每個GPU只保存部分參數,但依然可以完成整個前向傳播
    在這里插入圖片描述
  3. 當得到output開始計算梯度時,每個GPU完整自己那部分的梯度計算,在此過程中如果本地沒有相對應的參數,也依然需要從其他GPU傳過來;當完成梯度計算后,再把梯度發送給負責更新這部分參數的優化器分片的GPU,由它進行本地參數更新;這樣就完成了一次前向+反向傳播
    在這里插入圖片描述

再來看看FSDPPyTorch偽代碼:

# “組網”,也就是設置分布式環境方式和DDP沒有區別(不變)
import torch.distributed as dist
from torch.distrbuted.fsdp import FullyShardDataParallel as FSDP
def setup(rank, world_size): dis.init_process_group("nccl", rank=rank, world_size=world_size)torch.cuda.set_device(rank)# 使用FSDP包裝模型,同時設置分片策略(新增)
from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy
model = Model(xx)
model_fsdp = FSDP(mode, auto_wrap_policy=size_based_auto_wrap_policy, # 按層大小自動分片mixed_precision=True, # 啟用混合精度device_id=rank,sharding_strategy=torch.distributed.ShardingStrategy.FULLY_SHARD # 相當于ZeRO-3
)# 數據加載和分布式采樣,和DDP沒有區別(不變)
from torch.utils.data.distributed import DistributedSampler
dataset = datasets(xx)
sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)
dataloader == torch.utils.data.DataLoader(datase4t, batch_size=64, sampler=sampler)# 訓練和DP、DDP沒有區別(不變)
for epoch in range(epochs):sampler.set_epoch(epoch)for batch in dataloader:data, target = batck[0].to(rank), batch[1].to(rank) # H2Doptimizer.zero_grad()output = model_fsdp(mode) # 使用fsdp包裝的model進行前向傳播loss = loss(output, target)loss.backward()optimizer.step()

ZeRO1/2/3

ZeRO 是微軟家 DeepSpeed 中的核心技術,思想和 FSDP 是相同,二者都是 通過分片消除模型冗余存儲,擴大分布式并行訓練能力,只不過 FSDPPyTorch 的官方實現版。

ZeRO(Zero Redundancy Optimizer)有三種策略:

  • ZeRO-1:只分片 優化器狀態
  • ZeRO-2:分片 梯度優化器狀態 ,對應了 FSDPSHARD_GRAD_OP 策略
  • ZeRO-3:分片 參數梯度優化器狀態,對應了 FSDPFULLY_SHARD 策略

雖然 ZeRO 因為深度集成在 DeepSpeed 中,還可以利用上 DeepSpeed 的其他特性,但從生態偏好上講,個人更推薦使用 PyTorch官方的 FSDP

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/web/84187.shtml
繁體地址,請注明出處:http://hk.pswp.cn/web/84187.shtml
英文地址,請注明出處:http://en.pswp.cn/web/84187.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

LeeCode144二叉樹的前序遍歷

項目場景: 給你二叉樹的根節點 root ,返回它節點值的 前序 遍歷。 示例 1: 輸入:root [1,null,2,3] 輸出:[1,2,3] 解釋: 示例 2: 輸入:root [1,2,3,4,5,null,8,null,null,6,7…

日本生活:日語語言學校-日語作文-溝通無國界(3)-題目:わたしの友達

日本生活:日語語言學校-日語作文-溝通無國界(3)-題目:わたしの友達 1-前言2-作文原稿3-作文日語和譯本(1)日文原文(2)對應中文(3)對應英文 4-老師…

使用 rsync 拉取文件(從遠程服務器同步到本地)

最近在做服務器遷移,文件好幾個T。。。。只能單向訪問,服務器。怎么辦!!! 之前一直是使用rsync 服務器和服務器之間的雙向同步、備份(這是推的)。現在服務器要遷移,只能單向訪問&am…

Linux 并發編程:從線程池到單例模式的深度實踐

文章目錄 一、普通線程池:高效線程管理的核心方案1. 線程池概念:為什么需要 "線程工廠"?2. 線程池的實現:從 0 到 1 構建基礎框架 二、模式封裝:跨語言線程庫實現1. C 模板化實現:類型安全的泛型…

2013年SEVC SCI2區,自適應變領域搜索算法Adaptive VNS+多目標設施布局,深度解析+性能實測

目錄 1.摘要2.自適應局部搜索原理3.自適應變領域搜索算法Adaptive VNS4.結果展示5.參考文獻6.代碼獲取7.算法輔導應用定制讀者交流 1.摘要 VNS是一種探索性的局部搜索方法,其基本思想是在局部搜索過程中系統性地更換鄰域。傳統局部搜索應用于進化算法每一代的解上&…

詳細介紹醫學影像顯示中窗位和窗寬

在醫學影像(如DICOM格式的CT圖像)中,**窗寬(Window Width, WW)和窗位(Window Level, WL)**是兩個核心參數,用于調整圖像的顯示對比度和亮度,從而優化不同組織的可視化效果…

Unity_VR_如何用鍵鼠模擬VR輸入

文章目錄 [TOC] 一、創建項目1.直接創建VR核心模板(簡單)2.創建3D核心模板導入XR包 二、添加XR設備模擬器1.打開包管理器2.添加XR設備模擬器3.將XR設備模擬器拖到場景中4.運行即可用鍵盤模擬VR輸入 一、創建項目 1.直接創建VR核心模板(簡單&…

SpringBoot定時監控數據庫狀態

1.application.properties配置文件 # config for mysql spring.datasource.url jdbc\:mysql\://127.0.0.1\:3306/數據庫名?characterEncoding\utf8&useSSL\false spring.datasource.username 賬號 spring.datasource.password 密碼 spring.datasource.validation-quer…

Qt聯合Halcon開發一:Qt配置Halcon環境【詳細圖解流程】

在Qt中使用Halcon庫進行圖像處理開發,可以有效地結合Qt的圖形界面和Halcon強大的計算機視覺功能。下面是詳細的配置過程,幫助你在Qt項目中成功集成Halcon庫。 步驟 1: 安裝Halcon軟件并授權 首先,確保你已經在電腦上安裝了Halcon軟件&#x…

一體化(HIS系統)醫院信息系統,讓醫療數據互聯互通

在醫療信息化浪潮下,HIS系統、LIS系統、PACS系統、電子病歷系統等信息系統成為醫療機構必不可少的一部分,從患者掛號到看診,從各種檢查到用藥,從院內治療到院外管理……醫療機構不同部門、不同科室的各類醫療、管理業務幾乎都初步…

Spring Boot 的 3 種二級緩存落地方式

在高并發系統設計中,緩存是提升性能的關鍵策略之一。隨著業務的發展,單一的緩存方案往往無法同時兼顧性能、可靠性和一致性等多方面需求。 此時,二級緩存架構應運而生,本文將介紹在Spring Boot中實現二級緩存的三種方案。 一、二…

Android Studio Profiler使用

一:memory 參考文獻: AndroidStudio之內層泄漏工具Profiler使用指南_android studio profiler-CSDN博客

Zephyr boot

<!DOCTYPE html> <html lang"zh-CN"> <head> <meta charset"UTF-8"> <meta name"viewport" content"widthdevice-width, initial-scale1.0"> <title>Zephyr設備初始化機制交互式解析…

騰訊地圖Web版解決熱力圖被輪廓覆蓋的問題

前言 你好&#xff0c;我是喵喵俠。 還記得那天傍晚&#xff0c;我正對著電腦調試一個騰訊地圖的熱力圖頁面。項目是一個區域人流密度可視化模塊&#xff0c;我加了一個淡藍色的輪廓圖層用于表示區域范圍&#xff0c;熱力圖放在下面用于展示人流熱度。效果一預覽&#xff0c;…

【JVMGC垃圾回收場景總結】

文章目錄 CMS在并發標記階段&#xff0c;已經被標記的對象&#xff0c;又被新生代跨帶引用&#xff0c;這時JVM會怎么處理?為什么 Minor GC 會發生 STW&#xff1f;有哪些對象是在棧上分配的&#xff1f;對象在 JVM 中的內存結構為什么需要對齊填充&#xff1f;JVM 對象分配空…

3_STM32開發板使用(STM32F103ZET6)

STM32開發板使用(STM32F103ZET6) 一、概述 當前所用開發板為正點原子精英板,MCU: STM32F103ZET6。一般而言,拿到板子之后先要對板子有基礎的認識,包括對開發板上電開機、固件下載、調試方法這三個部分有基本的掌握。 二、系統開機 2.1 硬件連接 直接接電源線或Type-c線…

crackme012

crackme012 名稱值軟件名稱attackiko.exe加殼方式無保護方式serial編譯語言Delphi v1.0調試環境win10 64位使用工具x32dbg,PEid破解日期2025-06-18 -發現是 16位windows 程序環境還沒搭好先留坑

CppCon 2016 學習:I Just Wanted a Random Integer

你想要一個隨機整數&#xff0c;用于模擬隨機大小的DNA讀取片段&#xff08;reads&#xff09;&#xff0c;希望覆蓋不同長度范圍&#xff0c;也能測試邊界情況。 代碼部分是&#xff1a; #include <cstdlib> auto r std::rand() % 100;它生成一個0到99之間的隨機整數&…

MySQL層級查詢實戰:無函數實現部門父路徑

本次需要擊斃的MySQL函數 函數主要用于獲取部門的完整層級路徑&#xff0c;方便在應用程序或SQL查詢中直接調用&#xff0c;快速獲得部門的上下級關系信息。執行該函數之后簡單使用SQL可以實現數據庫中部門名稱查詢。例如下面sql select name,GetDepartmentParentNames(du.de…

Python初學者教程:如何從文本中提取IP地址

Python初學者教程:如何從文本中提取IP地址 在網絡安全和數據分析領域,經常需要從文本文件中提取IP地址。本文將引導您使用Python創建一個簡單但實用的工具,用于從文本文件提取所有IP地址并將其保存到新文件中。即使您是編程新手,也可以跟隨本教程學習Python的基礎知識! …