Swin transformer 論文閱讀記錄 代碼分析


該篇文章,是我解析 Swin transformer 論文原理(結合pytorch版本代碼)所記,圖片來源于源paper或其他相應博客。

代碼也非原始代碼,而是從代碼里摘出來的片段,配上簡單數據,以便理解。

當然,也可能因為設置數據不當,造成誤解,請多指教。

剛寫了一部分。先發布。希望多多指正。


在這里插入圖片描述
Figure 1.
(a) The proposed Swin Transformer builds hierarchical feature maps by merging image patches (shown in gray) in deeper layers ,
and has linear computation complexity to input image size due to computation of self-attention only within each local window (shown in red).
It can thus serve as a general-purpose backbone for both image classification and dense recognition tasks.
(b) In contrast, previous vision Transformers produce feature maps of a single low resolution and have quadratic computation complexity to input image size due to computation of self attention globally.

模型結構圖

在這里插入圖片描述
Figure 3.
(a) The architecture of a Swin Transformer (Swin-T);
(b) two successive Swin Transformer Blocks (notation presented with Eq. (3)).
W-MSA and SW-MSA are multi-head self attention modules with regular and shifted windowing configurations, respectively.

Stage 1 – Patch Embedding

It first splits an input RGB image into non-overlapping patches by a patch splitting module, like ViT.

Each patch is treated as a “token” and its feature is set as a concatenation of the raw pixel RGB values.

In our implementation, we use a patch size of 4×4 and thus the feature dimension of each patch is 4×4×3 = 48.(channel–3)

A linear embedding layer is applied on this raw-valued feature to project it to an arbitrary dimension (denoted as C).
這個表述,linear embedding layer,我感覺不太準確,但是,后半部分比較準確,哈哈,將channel–3變成了96.

Several Transformer blocks with modified self-attention computation (Swin Transformer blocks) are applied on these patch tokens.

The Transformer blocks maintain the number of tokens (H/4 × W/4), and together with the linear embedding are referred to as “Stage 1”.

代碼

以下代碼來自于model.py:

class PatchEmbed(nn.Module):"""2D Image to Patch Embedding"""
"""
@ time : 2024/12/17
"""
import torch
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.nn.functional as Fclass PatchEmbed(nn.Module):"""2D Image to Patch Embedding"""def __init__(self, patch_size=4, in_c=3, embed_dim=96, norm_layer=None):super().__init__()patch_size = (patch_size, patch_size)self.patch_size = patch_sizeself.in_chans = in_cself.embed_dim = embed_dimself.proj = nn.Conv2d(in_c, embed_dim, kernel_size=patch_size, stride=patch_size)self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()def forward(self, x):_, _, H, W = x.shape# padding# 如果輸入圖片的H,W不是patch_size的整數倍,需要進行paddingpad_input = (H % self.patch_size[0] != 0) or (W % self.patch_size[1] != 0)if pad_input:# to pad the last 3 dimensions,# (W_left,W_right, H_top,H_bottom, C_front,C_back)x = F.pad(x,(0, self.patch_size[1] - W % self.patch_size[1],0, self.patch_size[0] - H % self.patch_size[0],0, 0))# 下采樣patch_size倍x = self.proj(x)_, _, H, W = x.shape# flatten: [B, C, H, W] -> [B, C, HW]# transpose: [B, C, HW] -> [B, HW, C]x = x.flatten(2).transpose(1, 2)x = self.norm(x)print(x.shape)# torch.Size([1, 3136, 96])# 224/4 * 224/4 = 3136return x, H, Wif __name__ == '__main__':img_path = "tulips.jpg"img = Image.open(img_path)plt.imshow(img)# [N, C, H, W]print(img.size)# (500,375)#img_size = 224data_transform = transforms.Compose([transforms.Resize(int(img_size * 1.14)),transforms.CenterCrop(img_size),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])img = data_transform(img)print(img.shape)# torch.Size([3, 224, 224])# expand batch dimensionimg = torch.unsqueeze(img, dim=0)print(img.shape)# torch.Size([1, 3, 224, 224])# split image into non-overlapping patchespatch_embed = PatchEmbed(norm_layer=nn.LayerNorm)patch_embed(img)

Stage 2 – 3.2. Shifted Window based Self-Attention

Shifted window partitioning in successive blocks

The window-based self-attention module lacks connections across windows, which limits its modeling power.

To introduce cross-window connections while maintaining the efficient computation of non-overlapping windows,
we propose a shifted window partitioning approach which alternates between two partitioning configurations in consecutive Swin Transformer blocks.
為了在保持非重疊窗口高效計算的同時引入跨窗口連接,我們提出了一種移位窗口劃分方法,該方法在連續的Swin Transformer塊中交替使用兩種不同的劃分配置。

在這里插入圖片描述
Figure 2.
In layer l (left), a regular window partitioning scheme is adopted, and self-attention is computed within each window.
In the next layer l + 1 (right), the window partitioning is shifted, resulting in new windows.
The self-attention computation in the new windows crosses the boundaries of the previous windows in layer l, providing connections among them.
在新窗口中進行的自注意力計算跨越了第l層中先前窗口的邊界,從而在它們之間建立了連接。

Efficient batch computation for shifted configuration

An issue with shifted window partitioning is that it will result in more windows, and some of the windows will be smaller than M×M.

Here, we propose a more efficient batch computation approach by cyclic-shifting toward the top-left direction(向左上方向循環移動), as illustrated in Figure 4.

這里的 more efficient,是說相對于直觀方法 padding—mask來說:

A naive solution is to pad the smaller windows to a size of M×M and mask out the padded values when computing attention.


在這里插入圖片描述
Figure 4. Illustration of an efficient batch computation approach for self-attention in shifted window partitioning.


After this shift, a batched window may be composed of several sub-windows that are not adjacent in the feature map, so a masking mechanism is employed to limit self-attention computation to within each sub-window.
在此轉換之后,批處理窗口可能由特征圖中不相鄰的幾個子窗口組成,因此采用掩蔽機制將自注意力計算限制在每個子窗口內。

With the cyclic-shift, the number of batched windows remains the same as that of regular window partitioning, and thus is also efficient.
通過循環移位,批處理窗口的數量與常規窗口分區的數量保持不變,因此也是高效的。


上圖和敘述,并不太直觀,找了相關資料,一起分析:

在這里插入圖片描述
在這里插入圖片描述
在這里插入圖片描述
移動完成之后,4是一個單獨區域,5、3為一組,7、1為一組,8、6、2、0為一組。

但,5、3本身是兩個圖像的邊緣,混在一起計算不是亂了嗎?一起計算也沒問題,ViT也是全局計算的。

但,Swin-Transformer為了防止這個問題,在代碼中使用了masked MSA,這樣就能夠通過設置蒙板來隔絕不同區域的信息了。

源碼中具體的方法就是將不計算的位置元素減去100。

這里需要注意的是,在窗口數據進行滑動完之后,需要將數據還原回去,即挪回到原來的位置上。

代碼

以下代碼來自于model.py:

def window_partition(x, window_size: int):"""將feature map按照window_size劃分成一個個沒有重疊的window主要思路是將feature轉成 (num_windows*B, window_size*window_size, C)的shape,把需要self-attn計算的window排列到第0維,一次并行的qkv就可以了Args:x: (B, H, W, C)window_size (int): window size(M)Returns:windows: (num_windows*B, window_size, window_size, C)"""B, H, W, C = x.shape# B,224,224,C# B,56,56,Cx = x.view(B, H // window_size, window_size, W // window_size, window_size, C)# B,32,7,32,7,C# B,8,7,8,7,C# permute:# [B, H//Mh, Mh,    W//Mw, Mw, C] -># [B, H//Mh, W//Mh, Mw,    Mw, C]# B,32,32,7,7,C# B,8,8,7,7,C# view:# [B, H//Mh, W//Mw, Mh, Mw, C] -># [B*num_windows,   Mh, Mw, C]# B*1024,7,7,C# B*64,7,7,C# 32*32 = 1024# 224 / 7 = 32windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)return windows

分析:將 [B, C, 56, 56] 最后變成了[64B, C, 7, 7],原先的 B*C 張 56*56 的特征圖,最后變成了 B*64*C張7*7的特征;

即,我們有64B個樣本,每個樣本包含C個7x7的通道。

注意,window_size–M–7,是每個window的大小,7*7,不是7*7個window,我剛開始混淆了這一點。


class BasicLayer(nn.Module):# A basic Swin Transformer layer for one stage.def __init__(self, dim, depth, num_heads, window_size,mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0.,drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False):super().__init__()self.dim = dimself.depth = depthself.window_size = window_sizeself.use_checkpoint = use_checkpointself.shift_size = window_size // 2# 7//2 = 3# build blocksself.blocks = nn.ModuleList([SwinTransformerBlock(dim=dim,num_heads=num_heads,window_size=window_size,shift_size=0 if (i % 2 == 0) else self.shift_size,mlp_ratio=mlp_ratio,qkv_bias=qkv_bias,drop=drop,attn_drop=attn_drop,drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,norm_layer=norm_layer)for i in range(depth)])...# depth: 2, 2, 6, 2# 即,第一層,depth=2, 有兩個SwinTransformerBlock,shift_size分別為:0,3# 即,第二層,depth=2, 有兩個SwinTransformerBlock,shift_size分別為:0,3# 即,第三層,depth=6, 有兩個SwinTransformerBlock,shift_size分別為:#	0,3,0,3,0,3# 即,第四層,depth=2, 有兩個SwinTransformerBlock,shift_size分別為:0,3def create_mask(self, x, H, W):# calculate attention mask for SW-MSA
import numpy as np
import torchH = 7
W = 7
window_size = 7
shift_size = 3Hp = int(np.ceil(H / window_size)) * window_size
Wp = int(np.ceil(W / window_size)) * window_size# 擁有和feature map一樣的通道排列順序,方便后續window_partition
img_mask = torch.zeros((1, Hp, Wp, 1))
# [1, Hp, Wp, 1]
print(img_mask, '\n')h_slices = (slice(0, -window_size),slice(-window_size, -shift_size),slice(-shift_size, None)
)
print(h_slices, '\n')
# (slice(0, -7, None), slice(-7, -3, None), slice(-3, None, None))w_slices = (slice(0, -window_size),slice(-window_size, -shift_size),slice(-shift_size, None)
)
print(w_slices, '\n')
# (slice(0, -7, None), slice(-7, -3, None), slice(-3, None, None))cnt = 0
for h in h_slices:for w in w_slices:img_mask[:, h, w, :] = cntcnt += 1print(img_mask)

在這里插入圖片描述

import torchimg_mask = torch.rand((2, 3))
print(img_mask)
'''
tensor([[0.7410, 0.6020, 0.5195],[0.9214, 0.2777, 0.8418]])
'''
attn_mask = img_mask.unsqueeze(1) - img_mask.unsqueeze(2)
print(attn_mask)
'''
tensor([[[ 0.0000, -0.1390, -0.2215],[ 0.1390,  0.0000, -0.0825],[ 0.2215,  0.0825,  0.0000]],[[ 0.0000, -0.6437, -0.0796],[ 0.6437,  0.0000,  0.5642],[ 0.0796, -0.5642,  0.0000]]])
'''print(img_mask.unsqueeze(1))
'''
tensor([[[0.7410, 0.6020, 0.5195]],[[0.9214, 0.2777, 0.8418]]])
'''
print(img_mask.unsqueeze(2))
'''
tensor([[[0.7410],[0.6020],[0.5195]],[[0.9214],[0.2777],[0.8418]]])
'''

上面那個代碼,需要根據下面這個代碼對應著走,shift_size–torch.roll()

class SwinTransformerBlock(nn.Module):# Swin Transformer Block....def forward(self, x, attn_mask):H, W = self.H, self.WB, L, C = x.shapeassert L == H * W, "input feature has wrong size"shortcut = xx = self.norm1(x)x = x.view(B, H, W, C)# pad feature maps to multiples of window size# 把feature map給pad到window size的整數倍pad_l = pad_t = 0pad_r = (self.window_size - W % self.window_size) % self.window_sizepad_b = (self.window_size - H % self.window_size) % self.window_size# 注意F.pad的順序,剛好是反著來的, 例如:# x.shape = (b, h, w, c)# x = F.pad(x, (1, 1, 2, 2, 3, 3))# x.shape = (b, h+6, w+4, c+2)# 源碼可能有誤,修改成下面的# x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))x = F.pad(x, (0, 0, pad_t, pad_b, pad_l, pad_r))_, Hp, Wp, _ = x.shape# cyclic shiftif self.shift_size > 0:# paper中,滑動的size是窗口大小的/2(向下取整)# torch.roll以H,W的維度為例子,負值往左上移動,正值往右下移動。# 溢出的值在對角方向出現。即循環移動。shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))else:shifted_x = xattn_mask = None# partition windowsx_windows = window_partition(shifted_x, self.window_size)  # [nW*B, Mh, Mw, C]x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # [nW*B, Mh*Mw, C]...

其中,torch.roll()方法簡易示例如下:

import torchx = torch.randn(1, 4, 4, 3)
print(x, '\n')shifted_x = torch.roll(x, shifts=(-3, -3), dims=(1, 2))
print(shifted_x, '\n')

為了方便理解,我更換了維度:

import torchx = torch.randn(1, 3, 7, 7)
print(x, '\n')shifted_x = torch.roll(x, shifts=(-3, -3), dims=(2, 3))
print(shifted_x, '\n')

在這里插入圖片描述

Stage 3 – patch merging layers

To produce a hierarchical representation, the number of tokens is reduced by patch merging layers as the network gets deeper.

The first patch merging layer concatenates the features of each group of 2×2 neighboring patches, and applies a linear layer on the 4C-dimensional concatenated features.
首個補丁合并層將每組2×2相鄰補丁的特征進行拼接,并在拼接后的4C維特征上應用一個線性層。

This reduces the number of tokens by a multiple of 2×2=4(2 ×downsampling of resolution), and the output dimension is set to 2C.

Swin Transformer blocks are applied afterwards for feature transformation, with the resolution kept at H/8 × W/8.

同樣,結合其他大神分析,圖展示如下:

在這里插入圖片描述

Related Work

Self-attention based backbone architectures

Instead of using sliding windows, we propose to shift windows between consecutive layers, which allows for a more efficient implementation in general hardware.

。。。。。

Cited link or paper name

  1. Swin Transformer: Hierarchical Vision Transformer using Shifted Windows.
  2. https://blog.csdn.net/weixin_42392454/article/details/141395092

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/web/63430.shtml
繁體地址,請注明出處:http://hk.pswp.cn/web/63430.shtml
英文地址,請注明出處:http://en.pswp.cn/web/63430.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

GPT-Omni 與 Mini-Omni2:創新與性能的結合

近年來,隨著人工智能技術的飛速發展,各種模型和平臺應運而生,以滿足從個人用戶到企業級應用的多樣化需求。在這一領域,GPT-Omni 和 Mini-Omni2 是兩款備受矚目的技術產品,它們憑借獨特的設計和強大的功能,在…

龍迅#LT7911E適用于EDP/DP/TPYE-C轉MIPIDSI應用,支持圖像處理功能,內置I2C,主應用副屏顯示,投屏領域!

1. 描述 LT7911E 是一款高性能 eDP 轉 MIPI D-PHY 轉換器,旨在將 eDP 源連接到 MIPI 顯示面板。 LT7911E 集成了一個符合 eDP1.4 標準的接收器,支持 1.62Gbps 至 5.67Gbps 的輸入數據,以 270Mbps 的遞增步長,以及一個 2 端口 D…

C語言——實現求出最大值

問題描述&#xff1a;利用C語言自定義函數求出一維數組里邊最大的數字 //利用函數找最大數#include<stdio.h>int search(int s[9]) //查找函數 {int i , max s[0] , max_xia 0;for(i0;i<9;i){if(s[i] > max){max_xia i;max s[max_xia];}}return max; } in…

解鎖 draw.io 流程圖制作工具Docker私有化部署(2/2)

一、draw.io 流程圖制作工具簡介 &#xff08;一&#xff09;基礎介紹 draw.io 是一款備受青睞的開源流程圖軟件&#xff0c;它有著諸多優點。首先&#xff0c;其界面十分整潔有序&#xff0c;完全沒有廣告的干擾&#xff0c;并且所有功能都是免費向用戶開放的&#xff0c;這一…

[HNCTF 2022 Week1]baby_rsa

源代碼&#xff1a; from Crypto.Util.number import bytes_to_long, getPrime from gmpy2 import * from secret import flag m bytes_to_long(flag) p getPrime(128) q getPrime(128) n p * q e 65537 c pow(m,e,n) print(n,c) # 62193160459999883112594854240161159…

docker run命令大全

docker run命令大全 基本語法常用選項基礎選項資源限制網絡配置存儲卷和掛載環境變量重啟策略其他高級選項示例總結docker run 命令是 Docker 中最常用和強大的命令之一,用于創建并啟動一個新的容器。該命令支持多種選項和參數,可以滿足各種使用場景的需求。以下是 docker ru…

Java中JDBC過時方法的替代方案以及JDBC為什么過時詳細分析

在Java中&#xff0c;JDBC的一些方法因為安全問題、性能問題或者因為引入了更好的替代方法已經被標記為過時&#xff08;Deprecated&#xff09;。 以下是一些被過時的JDBC方法以及它們的替代方案&#xff1a; 1.DriverManager.getDrivers(): 這個方法用于獲取所有當前注冊的J…

詳細指南:在Ubuntu 20.04 ROS 1環境下設置和使用OpenNI2 SDK

詳細指南&#xff1a;在Ubuntu 20.04 ROS 1環境下設置和使用OpenNI2 SDK 要在Ubuntu 20.04系統上使用ROS 1環境中的OpenNI2 SDK&#xff0c;您需要按照一系列有組織的步驟進行操作&#xff0c;以確保軟件和驅動正確安裝&#xff0c;并配置好相應的開發環境。以下是詳細的步驟說…

RK3568平臺(Kbuild篇)vmlinux 編譯過程

一.vmlinux是什么 vmlinux 是 Linux 操作系統的內核映像文件,它包含了 Linux 內核的所有功能代碼和必要的數據結構。這個文件通常是沒有經過壓縮和符號表去除的原始可執行文件。 具體來說,vmlinux 文件是編譯后的 Linux 內核的最終產物,通常是 ELF(可執行和可鏈接格式)格…

Flink2.0未來趨勢中需要注意的一些問題

手機打字&#xff0c;篇幅不長&#xff0c;主要講一下FFA中關于Flink2.0的未來趨勢&#xff0c;直接看重點。 Flink Forward Asia 2024主會場有一場關于Flink2.0的演講&#xff0c;很精彩&#xff0c;官方也發布了一些關于Flink2.0的展望和要解決的問題。 1.0時代和2.0時代避免…

智能座艙進階-應用框架層-Jetpack主要組件

Jetpack的分類 1. DataBinding&#xff1a;以聲明方式將可觀察數據綁定到界面元素&#xff0c;通常和ViewModel配合使用。 2. Lifecycle&#xff1a;用于管理Activity和Fragment的生命周期&#xff0c;可幫助開發者生成更易于維護的輕量級代碼。 3. LiveData: 在底層數據庫更…

個人秋招總結

秋招總結 個人基本情況拿到offer的公司希望比較大但是主動放棄的簡歷沒過的&#xff0c;有名氣的公司&#xff08;一般的公司太多了不寫&#xff09;秋招感觸 個人基本情況 前言 僅用于個人總結&#xff0c;主要是寫給自己看的&#xff0c;也給別人一點參考 學歷 中國農業大學計…

docker 使用 xz save 鏡像

適用場景 如果docker save -o xxx > xxx 鏡像體積過大,可以使用 xz 命令壓縮。 命令 例如 save busybox:1.31.1 鏡像,其中 -T 是使用多核心壓縮,可以加快壓縮。 docker save busybox:1.31.1 |xz -T 8 > /tmp/busybox:1.31.1安裝 xz Ubuntu/Debian sudo apt upda…

PowerMILL 客制化宏 - 變量

從PowerMILL2012起&#xff0c;命令起始支持變量。支持變量將使宏命令更加靈活和功能強大。可以對變量做一些運算而不依賴其它語言。 當前支持有變量類型為&#xff1a; INT&#xff1b; REAL&#xff1b; STRING&#xff1b; ENTITY&#xff1b; ARRAY LIST; OBJECT; 以下就…

arcgis for js實現地圖截圖、地圖打印

地圖截圖 效果 實現 復制運行即可 要實現復雜的截圖保存可以參考 官網案例 <!DOCTYPE html> <html lang"zn"><head><meta charset"UTF-8" /><meta http-equiv"X-UA-Compatible" content"IEedge" />…

【BUG】記一次context canceled的報錯

文章目錄 案例分析gorm源碼解讀gin context 生命周期context什么時候cancel的什么時候context會被動cancel掉呢&#xff1f; 野生協程如何處理 案例分析 報錯信息 {"L":"ERROR","T":"2024-12-17T11:11:33.0050800","file"…

信號槽【QT】

文章目錄 對象樹字符集信號槽QT坐標系信號與槽connect自定義槽自定義信號disconnect 對象樹 #ifndef MYLABEL_H #define MYLABEL_H#include<QLabel> class MyLabel : public QLabel { public:// 構造函數使用帶 QWidget* 版本的.// 確保對象能夠加到對象樹上MyLabel(QWi…

寫SQL太麻煩?免費搭建 Text2SQL 應用,智能寫 SQL | OceanBase AI 實踐

自OceanBase 4.3.3版本推出以來&#xff0c;向量檢索的能力受到了很多客戶的關注&#xff0c;也紛紛表達希望OB能拓展更多 多模數據庫大模型 的AI應用實踐。 在上篇文章 &#x1f449; OceanBase LLM&#xff0c;免費構建你的專屬 AI 助手 &#xff0c;我們介紹了如何去搭建一…

400G/800G光模塊崛起:AI時代的網絡基礎設施革命

隨著AI技術的不斷成熟&#xff0c;各行各業都在大規模投入AI。醫療行業通過AI技術實現了更精準的診斷和治療&#xff1b;金融行業通過AI技術提高了風險管理能力&#xff1b;制造行業通過AI技術優化了生產流程&#xff1b;娛樂行業通過AI技術創造了更加豐富的用戶體驗。AI在醫療…

Dalsa線陣CCD相機使用開發手冊

要使用Dalsa工業相機進行二次開發&#xff0c;看用戶開發手冊順便做下筆記&#xff1a;&#xff08;歡迎加QQ討論&#xff1a;77248031&#xff0c; 或QQ群&#xff1a;585068192&#xff09; 由于“本公主”用的.NET開發&#xff0c;軟件支持只翻譯了手冊中.NET部分&#xff0…