參考
5.3 多輸入通道和多輸出通道
前面兩節里我們用到的輸入和輸出都是二維數組,但真實數據的維度經常更高。例如,彩色圖像在高和寬2個維度外還有RGB(紅、綠、藍)3個顏色通道。假設彩色圖像的高和寬分別是h和w(像素),那么它可以表示為一個3 * h * w的多維數組。我們將大小為3的這一維稱為通道(channel)維。本節將介紹含多個輸入通道或多個輸出通道的卷積核。
5.3.1 多輸入通道
接下來我們實現含多個輸入通道的互相關運算。我們只需要對每個通道做互相關運算,然后通過add_n
函數來進行累加
import torch
import torch.nn as nn
import sys
sys.path.append("..")
import d2lzh_pytorch as d2ldef corr2d_multi_in(X, K):# 沿著X和K的第0維(通道維)分別計算再相加res = d2l.corr2d(X[0, :, :], K[0, :, :])print(res)for i in range(1, X.shape[0]): # X.shape[0]代表多少個通道,此處為2個res += d2l.corr2d(X[i, :, :], K[i, :, :])return res
X = torch.tensor([[[0,1,2],[3,4,5],[6,7,8]],[[1,2,3], [4,5,6], [7,8,9]] ])K = torch.tensor([[[0,1],[2,3]], [[1,2],[3,4]]])corr2d_multi_in(X, K)
5.3.2 多輸出通道
當輸入通道有多個時,因為我們對各自通道的結果做了累加,所以不論輸入通道數是多少,輸出通道數總是為1。設卷積核輸入通道數和輸出通道數分別為c(i)和c(o),高和寬分別為k(h)和k(w)。如果希望得到含多個通道的輸出,我們可以為每個輸出通道分別創建形狀為c(i) * k(k) * h(w)的核數組。將它們在輸出通道維上連結,卷積核的形狀即 c(o) * c(i) * k(h) * k(w)。在做互相關運算時,每個輸出通道上的結果由卷積核在輸出通道上的核數組與整個輸入數組計算而來。
簡單說就是,如果你想輸出N個通道,你就需要創建N個 C * H * W的卷積核
下面實現一個互相關運算函數來計算多個通道的輸出。
def corr2d_multi_in_out(X, K):# 對K的第0維遍歷,每次同輸入X做互相關計算。所有結果使用stack函數合并在一起return torch.stack([corr2d_multi_in(X, k) for k in K])
我們將核數組K同K+1(K中每個元素加一)和K+2連結在一起來構造一個輸出通道數為3的卷積核
K = torch.tensor([[[0,1],[2,3]], [[1,2],[3,4]]])# 構造3個卷積核
K = torch.stack([K, K+1, K+2])
K.shape
下面我們對輸入數組X與核數組K做互相關運算。此時的輸出含有3個通道。其中第一個通道的結果與之前輸入數組X與多輸入通道、單輸出通道核的計算結果一致。
# 輸入的規模為 2 * 3 * 3 輸出的規模為 3 * (3 - 2+ 1) * (3 - 2 + 1)
corr2d_multi_in_out(X, K)
5.3.3 1 * 1卷積層
def corr2d_multi_in_out_1x1(X, K):c_i, h, w = X.shapec_o = K.shape[0]X = X.view(c_i, h * w)K = K.view(c_o, c_i)Y = torch.mm(K, X) # 全連接層的矩陣乘法return Y.view(c_o, h, w)
X = torch.rand(3, 3, 3)
K = torch.rand(2, 3, 1, 1)Y1 = corr2d_multi_in_out_1x1(X, K)
Y2 = corr2d_multi_in_out(X, K)(Y1 - Y2).norm().item() < 1e-6