1. 項目概述
在這個項目中,我們將使用PyTorch框架構建一個卷積神經網絡(CNN)來實現食物圖像分類任務。我們的數據集包含20種不同的食物類別,包括八寶粥、巴旦木、白蘿卜、板栗等常見食物。本文將詳細介紹從數據準備、模型構建到訓練和評估的完整流程。
2. 環境準備
首先,我們需要導入必要的Python庫:
import torch
import torchvision.models as models
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import numpy as np
import os
3. 數據準備與預處理
3.1 數據轉換定義
我們定義了兩組數據轉換,分別用于訓練集和驗證集:
data_transforms = {'train': transforms.Compose([transforms.Resize([256,256]), # 統一圖像大小為256x256transforms.ToTensor(), # 轉換為PyTorch張量]),'valid': transforms.Compose([transforms.Resize([256,256]),transforms.ToTensor(),]),
}
3.2?數據文件準備
我們編寫了一個函數來生成包含圖像路徑和標簽的文本文件:
def train_test_file(root, dir):file_txt = open(dir+'.txt','w')path = os.path.join(root, dir)for roots, directories, files in os.walk(path):if len(directories) != 0:dirs = directorieselse:now_dir = roots.split('\\')for file in files:path_1 = os.path.join(roots, file)file_txt.write(path_1+' '+str(dirs.index(now_dir[-1]))+'\n')file_txt.close()# 調用函數生成訓練集和測試集文件
root = r'.\食物分類\food_dataset'
train_dir = 'train'
test_dir = 'test'
train_test_file(root, train_dir)
train_test_file(root, test_dir)
- 生成的 txt 文件如下所示,包含圖片的路徑、類別和對應的標簽
3.3 自定義數據集類
我們創建了一個繼承自torch.utils.data.Dataset
的自定義數據集類:
class food_dataset(Dataset):def __init__(self, file_path, transform=None):self.file_path = file_pathself.imgs = []self.labels = []self.transform = transformwith open(self.file_path) as f:samples = [x.strip().split(' ') for x in f.readlines()]for img_path, label in samples:self.imgs.append(img_path)self.labels.append(label)def __len__(self):return len(self.imgs)def __getitem__(self, idx):image = Image.open(self.imgs[idx])if self.transform:image = self.transform(image)label = self.labels[idx]label = torch.from_numpy(np.array(label, dtype=np.int64))return image, label
3.4 創建數據加載器
training_data = food_dataset(file_path='train.txt', transform=data_transforms['train'])
test_data = food_dataset(file_path='test.txt', transform=data_transforms['valid'])train_dataloader = DataLoader(training_data, batch_size=16, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=16, shuffle=True)
4. 模型構建
我們定義了一個包含三個卷積層的CNN模型:
class CNN(nn.Module):def __init__(self):super(CNN, self).__init__()# 第一個卷積塊self.conv1 = nn.Sequential(nn.Conv2d(3, 16, 5, 1, 2), # (16,256,256)nn.ReLU(),nn.MaxPool2d(2), # (16,128,128))# 第二個卷積塊self.conv2 = nn.Sequential(nn.Conv2d(16, 32, 5, 1, 2), # (32,128,128)nn.ReLU(),nn.MaxPool2d(2), # (32,64,64))# 第三個卷積塊self.conv3 = nn.Sequential(nn.Conv2d(32, 64, 5, 1, 2), # (64,64,64)nn.ReLU(),nn.MaxPool2d(2), # (64,32,32))# 全連接層self.out = nn.Linear(64*32*32, 20)def forward(self, x):x = self.conv1(x)x = self.conv2(x)x = self.conv3(x)x = x.view(x.size(0), -1) # 展平操作output = self.out(x)return output