論文《Scaling Local Self-Attention for Parameter Efficient Visual Backbones》
1、作用
HaloNet通過引入Haloing機制和高效的注意力實現,在圖像識別任務中達到了最先進的準確性。這些模型通過局部自注意力機制,有效地捕獲像素間的全局交互,同時通過分塊和Haloing策略,顯著提高了處理速度和內存效率。
2、機制
1、Haloing策略:
為了克服傳統自注意力的計算和內存限制,HaloNet采用了Haloing策略,將圖像分割成多個塊,并為每個塊擴展一定的Halo區域,僅在這些區域內計算自注意力。這種方法減少了計算量,同時保持了較大的感受野。
2、多尺度特征層次:
HaloNet構建了多尺度特征層次結構,通過分層采樣和跨尺度的信息流,有效捕獲不同尺度的圖像特征,增強了模型對圖像中對象大小變化的適應性。
3、高效的自注意力實現:
通過改進的自注意力算法,包括非中心化的局部注意力和分層自注意力下采樣操作,HaloNet在保持高準確性的同時,提高了訓練和推理速度。
3、獨特優勢
1、參數效率:
HaloNet通過局部自注意力機制和Haloing策略,大幅度減少了所需的計算量和內存需求,實現了與當前最佳卷積模型相當甚至更好的性能,但使用更少的參數。
2、適應多尺度:
多尺度特征層次結構使得HaloNet能夠有效處理不同尺度的對象,提高了對復雜視覺任務的適應性和準確性。
3、提升速度和效率:
通過優化的自注意力實現,HaloNet在不犧牲準確性的前提下,實現了比現有技術更快的訓練和推理速度,使其更適合實際應用。
4、代碼
import torch
from torch import nn, einsum
import torch.nn.functional as Ffrom einops import rearrange, repeat# 將設備和數據類型轉換為字典格式def to(x):return {'device': x.device, 'dtype': x.dtype}# 確保輸入是元組形式
def pair(x):return (x, x) if not isinstance(x, tuple) else x# 在指定維度上擴展張量
def expand_dim(t, dim, k):t = t.unsqueeze(dim=dim)expand_shape = [-1] * len(t.shape)expand_shape[dim] = kreturn t.expand(*expand_shape)# 將相對位置編碼轉換為絕對位置編碼
def rel_to_abs(x):b, l, m = x.shaper = (m + 1) // 2col_pad = torch.zeros((b, l, 1), **to(x))x = torch.cat((x, col_pad), dim=2)flat_x = rearrange(x, 'b l c -> b (l c)')flat_pad = torch.zeros((b, m - l), **to(x))flat_x_padded = torch.cat((flat_x, flat_pad), dim=1)final_x = flat_x_padded.reshape(b, l + 1, m)final_x = final_x[:, :l, -r:]return final_x# 生成一維的相對位置logits
def relative_logits_1d(q, rel_k):b, h, w, _ = q.shaper = (rel_k.shape[0] + 1) // 2logits = einsum('b x y d, r d -> b x y r', q, rel_k)logits = rearrange(logits, 'b x y r -> (b x) y r')logits = rel_to_abs(logits)logits = logits.reshape(b, h, w, r)logits = expand_dim(logits, dim=2, k=r)return logits# 相對位置嵌入類
class RelPosEmb(nn.Module):def __init__(self,block_size,rel_size,dim_head):super().__init__()height = width = rel_sizescale = dim_head ** -0.5self.block_size = block_sizeself.rel_height = nn.Parameter(torch.randn(height * 2 - 1, dim_head) * scale)self.rel_width = nn.Parameter(torch.randn(width * 2 - 1, dim_head) * scale)def forward(self, q):block = self.block_sizeq = rearrange(q, 'b (x y) c -> b x y c', x=block)rel_logits_w = relative_logits_1d(q, self.rel_width)rel_logits_w = rearrange(rel_logits_w, 'b x i y j-> b (x y) (i j)')q = rearrange(q, 'b x y d -> b y x d')rel_logits_h = relative_logits_1d(q, self.rel_height)rel_logits_h = rearrange(rel_logits_h, 'b x i y j -> b (y x) (j i)')return rel_logits_w + rel_logits_h# HaloAttention類class HaloAttention(nn.Module):def __init__(self,*,dim,block_size,halo_size,dim_head=64,heads=8):super().__init__()assert halo_size > 0, 'halo size must be greater than 0'self.dim = dimself.heads = headsself.scale = dim_head ** -0.5self.block_size = block_sizeself.halo_size = halo_sizeinner_dim = dim_head * headsself.rel_pos_emb = RelPosEmb(block_size=block_size,rel_size=block_size + (halo_size * 2),dim_head=dim_head)self.to_q = nn.Linear(dim, inner_dim, bias=False)self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)self.to_out = nn.Linear(inner_dim, dim)def forward(self, x):# 驗證輸入特征圖維度是否符合要求b, c, h, w, block, halo, heads, device = *x.shape, self.block_size, self.halo_size, self.heads, x.deviceassert h % block == 0 and w % block == 0, assert c == self.dim, f'channels for input ({c}) does not equal to the correct dimension ({self.dim})'q_inp = rearrange(x, 'b c (h p1) (w p2) -> (b h w) (p1 p2) c', p1=block, p2=block)kv_inp = F.unfold(x, kernel_size=block + halo * 2, stride=block, padding=halo)kv_inp = rearrange(kv_inp, 'b (c j) i -> (b i) j c', c=c)#生成查詢、鍵、值q = self.to_q(q_inp)k, v = self.to_kv(kv_inp).chunk(2, dim=-1)# 拆分頭部q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=heads), (q, k, v))# 縮放查詢向量q *= self.scale# 計算注意力sim = einsum('b i d, b j d -> b i j', q, k)# 添加相對位置偏置sim += self.rel_pos_emb(q)# 掩碼填充mask = torch.ones(1, 1, h, w, device=device)mask = F.unfold(mask, kernel_size=block + (halo * 2), stride=block, padding=halo)mask = repeat(mask, '() j i -> (b i h) () j', b=b, h=heads)mask = mask.bool()max_neg_value = -torch.finfo(sim.dtype).maxsim.masked_fill_(mask, max_neg_value)# 注意力機制attn = sim.softmax(dim=-1)# 聚合out = einsum('b i j, b j d -> b i d', attn, v)# 合并和組合頭部out = rearrange(out, '(b h) n d -> b n (h d)', h=heads)out = self.to_out(out)# 將塊合并回原始特征圖out = rearrange(out, '(b h w) (p1 p2) c -> b c (h p1) (w p2)', b=b, h=(h // block), w=(w // block), p1=block,p2=block)return out# 輸入 N C H W, 輸出 N C H W
if __name__ == '__main__':block = HaloAttention(dim=512,block_size=2,halo_size=1, ).cuda()# 創建HaloAttention實例input = torch.rand(1, 512, 64, 64).cuda()# 創建隨機輸入output = block(input) # 前向傳播print(output.shape)