FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness
논문 원본: https://arxiv.org/abs/2205.14135
FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness
Transformers are slow and memory-hungry on long sequences, since the time and memory complexity of self-attention are quadratic in sequence length. Approximate attention methods have attempted to address this problem by trading off model quality to reduce
arxiv.org
Introduction
FlashAttention은 기존 Transformer의 attention이 느려지는 핵심 원인이 연산량이 아니라 메모리 입출력이라는 점에 집중해서 만들어진 알고리즘이다. 특히 길이가 긴 시퀀스를 처리할 때는 QK 전치곱, softmax, 그리고 결과를 V와 곱하는 과정에서 매우 큰 중간 행렬이 생성된다. 이 전체 행렬을 GPU 메모리에 반복적으로 읽고 쓰는 과정에서 병목이 발생하기 때문에, 이 부분을 줄이지 않는 한 FLOPs를 줄여도 실제 속도 개선이 거의 나타나지 않는다.
FlashAttention은 기존 attention과 수학적으로 동일한 결과를 내면서도 이 중간 행렬을 아예 생성하지 않는 방식으로 계산 패턴을 재구성한다. 이를 통해 시퀀스 길이가 길어질수록 크게 증가하던 메모리 사용량을 줄이고, 실제 추론 속도를 2배 이상 향상시키는 효과를 얻었다.

Preliminary
기존의 scaled dot-product attention은 크게 세 단계를 거친다.
- QK 전치곱으로 attention score 행렬을 생성
- score에 대해 row-wise softmax
- softmax 결과를 V와 곱해 최종 output 생성
문제는 첫 번째 단계에서 만들어지는 score 행렬의 크기가 n×n이라는 점이다. 시퀀스 길이가 늘어날수록 이 행렬이 매우 커지고, softmax나 V와의 곱을 계산할 때 매번 GPU 메모리(HBM)에서 이 행렬을 다시 읽어와야 한다. 이 과정 때문에 전체 attention 연산은 메모리 중심(workload)으로 분석되며, 실제 GPU 구조에서는 연산장치보다 메모리 대역폭이 먼저 한계에 걸린다.
FlashAttention에서는 이 구조적 병목을 해결하기 위해 Q, K, V를 작은 블록 단위로 분할하고, GPU의 온칩 SRAM에서만 계산이 이루어지도록 전체 로직을 재설계한다. 이렇게 하면 어텐션 전체 행렬을 메모리에 저장할 필요가 없어지고, HBM에서 SRAM으로의 데이터 이동량이 크게 줄어든다.
Method
FlashAttention의 핵심은 IO-aware tiling 방식이다. 전체 attention을 한 번에 계산하는 대신 다음과 같이 처리한다.
- Q를 작은 타일로 나누어 SRAM에 올린다.
- 각 Q 타일마다 K와 V를 chunk 단위로 읽어와 QK 곱을 계산한다.
- softmax는 전체 score를 한 번에 보지 않고, 각 타일마다 최대값과 exp 합을 누적하는 방식으로 계산해서 정확도를 유지한다.
- softmax 결과를 기반으로 partial output을 반복적으로 계산하고 누적한다.
- 마지막에 모든 partial output을 합쳐 최종 attention 출력을 얻는다.
이 과정에서 중간 n×n score 행렬이 생성되지 않기 때문에 메모리 사용량은 크게 감소한다. 또한 softmax를 타일 단위로 분해해 계산하더라도 running max와 running sum을 유지하면 원래 softmax와 동일한 값을 복원할 수 있기 때문에 정확도 손실이 없다. FlashAttention이 “근사 attention”이 아니라 완전히 동일한 full attention이라는 점이 중요한 특징이다.


Experiment
실험 결과 FlashAttention은 다음과 같은 효과를 보였다.
- GPT-2 및 GPT-Neo 모델 기준 최대 2.4배 속도 향상
- 시퀀스 길이가 길어질수록 속도 향상 폭 증가
- 메모리 사용량 약 2~3배 감소
- full attention과 출력이 완전히 동일
특히 GPU 자원을 한계까지 활용해야 하는 긴 문맥 처리나 대형 모델 추론 환경에서 FlashAttention은 일반 attention 대비 확실한 속도 이점을 가진다. 연산량 자체는 동일하지만, 메모리 이동량이 감소하여 실제 latency가 크게 개선되기 때문이다.
Related Work
기존 연구들은 attention 연산량을 줄이기 위해 low-rank approximation이나 sparse attention을 이용해 근사하는 방식이 많았다. 하지만 이러한 방법들은 정확도가 떨어진다는 단점이 있다. 다른 접근으로는 메모리 최적화나 chunk 단위 연산을 이용한 개선이 있었으나, 이들 역시 정확한 attention 결과를 유지하지 못하거나 구현 제약이 컸다. FlashAttention은 이들과 달리 정확한 full attention을 유지하면서도 메모리 사용을 근본적으로 줄이는 방식이라는 점에서 차별점을 가진다.
Conclusion
FlashAttention은 Transformer attention이 가진 구조적 병목을 메모리 관점에서 다시 분석하고, 이를 해결하기 위한 IO-aware tile-based 알고리즘을 제안한 연구다. 중간 행렬을 생성하지 않도록 설계하여 HBM-SRAM 간 데이터 이동을 최소화했으며, 이를 통해 정확도 손실 없이 실제 추론 속도를 크게 향상시켰다. 이후 FlashAttention-2, xFormers 최적화, MLA·CQ 같은 효율적 attention 연구들도 모두 이 논문이 제시한 IO-awareness 개념을 기반으로 발전하고 있어, LLM 효율화 연구에서 매우 중요한 기점이 되는 논문이라고 볼 수 있다.