Pytorch構建模型的3種方法

這個地方一直是我思考的地方!因為學的代碼太多了,構建的模型各有不同,這里記錄一下!
可以使用以下3種方式構建模型:

1,繼承nn.Module基類構建自定義模型。

2,使用nn.Sequential按層順序構建模型。

3,繼承nn.Module基類構建模型并輔助應用模型容器進行封裝(nn.Sequential,nn.ModuleList,nn.ModuleDict)。

其中 第1種方式最為常見,第2種方式最簡單,第3種方式最為靈活也較為復雜。

推薦使用第1種方式構建模型。

頭文件:

import torch 
from torch import nn

一,繼承nn.Module基類構建自定義模型

以下是繼承nn.Module基類構建自定義模型的一個范例。模型中的用到的層一般在__init__函數中定義,然后在forward方法中定義模型的正向傳播邏輯。

class Net(nn.Module):def __init__(self):super(Net, self).__init__()self.conv1 = nn.Conv2d(in_channels=3,out_channels=32,kernel_size = 3)self.pool1 = nn.MaxPool2d(kernel_size = 2,stride = 2)self.conv2 = nn.Conv2d(in_channels=32,out_channels=64,kernel_size = 5)self.pool2 = nn.MaxPool2d(kernel_size = 2,stride = 2)self.dropout = nn.Dropout2d(p = 0.1)self.adaptive_pool = nn.AdaptiveMaxPool2d((1,1))self.flatten = nn.Flatten()self.linear1 = nn.Linear(64,32)self.relu = nn.ReLU()self.linear2 = nn.Linear(32,1)self.sigmoid = nn.Sigmoid()def forward(self,x):x = self.conv1(x)x = self.pool1(x)x = self.conv2(x)x = self.pool2(x)x = self.dropout(x)x = self.adaptive_pool(x)x = self.flatten(x)x = self.linear1(x)x = self.relu(x)x = self.linear2(x)y = self.sigmoid(x)return ynet = Net()
print(net)

二,使用nn.Sequential按層順序構建模型

使用nn.Sequential按層順序構建模型無需定義forward方法。僅僅適合于簡單的模型。

以下是使用nn.Sequential搭建模型的一些等價方法。

1,利用add_module方法

net = nn.Sequential()
net.add_module("conv1",nn.Conv2d(in_channels=3,out_channels=32,kernel_size = 3))
net.add_module("pool1",nn.MaxPool2d(kernel_size = 2,stride = 2))
net.add_module("conv2",nn.Conv2d(in_channels=32,out_channels=64,kernel_size = 5))
net.add_module("pool2",nn.MaxPool2d(kernel_size = 2,stride = 2))
net.add_module("dropout",nn.Dropout2d(p = 0.1))
net.add_module("adaptive_pool",nn.AdaptiveMaxPool2d((1,1)))
net.add_module("flatten",nn.Flatten())
net.add_module("linear1",nn.Linear(64,32))
net.add_module("relu",nn.ReLU())
net.add_module("linear2",nn.Linear(32,1))
net.add_module("sigmoid",nn.Sigmoid())print(net)

2,利用變長參數

這種方式構建時不能給每個層指定名稱。

net = nn.Sequential(nn.Conv2d(in_channels=3,out_channels=32,kernel_size = 3),nn.MaxPool2d(kernel_size = 2,stride = 2),nn.Conv2d(in_channels=32,out_channels=64,kernel_size = 5),nn.MaxPool2d(kernel_size = 2,stride = 2),nn.Dropout2d(p = 0.1),nn.AdaptiveMaxPool2d((1,1)),nn.Flatten(),nn.Linear(64,32),nn.ReLU(),nn.Linear(32,1),nn.Sigmoid()
)print(net)

3,利用OrderedDict

from collections import OrderedDictnet = nn.Sequential(OrderedDict([("conv1",nn.Conv2d(in_channels=3,out_channels=32,kernel_size = 3)),("pool1",nn.MaxPool2d(kernel_size = 2,stride = 2)),("conv2",nn.Conv2d(in_channels=32,out_channels=64,kernel_size = 5)),("pool2",nn.MaxPool2d(kernel_size = 2,stride = 2)),("dropout",nn.Dropout2d(p = 0.1)),("adaptive_pool",nn.AdaptiveMaxPool2d((1,1))),("flatten",nn.Flatten()),("linear1",nn.Linear(64,32)),("relu",nn.ReLU()),("linear2",nn.Linear(32,1)),("sigmoid",nn.Sigmoid())]))
print(net)

三,繼承nn.Module基類構建模型并輔助應用模型容器進行封裝

當模型的結構比較復雜時,我們可以應用模型容器(nn.Sequential,nn.ModuleList,nn.ModuleDict)對模型的部分結構進行封裝。

這樣做會讓模型整體更加有層次感,有時候也能減少代碼量。

注意,在下面的范例中我們每次僅僅使用一種模型容器,但實際上這些模型容器的使用是非常靈活的,可以在一個模型中任意組合任意嵌套使用。

1,nn.Sequential作為模型容器

class Net(nn.Module):def __init__(self):super(Net, self).__init__()self.conv = nn.Sequential(nn.Conv2d(in_channels=3,out_channels=32,kernel_size = 3),nn.MaxPool2d(kernel_size = 2,stride = 2),nn.Conv2d(in_channels=32,out_channels=64,kernel_size = 5),nn.MaxPool2d(kernel_size = 2,stride = 2),nn.Dropout2d(p = 0.1),nn.AdaptiveMaxPool2d((1,1)))self.dense = nn.Sequential(nn.Flatten(),nn.Linear(64,32),nn.ReLU(),nn.Linear(32,1),nn.Sigmoid())def forward(self,x):x = self.conv(x)y = self.dense(x)return y net = Net()
print(net)

2,nn.ModuleList作為模型容器

class Net(nn.Module):def __init__(self):super(Net, self).__init__()self.layers = nn.ModuleList([nn.Conv2d(in_channels=3,out_channels=32,kernel_size = 3),nn.MaxPool2d(kernel_size = 2,stride = 2),nn.Conv2d(in_channels=32,out_channels=64,kernel_size = 5),nn.MaxPool2d(kernel_size = 2,stride = 2),nn.Dropout2d(p = 0.1),nn.AdaptiveMaxPool2d((1,1)),nn.Flatten(),nn.Linear(64,32),nn.ReLU(),nn.Linear(32,1),nn.Sigmoid()])def forward(self,x):for layer in self.layers:x = layer(x)return x
net = Net()
print(net)

3,nn.ModuleDict作為模型容器

注意下面中的ModuleDict不能用Python中的字典代替。

class Net(nn.Module):def __init__(self):super(Net, self).__init__()self.layers_dict = nn.ModuleDict({"conv1":nn.Conv2d(in_channels=3,out_channels=32,kernel_size = 3),"pool": nn.MaxPool2d(kernel_size = 2,stride = 2),"conv2":nn.Conv2d(in_channels=32,out_channels=64,kernel_size = 5),"dropout": nn.Dropout2d(p = 0.1),"adaptive":nn.AdaptiveMaxPool2d((1,1)),"flatten": nn.Flatten(),"linear1": nn.Linear(64,32),"relu":nn.ReLU(),"linear2": nn.Linear(32,1),"sigmoid": nn.Sigmoid()})def forward(self,x):layers = ["conv1","pool","conv2","pool","dropout","adaptive","flatten","linear1","relu","linear2","sigmoid"]for layer in layers:x = self.layers_dict[layer](x)return x
net = Net()
print(net)

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/389342.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/389342.shtml
英文地址,請注明出處:http://en.pswp.cn/news/389342.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

vue取數據第一個數據_我作為數據科學家的第一個月

vue取數據第一個數據A lot.很多。 I landed my first job as a Data Scientist at the beginning of August, and like any new job, there’s a lot of information to take in at once.我于8月初找到了數據科學家的第一份工作,并且像任何新工作一樣,一…

Flask-SocketIO 簡單使用指南

Flask-SocketIO 使 Flask 應用程序能夠訪問客戶端和服務器之間的低延遲雙向通信。客戶端應用程序可以使用 Javascript,C ,Java 和 Swift 中的任何 SocketIO 官方客戶端庫或任何兼容的客戶端來建立與服務器的永久連接。 安裝 直接使用 pip 來安裝&#xf…

STL-開篇

基本概念 STL: Standard Template Library,標準模板庫 定義: c引入的一個標準類庫 特點:1)數據結構和算法的 c實現( 采用模板類和模板函數)2)數據的存儲和算法的分離3)高…

Symbol Mc1000 聲音的設置以及播放

首先引用Symbol.Audio 加一命名空間using Symbol.Audio; /聲音設備的設置 //Select Device from device list Symbol.Audio.Device MyDevice (Symbol.Audio.Device)Symbol.StandardForms.SelectDevice.Select( Symbol.Audio.Controller.Title, Symbol.Audio.Devic…

/bin/bash^M: 壞的解釋器: 沒有那個文件或目錄

在win下編輯的時候,換行結尾是\n\r , 而在linux下 是\n,所以會多出來一個\r,這樣會出現錯誤 此時執行 sed -i s/\r$// file.sh 將file.sh中的\r都替換為空白,問題解決轉載于:https://www.cnblogs.com/zzdbullet/p/9890…

rcp rapido_為什么氣流非常適合Rapido

rcp rapidoBack in 2019, when we were building our data platform, we started building the data platform with Hadoop 2.8 and Apache Hive, managing our own HDFS. The need for managing workflows whether it’s data pipelines, i.e. ETL’s, machine learning predi…

pandas處理丟失數據與數據導入導出

3.4pandas處理丟失數據 頭文件: import numpy as np import pandas as pd丟棄數據部分: dates pd.date_range(20130101,periods6) df pd.DataFrame(np.random.randn(6,4),indexdates,columns[A,B,C,D]) df.iloc[0,1] np.nan df.iloc[1,2] np.nanp…

Mysql5.7開啟遠程

2019獨角獸企業重金招聘Python工程師標準>>> 1.注掉bind-address #bind-address 127.0.0.1 2.開啟遠程訪問權限 grant all privileges on *.* to root"xxx.xxx.xxx.xxx" identified by "密碼"; 或 grant all privileges on *.* to root"%…

分類結果可視化python_可視化分類結果的另一種方法

分類結果可視化pythonI love good data visualizations. Back in the days when I did my PhD in particle physics, I was stunned by the histograms my colleagues built and how much information was accumulated in one single plot.我喜歡出色的數據可視化。 早在我獲得…

算法組合 優化算法_算法交易簡化了風險價值和投資組合優化

算法組合 優化算法Photo by Markus Spiske (left) and Jamie Street (right) on UnsplashMarkus Spiske (左)和Jamie Street(右)在Unsplash上的照片 In the last post, we saw how actual algorithms are developed and tested. In this post, we will figure out the level of…

Symbol Mc1000 快捷鍵 的 設置 事件 開發

switch (e.KeyCode) { ///數據 case Keys.F1://清除數據 if(File.Exists("Storage Card/CG.sdf")) { Mc.gConn.Close(); Mc.gConn.Dispose(); File.Delete("Storage Card/CG.sdf"); } MessageBox.S…

pandas合并concatmerge和plot畫圖

3.6,3.7pandas合并concat&merge 頭文件: import pandas as pd import numpy as npconcat基礎合并用法 df1 pd.DataFrame(np.ones((3,4))*0,columns [a,b,c,d]) df2 pd.DataFrame(np.ones((3,4))*1,columns [a,b,c,d]) df3 pd.DataFrame(np.ones…

Android跳轉WIFI界面的四種方式

第一種 Intent intent new Intent(); intent.setAction("android.net.wifi.PICK_WIFI_NETWORK"); startActivity(intent); 第二種 startActivity(new Intent(android.provider.Settings.ACTION_WIFI_SETTINGS)); 第三種 Intent i new Intent(); if(android.os.Buil…

PS摳發絲技巧 「選擇并遮住…」

PS摳發絲技巧 「選擇并遮住…」 現在的海報設計,大多數都有模特MM,然而MM的頭發實用太多了,有的還飄起來…… 對于設計師(特別是淘寶美工)沒有一個強大、快速、實用的摳發絲技巧真的混不去哦。而PS CC 2017版本開始,就有了一個強大…

covid 19如何重塑美國科技公司的工作文化

未來 , 技術 , 觀點 (Future, Technology, Opinion) Who would have thought that a single virus would take down the whole world and make us stay inside our homes? A pandemic wave that has altered our lives in such a way that no human (bi…

Symbol Mc1000 Text文本閱讀器整體代碼

using System; using System.ComponentModel;using System.Data;using System.Drawing;using System.Text;using System.Windows.Forms;using System.Collections;using System.IO;namespace text{ /// <summary> /// Form1 的摘要說明。 /// </summary> public c…

python生日悖論分析_生日悖論

python生日悖論分析If you have a group of people in a room, how many do you need to for it to be more likely than not, that two or more will have the same birthday?如果您在一個房間里有一群人&#xff0c;那么您需要多少個才能使兩個或兩個以上的人有相同的生日&a…

統計0-n數字中出現k的次數

/*** 統計0-n數字中出現k的次數&#xff0c;其中k范圍為0-9 */ public static int countOne(int k, int n) {if (k > n) {return 0;}int sum 0;int right 0;for (int i 0; n > 0; i) {int last n % 10;sum last * i * (int) Math.pow(10, i - 1);if (k 0) {sum - (…

房價預測 search Search 中對數據預處理的學習

對于缺失的數據&#xff1a; 我們對連續數值的特征做標準化&#xff08;standardization&#xff09;&#xff1a;設該特征在整個數據集上的均值為 μ &#xff0c;標準差為 σ 。那么&#xff0c;我們可以將該特征的每個值先減去 μ 再除以 σ 得到標準化后的每個特征值。對于…

3.6.1.非阻塞IO

本節講解什么是非阻塞IO&#xff0c;如何將文件描述符修改為非阻塞式 3.6.1.1、阻塞與非阻塞 &#xff08;1&#xff09;阻塞是指函數調用會被阻塞。本質是當前進程調用了函數&#xff0c;進入內核里面去后&#xff0c;因為當前進程的執行條件不滿足&#xff0c;內核無法里面完…