基于卷積神經網絡的圖像二分類檢測模型訓練與推理實現教程 | 幽絡源

前言

對于本教程,說白了,就是期望能通過一個程序判斷一張圖片是否為某個物體,或者說判斷一張圖片是否為某個缺陷。因為本教程是針對二分類問題,因此主要處理 是 與 不是 的問題,比如我的模型是判斷一張圖片是否為蘋果,那么拿一張圖片給模型去推理,他會得出這張圖是蘋果的概率,如果概率大于0.5(這個概率在0~1之間),那么就判斷為是蘋果。

教程內容

使用了Python的 TensorFlow 和 Keras 庫 構建卷積神經網絡來完成二分類模型訓練,以及使用模型完成對一張圖片的推理。原文鏈接:基于卷積神經網絡的圖像二分類檢測模型訓練與推理實現教程 | 幽絡源

大致步驟

1.確定環境與庫

2.準備數據集并且劃分

3.數據集的命名問題注意事項

4.編寫訓練代碼完成模型訓練

5.編寫推理代碼

6.測試二分類檢測結果

7.根據結果優化數據集

步驟1.確定環境與庫

Python環境是必備的,我這里所使用的Python版本為3.12.3

其次還需要以下庫,依次執行如下命令即可

pip install tensorflow
pip install pillow
pip install scipy

如圖

1

2

步驟2.準備數據集并且劃分

我這里以判斷圖片是否為沖溝缺陷 來準備數據集,首先創建數據集的目錄結構,結構如下

data/train/true_sample/ false_sample/  val/true_sample/false_sample/

QQ_1734065732662

目錄解釋:

data:作為數據集的根目錄

train和val分別為訓練集、驗證集目錄

true_sample:正類樣本,也就是我這里需要把含有沖溝缺陷的圖放到這個目錄

false_sample:負類樣本,也就是這里需要將不含有沖溝缺陷的圖片放進這個目錄

如圖,我向train和val的true_sample目錄加入了一些含有沖溝缺陷的圖片

3

對于負類樣本,也不是無腦的只要不是沖溝就往里面放,而是放置你認為訓練出的模型可能會將什么識別為正類樣本。比如滑坡和沖溝其實是有聯系的,但不完全等同于,所以我需要將滑坡相關的,但是沒有沖溝情況的圖片放入false_sample中,期望模型不要誤判。再比如一個蘋果,你可能需要把紅色氣球作為父類樣本,防止模型將紅氣球判斷為是蘋果,如圖是我的負類樣本

4

步驟3.數據集的命名問題注意事項

關于數據集的命名,這里其實有一個坑,但是先說避免坑的做法:就像步驟2一樣,你的正類樣本所放置的目錄命名為true_sample、負類樣本所放置的目錄命名為false_sample就行了。(如果看不懂下面的解釋,按照這里做法做就是了)

然后我來解釋下是什么坑,對于這個二分類模型訓練,訓練出來的模型,無非是識別 是 與 不是 的問題,但是模型怎么區分我的哪個目錄放置的為是,哪個目錄放置的為不是呢,步驟4會給出訓練代碼,訓練代碼中的加載數據集時有一行如下代碼

class_mode='binary'  # 二分類(沖溝缺陷 vs. 非沖溝缺陷)

這表示我們要做二分類模型訓練,加上這行代碼,在加載數據集時,Keras 會自動將這些文件夾的名稱作為標簽,分別命名為1 和 0,如果被命名為標簽1 的目錄,則在推理時,概率越接近于1,則越表示是標為1的目錄的樣本,反之概率越接近于0,則越表示是標為0的目錄的樣本。而keras自動命名標簽1和0時是根據目錄名首字母的順序來的字,字母靠前的標為0,后者為1,true_sample的首字母為t,false_sample的首字母為f,因此false_sample標為0,true_sample標為1,這是符合我們的正常預期的。

反面例子:

如果我把正類樣本放置于名為defect的目錄,負類樣本放置于no_defect目錄會怎樣呢,按照如上解釋,defect目錄會被標為0,no_defect目錄會被標為1,這就和我們預期相反了,什么意思呢。我把正類樣本放置defect目錄中,其推理結果將會是越接近0,則越表示為正類了,因此這里特別需要注意(如果你要自定義目錄名的話)。

步驟4.編寫訓練代碼完成模型訓練

先直接上訓練代碼

from tensorflow.keras import layers, models
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import tensorflow as tftrain_dir='data/train'
val_dir='data/val'# 設置圖像的尺寸和批量大小,不用改,保持150是最平衡的
IMG_HEIGHT = 150
IMG_WIDTH = 150
BATCH_SIZE = 12# 數據預處理與增強
train_datagen = ImageDataGenerator(rescale=1./255,  # 將像素值歸一化到 [0, 1] 區間shear_range=0.2,zoom_range=0.2,horizontal_flip=True
)validation_datagen = ImageDataGenerator(rescale=1./255)# 加載訓練和驗證數據
train_generator = train_datagen.flow_from_directory(train_dir,  # 訓練數據目錄target_size=(IMG_HEIGHT, IMG_WIDTH),  # 圖像尺寸batch_size=BATCH_SIZE,class_mode='binary'  # 二分類(沖溝缺陷 vs. 非沖溝缺陷)
)train_class_labels = train_generator.class_indices
print("訓練集自動標簽映射關系為:"+str(train_class_labels))validation_generator = validation_datagen.flow_from_directory(val_dir,  # 驗證數據目錄target_size=(IMG_HEIGHT, IMG_WIDTH),batch_size=BATCH_SIZE,class_mode='binary'
)val_class_labels = validation_generator.class_indices
print("測試集自動標簽映射關系為:"+str(val_class_labels))# 將數據生成器轉換為 tf.data.Dataset 并應用 repeat() 方法
train_dataset = tf.data.Dataset.from_generator(lambda: train_generator,output_signature=(tf.TensorSpec(shape=(None, IMG_HEIGHT, IMG_WIDTH, 3), dtype=tf.float32),tf.TensorSpec(shape=(None,), dtype=tf.int32))
)
train_dataset = train_dataset.repeat()  # 確保數據重復validation_dataset = tf.data.Dataset.from_generator(lambda: validation_generator,output_signature=(tf.TensorSpec(shape=(None, IMG_HEIGHT, IMG_WIDTH, 3), dtype=tf.float32),tf.TensorSpec(shape=(None,), dtype=tf.int32))
)
validation_dataset = validation_dataset.repeat()  # 確保數據重復# 構建模型
model = models.Sequential([layers.InputLayer(shape=(IMG_HEIGHT, IMG_WIDTH, 3)),  # 添加 Input 層layers.Conv2D(32, (3, 3), activation='relu'),layers.MaxPooling2D((2, 2)),layers.Conv2D(64, (3, 3), activation='relu'),layers.MaxPooling2D((2, 2)),layers.Conv2D(128, (3, 3), activation='relu'),layers.MaxPooling2D((2, 2)),layers.Flatten(),layers.Dense(128, activation='relu'),layers.Dense(1, activation='sigmoid')  # 輸出層,二分類問題
])# 編譯模型
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])# 訓練模型
model.fit(train_dataset,steps_per_epoch=train_generator.samples // BATCH_SIZE,epochs=30,validation_data=validation_dataset,validation_steps=validation_generator.samples // BATCH_SIZE
)# 保存模型
model.save('defect_detector_model.keras')  # 使用 .keras 格式保存模型

使用這段代碼訓練數據集你唯一需要注意的是保持代碼文件于數據集文件在同一目錄,或者使用絕對路徑,如圖

QQ_1734070943655

我們啟動訓練代碼,可以看到控制臺在按照規定的輪次30在訓練中,而且可以看到我在訓練代碼中加入了輸出標簽映射關系來確保正類與負類的映射關系正確,如圖

QQ_1734071390702

訓練后,你會得到一個名為defect_detector_nodel.keras的文件,推理時會使用該模型進行推理

步驟5.編寫推理代碼

代碼如下:

import os
from tensorflow.keras.models import load_model
import numpy as np
from tensorflow.keras.preprocessing import image# 加載訓練好的模型
model = load_model('defect_detector_model.keras')  # 注意加載的是 .keras 格式# 設置輸入圖像的目標尺寸(與訓練時相同)
IMG_HEIGHT = 150
IMG_WIDTH = 150# 定義函數來加載并預測圖像
def predict_image(img_path):# 加載圖像并進行預處理img = image.load_img(img_path, target_size=(IMG_HEIGHT, IMG_WIDTH))img_array = image.img_to_array(img)  # 將圖像轉換為數組img_array = np.expand_dims(img_array, axis=0)  # 擴展維度,成為一個 batchimg_array = img_array / 255.0  # 歸一化處理(與訓練時一致)# 預測圖像類別prediction = model.predict(img_array)  # 返回的是一個包含概率的數組return prediction[0][0]  # 提取預測的概率值picPath=r"測試圖.jpg"
confidence = predict_image(picPath)
print("有沖溝缺陷的概率為:"+str(confidence))

這段推理代碼中,我們加載了剛才訓練出的模型,然后使用了一張名為測試圖.jpg的圖片來進行推理,然后輸出他有缺陷的概率

步驟6.測試二分類檢測結果

我這里就不用一張圖片來測試了,我這里指定一個目錄,進行整個目錄來測試里面的圖片,還是附上我這個推理代碼吧

import os
from tensorflow.keras.models import load_model
import numpy as np
from tensorflow.keras.preprocessing import image# 加載訓練好的模型
model = load_model('defect_detector_model.keras')  # 注意加載的是 .keras 格式# 設置輸入圖像的目標尺寸(與訓練時相同)
IMG_HEIGHT = 150
IMG_WIDTH = 150# 定義函數來加載并預測圖像
def predict_image(img_path):# 加載圖像并進行預處理img = image.load_img(img_path, target_size=(IMG_HEIGHT, IMG_WIDTH))img_array = image.img_to_array(img)  # 將圖像轉換為數組img_array = np.expand_dims(img_array, axis=0)  # 擴展維度,成為一個 batchimg_array = img_array / 255.0  # 歸一化處理(與訓練時一致)# 預測圖像類別prediction = model.predict(img_array)  # 返回的是一個包含概率的數組return prediction[0][0]  # 提取預測的概率值# 測試目錄,包含要進行推理的圖像
testDir = r"D:\virtualTemp\pythonProject\CNN分類檢測\data\train\true_sample"
pics = os.listdir(testDir)
# 遍歷目錄中的所有圖片并進行預測
for pic in pics:picPath = os.path.join(testDir, pic)  # 獲取圖片的完整路徑# 獲取預測結果的置信度confidence = predict_image(picPath)# 輸出圖像的置信度和類別print(f"{pic} 置信度: {confidence:.4f}, 預測結果: {'有缺陷' if confidence >= 0.5 else '無缺陷'}")

我先使用正類樣本來測試,先看看拿訓練的數據如何,然后再用另外的圖片來測試

結果如下圖,正類樣本中只有一張圖判定為了無沖溝,但是我正類樣本中其實都應當是沖溝,而我有101張圖,因此這里正確率為99.009%

QQ_1734071615033

拿訓練的數據來說話可能沒有說服力,現在我使用爬圖器來批量的爬取一些圖片,需要的可以這里拿=>?幽絡源爬圖器

如圖我爬取了3輪橋梁破損圖,2輪沖溝地貌圖,對于沖溝圖,最好是手動刪一些莫名奇妙的圖,便于驗證

QQ_1734072068792

QQ_1734072170259

ok,然后先測試橋梁破損,如果足夠符合預期,足夠表示模型很好,那么推理出的有缺陷數量應該沒有或者很少才對,結果如下

QQ_1734072462049

看起來結果并不好,90張圖中,居然有44張判定為了有沖溝缺陷,正確率只有46/90=51.11%,再測試下正類檢測呢,如圖48張圖中只有11張判定為了無,還是不錯的。

步驟7.根據結果優化數據集

在步驟6的測試中可知,所訓練的模型對正類比較適應,對負類的學習還有所欠缺,處理方法有如下

1.調整判定指標confidence,一般為0.5,可以調大以提高正確率,但是不推薦這么做

2.加大訓練輪次

3.訓練時的父類樣本圖片多加一些

ok,方法1我不是很推薦,現在首先加大訓練次數到100,然后多爬取一些非沖溝圖加入到負類樣本之中,當然,橋梁破損的圖也放進去一些,然后重新訓練獲取模型。

訓練完后還是按照步驟6中來測試橋梁破損,如圖,這一次,90張圖中判定為有缺陷的只有7個了,非常不錯,正確率提高到了82/90=91.11%

QQ_1734073419635

結語

以上是幽絡源的基于卷積神經網絡的圖像二分類檢測模型訓練與推理實現教程,對Python、Java感興趣的小伙伴可加群交流

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/web/62811.shtml
繁體地址,請注明出處:http://hk.pswp.cn/web/62811.shtml
英文地址,請注明出處:http://en.pswp.cn/web/62811.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

安全見聞全解析

跟隨 瀧羽sec團隊學習 聲明! 學習視頻來自B站up主 瀧羽sec 有興趣的師傅可以關注一下,如涉及侵權馬上刪除文章,筆記只是方便各位師傅的學習和探討,文章所提到的網站以及內容,只做學習交流,其他均與本人以及…

代碼隨想錄-算法訓練營-番外(圖論02:島嶼數量,島嶼的最大面積)

day02 圖論part02 今日任務:島嶼數量,島嶼的最大面積 都是一個模子套出來的 https://programmercarl.com/kamacoder/0099.島嶼的數量深搜.html#思路往日任務: day01 圖論part01 今日任務:圖論理論基礎/所有可到達的路徑 代碼隨想錄圖論視頻部分還沒更新 https://programmercar…

RabbitMQ個人理解與基本使用

目錄 一. 作用: 二. RabbitMQ的5中隊列模式: 1. 簡單模式 2. Work模式 3. 發布/訂閱模式 4. 路由模式 5. 主題模式 三. 消息持久化: 消息過期時間 ACK應答 四. 同步接收和異步接收: 應用場景 五. 基本使用 &#xff…

前端怎么預覽pdf

1.背景 后臺返回了一個在線的pdf地址,需要我這邊去做一個pdf的預覽(需求1),并且支持配置是否可以下載(需求2),需要在當前頁就能預覽(需求3)。之前我寫過一篇預覽pdf的文…

Python 參數配置使用 XML 文件的教程:輕松管理你的項目配置

Python 參數配置使用 XML 文件的教程:輕松管理你的項目配置 一句話總結:當配置項存儲在外部文件(如 XML、JSON)時,修改配置無需重新編譯和發布代碼。通過更新 XML 文件即可調整參數,無需更改源代碼&#xf…

解決 MySQL 啟動失敗與大小寫問題,重置數據庫

技術文檔:解決 MySQL 啟動失敗與大小寫問題,重置數據庫 1. 問題背景 在使用 MySQL 時,可能遇到以下問題: MySQL 啟動失敗,日志顯示 “permission denied” 或 “Can’t create directory” 錯誤。MySQL 在修改配置文…

python webdriver-manager 實現selenium 免下載安裝webdriver

python webdriver-manager 實現selenium 免下載安裝webdriver selenium在自動化測試中,通常需要使用瀏覽器驅動來與瀏覽器進行交互。然而,手動下載、安裝、以及管理這些驅動非常麻煩,尤其是當驅動版本頻繁更新時。為此,webdriver-manager庫提供了一個極簡的方案,自動幫我…

滑動窗口算法專題

滑動窗口簡介 滑動窗口就是利用單調性,配合同向雙指針來優化暴力枚舉的一種算法。 該算法主要有四個步驟 1. 先進進窗口 2. 判斷條件,后續根據條件來判斷是出窗口還是進窗口 3. 出窗口 4.更新結果,更新結果這個步驟是不確定的&#xff0c…

C# 中的Task

文章目錄 前言一、Task 的基本概念二、創建 Task使用異步方法使用 Task.Run 方法 三、等待 Task 完成使用 await 關鍵字使用 Task.Wait 方法 四、處理 Task 的異常使用 try-catch 塊使用 Task.Exception 屬性 五、Task 的延續使用 ContinueWith 方法使用 await 關鍵字和異步方法…

【AIGC】如何高效使用ChatGPT挖掘AI最大潛能?26個Prompt提問秘訣幫你提升300%效率的!

還記得第一次使用ChatGPT時,那種既興奮又困惑的心情嗎?我是從一個對AI一知半解的普通用戶,逐步成長為現在的“ChatGPT大神”。這一過程并非一蹴而就,而是通過不斷的探索和實踐,掌握了一系列高效使用的技巧。今天&#…

浩辰CAD教程004:柱梁板

文章目錄 柱梁板標準柱角柱構造柱柱齊墻邊繪制梁繪制樓板 柱梁板 標準柱 繪制標準柱: ①:點選插入柱子②:沿著一根軸線布置柱子③:指定的矩形區域內的軸線交點插入柱子 替換現有柱子:選擇替換之后的柱子形狀&#x…

UNIX數據恢復—UNIX系統常見故障問題和數據恢復方案

UNIX系統常見故障表現: 1、存儲結構出錯; 2、數據刪除; 3、文件系統格式化; 4、其他原因數據丟失。 UNIX系統常見故障解決方案: 1、檢測UNIX系統故障涉及的設備是否存在硬件故障,如果存在硬件故障&#xf…

橋接模式的理解和實踐

橋接模式(Bridge Pattern),又稱橋梁模式,是一種結構型設計模式。它的核心思想是將抽象部分與實現部分分離,使它們可以獨立地進行變化,從而提高系統的靈活性和可擴展性。本文將詳細介紹橋接模式的概念、原理…

HTML綜合

一.HTML的初始結構 <!DOCTYPE html> <html lang"en"><head><!-- 設置文本字符 --><meta charset"UTF-8"><meta name"viewport" content"widthdevice-width, initial-scale1.0"><!-- 設置網頁…

二維碼數據集,使用yolov,voc,coco標注,3044張各種二維碼原始圖片(未圖像增強)

二維碼數據集&#xff0c;使用yolov&#xff0c;voc&#xff0c;coco標注&#xff0c;3044張各種二維碼原始圖片&#xff08;未圖像增強&#xff09; 數據集分割 訓練組70&#xff05; 2132圖片 有效集20&#xff05; 607圖片 測試集10&#xff05; 305圖…

Python爬蟲技術的最新發展

在互聯網的海洋中&#xff0c;數據就像是一顆顆珍珠&#xff0c;而爬蟲技術就是我們手中的潛水艇。2024年&#xff0c;爬蟲技術有了哪些新花樣&#xff1f;讓我們一起潛入這個話題&#xff0c;看看最新的發展和趨勢。 1. 異步爬蟲&#xff1a;速度與激情 隨著現代Web應用的復…

用豆包MarsCode IDE,從0到1畫出精美數據大屏!

豆包MarsCode IDE 是一個云端 AI IDE 平臺&#xff0c;通過內置的 AI 編程助手&#xff0c;開箱即用的開發環境&#xff0c;可以幫助開發者更專注于各類項目的開發。 作為一名前端開發工程師&#xff0c;今天想嘗試利用豆包MarsCode IDE&#xff0c;選擇 Vue Echarts 創建一個…

游戲引擎學習第42天

倉庫: https://gitee.com/mrxiao_com/2d_game 簡介 目前我們正在研究的內容是如何構建一個基本的游戲引擎。我們將深入了解游戲開發的每一個環節&#xff0c;從最基礎的技術實現到高級的游戲編程。 角色移動代碼 我們主要討論的是角色的移動代碼。我一直希望能夠使用一些基…

Redis是什么?Redis和MongoDB的區別在那里?

Redis介紹 Redis&#xff08;Remote Dictionary Server&#xff09;是一個開源的、基于內存的數據結構存儲系統&#xff0c;它可以用作數據庫、緩存和消息中間件。以下是關于Redis的詳細介紹&#xff1a; 一、數據結構支持 字符串&#xff08;String&#xff09; 這是Redis最…

計算機網絡中的三大交換技術詳解與實現

目錄 計算機網絡中的三大交換技術詳解與實現1. 計算機網絡中的交換技術概述1.1 交換技術的意義1.2 三大交換技術簡介 2. 電路交換技術2.1 理論介紹2.2 Python實現及代碼詳解2.3 案例分析 3. 分組交換技術3.1 理論介紹3.2 Python實現及代碼詳解3.3 案例分析 4. 報文交換技術4.1 …