深度學習案例之基于 CNN 的 MNIST 手寫數字識別

一、模型結構

本文只涉及利用Tensorflow實現CNN的手寫數字識別,CNN的內容請參考:卷積神經網絡(CNN)

MNIST數據集的格式與數據預處理代碼input_data.py的講解請參考 :Tutorial (2)

二、實驗代碼

# -*- coding:utf-8 -*-
"""@Time  : @Author: Feng Lepeng@File  : mnist_cnn_tf_demo.py@Desc  : 手寫數字識別的CNN網絡 LeNet注意:一般情況下,我們都是直接將網絡結構翻譯成為這個代碼,最多稍微的修改一下網絡中的參數(超參數、窗口大小、步長等信息)https://deeplearnjs.org/demos/model-builder/https://js.tensorflow.org/#getting-started
"""
import math
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data# 數據加載
mnist = input_data.read_data_sets('data/mnist', one_hot=True)# 手寫數字識別的數據集主要包含三個部分:訓練集(5.5w, mnist.train)、測試集(1w, mnist.test)、驗證集(0.5w, mnist.validation)
# 手寫數字圖片大小是28*28*1像素的圖片(黑白),也就是每個圖片由784維的特征描述
train_img = mnist.train.images
train_label = mnist.train.labels
test_img = mnist.test.images
test_label = mnist.test.labels
train_sample_number = mnist.train.num_examples# 相關的參數、超參數的設置
# 學習率,一般學習率設置的比較小
learn_rate_base = 1.0
# 每次迭代的訓練樣本數量
batch_size = 64
# 展示信息的間隔大小
display_step = 1# 輸入的樣本維度大小信息
input_dim = train_img.shape[1]
# 輸出的維度大小信息
n_classes = train_label.shape[1]# 模型構建
# 1. 設置數據輸入的占位符
x = tf.placeholder(tf.float32, shape=[None, input_dim], name='x')
y = tf.placeholder(tf.float32, shape=[None, n_classes], name='y')
learn_rate = tf.placeholder(tf.float32, name='learn_rate')def learn_rate_func(epoch):"""根據給定的迭代批次,更新產生一個學習率的值:param epoch::return:"""return learn_rate_base * (0.9 ** int(epoch / 10))def get_variable(name, shape=None, dtype=tf.float32, initializer=tf.random_normal_initializer(mean=0, stddev=0.1)):"""返回一個對應的變量:param name::param shape::param dtype::param initializer::return:"""return tf.get_variable(name, shape, dtype, initializer)# 2. 構建網絡
def le_net(x, y):# 1. 輸入層with tf.variable_scope('input1'):# 將輸入的x的格式轉換為規定的格式# [None, input_dim] -> [None, height, weight, channels]net = tf.reshape(x, shape=[-1, 28, 28, 1])# 2. 卷積層with tf.variable_scope('conv2'):# 卷積# conv2d(input, filter, strides, padding, use_cudnn_on_gpu=True, data_format="NHWC", name=None) => 卷積的API# data_format: 表示的是輸入的數據格式,兩種:NHWC和NCHW,N=>樣本數目,H=>Height, W=>Weight, C=>Channels# input:輸入數據,必須是一個4維格式的圖像數據,具體格式和data_format有關,如果data_format是NHWC的時候,input的格式為: [batch_size, height, weight, channels] => [批次中的圖片數目,圖片的高度,圖片的寬度,圖片的通道數];如果data_format是NCHW的時候,input的格式為: [batch_size, channels, height, weight] => [批次中的圖片數目,圖片的通道數,圖片的高度,圖片的寬度]# filter: 卷積核,是一個4維格式的數據,shape: [height, weight, in_channels, out_channels] => [窗口的高度,窗口的寬度,輸入的channel通道數(上一層圖片的深度),輸出的通道數(卷積核數目)]# strides:步長,是一個4維的數據,每一維數據必須和data_format格式匹配,表示的是在data_format每一維上的移動步長,當格式為NHWC的時候,strides的格式為: [batch, in_height, in_weight, in_channels] => [樣本上的移動大小,高度的移動大小,寬度的移動大小,深度的移動大小],要求在樣本上和在深度通道上的移動必須是1;當格式為NCHW的時候,strides的格式為: [batch,in_channels, in_height, in_weight]# padding: 只支持兩個參數"SAME", "VALID",當取值為SAME的時候,表示進行填充,"在TensorFlow中,如果步長為1,并且padding為SAME的時候,經過卷積之后的圖像大小是不變的";當VALID的時候,表示多余的特征會丟棄;net = tf.nn.conv2d(input=net, filter=get_variable('w', [5, 5, 1, 20]), strides=[1, 1, 1, 1], padding='SAME')net = tf.nn.bias_add(net, get_variable('b', [20]))# 激勵 ReLu# tf.nn.relu => max(fetures, 0)# tf.nn.relu6 => min(max(fetures,0), 6)net = tf.nn.relu(net)# 3. 池化with tf.variable_scope('pool3'):# 和conv2一樣,需要給定窗口大小和步長# max_pool(value, ksize, strides, padding, data_format="NHWC", name=None)# avg_pool(value, ksize, strides, padding, data_format="NHWC", name=None)# 默認格式下:NHWC,value:輸入的數據,必須是[batch_size, height, weight, channels]格式# 默認格式下:NHWC,ksize:指定窗口大小,必須是[batch, in_height, in_weight, in_channels], 其中batch和in_channels必須為1# 默認格式下:NHWC,strides:指定步長大小,必須是[batch, in_height, in_weight, in_channels],其中batch和in_channels必須為1# padding: 只支持兩個參數"SAME", "VALID",當取值為SAME的時候,表示進行填充,;當VALID的時候,表示多余的特征會丟棄;net = tf.nn.max_pool(value=net, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')# 4. 卷積with tf.variable_scope('conv4'):net = tf.nn.conv2d(input=net, filter=get_variable('w', [5, 5, 20, 50]), strides=[1, 1, 1, 1], padding='SAME')net = tf.nn.bias_add(net, get_variable('b', [50]))net = tf.nn.relu(net)# 5. 池化with tf.variable_scope('pool5'):net = tf.nn.max_pool(value=net, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')# 6. 全連接with tf.variable_scope('fc6'):# 28 -> 14 -> 7(因為此時的卷積不改變圖片的大小)net = tf.reshape(net, shape=[-1, 7 * 7 * 50])net = tf.add(tf.matmul(net, get_variable('w', [7 * 7 * 50, 500])), get_variable('b', [500]))net = tf.nn.relu(net)# 7. 全連接with tf.variable_scope('fc7'):net = tf.add(tf.matmul(net, get_variable('w', [500, n_classes])), get_variable('b', [n_classes]))act = tf.nn.softmax(net)return act# 構建網絡
act = le_net(x, y)# 構建模型的損失函數
# softmax_cross_entropy_with_logits: 計算softmax中的每個樣本的交叉熵,logits指定預測值,labels指定實際值
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=act, labels=y))# 使用Adam優化方式比較多
# learning_rate: 要注意,不要過大,過大可能不收斂,也不要過小,過小收斂速度比較慢
train = tf.train.AdamOptimizer(learning_rate=learn_rate).minimize(cost)# 得到預測的類別是那一個
# tf.argmax:對矩陣按行或列計算最大值對應的下標,和numpy中的一樣
# tf.equal:是對比這兩個矩陣或者向量的相等的元素,如果是相等的那就返回True,反正返回False,返回的值的矩陣維度和A是一樣的
pred = tf.equal(tf.argmax(act, axis=1), tf.argmax(y, axis=1))
# 正確率(True轉換為1,False轉換為0)
acc = tf.reduce_mean(tf.cast(pred, tf.float32))# 初始化
init = tf.global_variables_initializer()with tf.Session() as sess:# 進行數據初始化sess.run(init)# 模型保存、持久化saver = tf.train.Saver()epoch = 0while True:avg_cost = 0# 計算出總的批次total_batch = int(train_sample_number / batch_size)# 迭代更新for i in range(total_batch):# 獲取x和ybatch_xs, batch_ys = mnist.train.next_batch(batch_size)feeds = {x: batch_xs, y: batch_ys, learn_rate: learn_rate_func(epoch)}# 模型訓練sess.run(train, feed_dict=feeds)# 獲取損失函數值avg_cost += sess.run(cost, feed_dict=feeds)# 重新計算平均損失(相當于計算每個樣本的損失值)avg_cost = avg_cost / total_batch# DISPLAY  顯示誤差率和訓練集的正確率以此測試集的正確率if (epoch + 1) % display_step == 0:print("批次: %03d 損失函數值: %.9f" % (epoch, avg_cost))# 這里之所以使用batch_xs和batch_ys,是因為我使用train_img會出現內存不夠的情況,直接就會退出feeds = {x: batch_xs, y: batch_ys, learn_rate: learn_rate_func(epoch)}train_acc = sess.run(acc, feed_dict=feeds)print("訓練集準確率: %.3f" % train_acc)feeds = {x: test_img, y: test_label, learn_rate: learn_rate_func(epoch)}test_acc = sess.run(acc, feed_dict=feeds)print("測試準確率: %.3f" % test_acc)if train_acc > 0.9 and test_acc > 0.9:saver.save(sess, './mnist/model')breakepoch += 1# 模型可視化輸出writer = tf.summary.FileWriter('./mnist/graph', tf.get_default_graph())writer.close()

?

?

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/453970.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/453970.shtml
英文地址,請注明出處:http://en.pswp.cn/news/453970.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

怎樣獲取linux命令幫助?

獲得命令使用幫助:內部命令:help COMMAND外部命令:COMMAND --help (大多數命令有help選項)命令手冊:manualman [章節號] COMMAND其中man數據庫是分章節的,相同的COMMAND出現在不同的章節表示…

編譯安裝 zbar 時兩次 make 帶來的驚喜

為了裝 php 的條形碼擴展模塊 php-zbarcode,先裝了一天的 ImageMagick 和 zbar。也許和我裝的 Ubuntu 17.10 的有版本兼容問題吧,總之什么毛病都有,apt 不行,PPA 源也不行,編譯安裝還有幾處源代碼出錯,裝不…

python數組的乘法_在Python中乘法非常大的2D數組

我必須在Python中將非常大的2D數組乘以大約100次.每個矩陣由3200032000元素組成.我正在使用np.dot(X,Y),但是每次乘法都需要很長時間…在我的代碼實例下面:import numpy as npX Nonefor i in range(100)multiplying Trueif X None:X generate_large_2darray()mu…

0階指數哥倫布編碼

指數哥倫布編碼 規定語法元素的編解碼模式的描述符如下: 比特串: b(8):任意形式的8比特字節(就是為了說明語法元素是為8個比特,沒有語法上的含義) f(n):n位固定模式比特串(其值固定,如forbidde…

TensorFolw 報錯

1、報錯1&#xff1a;ValueError: Only call softmax_cross_entropy_with_logits with named arguments (labels..., logits..., ...) 提示出錯如下&#xff1a; Traceback (most recent call last):File "/MNIST/softmax.py", line 12, in <module>cross_en…

CentOS7種搭建FTP服務器

安裝vsftpd 首先要查看你是否安裝vsftp [rootlocalhost /]# rpm -q vsftpd vsftpd-3.0.2-10.el7.x86_64 #顯示也就安裝成功了&#xff01; 如果沒有則安裝vsftpd [rootlocalhost/]# yum install -y vsftpd 完成后再檢查一遍 [rootlocalhost /]# whereis vsftpd vsf…

js循環

順序——要加分號結束 分支&#xff1a;讓程序根據條件不同執行不同的代碼 if else語句用來做分支的 if&#xff08;條件&#xff09;{代碼} if&#xff08;條件&#xff09;{代碼}else{代碼} else if&#xff08;條件&#xff09;{代碼} if是嵌套。 switch...case&#xff1…

x264函數調用關系圖

1 encoder 2 slice write 3 analyse FFMPEG中MPEG-2編解碼函數調用關系圖 1 Encoder &#xff08;函數調用從左到右&#xff0c;下同&#xff1b;圖片顯示不全時&#xff0c;請下載顯示&#xff09; 2 P幀運動估計流程圖 3 B幀運動估計流程圖 4 decoder ffmpeg的mpeg2編碼I幀代…

Tensorflow 加載預訓練模型和保存模型

使用tensorflow過程中&#xff0c;訓練結束后我們需要用到模型文件。有時候&#xff0c;我們可能也需要用到別人訓練好的模型&#xff0c;并在這個基礎上再次訓練。這時候我們需要掌握如何操作這些模型數據。看完本文&#xff0c;相信你一定會有收獲&#xff01; 一、Tensorfl…

在 ActiveReports 中嵌入 Spread 控件

Spread 是一款很出色的表格控件&#xff0c;Spread 可以使開發人員把具有兼容 Microsoft Excel 的電子表格添加到程序中。ActiveReports 提供了一個非常靈活的、簡單的報表環境。下面將展示怎樣在 ActiveReports 中使用 Spread for WinForm。和其他三方控件一樣&#xff0c;Spr…

sort()函數、C++

Sort&#xff08;&#xff09;函數是c一種排序方法之一&#xff0c;它使用的排序方法是類似于快排的方法&#xff0c;時間復雜度為n*log2(n) &#xff08;1&#xff09;Sort函數包含在頭文件為#include<algorithm>的c標準庫中。 II&#xff09;Sort函數有三個參數&#x…

python waitkey_python中VideoCapture(),read(),waitKey()的使用

有以下程序import cv2cap cv2.VideoCapture(0)while cap.isOpened():ret,frame cap.read()cv2.imshow(frame,frame)c cv2.waitKey(1)if c 27:breakcap.release()cv2.destroyAllWindows()說明&#xff1a;程序段里&#xff0c;1、cv2.VideoCapture()函數&#xff1a;cap cv…

深度學習案例之 驗證碼識別

本項目介紹利用深度學習技術&#xff08;tensorflow&#xff09;&#xff0c;來識別驗證碼&#xff08;4位驗證碼&#xff0c;具體的驗證碼的長度可以自己生成&#xff0c;可以在自己進行訓練&#xff09; 程序分為四個部分 1、生成驗證碼的程序&#xff0c;可生成數字字母大…

windows下使用pthread庫

最近在看《C多核高級編程》這本書&#xff0c;收集了些有用的東西&#xff0c;方便在windows下使用POSIX標準進行Pthread開發&#xff0c;有利于跨平臺。 -------------------------------------------------- windows下使用pthread庫時間:2010-01-27 07:41來源:羅索工作室 作…

day 05 多行輸出與多行注釋、字符串的格式化輸出、預設創建者和日期

msg"hello1 hello2 hello3 " print(msg) 顯示結果為&#xff1a; # " "只能進行單行的字符串 多行字符串用 ,前面設置變量&#xff0c;可以用 表示多行 msghello1 hello2 hello3print(msg) 顯示結果為&#xff1a; 當然如果沒有設置變量&#xff0c;…

python數值計算guess_【python】猜數字game,旨在提高初學者對Python循環結構的使用...

import random #引入生成隨機數的模塊需求&#xff1a;程序設定生成 1-20 之間的一個隨機數&#xff0c;讓用戶猜日期&#xff1a;2019-10-21作者&#xff1a;xiaoxiaohui目的&#xff1a;猜數字game&#xff0c;旨在提高初學者對Python 變量類型以及循環結構的使用。secretNu…

調試九法-總體規則

調試規則規則1 理解系統規則2 制造失敗規則3 不要想&#xff0c;而要看規則4 分而治之規則5 一次只改一個地方規則6 保持審計跟蹤規則7 檢查插頭規則8 獲得全新觀點規則9 如果你不修復bug&#xff0c;它將依然存在轉載于:https://www.cnblogs.com/uetucci/p/7987805.html

深度學習之循環神經網絡(Recurrent Neural Network,RNN)

遞歸神經網絡和循環神經網絡 循環神經網絡&#xff08;recurrent neural network&#xff09;&#xff1a;時間上的展開&#xff0c;處理的是序列結構的信息&#xff0c;是有環圖遞歸神經網絡&#xff08;recursive neural network&#xff09;&#xff1a;空間上的展開&#…

從北京回來的年輕人,我該告訴你點什么?

前言 就在上周末&#xff0c;我與公眾號里的一個當地粉絲見面了&#xff0c;一起吃了頓飯&#xff0c;順便聊了聊。先來簡單交代下我們這位粉絲&#xff08;以下簡稱小L&#xff09;的經歷以及訴求。 小L之前在北京八維研修學院培訓的PHP&#xff0c;因為家庭原因&#xff0c;沒…

Linphone編譯【轉載】

Linphone依賴太多的庫&#xff0c;以致于稍有疏失&#xff0c;就會在編譯&#xff0c;運行出錯&#xff0c;都是由于依賴庫安裝的問題。 1 基礎知識 1.1 動態庫的連接 很多人安裝完庫后&#xff0c;configure依然報告這個庫沒有。這是對linux動態庫知識匱乏造成&#xff0c;也就…