文件目錄
- 引言
- 一、環境準備
- 二、數據預處理
- 訓練集預處理說明:
- 驗證集預處理說明:
- 三、自定義數據集類
- 四、設備選擇
- 五、CNN模型構建
- 六、模型加載與評估
- 1. 加載預訓練模型
- 2. 準備測試數據
- 3. 測試函數
- 4. 計算準確率
- 七、完整代碼
- 八、總結
引言
本文將詳細介紹如何使用PyTorch框架構建一個完整的食物圖像分類系統,包含數據預處理、模型構建、訓練優化以及模型保存等關鍵環節。與上一篇博客介紹的版本相比,本版本增加了使用最優模型這一流程。
一、環境準備
首先,我們需要導入必要的Python庫:
import torch
import torchvision.models as models
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import numpy as np
import os
這些庫中:
torch
和torchvision
是PyTorch的核心庫Dataset
和DataLoader
用于數據加載和處理transforms
提供圖像預處理功能PIL
用于圖像處理numpy
用于數值計算
二、數據預處理
數據預處理是深度學習項目中至關重要的一環。PyTorch提供了transforms
模塊來方便地進行圖像預處理:
data_transforms = {'train': transforms.Compose([transforms.Resize([300,300]),transforms.RandomRotation(45),transforms.CenterCrop(256),transforms.RandomHorizontalFlip(p=0.5),transforms.RandomVerticalFlip(p=0.5),transforms.ColorJitter(brightness=0.2, contrast=0.1, saturation=0.1, hue=0.1),transforms.RandomGrayscale(p=0.1),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),'valid': transforms.Compose([transforms.Resize([256,256]),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
}
訓練集預處理說明:
Resize([300,300])
:將圖像調整為300×300像素RandomRotation(45)
:隨機旋轉圖像(-45°到45°之間)CenterCrop(256)
:從中心裁剪256×256的區域RandomHorizontalFlip(p=0.5)
:以50%概率水平翻轉圖像RandomVerticalFlip(p=0.5)
:以50%概率垂直翻轉圖像ColorJitter
:隨機調整亮度、對比度、飽和度和色調RandomGrayscale(p=0.1)
:以10%概率將圖像轉為灰度ToTensor()
:將PIL圖像轉為PyTorch張量Normalize
:標準化處理(使用ImageNet的均值和標準差)
驗證集預處理說明:
驗證集的預處理相對簡單,只包括調整大小、轉為張量和標準化,因為驗證階段不需要數據增強。
三、自定義數據集類
PyTorch的Dataset
類允許我們自定義數據加載方式。我們創建了一個food_dataset
類:
class food_dataset(Dataset):def __init__(self, file_path, transform=None):self.file_path = file_pathself.imgs = []self.labels = []self.transform = transformwith open(self.file_path) as f:samples = [x.strip().split(' ') for x in f.readlines()]for img_path, label in samples:self.imgs.append(img_path)self.labels.append(label)def __len__(self):return len(self.imgs)def __getitem__(self, idx):image = Image.open(self.imgs[idx])if self.transform:image = self.transform(image)label = self.labels[idx]label = torch.from_numpy(np.array(label, dtype=np.int64))return image, label
這個類的主要功能:
__init__
:初始化函數,讀取包含圖像路徑和標簽的文本文件__len__
:返回數據集大小__getitem__
:根據索引返回圖像和對應的標簽
四、設備選擇
PyTorch支持在CPU、GPU(CUDA)和蘋果M系列芯片(MPS)上運行。我們使用以下代碼自動選擇可用設備:
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
print(f"Using {device} device")
五、CNN模型構建
我們構建了一個簡單的CNN模型,包含三個卷積塊和一個全連接層:
class CNN(nn.Module):def __init__(self):super(CNN, self).__init__()self.conv1 = nn.Sequential(nn.Conv2d(3, 16, 5, 1, 2),nn.ReLU(),nn.MaxPool2d(kernel_size=2),)self.conv2 = nn.Sequential(nn.Conv2d(16, 32, 5, 1, 2),nn.ReLU(),nn.MaxPool2d(kernel_size=2),)self.conv3 = nn.Sequential(nn.Conv2d(32, 64, 5, 1, 2),nn.ReLU(),nn.MaxPool2d(kernel_size=2),)self.out = nn.Linear(64*32*32, 20)def forward(self, x):x = self.conv1(x)x = self.conv2(x)x = self.conv3(x)x = x.view(x.size(0), -1)output = self.out(x)return output
模型結構說明:
conv1
:輸入3通道,輸出16通道,5×5卷積核,ReLU激活,2×2最大池化conv2
:輸入16通道,輸出32通道,同上結構conv3
:輸入32通道,輸出64通道,同上結構out
:全連接層,將64×32×32的特征圖映射到20個類別
六、模型加載與評估
1. 加載預訓練模型
model = CNN().to(device)
model.load_state_dict(torch.load("best2025-04.pth"))
model.eval()
2. 準備測試數據
test_data = food_dataset(file_path='test.txt', transform=data_transforms['valid'])
test_dataloader = DataLoader(test_data, batch_size=1, shuffle=True)
3. 測試函數
result = []
labels = []def Test_true(dataloader, model):model.eval()with torch.no_grad():for X, y in dataloader:X, y = X.to(device), y.to(device)pred = model.forward(X)result.append(pred.argmax(1).item())labels.append(y.item())Test_true(test_dataloader, model)
4. 計算準確率
from sklearn.metrics import accuracy_score
accuracy = accuracy_score(labels, result)
print(f"準確率:{accuracy:.2%}")
七、完整代碼
import torch
import torchvision.models as models
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import numpy as np
import osdata_transforms = { #字典'train':transforms.Compose([ #對圖片預處理的組合transforms.Resize([300,300]), #對數據進行改變大小transforms.RandomRotation(45), #隨機旋轉,-45到45之間隨機選transforms.CenterCrop(256), #從中心開始裁剪[256,256]transforms.RandomHorizontalFlip(p=0.5),#隨機水平翻轉,p是指選擇一個概率翻轉,p=0.5表示百分之50transforms.RandomVerticalFlip(p=0.5),#隨機垂直翻轉transforms.ColorJitter(brightness=0.2,contrast=0.1,saturation=0.1,hue=0.1),transforms.RandomGrayscale(p=0.1),#概率轉換成灰度率,3通道就是R=G=Btransforms.ToTensor(),#數據轉換為tensortransforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])#標準化,均值,標準差]),'valid':transforms.Compose([transforms.Resize([256,256]),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # 標準化,均值,標準差]),
}#Dataset是用來處理數據的
class food_dataset(Dataset): # food_dataset是自己創建的類名稱,可以改為你需要的名稱def __init__(self,file_path,transform=None): #類的初始化,解析數據文件txtself.file_path = file_pathself.imgs = []self.labels = []self.transform = transformwith open(self.file_path) as f: #是把train.txt文件中的圖片路徑保存在self.imgssamples = [x.strip().split(' ') for x in f.readlines()]for img_path,label in samples:self.imgs.append(img_path) #圖像的路徑self.labels.append(label) #標簽,還不是tensor# 初始化:把圖片目錄加到selfdef __len__(self): #類實例化對象后,可以使用len函數測量對象的個數return len(self.imgs)#training_data[1]def __getitem__(self, idx): #關鍵,可通過索引的形式獲取每一個圖片的數據及標簽image = Image.open(self.imgs[idx]) #讀取到圖片數據,還不是tensor,BGRif self.transform: #將PIL圖像數據轉換為tensorimage = self.transform(image) #圖像處理為256*256,轉換為tensorlabel = self.labels[idx] #label還不是tensorlabel = torch.from_numpy(np.array(label,dtype=np.int64)) #label也轉換為tensorreturn image,label'''判斷當前設備是否支持GPU,其中mps是蘋果m系列芯片的GPU'''
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
print(f"Using {device} device") #字符串的格式化,CUDA驅動軟件的功能:pytorch能夠去執行cuda的命令
# 神經網絡的模型也需要傳入到GPU,1個batch_size的數據集也需要傳入到GPU,才可以進行訓練''' 定義神經網絡 類的繼承這種方式'''
class CNN(nn.Module): #通過調用類的形式來使用神經網絡,神經網絡的模型,nn.mdouledef __init__(self): #輸入大小:(3,256,256)super(CNN,self).__init__() #初始化父類self.conv1 = nn.Sequential( #將多個層組合成一起,創建了一個容器,將多個網絡組合在一起nn.Conv2d( # 2d一般用于圖像,3d用于視頻數據(多一個時間維度),1d一般用于結構化的序列數據in_channels=3, # 圖像通道個數,1表示灰度圖(確定了卷積核 組中的個數)out_channels=16, # 要得到多少個特征圖,卷積核的個數kernel_size=5, # 卷積核大小 3×3stride=1, # 步長padding=2, # 一般希望卷積核處理后的結果大小與處理前的數據大小相同,效果會比較好), # 輸出的特征圖為(16,256,256)nn.ReLU(), # Relu層,不會改變特征圖的大小nn.MaxPool2d(kernel_size=2), # 進行池化操作(2×2操作),輸出結果為(16,128,128))self.conv2 = nn.Sequential(nn.Conv2d(16,32,5,1,2), #輸出(32,128,128)nn.ReLU(), #Relu層 (32,128,128)nn.MaxPool2d(kernel_size=2), #池化層,輸出結果為(32,64,64))self.conv3 = nn.Sequential(nn.Conv2d(32, 64, 5, 1, 2), # 輸出(64,64,64)nn.ReLU(), # Relu層 (64,64,64)nn.MaxPool2d(kernel_size=2), # 池化層,輸出結果為(64,32,32))self.out = nn.Linear(64*32*32,20) # 全連接層得到的結果def forward(self,x): #前向傳播,你得告訴它 數據的流向 是神經網絡層連接起來,函數名稱不能改x = self.conv1(x)x = self.conv2(x)x = self.conv3(x)x = x.view(x.size(0),-1) # flatten操作,結果為:(batch_size,32 * 64 * 64)output = self.out(x)return output
# 提取模型的2種方法:
# 1、讀取參數的方法
model = CNN().to(device) #初始化模型,w都是隨機初始化的
model.load_state_dict(torch.load("best2025-04.pth"))
# 2、讀取完整模型的方法,無需提前創建model
# model = CNN().to(device)
# model = torch.load('best.pt')#w,b,cnn
# 模型保存的對不對?
model.eval() #固定模型參數和數據,防止后面被修改
print(model)test_data = food_dataset(file_path='test.txt', transform = data_transforms['valid'])
test_dataloader = DataLoader(test_data,batch_size=1,shuffle=True)result = [] #保存的預測的結果
labels = [] #真實結果def Test_true(dataloader,model):model.eval() #測試,w就不能再更新with torch.no_grad(): #一個上下文管理器,關閉梯度計算。當你確認不會調用Tensor.backward()的時候for X,y in dataloader:X,y = X.to(device),y.to(device)pred = model.forward(X) #預測之后的結果result.append(pred.argmax(1).item())labels.append(y.item())
Test_true(test_dataloader,model)
print('預測值:\t',result)
print('真實值:\t',labels)from sklearn.metrics import accuracy_score
accuracy = accuracy_score(labels,result)
print(f"準確率:{accuracy:.2%}")
八、總結
本文詳細介紹了使用PyTorch實現圖像分類任務的完整流程,包括:
- 數據預處理與增強
- 自定義數據集類
- CNN模型構建
- 模型加載與評估
關鍵點:
- 數據增強可以提高模型的泛化能力
- 自定義Dataset類可以靈活處理不同格式的數據
- CNN是圖像分類任務的經典模型結構
- 模型評估需要使用
eval()
模式和torch.no_grad()
上下文
通過這個示例,讀者可以掌握PyTorch進行圖像分類的基本方法,并可以根據自己的需求調整模型結構和數據處理方式。