Gradio 是一個用于快速創建機器學習模型和用戶界面之間交互的 Python 庫。它允許你無需編寫大量前端代碼,就能將機器學習模型部署為可交互的網頁應用。以下是一個基于 Gradio 可視化部署機器學習應用的基本步驟:
-
安裝 Gradio:
首先,你需要安裝 Gradio 庫。你可以使用 pip 來安裝:pip install gradio
-
導入 Gradio 并定義界面:
在你的 Python 腳本中,導入 Gradio,并定義輸入和輸出的組件。這些組件將構成你的交互界面的基礎。 -
加載機器學習模型:
加載你已經訓練好的機器學習模型。這可以是一個 scikit-learn 模型、TensorFlow 模型、PyTorch 模型等。 -
定義預測函數:
創建一個函數,該函數接受 Gradio 界面上的輸入,使用加載的模型進行預測,并返回預測結果。 -
創建 Gradio 接口:
使用 Gradio 的Interface
類(或其簡寫形式gr.Interface
)來創建交互界面。你需要指定輸入組件、輸出組件以及預測函數。 -
啟動 Gradio 應用:
調用launch()
方法來啟動 Gradio 應用。默認情況下,它將在本地服務器上運行,并在瀏覽器中自動打開。
以下是一個簡單的示例,展示了如何使用 Gradio 部署一個基于 scikit-learn 的鳶尾花分類模型:
import gradio as gr
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import StandardScaler# 加載數據集
iris = load_iris()
X, y = iris.data, iris.target# 劃分數據集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)# 數據標準化
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)# 訓練模型
model = LogisticRegression(max_iter=200)
model.fit(X_train, y_train)# 定義預測函數
def predict(sepal_length: float, sepal_width: float, petal_length: float, petal_width: float) -> str:X_new = [[sepal_length, sepal_width, petal_length, petal_width]]X_new_scaled = scaler.transform(X_new)prediction = model.predict(X_new_scaled)return iris.target_names[prediction[0]].capitalize()# 創建 Gradio 接口
iface = gr.Interface(fn=predict,inputs=[gr.inputs.NumberBox(label="Sepal Length"),gr.inputs.NumberBox(label="Sepal Width"),gr.inputs.NumberBox(label="Petal Length"),gr.inputs.NumberBox(label="Petal Width")],outputs=gr.outputs.Textbox(label="Predicted Iris Species"))# 啟動 Gradio 應用
iface.launch()
在這個示例中,我們創建了一個簡單的 Gradio 界面,用戶可以通過輸入鳶尾花的四個特征(花萼長度、花萼寬度、花瓣長度、花瓣寬度)來預測鳶尾花的種類。預測結果將以文本形式顯示。
你可以根據自己的需求調整輸入和輸出組件,以及預測函數。Gradio 支持多種類型的輸入和輸出組件,如文本框、下拉菜單、圖像上傳、滑塊等,使得創建復雜的交互界面變得非常容易。