23 - HaLoAttention模塊

論文《Scaling Local Self-Attention for Parameter Efficient Visual Backbones》

1、作用

HaloNet通過引入Haloing機制和高效的注意力實現,在圖像識別任務中達到了最先進的準確性。這些模型通過局部自注意力機制,有效地捕獲像素間的全局交互,同時通過分塊和Haloing策略,顯著提高了處理速度和內存效率。

2、機制

1、Haloing策略

為了克服傳統自注意力的計算和內存限制,HaloNet采用了Haloing策略,將圖像分割成多個塊,并為每個塊擴展一定的Halo區域,僅在這些區域內計算自注意力。這種方法減少了計算量,同時保持了較大的感受野。

2、多尺度特征層次

HaloNet構建了多尺度特征層次結構,通過分層采樣和跨尺度的信息流,有效捕獲不同尺度的圖像特征,增強了模型對圖像中對象大小變化的適應性。

3、高效的自注意力實現

通過改進的自注意力算法,包括非中心化的局部注意力和分層自注意力下采樣操作,HaloNet在保持高準確性的同時,提高了訓練和推理速度。

3、獨特優勢

1、參數效率

HaloNet通過局部自注意力機制和Haloing策略,大幅度減少了所需的計算量和內存需求,實現了與當前最佳卷積模型相當甚至更好的性能,但使用更少的參數。

2、適應多尺度

多尺度特征層次結構使得HaloNet能夠有效處理不同尺度的對象,提高了對復雜視覺任務的適應性和準確性。

3、提升速度和效率

通過優化的自注意力實現,HaloNet在不犧牲準確性的前提下,實現了比現有技術更快的訓練和推理速度,使其更適合實際應用。

4、代碼

import torch
from torch import nn, einsum
import torch.nn.functional as Ffrom einops import rearrange, repeat# 將設備和數據類型轉換為字典格式def to(x):return {'device': x.device, 'dtype': x.dtype}# 確保輸入是元組形式
def pair(x):return (x, x) if not isinstance(x, tuple) else x# 在指定維度上擴展張量
def expand_dim(t, dim, k):t = t.unsqueeze(dim=dim)expand_shape = [-1] * len(t.shape)expand_shape[dim] = kreturn t.expand(*expand_shape)# 將相對位置編碼轉換為絕對位置編碼
def rel_to_abs(x):b, l, m = x.shaper = (m + 1) // 2col_pad = torch.zeros((b, l, 1), **to(x))x = torch.cat((x, col_pad), dim=2)flat_x = rearrange(x, 'b l c -> b (l c)')flat_pad = torch.zeros((b, m - l), **to(x))flat_x_padded = torch.cat((flat_x, flat_pad), dim=1)final_x = flat_x_padded.reshape(b, l + 1, m)final_x = final_x[:, :l, -r:]return final_x# 生成一維的相對位置logits
def relative_logits_1d(q, rel_k):b, h, w, _ = q.shaper = (rel_k.shape[0] + 1) // 2logits = einsum('b x y d, r d -> b x y r', q, rel_k)logits = rearrange(logits, 'b x y r -> (b x) y r')logits = rel_to_abs(logits)logits = logits.reshape(b, h, w, r)logits = expand_dim(logits, dim=2, k=r)return logits# 相對位置嵌入類
class RelPosEmb(nn.Module):def __init__(self,block_size,rel_size,dim_head):super().__init__()height = width = rel_sizescale = dim_head ** -0.5self.block_size = block_sizeself.rel_height = nn.Parameter(torch.randn(height * 2 - 1, dim_head) * scale)self.rel_width = nn.Parameter(torch.randn(width * 2 - 1, dim_head) * scale)def forward(self, q):block = self.block_sizeq = rearrange(q, 'b (x y) c -> b x y c', x=block)rel_logits_w = relative_logits_1d(q, self.rel_width)rel_logits_w = rearrange(rel_logits_w, 'b x i y j-> b (x y) (i j)')q = rearrange(q, 'b x y d -> b y x d')rel_logits_h = relative_logits_1d(q, self.rel_height)rel_logits_h = rearrange(rel_logits_h, 'b x i y j -> b (y x) (j i)')return rel_logits_w + rel_logits_h# HaloAttention類class HaloAttention(nn.Module):def __init__(self,*,dim,block_size,halo_size,dim_head=64,heads=8):super().__init__()assert halo_size > 0, 'halo size must be greater than 0'self.dim = dimself.heads = headsself.scale = dim_head ** -0.5self.block_size = block_sizeself.halo_size = halo_sizeinner_dim = dim_head * headsself.rel_pos_emb = RelPosEmb(block_size=block_size,rel_size=block_size + (halo_size * 2),dim_head=dim_head)self.to_q = nn.Linear(dim, inner_dim, bias=False)self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)self.to_out = nn.Linear(inner_dim, dim)def forward(self, x):# 驗證輸入特征圖維度是否符合要求b, c, h, w, block, halo, heads, device = *x.shape, self.block_size, self.halo_size, self.heads, x.deviceassert h % block == 0 and w % block == 0, assert c == self.dim, f'channels for input ({c}) does not equal to the correct dimension ({self.dim})'q_inp = rearrange(x, 'b c (h p1) (w p2) -> (b h w) (p1 p2) c', p1=block, p2=block)kv_inp = F.unfold(x, kernel_size=block + halo * 2, stride=block, padding=halo)kv_inp = rearrange(kv_inp, 'b (c j) i -> (b i) j c', c=c)#生成查詢、鍵、值q = self.to_q(q_inp)k, v = self.to_kv(kv_inp).chunk(2, dim=-1)# 拆分頭部q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=heads), (q, k, v))# 縮放查詢向量q *= self.scale# 計算注意力sim = einsum('b i d, b j d -> b i j', q, k)# 添加相對位置偏置sim += self.rel_pos_emb(q)# 掩碼填充mask = torch.ones(1, 1, h, w, device=device)mask = F.unfold(mask, kernel_size=block + (halo * 2), stride=block, padding=halo)mask = repeat(mask, '() j i -> (b i h) () j', b=b, h=heads)mask = mask.bool()max_neg_value = -torch.finfo(sim.dtype).maxsim.masked_fill_(mask, max_neg_value)# 注意力機制attn = sim.softmax(dim=-1)# 聚合out = einsum('b i j, b j d -> b i d', attn, v)# 合并和組合頭部out = rearrange(out, '(b h) n d -> b n (h d)', h=heads)out = self.to_out(out)# 將塊合并回原始特征圖out = rearrange(out, '(b h w) (p1 p2) c -> b c (h p1) (w p2)', b=b, h=(h // block), w=(w // block), p1=block,p2=block)return out# 輸入 N C H W,  輸出 N C H W
if __name__ == '__main__':block = HaloAttention(dim=512,block_size=2,halo_size=1, ).cuda()# 創建HaloAttention實例input = torch.rand(1, 512, 64, 64).cuda()# 創建隨機輸入output = block(input) # 前向傳播print(output.shape)

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/diannao/87456.shtml
繁體地址,請注明出處:http://hk.pswp.cn/diannao/87456.shtml
英文地址,請注明出處:http://en.pswp.cn/diannao/87456.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

2025Mybatis最新教程(五)

第5章 ORM映射 5.1 MyBatis自動ORM失效 MyBatis只能自動維護庫表”列名“與”屬性名“相同時的對應關系,二者不同時,無法自動ORM。 自動ORM失效建表 create table t_managers(mgr_id int primary key auto_increment,mgr_name varchar(50),mgr_pwd varchar(50) ); 添加數據…

解決lombok注解失效問題

Lombok 注解失效是 Java 開發中的常見問題,通常由依賴配置、IDE 支持或構建工具設置引起。最近在拉取別人springboot3jdk21版本的項目時遇到了lombok注解失效,導致項目無法啟動的問題,以下是我的解決方案: 首先檢查idea 的lombok…

3分鐘搭建LarkXR實時云渲染PaaS平臺,實現各類3D/XR應用的一鍵推流

LarkXR是由Paraverse平行云自主研發的去中心化實時云渲染平臺,以其卓越的性能和豐富完備的功能插件,引領3D/XR云化行業風向標。LarkXR適用于3D/XR開發者、設計師、終端用戶等創新用戶,可以在零硬件負擔下,輕松實現超高清低時延的3…

vue3 watch監視詳解

watch監視 一 &#xff1a;watch監視{ref}定義的基本類型結構 <template><div class"person"><h1>情況一:watch監視{ref}定義的基本類型結構</h1><h1>當前的和為{{ sum }}</h1><button click"changeSum">點我…

TensorFlow Serving學習筆記2: 模型服務

本文深入剖析 TensorFlow Serving 的核心架構與實現機制&#xff0c;結合源碼分析揭示其如何實現高可用、動態更新的生產級模型服務。 一、TensorFlow Serving 核心架構 1.1 分層架構設計 TensorFlow Serving 采用模塊化分層設計&#xff0c;各組件職責分明&#xff1a; 組件…

共享云桌面為什么能打敗傳統電腦

近年來&#xff0c;隨著云桌面技術的快速發展&#xff0c;共享云桌面作為一種新型的計算模式&#xff0c;正在逐步改變人們的工作和生活方式。它憑借其獨特的優勢&#xff0c;正在逐步取代傳統電腦&#xff0c;成為企業和個人用戶的新選擇。之所以在部分場景中展現出替代傳統電…

B站PWN教程筆記-12

完結撒花。 今天還是以做題為主。 fmtstruaf 格式化字符串USER AFTER FREE 首先補充一個背景知識&#xff0c;指針也是有數據類型的&#xff0c;不同數據類型的指針xx&#xff0c;所加的字節數也不一樣&#xff0c;其實是指針指的項目的下一項。如int a[20]&#xff0c;a是…

零基礎設計模式——總結與進階 - 3. 學習資源與下一步

第五部分&#xff1a;總結與進階 - 3. 學習資源與下一步 到這里&#xff0c;你已經完成了設計模式主要內容的學習。但這僅僅是一個開始&#xff0c;設計模式的精髓在于實踐和持續學習。本節將為你提供一些優質的學習資源和后續學習的建議&#xff0c;幫助你在這條道路上走得更…

多模態大語言模型arxiv論文略讀(125)

Uni-Med: A Unified Medical Generalist Foundation Model For Multi-Task Learning Via Connector-MoE ?? 論文標題&#xff1a;Uni-Med: A Unified Medical Generalist Foundation Model For Multi-Task Learning Via Connector-MoE ?? 論文作者&#xff1a;Xun Zhu, Yi…

【學習筆記】NLP 基礎概念

1.1 什么是 NLP 定義&#xff1a; 自然語言處理&#xff08;NLP&#xff09;**是一種讓計算機理解、解釋和生成人類語言的技術。它是人工智能領域中極為活躍且重要的研究方向&#xff0c;旨在模擬人類對語言的認知和使用過程 特點&#xff1a; 多學科交叉&#xff1a;結合計…

RNN為什么不適合大語言模型

在自然語言處理&#xff08;NLP&#xff09;領域中&#xff0c;循環神經網絡&#xff08;RNN&#xff09;及衍生架構&#xff08;如LSTM&#xff09;采用序列依序計算的模式&#xff0c;這種模式之所以“限制了計算機并行計算能力”&#xff0c;核心原因在于其時序依賴的特性&a…

微信小程序一款不錯的文字動畫

效果圖 .js Page({data: {list:[],animation:[text-left,text-right,text-top,text-bottom],text:[[春眠不覺曉&#xff0c;處處聞啼鳥。,夜來風雨聲&#xff0c;花落知多少。 ],[床前明月光&#xff0c;疑是地上霜。,舉頭望明月&#xff0c;低頭思故鄉。],[千山鳥飛絕&#…

循環神經網絡(RNN):序列數據處理的強大工具

在人工智能和機器學習的廣闊領域中&#xff0c;處理和理解序列數據一直是一個重要且具有挑戰性的任務。循環神經網絡&#xff08;Recurrent Neural Network&#xff0c;RNN&#xff09;作為一類專門設計用于處理序列數據的神經網絡&#xff0c;在諸多領域展現出了強大的能力。從…

手機SIM卡通話中隨時插入錄音語音片段(Windows方案)

手機SIM卡通話中隨時插入錄音語音片段&#xff08;Windows方案&#xff09; --本地AI電話機器人 上一篇&#xff1a;手機SIM卡通話中隨時插入錄音語音片段&#xff08;Android方案&#xff09;??????? 下一篇&#xff1a;???????編寫中 一、前言 書接上文《手…

阿里云通義大模型:AI浪潮中的領航者

通義大模型初印象 在當今 AI 領域蓬勃發展的浪潮中&#xff0c;阿里云通義大模型宛如一顆璀璨的明星&#xff0c;迅速崛起并占據了重要的地位。隨著人工智能技術的不斷突破&#xff0c;大模型已成為推動各行業數字化轉型和創新發展的核心驅動力。通義大模型憑借其強大的技術實…

【算法篇】逐步理解動態規劃模型7(兩個數組dp問題)

目錄 兩個數組dp問題 1.最長公共子序列 2.不同的子序列 3.通配符匹配 本文旨在通過對力扣上三道題進行講解來讓大家對使用動態規劃解決兩個數組的dp問題有一定思路&#xff0c;培養大家對狀態定義&#xff0c;以及狀態方程書寫的思維。 順序&#xff1a; 題目鏈接-》算法思…

什么是 HTTP Range 請求(范圍請求)

HTTP Range 請求&#xff0c;即范圍請求&#xff0c;是一種 HTTP 請求方法&#xff0c;允許客戶端請求資源的部分數據。這種請求在處理大型文件&#xff08;如視頻、音頻、或大文件下載&#xff09;時特別有用&#xff0c;因為它可以有效地進行斷點續傳和按需加載數據&#xff…

java集合(十) ---- LinkedList 類

目錄 十、LinkedList 類 10.1 位置 10.2 特點 10.3 與 ArrayList 的區別 10.4 構造方法 10.5 常用方法 十、LinkedList 類 10.1 位置 LinkedList 類位于 java.util 包下 10.2 特點 是 List 接口的實現類是 Deque 接口的實現類底層使用雙向循環鏈表結構 10.3 與 Arra…

kafka消費的模式及消息積壓處理方案

目錄 1、kafka消費的流程 2、kafka的消費模式 2.1、點對點模式 2.2、發布-訂閱模式 3、consumer消息積壓 3.1、處理方案 3.2、積壓量 4、消息過期失效 5、kafka注意事項 Kafka消費積壓(Consumer Lag)是指消費者處理消息的速度跟不上生產者發送消息的速度&#xff0c;導致消息在…

RAG實踐:Routing機制與Query Construction策略

Routing機制與Query Construction策略 前言RoutingLogical RoutingChatOpenAIStructuredRouting DatasourceConclusion Semantic RoutingEmbedding & LLMPromptRounting PromptConclusion Query ConstructionGrab Youtube video informationStructuredPrompt GithubReferen…