K-最近鄰(K-Nearest Neighbors, KNN)是一種簡單且直觀的監督學習算法,廣泛應用于分類和回歸任務。本文將介紹KNN算法的基本概念、實現細節以及Python代碼示例。
基本概念
KNN算法的核心思想是:給定一個測試樣本,根據其在特征空間中與訓練樣本的距離,找到距離最近的K個訓練樣本(鄰居),然后通過這些鄰居的標簽來決定測試樣本的標簽。在分類任務中,KNN通過對K個鄰居的標簽進行投票,選擇出現次數最多的標簽作為預測結果;在回歸任務中,KNN通過對K個鄰居的標簽進行平均來預測結果。
算法步驟
- 計算距離:計算測試樣本與每個訓練樣本之間的距離。
- 選擇最近的K個鄰居:根據距離選擇K個最近的訓練樣本。
- 投票:在K個最近鄰居中,選擇出現次數最多的類別作為預測結果。
距離度量
在KNN算法中,通常使用歐氏距離(Euclidean Distance)來度量樣本之間的距離。
實現代碼
下面是一個使用 numpy
實現的 KNN 分類器的示例代碼:
import numpy as np
from collections import Counterclass KNN:def __init__(self, k=3):self.k = kdef fit(self, X_train, y_train):"""訓練KNN分類器,保存訓練數據。參數:- X_train: 訓練樣本特征,形狀 (num_samples, num_features)- y_train: 訓練樣本標簽,形狀 (num_samples,)"""self.X_train = X_trainself.y_train = y_traindef predict(self, X_test):"""對測試樣本進行預測。參數:- X_test: 測試樣本特征,形狀 (num_samples, num_features)返回值:- y_pred: 預測標簽,形狀 (num_samples,)"""y_pred = [self._predict(x) for x in X_test]return np.array(y_pred)def _predict(self, x):"""對單個測試樣本進行預測。參數:- x: 單個測試樣本特征,形狀 (num_features,)返回值:- 預測標簽"""# 計算所有訓練樣本與測試樣本之間的距離distances = np.linalg.norm(self.X_train - x, axis=1)# 獲取距離最近的k個訓練樣本的索引k_indices = np.argsort(distances)[:self.k]# 獲取k個最近鄰居的標簽k_nearest_labels = [self.y_train[i] for i in k_indices]# 返回出現次數最多的標簽most_common = Counter(k_nearest_labels).most_common(1)return most_common[0][0]# 示例用法
if __name__ == "__main__":# 創建示例數據X_train = np.array([[1, 2], [2, 3], [3, 4], [6, 7], [7, 8], [8, 9]])y_train = np.array([0, 0, 0, 1, 1, 1])X_test = np.array([[2, 3], [3, 5], [8, 8]])# 創建KNN實例knn = KNN(k=3)knn.fit(X_train, y_train)predictions = knn.predict(X_test)print("測試樣本預測結果:", predictions)
代碼解釋
-
初始化:
__init__
方法初始化KNN分類器,并設置K值。
-
訓練模型:
fit
方法保存訓練樣本的特征和標簽,供后續預測使用。
-
預測:
predict
方法對一組測試樣本進行預測,返回預測標簽。_predict
方法對單個測試樣本進行預測:- 計算測試樣本與每個訓練樣本之間的歐氏距離。
- 找到距離最近的K個訓練樣本的索引。
- 獲取K個最近鄰居的標簽。
- 返回出現次數最多的標簽作為預測結果。
-
示例用法:
- 創建示例訓練數據和測試數據。
- 實例化KNN分類器,并設置K值為3。
- 調用
fit
方法訓練模型。 - 調用
predict
方法對測試樣本進行預測,并輸出預測結果。
超參數選擇
K值是KNN算法的一個關鍵超參數,其選擇會直接影響模型的性能。一般來說,較小的K值會導致模型對噪聲敏感,而較大的K值會使模型過于平滑,導致欠擬合。可以通過交叉驗證來選擇最優的K值。
優缺點
優點
- 簡單直觀,易于理解和實現。
- 不需要顯式的訓練過程,只需保存訓練數據。
- 對于小規模數據集效果較好。
缺點
- 計算復雜度高,對大規模數據集不適用。
- 對噪聲和不相關特征敏感。
- 需要保存所有訓練數據,存儲開銷大。
總結
K-最近鄰(KNN)是一種經典的機器學習算法,適用于分類和回歸任務。盡管其簡單性和直觀性使其在許多應用中表現良好,但在處理大規模數據集和高維數據時,KNN的計算復雜度和存儲需求成為其主要限制因素。通過合理選擇K值和使用適當的距離度量,KNN可以在許多實際問題中取得令人滿意的效果。