卷積神經網絡(AlexNet)鳥類識別

文章目錄

  • 一、前言
  • 二、前期工作
    • 1. 設置GPU(如果使用的是CPU可以忽略這步)
    • 2. 導入數據
    • 3. 查看數據
  • 二、數據預處理
    • 1. 加載數據
    • 2. 可視化數據
    • 3. 再次檢查數據
    • 4. 配置數據集
  • 三、AlexNet (8層)介紹
  • 四、構建AlexNet (8層)網絡模型
  • 五、編譯
  • 六、訓練模型
  • 七、模型評估
  • 八、保存and加載模型
  • 九、預測

一、前言

我的環境:

  • 語言環境:Python3.6.5
  • 編譯器:jupyter notebook
  • 深度學習環境:TensorFlow2.4.1

往期精彩內容:

  • 卷積神經網絡(CNN)實現mnist手寫數字識別
  • 卷積神經網絡(CNN)多種圖片分類的實現
  • 卷積神經網絡(CNN)衣服圖像分類的實現
  • 卷積神經網絡(CNN)鮮花識別
  • 卷積神經網絡(CNN)天氣識別
  • 卷積神經網絡(VGG-16)識別海賊王草帽一伙
  • 卷積神經網絡(ResNet-50)鳥類識別

來自專欄:機器學習與深度學習算法推薦

二、前期工作

1. 設置GPU(如果使用的是CPU可以忽略這步)

import tensorflow as tfgpus = tf.config.list_physical_devices("GPU")if gpus:tf.config.experimental.set_memory_growth(gpus[0], True)  #設置GPU顯存用量按需使用tf.config.set_visible_devices([gpus[0]],"GPU")

2. 導入數據

import matplotlib.pyplot as plt
# 支持中文
plt.rcParams['font.sans-serif'] = ['SimHei']  # 用來正常顯示中文標簽
plt.rcParams['axes.unicode_minus'] = False  # 用來正常顯示負號import os,PIL# 設置隨機種子盡可能使結果可以重現
import numpy as np
np.random.seed(1)# 設置隨機種子盡可能使結果可以重現
import tensorflow as tf
tf.random.set_seed(1)import pathlib
data_dir = "bird_photos"data_dir = pathlib.Path(data_dir)

3. 查看數據

image_count = len(list(data_dir.glob('*/*')))
print("圖片總數為:",image_count)

圖片總數為: 565

二、數據預處理

文件夾數量
Bananaquit166 張
Black Throated Bushtiti111 張
Black skimmer122 張
Cockatoo166張

1. 加載數據

使用image_dataset_from_directory方法將磁盤中的數據加載到tf.data.Dataset

batch_size = 8
img_height = 227
img_width = 227
train_ds = tf.keras.preprocessing.image_dataset_from_directory(data_dir,validation_split=0.2,subset="training",seed=123,image_size=(img_height, img_width),batch_size=batch_size)
Found 565 files belonging to 4 classes.
Using 452 files for training.
val_ds = tf.keras.preprocessing.image_dataset_from_directory(data_dir,validation_split=0.2,subset="validation",seed=123,image_size=(img_height, img_width),batch_size=batch_size)
Found 565 files belonging to 4 classes.
Using 113 files for validation.

我們可以通過class_names輸出數據集的標簽。標簽將按字母順序對應于目錄名稱。

class_names = train_ds.class_names
print(class_names)
['Bananaquit', 'Black Skimmer', 'Black Throated Bushtiti', 'Cockatoo']

2. 可視化數據

plt.figure(figsize=(10, 5))  # 圖形的寬為10高為5for images, labels in train_ds.take(1):for i in range(8):ax = plt.subplot(2, 4, i + 1)  plt.imshow(images[i].numpy().astype("uint8"))plt.title(class_names[labels[i]])plt.axis("off")
plt.imshow(images[1].numpy().astype("uint8"))

3. 再次檢查數據

for image_batch, labels_batch in train_ds:print(image_batch.shape)print(labels_batch.shape)break
(8, 227, 227, 3)
(8,)
  • Image_batch是形狀的張量(8, 224, 224, 3)。這是一批形狀240x240x3的8張圖片(最后一維指的是彩色通道RGB)。
  • Label_batch是形狀(8,)的張量,這些標簽對應8張圖片

4. 配置數據集

AUTOTUNE = tf.data.AUTOTUNEtrain_ds = train_ds.cache().shuffle(1000).prefetch(buffer_size=AUTOTUNE)
val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)

三、AlexNet (8層)介紹

AleXNet使用了ReLU方法加快訓練速度,并且使用Dropout來防止過擬合

AleXNet (8層)是首次把卷積神經網絡引入計算機視覺領域并取得突破性成績的模型。獲得了ILSVRC 2012年的冠軍,再top-5項目中錯誤率僅僅15.3%,相對于使用傳統方法的亞軍26.2%的成績優良重大突破。和之前的LeNet相比,AlexNet通過堆疊卷積層使得模型更深更寬。

四、構建AlexNet (8層)網絡模型

from tensorflow.keras import layers, models, Input
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Dense, Flatten, Dropout,BatchNormalization,Activationimport numpy as np
seed = 7
np.random.seed(seed)def AlexNet(nb_classes, input_shape):input_tensor = Input(shape=input_shape)# 1st blockx = Conv2D(96, (11,11), strides=4, name='block1_conv1')(input_tensor)x = BatchNormalization()(x)x = Activation('relu')(x)x = MaxPooling2D((3,3), strides=2, name = 'block1_pool')(x)# 2nd blockx = Conv2D(256, (5,5), padding='same', name='block2_conv1')(x)x = BatchNormalization()(x)x = Activation('relu')(x)x = MaxPooling2D((3,3), strides=2, name='block2_pool')(x)# 3rd blockx = Conv2D(384, (3,3), activation='relu', padding='same',name='block3_conv1')(x)# 4th blockx = Conv2D(384, (3,3), activation='relu', padding='same',name='block4_conv1')(x)# 5th blockx = Conv2D(256, (3,3), activation='relu', padding='same',name='block5_conv1')(x)x = MaxPooling2D((3,3), strides=2, name = 'block5_pool')(x)# full connectionx = Flatten()(x)x = Dense(4096, activation='relu',  name='fc1')(x)x = Dropout(0.5)(x)x = Dense(4096, activation='relu', name='fc2')(x)x = Dropout(0.5)(x)output_tensor = Dense(nb_classes, activation='softmax', name='predictions')(x)model = Model(input_tensor, output_tensor)return modelmodel=AlexNet(1000, (img_width, img_height, 3))
model.summary()
Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_1 (InputLayer)         [(None, 227, 227, 3)]     0         
_________________________________________________________________
block1_conv1 (Conv2D)        (None, 55, 55, 96)        34944     
_________________________________________________________________
batch_normalization (BatchNo (None, 55, 55, 96)        384       
_________________________________________________________________
activation (Activation)      (None, 55, 55, 96)        0         
_________________________________________________________________
block1_pool (MaxPooling2D)   (None, 27, 27, 96)        0         
_________________________________________________________________
block2_conv1 (Conv2D)        (None, 27, 27, 256)       614656    
_________________________________________________________________
batch_normalization_1 (Batch (None, 27, 27, 256)       1024      
_________________________________________________________________
activation_1 (Activation)    (None, 27, 27, 256)       0         
_________________________________________________________________
block2_pool (MaxPooling2D)   (None, 13, 13, 256)       0         
_________________________________________________________________
block3_conv1 (Conv2D)        (None, 13, 13, 384)       885120    
_________________________________________________________________
block4_conv1 (Conv2D)        (None, 13, 13, 384)       1327488   
_________________________________________________________________
block5_conv1 (Conv2D)        (None, 13, 13, 256)       884992    
_________________________________________________________________
block5_pool (MaxPooling2D)   (None, 6, 6, 256)         0         
_________________________________________________________________
flatten (Flatten)            (None, 9216)              0         
_________________________________________________________________
fc1 (Dense)                  (None, 4096)              37752832  
_________________________________________________________________
dropout (Dropout)            (None, 4096)              0         
_________________________________________________________________
fc2 (Dense)                  (None, 4096)              16781312  
_________________________________________________________________
dropout_1 (Dropout)          (None, 4096)              0         
_________________________________________________________________
predictions (Dense)          (None, 1000)              4097000   
=================================================================
Total params: 62,379,752
Trainable params: 62,379,048
Non-trainable params: 704
_________________________________________________________________

五、編譯

在準備對模型進行訓練之前,還需要再對其進行一些設置。以下內容是在模型的編譯步驟中添加的:

  • 損失函數(loss):用于衡量模型在訓練期間的準確率。
  • 優化器(optimizer):決定模型如何根據其看到的數據和自身的損失函數進行更新。
  • 指標(metrics):用于監控訓練和測試步驟。以下示例使用了準確率,即被正確分類的圖像的比率。
# 設置優化器,我這里改變了學習率。
# opt = tf.keras.optimizers.Adam(learning_rate=1e-7)model.compile(optimizer="adam",loss='sparse_categorical_crossentropy',metrics=['accuracy'])

六、訓練模型

epochs = 20history = model.fit(train_ds,validation_data=val_ds,epochs=epochs
)
Epoch 1/20
57/57 [==============================] - 5s 30ms/step - loss: 9.2789 - accuracy: 0.2166 - val_loss: 3.2340 - val_accuracy: 0.3363
Epoch 2/20
57/57 [==============================] - 1s 14ms/step - loss: 0.9329 - accuracy: 0.6224 - val_loss: 1.1778 - val_accuracy: 0.5310
Epoch 3/20
57/57 [==============================] - 1s 14ms/step - loss: 0.7438 - accuracy: 0.6747 - val_loss: 1.9651 - val_accuracy: 0.5133
Epoch 4/20
57/57 [==============================] - 1s 14ms/step - loss: 0.8875 - accuracy: 0.7025 - val_loss: 1.5589 - val_accuracy: 0.4602
Epoch 5/20
57/57 [==============================] - 1s 14ms/step - loss: 0.6116 - accuracy: 0.7424 - val_loss: 0.9914 - val_accuracy: 0.4956
Epoch 6/20
57/57 [==============================] - 1s 15ms/step - loss: 0.6258 - accuracy: 0.7520 - val_loss: 1.1103 - val_accuracy: 0.5221
Epoch 7/20
57/57 [==============================] - 1s 13ms/step - loss: 0.5138 - accuracy: 0.8034 - val_loss: 0.7832 - val_accuracy: 0.6726
Epoch 8/20
57/57 [==============================] - 1s 14ms/step - loss: 0.5343 - accuracy: 0.7940 - val_loss: 6.1064 - val_accuracy: 0.4602
Epoch 9/20
57/57 [==============================] - 1s 14ms/step - loss: 0.8667 - accuracy: 0.7606 - val_loss: 0.6869 - val_accuracy: 0.7965
Epoch 10/20
57/57 [==============================] - 1s 16ms/step - loss: 0.5785 - accuracy: 0.8141 - val_loss: 1.3631 - val_accuracy: 0.5310
Epoch 11/20
57/57 [==============================] - 1s 15ms/step - loss: 0.4929 - accuracy: 0.8109 - val_loss: 0.7191 - val_accuracy: 0.7345
Epoch 12/20
57/57 [==============================] - 1s 15ms/step - loss: 0.4141 - accuracy: 0.8507 - val_loss: 0.4962 - val_accuracy: 0.8496
Epoch 13/20
57/57 [==============================] - 1s 15ms/step - loss: 0.2591 - accuracy: 0.9148 - val_loss: 0.8015 - val_accuracy: 0.8053
Epoch 14/20
57/57 [==============================] - 1s 15ms/step - loss: 0.2683 - accuracy: 0.9079 - val_loss: 0.5451 - val_accuracy: 0.8142
Epoch 15/20
57/57 [==============================] - 1s 14ms/step - loss: 0.2925 - accuracy: 0.9096 - val_loss: 0.6668 - val_accuracy: 0.8584
Epoch 16/20
57/57 [==============================] - 1s 14ms/step - loss: 0.4009 - accuracy: 0.8804 - val_loss: 1.1609 - val_accuracy: 0.6372
Epoch 17/20
57/57 [==============================] - 1s 14ms/step - loss: 0.4375 - accuracy: 0.8446 - val_loss: 0.9854 - val_accuracy: 0.7965
Epoch 18/20
57/57 [==============================] - 1s 14ms/step - loss: 0.3085 - accuracy: 0.8926 - val_loss: 0.6477 - val_accuracy: 0.8761
Epoch 19/20
57/57 [==============================] - 1s 15ms/step - loss: 0.1200 - accuracy: 0.9538 - val_loss: 1.8996 - val_accuracy: 0.5398
Epoch 20/20
57/57 [==============================] - 1s 15ms/step - loss: 0.3378 - accuracy: 0.9095 - val_loss: 0.9337 - val_accuracy: 0.8053

七、模型評估

acc = history.history['accuracy']
val_acc = history.history['val_accuracy']loss = history.history['loss']
val_loss = history.history['val_loss']epochs_range = range(epochs)plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)plt.plot(epochs_range, acc, label='Training Accuracy')
plt.plot(epochs_range, val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')plt.subplot(1, 2, 2)
plt.plot(epochs_range, loss, label='Training Loss')
plt.plot(epochs_range, val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()

八、保存and加載模型

 保存模型
model.save('model/my_model.h5')
# 加載模型
new_model = tf.keras.models.load_model('model/my_model.h5')

九、預測

# 采用加載的模型(new_model)來看預測結果plt.figure(figsize=(10, 5))  # 圖形的寬為10高為5for images, labels in val_ds.take(1):for i in range(8):ax = plt.subplot(2, 4, i + 1)  # 顯示圖片plt.imshow(images[i].numpy().astype("uint8"))# 需要給圖片增加一個維度img_array = tf.expand_dims(images[i], 0) # 使用模型預測圖片中的人物predictions = new_model.predict(img_array)plt.title(class_names[np.argmax(predictions)])plt.axis("off")

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/165853.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/165853.shtml
英文地址,請注明出處:http://en.pswp.cn/news/165853.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

微信小程序image組件圖片設置最大寬度 寬高自適應

問題描述:在使用微信小程序image組件的時候,在不確定圖片寬高情況下 想給一個最大寬度讓圖片自適應,按比例,image的widthfiex和heightFiex并不能滿足(只指定最大寬/高并不會生效) 問題解決:使用…

居家適老化設計第二十九條---衛生間之花灑

無電源 燈光顯示 無障礙扶手型花灑 以上產品圖片均來源于淘寶 侵權聯系刪除 居家適老化衛生間的花灑通常具有以下特點和功能:1. 高度可調節:適老化衛生間花灑可通過調節高度,滿足不同身高的老年人使用需求,避免彎腰或過高伸展造…

【開源】基于Vue.js的固始鵝塊銷售系統

項目編號: S 060 ,文末獲取源碼。 \color{red}{項目編號:S060,文末獲取源碼。} 項目編號:S060,文末獲取源碼。 目錄 一、摘要1.1 項目介紹1.2 項目錄屏 二、功能模塊2.1 數據中心模塊2.2 鵝塊類型模塊2.3 固…

qgis添加xyz柵格瓦片

方式1:手動一個個添加 左側瀏覽器-XYZ Tiles-右鍵-新建連接 例如添加高德瓦片地址 https://wprd01.is.autonavi.com/appmaptile?langzh_cn&size1&style7&x{x}&y{y}&z{z} 雙擊即可呈現 收集到的一些圖源,僅供參考,其中一…

【C++學習手札】模擬實現list

? 🎬慕斯主頁:修仙—別有洞天 ??今日夜電波:リナリア—まるりとりゅうが 0:36━━━━━━?💟──────── 3:51 🔄 ?? ? ??…

聊聊httpclient的staleConnectionCheckEnabled

序 本文主要研究一下httpclient的staleConnectionCheckEnabled staleConnectionCheckEnabled org/apache/http/client/config/RequestConfig.java public class RequestConfig implements Cloneable {public static final RequestConfig DEFAULT new Builder().build();pr…

【ARM 嵌入式 編譯 Makefile 系列 18 -- Makefile 中的 export 命令詳細介紹】

文章目錄 Makefile 中的 export 命令詳細介紹Makefile 使用 export導出與未導出變量的區別示例:導出變量以供子 Makefile 使用 Makefile 中的 export 命令詳細介紹 在 Makefile 中,export 命令用于將變量從 Makefile 導出到由 Makefile 啟動的子進程的環…

qgis添加wms服務

例如添加geoserver的wms服務 左右瀏覽器-WMS/WMTS-右鍵-新建連接 URL添加geoserver的wms地址 http://{ip}:{port}/geoserver/{workspace}/wms 展開wms目錄,雙擊相應圖層即可打開

Spark---基于Yarn模式提交任務

Yarn模式兩種提交任務方式 一、yarn-client提交任務方式 1、提交命令 ./spark-submit --master yarn --class org.apache.spark.examples.SparkPi ../examples/jars/spark-examples_2.11-2.3.1.jar 100 或者 ./spark-submit --master yarn–client --class org.apache.s…

三菱PLC應用[集錦]

三菱PLC應用[集錦] 如何判斷用PNP還是NPN的個人工作心得 10~30VDC接近開關與PLC連接時,如何判斷用PNP還是NPN的個人工作心得: 對于PLC的開關量輸入回路。我個人感覺日本三菱的要好得多,甚至比西門子等赫赫大名的PLC都要實用和可靠&#xff01…

vulnhub4

靶機地址: https://download.vulnhub.com/admx/AdmX_new.7z 信息收集 fscan 掃一下 ┌──(kali?kali)-[~/Desktop/Tools/fscan] └─$ ./fscan_amd64 -h 192.168.120.138 ___ _ / _ \ ___ ___ _ __ __ _ ___| | __ / /_\/____/ __|/ …

LeetCode | 622. 設計循環隊列

LeetCode | 622. 設計循環隊列 OJ鏈接 思路: 我們這里有一個思路: 插入數據,bank往后走 刪除數據,front往前走 再插入數據,就循環了 那上面這個方法可行嗎? 怎么判斷滿,怎么判斷空&#xff1…

模電知識點總結(二)二極管

系列文章目錄 文章目錄 系列文章目錄二極管二極管電路分析方法理想模型恒壓降模型折線模型小信號模型高頻/開關 二極管應用整流限幅/鉗位開關齊納二極管變容二極管肖特基二極管光電器件光電二極管發光二極管激光二極管太陽能電池 二極管 硅二極管:死區電壓&#xf…

今年注冊電氣工程師考試亂象及就業前景分析

1、注冊電氣工程師掛靠價格 # 2011年以前約為5萬一年,2011年開始強制實施注冊電氣執業制度,證書掛靠價格開始了飛漲,2013年達到巔峰,供配電15萬一年,發輸變電20-25萬一年,這哪里是證書,簡直就是…

Docker kill 命令

docker kill:殺死一個或多個正在運行的容器。 語法: docker kill [OPTIONS] CONTAINER [CONTAINER...]OPTIONS說明: -s:向容器發送一個信號 描述: docker kill子命令會殺死一個或多個容器。容器內的主進程被發送S…

C語言數組的距離(ZZULIOJ1200:數組的距離)

題目描述 已知元素從小到大排列的兩個數組x[]和y[], 請寫出一個程序算出兩個數組彼此之間差的絕對值中最小的一個,這叫做數組的距離 。 輸入:第一行為兩個整數m, n(1≤m, n≤1000),分別代表數組f[], g[]的長度。第二行有m個元素&a…

如何在Simulink中使用syms?換個思路解決報錯:Function ‘syms‘ not supported for code generation.

問題描述 在Simulink中的User defined function使用syms函數,報錯simulink無法使用外部函數。 具體來說: 我想在Predefined function定義如下符號函數作為輸入信號,在后續模塊傳入函數參數賦值,以實現一次定義多次使用&#xf…

014:MyString

題目 描述 補足MyString類&#xff0c;使程序輸出指定結果 #include <iostream> #include <string> #include <cstring> using namespace std; class MyString {char * p; public:MyString(const char * s) {if( s) {p new char[strlen(s) 1];strcpy(p,…

最小二乘線性回歸

? 線性回歸&#xff08;linear regression&#xff09;&#xff1a;試圖學得一個線性模型以盡可能準確地預測實際值的輸出。 以一個例子來說明線性回歸&#xff0c;假設銀行貸款會根據 年齡 和 工資 來評估可放款的額度。即&#xff1a; ? 數據&#xff1a;工資和年齡&…

python將模塊進行打包

模塊名稱為&#xff1a;my_module 目錄結構&#xff1a; my_modulemy_module__init__.pymy_module_main.pysetup.pypython setup.pu sdist bdist_wheel生成tar.gz包和whl文件用于安裝 """ python setup.py sdist bdist_wheel """from setuptoo…