FlexAttention: The Flexibility of PyTorch with the Performance of FlashAttention [external]
Freeing users from the software lottery tyranny of fused attention implementations.
I have a blog post up on the PyTorch blog (as part of my day job) on FlexAttention: https://pytorch.org/blog/flexattention/ (work done with Driss Guessous, Yanbo Liang, and Joy Dong). And here’s a tweet thread.
The beginning excerpted here:
In theory, Attention is All You Need. In practice, however, we also need optimized attention implementations like FlashAttention.
Although these fused attention implementations have substantially improved performance and enabled long contexts, this efficiency has come with a loss of flexibility. You can no longer try out a new attention variant by writing a few PyTorch operators - you often need to write a new custom kernel! This operates as a sort of “software lottery” for ML researchers - if your attention variant doesn’t fit into one of the existing optimized kernels, you’re doomed to slow runtime and CUDA OOMs.
For some examples of attention variants, we have Causal, Relative Positional Embeddings, Alibi, Sliding Window Attention, PrefixLM, Document Masking/Sample Packing/Jagged Tensors, Tanh Soft-Capping, PagedAttention, etc. Even worse, folks often want combinations of these! Sliding Window Attention + Document Masking + Causal + Context Parallelism? Or what about PagedAttention + Sliding Window + Tanh Soft-Capping?
The left picture below represents the state of the world today - some combinations of masking + biases + setting have existing kernels implemented. But the various options lead to an exponential number of settings, and so overall we end up with fairly spotty support. Even worse, new attention variants researchers come up with will have zero support.
To solve this hypercube problem once and for all, we introduce FlexAttention, a new PyTorch API.
We provide a flexible API that allows implementing many attention variants (including all the ones mentioned in the blog post so far) in a few lines of idiomatic PyTorch code.
We lower this into a fused FlashAttention kernel through
torch.compile
, generating a FlashAttention kernel that doesn’t materialize any extra memory and has performance competitive with handwritten ones.We also automatically generate the backwards pass, leveraging PyTorch’s autograd machinery.
Finally, we can also take advantage of sparsity in the attention mask, resulting in significant improvements over standard attention implementations.
With FlexAttention, we hope that trying new attention variants will only be limited by your imagination.
My next post (~70% done) will also be about attention! This one will be a historical retrospective on why Tri Dao was the one to invent FlashAttention, and not any of the large tech companies.
There’s some other interesting topics I’m considering writing about.
In particular, some potential titles:
What’s the point of ML compilers when Attention is All You Need? (how I think about building ML systems in a world where everybody uses one ML architecture)
Performance Metrics Were Made for Man, not Man for Performance Metrics (How do you choose the right performance metric?. In particular, flop counting)
My ML framework isn’t Obeying Mathematics! (a primer on floating point “nondeterminism” for machine learning settings)
Would love to see the post on nondeterminism! Think it would be a very useful resource