Commit Graph

166 Commits

Author SHA1 Message Date
Joel Schlosser
a8382847f4 Support rms_norm() for NJT (#135872)
`rms_norm()` is a nice-to-have for ViT :)

This PR:
* SymInt-ifies `rms_norm()`, allowing NJT to use the same decomp.
* Adds torch_function-based input validation logic for nested-specific stuff (no normalization supported over the ragged dim for now) on the python NJT side.
* Adds multi-dim support (on non-ragged, non-batch dims) to `mean()` for NJT.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135872
Approved by: https://github.com/mikaylagawarecki
ghstack dependencies: #125947
2024-09-17 18:09:20 +00:00
Aaron Gokaslan
31715be72a [BE]: Update mypy to 1.11.2 (#133816)
Updates mypy to 1.11.1 to improve type inference

Pull Request resolved: https://github.com/pytorch/pytorch/pull/133816
Approved by: https://github.com/ezyang
2024-09-16 19:44:11 +00:00
PyTorch MergeBot
3117f2cf67 Revert "[BE]: Update mypy to 1.11.2 (#133816)"
This reverts commit 55299cfc22.

Reverted https://github.com/pytorch/pytorch/pull/133816 on behalf of https://github.com/jeanschmidt due to seems to have broken https://github.com/pytorch/pytorch/actions/runs/10865710499/job/30155699792 on main ([comment](https://github.com/pytorch/pytorch/pull/133816#issuecomment-2352377684))
2024-09-16 09:11:16 +00:00
Aaron Gokaslan
55299cfc22 [BE]: Update mypy to 1.11.2 (#133816)
Updates mypy to 1.11.1 to improve type inference

Pull Request resolved: https://github.com/pytorch/pytorch/pull/133816
Approved by: https://github.com/ezyang
2024-09-14 21:40:36 +00:00
Joel Schlosser
06bc717410 Fix sum() forward for NJT (#131945)
This PR solves two problems with `sum()` support in NJT:
* `sum()` over a dim with `keepdim=True` returns the wrong shape (i.e. it'll keep the wrong dim). This is a long-standing bug from way back in #112519.
* Historically, we've only supported `sum()` over a dim and not a full reduction. This PR adds the full reduction form (forward only, backward still fails).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/131945
Approved by: https://github.com/davidberard98, https://github.com/jananisriram
2024-09-14 00:58:03 +00:00
Joel Schlosser
525bec804c NJT <-> padded dense conversions (#125947)
This PR:
* Implements the pre-existing `nt.to_padded_tensor(padding_val)` ATen op via the FBGEMM kernel + appropriate view gymnastics (since that kernel only handles 2D values)
* Introduces a new `_nested_from_padded_tensor` op for the reverse conversion, implemented via the reverse FBGEMM kernel + view gymnastics
    * Note: there is currently no public API for this; design booted to a future PR

TODO:
* ~~Propagate min / max sequence length via the new factory function `_nested_from_padded_tensor`~~
* ~~Verify that Inductor does computation fusion via test logic~~

Pull Request resolved: https://github.com/pytorch/pytorch/pull/125947
Approved by: https://github.com/soulitzer
2024-09-12 17:54:25 +00:00
PyTorch MergeBot
70a65a8bd5 Revert "NJT <-> padded dense conversions (#125947)"
This reverts commit 09a5e88bef.

Reverted https://github.com/pytorch/pytorch/pull/125947 on behalf of https://github.com/huydhn due to Sorry for reverting your change but it is failing dynamo test 09a5e88bef, maybe a landrace ([comment](https://github.com/pytorch/pytorch/pull/125947#issuecomment-2339228570))
2024-09-09 22:01:09 +00:00
Joel Schlosser
09a5e88bef NJT <-> padded dense conversions (#125947)
This PR:
* Implements the pre-existing `nt.to_padded_tensor(padding_val)` ATen op via the FBGEMM kernel + appropriate view gymnastics (since that kernel only handles 2D values)
* Introduces a new `_nested_from_padded_tensor` op for the reverse conversion, implemented via the reverse FBGEMM kernel + view gymnastics
    * Note: there is currently no public API for this; design booted to a future PR

TODO:
* ~~Propagate min / max sequence length via the new factory function `_nested_from_padded_tensor`~~
* ~~Verify that Inductor does computation fusion via test logic~~

Pull Request resolved: https://github.com/pytorch/pytorch/pull/125947
Approved by: https://github.com/soulitzer
2024-09-09 19:37:32 +00:00
yuqingj
defb515306 [NJT]Add permute ops support (#135336)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135336
Approved by: https://github.com/davidberard98
2024-09-08 21:00:41 +00:00
David Berard
4b4ba7ab06 [NJT] Support NJT SDPA + meta-device flop counting (#134289)
A user wants to use the flop counter with meta devices. This previously caused problems for SDPA+NJT:

1. autocast check: `torch.is_autocast_enabled("meta")` fails because `meta` is not valid for autocasting. If we skip this, we run into the next error
2. math backend: conversion to NST requires getting concrete offsets in a list of python integers, which doesn't work on a meta tensor b2eb0e8c6a/torch/nested/_internal/sdpa.py (L809-L815)
3. (fixed in the previous PR, #134288) - if we force using flash attention backend for flop counting, `_flash_attention_forward` previously didn't support meta tensors.

In this PR, we check specifically for FlopCounterMode, and, if it's enabled and combined with meta tensors, (a) skip autocasting and (b) force it down the flash attention path. This isn't generally safe for tracing (e.g. if you actually care which kernels you are running), but in the absence of actual device information, we have to make some assumptions. By specifically checking for FlopCounterMode, this should reduce the chance of unintended side effects for other meta tensor users.

Note: fake tensor would solve a bunch of these issues, but it's not a viable solution right now for the user.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134289
Approved by: https://github.com/soulitzer
ghstack dependencies: #134288
2024-08-29 03:43:42 +00:00
Aaron Orenstein
ed86ac2f25 [BE] typing for decorators - fx/_compatibility (#134054)
Summary: See #131429

Test Plan: unit tests pass

Differential Revision: D61493706

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134054
Approved by: https://github.com/oulgen
2024-08-26 04:00:27 +00:00
Aaron Orenstein
d95aedf5fd [BE] typing for decorators - fx/_compatibility (part 1) (#134202)
Part of #134054.

This corresponds to the pytorch mypy changes from D61493706. Updating takes so
long and touches so many files that it's impossible to land as a whole without conflicting with some other intermediate change.
So landing these 'type: ignore' for pytorch in advance of them actually being needed.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134202
Approved by: https://github.com/Skylion007
2024-08-22 17:07:33 +00:00
yuqingj
44fa9f991c [NJT] add aten.to.dtype support (#134164)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134164
Approved by: https://github.com/davidberard98
2024-08-22 16:59:38 +00:00
krzysztofjordan
2e1830c7c8 Implement 2D version of masked_select for nestedtensors (#133889)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/133889
Approved by: https://github.com/soulitzer
2024-08-20 21:46:32 +00:00
soulitzer
4af4910b1a Reland "Construct NJT without graph breaks" (#133196)
This reverts commit 154d40ca488e6979ce9c2de89d8a35b53129ebea.

and adds changes from https://github.com/pytorch/pytorch/pull/133061

Pull Request resolved: https://github.com/pytorch/pytorch/pull/133196
Approved by: https://github.com/ezyang
ghstack dependencies: #133145
2024-08-14 01:11:13 +00:00
PyTorch MergeBot
656465fc77 Revert "Conversions between strided and jagged layouts for Nested Tensors (#115749)"
This reverts commit ed97fb77f9.

Reverted https://github.com/pytorch/pytorch/pull/115749 on behalf of https://github.com/izaitsevfb due to fails internal jobs, see [S440348](https://www.internalfb.com/sevmanager/view/440348) ([comment](https://github.com/pytorch/pytorch/pull/115749#issuecomment-2285051164))
2024-08-12 23:14:19 +00:00
soulitzer
05de2b2d0f Revert "Construct NJT without graph breaks" (#133145)
This reverts commit 911154271309667b55dfb963ec6384bd0048019b.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/133145
Approved by: https://github.com/YuqingJ
2024-08-10 03:11:16 +00:00
Joel Schlosser
75eb66afc0 Support 'non-contiguous with holes' NJTs for contiguous clone() (#132776)
It's possible to construct an NJT with "holes" by specifying both `offsets` and `lengths` metadata. When `nt.clone(memory_format=torch.contiguous_format)` is called on such an NJT, the result should be an NJT without holes.

This PR fixes this in simplistic way using `unbind()`, which isn't really supported in `torch.compile`. The longer term solution involves writing a proper kernel to support this.

NB: Another limitation is that the returned NJT does not have the same ragged structure as the input. While we could manually hack the nested int registry (or update the union find when that lands), this is the first instance where a NJT with holes and an NJT without holes could have the same ragged structure, and getting those to play nicely together requires some fairly involved updates. For now, this PR punts on these updates until we can clean this up.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/132776
Approved by: https://github.com/ani300, https://github.com/soulitzer
ghstack dependencies: #131898, #131704, #131937
2024-08-08 17:08:11 +00:00
Janani Sriram
0ca8f66e3a [NestedTensor] Modify softmax on ragged dimension to allow for 2D nested tensors (#132812)
Summary:
Modify `softmax` on the ragged dimension, where `ragged_idx == 1`, to allow for 2D nested tensors. This diff now enables a `softmax` operation on tensors of shape `(B, *)`, where `*` is the ragged dimension.

Extend existing `softmax` unit tests to include 2D nested tensors using the `include_2d_tensor=True` keyword argument.

Test Plan:
Verify that existing and modified unit tests pass using the following commands:

```
buck2 run mode/{opt,inplace} //caffe2/test:nested -- --regex test_softmax
```

```
buck2 run mode/{opt,inplace} //caffe2/test:nested -- --regex test_jagged_op
```

Reviewed By: davidberard98

Differential Revision: D60780975

Pull Request resolved: https://github.com/pytorch/pytorch/pull/132812
Approved by: https://github.com/davidberard98
2024-08-08 15:41:28 +00:00
Edward Z. Yang
aec6332356 Only thunkify proxies in some situations (#132421)
The goal of this PR is to avoid stack overflow when we create extremely long chains of thunks, and then evaluate them (e.g., as occurs if you sum(long list of symint)). The basic idea behind this PR is to only thunkify proxies if they're being created in places where they may or may not be used--crucially, symint operations that occur in user code we are tracing are eagerly placed into the graph, even if they may eventually be dead.

I annotated the PR with explanation of changes.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/132421
Approved by: https://github.com/Skylion007, https://github.com/zou3519
ghstack dependencies: #132674, #132675
2024-08-08 12:03:06 +00:00
David Berard
59f4725b49 [NJT] manually autocast in SDPA handling (#132835)
When autocasting is turned on, right now SDPA w/ NJT won't be autocasted. This PR adds manual "autocasting" logic in sdpa.py - at the beginning, it just checks if autocasting is enabled, and if so, it casts the inputs in the way you would expect if autocasting was actually running.

Why normal autocasting won't work:
* NJT intercepts the `__torch_function__` call for scaled_dot_product_attention, which, AFAIK, happens before we get to any dispatcher logic, and then calls efficient attention or flash attention. So autocasting the scaled_dot_product_attention op won't work; we never call the aten op for scaled_dot_product_attention, so we won't ever run autocasting for it.
* If we try to add autocasting handling for `_flash_attention_forward` or `_efficient_attention_forward`, then autocasting will _run_, but it will have the wrong semantics: sdpa.py's handling will run first, and it will do backend selection based on the uncasted inputs to SDPA. This also means that if the inputs to the SDPA call don't have uniform types, the sdpa.py implementation will fail checks (this is the specific issue we're targeting).

Alternative: "just change the backend selection logic for NJT to be autocast aware, but don't actually do the autocast; then, add `_(flash|efficient)_attention_forward` to autocasting rules". I think this would work too. But it's arguably better to make the backend-selection logic and actual-autocast-behavior use the same implementation, in case the implementations are different.

Differential Revision: [D60879916](https://our.internmc.facebook.com/intern/diff/D60879916)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/132835
Approved by: https://github.com/soulitzer
2024-08-08 01:36:57 +00:00
Antoni Viros
ed97fb77f9 Conversions between strided and jagged layouts for Nested Tensors (#115749)
This PR does 3 things:
1. Adds a copy-free strided->jagged layout conversion for NT
2. Adds a copy-free jagged->strided layout conversion for NT
3. Modifies and expands the .to() API to support the layout argument for the specific case of NT layout conversion.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115749
Approved by: https://github.com/jbschlosser
2024-08-07 14:18:53 +00:00
Apurva Jain
8bc5ef563e Grouped Query Attention (#132689)
### Approach: Using the current function declaration

**Constraint:** Q_Heads % KV_Heads == 0

**Major change:**
- Added a new argument enable_gqa: bool to sdpa function call
- It adds a meaning to the last third dimension.

Sample use cases this would enable:
LLama3

```
# LLama3 8b call to SDPA
query = torch.rand(batch, 32, seq_len_q, D)
key = torch.rand(batch, 8, seq_len_kv, D)
value = torch.rand(batch, 8, seq_len_kv, D)

output = scaled_dot_product_attention(query, key, value, is_causal=True, enable_gqa=True)

# Output Shape
(batch, 32, seq_len_q, D)
```

### Design Choice:

- Check if Query.size(-3) == Key.size(-3) == Value.size(-3) or, Query.size(-3) % Key.size(-3) == 0
- The function adjusts the key and value tensors to match the query tensor's head dimension by using repeat_interleave if their number of heads are not equal, facilitating correct and efficient computation in attention mechanisms.
- By default the enable_gqa flag is set to False, which ensures that regular sdpa functionality remains unchanged.

### Benchmarks:

- **sdpa.py: #130634**
For different batch sizes enable_gqa=True shows a substansial improvement in the run_time of sdpa

 | batch_size | q_num_heads | kv_num_heads | q_seq_len | kv_seq_len | embed_dim | forward_time when enable_gqa=True   |   forward_time when enable_gqa=False    |
| ------------ | ------------- | -------------- | ----------- | ------------ | ----------- | ----------- | ---------------- |
|     1      |     32      |      8       |   2048    |    2048    |   2048    |   100.71  |  119.70  |
|     8      |     32      |      8       |   2048    |    2048    |   2048    |   539.78  |  628.83  |
|     16     |     32      |      8       |   2048    |    2048    |   2048    |   1056.81  |  1225.48  |
|     32      |     32      |      8       |   2048    |    2048    |   2048    |   2099.54  |  2440.45  |

![Screenshot 2024-07-25 at 9 07 40 PM](https://github.com/user-attachments/assets/a3e5f716-c39f-4096-9e6c-82a735e57b7b)

- **TorchTitan: https://github.com/pytorch/torchtitan/pull/458**

Differential Revision: D60772086

Pull Request resolved: https://github.com/pytorch/pytorch/pull/132689
Approved by: https://github.com/drisspg
2024-08-07 05:35:36 +00:00
yuqingj
623d0204f0 [NJT] Support Chunk backward for simple cases (#132193)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/132193
Approved by: https://github.com/soulitzer
2024-08-06 21:20:09 +00:00
soulitzer
f50621989b Construct NJT without graph breaks (#130292)
Combines contributions from https://github.com/pytorch/pytorch/pull/130505

Some context can be found in this large comment block:

a5b64d39fd/test/dynamo/test_subclasses.py (L1667-L1681)

Changes in this PR
- For each tensor fakified, check the nested int registry in eager, and eagerly symbolicize if that tensor has already been associated with nested int in eager.
- Adds a separate counter stored on FakeTensorMode as a fake analog to _tensor_id_counter (which keeps track of unique tensors). This counter is initialized to the global eager tensor id counter upon creation of the FakeTensorMode, and needs to be reset when the same FakeTensorMode is reused to trace again (in this PR, we piggyback on the epoch incrementing logic).
- (refactor) Today, we store FakeTensor -> symbolic nested int in the global registry. With this PR, symbolic nested int is stored directly on the FakeTensor. (Eager still caches nested int in the registry, though we should avoid this at some point.)

Basically unchanged, but worth noting:
- `__tensor_unflatten__` is still responsible for determining whether we should cache for now. The logic is somewhat simplified.
- to_copy is still using the trick of updating two different tensors in the registry to point to the same nested int. This is kind of broken, but we try to leave it as is, and plan a better fix with the UnionFind stack.

Differential Revision: [D60406772](https://our.internmc.facebook.com/intern/diff/D60406772)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/130292
Approved by: https://github.com/bdhirsh
ghstack dependencies: #131916, #131803
2024-08-06 17:03:39 +00:00
PyTorch MergeBot
38674bcb45 Revert "Conversions between strided and jagged layouts for Nested Tensors (#115749)"
This reverts commit eca0cb0fbe.

Reverted https://github.com/pytorch/pytorch/pull/115749 on behalf of https://github.com/izaitsevfb due to breaks test_overrides.py::TestTorchFunctionWarning::test_warn_on_invalid_torch_function_tensor_subclass ([comment](https://github.com/pytorch/pytorch/pull/115749#issuecomment-2270213988))
2024-08-06 01:55:41 +00:00
Antoni Viros
eca0cb0fbe Conversions between strided and jagged layouts for Nested Tensors (#115749)
This PR does 3 things:
1. Adds a copy-free strided->jagged layout conversion for NT
2. Adds a copy-free jagged->strided layout conversion for NT
3. Modifies and expands the .to() API to support the layout argument for the specific case of NT layout conversion.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115749
Approved by: https://github.com/jbschlosser
2024-08-05 23:45:48 +00:00
Aaron Gokaslan
fd4b649e6c [BE]: Simplify some list comps to generators C419 (#132578)
Simplifies some list comprehensions to generator which is more efficient. Automatically applied diffs for the most part with ruff

Pull Request resolved: https://github.com/pytorch/pytorch/pull/132578
Approved by: https://github.com/ezyang
2024-08-04 17:46:26 +00:00
Xuehai Pan
d2dc173664 Remove lint dependency ufmt (#132573)
`ufmt` is a combination of `black + usort`.

This PR removes `ufmt` and run `black` and `usort` separately.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/132573
Approved by: https://github.com/ezyang
ghstack dependencies: #129769, #132572
2024-08-04 10:24:09 +00:00
Xuehai Pan
f3fce597e9 [BE][Easy][17/19] enforce style for empty lines in import segments in torch/[a-c]*/ and torch/[e-n]*/ (#129769)
See https://github.com/pytorch/pytorch/pull/129751#issue-2380881501. Most changes are auto-generated by linter.

You can review these PRs via:

```bash
git diff --ignore-all-space --ignore-blank-lines HEAD~1
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129769
Approved by: https://github.com/ezyang
2024-08-04 10:24:09 +00:00
PyTorch MergeBot
bcb4f7c172 Revert "Grouped Query Attention (#128898)"
This reverts commit 6b28af1b79.

Reverted https://github.com/pytorch/pytorch/pull/128898 on behalf of https://github.com/ZainRizvi due to Sorry, this broke a bunch of tests internally. See D60638265 ([comment](https://github.com/pytorch/pytorch/pull/128898#issuecomment-2265961038))
2024-08-02 18:58:46 +00:00
jainapurva
6b28af1b79 Grouped Query Attention (#128898)
### Approach: Using the current function declaration

**Constraint:** Q_Heads % KV_Heads == 0

**Major change:**
- Added a new argument enable_gqa: bool to sdpa function call
- It adds a meaning to the last third dimension.

Sample use cases this would enable:
LLama3

```
# LLama3 8b call to SDPA
query = torch.rand(batch, 32, seq_len_q, D)
key = torch.rand(batch, 8, seq_len_kv, D)
value = torch.rand(batch, 8, seq_len_kv, D)

output = scaled_dot_product_attention(query, key, value, is_causal=True, enable_gqa=True)

# Output Shape
(batch, 32, seq_len_q, D)
```

### Design Choice:

- Check if Query.size(-3) == Key.size(-3) == Value.size(-3) or, Query.size(-3) % Key.size(-3) == 0
- The function adjusts the key and value tensors to match the query tensor's head dimension by using repeat_interleave if their number of heads are not equal, facilitating correct and efficient computation in attention mechanisms.
- By default the enable_gqa flag is set to False, which ensures that regular sdpa functionality remains unchanged.

### Benchmarks:

- **sdpa.py: #130634**
For different batch sizes enable_gqa=True shows a substansial improvement in the run_time of sdpa

 | batch_size | q_num_heads | kv_num_heads | q_seq_len | kv_seq_len | embed_dim | forward_time when enable_gqa=True   |   forward_time when enable_gqa=False    |
| ------------ | ------------- | -------------- | ----------- | ------------ | ----------- | ----------- | ---------------- |
|     1      |     32      |      8       |   2048    |    2048    |   2048    |   100.71  |  119.70  |
|     8      |     32      |      8       |   2048    |    2048    |   2048    |   539.78  |  628.83  |
|     16     |     32      |      8       |   2048    |    2048    |   2048    |   1056.81  |  1225.48  |
|     32      |     32      |      8       |   2048    |    2048    |   2048    |   2099.54  |  2440.45  |

![Screenshot 2024-07-25 at 9 07 40 PM](https://github.com/user-attachments/assets/a3e5f716-c39f-4096-9e6c-82a735e57b7b)

- **TorchTitan: https://github.com/pytorch/torchtitan/pull/458**

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128898
Approved by: https://github.com/drisspg
2024-07-31 22:58:51 +00:00
Joel Schlosser
7eb2a99585 Fix to support unary pointwise ops when an NJT is not the first arg (#131937)
**Background:** NJT utilizes a `jagged_unary_pointwise()` fallback that historically has assumed blindly that the first arg is an NJT. This assumption breaks certain ops; for example `pow(scalar, Tensor)` has an NJT as the second arg.

This PR expands `jagged_unary_pointwise()` and the associated schema validation logic to handle an NJT in args other than the first position.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/131937
Approved by: https://github.com/soulitzer
ghstack dependencies: #131898, #131704
2024-07-31 17:51:03 +00:00
Janani Sriram
46994e753b [NestedTensor] Integrate the layer normalization operator along the jagged dimension into NestedTensor (#132172)
Modify the existing `layer normalization` operator in PyTorch, invoked by `torch.layer_norm`, to allow for reductions along the jagged dimension of a nested tensor. The function originally had a basic implementation for reducing along 1 non-ragged dimension. This diff, which uses the `aten` padding operator, enables PyTorch users to invoke `torch.nn.functional.layer_norm` on a nested tensor when reducing along the ragged dimension, e.g. `*` in a `(B, *, M)` or `(B, *, M, N)` nested tensor.

Write unit tests based on the `softmax` jagged operator to verify the accuracy of the ragged reduction implementation for `torch.nn.functional.layer_norm`. Add unit tests to verify error handling for unsupported features.

Note that this implementation is limited to nested tensors with `ragged_idx == 1`, i.e. the ragged dimension is not transposed. The layer normalization operator also requires an operation on a 2-dimensional layer; for nested tensors with 4 or more dimensions, I flatten the extra dimensions, then unflatten them after performing layer normalization.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/132172
Approved by: https://github.com/davidberard98
ghstack dependencies: #132170
2024-07-31 10:51:46 +00:00
Janani Sriram
89053e382a [NestedTensor] Integrate the softmax operator along the jagged dimension into NestedTensor (#132170)
Modify the existing `softmax` operator in PyTorch, invoked by `torch.softmax`, to allow for reductions along the jagged dimension of a nested tensor. The function originally had a basic implementation for reducing along 1 non-ragged dimension. This diff, which uses the aten padding operator, enables PyTorch users to invoke `torch.softmax` on a nested tensor when reducing along the ragged dimension, e.g. `*` in a `(B, *, M)` nested tensor.

Write unit tests based on the `sum` and `mean` jagged operators to verify the accuracy of the ragged reduction implementation for `torch.softmax`. Add unit tests to verify error handling for unsupported features in `NestedTensor` `torch.softmax`.

Note that this implementation is limited to nested tensors with `ragged_idx == 1`, i.e. the ragged dimension is not transposed. In addition, the `softmax` operator is required to take in as input an integer for the reduction dimension `dim`, requiring new unit tests heavily inspired by the `sum` and `mean` jagged operator unit tests. `Softmax` also allows for reducing along the batch dimension.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/132170
Approved by: https://github.com/davidberard98
2024-07-31 10:51:46 +00:00
Joel Schlosser
524aac413c Initial OpInfo-based testing for NJTs (#131704)
This PR utilizes the info from the existing OpInfo database `op_db` to contribute to general NJT testing.
* New tests in `TestNestedTensorOpInfo`
    * `test_forward()` - compares forward output to an unbind-based reference
    * `test_backward()` - compares forward output and grads to an unbind-based reference
    * `test_forward_compile()` - compares forward compile output (`backend="aot_eager_decomp_partition"`) to eager
    * `test_backward_compile()` - compares forward compile output (`backend="aot_eager_decomp_partition"`) and grads to eager
* To avoid adding a bunch of NJT-specific stuff to the `OpInfo` structure, this PR translates `op_db` -> a NJT-specific `njt_op_db`.
    * `UnaryUfuncInfo`s utilize a new `sample_inputs_unary_njt_pointwise()` which iterates through a comprehensive list of NJTs: contiguous / non-contiguous, dims 2, 3, and 4, transposed / not, etc.
    * `BinaryUfuncInfo`s utilize a new `sample_inputs_binary_njt_pointwise()` which iterates through a comprehensive list of NJTs: contiguous / non-contiguous, dims 2, 3, and 4, transposed / not, etc.
    * `ReductionOpInfo`s utilize a new `sample_inputs_njt_reduction()` which covers full reductions, reductions over the jagged dim, and reductions over the non-jagged dim
* Several xfails were added to get things passing

TODO (future PRs):
* Pass non-contiguous / non-contiguous with holes NJTs (maybe we should have separate tests for these? most ops don't support NJTs with holes today)
* Mixed (NT, T), (T, NT) inputs for binary ops
* Handle other types of OpInfos (beyond unary pointwise, binary pointwise, and reduction) by manually by writing sample_inputs_funcs
* Address all xfails via fixes
Pull Request resolved: https://github.com/pytorch/pytorch/pull/131704
Approved by: https://github.com/soulitzer
ghstack dependencies: #131898
2024-07-30 23:02:24 +00:00
Joel Schlosser
d53b11bb6e Strict shape checking for NJTs with TestCase.assertEqual() (#131898)
**Background**: `TestCase.assertEqual()` is commonly used during test case validation. Historically, to support NSTs, the logic was written to compare two nested tensors by unbinding them and comparing their components. This logic applied to NJTs as well, which in practice meant that two NJTs with different nested ints in their shapes could compare equal if their components were equal.

This PR changes the above logic so that NJTs are no longer unbound during comparison, allowing them to receive full shape validation. This makes `TestCase.assertEqual()` stricter for NJTs, requiring them to have the same nested ints in their shapes to compare equal.

Note that some tests rely on the old, looser behavior. To address this, the PR introduces a base `NestedTensorTestCase` that defines a helper function `assertEqualIgnoringNestedInts()` so that these tests can explicitly opt in to the looser comparison behavior.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/131898
Approved by: https://github.com/soulitzer
2024-07-30 20:05:48 +00:00
PyTorch MergeBot
499ead96ff Revert "Grouped Query Attention (#128898)"
This reverts commit d039b14207.

Reverted https://github.com/pytorch/pytorch/pull/128898 on behalf of https://github.com/albanD due to Broken test on main ([comment](https://github.com/pytorch/pytorch/pull/128898#issuecomment-2258314481))
2024-07-30 13:11:24 +00:00
PyTorch MergeBot
7a7dd8c29e Revert "[NestedTensor] Integrate the softmax operator along the jagged dimension into NestedTensor (#131518)"
This reverts commit bcf5c68c18.

Reverted https://github.com/pytorch/pytorch/pull/131518 on behalf of https://github.com/ZainRizvi due to Sorry, reverting this since this is based on an internal diff that has diverged from actual internal commit (the final PR and diff must always be identical). Conflicts arise when that happens which block the diff train. Let's revert both this PR and the internal diff, and then reland them as a proper new codev diff ([comment](https://github.com/pytorch/pytorch/pull/131518#issuecomment-2257259839))
2024-07-30 00:55:10 +00:00
PyTorch MergeBot
be5e44192d Revert "[NestedTensor] Integrate the layer normalization operator along the jagged dimension into NestedTensor (#131519)"
This reverts commit 8fe2bf212d.

Reverted https://github.com/pytorch/pytorch/pull/131519 on behalf of https://github.com/ZainRizvi due to Sorry, reverting this since this is based on an internal diff that has diverged from actual internal commit.  Weird conflicts arise when that happens.  Let's revert both this PR and the internal diff, and then reland them as a proper new codev diff ([comment](https://github.com/pytorch/pytorch/pull/131519#issuecomment-2257230717))
2024-07-30 00:18:22 +00:00
yuqingj
e3dc20c94b [NJT] support cat backward (#132076)
cat_tensors_backward use narrow_symint, so we need to support aten::narrow for NJT.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/132076
Approved by: https://github.com/davidberard98
2024-07-29 23:49:26 +00:00
Janani Sriram
8fe2bf212d [NestedTensor] Integrate the layer normalization operator along the jagged dimension into NestedTensor (#131519)
Modify the existing `layer normalization` operator in PyTorch, invoked by `torch.layer_norm`, to allow for reductions along the jagged dimension of a nested tensor. The function originally had a basic implementation for reducing along 1 non-ragged dimension. This diff, which uses the `aten` padding operator, enables PyTorch users to invoke `torch.nn.functional.layer_norm` on a nested tensor when reducing along the ragged dimension, e.g. `*` in a `(B, *, M)` or `(B, *, M, N)` nested tensor.

Write unit tests based on the `softmax` jagged operator to verify the accuracy of the ragged reduction implementation for `torch.nn.functional.layer_norm`. Add unit tests to verify error handling for unsupported features.

Note that this implementation is limited to nested tensors with `ragged_idx == 1`, i.e. the ragged dimension is not transposed. The layer normalization operator also requires an operation on a 2-dimensional layer; for nested tensors with 4 or more dimensions, I flatten the extra dimensions, then unflatten them after performing layer normalization.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/131519
Approved by: https://github.com/davidberard98
ghstack dependencies: #131518
2024-07-29 22:16:32 +00:00
jainapurva
d039b14207 Grouped Query Attention (#128898)
### Approach: Using the current function declaration

**Constraint:** Q_Heads % KV_Heads == 0

**Major change:**
- Added a new argument enable_gqa: bool to sdpa function call
- It adds a meaning to the last third dimension.

Sample use cases this would enable:
LLama3

```
# LLama3 8b call to SDPA
query = torch.rand(batch, 32, seq_len_q, D)
key = torch.rand(batch, 8, seq_len_kv, D)
value = torch.rand(batch, 8, seq_len_kv, D)

output = scaled_dot_product_attention(query, key, value, is_causal=True, enable_gqa=True)

# Output Shape
(batch, 32, seq_len_q, D)
```

### Design Choice:

- Check if Query.size(-3) == Key.size(-3) == Value.size(-3) or, Query.size(-3) % Key.size(-3) == 0
- The function adjusts the key and value tensors to match the query tensor's head dimension by using repeat_interleave if their number of heads are not equal, facilitating correct and efficient computation in attention mechanisms.
- By default the enable_gqa flag is set to False, which ensures that regular sdpa functionality remains unchanged.

### Benchmarks:

- **sdpa.py: #130634**
For different batch sizes enable_gqa=True shows a substansial improvement in the run_time of sdpa

 | batch_size | q_num_heads | kv_num_heads | q_seq_len | kv_seq_len | embed_dim | forward_time when enable_gqa=True   |   forward_time when enable_gqa=False    |
| ------------ | ------------- | -------------- | ----------- | ------------ | ----------- | ----------- | ---------------- |
|     1      |     32      |      8       |   2048    |    2048    |   2048    |   100.71  |  119.70  |
|     8      |     32      |      8       |   2048    |    2048    |   2048    |   539.78  |  628.83  |
|     16     |     32      |      8       |   2048    |    2048    |   2048    |   1056.81  |  1225.48  |
|     32      |     32      |      8       |   2048    |    2048    |   2048    |   2099.54  |  2440.45  |

![Screenshot 2024-07-25 at 9 07 40 PM](https://github.com/user-attachments/assets/a3e5f716-c39f-4096-9e6c-82a735e57b7b)

- **TorchTitan: https://github.com/pytorch/torchtitan/pull/458**

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128898
Approved by: https://github.com/drisspg
2024-07-29 21:49:06 +00:00
PyTorch MergeBot
945bf78894 Revert "[BE] typing for decorators - fx/_compatibility (#131568)"
This reverts commit 193f62fde9.

Reverted https://github.com/pytorch/pytorch/pull/131568 on behalf of https://github.com/clee2000 due to same as https://github.com/pytorch/pytorch/pull/131572#issuecomment-2254328359 but I clicked the wrong link by accident.  This is where it actually starts ([comment](https://github.com/pytorch/pytorch/pull/131568#issuecomment-2254330781))
2024-07-28 03:43:39 +00:00
PyTorch MergeBot
8cdfdb41bc Revert "[NestedTensor] Integrate the layer normalization operator along the jagged dimension into NestedTensor (#131519)"
This reverts commit f862f45730.

Reverted https://github.com/pytorch/pytorch/pull/131519 on behalf of https://github.com/atalman due to broke CI: test_nestedtensor.py::TestNestedTensorSubclassCPU::test_layer_norm_with_lengths_requires_grad_False_components_require_grad_False_cpu_float32 [GH job link](https://github.com/pytorch/pytorch/actions/runs/10121747545/job/27996722731) [HUD commit link](f862f45730) ([comment](https://github.com/pytorch/pytorch/pull/131519#issuecomment-2254167994))
2024-07-27 14:45:47 +00:00
Janani Sriram
f862f45730 [NestedTensor] Integrate the layer normalization operator along the jagged dimension into NestedTensor (#131519)
Modify the existing `layer normalization` operator in PyTorch, invoked by `torch.layer_norm`, to allow for reductions along the jagged dimension of a nested tensor. The function originally had a basic implementation for reducing along 1 non-ragged dimension. This diff, which uses the `aten` padding operator, enables PyTorch users to invoke `torch.nn.functional.layer_norm` on a nested tensor when reducing along the ragged dimension, e.g. `*` in a `(B, *, M)` or `(B, *, M, N)` nested tensor.

Write unit tests based on the `softmax` jagged operator to verify the accuracy of the ragged reduction implementation for `torch.nn.functional.layer_norm`. Add unit tests to verify error handling for unsupported features.

Note that this implementation is limited to nested tensors with `ragged_idx == 1`, i.e. the ragged dimension is not transposed. The layer normalization operator also requires an operation on a 2-dimensional layer; for nested tensors with 4 or more dimensions, I flatten the extra dimensions, then unflatten them after performing layer normalization.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/131519
Approved by: https://github.com/davidberard98
ghstack dependencies: #131518
2024-07-27 07:09:10 +00:00
Janani Sriram
bcf5c68c18 [NestedTensor] Integrate the softmax operator along the jagged dimension into NestedTensor (#131518)
Modify the existing `softmax` operator in PyTorch, invoked by `torch.softmax`, to allow for reductions along the jagged dimension of a nested tensor. The function originally had a basic implementation for reducing along 1 non-ragged dimension. This diff, which uses the aten padding operator, enables PyTorch users to invoke `torch.softmax` on a nested tensor when reducing along the ragged dimension, e.g. `*` in a `(B, *, M)` nested tensor.

Write unit tests based on the `sum` and `mean` jagged operators to verify the accuracy of the ragged reduction implementation for `torch.softmax`. Add unit tests to verify error handling for unsupported features in `NestedTensor` `torch.softmax`.

Note that this implementation is limited to nested tensors with `ragged_idx == 1`, i.e. the ragged dimension is not transposed. In addition, the `softmax` operator is required to take in as input an integer for the reduction dimension `dim`, requiring new unit tests heavily inspired by the `sum` and `mean` jagged operator unit tests. `Softmax` also allows for reducing along the batch dimension.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/131518
Approved by: https://github.com/davidberard98
2024-07-27 07:09:10 +00:00
Janani Sriram
13e806a591 [NestedTensor] Add support for transposed NestedTensors where ragged_idx > 1 for sum and mean operators (#131517)
Add support for transposed, non-contiguous `NestedTensor`s, where `ragged_idx > 1`, for the aten operators `sum` and `mean`. This diff enables reducing along the jagged dimension for non-contiguous `NestedTensor`s, transposed between non-batch dimensions as well as between a ragged and a non-batch dimension. For example, users can now reduce a `NestedTensor` of shape `(B, M, *, N)` along `*` or `(B, N, M, *)` along `*`.

Parametrize existing unit tests and add new unit tests verifying the accuracy of implementations on `NestedTensor`s that transpose between 2 non-batch dimensions as well as between a ragged and a non-batch dimension.

Differential Revision: [D59847927](https://our.internmc.facebook.com/intern/diff/D59847927/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/131517
Approved by: https://github.com/davidberard98
2024-07-26 07:21:32 +00:00
Aaron Orenstein
193f62fde9 [BE] typing for decorators - fx/_compatibility (#131568)
See #131429

Pull Request resolved: https://github.com/pytorch/pytorch/pull/131568
Approved by: https://github.com/justinchuby, https://github.com/oulgen, https://github.com/zou3519
2024-07-25 22:24:19 +00:00
jananisriram
faddb0f30c [NestedTensor] Integrate the mean operator along the jagged dimension into NestedTensor (#131132)
Summary:
Modify the existing `mean` operator in PyTorch, invoked by `torch.mean`, to allow for reductions along the jagged dimension of a nested tensor. The function originally had a basic implementation for reducing along 1 non-ragged dimension. This diff enables PyTorch users to invoke `torch.mean` on a nested tensor when reducing along the ragged dimension, e.g. `*` in a `(B, *, M)` nested tensor.

Parametrize unit tests from `sum` to verify the accuracy of the ragged reduction implementation for `torch.mean`. Add unit tests and parametrize `sum` unit tests to verify error handling for unsupported features in `NestedTensor` `torch.mean`.

Test Plan:
Verify that the new unit test passes via the following command:
```
buck2 run mode/{opt,inplace} //caffe2/test:nested -- --regex test_mean
```

```
buck2 run mode/{opt,inplace} //caffe2/test:nested -- --regex test_jagged_op
```

Differential Revision: D59654668

Pull Request resolved: https://github.com/pytorch/pytorch/pull/131132
Approved by: https://github.com/davidberard98, https://github.com/jbschlosser
2024-07-23 18:48:34 +00:00