Enhancing linear attention with residual learning
利用残差学习增强线性注意力
https://arxiv.org/pdf/2509.25223v1
摘要
线性注意力为自注意力机制提供了一种线性时间复杂度的替代方案,但往往难以捕捉长距离模式。我们通过"预测-校正"的视角重新审视线性注意力,发现主流变体都可以被表示为历史预测与单令牌校正的组合,这造成了表达能力瓶颈。为解决这一瓶颈,我们提出了残差线性注意力(RLA),这是一个为线性注意力配备显式残差拟合机制的框架。RLA 维护一个辅助循环状态,用于学习随时间累积残差误差并校正基础预测。我们进一步实例化了一个 delta 规则版本——残差 Delta 网络(RDN),结合了自适应门控和残差裁剪以增强校正控制和稳定性。我们的实现利用了高度优化的线性注意力核函数,并保持线性的时间和内存复杂度。在语言建模和召回密集型评估中,RLA 和 RDN 始终优于各自的基线模型及其他现代线性注意力方法,在保持线性扩展性的同时缩小了与标准 Transformer 的差距。
1 引言
Transformer(Vaswani 等人,2017)架构已成为大型语言模型的标准。然而,其自注意力机制的二次时间复杂度仍然是一个关键瓶颈,限制了其在长序列上的应用(Li 等人,2024)。线性注意力最近作为标准自注意力的高效替代方案涌现,直接解决了其过高的二次复杂度问题。通过将注意力计算重构为循环过程,这些模型实现了线性时间的训练和推理,使其非常适合处理长序列。RetNet(Sun 等人,2023)和 Mamba(Gu & Dao,2023;Dao & Gu,2024)等架构已展现出具有竞争力的性能。GLA(Yang 等人,2023)和 DeltaNet(Yang 等人,2024b)等方法通过引入数据依赖的门控和状态更新规则来管理单一状态矩阵内的信息流,进一步改进了性能。
现代线性注意力方法可以被统一为学习从键到值的直接映射(Sun 等人,2024),这一过程类似于测试时训练。例如,delta 更新规则(Schlag 等人,2021)可以从二次损失目标的单步在线梯度下降推导得出。这一视角开辟了若干改进途径,包括探索不同的在线学习损失函数以推导新的更新规则(Schlag 等人,2021;Yang 等人,2024b)、采用更复杂的映射函数,或修改在线梯度更新机制(von Oswald 等人,2025;Siems 等人,2025)。例如,TTT-MLP(Sun 等人,2024)和 Titans(Behrouz 等人,2024)等近期工作利用多层感知机(MLP)作为深度记忆模块来实现更强大的映射。然而,这种方法牺牲了模型的线性循环特性,从而使并行训练变得复杂。
基于这一视角,我们对注意力输出提供了一种新的解释。我们证明,主流线性注意力模型的输出可以分解为来自历史状态的基础分量和仅源自当前令牌的校正项(见第 2.3 节)。依赖单一令牌来执行这种系统性校正造成了瓶颈,损害了模型的表达能力。为解决这些问题,我们引入了残差线性注意力,这是一个用显式残差拟合机制增强线性注意力模型的框架。我们的方法不依赖单一令牌进行校正,而是采用辅助状态矩阵来显式建模和校正基础线性注意力的系统性预测误差。最终输出是基础预测与该学习误差校正的组合。我们的方法可以推广为适用于各种线性注意力方法的统一框架,为构建更强大的序列模型提供了一种强大而高效的策略。
在现有线性注意力方法的基础上,我们提出了两种增强残差拟合的变体:残差线性注意力(RLA)和残差 Delta 网络(RDN)。我们在一系列基准测试上评估了它们,包括语言建模和召回密集型任务。我们的结果表明,这些模型优于各自的基线模型和其他现代线性注意力方法,而我们的消融分析证实了框架内每个关键设计选择的重要性。
2 预备知识
2.1 作为循环模型的线性注意力
Softmax 注意力机制相对于序列长度表现出二次计算复杂度,在处理长序列时构成了显著的瓶颈。线性注意力(Katharopoulos 等人,2020)架构通过移除 softmax 函数来解决这一问题,从而允许对计算进行重新排序。
这种循环形式在推理过程中保持每步恒定的时间和内存复杂度,并通过分块并行算法实现高效训练(Yang 等人,2023)。此外,门控机制的使用催生了更多变体的发展,如 RetNet(Sun 等人,2023)、Lightning Attention(Qin 等人,2024a)和 Mamba-2(Dao & Gu,2024)。
2.2 在线学习视角
这种形式化使 Delta Net(Yang 等人,2024b;Schlag 等人,2021)等模型能够实现细粒度的记忆控制。门控 Delta Net(Yang 等人,2024a)进一步通过在学习过程中引入权重衰减来增强这一方法。
2.3 分解为预测与校正
基于预测-校正的视角,我们引入了一个残差拟合框架来增强线性注意力。我们的框架通过显式拟合超出当前令牌的上下文信息,学习一个更具表达力的校正项。
3 方法
本节介绍我们提出的方法,该方法通过残差拟合过程来增强线性注意力。我们首先描述支撑我们方法的基础残差学习框架。接下来,我们引入自适应校正因子以增强建模能力,并引入裁剪方法来稳定残差拟合过程。最后,我们展示我们方法的两种最终变体。
3.1 显式残差拟合
利用第 2 节中线性注意力的在线学习视角,我们对辅助状态应用类似的更新规则。这产生了以下循环过程:
3.2 自适应门控与校正因子
这种形式化使用衰减因子和校正因子来分别对来自基础状态和辅助状态的检索进行动态门控。
3.3 归一化与残差裁剪
为确保计算稳定性,我们引入两种机制。首先,我们对查询和键向量应用 L2 归一化以提高数值稳定性。其次,我们通过裁剪残差来解决辅助状态中的潜在不稳定性:
这确保了误差校正状态保持稳定的学习轨迹,即使基础模型产生瞬态的、较大的预测误差。该裁剪方法的详细推导见附录 B。
3.4 最终形式化
残差拟合原理是一种通用技术,可以与各种线性注意力主干网络集成。通过将我们的残差机制应用于标准加法更新规则和 delta 更新规则,我们推导出两种强大的变体。这导出了我们的最终模型:
4 实验
4.1 实验设置
实现 为了最大化效率,我们在 Triton(Tillet 等人,2019)中实现了自定义注意力核函数,基于 flash-linear-attention 库(Yang & Zhang,2024)构建。我们利用了这样一个事实:我们的状态更新规则与线性注意力的相同,只需对其核函数进行微小修改:我们将其增强为返回注意力结果和中间残差。这种设计允许在所有残差拟合阶段重用相同的高度优化核函数,确保高吞吐量。
4.2 主要结果
核函数效率 我们将我们的核函数运行时间与线性注意力基线和 FlashAttention(Dao 等人,2022;Dao,2023)进行基准测试,如图 2 所示。尽管残差拟合过程增加了计算开销,但我们方法的运行时间随序列长度线性扩展。这使其在较长序列上显著快于二次扩展的 FlashAttention。关于吞吐量,我们的方法与其他线性注意力机制一样,保持几乎恒定的高吞吐量。相反,计算受限的 FlashAttention 的吞吐量随序列长度增加而迅速下降。
语言建模与常识推理 我们在 WikiText(Merity 等人,2016)困惑度以及一系列评估推理和常识理解的基准测试上评估 RLA 和 RDN。推理任务包括 ARC-Easy、ARC-Challenge(Clark 等人,2018)、PIQA(Bisk 等人,2020)和 MMLU(Hendrycks 等人,2020),而常识理解则在 HellaSwag(Zellers 等人,2019)、Winogrande(Sakaguchi 等人,2021)、SocialIQA(Sap 等人,2019)和 LAMBADA(Paperno 等人,2016)上进行评估。我们的主要结果总结于表 2,显示我们提出的残差学习变体 RLA 和 RDN 在困惑度上相对于各自的基线 sGLA 和 GDN 取得了一致的改进。此外,我们的模型在多个基准测试上优于其他领先的线性注意力方法,并提供与标准 Transformer 相当的性能。
召回密集型任务 为了评估记忆容量,我们在 Arora 等人(2024)的召回密集型任务上对我们的模型进行基准测试。此外,我们还直接使用"大海捞针"任务(NIAH)(gkamradt,2023)评估模型的检索能力,该任务需要检索插入在长文档不同深度的键值对。这些基准测试对线性注意力模型具有挑战性,因为它们的有限状态空间造成了信息瓶颈,如表 3 所示。结果表明,我们提出的 RLA 和 RDN 始终优于其相应的基线,在 DROP 和 FDA 基准测试上取得了特别显著的收益。此外,它们在 NIAH 任务上大幅优于其他模型,突显了增强的信息召回能力。
4.3 消融研究
在本节中,我们进行一系列消融研究以验证关键组件的贡献。我们首先量化我们学习的残差拟合方法相对于预定义校正的优势。接下来,我们研究使用专用校正因子的重要性,然后分析将基础预测与校正相结合的门控机制的必要性。最后,我们检查归一化和残差裁剪的效果。
如表 4 所示,缺乏显式残差拟合的变体表现不如我们的完整方法。尽管该消融变体在某些基准测试上保持竞争力,但它在训练集和评估集上的困惑度都显著增加。这种性能下降延伸到专业领域,在 GSM8k(Cobbe 等人,2021)和 HumanEval(Chen 等人,2021)的困惑度测量中,其数学和代码能力显著退化。这证明了辅助状态在累积过去残差以有效细化模型输出方面的关键作用。
专用校正因子 我们通过将我们的完整模型与 γ 绑定到更新因子 β 的变体进行比较,分析使用专用校正因子 γ 的益处。在图 3a 中,具有独立 γ 的模型始终实现更低的评估损失,其中 RDN 变体显示出更大的改进。这一趋势延伸到下游性能,如图 3b 的结果所示,该图还显示专用校正因子在多个基准测试上带来性能提升。值得注意的是,我们的基础架构(不需要额外的 γ)仍然比基线线性注意力方法有显著改进。
归一化与残差裁剪 最后,我们研究归一化和残差裁剪的重要性。我们通过对 RLA 移除归一化和裁剪来进行消融研究。如图 4 所示,两个组件对稳定训练都至关重要;移除它们会导致无界激活和性能退化。相比之下,RDN 模型对残差裁剪很大程度上不敏感。这种鲁棒性归因于其 delta 规则更新的固有稳定性,即使没有残差裁剪也能保持一致的损失曲线(图 4b)。
5 相关工作
序列建模历史上由循环神经网络(RNN)(Lipton 等人,2015)主导,包括长短期记忆网络(LSTM)(Hochreiter & Schmidhuber,1997)和门控循环单元(GRU)(Cho 等人,2014)等变体。虽然有效,但其固有的顺序性质阻碍了训练并行化。Transformer 架构(Vaswani 等人,2017)克服了这一限制,成为序列建模的事实标准。然而,其自注意力机制具有相对于序列长度的二次计算复杂度,对长上下文应用构成了显著瓶颈。
为解决这些挑战,近期研究重新审视了线性 RNN,将其作为高效 Transformer 替代方案的基础。通过将序列处理形式化为线性循环,这些模型实现了可并行化训练和线性时间推理。该领域的早期探索,如 S4(Gu 等人,2021)、LRU(Orvieto 等人,2023)和 RetNet(Sun 等人,2023),利用了结构化状态转移矩阵。通过引入数据依赖的动态特性,后续实现了性能飞跃。Mamba(Gu & Dao,2023;Dao & Gu,2024)、HGRN(Qin 等人,2023;2024b)和门控线性注意力(Yang 等人,2023)等模型利用输入依赖的门控来动态控制状态转移,从而增强其表达能力。
更先进的方法引入了 delta 学习规则,将状态更新从简单的门控衰减重新框架为细粒度的记忆校正。这种方法以 DeltaNet(Yang 等人,2024b;Schlag 等人,2021)和门控 DeltaNet(Yang 等人,2024a)为代表,实现了更精确的动态记忆修改。该机制可以从在线学习视角理解,其中状态更新被框架为优化过程,如 TTT(Sun 等人,2024)所探索的。这一观点启发了进一步的工作,旨在发现和改进序列模型内的内在学习算法(von Oswald 等人,2023;2025)。
同期研究聚焦于增加状态转移的表达能力。例如,RWKV-7(Peng 等人,2025)采用对角加低秩结构,而 DeltaProduct(Siems 等人,2025)通过每令牌执行多步更新来推广 DeltaNet。为进一步提升容量,近期架构如 Titans(Behrouz 等人,2024)和 Miras(Behrouz 等人,2025)引入了非线性深度记忆,用 MLP 对状态进行参数化。
6 结论
在本文中,我们介绍了残差线性注意力,这是一个通过显式残差拟合过程来增强线性注意力模型的框架。我们的方法利用辅助状态来校正基础模型的预测误差,从而构建更鲁棒和准确的上下文表示。该框架具有高度适应性,可应用于各种线性注意力方法。我们的实验证明了这种多功能性,显示我们的方法始终优于各自的基线。虽然这种改进以拟合过程的额外计算为代价,但平衡这一权衡为未来的研究提供了一个有前景的方向。
原文链接:https://arxiv.org/pdf/2509.25223v1
热门跟贴