深度學習:PyTorch卷積神經網絡圖像分類案例分享

本文目錄:

  • 一、了解CIFAR-10數據集
  • 二、案例之導包
  • 三、案例之創建數據集
  • 四、案例之搭建神經網絡(模型構建)
  • 五、案例之編寫訓練函數(訓練模型)
  • 六、案例之編寫預測函數(模型測試)

前言:此前分享了卷積神經網絡相關知識,今天實戰下:搭建一個卷積神經網絡來實現圖像分類任務。

一、了解CIFAR-10數據集

CIFAR-10數據集5萬張訓練圖像、1萬張測試圖像、10個類別、每個類別有6k個圖像,圖像大小32×32×3。下圖列舉了10個類,每一類隨機展示了10張圖片:
在這里插入圖片描述
PyTorch 中的 torchvision.datasets 計算機視覺模塊封裝了 CIFAR10 數據集,如果需要使用可以直接導入。

導入代碼:

from torchvision.datasets import CIFAR10

二、案例之導包

import torch
import torch.nn as nn
from torchvision.datasets import CIFAR10
from torchvision.transforms import ToTensor  # pip install torchvision -i https://mirrors.aliyun.com/pypi/simple/
import torch.optim as optim
from torch.utils.data import DataLoader
import time
import matplotlib.pyplot as plt
from torchsummary import summary# 每批次樣本數
BATCH_SIZE = 8

三、案例之創建數據集

# 1. 數據集基本信息
def create_dataset():# 加載數據集:訓練集數據和測試數據# ToTensor: 將image(一個PIL.Image對象)轉換為一個Tensortrain = CIFAR10(root='data', train=True, transform=ToTensor())valid = CIFAR10(root='data', train=False, transform=ToTensor())# 返回數據集結果return train, validif __name__ == '__main__':# 數據集加載train_dataset, valid_dataset = create_dataset()# 數據集類別print("數據集類別:", train_dataset.class_to_idx)# 數據集中的圖像數據print("訓練集數據集:", train_dataset.data.shape)print("測試集數據集:", valid_dataset.data.shape)# 圖像展示plt.figure(figsize=(2, 2))plt.imshow(train_dataset.data[1])plt.title(train_dataset.targets[1])plt.show()

運行結果:

數據集類別: {'airplane': 0, 'automobile': 1, 'bird': 2, 'cat': 3, 'deer': 4, 'dog': 5, 'frog': 6, 'horse': 7, 'ship': 8, 'truck': 9}
訓練集數據集: (50000, 32, 32, 3)
測試集數據集: (10000, 32, 32, 3)

圖像:

在這里插入圖片描述

四、案例之搭建神經網絡(模型構建)

需要搭建的CNN網絡結構如下:
在這里插入圖片描述
我們要搭建的網絡結構如下:

  1. 輸入形狀: 32x32;
  2. 第一個卷積層輸入 3 個 Channel, 輸出 6 個 Channel, Kernel Size 為: 3x3;
  3. 第一個池化層輸入 30x30, 輸出 15x15, Kernel Size 為: 2x2, Stride 為: 2;
  4. 第二個卷積層輸入 6 個 Channel, 輸出 16 個 Channel, Kernel Size 為 3x3;
  5. 第二個池化層輸入 13x13, 輸出 6x6, Kernel Size 為: 2x2, Stride 為: 2;
  6. 第一個全連接層輸入 576 維, 輸出 120 維;
  7. 第二個全連接層輸入 120 維, 輸出 84 維;
  8. 最后的輸出層輸入 84 維, 輸出 10 維。

我們在每個卷積計算之后應用 relu 激活函數來給網絡增加非線性因素。

# 模型構建
class ImageClassification(nn.Module):# 定義網絡結構def __init__(self):super(ImageClassification, self).__init__()# 定義網絡層:卷積層+池化層# 第一個卷積層, 輸入圖像為3通道,輸出特征圖為6通道,卷積核3*3self.conv1 = nn.Conv2d(3, 6, stride=1, kernel_size=3)# 第一個池化層, 核寬高2*2self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)# 第二個卷積層, 輸入圖像為6通道,輸出特征圖為16通道,卷積核3*3self.conv2 = nn.Conv2d(6, 16, stride=1, kernel_size=3)# 第二個池化層, 核寬高2*2self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)# 全連接層# 第一個隱藏層 輸入特征576(一張圖像為16*6*6), 輸出特征120個self.linear1 = nn.Linear(576, 120)# 第二個隱藏層self.linear2 = nn.Linear(120, 84)# 輸出層self.out = nn.Linear(84, 10)# 定義前向傳播def forward(self, x):# 卷積+relu+池化x = torch.relu(self.conv1(x))x = self.pool1(x)# 卷積+relu+池化x = torch.relu(self.conv2(x))x = self.pool2(x)# 將特征圖做成以為向量的形式:相當于特征向量 全連接層只能接收二維數據集# 由于最后一個批次可能不夠8,所以需要根據批次數量來改變形狀# x[8, 16, 6, 6] --> [8, 576] -->8個樣本,576個特征# x.size(0):1個值是樣本數 行數# -1:第2個值由原始x剩余3個維度值相乘計算得到 列數(特征個數)x = x.reshape(x.size(0), -1)# 全連接層x = torch.relu(self.linear1(x))x = torch.relu(self.linear2(x))# 返回輸出結果return self.out(x)if __name__ == '__main__':# 模型實例化model = ImageClassification()summary(model, input_size=(3,32,32), batch_size=1)

運行結果:

在這里插入圖片描述

五、案例之編寫訓練函數(訓練模型)

在訓練時,使用多分類交叉熵損失函數,Adam 優化器。具體實現代碼如下:

def train(model, train_dataset):# 構建數據加載器dataloader = DataLoader(train_dataset, batch_size=10, shuffle=True)criterion = nn.CrossEntropyLoss() # 構建損失函數optimizer = optim.Adam(model.parameters(), lr=1e-3) # 構建優化方法epoch = 100  # 訓練輪數for epoch_idx in range(epoch):sum_num = 0   # 樣本數量total_loss = 0.0  # 損失總和correct = 0  # 預測正確樣本數start = time.time()  # 開始時間# 遍歷數據進行網絡訓練for x, y in dataloader:model.train()output = model(x)loss = criterion(output, y)  # 計算損失optimizer.zero_grad()  # 梯度清零loss.backward()  # 反向傳播optimizer.step()  # 參數更新correct += (torch.argmax(output, dim=-1) == y).sum()  # 計算預測正確樣本數# 計算每次訓練模型的總損失值 loss是每批樣本平均損失值total_loss += loss.item()*len(y)  # 統計損失和sum_num += len(y)print('epoch:%2s loss:%.5f acc:%.2f time:%.2fs' %(epoch_idx + 1,total_loss / sum_num,correct / sum_num,time.time() - start))# 模型保存torch.save(model.state_dict(), 'model/image_classification.pth')#聯合上面代碼一起運行本代碼
if __name__ == '__main__':# 數據集加載train_dataset, valid_dataset = create_dataset()# 模型實例化model = ImageClassification()# 模型訓練train(model,train_dataset)

運行結果:

epoch: 1 loss:1.67102 acc:0.38 time:26.23s
epoch: 2 loss:1.35650 acc:0.51 time:27.63s
epoch: 3 loss:1.22355 acc:0.57 time:31.10s
epoch: 4 loss:1.14639 acc:0.59 time:66.37s
epoch: 5 loss:1.09468 acc:0.61 time:40.38s
。。。。。。

六、案例之編寫預測函數(模型測試)

當已經訓練好模型(model),并保存了模型參數(model.state_dict()),可直接實例化模型,并加載訓練好的模型參數,然后對測試集中的1萬條樣本進行預測,查看模型在測試集上的準確率。

def eval(valid_dataset):# 構建數據加載器dataloader = DataLoader(valid_dataset, batch_size=BATCH_SIZE, shuffle=False)# 加載模型并加載訓練好的權重model = ImageClassification()model.load_state_dict(torch.load('model/image_classification.pth'))# 模型切換評估模式, 如果網絡模型中有dropout/BN等層, 評估階段不進行相應操作model.eval()# 計算精度total_correct = 0total_samples = 0# 遍歷每個batch的數據,獲取預測結果,計算精度for x, y in dataloader:output = model(x)total_correct += (torch.argmax(output, dim=-1) == y).sum()total_samples += len(y)# 打印精度print('Acc: %.2f' % (total_correct / total_samples))if __name__ == '__main__':train_dataset, valid_dataset = create_dataset()eval(valid_dataset)

運行結果:

Acc: 0.57

最后,大家還可以通過調整lr(學習率)、神經元失活(dropout)、增加神經網絡層數等方式來調整模型,提升acc,各看本領吧!

今天的分享到此結束。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/912274.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/912274.shtml
英文地址,請注明出處:http://en.pswp.cn/news/912274.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

記錄多功能按鍵第二種寫法使用定時器周期間隔判斷.

邏輯是通過定時器溢出周期進行判斷按下次數 比如設置定時器溢出周期為500MS,每次溢出都會判斷按鍵按下次數,如果下個周期前沒有觸發按下,則結束鍵值判斷.并確定觸發鍵值.清空按下次數標志.測試比一個定時器周期按下按鍵次數判斷寫法要穩定... 記錄STM32實現多功能按鍵_stm32一…

【安卓Sensor框架-1】SensorService 的啟動流程

內核啟動后,首個用戶空間進程init(pid1)解析init.rc配置文件,啟動關鍵服務(如Zygote和ServiceManager)。 Zygote服務配置為/system/bin/app_process --zygote --start-system-server,后續用于孵…

centos網卡綁定參考

同事整理分享: 1. 加載 Bonding 模塊 modprobe bonding 獲取網卡名稱 ip a 找到接了網線的網卡名稱,記下。 3. 配置物理網卡 創建并編輯 /etc/sysconfig/network-scripts/ifcfg-ens36(ifcfg-后面的內容根據上面找到的具體網卡名稱決定&#…

mbedtls ssl handshake error,res:-0x2700

用LinkSDK.c連接第三方云平臺出現現象 解決方案: 在_tls_network_establish函數中加入 mbedtls_ssl_conf_authmode(&adapter_handle->mbedtls.ssl_config, MBEDTLS_SSL_VERIFY_NONE);原因解釋:用連接方式是不用證書認證/跳過服務端認證。

Spring Security 的方法級權限控制是如何利用 AOP 的?

Spring Security 的方法級權限控制是 AOP 技術在實際應用中一個極其強大的應用典范。它允許我們以聲明式的方式保護業務方法,將安全規則與業務邏輯徹底解耦。 核心思想:權限檢查的“門衛” 你可以把 AOP 在方法級安全中的作用想象成一個盡職盡責的“門…

一鍵內網穿透,無需域名和服務器,自動https訪問

cloudflare能將內網web轉為外網可訪問的地址。(這和apiSQL有點類似,apiSQ可以將內網數據庫輕松轉換為外網的API,并且還支持代理內網已有API,增強安全增加API Key,以https訪問等等) 但Cloudfalre tunnel這個…

Sentinel(二):Sentinel流量控制

一、Sentinel 流控規則基本介紹 1、Snetinel 流控規則配置方式 Sentinel 支持可視化的流控規則配置,使用非常簡單;可以在監控服務下的“簇點鏈路” 或 “流控規則” 中 給指定的請求資源配置流控規則;一般推薦在 “簇點鏈路” 中配置流控規則…

支持PY普冉系列單片機調試工具PY32linK仿真器

PY32 Link是專為 ?PY32系列ARM-Cortex內核單片機?(如PY32F002A/030/071/040/403等)設計的仿真器,支持全系列芯片的?調試和仿真?功能。?開發環境兼容性?支持主流IDE:?Keil MDK? 和 ?IAR Embedded Workbench?,…

深入解析Python多服務器監控告警系統:從原理到生產部署

深入解析Python多服務器監控告警系統:從原理到生產部署 整體架構圖 核心設計思想 無代理監控:通過SSH直接獲取數據,無需在目標服務器安裝代理故障隔離:單臺服務器故障不影響整體監控多級檢測:網絡層→資源層→服務層層…

JUC:10.線程、monitor管程、鎖對象之間在synchronized加鎖的流程(未完)

一、monitor管程工作原理: 首先,synchronized是一個對象鎖,當線程運行到某個臨界區,這個臨界區使用synchronized對對象obj進行了上鎖,此時底層發生了什么? 1.當synchronized對obj上鎖后,synch…

Elasticsearch(ES)分頁

Elasticsearch(簡稱 ES)本身不適合傳統意義上的“深分頁”,但提供了多種分頁方式,每種適用不同場景。我們來詳細講解: 一、基本分頁(from size) 最常用的分頁方式,類似 SQL 的 LIM…

原生微信小程序:用 `setData` 正確修改數組中的對象項狀態(附實戰技巧)

📌 背景介紹 在微信小程序開發中,我們經常需要修改數組中某個對象的某個字段,比如: 列表中的某一項展開/收起多選狀態切換數據列表中的臨時標記等 一個常見的場景是: lists: [{ show: true }, { show: true }, { s…

Oracle 臨時表空間相關操作

一、臨時表空間概述 臨時表空間(Temporary Tablespace)是Oracle數據庫中用于存儲臨時數據的特殊存儲區域,其數據在會話結束或事務提交后自動清除,重啟數據庫后徹底消失。主要用途包括: 存儲排序操作(如OR…

從靜態到動態:Web渲染模式的演進和突破

渲染模式有好多種,了解下web的各種渲染模式,對技術選型有很大的參考作用。 一、靜態HTML時代 早期(1990 - 1995年)網頁開發完全依賴手工編寫HTML(HyperText Markup Language)和CSS(層疊樣式表…

Flask(六) 數據庫操作SQLAlchemy

文章目錄 一、準備工作二、最小化可運行示例? 補充延遲綁定方式(推薦方式) 三、數據庫基本操作(增刪改查)1. 插入數據(增)2. 查詢數據(查)3. 更新數據(改)4.…

PYTHON從入門到實踐7-獲取用戶輸入與while循環

# 【1】獲取用戶輸入 # 【2】python數據類型的轉換 input_res input("請輸入一個數字\n") if int(input_res) % 10 0:print("你輸入的數是10的倍數") else:print("你輸入的數不是10的倍數") # 【3】while循環,適合不知道循環多少次…

學習筆記(C++篇)—— Day 8

1.STL簡介 STL(standard template libaray-標準模板庫):是C標準庫的重要組成部分,不僅是一個可復用的組件庫,而且是一個包羅數據結構與算法的軟件框架。 2.STL的六大組件 先這樣,下一部分是string的內容,內容比較多&a…

ant+Jmeter+jenkins接口自動化,如何實現把執行失敗的接口信息單獨發郵件?

B站講的最好的自動化測試教程,工具框架附項目實戰一套速通,零基礎完全輕松掌握!自動化測試課程、web/app/接口 實現AntJMeterJenkins接口自動化失敗接口郵件通知方案 要實現只發送執行失敗的接口信息郵件通知,可以通過以下步驟實…

惡意Python包“psslib“實施拼寫錯誤攻擊,可強制關閉Windows系統

Socket威脅研究團隊發現一個名為psslib的惡意Python包,該軟件包偽裝成提供密碼安全功能,實則會突然關閉Windows系統。這個由化名umaraq的威脅行為者開發的軟件包,是對知名密碼哈希工具庫passlib的拼寫錯誤仿冒(typosquatting&…

云原生灰度方案對比:服務網格灰度(Istio ) 與 K8s Ingress 灰度(Nginx Ingress )

服務網格灰度與 Kubernetes Ingress 灰度是云原生環境下兩種主流的灰度發布方案,它們在架構定位、實現方式和適用場景上存在顯著差異。以下從多個維度對比分析,并給出選型建議: 一、核心區別對比 維度服務網格灰度(以 Istio 為例…