一、整體介紹?
The FFT Strikes Again: An Efficient Alternative to Self-Attention
FFT再次出擊:一種高效的自注意力替代方案
圖1:FFTNet整體流程,包括局部窗口處理(STFT或小波變換,可選)和全局FFT,隨后在頻率/變換域進行等距融合(或門控)。
朋友們,今天為大家介紹一個非常有潛力,未來可能會在自然語言處理、計算機視覺、圖像處理等領域發揮重大作用的方法。
中心思想:該方法來源arXiv[1],是2025年3月16日最新公開論文,提出了一種名為FFTNet的自適應頻譜濾波框架,該框架利用快速傅里葉變換(FFT)在O(nlogn)時間內實現全局標記混合,有效解決了傳統自注意力機制在處理長序列時的二次復雜度問題,把自注意力機制(Self-Attention)的時間復雜度從O(n2)降到O(nlogn)。
實現動機:傳統的注意力機制在計算成對標記交互時,隨著序列長度n的增加,成本呈二次方增長,這使得處理長序列變得昂貴。相比之下,離散傅里葉變換(DFT)在O(nlogn)時間內自然編碼全局交互,因為它將標記序列分解為正交頻率分量。
核心原理:根據帕塞瓦爾定理,在傅里葉變換下,對于輸入序列X及其傅里葉變換F=FFT(X),信號的總能量保持不變,除了一個常數縮放因子。這一能量保持保證了自適應濾波和非線性操作不會意外扭曲輸入信號的固有信息。(學過數字信號處理課程的朋友應該更容易理解,總結起來就是一句話:把信號轉換到頻域進行處理,不會丟失信號信息,但是可以減少計算量)
證明公式和復雜度計算的公式較為枯燥,本文省略。
下面以代碼為例,展示原理及用法。
二、代碼與原理解讀?
1. 基于快速傅里葉變換的基礎網絡塊——FFTNetBlock
import torch
import torch.nn as nn
import torch.nn.functional as F
class ModReLU(nn.Module):def __init__(self, features):super().__init__()self.b = nn.Parameter(torch.Tensor(features))self.b.data.uniform_(-0.1, 0.1)def forward(self, x):return torch.abs(x) * F.relu(torch.cos(torch.angle(x) + self.b))
class FFTNetBlock(nn.Module):def __init__(self, dim):super().__init__()self.dim = dimself.filter = nn.Linear(dim, dim)self.modrelu = ModReLU(dim)def forward(self, x):# x: [batch_size, seq_len, dim]x_fft = torch.fft.fft(x, dim=1) # FFT along the sequence dimensionx_filtered = self.filter(x_fft.real) + 1j * self.filter(x_fft.imag)x_filtered = self.modrelu(x_filtered)x_out = torch.fft.ifft(x_filtered, dim=1).realreturn x_out
if __name__ == '__main__':# 參數設置batch_size = 1 # 批量大小seq_len = 224 * 224 # 序列長度(Transformer 中的 token 數量)dim = 32 # 維度# 創建隨機輸入張量,形狀為 (batch_size, seq_len, embed_dim)x = torch.randn(batch_size, seq_len, dim)# 初始化 FFTNetBlock 模塊model = FFTNetBlock(dim = dim)print(model)print("微信公眾號: AI縫合術!")output = model(x)print(x.shape)print(output.shape)
運行結果:
該代碼實現了一個基于 FFT(快速傅里葉變換)的神經網絡塊,稱為 FFTNetBlock,并在 forward 過程中對輸入信號進行頻域處理。
實現流程:
①使用 FFT 進行頻域轉換:輸入 x 通過 FFT 轉換到頻域,在頻域進行操作。
②使用可學習的濾波器:通過 nn.Linear 進行頻域的線性變換,相當于卷積核在頻域對信號進行加權處理。
③使用 ModReLU 進行非線性處理:由于 FFT 產生的結果是復數,傳統的 ReLU 不能直接作用,因此使用 ModReLU 進行非線性變換。ModReLU為修正的 ReLU 激活函數,作用類似于ReLU在實數域上的作用,但應用于復數域,通過修改相位角(angle)并結合 ReLU 進行修正。
④最終通過 iFFT 還原回時序空間:經過處理的頻域信息通過逆 FFT(ifft)變換回時序域,得到最終輸出。
2. 基于快速傅里葉變換的ViT網絡——FFTNetViT
import torch
import torch.nn as nn
import torch.nn.functional as F
def drop_path(x, drop_prob: float = 0., training: bool = False):"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""if drop_prob == 0. or not training:return xkeep_prob = 1 - drop_prob# Generate binary tensor mask; shape: (batch_size, 1, 1, ..., 1)shape = (x.shape[0],) + (1,) * (x.ndim - 1)random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)random_tensor.floor_() # binarizeoutput = x.div(keep_prob) * random_tensorreturn output
class DropPath(nn.Module):"""DropPath module that performs stochastic depth."""def __init__(self, drop_prob=None):super(DropPath, self).__init__()self.drop_prob = drop_probdef forward(self, x):return drop_path(x, self.drop_prob, self.training)class MultiHeadSpectralAttention(nn.Module):def __init__(self, embed_dim, seq_len, num_heads=4, dropout=0.1, adaptive=True):"""頻譜注意力模塊,在保持 O(n log n) 計算復雜度的同時,引入額外的非線性和自適應能力。參數:- embed_dim: 總的嵌入維度。- seq_len: 序列長度(例如 Transformer 中 token 的數量,包括類 token)。- num_heads: 注意力頭的數量。- dropout: 逆傅里葉變換(iFFT)后的 dropout 率。- adaptive: 是否啟用自適應 MLP 以生成乘法和加法的自適應調制參數。"""super().__init__()if embed_dim % num_heads != 0:raise ValueError("embed_dim 必須能被 num_heads 整除")self.num_heads = num_headsself.head_dim = embed_dim // num_headsself.seq_len = seq_lenself.adaptive = adaptive# 頻域的 FFT 頻率桶數量: (seq_len//2 + 1)self.freq_bins = seq_len // 2 + 1# 基礎乘法濾波器: 每個注意力頭和頻率桶一個self.base_filter = nn.Parameter(torch.ones(num_heads, self.freq_bins, 1))# 基礎加性偏置: 作為頻率幅度的學習偏移self.base_bias = nn.Parameter(torch.full((num_heads, self.freq_bins, 1), -0.1))if adaptive:# 自適應 MLP: 每個頭部和頻率桶生成 2 個值(縮放因子和偏置)self.adaptive_mlp = nn.Sequential(nn.Linear(embed_dim, embed_dim),nn.GELU(),nn.Linear(embed_dim, num_heads * self.freq_bins * 2))self.dropout = nn.Dropout(dropout)# 預歸一化層,提高傅里葉變換的穩定性self.pre_norm = nn.LayerNorm(embed_dim)def complex_activation(self, z):"""對復數張量應用非線性激活函數。該函數計算 z 的幅度,將其傳遞到 GELU 進行非線性變換,并按比例縮放 z,以保持相位不變。參數:z: 形狀為 (B, num_heads, freq_bins, head_dim) 的復數張量返回:經過非線性變換的復數張量,形狀相同。"""mag = torch.abs(z)# 對幅度進行非線性變換,GELU 提供平滑的非線性mag_act = F.gelu(mag)# 計算縮放因子,防止除零錯誤scale = mag_act / (mag + 1e-6)return z * scaledef forward(self, x):"""增強型頻譜注意力模塊的前向傳播。參數:x: 輸入張量,形狀為 (B, seq_len, embed_dim)返回:經過頻譜調制和殘差連接的張量,形狀仍為 (B, seq_len, embed_dim)"""B, N, D = x.shape# 預歸一化,提高頻域變換的穩定性x_norm = self.pre_norm(x)# 重新排列張量以分離不同的注意力頭,形狀變為 (B, num_heads, seq_len, head_dim)x_heads = x_norm.view(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)# 沿著序列維度計算 FFT,結果為復數張量,形狀為 (B, num_heads, freq_bins, head_dim)F_fft = torch.fft.rfft(x_heads, dim=2, norm='ortho')# 計算自適應調制參數(如果啟用)if self.adaptive:# 全局上下文:對 token 維度求均值,形狀為 (B, embed_dim)context = x_norm.mean(dim=1)# 經過 MLP 計算自適應參數,輸出形狀為 (B, num_heads*freq_bins*2)adapt_params = self.adaptive_mlp(context)adapt_params = adapt_params.view(B, self.num_heads, self.freq_bins, 2)# 劃分為乘法縮放因子和加法偏置adaptive_scale = adapt_params[..., 0:1] # 形狀: (B, num_heads, freq_bins, 1)adaptive_bias = adapt_params[..., 1:2] # 形狀: (B, num_heads, freq_bins, 1)else:# 如果不使用自適應機制,則縮放因子和偏置設為 0adaptive_scale = torch.zeros(B, self.num_heads, self.freq_bins, 1, device=x.device)adaptive_bias = torch.zeros(B, self.num_heads, self.freq_bins, 1, device=x.device)# 結合基礎濾波器和自適應調制參數# effective_filter: 影響頻譜響應的縮放因子effective_filter = self.base_filter * (1 + adaptive_scale)# effective_bias: 影響頻譜響應的偏置effective_bias = self.base_bias + adaptive_bias# 在頻域進行自適應調制# 先進行乘法縮放,再添加偏置(在 head_dim 維度上廣播)F_fft_mod = F_fft * effective_filter + effective_bias# 在頻域應用非線性激活F_fft_nl = self.complex_activation(F_fft_mod)# 逆傅里葉變換(iFFT)還原到時序空間# 需要指定 n=self.seq_len 以確保輸出長度匹配輸入x_filtered = torch.fft.irfft(F_fft_nl, dim=2, n=self.seq_len, norm='ortho')# 重新排列張量,將注意力頭合并回嵌入維度x_filtered = x_filtered.permute(0, 2, 1, 3).reshape(B, N, D)# 殘差連接并應用 Dropoutreturn x + self.dropout(x_filtered)
class TransformerEncoderBlock(nn.Module):def __init__(self, embed_dim, mlp_ratio=4.0, dropout=0.1, attention_module=None, drop_path=0.0):"""一個通用的 Transformer 編碼器塊,集成了 drop path 隨機深度 。- embed_dim: 嵌入維度。- mlp_ratio: MLP 的擴展因子。- dropout: dropout 比率。- attention_module: 處理自注意力的模塊。- drop_path: 隨機深度的 drop path 比率。"""super().__init__()if attention_module is None:raise ValueError("必須提供一個注意力模塊! 此處應調用 MultiHeadSpectralAttention")self.attention = attention_moduleself.mlp = nn.Sequential(nn.Linear(embed_dim, int(embed_dim * mlp_ratio)),nn.GELU(),nn.Dropout(dropout),nn.Linear(int(embed_dim * mlp_ratio), embed_dim),nn.Dropout(dropout))self.norm = nn.LayerNorm(embed_dim)# 用于隨機深度的 drop path 層self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()def forward(self, x):# 在殘差連接中應用帶有 drop path 的注意力。x = x + self.drop_path(self.attention(x))# 在殘差連接中應用 MLP(經過層歸一化)并加入 drop path。x = x + self.drop_path(self.mlp(self.norm(x)))return xif __name__ == '__main__':# 參數設置batch_size = 1 # 批大小seq_len = 224 * 224 # 序列長度embed_dim = 32 # 嵌入維度num_heads = 4 # 注意力頭數# 創建隨機輸入張量 (batch_size, seq_len, embed_dim)x = torch.randn(batch_size, seq_len, embed_dim)# 初始化 MultiHeadSpectralAttentionattention_module = MultiHeadSpectralAttention(embed_dim=embed_dim, seq_len=seq_len, num_heads=num_heads)# 初始化 TransformerEncoderBlocktransformer_block = TransformerEncoderBlock(embed_dim=embed_dim, attention_module=attention_module)print(transformer_block)print("微信公眾號: AI縫合術!")# 前向傳播測試output = transformer_block(x)# 打印輸出形狀print("輸入形狀:", x.shape)print("輸出形狀:", output.shape)
運行結果:
乍一看代碼比較多,其實原理非常簡單,該代碼實現了一個標準的Transformer編碼器結構,除去兩個固定操作的隨機深度DropPath,剩下僅有兩個類組成,MultiHeadSpectralAttention實現了基于快速傅里葉變換的高效多頭自注意力,TransformerEncoderBlock是一個通用的Transformer編碼器模塊。
上圖是ViT的經典結構圖,我們只看右側編碼器部分,上述代碼實現的就是右側的編碼器,只是將多頭注意力轉換到頻域來進行計算,非常容易理解。
采用上面方法構建的FFTNetViT在LRA和ImageNet兩個數據集上的廣泛評估確認,FFTNet不僅實現了有競爭力的準確性,而且與固定傅里葉方法和標準自注意力相比,顯著提高了計算效率。
以上兩個模塊均可即插即用,應用在自然語言處理、圖像處理、計算機視覺等各類任務上,是非常好的創新!
https://github.com/AIFengheshu/Plug-play-modules
2025年全網最全即插即用模塊,免費分享!包含人工智能全領域(機器學習、深度學習等),適用于圖像分類、目標檢測、實例分割、語義分割、全景分割、姿態識別、醫學圖像分割、視頻目標分割、圖像摳圖、圖像編輯、單目標跟蹤、多目標跟蹤、行人重識別、RGBT、圖像去噪、去雨、去霧、去陰影、去模糊、超分辨率、去反光、去摩爾紋、圖像恢復、圖像修復、高光譜圖像恢復、圖像融合、圖像上色、高動態范圍成像、視頻與圖像壓縮、3D點云、3D目標檢測、3D語義分割、3D姿態識別等各類計算機視覺和圖像處理任務,以及自然語言處理、大語言模型、多模態等其他各類人工智能相關任務。持續更新中.....