基于CNN+ViT的蔬果圖像分類實驗

本文只是做一個簡單融合的實驗,沒有任何新穎,大家看看就行了。

1.數據集

本文所采用的數據集為Fruit-360 果蔬圖像數據集,該數據集由 Horea Mure?an 等人整理并發布于 GitHub(項目地址:Horea94/Fruit-Images-Dataset),廣泛應用于圖像分類和目標識別等計算機視覺任務。該數據集共包含141 類水果和蔬菜圖像,總計 94,110 張圖像,每張圖像的尺寸統一為 100×100 像素,且背景已統一處理為白色背景,以減少背景噪聲對模型訓練的影響。

數據集中涵蓋了大量常見和不常見的果蔬品類,主要包括:

  1. 蘋果(多個品種:如深雪、金蘋果、金紅、青奶奶、粉紅女士、紅蘋果、紅美味等)
  2. 香蕉(黃色、紅色、淑女手指等)
  3. 葡萄(藍色、粉紅色、白色多個品種)
  4. 柑橘類(橙子、檸檬、酸橙、葡萄柚、柑橘等)
  5. 熱帶水果(芒果、木瓜、紅毛丹、百香果、番石榴、荔枝、菠蘿、火龍果等)
  6. 漿果類(藍莓、覆盆子、草莓、黑加侖、紅醋栗、桑葚等)
  7. 核果類與堅果類(桃子、李子、杏、椰子、榛子、核桃、栗子、山核桃等)
  8. 蔬菜類(黃瓜、茄子、胡椒、番茄、洋蔥、花椰菜、甜菜根、玉米、土豆等)
  9. 其他類如:仙人掌果實、楊布拉、姜根、格蘭納迪拉、Physalis(燈籠果)、油桃、佩皮諾、羅望子、大頭菜等。

在數據劃分方面,本研究按照如下比例進行數據集劃分:

(1)訓練集:70,491 張圖像
??????????其中按照 8:2 的比例劃分出驗證集,得到最終:

????????????????????????訓練子集:56,432 張

????????????????????????驗證集:14,059 張

(2)測試集:23,619 張圖像

2.模型簡述

在圖像分類任務中,深度學習方法已經取得了顯著的進展,如殘差神經網絡(ResNet),Vision Transformer展現了較強的性能。ResNet作為CNN下的網絡架構,在局部特征提取方面具有優勢,能夠有效地捕捉圖像中的空間結構信息。而Vision Transformer作為Transformer的變種,在捕捉全局依賴關系和建模長程依賴性方面的具有更好的優勢。

由于CNN的卷積操作本質上能夠生成具有空間局部關聯性的特征圖,實際上可以視為一種變相的patch操作。因此,在將CNN與Transformer相結合時,可以避免傳統ViT中對輸入圖像進行切分patch的操作,只需對圖像進行位置編碼,從而使得Transformer能夠有效處理這些具有空間結構的特征圖。這種設計不僅減少了計算開銷,還使得整個模型在處理圖像時更具效率與準確性。

同時,與原始ViT框架中描述的技術不同,原始框架通常會將一個可學習的位置嵌入向量預先添加到編碼后的patch序列中,作為圖像的位置信息進行表示。然而,為了簡化模型的實現并提高計算效率,本文在架構設計上有所調整,省略了額外的位置編碼步驟。具體來說,本文的模型通過直接輸入編碼后的patch序列到Transformer塊中,跳過了對每個patch進行獨立位置編碼的操作。

基于這一思路,結合了殘差神經網絡(ResNet)和Vision Transformer(ViT)兩種網絡架構,將它們以串行連接的方式進行融合。具體模型架構圖如下圖所示

3.實驗

模型代碼(基于tensorflow2.X)

import glob
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras import layers,models
import warnings
warnings.filterwarnings('ignore')
import osTrain = r"D:\archive (1)\fruits-360_dataset_100x100\fruits-360\Training"
Test = r"D:\archive (1)\fruits-360_dataset_100x100\fruits-360\Test"IMAGE_SIZE = 100
NUM_CLASSES = 141
BATCH_SIZE = 32imagegenerator = ImageDataGenerator(rescale=1.0 / 255.0, validation_split=0.2, rotation_range=10, horizontal_flip=True)# Training and validation data generators
Train_Data = imagegenerator.flow_from_directory(Train,target_size=(IMAGE_SIZE, IMAGE_SIZE),batch_size=BATCH_SIZE,class_mode='categorical',subset='training'
)
Validation_Data = imagegenerator.flow_from_directory(Train,target_size=(IMAGE_SIZE, IMAGE_SIZE),batch_size=BATCH_SIZE,class_mode='categorical',subset='validation'
)# Test data generator (no augmentation)
test_imagegenerator = ImageDataGenerator(rescale=1.0 / 255.0)
Test_Data = test_imagegenerator.flow_from_directory(Test,target_size=(IMAGE_SIZE,IMAGE_SIZE),batch_size=BATCH_SIZE,class_mode='categorical',# subset='test'
)
class ResidualBlock(layers.Layer):def __init__(self, filters, kernel_size=(3, 3), strides=1):super(ResidualBlock, self).__init__()self.conv1 = layers.Conv2D(filters, kernel_size, strides=strides, padding="same", activation='relu')self.conv2 = layers.Conv2D(filters, kernel_size, strides=1, padding='same', activation='relu')self.shortcut = layers.Conv2D(filters, (1, 1), strides=strides, padding='same', activation='relu')self.bn1 = layers.BatchNormalization()self.bn2 = layers.BatchNormalization()self.relu = layers.ReLU()def call(self, inputs):x = self.conv1(inputs)x = self.bn1(x)x = self.relu(x)x = self.conv2(x)x = self.bn2(x)shortcut = self.shortcut(inputs)x = layers.add([x, shortcut])x = self.relu(x)return x# ResNet Model definition
class ResNetModel(layers.Layer):def __init__(self):super(ResNetModel, self).__init__()self.conv1 = layers.Conv2D(32, (5, 5), activation='relu', input_shape=(IMAGE_SIZE, IMAGE_SIZE, 3),padding='same')self.maxpool1 = layers.MaxPooling2D((2, 2))# Residual Blocksself.resblock1 = ResidualBlock(32,strides=1)self.resblock2 = ResidualBlock(64,strides=2)self.resblock3 = ResidualBlock(128,strides=2)self.resblock4 = ResidualBlock(256, strides=2)# self.global_avg_pool = layers.GlobalAveragePooling2D()def call(self, inputs):print(inputs.shape)x = self.conv1(inputs)print(x.shape)x = self.maxpool1(x)print(x.shape)# Apply Residual Blocksx = self.resblock1(x)print(x.shape)x = self.resblock2(x)print(x.shape)# x = self.resblock3(x)# print(x.shape)# x = self.resblock4(x)# x = self.global_avg_pool(x)# print(x.shape)return x
class TransformerEncoder(layers.Layer):def __init__(self, num_heads=8, key_dim=64, ff_dim=256, dropout_rate=0.1):super(TransformerEncoder, self).__init__()self.attention = layers.MultiHeadAttention(num_heads=num_heads, key_dim=key_dim)self.dropout1 = layers.Dropout(dropout_rate)self.norm1 = layers.LayerNormalization()self.ff = layers.Dense(ff_dim, activation='relu')self.ff_output = layers.Dense(key_dim*num_heads)self.dropout2 = layers.Dropout(dropout_rate)self.norm2 = layers.LayerNormalization()def call(self, x):# Multi-head self-attentionattention_output = self.attention(x, x)attention_output = self.dropout1(attention_output)x = self.norm1(attention_output + x)  # Residual connection# Feed Forward Networkff_output = self.ff(x)ff_output = self.ff_output(ff_output)ff_output = self.dropout2(ff_output)x = self.norm2(ff_output + x)  # Residual connectionreturn x# Vision Transformer (ViT) 模型
class VisionTransformer(models.Model):def __init__(self, input_shape=(100, 100, 3), num_classes=141, num_encoders=3, patch_size=8, num_heads=16,key_dim=4, ff_dim=256, dropout_rate=0.2):super(VisionTransformer, self).__init__()self.patch_size = patch_size#Resnetself.resnet=ResNetModel()# Patch Embeddingself.conv = layers.Conv2D(64, (patch_size, patch_size), strides=(patch_size, patch_size), padding='valid')self.reshape = layers.Reshape((-1, 64))self.norm = layers.LayerNormalization()# 位置編碼層self.position_encoding = self.add_weight("position_encoding", shape=(1, 625, 64))# Stack multiple Transformer Encoder layersself.encoders = [TransformerEncoder(num_heads=num_heads, key_dim=key_dim, ff_dim=ff_dim, dropout_rate=dropout_rate) for _ inrange(num_encoders)]# Global Average Poolingself.global_avg_pooling = layers.GlobalAveragePooling1D()# Fully connected layerself.fc1 = layers.Dense(256, activation='relu')self.dropout = layers.Dropout(0.2)self.fc2 = layers.Dense(num_classes, activation='softmax')def call(self, inputs):#resnetx = self.resnet(inputs)# print("===========================")# print(x.shape)# Patch Embeddingx = self.reshape(x)# 添加位置編碼x = x + self.position_encoding  # 將位置編碼加到Patch嵌入向量中# print(x.shape)# x = self.norm(x)# Apply multiple Transformer encodersfor encoder in self.encoders:x = encoder(x)# Global Average Poolingx = self.global_avg_pooling(x)# Fully connected layersx = self.fc1(x)x = self.dropout(x)x = self.fc2(x)return x
# 構建 Vision Transformer 模型vit_model = VisionTransformer(input_shape=(100, 100, 3), num_classes=141, num_encoders=3)
vit_model.build(input_shape=(None, IMAGE_SIZE, IMAGE_SIZE, 3))  # 手動構建模型
# 打印模型摘要
vit_model.summary()
# 編譯模型
vit_model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.0001),loss='categorical_crossentropy',metrics=['accuracy']
)
checkpoint_path = "training_checkpoints_1/vit_model_checkpoint_epoch_{epoch:02d}.h5"# 創建ModelCheckpoint回調
checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(checkpoint_path,monitor='val_accuracy',  # 你可以選擇監控驗證集的損失或準確度save_best_only=True,  # 只保存驗證集損失最小的模型save_weights_only=True,  # 只保存權重(而不是整個模型)verbose=1  # 打印日志
)
# 檢查是否有保存的模型權重文件
checkpoint_dir = "training_checkpoints_1/"
# 查找所有的 .h5 文件
checkpoint_files = glob.glob(os.path.join(checkpoint_dir, "vit_model_checkpoint_epoch_*.h5"))
# print(latest_checkpoint)
if checkpoint_files:# 使用 os.path.getctime() 獲取文件創建時間(或者使用 getmtime() 獲取修改時間)latest_checkpoint = max(checkpoint_files, key=os.path.getctime)print(f"Loading model from checkpoint: {latest_checkpoint}")# 加載模型權重vit_model.load_weights(latest_checkpoint)
else:print("No checkpoint found, starting from scratch.")# 訓練模型
history = vit_model.fit(Train_Data,epochs=20,validation_data=Validation_Data,shuffle=True,callbacks=[checkpoint_callback]
)# 評估模型
test_loss, test_acc = vit_model.evaluate(Test_Data)
print(f"Test Loss: {test_loss}")
print(f"Test Accuracy: {test_acc}")
# 訓練和驗證的準確率和損失歷史記錄
def plot_training_history(history):# 創建子圖plt.figure(figsize=(14, 6))# 準備訓練準確率和驗證準確率的圖plt.subplot(1, 2, 1)plt.title('Accuracy History')plt.xlabel('Epochs')plt.ylabel('Accuracy')plt.plot(history.history['accuracy'], label='Training Accuracy', marker='o')plt.plot(history.history['val_accuracy'], label='Validation Accuracy', color='green', marker='o')plt.legend()# 準備訓練損失和驗證損失的圖plt.subplot(1, 2, 2)plt.title('Loss History')plt.xlabel('Epochs')plt.ylabel('Loss')plt.plot(history.history['loss'], label='Training Loss', marker='o')plt.plot(history.history['val_loss'], label='Validation Loss', color='green', marker='o')plt.legend()# 顯示圖形plt.tight_layout()plt.show()# 繪制訓練過程
plot_training_history(history)
for i in range(16):# 獲取測試數據的下一個批次img_batch, labels_batch = Test_Data.next()img = img_batch[0]  # 獲取當前批次的第一張圖像true_label_idx = np.argmax(labels_batch[0])  # 獲取真實標簽的索引# 獲取真實標簽的名稱true_label = [key for key, value in Train_Data.class_indices.items() if value == true_label_idx]# 擴展維度以匹配模型輸入EachImage = np.expand_dims(img, axis=0)# 進行預測prediction = vit_model.predict(EachImage)# 獲取預測標簽predicted_label = [key for key, value in Train_Data.class_indices.items() ifvalue == np.argmax(prediction, axis=1)[0]]# 獲取預測的概率predicted_prob = np.max(prediction, axis=1)[0]# 繪制圖像plt.subplot(4, 4, i + 1)plt.imshow(img)plt.title(f"True: {true_label[0]} \nPred: {predicted_label[0]} \nProb: {predicted_prob:.2f}")plt.axis('off')plt.tight_layout()
plt.show()

做了如下參數實驗

ResNet層數

Encoder層數

num_heads

test_accuracy

2(32,64)

3

4

92.14%

3(32,64,128)

3

4

94.53%

2(32,64)

3

8

96.19%

3(32,64,128)

3

8

97.46%

2(32,64)

3

16

93.32%

3(32,64,128)

3

16

93.17%

?分類效果圖

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/pingmian/76982.shtml
繁體地址,請注明出處:http://hk.pswp.cn/pingmian/76982.shtml
英文地址,請注明出處:http://en.pswp.cn/pingmian/76982.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

Ubuntu24.04安裝libgl1-mesa-glx 報錯,軟件包缺失

在 Ubuntu 24.04 系統中,您遇到的 libgl1-mesa-glx 軟件包缺失問題可能是由于該包在最新的 Ubuntu 版本中被重命名為 libglx-mesa0。以下是針對該問題的詳細解決方案: 1. 問題原因分析 包名稱變更:在 Ubuntu 24.04 中,libgl1-me…

webpack vite

? 1、webpack webpack打包工具(重點在于配置和使用,原理并不高優。只在開發環境應用,不在線上環境運行),壓縮整合代碼,讓網頁加載更快。 前端代碼為什么要進行構建和打包? 體積更好&#x…

如何在爬蟲中合理使用海外代理?在爬蟲中合理使用海外ip

我們都知道,爬蟲工作就是在各類網頁中游走,快速而高效地采集數據。然而如果目標網站分布在多個國家或者存在區域性限制,那靠普通的網絡訪問可能會帶來諸多阻礙。而這時,“海外代理”儼然成了爬蟲工程師們的得力幫手! …

數據倉庫分層存儲設計:平衡存儲成本與查詢效率

數據倉庫分層存儲不僅是一個技術問題,更是一種藝術:如何在有限的資源下,讓數據既能快速響應查詢,又能以最低的成本存儲? 目錄 一、什么是數據倉庫分層存儲? 二、分層存儲的體系架構 1. 數據源層(ODS,Operational Data Store) 2. 數據倉庫層(DW,Data Warehouse)…

YOLO學習筆記 | 基于YOLOv8的植物病害檢測系統

以下是基于YOLOv8的植物病害檢測系統完整技術文檔,包含原理分析、數學公式推導及代碼實現框架。 基于YOLOv8的智能植物病害檢測系統研究 摘要 針對傳統植物病害檢測方法存在的效率低、泛化性差等問題,本研究提出一種基于改進YOLOv8算法的智能檢測系統。通過設計輕量化特征提…

高級語言調用C接口(二)回調函數(4)Python

前面2篇分別說了java和c#調用C接口,參數為回調函數,回調函數中參數是結構體指針。 接下來說下python的調用方法。 from ctypes import * import sysclass stPayResult(Structure):_pack_ 4 # 根據實際C結構體的對齊方式設置(常見值為1,4,…

springboot啟動動態定時任務

1.自定義定時任務線程池 package com.x.devicetcpserver.global.tcp.tcpscheduler;import org.springframework.boot.context.properties.EnableConfigurationProperties; import org.springframework.context.annotation.Bean; import org.springframework.context.annotatio…

pytorch框架認識--手寫數字識別

手寫數字是機器學習中非常經典的案例,本文將通過pytorch框架,利用神經網絡來實現手寫數字識別 pytorch中提供了手寫數字的數據集,我們可以直接從pytorch中下載 MNIST中包含70000張手寫數字圖像:60000張用于訓練,10000…

WPF 使用依賴注入后關閉窗口程序不結束

原因是在ViewModel中在構造函數中注入了Window 對象,即使沒有使用,主窗口關閉程序不會退出,即使 ViewModel 是 AddTransient 注入的。 解決方法:不使用構造函數注入Window,通過GetService獲取Window 通過注入對象調用…

用戶管理(添加和刪除,查詢信息,切換用戶,查看登錄用戶,用戶組,配置文件)

目錄 添加和刪除用戶 查詢用戶信息 切換用戶 查看當前的操作用戶是誰 查看首次登錄的用戶是誰 用戶組(對屬于同個角色的用戶統一管理) 新增組 刪除組 添加用戶的同時,指定組 修改用戶的組 組的配置文件(/etc/group&…

PyTorch學習-小土堆教程

網絡搭建torch.nn.Module 卷積操作 torch.nn.functional.conv2d(input, weight, biasNone, stride1, padding0, dilation1, groups1) 神經網絡-卷積層

MVCC詳細介紹及面試題

目錄 1.什么是mvcc? 2.問題引入 3. MVCC實現原理? 3.1 隱藏字段 3.2 undo log 日志 3.2.1 undo log版本鏈 3.3 readview 3.3.1 當前讀 ?編輯 3.3.2 快照讀 3.3.3 ReadView中4個核心字段 3.3.4 版本數據鏈訪問的規則(了解&#x…

企業級Active Directory架構設計與運維管理白皮書

企業級Active Directory架構設計與運維管理白皮書 第一章 多域架構設計與信任管理 1.1 企業域架構拓撲設計 1.1.1 林架構設計規范 林根域規劃原則: 采用三段式域名結構(如corp.enterprise.com),避免使用不相關的頂級域名架構主…

android11 DevicePolicyManager淺析

目錄 📘 簡單定義 📘應用啟用設備管理者 📂 文件位置 🧠 DevicePolicyManager 功能分類舉例 🛡? 1. 安全策略控制 📷 2. 控制硬件功能 🧰 3. 應用管理 🔒 4. 用戶管理 &am…

Java學習手冊:Java線程安全與同步機制

在Java并發編程中,線程安全和同步機制是確保程序正確性和數據一致性的關鍵。當多個線程同時訪問共享資源時,如果不加以控制,可能會導致數據不一致、競態條件等問題。本文將深入探討Java中的線程安全問題以及解決這些問題的同步機制。 線程安…

PyTorch核心函數詳解:gather與where的實戰指南

PyTorch中的torch.gather和torch.where是處理張量數據的關鍵工具,前者實現基于索引的靈活數據提取,后者完成條件篩選與動態生成。本文通過典型應用場景和代碼演示,深入解析兩者的工作原理及使用技巧,幫助開發者提升數據處理的靈活…

聲學測溫度原理解釋

已知聲速,就可以得到溫度。 不同溫度下的勝訴不同。 25度的聲速大約346m/s 絕對溫度-273度 不同溫度下的聲速。 FPGA 通過測距雷達測溫度,固定測量距離,或者可以測出當前距離。已知距離,然后雷達發出聲波到接收到回波的時間&a…

【網絡篇】UDP協議的封裝分用全過程

大家好呀 我是浪前 今天講解的是網絡篇的第二章:UDP協議的封裝分用 我們的協議最開始是OSI七層網絡協議 這個OSI 七層網絡協議 是計算機的大佬寫的,但是這個協議一共有七層,太多了太麻煩了,于是我們就把這個七層網絡協議就簡化為…

spring-ai-alibaba使用Agent實現智能機票助手

示例目標是使用 Spring AI Alibaba 框架開發一個智能機票助手,它可以幫助消費者完成機票預定、問題解答、機票改簽、取消等動作,具體要求為: 基于 AI 大模型與用戶對話,理解用戶自然語言表達的需求支持多輪連續對話,能…

嵌入式C語言高級編程:OOP封裝、TDD測試與防御性編程實踐

一、面向對象編程(OOP) 盡管 C 語言并非面向對象編程語言,但借助一些編程技巧,也能實現面向對象編程(OOP)的核心特性,如封裝、繼承和多態。 1.1 封裝 封裝是把數據和操作數據的函數捆綁在一起,對外部隱藏…