Pytorch模型層簡單介紹

模型層layers

深度學習模型一般由各種模型層組合而成。

torch.nn中內置了非常豐富的各種模型層。它們都屬于nn.Module的子類,具備參數管理功能。

例如:

nn.Linear, nn.Flatten, nn.Dropout, nn.BatchNorm2d

nn.Conv2d,nn.AvgPool2d,nn.Conv1d,nn.ConvTranspose2d

nn.Embedding,nn.GRU,nn.LSTM

nn.Transformer

如果這些內置模型層不能夠滿足需求,我們也可以通過繼承nn.Module基類構建自定義的模型層。

實際上,pytorch不區分模型和模型層,都是通過繼承nn.Module進行構建。

因此,我們只要繼承nn.Module基類并實現forward方法即可自定義模型層。

一,內置模型層

頭文件:

import numpy as np 
import torch 
from torch import nn

基礎層

nn.Linear:全連接層。參數個數 = 輸入層特征數× 輸出層特征數(weight)+ 輸出層特征數(bias)

nn.Flatten:壓平層,用于將多維張量樣本壓成一維張量樣本。

nn.BatchNorm1d:一維批標準化層。通過線性變換將輸入批次縮放平移到穩定的均值和標準差。可以增強模型對輸入不同分布的適應性,加快模型訓練速度,有輕微正則化效果。一般在激活函數之前使用。可以用afine參數設置該層是否含有可以訓練的參數。

nn.BatchNorm2d:二維批標準化層。

nn.BatchNorm3d:三維批標準化層。

nn.Dropout:一維隨機丟棄層。一種正則化手段。

nn.Dropout2d:二維隨機丟棄層。

nn.Dropout3d:三維隨機丟棄層。

nn.Threshold:限幅層。當輸入大于或小于閾值范圍時,截斷之。

nn.ConstantPad2d: 二維常數填充層。對二維張量樣本填充常數擴展長度。

nn.ReplicationPad1d: 一維復制填充層。對一維張量樣本通過復制邊緣值填充擴展長度。

nn.ZeroPad2d:二維零值填充層。對二維張量樣本在邊緣填充0值.

nn.GroupNorm:組歸一化。一種替代批歸一化的方法,將通道分成若干組進行歸一。不受batch大小限制,據稱性能和效果都優于BatchNorm。

nn.LayerNorm:層歸一化。較少使用。

nn.InstanceNorm2d: 樣本歸一化。較少使用。

卷積網絡相關層

nn.Conv1d:普通一維卷積,常用于文本。參數個數 = 輸入通道數×卷積核尺寸(如3)×卷積核個數 + 卷積核尺寸(如3)

nn.Conv2d:普通二維卷積,常用于圖像。參數個數 = 輸入通道數×卷積核尺寸(如3乘3)×卷積核個數 + 卷積核尺寸(如3乘3) 通過調整dilation參數大于1,可以變成空洞卷積,增大卷積核感受野。 通過調整groups參數不為1,可以變成分組卷積。分組卷積中不同分組使用相同的卷積核,顯著減少參數數量。 當groups參數等于通道數時,相當于tensorflow中的二維深度卷積層tf.keras.layers.DepthwiseConv2D。 利用分組卷積和1乘1卷積的組合操作,可以構造相當于Keras中的二維深度可分離卷積層tf.keras.layers.SeparableConv2D。

nn.Conv3d:普通三維卷積,常用于視頻。參數個數 = 輸入通道數×卷積核尺寸(如3乘3乘3)×卷積核個數 + 卷積核尺寸(如3乘3乘3) 。

nn.MaxPool1d: 一維最大池化。

nn.MaxPool2d:二維最大池化。一種下采樣方式。沒有需要訓練的參數。

nn.MaxPool3d:三維最大池化。

nn.AdaptiveMaxPool2d:二維自適應最大池化。無論輸入圖像的尺寸如何變化,輸出的圖像尺寸是固定的。 該函數的實現原理,大概是通過輸入圖像的尺寸和要得到的輸出圖像的尺寸來反向推算池化算子的padding,stride等參數。

nn.FractionalMaxPool2d:二維分數最大池化。普通最大池化通常輸入尺寸是輸出的整數倍。而分數最大池化則可以不必是整數。分數最大池化使用了一些隨機采樣策略,有一定的正則效果,可以用它來代替普通最大池化和Dropout層。

nn.AvgPool2d:二維平均池化。

nn.AdaptiveAvgPool2d:二維自適應平均池化。無論輸入的維度如何變化,輸出的維度是固定的。

nn.ConvTranspose2d:二維卷積轉置層,俗稱反卷積層。并非卷積的逆操作,但在卷積核相同的情況下,當其輸入尺寸是卷積操作輸出尺寸的情況下,卷積轉置的輸出尺寸恰好是卷積操作的輸入尺寸。在語義分割中可用于上采樣。

nn.Upsample:上采樣層,操作效果和池化相反。可以通過mode參數控制上采樣策略為"nearest"最鄰近策略或"linear"線性插值策略。

nn.Unfold:滑動窗口提取層。其參數和卷積操作nn.Conv2d相同。實際上,卷積操作可以等價于nn.Unfold和nn.Linear以及nn.Fold的一個組合。 其中nn.Unfold操作可以從輸入中提取各個滑動窗口的數值矩陣,并將其壓平成一維。利用nn.Linear將nn.Unfold的輸出和卷積核做乘法后,再使用 nn.Fold操作將結果轉換成輸出圖片形狀。

nn.Fold:逆滑動窗口提取層。

循環網絡相關層

nn.Embedding:嵌入層。一種比Onehot更加有效的對離散特征進行編碼的方法。一般用于將輸入中的單詞映射為稠密向量。嵌入層的參數需要學習。

nn.LSTM:長短記憶循環網絡層【支持多層】。最普遍使用的循環網絡層。具有攜帶軌道,遺忘門,更新門,輸出門。可以較為有效地緩解梯度消失問題,從而能夠適用長期依賴問題。設置bidirectional = True時可以得到雙向LSTM。需要注意的時,默認的輸入和輸出形狀是(seq,batch,feature), 如果需要將batch維度放在第0維,則要設置batch_first參數設置為True。

nn.GRU:門控循環網絡層【支持多層】。LSTM的低配版,不具有攜帶軌道,參數數量少于LSTM,訓練速度更快。

nn.RNN:簡單循環網絡層【支持多層】。容易存在梯度消失,不能夠適用長期依賴問題。一般較少使用。

nn.LSTMCell:長短記憶循環網絡單元。和nn.LSTM在整個序列上迭代相比,它僅在序列上迭代一步。一般較少使用。

nn.GRUCell:門控循環網絡單元。和nn.GRU在整個序列上迭代相比,它僅在序列上迭代一步。一般較少使用。

nn.RNNCell:簡單循環網絡單元。和nn.RNN在整個序列上迭代相比,它僅在序列上迭代一步。一般較少使用。

Transformer相關層

nn.Transformer:Transformer網絡結構。Transformer網絡結構是替代循環網絡的一種結構,解決了循環網絡難以并行,難以捕捉長期依賴的缺陷。它是目前NLP任務的主流模型的主要構成部分。Transformer網絡結構由TransformerEncoder編碼器和TransformerDecoder解碼器組成。編碼器和解碼器的核心是MultiheadAttention多頭注意力層。

nn.TransformerEncoder:Transformer編碼器結構。由多個 nn.TransformerEncoderLayer編碼器層組成。

nn.TransformerDecoder:Transformer解碼器結構。由多個 nn.TransformerDecoderLayer解碼器層組成。

nn.TransformerEncoderLayer:Transformer的編碼器層。

nn.TransformerDecoderLayer:Transformer的解碼器層。

nn.MultiheadAttention:多頭注意力層。

二,自定義模型層

如果Pytorch的內置模型層不能夠滿足需求,我們也可以通過繼承nn.Module基類構建自定義的模型層。

實際上,pytorch不區分模型和模型層,都是通過繼承nn.Module進行構建。

因此,我們只要繼承nn.Module基類并實現forward方法即可自定義模型層。

下面是Pytorch的nn.Linear層的源碼,我們可以仿照它來自定義模型層。

import torch
from torch import nn
import torch.nn.functional as Fclass Linear(nn.Module):__constants__ = ['in_features', 'out_features']def __init__(self, in_features, out_features, bias=True):super(Linear, self).__init__()self.in_features = in_featuresself.out_features = out_featuresself.weight = nn.Parameter(torch.Tensor(out_features, in_features))if bias:self.bias = nn.Parameter(torch.Tensor(out_features))else:self.register_parameter('bias', None)self.reset_parameters()def reset_parameters(self):nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))if self.bias is not None:fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)bound = 1 / math.sqrt(fan_in)nn.init.uniform_(self.bias, -bound, bound)def forward(self, input):return F.linear(input, self.weight, self.bias)def extra_repr(self):return 'in_features={}, out_features={}, bias={}'.format(self.in_features, self.out_features, self.bias is not None)
linear = nn.Linear(20, 30)
inputs = torch.randn(128, 20)
output = linear(inputs)
print(output.size())

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/389353.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/389353.shtml
英文地址,請注明出處:http://en.pswp.cn/news/389353.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

有效溝通的技能有哪些_如何有效地展示您的數據科學或軟件工程技能

有效溝通的技能有哪些What is the most important thing to do after you got your skills to be a data scientist? It has to be to show off your skills. Otherwise, there is no use of your skills. If you want to get a job or freelance or start a start-up, you ha…

java.net.SocketException: Software caused connection abort: socket write erro

場景:接口測試 編輯器:eclipse 版本:Version: 2018-09 (4.9.0) testng版本:TestNG version 6.14.0 執行testng.xml時報錯信息: 出現此報錯原因之一:網上有人說是testng版本與eclipse版本不一致造成的&#…

[博客..配置?]博客園美化

博客園搞定時間 -> 18年6月27日 [讓我歇會兒 搞這個費腦子 代碼一個都看不懂] 轉載于:https://www.cnblogs.com/Steinway/p/9235437.html

使用K-Means對美因河畔法蘭克福的社區進行聚類

介紹 (Introduction) This blog post summarizes the results of the Capstone Project in the IBM Data Science Specialization on Coursera. Within the project, the districts of Frankfurt am Main in Germany shall be clustered according to their venue data using t…

Pytorch損失函數losses簡介

一般來說,監督學習的目標函數由損失函數和正則化項組成。(Objective Loss Regularization) Pytorch中的損失函數一般在訓練模型時候指定。 注意Pytorch中內置的損失函數的參數和tensorflow不同,是y_pred在前,y_true在后,而Ten…

讀取Mc1000的 唯一 ID 機器號

先引用Symbol.ResourceCoordination 然后引用命名空間 using System;using System.Security.Cryptography;using System.IO; 以下為類程序 /// <summary> /// 獲取設備id /// </summary> /// <returns></returns> public static string GetDevi…

樣本均值的抽樣分布_抽樣分布樣本均值

樣本均值的抽樣分布One of the most important concepts discussed in the context of inferential data analysis is the idea of sampling distributions. Understanding sampling distributions helps us better comprehend and interpret results from our descriptive as …

玩轉ceph性能測試---對象存儲(一)

筆者最近在工作中需要測試ceph的rgw&#xff0c;于是邊測試邊學習。首先工具采用的intel的一個開源工具cosbench&#xff0c;這也是業界主流的對象存儲測試工具。 1、cosbench的安裝&#xff0c;啟動下載最新的cosbench包wget https://github.com/intel-cloud/cosbench/release…

[BZOJ 4300]絕世好題

Description 題庫鏈接 給定一個長度為 \(n\) 的數列 \(a_i\) &#xff0c;求 \(a_i\) 的子序列 \(b_i\) 的最長長度&#xff0c;滿足 \(b_i\wedge b_{i-1}\neq 0\) &#xff08; \(\wedge\) 表示按位與&#xff09; \(1\leq n\leq 100000\) Solution 令 \(f_i\) 為二進制第 \(i…

因果關系和相關關系 大數據_數據科學中的相關性與因果關系

因果關系和相關關系 大數據Let’s jump into it right away.讓我們馬上進入。 相關性 (Correlation) Correlation means relationship and association to another variable. For example, a movement in one variable associates with the movement in another variable. For…

Pytorch構建模型的3種方法

這個地方一直是我思考的地方&#xff01;因為學的代碼太多了&#xff0c;構建的模型各有不同&#xff0c;這里記錄一下&#xff01; 可以使用以下3種方式構建模型&#xff1a; 1&#xff0c;繼承nn.Module基類構建自定義模型。 2&#xff0c;使用nn.Sequential按層順序構建模…

vue取數據第一個數據_我作為數據科學家的第一個月

vue取數據第一個數據A lot.很多。 I landed my first job as a Data Scientist at the beginning of August, and like any new job, there’s a lot of information to take in at once.我于8月初找到了數據科學家的第一份工作&#xff0c;并且像任何新工作一樣&#xff0c;一…

Flask-SocketIO 簡單使用指南

Flask-SocketIO 使 Flask 應用程序能夠訪問客戶端和服務器之間的低延遲雙向通信。客戶端應用程序可以使用 Javascript&#xff0c;C &#xff0c;Java 和 Swift 中的任何 SocketIO 官方客戶端庫或任何兼容的客戶端來建立與服務器的永久連接。 安裝 直接使用 pip 來安裝&#xf…

STL-開篇

基本概念 STL&#xff1a; Standard Template Library&#xff0c;標準模板庫 定義&#xff1a; c引入的一個標準類庫 特點&#xff1a;1&#xff09;數據結構和算法的 c實現&#xff08; 采用模板類和模板函數&#xff09;2&#xff09;數據的存儲和算法的分離3&#xff09;高…

Symbol Mc1000 聲音的設置以及播放

首先引用Symbol.Audio 加一命名空間using Symbol.Audio; /聲音設備的設置 //Select Device from device list Symbol.Audio.Device MyDevice (Symbol.Audio.Device)Symbol.StandardForms.SelectDevice.Select( Symbol.Audio.Controller.Title, Symbol.Audio.Devic…

/bin/bash^M: 壞的解釋器: 沒有那個文件或目錄

在win下編輯的時候&#xff0c;換行結尾是\n\r &#xff0c; 而在linux下 是\n&#xff0c;所以會多出來一個\r&#xff0c;這樣會出現錯誤 此時執行 sed -i s/\r$// file.sh 將file.sh中的\r都替換為空白&#xff0c;問題解決轉載于:https://www.cnblogs.com/zzdbullet/p/9890…

rcp rapido_為什么氣流非常適合Rapido

rcp rapidoBack in 2019, when we were building our data platform, we started building the data platform with Hadoop 2.8 and Apache Hive, managing our own HDFS. The need for managing workflows whether it’s data pipelines, i.e. ETL’s, machine learning predi…

pandas處理丟失數據與數據導入導出

3.4pandas處理丟失數據 頭文件&#xff1a; import numpy as np import pandas as pd丟棄數據部分&#xff1a; dates pd.date_range(20130101,periods6) df pd.DataFrame(np.random.randn(6,4),indexdates,columns[A,B,C,D]) df.iloc[0,1] np.nan df.iloc[1,2] np.nanp…

Mysql5.7開啟遠程

2019獨角獸企業重金招聘Python工程師標準>>> 1.注掉bind-address #bind-address 127.0.0.1 2.開啟遠程訪問權限 grant all privileges on *.* to root"xxx.xxx.xxx.xxx" identified by "密碼"; 或 grant all privileges on *.* to root"%…

分類結果可視化python_可視化分類結果的另一種方法

分類結果可視化pythonI love good data visualizations. Back in the days when I did my PhD in particle physics, I was stunned by the histograms my colleagues built and how much information was accumulated in one single plot.我喜歡出色的數據可視化。 早在我獲得…