目錄
一.案例:手寫數字的識別
1.安裝opencv-python庫
2.將大圖分割成100×50個小圖,每份對應一個手寫數字樣品
3.訓練集和測試集
4.為訓練集和測試集準備結果標簽
5.模型訓練與預測
6.計算準確率
7.完整代碼實現
一.案例:手寫數字的識別
現有一張2000×1000像素的手寫數字照片digits.png作為數據集,如下共有100列50行,由此我們可以計算出一個手寫數字的大小是20×20像素
1.安裝opencv-python庫
安裝opencv-python庫指令(可以根據自己的需要指定版本)如下:
pip install opencv-python==3.4.11.45 -i Https://pypi.tuna.tsinghua.edu.cn/simple
上面的digits.png圖片是彩色圖像,由RGB三個通道疊加而成,所以它的本質是三維矩陣
我們需要利用opencv-python這個庫的imread()方法來讀取圖片數據
然后用cv2.cvtColor(img,COLOR_BGR2GRAY)將其轉換為灰度圖,灰度圖僅保留亮度信息轉化為二維矩陣,無彩色通道數據更簡化
import numpy as np
import cv2
img = cv2.imread('digits.png')
gray = cv2.cvtColor(img,cv2.COLOR_BGR2GRAY)
2.將大圖分割成100×50個小圖,每份對應一個手寫數字樣品
先利用numpy的vsplit()方法,將大圖在垂直方向上分割成50行
再利用hsplit()方法把分割后的每一行在水平方向上切成100列
這樣就得到了50×100個手寫數字列表
cells = [ np.hsplit(row,100) for row in np.vsplit(gray,50)]
再用numpy庫的array()方法將列表轉化成矩陣,提升處理效率
data = np.array(cells)
3.訓練集和測試集
為了訓練和測試的準確性我們將全部數據的前50列作為訓練集,后50列作為測試集
train=data[:,:50]
test=data[:,50:]
訓練的數據都是一行數據代表一個樣本,前面我們知道一個手寫數字大小是20×20像素,所以我們可以將矩陣reshape為(-1,400)的樣式,這樣一行就是一個手寫數字的400個特征,訓練集和測試集各有2500行
數據類型也得從?uint8
(0-255整數)轉為?float32
,以支持KNN算法中的距離計算(含小數)
train_new=train.reshape(-1,400).astype(np.float32)
test_new=test.reshape(-1,400).astype(np.float32)
4.為訓練集和測試集準備結果標簽
用numpy庫中的arange()方法,生成0到9的數字序列,由于測試集和訓練集中每個數字都各有250個即他們的特征數據都各有250行,所以我們再用repeat()方法將數字序列重復250次
k=np.arange(0,10)
labels=np.repeat(k,250)
再將標簽labels通過np.newaxis 轉為二維列向量(2500×1),與特征數據對齊
train_labels = labels[:,np.newaxis]
test_labels = labels[:,np.newaxis]
5.模型訓練與預測
優先使用OpenCV內置算法(如KNN)以減少依賴庫數量,提升運行效率
使用OpenCV庫的KNN算法:
通過?cv2.ml.KNearest_create()創建v模型,在通過train()方法傳入訓練數據(特征矩陣和標簽),train()方法中的參數?cv2.ml.ROW_SAMPLE
?表示指定每行為一個樣本數據
knn = cv2.ml.KNearest_create()
knn.train(train_new,cv2.ml.ROW_SAMPLE,train_labels)
使用findNearest()方法完成對測試集的預測并指定K值
返回結果result中存放預測的結果
ret,result,neighbors,dist=knn.findNearest(test_new,k=3)
6.計算準確率
由于OpenCV庫的KNN算法中沒有計算準確率的方法所以我們需要自己計算
通過result==test_labels,預測結果與標簽相同則放回True反之返回False,最后返回一個只有True和False的序列
通過np.count_nonzero()方法來計算一共有多少個True
最后直接用True的個數除以總共的個數即為準確率
matches = result==test_labels
correct=np.count_nonzero(matches)
accuracy=correct*100.0/result.size
print("識別手寫數字的準確率為{}%".format(accuracy))
7.完整代碼實現
import numpy as np
import cv2
img = cv2.imread('digits.png')
gray = cv2.cvtColor(img,cv2.COLOR_BGR2GRAY)
cells = [ np.hsplit(row,100) for row in np.vsplit(gray,50)]
data = np.array(cells)train=data[:,:50]
test=data[:,50:]train_new=train.reshape(-1,400).astype(np.float32)
test_new=test.reshape(-1,400).astype(np.float32)k=np.arange(0,10)
labels=np.repeat(k,250)
train_labels = labels[:,np.newaxis]
test_labels = labels[:,np.newaxis]knn = cv2.ml.KNearest_create()
knn.train(train_new,cv2.ml.ROW_SAMPLE,train_labels)
ret,result,neighbors,dist=knn.findNearest(test_new,k=3)matches = result==test_labels
correct=np.count_nonzero(matches)
accuracy=correct*100.0/result.size
print("識別手寫數字的準確率為{}%".format(accuracy))