Fixes#137512
Relaxes the restriction that the ragged dim is immediately next to the batch dim e.g. `(B, *, D_0, ..., D_N)`. This allows for constructing NJTs of shape e.g. `(B, D, j0)` directly. It's possible before this PR to get an NJT of e.g. shape `(B, D, j0)` by constructing an NJT of shape `(B, j0, D)` and transposing it. This PR allows a user to go straight there without the transpose. The standard `torch.nested.nested_tensor(list)` constructor has been updated to support this.
At the very least, this is useful for testing on transposed NJTs. I'm willing to make this functionality private if needed.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/137125
Approved by: https://github.com/cpuhrsch, https://github.com/soulitzer
I'm sick of reductions not working properly - spotty dim coverage, missing backwards, etc. This PR fixes quite a bit.
It applies to the following ops:
* `sum` / `mean` / `prod`
* `all` / `any`
* `amin` / `amax`
* `min` / `max`
* `argmin` / `argmax`
The general reduction logic has been factored out into a helper `_apply_reduction(func, func_name, identity_element, *args, **kwargs)`. The idea is that by providing a valid identity element, we can utilize conversions to padded dense when needed for reducing over the ragged dim.
Extensive test coverage includes:
* reductions across ragged dim
* reductions across non-batch, non-ragged dims
* reductions across both batch and ragged dims
* multiple dim reductions (for ops that support this)
* full reduction -> scalar
Bonus: the PR includes backwards fixes for `sum` and `mean`, which have never worked.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/139317
Approved by: https://github.com/cpuhrsch
This PR adds FlexAttention + NJT support. In particular:
* To handle raggedness, treats the packed sequence dim of input NJTs as a giant "stacked sequence". To ensure user `score_mod` / `mask_mod` functions can still be written in the original NJT sequence space, this PR handles conversions for indices within the giant "stacked sequence" -> sequence relative indices automatically.
* Provides `py_impls` for `NestedTensor` to the HOPs for flex attention forward / backward that simply wrap / unwrap NJTs appropriately
* Adds barebones `new_empty()` support to NJT since FlexAttention utilizes this repeatedly; right now, only `new_empty()` with a shape of `()` is supported
* Tests that FlexAttention with a causal mask matches causal SDPA
* Adds a new public API for FlexAttention usage:
* `create_nested_block_mask(mask_mod, B, H, njt, BLOCK_SIZE, _compile)` - NJT analogue for `create_block_mask()` that utilizes the `njt`'s ragged structure to create an appropriately-sized block mask (e.g. `(1, 1, total_seqlen, total_seqlen)`). This function handles the index conversion from "stacked sequence" space -> relative sequence space.
* Minor note: as this is a public API, this function is purposefully named with "nested" instead of "njt" to keep the latter as an informal, mostly internal-only term.
Example usage:
```python
def causal_mask(b, h, q_idx, kv_idx):
return q_idx >= kv_idx
query = ... # NJT of shape (B, H, S*, D)
key = ... # NJT of shape (B, H, S*, D)
value = ... # NJT of shape (B, H, S*, D)
# create_nested_block_mask() automatically converts indices from "stacked sequence" space -> relative sequence space
block_mask = create_nested_block_mask(causal_mask, 1, 1, query) # block mask conceptual shape is (B, H, sum(S*), sum(S*))
output = flex_attention(query, key, value, block_mask=block_mask)
def causal_score_mod(score, b, h, q_idx, kv_idx):
return torch.where(q_idx >= kv_idx, score, float("-inf"))
# flex_attention() automatically converts indices from "stacked sequence" space -> relative sequence space for NJT inputs
output2 = flex_attention(query, key, value, score_mod=causal_score_mod)
```
TODO:
* ~~Determine the right level of abstraction for public API helpers + move them alongside other helpers~~ Verify this with others though
* ~~Some cleanup~~
* ~~`njt_score_mod_adapter`~~
* ~~Q: should `create_njt_block_mask()` call `njt_mask_mod_adapter()` so we don't need two calls?~~
* Can we avoid materializing the `sum(s)` length `seq_idx` used for conversion between stacked sequence -> sequence relative indices?
* Not for now, although future work may deepen the integration between Flex + NJT (possibly requiring custom templates). We should try to cache this though.
* ~~Demonstrate non-causal mask~~
* Support non-contiguous NJTs with holes (**booted to future PR**)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/136792
Approved by: https://github.com/drisspg
ghstack dependencies: #138841
Before this PR, NJT would dispatch e.g. `NJT * nested_int` to `mul.Tensor`, wrongly interpreting the SymInt as a tensor and outputting garbage. This PR verifies that there are no nested ints in the list of args before dispatching for pointwise ops.
I originally tried checking that `the number of passed tensor args == the number of func schema tensor args`, but this wrongly disallows `nt * 2`, which (non-intuitively to me at least at first) dispatches via the `mul.Tensor` overload.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/138602
Approved by: https://github.com/soulitzer
Does what it says on the tin. I believe the right behavior here is to ensure that `record_stream()` is called on all tensor components of the NJT to ensure they all live until stream computation is complete.
This is an ask from torchrec as the op is used there.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/137099
Approved by: https://github.com/ngimel
For `autograd.Function`, the engine will try to allocate correctly-shaped zeros for `None` grads (i.e. in the case where the output isn't used downstream). It determines the shape of these zeros from the `VariableInfo` entry, which is derived from the forward output shape. For the NJT forward output case, the size info stored will contain a nested int, and calling `zeros()` with this size throws:
```
RuntimeError: .../build/aten/src/ATen/RegisterCPU.cpp:5260: SymIntArrayRef expected to contain only concrete integers
```
This PR fixes this by storing the full tensor in the `VariableInfo` for the nested case and calling `zeros_like()` to allocate correctly-shaped zeros. This is pretty inefficient; ideally we would want to save just the NJT shape and be able to construct zeros from it, but this requires factory function support for nested ints (WIP). So this is a short-term fix until we have that.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/136875
Approved by: https://github.com/soulitzer, https://github.com/huydhn
For `autograd.Function`, the engine will try to allocate correctly-shaped zeros for `None` grads (i.e. in the case where the output isn't used downstream). It determines the shape of these zeros from the `VariableInfo` entry, which is derived from the forward output shape. For the NJT forward output case, the size info stored will contain a nested int, and calling `zeros()` with this size throws:
```
RuntimeError: .../build/aten/src/ATen/RegisterCPU.cpp:5260: SymIntArrayRef expected to contain only concrete integers
```
This PR fixes this by storing the full tensor in the `VariableInfo` for the nested case and calling `zeros_like()` to allocate correctly-shaped zeros. This is pretty inefficient; ideally we would want to save just the NJT shape and be able to construct zeros from it, but this requires factory function support for nested ints (WIP). So this is a short-term fix until we have that.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/136875
Approved by: https://github.com/soulitzer
Fixes#129366
Since NJT has custom serialization logic, we need an NJT-specific fix to clear out cached sizes / strides PyCapsules. Eventually, we should switch NJT to use the default serialization logic, but this depends on #125622 being addressed.
This PR also makes serialization more complete by explicitly handling `lengths`, `ragged_idx`, and the `metadata_cache`, ensuring working operation for both contiguous and non-contiguous NJTs,
Pull Request resolved: https://github.com/pytorch/pytorch/pull/137031
Approved by: https://github.com/soulitzer
ghstack dependencies: #137030
Prior to this PR, calling `reshape()` under `inference_mode()` would throw a `NotImplementedError`. This is because `inference_mode()` disables autograd key dispatch, incidentally preventing the decomposition of reshape for NJT.
This PR fixes this by redispatching on the `CompositeImplicitAutogradNestedTensor` key whenever a composite implicit op is encountered in `NJT.__torch_dispatch__()`. This fixes reshape and any other composite implicit ops underneath `inference_mode()`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134683
Approved by: https://github.com/soulitzer, https://github.com/albanD
ghstack dependencies: #136566
Related: #132695
This PR uses padded dense <-> jagged conversions to handle binary pointwise broadcasting of (NT, T) and (T, NT). This includes:
* `(B, j0, D) + (1, 1, 1)`
* `(B, j0, D) + (B, 1, 1)`
* `(B, j0, D) + (B, 1, D)`
* etc.
This PR also adds (hacky) support for bool inputs to the jagged <-> padded dense conversions. The underlying CUDA kernels do not support integer / bool inputs; so the following workaround is employed: `convert input -> half, run conversion kernel, convert output -> bool`. Note that this bool support is needed specifically for the backward formula of `fmax`, and likely others.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/133021
Approved by: https://github.com/cpuhrsch
`rms_norm()` is a nice-to-have for ViT :)
This PR:
* SymInt-ifies `rms_norm()`, allowing NJT to use the same decomp.
* Adds torch_function-based input validation logic for nested-specific stuff (no normalization supported over the ragged dim for now) on the python NJT side.
* Adds multi-dim support (on non-ragged, non-batch dims) to `mean()` for NJT.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135872
Approved by: https://github.com/mikaylagawarecki
ghstack dependencies: #125947
This PR solves two problems with `sum()` support in NJT:
* `sum()` over a dim with `keepdim=True` returns the wrong shape (i.e. it'll keep the wrong dim). This is a long-standing bug from way back in #112519.
* Historically, we've only supported `sum()` over a dim and not a full reduction. This PR adds the full reduction form (forward only, backward still fails).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/131945
Approved by: https://github.com/davidberard98, https://github.com/jananisriram
This PR:
* Implements the pre-existing `nt.to_padded_tensor(padding_val)` ATen op via the FBGEMM kernel + appropriate view gymnastics (since that kernel only handles 2D values)
* Introduces a new `_nested_from_padded_tensor` op for the reverse conversion, implemented via the reverse FBGEMM kernel + view gymnastics
* Note: there is currently no public API for this; design booted to a future PR
TODO:
* ~~Propagate min / max sequence length via the new factory function `_nested_from_padded_tensor`~~
* ~~Verify that Inductor does computation fusion via test logic~~
Pull Request resolved: https://github.com/pytorch/pytorch/pull/125947
Approved by: https://github.com/soulitzer
This PR:
* Implements the pre-existing `nt.to_padded_tensor(padding_val)` ATen op via the FBGEMM kernel + appropriate view gymnastics (since that kernel only handles 2D values)
* Introduces a new `_nested_from_padded_tensor` op for the reverse conversion, implemented via the reverse FBGEMM kernel + view gymnastics
* Note: there is currently no public API for this; design booted to a future PR
TODO:
* ~~Propagate min / max sequence length via the new factory function `_nested_from_padded_tensor`~~
* ~~Verify that Inductor does computation fusion via test logic~~
Pull Request resolved: https://github.com/pytorch/pytorch/pull/125947
Approved by: https://github.com/soulitzer
Summary:
When exporting for training with `tolist`, we do not hit `FunctionalTensor.tolist` since we do not functionalize. Unfortunately, this means we hit `FakeTensor.tolist`, which creates unbacked symints that are not backed by proxies.
Rather than trying to patch up this low-level implementation, we replace it with essentially what `FunctionalTensor.tolist` does, which is higher-level: we essentially desugar to `item()` calls and let it take care of unbacked symints.
Test Plan:
Some expected failures are gone now.
Also found a test for `tolist` that was written when `FunctionalTensor.tolist` was implemented but not really doing much; repurposed it now to exercise more modes.
Differential Revision: D62197742
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135131
Approved by: https://github.com/ezyang
A user wants to use the flop counter with meta devices. This previously caused problems for SDPA+NJT:
1. autocast check: `torch.is_autocast_enabled("meta")` fails because `meta` is not valid for autocasting. If we skip this, we run into the next error
2. math backend: conversion to NST requires getting concrete offsets in a list of python integers, which doesn't work on a meta tensor b2eb0e8c6a/torch/nested/_internal/sdpa.py (L809-L815)
3. (fixed in the previous PR, #134288) - if we force using flash attention backend for flop counting, `_flash_attention_forward` previously didn't support meta tensors.
In this PR, we check specifically for FlopCounterMode, and, if it's enabled and combined with meta tensors, (a) skip autocasting and (b) force it down the flash attention path. This isn't generally safe for tracing (e.g. if you actually care which kernels you are running), but in the absence of actual device information, we have to make some assumptions. By specifically checking for FlopCounterMode, this should reduce the chance of unintended side effects for other meta tensor users.
Note: fake tensor would solve a bunch of these issues, but it's not a viable solution right now for the user.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134289
Approved by: https://github.com/soulitzer
ghstack dependencies: #134288
It's possible to construct an NJT with "holes" by specifying both `offsets` and `lengths` metadata. When `nt.clone(memory_format=torch.contiguous_format)` is called on such an NJT, the result should be an NJT without holes.
This PR fixes this in simplistic way using `unbind()`, which isn't really supported in `torch.compile`. The longer term solution involves writing a proper kernel to support this.
NB: Another limitation is that the returned NJT does not have the same ragged structure as the input. While we could manually hack the nested int registry (or update the union find when that lands), this is the first instance where a NJT with holes and an NJT without holes could have the same ragged structure, and getting those to play nicely together requires some fairly involved updates. For now, this PR punts on these updates until we can clean this up.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/132776
Approved by: https://github.com/ani300, https://github.com/soulitzer
ghstack dependencies: #131898, #131704, #131937
Summary:
Modify `softmax` on the ragged dimension, where `ragged_idx == 1`, to allow for 2D nested tensors. This diff now enables a `softmax` operation on tensors of shape `(B, *)`, where `*` is the ragged dimension.
Extend existing `softmax` unit tests to include 2D nested tensors using the `include_2d_tensor=True` keyword argument.
Test Plan:
Verify that existing and modified unit tests pass using the following commands:
```
buck2 run mode/{opt,inplace} //caffe2/test:nested -- --regex test_softmax
```
```
buck2 run mode/{opt,inplace} //caffe2/test:nested -- --regex test_jagged_op
```
Reviewed By: davidberard98
Differential Revision: D60780975
Pull Request resolved: https://github.com/pytorch/pytorch/pull/132812
Approved by: https://github.com/davidberard98
When autocasting is turned on, right now SDPA w/ NJT won't be autocasted. This PR adds manual "autocasting" logic in sdpa.py - at the beginning, it just checks if autocasting is enabled, and if so, it casts the inputs in the way you would expect if autocasting was actually running.
Why normal autocasting won't work:
* NJT intercepts the `__torch_function__` call for scaled_dot_product_attention, which, AFAIK, happens before we get to any dispatcher logic, and then calls efficient attention or flash attention. So autocasting the scaled_dot_product_attention op won't work; we never call the aten op for scaled_dot_product_attention, so we won't ever run autocasting for it.
* If we try to add autocasting handling for `_flash_attention_forward` or `_efficient_attention_forward`, then autocasting will _run_, but it will have the wrong semantics: sdpa.py's handling will run first, and it will do backend selection based on the uncasted inputs to SDPA. This also means that if the inputs to the SDPA call don't have uniform types, the sdpa.py implementation will fail checks (this is the specific issue we're targeting).
Alternative: "just change the backend selection logic for NJT to be autocast aware, but don't actually do the autocast; then, add `_(flash|efficient)_attention_forward` to autocasting rules". I think this would work too. But it's arguably better to make the backend-selection logic and actual-autocast-behavior use the same implementation, in case the implementations are different.
Differential Revision: [D60879916](https://our.internmc.facebook.com/intern/diff/D60879916)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/132835
Approved by: https://github.com/soulitzer
This PR does 3 things:
1. Adds a copy-free strided->jagged layout conversion for NT
2. Adds a copy-free jagged->strided layout conversion for NT
3. Modifies and expands the .to() API to support the layout argument for the specific case of NT layout conversion.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/115749
Approved by: https://github.com/jbschlosser
Combines contributions from https://github.com/pytorch/pytorch/pull/130505
Some context can be found in this large comment block:
a5b64d39fd/test/dynamo/test_subclasses.py (L1667-L1681)
Changes in this PR
- For each tensor fakified, check the nested int registry in eager, and eagerly symbolicize if that tensor has already been associated with nested int in eager.
- Adds a separate counter stored on FakeTensorMode as a fake analog to _tensor_id_counter (which keeps track of unique tensors). This counter is initialized to the global eager tensor id counter upon creation of the FakeTensorMode, and needs to be reset when the same FakeTensorMode is reused to trace again (in this PR, we piggyback on the epoch incrementing logic).
- (refactor) Today, we store FakeTensor -> symbolic nested int in the global registry. With this PR, symbolic nested int is stored directly on the FakeTensor. (Eager still caches nested int in the registry, though we should avoid this at some point.)
Basically unchanged, but worth noting:
- `__tensor_unflatten__` is still responsible for determining whether we should cache for now. The logic is somewhat simplified.
- to_copy is still using the trick of updating two different tensors in the registry to point to the same nested int. This is kind of broken, but we try to leave it as is, and plan a better fix with the UnionFind stack.
Differential Revision: [D60406772](https://our.internmc.facebook.com/intern/diff/D60406772)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/130292
Approved by: https://github.com/bdhirsh
ghstack dependencies: #131916, #131803
This PR does 3 things:
1. Adds a copy-free strided->jagged layout conversion for NT
2. Adds a copy-free jagged->strided layout conversion for NT
3. Modifies and expands the .to() API to support the layout argument for the specific case of NT layout conversion.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/115749
Approved by: https://github.com/jbschlosser
mvlgamma backward trips DEBUG=1 asserts when trying to construct an empty tensor with `layout=torch.jagged`. This happens due to passing `self.options()` to `arange()` in `mvlgamma_backward()`. Fix in this PR unconditionally constructs `arange()` with the strided layout.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/132422
Approved by: https://github.com/albanD