[tensorflow、神經網絡] - 使用tf和mnist訓練一個識別手寫數字模型,并測試

  • 參考
  • 包含: 1.層級的計算、2.訓練的整體流程、3.tensorboard畫圖、4.保存/使用模型、5.總體代碼(含詳細注釋)

1. 層級的計算

在這里插入圖片描述
如上圖,mnist手寫數字識別的訓練集提供的圖片是 28 * 28 * 1的手寫圖像,初始識別的時候,并不知道一次要訓練多少個數據,因此輸入的規模為 [None, 784]. 由于最終的標簽輸出的是10個數據,因此輸出的規模為[None, 10], 中間采取一個簡單的全連接層作為隱藏層,規模為[784, 10]

2. 訓練的整體流程

在這里插入圖片描述

  • 1.首先定義占位符:
# 訓練集數據
x = tf.placehodler(tf.float32,	[None, 784])
# 訓練集標簽
y_true = tf.placeholder(rf.int32, [None, 10])
  • 2.建立模型
# 隨機生成權重矩陣和偏置
# 權重
weight = tf.Variable(tf.random_normal([784, 10], mean =0.0, stddev=1.0), name="weight")
# 偏置
bias = tf.Variable(tf.constant(0.0, shape=[10]))
# 預測
y_predict = tf.matmul(x, weight) + bias
  • 3.計算平均損失
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y_true, logits=y_predict))
  • 4.優化方案(梯度下降)
train_op = tf.train.GradientDescentOptimizer(0.1).minimize(loss)
  • 5.計算損失率
equal_list = tf.equal(tf.argmax(y_true, 1), tf.argmax(y_predict, 1))
accuracy = tf.reduce_mean(tf.cast(equal_list, tf.float32))

3. tensorboard使用

# 按作用域命名
with tf.variable_scope("data"):passwith tf.variable_scope("full_layer"):pass
# 收集變量(單維度)
tf.summary.scalar("losses", loss)
tf.summary.scalar("acc", accuracy)# 收集變量(多維度)
tf.summary.histogram("weightes", weight)
tf.summary.histogram("biases", bias)# 將訓練的每一步寫入
with tf.Session() as sess:# 建立events文件,然后寫入filewriter = tf.summary.FileWriter("./tmp/", graph=sess.graph)for i in range(5000):# 寫入每步訓練的值summary = sess.run(merged, feed_dict={x: mnist_x, y_true: mnist_y})filewriter.add_summary(summary, i)
  • 4.模型的保存/使用
# 模型的初始化(一般寫在Session上面)
saver = tf.train.Saver()# Session中為模型保存分配資源
with tf.Session() as sess:# 保存模型saver.save(sess, "./tmp/ckpt/fc_model")# 加載模型saver.restore(sess, "./tmp/ckpt/fc_model")# 預測for i in range(100):x, y = mnist.test.next_batch(1)predict = tf.argmax(sess.run(y_predict, feed_dict={x: x_test, y_true: y_test}), 1).eval()

5.總體代碼

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_dataFLAGS = tf.app.flags.FLAGStf.app.flags.DEFINE_integer("is_train", 1, "0: 預測, 1: 訓練")"""單層(全連接層)實現手寫數字識別特征值[None, 784]  目標值[None, 10]1、 定義數據占位符特征值[None, 784]  目標值[None, 10]2、 建立模型隨機初始化權重和偏置w[784, 10]  by_predict = tf.matmul(x, w) + b3、 計算損失loss: 平均樣本的損失  4、 梯度下降優化5、 準確率計算:equal_list = tf.equal(tf.argmax(y, 1), tf.argmax(y_label, 1))    accuracy = tf.reduce_mean(tf.cast(equal_list, tf.float32))
"""def ful_connected():# 讀取數據mnist = input_data.read_data_sets("./data/mnist/input_data/", one_hot=True)# 1、 建立數據的占位符 x [None, 784]  y_true  [None, 10]with tf.variable_scope("data"):x = tf.placeholder(tf.float32, [None, 784])y_true = tf.placeholder(tf.int32, [None, 10])# 2、 建立一個全連接層的神經網絡 w [784, 10]  b [10]with tf.variable_scope("full_layer"):# 隨機初始化權重和偏置weight = tf.Variable(tf.random_normal([784, 10], mean=0.0, stddev=1.0), name="weight")bias = tf.Variable(tf.constant(0.0, shape=[10]))# 預測None個樣本的輸出結果  [None, 784] * [784, 10]  + [10] = [None, 10]y_predict = tf.matmul(x, weight) + bias# 3、 求出所有樣本的損失,然后求平均值with tf.variable_scope("softmax"):# 求平均交叉熵損失loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y_true, logits=y_predict))# 4、 梯度下降求出損失with tf.variable_scope("optimizer"):train_op = tf.train.GradientDescentOptimizer(0.1).minimize(loss)# 5、 計算準確率with tf.variable_scope("count_acc"):equal_list = tf.equal(tf.argmax(y_true, 1), tf.argmax(y_predict, 1))# equal_list  None個樣本  [1, 0, 1, 0, 1, 1, ....]accuracy = tf.reduce_mean(tf.cast(equal_list, tf.float32))# 收集變量(單維度)tf.summary.scalar("losses", loss)tf.summary.scalar("acc", accuracy)# 收集變量(高維度)tf.summary.histogram("weightes", weight)tf.summary.histogram("biases", bias)# 定義一個初始化變量的opinit_op = tf.global_variables_initializer()# 定義合并變量merged = tf.summary.merge_all()# 保存模型saver = tf.train.Saver()# 開啟會話訓練with tf.Session() as sess:# 初始化變量sess.run(init_op)# 建立events文件,然后寫入filewriter = tf.summary.FileWriter("./tmp/", graph=sess.graph)if FLAGS.is_train == 0:# 迭代步驟去訓練,更新參數預測for i in range(5000):# 取出真實存在的特征值 和 目標值mnist_x, mnist_y = mnist.train.next_batch(50)# 運行train_op訓練sess.run(train_op, feed_dict={x: mnist_x, y_true: mnist_y})# 寫入每步訓練的值summary = sess.run(merged, feed_dict={x: mnist_x, y_true: mnist_y})filewriter.add_summary(summary, i)# 打印損失print("訓練第%d步,準確率為:%f" % (i, sess.run(accuracy, feed_dict={x: mnist_x, y_true: mnist_y})))# 保存模型saver.save(sess, "./tmp/ckpt/fc_model")else:# 加載模型saver.restore(sess, "./tmp/ckpt/fc_model")# 預測for i in range(100):# 每次測試一張圖片x_test, y_test = mnist.test.next_batch(1)print("第%d張圖片是: %d,預測結果是:%d" % (i,tf.argmax(y_test, 1).eval(),tf.argmax(sess.run(y_predict, feed_dict={x: x_test, y_true: y_test}), 1).eval()))return Noneif __name__ == "__main__":ful_connected()

6. cnn版本的mnist

import tensorflow as  tf
from tensorflow.examples.tutorials.mnist import input_data"""使用卷積神經網絡實現 mnist的手寫數據集識別
""""""input: [None, 784]output: [784, 10]進入卷積時,首先需要改變圖片的形狀 [None, 784] --> [None, 28, 28, 1]卷積網絡設計:· 第一層卷積層: 32 * core(5*5)、strides(1)、padding="SAME"· 此時大小為: [None, 28, 28, 32]· 激活· 池化: 2*2、 strides(2)、 padding="SAME"· 此時大小為: [None, 14, 14, 32]· 第二層卷積層: 64 * core(5*5)、 strides(1)、 padding="SAME"· 此時大小為: [None, 14, 14, 64]· 激活· 池化: 2*2、 strides(2)、 padding="SAME"· 此時大小為: [None, 7, 7, 64]· 全連接層: [None, 7*7*64] * [7*7*64, 10] + bias = [None, 10]
"""# 定義個初始化權重的函數
def weight_variable(shape):w = tf.Variable(tf.random_normal(shape=shape, mean=0.0, stddev=1.0))return w# 定義一個初始化偏置的函數
def bias_variables(shape):b = tf.Variable(tf.constant(0.0, shape=shape))return bdef model():"""自定義的卷積模型:return:"""# 1、準備數據的占位符 x [None, 784] 、 y_true [None, 10]with tf.variable_scope("data"):x = tf.placeholder(tf.float32, [None, 784])y_true = tf.placeholder(tf.int32, [None, 10])# 2、一卷積層 卷積: 5*5*1, 32個, strides = 1 、激活、池化with tf.variable_scope("conv1"):# 隨機初始化權重,偏置[32]w_conv1 = weight_variable([5, 5, 1, 32])b_conv1 = bias_variables([32])# 對x進行形狀的改變 [None, 784] -> [None, 28, 28, 1]x_reshape = tf.reshape(x, [-1, 28, 28, 1])# [None, 28, 28, 1] -> [None, 28, 28, 32]x_relu1 = tf.nn.relu(tf.nn.conv2d(x_reshape, w_conv1, strides=[1, 1, 1, 1], padding="SAME") + b_conv1)# 池化 2*2, strides2 [None, 28, 28, 32] -> [None, 14, 14, 32]x_pool1 = tf.nn.max_pool(x_relu1, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding="SAME")# 3、二卷積層  5*5*32, 64個filter, strides= 1with tf.variable_scope("conv2"):w_conv2 = weight_variable([5, 5, 32, 64])b_conv2 = bias_variables([64])# 卷積、激活、池化計算# [None, 14, 14, 32] -> [None, 14, 14, 64]x_relu2 = tf.nn.relu(tf.nn.conv2d(x_pool1, w_conv2, strides=[1, 1, 1, 1], padding="SAME") + b_conv2)# 池化 2*2, strides2 [None, 14, 14, 64] -> [None, 7, 7, 64]x_pool2 = tf.nn.max_pool(x_relu2, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding="SAME")# 4、全連接層 [None, 7, 7, 64] --> [None, 7*7*64] * [7*7*64, 10] + [10] = [None, 10]# 隨機初始化權重和偏置w_fc = weight_variable([7 * 7 * 64, 10])b_fc = bias_variables([10])# 修改形狀: [None, 7, 7, 64] -> [None, 7*7*64]x_fc_reshape = tf.reshape(x_pool2, [-1, 7 * 7 * 64])# 矩陣運算,得出每個樣本的10個結果y_predict = tf.matmul(x_fc_reshape, w_fc) + b_fcreturn x, y_true, y_predictdef conf_fc():# 1、 讀取數據mnist = input_data.read_data_sets("./data/mnist/input_data/", one_hot=True)# 2、 定義模型,得出輸出x, y_true, y_predict = model()# 3、 求出所有的損失,然后求平均值with tf.variable_scope("soft_cross"):# 求平均交叉熵損失loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y_true, logits=y_predict))# 4、 梯度下降求出損失with tf.variable_scope("optimizer"):train_op = tf.train.GradientDescentOptimizer(0.00005).minimize(loss)# 5、 計算準確率with tf.variable_scope("acc"):equal_list = tf.equal(tf.argmax(y_true, 1), tf.argmax(y_predict, 1))accuracy = tf.reduce_mean(tf.cast(equal_list, tf.float32))# 定義一個初始變量opinit_op = tf.global_variables_initializer()# 開啟會話運行with tf.Session() as sess:sess.run(init_op)# 循環去訓練for i in range(1000):# 取出真實存在的特征值和目標值mnist_x, mnist_y = mnist.train.next_batch(50)# 運行train_op訓練sess.run(train_op, feed_dict={x: mnist_x, y_true: mnist_y})# 打印損失print("訓練第%d步,準確率為:%f" % (i, sess.run(accuracy, feed_dict={x: mnist_x, y_true: mnist_y})))return Noneif __name__ == "__main__":conf_fc()

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/250199.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/250199.shtml
英文地址,請注明出處:http://en.pswp.cn/news/250199.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

面向過程、面向函數、面向對象的區別淺談

Python的面向過程、面向函數、面向對象的區別淺談 轉自--獵奇古今,加上其他 有人之前私信問我,python編程有面向過程、面向函數、面向對象三種,那么他們區別在哪呢? 面向過程就是將編程當成是做一件事,要按步驟完成&am…

js算法初窺06(算法模式03-函數式編程)

在解釋什么是函數式編程之前,我們先要說下什么是命令式編程,它們都屬于編程范式的一種。命令式編程其實就是一塊一塊的代碼,其中包括了我們要執行的邏輯或者判斷或者一些運算。也就是按部就班的一步一步完成我們所需要的邏輯。而函數式編程則…

[mmdetection] - win10配置mmdetection(1.1和2.0) + 訓練網絡(faster-rcnn、mask-rcnn)

pytorch配置 - 參考 mmdetextion 配置(win10) mmdetection訓練faster-rcnn (win10) mmdetection訓練mask-rcnn (win10) mmdetection 2.0配置(win10) mmdetection 2.0訓練Faster-RCNN(win10) mmdetection 2.0全家桶訓練(終結版) labelme安裝教程 l…

13、Spring Boot 2.x 多數據源配置

1.13 Spring Boot 2.x 多數據源配置 完整源碼: Spring-Boot-Demos轉載于:https://www.cnblogs.com/Grand-Jon/p/9999779.html

[pytorch、學習] - 3.5 圖像分類數據集

參考 3.5. 圖像分類數據集 在介紹shftmax回歸的實現前我們先引入一個多類圖像分類數據集 本章開始使用pytorch實現啦~ 本節我們將使用torchvision包,它是服務于PyTorch深度學習框架的,主要用來構建計算機視覺模型。torchvision主要由以下幾部分構成: torchvision.datasets: …

python自動化第三周---文件讀寫

1.python文件對象提供了三個“讀”方法: read()、readline() 和 readlines()。每種方法可以接受一個變量以限制每次讀取的數據量。 read() 每次讀取整個文件,它通常用于將文件內容放到一個字符串變量中。如果文件大于可用內存,為了保險起見&a…

最詳細的java泛型詳解

來源:最詳細的java泛型詳解 對java的泛型特性的了解僅限于表面的淺淺一層,直到在學習設計模式時發現有不了解的用法,才想起詳細的記錄一下。 本文參考java 泛型詳解、Java中的泛型方法、 java泛型詳解 1. 概述 泛型在java中有很重要的地位&a…

[pytorch、學習] - 3.6 softmax回歸的從零開始實現

參考 3.6 softmax回歸的從零開始實現 import torch import torchvision import numpy as np import sys sys.path.append("..") import d2lzh_pytorch as d2l3.6.1. 獲取和讀取數據 batch_size 256 train_iter, test_iter d2l.load_data_fashion_mnist(batch_si…

Django基礎必備三件套: HttpResponse render redirect

1. HttpResponse : 它的作用是內部傳入一個字符串參數, 然后發給瀏覽器 def index(request):return HttpResponse(ok) 2. render : 可以接收三個參數, 一是request參數, 二是待渲染的 html 模板文件, 三是保存具體數據的字典參數 def index(request):return render(request, …

React 簡單實例 (React-router + webpack + Antd )

React Demo Github 地址 經過React Native 的洗禮之后,寫了這個 demo ;React 是為了使前端的V層更具組件化,能更好的復用,同時可以讓你從操作dom中解脫出來,只需要操作數據就會改變相應的dom; 而React Nat…

[pytorch、學習] - 3.7 softmax回歸的簡潔實現

參考 3.7. softmax回歸的簡潔實現 使用pytorch實現softmax import torch from torch import nn from torch.nn import init import numpy as np import sys sys.path.append("..") import d2lzh_pytorch as d2l3.7.1. 獲取和讀取數據 batch_size 256 train_iter…

【模板】NTT

NTT模板 #include<bits/stdc.h> using namespace std; #define LL long long const int MAXL22; const int MAXN1<<MAXL; const int Mod998244353; int rev[MAXN],A[MAXN],B[MAXN],C[MAXN]; int fast_pow(int a,int b){int ans1;while(b){if(b&1)ans1ll*ans*a%…

centos 7 php7 yum源

rpm -Uvh https://dl.fedoraproject.org/pub/epel/epel-release-latest-7.noarch.rpmrpm -Uvh https://mirror.webtatic.com/yum/el7/webtatic-release.rpm 轉載于:https://www.cnblogs.com/myJuly/p/10008252.html

[pytorch、學習] - 3.9 多重感知機的從零開始實現

參考 3.9 多重感知機的從零開始實現 import torch import numpy as np import sys sys.path.append("..") import d2lzh_pytorch as d2l3.9.1. 獲取和讀取數據 batch_size 256 train_iter, test_iter d2l.load_data_fashion_mnist(batch_size)3.9.2. 定義模型參…

C語言逗號運算符和逗號表達式基礎總結

逗號運算符的作用&#xff1a; 1&#xff0c;起分隔符的作用&#xff1a; 定義變量用于分隔變量&#xff1a;int a,b輸入或輸出時用于分隔輸出表列 printf("%d%d",a,b) 2,用于逗號表達式的順序運算符 語法&#xff1a;表達式1&#xff0c;表達式2&#xff0c;...,表達…

java基礎-泛型舉例詳解

泛型 泛型是JDK5.0增加的新特性&#xff0c;泛型的本質是參數化類型&#xff0c;即所操作的數據類型被指定為一個參數。這種類型參數可以在類、接口、和方法的創建中&#xff0c;分別被稱為泛型類、泛型接口、泛型方法。 一、認識泛型 在沒有泛型之前,通過對類型Object的引用來…

MySQL數據庫視圖(view),視圖定義、創建視圖、修改視圖

原文鏈接&#xff1a;https://blog.csdn.net/moxigandashu/article/details/63254901轉載于:https://www.cnblogs.com/chrdai/p/9131881.html

[pytorch、學習] - 3.10 多重感知機的簡潔實現

參考 3.10. 多重感知機的簡潔實現 import torch from torch import nn from torch.nn import init import numpy as np import sys sys.path.append("..") import d2lzh_pytorch as d2l3.10.1. 定義模型 num_inputs, num_outputs, num_hiddens 784, 10, 256 # 參…

【匯編語言】——第三章課后總結

第三章 的書本上主要有以下幾個內容&#xff1a; 1.內存中字的存儲 字單元&#xff1a;即存放一個字型數據&#xff08;16位&#xff09;的內存單元&#xff0c;由兩個地址連續的內存單元組成。 小端法&#xff1a;高地址內存單元中存放字型數據的高位字節&#xff0c;低地址內…

如何從 Android 手機免費恢復已刪除的通話記錄/歷史記錄?

有一個有合作意向的人給我打電話&#xff0c;但我沒有接聽。更糟糕的是&#xff0c;我錯誤地將其刪除&#xff0c;認為這是一個騷擾電話。那么有沒有辦法從 Android 手機恢復已刪除的通話記錄呢&#xff1f;” 塞繆爾問道。如何在 Android 上恢復已刪除的通話記錄&#xff1f;如…