FlashAttention:通過克服硬件性能瓶頸來革新Transformer

近年來,注意力模型(Attention Model)已成為各種應用的首選架構,包括自然語言處理(Natural Language Processing)和圖像分類等領域(Image Classification)。然而,隨著這些模型變得越來越大且更複雜,它們在運行時間和內存消耗方面遇到了重大挑戰,尤其是在處理長序列時。這些瓶頸的主要原因是自我注意機制(Self-attention Mechanism),其時間和內存複雜度與序列長度呈二次關係。為了解決這個問題,研究人員探索了各種近似注意力(approximate attention)方法。然而,這些方法在實際場景中往往無法實現顯著的加速,主要是由於忽略了一個重要因素:輸入/輸出(IO)操作。本文介紹了FlashAttention,一種新穎的注意力算法,它不僅顯著加速了注意力計算(attention computations),還通過引入IO感知來最小化內存開銷。

左圖:FlashAttention使用分塊技術來防止在(相對)較慢的GPU HBM上對大型N x N注意力矩陣(虛線框)的物化。在外循環(紅色箭頭)中,FlashAttention循環遍歷KV矩陣的塊並將它們加載到快速的on-chip SRAM中。在每個區塊中,FlashAttention循環遍歷Q矩陣區塊(藍色箭頭),將它們加載到SRAM中,並將注意力計算的輸出寫回HBM。
右圖:相對於GPT-2上的PyTorch實現,FlashAttention加速了注意力計算。FlashAttention不需要將大型N x N注意力矩陣讀取和寫入HBM,這導致注意力計算的加速比為7.6倍。

標準注意力模型(Standard Attention Models)的問題所在

標準的注意力模型在處理長序列時面臨兩個主要挑戰:計算複雜度和內存需求。自我注意機制(Self-attention Mechanism)的二次時間和內存複雜度使得對於大型輸入,計算速度緩慢且對內存需求較高。儘管近似注意力方法試圖減輕這些挑戰,但它們通常優先考慮減少浮點運算(FLOPs),而忽略了IO開銷。結果,它們無法實現顯著的實際速度提升。

此外,相對於內存速度,GPU處理器的計算速度改善更快,操作越來越受到高帶寬內存(HBM)存取的瓶頸限制。

介紹FlashAttention:分塊、重新計算和內核融合

FlashAttention通過引入IO感知並結合三個關鍵技術:分塊(Tiling)、重新計算(Recomputation)和內核融合(Kernel Fusion),來解決標準注意力模型的限制。FlashAttention利用這些技術來充分利用快速的SRAM。

分塊(Tiling):FlashAttention使用分塊技術,將輸入分成較小的塊,逐步進行注意力計算。這樣做可以減少在相對較慢的GPU高帶寬內存(HBM)上物化大型注意力矩陣(large attention matrix)的需求。分塊還使FlashAttention能夠利用快速的GPU on-chip SRAM來存儲和處理數據,減少不同內存層級之間的內存讀寫操作。

重新計算(Recomputation):為了進一步減少內存存取開銷,FlashAttention利用重新計算技術。它在反向傳播過程中即時重新計算注意力值(attention values),而不是將中間的注意力矩陣存儲(attention matrix)在內存中。這種方法消除了從HBM讀取注意力矩陣的需要,從而加速計算並降低內存使用。

內核融合(Kernel Fusion):FlashAttention將所有注意力操作合併為一個單獨的GPU內核,充分利用對內存讀寫的精細度控制。通過融合這些操作,FlashAttention減少了啟動多個內核所帶來的開銷,提高了效率。

FlashAttention帶來的重要優勢和進展

更快的模型訓練(Faster Model Training):FlashAttention實現了Transformer模型的更快訓練速度。實驗結果顯示,相較於MLPerf 1.1訓練速度記錄,對於BERT-large(序列長度512),FlashAttention實現了15%的端到端實時加速。相對於現有基準,FlashAttention在GPT-2(序列長度1K)上實現了3倍的加速,並在長範圍競技場(序列長度1K-4K)上實現了2.4倍的加速。

更高質量的模型(Higher Quality Models):通過在Transformer中實現更長的上下文,FlashAttention提供了更高質量的模型。它在GPT-2上取得了0.7的困惑度(perplexity)的進步,在長文檔分類上提升了6.4分。值得注意的是,FlashAttention使Transformer在具有挑戰性的任務上實現了優於隨機的性能,如Path-X挑戰(序列長度16K,61.4% 的準確率)和Path-256(序列長度64K,63.1% 的準確率)。

注意力基準(Benchmarking Attention):FlashAttention在常見的序列長度(從128到2K)上實現了高達3倍的加速,優於標準注意力模型(Standard Attention Models)實現。它可以擴展到序列長度64K,展示了出色的性能和效率。

結論

FlashAttention在解決注意力模型(Attention Models)所面臨的硬件性能瓶頸方面取得了重大突破。通過引入IO感知並利用分塊(Tiling)、重新計算(Recomputation)和內核融合(Kernel Fusion)等技術,FlashAttention在計算速度和內存效率方面實現了顯著的改進。實驗結果驗證了其在訓練高質量Transformer模型和克服長序列的限制方面的有效性。隨著注意力模型在各個領域的重要性不斷提升,FlashAttention為更高效和可擴展的Transformer訓練鋪平了道路,為深度學習領域(Deep Learning)的研究和應用開辟了新的可能性。

本文是在 AI 的協助下撰寫,並參考以下來源:
https://arxiv.org/pdf/2205.14135.pdf
https://github.com/Dao-AILab/flash-attention
https://youtu.be/FThvfkXWqtE?feature=shared

在此感謝 InnoHK、香港特別行政區政府及人工智能金融科技實驗室對本文的支持。
(AIFT 竭力但不能保證內容之準確和可靠,亦不會承擔因任何不準確或遺漏而引起的任何損失或損害。 )

分享此內容

地址

香港沙田香港科學園科技大道西 19 號
11樓 1101-1102 及 1121-1123 室

產品及解決方案

人才

工作機會

關於我們

地址

版權所有 © 2024 人工智能金融科技實驗室有限公司