二、訓練fashion_mnist數據集

一、加載fashion_mnist數據集

fashion_mnist數據集中數據為28*28大小的10分類衣物數據集
其中訓練集60000張,測試集10000張

from tensorflow import keras
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as npfashion_mnist = keras.datasets.fashion_mnist
(train_images,train_labels),(test_images,test_labels) = fashion_mnist.load_data()print(train_images.shape)
"""
(60000, 28, 28)
"""
print(test_images.shape)
"""
(10000, 28, 28)
"""
print(train_labels.shape)
"""
(60000,)
"""
print(test_labels.shape)
"""
(60000,)
"""

光看像素值是不是能猜到這個圖片是啥了?

print(train_images[0])#看一下訓練集第一張圖片28*28像素點的值
"""
[[  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0][  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0][  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0][  0   0   0   0   0   0   0   0   0   0   0   0   1   0   0  13  73   0   0   1   4   0   0   0   0   1   1   0][  0   0   0   0   0   0   0   0   0   0   0   0   3   0  36 136 127  62  54   0   0   0   1   3   4   0   0   3][  0   0   0   0   0   0   0   0   0   0   0   0   6   0 102 204 176 134 144 123  23   0   0   0   0  12  10   0][  0   0   0   0   0   0   0   0   0   0   0   0   0   0 155 236 207 178 107 156 161 109  64  23  77 130  72  15][  0   0   0   0   0   0   0   0   0   0   0   1   0  69 207 223 218 216 216 163 127 121 122 146 141  88 172  66][  0   0   0   0   0   0   0   0   0   1   1   1   0 200 232 232 233 229 223 223 215 213 164 127 123 196 229   0][  0   0   0   0   0   0   0   0   0   0   0   0   0 183 225 216 223 228 235 227 224 222 224 221 223 245 173   0][  0   0   0   0   0   0   0   0   0   0   0   0   0 193 228 218 213 198 180 212 210 211 213 223 220 243 202   0][  0   0   0   0   0   0   0   0   0   1   3   0  12 219 220 212 218 192 169 227 208 218 224 212 226 197 209  52][  0   0   0   0   0   0   0   0   0   0   6   0  99 244 222 220 218 203 198 221 215 213 222 220 245 119 167  56][  0   0   0   0   0   0   0   0   0   4   0   0  55 236 228 230 228 240 232 213 218 223 234 217 217 209  92   0][  0   0   1   4   6   7   2   0   0   0   0   0 237 226 217 223 222 219 222 221 216 223 229 215 218 255  77   0][  0   3   0   0   0   0   0   0   0  62 145 204 228 207 213 221 218 208 211 218 224 223 219 215 224 244 159   0][  0   0   0   0  18  44  82 107 189 228 220 222 217 226 200 205 211 230 224 234 176 188 250 248 233 238 215   0][  0  57 187 208 224 221 224 208 204 214 208 209 200 159 245 193 206 223 255 255 221 234 221 211 220 232 246   0][  3 202 228 224 221 211 211 214 205 205 205 220 240  80 150 255 229 221 188 154 191 210 204 209 222 228 225   0][ 98 233 198 210 222 229 229 234 249 220 194 215 217 241  65  73 106 117 168 219 221 215 217 223 223 224 229  29][ 75 204 212 204 193 205 211 225 216 185 197 206 198 213 240 195 227 245 239 223 218 212 209 222 220 221 230  67][ 48 203 183 194 213 197 185 190 194 192 202 214 219 221 220 236 225 216 199 206 186 181 177 172 181 205 206 115][  0 122 219 193 179 171 183 196 204 210 213 207 211 210 200 196 194 191 195 191 198 192 176 156 167 177 210  92][  0   0  74 189 212 191 175 172 175 181 185 188 189 188 193 198 204 209 210 210 211 188 188 194 192 216 170   0][  2   0   0   0  66 200 222 237 239 242 246 243 244 221 220 193 191 179 182 182 181 176 166 168  99  58   0   0][  0   0   0   0   0   0   0  40  61  44  72  41  35   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0][  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0][  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0]]
"""

輸出以下這個照片

plt.imshow(train_images[0])

在這里插入圖片描述

二、開始訓練模型

model = keras.Sequential([keras.layers.Flatten(input_shape=(28,28)),#照片完全展平,一維數組形式keras.layers.Dense(128,activation=tf.nn.relu),#128個神經元keras.layers.Dense(10,activation=tf.nn.softmax)#輸出層0-9,一共十個
])

查看模型的結構
第一層784個,flatten層將輸入的2828圖像進行展開,排列成一行,2828=784

第二層128個,128個神經元;100480個參數,第一層的784和第二層的128全排列,784*128=100352,每一個都有一個bias偏置項,100352+128=100480

第三層10個,也就是10分類,10個不同的類別,到時候輸出10個概率值,哪個大就是哪一類;1290個參數,第二層128個神經元,分別于10進行全排列,128*10=1280,每一個都有一個bias偏置項,1280+10=1290

model.summary()
"""
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
flatten (Flatten)            (None, 784)               0         
_________________________________________________________________
dense (Dense)                (None, 128)               100480    
_________________________________________________________________
dense_1 (Dense)              (None, 10)                1290      
=================================================================
Total params: 101,770
Trainable params: 101,770
Non-trainable params: 0
_________________________________________________________________
"""

為了使得效果更好,將數據集中的圖像像素值都歸一化到0-1之間

train_images_y = train_images/255#對訓練圖像歸一化

訓練50次

model.compile(optimizer="adam",loss="sparse_categorical_crossentropy",metrics=['accuracy'])#指定優化方法和損失函數
model.fit(train_images_y,train_labels,epochs=50)#訓練

因為模型訓練的時候傳入的時訓練集歸一化之后的圖像
故,模型評估的時候也需要對測試集進行歸一化圖像

test_images_y = test_images/255#測試評估的時候需要對測試圖像也要歸一化
model.evaluate(test_images_y,test_labels)#evaluate評估效果
"""
[0.5110174604289234, 0.8845]
"""

從測試集中挑選幾個進行測試,實際上會輸出10個值,也就是可能性的概率值,最大的就是預測的類別

model.predict([[test_images[0]/255]])
"""
array([[2.2063166e-16, 1.1835037e-17, 7.4574429e-23, 2.0577940e-22,4.3680589e-17, 2.7080047e-08, 3.8249505e-15, 3.4797877e-06,1.4701404e-10, 9.9999654e-01]], dtype=float32)
"""

篩選模型預測出的值最大的那個

print(np.argmax(model.predict([[test_images[0]/255]])))
"""
9
"""

看下這個圖片的實際標簽

print(test_labels[0])
"""
9
"""

預測值和實際值一樣,說明預測對了

展示下這個圖片

plt.imshow(train_images[0])

在這里插入圖片描述

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/377649.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/377649.shtml
英文地址,請注明出處:http://en.pswp.cn/news/377649.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

jquerymobile 切換頁面時候閃爍問題

https://github.com/jquery/jquery-mobile/commit/acbec71e29b6acec6cd2087e84e8434fecc0053f 可以修改css好像是個bug -4,9 4,10 * Dual licensed under the MIT (MIT-LICENSE.txt) or GPL (GPL-LICENSE.txt) licenses.*/.spin {--webkit-animation-name: spin;--webkit-an…

二分法:兩個有序數組長度為N,找到第N、N+1大的數

題目 兩個有序數組長度為N,找到第N、N1大的數 思路1:雙指針,O(N)復雜度 簡述思路: 如果當前A指針指向的數組A的內容小于B指針指向的數組B的內容,那么A指針往右移動,然后nums(當前已經遍歷過的數字個數)也…

Javascript -- In

http://www.caveofprogramming.com/articles/javascript-2/javascript-in-using-the-in-operator-to-iterate-through-arrays-and-objects/ http://msdn.microsoft.com/en-us/library/ie/9k25hbz2(vvs.94).aspx轉載于:https://www.cnblogs.com/daishuguang/p/3392310.html

三、自動終止訓練

有時候,當模型損失函數值預期的效果時,就可以結束訓練了,一方面節約時間,另一方面防止過擬合 此時,設置損失函數值小于0.4,訓練停止 from tensorflow import keras import tensorflow as tf import matplo…

矩陣形狀| 使用Python的線性代數

Prerequisite: Linear Algebra | Defining a Matrix 先決條件: 線性代數| 定義矩陣 In the python code, we will add two Matrices. We can add two Matrices only and only if both the matrices have the same dimensions. Therefore, knowing the dimensions o…

[數據庫]oracle客戶端連服務器錯誤

昨天晚上和今天上午用11g客戶端連同事10g服務器,報錯: The Network Adapter could not establish the connection 檢查嘗試了好多次都沒好。 用程序連,依舊是報這個錯,所以一查就解決了! 參考:http://apps…

ASP.NET 抓取網頁內容

(轉)ASP.NET 抓取網頁內容 ASP.NET 抓取網頁內容-文字 ASP.NET 中抓取網頁內容是非常方便的,而其中更是解決了 ASP 中困擾我們的編碼問題。 需要三個類:WebRequest、WebResponse、StreamReader。 WebRequest、WebRespo…

leetcode 53. 最大子序和 動態規劃解法、貪心法以及二分法

題目 給定一個整數數組 nums ,找到一個具有最大和的連續子數組(子數組最少包含一個元素),返回其最大和。 示例: 輸入: [-2,1,-3,4,-1,2,1,-5,4] 輸出: 6 解釋: 連續子數組 [4,-1,2,1] 的和最大,為 6。 進階: 如果你…

四、卷積神經網絡(Convolution Neural Networks)

一、CNN(Convolution Neural Networks) 卷積神經網絡基本思想:識別物體的特征,來進行判斷物體 卷積Convolution:過濾器filter中的數值與圖片像素值對應相乘再相加,6 * 6卷積一次(步數為1)變成4 * 4 Max Pooling:對卷積…

POJ3096Surprising Strings(map)

題意:輸入很多字符串,以星號結束。判斷每個字符串是不是“Surprising Strings”,判斷方法是:以“ZGBG”為例,“0-pairs”是ZG,GB,BG,這三個子串不相同,所以是“0-unique”…

vs助手使用期過 編譯CEGUI的問題:error C2061: 語法錯誤: 標識符“__RPC__out_xcount_part” VS2010...

第一個問題,下一個破解版的VX_A.dll,將其覆蓋以前的dll即可, 但是目錄有所要求,如下: XP系統:系統盤\Documents and Settings\用戶名\Local Settings\Application win7或者vistaData\Microsoft\VisualStud…

五、項目實戰---識別人和馬

一、準備訓練數據 下載數據集 validation驗證集 train訓練集 數據集結構如下: 將數據集解壓到自己選擇的目錄下就行 最后的結構效果如下: 二、構建模型 ImageDataGenerator 真實數據中,往往圖片尺寸大小不一,需要裁剪成一樣…

leetcode 122. 買賣股票的最佳時機 II 思考分析

目錄題目貪心法題目 給定一個數組,它的第 i 個元素是一支給定股票第 i 天的價格。 設計一個算法來計算你所能獲取的最大利潤。你可以盡可能地完成更多的交易(多次買賣一支股票)。 注意:你不能同時參與多筆交易(你必…

css設置a連接禁用樣式_使用CSS禁用鏈接

css設置a連接禁用樣式Question: 題: Links are one of the most essential aspects of any web page or website. They play a very important role in making our website or web page quite responsive or interactive. So the topic for discussion is quite pe…

服務器出現 HTTP 錯誤代碼,及解決方法

HTTP 400 - 請求無效 HTTP 401.1 - 未授權:登錄失敗 HTTP 401.2 - 未授權:服務器配置問題導致登錄失敗 HTTP 401.3 - ACL 禁止訪問資源 HTTP 401.4 - 未授權:授權被篩選器拒絕 HTTP 401.5 - 未授權:ISAPI 或 CGI 授權失敗 HTTP 40…

leetcode 55. 跳躍游戲 思考分析

題目 給定一個非負整數數組,你最初位于數組的第一個位置。 數組中的每個元素代表你在該位置可以跳躍的最大長度。 判斷你是否能夠到達最后一個位置。 示例1: 輸入: [2,3,1,1,4] 輸出: true 解釋: 我們可以先跳 1 步,從位置 0 到達 位置 1…

六、項目實戰---識別貓和狗

一、準備數據集 kagglecatsanddogs網上一搜一大堆,這里我就不上傳了,需要的話可以私信 導包 import os import zipfile import random import shutil import tensorflow as tf from tensorflow.keras.optimizers import RMSprop from tensorflow.kera…

修改shell終端提示信息

PS1:就是用戶平時的提示符。PS2:第一行沒輸完,等待第二行輸入的提示符。公共設置位置:/etc/profile echo $PS1可以看到當前提示符設置例如:顯示綠色,并添加時間和shell版本export PS1"\[\e[32m\][\uyou are right…

java 字謎_計算字謎的出現次數

java 字謎Problem statement: 問題陳述: Given a string S and a word C, return the count of the occurrences of anagrams of the word in the text. Both string and word are in lowercase letter. 給定一個字符串S和一個單詞C ,返回該單詞在文本…

Origin繪制熱重TG和微分熱重DTG曲線

一、導入數據 二、傳到Origin中 三、熱重TG曲線 temp為橫坐標、mass為縱坐標 繪制折線圖 再稍微更改下格式 字體加粗,Times New Roman 曲線寬度設置為2 橫縱坐標數值格式為Times New Roman 根據實際情況改下橫縱坐標起始結束位置 四、微分熱重DTG曲線 點擊曲線…