Commit Graph

84 Commits

Author SHA1 Message Date
soulitzer
f50621989b Construct NJT without graph breaks (#130292)
Combines contributions from https://github.com/pytorch/pytorch/pull/130505

Some context can be found in this large comment block:

a5b64d39fd/test/dynamo/test_subclasses.py (L1667-L1681)

Changes in this PR
- For each tensor fakified, check the nested int registry in eager, and eagerly symbolicize if that tensor has already been associated with nested int in eager.
- Adds a separate counter stored on FakeTensorMode as a fake analog to _tensor_id_counter (which keeps track of unique tensors). This counter is initialized to the global eager tensor id counter upon creation of the FakeTensorMode, and needs to be reset when the same FakeTensorMode is reused to trace again (in this PR, we piggyback on the epoch incrementing logic).
- (refactor) Today, we store FakeTensor -> symbolic nested int in the global registry. With this PR, symbolic nested int is stored directly on the FakeTensor. (Eager still caches nested int in the registry, though we should avoid this at some point.)

Basically unchanged, but worth noting:
- `__tensor_unflatten__` is still responsible for determining whether we should cache for now. The logic is somewhat simplified.
- to_copy is still using the trick of updating two different tensors in the registry to point to the same nested int. This is kind of broken, but we try to leave it as is, and plan a better fix with the UnionFind stack.

Differential Revision: [D60406772](https://our.internmc.facebook.com/intern/diff/D60406772)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/130292
Approved by: https://github.com/bdhirsh
ghstack dependencies: #131916, #131803
2024-08-06 17:03:39 +00:00
PyTorch MergeBot
38674bcb45 Revert "Conversions between strided and jagged layouts for Nested Tensors (#115749)"
This reverts commit eca0cb0fbe.

Reverted https://github.com/pytorch/pytorch/pull/115749 on behalf of https://github.com/izaitsevfb due to breaks test_overrides.py::TestTorchFunctionWarning::test_warn_on_invalid_torch_function_tensor_subclass ([comment](https://github.com/pytorch/pytorch/pull/115749#issuecomment-2270213988))
2024-08-06 01:55:41 +00:00
Antoni Viros
eca0cb0fbe Conversions between strided and jagged layouts for Nested Tensors (#115749)
This PR does 3 things:
1. Adds a copy-free strided->jagged layout conversion for NT
2. Adds a copy-free jagged->strided layout conversion for NT
3. Modifies and expands the .to() API to support the layout argument for the specific case of NT layout conversion.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115749
Approved by: https://github.com/jbschlosser
2024-08-05 23:45:48 +00:00
Xuehai Pan
d2dc173664 Remove lint dependency ufmt (#132573)
`ufmt` is a combination of `black + usort`.

This PR removes `ufmt` and run `black` and `usort` separately.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/132573
Approved by: https://github.com/ezyang
ghstack dependencies: #129769, #132572
2024-08-04 10:24:09 +00:00
Xuehai Pan
f3fce597e9 [BE][Easy][17/19] enforce style for empty lines in import segments in torch/[a-c]*/ and torch/[e-n]*/ (#129769)
See https://github.com/pytorch/pytorch/pull/129751#issue-2380881501. Most changes are auto-generated by linter.

You can review these PRs via:

```bash
git diff --ignore-all-space --ignore-blank-lines HEAD~1
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129769
Approved by: https://github.com/ezyang
2024-08-04 10:24:09 +00:00
Joel Schlosser
7eb2a99585 Fix to support unary pointwise ops when an NJT is not the first arg (#131937)
**Background:** NJT utilizes a `jagged_unary_pointwise()` fallback that historically has assumed blindly that the first arg is an NJT. This assumption breaks certain ops; for example `pow(scalar, Tensor)` has an NJT as the second arg.

This PR expands `jagged_unary_pointwise()` and the associated schema validation logic to handle an NJT in args other than the first position.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/131937
Approved by: https://github.com/soulitzer
ghstack dependencies: #131898, #131704
2024-07-31 17:51:03 +00:00
Janani Sriram
46994e753b [NestedTensor] Integrate the layer normalization operator along the jagged dimension into NestedTensor (#132172)
Modify the existing `layer normalization` operator in PyTorch, invoked by `torch.layer_norm`, to allow for reductions along the jagged dimension of a nested tensor. The function originally had a basic implementation for reducing along 1 non-ragged dimension. This diff, which uses the `aten` padding operator, enables PyTorch users to invoke `torch.nn.functional.layer_norm` on a nested tensor when reducing along the ragged dimension, e.g. `*` in a `(B, *, M)` or `(B, *, M, N)` nested tensor.

Write unit tests based on the `softmax` jagged operator to verify the accuracy of the ragged reduction implementation for `torch.nn.functional.layer_norm`. Add unit tests to verify error handling for unsupported features.

Note that this implementation is limited to nested tensors with `ragged_idx == 1`, i.e. the ragged dimension is not transposed. The layer normalization operator also requires an operation on a 2-dimensional layer; for nested tensors with 4 or more dimensions, I flatten the extra dimensions, then unflatten them after performing layer normalization.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/132172
Approved by: https://github.com/davidberard98
ghstack dependencies: #132170
2024-07-31 10:51:46 +00:00
Janani Sriram
89053e382a [NestedTensor] Integrate the softmax operator along the jagged dimension into NestedTensor (#132170)
Modify the existing `softmax` operator in PyTorch, invoked by `torch.softmax`, to allow for reductions along the jagged dimension of a nested tensor. The function originally had a basic implementation for reducing along 1 non-ragged dimension. This diff, which uses the aten padding operator, enables PyTorch users to invoke `torch.softmax` on a nested tensor when reducing along the ragged dimension, e.g. `*` in a `(B, *, M)` nested tensor.

Write unit tests based on the `sum` and `mean` jagged operators to verify the accuracy of the ragged reduction implementation for `torch.softmax`. Add unit tests to verify error handling for unsupported features in `NestedTensor` `torch.softmax`.

Note that this implementation is limited to nested tensors with `ragged_idx == 1`, i.e. the ragged dimension is not transposed. In addition, the `softmax` operator is required to take in as input an integer for the reduction dimension `dim`, requiring new unit tests heavily inspired by the `sum` and `mean` jagged operator unit tests. `Softmax` also allows for reducing along the batch dimension.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/132170
Approved by: https://github.com/davidberard98
2024-07-31 10:51:46 +00:00
Joel Schlosser
524aac413c Initial OpInfo-based testing for NJTs (#131704)
This PR utilizes the info from the existing OpInfo database `op_db` to contribute to general NJT testing.
* New tests in `TestNestedTensorOpInfo`
    * `test_forward()` - compares forward output to an unbind-based reference
    * `test_backward()` - compares forward output and grads to an unbind-based reference
    * `test_forward_compile()` - compares forward compile output (`backend="aot_eager_decomp_partition"`) to eager
    * `test_backward_compile()` - compares forward compile output (`backend="aot_eager_decomp_partition"`) and grads to eager
* To avoid adding a bunch of NJT-specific stuff to the `OpInfo` structure, this PR translates `op_db` -> a NJT-specific `njt_op_db`.
    * `UnaryUfuncInfo`s utilize a new `sample_inputs_unary_njt_pointwise()` which iterates through a comprehensive list of NJTs: contiguous / non-contiguous, dims 2, 3, and 4, transposed / not, etc.
    * `BinaryUfuncInfo`s utilize a new `sample_inputs_binary_njt_pointwise()` which iterates through a comprehensive list of NJTs: contiguous / non-contiguous, dims 2, 3, and 4, transposed / not, etc.
    * `ReductionOpInfo`s utilize a new `sample_inputs_njt_reduction()` which covers full reductions, reductions over the jagged dim, and reductions over the non-jagged dim
* Several xfails were added to get things passing

TODO (future PRs):
* Pass non-contiguous / non-contiguous with holes NJTs (maybe we should have separate tests for these? most ops don't support NJTs with holes today)
* Mixed (NT, T), (T, NT) inputs for binary ops
* Handle other types of OpInfos (beyond unary pointwise, binary pointwise, and reduction) by manually by writing sample_inputs_funcs
* Address all xfails via fixes
Pull Request resolved: https://github.com/pytorch/pytorch/pull/131704
Approved by: https://github.com/soulitzer
ghstack dependencies: #131898
2024-07-30 23:02:24 +00:00
Joel Schlosser
d53b11bb6e Strict shape checking for NJTs with TestCase.assertEqual() (#131898)
**Background**: `TestCase.assertEqual()` is commonly used during test case validation. Historically, to support NSTs, the logic was written to compare two nested tensors by unbinding them and comparing their components. This logic applied to NJTs as well, which in practice meant that two NJTs with different nested ints in their shapes could compare equal if their components were equal.

This PR changes the above logic so that NJTs are no longer unbound during comparison, allowing them to receive full shape validation. This makes `TestCase.assertEqual()` stricter for NJTs, requiring them to have the same nested ints in their shapes to compare equal.

Note that some tests rely on the old, looser behavior. To address this, the PR introduces a base `NestedTensorTestCase` that defines a helper function `assertEqualIgnoringNestedInts()` so that these tests can explicitly opt in to the looser comparison behavior.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/131898
Approved by: https://github.com/soulitzer
2024-07-30 20:05:48 +00:00
PyTorch MergeBot
7a7dd8c29e Revert "[NestedTensor] Integrate the softmax operator along the jagged dimension into NestedTensor (#131518)"
This reverts commit bcf5c68c18.

Reverted https://github.com/pytorch/pytorch/pull/131518 on behalf of https://github.com/ZainRizvi due to Sorry, reverting this since this is based on an internal diff that has diverged from actual internal commit (the final PR and diff must always be identical). Conflicts arise when that happens which block the diff train. Let's revert both this PR and the internal diff, and then reland them as a proper new codev diff ([comment](https://github.com/pytorch/pytorch/pull/131518#issuecomment-2257259839))
2024-07-30 00:55:10 +00:00
PyTorch MergeBot
be5e44192d Revert "[NestedTensor] Integrate the layer normalization operator along the jagged dimension into NestedTensor (#131519)"
This reverts commit 8fe2bf212d.

Reverted https://github.com/pytorch/pytorch/pull/131519 on behalf of https://github.com/ZainRizvi due to Sorry, reverting this since this is based on an internal diff that has diverged from actual internal commit.  Weird conflicts arise when that happens.  Let's revert both this PR and the internal diff, and then reland them as a proper new codev diff ([comment](https://github.com/pytorch/pytorch/pull/131519#issuecomment-2257230717))
2024-07-30 00:18:22 +00:00
yuqingj
e3dc20c94b [NJT] support cat backward (#132076)
cat_tensors_backward use narrow_symint, so we need to support aten::narrow for NJT.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/132076
Approved by: https://github.com/davidberard98
2024-07-29 23:49:26 +00:00
Janani Sriram
8fe2bf212d [NestedTensor] Integrate the layer normalization operator along the jagged dimension into NestedTensor (#131519)
Modify the existing `layer normalization` operator in PyTorch, invoked by `torch.layer_norm`, to allow for reductions along the jagged dimension of a nested tensor. The function originally had a basic implementation for reducing along 1 non-ragged dimension. This diff, which uses the `aten` padding operator, enables PyTorch users to invoke `torch.nn.functional.layer_norm` on a nested tensor when reducing along the ragged dimension, e.g. `*` in a `(B, *, M)` or `(B, *, M, N)` nested tensor.

Write unit tests based on the `softmax` jagged operator to verify the accuracy of the ragged reduction implementation for `torch.nn.functional.layer_norm`. Add unit tests to verify error handling for unsupported features.

Note that this implementation is limited to nested tensors with `ragged_idx == 1`, i.e. the ragged dimension is not transposed. The layer normalization operator also requires an operation on a 2-dimensional layer; for nested tensors with 4 or more dimensions, I flatten the extra dimensions, then unflatten them after performing layer normalization.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/131519
Approved by: https://github.com/davidberard98
ghstack dependencies: #131518
2024-07-29 22:16:32 +00:00
PyTorch MergeBot
945bf78894 Revert "[BE] typing for decorators - fx/_compatibility (#131568)"
This reverts commit 193f62fde9.

Reverted https://github.com/pytorch/pytorch/pull/131568 on behalf of https://github.com/clee2000 due to same as https://github.com/pytorch/pytorch/pull/131572#issuecomment-2254328359 but I clicked the wrong link by accident.  This is where it actually starts ([comment](https://github.com/pytorch/pytorch/pull/131568#issuecomment-2254330781))
2024-07-28 03:43:39 +00:00
PyTorch MergeBot
8cdfdb41bc Revert "[NestedTensor] Integrate the layer normalization operator along the jagged dimension into NestedTensor (#131519)"
This reverts commit f862f45730.

Reverted https://github.com/pytorch/pytorch/pull/131519 on behalf of https://github.com/atalman due to broke CI: test_nestedtensor.py::TestNestedTensorSubclassCPU::test_layer_norm_with_lengths_requires_grad_False_components_require_grad_False_cpu_float32 [GH job link](https://github.com/pytorch/pytorch/actions/runs/10121747545/job/27996722731) [HUD commit link](f862f45730) ([comment](https://github.com/pytorch/pytorch/pull/131519#issuecomment-2254167994))
2024-07-27 14:45:47 +00:00
Janani Sriram
f862f45730 [NestedTensor] Integrate the layer normalization operator along the jagged dimension into NestedTensor (#131519)
Modify the existing `layer normalization` operator in PyTorch, invoked by `torch.layer_norm`, to allow for reductions along the jagged dimension of a nested tensor. The function originally had a basic implementation for reducing along 1 non-ragged dimension. This diff, which uses the `aten` padding operator, enables PyTorch users to invoke `torch.nn.functional.layer_norm` on a nested tensor when reducing along the ragged dimension, e.g. `*` in a `(B, *, M)` or `(B, *, M, N)` nested tensor.

Write unit tests based on the `softmax` jagged operator to verify the accuracy of the ragged reduction implementation for `torch.nn.functional.layer_norm`. Add unit tests to verify error handling for unsupported features.

Note that this implementation is limited to nested tensors with `ragged_idx == 1`, i.e. the ragged dimension is not transposed. The layer normalization operator also requires an operation on a 2-dimensional layer; for nested tensors with 4 or more dimensions, I flatten the extra dimensions, then unflatten them after performing layer normalization.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/131519
Approved by: https://github.com/davidberard98
ghstack dependencies: #131518
2024-07-27 07:09:10 +00:00
Janani Sriram
bcf5c68c18 [NestedTensor] Integrate the softmax operator along the jagged dimension into NestedTensor (#131518)
Modify the existing `softmax` operator in PyTorch, invoked by `torch.softmax`, to allow for reductions along the jagged dimension of a nested tensor. The function originally had a basic implementation for reducing along 1 non-ragged dimension. This diff, which uses the aten padding operator, enables PyTorch users to invoke `torch.softmax` on a nested tensor when reducing along the ragged dimension, e.g. `*` in a `(B, *, M)` nested tensor.

Write unit tests based on the `sum` and `mean` jagged operators to verify the accuracy of the ragged reduction implementation for `torch.softmax`. Add unit tests to verify error handling for unsupported features in `NestedTensor` `torch.softmax`.

Note that this implementation is limited to nested tensors with `ragged_idx == 1`, i.e. the ragged dimension is not transposed. In addition, the `softmax` operator is required to take in as input an integer for the reduction dimension `dim`, requiring new unit tests heavily inspired by the `sum` and `mean` jagged operator unit tests. `Softmax` also allows for reducing along the batch dimension.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/131518
Approved by: https://github.com/davidberard98
2024-07-27 07:09:10 +00:00
Janani Sriram
13e806a591 [NestedTensor] Add support for transposed NestedTensors where ragged_idx > 1 for sum and mean operators (#131517)
Add support for transposed, non-contiguous `NestedTensor`s, where `ragged_idx > 1`, for the aten operators `sum` and `mean`. This diff enables reducing along the jagged dimension for non-contiguous `NestedTensor`s, transposed between non-batch dimensions as well as between a ragged and a non-batch dimension. For example, users can now reduce a `NestedTensor` of shape `(B, M, *, N)` along `*` or `(B, N, M, *)` along `*`.

Parametrize existing unit tests and add new unit tests verifying the accuracy of implementations on `NestedTensor`s that transpose between 2 non-batch dimensions as well as between a ragged and a non-batch dimension.

Differential Revision: [D59847927](https://our.internmc.facebook.com/intern/diff/D59847927/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/131517
Approved by: https://github.com/davidberard98
2024-07-26 07:21:32 +00:00
Aaron Orenstein
193f62fde9 [BE] typing for decorators - fx/_compatibility (#131568)
See #131429

Pull Request resolved: https://github.com/pytorch/pytorch/pull/131568
Approved by: https://github.com/justinchuby, https://github.com/oulgen, https://github.com/zou3519
2024-07-25 22:24:19 +00:00
jananisriram
faddb0f30c [NestedTensor] Integrate the mean operator along the jagged dimension into NestedTensor (#131132)
Summary:
Modify the existing `mean` operator in PyTorch, invoked by `torch.mean`, to allow for reductions along the jagged dimension of a nested tensor. The function originally had a basic implementation for reducing along 1 non-ragged dimension. This diff enables PyTorch users to invoke `torch.mean` on a nested tensor when reducing along the ragged dimension, e.g. `*` in a `(B, *, M)` nested tensor.

Parametrize unit tests from `sum` to verify the accuracy of the ragged reduction implementation for `torch.mean`. Add unit tests and parametrize `sum` unit tests to verify error handling for unsupported features in `NestedTensor` `torch.mean`.

Test Plan:
Verify that the new unit test passes via the following command:
```
buck2 run mode/{opt,inplace} //caffe2/test:nested -- --regex test_mean
```

```
buck2 run mode/{opt,inplace} //caffe2/test:nested -- --regex test_jagged_op
```

Differential Revision: D59654668

Pull Request resolved: https://github.com/pytorch/pytorch/pull/131132
Approved by: https://github.com/davidberard98, https://github.com/jbschlosser
2024-07-23 18:48:34 +00:00
jananisriram
28a74b9fa4 [NestedTensor] Integrate sum along the jagged dimension into NestedTensor (#130425)
Summary: Modify the existing `sum` operator in PyTorch, invoked by `torch.sum`, to allow for reductions along the ragged dimension of a nested tensor. This diff enables PyTorch users to invoke `torch.sum` on a nested tensor with `dim=1`, where `ragged_idx=1`.

Functions modified in `caffe2/torch/nested/_internal/ops.py`:
- `sum_dim_IntList()`: The function assumes that `ragged_idx=1`; in the case that `dim=1` as well, where `dim` is the dimension on which we reduce, this diff invokes the PyTorch benchmark found in D58423489. Specifically, this diff pads a nested tensor, e.g. of logical shape `(B, *, M)`, using [`torch.ops.aten._jagged_to_padded_dense_forward`](https://www.internalfb.com/code/fbsource/[92c2a067ab04e3eebc999254fed4ae2fbea6def3]/fbcode/deeplearning/fbgemm/fbgemm_gpu/fb/inductor_lowerings/elementwise_ops.py?lines=26), then reduces across the `*` dimension (`dim == 1`) to a `(B, M)` output tensor.
- `_wrap_jagged_dims()`: This diff adds special handling to allow for the case where `dim` contains `1` and not `0`, but to continue disallowing the case where `dim` contains `0` and not `1`. In this function's creation, I created a helper function, `_get_condition_for_invalid_jagged_reductions()`, which makes it clearer which conditions apply to which operators. Specifically, operators which are enabled with jagged reductions are specified at the top of the file in `SUPPORTED_JAGGED_REDUCTIONS` and have a different set of conditions that need to be tested, as reducing along `dim == 1` without `dim == 0` is now possible.

Functions modified in `caffe2/test/test_nestedtensor.py`:
- `test_sum_int_DimList()`: This diff adds special handling in the `sum` unit test to allow for the case where `dim` contains `1` and not `0`, but to continue disallowing the case where `dim` contains `0` and not `1`.
- `test_sum_int_DimList_ragged_dim_1()`: This diff adds a new unit test which verifies the accuracy and feasibility of reducing along the jagged dimension of a nested tensor.

Notes:
- This diff solely adds functionality for the case in which we reduce only along the ragged dimension. Cases in which we reduce along both the ragged and another dimension, like `dim == (1, 2)`, are not permitted, as this set of diffs focuses primarily on the former.
- The `sum` operator is the only operator which uses the function `_wrap_jagged_dims()`; all other operators use `_wrap_jagged_dim()`. I would like to later look into why this is the case and if we can consolidate this!
- I modified some of the comments in the `sum` function as well as the unit tests for more clarity.

Test Plan:
Verify that existing (`test_sum_int_DimList`) and new (`test_sum_int_DimList_ragged_dim_1`) unit tests pass via the following command:

```
buck2 run mode/{opt,inplace} //caffe2/test:nested -- --regex test_sum_int_DimList
```

Differential Revision: D59571209

Pull Request resolved: https://github.com/pytorch/pytorch/pull/130425
Approved by: https://github.com/davidberard98
2024-07-18 10:48:18 +00:00
yuqingj
ea4f310ff1 [Nested Tensor][easy] Add softmax backward support (#130602)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/130602
Approved by: https://github.com/davidberard98, https://github.com/jbschlosser
2024-07-16 00:07:42 +00:00
Joel Schlosser
257b9c7936 Fix layout for *_like() factories on NJTs (#129879)
Background: this bug was triggering DEBUG=1 asserts in the backward for `unbind()`, which calls `empty_like()`. I found that the NJT implementation of `empty_like()` was redispatching on `values` while blindly passing along all kwargs. This resulted in `empty_like(values, ..., layout=torch.jagged)`, which is incorrect since `values` is strided, tripping the debug assert here:

433b691f98/aten/src/ATen/EmptyTensor.cpp (L305)

This PR explicitly sets `layout=torch.strided` when redispatching `*_like()` factories on `values`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129879
Approved by: https://github.com/soulitzer
2024-07-02 14:51:23 +00:00
PyTorch MergeBot
fa6c0fe3e4 Revert "Conversions between strided and jagged layouts for Nested Tensors (#115749)"
This reverts commit 9450e198aa.

Reverted https://github.com/pytorch/pytorch/pull/115749 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally ([comment](https://github.com/pytorch/pytorch/pull/115749#issuecomment-2197790226))
2024-06-29 00:16:47 +00:00
Antoni Viros
9450e198aa Conversions between strided and jagged layouts for Nested Tensors (#115749)
This PR does 3 things:
1. Adds a copy-free strided->jagged layout conversion for NT
2. Adds a copy-free jagged->strided layout conversion for NT
3. Modifies and expands the .to() API to support the layout argument for the specific case of NT layout conversion.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115749
Approved by: https://github.com/jbschlosser
2024-06-27 03:41:28 +00:00
Joel Schlosser
e364290718 Support linear backward for NJT with dim > 3 (#129393)
Replaces usage of `torch.mm()` with `torch.matmul()` in NJT's impl of linear_backward to support higher dims. See [here](https://github.com/pytorch/pytorch/issues/125214#issuecomment-2184968703) for more context.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129393
Approved by: https://github.com/soulitzer
2024-06-25 16:06:23 +00:00
Joel Schlosser
e1c1052829 Backward support for unbind() with NJT (#128032)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/128032
Approved by: https://github.com/soulitzer
2024-06-21 14:05:23 +00:00
Joel Schlosser
31d5753247 Short-term fix to preserve NJT metadata cache in torch.compile (#122836)
Idea: close over min / max sequence length in the main NJT view func (`_nested_view_from_jagged`) so that view replay during fake-ification propagates these correctly in torch.compile.

For dynamic shapes support for min / max sequence length, this PR uses a hack that stores the values in `(val, 0)` shaped tensors.

**NB: This PR changes SDPA to operate on real views instead of using `buffer_from_jagged()` / `ViewNestedFromBuffer`, which may impact the internal FIRST model. That is, it undoes the partial revert from #123215 alongside a fix to the problem that required the partial revert. We need to verify that there are no regressions there before landing.**

Differential Revision: [D55448636](https://our.internmc.facebook.com/intern/diff/D55448636)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/122836
Approved by: https://github.com/soulitzer
2024-06-20 23:15:53 +00:00
PyTorch MergeBot
b0d2fe6299 Revert "Short-term fix to preserve NJT metadata cache in torch.compile (#122836)"
This reverts commit 2a41fc0390.

Reverted https://github.com/pytorch/pytorch/pull/122836 on behalf of https://github.com/jbschlosser due to internal test failures with DEBUG=1 asserts ([comment](https://github.com/pytorch/pytorch/pull/122836#issuecomment-2177298245))
2024-06-19 00:28:53 +00:00
PyTorch MergeBot
5ffb032be6 Revert "Backward support for unbind() with NJT (#128032)"
This reverts commit 5dc4f652bc.

Reverted https://github.com/pytorch/pytorch/pull/128032 on behalf of https://github.com/jbschlosser due to reverting to revert parent PR ([comment](https://github.com/pytorch/pytorch/pull/128032#issuecomment-2177296325))
2024-06-19 00:26:40 +00:00
Joel Schlosser
5dc4f652bc Backward support for unbind() with NJT (#128032)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/128032
Approved by: https://github.com/soulitzer
2024-06-18 20:29:00 +00:00
Joel Schlosser
2a41fc0390 Short-term fix to preserve NJT metadata cache in torch.compile (#122836)
Idea: close over min / max sequence length in the main NJT view func (`_nested_view_from_jagged`) so that view replay during fake-ification propagates these correctly in torch.compile.

For dynamic shapes support for min / max sequence length, this PR uses a hack that stores the values in `(val, 0)` shaped tensors.

**NB: This PR changes SDPA to operate on real views instead of using `buffer_from_jagged()` / `ViewNestedFromBuffer`, which may impact the internal FIRST model. That is, it undoes the partial revert from #123215 alongside a fix to the problem that required the partial revert. We need to verify that there are no regressions there before landing.**

Differential Revision: [D55448636](https://our.internmc.facebook.com/intern/diff/D55448636)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/122836
Approved by: https://github.com/soulitzer
ghstack dependencies: #127007, #128057
2024-06-17 15:25:09 +00:00
Joel Schlosser
67e6c76a18 Support apply_(callable) sugar for CPU NJTs (#125416)
Example:
```python
nt.apply_(lambda x: x * 2)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/125416
Approved by: https://github.com/soulitzer
2024-06-12 20:30:57 +00:00
Joel Schlosser
ec1fdda196 Fix jagged NT softmax semantics (#119459)
Before: `softmax` definition uses `jagged_unary_pointwise()` (wrong)
After: `softmax` impl adjusts the `dim` arg to account for the difference in dimensionality between the outer NT and the NT's `_values`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/119459
Approved by: https://github.com/soulitzer
2024-06-12 19:12:03 +00:00
Aaron Orenstein
038b927590 Flip default value for mypy disallow_untyped_defs [7/11] (#127844)
See #127836 for details.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127844
Approved by: https://github.com/oulgen
ghstack dependencies: #127842, #127843
2024-06-08 18:49:45 +00:00
Janani Sriram
c1a43a69e4 [NestedTensor] Add error checks for unbind operator coverage when ragged_idx != 1 (#128058)
Summary:
Add the following error checks for the `unbind` operator on `NestedTensor`s when `ragged_idx != 1`:

- The current implementation allows the creation of `NestedTensor` instances from the class definition with an `offsets` tensor that applies to a dimension other than the jagged dimension. This diff ensures that `unbind` fails when the `offsets` exceed the length of the jagged dimension.

Test Plan:
Added the following unit tests:

`test_unbind_with_lengths_ragged_idx_equals_2_bad_dim_cpu` verifies that `unbind` fails when there is a mismatch between the offsets and the jagged dimension, for `NestedTensor`s with `lengths`.
```
test_unbind_with_lengths_ragged_idx_equals_2_bad_dim_cpu (test_nestedtensor.TestNestedTensorSubclassCPU) ... ok
```

Reviewed By: davidberard98

Differential Revision: D57989082

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128058
Approved by: https://github.com/davidberard98
2024-06-06 01:56:12 +00:00
Janani Sriram
7c3740d388 [NestedTensor] Extend coverage for unbind when ragged_idx != 1 (#127493)
Summary:
Extend coverage for the `NestedTensor` `unbind` operator to cases in which `ragged_idx != 1`.

Currently, the `unbind` operator in the `NestedTensor` class splits a tensor along the 0-th dimension, where the `ragged_idx` property, which controls the jagged dimension upon which `unbind` splits, is 1. This diff extends support for `ragged_idx != 1` in `NestedTensor`s, allowing `unbind` to split a tensor along a jagged dimension greater than 0 for `NestedTensor`s with and without the `lengths` property.

Test Plan:
Added the following unit tests:

`test_unbind_ragged_idx_equals_2_cpu`, `test_unbind_ragged_idx_equals_3_cpu`, and `test_unbind_ragged_idx_equals_last_dim_cpu` verify that `unbind` works for all jagged dimensions greater than 1, for `NestedTensor`s without `lengths`.
```
test_unbind_ragged_idx_equals_2_cpu (test_nestedtensor.TestNestedTensorSubclassCPU) ... ok
test_unbind_ragged_idx_equals_3_cpu (test_nestedtensor.TestNestedTensorSubclassCPU) ... ok
test_unbind_ragged_idx_equals_last_dim_cpu (test_nestedtensor.TestNestedTensorSubclassCPU) ... ok
```

`test_unbind_with_lengths_cpu` and `test_unbind_with_lengths_ragged_idx_equals_1_cpu` verify that `unbind` works when the jagged dimension is 1, for `NestedTensor`s with `lengths`.
```
test_unbind_with_lengths_cpu (test_nestedtensor.TestNestedTensorSubclassCPU) ... ok
test_unbind_with_lengths_ragged_idx_equals_1_cpu (test_nestedtensor.TestNestedTensorSubclassCPU) ... ok
```

`test_unbind_with_lengths_ragged_idx_equals_2_cpu` and `test_unbind_with_lengths_ragged_idx_equals_3_cpu` verify that `unbind` works when the jagged dimension is greater than 1, for `NestedTensor`s with `lengths`.
```
test_unbind_with_lengths_ragged_idx_equals_2_cpu (test_nestedtensor.TestNestedTensorSubclassCPU) ... ok
test_unbind_with_lengths_ragged_idx_equals_3_cpu (test_nestedtensor.TestNestedTensorSubclassCPU) ... ok
```

`test_unbind_with_lengths_ragged_idx_equals_0_cpu` verifies that `unbind` fails when the jagged dimension is 0 (the batch dimension), for `NestedTensor`s with `lengths`.
```
test_unbind_with_lengths_ragged_idx_equals_0_cpu (test_nestedtensor.TestNestedTensorSubclassCPU) ... ok
```

`test_unbind_with_lengths_ragged_idx_equals_2_bad_dim_cpu` verifies that `unbind` fails when there is a mismatch between the offsets and the jagged dimension, for `NestedTensor`s with `lengths`.
```
test_unbind_with_lengths_ragged_idx_equals_2_bad_dim_cpu (test_nestedtensor.TestNestedTensorSubclassCPU) ... ok
```

`test_unbind_with_wrong_lengths_cpu` verifies that `unbind` fails when the lengths exceed the limitations set by offsets, for `NestedTensor`s with `lengths`.

```
test_unbind_with_wrong_lengths_cpu (test_nestedtensor.TestNestedTensorSubclassCPU) ... ok
```

Differential Revision: D57942686

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127493
Approved by: https://github.com/davidberard98
2024-06-03 17:46:12 +00:00
David Berard
82edc8b5d5 [NT] Make NestedTensor register as having symbolic sizes/strides (#124687)
Fixes #123698

This PR makes TensorImpl::has_symbolic_sizes_strides return false for NestedTensors.

1. It passes in the actual sizes when we call `_make_wrapper_subclass` - this is the change that makes the subclass register as `has_symbolic_sizes_strides() == True`
2. It adds a field to `_make_wrapper_subclass` where an explicit `numel` can be provided. This allows us to skip the numel computation for the storage, which previously fails due to arithmetic on NestedInts.
3. Implements `aten::numel` for NJT - this is separate from the overridden numel in `make_wrapper_subclass` for now. Note also that this means that we leave `dispatch_sizes_strides_policy="sizes"`, so that we call into the custom `numel` implementation (as well as `sizes` and `strides`), because `numel` cannot currently be computed from `sizes` for NJT.

Note also that this depends on #121361, because calling TensorImpl::set_sizes_and_strides() tries to clone the sizes into the tensor, which means that we need `clone` to be implemented on NestedInt.

Differential Revision: [D57225736](https://our.internmc.facebook.com/intern/diff/D57225736)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/124687
Approved by: https://github.com/albanD
2024-05-13 16:50:25 +00:00
PyTorch MergeBot
7ffa5558ee Revert "[FX] Update type hints in torch.fx._compatibility.py (#125469)"
This reverts commit 235b4d6ec2.

Reverted https://github.com/pytorch/pytorch/pull/125469 on behalf of https://github.com/izaitsevfb due to breaks pyre in dependent projects (internal: see D56986361) ([comment](https://github.com/pytorch/pytorch/pull/125469#issuecomment-2096665396))
2024-05-06 18:36:43 +00:00
Xuehai Pan
235b4d6ec2 [FX] Update type hints in torch.fx._compatibility.py (#125469)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/125469
Approved by: https://github.com/Skylion007
ghstack dependencies: #125468
2024-05-05 19:30:22 +00:00
Aaron Gokaslan
1d6c5972c1 [BE]: Optimize min/max/sum comprehensions C419 (#123960)
Automatic fixes that replaces certain list comprehensions with generator ones where appropriate so that they are immediately consumed. This is preview functionality in ruff for rule C419 and it was automatically applied.

Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123960
Approved by: https://github.com/malfet
2024-04-12 23:54:15 +00:00
soulitzer
638b003cb7 [NJT] .to() properly updates device of offsets (#122797)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/122797
Approved by: https://github.com/jbschlosser
2024-04-02 16:07:27 +00:00
PyTorch MergeBot
4290a57e9c Revert "[NJT] .to() properly updates device of offsets (#122797)"
This reverts commit 3e7fd45b40.

Reverted https://github.com/pytorch/pytorch/pull/122797 on behalf of https://github.com/jeffdaily due to Sorry for reverting your change but it is failing CUDA and ROCm jobs in trunk. Please help take a look and reland the change ([comment](https://github.com/pytorch/pytorch/pull/122797#issuecomment-2025473181))
2024-03-28 15:17:45 +00:00
soulitzer
3e7fd45b40 [NJT] .to() properly updates device of offsets (#122797)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/122797
Approved by: https://github.com/jbschlosser
2024-03-28 00:56:23 +00:00
Joel Schlosser
6767c04fde Forward fix for broken internal tests related to NJT view dummy (#122704)
(internal link) [example test breakage](https://www.internalfb.com/intern/test/562950061753019?ref_report_id=0)

Symptom: `type stub not overridden` for SymInt. The global NJT dummy relies on `SymInt.__mul__()` in its constructor. Lazily constructing the dummy avoids the race.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/122704
Approved by: https://github.com/soulitzer
2024-03-26 21:22:12 +00:00
Joel Schlosser
cd6bfc7965 Proper view support for jagged layout NestedTensor (#113279)
This PR:
* Introduces an ATen op for creating true jagged views from a dense values buffer
    * `_nested_view_from_jagged(values, offsets, lengths, ragged_idx, dummy)`
    * This ops is implemented on the Python side using torch.library so we can return a subclass instance
    * `jagged_from_list()` now uses this instead of the old autograd.Function `NestedViewFromBuffer`
    * The latter op is used for non-contiguous JTs returned via `torch.nested.narrow()`
    * `dummy` is an awful hack to ensure that `NestedTensor.__torch_dispatch__()` is invoked for our view
* Introduces an ATen op for accessing the `values` component of an NT via a view
    * `_nested_get_values(nt)`
* **Removes** the autograd.Functions `ViewNestedFromBuffer` and `ViewBufferFromNested` in favor of `nested_from_values_offsets()` / `nested_from_values_offsets_lengths()` and `nt.values()`, respectively.
* Changes test code to prefer `as_nested_tensor()` over `jagged_from_list()` directly
    * Similarly, avoid `buffer_from_jagged()`, preferring `values()`
* Depends on general subclass view fake-ification on the PT2 side (handled solely in previous PRs in the stack)

With these changes, the semantics of jagged layout NTs are such that they are considered a true view of the underlying `values` buffer. This means views of jagged NTs are views of the underlying buffer as well, simplifying some handling.

Differential Revision: [D54269922](https://our.internmc.facebook.com/intern/diff/D54269922)
Co-authored-by: voznesenskym <voznesenskym@gmail.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113279
Approved by: https://github.com/ezyang
2024-03-22 02:12:36 +00:00
PyTorch MergeBot
224beecee6 Revert "Proper view support for jagged layout NestedTensor (#113279)"
This reverts commit 5855c490f0.

Reverted https://github.com/pytorch/pytorch/pull/113279 on behalf of https://github.com/jbschlosser due to Need to fix BC thing ([comment](https://github.com/pytorch/pytorch/pull/113279#issuecomment-2013899762))
2024-03-21 22:03:01 +00:00
Joel Schlosser
5855c490f0 Proper view support for jagged layout NestedTensor (#113279)
This PR:
* Introduces an ATen op for creating true jagged views from a dense values buffer
    * `_nested_view_from_jagged(values, offsets, lengths, ragged_idx, dummy)`
    * This ops is implemented on the Python side using torch.library so we can return a subclass instance
    * `jagged_from_list()` now uses this instead of the old autograd.Function `NestedViewFromBuffer`
    * The latter op is used for non-contiguous JTs returned via `torch.nested.narrow()`
    * `dummy` is an awful hack to ensure that `NestedTensor.__torch_dispatch__()` is invoked for our view
* Introduces an ATen op for accessing the `values` component of an NT via a view
    * `_nested_get_values(nt)`
* **Removes** the autograd.Functions `ViewNestedFromBuffer` and `ViewBufferFromNested` in favor of `nested_from_values_offsets()` / `nested_from_values_offsets_lengths()` and `nt.values()`, respectively.
* Changes test code to prefer `as_nested_tensor()` over `jagged_from_list()` directly
    * Similarly, avoid `buffer_from_jagged()`, preferring `values()`
* Depends on general subclass view fake-ification on the PT2 side (handled solely in previous PRs in the stack)

With these changes, the semantics of jagged layout NTs are such that they are considered a true view of the underlying `values` buffer. This means views of jagged NTs are views of the underlying buffer as well, simplifying some handling.

Differential Revision: [D54269922](https://our.internmc.facebook.com/intern/diff/D54269922)
Co-authored-by: voznesenskym <voznesenskym@gmail.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113279
Approved by: https://github.com/ezyang
2024-03-20 23:45:34 +00:00
Chengji Yao
0e604becc5 [NJT] support chunk on batch dim (#119713)
- support chunk op on batch dim
- support empty_like op
- add tests for the like ops

Pull Request resolved: https://github.com/pytorch/pytorch/pull/119713
Approved by: https://github.com/jbschlosser
2024-03-05 17:57:50 +00:00