基于 Flask的深度學習模型部署服務端詳解

基于 Flask 的深度學習模型部署服務端詳解

在深度學習領域,訓練出一個高精度的模型只是第一步,將其部署到生產環境中,為實際業務提供服務才是最終目標。本文將詳細解析一個基于 Flask 和 PyTorch 的深度學習模型部署服務端代碼,幫助你理解如何將訓練好的模型以 API 形式提供給客戶端使用。

一、整體概述

這段代碼的主要功能是搭建一個基于 Flask 的 Web 服務,用于接收客戶端發送的圖像數據,使用預訓練的 PyTorch 模型對圖像進行分類預測,并將預測結果以 JSON 格式返回給客戶端。

二、代碼詳細解析

1. 導入必要的庫

import io
import flask
import torch
import torch.nn.functional as F
from PIL import Image
from torch import nn
from torchvision import transforms, models
  • io:用于處理二進制數據,這里主要用于將客戶端發送的圖像二進制數據轉換為圖像對象。
  • flask:一個輕量級的 Web 框架,用于搭建 Web 服務。
  • torchtorch.nn.functional:PyTorch 的核心庫,用于深度學習模型的構建和計算。
  • PIL.Image:Python Imaging Library(PIL)的一部分,用于處理圖像文件。
  • torch.nn:用于定義神經網絡的層和模塊。
  • torchvision.transformstorchvision.modelstransforms 用于圖像預處理,models 提供了預訓練的深度學習模型。

2. 初始化 Flask 應用和模型相關變量

app = flask.Flask(__name__)
model = None
use_gpu = False
  • app = flask.Flask(__name__):創建一個新的 Flask 應用實例,__name__ 參數用于確定應用的根路徑。
  • model:用于存儲加載的深度學習模型,初始化為 None
  • use_gpu:一個布爾變量,用于控制是否使用 GPU 進行模型推理,初始化為 False

3. 加載模型

def load_model():global modelmodel = models.resnet18()num_ftrs = model.fc.in_featuresmodel.fc = nn.Sequential(nn.Linear(num_ftrs, 102))checkpoint = torch.load('best.pth')model.load_state_dict(checkpoint['state_dict'])model.eval()if use_gpu:model.cuda()
  • global model:聲明 model 為全局變量,以便在函數內部修改它。
  • model = models.resnet18():加載預訓練的 ResNet-18 模型。
  • num_ftrs = model.fc.in_features:獲取 ResNet-18 模型最后一層全連接層的輸入特征數。
  • model.fc = nn.Sequential(nn.Linear(num_ftrs, 102)):修改最后一層全連接層,將輸出維度改為 102,這里的 102 可以根據實際任務的類別數進行調整。
  • checkpoint = torch.load('best.pth'):從文件 best.pth 中加載訓練好的模型參數。
  • model.load_state_dict(checkpoint['state_dict']):將加載的參數應用到模型中。
  • model.eval():將模型設置為評估模式,關閉一些在訓練時使用的特殊層(如 Dropout)。
  • if use_gpu: model.cuda():如果 use_gpuTrue,將模型移動到 GPU 上。

4. 圖像預處理

def prepare_image(image, target_size):if image.mode != 'RGB':image = image.convert('RGB')image = transforms.Resize(target_size)(image)image = transforms.ToTensor()(image)image = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])(image)image = image[None]if use_gpu:image = image.cuda()return torch.tensor(image)
  • if image.mode != 'RGB': image = image.convert('RGB'):確保輸入圖像為 RGB 格式。
  • image = transforms.Resize(target_size)(image):將圖像調整為指定的大小。
  • image = transforms.ToTensor()(image):將圖像轉換為 PyTorch 張量。
  • image = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])(image):對圖像進行歸一化處理,使用的均值和標準差是在 ImageNet 數據集上計算得到的。
  • image = image[None]:增加一個維度,將圖像轉換為批量輸入的格式。
  • if use_gpu: image = image.cuda():如果 use_gpuTrue,將圖像移動到 GPU 上。

5. 定義預測接口

@app.route('/predict', methods=['POST'])
def predict():data = {'success': False}if flask.request.method == 'POST':if flask.request.files.get('image'):image = flask.request.files['image'].read()image = Image.open(io.BytesIO(image))image = prepare_image(image, target_size=(224, 224))preds = F.softmax(model(image), dim=1)results = torch.topk(preds.cpu().data, k=3, dim=1)results = (results[0].cpu().numpy(), results[1].cpu().numpy())data['prediction'] = list()for prob, label in zip(results[0][0], results[1][0]):r = {'label': str(label), 'probability': float(prob)}data['prediction'].append(r)data['success'] = Truereturn flask.jsonify(data)
  • @app.route('/predict', methods=['POST']):使用 Flask 的裝飾器定義一個路由,當客戶端向 /predict 路徑發送 POST 請求時,會調用 predict 函數。
  • data = {'success': False}:初始化一個字典,用于存儲預測結果和狀態信息,初始狀態為 success = False
  • if flask.request.method == 'POST':檢查請求方法是否為 POST。
  • if flask.request.files.get('image'):檢查請求中是否包含名為 image 的文件。
  • image = flask.request.files['image'].read():讀取客戶端發送的圖像文件內容。
  • image = Image.open(io.BytesIO(image)):將二進制數據轉換為圖像對象。
  • image = prepare_image(image, target_size=(224, 224)):對圖像進行預處理。
  • preds = F.softmax(model(image), dim=1):使用模型進行預測,并通過 softmax 函數將輸出轉換為概率分布。
  • results = torch.topk(preds.cpu().data, k=3, dim=1):獲取概率最大的前 3 個結果。
  • results = (results[0].cpu().numpy(), results[1].cpu().numpy()):將結果轉換為 NumPy 數組。
  • data['prediction'] = list():初始化一個列表,用于存儲預測結果。
  • for prob, label in zip(results[0][0], results[1][0]):遍歷前 3 個結果,將標簽和概率封裝成字典,并添加到 data['prediction'] 列表中。
  • data['success'] = True:將狀態信息設置為 success = True,表示預測成功。
  • return flask.jsonify(data):將結果以 JSON 格式返回給客戶端。

6. 啟動服務

if __name__ == '__main__':print('Loading PyTorch model and Flask starting server ...')print('Please wait until server has fully started')load_model()app.run(host='192.168.1.20', port=5012)
  • if __name__ == '__main__':確保代碼作為主程序運行時才執行以下操作。
  • print('Loading PyTorch model and Flask starting server ...')print('Please wait until server has fully started'):打印啟動信息。
  • load_model():調用 load_model 函數加載模型。
  • app.run(host='192.168.1.20', port=5012):啟動 Flask 服務,監聽 192.168.1.20 地址的 5012 端口。運行結果如下
  • 在這里插入圖片描述

三、總結

通過上述代碼,我們成功搭建了一個基于 Flask 和 PyTorch 的深度學習模型部署服務端。客戶端可以通過向 /predict 路徑發送包含圖像文件的 POST 請求,獲取圖像分類的預測結果。在實際應用中,可以根據需要對代碼進行擴展,如增加更多的模型、優化圖像預處理流程、添加錯誤處理機制等。希望本文能幫助你更好地理解深度學習模型的部署過程。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/diannao/82269.shtml
繁體地址,請注明出處:http://hk.pswp.cn/diannao/82269.shtml
英文地址,請注明出處:http://en.pswp.cn/diannao/82269.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

Vue3 + Node.js 實現客服實時聊天系統(WebSocket + Socket.IO 詳解)

Node.js 實現客服實時聊天系統(WebSocket Socket.IO 詳解) 一、為什么選擇 WebSocket? 想象一下淘寶客服的聊天窗口:你發消息,客服立刻就能看到并回復。這種即時通訊效果是如何實現的呢?我們使用 Vue3 作…

MySQL數據庫與表結構操作指南

前言:本文系統梳理MySQL核心操作語句。內容覆蓋建庫建表、結構調整、數據遷移全流程(包含創建/修改/刪除/備份場景)。希望它們能幫你快速解決問題。 庫結構操作 一、庫的創建 一個庫的簡單創建: create database 庫名; 注意&am…

【WEB3】區塊鏈、隱私計算、AI和Web3.0——數據民主化(1)

區塊鏈、隱私計算、AI,是未來Web3.0至關重要的三項技術。 1.數據民主化問題 數據在整個生命周期(生產、傳輸、處理、存儲)內的隱私安全,則是Web3.0在初始階段首要解決的問題。 數據民主化旨在打破數據壟斷,讓個體能…

C語言—指針2

1. const 修飾變量 1.1 const修飾變量 變量被const修飾時,變量此時為常變量,本質為常量,語法上不可被修改,但是如果此時需要修改變量值,可以通過指針的方式修改。 雖然此時通過指針的方式確實修改了變量的值&#xff…

高級架構軟考之網絡OSI網絡模型

高級架構軟考之網絡: 1.OSI網絡模型: a.物理層: a.物理傳輸介質物理連接,負責數據傳輸,并監控數據 b.傳輸單位:bit c.協議: d:對應設備:中繼器、集線器 b.數據鏈路層: a.…

el-table計算表頭列寬,不換行顯示

1、在utils.js中封裝renderHeader方法 2、在el-table-column中引入: 3、頁面展示:

MySQL OCP和Oracle OCP怎么選?

近期oracle 為慶祝 MySQL 數據庫發布 30 周年,Oracle 官方推出限時福利:2025 年 4 月 20 日至 7 月 31 日期間,所有人均可免費報考 MySQL OCP(Oracle Certified Professional)認證考試(具體可查看MySQL OCP…

2025最新免費視頻號下載工具!支持Win/Mac,一鍵解析原畫質+封面

軟件介紹 適用于Windows 2025 最新5月蝴蝶視頻號下載工具,免費使用,無廣告且免費,支持對原視頻和封面進行解析下載,親測可用,現在很多工具都失效了,難得的幾款下載視頻號工具,大家且用且珍…

Python學習之路(八)-多線程和多進程淺析

在 Python 中,多線程(Multithreading) 和 多進程(Multiprocessing) 是實現并發編程的兩種主要方式。它們各有優劣,適用于不同的場景。 一、基本概念 特性多線程(threading)多進程(multiprocessing)并發模型線程共享內存空間每個進程擁有獨立內存空間GIL(全局解釋器鎖…

Spark緩存--persist方法

1. 功能本質 persist:這是一個通用的持久化方法,能夠指定多種不同的存儲級別。存儲級別決定了數據的存儲位置(如內存、磁盤)以及存儲形式(如是否序列化)。 2. 存儲級別指定 persist:可以通過傳入…

裸辭8年前端的面試筆記——JavaScript篇(一)

裸辭后的第二個月開始準備找工作,今天是第三天目前還沒有面試,現在的行情是一言難盡,都在瘋狂的壓價。 下邊是今天復習的個人筆記 一、事件循環 JavaScript 的事件循環(Event Loop)是其實現異步編程的關鍵機制。 從…

什么是死信隊列?死信隊列是如何導致的?

死信交換機(Dead Letter Exchange,DLX) 定義:死信交換機是一種特殊的交換機,專門用于**接收從其他隊列中因特定原因變成死信的消息**。它的本質還是交換機,遵循RabbitMQ中交換機的基本工作原理&#xff0c…

9. 從《蜀道難》學CSS基礎:三種選擇器的實戰解析

引言:當古詩遇上現代網頁設計 今天我們通過李白的經典詩作《蜀道難》來學習CSS的三種核心選擇器。這種古今結合的學習方式,既能感受中華詩詞的魅力,又能掌握實用的網頁設計技能。讓我們開始這場穿越時空的技術之旅吧! 一、HTML骨架…

三角網格減面算法及其代表的算法庫都有哪些?

以下是三角網格減面算法及其代表庫/工具的詳細分類,涵蓋經典算法和現代實現: ??1. 頂點聚類(Vertex Clustering)?? ??原理??:將網格空間劃分為體素柵格,合并每個柵格內的頂點。??特點??&#…

URP - 屏幕圖像(_CameraOpaqueTexture)

首先需要在unity中開啟屏幕圖像開關才可以使用該紋理 同樣只有不透明對象才能被渲染到屏幕圖像中 若想要該對象不被渲染到屏幕圖像中,可以將其Shader的渲染隊列改為 "Queue" "Transparent" 如何在Shader中使用_CameraOpaqueTexture&#xf…

vue 和 html 的區別

使用 Vue.js 和原生 HTML 開發 Web 應用有顯著的區別,主要體現在開發模式、功能擴展、性能優化和維護性等方面。以下是兩者的對比分析: 🧱 原生 HTML(HTML CSS JavaScript) 特點: 靜態結構:H…

LeetCode[226] 翻轉二叉樹

思路: 使用遞歸,歸根結底還是左右節點互相倒,那么肯定需要一個temp節點在中間傳遞,最后就是遞歸,沒什么說的 代碼: /*** Definition for a binary tree node.* public class TreeNode {* int …

冪等的幾種解決方案以及實踐

目錄 什么是冪等? 解決冪等的常見解決方案: 唯一標識符案例 數據庫唯一約束 案例 樂觀鎖案例 分布式鎖(Distributed Locking) 實踐精選方案 首先 為什么不直接使用分布式鎖呢? 自定義實現冪等組件&#xff01…

PowerShell中的Json處理

1.定義JSON字符串變量 PS C:\WINDOWS\system32> $body {"Method": "POST","Body": {"model": "deepseek-r1","messages": [{"content": "why is the sky blue?","role"…

奧威BI:AI+BI深度融合,重塑智能AI數據分析新標桿

在數字化浪潮席卷全球的今天,企業正面臨著前所未有的數據挑戰與機遇。如何高效、精準地挖掘數據價值,已成為推動業務增長、提升競爭力的核心議題。奧威BI,作為智能AI數據分析領域的領軍者,憑借其創新的AIBI融合模式,正…