基于Keras3.x使用CNN實現簡單的貓狗分類

使用CNN實現簡單的貓狗分類
完整代碼見:基于Keras3.x使用CNN實現簡單的貓狗分類,置信度約為:85%

文章目錄

  • 概述
    • 項目整體目錄
    • 環境版本
    • 注意
  • 環境準備
    • 下載miniconda
    • 新建虛擬環境
    • 基于conda虛擬環境新建Pycharm項目
    • 下載分類需要用到的依賴
  • 數據準備
    • 數據目錄結構
    • 挪動圖片可以采用下列代碼
  • config
  • 準備訓練、測試數據集
  • 構建模型
  • 訓練模型
    • 訓練過程
    • 損失和準確度曲線
  • 測試模型
    • 使用帶標簽的測試圖片評估整體準確率
    • 定義模型類
    • 預測單張圖片
    • 將不帶標簽的測試圖片分類并存到對應目錄中
    • 源碼
  • 錯誤記錄
    • 雙重歸一化問題

概述

項目整體目錄

在這里插入圖片描述

  • /data 存放數據集
  • /model 存放訓練好的模型
  • /config.py 存儲一些關鍵模型參數和路徑信息等
  • /dataset.py 返回數據增強后的數據集,用于模型訓練
  • /model.py 定義模型
  • /train.py 訓練模型并繪制訓練損失和準確度曲線
  • /test.py 測試模型精準度

環境版本

  • python 3.11
  • keras 3.9.2
  • tensorflow 2.19.0

注意

本項目使用Keras3.x實現,代碼與keras 2.x有部分不同,請仔細甄別

環境準備

下載miniconda

如果沒有下載conda,參照上一篇文章進行下載配置

新建虛擬環境

創建一個用于貓狗識別的虛擬環境,可以指定py版本

conda create --name catanddog python=3.11

基于conda虛擬環境新建Pycharm項目

依次選擇菜單路徑:
File-NewProject-Pure Python
在彈出的窗口選擇:

  • custom environment
  • Select existing
  • Type:conda
  • 選擇剛剛新建的catanddog虛擬環境
    如圖:
    在這里插入圖片描述

下載分類需要用到的依賴

主要用到:keras3.9.2和tensorflow2.19.0

pip install keras

數據準備

從Kaggle上下載常用的貓狗分類數據集,下載下來后,有訓練數據(共25000張,貓狗各一半,帶標簽,命名示例:dog.0.jpg)和測試數據(共12500張,不帶標簽,命名示例:1.jgp)。
將訓練數據分為兩份,前20000張用于訓練數據,后5000張帶標簽的數據用于預測模型整體準確度。12500張測試數據可以用于單張圖片的模型預測,以及將貓狗分類后放入對應的目錄中,方便查看。
如圖組織數據:

  • test為12500張不帶標簽的測試數據;
  • test2為5000張帶標簽的測試數據;
  • train為20000張訓練數據

數據目錄結構

注意:必須把訓練數據、test2圖片放在新建好的cats和dogs目錄下,模型才能自動推斷標簽
在這里插入圖片描述

挪動圖片可以采用下列代碼

import os, shutil
# 將train_dir_tag_cat后2500張貓圖像移動到test2_dir_tag_cat
cats = ['cat.{}.jpg'.format(i) for i in range(1000)]
for cat in cats:src = os.path.join(train_dir_tag_cat, cat)dst = os.path.join(test2_dir_tag_cat, cat)shutil.move(src, dst)

config

config.py:用于存儲一些關鍵參數和路徑信息等。
訓練的batch為:32
訓練15個EPOCH

"""
@Author      :Ayaki Shi
@Date        :2025/4/18 11:03 
@Description : 配置信息
"""
import os, shutildata_dir = './data'# 訓練集、測試集所在路徑
test_dir = os.path.join(data_dir, 'test')
test2_dir = os.path.join(data_dir, 'test2')
train_dir = os.path.join(data_dir, 'train')# 劃分標簽后的數據路徑
train_dir_tag_cat = os.path.join(train_dir, 'cats')
test_dir_tag_cat = os.path.join(test_dir, 'cats')
test2_dir_tag_cat = os.path.join(test2_dir, 'cats')train_dir_tag_dog = os.path.join(train_dir, 'dogs')
test_dir_tag_dog = os.path.join(test_dir, 'dogs')
test2_dir_tag_dog = os.path.join(test2_dir, 'dogs')# 訓練參數
IMG_SIZE = (256, 256)
BATCH_SIZE = 32
EPOCHS = 15# 模型路徑
MODEL_PATH = './model/CatAndDogClassifier.keras'

準備訓練、測試數據集

dataset.py
訓練數據經過數據增強后返回,用于測試模型整體準確度的test2無需數據增強直接返回。
注意: 這個方法ImageDataGenerator已經不推薦使用了,因此使用image_dataset_from_directory這個方法,可以根據目錄自動推斷標簽,只是數據增強稍微復雜了點

"""
@Author      :Ayaki Shi
@Date        :2025/4/18 11:02
@Description : 返回dataset
"""from keras.api.utils import image_dataset_from_directory
from config import train_dir,test2_dir, BATCH_SIZE,IMG_SIZE
from keras import layers, models
import tensorflow as tf# 數據增強
def create_augmentation_model():return models.Sequential([layers.RandomFlip("horizontal", seed=42),layers.RandomRotation(0.2, fill_mode='nearest', seed=42),layers.RandomZoom(0.2, fill_mode='nearest', seed=42),layers.RandomContrast(0.3, seed=42),layers.RandomTranslation(0.1, 0.1, fill_mode='nearest', seed=42),], name="data_augmentation")def create_train_dataset():train_dataset = image_dataset_from_directory(train_dir,label_mode = 'binary',batch_size = BATCH_SIZE,image_size = IMG_SIZE,shuffle=True,  # 必須啟用 shuffleseed=42)# 創建預處理模型augmentation_model = create_augmentation_model()# 定義預處理函數def preprocess_train(image, label):image = augmentation_model(image, training=True)  # 訓練模式激活增強return image, labeltrain_dataset = train_dataset.map(preprocess_train,num_parallel_calls= tf.data.AUTOTUNE)print('--------------返回增強后的訓練數據集--------------')return train_dataset.prefetch(buffer_size=tf.data.AUTOTUNE)def create_test2_dataset():test2_dataset = image_dataset_from_directory(test2_dir,label_mode = 'binary',batch_size = BATCH_SIZE,image_size = IMG_SIZE,shuffle=False)print('--------------返回測試數據集[帶標簽]--------------')return test2_dataset

構建模型

model.py 模型結構:

  • 輸入層:指定輸入數據形狀
  • 數據歸一化
  • 四層卷積層和四層池化層交替
  • 展平層:將輸出的多維特征圖展平為一維向量
  • Dropout防止過擬合
  • 兩個全連接層,用于特征提取和最終分類
"""
@Author      :Ayaki Shi
@Date        :2025/4/18 11:02 
@Description : 創建模型
"""
from keras import layers, models, optimizersfrom config import IMG_SIZEdef create_model():model = models.Sequential([# 輸入層:指定輸入數據形狀layers.Input(shape=(*IMG_SIZE, 3)),layers.Rescaling(1./255),  # 歸一化到 [0,1]# 四層卷積層和四層池化層layers.Conv2D(32, (3, 3), activation='relu'),layers.MaxPooling2D(2, 2),layers.Conv2D(64, (3, 3), activation='relu'),layers.MaxPooling2D(2, 2),layers.Conv2D(128, (3, 3), activation='relu'),layers.MaxPooling2D(2, 2),layers.Conv2D(128, (3, 3), activation='relu'),layers.MaxPooling2D(2, 2),# 展平層:將輸出的多維特征圖展平為一維向量layers.Flatten(),# 防止過擬合layers.Dropout(0.5),# 兩個全連接層,用于特征提取和最終分類layers.Dense(512, activation='relu'),layers.Dense(1, activation='sigmoid')])# 編譯模型model.compile(loss='binary_crossentropy',  # 損失函數optimizer= optimizers.Adam(learning_rate=1e-4), # 優化器metrics=['accuracy']) # 評估標準:準確率print('--------------構建模型成功--------------')return model

訓練模型

train.py
獲取數據集-創建模型-訓練模型-保存模型-繪制損失和準確度曲線

"""
@Author      :Ayaki Shi
@Date        :2025/4/18 16:08 
@Description : 訓練模型
"""from dataset import create_train_dataset
from model import create_model
from config import EPOCHS, BATCH_SIZE, MODEL_PATH
import matplotlib.pyplot as pltdef train_model():# 獲取datasettrain_dataset = create_train_dataset()# 生成模型model = create_model()# 訓練模型print('--------------開始訓練模型--------------')history = model.fit(train_dataset,epochs = EPOCHS,batch_size = BATCH_SIZE)# 保存模型print('--------------開始保存模型--------------')model.save(MODEL_PATH)print('--------------開始繪制損失和準確性曲線--------------')# 繪制訓練損失曲線plt.figure(figsize=(10, 4))plt.plot(history.history['loss'], label='Training Loss', color='blue', marker='o')plt.title('Training Loss Over Epochs')plt.xlabel('Epochs')plt.ylabel('Loss')plt.legend()plt.grid(True)plt.show()# 繪制訓練準確率曲線plt.figure(figsize=(10, 4))plt.plot(history.history['accuracy'], label='Training Accuracy', color='green', marker='s')plt.title('Training Accuracy Over Epochs')plt.xlabel('Epochs')plt.ylabel('Accuracy')plt.legend()plt.grid(True)plt.show()if __name__ == '__main__':train_model()

訓練過程

在這里插入圖片描述

損失和準確度曲線

在這里插入圖片描述
在這里插入圖片描述

測試模型

test.py

使用帶標簽的測試圖片評估整體準確率

代碼見后面,可以看到整體準確度為85%左右
在這里插入圖片描述

定義模型類

代碼見后面

預測單張圖片

代碼見后面,預測為狗的概率是73%。
在這里插入圖片描述

將不帶標簽的測試圖片分類并存到對應目錄中

代碼見后面,可以看到大部分測試圖片都被放到了正確的目錄里,但是也有少數錯的

在這里插入圖片描述
在這里插入圖片描述

源碼

from keras import models
import numpy as np
import os, shutil
from keras_preprocessing import image
from config import MODEL_PATH,IMG_SIZE,test_dir,test_dir_tag_cat,test_dir_tag_dogDOG_TAG_STR = 'dog'
CAT_TAG_STR = 'cat'
NUM_IMAGES = 12500             # 測試圖片數class CatAndDogClassifier:def __init__(self):self.model = models.load_model(MODEL_PATH)print("模型加載成功!")def predict_single_image(self, img_path):img = image.load_img(img_path, target_size=IMG_SIZE)img_array = image.img_to_array(img)# 錯誤代碼:雙重歸一化# img_array = np.expand_dims(img_array, axis=0) / 255.0img_array = np.expand_dims(img_array, axis=0)prediction = self.model.predict(img_array)[0][0]print(prediction)return DOG_TAG_STR if prediction > 0.5 else CAT_TAG_STR, predictiondef classify_all_images(self):# 遍歷所有圖片# for i in range(1, NUM_IMAGES + 1):filename = ''for i in range(1, NUM_IMAGES + 1):try:#(文件名為1.jpg到12500.jpg)filename = f"{i}.jpg"src_path = os.path.join(test_dir, filename)# 跳過不存在的文件if not os.path.exists(src_path):print(f"Warning: {filename} 不存在,已跳過")continue# 進行預測label, confidence = self.predict_single_image(src_path)# 確定目標目錄dest_dir = test_dir_tag_dog if label == DOG_TAG_STR else test_dir_tag_catdest_path = os.path.join(dest_dir, filename)# 移動文件shutil.move(src_path, dest_path)if i%500 == 0: # 打印12500行太多了,每500行打印一次print(f"[{i}/12500] {filename} -> {dest_dir} (置信度: {confidence:.2%})")except Exception as e:print(f"處理 {filename} 時發生錯誤: {str(e)}")continuedef evaluate_model():from dataset import create_test2_datasettest2_dataset = create_test2_dataset()model = models.load_model(MODEL_PATH)loss, acc = model.evaluate(test2_dataset)print(f'\nTest accuracy: {acc:.2%}')if __name__ == '__main__':# 初始化分類器classifier = CatAndDogClassifier()# # 評估整體準確率# evaluate_model()# # 單張圖片預測# img_path = os.path.join('./data/train/dogs/dog.100.jpg')# label, prob = classifier.predict_single_image(img_path)# print(f'預測為: {label} (置信度: {prob if label == DOG_TAG_STR else 1 - prob:.2%})')# 將不帶標簽的測試圖片分類放入不同的文件夾classifier.classify_all_images()

錯誤記錄

雙重歸一化問題

在預測單張圖片過程中,出現了不管什么圖片,預測度總是特別低,只有7%左右
在這里插入圖片描述
首先預測結果不對第一時間考慮到是不是模型欠擬合或者過擬合的問題。
但是基于以下兩個原因:

  • 首先訓練過程中記錄的準確度和測試整體準確率都是85%,說明模型大概率是沒有問題的
  • 其次這個置信度已經低的離譜了
    所以考慮是在測試單張圖片對圖片處理出現了問題,經過排查,發現問題出在了,我在對單張圖片進行了歸一化,然后模型中又進行了一次歸一化,導致預測置信度極低。
    test.py
    在這里插入圖片描述

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/902762.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/902762.shtml
英文地址,請注明出處:http://en.pswp.cn/news/902762.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

中介者模式:解耦對象間復雜交互的設計模式

中介者模式:解耦對象間復雜交互的設計模式 一、模式核心:用中介者統一管理對象交互,避免兩兩直接依賴 當系統中多個對象之間存在復雜的網狀交互時(如 GUI 界面中按鈕、文本框、下拉框的聯動),對象間直接調…

豆包桌面版 1.47.4 可做瀏覽器,免安裝綠色版

自己動手升級更新辦法: 下載新版本后安裝,把 C:\Users\用戶名\AppData\Local\Doubao\Application 文件夾的文件,拷貝替換 DoubaoPortable\App\Doubao 文件夾的文件,就升級成功了。 再把安裝的豆包徹底卸載就可以。 桌面版比網頁版…

Android PackageManagerService(PMS)框架深度解析

目錄 一、概念與核心作用 二、技術架構與模塊組成 1. 分層架構 1.1 應用層架構細節 1.2 Binder接口層實現 1.3 PMS核心服務層 1.4 底層支持層實現 2. 核心模塊技術要點與工作流程 2.1 PackageParser 2.2 Settings 2.3 PermissionManager 2.4 Installer 2.5 ComponentM…

TensorFlow深度學習實戰(14)——循環神經網絡詳解

TensorFlow深度學習實戰(14)——循環神經網絡詳解 0. 前言1. 基本循環神經網絡單元1.1 循環神經網絡工作原理1.2 時間反向傳播1.3 梯度消失和梯度爆炸問題2. RNN 單元變體2.1 長短期記憶2.2 門控循環單元2.3 Peephole LSTM3. RNN 變體3.1 雙向 RNN3.2 狀態 RNN4. RNN 拓撲結構…

PySide6 GUI 學習筆記——常用類及控件使用方法(常用類矩陣QRectF)

文章目錄 類描述構造方法主要方法1. 基礎屬性2. 邊界操作3. 幾何運算4. 坐標調整5. 轉換方法6. 狀態判斷 類特點總結1. 浮點精度:2. 坐標系統:3. 有效性判斷:4. 幾何運算:5. 類型轉換:6. 特殊處理: 典型應用…

Electron主進程渲染進程間通信的方式

在 Electron 中,主進程和渲染進程之間的通信主要通過 IPC(進程間通信)機制實現。以下是幾種常見的通信方式: 1. 渲染進程向主進程發送消息(單向) 渲染進程可以通過 ipcRenderer.send 向主進程發送消息&am…

【C++基礎知識】C++類型特征組合:`disjunction_v` 和 `conjunction_v` 深度解析

這兩個模板是C17引入的類型特征組合工具,用于構建更復雜的類型判斷邏輯。下面我將從技術實現到實際應用進行全面剖析: 一、基本概念與C引入版本 1. std::disjunction_v (邏輯OR) 引入版本:C17功能:對多個類型特征進行邏輯或運算…

私有知識庫 Coco AI 實戰(二):攝入 MongoDB 數據

在之前的文章中,我們介紹過如何使用《 Logstash 遷移 MongoDB 數據到 Easyseach》,既然 Coco AI 后臺數據存儲也使用 Easysearch,我們能否直接把 MongoDB 的數據遷移到 Coco AI 的 Easysearch,使用 Coco AI 對數據進行檢索呢&…

sql server 與navicat測試后,連接qt

先用Navicat測試和sql的連通性,Navicat和sql連通之后,qt也能和sql連通了。 Navicat和Sqlserver Management 能連上,項目無法連接本地 Navicat 連接SQLServer 數據庫 QT國內鏡像網站 Navicat連接SqlServer的問題點 Sql Server的基本配置以及使…

2025年3月電子學會青少年機器人技術(六級)等級考試試卷-理論綜合

青少年機器人技術等級考試理論綜合試卷(六級) 分數:100 題數:30 一、單選題(共20題,共80分) 1. 2025年初,中國科技初創公司深度求索在大模型領域迅速崛起,其開源的大模型成為全球AI領域的焦…

spark local模式搭建運行示例

Apache Spark 是一個強大的分布式計算框架,但在本地模式下,它也可以作為一個單機程序運行,非常適合開發和測試階段。以下是一個簡單的示例,展示如何在本地模式下搭建和運行 Spark 程序。 一、環境準備 安裝 Java Spark 需要 Java…

【人工智能】解鎖 AI 潛能:DeepSeek 大模型遷移學習與特定領域微調的實踐

《Python OpenCV從菜鳥到高手》帶你進入圖像處理與計算機視覺的大門! 解鎖Python編程的無限可能:《奇妙的Python》帶你漫游代碼世界 隨著大型語言模型(LLMs)的快速發展,遷移學習與特定領域微調成為提升模型性能的關鍵技術。本文深入探討了 DeepSeek 大模型在遷移學習中的…

視頻智能分析平臺EasyCVR無線監控:全流程安裝指南與功能應用解析

在當今數字化安防時代,無線監控系統的安裝與調試對于保障各類場所的安全至關重要。本文將結合EasyCVR視頻監控的強大功能,為您詳細闡述監控系統安裝過程中的關鍵步驟和注意事項,幫助您打造一個高效、可靠的監控解決方案。 一、調試物資準備與…

【k8s系列7-更新中】kubeadm搭建Kubernetes高可用集群-三主兩從

主機準備 結合前面的章節,這里需要5臺機器,可以先創建一臺虛擬機作為基礎虛擬機。優先把5臺機器的公共部分優先在一臺機器上配置好 1、配置好靜態IP地址 2、主機名宇IP地址解析 [root@localhost ~]# cat /etc/hosts 127.0.0.1 localhost localhost.localdomain localhost…

【Java后端】MyBatis 與 MyBatis-Plus 如何防止 SQL 注入?從原理到實戰

在日常開發中,SQL 注入是一種常見但危害巨大的安全漏洞。如果你正在使用 MyBatis 或 MyBatis-Plus 進行數據庫操作,這篇文章將帶你系統了解:這兩個框架是如何防止 SQL 注入的,我們又該如何寫出安全的代碼。 什么是 SQL 注入&#…

數據分析案例:醫療健康數據分析

目錄 數據分析案例:醫療健康數據分析1. 項目背景2. 數據加載與預處理2.1 加載數據2.2 數據清洗3. 探索性數據分析(EDA)3.1 再入院率概覽3.2 按年齡分組的再入院率3.3 住院時長與再入院4. 特征工程與可視化5. 模型構建與評估5.1 數據劃分5.2 訓練邏輯回歸5.3 模型評估6. 業務…

3臺CentOS虛擬機部署 StarRocks 1 FE+ 3 BE集群

背景:公司最近業務數據量上去了,需要做一個漏斗分析功能,實時性要求較高,mysql已經已經不在適用,做了個大數據技術棧選型調研后,決定使用StarRocks StarRocks官網:StarRocks | A High-Performa…

軟件設計師/系統架構師---計算機網絡

概要 什么是計算機網絡? 計算機網絡是指將多臺計算機和其他設備通過通信線路互聯,以便共享資源和信息的系統。計算機網絡可以有不同的規模,從家庭網絡到全球互聯網。它們可以通過有線(如以太網)或無線(如W…

1.5軟考系統架構設計師:架構師的角色與能力要求 - 超簡記憶要點、知識體系全解、考點深度解析、真題訓練附答案及解析

超簡記憶要點 角色職責 需求規劃→架構設計→質量保障 能力要求 技術(架構模式/性能優化) 業務(模型抽象→技術方案) 管理(團隊協作/風險控制) 知識體系 基礎:CAP/設計模式/網絡協議案例&am…

基于STM32的汽車主門電動窗開關系統設計方案

芯片和功能模塊選型 主控芯片 STM32F103C8T6:基于 ARM Cortex - M3 內核,有豐富的 GPIO 接口用于連接各類外設,具備 ADC 模塊可用于電流檢測,還有 CAN 控制器方便實現 CAN 總線通信。它資源豐富、成本低,適合學生進行 DIY 項目開發。按鍵模塊 輕觸按鍵:用于控制車窗的自…