Commit Graph

265 Commits

Author SHA1 Message Date
Joel Schlosser
09a5e88bef NJT <-> padded dense conversions (#125947)
This PR:
* Implements the pre-existing `nt.to_padded_tensor(padding_val)` ATen op via the FBGEMM kernel + appropriate view gymnastics (since that kernel only handles 2D values)
* Introduces a new `_nested_from_padded_tensor` op for the reverse conversion, implemented via the reverse FBGEMM kernel + view gymnastics
    * Note: there is currently no public API for this; design booted to a future PR

TODO:
* ~~Propagate min / max sequence length via the new factory function `_nested_from_padded_tensor`~~
* ~~Verify that Inductor does computation fusion via test logic~~

Pull Request resolved: https://github.com/pytorch/pytorch/pull/125947
Approved by: https://github.com/soulitzer
2024-09-09 19:37:32 +00:00
yuqingj
defb515306 [NJT]Add permute ops support (#135336)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135336
Approved by: https://github.com/davidberard98
2024-09-08 21:00:41 +00:00
Avik Chaudhuri
43f4947d44 fix fake tensor tolist implementation (#135131)
Summary:
When exporting for training with `tolist`, we do not hit `FunctionalTensor.tolist` since we do not functionalize. Unfortunately, this means we hit `FakeTensor.tolist`, which creates unbacked symints that are not backed by proxies.

Rather than trying to patch up this low-level implementation, we replace it with essentially what `FunctionalTensor.tolist` does, which is higher-level: we essentially desugar to `item()` calls and let it take care of unbacked symints.

Test Plan:
Some expected failures are gone now.
Also found a test for `tolist` that was written when `FunctionalTensor.tolist` was implemented but not really doing much; repurposed it now to exercise more modes.

Differential Revision: D62197742

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135131
Approved by: https://github.com/ezyang
2024-09-05 23:20:31 +00:00
David Berard
4b4ba7ab06 [NJT] Support NJT SDPA + meta-device flop counting (#134289)
A user wants to use the flop counter with meta devices. This previously caused problems for SDPA+NJT:

1. autocast check: `torch.is_autocast_enabled("meta")` fails because `meta` is not valid for autocasting. If we skip this, we run into the next error
2. math backend: conversion to NST requires getting concrete offsets in a list of python integers, which doesn't work on a meta tensor b2eb0e8c6a/torch/nested/_internal/sdpa.py (L809-L815)
3. (fixed in the previous PR, #134288) - if we force using flash attention backend for flop counting, `_flash_attention_forward` previously didn't support meta tensors.

In this PR, we check specifically for FlopCounterMode, and, if it's enabled and combined with meta tensors, (a) skip autocasting and (b) force it down the flash attention path. This isn't generally safe for tracing (e.g. if you actually care which kernels you are running), but in the absence of actual device information, we have to make some assumptions. By specifically checking for FlopCounterMode, this should reduce the chance of unintended side effects for other meta tensor users.

Note: fake tensor would solve a bunch of these issues, but it's not a viable solution right now for the user.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134289
Approved by: https://github.com/soulitzer
ghstack dependencies: #134288
2024-08-29 03:43:42 +00:00
yuqingj
44fa9f991c [NJT] add aten.to.dtype support (#134164)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134164
Approved by: https://github.com/davidberard98
2024-08-22 16:59:38 +00:00
yuqingj
b459ca78eb [NJT]Add unit tests that cover the internal use cases using new NJT API (#133513)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/133513
Approved by: https://github.com/davidberard98, https://github.com/soulitzer
2024-08-22 13:54:40 +00:00
krzysztofjordan
2e1830c7c8 Implement 2D version of masked_select for nestedtensors (#133889)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/133889
Approved by: https://github.com/soulitzer
2024-08-20 21:46:32 +00:00
David Berard
bb0bf09aff [easy] skip test_sdpa_autocast on windows (#134009)
test is failing because torch.compile doesn't work on windows
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134009
Approved by: https://github.com/YuqingJ, https://github.com/Skylion007, https://github.com/ZainRizvi
2024-08-20 19:51:55 +00:00
soulitzer
4af4910b1a Reland "Construct NJT without graph breaks" (#133196)
This reverts commit 154d40ca488e6979ce9c2de89d8a35b53129ebea.

and adds changes from https://github.com/pytorch/pytorch/pull/133061

Pull Request resolved: https://github.com/pytorch/pytorch/pull/133196
Approved by: https://github.com/ezyang
ghstack dependencies: #133145
2024-08-14 01:11:13 +00:00
PyTorch MergeBot
656465fc77 Revert "Conversions between strided and jagged layouts for Nested Tensors (#115749)"
This reverts commit ed97fb77f9.

Reverted https://github.com/pytorch/pytorch/pull/115749 on behalf of https://github.com/izaitsevfb due to fails internal jobs, see [S440348](https://www.internalfb.com/sevmanager/view/440348) ([comment](https://github.com/pytorch/pytorch/pull/115749#issuecomment-2285051164))
2024-08-12 23:14:19 +00:00
soulitzer
05de2b2d0f Revert "Construct NJT without graph breaks" (#133145)
This reverts commit 911154271309667b55dfb963ec6384bd0048019b.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/133145
Approved by: https://github.com/YuqingJ
2024-08-10 03:11:16 +00:00
Joel Schlosser
75eb66afc0 Support 'non-contiguous with holes' NJTs for contiguous clone() (#132776)
It's possible to construct an NJT with "holes" by specifying both `offsets` and `lengths` metadata. When `nt.clone(memory_format=torch.contiguous_format)` is called on such an NJT, the result should be an NJT without holes.

This PR fixes this in simplistic way using `unbind()`, which isn't really supported in `torch.compile`. The longer term solution involves writing a proper kernel to support this.

NB: Another limitation is that the returned NJT does not have the same ragged structure as the input. While we could manually hack the nested int registry (or update the union find when that lands), this is the first instance where a NJT with holes and an NJT without holes could have the same ragged structure, and getting those to play nicely together requires some fairly involved updates. For now, this PR punts on these updates until we can clean this up.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/132776
Approved by: https://github.com/ani300, https://github.com/soulitzer
ghstack dependencies: #131898, #131704, #131937
2024-08-08 17:08:11 +00:00
Janani Sriram
0ca8f66e3a [NestedTensor] Modify softmax on ragged dimension to allow for 2D nested tensors (#132812)
Summary:
Modify `softmax` on the ragged dimension, where `ragged_idx == 1`, to allow for 2D nested tensors. This diff now enables a `softmax` operation on tensors of shape `(B, *)`, where `*` is the ragged dimension.

Extend existing `softmax` unit tests to include 2D nested tensors using the `include_2d_tensor=True` keyword argument.

Test Plan:
Verify that existing and modified unit tests pass using the following commands:

```
buck2 run mode/{opt,inplace} //caffe2/test:nested -- --regex test_softmax
```

```
buck2 run mode/{opt,inplace} //caffe2/test:nested -- --regex test_jagged_op
```

Reviewed By: davidberard98

Differential Revision: D60780975

Pull Request resolved: https://github.com/pytorch/pytorch/pull/132812
Approved by: https://github.com/davidberard98
2024-08-08 15:41:28 +00:00
David Berard
59f4725b49 [NJT] manually autocast in SDPA handling (#132835)
When autocasting is turned on, right now SDPA w/ NJT won't be autocasted. This PR adds manual "autocasting" logic in sdpa.py - at the beginning, it just checks if autocasting is enabled, and if so, it casts the inputs in the way you would expect if autocasting was actually running.

Why normal autocasting won't work:
* NJT intercepts the `__torch_function__` call for scaled_dot_product_attention, which, AFAIK, happens before we get to any dispatcher logic, and then calls efficient attention or flash attention. So autocasting the scaled_dot_product_attention op won't work; we never call the aten op for scaled_dot_product_attention, so we won't ever run autocasting for it.
* If we try to add autocasting handling for `_flash_attention_forward` or `_efficient_attention_forward`, then autocasting will _run_, but it will have the wrong semantics: sdpa.py's handling will run first, and it will do backend selection based on the uncasted inputs to SDPA. This also means that if the inputs to the SDPA call don't have uniform types, the sdpa.py implementation will fail checks (this is the specific issue we're targeting).

Alternative: "just change the backend selection logic for NJT to be autocast aware, but don't actually do the autocast; then, add `_(flash|efficient)_attention_forward` to autocasting rules". I think this would work too. But it's arguably better to make the backend-selection logic and actual-autocast-behavior use the same implementation, in case the implementations are different.

Differential Revision: [D60879916](https://our.internmc.facebook.com/intern/diff/D60879916)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/132835
Approved by: https://github.com/soulitzer
2024-08-08 01:36:57 +00:00
Antoni Viros
ed97fb77f9 Conversions between strided and jagged layouts for Nested Tensors (#115749)
This PR does 3 things:
1. Adds a copy-free strided->jagged layout conversion for NT
2. Adds a copy-free jagged->strided layout conversion for NT
3. Modifies and expands the .to() API to support the layout argument for the specific case of NT layout conversion.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115749
Approved by: https://github.com/jbschlosser
2024-08-07 14:18:53 +00:00
yuqingj
623d0204f0 [NJT] Support Chunk backward for simple cases (#132193)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/132193
Approved by: https://github.com/soulitzer
2024-08-06 21:20:09 +00:00
soulitzer
f50621989b Construct NJT without graph breaks (#130292)
Combines contributions from https://github.com/pytorch/pytorch/pull/130505

Some context can be found in this large comment block:

a5b64d39fd/test/dynamo/test_subclasses.py (L1667-L1681)

Changes in this PR
- For each tensor fakified, check the nested int registry in eager, and eagerly symbolicize if that tensor has already been associated with nested int in eager.
- Adds a separate counter stored on FakeTensorMode as a fake analog to _tensor_id_counter (which keeps track of unique tensors). This counter is initialized to the global eager tensor id counter upon creation of the FakeTensorMode, and needs to be reset when the same FakeTensorMode is reused to trace again (in this PR, we piggyback on the epoch incrementing logic).
- (refactor) Today, we store FakeTensor -> symbolic nested int in the global registry. With this PR, symbolic nested int is stored directly on the FakeTensor. (Eager still caches nested int in the registry, though we should avoid this at some point.)

Basically unchanged, but worth noting:
- `__tensor_unflatten__` is still responsible for determining whether we should cache for now. The logic is somewhat simplified.
- to_copy is still using the trick of updating two different tensors in the registry to point to the same nested int. This is kind of broken, but we try to leave it as is, and plan a better fix with the UnionFind stack.

Differential Revision: [D60406772](https://our.internmc.facebook.com/intern/diff/D60406772)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/130292
Approved by: https://github.com/bdhirsh
ghstack dependencies: #131916, #131803
2024-08-06 17:03:39 +00:00
PyTorch MergeBot
38674bcb45 Revert "Conversions between strided and jagged layouts for Nested Tensors (#115749)"
This reverts commit eca0cb0fbe.

Reverted https://github.com/pytorch/pytorch/pull/115749 on behalf of https://github.com/izaitsevfb due to breaks test_overrides.py::TestTorchFunctionWarning::test_warn_on_invalid_torch_function_tensor_subclass ([comment](https://github.com/pytorch/pytorch/pull/115749#issuecomment-2270213988))
2024-08-06 01:55:41 +00:00
Antoni Viros
eca0cb0fbe Conversions between strided and jagged layouts for Nested Tensors (#115749)
This PR does 3 things:
1. Adds a copy-free strided->jagged layout conversion for NT
2. Adds a copy-free jagged->strided layout conversion for NT
3. Modifies and expands the .to() API to support the layout argument for the specific case of NT layout conversion.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115749
Approved by: https://github.com/jbschlosser
2024-08-05 23:45:48 +00:00
Aaron Gokaslan
fd4b649e6c [BE]: Simplify some list comps to generators C419 (#132578)
Simplifies some list comprehensions to generator which is more efficient. Automatically applied diffs for the most part with ruff

Pull Request resolved: https://github.com/pytorch/pytorch/pull/132578
Approved by: https://github.com/ezyang
2024-08-04 17:46:26 +00:00
Xuehai Pan
4226ed1585 [BE] Format uncategorized Python files with ruff format (#132576)
Remove patterns `**`, `test/**`, and `torch/**` in `tools/linter/adapters/pyfmt_linter.py` and run `lintrunner`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/132576
Approved by: https://github.com/ezyang, https://github.com/Skylion007
ghstack dependencies: #132574
2024-08-04 17:13:31 +00:00
Joel Schlosser
a356a03f4a Fix DEBUG=1 asserts for mvlgamma backward with NJT (#132422)
mvlgamma backward trips DEBUG=1 asserts when trying to construct an empty tensor with `layout=torch.jagged`. This happens due to passing `self.options()` to `arange()` in `mvlgamma_backward()`. Fix in this PR unconditionally constructs `arange()` with the strided layout.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/132422
Approved by: https://github.com/albanD
2024-08-01 21:53:16 +00:00
Oguz Ulgen
221350e3a4 Add None return type to init -- tests (#132352)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/132352
Approved by: https://github.com/ezyang
ghstack dependencies: #132335, #132351
2024-08-01 15:44:51 +00:00
Joel Schlosser
7eb2a99585 Fix to support unary pointwise ops when an NJT is not the first arg (#131937)
**Background:** NJT utilizes a `jagged_unary_pointwise()` fallback that historically has assumed blindly that the first arg is an NJT. This assumption breaks certain ops; for example `pow(scalar, Tensor)` has an NJT as the second arg.

This PR expands `jagged_unary_pointwise()` and the associated schema validation logic to handle an NJT in args other than the first position.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/131937
Approved by: https://github.com/soulitzer
ghstack dependencies: #131898, #131704
2024-07-31 17:51:03 +00:00
Janani Sriram
46994e753b [NestedTensor] Integrate the layer normalization operator along the jagged dimension into NestedTensor (#132172)
Modify the existing `layer normalization` operator in PyTorch, invoked by `torch.layer_norm`, to allow for reductions along the jagged dimension of a nested tensor. The function originally had a basic implementation for reducing along 1 non-ragged dimension. This diff, which uses the `aten` padding operator, enables PyTorch users to invoke `torch.nn.functional.layer_norm` on a nested tensor when reducing along the ragged dimension, e.g. `*` in a `(B, *, M)` or `(B, *, M, N)` nested tensor.

Write unit tests based on the `softmax` jagged operator to verify the accuracy of the ragged reduction implementation for `torch.nn.functional.layer_norm`. Add unit tests to verify error handling for unsupported features.

Note that this implementation is limited to nested tensors with `ragged_idx == 1`, i.e. the ragged dimension is not transposed. The layer normalization operator also requires an operation on a 2-dimensional layer; for nested tensors with 4 or more dimensions, I flatten the extra dimensions, then unflatten them after performing layer normalization.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/132172
Approved by: https://github.com/davidberard98
ghstack dependencies: #132170
2024-07-31 10:51:46 +00:00
Janani Sriram
89053e382a [NestedTensor] Integrate the softmax operator along the jagged dimension into NestedTensor (#132170)
Modify the existing `softmax` operator in PyTorch, invoked by `torch.softmax`, to allow for reductions along the jagged dimension of a nested tensor. The function originally had a basic implementation for reducing along 1 non-ragged dimension. This diff, which uses the aten padding operator, enables PyTorch users to invoke `torch.softmax` on a nested tensor when reducing along the ragged dimension, e.g. `*` in a `(B, *, M)` nested tensor.

Write unit tests based on the `sum` and `mean` jagged operators to verify the accuracy of the ragged reduction implementation for `torch.softmax`. Add unit tests to verify error handling for unsupported features in `NestedTensor` `torch.softmax`.

Note that this implementation is limited to nested tensors with `ragged_idx == 1`, i.e. the ragged dimension is not transposed. In addition, the `softmax` operator is required to take in as input an integer for the reduction dimension `dim`, requiring new unit tests heavily inspired by the `sum` and `mean` jagged operator unit tests. `Softmax` also allows for reducing along the batch dimension.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/132170
Approved by: https://github.com/davidberard98
2024-07-31 10:51:46 +00:00
Joel Schlosser
524aac413c Initial OpInfo-based testing for NJTs (#131704)
This PR utilizes the info from the existing OpInfo database `op_db` to contribute to general NJT testing.
* New tests in `TestNestedTensorOpInfo`
    * `test_forward()` - compares forward output to an unbind-based reference
    * `test_backward()` - compares forward output and grads to an unbind-based reference
    * `test_forward_compile()` - compares forward compile output (`backend="aot_eager_decomp_partition"`) to eager
    * `test_backward_compile()` - compares forward compile output (`backend="aot_eager_decomp_partition"`) and grads to eager
* To avoid adding a bunch of NJT-specific stuff to the `OpInfo` structure, this PR translates `op_db` -> a NJT-specific `njt_op_db`.
    * `UnaryUfuncInfo`s utilize a new `sample_inputs_unary_njt_pointwise()` which iterates through a comprehensive list of NJTs: contiguous / non-contiguous, dims 2, 3, and 4, transposed / not, etc.
    * `BinaryUfuncInfo`s utilize a new `sample_inputs_binary_njt_pointwise()` which iterates through a comprehensive list of NJTs: contiguous / non-contiguous, dims 2, 3, and 4, transposed / not, etc.
    * `ReductionOpInfo`s utilize a new `sample_inputs_njt_reduction()` which covers full reductions, reductions over the jagged dim, and reductions over the non-jagged dim
* Several xfails were added to get things passing

TODO (future PRs):
* Pass non-contiguous / non-contiguous with holes NJTs (maybe we should have separate tests for these? most ops don't support NJTs with holes today)
* Mixed (NT, T), (T, NT) inputs for binary ops
* Handle other types of OpInfos (beyond unary pointwise, binary pointwise, and reduction) by manually by writing sample_inputs_funcs
* Address all xfails via fixes
Pull Request resolved: https://github.com/pytorch/pytorch/pull/131704
Approved by: https://github.com/soulitzer
ghstack dependencies: #131898
2024-07-30 23:02:24 +00:00
Joel Schlosser
d53b11bb6e Strict shape checking for NJTs with TestCase.assertEqual() (#131898)
**Background**: `TestCase.assertEqual()` is commonly used during test case validation. Historically, to support NSTs, the logic was written to compare two nested tensors by unbinding them and comparing their components. This logic applied to NJTs as well, which in practice meant that two NJTs with different nested ints in their shapes could compare equal if their components were equal.

This PR changes the above logic so that NJTs are no longer unbound during comparison, allowing them to receive full shape validation. This makes `TestCase.assertEqual()` stricter for NJTs, requiring them to have the same nested ints in their shapes to compare equal.

Note that some tests rely on the old, looser behavior. To address this, the PR introduces a base `NestedTensorTestCase` that defines a helper function `assertEqualIgnoringNestedInts()` so that these tests can explicitly opt in to the looser comparison behavior.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/131898
Approved by: https://github.com/soulitzer
2024-07-30 20:05:48 +00:00
Animesh Jain
f806128619 [dynamo] Skip <frozen abc> to skip __isisintance__ check on abc objects (#131956)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/131956
Approved by: https://github.com/williamwen42, https://github.com/mlazos
ghstack dependencies: #131827
2024-07-30 05:49:58 +00:00
PyTorch MergeBot
7a7dd8c29e Revert "[NestedTensor] Integrate the softmax operator along the jagged dimension into NestedTensor (#131518)"
This reverts commit bcf5c68c18.

Reverted https://github.com/pytorch/pytorch/pull/131518 on behalf of https://github.com/ZainRizvi due to Sorry, reverting this since this is based on an internal diff that has diverged from actual internal commit (the final PR and diff must always be identical). Conflicts arise when that happens which block the diff train. Let's revert both this PR and the internal diff, and then reland them as a proper new codev diff ([comment](https://github.com/pytorch/pytorch/pull/131518#issuecomment-2257259839))
2024-07-30 00:55:10 +00:00
PyTorch MergeBot
be5e44192d Revert "[NestedTensor] Integrate the layer normalization operator along the jagged dimension into NestedTensor (#131519)"
This reverts commit 8fe2bf212d.

Reverted https://github.com/pytorch/pytorch/pull/131519 on behalf of https://github.com/ZainRizvi due to Sorry, reverting this since this is based on an internal diff that has diverged from actual internal commit.  Weird conflicts arise when that happens.  Let's revert both this PR and the internal diff, and then reland them as a proper new codev diff ([comment](https://github.com/pytorch/pytorch/pull/131519#issuecomment-2257230717))
2024-07-30 00:18:22 +00:00
yuqingj
e3dc20c94b [NJT] support cat backward (#132076)
cat_tensors_backward use narrow_symint, so we need to support aten::narrow for NJT.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/132076
Approved by: https://github.com/davidberard98
2024-07-29 23:49:26 +00:00
Janani Sriram
8fe2bf212d [NestedTensor] Integrate the layer normalization operator along the jagged dimension into NestedTensor (#131519)
Modify the existing `layer normalization` operator in PyTorch, invoked by `torch.layer_norm`, to allow for reductions along the jagged dimension of a nested tensor. The function originally had a basic implementation for reducing along 1 non-ragged dimension. This diff, which uses the `aten` padding operator, enables PyTorch users to invoke `torch.nn.functional.layer_norm` on a nested tensor when reducing along the ragged dimension, e.g. `*` in a `(B, *, M)` or `(B, *, M, N)` nested tensor.

Write unit tests based on the `softmax` jagged operator to verify the accuracy of the ragged reduction implementation for `torch.nn.functional.layer_norm`. Add unit tests to verify error handling for unsupported features.

Note that this implementation is limited to nested tensors with `ragged_idx == 1`, i.e. the ragged dimension is not transposed. The layer normalization operator also requires an operation on a 2-dimensional layer; for nested tensors with 4 or more dimensions, I flatten the extra dimensions, then unflatten them after performing layer normalization.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/131519
Approved by: https://github.com/davidberard98
ghstack dependencies: #131518
2024-07-29 22:16:32 +00:00
PyTorch MergeBot
8cdfdb41bc Revert "[NestedTensor] Integrate the layer normalization operator along the jagged dimension into NestedTensor (#131519)"
This reverts commit f862f45730.

Reverted https://github.com/pytorch/pytorch/pull/131519 on behalf of https://github.com/atalman due to broke CI: test_nestedtensor.py::TestNestedTensorSubclassCPU::test_layer_norm_with_lengths_requires_grad_False_components_require_grad_False_cpu_float32 [GH job link](https://github.com/pytorch/pytorch/actions/runs/10121747545/job/27996722731) [HUD commit link](f862f45730) ([comment](https://github.com/pytorch/pytorch/pull/131519#issuecomment-2254167994))
2024-07-27 14:45:47 +00:00
Janani Sriram
f862f45730 [NestedTensor] Integrate the layer normalization operator along the jagged dimension into NestedTensor (#131519)
Modify the existing `layer normalization` operator in PyTorch, invoked by `torch.layer_norm`, to allow for reductions along the jagged dimension of a nested tensor. The function originally had a basic implementation for reducing along 1 non-ragged dimension. This diff, which uses the `aten` padding operator, enables PyTorch users to invoke `torch.nn.functional.layer_norm` on a nested tensor when reducing along the ragged dimension, e.g. `*` in a `(B, *, M)` or `(B, *, M, N)` nested tensor.

Write unit tests based on the `softmax` jagged operator to verify the accuracy of the ragged reduction implementation for `torch.nn.functional.layer_norm`. Add unit tests to verify error handling for unsupported features.

Note that this implementation is limited to nested tensors with `ragged_idx == 1`, i.e. the ragged dimension is not transposed. The layer normalization operator also requires an operation on a 2-dimensional layer; for nested tensors with 4 or more dimensions, I flatten the extra dimensions, then unflatten them after performing layer normalization.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/131519
Approved by: https://github.com/davidberard98
ghstack dependencies: #131518
2024-07-27 07:09:10 +00:00
Janani Sriram
bcf5c68c18 [NestedTensor] Integrate the softmax operator along the jagged dimension into NestedTensor (#131518)
Modify the existing `softmax` operator in PyTorch, invoked by `torch.softmax`, to allow for reductions along the jagged dimension of a nested tensor. The function originally had a basic implementation for reducing along 1 non-ragged dimension. This diff, which uses the aten padding operator, enables PyTorch users to invoke `torch.softmax` on a nested tensor when reducing along the ragged dimension, e.g. `*` in a `(B, *, M)` nested tensor.

Write unit tests based on the `sum` and `mean` jagged operators to verify the accuracy of the ragged reduction implementation for `torch.softmax`. Add unit tests to verify error handling for unsupported features in `NestedTensor` `torch.softmax`.

Note that this implementation is limited to nested tensors with `ragged_idx == 1`, i.e. the ragged dimension is not transposed. In addition, the `softmax` operator is required to take in as input an integer for the reduction dimension `dim`, requiring new unit tests heavily inspired by the `sum` and `mean` jagged operator unit tests. `Softmax` also allows for reducing along the batch dimension.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/131518
Approved by: https://github.com/davidberard98
2024-07-27 07:09:10 +00:00
Janani Sriram
13e806a591 [NestedTensor] Add support for transposed NestedTensors where ragged_idx > 1 for sum and mean operators (#131517)
Add support for transposed, non-contiguous `NestedTensor`s, where `ragged_idx > 1`, for the aten operators `sum` and `mean`. This diff enables reducing along the jagged dimension for non-contiguous `NestedTensor`s, transposed between non-batch dimensions as well as between a ragged and a non-batch dimension. For example, users can now reduce a `NestedTensor` of shape `(B, M, *, N)` along `*` or `(B, N, M, *)` along `*`.

Parametrize existing unit tests and add new unit tests verifying the accuracy of implementations on `NestedTensor`s that transpose between 2 non-batch dimensions as well as between a ragged and a non-batch dimension.

Differential Revision: [D59847927](https://our.internmc.facebook.com/intern/diff/D59847927/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/131517
Approved by: https://github.com/davidberard98
2024-07-26 07:21:32 +00:00
Janani Sriram
e782918b8e [NestedTensor] Add example NestedTensor objects with inner dimension of size 1 to tests reducing along jagged dimension for NestedTensor (#131516)
Add example `NestedTensor`s with inner dimension of size `1` to `_get_example_tensor_lists` with `include_inner_dim_size_1=True`. This diff creates `NestedTensor`s of sizes `(B, *, 1)` and `(B, *, 5, 1)`, ensuring that the current implementations of jagged reductions for `sum` and `mean` hold for tensors of effective shape `(B, *)` and `(B, *, 5)`.

Differential Revision: [D59846023](https://our.internmc.facebook.com/intern/diff/D59846023/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/131516
Approved by: https://github.com/davidberard98
2024-07-24 07:01:39 +00:00
jananisriram
faddb0f30c [NestedTensor] Integrate the mean operator along the jagged dimension into NestedTensor (#131132)
Summary:
Modify the existing `mean` operator in PyTorch, invoked by `torch.mean`, to allow for reductions along the jagged dimension of a nested tensor. The function originally had a basic implementation for reducing along 1 non-ragged dimension. This diff enables PyTorch users to invoke `torch.mean` on a nested tensor when reducing along the ragged dimension, e.g. `*` in a `(B, *, M)` nested tensor.

Parametrize unit tests from `sum` to verify the accuracy of the ragged reduction implementation for `torch.mean`. Add unit tests and parametrize `sum` unit tests to verify error handling for unsupported features in `NestedTensor` `torch.mean`.

Test Plan:
Verify that the new unit test passes via the following command:
```
buck2 run mode/{opt,inplace} //caffe2/test:nested -- --regex test_mean
```

```
buck2 run mode/{opt,inplace} //caffe2/test:nested -- --regex test_jagged_op
```

Differential Revision: D59654668

Pull Request resolved: https://github.com/pytorch/pytorch/pull/131132
Approved by: https://github.com/davidberard98, https://github.com/jbschlosser
2024-07-23 18:48:34 +00:00
jananisriram
28a74b9fa4 [NestedTensor] Integrate sum along the jagged dimension into NestedTensor (#130425)
Summary: Modify the existing `sum` operator in PyTorch, invoked by `torch.sum`, to allow for reductions along the ragged dimension of a nested tensor. This diff enables PyTorch users to invoke `torch.sum` on a nested tensor with `dim=1`, where `ragged_idx=1`.

Functions modified in `caffe2/torch/nested/_internal/ops.py`:
- `sum_dim_IntList()`: The function assumes that `ragged_idx=1`; in the case that `dim=1` as well, where `dim` is the dimension on which we reduce, this diff invokes the PyTorch benchmark found in D58423489. Specifically, this diff pads a nested tensor, e.g. of logical shape `(B, *, M)`, using [`torch.ops.aten._jagged_to_padded_dense_forward`](https://www.internalfb.com/code/fbsource/[92c2a067ab04e3eebc999254fed4ae2fbea6def3]/fbcode/deeplearning/fbgemm/fbgemm_gpu/fb/inductor_lowerings/elementwise_ops.py?lines=26), then reduces across the `*` dimension (`dim == 1`) to a `(B, M)` output tensor.
- `_wrap_jagged_dims()`: This diff adds special handling to allow for the case where `dim` contains `1` and not `0`, but to continue disallowing the case where `dim` contains `0` and not `1`. In this function's creation, I created a helper function, `_get_condition_for_invalid_jagged_reductions()`, which makes it clearer which conditions apply to which operators. Specifically, operators which are enabled with jagged reductions are specified at the top of the file in `SUPPORTED_JAGGED_REDUCTIONS` and have a different set of conditions that need to be tested, as reducing along `dim == 1` without `dim == 0` is now possible.

Functions modified in `caffe2/test/test_nestedtensor.py`:
- `test_sum_int_DimList()`: This diff adds special handling in the `sum` unit test to allow for the case where `dim` contains `1` and not `0`, but to continue disallowing the case where `dim` contains `0` and not `1`.
- `test_sum_int_DimList_ragged_dim_1()`: This diff adds a new unit test which verifies the accuracy and feasibility of reducing along the jagged dimension of a nested tensor.

Notes:
- This diff solely adds functionality for the case in which we reduce only along the ragged dimension. Cases in which we reduce along both the ragged and another dimension, like `dim == (1, 2)`, are not permitted, as this set of diffs focuses primarily on the former.
- The `sum` operator is the only operator which uses the function `_wrap_jagged_dims()`; all other operators use `_wrap_jagged_dim()`. I would like to later look into why this is the case and if we can consolidate this!
- I modified some of the comments in the `sum` function as well as the unit tests for more clarity.

Test Plan:
Verify that existing (`test_sum_int_DimList`) and new (`test_sum_int_DimList_ragged_dim_1`) unit tests pass via the following command:

```
buck2 run mode/{opt,inplace} //caffe2/test:nested -- --regex test_sum_int_DimList
```

Differential Revision: D59571209

Pull Request resolved: https://github.com/pytorch/pytorch/pull/130425
Approved by: https://github.com/davidberard98
2024-07-18 10:48:18 +00:00
Xuehai Pan
ba48cf6535 [BE][Easy][6/19] enforce style for empty lines in import segments in test/ (#129757)
See https://github.com/pytorch/pytorch/pull/129751#issue-2380881501. Most changes are auto-generated by linter.

You can review these PRs via:

```bash
git diff --ignore-all-space --ignore-blank-lines HEAD~1
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129757
Approved by: https://github.com/ezyang
2024-07-17 06:42:37 +00:00
David Berard
d548417d95 [NJT] throw an exception if nested_tensor_from_jagged is fx-traced without being fx.wrapped (#130702)
The NJT constructor can't be fx-traced safely due to the dummy nt used:

774ca93fd2/torch/nested/_internal/nested_tensor.py (L501-L508)

The error doesn't appear immediately, but appears if you try to move a module with an fx-traced NJT constructor onto a different device, or try to serialize it. Let's throw an error if we try to fx-trace the NJT constructor so users know to wrap the call.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/130702
Approved by: https://github.com/jbschlosser, https://github.com/soulitzer
2024-07-16 19:21:10 +00:00
Joel Schlosser
09b1b113f5 Cache min / max seq len for torch.nested.as_nested_tensor(t) (#130766)
For the `torch.nested.as_nested_tensor(t)` constructor, computing min / max seq len is trivial since the sequence lengths are all the same. Might as well cache them during construction.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/130766
Approved by: https://github.com/YuqingJ, https://github.com/soulitzer
2024-07-16 18:32:47 +00:00
yuqingj
ea4f310ff1 [Nested Tensor][easy] Add softmax backward support (#130602)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/130602
Approved by: https://github.com/davidberard98, https://github.com/jbschlosser
2024-07-16 00:07:42 +00:00
yuqingj
0e79e1f958 [NJT+SDPA]Fix flash_attention output when batch_size=1 and seq_len=1 (#130652)
fix issue  #130196

Pull Request resolved: https://github.com/pytorch/pytorch/pull/130652
Approved by: https://github.com/Skylion007, https://github.com/drisspg, https://github.com/jbschlosser
2024-07-15 19:44:04 +00:00
Xuehai Pan
973037be6a [BE][Easy] apply autofix for ruff rules unnecessary-collection-call (C408): list() / tuple() / dict() (#130199)
This PR changes the empty collection factory call to Python literals:

- `list()` -> `[]`
- `tuple()` -> `()`
- `dict()` -> `{}`

The Python literals are more performant and safer. For example, the bytecode for building an empty dictionary:

```bash
$ python3 -m dis - <<EOS
import collections

d1 = {}
d2 = dict()

dict = collections.OrderedDict
d3 = dict()
EOS
```

```text
  0           0 RESUME                   0

  1           2 LOAD_CONST               0 (0)
              4 LOAD_CONST               1 (None)
              6 IMPORT_NAME              0 (collections)
              8 STORE_NAME               0 (collections)

  3          10 BUILD_MAP                0
             12 STORE_NAME               1 (d1)

  4          14 PUSH_NULL
             16 LOAD_NAME                2 (dict)
             18 CALL                     0
             26 STORE_NAME               3 (d2)

  6          28 LOAD_NAME                0 (collections)
             30 LOAD_ATTR                8 (OrderedDict)
             50 STORE_NAME               2 (dict)

  7          52 PUSH_NULL
             54 LOAD_NAME                2 (dict)
             56 CALL                     0
             64 STORE_NAME               5 (d3)
             66 RETURN_CONST             1 (None)
```

The dict literal `{}` only has one bytecode `BUILD_MAP`, while the factory call `dict()` has three `PUSH_NULL + LOAD_NAME + CALL`. Also, the factory call is not safe if users override the `dict` name in `locals` or `globals` (see the example of replacing with `OrderedDict` above).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/130199
Approved by: https://github.com/malfet
2024-07-11 17:30:28 +00:00
Joel Schlosser
00335a27b4 Accept min / max sequence length in nested_tensor_from_jagged() constructor (#130175)
This PR updates the public API for NJT construction `torch.nested.nested_tensor_from_jagged()` to accept values for min / max sequence length. It's useful to provide these ahead of time to avoid GPU -> CPU syncs from on-demand computation later on.

NB: The test changes are extensive because I reworked the existing `_validate_nt()` helper function used throughout our NJT construction tests to verify more (specifically: expected cached min / max seq len and contiguity).

API design question: should we additionally provide an option to compute these from `offsets` at construction time? I can think of three possible cases during construction:
1. Min / max seq len has already been obtained from *somewhere* (manual calculation, static values, etc.) and they should be used in the cache
2. Min / max seq len should be computed immediately at construction time for use in the cache (ideally, the caller wouldn't have to do this computation manually)
3. Min / max seq len are not needed at all (i.e. SDPA isn't ever called) and computation should be skipped
Pull Request resolved: https://github.com/pytorch/pytorch/pull/130175
Approved by: https://github.com/davidberard98, https://github.com/soulitzer
2024-07-08 22:14:52 +00:00
Joel Schlosser
7192ee0735 Default to input tensor device for as_nested_tensor(t) (#130050)
Fixes #129647
Pull Request resolved: https://github.com/pytorch/pytorch/pull/130050
Approved by: https://github.com/YuqingJ
2024-07-05 17:50:08 +00:00
soulitzer
eeef68671d [autograd] Do not detach when unpacking tensors that do not require grad (#127959)
In this PR:
- Ensure that if a tensor not requiring grad is saved for backward unpacking does not trigger a detach (unless the user installs a saved tensor pack hook that returns a tensor requiring grad).
- Update non-reentrant checkpoint to also no longer detach for this case.

Alternatives:
- For custom autograd Function, you could directly save on ctx to work around this, but that would not work for when we switch to using custom ops.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127959
Approved by: https://github.com/YuqingJ
ghstack dependencies: #125795, #128545, #129262
2024-07-01 21:57:36 +00:00
PyTorch MergeBot
fa6c0fe3e4 Revert "Conversions between strided and jagged layouts for Nested Tensors (#115749)"
This reverts commit 9450e198aa.

Reverted https://github.com/pytorch/pytorch/pull/115749 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally ([comment](https://github.com/pytorch/pytorch/pull/115749#issuecomment-2197790226))
2024-06-29 00:16:47 +00:00