概述
本代碼實現了一個基于PyTorch的圖像特征提取與分類模型訓練流程。核心功能包括:
使用預訓練ResNet18模型進行圖像特征提取
將提取的特征保存為標準化格式
基于提取的特征訓練分類模型
代碼結構詳解?
1. 庫導入
import torch
import torch.nn as nn
import torchvision
from torchvision import transforms, datasets
from torch.utils.data import DataLoader, Subset
import numpy as np
import os
from ml.model_trainer import ModelTrainer
-
關鍵庫說明:
-
torch
:PyTorch核心庫 -
torch.nn
:神經網絡模塊 -
torchvision
:計算機視覺專用模塊 -
numpy
:數值計算庫 -
os
:文件系統操作 -
ModelTrainer
:自定義模型訓練類(需另行實現)
-
2. 特征提取器類(FeatureExtractor)
初始化方法 __init__
def __init__(self):self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")self.model = torchvision.models.resnet18(weights='IMAGENET1K_V1')self.model = nn.Sequential(*list(self.model.children())[:-1])self.model = self.model.to(self.device).eval()self.transform = transforms.Compose([...])
-
功能說明:
-
設備檢測:自動選擇GPU/CPU
-
模型加載:使用ImageNet預訓練的ResNet18
-
模型修改:移除最后的全連接層(保留卷積特征提取器)
-
預處理設置:標準化圖像尺寸和顏色空間
-
特征提取方法 extract_features
def extract_features(self, data_dir):full_dataset = datasets.ImageFolder(...)loader = DataLoader(...)features = []labels = []with torch.no_grad():for inputs, targets in loader:inputs = inputs.to(self.device)outputs = self.model(inputs)features.append(outputs.squeeze().cpu().numpy())labels.append(targets.numpy())features = np.concatenate(...)labels = np.concatenate(...)return features, labels, full_dataset.classes
-
關鍵參數:
-
data_dir
:包含分類子目錄的圖像數據集路徑 -
batch_size=32
:平衡內存使用與處理效率 -
num_workers=4
:多線程數據加載
-
-
處理流程:
-
創建ImageFolder數據集
-
使用DataLoader批量加載
-
禁用梯度計算加速推理
-
特征維度壓縮(squeeze)
-
設備間數據傳輸(GPU->CPU)
-
合并所有批次數據
-
3. 主執行流程
參數配置
DATA_DIR = "/home/.../data" # 實際數據路徑
SAVE_PATH = "./features.npz" # 特征保存路徑
特征提取與保存?
extractor = FeatureExtractor()
if not os.path.exists(SAVE_PATH):features, labels, classes = extractor.extract_features(DATA_DIR)np.savez(SAVE_PATH, features=features, labels=labels, classes=classes)
else:data = np.load(SAVE_PATH)features = data['features']labels = data['labels']
-
文件結構:
-
features: [N_samples, 512] 的特征矩陣
-
labels: [N_samples] 的標簽數組
-
classes: 類別名稱列表
-
模型訓練與保存
X, y = features, labels
trainer = ModelTrainer()
model = trainer.train_model(X, y)
joblib.dump(model, 'pest_classifier.pkl')
?
-
假設條件:
-
ModelTrainer需實現訓練邏輯(如SVM、隨機森林等)
-
默認使用全部數據進行訓練(建議實際添加數據分割)
-
技術細節說明
1. 圖像預處理流程
2. 特征維度分析
-
ResNet18最后層輸出:512維特征向量
-
假設1000張圖像:
-
原始圖像:1000×3×224×224 (約150MB)
-
提取特征:1000×512 (約2MB) → 顯著降維
-
3. 性能優化策略
-
GPU加速:自動檢測CUDA設備
-
批量處理:32張/批平衡效率與內存
-
緩存機制:避免重復特征提取
-
梯度禁用:減少內存消耗
?
?
?
?
?