PyTorch DDP 隨機卡死復盤:最后一個 batch 掛起,NCCL 等待不返回

PyTorch DDP 隨機卡死復盤:最后一個 batch 掛起,NCCL 等待不返回,三步修復 Sampler & drop_last

很多人在接觸深度學習的過程往往都是從自己的筆記本開始的,但是從接觸工作后,更多的是通過分布式的訓練來模型。由于經驗的不足常常會遇到分布式訓練“玄學卡死”:多卡的訓練偶發在 epoch 尾部停住不動,并且GPU 利用率掉到 0%,日志無異常。為了解決首次接觸分布式訓練的人的疑問,本文從bug現象以及調試逐一分析。

? Bug 現象

在我們進行多卡訓練的時候,偶爾會出現隨機在某些 epoch 尾部卡住,無異常棧;nvidia-smi 顯示兩卡功耗接近空閑。偶爾能看到 NCCL 打印(并不總出現):

NCCL WARN Reduce failed: ... Async operation timed out

接著通過kill -SIGQUIT 打印 Python 棧后發現停在 反向傳播的梯度 allreduce*上(DistributedDataParallel 內部)。

但是這個現象在關掉 DDP(單卡訓練)完全正常;把 batch_size 改小/大,卡住概率改變但仍會發生。

📽? 場景重現

當我們的問題在單卡不會出現,但是多卡會出現問題的時候,問題點集中在數據的問題上。主要原因以下:

1?? shuffle=TrueDistributedSampler 混用(會被忽略但容易誤導)。

2?? drop_last=False 時,最后一個小批的樣本數在不同 rank 上可能不一致(當 len(dataset) 不是 world_size 的整數倍且某些數據被過濾/增強丟棄時尤其明顯)。

3?? 每個 epoch 忘記調用 sampler.set_epoch(epoch),導致各 rank 的隨機順序不一致

以下是筆者在多卡訓練遇到的問題代碼

import os, random, torch, torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, Dataset, DistributedSamplerclass DummyDS(Dataset):def __init__(self, N=1003):  # 刻意設成非 world_size 整數倍self.N = Ndef __len__(self): return self.Ndef __getitem__(self, i):x = torch.randn(32, 3, 224, 224)y = torch.randint(0, 10, (32,))   # 模擬有時會丟棄某些樣本的增強(省略)return x, ydef setup():dist.init_process_group("nccl")torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))def main():setup()rank = dist.get_rank()device = torch.device("cuda", int(os.environ["LOCAL_RANK"]))ds = DummyDS()sampler = DistributedSampler(ds, shuffle=True, drop_last=False)  # ? drop_last=False# ? DataLoader 里又寫了 shuffle=True(被忽略,但容易誤以為生效)loader = DataLoader(ds, batch_size=2, shuffle=True, sampler=sampler, num_workers=4)model = torch.nn.Linear(3*224*224, 10).to(device)model = DDP(model, device_ids=[device.index])opt = torch.optim.SGD(model.parameters(), lr=0.1)for epoch in range(5):# ? 忘記 sampler.set_epoch(epoch)for x, y in loader:x = x.view(x.size(0), -1).to(device)y = y.to(device)opt.zero_grad()loss = torch.nn.functional.cross_entropy(model(x), y)loss.backward()      # 🔥 偶發卡在這里(allreduce)opt.step()if rank == 0:print(f"epoch {epoch} done")dist.destroy_process_group()if __name__ == "__main__":main()

🔴觸發條件(滿足一兩個就可能復現):

1?? len(dataset) 不是 world_size 的整數倍。

2?? 動態數據過濾/增強(例如有時返回 None 或丟樣),導致各 rank 實際步數不同。

3?? 忘記 sampler.set_epoch(epoch),各 rank 洗牌次序不同。

4?? drop_last=False,導致最后一個 batch 在各 rank 的樣本數不同。

5?? 某些自定義 collate_fn 在“空 batch”時直接 continue

?? Debug

1?? 先確認“各 rank 步數一致”

在訓練 loop 里加統計(不要只在 rank0 打印):

from collections import Counter
steps = Counter()
for i, _ in enumerate(loader):steps[rank] += 1
dist.all_reduce(torch.tensor([steps[rank]], device=device), op=dist.ReduceOp.SUM)
# 或每個 rank 各自 print,檢查是否相等

我的現象:有的 epoch,rank0 比 rank1 多 1–2 個 step

2??開啟 NCCL 調試

在啟動前設置:

export NCCL_DEBUG=INFO
export NCCL_ASYNC_ERROR_HANDLING=1
export NCCL_BLOCKING_WAIT=1

再跑一遍,可看到某些 allreduce 一直等不到某 rank 進來。

3??檢查 Sampler 與 DataLoader 參數
  • DistributedSampler 必須搭配 sampler.set_epoch(epoch)
  • DataLoader 里不要再寫 shuffle=True
  • 若數據不可整除,優先 drop_last=True;否則確保各 rank 最后一個 batch 大小一致(例如補齊/填充)。

🟢 解決方案(修復版)

  • 嚴格對齊 Sampler 語義 + 丟最后不齊整的 batch
import os, torch, torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, DistributedSampler, Datasetclass DummyDS(Dataset):def __init__(self, N=1003): self.N=Ndef __len__(self): return self.Ndef __getitem__(self, i):x = torch.randn(32, 3, 224, 224)y = torch.randint(0, 10, (32,))return x, ydef setup():dist.init_process_group("nccl")torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))def main():setup()rank = dist.get_rank()device = torch.device("cuda", int(os.environ["LOCAL_RANK"]))ds = DummyDS()# 關鍵 1:使用 DistributedSampler,統一交給它洗牌sampler = DistributedSampler(ds, shuffle=True, drop_last=True)  # ?# 關鍵 2:DataLoader 里不要再寫 shuffleloader = DataLoader(ds, batch_size=2, sampler=sampler, num_workers=4, pin_memory=True)model = torch.nn.Linear(3*224*224, 10).to(device)ddp = DDP(model, device_ids=[device.index], find_unused_parameters=False)  # 如無動態分支,關掉更穩更快opt = torch.optim.SGD(ddp.parameters(), lr=0.1)for epoch in range(5):sampler.set_epoch(epoch)  # ? 關鍵 3:每個 epoch 設置不同隨機種子for x, y in loader:x = x.view(x.size(0), -1).to(device, non_blocking=True)y = y.to(device, non_blocking=True)opt.zero_grad(set_to_none=True)loss = torch.nn.functional.cross_entropy(ddp(x), y)loss.backward()opt.step()if rank == 0:print(f"epoch {epoch} ok")dist.barrier()  # ? 收尾同步,避免 rank 提前退出dist.destroy_process_group()if __name__ == "__main__":main()
  • 必須保留最后一批

如果確實不能 drop_last=True(例如小數據集),可考慮對齊 batch 大小

  1. Padding/Repeat:在 collate_fn 里把最后一批補齊到一致大小
  2. EvenlyDistributedSampler:自定義 sampler,確保各 rank 拿到完全等長的 index 列表(對總長度做上采樣)。

示例(最簡單的“循環補齊”):

class EvenSampler(DistributedSampler):def __iter__(self):# 先拿到原始 index,再做均勻補齊indices = list(super().__iter__())# 使得 len(indices) 可整除 num_replicasrem = len(indices) % self.num_replicasif rem != 0:pad = self.num_replicas - remindices += indices[:pad]     # 簡單重復前幾個樣本return iter(indices)

總結

以上是這次 DDP 卡死問題從現象 → 排查 → 解決的完整記錄。這個坑非常高頻,尤其在課程項目/科研代碼里常被忽視。希望這篇復盤能讓你在分布式訓練時少掉一把汗。最終定位是 DistributedSampler 使用不當 + drop_last=False + 忘記 set_epoch引發各 rank 步數不一致,導致 allreduce 永久等待。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/bicheng/95858.shtml
繁體地址,請注明出處:http://hk.pswp.cn/bicheng/95858.shtml
英文地址,請注明出處:http://en.pswp.cn/bicheng/95858.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

計算機專業考研備考建議

對于全國碩士研究生招生考試(考研),考試科目主要由兩大部分組成:全國統一命題的公共課 和 由招生單位自主命題的專業課。具體的考試科目取決于你報考的專業和學校。下面我為你詳細拆解:一、考試科目構成(絕…

關于嵌入式學習——單片機1

基礎整體概念以應用為中心:消費電子(手機、藍牙耳機、智能音響)、醫療電子(心率脈搏、呼吸機)、無人機(大疆D)、機器人(人形四足機器人) 計算機技術:計算機五大組成:運算器(數據運算)、控制器(指令控制)、存儲器(內存外存)、輸入設備(鼠標、鍵盤、攝像頭)、輸出設備(顯示器)軟件…

LightDock.server liunx 雙跑比較

LightDock: a new multi-scale approach to protein–protein docking The LightDock server is free and open to all users and there is no login requirement server 1示例 故去除約束 next step 結果有正有負合理 2.常見警告? Structure contains HETATM entries. P…

SQL面試題及詳細答案150道(61-80) --- 多表連接查詢篇

《前后端面試題》專欄集合了前后端各個知識模塊的面試題,包括html,javascript,css,vue,react,java,Openlayers,leaflet,cesium,mapboxGL,threejs,nodejs,mangoDB,MySQL,Linux… 。 前后端面試題-專欄總目錄 文章目錄 一、本文面試題目錄 61. 什么是內連接(INNE…

【實操】Noej4圖數據庫安裝和mysql表銜接實操

目錄 一、圖數據庫介紹 二、安裝Neo4j 2.1 安裝java環境 2.2 安裝 Neo4j(社區版) 2.3 修改配置 2.4 驗證測試 2.5 卸載 2.6 基本用法 2.7 windows連接服務器可視化 三、neo4j和mysql對比 3.1 場景對比 3.2 Mysql和neo4j的映射對比 3.3 mys…

【mysql】SQL查詢全解析:從基礎分組到高級自連接技巧

SQL查詢全解析:從基礎分組到高級自連接技巧詳解玩家首次登錄查詢的多種實現方式與優化技巧在數據庫查詢中,同一個需求往往有多種實現方式。本文將通過"查詢每個玩家第一次登錄的日期"這一常見需求,深入解析SQL查詢的多種實現方法&a…

MySQL常見報錯分析及解決方案總結(9)---出現interactive_timeout/wait_timeout

關于超時報錯,一共有五種超時參數,詳見:MySQL常見報錯分析及解決方案總結(7)---超時參數connect_timeout、interactive_timeout/wait_timeout、lock_wait_timeout、net等-CSDN博客 以下是當前報錯的排查方法和解決方案: MySQL 中…

第13章 Jenkins性能優化

13.1 性能優化概述 性能問題識別 常見性能瓶頸: Jenkins性能問題分類:1. 系統資源瓶頸- CPU使用率過高- 內存不足或泄漏- 磁盤I/O瓶頸- 網絡帶寬限制2. 應用層面問題- JVM配置不當- 垃圾回收頻繁- 線程池配置問題- 數據庫連接池不足3. 架構設計問題- 單點…

Python+DRVT 從外部調用 Revit:批量創建梁

今天讓我們繼續,看看如何批量創建常用的基礎元素:梁。 跳過軸線為直線段形的,先從圓弧形的開始: from typing import List, Tuple import math # drvt_pybind 支持多會話、多文檔,先從簡單的單會話、單文檔開始 # My…

水上樂園票務管理系統設計與開發(代碼+數據庫+LW)

摘 要 隨著旅游業的蓬勃發展,水上樂園作為夏日娛樂的重要組成部分,其票務管理效率和服務質量直接影響游客體驗。然而,傳統的票務管理模式往往面臨信息更新不及時、服務響應慢等問題。因此,本研究旨在通過設計并實現一個基于Spri…

【前端教程】JavaScript DOM 操作實戰案例詳解

案例1&#xff1a;操作div子節點并修改樣式與內容 功能說明 獲取div下的所有子節點&#xff0c;設置它們的背景顏色為紅色&#xff1b;如果是p標簽&#xff0c;將其內容設置為"我愛中國"。 實現代碼 <!DOCTYPE html> <html> <head><meta ch…

qiankun+vite+react配置微前端

微前端框架&#xff1a;qiankun。 主應用&#xff1a;react19vite7&#xff0c;子應用1&#xff1a;react19vite7&#xff0c;子應用2 &#xff1a;react19vite7 一、主應用 1. 安裝依賴 pnpm i qiankun 2. 注冊子應用 (1) 在src目錄下創建個文件夾&#xff0c;用來存儲關于微…

git: 取消文件跟蹤

場景&#xff1a;第一次初始化倉庫的時候沒有忽略.env或者node_modules&#xff0c;導致后面將.env加入.gitignore也不生效。 取消文件跟蹤&#xff1a;如果是因為 node_modules 已被跟蹤導致忽略無效&#xff0c; 可以使用命令git rm -r --cached node_modules來刪除緩存&…

開講啦|MBSE公開課:第五集 MBSE中期設想(下)

第五集 在本集課程中&#xff0c;劉玉生教授以MBSE建模工具選型及二次定制開發為核心切入點&#xff0c;系統闡釋了"為何需要定制開發"與"如何實施定制開發"的實踐邏輯&#xff0c;并提煉出MBSE中期實施的四大核心要素&#xff1a;高效高質建摸、跨域協同…

CSDN個人博客文章全面優化過程

兩天前達到博客專家申請條件&#xff0c;興高采烈去申請博客專家&#xff1a; 結果今天一看&#xff0c;申請被打回了&#xff1a; 我根據“是Yu欸”大神的博客&#xff1a; 【2024-完整版】python爬蟲 批量查詢自己所有CSDN文章的質量分&#xff1a;附整個實現流程_抓取csdn的…

Websocket的Key多少個字節

在WebSocket協議中&#xff0c;握手過程中的Sec-WebSocket-Key是一個由客戶端生成的隨機字符串&#xff0c;用于安全地建立WebSocket連接。這個Sec-WebSocket-Key是基于Base64編碼的&#xff0c;并且通常由客戶端在WebSocket握手請求的頭部字段中發送。根據WebSocket協議規范&a…

SVT-AV1編碼器中實現WPP依賴管理核心調度

一 assign_enc_dec_segments 函數。這個函數是 SVT-AV1 編碼器中實現波前并行處理&#xff08;WPP&#xff09; 和分段依賴管理的核心調度器之一。//函數功能&#xff1a;分配編碼解碼段任務//返回值Bool//True 成功分配了一個段給當前線程&#xff0c;調用者應該處理這個段//F…

直接讓前端請求代理到自己的本地服務器,告別CV報文到自己的API工具,解放雙手

直接使用前端直接調用本地服務器&#xff0c;在自己的瀏覽器搜索插件proxyVerse&#xff0c;類似的插件應該還有一些&#xff0c;可以選擇自己喜歡的這類插件可以將瀏覽器請求&#xff0c;直接轉發到本地服務器&#xff0c;這樣在本地調試的時候&#xff0c;不需要前端項目&…

Golang Goroutine 與 Channel:構建高效并發程序的基石

在當今這個多核處理器日益普及的時代&#xff0c;利用并發來提升程序的性能和響應能力已經成為軟件開發的必然趨勢。而Go語言&#xff0c;作為一門為并發而生的語言&#xff0c;其設計哲學中將“并發”置于核心地位。其中&#xff0c;Goroutines 和 Channels 是Go實現并發編程的…

17 C 語言宏進階必看:從宏替換避坑到宏函數用法,不定參數模擬實現一次搞定

預處理詳解1. 預定義符號//C語?設置了?些預定義符號&#xff0c;可以直接使?&#xff0c;預定義符號也是在預處理期間處理的。 __FILE__ //進?編譯的源?件--預處理階段被替換成指向文件名字符串的指針--char* 類型的變量 __LINE__ //?件當前的?號 --預處理階段替換成使用…