Transformer 模型的计算复杂度随着输入序列长度的增加,呈平方的增加,那么这样就会消耗大量的内存资源,如何降低Transformer 模型的计算复杂度

打开网易新闻 查看精彩图片

Flash Attention是一种在Transformer模型中优化注意力机制(Attention Mechanism)的技术,它的主要目标是减少计算成本和内存需求,同时保持或提升模型的性能。传统的自我注意力层的计算复杂度为O(N^2),其中N是序列长度,这是因为每个位置的输出都需要与序列中的所有其他位置进行交互。

Flash Attention通过以下方式降低了计算复杂度:

1 局部注意力机制

Flash Attention利用了序列中信息的局部相关性,这意味着一个位置上的token通常与附近的token有更强的相关性。因此,它可以通过限制注意力窗口的大小来减少计算量,只考虑固定数量的前后token进行注意力计算。

2. 分块技术

将输入序列分成多个小块,然后在这些小块之间进行注意力计算。这种方法减少了全序列计算的需求,从而降低了计算复杂度。

具体实现:

将 Q, K, V 矩阵分割成多个小块,块大小根据 SRAM 的容量确定。

外循环遍历 K 和 V 的块,将其加载到 SRAM 中。

内循环遍历 Q 的块,也加载到 SRAM 中。

在 SRAM 中计算每个 Q 块与当前 K 块的局部注意力分数 Sij,并进行 softmax 和其他必要的计算。

更新全局统计量,并计算最终的注意力输出 O。

打开网易新闻 查看精彩图片

3. 优化 softmax 计算

Softmax 的 tiling 展开:Flash Attention 采用了 softmax 的 tiling 展开技术,支持 softmax 的拆分并行计算,从而提升计算效率。这种技术可以更有效地利用 GPU 的并行计算能力。

Safe softmax:为了处理 softmax 中 e^{x_i} 容易溢出的问题,Flash Attention 引入了 safe softmax。通过对每个 x_i 减去一个最大值 m(即 m = max^N_{j=1}(x_j)),使得 x_i - m ≪ 0,这时幂操作符对负数输入的计算是准确且安全的。

4. 稀疏注意力

只选择部分键值对进行注意力计算,而不是计算所有可能的键值对,这可以显著减少计算量。

5. 内存效率优化

传统注意力机制在计算过程中需要存储全部的Q(Query)、K(Key)、V(Value)矩阵以及注意力权重矩阵,而Flash Attention通过优化算法,可以在不存储完整注意力矩阵的情况下计算输出,从而大大节省了内存。

打开网易新闻 查看精彩图片