CNN tensorflow 人臉識別

數據材料

這是一個小型的人臉數據庫,一共有40個人,每個人有10張照片作為樣本數據。這些圖片都是黑白照片,意味著這些圖片都只有灰度0-255,沒有rgb三通道。于是我們需要對這張大圖片切分成一個個的小臉。整張圖片大小是1190 × 942,一共有20 × 20張照片。那么每張照片的大小就是(1190 / 20)× (942 / 20)= 57 × 47 (大約,以為每張圖片之間存在間距)。

問題解決:

10類樣本,利用CNN訓練可以分類10類數據的神經網絡,與手寫字符識別類似

olivettifaces.gif

?

復制代碼
#coding=utf-8
#http://www.jianshu.com/p/3e5ddc44aa56
#tensorflow 1.3.1
#python 3.6
import os
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import matplotlib.patches as patches
import numpy
from PIL import Image#獲取dataset
def load_data(dataset_path):img = Image.open(dataset_path)# 定義一個20 × 20的訓練樣本,一共有40個人,每個人都10張樣本照片img_ndarray = np.asarray(img, dtype='float64') / 256#img_ndarray = np.asarray(img, dtype='float32') / 32# 記錄臉數據矩陣,57 * 47為每張臉的像素矩陣faces = np.empty((400, 57 * 47))for row in range(20):for column in range(20):faces[20 * row + column] = np.ndarray.flatten(img_ndarray[row * 57: (row + 1) * 57, column * 47 : (column + 1) * 47])label = np.zeros((400, 40))for i in range(40):label[i * 10: (i + 1) * 10, i] = 1# 將數據分成訓練集,驗證集,測試集train_data = np.empty((320, 57 * 47))train_label = np.zeros((320, 40))vaild_data = np.empty((40, 57 * 47))vaild_label = np.zeros((40, 40))test_data = np.empty((40, 57 * 47))test_label = np.zeros((40, 40))for i in range(40):train_data[i * 8: i * 8 + 8] = faces[i * 10: i * 10 + 8]train_label[i * 8: i * 8 + 8] = label[i * 10: i * 10 + 8]vaild_data[i] = faces[i * 10 + 8]vaild_label[i] = label[i * 10 + 8]test_data[i] = faces[i * 10 + 9]test_label[i] = label[i * 10 + 9]train_data = train_data.astype('float32')vaild_data = vaild_data.astype('float32')test_data = test_data.astype('float32')return [(train_data, train_label),(vaild_data, vaild_label),(test_data, test_label)]def convolutional_layer(data, kernel_size, bias_size, pooling_size):kernel = tf.get_variable("conv", kernel_size, initializer=tf.random_normal_initializer())bias = tf.get_variable('bias', bias_size, initializer=tf.random_normal_initializer())conv = tf.nn.conv2d(data, kernel, strides=[1, 1, 1, 1], padding='SAME')linear_output = tf.nn.relu(tf.add(conv, bias))pooling = tf.nn.max_pool(linear_output, ksize=pooling_size, strides=pooling_size, padding="SAME")return poolingdef linear_layer(data, weights_size, biases_size):weights = tf.get_variable("weigths", weights_size, initializer=tf.random_normal_initializer())biases = tf.get_variable("biases", biases_size, initializer=tf.random_normal_initializer())return tf.add(tf.matmul(data, weights), biases)def convolutional_neural_network(data):# 根據類別個數定義最后輸出層的神經元n_ouput_layer = 40kernel_shape1=[5, 5, 1, 32]kernel_shape2=[5, 5, 32, 64]full_conn_w_shape = [15 * 12 * 64, 1024]out_w_shape = [1024, n_ouput_layer]bias_shape1=[32]bias_shape2=[64]full_conn_b_shape = [1024]out_b_shape = [n_ouput_layer]data = tf.reshape(data, [-1, 57, 47, 1])# 經過第一層卷積神經網絡后,得到的張量shape為:[batch, 29, 24, 32]with tf.variable_scope("conv_layer1") as layer1:layer1_output = convolutional_layer(data=data,kernel_size=kernel_shape1,bias_size=bias_shape1,pooling_size=[1, 2, 2, 1])# 經過第二層卷積神經網絡后,得到的張量shape為:[batch, 15, 12, 64]with tf.variable_scope("conv_layer2") as layer2:layer2_output = convolutional_layer(data=layer1_output,kernel_size=kernel_shape2,bias_size=bias_shape2,pooling_size=[1, 2, 2, 1])with tf.variable_scope("full_connection") as full_layer3:# 講卷積層張量數據拉成2-D張量只有有一列的列向量layer2_output_flatten = tf.contrib.layers.flatten(layer2_output)layer3_output = tf.nn.relu(linear_layer(data=layer2_output_flatten,weights_size=full_conn_w_shape,biases_size=full_conn_b_shape))# layer3_output = tf.nn.dropout(layer3_output, 0.8)with tf.variable_scope("output") as output_layer4:output = linear_layer(data=layer3_output,weights_size=out_w_shape,biases_size=out_b_shape)return output;def train_facedata(dataset, model_dir,model_path):# train_set_x = data[0][0]# train_set_y = data[0][1]# valid_set_x = data[1][0]# valid_set_y = data[1][1]# test_set_x = data[2][0]# test_set_y = data[2][1]# X = tf.placeholder(tf.float32, shape=(None, None), name="x-input")  # 輸入數據# Y = tf.placeholder(tf.float32, shape=(None, None), name='y-input')  # 輸入標簽
batch_size = 40# train_set_x, train_set_y = dataset[0]# valid_set_x, valid_set_y = dataset[1]# test_set_x, test_set_y = dataset[2]train_set_x = dataset[0][0]train_set_y = dataset[0][1]valid_set_x = dataset[1][0]valid_set_y = dataset[1][1]test_set_x = dataset[2][0]test_set_y = dataset[2][1]X = tf.placeholder(tf.float32, [batch_size, 57 * 47])Y = tf.placeholder(tf.float32, [batch_size, 40])predict = convolutional_neural_network(X)cost_func = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=predict, labels=Y))optimizer = tf.train.AdamOptimizer(1e-2).minimize(cost_func)# 用于保存訓練的最佳模型saver = tf.train.Saver()#model_dir = './model'#model_path = model_dir + '/best.ckpt'
    with tf.Session() as session:# 若不存在模型數據,需要訓練模型參數if not os.path.exists(model_path + ".index"):session.run(tf.global_variables_initializer())best_loss = float('Inf')for epoch in range(20):epoch_loss = 0for i in range((int)(np.shape(train_set_x)[0] / batch_size)):x = train_set_x[i * batch_size: (i + 1) * batch_size]y = train_set_y[i * batch_size: (i + 1) * batch_size]_, cost = session.run([optimizer, cost_func], feed_dict={X: x, Y: y})epoch_loss += costprint(epoch, ' : ', epoch_loss)if best_loss > epoch_loss:best_loss = epoch_lossif not os.path.exists(model_dir):os.mkdir(model_dir)print("create the directory: %s" % model_dir)save_path = saver.save(session, model_path)print("Model saved in file: %s" % save_path)# 恢復數據并校驗和測試
        saver.restore(session, model_path)correct = tf.equal(tf.argmax(predict,1), tf.argmax(Y,1))valid_accuracy = tf.reduce_mean(tf.cast(correct,'float'))print('valid set accuracy: ', valid_accuracy.eval({X: valid_set_x, Y: valid_set_y}))test_pred = tf.argmax(predict, 1).eval({X: test_set_x})test_true = np.argmax(test_set_y, 1)test_correct = correct.eval({X: test_set_x, Y: test_set_y})incorrect_index = [i for i in range(np.shape(test_correct)[0]) if not test_correct[i]]for i in incorrect_index:print('picture person is %i, but mis-predicted as person %i'%(test_true[i], test_pred[i]))plot_errordata(incorrect_index, "olivettifaces.gif")#畫出在測試集中錯誤的數據
def plot_errordata(error_index, dataset_path):img = mpimg.imread(dataset_path)plt.imshow(img)currentAxis = plt.gca()for index in error_index:row = index // 2column = index % 2currentAxis.add_patch(patches.Rectangle(xy=(47 * 9 if column == 0 else 47 * 19,row * 57),width=47,height=57,linewidth=1,edgecolor='r',facecolor='none'))plt.savefig("result.png")plt.show()def main():dataset_path = "olivettifaces.gif"data = load_data(dataset_path)model_dir = './model'model_path = model_dir + '/best.ckpt'train_facedata(data, model_dir, model_path)if __name__ == "__main__" :main()
復制代碼

?C:\python36\python.exe X:/DeepLearning/code/face/TensorFlow_CNN_face/facerecognition_main.py
valid set accuracy:? 0.825
picture person is 0, but mis-predicted as person 23
picture person is 6, but mis-predicted as person 38
picture person is 8, but mis-predicted as person 34
picture person is 15, but mis-predicted as person 11
picture person is 24, but mis-predicted as person 7
picture person is 29, but mis-predicted as person 7
picture person is 33, but mis-predicted as person 39

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/387221.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/387221.shtml
英文地址,請注明出處:http://en.pswp.cn/news/387221.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

數據結構01緒論

第一章緒論 1.1 什么是數據結構 數據結構是一門研究非數值計算的程序設計問題中,計算機的操作對象以及他們之間的關系和操作的學科。 面向過程程序數據結構算法 數據結構是介于數學、計算機硬件、計算機軟件三者之間的一門核心課程。 數據結構是程序設計、編譯…

css3動畫、2D與3D效果

1.兼容性 css3針對同一樣式在不同瀏覽器的兼容 需要在樣式屬性前加上內核前綴; 谷歌(chrome) -webkit-transition: Opera(歐鵬) -o-transition: Firefox(火狐) -moz-transition Ie -ms-tr…

ES6學習筆記(六)數組的擴展

1.擴展運算符 1.1含義 擴展運算符(spread)是三個點(...)。它好比 rest 參數的逆運算,將一個數組轉為用逗號分隔的參數序列。 console.log(...[1, 2, 3]) // 1 2 3console.log(1, ...[2, 3, 4], 5) // 1 2 3 4 5[...doc…

數據結構02線性表

第二章 線性表 C中STL順序表:vector http://blog.csdn.net/weixin_37289816/article/details/54710677鏈表:list http://blog.csdn.net/weixin_37289816/article/details/54773406在數據元素的非空有限集中: (1)存在唯一一個被稱作“第…

訓練一個神經網絡 能讓她認得我

寫個神經網絡,讓她認得我(?????)(Tensorflow,opencv,dlib,cnn,人臉識別) 這段時間正在學習tensorflow的卷積神經網絡部分,為了對卷積神經網絡能夠有一個更深的了解,自己動手實現一個例程是比較好的方式,所以就選了一個這樣比…

數據結構03棧和隊列

第三章棧和隊列 STL棧:stack http://blog.csdn.net/weixin_37289816/article/details/54773495隊列:queue http://blog.csdn.net/weixin_37289816/article/details/54773581priority_queue http://blog.csdn.net/weixin_37289816/article/details/5477…

Java動態編譯執行

在某些情況下,我們需要動態生成java代碼,通過動態編譯,然后執行代碼。JAVA API提供了相應的工具(JavaCompiler)來實現動態編譯。下面我們通過一個簡單的例子介紹,如何通過JavaCompiler實現java代碼動態編譯…

樹莓派pwm驅動好盈電調及伺服電機

本文講述如何通過樹莓派的硬件PWM控制好盈電調來驅動RC車子的前進后退,以及如何驅動伺服電機來控制車子轉向。 1. 好盈電調簡介 車子上的電調型號為:WP-10BLS-A-RTR,在好盈官網并沒有搜到對應手冊,但找到一份通用RC競速車的電調使…

數據結構04串

第四章 串 STL:string http://blog.csdn.net/weixin_37289816/article/details/54716009計算機上非數值處理的對象基本上是字符串數據。 在不同類型的應用中,字符串具有不同的特點,要有效的實現字符串的處理,必須選用合適的存儲…

CAS單點登錄原理解析

CAS單點登錄原理解析 SSO英文全稱Single Sign On,單點登錄。SSO是在多個應用系統中,用戶只需要登錄一次就可以訪問所有相互信任的應用系統。CAS是一種基于http協議的B/S應用系統單點登錄實現方案,認識CAS之前首先要熟悉http協議、Session與Co…

JDK1.6版添加了新的ScriptEngine類,允許用戶直接執行js代碼。

JDK1.6版添加了新的ScriptEngine類,允許用戶直接執行js代碼。在Java中直接調用js代碼 不能調用瀏覽器中定義的js函數,會拋出異常提示ReferenceError: “alert” is not defined。[java] view plaincopypackage com.sinaapp.manjushri; import javax.sc…

數據結構05數組和廣義表

第五章 數組 和 廣義表 數組和廣義表可以看成是線性表在下述含義上的擴展:表中的數據元素本身也是一個數據結構。 5.1 數組的定義 n維數組中每個元素都受著n個關系的約束,每個元素都有一個直接后繼元素。 可以把二維數組看成是這樣一個定長線性表&…

k8s的ingress使用

ingress 可以配置一個入口來提供k8s上service從外部來訪問的url、負載平衡流量、終止SSL和提供基于名稱的虛擬主機。 配置ingress的yaml: 要求域名解析無誤 要求service對應的pod正常 一、test1.domain.com --> service1:8080 apiVersion: extensions/v1beta1…

JDK1.8中如何用ScriptEngine動態執行JS

JDK1.8中如何用ScriptEngine動態執行JS jdk1.6開始就提供了動態腳本語言諸如JavaScript動態的支持。這無疑是一個很好的功能,畢竟Java的語法不是適合成為動態語言。而JDK通過執行JavaScript腳本可以彌補這一不足。這也符合“Java虛擬機不僅僅是Java一種語言的虛擬機…

數據結構06樹和二叉樹

第六章 樹和二叉樹 6.1 樹的定義和基本術語 樹 Tree 是n個結點的有限集。 任意一棵非空樹中: (1)有且僅有一個特定的稱為根(root)的結點; (2)當n>1時,其余結點可…

2019.03.20 mvt,Django分頁

MVT模式 MVT各部分的功能: M全拼為Model,與MVC中的M功能相同,負責和數據庫交互,進行數據處理。 V全拼為View,與MVC中的C功能相同,接收請求,進行業務處理,返回響應。 T全拼為Tem…

CountDownLatch,CyclicBarrier和Semaphore

在java 1.5中,提供了一些非常有用的輔助類來幫助我們進行并發編程,比如CountDownLatch,CyclicBarrier和Semaphore,今天我們就來學習一下這三個輔助類的用法。以下是本文目錄大綱:一.CountDownLatch用法二.CyclicBarrie…

數據結構07排序

第十章內部排序 10.1 概述 排序就是把一組數據按關鍵字的大小有規律地排列。經過排序的數據更易于查找。 排序前KiKj,且Ki在前: 排序方法是穩定的,若排序后Ki在前; 排序方法是不穩定的,如排序后Kj在前。 分類: 內…

數據結構08查找

第九章 查找 另一種在實際應用中大量使用的數據結構--查找表。 所謂查找,即為在一個含有眾多的數據元素的查找表中找出某個“特定的”數據元素。 查找表 search table 是由同一類型的數據元素構成的集合。集合中的數據元素之間存在著完全松散的關系,故…

下載Centos7 64位鏡像

下載Centos7 64位鏡像 1.打開Centos官網 打開Centos官方網站地址:https://www.centos.org/,點擊Get CentOS Now 2.點擊Minimal ISO鏡像 Minimal ISO鏡像,與DVD ISO鏡像的差別有很多,這里只說兩點 1.Minimal ISO類似于Windows的純凈…