【深度學習】注意力機制(一)

本文介紹一些注意力機制的實現,包括SE/ECA/GE/A2-Net/GC/CBAM。

目錄

一、SE(Squeeze-and-Excitation)

二、ECA(Efficient Channel Attention)

三、GE(Gather-Excite)

四、A2-Net(Double Attention Networks)

五、GCNet(Global Context)

六、CBAM(Convolutional Block Attention Module)


一、SE(Squeeze-and-Excitation)

SE是通道注意力機制,論文地址:論文地址

SE模塊流程:

1、輸入特征圖經過自適應池化變為NC11的特征圖,特征圖resize為NC;

2、經過全連接層和Relu、sigmoid生成權重;

3、將權重和輸入特征圖相乘。

如下所示:

torch代碼實現:

import numpy as np
import torch
from torch import nn
from torch.nn import initclass SEAttention(nn.Module):def __init__(self, channel=512,reduction=16):super().__init__()self.avg_pool = nn.AdaptiveAvgPool2d(1)self.fc = nn.Sequential(nn.Linear(channel, channel // reduction, bias=False),nn.ReLU(inplace=True),nn.Linear(channel // reduction, channel, bias=False),nn.Sigmoid())def init_weights(self):for m in self.modules():if isinstance(m, nn.Conv2d):init.kaiming_normal_(m.weight, mode='fan_out')if m.bias is not None:init.constant_(m.bias, 0)elif isinstance(m, nn.BatchNorm2d):init.constant_(m.weight, 1)init.constant_(m.bias, 0)elif isinstance(m, nn.Linear):init.normal_(m.weight, std=0.001)if m.bias is not None:init.constant_(m.bias, 0)def forward(self, x):b, c, _, _ = x.size()y = self.avg_pool(x).view(b, c)y = self.fc(y).view(b, c, 1, 1)return x * y.expand_as(x)

二、ECA(Efficient Channel Attention)

ECA是通道注意力機制,論文:論文地址

ECA模塊過程:

1、使用自適應池化將NCHW的特征圖變為N1C的特征圖(自適應池化、squeeze、transpose);

2、使用1D卷積生成N1C的特征圖(在C通道做卷積),將經過1D卷積的特征圖變為NC11(transpose、unsqueeze);

3、特征圖通過sigmoid,生成NC11的權重,將權重與原特征圖相乘;

如下圖:

torch代碼:

import torch
from torch import nn
from torch.nn.parameter import Parameterclass ECALayer(nn.Module):"""Constructs a ECA module.Args:channel: Number of channels of the input feature mapk_size: Adaptive selection of kernel size"""def __init__(self, channel, k_size=3):super(eca_layer, self).__init__()self.avg_pool = nn.AdaptiveAvgPool2d(1)self.conv = nn.Conv1d(1, 1, kernel_size=k_size, padding=(k_size - 1) // 2, bias=False) self.sigmoid = nn.Sigmoid()def forward(self, x):# feature descriptor on the global spatial informationy = self.avg_pool(x)# Two different branches of ECA moduley = self.conv(y.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1)# Multi-scale information fusiony = self.sigmoid(y)return x * y.expand_as(x)

三、GE(Gather-Excite)

GE是空間注意力機制,論文:論文地址

該機制較為簡單,有四種方式,總體流程如下(看圖理解比較好,不多說了):

可以通過timm輕松調用該模塊,timm實現的源碼:

import mathfrom torch import nn as nn
import torch.nn.functional as Ffrom .create_act import create_act_layer, get_act_layer
from .create_conv2d import create_conv2d
from .helpers import make_divisible
from .mlp import ConvMlpclass GatherExcite(nn.Module):""" Gather-Excite Attention Module"""def __init__(self, channels, feat_size=None, extra_params=False, extent=0, use_mlp=True,rd_ratio=1./16, rd_channels=None,  rd_divisor=1, add_maxpool=False,act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, gate_layer='sigmoid'):super(GatherExcite, self).__init__()self.add_maxpool = add_maxpoolact_layer = get_act_layer(act_layer)self.extent = extentif extra_params:self.gather = nn.Sequential()if extent == 0:assert feat_size is not None, 'spatial feature size must be specified for global extent w/ params'self.gather.add_module('conv1', create_conv2d(channels, channels, kernel_size=feat_size, stride=1, depthwise=True))if norm_layer:self.gather.add_module(f'norm1', nn.BatchNorm2d(channels))else:assert extent % 2 == 0num_conv = int(math.log2(extent))for i in range(num_conv):self.gather.add_module(f'conv{i + 1}',create_conv2d(channels, channels, kernel_size=3, stride=2, depthwise=True))if norm_layer:self.gather.add_module(f'norm{i + 1}', nn.BatchNorm2d(channels))if i != num_conv - 1:self.gather.add_module(f'act{i + 1}', act_layer(inplace=True))else:self.gather = Noneif self.extent == 0:self.gk = 0self.gs = 0else:assert extent % 2 == 0self.gk = self.extent * 2 - 1self.gs = self.extentif not rd_channels:rd_channels = make_divisible(channels * rd_ratio, rd_divisor, round_limit=0.)self.mlp = ConvMlp(channels, rd_channels, act_layer=act_layer) if use_mlp else nn.Identity()self.gate = create_act_layer(gate_layer)def forward(self, x):size = x.shape[-2:]if self.gather is not None:x_ge = self.gather(x)else:if self.extent == 0:# global extentx_ge = x.mean(dim=(2, 3), keepdims=True)if self.add_maxpool:# experimental codepath, may remove or changex_ge = 0.5 * x_ge + 0.5 * x.amax((2, 3), keepdim=True)else:x_ge = F.avg_pool2d(x, kernel_size=self.gk, stride=self.gs, padding=self.gk // 2, count_include_pad=False)if self.add_maxpool:# experimental codepath, may remove or changex_ge = 0.5 * x_ge + 0.5 * F.max_pool2d(x, kernel_size=self.gk, stride=self.gs, padding=self.gk // 2)x_ge = self.mlp(x_ge)if x_ge.shape[-1] != 1 or x_ge.shape[-2] != 1:x_ge = F.interpolate(x_ge, size=size)return x * self.gate(x_ge)

四、A2-Net(Double Attention Networks)

雙重注意力網絡(A2-Nets)方法引入了新的關系函數用于非局部(NL)塊,依次使用兩個連續的注意力塊。論文地址:論文地址

其計算過程類似于SelfAttention模塊,可以看diamagnetic對照理解。

如下圖:

代碼如下:

import torch
import torch.nn as nn
import torch.nn.functional as Fclass DoubleAtten(nn.Module):"""A2-Nets: Double Attention Networks. NIPS 2018"""def __init__(self,in_c):""":paramin_c: 進行注意力refine的特征圖的通道數目;原文中的降維和升維沒有使用"""super(DoubleAtten,self).__init__()self.in_c = in_c"""以下對同一輸入特征圖進行卷積,產生三個尺度相同的特征圖,即為文中提到A, B, V"""self.convA = nn.Conv2d(in_c,in_c,kernel_size=1)self.convB = nn.Conv2d(in_c,in_c,kernel_size=1)self.convV = nn.Conv2d(in_c,in_c,kernel_size=1)def forward(self,input):feature_maps = self.convA(input)atten_map = self.convB(input)b, _, h, w = feature_maps.shapefeature_maps = feature_maps.view(b, 1, self.in_c, h*w) # 對 A 進行reshapeatten_map = atten_map.view(b, self.in_c, 1, h*w)       # 對 B 進行reshape 生成 attention_apsglobal_descriptors = torch.mean((feature_maps * F.softmax(atten_map, dim=-1)),dim=-1) # 特征圖與attention_maps 相乘生成全局特征描述子v = self.convV(input)atten_vectors = F.softmax(v.view(b, self.in_c, h*w), dim=-1) # 生成 attention_vectorsout = torch.bmm(atten_vectors.permute(0,2,1), global_descriptors).permute(0,2,1) # 注意力向量左乘全局特征描述子return out.view(b, _, h, w)

五、GCNet(Global Context)

全局上下文網絡(GC-Net)方法使用復雜的基于置換的操作將NL-塊和SE塊集成,以捕捉長期依賴關系。論文:論文地址

可以看出GC模塊是對SE的改進,如下圖:

該實現的初始化依賴于mmcv,代碼如下:

import torch
from mmcv.cnn import constant_init, kaiming_init
from torch import nndef last_zero_init(m):if isinstance(m, nn.Sequential):constant_init(m[-1], val=0)else:constant_init(m, val=0)class ContextBlock(nn.Module):def __init__(self,inplanes,ratio,pooling_type='att',fusion_types=('channel_add', )):super(ContextBlock, self).__init__()assert pooling_type in ['avg', 'att']assert isinstance(fusion_types, (list, tuple))valid_fusion_types = ['channel_add', 'channel_mul']assert all([f in valid_fusion_types for f in fusion_types])assert len(fusion_types) > 0, 'at least one fusion should be used'self.inplanes = inplanesself.ratio = ratioself.planes = int(inplanes * ratio)self.pooling_type = pooling_typeself.fusion_types = fusion_typesif pooling_type == 'att':self.conv_mask = nn.Conv2d(inplanes, 1, kernel_size=1)self.softmax = nn.Softmax(dim=2)else:self.avg_pool = nn.AdaptiveAvgPool2d(1)if 'channel_add' in fusion_types:self.channel_add_conv = nn.Sequential(nn.Conv2d(self.inplanes, self.planes, kernel_size=1),nn.LayerNorm([self.planes, 1, 1]),nn.ReLU(inplace=True),  # yapf: disablenn.Conv2d(self.planes, self.inplanes, kernel_size=1))else:self.channel_add_conv = Noneif 'channel_mul' in fusion_types:self.channel_mul_conv = nn.Sequential(nn.Conv2d(self.inplanes, self.planes, kernel_size=1),nn.LayerNorm([self.planes, 1, 1]),nn.ReLU(inplace=True),  # yapf: disablenn.Conv2d(self.planes, self.inplanes, kernel_size=1))else:self.channel_mul_conv = Noneself.reset_parameters()def reset_parameters(self):if self.pooling_type == 'att':kaiming_init(self.conv_mask, mode='fan_in')self.conv_mask.inited = Trueif self.channel_add_conv is not None:last_zero_init(self.channel_add_conv)if self.channel_mul_conv is not None:last_zero_init(self.channel_mul_conv)def spatial_pool(self, x):batch, channel, height, width = x.size()if self.pooling_type == 'att':input_x = x# [N, C, H * W]input_x = input_x.view(batch, channel, height * width)# [N, 1, C, H * W]input_x = input_x.unsqueeze(1)# [N, 1, H, W]context_mask = self.conv_mask(x)# [N, 1, H * W]context_mask = context_mask.view(batch, 1, height * width)# [N, 1, H * W]context_mask = self.softmax(context_mask)# [N, 1, H * W, 1]context_mask = context_mask.unsqueeze(-1)# [N, 1, C, 1]context = torch.matmul(input_x, context_mask)# [N, C, 1, 1]context = context.view(batch, channel, 1, 1)else:# [N, C, 1, 1]context = self.avg_pool(x)return contextdef forward(self, x):# [N, C, 1, 1]context = self.spatial_pool(x)out = xif self.channel_mul_conv is not None:# [N, C, 1, 1]channel_mul_term = torch.sigmoid(self.channel_mul_conv(context))out = out * channel_mul_termif self.channel_add_conv is not None:# [N, C, 1, 1]channel_add_term = self.channel_add_conv(context)out = out + channel_add_termreturn out

六、CBAM(Convolutional Block Attention Module)

CBAM是通道-空間注意力機制,論文:論文地址

很簡單的通道注意力和空間注意力融合。

如下圖:

代碼如下:

import numpy as np
import torch
from torch import nn
from torch.nn import initclass ChannelAttention(nn.Module):def __init__(self,channel,reduction=16):super().__init__()self.maxpool=nn.AdaptiveMaxPool2d(1)self.avgpool=nn.AdaptiveAvgPool2d(1)self.se=nn.Sequential(nn.Conv2d(channel,channel//reduction,1,bias=False),nn.ReLU(),nn.Conv2d(channel//reduction,channel,1,bias=False))self.sigmoid=nn.Sigmoid()def forward(self, x) :max_result=self.maxpool(x)avg_result=self.avgpool(x)max_out=self.se(max_result)avg_out=self.se(avg_result)output=self.sigmoid(max_out+avg_out)return outputclass SpatialAttention(nn.Module):def __init__(self,kernel_size=7):super().__init__()self.conv=nn.Conv2d(2,1,kernel_size=kernel_size,padding=kernel_size//2)self.sigmoid=nn.Sigmoid()def forward(self, x) :max_result,_=torch.max(x,dim=1,keepdim=True)avg_result=torch.mean(x,dim=1,keepdim=True)result=torch.cat([max_result,avg_result],1)output=self.conv(result)output=self.sigmoid(output)return outputclass CBAMBlock(nn.Module):def __init__(self, channel=512,reduction=16,kernel_size=49):super().__init__()self.ca=ChannelAttention(channel=channel,reduction=reduction)self.sa=SpatialAttention(kernel_size=kernel_size)def init_weights(self):for m in self.modules():if isinstance(m, nn.Conv2d):init.kaiming_normal_(m.weight, mode='fan_out')if m.bias is not None:init.constant_(m.bias, 0)elif isinstance(m, nn.BatchNorm2d):init.constant_(m.weight, 1)init.constant_(m.bias, 0)elif isinstance(m, nn.Linear):init.normal_(m.weight, std=0.001)if m.bias is not None:init.constant_(m.bias, 0)def forward(self, x):b, c, _, _ = x.size()residual=xout=x*self.ca(x)out=out*self.sa(out)return out+residual

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/211363.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/211363.shtml
英文地址,請注明出處:http://en.pswp.cn/news/211363.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

二維碼智慧門牌管理系統升級解決方案:數字鑒權

文章目錄 前言一、數字鑒權的核心機制二、數字鑒權的意義和應用 前言 隨著科技的飛速發展,我們的生活逐漸進入數字化時代。在這個數字化的過程中,數據的安全性和門牌信息的保障變得至關重要。今天,我們要介紹的是二維碼智慧門牌管理系統升級…

【論文復現】zoedepth踩坑

注意模型IO: 保證輸入、輸出精度、類型與復現目標一致。 模型推理的代碼 from torchvision import transforms def image_to_tensor(img_path, unsqueezeTrue):rgb transforms.ToTensor()(Image.open(img_path))if unsqueeze:rgb rgb.unsqueeze(0)return rgbdef…

dockerdesktop 導出鏡像,導入鏡像

總體思路 備份時 容器 > 鏡像 > 本地文件 恢復時 本地文件 > 鏡像 > 容器 備份步驟 首先,把容器生成為鏡像 docker commit [容器名稱] [鏡像名稱] 示例 docker commit nginx mynginx然后,把鏡像備份為本地文件,如果使用的是Docker Desktop,打包備份的文件會自動存…

機器學習筆記 - 基于C# + .net framework 4.8的ONNX Runtime進行分類推理

該示例是從官方抄的,演示了如何使用 Onnx Runtime C# API 運行預訓練的 ResNet50 v2 ONNX 模型。 我這里的環境基于.net framework 4.8的一個winform項目,主要依賴下面版本的相關庫。 Microsoft.Bcl.Numerics.8.0.0 Microsoft.ML.OnnxRuntime.Gpu.1.16.3 SixLabors.ImageShar…

MyString:string類的模擬實現 1

MyString:string類的模擬實現 前言: 為了區分標準庫中的string,避免編譯沖突,使用命名空間 MyString。 namespace MyString {class string{private:char* _str;size_t _size;size_t _capacity;const static size_t npos -1;// C標…

2023年 - 我的程序員之旅和成長故事

2023年 - 我的程序員之旅和成長故事 🔥 1.前言 大家好,我是Leo哥🫣🫣🫣,今天咱們不聊技術,聊聊我自己,聊聊我從2023年年初到現在的一些經歷和故事,我也很愿意我的故事分…

TS學習——快速入門

TypeScript簡介 TypeScript是JavaScript的超集。它對JS進行了擴展,向JS中引入了類型的概念,并添加了許多新的特性。TS代碼需要通過編譯器編譯為JS,然后再交由JS解析器執行。TS完全兼容JS,換言之,任何的JS代碼都可以直…

Android 樣式小結

關于作者:CSDN內容合伙人、技術專家, 從零開始做日活千萬級APP。 專注于分享各領域原創系列文章 ,擅長java后端、移動開發、商業變現、人工智能等,希望大家多多支持。 目錄 一、導讀二、概覽三、使用3.1 創建并應用樣式3.2 創建并…

DJI ONBOARD SDK—— 基礎控制功能 Joystick的講解,使用和擴展

DJI ONBOARD SDK/DJI OSDK ROS—— 基礎控制功能 Joystick的使用 概述 使用OSDK/OSDK_ROS 的無人機飛行控制功能,能夠設置并獲取無人機各項基礎參數,控制無人機執行基礎飛行動作,通過Joystick 功能控制無人機執行復雜的飛行動作。 Joystic…

【精彩回顧】恒拓高科亮相第十一屆深圳軍博會

2023年12月6日-8日,由中國和平利用軍工技術協會、全國工商聯科技裝備業商會、深圳市國防科技工業協會等單位主辦以及政府相關部門支持,深圳企發展覽有限公司承的“2023第11屆中國(深圳)軍民兩用科技裝備博覽會(深圳軍博…

02 CSS基礎入門

文章目錄 一、CSS介紹1. 簡介2. 相關網站3. HTML引入方式 二、選擇器1. 標簽選擇器2. 類選擇器3. ID選擇器4. 群組選擇器 四、樣式1. 字體樣式2. 文本樣式3. 邊框樣式4. 表格樣式 五、模型和布局1. 盒子模型2. 網頁布局 一、CSS介紹 1. 簡介 CSS主要用于控制網頁的外觀&#…

C#如何使用SqlSugar操作MySQL/SQL Server數據庫

一. SqlSugar 連接MySQL數據庫 public class MySqlCNHelper : Singleton<MySqlCNHelper>{public static SqlSugarClient CnDB;public void InitDB() {//--------------------MySQL--------------------CnDB new SqlSugarClient(new ConnectionConfig(){ConnectionString…

窮舉問題-搬磚(for循環)

某工地需要搬運磚塊&#xff0c;已知男人一人搬3塊&#xff0c;女人一人搬2塊&#xff0c;小孩兩人搬1塊。如果想用n人正好搬n塊磚&#xff0c;問有多少種搬法&#xff1f; 輸入格式: 輸入在一行中給出一個正整數n。 輸出格式: 輸出在每一行顯示一種方案&#xff0c;按照&q…

玩轉大數據12:大數據安全與隱私保護策略

1. 引言 大數據的快速發展&#xff0c;為各行各業帶來了巨大的變革&#xff0c;也帶來了新的安全和隱私挑戰。大數據系統通常處理大量敏感數據&#xff0c;包括個人身份信息、財務信息、健康信息等。如果這些數據被泄露或濫用&#xff0c;可能會對個人、企業和社會造成嚴重的損…

Unity 資源管理之Resources

Resources是一個特殊的文件夾&#xff0c;用于存放運行時加載的資源。 Resources文件夾中可以放置各種類型的資源文件&#xff0c;如紋理、模型、音頻、預制體等&#xff0c;一般用來存儲預制體和紋理信息。 通過API可以加載和訪問該文件夾及其子文件夾中的資源。 當我們打包…

大數據Doris(三十五):Unique模型(唯一主鍵)介紹

文章目錄 Unique模型(唯一主鍵)介紹 一、創建doris表 二、插入數據

【華為OD題庫-076】執行時長/GPU算力-Java

題目 為了充分發揮GPU算力&#xff0c;需要盡可能多的將任務交給GPU執行&#xff0c;現在有一個任務數組&#xff0c;數組元素表示在這1秒內新增的任務個數且每秒都有新增任務。 假設GPU最多一次執行n個任務&#xff0c;一次執行耗時1秒&#xff0c;在保證GPU不空閑情況下&…

海外獨立站站長常用的ChatGPT通用提示詞模板

目標市場&#xff1a;如何確定目標市場&#xff1f; 用戶需求&#xff1a;如何了解用戶需求&#xff1f; 網站設計&#xff1a;如何設計一個優秀的網站&#xff1f; 用戶體驗&#xff1a;如何提升用戶體驗&#xff1f; 功能規劃&#xff1a;請幫助我規劃網站的功能。 內容…

linux 應用開發筆記---【標準I/O庫/文件屬性及目錄】

一&#xff0c;什么是標準I/O庫 標準c庫當中用于文件I/O操作相關的一套庫函數&#xff0c;實用標準I/O需要包含頭文件 二&#xff0c;文件I/O和標準I/O之間的區別 1.標準I/O是庫函數&#xff0c;而文件I/O是系統調用 2.標準I/O是對文件I/O的封裝 3.標準I/O相對于文件I/O具有更…

SpringBoot 項目 Jar 包加密,防止反編譯

1場景 最近項目要求部署到其他公司的服務器上&#xff0c;但是又不想將源碼泄露出去。要求對正式環境的啟動包進行安全性處理&#xff0c;防止客戶直接通過反編譯工具將代碼反編譯出來。 2方案 第一種方案使用代碼混淆 采用proguard-maven-plugin插件 在單模塊中此方案還算簡…