使用PyTorch進行熱狗圖像分類模型微調

本教程將演示如何使用PyTorch框架對預訓練模型進行微調,實現熱狗與非熱狗圖像的分類任務。我們將從數據準備開始,逐步完成數據加載、可視化等關鍵步驟。


1. 環境配置與庫導入

%matplotlib inline
import os
import torch
from torch import nn
from d2l import torch as d2l
import torchvision

2. 熱狗數據集準備

# 熱狗數據集配置
d2l.DATA_HUB['hotdog'] = (d2l.DATA_URL + 'hotdog.zip','fba480ffa8aa7e0febbb511d181409f899b9baa5')# 下載并加載數據集
data_dir = d2l.download_extract('hotdog')
train_imgs = torchvision.datasets.ImageFolder(os.path.join(data_dir, 'train'))
test_imgs = torchvision.datasets.ImageFolder(os.path.join(data_dir, 'test'))

3. 數據可視化

# 可視化訓練集樣本
import matplotlib.pyplot as plt# 設置畫布大小
plt.figure(figsize=(12, 8))# 繪制前16張圖片
for i, (image, label) in enumerate(train_imgs[:16]):plt.subplot(4, 4, i+1)plt.imshow(image)plt.title('hotdog' if label == 0 else 'not hotdog')plt.axis('off')plt.tight_layout()
plt.show()

輸出結果

array([<Axes: >, <Axes: >, <Axes: >, <Axes: >, <Axes: >, <Axes: >,<Axes: >, <Axes: >, <Axes: >, <Axes: >, <Axes: >, <Axes: >,<Axes: >, <Axes: >, <Axes: >, <Axes: >], dtype=object)

(實際運行時將顯示4x4網格排列的16張圖像,包含熱狗和其他食品的圖片)?

4.數據增強?

normalize = torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225]
)train_augs = torchvision.transforms.Compose([torchvision.transforms.RandomResizedCrop(224),torchvision.transforms.RandomHorizontalFlip(),torchvision.transforms.ToTensor(),normalize
])test_augs = torchvision.transforms.Compose([torchvision.transforms.Resize(256),torchvision.transforms.CenterCrop(224),torchvision.transforms.ToTensor(),normalize
])

5.定義并修改預訓練模型

# 使用預訓練的ResNet18模型
pretrained_net = torchvision.models.resnet18(pretrained=True)
print(pretrained_net.fc)  # 最后一層全連接層查看

輸出結果:

Linear(in_features=512, out_features=1000, bias=True)
# 修改最后一層,以適應我們二分類任務
finetune_net = torchvision.models.resnet18(pretrained=True)
finetune_net.fc = nn.Linear(finetune_net.fc.in_features, 2)
nn.init.xavier_uniform_(finetune_net.fc.weight)

6.微調模型

定義微調函數:

from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolderdef train_fine_tuning(net, lr, batch_size=128, num_epochs=5, param_group=True):train_iter = DataLoader(ImageFolder(os.path.join(data_dir,'train'), transform=train_augs),batch_size=batch_size,shuffle=True)test_iter = DataLoader(ImageFolder(os.path.join(data_dir,'test'), transform=test_augs),batch_size=batch_size,shuffle=False)devices = d2l.try_all_gpus()loss = nn.CrossEntropyLoss(reduction='mean')if param_group:params_lx = [param for name, param in net.named_parameters()if name not in ['fc.weight', 'fc.bias']]optim = torch.optim.SGD([{'params': params_lx},{'params': net.fc.parameters(), 'lr': lr * 10}], lr=lr, weight_decay=0.001)else:optim = torch.optim.SGD(net.parameters(), lr=lr, weight_decay=0.001)d2l.train_ch13(net, train_iter, test_iter, loss, optim, num_epochs, devices)

使用小的學習率進行微調:

train_fine_tuning(finetune_net, 5e-5)

輸出:

loss 0.006, train acc 0.606, test acc 0.599
18.3 examples/sec on [device(type='cuda', index=0)]

為了進行比較,所有模型參數初始化為隨機值?

scratch_net=torchvision.models.resnet18() # 沒有預訓練參數
scratch_net.fc=nn.Linear(scratch_net.fc.in_features,2) # 修改最后一層全連接層,輸出為2
train_fine_tuning(scratch_net,5e-4,param_group=False) # param_group=False使得所有層的參數都為默認的學習率   

輸出:

loss 0.005, train acc 0.752, test acc 0.750
10.6 examples/sec on [device(type='cuda', index=0)]

7.總結

本文完整展示了從數據準備到模型訓練的熱狗分類任務流程。關鍵步驟包括:

  1. 使用torchvision加載和預處理圖像數據

  2. 可視化數據集樣本

  3. 構建數據加載管道

  4. 修改預訓練模型進行微調

  5. 訓練和評估分類模型

實際應用中可以通過調整數據增強策略、嘗試不同網絡架構、優化超參數等方式進一步提升模型性能。后續可以擴展為部署到移動端的食品識別應用。


注意事項

  1. 確保GPU環境加速訓練

  2. 根據顯存調整batch_size大小

  3. 適當調整學習率等超參數

  4. 添加早停機制防止過擬合

希望本教程能幫助您快速上手PyTorch模型微調任務!

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/903972.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/903972.shtml
英文地址,請注明出處:http://en.pswp.cn/news/903972.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

內容中臺與企業內容管理核心差異剖析

功能定位與架構設計差異 在企業數字化進程中&#xff0c;內容中臺與企業內容管理&#xff08;ECM&#xff09;的核心差異首先體現在功能定位層面。傳統ECM系統以文檔存儲、版本控制及權限管理為核心&#xff0c;主要服務于企業內部知識庫的靜態管理需求&#xff0c;例如通過Ba…

使用PyMongo連接MongoDB的基本操作

MongoDB是由C語言編寫的非關系型數據庫&#xff0c;是一個基于分布式文件存儲的開源數據庫系統&#xff0c;其內容存儲形式類似JSON對象&#xff0c;它的字段值可以包含其他文檔、數組及文檔數組。在這一節中&#xff0c;我們就來回顧Python 3下MongoDB的存儲操作。 常用命令:…

第 12 屆藍橋杯 C++ 青少組中 / 高級組省賽 2021 年真題

一、選擇題 第 1 題 題目&#xff1a;下列符號中哪個在 C 中表示行注釋 ( )。 A. ! B. # C. ] D. // 正確答案&#xff1a;D 答案解析&#xff1a; 在 C 中&#xff0c;//用于單行注釋&#xff08;行注釋&#xff09;&#xff0c;從//開始到行末的內容會被編譯器忽略。選項 A…

【python】【UV】一篇文章學完新一代 Python 環境與包管理器使用指南

&#x1f40d; UV&#xff1a;新一代 Python 環境與包管理器使用指南 一、UV 是什么&#xff1f; UV 是由 Astral 團隊開發的高性能 Python 環境管理器&#xff0c;旨在統一替代 pyenv、pip、venv、pip-tools、pipenv 等工具。 1.1 UV 的主要功能 &#x1f680; 極速包安裝&…

前端性能優化2:結合HTTPS與最佳實踐,全面優化你的網站性能

點亮極速體驗&#xff1a;結合HTTPS與最佳實踐&#xff0c;為你詳解網站性能優化的道與術 在如今這個信息爆炸、用戶耐心極其有限的數字時代&#xff0c;網站的性能早已不是一個可選項&#xff0c;而是關乎生存和發展的核心競爭力。一個遲緩的網站&#xff0c;無異于在數字世界…

JavaWeb:vueaxios

一、簡介 什么是vue? 快速入門 <!-- 3.準備視圖元素 --><div id"app"><!-- 6.數據渲染 --><h1>{{ msg }}</h1></div><script type"module">// 1.引入vueimport { createApp, ref } from https://unpkg.com/vu…

Tauri聯合Vue開發中Vuex與Pinia關系及前景分析

在 TauriVue 的開發場景中&#xff0c;Vuex 和 Pinia 是兩種不同的狀態管理工具&#xff0c;它們的關系和前景可以從以下角度分析&#xff1a; 一、Vuex 與 Pinia 的關系 繼承與發展 Pinia 最初是作為 Vuex 5 的提案設計的&#xff0c;其目標是簡化 Vuex 的復雜性并更好地適配 …

Linux中的時間同步

一、時間同步服務擴展總結 1. 時間同步的重要性 多主機協作需求&#xff1a;在分布式系統、集群、微服務架構中&#xff0c;時間一致性是日志排序、事務順序、數據一致性的基礎。 安全協議依賴&#xff1a;TLS/SSL證書、Kerberos認證等依賴時間有效性&#xff0c;時間偏差可能…

【算法基礎】三指針排序算法 - JAVA

一、基礎概念 1.1 什么是三指針排序 三指針排序是一種特殊的分區排序算法&#xff0c;通過使用三個指針同時操作數組&#xff0c;將元素按照特定規則進行分類和排序。這種算法在處理包含有限種類值的數組時表現出色&#xff0c;最經典的應用是荷蘭國旗問題&#xff08;Dutch …

《操作系統真象還原》第十二章(2)——進一步完善內核

文章目錄 前言可變參數的原理實現系統調用write更新syscall.h更新syscall.c更新syscall-init.c 實現printf編寫stdio.h編寫stdio.c 第一次測試main.cmakefile結果截圖 完善printf修改main.c 結語 前言 上部分鏈接&#xff1a;《操作系統真象還原》第十二章&#xff08;1&#…

ICML2021 | DeiT | 訓練數據高效的圖像 Transformer 與基于注意力的蒸餾

Training data-efficient image transformers & distillation through attention 摘要-Abstract引言-Introduction相關工作-Related Work視覺Transformer&#xff1a;概述-Vision transformer: overview通過注意力機制蒸餾-Distillation through attention實驗-Experiments…

深度學習:AI 機器人時代

在科技飛速發展的當下&#xff0c;AI 機器人時代正以洶涌之勢席卷而來&#xff0c;而深度學習作為其核心驅動力&#xff0c;正重塑著我們生活與工作的方方面面。 從智能工廠的自動化生產&#xff0c;到家庭中貼心服務的智能助手&#xff0c;再到復雜環境下執行特殊任務的專業機…

《告別試錯式開發:TDD的精準質量鍛造術》

深度解鎖TDD&#xff1a;應用開發的創新密鑰 在應用開發的復雜版圖中&#xff0c;如何雕琢出高質量、高可靠性的應用&#xff0c;始終是開發者們不懈探索的核心命題。測試驅動開發&#xff08;TDD&#xff09;&#xff0c;作為一種顛覆性的開發理念與方法&#xff0c;正逐漸成…

應用層自定義協議序列與反序列化

目錄 一、網絡版計算器 二、網絡版本計算器實現 2.1源代碼 2.2測試結果 一、網絡版計算器 應用層定義的協議&#xff1a; 應用層進行網絡通信能否使用如下的協議進行通信呢&#xff1f; 在操作系統內核中是以這種協議進行通信的&#xff0c;但是在應用層禁止以這種協議進行…

Excel-CLI:終端中的輕量級Excel查看器

在數據驅動的今天&#xff0c;Excel 文件處理成為了我們日常工作中不可或缺的一部分。然而&#xff0c;頻繁地在圖形界面與命令行界面之間切換&#xff0c;不僅效率低下&#xff0c;而且容易出錯。現在&#xff0c;有了 Excel-CLI&#xff0c;一款運行在終端中的輕量級Excel查看…

百度后端開發一面

mutex, rwmutex 在Go語言中&#xff0c;Mutex&#xff08;互斥鎖&#xff09;和RWMutex&#xff08;讀寫鎖&#xff09;是用于管理并發訪問共享資源的核心工具。以下是它們的常見問題、使用場景及最佳實踐總結&#xff1a; 1. Mutex 與 RWMutex 的區別 Mutex: 互斥鎖&#xf…

STM32 IIC總線

目錄 IIC協議簡介 IIC總線系統結構 IIC總線物理層特點 IIC總線協議層 空閑狀態 應答信號 數據的有效性 數據傳輸 STM32的IIC特性及架構 STM32的IIC結構體 0.96寸OLED顯示屏 SSD1306框圖及引腳定義 4針腳I2C接口模塊原理圖 字節傳輸-I2C 執行邏輯框圖 命令表…

【unity游戲開發入門到精通——UGUI】整體控制一個UGUI面板的淡入淡出——CanvasGroup畫布組組件的使用

注意&#xff1a;考慮到UGUI的內容比較多&#xff0c;我將UGUI的內容分開&#xff0c;并全部整合放在【unity游戲開發——UGUI】專欄里&#xff0c;感興趣的小伙伴可以前往逐一查看學習。 文章目錄 前言CanvasGroup畫布組組件參數 實戰專欄推薦完結 前言 如果我們想要整體控制…

大型語言模型個性化助手實現

大型語言模型個性化助手實現 目錄 大型語言模型個性化助手實現PERSONAMEM,以及用戶資料和對話模擬管道7種原位用戶查詢類型關于大語言模型個性化能力評估的研究大型語言模型(LLMs)已經成為用戶在各種任務中的個性化助手,從提供寫作支持到提供量身定制的建議或咨詢。隨著時間…

生成式 AI 的未來

在人類文明的長河中,技術革命始終是推動社會躍遷的核心引擎。從蒸汽機解放雙手,到電力點亮黑夜,再到互聯網編織全球神經網絡,每一次技術浪潮都在重塑人類的生產方式與認知邊界。而今天,生成式人工智能(Generative AI)正以一種前所未有的姿態登上歷史舞臺——它不再局限于…