**標題:**On the Integration of Self-Attention and Convolution
**論文鏈接:**https://arxiv.org/pdf/2111.14556
**代碼鏈接:**https://github.com/LeapLabTHU/ACmix
創新點
1. 揭示卷積和自注意力的內在聯系
文章通過重新分解卷積和自注意力模塊的操作,發現它們在第一階段(特征投影)都依賴于 1×1 卷積操作,并且這一階段占據了大部分的計算復雜度(與通道數的平方成正比)。這一發現為整合兩種模塊提供了理論基礎。
2. 提出 ACmix 模型
基于上述發現,作者提出了 ACmix 模型,它通過共享 1×1 卷積操作來同時實現卷積和自注意力的功能。具體來說:
**第一階段:**輸入特征通過 1×1 卷積投影,生成中間特征。
**第二階段:**這些中間特征分別用于卷積路徑(通過移位和聚合操作)和自注意力路徑(計算注意力權重并聚合值)。最終,兩條路徑的輸出通過可學習的權重加權求和,得到最終輸出。
3. 改進的移位和聚合操作
文章還提出了一種改進的移位操作,通過使用 固定卷積核的分組卷積 來替代傳統的張量移位操作。這種方法不僅提高了計算效率,還允許卷積核的可學習性,進一步增強了模型的靈活性。
4. 適應性路徑權重
ACmix 引入了兩個可學習的標量參數(α 和 β),用于動態調整卷積路徑和自注意力路徑的權重。這種設計不僅提高了模型的靈活性,還允許模型在不同深度上自適應地選擇更適合的特征提取方式。實驗表明,這種設計在模型的不同階段表現出不同的偏好,例如在早期階段更傾向于卷積,在后期階段更傾向于自注意力。
整體結構
第一階段:特征投影
在第一階段,輸入特征通過三個1×1卷積進行投影,分別生成查詢(query)、鍵(key)和值(value)特征映射。這些特征映射隨后被重塑為N塊,形成一個包含3×N特征映射的中間特征集。
第二階段:特征聚合
在第二階段,中間特征集被分為兩個路徑進行處理:
- **自注意力路徑:**將中間特征集分為N組,每組包含三個特征映射(分別對應查詢、鍵和值)。這些特征映射按照傳統的多頭自注意力機制進行處理,計算注意力權重并聚合值。
- **卷積路徑:**通過輕量級的全連接層生成k2個特征映射(k為卷積核大小)。這些特征映射通過移位和聚合操作,以類似傳統卷積的方式處理輸入特征,從局部感受野收集信息。
輸出整合
最后,自注意力路徑和卷積路徑的輸出通過兩個可學習的標量參數(α和β)加權求和,得到最終的輸出。
改進的移位和聚合操作
為了提高計算效率,ACmix模型采用了改進的移位操作,通過固定卷積核的分組卷積來替代傳統的張量移位操作。這種方法不僅提高了計算效率,還允許卷積核的可學習性,進一步增強了模型的靈活性。
模型的靈活性和泛化能力
ACmix模型不僅適用于標準的自注意力機制,還可以與各種變體(如Patchwise Attention、Window Attention和Global Attention)結合使用。這種設計使得ACmix能夠適應不同的任務需求,具有廣泛的適用性。
消融實驗
1. 結合兩個路徑的輸出
消融實驗探索了卷積和自注意力輸出的不同組合方式對模型性能的影響。實驗結果表明:
- **卷積和自注意力的組合優于單一路徑:**使用卷積和自注意力模塊的組合始終優于僅使用單一路徑(如僅卷積或僅自注意力)的模型。
- **可學習參數的靈活性:**通過引入可學習的參數(如α和β)來動態調整卷積和自注意力路徑的權重,ACmix能夠根據網絡中不同位置的需求自適應地調整路徑強度,從而獲得更高的靈活性和性能。
2. 組卷積核的選擇
實驗還對組卷積核的設計進行了驗證,結果表明:
- **用組卷積替代張量位移:**通過使用組卷積替代傳統的張量位移操作,顯著提高了模型的推理速度。
- **可學習卷積核和初始化:**使用可學習的卷積核并結合精心設計的初始化方法,進一步增強了模型的靈活性,并有助于提升最終性能。
3. 不同路徑的偏好
ACmix模型引入了兩個可學習標量α和β,用于動態調整卷積和自注意力路徑的權重。通過平行實驗,觀察到以下趨勢:
- **早期階段偏好卷積:**在Transformer模型的早期階段,卷積作為特征提取器表現更好。
- **中間階段混合使用:**在網絡的中間階段,模型傾向于混合使用兩種路徑,并逐漸增加對卷積的偏好。
- **后期階段偏好自注意力:**在網絡的最后階段,自注意力表現優于卷積。
4. 對模型性能的影響
這些消融實驗結果表明,ACmix模型通過合理結合卷積和自注意力的優勢,并優化計算路徑,不僅在多個視覺任務上取得了顯著的性能提升,還保持了較高的計算效率
ACmix模塊的作用
1. 融合卷積和自注意力的優勢
ACmix模塊通過結合卷積的局部特征提取能力和自注意力的全局感知能力,實現了一種高效的特征融合策略。這種設計使得模型能夠同時利用卷積的局部感受野特性和自注意力的靈活性。
2. 優化計算路徑
ACmix通過優化計算路徑和減少重復計算,提高了整體模塊的計算效率。具體來說,它通過1×1卷積對輸入特征圖進行投影,生成中間特征,然后根據不同的范式(卷積和自注意力)分別重用和聚合這些中間特征。這種設計不僅減少了計算開銷,還提升了模型性能。
3. 改進的位移與求和操作
在卷積路徑中,ACmix采用深度可分離卷積(depthwise convolution)來替代低效的張量位移操作,從而提高了實際推理效率。
4. 動態調整路徑權重
ACmix引入了兩個可學習的標量參數(α和β),用于動態調整卷積和自注意力路徑的權重。這種設計使得模型能夠根據網絡中不同位置的需求自適應地調整路徑強度,從而獲得更高的靈活性。
5. 廣泛的應用潛力
ACmix在多個視覺任務(如圖像分類、語義分割和目標檢測)上均顯示出優于單一機制(僅卷積或僅自注意力)的性能,展示了其廣泛的應用潛力。
6. 實驗驗證
實驗結果表明,ACmix在保持較低計算開銷的同時,能夠顯著提升模型的性能。例如,在ImageNet分類任務中,ACmix模型在相同的FLOPs或參數數量下表現出色,并且在與競爭對手的基準比較中取得了持續的改進。此外,ACmix在ADE20K語義分割任務和COCO目標檢測任務中也顯示出明顯的改進
代碼實現
import torch
import torch.nn as nndef position(H, W, is_cuda=True):if is_cuda:loc_w = torch.linspace(-1.0, 1.0, W).cuda().unsqueeze(0).repeat(H, 1)loc_h = torch.linspace(-1.0, 1.0, H).cuda().unsqueeze(1).repeat(1, W)else:loc_w = torch.linspace(-1.0, 1.0, W).unsqueeze(0).repeat(H, 1)loc_h = torch.linspace(-1.0, 1.0, H).unsqueeze(1).repeat(1, W)loc = torch.cat([loc_w.unsqueeze(0), loc_h.unsqueeze(0)], 0).unsqueeze(0)return locdef stride(x, stride):b, c, h, w = x.shapereturn x[:, :, ::stride, ::stride]def init_rate_half(tensor):if tensor is not None:tensor.data.fill_(0.5)def init_rate_0(tensor):if tensor is not None:tensor.data.fill_(0.)class ACmix(nn.Module):def __init__(self, in_planes, out_planes, kernel_att=7, head=4, kernel_conv=3, stride=1, dilation=1):super(ACmix, self).__init__()self.in_planes = in_planesself.out_planes = out_planesself.head = headself.kernel_att = kernel_attself.kernel_conv = kernel_convself.stride = strideself.dilation = dilationself.rate1 = torch.nn.Parameter(torch.Tensor(1))self.rate2 = torch.nn.Parameter(torch.Tensor(1))self.head_dim = self.out_planes // self.headself.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=1)self.conv2 = nn.Conv2d(in_planes, out_planes, kernel_size=1)self.conv3 = nn.Conv2d(in_planes, out_planes, kernel_size=1)self.conv_p = nn.Conv2d(2, self.head_dim, kernel_size=1)self.padding_att = (self.dilation * (self.kernel_att - 1) + 1) // 2self.pad_att = torch.nn.ReflectionPad2d(self.padding_att)self.unfold = nn.Unfold(kernel_size=self.kernel_att, padding=0, stride=self.stride)self.softmax = torch.nn.Softmax(dim=1)self.fc = nn.Conv2d(3 * self.head, self.kernel_conv * self.kernel_conv, kernel_size=1, bias=False)self.dep_conv = nn.Conv2d(self.kernel_conv * self.kernel_conv * self.head_dim, out_planes,kernel_size=self.kernel_conv, bias=True, groups=self.head_dim, padding=1,stride=stride)self.reset_parameters()def reset_parameters(self):init_rate_half(self.rate1)init_rate_half(self.rate2)kernel = torch.zeros(self.kernel_conv * self.kernel_conv, self.kernel_conv, self.kernel_conv)for i in range(self.kernel_conv * self.kernel_conv):kernel[i, i // self.kernel_conv, i % self.kernel_conv] = 1.kernel = kernel.squeeze(0).repeat(self.out_planes, 1, 1, 1)self.dep_conv.weight = nn.Parameter(data=kernel, requires_grad=True)self.dep_conv.bias = init_rate_0(self.dep_conv.bias)def forward(self, x):q, k, v = self.conv1(x), self.conv2(x), self.conv3(x)scaling = float(self.head_dim) ** -0.5b, c, h, w = q.shapeh_out, w_out = h // self.stride, w // self.stride# ### att# ## positional encodingpe = self.conv_p(position(h, w, x.is_cuda))q_att = q.view(b * self.head, self.head_dim, h, w) * scalingk_att = k.view(b * self.head, self.head_dim, h, w)v_att = v.view(b * self.head, self.head_dim, h, w)if self.stride > 1:q_att = stride(q_att, self.stride)q_pe = stride(pe, self.stride)else:q_pe = peunfold_k = self.unfold(self.pad_att(k_att)).view(b * self.head, self.head_dim,self.kernel_att * self.kernel_att, h_out,w_out) # b*head, head_dim, k_att^2, h_out, w_outunfold_rpe = self.unfold(self.pad_att(pe)).view(1, self.head_dim, self.kernel_att * self.kernel_att, h_out,w_out) # 1, head_dim, k_att^2, h_out, w_outatt = (q_att.unsqueeze(2) * (unfold_k + q_pe.unsqueeze(2) - unfold_rpe)).sum(1) # (b*head, head_dim, 1, h_out, w_out) * (b*head, head_dim, k_att^2, h_out, w_out) -> (b*head, k_att^2, h_out, w_out)att = self.softmax(att)out_att = self.unfold(self.pad_att(v_att)).view(b * self.head, self.head_dim, self.kernel_att * self.kernel_att,h_out, w_out)out_att = (att.unsqueeze(1) * out_att).sum(2).view(b, self.out_planes, h_out, w_out)## convf_all = self.fc(torch.cat([q.view(b, self.head, self.head_dim, h * w), k.view(b, self.head, self.head_dim, h * w),v.view(b, self.head, self.head_dim, h * w)], 1))f_conv = f_all.permute(0, 2, 1, 3).reshape(x.shape[0], -1, x.shape[-2], x.shape[-1])out_conv = self.dep_conv(f_conv)return self.rate1 * out_att + self.rate2 * out_conv#輸入 B C H W, 輸出 B C H W
if __name__ == '__main__':block = ACmix(in_planes=64, out_planes=64)input = torch.rand(3, 64, 32, 32)output = block(input)print(input.size(), output.size())