題目:Dual Aggregation Transformer(雙聚合Transformer) for Image Super-Resolution(圖像超分辨)
論文(ICCV):Chen_Dual_Aggregation_Transformer_for_Image_Super-Resolution_ICCV_2023_paper.pdf (thecvf.com)
源碼:zhengchen1999/DAT: PyTorch code for our ICCV 2023 paper "Dual Aggregation Transformer for Image Super-Resolution" (github.com)?
Super Resolution:超分辨率(Super-Resolution),簡稱超分(SR)。是指利用光學及其相關光學知識,根據已知圖像信息恢復圖像細節和其他數據信息的過程,簡單來說就是增大圖像的分辨率,防止其圖像質量下降。
?一、摘要
研究背景:Transformer最近在低級視覺任務中獲得了相當大的流行,包括圖像超分辨率(SR)。這些網絡沿著不同的維度、空間或通道利用自注意力,并取得了令人印象深刻的性能。這激勵我們將 Transformer 中的兩個維度結合起來,以獲得更強大的表示能力。
主要工作:基于上述思想,本文提出了一種新的 Transformer 模型,雙聚合 Transformer(DAT),用于 SR 圖像。該 DAT?以 模塊間 和 模塊內 雙重方式聚合了 跨空間?和 跨通道維度?的特征。
- 1. 交替地在連續的 Transformer 塊中應用 空間 和 通道自注意力。該策略使 DAT 能夠捕獲全局上下文并實現?模塊間特征聚合?。
- 2. 提出了自適應交互模塊(AIM)和空間門前饋網絡(SGFN)來實現?模塊內特征聚合?。AIM 從相應維度補充了兩種自注意力機制。
- 3. 同時,SGFN 在前饋網絡中引入了額外的非線性空間信息。
實驗效果:大量實驗表明,DAT方法優于現有方法。
二、引言
圖像超分辨任務的背景、挑戰以及基于CNN網絡的方法的不足(在全局依賴上)—> transformer簡介 + 在超分辨方向上transformer相關的研究工作(主要為自注意力方向,兩個方面:空間層面和通道層面)+ 概括 Spatial window self-attention(SW-SA)和?Channel-wise self-attention (CW-SA) 的作用(對超分辨)—> DAT網絡、AIM模塊和SGFN模塊的設計動機(為了解決哪些問題)、設計思路(如何實現,網絡具體實現是怎么做的)、功能和作用?—> 貢獻:
- 1.?設計了一種新的圖像SR模型--雙聚合transformer(DAT)。DAT以塊間和塊內雙重方式聚合空間和通道特征,以獲得強大的表示能力。(主要工作概述)
- 2.?交替采用空間和通道自關注,實現塊間空間和通道特征聚合。此外,還提出了AIM和SGFN來實現塊內特征聚合。(新模塊概述)
- 3.?進行了大量的實驗,以證明DAT優于最先進的方法,同時保持了較低的復雜性和模型大小。(實驗效果概述)
三、方法
3.1 架構概述??
Dual Aggregation Transformer (DAT) 的網絡體系結構如下圖所示。雙空間transformer模塊 (DSTB)和雙通道transformer模塊 (DCTB)是兩個連續的雙聚合transformer模塊 (DATB)。(DSTB和DCTB只在注意力有所不同,因此將他們都看作DATB模塊)
整個網絡包括三個模塊:淺層特征提取、深層特征提取和圖像重建。
淺層特征提取(淺層卷積):首先,給定一幅低分辨率(LR)輸入圖像?,使用卷積層對其進行處理并生成淺層特征?
。
深層特征提取(DSTB +?DCTB + 2× Conv):淺層特征??在深特征提取模塊內進行處理,以獲得深層特征?
?。該模塊由N1個殘差組(RG)堆疊。每個RG包含n2對雙聚合transformer模塊(DATB)。每個DATB對包含兩個transformer模塊,分別利用空間和通道自注意力。在RG的末尾引入一個卷積層來細化從變壓器塊中提取的特征。此外,對于每個RG,使用殘差連接。
圖像重建(conv + pixel shuffle + conv):在該模塊中,通過 pixel shuffle 方法對深度特征??進行上采樣。并在上采樣操作之前和之后使用卷積層聚集特征。
Q:pixel shuffle 方法是什么?
3.2?Dual Aggregation Transformer Block(雙聚合transformer模塊)
DATB有兩種類型:雙空間transformer模塊 (DSTB)和雙通道transformer模塊 (DCTB)。?
DSTB 和 DCTB 分別基于?Spatial Window Self-Attention(空間窗口自注意力)?和 Channel-Wise Self-Attention(逐通道自注意力)。通過交替應用 DSTB 和 DCTB ,DAT可以實現空間維度和通道維度之間的塊間特征聚合。此外,還提出了自適應交互模塊(AIM)和空間門前饋網絡(SGFN)來實現模塊內特征聚合。
1)Spatial Window Self-Attention(空間窗口自注意力)
如圖所示,空間窗口自注意力(SW-SA)計算窗口內的注意。
過程:
1. 給定輸入?,通過線性投影生成查詢Q、鍵K和值V矩陣。該過程被定義為:
其中,是省略偏差的線性投影。
2. 隨后,將Q、K和V劃分為不重疊的窗口,并展平每個包含??個像素的窗口。將重塑的投影矩陣表示為?
。然后,將?
?分成 h 個頭:
,
,且?
?。每個頭的維度為?
?。第 i 個頭的輸出?
?定義為:
其中,D表示相對位置編碼。(自注意力計算)
3. 最后,通過對所有??的重塑和拼接,得到特征?
。?這一過程的公式如下:
其中,?是融合所有特征的線性投影。(這里提到默認使用Swin transformer中的移位窗口操作來捕捉更多的空間信息)
代碼實現:
def img2windows(img, H_sp, W_sp): # 劃分窗口"""Input: Image (B, C, H, W)Output: Window Partition (B', N, C)"""B, C, H, W = img.shapeimg_reshape = img.view(B, C, H // H_sp, H_sp, W // W_sp, W_sp)img_perm = img_reshape.permute(0, 2, 4, 3, 5, 1).contiguous().reshape(-1, H_sp* W_sp, C)return img_permclass Spatial_Attention(nn.Module):""" Spatial Window Self-Attention.It supports rectangle window (containing square window).Args:dim (int): Number of input channels.idx (int): The indentix of window. (0/1)split_size (tuple(int)): Height and Width of spatial window.dim_out (int | None): The dimension of the attention output. Default: Nonenum_heads (int): Number of attention heads. Default: 6attn_drop (float): Dropout ratio of attention weight. Default: 0.0proj_drop (float): Dropout ratio of output. Default: 0.0qk_scale (float | None): Override default qk scale of head_dim ** -0.5 if setposition_bias (bool): The dynamic relative position bias. Default: True"""def __init__(self, dim, idx, split_size=[8,8], dim_out=None, num_heads=6, attn_drop=0., proj_drop=0., qk_scale=None, position_bias=True):super().__init__()self.dim = dimself.dim_out = dim_out or dimself.split_size = split_sizeself.num_heads = num_headsself.idx = idxself.position_bias = position_biashead_dim = dim // num_heads # 每個頭的維度self.scale = qk_scale or head_dim ** -0.5if idx == 0:H_sp, W_sp = self.split_size[0], self.split_size[1]elif idx == 1:W_sp, H_sp = self.split_size[0], self.split_size[1]else:print ("ERROR MODE", idx)exit(0)self.H_sp = H_spself.W_sp = W_spif self.position_bias:self.pos = DynamicPosBias(self.dim // 4, self.num_heads, residual=False)# generate mother-setposition_bias_h = torch.arange(1 - self.H_sp, self.H_sp)position_bias_w = torch.arange(1 - self.W_sp, self.W_sp)biases = torch.stack(torch.meshgrid([position_bias_h, position_bias_w]))biases = biases.flatten(1).transpose(0, 1).contiguous().float()self.register_buffer('rpe_biases', biases)# get pair-wise relative position index for each token inside the windowcoords_h = torch.arange(self.H_sp)coords_w = torch.arange(self.W_sp)coords = torch.stack(torch.meshgrid([coords_h, coords_w]))coords_flatten = torch.flatten(coords, 1)relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]relative_coords = relative_coords.permute(1, 2, 0).contiguous()relative_coords[:, :, 0] += self.H_sp - 1relative_coords[:, :, 1] += self.W_sp - 1relative_coords[:, :, 0] *= 2 * self.W_sp - 1relative_position_index = relative_coords.sum(-1)self.register_buffer('relative_position_index', relative_position_index)self.attn_drop = nn.Dropout(attn_drop)def im2win(self, x, H, W): # 將Q、K和V劃分為不重疊的窗口, (B N C) --> (num_win num_heads H_sp* W_sp C//num_heads)B, N, C = x.shapex = x.transpose(-2,-1).contiguous().view(B, C, H, W)x = img2windows(x, self.H_sp, self.W_sp)x = x.reshape(-1, self.H_sp* self.W_sp, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3).contiguous()return xdef forward(self, qkv, H, W, mask=None):"""Input: qkv: (B, 3*L, C), H, W, mask: (B, N, N), N is the window sizeOutput: x (B, H, W, C)"""q,k,v = qkv[0], qkv[1], qkv[2]B, L, C = q.shapeassert L == H * W, "flatten img_tokens has wrong size"# partition the q,k,v, image to windowq = self.im2win(q, H, W)k = self.im2win(k, H, W)v = self.im2win(v, H, W)q = q * self.scaleattn = (q @ k.transpose(-2, -1)) # B head N C @ B head C N --> B head N N# calculate drpeif self.position_bias:pos = self.pos(self.rpe_biases)# select position biasrelative_position_bias = pos[self.relative_position_index.view(-1)].view(self.H_sp * self.W_sp, self.H_sp * self.W_sp, -1)relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()attn = attn + relative_position_bias.unsqueeze(0)N = attn.shape[3]# use mask for shift windowif mask is not None:nW = mask.shape[0]attn = attn.view(B, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)attn = attn.view(-1, self.num_heads, N, N)attn = nn.functional.softmax(attn, dim=-1, dtype=attn.dtype)attn = self.attn_drop(attn)x = (attn @ v)x = x.transpose(1, 2).reshape(-1, self.H_sp* self.W_sp, C) # B head N N @ B head N C# merge the window, window to imagex = windows2img(x, self.H_sp, self.W_sp, H, W) # B H' W' Creturn x
2)Channel-Wise Self-Attention(逐通道自注意力)
通道自注意力(channel-wise self-attention, CW-SA)中的自注意力機制是沿著通道維度進行的。?
方法:按通道劃分為頭部,并分別對每個頭部進行注意力計算。
過程:給定輸入X,應用線性投影來生成查詢、鍵和值矩陣,并將它們重塑為??大小。用?
,
?和?
?表示重構矩陣。與SW-SA中的操作相同,將投影向量分成 h 個頭。則第 i 頭的通道自注意力過程可計算為:
其中,?是第 i 個頭的輸出,α 是可學習的參數,用于在softmax函數之前調整內積。最后,通過對所有?
?進行重塑和拼接(這里與空間窗口自注意力操作相同),得到注意力特征?
。
class Adaptive_Channel_Attention(nn.Module):# The implementation builds on XCiT code https://github.com/facebookresearch/xcit""" Adaptive Channel Self-AttentionArgs:dim (int): Number of input channels.num_heads (int): Number of attention heads. Default: 6qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: Trueqk_scale (float | None): Override default qk scale of head_dim ** -0.5 if set.attn_drop (float): Attention dropout rate. Default: 0.0drop_path (float): Stochastic depth rate. Default: 0.0"""def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):super().__init__()self.num_heads = num_headsself.temperature = nn.Parameter(torch.ones(num_heads, 1, 1))self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)self.attn_drop = nn.Dropout(attn_drop)self.proj = nn.Linear(dim, dim)self.proj_drop = nn.Dropout(proj_drop)self.dwconv = nn.Sequential(nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1,groups=dim),nn.BatchNorm2d(dim),nn.GELU())self.channel_interaction = nn.Sequential(nn.AdaptiveAvgPool2d(1),nn.Conv2d(dim, dim // 8, kernel_size=1),nn.BatchNorm2d(dim // 8),nn.GELU(),nn.Conv2d(dim // 8, dim, kernel_size=1),)self.spatial_interaction = nn.Sequential(nn.Conv2d(dim, dim // 16, kernel_size=1),nn.BatchNorm2d(dim // 16),nn.GELU(),nn.Conv2d(dim // 16, 1, kernel_size=1))def forward(self, x, H, W):"""Input: x: (B, H*W, C), H, WOutput: x: (B, H*W, C)"""B, N, C = x.shapeqkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads) # 按通道劃分頭部qkv = qkv.permute(2, 0, 3, 1, 4)q, k, v = qkv[0], qkv[1], qkv[2]q = q.transpose(-2, -1)k = k.transpose(-2, -1)v = v.transpose(-2, -1)v_ = v.reshape(B, C, N).contiguous().view(B, C, H, W)q = torch.nn.functional.normalize(q, dim=-1)k = torch.nn.functional.normalize(k, dim=-1)attn = (q @ k.transpose(-2, -1)) * self.temperatureattn = attn.softmax(dim=-1)attn = self.attn_drop(attn)# attention outputattened_x = (attn @ v).permute(0, 3, 1, 2).reshape(B, N, C)# convolution outputconv_x = self.dwconv(v_)# Adaptive Interaction Module (AIM)# C-Map (before sigmoid)attention_reshape = attened_x.transpose(-2,-1).contiguous().view(B, C, H, W)channel_map = self.channel_interaction(attention_reshape)# S-Map (before sigmoid)spatial_map = self.spatial_interaction(conv_x).permute(0, 2, 3, 1).contiguous().view(B, N, 1)# S-Iattened_x = attened_x * torch.sigmoid(spatial_map)# C-Iconv_x = conv_x * torch.sigmoid(channel_map)conv_x = conv_x.permute(0, 2, 3, 1).contiguous().view(B, N, C)x = attened_x + conv_xx = self.proj(x)x = self.proj_drop(x)return x
??
3)Adaptive Interaction Module(自適應交互模塊)?
? ? ? ? ? ? ? ? ? ? ? ? ? ?? ? ?
下分支:由于自注意力專注于捕捉全局特征,納入了一個平行于自注意力模塊的卷積分支(DW-Conv),將局部性引入Transformer。?
問題:
- 1. 簡單地添加卷積分支不能有效地融合全局和局部特征。
- 2. 盡管SW-SA和CW-SA交替執行可以同時捕獲空間和通道特征,但不同維度的信息仍然不能在單個自注意力中有效利用。
目的:為克服這些問題,本文提出了自適應交互模塊(AIM),根據自注意力機制的類型,從空間或通道維度自適應地重新加權兩個分支的特征。
過程:首先,對 V 進行并行深度卷積(DW-Conv),以建立自注意力和卷積之間的直接聯系。卷積輸出為??。然后引入AIM,對兩個特征進行自適應調整。具體而言,AIM包括兩個交互操作:空間交互(S-I)?和?通道交互(C-I)。給定兩個輸入特征,
?和?
,空間交互計算一個輸入的空間注意力圖( 記為S-Map,大小為?
?)。通道交互計算通道注意力圖( 記為C-Map,大小為?
?)。以 B 為例,公式表達如下:
其中??表示全局平均池,
?表示Sigmoid函數,
?表示GELU函數。
?表示用于縮小或放大通道維度的逐點卷積的權重。W1和W2的縮放比率分別為 r1,C/r1。W3的縮放比率為r2,并且W4膨脹比率為 r2。
隨后,相互將注意力圖應用于另一個輸入,從而實現交互。這一過程的公式如下:
其中,⊙表示逐元素乘法。
最后,基于AIM,在SW-SA和CW-SA的基礎上,設計了兩種新的自注意力機制AS-SA和AC-SA。對于SW-SA,我們引入了兩個分支之間的通道-空間相互作用。對于CW-SA,我們采用空間-信道交互。給定輸入?,過程定義為:
其中,、
?和?
是上面定義的SW-SA、CW-SA和DW-Conv的輸出。
4)Spatial-Gate Feed-Forward Network(空間門前饋網絡)
?問題:
- 1.?前饋網絡(FFN)難以捕獲空間信息。
- 2. 此外,通道中的冗余信息阻礙了特征表達能力。
解決方法:提出了空間門前饋網絡(SGFN),將空間門(SG)引入到FFN中。
結構:SG模塊是一個簡單的門機制,由深度卷積和逐元素乘法組成。沿著通道維度,將特征映射分為卷積支路和乘法支路兩部分。總體而言,給定輸入?,SGFN計算公式如下:
其中,?和?
?表示線性投影,σ 表示Gelu函數,
?表示深度卷積的可學習參數。
?和?
?
?空間中,其中 C' 表示SGFN中的隱維度。
四、實驗
訓練設置:本文訓練了 patch 大小為64×64,批次大小為32的模型。訓練迭代次數為500K。通過ADAM優化器( β1=0.9和β2=0.99 ),通過最小化 L1 損失來優化模型。將學習速率設置為2×10?4,并以[250K,400K,450K,475K]為標記減半。此外,在訓練期間,隨機使用90?、180?和270?的旋轉和水平翻轉來增強數據。本文的模型是基于4個A100圖形處理器的PyTorch實現的。
數據集:DIV2K 和 Flickr2K用于訓練,以及五個基準數據集:Set5、Set14、B100、Urban100和Manga109用于測試。分別在×2、×3、×4三種尺度下進行了實驗。
評估指標:PSNR 和 SSIM,這兩個度量是在YCbCR空間的Y通道( 即,亮度 )上計算的。
4.1?消融實驗
為了調查交替使用SW-SA和CW-SA的策略的效果,本文進行了幾個實驗:
- 1. 表的第一行和第二行表示用 CW-SA 或 SW-SA 替換 DAT 中的所有注意模塊,其中SW-SA采用8x8窗口大小。(單一模塊)
- 2. 第三行表示在 DAT 中的連續transformer模塊中交替應用兩個SA。此外,在SA中,所有模型都采用規則的FFN,而不采用AIM。(本文方法)
比較這三種模型,可以觀察到,使用SW-SA的模型的性能優于使用CW-SA的模型。此外,交替應用兩個SA可以獲得33.34dB的最佳性能。這表明,同時利用通道信息和空間信息是精確圖像恢復的關鍵。?
4.2 與最先進的方法進行比較
定量比較:同時,除了在Urban100數據集(×4)上的PSNR值與CAT-A相比外,DAT的性能要好于以前的方法。具體地說,與SwinIR和CAT-A相比(比較對象),DAT在Manga109數據集(×2)上(數據集)獲得了顯著的增益,分別獲得了0.41db和0.23db的改進(提升比例)。此外,小視覺模型DAT-S也取得了與以往方法相當或更好的性能。所有這些定量結果表明,聚合塊間和塊內的空間和通道信息可以有效地提高圖像重建質量(結論)。?
定性比較:在一些具有挑戰性的場景中,以前的方法可能會遇到模糊偽影、扭曲或不準確的紋理恢復(對比方法定性描述)。與之形成鮮明對比的是,本文的方法有效地減少了偽影,保留了更多的結構和更精細的細節(本文方法定性描述)。這主要是因為本文的方法通過從不同維度提取復雜特征,具有更強的表示能力(結論)。
五、結論
主要工作:本文提出了一種新的圖像SR變換模型--雙聚集變換(DAT)。DAT以塊間和塊內雙重方式聚合空間和信道特征,以獲得強大的表示能力(概述,方法 + 作用)。
- 1. 具體地說,連續的transformer模塊交替地應用空間窗口和通道方式的自注意力。DAT可以通過這種替代策略對全局依賴關系進行建模,并實現空間維度和通道維度之間的塊間特征聚合。
- 2. 此外,還提出了自適應交互模塊(AIM)和空間門前饋網絡(SGFN)來增強每個塊并實現兩維之間的塊內特征聚合。目的從相應維度強化兩種自我注意機制的建模能力。(逐模塊細化概述,方法 + 作用)
- 3. 同時,SGFN利用非線性空間信息對前饋網絡進行補充。
實驗結果:大量的實驗表明,DAT的性能優于以往的方法。