Pytorch的BatchNorm層使用中容易出現的問題

前言

本文主要介紹在pytorch中的Batch Normalization的使用以及在其中容易出現的各種小問題,本來此文應該歸屬于[1]中的,但是考慮到此文的篇幅可能會比較大,因此獨立成篇,希望能夠幫助到各位讀者。如有謬誤,請聯系指出,如需轉載,請注明出處,謝謝。

? \nabla ? 聯系方式:

e-mail: FesianXu@gmail.com

QQ: 973926198

github: https://github.com/FesianXu

知乎專欄: 計算機視覺/計算機圖形理論與應用

微信公眾號:
qrcode
Batch Normalization,批規范化

Batch Normalization(簡稱為BN)[2],中文翻譯成批規范化,是在深度學習中普遍使用的一種技術,通常用于解決多層神經網絡中間層的協方差偏移(Internal Covariate Shift)問題,類似于網絡輸入進行零均值化和方差歸一化的操作,不過是在中間層的輸入中操作而已,具體原理不累述了,見[2-4]的描述即可。

在BN操作中,最重要的無非是這四個式子:

注意到這里的最后一步也稱之為仿射(affine),引入這一步的目的主要是設計一個通道,使得輸出output至少能夠回到輸入input的狀態(當 γ = 1 , β = 0 \gamma=1,\beta=0 γ=1,β=0時)使得BN的引入至少不至于降低模型的表現,這是深度網絡設計的一個套路。
整個過程見流程圖,BN在輸入后插入,BN的輸出作為規范后的結果輸入的后層網絡中。

好了,這里我們記住了,在BN中,一共有這四個參數我們要考慮的:

??? γ , β \gamma, \beta γ,β:分別是仿射中的 w e i g h t \mathrm{weight} weight和 b i a s \mathrm{bias} bias,在pytorch中用weight和bias表示。
??? μ B \mu_{\mathcal{B}} μB?和 σ B 2 \sigma_{\mathcal{B}}^2 σB2?:和上面的參數不同,這兩個是根據輸入的batch的統計特性計算的,嚴格來說不算是“學習”到的參數,不過對于整個計算是很重要的。在pytorch中,這兩個統計參數,用running_mean和running_var表示[5],這里的running指的就是當前的統計參數不一定只是由當前輸入的batch決定,還可能和歷史輸入的batch有關,詳情見以下的討論,特別是參數momentum那部分。

Update 2020/3/16:
因為BN層的考核,在工作面試中實在是太常見了,在本文順帶補充下BN層的參數的具體shape大小。
以圖片輸入作為例子,在pytorch中即是nn.BatchNorm2d(),我們實際中的BN層一般是對于通道進行的,舉個例子而言,我們現在的輸入特征(可以視為之前討論的batch中的其中一個樣本的shape)為 x ∈ R C × W × H \mathbf{x} \in \mathbb{R}^{C \times W \times H} x∈RC×W×H(其中C是通道數,W是width,H是height),那么我們的 μ B ∈ R C \mu_{\mathcal{B}} \in \mathbb{R}^{C} μB?∈RC,而方差 σ B 2 ∈ R C \sigma^{2}_{\mathcal{B}} \in \mathbb{R}^C σB2?∈RC。而仿射中 w e i g h t , γ ∈ R C \mathrm{weight}, \gamma \in \mathbb{R}^{C} weight,γ∈RC以及 b i a s , β ∈ R C \mathrm{bias}, \beta \in \mathbb{R}^{C} bias,β∈RC。我們會發現,這些參數,無論是學習參數還是統計參數都會通道數有關,其實在pytorch中,通道數的另一個稱呼是num_features,也即是特征數量,因為不同通道的特征信息通常很不相同,因此需要隔離開通道進行處理。

有些朋友可能會認為這里的weight應該是一個張量,而不應該是一個矢量,其實不是的,這里的weight其實應該看成是 對輸入特征圖的每個通道得到的歸一化后的 x ^ \hat{\mathbf{x}} x^進行尺度放縮的結果,因此對于一個通道數為 C C C的輸入特征圖,那么每個通道都需要一個尺度放縮因子,同理,bias也是對于每個通道而言的。這里切勿認為 y i ← γ x ^ i + β y_i \leftarrow \gamma \hat{x}_i+\beta yi?←γx^i?+β這一步是一個全連接層,他其實只是一個尺度放縮而已。關于這些參數的形狀,其實可以直接從pytorch源代碼看出,這里截取了_NormBase層的部分初始代碼,便可一見端倪。

class _NormBase(Module):
??? """Common base of _InstanceNorm and _BatchNorm"""
??? _version = 2
??? __constants__ = ['track_running_stats', 'momentum', 'eps',
???????????????????? 'num_features', 'affine']

??? def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True,
???????????????? track_running_stats=True):
??????? super(_NormBase, self).__init__()
??????? self.num_features = num_features
??????? self.eps = eps
??????? self.momentum = momentum
??????? self.affine = affine
??????? self.track_running_stats = track_running_stats
??????? if self.affine:
??????????? self.weight = Parameter(torch.Tensor(num_features))
??????????? self.bias = Parameter(torch.Tensor(num_features))
??????? else:
??????????? self.register_parameter('weight', None)
??????????? self.register_parameter('bias', None)
??????? if self.track_running_stats:
??????????? self.register_buffer('running_mean', torch.zeros(num_features))
??????????? self.register_buffer('running_var', torch.ones(num_features))
??????????? self.register_buffer('num_batches_tracked', torch.tensor(0, dtype=torch.long))
??????? else:
??????????? self.register_parameter('running_mean', None)
??????????? self.register_parameter('running_var', None)
??????????? self.register_parameter('num_batches_tracked', None)
??????? self.reset_parameters()

?

在Pytorch中使用

Pytorch中的BatchNorm的API主要有:

torch.nn.BatchNorm1d(num_features,
???????????????????? eps=1e-05,
???????????????????? momentum=0.1,
???????????????????? affine=True,
???????????????????? track_running_stats=True)

?

一般來說pytorch中的模型都是繼承nn.Module類的,都有一個屬性trainning指定是否是訓練狀態,訓練狀態與否將會影響到某些層的參數是否是固定的,比如BN層或者Dropout層。通常用model.train()指定當前模型model為訓練狀態,model.eval()指定當前模型為測試狀態。
同時,BN的API中有幾個參數需要比較關心的,一個是affine指定是否需要仿射,還有個是track_running_stats指定是否跟蹤當前batch的統計特性。容易出現問題也正好是這三個參數:trainning,affine,track_running_stats。

??? 其中的affine指定是否需要仿射,也就是是否需要上面算式的第四個,如果affine=False,則 γ = 1 , β = 0 \gamma=1,\beta=0 γ=1,β=0,并且不能學習被更新。一般都會設置成affine=True[10]
??? trainning和track_running_stats,track_running_stats=True表示跟蹤整個訓練過程中的batch的統計特性,得到方差和均值,而不只是僅僅依賴與當前輸入的batch的統計特性。相反的,如果track_running_stats=False那么就只是計算當前輸入的batch的統計特性中的均值和方差了。當在推理階段的時候,如果track_running_stats=False,此時如果batch_size比較小,那么其統計特性就會和全局統計特性有著較大偏差,可能導致糟糕的效果。

一般來說,trainning和track_running_stats有四種組合[7]

??? trainning=True, track_running_stats=True。這個是期望中的訓練階段的設置,此時BN將會跟蹤整個訓練過程中batch的統計特性。
??? trainning=True, track_running_stats=False。此時BN只會計算當前輸入的訓練batch的統計特性,可能沒法很好地描述全局的數據統計特性。
??? trainning=False, track_running_stats=True。這個是期望中的測試階段的設置,此時BN會用之前訓練好的模型中的(假設已經保存下了)running_mean和running_var并且不會對其進行更新。一般來說,只需要設置model.eval()其中model中含有BN層,即可實現這個功能。[6,8]
??? trainning=False, track_running_stats=False 效果同(2),只不過是位于測試狀態,這個一般不采用,這個只是用測試輸入的batch的統計特性,容易造成統計特性的偏移,導致糟糕效果。

同時,我們要注意到,BN層中的running_mean和running_var的更新是在forward()操作中進行的,而不是optimizer.step()中進行的,因此如果處于訓練狀態,就算你不進行手動step(),BN的統計特性也會變化的。如

model.train() # 處于訓練狀態

for data, label in self.dataloader:
?? ?pred = model(data) ?
?? ?# 在這里就會更新model中的BN的統計特性參數,running_mean, running_var
?? ?loss = self.loss(pred, label)
?? ?# 就算不要下列三行代碼,BN的統計特性參數也會變化
?? ?opt.zero_grad()
?? ?loss.backward()
?? ?opt.step()

?

這個時候要將model.eval()轉到測試階段,才能固定住running_mean和running_var。有時候如果是先預訓練模型然后加載模型,重新跑測試的時候結果不同,有一點性能上的損失,這個時候十有八九是trainning和track_running_stats設置的不對,這里需要多注意。 [8]

假設一個場景,如下圖所示:

此時為了收斂容易控制,先預訓練好模型model_A,并且model_A內含有若干BN層,后續需要將model_A作為一個inference推理模型和model_B聯合訓練,此時就希望model_A中的BN的統計特性值running_mean和running_var不會亂變化,因此就必須將model_A.eval()設置到測試模式,否則在trainning模式下,就算是不去更新該模型的參數,其BN都會改變的,這個將會導致和預期不同的結果。

Update 2020/3/17:
評論區的Oshrin朋友提出問題

??? 作者您好,寫的很好,但是是否存在問題。即使將track_running_stats設置為False,如果momentum不為None的話,還是會用滑動平均來計算running_mean和running_var的,而非是僅僅使用本batch的數據情況。而且關于凍結bn層,有一些更好的方法。

這里的momentum的作用,按照文檔,這個參數是在對統計參數進行更新過程中,進行指數平滑使用的,比如統計參數的更新策略將會變成:

其中的更新后的統計參數 x ^ n e w \hat{x}_{\mathrm{new}} x^new?,是根據當前觀察 x t x_t xt?和歷史觀察 x ^ \hat{x} x^進行加權平均得到的(差分的加權平均相當于歷史序列的指數平滑),默認的momentum=0.1。然而跟蹤歷史信息并且更新的這個行為是基于track_running_stats為true并且training=true的情況同時成立的時候,才會進行的,當在track_running_stats=true, training=false時(在默認的model.eval()情況下,即是之前談到的四種組合的第三個,既滿足這種情況),將不涉及到統計參數的指數滑動更新了。[12,13]

這里引用一個不錯的BN層凍結的例子,如:[14]

import torch
import torch.nn as nn
from torch.nn import init
from torchvision import models
from torch.autograd import Variable
from apex.fp16_utils import *

def fix_bn(m):
??? classname = m.__class__.__name__
??? if classname.find('BatchNorm') != -1:
??????? m.eval()

model = models.resnet50(pretrained=True)
model.cuda()
model = network(model)
model.train()
model.apply(fix_bn) # fix batchnorm
input = Variable(torch.FloatTensor(8, 3, 224, 224).cuda())
output = model(input)
output_mean = torch.mean(output)
output_mean.backward()

總結來說,在某些情況下,即便整體的模型處于model.train()的狀態,但是某些BN層也可能需要按照需求設置為model_bn.eval()的狀態。

Update 2020.6.19:
評論區有個同學問了一個問題:

??? K.G.lee:想問博主,為什么模型測試時的參數為trainning=False, track_running_stats=True啊??測試不是用訓練時的滑動平均值嗎?為什么track_running_stats=True呢?為啥要跟蹤當前batch??

我感覺這個問題問得挺好的,我們需要去翻下源碼[15],我們發現我們所有的BatchNorm層都有個共同的父類_BatchNorm,我們最需要關注的是return F.batch_norm()這一段,我們發現,其對training的判斷邏輯是

training=self.training or not self.track_running_stats

那么,其實其在eval階段,這里的track_running_stats并不能設置為False,原因很簡單,這樣會使得上面談到的training=True,導致最終的期望程序錯誤。至于設置了track_running_stats=True是不是會導致在eval階段跟蹤測試集的batch的統計參數呢?我覺得是不會的,我們追蹤會發現[16],整個流程的最后一步其實是調用了torch.batch_norm(),其是調用C++的底層函數,其參數列表可和track_running_stats一點關系都沒有,只是由training控制,因此當training=False時,其不會跟蹤統計參數的,只是會調用訓練集訓練得到的統計參數。(當然,時間有限,我也沒有繼續追到C++層次去看源碼了)。

class _BatchNorm(_NormBase):

??? def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True,
???????????????? track_running_stats=True):
??????? super(_BatchNorm, self).__init__(
??????????? num_features, eps, momentum, affine, track_running_stats)

??? def forward(self, input):
??????? self._check_input_dim(input)

??????? # exponential_average_factor is set to self.momentum
??????? # (when it is available) only so that it gets updated
??????? # in ONNX graph when this node is exported to ONNX.
??????? if self.momentum is None:
??????????? exponential_average_factor = 0.0
??????? else:
??????????? exponential_average_factor = self.momentum

??????? if self.training and self.track_running_stats:
??????????? # TODO: if statement only here to tell the jit to skip emitting this when it is None
??????????? if self.num_batches_tracked is not None:
??????????????? self.num_batches_tracked = self.num_batches_tracked + 1
??????????????? if self.momentum is None:? # use cumulative moving average
??????????????????? exponential_average_factor = 1.0 / float(self.num_batches_tracked)
??????????????? else:? # use exponential moving average
??????????????????? exponential_average_factor = self.momentum

??????? return F.batch_norm(
??????????? input, self.running_mean, self.running_var, self.weight, self.bias,
??????????? self.training or not self.track_running_stats,
??????????? exponential_average_factor, self.eps)

?? def batch_norm(input, running_mean, running_var, weight=None, bias=None,
?????????????? training=False, momentum=0.1, eps=1e-5):
??? # type: (Tensor, Optional[Tensor], Optional[Tensor], Optional[Tensor], Optional[Tensor], bool, float, float) -> Tensor? # noqa
??? r"""Applies Batch Normalization for each channel across a batch of data.

??? See :class:`~torch.nn.BatchNorm1d`, :class:`~torch.nn.BatchNorm2d`,
??? :class:`~torch.nn.BatchNorm3d` for details.
??? """
??? if not torch.jit.is_scripting():
??????? if type(input) is not Tensor and has_torch_function((input,)):
??????????? return handle_torch_function(
??????????????? batch_norm, (input,), input, running_mean, running_var, weight=weight,
??????????????? bias=bias, training=training, momentum=momentum, eps=eps)
??? if training:
??????? _verify_batch_size(input.size())

??? return torch.batch_norm(
??????? input, weight, bias, running_mean, running_var,
??????? training, momentum, eps, torch.backends.cudnn.enabled
??? )

??

Reference

[1]. 用pytorch踩過的坑
[2]. Ioffe S, Szegedy C. Batch normalization: accelerating deep network training by reducing internal covariate shift[C]// International Conference on International Conference on Machine Learning. JMLR.org, 2015:448-456.
[3]. <深度學習優化策略-1>Batch Normalization(BN)
[4]. 詳解深度學習中的Normalization,BN/LN/WN
[5]. https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/batchnorm.py#L23-L24
[6]. https://discuss.pytorch.org/t/what-is-the-running-mean-of-batchnorm-if-gradients-are-accumulated/18870
[7]. BatchNorm2d增加的參數track_running_stats如何理解?
[8]. Why track_running_stats is not set to False during eval
[9]. How to train with frozen BatchNorm?
[10]. Proper way of fixing batchnorm layers during training
[11]. 大白話《Understanding the Disharmony between Dropout and Batch Normalization by Variance Shift》
[12]. https://discuss.pytorch.org/t/what-does-model-eval-do-for-batchnorm-layer/7146/2
[13]. https://zhuanlan.zhihu.com/p/65439075
[14]. https://github.com/NVIDIA/apex/issues/122
[15]. https://pytorch.org/docs/stable/_modules/torch/nn/modules/batchnorm.html#BatchNorm2d
[16]. https://pytorch.org/docs/stable/_modules/torch/nn/functional.html#batch_norm
————————————————
版權聲明:本文為CSDN博主「FesianXu」的原創文章,遵循CC 4.0 BY-SA版權協議,轉載請附上原文出處鏈接及本聲明。
原文鏈接:https://blog.csdn.net/LoseInVain/article/details/86476010

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/458022.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/458022.shtml
英文地址,請注明出處:http://en.pswp.cn/news/458022.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

android 比較靠譜的圖片壓縮

2019獨角獸企業重金招聘Python工程師標準>>> 第一&#xff1a;我們先看下質量壓縮方法&#xff1a; private Bitmap compressImage(Bitmap image) { ByteArrayOutputStream baos new ByteArrayOutputStream(); image.compress(Bitmap.CompressFormat.JPEG, 100, …

jetty上手

jetty簡介&#xff1a;維基百科 Jetty是一個純粹的基于Java的網頁服務器和Java Servlet容器。盡管網頁服務器通常用來為人們呈現文檔&#xff0c;但是Jetty通常在較大的軟件框架中用于計算機與計算機之間的通信。Jetty支持最新的Java Servlet API&#xff08;帶JSP的支持&#…

常用公差配合表圖_ER彈簧夾頭配套BT刀柄常用規格型號表

ER彈簧夾頭具有定心精度高&#xff0c;夾緊力均勻的特點&#xff0c;廣泛用于機械類零件的精加工和半精加工&#xff0c;通常與BT刀柄匹配使用。BT刀柄是是機械主軸與刀具和其它附件工具連接件&#xff0c;BT為日本標準(MAS403)&#xff0c;現在也是普遍使用的一種標準。傳統刀…

Spatial Transformer Networks(STN)

詳細解讀Spatial Transformer Networks&#xff08;STN&#xff09;-一篇文章讓你完全理解STN了_多元思考力-CSDN博客_stn

Linux下python安裝升級詳細步驟 | Python2 升級 Python3

Linux下python升級步驟 Python2 ->Python3 多數情況下&#xff0c;系統自動的Python版本是2.x 或者yum直接安裝的也是2.x 但是&#xff0c;現在多數情況下建議使用3.x 那么如何升級呢&#xff1f; 下面老徐詳細講解升級步驟&#xff1b; 首先下載源tar包 可利用linux自帶下…

華為手機連電腦_手機、電腦無網高速互傳!華為神技逆天

Huawei Share是華為的一項自研多終端傳輸技術&#xff0c;可以在沒有網絡狀態下實現手機與手機、電腦等多終端設備間快速穩定的文件分享&#xff0c;尤其是在辦公場景下&#xff0c;可以極大提升辦公效率。華為表示&#xff0c;未來Huawei Share將應用于更多全場景跨設備無縫分…

【無標題】移動端深度學習開源框架及部署(對比)

移動端深度學習開源框架及部署 - 凌逆戰 - 博客園

Github基本操作的學習與溫習

GitHub是最先進的分布式版本控制工具&#xff0c;下面是我學習中總結的操作流程&#xff0c;僅供參考 -----------------------------------------------------------------------------------------------------------------------------------------------------------------…

excel統計行數_值得收藏的6個Excel函數公式(有講解)

收藏的Excel函數大全公式再多&#xff0c;幾天不用也會忘記。怎么才能不忘&#xff1f;你需要了解公式的運行原理。小編今天不再推送一大堆函數公式&#xff0c;而是根據提問最多的問題&#xff0c;精選出6個實用的&#xff0c;然后詳細的解釋給大家。1、計算兩個時間差TEXT(B2…

Studio One正版多少錢 Studio One正版怎么購買

隨著版權意識的增強&#xff0c;打擊盜版的力度越來越大&#xff0c;現在網絡上的盜版資源越來越少&#xff0c;資源少很難找是一方面&#xff0c;另一方面使用盜版軟件不僅很多功能不能使用&#xff0c;而且很多盜版軟件都被植入各種木馬病毒&#xff0c;從而帶來各種各樣的風…

DNS簡述

常見DNS記錄SOA&#xff1a;域權威開始NS&#xff1a;權威域名服務器A&#xff1a;主機地址CNAME&#xff1a;別名對應的正規名稱MX&#xff1a;郵件傳遞服務器PTR&#xff1a;域名指針 (用于反向 DNS)查詢過程瀏覽器緩存->hosts->LDNS->LDNS緩存->ISP->ISP緩存…

cuda gpu相關匯總

1.Ubuntu16.04:在anaconda下安裝pytorch-gpu 轉自&#xff1a;Ubuntu16.04:在anaconda下安裝pytorch-gpu_莫等閑996的博客-CSDN博客 1 創建虛擬環境并進入 conda create -n pytorch-gpu python3.6 conda activate pytorch-gpu 2 下載對應的安裝包和配件 方法一(推薦)&#…

普通人學python有意義嗎_學python難嗎

首先&#xff0c;對于初學者來說學習Python是不錯的選擇&#xff0c;一方面Python語言的語法比較簡單易學&#xff0c;另一方面Python的實驗環境也比較容易搭建。學習Python需要的時間取決于三方面因素。(推薦學習&#xff1a;Python視頻教程)其一是學習者是否具有一定的計算機…

karatsuba乘法

karatsuba乘法 Karatsuba乘法是一種快速乘法。此算法在1960年由Anatolii Alexeevitch Karatsuba 提出&#xff0c;并于1962年得以發表。[1]此算法主要用于兩個大數相乘。普通乘法的復雜度是n2&#xff0c;而Karatsuba算法的復雜度僅為3nlog3≈3n1.585&#xff08;log3是以2為底…

在Visual Studio上開發Node.js程序(2)——遠程調試及發布到Azure

【題外話】 上次介紹了VS上開發Node.js的插件Node.js Tools for Visual Studio&#xff08;NTVS&#xff09;&#xff0c;其提供了非常方便的開發和調試功能&#xff0c;當然很多情況下由于平臺限制等原因需要在其他機器上運行程序&#xff0c;進而需要遠程調試功能&#xff0c…

服務器定期監控數據_基礎設施硬件監控探索與實踐

本文選自 《交易技術前沿》總第三十六期文章(2019年9月)陳靖宇深圳證券交易所 系統運行部Email: jingyuchenszse.cn摘要&#xff1a;為了應對基礎設施規模不斷上升&#xff0c;數據中心兩地三中心帶來的運維挑戰&#xff0c;深交所結合現有基礎設施現狀&#xff0c;以通用性、靈…

LeetCode206:Reverse Linked List

Reverse a singly linked list. 分別用迭代和遞歸實現 struct ListNode {int val;struct ListNode *next; }; 迭代實現&#xff1a; struct ListNode* reverseList(struct ListNode* head) {struct ListNode *pre NULL;struct ListNode *cur head;while( cur ! NULL ){struct…

VS2010問題匯總

問題1&#xff1a;error C3872: "0xa0": 此字符不允許在標識符中使用 error C3872: "0xa0": 此字符不允許在標識符中使用 或者 error C3872: 0xa0: this character is not allowed in an identifier 解法&#xff1a;這是因為直接復制代碼的問題。0xa0是…

交叉編譯HTOP并移植到ARM嵌入式Linux系統

原創作品&#xff0c;允許轉載&#xff0c;轉載時請務必以超鏈接形式標明文章、作者信息和本聲明&#xff0c;否則將追究法律責任。 最近一直在完善基于Busybox做的ARM Linux的根文件系統&#xff0c;由于busybox是一個精簡的指令集組成的簡單文件系統&#xff0c;其優點就是極…

vue如何獲取年月日_好程序員web前端教程分享Vue相關面試題

好程序員web前端教程分享Vue相關面試題&#xff0c;Vue是一套構建用戶界面的漸進式框架&#xff0c;具有簡單易用、性能好、前后端分離等優勢&#xff0c;是web前端工程師工作的好幫手&#xff0c;也是企業選拔人才時考察的重點技能。接下來好程序員web前端教程資源就給大家分享…