食物圖像分類是計算機視覺的經典任務之一,其核心是讓機器 “看懂” 圖像中的食物類別。隨著深度學習的發展,卷積神經網絡(CNN)憑借強大的特征提取能力,成為圖像分類的主流方案。本文將基于 PyTorch 框架,從代碼實戰出發,拆解食物圖像分類項目中的核心知識點,包括環境搭建、數據預處理、數據集構建、CNN 模型設計、模型訓練與測試、單圖預測等,帶大家從零搭建一個能識別 20 類食物的分類系統。
# 導入必要的庫
import torch # PyTorch核心庫,用于構建和訓練神經網絡
from torch import nn # 神經網絡模塊,包含各種層和損失函數
from torch.utils.data import Dataset, DataLoader # 數據集和數據加載器,用于數據處理
import numpy as np # 數值計算庫,可用于數據預處理等
from PIL import Image # 圖像處理庫,用于讀取和處理圖像
from torchvision import transforms # 圖像轉換工具,用于數據增強和預處理
import os # 操作系統接口,用于文件路徑處理等# 定義數據轉換策略:訓練集使用數據增強,驗證集/測試集保持一致的基礎轉換
data_transforms = {'train': # 訓練集轉換(包含數據增強,增加樣本多樣性)transforms.Compose([transforms.Resize([300, 300]), # 先將圖像調整為300x300transforms.RandomRotation(45), # 隨機旋轉(-45~45度),增強旋轉不變性transforms.CenterCrop(256), # 中心裁剪到256x256,去除旋轉后的黑邊transforms.RandomHorizontalFlip(p=0.5), # 50%概率水平翻轉transforms.RandomVerticalFlip(p=0.5), # 50%概率垂直翻轉transforms.ColorJitter(0.1, 0.1, 0.1, 0.1), # 隨機調整亮度、對比度、飽和度和色調transforms.RandomGrayscale(p=0.1), # 10%概率轉為灰度圖transforms.ToTensor(), # 轉為Tensor格式([C, H, W]),并將像素值歸一化到[0,1]transforms.Normalize( # 使用ImageNet的均值和標準差進行標準化[0.485, 0.456, 0.406], # 均值(RGB三個通道)[0.229, 0.224, 0.225] # 標準差(RGB三個通道))]),'valid': # 驗證集/測試集轉換(無增強,保持數據一致性)transforms.Compose([transforms.Resize([256, 256]), # 調整為256x256,與訓練集裁剪后尺寸一致transforms.ToTensor(), # 轉為Tensor]),
}# -------------------------- 2. 自定義數據集類 --------------------------
class food_dataset(Dataset):"""自定義食物圖像數據集類,繼承自PyTorch的Dataset用于加載圖像路徑和對應標簽,并進行預處理"""def __init__(self, file_path, transform=None):"""初始化數據集:param file_path: 存儲圖像路徑和標簽的文本文件路徑:param transform: 圖像轉換函數(預處理/數據增強)"""self.file_path = file_path # 文本文件路徑self.transform = transform # 轉換函數self.imgs = [] # 存儲所有圖像路徑self.labels = [] # 存儲對應標簽# 讀取文件列表(每行格式:圖片路徑 數字標簽)with open(self.file_path, 'r', encoding="utf-8") as f:for line in f.readlines():line = line.strip() # 去除首尾空格和換行符if not line: # 跳過空行continue# 按空格分割路徑和標簽(假設格式嚴格,無多余空格)img_path, label = line.split(' ')self.imgs.append(img_path)self.labels.append(label)def __len__(self):"""返回數據集樣本數量"""return len(self.imgs)def __getitem__(self, index):"""根據索引獲取單個樣本(圖像和標簽):param index: 樣本索引:return: 處理后的圖像張量和標簽張量"""# 讀取圖片并強制轉為RGB(避免灰度圖導致的通道數不匹配問題)try:image = Image.open(self.imgs[index]).convert('RGB') # 確保3通道輸入except Exception as e:# 捕獲讀取錯誤,便于調試raise ValueError(f"讀取圖片 {self.imgs[index]} 失敗:{e}")# 應用轉換(預處理/數據增強)if self.transform:image = self.transform(image)# 處理標簽:轉為整數類型的張量(PyTorch分類任務要求標簽為long類型)label = torch.tensor(int(self.labels[index]), dtype=torch.int64)return image, label# 加載數據集
# 注意:需確保train.txt和test.txt文件存在,每行格式為「圖片路徑 數字標簽」
try:# 加載訓練集(使用訓練集轉換)training_data = food_dataset(file_path='./train.txt', transform=data_transforms['train'])# 加載測試集(使用驗證集轉換)test_data = food_dataset(file_path='./test.txt', transform=data_transforms['valid'])
except FileNotFoundError:# 捕獲文件不存在錯誤,提示用戶raise FileNotFoundError("請確保 train.txt 和 test.txt 文件在當前目錄下")# 創建數據加載器(批量加載數據,支持打亂和多進程)
train_dataloader = DataLoader(training_data,batch_size=8, # 批大小:每次加載8張圖片shuffle=True # 訓練時打亂數據順序,增強訓練效果
)
test_dataloader = DataLoader(test_data,batch_size=8, # 測試時也用相同批大小shuffle=True # 測試時打亂不影響結果,主要便于觀察不同樣本
)# 設備配置:優先使用GPU(cuda),其次是Apple M系列芯片(mps),最后是CPU
device = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
print(f'Using {device} device') # 打印使用的設備# 定義CNN模型(卷積神經網絡)
class CNN(nn.Module):"""自定義卷積神經網絡模型,用于食物圖像分類包含4個卷積塊和1個全連接輸出層"""def __init__(self):super().__init__() # 調用父類nn.Module的初始化方法# 第一個卷積塊:1次卷積 + ReLU激活 + 最大池化self.conv1 = nn.Sequential(# 卷積層:輸入3通道(RGB),輸出16通道,卷積核5x5,步長1,填充2(保持尺寸)nn.Conv2d(in_channels=3, out_channels=16, kernel_size=5, stride=1, padding=2),nn.ReLU(), # 激活函數,引入非線性nn.MaxPool2d(kernel_size=2), # 最大池化:尺寸減半(256→128))# 第二個卷積塊:2次卷積 + ReLU + 最大池化self.conv2 = nn.Sequential(nn.Conv2d(16, 32, 5, 1, 2), # 輸入16通道,輸出32通道nn.ReLU(),nn.Conv2d(32, 32, 5, 1, 2), # 輸入32通道,輸出32通道nn.ReLU(),nn.MaxPool2d(2), # 尺寸減半(128→64))# 第三個卷積塊:2次卷積 + ReLU + 最大池化self.conv3 = nn.Sequential(nn.Conv2d(32, 64, 5, 1, 2), # 輸入32通道,輸出64通道nn.ReLU(),nn.Conv2d(64, 128, 5, 1, 2), # 輸入64通道,輸出128通道nn.ReLU(),nn.MaxPool2d(2), # 尺寸減半(64→32))# 第四個卷積塊:1次卷積 + ReLU(無池化,保持尺寸)self.conv4 = nn.Sequential(nn.Conv2d(128, 128, 5, 1, 2), # 輸入128通道,輸出128通道nn.ReLU(), # 輸出尺寸:32×32,通道數128)# 全連接輸出層:將特征映射到20個類別(食物種類)# 輸入尺寸計算:128通道 × 32高 × 32寬(經多次池化后的特征圖尺寸)self.out = nn.Linear(128 * 32 * 32, 20) # 20類:需與標簽數量一致def forward(self, x):"""前向傳播:定義數據在網絡中的流動路徑:param x: 輸入張量,形狀為[batch_size, 3, 256, 256]:return: 輸出張量,形狀為[batch_size, 20](各類別的預測分數)"""x = self.conv1(x) # 經第一個卷積塊處理x = self.conv2(x) # 經第二個卷積塊處理x = self.conv3(x) # 經第三個卷積塊處理x = self.conv4(x) # 經第四個卷積塊處理x = x.view(x.size(0), -1) # 展平特征圖:[batch_size, 128*32*32]output = self.out(x) # 經全連接層輸出預測結果return output# -------------------------- 訓練與測試函數 --------------------------
def train(dataloader, model, loss_fn, optimizer):"""訓練模型的函數:param dataloader: 訓練數據集加載器:param model: 待訓練的模型:param loss_fn: 損失函數(用于計算預測誤差):param optimizer: 優化器(用于更新模型參數)"""model.train() # 開啟訓練模式(啟用Dropout、BatchNorm等訓練特定行為)batch_size_num = 1 # 記錄當前批次編號for X, y in dataloader:# 將數據移動到指定設備(GPU/CPU)X, y = X.to(device), y.to(device)# 前向傳播:計算模型預測結果pred = model(X)# 計算損失(預測值與真實標簽的差距)loss = loss_fn(pred, y)# 反向傳播與參數更新optimizer.zero_grad() # 清空上一輪的梯度(避免梯度累積)loss.backward() # 反向傳播計算梯度optimizer.step() # 根據梯度更新模型參數# 打印損失(每2個batch打印一次,便于監控訓練過程)loss_val = loss.item() # 獲取損失的標量值if batch_size_num % 2 == 0:print(f"loss: {loss_val:>7f} [batch: {batch_size_num}]")batch_size_num += 1def test(dataloader, model, loss_fn):model.eval() # 開啟評估模式(關閉Dropout、固定BatchNorm參數等)size = len(dataloader.dataset) # 測試集總樣本數num_batches = len(dataloader) # 測試集批次數test_loss, correct = 0, 0 # 總損失和正確預測數# 關閉梯度計算(測試時不需要更新參數,節省計算資源)with torch.no_grad():for X, y in dataloader:X, y = X.to(device), y.to(device) # 數據移至設備pred = model(X) # 預測test_loss += loss_fn(pred, y).item() # 累加損失# 統計正確預測數:取預測概率最大的類別與真實標簽比較correct += (pred.argmax(1) == y).type(torch.float).sum().item()# 計算平均損失和準確率test_loss /= num_batches # 平均損失correct /= size # 準確率print(f"\nTest Result: \n Accuracy: {(100 * correct):>5.2f}%, Avg Loss: {test_loss:>8f}\n")# -------------------------- 單張圖片預測函數 --------------------------
def predict_single_image(image_path, model, transform, device, label_map):"""對單張圖片進行預測:param image_path: 圖片路徑:param model: 訓練好的模型:param transform: 圖像預處理函數(與測試集一致):param device: 計算設備:param label_map: 標簽映射字典(數字標簽→食物名稱):return: 預測的食物名稱"""# 讀取并預處理圖片(與測試集預處理一致)image = Image.open(image_path).convert('RGB') # 確保3通道image = transform(image) # 應用預處理(Resize和ToTensor)# 增加batch維度(模型要求輸入格式為[batch, C, H, W],這里batch=1)image = image.unsqueeze(0).to(device)# 模型預測model.eval() # 開啟評估模式with torch.no_grad(): # 關閉梯度計算pred_logits = model(image) # 得到預測分數(logits)# 取概率最大的類別標簽(argmax(1)按行取最大值索引)pred_label = pred_logits.argmax(1).item()# 映射為食物名稱if pred_label not in label_map:raise KeyError(f"預測標簽 {pred_label} 不在標簽映射字典中")return label_map[pred_label]# -------------------------- 主程序 --------------------------
if __name__ == "__main__":# 初始化模型、損失函數、優化器model = CNN().to(device) # 創建模型并移至設備loss_fn = nn.CrossEntropyLoss() # 多分類問題常用交叉熵損失# Adam優化器:自適應學習率,訓練效果較好,學習率0.001optimizer = torch.optim.Adam(model.parameters(), lr=0.001)# 訓練模型(100輪)epochs = 100for t in range(epochs):print(f"\nEpoch: {t + 1}/{epochs}\n----------------------------")train(train_dataloader, model, loss_fn, optimizer)print("Training Done!")# 測試模型在測試集上的性能test(test_dataloader, model, loss_fn)# 定義標簽映射字典:數字標簽→食物名稱# 需與數據集的標簽完全對應(順序和數量一致)label_to_food = {0: "八寶粥", 1: "巴旦木", 2: "白蘿卜", 3: "板栗", 4: "菠蘿",5: "草莓", 6: "蛋", 7: "蛋撻", 8: "骨肉相連", 9: "瓜子",10: "哈密瓜", 11: "漢堡", 12: "胡蘿卜", 13: "火龍果", 14: "雞翅",15: "青菜", 16: "生肉", 17: "圣女果", 18: "薯條", 19: "炸雞"}# 輸入圖片路徑并預測image_path = input("請輸入圖片路徑:") # 用戶輸入待預測圖片路徑true_food = input("請輸入該圖片的真實食物名稱:") # 用戶輸入真實標簽(用于對比)# 執行預測并輸出結果predicted_food = predict_single_image(image_path=image_path,model=model,transform=data_transforms['valid'],device=device,label_map=label_to_food)# 輸出對比結果print("\n" + "-" * 50)print(f"預測結果:{predicted_food}")print(f"真實結果:{true_food}")print(f"判斷:{'預測正確' if predicted_food == true_food else '預測錯誤'}")print("-" * 50)
二、數據預處理:讓數據 “適配” 模型
在深度學習中,數據預處理的質量直接影響模型性能。原始圖像可能存在尺寸不一、像素值范圍差異大、樣本數量不足等問題,需通過預處理將其轉化為模型可接受的格式,并通過數據增強提升模型泛化能力。
本項目的預處理邏輯集中在data_transforms
字典中,分 “訓練集” 和 “驗證集 / 測試集” 兩種策略,我們逐一拆解其設計思路。
2.1 為什么要區分訓練集與驗證集預處理?
- 訓練集:需要通過 “數據增強” 增加樣本多樣性,避免模型過擬合(即模型只記住訓練樣本,對新樣本識別能力差)。
- 驗證集 / 測試集:需保持數據的 “真實性”,僅進行基礎預處理(如 Resize、ToTensor),確保評估結果能反映模型的實際泛化能力。
2.2 訓練集數據增強:每一步的作用與原理
訓練集的預處理鏈為:Resize → RandomRotation → CenterCrop → RandomHorizontalFlip → RandomVerticalFlip → ColorJitter → RandomGrayscale → ToTensor → Normalize
,我們逐個解析:
(1)Resize ([300, 300]):統一初始尺寸
將所有圖像調整為 300×300 像素。為什么不直接調整為最終的 256×256?因為后續會進行旋轉和裁剪,預留一定尺寸可避免旋轉后出現黑邊。
(2)RandomRotation (45):隨機旋轉
隨機將圖像旋轉 - 45°~45°。食物在拍攝時可能有不同角度(如躺著的漢堡、豎放的胡蘿卜),旋轉增強能讓模型對角度不敏感,提升魯棒性。
(3)CenterCrop (256):中心裁剪
將旋轉后的圖像從中心裁剪為 256×256。旋轉會導致圖像邊緣出現黑邊,裁剪可去除黑邊,同時將圖像尺寸統一為模型輸入尺寸(256×256)。
(4)RandomHorizontalFlip (p=0.5) & RandomVerticalFlip (p=0.5):隨機翻轉
- 水平翻轉(50% 概率):模擬 “左右鏡像” 的食物(如翻轉后的草莓外觀不變)。
- 垂直翻轉(50% 概率):模擬 “上下顛倒” 的場景(如掉落的薯條)。
翻轉操作不改變食物的核心特征,但能增加樣本多樣性,且計算成本低。
(5)ColorJitter (0.1, 0.1, 0.1, 0.1):隨機顏色抖動
調整圖像的亮度、對比度、飽和度、色調,各參數的取值范圍為 0~1(0 表示不調整,1 表示最大調整幅度)。
食物圖像的顏色易受光照影響(如白天和夜晚拍攝的青菜顏色不同),顏色抖動能讓模型對光照變化不敏感。
(6)RandomGrayscale (p=0.1):隨機灰度化
10% 概率將彩色圖像轉為灰度圖。雖然食物的顏色是重要特征,但灰度化能迫使模型關注食物的形狀、紋理等更本質的特征,避免過度依賴顏色信息(如紅色的草莓和紅色的圣女果,需通過形狀區分)。
(7)ToTensor ():轉為 Tensor 格式
將 PIL 圖像(H×W×C,像素值 0~255)轉為 PyTorch Tensor(C×H×W,像素值歸一化到 0~1)。
- 維度轉換:模型要求輸入為 “通道優先”(C×H×W),而 PIL 圖像是 “高度優先”(H×W×C),需通過 ToTensor 調整。
- 歸一化:將像素值從 0~255 縮放到 0~1,避免大數值導致模型梯度爆炸。
(8)Normalize ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]):標準化
使用 ImageNet 數據集的均值和標準差對 Tensor 進行標準化,公式為:
標準化后像素值 = (原始像素值 - 均值) / 標準差
為什么用 ImageNet 的參數?因為本項目后續可擴展為遷移學習(使用預訓練模型),而預訓練模型是在 ImageNet 上訓練的,使用相同的標準化參數能讓模型更快收斂。
transforms.Compose([transforms.Resize([300, 300]), # 先將圖像調整為300x300transforms.RandomRotation(45), # 隨機旋轉(-45~45度),增強旋轉不變性transforms.CenterCrop(256), # 中心裁剪到256x256,去除旋轉后的黑邊transforms.RandomHorizontalFlip(p=0.5), # 50%概率水平翻轉transforms.RandomVerticalFlip(p=0.5), # 50%概率垂直翻轉transforms.ColorJitter(0.1, 0.1, 0.1, 0.1), # 隨機調整亮度、對比度、飽和度和色調transforms.RandomGrayscale(p=0.1), # 10%概率轉為灰度圖transforms.ToTensor(), # 轉為Tensor格式([C, H, W]),并將像素值歸一化到[0,1]transforms.Normalize( # 使用ImageNet的均值和標準差進行標準化[0.485, 0.456, 0.406], # 均值(RGB三個通道)[0.229, 0.224, 0.225] # 標準差(RGB三個通道))
2.3 驗證集預處理
驗證集的預處理鏈為:Resize([256, 256]) → ToTensor()
,僅保留基礎操作:
- Resize ([256, 256]):直接將圖像調整為 256×256,無需旋轉(避免引入非真實樣本)。
- ToTensor ():與訓練集一致,確保數據格式統一。
(注:代碼中驗證集未做 Normalize,實際項目中建議與訓練集保持一致,此處可根據需求調整)
三、自定義 Dataset:PyTorch 數據加載的核心
PyTorch 通過Dataset
和DataLoader
實現數據加載,其中Dataset
負責 “定義數據來源和格式”,DataLoader
負責 “批量加載和并行處理”。本項目自定義了food_dataset
類,用于加載食物圖像和對應標簽,我們詳細解析其實現邏輯。
3.1 Dataset 的核心作用
Dataset
是一個抽象類,要求子類必須實現三個方法:
__init__
:初始化數據集(讀取文件列表、加載預處理函數)。__len__
:返回數據集的總樣本數。__getitem__
:根據索引返回單個樣本(圖像 + 標簽)。
這三個方法確保了 PyTorch 能高效地迭代訪問數據。
3.2 food_dataset 類逐方法解析
(1)init:初始化數據列表
def __init__(self, file_path, transform=None):self.file_path = file_path # 存儲圖像路徑和標簽的txt文件路徑self.transform = transform # 預處理函數self.imgs = [] # 存儲所有圖像路徑self.labels = [] # 存儲對應標簽# 讀取txt文件,解析圖像路徑和標簽with open(self.file_path, 'r', encoding="utf-8") as f:for line in f.readlines():line = line.strip() # 去除首尾空格和換行符if not line: # 跳過空行(避免解析錯誤)continueimg_path, label = line.split(' ') # 按空格分割路徑和標簽self.imgs.append(img_path)self.labels.append(label)
- txt 文件格式要求:每行需包含 “圖像路徑” 和 “數字標簽”,用空格分隔
其中 “0” 對應 “八寶粥”,“1” 對應 “巴旦木”,需與后續label_to_food
字典一致。
(2)len:返回樣本總數
def __len__(self):return len(self.imgs)
簡單直接,返回self.imgs
的長度),DataLoader
會通過該方法確定迭代次數。
(3)getitem:返回單個樣本
def __getitem__(self, index):# 讀取圖像并強制轉為RGB(避免灰度圖通道數問題)try:image = Image.open(self.imgs[index]).convert('RGB')except Exception as e:raise ValueError(f"讀取圖片 {self.imgs[index]} 失敗:{e}")# 應用預處理if self.transform:image = self.transform(image)# 處理標簽:轉為int64類型Tensor(PyTorch分類任務要求)label = torch.tensor(int(self.labels[index]), dtype=torch.int64)return image, label
這是Dataset
的核心方法,需重點關注三個細節:
- 強制 RGB 格式:
convert('RGB')
確保所有圖像都是 3 通道(避免部分灰度圖是 1 通道,導致模型輸入維度不匹配)。 - 異常處理:
try-except
捕獲圖像讀取錯誤(如路徑錯誤、圖像損壞),并明確提示錯誤位置,便于調試。 - 標簽類型:將標簽轉為
torch.int64
(即 LongTensor),因為 PyTorch 的CrossEntropyLoss
要求標簽為 Long 類型。
3.3 如何準備自己的數據集?
- 收集圖像:每個食物類別收集至少 100 張圖像(樣本越多,模型性能越好),建議按類別分文件夾存儲
- 生成 txt 文件:編寫腳本遍歷圖像文件夾,生成
train.txt
和test.txt
import os# 數據集根目錄 train_root = "./dataset/train" test_root = "./dataset/test" # 標簽映射(與后續一致) label_to_food = {0: "八寶粥", 1: "巴旦木", ..., 19: "炸雞"} # 反向映射:食物名稱→數字標簽 food_to_label = {v: k for k, v in label_to_food.items()}# 生成train.txt with open("train.txt", "w", encoding="utf-8") as f:for food_name in os.listdir(train_root):food_dir = os.path.join(train_root, food_name)if not os.path.isdir(food_dir):continuelabel = food_to_label[food_name]for img_name in os.listdir(food_dir):img_path = os.path.join(food_dir, img_name)f.write(f"{img_path} {label}\n")# 生成test.txt(邏輯同上) with open("test.txt", "w", encoding="utf-8") as f:for food_name in os.listdir(test_root):food_dir = os.path.join(test_root, food_name)if not os.path.isdir(food_dir):continuelabel = food_to_label[food_name]for img_name in os.listdir(food_dir):img_path = os.path.join(food_dir, img_name)f.write(f"{img_path} {label}\n")
- 檢查路徑:確保 txt 文件中的圖像路徑與實際文件路徑一致
四、DataLoader:批量加載與并行處理
Dataset
定義了數據的 “來源”,而DataLoader
則負責將數據 “批量加載” 到模型中,并支持并行處理,提升數據加載速度。
4.1 DataLoader 的核心參數解析
本項目的DataLoader
初始化代碼如下:
train_dataloader = DataLoader(training_data, batch_size=8, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=8, shuffle=True)
核心參數含義:
batch_size
:每次加載的樣本數量(批大小),需根據 GPU 顯存調整shuffle
:是否打亂數據順序。訓練集設為True
,驗證集可設為False
,本項目測試集設為True
是為了觀察不同樣本的預測效果。
4.2 DataLoader 與 Dataset 的協作流程
DataLoader
的工作流程可概括為:
- 調用
Dataset.__len__()
獲取總樣本數,計算總批次數(總樣本數 //batch_size)。 - 若
shuffle=True
,則在每個 epoch(訓練輪次)開始前打亂樣本索引。 - 對每個批次,根據索引調用
Dataset.__getitem__()
獲取單個樣本,組裝成一個批次的 Tensor(形狀為 [batch_size, C, H, W])。 - 將批次數據移動到指定設備(CUDA/MPS/CPU),供模型訓練或測試。
4.3 數據加載到設備的邏輯
X, y = X.to(device), y.to(device)
其中device
是通過以下代碼確定的:
device = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
- 為什么要移動設備??模型和數據必須在同一設備上才能進行計算(如模型在 CUDA 上,數據也需在 CUDA 上),否則會報錯。
- 設備優先級:優先使用 CUDA(NVIDIA GPU),其次是 MPS(Apple M 系列),最后是 CPU
五、CNN 模型構建:從卷積到全連接的特征提取
卷積神經網絡(CNN)是圖像分類的核心,其通過 “卷積層提取局部特征→池化層降維→全連接層分類” 的流程,實現對圖像的識別。本項目的 CNN 模型包含 4 個卷積塊和 1 個全連接層,我們逐一解析其設計思路和尺寸計算。
5.1 CNN 的核心組件與作用
在解析代碼前,先回顧 CNN 的三個核心組件:
- 卷積層(Conv2d):通過卷積核滑動提取圖像的局部特征(如邊緣、紋理、形狀),輸出 “特征圖”(Feature Map)。
- 激活函數(ReLU):引入非線性,讓模型能擬合復雜的特征關系(避免線性模型的表達能力不足)。
- 池化層(MaxPool2d):對特征圖進行下采樣,降低維度和計算量,同時增強模型對特征位置的魯棒性。
5.2 模型代碼逐塊解析
模型定義代碼如下,我們按 “卷積塊 1→卷積塊 2→卷積塊 3→卷積塊 4→全連接層” 的順序解析:
class CNN(nn.Module):def __init__(self):super().__init__()# 卷積塊1:1次卷積 + ReLU + 最大池化self.conv1 = nn.Sequential(nn.Conv2d(in_channels=3, out_channels=16, kernel_size=5, stride=1, padding=2),nn.ReLU(),nn.MaxPool2d(kernel_size=2), # 池化后尺寸:256→128)# 卷積塊2:2次卷積 + ReLU + 最大池化self.conv2 = nn.Sequential(nn.Conv2d(16, 32, 5, 1, 2),nn.ReLU(),nn.Conv2d(32, 32, 5, 1, 2),nn.ReLU(),nn.MaxPool2d(2), # 尺寸:128→64)# 卷積塊3:2次卷積 + ReLU + 最大池化self.conv3 = nn.Sequential(nn.Conv2d(32, 64, 5, 1, 2),nn.ReLU(),nn.Conv2d(64, 128, 5, 1, 2),nn.ReLU(),nn.MaxPool2d(2), # 尺寸:64→32)# 卷積塊4:1次卷積 + ReLU(無池化)self.conv4 = nn.Sequential(nn.Conv2d(128, 128, 5, 1, 2),nn.ReLU(), # 輸出尺寸:32×32,通道數128)# 全連接層:映射到20類self.out = nn.Linear(128 * 32 * 32, 20)
池化層參數計算:尺寸減半
MaxPool2d(kernel_size=2)
表示池化核大小為 2×2,步長默認等于核大小(即 2),因此輸出尺寸為輸入尺寸的 1/2:
conv1
池化前尺寸:256×256 → 池化后:128×128conv2
池化前尺寸:128×128 → 池化后:64×64conv3
池化前尺寸:64×64 → 池化后:32×32
各卷積塊的輸出特征
- conv1:輸出特征圖形狀為 [batch_size, 16, 128, 128],提取的是圖像的低級特征(如邊緣、顏色塊)。
- conv2:輸出形狀為 [batch_size, 32, 64, 64],通過 2 次卷積提取更復雜的特征(如食物的局部輪廓)。
- conv3:輸出形狀為 [batch_size, 128, 32, 32],通道數增加到 128,特征更抽象(如食物的結構特征)。
- conv4:輸出形狀為 [batch_size, 128, 32, 32],無池化,進一步細化特征(避免池化導致的特征損失)。
(4)全連接層:從特征到分類
全連接層self.out
的輸入維度是128×32×32
,這是由conv4
的輸出特征圖形狀決定的:
conv4
輸出:[batch_size, 128, 32, 32] → 展平后為 [batch_size, 128×32×32](展平操作在forward
中通過x.view(x.size(0), -1)
實現)。- 輸出維度:20,對應 20 種食物類別,每個維度輸出該類別的 “預測分數”(后續通過
argmax
取分數最高的類別作為預測結果)。
5.3 forward 方法:定義數據流動路徑
def forward(self, x):x = self.conv1(x) # 經卷積塊1處理x = self.conv2(x) # 經卷積塊2處理x = self.conv3(x) # 經卷積塊3處理x = self.conv4(x) # 經卷積塊4處理x = x.view(x.size(0), -1) # 展平:[batch_size, 128*32*32]output = self.out(x) # 全連接層輸出return output
- 展平操作:
x.view(x.size(0), -1)
將 4 維特征圖(batch, C, H, W)轉為 2 維張量(batch, C×H×W),因為全連接層僅接受 2 維輸入。 - 數據流動:輸入圖像([batch, 3, 256, 256])→ 卷積塊 1→2→3→4 → 展平 → 全連接層 → 輸出([batch, 20])。
5.4 模型初始化與設備移動
模型初始化代碼如下:
model = CNN().to(device)
CNN()
創建模型實例,to(device)
將模型參數移動到指定設備(CUDA/MPS/CPU),確保模型和數據在同一設備上計算。
六、模型訓練與測試:從損失下降到性能評估
模型構建完成后,需通過訓練讓模型 “學習” 食物特征,再通過測試評估模型的泛化能力。本項目定義了train
和test
兩個函數,分別實現訓練和測試邏輯。
6.1 訓練函數:讓模型 “學習”
訓練函數的核心是 “前向傳播計算損失→反向傳播更新參數”,代碼如下:
def train(dataloader, model, loss_fn, optimizer):model.train() # 開啟訓練模式(啟用Dropout、BatchNorm訓練行為)batch_size_num = 1 # 批次編號,用于打印損失for X, y in dataloader:# 數據移動到設備X, y = X.to(device), y.to(device)# 1. 前向傳播:計算預測結果pred = model(X)# 2. 計算損失:預測值與真實標簽的差距loss = loss_fn(pred, y)# 3. 反向傳播與參數更新optimizer.zero_grad() # 清空上一輪梯度(避免累積)loss.backward() # 反向傳播計算梯度optimizer.step() # 根據梯度更新模型參數# 打印損失(每2個批次打印一次)loss_val = loss.item()if batch_size_num % 2 == 0:print(f"loss: {loss_val:>7f} [batch: {batch_size_num}]")batch_size_num += 1
(1)關鍵步驟解析
model.train():開啟訓練模式,對含有 Dropout、BatchNorm 的模型至關重要:
- Dropout:訓練時隨機 “關閉” 部分神經元,防止過擬合;測試時不關閉。
- BatchNorm:訓練時使用批次的均值和方差歸一化;測試時使用訓練階段累積的均值和方差。
前向傳播(Forward Pass):
pred = model(X)
:將批次數據輸入模型,得到預測結果([batch, 20])。loss = loss_fn(pred, y)
:計算損失,本項目使用CrossEntropyLoss
(多分類任務的常用損失函數)。
反向傳播(Backward Pass)與參數更新:
optimizer.zero_grad()
:清空梯度。若不清空,梯度會累積到上一輪,導致參數更新錯誤。loss.backward()
:根據損失計算各參數的梯度(optimizer.step()
:根據梯度更新模型參數
(2)損失函數與優化器選擇
本項目使用的損失函數和優化器如下:
loss_fn = nn.CrossEntropyLoss() # 多分類交叉熵損失
optimizer = torch.optim.Adam(model.parameters(), lr=0.001) # Adam優化器
- CrossEntropyLoss:適用于多分類任務
- Adam 優化器:自適應學習率優化器,收斂速度快,對學習率不敏感,是深度學習中最常用的優化器之一。
lr=0.001
是常用的初始學習率,可根據訓練情況調整
6.2 測試函數:評估模型泛化能力
測試函數的核心是 “計算模型在測試集上的準確率和平均損失”,代碼如下:
def test(dataloader, model, loss_fn):model.eval() # 開啟評估模式(關閉Dropout、固定BatchNorm)size = len(dataloader.dataset) # 測試集總樣本數num_batches = len(dataloader) # 測試集批次數test_loss, correct = 0, 0 # 總損失和正確預測數# 關閉梯度計算(節省資源,避免參數更新)with torch.no_grad():for X, y in dataloader:X, y = X.to(device), y.to(device)pred = model(X)# 累加損失test_loss += loss_fn(pred, y).item()# 累加正確預測數correct += (pred.argmax(1) == y).type(torch.float).sum().item()# 計算平均損失和準確率test_loss /= num_batchescorrect /= sizeprint(f"\nTest Result: \n Accuracy: {(100 * correct):>5.2f}%, Avg Loss: {test_loss:>8f}\n")
(1)關鍵步驟解析
- model.eval():開啟評估模式,與
model.train()
對應,確保模型在測試時的行為與訓練時一致(如關閉 Dropout)。 - torch.no_grad():上下文管理器,關閉梯度計算。測試時無需更新參數,關閉梯度可大幅減少內存占用和計算時間。
- 準確率計算:
pred.argmax(1)
:對每個樣本,取預測分數最高的類別(維度 1 是類別維度,[batch, 20]→[batch, 1])。(pred.argmax(1) == y)
:比較預測類別與真實標簽,得到布爾張量(True = 正確,False = 錯誤)。type(torch.float).sum().item()
:將布爾張量轉為 float(True=1,False=0),求和得到正確預測數,再轉為 Python 標量。
- 結果解讀:
- Accuracy:準確率(正確預測數 / 總樣本數),反映模型的整體識別能力,越高越好。
- Avg Loss:平均損失,反映模型預測值與真實標簽的平均差距,越低越好。
6.3 訓練流程與輪次設置
訓練流程代碼如下:
# 訓練輪次(epochs)
epochs = 100
for t in range(epochs):print(f"\nEpoch: {t + 1}/{epochs}\n")train(train_dataloader, model, loss_fn, optimizer)
print("Training Done!")# 測試模型
test(test_dataloader, model, loss_fn)
- epochs(訓練輪次):表示模型將遍歷整個訓練集的次數。本項目設為 100,可根據實際情況調整:
- 若訓練損失仍在下降,可增加 epochs;
- 若訓練損失下降但測試損失上升(過擬合),可減少 epochs 或加入早停機制。
- 訓練與測試順序:每輪訓練后可加入測試(如在
train
后調用test
),便于監控模型是否過擬合;本項目在所有訓練完成后測試,適用于快速驗證。
七、單張圖片預測:模型的實際應用
訓練完成后,需將模型用于實際場景 —— 對單張食物圖片進行分類。本項目定義了predict_single_image
函數,實現從圖像讀取到類別輸出的完整流程。
7.1 預測函數解析
def predict_single_image(image_path, model, transform, device, label_map):# 1. 讀取并預處理圖像(與測試集一致)image = Image.open(image_path).convert('RGB') # 強制RGBimage = transform(image) # 應用預處理(Resize + ToTensor)# 2. 增加batch維度(模型要求輸入為[batch, C, H, W])image = image.unsqueeze(0).to(device)# 3. 模型預測model.eval() # 開啟評估模式with torch.no_grad():pred_logits = model(image) # 預測分數(logits)pred_label = pred_logits.argmax(1).item() # 取最高分數類別# 4. 映射為食物名稱if pred_label not in label_map:raise KeyError(f"預測標簽 {pred_label} 不在標簽映射字典中")return label_map[pred_label]
(1)關鍵步驟解析
- 圖像預處理一致性:預測時的預處理必須與測試集一致(本項目使用
data_transforms['valid']
),否則模型輸入格式不匹配,預測結果會失真。 - 增加 batch 維度:模型訓練和測試時輸入都是批次數據([batch, C, H, W]),而單張圖片是 [C, H, W],需通過
unsqueeze(0)
在第 0 維(batch 維)增加一個維度,變為 [1, C, H, W]。 - 標簽映射:
label_map
(如label_to_food
)將數字標簽(如 0)映射為食物名稱(如 “八寶粥”),讓預測結果更直觀。
7.2 預測實戰與結果展示
預測代碼如下,用戶輸入圖片路徑和真實標簽,模型輸出預測結果并對比:
# 標簽映射字典(與數據集標簽對應)
label_to_food = {0: "八寶粥", 1: "巴旦木", 2: "白蘿卜", 3: "板栗", 4: "菠蘿",5: "草莓", 6: "蛋", 7: "蛋撻", 8: "骨肉相連", 9: "瓜子",10: "哈密瓜", 11: "漢堡", 12: "胡蘿卜", 13: "火龍果", 14: "雞翅",15: "青菜", 16: "生肉", 17: "圣女果", 18: "薯條", 19: "炸雞"
}# 用戶輸入
image_path = input("請輸入圖片路徑:")
true_food = input("請輸入該圖片的真實食物名稱:")# 執行預測
predicted_food = predict_single_image(image_path=image_path,model=model,transform=data_transforms['valid'],device=device,label_map=label_to_food
)# 輸出結果
print("\n" + "-" * 50)
print(f"預測結果:{predicted_food}")
print(f"真實結果:{true_food}")
print(f"判斷:{'預測正確' if predicted_food == true_food else '預測錯誤'}")
print("-" * 50)
(1)預測示例
假設用戶輸入:
- 圖片路徑:"D:\食物分類\food_dataset\test\八寶粥\img_八寶粥罐_81.jpeg"(一張八寶粥圖片)
- 真實食物名稱:八寶粥
模型輸出:
請輸入圖片路徑:./test_images/hamburger.jpg
請輸入該圖片的真實食物名稱:漢堡--------------------------------------------------
預測結果:漢堡
真實結果:漢堡
判斷:預測正確
--------------------------------------------------
(2)預測錯誤原因
樣本數量不足
訓練的輪數過少