pytorch/docs/source
Joel Schlosser 8ba9063002 FlexAttention support for NJT (#136792)
This PR adds FlexAttention + NJT support. In particular:
* To handle raggedness, treats the packed sequence dim of input NJTs as a giant "stacked sequence". To ensure user `score_mod` / `mask_mod` functions can still be written in the original NJT sequence space, this PR handles conversions for indices within the giant "stacked sequence" -> sequence relative indices automatically.
* Provides `py_impls` for `NestedTensor` to the HOPs for flex attention forward / backward that simply wrap / unwrap NJTs appropriately
* Adds barebones `new_empty()` support to NJT since FlexAttention utilizes this repeatedly; right now, only `new_empty()` with a shape of `()` is supported
* Tests that FlexAttention with a causal mask matches causal SDPA
* Adds a new public API for FlexAttention usage:
    * `create_nested_block_mask(mask_mod, B, H, njt, BLOCK_SIZE, _compile)` - NJT analogue for `create_block_mask()` that utilizes the `njt`'s ragged structure to create an appropriately-sized block mask (e.g. `(1, 1, total_seqlen, total_seqlen)`). This function handles the index conversion from "stacked sequence" space -> relative sequence space.
      * Minor note: as this is a public API, this function is purposefully named with "nested" instead of "njt" to keep the latter as an informal, mostly internal-only term.

Example usage:
```python
def causal_mask(b, h, q_idx, kv_idx):
    return q_idx >= kv_idx

query = ... # NJT of shape (B, H, S*, D)
key = ... # NJT of shape (B, H, S*, D)
value = ... # NJT of shape (B, H, S*, D)
# create_nested_block_mask() automatically converts indices from "stacked sequence" space -> relative sequence space
block_mask = create_nested_block_mask(causal_mask, 1, 1, query)  # block mask conceptual shape is (B, H, sum(S*), sum(S*))
output = flex_attention(query, key, value, block_mask=block_mask)

def causal_score_mod(score, b, h, q_idx, kv_idx):
    return torch.where(q_idx >= kv_idx, score, float("-inf"))

# flex_attention() automatically converts indices from "stacked sequence" space -> relative sequence space for NJT inputs
output2 = flex_attention(query, key, value, score_mod=causal_score_mod)
```

TODO:
* ~~Determine the right level of abstraction for public API helpers + move them alongside other helpers~~ Verify this with others though
* ~~Some cleanup~~
* ~~`njt_score_mod_adapter`~~
* ~~Q: should `create_njt_block_mask()` call `njt_mask_mod_adapter()` so we don't need two calls?~~
* Can we avoid materializing the `sum(s)` length `seq_idx` used for conversion between stacked sequence -> sequence relative indices?
    * Not for now, although future work may deepen the integration between Flex + NJT (possibly requiring custom templates). We should try to cache this though.
* ~~Demonstrate non-causal mask~~
* Support non-contiguous NJTs with holes (**booted to future PR**)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/136792
Approved by: https://github.com/drisspg
ghstack dependencies: #138841
2024-10-28 20:01:27 +00:00
..
_static Add Programmable Google Search (#137716) 2024-10-18 18:18:16 +00:00
_templates Add Programmable Google Search (#137716) 2024-10-18 18:18:16 +00:00
community Update maintainers for inductor and x86 CPU (#136839) 2024-10-11 07:24:07 +00:00
elastic DOC: add docstring to construct_and_record_rdzv_event() (#128189) 2024-06-10 22:17:33 +00:00
notes reimport pr137735 due to merging check issues (#138959) 2024-10-27 16:31:34 +00:00
rpc [Doc] fix some typos (found by codespell and typos) (#132544) 2024-08-05 17:21:56 +00:00
scripts [Doc] fix some typos (found by codespell and typos) (#132544) 2024-08-05 17:21:56 +00:00
accelerator.rst Introduce a device-agnostic runtime API design (#132204) 2024-10-27 10:37:09 +00:00
amp.rst Update document for autocast on CPU (#135299) 2024-09-13 09:11:47 +00:00
autograd.rst Add torch.library.register_autograd (#124071) 2024-04-18 12:47:59 +00:00
backends.rst Clarify opt-einsum usage, fix #127109 (#137596) 2024-10-09 20:31:24 +00:00
benchmark_utils.rst Adding Compare in torch.utils.benchmark documentation (#125009) 2024-05-03 00:50:54 +00:00
bottleneck.rst
checkpoint.rst [checkpoint] Clean up selective activation checkpoint and make public (#125795) 2024-06-18 18:18:50 +00:00
complex_numbers.rst Document complex optimizer semantic behavior (#121667) 2024-03-16 00:43:47 +00:00
cond.rst [Doc] fix some typos (found by codespell and typos) (#132544) 2024-08-05 17:21:56 +00:00
conf.py Update copyrights to 2024 (#138638) 2024-10-22 21:00:58 +00:00
config_mod.rst
cpp_extension.rst
cpp_index.rst
cpu.rst Add current_device() to torch.cpu (#110987) 2023-10-11 05:13:10 +00:00
cuda_environment_variables.rst Add doc page for environment variables that effect PyTorch Runtime (#119087) 2024-02-15 21:41:38 +00:00
cuda._sanitizer.rst
cuda.rst raw_alloc ignores PYTORCH_NO_CUDA_MEMORY_CACHING (#131114) 2024-10-04 15:36:29 +00:00
cuda.tunable.rst [ROCm] Tunableop record untuned (#128813) 2024-10-09 21:59:03 +00:00
cudnn_persistent_rnn.rst
cudnn_rnn_determinism.rst
data.rst Revert "reseed all Generators in Dataloader's _worker_loop() -- via GC (#107131)" 2023-08-23 17:08:07 +00:00
ddp_comm_hooks.rst
debugging_environment_variables.rst Add doc page for environment variables that effect PyTorch Runtime (#119087) 2024-02-15 21:41:38 +00:00
deploy.rst
deterministic.rst Add torch.utils.deterministic.fill_uninitialized_memory flag (#111377) 2023-11-01 16:10:09 +00:00
distributed.algorithms.join.rst
distributed.checkpoint.rst [Doc] fix some typos (found by codespell and typos) (#132544) 2024-08-05 17:21:56 +00:00
distributed.elastic.rst Reapply "distributed debug handlers (#126601)" (#127805) 2024-06-04 19:44:30 +00:00
distributed.optim.rst
distributed.pipelining.rst [Pipelining] Refactor Interleaved1F1B and ZeroBubble (#137783) 2024-10-16 03:05:14 +00:00
distributed.rst [reland][dtensor] move DTensor to public namespace (#134203) 2024-09-08 17:08:40 +00:00
distributed.tensor.parallel.rst Update link in distributed.tensor.parallel.rst (#136103) 2024-09-15 19:36:29 +00:00
distributed.tensor.rst [dtensor][experimental] expose DTensor Context Parallel API (#137038) 2024-10-02 18:00:23 +00:00
distributions.rst Add inverse gamma distribution and fix sign bug in PowerTransform. (#104501) 2023-11-01 02:26:25 +00:00
dlpack.rst
docutils.conf
export.ir_spec.rst [export] Remove torch._export.export (#119095) 2024-02-08 21:22:04 +00:00
export.rst Replace torch.export default decomp table to be lazily populated (#137650) 2024-10-18 19:28:52 +00:00
fft.rst
fsdp.rst [FSDP][state_dict] Expose optimizer state_dict config (#105949) 2023-08-21 07:29:49 +00:00
func.api.rst
func.batch_norm.rst
func.migrating.rst
func.rst
func.ux_limitations.rst
func.whirlwind_tour.rst
future_mod.rst Add swap_tensors path to nn.Module._apply (#117167) 2024-02-07 18:55:44 +00:00
futures.rst
fx.experimental.rst Remove parallel_and and parallel_or (#138135) 2024-10-23 00:22:22 +00:00
fx.rst Consolidate SymDispatchMode into ProxyTensorMode (#132674) 2024-08-08 12:02:54 +00:00
hub.rst
index.rst Introduce a device-agnostic runtime API design (#132204) 2024-10-27 10:37:09 +00:00
jit_builtin_functions.rst
jit_language_reference_v2.rst [Doc] fix some typos (found by codespell and typos) (#132544) 2024-08-05 17:21:56 +00:00
jit_language_reference.rst [Doc] fix some typos (found by codespell and typos) (#132544) 2024-08-05 17:21:56 +00:00
jit_python_reference.rst
jit_unsupported.rst Add support for torch.Generator type in TorchScript (#110413) 2023-11-21 23:07:21 +00:00
jit_utils.rst
jit.rst Doc test non packages (#110568) 2023-10-06 14:16:01 +00:00
library.rst Link directly to new Custom Ops Landing Page (#137933) 2024-10-15 21:18:21 +00:00
linalg.rst
logging.rst Change classification to beta for TORCH_LOGS (#118682) 2024-01-31 21:50:55 +00:00
masked.rst Add MaskedTensor passthrough: unfold, F.Unfold, F.Fold, stack (#125262) 2024-09-06 19:06:23 +00:00
math-quantizer-equation.png
meta.rst Add documentation for meta device (#119119) 2024-02-04 01:05:22 +00:00
miscellaneous_environment_variables.rst Add environment variable to force no weights_only load (#138225) 2024-10-21 23:26:15 +00:00
mobile_optimizer.rst Add ExecuTorch warning to mobile_optimizer (#134697) 2024-09-04 17:47:14 +00:00
model_zoo.rst
module_tracker.rst Add module tracker (#125352) 2024-05-04 18:33:35 +00:00
monitor.rst
mps_environment_variables.rst [MPS] Add mps profiler env vars to docs (#129552) 2024-07-04 06:44:48 +00:00
mps.rst Add support in Python API for the recommended max working set size. (#128289) 2024-06-12 16:03:57 +00:00
mtia.rst [MTIA] Support torch.cuda.get_device_capability equivalent API on MTIA (#135889) 2024-09-17 17:42:56 +00:00
multiprocessing.rst Doc test non packages (#110568) 2023-10-06 14:16:01 +00:00
name_inference.rst [docs] Properly link register_post_accumulate_grad_hook docs (#108157) 2023-08-29 22:13:33 +00:00
named_tensor.rst fixing named tensor unflatten example (#106921) 2023-08-22 18:00:10 +00:00
nested.rst
nn.attention.bias.rst Remove sdp_kernel and replace with sdpa_kernel in attention namespace (#114689) 2024-01-24 22:28:04 +00:00
nn.attention.flex_attention.rst FlexAttention support for NJT (#136792) 2024-10-28 20:01:27 +00:00
nn.attention.rst Make FlexAttention API public (#130755) 2024-07-16 16:21:25 +00:00
nn.functional.rst Add RMSNorm module (#121364) 2024-03-29 18:05:28 +00:00
nn.init.rst
nn.rst Make adding Buffers more like adding Parameters (#125971) 2024-07-31 10:32:40 +00:00
onnx_dynamo_onnxruntime_backend.rst Follow-up #108379 (#108905) 2023-09-09 01:38:36 +00:00
onnx_dynamo.rst [ONNX] Improves documentation of ONNX exporter (#135372) 2024-09-09 15:09:01 +00:00
onnx_torchscript_supported_aten_ops.rst Refactor torch.onnx documentation (#108379) 2023-09-08 18:23:48 +00:00
onnx_torchscript.rst [ONNX] Remove deprecated export_to_pretty_string (#137790) 2024-10-21 18:17:48 +00:00
onnx.rst [ONNX] Improves documentation of ONNX exporter (#135372) 2024-09-09 15:09:01 +00:00
optim.rst Documentation Update: Fix Missing Whitespace in Optimizer Docs (#138321) 2024-10-18 15:41:43 +00:00
package.rst Doc test non packages (#110568) 2023-10-06 14:16:01 +00:00
profiler.rst Doc test non packages (#110568) 2023-10-06 14:16:01 +00:00
quantization-accuracy-debugging.rst
quantization-backend-configuration.rst
quantization-support.rst Update pt2e numeric debugger to use node.meta["custom"] field (#134040) 2024-08-27 19:51:03 +00:00
quantization.rst Cleanup some duplicated placeholder py:module docs (#123244) 2024-04-05 03:18:53 +00:00
random.rst
rpc.rst [BE] RPC is missing RRef docs (#106902) 2023-08-10 16:26:27 +00:00
signal.rst
size.rst Added a docstring for torch.Size.numel. (#124186) 2024-04-19 09:23:02 +00:00
sparse.rst SparseCsrCUDA: cuDSS backend for linalg.solve (#129856) 2024-08-22 07:57:30 +00:00
special.rst
storage.rst
tensor_attributes.rst Refine the logic of device construction when only device index is given (#129119) 2024-07-15 14:34:29 +00:00
tensor_view.rst
tensorboard.rst
tensors.rst add xpu to torch.tensors (#127280) 2024-06-11 18:13:01 +00:00
testing.rst
threading_environment_variables.rst Add doc page for environment variables that effect PyTorch Runtime (#119087) 2024-02-15 21:41:38 +00:00
torch_cuda_memory.rst Fix typo under docs directory (#110359) 2023-10-03 16:36:05 +00:00
torch_environment_variables.rst [Docs][MPS] Add mps environment variable table (#129008) 2024-06-20 03:30:35 +00:00
torch_nccl_environment_variables.rst [c10d][doc] Add docs for ENV variables TORCH_NCCL_ASYNC_ERROR_HANDLING TORCH_NCCL_TRACE_CPP_STACK and TORCH_NCCL_COORD_CHECK_MILSEC (#132920) 2024-08-09 21:08:20 +00:00
torch.ao.ns._numeric_suite_fx.rst
torch.ao.ns._numeric_suite.rst
torch.compiler_aot_inductor.rst [AOTI] docs: add suggestion to turn on freezing on CPU (#128010) 2024-06-07 08:57:02 +00:00
torch.compiler_api.rst [dynamo] add torch.compiler.set_stance (#137504) 2024-10-16 16:18:25 +00:00
torch.compiler_best_practices_for_backends.rst
torch.compiler_cudagraph_trees.rst [CUDAGraph] add more docs for cudagraph trees (#127963) 2024-06-18 02:07:07 +00:00
torch.compiler_custom_backends.rst Fix a link in the compiler backend doc (#126079) 2024-05-21 20:16:04 +00:00
torch.compiler_dynamic_shapes.rst feat: Add min, max ranges to mark_dynamic API (#119737) 2024-03-07 23:26:03 +00:00
torch.compiler_dynamo_deepdive.rst Stop immediately specializing common constants 0/1 for plain int (#128327) 2024-07-03 16:41:51 +00:00
torch.compiler_dynamo_overview.rst Rename TorchDynamo -> Dyanamo in the dynamo tutorial doc (#123431) 2024-05-07 05:07:00 +00:00
torch.compiler_fake_tensor.rst [BE] Reroute all uses of proxy_tensor.maybe_disable_fake_tensor_mode to fake_tensor.unset_fake_temporarily (#132770) 2024-08-08 23:07:23 +00:00
torch.compiler_faq.rst [dynamo] Retire CompileProfiler (#135133) 2024-09-05 01:08:40 +00:00
torch.compiler_fine_grain_apis.rst [Doc] fix some typos (found by codespell and typos) (#132544) 2024-08-05 17:21:56 +00:00
torch.compiler_get_started.rst [Inductor] Update AttrsDescriptor instantiation for Triton changes (#137458) 2024-10-14 20:20:29 +00:00
torch.compiler_inductor_profiling.rst
torch.compiler_ir.rst [export] torch.export landing page (#108783) 2023-09-10 01:40:42 +00:00
torch.compiler_nn_module.rst Revert "Reland 3rd try [finishing colesbury's PR 100642] Guard on nn.Module dicts and type (#109323)" + Forward fixes + test (#110964) 2023-10-11 05:16:47 +00:00
torch.compiler_performance_dashboard.rst
torch.compiler_profiling_torch_compile.rst [EZ] Fix spelling typo (#136157) 2024-09-16 19:30:30 +00:00
torch.compiler_transformations.rst Fix typo under docs directory (#110359) 2023-10-03 16:36:05 +00:00
torch.compiler_troubleshooting.rst Add link to torch.compile the missing manual in troubleshooting (#137301) 2024-10-04 18:19:30 +00:00
torch.compiler.rst add xpu to torch.compile (#127279) 2024-06-13 21:15:09 +00:00
torch.overrides.rst Doc test non packages (#110568) 2023-10-06 14:16:01 +00:00
torch.rst Introduce torch.sym_add, variadic add (#138660) 2024-10-23 17:42:41 +00:00
type_info.rst
utils.rst New swap function (#111747) 2023-12-08 18:49:35 +00:00
xpu.rst Add torch.xpu.get_arch_list and torch.xpu.get_gencode_flags for XPU (#137773) 2024-10-18 02:28:08 +00:00