該篇文章,是我解析 Swin transformer 論文原理(結合pytorch版本代碼)所記,圖片來源于源paper或其他相應博客。
代碼也非原始代碼,而是從代碼里摘出來的片段,配上簡單數據,以便理解。
當然,也可能因為設置數據不當,造成誤解,請多指教。
剛寫了一部分。先發布。希望多多指正。
Figure 1.
(a) The proposed Swin Transformer builds hierarchical feature maps by merging image patches (shown in gray) in deeper layers ,
and has linear computation complexity to input image size due to computation of self-attention only within each local window (shown in red).
It can thus serve as a general-purpose backbone for both image classification and dense recognition tasks.
(b) In contrast, previous vision Transformers produce feature maps of a single low resolution and have quadratic computation complexity to input image size due to computation of self attention globally.
模型結構圖
Figure 3.
(a) The architecture of a Swin Transformer (Swin-T);
(b) two successive Swin Transformer Blocks (notation presented with Eq. (3)).
W-MSA and SW-MSA are multi-head self attention modules with regular and shifted windowing configurations, respectively.
Stage 1 – Patch Embedding
It first splits an input RGB image into non-overlapping patches by a patch splitting module, like ViT.
Each patch is treated as a “token” and its feature is set as a concatenation of the raw pixel RGB values.
In our implementation, we use a patch size of 4×4 and thus the feature dimension of each patch is 4×4×3 = 48.(channel–3)
A linear embedding layer is applied on this raw-valued feature to project it to an arbitrary dimension (denoted as C).
這個表述,linear embedding layer,我感覺不太準確,但是,后半部分比較準確,哈哈,將channel–3變成了96.
Several Transformer blocks with modified self-attention computation (Swin Transformer blocks) are applied on these patch tokens.
The Transformer blocks maintain the number of tokens (H/4 × W/4), and together with the linear embedding are referred to as “Stage 1”.
代碼
以下代碼來自于model.py:
class PatchEmbed(nn.Module):"""2D Image to Patch Embedding"""
"""
@ time : 2024/12/17
"""
import torch
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.nn.functional as Fclass PatchEmbed(nn.Module):"""2D Image to Patch Embedding"""def __init__(self, patch_size=4, in_c=3, embed_dim=96, norm_layer=None):super().__init__()patch_size = (patch_size, patch_size)self.patch_size = patch_sizeself.in_chans = in_cself.embed_dim = embed_dimself.proj = nn.Conv2d(in_c, embed_dim, kernel_size=patch_size, stride=patch_size)self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()def forward(self, x):_, _, H, W = x.shape# padding# 如果輸入圖片的H,W不是patch_size的整數倍,需要進行paddingpad_input = (H % self.patch_size[0] != 0) or (W % self.patch_size[1] != 0)if pad_input:# to pad the last 3 dimensions,# (W_left,W_right, H_top,H_bottom, C_front,C_back)x = F.pad(x,(0, self.patch_size[1] - W % self.patch_size[1],0, self.patch_size[0] - H % self.patch_size[0],0, 0))# 下采樣patch_size倍x = self.proj(x)_, _, H, W = x.shape# flatten: [B, C, H, W] -> [B, C, HW]# transpose: [B, C, HW] -> [B, HW, C]x = x.flatten(2).transpose(1, 2)x = self.norm(x)print(x.shape)# torch.Size([1, 3136, 96])# 224/4 * 224/4 = 3136return x, H, Wif __name__ == '__main__':img_path = "tulips.jpg"img = Image.open(img_path)plt.imshow(img)# [N, C, H, W]print(img.size)# (500,375)#img_size = 224data_transform = transforms.Compose([transforms.Resize(int(img_size * 1.14)),transforms.CenterCrop(img_size),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])img = data_transform(img)print(img.shape)# torch.Size([3, 224, 224])# expand batch dimensionimg = torch.unsqueeze(img, dim=0)print(img.shape)# torch.Size([1, 3, 224, 224])# split image into non-overlapping patchespatch_embed = PatchEmbed(norm_layer=nn.LayerNorm)patch_embed(img)
Stage 2 – 3.2. Shifted Window based Self-Attention
Shifted window partitioning in successive blocks
The window-based self-attention module lacks connections across windows, which limits its modeling power.
To introduce cross-window connections while maintaining the efficient computation of non-overlapping windows,
we propose a shifted window partitioning approach which alternates between two partitioning configurations in consecutive Swin Transformer blocks.
為了在保持非重疊窗口高效計算的同時引入跨窗口連接,我們提出了一種移位窗口劃分方法,該方法在連續的Swin Transformer塊中交替使用兩種不同的劃分配置。
Figure 2.
In layer l (left), a regular window partitioning scheme is adopted, and self-attention is computed within each window.
In the next layer l + 1 (right), the window partitioning is shifted, resulting in new windows.
The self-attention computation in the new windows crosses the boundaries of the previous windows in layer l, providing connections among them.
在新窗口中進行的自注意力計算跨越了第l層中先前窗口的邊界,從而在它們之間建立了連接。
Efficient batch computation for shifted configuration
An issue with shifted window partitioning is that it will result in more windows, and some of the windows will be smaller than M×M.
Here, we propose a more efficient batch computation approach by cyclic-shifting toward the top-left direction(向左上方向循環移動), as illustrated in Figure 4.
這里的 more efficient,是說相對于直觀方法 padding—mask來說:
A naive solution is to pad the smaller windows to a size of M×M and mask out the padded values when computing attention.
Figure 4. Illustration of an efficient batch computation approach for self-attention in shifted window partitioning.
After this shift, a batched window may be composed of several sub-windows that are not adjacent in the feature map, so a masking mechanism is employed to limit self-attention computation to within each sub-window.
在此轉換之后,批處理窗口可能由特征圖中不相鄰的幾個子窗口組成,因此采用掩蔽機制將自注意力計算限制在每個子窗口內。
With the cyclic-shift, the number of batched windows remains the same as that of regular window partitioning, and thus is also efficient.
通過循環移位,批處理窗口的數量與常規窗口分區的數量保持不變,因此也是高效的。
上圖和敘述,并不太直觀,找了相關資料,一起分析:
移動完成之后,4是一個單獨區域,5、3為一組,7、1為一組,8、6、2、0為一組。
但,5、3本身是兩個圖像的邊緣,混在一起計算不是亂了嗎?一起計算也沒問題,ViT也是全局計算的。
但,Swin-Transformer為了防止這個問題,在代碼中使用了masked MSA,這樣就能夠通過設置蒙板來隔絕不同區域的信息了。
源碼中具體的方法就是將不計算的位置元素減去100。
這里需要注意的是,在窗口數據進行滑動完之后,需要將數據還原回去,即挪回到原來的位置上。
代碼
以下代碼來自于model.py:
def window_partition(x, window_size: int):"""將feature map按照window_size劃分成一個個沒有重疊的window主要思路是將feature轉成 (num_windows*B, window_size*window_size, C)的shape,把需要self-attn計算的window排列到第0維,一次并行的qkv就可以了Args:x: (B, H, W, C)window_size (int): window size(M)Returns:windows: (num_windows*B, window_size, window_size, C)"""B, H, W, C = x.shape# B,224,224,C# B,56,56,Cx = x.view(B, H // window_size, window_size, W // window_size, window_size, C)# B,32,7,32,7,C# B,8,7,8,7,C# permute:# [B, H//Mh, Mh, W//Mw, Mw, C] -># [B, H//Mh, W//Mh, Mw, Mw, C]# B,32,32,7,7,C# B,8,8,7,7,C# view:# [B, H//Mh, W//Mw, Mh, Mw, C] -># [B*num_windows, Mh, Mw, C]# B*1024,7,7,C# B*64,7,7,C# 32*32 = 1024# 224 / 7 = 32windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)return windows
分析:將 [B, C, 56, 56] 最后變成了[64B, C, 7, 7],原先的 B*C 張 56*56 的特征圖,最后變成了 B*64*C張7*7的特征;
即,我們有64B個樣本,每個樣本包含C個7x7的通道。
注意,window_size–M–7,是每個window的大小,7*7,不是7*7個window,我剛開始混淆了這一點。
class BasicLayer(nn.Module):# A basic Swin Transformer layer for one stage.def __init__(self, dim, depth, num_heads, window_size,mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0.,drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False):super().__init__()self.dim = dimself.depth = depthself.window_size = window_sizeself.use_checkpoint = use_checkpointself.shift_size = window_size // 2# 7//2 = 3# build blocksself.blocks = nn.ModuleList([SwinTransformerBlock(dim=dim,num_heads=num_heads,window_size=window_size,shift_size=0 if (i % 2 == 0) else self.shift_size,mlp_ratio=mlp_ratio,qkv_bias=qkv_bias,drop=drop,attn_drop=attn_drop,drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,norm_layer=norm_layer)for i in range(depth)])...# depth: 2, 2, 6, 2# 即,第一層,depth=2, 有兩個SwinTransformerBlock,shift_size分別為:0,3# 即,第二層,depth=2, 有兩個SwinTransformerBlock,shift_size分別為:0,3# 即,第三層,depth=6, 有兩個SwinTransformerBlock,shift_size分別為:# 0,3,0,3,0,3# 即,第四層,depth=2, 有兩個SwinTransformerBlock,shift_size分別為:0,3def create_mask(self, x, H, W):# calculate attention mask for SW-MSA
import numpy as np
import torchH = 7
W = 7
window_size = 7
shift_size = 3Hp = int(np.ceil(H / window_size)) * window_size
Wp = int(np.ceil(W / window_size)) * window_size# 擁有和feature map一樣的通道排列順序,方便后續window_partition
img_mask = torch.zeros((1, Hp, Wp, 1))
# [1, Hp, Wp, 1]
print(img_mask, '\n')h_slices = (slice(0, -window_size),slice(-window_size, -shift_size),slice(-shift_size, None)
)
print(h_slices, '\n')
# (slice(0, -7, None), slice(-7, -3, None), slice(-3, None, None))w_slices = (slice(0, -window_size),slice(-window_size, -shift_size),slice(-shift_size, None)
)
print(w_slices, '\n')
# (slice(0, -7, None), slice(-7, -3, None), slice(-3, None, None))cnt = 0
for h in h_slices:for w in w_slices:img_mask[:, h, w, :] = cntcnt += 1print(img_mask)
import torchimg_mask = torch.rand((2, 3))
print(img_mask)
'''
tensor([[0.7410, 0.6020, 0.5195],[0.9214, 0.2777, 0.8418]])
'''
attn_mask = img_mask.unsqueeze(1) - img_mask.unsqueeze(2)
print(attn_mask)
'''
tensor([[[ 0.0000, -0.1390, -0.2215],[ 0.1390, 0.0000, -0.0825],[ 0.2215, 0.0825, 0.0000]],[[ 0.0000, -0.6437, -0.0796],[ 0.6437, 0.0000, 0.5642],[ 0.0796, -0.5642, 0.0000]]])
'''print(img_mask.unsqueeze(1))
'''
tensor([[[0.7410, 0.6020, 0.5195]],[[0.9214, 0.2777, 0.8418]]])
'''
print(img_mask.unsqueeze(2))
'''
tensor([[[0.7410],[0.6020],[0.5195]],[[0.9214],[0.2777],[0.8418]]])
'''
上面那個代碼,需要根據下面這個代碼對應著走,shift_size–torch.roll()
class SwinTransformerBlock(nn.Module):# Swin Transformer Block....def forward(self, x, attn_mask):H, W = self.H, self.WB, L, C = x.shapeassert L == H * W, "input feature has wrong size"shortcut = xx = self.norm1(x)x = x.view(B, H, W, C)# pad feature maps to multiples of window size# 把feature map給pad到window size的整數倍pad_l = pad_t = 0pad_r = (self.window_size - W % self.window_size) % self.window_sizepad_b = (self.window_size - H % self.window_size) % self.window_size# 注意F.pad的順序,剛好是反著來的, 例如:# x.shape = (b, h, w, c)# x = F.pad(x, (1, 1, 2, 2, 3, 3))# x.shape = (b, h+6, w+4, c+2)# 源碼可能有誤,修改成下面的# x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))x = F.pad(x, (0, 0, pad_t, pad_b, pad_l, pad_r))_, Hp, Wp, _ = x.shape# cyclic shiftif self.shift_size > 0:# paper中,滑動的size是窗口大小的/2(向下取整)# torch.roll以H,W的維度為例子,負值往左上移動,正值往右下移動。# 溢出的值在對角方向出現。即循環移動。shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))else:shifted_x = xattn_mask = None# partition windowsx_windows = window_partition(shifted_x, self.window_size) # [nW*B, Mh, Mw, C]x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # [nW*B, Mh*Mw, C]...
其中,torch.roll()方法簡易示例如下:
import torchx = torch.randn(1, 4, 4, 3)
print(x, '\n')shifted_x = torch.roll(x, shifts=(-3, -3), dims=(1, 2))
print(shifted_x, '\n')
為了方便理解,我更換了維度:
import torchx = torch.randn(1, 3, 7, 7)
print(x, '\n')shifted_x = torch.roll(x, shifts=(-3, -3), dims=(2, 3))
print(shifted_x, '\n')
Stage 3 – patch merging layers
To produce a hierarchical representation, the number of tokens is reduced by patch merging layers as the network gets deeper.
The first patch merging layer concatenates the features of each group of 2×2 neighboring patches, and applies a linear layer on the 4C-dimensional concatenated features.
首個補丁合并層將每組2×2相鄰補丁的特征進行拼接,并在拼接后的4C維特征上應用一個線性層。
This reduces the number of tokens by a multiple of 2×2=4(2 ×downsampling of resolution), and the output dimension is set to 2C.
Swin Transformer blocks are applied afterwards for feature transformation, with the resolution kept at H/8 × W/8.
同樣,結合其他大神分析,圖展示如下:
Related Work
Self-attention based backbone architectures
Instead of using sliding windows, we propose to shift windows between consecutive layers, which allows for a more efficient implementation in general hardware.
。。。。。
Cited link or paper name
- Swin Transformer: Hierarchical Vision Transformer using Shifted Windows.
- https://blog.csdn.net/weixin_42392454/article/details/141395092