深度學習——基于卷積神經網絡實現食物圖像分類【3】(保存最優模型)

文章目錄

    • 引言
    • 一、項目概述
    • 二、環境配置
    • 三、數據預處理
      • 3.1 數據轉換設置
      • 3.2 數據集準備
    • 四、自定義數據集類
    • 五、CNN模型架構
    • 六、訓練與評估流程
      • 6.1 訓練函數
      • 6.2 評估與模型保存
    • 七、完整訓練流程
    • 八、模型保存與加載
      • 8.1 保存模型
      • 8.2 加載模型
    • 九、優化建議
    • 十、常見問題解決
    • 十一、完整代碼
    • 十二、總結

引言

本文將詳細介紹如何使用PyTorch框架構建一個完整的食物圖像分類系統,包含數據預處理、模型構建、訓練優化以及模型保存等關鍵環節。與上一篇博客介紹的版本相比,本版本增加了模型保存與加載功能,并優化了測試評估流程。

一、項目概述

本項目的目標是構建一個能夠識別20種不同食物的圖像分類系統。主要技術特點包括:

  1. 簡化但高效的數據預處理流程
  2. 三層CNN網絡架構設計
  3. 訓練過程中自動保存最佳模型
  4. 完整的訓練-評估流程實現

二、環境配置

首先確保已安裝必要的Python庫:

import torch
import torchvision.models as models
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import numpy as np
import os

三、數據預處理

3.1 數據轉換設置

我們為訓練集和驗證集定義了不同的轉換策略:

data_transforms = {'train': transforms.Compose([transforms.Resize([256,256]),transforms.ToTensor(),]),'valid': transforms.Compose([transforms.Resize([256,256]),transforms.ToTensor(),]),
}

簡化說明

  • 本版本簡化了數據增強,僅保留基本的resize和tensor轉換
  • 實際應用中可根據需求添加更多增強策略

3.2 數據集準備

def train_test_file(root, dir):file_txt = open(dir+'.txt','w')path = os.path.join(root,dir)for roots, directories, files in os.walk(path):if len(directories) != 0:dirs = directorieselse:now_dir = roots.split('\\')for file in files:path_1 = os.path.join(roots,file)file_txt.write(path_1+' '+str(dirs.index(now_dir[-1]))+'\n')file_txt.close()

該函數會生成包含圖像路徑和標簽的文本文件,格式為:

path/to/image1.jpg 0
path/to/image2.jpg 1
...

四、自定義數據集類

我們繼承PyTorch的Dataset類實現自定義數據集:

class food_dataset(Dataset):def __init__(self, file_path, transform=None):self.file_path = file_pathself.imgs = []self.labels = []self.transform = transformwith open(self.file_path) as f:samples = [x.strip().split(' ') for x in f.readlines()]for img_path, label in samples:self.imgs.append(img_path)self.labels.append(label)def __len__(self):return len(self.imgs)def __getitem__(self, idx):image = Image.open(self.imgs[idx])if self.transform:image = self.transform(image)label = self.labels[idx]label = torch.from_numpy(np.array(label, dtype=np.int64))return image, label

關鍵改進

  • 更清晰的數據加載邏輯
  • 完善的類型轉換處理
  • 支持靈活的數據變換

五、CNN模型架構

我們設計了一個三層CNN網絡:

class CNN(nn.Module):def __init__(self):super(CNN,self).__init__()self.conv1 = nn.Sequential(nn.Conv2d(3, 16, 5, 1, 2),nn.ReLU(),nn.MaxPool2d(2))self.conv2 = nn.Sequential(nn.Conv2d(16,32,5,1,2),nn.ReLU(),nn.MaxPool2d(2))self.conv3 = nn.Sequential(nn.Conv2d(32, 64, 5, 1, 2),nn.ReLU(),nn.MaxPool2d(2))self.out = nn.Linear(64*32*32, 20)def forward(self, x):x = self.conv1(x)x = self.conv2(x)x = self.conv3(x)x = x.view(x.size(0), -1)return self.out(x)

架構特點

  1. 每層包含卷積、ReLU激活和最大池化
  2. 使用padding保持特征圖尺寸
  3. 最后通過全連接層輸出分類結果

六、訓練與評估流程

6.1 訓練函數

def train(dataloader, model, loss_fn, optimizer):model.train()batch_size_num = 1for X, y in dataloader:X, y = X.to(device), y.to(device)pred = model(X)loss = loss_fn(pred, y)optimizer.zero_grad()loss.backward()optimizer.step()if batch_size_num % 1 == 0:print(f"loss: {loss.item():>7f} [batch:{batch_size_num}]")batch_size_num += 1

6.2 評估與模型保存

best_acc = 0def Test(dataloader, model, loss_fn):global best_accsize = len(dataloader.dataset)num_batches = len(dataloader)model.eval()test_loss, correct = 0, 0with torch.no_grad():for X, y in dataloader:X, y = X.to(device), y.to(device)pred = model(X)test_loss += loss_fn(pred, y).item()correct += (pred.argmax(1) == y).type(torch.float).sum().item()test_loss /= num_batchescorrect /= size# 保存最佳模型if correct > best_acc:best_acc = correcttorch.save(model.state_dict(), "best_model.pth")print(f"\n測試結果: \n 準確率:{(100*correct):.2f}%, 平均損失:{test_loss:.4f}")

關鍵改進

  1. 增加全局變量best_acc跟蹤最佳準確率
  2. 實現兩種模型保存方式:
    • 只保存模型參數(state_dict)
    • 保存整個模型
  3. 更詳細的測試結果輸出

七、完整訓練流程

# 初始化
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
model = CNN().to(device)
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)# 訓練循環
epochs = 10
for t in range(epochs):print(f"Epoch {t+1}\n{'-'*20}")train(train_dataloader, model, loss_fn, optimizer)# 最終評估
Test(test_dataloader, model, loss_fn)

八、模型保存與加載

8.1 保存模型

# 方法1:只保存參數
torch.save(model.state_dict(), "model_params.pth")# 方法2:保存完整模型
torch.save(model, "full_model.pt")

8.2 加載模型

# 方法1對應加載方式
model = CNN().to(device)
model.load_state_dict(torch.load("model_params.pth"))# 方法2對應加載方式
model = torch.load("full_model.pt").to(device)

九、優化建議

  1. 數據增強:添加更多變換提高模型泛化能力
  2. 學習率調度:使用torch.optim.lr_scheduler動態調整學習率
  3. 早停機制:當驗證集性能不再提升時停止訓練
  4. 模型微調:使用預訓練模型作為基礎

十、常見問題解決

  1. 內存不足

    • 減小batch size
    • 使用梯度累積
    • 嘗試混合精度訓練
  2. 過擬合

    • 增加Dropout層
    • 添加L2正則化
    • 使用更多數據增強
  3. 訓練不穩定

    • 檢查數據標準化
    • 調整學習率
    • 檢查損失函數

十一、完整代碼

import torch
import torchvision.models as models
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import numpy as np
import osdata_transforms = { #字典'train':transforms.Compose([            #對圖片預處理的組合transforms.Resize([256,256]),   #對數據進行改變大小transforms.ToTensor(),          #數據轉換為tensor]),'valid':transforms.Compose([transforms.Resize([256,256]),transforms.ToTensor(),]),
}def train_test_file(root,dir):file_txt = open(dir+'.txt','w')path = os.path.join(root,dir)for roots,directories,files in os.walk(path):if len(directories) !=0:dirs = directorieselse:now_dir = roots.split('\\')for file in files:path_1 = os.path.join(roots,file)print(path_1)file_txt.write(path_1+' '+str(dirs.index(now_dir[-1]))+'\n')file_txt.close()root = r'.\食物分類\food_dataset'
train_dir = 'train'
test_dir = 'test'
train_test_file(root,train_dir)
train_test_file(root,test_dir)#Dataset是用來處理數據的
class food_dataset(Dataset):        # food_dataset是自己創建的類名稱,可以改為你需要的名稱def __init__(self,file_path,transform=None):    #類的初始化,解析數據文件txtself.file_path = file_pathself.imgs = []self.labels = []self.transform = transformwith open(self.file_path) as f: #是把train.txt文件中的圖片路徑保存在self.imgssamples = [x.strip().split(' ') for x in f.readlines()]for img_path,label in samples:self.imgs.append(img_path)  #圖像的路徑self.labels.append(label)   #標簽,還不是tensor# 初始化:把圖片目錄加到selfdef __len__(self):  #類實例化對象后,可以使用len函數測量對象的個數return  len(self.imgs)#training_data[1]def __getitem__(self, idx):    #關鍵,可通過索引的形式獲取每一個圖片的數據及標簽image = Image.open(self.imgs[idx])  #讀取到圖片數據,還不是tensor,BGRif self.transform:                  #將PIL圖像數據轉換為tensorimage = self.transform(image)   #圖像處理為256*256,轉換為tensorlabel = self.labels[idx]    #label還不是tensorlabel = torch.from_numpy(np.array(label,dtype=np.int64))    #label也轉換為tensorreturn image,label
#training_data包含了本次需要訓練的全部數據集
training_data = food_dataset(file_path='train.txt', transform=data_transforms['train'])
test_data = food_dataset(file_path='test.txt', transform=data_transforms['valid'])#training_data需要具備索引的功能,還要確保數據是tensor
train_dataloader = DataLoader(training_data,batch_size=16,shuffle=True)
test_dataloader = DataLoader(test_data,batch_size=16,shuffle=True)'''判斷當前設備是否支持GPU,其中mps是蘋果m系列芯片的GPU'''
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
print(f"Using {device} device")   #字符串的格式化,CUDA驅動軟件的功能:pytorch能夠去執行cuda的命令
# 神經網絡的模型也需要傳入到GPU,1個batch_size的數據集也需要傳入到GPU,才可以進行訓練''' 定義神經網絡  類的繼承這種方式'''
class CNN(nn.Module): #通過調用類的形式來使用神經網絡,神經網絡的模型,nn.mdouledef __init__(self): #輸入大小:(3,256,256)super(CNN,self).__init__()  #初始化父類self.conv1 = nn.Sequential( #將多個層組合成一起,創建了一個容器,將多個網絡組合在一起nn.Conv2d(              # 2d一般用于圖像,3d用于視頻數據(多一個時間維度),1d一般用于結構化的序列數據in_channels=3,      # 圖像通道個數,1表示灰度圖(確定了卷積核 組中的個數)out_channels=16,     # 要得到多少個特征圖,卷積核的個數kernel_size=5,      # 卷積核大小 3×3stride=1,           # 步長padding=2,          # 一般希望卷積核處理后的結果大小與處理前的數據大小相同,效果會比較好),                      # 輸出的特征圖為(16,256,256)nn.ReLU(),  # Relu層,不會改變特征圖的大小nn.MaxPool2d(kernel_size=2),    # 進行池化操作(2×2操作),輸出結果為(16,128,128))self.conv2 = nn.Sequential(nn.Conv2d(16,32,5,1,2),  #輸出(32,128,128)nn.ReLU(),  #Relu層  (32,128,128)nn.MaxPool2d(kernel_size=2),    #池化層,輸出結果為(32,64,64))self.conv3 = nn.Sequential(nn.Conv2d(32, 64, 5, 1, 2),  # 輸出(64,64,64)nn.ReLU(),  # Relu層  (64,64,64)nn.MaxPool2d(kernel_size=2),  # 池化層,輸出結果為(64,32,32))self.out = nn.Linear(64*32*32,20)  # 全連接層得到的結果def forward(self,x):   #前向傳播,你得告訴它 數據的流向 是神經網絡層連接起來,函數名稱不能改x = self.conv1(x)x = self.conv2(x)x = self.conv3(x)x = x.view(x.size(0),-1)    # flatten操作,結果為:(batch_size,32 * 64 * 64)output = self.out(x)return output# 提取模型的2種方法:
#   1、讀取參數的方法
model = CNN().to(device) #初始化模型,w都是隨機初始化的
# model.load_state_dict(torch.load("best.pth"))
#   2、讀取完整模型的方法,無需提前創建model
#   model = CNN().to(device)
#   model = torch.load('best.pt')#w,b,cnn
# 模型保存的對不對?
# model.eval() #固定模型參數和數據,防止后面被修改
print(model)def train(dataloader,model,loss_fn,optimizer):model.train() #告訴模型,我要開始訓練,模型中w進行隨機化操作,已經更新w,在訓練過程中,w會被修改的
# pytorch提供2種方式來切換訓練和測試的模式,分別是:model.train() 和 mdoel.eval()
# 一般用法是:在訓練開始之前寫上model.train(),在測試時寫上model.eval()batch_size_num = 1for X,y in dataloader:              #其中batch為每一個數據的編號X,y = X.to(device),y.to(device) #把訓練數據集和標簽傳入cpu或GPUpred = model.forward(X)         # .forward可以被省略,父類種已經對此功能進行了設置loss = loss_fn(pred,y)          # 通過交叉熵損失函數計算損失值loss# Backpropagation 進來一個batch的數據,計算一次梯度,更新一次網絡optimizer.zero_grad()           # 梯度值清零loss.backward()                 # 反向傳播計算得到每個參數的梯度值woptimizer.step()                # 根據梯度更新網絡w參數loss_value = loss.item()        # 從tensor數據種提取數據出來,tensor獲取損失值if batch_size_num %1 ==0:print(f"loss: {loss_value:>7f} [number:{batch_size_num}]")batch_size_num += 1best_acc = 0def Test(dataloader,model,loss_fn):global best_accsize = len(dataloader.dataset)num_batches = len(dataloader)  # 打包的數量model.eval()        #測試,w就不能再更新test_loss,correct =0,0with torch.no_grad():       #一個上下文管理器,關閉梯度計算。當你確認不會調用Tensor.backward()的時候for X,y in dataloader:X,y = X.to(device),y.to(device)pred = model(X) #等價于model.forward(X)test_loss += loss_fn(pred,y).item() #test_loss是會自動累加每一個批次的損失值correct += (pred.argmax(1) == y).type(torch.float).sum().item()a = (pred.argmax(1) == y) #dim=1表示每一行中的最大值對應的索引號,dim=0表示每一列中的最大值對應的索引號b = (pred.argmax(1) == y).type(torch.float)test_loss /= num_batches #能來衡量模型測試的好壞correct /= size  #平均的正確率# 保存最優模型的2種方法(模型的文件擴展名一般:pt\pth,t7) #opencvif correct > best_acc:best_acc = correct
#   1.保存模型參數方法:torch.save(model.state_dict(),path)  (w,b)
#         print(model.state_dict().keys())    #輸出模型參數名稱   cnntorch.save(model.state_dict(), "best2025-04.pth")
#   2.保存完整模型(w,b,模型cnn)
#         torch.save(model,'best.pt')print(f"\n最終測試結果: \n 準確率:{(100*correct):.2f}%, 平均損失:{test_loss:.4f}")loss_fn = nn.CrossEntropyLoss()  #創建交叉熵損失函數對象,因為手寫字識別一共有十種數字,輸出會有10個結果optimizer = torch.optim.Adam(model.parameters(),lr=0.001) #創建一個優化器,SGD為隨機梯度下降算法
# # params:要訓練的參數,一般我們傳入的都是model.parameters()
# # lr:learning_rate學習率,也就是步長
#
# # loss表示模型訓練后的輸出結果與樣本標簽的差距。如果差距越小,就表示模型訓練越好,越逼近真實的模型
# train(train_dataloader,model,loss_fn,optimizer) #訓練1次完整的數據。多輪訓練
# Test(test_dataloader,model,loss_fn)epochs = 10
for t in range(epochs):print(f"epoch {t+1}\n---------------")train(train_dataloader,model,loss_fn,optimizer)
print("Done!")
Test(test_dataloader,model,loss_fn)

十二、總結

本文詳細介紹了使用PyTorch實現食物分類的完整流程,重點講解了:

  1. 自定義數據集的處理方法
  2. CNN網絡的設計與實現
  3. 訓練過程中的模型保存策略
  4. 完整的訓練-評估流程

通過本教程,讀者可以掌握PyTorch圖像分類的基本方法,并能夠根據實際需求進行調整和優化。完整代碼已包含在文章中,建議在實際應用中根據具體數據集調整相關參數。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/bicheng/85007.shtml
繁體地址,請注明出處:http://hk.pswp.cn/bicheng/85007.shtml
英文地址,請注明出處:http://en.pswp.cn/bicheng/85007.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

《棒球百科》棒球怎么玩·棒球9號位

用最簡單的方式介紹棒球的核心玩法和規則,完全零基礎也能看懂: 一句話目標 進攻方:用球棒把球打飛,然后拼命跑完4個壘包(逆時針繞一圈)得分。 防守方:想盡辦法讓進攻方出局,阻止他…

語言模型是怎么工作的?通俗版原理解讀!

大模型為什么能聊天、寫代碼、懂醫學? 我們從四個關鍵模塊,一步步拆開講清楚 👇 ? 模塊一:模型的“本事”從哪來?靠訓練數據 別幻想它有意識,它的能力,全是“喂”出來的: 吃過成千…

nrf52811墨水屏edp_service.c文件學習

on_connect函數 /**brief Function for handling the ref BLE_GAP_EVT_CONNECTED event from the S110 SoftDevice.** param[in] p_epd EPD Service structure.* param[in] p_ble_evt Pointer to the event received from BLE stack.*/ static void on_connect(ble_epd_t …

Nginx-2 詳解處理 Http 請求

Nginx-2 詳解處理 Http 請求 Nginx 作為當今最流行的開源 Web 服務器之一,以其高性能、高穩定性和豐富的功能而聞名。在處理 HTTP請求 的過程中,Nginx 采用了模塊化的設計,將整個請求處理流程劃分為若干個階段,每個階段都可以由特…

40-Oracle 23 ai Bigfile~Smallfile-Basicfile~Securefile矩陣對比

小伙伴們是不是在文件選擇上還默認給建文件4G/個么,在oracle每個版本上系統默認屬性是什么,選擇困難癥了沒,一起一次性文件存儲和默認屬性看透。 基于Oracle歷代在存儲架構的技術演進分析,結合版本升級和23ai新特性,一…

【一】零基礎--分層強化學習概覽

分層強化學習(Hierarchical Reinforcement Learning, HRL)最早一般視為1993 年封建強化學習的提出. 一、HL的基礎理論 1.1 MDP MDP(馬爾可夫決策過程):MDP是一種用于建模序列決策問題的框架,包含狀態&am…

Java延時

在 Java 中實現延時操作主要有以下幾種方式,根據使用場景選擇合適的方法: 1. Thread.sleep()(最常用) java 復制 下載 try {// 延時 1000 毫秒(1秒)Thread.sleep(1000); } catch (InterruptedExcepti…

電阻篇---下拉電阻的取值

下拉電阻的取值需要綜合考慮電路驅動能力、功耗、信號完整性、噪聲容限等多方面因素。以下是詳細的取值分析及方法: 一、下拉電阻的核心影響因素 1. 驅動能力與電流限制 單片機 IO 口驅動能力:如 STM32 的 IO 口在輸入模式下的漏電流通常很小&#xf…

NY271NY274美光科技固態NY278NY284

美光科技NY系列固態硬盤深度剖析:技術、市場與未來 技術前沿:232層NAND架構與性能突破 在存儲技術的賽道上,美光科技(Micron)始終是行業領跑者。其NY系列固態硬盤(SSD)憑借232層NAND閃存架構的…

微信開發者工具 插件未授權使用,user uni can not visit app

參考:https://www.jingpinma.cn/archives/159.html 問題描述 我下載了一個別人的小程序,想運行看看效果,結果報錯信息如下 原因 其實就是插件沒有安裝,需要到小程序平臺安裝插件。處理辦法如下 在 app.json 里,聲…

UE5 讀取配置文件

使用免費的Varest插件,可以讀取本地的json數據 獲取配置文件路徑:當前配置文件在工程根目錄,打包后在 Windows/項目名稱 下 讀取json 打包后需要手動復制配置文件到Windows/項目名稱 下

【kdump專欄】KEXEC機制中SME(安全內存加密)

【kdump專欄】KEXEC機制中SME&#xff08;安全內存加密&#xff09; 原始代碼&#xff1a; /* Ensure that these pages are decrypted if SME is enabled. */ 533 if (pages) 534 arch_kexec_post_alloc_pages(page_address(pages), 1 << order, 0);&#x1f4cc…

C# vs2022 找不到指定的 SDK“Microsof.NET.Sdk

找不到指定的 SDK"Microsof.NET.Sdk 第一查 看 系統盤目錄 C:\Program Files\dotnet第二 命令行輸入 dotnet --version第三 檢查環境變量總結 只要執行dotnet --version 正常返回版本號此問題即解決 第一查 看 系統盤目錄 C:\Program Files\dotnet 有2種方式 去檢查 是否…

Pytest斷言全解析:掌握測試驗證的核心藝術

Pytest斷言全解析&#xff1a;掌握測試驗證的核心藝術 一、斷言的本質與重要性 什么是斷言&#xff1f; 斷言是自動化測試中的驗證檢查點&#xff0c;用于確認代碼行為是否符合預期。在Pytest中&#xff0c;斷言直接使用Python原生assert語句&#xff0c;當條件不滿足時拋出…

【編譯原理】題目合集(一)

未經許可,禁止轉載。 文章目錄 選擇填空綜合選擇 將編譯程序分成若干個“遍”是為了 (D.利用有限的機器內存,但降低了執行效率) A.提高程序的執行效率 B.使程序的結構更加清晰 C.利用有限的機器內存并提高執行效率 D.利用有限的機器內存,但降低了執行效率 詞法分析…

uni-app項目實戰筆記13--全屏頁面的absolute定位布局和fit-content自適應內容寬度

本篇主要實現全屏頁面的布局&#xff0c;其中還涉及內容自適應寬度。 創建一個preview.vue頁面用于圖片預覽&#xff0c;寫入以下代碼&#xff1a; <template><view class"preview"><swiper circular><swiper-item v-for"item in 5&quo…

OVS Faucet Tutorial筆記(下)

官方文檔&#xff1a; OVS Faucet Tutorial 5、Routing Faucet Router 通過控制器模擬三層網關&#xff0c;提供 ARP 應答、路由轉發功能。 5.1 控制器配置 5.1.1 編輯控制器yaml文件&#xff0c;增加router配置 rootserver1:~/faucet/inst# vi faucet.yaml dps:switch-1:d…

PCB設計教程【大師篇】stm32開發板PCB布線(信號部分)

前言 本教程基于B站Expert電子實驗室的PCB設計教學的整理&#xff0c;為個人學習記錄&#xff0c;旨在幫助PCB設計新手入門。所有內容僅作學習交流使用&#xff0c;無任何商業目的。若涉及侵權&#xff0c;請隨時聯系&#xff0c;將會立即處理 1. 布線優先級與原則 - 遵循“重…

Phthon3 學習記錄-0613

List&#xff08;列表&#xff09;、Tuple&#xff08;元組&#xff09;、Set&#xff08;集合&#xff09;和 Dictionary&#xff08;字典&#xff09; 在接口自動化測試中&#xff0c;List&#xff08;列表&#xff09;、Tuple&#xff08;元組&#xff09;、Set&#xff08…

UVa12298 3KP-BASH Project

UVa12298 3KP-BASH Project 題目鏈接題意輸入格式輸出格式 分析AC 代碼 題目鏈接 UVa12298 3KP-BASH Project 題意 摘自 《算法競賽入門經典&#xff1a;訓練指南》劉汝佳&#xff0c;陳鋒著。有刪改。 你的任務是為一個假想的 3KP 操作系統編寫一個簡單的 Bash 模擬器。由于操…