前言
在自然語言處理領域,隨著大語言模型(LLMs)不斷拓展其閱讀、理解和生成文本的能力,如何高效處理長文本成為一項關鍵挑戰。近日,Moonshot AI Research 聯合清華大學、浙江大學的研究人員提出了一種創新方法 —— 混合塊注意力機制(Mixture of Block Attention,MoBA),它將專家混合(Mixture of Experts,MoE)原理應用于注意力機制,為解決長文本處理難題帶來了新的思路。
在 Transformer 架構廣泛應用的當下,其注意力機制存在明顯弊端。在處理長文本時,傳統注意力機制需將每個 token 與其他所有 token 進行比較,這使得計算成本隨序列長度呈二次方增長。當模型處理長篇文檔、多章書籍、法律簡報或大型代碼庫等包含大量文本信息的任務時,這種計算成本會變得難以承受。此前,為解決這一問題,研究人員嘗試過多種方法。例如,滑動窗口機制將 token 限制在局部鄰域內,雖降低了計算量,但會忽略重要的全局關系;而一些徹底改變基本架構的方法,如用全新結構替代 softmax 注意力機制,往往需要從頭開始重新訓練模型,難以利用現有的預訓練成果。
核心原理
MoBA 的出現有效彌補了上述方法的不足。它的核心在于將輸入劃分為易于管理的 “塊”,并借助可訓練的門控系統來確定每個查詢 token 相關的塊。這種設計遵循 “少結構” 原則,不預先定義哪些 token 應該相互作用,而是由學習到的門控網絡做出決策。與固定結構或近似處理的方法不同,MoBA 能讓模型自主學習注意力的聚焦點。而且,MoBA 可與現有的基于 Transformer 的模型無縫協作,它作為一種 “插件” 或替代方案,保持與原模型相同的參數數量,避免架構膨脹,同時保留因果掩碼,確保自回歸生成的準確性。在實際應用中,MoBA 能在稀疏注意力和全注意力之間靈活切換。處理超長輸入時,稀疏注意力可提升速度;而在訓練的某些層或階段,若需要全注意力,模型也能切換回標準模式。
從技術細節來看,MoBA 將上下文劃分為多個塊,每個塊包含連續的 token 序列。門控機制通過比較查詢 token 與塊的池化鍵表示,計算查詢 token 與每個塊之間的 “親和度” 分數,然后選擇得分最高的塊。這樣,只有最相關塊中的 token 才會對最終的注意力分布產生影響。同時,包含查詢 token 本身的塊始終被納入,以確保局部上下文信息可訪問。并且,MoBA 執行因果掩碼,防止 token 關注未來位置,維持從左到右的自回歸屬性。這種基于塊的方法大幅減少了 token 比較次數,使計算規模低于二次方,隨著上下文長度增加到數十萬甚至數百萬個 token,效率提升愈發顯著。此外,MoBA 與現代加速器和專用內核兼容性良好。研究人員將 MoBA 與 FlashAttention(一種高性能的快速、內存高效的精確注意力庫)相結合,根據所選塊對查詢 - 鍵 - 值操作進行精心分組,進一步優化了計算流程。實驗數據顯示,在處理一百萬個 token 時,MoBA 相比傳統全注意力機制速度提升約 6 倍,凸顯了其在實際應用中的優勢。
在性能測試方面,MoBA 表現出色。技術報告顯示,在多種任務中,MoBA 的性能與全注意力機制相當,但在處理長序列時可顯著節省計算資源。在語言建模數據測試中,當序列長度為 8192 或 32768 個 token 時,MoBA 的困惑度與全注意力 Transformer 相近。更為關鍵的是,當研究人員將上下文長度逐漸擴展到 128000 及更長時,MoBA 仍能保持強大的長上下文理解能力。在 “尾隨 token” 評估中,MoBA 能夠有效處理長提示末尾附近的 token 預測任務,且預測質量沒有明顯下降。研究人員還對 MoBA 的塊大小和門控策略進行了敏感性探索。實驗表明,細化粒度(使用更小的塊但選擇更多的塊)有助于模型更接近全注意力的效果。即使在忽略大部分上下文的情況下,自適應門控也能識別與查詢真正相關的塊。此外,“混合” 模式展現出一種平衡策略:部分層繼續使用 MoBA 提升速度,少數層則恢復全注意力。這種混合方法在監督微調任務中尤為有益,例如當輸入中的某些位置在訓練目標中被屏蔽時,保留少數上層的全注意力,可使模型保持廣泛的上下文覆蓋,有助于需要全局視角的任務。
關鍵代碼分析:
以下是對 MoBA 庫關鍵代碼?MixedAttention
?類的分析以及關鍵代碼的摘錄與注釋:
整體分析
MixedAttention
?類是一個自定義的?torch.autograd.Function
,用于實現混合塊注意力機制。這個類主要包含兩個靜態方法:forward
?和?backward
,分別用于前向傳播和反向傳播。
class MixedAttention(torch.autograd.Function):# 前向傳播函數@staticmethoddef forward(ctx,q, # 查詢張量k, # 鍵張量v, # 值張量self_attn_cu_seqlen, # 自注意力累積序列長度moba_q, # MoBA 查詢張量moba_kv, # MoBA 鍵值張量moba_cu_seqlen_q, # MoBA 查詢累積序列長度moba_cu_seqlen_kv, # MoBA 鍵值累積序列長度max_seqlen, # 最大序列長度moba_chunk_size, # MoBA 塊大小moba_q_sh_indices, # MoBA 查詢塊索引):# 保存一些參數,用于后續的反向傳播ctx.max_seqlen = max_seqlenctx.moba_chunk_size = moba_chunk_sizectx.softmax_scale = softmax_scale = q.shape[-1] ** (-0.5)# 自注意力計算_, _, _, _, self_attn_out_sh, self_attn_lse_hs, _, _ = (_flash_attn_varlen_forward(q=q,k=k,v=v,cu_seqlens_q=self_attn_cu_seqlen,cu_seqlens_k=self_attn_cu_seqlen,max_seqlen_q=max_seqlen,max_seqlen_k=max_seqlen,softmax_scale=softmax_scale,causal=True,dropout_p=0.0,))# MoBA 注意力計算_, _, _, _, moba_attn_out, moba_attn_lse_hs, _, _ = _flash_attn_varlen_forward(q=moba_q,k=moba_kv[:, 0],v=moba_kv[:, 1],cu_seqlens_q=moba_cu_seqlen_q,cu_seqlens_k=moba_cu_seqlen_kv,max_seqlen_q=max_seqlen,max_seqlen_k=moba_chunk_size,softmax_scale=softmax_scale,causal=False,dropout_p=0.0,)# 轉換 lse 形狀,從 hs 轉換為 sh(遵循傳統混合注意力邏輯)self_attn_lse_sh = self_attn_lse_hs.t().contiguous()moba_attn_lse = moba_attn_lse_hs.t().contiguous()# 初始化輸出緩沖區,形狀與 q 相同output = torch.zeros((q.shape[0], q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)# 將輸出張量展平為二維,便于后續索引操作output_2d = output.view(-1, q.shape[2])# 計算混合 lse# 減去最大 lse 以避免指數爆炸max_lse_1d = self_attn_lse_sh.view(-1)max_lse_1d = max_lse_1d.index_reduce(0, moba_q_sh_indices, moba_attn_lse.view(-1), "amax")self_attn_lse_sh = self_attn_lse_sh - max_lse_1d.view_as(self_attn_lse_sh)moba_attn_lse = (moba_attn_lse.view(-1).sub(max_lse_1d.index_select(0, moba_q_sh_indices)).reshape_as(moba_attn_lse))# 計算自注意力和 MoBA 注意力的 softmax 結果mixed_attn_se_sh = self_attn_lse_sh.exp()moba_attn_se = moba_attn_lse.exp()# 將 MoBA 注意力結果累加到自注意力結果上mixed_attn_se_sh.view(-1).index_add_(0, moba_q_sh_indices, moba_attn_se.view(-1))mixed_attn_lse_sh = mixed_attn_se_sh.log()# 加權自注意力輸出factor = (self_attn_lse_sh - mixed_attn_lse_sh).exp() # [ vS, H ]self_attn_out_sh = self_attn_out_sh * factor.unsqueeze(-1)output_2d += self_attn_out_sh.reshape_as(output_2d)# 加權 MoBA 輸出mixed_attn_lse = (mixed_attn_lse_sh.view(-1).index_select(0, moba_q_sh_indices).view_as(moba_attn_lse))factor = (moba_attn_lse - mixed_attn_lse).exp() # [ vS, H ]moba_attn_out = moba_attn_out * factor.unsqueeze(-1)raw_attn_out = moba_attn_out.view(-1, moba_attn_out.shape[-1])output_2d.index_add_(0, moba_q_sh_indices, raw_attn_out)# 將輸出轉換為與輸入相同的數據類型output = output.to(q.dtype)# 恢復最大 lsemixed_attn_lse_sh = mixed_attn_lse_sh + max_lse_1d.view_as(mixed_attn_se_sh)# 保存中間結果,用于反向傳播ctx.save_for_backward(output,mixed_attn_lse_sh,q,k,v,self_attn_cu_seqlen,moba_q,moba_kv,moba_cu_seqlen_q,moba_cu_seqlen_kv,moba_q_sh_indices,)return output# 反向傳播函數@staticmethoddef backward(ctx, d_output):# 從上下文中獲取保存的參數max_seqlen = ctx.max_seqlenmoba_chunk_size = ctx.moba_chunk_sizesoftmax_scale = ctx.softmax_scale(output,mixed_attn_vlse_sh,q,k,v,self_attn_cu_seqlen,moba_q,moba_kv,moba_cu_seqlen_q,moba_cu_seqlen_kv,moba_q_sh_indices,) = ctx.saved_tensors# 確保輸入梯度連續d_output = d_output.contiguous()# 計算自注意力的梯度dq, dk, dv, _ = _flash_attn_varlen_backward(dout=d_output,q=q,k=k,v=v,out=output,softmax_lse=mixed_attn_vlse_sh.t().contiguous(),dq=None,dk=None,dv=None,cu_seqlens_q=self_attn_cu_seqlen,cu_seqlens_k=self_attn_cu_seqlen,max_seqlen_q=max_seqlen,max_seqlen_k=max_seqlen,softmax_scale=softmax_scale,causal=True,dropout_p=0.0,window_size=(-1, -1),softcap=0.0,alibi_slopes=None,deterministic=True,)# 計算 MoBA 注意力的梯度headdim = q.shape[-1]d_moba_output = (d_output.view(-1, headdim).index_select(0, moba_q_sh_indices).unsqueeze(1))moba_output = (output.view(-1, headdim).index_select(0, moba_q_sh_indices).unsqueeze(1))mixed_attn_vlse = (mixed_attn_vlse_sh.view(-1).index_select(0, moba_q_sh_indices).view(1, -1))dmq, dmk, dmv, _ = _flash_attn_varlen_backward(dout=d_moba_output,q=moba_q,k=moba_kv[:, 0],v=moba_kv[:, 1],out=moba_output,softmax_lse=mixed_attn_vlse,dq=None,dk=None,dv=None,cu_seqlens_q=moba_cu_seqlen_q,cu_seqlens_k=moba_cu_seqlen_kv,max_seqlen_q=max_seqlen,max_seqlen_k=moba_chunk_size,softmax_scale=softmax_scale,causal=False,dropout_p=0.0,window_size=(-1, -1),softcap=0.0,alibi_slopes=None,deterministic=True,)# 合并 MoBA 的鍵和值的梯度dmkv = torch.stack((dmk, dmv), dim=1)return dq, dk, dv, None, dmq, dmkv, None, None, None, None, None
代碼關鍵部分解釋
-
前向傳播 (
forward
):- 分別計算自注意力和 MoBA 注意力的結果。
- 對注意力分數進行處理,包括形狀轉換、歸一化等操作,以避免指數爆炸。
- 將自注意力和 MoBA 注意力的結果進行加權合并,得到最終的輸出。
- 保存中間結果,用于后續的反向傳播。
-
反向傳播 (
backward
):- 根據前向傳播保存的中間結果,計算自注意力和 MoBA 注意力的梯度。
- 最終返回各個輸入張量的梯度。
小結
通過這種方式,MixedAttention
?類實現了 MoBA 混合塊注意力機制,通過將上下文劃分為塊并進行選擇性的注意力計算,有效減少了計算量,提升了處理長文本的效率。
總結
總體而言,MoBA 非常適合處理涉及大量上下文的任務,如長篇文檔閱讀理解、大規模代碼補全以及需要完整對話歷史的多輪對話系統。它在提高效率的同時,性能損失極小,為大規模訓練大語言模型提供了一種極具吸引力的方法。雖然目前 MoBA 主要應用于文本領域,但研究人員認為,其底層機制在其他數據模態中也具有應用潛力。只要序列長度足夠長,引發計算或內存問題,將查詢分配給塊 “專家” 的思路就有望緩解瓶頸,同時保持處理關鍵全局依賴關系的能力。隨著語言應用中的序列長度持續增長,像 MoBA 這樣的方法可能會在推動神經語言建模的可擴展性和成本效益方面發揮關鍵作用,為人工智能的發展注入新的活力。