一、項目背景?:
? ? ? ? 本項目是“從零開始學習深度學習”系列中的第二個實戰項目,旨在實現第一個簡易App(圖像分類任務——水果分類),進一步地落地AI模型應用,幫助初學者初步了解模型落地。
? ? ? ? 基于PyQt5圖形界面的水果圖像分類系統,用戶可以通過加載模型、選擇圖像并一鍵完成圖像識別。
二、項目目標🚀:
? ? ? ? 基于PyQt5圖形界面實現以下功能:
加載本地
.pth
訓練好的模型;加載本地圖像進行展示;
自動完成圖像預處理(Resize、ToTensor、Normalize);
使用模型完成預測并展示結果;
界面美觀,交互友好。
三、適合人群🫵:
- 深度學習零基礎或剛入門的學習者
- 希望通過項目實戰學習BP神經網絡、卷積神經網絡模型搭建的開發者
- 對圖像識別、分類應用感興趣的童鞋
- 適用于想學習通過界面實現AI模型推理,
四、項目實戰?:
1.主界面構建
def initUI(self):# 主窗口設置self.setWindowTitle("水果分類應用")self.setGeometry(100, 100, 800, 600)# 創建主窗口部件central_widget = QWidget()self.setCentralWidget(central_widget)# 創建主布局main_layout = QVBoxLayout()# 模型選擇部分model_layout = QHBoxLayout()model_label = QLabel("模型路徑:")self.model_path_edit = QtWidgets.QLineEdit()model_button = QPushButton("選擇模型")model_button.clicked.connect(self.select_model_path)self.load_model_button = QPushButton("加載模型")self.load_model_button.clicked.connect(self.load_model)self.load_model_button.setEnabled(False)model_layout.addWidget(model_label)model_layout.addWidget(self.model_path_edit)model_layout.addWidget(model_button)model_layout.addWidget(self.load_model_button)main_layout.addLayout(model_layout)# 圖像顯示部分self.image_label = QLabel()self.image_label.setAlignment(QtCore.Qt.AlignCenter)self.image_label.setMinimumSize(600, 400)main_layout.addWidget(self.image_label)# 圖像選擇部分image_layout = QHBoxLayout()image_path_label = QLabel("圖像路徑:")self.image_path_edit = QtWidgets.QLineEdit()image_select_button = QPushButton("選擇圖像")image_select_button.clicked.connect(self.select_image_path)self.predict_button = QPushButton("分類預測")self.predict_button.clicked.connect(self.predict_image)self.predict_button.setEnabled(False)image_layout.addWidget(image_path_label)image_layout.addWidget(self.image_path_edit)image_layout.addWidget(image_select_button)image_layout.addWidget(self.predict_button)main_layout.addLayout(image_layout)# 結果顯示部分self.result_label = QLabel("請先加載模型并選擇圖像")self.result_label.setAlignment(QtCore.Qt.AlignCenter)self.result_label.setStyleSheet("font-size: 20px")main_layout.addWidget(self.result_label)central_widget.setLayout(main_layout)
2.功能輔助函數
def select_model_path(self):file_path, _ = QFileDialog.getOpenFileName(self,"選擇模型文件","","Pytorch模型 (*.pth);;所有文件(*)")if file_path:self.model_path_edit.setText(file_path)self.load_model_button.setEnabled(True)def load_model(self):model_path = self.model_path_edit.text()if not model_path:returntry:# 模型類型(根據你的模型的時間需求進行修改)self.model = FruitClassificationModelResnet18(4)self.model.load_state_dict(torch.load(model_path, map_location=self.device, weights_only=False))self.model = self.model.to(self.device)self.model.eval()self.result_label.setText("模型加載成功!請選擇圖像進行預測.")self.predict_button.setEnabled(True)except Exception as e:self.result_label.setText(f"模型加載失敗: {str(e)}")self.model = Noneself.predict_button.setEnabled(False)def select_image_path(self):file_path, _ = QFileDialog.getOpenFileName(self,"選擇圖像文件","","圖像文件 (*bmp *.png *.jpg *.jpeg);;所有文件(*)")if file_path:self.image_path_edit.setText(file_path)self.display_image(file_path)def display_image(self, file_path):pixmap = QtGui.QPixmap(file_path)if not pixmap.isNull():scaled_pixmap = pixmap.scaled(self.image_label.size(),QtCore.Qt.KeepAspectRatio,QtCore.Qt.SmoothTransformation)self.image_label.setPixmap(scaled_pixmap)else:self.image_label.setText("無法加載圖像")def preprocess_image(self, image_path):try:# 定義圖像預處理流程transform = transforms.Compose([transforms.Resize((224, 224)), # 調整圖像大小為224x224transforms.ToTensor(), # 轉換為Tensor格式transforms.Normalize([0.485, 0.456, 0.406], # 標準化均值(ImageNet數據集)[0.229, 0.224, 0.225]) # 標準化標準差])# 打開圖像文件image = Image.open(image_path)# 如果圖像不是RGB模式,轉換為RGBif image.mode != "RGB":image = image.convert("RGB")# 應用預處理變換并添加batch維度(unsqueeze(0)),然后移動到指定設備image = transform(image).unsqueeze(0).to(self.device)return imageexcept Exception as e:self.result_label.setText(f"圖像預處理失敗: {str(e)}")return None
3.加載模型
def load_model(self):model_path = self.model_path_edit.text()if not model_path:returntry:# 模型類型(根據你的模型的時間需求進行修改)self.model = FruitClassificationModelResnet18(4)self.model.load_state_dict(torch.load(model_path, map_location=self.device, weights_only=False))self.model = self.model.to(self.device)self.model.eval()self.result_label.setText("模型加載成功!請選擇圖像進行預測.")self.predict_button.setEnabled(True)except Exception as e:self.result_label.setText(f"模型加載失敗: {str(e)}")self.model = Noneself.predict_button.setEnabled(False)
4.預測函數
def predict_image(self):if not self.model:self.result_label.setText("請先加載模型")returnimage_path = self.image_path_edit.text()if not image_path:self.result_label.setText("請選擇圖像")returninput_tensor = self.preprocess_image(image_path)if input_tensor is None:return# 預測with torch.no_grad():input_tensor = input_tensor.to(self.device)outputs = self.model(input_tensor)_, predicted = torch.max(outputs.data, 1)class_id = predicted.item()# 顯示結果class_names = ['Apple', 'Banana', 'Orange', 'Pinenapple'] # 示例類別 根據你的模型進行修改if class_id < len(class_names):self.result_label.setText(f"預測結果: {class_names[class_id]}")else:self.result_label.setText(f"預測結果: 未知類別 ({class_id})")QtWidgets.QApplication.processEvents()
6.完整實現代碼
import cv2
import sys
import numpy as np
import torch
import torchvision.transforms as transforms
from PIL import Image
from PyQt5 import QtWidgets, QtCore, QtGui
from PyQt5.QtWidgets import QFileDialog, QLabel, QPushButton, QVBoxLayout, QWidget, QHBoxLayout
from model import FruitClassificationModelResnet18class FruitClassificationApp(QtWidgets.QMainWindow):def __init__(self):super().__init__()self.model = Noneself.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")self.initUI()def initUI(self):# 主窗口設置self.setWindowTitle("水果分類應用")self.setGeometry(100, 100, 800, 600)# 創建主窗口部件central_widget = QWidget()self.setCentralWidget(central_widget)# 創建主布局main_layout = QVBoxLayout()# 模型選擇部分model_layout = QHBoxLayout()model_label = QLabel("模型路徑:")self.model_path_edit = QtWidgets.QLineEdit()model_button = QPushButton("選擇模型")model_button.clicked.connect(self.select_model_path)self.load_model_button = QPushButton("加載模型")self.load_model_button.clicked.connect(self.load_model)self.load_model_button.setEnabled(False)model_layout.addWidget(model_label)model_layout.addWidget(self.model_path_edit)model_layout.addWidget(model_button)model_layout.addWidget(self.load_model_button)main_layout.addLayout(model_layout)# 圖像顯示部分self.image_label = QLabel()self.image_label.setAlignment(QtCore.Qt.AlignCenter)self.image_label.setMinimumSize(600, 400)main_layout.addWidget(self.image_label)# 圖像選擇部分image_layout = QHBoxLayout()image_path_label = QLabel("圖像路徑:")self.image_path_edit = QtWidgets.QLineEdit()image_select_button = QPushButton("選擇圖像")image_select_button.clicked.connect(self.select_image_path)self.predict_button = QPushButton("分類預測")self.predict_button.clicked.connect(self.predict_image)self.predict_button.setEnabled(False)image_layout.addWidget(image_path_label)image_layout.addWidget(self.image_path_edit)image_layout.addWidget(image_select_button)image_layout.addWidget(self.predict_button)main_layout.addLayout(image_layout)# 結果顯示部分self.result_label = QLabel("請先加載模型并選擇圖像")self.result_label.setAlignment(QtCore.Qt.AlignCenter)self.result_label.setStyleSheet("font-size: 20px")main_layout.addWidget(self.result_label)central_widget.setLayout(main_layout)def select_model_path(self):file_path, _ = QFileDialog.getOpenFileName(self,"選擇模型文件","","Pytorch模型 (*.pth);;所有文件(*)")if file_path:self.model_path_edit.setText(file_path)self.load_model_button.setEnabled(True)def load_model(self):model_path = self.model_path_edit.text()if not model_path:returntry:# 模型類型(根據你的模型的時間需求進行修改)self.model = FruitClassificationModelResnet18(4)self.model.load_state_dict(torch.load(model_path, map_location=self.device, weights_only=False))self.model = self.model.to(self.device)self.model.eval()self.result_label.setText("模型加載成功!請選擇圖像進行預測.")self.predict_button.setEnabled(True)except Exception as e:self.result_label.setText(f"模型加載失敗: {str(e)}")self.model = Noneself.predict_button.setEnabled(False)def select_image_path(self):file_path, _ = QFileDialog.getOpenFileName(self,"選擇圖像文件","","圖像文件 (*bmp *.png *.jpg *.jpeg);;所有文件(*)")if file_path:self.image_path_edit.setText(file_path)self.display_image(file_path)def display_image(self, file_path):pixmap = QtGui.QPixmap(file_path)if not pixmap.isNull():scaled_pixmap = pixmap.scaled(self.image_label.size(),QtCore.Qt.KeepAspectRatio,QtCore.Qt.SmoothTransformation)self.image_label.setPixmap(scaled_pixmap)else:self.image_label.setText("無法加載圖像")def preprocess_image(self, image_path):try:# 定義圖像預處理流程transform = transforms.Compose([transforms.Resize((224, 224)), # 調整圖像大小為224x224transforms.ToTensor(), # 轉換為Tensor格式transforms.Normalize([0.485, 0.456, 0.406], # 標準化均值(ImageNet數據集)[0.229, 0.224, 0.225]) # 標準化標準差])# 打開圖像文件image = Image.open(image_path)# 如果圖像不是RGB模式,轉換為RGBif image.mode != "RGB":image = image.convert("RGB")# 應用預處理變換并添加batch維度(unsqueeze(0)),然后移動到指定設備image = transform(image).unsqueeze(0).to(self.device)return imageexcept Exception as e:self.result_label.setText(f"圖像預處理失敗: {str(e)}")return Nonedef predict_image(self):if not self.model:self.result_label.setText("請先加載模型")returnimage_path = self.image_path_edit.text()if not image_path:self.result_label.setText("請選擇圖像")returninput_tensor = self.preprocess_image(image_path)if input_tensor is None:return# 預測with torch.no_grad():input_tensor = input_tensor.to(self.device)outputs = self.model(input_tensor)_, predicted = torch.max(outputs.data, 1)class_id = predicted.item()# 顯示結果class_names = ['Apple', 'Banana', 'Orange', 'Pinenapple'] # 示例類別 根據你的模型進行修改if class_id < len(class_names):self.result_label.setText(f"預測結果: {class_names[class_id]}")else:self.result_label.setText(f"預測結果: 未知類別 ({class_id})")QtWidgets.QApplication.processEvents()if __name__ == '__main__':app = QtWidgets.QApplication(sys.argv)window = FruitClassificationApp()window.show()sys.exit(app.exec_())
五、學習收獲🎁:
通過本次 PyTorch 與 PyQt5 的項目實戰,不僅鞏固了深度學習模型的使用方法,也系統地學習了如何將模型部署到圖形界面中。以下是我的一些具體收獲:
1?? 深度學習模型部署實踐
學會了如何將
.pth
格式的模型加載到推理環境;熟悉了圖像的預處理流程(如Resize、ToTensor、Normalize);
掌握了
torch.no_grad()
推理模式下的使用,避免梯度計算加速推理。
2?? PyQt5 圖形界面開發
掌握了 PyQt5 中常用的控件如
QLabel
、QPushButton
、QLineEdit
等;學會了如何使用
QFileDialog
實現文件選擇;了解了如何通過
QPixmap
加載并展示圖像;熟悉了
QVBoxLayout
和QHBoxLayout
進行界面布局。
3?? 端到端流程整合
實現了從模型加載 → 圖像讀取 → 圖像預處理 → 推理 → 展示結果 的完整流程;
初步理解了如何將 AI 模型變成一個用戶可交互的軟件;
為后續構建更復雜的推理系統(如視頻流識別、多模型切換)打下了基礎。
注:完整代碼,請私聊,免費獲取。