AI技術實戰:從零搭建圖像分類系統全流程詳解
人工智能學習 https://www.captainbed.cn/ccc
前言
本文將以圖像分類任務為切入點,手把手教你完成AI模型從數據準備到工業部署的全鏈路開發。通過一個完整的Kaggle貓狗分類項目(代碼兼容PyTorch/TensorFlow),覆蓋以下核心技能:
- 數據清洗與增強的工程化實現
- 模型構建與訓練技巧
- 模型壓縮與TensorRT部署優化
- 可視化監控與性能調優
所有代碼均提供可運行的Colab鏈接,建議邊閱讀邊實踐。
目錄
-
環境搭建與數據準備
- 1.1 本地/云端開發環境配置
- 1.2 數據爬取與清洗腳本開發
- 1.3 自動化標注工具實戰
-
圖像分類模型實戰
- 2.1 手寫CNN模型構建(帶可運行代碼)
- 2.2 遷移學習Fine-tuning技巧
- 2.3 訓練過程可視化監控
-
模型優化與部署
- 3.1 模型剪枝與量化壓縮
- 3.2 ONNX格式轉換與TensorRT加速
- 3.3 RESTful API服務封裝
-
工業級增強技巧
- 4.1 解決類別不平衡問題
- 4.2 應對小樣本學習的策略
- 4.3 模型熱更新方案
1. 環境搭建與數據準備
1.1 開發環境配置(PyTorch示例)
# 創建虛擬環境
conda create -n ai_tutorial python=3.8
conda activate ai_tutorial# 安裝核心依賴
pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 --extra-index-url https://download.pytorch.org/whl/cu113
pip install opencv-python albumentations pandas
1.2 數據爬取實戰
# 使用Bing圖片下載API批量獲取數據
import requestsdef download_images(keyword, count=100):headers = {'Ocp-Apim-Subscription-Key': 'YOUR_API_KEY'}params = {'q': keyword, 'count': count}response = requests.get('https://api.bing.microsoft.com/v7.0/images/search', headers=headers, params=params)for idx, img in enumerate(response.json()['value']):img_data = requests.get(img['contentUrl']).contentwith open(f'dataset/{keyword}_{idx}.jpg', 'wb') as f:f.write(img_data)# 執行下載
download_images('cat')
download_images('dog')
1.3 自動化數據清洗
# 使用OpenCV過濾損壞圖片
import cv2
import osdef clean_dataset(folder):valid_extensions = ['.jpg', '.jpeg', '.png']for filename in os.listdir(folder):filepath = os.path.join(folder, filename)try:img = cv2.imread(filepath)if img is None or img.size == 0:os.remove(filepath)elif os.path.splitext(filename)[1].lower() not in valid_extensions:os.remove(filepath)except Exception as e:print(f"刪除損壞文件: {filename}")os.remove(filepath)clean_dataset('dataset/train')
2. 圖像分類模型實戰
2.1 自定義CNN模型(PyTorch實現)
import torch.nn as nnclass SimpleCNN(nn.Module):def __init__(self, num_classes=2):super().__init__()self.features = nn.Sequential(nn.Conv2d(3, 32, kernel_size=3, padding=1), # 輸入3通道nn.ReLU(),nn.MaxPool2d(2),nn.Conv2d(32, 64, kernel_size=3, padding=1),nn.ReLU(),nn.MaxPool2d(2),nn.Conv2d(64, 128, kernel_size=3, padding=1),nn.ReLU(),nn.MaxPool2d(2))self.classifier = nn.Sequential(nn.Linear(128 * 28 * 28, 512), # 根據輸入尺寸調整nn.ReLU(),nn.Dropout(0.5),nn.Linear(512, num_classes))def forward(self, x):x = self.features(x)x = x.view(x.size(0), -1)return self.classifier(x)
2.2 遷移學習實戰(ResNet50微調)
from torchvision import models# 加載預訓練模型
model = models.resnet50(pretrained=True)# 替換最后一層全連接
num_ftrs = model.fc.in_features
model.fc = nn.Sequential(nn.Linear(num_ftrs, 256),nn.ReLU(),nn.Dropout(0.4),nn.Linear(256, 2)
)# 凍結早期層參數
for param in model.parameters():param.requires_grad = False
for param in model.layer4.parameters():param.requires_grad = True
2.3 訓練過程可視化(TensorBoard集成)
from torch.utils.tensorboard import SummaryWriterwriter = SummaryWriter()for epoch in range(epochs):# 訓練代碼...writer.add_scalar('Loss/train', loss.item(), epoch)writer.add_scalar('Accuracy/train', acc, epoch)# 可視化特征圖if epoch % 10 == 0:writer.add_images('Feature Maps', model.features[0](images[:4]), epoch)
3. 模型優化與部署
3.1 模型剪枝實戰
import torch.nn.utils.prune as prune# 對卷積層進行L1非結構化剪枝
parameters_to_prune = ((model.features[0], 'weight'),(model.features[3], 'weight'),
)prune.global_unstructured(parameters_to_prune,pruning_method=prune.L1Unstructured,amount=0.2, # 剪枝20%的權重
)
3.2 TensorRT加速部署
# 導出ONNX模型
torch.onnx.export(model, dummy_input, "model.onnx",opset_version=11)# 使用TensorRT轉換
trt_cmd = f"""
trtexec --onnx=model.onnx \--saveEngine=model.trt \--fp16 \--workspace=2048
"""
os.system(trt_cmd)
3.3 封裝Flask API服務
from flask import Flask, request
import trt_inference # 自定義TRT推理模塊app = Flask(__name__)@app.route('/predict', methods=['POST'])
def predict():file = request.files['image']img = preprocess(file.read())output = trt_inference.run(img)return {'class_id': int(output.argmax())}if __name__ == '__main__':app.run(host='0.0.0.0', port=5000)
4. 工業級增強技巧
4.1 類別不平衡解決方案
# 使用加權采樣器
from torch.utils.data import WeightedRandomSamplerclass_counts = [num_cat, num_dog]
weights = 1. / torch.tensor(class_counts, dtype=torch.float)
samples_weights = weights[labels]sampler = WeightedRandomSampler(weights=samples_weights,num_samples=len(samples_weights),replacement=True
)
4.2 小樣本學習方案
# 使用MixUp數據增強
def mixup_data(x, y, alpha=0.2):lam = np.random.beta(alpha, alpha)batch_size = x.size()[0]index = torch.randperm(batch_size)mixed_x = lam * x + (1 - lam) * x[index]y_a, y_b = y, y[index]return mixed_x, y_a, y_b, lam# 修改損失函數
criterion = nn.CrossEntropyLoss()
loss = lam * criterion(output, y_a) + (1 - lam) * criterion(output, y_b)