Commit Graph

135 Commits

Author SHA1 Message Date
IvanKobzarev
661d1f0372 [aotd] non-contiguous NestedTensor mutation in compile (#139630)
Allow mutations mutations for subclasses that are non-contiguous.

Changes:

Removing assert in collect_metadata_analysis

Main requested testcase:
Compilation of NJT.index_put()

Adding test in test_nestedtensor.py, that compiles NJT.index_put()

It is  decomposed to NJT split,unbind, which  needed additional `torch._check`, `torch._check_is_size` for NJT.unbind()  and guard_size_oblivious() usage in _meta_registrations and _inductor/lowering.py.

Special case:
If tangent is mutated outside of the graph, it does not participate in backward graph. Autograd in this case will set this tangent to zeros tensor.

We handle it separately in CompiledFunction.backward: not doing any processing for this tangent and broadcast to number of expected subclass unwrapped arguments.

disabling for dynamo 2 tests:
1/ For nested tensor - symbolic shapes issue on nested_tensor index operation that does splits [0, 0, 0] - there is a failure with "pending unbacked symints". This PR does not add more .tolist()/item() ops than it was before.

2/ As we do not fail with exception in collect_metadata_analysis new paths for dynamo started working and it started failing with smth strange that set_ in storage_offset (because of test for views) handling updates storage "cpu" -> "meta"

Pull Request resolved: https://github.com/pytorch/pytorch/pull/139630
Approved by: https://github.com/bdhirsh
2024-12-06 12:18:46 +00:00
Joel Schlosser
e803a3d83a Fix reductions for NJTs with ragged_idx != 1 (#142173)
**Background:** conversion from outer dim -> inner dim makes the (previously valid) assumption that the ragged dim is immediately next to the batch dim. This is no longer the case after #137125.

This PR:
* Updates the outer dim -> inner dim conversion logic to match the actual ragged_idx. Since ragged_idx tells us where the packed ragged / batch dim is, both ragged and batch outer dims should map to this inner dim. The conversion logic must now take in `ragged_idx` to make this possible, so the PR updates all call-sites to pass this.
* Fixes outputs across keepdim settings when reducing over ragged / batch dims.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/142173
Approved by: https://github.com/drisspg
2024-12-06 01:23:17 +00:00
Joel Schlosser
c9e2b3fefe NJT: Return correct number of outputs for chunk() on the batch dim (#141604)
Old logic was completely wrong, returning `chunk_size` chunks instead of the intended number. The original test didn't catch this because `chunk_size == num_chunks` :p New OpInfo-based testing covers it though.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/141604
Approved by: https://github.com/soulitzer
ghstack dependencies: #141500, #140736, #140161, #141392, #141506
2024-11-27 02:31:23 +00:00
Joel Schlosser
43121b6f0d Adjust output NJT ragged_idx for reductions and select() (#141506)
This fixes some bugs when performing reductions / select() on dims before the ragged dim. In this case, the output NJT has a smaller number of dims, and its ragged_idx should reflect that correctly.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/141506
Approved by: https://github.com/cpuhrsch, https://github.com/soulitzer
ghstack dependencies: #141500, #140736, #140161, #141392
2024-11-27 02:25:53 +00:00
Joel Schlosser
23793cf93d NJT unsqueeze() fixes (#141392)
This PR contains three `unsqueeze()`-related fixes for NJT:
1. Adjusts the output's `_ragged_idx` when `unsqueeze()` inserts a dim before the ragged dim
2. Corrects the unbind reference for `unsqueeze()` after the last input dim. For this case, the dim kwarg canonicalization logic needs to be applied wrt `inp.dim() + 1` to account for `dim=-1` properly
3. Adds ragged dim support to `unsqueeze()`, allowing for e.g. `(B, j1, D) -> (B, 1, j1, D)`. This is okay now after #137125

Note that `unsqueeze()` still doesn't support batch dim operation, and arguably should never support this.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/141392
Approved by: https://github.com/cpuhrsch
ghstack dependencies: #141500, #140736, #140161
2024-11-26 22:38:35 +00:00
Joel Schlosser
9ee5d6f83c Initial NJT testing over dim type / views (#140161)
This PR introduces `ExtraOpData`, a structure that contains op metadata regarding whether the op is a view and the dim-related args it accepts. It also populates a huge database for dim-wise / view ops with this info.

Test logic (sample input generation, references) have been updated to utilize this data. It allows for a fairly generic set of sample inputs & a reference for the class of ops that accept a single NJT and operate dim-wise (AKA "unary dimwise ops").

Testing is added over the following ops:
* `chunk()`
* `narrow()`
* `select()`
* `split()`
* `split_with_sizes()`
* `squeeze()`
* `unflatten()`
* `unsqueeze()`

Most of the above do not operate on the ragged / batch dims or on non-contiguous NJTs, so the proper xfails are added as needed.

I also slipped in a couple minor fixes (sorry):
1. The `_wrap_jagged_dim()` helper now avoids assuming the `nt._ragged_idx == 1` and allows for a batch dim to be a valid input, disambiguating the converted inner dim as necessary through an additional `operating_on_batch` return value (i.e. both dim=0 and dim=1 map to dim=0 on the inner values tensor, since that dim represents a packed ragged dim for all batch items)
2. Padded dense -> NJT conversion requires shape gymnastics to operate with the restrictive FBGEMM kernel. The gymnastics were slightly wrong for the transposed NJT case, and this PR fixes that
Pull Request resolved: https://github.com/pytorch/pytorch/pull/140161
Approved by: https://github.com/Skylion007, https://github.com/cpuhrsch
ghstack dependencies: #141500, #140736
2024-11-26 22:08:08 +00:00
Joel Schlosser
869d629c0f Forward / backward NJT support for several activation functions (#140736)
Several activation functions were unimplemented due to missing `pointwise` tags. This PR adds them and corresponding backwards implementations.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/140736
Approved by: https://github.com/Skylion007, https://github.com/cpuhrsch
ghstack dependencies: #141500
2024-11-26 21:19:58 +00:00
Joel Schlosser
8ba555ec8a Fix where() for NJT (#141500)
**Background:** It's common to use `scalar_tensor()` in the input to `where()` to convert any scalars present to compatible tensors with matching options, *including layout*. This shows up in various places, notably including derivative formulas ([example](78491d6afc/tools/autograd/derivatives.yaml (L432-L434))). It causes problems for NJTs because they have `layout=torch.jagged` and it never makes sense to create a scalar tensor with this layout. Some of the breakage only seems to happen in CI for reasons I don't fully understand (see the revert of #140736 due to softshrink's derivative formula).

**This PR:**
* Allows non-contiguous NJT inputs to `where()` + adds tests for this
* Handles scalar tensor / dense tensor inputs for `condition` / `other` + adds tests for this
    * Uses limited `broadcast_tensors()` / `broadcast_to()` support
    * Improves `expand()` to work on non-contig NJTs
* Changes `scalar_tensor()` to use `torch.strided` instead of `torch.jagged` in both eager and torch.compile (i.e. meta registration)
* Changes backward formulas for `sinc`, `pow`, `special.i1`, and `special.i1e` to uses `scalar_tensor()` instead of e.g. `zeros({})`

**Alternative approach:** Update all problematic usages of `scalar_tensor()` to avoid ever passing `layout=torch.jagged`. This is an extensive change and includes `torch.where()` logic, a bunch of derivative formulas, and likely other places not yet discovered.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/141500
Approved by: https://github.com/malfet, https://github.com/cpuhrsch, https://github.com/soulitzer
2024-11-26 20:13:27 +00:00
PyTorch MergeBot
cffeb83f15 Revert "Forward / backward NJT support for several activation functions (#140736)"
This reverts commit daaecb96d6.

Reverted https://github.com/pytorch/pytorch/pull/140736 on behalf of https://github.com/malfet due to Take 2, of stack revert your change but its tests are failing in trunk ([comment](https://github.com/pytorch/pytorch/pull/140736#issuecomment-2498479702))
2024-11-25 16:27:00 +00:00
PyTorch MergeBot
e0f9ec4a25 Revert "Initial NJT testing over dim type / views (#140161)"
This reverts commit 730caf0aed.

Reverted https://github.com/pytorch/pytorch/pull/140161 on behalf of https://github.com/malfet due to Sorry for reverting your change but its tests are failing in trunk ([comment](https://github.com/pytorch/pytorch/pull/140736#issuecomment-2498358652))
2024-11-25 15:40:54 +00:00
PyTorch MergeBot
58727b6f5f Revert "NJT unsqueeze() fixes (#141392)"
This reverts commit 48409a5cc6.

Reverted https://github.com/pytorch/pytorch/pull/141392 on behalf of https://github.com/malfet due to Sorry for reverting your change but its tests are failing in trunk ([comment](https://github.com/pytorch/pytorch/pull/140736#issuecomment-2498358652))
2024-11-25 15:40:54 +00:00
Joel Schlosser
48409a5cc6 NJT unsqueeze() fixes (#141392)
This PR contains three `unsqueeze()`-related fixes for NJT:
1. Adjusts the output's `_ragged_idx` when `unsqueeze()` inserts a dim before the ragged dim
2. Corrects the unbind reference for `unsqueeze()` after the last input dim. For this case, the dim kwarg canonicalization logic needs to be applied wrt `inp.dim() + 1` to account for `dim=-1` properly
3. Adds ragged dim support to `unsqueeze()`, allowing for e.g. `(B, j1, D) -> (B, 1, j1, D)`. This is okay now after #137125

Note that `unsqueeze()` still doesn't support batch dim operation, and arguably should never support this.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/141392
Approved by: https://github.com/cpuhrsch
ghstack dependencies: #140736, #140161
2024-11-25 08:08:38 +00:00
Joel Schlosser
730caf0aed Initial NJT testing over dim type / views (#140161)
This PR introduces `ExtraOpData`, a structure that contains op metadata regarding whether the op is a view and the dim-related args it accepts. It also populates a huge database for dim-wise / view ops with this info.

Test logic (sample input generation, references) have been updated to utilize this data. It allows for a fairly generic set of sample inputs & a reference for the class of ops that accept a single NJT and operate dim-wise (AKA "unary dimwise ops").

Testing is added over the following ops:
* `chunk()`
* `narrow()`
* `select()`
* `split()`
* `split_with_sizes()`
* `squeeze()`
* `unflatten()`
* `unsqueeze()`

Most of the above do not operate on the ragged / batch dims or on non-contiguous NJTs, so the proper xfails are added as needed.

I also slipped in a couple minor fixes (sorry):
1. The `_wrap_jagged_dim()` helper now avoids assuming the `nt._ragged_idx == 1` and allows for a batch dim to be a valid input, disambiguating the converted inner dim as necessary through an additional `operating_on_batch` return value (i.e. both dim=0 and dim=1 map to dim=0 on the inner values tensor, since that dim represents a packed ragged dim for all batch items)
2. Padded dense -> NJT conversion requires shape gymnastics to operate with the restrictive FBGEMM kernel. The gymnastics were slightly wrong for the transposed NJT case, and this PR fixes that
Pull Request resolved: https://github.com/pytorch/pytorch/pull/140161
Approved by: https://github.com/Skylion007, https://github.com/cpuhrsch
ghstack dependencies: #140736
2024-11-25 08:08:38 +00:00
Joel Schlosser
daaecb96d6 Forward / backward NJT support for several activation functions (#140736)
Several activation functions were unimplemented due to missing `pointwise` tags. This PR adds them and corresponding backwards implementations.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/140736
Approved by: https://github.com/Skylion007, https://github.com/cpuhrsch
2024-11-25 08:08:31 +00:00
PyTorch MergeBot
0be0c944b1 Revert "Forward / backward NJT support for several activation functions (#140736)"
This reverts commit af70f5e04c.

Reverted https://github.com/pytorch/pytorch/pull/140736 on behalf of https://github.com/huydhn due to Sorry for reverting your change but its tests are failing in trunk ([comment](https://github.com/pytorch/pytorch/pull/140736#issuecomment-2495075871))
2024-11-22 23:15:55 +00:00
Joel Schlosser
af70f5e04c Forward / backward NJT support for several activation functions (#140736)
Several activation functions were unimplemented due to missing `pointwise` tags. This PR adds them and corresponding backwards implementations.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/140736
Approved by: https://github.com/Skylion007, https://github.com/cpuhrsch
2024-11-22 22:05:53 +00:00
Joel Schlosser
41f315417c Fix NJT linear_backward() memory usage (#141163)
Fixes #141112

The formula we're using for `linear_backward()` is inefficient for higher dim input sizes, even if the input is trivially higher dim (e.g. via use of `unsqueeze()`). This PR updates the formula to match the more efficient version employed by NST. Specifically, note the leading dim collapse for `grad_output`'s values before we compute the various matmuls.
d5ee1d1b58/aten/src/ATen/native/nested/NestedTensorBackward.cpp (L37-L70)

Testing for correctness is done via existing gradcheck tests (e.g. `test_backward_nn_functional_linear`). I added a memory usage test but I think it's likely there's a better way to do this.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/141163
Approved by: https://github.com/Skylion007, https://github.com/cpuhrsch, https://github.com/soulitzer
2024-11-21 15:22:45 +00:00
chilli
c1f21bf2b6 Made FlexAttention error on subgraph lowering failure (#140331)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/140331
Approved by: https://github.com/drisspg
2024-11-17 02:43:58 +00:00
Joel Schlosser
9c678af9f9 Misc. non-contig NJT fixes (#140160)
This PR contains several fixes related to non-contiguous NJTs:
1. Propagates `lengths` through op calls appropriately (see desc of #138098)
    * SDPA now calls `nested_view_from_values_offsets_lengths()` instead of `nested_view_from_values_offsets()`
2. Allows non-contig NJTs in unsqueeze / transpose / select
3. Expands padded dense -> NJT conversion to support non-contig NJTs
4. (unrelated sorry) Updates `split` / `split_with_sizes` to allow for optional `dim`, matching the ATen signature
Pull Request resolved: https://github.com/pytorch/pytorch/pull/140160
Approved by: https://github.com/cpuhrsch
2024-11-09 01:18:26 +00:00
Joel Schlosser
ddb291a881 Fix and test several NJT reductions (#139317)
I'm sick of reductions not working properly - spotty dim coverage, missing backwards, etc. This PR fixes quite a bit.

It applies to the following ops:
* `sum` / `mean` / `prod`
* `all` / `any`
* `amin` / `amax`
* `min` / `max`
* `argmin` / `argmax`

The general reduction logic has been factored out into a helper `_apply_reduction(func, func_name, identity_element, *args, **kwargs)`. The idea is that by providing a valid identity element, we can utilize conversions to padded dense when needed for reducing over the ragged dim.

Extensive test coverage includes:
* reductions across ragged dim
* reductions across non-batch, non-ragged dims
* reductions across both batch and ragged dims
* multiple dim reductions (for ops that support this)
* full reduction -> scalar

Bonus: the PR includes backwards fixes for `sum` and `mean`, which have never worked.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/139317
Approved by: https://github.com/cpuhrsch
2024-10-31 20:55:38 +00:00
Antoni Viros
ad637a4c5c Add support for index_put_ in NT (#135722)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135722
Approved by: https://github.com/jbschlosser
2024-10-30 17:17:59 +00:00
PyTorch MergeBot
5861279f47 Revert "Add support for index_put_ in NT (#135722)"
This reverts commit b4836e5b5c.

Reverted https://github.com/pytorch/pytorch/pull/135722 on behalf of https://github.com/huydhn due to Sorry for reverting your change, but it is failing on ROCm ([comment](https://github.com/pytorch/pytorch/pull/135722#issuecomment-2445651914))
2024-10-30 01:53:55 +00:00
Antoni Viros
b4836e5b5c Add support for index_put_ in NT (#135722)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135722
Approved by: https://github.com/jbschlosser
2024-10-30 00:03:21 +00:00
Jake Schmidt
2b577ae58f Implement NJT embedding backward (#138627)
Fixes #138352

Pull Request resolved: https://github.com/pytorch/pytorch/pull/138627
Approved by: https://github.com/jbschlosser
2024-10-29 18:44:58 +00:00
Joel Schlosser
8ba9063002 FlexAttention support for NJT (#136792)
This PR adds FlexAttention + NJT support. In particular:
* To handle raggedness, treats the packed sequence dim of input NJTs as a giant "stacked sequence". To ensure user `score_mod` / `mask_mod` functions can still be written in the original NJT sequence space, this PR handles conversions for indices within the giant "stacked sequence" -> sequence relative indices automatically.
* Provides `py_impls` for `NestedTensor` to the HOPs for flex attention forward / backward that simply wrap / unwrap NJTs appropriately
* Adds barebones `new_empty()` support to NJT since FlexAttention utilizes this repeatedly; right now, only `new_empty()` with a shape of `()` is supported
* Tests that FlexAttention with a causal mask matches causal SDPA
* Adds a new public API for FlexAttention usage:
    * `create_nested_block_mask(mask_mod, B, H, njt, BLOCK_SIZE, _compile)` - NJT analogue for `create_block_mask()` that utilizes the `njt`'s ragged structure to create an appropriately-sized block mask (e.g. `(1, 1, total_seqlen, total_seqlen)`). This function handles the index conversion from "stacked sequence" space -> relative sequence space.
      * Minor note: as this is a public API, this function is purposefully named with "nested" instead of "njt" to keep the latter as an informal, mostly internal-only term.

Example usage:
```python
def causal_mask(b, h, q_idx, kv_idx):
    return q_idx >= kv_idx

query = ... # NJT of shape (B, H, S*, D)
key = ... # NJT of shape (B, H, S*, D)
value = ... # NJT of shape (B, H, S*, D)
# create_nested_block_mask() automatically converts indices from "stacked sequence" space -> relative sequence space
block_mask = create_nested_block_mask(causal_mask, 1, 1, query)  # block mask conceptual shape is (B, H, sum(S*), sum(S*))
output = flex_attention(query, key, value, block_mask=block_mask)

def causal_score_mod(score, b, h, q_idx, kv_idx):
    return torch.where(q_idx >= kv_idx, score, float("-inf"))

# flex_attention() automatically converts indices from "stacked sequence" space -> relative sequence space for NJT inputs
output2 = flex_attention(query, key, value, score_mod=causal_score_mod)
```

TODO:
* ~~Determine the right level of abstraction for public API helpers + move them alongside other helpers~~ Verify this with others though
* ~~Some cleanup~~
* ~~`njt_score_mod_adapter`~~
* ~~Q: should `create_njt_block_mask()` call `njt_mask_mod_adapter()` so we don't need two calls?~~
* Can we avoid materializing the `sum(s)` length `seq_idx` used for conversion between stacked sequence -> sequence relative indices?
    * Not for now, although future work may deepen the integration between Flex + NJT (possibly requiring custom templates). We should try to cache this though.
* ~~Demonstrate non-causal mask~~
* Support non-contiguous NJTs with holes (**booted to future PR**)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/136792
Approved by: https://github.com/drisspg
ghstack dependencies: #138841
2024-10-28 20:01:27 +00:00
Joel Schlosser
f089d5ffef Improve input validation for NJT pointwise ops (#138602)
Before this PR, NJT would dispatch e.g. `NJT * nested_int` to `mul.Tensor`, wrongly interpreting the SymInt as a tensor and outputting garbage. This PR verifies that there are no nested ints in the list of args before dispatching for pointwise ops.

I originally tried checking that `the number of passed tensor args == the number of func schema tensor args`, but this wrongly disallows `nt * 2`, which (non-intuitively to me at least at first) dispatches via the `mul.Tensor` overload.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/138602
Approved by: https://github.com/soulitzer
2024-10-22 20:13:12 +00:00
Joel Schlosser
134f6cda7e Support record_stream() for NJT (#137099)
Does what it says on the tin. I believe the right behavior here is to ensure that `record_stream()` is called on all tensor components of the NJT to ensure they all live until stream computation is complete.

This is an ask from torchrec as the op is used there.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/137099
Approved by: https://github.com/ngimel
2024-10-21 21:10:42 +00:00
Tom Ritchford
c0582fd0f8 Remove unused Python variables in torch/[b-z]* (#136963)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/136963
Approved by: https://github.com/ezyang
2024-10-19 16:45:22 +00:00
Joel Schlosser
ecc5e05854 Refactor NJT min / max seqlen handling for convenience (#138130)
There's an annoying pattern emerging for pulling out the NJT min / max seqlen ints if they exist without computing / caching if they don't. This PR introduces private convenience functions to simplify handling this and avoiding redundant checks.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/138130
Approved by: https://github.com/soulitzer
2024-10-17 17:28:39 +00:00
Joel Schlosser
906fe05895 Naive impls for NJT matmul (#138121)
Our matmul support is abysmal - let's at least get this working and do it performantly later.

Bonus: implements `bmm` as well.

jagged <-> padded dense conversions are utilized when possible, and an unbind-based fallback otherwise (the former works with torch.compile and the latter doesn't). Some testing is missing because we don't have factory function support yet :(
Pull Request resolved: https://github.com/pytorch/pytorch/pull/138121
Approved by: https://github.com/cpuhrsch
2024-10-17 01:31:46 +00:00
Joel Schlosser
3e2f276a14 Fix to() on non-contiguous NJTs (#137124)
Called out via torchrec integration: `lengths` is not handled properly.

Future work (not related to non-contiguous NJTs): #137275
Pull Request resolved: https://github.com/pytorch/pytorch/pull/137124
Approved by: https://github.com/soulitzer
ghstack dependencies: #137030, #137031
2024-10-08 15:11:05 +00:00
Joel Schlosser
991f8f8ec3 Bias gradient calculation for NJT linear backward (#136660)
Previously NYI - @mikaylagawarecki needs it for Transformers.

Fixes #136652
Pull Request resolved: https://github.com/pytorch/pytorch/pull/136660
Approved by: https://github.com/soulitzer
2024-09-26 21:38:10 +00:00
Joel Schlosser
4150ab44a4 Fix composite op redispatch for NJT in inference mode (#134683)
Prior to this PR, calling `reshape()` under `inference_mode()` would throw a `NotImplementedError`. This is because `inference_mode()` disables autograd key dispatch, incidentally preventing the decomposition of reshape for NJT.

This PR fixes this by redispatching on the `CompositeImplicitAutogradNestedTensor` key whenever a composite implicit op is encountered in `NJT.__torch_dispatch__()`. This fixes reshape and any other composite implicit ops underneath `inference_mode()`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134683
Approved by: https://github.com/soulitzer, https://github.com/albanD
ghstack dependencies: #136566
2024-09-26 14:10:53 +00:00
Joel Schlosser
888744bd36 NJT binary pointwise broadcasting support via jagged <-> padded dense conversion (#133021)
Related: #132695

This PR uses padded dense <-> jagged conversions to handle binary pointwise broadcasting of (NT, T) and (T, NT). This includes:
* `(B, j0, D) + (1, 1, 1)`
* `(B, j0, D) + (B, 1, 1)`
* `(B, j0, D) + (B, 1, D)`
* etc.

This PR also adds (hacky) support for bool inputs to the jagged <-> padded dense conversions. The underlying CUDA kernels do not support integer / bool inputs; so the following workaround is employed: `convert input -> half, run conversion kernel, convert output -> bool`. Note that this bool support is needed specifically for the backward formula of `fmax`, and likely others.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/133021
Approved by: https://github.com/cpuhrsch
2024-09-24 19:11:49 +00:00
Joel Schlosser
a8382847f4 Support rms_norm() for NJT (#135872)
`rms_norm()` is a nice-to-have for ViT :)

This PR:
* SymInt-ifies `rms_norm()`, allowing NJT to use the same decomp.
* Adds torch_function-based input validation logic for nested-specific stuff (no normalization supported over the ragged dim for now) on the python NJT side.
* Adds multi-dim support (on non-ragged, non-batch dims) to `mean()` for NJT.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135872
Approved by: https://github.com/mikaylagawarecki
ghstack dependencies: #125947
2024-09-17 18:09:20 +00:00
Joel Schlosser
06bc717410 Fix sum() forward for NJT (#131945)
This PR solves two problems with `sum()` support in NJT:
* `sum()` over a dim with `keepdim=True` returns the wrong shape (i.e. it'll keep the wrong dim). This is a long-standing bug from way back in #112519.
* Historically, we've only supported `sum()` over a dim and not a full reduction. This PR adds the full reduction form (forward only, backward still fails).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/131945
Approved by: https://github.com/davidberard98, https://github.com/jananisriram
2024-09-14 00:58:03 +00:00
Joel Schlosser
525bec804c NJT <-> padded dense conversions (#125947)
This PR:
* Implements the pre-existing `nt.to_padded_tensor(padding_val)` ATen op via the FBGEMM kernel + appropriate view gymnastics (since that kernel only handles 2D values)
* Introduces a new `_nested_from_padded_tensor` op for the reverse conversion, implemented via the reverse FBGEMM kernel + view gymnastics
    * Note: there is currently no public API for this; design booted to a future PR

TODO:
* ~~Propagate min / max sequence length via the new factory function `_nested_from_padded_tensor`~~
* ~~Verify that Inductor does computation fusion via test logic~~

Pull Request resolved: https://github.com/pytorch/pytorch/pull/125947
Approved by: https://github.com/soulitzer
2024-09-12 17:54:25 +00:00
PyTorch MergeBot
70a65a8bd5 Revert "NJT <-> padded dense conversions (#125947)"
This reverts commit 09a5e88bef.

Reverted https://github.com/pytorch/pytorch/pull/125947 on behalf of https://github.com/huydhn due to Sorry for reverting your change but it is failing dynamo test 09a5e88bef, maybe a landrace ([comment](https://github.com/pytorch/pytorch/pull/125947#issuecomment-2339228570))
2024-09-09 22:01:09 +00:00
Joel Schlosser
09a5e88bef NJT <-> padded dense conversions (#125947)
This PR:
* Implements the pre-existing `nt.to_padded_tensor(padding_val)` ATen op via the FBGEMM kernel + appropriate view gymnastics (since that kernel only handles 2D values)
* Introduces a new `_nested_from_padded_tensor` op for the reverse conversion, implemented via the reverse FBGEMM kernel + view gymnastics
    * Note: there is currently no public API for this; design booted to a future PR

TODO:
* ~~Propagate min / max sequence length via the new factory function `_nested_from_padded_tensor`~~
* ~~Verify that Inductor does computation fusion via test logic~~

Pull Request resolved: https://github.com/pytorch/pytorch/pull/125947
Approved by: https://github.com/soulitzer
2024-09-09 19:37:32 +00:00
yuqingj
defb515306 [NJT]Add permute ops support (#135336)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135336
Approved by: https://github.com/davidberard98
2024-09-08 21:00:41 +00:00
Aaron Orenstein
ed86ac2f25 [BE] typing for decorators - fx/_compatibility (#134054)
Summary: See #131429

Test Plan: unit tests pass

Differential Revision: D61493706

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134054
Approved by: https://github.com/oulgen
2024-08-26 04:00:27 +00:00
Aaron Orenstein
d95aedf5fd [BE] typing for decorators - fx/_compatibility (part 1) (#134202)
Part of #134054.

This corresponds to the pytorch mypy changes from D61493706. Updating takes so
long and touches so many files that it's impossible to land as a whole without conflicting with some other intermediate change.
So landing these 'type: ignore' for pytorch in advance of them actually being needed.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134202
Approved by: https://github.com/Skylion007
2024-08-22 17:07:33 +00:00
yuqingj
44fa9f991c [NJT] add aten.to.dtype support (#134164)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134164
Approved by: https://github.com/davidberard98
2024-08-22 16:59:38 +00:00
krzysztofjordan
2e1830c7c8 Implement 2D version of masked_select for nestedtensors (#133889)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/133889
Approved by: https://github.com/soulitzer
2024-08-20 21:46:32 +00:00
soulitzer
4af4910b1a Reland "Construct NJT without graph breaks" (#133196)
This reverts commit 154d40ca488e6979ce9c2de89d8a35b53129ebea.

and adds changes from https://github.com/pytorch/pytorch/pull/133061

Pull Request resolved: https://github.com/pytorch/pytorch/pull/133196
Approved by: https://github.com/ezyang
ghstack dependencies: #133145
2024-08-14 01:11:13 +00:00
PyTorch MergeBot
656465fc77 Revert "Conversions between strided and jagged layouts for Nested Tensors (#115749)"
This reverts commit ed97fb77f9.

Reverted https://github.com/pytorch/pytorch/pull/115749 on behalf of https://github.com/izaitsevfb due to fails internal jobs, see [S440348](https://www.internalfb.com/sevmanager/view/440348) ([comment](https://github.com/pytorch/pytorch/pull/115749#issuecomment-2285051164))
2024-08-12 23:14:19 +00:00
soulitzer
05de2b2d0f Revert "Construct NJT without graph breaks" (#133145)
This reverts commit 911154271309667b55dfb963ec6384bd0048019b.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/133145
Approved by: https://github.com/YuqingJ
2024-08-10 03:11:16 +00:00
Joel Schlosser
75eb66afc0 Support 'non-contiguous with holes' NJTs for contiguous clone() (#132776)
It's possible to construct an NJT with "holes" by specifying both `offsets` and `lengths` metadata. When `nt.clone(memory_format=torch.contiguous_format)` is called on such an NJT, the result should be an NJT without holes.

This PR fixes this in simplistic way using `unbind()`, which isn't really supported in `torch.compile`. The longer term solution involves writing a proper kernel to support this.

NB: Another limitation is that the returned NJT does not have the same ragged structure as the input. While we could manually hack the nested int registry (or update the union find when that lands), this is the first instance where a NJT with holes and an NJT without holes could have the same ragged structure, and getting those to play nicely together requires some fairly involved updates. For now, this PR punts on these updates until we can clean this up.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/132776
Approved by: https://github.com/ani300, https://github.com/soulitzer
ghstack dependencies: #131898, #131704, #131937
2024-08-08 17:08:11 +00:00
Janani Sriram
0ca8f66e3a [NestedTensor] Modify softmax on ragged dimension to allow for 2D nested tensors (#132812)
Summary:
Modify `softmax` on the ragged dimension, where `ragged_idx == 1`, to allow for 2D nested tensors. This diff now enables a `softmax` operation on tensors of shape `(B, *)`, where `*` is the ragged dimension.

Extend existing `softmax` unit tests to include 2D nested tensors using the `include_2d_tensor=True` keyword argument.

Test Plan:
Verify that existing and modified unit tests pass using the following commands:

```
buck2 run mode/{opt,inplace} //caffe2/test:nested -- --regex test_softmax
```

```
buck2 run mode/{opt,inplace} //caffe2/test:nested -- --regex test_jagged_op
```

Reviewed By: davidberard98

Differential Revision: D60780975

Pull Request resolved: https://github.com/pytorch/pytorch/pull/132812
Approved by: https://github.com/davidberard98
2024-08-08 15:41:28 +00:00
Antoni Viros
ed97fb77f9 Conversions between strided and jagged layouts for Nested Tensors (#115749)
This PR does 3 things:
1. Adds a copy-free strided->jagged layout conversion for NT
2. Adds a copy-free jagged->strided layout conversion for NT
3. Modifies and expands the .to() API to support the layout argument for the specific case of NT layout conversion.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115749
Approved by: https://github.com/jbschlosser
2024-08-07 14:18:53 +00:00