PyTorch DDP 隨機卡死復盤:最后一個 batch 掛起,NCCL 等待不返回,三步修復 Sampler & drop_last
很多人在接觸深度學習的過程往往都是從自己的筆記本開始的,但是從接觸工作后,更多的是通過分布式的訓練來模型。由于經驗的不足常常會遇到分布式訓練“玄學卡死”:多卡的訓練偶發在 epoch 尾部停住不動,并且GPU 利用率掉到 0%,日志無異常。為了解決首次接觸分布式訓練的人的疑問,本文從bug現象以及調試逐一分析。
? Bug 現象
在我們進行多卡訓練的時候,偶爾會出現隨機在某些 epoch 尾部卡住,無異常棧;nvidia-smi
顯示兩卡功耗接近空閑。偶爾能看到 NCCL 打印(并不總出現):
NCCL WARN Reduce failed: ... Async operation timed out
接著通過kill -SIGQUIT
打印 Python 棧后發現停在 反向傳播的梯度 allreduce*上(DistributedDataParallel
內部)。
但是這個現象在關掉 DDP(單卡訓練)完全正常;把 batch_size 改小/大,卡住概率改變但仍會發生。
📽? 場景重現
當我們的問題在單卡不會出現,但是多卡會出現問題的時候,問題點集中在數據的問題上。主要原因以下:
1?? shuffle=True
與 DistributedSampler
混用(會被忽略但容易誤導)。
2?? drop_last=False
時,最后一個小批的樣本數在不同 rank 上可能不一致(當 len(dataset)
不是 world_size
的整數倍且某些數據被過濾/增強丟棄時尤其明顯)。
3?? 每個 epoch 忘記調用 sampler.set_epoch(epoch)
,導致各 rank 的隨機順序不一致。
以下是筆者在多卡訓練遇到的問題代碼
import os, random, torch, torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, Dataset, DistributedSamplerclass DummyDS(Dataset):def __init__(self, N=1003): # 刻意設成非 world_size 整數倍self.N = Ndef __len__(self): return self.Ndef __getitem__(self, i):x = torch.randn(32, 3, 224, 224)y = torch.randint(0, 10, (32,)) # 模擬有時會丟棄某些樣本的增強(省略)return x, ydef setup():dist.init_process_group("nccl")torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))def main():setup()rank = dist.get_rank()device = torch.device("cuda", int(os.environ["LOCAL_RANK"]))ds = DummyDS()sampler = DistributedSampler(ds, shuffle=True, drop_last=False) # ? drop_last=False# ? DataLoader 里又寫了 shuffle=True(被忽略,但容易誤以為生效)loader = DataLoader(ds, batch_size=2, shuffle=True, sampler=sampler, num_workers=4)model = torch.nn.Linear(3*224*224, 10).to(device)model = DDP(model, device_ids=[device.index])opt = torch.optim.SGD(model.parameters(), lr=0.1)for epoch in range(5):# ? 忘記 sampler.set_epoch(epoch)for x, y in loader:x = x.view(x.size(0), -1).to(device)y = y.to(device)opt.zero_grad()loss = torch.nn.functional.cross_entropy(model(x), y)loss.backward() # 🔥 偶發卡在這里(allreduce)opt.step()if rank == 0:print(f"epoch {epoch} done")dist.destroy_process_group()if __name__ == "__main__":main()
🔴觸發條件(滿足一兩個就可能復現):
1?? len(dataset)
不是 world_size
的整數倍。
2?? 動態數據過濾/增強(例如有時返回 None
或丟樣),導致各 rank 實際步數不同。
3?? 忘記 sampler.set_epoch(epoch)
,各 rank 洗牌次序不同。
4?? drop_last=False
,導致最后一個 batch 在各 rank 的樣本數不同。
5?? 某些自定義 collate_fn 在“空 batch”時直接 continue
。
?? Debug
1?? 先確認“各 rank 步數一致”
在訓練 loop 里加統計(不要只在 rank0 打印):
from collections import Counter
steps = Counter()
for i, _ in enumerate(loader):steps[rank] += 1
dist.all_reduce(torch.tensor([steps[rank]], device=device), op=dist.ReduceOp.SUM)
# 或每個 rank 各自 print,檢查是否相等
我的現象:有的 epoch,rank0 比 rank1 多 1–2 個 step。
2??開啟 NCCL 調試
在啟動前設置:
export NCCL_DEBUG=INFO
export NCCL_ASYNC_ERROR_HANDLING=1
export NCCL_BLOCKING_WAIT=1
再跑一遍,可看到某些 allreduce 一直等不到某 rank 進來。
3??檢查 Sampler 與 DataLoader 參數
DistributedSampler
必須搭配sampler.set_epoch(epoch)
。- DataLoader 里不要再寫
shuffle=True
。 - 若數據不可整除,優先
drop_last=True
;否則確保各 rank 最后一個 batch 大小一致(例如補齊/填充)。
🟢 解決方案(修復版)
- 嚴格對齊 Sampler 語義 + 丟最后不齊整的 batch
import os, torch, torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, DistributedSampler, Datasetclass DummyDS(Dataset):def __init__(self, N=1003): self.N=Ndef __len__(self): return self.Ndef __getitem__(self, i):x = torch.randn(32, 3, 224, 224)y = torch.randint(0, 10, (32,))return x, ydef setup():dist.init_process_group("nccl")torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))def main():setup()rank = dist.get_rank()device = torch.device("cuda", int(os.environ["LOCAL_RANK"]))ds = DummyDS()# 關鍵 1:使用 DistributedSampler,統一交給它洗牌sampler = DistributedSampler(ds, shuffle=True, drop_last=True) # ?# 關鍵 2:DataLoader 里不要再寫 shuffleloader = DataLoader(ds, batch_size=2, sampler=sampler, num_workers=4, pin_memory=True)model = torch.nn.Linear(3*224*224, 10).to(device)ddp = DDP(model, device_ids=[device.index], find_unused_parameters=False) # 如無動態分支,關掉更穩更快opt = torch.optim.SGD(ddp.parameters(), lr=0.1)for epoch in range(5):sampler.set_epoch(epoch) # ? 關鍵 3:每個 epoch 設置不同隨機種子for x, y in loader:x = x.view(x.size(0), -1).to(device, non_blocking=True)y = y.to(device, non_blocking=True)opt.zero_grad(set_to_none=True)loss = torch.nn.functional.cross_entropy(ddp(x), y)loss.backward()opt.step()if rank == 0:print(f"epoch {epoch} ok")dist.barrier() # ? 收尾同步,避免 rank 提前退出dist.destroy_process_group()if __name__ == "__main__":main()
- 必須保留最后一批
如果確實不能 drop_last=True
(例如小數據集),可考慮對齊 batch 大小:
- Padding/Repeat:在
collate_fn
里把最后一批補齊到一致大小; - EvenlyDistributedSampler:自定義 sampler,確保各 rank 拿到完全等長的 index 列表(對總長度做上采樣)。
示例(最簡單的“循環補齊”):
class EvenSampler(DistributedSampler):def __iter__(self):# 先拿到原始 index,再做均勻補齊indices = list(super().__iter__())# 使得 len(indices) 可整除 num_replicasrem = len(indices) % self.num_replicasif rem != 0:pad = self.num_replicas - remindices += indices[:pad] # 簡單重復前幾個樣本return iter(indices)
總結
以上是這次 DDP 卡死問題從現象 → 排查 → 解決的完整記錄。這個坑非常高頻,尤其在課程項目/科研代碼里常被忽視。希望這篇復盤能讓你在分布式訓練時少掉一把汗。最終定位是 DistributedSampler
使用不當 + drop_last=False
+ 忘記 set_epoch
引發各 rank 步數不一致,導致 allreduce 永久等待。