Pytorch: model.eval(), model.train() 講解

文章目錄

  • 1. model.eval()
  • 2. model.train()

兩者只在一定的情況下有區別:訓練的模型中含有dropout 和 batch normalization

1. model.eval()

在模型測試階段使用

pytorch會自動把BN和DropOut固定住,不會取每個batchsize的平均,而是用訓練好的值。
不然的話,有輸入數據,即使不訓練,它也會改變權值。

一旦test的batch_size過小,很容易就會被BN層導致生成圖片顏色失真極大;

測試模型的時候
一般model.eval() 和 with torch.no_grad() 一起使用。

model.eval()with torch.no_grad():...out_data = model(data)...
model.train() #最后別忘記加上

2. model.train()

讓model變成訓練模式,此時 dropout和batch normalization的操作在訓練q起到防止網絡過擬合的問題

總結: model.train() 和 model.eval() 一般在模型訓練和評價的時候會加上這兩句,主要是針對由于model 在訓練時和評價時 Batch Normalization 和 Dropout 方法模式不同;
因此,在使用PyTorch進行訓練和測試時一定注意要把實例化的model指定train/eval;

在訓練的時候, 會計算一個batch內的mean 和var, 但是因為是小batch小batch的訓練的,所以會采用加權或者動量的形式來將每個batch的 mean和var來累加起來,也就是說再算當前的batch的時候,其實當前的權重只是占了0.1, 之前所有訓練過的占了0.9的權重,這樣做的好處是不至于因為某一個batch太過奇葩而導致的訓練不穩定。
好,現在假設訓練完成了, 那么在整個訓練集上面也得到了一個最終的”mean 和var”, BN層里面的參數也學習完了(如果指定學習的話),而現在需要測試了,測試的時候往往會一張圖一張圖的去測,這時候沒有batch而言了,對單獨一個數據做 mean和var是沒有意義的, 那么怎么辦,實際上在測試的時候BN里面用的mean和var就是訓練結束后的mean_final 和 val_final. 也可說是在測試的時候BN就是一個變換。所以在用pytorch的時候要注意這一點,在訓練之前要有model.train() 來告訴網絡現在開啟了訓練模式,在eval的時候要用”model.eval()”, 用來告訴網絡現在要進入測試模式了.因為這兩種模式下BN的作用是不同的。

https://blog.csdn.net/qq_32678471/article/details/102892930

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/535205.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/535205.shtml
英文地址,請注明出處:http://en.pswp.cn/news/535205.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

Job for smbd.service failed because the control process exited with error code. See “systemctl statu

錯誤 $ sudo service smbd restartJob for smbd.service failed because the control process exited with error code. See "systemctl status smbd.service" and "journalctl -xe" for details.$ systemctl status smbd.servicesmbd.service - Samba SM…

RuntimeError: one of the variables needed for gradient computation has been modified by an inplace o

問題 RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [4]] is at version 1; expected version 0 instead 分析 nn.relu(True) # 這個引起的問題原來的變量被替換掉了&#xff…

batchsize大小對訓練速度的影響

1.batchsize越大 是不是訓練越快? GPU :一塊2080Ti 平臺:pytorch cuda 圖片數量:2700 batchsize一個圖片處理時間GPU內存占用GPU算力使用一個epoch 所用時間10.117s2.5G20%2700 * 0.0117 318s50.516s8G90%2700 * 0.516/5 279s batchsize大…

os.environ[‘CUDA_VISIBLE_DEVICES‘]= ‘0‘設置環境變量

os.environ[‘環境變量名稱’]‘環境變量值’ #其中key和value均為string類型 import os os.environ["CUDA_VISIBLE_DEVICES"]‘6‘’,‘7’這個語句的作用是程序可見的顯卡ID。 注意 如果"CUDA_VISIBLE_DEVICES" 書寫錯誤,也不報…

nn.ReLU() 和 nn.ReLU(inplace=True)中inplace的作用

inplace 作用: 1.用relu后的變量進行替換,直接把原來的變量覆蓋掉 2.節省了內存 缺點: 有時候出現梯度回傳失敗的問題,因為之前的變量被替換了,找不到之前的變量了 參考這篇文章

pytorch:加載預訓練模型(多卡加載單卡預訓練模型,多GPU,單GPU)

在pytorch加載預訓練模型時,可能遇到以下幾種情況。 分為以下幾種在pytorch加載預訓練模型時,可能遇到以下幾種情況。1.多卡訓練模型加載單卡預訓練模型2. 多卡訓練模型加載多卡預訓練模型3. 單卡訓練模型加載單卡預訓練模型4. 單卡訓練模型加載多卡預訓…

python中 numpy轉list list 轉numpy

list to numpy import numpy as np a [1,2] b np.array(a)numpy to list a np.zero(1,1) a.tolist()

知識蒸餾 knowledge distill 相關論文理解

Knowledge Distil 相關文章1.FitNets : Hints For Thin Deep Nets (ICLR2015)2.A Gift from Knowledge Distillation:Fast Optimization, Network Minimization and Transfer Learning (CVPR 2017)3.Matching Guided Distillation&#xff08…

模型壓縮 相關文章解讀

模型壓縮相關文章Learning both Weights and Connections for Efficient Neural Networks (NIPS2015)Deep Compression: Compressing Deep Neural Networks with Pruning, Trained Quantization and Huffman Coding(ICLR2016)Learning both Weights and …

Linux 殺死進程

kill -9 進程名 殺死進程

計算圖片相似度的方法

文章目錄1.余弦相似度計算2.哈希算法計算圖片的相似度3.直方圖計算圖片的相似度4.SSIM(結構相似度度量)計算圖片的相似度5.基于互信息(Mutual Information)計算圖片的相似度1.余弦相似度計算 把圖片表示成一個向量,通…

.size .shape .size() type的運用

.size ndarray.size 數組元素的總個數,相當于 .shape 中 n*m 的值 a np.array([2,2]) print(a.size)2.shap ndarray.shape 數組的維度,對于矩陣,n 行 m 列 a np.array([2,2]) print(a.shape) (1,2)torch.tensor 數組的維度 x torch.r…

矩陣相加

tensor 類型 a torch.randn(1,3,3) b torch.randn(1,3,3) c a b numpy.array 類型 a np.array([2,2]) b np.array([2,2]) print(type(a)) print(ab)[4,4]

Latex 生成的PDF增加行號 左右兩邊

增加行號 \usepackage[switch]{lineno}\linenumbers \begin{document} \end{document}

pytorh .to(device) 和.cuda()的區別

原理 .to(device) 可以指定CPU 或者GPU device torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # 單GPU或者CPU model.to(device) #如果是多GPU if torch.cuda.device_count() > 1:model nn.DataParallel(model,devic…

Linux 修改用戶名的主目錄 家目錄

首先root 登陸 sudo -i 輸入密碼然后 vim /etc/passwd 找到用戶名 然后修改后面的路徑即可

ubunt16.04 安裝3090顯卡驅動 cuda cudnn pytorch

安裝驅動 需要的安裝包 30系列顯卡是新一代架構,新驅動不支持cuda9以及cuda10,所以必須安裝cuda11、而pytorch現在穩定版為1.6,最高僅支持到cud10.2。所以唯一的辦法就是使用上處于beta測試的1.7或1.8。這也是為啥一開始就強調本文的寫作時…

3090顯卡 torch.cuda.is_available()返回false的解決辦法

問題 1.執行Nvidia-smi 命令沒有報錯,能夠顯示驅動信息; 2.執行 torch.backends.cudnn.enabled is TRUE 3.torch.cuda.is_available()一直返回False 解決 把torch,torchvision等相關安裝包全部刪除,安裝適合版本的torch。 30系…

測試項目:車牌檢測,行人檢測,紅綠燈檢測,人流檢測,目標識別

本項目為2020年中國軟件杯A組第一批賽題"基于計算機視覺的交通場景智能應用".項目用python實現,主要使用YOLO模型實現道路目標如人、車、交通燈等物體的識別,使用開源的"中文車牌識別HyperLPR&a…

pytorch: 在訓練中保存模型,加載模型

文章目錄1. 保存整個模型2.僅保存和加載模型參數(推薦使用)3. 保存其他參數到模型中,比如optimizer,epoch 等1. 保存整個模型 torch.save(model, model.pkl) model torch.load(model.pkl)2.僅保存和加載模型參數(推薦使用) torch.save(model_object.state_dict()…