關于pytorch使用多個dataloader并使用zip和cycle來進行循環時出現的顯存泄漏的問題
如果我們想要在 Pytorch 中同時迭代兩個 dataloader 來處理數據,會有兩種情況:一是我們按照較短的 dataloader 來迭代,長的 dataloader 超過的部分就丟棄掉;二是比較常見的,我們想要按照較長的 dataloader 來迭代,短的 dataloader 在循環完一遍再循環一遍,直到長的 dataloader 循環完一遍。
兩個dataloader的寫法及問題的出現
第一種情況很好寫,直接用 zip
包一下兩個 dataloader 即可:
# ...
dataloaders1 = DataLoader(DummyDataset(0, 100), batch_size=10, shuffle=True)
dataloaders2 = DataLoader(DummyDataset(0, 200), batch_size=10, shuffle=True)
num_epochs = 10for epoch in range(num_epochs):for i, data in enumerate(zip(dataloaders1, dataloaders2)):print(data)# 開始寫你的訓練腳本
第二種情況筆者一開始時參考的一篇博客的寫法,用 cycle
將較短的 dataloader 包一下:
from itertools import cycle
# ...
dataloaders1 = DataLoader(DummyDataset(0, 100), batch_size=10, shuffle=True)
dataloaders2 = DataLoader(DummyDataset(0, 200), batch_size=10, shuffle=True)
num_epochs = 10for epoch in range(num_epochs):for i, data in enumerate(zip(cycle(dataloaders1), dataloaders2)):print(data)# 開始寫你的訓練腳本
是可以運行,但是這樣出現了明顯顯存泄漏的問題,在筆者自己的實驗中,顯存占用量會隨著訓練的進行,每輪增加 20M 左右,最終導致顯存溢出,程序失敗。
解決方法
筆者找了半天,終于在 StackOverflow 的一篇貼子中找到了解決方法,該貼的一個答案指出:cycle
加 zip
的方法確實可能會造成顯存泄漏(memory leakage)的問題,尤其是在使用圖像數據集時,可以通過以下寫法來迭代兩個 dataloader 并避免這個問題:
# ...
dataloaders1 = DataLoader(DummyDataset(0, 100), batch_size=10, shuffle=True)
dataloaders2 = DataLoader(DummyDataset(0, 200), batch_size=10, shuffle=True)
num_epochs = 10for epoch in range(num_epochs):dataloader_iterator1 = iter(dataloaders1)for i, data2 in enumerate(dataloaders2):try:data1 = next(dataloader_iterator1)except StopIteration:dataloader_iterator1 = iter(dataloaders1)data1 = next(dataloader_iterator1)print(data1, data2)# 開始你的訓練腳本
筆者親測這種方式是可以正常運行且不會有顯存泄漏問題的。