YOLOv5、YOLOv8改進:MobileViT:輕量通用且適合移動端的視覺Transformer

MobileViT: Light-weight, General-purpose, and Mobile-friendly Vision Transformer

論文:https://arxiv.org/abs/2110.02178

e1a87d4eb4a90e532c099f743d822b53.png

?1簡介

MobileviT是一個用于移動設備的輕量級通用可視化Transformer,據作者介紹,這是第一次基于輕量級CNN網絡性能的輕量級ViT工作,性能SOTA!。性能優于MobileNetV3、CrossviT等網絡。

輕量級卷積神經網絡(CNN)是移動視覺任務的實際應用。他們的空間歸納偏差允許他們在不同的視覺任務中以較少的參數學習表征。然而,這些網絡在空間上是局部的。為了學習全局表征,采用了基于自注意力的Vision Transformer(ViTs)。與CNN不同,ViT是heavy-weight。

在本文中,本文提出了以下問題:是否有可能結合CNN和ViT的優勢,構建一個輕量級、低延遲的移動視覺任務網絡?

為此提出了MobileViT,一種輕量級的、通用的移動設備Vision Transformer。MobileViT提出了一個不同的視角,以Transformer作為卷積處理信息。

98d3914f64b8e17da85f833084a6d45e.png

?

實驗結果表明,在不同的任務和數據集上,MobileViT顯著優于基于CNN和ViT的網絡。

在ImageNet-1k數據集上,MobileViT在大約600萬個參數的情況下達到了78.4%的Top-1準確率,對于相同數量的參數,比MobileNetv3和DeiT的準確率分別高出3.2%和6.2%。

在MS-COCO目標檢測任務中,在參數數量相近的情況下,MobileViT比MobileNetv3的準確率高5.7%。

2.Mobile-ViT

MobileViT Block如下圖所示,其目的是用較少的參數對輸入張量中的局部和全局信息進行建模。

形式上,對于一個給定的輸入張量, MobileViT首先應用一個n×n標準卷積層,然后用一個一個點(或1×1)卷積層產生特征。n×n卷積層編碼局部空間信息,而點卷積通過學習輸入通道的線性組合將張量投影到高維空間(d維,其中d>c)。

?

通過MobileViT,希望在擁有有效感受野的同時,對遠距離非局部依賴進行建模。一種被廣泛研究的建模遠程依賴關系的方法是擴張卷積。然而,這種方法需要謹慎選擇膨脹率。否則,權重將應用于填充的零而不是有效的空間區域。

另一個有希望的解決方案是Self-Attention。在Self-Attention方法中,具有multi-head self-attention的vision transformers(ViTs)在視覺識別任務中是有效的。然而,vit是heavy-weight,并由于vit缺乏空間歸納偏差,表現出較差的可優化性。

下面附上改進代碼

---------------------------------------------分割線--------------------------------------------------

在common中加入如下代碼

需要安裝一個einops模塊

pip --default-timeout=5000 install -i https://pypi.tuna.tsinghua.edu.cn/simple einops

這邊建議直接興建一個?

?

import torch
import torch.nn as nnfrom einops import rearrangedef conv_1x1_bn(inp, oup):return nn.Sequential(nn.Conv2d(inp, oup, 1, 1, 0, bias=False),nn.BatchNorm2d(oup),nn.SiLU())def conv_nxn_bn(inp, oup, kernal_size=3, stride=1):return nn.Sequential(nn.Conv2d(inp, oup, kernal_size, stride, 1, bias=False),nn.BatchNorm2d(oup),nn.SiLU())class PreNorm(nn.Module):def __init__(self, dim, fn):super().__init__()self.norm = nn.LayerNorm(dim)self.fn = fndef forward(self, x, **kwargs):return self.fn(self.norm(x), **kwargs)class FeedForward(nn.Module):def __init__(self, dim, hidden_dim, dropout=0.):super().__init__()self.net = nn.Sequential(nn.Linear(dim, hidden_dim),nn.SiLU(),nn.Dropout(dropout),nn.Linear(hidden_dim, dim),nn.Dropout(dropout))def forward(self, x):return self.net(x)class Attention(nn.Module):def __init__(self, dim, heads=8, dim_head=64, dropout=0.):super().__init__()inner_dim = dim_head *  headsproject_out = not (heads == 1 and dim_head == dim)self.heads = headsself.scale = dim_head ** -0.5self.attend = nn.Softmax(dim = -1)self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)self.to_out = nn.Sequential(nn.Linear(inner_dim, dim),nn.Dropout(dropout)) if project_out else nn.Identity()def forward(self, x):qkv = self.to_qkv(x).chunk(3, dim=-1)q, k, v = map(lambda t: rearrange(t, 'b p n (h d) -> b p h n d', h = self.heads), qkv)dots = torch.matmul(q, k.transpose(-1, -2)) * self.scaleattn = self.attend(dots)out = torch.matmul(attn, v)out = rearrange(out, 'b p h n d -> b p n (h d)')return self.to_out(out)class Transformer(nn.Module):def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout=0.):super().__init__()self.layers = nn.ModuleList([])for _ in range(depth):self.layers.append(nn.ModuleList([PreNorm(dim, Attention(dim, heads, dim_head, dropout)),PreNorm(dim, FeedForward(dim, mlp_dim, dropout))]))def forward(self, x):for attn, ff in self.layers:x = attn(x) + xx = ff(x) + xreturn xclass MV2Block(nn.Module):def __init__(self, inp, oup, stride=1, expansion=4):super().__init__()self.stride = strideassert stride in [1, 2]hidden_dim = int(inp * expansion)self.use_res_connect = self.stride == 1 and inp == oupif expansion == 1:self.conv = nn.Sequential(# dwnn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),nn.BatchNorm2d(hidden_dim),nn.SiLU(),# pw-linearnn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),nn.BatchNorm2d(oup),)else:self.conv = nn.Sequential(# pwnn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False),nn.BatchNorm2d(hidden_dim),nn.SiLU(),# dwnn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),nn.BatchNorm2d(hidden_dim),nn.SiLU(),# pw-linearnn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),nn.BatchNorm2d(oup),)def forward(self, x):if self.use_res_connect:return x + self.conv(x)else:return self.conv(x)class MobileViTBlock(nn.Module):def __init__(self, dim, depth, channel, kernel_size, patch_size, mlp_dim, dropout=0.):super().__init__()self.ph, self.pw = patch_sizeself.conv1 = conv_nxn_bn(channel, channel, kernel_size)self.conv2 = conv_1x1_bn(channel, dim)self.transformer = Transformer(dim, depth, 4, 8, mlp_dim, dropout)self.conv3 = conv_1x1_bn(dim, channel)self.conv4 = conv_nxn_bn(2 * channel, channel, kernel_size)def forward(self, x):y = x.clone()# Local representationsx = self.conv1(x)x = self.conv2(x)# Global representations_, _, h, w = x.shapex = rearrange(x, 'b d (h ph) (w pw) -> b (ph pw) (h w) d', ph=self.ph, pw=self.pw)x = self.transformer(x)x = rearrange(x, 'b (ph pw) (h w) d -> b d (h ph) (w pw)', h=h//self.ph, w=w//self.pw, ph=self.ph, pw=self.pw)# Fusionx = self.conv3(x)x = torch.cat((x, y), 1)x = self.conv4(x)return xclass MobileViT(nn.Module):def __init__(self, image_size, dims, channels, num_classes, expansion=4, kernel_size=3, patch_size=(2, 2)):super().__init__()ih, iw = image_sizeph, pw = patch_sizeassert ih % ph == 0 and iw % pw == 0L = [2, 4, 3]self.conv1 = conv_nxn_bn(3, channels[0], stride=2)self.mv2 = nn.ModuleList([])self.mv2.append(MV2Block(channels[0], channels[1], 1, expansion))self.mv2.append(MV2Block(channels[1], channels[2], 2, expansion))self.mv2.append(MV2Block(channels[2], channels[3], 1, expansion))self.mv2.append(MV2Block(channels[2], channels[3], 1, expansion))   # Repeatself.mv2.append(MV2Block(channels[3], channels[4], 2, expansion))self.mv2.append(MV2Block(channels[5], channels[6], 2, expansion))self.mv2.append(MV2Block(channels[7], channels[8], 2, expansion))self.mvit = nn.ModuleList([])self.mvit.append(MobileViTBlock(dims[0], L[0], channels[5], kernel_size, patch_size, int(dims[0]*2)))self.mvit.append(MobileViTBlock(dims[1], L[1], channels[7], kernel_size, patch_size, int(dims[1]*4)))self.mvit.append(MobileViTBlock(dims[2], L[2], channels[9], kernel_size, patch_size, int(dims[2]*4)))self.conv2 = conv_1x1_bn(channels[-2], channels[-1])self.pool = nn.AvgPool2d(ih//32, 1)self.fc = nn.Linear(channels[-1], num_classes, bias=False)def forward(self, x):x = self.conv1(x)x = self.mv2[0](x)x = self.mv2[1](x)x = self.mv2[2](x)x = self.mv2[3](x)      # Repeatx = self.mv2[4](x)x = self.mvit[0](x)x = self.mv2[5](x)x = self.mvit[1](x)x = self.mv2[6](x)x = self.mvit[2](x)x = self.conv2(x)x = self.pool(x).view(-1, x.shape[1])x = self.fc(x)return xdef mobilevit_xxs():dims = [64, 80, 96]channels = [16, 16, 24, 24, 48, 48, 64, 64, 80, 80, 320]return MobileViT((256, 256), dims, channels, num_classes=1000, expansion=2)def mobilevit_xs():dims = [96, 120, 144]channels = [16, 32, 48, 48, 64, 64, 80, 80, 96, 96, 384]return MobileViT((256, 256), dims, channels, num_classes=1000)def mobilevit_s():dims = [144, 192, 240]channels = [16, 32, 64, 64, 96, 96, 128, 128, 160, 160, 640]return MobileViT((256, 256), dims, channels, num_classes=1000)def count_parameters(model):return sum(p.numel() for p in model.parameters() if p.requires_grad)if __name__ == '__main__':img = torch.randn(5, 3, 256, 256)vit = mobilevit_xxs()out = vit(img)print(out.shape)print(count_parameters(vit))vit = mobilevit_xs()out = vit(img)print(out.shape)print(count_parameters(vit))vit = mobilevit_s()out = vit(img)print(out.shape)print(count_parameters(vit))

yolo.py中導入并注冊

加入MV2Block, MobileViTBlock

?

?

修改yaml文件

# YOLOv5 🚀 by Ultralytics, GPL-3.0 license# Parameters
nc: 1 # number of classes
depth_multiple: 0.33  # model depth multiple
width_multiple: 0.50  # layer channel multiple
anchors:- [10,13, 16,30, 33,23]  # P3/8- [30,61, 62,45, 59,119]  # P4/16- [116,90, 156,198, 373,326]  # P5/32# YOLOv5 backbone
backbone:# [from, number, module, args] 640 x 640
#  [[-1, 1, Conv, [32, 6, 2, 2]],  # 0-P1/2  320 x 320[[-1, 1, Focus, [32, 3]],[-1, 1, MV2Block, [32, 1, 2]],  # 1-P2/4[-1, 1, MV2Block, [48, 2, 2]],  # 160 x 160[-1, 2, MV2Block, [48, 1, 2]],[-1, 1, MV2Block, [64, 2, 2]],  # 80 x 80[-1, 1, MobileViTBlock, [64,96, 2, 3, 2, 192]], # 5 out_dim,dim, depth, kernel_size, patch_size, mlp_dim[-1, 1, MV2Block, [80, 2, 2]],  # 40 x 40[-1, 1, MobileViTBlock, [80,120, 4, 3, 2, 480]], # 7[-1, 1, MV2Block, [96, 2, 2]],   # 20 x 20[-1, 1, MobileViTBlock, [96,144, 3, 3, 2, 576]], # 11-P2/4 # 9]# YOLOv5 head
head:[[-1, 1, Conv, [256, 1, 1]],[-1, 1, nn.Upsample, [None, 2, 'nearest']],[[-1, 7], 1, Concat, [1]],  # cat backbone P4[-1, 3, C3, [256, False]],  # 13[-1, 1, Conv, [128, 1, 1]],[-1, 1, nn.Upsample, [None, 2, 'nearest']],[[-1, 5], 1, Concat, [1]],  # cat backbone P3[-1, 3, C3, [128, False]],  # 17 (P3/8-small)[-1, 1, Conv, [128, 3, 2]],[[-1, 14], 1, Concat, [1]],  # cat head P4[-1, 3, C3, [256, False]],  # 20 (P4/16-medium)[-1, 1, Conv, [256, 3, 2]],[[-1, 10], 1, Concat, [1]],  # cat head P5[-1, 3, C3, [512, False]],  # 23 (P5/32-large)[[17, 20, 23], 1, Detect, [nc, anchors]],  # Detect(P3, P4, P5)]

?

  1. 修改mobilevit.py

?

補充說明

einops.EinopsError: Error while processing rearrange-reduction pattern "b d (h ph) (w pw) -> b (ph pw) (h w) d".

Input tensor shape: torch.Size([1, 120, 42, 42]). Additional info: {'ph': 4, 'pw': 4}

是因為輸入輸出不匹配造成

記得關掉rect哦!一個是在參數里,另一個在下圖。如果要在test或者val中跑,同樣要改

?

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/34906.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/34906.shtml
英文地址,請注明出處:http://en.pswp.cn/news/34906.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

LeetCode150道面試經典題--單詞規律(簡單)

1.題目 給定一種規律 pattern 和一個字符串 s ,判斷 s 是否遵循相同的規律。 這里的 遵循 指完全匹配,例如, pattern 里的每個字母和字符串 s 中的每個非空單詞之間存在著雙向連接的對應規律。 2.示例 pattern"abba" s "c…

SpingBoot-Vue前后端——實現CRUD

目錄??????? 一、實例需求 ? 二、代碼實現 🏌 數據庫 👀 后端實現 📫 前端實現 🌱 三、源碼下載 👋 一、實例需求 ? 實現一個簡單的CRUD,包含前后端交互。 二、代碼實現 🏌 數…

[樹莓派]ImportError: libcblas.so.3: cannot open shared object file

嘗試在樹莓派4b安裝opencv-python,出現以下錯誤,ImportError: libcblas.so.3: cannot open shared object file: No such file or directory 解決方法,安裝依賴 sudo apt install libatlas-base-dev 再次import cv2就不會報這個錯誤。

約束綜合中的邏輯互斥時鐘(Logically Exclusive Clocks)

注:本文翻譯自Constraining Logically Exclusive Clocks in Synthesis 邏輯互斥時鐘的定義 邏輯互斥時鐘是指設計中活躍(activate)但不彼此影響的時鐘。常見的情況是,兩個時鐘作為一個多路選擇器的輸入,并根據sel信號…

八、解析應用程序——分析應用程序(1)

文章目錄 一、確定用戶輸入入口點1.1 URL文件路徑1.2 請求參數1.3 HTTP消息頭1.4 帶外通道 二、確定服務端技術2.1 提取版本信息2.2 HTTP指紋識別2.3 文件拓展名2.4 目錄名稱2.5 會話令牌2.6 第三方代碼組件 小結 枚舉盡可能多的應用程序內容只是解析過程的一個方面。分析應用程…

小龜帶你敲排序之冒泡排序

冒泡排序 一. 定義二.題目三. 思路分析(圖文結合)四. 代碼演示 一. 定義 冒泡排序(Bubble Sort,臺灣譯為:泡沫排序或氣泡排序)是一種簡單的排序算法。它重復地走訪過要排序的數列,一次比較兩個元…

【深度學習】再談向量化

前言 向量化是一種思想,不僅體現在可以將任意實體用向量來表示,更為突出的表現了人工智能的發展脈絡。向量的演進過程其實都是人工智能向前發展的時代縮影。 1.為什么人工智能需要向量化 電腦如何理解一門語言?電腦的底層是二進制也就是0和1&…

Arduino+esp32學習筆記

學習目標: 使用Arduino配置好藍牙或者wifi模塊 學習使用python配置好藍牙或者wifi模塊 學習內容(筆記): 一、 Arduino語法基礎 Arduino語法是基于C的語法,C又是c基礎上增加了面向對象思想等進階語言。那就只記錄沒見過的。 單多…

全國各城市-貨物進出口總額和利用外資-外商直接投資額實際使用額(1999-2020年)

最新數據顯示,全國各城市外商直接投資額實際使用額在過去一年中呈現了穩步增長的趨勢。這一數據為研究者提供了對中國外商投資活動的全面了解,并對未來投資趨勢和政策制定提供了重要參考。 首先,這一數據反映了中國各城市作為外商投資的熱門目…

Effective Java筆記(31)利用有限制通配符來提升 API 的靈活性

參數化類型是不變的&#xff08; invariant &#xff09; 。 換句話說&#xff0c;對于任何兩個截然不同的類型 Typel 和 Type2 而言&#xff0c; List<Type1 &#xff1e;既不是 List<Type 2 &#xff1e; 的子類型&#xff0c;也不是它的超類型 。雖然 L ist<String…

Oracle自定義函數生成MySQL表結構的DDL語句

1. 自定義函數fnc_table_to_mysql create or replace function fnc_table_to_mysql ( i_owner in string, i_table_name in string, i_number_default_type in string : decimal, i_auto_incretment_column_name in stri…

Linux 文件查看命令

一、cat命令 1.cat文件名&#xff0c;查看文件內容&#xff1a; 例如&#xff0c;查看main.c文件的內容&#xff1a; 2.cat < 文件名&#xff0c;往文件中寫入數據&#xff0c; Ctrld是結束輸入 例如&#xff0c;向文件a.txt中寫入數據&#xff1a; 查看剛剛寫入a.txt的…

Yolov5(一)VOC劃分數據集、VOC轉YOLO數據集

代碼使用方法注意修改一下路徑、驗證集比例、類別名稱&#xff0c;其他均不需要改動&#xff0c;自動劃分訓練集、驗證集、建好全部文件夾、一鍵自動生成Yolo格式數據集在當前目錄下&#xff0c;大家可以直接修改相應的配置文件進行訓練。 目錄 使用方法&#xff1a; 全部代碼…

解決監督學習,深度學習報錯:AttributeError: ‘xxx‘ object has no attribute ‘module‘!!!!

哈嘍小伙伴們大家好呀&#xff0c;很長時間沒有更新啦&#xff0c;最近在研究一個問題&#xff0c;就是AttributeError: xxx object has no attribute module 今天終于是解決了&#xff0c;所以來記錄分享一下&#xff1a; 我這里出現的問題是&#xff1a; 因為我的數據比較大…

SQL優化

一、插入數據 優化 1.1 普通插入&#xff08;小數據量&#xff09; 普通插入&#xff08;小數據量&#xff09;&#xff1a; 采用批量插入&#xff08;一次插入的數據不建議超過1000條&#xff09;手動提交事務主鍵順序插入 1.2 大批量數據插入 大批量插入&#xff1a;&…

Android 開發中需要了解的 Gradle 知識

作者&#xff1a;wkxjc Gradle 是一個基于 Groovy 的構建工具&#xff0c;用于構建 Android 應用程序。在 Android 開發中&#xff0c;了解 Gradle 是非常重要的&#xff0c;因為它是 Android Studio 默認的構建工具&#xff0c;可以幫助我們管理依賴項、構建應用程序、運行測試…

macOS 如何安裝git和nvm

首先&#xff1a;先來安裝git 打開macOS終端 將下面的命令復制粘貼進去&#xff1a; curl -O https://mirrors.edge.kernel.org/pub/software/scm/git/git-2.41.0.tar.gz 版本號可以參考一下官網的 我這里安裝的是目前最新的2.41.0 然后在終端輸入下面的代碼或者雙擊git的…

數據結構:力扣OJ題

目錄 ?編輯題一&#xff1a;鏈表分割 思路一&#xff1a; 題二&#xff1a;相交鏈表 思路一&#xff1a; 題三&#xff1a;環形鏈表 思路一&#xff1a; 題四&#xff1a;鏈表的回文結構 思路一&#xff1a; 鏈表反轉&#xff1a; 查找中間節點&#xff1a; 本人實力…

YOLOv8+ByteTrack多目標跟蹤(行人車輛計數與越界識別)

課程鏈接&#xff1a;https://edu.csdn.net/course/detail/38901 ByteTrack是發表于2022年的ECCV國際會議的先進的多目標跟蹤算法。YOLOv8代碼中已集成了ByteTrack。本課程使用YOLOv8和ByteTrack對視頻中的行人、車輛做多目標跟蹤計數與越界識別&#xff0c;開展YOLOv8目標檢測…

Leetcode每日一題:23. 合并 K 個升序鏈表(2023.8.12 C++)

目錄 23. 合并 K 個升序鏈表 題目描述&#xff1a; 實現代碼與解析&#xff1a; 優先級隊列&#xff1a; 原理思路&#xff1a; 23. 合并 K 個升序鏈表 題目描述&#xff1a; 給你一個鏈表數組&#xff0c;每個鏈表都已經按升序排列。 請你將所有鏈表合并到一個升序鏈表…