前言
在計算機視覺領域,圖像分類是一個基礎且重要的任務。本文將介紹如何使用MobileNetV3預訓練模型來訓練一個水果分類模型,并通過Flask框架進行部署。MobileNetV3作為輕量級網絡,在保持較高精度的同時,具有較快的推理速度,非常適合實際應用場景。
環境準備
首先,我們需要準備以下環境:
# 主要依賴包
torch>=1.7.0
torchvision>=0.8.0
flask>=2.0.0
pillow>=8.0.0
numpy>=1.19.0
requests>=2.25.0 # 用于數據采集
matplotlib>=3.3.0 # 用于繪制訓練曲線
數據集準備
1. 數據采集
我們使用百度圖片API來采集水果圖片數據。以下是數據采集的代碼實現:
import requests
import osdef get_images(keyword, page_num):headers = {'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; WOW64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/69.0.3497.81 Safari/537.36'}url = 'https://image.baidu.com/search/acjson?'# 設置圖片保存路徑download_path = os.path.join("./data", keyword)if not os.path.exists(download_path):os.makedirs(download_path)# 構造請求參數params = {'tn': 'resultjson_com','word': keyword,'pn': 0,'rn': 30,# ... 其他參數}# 下載圖片for i in range(page_num):params["pn"] = i*30response = requests.get(url, params=params, headers=headers)# 處理返回結果并保存圖片
2. 數據集組織
將采集到的圖片按照以下結構組織:
data/├── apple/│ ├── 0.jpg│ ├── 1.jpg│ └── ...├── banana/│ ├── 0.jpg│ ├── 1.jpg│ └── ...└── ...
模型訓練
1. 數據加載和預處理
from torchvision import transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import random_split, DataLoader# 圖像預處理
transform = transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])# 加載數據集
dataset = ImageFolder("data", transform=transform)# 保存類別標簽
with open("label.txt", "w", encoding="UTF-8") as f:for line in dataset.classes:f.write(line + "\n")# 劃分訓練集和測試集
train_ratio = 0.8
train_size = int(len(dataset) * train_ratio)
test_size = len(dataset) - train_size
train_dataset, test_dataset = random_split(dataset, [train_size, test_size])# 創建數據加載器
batch_size = 64
train_loader = DataLoader(train_dataset, batch_size=batch_size)
test_loader = DataLoader(test_dataset, batch_size=batch_size)
2. 模型定義
from torchvision import models
import torch.nn as nn# 使用MobileNetV3-Small預訓練模型
model = models.mobilenet_v3_small(pretrained=True)
# 修改最后的分類層
model.classifier[3] = nn.Linear(in_features=1024, out_features=5) # 5個類別# 如果有GPU則使用GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
3. 訓練過程
# 定義損失函數和優化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)# 訓練參數
num_epochs = 20
best_valid_acc = 0
best_model = None# 記錄訓練過程
train_losses = []
valid_losses = []
train_accs = []
valid_accs = []for epoch in range(num_epochs):# 訓練階段model.train()train_loss = 0.0train_acc = 0.0total = 0for inputs, labels in train_loader:inputs, labels = inputs.to(device), labels.to(device)optimizer.zero_grad()outputs = model(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()train_loss += loss.item()_, predicted = torch.max(outputs, 1)train_acc += (predicted == labels).sum().item()total += len(labels)
Flask部署
1. 創建Flask應用
2. 實現預測接口
@app.route('/predict', methods=['POST'])
def predict():if 'image' not in request.files:return render_template('index.html', prediction=None)image_file = request.files['image']image_data = image_file.read()# 圖像預處理img = Image.open(io.BytesIO(image_data))img = transform(img)img = torch.unsqueeze(img, dim=0)# 模型預測with torch.no_grad():prediction = model(img)prediction = F.softmax(prediction, dim=1)# 獲取預測結果pred_label = class_labels[torch.argmax(prediction).item()]confidence = torch.max(prediction).item()return render_template('index.html', prediction=pred_label,confidence=confidence)
部署步驟
- 確保服務器已安裝Python環境
- 安裝所需依賴包:
pip install -r requirements.txt
- 將模型文件、Flask應用和模板文件上傳到服務器
- 運行Flask應用:
python app.py
總結
本文詳細介紹了使用MobileNetV3訓練水果分類模型并用Flask部署的完整流程。通過使用預訓練模型,我們可以在較小的數據集上獲得不錯的分類效果。Flask框架的輕量級特性使得部署變得簡單快捷。在實際應用中,可以根據具體需求進行進一步的優化和改進。
參考資料
- MobileNetV3論文:Searching for MobileNetV3
- Flask官方文檔:https://flask.palletsprojects.com/
- PyTorch官方文檔:https://pytorch.org/docs/stable/index.html
- 百度圖片API文檔