DNN全連接層(線性層)
計算公式:
y = w * x + b
W和b是參與訓練的參數
W的維度決定了隱含層輸出的維度,一般稱為隱單元個數(hidden size)
b是偏差值(本文沒考慮)
舉例:
輸入:x (維度1 x 3)
隱含層1:w(維度3 x 5)
隱含層2: w(維度5 x 2)
個人思想如下:
比如說如上圖,我們有輸入層是3個,中間層是5個,輸出層要求是2個。利用線性代數,輸入是【1×3】,那么需要乘【3×5】的權重矩陣得到【1×5】,再由【1×5】乘【5×2】的權重矩陣,最后得到【1×2】的結果。在本代碼中沒有考慮偏差值(bias),利用pytorch中隨機初始化的權重實現模型預測。
import torch
import torch.nn as nn
import numpy as np
"""
用pytorch框架實現單層的全連接網絡
不使用偏置bias
"""
class TorchModel(nn.Module): #nn.module是torch自帶的庫def __init__(self, input_size, hidden_size, output_size):super(TorchModel, self).__init__()self.layer1 = nn.Linear(input_size, hidden_size, bias=False)#nn.linear是torch的線性層,input_size是輸入的維度,hidden_size是這一層的輸出的維度self.layer2 = nn.Linear(hidden_size, output_size, bias=False)#這個線性層可以有很多個def forward(self, x): #開始計算的函數hidden = self.layer1(x) #傳入輸入第一層# print("torch hidden", hidden)y_pred = self.layer2(hidden) #傳入輸入第二層return y_pred
x = np.array([1, 0, 0]) #網絡輸入#torch實驗
torch_model = TorchModel(len(x), 5, 2) #這三個數分別代表輸入,中間,結果層的維度
#print(torch_model.state_dict()) #可以打印出pytorch隨機初始化的權重
torch_model_w1 = torch_model.state_dict()["layer1.weight"].numpy()
#通過取字典方式將權重取出來并把torch的權重轉化為numpy的
torch_model_w2 = torch_model.state_dict()["layer2.weight"].numpy()
#print(torch_model_w1, "torch w1 權重")
#這里你會發現隨機初始化的權重矩陣是5×3,所以當自定義模型時需要轉置,但是在pytorch中會自動轉置相乘
#print(torch_model_w2, "torch w2 權重")
torch_x = torch.FloatTensor([x]) #numpy的輸入轉化為torch
y_pred = torch_model.forward(torch_x)
print("torch模型預測結果:", y_pred)
以上是pytorch模型實現DNN的簡單方法。
自定義模型手工實現:
(注意因為自定義模型需要得到模型中的權重,而上面代碼利用的是pytorch的隨機自定義模型,為了能讓兩者對比答案是否相同,自定義模型中的權重需要繼承pytorch的隨機權重)
"""
手動實現簡單的神經網絡
用自定義框架實現單層的全連接網絡
不使用偏置bias
"""
#自定義模型
class DiyModel:def __init__(self, weight1, weight2):self.weight1 = weight1 #收到在torch隨機的權重self.weight2 = weight2def forward(self, x):hidden = np.dot(x, self.weight1.T) #將輸入與第一層權重的轉置相乘y_pred = np.dot(hidden, self.weight2.T)return y_preddiy_model = DiyModel(torch_model_w1, torch_model_w2)
y_pred_diy = diy_model.forward(np.array([x]))
print("diy模型預測結果:", y_pred_diy)
如需運行須將自定義模型放入pytorch的代碼下面繼承輸入和隨機權重,通過最后結果能發現兩者相同。
結果如下:
可以發現兩者代碼結果相同~