使用PyTorch實現圖像增廣與模型訓練實戰

本文通過完整代碼示例演示如何利用PyTorch和torchvision實現常用圖像增廣方法,并在CIFAR-10數據集上訓練ResNet-18模型。我們將從基礎圖像變換到復雜數據增強策略逐步講解,最終實現一個完整的訓練流程。


一、圖像增廣基礎操作

1.1 準備工作

#matplotlib inline
import torch
import torchvision
from torch import nn
from d2l import torch as d2ld2l.set_figsize()
img = d2l.Image.open('/workspace/data/cat.png')
d2l.plt.imshow(img)

1.2 圖像變換工具函數

def apply(img, aug, num_rows=2, num_cols=4, scale=1.5, titles=None):y = [aug(img) for _ in range(num_rows*num_cols)]d2l.show_images(y, num_rows, num_cols, titles, scale)

二、常用圖像增廣方法

2.1 水平/垂直翻轉

# 水平翻轉
apply(img, torchvision.transforms.RandomHorizontalFlip())# 垂直翻轉
apply(img, torchvision.transforms.RandomVerticalFlip())

2.2 隨機裁剪

shape_aug = torchvision.transforms.RandomResizedCrop((200,200), scale=(0.1,1), ratio=(0.5,2))
apply(img, shape_aug)

2.3 顏色調整

color_aug = torchvision.transforms.ColorJitter(brightness=0.5, contrast=0.2, saturation=0.3, hue=0.5)
apply(img, color_aug)

2.4 組合增廣策略

augs = torchvision.transforms.Compose([torchvision.transforms.RandomHorizontalFlip(),color_aug,shape_aug
])
apply(img, augs)

三、CIFAR-10數據增強實戰

3.1 數據加載與可視化

all_images = torchvision.datasets.CIFAR10(train=True, root='/workspace/data', download=True)
d2l.show_images([all_images[i][0] for i in range(32)], 4, 8, scale=0.8)

3.2 數據預處理配置

train_augs = torchvision.transforms.Compose([torchvision.transforms.RandomHorizontalFlip(),torchvision.transforms.ToTensor()
])test_augs = torchvision.transforms.ToTensor()

3.3 數據加載函數

def load_cifar10(is_train, augs, batch_size):dataset = torchvision.datasets.CIFAR10(root='../data', train=is_train,transform=augs, download=True)return torch.utils.data.DataLoader(dataset, batch_size=batch_size,shuffle=is_train, num_workers=4)

四、模型訓練實現

4.1 訓練核心函數

def train_batch_ch13(net, X, y, loss, trainer, devices):if isinstance(X, list):X = [x.to(devices[0]) for x in X]else:X = X.to(devices[0])y = y.to(devices[0])net.train()trainer.zero_grad()pred = net(X)l = loss(pred, y)l.sum().backward()trainer.step()train_loss_sum = l.sum()train_acc_sum = d2l.accuracy(pred, y)return train_loss_sum, train_acc_sum

4.2 模型初始化

batch_size = 1024
devices = d2l.try_all_gpus()
net = d2l.resnet18(10, 3)def init_weights(m):if type(m) in [nn.Linear, nn.Conv2d]:nn.init.xavier_uniform_(m.weight)net.apply(init_weights)

4.3 訓練入口函數

def train_with_data_aug(train_augs, test_augs, net, lr=0.001):train_iter = load_cifar10(True, train_augs, batch_size)test_iter = load_cifar10(False, test_augs, batch_size)loss = nn.CrossEntropyLoss(reduction='none')optimizer = torch.optim.Adam(net.parameters(), lr=lr)d2l.train_ch13(net, train_iter, test_iter, loss, optimizer, 10, devices)# 啟動訓練
train_with_data_aug(train_augs, test_augs, net)

五、訓練結果分析

執行訓練后可以看到類似如下輸出:

train loss 0.018, train acc 0.895
test acc 0.856

典型訓練過程特征:

  1. 訓練損失持續下降

  2. 驗證準確率穩步提升

  3. 最終測試準確率可達85%以上


六、關鍵知識點總結

  1. 圖像增廣作用:通過隨機變換增加數據多樣性,提升模型泛化能力

  2. 組合策略:合理組合幾何變換與顏色變換可以達到最佳效果

  3. 訓練技巧

    • 使用Xavier初始化保證參數合理分布

    • Adam優化器自動調整學習率

    • 多GPU并行加速訓練


七、擴展改進方向

1.嘗試更多增廣組合:

advanced_augs = torchvision.transforms.Compose([torchvision.transforms.RandomRotation(15),torchvision.transforms.RandomPerspective(),torchvision.transforms.RandomGrayscale(p=0.1)
])

2.調整網絡結構:

net = d2l.resnet50(10, 3)  # 使用更深層的ResNet-50

3.優化參數:

optimizer = torch.optim.SGD(net.parameters(), lr=0.01, momentum=0.9)

?完整代碼已通過測試,可直接復制到Jupyter Notebook中運行。實際效果可能因硬件配置有所差異,建議使用GPU環境進行訓練。如果遇到數據集下載問題,請檢查root參數指定的路徑是否正確。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/diannao/80347.shtml
繁體地址,請注明出處:http://hk.pswp.cn/diannao/80347.shtml
英文地址,請注明出處:http://en.pswp.cn/diannao/80347.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

解決Mac 安裝 PyICU 依賴失敗

失敗日志: 解決辦法 1、使用 homebrew 安裝相關依賴 brew install icu4c 安裝完成后,設置環境變量 echo export PATH"/opt/homebrew/opt/icu4c77/bin:$PATH" >> ~/.zshrcecho export PATH"/opt/homebrew/opt/icu4c77/sbin:$PATH…

Springboot后端查詢參數接收

1.實現方式 假設前端發送的接口: /users?nameJohn&age30 后端怎么接收里面的name和age呢?以及再發別的參數后端怎么接收呢? 1.比較簡單的方式 當控制器方法的參數類型是簡單類型(如 String、Integer、Long 等&#xff09…

桌面應用中VUE使用新瀏覽器窗口打開頁面

1、瀏覽器應用忽略此方式,可任意方式打開。針對桌面應用設置 newWindowClick(){try {this.fileUrl "";this.params.year ""this.params.date ""axios({method: post,url: /url/pdf/preview,data: this.params,}).then(res> {t…

華為手機怎么進行音頻降噪?音頻降噪技巧分享:提升聽覺體驗

在當今數字化時代,音頻質量對于提升用戶體驗至關重要,無論是在通話、視頻錄制還是音頻文件播放中,清晰的音頻都能帶來更佳的聽覺享受。 而華為手機憑借其強大的音頻處理技術,為用戶提供了多種音頻降噪功能,幫助用戶在…

【數據可視化-22】脫發因素探索的可視化分析

?? 博主簡介:曾任某智慧城市類企業算法總監,目前在美國市場的物流公司從事高級算法工程師一職,深耕人工智能領域,精通python數據挖掘、可視化、機器學習等,發表過AI相關的專利并多次在AI類比賽中獲獎。CSDN人工智能領域的優質創作者,提供AI相關的技術咨詢、項目開發和個…

青少年編程與數學 02-018 C++數據結構與算法 06課題、樹

青少年編程與數學 02-018 C數據結構與算法 06課題、樹 一、樹(Tree)1. 樹的定義2. 樹的基本術語3. 常見的樹類型4. 樹的主要操作5. 樹的應用 二、二叉樹(Binary Tree)1. 二叉樹的定義2. 二叉樹的基本術語3. 二叉樹的常見類型4. 二叉樹的主要操作5. 二叉樹的實現代碼說明輸出示例…

【論文閱讀】Visual Instruction Tuning

文章目錄 導言1、論文簡介2、論文主要方法3、論文針對的問題4、論文創新點總結 導言 本論文介紹了一個新興的多模態模型——LLaVA(Large Language and Vision Assistant),旨在通過指令調優提升大型語言模型(LLM)在視覺…

【學習筆記】Cadence電子設計全流程(三)Capture CIS 原理圖繪制(下)

【學習筆記】Cadence電子設計全流程(三)Capture CIS 原理圖繪制(下) 3.16 原理圖中元件的編輯與更新3.17 原理圖元件跳轉與查找3.18 原理圖常見錯誤設置于編譯檢查3.19 低版本原理圖文件輸出3.20 原理圖文件的鎖定與解鎖3.21 Orca…

js使用IntersectionObserver實現目標元素可見度的交互

文章目錄 1、前言2、代碼實現3、使用場景4、兼容性5、成熟的Hooks推薦 1、前言 IntersectionObserver 是瀏覽器原生提供的一個Api。可以"觀察"我們的元素是否可見,原理是判斷目標元素與可見區域的交叉比例,所以也被稱為"交叉觀察器"…

linux 中斷子系統 層級中斷編程

虛擬中斷控制器代碼&#xff1a; #include<linux/kernel.h> #include<linux/module.h> #include<linux/clk.h> #include<linux/err.h> #include<linux/init.h> #include<linux/interrupt.h> #include<linux/io.h> #include<linu…

蝦皮(Shopee)商品詳情 API 接口概述及 JSON 數據返回參考

前言 一、接口概述 Shopee 商品詳情 API 接口是 Shopee 平臺為開發者提供的&#xff0c;用于獲取商品詳細信息的接口服務。通過該接口&#xff0c;開發者可以獲取商品的標題、價格、庫存、描述、圖片、規格參數、銷量、評價等詳細信息。這些數據為電商數據分析、商品比價工具…

three.js中的instancedMesh類優化渲染多個同網格材質的模型

three.js小白的學習之路。 在上上一篇博客中&#xff0c;簡單驗證了一下three.js中的網格共享。寫的時候就有一些想法&#xff0c;如果說某個場景中有一萬棵樹&#xff0c;這些樹共享一個geometry和material&#xff0c;有沒有好的辦法將其進行一定程度上的渲染優化&#xff0…

MySQL-自定義函數

自定義函數 函數的作用 mysql數據庫中已經提供了內置的函數&#xff0c;比如&#xff1a;sum&#xff0c;avg&#xff0c;concat等等&#xff0c;方便我們日常的使用&#xff0c;當需要時mysql支持定義自定義的函數&#xff0c;方便與我們對于需用復用的功能進行封裝。 基本…

ESP32上C語言實現JSON對象的創建和解析

在ESP32上使用C語言實現JSON對象的創建和解析&#xff0c;同樣可以借助cJSON庫。ESP-IDF&#xff08;Espressif IoT Development Framework&#xff09;本身已經集成了cJSON庫&#xff0c;你可以直接使用。以下是詳細的步驟和示例代碼。 1. 創建一個新的ESP-IDF項目 首先&…

【FAQ】PCoIP 會話后物理工作站本地顯示器黑屏

# 問題 工作人員從家里建立了到辦公室工作站的 PCoIP 連接&#xff0c;該工作站安裝了 HP Anyware Graphics Agent&#xff0c;并且還連接了本地顯示器。然后&#xff0c;遠程用戶決定去辦公室進行本地工作&#xff0c;工作站顯示器顯示黑屏&#xff08;有時沒有信號&#xff…

el-table 目錄樹列表本地實現模糊查詢

table目錄樹結構實現模糊查詢 <el-form :model"queryParams" ref"queryForm" size"small" :inline"true" v-show"showSearch"><el-form-item label"名稱:" prop"Name"><el-input v-mode…

力扣hot100 LeetCode 熱題 100 Java 哈希篇

兩數之和 1. 兩數之和 - 力扣&#xff08;LeetCode&#xff09; 直接暴力 class Solution {public int[] twoSum(int[] nums, int target) {for(int i0;i<nums.length;i){for(int ji1;j<nums.length;j){long ans nums[i]nums[j];if(ans>target)continue;if(anstarg…

前后端部署

#在學習JavaWeb之后&#xff0c;進行了蒼穹外賣的學習。在進行蒼穹外賣的部署的時候&#xff0c;作者遇到了下面的問題# 1.前端工程nginx無法啟動&#xff1a; 當我雙擊已經部署好的nginx工程中nginx.exe文件的時候&#xff0c;在服務中&#xff0c;并沒有找到ngnix成功運行。…

基于 EFISH-SBC-RK3588 的無人機環境感知與數據采集方案

一、核心硬件架構設計? ?高性能算力引擎&#xff08;RK3588 處理器&#xff09;? ?異構計算架構?&#xff1a;集成 8 核 CPU&#xff08;4Cortex-A762.4GHz 4Cortex-A551.8GHz&#xff09;&#xff0c;支持動態調頻與多任務并行處理&#xff0c;單線程性能較傳統四核方案…

什么是Maven

Maven的概念 Maven是一個一鍵式的自動化的構建工具。Maven 是 Apache 軟件基金會組織維護的一款自動化構建工具&#xff0c;專注服務于Java 平臺的項目構建和依賴管理。Maven 這個單詞的本意是&#xff1a;專家&#xff0c;內行。Maven 是目前最流行的自動化構建工具&#xff0…