1 介紹
????????支持向量機(Support Vector Machine,簡稱 SVM)是一種監督學習算法,主要用于分類和回歸問題。SVM 的核心思想是找到一個最優的超平面,將不同類別的數據分開。這個超平面不僅要能夠正確分類數據,還要使得兩個類別之間的間隔(margin)最大化。
1.1 線性可分
????????在二維空間上,兩類點被一條直線完全分開叫做線性可分。
????????樣本中距離超平面最近的一些點,這些點叫做支持向量。
1.2?軟間隔
????????在實際應用中,完全線性可分的樣本是很少的,如果遇到了不能夠完全線性可分的樣本,我們應該怎么辦?比如下面這個:
????????于是我們就有了軟間隔,相比于硬間隔的苛刻條件,我們允許個別樣本點出現在間隔帶里面,比如:
1.3?線性不可分
???????我們剛剛討論的硬間隔和軟間隔都是在說樣本的完全線性可分或者大部分樣本點的線性可分。但我們可能會碰到的一種情況是樣本點不是線性可分的,比如:
????????這種情況的解決方法就是將二維線性不可分樣本映射到高維空間中,讓樣本點在高維空間線性可分,比如:
????????對于在有限維度向量空間中線性不可分的樣本,我們將其映射到更高維度的向量空間里,再通過間隔最大化的方式,學習得到支持向量機,就是非線性 SVM。
1.4?優缺點
????????優點
- 有嚴格的數學理論支持,可解釋性強,不依靠統計方法,從而簡化了通常的分類和回歸問題
- 能找出對任務至關重要的關鍵樣本(即:支持向量)
- 采用核技巧之后,可以處理非線性分類/回歸任務
- 最終決策函數只由少數的支持向量所確定,計算的復雜性取決于支持向量的數目,而不是樣本空間的維數,這在某種意義上避免了“維數災難”。
????????缺點
- 訓練時間長。當采用 SMO 算法時,由于每次都需要挑選一對參數,因此時間復雜度為?O(N2)?,其中 N 為訓練樣本的數量;
- 當采用核技巧時,如果需要存儲核矩陣,則空間復雜度為?O(N2)?;
- 模型預測時,預測時間與支持向量的個數成正比。當支持向量的數量較大時,預測計算復雜度較高。
????????因此支持向量機目前只適合小批量樣本的任務,無法適應百萬甚至上億樣本的任務。
2?使用 Python 實現 SVM
2.1?安裝必要的庫
????????首先,確保你已經安裝了scikit-learn
庫。如果沒有安裝,可以使用以下命令進行安裝:
pip install scikit-learn
2.2?導入庫
import numpy as np
import matplotlib.pyplot as plt
from sklearn import svm, datasets
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
2.3?加載數據集
????????我們將使用scikit-learn
自帶的鳶尾花(Iris)數據集。
# 加載鳶尾花數據集
iris = datasets.load_iris()
X = iris.data[:, :2] # 只使用前兩個特征
y = iris.target
2.4?劃分訓練集和測試集
# 將數據集劃分為訓練集和測試集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)
2.5?訓練 SVM 模型
# 創建SVM分類器
clf = svm.SVC(kernel='linear') # 使用線性核函數# 訓練模型
clf.fit(X_train, y_train)
2.6?預測與評估
# 在測試集上進行預測
y_pred = clf.predict(X_test)# 計算準確率
accuracy = accuracy_score(y_test, y_pred)
print(f"模型準確率: {accuracy:.2f}")
2.7?可視化結果
# 繪制決策邊界
def plot_decision_boundary(X, y, model):h = .02 # 網格步長x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1xx, yy = np.meshgrid(np.arange(x_min, x_max, h),np.arange(y_min, y_max, h))Z = model.predict(np.c_[xx.ravel(), yy.ravel()])Z = Z.reshape(xx.shape)plt.contourf(xx, yy, Z, alpha=0.8)plt.scatter(X[:, 0], X[:, 1], c=y, edgecolors='k', marker='o')plt.xlabel('Sepal length')plt.ylabel('Sepal width')plt.title('SVM Decision Boundary')plt.show()plot_decision_boundary(X_train, y_train, clf)