Pytorch注意力機制應用到具體網絡方法(閉眼都會版)

文章目錄

  • 以YoloV4-tiny為例
    • 要加入的注意力機制代碼
    • 模型中插入注意力機制

以YoloV4-tiny為例

在這里插入圖片描述
解釋一下各個部分:

  • 最左邊這部分為主干提取網絡,功能為特征提取
  • 中間這邊部分為FPN,功能是加強特征提取
  • 最后一部分為yolo head,功能為獲得我們具體的一個預測結果

需要明白幾個點:

  • 注意力機制模塊是一個即插即用的模塊,理論上是可以添加到任何一個特征圖后面
  • 但是,不建議添加到主干部分(即最左邊的那部分),主干部分所用的特征是我們后面處理所用的基礎,故不建議添加到主干部分
  • 如果添加到主干部分,由于注意力機制模塊 它的權值模塊是隨機初始化的,那主干部分的權值就被破壞了,最開始提取出來的特征就不好用了。
  • 故建議把注意力機制模塊添加到主干以外的部分

本節把注意力機制添加到加強網絡里面,即上圖的中間部分。
添加注意力機制可以添加到上圖標注的部分。

要加入的注意力機制代碼

這一部分為要加入的注意力機制模塊,文件名為attention.py

import torch
from torch import nn
# 通道注意力機制
class channel_attention(nn.Module):def __init__(self,channel,ration=16):   #因為要進行全連接,故需要傳入通道數量,及縮放比例super(channel_attention,self).__init__()  #初始化#定義最大池化層self.max_pool = nn.AdaptiveMaxPool2d(1) #輸出層的高和寬是1#定義平均池化self.avg_pool = nn.AdaptiveAvgPool2d(1)self.fc = nn.Sequential(#定義第一次全連接nn.Linear(channel,channel // ration ,False),nn.ReLU(),# 定義第二次全連接nn.Linear(channel//ration,channel,False))#由于圖中的通道注意力機制是連個全連接層相加之后再取sigmoidself.sigmoid=nn.Sigmoid()#前傳部分def forward(self,x):b,c,h,w=x.size()#首先對輸入進來的x先進行一個全局最大池化 在進行一個全局平均池化max_pool_out=self.max_pool(x).view([b,c])avg_pool_out=self.avg_pool(x).view([b,c])#然后對兩次池化后的結果用共享的全連接層fc進行處理max_fc_out=self.fc(max_pool_out)avg_fc_out=self.fc(avg_pool_out)#最后將上面的兩個結果進行相加out=max_fc_out + avg_fc_outout=self.sigmoid(out).view([b,c,1,1])#print(out)return out * x
# 空間注意力機制
class spacial_attention(nn.Module):def __init__(self,kernel_size=7):   #空間注意力沒有通道數,故不用傳入channel和ration#但是空間注意力會進行一次卷積,故我們需要關注卷積核大小,一般為3或7super(spacial_attention,self).__init__()  #初始化padding=7//2  #卷積核大小整除輸入通道數self.conv=nn.Conv2d(2,1,kernel_size,1,padding,bias=False)#由圖可知輸入通道數是2,輸出通道數為1,卷積核大小默認設置為7,步長為1,因為不需要壓縮特征層阿高和寬#由于圖中的通道注意力機制是連個全連接層相加之后再取sigmoidself.sigmoid=nn.Sigmoid()#空間注意力機制前傳部分def forward(self,x):b,c,h,w=x.size()max_pool_out,_= torch.max(x,dim=1,keepdim=True)#需要把通道這一維度保留下來,故設置keepdim為True#對于pytorch來講,它的通道是在第一維度,也就是batchsize后面的那個維度故定義dim為1mean_pool_out = torch.mean(x,dim = 1,keepdim=True)#對最大值和平均值進行一個堆疊pool_out = torch.cat([max_pool_out, mean_pool_out],dim=1)#對堆疊后的結果取一個卷積out=self.conv(pool_out)out=self.sigmoid(out)print(out)return out * x#把空間注意力機制和通道注意力機制進行一個融合
class Cbam(nn.Module):def __init__(self,channel,ratio=16,kernel_size=7):super(Cbam,self).__init__()#調用已經定義好的2個注意力機制self.channel_attention=channel_attention(channel,ratio)self.spacial_attention = spacial_attention(kernel_size)#融合后機制的前傳部分def forward(self,x):x=self.channel_attention(x)x=self.spacial_attention(x)return x

在模型文件(yolo.py)中,首行添加如下部分

from .attention import se_block,cbam_block,eca_block
attention_blocks=[se_block,cbam_block,eca_block]
為何要設置成上面的形式?
為了方便調用,到時候可以直接編寫下面的代碼調用具體的注意力機制模塊
attention_blocks[0]

之后,需要找到yolo.py里面的模型主體部分,大概形式如下代碼

class YoloBody(nn.Module):def __init__(self,anchors_mask,num_classes,phi=0)#在原來的代碼上只是添加了phi,代表我們選用的注意力機制模塊,默認情況下為0super(YoloBody, self).__init__()self.backbone       = darknet53_tiny(None)self.conv_for_P5    = BasicConv(512,256,1)self.yolo_headP5    = yolo_head([512, len(anchors_mask[0]) * (5 + num_classes)],256)self.upsample       = Upsample(256,128)self.yolo_headP4    = yolo_head([256, len(anchors_mask[1]) * (5 + num_classes)],384)#下面這部分為自己填寫self.phi    = phi  #這個是自己添加的if 1 <= self.phi and self.phi <= 3:self.feat1_att      = attention_block[self.phi - 1](256)  #通道數為256self.feat2_att      = attention_block[self.phi - 1](512)#通道數為512self.upsample_att   = attention_block[self.phi - 1](128)#通道數為128#通道數到底是多少看這個模型的前傳部分的通道數為多少def forward(self, x):#---------------------------------------------------##   生成CSPdarknet53_tiny的主干模型#   feat1的shape為26,26,256#   feat2的shape為13,13,512#---------------------------------------------------#feat1, feat2 = self.backbone(x)#下面代碼為自己填寫if 1 <= self.phi and self.phi <= 3:#如果滿足條件就添加具體的注意力機制feat1 = self.feat1_att(feat1)feat2 = self.feat2_att(feat2)#下面代碼模型自帶# 13,13,512 -> 13,13,256P5 = self.conv_for_P5(feat2)# 13,13,256 -> 13,13,512 -> 13,13,255out0 = self.yolo_headP5(P5) # 13,13,256 -> 13,13,128 -> 26,26,128P5_Upsample = self.upsample(P5)# 26,26,256 + 26,26,128 -> 26,26,384#上面代碼模型自帶,下面代碼自己編寫if 1 <= self.phi and self.phi <= 3:P5_Upsample = self.upsample_att(P5_Upsample)#下面代碼模型自帶P4 = torch.cat([P5_Upsample,feat1],axis=1)# 26,26,384 -> 26,26,256 -> 26,26,255out1 = self.yolo_headP4(P4)return out0, out1

模型中插入注意力機制

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/web/64010.shtml
繁體地址,請注明出處:http://hk.pswp.cn/web/64010.shtml
英文地址,請注明出處:http://en.pswp.cn/web/64010.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

修改el-select下拉框高度;更新:支持動態修改

文章目錄 效果動態修改&#xff1a;效果代碼固定高度版本動態修改高度版本&#xff08;2024-12-25 更新&#xff1a; 支持動態修改下拉框高度&#xff09; 效果 動態修改&#xff1a;效果 代碼 固定高度版本 注意點&#xff1a; popper-class 盡量獨一無二&#xff0c;防止影…

開關電源特點、分類、工作方式

什么叫開關電源隨著電力電子技術的發展和創新&#xff0c;使得開關電源技術也在不斷地創新。目前&#xff0c;開關電源以小型、輕量和高效率的特點被廣泛應用幾乎所有的電子設備&#xff0c;是當今電子信息產業飛速發展不可缺少的一種電源方式。 開關電源是利用現代電力電子技…

Linux應用軟件編程-文件操作(目錄io)

1.打開目錄&#xff1a; DIR *opendir(const char *name); 功能&#xff1a;打開一個目錄獲得一個目錄流指針 參數: name:目錄名 返回值&#xff1a;成功返回目錄流指針&#xff1b;失敗返回NULL 2.讀目錄&#xff1a; struct dirent *readdir(DIR *dirp); 功能&…

有哪些開發者模式?

1、單例開發模式&#xff08;Singleton Pattern&#xff09; 單例模式是一種創建型設計模式&#xff0c;目的是確保在程序運行期間&#xff0c;某個類只有一個實例&#xff0c;并提供一個全局訪問點來訪問該實例。 核心特點 唯一實例&#xff1a;一個類只能創建一個對象實例。…

如何完全剔除對Eureka的依賴,報錯Cannot execute request on any known server

【現象】 程序運行報錯如下&#xff1a; com.netflix.discovery.shared.transport.TransportException報錯Cannot execute request on any known server 【解決方案】 &#xff08;1&#xff09;在Maven工程中的pom去掉Eureka相關的引用&#xff08;注釋以下部分&#xff0…

vscode寫python,遇到問題:ModuleNotFoundError: No module named ‘pillow‘(已解決 避坑)

1 問題&#xff1a; ModuleNotFoundError: No module named pillow 2 原因&#xff1a; 原因1&#xff1a;安裝Pillow的pip命令所處的python版本與vscode調用的python解釋器版本不同。 如&#xff1a; 原因2&#xff1a;雖然用的是pillow&#xff0c;但是寫代碼的時候只能用…

Ashy的考研游記

文章目錄 摘要12.1112.2012.21 DAY1&#xff08;政治/英語&#xff09;政治英語 12.22 DAY2&#xff08;數學/專業課&#xff09;數學專業課 結束估分 摘要 在24年的12月里&#xff0c;Ashy完成了他的考研沖刺&#xff0c;順利的結束了他本年度的考研之旅。 在十二月里&#…

AIGC實踐|AI/AR助力文旅沉浸式互動體驗探索

前言&#xff1a; 本篇文章的創作靈感來源于近期熱門話題——讓文物“動起來”&#xff0c;各大博物館成為新進潮流打卡地。結合之前創作的AI文旅宣傳片良好的流量和反饋&#xff0c;外加最近比較感興趣的AR互動探索&#xff0c;想嘗試看看自己能不能把這些零碎的內容整合起來…

tcp 的三次握手與四次揮手

問1: 請你說一下tcp的三次握手一次握手兩次握手三次握手問: 為什么不四(更多)次握手? 問 2: 請說一下 tcp 的 4 次揮手一次揮手兩次揮手問題:能不能等到數據傳輸完成再返回 ack? 三次揮手四次揮手問: 為什么要等兩個最大報文存在時間? bg: tcp 是可靠的連接,如何保證 建立連…

Kubernetes(k8s)離線部署DolphinScheduler3.2.2

1.環境準備 1.1 集群規劃 本次安裝環境為&#xff1a;3臺k8s現有的postgreSql數據庫zookeeper服務 1.2 下載及介紹 DolphinScheduler-3.2.2官網&#xff1a;https://dolphinscheduler.apache.org/zh-cn/docs/3.2.2 官網安裝文檔&#xff1a;https://dolphinscheduler.apach…

C++的侵入式鏈表

非侵入式鏈表 非侵入式鏈表是一種鏈表數據結構&#xff0c;其中每個元素&#xff08;節點&#xff09;并不需要自己包含指向前后節點的指針。鏈表的結構和節點的存儲是分開的&#xff0c;鏈表容器會單獨管理這些指針。 常見的非侵入式鏈表節點可以由以下所示&#xff0c;即&a…

Flutter組合動畫學習

如何使用動畫控制器和動畫來創建一個簡單的動畫效果。具體來說&#xff0c;它通過一個 AnimationController 來控制兩個動畫&#xff0c;一個用于旋轉&#xff0c;一個用于繪制。 前置知識點學習 SingleTickerProviderStateMixin SingleTickerProviderStateMixin 是 Flutter …

在vscode的ESP-IDF中使用自定義組件

以hello-world為例&#xff0c;演示步驟和注意事項 1、新建ESP-IDF項目 選擇模板 從hello-world模板創建 2、打開項目 3、編譯結果沒錯 正在執行任務: /home/azhu/.espressif/python_env/idf5.1_py3.10_env/bin/python /home/azhu/esp/v5.1/esp-idf/tools/idf_size.py /home…

2025差旅平臺怎么選?一體化、全流程降本案例解析

差旅支出在企業中一直是一項重要但容易被忽視的成本開支&#xff0c;尤其是在項目驅動型企業中&#xff0c;因頻繁的差旅需求&#xff0c;支出規模往往持續增長。以差旅平臺分貝通簽約伙伴——某智能制造業的業務模式為例&#xff0c;該模式要求員工定期前往不同的工廠、供應商…

【linux】NFS實驗

NFS NFS服務 nfs,最早是Sun這家公司所發展出來的,它最大的功能就是可以透過網絡,讓不同的機器,不同的操作系統,進行實現文檔的共享。所以你可以簡單的將他看做是文件服務器。 實驗準備 ①先準備一個服務器端的操作系統和客戶端的操作系統(Red Hat)。 ②選擇NAT模式,…

智源研究院與安謀科技達成戰略合作,共建開源AI“芯”生態

12月25日&#xff0c;智源研究院與安謀科技&#xff08;中國&#xff09;有限公司&#xff08;以下簡稱“安謀科技”&#xff09;與正式簽署戰略合作協議&#xff0c;雙方將面向多元AI芯片領域開展算子庫優化與適配、編譯器與工具鏈支持、生態系統建設與推廣等一系列深入合作&a…

ROG NUC:強大內核激發創意,AI賦能學子科技探索

有這么一款能夠激發無限創意、助力科技探索的迷你主機&#xff0c;它以其卓越的性能和迷你的身材成為了成為了ProArt百校行活動中的明星產品&#xff0c;助力廣大學子勇敢探索未知&#xff0c;追逐屬于自己的科技夢想。它就是ROG NUC 2024&#xff01; 強大性能&#xff0c;創意…

從零玩轉CanMV-K230(8)-多線程例程

文章目錄 前言一、_thread模塊API二、使用示例創建并啟動線程停止線程_thread.exit() 總結 前言 K230上不支持threading&#xff0c;只能支持_thread&#xff0c;該模塊實現了相應 CPython 模塊的子集&#xff0c;CPython 是 Python 編程的參考實現 語言&#xff0c;也是最著名…

yii2 手動添加 phpoffice\phpexcel

1.下載地址&#xff1a;https://github.com/PHPOffice/PHPExcel 2.解壓并修改文件名為phpexcel 在yii項目的vendor目錄下創建一個文件夾命名為phpoffice 把phpexcel目錄放到phpoffic文件夾下 查看vendor\phpoffice\phpexcel目錄下會看到這些文件 3.到vendor\composer目錄下…

安卓多渠道apk配置不同簽名

一般簽名都是放在buildTypes里面&#xff1a; ... android {...defaultConfig {...}signingConfigs {release {storeFile file("myreleasekey.keystore")storePassword "password"keyAlias "MyReleaseKey"keyPassword "password"}}bu…