本文通過PyTorch實現二維互相關運算、自定義卷積層,并演示如何通過卷積核檢測圖像邊緣。同時,我們將訓練一個卷積核參數,使其能夠從數據中學習邊緣特征。
1. 二維互相關運算的實現
互相關運算(Cross-Correlation)是卷積操作的基礎。以下代碼實現了二維互相關運算:
import torch
from torch import nndef corr2d(x, k):h, w = k.shapey = torch.zeros((x.shape[0] - h + 1, x.shape[1] - w + 1))for i in range(y.shape[0]):for j in range(y.shape[1]):y[i, j] = (x[i:i+h, j:j+w] * k).sum() # 逐元素相乘后求和return y
驗證輸出:
輸入矩陣和卷積核如下,輸出結果為互相關運算后的張量:
x = torch.tensor([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0], [6.0, 7.0, 8.0]])
k = torch.tensor([[0.0, 1.0], [2.0, 3.0]])
print(corr2d(x, k))
輸出:
tensor([[19., 25.],[37., 43.]])
2. 自定義二維卷積層
通過繼承nn.Module
實現一個自定義卷積層,包含可學習的權重和偏置:
class Conv2D(nn.Module):def __init__(self, kernel_size):super().__init__()self.weight = nn.Parameter(torch.rand(kernel_size))self.bias = nn.Parameter(torch.zeros(1))def forward(self, x):return corr2d(x, self.weight) + self.bias
3. 邊緣檢測應用
3.1 構造輸入圖像
創建一個6x8的矩陣,中間4列為黑色(值為0),兩側為白色(值為1):
x = torch.ones(6, 8)
x[:, 2:6] = 0
print(x)
輸出:
tensor([[1., 1., 0., 0., 0., 0., 1., 1.],[1., 1., 0., 0., 0., 0., 1., 1.],[1., 1., 0., 0., 0., 0., 1., 1.],[1., 1., 0., 0., 0., 0., 1., 1.],[1., 1., 0., 0., 0., 0., 1., 1.],[1., 1., 0., 0., 0., 0., 1., 1.]])
3.2 定義卷積核
使用卷積核[[1, -1]]
檢測垂直邊緣:
k = torch.tensor([[1.0, -1.0]])
y = corr2d(x, k)
print(y)
輸出:
tensor([[ 0., 1., 0., 0., 0., -1., 0.],[ 0., 1., 0., 0., 0., -1., 0.],[ 0., 1., 0., 0., 0., -1., 0.],[ 0., 1., 0., 0., 0., -1., 0.],[ 0., 1., 0., 0., 0., -1., 0.],[ 0., 1., 0., 0., 0., -1., 0.]])
-
結果解釋:
輸出中1
表示從白到黑的邊緣,-1
表示從黑到白的邊緣。
3.3 水平邊緣檢測
若將輸入矩陣轉置,原卷積核無法檢測水平邊緣:
print(corr2d(x.T, k))
輸出:全零矩陣(無法檢測到水平邊緣)
tensor([[0., 0., 0., 0., 0.],[0., 0., 0., 0., 0.],...])
4. 學習卷積核參數
使用PyTorch內置的nn.Conv2d
,通過梯度下降學習卷積核參數:
# 定義模型
conv2d = nn.Conv2d(in_channels=1, out_channels=1, kernel_size=(1, 2), bias=False)# 調整輸入輸出形狀
x = x.reshape((1, 1, 6, 8)) # (batch_size, channels, height, width)
y = y.reshape((1, 1, 6, 7))# 訓練過程
for i in range(10):y_hat = conv2d(x)loss = (y_hat - y).pow(2)conv2d.zero_grad()loss.sum().backward()conv2d.weight.data[:] -= 3e-2 * conv2d.weight.grad # 更新權重if (i+1) % 2 == 0:print(f'batch{i+1}, loss{loss.sum():.3f}')
輸出:
batch2, loss5.270
batch4, loss0.884
batch6, loss0.148
batch8, loss0.025
batch10, loss0.004
4.1 查看學習后的卷積核
訓練后的權重接近理想值[1, -1]
:
print(conv2d.weight.data.reshape((1, 2)))
輸出:
tensor([[ 0.9883, -0.9878]])
5. 總結
-
互相關運算:通過逐窗口計算實現基礎的卷積操作。
-
邊緣檢測:方向特定的卷積核可提取圖像邊緣特征。
-
參數學習:利用梯度下降可自動學習卷積核參數,無需手動設計。
完整代碼已驗證,讀者可自行調整輸入或卷積核探索更多效果。
提示:實際項目中建議使用PyTorch內置的高效卷積層(如nn.Conv2d
),而非手動實現,以充分利用GPU加速。