繪制混淆矩陣(Confusion Matrix),用于評估分類模型的性能。混淆矩陣展示了模型預測結果與真實標簽之間的對應關系,能夠直觀地顯示各類別的預測準確性和錯誤類型。
混淆矩陣是評估分類模型性能的基礎工具,特別適用于多分類問題。
你可以使用swanlab.confusion_matrix
來記錄混淆矩陣。
Demo鏈接:ComputeMetrics - SwanLab
基本用法
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
import xgboost as xgb
import swanlab# 加載鳶尾花數據集
iris_data = load_iris()
X = iris_data.data
y = iris_data.target
class_names = iris_data.target_names.tolist()# 劃分訓練集和測試集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)# 訓練模型
model = xgb.XGBClassifier(objective='multi:softmax', num_class=len(class_names))
model.fit(X_train, y_train)# 獲取預測結果
y_pred = model.predict(X_test)# 初始化SwanLab
swanlab.init(project="Confusion-Matrix-Demo", experiment_name="Confusion-Matrix-Example")# 記錄混淆矩陣
swanlab.log({"confusion_matrix": swanlab.confusion_matrix(y_test, y_pred, class_names)
})swanlab.finish()
使用自定義類別名稱
# 定義自定義類別名稱
custom_class_names = ["類別A", "類別B", "類別C"]# 記錄混淆矩陣
confusion_matrix = swanlab.confusion_matrix(y_test, y_pred, custom_class_names)
swanlab.log({"confusion_matrix_custom": confusion_matrix})
不使用類別名稱
# 不指定類別名稱,將使用數字索引
confusion_matrix = swanlab.confusion_matrix(y_test, y_pred)
swanlab.log({"confusion_matrix_default": confusion_matrix})
二分類示例
import numpy as np
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
import xgboost as xgb
import swanlab# 生成二分類數據
X, y = make_classification(n_samples=1000, n_features=20, n_informative=2, n_redundant=10, random_state=42)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)# 訓練模型
model = xgb.XGBClassifier(use_label_encoder=False, eval_metric='logloss')
model.fit(X_train, y_train)# 獲取預測結果
y_pred = model.predict(X_test)# 記錄混淆矩陣
swanlab.log({"confusion_matrix": swanlab.confusion_matrix(y_test, y_pred, ["負類", "正類"])
})
注意事項
- 數據格式:
y_true
和y_pred
可以是列表或numpy數組 - 多分類支持: 此函數支持二分類和多分類問題
- 類別名稱:
class_names
的長度應該與類別數量一致 - 依賴包: 需要安裝
scikit-learn
和pyecharts
包 - 坐標軸: sklearn的confusion_matrix左上角為(0,0),在pyecharts的heatmap中是左下角,函數會自動處理坐標轉換
- 矩陣解讀: 混淆矩陣中,行表示真實標簽,列表示預測標簽