用KNN實現手寫數字識別:基于 OpenCV 和 scikit-learn 的實戰教學
在這篇文章中,我們將使用 KNN(K-Nearest Neighbors)算法對手寫數字進行分類識別。我們會用 OpenCV 讀取圖像并預處理數據,用 scikit-learn 構建并訓練模型,最終識別新的數字圖像。
為什么像素可以被代碼讀取為數據:
圖像的本質:像素的數字矩陣
任何數字圖像(如照片、截圖、手寫數字圖片)都是由無數個微小的 “像素點”(Pixel)組成的
每個像素點的數值含義:
- 對于灰度圖(如代碼中的手寫數字),每個像素用一個 0-255 的整數表示亮度:0 代表純黑,255 代表純白,中間值表示不同深淺的灰色。
- 對于彩色圖(如 RGB 格式),每個像素由三個數值(R、G、B)組成,分別對應紅、綠、藍三種顏色的亮度,組合后呈現出各種顏色。
(可以在調試時看一看代碼里的‘gray’參數里面 單個數字圖像的矩陣)
此20×20 像素的數字圖像就是一個數字矩陣顯示出的一個大大的 0
使用的數據集
我們使用的是一個包含 5000 個手寫數字(0-9) 的圖像文件(digits5000.png
),每種數字500個,總共10類。圖像被排布成了一個 50 行 × 100 列 的網格,每個小格是一個 20×20 像素的數字圖像。
數據圖像:
保存下面代碼所需的三張圖片:
?
上面的 ‘3’,‘6’ 圖片可在‘開始’里的‘畫圖’中,可以創建我們想要自定義的數字圖片:
然后使用畫筆寫一個數字后(筆粗一點):
把比例調到? ?20×20 像素
即可得到。
?完整代碼:
import numpy as np
from sklearn.neighbors import KNeighborsClassifier
import cv2# 讀取包含5000個手寫數字的大圖,每個數字為20x20像素
img = cv2.imread('digits5000.png')
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) # 轉換為灰度圖# 將大圖分割為50行100列的小單元格,每個單元格包含一個手寫數字
cells = [np.hsplit(row,100) for row in np.vsplit(gray,50)]# 將單元格列表轉換為四維數組: (50行, 100列, 20像素高, 20像素寬)
x = np.array(cells)# 準備訓練集和測試集數據
# 前50列作為訓練集,后50列作為測試集
# 數據重塑為: (樣本數, 特征數),每個樣本包含400個像素值
train = x[:,:50].reshape(-1,400)
test = x[:,50:].reshape(-1,400)# 創建標簽數據
n = np.arange(10) # 創建數字0-9的數組
# 每個數字對應250個樣本(50行×5列),生成訓練標簽
tags = np.repeat(n,250)
tag = tags[:,np.newaxis] # 轉換為二維數組,形狀為(2500, 1)# 創建并訓練KNN分類器,使用k=5最近鄰
knn = KNeighborsClassifier(n_neighbors=5) #給初學者的建議:在此處設置斷點 或 直接在import處設置斷點,開啟調試一行一行地運行,更好的看到每一個參數的變化
knn.fit(train, tag)# 評估模型在訓練集上的準確率
predictions_train = knn.predict(train) #此行結果沒有運行,建議在開啟調試,可看到
accuracy_train = knn.score(train, tag)
print(f"訓練集準確率: {accuracy_train:.4f}")# 評估模型在測試集上的準確率
predictions_test = knn.predict(test)
accuracy_test = knn.score(test, tag)
print(f"測試集準確率: {accuracy_test:.4f}")# 預測外部數字3的圖像
digit3 = cv2.imread('digit3.png')
digit3gray = cv2.cvtColor(digit3, cv2.COLOR_BGR2GRAY)
digit3test = digit3gray.reshape(-1,400) # 重塑為模型期望的輸入格式
predictions_digit3 = knn.predict(digit3test)
print(f"預測數字3的結果: {predictions_digit3}")# 預測外部數字6的圖像
digit6 = cv2.imread('digit6.png')
digit6gray = cv2.cvtColor(digit6, cv2.COLOR_BGR2GRAY)
digit6test = digit6gray.reshape(-1,400) # 重塑為模型期望的輸入格式
predictions_digit6 = knn.predict(digit6test)
print(f"預測數字6的結果: {predictions_digit6}")
紅色斷點:
調試:
點擊 單步執行 或 其他? ?,多嘗試嘗試
點擊“作為...查看”即可清晰查看
所需庫導入
import numpy as np
from sklearn.neighbors import KNeighborsClassifier
import cv2
numpy
: 用于矩陣操作。sklearn.neighbors.KNeighborsClassifier
: 實現KNN分類器。cv2
: OpenCV庫,用于圖像讀取與處理。
圖像加載與預處理
img = cv2.imread('digits5000.png')
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
cv2.imread()
讀取圖像。cv2.cvtColor()
將圖像從彩色(BGR)轉換為灰度(GRAY),便于處理。
圖像分割成小數字圖塊
cells = [np.hsplit(row,100) for row in np.vsplit(gray,50)]
x = np.array(cells)
使用
np.vsplit()
將圖像豎直切成50行(每行包含100個數字)。對每行使用
np.hsplit()
水平切分成100列,最終每個小格是一個20x20的數字圖像。得到的
x
是一個形狀為(50, 100, 20, 20)
的數組。
構建訓練集與測試集
train = x[:,:50].reshape(-1,400)
test = x[:,50:].reshape(-1,400)
將每個 20×20 圖像展開為 1×400 的一維向量。
前 50 列作為訓練集,后 50 列作為測試集。
train
和test
的形狀均為(2500, 400)
。
構造標簽
n = np.arange(10) #(0123456789)
tags = np.repeat(n,250) #每個數字重復250次 -> [0,...0,1,...,1,...9,...9]
tag = tags[:,np.newaxis] #添加新維度,變成列向量
# test_tag = np.repeat(n,250)[:np.newaxis]
tags
是長度為 2500 的一維標簽數組。tag
是形狀為(2500, 1)
的列向量,作為訓練和測試的真實標簽。
為什么重復250次,因為我們把數據從中間對半切開,每個數字有500個,左二百五十為訓練集,右二百五十個為測試集。
所以其實訓練集和測試集的標簽是一樣的,標簽就是每一個20x20數字圖像所顯示的數字。
訓練模型并評估準確率
knn = KNeighborsClassifier(n_neighbors=5)
knn.fit(train, tag)predictions_train = knn.predict(train)
accuracy_train = knn.score(train, tag)
print(accuracy_train)predictions_test = knn.predict(test)
accuracy_test = knn.score(test, tag)
print(accuracy_test)
初始化一個 KNN 分類器,選擇
k=5
。使用
.fit()
訓練模型。使用
.predict()
和.score()
對訓練集和測試集進行預測和評分。打印準確率:理論上訓練集精度應很高,測試集略低(但通常也超過 97%)。
識別自定義手寫數字
digit3 = cv2.imread('digit3.png') #圖片中的數字是3
digit3gray = cv2.cvtColor(digit3, cv2.COLOR_BGR2GRAY)
digit3test = digit3gray.reshape(-1,400)
predictions_digit3 = knn.predict(digit3test)
print(predictions_digit3)digit6 = cv2.imread('digit6.png') #圖片中的數字是6
digit6gray = cv2.cvtColor(digit6, cv2.COLOR_BGR2GRAY)
digit6test = digit6gray.reshape(-1,400)
predictions_digit6 = knn.predict(digit6test)
print(predictions_digit6)
讀取兩張額外圖像
digit3.png
和digit6.png
。將其轉換為灰度、再reshape成與訓練數據一致的
(1, 400)
形狀。使用模型預測數字類別。
總結
我們用 KNN 成功實現了手寫數字的分類識別,關鍵步驟包括:
圖像預處理和切分
標簽構造與數據 reshape
使用
KNeighborsClassifier
建模預測未知圖像的數字類別