CNN卷積神經網絡
基礎概念:
以卷積操作為基礎的網絡結構,每個卷積核可以看成一個特征提取器。
思想:
每次觀察數據的一部分,如圖,在整個矩陣中只觀察黃色部分3×3的矩陣,將這【3×3】矩陣·(點乘)權重得到特征矩陣的第一項,然后進行平移進行第二項的計算。依此類推,得到最后的特征矩陣。
利用Pytorch框架實現CNN
import torch
import torch.nn as nn
import numpy as np"""
使用pytorch實現CNN
不考慮偏差值
"""class TorchCNN(nn.Module):def __init__(self, in_channel, out_channel, kernel):super(TorchCNN, self).__init__()self.layer = nn.Conv2d(in_channel, out_channel, kernel, bias=False)def forward(self, x):return self.layer(x)x = np.array([[0.1, 0.2, 0.3, 0.4],[-3, -4, -5, -6],[5.1, 6.2, 7.3, 8.4],[-0.7, -0.8, -0.9, -1]]) #網絡輸入#torch實驗
in_channel = 1 #單通道(NLP中一般用單通道)
out_channel = 3 #多少個卷積核(每一個卷積核代表一個獨立的權重)
kernel_size = 2 #2*2的方塊(功能就是圖中黃色[3×3]矩陣)
torch_model = TorchCNN(in_channel, out_channel, kernel_size)
# print(torch_model.state_dict())
torch_w = torch_model.state_dict()["layer.weight"]
# print(torch_w.numpy().shape)
torch_x = torch.FloatTensor([[x]])
#權重是4維,輸入應該也為四維,通過多一個[],將輸入由三維變成四維
output = torch_model.forward(torch_x)
output = output.detach().numpy()
print(output, output.shape, "torch模型預測結果\n")
自定義模型代碼實現CNN:
采用自定義模型實現CNN,不考慮偏差值,因為要與Pytorch框架結果相對比,需要調取在Pytorch模型中的輸入和隨機權重。因此如果要運行,須將此代碼放在Pytorch框架下運行。
"""
手動實現簡單的神經網絡
與Pytorch對比實驗
"""
#自定義CNN模型
class DiyModel:def __init__(self, input_height, input_width, weights, kernel_size):self.height = input_heightself.width = input_widthself.weights = weightsself.kernel_size = kernel_sizedef forward(self, x):output = []for kernel_weight in self.weights:kernel_weight = kernel_weight.squeeze().numpy()#weight取出來時是[1×2×2],通過squeeze變成[2×2],然后變成numpy取出kernel_output = np.zeros((self.height - kernel_size + 1, self.width - kernel_size + 1)) #全0輸出矩陣for i in range(self.height - kernel_size + 1):for j in range(self.width - kernel_size + 1):window = x[i:i+kernel_size, j:j+kernel_size] #x是原始輸入 剩下的是矩陣索引方法kernel_output[i, j] = np.sum(kernel_weight * window) #np.dot != x*y x*y是點乘(對應位置相乘)output.append(kernel_output)return np.array(output)diy_model = DiyModel(x.shape[0], x.shape[1], torch_w, kernel_size)
output = diy_model.forward(x)
print(output, "diy模型預測結果")
最終對比結果:
可以清楚看到Pytorch框架下的結果與自定義框架下的結果相同。