小白的進階之路系列之四----人工智能從初步到精通pytorch自定義數據集下

本篇涵蓋的內容

在之前的文章中,我們已經討論了如何獲取數據,轉換數據以及如何準備自定義數據集,本篇文章將涵蓋更加深入的問題,希望通過詳細的代碼示例,幫助大家了解PyTorch自定義數據集是如何應對各種復雜實際情況中,數據處理的。

更加詳細的,我們將討論下面一些內容:

主題內容
7 Model 0:沒有數據增強的TinyVGG到這個階段,我們已經準備好了數據,讓我們建立一個能夠擬合數據的模型。我們還將創建一些訓練和測試函數來訓練和評估我們的模型。
8 探索損失曲線損失曲線是觀察你的模型如何訓練/改進的好方法。它們也是一種很好的方法來判斷你的模型是過擬合還是欠擬合。
9 Model 1:帶數據增強功能的TinyVGG到目前為止,我們已經嘗試了一個沒有數據增強的模型?
10 比較模型結果讓我們比較不同模型的損失曲線,看看哪個表現更好,并討論一些改進性能的選項。
11 對自定義圖像進行預測我們的模型是在披薩、牛排和壽司圖像的數據集上訓練的。在本節中,我們將介紹如何使用我們訓練好的模型來預測現有數據集之外的圖像。

7 Model 0:沒有數據增強的TinyVGG

好了,我們已經看到了如何把數據從文件夾里的圖像變成變換后的張量。

現在讓我們構建一個計算機視覺模型,看看我們是否可以將圖像分類為披薩、牛排或壽司。

首先,我們將從一個簡單的變換開始,僅將圖像大小調整為(64,64)并將它們轉換為張量。

7.1 為模型0創建轉換和加載數據

# Create simple transform
simple_transform = transforms.Compose([ transforms.Resize((64, 64)),transforms.ToTensor(),
])

很好,現在我們有了一個簡單的變換,讓我們

  • 加載數據,首先使用torchvision.datasets.ImageFolder()將每個訓練和測試文件夾轉換為Dataset

  • 然后使用torch.utils.data.DataLoader())轉換為數據加載器。

  • 我們將把batch_size=32和num_workers設置為機器上盡可能多的cpu(這取決于您使用的機器)。

# 1. Load and transform data
from torchvision import datasets
train_data_simple = datasets.ImageFolder(root=train_dir, transform=simple_transform)
test_data_simple = datasets.ImageFolder(root=test_dir, transform=simple_transform)# 2. Turn data into DataLoaders
import os
from torch.utils.data import DataLoader# Setup batch size and number of workers 
BATCH_SIZE = 32
NUM_WORKERS = os.cpu_count()
print(f"Creating DataLoader's with batch size {BATCH_SIZE} and {NUM_WORKERS} workers.")# Create DataLoader's
train_dataloader_simple = DataLoader(train_data_simple, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)test_dataloader_simple = DataLoader(test_data_simple, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS)print(train_dataloader_simple, test_dataloader_simple)

輸出為:

Creating DataLoader's with batch size 32 and 16 workers.
<torch.utils.data.dataloader.DataLoader object at 0x0000024974F734D0> <torch.utils.data.dataloader.DataLoader object at 0x0000024974F07A80>

很好dataloader已經創建好了,現在讓我們設立模型。

7.2創建TinyVGG模型類

在上一篇文章中,我們使用了來自CNN解釋器網站的TinyVGG模型。

讓我們重新創建相同的模型,只不過這次我們將使用彩色圖像而不是灰度圖像(對于RGB像素,in_channels=3而不是in_channels=1)。

class TinyVGG(nn.Module):"""Model architecture copying TinyVGG from: https://poloclub.github.io/cnn-explainer/"""def __init__(self, input_shape: int, hidden_units: int, output_shape: int) -> None:super().__init__()self.conv_block_1 = nn.Sequential(nn.Conv2d(in_channels=input_shape, out_channels=hidden_units, kernel_size=3, # how big is the square that's going over the image?stride=1, # defaultpadding=1), # options = "valid" (no padding) or "same" (output has same shape as input) or int for specific number nn.ReLU(),nn.Conv2d(in_channels=hidden_units, out_channels=hidden_units,kernel_size=3,stride=1,padding=1),nn.ReLU(),nn.MaxPool2d(kernel_size=2,stride=2) # default stride value is same as kernel_size)self.conv_block_2 = nn.Sequential(nn.Conv2d(hidden_units, hidden_units, kernel_size=3, padding=1),nn.ReLU(),nn.Conv2d(hidden_units, hidden_units, kernel_size=3, padding=1),nn.ReLU(),nn.MaxPool2d(2))self.classifier = nn.Sequential(nn.Flatten(),# Where did this in_features shape come from? # It's because each layer of our network compresses and changes the shape of our input data.nn.Linear(in_features=hidden_units*16*16,out_features=output_shape))def forward(self, x: torch.Tensor):x = self.conv_block_1(x)# print(x.shape)x = self.conv_block_2(x)# print(x.shape)x = self.classifier(x)# print(x.shape)return x# return self.classifier(self.conv_block_2(self.conv_block_1(x))) # <- leverage the benefits of operator fusiontorch.manual_seed(42)
model_0 = TinyVGG(input_shape=3, # number of color channels (3 for RGB) hidden_units=10, output_shape=len(train_data.classes)).to(device)
print(model_0)

輸出為:

TinyVGG((conv_block_1): Sequential((0): Conv2d(3, 10, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(1): ReLU()(2): Conv2d(10, 10, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(3): ReLU()(4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False))(conv_block_2): Sequential((0): Conv2d(10, 10, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(1): ReLU()(2): Conv2d(10, 10, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(3): ReLU()(4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False))(classifier): Sequential((0): Flatten(start_dim=1, end_dim=-1)(1): Linear(in_features=2560, out_features=3, bias=True)</

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/907328.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/907328.shtml
英文地址,請注明出處:http://en.pswp.cn/news/907328.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

DeepSeek實戰:打造智能數據分析與可視化系統

DeepSeek實戰:打造智能數據分析與可視化系統 1. 數據智能時代:DeepSeek數據分析系統入門 在數據驅動的決策時代,智能數據分析系統正成為企業核心競爭力。本節將使用DeepSeek構建一個從數據清洗到可視化分析的全流程智能系統。 1.1 系統核心功能架構 class DataAnalysisS…

力扣100題---字母異位詞分組

1.字母異位詞分組 給你一個字符串數組&#xff0c;請你將 字母異位詞 組合在一起。可以按任意順序返回結果列表。 字母異位詞 是由重新排列源單詞的所有字母得到的一個新單詞。 方法一&#xff1a;字母排序 class Solution {public List<List<String>> groupAnagr…

使用子查詢在 SQL Server 中進行數據操作

在 SQL Server 中&#xff0c;子查詢&#xff08;Subquery&#xff09;是一種在查詢中嵌套另一個查詢的技術&#xff0c;可以用來執行復雜的查詢、過濾數據或進行數據計算。子查詢通常被用在 SELECT、INSERT、UPDATE 或 DELETE 語句中&#xff0c;可以幫助我們高效地解決問題。…

Flask集成pyotp生成動態口令

Python中的pyotp模塊是一個用于生成和驗證一次性密碼&#xff08;OTP&#xff09;的庫&#xff0c;支持基于時間&#xff08;TOTP&#xff09;和計數器&#xff08;HOTP&#xff09;的兩種主流算法。它遵循RFC 4226&#xff08;HOTP&#xff09;和RFC 6238&#xff08;TOTP&…

觸控精靈 ADB運行模式填寫電腦端IP教程

?ADB模式&#xff0c;如果你手機已經root則可以直接運行&#xff0c;無需安裝電腦端。 ?ADB模式&#xff0c;如果你手機沒有root&#xff0c;那你可以windows電腦下載【極限投屏】軟件&#xff0c;然后你的手機和電腦的網絡要同一個wifi&#xff0c;然后把你電腦的ip地址填寫…

【Python】 -- 趣味代碼 - 佩奇

文章目錄 文章目錄 00 佩奇程序設計框架1. 繪圖設置2. 繪制卡通人物的各個部分3. 主程序總結01 佩奇程序設計00 佩奇程序設計框架 這段代碼使用 turtle 模塊繪制了一個粉色的卡通人物圖像,主要功能包括繪制鼻子、頭、耳朵、眼睛、腮、嘴、身體、手、腳和尾巴等部分。代碼的主…

uniapp-商城-69-shop(2-商品列表,點擊商品展示,商品的詳情, vuex的使用,rich-text使用)

頁面中將我們的數據進行了羅列,對于單個數據的展示,還需要進行開發,這里使用了點擊商品后,進行彈窗展示。 同樣這里用一個組件來進行實現該彈窗的展示。 本文介紹了商品詳情彈窗的實現方案。主要采用Vuex進行狀態管理,通過幾個關鍵組件協同工作: 商品列表組件productItem…

C# Datatable篩選過濾各方式詳解

在C#中&#xff0c;DataTable提供了多種篩選過濾數據的方法&#xff0c;以下是常用的幾種方式及其特點&#xff1a; 1. ?Select方法篩選? 這是最基礎的篩選方式&#xff0c;支持類似SQL的表達式語法 // 單條件篩選 DataRow[] rows dt.Select("Age > 25");// …

計算機網絡中的路由算法:互聯網的“路徑規劃師”

計算機網絡中的路由算法&#xff1a;互聯網的“路徑規劃師” 當你打開瀏覽器&#xff0c;輸入 www.example.com 并敲下回車&#xff0c;數據會從你的電腦出發&#xff0c;穿越一個個路由器&#xff0c;最終抵達目標服務器。這一路上&#xff0c;數據包是怎么知道該走哪條路的&…

硬件工程師筆記——三極管Multisim電路仿真實驗匯總

目錄 1 三極管基礎 更多電子器件基礎知識匯總鏈接 1.1 工作原理 NPN型三極管的工作原理 PNP型三極管的工作原理 1.2 三極管的特性曲線 輸入特性曲線 理想和現實輸出特性 三極管的主要參數包括&#xff1a; 2 三極管伏安特性 2.1 伏安特性仿真 Multisim使用說明鏈接…

Linux 進階命令篇

一、Linux 系統軟件安裝命令 &#xff08;一&#xff09;Ubuntu 系統&#xff08;基于 Debian&#xff09; apt &#xff1a;是 Ubuntu 系統中常用的包管理工具&#xff0c;可以自動處理軟件依賴關系。 安裝命令格式 &#xff1a;sudo apt install 軟件名 示例 &#xff1a;…

LVS-DR 負載均衡群集

目錄 一、LVS-DR集群 1、LVS-DR 工作原理 2、數據包流向分析 3、LVS-DR 模式特點 二、直接路由模式&#xff08;LVS-DR&#xff09; 1、準備案例環境 2、配置負載調度器&#xff08;101&#xff09; &#xff08;1&#xff09;配置虛擬IP 地址&#xff08;VIP&#xff…

提升 GitHub Stats 的 6 個關鍵策略

哈哈&#xff0c;GitHub 的 “B-” 評級 其實是個玄學問題&#xff0c;但確實有一些 快速提升的技巧&#xff01;你的數據看起來 提交數&#xff08;147&#xff09;和 PR&#xff08;9&#xff09;不算少&#xff0c;但 Stars&#xff08;21&#xff09;和貢獻項目數&#xff…

常見的垃圾回收算法原理及其模擬實現

1.標記 - 清除&#xff08;Mark - Sweep&#xff09;算法&#xff1a; 這是一種基礎的垃圾回收算法。首先標記所有可達的對象&#xff0c;然后清除未被標記的對象。 缺點是會產生內存碎片。 原理&#xff1a; 如下圖分配一段內存&#xff0c;假設已經存儲上數據了 標記所有…

卷積神經網絡(CNN):原理、架構與實戰

卷積神經網絡&#xff08;CNN&#xff09;&#xff1a;原理、架構與實戰 卷積神經網絡&#xff08;Convolutional Neural Network, CNN&#xff09;是深度學習領域的一項重要突破&#xff0c;特別擅長處理具有網格結構的數據&#xff0c;如圖像、音頻和視頻。自 2012 年 AlexN…

RabbitMQ 集群與高可用方案設計(二)

三、為什么需要集群與高可用方案 &#xff08;一&#xff09;業務需求驅動 隨著業務的快速發展和用戶量的急劇增長&#xff0c;系統面臨的挑戰也日益嚴峻。在這種情況下&#xff0c;對消息隊列的可靠性、吞吐量和負載均衡能力提出了更高的要求&#xff0c;而單機部署的 Rabbi…

《ChatGPT o3抗命:AI失控警鐘還是成長陣痛?》

ChatGPT o3 “抗命” 事件起底 在人工智能的飛速發展進程中&#xff0c;OpenAI 于 2025 年推出的 ChatGPT o3 推理模型&#xff0c;猶如一顆重磅炸彈投入了技術的海洋&#xff0c;激起千層浪。它被視為 “推理模型” 系列的巔峰之作&#xff0c;承載著賦予 ChatGPT 更強大問題解…

RK3568DAYU開發板-平臺驅動開發:I2C驅動(原理、源碼、案例分析)

1、程序介紹 本程序是基于OpenHarmony標準系統編寫的平臺驅動案例&#xff1a;I2C 系統版本:openharmony5.0.0 開發板:dayu200 編譯環境:ubuntu22 部署路徑&#xff1a; //sample/04_platform_i2c 2、基礎知識 2.1、I2C簡介 I2C&#xff08;Inter Integrated Circuit&a…

在UniApp中開發微信小程序實現圖片、音頻和視頻下載功能

隨著微信小程序的迅猛發展&#xff0c;越來越多的開發者選擇通過UniApp框架來進行跨平臺應用開發。UniApp能夠讓開發者在一個代碼庫中同時發布iOS、Android和小程序等多平臺應用。而在實際開發過程中&#xff0c;很多應用都需要實現一些常見的下載功能&#xff0c;例如圖片、音…

鴻蒙5.0項目開發——接入有道大模型翻譯

鴻蒙5.0項目開發——接入有道大模型翻譯 【高心星出品】 項目效果圖 項目功能 文本翻譯功能 支持文本輸入和翻譯結果顯示 使用有道翻譯API進行翻譯 支持自動檢測語言&#xff08;auto&#xff09; 支持雙向翻譯&#xff08;源語言和目標語言可互換&#xff09; 文本操作…