
FlashAttention 해부: 박사과정 학생이 만든 커널이 AI 산업 전체를 바꿨다
GPT-3의 컨텍스트가 2K에 머물렀던 이유? 어텐션이 O(N²) 메모리를 잡아먹었기 때문이다. 한 박사과정 학생이 GPU 메모리 계층을 이해하고, 수학은 그대로 두되 메모리 접근만 바꿔서 2~4배 빠르고 10~20배 적은 메모리를 달성했다. 정확도 손실 0%.

GPT-3의 컨텍스트가 2K에 머물렀던 이유? 어텐션이 O(N²) 메모리를 잡아먹었기 때문이다. 한 박사과정 학생이 GPU 메모리 계층을 이해하고, 수학은 그대로 두되 메모리 접근만 바꿔서 2~4배 빠르고 10~20배 적은 메모리를 달성했다. 정확도 손실 0%.

당신은 요리사다. 주방 조리대(SRAM)는 아주 작지만, 재료를 꺼내면 즉시 쓸 수 있다. 식료품 저장고(HBM)는 지하실에 있어서 넓지만, 다녀오는 데 시간이 걸린다.
바보 요리사: 재료 하나가 필요할 때마다 지하실로 내려간다. 소금 가져오고, 후추 가져오고, 기름 가져오고... 요리 시간의 대부분을 계단 오르내리는 데 쓴다.
똑똑한 요리사 (FlashAttention): 레시피를 보고, 다음 몇 단계에 필요한 재료를 한 번에 가져온다. 조리대 위에서 모든 작업을 끝내고, 결과물만 접시에 담아 내보낸다. 총 요리 동작은 같지만, 지하실 왕복이 극적으로 줄었다.
이것이 FlashAttention의 핵심이다. 수학을 바꾸지 않고, 메모리 접근 패턴만 바꿔서 Transformer 어텐션을 24배 빠르게, 메모리를 1020배 적게 사용하게 만들었다. 정확도 손실? 정확히 0%.
그리고 이 최적화가 GPT-3의 2K 토큰 컨텍스트를 GPT-4의 128K, 나아가 100만 토큰 시대로 열어젖혔다.
Transformer의 셀프 어텐션은 시퀀스의 모든 위치가 다른 모든 위치를 "본다." 시퀀스 길이 N이면, N×N 어텐션 점수 행렬을 계산해야 한다.
시퀀스를 2배 늘리면 메모리가 4배. BERT는 512, GPT-3는 2,048 토큰에 갇혔던 이유다.
SRAM은 HBM보다 ~10배 빠르지만 4,000배 작다. 표준 어텐션은 N×N 행렬을 HBM에 반복적으로 읽고 쓴다. GPU의 연산 유닛은 대부분의 시간을 데이터가 도착하기를 기다리며 놀고 있다. 어텐션 연산의 50% 이상이 메모리 대기.
FlashAttention 이전, 연구자들은 N×N 어텐션을 근사하려 했다:
| 방법 | 아이디어 | 왜 실패했나 |
|---|---|---|
| Sparse Attention (2019) | 일부 위치만 어텐드 | 불규칙 메모리 접근 → IO 악화 |
| Linformer (2020) | 저랭크 근사 | 정보 손실, 품질 저하 |
| Performer (2020) | 커널 근사 | 실제 속도 향상 미미 |
| BigBird (2020) | 희소+랜덤+전역 | 구현 복잡, 퓨즈 불가 |
공통 문제: FLOPs는 줄었지만, 메모리 접근 패턴이 GPU에 비친화적이어서 실제 벽시계 속도가 나아지지 않았다. 그리고 근사이므로 정확도가 떨어졌다.
Tri Dao: Stanford CS 박사, Christopher Ré와 Stefano Ermon 공동 지도. 현재 Princeton 조교수 + Together AI 공동 창립자 겸 수석 과학자. Schmidt Sciences AI2050 펠로우, Google ML and Systems Junior Faculty Award. 이후 Mamba(상태 공간 모델), Flash-Decoding 등 개발.

표준 어텐션의 문제:
세 단계마다 거대한 행렬이 느린 HBM을 왕복. 총 HBM 접근: Θ(Nd + N²).
FlashAttention의 해법:
N×N 어텐션 행렬이 HBM에 절대 저장되지 않는다. SRAM에서 일시적으로만 존재하고, 최종 출력만 HBM에 기록.
일반 softmax는 전체 행을 봐야 계산 가능 (최댓값, 합 필요). 온라인 softmax (Milakov & Gimelshein, 2018)는 블록 단위로 점진적으로 계산:
블록 j 처리:
m_new = max(m_old, max(current_block)) # 러닝 최댓값 갱신
l_new = exp(m_old - m_new) * l_old # 이전 합 스케일 조정
+ sum(exp(current_block - m_new)) # 현재 블록 합 추가
O_new = rescale(O_old) + new_contribution # 출력 점진 갱신
블록을 하나씩 처리하면서도 정확한 softmax 결과를 유지한다. 수학적으로 표준 softmax와 비트 단위 동일.
표준 학습은 역전파를 위해 N×N 어텐션 행렬을 저장한다 → O(N²) 메모리. FlashAttention은 출력 O와 softmax 통계(m, l)만 저장하고, 역전파 시 Q, K, V 블록에서 어텐션 점수를 재계산한다.
d=64, M10⁵일 때 FlashAttention은 **925배 적은 HBM 접근**. 논문은 이것이 SRAM 크기 범위 내에서 증명 가능하게 최적(로그 팩터까지)임을 증명했다.
Long-Range Arena에서 역사적 첫 성과:
이전의 모든 Transformer는 이 과제에서 메모리 부족이거나 무작위 수준이었다.
각 세대가 새 GPU 아키텍처의 병목을 정확히 겨냥:
FlashAttention의 O(N) 메모리 없이는, 128K 컨텍스트가 2K 대비 약 4,000배 더 많은 어텐션 메모리를 요구하여 완전히 비실용적이었을 것이다.
scaled_dot_product_attention에 내장. 자동 호출attn_implementation="flash_attention_2"| 표준 어텐션 | 근사 방법들 | FlashAttention | |
|---|---|---|---|
| 정확도 | 정확 | 근사 (손실 있음) | 정확 |
| 메모리 | O(N²) | O(N) ~ O(N√N) | O(N) |
| 실제 속도 | 기준 | 종종 더 느림 (IO 악화) | 2~4× 빠름 |
| 추론 지연 | 기준 | 다양 | 동일 |
Tri Dao가 Stanford 박사과정 중 개발한 FlashAttention은 현재 PyTorch에 내장되고, 모든 주요 LLM 제공업체가 사용하며, 4개 주요 버전이 연속 GPU 아키텍처를 타겟한다. Tri Dao의 Google Scholar 인용 총 29,000+회. 그는 이후 Together AI를 공동 창립하고, Mamba(선형 시간 시퀀스 모델)를 개발하며, Princeton 조교수로 임용됐다.
이 글의 서사를 한 문장으로: 가장 큰 성능 향상은 모델 아키텍처를 바꾸는 것이 아니라, 알고리즘이 하드웨어와 상호작용하는 방식을 바꾸는 데서 나온다.
GPT-3 글에서 "규모가 질적 변화를 만든다"를, LoRA 글에서 "적응은 효율적이어야 한다"를 봤다. FlashAttention은 세 번째 교훈을 추가한다: "같은 수학이라도 하드웨어를 이해하면 모든 것이 달라진다."
GPT-3가 2K 토큰에 갇혀 있을 때, 해법은 더 큰 GPU가 아니었다. 더 영리한 메모리 접근이었다. 이 원리는 하드웨어가 더 전문화될수록 더 중요해진다. FlashAttention V1→V2→V3→V4가 보여주듯, 하드웨어가 진화하면 알고리즘도 함께 진화해야 한다.
코어닷투데이의 AI 제품에서 FlashAttention은 보이지 않는 기반이다. AI 아르스 키오스크의 실시간 응답, 의정지원 AI의 긴 정책 문서 처리, Sharp-PINN의 대규모 시계열 분석 — 이 모든 것이 O(N²)이 아닌 O(N) 메모리로 가능해진 것은 한 박사과정 학생이 "GPU의 지하실을 덜 왕복하는 법"을 찾았기 때문이다.