UNet改進(4):交叉注意力(Cross Attention)-多模態/多特征交互

在計算機視覺領域,UNet因其優異的性能在圖像分割任務中廣受歡迎。本文將介紹一種改進的UNet架構——UNetWithCrossAttention,它通過引入交叉注意力機制來增強模型的特征融合能力。

1. 交叉注意力機制

交叉注意力(Cross Attention)是一種讓模型能夠動態地從輔助特征中提取相關信息來增強主特征的機制。在我們的實現中,CrossAttention類實現了這一功能:

class CrossAttention(nn.Module):def __init__(self, channels):super(CrossAttention, self).__init__()self.query_conv = nn.Conv2d(channels, channels // 8, kernel_size=1)self.key_conv = nn.Conv2d(channels, channels // 8, kernel_size=1)self.value_conv = nn.Conv2d(channels, channels, kernel_size=1)self.gamma = nn.Parameter(torch.zeros(1))def forward(self, x1, x2):batch_size, C, height, width = x1.size()# 投影到query, key, value空間proj_query = self.query_conv(x1).view(batch_size, -1, height * width).permute(0, 2, 1)proj_key = self.key_conv(x2).view(batch_size, -1, height * width)proj_value = self.value_conv(x2).view(batch_size, -1, height * width)# 計算注意力圖energy = torch.bmm(proj_query, proj_key)attention = torch.softmax(energy / math.sqrt(proj_key.size(-1)), dim=-1)# 應用注意力out = torch.bmm(proj_value, attention.permute(0, 2, 1))out = out.view(batch_size, C, height, width)# 殘差連接out = self.gamma * out + x1return out

該模塊的工作原理是:

  1. 將主特征x1投影為query,輔助特征x2投影為key和value

  2. 計算query和key的相似度得到注意力權重

  3. 使用注意力權重對value進行加權求和

  4. 通過殘差連接將結果與原始主特征融合

2. 雙卷積模塊

DoubleConv是UNet中的基礎構建塊,包含兩個連續的卷積層,并可選擇性地加入交叉注意力:

class DoubleConv(nn.Module):def __init__(self, in_channels, out_channels, use_cross_attention=False):super(DoubleConv, self).__init__()self.use_cross_attention = use_cross_attentionself.conv1 = nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False),nn.BatchNorm2d(out_channels),nn.ReLU(inplace=True))self.conv2 = nn.Sequential(nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False),nn.BatchNorm2d(out_channels),nn.ReLU(inplace=True))if use_cross_attention:self.cross_attention = CrossAttention(out_channels)def forward(self, x, aux_feature=None):x = self.conv1(x)x = self.conv2(x)if self.use_cross_attention and aux_feature is not None:x = self.cross_attention(x, aux_feature)return x

3. 下采樣和上采樣模塊

下采樣模塊Down結合了最大池化和雙卷積:

class Down(nn.Module):def __init__(self, in_channels, out_channels, use_cross_attention=False):super(Down, self).__init__()self.downsampling = nn.Sequential(nn.MaxPool2d(kernel_size=2, stride=2),DoubleConv(in_channels, out_channels, use_cross_attention))def forward(self, x, aux_feature=None):return self.downsampling[1](self.downsampling[0](x), aux_feature)

上采樣模塊Up使用轉置卷積進行上采樣并拼接特征:

class Up(nn.Module):def __init__(self, in_channels, out_channels, use_cross_attention=False):super(Up, self).__init__()self.upsampling = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)self.conv = DoubleConv(in_channels, out_channels, use_cross_attention)def forward(self, x1, x2, aux_feature=None):x1 = self.upsampling(x1)x = torch.cat([x2, x1], dim=1)x = self.conv(x, aux_feature)return x

4. 完整的UNetWithCrossAttention架構

將上述模塊組合起來,我們得到了完整的UNetWithCrossAttention:

class UNetWithCrossAttention(nn.Module):def __init__(self, in_channels=1, num_classes=1, use_cross_attention=False):super(UNetWithCrossAttention, self).__init__()self.in_channels = in_channelsself.num_classes = num_classesself.use_cross_attention = use_cross_attention# 編碼器self.in_conv = DoubleConv(in_channels, 64, use_cross_attention)self.down1 = Down(64, 128, use_cross_attention)self.down2 = Down(128, 256, use_cross_attention)self.down3 = Down(256, 512, use_cross_attention)self.down4 = Down(512, 1024, use_cross_attention)# 解碼器self.up1 = Up(1024, 512, use_cross_attention)self.up2 = Up(512, 256, use_cross_attention)self.up3 = Up(256, 128, use_cross_attention)self.up4 = Up(128, 64, use_cross_attention)self.out_conv = OutConv(64, num_classes)def forward(self, x, aux_feature=None):# 編碼過程x1 = self.in_conv(x, aux_feature)x2 = self.down1(x1, aux_feature)x3 = self.down2(x2, aux_feature)x4 = self.down3(x3, aux_feature)x5 = self.down4(x4, aux_feature)# 解碼過程x = self.up1(x5, x4, aux_feature)x = self.up2(x, x3, aux_feature)x = self.up3(x, x2, aux_feature)x = self.up4(x, x1, aux_feature)x = self.out_conv(x)return x

5. 應用場景與優勢

這種帶有交叉注意力的UNet架構特別適合以下場景:

  1. 多模態圖像分割:當有來自不同成像模態的輔助信息時,交叉注意力可以幫助模型有效地融合這些信息

  2. 時序圖像分析:對于視頻序列,前一幀的特征可以作為輔助特征來增強當前幀的分割

  3. 弱監督學習:當有額外的弱監督信號時,可以通過交叉注意力將其融入主網絡

相比于傳統UNet,這種架構的優勢在于:

  • 能夠動態地關注輔助特征中最相關的部分

  • 通過注意力機制實現更精細的特征融合

  • 保留了UNet原有的多尺度特征提取能力

  • 通過殘差連接避免了信息丟失

6. 總結

本文介紹了一種增強版的UNet架構,通過引入交叉注意力機制,使模型能夠更有效地利用輔助特征。這種設計既保留了UNet原有的優勢,又增加了靈活的特征融合能力,特別適合需要整合多源信息的復雜視覺任務。

在實際應用中,可以根據具體任務需求選擇在哪些層級啟用交叉注意力,也可以調整注意力模塊的復雜度來平衡模型性能和計算開銷。

希望這篇文章能幫助你理解交叉注意力在UNet中的應用。如果你有任何問題或建議,歡迎在評論區留言討論!

完整代碼

如下:

import torch.nn as nn
import torch
import mathclass CrossAttention(nn.Module):def __init__(self, channels):super(CrossAttention, self).__init__()self.query_conv = nn.Conv2d(channels, channels // 8, kernel_size=1)self.key_conv = nn.Conv2d(channels, channels // 8, kernel_size=1)self.value_conv = nn.Conv2d(channels, channels, kernel_size=1)self.gamma = nn.Parameter(torch.zeros(1))def forward(self, x1, x2):"""x1: 主特征 (batch, channels, height, width)x2: 輔助特征 (batch, channels, height, width)"""batch_size, C, height, width = x1.size()# 投影到query, key, value空間proj_query = self.query_conv(x1).view(batch_size, -1, height * width).permute(0, 2, 1)  # (B, N, C')proj_key = self.key_conv(x2).view(batch_size, -1, height * width)  # (B, C', N)proj_value = self.value_conv(x2).view(batch_size, -1, height * width)  # (B, C, N)# 計算注意力圖energy = torch.bmm(proj_query, proj_key)  # (B, N, N)attention = torch.softmax(energy / math.sqrt(proj_key.size(-1)), dim=-1)# 應用注意力out = torch.bmm(proj_value, attention.permute(0, 2, 1))  # (B, C, N)out = out.view(batch_size, C, height, width)# 殘差連接out = self.gamma * out + x1return outclass DoubleConv(nn.Module):def __init__(self, in_channels, out_channels, use_cross_attention=False):super(DoubleConv, self).__init__()self.use_cross_attention = use_cross_attentionself.conv1 = nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False),nn.BatchNorm2d(out_channels),nn.ReLU(inplace=True)self.conv2 = nn.Sequential(nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False),nn.BatchNorm2d(out_channels),nn.ReLU(inplace=True))if use_cross_attention:self.cross_attention = CrossAttention(out_channels)def forward(self, x, aux_feature=None):x = self.conv1(x)x = self.conv2(x)if self.use_cross_attention and aux_feature is not None:x = self.cross_attention(x, aux_feature)return xclass Down(nn.Module):def __init__(self, in_channels, out_channels, use_cross_attention=False):super(Down, self).__init__()self.downsampling = nn.Sequential(nn.MaxPool2d(kernel_size=2, stride=2),DoubleConv(in_channels, out_channels, use_cross_attention))def forward(self, x, aux_feature=None):return self.downsampling[1](self.downsampling[0](x), aux_feature)class Up(nn.Module):def __init__(self, in_channels, out_channels, use_cross_attention=False):super(Up, self).__init__()self.upsampling = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)self.conv = DoubleConv(in_channels, out_channels, use_cross_attention)def forward(self, x1, x2, aux_feature=None):x1 = self.upsampling(x1)x = torch.cat([x2, x1], dim=1)x = self.conv(x, aux_feature)return xclass UNetWithCrossAttention(nn.Module):def __init__(self, in_channels=1, num_classes=1, use_cross_attention=False):super(UNetWithCrossAttention, self).__init__()self.in_channels = in_channelsself.num_classes = num_classesself.use_cross_attention = use_cross_attention# 編碼器self.in_conv = DoubleConv(in_channels, 64, use_cross_attention)self.down1 = Down(64, 128, use_cross_attention)self.down2 = Down(128, 256, use_cross_attention)self.down3 = Down(256, 512, use_cross_attention)self.down4 = Down(512, 1024, use_cross_attention)# 解碼器self.up1 = Up(1024, 512, use_cross_attention)self.up2 = Up(512, 256, use_cross_attention)self.up3 = Up(256, 128, use_cross_attention)self.up4 = Up(128, 64, use_cross_attention)self.out_conv = OutConv(64, num_classes)def forward(self, x, aux_feature=None):# 編碼過程x1 = self.in_conv(x, aux_feature)x2 = self.down1(x1, aux_feature)x3 = self.down2(x2, aux_feature)x4 = self.down3(x3, aux_feature)x5 = self.down4(x4, aux_feature)# 解碼過程x = self.up1(x5, x4, aux_feature)x = self.up2(x, x3, aux_feature)x = self.up3(x, x2, aux_feature)x = self.up4(x, x1, aux_feature)x = self.out_conv(x)return x

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/diannao/88062.shtml
繁體地址,請注明出處:http://hk.pswp.cn/diannao/88062.shtml
英文地址,請注明出處:http://en.pswp.cn/diannao/88062.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

C#里從CSV文件加載BLOB數據字段到數據庫的處理

大量的數據保存在CSV文件, 當需要把這些數據加載到數據庫,然后使用數據庫來共享出去。 就需要把CSV文件導入數據庫, 怎么樣快速地把CSV文件導入數據庫呢? 這個就需要使用類MySqlBulkLoader,它是mariadb數據庫快速導入的方式。 一般使用SQL語句導入是10秒,那么使用這種方…

【后端】負載均衡

長期不定期更新補充。 定義 負載均衡(Load Balancing)是指將來自客戶端的請求合理分發到多個服務器或服務節點,以提高系統性能、可用性與可靠性。 分工 前端不做負載均衡,前端只發請求,不知道請求去哪臺服務器。 負…

記錄一次:Java Web 項目 CSS 樣式/圖片丟失問題:一次深度排查與根源分析

記錄一次:Java Web 項目 CSS 樣式/圖片丟失問題:一次深度排查與根源分析 **記錄一次:Java Web 項目 CSS 樣式丟失問題:一次深度排查與根源分析****第一層分析:資源路徑問題****第二層分析:服務端跳轉邏輯**…

torchmd-net開源程序是訓練神經網絡潛力

?一、軟件介紹 文末提供程序和源碼下載 TorchMD-NET 提供最先進的神經網絡電位 (NNP) 和訓練它們的機制。如果有多個 NNP,它可提供高效、快速的實現,并且它集成在 GPU 加速的分子動力學代碼中,如 ACEMD、OpenMM 和 …

在Docker上安裝Mongo及Redis-NOSQL數據庫

應用環境 Ubuntu 20.04.6 LTS (GNU/Linux 5.15.0-139-generic x86_64) Docker version 28.1.1, build 4eba377 文章目錄 一、部署Mongo1. 拉取容器鏡像2. 生成Run腳本2.1 準備條件2.2 參數解讀2.3 實例腳本 3. 實例操作3.1 Mongo bash控制臺3.2 庫表操作 4. MongoDB Compass (G…

Java 編程之責任鏈模式

一、什么是責任鏈模式? 責任鏈模式(Chain of Responsibility Pattern) 是一種行為型設計模式,它讓多個對象都有機會處理請求,從而避免請求的發送者和接收者之間的耦合關系。將這些對象連成一條鏈,沿著這條…

1、做中學 | 一年級上期 Golang簡介和安裝環境

一、什么是golang Golang,通常簡稱 Go,是由 Google 公司的 Robert Griesemer、Rob Pike 和 Ken Thompson 于 2007 年創建的一種開源編程語言,并在 2009 年正式對外公布。 已經有了很多編程語言,為什么還要創建一種新的編程語言&…

Linux--迷宮探秘:從路徑解析到存儲哲學

上一篇博客我們說完了文件系統在硬件層面的意義,今天我們來說說文件系統在軟件層是怎么管理的。 Linux--深入EXT2文件系統:數據是如何被組織、存儲與訪問的?-CSDN博客 🌌 引言:文件系統的宇宙觀 "在Linux的宇宙中…

淘寶商品數據實時獲取方案|API 接口開發與安全接入

在電商數據獲取領域,除了官方 API,第三方數據 API 接入也是高效獲取淘寶商品數據的重要途徑。第三方數據 API 憑借豐富的功能、靈活的服務,為企業和開發者提供了多樣化的數據解決方案。本文將聚焦第三方數據 API 接入,詳細介紹其優…

什么是防抖和節流?它們有什么區別?

文章目錄 一、防抖(Debounce)1.1 什么是防抖?1.2 防抖的實現 二、節流(Throttle)2.1 什么是節流?2.2 節流的實現方式 三、防抖與節流的對比四、總結 在前端開發中,我們經常會遇到一些高頻觸發的…

Springboot集成阿里云OSS上傳

Springboot集成阿里云OSS上傳 API 接口描述 DEMO提供的四個API接口,支持不同方式的文件和 JSON 數據上傳: 1. 普通文件上傳接口 上傳任意類型的文件 2. JSON 字符串上傳接口 上傳 JSON 字符串 3. 單個 JSON 壓縮上傳接口 上傳并壓縮 JSON 字符串…

刪除大表數據注意事項

數據庫是否會因刪除操作卡死,沒有固定的 “安全刪除條數”,而是受數據庫配置、表結構、操作方式、當前負載等多種因素影響。以下是關鍵影響因素及實踐建議: 一、導致數據庫卡死的核心因素 硬件與數據庫配置 CPU / 內存瓶頸:刪除…

Redis 是單線程模型?|得物技術

一、背景 使用過Redis的同學肯定都了解過一個說法,說Redis是單線程模型,那么實際情況是怎樣的呢? 其實,我們常說Redis是單線程模型,是指Redis采用單線程的事件驅動模型,只有并且只會在一個主線程中執行Re…

[特殊字符] AIGC工具深度實戰:GPT與通義靈碼如何徹底重構企業開發流程

🔍 第一模塊:理念顛覆——為什么AIGC不是“玩具”而是“效能倍增器”? ▍企業開發的核心痛點圖譜(2025版) ??研發效能瓶頸??:需求膨脹與交付時限矛盾持續尖銳,傳統敏捷方法論已觸天花板?…

(LeetCode 面試經典 150 題) 169. 多數元素(哈希表 || 二分查找)

題目&#xff1a;169. 多數元素 方法一&#xff1a;二分法&#xff0c;最壞的時間復雜度0(nlogn)&#xff0c;但平均0(n)即可。空間復雜度為0(1)。 C版本&#xff1a; int nnums.size();int l0,rn-1;while(l<r){int mid(lr)/2;int ans0;for(auto x:nums){if(xnums[mid]) a…

(17)java+ selenium->自動化測試-元素定位大法之By css上

1.簡介 CSS定位方式和xpath定位方式基本相同,只是CSS定位表達式有其自己的格式。CSS定位方式擁有比xpath定位速度快,且比CSS穩定的特性。下面詳細介紹CSS定位方式的使用方法。相對CSS來說,具有語法簡單,定位速度快等優點。 2.CSS定位優勢 CSS定位是平常使用過程中非常重要…

【軟考高級系統架構論文】企業集成平臺的技術與應用

論文真題 企業集成平臺是一個支持復雜信息環境下信息系統開發、集成和協同運行的軟件支撐環境。它基于各種企業經營業務的信息特征,在異構分布環境(操作系統、網絡、數據庫)下為應用提供一致的信息訪問和交互手段,對其上運行的應用進行管理,為應用提供服務,并支持企業信息…

i.MX8MP LVDS 顯示子系統全解析:設備樹配置與 DRM 架構詳解

&#x1f525; 推薦&#xff1a;《Yocto項目實戰教程&#xff1a;高效定制嵌入式Linux系統》 京東正版促銷&#xff0c;歡迎支持原創&#xff01; 鏈接&#xff1a;https://item.jd.com/15020438.html i.MX8MP LVDS 顯示子系統全解析&#xff1a;設備樹配置與 DRM 架構詳解 在…

keep-alive實現原理及Vue2/Vue3對比分析

一、keep-alive基本概念 keep-alive是Vue的內置組件&#xff0c;用于緩存組件實例&#xff0c;避免重復渲染。它具有以下特點&#xff1a; 抽象組件&#xff1a;自身不會渲染DOM&#xff0c;也不會出現在父組件鏈中包裹動態組件&#xff1a;緩存不活動的組件實例&#xff0c;…

安卓jetpack compose學習筆記-Navigation基礎學習

目錄 一、Navigation 二、BottomNavigation Compose是一個偏向靜態刷新的UI組件&#xff0c;如果不想要自己管理頁面切換的復雜狀態&#xff0c;可以以使用Navigation組件。 頁面間的切換可以NavHost&#xff0c;使用底部頁面切換欄&#xff0c;可以使用腳手架的bottomBarNav…