學習注意力機制并將其應用到網絡中

什么是注意力機制

注意力機制的核心重點就是讓網絡關注到它更需要關注的地方

當我們使用卷積神經網絡去處理圖片的時候,我們會更希望卷積神經網絡去注意應該注意的地方,而不是什么都關注,我們不可能手動去調節需要注意的地方,這個時候,如何讓卷積神經網絡去自適應的注意重要的物體變得極為重要。

注意力機制就是實現網絡自適應注意的一個方式。

注意力機制可以分為通道注意力機制空間注意力機制,以及二者的結合。

通道注意力機制關注的是某些重要的通道,空間注意力機制關注的是圖片中某些重要的區域

注意力機制的實現方式

在深度學習中,常見的注意力機制的實現方式有SENet,CBAM,ECA等等。

1.SENet的實現

SENet是通道注意力機制的典型實現。
對于輸入進來的特征層,我們關注其每一個通道的權重,對于SENet而言,其重點是獲得輸入進來的特征層,每一個通道的權值。利用SENet,我們可以讓網絡關注它最需要關注的通道

其具體實現方式就是:
1、對輸入進來的特征層進行全局平均池化
2、然后進行兩次全連接,第一次全連接神經元個數較少,第二次全連接神經元個數和輸入特征層相同。
3、在完成兩次全連接后,我們再取一次Sigmoid將值固定到0-1之間,此時我們獲得了輸入特征層每一個通道的權值(0-1之間)。
4、在獲得這個權值后,我們將這個權值乘上原輸入特征層即可。

實現代碼:

def se_block(input_feature, ratio=16, name=""):channel = input_feature._keras_shape[-1]se_feature = GlobalAveragePooling2D()(input_feature)se_feature = Reshape((1, 1, channel))(se_feature)se_feature = Dense(channel // ratio,activation='relu',kernel_initializer='he_normal',use_bias=False,bias_initializer='zeros',name = "se_block_one_"+str(name))(se_feature)se_feature = Dense(channel,kernel_initializer='he_normal',use_bias=False,bias_initializer='zeros',name = "se_block_two_"+str(name))(se_feature)se_feature = Activation('sigmoid')(se_feature)se_feature = multiply([input_feature, se_feature])return se_feature

2.CBAM的實現

CBAM將通道注意力機制和空間注意力機制進行一個結合,相比于SENet只關注通道的注意力機制可以取得更好的效果。CBAM會對輸入進來的特征層,分別進行通道注意力機制的處理和空間注意力機制的處理。
通道注意力機制的實現可以分為兩個部分,我們會對輸入進來的單個特征層,分別進行全局平均池化和全局最大池化。之后對平均池化和最大池化的結果,利用共享的全連接層進行處理,我們會對處理后的兩個結果進行相加,然后取一個sigmoid,此時我們獲得了輸入特征層每一個通道的權值(0-1之間)。在獲得這個權值后,我們將這個權值乘上原輸入特征層即可

空間注意力機制的實現:我們會對輸入進來的特征層,在每一個特征點的通道上取最大值和平均值。之后將這兩個結果進行一個堆疊,利用一次通道數為1的卷積調整通道數,然后取一個sigmoid,此時我們獲得了輸入特征層每一個特征點的權值(0-1之間)。在獲得這個權值后,我們將這個權值乘上原輸入特征層即可。

實現代碼如下:

def channel_attention(input_feature, ratio=8, name=""):channel = input_feature._keras_shape[-1]shared_layer_one = Dense(channel//ratio,activation='relu',kernel_initializer='he_normal',use_bias=False,bias_initializer='zeros',name = "channel_attention_shared_one_"+str(name))shared_layer_two = Dense(channel,kernel_initializer='he_normal',use_bias=False,bias_initializer='zeros',name = "channel_attention_shared_two_"+str(name))avg_pool = GlobalAveragePooling2D()(input_feature)    max_pool = GlobalMaxPooling2D()(input_feature)avg_pool = Reshape((1,1,channel))(avg_pool)max_pool = Reshape((1,1,channel))(max_pool)avg_pool = shared_layer_one(avg_pool)max_pool = shared_layer_one(max_pool)avg_pool = shared_layer_two(avg_pool)max_pool = shared_layer_two(max_pool)cbam_feature = Add()([avg_pool,max_pool])cbam_feature = Activation('sigmoid')(cbam_feature)return multiply([input_feature, cbam_feature])def spatial_attention(input_feature, name=""):kernel_size = 7cbam_feature = input_featureavg_pool = Lambda(lambda x: K.mean(x, axis=3, keepdims=True))(cbam_feature)max_pool = Lambda(lambda x: K.max(x, axis=3, keepdims=True))(cbam_feature)concat = Concatenate(axis=3)([avg_pool, max_pool])cbam_feature = Conv2D(filters = 1,kernel_size=kernel_size,strides=1,padding='same',kernel_initializer='he_normal',use_bias=False,name = "spatial_attention_"+str(name))(concat)	cbam_feature = Activation('sigmoid')(cbam_feature)return multiply([input_feature, cbam_feature])def cbam_block(cbam_feature, ratio=8, name=""):cbam_feature = channel_attention(cbam_feature, ratio, name=name)cbam_feature = spatial_attention(cbam_feature, name=name)return cbam_feature

3、ECA的實現
ECANet是也是通道注意力機制的一種實現形式。ECANet可以看作是SENet的改進版
ECANet的作者認為SENet對通道注意力機制的預測帶來了副作用,捕獲所有通道的依賴關系是低效并且是不必要的
ECA模塊的思想是非常簡單的,它去除了原來SE模塊中的全連接層,直接在全局平均池化之后的特征上通過一個1D卷積進行學習

既然使用到了1D卷積,那么1D卷積的卷積核大小的選擇就變得非常重要了,了解過卷積原理的同學很快就可以明白,1D卷積的卷積核大小會影響注意力機制每個權重的計算要考慮的通道數量

實現代碼如下:

def eca_block(input_feature, b=1, gamma=2, name=""):channel = input_feature._keras_shape[-1]kernel_size = int(abs((math.log(channel, 2) + b) / gamma))kernel_size = kernel_size if kernel_size % 2 else kernel_size + 1avg_pool = GlobalAveragePooling2D()(input_feature)x = Reshape((-1,1))(avg_pool)x = Conv1D(1, kernel_size=kernel_size, padding="same", name = "eca_layer_"+str(name), use_bias=False,)(x)x = Activation('sigmoid')(x)x = Reshape((1, 1, -1))(x)output = multiply([input_feature,x])return output

開始應用:將注意力機制加入到YOLOv8中

1.找到conv.py文件

2.在conv.py中添加名字

3.在__init__.py中添加名字

4.在tasks.py文件中添加名字

5.在tasks.py中添加配置

在該函數中添加代碼

添加的代碼為:

elif m in {CBAM}:c1, c2 = ch[f], args[0]if c2 != nc:c2 = make_divisible(min(c2, max_channels) * width, 8)args = [c1, *args[1:]]

添加后的為:

6.打開yaml文件

7.盡量不要在這個文件中更改內容,我們可以自己創建一個yaml文件(my_yolov8_CBAM.yaml),然后將yolov8.yaml中的內容復制過來

8.在backbone中進行修改

from列中的-1表示應用上一層的參數、repeats列表示重復多少次、module列表示模型的名字、args列表示參數

9.第八點操作添加完后層數會改變,head部分需要進行相應的修改

修改前:

# YOLOv8.0n head
head:- [ -1, 1, nn.Upsample, [ None, 2, "nearest" ] ]- [ [ -1, 6 ], 1, Concat, [ 1 ] ] # cat backbone P4- [ -1, 3, C2f, [ 512 ] ] # 12- [ -1, 1, nn.Upsample, [ None, 2, "nearest" ] ]- [ [ -1, 4 ], 1, Concat, [ 1 ] ] # cat backbone P3- [ -1, 3, C2f, [ 256 ] ] # 15 (P3/8-small)- [ -1, 1, Conv, [ 256, 3, 2 ] ]- [ [ -1, 12 ], 1, Concat, [ 1 ] ] # cat head P4- [ -1, 3, C2f, [ 512 ] ] # 18 (P4/16-medium)- [ -1, 1, Conv, [ 512, 3, 2 ] ]- [ [ -1, 9 ], 1, Concat, [ 1 ] ] # cat head P5- [ -1, 3, C2f, [ 1024 ] ] # 21 (P5/32-large)- [ [ 15, 18, 21 ], 1, Detect, [ nc ] ] # Detect(P3, P4, P5)

修改后:

為什么都+1了?

舉個例子,原來要連接第六層,加了注意力層后,原來的第六層就變成第七層,所以在Concat連接時需要修改相應的層數

至此,注意力機制已經插入,可以開始使用了

10.在根目錄下新建一個main.py文件,代碼如下:

from ultralytics import YOLOmodel = (YOLO("ultralytics/cfg/models/v8/my_yolov8_CBAM.yaml"))
model.train(**{'cfg': 'ultralytics/cfg/default.yaml'})

運行即可開始訓練

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/web/11411.shtml
繁體地址,請注明出處:http://hk.pswp.cn/web/11411.shtml
英文地址,請注明出處:http://en.pswp.cn/web/11411.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

【Pytest官方文檔翻譯及學習】2.1 如何調用pytest

目錄 2.1 如何調用pytest 2.1.1 指定要運行的測試 2.1.2 獲取有關版本、選項名稱、環境變量的幫助 2.1.3 分析測試執行時間 2.1.4 管理加載插件 2.1.5 調用pytest的其他方式 2.1 如何調用pytest 2.1.1 指定要運行的測試 Pytest支持幾種從命令行運行和選擇測試的方法。、…

證明力引導算法forceatlas2為什么不是啟發式算法

一、基本概念 吸引力 F a ( n i ) ∑ n j ∈ N c t d ( n i ) ω i , j d E ( n i , n j ) V i , j \displaystyle \bm{F}_a(n_i) \sum_{n_j \in \mathcal{N}_{ctd}(n_i)} \omega_{i,j} \; d_E(n_i,n_j) \bm{V}_{i,j} Fa?(ni?)nj?∈Nctd?(ni?)∑?ωi,j?dE?(ni?,nj?…

class常量池、運行時常量池和字符串常量池的關系

類常量池、運行時常量池和字符串常量池這三種常量池,在Java中扮演著不同但又相互關聯的角色。理解它們之間的關系,有助于深入理解Java虛擬機(JVM)的內部工作機制,尤其是在類加載、內存分配和字符串處理方面。 類常量池…

MinCED:注釋CRISPRs

GitHub - ctSkennerton/minced: Mining CRISPRs in Environmental Datasets 安裝 git clone http://github.com/ctSkennerton/minced cd minced make 使用 gunzip -k * cat *.fa > all_MAG_contig.fasta /home/zhongpei/hard_disk_sda2/zhongpei/Software/minced/minced…

NeurIPS‘24 截稿日期逼近 加拿大溫哥華邀你共赴盛會

會議之眼 快訊 第38屆NeurIPS24(Conference and Workshop on Neural Information Processing Systems)即神經信息處理系統研討會將于 2024 年 12月9日-15日在加拿大溫哥華會議中心舉行! NeurIPS 每一年都是全球AI領域的一場盛宴,吸引著來自世界各地的頂…

暴雨信息:IT是新質生產力的賦能者

5月11日下午,2024全球徽商上海論壇在上海國際會議中心舉辦。暴雨信息孫輝在會上發表歡迎辭。孫輝在致辭和會后接受采訪時表示,發展新質生產力要以“智”提質,發揮人工智能作為培育新質生產力的引擎作用,通過推廣混合式人工智能&am…

【小白誤闖】Activiti 框架你不得不知道的一些事

Activiti 是一個輕量級的、以Java為中心的開源工作流和業務流程管理(BPM)平臺。它允許用戶在業務應用程序中定義、執行和監控業務流程。以下是Activiti的核心組件: 8個核心組件概述 Activiti Engine:這是Activiti最核心的部分&am…

Java 面試問題及答案

Java 面試問題及答案 問題 1: 什么是Java虛擬機(JVM)?請簡述其主要組成部分及其作用。 回答: Java虛擬機(JVM)是一個可以執行Java字節碼的虛擬計算機。它是Java平臺的核心組成部分,使得Java能夠實現其核心特性之一&a…

Elasticsearch映射定義

文章目錄 認識映射元字段數據類型1.基本數據類型2.復雜數據類型專用數據類型多字段類型 認識映射 映射類似于關系型數據庫中的Schema(模式)。Schema在關系型數據庫中是指庫表包含的字段及字段存儲類型等基礎信息。 映射定義由兩部分組成:元…

一些python包缺失帶來的報錯及解決辦法

描述 一些python包缺失帶來的報錯及解決辦法 安裝 ModuleNotFoundError: No module named cv2 pip install opencv-pythonModuleNotFoundError: No module named torch 我的CSDN博客ModuleNotFoundError: No module named colorama pip install coloramaModuleNotFoundError…

5.10.8 Transformer in Transformer

Transformer iN Transformer (TNT)。具體來說,我們將局部補丁(例如,1616)視為“視覺句子”,并將它們進一步劃分為更小的補丁(例如,44)作為“視覺單詞”。每個單詞的注意力將與給定視…

信號和槽基本概念

🐌博主主頁:🐌?倔強的大蝸牛🐌? 📚專欄分類:QT??感謝大家點贊👍收藏?評論?? 目錄 一、概述 二、信號的本質 三、槽的本質 一、概述 在 Qt 中,用戶和控件的每次交互過程稱…

Bootloader+升級方案

隨著設備的功能越來越強大,系統也越來越復雜,產品升級也成為了開發過程不可或缺的一道程序。在工程應用中,如何在不更改硬件的前提下通過軟件的方式實現產品升級。通過Bootloader來實現固件的升級是一種極好的方式,Bootloader是單…

I2CKD : INTRA- AND INTER-CLASS KNOWLEDGE DISTILLATION FOR SEMANTIC SEGMENTATION

摘要 本文提出了一種新的針對圖像語義分割的知識蒸餾方法,稱為類內和類間知識蒸餾(I2CKD)。該方法的重點是在教師(繁瑣模型)和學生(緊湊模型)的中間層之間捕獲和傳遞知識。對于知識提取&#x…

12個乒乓球,有一個次品,不知輕重,用一臺無砝碼天平稱三次,找出次品,告知輕重?

前言 B站上看到個視頻:為什么有人不認可清北的學生大多是智商高的? 然后試了下,發現我真菜 自己的思路(失敗) 三次稱重要獲取到12個乒乓球中那個是次品,我想著將12個小球編號,分為四組,每組…

yo!這里是socket網絡編程相關介紹

目錄 前言 基本概念 源ip&&目的ip 源端口號&&目的端口號 udp&&tcp初識 socket編程 網絡字節序 socket常見接口 socket bind listen accept connect 地址轉換函數 字符串轉in_addr in_addr轉字符串 套接字讀寫函數 recvfrom&&a…

Java入門基礎學習筆記2——JDK的選擇下載安裝

搭建Java的開發環境: Java的產品叫JDK(Java Development Kit: Java開發者工具包),必須安裝JDK才能使用Java。 JDK的發展史: LTS:Long-term Support:長期支持版。指的Java會對這些版…

pycharm報錯Process finished with exit code -1073740791 (0xC0000409)

pycharm報錯Process finished with exit code -1073740791 (0xC0000409) 各種垃圾文章(包括chatgpt產生的垃圾文章),沒有給出具體的解決辦法。 解決辦法就是把具體報錯信息顯示出來,然后再去查。 勾選 然后再運行就能把錯誤顯示…

MetaRTC-play拉流客戶端代碼分析

渲染使用opengl,音頻播放使用alsa。 當點擊播放按鈕后,以此調用的類如下,開始建立rtc連接,AV解碼,音頻渲染,視頻渲染。 如果想去除QT,改為cmake工程管理,去掉渲染部分即可。 下方是…

VUE+PrintJS打印-邊距設置問題(提供解決方案)

VUE打印我們一般用printJS,雖然它也提供了邊距設置,但不管怎么調,感覺都不對,也換其他組件試過,沒啥區別,并不能解決問題。 今天又發來個需求,要求設置打印頁面的上、下、左、右邊距&#xff0…