【TensorFlow學習筆記:神經網絡八股】(實現MNIST數據集手寫數字識別分類以及FASHION數據集衣褲識別分類)

課程來源:人工智能實踐:Tensorflow筆記2

文章目錄

  • 前言
  • 一、搭建網絡八股sequential
    • 1.函數介紹
    • 2.6步法實現鳶尾花分類
  • 二、搭建網絡八股class
    • 1.創建自己的神經網絡模板:
    • 2.調用自己創建的model對象
  • 三、MNIST數據集
    • 1.用sequential搭建網絡實現手寫數字識別
    • 2.用類搭建網絡實現手寫數字識別
  • 四、FASHION數據集
    • 用sequential搭建網絡實現衣褲識別
  • 總結


前言

本講目標:使用八股搭建神經網絡 神經網絡搭建八股 iris代碼復現 MNIST數據集 訓練MNIST數據集 Fashion數據集

一、搭建網絡八股sequential

使用六步法,使用TensorFlow的API: tf.keras搭建網絡八股
1、import 導入相關模塊
2、train、test 告知要喂入網絡的訓練集、測試集是什么,也就是要指定訓練集、測試集的輸入特征和訓練集的標簽
3、model = tf.keras.models.Sequential 在sequential()中搭建網絡結構,逐層描述每層網絡,相當于走了一遍前向傳播
4、model.compile 在compile中配置訓練方法,告知訓練時選擇哪種優化器,選擇哪個損失函數,選擇哪種評測指標
5、model.fit 在fit中執行訓練過程,告知訓練集和測試集的輸入特征和標簽,告知每個batch是多少,告知要迭代多少次數據集
6、model.summary 用summary打印出網絡的結構和參數統計

1.函數介紹

sequential()用法:
model = tf.keras.models.Sequential([網絡結構]) #描述各層網絡
網絡結構舉例:
拉直層:tf.keras.layers.Flatter()
全連接層:tf.keras.layers.Dense(神經元個數,activation=“激活函數”,kernel_regularizer=哪種正則化)
activation(字符串給出) 可選:relu、softmax、signoid、tanh
kernel_regularizer可選:kernel_regularizer.l1()、kernel_regularizer.l2()
卷積層:tf.keras.layers.Conv2D(filters=卷積核個數,kernel_size=卷積核尺寸,strides=卷積步長,padding=“vaild” or “same”)
LSTM層:tf.kreas.layers.LSTM()

compile() 用法:
model.compile(optimizer =優化器,loss =損失函數,metrics=[“準確率”])
optimizer 可選:
‘sgd’ or tf.keras.optimizers.SGD(lr=學習率,momentum=動量參數)
‘adagrad’ or tf.keras.optimizers.Adagrad(lr=學習率)
‘adadelta’ or tf.keras.optimizers.Adadelta(lr=學習率)
‘adam’ or tf.keras.optimizers.Adam(lr=學習率,beta_1=0.9,beta_2=0.999)
loss 可選:
‘mse’ or tf.keras.losses.MeanSquaredError()
‘sparse_categorical_crossentropy’ or tf.keras.losses.SparseCategoricalCrossentropy(from_logits =False)
(有的神經網絡的輸出是經過了softmax等函數的概率分布,有些則不經概率分布直接輸出,from_logits 是在詢問是否是原始輸出)
Metrics 可選:
‘accuracy’:y_pred和y都是數值,如y_pred=[1] y=[1]
‘categorical_accuracy’:y_pred 和 y 都是獨熱碼(概率分布),如y_pred=[0,1,0], y=[0.5,0.5,0.5]
‘sparse_categorical_accuracy’:y_pred是數值,y是獨熱編碼,y_pred=[1],y=[0.5,0.5,0.5]

fit()用法:
model.fit(訓練集的輸入特征,訓練集的標簽,
? batch_size = ,epochs = ,
? validation_data =(測試集的輸入特征,測試集的標簽),
? validation_split =從訓練集劃分多少比例給測試集,
? validation_freq =多少次epoch測試一次)

model()用法:
model.summary()
在這里插入圖片描述

2.6步法實現鳶尾花分類

代碼如下:

import tensorflow as tf
from sklearn import datasets
import numpy as np#由于這里是選擇從訓練集劃分出測試集,所以不需要單獨導入test
x_train = datasets.load_iris().data
y_train = datasets.load_iris().target
#打亂順序
np.random.seed(116)
np.random.shuffle(x_train)
np.random.seed(116)
np.random.shuffle(y_train)
tf.random.set_seed(116)
#3個神經元,softmax激活,L2正則化
model = tf.keras.models.Sequential([tf.keras.layers.Dense(3, activation='softmax', kernel_regularizer=tf.keras.regularizers.l2())
])
#SGD優化器、學習率0.1,使用SparseCategoricalCrossentropy作為損失函數,由于神經網絡末端使用softmax函數,輸出為概率分布,所以from_logits為false
#鳶尾花數據集給的標簽為0,1,2,神經網絡前向傳播的輸出是概率分布,使用sparse_categorical_accuracy作為準確率
model.compile(optimizer=tf.keras.optimizers.SGD(lr=0.1),loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),metrics=['sparse_categorical_accuracy'])
#輸入訓練數據,一次喂入32組數據,迭代500次,從訓練集中劃分出20%作為測試集,每迭代20次訓練集就要在測試集中驗證一次準確率
model.fit(x_train, y_train, batch_size=32, epochs=500, validation_split=0.2, validation_freq=20)
#打印網絡結構和參數統計
model.summary()

打印結果如下:
在這里插入圖片描述

二、搭建網絡八股class

用sequential可以搭建出上層輸出就是下層輸入的順序網絡結構,但是無法寫出一些帶有跳連的非順序網絡結構。這時我們可以選擇用類class搭建神經網絡結構
使用六步法,使用TensorFlow的API: tf.keras搭建網絡八股
1、import
2、train、test
3、class MyMode(Model) model=MyModel
4、model.compile
5、model.fit
6、model.summary

1.創建自己的神經網絡模板:

偽代碼如下:

class MyModel(Model):def _init_(self):super(MyModel,self).init_()定義網絡結構塊def call(self,x):調用網絡結構塊,實現前向傳播return ymodel=MyModel()

代碼如下:

class IrisModel(Model):def __init__(self):super(IrisModel, self).__init__()#鳶尾花分類的單層網絡是含有3個神經元的全連接self.d1 = Dense(3, activation='softmax', kernel_regularizer=tf.keras.regularizers.l2())def call(self, x):y = self.d1(x)return y
#實例化名為model的對象
model = IrisModel()

2.調用自己創建的model對象

代碼如下:

import tensorflow as tf
from tensorflow.keras.layers import Dense
from tensorflow.keras import Model
from sklearn import datasets
import numpy as npx_train = datasets.load_iris().data
y_train = datasets.load_iris().targetnp.random.seed(116)
np.random.shuffle(x_train)
np.random.seed(116)
np.random.shuffle(y_train)
tf.random.set_seed(116)class IrisModel(Model):def __init__(self):super(IrisModel, self).__init__()self.d1 = Dense(3, activation='softmax', kernel_regularizer=tf.keras.regularizers.l2())def call(self, x):y = self.d1(x)return ymodel = IrisModel()model.compile(optimizer=tf.keras.optimizers.SGD(lr=0.1),loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),metrics=['sparse_categorical_accuracy'])model.fit(x_train, y_train, batch_size=32, epochs=500, validation_split=0.2, validation_freq=20)
model.summary()

打印結果如下:
在這里插入圖片描述

三、MNIST數據集

MNIST數據集

提供6萬張28x28像素點的0~9手寫數字圖片和標簽,用于訓練。
提供1萬張28x28像素點的0~9手寫數字圖片和標簽,用于測試。

導入數據集

mnist =tf.keras.datasets.mnist
(x_train,y_train),(x_test,y_test)=mnist.load_data()

作為輸入特征,輸入神經網絡時,將數據拉伸為一維數組

tf.keras.layers.Flatter()

1.用sequential搭建網絡實現手寫數字識別

code:

import tensorflow as tfmnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
#對輸入網絡的特征進行歸一化,使原本0~255的灰度值轉化為0~1的小數。
#把輸入特征的值變小更有利于神經網絡吸收
x_train, x_test = x_train / 255.0, x_test / 255.0
#用Sequential搭建網絡
model = tf.keras.models.Sequential([tf.keras.layers.Flatten(),                          #把輸入特征拉直為1維數組,即784(28*28)個數值tf.keras.layers.Dense(128, activation='relu'),      #定義第一層網絡有128個神經元,relu為激活函數tf.keras.layers.Dense(10, activation='softmax')     #定義第二層網絡有10個神經元,softmax使輸出符合概率分布
])
#用compile配置訓練方法
model.compile(optimizer='adam',                         loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),metrics=['sparse_categorical_accuracy'])
#每一輪訓練集迭代,執行一次測試集評測,隨著迭代輪數增加,手寫數字識別準確率不斷提升(使用測試集)
model.fit(x_train, y_train, batch_size=32, epochs=5, validation_data=(x_test, y_test), validation_freq=1)
model.summary()

print result:

Train on 60000 samples, validate on 10000 samples
Epoch 1/5
60000/60000 [] - 4s 62us/sample - loss: 0.2589 - sparse_categorical_accuracy: 0.9262 - val_loss: 0.1373 - val_sparse_categorical_accuracy: 0.9607
Epoch 2/5
60000/60000 [] - 2s 40us/sample - loss: 0.1114 - sparse_categorical_accuracy: 0.9676 - val_loss: 0.1027 - val_sparse_categorical_accuracy: 0.9699
Epoch 3/5
60000/60000 [] - 3s 43us/sample - loss: 0.0762 - sparse_categorical_accuracy: 0.9775 - val_loss: 0.0898 - val_sparse_categorical_accuracy: 0.9722
Epoch 4/5
60000/60000 [] - 2s 41us/sample - loss: 0.0573 - sparse_categorical_accuracy: 0.9822 - val_loss: 0.0851 - val_sparse_categorical_accuracy: 0.9752
Epoch 5/5
60000/60000 [] - 2s 41us/sample - loss: 0.0450 - sparse_categorical_accuracy: 0.9858 - val_loss: 0.0846 - val_sparse_categorical_accuracy: 0.9738
Model: “sequential”
Layer (type) Output Shape Param #
=================================================================
flatten (Flatten) multiple 0
dense (Dense) multiple 100480
Total params: 101,770
Trainable params: 101,770
Non-trainable params: 0


可以觀察到,隨著迭代輪數增加,準確率也不斷提升。訓練的參數也是極其多的,達到10萬多個。

2.用類搭建網絡實現手寫數字識別

只是實例化model的方法不同,其他與用sequential搭建網絡實現手寫數字識別一致。
init函數中定義了call函數中所用到的層,call函數中從輸入x到輸出y走過一次前向傳播,返回輸出y

import tensorflow as tf
from tensorflow.keras.layers import Dense, Flatten
from tensorflow.keras import Modelmnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0class MnistModel(Model):def __init__(self):super(MnistModel, self).__init__()self.flatten = Flatten()self.d1 = Dense(128, activation='relu')self.d2 = Dense(10, activation='softmax')def call(self, x):x = self.flatten(x)x = self.d1(x)y = self.d2(x)return ymodel = MnistModel()model.compile(optimizer='adam',loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),metrics=['sparse_categorical_accuracy'])model.fit(x_train, y_train, batch_size=32, epochs=5, validation_data=(x_test, y_test), validation_freq=1)
model.summary()

四、FASHION數據集

FASHION數據集

提供6萬張 28x28像素點的衣褲等圖片和標簽,用于訓練.
提供1萬張28x28像素點的衣褲等圖片和標簽,用于測試。
在這里插入圖片描述

導入數據集

fashion = tf.keras.datasets.fashion_mnist
(x_train, y_train),(x_test, y_test) = fashion.load_data()

用sequential搭建網絡實現衣褲識別

加載數據需要較長時間,需耐心等待

import tensorflow as tffashion = tf.keras.datasets.fashion_mnist
(x_train, y_train),(x_test, y_test) = fashion.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0model = tf.keras.models.Sequential([tf.keras.layers.Flatten(),tf.keras.layers.Dense(128, activation='relu'),tf.keras.layers.Dense(10, activation='softmax')
])model.compile(optimizer='adam',loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),metrics=['sparse_categorical_accuracy'])model.fit(x_train, y_train, batch_size=32, epochs=5, validation_data=(x_test, y_test), validation_freq=1)
model.summary()

用類的方法也可以實現,這里不做重復展開,套用八股模板即可。


總結

這個單元將整個訓練的構架走了一遍,并且以八股的形式做了總結,收獲很大。

課程鏈接:MOOC人工智能實踐:TensorFlow筆記2

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/378292.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/378292.shtml
英文地址,請注明出處:http://en.pswp.cn/news/378292.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

c語言 在執行區域沒有空格,C語言上機操作指導之TurboC.doc

C語言上機操作指導之 -------- Turbo C程序設計是實踐性很強的過程,任何程序都必須在計算機上運行,以檢驗程序的正確與否。因此在學習程序設計中,一定要重視上機實踐環節,通過上機可以加深理解 C語言的有關概念,以鞏固…

java 根據類名示例化類_Java即時類| from()方法與示例

java 根據類名示例化類即時類from()方法 (Instant Class from() method) from() method is available in java.time package. from()方法在java.time包中可用。 from() method is used to return a copy of the Instant from the given TemporalAccessor object. from()方法用于…

第十二章 圖形用戶界面

第十二章 圖形用戶界面 GUI就是包含按鈕、文本框等控件的窗口 Tkinter是事實上的Python標準GUI工具包 創建GUI示例應用程序 初探 導入tkinter import tkinter as tk也可導入這個模塊的所有內容 from tkinter import *要創建GUI,可創建一個將充當主窗口的頂級組…

Sqlserver 2005 配置 數據庫鏡像:數據庫鏡像期間可能出現的故障:鏡像超時機制

數據庫鏡像期間可能出現的故障 SQL Server 2005其他版本更新日期: 2006 年 7 月 17 日 物理故障、操作系統故障或 SQL Server 故障都可能導致數據庫鏡像會話失敗。數據庫鏡像不會定期檢查 Sqlservr.exe 所依賴的組件來驗證組件是在正常運行還是已出現故障。但對于某…

江西理工大學期末試卷c語言,2016年江西理工大學信息工程學院計算機應用技術(加試)之C語言程序設計復試筆試最后押題五套卷...

一、選擇題1. 設有函數定義:( )。A. B. C. D. 答:A則以下對函數sub 的調用語句中,正確的是【解析】函數的參數有兩個,第一個是整型,第二個是字符類型,在調用函數時,實參必須一個是整型&#xff…

第十三章 數據庫支持

第十三章 數據庫支持 本章討論Python數據庫API(一種連接到SQL數據庫的標準化方式),并演示如何使用這個API來執行一些基本的SQL。最后,本章將討論其他一些數據庫技術。 關Python支持的數據庫清單 Python數據庫API 標準數據庫API…

【神經網絡八股擴展】:自制數據集

課程來源:人工智能實踐:Tensorflow筆記2 文章目錄前言1、文件一覽2、將load_data()函數替換掉2、調用generateds函數4、效果總結前言 本講目標:自制數據集,解決本領域應用 將我們手中的圖片和標簽信息制作為可以直接導入的npy文件。 1、文件一覽 首先看…

java 批量處理 示例_Java中異常處理的示例

java 批量處理 示例Here, we will analyse some exception handling codes, to better understand the concepts. 在這里,我們將分析一些異常處理代碼 ,以更好地理解這些概念。 Try to find the errors in the following code, if any 嘗試在以下代碼中…

hdu 1465 不容易系列之一

http://acm.hdu.edu.cn/showproblem.php?pid1465 今天立神和我們講了錯排,才知道錯排原來很簡單,從第n個推起: 當n個編號元素放在n個編號位置,元素編號與位置編號各不對應的方法數用M(n)表示,那么M(n-1)就表示n-1個編號元素放在n-1個編號位置…

第十四章 網絡編程

第十四章 網絡編程 本章首先概述Python標準庫中的一些網絡模塊。然后討論SocketServer和相關的類,并介紹同時處理多個連接的各種方法。最后,簡單地說一說Twisted,這是一個使用Python編寫網絡程序的框架,功能豐富而成熟。 幾個網…

c語言輸出11258循環,c/c++內存機制(一)(轉)

一:C語言中的內存機制在C語言中,內存主要分為如下5個存儲區:(1)棧(Stack):位于函數內的局部變量(包括函數實參),由編譯器負責分配釋放,函數結束,棧變量失效。(2)堆(Heap):由程序員用…

【神經網絡八股擴展】:數據增強

課程來源:人工智能實踐:Tensorflow筆記2 文章目錄前言TensorFlow2數據增強函數數據增強網絡八股代碼:總結前言 本講目標:數據增強,增大數據量 關于我們為何要使用數據增強以及常用的幾種數據增強的手法,可以看看下面的文章&#…

C++:從C繼承的標準庫

C從C繼承了的標準庫 &#xff0c; 這就意味著 C 中 可以使用的標準庫函數 在C 中都可以使用 &#xff0c; 但是需要注意的是 &#xff0c; 這些標準庫函數在C中不再以 <xxx.h> 命名 &#xff0c; 而是變成了 <cxxx> 。 例如 &#xff1a; 在C中操作字符串的…

分享WCF聊天程序--WCFChat

無意中在一個國外的站點下到了一個利用WCF實現聊天的程序&#xff0c;作者是&#xff1a;Nikola Paljetak。研究了一下&#xff0c;自己做了測試和部分修改&#xff0c;感覺還不錯&#xff0c;分享給大家。先來看下運行效果&#xff1a;開啟服務&#xff1a;客戶端程序&#xf…

c# uri.host_C#| 具有示例的Uri.Equality()運算符

c# uri.hostUri.Equality()運算符 (Uri.Equality() Operator) Uri.Equality() Operator is overloaded which is used to compare two Uri objects. It returns true if two Uri objects contain the same Uri otherwise it returns false. Uri.Equality()運算符已重載&#xf…

第六章至第九章的單元測試

1,?助劑與纖維作用力大于纖維分子之間的作用力,則該助劑最好用作() 纖維增塑膨化劑。 2,助劑擴散速率快,優先占領纖維上的染座,但助劑與纖維之間作用力小于染料與纖維之間作用力,該助劑可以作為() 勻染劑。 3,助劑占領纖維上的染座,但助劑與纖維之間作用力大于染…

【神經網絡擴展】:斷點續訓和參數提取

課程來源&#xff1a;人工智能實踐:Tensorflow筆記2 文章目錄前言斷點續訓主要步驟參數提取主要步驟總結前言 本講目標:斷點續訓&#xff0c;存取最優模型&#xff1b;保存可訓練參數至文本 斷點續訓主要步驟 讀取模型&#xff1a; 先定義出存放模型的路徑和文件名&#xff0…

開發DBA(APPLICATION DBA)的重要性

開發DBA是干什么的&#xff1f; 1. 審核開發人員寫的SQL&#xff0c;并且糾正存在性能問題的SQL ---非常重要 2. 編寫復雜業務邏輯SQL&#xff0c;因為復雜業務邏輯SQL開發人員寫出的SQL基本上都是有性能問題的&#xff0c;與其讓開發人員寫&#xff0c;不如DBA自己寫。---非常…

javascript和var之間的區別?

You can define your variables in JavaScript using two keywords - the let keyword and the var keyword. The var keyword is the oldest way of defining and declaring variables in JavaScript whereas the let is fairly new and was introduced by ES15. 您可以使用兩…

小米手環6NFC安裝太空人表盤

以前看我室友峰哥、班長都有手環&#xff0c;一直想買個手環&#xff0c;不舍得&#xff0c;然后今年除夕的時候降價&#xff0c;一狠心&#xff0c;入手了&#xff0c;配上除夕的打年獸活動還有看春晚京東敲鼓領的紅包和這幾年攢下來的京東豆豆&#xff0c;原價279的小米手環6…