文章目錄
- 1、Multiscale Dual-Representation Alignment Filter
- 2、代碼實現
paper:SFFNet: A Wavelet-Based Spatial and Frequency Domain Fusion Network for Remote Sensing Segmentation
Code:https://github.com/yysdck/SFFNet
1、Multiscale Dual-Representation Alignment Filter
頻率域特征和空間域特征分別捕捉圖像的不同方面和屬性,但它們之間存在語義差異。直接將兩者融合可能導致特征表示不一致,無法充分發揮各自優勢。需要一種方法來對齊兩者的語義,并選擇更具代表性的特征進行融合。所以這篇論文提出一種 多尺度雙表示對齊過濾器(Multiscale Dual-Representation Alignment Filter),其主要包含以下兩點:多尺度映射 (Multiscale Mapping):使用不同尺度的豎條卷積對頻率域特征和空間域特征進行處理。將處理后的特征拼接并進行 1x1 卷積,得到統一尺度的矩陣 Q, K, V 作為輸入。多域注意力融合 (MDAF):設計了 DAF (Dual-Representation Alignment Filter) 結構,利用交叉注意力機制實現語義對齊和特征選擇。通過查詢對方及其自身的鍵值對計算注意力,并進行特征加權,最終實現特征選擇。
實現過程:
- 多尺度映射:對空間域特征 Fs 和頻率域特征 Ff 分別進行多尺度映射,得到兩組矩陣 (Q1, K1, V1) 和 (Q2, K2, V2)
- DAF 計算:計算 DAF 輸出 F1 和 F2:(1)F1 = δ1×1(Attn(Q2, K1, V1)):使用 Ff 的 Q, K, V 與 Fs 的 K, V 計算注意力,并進行特征加權。(2)F2 = δ1×1(Attn(Q1, K2, V2)):使用 Fs 的 Q, K, V 與 Ff 的 K, V 計算注意力,并進行特征加權。
- MDAF 輸出:將 F1 和 F2 拼接得到最終的輸出特征。
Multiscale Dual-Representation Alignment Filter 結構圖:
2、代碼實現
import torch
import torch.nn as nn
import torch.nn.functional as F
import numbers
from einops import rearrangedef to_3d(x):return rearrange(x, 'b c h w -> b (h w) c')def to_4d(x, h, w):return rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)class BiasFree_LayerNorm(nn.Module):def __init__(self, normalized_shape):super(BiasFree_LayerNorm, self).__init__()if isinstance(normalized_shape, numbers.Integral):normalized_shape = (normalized_shape,)normalized_shape = torch.Size(normalized_shape)assert len(normalized_shape) == 1self.weight = nn.Parameter(torch.ones(normalized_shape))self.normalized_shape = normalized_shapedef forward(self, x):sigma = x.var(-1, keepdim=True, unbiased=False)return x / torch.sqrt(sigma + 1e-5) * self.weightclass WithBias_LayerNorm(nn.Module):def __init__(self, normalized_shape):super(WithBias_LayerNorm, self).__init__()if isinstance(normalized_shape, numbers.Integral):normalized_shape = (normalized_shape,)normalized_shape = torch.Size(normalized_shape)assert len(normalized_shape) == 1self.weight = nn.Parameter(torch.ones(normalized_shape))self.bias = nn.Parameter(torch.zeros(normalized_shape))self.normalized_shape = normalized_shapedef forward(self, x):mu = x.mean(-1, keepdim=True)sigma = x.var(-1, keepdim=True, unbiased=False)return (x - mu) / torch.sqrt(sigma + 1e-5) * self.weight + self.biasclass LayerNorm(nn.Module):def __init__(self, dim, LayerNorm_type):super(LayerNorm, self).__init__()if LayerNorm_type == 'BiasFree':self.body = BiasFree_LayerNorm(dim)else:self.body = WithBias_LayerNorm(dim)def forward(self, x):h, w = x.shape[-2:]return to_4d(self.body(to_3d(x)), h, w)class MDAF(nn.Module):def __init__(self, dim, num_heads=8, LayerNorm_type='WithBias'):super(MDAF, self).__init__()self.num_heads = num_headsself.norm1 = LayerNorm(dim, LayerNorm_type)self.norm2 = LayerNorm(dim, LayerNorm_type)self.project_out = nn.Conv2d(dim, dim, kernel_size=1)self.conv1_1_1 = nn.Conv2d(dim, dim, (1, 7), padding=(0, 3), groups=dim)self.conv1_1_2 = nn.Conv2d(dim, dim, (1, 11), padding=(0, 5), groups=dim)self.conv1_1_3 = nn.Conv2d(dim, dim, (1, 21), padding=(0, 10), groups=dim)self.conv1_2_1 = nn.Conv2d(dim, dim, (7, 1), padding=(3, 0), groups=dim)self.conv1_2_2 = nn.Conv2d(dim, dim, (11, 1), padding=(5, 0), groups=dim)self.conv1_2_3 = nn.Conv2d(dim, dim, (21, 1), padding=(10, 0), groups=dim)self.conv2_1_1 = nn.Conv2d(dim, dim, (1, 7), padding=(0, 3), groups=dim)self.conv2_1_2 = nn.Conv2d(dim, dim, (1, 11), padding=(0, 5), groups=dim)self.conv2_1_3 = nn.Conv2d(dim, dim, (1, 21), padding=(0, 10), groups=dim)self.conv2_2_1 = nn.Conv2d(dim, dim, (7, 1), padding=(3, 0), groups=dim)self.conv2_2_2 = nn.Conv2d(dim, dim, (11, 1), padding=(5, 0), groups=dim)self.conv2_2_3 = nn.Conv2d(dim, dim, (21, 1), padding=(10, 0), groups=dim)def forward(self, x1,x2):b, c, h, w = x1.shapex1 = self.norm1(x1)x2 = self.norm2(x2)attn_111 = self.conv1_1_1(x1)attn_112 = self.conv1_1_2(x1)attn_113 = self.conv1_1_3(x1)attn_121 = self.conv1_2_1(x1)attn_122 = self.conv1_2_2(x1)attn_123 = self.conv1_2_3(x1)attn_211 = self.conv2_1_1(x2)attn_212 = self.conv2_1_2(x2)attn_213 = self.conv2_1_3(x2)attn_221 = self.conv2_2_1(x2)attn_222 = self.conv2_2_2(x2)attn_223 = self.conv2_2_3(x2)out1 = attn_111 + attn_112 + attn_113 +attn_121 + attn_122 + attn_123out2 = attn_211 + attn_212 + attn_213 +attn_221 + attn_222 + attn_223out1 = self.project_out(out1)out2 = self.project_out(out2)k1 = rearrange(out1, 'b (head c) h w -> b head h (w c)', head=self.num_heads)v1 = rearrange(out1, 'b (head c) h w -> b head h (w c)', head=self.num_heads)k2 = rearrange(out2, 'b (head c) h w -> b head w (h c)', head=self.num_heads)v2 = rearrange(out2, 'b (head c) h w -> b head w (h c)', head=self.num_heads)q2 = rearrange(out1, 'b (head c) h w -> b head w (h c)', head=self.num_heads)q1 = rearrange(out2, 'b (head c) h w -> b head h (w c)', head=self.num_heads)q1 = torch.nn.functional.normalize(q1, dim=-1)q2 = torch.nn.functional.normalize(q2, dim=-1)k1 = torch.nn.functional.normalize(k1, dim=-1)k2 = torch.nn.functional.normalize(k2, dim=-1)attn1 = (q1 @ k1.transpose(-2, -1))attn1 = attn1.softmax(dim=-1)out3 = (attn1 @ v1) + q1attn2 = (q2 @ k2.transpose(-2, -1))attn2 = attn2.softmax(dim=-1)out4 = (attn2 @ v2) + q2out3 = rearrange(out3, 'b head h (w c) -> b (head c) h w', head=self.num_heads, h=h, w=w)out4 = rearrange(out4, 'b head w (h c) -> b (head c) h w', head=self.num_heads, h=h, w=w)out = self.project_out(out3) + self.project_out(out4) + x1+x2return outif __name__ == '__main__':x = torch.randn(4, 64, 128, 128).cuda()y = torch.randn(4, 64, 128, 128).cuda()model = MDAF(64).cuda()out = model(x,y)print(out.shape)