一、樸素貝葉斯模型分類與核心原理
樸素貝葉斯算法的核心是基于 “特征條件獨立性假設”,通過貝葉斯公式計算后驗概率實現分類。根據特征數據類型的差異,衍生出三大經典模型,分別適用于不同場景,其核心區別在于對 “特征條件概率” 的計算方式不同。
1.1 多項式樸素貝葉斯(MultinomialNB)
適用場景
- 特征為離散型數據,尤其是文本分類(如統計單詞出現次數、TF-IDF 值)、物品計數等場景。
- 典型案例:垃圾郵件分類(統計 “優惠”“中獎” 等關鍵詞的出現頻次)、新聞主題分類。
核心原理
- 假設特征服從多項式分布,即特征值代表 “事件發生的次數”(如單詞在文本中出現的次數)。
- 計算條件概率時,需引入拉普拉斯平滑(通過
alpha
參數控制),避免因某些特征未出現導致概率為 0 的問題。
關鍵參數(sklearn 實現)
參數 | 類型 | 默認值 | 功能說明 |
---|---|---|---|
alpha | 浮點型 | 1.0 | 拉普拉斯平滑系數:alpha=0 表示不平滑,alpha>0 時,值越大平滑效果越強 |
fit_prior | 布爾型 | True | 是否使用先驗概率:True 時基于數據計算先驗,False 時假設所有類別先驗概率相等 |
class_prior | 數組型 | None | 自定義類別先驗概率:若指定,則忽略fit_prior 的設置 |
1.2 高斯樸素貝葉斯(GaussianNB)
適用場景
- 特征為連續型數據(如身高、體重、溫度等數值型特征),無法通過 “計數” 描述的場景。
- 典型案例:鳶尾花品種分類(花瓣長度、寬度為連續值)、房價區間預測(面積、樓層為連續值)。
核心原理
- 假設特征服從正態分布(高斯分布),通過計算樣本中每個類別下特征的均值和標準差,構建正態分布概率密度函數,進而求解條件概率。
- 無需手動設置平滑參數,模型會自動通過極大似然法估計正態分布的參數(均值、方差)。
關鍵參數(sklearn 實現)
參數 | 類型 | 默認值 | 功能說明 |
---|---|---|---|
priors | 數組型 | None | 自定義類別先驗概率:若為 None,模型通過樣本數據自動計算(極大似然法) |
1.3 伯努利樸素貝葉斯(BernoulliNB)
適用場景
- 特征為二值離散型數據(僅取值 0 或 1),即 “特征是否存在” 而非 “特征出現次數” 的場景。
- 典型案例:文本分類(單詞是否在文本中出現,1 = 出現,0 = 未出現)、用戶行為分析(是否點擊某按鈕,1 = 點擊,0 = 未點擊)。
核心原理
- 假設特征服從伯努利分布(0-1 分布),僅關注特征 “是否發生”,不關注發生次數。
- 需通過
binarize
參數將非二值特征轉換為二值(若特征已二值化,可設為 None),同時支持拉普拉斯平滑(alpha
參數)。
關鍵參數(sklearn 實現)
參數 | 類型 | 默認值 | 功能說明 |
---|---|---|---|
alpha | 浮點型 | 1.0 | 拉普拉斯平滑系數,作用同 MultinomialNB |
binarize | 浮點型 / None | 0 | 特征二值化閾值:若為x ,則特征值 > x 設為 1,否則設為 0;None 表示特征已二值化 |
fit_prior | 布爾型 | True | 是否使用先驗概率,作用同 MultinomialNB |
class_prior | 數組型 | None | 自定義類別先驗概率,作用同 MultinomialNB |
1.4 三大模型對比與選擇指南
模型 | 適用特征類型 | 核心假設 | 關鍵參數 | 典型場景 |
---|---|---|---|---|
多項式樸素貝葉斯 | 離散型(計數數據) | 特征服從多項式分布 | alpha 、fit_prior | 文本分類(單詞頻次)、商品銷量分類 |
高斯樸素貝葉斯 | 連續型數據 | 特征服從正態分布 | priors | 數值型特征分類(鳶尾花、房價) |
伯努利樸素貝葉斯 | 二值離散型(0/1) | 特征服從伯努利分布 | alpha 、binarize | 文本分類(單詞存在性)、用戶行為分析 |
二、樸素貝葉斯模型通用 API(sklearn)
sklearn 中三種樸素貝葉斯模型的接口完全一致,核心方法如下,便于快速切換模型進行實驗:
方法 | 功能描述 |
---|---|
fit(X, y) | 用訓練集(X:特征矩陣,y:標簽)擬合模型,學習概率參數 |
predict(X) | 對測試集 X 進行分類預測,返回每個樣本的類別標簽 |
predict_proba(X) | 返回測試集 X 屬于每個類別的概率(概率和為 1) |
predict_log_proba(X) | 返回predict_proba(X) 的對數形式(避免概率過小導致的數值下溢) |
score(X, y) | 計算模型在數據集(X,y)上的準確率(正確預測數 / 總樣本數) |
三、課后練習:樸素貝葉斯實現手寫數字識別
3.1 任務背景
手寫數字數據集(load_digits
)包含 8×8 像素的灰度圖像(共 1797 個樣本),每個樣本的特征是 64 個連續值(0-16,代表像素灰度),標簽是 0-9 的數字。需選擇合適的樸素貝葉斯模型實現分類,并評估性能。
3.2 模型選擇依據
- 特征類型:64 個像素值為連續型數據(0-16 的整數,本質是連續區間內的離散化表示)。
- 模型匹配:連續型特征適合使用高斯樸素貝葉斯(GaussianNB),無需手動處理特征分布,直接通過正態分布擬合像素值概率。
3.3 完整代碼實現
python
運行
# 1. 導入必要庫
import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import load_digits
from sklearn.naive_bayes import GaussianNB # 選擇高斯樸素貝葉斯
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix# 2. 加載并探索數據集
digits = load_digits()
print("數據集基本信息:")
print(f"樣本數量:{digits.data.shape[0]}, 特征數量(像素數):{digits.data.shape[1]}")
print(f"類別數量(數字0-9):{len(digits.target_names)}")
print(f"像素值范圍:{digits.data.min()} ~ {digits.data.max()}")# 可視化前4個樣本(驗證數據格式)
plt.figure(figsize=(8, 4))
for i in range(4):plt.subplot(1, 4, i+1)# 將64維特征重塑為8×8圖像plt.imshow(digits.images[i], cmap=plt.cm.gray_r)plt.title(f"Label: {digits.target[i]}")plt.axis("off")
plt.show()# 3. 數據預處理:劃分訓練集與測試集
# 隨機劃分,測試集占比30%,固定隨機種子確保結果可復現
X_train, X_test, y_train, y_test = train_test_split(digits.data, digits.target, test_size=0.3, random_state=42
)
print(f"\n訓練集大小:{X_train.shape}, 測試集大小:{X_test.shape}")# 4. 初始化并訓練高斯樸素貝葉斯模型
model = GaussianNB()
model.fit(X_train, y_train) # 擬合模型,學習每個類別下特征的正態分布參數# 5. 模型預測與性能評估
# 5.1 測試集預測
y_pred = model.predict(X_test) # 類別預測
y_pred_proba = model.predict_proba(X_test) # 類別概率預測(可選)# 5.2 核心指標:準確率
accuracy = accuracy_score(y_test, y_pred)
print(f"\n模型在測試集上的準確率:{accuracy:.4f}")# 5.3 詳細評估:分類報告(精確率、召回率、F1分數)
print("\n分類報告(精確率/召回率/F1分數):")
print(classification_report(y_test, y_pred, target_names=[str(i) for i in digits.target_names]
))# 5.4 混淆矩陣:分析各類別預測錯誤情況
conf_matrix = confusion_matrix(y_test, y_pred)
print("\n混淆矩陣(行:真實標簽,列:預測標簽):")
print(conf_matrix)# 6. 錯誤案例分析(可選):查看前3個預測錯誤的樣本
error_indices = np.where(y_pred != y_test)[0][:3] # 前3個錯誤樣本的索引
plt.figure(figsize=(8, 3))
for i, idx in enumerate(error_indices):plt.subplot(1, 3, i+1)plt.imshow(X_test[idx].reshape(8, 8), cmap=plt.cm.gray_r)plt.title(f"True: {y_test[idx]}, Pred: {y_pred[idx]}")plt.axis("off")
plt.show()
3.4 結果分析
1. 基礎性能
- 模型在測試集上的準確率約為0.83-0.85(因隨機劃分可能略有波動,固定
random_state=42
后準確率為 0.8426)。 - 從分類報告可見:數字 “0”“1”“6” 的 F1 分數接近 1.0,分類效果極佳;數字 “8”“9” 的 F1 分數較低(約 0.75),因這兩個數字的像素分布更相似,易混淆。
2. 混淆矩陣解讀
- 對角線元素表示 “正確預測數”,非對角線元素表示 “錯誤預測數”。
- 例如:混淆矩陣中
conf_matrix[8,9]
(真實為 8,預測為 9)的值較大,說明模型易將 “8” 誤判為 “9”,需后續優化(如增加特征工程、換用其他模型)。
3. 錯誤案例可視化
- 錯誤樣本的圖像顯示:“8” 與 “9” 的區別僅在于底部是否有缺口,像素差異小,導致高斯樸素貝葉斯的正態分布假設難以區分,這是模型性能瓶頸的主要原因。
四、學習總結與拓展思考
4.1 核心收獲
- 模型選型邏輯:根據特征類型選擇樸素貝葉斯模型是關鍵 —— 離散計數用多項式、連續值用高斯、二值特征用伯努利,避免 “錯配” 導致的性能損失。
- 高斯模型特點:無需手動處理連續特征的分布,實現簡單、速度快,但對 “特征獨立” 假設敏感(如手寫數字中相鄰像素存在相關性,會影響模型精度)。
- 評估維度:除準確率外,需通過混淆矩陣、分類報告分析 “類別級” 性能,定位易混淆類別,為優化提供方向。
4.2 優化方向
- 特征工程:對像素特征進行預處理(如二值化、邊緣檢測),增強數字間的區分度,可嘗試用伯努利樸素貝葉斯重新實驗。
- 模型改進:若需更高精度,可換用非樸素貝葉斯模型(如 SVM、隨機森林),或通過 “貝葉斯網絡” 放松 “特征獨立” 假設。
- 超參數調優:對高斯樸素貝葉斯的
priors
參數進行自定義(如根據訓練集中各類別樣本占比調整先驗概率),可能提升少數類別的召回率。
通過本次實踐,不僅掌握了三種樸素貝葉斯模型的應用場景與代碼實現,更理解了 “模型假設與數據特性匹配” 的重要性 —— 樸素貝葉斯的 “樸素” 既是其優勢(計算快),也是其局限(依賴獨立假設),需在實際任務中靈活權衡。