FlashAttention-4: Algorithm and Kernel Pipelining
5 hours ago
- #Attention Mechanism
- #Deep Learning
- #GPU Optimization
- Asymmetric hardware scaling on modern GPUs like Blackwell leads to tensor core throughput growing faster than resources like shared memory bandwidth and special function units (SFUs).
- FlashAttention-4 is a co-designed algorithm and kernel for the Blackwell architecture, achieving up to 1605 TFLOPs/s on B200 with BF16, outperforming cuDNN and Triton.
- Key innovations include new pipelining for overlap, software emulation of exponential via polynomial approximation to mitigate SFU bottlenecks, and TMEM usage to reduce shared memory traffic.
- Blackwell hardware features like Tensor Memory (TMEM), fully asynchronous 5th gen tensor cores, and 2-CTA MMA mode enable larger tiles, reduced traffic, and improved performance.
- The backward pass is shared-memory bandwidth limited; optimizations include 2-CTA MMA to halve traffic and reduce atomic adds, plus deterministic mode for reproducible training.
- Scheduling improvements address load imbalance from causal masking and variable sequence lengths via grid linearization and longest-processing-time-first (LPT) scheduling.
- FlashAttention-4 is implemented in CuTe-DSL, cutting compile times significantly, and has influenced optimizations in cuDNN versions 9.13 and later.