深度學習——基于卷積神經網絡實現食物圖像分類【2】(數據增強)

文章目錄

    • 引言
    • 一、項目概述
    • 二、環境準備
    • 三、數據預處理
      • 3.1 數據增強與標準化
      • 3.2 數據集準備
    • 四、自定義數據集類
    • 五、構建CNN模型
    • 六、訓練與評估
      • 6.1 訓練函數
      • 6.2 評估函數
      • 6.3 訓練流程
    • 七、關鍵技術與優化
    • 八、常見問題與解決
    • 九、完整代碼
    • 十、總結

引言

本文將詳細介紹如何使用PyTorch框架構建一個食物圖像分類系統,涵蓋數據預處理、模型構建、訓練和評估全過程。我們將使用自定義的食物數據集,構建一個卷積神經網絡(CNN)模型,并實現完整的訓練流程。

一、項目概述

食物圖像分類是計算機視覺中的一個常見應用場景。在本項目中,我們將構建一個能夠識別20種不同食物的分類系統。整個流程包括:

  1. 數據準備與預處理
  2. 構建自定義數據集類
  3. 設計CNN模型架構
  4. 訓練模型并評估性能
  5. 優化與結果分析

二、環境準備

首先確保已安裝必要的Python庫:

import torch
import torchvision.models as models
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import numpy as np
import os

三、數據預處理

3.1 數據增強與標準化

我們為訓練集和驗證集分別定義不同的轉換策略:

data_transforms = {'train': transforms.Compose([transforms.Resize([300,300]),transforms.RandomRotation(45),transforms.CenterCrop(256),transforms.RandomHorizontalFlip(p=0.5),transforms.RandomVerticalFlip(p=0.5),transforms.ColorJitter(brightness=0.2, contrast=0.1, saturation=0.1, hue=0.1),transforms.RandomGrayscale(p=0.1),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),'valid': transforms.Compose([transforms.Resize([256,256]),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
}

關鍵點解析:

  1. 訓練集增強

    • 隨機旋轉(-45°到45°)
    • 隨機水平和垂直翻轉
    • 色彩抖動(亮度、對比度、飽和度和色調)
    • 隨機灰度化(概率10%)
  2. 標準化處理

    • 使用ImageNet的均值和標準差進行歸一化
    • 有助于模型更快收斂

3.2 數據集準備

我們編寫了一個函數來生成訓練和測試的標注文件:

def train_test_file(root, dir):file_txt = open(dir+'.txt','w')path = os.path.join(root,dir)for roots, directories, files in os.walk(path):if len(directories) != 0:dirs = directorieselse:now_dir = roots.split('\\')for file in files:path_1 = os.path.join(roots,file)file_txt.write(path_1+' '+str(dirs.index(now_dir[-1]))+'\n')file_txt.close()

該函數會遍歷指定目錄,生成包含圖像路徑和對應標簽的文本文件。

四、自定義數據集類

我們繼承PyTorch的Dataset類創建自定義數據集:

class food_dataset(Dataset):def __init__(self, file_path, transform=None):self.file_path = file_pathself.imgs = []self.labels = []self.transform = transformwith open(self.file_path) as f:samples = [x.strip().split(' ') for x in f.readlines()]for img_path, label in samples:self.imgs.append(img_path)self.labels.append(label)def __len__(self):return len(self.imgs)def __getitem__(self, idx):image = Image.open(self.imgs[idx])if self.transform:image = self.transform(image)label = self.labels[idx]label = torch.from_numpy(np.array(label, dtype=np.int64))return image, label

關鍵方法:

  1. __init__: 初始化數據集,讀取標注文件
  2. __len__: 返回數據集大小
  3. __getitem__: 根據索引返回圖像和標簽,應用預處理

五、構建CNN模型

我們設計了一個三層的CNN網絡:

class CNN(nn.Module):def __init__(self):super(CNN,self).__init__()self.conv1 = nn.Sequential(nn.Conv2d(3, 16, 5, 1, 2),nn.ReLU(),nn.MaxPool2d(kernel_size=2))self.conv2 = nn.Sequential(nn.Conv2d(16,32,5,1,2),nn.ReLU(),nn.MaxPool2d(kernel_size=2))self.conv3 = nn.Sequential(nn.Conv2d(32, 64, 5, 1, 2),nn.ReLU(),nn.MaxPool2d(kernel_size=2))self.out = nn.Linear(64*32*32, 20)def forward(self, x):x = self.conv1(x)x = self.conv2(x)x = self.conv3(x)x = x.view(x.size(0), -1)output = self.out(x)return output

網絡結構分析:

  1. 卷積層1

    • 輸入通道:3 (RGB)
    • 輸出通道:16
    • 卷積核:5×5
    • 輸出尺寸:(16, 128, 128)
  2. 卷積層2

    • 輸入通道:16
    • 輸出通道:32
    • 輸出尺寸:(32, 64, 64)
  3. 卷積層3

    • 輸入通道:32
    • 輸出通道:64
    • 輸出尺寸:(64, 32, 32)
  4. 全連接層

    • 輸入:64×32×32 = 65536
    • 輸出:20 (對應20類食物)

六、訓練與評估

6.1 訓練函數

def train(dataloader, model, loss_fn, optimizer):model.train()batch_size_num = 1for X, y in dataloader:X, y = X.to(device), y.to(device)pred = model(X)loss = loss_fn(pred, y)optimizer.zero_grad()loss.backward()optimizer.step()loss_value = loss.item()if batch_size_num % 1 == 0:print(f"loss: {loss_value:>7f} [number:{batch_size_num}]")batch_size_num += 1

6.2 評估函數

def Test(dataloader, model, loss_fn):size = len(dataloader.dataset)num_batches = len(dataloader)model.eval()test_loss, correct = 0, 0with torch.no_grad():for X, y in dataloader:X, y = X.to(device), y.to(device)pred = model(X)test_loss += loss_fn(pred, y).item()correct += (pred.argmax(1) == y).type(torch.float).sum().item()test_loss /= num_batchescorrect /= sizeprint(f"Test result: \n Accuracy:{(100*correct)}%, Avg loss:{test_loss}")

6.3 訓練流程

# 初始化模型
model = CNN().to(device)# 定義損失函數和優化器
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)# 訓練10個epoch
epochs = 10
for t in range(epochs):print(f"epoch {t+1}\n---------------")train(train_dataloader, model, loss_fn, optimizer)# 最終評估
Test(test_dataloader, model, loss_fn)

七、關鍵技術與優化

  1. 數據增強:通過多種變換增加數據多樣性,防止過擬合
  2. 批標準化:使用ImageNet統計量進行標準化,加速收斂
  3. 學習率選擇:使用Adam優化器,初始學習率0.001
  4. 設備選擇:自動檢測并使用GPU加速訓練

八、常見問題與解決

  1. 內存不足

    • 減小batch size
    • 使用更小的圖像尺寸
  2. 過擬合

    • 增加數據增強
    • 添加Dropout層
    • 使用L2正則化
  3. 訓練不收斂

    • 檢查學習率
    • 檢查數據預處理
    • 檢查模型結構

九、完整代碼

import torch
import torchvision.models as models
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import numpy as np
import osdata_transforms = { #字典'train':transforms.Compose([            #對圖片預處理的組合transforms.Resize([300,300]),   #對數據進行改變大小transforms.RandomRotation(45),  #隨機旋轉,-45到45之間隨機選transforms.CenterCrop(256),     #從中心開始裁剪[256,256]transforms.RandomHorizontalFlip(p=0.5),#隨機水平翻轉,p是指選擇一個概率翻轉,p=0.5表示百分之50transforms.RandomVerticalFlip(p=0.5),#隨機垂直翻轉transforms.ColorJitter(brightness=0.2,contrast=0.1,saturation=0.1,hue=0.1),transforms.RandomGrayscale(p=0.1),#概率轉換成灰度率,3通道就是R=G=Btransforms.ToTensor(),#數據轉換為tensortransforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])#標準化,均值,標準差]),'valid':transforms.Compose([transforms.Resize([256,256]),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])  # 標準化,均值,標準差]),
}
#做了數據增強不代表訓練效果一定會好,只能說大概率會變好
def train_test_file(root,dir):file_txt = open(dir+'.txt','w')path = os.path.join(root,dir)for roots,directories,files in os.walk(path):if len(directories) !=0:dirs = directorieselse:now_dir = roots.split('\\')for file in files:path_1 = os.path.join(roots,file)print(path_1)file_txt.write(path_1+' '+str(dirs.index(now_dir[-1]))+'\n')file_txt.close()root = r'.\食物分類\food_dataset'
train_dir = 'train'
test_dir = 'test'
train_test_file(root,train_dir)
train_test_file(root,test_dir)#Dataset是用來處理數據的
class food_dataset(Dataset):        # food_dataset是自己創建的類名稱,可以改為你需要的名稱def __init__(self,file_path,transform=None):    #類的初始化,解析數據文件txtself.file_path = file_pathself.imgs = []self.labels = []self.transform = transformwith open(self.file_path) as f: #是把train.txt文件中的圖片路徑保存在self.imgssamples = [x.strip().split(' ') for x in f.readlines()]for img_path,label in samples:self.imgs.append(img_path)  #圖像的路徑self.labels.append(label)   #標簽,還不是tensor# 初始化:把圖片目錄加到selfdef __len__(self):  #類實例化對象后,可以使用len函數測量對象的個數return  len(self.imgs)#training_data[1]def __getitem__(self, idx):    #關鍵,可通過索引的形式獲取每一個圖片的數據及標簽image = Image.open(self.imgs[idx])  #讀取到圖片數據,還不是tensor,BGRif self.transform:                  #將PIL圖像數據轉換為tensorimage = self.transform(image)   #圖像處理為256*256,轉換為tensorlabel = self.labels[idx]    #label還不是tensorlabel = torch.from_numpy(np.array(label,dtype=np.int64))    #label也轉換為tensorreturn image,label
#training_data包含了本次需要訓練的全部數據集
training_data = food_dataset(file_path='train.txt', transform=data_transforms['train'])
test_data = food_dataset(file_path='test.txt', transform=data_transforms['valid'])#training_data需要具備索引的功能,還要確保數據是tensor
train_dataloader = DataLoader(training_data,batch_size=16,shuffle=True)
test_dataloader = DataLoader(test_data,batch_size=16,shuffle=True)'''判斷當前設備是否支持GPU,其中mps是蘋果m系列芯片的GPU'''
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
print(f"Using {device} device")   #字符串的格式化,CUDA驅動軟件的功能:pytorch能夠去執行cuda的命令
# 神經網絡的模型也需要傳入到GPU,1個batch_size的數據集也需要傳入到GPU,才可以進行訓練''' 定義神經網絡  類的繼承這種方式'''
class CNN(nn.Module): #通過調用類的形式來使用神經網絡,神經網絡的模型,nn.mdouledef __init__(self): #輸入大小:(3,256,256)super(CNN,self).__init__()  #初始化父類self.conv1 = nn.Sequential( #將多個層組合成一起,創建了一個容器,將多個網絡組合在一起nn.Conv2d(              # 2d一般用于圖像,3d用于視頻數據(多一個時間維度),1d一般用于結構化的序列數據in_channels=3,      # 圖像通道個數,1表示灰度圖(確定了卷積核 組中的個數)out_channels=16,     # 要得到多少個特征圖,卷積核的個數kernel_size=5,      # 卷積核大小 3×3stride=1,           # 步長padding=2,          # 一般希望卷積核處理后的結果大小與處理前的數據大小相同,效果會比較好),                      # 輸出的特征圖為(16,256,256)nn.ReLU(),  # Relu層,不會改變特征圖的大小nn.MaxPool2d(kernel_size=2),    # 進行池化操作(2×2操作),輸出結果為(16,128,128))self.conv2 = nn.Sequential(nn.Conv2d(16,32,5,1,2),  #輸出(32,128,128)nn.ReLU(),  #Relu層  (32,128,128)nn.MaxPool2d(kernel_size=2),    #池化層,輸出結果為(32,64,64))self.conv3 = nn.Sequential(nn.Conv2d(32, 64, 5, 1, 2),  # 輸出(64,64,64)nn.ReLU(),  # Relu層  (64,64,64)nn.MaxPool2d(kernel_size=2),  # 池化層,輸出結果為(64,32,32))self.out = nn.Linear(64*32*32,20)  # 全連接層得到的結果def forward(self,x):   #前向傳播,你得告訴它 數據的流向 是神經網絡層連接起來,函數名稱不能改x = self.conv1(x)x = self.conv2(x)x = self.conv3(x)x = x.view(x.size(0),-1)    # flatten操作,結果為:(batch_size,32 * 64 * 64)output = self.out(x)return output
model = CNN().to(device) #把剛剛創建的模型傳入到GPU
print(model)def train(dataloader,model,loss_fn,optimizer):model.train() #告訴模型,我要開始訓練,模型中w進行隨機化操作,已經更新w,在訓練過程中,w會被修改的
# pytorch提供2種方式來切換訓練和測試的模式,分別是:model.train() 和 mdoel.eval()
# 一般用法是:在訓練開始之前寫上model.train(),在測試時寫上model.eval()batch_size_num = 1for X,y in dataloader:              #其中batch為每一個數據的編號X,y = X.to(device),y.to(device) #把訓練數據集和標簽傳入cpu或GPUpred = model.forward(X)         # .forward可以被省略,父類種已經對此功能進行了設置loss = loss_fn(pred,y)          # 通過交叉熵損失函數計算損失值loss# Backpropagation 進來一個batch的數據,計算一次梯度,更新一次網絡optimizer.zero_grad()           # 梯度值清零loss.backward()                 # 反向傳播計算得到每個參數的梯度值woptimizer.step()                # 根據梯度更新網絡w參數loss_value = loss.item()        # 從tensor數據種提取數據出來,tensor獲取損失值if batch_size_num %1 ==0:print(f"loss: {loss_value:>7f} [number:{batch_size_num}]")batch_size_num += 1def Test(dataloader,model,loss_fn):size = len(dataloader.dataset)num_batches = len(dataloader)  # 打包的數量model.eval()        #測試,w就不能再更新test_loss,correct =0,0with torch.no_grad():       #一個上下文管理器,關閉梯度計算。當你確認不會調用Tensor.backward()的時候for X,y in dataloader:X,y = X.to(device),y.to(device)pred = model.forward(X)test_loss += loss_fn(pred,y).item() #test_loss是會自動累加每一個批次的損失值correct += (pred.argmax(1) == y).type(torch.float).sum().item()a = (pred.argmax(1) == y) #dim=1表示每一行中的最大值對應的索引號,dim=0表示每一列中的最大值對應的索引號b = (pred.argmax(1) == y).type(torch.float)test_loss /= num_batches #能來衡量模型測試的好壞correct /= size  #平均的正確率print(f"Test result: \n Accuracy:{(100*correct)}%, Avg loss:{test_loss}")loss_fn = nn.CrossEntropyLoss()  #創建交叉熵損失函數對象,因為手寫字識別一共有十種數字,輸出會有10個結果
#
optimizer = torch.optim.Adam(model.parameters(),lr=0.001) #創建一個優化器,SGD為隨機梯度下降算法
# # params:要訓練的參數,一般我們傳入的都是model.parameters()
# # lr:learning_rate學習率,也就是步長
#
# # loss表示模型訓練后的輸出結果與樣本標簽的差距。如果差距越小,就表示模型訓練越好,越逼近真實的模型
train(train_dataloader,model,loss_fn,optimizer) #訓練1次完整的數據。多輪訓練
Test(test_dataloader,model,loss_fn)epochs = 10
for t in range(epochs):print(f"epoch {t+1}\n---------------")train(train_dataloader,model,loss_fn,optimizer)
print("Done!")
Test(test_dataloader,model,loss_fn)

十、總結

本文詳細介紹了使用PyTorch實現食物分類的全流程。通過合理的網絡設計、數據增強和訓練策略,我們能夠構建一個有效的分類系統。讀者可以根據實際需求調整網絡結構、超參數和數據增強策略,以獲得更好的性能。

完整代碼已在上文展示,建議在實際應用中根據具體數據集調整相關參數。希望本文能幫助讀者掌握PyTorch圖像分類的基本流程和方法。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/pingmian/85421.shtml
繁體地址,請注明出處:http://hk.pswp.cn/pingmian/85421.shtml
英文地址,請注明出處:http://en.pswp.cn/pingmian/85421.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

詳細說說分布式Session的幾種實現方式

1. 基于客戶端存儲(Cookie-Based) 原理:將會話數據直接存儲在客戶端 Cookie 中 實現: // Spring Boot 示例 Bean public CookieSerializer cookieSerializer() {DefaultCookieSerializer serializer new DefaultCookieSerializ…

用mac的ollama訪問模型,為什么會出現模型胡亂輸出,然后過一會兒再訪問,就又變成正常的

例子:大模型推理遇到內存不足 1. 場景還原 你在Mac上用Ollama運行如下代碼(以Python為例,假設Ollama有API接口): import requestsprompt "請寫一首關于夏天的詩。" response requests.post("http:…

簡說 Linux 用戶組

Linux 用戶組 的核心概念、用途和管理方法,盡量簡明易懂。 🌟 什么是 Linux 用戶組? 在 Linux 系統中: 👉 用戶組(group) 是一組用戶的集合,用來方便地管理權限。 👉 用…

S32DS上進行S32K328的時鐘配置,LPUART時鐘配置步驟詳解

1:S32K328的基礎信息 S32K328官網介紹 由下圖可知,S32K328的最大主頻為 240MHz 2:S32K328時鐘樹配置 2.1 system clock node 節點說明 根據《S32K3xx Reference Manual》資料說明 Table 143 各個 系統時鐘節點 的最大頻率如下所示&#…

wordpress小語種網站模板

wordpress朝鮮語模板 紫色風格的韓語wordpress主題,適合做韓國、朝鮮的外貿公司官方網站使用。 https://www.jianzhanpress.com/?p8486 wordpress日文模板 綠色的日語wordpress外貿主題,用來搭建日文外貿網站很實用。 https://www.jianzhanpress.co…

網絡:Wireshark解析https協議,firefox

文章目錄 問題瀏覽器訪問的解決方法python requests問題 現在大部分的網站已經切到https,很多站點即使開了80的端口,最終還是會返回301消息,讓客戶端轉向到https的一個地址。 所以在使用wireshark進行問題分析的時候,解析tls上層的功能,是必不可少的,但是這個安全交換的…

ollama部署開源大模型

1. 技術概述 Spring AI:Spring 官方推出的 AI 框架,簡化大模型集成(如文本生成、問答系統),支持多種 LLM 提供商。Olama:開源的本地 LLM 推理引擎,支持量化模型部署,提供 REST API …

Kafka 可靠性保障:消息確認與事務機制(二)

Kafka 事務機制 1. 冪等性與事務的關系 在深入探討 Kafka 的事務機制之前,先來了解一下冪等性的概念。冪等性,簡單來說,就是對接口的多次調用所產生的結果和調用一次是一致的。在 Kafka 中,冪等性主要體現在生產者端&#xff0c…

使用 React.Children.map遍歷或修改 children

使用場景: 需要對子組件進行統一處理(如添加 key、包裹額外元素、過濾特定類型等)。 動態修改 children 的 props 或結構。 示例代碼:遍歷并修改 children import React from react;// 一個組件,給每個子項添加邊框…

智能體三階:LLM→Function Call→MCP

哈嘍,我是老劉 老劉是個客戶端開發者,目前主要是用Flutter進行開發,從Flutter 1.0開始到現在已經6年多了。 那為啥最近我對MCP和AI這么感興趣的呢? 一方面是因為作為一個在客戶端領域實戰多年的程序員,我覺得客戶端開發…

flutter的常規特征

前言 Flutter 是由 Google 開發的開源 UI 軟件開發工具包,用于構建跨平臺的高性能、美觀且一致的應用程序。 一、跨平臺開發能力 1.多平臺支持:Flutter 支持構建 iOS、Android、Web、Windows、macOS 和 Linux 應用,開發者可以使用一套代碼庫在…

【Git】代碼托管服務

博主:👍不許代碼碼上紅 歡迎:🐋點贊、收藏、關注、評論。 格言: 大鵬一日同風起,扶搖直上九萬里。 文章目錄 Git代碼托管服務概述Git核心概念主流Git托管平臺Git基礎配置倉庫創建方式Git文件狀態管理常用…

Android 網絡請求的選擇邏輯(Connectivity Modules)

代碼分析 ConnectivityManager packages/modules/Connectivity/framework/src/android/net/ConnectivityManager.java 許多APN已經棄用,應用層統一用 requestNetwork() 來請求網絡。 [ConnectivityManager] example [ConnectivityManager] requestNetwork() [Connectivi…

C#建立與數據庫連接(版本問題的解決方案)踩坑總結

1.如何優雅的建立數據庫連接 今天使用這個deepseek寫代碼,主要就是建立數據庫的鏈接,包括這個建庫建表啥的都是他整得,我就是負責執行,然后解決這個里面遇到的一些問題; 其實我學習這個C#不過是短短的4天的時間&…

FastAPI的初步學習(Django用戶過來的)

我一直以來是Django重度用戶。它有清晰的MVC架構模式、多應用組織結構。它內置用戶認證、數據庫ORM、數據庫遷移、管理后臺、日志等功能,還有強大的社區支持。再搭配上Django REST framework (DRF) ,開發起來效率極高。主打功能強大、易于使用。 曾經也…

提升IT運維效率 貝銳向日葵推出自動化企業腳本功能

在企業進行遠程IT運維管理的過程中,難免會涉及很多需要批量操作下發指令的場景,包括但不限于下列場景: ● ?規模設備部署與初始化、設備配置更新 ● 業務軟件安裝與系統維護,進行安全加固或執行問題修復命令 ● 遠程設備監控與…

最簡單的遠程桌面連接方法是什么?系統自帶內外網訪問實現

在眾多遠程桌面連接方式中,使用 Windows 系統自帶的遠程桌面連接功能是較為簡單的方法之一,無論是在局域網內還是通過公網進行遠程連接,都能輕松實現。 一、局域網內連接步驟 1、 開啟目標計算機遠程桌面功能:在目標計算機&…

JVM(2)——垃圾回收算法

本文將穿透式解析JVM垃圾回收核心算法,涵蓋7大基礎算法4大現代GC實現3種內存分配策略,通過15張動態示意圖GC日志實戰分析,帶您徹底掌握JVM內存自動管理機制。 一、GC核心概念體系 1.1 對象存亡判定法則 引用計數法致命缺陷: // …

基于Spring Boot+Vue的“暖寓”宿舍管理系統設計與實現(源碼及文檔)

基于Spring BootVue的“暖寓”宿舍管理系統設計與實現 第 1 章 緒論 1.1 論文研究主要內容 1.1.1 系統概述 1.1.2 系統介紹 1.2 國內外研究現狀 第 2 章 關鍵技術介紹 2.1 關鍵性開發技術的介紹 2.1.1 Java簡介 2.1.2 Spring Boot框架 2.2 其他相關技術 2.2.1 Vue.J…

基于Java的不固定長度字符集在指定寬度和自適應模型下圖片繪制生成實戰

目錄 前言 一、需求介紹 1、指定寬度生成 2、指定列自適應生成 二、Java生成實現 1、公共方法 2、指定寬度生成 3、指定列自適應生成 三、總結 前言 在當今數字化與信息化飛速發展的時代,圖像的生成與處理技術正日益成為眾多領域關注的焦點。從創意設計到數…