pytorch/torch
Angel Li 476b149a00 bwd pass (#164504)
**Summary**
This implements the backward pass for the Varlen API and registers `_varlen_attn()` as a custom op.

**Benchmarking**

To benchmark, we compare runtime and TFLOPs against the current SDPA approach with padding.

Settings:

- 1 H100 machine
- `batch_size=8`, `max_seq_len=2048`, `embed_dim=1024`, `num_heads=16`
- dtype `torch.bfloat16`
- `is_causal=False`
- for variable length, we set sequences to be random multiples of 64 up to `max_seq_len`
- 100 runs

|        | Variable Length API | SDPA     |
|--------|--------------------|----------|
| Runtime | 0.8189142608642578 ms       | 3.263883056640625 ms  |
| TFLOPs | 268.652       | 158.731  |

We can see that runtime for Varlen is >3x faster

**Testing**

Run `python test/test_varlen_attention.py` for unit tests where we verify basic functionality and confirm numerical match between varlen gradients vs SDPA.

For custom op testing, `test_custom_op_registration` uses logging mode to verify that `_varlen_attn()` was called and tests with `torch.compile`. `test_custom_op_compliances` uses `torch.library.opcheck()` to verify.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164504
Approved by: https://github.com/drisspg
2025-10-30 03:46:37 +00:00
..
_awaits
_C [ROCm][CUDA] add unit test utility busy_wait_for_flag (#166218) 2025-10-29 22:40:23 +00:00
_C_flatbuffer
_custom_op Fix flake8 B028 warnings (#166224) 2025-10-26 06:18:55 +00:00
_decomp address DDE in matmul decomp (#166541) 2025-10-30 03:19:29 +00:00
_dispatch Fix pyrefly ignores 1/n (#166239) 2025-10-26 00:44:10 +00:00
_dynamo Revert "[pytree] add treespec_{leaf,tuple,dict} functions for args_spec modification (#160843)" 2025-10-29 22:46:48 +00:00
_export Revert "[pytree] add treespec_{leaf,tuple,dict} functions for args_spec modification (#160843)" 2025-10-29 22:46:48 +00:00
_functorch Fix comparing inductor actual strides vs bw graph for activations should not throw DDE. (#166277) 2025-10-30 00:34:05 +00:00
_higher_order_ops Revert "[1/N] Remove unused loop variables (#166258)" 2025-10-29 11:10:37 +00:00
_inductor [xpu][test] Reuse native_mm and mix_order_reduction for Intel GPU. (#166384) 2025-10-30 03:38:35 +00:00
_lazy Fix pyrefly ignores 1/n (#166239) 2025-10-26 00:44:10 +00:00
_library Fix pyrefly ignore syntax (#166438) 2025-10-29 00:02:21 +00:00
_logging Clean up unused Pyrefly suppressions (#166178) 2025-10-25 05:32:21 +00:00
_numpy Enable PLW0127 in ruff (#165851) 2025-10-21 03:30:57 +00:00
_prims Fix pyrefly ignore syntax (#166438) 2025-10-29 00:02:21 +00:00
_prims_common Fix pyrefly ignore syntax (#166438) 2025-10-29 00:02:21 +00:00
_refs Fix pyrefly error syntax (2/n) (#166448) 2025-10-29 00:36:40 +00:00
_strobelight Fix pyrefly error syntax (2/n) (#166448) 2025-10-29 00:36:40 +00:00
_subclasses Revert "[1/N] Remove unused loop variables (#166258)" 2025-10-29 11:10:37 +00:00
_vendor
accelerator Fix pyrefly ignores 1/n (#166239) 2025-10-26 00:44:10 +00:00
amp Fix pyrefly error syntax (2/n) (#166448) 2025-10-29 00:36:40 +00:00
ao Revert "[pytree] add treespec_{leaf,tuple,dict} functions for args_spec modification (#160843)" 2025-10-29 22:46:48 +00:00
autograd Revert "[1/N] Remove unused loop variables (#166258)" 2025-10-29 11:10:37 +00:00
backends Fix pyrefly error syntax (2/n) (#166448) 2025-10-29 00:36:40 +00:00
compiler Fix pyrefly error syntax (2/n) (#166448) 2025-10-29 00:36:40 +00:00
contrib
cpu Fix pyrefly error syntax (2/n) (#166448) 2025-10-29 00:36:40 +00:00
csrc shrink_group implementation to expose ncclCommShrink API (#164518) 2025-10-30 01:50:54 +00:00
cuda [ROCm][CUDA] add unit test utility busy_wait_for_flag (#166218) 2025-10-29 22:40:23 +00:00
distributed Enable local tensor mode for DTensor attention and convolution tests (#166406) 2025-10-30 02:48:02 +00:00
distributions Fix pyrelfy ignore syntax in distributions and ao (#166248) 2025-10-26 22:13:48 +00:00
export Revert "[pytree] add treespec_{leaf,tuple,dict} functions for args_spec modification (#160843)" 2025-10-29 22:46:48 +00:00
fft
func
futures
fx [symbolic shapes] remove maybe_guard_rel warning (#166553) 2025-10-30 00:57:28 +00:00
headeronly Add TORCH_TARGET_VERSION for stable ABI (#164356) 2025-10-29 15:41:28 +00:00
jit Revert "[1/N] Remove unused loop variables (#166258)" 2025-10-29 11:10:37 +00:00
legacy
lib [2/N] Mark unused parameters in C++ code (#165121) 2025-10-15 03:04:39 +00:00
linalg
masked Fix syntax for pyrefly errors (#166496) 2025-10-29 20:00:25 +00:00
monitor
mps
mtia Fix pyrefly ignores 1/n (#166239) 2025-10-26 00:44:10 +00:00
multiprocessing Fix syntax for pyrefly errors (#166496) 2025-10-29 20:00:25 +00:00
nativert [triton][nativert] Add num_cpu_threads for triton-cpu (#166255) 2025-10-28 08:40:04 +00:00
nested Fix syntax for pyrefly errors (#166496) 2025-10-29 20:00:25 +00:00
nn bwd pass (#164504) 2025-10-30 03:46:37 +00:00
numa Fix syntax for pyrefly errors (#166496) 2025-10-29 20:00:25 +00:00
onnx [ONNX] Ignore pyrefly errors in torchlib (#166588) 2025-10-30 03:43:52 +00:00
optim Fix syntax for pyrefly errors (#166496) 2025-10-29 20:00:25 +00:00
package Fix syntax for pyrefly errors (#166496) 2025-10-29 20:00:25 +00:00
profiler Fix syntax for pyrefly errors (#166496) 2025-10-29 20:00:25 +00:00
quantization Fix syntax for pyrefly errors (#166496) 2025-10-29 20:00:25 +00:00
signal Fix syntax for pyrefly errors (#166496) 2025-10-29 20:00:25 +00:00
sparse Fix syntax for pyrefly errors (#166496) 2025-10-29 20:00:25 +00:00
special
testing Enable local tensor mode for DTensor attention and convolution tests (#166406) 2025-10-30 02:48:02 +00:00
utils Revert "[pytree] add treespec_{leaf,tuple,dict} functions for args_spec modification (#160843)" 2025-10-29 22:46:48 +00:00
xpu Introduce a new API torch.xpu.set_per_process_memory_fraction (#165510) 2025-10-29 03:24:52 +00:00
__config__.py
__future__.py
__init__.py Fix pyrefly ignore syntax (#166438) 2025-10-29 00:02:21 +00:00
_appdirs.py
_classes.py
_compile.py
_custom_ops.py
_environment.py
_guards.py Fix pyrefly ignores 1/n (#166239) 2025-10-26 00:44:10 +00:00
_jit_internal.py Fix pyrefly ignore syntax (#166438) 2025-10-29 00:02:21 +00:00
_linalg_utils.py
_lobpcg.py Fix pyrefly ignore syntax (#166438) 2025-10-29 00:02:21 +00:00
_lowrank.py
_meta_registrations.py [Inductor][Triton][FP8] Support deepseek-style scaling in Inductor (#164404) 2025-10-28 03:38:54 +00:00
_namedtensor_internals.py
_ops.py Fix pyrefly ignore syntax (#166438) 2025-10-29 00:02:21 +00:00
_python_dispatcher.py
_size_docs.py
_sources.py
_storage_docs.py
_streambase.py
_tensor_docs.py
_tensor_str.py Fix pyrefly error syntax (2/n) (#166448) 2025-10-29 00:36:40 +00:00
_tensor.py Fix flake8 B028 warnings (#166224) 2025-10-26 06:18:55 +00:00
_thread_safe_fork.py
_torch_docs.py Clarrifying input output angle unit in the docs for trigonometric fun… (#161248) 2025-10-18 11:53:48 +00:00
_utils_internal.py Fix pyrefly error syntax (2/n) (#166448) 2025-10-29 00:36:40 +00:00
_utils.py Fix flake8 B028 warnings (#166224) 2025-10-26 06:18:55 +00:00
_VF.py
_vmap_internals.py
_weights_only_unpickler.py Fix flake8 B028 warnings (#166224) 2025-10-26 06:18:55 +00:00
CMakeLists.txt [ROCm] Use a ROCm version string without hash. (#166336) 2025-10-28 03:53:55 +00:00
custom_class_detail.h Mark unused parameters in C++ code (#164912) 2025-10-09 06:23:25 +00:00
custom_class.h Mark unused parameters in C++ code (#164912) 2025-10-09 06:23:25 +00:00
extension.h
functional.py Fix pyrefly ignores 1/n (#166239) 2025-10-26 00:44:10 +00:00
header_only_apis.txt Move toUnderlying to headeronly (#165694) 2025-10-22 05:31:16 +00:00
hub.py Fix flake8 B028 warnings (#166224) 2025-10-26 06:18:55 +00:00
library.h Mark unused parameters in C++ code (#164912) 2025-10-09 06:23:25 +00:00
library.py Fix syntax for pyrefly errors (#166496) 2025-10-29 20:00:25 +00:00
overrides.py Fix flake8 B028 warnings (#166224) 2025-10-26 06:18:55 +00:00
py.typed
quasirandom.py
random.py Fix flake8 B028 warnings (#166224) 2025-10-26 06:18:55 +00:00
return_types.py
script.h
serialization.py Fix syntax for pyrefly errors (#166496) 2025-10-29 20:00:25 +00:00
storage.py Fix pyrefly ignores 1/n (#166239) 2025-10-26 00:44:10 +00:00
torch_version.py
types.py Enable PLC0414 on ruff (#165828) 2025-10-22 04:56:52 +00:00
version.py.tpl