【Block總結】掩碼窗口自注意力 (M-WSA)

在這里插入圖片描述

摘要

論文鏈接:https://arxiv.org/pdf/2404.07846
論文標題:Transformer-Based Blind-Spot Network for Self-Supervised Image Denoising
Masked Window-Based Self-Attention (M-WSA) 是一種新穎的自注意力機制,旨在解決傳統自注意力方法在處理圖像時的局限性,特別是在圖像去噪和恢復任務中。M-WSA 通過引入掩碼機制,確保在計算注意力時遵循盲點要求,從而避免信息泄露。

設計原理

  1. 窗口自注意力:M-WSA 基于窗口自注意力(Window Self-Attention, WSA)的概念,將輸入圖像劃分為多個不重疊的窗口。在每個窗口內,計算自注意力以捕捉局部特征。這種方法的計算復雜度相對較低,適合處理高分辨率圖像。

  2. 掩碼機制:為了滿足盲點要求,M-WSA 在計算注意力時應用了掩碼。具體而言,掩碼限制了每個像素只能關注其窗口內的特定像素,從而避免了對盲點信息的訪問。這一設計確保了網絡在去噪時不會泄露噪聲信息。

  3. 擴張卷積模擬:M-WSA 的掩碼設計模仿了擴張卷積的感受野,使得網絡能夠在保持計算效率的同時,捕捉到更大范圍的上下文信息。這種方法有效地擴展了網絡的感受野,增強了特征提取能力。
    在這里插入圖片描述

優勢

  • 高效性:通過限制注意力計算在窗口內,M-WSA 顯著降低了計算復雜度,使其適用于大規模圖像處理任務。

  • 信息保護:掩碼機制確保了盲點信息不被泄露,從而提高了去噪效果,特別是在處理具有空間相關噪聲的圖像時。

  • 靈活性:M-WSA 可以與其他網絡架構結合使用,增強其在各種視覺任務中的表現,尤其是在自我監督學習和圖像恢復領域。

實驗結果

在多個真實世界的圖像去噪數據集上進行的實驗表明,M-WSA 顯著提高了去噪性能,超越了傳統的卷積網絡和其他自注意力機制。這一結果表明,M-WSA 在處理復雜噪聲模式時具有良好的適應性和有效性。

代碼

Masked Window-Based Self-Attention (M-WSA) 通過結合窗口自注意力和掩碼機制,為圖像去噪和恢復任務提供了一種有效的解決方案。其設計不僅提高了計算效率,還確保了信息的安全性,展示了在自我監督學習中的廣泛應用潛力。代碼:

import torch
import torch.nn as nn
from einops import rearrange
from torch import einsumdef to(x):return {'device': x.device, 'dtype': x.dtype}def expand_dim(t, dim, k):t = t.unsqueeze(dim=dim)expand_shape = [-1] * len(t.shape)expand_shape[dim] = kreturn t.expand(*expand_shape)def rel_to_abs(x):b, l, m = x.shaper = (m + 1) // 2col_pad = torch.zeros((b, l, 1), **to(x))x = torch.cat((x, col_pad), dim=2)flat_x = rearrange(x, 'b l c -> b (l c)')flat_pad = torch.zeros((b, m - l), **to(x))flat_x_padded = torch.cat((flat_x, flat_pad), dim=1)final_x = flat_x_padded.reshape(b, l + 1, m)final_x = final_x[:, :l, -r:]return final_xdef relative_logits_1d(q, rel_k):b, h, w, _ = q.shaper = (rel_k.shape[0] + 1) // 2logits = einsum('b x y d, r d -> b x y r', q, rel_k)logits = rearrange(logits, 'b x y r -> (b x) y r')logits = rel_to_abs(logits)logits = logits.reshape(b, h, w, r)logits = expand_dim(logits, dim=2, k=r)return logitsclass RelPosEmb(nn.Module):def __init__(self,block_size,rel_size,dim_head):super().__init__()height = width = rel_sizescale = dim_head ** -0.5self.block_size = block_sizeself.rel_height = nn.Parameter(torch.randn(height * 2 - 1, dim_head) * scale)self.rel_width = nn.Parameter(torch.randn(width * 2 - 1, dim_head) * scale)def forward(self, q):block = self.block_sizeq = rearrange(q, 'b (x y) c -> b x y c', x=block)rel_logits_w = relative_logits_1d(q, self.rel_width)rel_logits_w = rearrange(rel_logits_w, 'b x i y j-> b (x y) (i j)')q = rearrange(q, 'b x y d -> b y x d')rel_logits_h = relative_logits_1d(q, self.rel_height)rel_logits_h = rearrange(rel_logits_h, 'b x i y j -> b (y x) (j i)')return rel_logits_w + rel_logits_hclass FixedPosEmb(nn.Module):def __init__(self, window_size, overlap_window_size):super().__init__()self.window_size = window_sizeself.overlap_window_size = overlap_window_sizeattention_mask_table = torch.zeros((window_size + overlap_window_size - 1),(window_size + overlap_window_size - 1))attention_mask_table[0::2, :] = float('-inf')attention_mask_table[:, 0::2] = float('-inf')attention_mask_table = attention_mask_table.view((window_size + overlap_window_size - 1) * (window_size + overlap_window_size - 1))# get pair-wise relative position index for each token inside the windowcoords_h = torch.arange(self.window_size)coords_w = torch.arange(self.window_size)coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Wwcoords_flatten_1 = torch.flatten(coords, 1)  # 2, Wh*Wwcoords_h = torch.arange(self.overlap_window_size)coords_w = torch.arange(self.overlap_window_size)coords = torch.stack(torch.meshgrid([coords_h, coords_w]))coords_flatten_2 = torch.flatten(coords, 1)relative_coords = coords_flatten_1[:, :, None] - coords_flatten_2[:, None, :]  # 2, Wh*Ww, Wh*Wwrelative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2relative_coords[:, :, 0] += self.overlap_window_size - 1  # shift to start from 0relative_coords[:, :, 1] += self.overlap_window_size - 1relative_coords[:, :, 0] *= self.window_size + self.overlap_window_size - 1relative_position_index = relative_coords.sum(-1)  # Wh*Ww, Wh*Wwself.attention_mask = nn.Parameter(attention_mask_table[relative_position_index.view(-1)].view(1, self.window_size ** 2, self.overlap_window_size ** 2), requires_grad=False)def forward(self):return self.attention_maskclass DilatedOCA(nn.Module):def __init__(self, dim, window_size, overlap_ratio, num_heads, dim_head, bias):super(DilatedOCA, self).__init__()self.num_spatial_heads = num_headsself.dim = dimself.window_size = window_sizeself.overlap_win_size = int(window_size * overlap_ratio) + window_sizeself.dim_head = dim_headself.inner_dim = self.dim_head * self.num_spatial_headsself.scale = self.dim_head ** -0.5self.unfold = nn.Unfold(kernel_size=(self.overlap_win_size, self.overlap_win_size), stride=window_size,padding=(self.overlap_win_size - window_size) // 2)self.qkv = nn.Conv2d(self.dim, self.inner_dim * 3, kernel_size=1, bias=bias)self.project_out = nn.Conv2d(self.inner_dim, dim, kernel_size=1, bias=bias)self.rel_pos_emb = RelPosEmb(block_size=window_size,rel_size=window_size + (self.overlap_win_size - window_size),dim_head=self.dim_head)self.fixed_pos_emb = FixedPosEmb(window_size, self.overlap_win_size)def forward(self, x):b, c, h, w = x.shapeqkv = self.qkv(x)qs, ks, vs = qkv.chunk(3, dim=1)# spatial attentionqs = rearrange(qs, 'b c (h p1) (w p2) -> (b h w) (p1 p2) c', p1=self.window_size, p2=self.window_size)ks, vs = map(lambda t: self.unfold(t), (ks, vs))ks, vs = map(lambda t: rearrange(t, 'b (c j) i -> (b i) j c', c=self.inner_dim), (ks, vs))# print(f'qs.shape:{qs.shape}, ks.shape:{ks.shape}, vs.shape:{vs.shape}')# split headsqs, ks, vs = map(lambda t: rearrange(t, 'b n (head c) -> (b head) n c', head=self.num_spatial_heads),(qs, ks, vs))# attentionqs = qs * self.scalespatial_attn = (qs @ ks.transpose(-2, -1))spatial_attn += self.rel_pos_emb(qs)spatial_attn += self.fixed_pos_emb()spatial_attn = spatial_attn.softmax(dim=-1)out = (spatial_attn @ vs)out = rearrange(out, '(b h w head) (p1 p2) c -> b (head c) (h p1) (w p2)', head=self.num_spatial_heads,h=h // self.window_size, w=w // self.window_size, p1=self.window_size, p2=self.window_size)# merge spatial and channelout = self.project_out(out)return outif __name__ == "__main__":dim = 64window_size = 8overlap_ratio = 0.5num_heads = 2dim_head = 16# 初始化 DilatedOCA 模塊oca_attention = DilatedOCA(dim=dim,window_size=window_size,overlap_ratio=overlap_ratio,num_heads=num_heads,dim_head=dim_head,bias=True)device = torch.device("cuda" if torch.cuda.is_available() else "cpu")oca_attention = oca_attention.to(device)print(oca_attention)x = torch.randn(1, 32, 640, 480).to(device)# 前向傳播output = oca_attention(x)print("input張量形狀:", x.shape)print("output張量形狀:", output.shape)

DilatedOCA模塊詳解

代碼結構

import torch
import torch.nn as nn
from einops import rearrange
  • 導入庫:首先導入 PyTorch 和 einops 庫。einops 用于簡化張量的重排操作。

模塊定義

class DilatedOCA(nn.Module):def __init__(self, dim, window_size, overlap_ratio, num_heads, dim_head, bias):super(DilatedOCA, self).__init__()self.num_spatial_heads = num_headsself.dim = dimself.window_size = window_sizeself.overlap_win_size = int(window_size * overlap_ratio) + window_sizeself.dim_head = dim_headself.inner_dim = self.dim_head * self.num_spatial_headsself.scale = self.dim_head ** -0.5self.unfold = nn.Unfold(kernel_size=(self.overlap_win_size, self.overlap_win_size), stride=window_size,padding=(self.overlap_win_size - window_size) // 2)self.qkv = nn.Conv2d(self.dim, self.inner_dim * 3, kernel_size=1, bias=bias)self.project_out = nn.Conv2d(self.inner_dim, dim, kernel_size=1, bias=bias)self.rel_pos_emb = RelPosEmb(block_size=window_size,rel_size=window_size + (self.overlap_win_size - window_size),dim_head=self.dim_head)self.fixed_pos_emb = FixedPosEmb(window_size, self.overlap_win_size)
  • 初始化方法__init__ 方法定義了模塊的結構。

    • dim:輸入特征的通道數。

    • window_size:窗口的大小,用于空間注意力計算。

    • overlap_ratio:重疊窗口的比例,決定了窗口之間的重疊程度。

    • num_heads:空間注意力的頭數。

    • dim_head:每個頭的維度。

  • 層的定義

    • self.unfold:用于將輸入張量展開為重疊窗口的操作。

    • self.qkv:一個 1x1 的卷積層,用于生成查詢(Q)、鍵(K)和值(V)三個特征圖。

    • self.project_out:一個 1x1 的卷積層,用于將輸出特征映射回原始通道數。

    • self.rel_pos_embself.fixed_pos_emb:用于位置編碼的模塊,增強模型對空間位置的感知。

前向傳播

def forward(self, x):b, c, h, w = x.shapeqkv = self.qkv(x)qs, ks, vs = qkv.chunk(3, dim=1)# spatial attentionqs = rearrange(qs, 'b c (h p1) (w p2) -> (b h w) (p1 p2) c', p1=self.window_size, p2=self.window_size)ks, vs = map(lambda t: self.unfold(t), (ks, vs))ks, vs = map(lambda t: rearrange(t, 'b (c j) i -> (b i) j c', c=self.inner_dim), (ks, vs))# split headsqs, ks, vs = map(lambda t: rearrange(t, 'b n (head c) -> (b head) n c', head=self.num_spatial_heads),(qs, ks, vs))# attentionqs = qs * self.scalespatial_attn = (qs @ ks.transpose(-2, -1))spatial_attn += self.rel_pos_emb(qs)spatial_attn += self.fixed_pos_emb()spatial_attn = spatial_attn.softmax(dim=-1)out = (spatial_attn @ vs)out = rearrange(out, '(b h w head) (p1 p2) c -> b (head c) (h p1) (w p2)', head=self.num_spatial_heads,h=h // self.window_size, w=w // self.window_size, p1=self.window_size, p2=self.window_size)# merge spatial and channelout = self.project_out(out)return out
  • 輸入形狀x 的形狀為 (batch_size, channels, height, width),其中 b 是批量大小,c 是通道數,hw 是圖像的高度和寬度。

  • 特征提取

    • qkv = self.qkv(x):通過 qkv 層生成 Q、K、V 特征圖。

    • qs, ks, vs = qkv.chunk(3, dim=1):將 Q、K、V 特征圖沿通道維度分離。

  • 空間注意力計算

    • qs 被重排為適合空間注意力計算的格式。

    • ksvs 通過 unfold 操作展開為重疊窗口。

  • 分頭處理

    • 使用 einops.rearrange 將 Q、K、V 的形狀調整為適合多頭自注意力計算的格式。
  • 計算注意力

    • qs = qs * self.scale:對 Q 進行縮放以提高穩定性。

    • spatial_attn = (qs @ ks.transpose(-2, -1)):計算注意力分數。

    • spatial_attn += self.rel_pos_emb(qs)spatial_attn += self.fixed_pos_emb():添加位置編碼以增強空間感知。

    • spatial_attn = spatial_attn.softmax(dim=-1):對注意力分數進行 softmax 歸一化。

  • 輸出計算

    • out = (spatial_attn @ vs):使用注意力權重對 V 進行加權求和,得到最終輸出。
  • 重排輸出

    • out = rearrange(out, '(b h w head) (p1 p2) c -> b (head c) (h p1) (w p2)', ...):將輸出重排回原始形狀。
  • 最終投影

    • out = self.project_out(out):通過投影層將輸出映射回原始通道數。

總結

DilatedOCA 模塊結合了擴張卷積和空間注意力機制,通過重疊窗口的設計增強了對圖像局部特征的捕捉能力。該模塊在圖像處理任務中具有廣泛的應用潛力,尤其是在需要精細特征提取的場景中。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/web/65996.shtml
繁體地址,請注明出處:http://hk.pswp.cn/web/65996.shtml
英文地址,請注明出處:http://en.pswp.cn/web/65996.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

【Linux】統信UOS服務器安裝MySQL8.0(RPM)

目錄 一、下載安裝包 二、安裝MySQL 2.1hive適配 2.2ranger適配 3.2DolphinScheduler適配 一、下載安裝包 官網下載安裝包:MySQL :: MySQL Downloads 選擇社區版本下載 點擊MySQL Community Server 選擇對應系統的MySQL版本號 統信1060a 操作系統對應 redhat8…

小白:react antd 搭建框架關于 RangePicker DatePicker 時間組件使用記錄 2

文章目錄 一、 關于 RangePicker 組件返回的moment 方法示例 一、 關于 RangePicker 組件返回的moment 方法示例 moment方法中日后開發有用的方法如下: form.getFieldsValue().date[0].weeksInWeekYear(),form.getFieldsValue().date[0].zoneName(), form.getFiel…

Jenkins簡單的安裝運行

一、下載 官網下載:https://www.jenkins.io/download/ 清華大學開源軟件鏡像站:https://mirrors.tuna.tsinghua.edu.cn/jenkins/ 官網資料豐富,介紹了各種平臺安裝以及下載。安裝簡單,按照說明來就行。下面我介紹一個非常簡單的…

【CSS】HTML頁面定位CSS - position 屬性 relative 、absolute、fixed 、sticky

目錄 relative 相對定位 absolute 絕對定位 fixed 固定定位 sticky 粘性定位 position:relative 、absolute、fixed 、sticky (四選一) top:距離上面的像素 bottom:距離底部的像素 left:距離左邊的像素…

網絡安全 | WAF防護開通流程與技術原理詳解

關注:CodingTechWork 引言 隨著互聯網安全形勢的日益嚴峻,Web應用防火墻(WAF, Web Application Firewall)逐漸成為網站和應用的標準防護措施。WAF能夠有效識別和防止如SQL注入、跨站腳本攻擊(XSS)、惡意流…

小結:路由器和交換機的指令對比

路由器和交換機的指令有一定的相似性,但也有明顯的區別。以下是兩者指令的對比和主要差異: 相似之處 基本操作 兩者都支持類似的基本管理命令,比如: 進入系統視圖:system-view查看當前配置:display current…

Ubuntu中雙擊自動運行shell腳本

方法1: 修改文件雙擊反應 參考: https://blog.csdn.net/miffywm/article/details/103382405 chmod x test.sh鼠標選中待執行文件,在窗口左上角edit菜單中選擇preference設計雙擊執行快捷鍵,如下圖: 方法2: 設置一個應用 參考: https://blo…

從0開始學習搭網站的第一天

前言,以下內容學習自mdn社區,感興趣的朋友可以直接去看原文章web技術 目錄 web機制互聯網是怎么運作的網站服務器是什么什么是URL?什么是web服務器?什么是域名什么是超鏈接什么是網頁DOMgoole瀏覽器開發者工具 web機制 互聯網是怎…

java小灶課詳解:關于char和string的區別和對應的詳細操作

char和string的區別與操作詳解 在編程語言中,char和string是用于處理字符和字符串的兩種重要數據類型。它們在存儲、操作和應用場景上存在顯著差異。本文將從以下幾個方面詳細解析兩者的區別及常見操作。 1. 基本定義與存儲差異 char: 定義:…

黑馬linux筆記(03)在Linux上部署各類軟件 MySQL5.7/8.0 Tomcat(JDK) Nginx RabbitMQ

文章目錄 實戰章節:在Linux上部署各類軟件tar -zxvf各個選項的含義 為什么學習各類軟件在Linux上的部署 一 MySQL數據庫管理系統安裝部署【簡單】MySQL5.7版本在CentOS系統安裝MySQL8.0版本在CentOS系統安裝MySQL5.7版本在Ubuntu(WSL環境)系統…

[Transformer] The Structure of GPT, Generative Pretrained Transformer

The Structure of Generative Pretrained Transformer Reference: The Transformer architecture of GPT models How GPT Models Work

淺談云計算04 | 云基礎設施機制

探秘云基礎設施機制:云計算的基石 一、云基礎設施 —— 云計算的根基![在這里插入圖片描述](https://i-blog.csdnimg.cn/direct/1fb7ff493d3c4a1a87f539742a4f57a5.png)二、核心機制之網絡:連接云的橋梁(一)虛擬網絡邊界&#xff…

國內主流的Spring微服務方案指南

構建一個完整的 Spring 微服務方案涉及多個關鍵組件的集成與配置,包括服務注冊與發現、配置管理、API 網關、負載均衡、服務調用、熔斷與限流、消息中間件、分布式追蹤、服務網格、容器編排以及數據庫與緩存等。以下將結合前述內容,詳細介紹一個完整的中…

解鎖 JMeter 的 ForEach Controller 高效測試秘籍

各位小伙伴們,今天咱就來嘮嘮 JMeter 里超厲害的 “寶藏工具”——ForEach Controller,它可是能幫咱們在性能測試的江湖里 “大殺四方” 哦! 一、ForEach Controller 是啥 “神器” 想象一下,你手頭有一串神秘鑰匙,每…

【QT】QComboBox:activated信號和currentIndexChanged信號的區別

目錄 1、activated1.1 原型1.2 觸發機制1.3 使用場景1.4 連接信號和槽的方法1.4.1 方式一1.4.2 方式二 2、currentIndexChanged2.1 原型2.2 觸發機制2.3 使用場景2.4 連接信號和槽的方法 1、activated 1.1 原型 [signal] void QComboBox::activated(int index) [signal] void…

PHP 循環控制結構深度剖析:從基礎到實戰應用

PHP 循環控制結構深度剖析:從基礎到實戰應用 PHP提供了多種控制結構,其中循環控制結構是最常見的結構之一。它們使得我們能夠高效地重復執行一段代碼,直到滿足某個條件為止。本文將從PHP循環的基礎知識出發,逐步分析其在實際項目…

根據瀏覽器的不同類型動態加載不同的 CSS 文件

實現思路: 安裝并引入 vue 項目相關的 CSS 文件:首先確保你有為不同瀏覽器準備了不同的 CSS 文件(例如,style-chrome.css,style-firefox.css,style-ie.css 等)。 在 index.js 中根據瀏覽器類型…

JAVA之單例模式

單例模式(Singleton Pattern)是一種設計模式,用于確保一個類只有一個實例,并提供一個全局訪問點來獲取該實例。在軟件設計中,單例模式常用于控制對資源的訪問,例如數據庫連接、線程池等。以下是單例模式的詳…

Rust 1.84.0 發布

Cargo 依賴版本選擇改進 穩定了最小支持 Rust 版本(MSRV)感知的解析器,該解析器會優先選擇與項目聲明的 MSRV 兼容的依賴版本,減少了維護者支持舊工具鏈的工作量,無需手動為每個依賴選擇舊版本。可以通過.cargo/config…

sosadmin相關命令

sosadmin命令 以下是本人翻譯的官方文檔,如有不對,還請指出,引用請標明出處。 原本有個對應表可以跳轉的,但是CSDN的這個[](#)跳轉好像不太一樣,必須得用html標簽,就懶得改了。 sosadmin help 用法 sosadm…