深度學習筆記11-優化器對比實驗(Tensorflow)

  • 🍨 本文為🔗365天深度學習訓練營中的學習記錄博客
  • 🍖 原作者:K同學啊

目錄

一、導入數據并檢查

二、配置數據集

三、數據可視化

四、構建模型

五、訓練模型

六、模型對比評估

七、總結


一、導入數據并檢查

import pathlib,PIL
import tensorflow as tf
from tensorflow import keras
import matplotlib.pyplot as plt
# 支持中文
plt.rcParams['font.sans-serif'] = ['SimHei']  # 用來正常顯示中文標簽data_dir    = pathlib.Path("./T6")
image_count = len(list(data_dir.glob('*/*')))
batch_size = 16
img_height = 336
img_width  = 336
"""
關于image_dataset_from_directory()的詳細介紹可以參考文章:https://mtyjkh.blog.csdn.net/article/details/117018789
"""
train_ds = tf.keras.preprocessing.image_dataset_from_directory(data_dir,validation_split=0.2,subset="training",seed=12,image_size=(img_height, img_width),batch_size=batch_size)
val_ds = tf.keras.preprocessing.image_dataset_from_directory(data_dir,validation_split=0.2,subset="validation",seed=12,image_size=(img_height, img_width),batch_size=batch_size)

class_names = train_ds.class_names
print(class_names)

for image_batch, labels_batch in train_ds:print(image_batch.shape)print(labels_batch.shape)break

二、配置數據集

AUTOTUNE = tf.data.AUTOTUNE
#歸一化處理
def train_preprocessing(image,label):return (image/255.0,label)train_ds = (train_ds.cache().shuffle(1000).map(train_preprocessing)    # 這里可以設置預處理函數
#     .batch(batch_size)           # 在image_dataset_from_directory處已經設置了batch_size.prefetch(buffer_size=AUTOTUNE)
)val_ds = (val_ds.cache().shuffle(1000).map(train_preprocessing)    # 這里可以設置預處理函數
#     .batch(batch_size)         # 在image_dataset_from_directory處已經設置了batch_size.prefetch(buffer_size=AUTOTUNE)
)

三、數據可視化

plt.figure(figsize=(10, 8))  # 圖形的寬為10高為5
plt.suptitle("數據展示")for images, labels in train_ds.take(1):for i in range(15):plt.subplot(4, 5, i + 1)plt.xticks([])plt.yticks([])plt.grid(False)# 顯示圖片plt.imshow(images[i])# 顯示標簽plt.xlabel(class_names[labels[i]-1])plt.show()

四、構建模型

from tensorflow.keras.layers import Dropout,Dense,BatchNormalization
from tensorflow.keras.models import Modeldef create_model(optimizer='adam'):# 加載預訓練模型vgg16_base_model = tf.keras.applications.vgg16.VGG16(weights='imagenet',include_top=False,#不包含頂層的全連接層input_shape=(img_width, img_height, 3),pooling='avg')#平均池化層替代頂層的全連接層for layer in vgg16_base_model.layers:layer.trainable = False  #將 trainable屬性設置為 False 意味著在訓練過程中,這些層的權重不會更新X = vgg16_base_model.outputX = Dense(170, activation='relu')(X)X = BatchNormalization()(X)X = Dropout(0.5)(X)output = Dense(len(class_names), activation='softmax')(X)#神經元數量等于類別數vgg16_model = Model(inputs=vgg16_base_model.input, outputs=output)vgg16_model.compile(optimizer=optimizer,loss='sparse_categorical_crossentropy',metrics=['accuracy'])return vgg16_modelmodel1 = create_model(optimizer=tf.keras.optimizers.Adam())
model2 = create_model(optimizer=tf.keras.optimizers.SGD())#隨機梯度下降(SGD)優化器的
model2.summary()

五、訓練模型

NO_EPOCHS = 20history_model1  = model1.fit(train_ds, epochs=NO_EPOCHS, verbose=1, validation_data=val_ds)
history_model2  = model2.fit(train_ds, epochs=NO_EPOCHS, verbose=1, validation_data=val_ds)

六、模型對比評估

from matplotlib.ticker import MultipleLocator
plt.rcParams['savefig.dpi'] = 300 #圖片像素
plt.rcParams['figure.dpi']  = 300 #分辨率acc1     = history_model1.history['accuracy']
acc2     = history_model2.history['accuracy']
val_acc1 = history_model1.history['val_accuracy']
val_acc2 = history_model2.history['val_accuracy']loss1     = history_model1.history['loss']
loss2     = history_model2.history['loss']
val_loss1 = history_model1.history['val_loss']
val_loss2 = history_model2.history['val_loss']epochs_range = range(len(acc1))plt.figure(figsize=(16, 4))
plt.subplot(1, 2, 1)plt.plot(epochs_range, acc1, label='Training Accuracy-Adam')
plt.plot(epochs_range, acc2, label='Training Accuracy-SGD')
plt.plot(epochs_range, val_acc1, label='Validation Accuracy-Adam')
plt.plot(epochs_range, val_acc2, label='Validation Accuracy-SGD')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')
# 設置刻度間隔,x軸每1一個刻度
ax = plt.gca()
ax.xaxis.set_major_locator(MultipleLocator(1))plt.subplot(1, 2, 2)
plt.plot(epochs_range, loss1, label='Training Loss-Adam')
plt.plot(epochs_range, loss2, label='Training Loss-SGD')
plt.plot(epochs_range, val_loss1, label='Validation Loss-Adam')
plt.plot(epochs_range, val_loss2, label='Validation Loss-SGD')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')# 設置刻度間隔,x軸每1一個刻度
ax = plt.gca()
ax.xaxis.set_major_locator(MultipleLocator(1))plt.show()

可以看出,在這個實例中,Adam優化器的效果優于SGD優化器

七、總結

? ? ? 通過本次實驗,學會了比較不同優化器(Adam和SGD)在訓練過程中的性能表現,可視化訓練過程的損失曲線和準確率等指標。這是一項非常重要的技能,在研究論文中,可以通過這些優化方法可以提高工作量。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/web/65673.shtml
繁體地址,請注明出處:http://hk.pswp.cn/web/65673.shtml
英文地址,請注明出處:http://en.pswp.cn/web/65673.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

FFmpeg Muxer HLS

使用FFmpeg命令來研究它對HLS協議的支持程度是最好的方法: ffmpeg -h muxerhls Muxer HLS Muxer hls [Apple HTTP Live Streaming]:Common extensions: m3u8.Default video codec: h264.Default audio codec: aac.Default subtitle codec: webvtt. 這里面告訴我…

Apache和PHP:構建動態網站的黃金組合

在當今的互聯網世界,網站已經成為了企業、個人和機構展示自己、與用戶互動的重要平臺。而在這些動態網站的背后,Apache和PHP無疑是最受開發者青睞的技術組合之一。這一組合提供了高效、靈活且可擴展的解決方案,幫助您快速搭建出強大的網站&am…

git相關操作筆記

git相關操作筆記 1. git init git init 是一個 Git 命令,用于初始化一個新的 Git 倉庫。執行該命令后,Git 會在當前目錄創建一個 .git 子目錄,這是 Git 用來存儲所有版本控制信息的地方。 使用方法如下: (1&#xff…

Docker Desktop 構建java8基礎鏡像jdk安裝配置失效解決

Docker Desktop 構建java8基礎鏡像jdk安裝配置失效解決 文章目錄 1.問題2.解決方法3.總結 1.問題 之前的好幾篇文章中分享了在Linux(centOs上)和windows10上使用docker和docker Desktop環境構建java8的最小jre基礎鏡像,前幾天我使用Docker Desktop環境重新構建了一個…

VUE + pdfh5 實現pdf 預覽,主要用來uniappH5實現嵌套預覽PDF

1. 安裝依賴 npm install pdfh5 2. pdfh5 預覽(移動端,h5) npm install pdfh5 , (會報錯,需要其他依賴,不能直接用提示的語句直接npm下載,依舊會報錯,npm報錯:These dependencies were not fou…

Node.js——fs(文件系統)模塊

個人簡介 👀個人主頁: 前端雜貨鋪 🙋?♂?學習方向: 主攻前端方向,正逐漸往全干發展 📃個人狀態: 研發工程師,現效力于中國工業軟件事業 🚀人生格言: 積跬步…

Microsoft Azure Cosmos DB:全球分布式、多模型數據庫服務

目錄 前言1. Azure Cosmos DB 簡介1.1 什么是 Azure Cosmos DB?1.2 核心技術特點 2. 數據模型與 API 支持2.1 文檔存儲(Document Store)2.2 圖數據庫(Graph DBMS)2.3 鍵值存儲(Key-Value Store)…

springboot項目讀取resources目錄下文件

要用以下這種方式讀取 classPathResource new ClassPathResource("template/test.docx");不能用以下這種獲取絕對路徑的方式,idea調試正常,但是部署window和linux的目錄結構不一樣,部署后會找不到文件,另外window直接…

Ruby語言的軟件開發工具

Ruby語言的軟件開發工具概述 引言 Ruby是一種簡單且功能強大的編程語言,它以優雅的語法和靈活性而聞名。自1995年首次發布以來,Ruby已經被廣泛應用于各種開發領域,特別是Web開發。隨著Ruby語言的普及,相關的開發工具也日益豐富。…

C++例程:使用I/O模擬IIC接口(6)

完整的STM32F405代碼工程I2C驅動源代碼跟蹤 一)myiic.c #include "myiic.h" #include "delay.h" #include "stm32f4xx_rcc.h" //初始化IIC void IIC_Init(void) { GPIO_InitTypeDef GPIO_InitStructure;RCC_AHB1PeriphCl…

CNN-BiLSTM-Attention模型詳解及應用分析

CNN-BiLSTM-Attention結構 CNN-BiLSTM-Attention結構是一種強大的深度學習架構,巧妙地結合了三種不同的技術優勢:卷積神經網絡(CNN)、雙向長短期記憶網絡(BiLSTM)和注意力機制(Attention)。這種創新性的組合使得模型能夠在處理復雜序列數據時表現出色,尤其適用于自然…

2025年華為OD上機考試真題(Java)——整數對最小和

題目: 給定兩個整數數組array1、array2,數組元素按升序排列。假設從array1、array2中分別取出一個元素可構成一對元素,現在需要取出k對元素,并對取出的所有元素求和,計算和的最小值。 注意:兩對元素如果對應…

【Java知識】Groovy 一個兼容java的編程語言

groovy語言介紹 概述一、基本特點二、主要特性三、應用領域四、與Java的比較 基本語法特性一、基本語法二、數據類型三、運算符四、字符串五、方法六、閉包七、類與對象八、異常處理九、其他特性 集成到springboot項目1. 創建Spring Boot項目2. 添加Groovy依賴3. 編寫Groovy類4…

Python網絡爬蟲:從入門到實戰

Python以其簡潔易用和強大的庫支持成為網絡爬蟲開發的首選語言。本文將系統介紹Python網絡爬蟲的開發方法,包括基礎知識、常用工具以及實戰案例,幫助讀者從入門到精通。 什么是網絡爬蟲? 網絡爬蟲(Web Crawler)是一種…

【vLLM 學習】安裝

vLLM 是一款專為大語言模型推理加速而設計的框架,實現了 KV 緩存內存幾乎零浪費,解決了內存管理瓶頸問題。 更多 vLLM 中文文檔及教程可訪問 →https://vllm.hyper.ai/ vLLM 是一個 Python 庫,包含預編譯的 C 和 CUDA (12.1) 二進制文件。 …

npm : 無法加載文件 D:\SoftFile\npm.ps1,因為在此系統上禁止運行腳本。

這個錯誤是由于 Windows PowerShell 的執行策略禁止執行腳本,導致無法運行 npm 命令。你可以通過以下步驟來解決這個問題: 以管理員身份運行 PowerShell: 點擊“開始”菜單,搜索“PowerShell”,然后右鍵點擊“Windows …

7 分布式定時任務調度框架

先簡單介紹下分布式定時任務調度框架的使用場景和功能和架構,然后再介紹世面上常見的產品 我們在大型的復雜的系統下,會有大量的跑批,定時任務的功能,如果在獨立的子項目中單獨去處理這些任務,隨著業務的復雜度的提高…

網絡安全 | 網絡安全法規:GDPR、CCPA與中國網絡安全法

網絡安全 | 網絡安全法規:GDPR、CCPA與中國網絡安全法 一、前言二、歐盟《通用數據保護條例》(GDPR)2.1 背景2.2 主要內容2.3 特點2.4 實施效果與影響 三、美國《加利福尼亞州消費者隱私法案》(CCPA)3.1 背景3.2 主要內…

Elixir語言的計算機基礎

Elixir語言的計算機基礎 引言 在當今這個快速發展的技術時代,編程語言層出不窮。Elixir作為一種較新的編程語言,以其高并發、低延遲和強大的容錯能力受到越來越多開發者的青睞。它基于Erlang虛擬機(BEAM),自然繼承了…

mysql的mvcc理解

人閱讀 一、說到mvcc就少不了事務隔離級別(大白話解釋) 序列化(SERIALIZABLE):事務之間完全隔離,當成一個序列,一個一個執行。 1 可重復讀(REPEATABLE READ)&#xff…