一、生活中的 "分類難題" 與 k 近鄰的靈感
你有沒有這樣的經歷:在超市看到一種從沒見過的水果,表皮黃黃的,拳頭大小,形狀圓滾滾。正當你猶豫要不要買時,突然想起外婆家的橘子好像就是這個樣子 —— 黃色、圓形、大小和拳頭差不多。于是你推斷:"這應該是橘子吧!"
其實,這個看似平常的判斷過程,竟然藏著機器學習中最經典的分類算法 ——k 近鄰(k-Nearest Neighbors,簡稱 kNN)的核心思想!
1.1 現實中的解法拆解
當我們判斷未知水果時,大腦會自動完成三個步驟:
- 收集特征:觀察顏色(黃色)、形狀(圓形)、大小(拳頭大)
- 匹配經驗:調動記憶中 "橘子" 的特征庫(黃色、圓形、拳頭大)
- 做出判斷:因為新水果的特征和記憶中的橘子最像,所以歸類為橘子
這和 k 近鄰算法的工作流程驚人地相似!唯一的區別是:計算機需要我們把這些 "看得到的特征" 變成 "算得出的數據"。
1.2 k 近鄰算法的有趣靈魂
k 近鄰算法的有趣之處在于它的 "懶惰" 和 "實在":
- 懶惰:它不像其他算法那樣先總結規律(比如 "黃色圓形水果都是橘子"),而是等到需要判斷時才去比對已知數據
- 實在:它的判斷邏輯簡單粗暴 ——"少數服從多數",看新樣本周圍最像的 k 個樣本里哪種類型占多數
就像你糾結新水果是橘子還是蘋果時,會找 5 個見過這兩種水果的人投票,哪種意見多就信哪種。
二、從生活到代碼:k 近鄰算法的實現之路
我們用一個具體案例來實現:根據 "顏色深度"(0-10,數值越大越黃)和 "大小"(0-10,數值越大越大)兩個特征,判斷水果是橘子(標簽 1)還是蘋果(標簽 0)。
2.1 準備數據:把生活觀察變成數字
首先,我們需要把已知的水果數據整理成計算機能理解的格式:
# 導入必要的庫
import numpy as np # 用于數值計算
import matplotlib.pyplot as plt # 用于畫圖
# 已知水果數據:[顏色深度, 大小],標簽:0=蘋果,1=橘子
# 想象這些數據來自我們之前見過的水果:
# 蘋果通常偏紅(顏色深度小),大小不一;橘子偏黃(顏色深度大)
known_fruits = np.array([
[2, 3], # 蘋果:顏色偏紅(2),小個(3)
[3, 4], # 蘋果:顏色較紅(3),中個(4)
[1, 5], # 蘋果:顏色很紅(1),大個(5)
[7, 6], # 橘子:顏色較黃(7),中個(6)
[8, 5], # 橘子:顏色很黃(8),中個(5)
[9, 4] # 橘子:顏色極黃(9),小個(4)
])
# 對應的標簽:0代表蘋果,1代表橘子
labels = np.array([0, 0, 0, 1, 1, 1])
# 未知水果:顏色深度6,大小5(就是我們在超市看到的那個)
unknown_fruit = np.array([6, 5])
2.2 數據可視化:讓計算機 "看見" 差異
我們用散點圖把數據畫出來,直觀感受蘋果和橘子的特征差異:
# 繪制已知水果
plt.scatter(known_fruits[labels==0, 0], known_fruits[labels==0, 1],
color='red', marker='o', label='蘋果') # 蘋果標為紅色圓點
plt.scatter(known_fruits[labels==1, 0], known_fruits[labels==1, 1],
color='orange', marker='o', label='橘子') # 橘子標為橙色圓點
# 繪制未知水果(用五角星標記)
plt.scatter(unknown_fruit[0], unknown_fruit[1],
color='purple', marker='*', s=200, label='未知水果') # 紫色五角星,放大顯示
# 加上坐標軸標簽和標題
plt.xlabel('顏色深度(0-10,數值越大越黃)')
plt.ylabel('大小(0-10,數值越大越大)')
plt.title('水果特征分布圖')
plt.legend() # 顯示圖例
plt.show() # 展示圖像
運行這段代碼,你會看到:紅色圓點(蘋果)集中在左側(顏色偏紅),橙色圓點(橘子)集中在右側(顏色偏黃),而紫色五角星(未知水果)剛好在橘子群附近 —— 這就是我們肉眼判斷的依據!
三、k 近鄰算法的核心步驟:用數學實現 "投票選舉"
計算機怎么判斷未知水果的類別呢?它會執行四個關鍵步驟,我們一步步用代碼實現:
3.1 第一步:計算距離(誰離我最近?)
生活中我們靠 "感覺" 判斷相似,計算機則靠 "距離" 計算。最常用的是歐氏距離(就像直尺測量兩點距離):
\(distance = \sqrt{(x_1-x_2)^2 + (y_1-y_2)^2}\)
用代碼實現這個計算:
def calculate_distance(known_point, unknown_point):
"""
計算兩個點之間的歐氏距離
參數:
known_point:已知點的特征(如[2,3])
unknown_point:未知點的特征(如[6,5])
返回:
兩點之間的距離
"""
# 計算每個特征的差值平方,再求和,最后開平方
squared_diff = (known_point[0] - unknown_point[0])**2 + (known_point[1] - unknown_point[1])** 2
distance = np.sqrt(squared_diff)
return distance
# 計算未知水果與每個已知水果的距離
distances = []
for fruit in known_fruits:
dist = calculate_distance(fruit, unknown_fruit)
distances.append(dist)
# 打印計算過程,方便理解
print(f"已知水果特征{fruit}與未知水果的距離:{dist:.2f}")
運行后會得到類似這樣的結果:
已知水果特征[2 3]與未知水果的距離:4.47
已知水果特征[3 4]與未知水果的距離:3.16
已知水果特征[1 5]與未知水果的距離:5.10
已知水果特征[7 6]與未知水果的距離:1.41 # 這個最近!
已知水果特征[8 5]與未知水果的距離:2.00
已知水果特征[9 4]與未知水果的距離:3.61
3.2 第二步:找鄰居(選 k 個最像的)
k 近鄰算法中的 "k" 就是要選的鄰居數量。比如 k=3,就是找距離最近的 3 個已知水果:
# 把距離和對應的標簽組合起來,方便排序
distance_with_label = list(zip(distances, labels))
# 按距離從小到大排序
sorted_distance = sorted(distance_with_label, key=lambda x: x[0])
# 選擇k=3個最近的鄰居
k = 3
nearest_neighbors = sorted_distance[:k]
print(f"\n距離最近的{k}個鄰居是:")
for dist, label in nearest_neighbors:
fruit_type = "橘子" if label == 1 else "蘋果"
print(f"距離{dist:.2f},類別:{fruit_type}")
此時會輸出:
距離最近的3個鄰居是:
距離1.41,類別:橘子
距離2.00,類別:橘子
距離3.16,類別:蘋果
3.3 第三步:投票表決(少數服從多數)
看看這 3 個鄰居里哪種水果占多數:
# 提取鄰居的標簽
neighbor_labels = [label for (dist, label) in nearest_neighbors]
# 統計每個標簽出現的次數
label_counts = np.bincount(neighbor_labels)
# 找到出現次數最多的標簽
predicted_label = np.argmax(label_counts)
# 輸出結果
if predicted_label == 1:
print("\n根據k近鄰算法判斷,這個未知水果是:橘子!")
else:
print("\n根據k近鄰算法判斷,這個未知水果是:蘋果!")
最終結果會顯示 "橘子",和我們的直覺判斷完全一致!
四、完整代碼:可直接運行的 k 近鄰分類器
把上面的步驟整合起來,再加上一些優化,就得到了一個完整的 k 近鄰分類器:
import numpy as np
import matplotlib.pyplot as plt
class SimpleKNN:
"""簡單的k近鄰分類器"""
def __init__(self, k=3):
"""
初始化分類器
參數:
k:要選擇的鄰居數量,默認3個
"""
self.k = k
self.known_data = None # 用于存儲已知數據
self.known_labels = None # 用于存儲已知標簽
def fit(self, X, y):
"""
訓練模型(其實就是記住已知數據)
參數:
X:已知樣本的特征數據,形狀為[樣本數, 特征數]
y:已知樣本的標簽,形狀為[樣本數]
"""
self.known_data = X
self.known_labels = y
print(f"模型訓練完成,記住了{len(X)}個樣本")
def predict(self, X):
"""
預測新樣本的類別
參數:
X:新樣本的特征數據,形狀為[特征數]
返回:
預測的標簽
"""
# 計算與所有已知樣本的距離
distances = []
for data in self.known_data:
# 計算歐氏距離
dist = np.sqrt(np.sum((data - X) **2))
distances.append(dist)
# 把距離和標簽綁定,按距離排序
distance_with_label = list(zip(distances, self.known_labels))
sorted_distance = sorted(distance_with_label, key=lambda x: x[0])
# 取前k個鄰居的標簽
nearest_labels = [label for (dist, label) in sorted_distance[:self.k]]
# 少數服從多數
return np.argmax(np.bincount(nearest_labels))
# ----------------------
# 用水果數據測試我們的分類器
# ----------------------
if __name__ == "__main__":
# 已知水果特征:[顏色深度, 大小]
fruits = np.array([
[2, 3], [3, 4], [1, 5], # 蘋果(標簽0)
[7, 6], [8, 5], [9, 4] # 橘子(標簽1)
])
labels = np.array([0, 0, 0, 1, 1, 1])
# 創建分類器,選擇5個鄰居(試試把k改成1或5,看結果會不會變)
knn = SimpleKNN(k=5)
# 訓練模型(其實就是記住數據)
knn.fit(fruits, labels)
# 要預測的未知水果:顏色深度6,大小5
unknown_fruit = np.array([6, 5])
prediction = knn.predict(unknown_fruit)
# 輸出結果
fruit_names = {0: "蘋果", 1: "橘子"}
print(f"\n未知水果的特征:顏色深度{unknown_fruit[0]},大小{unknown_fruit[1]}")
print(f"預測結果:這是一個{fruit_names[prediction]}!")
# 畫圖展示
plt.scatter(fruits[labels==0, 0], fruits[labels==0, 1],
color='red', marker='o', label='蘋果')
plt.scatter(fruits[labels==1, 0], fruits[labels==1, 1],
color='orange', marker='o', label='橘子')
plt.scatter(unknown_fruit[0], unknown_fruit[1],
color='purple', marker='*', s=200, label='未知水果')
plt.xlabel('顏色深度(0-10,越大越黃)')
plt.ylabel('大小(0-10,越大越大)')
plt.title(f'k={knn.k}的k近鄰分類結果')
plt.legend()
plt.show()
五、k 近鄰算法的關鍵知識點
5.1 如何選擇最佳的 k 值?
k 值是 k 近鄰算法中最重要的參數:
- k 太小:容易被噪聲干擾(比如剛好有個奇怪的蘋果長得像橘子)
- k 太大:會把不相關的樣本也算進來(比如遠在天邊的蘋果也參與投票)
一個簡單的方法是:從 k=3 開始嘗試,逐漸增大,看哪個 k 值的預測效果最好。
5.2 特征需要 "標準化"
生活中如果特征的單位不一樣(比如一個特征是厘米,一個是千克),會影響距離計算。解決辦法是標準化:
# 標準化特征:讓每個特征的平均值為0,標準差為1
def standardize(X):
return (X - np.mean(X, axis=0)) / np.std(X, axis=0)
5.3 k 近鄰的優缺點
優點:
- 簡單易懂,幾乎不用數學基礎就能理解
- 不需要提前訓練模型,拿到新數據可以直接用
- 可以處理多種類型的數據
缺點:
- 數據量大的時候,計算距離會很慢
- 對特征的數量敏感(特征太多時會 "迷路")
六、動手實踐:用 scikit-learn 實現更專業的 k 近鄰
真實項目中,我們會用成熟的庫來實現 k 近鄰。試試用 scikit-learn(Python 最流行的機器學習庫)重寫上面的水果分類:
# 安裝scikit-learn(如果沒安裝的話)
# !pip install scikit-learn
from sklearn.neighbors import KNeighborsClassifier
import numpy as np
# 數據準備(和之前一樣)
fruits = np.array([
[2, 3], [3, 4], [1, 5], # 蘋果
[7, 6], [8, 5], [9, 4] # 橘子
])
labels = np.array([0, 0, 0, 1, 1, 1])
# 創建k近鄰分類器,k=3
knn = KNeighborsClassifier(n_neighbors=3)
# 訓練模型
knn.fit(fruits, labels)
# 預測未知水果
unknown_fruit = np.array([[6, 5]]) # 注意這里要寫成二維數組
prediction = knn.predict(unknown_fruit)
print("scikit-learn預測結果:", "橘子" if prediction[0]==1 else "蘋果") # 輸出"橘子"
是不是更簡單了?這就是專業庫的力量!
七、總結:一篇博客掌握 k 近鄰
通過辨別水果的例子,我們學會了:
- k 近鄰算法的核心思想:"看鄰居投票"
- 實現步驟:計算距離→找鄰居→投票表決
- 關鍵參數 k 的選擇方法
- 如何用代碼實現(從手寫簡單版本到專業庫)
k 近鄰就像機器學習世界的 "Hello World",它簡單卻蘊含了機器學習的基本思想 ——從數據中找規律。下一次當你在超市辨別水果時,不妨想想:"這個過程如果寫成代碼,應該怎么實現呢?"
現在就動手修改代碼里的參數(比如 k 值、水果特征),看看會得到什么有趣的結果吧!
祝你的機器學習之旅,從這個甜甜的 "橘子分類器" 開始,越來越精彩!
??還想看更多,來啦!!!