Forcing Flash Attention onto a TPU and Learning the Hard Way
6 days ago
- #JAX
- #Flash Attention
- #TPU
- The post discusses porting Flash Attention from GPU to TPU using JAX, highlighting challenges and insights.
- JAX's functional programming model requires immutable arrays and pure functions, differing from Triton's mutable pointers.
- The initial Flash Attention implementation on TPU was slower than standard attention due to XLA's fusion optimizations.
- Using `vmap` instead of `fori_loop` for outer Q blocks improved performance by 45x, revealing compiler optimization opportunities.
- TPU's architecture, with its MXU systolic arrays and large VMEM, handles matrix operations differently than GPU, making some optimizations unnecessary.
- The post explores TPU's memory hierarchy and how XLA's fusion avoids intermediate HBM round-trips, contrasting with GPU's need for manual tiling.
- A systolic array emulator was built to understand TPU's matmul timing, showing how data flows through the hardware.
- Production-grade attention on TPU uses `jax.nn.dot_product_attention`, which matches or outperforms manual implementations by leveraging XLA optimizations.
- Key lesson: Hardware and compiler capabilities dictate the value of optimizations like Flash Attention, which is essential on GPU but less so on TPU for certain setups.