EMO實戰:使用EMO實現圖像分類任務(一)

文章目錄

  • 摘要
  • 安裝包
    • 安裝timm
    • 安裝 grad-cam
    • 安裝einops
  • 數據增強Cutout和Mixup
  • EMA
  • 項目結構
  • 計算mean和std
  • 生成數據集

摘要

論文翻譯:https://blog.csdn.net/m0_47867638/article/details/132034098?spm=1001.2014.3001.5501
官方源碼:https://github.com/zhangzjn/EMO

EMO是高效、輕量級的模型,以在參數、FLOPs和性能之間實現平衡,適用于密集預測任務。文章從倒立殘差塊(IRB)和Transformer的有效組件的統一角度出發,將基于CNN的IRB擴展到基于注意力的模型,并抽象出一個用于輕量級模型設計的單殘留元移動塊(MMB)。

作者提出了反向殘差移動塊(iRMB),并根據簡單而有效的設計準則構建了一個只有iRMB的類ResNet高效模型(EMO)用于下游任務。實驗結果表明,EMO在ImageNet-1K、COCO2017和ADE20K基準測試上表現出優異的性能,超過了SOTA的CNN和基于注意力的模型。EMO-1m/2M/5M達到71.5、75.1和78.4 Top-1,同時實現了良好的參數效率與精度權衡,運行速度比iPhone14上的EdgeNeXt快2.8-4.0倍。

EMO為輕量級模型設計提供了一個新的思路,通過將CNN和Transformer的有效組件統一起來,實現了高效的模型性能。大量實驗驗證了所提出的方法的有效性和優越性,為相關領域的研究提供了有益的參考。總的來說,文章提出的方法在參數效率、性能和計算成本之間實現了良好的平衡,具有廣泛的應用前景。

在這里插入圖片描述

這篇文章使用EMO完成植物分類任務,模型采用EMO_1M向大家展示如何使用EMO。EMO_1M在這個數據集上實現了96+%的ACC,如下圖:

請添加圖片描述

請添加圖片描述

通過這篇文章能讓你學到:

  1. 如何使用數據增強,包括transforms的增強、CutOut、MixUp、CutMix等增強手段?
  2. 如何實現EMO模型實現訓練?
  3. 如何使用pytorch自帶混合精度?
  4. 如何使用梯度裁剪防止梯度爆炸?
  5. 如何使用DP多顯卡訓練?
  6. 如何繪制loss和acc曲線?
  7. 如何生成val的測評報告?
  8. 如何編寫測試腳本測試測試集?
  9. 如何使用余弦退火策略調整學習率?
  10. 如何使用AverageMeter類統計ACC和loss等自定義變量?
  11. 如何理解和統計ACC1和ACC5?
  12. 如何使用EMA?
  13. 如果使用Grad-CAM 實現熱力圖可視化?

如果基礎薄弱,對上面的這些功能難以理解可以看我的專欄:經典主干網絡精講與實戰
這個專欄,從零開始時,一步一步的講解這些,讓大家更容易接受。

安裝包

安裝timm

使用pip就行,命令:

pip install timm

mixup增強和EMA用到了timm

安裝 grad-cam

pip install grad-cam

安裝einops

pip install einops

數據增強Cutout和Mixup

為了提高成績我在代碼中加入Cutout和Mixup這兩種增強方式。實現這兩種增強需要安裝torchtoolbox。安裝命令:

pip install torchtoolbox

Cutout實現,在transforms中。

from torchtoolbox.transform import Cutout
# 數據預處理
transform = transforms.Compose([transforms.Resize((224, 224)),Cutout(),transforms.ToTensor(),transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])

需要導入包:from timm.data.mixup import Mixup,

定義Mixup,和SoftTargetCrossEntropy

  mixup_fn = Mixup(mixup_alpha=0.8, cutmix_alpha=1.0, cutmix_minmax=None,prob=0.1, switch_prob=0.5, mode='batch',label_smoothing=0.1, num_classes=12)criterion_train = SoftTargetCrossEntropy()

參數詳解:

mixup_alpha (float): mixup alpha 值,如果 > 0,則 mixup 處于活動狀態。

cutmix_alpha (float):cutmix alpha 值,如果 > 0,cutmix 處于活動狀態。

cutmix_minmax (List[float]):cutmix 最小/最大圖像比率,cutmix 處于活動狀態,如果不是 None,則使用這個 vs alpha。

如果設置了 cutmix_minmax 則cutmix_alpha 默認為1.0

prob (float): 每批次或元素應用 mixup 或 cutmix 的概率。

switch_prob (float): 當兩者都處于活動狀態時切換cutmix 和mixup 的概率 。

mode (str): 如何應用 mixup/cutmix 參數(每個’batch’,‘pair’(元素對),‘elem’(元素)。

correct_lam (bool): 當 cutmix bbox 被圖像邊框剪裁時應用。 lambda 校正

label_smoothing (float):將標簽平滑應用于混合目標張量。

num_classes (int): 目標的類數。

EMA

EMA(Exponential Moving Average)是指數移動平均值。在深度學習中的做法是保存歷史的一份參數,在一定訓練階段后,拿歷史的參數給目前學習的參數做一次平滑。具體實現如下:


import logging
from collections import OrderedDict
from copy import deepcopy
import torch
import torch.nn as nn_logger = logging.getLogger(__name__)class ModelEma:def __init__(self, model, decay=0.9999, device='', resume=''):# make a copy of the model for accumulating moving average of weightsself.ema = deepcopy(model)self.ema.eval()self.decay = decayself.device = device  # perform ema on different device from model if setif device:self.ema.to(device=device)self.ema_has_module = hasattr(self.ema, 'module')if resume:self._load_checkpoint(resume)for p in self.ema.parameters():p.requires_grad_(False)def _load_checkpoint(self, checkpoint_path):checkpoint = torch.load(checkpoint_path, map_location='cpu')assert isinstance(checkpoint, dict)if 'state_dict_ema' in checkpoint:new_state_dict = OrderedDict()for k, v in checkpoint['state_dict_ema'].items():# ema model may have been wrapped by DataParallel, and need module prefixif self.ema_has_module:name = 'module.' + k if not k.startswith('module') else kelse:name = knew_state_dict[name] = vself.ema.load_state_dict(new_state_dict)_logger.info("Loaded state_dict_ema")else:_logger.warning("Failed to find state_dict_ema, starting from loaded model weights")def update(self, model):# correct a mismatch in state dict keysneeds_module = hasattr(model, 'module') and not self.ema_has_modulewith torch.no_grad():msd = model.state_dict()for k, ema_v in self.ema.state_dict().items():if needs_module:k = 'module.' + kmodel_v = msd[k].detach()if self.device:model_v = model_v.to(device=self.device)ema_v.copy_(ema_v * self.decay + (1. - self.decay) * model_v)

加入到模型中。

#初始化
if use_ema:model_ema = ModelEma(model_ft,decay=model_ema_decay,device='cpu',resume=resume)# 訓練過程中,更新完參數后,同步update shadow weights
def train():optimizer.step()if model_ema is not None:model_ema.update(model)# 將model_ema傳入驗證函數中
val(model_ema.ema, DEVICE, test_loader)

針對沒有預訓練的模型,容易出現EMA不上分的情況,這點大家要注意啊!

項目結構

EMO_Demo
├─data1
│  ├─Black-grass
│  ├─Charlock
│  ├─Cleavers
│  ├─Common Chickweed
│  ├─Common wheat
│  ├─Fat Hen
│  ├─Loose Silky-bent
│  ├─Maize
│  ├─Scentless Mayweed
│  ├─Shepherds Purse
│  ├─Small-flowered Cranesbill
│  └─Sugar beet
├─models
│  ├─__init__.py
│  ├─_emo_ios.py
│  ├─basic_modules.py
│  ├─cls_factory.py
│  └─emo.py
├─mean_std.py
├─makedata.py
├─train.py
├─cam_image.py
└─test.py

models:來源官方代碼,對面的代碼做了一些適應性修改。
mean_std.py:計算mean和std的值。
makedata.py:生成數據集。
ema.py:EMA腳本
train.py:訓練InceptionNext模型
cam_image.py:熱力圖可視化

計算mean和std

為了使模型更加快速的收斂,我們需要計算出mean和std的值,新建mean_std.py,插入代碼:

from torchvision.datasets import ImageFolder
import torch
from torchvision import transformsdef get_mean_and_std(train_data):train_loader = torch.utils.data.DataLoader(train_data, batch_size=1, shuffle=False, num_workers=0,pin_memory=True)mean = torch.zeros(3)std = torch.zeros(3)for X, _ in train_loader:for d in range(3):mean[d] += X[:, d, :, :].mean()std[d] += X[:, d, :, :].std()mean.div_(len(train_data))std.div_(len(train_data))return list(mean.numpy()), list(std.numpy())if __name__ == '__main__':train_dataset = ImageFolder(root=r'data1', transform=transforms.ToTensor())print(get_mean_and_std(train_dataset))

數據集結構:

image-20220221153058619

運行結果:

([0.3281186, 0.28937867, 0.20702125], [0.09407319, 0.09732835, 0.106712654])

把這個結果記錄下來,后面要用!

生成數據集

我們整理還的圖像分類的數據集結構是這樣的

data
├─Black-grass
├─Charlock
├─Cleavers
├─Common Chickweed
├─Common wheat
├─Fat Hen
├─Loose Silky-bent
├─Maize
├─Scentless Mayweed
├─Shepherds Purse
├─Small-flowered Cranesbill
└─Sugar beet

pytorch和keras默認加載方式是ImageNet數據集格式,格式是

├─data
│  ├─val
│  │   ├─Black-grass
│  │   ├─Charlock
│  │   ├─Cleavers
│  │   ├─Common Chickweed
│  │   ├─Common wheat
│  │   ├─Fat Hen
│  │   ├─Loose Silky-bent
│  │   ├─Maize
│  │   ├─Scentless Mayweed
│  │   ├─Shepherds Purse
│  │   ├─Small-flowered Cranesbill
│  │   └─Sugar beet
│  └─train
│      ├─Black-grass
│      ├─Charlock
│      ├─Cleavers
│      ├─Common Chickweed
│      ├─Common wheat
│      ├─Fat Hen
│      ├─Loose Silky-bent
│      ├─Maize
│      ├─Scentless Mayweed
│      ├─Shepherds Purse
│      ├─Small-flowered Cranesbill
│      └─Sugar beet

新增格式轉化腳本makedata.py,插入代碼:

import glob
import os
import shutilimage_list=glob.glob('data1/*/*.png')
print(image_list)
file_dir='data'
if os.path.exists(file_dir):print('true')#os.rmdir(file_dir)shutil.rmtree(file_dir)#刪除再建立os.makedirs(file_dir)
else:os.makedirs(file_dir)from sklearn.model_selection import train_test_split
trainval_files, val_files = train_test_split(image_list, test_size=0.3, random_state=42)
train_dir='train'
val_dir='val'
train_root=os.path.join(file_dir,train_dir)
val_root=os.path.join(file_dir,val_dir)
for file in trainval_files:file_class=file.replace("\\","/").split('/')[-2]file_name=file.replace("\\","/").split('/')[-1]file_class=os.path.join(train_root,file_class)if not os.path.isdir(file_class):os.makedirs(file_class)shutil.copy(file, file_class + '/' + file_name)for file in val_files:file_class=file.replace("\\","/").split('/')[-2]file_name=file.replace("\\","/").split('/')[-1]file_class=os.path.join(val_root,file_class)if not os.path.isdir(file_class):os.makedirs(file_class)shutil.copy(file, file_class + '/' + file_name)

完成上面的內容就可以開啟訓練和測試了。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/41034.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/41034.shtml
英文地址,請注明出處:http://en.pswp.cn/news/41034.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

shell的兩種屬性: 交互(interactive)與登錄(login)

1. 背景 在看shell變量的時候引起了興趣: 局部變量,全局變量,環境變量,shell的配置文件,參考博客: http://c.biancheng.net/view/773.html 2. 交互式與非交互式 參考博客: shell的兩個屬性:是否交互式(interactive), 是否登錄…

生產環境下的終極指南:使用 Docker 部署 Nacos 集群和 MySQL

🌷🍁 博主貓頭虎 帶您 Go to New World.?🍁 🦄 博客首頁——貓頭虎的博客🎐 🐳《面試題大全專欄》 文章圖文并茂🦕生動形象🦖簡單易學!歡迎大家來踩踩~🌺 &a…

01-集群安裝JDK(普通用戶)

機器部署 集群規劃 我們準備三臺服務器kk01、kk02、kk03,內存4G、硬盤50G、處理器4核心2內核(總8) kk01使用 192.168.188.128 kk02使用 192.168.188.129 kk03使用 192.168.188.130 模板機準備 我們先創建一臺作為模板機,后…

C++ 11 新特性 學習筆記

1、字符串原始字面量 R“()”用于取消轉義,可用于路徑表示 運行成功 這兩個RawValue起到描述作用(可以不寫),并不參與輸出 注意,這里輸出中文亂碼 2、nullptr NULL在C中表示0,在非C中表示萬能指針 nullpt…

Vue3 使用json編輯器

安裝 npm install json-editor-vue3 main中引入 main.js 中加入下面代碼 import "jsoneditor";不然會有報錯&#xff0c;如jsoneditor does not provide an export named ‘default’。 圖片信息來源-github 代碼示例 <template><json-editor-vue class…

SQL | 分組數據

10-分組數據 兩個新的select子句&#xff1a;group by子句和having子句。 10.1-數據分組 上面我們學到了&#xff0c;使用SQL中的聚集函數可以匯總數據&#xff0c;這樣&#xff0c;我們就能夠對行進行計數&#xff0c;計算和&#xff0c;計算平均數。 目前為止&#xff0c…

ESP-C3入門21. I2C接口點亮1306驅動的OLED屏

ESP-C3入門21. 點亮1306驅動的OLED屏 一、Espressif/ssd1306 驅動簡介1. 驅動介紹2. OLED充電泵概念 二、I2C 通訊步驟1. 初始化 I2C 總線 (i2c_master_init()函數)&#xff1a;2. 創建 I2C 命令句柄 (i2c_cmd_handle_t cmd i2c_cmd_link_create())&#xff1a;3. 發送啟動信號…

【C#】獲取電腦CPU、內存、屏幕、磁盤等信息

通過WMI類來獲取電腦各種信息&#xff0c;參考文章&#xff1a;WMI_04_常見的WMI類的屬性_wmi scsilogicalunit_fantongl的博客-CSDN博客 自己整理了獲取電腦CPU、內存、屏幕、磁盤等信息的代碼 #region 系統信息/// <summary>/// 電腦信息/// </summary>public p…

flinksql報錯 Cannot determine simple type name “org“

flink版本 1.15 報錯內容 2023-08-17 15:46:02 java.lang.RuntimeException: Could not instantiate generated class WatermarkGenerator$0at org.apache.flink.table.runtime.generated.GeneratedClass.newInstance(GeneratedClass.java:74)at org.apache.flink.table.runt…

低功耗、5Mbps、RS-422 接口電路MS2583/MS2583M

MS2583/MS2583M 是一款低功耗、 5Mbps 、高 ESD 能力的 RS422 通訊接口電路。 在接收狀態下&#xff0c;其功耗僅為 0.3mA 左右。 A/B 端 ESD 耐壓可達 15kV &#xff0c;且無自激現象。當輸出短路發生大電 流導致電路溫度過高時&#xff0c;開啟內部過溫保護電路&…

go 使用 make 初始化 slice 切片使用注意

go 使用 make 初始化 slice 切片 時指定長度和不指定長度的情況 指定長度 package mainimport "fmt"func main() {s1 : make([]int, 5)data : []int{1, 2, 3}for _, v : range data {s1 append(s1, v)}fmt.Println(s1) }// 以上代碼會輸出 // [0 0 0 0 0 1 2 3] //…

vue中的路由緩存和解決方案

路由緩存的原因 解決方法 推薦方案二&#xff0c;使用鉤子函數beforeRouteUpdate&#xff0c;每次路由更新前執行

手寫spring筆記

手寫spring筆記 《Spring 手擼專欄》筆記 IoC部分 Bean初始化和屬性注入 Bean的信息封裝在BeanDefinition中 /*** 用于記錄Bean的相關信息*/ public class BeanDefinition {/*** Bean對象的類型*/private Class beanClass;/*** Bean對象中的屬性信息*/private PropertyVal…

MFC第三十天 通過CToolBar類開發文字工具欄和工具箱、GDI+邊框填充以及基本圖形的繪制方法、圖形繪制過程的反色線模型和實色模型

文章目錄 CControlBar通過CToolBar類開發文字工具欄和工具箱CMainFrame.hCAppCMainFrm.cppCMainView.hCMainView.cppCEllipse.hCEllipse.cppCLine.hCLine.cppCRRect .hCRRect .cpp CControlBar class AFX_NOVTABLE CControlBar : public CWnd{DECLARE_DYNAMIC(CControlBar)pro…

OC調用Swift編寫的framework

一、前言 隨著swift趨向穩定&#xff0c;越來越多的公司都開始用swift來編寫蘋果相關的業務了&#xff0c;關于swift的利弊這里就不多說了。這里詳細介紹OC調用swift編寫的framework庫的步驟 二、制作framework 1、新建項目&#xff0c;選擇framework 2、填寫framework的名稱…

AutoHotkey:定時刪除目錄下指定分鐘以前的文件,帶UI界面

刪除指定目錄下&#xff0c;所有在某個指定分鐘以前的文件&#xff0c;可以用來清理經常生成很多文件的目錄&#xff0c;但又需要保留最新的一部分文件 支持拖放目錄到界面 能夠記憶設置&#xff0c;下次啟動后不用重新設置&#xff0c;可以直接開始 應用場景比如&#xff1a…

WinForm內嵌Unity3D

Unity3D可以C#腳本進行開&#xff0c;使用vstu2013.msi插件&#xff0c;可以實現在VS2013中的調試。在開發完成后&#xff0c;由于項目需要&#xff0c;需要將Unity3D嵌入到WinForm中。WinForm中的UnityWebPlayer Control可以載入Unity3D。先看效果圖。 一、為了能夠動態設置ax…

【boost網絡庫從青銅到王者】第五篇:asio網絡編程中的同步讀寫的客戶端和服務器示例

文章目錄 1、簡介2、客戶端設計3、服務器設計3.1、session函數3.2、StartListen函數3、總體設計 4、效果測試5、遇到的問題5.1、服務器遇到的問題5.1.1、不用顯示調用bind綁定和listen監聽函數5.1.2、出現 Error occured!Error code : 10009 .Message: 提供的文件句柄無效。 [s…

Failed to execute goal org.apache.maven.plugins

原因&#xff1a; 這個文件D:\java\maven\com\ruoyi\pg-student\maven-metadata-local.xml出了問題 解決&#xff1a; 最簡單的直接刪除D:\java\maven\com\ruoyi\pg-student\maven-metadata-local.xml重新打包 或者把D:\java\maven\com\ruoyi\pg-student這個目錄下所有文件…

性能測試場景設計

性能測試場景設計&#xff0c;是性能測試中的重要概念&#xff0c;性能測試場景設計&#xff0c;目的是要描述如何執行性能測試。 通常來講&#xff0c;性能測試場景設計主要會涉及以下部分&#xff1a; 并發用戶數是多少&#xff1f; 測試剛開始時&#xff0c;以什么樣的速率…