基于yolov5和desnet的貓咪識別模型

前言

前段時間給學校的貓咪小程序搭建了識貓模型,可以通過貓咪的照片辨別出是那只貓貓,這里分享下具體的方案,先看效果圖:

源代碼在文末

模型訓練

在訓練服務器(或你的個人PC)上拉取本倉庫代碼。

圖片數據準備


進入`data`目錄,執行`npm install`安裝依賴。(需要 Node.js 環境,不確定老版本 Node.js 兼容性,建議使用最新版本。)


復制`config.demo.ts`文件并改名為`config.ts`,填寫Laf云環境的`LAF_APPID`;


執行`npm start`,腳本將根據小程序數據庫記錄拉取小程序云存儲中的圖片。

如果不打算從laf拉取數據,也可以自己制作數據集,只要保證文件格式如下就可以

catface文件下面的data文件中的photos中有若干個文件夾,每個文件夾名稱為id,文件夾下為圖片。

環境搭建


返回倉庫根目錄,執行`python -m pip install -r requirements.txt`安裝依賴。(需要Python>=3.8。不建議使用特別新版本的 Python,可能有兼容性問題。)


如果是linux系統,可以直接執行`bash prepare_yolov5.sh`拉取YOLOv5目標檢測模型所需的代碼,然后下載并預處理模型數據。如果是windows系統可以自己手動從gihub上拉取yolov5的模型。


執行`python3 data_preprocess.py`,腳本將使用YOLOv5從`data/photos`的圖片中識別出貓貓并截取到`data/crop_photos`目錄。

開始訓練

執行`python3 main.py`,使用默認參數訓練一個識別貓貓圖片的模型。(你可以通過`python3 main.py --help`查看幫助來自定義一些訓練參數。)程序運行結束時,你應當看到目錄的export文件夾下存在`cat.onnx`和`cat.json`兩個文件。(訓練數據使用TensorBoard記錄在`lightning_logs`文件夾下。若要查看準確率等信息,請自行運行TensorBoard。)


執行`python3 main.py --data data/photos --size 224 --name fallback`,使用修改后的參數訓練一個在YOLOv5無法找到貓貓時使用的全圖識別模型。程序運行結束時,你應當看到目錄的export文件夾下存在`fallback.onnx`和`fallback.json`兩個文件。

這里介紹下模型類的代碼,我們定義了學習率,網絡指定為densenet21

import torch
import torch.nn as nn
from torchvision import models
import torch.optim as optim
from pytorch_lightning import LightningModule
import torchmetrics
from typing import Tupleclass CatFaceModule(LightningModule):def __init__(self, num_classes: int, lr: float):super(CatFaceModule, self).__init__()self.save_hyperparameters()self.net = models.densenet121(num_classes=num_classes)self.loss_func = nn.CrossEntropyLoss()def forward(self, x: torch.Tensor) -> torch.Tensor:return self.net(x)def training_step(self, batch: Tuple[torch.Tensor, torch.LongTensor], batch_idx: int) -> torch.Tensor:loss, acc = self.do_step(batch)self.log('train/loss', loss, on_step=True, on_epoch=True)self.log('train/acc', acc, on_step=True, on_epoch=True)return lossdef validation_step(self, batch, batch_idx: int):loss, acc = self.do_step(batch)self.log('val/loss', loss, on_step=False, on_epoch=True)self.log('val/acc', acc, on_step=False, on_epoch=True)def do_step(self, batch: Tuple[torch.Tensor, torch.LongTensor]) -> Tuple[torch.Tensor, torch.Tensor]:# shape: x (B, C, H, W), y (B), w (B)x, y = batch# shape: out (B, num_classes)out = self.net(x)loss = self.loss_func(out, y)with torch.no_grad():# 每個類別分別計算準確率,以平衡地綜合考慮每只貓的準確率accuracy_per_class = torchmetrics.functional.accuracy(out, y, task="multiclass", num_classes=self.hparams['num_classes'], average=None)# 去掉batch中沒有出現的類別,這些位置為nannan_mask = accuracy_per_class.isnan()accuracy_per_class = accuracy_per_class.masked_fill(nan_mask, 0)# 剩下的位置取均值acc = accuracy_per_class.sum() / (~nan_mask).sum()return loss, accdef configure_optimizers(self) -> optim.Optimizer:return optim.Adam(self.parameters(), lr=self.hparams['lr'])

在模型訓練完畢后可以運行我編寫的modelTest,在這個文件中替換圖片為自己的圖片,觀察輸出是否正常,正常輸出是這樣的:

在這個輸出中,通過yolo檢測了圖片中是否含有貓咪,通過densenet對圖片所屬于的類進行概率計算,概率和id按照概率從大到小排序返回。

接口實現

我們訓練了兩個densenet模型,一個是全圖像的輸入為228的模型a,一個是輸入圖像為128的模型b,當請求打到服務器時,應用程序會先通過yolo檢測是否有貓,有的話就截取貓咪圖像,使用模型b;否則不截取,使用模型a。

以下是代碼:

from typing import Any
from werkzeug.datastructures import FileStorageimport torch
from PIL import Image
import numpy as np
import onnxruntime
from flask import Flask, request
from dotenv import load_dotenv
import os
import json
import time
from base64 import b64encode
from hashlib import sha256load_dotenv("./env", override=True)HOST_NAME = os.environ['HOST_NAME']
PORT = int(os.environ['PORT'])SECRET_KEY = os.environ['SECRET_KEY']
TOLERANT_TIME_ERROR = int(os.environ['TOLERANT_TIME_ERROR']) # 可以容忍的時間戳誤差(s)IMG_SIZE = int(os.environ['IMG_SIZE'])
FALLBACK_IMG_SIZE = int(os.environ['FALLBACK_IMG_SIZE'])CAT_BOX_MAX_RET_NUM = int(os.environ['CAT_BOX_MAX_RET_NUM']) # 最多可以返回的貓貓框個數
RECOGNIZE_MAX_RET_NUM = int(os.environ['RECOGNIZE_MAX_RET_NUM']) # 最多可以返回的貓貓識別結果個數print("==> loading models...")
assert os.path.isdir("export"), "*** export directory not found! you should export the training checkpoint to ONNX model."crop_model = torch.hub.load('yolov5', 'custom', 'yolov5/yolov5m.onnx', source='local')with open("export/cat.json", "r") as fp:cat_ids = json.load(fp)
cat_model = onnxruntime.InferenceSession("export/cat.onnx", providers=["CPUExecutionProvider"])with open("export/cat.json", "r") as fp:fallback_ids = json.load(fp)
fallback_model = onnxruntime.InferenceSession("export/cat.onnx", providers=["CPUExecutionProvider"])print("==> models are loaded.")app = Flask(__name__)
# 限制post大小為10MB
app.config['MAX_CONTENT_LENGTH'] = 10 * 1024 * 1024def wrap_ok_return_value(data: Any) -> str:return json.dumps({'ok': True,'message': 'OK','data': data})def wrap_error_return_value(message: str) -> str:return json.dumps({'ok': False,'message': message,'data': None})def check_signature(photo: FileStorage, timestamp: int, signature: str) -> bool:if abs(timestamp - time.time()) > TOLERANT_TIME_ERROR:return FalsephotoBase64 = b64encode(photo.read()).decode()photo.seek(0) # 重置讀取位置,避免影響后續操作signatureData = (photoBase64 + str(timestamp) + SECRET_KEY).encode()return signature == sha256(signatureData).hexdigest()@app.route("/recognizeCatPhoto", methods=["POST"])
@app.route("/recognizeCatPhoto/", methods=["POST"])
def recognize_cat_photo():try:photo = request.files['photo']timestamp = int(request.form['timestamp'])signature = request.form['signature']if not check_signature(photo, timestamp=timestamp, signature=signature):return wrap_error_return_value("fail signature check.")src_img = Image.open(photo).convert("RGB")# 使用 YOLOv5 進行目標檢測,結果為[{xmin, ymin, xmax, ymax, confidence, class, name}]格式results = crop_model(src_img).pandas().xyxy[0].to_dict('records')# 過濾非cat目標cat_results = list(filter(lambda target: target['name'] == 'cat', results))if len(cat_results) >= 1:cat_idx = int(request.form['catIdx']) if 'catIdx' in request.form and int(request.form['catIdx']) < len(cat_results) else 0# 裁剪出(指定的)catcat_result = cat_results[cat_idx]crop_box = cat_result['xmin'], cat_result['ymin'], cat_result['xmax'], cat_result['ymax']# 裁剪后直接resize到正方形src_img = src_img.crop(crop_box).resize((IMG_SIZE, IMG_SIZE))# 輸入到cat模型img = np.array(src_img, dtype=np.float32).transpose((2, 0, 1)) / 255scores = cat_model.run([node.name for node in cat_model.get_outputs()], {cat_model.get_inputs()[0].name: img[np.newaxis, :]})[0][0].tolist()# 按概率排序cat_id_with_score = sorted([dict(catID=cat_ids[i], score=scores[i]) for i in range(len(cat_ids))], key=lambda item: item['score'], reverse=True)else:# 沒有檢測到cat# 整張圖片直接resize到正方形src_img = src_img.resize((FALLBACK_IMG_SIZE, FALLBACK_IMG_SIZE))img = np.array(src_img, dtype=np.float32).transpose((2, 0, 1)) / 255scores = fallback_model.run([node.name for node in fallback_model.get_outputs()], {fallback_model.get_inputs()[0].name: img[np.newaxis, :]})[0][0].tolist()# 按概率排序cat_id_with_score = sorted([dict(catID=fallback_ids[i], score=scores[i]) for i in range(len(fallback_ids))], key=lambda item: item['score'], reverse=True)return wrap_ok_return_value({'catBoxes': [{'xmin': item['xmin'],'ymin': item['ymin'],'xmax': item['xmax'],'ymax': item['ymax']} for item in cat_results][:CAT_BOX_MAX_RET_NUM],'recognizeResults': cat_id_with_score[:RECOGNIZE_MAX_RET_NUM]})except BaseException as err:return wrap_error_return_value(str(err))if __name__ == "__main__":app.run(host=HOST_NAME, port=PORT, debug=False)

我們可以在本地運行,如果想測試的小伙伴可以把接口中密鑰校驗的代碼刪除,然后直接發送post請求即可。

源碼鏈接

cat-face: 貓臉識別程序,使用yolov5和densenet分類

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/bicheng/15834.shtml
繁體地址,請注明出處:http://hk.pswp.cn/bicheng/15834.shtml
英文地址,請注明出處:http://en.pswp.cn/bicheng/15834.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

[力扣題解] 200. 島嶼數量

題目&#xff1a;200. 島嶼數量 思路 深度優先搜索、廣度優先搜索、并查集&#xff1b; 代碼 廣度優先搜索 class Solution { public:int dir[4][2] {{0, 1}, {0, -1}, {1, 0}, {-1, 0}};queue<pair<int, int>> que;void bfs(vector<vector<char>&g…

10款免費黑科技軟件,強烈推薦!

1.AI視頻生成——巨日祿 網頁版https://aitools.jurilu.com/ "巨日祿 "是一款功能強大的文本視頻生成器&#xff0c;可以快速將文本內容轉換成極具吸引力的視頻。操作簡單&#xff0c;用戶只需輸入文字&#xff0c;選擇喜歡的樣式和模板&#xff0c; “巨日祿”就會…

Day39貪心算法part06

LC738單調遞增的數字&#xff08;未掌握&#xff09; 思路分析&#xff1a;一旦出現strNum[i - 1] > strNum[i]的情況&#xff08;非單調遞增&#xff09;&#xff0c;首先想讓strNum[i - 1]–&#xff0c;然后strNum[i]給為9字符串是不可變的&#xff0c;不可以使用s.char…

嵌入式交叉編譯:OpenCV

編譯ffmpeg 嵌入式交叉編譯&#xff1a;ffmpeg及相關庫-CSDN博客 下載 LINUX編譯opencv_linux 編譯opencv 模塊-CSDN博客 解壓編譯 penCV自帶編譯配置&#xff0c;十分方便。 BUILD_DIR${HOME}/build_libsCROSS_NAMEaarch64-mix210-linuxFFMPEG_DIR${BUILD_DIR}/libmkdir…

樹莓派學習筆記——樹莓派的三種GPIO編碼方式

1、板載編碼&#xff08;Board pin numbering&#xff09;: 板載編碼是樹莓派上的一種GPIO引腳編號方式&#xff0c;它指的是按照引腳在樹莓派主板上的物理位置來編號。這種方式對于初學者來說可能比較直觀&#xff0c;因為它允許你直接根據引腳在板上的位置來編程。 2、BCM編…

Linux gurb2簡介

文章目錄 前言一、GRUB 2簡介二、GRUB 2相關文件/文件夾2.1 /etc/default/grub文件2.2 /etc/grub.d/文件夾2.3 /boot/grub/grub.cfg文件 三、grubx64.efi參考資料 前言 簡單來說&#xff0c;引導加載程序&#xff08;boot loader&#xff09;是計算機啟動時運行的第一個軟件程…

一起學習大模型 - 從底層了解Token Embeddings的原理(2)

文章目錄 前言4. Token Embeddings綜合運用演示4.1 Token Embeddings處理4.2 偽代碼示例4.3 計算cat和dog兩個詞的相近程序4.3.1 計算方法4.3.2 例子4.3.3 輸出結果 前言 上一篇文章了解了Token Embeddings的原理&#xff0c;這一篇&#xff0c;我們一起來綜合運用學到的知識來…

純干貨分享 機器學習7大方面,30個硬核數據集

在剛剛開始學習算法的時候&#xff0c;大家有沒有過這種感覺&#xff0c;最最重要的那必須是算法本身&#xff01; 其實在一定程度上忽略了數據的重要性。 而事實上一定是&#xff0c;質量高的數據集可能是最重要的&#xff01; 數據集在機器學習算法項目中具有非常關鍵的重…

文章解讀與仿真程序復現思路——電力系統保護與控制EI\CSCD\北大核心《計及溫控厭氧發酵和階梯碳交易的農村綜合能源低碳經濟調度》

本專欄欄目提供文章與程序復現思路&#xff0c;具體已有的論文與論文源程序可翻閱本博主免費的專欄欄目《論文與完整程序》 論文與完整源程序_電網論文源程序的博客-CSDN博客https://blog.csdn.net/liang674027206/category_12531414.html 電網論文源程序-CSDN博客電網論文源…

網絡域名是什么意思

網絡域名&#xff0c;顧名思義&#xff0c;就是網絡上的名字&#xff0c;類似于現實中的地址或姓名一樣&#xff0c;用來標識網絡上的一個或一組計算機或服務器的位置&#xff0c;以及它們的相應服務資源。網絡域名是互聯網上最基礎的基礎設施之一&#xff0c;是網絡通信的“標…

【mysql】更新操作是如何執行的

現有一張表&#xff0c;建表語句如下&#xff1a; mysql> create table T(ID int primary key, c int);如果要將 ID2 這一行的a字段值加 1&#xff0c;SQL語句會這么寫&#xff1a; mysql> update T set c c 1 where ID 2;上面這條sql執行時&#xff0c;分析器會通過詞…

Nacos 微服務管理

Nacos 本教程將為您提供Nacos的基本介紹&#xff0c;并帶您完成Nacos的安裝、服務注冊與發現、配置管理等功能。在這個過程中&#xff0c;您將學到如何使用Nacos進行微服務管理。下方是官方文檔&#xff1a; Nacos官方文檔 1. Nacos 簡介 Nacos&#xff08;Naming and Confi…

操作符詳解(上)(新手向)

操作符詳解&#xff08;上&#xff09; 一&#xff0c;算術操作符&#xff08;雙目操作符&#xff09;1:‘’,‘-’,‘*’2&#xff1a;‘/’&#xff0c;‘%’ 一&#xff0c;單目操作符1:‘’,‘-’2&#xff1a;‘!’3&#xff1a;‘&’4&#xff1a;‘*’5&#xff1a;…

linux 排查java內存溢出(持續更新中)

場景 tone.jar 啟動后內存溢出,假設pid 為48044 排查 1.確定java程序的pid(進程id) ps 或 jps 都可以 ps -ef | grep tone jps -l 2.查看堆棧信息 jmap -heap 48044 3.查看對象的實例數量顯示前30 jmap -histo:live 48044 | head -n 30 4.查看線程狀態 jstack 48044

Spring 事件監聽

參考&#xff1a;Spring事件監聽流程分析【源碼淺析】_private void processbean(final string beanname, fi-CSDN博客 一、簡介 Spring早期通過實現ApplicationListener接口定義監聽事件&#xff0c;Spring 4.2開始通過EventListener注解實現監聽事件 FunctionalInterface p…

Rustdesk客戶端源碼編譯

1.安裝VCPKG windows平臺vcpkg安裝-CSDN博客 2.使用VCPKG安裝: windows平臺vcpkg安裝-CSDN博客 配置VCPKG_ROOT環境變量: 安裝靜態庫: ./vcpkg install libvpx:x64-windows-static libyuv:x64-windows-static opus:x64-windows-static aom:x64-windows-static 靜態庫安裝成…

【C語言深度解剖】(15):動態內存管理和柔性數組

&#x1f921;博客主頁&#xff1a;醉竺 &#x1f970;本文專欄&#xff1a;《C語言深度解剖》 &#x1f63b;歡迎關注&#xff1a;感謝大家的點贊評論關注&#xff0c;祝您學有所成&#xff01; ??&#x1f49c;&#x1f49b;想要學習更多C語言深度解剖點擊專欄鏈接查看&…

I.MX6ULL的官方 SDK 移植實驗

系列文章目錄 I.MX6ULL的官方 SDK 移植實驗 I.MX6ULL的官方 SDK 移植實驗 系列文章目錄一、前言二、I.MX6ULL 官方 SDK 包簡介三、硬件原理圖四、試驗程序編寫4.1 SDK 文件移植4.2 創建 cc.h 文件4.3 編寫實驗代碼 五、編譯下載驗證5.1編寫 Makefile 和鏈接腳本5.2編譯下載 一、…

列表元素添加的藝術:從單一到批量

新書上架~&#x1f447;全國包郵奧~ python實用小工具開發教程http://pythontoolsteach.com/3 歡迎關注我&#x1f446;&#xff0c;收藏下次不迷路┗|&#xff40;O′|┛ 嗷~~ 目錄 一、引言 二、向列表中添加單一元素 1. append方法 2. insert方法 三、向列表中添加批量…

MySQL 存儲過程(實驗報告)

一、實驗名稱&#xff1a; 存儲過程 二、實驗日期&#xff1a; 2024 年5 月 25 日 三、實驗目的&#xff1a; 掌握MySQL存儲過程的創建及調用&#xff1b; 四、實驗用的儀器和材料&#xff1a; 硬件&#xff1a;PC電腦一臺&#xff1b; 配置&#xff1a;內存&#xff0…