ICML | 千倍長度泛化!螞蟻新注意力機制GCA 16M長上下文精準理解

ICML | 千倍長度泛化!螞蟻新注意力機制GCA 16M長上下文精準理解

文章圖片

ICML | 千倍長度泛化!螞蟻新注意力機制GCA 16M長上下文精準理解

文章圖片

ICML | 千倍長度泛化!螞蟻新注意力機制GCA 16M長上下文精準理解

文章圖片

ICML | 千倍長度泛化!螞蟻新注意力機制GCA 16M長上下文精準理解

文章圖片


該工作第一作者為螞蟻技術研究院副研究員胡翔 , 螞蟻技術研究院高級研究員武威為通訊作者 。
在大語言模型如火如荼的當下 , 長文本建模仍然是一個極具挑戰的問題 。 糾其根源 , 一方面在于主流 LLMs 的架構 Transformers 中平方復雜度及隨序列長度線性增長的推理階段顯存開銷;另一方面在于 full-attention 有限的外推能力 , 難以泛化到遠超預訓練階段長度的輸入 。
而高效處理長上下文能力 , 除了簡單的工業界降本增效的需求外 , 還涉及通用人工智能 (AGI) 的核心問題:具有永久記憶的智能體 。 如果將人類從出生開始接收到的信息視作長上下文 , 人類擁有記憶無非是訪問這些上下文 。 因此記憶可以看作是超長上下文訪問能力 , 而擁有與用戶所有對話記憶的智能體 , 很可能為大語言模型公司構建數據護城河 (事實上 , OpenAI 已經開放了類似能力) 。
近日 , 螞蟻的研究團隊為這個問題帶來了一個新思路 。 就像人類開卷考試只會挑和當前問題相關的關鍵頁作為參考 , 語言模型也可以只關注與當前上下文相關的過去片段 。 以此為出發點 , 他們提出一種基于因果檢索的注意力機制 GCA (Grouped Cross Attention) , 完全端到端地學習如何從上文檢索并挑選最相關片段 , 從而實現超長序列高性能處理與泛化能力 。 人類記憶的另一個特性是大部分時候記憶處于沉睡狀態 , 相關記憶片段只會在激活時進入意識 。 類似地 , GCA 通過將上文信息卸載到 CPU / 磁盤 , 只在需要的時候動態加載需要的片段到 GPU 的方式 , 大幅降低了長文本處理的顯存開銷 。
目前 , GCA 的 Triton kernel 實現已全部開源 , 相關論文已被 ICML 2025 接收 。

  • 論文標題:Efficient Length-Generalizable Attention via Causal Retrieval for Long-Context Language Modeling
  • 論文地址:https://arxiv.org/abs/2410.01651
  • GitHub 主頁:https://github.com/ant-research/long-context-modeling
實驗結果也令人振奮:整合 GCA 的模型不僅在長文本數據集上展現了更優的 perplexity , 更展現了 1000 倍以上的長度泛化能力 , 在 16K 上下文預訓練的模型可在 16M 長上下文密鑰檢索 (passkey retrieval) 實現 100% 準確率 , 并在更復雜的多跳檢索任務持續展現了超強外推能力 。 此外長度泛化與檢索能力效果拔群 , 基于 GCA 的模型訓練開銷隨序列長度幾乎呈線性關系 , 并且推理的顯存開銷接近常數 , 同時基本持平 Transformers 推理速度 。
值得一提的是 , 本工作 24 年 10 月在 arXiv 公開后 , 國產之光 DeepSeek 在 25 年初公開了 NSA , 兩者思路都是通過挑選過去 chunk 并 attention 的方式實現性能優化 。 但各有側重 , GCA 核心亮點在于超長的長度泛化 , NSA 通過巧妙的 kernel 設計實現了逐 token 的稀疏 attention 。 受 NSA 的啟發 , GCA 的后繼工作 HSA (https://arxiv.org/abs/2504.16795) 結合了兩者的優點進行了融合 。
長文本處理難點及現有方案的局限性
近年來 , 有不少工作討論 Transformers (TRMs) 架構如何高效處理長文本 。 因為基于全量上文 attention 的 TRMs 有一個很顯著的局限:輸入長度超過預訓練長度一定程度后 , perplexity 會飆升 , 無法生成正常文本 。 如果只是解決正常生成的問題 , 一個最簡單的思路是滑動窗口注意力 , 即每個 token 僅關注最鄰近的 N 個 token 即可 。 這種方式可以保證 LLMs 持續生成 , 但它犧牲了長程信息獲取能力 。
另一種思路是認為 attention 窗口擴大到預訓練長度范圍之外后會導致原本的 attention 權重分布發生變化 , 因此通過調整 softmax 溫度的方式進行長度泛化 。 但這類方法經實驗驗證往往泛化的倍率也有限 。
因此 , attention 長度泛化的難點在于處理超長序列的同時 , 能夠真正有效利用上文中的信息 。
GCA: 基于端到端因果檢索的注意力機制
現有一些工作通過檢索增強 (RAG) 的思路來進行長文本建模 , 其基本思路是將文本分段 , 譬如每 64 個 token 為一個 chunk;每生成一個 chunk 后 , 模型根據當前上文信息檢索歷史 chunk 來輔助下一個 chunk 的生成 。 理想情況下 , 只要能檢索到對下文生成最有幫助的 chunk , 再通過 cross-attention 機制從相關 chunk 收集信息即可 。 但通常檢索模塊是單獨訓練的 , 只能檢索到相似內容 , 無法保證挑選對下文生成最有幫助的 chunk 。
和已有工作相比 , GCA 的一個顯著優勢是能夠與自回歸語言模型聯合預訓練 , 從而實現端到端學習 。

上圖對比了 GCA 與傳統檢索方式的運作區別 。 傳統方式中 (a) 檢索模塊檢索并返回相關 chunk , 但檢索分只用于挑選 chunk 完全不參與 forward 運算 , 因此無法獲得梯度 , 無法學習 。 GCA 的核心創新在于通過一種兩階段的注意力機制 , 使得每個 chunk 的檢索分能參與到自回歸預測中如圖中(b)所示 。
1. 分組注意力機制
不同于 (a) 中直接將 chunk 拼接在一起進行 attention ,GCA 分別對每個 chunk 進行 attention (分組 attention) , 從各個 chunk 收集 token 粒度的信息并整合 , 作為每個 chunk 整體的信息 。
2. Chunk-level 信息融合
GCA 將每個 chunk 的檢索相關分通過 softmax 得到一個概率分布 , 將其作為權重對第一步所有 chunk 的表征進行加權求和 , 融合所有 chunk 信息用于下一個 token 預測 。 在反向傳播過程中 , 更有助于預測下文的 chunk 將被分配更大的權重 , 從而實現檢索模塊的端到端學習 。
模型整體架構是通過 GCA 與 sliding window attention 結合實現長上下文建模;前者負責長程信息檢索 , 后者負責整合短程信息 。 為了進一步提升 GCA 性能 , 降低顯存開銷 , 研究團隊將整個 GCA 封裝成由 Triton 實現的 kernel , 方便未來工作可以直接復用 。
實驗結果
在語言模型 , 長程檢索等任務上的實驗表明:
1. 基于 GCA 的 128M 的模型在大海撈針任務即可超越大部分主流 7B 模型 , 達成 1000 倍外推 , 實現 16M 上下文的完美大海撈針 。
在該實驗中 , 所有模型都僅在不超過 16K 的上下文進行預訓練 , baseline 囊括了包含 sliding window attention 等主流注意力機制 。 基于 GCA 的模型無論在簡單大海撈針 , 還是更復雜的變量追蹤任務 , 都保持了穩定的外推能力 。
注意到幾乎所有 baseline 在上下文長度超過 64K 后幾乎都歸零 , 這些不同模型存在不同原因 。 劃窗注意力因為只能看最鄰近的 token , 無法實現長程信息獲?。 換諮方峁溝撓捎謁猩舷攣男畔⒍急謊顧踉諞桓齬潭ㄎ鵲謀碚?, 必然存在信息損失的問題;基于單獨訓練檢索器的模型 (RPTContriever) 的結果進一步驗證了檢索模型未必能檢索到對下文有幫助的上文 。
這一結果經驗性地為可長度泛化的注意力機制提供了一個成功的概念原型 。 同時證明可泛化的長程信息獲取能力取決于注意力機制原理上的改進 , 與參數量的提升無關 。

在摘要及 RULER 榜單的效果
2. 預訓練高效 , 推理時顯存開銷接近常數:GCA 是一種 sparse attention , 其 attention 的視野域保持常數 , 因此在 batch size 一定的情況下 , 訓練開銷幾乎與序列長度呈線性 。 由于 GCA 在生成階段將所有上文的 KV cache 都卸載到 CPU , 每次檢索的時候才把相關 chunk 的 kv cache 載入 GPU , 因此超長上文也不會有 KV cache 顯存爆炸的問題 。 而 GPU-CPU 的交換控制在每 64 個 token 一次 , 因此對推理速度影響非常小 , 從而實現接近常數的顯存開銷 , 但仍保持高效的推理速度及長程信息獲取能力 。

訓練時間及 ppl 隨序列長度的變化

推理速度與顯存開銷相比基線 (基于劃窗注意力的 Transformers) 的倍率關系(越低越好)

相同條件不同模型各個參數規模下的訓練吞吐量 , 相比劃窗注意力有額外 20% 的開銷 , 但帶來超長程信息獲取的能力
【ICML | 千倍長度泛化!螞蟻新注意力機制GCA 16M長上下文精準理解】3. 在 arXiv-math 上的數據分析發現 , 通過 GCA , 語言模型會根據當前上下文 , 檢索下文生成中可能會用到的引理及變量聲明 。 這說明 GCA 學到的不僅僅是字面相似性 , 更包含了語義乃至邏輯相關性 。

黑體是當前 chunk , 紅色 , 藍色 , 黃色 , 分別代表 top3 相關 chunk、
結語
本工作提出一種可以長度泛化的稀疏注意力機制 GCA 其核心在于可導的檢索模塊 , 可以有效處理 1000 倍于預訓練長度的文本 , 首次實現在 16M 長度完美的大海撈針 。 雖然當前實驗的模型規模較小 , 但期望該工作可以為機器如何實現永久記憶提供新的研究思路 。

    推薦閱讀