pytorch-nn.Module

目錄

  • 1. nn.Module
  • 2. nn.Sequential容器
  • 3. 網絡參數parameters
  • 4. Modules內部管理
  • 5. checkpoint
  • 6. train/test狀態切換
  • 6. 實現自己的網絡層
    • 6.1 實現打平操作
    • 6.2 實現自己的線性層
  • 7. 代碼

1. nn.Module

是所有nn.類的父類,其中包括nn.Linear nn.BatchNorm2d nn.Conv2d nn.ReLU nn.Sigmoid等等

2. nn.Sequential容器

如下圖,定義一個net網絡,將所有繼承自nn.Module的子類定義的網絡層加入到了nn.Sequential容器中,與一層一層的單獨調用模塊組成序列相比,nn.Sequential() 可以允許將整個容器視為單個模塊(即相當于把多個模塊封裝成一個模塊),forward()方法接收輸入之后,nn.Sequential()按照內部模塊的順序自動依次計算并輸出結果。因此可以利用nn.Sequential()搭建模型架構

在這里插入圖片描述

3. 網絡參數parameters

如下圖,通過net.parameters()可以獲取到net的參數,轉換成list后,通過index訪問第幾個參數,比如:圖中的list(net.named_parameters())[0]就可以獲取到網絡的第一個參數,也就是網絡第一層的w參數。
通過list(net.named_parameters()).items()獲取到所有網絡層,從獲取結果可以看到,每一層都被pytorch命名了,比如:‘0.weight’,‘0.bias’,即第一層網絡的weight和bias.
在這里插入圖片描述

4. Modules內部管理

與根節點相連的直系親屬叫children,其他再與children連接的節點都叫modules
如下圖,nn.Sequential是Net的children,其他的是modules,包括nn.ReLU、nn.Linear、BasicNet
在這里插入圖片描述
從下面這張截圖可以看出,Net本身和Children也都是modules
在這里插入圖片描述

5. checkpoint

為了防止train過程意外停止,需從頭train的問題,train過程需要定期保持checkpoint,而一旦出現train意外停止,就可以從最后一次checkpoint接著訓練。
torch.save保存checkpoint
torch.load_state_dict(torch.load(‘chpt.md’))用于load checkpoint
在這里插入圖片描述

6. train/test狀態切換

所有nn.類都繼承自nn.Module,因此在切換train和test狀態時,只需要調用一次net.train()或net.eval即可,而不需要那些train和test(dropout)行為不一致的類每個單獨去切換.
在這里插入圖片描述

6. 實現自己的網絡層

6.1 實現打平操作

全連接層層需要打平輸入,打平操作通過.view方法實現,由于Flatten繼承自nn.Module,因此可以直接放到nn.Sequential中。
在這里插入圖片描述

6.2 實現自己的線性層

通過net.parameters()可以將網絡參數加到優化器中。
在這里插入圖片描述
troch.tensor是不會自動加到nn.parameters中,因此需要使用nn.Parameter將tensor加到nn.parameters,從而才能加到SGD等優化器中。

在這里插入圖片描述

7. 代碼

import  torch
from    torch import nn
from    torch import optimclass MyLinear(nn.Module):def __init__(self, inp, outp):super(MyLinear, self).__init__()# requires_grad = Trueself.w = nn.Parameter(torch.randn(outp, inp))self.b = nn.Parameter(torch.randn(outp))def forward(self, x):x = x @ self.w.t() + self.breturn xclass Flatten(nn.Module):def __init__(self):super(Flatten, self).__init__()def forward(self, input):return input.view(input.size(0), -1)class TestNet(nn.Module):def __init__(self):super(TestNet, self).__init__()self.net = nn.Sequential(nn.Conv2d(1, 16, stride=1, padding=1),nn.MaxPool2d(2, 2),Flatten(),nn.Linear(1*14*14, 10))def forward(self, x):return self.net(x)class BasicNet(nn.Module):def __init__(self):super(BasicNet, self).__init__()self.net = nn.Linear(4, 3)def forward(self, x):return self.net(x)class Net(nn.Module):def __init__(self):super(Net, self).__init__()self.net = nn.Sequential(BasicNet(),nn.ReLU(),nn.Linear(3, 2))def forward(self, x):return self.net(x)def main():device = torch.device('cuda')net = Net()net.to(device)net.train()net.eval()# net.load_state_dict(torch.load('ckpt.mdl'))### torch.save(net.state_dict(), 'ckpt.mdl')for name, t in net.named_parameters():print('parameters:', name, t.shape)for name, m in net.named_children():print('children:', name, m)for name, m in net.named_modules():print('modules:', name, m)if __name__ == '__main__':main()

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/web/23835.shtml
繁體地址,請注明出處:http://hk.pswp.cn/web/23835.shtml
英文地址,請注明出處:http://en.pswp.cn/web/23835.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

每日一練 - OSPF協議驗證機制

01 真題題目 OSPF 只有在 Hello 報文中有驗證信息,OSPF 支持 MD5 密文驗證. A.正確 B.錯誤 02 真題答案 B 03 答案解析 這個陳述是不完全正確的。首先,OSPF確實使用Hello報文來攜帶認證信息,但這不意味著只有Hello報文包含驗證信息。 OSPF的認證機制可…

政府績效考核第三方評估的含義

政府績效考核第三方評估是指由獨立于政府的外部機構(如專業評估公司、研究機構或非政府組織)對政府部門或其下屬單位的績效進行客觀、公正、系統的評估。其主要目的是通過引入獨立的第三方評估機構,對政府績效進行科學、全面的考核&#xff0…

【AIGC調研系列】Qwen2與llama3對比的優勢

Qwen2與Llama3的對比中,Qwen2展現出了多方面的優勢。首先,從性能角度來看,Qwen2在多個基準測試中表現出色,尤其是在代碼和數學能力上有顯著提升[1][9]。此外,Qwen2還在自然語言理解、知識、多語言等多項能力上均顯著超…

肺結節14問,查出肺結節怎么辦?哪些能用中醫調治消散?快來了解一下吧

近些年,隨著大眾防癌意識的加強,和胸部低劑量CT的普及,肺結節的檢出率也逐年升高,不少患者CT報告上,寫著“肺小結”“肺部磨玻璃結節”的字樣,當你看到這幾個字時,會不會瞬間緊張起來&#xff1…

編程規范-代碼檢測-格式化-規范化提交

適用于vue項目的編程規范 – 在多人開發時統一編程規范至關重要 1、代碼檢測 --Eslint Eslint:一個插件化的 javascript 代碼檢測工具 在 .eslintrc.js 文件中進行配置 // ESLint 配置文件遵循 commonJS 的導出規則,所導出的對象就是 ESLint 的配置對…

簡化電動汽車充電器和光伏逆變器的高壓電流檢測

在任何電氣系統中,電流都是一個至關重要的參數。電動汽車 (EV) 充電系統和太陽能系統都需要檢測電流的大小,以便控制和監測功率轉換、充電和放電。電流傳感器通過監測分流電阻器上的壓降或導體中電流產生的磁場來測量電流。 金屬氧化物半導體場效應晶體…

DBeaver連接MySQL提示“Public Key Retrieval is not allowed“問題的解決方式

問題描述 客戶端root用戶連接數據庫出現出現Public Key Retrieval is not allowed 原因分析: 加上allowPublicKeyRetrievalfalse: 解決方案: allowPublicKeyRetrievaltrue:

Java Web學習筆記14——BOM對象

BOM: 概念:瀏覽器對象模型(Browser Object Model),允許JavaScript與瀏覽器對話,JavaScript將瀏覽器的各個組成部分封裝為對象。 組成: Window:瀏覽器窗口對象 介紹:瀏覽…

opencv銳化卷積核的定義和應用(圖像銳化)。

定義銳化卷積核 卷積核(Kernel)是一個小矩陣,它用于在圖像處理操作中,比如模糊、銳化、邊緣檢測等。卷積核通過卷積操作應用于圖像像素,產生新的圖像。 在銳化操作中,我們通常使用一個 3x3 的卷積核。以下…

注解 - @RestController

注解簡介 在今天的每日一注解中,我們將探討RestController注解。RestController是Spring框架中的一個組合注解,方便創建RESTful Web服務。 注解定義 RestController注解是Controller和ResponseBody注解的組合,用于定義RESTful控制器。以下是…

物聯網(IoT)及物聯網網絡協議面試題及參考答案(2萬字長文)

什么是物聯網(IoT)? 物聯網(Internet of Things,簡稱IoT)是一個由互聯網、傳統電信網、傳感器網絡等多種網絡組成的網絡概念。它允許物體與物體、物體與人、人與人之間通過智能傳感器、軟件和網絡進行信息交換和通信,實現智能化識別、定位、跟蹤、監控和管理。物聯網的…

光伏電站鳥害解決方案,列式沖擊波聲壓光伏驅鳥器

光伏電站的運營過程中,最怕遇上鳥糞污染。鳥糞不僅難以清洗,還可能導致光伏組件損壞、降低發電效率。因此,制定并實施有效的驅鳥策略對于光伏電站的穩定運營至關重要。 針對光伏電站的鳥害問題,我們可以從以下幾個方面來解決&…

知名優秀定制線纜生產源頭工廠推薦-精工電聯:全程跟蹤監制,打造水下機器人線纜定制新標桿

在科技飛速發展的今天,精工電聯作為高科技智能化產品及自動化設備專用連接線束和連接器配套服務商,始終站在行業前沿。我們專注于為高科技行業提供高品質、優匹配的集成線纜和連接器定制服務,特別是在水下機器人線纜定制領域,通過…

CAN的TP模式和COM模式的區別

CAN的TP(傳輸協議)模式和COM(通信)模式主要涉及汽車網絡中的數據傳輸機制,兩者在功能、尋址方式和幀類型等方面有所不同。具體分析如下: 功能 TP模式:TP模式,即傳輸協議模式&#…

sql死鎖分析

一、重要參數 獲取事務信息:SELECT * FROM information_schema.INNODB_TRX; 獲取鎖等待:SELECT * FROM information_schema.INNODB_LOCK_WAITS; 查看鎖信息:SELECT * FROM information_schema.INNODB_LOCKS WHERE lock_trx_id IN () 二、case1:間隙鎖和x鎖互斥導致死鎖 1、背景…

安全高效海外倉系統:中小海外倉標準化管理的第一步

在當今全球化的商業背景中,可以說海外倉已經成為跨境電商供應鏈中不可或缺的一環。 尤其是對于那些處于成長階段的中小型海外倉來說,選擇一款安全高效并且符合其海外倉規模特點的wms管理系統尤其重要。 今天我們就來系統的了解一下,安全高效…

大廠AI團戰高考作文,華師一附中特級教師這樣打分

在人工智能的浪潮中, 人們不禁疑問: AI真的能超越人類嗎? 這究竟是現實還是幻想? 我們將目睹一場前所未有的較量: 百度文心一言、阿里通義千問、 騰訊混元、字節豆包 四家國內頂尖互聯網企業 精心打造的AI大模…

HBM簡介

1、什么是HBM HBMHigh Bandwidth Memory 是一種用于某些 GPU的 3D 堆疊 DRAM存儲器 (動態隨機存取存儲器)以及服務器、高性能計算 (HPC) 、網絡連接的內存接口。其實就是將很多個DDR芯片堆疊在一起后和GPU封裝在一起,實…

ROS socketcan_bridge使用說明

ROS socketcan_bridge使用說明(以ubuntu20.04為例) socketcan_bridge是什么 ROS針對socketcan提供了三個層次的驅動庫,分別是ros_canopen,socketcan_bridge和socketcan_interface。 socketcan_interface: 功能&#x…

k-means聚類模型的原理和應用

k-means聚類算法是一種迭代求解的聚類分析算法,其步驟是,預將數據分為K組,然后隨機選取K個對象作為初始的聚類中心;計算每個對象與各個種子聚類中心之間的距離,把每個對象分配給距離它最近的聚類中心;聚類中…