使用 PyTorch 完全分片數據并行技術加速大模型訓練

本文,我們將了解如何基于 PyTorch 最新的 完全分片數據并行 (Fully Sharded Data Parallel,FSDP) 功能用 Accelerate 庫來訓練大模型。

動機

隨著機器學習 (ML) 模型的規模、大小和參數量的不斷增加,ML 從業者發現在自己的硬件上訓練甚至加載如此大的模型變得越來越難。 一方面,人們發現大模型與較小的模型相比,學習速度更快 (數據和計算效率更高) 且會有顯著的提升 [1]; 另一方面,在大多數硬件上訓練此類模型變得令人望而卻步。

分布式訓練是訓練這些機器學習大模型的關鍵。大規模分布式訓練 領域最近取得了不少重大進展,我們將其中一些最突出的進展總結如下:

  1. 使用 ZeRO 數據并行 - 零冗余優化器 [2]

  2. 階段 1: 跨數據并行進程 / GPU 對優化器狀態 進行分片

  3. 階段 2: 跨數據并行進程/ GPU 對優化器狀態 + 梯度 進行分片

  4. 階段 3: 跨數據并行進程 / GPU 對優化器狀態 + 梯度 + 模型參數 進行分片

  5. CPU 卸載: 進一步將 ZeRO 階段 2 的優化器狀態 + 梯度 卸載到 CPU 上 [3]

  6. 張量并行 [4]: 模型并行的一種形式,通過對各層參數進行精巧的跨加速器 / GPU 分片,在實現并行計算的同時避免了昂貴的通信同步開銷。

  7. 流水線并行 [5]: 模型并行的另一種形式,其將模型的不同層放在不同的加速器 / GPU 上,并利用流水線來保持所有加速器同時運行。舉個例子,在第 2 個加速器 / GPU 對第 1 個 micro batch 進行計算的同時,第 1 個加速器 / GPU 對第 2 個 micro batch 進行計算。

  8. 3D 并行 [3]: 采用 ZeRO 數據并行 + 張量并行 + 流水線并行 的方式來訓練數百億參數的大模型。例如,BigScience 176B 語言模型就采用了該并行方式 [6]。

本文我們主要關注 ZeRO 數據并行,更具體地講是 PyTorch 最新的 完全分片數據并行 (Fully Sharded Data Parallel,FSDP) 功能。DeepSpeedFairScale 實現了 ZeRO 論文的核心思想。我們已經將其集成到了 transformersTrainer 中,詳見博文 通過 DeepSpeed 和 FairScale 使用 ZeRO 進行更大更快的訓練[10]。最近,PyTorch 已正式將 Fairscale FSDP 整合進其 Distributed 模塊中,并增加了更多的優化。

Accelerate 🚀: 無需更改任何代碼即可使用 PyTorch FSDP

我們以基于 GPT-2 的 Large (762M) 和 XL (1.5B) 模型的因果語言建模任務為例。

以下是預訓練 GPT-2 模型的代碼。其與 此處 的官方因果語言建模示例相似,僅增加了 2 個參數 n_train (2000) 和 n_val (500) 以防止對整個數據集進行預處理/訓練,從而支持更快地進行概念驗證。

run_clm_no_trainer.py

運行 accelerate config 命令后得到的 FSDP 配置示例如下:

compute_environment: LOCAL_MACHINE
deepspeed_config: {}
distributed_type: FSDP
fsdp_config:min_num_params: 2000offload_params: falsesharding_strategy: 1
machine_rank: 0
main_process_ip: null
main_process_port: null
main_training_function: main
mixed_precision: 'no'
num_machines: 1
num_processes: 2
use_cpu: false

多 GPU FSDP

本文我們使用單節點多 GPU 上作為實驗平臺。我們比較了分布式數據并行 (DDP) 和 FSDP 在各種不同配置下的性能。我們可以看到,對 GPT-2 Large(762M) 模型而言,DDP 尚能夠支持其中某些 batch size 而不會引起內存不足 (OOM) 錯誤。但當使用 GPT-2 XL (1.5B) 時,即使 batch size 為 1,DDP 也會失敗并出現 OOM 錯誤。同時,我們看到,FSDP 可以支持以更大的 batch size 訓練 GPT-2 Large 模型,同時它還可以使用較大的 batch size 訓練 DDP 訓練不了的 GPT-2 XL 模型。

硬件配置: 2 張 24GB 英偉達 Titan RTX GPU。

GPT-2 Large 模型 (762M 參數) 的訓練命令如下:

export BS=#`try with different batch sizes till you don't get OOM error,
#i.e., start with larger batch size and go on decreasing till it fits on GPU`time accelerate launch run_clm_no_trainer.py \
--model_name_or_path gpt2-large \
--dataset_name wikitext \
--dataset_config_name wikitext-2-raw-v1 \
--per_device_train_batch_size $BS
--per_device_eval_batch_size $BS
--num_train_epochs 1
--block_size 12

FSDP 運行截屏:

圖片

FSDP 運行截屏

圖片

表 1: GPT-2 Large (762M) 模型 FSDP 訓練性能基準測試

從表 1 中我們可以看到,相對于 DDP 而言,FSDP 支持更大的 batch size,在不使用和使用 CPU 卸載設置的情況下 FSDP 支持的最大 batch size 分別可達 DDP 的 2 倍及 3 倍。從訓練時間來看,混合精度的 DDP 最快,其后是分別使用 ZeRO 階段 2 和階段 3 的 FSDP。由于因果語言建模的任務的上下文序列長度 ( --block_size ) 是固定的,因此 FSDP 在訓練時間上加速還不是太高。對于動態 batch size 的應用而言,支持更大 batch size 的 FSDP 可能會在訓練時間方面有更大的加速。目前,FSDP 的混合精度支持在 transformers 上還存在一些 問題。一旦問題解決,訓練時間將會進一步顯著縮短。

使用 CPU 卸載來支持放不進 GPU 顯存的大模型訓練

訓練 GPT-2 XL (1.5B) 模型的命令如下:

export BS=#`try with different batch sizes till you don't get OOM error,
#i.e., start with larger batch size and go on decreasing till it fits on GPU`time accelerate launch run_clm_no_trainer.py \
--model_name_or_path gpt2-xl \
--dataset_name wikitext \
--dataset_config_name wikitext-2-raw-v1 \
--per_device_train_batch_size $BS
--per_device_eval_batch_size $BS
--num_train_epochs 1
--block_size 12

圖片

表 2: GPT-2 XL (1.5B) 模型上的 FSDP 基準測試

從表 2 中,我們可以觀察到 DDP (帶和不帶 fp16) 甚至在 batch size 為 1 的情況下就會出現 CUDA OOM 錯誤,從而無法運行。而開啟了 ZeRO- 階段 3 的 FSDP 能夠以 batch size 為 5 (總 batch size = 10 (5 2) ) 在 2 個 GPU 上運行。當使用 2 個 GPU 時,開啟了 CPU 卸載的 FSDP 還能將最大 batch size 進一步增加到每 GPU 14。開啟了 CPU 卸載的 FSDP 可以在單個 GPU 上訓練 GPT-2 1.5B 模型,batch size 為 10。這使得機器學習從業者能夠用最少的計算資源來訓練大模型,從而助力大模型訓練民主化。

Accelerate 的 FSDP 集成的功能和限制

下面,我們深入了解以下 Accelerate 對 FSDP 的集成中,支持了那些功能,有什么已知的限制。

支持 FSDP 所需的 PyTorch 版本: PyTorch Nightly 或 1.12.0 之后的版本。

命令行支持的配置:

  1. 分片策略: [1] FULL_SHARD, [2] SHARD_GRAD_OP

  2. Min Num Params: FSDP 默認自動包裝的最小參數量。

  3. Offload Params: 是否將參數和梯度卸載到 CPU。

如果想要對更多的控制參數進行配置,用戶可以利用 FullyShardedDataParallelPlugin ,其可以指定 auto_wrap_policybackward_prefetch 以及 ignored_modules

創建該類的實例后,用戶可以在創建 Accelerator 對象時把該實例傳進去。

有關這些選項的更多信息,請參閱 PyTorch FullyShardedDataParallel 代碼。

接下來,我們體會下 min_num_params 配置的重要性。以下內容摘自 [8],它詳細說明了 FSDP 自動包裝策略的重要性。

圖片

FSDP 自動包裝策略的重要性

(圖源: 鏈接)

當使用 default_auto_wrap_policy 時,如果該層的參數量超過 min_num_params ,則該層將被包裝在一個 FSDP 模塊中。官方有一個在 GLUE MRPC 任務上微調 BERT-Large (330M) 模型的示例代碼,其完整地展示了如何正確使用 FSDP 功能,其中還包含了用于跟蹤峰值內存使用情況的代碼。

fsdp_with_peak_mem_tracking.py

我們利用 Accelerate 的跟蹤功能來記錄訓練和評估期間的峰值內存使用情況以及模型準確率指標。下圖展示了 wandb 實驗臺 頁面的截圖。

圖片

wandb 實驗臺

我們可以看到,DDP 占用的內存是使用了自動模型包裝功能的 FSDP 的兩倍。不帶自動模型包裝的 FSDP 比帶自動模型包裝的 FSDP 的內存占用更多,但比 DDP 少得多。與 min_num_params=1M 時相比, min_num_params=2k 時帶自動模型包裝的 FSDP 占用的內存略少。這凸顯了 FSDP 自動模型包裝策略的重要性,用戶應該調整 min_num_params 以找到能顯著節省內存又不會導致大量通信開銷的設置。如 [8] 中所述,PyTorch 團隊也在為此開發自動配置調優工具。

需要注意的一些事項

  • PyTorch FSDP 會自動對模型子模塊進行包裝、將參數攤平并對其進行原位分片。因此,在模型包裝之前創建的任何優化器都會被破壞并導致更多的內存占用。因此,強烈建議在對模型調用 prepare 方法后再創建優化器,這樣效率會更高。對單模型而言,如果沒有按照順序調用的話, Accelerate 會拋出以下告警信息,并自動幫你包裝模型并創建優化器。

    FSDP Warning: When using FSDP, it is efficient and recommended to call prepare for the model before creating the optimizer

即使如此,我們還是推薦用戶在使用 FSDP 時用以下方式顯式準備模型和優化器:

model = AutoModelForSequenceClassification.from_pretrained("bert-base-cased", return_dict=True)
+ model = accelerator.prepare(model)optimizer = torch.optim.AdamW(params=model.parameters(), lr=lr)- model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare(model,
- optimizer, train_dataloader, eval_dataloader, lr_scheduler
- )+ optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare(
+ optimizer, train_dataloader, eval_dataloader, lr_scheduler
+ )
  • 對單模型而言,如果你的模型有多組參數,而你想為它們設置不同優化器超參。此時,如果你對整個模型統一調用 prepare 方法,這些參數的組別信息會丟失,你會看到如下告警信息:

    FSDP Warning: When using FSDP, several parameter groups will be conflated into a single one due to nested module wrapping and parameter flattening.

告警信息表明,在使用 FSDP 對模型進行包裝后,之前創建的參數組信息丟失了。因為 FSDP 會將嵌套式的模塊參數攤平為一維數組 (一個數組可能包含多個子模塊的參數)。舉個例子,下面是 GPU 0 上 FSDP 模型的有名稱的參數 (當使用 2 個 GPU 時,FSDP 會把第一個分片的參數給 GPU 0, 因此其一維數組中大約會有 55M (110M / 2) 個參數)。此時,如果我們在 FSDP 包裝前將 BERT-Base 模型的 [bias, LayerNorm.weight] 參數的權重衰減設為 0,則在模型包裝后,該設置將無效。原因是,你可以看到下面這些字符串中均已不含這倆參數的名字,這倆參數已經被并入了其他層。想要了解更多細節,可參閱本 問題 (其中寫道: 原模型參數沒有 .grads 屬性意味著它們無法單獨被優化器優化 (這就是我們為什么不能支持對多組參數設置不同的優化器超參) )。

{
'_fsdp_wrapped_module.flat_param': torch.Size([494209]),'_fsdp_wrapped_module._fpw_module.bert.embeddings.word_embeddings._fsdp_wrapped_module.flat_param': torch.Size([11720448]),'_fsdp_wrapped_module._fpw_module.bert.encoder._fsdp_wrapped_module.flat_param': torch.Size([42527232])
}
  • 如果是多模型情況,須在創建優化器之前調用模型 prepare 方法,否則會拋出錯誤。

  • FSDP 目前不支持混合精度,我們正在等待 PyTorch 修復對其的支持。

工作原理 📝

圖片

FSDP 工作流

(圖源: 鏈接)

上述工作流概述了 FSDP 的幕后流程。我們先來了解一下 DDP 是如何工作的,然后再看 FSDP 是如何改進它的。在 DDP 中,每個工作進程 (加速器 / GPU) 都會保留一份模型的所有參數、梯度和優化器狀態的副本。每個工作進程會獲取不同的數據,這些數據會經過前向傳播,計算損失,然后再反向傳播以生成梯度。接著,執行 all-reduce 操作,此時每個工作進程從其余工作進程獲取梯度并取平均。這樣一輪下來,每個工作進程上的梯度都是相同的,且都是全局梯度,接著優化器再用這些梯度來更新模型參數。我們可以看到,每個 GPU 上都保留完整副本會消耗大量的顯存,這限制了該方法所能支持的 batch size 以及模型尺寸。

FSDP 通過讓各數據并行工作進程分片存儲優化器狀態、梯度和模型參數來解決這個問題。進一步地,還可以通過將這些張量卸載到 CPU 內存來支持那些 GPU 顯存容納不下的大模型。在具體運行時,與 DDP 類似,FSDP 的每個工作進程獲取不同的數據。在前向傳播過程中,如果啟用了 CPU 卸載,則首先將本地分片的參數搬到 GPU/加速器。然后,每個工作進程對給定的 FSDP 包裝模塊/層執行 all-gather 操作以獲取所需的參數,執行計算,然后釋放/清空其他工作進程的參數分片。在對所有 FSDP 模塊全部執行該操作后就是計算損失,然后是后向傳播。在后向傳播期間,再次執行 all-gather 操作以獲取給定 FSDP 模塊所需的所有參數,執行計算以獲得局部梯度,然后再次釋放其他工作進程的分片。最后,使用 reduce-scatter 操作對局部梯度進行平均并將相應分片給對應的工作進程,該操作使得每個工作進程都可以更新其本地分片的參數。如果啟用了 CPU 卸載的話,梯度會傳給 CPU,以便直接在 CPU 上更新參數。

如欲深入了解 PyTorch FSDP 工作原理以及相關實驗及其結果,請參閱 [7,8,9]。

問題

如果在 accelerate 中使用 PyTorch FSDP 時遇到任何問題,請提交至 accelerate。

但如果你的問題是跟 PyTorch FSDP 配置和部署有關的 - 你需要提交相應的問題至 PyTorch。

參考文獻

[1] Train Large, Then Compress: Rethinking Model Size for Efficient Training and Inference of Transformers

[2] ZeRO: Memory Optimizations Toward Training Trillion Parameter Models

[3] DeepSpeed: Extreme-scale model training for everyone - Microsoft Research

[4] Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism

[5] Introducing GPipe, an Open Source Library for Efficiently Training Large-scale Neural Network Models

[6] Which hardware do you need to train a 176B parameters model?

[7] Introducing PyTorch Fully Sharded Data Parallel (FSDP) API | PyTorch

[8] Getting Started with Fully Sharded Data Parallel(FSDP) — PyTorch Tutorials 1.11.0+cu102 documentation

[9] Training a 1 Trillion Parameter Model With PyTorch Fully Sharded Data Parallel on AWS | by PyTorch | PyTorch | Mar, 2022 | Medium

[10] Fit More and Train Faster With ZeRO via DeepSpeed and FairScale

[11] https://huggingface.co/blog/zh/pytorch-fsdp

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/215950.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/215950.shtml
英文地址,請注明出處:http://en.pswp.cn/news/215950.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

小程序域名SSL證書能用免費的嗎?

眾所周知,目前小程序要求域名強制使用https協議,否則無法上線。但是對于大多數開發者來說,為每一個小程序都使用上付費的SSL證書,也是一筆不小的支出。那么小程序能使用免費的SSL證書嗎? 答案是肯定的。目前市面上可選…

HCIP---RSTP/MSTP

文章目錄 目錄 文章目錄 前言 一.RSTP誕生背景 二.RSTP對比STP的快速收斂機制 端口角色變化 接口狀態變化 RSTP-BPDU 指定端口- P/A機制 BPDU發送變化 端口狀態快速切換 優化拓撲變更機制 三.MSTP MSTP誕生背景 MSTP相關概念 MSTP配置 總結 前言 STP協議雖然能夠解決環…

TypeScript中的函數注釋

一. 概覽 函數注釋主要分為顯示注釋、類型推斷、隱式的any&#xff0c;現在來詳細總結下 二. 顯示注釋 舉個例子 let str1: string hello,jacklet intArr: number[] [1,2,3] let strArr&#xff1a;Array<string> [1,2,3]function test(a: number,b: number): num…

記錄 | xftp遠程連接兩臺windows

1、打開openssh 設置 -> 應用 -> 可選功能 -> 添加功能 -> OpenSSH 客戶端&#xff0c;將 ssh 客戶端安裝將兩臺電腦的 ssh 開啟&#xff0c;cmd 中輸入 net start sshd2、配置 win10 賬號密碼 3、進行 xftp 連接

MATLAB安裝

親自驗證有效&#xff0c;多謝這位網友的分享&#xff1a; https://blog.csdn.net/xiajinbiaolove/article/details/88907232

租一臺服務器多少錢決定服務器的價格因素有哪些

租一臺服務器多少錢決定服務器的價格因素有哪些 大家好我是艾西&#xff0c;服務器這個名詞對于不從業網絡行業的人們看說肯定還是比較陌生的。在21世紀這個時代發展迅速的年代服務器在現實生活中是不可缺少的一環&#xff0c;平時大家上網瀏覽自己想要查詢的信息等都是需要服…

加減乘除簡單嗎?不,一點都不,利用位運算實現加減乘除(代碼中不含+ - * /)

文章目錄 &#x1f680;前言&#x1f680;異或運算以及與運算&#x1f680;加法的實現&#x1f680;減法的實現&#x1f680;乘法的實現&#x1f680;除法的實現 &#x1f680;前言 這也是阿輝開的新專欄&#xff0c;知識將會很零散不成體系&#xff0c;不過絕對干貨滿滿&…

華為鴻蒙HarmonyOS應用開發者高級認證試題及答案

判斷 1只要使用端云一體化的云端資源就需要支付費用&#xff08;錯&#xff09; 2所有使用Component修飾的自定義組件都支持onPageShow&#xff0c;onBackPress和onPageHide生命周期函數。&#xff08;錯&#xff09; 3 HarmonyOS應用可以兼容OpenHarmony生態&#xff08;對…

多維時序 | MATLAB實現SAO-CNN-BiGRU-Multihead-Attention多頭注意力機制多變量時間序列預測

多維時序 | MATLAB實現SAO-CNN-BiGRU-Multihead-Attention多頭注意力機制多變量時間序列預測 目錄 多維時序 | MATLAB實現SAO-CNN-BiGRU-Multihead-Attention多頭注意力機制多變量時間序列預測預測效果基本介紹模型描述程序設計參考資料 預測效果 基本介紹 MATLAB實現SAO-CNN-B…

CommonJs模塊化實現原理ES Module模塊化原理

CommonJs模塊化實現原理 首先看一個案例 初始化項目 npm init npm i webpack -D目錄結構如下&#xff1a; webpack.config.js const path require("path"); module.exports {mode: "development",entry: "./src/index.js",output: {path: p…

硬件開發筆記(十六):RK3568底板電路mipi攝像頭接口原理圖分析、mipi攝像頭詳解

若該文為原創文章&#xff0c;轉載請注明原文出處 本文章博客地址&#xff1a;https://hpzwl.blog.csdn.net/article/details/134922307 紅胖子網絡科技博文大全&#xff1a;開發技術集合&#xff08;包含Qt實用技術、樹莓派、三維、OpenCV、OpenGL、ffmpeg、OSG、單片機、軟硬…

Redis緩存主要異常及解決方案

1 導讀 Redis 是當前最流行的 NoSQL數據庫。Redis主要用來做緩存使用,在提高數據查詢效率、保護數據庫等方面起到了關鍵性的作用,很大程度上提高系統的性能。當然在使用過程中,也會出現一些異常情景,導致Redis失去緩存作用。 2 異常類型 異常主要有 緩存雪崩 緩存穿透 緩…

【sqli靶場】第二關和第三關通關思路

目錄 前言 一、sqli靶場第二關 1.1 判斷注入類型 1.2 判斷數據表中的列數 1.3 使用union聯合查詢 1.4 使用group_concat()函數 1.5 爆出users表中的列名 1.6 爆出users表中的數據 二、sqli靶場第三關 2.1 判斷注入類型 2.2 觀察報錯 2.3 判斷數據表中的列數 2.4 使用union聯合…

Emutouch學習筆記

1 項目依賴 DeviceFarmer/minitouch 1.1 確認submodule引用的 commit ID git submodule status1.2 更新子模塊到最新版本 git submodule init && git submodule update --remote

Android:監聽開機廣播自己喚醒

要通過代碼獲取安卓系統的開機廣播消息&#xff0c;并在收到消息后拉起當前apk&#xff0c;您可以使用以下步驟&#xff1a; 創建一個廣播接收器&#xff08;Broadcast Receiver&#xff09;來接收開機廣播消息。在接收到開機廣播消息時&#xff0c;您可以在接收器中編寫代碼來…

什么是 web 組態?web 組態與傳統組態的區別是什么?

組態軟件是一種用于控制和監控各種設備的軟件&#xff0c;也是指在自動控制系統監控層一級的軟件平臺和開發環境。這類軟件實際上也是一種通過靈活的組態方式&#xff0c;為用戶提供快速構建工業自動控制系統監控功能的、通用層次的軟件工具。通常用于工業控制&#xff0c;自動…

Spring Boot整合 Spring Security

Spring Boot整合 1、RBAC 權限模型 RBAC模型&#xff08;Role-Based Access Control&#xff1a;基于角色的訪問控制&#xff09; 在RBAC模型里面&#xff0c;有3個基礎組成部分&#xff0c;分別是&#xff1a;用戶、角色和權限&#xff0c;它們之間的關系如下圖所示 SELECT…

02.類模板

2、類模板 2.1 類模板語法 建立一個通用類&#xff0c;類中的成員、數據類型可以不具體制定&#xff0c;用一個虛擬的類型來代表。 template<typename T> // 類template&#xff1a;聲明創建模板typename&#xff1a;表名其后面的符號是一種數據類型&#xff0c;可以用 …

【算法】算法題-20231211

這里寫目錄標題 一、387. 字符串中的第一個唯一字符二、1189. “氣球” 的最大數量三、1221. 分割平衡字符串 一、387. 字符串中的第一個唯一字符 簡單 給定一個字符串 s &#xff0c;找到 它的第一個不重復的字符&#xff0c;并返回它的索引 。如果不存在&#xff0c;則返回…

算法通關村第十五關 | 青銅 | 用4KB內存尋找重復元素

處理海量數據的思路 1.使用位存儲&#xff1a;占用的空間是存整數的 1/8 。 2.分塊&#xff1a;也叫外部排序&#xff0c;將大文件劃分為若干小塊&#xff0c;先處理小塊再逐步得到想要的結果&#xff0c;需要至少遍歷兩次全部序列&#xff0c;是用時間換空間的方法。 3.堆&…