CVPR | CNN融合注意力機制,蕪湖起飛!

**標題:**On the Integration of Self-Attention and Convolution
**論文鏈接:**https://arxiv.org/pdf/2111.14556
**代碼鏈接:**https://github.com/LeapLabTHU/ACmix

創新點

1. 揭示卷積和自注意力的內在聯系

文章通過重新分解卷積和自注意力模塊的操作,發現它們在第一階段(特征投影)都依賴于 1×1 卷積操作,并且這一階段占據了大部分的計算復雜度(與通道數的平方成正比)。這一發現為整合兩種模塊提供了理論基礎。

2. 提出 ACmix 模型

基于上述發現,作者提出了 ACmix 模型,它通過共享 1×1 卷積操作來同時實現卷積和自注意力的功能。具體來說:
**第一階段:**輸入特征通過 1×1 卷積投影,生成中間特征。
**第二階段:**這些中間特征分別用于卷積路徑(通過移位和聚合操作)和自注意力路徑(計算注意力權重并聚合值)。最終,兩條路徑的輸出通過可學習的權重加權求和,得到最終輸出。

3. 改進的移位和聚合操作

文章還提出了一種改進的移位操作,通過使用 固定卷積核的分組卷積 來替代傳統的張量移位操作。這種方法不僅提高了計算效率,還允許卷積核的可學習性,進一步增強了模型的靈活性。

4. 適應性路徑權重

ACmix 引入了兩個可學習的標量參數(α 和 β),用于動態調整卷積路徑和自注意力路徑的權重。這種設計不僅提高了模型的靈活性,還允許模型在不同深度上自適應地選擇更適合的特征提取方式。實驗表明,這種設計在模型的不同階段表現出不同的偏好,例如在早期階段更傾向于卷積,在后期階段更傾向于自注意力。

整體結構

第一階段:特征投影

在第一階段,輸入特征通過三個1×1卷積進行投影,分別生成查詢(query)、鍵(key)和值(value)特征映射。這些特征映射隨后被重塑為N塊,形成一個包含3×N特征映射的中間特征集。

第二階段:特征聚合

在第二階段,中間特征集被分為兩個路徑進行處理:

  • **自注意力路徑:**將中間特征集分為N組,每組包含三個特征映射(分別對應查詢、鍵和值)。這些特征映射按照傳統的多頭自注意力機制進行處理,計算注意力權重并聚合值。
  • **卷積路徑:**通過輕量級的全連接層生成k2個特征映射(k為卷積核大小)。這些特征映射通過移位和聚合操作,以類似傳統卷積的方式處理輸入特征,從局部感受野收集信息。

輸出整合

最后,自注意力路徑和卷積路徑的輸出通過兩個可學習的標量參數(α和β)加權求和,得到最終的輸出。

改進的移位和聚合操作

為了提高計算效率,ACmix模型采用了改進的移位操作,通過固定卷積核的分組卷積來替代傳統的張量移位操作。這種方法不僅提高了計算效率,還允許卷積核的可學習性,進一步增強了模型的靈活性。

模型的靈活性和泛化能力

ACmix模型不僅適用于標準的自注意力機制,還可以與各種變體(如Patchwise Attention、Window Attention和Global Attention)結合使用。這種設計使得ACmix能夠適應不同的任務需求,具有廣泛的適用性。

消融實驗

1. 結合兩個路徑的輸出

消融實驗探索了卷積和自注意力輸出的不同組合方式對模型性能的影響。實驗結果表明:

  • **卷積和自注意力的組合優于單一路徑:**使用卷積和自注意力模塊的組合始終優于僅使用單一路徑(如僅卷積或僅自注意力)的模型。
  • **可學習參數的靈活性:**通過引入可學習的參數(如α和β)來動態調整卷積和自注意力路徑的權重,ACmix能夠根據網絡中不同位置的需求自適應地調整路徑強度,從而獲得更高的靈活性和性能。

2. 組卷積核的選擇

實驗還對組卷積核的設計進行了驗證,結果表明:

  • **用組卷積替代張量位移:**通過使用組卷積替代傳統的張量位移操作,顯著提高了模型的推理速度。
  • **可學習卷積核和初始化:**使用可學習的卷積核并結合精心設計的初始化方法,進一步增強了模型的靈活性,并有助于提升最終性能。

3. 不同路徑的偏好

ACmix模型引入了兩個可學習標量α和β,用于動態調整卷積和自注意力路徑的權重。通過平行實驗,觀察到以下趨勢:

  • **早期階段偏好卷積:**在Transformer模型的早期階段,卷積作為特征提取器表現更好。
  • **中間階段混合使用:**在網絡的中間階段,模型傾向于混合使用兩種路徑,并逐漸增加對卷積的偏好。
  • **后期階段偏好自注意力:**在網絡的最后階段,自注意力表現優于卷積。

4. 對模型性能的影響

這些消融實驗結果表明,ACmix模型通過合理結合卷積和自注意力的優勢,并優化計算路徑,不僅在多個視覺任務上取得了顯著的性能提升,還保持了較高的計算效率

ACmix模塊的作用

1. 融合卷積和自注意力的優勢

ACmix模塊通過結合卷積的局部特征提取能力和自注意力的全局感知能力,實現了一種高效的特征融合策略。這種設計使得模型能夠同時利用卷積的局部感受野特性和自注意力的靈活性。

2. 優化計算路徑

ACmix通過優化計算路徑和減少重復計算,提高了整體模塊的計算效率。具體來說,它通過1×1卷積對輸入特征圖進行投影,生成中間特征,然后根據不同的范式(卷積和自注意力)分別重用和聚合這些中間特征。這種設計不僅減少了計算開銷,還提升了模型性能。

3. 改進的位移與求和操作

在卷積路徑中,ACmix采用深度可分離卷積(depthwise convolution)來替代低效的張量位移操作,從而提高了實際推理效率。

4. 動態調整路徑權重

ACmix引入了兩個可學習的標量參數(α和β),用于動態調整卷積和自注意力路徑的權重。這種設計使得模型能夠根據網絡中不同位置的需求自適應地調整路徑強度,從而獲得更高的靈活性。

5. 廣泛的應用潛力

ACmix在多個視覺任務(如圖像分類、語義分割和目標檢測)上均顯示出優于單一機制(僅卷積或僅自注意力)的性能,展示了其廣泛的應用潛力。

6. 實驗驗證

實驗結果表明,ACmix在保持較低計算開銷的同時,能夠顯著提升模型的性能。例如,在ImageNet分類任務中,ACmix模型在相同的FLOPs或參數數量下表現出色,并且在與競爭對手的基準比較中取得了持續的改進。此外,ACmix在ADE20K語義分割任務和COCO目標檢測任務中也顯示出明顯的改進

代碼實現

import torch
import torch.nn as nndef position(H, W, is_cuda=True):if is_cuda:loc_w = torch.linspace(-1.0, 1.0, W).cuda().unsqueeze(0).repeat(H, 1)loc_h = torch.linspace(-1.0, 1.0, H).cuda().unsqueeze(1).repeat(1, W)else:loc_w = torch.linspace(-1.0, 1.0, W).unsqueeze(0).repeat(H, 1)loc_h = torch.linspace(-1.0, 1.0, H).unsqueeze(1).repeat(1, W)loc = torch.cat([loc_w.unsqueeze(0), loc_h.unsqueeze(0)], 0).unsqueeze(0)return locdef stride(x, stride):b, c, h, w = x.shapereturn x[:, :, ::stride, ::stride]def init_rate_half(tensor):if tensor is not None:tensor.data.fill_(0.5)def init_rate_0(tensor):if tensor is not None:tensor.data.fill_(0.)class ACmix(nn.Module):def __init__(self, in_planes, out_planes, kernel_att=7, head=4, kernel_conv=3, stride=1, dilation=1):super(ACmix, self).__init__()self.in_planes = in_planesself.out_planes = out_planesself.head = headself.kernel_att = kernel_attself.kernel_conv = kernel_convself.stride = strideself.dilation = dilationself.rate1 = torch.nn.Parameter(torch.Tensor(1))self.rate2 = torch.nn.Parameter(torch.Tensor(1))self.head_dim = self.out_planes // self.headself.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=1)self.conv2 = nn.Conv2d(in_planes, out_planes, kernel_size=1)self.conv3 = nn.Conv2d(in_planes, out_planes, kernel_size=1)self.conv_p = nn.Conv2d(2, self.head_dim, kernel_size=1)self.padding_att = (self.dilation * (self.kernel_att - 1) + 1) // 2self.pad_att = torch.nn.ReflectionPad2d(self.padding_att)self.unfold = nn.Unfold(kernel_size=self.kernel_att, padding=0, stride=self.stride)self.softmax = torch.nn.Softmax(dim=1)self.fc = nn.Conv2d(3 * self.head, self.kernel_conv * self.kernel_conv, kernel_size=1, bias=False)self.dep_conv = nn.Conv2d(self.kernel_conv * self.kernel_conv * self.head_dim, out_planes,kernel_size=self.kernel_conv, bias=True, groups=self.head_dim, padding=1,stride=stride)self.reset_parameters()def reset_parameters(self):init_rate_half(self.rate1)init_rate_half(self.rate2)kernel = torch.zeros(self.kernel_conv * self.kernel_conv, self.kernel_conv, self.kernel_conv)for i in range(self.kernel_conv * self.kernel_conv):kernel[i, i // self.kernel_conv, i % self.kernel_conv] = 1.kernel = kernel.squeeze(0).repeat(self.out_planes, 1, 1, 1)self.dep_conv.weight = nn.Parameter(data=kernel, requires_grad=True)self.dep_conv.bias = init_rate_0(self.dep_conv.bias)def forward(self, x):q, k, v = self.conv1(x), self.conv2(x), self.conv3(x)scaling = float(self.head_dim) ** -0.5b, c, h, w = q.shapeh_out, w_out = h // self.stride, w // self.stride# ### att# ## positional encodingpe = self.conv_p(position(h, w, x.is_cuda))q_att = q.view(b * self.head, self.head_dim, h, w) * scalingk_att = k.view(b * self.head, self.head_dim, h, w)v_att = v.view(b * self.head, self.head_dim, h, w)if self.stride > 1:q_att = stride(q_att, self.stride)q_pe = stride(pe, self.stride)else:q_pe = peunfold_k = self.unfold(self.pad_att(k_att)).view(b * self.head, self.head_dim,self.kernel_att * self.kernel_att, h_out,w_out) # b*head, head_dim, k_att^2, h_out, w_outunfold_rpe = self.unfold(self.pad_att(pe)).view(1, self.head_dim, self.kernel_att * self.kernel_att, h_out,w_out) # 1, head_dim, k_att^2, h_out, w_outatt = (q_att.unsqueeze(2) * (unfold_k + q_pe.unsqueeze(2) - unfold_rpe)).sum(1) # (b*head, head_dim, 1, h_out, w_out) * (b*head, head_dim, k_att^2, h_out, w_out) -> (b*head, k_att^2, h_out, w_out)att = self.softmax(att)out_att = self.unfold(self.pad_att(v_att)).view(b * self.head, self.head_dim, self.kernel_att * self.kernel_att,h_out, w_out)out_att = (att.unsqueeze(1) * out_att).sum(2).view(b, self.out_planes, h_out, w_out)## convf_all = self.fc(torch.cat([q.view(b, self.head, self.head_dim, h * w), k.view(b, self.head, self.head_dim, h * w),v.view(b, self.head, self.head_dim, h * w)], 1))f_conv = f_all.permute(0, 2, 1, 3).reshape(x.shape[0], -1, x.shape[-2], x.shape[-1])out_conv = self.dep_conv(f_conv)return self.rate1 * out_att + self.rate2 * out_conv#輸入 B C H W, 輸出 B C H W
if __name__ == '__main__':block = ACmix(in_planes=64, out_planes=64)input = torch.rand(3, 64, 32, 32)output = block(input)print(input.size(), output.size())

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/web/67997.shtml
繁體地址,請注明出處:http://hk.pswp.cn/web/67997.shtml
英文地址,請注明出處:http://en.pswp.cn/web/67997.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

module ‘matplotlib.cm‘ has no attribute ‘get_cmap‘

目錄 解決方法1: 解決方法2,新版api改了: module matplotlib.cm has no attribute get_cmap 報錯代碼: cmap matplotlib.cm.get_cmap(Oranges) 解決方法1: pip install matplotlib3.7.3 解決方法2,新版…

使用Nuxt.js實現服務端渲染(SSR):提升SEO與性能的完整指南

使用Nuxt.js實現服務端渲染(SSR):提升SEO與性能的完整指南 使用Nuxt.js實現服務端渲染(SSR):提升SEO與性能的完整指南1. 服務端渲染(SSR)核心概念1.1 CSR vs SSR vs SSG1.2 SSR工作原…

解釋 Java 中的反射機制和動態代理的原理?

反射機制是Java語言的一個特性,它允許程序在運行時檢查和操作類、方法、字段等。 通過反射,我們可以在運行時獲取類的信息,創建對象,調用方法和訪問字段,即使這些信息在編譯時是未知的。 反射的基本用法 import jav…

http狀態碼:504 Gateway Timeout(網關超時)的原有以及排查問題的思路

504 Gateway Timeout(網關超時) 是一種常見的HTTP錯誤狀態碼,表示服務器作為網關或代理時,未能及時從上游服務器收到響應。以下是它的原因和排查問題的思路: 1. 504錯誤的含義 定義:服務器作為網關或代理時…

Linux 安裝 RabbitMQ

Linux下安裝RabbitMQ 1 、獲取安裝包 # 地址 https://github.com/rabbitmq/erlang-rpm/releases/download/v21.3.8.9/erlang-21.3.8.9-1.el7.x86_64.rpm erlang-21.3.8.9-1.el7.x86_64.rpmsocat-1.7.3.2-1.el6.lux.x86_64.rpm# 地址 https://github.com/rabbitmq/rabbitmq-se…

LOCAL_PREBUILT_JNI_LIBS使用說明

LOCAL_PREBUILT_JNI_LIBS使用說明 使用LOCAL_PREBUILT_JNI_LIBS,可用于控制APK集成時,其相關so的集成方式。 比如,用于將APK中的so,抽取出來。 LOCAL_PREBUILT_JNI_LIBS : \lib/arm64-v8a/libNativeCore.so \lib/arm64-v8a/liba…

Java中的object類

1.Object類是什么? 🟪Object 是 Java 類庫中的一個特殊類,也是所有類的父類(超類),位于類繼承層次結構的頂端。也就是說,Java 允許把任何類型的對象賦給 Object 類型的變量。 🟦Java里面除了Object類,所有的…

uniapp小程序自定義中間凸起樣式底部tabbar

我自己寫的自定義的tabbar效果圖 廢話少說咱們直接上代碼,一步一步來 第一步: 找到根目錄下的 pages.json 文件,在 tabBar 中把 custom 設置為 true,默認值是 false。list 中設置自定義的相關信息, pagePath&#x…

四、GPIO中斷實現按鍵功能

4.1 GPIO簡介 輸入輸出(I/O)是一個非常重要的概念。I/O泛指所有類型的輸入輸出端口,包括單向的端口如邏輯門電路的輸入輸出管腳和雙向的GPIO端口。而GPIO(General-Purpose Input/Output)則是一個常見的術語&#xff0c…

vscode+CMake+Debug實現 及權限不足等諸多問題匯總

環境說明 有空再補充 直接貼兩個json tasks.json {"version": "2.0.0","tasks": [{"label": "cmake","type": "shell","command": "cmake","args": ["../"…

【Elasticsearch】post_filter

post_filter是 Elasticsearch 中的一種后置過濾機制,用于在查詢執行完成后對結果進行過濾。以下是關于post_filter的詳細介紹: 工作原理 ? 查詢后過濾:post_filter在查詢執行完畢后對返回的文檔集進行過濾。這意味著所有與查詢匹配的文檔都…

《數據可視化新高度:Graphy的AI協作變革》

在數據洪流奔涌的時代,企業面臨的挑戰不再僅僅是數據的收集,更在于如何高效地將數據轉化為洞察,助力決策。Graphy作為一款前沿的數據可視化工具,憑借AI賦能的團隊協作功能,為企業打開了數據協作新局面,重新…

Vue 2 與 Vue 3 的主要區別

Vue.js 是一個流行的前端框架,用于構建用戶界面和單頁應用。自從 Vue 2 發布以來,社區對其進行了廣泛的應用和擴展,而 Vue 3 的發布則帶來了許多重要的改進和新特性。 性能提升 Vue 3 在響應式系統上進行了重大的改進,采用了基于…

從零開始:用Qt開發一個功能強大的文本編輯器——WPS項目全解析

文章目錄 引言項目功能介紹1. **文件操作**2. **文本編輯功能**3. **撤銷與重做**4. **剪切、復制與粘貼**5. **文本查找與替換**6. **打印功能**7. **打印預覽**8. **設置字體顏色**9. **設置字號**10. **設置字體**11. **左對齊**12. **右對齊**13. **居中對齊**14. **兩側對…

【IoCDI】_Spring的基本掃描機制

目錄 1. 創建測試項目 2. 改變啟動類所屬包 3. 使用ComponentScan 4. Spring基本掃描機制 程序通過注解告訴Spring希望哪些bean被管理,但在僅使用Bean時已經發現,Spring需要根據五大類注解才能進一步掃描方法注解。 由此可見,Spring對注…

vue 引入百度地圖和高德天氣 都得獲取權限

vue接入百度地圖---獲取ak https://blog.csdn.net/qq_57144407/article/details/143430661 vue接入高德天氣, 需要授權----獲取key https://www.jianshu.com/p/09ddd698eebe

通向AGI之路:人工通用智能的技術演進與人類未來

文章目錄 引言:當機器開始思考一、AGI的本質定義與技術演進1.1 從專用到通用:智能形態的范式轉移1.2 AGI發展路線圖二、突破AGI的五大技術路徑2.1 神經符號整合(Neuro-Symbolic AI)2.2 世界模型架構(World Models)2.3 具身認知理論(Embodied Cognition)三、AGI安全:價…

python中的命名規范

在python中,命名規范是編寫清晰,可讀性強代碼的重要部分,遵循這些規范可以使代碼更易于理解和維護。 Type命名約定命名例子函數(Function)小寫單詞,下劃線分割單詞function,delta_function方法&#xff08…

【工具變量】中國省級八批自由貿易試驗區設立及自貿區設立數據(2024-2009年)

一、測算方式:參考C刊《中國軟科學》任曉怡老師(2022)的做法,使用自由貿易試驗區(Treat Post) 表征,Treat為個體不隨時間變化的虛擬變量,如果該城市設立自由貿易試驗區則賦值為1,反之賦值為0&am…

Java進階總結——集合

Java進階總結——集合 說明:對于以上的框架圖有如下幾點說明 1.所有集合類都位于java.util包下。Java的集合類主要由兩個接口派生而出:Collection和Map,Collection和Map是Java集合框架的根接口,這兩個接口又包含了一些子接口或實…