coredot.today
FlashAttention 해부: 박사과정 학생이 만든 커널이 AI 산업 전체를 바꿨다
블로그로 돌아가기
FlashAttentionTransformerGPU 최적화메모리 계층긴 컨텍스트

FlashAttention 해부: 박사과정 학생이 만든 커널이 AI 산업 전체를 바꿨다

GPT-3의 컨텍스트가 2K에 머물렀던 이유? 어텐션이 O(N²) 메모리를 잡아먹었기 때문이다. 한 박사과정 학생이 GPU 메모리 계층을 이해하고, 수학은 그대로 두되 메모리 접근만 바꿔서 2~4배 빠르고 10~20배 적은 메모리를 달성했다. 정확도 손실 0%.

코어닷투데이2025-11-0424

들어가며: 더 빠르게 계산하는 게 아니라, 덜 걸어다니는 것

FlashAttention: 작은 주방에서 효율적으로 요리하기

당신은 요리사다. 주방 조리대(SRAM)는 아주 작지만, 재료를 꺼내면 즉시 쓸 수 있다. 식료품 저장고(HBM)는 지하실에 있어서 넓지만, 다녀오는 데 시간이 걸린다.

바보 요리사: 재료 하나가 필요할 때마다 지하실로 내려간다. 소금 가져오고, 후추 가져오고, 기름 가져오고... 요리 시간의 대부분을 계단 오르내리는 데 쓴다.

똑똑한 요리사 (FlashAttention): 레시피를 보고, 다음 몇 단계에 필요한 재료를 한 번에 가져온다. 조리대 위에서 모든 작업을 끝내고, 결과물만 접시에 담아 내보낸다. 총 요리 동작은 같지만, 지하실 왕복이 극적으로 줄었다.

이것이 FlashAttention의 핵심이다. 수학을 바꾸지 않고, 메모리 접근 패턴만 바꿔서 Transformer 어텐션을 24배 빠르게, 메모리를 1020배 적게 사용하게 만들었다. 정확도 손실? 정확히 0%.

그리고 이 최적화가 GPT-3의 2K 토큰 컨텍스트를 GPT-4의 128K, 나아가 100만 토큰 시대로 열어젖혔다.


제1장: 왜 어텐션이 병목이었는가

O(N²)의 저주

Transformer의 셀프 어텐션은 시퀀스의 모든 위치가 다른 모든 위치를 "본다." 시퀀스 길이 N이면, N×N 어텐션 점수 행렬을 계산해야 한다.

420만 N=2,048 (GPT-3) 헤드당 어텐션 점수
6,700만 N=8,192 4배 길어지면 16배
164억 N=128,000 (GPT-4) 헤드당 164억 점수!
0.07→18 GB GPT-2 어텐션 메모리 512토큰 → 8,192토큰

시퀀스를 2배 늘리면 메모리가 4배. BERT는 512, GPT-3는 2,048 토큰에 갇혔던 이유다.

GPU 메모리 계층: 진짜 병목은 어디인가

NVIDIA A100 GPU 메모리 계층
SRAM (온칩) ~20 MB | 19 TB/s 매우 빠르지만 매우 작다
HBM (글로벌 메모리) 80 GB | 2 TB/s 넓지만 SRAM보다 ~10배 느리다

SRAM은 HBM보다 ~10배 빠르지만 4,000배 작다. 표준 어텐션은 N×N 행렬을 HBM에 반복적으로 읽고 쓴다. GPU의 연산 유닛은 대부분의 시간을 데이터가 도착하기를 기다리며 놀고 있다. 어텐션 연산의 50% 이상이 메모리 대기.

💡
핵심 통찰: 어텐션의 병목은 연산 수(FLOPs)가 아니라 메모리 접근(IO)이다. 연산량을 줄이는 근사 방법이 실제로 빨라지지 않는 이유: 더 적은 연산을 하지만 메모리 접근 패턴이 더 나쁘다. FlashAttention의 혁신은 연산량이 아니라 IO 복잡도를 최적화한 것이다.

근사 어텐션이 실패한 이유

FlashAttention 이전, 연구자들은 N×N 어텐션을 근사하려 했다:

방법아이디어왜 실패했나
Sparse Attention (2019)일부 위치만 어텐드불규칙 메모리 접근 → IO 악화
Linformer (2020)저랭크 근사정보 손실, 품질 저하
Performer (2020)커널 근사실제 속도 향상 미미
BigBird (2020)희소+랜덤+전역구현 복잡, 퓨즈 불가

공통 문제: FLOPs는 줄었지만, 메모리 접근 패턴이 GPU에 비친화적이어서 실제 벽시계 속도가 나아지지 않았다. 그리고 근사이므로 정확도가 떨어졌다.


제2장: FlashAttention — 수학은 그대로, IO만 바꾸기

논문 기본 정보

  • 저자: Tri Dao, Daniel Y. Fu, Stefano Ermon, Atri Rudra, Christopher Ré
  • 소속: Stanford University, University at Buffalo
  • 발표: NeurIPS 2022 (arXiv 2022년 5월 27일)
  • 인용: 3,300+ (Semantic Scholar), Tri Dao 전체 29,000+

Tri Dao: Stanford CS 박사, Christopher Ré와 Stefano Ermon 공동 지도. 현재 Princeton 조교수 + Together AI 공동 창립자 겸 수석 과학자. Schmidt Sciences AI2050 펠로우, Google ML and Systems Junior Faculty Award. 이후 Mamba(상태 공간 모델), Flash-Decoding 등 개발.

핵심: 타일링(Tiling) + 재계산(Recomputation)

표준 어텐션 vs FlashAttention

표준 어텐션의 문제:

  1. S = QK^T 계산 → N×N 행렬을 HBM에 저장
  2. P = softmax(S) → N×N 행렬을 HBM에서 읽고 다시 저장
  3. O = PV → N×N 행렬을 HBM에서 읽음

세 단계마다 거대한 행렬이 느린 HBM을 왕복. 총 HBM 접근: Θ(Nd + N²).

FlashAttention의 해법:

Q, K, V를 SRAM에 맞는 블록으로 분할
K, V 블록 하나를 SRAM에 로드
SRAM에서 로컬 어텐션 계산 (온라인 softmax)
누적 결과를 점진적 업데이트
다음 블록 → ⟳ 반복
최종 결과만 HBM에 1회 저장

N×N 어텐션 행렬이 HBM에 절대 저장되지 않는다. SRAM에서 일시적으로만 존재하고, 최종 출력만 HBM에 기록.

온라인 Softmax 트릭

일반 softmax는 전체 행을 봐야 계산 가능 (최댓값, 합 필요). 온라인 softmax (Milakov & Gimelshein, 2018)는 블록 단위로 점진적으로 계산:

블록 j 처리:
  m_new = max(m_old, max(current_block))     # 러닝 최댓값 갱신
  l_new = exp(m_old - m_new) * l_old         # 이전 합 스케일 조정
        + sum(exp(current_block - m_new))     # 현재 블록 합 추가
  O_new = rescale(O_old) + new_contribution  # 출력 점진 갱신

블록을 하나씩 처리하면서도 정확한 softmax 결과를 유지한다. 수학적으로 표준 softmax와 비트 단위 동일.

역전파: 저장 대신 재계산

표준 학습은 역전파를 위해 N×N 어텐션 행렬을 저장한다 → O(N²) 메모리. FlashAttention은 출력 O와 softmax 통계(m, l)만 저장하고, 역전파 시 Q, K, V 블록에서 어텐션 점수를 재계산한다.

💡
"더 많이 계산하는 게 더 빠르다"의 역설: 재계산은 FLOPs를 늘린다. 하지만 HBM에서 N×N 행렬을 읽는 것보다 SRAM에서 재계산하는 것이 더 빠르다. GPU 연산 유닛이 메모리를 기다리며 놀고 있었기 때문. 빈 연산 슬롯에 재계산을 채워넣는 것이다.

IO 복잡도와 최적성

  • 표준 어텐션 HBM 접근: Θ(Nd + N²)
  • FlashAttention HBM 접근: O(N²d²/M) (M = SRAM 크기)

d=64, M10⁵일 때 FlashAttention은 **925배 적은 HBM 접근**. 논문은 이것이 SRAM 크기 범위 내에서 증명 가능하게 최적(로그 팩터까지)임을 증명했다.


제3장: 결과 — 정확도 0% 손실, 속도 2~4배

2~4× A100에서 벽시계 속도 향상 표준 PyTorch 어텐션 대비
10~20× 메모리 절감 O(N²) → O(N)
0% 정확도 손실 정확(exact) 어텐션, 근사 아님
15% BERT 학습 속도 향상 MLPerf 1.1 기록 대비

Long-Range Arena에서 역사적 첫 성과:

  • Path-X (16K 토큰): 61.4% — Transformer로 최초로 무작위 이상의 성능 달성
  • Path-256 (64K 토큰): 63.1% — 이 역시 최초 (블록-희소 FlashAttention)

이전의 모든 Transformer는 이 과제에서 메모리 부족이거나 무작위 수준이었다.


제4장: 진화 — V2, V3, V4

2022.5FlashAttention V1NeurIPS 2022. 타일링+재계산. A100에서 2~4× 속도, 10~20× 메모리 절감
2023.7FlashAttention-2작업 분할 개선, 루프 재배치. V1 대비 2× 빠름. 이론적 최대의 50~73% 달성
2024.7FlashAttention-3 (H100용)워프 특화, FP8, 비동기 처리. H100에서 740 TFLOPS. NeurIPS 2024 Spotlight
2026.3FlashAttention-4 (Blackwell용)소프트웨어 지수 에뮬레이션, 조건부 리스케일링. B200에서 1,613 TFLOPS

각 세대가 새 GPU 아키텍처의 병목을 정확히 겨냥:

  • V1: HBM 대역폭 병목 (A100)
  • V2: GPU 점유율 병목 (A100 최적화)
  • V3: 비동기 실행 병목 (H100 Hopper)
  • V4: 지수 함수 연산 유닛 병목 (B200 Blackwell)

제5장: FlashAttention이 없었다면 불가능했을 것들

긴 컨텍스트 혁명

컨텍스트 길이 진화 (FlashAttention 기여)
BERT (2018)
512
GPT-3 (2020)
2K
GPT-4 (2023)
128K
Gemini 1.5 (2024)
1M+

FlashAttention의 O(N) 메모리 없이는, 128K 컨텍스트가 2K 대비 약 4,000배 더 많은 어텐션 메모리를 요구하여 완전히 비실용적이었을 것이다.

어디서나 사용되는 FlashAttention

  • PyTorch 2.0+: scaled_dot_product_attention에 내장. 자동 호출
  • HuggingFace: attn_implementation="flash_attention_2"
  • 모든 주요 LLM 학습: GPT-4, Claude, Gemini, LLaMA, Mistral
  • 추론 서빙: vLLM, SGLang, FlashInfer

제6장: 유사 개념 비교

🐢 표준 어텐션
메모리: O(N²)
N×N 행렬을 HBM에 저장
이론적 최대의 25~40% 활용
정확(exact)
⚡ FlashAttention
메모리: O(N) — 선형!
N×N 행렬이 HBM에 절대 안 감
이론적 최대의 50~73% 활용
정확(exact) — 근사 아님
표준 어텐션근사 방법들FlashAttention
정확도정확근사 (손실 있음)정확
메모리O(N²)O(N) ~ O(N√N)O(N)
실제 속도기준종종 더 느림 (IO 악화)2~4× 빠름
추론 지연기준다양동일

제7장: 한 박사과정 학생이 만든 커널

🎯
FlashAttention의 교훈: Transformer의 수학을 한 글자도 바꾸지 않았다. 동일한 입력에 동일한 출력. 기여는 전적으로 계산이 하드웨어에 매핑되는 방식을 바꾼 것이다. 메모리 계층을 존중하고, 데이터 이동을 최소화하고, 연산을 퓨즈한다. 이 시스템 수준의 알고리즘 혁신이 대부분의 아키텍처 혁신보다 실전적 영향이 컸다.

Tri Dao가 Stanford 박사과정 중 개발한 FlashAttention은 현재 PyTorch에 내장되고, 모든 주요 LLM 제공업체가 사용하며, 4개 주요 버전이 연속 GPU 아키텍처를 타겟한다. Tri Dao의 Google Scholar 인용 총 29,000+회. 그는 이후 Together AI를 공동 창립하고, Mamba(선형 시간 시퀀스 모델)를 개발하며, Princeton 조교수로 임용됐다.


맺으며: 알고리즘과 하드웨어 사이의 간극

이 글의 서사를 한 문장으로: 가장 큰 성능 향상은 모델 아키텍처를 바꾸는 것이 아니라, 알고리즘이 하드웨어와 상호작용하는 방식을 바꾸는 데서 나온다.

GPT-3 글에서 "규모가 질적 변화를 만든다"를, LoRA 글에서 "적응은 효율적이어야 한다"를 봤다. FlashAttention은 세 번째 교훈을 추가한다: "같은 수학이라도 하드웨어를 이해하면 모든 것이 달라진다."

GPT-3가 2K 토큰에 갇혀 있을 때, 해법은 더 큰 GPU가 아니었다. 더 영리한 메모리 접근이었다. 이 원리는 하드웨어가 더 전문화될수록 더 중요해진다. FlashAttention V1→V2→V3→V4가 보여주듯, 하드웨어가 진화하면 알고리즘도 함께 진화해야 한다.

코어닷투데이의 AI 제품에서 FlashAttention은 보이지 않는 기반이다. AI 아르스 키오스크의 실시간 응답, 의정지원 AI의 긴 정책 문서 처리, Sharp-PINN의 대규모 시계열 분석 — 이 모든 것이 O(N²)이 아닌 O(N) 메모리로 가능해진 것은 한 박사과정 학생이 "GPU의 지하실을 덜 왕복하는 법"을 찾았기 때문이다.