機器學習入門核心算法:K均值(K-Means)
- 1. 算法邏輯
- 2. 算法原理與數學推導
- 2.1 目標函數
- 2.2 數學推導
- 2.3 時間復雜度
- 3. 模型評估
- 內部評估指標
- 外部評估指標(需真實標簽)
- 4. 應用案例
- 4.1 客戶細分
- 4.2 圖像壓縮
- 4.3 文檔聚類
- 5. 面試題及答案
- 6. 優缺點分析
- **優點**:
- **缺點**:
- 7. 數學證明:為什么均值最小化WCSS?
1. 算法邏輯
K均值是一種無監督聚類算法,核心目標是將n個數據點劃分為k個簇(cluster),使得同一簇內數據點相似度高,不同簇間差異大。算法流程如下:
graph TDA[初始化K個質心] --> B[分配數據點到最近質心]B --> C[重新計算質心位置]C --> D{質心是否變化?}D -- 是 --> BD -- 否 --> E[輸出聚類結果]
2. 算法原理與數學推導
2.1 目標函數
最小化簇內平方和(Within-Cluster Sum of Squares, WCSS):
J = ∑ i = 1 k ∑ x ∈ C i ∥ x ? μ i ∥ 2 J = \sum_{i=1}^k \sum_{x \in C_i} \|x - \mu_i\|^2 J=i=1∑k?x∈Ci?∑?∥x?μi?∥2
其中:
- C i C_i Ci? 表示第i個簇
- μ i \mu_i μi? 表示第i個簇的質心
- ∥ x ? μ i ∥ \|x - \mu_i\| ∥x?μi?∥ 表示歐氏距離
2.2 數學推導
步驟1:初始化質心
隨機選擇k個數據點作為初始質心:
μ 1 ( 0 ) , μ 2 ( 0 ) , . . . , μ k ( 0 ) 其中 μ j ( 0 ) ∈ R d \mu_1^{(0)}, \mu_2^{(0)}, ..., \mu_k^{(0)} \quad \text{其中} \mu_j^{(0)} \in \mathbb{R}^d μ1(0)?,μ2(0)?,...,μk(0)?其中μj(0)?∈Rd
步驟2:分配數據點
對每個數據點 x i x_i xi?,計算到所有質心的距離,分配到最近質心的簇:
C j ( t ) = { x i : ∥ x i ? μ j ( t ) ∥ 2 ≤ ∥ x i ? μ l ( t ) ∥ 2 ? l } C_j^{(t)} = \{ x_i : \|x_i - \mu_j^{(t)}\|^2 \leq \|x_i - \mu_l^{(t)}\|^2 \ \forall l \} Cj(t)?={xi?:∥xi??μj(t)?∥2≤∥xi??μl(t)?∥2??l}
步驟3:更新質心
重新計算每個簇的均值作為新質心:
μ j ( t + 1 ) = 1 ∣ C j ( t ) ∣ ∑ x i ∈ C j ( t ) x i \mu_j^{(t+1)} = \frac{1}{|C_j^{(t)}|} \sum_{x_i \in C_j^{(t)}} x_i μj(t+1)?=∣Cj(t)?∣1?xi?∈Cj(t)?∑?xi?
步驟4:收斂條件
當滿足以下任一條件時停止迭代:
∥ μ j ( t + 1 ) ? μ j ( t ) ∥ < ? 或 J ( t ) ? J ( t + 1 ) < δ \|\mu_j^{(t+1)} - \mu_j^{(t)}\| < \epsilon \quad \text{或} \quad J^{(t)} - J^{(t+1)} < \delta ∥μj(t+1)??μj(t)?∥<?或J(t)?J(t+1)<δ
2.3 時間復雜度
- 每次迭代: O ( k n d ) O(knd) O(knd)
- 總復雜度: O ( t k n d ) O(tknd) O(tknd)
其中k=簇數,n=樣本數,d=特征維度,t=迭代次數
3. 模型評估
內部評估指標
-
輪廓系數(Silhouette Coefficient)
s ( i ) = b ( i ) ? a ( i ) max ? { a ( i ) , b ( i ) } s(i) = \frac{b(i) - a(i)}{\max\{a(i), b(i)\}} s(i)=max{a(i),b(i)}b(i)?a(i)?- a ( i ) a(i) a(i):樣本i到同簇其他點的平均距離
- b ( i ) b(i) b(i):樣本i到最近其他簇的平均距離
- 取值范圍:[-1, 1],越大越好
-
戴維森堡丁指數(Davies-Bouldin Index)
D B = 1 k ∑ i = 1 k max ? j ≠ i ( σ i + σ j d ( μ i , μ j ) ) DB = \frac{1}{k} \sum_{i=1}^k \max_{j \neq i} \left( \frac{\sigma_i + \sigma_j}{d(\mu_i, \mu_j)} \right) DB=k1?i=1∑k?j=imax?(d(μi?,μj?)σi?+σj??)- σ i \sigma_i σi?:簇i內平均距離
- d ( μ i , μ j ) d(\mu_i, \mu_j) d(μi?,μj?):簇中心距離
- 取值越小越好
外部評估指標(需真實標簽)
- 調整蘭德指數(Adjusted Rand Index)
- Fowlkes-Mallows 指數
- 互信息(Mutual Information)
4. 應用案例
4.1 客戶細分
- 場景:電商用戶行為分析
- 特征:購買頻率、客單價、瀏覽時長
- 結果:識別高價值客戶群(K=5),營銷轉化率提升23%
4.2 圖像壓縮
- 原理:將像素顏色聚類為K種代表色
- 步驟:
- 將圖像視為RGB點集(n=像素數,d=3)
- 設置K=256(256色圖像)
- 用簇中心顏色代替原始像素
- 效果:文件大小減少85%且視覺質量可接受
4.3 文檔聚類
- 場景:新聞主題分類
- 特征:TF-IDF向量(d≈10,000)
- 挑戰:高維稀疏數據需先降維(PCA/t-SNE)
5. 面試題及答案
Q1:K均值對初始質心敏感,如何改進?
A:采用K-means++初始化:
- 隨機選第一個質心
- 后續質心以概率 D ( x ) 2 / ∑ D ( x ) 2 D(x)^2 / \sum D(x)^2 D(x)2/∑D(x)2選擇
( D ( x ) D(x) D(x)為點到最近質心的距離)
Q2:如何確定最佳K值?
A:常用方法:
- 肘部法則(Elbow Method):繪制K與WCSS曲線,選拐點
from sklearn.cluster import KMeans distortions = [] for k in range(1,10):kmeans = KMeans(n_clusters=k)kmeans.fit(X)distortions.append(kmeans.inertia_)
- 輪廓系數法:選擇使平均輪廓系數最大的K
Q3:K均值能否處理非凸數據?
A:不能。K均值假設簇是凸形且各向同性。解決方案:
- 使用譜聚類(Spectral Clustering)
- 或DBSCAN等基于密度的算法
6. 優缺點分析
優點:
- 簡單高效:時間復雜度線性增長,適合大規模數據
- 可解釋性強:簇中心代表原型特征
- 易于實現:核心代碼僅需10行
def k_means(X, k):centroids = X[np.random.choice(len(X), k)]while True:labels = np.argmin(np.linalg.norm(X[:,None]-centroids, axis=2), axis=1)new_centroids = np.array([X[labels==j].mean(0) for j in range(k)])if np.allclose(centroids, new_centroids): breakcentroids = new_centroidsreturn labels, centroids
缺點:
- 需預先指定K值:實際應用中K常未知
- 對異常值敏感:均值易受極端值影響
- 僅適用于數值數據:需對類別變量編碼
- 局部最優問題:不同初始化可能產生不同結果
- 假設各向同性:在細長簇或流形數據上效果差
7. 數學證明:為什么均值最小化WCSS?
給定簇 C j C_j Cj?,優化目標:
min ? μ j ∑ x i ∈ C j ∥ x i ? μ j ∥ 2 \min_{\mu_j} \sum_{x_i \in C_j} \|x_i - \mu_j\|^2 μj?min?xi?∈Cj?∑?∥xi??μj?∥2
求導并令導數為零:
? ? μ j ∑ ∥ x i ? μ j ∥ 2 = ? 2 ∑ ( x i ? μ j ) = 0 \frac{\partial}{\partial \mu_j} \sum \|x_i - \mu_j\|^2 = -2 \sum (x_i - \mu_j) = 0 ?μj???∑∥xi??μj?∥2=?2∑(xi??μj?)=0
解得:
μ j = 1 ∣ C j ∣ ∑ x i \mu_j = \frac{1}{|C_j|} \sum x_i μj?=∣Cj?∣1?∑xi?
證畢:均值是簇內平方和的最優解。
💡 關鍵洞察:K均值本質是期望最大化(EM)算法的特例:
- E步:固定質心,分配數據點(期望)
- M步:固定分配,更新質心(最大化)
實際應用時建議:
- 數據標準化:消除量綱影響
- 多次運行:取最佳結果(
n_init=10
)- 結合PCA降維:處理高維數據
- 對分類型數據用K-mode變種