PyTorch中的torch.nn.Parameter() 詳解
今天來聊一下PyTorch中的torch.nn.Parameter()這個函數,筆者第一次見的時候也是大概能理解函數的用途,但是具體實現原理細節也是云里霧里,在參考了幾篇博文,做過幾個實驗之后算是清晰了,本文在記錄的同時希望給后來人一個參考,歡迎留言討論。
分析
先看其名,parameter,中文意為參數。我們知道,使用PyTorch訓練神經網絡時,本質上就是訓練一個函數,這個函數輸入一個數據(如CV中輸入一張圖像),輸出一個預測(如輸出這張圖像中的物體是屬于什么類別)。而在我們給定這個函數的結構(如卷積、全連接等)之后,能學習的就是這個函數的參數了,我們設計一個損失函數,配合梯度下降法,使得我們學習到的函數(神經網絡)能夠盡量準確地完成預測任務。
通常,我們的參數都是一些常見的結構(卷積、全連接等)里面的計算參數。而當我們的網絡有一些其他的設計時,會需要一些額外的參數同樣很著整個網絡的訓練進行學習更新,最后得到最優的值,經典的例子有注意力機制中的權重參數、Vision Transformer中的class token和positional embedding等。
而這里的torch.nn.Parameter()就可以很好地適應這種應用場景。
下面是這篇博客的一個總結,筆者認為講的比較明白,在這里引用一下:
首先可以把這個函數理解為類型轉換函數,將一個不可訓練的類型
Tensor
轉換成可以訓練的類型parameter
并將這個parameter
綁定到這個module
里面(net.parameter()
中就有這個綁定的parameter
,所以在參數優化的時候可以進行優化的),所以經過類型轉換這個self.v
變成了模型的一部分,成為了模型中根據訓練可以改動的參數了。使用這個函數的目的也是想讓某些變量在學習的過程中不斷的修改其值以達到最優化。
ViT中nn.Parameter()的實驗
看過這個分析后,我們再看一下Vision Transformer中的用法:
...self.pos_embedding = nn.Parameter(torch.randn(1, num_patches+1, dim))
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
...
我們知道在ViT中,positonal embedding和class token是兩個需要隨著網絡訓練學習的參數,但是它們又不屬于FC、MLP、MSA等運算的參數,在這時,就可以用nn.Parameter()來將這個隨機初始化的Tensor注冊為可學習的參數Parameter。
為了確定這兩個參數確實是被添加到了net.Parameters()內,筆者稍微改動源碼,顯式地指定這兩個參數的初始數值為0.98,并打印迭代器net.Parameters()。
...self.pos_embedding = nn.Parameter(torch.ones(1, num_patches+1, dim) * 0.98)
self.cls_token = nn.Parameter(torch.ones(1, 1, dim) * 0.98)
...
實例化一個ViT模型并打印net.Parameters():
net_vit = ViT(image_size = 256,patch_size = 32,num_classes = 1000,dim = 1024,depth = 6,heads = 16,mlp_dim = 2048,dropout = 0.1,emb_dropout = 0.1)for para in net_vit.parameters():print(para.data)
輸出結果中可以看到,最前兩行就是我們顯式指定為0.98的兩個參數pos_embedding和cls_token:
tensor([[[0.9800, 0.9800, 0.9800, ..., 0.9800, 0.9800, 0.9800],[0.9800, 0.9800, 0.9800, ..., 0.9800, 0.9800, 0.9800],[0.9800, 0.9800, 0.9800, ..., 0.9800, 0.9800, 0.9800],...,[0.9800, 0.9800, 0.9800, ..., 0.9800, 0.9800, 0.9800],[0.9800, 0.9800, 0.9800, ..., 0.9800, 0.9800, 0.9800],[0.9800, 0.9800, 0.9800, ..., 0.9800, 0.9800, 0.9800]]])
tensor([[[0.9800, 0.9800, 0.9800, ..., 0.9800, 0.9800, 0.9800]]])
tensor([[-0.0026, -0.0064, 0.0111, ..., 0.0091, -0.0041, -0.0060],[ 0.0003, 0.0115, 0.0059, ..., -0.0052, -0.0056, 0.0010],[ 0.0079, 0.0016, -0.0094, ..., 0.0174, 0.0065, 0.0001],...,[-0.0110, -0.0137, 0.0102, ..., 0.0145, -0.0105, -0.0167],[-0.0116, -0.0147, 0.0030, ..., 0.0087, 0.0022, 0.0108],[-0.0079, 0.0033, -0.0087, ..., -0.0174, 0.0103, 0.0021]])
...
...
這就可以確定nn.Parameter()添加的參數確實是被添加到了Parameters列表中,會被送入優化器中隨訓練一起學習更新。
from torch.optim import Adam
opt = Adam(net_vit.parameters(), learning_rate=0.001)
其他解釋
以下是國外StackOverflow的一個大佬的解讀,筆者自行翻譯并放在這里供大家參考,想查看原文的同學請戳這里。
我們知道Tensor相當于是一個高維度的矩陣,它是Variable類的子類。Variable和Parameter之間的差異體現在與Module關聯時。當Parameter作為model的屬性與module相關聯時,它會被自動添加到Parameters列表中,并且可以使用net.Parameters()迭代器進行訪問。
最初在Torch中,一個Variable(例如可以是某個中間state)也會在賦值時被添加為模型的Parameter。在某些實例中,需要緩存變量,而不是將它們添加到Parameters列表中。
文檔中提到的一種情況是RNN,在這種情況下,您需要保存最后一個hidden state,這樣就不必一次又一次地傳遞它。需要緩存一個Variable,而不是讓它自動注冊為模型的Parameter,這就是為什么我們有一個顯式的方法將參數注冊到我們的模型,即nn.Parameter類。
舉個例子:
import torch
import torch.nn as nn
from torch.optim import Adamclass NN_Network(nn.Module):def __init__(self,in_dim,hid,out_dim):super(NN_Network, self).__init__()self.linear1 = nn.Linear(in_dim,hid)self.linear2 = nn.Linear(hid,out_dim)self.linear1.weight = torch.nn.Parameter(torch.zeros(in_dim,hid))self.linear1.bias = torch.nn.Parameter(torch.ones(hid))self.linear2.weight = torch.nn.Parameter(torch.zeros(in_dim,hid))self.linear2.bias = torch.nn.Parameter(torch.ones(hid))def forward(self, input_array):h = self.linear1(input_array)y_pred = self.linear2(h)return y_predin_d = 5
hidn = 2
out_d = 3
net = NN_Network(in_d, hidn, out_d)
然后檢查一下這個模型的Parameters列表:
for param in net.parameters():print(type(param.data), param.size())""" Output
<class 'torch.FloatTensor'> torch.Size([5, 2])
<class 'torch.FloatTensor'> torch.Size([2])
<class 'torch.FloatTensor'> torch.Size([5, 2])
<class 'torch.FloatTensor'> torch.Size([2])
"""
可以輕易地送入到優化器中:
opt = Adam(net.parameters(), learning_rate=0.001)
另外,請注意Parameter的require_grad會自動設定。
各位讀者有疑惑或異議的地方,歡迎留言討論。
參考:
https://www.jianshu.com/p/d8b77cc02410
https://stackoverflow.com/questions/50935345/understanding-torch-nn-parameter