[pytorch、學習] - 3.13 丟棄法

參考

3.13 丟棄法

過擬合問題的另一種解決辦法是丟棄法。當對隱藏層使用丟棄法時,隱藏單元有一定概率被丟棄。

3.12.1 方法

在這里插入圖片描述

3.13.2 從零開始實現

import torch
import torch.nn as nn
import numpy as np
import sys
sys.path.append("..") 
import d2lzh_pytorch as d2ldef dropout(X, drop_prob):X = X.float()assert 0 <= drop_prob <= 1keep_prob = 1 - drop_prob# 這種情況下把全部元素都丟棄if keep_prob == 0:return torch.zeros_like(X)mask = (torch.rand(X.shape) < keep_prob).float()return mask * X / keep_prob
X = torch.arange(16).view(2, 8)
X

在這里插入圖片描述

dropout(X, 0.5)

在這里插入圖片描述

dropout(X, 1)

在這里插入圖片描述

3.13.2.1 定義模型參數

num_inputs, num_outputs, num_hiddens1, num_hiddens2 = 784, 10, 256, 256W1 = torch.tensor(np.random.normal(0, 0.01, size=(num_inputs, num_hiddens1)), dtype=torch.float, requires_grad=True)
b1 = torch.zeros(num_hiddens1, requires_grad=True)
W2 = torch.tensor(np.random.normal(0, 0.01, size=(num_hiddens1, num_hiddens2)), dtype=torch.float, requires_grad=True)
b2 = torch.zeros(num_hiddens2, requires_grad=True)
W3 = torch.tensor(np.random.normal(0, 0.01, size=(num_hiddens2, num_outputs)), dtype=torch.float, requires_grad=True)
b3 = torch.zeros(num_outputs, requires_grad=True)params = [W1, b1, W2, b2, W3, b3]

3.13.2.2 定義模型

drop_prob1, drop_prob2 = 0.2, 0.5def net(X, is_training=True):X = X.view(-1, num_inputs)H1 = (torch.matmul(X, W1) + b1).relu()if is_training:  # 只在訓練模型時使用丟棄法H1 = dropout(H1, drop_prob1)  # 在第一層全連接后添加丟棄層H2 = (torch.matmul(H1, W2) + b2).relu()if is_training:H2 = dropout(H2, drop_prob2)  # 在第二層全連接后添加丟棄層return torch.matmul(H2, W3) + b3# 本函數已保存在d2lzh_pytorch
def evaluate_accuracy(data_iter, net):acc_sum, n = 0.0, 0for X, y in data_iter:if isinstance(net, torch.nn.Module):net.eval() # 評估模式, 這會關閉dropoutacc_sum += (net(X).argmax(dim=1) == y).float().sum().item()net.train() # 改回訓練模式else: # 自定義的模型if('is_training' in net.__code__.co_varnames): # 如果有is_training這個參數# 將is_training設置成Falseacc_sum += (net(X, is_training=False).argmax(dim=1) == y).float().sum().item() else:acc_sum += (net(X).argmax(dim=1) == y).float().sum().item() n += y.shape[0]return acc_sum / n

3.13.2.3 訓練和測試模型

num_epochs, lr, batch_size = 5, 100.0, 256
loss = torch.nn.CrossEntropyLoss()
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)
d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs, batch_size, params, lr)

在這里插入圖片描述

3.13.3 簡潔實現

net = nn.Sequential(d2l.FlattenLayer(),nn.Linear(num_inputs, num_hiddens1),nn.ReLU(),nn.Dropout(drop_prob1),nn.Linear(num_hiddens1, num_hiddens2),nn.ReLU(),nn.Dropout(drop_prob2),nn.Linear(num_hiddens2, 10)
)for param in net.parameters():nn.init.normal_(param, mean=0, std= 0.01)optimizer = torch.optim.SGD(net.parameters(), lr=0.5)
d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs, batch_size, None, None, optimizer)

在這里插入圖片描述

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/250168.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/250168.shtml
英文地址,請注明出處:http://en.pswp.cn/news/250168.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

springboot---request 中Parameter,Attribute區別

HttpServletRequest類既有getAttribute()方法&#xff0c;也由getParameter()方法&#xff0c;這兩個方法有以下區別&#xff1a; &#xff08;1&#xff09;HttpServletRequest類有setAttribute()方法&#xff0c;而沒有setParameter()方法 &#xff08;2&#xff09;當兩個…

Python之令人心煩意亂的字符編碼與轉碼

ASC-II碼&#xff1a;英文1個字節&#xff08;8 byte&#xff09;&#xff0c;不支持中文&#xff1b; 高大上的中國&#xff0c;擴展出自己的gbk、gb2312、gb2318等字符編碼。 由于各個國家都有自己的編碼&#xff0c;于是就需要統一的編碼形式用于國際流傳&#xff0c;防止亂…

[pytorch、學習] - 4.1 模型構造

參考 4.1 模型構造 讓我們回顧以下多重感知機的簡潔實現中包含單隱藏層的多重感知機的實現方法。我們首先構造Sequential實例,然后依次添加兩個全連接層。其中第一層的輸出大小為256,即隱藏層單元個數是256;第二層的輸出大小為10,即輸出層單元個數是10. 4.1.1 繼承Module類來…

springboot---基本模塊詳解

概述 1.基于Spring框架的“約定優先于配置&#xff08;COC&#xff09;”理念以及最佳實踐之路。 2.針對日常企業應用研發各種場景的Spring-boot-starter自動配置依賴模塊&#xff0c;且“開箱即用”&#xff08;約定spring-boot-starter- 作為命名前綴&#xff0c;都位于org.…

第二課 運算符(day10)

第二課 運算符(day10) 一、運算符 結果是值 算數運算 a 10 * 10 賦值運算 a a 1 a1 結果是布爾值 比較運算 a 1 > 5 邏輯運算 a 1>6 or 11 成員運算 a "蚊" in "鄭建文" 二、基本數據類型 1、數值…

[pytorch、學習] - 4.2 模型參數的訪問、初始化和共享

參考 4.2 模型參數的訪問、初始化和共享 在3.3節(線性回歸的簡潔實現)中,我們通過init模塊來初始化模型的參數。我們也介紹了訪問模型參數的簡單方法。本節將深入講解如何訪問和初始化模型參數,以及如何在多個層之間共享同一份模型參數。 import torch from torch import nn…

spring-boot注解詳解(三)

1.SpringBoot/spring SpringBootApplication: 包含Configuration、EnableAutoConfiguration、ComponentScan通常用在主類上&#xff1b; Repository: 用于標注數據訪問組件&#xff0c;即DAO組件&#xff1b; Service: 用于標注業務層組件&#xff1b; RestController: 用于…

IEnumerableT和IQueryableT區分

哎&#xff0c;看了那么多&#xff0c;這個知識點還是得開一個文章 IQueryable和IEnumerable都是延時執行(Deferred Execution)的&#xff0c;而IList是即時執行(Eager Execution) IQueryable和IEnumerable在每次執行時都必須連接數據庫讀取&#xff0c;而IList讀取一次后&…

表的轉置 行轉列: DECODE(Oracle) 和 CASE WHEN 的異同點

異同點 都可以對表行轉列&#xff1b;DECODE功能上和簡單Case函數比較類似&#xff0c;不能像Case搜索函數一樣&#xff0c;進行更復雜的判斷在Case函數中&#xff0c;可以使用BETWEEN, LIKE, IS NULL, IN, EXISTS等等&#xff08;也可以使用NOT IN和NOT EXISTS&#xff0c;但是…

[pytorch、學習] - 4.4 自定義層

參考 4.4 自定義層 深度學習的一個魅力在于神經網絡中各式各樣的層,例如全連接層和后面章節將要用介紹的卷積層、池化層與循環層。雖然PyTorch提供了大量常用的層,但有時候我們依然希望自定義層。本節將介紹如何使用Module來自定義層,從而可以被重復調用。 4.4.1 不含模型參…

樹的存儲

父親表示法 顧名思義&#xff0c;就是只記錄每個結點的父結點。 int n; int p[MAX_N]; // 指向每個結點的父結點 孩子表示法 如上&#xff0c;就是只記錄每個結點的子結點。 int n; int cnt[MAX_N]; // 記錄每個結點的子結點的數量 int p[MAX_N][MAX_CNT]; // 指向每個結點的子…

spring-boot注解詳解(四)

repository repository跟Service,Compent,Controller這4種注解是沒什么本質區別,都是聲明作用,取不同的名字只是為了更好區分各自的功能.下圖更多的作用是mapper注冊到類似于以前mybatis.xml中的mappers里. 也是因為接口沒辦法在spring.xml中用bean的方式來配置實現類吧(接口…

令人叫絕的EXCEL函數功能

http://club.excelhome.net/thread-166725-1-1.html https://wenku.baidu.com/view/db319da0bb0d4a7302768e9951e79b8969026864.html轉載于:https://www.cnblogs.com/cqufengchao/articles/9150401.html

[pytorch、學習] - 4.5 讀取和存儲

參考 4.5 讀取和存儲 到目前為止,我們介紹了如何處理數據以及如何構建、訓練和測試深度學習模型。然而在實際中,我們有時需要把訓練好的模型部署到很多不同的設備。在這種情況下,我們可以把內存中訓練好的模型參數存儲在硬盤上供后續讀取使用。 4.5.1 讀寫tensor 我們可以直…

JAVA排序的方法

//冒泡排序法&#xff1a; package fuxi;public class Bubble { public static void main(String[] args) { int a[] { 10,23,11,56,45,26,59,28,84,79 }; int i,temp; System.out.println("輸出原始數組數據&#xff1a;"); for (i…

spring-boot注解詳解(五)

AutoWired 首先要知道另一個東西&#xff0c;default-autowire&#xff0c;它是在xml文件中進行配置的&#xff0c;可以設置為byName、byType、constructor和autodetect&#xff1b;比如byName&#xff0c;不用顯式的在bean中寫出依賴的對象&#xff0c;它會自動的匹配其它bea…

什么是p12證書?ios p12證書怎么獲取?

.cer是蘋果的默認證書&#xff0c;在xcode開發打包可以使用&#xff0c;如果在lbuilder、phonegap、HBuilder、AppCan、APICloud這些跨平臺開發工具打包&#xff0c;就需要用到p12文件。 .cer證書僅包含公鑰&#xff0c;.p12證書可能既包含公鑰也包含私鑰&#xff0c;這就是他們…

[pytorch、學習] - 4.6 GPU計算

參考 4.6 GPU計算 到目前為止,我們一直使用CPU進行計算。對復雜的神經網絡和大規模數據來說,使用CPU來計算可能不夠高效。 在本節中,將要介紹如何使用單塊NIVIDA GPU進行計算 4.6.1 計算設備 PyTorch可以指定用來存儲和計算的設備,如果用內存的CPU或者顯存的GPU。默認情況下…

adb connect 192.168.1.10 failed to connect to 192.168.1.10:5555

adb connect 192.168.1.10 輸出 failed to connect to 192.168.1.10:5555 關閉安卓端Wi-Fi&#xff0c;重新打開連接即可 轉載于:https://www.cnblogs.com/sea-stream/p/10020995.html

創建oracle數據庫表空間并分配用戶

我們在本地的oracle上或者virtualbox的oracle上 創建新的數據庫表空間操作&#xff1a;通過system賬號來創建并授權/*--創建表空間create tablespace YUJKDATAdatafile c:\yujkdata200.dbf --指定表空間對應的datafile文件的具體的路徑size 100mautoextend onnext 10m*/ /*--創…