nn.functional 和 nn.Module入門講解

本文來自《20天吃透Pytorch》

一,nn.functional 和 nn.Module

前面我們介紹了Pytorch的張量的結構操作和數學運算中的一些常用API。

利用這些張量的API我們可以構建出神經網絡相關的組件(如激活函數,模型層,損失函數)。

Pytorch和神經網絡相關的功能組件大多都封裝在 torch.nn模塊下。

這些功能組件的絕大部分既有函數形式實現,也有類形式實現。

其中nn.functional(一般引入后改名為F)有各種功能組件的函數實現。例如:

(激活函數) * F.relu * F.sigmoid * F.tanh * F.softmax
(模型層) * F.linear * F.conv2d * F.max_pool2d * F.dropout2d * F.embedding
(損失函數) * F.binary_cross_entropy * F.mse_loss * F.cross_entropy

為了便于對參數進行管理,一般通過繼承 nn.Module 轉換成為類的實現形式,并直接封裝在 nn 模塊下。例如:

(激活函數) * nn.ReLU * nn.Sigmoid * nn.Tanh * nn.Softmax
(模型層) * nn.Linear * nn.Conv2d * nn.MaxPool2d * nn.Dropout2d * nn.Embedding
(損失函數) * nn.BCELoss * nn.MSELoss * nn.CrossEntropyLoss

二,使用nn.Module來管理參數

在Pytorch中,模型的參數是需要被優化器訓練的,因此,通常要設置參數為 requires_grad = True 的張量。
同時,在一個模型中,往往有許多的參數,要手動管理這些參數并不是一件容易的事情。
Pytorch一般將參數用nn.Parameter來表示,并且用nn.Module來管理其結構下的所有參數。

# nn.Parameter 具有 requires_grad = True 屬性
w = nn.Parameter(torch.randn(2,2))
print(w)
print(w.requires_grad)# nn.ParameterList 可以將多個nn.Parameter組成一個列表
params_list = nn.ParameterList([nn.Parameter(torch.rand(8,i)) for i in range(1,3)])
print(params_list)
print(params_list[0].requires_grad)# nn.ParameterDict 可以將多個nn.Parameter組成一個字典params_dict = nn.ParameterDict({"a":nn.Parameter(torch.rand(2,2)),"b":nn.Parameter(torch.zeros(2))})
print(params_dict)
print(params_dict["a"].requires_grad)# 可以用Module將它們管理起來
# module.parameters()返回一個生成器,包括其結構下的所有parametersmodule = nn.Module()
module.w = w
module.params_list = params_list
module.params_dict = params_dictnum_param = 0
for param in module.parameters():print(param,"\n")num_param = num_param + 1
print("number of Parameters =",num_param)#實踐當中,一般通過繼承nn.Module來構建模塊類,并將所有含有需要學習的參數的部分放在構造函數中。#以下范例為Pytorch中nn.Linear的源碼的簡化版本
#可以看到它將需要學習的參數放在了__init__構造函數中,并在forward中調用F.linear函數來實現計算邏輯。class Linear(nn.Module):__constants__ = ['in_features', 'out_features']def __init__(self, in_features, out_features, bias=True):super(Linear, self).__init__()self.in_features = in_featuresself.out_features = out_featuresself.weight = nn.Parameter(torch.Tensor(out_features, in_features))if bias:self.bias = nn.Parameter(torch.Tensor(out_features))else:self.register_parameter('bias', None)def forward(self, input):return F.linear(input, self.weight, self.bias)

三,使用nn.Module來管理子模塊

實際上nn.Module除了可以管理其引用的各種參數,還可以管理其引用的子模塊,功能十分強大。

一般情況下,我們都很少直接使用 nn.Parameter來定義參數構建模型,而是通過一些拼裝一些常用的模型層來構造模型。

這些模型層也是繼承自nn.Module的對象,本身也包括參數,屬于我們要定義的模塊的子模塊。

nn.Module提供了一些方法可以管理這些子模塊。

children() 方法: 返回生成器,包括模塊下的所有子模塊。

named_children()方法:返回一個生成器,包括模塊下的所有子模塊,以及它們的名字。

modules()方法:返回一個生成器,包括模塊下的所有各個層級的模塊,包括模塊本身。

named_modules()方法:返回一個生成器,包括模塊下的所有各個層級的模塊以及它們的名字,包括模塊本身。

其中chidren()方法和named_children()方法較多使用。

modules()方法和named_modules()方法較少使用,其功能可以通過多個named_children()的嵌套使用實現。

i = 0
for child in net.children():i+=1print(child,"\n")
print("child number",i)
i = 0
for name,child in net.named_children():i+=1print(name,":",child,"\n")
print("child number",i)
i = 0
for module in net.modules():i+=1print(module)
print("module number:",i)

下面我們通過named_children方法找到embedding層,并將其參數設置為不可訓練(相當于凍結embedding層)。

children_dict = {name:module for name,module in net.named_children()}print(children_dict)
embedding = children_dict["embedding"]
embedding.requires_grad_(False) #凍結其參數

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/389364.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/389364.shtml
英文地址,請注明出處:http://en.pswp.cn/news/389364.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

10.30PMP試題每日一題

SC>0&#xff0c;CPI<1&#xff0c;說明項目截止到當前&#xff1a;A、進度超前&#xff0c;成本超值B、進度落后&#xff0c;成本結余C、進度超前&#xff0c;成本結余D、無法判斷 答案將于明天和新題一起揭曉&#xff01; 10.29試題答案&#xff1a;A轉載于:https://bl…

02-web框架

1 while True:print(server is waiting...)conn, addr server.accept()data conn.recv(1024) print(data:, data)# 1.得到請求的url路徑# ------------dict/obj d["path":"/login"]# d.get(”path“)# 按著http請求協議解析數據# 專注于web業…

ai驅動數據安全治理_AI驅動的Web數據收集解決方案的新起點

ai驅動數據安全治理Data gathering consists of many time-consuming and complex activities. These include proxy management, data parsing, infrastructure management, overcoming fingerprinting anti-measures, rendering JavaScript-heavy websites at scale, and muc…

從Text文本中讀值插入到數據庫中

/// <summary> /// 轉換數據&#xff0c;從Text文本中導入到數據庫中 /// </summary> private void ChangeTextToDb() { if(File.Exists("Storage Card/Zyk.txt")) { try { this.RecNum.Visibletrue; SqlCeCommand sqlCreateTable…

Dataset和DataLoader構建數據通道

重點在第二部分的構建數據通道和第三部分的加載數據集 Pytorch通常使用Dataset和DataLoader這兩個工具類來構建數據管道。 Dataset定義了數據集的內容&#xff0c;它相當于一個類似列表的數據結構&#xff0c;具有確定的長度&#xff0c;能夠用索引獲取數據集中的元素。 而D…

鐵拳nat映射_鐵拳如何重塑我的數據可視化設計流程

鐵拳nat映射It’s been a full year since I’ve become an independent data visualization designer. When I first started, projects that came to me didn’t relate to my interests or skills. Over the past eight months, it’s become very clear to me that when cl…

Django2 Web 實戰03-文件上傳

作者&#xff1a;Hubery 時間&#xff1a;2018.10.31 接上文&#xff1a;接上文&#xff1a;Django2 Web 實戰02-用戶注冊登錄退出 視頻是一種可視化媒介&#xff0c;因此視頻數據庫至少應該存儲圖像。讓用戶上傳文件是個很大的隱患&#xff0c;因此接下來會討論這倆話題&#…

BZOJ.2738.矩陣乘法(整體二分 二維樹狀數組)

題目鏈接 BZOJ洛谷 整體二分。把求序列第K小的樹狀數組改成二維樹狀數組就行了。 初始答案區間有點大&#xff0c;離散化一下。 因為這題是一開始給點&#xff0c;之后詢問&#xff0c;so可以先處理該區間值在l~mid的修改&#xff0c;再處理詢問。即二分標準可以直接用點的標號…

從數據庫里讀值往TEXT文本里寫

/// <summary> /// 把預定內容導入到Text文檔 /// </summary> private void ChangeDbToText() { this.RecNum.Visibletrue; //建立文件&#xff0c;并打開 string oneLine ""; string filename "Storage Card/YD" DateTime.Now.…

DengAI —如何應對數據科學競賽? (EDA)

了解機器學習 (Understanding ML) This article is based on my entry into DengAI competition on the DrivenData platform. I’ve managed to score within 0.2% (14/9069 as on 02 Jun 2020). Some of the ideas presented here are strictly designed for competitions li…

Pytorch模型層簡單介紹

模型層layers 深度學習模型一般由各種模型層組合而成。 torch.nn中內置了非常豐富的各種模型層。它們都屬于nn.Module的子類&#xff0c;具備參數管理功能。 例如&#xff1a; nn.Linear, nn.Flatten, nn.Dropout, nn.BatchNorm2d nn.Conv2d,nn.AvgPool2d,nn.Conv1d,nn.Co…

有效溝通的技能有哪些_如何有效地展示您的數據科學或軟件工程技能

有效溝通的技能有哪些What is the most important thing to do after you got your skills to be a data scientist? It has to be to show off your skills. Otherwise, there is no use of your skills. If you want to get a job or freelance or start a start-up, you ha…

java.net.SocketException: Software caused connection abort: socket write erro

場景&#xff1a;接口測試 編輯器&#xff1a;eclipse 版本&#xff1a;Version: 2018-09 (4.9.0) testng版本&#xff1a;TestNG version 6.14.0 執行testng.xml時報錯信息&#xff1a; 出現此報錯原因之一&#xff1a;網上有人說是testng版本與eclipse版本不一致造成的&#…

[博客..配置?]博客園美化

博客園搞定時間 -> 18年6月27日 [讓我歇會兒 搞這個費腦子 代碼一個都看不懂] 轉載于:https://www.cnblogs.com/Steinway/p/9235437.html

使用K-Means對美因河畔法蘭克福的社區進行聚類

介紹 (Introduction) This blog post summarizes the results of the Capstone Project in the IBM Data Science Specialization on Coursera. Within the project, the districts of Frankfurt am Main in Germany shall be clustered according to their venue data using t…

Pytorch損失函數losses簡介

一般來說&#xff0c;監督學習的目標函數由損失函數和正則化項組成。(Objective Loss Regularization) Pytorch中的損失函數一般在訓練模型時候指定。 注意Pytorch中內置的損失函數的參數和tensorflow不同&#xff0c;是y_pred在前&#xff0c;y_true在后&#xff0c;而Ten…

讀取Mc1000的 唯一 ID 機器號

先引用Symbol.ResourceCoordination 然后引用命名空間 using System;using System.Security.Cryptography;using System.IO; 以下為類程序 /// <summary> /// 獲取設備id /// </summary> /// <returns></returns> public static string GetDevi…

樣本均值的抽樣分布_抽樣分布樣本均值

樣本均值的抽樣分布One of the most important concepts discussed in the context of inferential data analysis is the idea of sampling distributions. Understanding sampling distributions helps us better comprehend and interpret results from our descriptive as …

玩轉ceph性能測試---對象存儲(一)

筆者最近在工作中需要測試ceph的rgw&#xff0c;于是邊測試邊學習。首先工具采用的intel的一個開源工具cosbench&#xff0c;這也是業界主流的對象存儲測試工具。 1、cosbench的安裝&#xff0c;啟動下載最新的cosbench包wget https://github.com/intel-cloud/cosbench/release…

[BZOJ 4300]絕世好題

Description 題庫鏈接 給定一個長度為 \(n\) 的數列 \(a_i\) &#xff0c;求 \(a_i\) 的子序列 \(b_i\) 的最長長度&#xff0c;滿足 \(b_i\wedge b_{i-1}\neq 0\) &#xff08; \(\wedge\) 表示按位與&#xff09; \(1\leq n\leq 100000\) Solution 令 \(f_i\) 為二進制第 \(i…