使用DataLoader加載本地數據 食物分類案例

目錄

一.食物分類案例

1..整合訓練集測試集文檔

2.導入相關的庫

3.設置圖片數據的格式轉換

3.數據處理

4.數據打包

5.定義卷積神經網絡

6.創建模型

7.訓練和測試方法定義

8.損失函數和優化器

9.訓練模型,測試準確率

10.測試模型


之前我們DataLoader加載的Minist手寫數字集是已經封裝處理好的數據,所以我們可以直接使用,現在我們學習如何使用DataLoader加載本地數據

一.食物分類案例

1..整合訓練集測試集文檔

我們可以看到在food_dataset文件夾有訓練集和測試集兩個文件夾分別存放不同食物文件夾共20種食物的照片

為了方便后面模型的對數據的讀取和訓練,我們將每個圖片的地址和標簽(標簽可以自己定義)以空格分開都寫在一個txt文件中,形如:

定義一個函數用來完成填寫txt文件

需要注意的是os.walk()函數實在path里面漫步進入文件后還會出來進入下一個文件夾

第一次循環進入到train文件中directories為各食物名稱列表,len不為0,賦值給dirs方便后面命名標簽值

第二次循環時,root已經進入到第一個食物文件夾,遍歷該文件夾內的所有食物照片得到路徑,對root進行split()方便我們得到食物名now_dir[-1]后,再用dirs.index(now_dir[-1])獲取該食物在dires中的下標值作為標簽,最后將路徑與標簽寫入文件

遍歷完一個食物文件后順便在字典中保存對應的食物和其標簽

import os
dire={}
def train_test_file(root,dir):f_out=open(dir+'.txt','w')path=os.path.join(root,dir)for root,directories,files in os.walk(path):if len(directories)!=0:dirs=directorieselse:now_dir=root.split('\\')for file in files:path=os.path.join(root,file)f_out.write(path+' '+str(dirs.index(now_dir[-1]))+'\n')dire[dirs.index(now_dir[-1])]=now_dir[-1]f_out.close()

最后就是調用該函數,傳入相關路徑即可得到對應train.txt和test.txt

root=r'.\food_dataset'
train_dir='train'
test_dir='test'
train_test_file(root,train_dir)
train_test_file(root,test_dir)

2.導入相關的庫

import torch
from torch.utils.data import Dataset,DataLoader
import numpy as np
from PIL import Image
from torchvision import transforms

3.設置圖片數據的格式轉換

用字典來保存訓練集和測試集的相應格式轉換

transforms.Compose()是將一些格式的轉換組合在一起相當于一個容器

由于每個照片的大小都可能不同會影響到后面全連接層的展開輸入總個數,所以這里必須統一大小

還需要將數據轉化為Tensor張量類型

data_transforms={'train':transforms.Compose([transforms.Resize([256,256]),transforms.ToTensor()]),'valid':transforms.Compose([transforms.Resize([256,256]),transforms.ToTensor()])
}

3.數據處理

由于我們使用的是自己的數據,所以我們必須要讓我們的數據集可以通過[]即索引來獲取,這樣DataLoader才能拿到數據并打包

定義一個類來實現上述要求,傳入文件路徑和上述的圖片數據轉換方式即可

在初始化方法__inint__中,完成一些共享空間self賦值后,我們將傳入文件中的路徑和標簽分別保存在存儲路徑和存儲標簽的列表中

__len__方法也是必不可少的,DataLoader打包數據前會先檢查總數據的大小長度夠不夠再完成打包

__getitem__則是我們通過索引獲取信息的關鍵方法,只要使用索引就會調用該方法,我們在該方法中通過傳進來的索引我們可以通過之前的存儲列表獲取圖片的路徑和標簽,并用PIL庫的Image.open()方法讀取圖片后根據初始化時傳進來的轉化格式進行轉換,標簽則用torch.from_numpy()方法也轉化為tensor張量,最后返回圖片數據和標簽

class food_dataset(Dataset):#能通過索引的方式返回圖片數據和標簽結果def __init__(self,file_path,transform=None):self.file_path=file_pathself.imgs_paths=[]self.labels=[]self.transform=transformwith open(self.file_path) as f:samples=[x.strip().split(' ') for x in f.readlines()]for img_path,label in samples:self.imgs_paths.append(img_path)self.labels.append(label)def __len__(self):return len(self.imgs_paths)def __getitem__(self, idx):image=Image.open(self.imgs_paths[idx])if self.transform:image=self.transform(image)label=self.labels[idx]label=torch.from_numpy(np.array(label,dtype=np.int64))#label也轉化為tensorreturn image,label

創建該類的對象,傳入訓練集和測試集的路徑得到可以被DataLoader打包的數據

train_data=food_dataset(file_path='./train.txt',transform=data_transforms['train'])
test_data=food_dataset(file_path='./test.txt',transform=data_transforms['valid'])

4.數據打包

由于圖片大小比較大我們就將一個圖片數據打包成一個批次

train_loader=DataLoader(train_data,batch_size=1,shuffle=True)
test_loader=DataLoader(test_data,batch_size=1,shuffle=True)device='cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
print(f'Using {device} device')

5.定義卷積神經網絡

只需注意必須繼承nn.Module類,和我們此次訓練的是彩色圖片,所以第一個in_channel輸入通道數是3而不是之前手寫數字灰度圖的1,最后全連接層的輸出通道是20因為我們有20種食物

from  torch import  nn
class CNN(nn.Module):def __init__(self):super().__init__()#nn.Sequential()是將網絡層組合在一起,內部不能寫函數self.conv1=nn.Sequential(#1*3*256*256nn.Conv2d(in_channels=3,#輸入通道數out_channels=8,kernel_size=5,stride=1,padding=2),#1*8*256*256nn.ReLU(),nn.MaxPool2d(kernel_size=2)#1*8*128*128)self.conv2 = nn.Sequential(nn.Conv2d(8,16,5,1,2),#1*16*128*128nn.ReLU(),nn.Conv2d(16,32,5,1,2),#1*32*128*128nn.ReLU(),nn.MaxPool2d(kernel_size=2)##1*32*64*64)self.conv3 = nn.Sequential(nn.Conv2d(32,64,5,1,2),#1*64*64*64nn.ReLU(),nn.Conv2d(64, 64, 5, 1, 2),#1*64*64*64nn.ReLU())self.flatten=nn.Flatten()self.out=nn.Linear(64*64*64,20)def forward(self,x):x=self.conv1(x)x=self.conv2(x)x=self.conv3(x)# x=x.view(x.size(0),-1)x=self.flatten(x)output=self.out(x)return output

6.創建模型

model=CNN().to(device)

7.訓練和測試方法定義

與之前手寫數字的方法并無任何不同

def train(dataloader,model,loss_fn,optimizer):model.train()batch_size_num=1for X,y in dataloader:X,y=X.to(device),y.to(device)pred=model(X)loss=loss_fn(pred,y)optimizer.zero_grad()loss.backward()optimizer.step()loss_value=loss.item()if batch_size_num % 100 == 0:print(f'loss:{loss_value:>7f} [number:{batch_size_num}]')batch_size_num += 1
def test(dataloader,model,loss_fn):model.eval()len_data=len(dataloader.dataset)correct,loss_sum=0,0num_batch=0with torch.no_grad():for X, y in dataloader:X, y = X.to(device), y.to(device)pred = model(X)loss_sum += loss_fn(pred, y).item()correct+=(pred.argmax(1)==y).type(torch.float).sum().item()num_batch+=1loss_avg=loss_sum/num_batchaccuracy=correct/len_dataprint(f'Accuracy:{100 * accuracy}%\nLoss Avg:{loss_avg}')return pred.argmax(1)

8.損失函數和優化器

多分類問題選擇交叉熵損失函數

loss_fn=nn.CrossEntropyLoss()
optimizer=torch.optim.Adam(model.parameters(),lr=0.0001)

9.訓練模型,測試準確率

設置20輪訓練

epochs=20
for i in range(epochs):print(f'==========第{i + 1}輪訓練==============')train(train_loader, model, loss_fn, optimizer)print(f'第{i + 1}輪訓練結束')test(test_loader,model,loss_fn)

準確率很低,后續會改進

10.測試模型

自己輸入一個照片的路徑通過模型來判斷類別,我們需要將照片數據用上面的格式轉化處理,需要注意的是我們必須手動為其添加batch維度,因為這里沒用Dataloader加載數據不會自動添加batch維度

pred.argmax(1).item()是獲取前向傳播之后輸出的最大概率的標簽,.item()是將其轉化為可讀的形式

path=input('請輸入一個圖片地址: ')
image=Image.open(path)
image=data_transforms['valid'](image)
image=image.unsqueeze(0).to(device)#添加batch維度
# 注意使用DataLoader加載數據時,它會自動為批量數據添加 batch 維度
model.eval()
with torch.no_grad():pred=model.forward(image)label=pred.argmax(1).item()print('該圖片是: '+dire[label])

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/web/95390.shtml
繁體地址,請注明出處:http://hk.pswp.cn/web/95390.shtml
英文地址,請注明出處:http://en.pswp.cn/web/95390.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

從零開始的python學習——函數(2)

? ? ? ? ? づ?ど 🎉 歡迎點贊支持🎉 個人主頁:勵志不掉頭發的內向程序員; 專欄主頁:python學習專欄; 文章目錄 前言 一、變量作用域 二、函數執行過程 三、鏈式調用 四、嵌套調用 五、函數遞歸 六、…

RAG 的完整流程是怎么樣的?

RAG(檢索增強生成)的完整流程可分為5個核心階段:數據準備:清洗文檔、分塊處理(如PDF轉文本切片);向量化:使用嵌入模型(如BERT、BGE)將文本轉為向量&#xff1…

研發文檔版本混亂的根本原因是什么,怎么辦

研發文檔版本混亂的根本原因通常包括缺乏統一的版本控制制度、團隊協作不暢、文檔管理工具使用不當以及項目需求頻繁變化等因素。這些問題使得研發團隊在日常工作中容易出現文檔版本混亂的情況,導致信息的不一致性、溝通不暢以及開發進度的延誤。為了解決這一問題&a…

ChartView的基本使用

Qt ChartView(準確類名 QChartView)是 Qt Charts 模塊里最常用的圖表顯示控件。一句話概括:“它把 QChart 畫出來,并自帶縮放、平移、抗鋸齒等交互能力”。QML ChartView 簡介(一句話先記住:ChartView 是 Q…

系統擴展策略

1、核心指導思想:擴展立方體 在討論具體策略前,先了解著名的擴展立方體(Scale Cube),它定義了三種擴展維度: X軸:水平復制(克隆) 策略:通過負載均衡器&#…

HBuilder X 4.76 開發微信小程序集成 uview-plus

簡介 本文記錄了在HBuilder中創建并配置uni-app項目的完整流程。 首先創建項目并測試運行,確認無報錯后添加uView-Plus組件庫。 隨后修改了main.js、uni.scss、App.vue等核心文件,配置manifest.json并安裝dayjs、clipboard等依賴庫。 通過調整vite.c…

第4章:內存分析與堆轉儲

本章概述內存分析是 Java 應用性能調優的核心環節之一。本章將深入探討如何使用 VisualVM 進行內存分析,包括堆內存監控、堆轉儲生成與分析、內存泄漏檢測以及內存優化策略。通過本章的學習,你將掌握識別和解決內存相關問題的專業技能。學習目標理解 Jav…

面經分享一:分布式環境下的事務難題:理論邊界、實現路徑與選型邏輯

一、什么是分布式事務? 分布式事務是指事務的參與者、支持事務的服務器、資源服務器以及事務管理器分別位于不同的分布式系統的不同節點之上。 一個典型的例子就是跨行轉賬: 用戶從銀行A的賬戶向銀行B的賬戶轉賬100元。 這個操作包含兩個步驟: 從A賬戶扣減100元。 向B賬戶…

C++的演化歷史

C是一門這樣的編程語言: 兼顧底層計算機硬件系統和高層應用抽象機制從實際問題出發,注重零成本抽象、性能、可移植性、與C兼容語言特性和細節很多,學習成本較高,是一門讓程序員很難敢說精通的語言 C是自由的,支持5種…

Qt6實現繪圖工具:12種繪圖工具全家桶!這個項目滿足全部2D場景

項目概述 一個基于Qt框架開發的專業繪圖工具,實現了完整的2D圖形繪制、編輯和管理功能。該項目采用模塊化設計,包含圖形繪制、圖層管理、命令模式撤銷重做、用戶界面等多個子系統,是學習現代C++和Qt框架的最佳實踐。 核心功能特性 12種專業繪圖工具 多圖層繪制系統 完整的…

Linux驅動開發學習筆記

第1章 Linux驅動開發的方式mmap映射型設計方法。【不推薦】將芯片上的物理地址映射到用戶空間的虛擬地址上,用戶操作虛擬地址來操作硬件。使用文件操作集(file_operatiopns)設計方法。【極致推薦】platfrom總線型設置方法。【比較流行】設備樹。【推薦】第2章 Linux…

mac中進行適用于IOS的靜態庫構建

前沿: 進行C開發完成之后,需要將代碼編譯成靜態庫,并且在IOS的手機系統中執行,因此記錄該實現過程. 1主要涉及內容 1.1 整體文件架構 gongyonglocalhost Attention % tree -L 2 . ├── build │ ├── __.SYMDEF │ ├── cmake_install.cmake │ ├── CMakeCache…

C++二維數組的前綴和

C二維數組的前綴和的方法很簡單&#xff0c;可以利用公式res[i][j]arr[i][j]res[i-1][j]prefix[i][j-1]-res[i-1][j-1]。輸入4 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16輸出1 3 6 10 6 14 24 36 15 33 54 78 28 60 96 136#include<bits/stdc.h> using namespace std; int…

Wifi開發上層學習1:實現一個wifi搜索以及打開的app

Wifi開發上層學習1&#xff1a;實現一個wifi搜索以及打開的app 文章目錄Wifi開發上層學習1&#xff1a;實現一個wifi搜索以及打開的app背景demo實現1.添加系統權限以及系統簽名2.布局配置3.邏輯設計3.1 wifi開關的實現3.2 wifi掃描功能3.3 連接wifi總結一、WiFi 狀態控制接口二…

【DSP28335 入門教程】定時器中斷:為你的系統注入精準的“心跳”

大家好&#xff0c;歡迎來到 DSP28335 的核心精講系列。我們已經掌握了如何通過外部中斷來響應“外部事件”&#xff0c;但系統內部同樣需要一個精準的節拍器來處理“內部周期性任務”。單純依靠 DELAY_US() 這樣的軟件延時&#xff0c;不僅精度差&#xff0c;而且會在延時期間…

從零開始:用代碼解析區塊鏈的核心工作原理

區塊鏈技術被譽為信任的機器&#xff0c;它正在重塑金融、供應鏈、數字身份等眾多領域。但對于許多開發者來說&#xff0c;它仍然像一個神秘的黑盒子。今天&#xff0c;我們將拋開炒作的泡沫&#xff0c;深入技術本質&#xff0c;用大約100行Python代碼構建一個簡易的區塊鏈&am…

網絡通信IP細節

目錄 1.通信的NAT技術 2.代理服務器 3.內網穿透和內網打洞 1.通信的NAT技術 NAT技術產生的背景是我們為了解決IPV4不夠用的問題&#xff0c;NAT在通信的時候可以對IP將私網IP轉化為公網IP&#xff0c;全局IP要求唯一&#xff0c;但是私人IP不是唯一的。 將報文發給路由器進行…

國內真實的交換機、路由器和分組情況

一、未考慮擁擠情況理想狀態的網絡通信 前面我對骨干網&#xff1a; 宜春城區SDH網圖分析-CSDN博客 數據鏈路層MAC傳輸&#xff1a; 無線通信網卡底層原理&#xff08;Inter Wi-Fi AX201&#xff09;_ax201ngw是cnvio轉pci-e-CSDN博客 物理層、數據鏈路層、網絡層及傳輸層…

atomic常用類方法

Java中的java.util.concurrent.atomic包提供了多種原子操作工具類&#xff0c;以下是核心類及其方法&#xff1a;?1. AtomicBoolean??方法?&#xff1a;get()&#xff1a;獲取當前值set(boolean newValue)&#xff1a;強制設置值compareAndSet(boolean expect, boolean upd…

算法題打卡力扣第3題:無重復字符的最長子串(mid)

文章目錄題目描述解法一&#xff1a;暴力解解法二&#xff1a;滑動窗口題目描述 解法一&#xff1a;暴力解 遍歷每一個可能的子串&#xff0c;然后逐一判斷每個子串中是否有重復字符。 具體步驟&#xff1a; 使用兩層嵌套循環來生成所有子串的起止位置&#xff1a; 外層循環 i…