基于 Flask 的深度學習模型部署服務端詳解
在深度學習領域,訓練出一個高精度的模型只是第一步,將其部署到生產環境中,為實際業務提供服務才是最終目標。本文將詳細解析一個基于 Flask 和 PyTorch 的深度學習模型部署服務端代碼,幫助你理解如何將訓練好的模型以 API 形式提供給客戶端使用。
一、整體概述
這段代碼的主要功能是搭建一個基于 Flask 的 Web 服務,用于接收客戶端發送的圖像數據,使用預訓練的 PyTorch 模型對圖像進行分類預測,并將預測結果以 JSON 格式返回給客戶端。
二、代碼詳細解析
1. 導入必要的庫
import io
import flask
import torch
import torch.nn.functional as F
from PIL import Image
from torch import nn
from torchvision import transforms, models
io
:用于處理二進制數據,這里主要用于將客戶端發送的圖像二進制數據轉換為圖像對象。flask
:一個輕量級的 Web 框架,用于搭建 Web 服務。torch
和torch.nn.functional
:PyTorch 的核心庫,用于深度學習模型的構建和計算。PIL.Image
:Python Imaging Library(PIL)的一部分,用于處理圖像文件。torch.nn
:用于定義神經網絡的層和模塊。torchvision.transforms
和torchvision.models
:transforms
用于圖像預處理,models
提供了預訓練的深度學習模型。
2. 初始化 Flask 應用和模型相關變量
app = flask.Flask(__name__)
model = None
use_gpu = False
app = flask.Flask(__name__)
:創建一個新的 Flask 應用實例,__name__
參數用于確定應用的根路徑。model
:用于存儲加載的深度學習模型,初始化為None
。use_gpu
:一個布爾變量,用于控制是否使用 GPU 進行模型推理,初始化為False
。
3. 加載模型
def load_model():global modelmodel = models.resnet18()num_ftrs = model.fc.in_featuresmodel.fc = nn.Sequential(nn.Linear(num_ftrs, 102))checkpoint = torch.load('best.pth')model.load_state_dict(checkpoint['state_dict'])model.eval()if use_gpu:model.cuda()
global model
:聲明model
為全局變量,以便在函數內部修改它。model = models.resnet18()
:加載預訓練的 ResNet-18 模型。num_ftrs = model.fc.in_features
:獲取 ResNet-18 模型最后一層全連接層的輸入特征數。model.fc = nn.Sequential(nn.Linear(num_ftrs, 102))
:修改最后一層全連接層,將輸出維度改為 102,這里的 102 可以根據實際任務的類別數進行調整。checkpoint = torch.load('best.pth')
:從文件best.pth
中加載訓練好的模型參數。model.load_state_dict(checkpoint['state_dict'])
:將加載的參數應用到模型中。model.eval()
:將模型設置為評估模式,關閉一些在訓練時使用的特殊層(如 Dropout)。if use_gpu: model.cuda()
:如果use_gpu
為True
,將模型移動到 GPU 上。
4. 圖像預處理
def prepare_image(image, target_size):if image.mode != 'RGB':image = image.convert('RGB')image = transforms.Resize(target_size)(image)image = transforms.ToTensor()(image)image = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])(image)image = image[None]if use_gpu:image = image.cuda()return torch.tensor(image)
if image.mode != 'RGB': image = image.convert('RGB')
:確保輸入圖像為 RGB 格式。image = transforms.Resize(target_size)(image)
:將圖像調整為指定的大小。image = transforms.ToTensor()(image)
:將圖像轉換為 PyTorch 張量。image = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])(image)
:對圖像進行歸一化處理,使用的均值和標準差是在 ImageNet 數據集上計算得到的。image = image[None]
:增加一個維度,將圖像轉換為批量輸入的格式。if use_gpu: image = image.cuda()
:如果use_gpu
為True
,將圖像移動到 GPU 上。
5. 定義預測接口
@app.route('/predict', methods=['POST'])
def predict():data = {'success': False}if flask.request.method == 'POST':if flask.request.files.get('image'):image = flask.request.files['image'].read()image = Image.open(io.BytesIO(image))image = prepare_image(image, target_size=(224, 224))preds = F.softmax(model(image), dim=1)results = torch.topk(preds.cpu().data, k=3, dim=1)results = (results[0].cpu().numpy(), results[1].cpu().numpy())data['prediction'] = list()for prob, label in zip(results[0][0], results[1][0]):r = {'label': str(label), 'probability': float(prob)}data['prediction'].append(r)data['success'] = Truereturn flask.jsonify(data)
@app.route('/predict', methods=['POST'])
:使用 Flask 的裝飾器定義一個路由,當客戶端向/predict
路徑發送 POST 請求時,會調用predict
函數。data = {'success': False}
:初始化一個字典,用于存儲預測結果和狀態信息,初始狀態為success = False
。if flask.request.method == 'POST'
:檢查請求方法是否為 POST。if flask.request.files.get('image')
:檢查請求中是否包含名為image
的文件。image = flask.request.files['image'].read()
:讀取客戶端發送的圖像文件內容。image = Image.open(io.BytesIO(image))
:將二進制數據轉換為圖像對象。image = prepare_image(image, target_size=(224, 224))
:對圖像進行預處理。preds = F.softmax(model(image), dim=1)
:使用模型進行預測,并通過softmax
函數將輸出轉換為概率分布。results = torch.topk(preds.cpu().data, k=3, dim=1)
:獲取概率最大的前 3 個結果。results = (results[0].cpu().numpy(), results[1].cpu().numpy())
:將結果轉換為 NumPy 數組。data['prediction'] = list()
:初始化一個列表,用于存儲預測結果。for prob, label in zip(results[0][0], results[1][0])
:遍歷前 3 個結果,將標簽和概率封裝成字典,并添加到data['prediction']
列表中。data['success'] = True
:將狀態信息設置為success = True
,表示預測成功。return flask.jsonify(data)
:將結果以 JSON 格式返回給客戶端。
6. 啟動服務
if __name__ == '__main__':print('Loading PyTorch model and Flask starting server ...')print('Please wait until server has fully started')load_model()app.run(host='192.168.1.20', port=5012)
if __name__ == '__main__'
:確保代碼作為主程序運行時才執行以下操作。print('Loading PyTorch model and Flask starting server ...')
和print('Please wait until server has fully started')
:打印啟動信息。load_model()
:調用load_model
函數加載模型。app.run(host='192.168.1.20', port=5012)
:啟動 Flask 服務,監聽192.168.1.20
地址的 5012 端口。運行結果如下
三、總結
通過上述代碼,我們成功搭建了一個基于 Flask 和 PyTorch 的深度學習模型部署服務端。客戶端可以通過向 /predict
路徑發送包含圖像文件的 POST 請求,獲取圖像分類的預測結果。在實際應用中,可以根據需要對代碼進行擴展,如增加更多的模型、優化圖像預處理流程、添加錯誤處理機制等。希望本文能幫助你更好地理解深度學習模型的部署過程。