模型部署
- 示例:保存 Scikit-learn 模型
- myapp/views.py
- 全局加載模型
- tasks.py(Celery任務)
- views.py 修改為異步調用
- views.py
- 準備工作
模型保存格式
確保你的模型已保存為可加載的格式:
● TensorFlow/Keras:.h5 或 SavedModel 格式
● PyTorch:.pt 或 .pth 文件
● Scikit-learn:使用 joblib 或 pickle 保存(推薦 joblib)
示例:保存 Scikit-learn 模型
from sklearn.ensemble import RandomForestClassifier
import joblib
model = RandomForestClassifier()
model.fit(X_train, y_train)
joblib.dump(model, ‘my_model.joblib’)
-
項目結構規劃
建議的 Django 項目結構:
myproject/
├── myapp/
│ ├── models/ # 存放模型文件
│ │ └── my_model.joblib
│ ├── views.py # 處理請求和模型調用
│ ├── urls.py # 定義API路由
│ └── …
├── myproject/
│ ├── settings.py
│ └── urls.py # 主路由
└── manage.py -
模型加載與初始化
在 Django 中全局加載模型
在 myapp/apps.py 或 views.py 中初始化模型,避免每次請求重復加載。
myapp/views.py
from django.http import JsonResponse
from django.views.decorators.csrf import csrf_exempt
import joblib
import os
全局加載模型
model_path = os.path.join(os.path.dirname(file), ‘models/my_model.joblib’)
model = joblib.load(model_path)
@csrf_exempt # 若需跨域訪問可臨時禁用CSRF(生產環境需謹慎)
def predict(request):
if request.method == ‘POST’:
try:
# 獲取輸入數據(假設發送JSON)
data = json.loads(request.body)
features = data[‘features’]
# 調用模型預測prediction = model.predict([features])[0]return JsonResponse({'prediction': prediction})except Exception as e:return JsonResponse({'error': str(e)}, status=400)
return JsonResponse({'error': '僅支持POST請求'}, status=405)
- 配置路由
在 myapp/urls.py 中添加API路由
from django.urls import path
from . import views
urlpatterns = [
path(‘predict/’, views.predict, name=‘predict’),
]
在項目主路由 myproject/urls.py 中引入
from django.urls import include, path
urlpatterns = [
path(‘api/’, include(‘myapp.urls’)),
]
-
測試API
使用 curl 或 Postman 發送POST請求測試:
curl -X POST http://localhost:8000/api/predict/
-H “Content-Type: application/json”
-d ‘{“features”: [1.2, 3.4, 5.6]}’
預期響應:
{“prediction”: 0} -
高級優化
異步處理(Celery + Redis)
如果模型推理耗時較長,可用 Celery 異步任務避免阻塞請求:
tasks.py(Celery任務)
from celery import shared_task
from myapp.views import model # 復用全局加載的模型
@shared_task
def async_predict(features):
return model.predict([features])[0]
views.py 修改為異步調用
@csrf_exempt
def predict(request):
if request.method == ‘POST’:
data = json.loads(request.body)
task = async_predict.delay(data[‘features’])
return JsonResponse({‘task_id’: task.id}, status=202)
緩存模型輸出
使用 Django 緩存減少重復計算:
from django.core.cache import cache
def predict(request):
data = json.loads(request.body)
features = tuple(data[‘features’]) # 轉換為可哈希類型
# 檢查緩存
if cache.has_key(features):return JsonResponse({'prediction': cache.get(features)})# 計算并緩存
prediction = model.predict([features])[0]
cache.set(features, prediction, timeout=3600) # 緩存1小時
return JsonResponse({'prediction': prediction})
- 關鍵注意事項
- 線程安全:
from threading import Lock
model_lock = Lock()
def predict(request):
with model_lock:
prediction = model.predict(…)
○ 如果模型非線程安全(如某些 TensorFlow 舊版本),需加鎖或使用單例模式。
2. 性能優化:
○ 使用 gunicorn 或 uvicorn 替代 Django 自帶的開發服務器。
○ 啟用 GPU 加速(如 TensorFlow/PyTorch 的 GPU 版本)。
3. 輸入驗證:
def validate_features(features):
if len(features) != 3:
raise ValueError(“必須提供3個特征”)
if not all(isinstance(x, (int, float)) for x in features):
raise ValueError(“特征必須為數字”)
○ 嚴格校驗前端傳入的數據格式和范圍,防止惡意輸入。
4. 依賴管理:
tensorflow2.12.0
scikit-learn1.2.2
joblib==1.2.0
○ 在 requirements.txt 中明確指定模型庫版本:
完整示例:圖像分類模型集成
假設有一個圖像分類模型(如 ResNet),可按以下方式處理文件上傳:
views.py
from django.core.files.storage import default_storage
from tensorflow.keras.preprocessing import image
import numpy as np
def predict_image(request):
if request.method == ‘POST’:
file = request.FILES[‘image’]
file_path = default_storage.save(‘tmp/’ + file.name, file)
# 預處理圖像img = image.load_img(file_path, target_size=(224, 224))img_array = image.img_to_array(img)img_array = np.expand_dims(img_array, axis=0) / 255.0# 預測prediction = model.predict(img_array)class_idx = np.argmax(prediction)return JsonResponse({'class': class_idx})
通過以上步驟,你可以將訓練好的模型無縫集成到 Django 中,并通過 RESTful API 提供服務。根據實際需求調整代碼結構和優化策略。