好萊塢明星識別

?一、前期工作


1. 設置GPU

from tensorflow       import keras
from tensorflow.keras import layers,models
import os, PIL, pathlib
import matplotlib.pyplot as plt
import tensorflow        as tfgpus = tf.config.list_physical_devices("GPU")if gpus:gpu0 = gpus[0]                                        #如果有多個GPU,僅使用第0個GPUtf.config.experimental.set_memory_growth(gpu0, True)  #設置GPU顯存用量按需使用tf.config.set_visible_devices([gpu0],"GPU")gpus


如果使用的是CPU可以忽略這步



2. 導入數據

data_dir = "./46-data/"data_dir = pathlib.Path(data_dir)




3. 查看數據

?

image_count = len(list(data_dir.glob('*/*/*.jpg')))print("圖片總數為:",image_count)


?

圖片總數為: 578
roses = list(data_dir.glob('train/nike/*.jpg'))
PIL.Image.open(str(roses[0]))

YAIRI

output_11_0.png



二、數據預處理

1. 加載數據

使用image_dataset_from_directory方法將磁盤中的數據加載到tf.data.Dataset中
●tf.keras.preprocessing.image_dataset_from_directory():是 TensorFlow 的 Keras 模塊中的一個函數,用于從目錄中創建一個圖像數據集(dataset)。這個函數可以以更方便的方式加載圖像數據,用于訓練和評估神經網絡模型。


測試集與驗證集的關系:

1驗證集并沒有參與訓練過程梯度下降過程的,狹義上來講是沒有參與模型的參數訓練更新的。
2但是廣義上來講,驗證集存在的意義確實參與了一個“人工調參”的過程,我們根據每一個epoch訓練之后模型在valid data上的表現來決定是否需要訓練進行early stop,或者根據這個過程模型的性能變化來調整模型的超參數,如學習率,batch_size等等。
3因此,我們也可以認為,驗證集也參與了訓練,但是并沒有使得模型去overfit驗證集

batch_size = 32
img_height = 224
img_width = 224



如果準備嘗試 categorical_crossentropy損失函數,下面的代碼遇到變動哈,變動細節將在下一周博客內公布。
?

"""
關于image_dataset_from_directory()的詳細介紹可以參考文章:https://mtyjkh.blog.csdn.net/article/details/117018789
"""
train_ds = tf.keras.preprocessing.image_dataset_from_directory("./46-data/train/",seed=123,image_size=(img_height, img_width),batch_size=batch_size)

Found 502 files belonging to 2 classes.
"""
關于image_dataset_from_directory()的詳細介紹可以參考文章:https://mtyjkh.blog.csdn.net/article/details/117018789
"""
val_ds = tf.keras.preprocessing.image_dataset_from_directory("./46-data/test/",seed=123,image_size=(img_height, img_width),batch_size=batch_size)
Found 76 files belonging to 2 classes.




我們可以通過class_names輸出數據集的標簽。標簽將按字母順序對應于目錄名稱。
?

class_names = train_ds.class_names
print(class_names)
['adidas', 'nike']



2. 可視化數據

?

plt.figure(figsize=(20, 10))for images, labels in train_ds.take(1):for i in range(20):ax = plt.subplot(5, 10, i + 1)plt.imshow(images[i].numpy().astype("uint8"))plt.title(class_names[labels[i]])plt.axis("off")

output_22_0.png



3. 再次檢查數據

?

for image_batch, labels_batch in train_ds:print(image_batch.shape)print(labels_batch.shape)break
(32, 224, 224, 3)
(32,)


●Image_batch是形狀的張量(32,224,224,3)。這是一批形狀224x224x3的32張圖片(最后一維指的是彩色通道RGB)。
●Label_batch是形狀(32,)的張量,這些標簽對應32張圖片

4. 配置數據集

●shuffle() :打亂數據,關于此函數的詳細介紹可以參考:數據集shuffle方法中buffer_size的理解 - 知乎
●prefetch() :預取數據,加速運行

prefetch()功能詳細介紹:CPU 正在準備數據時,加速器處于空閑狀態。相反,當加速器正在訓練模型時,CPU 處于空閑狀態。因此,訓練所用的時間是 CPU 預處理時間和加速器訓練時間的總和。prefetch()將訓練步驟的預處理和模型執行過程重疊到一起。當加速器正在執行第 N 個訓練步時,CPU 正在準備第 N+1 步的數據。這樣做不僅可以最大限度地縮短訓練的單步用時(而不是總用時),而且可以縮短提取和轉換數據所需的時間。如果不使用prefetch(),CPU 和 GPU/TPU 在大部分時間都處于空閑狀態:

image.png


使用prefetch()可顯著減少空閑時間:

image.png


●cache() :將數據集緩存到內存當中,加速運行

AUTOTUNE = tf.data.AUTOTUNEtrain_ds = train_ds.cache().shuffle(1000).prefetch(buffer_size=AUTOTUNE)
val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)



三、構建CNN網絡

卷積神經網絡(CNN)的輸入是張量 (Tensor) 形式的 (image_height, image_width, color_channels),包含了圖像高度、寬度及顏色信息。不需要輸入batch size。color_channels 為 (R,G,B) 分別對應 RGB 的三個顏色通道(color channel)。在此示例中,我們的 CNN 輸入的形狀是 (224, 224, 3)即彩色圖像。我們需要在聲明第一層時將形狀賦值給參數input_shape。

網絡結構圖(可單擊放大查看):

image.png

"""
關于卷積核的計算不懂的可以參考文章:https://blog.csdn.net/qq_38251616/article/details/114278995layers.Dropout(0.4) 作用是防止過擬合,提高模型的泛化能力。
關于Dropout層的更多介紹可以參考文章:https://mtyjkh.blog.csdn.net/article/details/115826689
"""model = models.Sequential([layers.experimental.preprocessing.Rescaling(1./255, input_shape=(img_height, img_width, 3)),layers.Conv2D(16, (3, 3), activation='relu', input_shape=(img_height, img_width, 3)), # 卷積層1,卷積核3*3  layers.AveragePooling2D((2, 2)),               # 池化層1,2*2采樣layers.Conv2D(32, (3, 3), activation='relu'),  # 卷積層2,卷積核3*3layers.AveragePooling2D((2, 2)),               # 池化層2,2*2采樣layers.Dropout(0.3),  layers.Conv2D(64, (3, 3), activation='relu'),  # 卷積層3,卷積核3*3layers.Dropout(0.3),  layers.Flatten(),                       # Flatten層,連接卷積層與全連接層layers.Dense(128, activation='relu'),   # 全連接層,特征進一步提取layers.Dense(len(class_names))               # 輸出層,輸出預期結果
])model.summary()  # 打印網絡結構
Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
rescaling (Rescaling)        (None, 224, 224, 3)       0         
_________________________________________________________________
conv2d (Conv2D)              (None, 222, 222, 16)      448       
_________________________________________________________________
average_pooling2d (AveragePo (None, 111, 111, 16)      0         
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 109, 109, 32)      4640      
_________________________________________________________________
average_pooling2d_1 (Average (None, 54, 54, 32)        0         
_________________________________________________________________
dropout (Dropout)            (None, 54, 54, 32)        0         
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 52, 52, 64)        18496     
_________________________________________________________________
dropout_1 (Dropout)          (None, 52, 52, 64)        0         
_________________________________________________________________
flatten (Flatten)            (None, 173056)            0         
_________________________________________________________________
dense (Dense)                (None, 128)               22151296  
_________________________________________________________________
dense_1 (Dense)              (None, 2)                 258       
=================================================================
Total params: 22,175,138
Trainable params: 22,175,138
Non-trainable params: 0
_________________________________________________________________




四、訓練模型

在準備對模型進行訓練之前,還需要再對其進行一些設置。以下內容是在模型的編譯步驟中添加的:

●損失函數(loss):用于衡量模型在訓練期間的準確率。
●優化器(optimizer):決定模型如何根據其看到的數據和自身的損失函數進行更新。
●指標(metrics):用于監控訓練和測試步驟。以下示例使用了準確率,即被正確分類的圖像的比率。

1.設置動態學習率

📮 ExponentialDecay函數:
tf.keras.optimizers.schedules.ExponentialDecay是 TensorFlow 中的一個學習率衰減策略,用于在訓練神經網絡時動態地降低學習率。學習率衰減是一種常用的技巧,可以幫助優化算法更有效地收斂到全局最小值,從而提高模型的性能。

🔎 主要參數:
●initial_learning_rate(初始學習率):初始學習率大小。
●decay_steps(衰減步數):學習率衰減的步數。在經過 decay_steps 步后,學習率將按照指數函數衰減。例如,如果 decay_steps 設置為 10,則每10步衰減一次。
●decay_rate(衰減率):學習率的衰減率。它決定了學習率如何衰減。通常,取值在 0 到 1 之間。
●staircase(階梯式衰減):一個布爾值,控制學習率的衰減方式。如果設置為 True,則學習率在每個 decay_steps 步之后直接減小,形成階梯狀下降。如果設置為 False,則學習率將連續衰減。

# 設置初始學習率
initial_learning_rate = 0.1lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(initial_learning_rate, decay_steps=10,      # 敲黑板!!!這里是指 steps,不是指epochsdecay_rate=0.92,     # lr經過一次衰減就會變成 decay_rate*lrstaircase=True)# 將指數衰減學習率送入優化器
optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule)model.compile(optimizer=optimizer,loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),metrics=['accuracy'])



注:這里設置的動態學習率為:指數衰減型(ExponentialDecay)。在每一個epoch開始前,學習率(learning_rate)都將會重置為初始學習率(initial_learning_rate),然后再重新開始衰減。計算公式如下:

learning_rate = initial_learning_rate * decay_rate ^ (step / decay_steps)

學習率大與學習率小的優缺點分析:

學習率大

● 優點:
○1、加快學習速率。
○2、有助于跳出局部最優值。
● 缺點:
○1、導致模型訓練不收斂。
○2、單單使用大學習率容易導致模型不精確。

學習率小

● 優點:
○1、有助于模型收斂、模型細化。
○2、提高模型精度。
● 缺點:
○1、很難跳出局部最優值。
○2、收斂緩慢。

2.早停與保存最佳模型參數

EarlyStopping()參數說明:

●monitor: 被監測的數據。
●min_delta: 在被監測的數據中被認為是提升的最小變化, 例如,小于 min_delta 的絕對變化會被認為沒有提升。
●patience: 沒有進步的訓練輪數,在這之后訓練就會被停止。
●verbose: 詳細信息模式。
●mode: {auto, min, max} 其中之一。 在 min 模式中, 當被監測的數據停止下降,訓練就會停止;在 max 模式中,當被監測的數據停止上升,訓練就會停止;在 auto 模式中,方向會自動從被監測的數據的名字中判斷出來。
●baseline: 要監控的數量的基準值。 如果模型沒有顯示基準的改善,訓練將停止。
●estore_best_weights: 是否從具有監測數量的最佳值的時期恢復模型權重。 如果為 False,則使用在訓練的最后一步獲得的模型權重。

from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStoppingepochs = 50# 保存最佳模型參數
checkpointer = ModelCheckpoint('best_model.h5',monitor='val_accuracy',verbose=1,save_best_only=True,save_weights_only=True)# 設置早停
earlystopper = EarlyStopping(monitor='val_accuracy', min_delta=0.001,patience=20, verbose=1)



3. 模型訓練
?

history = model.fit(train_ds,validation_data=val_ds,epochs=epochs,callbacks=[checkpointer, earlystopper])
Epoch 1/50
16/16 [==============================] - 4s 31ms/step - loss: 3.5439 - accuracy: 0.4721 - val_loss: 0.6931 - val_accuracy: 0.5789Epoch 00001: val_accuracy improved from -inf to 0.57895, saving model to best_model.h5
Epoch 2/50
16/16 [==============================] - 0s 12ms/step - loss: 0.6929 - accuracy: 0.5279 - val_loss: 0.6891 - val_accuracy: 0.6447......Epoch 00040: val_accuracy did not improve from 0.89474
Epoch 41/50
16/16 [==============================] - 0s 12ms/step - loss: 0.0931 - accuracy: 0.9841 - val_loss: 0.3837 - val_accuracy: 0.8816Epoch 00041: val_accuracy did not improve from 0.89474
Epoch 42/50
16/16 [==============================] - 0s 12ms/step - loss: 0.0871 - accuracy: 0.9801 - val_loss: 0.3834 - val_accuracy: 0.8816Epoch 00042: val_accuracy did not improve from 0.89474
Epoch 00042: early stopping



五、模型評估

1. Loss與Accuracy圖
?

acc = history.history['accuracy']
val_acc = history.history['val_accuracy']loss = history.history['loss']
val_loss = history.history['val_loss']epochs_range = range(len(loss))plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(epochs_range, acc, label='Training Accuracy')
plt.plot(epochs_range, val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')plt.subplot(1, 2, 2)
plt.plot(epochs_range, loss, label='Training Loss')
plt.plot(epochs_range, val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()

output_51_0.png


2. 指定圖片進行預測
?

from PIL import Image
import numpy as np# img = Image.open("./45-data/Monkeypox/M06_01_04.jpg")  #這里選擇你需要預測的圖片
img = Image.open("./46-data/test/nike/1.jpg")  #這里選擇你需要預測的圖片
image = tf.image.resize(img, [img_height, img_width])img_array = tf.expand_dims(image, 0) #/255.0  # 記得做歸一化處理(與訓練集處理方式保持一致)predictions = model.predict(img_array) # 這里選用你已經訓練好的模型
print("預測結果為:",class_names[np.argmax(predictions)])

預測結果為: nike

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/208268.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/208268.shtml
英文地址,請注明出處:http://en.pswp.cn/news/208268.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

動態規劃——完全背包問題(公式推導,組合、排列)

本文章是對于完全背包 一些題型(如題目所示,組合、排列和最小值類型)的總結和理解,依次記錄一下,方便回顧與復習。 本文章是基于個人所總結 實現的,但在其中遇到了一些疑惑與困難,所以總結一篇與完全背包相關的問題。 …

Spring基于注解開發

Component的使用 基本Bean注解&#xff0c;主要是使用注解的方式替代原有的xml的<bean>標簽及其標簽屬性的配置&#xff0c;使用Component注解替代<bean>標簽中的id以及class屬性&#xff0c;而對于是否延遲加載或是Bean的作用域&#xff0c;則是其他注解 xml配置…

IntelliJ IDEA 的 HTTP 客戶端的高級用法

本心、輸入輸出、結果 文章目錄 IntelliJ IDEA 的 HTTP 客戶端的高級用法前言HTTP 請求對 gRPC 請求的支持對 GraphQL 和 WebSocket 請求的支持環境文件OpenAPI 補全用于持續集成的 HTTP 客戶端 CLI花有重開日,人無再少年實踐是檢驗真理的唯一標準IntelliJ IDEA 的 HTTP 客戶端…

keepalived 高可用主備

實驗采用兩臺centos9 nginxkeepalived 一共兩臺&#xff0c;進行主備切換 主服務器 192.168.100.105 備用 192.168.100.106 虛擬ip 192.168.100.200 安裝 dnf install vim wget curl vim net-tools nginx keepalivedUndefined nginx 配置需要更改為虛擬ip server {listen …

四招打造完美分層自動化測試框架,讓測試更高效!

寫在前面 我們剛開始做自動化測試&#xff0c;可能寫的代碼都是基于原生寫的代碼&#xff0c;看起來特別不美觀&#xff0c;而且感覺特別生硬。 來看下面一段代碼&#xff1a; 具體表現如下&#xff1a; driver對象在測試類中顯示 定位元素的value值在測試類中顯示 定位元素…

Navicat 技術指引 | 適用于 GaussDB 分布式的用戶/權限功能

Navicat Premium&#xff08;16.3.3 Windows 版或以上&#xff09;正式支持 GaussDB 分布式數據庫。GaussDB 分布式模式更適合對系統可用性和數據處理能力要求較高的場景。Navicat 工具不僅提供可視化數據查看和編輯功能&#xff0c;還提供強大的高階功能&#xff08;如模型、結…

干貨:軟文推廣中的關鍵詞類別有哪些?

軟文推廣如果想要增加文案曝光率&#xff0c;seo是其主要的傳播方式之一&#xff0c;因而好的關鍵詞十分重要&#xff0c;這里的關鍵詞指得是針對搜索引擎而言&#xff0c;由用戶輸入搜索引擎框中的提示性文字&#xff0c;只要關鍵詞設置得好&#xff0c;軟文就能通過搜索引擎精…

因為 postman環境變量全局變量設置好兄弟被公司優化了!

postman環境變量、全局變量設置 在公司中&#xff0c;一般會存在開發環境、測試環境、線上環境等&#xff0c;如果需要在不 同的環境下切換做接口測試&#xff0c;顯然我們需要把所有接口的域名進行修改&#xff0c;如果接 口測試用例較多&#xff0c;那么修改會非常費力&…

springboot(ssm大學生志愿者管理系統 志愿者管理平臺 Java系統

springboot(ssm大學生志愿者管理系統 志愿者管理平臺 Java系統 開發語言&#xff1a;Java 框架&#xff1a;ssm/springboot vue JDK版本&#xff1a;JDK1.8&#xff08;或11&#xff09; 服務器&#xff1a;tomcat 數據庫&#xff1a;mysql 5.7&#xff08;或8.0&#xff…

Python與ArcGIS系列(十五)根據距離抓取字段

目錄 0 簡述1 實例需求2 arcpy開發腳本0 簡述 在處理gis數據的時候,會遇到這種需求:將一個圖層與另一個圖層中相近的要素進行字段賦值。本篇將介紹如何利用arcpy及arcgis的工具箱實現這個功能。 1 實例需求 為了介紹這個功能的實現,我們需要有一個特定的功能需求。在這里選…

視頻號小店怎么選品?選品技巧及思維,教程如下!

我是電商珠珠 開通視頻號小店后&#xff0c;除了定類目之外&#xff0c;最終的就是選品了。 很多人不知道怎么選品&#xff0c;特別是新手小白&#xff0c;做起來比較難一些。店鋪也會很少有流量進入&#xff0c;沒有流量曝光的話&#xff0c;店鋪的銷量就更不用提了。 我做…

L1-019:誰先倒

題目描述 劃拳是古老中國酒文化的一個有趣的組成部分。酒桌上兩人劃拳的方法為&#xff1a;每人口中喊出一個數字&#xff0c;同時用手比劃出一個數字。如果誰比劃出的數字正好等于兩人喊出的數字之和&#xff0c;誰就輸了&#xff0c;輸家罰一杯酒。兩人同贏或兩人同輸則繼續下…

【Android】Java NIO(New I/O)的`Selector`類來實現非阻塞的Socket監聽

如果你不想使用循環來監聽客戶端的連接和數據&#xff0c;你可以使用Java NIO&#xff08;New I/O&#xff09;的Selector類來實現非阻塞的Socket監聽。Selector類提供了一種選擇一組已經就緒的通道的機制&#xff0c;這樣你就不需要使用循環來等待連接和數據。 以下是使用Sel…

Axure網頁端高復用組件庫, 下拉菜單文件上傳穿梭框日期城市選擇器

作品說明 組件數量&#xff1a;共 11 套 兼容軟件&#xff1a;Axure RP 9/10&#xff0c;不支持低版本 應用領域&#xff1a;web端原型設計、桌面端原型設計 作品特色 本作品為「web端組件庫」&#xff0c;高保真高交互 (帶仿真功能效果)&#xff1b;運用了動態面板、中繼…

使用pytorch查看中間層特征矩陣以及卷積核參數

這篇是我對嗶哩嗶哩up主 霹靂吧啦Wz 的視頻的文字版學習筆記 感謝他對知識的分享 1和4是之前講過的alexnet和resnet模型 2是分析中間層特征矩陣的腳本 3是查看卷積核參數的腳本 1設置預處理方法 和圖像訓練的時候用的預處理方法保持一致 2實例化模型 3載入之前的模型參數 4載入…

小白理解GPT的“微調“(fine-tuning)

對于GPT-3.5&#xff0c;我們實際上并不能在OpenAI的服務器上直接訓練它。OpenAI的模型通常是預訓練好的&#xff0c;也就是說&#xff0c;它們已經在大量的語料上進行過訓練&#xff0c;學習到了語言的基本規則和模式。 然而&#xff0c;OpenAI提供了一種叫做"微調"…

Pandas操作數據庫

一&#xff1a;Pandas讀取數據庫數據 二&#xff1a;Pandas讀取海量數據 三&#xff1a;Pandas向數據庫存數據 四&#xff1a;Pandas寫入海量數據

理想中的PC端剪切板工具,應該有哪些功能?

在日常工作中&#xff0c;我們經常需要復制和粘貼文本、圖片和鏈接。 首先&#xff0c;這款剪切板功能應該在不使用時不顯示窗口&#xff0c;以避免干擾我們的工作。它應該在后臺靜默記錄剪切板歷史&#xff0c;以便我們可以隨時查看之前的記錄。 其次&#xff0c;當我們需要…

A類中創建posix線程,線程間如何通信

如果你在類A中使用pthread_create創建了線程B&#xff0c;而線程B需要與類A進行通信&#xff0c;你可以考慮以下兩種方法&#xff1a; 使用回調函數&#xff1a; 在創建線程B時&#xff0c;通過參數傳遞一個回調函數&#xff0c;該回調函數可以在線程B中執行&#xff0c;并在完…

上海寶山區12月8日發生一起火災 火勢已撲滅 揭秘AI如何“救援”

在這個冬日的早晨&#xff0c;上海寶山區的居民經歷了一場驚心動魄的火災。幸運的是&#xff0c;火勢很快就被撲滅了。但這起事件不禁讓我們思考&#xff1a;如何更有效地預防和應對這樣的緊急情況&#xff1f; 這時候&#xff0c;就不得不提到北京富維圖像公司的一項創新技術—…