文章目錄
- 前言
- 基本操作
- 填充、步幅和多通道
- 填充 (Padding)
- 步幅 (Stride)
- 多通道
- 總結
前言
在卷積神經網絡(CNN)的大家族中,我們熟悉的卷積層和匯聚(池化)層通常會降低輸入特征圖的空間維度(高度和寬度)。然而,在許多應用場景中,例如圖像的語義分割(需要對每個像素進行分類)或生成對抗網絡(GAN)中的圖像生成,我們反而需要增加特征圖的空間維度,即進行上采樣。
轉置卷積(Transposed Convolution),有時也被不那么準確地稱為反卷積(Deconvolution),正是實現這一目標的關鍵操作。它能夠將經過下采樣的低分辨率特征圖恢復到較高的分辨率,或者在生成模型中從低維噪聲逐步生成高分辨率圖像。
本文將通過具體的 PyTorch 代碼示例,帶您一步步理解轉置卷積的基本原理、填充(Padding)、步幅(Stride)以及在多通道情況下的應用。
完整代碼:下載連接
基本操作
讓我們從最基礎的轉置卷積開始。假設我們有一個 2x2 的輸入張量,并使用一個 2x2 的卷積核,步幅為1,沒有填充。轉置卷積的操作過程可以直觀地理解為:將輸入張量的每個元素作為標量,與卷積核相乘,得到多個中間結果;然后,將這些中間結果按照輸入元素在原張量中的位置進行“放置”和疊加,從而得到最終的輸出張量。
其核心思想可以看作是常規卷積操作的一種“逆向”映射,但它并非嚴格意義上的數學逆運算。
下圖形象地展示了這個過程:
圖中,輸入是 2x2,卷積核是 2x2。
- 輸入張量的左上角元素(0)與整個卷積核相乘,結果放置在輸出張量的左上角。
- 輸入張量的右上角元素(1)與整個卷積核相乘,結果向右移動一格放置。
- 輸入張量的左下角元素(2)與整個卷積核相乘,結果向下移動一格放置。
- 輸入張量的右下角元素(3)與整個卷積核相乘,結果向右和向下各移動一格放置。
- 所有這些放置的張量在重疊區域進行元素相加,得到最終的 3x3 輸出。
輸出張量的高度 (H_out) 和寬度 (W_out) 可以通過以下公式計算(當步幅為1,無填充時):
- H_out = H_in + H_kernel - 1
- W_out = W_in + W_kernel - 1
下面我們用代碼來實現這個基本操作:
import torch
from torch import nndef transposed_convolution(input_tensor, kernel):"""實現轉置卷積(反卷積)操作參數:input_tensor: 輸入張量,維度為 (input_height, input_width)kernel: 卷積核,維度為 (kernel_height, kernel_width)返回:output_tensor: 轉置卷積結果,維度為 (input_height + kernel_height - 1, input_width + kernel_width - 1)"""# 獲取卷積核的高度和寬度,維度分別為 scalarkernel_height, kernel_width = kernel.shape# 初始化輸出張量,維度為 (input_height + kernel_height - 1, input_width + kernel_width - 1)output_tensor = torch.zeros((input_tensor.shape[0] + kernel_height - 1, input_tensor.shape[1] + kernel_width - 1))# 對輸入張量中的每個元素進行處理for i in range(input_tensor.shape[0]): # 遍歷輸入張量的行for j in range(input_tensor.shape[1]): # 遍歷輸入張量的列# 對于輸入張量中的每個元素,將其與卷積核相乘,然后加到輸出張量的對應區域# input_tensor[i, j] 是標量,維度為 ()# kernel 維度為 (kernel_height, kernel_width)# 輸出區域 output_tensor[i:i+kernel_height, j:j+kernel_width] 維度為 (kernel_height, kernel_width)output_tensor[i:i + kernel_height, j:j + kernel_width] += input_tensor[i, j] * kernelreturn output_tensor# 示例使用
# 創建輸入張量X,維度為 (2, 2)
X = torch.tensor([[0.0, 1.0], [2.0, 3.0]])# 創建卷積核K,維度為 (2, 2)
K = torch.tensor([