FlashAttention is a memory-efficient attention algorithm for transformer models that computes scaled dot-product attention without materializing the full attention matrix in GPU high-bandwidth memory. It uses kernel fusion and tiling to reduce memory reads and writes, making attention faster and enabling longer sequences at the same hardware budget.
What is FlashAttention?
In standard attention, the model forms an (n × n) attention score matrix for a sequence of length n, applies softmax, then multiplies by values. For long sequences, the attention matrix is large and becomes a memory bottleneck, even when compute is available. FlashAttention restructures the computation so that it streams blocks of queries, keys, and values through on-chip SRAM (shared memory) and computes partial softmax statistics incrementally. This avoids writing the full (n × n) matrix to global memory.
Practically, FlashAttention relies on a fused GPU kernel that combines multiple steps, such as matmul, scaling, masking, softmax, and the value-weighted sum, into a single efficient pass. It also uses numerically stable online softmax to ensure the result matches standard attention within floating-point error bounds. The result is lower peak memory usage and higher throughput, especially for long-context training and inference.
Where it’s used and why it matters
FlashAttention is widely used in LLM training stacks and serving engines to push sequence lengths higher without linear growth in memory overhead. It matters because attention memory often limits batch size and context length more than raw FLOPs. By reducing memory traffic, FlashAttention can improve training speed, reduce cost, and make long-context models more practical. It also improves latency and throughput in inference when attention dominates runtime, such as during long prompt processing.
Examples
- Long-context training: Train a transformer with longer sequences while keeping GPU memory stable.
- Efficient prefill: Speed up prompt processing for chat models with large context windows.
- Serving optimization: Combine with KV cache, paged attention, and quantization for higher concurrency.
FAQs
- Does FlashAttention change model outputs?
It is designed to be functionally equivalent to standard attention up to floating-point rounding differences. - Is FlashAttention only for training?
No. It can accelerate both training and inference, especially prefill. - When does FlashAttention help the most?
It helps most for long sequences where attention memory traffic dominates runtime. - Do all GPUs support it?
Performance depends on GPU architecture and kernel implementations, so support varies by framework and hardware generation.