在 Hacker News 及 Twitter 等社交網絡上,該論文都反響熱烈,有網友表示差分 Transformer 提出的改進簡單又美麗,而帶來的提升又非常顯著。

甚至已有開發者做出了差分 Transformer 的輕量實現!

那么差分 Transformer 彌補了原生 Transformer 的哪些問題呢?如下圖所示,Transformer 往往會過度關注不相關的上下文,該團隊將此稱為注意力噪聲(attention noise)。而差分 Transformer 則能放大對答案范圍的注意力并消除噪音,從而增強上下文建模的能力。這就要用到該團隊新提出的差分注意力機制(differential attention mechanism)了。

差分注意力機制可以消除注意力噪聲,鼓勵模型重點關注關鍵信息。該方法有些類似于電氣工程中的降噪耳機和差分放大器。

下面我們就來詳細了解一下差分 Transformer 的設計思路。

差分 Transformer

差分 Transformer 是一種用于序列建模的基礎模型架構。為了方便說明,他們使用了僅解碼器(decoder-only)模型作為示例來描述該架構。

該模型堆疊了 L 個 Diff Transformer 層。給定一個輸入序列 x,將輸入嵌入打包成 X^0。輸入會被進一步上下文化來獲得輸出 X^L。每一層都由兩個模塊組成:一個差分注意力模塊和之后的前向網絡模塊。

相比于 Transformer,差分 Transformer 的主要差別在于使用差分注意力替換了傳統的 softmax 注意力,同時保持整體宏觀布局不變。此外,他們也參考 LLaMA 采用了 pre-RMSNorm 和 SwiGLU 這兩項改進措施。

差分注意力

差分注意力機制的作用是將查詢、鍵和值向量映射成輸出。這里使用查詢和鍵向量來計算注意力分數,然后計算值向量的加權和。

此處的關鍵設計是使用一對 softmax 函數來消除注意力分數的噪聲。具體來說,給定輸入 X,首先將它們投射成查詢、鍵和值 Q_1、Q_2、K_1、K_2、V。然后差分注意力算子 DiffAttn (?) 通過以下方式計算輸出:

其中 W^Q、W^K 、W^V 是參數,λ 是可學習的標量。為了同步學習動態,將標量 λ 重新參數化為:

其中 λ_q1、λ_k1、λ_q2、λ_k2 是可學習的向量,λ_init ∈ (0, 1) 是用于初始化 λ 的常數。該團隊通過經驗發現,設置 λ_init = 0.8 ? 0.6 × exp (?0.3?(l ? 1)) 在實踐中效果很好,其中 l ∈ [1, L] 表示層索引。它在實驗中被用作默認策略。

他們也探索了另一種初始化策略:對所有層使用相同的 λ_init(例如 0.8)。如后面消融研究所示,使用不同的初始化策略時,性能相對穩健。

差分注意力利用兩個 softmax 注意力函數之間的差來消除注意力噪聲。這個想法類似于電氣工程中提出的差分放大器,其中兩個信號之間的差用作輸出,這樣就可以消除輸入的共模噪聲。此外,降噪耳機的設計也基于類似的想法。

該團隊也為差分注意力使用了多頭機制。令 h 表示注意力頭的數量。他們對各個頭使用不同的投影矩陣 W^Q_i 、W^K_i 、W^V_i ,i ∈ [1, h]。標量 λ 在同一層內的頭之間共享。然后對頭輸出執行歸一化,并投射成最終結果,如下所示:


其中 λ_init 是 (2) 式中的常數標量,W^O 是可學習的投影矩陣,LN (?) 是對每個頭使用 RMSNorm,Concat (?) 的作用是沿通道維度將頭連接在一起。這里使用一個固定乘數(1 ? λ_init)作為 LN (?) 的縮放尺度,以使梯度與 Transformer 對齊。

圖 2 使用了 GroupNorm (?) 來強調 LN (?) 獨立應用于每個 head。由于差分注意力往往具有更稀疏的模式,因此頭之間的統計信息更加多樣化。為了改進梯度的統計情況,LN (?) 算子會在連接操作之前對每個頭進行歸一化。

整體架構

其整體架構會堆疊 L 層,其中每層包含一個多頭差分注意力模塊和一個前向網絡模塊。如此,便可將差分 Transformer 層描述為:

其中 LN (?) 是 RMSNorm,SwiGLU (X) = (swish (XW^G) ⊙ XW_1) W_2,且 W^G、W_1、W_2 是可學習的矩陣。

實驗

該團隊從以下角度評估了差分 Transformer 在 LLM 中的應用,包括對比評估、應用評估和消融研究。這里我們僅關注實驗結果,更多實驗過程請訪問原論文。

語言建模評估

該團隊評估了差分 Transformer 的語言建模能力。為此,他們使用 1T token 訓練了一個 3B 大小的差分 Transformer 語言模型,并與之前的 Transformer 語言模型做了比較。

結果見表 1,其中報告的是在 LM Eval Harness 基準上的零樣本結果。

可以看到,3B 規模下,差分 Transformer 語言模型的表現優于之前的 Transformer 語言模型。此外,實驗也表明差分 Transformer 在多種任務上都勝過 Transformer,詳見原論文附錄。

與 Transformer 的可擴展性比較

該團隊也比較了新舊 Transformer 的可擴展性。結果見圖 3,其中 a 比較了模型規模方面的可擴展性,而 b 則是訓練 token 數量方面的可擴展性。

可以看到,在這兩個方面,差分 Transformer 的可擴展性均優于常規 Transformer:僅需后者 65% 左右的模型大小或訓練 token 數量就能達到相媲美的性能。

長上下文評估

當 3B 模型上下文長度增長至 64K,模型的表現又如何呢?又使用另外 1.5B token 訓練了 3B 版本的檢查點模型之后,該團隊發現隨著上下文長度的增加,累積平均負對數似然(NLL)持續下降。差分 Transformer 得到的 NLL 值低于常規 Transformer。見圖 4,這樣的結果表明,差分 Transformer 可以有效地利用不斷增加的上下文。

關鍵信息檢索

為了檢驗差分 Transformer 檢索關鍵信息的能力,該團隊執行了 Needle-In-A-Haystack(草堆找針)測試。

表 2 給出了 4K 上下文長度的情況,其中 N 是針的數量,R 是查詢引用的數量。可以看到,差分 Transformer 的多針檢索準確度高于常規 Transformer,尤其是當針數量較多時,差分 Transformer 的優勢會更加明顯。

那么當上下文長度提升至 64K 時,又會如何呢?結果見圖 5,這里使用的上下文長度在 8K 到 64K 之間,使用了 N = 8 和 R = 1 的設置。

可以看到,在不同的上下文長度下,差分 Transformer 能夠保持相對穩定的性能。而當上下文長度越來越大時,常規 Transformer 的性能會逐漸下降。

另外,表 3 展示了分配給關鍵信息檢索任務的答案范圍和噪聲上下文的注意力分數。該分數可代表模型保留有用信息、抵抗注意力噪聲的能力。

可以看到,相比于常規 Transformer,差分 Transformer 能為答案范圍分配更高的注意力分數,同時為注意力噪聲分配更低的注意力分數。

上下文學習能力評估

該團隊從兩個角度評估模型的上下文學習能力,包括多樣本分類和上下文學習的穩健性。

圖 6 展示了新舊 Transformer 模型的多樣本分類結果。結果表明,在不同的數據集和不同的演示樣本數量上,差分 Transformer 均穩定地優于 Transformer。此外,差分 Transformer 的平均準確度優勢也很明顯,從 5.2% 到 21.6% 不等。

圖 7 則展示了兩種模型的上下文學習穩健性結果。該分析基于 TREC 數據集,并且采用了兩種提示詞格式:示例隨機排列(圖 7a)和按類別交替排列(圖 7b)。

在這兩種設置下,差分 Transformer 的性能方差要小得多。結果表明,新方法在上下文學習任務中更為穩健。相比之下,Transformer 容易受到順序排列的影響,導致最佳結果與最差結果之間差距巨大。

上下文幻覺評估

該團隊基于文本摘要和問答任務評估了模型的上下文幻覺現象。結果見表 4。

可以看到,相比于常規 Transformer,差分 Transformer 在摘要和問答任務上的上下文幻覺更低。該團隊表示,原因可能是差分 Transformer 能更好地關注任務所需的基本信息,而不是無關上下文。

激活異常值分析

在 LLM 中,一部分激活值明顯大于大多數激活值的現象被稱為激活異常值(activation outliers)。異常值導致訓練和推理過程中模型量化困難。實驗表明差分 Transformer 可以降低激活異常值的幅度,從而可能實現更低的量化位寬。

表 5 展示了兩個訓練得到 Transformer 和差分 Transformer 模型的激活值統計情況。這里分析了兩種類型的激活,包括注意力 logit(即 pre-softmax 激活)和隱藏狀態(即層輸出)。可以看到,盡管中位數相似,但與 Transformer 相比,差分 Transformer 的較大激活值要低得多。這表明新方法產生的激活異常值較少。

圖 8 則展示了將注意力 logit 量化到更低位的情況。這里使用的方案是:使用 absmax 量化的動態后訓練量化。其中,16 位配置表示未經量化的原始結果。模型逐步量化為 8 位、6 位和 4 位。這里報告的是在 HellaSwag 上的零樣本準確度,但該團隊也指出在其它數據集上也有類似表現。

從圖中可知,即使降低位寬,差分 Transformer 也能保持較高性能。相較之下,常規 Transformer 的準確度在 6 位和 4 位量化時會顯著下降。這一結果表明,差分 Transformer 本身就能緩解注意力分數中的激活異常值問題,從而可為低位 FlashAttention 的實現提供新機會。

最后,該團隊也進行了消融實驗,證明了各個新設計的有效性。

文章轉自微信公眾號@算法進階

上一篇:

超完整!11 種經典時間序列預測方法!

下一篇:

圖神經網絡加速綜述: 算法、系統和硬件
#你可能也喜歡這些API文章!

我們有何不同?

API服務商零注冊

多API并行試用

數據驅動選型,提升決策效率

查看全部API→
??

熱門場景實測,選對API

#AI文本生成大模型API

對比大模型API的內容創意新穎性、情感共鳴力、商業轉化潛力

25個渠道
一鍵對比試用API 限時免費

#AI深度推理大模型API

對比大模型API的邏輯推理準確性、分析深度、可視化建議合理性

10個渠道
一鍵對比試用API 限時免費