# 導入必要的庫
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import (accuracy_score, confusion_matrix,classification_report, ConfusionMatrixDisplay)
from sklearn.preprocessing import StandardScaler# 1. 加載鳶尾花數據集
iris = load_iris()
# 轉換為DataFrame方便查看(特征+標簽)
iris_df = pd.DataFrame(data=iris.data, columns=iris.feature_names)
iris_df['species'] = [iris.target_names[i] for i in iris.target] # 添加花名標簽# 2. 數據基本信息查看
print("數據集形狀:", iris.data.shape) # 150個樣本,4個特征
print("\n特征名稱:", iris.feature_names) # 花萼長度、寬度,花瓣長度、寬度
print("\n類別名稱:", iris.target_names) # 山鳶尾、變色鳶尾、維吉尼亞鳶尾# 3. 數據劃分(特征X和標簽y)
X = iris.data # 特征:4個植物學測量值
y = iris.target # 標簽:0,1,2分別對應三種鳶尾花# 劃分訓練集(80%)和測試集(20%),隨機種子確保結果可復現
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42, stratify=y # stratify=y保持類別比例
)# 4. 特征標準化(邏輯回歸對特征尺度敏感,標準化可提升性能)
scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train) # 訓練集擬合并標準化
X_test_scaled = scaler.transform(X_test) # 測試集使用相同的標準化參數# 5. 訓練邏輯回歸模型(多分類任務)
model = LogisticRegression(max_iter=200, random_state=42) # 增加迭代次數確保收斂
model.fit(X_train_scaled, y_train)# 6. 模型預測
y_pred = model.predict(X_test_scaled) # 測試集預測標簽
y_pred_proba = model.predict_proba(X_test_scaled) # 預測每個類別的概率# 7. 模型評估
print("\n===== 模型評估結果 =====")
print(f"訓練集準確率:{model.score(X_train_scaled, y_train):.4f}")
print(f"測試集準確率:{accuracy_score(y_test, y_pred):.4f}")print("\n混淆矩陣:")
cm = confusion_matrix(y_test, y_pred)
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=iris.target_names)
disp.plot(cmap=plt.cm.Blues)
plt.title("混淆矩陣(測試集)")
plt.show()print("\n分類報告:")
print(classification_report(y_test, y_pred, target_names=iris.target_names))# 8. 特征重要性分析(邏輯回歸系數)
feature_importance = pd.DataFrame({'特征': iris.feature_names,'系數絕對值': np.abs(model.coef_).mean(axis=0) # 多分類取各系數的絕對值均值
}).sort_values(by='系數絕對值', ascending=False)print("\n特征重要性(系數絕對值):")
print(feature_importance)# 可視化特征重要性
plt.figure(figsize=(8, 4))
sns.barplot(x='系數絕對值', y='特征', data=feature_importance, palette='coolwarm')
plt.title("特征對分類的重要性")
plt.show()# 9. 新樣本預測示例
# 假設一個新的鳶尾花測量數據(花萼長、花萼寬、花瓣長、花瓣寬)
new_sample = np.array([[5.8, 3.0, 4.9, 1.6]]) # 接近變色鳶尾的特征
new_sample_scaled = scaler.transform(new_sample) # 標準化# 預測結果
predicted_class = model.predict(new_sample_scaled)
predicted_prob = model.predict_proba(new_sample_scaled)print("\n===== 新樣本預測 =====")
print(f"預測類別:{iris.target_names[predicted_class[0]]}")
print("各類別概率:")
for i, prob in enumerate(predicted_prob[0]):print(f"{iris.target_names[i]}: {prob:.4f}")
這段代碼使用邏輯回歸算法對經典的鳶尾花數據集進行分類,是一個完整的機器學習項目流程。
1. 導入必要的庫
import numpy as np import pandas as pd import matplotlib.pyplot as plt import seaborn as sns from sklearn.datasets import load_iris from sklearn.model_selection import train_test_split from sklearn.linear_model import LogisticRegression from sklearn.metrics import (accuracy_score, confusion_matrix, ???????????????????????????? classification_report, ConfusionMatrixDisplay) from sklearn.preprocessing import StandardScaler |
- numpy/pandas:用于數據處理(如矩陣運算、表格操作)。
- matplotlib/seaborn:用于繪制圖表(如混淆矩陣、特征重要性)。
- sklearn:機器學習庫,提供數據集、模型、評估工具。
2. 加載和查看數據
iris = load_iris()? # 加載內置鳶尾花數據集 iris_df = pd.DataFrame(iris.data, columns=iris.feature_names) iris_df['species'] = [iris.target_names[i] for i in iris.target] print("數據集形狀:", iris.data.shape)? # (150, 4) → 150個樣本,4個特征 print("特征名稱:", iris.feature_names)? # 花瓣/花萼的長度、寬度 print("類別名稱:", iris.target_names)? # ['setosa' 'versicolor' 'virginica'] |
- 鳶尾花數據集:包含 150 朵花的數據,分為 3 個品種(每個品種 50 朵)。
- 4 個特征:花瓣長度、花瓣寬度、花萼長度、花萼寬度(都是厘米)。
- 目標:根據這 4 個特征預測花的品種。
3. 數據劃分(訓練集和測試集)
X = iris.data? # 特征(花瓣/花萼的測量值) y = iris.target? # 標簽(0/1/2對應3個品種) X_train, X_test, y_train, y_test = train_test_split( ??? X, y, test_size=0.2, random_state=42, stratify=y ) |
- train_test_split:將數據分為 80% 訓練集和 20% 測試集。
- stratify=y:確保訓練集和測試集中 3 個品種的比例相同(避免數據偏斜)。
- random_state=42:固定隨機種子,確保結果可復現(每次運行劃分結果相同)。
4. 特征標準化
scaler = StandardScaler() X_train_scaled = scaler.fit_transform(X_train)? # 訓練集標準化 X_test_scaled = scaler.transform(X_test)? # 測試集用相同參數標準化 |
- 為什么標準化?:邏輯回歸對特征尺度敏感(例如,如果某個特征的數值范圍很大,會影響模型收斂)。
- StandardScaler:將特征轉換為均值為 0、標準差為 1 的標準正態分布。
- fit_transform:計算訓練集的均值 / 標準差,并應用轉換。
- transform:用訓練集的統計參數(均值 / 標準差)轉換測試集(不能重新計算)。
5. 訓練邏輯回歸模型
model = LogisticRegression(max_iter=200, random_state=42) model.fit(X_train_scaled, y_train) |
- LogisticRegression:邏輯回歸是分類算法(盡管名字帶 “回歸”)。
- max_iter=200:增加最大迭代次數,確保模型收斂(默認 100 可能不夠)。
- fit:用訓練數據學習模型參數(找到最佳分類邊界)。
6. 模型預測
y_pred = model.predict(X_test_scaled)? # 預測類別(0/1/2) y_pred_proba = model.predict_proba(X_test_scaled)? # 預測每個類別的概率 |
- predict:直接輸出預測的類別(例如 1 代表 versicolor)。
- predict_proba:輸出樣本屬于每個類別的概率(例如 [0.01, 0.95, 0.04] 表示 95% 概率是第二類)。
7. 模型評估
print(f"訓練集準確率:{model.score(X_train_scaled, y_train):.4f}") print(f"測試集準確率:{accuracy_score(y_test, y_pred):.4f}") |
- 準確率(Accuracy):預測正確的樣本比例。
- 訓練集準確率:約 0.99(模型對訓練數據的擬合程度)。
- 測試集準確率:約 0.97(模型對新數據的泛化能力)。
混淆矩陣(Confusion Matrix)
cm = confusion_matrix(y_test, y_pred) disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=iris.target_names) disp.plot() |
- 混淆矩陣:可視化分類結果,對角線表示預測正確的樣本數。
- 例如:預測 setosa(0)的樣本全部分類正確;有 1 個 versicolor(1)被誤分類為 virginica(2)。
分類報告(Classification Report)
print(classification_report(y_test, y_pred, target_names=iris.target_names)) |
- 精確率(Precision):預測為某類的樣本中,實際屬于該類的比例。
- 召回率(Recall):實際屬于某類的樣本中,被正確預測的比例。
- F1 分數(F1-score):精確率和召回率的調和平均。
8. 特征重要性分析
feature_importance = pd.DataFrame({ ??? '特征': iris.feature_names, ??? '系數絕對值': np.abs(model.coef_).mean(axis=0) }).sort_values('系數絕對值', ascending=False) |
- 邏輯回歸系數:系數絕對值越大,說明該特征對分類的影響越大。
- 通常petal width(花瓣寬度)和petal length(花瓣長度)對分類最重要。
9. 新樣本預測示例
new_sample = np.array([[5.8, 3.0, 4.9, 1.6]])? # 手動構造一個樣本 new_sample_scaled = scaler.transform(new_sample)? # 標準化 predicted_class = model.predict(new_sample_scaled)? # 預測類別 predicted_prob = model.predict_proba(new_sample_scaled)? # 預測概率 |
- 預測結果:輸出新樣本的預測類別和概率(例如 95% 概率是 versicolor)。
總結
這個代碼展示了一個完整的機器學習流程:
- 數據準備:加載數據、劃分訓練集 / 測試集。
- 特征工程:標準化特征,避免量綱影響。
- 模型訓練:用邏輯回歸學習分類規則。
- 模型評估:用準確率、混淆矩陣等指標衡量性能。
- 預測應用:對新樣本進行分類。
鳶尾花數據集是機器學習的 “Hello World”,適合入門。邏輯回歸是簡單但強大的分類算法,尤其適合特征與類別之間存在線性關系的場景。