Summary:
Previously we don't generate out variant (both schema and kernel) for an operator with functional variant only. This adds support for that and adds test.
## Changes on `native_function_generation.py`
We are generating out variant for all functional variants if possible. This PR introduces a lot of newly generated out variants and `native_functions.yaml` needs to incorporate the changes by adding `autogen` keywords.
The logic for determining what operators we should generate an out variant for is the following:
1. No existing out variant for this `NativeFunction`
2. Contains an existing in place, mutable or functional variant
3. Contains at least 1 tensor like return(s)
For operators matching the first two conditions but failing the third, I listed them in `FUNCTIONAL_OPS_THAT_CANNOT_GET_AN_OUT_VARIANT`.
## Special handling
The following operators satisfy all 3 criteria above but we chose to not autogen them, with some reasons.
* `mkldnn_adaptive_avg_pool2d`, the generated out variant `mkldnn_adaptive_avg_pool2d.out` is colliding with the `mkldnn_adaptive_avg_pool2d_out` kernel in `adaptive_avg_pool2d.out` operator. I manually created `mkldnn_adaptive_avg_pool2d.out` and renamed `mkldnn_adaptive_avg_pool2d_out` to `mkldnn_adaptive_avg_pool2d_out_stub`.
* `min`, `max` and `mean`. There already exist `min.out`, `max.out` and `mean.out` but they are having different semantics with the functional ones. I manually created `min.unary_out`, `max.unary_out` and `mean.dtype_out` to disambiguate.
## Autograd Changes
We introduced a logic to not match derivatives info in `derivatives.yaml` to out variant, since we are generating `NOT_IMPLEMENTED` kernels for those out variants anyway. The issue we are seeing with the original logic is that it doesn't handle `TensorOption` arguments really well. For example we have these two operators:
* `_to_copy(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, bool non_blocking=False, MemoryFormat? memory_format=None) -> Tensor`
* `_to_copy.out(Tensor self, *, bool non_blocking=False, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!)`
If we uses `_to_copy` derivative info, there will be compilation error since `dtype` is missing from `_to_copy.out` signature.
Test Plan: Rely on unit test
Differential Revision: D37832342
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81437
Approved by: https://github.com/iseeyuan, https://github.com/bdhirsh
`derivatives.yaml` can now take a `dispatch` entry which registers per-autograd dispatch key derivatives such as
```
name: foo(Tensor self, Tensor y) -> Tensor
dispatch:
Default:
x: grad
y: grad.expand(y.sizes())
AutogradNestedTensor:
x: grad
y: NestedTensor_foo_backward(grad, y)
output_differentiabilty: [True]
```
However the old schema where there is no `dispatch` entry is still supported.
Would greatly appreciate feedback on *how to improve the testing strategy* of this PR, currently have registered an aten test op in TestOps.cpp with dummy gradients in derivatives.yaml and have some tests in test_autograd.py:TestAutogradMultipleDispatch but I am not sure whether these are sufficiently rigorous.
Additionally, this PR also makes the assumption that sets like [VIEW_FUNCTIONS](ff5399e528/tools/autograd/gen_inplace_or_view_type.py (L60)) are per-native-function and not per-native-function-and-dispatch-key. I'm not sure whether this is necessarily the case, *would there ever be a situation where (e.g. a nested_tensor op is a view op but the aten function is not or vice versa?)*
* __->__ #82801
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82801
Approved by: https://github.com/bhosmer, https://github.com/albanD
Fixes#81774
`TensorOptions` arguments in the JIT schema are optional, but in the Python API these were being translated to non-optional but with a default value. This change makes the arguments accept `None` for consistency with the JIT schema. However, it also means that `dtype=c10::nullopt` was previously completely untested so this also fixes several related bugs.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82241
Approved by: https://github.com/ngimel
Deprecated signatures are currently "parsed" manually to find the
relative order of the argument names and all other information is
inferred from the aten schema for the non-deprecated overload.
However, this leads to problems if the argument names don't match or
if there are multiple candidates that match the ATen function call.
Instead, this makes the deprecated function a full FunctionSchema and
so the entire python signature comes solely from the deprecated
schema, with the `aten:` clause only used for the dispatch lambda call.
I have confirmed locally that there is no change to
`python_torch_functionsEverything.cpp`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82179
Approved by: https://github.com/albanD
Currently any function with a default dtype other than None has to be
manually entered into this function. Instead, this reads the default
directly from `native_functions.yaml`. In order to do this, I also
change `PythonSignatureGroup` to take `tensor_options_args` from the
functional variant since the out variant doesn't actually have tensor
options arguments to take the default values from.
Also note that we need to use `default_init` instead of `default`
because the out argument version doesn't have a `tensor_options`
argument to extract the default value from and so the PythonSignature
objects wouldn't match.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81479
Approved by: https://github.com/albanD
This should fix the last issue that @anijain2305 hit when running ResNet with TorchDynamo <> functionalization.
Today if you try to call an `OpOverloadPacket` from python with some arguments, we will use the types of those arguments to perform overload resolution. With some functional variants of ops, this can be ambiguous.
Today this affects just one op: `_fused_moving_avg_obs_fq_helper`, although it would potentially affect e.g. `native_batch_norm` in the future.
Example:
```
# There are technically two overloads:
# torch.ops.aten._fused_moving_avg_obs_fq_helper.default (returns 2 argument, mutates 4 of its inputs inplace)
# torch.ops.aten._fused_moving_avg_obs_fq_helper.functional (returns 6 argument, mutates none of its inputs)
# We pick the wrong one - no way to know that we should pick the functional one, just from the call site.
outs = torch.ops.aten._fused_moving_avg_obs_fq_helper(a, a, a, a, a, a, a, 1.0, 0, 1, 0)
# raises an error - tries to call the overload with only 2 returns
return _fused_moving_avg_obs_fq_helper_functional[5]
```
Specifically, functionalization will bake `_fused_moving_avg_obs_fq_helper.functional` into the graph, but when AOTAutograd tries to compile with TorchScript, it needs to remove the overload name (TS doesn't know how to parse overload names directly, so we need to remove the overload name and let it infer the right overload at runtime later- so it picks the wrong one).
The situation is pretty similar to inplace; `ops.aten.add` and `ops.aten.add_` represent two different `OverloadPacket` objects; they can't be overloads of the same op, because their schemas would be ambiguous - the alias annotations are different, but that isn't enough to disambiguate).
In this PR, I try to fix the situation in a pretty similar way to how we handle `inplace` in the data model: `inplace` ops get their own base operator name, but they are represented as a flag inside of `BaseOperatorName` in the data model.
Two other important changes that I made as part of this PR:
(1) Originally, there were ~100 different `*_functional` operators: e.g. we had operators named `resize.functional` and `zero.functional`. The `_functional` bit isn't actually necessary in most cases: it's only necessary for operators that **also** have a `SchemaKind.mutable` variant, where `_fused_moving_avg_obs_fq_helper` is the only op that fits that description today. So I removed the unnecessary notion of "functional" from those other ops. I also added a bunch of assertions to force this restriction.
I think that makes more sense in the long run, because it eliminates an unnecessary difference in the model. E.g. we don't have `add_.Tensor` and `add.Tensor_functional`. We just have `add_.Tensor` and `add.Tensor`.
(2) I noticed that we actually still weren't pairing up a bunch of `_foreach` operators correctly, because their input arguments were different (`self` vs. `tensors`). Since they're private API's, I went ahead and changed the argument names directly so they get matched up. Before this PR, we were generating a separate `_foreach_add` and `_foreach_add.functional` variant in a bunch of cases, that really did the same thing (but happened to have a different name for the first argument).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/80556
Approved by: https://github.com/ezyang, https://github.com/albanD
Due to implicit conversion shenanigans, having both IntArrayRef
and SymIntArrayRef overloads makes {} ambiguous. While we could
fix this by making a single unified type that accepts all the overloads
we want, an easier fix was to just push the SymIntArrayRef overload
to its own name.
Signed-off-by: Edward Z. Yang <ezyangfb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/79281
Approved by: https://github.com/suo
Add codegen infrastructure to generate IR nodes for non-native ops.
The proposed change is to add a `non_native` key to the `{backend}_native_functions.yaml` file that contains schema definitions similar to what is found in `native_functions.yaml`. e.g.
```
non_native:
...
- func: expand(Tensor input, int[] size, bool is_scalar_expand) -> Tensor
...
```
these definitions are parsed into a `LazyIrSchema` that can be used for generating IR nodes using `GenLazyIR`.
Fixes#74628
CC: @wconstab @desertfire @henrytwo
Pull Request resolved: https://github.com/pytorch/pytorch/pull/76535
Approved by: https://github.com/wconstab
Previously when codegening ops like `zeros_` or `ones_` we'd hit a `Code below assumes there is at least one tensor arg error`. This check is not entirely correct which is what is causing the error to be thrown. There are ops like the ones mentioned that pass in a `device` parameter that can be used in place of the "first tensor".
CC: @wconstab @desertfire @henrytwo @ke1337
Pull Request resolved: https://github.com/pytorch/pytorch/pull/76917
Approved by: https://github.com/desertfire
This PR turns the previously introduced `ITensorList` into a more general `IList`
class. It is a container wrapper for arbitrary types (given their appropriate
implementations).
In summary, I have:
- Renamed `ITensorList` (its iterators and macros, for consistency) to `IList`
- Made `IList` a templated function (for an arbitrary type `T`), given that they:
- Specialize `IListTagImpl<T, Tag>`, for all `IListTag`
- Introduced type aliases (for both list and iterator types):
- `at::ITensorList` -> `c10::IList<at::Tensor>`
- `at::IOptTensorRefList` -> `c10::IList<at::OptionalTensorRef>`
- Added support for `Tensor?[]` in the structured codegen
Pull Request resolved: https://github.com/pytorch/pytorch/pull/69606
Approved by: https://github.com/ezyang
This **roughly** corresponds to Goal 3.2 in https://docs.google.com/document/d/1iiLNwR5ohAsw_ymfnOpDsyF6L9RTUaHMpD8YLw-jxEw/edit#
Namely:
It adds the following:
* SymbolicIntNode interface
* LazySymbolicIntNode implementation
* Lazy `narrow_copy` implementation
* Need add support for SymInt in codegen
* Test (below)
```cpp
TEST(LazyDynamicOpsTest, NarrowCopy) {
auto x = torch::rand({5, 10, 10}).to(kLazy);
const size_t Y_DIM = 3;
const size_t X_DIM_INDEX = 2;
auto y = torch::rand({Y_DIM}).to(kLazy);
auto ly = torch::lazy::TryGetLtcTensor(y);
auto dim_node = MakeNode<SizeNode>(ly->GetIrValue(), 0);
auto lmn = new torch::lazy::SymbolicIntNode(dim_node);
auto z = x.narrow_copy(X_DIM_INDEX, 0, lmn->toSymInt());
AllClose(z.cpu(), x.cpu().narrow_copy(X_DIM_INDEX, 0, Y_DIM));
}
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/75759
Approved by: https://github.com/wconstab