從 0 到 1 實現 PyTorch 食物圖像分類:核心知識點與完整實

食物圖像分類是計算機視覺的經典任務之一,其核心是讓機器 “看懂” 圖像中的食物類別。隨著深度學習的發展,卷積神經網絡(CNN)憑借強大的特征提取能力,成為圖像分類的主流方案。本文將基于 PyTorch 框架,從代碼實戰出發,拆解食物圖像分類項目中的核心知識點,包括環境搭建、數據預處理、數據集構建、CNN 模型設計、模型訓練與測試、單圖預測等,帶大家從零搭建一個能識別 20 類食物的分類系統。

# 導入必要的庫
import torch  # PyTorch核心庫,用于構建和訓練神經網絡
from torch import nn  # 神經網絡模塊,包含各種層和損失函數
from torch.utils.data import Dataset, DataLoader  # 數據集和數據加載器,用于數據處理
import numpy as np  # 數值計算庫,可用于數據預處理等
from PIL import Image  # 圖像處理庫,用于讀取和處理圖像
from torchvision import transforms  # 圖像轉換工具,用于數據增強和預處理
import os  # 操作系統接口,用于文件路徑處理等# 定義數據轉換策略:訓練集使用數據增強,驗證集/測試集保持一致的基礎轉換
data_transforms = {'train':  # 訓練集轉換(包含數據增強,增加樣本多樣性)transforms.Compose([transforms.Resize([300, 300]),  # 先將圖像調整為300x300transforms.RandomRotation(45),  # 隨機旋轉(-45~45度),增強旋轉不變性transforms.CenterCrop(256),  # 中心裁剪到256x256,去除旋轉后的黑邊transforms.RandomHorizontalFlip(p=0.5),  # 50%概率水平翻轉transforms.RandomVerticalFlip(p=0.5),  # 50%概率垂直翻轉transforms.ColorJitter(0.1, 0.1, 0.1, 0.1),  # 隨機調整亮度、對比度、飽和度和色調transforms.RandomGrayscale(p=0.1),  # 10%概率轉為灰度圖transforms.ToTensor(),  # 轉為Tensor格式([C, H, W]),并將像素值歸一化到[0,1]transforms.Normalize(  # 使用ImageNet的均值和標準差進行標準化[0.485, 0.456, 0.406],  # 均值(RGB三個通道)[0.229, 0.224, 0.225]  # 標準差(RGB三個通道))]),'valid':  # 驗證集/測試集轉換(無增強,保持數據一致性)transforms.Compose([transforms.Resize([256, 256]),  # 調整為256x256,與訓練集裁剪后尺寸一致transforms.ToTensor(),  # 轉為Tensor]),
}# -------------------------- 2. 自定義數據集類 --------------------------
class food_dataset(Dataset):"""自定義食物圖像數據集類,繼承自PyTorch的Dataset用于加載圖像路徑和對應標簽,并進行預處理"""def __init__(self, file_path, transform=None):"""初始化數據集:param file_path: 存儲圖像路徑和標簽的文本文件路徑:param transform: 圖像轉換函數(預處理/數據增強)"""self.file_path = file_path  # 文本文件路徑self.transform = transform  # 轉換函數self.imgs = []  # 存儲所有圖像路徑self.labels = []  # 存儲對應標簽# 讀取文件列表(每行格式:圖片路徑 數字標簽)with open(self.file_path, 'r', encoding="utf-8") as f:for line in f.readlines():line = line.strip()  # 去除首尾空格和換行符if not line:  # 跳過空行continue# 按空格分割路徑和標簽(假設格式嚴格,無多余空格)img_path, label = line.split(' ')self.imgs.append(img_path)self.labels.append(label)def __len__(self):"""返回數據集樣本數量"""return len(self.imgs)def __getitem__(self, index):"""根據索引獲取單個樣本(圖像和標簽):param index: 樣本索引:return: 處理后的圖像張量和標簽張量"""# 讀取圖片并強制轉為RGB(避免灰度圖導致的通道數不匹配問題)try:image = Image.open(self.imgs[index]).convert('RGB')  # 確保3通道輸入except Exception as e:# 捕獲讀取錯誤,便于調試raise ValueError(f"讀取圖片 {self.imgs[index]} 失敗:{e}")# 應用轉換(預處理/數據增強)if self.transform:image = self.transform(image)# 處理標簽:轉為整數類型的張量(PyTorch分類任務要求標簽為long類型)label = torch.tensor(int(self.labels[index]), dtype=torch.int64)return image, label# 加載數據集
# 注意:需確保train.txt和test.txt文件存在,每行格式為「圖片路徑 數字標簽」
try:# 加載訓練集(使用訓練集轉換)training_data = food_dataset(file_path='./train.txt', transform=data_transforms['train'])# 加載測試集(使用驗證集轉換)test_data = food_dataset(file_path='./test.txt', transform=data_transforms['valid'])
except FileNotFoundError:# 捕獲文件不存在錯誤,提示用戶raise FileNotFoundError("請確保 train.txt 和 test.txt 文件在當前目錄下")# 創建數據加載器(批量加載數據,支持打亂和多進程)
train_dataloader = DataLoader(training_data,batch_size=8,  # 批大小:每次加載8張圖片shuffle=True  # 訓練時打亂數據順序,增強訓練效果
)
test_dataloader = DataLoader(test_data,batch_size=8,  # 測試時也用相同批大小shuffle=True  # 測試時打亂不影響結果,主要便于觀察不同樣本
)# 設備配置:優先使用GPU(cuda),其次是Apple M系列芯片(mps),最后是CPU
device = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
print(f'Using {device} device')  # 打印使用的設備# 定義CNN模型(卷積神經網絡)
class CNN(nn.Module):"""自定義卷積神經網絡模型,用于食物圖像分類包含4個卷積塊和1個全連接輸出層"""def __init__(self):super().__init__()  # 調用父類nn.Module的初始化方法# 第一個卷積塊:1次卷積 + ReLU激活 + 最大池化self.conv1 = nn.Sequential(# 卷積層:輸入3通道(RGB),輸出16通道,卷積核5x5,步長1,填充2(保持尺寸)nn.Conv2d(in_channels=3, out_channels=16, kernel_size=5, stride=1, padding=2),nn.ReLU(),  # 激活函數,引入非線性nn.MaxPool2d(kernel_size=2),  # 最大池化:尺寸減半(256→128))# 第二個卷積塊:2次卷積 + ReLU + 最大池化self.conv2 = nn.Sequential(nn.Conv2d(16, 32, 5, 1, 2),  # 輸入16通道,輸出32通道nn.ReLU(),nn.Conv2d(32, 32, 5, 1, 2),  # 輸入32通道,輸出32通道nn.ReLU(),nn.MaxPool2d(2),  # 尺寸減半(128→64))# 第三個卷積塊:2次卷積 + ReLU + 最大池化self.conv3 = nn.Sequential(nn.Conv2d(32, 64, 5, 1, 2),  # 輸入32通道,輸出64通道nn.ReLU(),nn.Conv2d(64, 128, 5, 1, 2),  # 輸入64通道,輸出128通道nn.ReLU(),nn.MaxPool2d(2),  # 尺寸減半(64→32))# 第四個卷積塊:1次卷積 + ReLU(無池化,保持尺寸)self.conv4 = nn.Sequential(nn.Conv2d(128, 128, 5, 1, 2),  # 輸入128通道,輸出128通道nn.ReLU(),  # 輸出尺寸:32×32,通道數128)# 全連接輸出層:將特征映射到20個類別(食物種類)# 輸入尺寸計算:128通道 × 32高 × 32寬(經多次池化后的特征圖尺寸)self.out = nn.Linear(128 * 32 * 32, 20)  # 20類:需與標簽數量一致def forward(self, x):"""前向傳播:定義數據在網絡中的流動路徑:param x: 輸入張量,形狀為[batch_size, 3, 256, 256]:return: 輸出張量,形狀為[batch_size, 20](各類別的預測分數)"""x = self.conv1(x)  # 經第一個卷積塊處理x = self.conv2(x)  # 經第二個卷積塊處理x = self.conv3(x)  # 經第三個卷積塊處理x = self.conv4(x)  # 經第四個卷積塊處理x = x.view(x.size(0), -1)  # 展平特征圖:[batch_size, 128*32*32]output = self.out(x)  # 經全連接層輸出預測結果return output# -------------------------- 訓練與測試函數 --------------------------
def train(dataloader, model, loss_fn, optimizer):"""訓練模型的函數:param dataloader: 訓練數據集加載器:param model: 待訓練的模型:param loss_fn: 損失函數(用于計算預測誤差):param optimizer: 優化器(用于更新模型參數)"""model.train()  # 開啟訓練模式(啟用Dropout、BatchNorm等訓練特定行為)batch_size_num = 1  # 記錄當前批次編號for X, y in dataloader:# 將數據移動到指定設備(GPU/CPU)X, y = X.to(device), y.to(device)# 前向傳播:計算模型預測結果pred = model(X)# 計算損失(預測值與真實標簽的差距)loss = loss_fn(pred, y)# 反向傳播與參數更新optimizer.zero_grad()  # 清空上一輪的梯度(避免梯度累積)loss.backward()  # 反向傳播計算梯度optimizer.step()  # 根據梯度更新模型參數# 打印損失(每2個batch打印一次,便于監控訓練過程)loss_val = loss.item()  # 獲取損失的標量值if batch_size_num % 2 == 0:print(f"loss: {loss_val:>7f}  [batch: {batch_size_num}]")batch_size_num += 1def test(dataloader, model, loss_fn):model.eval()  # 開啟評估模式(關閉Dropout、固定BatchNorm參數等)size = len(dataloader.dataset)  # 測試集總樣本數num_batches = len(dataloader)  # 測試集批次數test_loss, correct = 0, 0  # 總損失和正確預測數# 關閉梯度計算(測試時不需要更新參數,節省計算資源)with torch.no_grad():for X, y in dataloader:X, y = X.to(device), y.to(device)  # 數據移至設備pred = model(X)  # 預測test_loss += loss_fn(pred, y).item()  # 累加損失# 統計正確預測數:取預測概率最大的類別與真實標簽比較correct += (pred.argmax(1) == y).type(torch.float).sum().item()# 計算平均損失和準確率test_loss /= num_batches  # 平均損失correct /= size  # 準確率print(f"\nTest Result: \n Accuracy: {(100 * correct):>5.2f}%, Avg Loss: {test_loss:>8f}\n")# -------------------------- 單張圖片預測函數 --------------------------
def predict_single_image(image_path, model, transform, device, label_map):"""對單張圖片進行預測:param image_path: 圖片路徑:param model: 訓練好的模型:param transform: 圖像預處理函數(與測試集一致):param device: 計算設備:param label_map: 標簽映射字典(數字標簽→食物名稱):return: 預測的食物名稱"""# 讀取并預處理圖片(與測試集預處理一致)image = Image.open(image_path).convert('RGB')  # 確保3通道image = transform(image)  # 應用預處理(Resize和ToTensor)# 增加batch維度(模型要求輸入格式為[batch, C, H, W],這里batch=1)image = image.unsqueeze(0).to(device)# 模型預測model.eval()  # 開啟評估模式with torch.no_grad():  # 關閉梯度計算pred_logits = model(image)  # 得到預測分數(logits)# 取概率最大的類別標簽(argmax(1)按行取最大值索引)pred_label = pred_logits.argmax(1).item()# 映射為食物名稱if pred_label not in label_map:raise KeyError(f"預測標簽 {pred_label} 不在標簽映射字典中")return label_map[pred_label]# -------------------------- 主程序 --------------------------
if __name__ == "__main__":# 初始化模型、損失函數、優化器model = CNN().to(device)  # 創建模型并移至設備loss_fn = nn.CrossEntropyLoss()  # 多分類問題常用交叉熵損失# Adam優化器:自適應學習率,訓練效果較好,學習率0.001optimizer = torch.optim.Adam(model.parameters(), lr=0.001)# 訓練模型(100輪)epochs = 100for t in range(epochs):print(f"\nEpoch: {t + 1}/{epochs}\n----------------------------")train(train_dataloader, model, loss_fn, optimizer)print("Training Done!")# 測試模型在測試集上的性能test(test_dataloader, model, loss_fn)# 定義標簽映射字典:數字標簽→食物名稱# 需與數據集的標簽完全對應(順序和數量一致)label_to_food = {0: "八寶粥", 1: "巴旦木", 2: "白蘿卜", 3: "板栗", 4: "菠蘿",5: "草莓", 6: "蛋", 7: "蛋撻", 8: "骨肉相連", 9: "瓜子",10: "哈密瓜", 11: "漢堡", 12: "胡蘿卜", 13: "火龍果", 14: "雞翅",15: "青菜", 16: "生肉", 17: "圣女果", 18: "薯條", 19: "炸雞"}# 輸入圖片路徑并預測image_path = input("請輸入圖片路徑:")  # 用戶輸入待預測圖片路徑true_food = input("請輸入該圖片的真實食物名稱:")  # 用戶輸入真實標簽(用于對比)# 執行預測并輸出結果predicted_food = predict_single_image(image_path=image_path,model=model,transform=data_transforms['valid'],device=device,label_map=label_to_food)# 輸出對比結果print("\n" + "-" * 50)print(f"預測結果:{predicted_food}")print(f"真實結果:{true_food}")print(f"判斷:{'預測正確' if predicted_food == true_food else '預測錯誤'}")print("-" * 50)

二、數據預處理:讓數據 “適配” 模型

在深度學習中,數據預處理的質量直接影響模型性能。原始圖像可能存在尺寸不一、像素值范圍差異大、樣本數量不足等問題,需通過預處理將其轉化為模型可接受的格式,并通過數據增強提升模型泛化能力。

本項目的預處理邏輯集中在data_transforms字典中,分 “訓練集” 和 “驗證集 / 測試集” 兩種策略,我們逐一拆解其設計思路。

2.1 為什么要區分訓練集與驗證集預處理?

  • 訓練集:需要通過 “數據增強” 增加樣本多樣性,避免模型過擬合(即模型只記住訓練樣本,對新樣本識別能力差)。
  • 驗證集 / 測試集:需保持數據的 “真實性”,僅進行基礎預處理(如 Resize、ToTensor),確保評估結果能反映模型的實際泛化能力。

2.2 訓練集數據增強:每一步的作用與原理

訓練集的預處理鏈為:Resize → RandomRotation → CenterCrop → RandomHorizontalFlip → RandomVerticalFlip → ColorJitter → RandomGrayscale → ToTensor → Normalize,我們逐個解析:

(1)Resize ([300, 300]):統一初始尺寸

將所有圖像調整為 300×300 像素。為什么不直接調整為最終的 256×256?因為后續會進行旋轉和裁剪,預留一定尺寸可避免旋轉后出現黑邊。

(2)RandomRotation (45):隨機旋轉

隨機將圖像旋轉 - 45°~45°。食物在拍攝時可能有不同角度(如躺著的漢堡、豎放的胡蘿卜),旋轉增強能讓模型對角度不敏感,提升魯棒性。

(3)CenterCrop (256):中心裁剪

將旋轉后的圖像從中心裁剪為 256×256。旋轉會導致圖像邊緣出現黑邊,裁剪可去除黑邊,同時將圖像尺寸統一為模型輸入尺寸(256×256)。

(4)RandomHorizontalFlip (p=0.5) & RandomVerticalFlip (p=0.5):隨機翻轉
  • 水平翻轉(50% 概率):模擬 “左右鏡像” 的食物(如翻轉后的草莓外觀不變)。
  • 垂直翻轉(50% 概率):模擬 “上下顛倒” 的場景(如掉落的薯條)。
    翻轉操作不改變食物的核心特征,但能增加樣本多樣性,且計算成本低。
(5)ColorJitter (0.1, 0.1, 0.1, 0.1):隨機顏色抖動

調整圖像的亮度、對比度、飽和度、色調,各參數的取值范圍為 0~1(0 表示不調整,1 表示最大調整幅度)。
食物圖像的顏色易受光照影響(如白天和夜晚拍攝的青菜顏色不同),顏色抖動能讓模型對光照變化不敏感。

(6)RandomGrayscale (p=0.1):隨機灰度化

10% 概率將彩色圖像轉為灰度圖。雖然食物的顏色是重要特征,但灰度化能迫使模型關注食物的形狀、紋理等更本質的特征,避免過度依賴顏色信息(如紅色的草莓和紅色的圣女果,需通過形狀區分)。

(7)ToTensor ():轉為 Tensor 格式

將 PIL 圖像(H×W×C,像素值 0~255)轉為 PyTorch Tensor(C×H×W,像素值歸一化到 0~1)。

  • 維度轉換:模型要求輸入為 “通道優先”(C×H×W),而 PIL 圖像是 “高度優先”(H×W×C),需通過 ToTensor 調整。
  • 歸一化:將像素值從 0~255 縮放到 0~1,避免大數值導致模型梯度爆炸。
(8)Normalize ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]):標準化

使用 ImageNet 數據集的均值和標準差對 Tensor 進行標準化,公式為:
標準化后像素值 = (原始像素值 - 均值) / 標準差
為什么用 ImageNet 的參數?因為本項目后續可擴展為遷移學習(使用預訓練模型),而預訓練模型是在 ImageNet 上訓練的,使用相同的標準化參數能讓模型更快收斂。

transforms.Compose([transforms.Resize([300, 300]),  # 先將圖像調整為300x300transforms.RandomRotation(45),  # 隨機旋轉(-45~45度),增強旋轉不變性transforms.CenterCrop(256),  # 中心裁剪到256x256,去除旋轉后的黑邊transforms.RandomHorizontalFlip(p=0.5),  # 50%概率水平翻轉transforms.RandomVerticalFlip(p=0.5),  # 50%概率垂直翻轉transforms.ColorJitter(0.1, 0.1, 0.1, 0.1),  # 隨機調整亮度、對比度、飽和度和色調transforms.RandomGrayscale(p=0.1),  # 10%概率轉為灰度圖transforms.ToTensor(),  # 轉為Tensor格式([C, H, W]),并將像素值歸一化到[0,1]transforms.Normalize(  # 使用ImageNet的均值和標準差進行標準化[0.485, 0.456, 0.406],  # 均值(RGB三個通道)[0.229, 0.224, 0.225]  # 標準差(RGB三個通道))

2.3 驗證集預處理

驗證集的預處理鏈為:Resize([256, 256]) → ToTensor(),僅保留基礎操作:

  • Resize ([256, 256]):直接將圖像調整為 256×256,無需旋轉(避免引入非真實樣本)。
  • ToTensor ():與訓練集一致,確保數據格式統一。
    (注:代碼中驗證集未做 Normalize,實際項目中建議與訓練集保持一致,此處可根據需求調整)

三、自定義 Dataset:PyTorch 數據加載的核心

PyTorch 通過DatasetDataLoader實現數據加載,其中Dataset負責 “定義數據來源和格式”,DataLoader負責 “批量加載和并行處理”。本項目自定義了food_dataset類,用于加載食物圖像和對應標簽,我們詳細解析其實現邏輯。

3.1 Dataset 的核心作用

Dataset是一個抽象類,要求子類必須實現三個方法:

  1. __init__:初始化數據集(讀取文件列表、加載預處理函數)。
  2. __len__:返回數據集的總樣本數。
  3. __getitem__:根據索引返回單個樣本(圖像 + 標簽)。
    這三個方法確保了 PyTorch 能高效地迭代訪問數據。

3.2 food_dataset 類逐方法解析

(1)init:初始化數據列表
def __init__(self, file_path, transform=None):self.file_path = file_path  # 存儲圖像路徑和標簽的txt文件路徑self.transform = transform  # 預處理函數self.imgs = []  # 存儲所有圖像路徑self.labels = []  # 存儲對應標簽# 讀取txt文件,解析圖像路徑和標簽with open(self.file_path, 'r', encoding="utf-8") as f:for line in f.readlines():line = line.strip()  # 去除首尾空格和換行符if not line:  # 跳過空行(避免解析錯誤)continueimg_path, label = line.split(' ')  # 按空格分割路徑和標簽self.imgs.append(img_path)self.labels.append(label)

  • txt 文件格式要求:每行需包含 “圖像路徑” 和 “數字標簽”,用空格分隔
    其中 “0” 對應 “八寶粥”,“1” 對應 “巴旦木”,需與后續label_to_food字典一致。
(2)len:返回樣本總數
def __len__(self):return len(self.imgs)

簡單直接,返回self.imgs的長度),DataLoader會通過該方法確定迭代次數。

(3)getitem:返回單個樣本
def __getitem__(self, index):# 讀取圖像并強制轉為RGB(避免灰度圖通道數問題)try:image = Image.open(self.imgs[index]).convert('RGB')except Exception as e:raise ValueError(f"讀取圖片 {self.imgs[index]} 失敗:{e}")# 應用預處理if self.transform:image = self.transform(image)# 處理標簽:轉為int64類型Tensor(PyTorch分類任務要求)label = torch.tensor(int(self.labels[index]), dtype=torch.int64)return image, label

這是Dataset的核心方法,需重點關注三個細節:

  1. 強制 RGB 格式convert('RGB')確保所有圖像都是 3 通道(避免部分灰度圖是 1 通道,導致模型輸入維度不匹配)。
  2. 異常處理try-except捕獲圖像讀取錯誤(如路徑錯誤、圖像損壞),并明確提示錯誤位置,便于調試。
  3. 標簽類型:將標簽轉為torch.int64(即 LongTensor),因為 PyTorch 的CrossEntropyLoss要求標簽為 Long 類型。

3.3 如何準備自己的數據集?

  1. 收集圖像:每個食物類別收集至少 100 張圖像(樣本越多,模型性能越好),建議按類別分文件夾存儲
  2. 生成 txt 文件:編寫腳本遍歷圖像文件夾,生成train.txttest.txt

    import os# 數據集根目錄
    train_root = "./dataset/train"
    test_root = "./dataset/test"
    # 標簽映射(與后續一致)
    label_to_food = {0: "八寶粥", 1: "巴旦木", ..., 19: "炸雞"}
    # 反向映射:食物名稱→數字標簽
    food_to_label = {v: k for k, v in label_to_food.items()}# 生成train.txt
    with open("train.txt", "w", encoding="utf-8") as f:for food_name in os.listdir(train_root):food_dir = os.path.join(train_root, food_name)if not os.path.isdir(food_dir):continuelabel = food_to_label[food_name]for img_name in os.listdir(food_dir):img_path = os.path.join(food_dir, img_name)f.write(f"{img_path} {label}\n")# 生成test.txt(邏輯同上)
    with open("test.txt", "w", encoding="utf-8") as f:for food_name in os.listdir(test_root):food_dir = os.path.join(test_root, food_name)if not os.path.isdir(food_dir):continuelabel = food_to_label[food_name]for img_name in os.listdir(food_dir):img_path = os.path.join(food_dir, img_name)f.write(f"{img_path} {label}\n")
    
  3. 檢查路徑:確保 txt 文件中的圖像路徑與實際文件路徑一致

四、DataLoader:批量加載與并行處理

Dataset定義了數據的 “來源”,而DataLoader則負責將數據 “批量加載” 到模型中,并支持并行處理,提升數據加載速度。

4.1 DataLoader 的核心參數解析

本項目的DataLoader初始化代碼如下:

train_dataloader = DataLoader(training_data, batch_size=8, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=8, shuffle=True)

核心參數含義:

  • batch_size:每次加載的樣本數量(批大小),需根據 GPU 顯存調整
  • shuffle:是否打亂數據順序。訓練集設為True,驗證集可設為False,本項目測試集設為True是為了觀察不同樣本的預測效果。

4.2 DataLoader 與 Dataset 的協作流程

DataLoader的工作流程可概括為:

  1. 調用Dataset.__len__()獲取總樣本數,計算總批次數(總樣本數 //batch_size)。
  2. shuffle=True,則在每個 epoch(訓練輪次)開始前打亂樣本索引。
  3. 對每個批次,根據索引調用Dataset.__getitem__()獲取單個樣本,組裝成一個批次的 Tensor(形狀為 [batch_size, C, H, W])。
  4. 將批次數據移動到指定設備(CUDA/MPS/CPU),供模型訓練或測試。

4.3 數據加載到設備的邏輯

X, y = X.to(device), y.to(device)

其中device是通過以下代碼確定的:

device = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'

  • 為什么要移動設備??模型和數據必須在同一設備上才能進行計算(如模型在 CUDA 上,數據也需在 CUDA 上),否則會報錯。
  • 設備優先級:優先使用 CUDA(NVIDIA GPU),其次是 MPS(Apple M 系列),最后是 CPU

五、CNN 模型構建:從卷積到全連接的特征提取

卷積神經網絡(CNN)是圖像分類的核心,其通過 “卷積層提取局部特征→池化層降維→全連接層分類” 的流程,實現對圖像的識別。本項目的 CNN 模型包含 4 個卷積塊和 1 個全連接層,我們逐一解析其設計思路和尺寸計算。

5.1 CNN 的核心組件與作用

在解析代碼前,先回顧 CNN 的三個核心組件:

  1. 卷積層(Conv2d):通過卷積核滑動提取圖像的局部特征(如邊緣、紋理、形狀),輸出 “特征圖”(Feature Map)。
  2. 激活函數(ReLU):引入非線性,讓模型能擬合復雜的特征關系(避免線性模型的表達能力不足)。
  3. 池化層(MaxPool2d):對特征圖進行下采樣,降低維度和計算量,同時增強模型對特征位置的魯棒性。

5.2 模型代碼逐塊解析

模型定義代碼如下,我們按 “卷積塊 1→卷積塊 2→卷積塊 3→卷積塊 4→全連接層” 的順序解析:

class CNN(nn.Module):def __init__(self):super().__init__()# 卷積塊1:1次卷積 + ReLU + 最大池化self.conv1 = nn.Sequential(nn.Conv2d(in_channels=3, out_channels=16, kernel_size=5, stride=1, padding=2),nn.ReLU(),nn.MaxPool2d(kernel_size=2),  # 池化后尺寸:256→128)# 卷積塊2:2次卷積 + ReLU + 最大池化self.conv2 = nn.Sequential(nn.Conv2d(16, 32, 5, 1, 2),nn.ReLU(),nn.Conv2d(32, 32, 5, 1, 2),nn.ReLU(),nn.MaxPool2d(2),  # 尺寸:128→64)# 卷積塊3:2次卷積 + ReLU + 最大池化self.conv3 = nn.Sequential(nn.Conv2d(32, 64, 5, 1, 2),nn.ReLU(),nn.Conv2d(64, 128, 5, 1, 2),nn.ReLU(),nn.MaxPool2d(2),  # 尺寸:64→32)# 卷積塊4:1次卷積 + ReLU(無池化)self.conv4 = nn.Sequential(nn.Conv2d(128, 128, 5, 1, 2),nn.ReLU(),  # 輸出尺寸:32×32,通道數128)# 全連接層:映射到20類self.out = nn.Linear(128 * 32 * 32, 20)
池化層參數計算:尺寸減半

MaxPool2d(kernel_size=2)表示池化核大小為 2×2,步長默認等于核大小(即 2),因此輸出尺寸為輸入尺寸的 1/2:

  • conv1池化前尺寸:256×256 → 池化后:128×128
  • conv2池化前尺寸:128×128 → 池化后:64×64
  • conv3池化前尺寸:64×64 → 池化后:32×32
各卷積塊的輸出特征
  • conv1:輸出特征圖形狀為 [batch_size, 16, 128, 128],提取的是圖像的低級特征(如邊緣、顏色塊)。
  • conv2:輸出形狀為 [batch_size, 32, 64, 64],通過 2 次卷積提取更復雜的特征(如食物的局部輪廓)。
  • conv3:輸出形狀為 [batch_size, 128, 32, 32],通道數增加到 128,特征更抽象(如食物的結構特征)。
  • conv4:輸出形狀為 [batch_size, 128, 32, 32],無池化,進一步細化特征(避免池化導致的特征損失)。
(4)全連接層:從特征到分類

全連接層self.out的輸入維度是128×32×32,這是由conv4的輸出特征圖形狀決定的:

  • conv4輸出:[batch_size, 128, 32, 32] → 展平后為 [batch_size, 128×32×32](展平操作在forward中通過x.view(x.size(0), -1)實現)。
  • 輸出維度:20,對應 20 種食物類別,每個維度輸出該類別的 “預測分數”(后續通過argmax取分數最高的類別作為預測結果)。

5.3 forward 方法:定義數據流動路徑

def forward(self, x):x = self.conv1(x)  # 經卷積塊1處理x = self.conv2(x)  # 經卷積塊2處理x = self.conv3(x)  # 經卷積塊3處理x = self.conv4(x)  # 經卷積塊4處理x = x.view(x.size(0), -1)  # 展平:[batch_size, 128*32*32]output = self.out(x)  # 全連接層輸出return output
  • 展平操作x.view(x.size(0), -1)將 4 維特征圖(batch, C, H, W)轉為 2 維張量(batch, C×H×W),因為全連接層僅接受 2 維輸入。
  • 數據流動:輸入圖像([batch, 3, 256, 256])→ 卷積塊 1→2→3→4 → 展平 → 全連接層 → 輸出([batch, 20])。

5.4 模型初始化與設備移動

模型初始化代碼如下:

model = CNN().to(device)
  • CNN()創建模型實例,to(device)將模型參數移動到指定設備(CUDA/MPS/CPU),確保模型和數據在同一設備上計算。

六、模型訓練與測試:從損失下降到性能評估

模型構建完成后,需通過訓練讓模型 “學習” 食物特征,再通過測試評估模型的泛化能力。本項目定義了traintest兩個函數,分別實現訓練和測試邏輯。

6.1 訓練函數:讓模型 “學習”

訓練函數的核心是 “前向傳播計算損失→反向傳播更新參數”,代碼如下:

def train(dataloader, model, loss_fn, optimizer):model.train()  # 開啟訓練模式(啟用Dropout、BatchNorm訓練行為)batch_size_num = 1  # 批次編號,用于打印損失for X, y in dataloader:# 數據移動到設備X, y = X.to(device), y.to(device)# 1. 前向傳播:計算預測結果pred = model(X)# 2. 計算損失:預測值與真實標簽的差距loss = loss_fn(pred, y)# 3. 反向傳播與參數更新optimizer.zero_grad()  # 清空上一輪梯度(避免累積)loss.backward()        # 反向傳播計算梯度optimizer.step()       # 根據梯度更新模型參數# 打印損失(每2個批次打印一次)loss_val = loss.item()if batch_size_num % 2 == 0:print(f"loss: {loss_val:>7f}  [batch: {batch_size_num}]")batch_size_num += 1
(1)關鍵步驟解析
  1. model.train():開啟訓練模式,對含有 Dropout、BatchNorm 的模型至關重要:

    • Dropout:訓練時隨機 “關閉” 部分神經元,防止過擬合;測試時不關閉。
    • BatchNorm:訓練時使用批次的均值和方差歸一化;測試時使用訓練階段累積的均值和方差。
  2. 前向傳播(Forward Pass)

    • pred = model(X):將批次數據輸入模型,得到預測結果([batch, 20])。
    • loss = loss_fn(pred, y):計算損失,本項目使用CrossEntropyLoss(多分類任務的常用損失函數)。
  3. 反向傳播(Backward Pass)與參數更新

    • optimizer.zero_grad():清空梯度。若不清空,梯度會累積到上一輪,導致參數更新錯誤。
    • loss.backward():根據損失計算各參數的梯度(
    • optimizer.step():根據梯度更新模型參數
(2)損失函數與優化器選擇

本項目使用的損失函數和優化器如下:

loss_fn = nn.CrossEntropyLoss()  # 多分類交叉熵損失
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)  # Adam優化器
  • CrossEntropyLoss:適用于多分類任務
  • Adam 優化器:自適應學習率優化器,收斂速度快,對學習率不敏感,是深度學習中最常用的優化器之一。lr=0.001是常用的初始學習率,可根據訓練情況調整

6.2 測試函數:評估模型泛化能力

測試函數的核心是 “計算模型在測試集上的準確率和平均損失”,代碼如下:

def test(dataloader, model, loss_fn):model.eval()  # 開啟評估模式(關閉Dropout、固定BatchNorm)size = len(dataloader.dataset)  # 測試集總樣本數num_batches = len(dataloader)    # 測試集批次數test_loss, correct = 0, 0        # 總損失和正確預測數# 關閉梯度計算(節省資源,避免參數更新)with torch.no_grad():for X, y in dataloader:X, y = X.to(device), y.to(device)pred = model(X)# 累加損失test_loss += loss_fn(pred, y).item()# 累加正確預測數correct += (pred.argmax(1) == y).type(torch.float).sum().item()# 計算平均損失和準確率test_loss /= num_batchescorrect /= sizeprint(f"\nTest Result: \n Accuracy: {(100 * correct):>5.2f}%, Avg Loss: {test_loss:>8f}\n")
(1)關鍵步驟解析
  1. model.eval():開啟評估模式,與model.train()對應,確保模型在測試時的行為與訓練時一致(如關閉 Dropout)。
  2. torch.no_grad():上下文管理器,關閉梯度計算。測試時無需更新參數,關閉梯度可大幅減少內存占用和計算時間。
  3. 準確率計算
    • pred.argmax(1):對每個樣本,取預測分數最高的類別(維度 1 是類別維度,[batch, 20]→[batch, 1])。
    • (pred.argmax(1) == y):比較預測類別與真實標簽,得到布爾張量(True = 正確,False = 錯誤)。
    • type(torch.float).sum().item():將布爾張量轉為 float(True=1,False=0),求和得到正確預測數,再轉為 Python 標量。
  4. 結果解讀
    • Accuracy:準確率(正確預測數 / 總樣本數),反映模型的整體識別能力,越高越好。
    • Avg Loss:平均損失,反映模型預測值與真實標簽的平均差距,越低越好。

6.3 訓練流程與輪次設置

訓練流程代碼如下:

# 訓練輪次(epochs)
epochs = 100
for t in range(epochs):print(f"\nEpoch: {t + 1}/{epochs}\n")train(train_dataloader, model, loss_fn, optimizer)
print("Training Done!")# 測試模型
test(test_dataloader, model, loss_fn)
  • epochs(訓練輪次):表示模型將遍歷整個訓練集的次數。本項目設為 100,可根據實際情況調整:
    • 若訓練損失仍在下降,可增加 epochs;
    • 若訓練損失下降但測試損失上升(過擬合),可減少 epochs 或加入早停機制。
  • 訓練與測試順序:每輪訓練后可加入測試(如在train后調用test),便于監控模型是否過擬合;本項目在所有訓練完成后測試,適用于快速驗證。

七、單張圖片預測:模型的實際應用

訓練完成后,需將模型用于實際場景 —— 對單張食物圖片進行分類。本項目定義了predict_single_image函數,實現從圖像讀取到類別輸出的完整流程。

7.1 預測函數解析

def predict_single_image(image_path, model, transform, device, label_map):# 1. 讀取并預處理圖像(與測試集一致)image = Image.open(image_path).convert('RGB')  # 強制RGBimage = transform(image)  # 應用預處理(Resize + ToTensor)# 2. 增加batch維度(模型要求輸入為[batch, C, H, W])image = image.unsqueeze(0).to(device)# 3. 模型預測model.eval()  # 開啟評估模式with torch.no_grad():pred_logits = model(image)  # 預測分數(logits)pred_label = pred_logits.argmax(1).item()  # 取最高分數類別# 4. 映射為食物名稱if pred_label not in label_map:raise KeyError(f"預測標簽 {pred_label} 不在標簽映射字典中")return label_map[pred_label]
(1)關鍵步驟解析
  1. 圖像預處理一致性:預測時的預處理必須與測試集一致(本項目使用data_transforms['valid']),否則模型輸入格式不匹配,預測結果會失真。
  2. 增加 batch 維度:模型訓練和測試時輸入都是批次數據([batch, C, H, W]),而單張圖片是 [C, H, W],需通過unsqueeze(0)在第 0 維(batch 維)增加一個維度,變為 [1, C, H, W]。
  3. 標簽映射label_map(如label_to_food)將數字標簽(如 0)映射為食物名稱(如 “八寶粥”),讓預測結果更直觀。

7.2 預測實戰與結果展示

預測代碼如下,用戶輸入圖片路徑和真實標簽,模型輸出預測結果并對比:

# 標簽映射字典(與數據集標簽對應)
label_to_food = {0: "八寶粥", 1: "巴旦木", 2: "白蘿卜", 3: "板栗", 4: "菠蘿",5: "草莓", 6: "蛋", 7: "蛋撻", 8: "骨肉相連", 9: "瓜子",10: "哈密瓜", 11: "漢堡", 12: "胡蘿卜", 13: "火龍果", 14: "雞翅",15: "青菜", 16: "生肉", 17: "圣女果", 18: "薯條", 19: "炸雞"
}# 用戶輸入
image_path = input("請輸入圖片路徑:")
true_food = input("請輸入該圖片的真實食物名稱:")# 執行預測
predicted_food = predict_single_image(image_path=image_path,model=model,transform=data_transforms['valid'],device=device,label_map=label_to_food
)# 輸出結果
print("\n" + "-" * 50)
print(f"預測結果:{predicted_food}")
print(f"真實結果:{true_food}")
print(f"判斷:{'預測正確' if predicted_food == true_food else '預測錯誤'}")
print("-" * 50)
(1)預測示例

假設用戶輸入:

  • 圖片路徑:"D:\食物分類\food_dataset\test\八寶粥\img_八寶粥罐_81.jpeg"(一張八寶粥圖片)
  • 真實食物名稱:八寶粥

模型輸出:

請輸入圖片路徑:./test_images/hamburger.jpg
請輸入該圖片的真實食物名稱:漢堡--------------------------------------------------
預測結果:漢堡
真實結果:漢堡
判斷:預測正確
--------------------------------------------------
(2)預測錯誤原因

樣本數量不足

訓練的輪數過少

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/diannao/97926.shtml
繁體地址,請注明出處:http://hk.pswp.cn/diannao/97926.shtml
英文地址,請注明出處:http://en.pswp.cn/diannao/97926.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

Python 值傳遞 (Pass by Value) 和引用傳遞 (Pass by Reference)

Python 值傳遞 {Pass by Value} 和引用傳遞 {Pass by Reference}1. Mutable Objects and Immutable Objects in Python (Python 可變對象和不可變對象)2. Pass by Value and Pass by Reference2.1. What is Pass by Value in Python?2.2. What is Pass by Reference in Python…

aippt自動生成工具有哪些?一文看懂,總有一款適合你!

在當今快節奏的工作與學習環境中,傳統耗時的PPT制作方式已難以滿足高效表達的需求。隨著人工智能技術的發展,AI自動生成PPT工具應運而生,成為提升演示文稿制作效率的利器。這類工具通過自然語言處理和深度學習技術,能夠根據用戶輸…

Langflow 框架中 Prompt 技術底層實現分析

Langflow 框架中 Prompt 技術底層實現分析 1. Prompt 技術概述 Langflow 是一個基于 LangChain 的可視化 AI 工作流構建框架,其 Prompt 技術是整個系統的核心組件之一。Prompt 技術主要負責: 模板化處理:支持動態變量替換的提示詞模板變量驗證…

前端、node跨域問題

前端頁面訪問node后端接口跨域報錯 Access to XMLHttpRequest at http://192.18.31.75/api/get?namess&age19 from origin http://127.0.0.1:5500 has been blocked by CORS policy: No Access-Control-Allow-Origin header is present on the requested resource. 這個報…

超越馬力歐:如何為經典2D平臺游戲注入全新靈魂

在游戲開發的世界里,2D平臺游戲仿佛是一位熟悉的老朋友。從《超級馬力歐兄弟》開啟的黃金時代到現在,這個類型已經經歷了數十年的演變與打磨。當每個基礎設計似乎都已被探索殆盡時,我們如何才能打造出一款令人耳目一新的平臺游戲?…

基于Springboot + vue3實現的時尚美妝電商網站

項目描述本系統包含管理員和用戶兩個角色。管理員角色:商品分類管理:新增、查看、修改、刪除商品分類。商品信息管理:新增、查看、修改、刪除、查看評論商品信息。用戶管理:新增、查看、修改、刪除用戶。管理員管理:查…

網絡協議之https?

寫在前面 https協議還是挺復雜的,本人也是經過了很多次的學習,依然感覺一知半解,無法將所有的知識點串起來,本次學習呢,也是有很多的疑惑點,但是還是盡量的輸出內容,來幫助自己和在看文章的你來…

word運行時錯誤‘53’,文件未找到:MathPage.WLL,更改加載項路徑完美解決

最簡單的方法解決!!!安裝Mathtype之后粘貼顯示:運行時錯誤‘53’,文件未找到:MathPage.WLLwin11安裝mathtype后會有這個錯誤,這是由于word中加載項加載mathtype路徑出錯導致的,這時候…

React實現列表拖拽排序

本文主要介紹一下React實現列表拖拽排序方法,具體樣式如下圖首先,簡單展示一下組件的數據結構 const CodeSetting props > {const {$t, // 國際化翻譯函數vm, // 視圖模型數據vm: {CodeSet: { Enable [], …

將 MySQL 表數據導出為 CSV 文件

目錄 一、實現思路 二、核心代碼 1. 數據庫連接部分 2. 數據導出核心邏輯 3. CSV文件寫入 三、完整代碼實現 五、輸出結果 一、實現思路 建立數據庫連接 查詢目標表的數據總量和具體數據 獲取表的列名作為CSV文件的表頭 將查詢結果轉換為二維數組格式 使用Hutool工具…

一文讀懂RAG:從生活場景到核心邏輯,AI“查資料答題”原來這么簡單

一文讀懂RAG:從生活場景到核心邏輯,AI“查資料答題”原來這么簡單 要理解 RAG(Retrieval-Augmented Generation,檢索增強生成),不需要先背復雜公式,我們可以從一個生活場景切入——它本質是讓AI…

git將當前分支推送到遠端指定分支

在 Git 中&#xff0c;將當前本地分支推送到遠程倉庫的指定分支&#xff0c;可以使用 git push 命令&#xff0c;并指定本地分支和遠程分支的映射關系。 基本語法 git push <遠程名稱> <本地分支名>:<遠程分支名><遠程名稱>&#xff1a;通常是 origin&…

【Linux】線程封裝

提示&#xff1a;文章寫完后&#xff0c;目錄可以自動生成&#xff0c;如何生成可參考右邊的幫助文檔 文章目錄 一、為什么需要封裝線程庫&#xff1f; pthread的痛點&#xff1a; 封裝帶來的好處&#xff1a; 二、線程封裝核心代碼解析 1. 頭文件定義&#xff08;Thread.hpp&a…

智慧交通管理信號燈通信4G工業路由器應用

在交通信號燈管理中傳統的有線通訊&#xff08;光纖、網線&#xff09;存在部署成本高、偏遠區域覆蓋難、故障維修慢等問題&#xff0c;而4G工業路由器憑借無線化、高穩定、強適配的特性&#xff0c;成為信號燈與管控平臺間的數據傳輸核心&#xff0c;適配多場景需求。智慧交通…

《Python Flask 實戰:構建一個可交互的 Web 應用,從用戶輸入到智能響應》

《Python Flask 實戰:構建一個可交互的 Web 應用,從用戶輸入到智能響應》 一、引言:從“Hello, World!”到“你好,用戶” 在 Web 應用的世界里,最打動人心的功能往往不是炫酷的界面,而是人與系統之間的真實互動。一個簡單的輸入框,一句個性化的回應,往往能讓用戶感受…

開發效率翻倍:資深DBA都在用的MySQL客戶端利器

MySQL 連接工具&#xff08;也稱為客戶端或圖形化界面工具&#xff0c;GUI Tools&#xff09;是數據庫開發、管理和運維中不可或缺的利器。它們比命令行更直觀&#xff0c;能極大提高工作效率。以下是一份主流的 MySQL 連接工具清單&#xff0c;并附上了它們的優缺點和適用場景…

基于Docker和Kubernetes的CI/CD流水線架構設計與優化實踐

基于Docker和Kubernetes的CI/CD流水線架構設計與優化實踐 本文分享了在生產環境中基于Docker和Kubernetes構建高效可靠的CI/CD流水線的實戰經驗&#xff0c;包括業務場景、技術選型、詳細方案、踩坑與解決方案&#xff0c;以及最終的總結與最佳實踐&#xff0c;幫助后端開發者快…

Trae x 圖片素描MCP一鍵將普通圖片轉換為多風格素描效果

目錄前言一、核心工具與優勢解析二、操作步驟&#xff1a;從安裝到生成素描效果第一步&#xff1a;獲取MCP配置代碼第二步&#xff1a;下載第三步&#xff1a;在 Trae 中導入 MCP 配置并建立連接第四步&#xff1a;核心功能調用三、三大素描風格差異化應用四.總結前言 在設計創…

2 XSS

XSS的原理 XSS&#xff08;跨站腳本攻擊&#xff09;原理 1. 核心機制 XSS攻擊的本質是惡意腳本在用戶瀏覽器中執行。攻擊者通過向網頁注入惡意代碼&#xff0c;當其他用戶訪問該頁面時&#xff0c;瀏覽器會執行這些代碼&#xff08;沒有對用戶的輸入進行過濾導致用戶輸入的…

GitHub每日最火火火項目(9.3)

1. pedroslopez / whatsapp-web.js 項目名稱&#xff1a;whatsapp-web.js項目介紹&#xff1a;基于 JavaScript 開發&#xff0c;是一個用于 Node.js 的 WhatsApp 客戶端庫&#xff0c;通過 WhatsApp Web 瀏覽器應用進行連接&#xff08;A WhatsApp client library for NodeJS …