一、引言
在計算機視覺領域的發展歷程中,卷積神經網絡(CNN) 長期占據主導地位。從早期的 LeNet 到后來的 AlexNet、VGGNet、ResNet 等,CNN 在圖像分類、目標檢測、語義分割等任務中取得了顯著成果。然而,CNN 在捕捉全局信息和處理長距離依賴關系方面存在局限性。與此同時,Transformer Architektur 在自然語言處理(NLP)領域表現出色,憑借自注意力機制有效捕捉序列數據中的長距離依賴關系,例如 GPT 系列模型在語言生成和問答系統中的成功應用。
將 Transformer 直接應用于視覺任務面臨挑戰,例如計算復雜度高,尤其是在處理高分辨率圖像時,計算量會隨著圖像尺寸增加而顯著增長,對硬件資源和計算時間要求較高。此外,Transformer 最初為序列數據設計,在提取圖像局部特征方面不如 CNN 有效。
Swin Transformer 通過引入 窗口注意力機制,將特征圖劃分為多個不重疊窗口,在每個窗口內進行自注意力計算,從而降低了計算復雜度。它采用分層結構,類似 CNN 的層次設計,能夠提取不同尺度的特征,適應多尺度視覺任務。此外,補丁合并層 通過減少特征圖尺寸并增加通道數進一步提升性能。Swin Transformer 在多個視覺任務中表現出色,成為計算機視覺領域的研究重點。本文將深入分析其原理、結構、優勢及應用案例。
二、Swin Transformer 的背景
在深度學習發展中,卷積神經網絡(CNN) 在計算機視覺領域占據重要地位。LeNet 在手寫數字識別中取得初步成功,為 CNN 奠定了基礎。2012 年,AlexNet 在 ImageNet 挑戰賽中以更深的網絡結構和 ReLU 激活函數大幅提升準確率,推動了深度學習在視覺領域的快速發展。此后,VGGNet 通過堆疊小卷積核減少參數,ResNet 通過殘差連接解決深層網絡的梯度問題,使網絡能夠更深層并學習復雜特征。
然而,CNN 在捕捉全局信息方面能力較弱。卷積操作主要提取局部特征,通過多層卷積擴大感受野,但對長距離依賴關系建模仍有限制。與此同時,Transformer 在 NLP 領域憑借自注意力機制和并行計算能力取得成功,GPT-3 等模型展示了其語言理解和生成能力。
研究者嘗試將 Transformer 應用于視覺任務,但面臨圖像數據與文本數據的結構差異及高計算復雜度問題。Swin Transformer 的提出旨在將 Transformer 的能力引入視覺領域,通過窗口注意力機制和分層結構降低計算復雜度,提升特征提取能力。
三、核心原理剖析
Swin Transformer 是一種基于 Transformer 的視覺模型,其核心創新在于層次化架構設計(Hierarchical Architecture)和移位窗口自注意力(Shifted Window Self-Attention)。這一設計使其能夠高效處理圖像數據,同時兼容卷積神經網絡(CNN)的多尺度特征提取能力,適用于分類、檢測、分割等任務。
(一)整體架構與分層設計
Swin Transformer 的整體架構分為 4 個階段(Stage),每個階段通過 Patch Merging 操作逐步降低特征圖分辨率,同時增加通道維度,形成金字塔式的層次化特征表示。整體流程如下:
-
輸入處理
? Patch Partition:將輸入圖像劃分為 4 × 4 4\times4 4×4 的非重疊塊(Patch),每個塊通過 線性投影(Linear Embedding)轉換為特征向量。例如,輸入圖像尺寸為 H × W × 3 H \times W \times 3 H×W×3,處理后得到 ( H 4 × W 4 × C \frac{H}{4} \times \frac{W}{4} \times C 4H?×4W?×C) 的特征圖( C 為嵌入維度,默認 96 C 為嵌入維度,默認 96 C為嵌入維度,默認96)。 -
Stage 1~4
? 每個 Stage 包含若干 Swin Transformer Block 和一個 Patch Merging 層(最后一個 Stage 無 Patch Merging)。
? Swin Transformer Block:交替使用 窗口多頭自注意力(W-MSA) 和 移位窗口多頭自注意力(SW-MSA),通過窗口劃分減少計算復雜度。
? Patch Merging:將相鄰的 2 × 2 2\times2 2×2 塊合并為一個塊(類似池化),分辨率降低為原來的 1 2 \frac{1}{2} 21?,通道數增加為原來的 2 2 2 倍(例如從 C C C 到 2 C 2C 2C)。
典型配置示例
Stage | 特征圖分辨率 | 通道數 | Swin Block 數量 | 窗口大小 |
---|---|---|---|---|
1 | H 4 × W 4 \frac{H}{4} \times \frac{W}{4} 4H?×4W? | 96 | 2 | 7×7 |
2 | H 8 × W 8 \frac{H}{8} \times \frac{W}{8} 8H?×8W? | 192 | 2 | 7×7 |
3 | H 16 × W 16 \frac{H}{16} \times \frac{W}{16} 16H?×16W? | 384 | 6 | 7×7 |
4 | H 32 × W 32 \frac{H}{32} \times \frac{W}{32} 32H?×32W? | 768 | 2 | 7×7 |
(二)窗口注意力機制(W-MSA)
窗口注意力機制 是 Swin Transformer 的核心創新。傳統 Transformer 的全局自注意力計算復雜度隨圖像尺寸平方增長,而 Swin Transformer 將特征圖劃分為不重疊窗口(如 7x7),在窗口內進行自注意力計算。窗口大小為 MxM 時,計算復雜度為 O(M2 * H/W),遠低于全局自注意力的 O(HW2)。
實現過程包括:將特征圖劃分為窗口,計算每個窗口內的 Query、Key 和 Value 矩陣,通過矩陣運算生成注意力權重并與 Value 相乘。這種方式降低計算量,同時保留局部特征提取能力。
(三)移位窗口機制(SW-MSA)
窗口注意力機制雖高效,但窗口間無交互可能限制全局信息捕捉。為此,Swin Transformer 引入 移位窗口機制。在連續自注意力層間,窗口位置移動(如右、下移),使相鄰窗口部分重疊,促進信息交互。超出邊界的區域通過填充處理。這一機制在保持低計算復雜度的同時增強全局上下文建模能力。
(四)補丁合并層(Patch Merging)
補丁合并層 用于構建層次特征。將特征圖按 2x2 窗口切分,拼接為 4C 維向量(C 為原通道數),通過線性層降維至 2C,最終特征圖尺寸減半,通道數翻倍。這一過程逐步整合局部特征,提取更具代表性的全局特征。
(五)多頭自注意力機制(Multi-Head Self-Attention)
Swin Transformer 沿用 多頭自注意力機制,通過多個線性變換生成多組 Query、Key 和 Value 矩陣,分別計算注意力并拼接輸出。不同注意力頭關注圖像的不同特征(如形狀、紋理),提升模型對復雜任務的適應性。
應用領域展示
(一)圖像分類
在 ImageNet 數據集上,Swin Transformer 表現優異。例如,Swin-B 在 ImageNet-22K 預訓練后,在 ImageNet-1K 上 Top-1 準確率達 87.3%,優于 ResNet50(約 76%)及 Vision Transformer(ViT)。其優勢在于窗口注意力機制和移位窗口機制結合,有效捕捉全局和局部信息。
(二)目標檢測
在 COCO 數據集上,以 Swin Transformer 為骨干網絡的 Mask R-CNN 模型 mAP 達 49.5,超越 Faster R-CNN(約 38)。分層結構提取多尺度特征,窗口機制增強上下文信息捕捉,提升檢測精度。
(三)語義分割
在 ADE20K 數據集上,基于 Swin Transformer 的 UperNet 模型 mIoU 達 44.5,高于 FCN(約 41)。其多尺度特征提取和上下文理解能力提升像素級分類準確性。
優勢對比分析
(一)與傳統 CNN 對比
與 CNN 相比,Swin Transformer 在全局信息和長距離依賴建模上更強。CNN 通過卷積提取局部特征,依賴多層堆疊擴大感受野,而 Swin Transformer 的自注意力機制直接捕捉全局依賴。計算復雜度方面,Swin Transformer 通過窗口機制實現近似線性增長,優于傳統 Transformer 的平方級增長。
(二)與Vision Transformer 模型對比
相比 Vision Transformer(ViT),Swin Transformer 通過窗口機制降低計算復雜度,提升空間和計算效率。其多層次設計同時捕捉局部和全局特征,且靈活的窗口調整適應不同任務,性能更優。
特性 | Swin Transformer | Vision Transformer (ViT) |
---|---|---|
特征圖分辨率 | 多尺度(4 個 Stage) | 單尺度(固定分辨率) |
計算復雜度 | 線性復雜度(窗口劃分) | 平方復雜度(全局注意力) |
適用任務 | 分類、檢測、分割 | 主要分類 |
位置編碼 | 相對位置編碼 | 絕對位置編碼 |
代碼實現示例
以下是使用 Python 和 PyTorch 框架實現 Swin Transformer 中關鍵模塊的代碼示例,并對代碼進行詳細解釋,以幫助讀者更好地理解模型的實現細節。
(一)窗口注意力機制(W-MSA)
WindowAttention
是 Swin Transformer 的核心模塊之一,實現了窗口內的多頭自注意力機制(Window-based Multi-head Self-Attention, W-MSA),通過限制注意力計算范圍降低復雜度,并加入相對位置偏置以增強空間感知能力。
import torch
import torch.nn as nn
import torch.nn.functional as Fclass WindowAttention(nn.Module):"""基于窗口的多頭自注意力模塊,包含相對位置編碼(Swin Transformer的核心組件)"""def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):"""Args:dim (int): 輸入特征維度window_size (tuple): 窗口大小 (h, w)num_heads (int): 注意力頭的數量qkv_bias (bool): 是否在qkv線性層添加偏置qk_scale (float): 縮放因子,默認為 head_dim^-0.5attn_drop (float): 注意力dropout概率proj_drop (float): 輸出投影層的dropout概率"""super().__init__()self.dim = dimself.window_size = window_sizeself.num_heads = num_headshead_dim = dim // num_heads # 每個注意力頭的維度# 縮放因子,用于縮放點積注意力得分self.scale = qk_scale or head_dim ?**? -0.5# 定義相對位置編碼表:存儲所有可能相對位置的位置偏置# 形狀為 [(2h-1)*(2w-1), num_heads],用于表示不同相對位置的注意力偏置self.relative_position_bias_table = nn.Parameter(torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads))# 生成窗口內每個位置的坐標(用于計算相對位置索引)coords_h = torch.arange(self.window_size[0]) # 高度方向坐標 [0,1,...,h-1]coords_w = torch.arange(self.window_size[1]) # 寬度方向坐標 [0,1,...,w-1]coords = torch.stack(torch.meshgrid([coords_h, coords_w], indexing='xy')) # 網格坐標 [2, h, w]coords_flatten = torch.flatten(coords, 1) # 展平為 [2, h*w]# 計算相對坐標差值(每個位置與其他位置的相對坐標差)relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # [2, h*w, h*w]relative_coords = relative_coords.permute(1, 2, 0).contiguous() # 調整維度順序為 [h*w, h*w, 2]# 將相對坐標偏移到非負數范圍(方便作為索引)relative_coords[:, :, 0] += self.window_size[0] - 1 # 行偏移到 [0, 2h-2]relative_coords[:, :, 1] += self.window_size[1] - 1 # 列偏移到 [0, 2w-2]# 將二維相對坐標轉換為一維索引(用于查表)relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 # 行坐標乘以跨度relative_position_index = relative_coords.sum(-1) # 合并坐標得到一維索引 [h*w, h*w]# 注冊為不參與梯度更新的緩沖區(在forward中通過索引獲取位置偏置)self.register_buffer("relative_position_index", relative_position_index)# 定義qkv投影層:將輸入特征映射為query, key, valueself.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)# 定義輸出投影層和dropoutself.proj = nn.Linear(dim, dim) # 合并多頭輸出self.attn_drop = nn.Dropout(attn_drop) # 注意力分數dropoutself.proj_drop = nn.Dropout(proj_drop) # 輸出投影dropout# 初始化相對位置偏置表(正態分布)nn.init.normal_(self.relative_position_bias_table, std=0.02)def forward(self, x, mask=None):"""Args:x (Tensor): 輸入特征,形狀為 [batch_size*num_windows, num_patches, dim]mask (Tensor): 窗口注意力掩碼(用于SW-MSA),形狀為 [num_windows, num_patches, num_patches]Returns:Tensor: 輸出特征,形狀同輸入"""B_, N, C = x.shape # B_ = batch_size * num_windows, N = num_patches (h*w), C = dim# 生成qkv向量,并重塑為多頭形式qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)q, k, v = qkv[0], qkv[1], qkv[2] # 分離q/k/v [B_, num_heads, N, head_dim]# 縮放點積得分q = q * self.scale # 縮放query# 計算原始注意力分數 [B_, num_heads, N, N]attn = (q @ k.transpose(-2, -1)) # 矩陣乘法計算注意力分數# 添加相對位置偏置(從預定義的表中獲取)relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1) # 將索引展平查表].view(self.window_size[0] * self.window_size[1], # 窗口內總位置數(h*w)self.window_size[0] * self.window_size[1], -1) # 形狀變為 [h*w, h*w, num_heads]relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # [num_heads, h*w, h*w]attn = attn + relative_position_bias.unsqueeze(0) # 廣播到batch維度 [B_, num_heads, N, N]# 應用掩碼(用于SW-MSA的移位窗口)if mask is not None:nW = mask.shape[0] # 窗口數量# 將attn拆分為不同窗口的注意力并添加掩碼attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)attn = attn.view(-1, self.num_heads, N, N) # 重新合并batch維度attn = F.softmax(attn, dim=-1) # 帶掩碼的softmaxelse:attn = F.softmax(attn, dim=-1) # 普通softmaxattn = self.attn_drop(attn) # 應用注意力dropout# 計算加權值向量并合并多頭輸出x = (attn @ v).transpose(1, 2).reshape(B_, N, C) # [B_, N, dim]# 輸出投影和dropoutx = self.proj(x)x = self.proj_drop(x)return x
(二)補丁合并層(Patch Merging)
PatchMerging
是 Swin Transformer 中用于降采樣和通道擴展的模塊,通過合并相鄰補丁減少空間分辨率并增加特征維度。
import torch
import torch.nn as nnclass PatchMerging(nn.Module):"""空間下采樣模塊,用于Swin Transformer的層次化特征提取(類似CNN中的池化層)功能:將特征圖分辨率降低為1/2,通道數增加為2倍(通過合并相鄰2x2區域的特征)"""def __init__(self, dim, norm_layer=nn.LayerNorm):"""Args:dim (int): 輸入特征維度norm_layer (nn.Module): 歸一化層,默認為LayerNorm"""super().__init__()self.dim = dim# 定義線性投影層:將4*dim維特征映射到2*dim維(通道數翻倍)self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)# 歸一化層:作用于合并后的特征(輸入維度為4*dim)self.norm = norm_layer(4 * dim)def forward(self, x, H, W):"""Args:x (Tensor): 輸入特征,形狀為 [batch_size, H*W, dim]H, W (int): 特征圖的高度和寬度Returns:Tensor: 下采樣后的特征,形狀為 [batch_size, (H//2)*(W//2), 2*dim]"""B, L, C = x.shapeassert L == H * W, "輸入特征長度必須等于H*W"# 重塑為空間結構 [B, H, W, C]x = x.view(B, H, W, C)# 處理奇數尺寸:當H或W為奇數時,通過padding補充1行/列(右下補零)pad_input = (H % 2 == 1) or (W % 2 == 1)if pad_input:# padding格式:(左, 右, 上, 下, 前, 后) -> 此處僅padding高度和寬度的右側/底部x = nn.functional.pad(x, (0, 0, # 通道維度不padding0, W % 2, # 寬度右側補 (0或1列)0, H % 2)) # 高度底部補 (0或1行)# 劃分2x2區域并拼接(空間下采樣核心操作)# 通過切片操作提取相鄰2x2區域的四個子塊x0 = x[:, 0::2, 0::2, :] # 左上塊 [B, H//2, W//2, C]x1 = x[:, 1::2, 0::2, :] # 左下塊x2 = x[:, 0::2, 1::2, :] # 右上塊x3 = x[:, 1::2, 1::2, :] # 右下塊# 沿通道維度拼接 -> 通道數變為4倍 [B, H//2, W//2, 4*C]x = torch.cat([x0, x1, x2, x3], dim=-1)# 展平空間維度 -> [B, (H//2)*(W//2), 4*C]x = x.view(B, -1, 4 * C) # 歸一化處理x = self.norm(x)# 線性投影降維:4*C -> 2*C(通道數翻倍)x = self.reduction(x)return x
(三)demo調用及輸出結果展示
我們創建一個 demo 函數來演示 Swin Transformer 中 PatchMerging
和 WindowAttention
的調用流程
import torchdef demo():"""演示 Swin Transformer 中 PatchMerging 和 WindowAttention 的調用流程"""# --------------- 參數設置 ---------------batch_size = 2 # 批大小height, width = 16, 16 # 輸入特征圖的高和寬dim = 96 # 輸入特征的維度(通道數)window_size = (4, 4) # 窗口大小(高方向4像素,寬方向4像素)num_heads = 4 # 多頭注意力頭數# 創建隨機輸入特征圖 [B, H*W, C]x = torch.randn(batch_size, height * width, dim)print(f"原始輸入形狀: {x.shape}") # 預期輸出: [2, 256, 96]# --------------- 調用 PatchMerging 模塊 ---------------patch_merge = PatchMerging(dim=dim)x_patch = patch_merge(x, height, width)new_height, new_width = height // 2, width // 2 # 下采樣后特征圖尺寸print(f"PatchMerging 輸出形狀: {x_patch.shape}") # 預期輸出: [2, 64, 192]# 注:H/2 * W/2 = 8 * 8=64,通道數從96擴展到192# --------------- 調整形狀以適應 WindowAttention 輸入 ---------------# 計算窗口劃分后的參數num_windows_h = new_height // window_size[0] # 窗口行數 8//4=2num_windows_w = new_width // window_size[1] # 窗口列數 8//4=2num_windows = num_windows_h * num_windows_w # 總窗口數 2 * 2=4tokens_per_window = window_size[0] * window_size[1] # 每個窗口的token數 4 * 4=16new_dim = 2 * dim # PatchMerging后的通道數 96 * 2=192# 重塑特征圖為窗口形式 [B * num_windows, tokens_per_window, new_dim]x_window = x_patch.view(batch_size, num_windows_h, num_windows_w, # 窗口的行列數tokens_per_window, # 每個窗口的token數new_dim # 新通道數)x_window = x_window.permute(0, 1, 2, 3, 4).contiguous() # 維度調整 [B, num_h, num_w, tokens, C]x_window = x_window.view(batch_size * num_windows, tokens_per_window, new_dim)print(f"調整為窗口輸入形狀: {x_window.shape}") # 預期輸出: [8, 16, 192] (2 * 4=8窗口)# --------------- 調用 WindowAttention 模塊 ---------------window_attn = WindowAttention(dim=new_dim, window_size=window_size, num_heads=num_heads)x_out = window_attn(x_window) # 輸入形狀 [8, 16, 192]print(f"WindowAttention 輸出形狀: {x_out.shape}") # 預期輸出: [8, 16, 192]# --------------- 將輸出還原為特征圖形式 ---------------# 逆向重塑操作(僅用于展示,實際可能不需要)x_out = x_out.view(batch_size, num_windows_h, num_windows_w, # 窗口行列數tokens_per_window, # 每個窗口的token數new_dim # 通道數)x_out = x_out.permute(0, 1, 3, 2, 4).contiguous() # [B, num_h, tokens, num_w, C]x_out = x_out.view(batch_size, new_height, new_width, new_dim)print(f"最終特征圖形狀: {x_out.shape}") # 預期輸出: [2, 8, 8, 192]if __name__ == "__main__":demo()
輸出結果如下:
通過代碼示例,我們不僅理解了 Swin Transformer ?層次化架構、窗口注意力和移位窗口機制的實現細節,更深入認識到其設計哲學:在保持 Transformer 全局建模能力的同時,通過局部計算和層次化設計逼近 CNN 的效率優勢。這種平衡使其成為視覺任務的通用 Backbone,為后續研究(如 SwinV2、Uniformer)提供了重要參考。
總結與展望
Swin Transformer 在計算機視覺領域具有重要地位,通過 窗口注意力機制、移位窗口機制 和 補丁合并層 等設計降低計算復雜度,提升特征提取能力,在多項任務中表現出色。對于研究者而言,它是一個值得深入探索的模型,未來可在更多領域發揮作用。
Swin Transformer 的研究正處于快速發展階段。優化方向包括改進窗口注意力機制(如動態窗口劃分)和降低計算復雜度(如稀疏注意力)。多模態融合 是另一熱點,與 NLP 結合可實現圖像描述和視覺問答等任務。在應用上,Swin Transformer 在自動駕駛(車輛檢測)和醫療影像分析(腫瘤檢測)中展現潛力。未來,其通用性和計算效率有望進一步提升,應用范圍將更廣泛。
延伸閱讀
-
AI Agent 系列文章
-
計算機視覺系列文章
-
機器學習核心算法系列文章
-
深度學習系列文章