mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
**Summary** Today, the only way to have variable sequence length support in PyTorch attention is through nested tensors [here](https://docs.pytorch.org/tutorials/intermediate/scaled_dot_product_attention_tutorial.html#nestedtensor-and-dense-tensor-support). We also want to add an explicit lower-level API that provides variable sequence length support without padding/masking in SDPA. This PR builds out `varlen_attn`, the public API that users can call for the forward method, and `_varlen_attn`, the private API that calls into the Flash Attention/cuDNN backend. **Benchmarking** To benchmark, we compare runtime and TFLOPs against the current SDPA approach with padding. Settings: - 1 H100 machine - `batch_size=8`, `max_seq_len=2048`, `embed_dim=1024`, `num_heads=16` - dtype `torch.bfloat16` - `is_causal=False` - for variable length, we set sequences to be random multiples of 64 up to `max_seq_len` - 100 runs | | Variable Length API | SDPA | |--------|--------------------|----------| | Runtime | 0.21750560760498047 ms | 0.43171775817871094 ms | | TFLOPs | 231.812 | 320.840 | The sparsity is 0.453 which we can see matches the speedup we get from Varlen (approx 50%). TFLOPs remains around the same, with SDPA slightly larger due to potential higher overhead and total flops scaling with sequence length. **Testing** Run `python test/test_varlen_attention.py` for unit tests where we verify basic functionality and confirm numerical match between varlen outputs vs SDPA. **Next steps** Next steps from this PR (higher in the stack) include registering the private API `_varlen_attn` as a custom op, implementing backward support, and enabling cuDNN with correct numerics. (This stack builds on top of #162326) Pull Request resolved: https://github.com/pytorch/pytorch/pull/164502 Approved by: https://github.com/v0i0, https://github.com/drisspg
35 lines
490 B
ReStructuredText
35 lines
490 B
ReStructuredText
.. role:: hidden
|
|
:class: hidden-section
|
|
|
|
torch.nn.attention
|
|
==================
|
|
|
|
.. automodule:: torch.nn.attention
|
|
|
|
Utils
|
|
-------------------
|
|
.. autosummary::
|
|
:toctree: generated
|
|
:nosignatures:
|
|
|
|
sdpa_kernel
|
|
SDPBackend
|
|
|
|
Submodules
|
|
----------
|
|
.. autosummary::
|
|
:nosignatures:
|
|
|
|
flex_attention
|
|
bias
|
|
experimental
|
|
varlen
|
|
|
|
.. toctree::
|
|
:hidden:
|
|
|
|
nn.attention.flex_attention
|
|
nn.attention.bias
|
|
nn.attention.experimental
|
|
nn.attention.varlen
|