搭建神經網絡架構
在pytorch中,神經網絡被抽象成由一系列對數據執行特定操作的層或者模塊組成,比如下面的Attention實現,每個塊都是一個模塊或者層。
如果你想快速搭建網絡架構,torch.nn這個命名空間提供了所有很多開箱即用的層/模塊/算子:
如果你想自定義一個模塊也是完全可以的。每個模塊都是nn.Module
的子類,你只需要繼承然后復寫即可,這個后面有例子。
這種簡潔的架構抽象可以讓使用pytorch的人們快速搭建并管理精妙的模型架構。
接下來,我們將搭建一個神經網絡來分類FashionMNIST數據集,來過一遍搭建網絡的工作流。
import os
import torch
from torch import nn
from torch.utils.data import Dataloader
from torchvision import datasets, transforms
1. 獲取可能的加速設備
為了在 加速器(accelerator) 上訓練我們的模型,例如 CUDA、MPS、MTIA 或 XPU,我們將遵循以下邏輯:
如果當前設備有可用的加速器,我們就使用它;否則,我們將使用 CPU。
device = torch.accelerator.current_accelerator().type if torch.accelerator.is_available() else "cpu"
print(f"Using {device} device")
2. 搭建網絡結構
2.1 定義網絡類
通過繼承nn.Module
,我們可以定義我們的神經網絡類,并且在__init__
里面定義我們要用到的模塊或者層。然后實現forward
方法來定義對輸入模型的數據的實際操作以及操作順序,并且返回推理結果。
class NeuralNetwork(nn.Module):def __init__(self):super().__init__()self.faltten = nn.Faltten() # 展平層self.linear_relu_stack = nn.Sequential( # 定義一個序列模塊,被調用時會依次執行所含模塊nn.Linear(28*28, 512),nn.ReLU(),nn.Linear(512, 512),nn.ReLU(),nn.Linear(512, 10),)def forward(self, x):x = self.flatten(x)logit = self.linearr_relu_stack(x)return logits
注意,
__init__
只負責把需要的塊給初始化出來,具體數據是怎么在塊間流動由forward
實現。
2.2 實例化網絡并查看結構
現在我們實例化網絡,并且把它搬到device側,然后打印出他的結構:
model = NeuralNetwork().to(device)
print(model)
2.3 進行網絡“冒煙測試”
搭建好網絡結構之后,強烈建議進行一次“冒煙測試”,用一個符合輸入shape的tensor看看整個網絡能不能跑通。
要給模型傳入數據進行推理,直接給模型傳入數據即可,千萬別直接調用forward方法,因為model(x)
還會做一些forward沒做的一些其他必要操作。
X = torch.rand(1, 28, 28, device=device)
logits = model(X)
print(logits.shape)
pred_probab = nn.Softmax(dim=1)(logits)
print(pred_probab)
y_pred = pred_probab.argmax(1)
print(f"Predicted class: {y_pred}")
給模型輸入數據之后,模型返回一個2維的tensor,dim=0的數據是batch中的具體樣本idx,dim=1的數據則是輸出的這個樣本的所屬10個不同類別的預測值。最后我們套一層nn.Softmax
, 就可以獲得每個類別的概率pred_probab
了。最后對其使用argmax(1)
找到該張量在dim=1維度上的最大值索引,就獲得了這一次推理的分類結果。
3. 進階操作:獲取模型當前的參數
如果你想要一點可解釋性,你可能得用到這個
神經網絡中的許多層都是參數化的,也就是說,它們有相關的權重(weights)和偏差(biases),這些值會在訓練過程中進行優化。
當你的模型繼承自 nn.Module
時,PyTorch 會自動追蹤模型對象中定義的所有字段。因此,你可以通過模型的 parameters()
或 named_parameters()
方法來訪問所有這些參數。
print(model)for name, param in model.named_parameters():pritn(f"Layer: {name} | Size: {param.size()} | Values : {param[:2]} \n") # 矩陣獲取前兩行,bias獲取前兩個
在這個例子中,我們遍歷了每一個參數,并打印出它的尺寸(size)和部分值預覽。