PyTorch構建自定義模型

PyTorch 提供了靈活的方式來構建自定義神經網絡模型。下面我將詳細介紹從基礎到高級的自定義模型構建方法,包含實際代碼示例和最佳實踐。

一、基礎模型構建

1. 繼承 nn.Module 基類

所有自定義模型都應該繼承?torch.nn.Module?類,并實現兩個基本方法:

import torch.nn as nn
import torch.nn.functional as Fclass MyModel(nn.Module):def __init__(self):super().__init__()  # 必須調用父類初始化# 在這里定義網絡層self.conv1 = nn.Conv2d(1, 20, 5)self.conv2 = nn.Conv2d(20, 50, 5)self.fc1 = nn.Linear(4*4*50, 500)self.fc2 = nn.Linear(500, 10)def forward(self, x):# 定義前向傳播邏輯x = F.relu(self.conv1(x))x = F.max_pool2d(x, 2, 2)x = F.relu(self.conv2(x))x = F.max_pool2d(x, 2, 2)x = x.view(-1, 4*4*50)x = F.relu(self.fc1(x))x = self.fc2(x)return F.log_softmax(x, dim=1)

2. 模型使用方式?

model = MyModel()
output = model(input_tensor)  # 自動調用forward方法
loss = criterion(output, target)
loss.backward()

二、中級構建技巧

1. 使用 nn.Sequential

nn.Sequential?是一種用于快速構建順序神經網絡的容器類,適用于模塊按線性順序排列的模型。

class MySequentialModel(nn.Module):def __init__(self):super().__init__()self.features = nn.Sequential(nn.Conv2d(3, 64, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=2, stride=2),nn.Conv2d(64, 128, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=2, stride=2))self.classifier = nn.Sequential(nn.Linear(128 * 8 * 8, 512),nn.ReLU(inplace=True),nn.Linear(512, 10))def forward(self, x):x = self.features(x)x = x.view(x.size(0), -1)x = self.classifier(x)return x

?2. 參數初始化

def initialize_weights(m):if isinstance(m, nn.Conv2d):nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')if m.bias is not None:nn.init.constant_(m.bias, 0)elif isinstance(m, nn.Linear):nn.init.normal_(m.weight, 0, 0.01)nn.init.constant_(m.bias, 0)model.apply(initialize_weights)  # 遞歸應用初始化函數

三、高級構建模式

1. 殘差連接 (ResNet風格)

class ResidualBlock(nn.Module):def __init__(self, in_channels, out_channels, stride=1):super().__init__()self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)self.bn1 = nn.BatchNorm2d(out_channels)self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3,stride=1, padding=1, bias=False)self.bn2 = nn.BatchNorm2d(out_channels)self.shortcut = nn.Sequential()if stride != 1 or in_channels != out_channels:self.shortcut = nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size=1,stride=stride, bias=False),nn.BatchNorm2d(out_channels))def forward(self, x):out = F.relu(self.bn1(self.conv1(x)))out = self.bn2(self.conv2(out))out += self.shortcut(x)return F.relu(out)

2. 自定義層?

class MyCustomLayer(nn.Module):def __init__(self, input_dim, output_dim):super().__init__()self.weight = nn.Parameter(torch.randn(output_dim, input_dim))self.bias = nn.Parameter(torch.randn(output_dim))def forward(self, x):return F.linear(x, self.weight, self.bias)

?

四、模型保存與加載

1. 保存整個模型

torch.save(model, 'model.pth')  # 保存
model = torch.load('model.pth')  # 加載

2. 保存狀態字典 (推薦)

torch.save(model.state_dict(), 'model_state.pth')  # 保存
model.load_state_dict(torch.load('model_state.pth'))  # 加載

五、模型部署準備

1. 模型導出為TorchScript

scripted_model = torch.jit.script(model)  # 或 torch.jit.trace
scripted_model.save('model_scripted.pt')

2. ONNX格式導出

dummy_input = torch.randn(1, 3, 224, 224)
torch.onnx.export(model, dummy_input, "model.onnx", input_names=["input"], output_names=["output"])

六、完整示例:自定義CNN分類器

import torch
from torch import nn
from torch.nn import functional as Fclass CustomCNN(nn.Module):"""自定義CNN圖像分類器Args:num_classes (int): 輸出類別數dropout_prob (float): dropout概率,默認0.5"""def __init__(self, num_classes=10, dropout_prob=0.5):super().__init__()self.features = nn.Sequential(nn.Conv2d(3, 32, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=2, stride=2),nn.Conv2d(32, 64, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=2, stride=2),nn.Conv2d(64, 128, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=2, stride=2))self.avgpool = nn.AdaptiveAvgPool2d((6, 6))self.classifier = nn.Sequential(nn.Dropout(p=dropout_prob),nn.Linear(128 * 6 * 6, 512),nn.ReLU(inplace=True),nn.Dropout(p=dropout_prob),nn.Linear(512, num_classes))# 初始化權重for m in self.modules():if isinstance(m, nn.Conv2d):nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')elif isinstance(m, nn.Linear):nn.init.normal_(m.weight, 0, 0.01)nn.init.constant_(m.bias, 0)def forward(self, x: torch.Tensor) -> torch.Tensor:"""前向傳播Args:x (torch.Tensor): 輸入張量,形狀為[B, C, H, W]Returns:torch.Tensor: 輸出logits,形狀為[B, num_classes]"""x = self.features(x)x = self.avgpool(x)x = torch.flatten(x, 1)x = self.classifier(x)return x

七、注意事項

  1. ?輸入輸出維度匹配?

    • 需確保相鄰模塊的輸入/輸出維度兼容。例如,卷積層后接全連接層時需通過?Flatten?或自適應池化調整維度?。
  2. ?調試與驗證?

    • 可通過模擬輸入數據驗證模型結構,如:
      input = torch.ones(64, 3, 32, 32)  # 模擬 batch_size=64 的輸入
      output = model(input)
      print(output.shape)  # 檢查輸出形狀是否符合預期

      ?

?

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/pingmian/76052.shtml
繁體地址,請注明出處:http://hk.pswp.cn/pingmian/76052.shtml
英文地址,請注明出處:http://en.pswp.cn/pingmian/76052.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

AI智算-K8s如何利用GPFS分布式并行文件存儲加速訓練or推理

文章目錄 GPFS簡介核心特性存儲環境介紹存儲軟件版本客戶端存儲RoCEGPFS 管理(GUI)1. 創建 CSI 用戶2. 檢查GUI與k8s通信文件系統配置1. 開啟配額2. 啟用filesetdf文件系統3. 驗證文件系統配置4. 啟用自動inode擴展存儲集群配置1. 啟用對根文件集(root fileset)配額2. igno…

gbase8s之邏輯導出導入腳本(完美版本)

該腳本dbexport.sh用于快速導出庫和導入庫(使用多并發unload,和多并發dbload的方式) #!/bin/sh #腳本功能:將數據導出成文本,遷移至其他實例 #最后更新時間:2023-12-19 #使用方法: #1.執行該腳…

springMVC-攔截器詳解

攔截器 概述 SpringMVC的處理器攔截器類似于Servlet開發中的過濾器Filter,用于對處理器進行預處理和后處理。開發者可以自己定義一些攔截器來實現特定的功能。 過濾器與攔截器的區別:攔截器是AOP思想的具體應用。 過濾器 servlet規范中的一部分,任何ja…

網絡安全應急響應-系統排查

在網絡安全應急響應中,系統排查是快速識別潛在威脅的關鍵步驟。以下是針對Windows和Linux系統的系統基本信息排查指南,涵蓋常用命令及注意事項: 一、Windows系統排查 1. 系統信息工具(msinfo32.exe) 命令執行&#x…

基于YOLO的半自動化標注方法:提升鐵路視頻缺陷檢測效率

論文地址:https://arxiv.org/pdf/2504.01010 1. 論文結構概述 本文提出了一種半自動化標注方法,旨在解決鐵路缺陷檢測中大規模圖像/視頻數據集標注成本高、耗時長的問題。論文結構清晰,分為以下核心部分: ?引言(Introduction)? 強調傳統手動標注的痛點(耗時、易錯、…

Linux驅動開發:SPI驅動開發原理

前言 本文章是根據韋東山老師的教學視頻整理的學習筆記https://video.100ask.net/page/1712503 SPI 通信協議采用同步全雙工傳輸機制,拓撲架構支持一主多從連接模式,這種模式在實際應用場景中頗為高效。其有效傳輸距離大致為 10m ,傳輸速率…

Android Hilt 教程

Android Hilt 教程 —— 一看就懂,一學就會 1. 什么是 Hilt?為什么要用 Hilt? Hilt 是 Android 官方推薦的 依賴注入(DI)框架,基于 Dagger 開發,能夠大大簡化依賴注入的使用。 為什么要用 Hi…

【算法手記11】NC41 最長無重復子數組 NC379 重排字符串

🦄個人主頁:修修修也 🎏所屬專欄:刷題 ??操作環境:牛客網 目錄 一.NC41 最長無重復子數組 題目詳情: 題目思路: 解題代碼: 二.NC379 重排字符串 題目詳情: 題目思路: 解題代碼: 結語 一.NC41 最長無重復子數組 牛客網題目鏈接(點擊即可跳轉):NC41 最長…

C語言:字符串處理函數strstr分析

在 C 語言中,strstr 函數用于查找一個字符串中是否存在另一個字符串。它的主要功能是搜索指定的子字符串,并返回該子字符串在目標字符串中第一次出現的位置的指針。如果沒有找到子字符串,則返回 NULL。 詳細說明: 頭文件&#xf…

在windows下安裝spark

在windows下安裝spark完成 安裝過程:

MongoDB常見面試題總結(上)

MongoDB 基礎 MongoDB 是什么? MongoDB 是一個基于 分布式文件存儲 的開源 NoSQL 數據庫系統,由 C 編寫的。MongoDB 提供了 面向文檔 的存儲方式,操作起來比較簡單和容易,支持“無模式”的數據建模,可以存儲比較復雜…

【Java設計模式】第2章 UML急速入門

2-1 本章導航 UML類圖與時序圖入門 UML定義 統一建模語言(Unified Modeling Language):第三代非專利建模語言。特點:開放方法,支持可視化構建面向對象系統,涵蓋模型、流程、代碼等。UML分類(2.2版本) 結構式圖形:系統靜態建模(類圖、對象圖、包圖)。行為式圖形:事…

【4】搭建k8s集群系列(二進制部署)之安裝master節點組件(kube-apiserver)

一、下載k8s二進制文件 下載地址: https://github.com/kubernetes/kubernetes/blob/master/CHANGELOG/CHANGELOG -1.20.md 注:打開鏈接你會發現里面有很多包,下載一個 server 包就夠了,包含了 Master 和 Worker Node 二進制文件。…

電子電氣架構 --- AUTOSAR 的信息安全架構

我是穿拖鞋的漢子,魔都中堅持長期主義的汽車電子工程師。 老規矩,分享一段喜歡的文字,避免自己成為高知識低文化的工程師: 周末洗了一個澡,換了一身衣服,出了門卻不知道去哪兒,不知道去找誰,漫無目的走著,大概這就是成年人最深的孤獨吧! 舊人不知我近況,新人不知我過…

ROS2與OpenAI Gym集成指南:從安裝到自定義環境與強化學習訓練

1.理解 ROS2 和 OpenAI Gym 的基本概念 ROS2(Robot Operating System 2):是一個用于機器人軟件開發的框架。它提供了一系列的工具、庫和通信機制,方便開發者構建復雜的機器人應用程序。例如,ROS2 可以處理機器人不同組…

【設計模式】創建型 -- 單例模式 (c++實現)

文章目錄 單例模式使用場景c實現靜態局部變量餓漢式(線程安全)懶漢式(線程安全)懶漢式(線程安全) 智能指針懶漢式(線程安全)智能指針call_once懶漢式(線程安全)智能指針call_onceCRTP 單例模式 單例模式是…

C語言之九九乘法表

一、代碼展示 二、運行結果 三、代碼分析 首先->是外層循環是小于等于9的 然后->是內層循環是小于等于外層循環的 最后->就是\n讓九九乘法表的格式更加美觀(當然 電腦不同 有可能%2d 也有可能%3d) 四、與以下素數題目邏輯相似 五、運行結果

自動化備份全網服務器數據平臺

自動化備份全網服務器數據平臺 項目背景知識 總體需求 某企業里有一臺Web服務器,里面的數據很重要,但是如果硬盤壞了數據就會丟失,現在領導要求把數據做備份,這樣Web服務器數據丟失在可以進行恢復。要求如下:1.每天0…

stm32+esp8266+機智云手機app

現在很多大學嵌入式畢設都要求云端控制,本文章就教一下大家如何使用esp8266去連接機智云的app去進行顯示stm32的外設傳感器數據啊,控制一些外設啊等。 因為本文章主要教大家如何移植機智云的代碼到自己的工程,所以前面的一些準備工作&#x…

時序數據庫 TDengine Cloud 私有連接實戰指南:4步實現數據安全傳輸與成本優化

小T導讀:在物聯網和工業互聯網場景下,企業對高并發、低延遲的數據處理需求愈發迫切。本文將帶你深入了解 TDengineCloud 如何通過全托管服務與私有連接,幫助企業實現更安全、更高效、更低成本的數據采集與傳輸,從架構解析到實際配…