利用knn算法實現手寫數字分類
- 1.作者介紹
- 2.KNN算法
- 2.1KNN(K-Nearest Neighbors)算法核心思想
- 2.2KNN算法的工作流程
- 2.3優缺點
- 2.4 KNN算法圖示介紹
- 3.實驗過程
- 3.1安裝所需庫
- 3.2 MNIST數據集
- 3.3 導入手寫數字圖像進行分類
- 3.4 完整代碼
- 3.5 實驗結果
1.作者介紹
王鵬飛,男,西安工程大學電子信息學院,2024級研究生
研究方向:機器視覺與人工智能
電子郵件:2018659934@QQ.com
王海博, 男 , 西安工程大學電子信息學院, 2024級研究生, 張宏偉人工智能課題組
研究方向:模式識別與人工智能
電子郵件:1137460680@qq.com
2.KNN算法
2.1KNN(K-Nearest Neighbors)算法核心思想
將訓練數據保存下來,對于一個新的數據點,通過查看其在特征空間中最近的K個鄰居來預測其類別或值。針對分類任務:如果K個鄰居中多數屬于某個類別,那么新數據點也被歸為該類別。
2.2KNN算法的工作流程
(1) 數據準備
特征提取:將數據集中的每個樣本表示為特征向量。
數據標準化:由于KNN依賴距離計算,因此需要對特征進行標準化(如歸一化或Z分數標準化),以消除不同特征量綱的影響。
(2) 距離計算
對于一個新的數據點,計算它與數據集中每個點之間的距離。常用的距離度量方式包括:歐氏距離、曼哈頓距離和明可夫斯基距離。
(3) 確定最近鄰
根據計算出的距離,找出與新數據點距離最近的K個點,這K個點稱為“最近鄰”。
K是一個超參數,需要根據具體問題選擇合適的值。K值過小可能導致過擬合,K值過大可能導致模型過于平滑。
(4) 進行預測
分類任務:統計K個最近鄰中每個類別的出現頻率,選擇出現次數最多的類別作為新數據點的預測類別。
2.3優缺點
(1) 優點
簡單易實現:原理直觀,實現代碼簡單。
無需訓練:KNN不需要像其他算法那樣進行復雜的訓練過程,只需在預測時計算距離。
對復雜數據集表現良好:可以很好地處理多類別問題和非線性數據。
(2) 缺點
計算效率低:每次預測都需要計算新數據點與所有訓練數據點之間的距離,計算量大。
存儲需求高:需要存儲整個訓練數據集。
對K值和距離度量敏感:K值的選擇和距離度量方式對模型性能影響較大。
2.4 KNN算法圖示介紹
見上圖所示,五角星為新輸入的數據,原訓練數據有Class A和Class B兩類,對于新輸入的數據,根據特征向量計算新輸入數據點與訓練集數據點之間的距離,根據所選K值確定出,新數據最鄰近K個點,圖示第一次k值選取為3時,其中Class B類占2/3,所以新數據將被分類為Class B類。
當k值選取為6時,見上圖所示,Class A類占4/6,所以此時對于新數據點將被歸為Class A類。由此可見K值的選擇對于分類的結果存在一定的影響,因此k值的選擇對于結果有重要的作用。
3.實驗過程
3.1安裝所需庫
import numpy as np
from sklearn.datasets import fetch_openml
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import accuracy_score, classification_report
import cv2
from PIL import Image
import matplotlib.pyplot as plt
在編寫代碼前需要安裝上述的庫和所需的函數。
3.2 MNIST數據集
MNIST數據集來自美國國家標準與技術研究所。訓練集由來自250個不同人手寫的數字構成,測試集也是同樣的手寫數字數據,保證了測試集和訓練集的作者集不相交。MNIST數據集一共有7萬張圖片,其中6萬張是訓練集,1萬張是測試集。每張圖片是28 × 28像素 的0 ? 9的手寫數字圖片組成。每個圖片是黑底白字的灰度圖像。MNIST數據集可以導入fetch_openml函數從OpenML平臺加載數據集。
3.3 導入手寫數字圖像進行分類
# 導入自定義圖像并進行預測
def preprocess_image(image_path):image = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)image = cv2.resize(image, (28, 28))image = cv2.bitwise_not(image)image = image.reshape(1, -1)image = scaler.transform(image)return image
def predict_image(image_path):image = preprocess_image(image_path)prediction = knn.predict(image)return prediction[0]
print("Testing custom image...")
image_path = "d:/wenjian/1.jpg" #更改為自己的路徑
prediction = predict_image(image_path)
print(f"Predicted digit: {prediction}")
# 顯示圖像
image = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
image = cv2.resize(image, (28, 28))
image = cv2.bitwise_not(image)
plt.imshow(image, cmap="gray")
plt.title(f"Predicted Digit: {prediction}")
plt.show()
導入一張白底黑字的手寫數字圖像,并對圖像進行預處理使得格式和灰度值與其訓練集相同,本次實驗導入的是白底黑字的手寫數字圖像,因為距離計算是依據灰度圖像的灰度值進行計算,訓練集的圖像是黑底白字的灰度圖像,因此需要對灰度值進行反轉,否則會造成預測誤差較大。導入圖像路徑需更改為自己圖像路徑。
3.4 完整代碼
import numpy as np
from sklearn.datasets import fetch_openml
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import accuracy_score, classification_report
import cv2
from PIL import Image
import matplotlib.pyplot as plt# 加載MNIST數據集
print("Loading MNIST dataset...")
mnist = fetch_openml('mnist_784', version=1)
X, y = mnist["data"], mnist["target"]
y = y.astype(np.uint8)
# 數據預處理
print("Preprocessing data...")
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)
# 訓練KNN模型
print("Training KNN model...")
knn = KNeighborsClassifier(n_neighbors=5)
knn.fit(X_train, y_train)
# 評估模型
print("Evaluating model...")
y_pred = knn.predict(X_test)
accuracy = accuracy_score(y_test, y_pred)
print(f"Accuracy: {accuracy:.4f}")
print(classification_report(y_test, y_pred))
# 導入自定義圖像并進行預測
def preprocess_image(image_path):image = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)image = cv2.resize(image, (28, 28))image = cv2.bitwise_not(image)image = image.reshape(1, -1)image = scaler.transform(image)return image
def predict_image(image_path):image = preprocess_image(image_path)prediction = knn.predict(image)return prediction[0]
print("Testing custom image...")
image_path = "d:/wenjian/1.jpg" #更改為自己的路徑
prediction = predict_image(image_path)
print(f"Predicted digit: {prediction}")
# 顯示圖像
image = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
image = cv2.resize(image, (28, 28))
image = cv2.bitwise_not(image)
plt.imshow(image, cmap="gray")
plt.title(f"Predicted Digit: {prediction}")
plt.show()