Part 2 of implementation for general [subclass view fake-ification](https://docs.google.com/document/d/1C5taWiplmX7nKiURXDOAZG2W5VNJ2iV0fQFq92H0Cxw).
Details:
* Codegen `rev_view_func()` alongside `view_func()`
* Reverse view_func gives you a "base" from a "view": `rev_view_func(new_view) -> new_base` AKA it plays the original view backwards
* Utilizes the functional inverses defined in `FunctionalInverses.cpp`, passing `InverseReturnMode::AlwaysView`
* Manually implements functional inverses for `narrow()` and `chunk()`
* **NB: Multi-output views now set view_func() / rev_view_func() for each of the output views!**
* Due to this, the `as_view()` overload that operates on a list of views is scrapped in favor of iteration via codegen
Example codegen in `ADInplaceOrViewTypeN.cpp`:
```cpp
at::Tensor narrow(c10::DispatchKeySet ks, const at::Tensor & self, int64_t dim, c10::SymInt start, c10::SymInt length) {
auto _tmp = ([&]() {
at::AutoDispatchBelowADInplaceOrView guard;
return at::_ops::narrow::redispatch(ks & c10::after_ADInplaceOrView_keyset, self, dim, start, length);
})();
std::function<at::Tensor(const at::Tensor&)> func=nullptr;
std::function<at::Tensor(const at::Tensor&)> rev_func=nullptr;
if (false || !self.unsafeGetTensorImpl()->support_as_strided() ||
c10::AutogradState::get_tls_state().get_view_replay_enabled()) {
func = [=](const at::Tensor& input_base) {
return at::_ops::narrow::call(input_base, dim, start, length);
};
rev_func = [=](const at::Tensor& input_view) {
// NB: args from narrow() signature are passed along to the inverse
return at::functionalization::FunctionalInverses::narrow_copy_inverse(self, input_view, at::functionalization::InverseReturnMode::AlwaysView, dim, start, length);
};
}
auto result = as_view(/* base */ self, /* output */ _tmp, /* is_bw_differentiable */ true, /* is_fw_differentiable */ true, /* view_func */ func, /* rev_view_func */ rev_func, /* creation_meta */ InferenceMode::is_enabled() ? CreationMeta::INFERENCE_MODE : (at::GradMode::is_enabled() ? CreationMeta::DEFAULT : CreationMeta::NO_GRAD_MODE));
return result;
}
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/115894
Approved by: https://github.com/soulitzer
Decorates all NT tests with `@markDynamoStrictTest` to ensure we get the correct signal. Adds xfails where needed to get things passing.
Includes a fix in meta_utils.py for a bug that was breaking several python 3.11 tests. In particular, a dense tensor graph input that is a view of a strided NT would slip past Dynamo's check and break in meta-ification.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/116111
Approved by: https://github.com/soulitzer, https://github.com/zou3519
ghstack dependencies: #115192
This PR removes the need for passing `ragged_size` into the `NestedTensor` constructor. This was an artifact of fake-ification, where sometimes we needed the NT to have a symbolic singleton symint shape for the ragged dimension. The new way of achieving this is to also store mappings between fake / functional tensors -> symbolic symints in the ragged structure registry. Now the `NestedTensor` constructor can just query this registry for the `ragged_size`.
Old: `NestedTensor(values, offsets, *, ragged_size=None, **kwargs)`
New: `NestedTensor(values, offsets, **kwargs)`
This makes it possible to have a `_nested_view_from_values_offsets(values, offsets)` without needing to pass a `ragged_size`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113491
Approved by: https://github.com/ezyang, https://github.com/soulitzer
Summary:
Add split and layer_norm_backward.
Note: It is non trivial to support split_with_sizes backward so adding the split operation to support the use case in the model.
Test Plan: unit tests
Differential Revision: D51052966
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113108
Approved by: https://github.com/soulitzer
This PR:
* Adds support for the `layout` kwarg to `torch.nested.as_nested_tensor()`
* Fixes `torch.nested.nested_tensor()`
* It should accept a list of lists of scalars
* It should not preserve autograd history
* Adds extensive testing for these two functions
Semantics for the two functions follow those of the strided layout:
* `torch.nested.nested_tensor(tensor_list, layout=torch.jagged)`: Creates a new jagged layout NT **with no autograd history**
* `tensor_list` can be a list of Tensors or list of lists of scalars
* `torch.nested.as_nested_tensor(tensor_list, layout=torch.jagged)`: Creates a new jagged layout NT **preserving autograd history of `tensor_list`**
* `tensor_list` must be a list of Tensors
Pull Request resolved: https://github.com/pytorch/pytorch/pull/112304
Approved by: https://github.com/cpuhrsch, https://github.com/soulitzer
This PR has a number of changes that improve subclass support for AOTAutograd/Inductor in general:
- previously if a subclass does extra aliasing between graph outputs/inputs in a way, the partitioner would complain because grad_outputs are the outputs reused as-is. Now we do a view_as(self) to workaround this.
- Use dense -> dense metadata when working with fwd_output_strides during backward. This is important since the stride information comes from inductor which sees the dense to dense graph.
- Inductor requires that the inputs to the compiled backward to match some expected strides computed during compilation. We make sure to make the inner tensors of the subclass contiguous (previously, we only made the subclass itself contiguous)
Changes specific to NestedTensor relevant to compilation:
- Properly handle the case where `__tensor_unflatten__` is passed non-symbolic dense tensors and with meta extracted from fake subclasses.
- Skip var_to_range logic for singleton int
- Skip size hint logic in inductor for singleton int
Pull Request resolved: https://github.com/pytorch/pytorch/pull/110529
Approved by: https://github.com/bdhirsh
In this PR:
- Adds support for strides for jagged tensor (design doc for this coming soon)
- NestedTensor skips automatic dynamic
- Make use of @bdhirsh's subclass fakification logic by adding the __tensor_{un,}flatten__ functions.
- Additional logic for fakification: since existing subclass fakification logic does not handle the case where the outer tensor has an additional dimension. We insert one-off logic to (1) insert an extra SingletonSymInt onto the fakified NestedTensor. (2) make sure we call track_symint on both the sizes on the inner and outer tensor during guard creation.
Remaining things that are weird:
- Still need to skip some logic in meta utils for some reason (I was going to write this up more, but decided not to since we're not able to do this anyway for a immediate reason: we cannot arbitrarily compare singleton ints. For now I'm just following Brian's advise from [here](https://github.com/pytorch/pytorch/pull/109171#discussion_r1328137070) )
Pull Request resolved: https://github.com/pytorch/pytorch/pull/109171
Approved by: https://github.com/ezyang, https://github.com/bdhirsh
We want to be able to use SingletonSymNode to represent strides for Jagged layout tensor. The following is for 3D, but easily generalizable to higher dimensions.
Constraints:
- [B, x, D] (where x represents the "variably lengthed dim") can be strided in two ways [x, 1, sum(x)] and [dx, d, 1]. We need two different placeholder values depending on how the jagged tensor is strided.
- When doing operations we need the strides of output tensors to be expressable in terms of the strides and sizes of the inner tensors. Given [B, x, D] @ [D, D'], the output strides is [x * D', D', 1] rather than some opaque [x2, D', 1]. This constraint exists because if I'm tracing, I need a symint to represent the output stride. This symint needs to come from somewhere; I get it in several ways: (1) create a constant, (2) unbacked symint, (3) create a new input using a source, (4) output of an operation on an existing symint. It is clear that (4) is what we want here, which brings us to the design below.
Design:
Given the two constraints, the most straightforward way to implement this is actually to update SingletonSymNode to include some scalar factor, i.e. Morally, SingletonSymNode represents `factor * [s_0, s_1, …, s_n]` This enables us to symbolically compute strides from sizes.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/110369
Approved by: https://github.com/ezyang
ghstack dependencies: #110044
# Summary
Preivously we disallowd dis-contiguous NTs to passed into to empty_like. This was done out of an abundance of caution, :think:. However it should be safe to create an empty NT for dis-contiguous NTs. Empty like does account for offsets, strides, and sizes in construction of the result and therefore this should be safe.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/98383
Approved by: https://github.com/cpuhrsch
This is needed for the HSTU model.
Details:
* ~~NT `chunk` now calls into NT `split_with_sizes` since the latter is more general~~ (removed; they're totally separate)
* Throws for backward
* Only operates over the last dim (`dim=-1`)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/97446
Approved by: https://github.com/cpuhrsch
# Summary
NestedTensors currenlty don't support non-identical strided addition. When accumulating grad it possible to try and accumulate a grad with different striding then the old var and there is no way to change this in user code. This is a solution.. probs should support strided addition for NT
Pull Request resolved: https://github.com/pytorch/pytorch/pull/97195
Approved by: https://github.com/albanD, https://github.com/cpuhrsch
# Summary
In preparation for pt 2.0 launch this PR updates SDPA's API and makes the function a nn.funcitonal public function.
## Changes
### API
Previously the the function signature was:
`scaled_dot_product_attention(query, key, value, attn_mask=None, need_attn_weights=False, dropout_p=0.0, is_causal=False) -> (Tensor, Tensor)`
Updated signature:
`scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False) -> Tensor`
This PR removes the need_attn_weights optional boolean variable and updates the return type to a singular tensor.
#### Reasoning:
The main goal of this function is to provide an easy interface for users to call into fused attention kernels e.g. (FlashAttention). The fused kernels do not currently support arbitrary attn_mask or dropout but there is a PR to mem-efficient attention to enable these. We want to have the API surface ready for when the backing kernels get updated.
The fused kernels save on memory usage by not materializing the weights and it is unlikely that a fast fused implementation will enable this feature so we are removing.
Discussed with folks at FAIR/Xformers and +1 this API change.
#### Make function Public
In preparation for the pt 2.0 launch we make the function public to start to generate user feedback
Pull Request resolved: https://github.com/pytorch/pytorch/pull/92189
Approved by: https://github.com/cpuhrsch
Summary: This diff modifies the implementation of the select operator so slices of the irregular dimension can be selected (e.g. nt[:,0,:]).
Test Plan:
Added new unit tests to test that the new functions work as intended (see them in diff). To test,
`buck test mode/dev-nosan //caffe2/test:nested`
Differential Revision: D41083993
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88585
Approved by: https://github.com/cpuhrsch
Summary: This diff merges both previous implementations of constructors for nested tensors, the one from lists of tensors and the one with arbitrary python lists, adn implements it in pytorch core so no extensions are needed to construct NT.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88213
Approved by: https://github.com/cpuhrsch
Fixes#87713
BMM for cpu supports non-contiguous nested tensor inputs, while BMM for Cuda does not support currently non-contiguous inputs.
The derivative for BMM:
```
- name: bmm(Tensor self, Tensor mat2) -> Tensor
self: grad.bmm(mat2.transpose(1, 2).conj())
mat2: self.transpose(1, 2).conj().bmm(grad)
result: self_t.bmm(mat2_p) + self_p.bmm(mat2_t)
```
When calling backward it was impossible for this function to succeed since the inputs were always discontiguous, regardless of the user input. This adds contiguous calls to BMM_cuda implementation for nested tensors.
This was not caught by tests because grad_check is currently only done on CPU in test_nestedtensors. This PR updates the autograd test to also be run on GPU.
As a result I found one more issue with the backward for to_padded_tensor erroring instead of calling the generic version.
cc @cpuhrsch @jbschlosser @bhosmer @mikaylagawarecki
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88108
Approved by: https://github.com/cpuhrsch
Summary: This diff implements copy_ in order to allow pinned memory transfers for nested tensors, as well as fill_ and ones_like, to test whether nested tensors can be created with other factory functions.
Test Plan: Pass all CI and sandcastle jobs.
Reviewed By: mikekgfb
Differential Revision: D40689594
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87728
Approved by: https://github.com/cpuhrsch
Summary: This commit adds support for moving NestedTensors from CPU to GPU and back. The implementation includes requires implementing empty_like(), which is based on PR#83140.
Test Plan: Added a new unit test based on the unit test for the main .to() implementation. All unit tests must pass, as well as every sandcastle job.
Differential Revision: D40437585
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87146
Approved by: https://github.com/drisspg
Summary: In order to make the layer normalization implementation for nested tensors public, it needs to be generalized to accept a normalized_shape argument instead of assuming it to be the last dimension of the nested_tensor. This commit does that, as well as adding extra unit tests to ensure the implementation is correct.
Test Plan:
All unit tests designed to test different ways of using the function work:
`buck test //caffe2/test:nested -- test_layer_norm`
Differential Revision: D40105207
Pull Request resolved: https://github.com/pytorch/pytorch/pull/86295
Approved by: https://github.com/drisspg
### this effectively means that we only allow reshaping/viewing of nt with ONE ragged dimension
Behavior before this PR:
1. `-1` allowed for implicit batch dimension
2. multiple `-1`s allowed for pre-existing dimensions
3. for new dimensions, `-1` is not allowed
it is worth noting that for the most part 3 is basically unreachable because assuming a nested tensor has at least 1 ragged dimension, you would expect at least one -1 to be in the proposed shape for the pre-existing dimensions
Behavior after this PR:
1. batch dimension **must be specified**
2. **only one** `-1` allowed for pre-existing dimensions **this effectively means that we only allow reshaping/viewing of nt with ONE ragged dimension**
3. unchanged
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85691
Approved by: https://github.com/cpuhrsch
### this effectively means that we only allow reshaping/viewing of nt with ONE ragged dimension
Behavior before this PR:
1. `-1` allowed for implicit batch dimension
2. multiple `-1`s allowed for pre-existing dimensions
3. for new dimensions, `-1` is not allowed
it is worth noting that for the most part 3 is basically unreachable because assuming a nested tensor has at least 1 ragged dimension, you would expect at least one -1 to be in the proposed shape for the pre-existing dimensions
Behavior after this PR:
1. batch dimension **must be specified**
2. **only one** `-1` allowed for pre-existing dimensions **this effectively means that we only allow reshaping/viewing of nt with ONE ragged dimension**
3. unchanged
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85691
Approved by: https://github.com/cpuhrsch
Previously indexing a nested tensor when it requires_grad would raise an error because the backward formula for `select.int` uses `self.sizes()`. This PR fixes that by temporarily registering a _nested_select_backward function which can be removed when we start using the symint approach to register kernels. For now this functionality is needed for creating a POC that nested tensor can be an API to `segment_coo` and `segment_csr` in the torch_scatter repo
```
a = torch.arange(10).reshape(2, 5).float()
b = torch.arange(12).reshape(2, 6).float()
nt = torch.nested_tensor([a, b], dtype=torch.float).requires_grad_(True)
nt[0]
# RuntimeError: Internal error: NestedTensorImpl doesn't support sizes. Please file an issue on https://github.com/pytorch/nestedtensor
```
whereas
```
nt = torch.nested_tensor([a, b], dtype=torch.float).requires_grad_(False)
nt[0]
```
would succeed
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83875
Approved by: https://github.com/albanD, https://github.com/drisspg