[pytorch、學習] - 4.4 自定義層

參考

4.4 自定義層

深度學習的一個魅力在于神經網絡中各式各樣的層,例如全連接層和后面章節將要用介紹的卷積層、池化層與循環層。雖然PyTorch提供了大量常用的層,但有時候我們依然希望自定義層。本節將介紹如何使用Module來自定義層,從而可以被重復調用。

4.4.1 不含模型參數的自定義層

我們先介紹如何定義一個不含模型參數的自定義層。

import torch
from torch import nnclass CenteredLayer(nn.Module):def __init__(self, **kwargs):super(CenteredLayer, self).__init__(**kwargs)def forward(self, x):return x - x.mean()
layer = CenteredLayer()
layer(torch.tensor([1, 2, 3, 4, 5], dtype=torch.float))

在這里插入圖片描述
我們也可以用它來構造更復雜的模型。

net = nn.Sequential(nn.Linear(8, 128), CenteredLayer())
y = net(torch.rand(4, 8))
y.mean().item()

在這里插入圖片描述

4.4.2 含模型參數的自定義層

我們還可以自定義含模型參數的自定義層。其中的模型參數可以通過訓練學習。
Parameter類其實是Tensor的子類,如果一個Tensor是Parameter,那么它會自動被添加到模型的參數列表里。所以在自定義含模型參數的層時,我們應該將參數定義成Parameter,除了像4.2.1節那樣直接定義成Parameter類外,還可以使用ParameterListParameterDict分別定義參數的列表和字典。

ParameterList接收一個Parameter實例的列表作為輸入然后得到一個參數列表,使用的時候可以用索引來訪問某個參數,另外也可以使用appendextend在列表后面新增參數。

class MyDense(nn.Module):def __init__(self):super(MyDense, self).__init__()self.params = nn.ParameterList([nn.Parameter(torch.randn(4, 4)) for i in range(3)])self.params.append(nn.Parameter(torch.randn(4, 1)))def forward(self, x):for i in range(len(self.params)):x =  torch.mm(x, self.params[i])return xnet = MyDense()
print(net)

在這里插入圖片描述
ParameterDict接收一個Parameter實例的字典作為輸入然后得到一個參數字典,然后可以按照字典的規則使用了。

class MyDictDense(nn.Module):def __init__(self):super(MyDictDense, self).__init__()self.params = nn.ParameterDict({'linear1': nn.Parameter(torch.randn(4, 4)),'linear2': nn.Parameter(torch.randn(4, 1))})self.params.update({'linear3': nn.Parameter(torch.randn(4, 2))})def forward(self, x, choice='linear1'):return torch.mm(x, self.params[choice])net = MyDictDense()
print(net)

在這里插入圖片描述

x = torch.ones(1, 4)
print(net(x, 'linear1'))
print(net(x, 'linear2'))
print(net(x, 'linear3'))

在這里插入圖片描述
我們也可以使用自定義層構造模型。它和PyTorch的其他層在使用上很類似。

net = nn.Sequential(MyDictDense(),MyDictDense()
)
print(net)
print(net(x))

在這里插入圖片描述

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/250158.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/250158.shtml
英文地址,請注明出處:http://en.pswp.cn/news/250158.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

樹的存儲

父親表示法 顧名思義,就是只記錄每個結點的父結點。 int n; int p[MAX_N]; // 指向每個結點的父結點 孩子表示法 如上,就是只記錄每個結點的子結點。 int n; int cnt[MAX_N]; // 記錄每個結點的子結點的數量 int p[MAX_N][MAX_CNT]; // 指向每個結點的子…

spring-boot注解詳解(四)

repository repository跟Service,Compent,Controller這4種注解是沒什么本質區別,都是聲明作用,取不同的名字只是為了更好區分各自的功能.下圖更多的作用是mapper注冊到類似于以前mybatis.xml中的mappers里. 也是因為接口沒辦法在spring.xml中用bean的方式來配置實現類吧(接口…

令人叫絕的EXCEL函數功能

http://club.excelhome.net/thread-166725-1-1.html https://wenku.baidu.com/view/db319da0bb0d4a7302768e9951e79b8969026864.html轉載于:https://www.cnblogs.com/cqufengchao/articles/9150401.html

[pytorch、學習] - 4.5 讀取和存儲

參考 4.5 讀取和存儲 到目前為止,我們介紹了如何處理數據以及如何構建、訓練和測試深度學習模型。然而在實際中,我們有時需要把訓練好的模型部署到很多不同的設備。在這種情況下,我們可以把內存中訓練好的模型參數存儲在硬盤上供后續讀取使用。 4.5.1 讀寫tensor 我們可以直…

JAVA排序的方法

//冒泡排序法: package fuxi;public class Bubble { public static void main(String[] args) { int a[] { 10,23,11,56,45,26,59,28,84,79 }; int i,temp; System.out.println("輸出原始數組數據:"); for (i…

spring-boot注解詳解(五)

AutoWired 首先要知道另一個東西,default-autowire,它是在xml文件中進行配置的,可以設置為byName、byType、constructor和autodetect;比如byName,不用顯式的在bean中寫出依賴的對象,它會自動的匹配其它bea…

什么是p12證書?ios p12證書怎么獲取?

.cer是蘋果的默認證書,在xcode開發打包可以使用,如果在lbuilder、phonegap、HBuilder、AppCan、APICloud這些跨平臺開發工具打包,就需要用到p12文件。 .cer證書僅包含公鑰,.p12證書可能既包含公鑰也包含私鑰,這就是他們…

[pytorch、學習] - 4.6 GPU計算

參考 4.6 GPU計算 到目前為止,我們一直使用CPU進行計算。對復雜的神經網絡和大規模數據來說,使用CPU來計算可能不夠高效。 在本節中,將要介紹如何使用單塊NIVIDA GPU進行計算 4.6.1 計算設備 PyTorch可以指定用來存儲和計算的設備,如果用內存的CPU或者顯存的GPU。默認情況下…

adb connect 192.168.1.10 failed to connect to 192.168.1.10:5555

adb connect 192.168.1.10 輸出 failed to connect to 192.168.1.10:5555 關閉安卓端Wi-Fi,重新打開連接即可 轉載于:https://www.cnblogs.com/sea-stream/p/10020995.html

創建oracle數據庫表空間并分配用戶

我們在本地的oracle上或者virtualbox的oracle上 創建新的數據庫表空間操作:通過system賬號來創建并授權/*--創建表空間create tablespace YUJKDATAdatafile c:\yujkdata200.dbf --指定表空間對應的datafile文件的具體的路徑size 100mautoextend onnext 10m*/ /*--創…

spring-boot注解詳解(六)

Target Target說明了Annotation所修飾的對象范圍:Annotation可被用于 packages、types(類、接口、枚舉、Annotation類型)、類型成員(方法、構造方法、成員變量、枚舉值)、方法參數和本地變量(如循環變量、…

[pytorch、學習] - 5.1 二維卷積層

參考 5.1 二維卷積層 卷積神經網絡(convolutional neural network)是含有卷積層(convolutional layer)的神經網絡。本章介紹的卷積神經網絡均使用最常見的二維卷積層。它有高和寬兩個空間維度,常用來處理圖像數據。本節中,我們將介紹簡單形式的二維卷積層的工作原理。 5.1.1…

[51CTO]給您介紹Windows10各大版本之間區別

給您介紹Windows10各大版本之間區別 隨著win10的不斷普及和推廣,越來越多的朋友想安裝win10系統了,但是很多朋友不知道win10哪個版本好用,為了讓大家能夠更好的選擇win10系統版本,下面小編就來告訴你 http://os.51cto.com/art/201…

iOS中NSString轉換成HEX(十六進制)-NSData轉換成int

NSString *str "0xff055008"; //先以16為參數告訴strtoul字符串參數表示16進制數字,然后使用0x%X轉為數字類型 unsigned long red strtoul([str UTF8String],0,16); //strtoul如果傳入的字符開頭是“0x”,那么第三個參數是0,也是會轉為十…

spring-boot注解詳解(七)

Configuration 從Spring3.0,Configuration用于定義配置類,可替換xml配置文件,被注解的類內部包含有一個或多個被Bean注解的方法,這些方法將會被AnnotationConfigApplicationContext或AnnotationConfigWebApplicationContext類進行…

[pytorch、學習] - 5.2 填充和步幅

參考 5.2 填充和步幅 5.2.1 填充 填充(padding)是指在輸入高和寬的兩側填充元素(通常是0元素)。圖5.2里我們在原輸入高和寬的兩側分別添加了值為0的元素,使得輸入高和寬從3變成了5,并導致輸出高和寬由2增加到4。圖5.2中的陰影部分為第一個輸出元素及其計算所使用的輸入和核數…

java實現Comparable接口和Comparator接口,并重寫compareTo方法和compare方法

原文地址https://segmentfault.com/a/1190000005738975 實體類:java.lang.Comparable(接口) comareTo(重寫方法),業務排序類 java.util.Comparator(接口) compare(重寫方法). 這兩個接口我們非常的熟悉,但是 在用的時候會有一些不知道怎么下手的感覺&a…

hdu 4714 樹+DFS

題目鏈接:http://acm.hdu.edu.cn/showproblem.php?pid4714 本來想直接求樹的直徑,再得出答案,后來發現是錯的。 思路:任選一個點進行DFS,對于一棵以點u為根節點的子樹來說,如果它的分支數大于1&#xff0c…

springboot----shiro集成

springboot中集成shiro相對簡單,只需要兩個類:一個是shiroConfig類,一個是CustonRealm類。 ShiroConfig類: 顧名思義就是對shiro的一些配置,相對于之前的xml配置。包括:過濾的文件和權限,密碼加…

[pytorch、學習] - 5.3 多輸入通道和多輸出通道

參考 5.3 多輸入通道和多輸出通道 前面兩節里我們用到的輸入和輸出都是二維數組,但真實數據的維度經常更高。例如,彩色圖像在高和寬2個維度外還有RGB(紅、綠、藍)3個顏色通道。假設彩色圖像的高和寬分別是h和w(像素),那么它可以表示為一個3 * h * w的多維數組。我們將大小為3…