如何用更少的顯存訓練 PyTorch 模型

文章目錄

1、引言

2、自動混合精度訓練

3、低精度訓練

4、梯度檢查點

5、通過梯度累積減小批量大小

6、張量分片與分布式訓練

7、高效數據加載

8、使用 In-Place 操作

9、Activation and Parameter Offloading

10、使用更精簡的優化器

11、高級策略

12、總結


1、引言

在訓練大型深度學習模型(包括LLM和視覺Transformer)時,最常見的瓶頸之一就是顯存消耗達到峰值。由于大多數人無法使用大規模的GPU集群,因此在本文中將概述一些技術和策略,在不犧牲模型性能和預測準確性的情況下,將顯存消耗降低近20倍。請記住,這些技術中的大多數應用并不互相排斥,可以很容易地結合使用,以提高顯存效率。

2、自動混合精度訓練

混合精度訓練結合了16位(FP16)和32位(FP32)浮點格式。其核心思想是在低精度下執行大部分數學運算,從而降低顯存帶寬和存儲需求,同時在計算的關鍵環節保留必要的精度保障。通過使用FP16存儲激活值和梯度,這些張量的顯存占用量可減少約一半。但需注意,某些網絡層或運算仍需保持FP32精度以避免數值不穩定問題。值得慶幸的是,PyTorch對自動混合精度(AMP)的原生支持極大簡化了這一過程。

注意這里是混合精度訓練而不是低精度訓練

什么是混合精度訓練?

混合精度訓練結合使用16位(FP16)和32位(FP32)浮點格式以保持模型精度。通過使用16位精度計算梯度,相比全32位精度計算,這一過程可大幅加快運算速度并顯著減少顯存占用。這種方法在顯存或計算資源受限的場景下尤為實用。

之所以采用混合精度而非低精度這一表述,是因為并非所有參數或運算都被轉換為16位格式。實際上,訓練過程會在32位與16位運算之間動態切換,這種精度層級的有序交替正是該技術被稱為混合精度的根本原因。

如上述示意圖所示,混合精度訓練流程首先將權重轉換為低精度格式(FP16)以加速計算,隨后梯度計算在低精度環境下完成,但為確保數值穩定性,這些梯度會被重新轉換為高精度格式(FP32),最終經過縮放處理的梯度將用于更新原始權重。因此,通過這種機制既能提升訓練效率,又不會犧牲網絡的整體精度與穩定性。

如前所述,使用 torch.cuda.amp.autocast( ) 可以輕松啟用該功能,一個簡單的代碼示例片段如下:

import?torch
from?torch.cuda.amp?import?autocast, GradScaler# Assume your model and optimizer have been defined elsewhere.
model = MyModel().cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
scaler = GradScaler()
for?data, target?in?data_loader:optimizer.zero_grad()# Enable mixed precisionwith?autocast():output = model(data)loss = loss_fn(output, target)# Scale the loss and backpropagatescaler.scale(loss).backward()scaler.step(optimizer)scaler.update()

3、低精度訓練

如原文所述,理論上可以更進一步嘗試完全使用16位低精度(而非混合精度)進行訓練。但此時可能因16位浮點數的固有精度限制出現NaN值異常。為解決這一問題,業界開發了多種新型浮點格式,其中由谷歌專門為此研發的BF16應用較為廣泛。簡而言之,相較于標準的FP16,BF16擁有更大的動態范圍——這種擴展的動態范圍使其能夠更精確地表示極大或極小的數值,從而更適配可能遭遇廣泛數值區間的深度學習場景。雖然其較低的尾數精度在某些情況下可能影響計算準確性或引發舍入誤差,但在大多數實踐中對模型性能的影響微乎其微。

FP16與BF16的動態范圍對比

雖然這種格式最初是為TPU開發的,但在大多數現代GPU(Nvidia Ampere架構及更高版本)也支持這種格式。大家可以使用以下方法檢查您的GPU是否支持這種格式:

import?torch
print(torch.cuda.is_bf16_supported()) ?# should print True

4、梯度檢查點

即使采用混合精度與低精度訓練,這些大型模型仍會生成大量中間張量,消耗可觀的顯存。梯度檢查點技術通過在前向傳播過程中選擇性存儲部分中間結果來解決這一問題——未被保存的中間張量將在反向傳播階段重新計算。盡管這會引入額外的計算開銷,卻能顯著節省顯存資源。

通過策略性選擇需設置檢查點的網絡層,大家可通過動態重新計算激活值而非存儲它們來減少顯存使用。這種時間與內存的折中策略對于具有深層架構的模型特別有益,因為中間激活值占用了大部分內存消耗。以下是一個簡單的使用示例:

import?torch
from?torch.utils.checkpoint?import?checkpoint
def?checkpointed_segment(input_tensor):# This function represents a portion of your model# which will be recomputed during the backward pass.# You can create a custom forward pass for this segment.return?model_segment(input_tensor)
# Instead of a conventional forward pass, wrap the segment with checkpoint.
output = checkpoint(checkpointed_segment, input_tensor)

采用該方法,在大多數情況下可使激活值的顯存占用量降低40%至50%。盡管反向傳播階段因此增加了額外的計算量,但在GPU顯存成為瓶頸的場景下,這種以時間換空間的策略通常是可接受的。

5、通過梯度累積減小批量大小

通過最初的方法,你可能會問自己:

為什么不干脆減少batchsize大小?

通過減小批量大小的確是減少顯存占用最直接的方法,但需注意的是,這種方式在多數情況下會導致模型預測性弱于使用更大批量訓練的模型。因此需要在顯存限制與模型效果之間謹慎權衡。

那么如何達到平衡呢?

這正是梯度累積技術發揮作用之處!該方法通過在訓練過程中虛擬增大有效批量規模:其核心原理是先在較小的批量上計算梯度,并經過多次迭代的累積(通過采用累加或平均方式),而非在每批次處理后立即更新模型參數。當累積梯度達到目標“虛擬”批量規模時,才使用聚合后的梯度一次性完成模型權重的更新。

這種技術的一個主要缺點是大大增加了訓練時間。

6、張量分片與分布式訓練

對于單個GPU無法容納的龐大訓練模型(即使經過上述優化),完全分片數據并行(FSDP)是不可或缺的。FSDP將模型參數、梯度和優化器狀態分散到多個GPU上。這不僅能將巨大的模型放入顯存,還能通過更好地分配通信開銷提高訓練效率。

FSDP不在每個GPU上維護模型的完整副本,而是在可用設備之間分配模型參數。在執行前向或后向傳遞時,只有相關的分片被加載到顯存中。這種分片機制大大降低了對每臺設備顯存的需求,結合上述技術,在某些情況下甚至可以將顯存需求降低10倍。

Tensor Parallel

樣例如下:

import?torch
from?torch.distributed.fsdp?import?FullyShardedDataParallel?as?FSDP
# Initialize your model and ensure it is on the correct device.
model = MyLargeModel().cuda()
# Wrap the model in FSDP for sharded training across GPUs.
fsdp_model = FSDP(model)

7、高效數據加載

在顯存優化實踐中,數據加載環節常被忽視。雖然優化重點通常集中在模型內部結構與計算過程上,但低效的數據處理可能引發不必要的性能瓶頸,同時影響顯存占用與訓練速度。若不確定如何優化數據加載器,可遵循以下經驗法則:優先啟用固定內存(Pinned Memory)與多工作進程(Multiple Workers)配置。

from?torch.utils.data?import?DataLoader# Create your dataset instance and then the DataLoader with pinned memory enabled.
train_loader = DataLoader(dataset,batch_size=64,shuffle=True,num_workers=4, ? ? ?# Adjust based on your CPU capabilitiespin_memory=True? ? ?# Enables faster host-to-device transfers
)

8、使用 In-Place 操作

在張量運算中,若未謹慎管理內存,每次操作都可能生成新對象。原地(In-Place)操作通過直接修改現有張量而非創建副本,可有效減少內存碎片化與總體內存占用。這種特性尤其有利于降低迭代訓練循環中的臨時內存分配開銷。例如:

import?torch
x = torch.randn(100,?100, device='cuda')
y = torch.randn(100,?100, device='cuda')
# Using in-place addition
x.add_(y) ?# Here x is modified directly instead of creating a new tensor

9、Activation and Parameter Offloading

即便綜合運用前述所有優化技術,在訓練超大規模模型時,仍可能因海量中間激活值的瞬時占用而觸及GPU顯存容量極限。此時,中間數據卸載技術可作為額外的安全閥機制——其核心思路是將部分非即時必需的中間數據臨時轉換至CPU內存,從而為GPU顯存騰出關鍵空間,確保訓練流程持續進行。

我們通過策略性將部分激活值和或模型參數卸載至CPU內存,從而將GPU顯存專用于核心計算任務。雖然如DeepSpeed、Fabric等專業框架已內置管理此類數據遷移的機制,大家仍可通過以下方式自主實現該功能。

def?offload_activation(tensor):# Move tensor to CPU to save GPU memoryreturn?tensor.cpu()def?process_batch(data):# Offload some activations explicitlyintermediate = model.layer1(data)intermediate = offload_activation(intermediate)intermediate = intermediate.cuda() ?# Move back when neededoutput = model.layer2(intermediate)return?output

10、使用更精簡的優化器

并非所有優化器對內存的需求均等。以廣泛使用的Adam優化器為例,其針對每個模型參數需額外維護兩個狀態變量(均值與方差),導致內存占用倍增。相比之下,采用無狀態優化器(如SGD)可將參數總量減少近三分之二——這對于訓練大語言模型(LLMs)及其他大規模模型具有顯著意義。

盡管普通SGD優化器存在收斂性能較弱的缺陷,但通過引入余弦衰減學習率調整策略(Cosine Decay Learning Rate Scheduler)可有效補償這一不足。簡而言之:

# instead of this
optimizer?= torch.optim.Adam(model.parameters(), lr=5e-5)
# use this
optimizer?= torch.optim.SGD(model.parameters(), lr=0.01)
num_steps?= NUM_EPOCHS * len(train_loader)
scheduler?= torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_steps)

通過這一調整,大家可以在顯著改變峰值內存占有量的同時(具體取決于實際任務需求),仍能保持模型精度接近97%的水平。

11、高級策略

雖然上面列出的技術確實為我們奠定了堅實的基礎,但我還想列出一些其他高級策略,我們可以考慮將 GPU 提升到極限:

  • 內存剖析和高速緩存管理

如果無法測量,就很難優化。PyTorch 提供了一些檢查 GPU 內存使用情況的默認實用程序。使用方法如下:

import?torch
# print a detailed report of current GPU memory usage and fragmentation
print(torch.cuda.memory_summary(device=None, abbreviated=False))
# free up cached memory that’s no longer needed by PyTorch
torch.cuda.empty_cache()
  • 使用TorchScript進行JIT編譯

PyTorch 的即時(JIT)編譯器使大家使用 TorchScript 將 Python 模型轉換為優化的、可序列化的程序。通過優化內核啟動和減少開銷,這種轉換可以提高內存和性能。您可以通過以下方式輕松訪問它:

import torch
# Suppose `model` is an instance of your PyTorch network.
scripted_model = torch.jit.script(model)
# Now, you can run the scripted model just like before.
output = scripted_model(input_tensor)

盡管框架原生方法已能實現基礎功能,但模型編譯技術通常能帶來更深層次的性能優化。

  • 自定義內核融合

編譯的另一個主要好處是將多個操作融合到一個內核中。這有助于減少內存讀/寫,提高整體吞吐量。融合后的操作如下:

  • 使用torch.compile()進行動態內存分配

再次從編譯中獲益--使用 JIT 編譯器可通過利用跟蹤和圖形優化技術的編譯時優化來優化動態內存分配,從而進一步壓縮內存并提高性能,尤其是在大型模型和Transformer架構中。

12、總結

隨著 GPU 和云計算變得異常昂貴,只有充分利用現有資源才有意義。這有時可能意味著要在單個 GPU 工作站/筆記本電腦上對 LLM 或視覺Transformer進行訓練/微調。上面列出的技術是研究人員/專業人士在算力緊張的情況下進行訓練所使用的眾多策略中的一部分。

參考資料:AI算法之道

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/bicheng/80201.shtml
繁體地址,請注明出處:http://hk.pswp.cn/bicheng/80201.shtml
英文地址,請注明出處:http://en.pswp.cn/bicheng/80201.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

極速輕量,Rust 網絡開發新選擇:Hyperlane 框架深度解析

極速輕量,Rust 網絡開發新選擇:Hyperlane 框架深度解析 在高性能網絡服務開發領域,Rust 憑借其內存安全與高效并發的特性備受青睞。今天,我們迎來一款專為現代 Web 服務打造的明星框架——Hyperlane,它以“輕量高效、…

單片機裸機環境下臨界區保護

目錄 1、直接中斷屏蔽法 2、嵌套計數優化法 3、BASEPRI寄存器應用 4、動態優先級調整策略 5、LDREX/STREX指令應用 6、位帶別名區原子訪問 7、上下文感知保護 8、中斷延遲優化技術 在嵌入式系統開發中,臨界區保護是確保系統可靠性的關鍵技術。本文以ARM Cor…

【deepseek教學應用】001:deepseek如何撰寫教案并自動實現word排版

本文講述利用deepseek如何撰寫教案并自動實現word高效完美排版。 文章目錄 一、訪問deepseek官網二、輸入教案關鍵詞三、格式轉換四、word進一步排版 一、訪問deepseek官網 官網:https://www.deepseek.com/ 進入主頁后,點擊【開始對話】,如…

springboot使用mybatisPlus進行數據庫增刪改查

springboot使用mybatisPlus進行數據庫增刪改查 提示:幫幫志會陸續更新非常多的IT技術知識,希望分享的內容對您有用。本章分享的是springboot的使用。前后每一小節的內容是存在的有:學習and理解的關聯性。【幫幫志系列文章】:每個…

基于SpringBoot的校園周邊美食探索及分享平臺的設計與實現

資源詳情: 私信我或點擊鏈接獲取: 基于SpringBoot的校園周邊美食探索及分享平臺的設計與實現資源-CSDN文庫 摘要 美食一直是與人們日常生活息息相關的產業。傳統的電話訂餐或者到店消費已經不能適應市場發展的需求。隨著網絡的迅速崛起,互聯…

到達最后一個房間的最少時間II 類似棋盤轉移規律查找

文章目錄 3342.到達最后一個房間的最少時間II 思路分析:最短路徑問題,當然,由于不同的格子之間的移動的代價不統一,所以這個最短路徑需要使用Dijkstra算法進行求解,對于直接使用Dijkstra算法模版的題目,大家可以先去做…

基于開源AI大模型AI智能名片S2B2C商城小程序源碼的私域流量穩定性構建研究

摘要:在私域流量時代,傳統實體零售的"時間積累"邏輯被直播電商等新業態顛覆。完美日記等新銳品牌通過構建私域流量池,實現了從0到1的指數級增長,而傳統品牌卻陷入"流量焦慮"。本文提出以開源AI大模型AI智能名…

做 iOS 調試時,我嘗試了 5 款抓包工具

日常做開發的人,特別是和客戶端接口打交道的同學,應該對“抓包”這件事不陌生。 調試登錄流程、分析接口格式、排查錯誤返回、分析網絡性能、甚至研究第三方 App 的數據通信……說到底,都繞不開“抓 HTTPS 包”這一步。 而這一步&#xff0…

Algolia - Docsearch的申請配置安裝【以踩坑解決版】

👨?🎓博主簡介 🏅CSDN博客專家 ??🏅云計算領域優質創作者 ??🏅華為云開發者社區專家博主 ??🏅阿里云開發者社區專家博主 💊交流社區:運維交流社區 歡迎大家的加入&#xff01…

nginx 配置后端健康檢查模塊

nginx自帶的針對后端節點健康檢查的功能比較簡單,通過默認自帶的ngx_http_proxy_module 模塊和ngx_http_upstream_module模塊中的參數來完成,當后端節點出現故障時,自動切換到健康節點來提供訪問。但是nginx不能事先知道后端節點狀態是否健康,后端即使有不健康節點,負載均…

平板收銀系統、國產系統,鴻蒙系統,小鍵盤的封裝與應用—仙盟創夢IDE

數字小鍵盤封裝 數組小鍵盤封裝是指將與數組小鍵盤相關的功能、操作、數據等進行整合,形成一個獨立的、可復用的模塊。封裝數組小鍵盤具有以下幾方面重要意義: 提高代碼可維護性 降低復雜度:數組小鍵盤在實際應用中,可能涉及到…

網工實驗——OSPF配置

網絡拓撲圖 配置 1.為每個路由器配置接口(略)(詳細見RIP實驗) 2.配置OSPF AR1 [AR1]ospf [AR1-ospf-1]area 1 [AR1-ospf-1-area-0.0.0.1]network 172.16.1.1 0.0.0.0 #精確配置網絡,也可以像下面那條命令那樣配置 …

Kubernetes client-go 客戶端類型與初始化指南

Kubernetes client-go 客戶端類型與初始化指南 在 Kubernetes 的 client-go 庫中,存在多種客戶端用于與 API 服務器交互。以下介紹主要客戶端類型,包括用途、初始化方式及 Demo。 1. RESTClient 用途 RESTClient 是底層 REST 客戶端,直接…

java加強 -泛型

概念 定義類、接口、方法時&#xff0c;同時聲明了一個或多個類型變量&#xff08;如<E>&#xff09;&#xff0c;稱為泛型類、泛型接口、泛型方法、它們統稱為泛型。 語法 public class ArrayList<E>{} E可以接收不同類型的數據&#xff0c;可以是字符串&…

C++ 項目 -- 高并發內存池

目錄 項目介紹 內存池概念 池化技術 內存池 內存池主要解決的問題 malloc 定長內存池 申請內存 釋放內存 整體框架設計 thread cache 申請內存 釋放內存 central cache 申請內存 釋放內存 page cache 申請內存 釋放內存 大塊內存申請實現 定長內存…

高效C/C++之九:Coverity修復問題:關于數組操作 和 內存操作

【關注我&#xff0c;后續持續新增專題博文&#xff0c;謝謝&#xff01;&#xff01;&#xff01;】 上一篇我們講了&#xff1a; 這一篇我們開始講&#xff1a; 高效C/C之九&#xff1a;Coverity修復問題&#xff1a;關于數組操作 和 內存操作 目錄 【關注我&#xff0c;后…

vfrom表單設計器使用事件機制控制字段顯示隱藏

1. 使用表單設計器進行debug調試 依據 vform3.0開發者文檔 https://www.ganweicloud.com/docs/6.1.0/pages/d3e6d9/ 對switch組件設置事件邏輯 調試中

iPhone 和 Android 在日期格式方面的區別

整篇文章由iPhone 和 Android 在日期格式方面有所不同引起,大致介紹了,兩種時間標準,以及在 JavaScript 下的格式轉換方法。 Unix 時間戳是從1970年1月1日(UTC/GMT的午夜)開始所經過的秒數,不考慮閏秒。 iPhone 和 Android 在日期格式方面有所不同。其中,iPhone(iOS)使…

985高校查重率“隱性閾值”:低于5%可能被重點審查!

你是不是也以為&#xff1a; “查重率越低越好&#xff0c;最好壓到1%、0%&#xff0c;導師看了都感動哭&#x1f979;” 但是你不知道的是——在985/211等重點高校&#xff0c;查重率太低反而可能引起導師和學術辦公室的“特別關注”&#xff01; 今天就來扒一扒這個查重圈“…

【NLP】33. Pinecone + OpenAI :構建自定義語義搜索系統

Pinecone OpenAI 中文教學教程&#xff1a;構建自定義語義搜索系統 一、背景介紹 當下 AI 問答系統、矩陣檢索、短文本分類等場景中&#xff0c;都需要很好地實現 “根據輸入進行相似給點搜索”。這種算法基礎稱為 “向量搜索”&#xff0c;它的核心是將文本轉換為向量后&am…