PyTorch 提供了靈活的方式來構建自定義神經網絡模型。下面我將詳細介紹從基礎到高級的自定義模型構建方法,包含實際代碼示例和最佳實踐。
一、基礎模型構建
1. 繼承 nn.Module 基類
所有自定義模型都應該繼承?torch.nn.Module
?類,并實現兩個基本方法:
import torch.nn as nn
import torch.nn.functional as Fclass MyModel(nn.Module):def __init__(self):super().__init__() # 必須調用父類初始化# 在這里定義網絡層self.conv1 = nn.Conv2d(1, 20, 5)self.conv2 = nn.Conv2d(20, 50, 5)self.fc1 = nn.Linear(4*4*50, 500)self.fc2 = nn.Linear(500, 10)def forward(self, x):# 定義前向傳播邏輯x = F.relu(self.conv1(x))x = F.max_pool2d(x, 2, 2)x = F.relu(self.conv2(x))x = F.max_pool2d(x, 2, 2)x = x.view(-1, 4*4*50)x = F.relu(self.fc1(x))x = self.fc2(x)return F.log_softmax(x, dim=1)
2. 模型使用方式?
model = MyModel()
output = model(input_tensor) # 自動調用forward方法
loss = criterion(output, target)
loss.backward()
二、中級構建技巧
1. 使用 nn.Sequential
nn.Sequential
?是一種用于快速構建順序神經網絡的容器類,適用于模塊按線性順序排列的模型。
class MySequentialModel(nn.Module):def __init__(self):super().__init__()self.features = nn.Sequential(nn.Conv2d(3, 64, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=2, stride=2),nn.Conv2d(64, 128, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=2, stride=2))self.classifier = nn.Sequential(nn.Linear(128 * 8 * 8, 512),nn.ReLU(inplace=True),nn.Linear(512, 10))def forward(self, x):x = self.features(x)x = x.view(x.size(0), -1)x = self.classifier(x)return x
?2. 參數初始化
def initialize_weights(m):if isinstance(m, nn.Conv2d):nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')if m.bias is not None:nn.init.constant_(m.bias, 0)elif isinstance(m, nn.Linear):nn.init.normal_(m.weight, 0, 0.01)nn.init.constant_(m.bias, 0)model.apply(initialize_weights) # 遞歸應用初始化函數
三、高級構建模式
1. 殘差連接 (ResNet風格)
class ResidualBlock(nn.Module):def __init__(self, in_channels, out_channels, stride=1):super().__init__()self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)self.bn1 = nn.BatchNorm2d(out_channels)self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3,stride=1, padding=1, bias=False)self.bn2 = nn.BatchNorm2d(out_channels)self.shortcut = nn.Sequential()if stride != 1 or in_channels != out_channels:self.shortcut = nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size=1,stride=stride, bias=False),nn.BatchNorm2d(out_channels))def forward(self, x):out = F.relu(self.bn1(self.conv1(x)))out = self.bn2(self.conv2(out))out += self.shortcut(x)return F.relu(out)
2. 自定義層?
class MyCustomLayer(nn.Module):def __init__(self, input_dim, output_dim):super().__init__()self.weight = nn.Parameter(torch.randn(output_dim, input_dim))self.bias = nn.Parameter(torch.randn(output_dim))def forward(self, x):return F.linear(x, self.weight, self.bias)
?
四、模型保存與加載
1. 保存整個模型
torch.save(model, 'model.pth') # 保存
model = torch.load('model.pth') # 加載
2. 保存狀態字典 (推薦)
torch.save(model.state_dict(), 'model_state.pth') # 保存
model.load_state_dict(torch.load('model_state.pth')) # 加載
五、模型部署準備
1. 模型導出為TorchScript
scripted_model = torch.jit.script(model) # 或 torch.jit.trace
scripted_model.save('model_scripted.pt')
2. ONNX格式導出
dummy_input = torch.randn(1, 3, 224, 224)
torch.onnx.export(model, dummy_input, "model.onnx", input_names=["input"], output_names=["output"])
六、完整示例:自定義CNN分類器
import torch
from torch import nn
from torch.nn import functional as Fclass CustomCNN(nn.Module):"""自定義CNN圖像分類器Args:num_classes (int): 輸出類別數dropout_prob (float): dropout概率,默認0.5"""def __init__(self, num_classes=10, dropout_prob=0.5):super().__init__()self.features = nn.Sequential(nn.Conv2d(3, 32, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=2, stride=2),nn.Conv2d(32, 64, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=2, stride=2),nn.Conv2d(64, 128, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=2, stride=2))self.avgpool = nn.AdaptiveAvgPool2d((6, 6))self.classifier = nn.Sequential(nn.Dropout(p=dropout_prob),nn.Linear(128 * 6 * 6, 512),nn.ReLU(inplace=True),nn.Dropout(p=dropout_prob),nn.Linear(512, num_classes))# 初始化權重for m in self.modules():if isinstance(m, nn.Conv2d):nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')elif isinstance(m, nn.Linear):nn.init.normal_(m.weight, 0, 0.01)nn.init.constant_(m.bias, 0)def forward(self, x: torch.Tensor) -> torch.Tensor:"""前向傳播Args:x (torch.Tensor): 輸入張量,形狀為[B, C, H, W]Returns:torch.Tensor: 輸出logits,形狀為[B, num_classes]"""x = self.features(x)x = self.avgpool(x)x = torch.flatten(x, 1)x = self.classifier(x)return x
七、注意事項
-
?輸入輸出維度匹配?
- 需確保相鄰模塊的輸入/輸出維度兼容。例如,卷積層后接全連接層時需通過?
Flatten
?或自適應池化調整維度?。
- 需確保相鄰模塊的輸入/輸出維度兼容。例如,卷積層后接全連接層時需通過?
-
?調試與驗證?
- 可通過模擬輸入數據驗證模型結構,如:
input = torch.ones(64, 3, 32, 32) # 模擬 batch_size=64 的輸入 output = model(input) print(output.shape) # 檢查輸出形狀是否符合預期
?
- 可通過模擬輸入數據驗證模型結構,如:
?