x.view()
就是對tensor進行reshape:
我們在創建一個網絡的時候,會在Foward函數內看到view的使用。
首先這里是一個簡單的網絡,有卷積和全連接組成。它的foward函數如下:
class NET(nn.Module):def __init__(self,batch_size):super(NET,self).__init__()self.conv = nn.Conv2d(outchannels=3,in_channels=64,kernel_size=3,stride=1)self.fc = nn.Linear(64*batch_size,10)def forward(self,x):x = self.conv(x)x = x.view(x.size(0), -1) out = self.fc(x)
在CNN中卷積或者池化之后需要連接全連接層,所以需要把多維度的tensor展平成一維,x.view(x.size(0), -1)
就實現的這個功能。
卷積或者池化之后的tensor的維度為(batchsize,channels,x,y),其中x.size(0)指batchsize的值,x = x.view(x.size(0), -1)簡化x = x.view(batchsize, -1)
。( 通過x.view(x.size(0), -1)
將tensor的結構轉換為了(batchsize, channelsxy),即將(channels,x,y)拉直,然后就可以和fc層連接了。)
示例:
x變量的本質就是一個4維向量,而在conv1層的輸入的x為一個10 ? * ? 3 ? * ? 100 ? * ? 100的向量,參數分別表示batchsize,RGB,100 ? * ? 100圖片大小,x經過一層層的卷積,最后10 ? * ? 256 ? * ? 4 ? * ? 4向量作為第四層卷積輸出。
最后使用x.view(x.shape(0),-1)將x轉化成一個10行的矩陣,矩陣的每一行就是這個批量(批量大小為10)中每張圖片的各個參數(即256 ? * ? 4 ? * ? 4),即矩陣中一行對應一張圖片。
view()函數的功能根reshape類似,用來轉換size大小。
x = x.view(batchsize, -1)
中batchsize指轉換后有幾行,而-1指在不告訴函數有多少列的情況下,根據原tensor數據和batchsize自動分配列數。