部署深度學習模型:Flask API 服務端與客戶端通信實戰
在這篇文章中,我們將探討如何使用 Flask 框架部署一個深度學習模型,并通過客戶端與服務端進行通信。我們將通過一個實際的例子,展示如何構建服務端和客戶端,以及如何處理圖像預測請求。
環境準備
首先,確保你已經安裝了以下庫:
- Flask
- PyTorch
- torchvision
- Pillow
如果尚未安裝,可以通過以下命令安裝:
pip install flask torch torchvision pillow
服務端代碼
服務端代碼的主要功能是加載預訓練的深度學習模型,接收客戶端發送的圖像數據,進行預測,并將結果返回給客戶端。
# 導入所需的庫
import io # 用于處理二進制數據
import flask # Flask框架,用于搭建Web服務
import torch # PyTorch庫,用于深度學習模型的加載和推理
import torch.nn.functional as F # PyTorch的神經網絡函數模塊,用于softmax等操作
from PIL import Image # Python圖像處理庫,用于圖像的讀取和預處理
from torch import nn # PyTorch的神經網絡模塊
from torchvision import transforms, models # torchvision庫,用于圖像預處理和加載預訓練模型# 初始化Flask app
app = flask.Flask(__name__) # 創建一個新的Flask應用程序實例
# __name__參數通常被傳遞給Flask應用程序來定位應用程序的根路徑,這樣Flask就可以知道在哪里找到模板、靜態文件等。
# 總體來說app = flask.Flask(__name__)是FLasK應用程序的起點。它初始化了一個新的FLasK應用程序實例。
model = None # 初始化模型變量為None
use_gpu = False # 初始化是否使用GPU的標志為False# 定義加載模型的函數
def load_model():# """Load the pre-trained model, you can use your model just as easily."""global model # 聲明使用全局變量model# 加載resnet18網絡model = models.resnet18() # 加載預訓練的resnet18模型num_ftrs = model.fc.in_features # 獲取全連接層的輸入特征數量model.fc = nn.Sequential(nn.Linear(num_ftrs, 102)) # 修改全連接層,輸出為102個類別(根據具體任務修改類別數)# print(model)checkpoint = torch.load('best.pth') # 加載訓練好的模型權重model.load_state_dict(checkpoint['state_dict']) # 將權重加載到模型中# 將模型指定為測試格式model.eval() # 將模型設置為評估模式# 是否使用gpuif use_gpu:model.cuda() # 如果使用GPU,則將模型移動到GPU上# 定義數據預處理函數
def prepare_image(image, target_size):"""Do image preprocessing before prediction on any data.param image : original imageparam target_size : target image sizereturn : preprocessed image"""# 針對不同模型,image的格式不同,但需要統一到RGB格式if image.mode != 'RGB':image = image.convert('RGB') # 如果圖像不是RGB格式,則轉換為RGB格式# Resize the input image and preprocess it.(按照所使用的模型將輸入圖片的尺寸修改)image = transforms.Resize(target_size)(image) # 調整圖像大小為目標尺寸image = transforms.ToTensor()(image) # 將圖像轉換為Tensor# Convert to Torch. Tensor and normalize. mean與stdimage = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])(image) # 對圖像進行標準化處理# Add batch size axis 增加一個維度,用于按batch測試image = image[None] # 增加一個維度if use_gpu:image = image.cuda() # 如果使用GPU,則將圖像移動到GPU上return torch.tensor(image) # 返回預處理后的圖像Tensor# 定義一個裝飾器,用于將指定的URL路徑與一個函數關聯起來,并指定該函數響應的HTTP方法
@app.route('/predict', methods=['POST'])
def predict(): # 當客戶端發送請求時# 做一個標志,剛開始無圖像傳入時為false,傳入圖像時為truedata = {'success': False} # 初始化返回數據字典,初始值為Falseif flask.request.method == 'POST': # 如果收到POST請求if flask.request.files.get('image'): # 判斷是否有圖像文件image = flask.request.files['image'].read() # 將收到的圖像進行讀取,內容為二進制image = Image.open(io.BytesIO(image)) # 將二進制圖像數據轉換為PIL圖像對象# 利用上面的預處理函數將讀入的圖像進行預處理image = prepare_image(image, target_size=(224, 224)) # 對圖像進行預處理,目標尺寸為224x224preds = F.softmax(model(image), dim=1) # 得到各個類別的概率results = torch.topk(preds.cpu().data, k=3, dim=1) # 概率最大的前3個結果# torch.topk用于返回輸入張量中每行最大的k個元素及其對應的索引results = (results[0].cpu().numpy(), results[1].cpu().numpy()) # 將結果轉換為numpy數組# 將data字典增加一個key,value,其中value為list格式data['prediction'] = list() # 初始化預測結果列表for prob, label in zip(results[0][0], results[1][0]): # 遍歷概率和標簽# Label name = idx2labellstr(label)]r = {'label': str(label), 'probability': float(prob)} # 創建一個字典,包含標簽和概率# 將預測結果添加至data字典data['prediction'].append(r) # 將預測結果添加到列表中# Indicate that the request was a success.data['success'] = True # 將請求成功的標志設置為Truereturn flask.jsonify(data) # 返回預測結果的JSON格式數據# 主程序入口
if __name__ == '__main__': # 判斷是否是主程序運行print('Loading PyTorch model and Flask starting server ...') # 打印加載模型和啟動服務器的信息print('Please wait until server has fully started') # 提示用戶等待服務器啟動完成load_model() # 先加載模型# 再開啟服務app.run(port='5012') # 啟動Flask應用,監聽5012端口
客戶端代碼
客戶端代碼負責發送圖像數據到服務端,并接收預測結果。
import requests # 導入requests庫,用于發送HTTP請求# 定義Flask服務的URL地址
flask_url = 'http://127.0.0.1:5012/predict' # Flask服務的地址,運行在本地主機的5012端口# 定義一個函數,用于發送圖像到Flask服務并獲取預測結果
def predict_result(image_path):# 打開圖像文件并讀取其內容image = open(image_path, 'rb').read() # 以二進制模式打開圖像文件并讀取內容payload = {'image': image} # 將圖像內容封裝為一個字典,作為請求的文件數據# 使用requests庫發送POST請求到Flask服務r = requests.post(flask_url, files=payload).json() # 發送POST請求,并將返回的JSON數據解析為字典# 檢查請求是否成功if r['success']:# 如果請求成功,遍歷預測結果并打印for (i, result) in enumerate(r['prediction']):print('{}. 預測類別為{}的概率:{}'.format(i + 1, result['label'], result['probability']))else:# 如果請求失敗,打印失敗信息print('request failed')# 主程序入口
if __name__ == '__main__':# 定義要預測的圖像路徑image_path = r'D:\Users\妄生\PycharmProjects\人工智能\深度學習\模型部署\flower_data\flower_data\val_filelist\image_00059.jpg'# 調用predict_result函數,對指定圖像進行預測predict_result(image_path)
運行與測試
-
啟動服務端:
-
啟動客戶端:
網絡問題處理
如果在運行過程中遇到網絡問題,例如無法訪問 http://127.0.0.1:5012/predict
,這可能是由于以下原因:
- 服務端未正確啟動或端口被占用。
- 本地網絡配置問題。
解決方法:
- 確保服務端正確啟動,并監聽在正確的端口上。
- 檢查防火墻或安全軟件設置,確保沒有阻止訪問該端口。
- 嘗試重新啟動服務端或計算機。
總結
通過這篇文章,我們展示了如何使用 Flask 部署一個深度學習模型,并通過客戶端與服務端進行通信。我們詳細解釋了服務端和客戶端的代碼,并提供了運行和測試的步驟。希望這能幫助你理解如何將深度學習模型部署為 Web 服務,并處理可能遇到的網絡問題。