FlashAttention:通过克服硬件性能瓶颈来革新Transformer

近年来,注意力模型(Attention Model)已成为各种应用的首选架构,包括自然语言处理(Natural Language Processing)和图像分类等领域(Image Classification)。然而,随着这些模型变得越来越大且更复杂,它们在运行时间和内存消耗方面遇到了重大挑战,尤其是在处理长序列时。这些瓶颈的主要原因是自我注意机制(Self-attention Mechanism),其时间和内存复杂度与序列长度呈二次关系。为了解决这个问题,研究人员探索了各种近似注意力(approximate attention)方法。然而,这些方法在实际场景中往往无法实现显著的加速,主要是由于忽略了一个重要因素:输入/输出(IO)操作。本文介绍了FlashAttention,一种新颖的注意力算法,它不仅显著加速了注意力计算(attention computations),还通过引入IO感知来最小化内存开销。

左图:FlashAttention使用分块技术来防止在(相对)较慢的GPU HBM上对大型N x N注意力矩阵(虚线框)的物化。在外循环(红色箭头)中,FlashAttention循环遍历KV矩阵的块并将它们加载到快速的on-chip SRAM中。在每个区块中,FlashAttention循环遍历Q矩阵区块(蓝色箭头),将它们加载到SRAM中,并将注意力计算的输出写回HBM。
右图:相对于GPT-2上的PyTorch实现,FlashAttention加速了注意力计算。 FlashAttention不需要将大型N x N注意力矩阵读取和写入HBM,这导致注意力计算的加速比为7.6倍。

标准注意力模型(Standard Attention Models)的问题所在

标准的注意力模型在处理长序列时面临两个主要挑战:计算复杂度和内存需求。自我注意机制(Self-attention Mechanism)的二次时间和内存复杂度使得对于大型输入,计算速度缓慢且对内存需求较高。尽管近似注意力方法试图减轻这些挑战,但它们通常优先考虑减少浮点运算(FLOPs),而忽略了IO开销。结果,它们无法实现显著的实际速度提升。

此外,相对于内存速度,GPU处理器的计算速度改善更快,操作越来越受到高带宽内存(HBM)存取的瓶颈限制。

介绍FlashAttention:分块、重新计算和内核融合

FlashAttention通过引入IO感知并结合三个关键技术:分块(Tiling)、重新计算(Recomputation)和内核融合(Kernel Fusion),来解决标准注意力模型的限制。 FlashAttention利用这些技术来充分利用快速的SRAM。

分块(Tiling):FlashAttention使用分块技术,将输入分成较小的块,逐步进行注意力计算。这样做可以减少在相对较慢的GPU高带宽内存(HBM)上物化大型注意力矩阵(large attention matrix)的需求。分块还使FlashAttention能够利用快速的GPU on-chip SRAM来存储和处理数据,减少不同内存层级之间的内存读写操作。

重新计算(Recomputation):为了进一步减少内存存取开销,FlashAttention利用重新计算技术。它在反向传播过程中即时重新计算注意力值(attention values),而不是将中间的注意力矩阵存储(attention matrix)在内存中。这种方法消除了从HBM读取注意力矩阵的需要,从而加速计算并降低内存使用。

内核融合(Kernel Fusion):FlashAttention将所有注意力操作合并为一个单独的GPU内核,充分利用对内存读写的精细度控制。通过融合这些操作,FlashAttention减少了启动多个内核所带来的开销,提高了效率。

FlashAttention带来的重要优势和进展

更快的模型训练(Faster Model Training):FlashAttention实现了Transformer模型的更快训练速度。实验结果显示,相较于MLPerf 1.1训练速度记录,对于BERT-large(序列长度512),FlashAttention实现了15%的端到端实时加速。相对于现有基准,FlashAttention在GPT-2(序列长度1K)上实现了3倍的加速,并在长范围竞技场(序列长度1K-4K)上实现了2.4倍的加速。

更高质量的模型(Higher Quality Models):通过在Transformer中实现更长的上下文,FlashAttention提供了更高质量的模型。它在GPT-2上取得了0.7的困惑度(perplexity)的进步,在长文档分类上提升了6.4分。值得注意的是,FlashAttention使Transformer在具有挑战性的任务上实现了优于随机的性能,如Path-X挑战(序列长度16K,61.4% 的准确率)和Path-256(序列长度64K,63.1%的准确率)。

注意力基准(Benchmarking Attention):FlashAttention在常见的序列长度(从128到2K)上实现了高达3倍的加速,优于标准注意力模型(Standard Attention Models)实现。它可以扩展到序列长度64K,展示了出色的性能和效率。

结论

FlashAttention在解决注意力模型(Attention Models)所面临的硬件性能瓶颈方面取得了重大突破。通过引入IO感知并利用分块(Tiling)、重新计算(Recomputation)和内核融合(Kernel Fusion)等技术,FlashAttention在计算速度和内存效率方面实现了显著的改进。实验结果验证了其在训练高质量Transformer模型和克服长序列的限制方面的有效性。随着注意力模型在各个领域的重要性不断提升,FlashAttention为更高效和可扩展的Transformer训练铺平了道路,为深度学习领域(Deep Learning)的研究和应用开辟了新的可能性。

本文是在 AI 的协助下撰写,并参考以下来源:
https://arxiv.org/pdf/2205.14135.pdf
https://github.com/Dao-AILab/flash-attention
https://youtu.be/FThvfkXWqtE?feature=shared

在此感谢 InnoHK、香港特别行政区政府及人工智能金融科技实验室对本文的支持。
(AIFT 竭力但不能保证内容之准确和可靠,亦不会承担因任何不准确或遗漏而引起的任何损失或损害。)

分享此內容

宋林琦教授于日内瓦国际发明展获得2项银奖

地址

香港沙田香港科学园科技大道西 19号
11楼 1101-1102 及 1121-1123 室

产品及解决方案

人才

工作机会

关于我们

地址

版权所有 © 2024 人工智能金融科技实验室有限公司