一、概述
支持向量機(Support Vector Machine,SVM)是一種應用范圍非常廣泛的算法,既可以用于分類,也可以用于回歸。
本文將介紹如何將線性支持向量機應用于二元分類問題,以間隔(margin)最大化為基準,得到更好的決策邊界。雖然該算法的決策邊界與邏輯回歸一樣是線性的,但有時線性支持向量機得到的結果更好。
下面對同一數據分別應用線性支持向量機和邏輯回歸,并比較其結果
左圖是學習前的數據,右圖是從數據中學習后的結果。右圖中用黑色直線標記的是線性支持向量機的決策邊界,用藍色虛線標記的是邏輯回歸的決策邊界。線性支持向量機的分類結果更佳。線性支持向量機的學習方式是:以間隔最大化為基準,讓決策邊界盡可能地遠離數據。下面來看一下線性支持向量機是如何從數據中學習的。
二、算法說明
間隔的定義:
以平面上的二元分類問題為例進行說明,并且假設數據可以完全分類。線性支持向量機通過線性的決策邊界將平面一分為二,據此進行二元分類。此時,訓練數據中最接近決策邊界的數據與決策邊界之間的距離就稱為間隔
右圖的間隔大于左圖的間隔。支持向量機試圖通過增大決策邊界和訓練數據之間的間隔來獲得更合理的邊界。
三、示例代碼
下面生成線性可分的數據,將其分割成訓練數據和驗證數據,使用訓練數據訓練線性支持向量機,使用驗證數據評估正確率。另外,由于使用了隨機數,所以每次運行的結果可能有所不同。
?
"""
LinearSVC:線性支持向量分類模型。
make_blobs:生成模擬數據集。
train_test_split:劃分訓練集和測試集。
accuracy_score:計算分類準確率。
"""from sklearn.svm import LinearSVC
from sklearn.datasets import make_blobs
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
"""
生成 50個樣本,分為兩類,每類25個樣本。
數據特征為二維,圍繞兩個中心點 (-1, -0.125) 和 (0.5, 0.5) 生成。
cluster_std=0.3 標準差表示數據點圍繞中心點的分布較集中
"""
centers = [(-1, -0.125), (0.5, 0.5)]
X, y = make_blobs(n_samples=50, n_features=2,
centers=centers, cluster_std=0.3)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3) # 15個樣本作為測試集(30%)
model = LinearSVC()
model.fit(X_train, y_train) # 訓練
y_pred = model.predict(X_test)
print(accuracy_score(y_pred, y_test)) """
評估函數說明:
輸入參數:y_pred:模型對測試集的預測結果(由 model.predict(X_test) 生成)
y_test:測試集的真實標簽(實際正確答案)
輸出結果:
返回一個 0~1 之間的浮點數,表示預測正確的樣本比例。
例如:0.85 表示 85% 的測試樣本被正確分類。
"""
四、詳細說明
軟間隔和支持向量
目前我們了解的都是數據可以線性分離的情況,這種不允許數據進入間隔內側的情況稱為硬間隔。但一般來說,數據并不是完全可以線性分離的,所以要允許一部分數據進入間隔內側,這種情況叫作軟間隔。通過引入軟間隔,無法線性分離的數據也可以如圖所示進行學習。
基于線性支持向量機的學習結果,我們可以將訓練數據分為以下3種。
1. 與決策邊界之間的距離比間隔還要遠的數據:間隔外側的數據。
2. 與決策邊界之間的距離和間隔相同的數據:間隔上的數據。
3. 與決策邊界之間的距離比間隔近,或者誤分類的數據:間隔內側的數據。
其中,我們將間隔上的數據和間隔內側的數據特殊對待,稱為支持向量。支持向量是確定決策邊界的重要數據。間隔外側的數據則不會影響決策邊界的形狀。由于間隔內側的數據包含被誤分類的數據,所以乍看起來通過調整間隔,使間隔內側不存在數據的做法更好。但對于線性可分的數據,如果強制訓練數據不進入間隔內側,可能會導致學習結果對數據過擬合。使用由兩個標簽組成的訓練數據訓練線性支持向量機而得到的結果如圖所示。
左圖是不允許數據進入間隔內側的硬間隔的情況,右圖是允許數據進入間隔內側的軟間隔的情況。
另外,在藍色點表示的訓練數據中特意加上了偏離值。
比較兩個結果可以發現,使用了硬間隔的左圖上的決策邊界受偏離值的影響很大;而在引入軟間隔的右圖上的學習結果不容易受到偏離值的影響。在使用軟間隔時,允許間隔內側進入多少數據由超參數決定。
與其他算法一樣,在決定超參數時,需要使用網格搜索(grid search)和隨機搜索(random search)等方法反復驗證后再做決定。
?
?