Commit Graph

195 Commits

Author SHA1 Message Date
PyTorch MergeBot
c34f8c7f91 Revert "Lift jagged -> padded dense forward / backward kernels from fbgemm_gpu (#125946)"
This reverts commit 5e69e11d09.

Reverted https://github.com/pytorch/pytorch/pull/125946 on behalf of https://github.com/clee2000 due to sorry the Dr CI fix hasn't been merged yet and its still failing 5e69e11d09 https://github.com/pytorch/pytorch/actions/runs/9228887299/job/25393895252 ([comment](https://github.com/pytorch/pytorch/pull/125946#issuecomment-2130305958))
2024-05-24 20:26:07 +00:00
Joel Schlosser
5e69e11d09 Lift jagged -> padded dense forward / backward kernels from fbgemm_gpu (#125946)
PyTorch can't depend on `fbgemm_gpu` as a dependency because `fbgemm_gpu` already has a dependency on PyTorch. So this PR copy / pastes kernels from `fbgemm_gpu`:
* `dense_to_jagged_forward()` as CUDA registration for new ATen op `_padded_dense_to_jagged_forward()`
* `jagged_to_padded_dense_forward()` as CUDA registration for new ATen op `_jagged_to_padded_dense_forward()`

CPU impls for these new ATen ops will be added in a follow-up PR.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/125946
Approved by: https://github.com/davidberard98
2024-05-24 19:16:29 +00:00
PyTorch MergeBot
99a11efc8a Revert "Lift jagged -> padded dense forward / backward kernels from fbgemm_gpu (#125946)"
This reverts commit e2f081837f.

Reverted https://github.com/pytorch/pytorch/pull/125946 on behalf of https://github.com/clee2000 due to I think dr ci is wrong and the windows build failure is real e2f081837f https://github.com/pytorch/pytorch/actions/runs/9216826622/job/25357819877 ([comment](https://github.com/pytorch/pytorch/pull/125946#issuecomment-2128388126))
2024-05-24 02:37:46 +00:00
Joel Schlosser
e2f081837f Lift jagged -> padded dense forward / backward kernels from fbgemm_gpu (#125946)
PyTorch can't depend on `fbgemm_gpu` as a dependency because `fbgemm_gpu` already has a dependency on PyTorch. So this PR copy / pastes kernels from `fbgemm_gpu`:
* `dense_to_jagged_forward()` as CUDA registration for new ATen op `_padded_dense_to_jagged_forward()`
* `jagged_to_padded_dense_forward()` as CUDA registration for new ATen op `_jagged_to_padded_dense_forward()`

CPU impls for these new ATen ops will be added in a follow-up PR.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/125946
Approved by: https://github.com/davidberard98
2024-05-24 00:42:59 +00:00
eqy
5f64086d08 [NT][SDPA] Bump tolerances for test_sdpa_with_packed_in_proj_cuda_bfloat16 (#126356)
Current tolerances fail on RTX 6000 (Ada) with `Mismatched elements: 2 / 144 (1.4%)`

```
AssertionError: Tensor-likes are not close!

Mismatched elements: 2 / 144 (1.4%)
Greatest absolute difference: 0.002197265625 at index (5, 0, 0) (up to 0.001 allowed)
Greatest relative difference: 0.08203125 at index (3, 0, 0) (up to 0.016 allowed)

To execute this test, run the following from the base repo dir:
     python test/test_nestedtensor.py -k test_sdpa_with_packed_in_proj_cuda_bfloat16

This message can be suppressed by setting PYTORCH_PRINT_REPRO_ON_FAILURE=0

----------------------------------------------------------------------
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/126356
Approved by: https://github.com/drisspg
2024-05-21 03:25:30 +00:00
David Berard
82edc8b5d5 [NT] Make NestedTensor register as having symbolic sizes/strides (#124687)
Fixes #123698

This PR makes TensorImpl::has_symbolic_sizes_strides return false for NestedTensors.

1. It passes in the actual sizes when we call `_make_wrapper_subclass` - this is the change that makes the subclass register as `has_symbolic_sizes_strides() == True`
2. It adds a field to `_make_wrapper_subclass` where an explicit `numel` can be provided. This allows us to skip the numel computation for the storage, which previously fails due to arithmetic on NestedInts.
3. Implements `aten::numel` for NJT - this is separate from the overridden numel in `make_wrapper_subclass` for now. Note also that this means that we leave `dispatch_sizes_strides_policy="sizes"`, so that we call into the custom `numel` implementation (as well as `sizes` and `strides`), because `numel` cannot currently be computed from `sizes` for NJT.

Note also that this depends on #121361, because calling TensorImpl::set_sizes_and_strides() tries to clone the sizes into the tensor, which means that we need `clone` to be implemented on NestedInt.

Differential Revision: [D57225736](https://our.internmc.facebook.com/intern/diff/D57225736)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/124687
Approved by: https://github.com/albanD
2024-05-13 16:50:25 +00:00
soulitzer
4440d0755a Support custom layout call under torch dispatch mode (#125379)
Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/125379
Approved by: https://github.com/jbschlosser
2024-05-02 23:44:12 +00:00
yuqingj
9a70e7f58c [Nested Tensor]Add unit test that cover the internal use cases (#124880)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/124880
Approved by: https://github.com/jbschlosser
2024-04-25 05:04:27 +00:00
David Berard
868e5ced5d [Dispatcher] Collect autograd sequence numbers on PythonTLSSnapshot dispatch keys (#123304)
Fixes #121758

**TL;DR**: When profiling is turned on, the dispatcher will sometimes attach the autograd sequence number to the recorded profiler event. This PR expands the set of profiler events onto which we attach sequence numbers. Before, we'd only attach a sequence number if the current dispatch key was an Autograd dispatch key. Now, we attach a sequence number if the current dispatch key **set** contains Autograd.

**Context**:
The use case for this is torch.profiler for python subclasses.

Autograd attaches a "sequence number" to all ops that it encounters during the forward pass. Then, the corresponding sequence number can be associated with a backward kernel when backward is executed. This is used by the profiler to associate the forward ops to the backward ops; a forward op and a backward op with the same sequence number are "linked" in some post-processing step.

Prior to this PR, this profiler feature didn't work for python subclasses. The reason is that we don't collect profiler information for all the dispatches for a given kernel; we only dispatch the initial `call`, and not the subsequent `redispatch` invocations. Normally, an Autograd key (if we're running with autograd) is the highest dispatch key, so the initial `call` that we profile is an Autograd key, and we collect the sequence number. But when we're dealing with a python subclass, the first dispatch key is PythonTLSSnapshot, which eventually redispatches into Autograd. We don't record the Autograd sequence number in that case because we don't record redispatches.

To fix this, this PR collects a sequence number whenever the dispatch key **set** contains an Autograd key. That means we might sometimes collect multiple events with the same sequence number, or possibly attach sequence numbers when we won't actually use them? (e.g. maybe if the initial dispatch key handler removes Autograd for some reason). Although this might be a bit confusing for users looking directly at the sequence_nr directly, I think the main use case is for the profiler to create fwd-bwd links; and those should still be generated correctly in these cases.

Differential Revision: [D55724190](https://our.internmc.facebook.com/intern/diff/D55724190)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/123304
Approved by: https://github.com/soulitzer
2024-04-12 02:01:15 +00:00
Grace Zhang
3e43dc086a implement bmm to support nested tensor and normal tensor (#119762)
implement bmm to support nested and normal tensor

Pull Request resolved: https://github.com/pytorch/pytorch/pull/119762
Approved by: https://github.com/cyyever, https://github.com/ezyang
2024-04-11 01:10:04 +00:00
William Wen
cbde0f048b [dynamo, 3.12] enable tests disabled due to missing dynamo 3.12 support (#123300)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123300
Approved by: https://github.com/jansel, https://github.com/malfet, https://github.com/zou3519
2024-04-05 20:13:17 +00:00
Joel Schlosser
721dcaff94 Revert usage of NJT views in SDPA (#123215)
For internal purposes, this PR reverts the use of real views in SDPA -> autograd.Function "views" (i.e. `ViewBufferFromNested` and `ViewNestedFromBuffer`). This is a temporary fix to get the FIRST model launched and working.

**Note: this breaks some other Dynamo tests related to SDPA that rely on real views, but the breakage there isn't expected to be likely in a real-world scenario.**

Pull Request resolved: https://github.com/pytorch/pytorch/pull/123215
Approved by: https://github.com/YuqingJ
2024-04-04 18:45:47 +00:00
PyTorch MergeBot
63d17d3c90 Revert "Revert usage of NJT views in SDPA (#123215)"
This reverts commit 0fcddb5625.

Reverted https://github.com/pytorch/pytorch/pull/123215 on behalf of https://github.com/huydhn due to Sorry for reverting your PR but I think it needs to be skipped on ROCm 0fcddb5625 ([comment](https://github.com/pytorch/pytorch/pull/123215#issuecomment-2036080570))
2024-04-04 02:57:09 +00:00
Joel Schlosser
0fcddb5625 Revert usage of NJT views in SDPA (#123215)
For internal purposes, this PR reverts the use of real views in SDPA -> autograd.Function "views" (i.e. `ViewBufferFromNested` and `ViewNestedFromBuffer`). This is a temporary fix to get the FIRST model launched and working.

**Note: this breaks some other Dynamo tests related to SDPA that rely on real views, but the breakage there isn't expected to be likely in a real-world scenario.**

Pull Request resolved: https://github.com/pytorch/pytorch/pull/123215
Approved by: https://github.com/YuqingJ
2024-04-03 23:25:31 +00:00
soulitzer
638b003cb7 [NJT] .to() properly updates device of offsets (#122797)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/122797
Approved by: https://github.com/jbschlosser
2024-04-02 16:07:27 +00:00
PyTorch MergeBot
4290a57e9c Revert "[NJT] .to() properly updates device of offsets (#122797)"
This reverts commit 3e7fd45b40.

Reverted https://github.com/pytorch/pytorch/pull/122797 on behalf of https://github.com/jeffdaily due to Sorry for reverting your change but it is failing CUDA and ROCm jobs in trunk. Please help take a look and reland the change ([comment](https://github.com/pytorch/pytorch/pull/122797#issuecomment-2025473181))
2024-03-28 15:17:45 +00:00
soulitzer
3e7fd45b40 [NJT] .to() properly updates device of offsets (#122797)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/122797
Approved by: https://github.com/jbschlosser
2024-03-28 00:56:23 +00:00
Joel Schlosser
e6986e4317 Public API for NJT construction from jagged components (#121518)
This PR introduces `torch.nested.nested_tensor_from_jagged(values, offsets=None, lengths=None, jagged_dim=1)` (bikeshedding welcome). This is intended to be the main entrypoint for getting an NJT from the `(values, offsets, lengths)` components. The returned NJT is a view of the `values` component.

Note that `torch.nested.nested_tensor()` / `torch.nested.as_nested_tensor()` already exist for constructing an NJT from a list of tensors.

TODO:
* Some doc formatting; suggestions welcome there
* Tests / examples using `jagged_dim != 1`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/121518
Approved by: https://github.com/cpuhrsch
ghstack dependencies: #113279, #113280
2024-03-22 14:48:22 +00:00
Joel Schlosser
470b44c048 Support for torch.nested.as_nested_tensor(t) (#113280)
This PR adds support for tensor inputs to `as_nested_tensor()`. The tensor is treated as a batch of consistently-sized constituents. It utilizes `_nested_view_from_values_offsets()` to return a real view that allows for propagating gradients into inputs.
Co-authored-by: voznesenskym <voznesenskym@gmail.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113280
Approved by: https://github.com/cpuhrsch, https://github.com/soulitzer
ghstack dependencies: #113279
2024-03-22 02:12:37 +00:00
Joel Schlosser
cd6bfc7965 Proper view support for jagged layout NestedTensor (#113279)
This PR:
* Introduces an ATen op for creating true jagged views from a dense values buffer
    * `_nested_view_from_jagged(values, offsets, lengths, ragged_idx, dummy)`
    * This ops is implemented on the Python side using torch.library so we can return a subclass instance
    * `jagged_from_list()` now uses this instead of the old autograd.Function `NestedViewFromBuffer`
    * The latter op is used for non-contiguous JTs returned via `torch.nested.narrow()`
    * `dummy` is an awful hack to ensure that `NestedTensor.__torch_dispatch__()` is invoked for our view
* Introduces an ATen op for accessing the `values` component of an NT via a view
    * `_nested_get_values(nt)`
* **Removes** the autograd.Functions `ViewNestedFromBuffer` and `ViewBufferFromNested` in favor of `nested_from_values_offsets()` / `nested_from_values_offsets_lengths()` and `nt.values()`, respectively.
* Changes test code to prefer `as_nested_tensor()` over `jagged_from_list()` directly
    * Similarly, avoid `buffer_from_jagged()`, preferring `values()`
* Depends on general subclass view fake-ification on the PT2 side (handled solely in previous PRs in the stack)

With these changes, the semantics of jagged layout NTs are such that they are considered a true view of the underlying `values` buffer. This means views of jagged NTs are views of the underlying buffer as well, simplifying some handling.

Differential Revision: [D54269922](https://our.internmc.facebook.com/intern/diff/D54269922)
Co-authored-by: voznesenskym <voznesenskym@gmail.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113279
Approved by: https://github.com/ezyang
2024-03-22 02:12:36 +00:00
PyTorch MergeBot
224beecee6 Revert "Proper view support for jagged layout NestedTensor (#113279)"
This reverts commit 5855c490f0.

Reverted https://github.com/pytorch/pytorch/pull/113279 on behalf of https://github.com/jbschlosser due to Need to fix BC thing ([comment](https://github.com/pytorch/pytorch/pull/113279#issuecomment-2013899762))
2024-03-21 22:03:01 +00:00
PyTorch MergeBot
12e7602cf9 Revert "Support for torch.nested.as_nested_tensor(t) (#113280)"
This reverts commit 17c9c70265.

Reverted https://github.com/pytorch/pytorch/pull/113280 on behalf of https://github.com/jbschlosser due to Need to fix BC thing ([comment](https://github.com/pytorch/pytorch/pull/113280#issuecomment-2013893099))
2024-03-21 22:00:44 +00:00
PyTorch MergeBot
816db3bd29 Revert "Public API for NJT construction from jagged components (#121518)"
This reverts commit d4dff9cf5e.

Reverted https://github.com/pytorch/pytorch/pull/121518 on behalf of https://github.com/jbschlosser due to Need to fix BC thing ([comment](https://github.com/pytorch/pytorch/pull/121518#issuecomment-2013879641))
2024-03-21 21:56:29 +00:00
Christian Puhrsch
16935de961 Support alias for NestedTensorCPU/CUDA (#117711)
Fixes #ISSUE_NUMBER

Co-authored-by: Vincent Moens <vmoens@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/117711
Approved by: https://github.com/albanD
2024-03-21 16:05:52 +00:00
Joel Schlosser
d4dff9cf5e Public API for NJT construction from jagged components (#121518)
This PR introduces `torch.nested.nested_tensor_from_jagged(values, offsets=None, lengths=None, jagged_dim=1)` (bikeshedding welcome). This is intended to be the main entrypoint for getting an NJT from the `(values, offsets, lengths)` components. The returned NJT is a view of the `values` component.

Note that `torch.nested.nested_tensor()` / `torch.nested.as_nested_tensor()` already exist for constructing an NJT from a list of tensors.

TODO:
* Some doc formatting; suggestions welcome there
* Tests / examples using `jagged_dim != 1`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/121518
Approved by: https://github.com/cpuhrsch
ghstack dependencies: #113280
2024-03-21 04:14:17 +00:00
Joel Schlosser
17c9c70265 Support for torch.nested.as_nested_tensor(t) (#113280)
This PR adds support for tensor inputs to `as_nested_tensor()`. The tensor is treated as a batch of consistently-sized constituents. It utilizes `_nested_view_from_values_offsets()` to return a real view that allows for propagating gradients into inputs.
Co-authored-by: voznesenskym <voznesenskym@gmail.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113280
Approved by: https://github.com/cpuhrsch, https://github.com/soulitzer
2024-03-21 04:13:55 +00:00
Joel Schlosser
5855c490f0 Proper view support for jagged layout NestedTensor (#113279)
This PR:
* Introduces an ATen op for creating true jagged views from a dense values buffer
    * `_nested_view_from_jagged(values, offsets, lengths, ragged_idx, dummy)`
    * This ops is implemented on the Python side using torch.library so we can return a subclass instance
    * `jagged_from_list()` now uses this instead of the old autograd.Function `NestedViewFromBuffer`
    * The latter op is used for non-contiguous JTs returned via `torch.nested.narrow()`
    * `dummy` is an awful hack to ensure that `NestedTensor.__torch_dispatch__()` is invoked for our view
* Introduces an ATen op for accessing the `values` component of an NT via a view
    * `_nested_get_values(nt)`
* **Removes** the autograd.Functions `ViewNestedFromBuffer` and `ViewBufferFromNested` in favor of `nested_from_values_offsets()` / `nested_from_values_offsets_lengths()` and `nt.values()`, respectively.
* Changes test code to prefer `as_nested_tensor()` over `jagged_from_list()` directly
    * Similarly, avoid `buffer_from_jagged()`, preferring `values()`
* Depends on general subclass view fake-ification on the PT2 side (handled solely in previous PRs in the stack)

With these changes, the semantics of jagged layout NTs are such that they are considered a true view of the underlying `values` buffer. This means views of jagged NTs are views of the underlying buffer as well, simplifying some handling.

Differential Revision: [D54269922](https://our.internmc.facebook.com/intern/diff/D54269922)
Co-authored-by: voznesenskym <voznesenskym@gmail.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113279
Approved by: https://github.com/ezyang
2024-03-20 23:45:34 +00:00
Joel Schlosser
ea8f6e2e54 Subclass view fake-ification via reified ViewFuncs (#118405)
This PR:
* Uses reified ViewFuncs to swap in fake tensors / symbolic SymInts for view replay during subclass view fake-ification
* Enables automatic dynamic on view bases -> fakeifies according to the resultant symbolic context instead of the old "all-static" approach
* Covers the following view types:
    * subclass -> dense
    * dense -> subclass
    * subclass -> subclass
* Dense -> dense views are handled the old way via an `as_strided()` call, as it's likely there is no view func available

Differential Revision: [D54269082](https://our.internmc.facebook.com/intern/diff/D54269082)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118405
Approved by: https://github.com/ezyang
2024-03-07 19:56:16 +00:00
Chengji Yao
0e604becc5 [NJT] support chunk on batch dim (#119713)
- support chunk op on batch dim
- support empty_like op
- add tests for the like ops

Pull Request resolved: https://github.com/pytorch/pytorch/pull/119713
Approved by: https://github.com/jbschlosser
2024-03-05 17:57:50 +00:00
soulitzer
312ce35c1f Rename singleton int to nested int (#119661)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/119661
Approved by: https://github.com/ezyang
2024-02-16 19:21:17 +00:00
Joel Schlosser
756cf2913d Fix NJT stride access in SDPA dispatcher logic (#119846)
`._stride` -> `._strides`

Adds test to cover this case.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/119846
Approved by: https://github.com/drisspg, https://github.com/ani300, https://github.com/soulitzer
ghstack dependencies: #119910
2024-02-14 22:37:52 +00:00
Joel Schlosser
0560c193a6 Fix meta registration for _flash_attention_forward() [ROCm forward fix] (#119910)
Addresses ROCm failures from #119812

Pull Request resolved: https://github.com/pytorch/pytorch/pull/119910
Approved by: https://github.com/drisspg
2024-02-14 22:37:52 +00:00
Joel Schlosser
31e59766e7 Fix meta registration for _flash_attention_forward() (#119812)
Meta registration wrongly assumes 4D inputs, while the underlying op allows 3D inputs for the `mha_varlen_fwd()` case.
Testing: I added `detach()`es so the NJT test `test_sdpa_compile()` won't fail for a view-related reason. It should pass now with this fix.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/119812
Approved by: https://github.com/drisspg
2024-02-14 02:38:53 +00:00
PyTorch MergeBot
8994f2367d Revert "Fix jagged NT softmax semantics (#119459)"
This reverts commit 6adadbaf79.

Reverted https://github.com/pytorch/pytorch/pull/119459 on behalf of https://github.com/malfet due to broke dynamo, see https://github.com/pytorch/pytorch/actions/runs/7835402753/job/21386634602 ([comment](https://github.com/pytorch/pytorch/pull/119459#issuecomment-1935246413))
2024-02-09 02:31:49 +00:00
Joel Schlosser
6adadbaf79 Fix jagged NT softmax semantics (#119459)
Before: `softmax` definition uses `jagged_unary_pointwise()` (wrong)
After: `softmax` impl adjusts the `dim` arg to account for the difference in dimensionality between the outer NT and the NT's `_values`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/119459
Approved by: https://github.com/soulitzer
2024-02-08 20:13:12 +00:00
David Berard
278a0e1600 [NestedTensor] Support binary pointwise ops with >2 inputs (if inputs are non-tensors) (#119419)
It should usually be safe to run pointwise binary ops with >2 inputs. e.g. threshold_backward(tensor, tensor, scalar): we just operate on the values of the nested tensors, and pass in the other args as-is.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/119419
Approved by: https://github.com/soulitzer
2024-02-08 20:06:40 +00:00
Huy Do
3ed9df36a9 Clean up some obsolete TODOs in run_test and several test files (#119113)
* The TODOs in `test/test_nestedtensor.py` has been mitigated, I keep the issue for reference.
* ~~The TODOs in `test/test_ops_fwd_gradients.py` doesn't apply anymore~~
* The TODOs in `run_test.py` to support disabling C++ tests is probably not going to happen.  I have never seen a flaky C++ test that needs to be disabled before.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/119113
Approved by: https://github.com/kit1980
2024-02-03 23:54:30 +00:00
David Berard
460950d3aa [Nested Tensor] Support ragged_idx != 1 on aten::is_same_size, aten::_to_copy (#118442)
is_same_size is needed internally; `_to_copy` should be easy because it doesn't support new layouts.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118442
Approved by: https://github.com/cpuhrsch
2024-01-30 01:32:51 +00:00
David Berard
2842d3c9d3 [Nested Tensor] view: basic support for ragged_idx != 1 and _unsafe_view (#118317)
Uses case: `_unsafe_view` is used in aot_autograd to create a view that doesn't register as a view:

eebe7e1d37/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py (L470-L476)

If a transposed nested tensor (i.e. NT with ragged_idx != 1) encounters this code path, it previously would fail for two reasons: 1) because `_unsafe_view` isn't registered, and 2) because ragged_idx != 1 is not supported. This PR adds support for `_unsafe_view` (completely reusing the implementation of `view`; this just registers `_unsafe_view` as another op using the same implementation). It also adds support for ragged_idx != 1, but only for trivial cases where inp._size == size (the use case used by aot_autograd).

Tests: verify that the result of `_unsafe_view` doesn't have a `_base`, and that simple views on transposed NTs work.

Differential Revision: [D53096814](https://our.internmc.facebook.com/intern/diff/D53096814)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118317
Approved by: https://github.com/soulitzer
2024-01-26 17:29:37 +00:00
David Berard
52c5803088 [NestedTensor] Support ragged_idx != 1 in pointwise ops (#118157)
This PR allows pointwise ops to operate on tensors with ragged_idx != 1. It does this by passing the ragged_idx metadata into the construction of the returned NestedTensor when computing pointwise ops. The assumption is that: pointwise ops can operate directly on the values tensors, and the resulting tensor should have all the same metadata properties as the input tensors. For binary ops, a test is added to verify that adding two tensors with different ragged_idx cannot be added.

Previously:
* unary pointwise ops would error out when performed on nested tensors with ragged_idx != 1
* binary pointwise ops would produce tensors with nonsense shapes

Differential Revision: [D53032641](https://our.internmc.facebook.com/intern/diff/D53032641)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118157
Approved by: https://github.com/jbschlosser
2024-01-25 23:34:15 +00:00
YuqingJ
a97d00cca5 [Nested Tensor]Support SDPA math fallback for jagged layout nested tensor (#116445)
Support this fallback by converting the jagged layout NT to strided layout NT, and the convert the result back to jagged layout NT.
This fallback might not be efficient since it uses unbind, contiguous and split.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/116445
Approved by: https://github.com/soulitzer
2024-01-12 17:30:40 +00:00
PyTorch MergeBot
9f87760160 Revert "[Nested Tensor]Support SDPA math fallback for jagged layout nested tensor (#116445)"
This reverts commit e55a778cbb.

Reverted https://github.com/pytorch/pytorch/pull/116445 on behalf of https://github.com/huydhn due to Sorry for reverting your change, but i see it fails ROCm test in trunk due to an unsupported use case e55a778cbb ([comment](https://github.com/pytorch/pytorch/pull/116445#issuecomment-1888060036))
2024-01-11 22:21:45 +00:00
YuqingJ
e55a778cbb [Nested Tensor]Support SDPA math fallback for jagged layout nested tensor (#116445)
Support this fallback by converting the jagged layout NT to strided layout NT, and the convert the result back to jagged layout NT.
This fallback might not be efficient since it uses unbind, contiguous and split.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/116445
Approved by: https://github.com/soulitzer
2024-01-11 20:28:40 +00:00
Joel Schlosser
f70aeb4ffd Fix backward for reshape() on jagged layout NT (#117137)
Provides symbolic C++-side `reshape_as()` / `reshape()` decomps for jagged layout NTs to make the backwards pass work.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/117137
Approved by: https://github.com/soulitzer
2024-01-10 23:35:07 +00:00
Joel Schlosser
0b0c76bace Support squeeze.dim for jagged NT (#116891)
As title. Needed for `rev_view_func()` of `unsqueeze()`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/116891
Approved by: https://github.com/soulitzer
ghstack dependencies: #115894, #116512
2024-01-06 01:00:53 +00:00
Joel Schlosser
3c21264c9b Introduce reverse view_funcs (#115894)
Part 2 of implementation for general [subclass view fake-ification](https://docs.google.com/document/d/1C5taWiplmX7nKiURXDOAZG2W5VNJ2iV0fQFq92H0Cxw).

Details:
* Codegen `rev_view_func()` alongside `view_func()`
    * Reverse view_func gives you a "base" from a "view": `rev_view_func(new_view) -> new_base` AKA it plays the original view backwards
* Utilizes the functional inverses defined in `FunctionalInverses.cpp`, passing `InverseReturnMode::AlwaysView`
* Manually implements functional inverses for `narrow()` and `chunk()`
* **NB: Multi-output views now set view_func() / rev_view_func() for each of the output views!**
    * Due to this, the `as_view()` overload that operates on a list of views is scrapped in favor of iteration via codegen

Example codegen in `ADInplaceOrViewTypeN.cpp`:
```cpp
at::Tensor narrow(c10::DispatchKeySet ks, const at::Tensor & self, int64_t dim, c10::SymInt start, c10::SymInt length) {
  auto _tmp = ([&]() {
    at::AutoDispatchBelowADInplaceOrView guard;
    return at::_ops::narrow::redispatch(ks & c10::after_ADInplaceOrView_keyset, self, dim, start, length);
  })();
  std::function<at::Tensor(const at::Tensor&)> func=nullptr;
  std::function<at::Tensor(const at::Tensor&)> rev_func=nullptr;
  if (false || !self.unsafeGetTensorImpl()->support_as_strided() ||
      c10::AutogradState::get_tls_state().get_view_replay_enabled()) {
    func = [=](const at::Tensor& input_base) {
      return at::_ops::narrow::call(input_base, dim, start, length);
    };
    rev_func = [=](const at::Tensor& input_view) {
      // NB: args from narrow() signature are passed along to the inverse
      return at::functionalization::FunctionalInverses::narrow_copy_inverse(self, input_view, at::functionalization::InverseReturnMode::AlwaysView, dim, start, length);
    };
  }
  auto result = as_view(/* base */ self, /* output */ _tmp, /* is_bw_differentiable */ true, /* is_fw_differentiable */ true, /* view_func */ func, /* rev_view_func */ rev_func, /* creation_meta */ InferenceMode::is_enabled() ? CreationMeta::INFERENCE_MODE : (at::GradMode::is_enabled() ? CreationMeta::DEFAULT : CreationMeta::NO_GRAD_MODE));
  return result;
}
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/115894
Approved by: https://github.com/soulitzer
2024-01-05 16:48:12 +00:00
Joel Schlosser
ea3a5f8ddc Add chunk for jagged layout NT (#115842)
Nice to have for the [SDPA tutorial](https://pytorch.org/tutorials/intermediate/scaled_dot_product_attention_tutorial.html)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/115842
Approved by: https://github.com/soulitzer
ghstack dependencies: #115192, #116111
2023-12-20 20:13:20 +00:00
Joel Schlosser
29b198dcf8 Add markDynamoStrictTest to NT tests (#116111)
Decorates all NT tests with `@markDynamoStrictTest` to ensure we get the correct signal. Adds xfails where needed to get things passing.

Includes a fix in meta_utils.py for a bug that was breaking several python 3.11 tests. In particular, a dense tensor graph input that is a view of a strided NT would slip past Dynamo's check and break in meta-ification.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/116111
Approved by: https://github.com/soulitzer, https://github.com/zou3519
ghstack dependencies: #115192
2023-12-20 20:13:20 +00:00
Joel Schlosser
1474eb5f29 Fix jagged composite impl of flatten() (#115192)
Need to handle this in `NestedTensor.__torch_function__()` since it's CompositeImplicit
Pull Request resolved: https://github.com/pytorch/pytorch/pull/115192
Approved by: https://github.com/cpuhrsch, https://github.com/soulitzer
2023-12-19 19:15:21 +00:00
Joel Schlosser
bf62511e07 Reshape decomposition for jagged layout NT (#115191)
No more segfault from using `reshape()` on jagged NT :)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/115191
Approved by: https://github.com/cpuhrsch, https://github.com/soulitzer
2023-12-18 22:34:41 +00:00