[pytorch、學習] - 4.2 模型參數的訪問、初始化和共享

參考

4.2 模型參數的訪問、初始化和共享

在3.3節(線性回歸的簡潔實現)中,我們通過init模塊來初始化模型的參數。我們也介紹了訪問模型參數的簡單方法。本節將深入講解如何訪問和初始化模型參數,以及如何在多個層之間共享同一份模型參數。

import torch
from torch import nn
from torch.nn import initnet = nn.Sequential(nn.Linear(4,3), nn.ReLU(), nn.Linear(3, 1))print(net)
X = torch.rand(2, 4)
Y = net(X).sum()

在這里插入圖片描述

4.2.1 訪問模型的參數

回憶上一節中提到的Sequential類與Module類的繼承關系。對于Sequential實例中含模型參數的層,我們可以通過Module類的parameters()或者named_parameters方法來訪問所有參數(以迭代器的形式返回),后者除了返回參數Tensor外還會返回其名字

print(type(net.named_parameters()))
for name, param in net.named_parameters():print(name, param.size())

在這里插入圖片描述
可見返回的名字自動加上了層數的索引作為前綴。我們再來訪問net中單層的參數。對于使用Sequential類構造的神經網絡,我們可以通過方括號[]來訪問網絡的任一層。索引0表示隱藏層為Sequential實例最先添加的層.

for name, param in net[0].named_parameters():print(name, param.size(), type(param))

在這里插入圖片描述

# 如果一個Tensor是Parameter,那么它會自動被添加到模型的參數列表里
class MyModel(nn.Module):def __init__(self, **kwargs):super(MyModel, self).__init__(**kwargs)self.weight1 = nn.Parameter(torch.rand(20, 20))self.weight2 = torch.rand(20, 20)def forward(self, x):passn = MyModel()
for name, param in n.named_parameters():print(name)

在這里插入圖片描述

# 上面代碼中weight1在參數列表中,但是weight2卻沒在參數列表中
# 因為Parameters是Tensor,即Tensor擁有的屬性它都有,比如可以根據data來訪問參數數值,用grad來訪問參數梯度
weight_0 = list(net[0].parameters())[0]  # 將第0層的W取出
print(net)
print(weight_0)
print(weight_0.grad)   # 此時并未對Y做梯度下降,因此會顯示None
Y.backward()
print(weight_0.grad)

在這里插入圖片描述

4.2.2 初始化模型參數

在下面的例子中,我們將權重參數初始化為均值為0、標準差為0.01的正態分布隨機數,并依然將偏差參數清零。

for name, param in net.named_parameters():if 'weight' in name:init.normal_(param, mean=0, std=0.01)print(name, param.data)

在這里插入圖片描述

# 使用常數來初始化權重參數
for name, param in net.named_parameters():if 'bias' in name:init.constant_(param, val=0)print(name, param.data)

在這里插入圖片描述

4.2.3 自定義初始化方法

有時候我們需要的初始化方法并沒有在init模塊中提供。這時,可以實現一個初始化方法,從而能夠像使用其他方法那樣使用它。

# 我們先看看pytorch如何實現的
def normal_(tensor, mean=0, std= 1):with torch.no_grad:return tensor.normal_(mean, std)

可以看到這就是一個inplace改變Tensor值的函數,而且這個過程是不記錄梯度的。 類似的我們來實現一個自定義的初始化方法。在下面的例子里,我們令權重有一半概率初始化為0,有另一半概率初始化為[?10,?5]和[5,10]兩個區間里均勻分布的隨機數。

def init_weight_(tensor):with torch.no_grad():tensor.uniform_(-10, 10)tensor *= (tensor.abs() >= 5).float()for name, param in net.named_parameters():if 'weight' in name:init_weight_(param)print(name, param.data)

在這里插入圖片描述

4.2.4 共享模型參數

在有些情況下,我們希望在多個層之間共享模型參數。下面來看一個例子

linear = nn.Linear(1, 1, bias=False)
net = nn.Sequential(linear, linear)
print(net)for name, param in net.named_parameters():init.constant_(param, val = 3)print(name, param.data)

在這里插入圖片描述

# 在內存中,這兩個線性層其實是一個對象
print(id(net[0]) == id(net[1]))
print(id(net[0].weight) == id(net[1].weight))

在這里插入圖片描述

# 因為模型參數里包含了梯度,所以在反向傳播時,這些共享的參數的梯度是累加的
x = torch.ones(1, 1)
y = net(x).sum()
print(y)
y.backward()
print(net[0].weight.grad)

在這里插入圖片描述

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/250162.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/250162.shtml
英文地址,請注明出處:http://en.pswp.cn/news/250162.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

spring-boot注解詳解(三)

1.SpringBoot/spring SpringBootApplication: 包含Configuration、EnableAutoConfiguration、ComponentScan通常用在主類上; Repository: 用于標注數據訪問組件,即DAO組件; Service: 用于標注業務層組件; RestController: 用于…

IEnumerableT和IQueryableT區分

哎,看了那么多,這個知識點還是得開一個文章 IQueryable和IEnumerable都是延時執行(Deferred Execution)的,而IList是即時執行(Eager Execution) IQueryable和IEnumerable在每次執行時都必須連接數據庫讀取,而IList讀取一次后&…

表的轉置 行轉列: DECODE(Oracle) 和 CASE WHEN 的異同點

異同點 都可以對表行轉列;DECODE功能上和簡單Case函數比較類似,不能像Case搜索函數一樣,進行更復雜的判斷在Case函數中,可以使用BETWEEN, LIKE, IS NULL, IN, EXISTS等等(也可以使用NOT IN和NOT EXISTS,但是…

[pytorch、學習] - 4.4 自定義層

參考 4.4 自定義層 深度學習的一個魅力在于神經網絡中各式各樣的層,例如全連接層和后面章節將要用介紹的卷積層、池化層與循環層。雖然PyTorch提供了大量常用的層,但有時候我們依然希望自定義層。本節將介紹如何使用Module來自定義層,從而可以被重復調用。 4.4.1 不含模型參…

樹的存儲

父親表示法 顧名思義,就是只記錄每個結點的父結點。 int n; int p[MAX_N]; // 指向每個結點的父結點 孩子表示法 如上,就是只記錄每個結點的子結點。 int n; int cnt[MAX_N]; // 記錄每個結點的子結點的數量 int p[MAX_N][MAX_CNT]; // 指向每個結點的子…

spring-boot注解詳解(四)

repository repository跟Service,Compent,Controller這4種注解是沒什么本質區別,都是聲明作用,取不同的名字只是為了更好區分各自的功能.下圖更多的作用是mapper注冊到類似于以前mybatis.xml中的mappers里. 也是因為接口沒辦法在spring.xml中用bean的方式來配置實現類吧(接口…

令人叫絕的EXCEL函數功能

http://club.excelhome.net/thread-166725-1-1.html https://wenku.baidu.com/view/db319da0bb0d4a7302768e9951e79b8969026864.html轉載于:https://www.cnblogs.com/cqufengchao/articles/9150401.html

[pytorch、學習] - 4.5 讀取和存儲

參考 4.5 讀取和存儲 到目前為止,我們介紹了如何處理數據以及如何構建、訓練和測試深度學習模型。然而在實際中,我們有時需要把訓練好的模型部署到很多不同的設備。在這種情況下,我們可以把內存中訓練好的模型參數存儲在硬盤上供后續讀取使用。 4.5.1 讀寫tensor 我們可以直…

JAVA排序的方法

//冒泡排序法: package fuxi;public class Bubble { public static void main(String[] args) { int a[] { 10,23,11,56,45,26,59,28,84,79 }; int i,temp; System.out.println("輸出原始數組數據:"); for (i…

spring-boot注解詳解(五)

AutoWired 首先要知道另一個東西,default-autowire,它是在xml文件中進行配置的,可以設置為byName、byType、constructor和autodetect;比如byName,不用顯式的在bean中寫出依賴的對象,它會自動的匹配其它bea…

什么是p12證書?ios p12證書怎么獲取?

.cer是蘋果的默認證書,在xcode開發打包可以使用,如果在lbuilder、phonegap、HBuilder、AppCan、APICloud這些跨平臺開發工具打包,就需要用到p12文件。 .cer證書僅包含公鑰,.p12證書可能既包含公鑰也包含私鑰,這就是他們…

[pytorch、學習] - 4.6 GPU計算

參考 4.6 GPU計算 到目前為止,我們一直使用CPU進行計算。對復雜的神經網絡和大規模數據來說,使用CPU來計算可能不夠高效。 在本節中,將要介紹如何使用單塊NIVIDA GPU進行計算 4.6.1 計算設備 PyTorch可以指定用來存儲和計算的設備,如果用內存的CPU或者顯存的GPU。默認情況下…

adb connect 192.168.1.10 failed to connect to 192.168.1.10:5555

adb connect 192.168.1.10 輸出 failed to connect to 192.168.1.10:5555 關閉安卓端Wi-Fi,重新打開連接即可 轉載于:https://www.cnblogs.com/sea-stream/p/10020995.html

創建oracle數據庫表空間并分配用戶

我們在本地的oracle上或者virtualbox的oracle上 創建新的數據庫表空間操作:通過system賬號來創建并授權/*--創建表空間create tablespace YUJKDATAdatafile c:\yujkdata200.dbf --指定表空間對應的datafile文件的具體的路徑size 100mautoextend onnext 10m*/ /*--創…

spring-boot注解詳解(六)

Target Target說明了Annotation所修飾的對象范圍:Annotation可被用于 packages、types(類、接口、枚舉、Annotation類型)、類型成員(方法、構造方法、成員變量、枚舉值)、方法參數和本地變量(如循環變量、…

[pytorch、學習] - 5.1 二維卷積層

參考 5.1 二維卷積層 卷積神經網絡(convolutional neural network)是含有卷積層(convolutional layer)的神經網絡。本章介紹的卷積神經網絡均使用最常見的二維卷積層。它有高和寬兩個空間維度,常用來處理圖像數據。本節中,我們將介紹簡單形式的二維卷積層的工作原理。 5.1.1…

[51CTO]給您介紹Windows10各大版本之間區別

給您介紹Windows10各大版本之間區別 隨著win10的不斷普及和推廣,越來越多的朋友想安裝win10系統了,但是很多朋友不知道win10哪個版本好用,為了讓大家能夠更好的選擇win10系統版本,下面小編就來告訴你 http://os.51cto.com/art/201…

iOS中NSString轉換成HEX(十六進制)-NSData轉換成int

NSString *str "0xff055008"; //先以16為參數告訴strtoul字符串參數表示16進制數字,然后使用0x%X轉為數字類型 unsigned long red strtoul([str UTF8String],0,16); //strtoul如果傳入的字符開頭是“0x”,那么第三個參數是0,也是會轉為十…

spring-boot注解詳解(七)

Configuration 從Spring3.0,Configuration用于定義配置類,可替換xml配置文件,被注解的類內部包含有一個或多個被Bean注解的方法,這些方法將會被AnnotationConfigApplicationContext或AnnotationConfigWebApplicationContext類進行…

[pytorch、學習] - 5.2 填充和步幅

參考 5.2 填充和步幅 5.2.1 填充 填充(padding)是指在輸入高和寬的兩側填充元素(通常是0元素)。圖5.2里我們在原輸入高和寬的兩側分別添加了值為0的元素,使得輸入高和寬從3變成了5,并導致輸出高和寬由2增加到4。圖5.2中的陰影部分為第一個輸出元素及其計算所使用的輸入和核數…