第38周:貓狗識別 (Tensorflow實戰第八周)

目錄

前言

一、前期工作

1.1 設置GPU

1.2 導入數據

輸出

二、數據預處理

2.1 加載數據

2.2 再次檢查數據

2.3 配置數據集

2.4 可視化數據

三、構建VGG-16網絡

3.1 VGG-16網絡介紹

3.2 搭建VGG-16模型

四、編譯

五、訓練模型

六、模型評估

七、預測

總結


前言

  • 🍨 本文為中的學習記錄博客
  • 🍖 原作者:

說在前面

1)本周任務:了解model.train_on_batch()并運用;了解tqdm,并使用tqdm實現可視化進度條;

2)運行環境:Python3.6、Pycharm2020、tensorflow2.4.0


一、前期工作

1.1 設置GPU

代碼如下:

import os
os.environ["CUDA_VISIBLE_DEVICES"]="0"
os.environ["TF_CPP_MIN_LOG_LEVEL"]='3' # 忽略 Error
#隱藏警告
import warnings
warnings.filterwarnings('ignore')
# 1.1 設置GPU
import tensorflow as tf
gpus = tf.config.list_physical_devices("GPU")
if gpus:tf.config.experimental.set_memory_growth(gpus[0], True)  #設置GPU顯存用量按需使用tf.config.set_visible_devices([gpus[0]],"GPU")
# 打印顯卡信息,確認GPU可用
print(gpus)

輸出:[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]

??????前期我沒有使用GPU就采用的CPU訓練速度很慢,雖然安裝了tensorflow-gpu但還是用的CPU因為我的cudnn和cudatoolkit之前沒配置成功,然后我補充安裝。這里出線會打印很多關于gpu調用的日志信息,會很影響我們對訓練過程和打印信息的關注度,這里我在import tensorflow之前先通過下面的設置來控制打印的內容

import os
os.environ["CUDA_VISIBLE_DEVICES"]="0"
os.environ["TF_CPP_MIN_LOG_LEVEL"]='3'?

TF_CPP_MIN_LOG_LEVEL 取值 0 : 0也是默認值,輸出所有信息
TF_CPP_MIN_LOG_LEVEL 取值 1 : 屏蔽通知信息
TF_CPP_MIN_LOG_LEVEL 取值 2 : 屏蔽通知信息和警告信息
TF_CPP_MIN_LOG_LEVEL 取值 3 : 屏蔽通知信息、警告信息和報錯信息? ? ? ? ? ? ? ? ?
參考自:https://blog.csdn.net/xiaoqiaoliushuiCC/article/details/124435241

1.2 導入數據

代碼如下:

# 1.2 導入數據
import matplotlib.pyplot as plt
# 支持中文
plt.rcParams['font.sans-serif'] = ['SimHei']  # 用來正常顯示中文標簽
plt.rcParams['axes.unicode_minus'] = False  # 用來正常顯示負號
import os,PIL,pathlib
data_dir = "./data"
data_dir = pathlib.Path(data_dir)
image_count = len(list(data_dir.glob('*/*')))
print("圖片總數為:",image_count)

輸出

圖片總數為:??3400

二、數據預處理

2.1 加載數據

使用image_dataset_from_directory方法將磁盤中的數據加載到tf.data.Dataset,tf.keras.preprocessing.image_dataset_from_directory():是 TensorFlow 的 Keras 模塊中的一個函數,用于從目錄中創建一個圖像數據集(dataset)。這個函數可以以更方便的方式加載圖像數據,用于訓練和評估神經網絡模型

測試集與驗證集的關系:

  • 驗證集并沒有參與訓練過程梯度下降過程的,狹義上來講是沒有參與模型的參數訓練更新的。
  • 但是廣義上來講,驗證集存在的意義確實參與了一個“人工調參”的過程,我們根據每一個epoch訓練之后模型在valid data上的表現來決定是否需要訓練進行early stop,或者根據這個過程模型的性能變化來調整模型的超參數,如學習率,batch_size等等。因此,我們也可以認為,驗證集也參與了訓練,但是并沒有使得模型去overfit驗證集
  • 因此,我們也可以認為,驗證集也參與了訓練,但是并沒有使得模型去overfit驗證集

代碼如下:

# 二、數據預處理
# 2.1 加載數據
batch_size = 8
img_height = 224
img_width = 224
train_ds = tf.keras.preprocessing.image_dataset_from_directory(data_dir,validation_split=0.2,subset="training",seed=12,image_size=(img_height, img_width),batch_size=batch_size)
val_ds = tf.keras.preprocessing.image_dataset_from_directory(data_dir,validation_split=0.2,subset="validation",seed=12,image_size=(img_height, img_width),batch_size=batch_size)
class_names = train_ds.class_names
print(class_names)

輸出如下:

['cat', 'dog']

2.2 再次檢查數據

代碼如下:

# 2.2 再次檢查數據
for image_batch, labels_batch in train_ds:print(image_batch.shape)print(labels_batch.shape)break

輸出:

(8, 224, 224, 3)
(8,)

2.3 配置數據集

代碼如下:

# 2.3 配置數據集
AUTOTUNE = tf.data.AUTOTUNEdef preprocess_image(image,label):return (image/255.0,label)
# 歸一化處理
train_ds = train_ds.map(preprocess_image, num_parallel_calls=AUTOTUNE)
val_ds = val_ds.map(preprocess_image, num_parallel_calls=AUTOTUNE)
train_ds = train_ds.cache().shuffle(1000).prefetch(buffer_size=AUTOTUNE)
val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)

2.4 可視化數據

代碼如下:

plt.figure(figsize=(15, 10))  # 圖形的寬為15高為10
for images, labels in train_ds.take(1):for i in range(8):ax = plt.subplot(5, 8, i + 1)plt.imshow(images[i])plt.title(class_names[labels[i]])plt.axis("off")

輸出:

三、構建VGG-16網絡

3.1 VGG-16網絡介紹

結構說明:

  • 13個卷積層(Convolutional Layer),分別用blockX_convX表示
  • 3個全連接層(Fully connected Layer),分別用fcXpredictions表示
  • 5個池化層(Pool layer),分別用blockX_pool表示

網絡結構圖如下(包含了16個隱藏層--13個卷積層和3個全連接層,故稱為VGG-16)

???

?

3.2 搭建VGG-16模型

代碼如下:

# 三、構建VGG-16網絡
from tensorflow.keras import layers, models, Input
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Dense, Flatten, Dropoutdef VGG16(nb_classes, input_shape):input_tensor = Input(shape=input_shape)# 1st blockx = Conv2D(64, (3,3), activation='relu', padding='same',name='block1_conv1')(input_tensor)x = Conv2D(64, (3,3), activation='relu', padding='same',name='block1_conv2')(x)x = MaxPooling2D((2,2), strides=(2,2), name = 'block1_pool')(x)# 2nd blockx = Conv2D(128, (3,3), activation='relu', padding='same',name='block2_conv1')(x)x = Conv2D(128, (3,3), activation='relu', padding='same',name='block2_conv2')(x)x = MaxPooling2D((2,2), strides=(2,2), name = 'block2_pool')(x)# 3rd blockx = Conv2D(256, (3,3), activation='relu', padding='same',name='block3_conv1')(x)x = Conv2D(256, (3,3), activation='relu', padding='same',name='block3_conv2')(x)x = Conv2D(256, (3,3), activation='relu', padding='same',name='block3_conv3')(x)x = MaxPooling2D((2,2), strides=(2,2), name = 'block3_pool')(x)# 4th blockx = Conv2D(512, (3,3), activation='relu', padding='same',name='block4_conv1')(x)x = Conv2D(512, (3,3), activation='relu', padding='same',name='block4_conv2')(x)x = Conv2D(512, (3,3), activation='relu', padding='same',name='block4_conv3')(x)x = MaxPooling2D((2,2), strides=(2,2), name = 'block4_pool')(x)# 5th blockx = Conv2D(512, (3,3), activation='relu', padding='same',name='block5_conv1')(x)x = Conv2D(512, (3,3), activation='relu', padding='same',name='block5_conv2')(x)x = Conv2D(512, (3,3), activation='relu', padding='same',name='block5_conv3')(x)x = MaxPooling2D((2,2), strides=(2,2), name = 'block5_pool')(x)# full connectionx = Flatten()(x)x = Dense(4096, activation='relu',  name='fc1')(x)x = Dense(4096, activation='relu', name='fc2')(x)output_tensor = Dense(nb_classes, activation='softmax', name='predictions')(x)model = Model(input_tensor, output_tensor)return modelmodel=VGG16(1000, (img_width, img_height, 3))
model.summary()

模型結構打印如下:

?Model: "model"
_________________________________________________________________
Layer (type) ? ? ? ? ? ? ? ? Output Shape ? ? ? ? ? ? ?Param # ??
=================================================================
input_1 (InputLayer) ? ? ? ? [(None, 224, 224, 3)] ? ? 0 ? ? ? ??
_________________________________________________________________
block1_conv1 (Conv2D) ? ? ? ?(None, 224, 224, 64) ? ? ?1792 ? ? ?
_________________________________________________________________
block1_conv2 (Conv2D) ? ? ? ?(None, 224, 224, 64) ? ? ?36928 ? ??
_________________________________________________________________
block1_pool (MaxPooling2D) ? (None, 112, 112, 64) ? ? ?0 ? ? ? ??
_________________________________________________________________
block2_conv1 (Conv2D) ? ? ? ?(None, 112, 112, 128) ? ? 73856 ? ??
_________________________________________________________________
block2_conv2 (Conv2D) ? ? ? ?(None, 112, 112, 128) ? ? 147584 ? ?
_________________________________________________________________
block2_pool (MaxPooling2D) ? (None, 56, 56, 128) ? ? ? 0 ? ? ? ??
_________________________________________________________________
block3_conv1 (Conv2D) ? ? ? ?(None, 56, 56, 256) ? ? ? 295168 ? ?
_________________________________________________________________
block3_conv2 (Conv2D) ? ? ? ?(None, 56, 56, 256) ? ? ? 590080 ? ?
_________________________________________________________________
block3_conv3 (Conv2D) ? ? ? ?(None, 56, 56, 256) ? ? ? 590080 ? ?
_________________________________________________________________
block3_pool (MaxPooling2D) ? (None, 28, 28, 256) ? ? ? 0 ? ? ? ??
_________________________________________________________________
block4_conv1 (Conv2D) ? ? ? ?(None, 28, 28, 512) ? ? ? 1180160 ??
_________________________________________________________________
block4_conv2 (Conv2D) ? ? ? ?(None, 28, 28, 512) ? ? ? 2359808 ??
_________________________________________________________________
block4_conv3 (Conv2D) ? ? ? ?(None, 28, 28, 512) ? ? ? 2359808 ??
_________________________________________________________________
block4_pool (MaxPooling2D) ? (None, 14, 14, 512) ? ? ? 0 ? ? ? ??
_________________________________________________________________
block5_conv1 (Conv2D) ? ? ? ?(None, 14, 14, 512) ? ? ? 2359808 ??
_________________________________________________________________
block5_conv2 (Conv2D) ? ? ? ?(None, 14, 14, 512) ? ? ? 2359808 ??
_________________________________________________________________
block5_conv3 (Conv2D) ? ? ? ?(None, 14, 14, 512) ? ? ? 2359808 ??
_________________________________________________________________
block5_pool (MaxPooling2D) ? (None, 7, 7, 512) ? ? ? ? 0 ? ? ? ??
_________________________________________________________________
flatten (Flatten) ? ? ? ? ? ?(None, 25088) ? ? ? ? ? ? 0 ? ? ? ??
_________________________________________________________________
fc1 (Dense) ? ? ? ? ? ? ? ? ?(None, 4096) ? ? ? ? ? ? ?102764544?
_________________________________________________________________
fc2 (Dense) ? ? ? ? ? ? ? ? ?(None, 4096) ? ? ? ? ? ? ?16781312 ?
_________________________________________________________________
predictions (Dense) ? ? ? ? ?(None, 1000) ? ? ? ? ? ? ?4097000 ??
=================================================================
Total params: 138,357,544
Trainable params: 138,357,544
Non-trainable params: 0

四、編譯

代碼如下:

model.compile(optimizer="adam",loss='sparse_categorical_crossentropy',metrics=['accuracy'])

五、訓練模型

代碼如下:

# 五、訓練模型
from tqdm import tqdm
import tensorflow.keras.backend as Kepochs = 10
lr = 1e-4# 記錄訓練數據,方便后面的分析
history_train_loss = []
history_train_accuracy = []
history_val_loss = []
history_val_accuracy = []
for epoch in range(epochs):train_total = len(train_ds)val_total = len(val_ds)with tqdm(total=train_total, desc=f'Epoch {epoch + 1}/{epochs}', mininterval=1, ncols=100) as pbar:lr = lr * 0.92K.set_value(model.optimizer.lr, lr)for image, label in train_ds:history = model.train_on_batch(image, label)train_loss = history[0]train_accuracy = history[1]pbar.set_postfix({"loss": "%.4f" % train_loss,"accuracy": "%.4f" % train_accuracy,"lr": K.get_value(model.optimizer.lr)})pbar.update(1)history_train_loss.append(train_loss)history_train_accuracy.append(train_accuracy)print('開始驗證!')with tqdm(total=val_total, desc=f'Epoch {epoch + 1}/{epochs}', mininterval=0.3, ncols=100) as pbar:for image, label in val_ds:history = model.test_on_batch(image, label)val_loss = history[0]val_accuracy = history[1]pbar.set_postfix({"loss": "%.4f" % val_loss,"accuracy": "%.4f" % val_accuracy})pbar.update(1)history_val_loss.append(val_loss)history_val_accuracy.append(val_accuracy)print('結束驗證!')print("驗證loss為:%.4f" % val_loss)print("驗證準確率為:%.4f" % val_accuracy)

打印訓練過程:

?

六、模型評估

代碼如下:

epochs_range = range(epochs)
plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)plt.plot(epochs_range, history_train_accuracy, label='Training Accuracy')
plt.plot(epochs_range, history_val_accuracy, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')plt.subplot(1, 2, 2)
plt.plot(epochs_range, history_train_loss, label='Training Loss')
plt.plot(epochs_range, history_val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()

訓練結果可視化如下:

???

七、預測

代碼如下:

# 七、預測
import numpy as np
# 采用加載的模型(new_model)來看預測結果
plt.figure(figsize=(18, 3))  # 圖形的寬為18高為5
plt.suptitle("預測結果展示")
for images, labels in val_ds.take(1):for i in range(8):ax = plt.subplot(1, 8, i + 1)# 顯示圖片plt.imshow(images[i].numpy())# 需要給圖片增加一個維度img_array = tf.expand_dims(images[i], 0)# 使用模型預測圖片中的人物predictions = model.predict(img_array)plt.title(class_names[np.argmax(predictions)])plt.axis("off")

輸出:

1/1 [==============================] - 0s 129ms/step
1/1 [==============================] - 0s 19ms/step
1/1 [==============================] - 0s 18ms/step
1/1 [==============================] - 0s 18ms/step
1/1 [==============================] - 0s 17ms/step
1/1 [==============================] - 0s 18ms/step
1/1 [==============================] - 0s 17ms/step
1/1 [==============================] - 0s 17ms/step


總結

  • Tensorflow訓練過程中打印多余信息的處理,并且引入了進度條的顯示方式,更加方便及時查看模型訓練過程中的情況,可以及時打印各項指標
  • 修改了以往的model.fit()訓練方法,改用model.train_on_batch方法。兩種方法的比較:model.fit():用起來十分簡單,對新手非常友好;model.train_on_batch():封裝程度更低,可以玩更多花樣
  • 完成了VGG-16基于Tensorflow下的搭建、訓練等工作,對比分析了pytorch和tensorflow兩個框架下實現同種任務的異同;
  • 完成VGG-16對貓狗圖片的高精度識別

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/web/66958.shtml
繁體地址,請注明出處:http://hk.pswp.cn/web/66958.shtml
英文地址,請注明出處:http://en.pswp.cn/web/66958.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

我的2024年年度總結

序言 在前不久(應該是上周)的博客之星入圍賽中鎩羽而歸了。雖然心中頗為不甘,覺得這一年兢兢業業,每天都在發文章,不應該是這樣的結果(連前300名都進不了)。但人不能總抱怨,總要向前…

Trimble三維激光掃描-地下公共設施維護的新途徑【滬敖3D】

三維激光掃描技術生成了復雜隧道網絡的高度詳細的三維模型 項目背景 紐約州北部的地下通道網絡已有100年歷史,其中包含供暖系統、電線和其他公用設施,現在已經開始顯露出老化跡象。由于安全原因,第三方的進入受到限制,在沒有現成紙…

QT 中 UDP 的使用

目錄 一、UDP 簡介 二、QT 中 UDP 編程的基本步驟 (一)包含頭文件 (二)創建 UDP 套接字對象 (三)綁定端口 (四)發送數據 (五)接收數據 三、完整示例代…

開源鴻蒙開發者社區記錄

lava鴻蒙社區可提問 Laval社區 開源鴻蒙項目 OpenHarmony 開源鴻蒙開發者論壇 OpenHarmony 開源鴻蒙開發者論壇

Git上傳了秘鑰如何徹底修改包括歷史記錄【從安裝到實戰詳細版】

使用 BFG Repo-Cleaner 清除 Git 倉庫中的敏感信息 1. 背景介紹 在使用 Git 進行版本控制時,有時會不小心將敏感信息(如 API 密鑰、密碼等)提交到倉庫中。即使后續刪除,這些信息仍然存在于 Git 的歷史記錄中。本文將介紹如何使用…

多層 RNN原理以及實現

數學原理 多層 RNN 的核心思想是堆疊多個 RNN 層,每一層的輸出作為下一層的輸入,從而逐層提取更高層次的抽象特征。 1. 單層 RNN 的數學表示 首先,單層 RNN 的計算過程如下。對于一個時間步 t t t,單層 RNN 的隱藏狀態 h t h_t…

RNA 測序技術概覽(RNA-seq)

前言 轉錄組測序(RNA-seq)是當下最流行的二代測序(NGS)方法之一,使科研工作者實現在轉錄水平上定量、定性的研究,它的出現已經革命性地改變了人們研究基因表達調控的方式。然而,轉錄組測序&…

C語言練習(16)

猴子吃桃問題。猴子第一天摘下若干個桃子,當即吃了一半,還不過癮,又多吃了一個。第二天早上又將剩下的桃子吃掉一半,又多吃了一個。以后每天早上都吃了前一天剩下的一半加一個。到第10天早上想再吃時,見只剩一個桃子了…

【機器學習】自定義數據集使用框架的線性回歸方法對其進行擬合

一、使用框架的線性回歸方法 1. 基礎原理 在自求導線性回歸中,我們需要先自定義參數,并且需要通過數學公式來對w和b進行求導,然后在反向傳播過程中通過梯度下降的方式來更新參數,從而降低損失值。 2. 實現步驟 ① 散點輸入 有一…

pytest執行報錯:found no collectors

今天在嘗試使用pytest運行用例的時候出現報錯:found no collectors;從兩個方向進行排查,一是看文件名和函數名是不是符合規范,命名要是"test_*"格式;二是是否存在修改文件名的情況,如果修改過文件…

mysql-06.JDBC

目錄 什么是JDBC: 為啥存在JDBC: JDBC工作原理: JDBC的優勢: 下載mysql驅動包: 用java程序操作數據庫 1.創建dataSource: 2.與服務端建立連接 3.構造sql語句 4.執行sql 5.關閉連接,釋放資源 參考代碼: 插…

微信小程序wxs實現UTC轉北京時間

微信小程序實現UTC轉北京時間 打臉一刻:最近在迭代原生微信小程序,好一段時間沒寫原生的,有點不習慣; 咦,更新數據咋不生效呢?原來還停留在 this.xxx; 喲,事件又沒反應了&#xff1f…

機器學習-線性回歸(對于f(x;w)=w^Tx+b理解)

一、𝑓(𝒙;𝒘) 𝒘T𝒙的推導 學習線性回歸,我們那先要對于線性回歸的表達公示,有所認識。 我們先假設空間是一組參數化的線性函數: 其中權重向量𝒘 ∈ R𝐷 …

R語言學習筆記之語言入門基礎

一、R語言基礎 快速熟悉R語言中的基本概念&#xff0c;先入個門。 1、運算符 運算符含義例子加1 1-減3 - 2*乘3 * 2/除9 / 3^(**)乘方2 ^ 3 2 ** 3%%取余5 %% 2%/%取整5 %/% 2 2、賦值符號 等號a 1三者等價&#xff1a;把1賦值給變量a左箭頭<?a <- 1右箭頭?&g…

計算機網絡三張表(ARP表、MAC表、路由表)總結

參考&#xff1a; 網絡三張表&#xff1a;ARP表, MAC表, 路由表&#xff0c;實現你的網絡自由&#xff01;&#xff01;_mac表、arp表、路由表-CSDN博客 網絡中的三張表&#xff1a;ARP表、MAC表、路由表 首先要明確一件事&#xff0c;如果一個主機要發送數據&#xff0c;那么必…

【Nomoto 船舶模型】

【Nomoto 船舶模型】 1. Nomoto 船舶模型簡介2. 來源及發展歷程3. 構建 一階模型Nomoto 船舶模型3.1 C 實現3.2 Python 實現3.3 說明 5. 參數辨識方法5.1 基于最小二乘法的參數辨識5.2 數學推導5.3 Python 實現5.4 說明 4. 結論參考文獻 1. Nomoto 船舶模型簡介 Nomoto 模型是…

差分進化算法 (Differential Evolution) 算法詳解及案例分析

差分進化算法 (Differential Evolution) 算法詳解及案例分析 目錄 差分進化算法 (Differential Evolution) 算法詳解及案例分析1. 引言2. 差分進化算法 (DE) 算法原理2.1 基本概念2.2 算法步驟3. 差分進化算法的優勢與局限性3.1 優勢3.2 局限性4. 案例分析4.1 案例1: 單目標優化…

深入理解GPT底層原理--從n-gram到RNN到LSTM/GRU到Transformer/GPT的進化

從簡單的RNN到復雜的LSTM/GRU,再到引入注意力機制,研究者們一直在努力解決序列建模的核心問題。每一步的進展都為下一步的突破奠定了基礎,最終孕育出了革命性的Transformer架構和GPT大模型。 1. 從n-gram到循環神經網絡(RNN)的誕生 1.1 N-gram 模型 在深度學習興起之前,處理…

【JWT】jwt實現HS、RS、ES、ED簽名與驗簽

JWT 實現 HS、RS、ES 和 ED 簽名與驗簽 簽名方式算法密鑰類型簽名要點驗簽要點HSHMAC-SHA256對稱密鑰- 使用 crypto/hmac 和對稱密鑰生成 HMAC 簽名- 將 header.payload 作為數據輸入- 使用同一密鑰重新計算 HMAC 簽名- 比較計算結果與接收到的簽名是否一致RSRSA-SHA256公鑰 …

地址欄信息location

獲取信息 頁面跳轉 location.href當前地址欄信息 location.assign()設置跳轉新的頁面 location.replace() location.reload()刷新頁面