Disabled by default for now behind `TORCH_CUDNN_SDPA_NESTED_TENSOR_ENABLED=1`
Just wanted to get this out before starting a series of SDPA cleanup PRs---the biggest thing is we don't need the boilerplate around all of the `build_graph_and_tensors*` functions anymore as we can now use the `UID`-style referencing of tensor nodes as was done for the Conv-V8 API backend.
CC @drisspg
Pull Request resolved: https://github.com/pytorch/pytorch/pull/141178
Approved by: https://github.com/jbschlosser
Fixes 3 issues:
1. The test wasn't actually testing SDPA: both were checking cuda, and the inputs to SDPA were not transposed.
2. FlopCounterMode has been renamed _FlopCounterMode (and a wrapper named FlopCounterMode has been added)
3. offsets_to_list also needs to ignore the actual offset values if offsets is a meta tensor.
Differential Revision: [D69558785](https://our.internmc.facebook.com/intern/diff/D69558785)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/147032
Approved by: https://github.com/jbschlosser
Fixes#146404
Adds changes to the matmul and matmul_backward operation for nested jagged tensors, to support back propagation when the output is a regular strided tensor.
This required adding support for the nested matmul operation to work when the nested tensor wasn't 'self', i.e
`A@B` where `A` isn't nested but `B` is.
The operation schemas had to be updated to reflect that either input can be a strided tensor instead (and the gradient), so an extra assertion is added in an edge case where neither input is nested.
Unit tests are also added.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/146405
Approved by: https://github.com/soulitzer, https://github.com/jbschlosser
Fixes#144761
This PR adds NJT impls for those *_like functions that were previously missing:
* `full_like()`
* `rand_like()`
* `randint_like()`
It also fixes a bug in existing *_like functions when a new device is specified. Fix is to also transfer `offsets` / `lengths` to the new device.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144889
Approved by: https://github.com/soulitzer
Updated nested tensor docs to be NJT-centric (instead of NST-centric). They now include:
* High-level description of NST vs. NJT + a recommendation to use NJT
* General NJT construction / usage
* torch.compile() integration w/ dynamic shapes
* Common errors and how to fix them
* Contribution guide
* Data layout / shape information (with diagram)
* Links to more extensive tutorials involving Transformers / SDPA / FlexAttention
Pull Request resolved: https://github.com/pytorch/pytorch/pull/145402
Approved by: https://github.com/soulitzer
Part of my BE project addressing NJT bugs surfaced via OpInfo tests.
This PR implements missing backward support for NJT matmul. Notably, for dense tensors, matmul dispatches to bmm. However, due to historical reasons related to NST, NJT handles matmul directly, and thus can't rely on the CompositeImplicit impl of matmul to get the derivative formula.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144587
Approved by: https://github.com/soulitzer
ghstack dependencies: #144586
Part of my BE project addressing NJT bugs surfaced via OpInfo tests.
This PR implements the missing `fill.Scalar` support, which works fine for contiguous inputs, but there is still some AOTAutograd debugging required to handle non-contiguous transposed NJTs.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144586
Approved by: https://github.com/soulitzer
Part of my BE project addressing NJT bugs surfaced via OpInfo tests.
Before this PR, `frexp()` for NJT was handled via the unary pointwise fallback. The op returns a tuple, however, and the fallback doesn't handle that. This PR defines an explicit impl for `frexp()` that wraps both returned `(mantissa, exponent)` as NJTs.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144585
Approved by: https://github.com/soulitzer
ghstack dependencies: #144582, #144583, #144584
Part of my BE project addressing NJT bugs surfaced via OpInfo tests.
Implements `chunk()` backward on the batch dim, which was left out before. This PR unbinds the components and invokes `copy_()` on these to pass along the appropriate gradients.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144584
Approved by: https://github.com/soulitzer
ghstack dependencies: #144582, #144583
Part of my BE project addressing NJT bugs surfaced via OpInfo tests.
`value_selecting_reduction_backward()` is used in the backward for min / max, so this PR implements it for NJT. Notably, this isn't enough for reducing over the ragged dim, since that results in a dense tensor and thus NJT's torch_dispatch will not be called for this op. We need factory function support for nested ints to fix that case.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144583
Approved by: https://github.com/soulitzer
ghstack dependencies: #144582
Allow mutations mutations for subclasses that are non-contiguous.
Changes:
Removing assert in collect_metadata_analysis
Main requested testcase:
Compilation of NJT.index_put()
Adding test in test_nestedtensor.py, that compiles NJT.index_put()
It is decomposed to NJT split,unbind, which needed additional `torch._check`, `torch._check_is_size` for NJT.unbind() and guard_size_oblivious() usage in _meta_registrations and _inductor/lowering.py.
Special case:
If tangent is mutated outside of the graph, it does not participate in backward graph. Autograd in this case will set this tangent to zeros tensor.
We handle it separately in CompiledFunction.backward: not doing any processing for this tangent and broadcast to number of expected subclass unwrapped arguments.
disabling for dynamo 2 tests:
1/ For nested tensor - symbolic shapes issue on nested_tensor index operation that does splits [0, 0, 0] - there is a failure with "pending unbacked symints". This PR does not add more .tolist()/item() ops than it was before.
2/ As we do not fail with exception in collect_metadata_analysis new paths for dynamo started working and it started failing with smth strange that set_ in storage_offset (because of test for views) handling updates storage "cpu" -> "meta"
Pull Request resolved: https://github.com/pytorch/pytorch/pull/139630
Approved by: https://github.com/bdhirsh
**Background:** conversion from outer dim -> inner dim makes the (previously valid) assumption that the ragged dim is immediately next to the batch dim. This is no longer the case after #137125.
This PR:
* Updates the outer dim -> inner dim conversion logic to match the actual ragged_idx. Since ragged_idx tells us where the packed ragged / batch dim is, both ragged and batch outer dims should map to this inner dim. The conversion logic must now take in `ragged_idx` to make this possible, so the PR updates all call-sites to pass this.
* Fixes outputs across keepdim settings when reducing over ragged / batch dims.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/142173
Approved by: https://github.com/drisspg
This PR contains three `unsqueeze()`-related fixes for NJT:
1. Adjusts the output's `_ragged_idx` when `unsqueeze()` inserts a dim before the ragged dim
2. Corrects the unbind reference for `unsqueeze()` after the last input dim. For this case, the dim kwarg canonicalization logic needs to be applied wrt `inp.dim() + 1` to account for `dim=-1` properly
3. Adds ragged dim support to `unsqueeze()`, allowing for e.g. `(B, j1, D) -> (B, 1, j1, D)`. This is okay now after #137125
Note that `unsqueeze()` still doesn't support batch dim operation, and arguably should never support this.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/141392
Approved by: https://github.com/cpuhrsch
ghstack dependencies: #141500, #140736, #140161
This PR introduces `ExtraOpData`, a structure that contains op metadata regarding whether the op is a view and the dim-related args it accepts. It also populates a huge database for dim-wise / view ops with this info.
Test logic (sample input generation, references) have been updated to utilize this data. It allows for a fairly generic set of sample inputs & a reference for the class of ops that accept a single NJT and operate dim-wise (AKA "unary dimwise ops").
Testing is added over the following ops:
* `chunk()`
* `narrow()`
* `select()`
* `split()`
* `split_with_sizes()`
* `squeeze()`
* `unflatten()`
* `unsqueeze()`
Most of the above do not operate on the ragged / batch dims or on non-contiguous NJTs, so the proper xfails are added as needed.
I also slipped in a couple minor fixes (sorry):
1. The `_wrap_jagged_dim()` helper now avoids assuming the `nt._ragged_idx == 1` and allows for a batch dim to be a valid input, disambiguating the converted inner dim as necessary through an additional `operating_on_batch` return value (i.e. both dim=0 and dim=1 map to dim=0 on the inner values tensor, since that dim represents a packed ragged dim for all batch items)
2. Padded dense -> NJT conversion requires shape gymnastics to operate with the restrictive FBGEMM kernel. The gymnastics were slightly wrong for the transposed NJT case, and this PR fixes that
Pull Request resolved: https://github.com/pytorch/pytorch/pull/140161
Approved by: https://github.com/Skylion007, https://github.com/cpuhrsch
ghstack dependencies: #141500, #140736
**Background:** It's common to use `scalar_tensor()` in the input to `where()` to convert any scalars present to compatible tensors with matching options, *including layout*. This shows up in various places, notably including derivative formulas ([example](78491d6afc/tools/autograd/derivatives.yaml (L432-L434))). It causes problems for NJTs because they have `layout=torch.jagged` and it never makes sense to create a scalar tensor with this layout. Some of the breakage only seems to happen in CI for reasons I don't fully understand (see the revert of #140736 due to softshrink's derivative formula).
**This PR:**
* Allows non-contiguous NJT inputs to `where()` + adds tests for this
* Handles scalar tensor / dense tensor inputs for `condition` / `other` + adds tests for this
* Uses limited `broadcast_tensors()` / `broadcast_to()` support
* Improves `expand()` to work on non-contig NJTs
* Changes `scalar_tensor()` to use `torch.strided` instead of `torch.jagged` in both eager and torch.compile (i.e. meta registration)
* Changes backward formulas for `sinc`, `pow`, `special.i1`, and `special.i1e` to uses `scalar_tensor()` instead of e.g. `zeros({})`
**Alternative approach:** Update all problematic usages of `scalar_tensor()` to avoid ever passing `layout=torch.jagged`. This is an extensive change and includes `torch.where()` logic, a bunch of derivative formulas, and likely other places not yet discovered.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/141500
Approved by: https://github.com/malfet, https://github.com/cpuhrsch, https://github.com/soulitzer
This PR contains three `unsqueeze()`-related fixes for NJT:
1. Adjusts the output's `_ragged_idx` when `unsqueeze()` inserts a dim before the ragged dim
2. Corrects the unbind reference for `unsqueeze()` after the last input dim. For this case, the dim kwarg canonicalization logic needs to be applied wrt `inp.dim() + 1` to account for `dim=-1` properly
3. Adds ragged dim support to `unsqueeze()`, allowing for e.g. `(B, j1, D) -> (B, 1, j1, D)`. This is okay now after #137125
Note that `unsqueeze()` still doesn't support batch dim operation, and arguably should never support this.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/141392
Approved by: https://github.com/cpuhrsch
ghstack dependencies: #140736, #140161
This PR introduces `ExtraOpData`, a structure that contains op metadata regarding whether the op is a view and the dim-related args it accepts. It also populates a huge database for dim-wise / view ops with this info.
Test logic (sample input generation, references) have been updated to utilize this data. It allows for a fairly generic set of sample inputs & a reference for the class of ops that accept a single NJT and operate dim-wise (AKA "unary dimwise ops").
Testing is added over the following ops:
* `chunk()`
* `narrow()`
* `select()`
* `split()`
* `split_with_sizes()`
* `squeeze()`
* `unflatten()`
* `unsqueeze()`
Most of the above do not operate on the ragged / batch dims or on non-contiguous NJTs, so the proper xfails are added as needed.
I also slipped in a couple minor fixes (sorry):
1. The `_wrap_jagged_dim()` helper now avoids assuming the `nt._ragged_idx == 1` and allows for a batch dim to be a valid input, disambiguating the converted inner dim as necessary through an additional `operating_on_batch` return value (i.e. both dim=0 and dim=1 map to dim=0 on the inner values tensor, since that dim represents a packed ragged dim for all batch items)
2. Padded dense -> NJT conversion requires shape gymnastics to operate with the restrictive FBGEMM kernel. The gymnastics were slightly wrong for the transposed NJT case, and this PR fixes that
Pull Request resolved: https://github.com/pytorch/pytorch/pull/140161
Approved by: https://github.com/Skylion007, https://github.com/cpuhrsch
ghstack dependencies: #140736
Per discussion with @malfet, only allow weights_only unpickler to load NJT if `torch.nested` and `torch._dynamo` are imported
(this is slightly weird as technically `torch.nested` is actually imported by default and `torch._dynamo.decorators._DimRange` is actually what needs to be imported)
we can't import this from `torch.nested` as this would
- undo dynamo lazy import
- cause circular import
===========================
Redo of https://github.com/pytorch/pytorch/pull/140304 caused issues as `torch.nested._internal.foo` needs to be imported, which causes issues like
```python
torch/_weights_only_unpickler.py", line 339, in load
if full_path in _get_allowed_globals():
torch/_weights_only_unpickler.py", line 188, in _get_allowed_globals
torch.nested._internal.nested_tensor.NestedTensor
AttributeError: module 'torch.nested' has no attribute '_internal'
```
**This likely wasn't caught in our CI because imports are global during unit tests(?), so we use subprocess to properly test this time**
Differential Revision: [D65961691](https://our.internmc.facebook.com/intern/diff/D65961691)
@jbschlosser
Pull Request resolved: https://github.com/pytorch/pytorch/pull/140739
Approved by: https://github.com/malfet
This PR contains several fixes related to non-contiguous NJTs:
1. Propagates `lengths` through op calls appropriately (see desc of #138098)
* SDPA now calls `nested_view_from_values_offsets_lengths()` instead of `nested_view_from_values_offsets()`
2. Allows non-contig NJTs in unsqueeze / transpose / select
3. Expands padded dense -> NJT conversion to support non-contig NJTs
4. (unrelated sorry) Updates `split` / `split_with_sizes` to allow for optional `dim`, matching the ATen signature
Pull Request resolved: https://github.com/pytorch/pytorch/pull/140160
Approved by: https://github.com/cpuhrsch
Fixes#137512
Relaxes the restriction that the ragged dim is immediately next to the batch dim e.g. `(B, *, D_0, ..., D_N)`. This allows for constructing NJTs of shape e.g. `(B, D, j0)` directly. It's possible before this PR to get an NJT of e.g. shape `(B, D, j0)` by constructing an NJT of shape `(B, j0, D)` and transposing it. This PR allows a user to go straight there without the transpose. The standard `torch.nested.nested_tensor(list)` constructor has been updated to support this.
At the very least, this is useful for testing on transposed NJTs. I'm willing to make this functionality private if needed.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/137125
Approved by: https://github.com/cpuhrsch, https://github.com/soulitzer
I'm sick of reductions not working properly - spotty dim coverage, missing backwards, etc. This PR fixes quite a bit.
It applies to the following ops:
* `sum` / `mean` / `prod`
* `all` / `any`
* `amin` / `amax`
* `min` / `max`
* `argmin` / `argmax`
The general reduction logic has been factored out into a helper `_apply_reduction(func, func_name, identity_element, *args, **kwargs)`. The idea is that by providing a valid identity element, we can utilize conversions to padded dense when needed for reducing over the ragged dim.
Extensive test coverage includes:
* reductions across ragged dim
* reductions across non-batch, non-ragged dims
* reductions across both batch and ragged dims
* multiple dim reductions (for ops that support this)
* full reduction -> scalar
Bonus: the PR includes backwards fixes for `sum` and `mean`, which have never worked.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/139317
Approved by: https://github.com/cpuhrsch
This PR adds FlexAttention + NJT support. In particular:
* To handle raggedness, treats the packed sequence dim of input NJTs as a giant "stacked sequence". To ensure user `score_mod` / `mask_mod` functions can still be written in the original NJT sequence space, this PR handles conversions for indices within the giant "stacked sequence" -> sequence relative indices automatically.
* Provides `py_impls` for `NestedTensor` to the HOPs for flex attention forward / backward that simply wrap / unwrap NJTs appropriately
* Adds barebones `new_empty()` support to NJT since FlexAttention utilizes this repeatedly; right now, only `new_empty()` with a shape of `()` is supported
* Tests that FlexAttention with a causal mask matches causal SDPA
* Adds a new public API for FlexAttention usage:
* `create_nested_block_mask(mask_mod, B, H, njt, BLOCK_SIZE, _compile)` - NJT analogue for `create_block_mask()` that utilizes the `njt`'s ragged structure to create an appropriately-sized block mask (e.g. `(1, 1, total_seqlen, total_seqlen)`). This function handles the index conversion from "stacked sequence" space -> relative sequence space.
* Minor note: as this is a public API, this function is purposefully named with "nested" instead of "njt" to keep the latter as an informal, mostly internal-only term.
Example usage:
```python
def causal_mask(b, h, q_idx, kv_idx):
return q_idx >= kv_idx
query = ... # NJT of shape (B, H, S*, D)
key = ... # NJT of shape (B, H, S*, D)
value = ... # NJT of shape (B, H, S*, D)
# create_nested_block_mask() automatically converts indices from "stacked sequence" space -> relative sequence space
block_mask = create_nested_block_mask(causal_mask, 1, 1, query) # block mask conceptual shape is (B, H, sum(S*), sum(S*))
output = flex_attention(query, key, value, block_mask=block_mask)
def causal_score_mod(score, b, h, q_idx, kv_idx):
return torch.where(q_idx >= kv_idx, score, float("-inf"))
# flex_attention() automatically converts indices from "stacked sequence" space -> relative sequence space for NJT inputs
output2 = flex_attention(query, key, value, score_mod=causal_score_mod)
```
TODO:
* ~~Determine the right level of abstraction for public API helpers + move them alongside other helpers~~ Verify this with others though
* ~~Some cleanup~~
* ~~`njt_score_mod_adapter`~~
* ~~Q: should `create_njt_block_mask()` call `njt_mask_mod_adapter()` so we don't need two calls?~~
* Can we avoid materializing the `sum(s)` length `seq_idx` used for conversion between stacked sequence -> sequence relative indices?
* Not for now, although future work may deepen the integration between Flex + NJT (possibly requiring custom templates). We should try to cache this though.
* ~~Demonstrate non-causal mask~~
* Support non-contiguous NJTs with holes (**booted to future PR**)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/136792
Approved by: https://github.com/drisspg
ghstack dependencies: #138841
Before this PR, NJT would dispatch e.g. `NJT * nested_int` to `mul.Tensor`, wrongly interpreting the SymInt as a tensor and outputting garbage. This PR verifies that there are no nested ints in the list of args before dispatching for pointwise ops.
I originally tried checking that `the number of passed tensor args == the number of func schema tensor args`, but this wrongly disallows `nt * 2`, which (non-intuitively to me at least at first) dispatches via the `mul.Tensor` overload.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/138602
Approved by: https://github.com/soulitzer