PyTorch中的torch.nn.Parameter() 詳解

PyTorch中的torch.nn.Parameter() 詳解

今天來聊一下PyTorch中的torch.nn.Parameter()這個函數,筆者第一次見的時候也是大概能理解函數的用途,但是具體實現原理細節也是云里霧里,在參考了幾篇博文,做過幾個實驗之后算是清晰了,本文在記錄的同時希望給后來人一個參考,歡迎留言討論。

分析

先看其名,parameter,中文意為參數。我們知道,使用PyTorch訓練神經網絡時,本質上就是訓練一個函數,這個函數輸入一個數據(如CV中輸入一張圖像),輸出一個預測(如輸出這張圖像中的物體是屬于什么類別)。而在我們給定這個函數的結構(如卷積、全連接等)之后,能學習的就是這個函數的參數了,我們設計一個損失函數,配合梯度下降法,使得我們學習到的函數(神經網絡)能夠盡量準確地完成預測任務。

通常,我們的參數都是一些常見的結構(卷積、全連接等)里面的計算參數。而當我們的網絡有一些其他的設計時,會需要一些額外的參數同樣很著整個網絡的訓練進行學習更新,最后得到最優的值,經典的例子有注意力機制中的權重參數、Vision Transformer中的class token和positional embedding等。

而這里的torch.nn.Parameter()就可以很好地適應這種應用場景。

下面是這篇博客的一個總結,筆者認為講的比較明白,在這里引用一下:

首先可以把這個函數理解為類型轉換函數,將一個不可訓練的類型Tensor轉換成可以訓練的類型parameter并將這個parameter綁定到這個module里面(net.parameter()中就有這個綁定的parameter,所以在參數優化的時候可以進行優化的),所以經過類型轉換這個self.v變成了模型的一部分,成為了模型中根據訓練可以改動的參數了。使用這個函數的目的也是想讓某些變量在學習的過程中不斷的修改其值以達到最優化。

ViT中nn.Parameter()的實驗

看過這個分析后,我們再看一下Vision Transformer中的用法:

...self.pos_embedding = nn.Parameter(torch.randn(1, num_patches+1, dim))
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
...

我們知道在ViT中,positonal embedding和class token是兩個需要隨著網絡訓練學習的參數,但是它們又不屬于FC、MLP、MSA等運算的參數,在這時,就可以用nn.Parameter()來將這個隨機初始化的Tensor注冊為可學習的參數Parameter。

為了確定這兩個參數確實是被添加到了net.Parameters()內,筆者稍微改動源碼,顯式地指定這兩個參數的初始數值為0.98,并打印迭代器net.Parameters()。

...self.pos_embedding = nn.Parameter(torch.ones(1, num_patches+1, dim) * 0.98)
self.cls_token = nn.Parameter(torch.ones(1, 1, dim) * 0.98)
...

實例化一個ViT模型并打印net.Parameters():

net_vit = ViT(image_size = 256,patch_size = 32,num_classes = 1000,dim = 1024,depth = 6,heads = 16,mlp_dim = 2048,dropout = 0.1,emb_dropout = 0.1)for para in net_vit.parameters():print(para.data)

輸出結果中可以看到,最前兩行就是我們顯式指定為0.98的兩個參數pos_embedding和cls_token:

tensor([[[0.9800, 0.9800, 0.9800,  ..., 0.9800, 0.9800, 0.9800],[0.9800, 0.9800, 0.9800,  ..., 0.9800, 0.9800, 0.9800],[0.9800, 0.9800, 0.9800,  ..., 0.9800, 0.9800, 0.9800],...,[0.9800, 0.9800, 0.9800,  ..., 0.9800, 0.9800, 0.9800],[0.9800, 0.9800, 0.9800,  ..., 0.9800, 0.9800, 0.9800],[0.9800, 0.9800, 0.9800,  ..., 0.9800, 0.9800, 0.9800]]])
tensor([[[0.9800, 0.9800, 0.9800,  ..., 0.9800, 0.9800, 0.9800]]])
tensor([[-0.0026, -0.0064,  0.0111,  ...,  0.0091, -0.0041, -0.0060],[ 0.0003,  0.0115,  0.0059,  ..., -0.0052, -0.0056,  0.0010],[ 0.0079,  0.0016, -0.0094,  ...,  0.0174,  0.0065,  0.0001],...,[-0.0110, -0.0137,  0.0102,  ...,  0.0145, -0.0105, -0.0167],[-0.0116, -0.0147,  0.0030,  ...,  0.0087,  0.0022,  0.0108],[-0.0079,  0.0033, -0.0087,  ..., -0.0174,  0.0103,  0.0021]])
...
...

這就可以確定nn.Parameter()添加的參數確實是被添加到了Parameters列表中,會被送入優化器中隨訓練一起學習更新。

from torch.optim import Adam
opt = Adam(net_vit.parameters(), learning_rate=0.001)

其他解釋

以下是國外StackOverflow的一個大佬的解讀,筆者自行翻譯并放在這里供大家參考,想查看原文的同學請戳這里。

我們知道Tensor相當于是一個高維度的矩陣,它是Variable類的子類。Variable和Parameter之間的差異體現在與Module關聯時。當Parameter作為model的屬性與module相關聯時,它會被自動添加到Parameters列表中,并且可以使用net.Parameters()迭代器進行訪問。
最初在Torch中,一個Variable(例如可以是某個中間state)也會在賦值時被添加為模型的Parameter。在某些實例中,需要緩存變量,而不是將它們添加到Parameters列表中。
文檔中提到的一種情況是RNN,在這種情況下,您需要保存最后一個hidden state,這樣就不必一次又一次地傳遞它。需要緩存一個Variable,而不是讓它自動注冊為模型的Parameter,這就是為什么我們有一個顯式的方法將參數注冊到我們的模型,即nn.Parameter類。

舉個例子:

import torch
import torch.nn as nn
from torch.optim import Adamclass NN_Network(nn.Module):def __init__(self,in_dim,hid,out_dim):super(NN_Network, self).__init__()self.linear1 = nn.Linear(in_dim,hid)self.linear2 = nn.Linear(hid,out_dim)self.linear1.weight = torch.nn.Parameter(torch.zeros(in_dim,hid))self.linear1.bias = torch.nn.Parameter(torch.ones(hid))self.linear2.weight = torch.nn.Parameter(torch.zeros(in_dim,hid))self.linear2.bias = torch.nn.Parameter(torch.ones(hid))def forward(self, input_array):h = self.linear1(input_array)y_pred = self.linear2(h)return y_predin_d = 5
hidn = 2
out_d = 3
net = NN_Network(in_d, hidn, out_d)

然后檢查一下這個模型的Parameters列表:

for param in net.parameters():print(type(param.data), param.size())""" Output
<class 'torch.FloatTensor'> torch.Size([5, 2])
<class 'torch.FloatTensor'> torch.Size([2])
<class 'torch.FloatTensor'> torch.Size([5, 2])
<class 'torch.FloatTensor'> torch.Size([2])
"""

可以輕易地送入到優化器中:

opt = Adam(net.parameters(), learning_rate=0.001)

另外,請注意Parameter的require_grad會自動設定。

各位讀者有疑惑或異議的地方,歡迎留言討論。

參考:

https://www.jianshu.com/p/d8b77cc02410

https://stackoverflow.com/questions/50935345/understanding-torch-nn-parameter

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/532358.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/532358.shtml
英文地址,請注明出處:http://en.pswp.cn/news/532358.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

Vision Transformer(ViT)PyTorch代碼全解析(附圖解)

Vision Transformer&#xff08;ViT&#xff09;PyTorch代碼全解析 最近CV領域的Vision Transformer將在NLP領域的Transormer結果借鑒過來&#xff0c;屠殺了各大CV榜單。本文將根據最原始的Vision Transformer論文&#xff0c;及其PyTorch實現&#xff0c;將整個ViT的代碼做一…

hdfs的副本數為啥增加了_HDFS詳解之塊大小和副本數

1.HDFSHDFS : 偽分布式(學習)NNDNSNNsbin/start-dfs.sh(開啟hdfs使用的腳本)bin/hdfs dfs -ls (輸入命令加前綴bin/hdfs dfs)2.block(塊)dfs.blocksize &#xff1a; 134217728(字節) / 128M 官網默認一個塊的大小128M*舉例理解塊1個文件 130M&#xff0c;默認一個塊的大小128M…

Linux下的ELF文件、鏈接、加載與庫(含大量圖文解析及例程)

Linux下的ELF文件、鏈接、加載與庫 鏈接是將將各種代碼和數據片段收集并組合為一個單一文件的過程&#xff0c;這個文件可以被加載到內存并執行。鏈接可以執行與編譯時&#xff0c;也就是在源代碼被翻譯成機器代碼時&#xff1b;也可以執行于加載時&#xff0c;也就是被加載器加…

mysql gender_Mysql第一彈

1、創建數據庫pythoncreate database python charsetutf8;2、設計班級表結構為id、name、isdelete&#xff0c;編寫創建表的語句create table classes(id int unsigned auto_increment primary key not null,name varchar(10),isdelete bit default 0);向班級表中插入數據pytho…

python virtualenv nginx_Ubuntu下搭建Nginx+supervisor+pypy+virtualenv

系統&#xff1a;Ubuntu 14.04 LTS搭建python的運行環境&#xff1a;NginxSupervisorPypyVirtualenv軟件說明&#xff1a;Nginx&#xff1a;通過upstream進行負載均衡Supervisor&#xff1a;管理python進程Pypy&#xff1a;用Python實現的Python解釋器PyPy is a fast, complian…

如何設置mysql表中文亂碼_php mysql表中文亂碼問題如何解決

為避免mysql中出現中文亂碼&#xff0c;建議在創建數據庫時指定編碼格式&#xff1a;復制代碼 代碼示例:create database zzjz CHARACTER SET gbk COLLATE gbk_chinese_ci;create table zz_employees (employeeid int unsigned not null auto_increment primary key,name varch…

java 按鈕 監聽_Button的四種監聽方式

Button按鈕設置點擊的四種監聽方式注&#xff1a;加粗放大的都是改變的代碼1.使用匿名內部類的形式進行設置使用匿名內部類的形式&#xff0c;直接將需要設置的onClickListener接口對象初始化&#xff0c;內部的onClick方法會在按鈕被點擊的時候執行第一個活動的java代碼&#…

java int轉bitmap_Java Base64位編碼與String字符串的相互轉換,Base64與Bitmap的相互轉換實例代碼...

首先是網上大神給的類package com.duanlian.daimengmusic.utils;public final class Base64Util {private static final int BASELENGTH 128;private static final int LOOKUPLENGTH 64;private static final int TWENTYFOURBITGROUP 24;private static final int EIGHTBIT …

linux查看java虛擬機內存_深入理解java虛擬機(linux與jvm內存關系)

本文轉載自美團技術團隊發表的同名文章https://tech.meituan.com/linux-jvm-memory.html一, linux與進程內存模型要理解jvm最重要的一點是要知道jvm只是linux的一個進程,把jvm的視野放大,就能很好的理解JVM細分的一些概念下圖給出了硬件系統進程三個層面內存之間的關系.從硬件上…

java 循環stringbuffer_java常用類-----StringBuilder和StringBuffer的用法

一、可變字符常用方法package cn.zxg.PackgeUse;/*** 測試StringBuilder,StringBuffer可變字符序列常用方法*/public class TestStringBuilder2 {public static void main(String[] args) {StringBuilder sbnew StringBuilder();for(int i0;i<26;i){char temp(char)(ai);sb.…

java function void_Java8中你可能不知道的一些地方之函數式接口實戰

什么時候可以使用 Lambda&#xff1f;通常 Lambda 表達式是用在函數式接口上使用的。從 Java8 開始引入了函數式接口&#xff0c;其說明比較簡單&#xff1a;函數式接口(Functional Interface)就是一個有且僅有一個抽象方法&#xff0c;但是可以有多個非抽象方法的接口。 java8…

java jvm內存地址_JVM--Java內存區域

Java虛擬機在執行Java程序的過程中會把它所管理的內存劃分為若干個不同的數據區域&#xff0c;如圖&#xff1a;1.程序計數器可以看作是當前線程所執行的字節碼的行號指示器&#xff0c;通俗的講就是用來指示執行哪條指令的。為了線程切換后能恢復到正確的執行位置Java多線程是…

java情人節_情人節寫給女朋友Java Swing代碼程序

馬上又要到情人節了&#xff0c;再不解風情的人也得向女友表示表示。作為一個程序員&#xff0c;示愛的時候自然也要用我們自己的方式。這里給大家上傳一段我在今年情人節的時候寫給女朋友的一段簡單的Java Swing代碼&#xff0c;主要定義了一個對話框&#xff0c;讓女友選擇是…

java web filter鏈_filter過濾鏈:Filter鏈是如何構建的?

在一個Web應用程序中可以注冊多個Filter程序&#xff0c;每個Filter程序都可以針對某一個URL進行攔截。如果多個Filter程序都對同一個URL進行攔截&#xff0c;那么這些Filter就會組成一個Filter鏈(也叫過濾器鏈)。Filter鏈用FilterChain對象來表示&#xff0c;FilterChain對象中…

java web 應用技術與案例教程_《Java Web應用開發技術與案例教程》怎么樣_目錄_pdf在線閱讀 - 課課家教育...

出版說明前言第1章 java Web應用開發技術概述1.1 Java Web應用開發技術簡介1.1.1 Java Web應用1.1.2 Java Web應用開發技術1.2 Java Web開發環境及開發工具1.2.1 JDK的下載與安裝1.2.2 Tomcat服務器的安裝和配置1.2.3 MyEclipse集成開發工具的安裝與操作1.3 Java Web應用程序的…

java環境變量自動設置_自動設置Java環境變量

echo offSETLOCALENABLEDELAYEDEXPANSIONfor /f "tokens2* delims " %%i in(reg query "HKLM\Software\JavaSoft\Java Development Kit" /s ^|find /I"JavaHome") do (echo 找到目錄 %%jset /p isOK該目錄是不是JDK^(JavaDevelopment Kit^)的安裝…

mysql運行狀態監控研究內容_如何監控mysql主從的運行狀態shell腳本實例介紹

如何監控mysql主從的運行狀態shell腳本實例介紹。#!/bin/bash#define mysql variablemysql_user”root”mysql_pass”123456″email_addr”slavecentos.bz”mysql_statusnetstat -nl | awk ‘NR>2{if ($4 ~ /.*:3306/) {print “Yes”;exit 0}}’if [ "$mysql_status&q…

java 100% cpu_Java服務,CPU 100%問題如何快速定位?

Java服務&#xff0c;有時候會遇到CPU 100%的問題&#xff0c;對于這樣的問題&#xff0c;我們如何快速定位并解決呢&#xff1f;一般會有如下三個步驟&#xff1a;1、找到最耗CPU的進程2、找到這個進程中最耗CPU的線程3、查看堆棧信息&#xff0c;定位線程的什么操作消耗了大量…

java 泛型 加_Java泛型并將數字加在一起

為了一般地計算總和,您需要提供兩個動作&#xff1a;>一種總計零項的方法>一種總結兩個項目的方法在Java中,您可以通過界面完成.這是一個完整的例子&#xff1a;import java.util.*;interface adder {T zero(); // Adding zero itemsT add(T lhs, T rhs); // Adding two …

java 字母金字塔_LeetCode756:金字塔轉換矩陣(JAVA題解)

題目描述現在&#xff0c;我們用一些方塊來堆砌一個金字塔。 每個方塊用僅包含一個字母的字符串表示。使用三元組表示金字塔的堆砌規則如下&#xff1a;對于三元組(A, B, C) &#xff0c;“C”為頂層方塊&#xff0c;方塊“A”、“B”分別作為方塊“C”下一層的的左、右子塊。當…