【PyTorch】 暫退法(dropout)

文章目錄

  • 1. 理論介紹
  • 2. 實例解析
    • 2.1. 實例描述
    • 2.2. 代碼實現
      • 2.2.1. 主要代碼
      • 2.2.2. 完整代碼
      • 2.2.3. 輸出結果

1. 理論介紹

  • 線性模型泛化的可靠性是有代價的,因為線性模型沒有考慮到特征之間的交互作用,由此模型靈活性受限。
  • 泛化性和靈活性之間的基本權衡被描述為偏差-方差權衡
    • 線性模型有很高的偏差,因此它們只能表示一小類函數,但其方差很低,因此它們在不同的隨機數據樣本上可以得出相似的結果。
    • 神經網絡不局限于單獨查看每個特征,而是學習特征之間的交互,但即使我們有比特征多得多的樣本,深度神經網絡也有可能過擬合。
  • 經典泛化理論認為,為了縮小訓練和測試性能之間的差距,應該以簡單的模型為目標。 簡單性以較小維度的形式展現,簡單性的另一個角度是平滑性,即函數不應該對其輸入的微小變化敏感
  • 暫退法在前向傳播過程中,計算每一內部層的同時注入噪聲。 因為當訓練一個有多層的深層網絡時,注入噪聲只會在輸入-輸出映射上增強平滑性。從表面上看是在訓練過程中丟棄(drop out)一些神經元。 在整個訓練過程的每一次迭代中,標準暫退法包括在計算下一層之前將當前層中的一些節點置零。
  • 神經網絡過擬合與每一層都依賴于前一層激活值相關,這種情況稱為共適應性,而暫退法會破壞共適應性。
  • 可以以一種無偏向的方式注入噪聲,在固定住其他層時,每一層的期望值等于沒有噪音時的值。
  • 標準暫退正則化通過按未丟棄的節點的分數進行規范化來消除每一層的偏差,換言之,每個中間活性值 h h h以暫退概率 p p p由隨機變量 h ′ h' h替換,即:
    h ′ = { 0 概率為? p h 1 ? p 其他情況 \begin{aligned} h' = \begin{cases} 0 & \text{ 概率為 } p \\ \frac{h}{1-p} & \text{ 其他情況} \end{cases} \end{aligned} h={01?ph???概率為?p?其他情況??
  • 通常,我們在測試時不用暫退法。 給定一個訓練好的模型和一個新的樣本,我們不會丟棄任何節點,因此不需要標準化。 然而也有一些例外:一些研究人員在測試時使用暫退法, 用于估計神經網絡預測的不確定性: 如果通過許多不同的暫退法遮蓋后得到的預測結果都是一致的,那么我們可以說網絡發揮更穩定。
  • 我們可以將暫退法應用于每個隱藏層的輸出(在激活函數之后), 并且可以為每一層分別設置暫退概率,常見的技巧是在靠近輸入層的地方設置較低的暫退概率。

2. 實例解析

2.1. 實例描述

使用具有兩個隱藏層的多層感知機和暫退法,擬合Fashion-MNIST數據集。

2.2. 代碼實現

2.2.1. 主要代碼

net = nn.Sequential(nn.Flatten(),nn.Linear(784, 256),nn.ReLU(),nn.Dropout(0.2),	# 暫退概率為0.2nn.Linear(256, 256),nn.ReLU(),nn.Dropout(0.5),	# 暫退概率為0.5nn.Linear(256, 10)
).to(device)

2.2.2. 完整代碼

import os
from tensorboardX import SummaryWriter
from rich.progress import track
from torchvision.transforms import Compose, ToTensor
from torchvision.datasets import FashionMNIST
import torch
from torch.utils.data import DataLoader
from torch import nn, optimdef load_dataset():"""加載數據集"""root = "./dataset"transform = Compose([ToTensor()])mnist_train = FashionMNIST(root, True, transform, download=True)mnist_test = FashionMNIST(root, False, transform, download=True)dataloader_train = DataLoader(mnist_train, batch_size, shuffle=True, num_workers=num_workers,)dataloader_test = DataLoader(mnist_test, batch_size, shuffle=False,num_workers=num_workers,)return dataloader_train, dataloader_testif __name__ == '__main__':# 全局參數設置num_epochs = 10batch_size = 256num_workers = 3device = torch.device('cuda:0')lr = 0.5# 創建記錄器def log_dir():root = "runs"if not os.path.exists(root):os.mkdir(root)order = len(os.listdir(root)) + 1return f'{root}/exp{order}'writer = SummaryWriter(log_dir=log_dir())# 數據集配置dataloader_train, dataloader_test = load_dataset()# 定義模型net = nn.Sequential(nn.Flatten(),nn.Linear(784, 256),nn.ReLU(),nn.Dropout(0.2),nn.Linear(256, 256),nn.ReLU(),nn.Dropout(0.5),nn.Linear(256, 10)).to(device)def init_weights(m):if type(m) == nn.Linear:nn.init.normal_(m.weight, std=0.01)net.apply(init_weights)criterion = nn.CrossEntropyLoss(reduction='none')optimizer = optim.SGD(net.parameters(), lr=lr)# 訓練循環for epoch in track(range(num_epochs), description='dropout'):for X, y in dataloader_train:X, y = X.to(device), y.to(device)loss = criterion(net(X), y)optimizer.zero_grad()loss.mean().backward()optimizer.step()with torch.no_grad():train_loss, train_acc, num_samples = 0.0, 0.0, 0for X, y in dataloader_train:X, y = X.to(device), y.to(device)y_hat = net(X)loss = criterion(y_hat, y)train_loss += loss.sum()train_acc += (y_hat.argmax(dim=1) == y).sum()num_samples += y.numel()train_loss /= num_samplestrain_acc /= num_samplestest_acc, num_samples = 0.0, 0for X, y in dataloader_test:X, y = X.to(device), y.to(device)y_hat = net(X)test_acc += (y_hat.argmax(dim=1) == y).sum()num_samples += y.numel()test_acc /= num_sampleswriter.add_scalars('metrics', {'train_loss': train_loss,'train_acc': train_acc,'test_acc': test_acc}, epoch)writer.close()

2.2.3. 輸出結果

dropout

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/206749.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/206749.shtml
英文地址,請注明出處:http://en.pswp.cn/news/206749.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

Docker構建自定義鏡像

創建一個docker-demo的文件夾,放入需要構建的文件 主要是配置Dockerfile文件 第一種配置方法 # 指定基礎鏡像 FROM ubuntu:16.04 # 配置環境變量,JDK的安裝目錄 ENV JAVA_DIR/usr/local# 拷貝jdk和java項目的包 COPY ./jdk8.tar.gz $JAVA_DIR/ COPY ./docker-demo…

Java基礎50題: 21.實現一個方法printArray, 以數組為參數,循環訪問數組中的每個元素,打印每個元素的值.

概述 實現一個方法printArray, 以數組為參數,循環訪問數組中的每個元素,打印每個元素的值. 代碼 public static void printArray(int[] array) {for (int i 0; i < array.length; i) {System.out.println(array[i] " ");}System.out.println();}public static…

【數據結構c實現】順序表實現

文章目錄 線性表線性表的順序實現結點結構結點初始化增配空間Inc打印順序表show_list線性表長度length尾部插入push_back頭部插入push_front尾部刪除pop_back頭部刪除pop_front按位置插入insert_pos按值查找find按位置刪除delete_pos按值刪除delete_val排序sort(冒泡&#xff1…

云上業務DDoS與CC攻擊防護實踐

案例背景&#xff1a;DDoS攻擊來勢洶洶&#xff0c;云上業務面臨威脅 某網絡科技有限公司&#xff0c;SaaS化創業公司&#xff0c;業務基于云上開展。其業務主要為各大網站提供安全驗證服務&#xff0c;且市場占有率較高&#xff0c;服務客戶遍布金融、直播、教育、電商等多個領…

【日常總結】mybatis-plus WHERE BINARY 中文查不出來

目錄 一、場景 二、問題 三、原因 四、解決方案 五、拓展&#xff08;全表全字段修改字符集一鍵更改&#xff09; 準備工作&#xff1a;做好整個庫備份 1. 全表一鍵修改 Stage 1&#xff1a;運行如下查詢 Stage 2&#xff1a;復制sql語句 Stage 3&#xff1a;執行即可…

100. 相同的樹(Java)

目錄 解法&#xff1a; 官方解法&#xff1a; 方法一&#xff1a;深度優先搜索 復雜度分析 時間復雜度&#xff1a; 空間復雜度&#xff1a; 方法二&#xff1a;廣度優先搜索 復雜度分析 時間復雜度&#xff1a; 空間復雜度&#xff1a; 給你兩棵二叉樹的根節點 p 和…

L1-028:判斷素數

題目描述 本題的目標很簡單&#xff0c;就是判斷一個給定的正整數是否素數。 輸入格式&#xff1a; 輸入在第一行給出一個正整數N&#xff08;≤ 10&#xff09;&#xff0c;隨后N行&#xff0c;每行給出一個小于231的需要判斷的正整數。 輸出格式&#xff1a; 對每個需要判斷的…

Kotlin(十五) 高階函數詳解

高階函數的定義 高階函數和Lambda的關系是密不可分的。在之前的文章中&#xff0c;我們熟悉了Lambda編程的基礎知識&#xff0c;并且掌握了一些與集合相關的函數式API的用法&#xff0c;如map、filter函數等。另外&#xff0c;我們也了解了Kotlin的標準函數&#xff0c;如run、…

vuepress-----22、其他評論方案

vuepress 支持評論 本文講述 vuepress 站點如何集成評論系統&#xff0c;選型是 valineleancloud, 支持匿名評論&#xff0c;缺點是數據沒有存儲在自己手里。市面上也有其他的方案, 如 gitalk,vssue 等, 但需要用戶登錄 github 才能發表評論, 但 github 經常無法連接,導致體驗…

[wp]“古劍山”第一屆全國大學生網絡攻防大賽 Web部分wp

“古劍山”第一屆全國大學生網絡攻防大賽 群友說是原題杯 哈哈哈哈 我也不懂 我比賽打的少 Web Web | unse 源碼&#xff1a; <?phpinclude("./test.php");if(isset($_GET[fun])){if(justafun($_GET[fun])){include($_GET[fun]);}}else{unserialize($_GET[…

使用cmake構建的工程的編譯方法

1、克隆項目工程 2、進入到工程目錄 3、執行 mkdir build && cd build 4、執行 cmake .. 5、執行 make 執行以上步驟即可完成對cmake編寫的工程進行編譯 &#xff0c;后面只需執行你的編譯結果即可 $ git clone 你想要克隆的代碼路徑 $ cd 代碼文件夾 $ mkdir bu…

測試:SRE

SRE&#xff08;Site Reliability Engineering&#xff0c;站點可靠性工程&#xff09;是一種關注于構建、運行和維護大規模分布式系統的工程學科。它旨在確保系統在各種故障情況下仍然可用、可靠和高效。 SRE的核心目標是通過軟件工程的方法來解決系統可靠性問題&#xff0c;…

WPF DataGrid 里面的ToggleButton點擊不生效

已解決&#xff1a;根本原因是沒寫UpdateSourceTriggerPropertyChanged <ToggleButton IsChecked"{Binding PathIsEnabled,ModeTwoWay,UpdateSourceTriggerPropertyChanged}"/> 具體原因參考下面文章&#xff1a;鳴謝作者 WPF 數據集合綁定到DataGrid、ListV…

vmware安裝centos7總結

vmware安裝centos7總結 文章目錄 vmware安裝centos7總結一、配置網絡&#xff08;橋接模式&#xff09;二、配置yum源&#xff08;連網配置&#xff09;三、可視化界面四、安裝Docker五、安裝DockerUI 一、配置網絡&#xff08;橋接模式&#xff09; 網絡連接模式選擇橋接模式…

Ubuntu安裝nvidia GPU顯卡驅動教程

Ubuntu安裝nvidia顯卡驅動 1.安裝前安裝必要的依賴 sudo apt-get install build-essential sudo apt-get install g sudo apt-get install make2.到官網下載對應驅動 https://www.nvidia.cn/Download/index.aspx?langcn 3.卸載原有驅動 sudo apt-get remove --purge nvidi…

深度學習:注意力機制(Attention Mechanism)

1 注意力機制概述 1.1 定義 注意力機制&#xff08;Attention Mechanism&#xff09;是深度學習領域中的一種重要技術&#xff0c;特別是在序列模型如自然語言處理&#xff08;NLP&#xff09;和計算機視覺中。它使模型能夠聚焦于輸入數據的重要部分&#xff0c;從而提高整體…

孩子都能學會的FPGA:第二十五課——用FPGA實現頻率計

&#xff08;原創聲明&#xff1a;該文是作者的原創&#xff0c;面向對象是FPGA入門者&#xff0c;后續會有進階的高級教程。宗旨是讓每個想做FPGA的人輕松入門&#xff0c;作者不光讓大家知其然&#xff0c;還要讓大家知其所以然&#xff01;每個工程作者都搭建了全自動化的仿…

基于SpringBoot+maven+Mybatis+html慢性病報銷系統(源碼+數據庫)

一、項目簡介 本項目是一套基于SpringBootmavenMybatishtml慢性病報銷系統&#xff0c;主要針對計算機相關專業的正在做bishe的學生和需要項目實戰練習的Java學習者。 包含&#xff1a;項目源碼、數據庫腳本等&#xff0c;該項目可以直接作為bishe使用。 項目都經過嚴格調試&a…

二十一章(網絡通信)

計算機網絡實現了多臺計算機間的互聯&#xff0c;使得它們彼此之間能夠進行數據交流。網絡應用程序就是在已連接的不同計算機上運行的程序&#xff0c;這些程序借助于網絡協議&#xff0c;相互之間可以交換數據。編寫網絡應用程序前&#xff0c;首先必須明確所要使用的網絡協議…

C++_命名空間(namespace)

目錄 1、namespace的重要性 2、 namespace的定義及作用 2.1 作用域限定符 3、命名空間域與全局域的關系 4、命名空間的嵌套 5、展開命名空間的方法 5.1 特定展開 5.1 部分展開 5.2 全部展開 結語&#xff1a; 前言&#xff1a; C作為c語言的“升級版”&#xff0c;其在…