第P2周:Pytorch實現CIFAR10彩色圖片識別

  • 🍨 本文為🔗365天深度學習訓練營 中的學習記錄博客
  • 🍖 原作者:K同學啊

目標

  1. 實現CIFAR-10的彩色圖片識別
  2. 實現比P1周更復雜一點的CNN網絡

具體實現

(一)環境

語言環境:Python 3.10
編 譯 器: PyCharm
框 架: Pytorch 2.5.1

(二)具體步驟
1.
import torch  
import torch.nn as nn  
import matplotlib.pyplot as plt  
import torchvision  # 第一步:設置GPU  
def USE_GPU():  if torch.cuda.is_available():  print('CUDA is available, will use GPU')  device = torch.device("cuda")  else:  print('CUDA is not available. Will use CPU')  device = torch.device("cpu")  return device  device = USE_GPU()  

輸出:CUDA is available, will use GPU

  
# 第二步:導入數據。同樣的CIFAR-10也是torch內置了,可以自動下載  
train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True,  transform=torchvision.transforms.ToTensor())  
test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True,  transform=torchvision.transforms.ToTensor())  batch_size = 32  
train_dataload = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)  
test_dataload = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size)  # 取一個批次查看數據格式  
# 數據的shape為:[batch_size, channel, height, weight]  
# 其中batch_size為自己設定,channel,height和weight分別是圖片的通道數,高度和寬度。  
imgs, labels = next(iter(train_dataload))  
print(imgs.shape)  # 查看一下圖片  
import numpy as np  
plt.figure(figsize=(20, 5))  
for i, images in enumerate(imgs[:20]):  # 使用numpy的transpose將張量(C,H, W)轉換成(H, W, C),便于可視化處理  npimg = imgs.numpy().transpose((1, 2, 0))  # 將整個figure分成2行10列,并繪制第i+1個子圖  plt.subplot(2, 10, i+1)  plt.imshow(npimg, cmap=plt.cm.binary)  plt.axis('off')  
plt.show()  

輸出:
Files already downloaded and verified
Files already downloaded and verified
torch.Size([32, 3, 32, 32])
image.png

# 第三步,構建CNN網絡  
import torch.nn.functional as F  num_classes = 10  # 因為CIFAR-10是10種類型  
class Model(nn.Module):  def __init__(self):  super(Model, self).__init__()  # 提取特征網絡  self.conv1 = nn.Conv2d(3, 64, 3)  self.pool1 = nn.MaxPool2d(kernel_size=2)  self.conv2 = nn.Conv2d(64, 64, 3)  self.pool2 = nn.MaxPool2d(kernel_size=2)  self.conv3 = nn.Conv2d(64, 128, 3)  self.pool3 = nn.MaxPool2d(kernel_size=2)  # 分類網絡  self.fc1 = nn.Linear(512, 256)  self.fc2 = nn.Linear(256, num_classes)  # 前向傳播  def forward(self, x):  x = self.pool1(F.relu(self.conv1(x)))  x = self.pool2(F.relu(self.conv2(x)))  x = self.pool3(F.relu(self.conv3(x)))  x = torch.flatten(x, 1)  x = F.relu(self.fc1(x))  x = self.fc2(x)  return x  from torchinfo import summary  
# 將模型轉移到GPU中  
model = Model().to(device)  
summary(model)  

image.png

# 訓練模型  
loss_fn = nn.CrossEntropyLoss() # 創建損失函數  
learn_rate = 1e-2   # 設置學習率  
opt = torch.optim.SGD(model.parameters(), lr=learn_rate)    # 設置優化器  # 編寫訓練函數  
def train(dataloader, model, loss_fn, optimizer):  size = len(dataloader.dataset) # 訓練集的大小 ,這里一共是60000張圖片  num_batches = len(dataloader)   # 批次大小,這里是1875(60000/32=1875)  train_acc, train_loss = 0, 0    # 初始化訓練正確率和損失率都為0  for X, y in dataloader: # 獲取圖片及標簽,X-圖片,y-標簽(也是實際值)  X, y = X.to(device), y.to(device)  # 計算預測誤差  pred = model(X) # 網絡輸出預測值  loss = loss_fn(pred, y) # 計算網絡輸出的預測值和實際值之間的差距  # 反向傳播  optimizer.zero_grad()   # grad屬性歸零  loss.backward() # 反向傳播  optimizer.step()    # 第一步自動更新  # 記錄正確率和損失率  train_acc += (pred.argmax(1) == y).type(torch.float).sum().item()  train_loss += loss.item()  train_acc /= size  train_loss /= num_batches  return train_acc, train_loss  # 測試函數  
def test(dataloader, model, loss_fn):  size = len(dataloader.dataset) # 測試集大小,這里一共是10000張圖片  num_batches = len(dataloader)   # 批次大小 ,這里312,即10000/32=312.5,向上取整  test_acc, test_loss = 0, 0  # 因為是測試,因此不用訓練,梯度也不用計算不用更新  with torch.no_grad():  for imgs, target in dataloader:  imgs, target = imgs.to(device), target.to(device)  # 計算loss  target_pred = model(imgs)  loss = loss_fn(target_pred, target)  test_loss += loss.item()  test_acc += (target_pred.argmax(1) == target).type(torch.float).sum().item()  test_acc /= size  test_loss /= num_batches  return test_acc, test_loss  # 正式訓練  
epochs = 10  
train_acc, train_loss, test_acc, test_loss = [], [], [], []  for epoch in range(epochs):  model.train()  epoch_train_acc, epoch_train_loss = train(train_dataload, model, loss_fn, opt)  model.eval()  epoch_test_acc, epoch_test_loss = test(test_dataload, model, loss_fn)  train_acc.append(epoch_train_acc)  train_loss.append(epoch_train_loss)  test_acc.append(epoch_test_acc)  test_loss.append(epoch_test_loss)  template = 'Epoch:{:2d}, 訓練正確率:{:.1f}%, 訓練損失率:{:.3f}, 測試正確率:{:.1f}%, 測試損失率:{:.3f}'  print(template.format(epoch+1, epoch_train_acc * 100, epoch_train_loss, epoch_test_acc*100, epoch_test_loss))  print('Done')  # 結果可視化  
# 隱藏警告  
import warnings  
warnings.filterwarnings('ignore')   # 忽略警告信息  
plt.rcParams['font.sans-serif'] = ['SimHei']    # 正常顯示中文標簽  
plt.rcParams['axes.unicode_minus'] = False  # 正常顯示+/-號  
plt.rcParams['figure.dpi'] = 100    # 分辨率  epochs_range = range(epochs)  plt.figure(figsize=(12, 3))  plt.subplot(1, 2, 1)    # 第一張子圖  
plt.plot(epochs_range, train_acc, label='訓練正確率')  
plt.plot(epochs_range, test_acc, label='測試正確率')  
plt.legend(loc='lower right')  
plt.title('訓練和測試正確率比較')  plt.subplot(1, 2, 2)    # 第二張子圖  
plt.plot(epochs_range, train_loss, label='訓練損失率')  
plt.plot(epochs_range, test_loss, label='測試損失率')  
plt.legend(loc='upper right')  
plt.title('訓練和測試損失率比較')  plt.show()# 保存模型  
torch.save(model, './models/cnn-cifar10.pth')

image.png
再次設置epochs為50訓練結果:
image.png
epochs增加到100,訓練結果:
image.png
可以看到訓練集和測試集的差距有點大,不太理想。做一下數據增加試試:

data_transforms= {  'train': transforms.Compose([  transforms.RandomHorizontalFlip(),  transforms.ToTensor(),  ]),  'test': transforms.Compose([  transforms.ToTensor(),  ])  
}

在dataset中:

train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True,  transform=data_transforms['train'])  
test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=data_transforms['test'])

運行結果:
image.png
image.png
比較漂亮了,再調整batch_size=16和epochs=20,提高了近6個百分點。
image.png
batch_size=16,epochs=50:有第20輪左右的時候,驗證集的確認性基本就沒有再提高了。和上面基本一樣。
image.png

(三)總結
  1. epochs并不是越多越好。batch_size同樣的道理
  2. 數據增強確實可以提高模型訓練的準確性。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/web/62905.shtml
繁體地址,請注明出處:http://hk.pswp.cn/web/62905.shtml
英文地址,請注明出處:http://en.pswp.cn/web/62905.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

Quant connect的優勢和不足,學習曲線難

Quant connect的優勢和不足 Quant connect作為一個成熟的算法交易平臺,具有許多優勢,包括: 強大的回測功能:Quant connect提供了豐富的數據源和回測功能,可以對各種交易策略進行全面的回測和分析。 容易上手&#xf…

深入理解 Ansible Playbook:組件與實戰

目錄 1 playbook介紹 2 YAML語言 2.1語法簡介 2.2數據類型 3 Playbook核心組件 3.1 hosts組件 3.2 remote_user組件 3.3 task列表和action組件 3.4 handlers 3.5 tags組件 3.6 其他組件說明 1 playbook介紹 playbook 劇本是由一個或多個"play"組成的列表。…

2024年食堂采購系統源碼技術趨勢:如何開發智能的供應鏈管理APP

本篇文章,小編將與大家一同探討2024年食堂采購系統的技術趨勢,并提供開發更智能的供應鏈管理APP的策略。 一、2024年食堂采購系統的技術趨勢 1.人工智能與機器學習的深度應用 在2024年,AI和機器學習在食堂采購系統中的應用將更加普遍。這些…

代碼隨想錄-算法訓練營-番外(圖論01:圖論理論基礎,所有可到達的路徑)

day01 圖論part01 今日任務:圖論理論基礎/所有可到達的路徑 代碼隨想錄圖論視頻部分還沒更新 https://programmercarl.com/kamacoder/圖論理論基礎.html#圖的基本概念 day01 所有可達路徑 鄰接矩陣 import java.util.Scanner;import java.util.List;import java.util.ArrayL…

系統架構的演變

什么是系統架構? 系統架構是系統的一種整體的高層次的結構表示,它確定了系統的基本組織、組件之間的關系、組件與環境的關系,以及指導其設計和發展的原則。隨著技術的發展和業務需求的增長,系統架構經歷了從簡單到復雜、從集中到…

c++總復習

C 中多態性在實際項目中的應用場景 圖形繪制系統 描述:在一個圖形繪制軟件中,可能有多種圖形,如圓形、矩形、三角形等。這些圖形都有一個共同的操作,比如繪制(draw)。通過多態性,可以定義一個基…

pip離線安裝一個github倉庫

要使用pip安裝一個本地Git倉庫,你可以按照以下步驟操作: 確保你已經克隆了Git倉庫到本地。 進入倉庫所在的目錄。 使用pip安裝。 以下是具體的命令: 克隆Git倉庫到本地(替換下面的URL為你的倉庫URL) git clone https…

【從零開始入門unity游戲開發之——C#篇04】棧(Stack)和堆(Heap),值類型和引用類型,以及特殊的引用類型string

文章目錄 知識回顧一、棧(Stack)和堆(Heap)1、什么是棧和堆2、為什么要分棧和堆3、棧和堆的區別棧堆 4、總結 二、值類型和引用類型1、那么值類型和引用類型到底有什么區別呢?值類型引用類型 2、總結 三、特殊的引用類…

【C語言實現:用隊列模擬棧與用棧模擬隊列(LeetCode 225 232)】

LeetCode刷題記錄 🌐 我的博客主頁:iiiiiankor🎯 如果你覺得我的內容對你有幫助,不妨點個贊👍、留個評論?,或者收藏?,讓我們一起進步!📝 專欄系列:LeetCode…

【Python】Selenium 爬蟲的使用技巧和案例

引言 Selenium 是 Python 中功能強大的自動化測試工具,因其能夠操控瀏覽器進行模擬操作,被廣泛應用于網頁數據爬取。相比傳統的 requests 等庫,Selenium 能更好地應對動態加載內容和復雜交互場景。本文將詳細介紹 Selenium 爬蟲的使用技巧,并提供實際案例來幫助讀者快速上…

MySQL SQL語句性能優化

MySQL SQL語句性能優化指南 一、查詢設計優化1. 避免 SELECT *2. 使用 WHERE 進行條件過濾3. 避免在索引列上使用函數和表達式4. 使用 LIMIT 限制返回行數5. 避免使用子查詢6. 優化 JOIN 操作7. 避免全表掃描 二、索引優化1. 使用合適的索引2. 覆蓋索引3. 索引選擇性4. 多列索引…

Mybatis動態sql執行過程

動態SQL的執行原理主要涉及到在運行時根據條件動態地生成SQL語句,然后將其發送給數據庫執行。以下是動態SQL執行原理的詳細解釋: 一、接收參數 動態SQL首先會根據用戶的輸入或系統的條件接收參數。這些參數可以是查詢條件、更新數據等,它們…

java jar包加密 jar-protect

介紹 java 本身是開放性極強的語言,代碼也容易被反編譯,沒有語言層面的一些常規保護機制,jar包很容易被反編譯和破解。 受classfinal(已停止維護)設計啟發,針對springboot日常項目開發,重新編寫安全可靠的jar包加殼加密技術,用于保護軟件版權。 使用說…

Linux:Git

Git常見指令: git help xx_command git xx_command --help git --version 查看git版本git config --global user.name "xxx_name" 全局級別的簽名設置,全局的放在本用 git config --global user.ema…

【WiFi】WiFi中RSSI、SNR、NF之間關系及說明

RSSI(接收信號強度指示) 定義: RSSI 是一個相對值,用于表示接收到的無線信號的強度。它通常由無線設備的硬件(如無線網卡或無線芯片)直接提供。 計算: RSSI 的計算通常是由設備的無線芯片完成的…

提升音頻轉錄準確性:VAD技術的應用與挑戰

引言 在音頻轉錄技術飛速發展的今天,我們面臨著一個普遍問題:在嘈雜環境中,轉錄系統常常將非人聲誤識別為人聲,導致轉錄結果出現錯誤。例如,在whisper模式下,系統可能會錯誤地轉錄出“謝謝大家”。本文將探…

[ZMQ] -- ZMQ通信Protobuf數據結構 1

1、前言背景 工作需要域間實現zmq通信,剛開始需要比較簡單的數據結構,比如兩個bool,后面可能就需要傳輸比較大的數據,所以記錄下實現流程,至于為啥選擇proto數據結構去做大數據傳輸,可能是地平線也用這個&…

順序表的使用,對數據的增刪改查

主函數: 3.c #include "3.h"//頭文件調用 SqlListptr sql_cerate()//創建順序表函數 {SqlListptr ptr(SqlListptr)malloc(sizeof(SqlList));//在堆區申請連續的空間if(NULLptr){printf("創建失敗\n");return NULL;//如果沒有申請成功&#xff…

React和Vue中暴露子組件的屬性和方法給父組件用,并且控制子組件暴露的顆粒度的做法

React 在 React 中,forwardRef 是一種高級技術,它允許你將 ref 從父組件傳遞到子組件,從而直接訪問子組件的 DOM 節點或公開的方法。這對于需要操作子組件內部狀態或 DOM 的場景非常有用。為了使子組件能夠暴露其屬性和方法給父組件&#xf…

《C++ 實時視頻流物體跟蹤與行為分析全解析》

在當今科技飛速發展的時代,視頻監控與智能分析技術在眾多領域發揮著極為重要的作用。從安防監控到智能交通,從工業自動化到人機交互,利用 C 處理實時視頻流中的物體跟蹤和行為分析成為了熱門且極具挑戰性的研究與開發方向。本文將深入探討其中…