深度學習——基于卷積神經網絡實現食物圖像分類【4】(使用最優模型)

文件目錄

    • 引言
    • 一、環境準備
    • 二、數據預處理
      • 訓練集預處理說明:
      • 驗證集預處理說明:
    • 三、自定義數據集類
    • 四、設備選擇
    • 五、CNN模型構建
    • 六、模型加載與評估
      • 1. 加載預訓練模型
      • 2. 準備測試數據
      • 3. 測試函數
      • 4. 計算準確率
    • 七、完整代碼
    • 八、總結

引言

本文將詳細介紹如何使用PyTorch框架構建一個完整的食物圖像分類系統,包含數據預處理、模型構建、訓練優化以及模型保存等關鍵環節。與上一篇博客介紹的版本相比,本版本增加了使用最優模型這一流程。

一、環境準備

首先,我們需要導入必要的Python庫:

import torch
import torchvision.models as models
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import numpy as np
import os

這些庫中:

  • torchtorchvision是PyTorch的核心庫
  • DatasetDataLoader用于數據加載和處理
  • transforms提供圖像預處理功能
  • PIL用于圖像處理
  • numpy用于數值計算

二、數據預處理

數據預處理是深度學習項目中至關重要的一環。PyTorch提供了transforms模塊來方便地進行圖像預處理:

data_transforms = {'train': transforms.Compose([transforms.Resize([300,300]),transforms.RandomRotation(45),transforms.CenterCrop(256),transforms.RandomHorizontalFlip(p=0.5),transforms.RandomVerticalFlip(p=0.5),transforms.ColorJitter(brightness=0.2, contrast=0.1, saturation=0.1, hue=0.1),transforms.RandomGrayscale(p=0.1),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),'valid': transforms.Compose([transforms.Resize([256,256]),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
}

訓練集預處理說明:

  1. Resize([300,300]):將圖像調整為300×300像素
  2. RandomRotation(45):隨機旋轉圖像(-45°到45°之間)
  3. CenterCrop(256):從中心裁剪256×256的區域
  4. RandomHorizontalFlip(p=0.5):以50%概率水平翻轉圖像
  5. RandomVerticalFlip(p=0.5):以50%概率垂直翻轉圖像
  6. ColorJitter:隨機調整亮度、對比度、飽和度和色調
  7. RandomGrayscale(p=0.1):以10%概率將圖像轉為灰度
  8. ToTensor():將PIL圖像轉為PyTorch張量
  9. Normalize:標準化處理(使用ImageNet的均值和標準差)

驗證集預處理說明:

驗證集的預處理相對簡單,只包括調整大小、轉為張量和標準化,因為驗證階段不需要數據增強。

三、自定義數據集類

PyTorch的Dataset類允許我們自定義數據加載方式。我們創建了一個food_dataset類:

class food_dataset(Dataset):def __init__(self, file_path, transform=None):self.file_path = file_pathself.imgs = []self.labels = []self.transform = transformwith open(self.file_path) as f:samples = [x.strip().split(' ') for x in f.readlines()]for img_path, label in samples:self.imgs.append(img_path)self.labels.append(label)def __len__(self):return len(self.imgs)def __getitem__(self, idx):image = Image.open(self.imgs[idx])if self.transform:image = self.transform(image)label = self.labels[idx]label = torch.from_numpy(np.array(label, dtype=np.int64))return image, label

這個類的主要功能:

  1. __init__:初始化函數,讀取包含圖像路徑和標簽的文本文件
  2. __len__:返回數據集大小
  3. __getitem__:根據索引返回圖像和對應的標簽

四、設備選擇

PyTorch支持在CPU、GPU(CUDA)和蘋果M系列芯片(MPS)上運行。我們使用以下代碼自動選擇可用設備:

device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
print(f"Using {device} device")

五、CNN模型構建

我們構建了一個簡單的CNN模型,包含三個卷積塊和一個全連接層:

class CNN(nn.Module):def __init__(self):super(CNN, self).__init__()self.conv1 = nn.Sequential(nn.Conv2d(3, 16, 5, 1, 2),nn.ReLU(),nn.MaxPool2d(kernel_size=2),)self.conv2 = nn.Sequential(nn.Conv2d(16, 32, 5, 1, 2),nn.ReLU(),nn.MaxPool2d(kernel_size=2),)self.conv3 = nn.Sequential(nn.Conv2d(32, 64, 5, 1, 2),nn.ReLU(),nn.MaxPool2d(kernel_size=2),)self.out = nn.Linear(64*32*32, 20)def forward(self, x):x = self.conv1(x)x = self.conv2(x)x = self.conv3(x)x = x.view(x.size(0), -1)output = self.out(x)return output

模型結構說明:

  1. conv1:輸入3通道,輸出16通道,5×5卷積核,ReLU激活,2×2最大池化
  2. conv2:輸入16通道,輸出32通道,同上結構
  3. conv3:輸入32通道,輸出64通道,同上結構
  4. out:全連接層,將64×32×32的特征圖映射到20個類別

六、模型加載與評估

1. 加載預訓練模型

model = CNN().to(device)
model.load_state_dict(torch.load("best2025-04.pth"))
model.eval()

2. 準備測試數據

test_data = food_dataset(file_path='test.txt', transform=data_transforms['valid'])
test_dataloader = DataLoader(test_data, batch_size=1, shuffle=True)

3. 測試函數

result = []
labels = []def Test_true(dataloader, model):model.eval()with torch.no_grad():for X, y in dataloader:X, y = X.to(device), y.to(device)pred = model.forward(X)result.append(pred.argmax(1).item())labels.append(y.item())Test_true(test_dataloader, model)

4. 計算準確率

from sklearn.metrics import accuracy_score
accuracy = accuracy_score(labels, result)
print(f"準確率:{accuracy:.2%}")

七、完整代碼

import torch
import torchvision.models as models
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import numpy as np
import osdata_transforms = { #字典'train':transforms.Compose([            #對圖片預處理的組合transforms.Resize([300,300]),   #對數據進行改變大小transforms.RandomRotation(45),  #隨機旋轉,-45到45之間隨機選transforms.CenterCrop(256),     #從中心開始裁剪[256,256]transforms.RandomHorizontalFlip(p=0.5),#隨機水平翻轉,p是指選擇一個概率翻轉,p=0.5表示百分之50transforms.RandomVerticalFlip(p=0.5),#隨機垂直翻轉transforms.ColorJitter(brightness=0.2,contrast=0.1,saturation=0.1,hue=0.1),transforms.RandomGrayscale(p=0.1),#概率轉換成灰度率,3通道就是R=G=Btransforms.ToTensor(),#數據轉換為tensortransforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])#標準化,均值,標準差]),'valid':transforms.Compose([transforms.Resize([256,256]),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])  # 標準化,均值,標準差]),
}#Dataset是用來處理數據的
class food_dataset(Dataset):        # food_dataset是自己創建的類名稱,可以改為你需要的名稱def __init__(self,file_path,transform=None):    #類的初始化,解析數據文件txtself.file_path = file_pathself.imgs = []self.labels = []self.transform = transformwith open(self.file_path) as f: #是把train.txt文件中的圖片路徑保存在self.imgssamples = [x.strip().split(' ') for x in f.readlines()]for img_path,label in samples:self.imgs.append(img_path)  #圖像的路徑self.labels.append(label)   #標簽,還不是tensor# 初始化:把圖片目錄加到selfdef __len__(self):  #類實例化對象后,可以使用len函數測量對象的個數return  len(self.imgs)#training_data[1]def __getitem__(self, idx):    #關鍵,可通過索引的形式獲取每一個圖片的數據及標簽image = Image.open(self.imgs[idx])  #讀取到圖片數據,還不是tensor,BGRif self.transform:                  #將PIL圖像數據轉換為tensorimage = self.transform(image)   #圖像處理為256*256,轉換為tensorlabel = self.labels[idx]    #label還不是tensorlabel = torch.from_numpy(np.array(label,dtype=np.int64))    #label也轉換為tensorreturn image,label'''判斷當前設備是否支持GPU,其中mps是蘋果m系列芯片的GPU'''
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
print(f"Using {device} device")   #字符串的格式化,CUDA驅動軟件的功能:pytorch能夠去執行cuda的命令
# 神經網絡的模型也需要傳入到GPU,1個batch_size的數據集也需要傳入到GPU,才可以進行訓練''' 定義神經網絡  類的繼承這種方式'''
class CNN(nn.Module): #通過調用類的形式來使用神經網絡,神經網絡的模型,nn.mdouledef __init__(self): #輸入大小:(3,256,256)super(CNN,self).__init__()  #初始化父類self.conv1 = nn.Sequential( #將多個層組合成一起,創建了一個容器,將多個網絡組合在一起nn.Conv2d(              # 2d一般用于圖像,3d用于視頻數據(多一個時間維度),1d一般用于結構化的序列數據in_channels=3,      # 圖像通道個數,1表示灰度圖(確定了卷積核 組中的個數)out_channels=16,     # 要得到多少個特征圖,卷積核的個數kernel_size=5,      # 卷積核大小 3×3stride=1,           # 步長padding=2,          # 一般希望卷積核處理后的結果大小與處理前的數據大小相同,效果會比較好),                      # 輸出的特征圖為(16,256,256)nn.ReLU(),  # Relu層,不會改變特征圖的大小nn.MaxPool2d(kernel_size=2),    # 進行池化操作(2×2操作),輸出結果為(16,128,128))self.conv2 = nn.Sequential(nn.Conv2d(16,32,5,1,2),  #輸出(32,128,128)nn.ReLU(),  #Relu層  (32,128,128)nn.MaxPool2d(kernel_size=2),    #池化層,輸出結果為(32,64,64))self.conv3 = nn.Sequential(nn.Conv2d(32, 64, 5, 1, 2),  # 輸出(64,64,64)nn.ReLU(),  # Relu層  (64,64,64)nn.MaxPool2d(kernel_size=2),  # 池化層,輸出結果為(64,32,32))self.out = nn.Linear(64*32*32,20)  # 全連接層得到的結果def forward(self,x):   #前向傳播,你得告訴它 數據的流向 是神經網絡層連接起來,函數名稱不能改x = self.conv1(x)x = self.conv2(x)x = self.conv3(x)x = x.view(x.size(0),-1)    # flatten操作,結果為:(batch_size,32 * 64 * 64)output = self.out(x)return output
# 提取模型的2種方法:
#   1、讀取參數的方法
model = CNN().to(device) #初始化模型,w都是隨機初始化的
model.load_state_dict(torch.load("best2025-04.pth"))
#   2、讀取完整模型的方法,無需提前創建model
#   model = CNN().to(device)
#   model = torch.load('best.pt')#w,b,cnn
# 模型保存的對不對?
model.eval() #固定模型參數和數據,防止后面被修改
print(model)test_data = food_dataset(file_path='test.txt', transform = data_transforms['valid'])
test_dataloader = DataLoader(test_data,batch_size=1,shuffle=True)result = [] #保存的預測的結果
labels = [] #真實結果def Test_true(dataloader,model):model.eval()        #測試,w就不能再更新with torch.no_grad():   #一個上下文管理器,關閉梯度計算。當你確認不會調用Tensor.backward()的時候for X,y in dataloader:X,y = X.to(device),y.to(device)pred = model.forward(X) #預測之后的結果result.append(pred.argmax(1).item())labels.append(y.item())
Test_true(test_dataloader,model)
print('預測值:\t',result)
print('真實值:\t',labels)from sklearn.metrics import accuracy_score
accuracy = accuracy_score(labels,result)
print(f"準確率:{accuracy:.2%}")

八、總結

本文詳細介紹了使用PyTorch實現圖像分類任務的完整流程,包括:

  1. 數據預處理與增強
  2. 自定義數據集類
  3. CNN模型構建
  4. 模型加載與評估

關鍵點:

  • 數據增強可以提高模型的泛化能力
  • 自定義Dataset類可以靈活處理不同格式的數據
  • CNN是圖像分類任務的經典模型結構
  • 模型評估需要使用eval()模式和torch.no_grad()上下文

通過這個示例,讀者可以掌握PyTorch進行圖像分類的基本方法,并可以根據自己的需求調整模型結構和數據處理方式。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/bicheng/85645.shtml
繁體地址,請注明出處:http://hk.pswp.cn/bicheng/85645.shtml
英文地址,請注明出處:http://en.pswp.cn/bicheng/85645.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

C++基礎算法————并查集

C++并查集詳解與實戰指南 一、引言 并查集(Union-Find)是一種高效的數據結構,用于處理一些不相交集合的合并與查詢問題。它在圖論、社交網絡、網絡連通性等領域有廣泛的應用。并查集的核心思想是通過一個數組來記錄每個元素的父節點,從而將元素組織成若干棵樹,每棵樹代表…

系統性能優化的關鍵手段

系統性能的提升方向 服務器并發處理能力:通過優化內存管理策略、選擇合適的連接模式(長連接或短連接)、改進 I/O 模型(如 epoll、IOCP)、以及采用高效的服務器并發策略(如多線程、事件驅動等)&a…

httpclient實現http連接池

HTTP連接池是一種優化網絡通信性能的技術,通過復用已建立的TCP連接減少重復握手開銷,提升資源利用率。以下是關鍵要點: 核心原理與優勢 ?連接復用機制? 維護活躍連接隊列,避免每次請求重復TCP三次握手/SSL協商,降低…

廣義焦點丟失:學習用于密集目標檢測的合格和分布式邊界盒之GFL論文閱讀

摘要 一階段檢測器通常將目標檢測形式化為密集的分類與定位(即邊界框回歸)問題。分類部分通常使用 Focal Loss 進行優化,而邊界框位置則在狄拉克δ分布下進行學習。最近,一階段檢測器的發展趨勢是引入獨立的預測分支來估計定位質量,所預測的質量可以輔助分類,從而提升檢…

Real-World Deep Local Motion Deblurring論文閱讀

Real-World Deep Local Motion Deblurring 1. 研究目標與實際問題意義1.1 研究目標1.2 實際問題1.3 產業意義2. 創新方法:LBAG模型與關鍵技術2.1 整體架構設計2.2 關鍵技術細節2.2.1 真實模糊掩碼生成(LBFMG)2.2.2 門控塊(Gate Block)2.2.3 模糊感知補丁裁剪(BAPC)2.3 損…

【Docker基礎】Docker鏡像管理:docker commit詳解

目錄 引言 1 docker commit命令概述 1.1 什么是docker commit 1.2 使用場景 1.3 優缺點分析 2 docker commit命令詳解 2.1 基本語法 2.2 常用參數選項 2.3 實際命令示例 2.4 提交流程 2.5 步驟描述 3 docker commit與Dockerfile構建對比 3.1 構建流程對比 3.2 對…

可調式穩壓二極管

1.與普通穩壓二極管的比較: 項目普通穩壓二極管可調式穩壓二極管(如 TL431)輸出電壓固定(如5.1V、3.3V)可調(2.5V ~ 36V,取決于外部分壓)精度低(5%~10%)高&a…

Kafka使用Elasticsearch Service Sink Connector直接傳輸topic數據到Elasticsearch

鏈接:Elasticsearch Service Sink Connector for Confluent Platform | Confluent Documentation 鏈接:Apache Kafka 一、搭建測試環境 下載Elasticsearch Service Sink Connector https://file.zjwlyy.cn/confluentinc-kafka-connect-elasticsearch…

訊方“教學有方”平臺獲華為昇騰應用開發技術認證!

教學有方 華為昇騰應用開發技術認證 權威認證 彰顯實力 近日,訊方技術自研的教育行業大模型平臺——“教學有方”,成功獲得華為昇騰應用開發技術認證。這一認證不僅是對 “教學有方” 平臺技術實力的高度認可,更標志著訊方在智慧教育領域的…

保護你的Electron應用:深度解析asar文件與Virbox Protector的安全策略

在現代軟件開發中,Electron框架因其跨平臺特性而備受開發者青睞。然而,隨著Electron應用的普及,如何保護應用中的核心資源文件——asar文件,成為了開發者必須面對的問題。今天,我們將深入探討asar文件的特性&#xff0…

端口安全配置示例

組網需求 如圖所示,用戶PC1、PC2、PC3通過接入設備連接公司網絡。為了提高用戶接入的安全性,將接入設備Router的接口使能端口安全功能,并且設置接口學習MAC地址數的上限為接入用戶數,這樣其他外來人員使用自己帶來的PC無法訪問公…

零基礎RT-thread第四節:電容按鍵

電容按鍵 其實只需要理解,手指按上去后充電時間變長,我們可以利用定時器輸入捕獲功能計算充電時間,超過無觸摸時的充電時間一定的閾值就認為是有手指觸摸。 基本原理就是這樣,我們開始寫代碼: 其實,看過了…

SQL基礎操作:從增刪改查開始

好的!SQL(Structured Query Language)是用于管理關系型數據庫的標準語言。讓我們從最基礎的增刪改查(CRUD)?? 操作開始學習,我會用簡單易懂的方式講解每個操作。 🛠 準備工作(建表…

vim 編輯模式/命令模式/視圖模式常用命令

以下是一份 Vim 命令大全,涵蓋 編輯模式(Insert Mode)、命令模式(Normal Mode) 和 視圖模式(Visual Mode) 的常用操作,適合初學者和進階用戶使用。 🧾 Vim 模式簡介 Vim…

每天看一個Fortran文件(10)

今天來看下MCV模式調用物理過程的相關代碼。我想改進有關于海氣邊界層方面的內容,因此我尋找相關的代碼,發現在physics目錄下有一個sfc_ocean.f的文件。 可以看見這個文件是在好多好多年前更新的了,里面內容不多,總共146行。是計算…

python打卡day37

疏錦行 知識點回顧: 1. 過擬合的判斷:測試集和訓練集同步打印指標 2. 模型的保存和加載 a. 僅保存權重 b. 保存權重和模型 c. 保存全部信息checkpoint,還包含訓練狀態 3. 早停策略 作業:對信貸數據集訓練后保存權重&#xf…

【Spark征服之路-2.9-Spark-Core編程(五)】

RDD行動算子: 行動算子就是會觸發action的算子,觸發action的含義就是真正的計算數據。 1. reduce ? 函數簽名 def reduce(f: (T, T) > T): T ? 函數說明 聚集 RDD 中的所有元素,先聚合分區內數據,再聚合分區間數據 val…

【入門】【練17.3 】比大小

| 時間限制:C/C 1000MS,其他語言 2000MS 內存限制:C/C 64MB,其他語言 128MB 難度:中等 分數:100 OI排行榜得分:12(0.1分數2難度) 出題人:root | 描述 試編一個程序,輸入…

CppCon 2017 學習:Free Your Functions!

“Free Your Functions!” 這句話在C設計中有很深的含義,意思是: “Free Your Functions!” 的理解 “解放你的函數”,鼓勵程序員: 不要把所有的函數都綁在類的成員函數里,優先考慮寫成自由函數(non-mem…

日常運維問題匯總-19

60. OVF3維護成本中心與訂貨原因之間的對應關系時,報錯提示,SYST: 不期望的日期 00/00/0000。消息號 FGV004,如下圖所示: OVF3往右邊拉動,有一個需要填入的字段“有效期自”,此字段值必須在成本中心定義的有…