【深度學習】PyTorch從0到1——手寫你的第一個卷積神經網絡模型,AI模型開發全過程實戰

引言

本次準備建立一個卷積神經網絡模型,用于區分鳥和飛機,并從CIFAR-10數據集中選出所有鳥和飛機作為本次的數據集。

以此為例,介紹一個神經網絡模型從數據集準備、數據歸一化處理、模型網絡函數定義、模型訓練、結果驗證、模型文件保存,端到端的模型全生命周期,方便大家深入了解AI模型開發的全過程。

?

一、網絡場景定義與數據集準備

1.1 數據集準備

本次我準備使用CIFAR10數據集,它是一個簡單有趣的數據集,由60000張小RGB圖片構成(32像素*32像素),每張圖類別標簽用1~10數字表示

%matplotlib inline
from matplotlib import pyplot as pltfrom torchvision import datasets
data_path = '/content/sample_data'
cifar10 = datasets.CIFAR10(data_path, train=True, download=True)
cifar10_val = datasets.CIFAR10(data_path, train=False, download=True)type(cifar10).__mro__

?

1.2 查看數據集類別示例

class_names = ['airplane', 'aotomobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']fig = plt.figure(figsize=(8, 3))
num_classes = 10
for i in range(num_classes):ax = fig.add_subplot(2, 5 ,1 + i, xticks=[], yticks=[])ax.set_title(class_names[i])img = next(img for img, label in cifar10 if label == i)plt.imshow(img)
plt.show()

?

1.2.1 輸出單張圖像類別及展示圖片

img, label = cifar10[99]
img, label, class_names[label]

plt.imshow(img)
plt.show()

?

1.3 數據集Dataset變換

使用torchvision.transforms模塊,將PIL圖像變換為PyTorch張量,用于圖像分類

1.3.1 將單張圖像轉換為張量,輸出張量大小

from torchvision import transformsto_tensor = transforms.ToTensor()
img_t = to_tensor(img)
img_t.shape

1.3.2 將CIFAR10數據集轉換為張量

tensor_cifar10 = datasets.CIFAR10(data_path, train=True, download=False, transform=transforms.ToTensor())tensor_cifar10.__len__()

1.4 數據歸一化

使用transforms.Compose()將圖像連接起來,在數據加載器中直接進行數據歸一化和數據增強操作

使用transforms.Normalize(),計算數據集中每個通道的平均值和標準差,使每個通道的均值為0,標準差為1

imgs = torch.stack([img_t for img_t, _ in tensor_cifar10], dim=3)
imgs.shape

1.4.1 計算每個通道的平均值(mean)

imgs.view(3, -1).mean(dim=1)

1.4.2 計算每個通道的標準差(stdev)

imgs.view(3, -1).std(dim=1)

1.4.3 使用transforms.Normailze()對數據集歸一化

使每個數據集的通道的均值為0,標準差為1

transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))

二、使用nn.Module編寫第一個識別鳥與飛機的網絡模型

2.1 構建鳥與飛機的訓練集和驗證集

2.1.1 準備CIFAR10數據集

cifar10 = datasets.CIFAR10(data_path, train=True, download=False,transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.4914, 0.4822, 0.4465),(0.2470, 0.2435, 0.2616))]))
cifar10_val = datasets.CIFAR10(data_path, train=True, download=False,transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.4914, 0.4822, 0.4465),(0.2470, 0.2435, 0.2616))]))

2.1.2 構建CIFAR2-數據集

label_map = {0:0, 2:1}
class_names = ['airplane', 'bird']
cifar2 = [(img, label_map[label])for img, label in cifar10if label in [0, 2]
]cifar2.__len__()

2.1.3 構建CIFAR2-驗證集

cifar2_val = [(img, label_map[label])for img, label in cifar10_valif label in [0, 2]]cifar2_val.__len__()

2.1.4 準備批處理圖像

img, _ = cifar2[0]plt.imshow(img.permute(1, 2, 0))
plt.show()

img
img.shape

2.2 編寫第一個nn.Module子模塊的網絡定義

import torch
import torch.nn as nn
import torch.optim as optimclass Net(nn.Module):def __init__(self):super().__init__()self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1)self.act1 = nn.Tanh()self.pool1 = nn.MaxPool2d(2)self.conv2 = nn.Conv2d(16, 8, kernel_size=3, padding=1)self.act2 = nn.Tanh()self.pool2 = nn.MaxPool2d(2)self.fc1 = nn.Linear(8 * 8 * 8, 32)self.act3 = nn.Tanh()self.fc2 = nn.Linear(32, 2)def forward(self, x):out = self.pool1(self.act1(self.conv1(x)))out = self.pool2(self.act2(self.conv2(out)))out = out.view(-1, 8 * 8 * 8)out = self.act3(self.fc1(out))out = self.fc2(out)return out

?

2.2.1 將網絡模型實例化,并輸出模型參數

model = Net()numel_list = [p.numel() for p in model.parameters()]
sum(numel_list), numel_list

2.3 使用函數式API,優化nn.Module網絡函數定義

import torch.nn.functional as Fclass Net(nn.Module):def __init__(self):super().__init__()self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1)self.conv2 = nn.Conv2d(16, 8, kernel_size=3, padding=1)self.fc1 = nn.Linear(8 * 8 * 8, 32)self.fc2 = nn.Linear(32, 2)def forward(self, x):out = F.max_pool2d(torch.tanh(self.conv1(x)), 2)out = F.max_pool2d(torch.tanh(self.conv2(out)), 2)out = out.view(-1, 8 * 8 * 8)out = torch.tanh(self.fc1(out))out = self.fc2(out)return outmodel = Net()
model(img.unsqueeze(0))

2.4 定義網絡模型的訓練循環函數,并執行訓練

import datetimedef training_loop(n_epochs, optimizer, model, loss_fn, train_loader):for epoch in range(1, n_epochs + 1):loss_train = 0.0for imgs, labels in train_loader:outputs = model(imgs)loss = loss_fn(outputs, labels)optimizer.zero_grad()loss.backward()optimizer.step()loss_train += loss.item()if epoch == 1 or epoch %10 == 0:print('{} Epoch {}, Training loss{}'.format(datetime.datetime.now(), epoch, loss_train / len(train_loader)))train_loader = torch.utils.data.DataLoader(cifar2, batch_size = 64, shuffle=True)model = Net()
optimizer = optim.SGD(model.parameters(), lr=1e-2)
loss_fn = nn.CrossEntropyLoss()training_loop(n_epochs = 100,optimizer = optimizer,model = model,loss_fn = loss_fn,train_loader = train_loader,

?

2.4.1 訓練結果(耗時7分鐘)

2025-08-17 15:13:20.123706 Epoch 1, Training loss0.5672952472024663
2025-08-17 15:14:01.667640 Epoch 10, Training loss0.32902660861516453
2025-08-17 15:14:47.187795 Epoch 20, Training loss0.2960508146863075
2025-08-17 15:15:33.119990 Epoch 30, Training loss0.26820498961172284
2025-08-17 15:16:19.303661 Epoch 40, Training loss0.24607981879050564
2025-08-17 15:17:04.858228 Epoch 50, Training loss0.22783752284042394
2025-08-17 15:17:50.712569 Epoch 60, Training loss0.2095268357806145
2025-08-17 15:18:36.846523 Epoch 70, Training loss0.19460647420328894
2025-08-17 15:19:22.404563 Epoch 80, Training loss0.18098321051639357
2025-08-17 15:20:08.067236 Epoch 90, Training loss0.16757476806735536
2025-08-17 15:20:54.041604 Epoch 100, Training loss0.15512346253273593

2.5 測量準確率(使用驗證集)

train_loader = torch.utils.data.DataLoader(cifar2, batch_size=64, shuffle=False)val_loader = torch.utils.data.DataLoader(cifar2_val, batch_size=64, shuffle=False)def validate(model, train_loader, val_loader):for name, loader in [("train", train_loader), ("val", val_loader)]:correct = 0total = 0with torch.no_grad():for imgs, labels in loader:outputs = model(imgs)_, predicted = torch.max(outputs, dim=1)total += labels.shape[0]correct += int((predicted == labels).sum())print("Accuracy {}: {:.2f}".format(name, correct/total))validate(model, train_loader, val_loader)

2.6 保存并加載我們的模型

2.6.1 保存模型

torch.save(model.state_dict(), data_path + 'birds_vs_airplanes.pt')

torch.save(model.state_dict(), data_path + 'birds_vs_airplanes.pt')

2.6.2 模型pt文件生成

包含模型的所有參數,即2個卷積模塊和2個線性模塊的權重和偏置

2.6.3 加載參數到模型實例

loaded_model = Net()
loaded_model.load_state_dict(torch.load(data_path+'birds_vs_airplanes.pt'))

三、小結

至此,我們完成一個卷積神經網絡模型birds_vs_airplanes的構建,可用于圖像分類識別,區分圖片是鳥還是飛機,準確性高達94!

我們從數據集準備、數據集準備、數據歸一化處理、模型網絡函數定義、模型訓練、結果驗證、模型文件保存,并將模型參數加載到另一個新模型實例中,端到端完整串聯一個神經網絡模型全生命周期的過程,加深對AI模型開發的理解,這是個經典案例,快來試試吧~

?

?

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/diannao/96081.shtml
繁體地址,請注明出處:http://hk.pswp.cn/diannao/96081.shtml
英文地址,請注明出處:http://en.pswp.cn/diannao/96081.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

云計算核心技術之容器技術

一、容器技術 1.1、為什么需要容器 在使用虛擬化一段時間后,發現它存在一些問題:不同的用戶,有時候只是希望運行各自的一些簡單程序,跑一個小進程。為了不相互影響,就要建立虛擬機。如果建虛擬機,顯然浪費就…

微信小程序通過uni.chooseLocation打開地圖選擇位置,相關設置及可能出現的問題

前言 uni.chooseLocation打開地圖選擇位置,看官方文檔介紹的比較簡單,但是需要注意的細節不少,如果沒有注意可能就無法使用該API或者報錯,下面就把詳細的配置方法做一下介紹。 一、勾選位置接口 ①在uniapp項目根目錄找到manif…

從財務整合到患者管理:德國醫療集團 Asklepios完成 SAP S/4HANA 全鏈條升級路徑

目錄 挑戰 解決方案 詳細信息 Asklepios成立于1985年,目前擁有約170家醫療機構,是德國大型私營診所運營商。Asklepios是希臘和羅馬神話中的醫神。 挑戰 Asklepios希望進一步擴大其作為數字醫療保健集團的地位。2020年9月,該公司與SNP合作…

高頻PCB廠家及工藝能力分析

一、技術領先型廠商(適合高復雜度、高可靠性設計)這類廠商在高頻材料處理、超精密加工和信號完整性控制方面具備深厚積累,尤其適合軍工、衛星通信、醫療設備等嚴苛場景:深南電路:在超高層板和射頻PCB領域是行業標桿&am…

AJAX 與 ASP 的融合:技術深度解析與應用

AJAX 與 ASP 的融合:技術深度解析與應用 引言 隨著互聯網技術的不斷發展,AJAX(Asynchronous JavaScript and XML)和ASP(Active Server Pages)技術逐漸成為構建動態網頁和應用程序的重要工具。本文將深入探討AJAX與ASP的融合,分析其原理、應用場景以及在實際開發中的優…

MuMu模擬器Pro Mac 安卓手機平板模擬器(Mac中文)

原文地址:MuMu模擬器Pro Mac 安卓手機平板模擬器 MuMu模擬器 Pro mac版,是一款MuMuPlayer安卓模擬器,可以暢快運行安卓游戲和應用。 MuMu模擬器Pro搭載安卓12操作系統,極致釋放設備性能,最高支持240幀畫面效果&#…

Oracle維護指南

Part 1 Oracle 基礎與架構#### **1.1 概述** - **Oracle 數據庫版本歷史與特性對比** - **版本演進**: - Oracle 8i(1999):支持 Internet 應用,引入 Java 虛擬機(JVM)。 - Oracle 9i&#…

如何為PDF文件批量添加騎縫章?

騎縫章跨越多頁文件的邊緣加蓋,一旦文件被替換其中某一頁或順序被打亂,印章就無法對齊,能立刻發現異常。這有效保障了文件的完整性和真實性。它是純凈免費,不帶廣告,專治各類PDF蓋章需求。用法極簡:文件直接…

組合時代的 TOGAF?:為模塊化企業重新思考架構

隨著企業努力追求敏捷性和創新性,組合性正逐漸成為一項基礎性的設計原則。組合思維改變了企業交付能力的方式 —— 更傾向于采用模塊化、獨立的組件,這些組件可以快速組裝和重組。本文探討了長期以來作為企業架構框架的TOGAF標準如何演進以支持組合架構。…

電子元器件-電阻終篇:基本原理,電阻分類及特點,參數/手冊詳解,電阻作用及應用場景,電阻選型及實戰案例

目錄 一、基本原理 1.1 介紹 1.2 計算公式?編輯 1.3 單位 1.4 標稱值 二、分類及特點 2.1電阻分類及特點介紹 2.2常用電阻器件詳細介紹 三、參數/數據手冊解讀 3.1 阻值 3.2 封裝&功率 3.3 精度 3.5 額定電壓 3.6 溫度系數(TCR) 3.7 擴展 四、作用與使用場…

【軟件測試】電商購物項目-各個測試點整理(六)

目錄:導讀 前言一、Python編程入門到精通二、接口自動化項目實戰三、Web自動化項目實戰四、App自動化項目實戰五、一線大廠簡歷六、測試開發DevOps體系七、常用自動化測試工具八、JMeter性能測試九、總結(尾部小驚喜) 前言 1、優惠券測試點 …

心路歷程-啟動流程的概念

我們之前已經安裝過系統,其實興奮的內心已經無以言表; 記得剛開始的那份喜悅是沒辦法演說的;可是高興之余,好像突然又心情EMO了; 為何呢?因為系統裝完了,你也不知道能夠干什么; 所以…

Kubernetes Ingress實戰:從環境搭建到應用案例

目錄 一、概述 版本對比圖 二、 Ingress應用案例 2.1 環境準備 2.2 驗證-NodePort模式 設置Http代理 2.3 驗證-LoadBalancer模式 修改ARP模式,啟用嚴格ARP模式 搭建metallb支持LoadBalancer 普通的service測試 ingress訪問測試: 一、概述 Ser…

項目發布上線清單

說明:博主想整理一份項目發布上線的清單,在每次發布上線前,對照清單一一核對,避免遺漏(往事不堪回首),歡迎大家補充。 前端是否有與后端協同發布的接口? 如果有,先發前端…

HTB Information Gathering - Web Edition最后的測驗

因為它沒有DNS解析,,所以不要嘗試去使用dns枚舉所有枚舉出來的子域,馬上修改hosts文件,與ip和域名填好,因為它不依賴dns通過vhost子域爆破 爬蟲登場 w*****.inlanefreight.htb:32508爬到之后不要去理會那個api,除了填答案,,,其他任何用處都沒有,不要浪費時間后面就不能劇透了,可…

IDEA、Pycharm、DataGrip等激活破解沖突問題解決方案之一

Jetbranis旗下的軟件破解沖突問題解決方案之一,不一定適用所有破解包 問題:在使用Pycharm破解包破解該軟件后,同樣是jetbranis旗下軟件的Datagrip卻失去了之前破解的效果,需要重新破解,重新成功破解datagrip后&#xf…

使用 uv管理 Python 虛擬環境:比conda更快、更輕量的現代方案

文章目錄什么是 uv?安裝 uv在線安裝(推薦)Windows 系統Linux / macOS 系統離線安裝步驟 1:獲取二進制包步驟 2:解壓并移動到可執行路徑步驟 3:設置環境變量驗證安裝創建并激活虛擬環境創建虛擬環境輸出示例…

課堂記憶項目開發日志

課堂記憶項目開發日志 日期: 2025年8月18日 1. 基礎實現 項目目標: 創建一個動態、美觀的“課堂記憶”頁面,展示教師信息、教學成果、學生反饋、未來計劃、教學成就和教學金句。 實現交互功能,包括按鈕點擊展開內容、圖片點擊彈出詳細信息、圖表展示數據。 技術棧: HTML5 C…

藍橋杯算法之搜索章 - 7

大家好,不同的時間,相同的地點!又和大家見面了,接下來我將帶來多源BFS的內容 通過多源BFS的學習,大家將對BFS理解更加深入! lets go! 前言 通過前面內容的學習,大家肯定已經對于BFS有了一定理解…

onRequestHide at ORIGIN_CLIENT reason HIDE_SOFT_INPUT fromUser false

這個錯誤日志 onRequestHide at ORIGIN_CLIENT reason HIDE_SOFT_INPUT fromUser false 通常出現在 Android 平臺的 WebView 或混合應用(如 Cordova/Capacitor)中,與軟鍵盤(Soft Input)的隱藏行為有關。以下是可能的原…