常規訓練方式
for x,y in train_loader:pred = model(x)loss = criterion(pred, label)# 反向傳播loss.backward()# 根據新的梯度更新網絡參數optimizer.step()# 清空以往梯度,通過下面反向傳播重新計算梯度optimizer.zero_grad()
pytorch每次forward完都會得到一個用于梯度回傳的計算圖,pytorch構建的計算圖是動態的,其實在每次backward后計算圖都會從內存中釋放掉,但是梯度不會清空的。所以若不顯示的進行optimizer.zero_grad()清空過往梯度這一步操作,backward()的時候就會累加過往梯度。
梯度累加方法
accumulation_steps = 4
for i,(x,y) in enumerate(train_loader):pred = model(x)loss = criterion(pred, label)# 相當于對累加后的梯度取平均loss = loss/accumulation_steps# 反向傳播loss.backward()if (i+1) % accumulation_steps == 0:# 根據新的梯度更新網絡參數optimizer.step()# 清空以往梯度,通過下面反向傳播重新計算梯度optimizer.zero_grad()
????????代碼中設置accumulation_steps = 4,意思就是變相擴大batch_size四倍。因為代碼中每隔4次迭代才清空梯度,更新參數。
????????loss = loss/accumulation_steps,梯度累加了四次,那就要取平均,除以4。每次loss取4,其實就相當于最后將累加后的梯度除4。同時,因為累計了4個batch,那學習率也應該擴大4倍,讓更新的步子跨大點。
?看網上的帖子有討論對BN層是否有影響,因為BN的估算階段(計算batch內均值、方差)是在forward階段完成的,那真實的batch_size放大4倍效果肯定是比通過梯度累加放大4倍效果好的,畢竟計算真實的大batch_size內的均值、方差肯定更精確。
?還有討論說通過調低BN參數momentum可以得到更長序列的統計信息,應該意思是能夠記憶更久遠的統計信息(均值、方差),以逼近真實的擴大batch_size的效果。
?
參考
pytorch騷操作之梯度累加,變相增大batch size
pytorch里巧用optimizer.zero_grad增大batchsize