pytorch 實現 Restormer 主要模塊(多頭通道自注意力機制和門控制結構)

????????前面的博文讀論文:Restormer: Efficient Transformer for High-Resolution Image Restoration?介紹了 Restormer 網絡結構的網絡技術特點,本文用 pytorch 實現其中的主要網絡結構模塊。

1. MDTA(Multi-Dconv Head Transposed Attention:多頭注意力機制

## Multi-DConv Head Transposed Self-Attention (MDTA)
class Attention(nn.Module):def __init__(self, dim, num_heads, bias):super(Attention, self).__init__()self.num_heads = num_heads  # 注意力頭的個數self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1))  # 可學習系數# 1*1 升維self.qkv = nn.Conv2d(dim, dim*3, kernel_size=1, bias=bias)# 3*3 分組卷積self.qkv_dwconv = nn.Conv2d(dim*3, dim*3, kernel_size=3, stride=1, padding=1, groups=dim*3, bias=bias)# 1*1 卷積self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)def forward(self, x):b,c,h,w = x.shape  # 輸入的結構 batch 數,通道數和高寬qkv = self.qkv_dwconv(self.qkv(x))q,k,v = qkv.chunk(3, dim=1)  #  第 1 個維度方向切分成 3 塊# 改變 q, k, v 的結構為 b head c (h w),將每個二維 plane 展平q = rearrange(q, 'b (head c) h w -> b head c (h w)', head=self.num_heads)k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_heads)v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.num_heads)q = torch.nn.functional.normalize(q, dim=-1)  # C 維度標準化,這里的 C 與通道維度略有不同k = torch.nn.functional.normalize(k, dim=-1)attn = (q @ k.transpose(-2, -1)) * self.temperature # @ 是矩陣乘attn = attn.softmax(dim=-1)out = (attn @ v)  # 注意力圖(嚴格來說不算圖)# 將展平后的注意力圖恢復out = rearrange(out, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=h, w=w)# 真正的注意力圖out = self.project_out(out)return out

2.?GDFN( Gated-Dconv Feed-Forward Network)?

## Gated-Dconv Feed-Forward Network (GDFN)
class FeedForward(nn.Module):def __init__(self, dim, ffn_expansion_factor, bias):super(FeedForward, self).__init__()# 隱藏層特征維度等于輸入維度乘以擴張因子hidden_features = int(dim*ffn_expansion_factor)# 1*1 升維self.project_in = nn.Conv2d(dim, hidden_features*2, kernel_size=1, bias=bias)# 3*3 分組卷積self.dwconv = nn.Conv2d(hidden_features*2, hidden_features*2, kernel_size=3, stride=1, padding=1, groups=hidden_features*2, bias=bias)# 1*1 降維self.project_out = nn.Conv2d(hidden_features, dim, kernel_size=1, bias=bias)def forward(self, x):x = self.project_in(x)x1, x2 = self.dwconv(x).chunk(2, dim=1)  # 第 1 個維度方向切分成 2 塊x = F.gelu(x1) * x2  # gelu 相當于 relu+dropoutx = self.project_out(x)return x

3.?TransformerBlock

## 就是標準的 Transformer 架構
class TransformerBlock(nn.Module):def __init__(self, dim, num_heads, ffn_expansion_factor, bias, LayerNorm_type):super(TransformerBlock, self).__init__()self.norm1 = LayerNorm(dim, LayerNorm_type)  # 層標準化self.attn = Attention(dim, num_heads, bias)  # 自注意力self.norm2 = LayerNorm(dim, LayerNorm_type)  # 層表轉化self.ffn = FeedForward(dim, ffn_expansion_factor, bias)  # FFNdef forward(self, x):x = x + self.attn(self.norm1(x))  # 殘差x = x + self.ffn(self.norm2(x))  # 殘差return x

4. 測試樣例

model = Restormer()
print(model)  # 打印網絡結構x = torch.randn((1, 3, 512, 512))  #隨機生成輸入圖像
x = model(x)  # 送入網絡
print(x.shape) # 打印網絡輸入的圖像結構

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/536813.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/536813.shtml
英文地址,請注明出處:http://en.pswp.cn/news/536813.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

選擇免費的SSL證書,還是付費的?

作為一個互聯網文章作者,我會根據具體的使用場景和需求來選擇SSL證書。通常情況下,如果是用于個人網站或者小型項目,我會傾向于選擇免費的SSL證書,比如 JoySSL提供的免費證書。這樣可以在不增加額外費用的情況下為網站提供安全的加…

靜態HTTP與CDN:如何優化內容分發

大家好,今天我們來聊聊靜態HTTP和CDN這對“黃金搭檔”。沒錯,就是那個讓你的網站內容像閃電一樣傳遍全球的CDN! 首先,我們來了解一下靜態HTTP。它就像是那個老實可靠的郵差,每次都按時按點地把你的內容送到用戶手中。…

第二十一章博客

計算機應用實現了多臺計算機間的互聯,使得它們彼此之間能夠進行數據交流。網絡應用程序就是在已連接的不同計算機上運行的程序,這些程序借助于網絡協議,相互之間可以交換數據。編寫網絡應用程序前,首先必須明確所要使用的網絡協議…

Node.js中處理特殊字符的文件名,安全穩妥的方案

在Node.js中,通過path模塊提供的basename方法,我們可以輕松地從文件路徑中提取文件名。然而,這個方法在處理特殊字符時存在一些問題,因為它會對這些字符進行轉義,導致在不同操作系統上的兼容性問題。在這篇文章中&…

C++ boost planner_cond_.wait(lock) 報錯1225

1.如下程序段 boost unique_lock doesn’t own the mutex: Operation not permitted 問題: 其中makePlan是一個線程。這里的unlock導致錯誤這個報錯 boost unique_lock doesn’t own the mutex: Operation not permitted bool navigation::makePlan(){ //cv::named…

MySQL中如何快速定位占用CPU過高的SQL

作為DBA工作中都會遇到過數據庫服務器CPU飆升的場景,我們該如何快速定位問題?又該如何快速找到具體是哪個SQL引發的CPU異常呢?下面我們說兩個方法。聊聊MySQL中如何快速定位占用CPU過高的SQL。 技術人人都可以磨煉,但處理問題的思…

華為OD機試 - 多段線數據壓縮(Java JS Python C)

在線OJ刷題 題目詳情 - 多段線數據壓縮 - Hydro 題目描述 下圖中,每個方塊代表一個像素,每個像素用其行號和列號表示。 為簡化處理,多線段的走向只能是水平、豎直、斜向45度。 上圖中的多線段可以用下面的坐標串表示:(2,8),(3,7),(3,6),(3,5),(4,4),(5,3),(6,2),(7,3),(…

042、序列模型

之——從時序中獲取信息 目錄 之——從時序中獲取信息 雜談 正文 1.建模 2.方案A-馬爾科夫假設 3.方案B-潛變量模型 4.簡單實現 雜談 很多連續的數據都是有前后的時間相關性的,并不是每一個單獨的數據是隨機出現的。在時序中會蘊含一些空間結構的變化信息、…

【數據科學】一文徹底理清數據、數據類型、數據結構的概念

一、什么是數據? 入門數據學科,首先第一步要認識數據什么,可能大多數人都無法對數據做一個準確的定義,在我們印象中,提到數據首先頭腦浮現的是數據表格,是一堆堆數字,那么數據就是數字嗎&#x…

SpringBoot 2.0 中默認 HikariCP 數據庫連接池原理解析

作為后臺服務開發,在日常工作中我們天天都在跟數據庫打交道,一直在進行各種CRUD操作,都會使用到數據庫連接池。按照發展歷程,業界知名的數據庫連接池有以下幾種:c3p0、DBCP、Tomcat JDBC Connection Pool、Druid 等&am…

阿里云服務器記錄

阿里云服務器記錄 CentOS 8.4 64位 SCC版 CentOS 7.9 64位 SCC版 CentOS 7.9 64位 CentOS 7.9 64位 UEFI版 Alibaba Cloud Linux Anolis OS CentOS Windows Server Ubuntu Debian Fedora OpenSUSE Rocky Linux CentOS Stream AlmaLinux 阿里云服務器有個scc版,這個…

Flask+Mysql項目docker-compose部署(Pythondocker-compose詳細步驟)

一、前言 環境: Linux、docker、docker-compose、python(Flask)、Mysql 簡介: 簡單使用Flask框架寫的查詢Mysql數據接口,使用docker部署,shell腳本啟動 優勢: 采用docker方式部署更加便于維護,更加簡單快…

如何在Go中使用模板

引言 您是否需要以格式良好的輸出、文本報告或HTML頁面呈現一些數據?你可以使用Go模板來做到這一點。任何Go程序都可以使用text/template或html/template包(兩者都包含在Go標準庫中)來整齊地顯示數據。 這兩個包都允許你編寫文本模板并將數據傳遞給它們,以按你喜歡的格式呈…

“C語言“——scanf()、getchar() 、putchar()、之間的關系

scanf函數說明 scanf函數是對來自于標準輸入流的輸入數據作格式轉換,并將轉換結果保存至format后面的實參所指向的對象。 而const char*format 指向的字符串為格式控制字符串,它指定了可輸入的字符串以及賦值時轉換方法。 簡單來說給一個打印格式(輸入…

【并發編程篇】源碼分析,手動創建線程池

文章目錄 🛸前言🌹Executors的三大方法 🍔簡述線程池🎆手動創建線程池?源碼分析?代碼實現,手動創建線程池🎈CallerRunsPolicy()🎈AbortPolicy()🎈DiscardPolicy()🎈Dis…

LNPMariadb數據庫分離|web服務器集群

LNP&Mariadb數據庫分離|web服務器集群 網站架構演變單機版LNMP獨立數據庫服務器web服務器集群與Session保持 LNP與數據庫分離1. 準備一臺獨立的服務器,安裝數據庫軟件包2. 將之前的LNMP網站中的數據庫遷移到新的數據庫服務器3. 修改wordpress網站配置…

2023.12.24 關于 Redis 中 String 類型內部編碼 及 應用場景

目錄 String 類型內部編碼 3 種內部編碼方式 String 類型應用場景 Cache 緩存 鍵名命名規則 計數(Counter) 共享會話(Session ) 手機驗證碼 總結 String 類型內部編碼 3 種內部編碼方式 int:用來表示 64 位 —…

vue3菜單權限管理實現

前提 你的菜單是根據路由動態生成的,具體可以參考這篇博客對el-menu組件進行遞歸封裝(根據路由配置動態生成) 描述 首先將路由分為常量路由constantRoute(所有用戶都有的路由)和異步路由asyncRoute(需要動…

Gradle 插件

自定義Gradle插件 - 簡書

小天使的小難題:新生兒疝氣的關注與溫馨呵護

引言: 新生兒疝氣是一種在出生后可能出現的常見情況,雖然通常不會造成長期影響,但對于家長而言,了解如何正確應對新生兒疝氣是至關重要的。本文將深入探討新生兒疝氣的原因、癥狀,以及家長在面對這一問題時應該采取的…