K近鄰:從理論到實踐
文章目錄
- K近鄰:從理論到實踐
- 1. 核心思想
- 2. 距離度量
- 3. k的選擇與誤差分析
- 3.1 近似誤差
- 3.2 估計誤差
- 3.3 總誤差
- 4. kd樹的構造與搜索
- 4.1 kd樹的構造
- 4.2 kd樹的搜索
- 5. 總結
- 6. K近鄰用于iris數據集分類
- 6.1加載數據
- 6.2加載模型并可視化
1. 核心思想
K近鄰(KNN)是一種基于實例的監督學習方法。其基本思想是:
對于一個待分類樣本,根據訓練集中與其“距離”最近的 kk 個鄰居的類別,通過投票或加權投票的方式決定該樣本的類別。
數學表達:
設訓練集為
D={(x1,y1),(x2,y2),…,(xn,yn)},xi∈Rd,yi∈{1,2,…,C}{D} = \{ (x_1,y_1), (x_2,y_2), \dots, (x_n,y_n) \}, \quad x_i \in \mathbb{R}^d, \; y_i \in \{1,2,\dots,C\}D={(x1?,y1?),(x2?,y2?),…,(xn?,yn?)},xi?∈Rd,yi?∈{1,2,…,C}
給定測試樣本x,找到其最近的 kk 個鄰居集合Nk(x){N}_k(x)Nk?(x)。
預測類別為:
y^(x)=arg?max?c∈{1,…,C}∑(xi,yi)∈Nk(x)1(yi=c)\hat{y}(x) = \arg\max_{c \in \{1,\dots,C\}} \sum_{(x_i,y_i) \in \mathcal{N}_k(x)} \mathbf{1}(y_i = c)y^?(x)=argc∈{1,…,C}max?(xi?,yi?)∈Nk?(x)∑?1(yi?=c)
其中,1(?){1}(\cdot)1(?) 是指示函數。
如果采用加權投票(考慮距離遠近),則為:
y^(x)=arg?max?c∈{1,…,C}∑(xi,yi)∈Nk(x)1∥x?xi∥?1(yi=c)\hat{y}(x) = \arg\max_{c \in \{1,\dots,C\}} \sum_{(x_i,y_i) \in \mathcal{N}_k(x)} \frac{1}{\|x - x_i\|} \cdot \mathbf{1}(y_i = c)y^?(x)=argc∈{1,…,C}max?(xi?,yi?)∈Nk?(x)∑?∥x?xi?∥1??1(yi?=c)
2. 距離度量
KNN 依賴距離來衡量樣本相似度。常見的度量方式有:
- 歐氏距離:
d(xi,xj)=∑l=1d(xi(l)?xj(l))2d(x_i, x_j) = \sqrt{\sum_{l=1}^d (x_i^{(l)} - x_j^{(l)})^2}d(xi?,xj?)=l=1∑d?(xi(l)??xj(l)?)2?
- 曼哈頓距離:
d(xi,xj)=∑l=1d∣xi(l)?xj(l)∣d(x_i, x_j) = \sum_{l=1}^d |x_i^{(l)} - x_j^{(l)}|d(xi?,xj?)=l=1∑d?∣xi(l)??xj(l)?∣
- 閔可夫斯基距離(推廣形式):
d(xi,xj)=(∑l=1d∣xi(l)?xj(l)∣p)1/pd(x_i, x_j) = \left( \sum_{l=1}^d |x_i^{(l)} - x_j^{(l)}|^p \right)^{1/p}d(xi?,xj?)=(l=1∑d?∣xi(l)??xj(l)?∣p)1/p
3. k的選擇與誤差分析
KNN 的性能對 k 值選擇敏感,體現了 近似誤差 與 估計誤差 的權衡。
3.1 近似誤差
- 定義:模型表達能力不足,導致預測結果無法逼近真實分布。
- k 較大時:決策邊界過于平滑,難以捕捉復雜模式 → 近似誤差大。
- k 較小時:決策邊界靈活,可以更好地擬合真實模式 → 近似誤差小。
數學上,假設真實函數為 f(x),KNN 的期望預測為:
f^(x)=ED[y^(x)]\hat{f}(x) = \mathbb{E}_{\mathcal{D}}[\hat{y}(x)]f^?(x)=ED?[y^?(x)]
則近似誤差為:
Bias2(x)=(ED[y^(x)]?f(x))2\text{Bias}^2(x) = \big( \mathbb{E}_{\mathcal{D}}[\hat{y}(x)] - f(x) \big)^2Bias2(x)=(ED?[y^?(x)]?f(x))2
3.2 估計誤差
- 定義:模型對有限訓練數據過于依賴,泛化性差,導致預測不穩定。
- k 較小時:極易受噪聲點影響,估計誤差大。
- k 較大時:結果受單個點波動影響小,估計誤差小。
其數學形式為:
Var(x)=ED[(y^(x)?ED[y^(x)])2]\text{Var}(x) = \mathbb{E}_{\mathcal{D}}\big[(\hat{y}(x) - \mathbb{E}_{\mathcal{D}}[\hat{y}(x)])^2\big]Var(x)=ED?[(y^?(x)?ED?[y^?(x)])2]
3.3 總誤差
textMSE(x)=Bias2(x)+Var(x)+σ2text{MSE}(x) = \text{Bias}^2(x) + \text{Var}(x) + \sigma^2textMSE(x)=Bias2(x)+Var(x)+σ2
其中,σ2\sigma^2σ2 是不可約誤差。
因此,選擇合適的 k 值非常重要。
4. kd樹的構造與搜索
由于 KNN 需要計算測試點與所有訓練點的距離,時間復雜度為O(n)。為了加速,可以用 kd樹進行近鄰搜索。
4.1 kd樹的構造
- kd樹是一種對數據進行遞歸二分的空間劃分結構。
- 每次選擇一個維度(通常是方差最大的維度),按照該維度的中位數劃分數據。
- 構造過程:
- 從根節點開始,選擇一個維度作為切分軸;
- 找到該維度的中位數,作為節點存儲值;
- 左子樹存儲小于該值的樣本,右子樹存儲大于該值的樣本;
- 遞歸進行直到樣本數過少或樹深度達到限制。
偽代碼:
function build_kd_tree(points, depth):
if points is empty:
return None
axis = depth mod d
sort points by axis
median = len(points) // 2
node = new Node(points[median])
node.left = build_kd_tree(points[:median], depth+1)
node.right = build_kd_tree(points[median+1:], depth+1)
return node
4.2 kd樹的搜索
kd樹搜索遵循“回溯+剪枝”原則:
- 從根節點開始,遞歸到葉子節點,找到測試點所屬的區域;
- 以該葉子節點為“當前最近鄰”;
- 回溯檢查父節點和另一子樹,若另一子樹中可能存在更近鄰,則遞歸進入;
- 維護一個大小為 kk 的優先隊列,存儲當前最近的 kk 個鄰居;
- 搜索結束時隊列中的點即為近鄰結果。
偽代碼:
function knn_search(node, target, k, depth):
if node is None:
return
axis = depth mod d
if target[axis] < node.point[axis]:
next = node.left
other = node.right
else:
next = node.right
other = node.left
function knn_search(node, target, k, depth):if node is None:returnaxis = depth mod dif target[axis] < node.point[axis]:next = node.leftother = node.rightelse:next = node.rightother = node.leftknn_search(next, target, k, depth+1)update priority queue with node.pointif |target[axis] - node.point[axis]| < current_max_distance_in_queue:knn_search(other, target, k, depth+1)
5. 總結
- 核心思想:KNN 通過尋找最近的 kk 個鄰居來分類或回歸。
- k 的選擇:小 kk → 近似誤差小、估計誤差大(過擬合);大 kk → 近似誤差大、估計誤差小(欠擬合)。
- kd樹:通過空間劃分加速近鄰搜索,提升算法效率。
最終,KNN 的關鍵在于 合適的 k 值選擇 和 高效的搜索結構。
6. K近鄰用于iris數據集分類
6.1加載數據
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_splitiris = load_iris(as_frame=True)
X = iris.data[["sepal length (cm)", "sepal width (cm)"]]
y = iris.target
X_train, X_test, y_train, y_test = train_test_split(X, y, stratify=y, random_state=0)
鳶尾花數據集,as_frame=True
表示返回 pandas DataFrame 而不是 numpy 數組,方便做列選擇。
這個數據集有 150 條樣本,4 個特征:sepal length
, sepal width
, petal length
, petal width
。目標變量 target
有三類 (0=setosa, 1=versicolor, 2=virginica)。
6.2加載模型并可視化
import matplotlib.pyplot as plt
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.inspection import DecisionBoundaryDisplay
import pandas as pd
import time# 1. 載入數據
iris = load_iris(as_frame=True)
X = iris.data[["sepal length (cm)", "sepal width (cm)"]]
y = iris.target
X_train, X_test, y_train, y_test = train_test_split(X, y, stratify=y, random_state=0
)# 2. 構建 pipeline:標準化 + KNN
clf = Pipeline(steps=[("scaler", StandardScaler()),("knn", KNeighborsClassifier(n_neighbors=11))]
)# 3. 不同的 weights 和 algorithm 組合
weights_list = ["uniform", "distance"]
algorithms = ["auto", "ball_tree", "kd_tree"]# 定義結果存儲表
results = []# 4. 畫圖:每行一個 weights,每列一個 algorithm
fig, axs = plt.subplots(nrows=len(weights_list), ncols=len(algorithms), figsize=(18, 10)
)for i, weights in enumerate(weights_list):for j, algo in enumerate(algorithms):ax = axs[i, j]# 設置參數并擬合start_train = time.time()clf.set_params(knn__weights=weights, knn__algorithm=algo).fit(X_train, y_train)end_train = time.time()start_pred = time.time()clf.predict(X_test)end_pred = time.time()acc = clf.score(X_test, y_test)results.append({"weights": weights,"algorithm": algo,"accuracy": acc,"train_time (s)": end_train - start_train,"predict_time (s)": end_pred - start_pred})# 決策邊界disp = DecisionBoundaryDisplay.from_estimator(clf,X_test,response_method="predict",plot_method="pcolormesh",xlabel=iris.feature_names[0],ylabel=iris.feature_names[1],shading="auto",alpha=0.5,ax=ax,)# 訓練樣本點scatter = disp.ax_.scatter(X.iloc[:, 0], X.iloc[:, 1], c=y, edgecolors="k")# 圖例disp.ax_.legend(scatter.legend_elements()[0],iris.target_names,loc="lower left",title="Classes",)# 子圖標題ax.set_title(f"k={clf[-1].n_neighbors}, weights={weights}, algo={algo}")plt.tight_layout()
plt.show()
df_results = pd.DataFrame(results)
print(df_results)
weights algorithm accuracy train_time (s) predict_time (s)
0 uniform auto 0.710526 0.003293 0.004401
1 uniform ball_tree 0.710526 0.004864 0.006618
2 uniform kd_tree 0.710526 0.003537 0.004044
3 distance auto 0.631579 0.003269 0.001961
4 distance ball_tree 0.631579 0.003211 0.001694
5 distance kd_tree 0.631579 0.003055 0.001578
不同 algorithm 的表現
auto
、ball_tree
、kd_tree
在相同權重下的 準確率完全一致,訓練預測速度不同,這說明 搜索算法僅影響計算效率,不會改變最終分類結果。- 這和理論一致:算法只是用不同的數據結構加速鄰居查找,不會影響鄰居集合本身。
不同 weights 的表現
uniform
權重下,測試集準確率為 71.05%;distance
權重下,測試集準確率為 63.16%;- 在本實驗中,uniform 明顯優于 distance。
- 這表明在鳶尾花數據的 前兩個特征(花萼長、寬) 上,等權投票比加權投票更適合。可能原因是:
- 特征維度少,距離加權放大了噪聲點或邊界點的影響;
- 類別邊界本身不完全線性,用距離權重反而削弱了多數鄰居的穩定性。
結合可視化
-
從決策邊界圖上可以看到:
uniform
的邊界相對平滑,更符合數據整體分布;
eights 的表現**
-
uniform
權重下,測試集準確率為 71.05%; -
distance
權重下,測試集準確率為 63.16%; -
在本實驗中,uniform 明顯優于 distance。
-
這表明在鳶尾花數據的 前兩個特征(花萼長、寬) 上,等權投票比加權投票更適合。可能原因是:
- 特征維度少,距離加權放大了噪聲點或邊界點的影響;
- 類別邊界本身不完全線性,用距離權重反而削弱了多數鄰居的穩定性。