본문 바로가기
Paper review/Efficiency

FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness

by 오서영 2025. 11. 14.

논문 원본: https://arxiv.org/abs/2205.14135

 

FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness

Transformers are slow and memory-hungry on long sequences, since the time and memory complexity of self-attention are quadratic in sequence length. Approximate attention methods have attempted to address this problem by trading off model quality to reduce

arxiv.org


Introduction

FlashAttention은 기존 Transformer의 attention이 느려지는 핵심 원인이 연산량이 아니라 메모리 입출력이라는 점에 집중해서 만들어진 알고리즘이다. 특히 길이가 긴 시퀀스를 처리할 때는 QK 전치곱, softmax, 그리고 결과를 V와 곱하는 과정에서 매우 큰 중간 행렬이 생성된다. 이 전체 행렬을 GPU 메모리에 반복적으로 읽고 쓰는 과정에서 병목이 발생하기 때문에, 이 부분을 줄이지 않는 한 FLOPs를 줄여도 실제 속도 개선이 거의 나타나지 않는다.

FlashAttention은 기존 attention과 수학적으로 동일한 결과를 내면서도 이 중간 행렬을 아예 생성하지 않는 방식으로 계산 패턴을 재구성한다. 이를 통해 시퀀스 길이가 길어질수록 크게 증가하던 메모리 사용량을 줄이고, 실제 추론 속도를 2배 이상 향상시키는 효과를 얻었다.

(1) 왼쪽: GPU 메모리 계층도 SRAM → HBM → DRAM 순으로 용량은 작지만 대역폭이 매우 빠른 메모리(SRAM) 용량은 크지만 느린 메모리(HBM/DRAM) 이렇게 계층적으로 구성되어 있음을 보여준다. (2) 가운데: FlashAttention의 타일링 구조 기존 attention에서는 QKᵀ 전체 NxN 행렬을 만들어 HBM에 저장하지만, FlashAttention은 이를 타일 단위로 나누어 계산한다. K, V를 블록 단위로 SRAM에 복사 Q도 블록 단위로 SRAM에 불러옴 SRAM 안에서 바로 QKᵀ, softmax, (softmax)V까지 모든 계산 수행 최종 결과만 HBM으로 기록 즉, NxN 행렬을 절대 materialize하지 않는다. (3) 오른쪽: PyTorch vs FlashAttention 속도 비교 PyTorch는 matmul → mask → softmax → dropout 등 여러 커널을 순차 실행하지만, FlashAttention은 이를 하나의 fused kernel로 처리한다. 결과적으로 GPT-2 attention에서 7.6배 속도 향상이 난다는 것을 보여준다.

 

Preliminary

기존의 scaled dot-product attention은 크게 세 단계를 거친다.

  1. QK 전치곱으로 attention score 행렬을 생성
  2. score에 대해 row-wise softmax
  3. softmax 결과를 V와 곱해 최종 output 생성

문제는 첫 번째 단계에서 만들어지는 score 행렬의 크기가 n×n이라는 점이다. 시퀀스 길이가 늘어날수록 이 행렬이 매우 커지고, softmax나 V와의 곱을 계산할 때 매번 GPU 메모리(HBM)에서 이 행렬을 다시 읽어와야 한다. 이 과정 때문에 전체 attention 연산은 메모리 중심(workload)으로 분석되며, 실제 GPU 구조에서는 연산장치보다 메모리 대역폭이 먼저 한계에 걸린다.

FlashAttention에서는 이 구조적 병목을 해결하기 위해 Q, K, V를 작은 블록 단위로 분할하고, GPU의 온칩 SRAM에서만 계산이 이루어지도록 전체 로직을 재설계한다. 이렇게 하면 어텐션 전체 행렬을 메모리에 저장할 필요가 없어지고, HBM에서 SRAM으로의 데이터 이동량이 크게 줄어든다.

Method

FlashAttention의 핵심은 IO-aware tiling 방식이다. 전체 attention을 한 번에 계산하는 대신 다음과 같이 처리한다.

  1. Q를 작은 타일로 나누어 SRAM에 올린다.
  2. 각 Q 타일마다 K와 V를 chunk 단위로 읽어와 QK 곱을 계산한다.
  3. softmax는 전체 score를 한 번에 보지 않고, 각 타일마다 최대값과 exp 합을 누적하는 방식으로 계산해서 정확도를 유지한다.
  4. softmax 결과를 기반으로 partial output을 반복적으로 계산하고 누적한다.
  5. 마지막에 모든 partial output을 합쳐 최종 attention 출력을 얻는다.

이 과정에서 중간 n×n score 행렬이 생성되지 않기 때문에 메모리 사용량은 크게 감소한다. 또한 softmax를 타일 단위로 분해해 계산하더라도 running max와 running sum을 유지하면 원래 softmax와 동일한 값을 복원할 수 있기 때문에 정확도 손실이 없다. FlashAttention이 “근사 attention”이 아니라 완전히 동일한 full attention이라는 점이 중요한 특징이다.

(1) 왼쪽 테이블: Standard vs FlashAttention 성능 비교 FLOPs는 거의 비슷하지만 HBM read/write는 40GB → 4.4GB로 거의 10배 감소 실제 runtime도 41.7ms → 7.3ms로 대폭 줄어듦 FlashAttention의 핵심이 “연산 최적화가 아니라 I/O 최적화”라는 점을 보여준다. (2) 가운데 그래프: Block size에 따른 HBM 접근량·런타임 블록 크기를 크게 할수록 HBM 접근량이 줄고 속도도 빨라지지만 너무 크게 잡으면 SRAM에 다 못 올라가기 때문에 중간 지점에서 최적이 생긴다. FlashAttention이 타일 크기를 설계할 때 I/O와 SRAM 크기를 함께 고려해야 하는 이유를 설명한다. (3) 오른쪽 그래프: Block-sparse FlashAttention의 추가 속도 향상 시퀀스 길이가 4K인 경우 dense FlashAttention보다 block-sparse FlashAttention이 훨씬 더 빠르다. 희소성이 높아질수록 속도가 비례해서 개선된다는 점을 보여준다.
(1) 왼쪽: Attention runtime (forward + backward) 시퀀스 길이가 짧을 때는 기존 구현도 빠르지만 길이가 길어질수록 PyTorch·Megatron 등은 기하급수적으로 느려진다. 반면 FlashAttention은 증가 속도가 훨씬 완만해서 중간 지점에서 crossover가 발생한다. 즉, 시퀀스가 길수록 FlashAttention의 이점이 극대화된다는 것을 보여준다. (2) 오른쪽: Attention memory usage 메모리 사용량을 시퀀스 길이별로 비교한 그래프다. 기존 Transformer Attention은 O(n²) 메모리 증가 FlashAttention은 block 단위로 계산하므로 증가 폭이 낮다. Linformer 등 저랭크 기반의 다른 접근들과 비교했을 때도, FlashAttention은 동일하거나 더 낮은 메모리 사용량을 유지한다.

 

Experiment

실험 결과 FlashAttention은 다음과 같은 효과를 보였다.

  • GPT-2 및 GPT-Neo 모델 기준 최대 2.4배 속도 향상
  • 시퀀스 길이가 길어질수록 속도 향상 폭 증가
  • 메모리 사용량 약 2~3배 감소
  • full attention과 출력이 완전히 동일

특히 GPU 자원을 한계까지 활용해야 하는 긴 문맥 처리나 대형 모델 추론 환경에서 FlashAttention은 일반 attention 대비 확실한 속도 이점을 가진다. 연산량 자체는 동일하지만, 메모리 이동량이 감소하여 실제 latency가 크게 개선되기 때문이다.

Related Work

기존 연구들은 attention 연산량을 줄이기 위해 low-rank approximation이나 sparse attention을 이용해 근사하는 방식이 많았다. 하지만 이러한 방법들은 정확도가 떨어진다는 단점이 있다. 다른 접근으로는 메모리 최적화나 chunk 단위 연산을 이용한 개선이 있었으나, 이들 역시 정확한 attention 결과를 유지하지 못하거나 구현 제약이 컸다. FlashAttention은 이들과 달리 정확한 full attention을 유지하면서도 메모리 사용을 근본적으로 줄이는 방식이라는 점에서 차별점을 가진다.

Conclusion

FlashAttention은 Transformer attention이 가진 구조적 병목을 메모리 관점에서 다시 분석하고, 이를 해결하기 위한 IO-aware tile-based 알고리즘을 제안한 연구다. 중간 행렬을 생성하지 않도록 설계하여 HBM-SRAM 간 데이터 이동을 최소화했으며, 이를 통해 정확도 손실 없이 실제 추론 속도를 크게 향상시켰다. 이후 FlashAttention-2, xFormers 최적화, MLA·CQ 같은 효율적 attention 연구들도 모두 이 논문이 제시한 IO-awareness 개념을 기반으로 발전하고 있어, LLM 효율화 연구에서 매우 중요한 기점이 되는 논문이라고 볼 수 있다.