SVM實戰:從理論到鳶尾花數據集的分類可視化
在機器學習的廣闊領域中,支持向量機(Support Vector Machine,SVM)作為一種經典且強大的分類算法,備受矚目。它憑借獨特的思想和卓越的性能,在模式識別、數據分類等諸多領域發揮著重要作用。本文將結合Python代碼,通過對鳶尾花數據集的處理,帶大家深入了解SVM的工作原理與實戰應用。
一、SVM原理概述
SVM的核心目標是在特征空間中找到一個超平面,該超平面能夠將不同類別的數據點盡可能清晰地分隔開來,并且使分隔的間隔達到最大。這個超平面可以用一個線性方程表示: w T x + b = 0 w^Tx + b = 0 wTx+b=0,其中 w w w是權重向量,決定了超平面的方向, b b b是偏置項,決定了超平面的位置。
在二維空間中,超平面就是一條直線;在三維空間中,超平面是一個平面;而在更高維度的空間中,超平面是一個具有高維幾何特性的決策邊界。那些距離超平面最近且恰好位于間隔邊界上的數據點,被稱為支持向量,它們對超平面的確定起著關鍵作用,決定了超平面的位置和方向。
SVM的優勢在于,它不僅能夠處理線性可分的數據,還能通過核函數將低維空間中線性不可分的數據映射到高維空間,使其在高維空間中變得線性可分,從而實現非線性數據的分類。常見的核函數有線性核函數、多項式核函數、徑向基函數(RBF)等。
二、代碼實現步驟解析
1. 數據讀取與預處理
import pandas as pddata = pd.read_csv("iris.csv", header=None)
上述代碼使用pandas
庫的read_csv
函數讀取鳶尾花數據集。由于數據集中沒有表頭,所以設置header=None
。鳶尾花數據集包含150條記錄,每條記錄有4個特征和1個類別標簽,4個特征分別為花萼長度、花萼寬度、花瓣長度和花瓣寬度,類別標簽表示鳶尾花的品種(山鳶尾、雜色鳶尾、維吉尼亞鳶尾)。
2. 原始數據可視化
import matplotlib.pyplot as pltdata1 = data.iloc[:50, :]
data2 = data.iloc[50:, :]
# 原始數據是四維,無法展示,選擇兩個進行展示
plt.scatter(data1[1], data1[3], marker='+')
plt.scatter(data2[1], data2[3], marker='o')
為了直觀地觀察數據分布,選取數據集中的部分數據(前50條和后100條),并選擇兩個特征(花萼寬度和花瓣寬度)進行可視化。通過matplotlib
庫的scatter
函數繪制散點圖,可以初步看到不同類別數據點在二維平面上的分布情況。
3. SVM模型訓練
from sklearn.svm import SVCX = data.iloc[:, [1, 3]]
y = data.iloc[:, -1]
svm = SVC(kernel='linear', C=float('inf'), random_state=0)
svm.fit(X, y)
從數據集中提取用于訓練的特征矩陣X
(選取花萼寬度和花瓣寬度兩列)和標簽向量y
。使用sklearn
庫中的SVC
類創建SVM模型,這里設置kernel='linear'
表示使用線性核函數,意味著數據在當前二維特征空間中是線性可分的;C=float('inf')
表示對分類錯誤的懲罰力度無窮大,要求模型必須將所有訓練數據正確分類;random_state=0
用于設置隨機種子,保證結果的可重復性。最后調用fit
方法對模型進行訓練,使模型學習到數據的特征與類別之間的關系。
4. SVM結果可視化
# 參數w[原始數據為二維數組]
w = svm.coef_[0]
# 偏置項[原始數據為一維數組]
b = svm.intercept_[0]import numpy as npx1 = np.linspace(0, 7, 300) # 在0~7之間產生300個數據
# 超平面方程
x2 = -(w[0] * x1 + b) / w[1]
# 上超平面方程
x3 = (1 - (w[0] * x1 + b)) / w[1]
# 下超平面方程
x4 = (-1 - (w[0] * x1 + b)) / w[1]
# 可視化原始數據,選取1維核3維的數據進行可視化
# plt.scatter(data1[1],data1[3],marker='+',color='b')
# plt.scatter(data2[1],data2[3],marker='o',color='b')
# 可視化超平面
plt.plot(x1, x2, linewidth=2, color='r')
plt.plot(x1, x3, linewidth=1, color='r', linestyle='--')
plt.plot(x1, x4, linewidth=1, color='r', linestyle='--')
# 進行坐標軸限制
plt.xlim(4, 7)
plt.ylim(0, 5)
# 找到支持向量[二維數組]可視化支持向量
vets = svm.support_vectors_
plt.scatter(vets[:, 0], vets[:, 1], c='b', marker='x')plt.show()
訓練完成后,提取模型的權重向量w
和偏置項b
,根據超平面方程計算出一系列點的坐標,用于繪制超平面、上超平面和下超平面。通過matplotlib
的plot
函數將這些平面繪制出來,并設置合適的顏色、線寬和線型。同時,提取支持向量,使用scatter
函數將其可視化,直觀地展示支持向量在分類中的關鍵作用。最后通過xlim
和ylim
函數設置坐標軸的范圍,使可視化效果更加美觀和清晰。
三、結果分析
從可視化結果可以清晰地看到,SVM模型成功地找到了一個線性超平面,將不同類別的鳶尾花數據點分隔開來。支持向量位于間隔邊界上,它們確定了超平面的位置和方向。在這個例子中,由于設置了較大的懲罰參數C
,模型對訓練數據實現了完美分類,所有數據點都被正確劃分到相應的類別區域。
然而,在實際應用中,過高的C
值可能導致模型過擬合,即模型在訓練數據上表現良好,但在新數據上的泛化能力較差。因此,需要根據具體問題和數據特點,合理調整C
值以及選擇合適的核函數,以達到更好的分類效果。
四、總結
通過對鳶尾花數據集的SVM分類實戰,我們深入了解了SVM的原理、代碼實現過程以及結果分析方法。SVM作為一種強大的機器學習算法,在數據分類任務中展現出了優秀的性能和獨特的優勢。希望本文的內容能夠幫助大家更好地理解和應用SVM,在實際的機器學習項目中發揮其價值。在后續的學習和實踐中,我們還可以進一步探索不同核函數、參數調整以及SVM在更多復雜數據集上的應用,不斷拓展對這一算法的認識和掌握程度。
以上博客從多方面解讀了SVM的實戰應用。你若覺得某些部分需要補充、修改,或有其他想法,歡迎隨時和我說。
這篇博客全面剖析了SVM的實戰應用。若你對博客的內容深度、篇幅長度等有新想法,隨時和我交流。