KNN 算法–圖像分類算法
找到最近的K個鄰居,在前k個最近樣本中選擇最近的占比最高的類別作為預測類別。
- 給定測試對象,計算它與訓練集中每個對象的距離。
- 圈定距離最近的k個訓練對象,作為測試對象的鄰居。
- 根據這k個緊鄰對象所屬的類別,找到占比最高的那個類別作為測試對象的預測類別。
影響因素:
- 計算測試對象與訓練集中各個對象的距離。
- k的選擇。
import operatorimport numpy as np
import matplotlib.pyplot as pltdef create_data_set():group = np.array([[1.0, 2.0], [1.2, 0.1], [0.1, 1.4], [0.3, 3.5], [1.1, 1.0], [0.5, 1.5]])labels = np.array(['A', 'A', 'B', 'B', 'A', 'B'])return group, labelsdef knn_classify(k, dis, X_train, x_train, Y_test):assert dis == 'E' or dis == 'M', 'dis must E or M, E 代表歐式距離,M代表曼哈頓距離'num_test = Y_test.shape[0]label_list = []if dis == 'E':for i in range(num_test):distances = np.sqrt(np.sum(((X_train - np.tile(Y_test[i], (X_train.shape[0], 1))) ** 2), axis=1))nearest_k = np.argsort(distances)topK = nearest_k[:k]print(topK)classCount = {}for i in topK:classCount[x_train[i]] = classCount.get(x_train[i], 0) + 1sorted_class_count = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)label_list.append(sorted_class_count[0][0])return np.array(label_list)if __name__ == '__main__':group, labels = create_data_set()plt.scatter(group[labels == 'A', 0], group[labels == 'A', 1], color='r', marker='*')plt.scatter(group[labels == 'B', 0], group[labels == 'B', 1], color='g', marker='+')y_test_pred = knn_classify(1, 'E', group, labels, np.array([[1.0, 2.1], [0.4, 2.0]]))print(y_test_pred)plt.show()