Due to implicit conversion shenanigans, having both IntArrayRef
and SymIntArrayRef overloads makes {} ambiguous. While we could
fix this by making a single unified type that accepts all the overloads
we want, an easier fix was to just push the SymIntArrayRef overload
to its own name.
Signed-off-by: Edward Z. Yang <ezyangfb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/79281
Approved by: https://github.com/suo
Summary:
Adding a feature to allow user to specify namespaces for operator and kernels.
# Feature
There's a feature request to allow DSL to:
1. take in an operator namespace other than `aten`.
2. take in a kernel that is in a different namespace than `at::native`.
For both features, we only allow user to have a single layer of namespace for the sake of simplicity. If user specify `custom::function` as kernel, the codegen will depend on `custom::native::function` where `native` is hardcoded.
# Proposal
For feature 1, add a `namespace` attribute to data class `NativeFunction`. The namespace will be extract out by matching pattern "::" on the `func` variable. For `NativeFunctionsGroup` there's an assumption that all variants (function, inplace, out) will have the same namespace. By default (if not specified) the namespace will be "aten".
For feature 2, add a `namespace` attribute to `BackendMetadata` class, similarly match pattern "::" on the kernel field. Remove the `cpp_namespace` field from `register_dispatch_key` data class. By default (if not specified) the namespace for a kernel would be "at::native".
Test Plan:
Example yaml entries:
```
- func: custom::gelu.out(Tensor self, *, str approximate='none', Tensor(a!) out) -> Tensor(a!)
structured: True
structured_inherits: TensorIteratorBase
device_check: NoCheck # TensorIterator
python_module: nn
dispatch:
CPU: custom::gelu_out_cpu
CUDA: custom::gelu_out_cuda
MPS: custom::gelu_out_mps
- func: custom::gelu_(Tensor(a!) self, *, str approximate='none') -> Tensor(a!)
structured_delegate: gelu.out
device_check: NoCheck # TensorIterator
python_module: nn
dispatch:
NestedTensorCPU, NestedTensorCUDA: custom::NestedTensor_gelu_
- func: custom::gelu(Tensor self, *, str approximate='none') -> Tensor
structured_delegate: gelu.out
device_check: NoCheck # TensorIterator
python_module: nn
dispatch:
MkldnnCPU: custom::mkldnn_gelu
QuantizedCPU: custom::gelu_quantized_cpu
NestedTensorCPU, NestedTensorCUDA: custom::NestedTensor_gelu
```
see generated code:
`RegisterCPU.cpp`:
```
TORCH_LIBRARY_IMPL(aten, CPU, m) {
...
}
TORCH_LIBRARY_IMPL(custom, CPU, m) {
m.impl("gelu", TORCH_FN(wrapper_gelu));
m.impl("gelu.out", TORCH_FN(wrapper_gelu_out_out));
m.impl("gelu_", TORCH_FN(wrapper_gelu_));
};
```
```
struct structured_gelu_out_cpu_inplace final : public custom::native::structured_gelu_out_cpu {
structured_gelu_out_cpu_inplace(Tensor& self) : outputs_{std::ref(self)} {}
void set_output_strided(
int64_t output_idx, IntArrayRef sizes, IntArrayRef strides,
TensorOptions options, DimnameList names
) override {
const auto& out = outputs_[output_idx].get();
check_inplace(out, sizes, options);
auto maybe_proxy = maybe_create_proxy(out, sizes, strides, options);
if (C10_UNLIKELY(maybe_proxy.has_value())) {
proxy_outputs_[output_idx] = c10::ExclusivelyOwned<Tensor>(std::move(maybe_proxy).value());
}
if (!names.empty()) {
namedinference::propagate_names(outputs_[output_idx], names);
}
// super must happen after, so that downstream can use maybe_get_output
// to retrieve the output
custom::native::structured_gelu_out_cpu::set_output_raw_strided(output_idx, sizes, strides, options, names);
}
void set_output_raw_strided(
int64_t output_idx, IntArrayRef sizes, IntArrayRef strides,
TensorOptions options, DimnameList names
) override {
const auto& out = outputs_[output_idx].get();
check_inplace(out, sizes, options);
if (!names.empty()) {
namedinference::propagate_names(outputs_[output_idx], names);
}
// super must happen after, so that downstream can use maybe_get_output
// to retrieve the output
custom::native::structured_gelu_out_cpu::set_output_raw_strided(output_idx, sizes, strides, options, names);
}
const Tensor& maybe_get_output(int64_t output_idx) override {
return proxy_outputs_[output_idx].has_value() ? **proxy_outputs_[output_idx] : outputs_[output_idx].get();
}
std::array<std::reference_wrapper<Tensor>, 1> outputs_;
std::array<c10::optional<c10::ExclusivelyOwned<Tensor>>, 1> proxy_outputs_;
};
```
`RegisterSchema.cpp`
```
TORCH_LIBRARY(aten, m) {
...
}
TORCH_LIBRARY(custom, m) {
m.def("gelu.out(Tensor self, *, str approximate='none', Tensor(a!) out) -> Tensor(a!)");
m.def("gelu_(Tensor(a!) self, *, str approximate='none') -> Tensor(a!)");
m.def("gelu(Tensor self, *, str approximate='none') -> Tensor");
};
```
Differential Revision: D36558459
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78015
Approved by: https://github.com/bdhirsh
Package config/template files with torchgen
This PR packages native_functions.yaml, tags.yaml and ATen/templates
with torchgen.
This PR:
- adds a step to setup.py to copy the relevant files over into torchgen
- adds a docstring for torchgen (so `import torchgen; help(torchgen)`
says something)
- adds a helper function in torchgen so you can get the torchgen root
directory (and figure out where the packaged files are)
- changes some scripts to explicitly pass the location of torchgen,
which will be helpful for the first item in the Future section.
Future
======
- torchgen, when invoked from the command line, should use sources
in torchgen/packaged instead of aten/src. I'm unable to do this because
people (aka PyTorch CI) invokes `python -m torchgen.gen` without
installing torchgen.
- the source of truth for all of these files should be in torchgen.
This is a bit annoying to execute on due to potential merge conflicts
and dealing with merge systems
- CI and testing. The way things are set up right now is really fragile,
we should have a CI job for torchgen.
Test Plan
=========
I ran the following locally:
```
python -m torchgen.gen -s torchgen/packaged
```
and verified that it outputted files.
Furthermore, I did a setup.py install and checked that the files are
actually being packaged with torchgen.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78942
Approved by: https://github.com/ezyang
Summary:
Add script to go through view ops in "native_functions.yaml" and auto-register them into static runtime and auto-generate op unit tests for each.
Overall there are 96 grouped view ops, among which 21 is already registered by hand; 9 (including sparse ops/training related ops etc.) are not the target of static runtime; 30 has list args or list ret; and 7 has non-basic types such as "Dimname", "MemoryFormat", etc. In summary, this script auto-generate 29 view ops for now.
Run `buck run //caffe2/torch/fb/jit:gen_static_runtime_ops` to generate static runtime ops, and the results with this script are,
```
total grouped native ops: 1582
grouped native ops with out variant: 548
generated functions groups with out variant: 241
view grouped native ops: 96
generated functions view groups: 29
overall generated : 270
```
The generated view ops are added in D36258968
Test Plan:
Generate static runtime ops: `buck run //caffe2/torch/fb/jit:gen_static_runtime_ops`
Unit tests: `buck run mode/opt //caffe2/benchmarks/static_runtime:static_runtime_cpptest`
Differential Revision: D36258767
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77105
Approved by: https://github.com/mikeiovine
Add codegen infrastructure to generate IR nodes for non-native ops.
The proposed change is to add a `non_native` key to the `{backend}_native_functions.yaml` file that contains schema definitions similar to what is found in `native_functions.yaml`. e.g.
```
non_native:
...
- func: expand(Tensor input, int[] size, bool is_scalar_expand) -> Tensor
...
```
these definitions are parsed into a `LazyIrSchema` that can be used for generating IR nodes using `GenLazyIR`.
Fixes#74628
CC: @wconstab @desertfire @henrytwo
Pull Request resolved: https://github.com/pytorch/pytorch/pull/76535
Approved by: https://github.com/wconstab
For static dispatch we are hardcoding namespace to be `at` for backend-specific C++ functions, e.g., `at::cpu::add()`. We are extending it to accept namespaces from callsite. This is a temporary solution, in the long run we want to introduce custom namespace into codegen system, e.g., we should be able to add `at::` to `native_functions.yaml` and parse it into `NativeFunction`. This needs a bit more design.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77710
Approved by: https://github.com/ezyang
Previously when codegening ops like `zeros_` or `ones_` we'd hit a `Code below assumes there is at least one tensor arg error`. This check is not entirely correct which is what is causing the error to be thrown. There are ops like the ones mentioned that pass in a `device` parameter that can be used in place of the "first tensor".
CC: @wconstab @desertfire @henrytwo @ke1337
Pull Request resolved: https://github.com/pytorch/pytorch/pull/76917
Approved by: https://github.com/desertfire
Partially fix#69813
This PR does mainly 3 things:
1. Introduces new methods for the `MetaBase` API:
- `set_output_strided`: creates proxy tensors with exact strides, if strides don't match
- `set_output_contiguous`: alias for `set_output_strided` with contiguous strides
- `set_output_raw_strided`: does not create proxy tensors
2. Modifies codegen for handling proxy tensors:
- Creates a new field for out-of-place kernels: `proxy_output_`
- Implements `set_output_strided` by creating a proxy tensor if necessary
- Passes the proxy tensor to them `IMPL` function
- Copy the result back to the real output, in the end, whenever a proxy was created
3. Replace `set_output` by `set_output_raw_strided` for `TensorIterator*`
- Needed, since it overrides `set_output`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/76096
Approved by: https://github.com/ezyang
Unfortunately the built-in pprint module support pretty-print of dataclasses only from python 3.10. The code that I wrote in method `__str__` of OpInfo should do the same job and should also work for any dataclass. For now I've put it there but we can create a function and put it somewhere where is accessible also for other dataclasses. Also the max width (80) is now hardcode but it would ideally be the parameter of the function.
when you call print on an OpInfo you get:
```
OpInfo(name = '__getitem__',
ref = None,
aliases = (),
variant_test_name = '',
op = <slot wrapper '__getitem__' of 'torch._C._TensorBase' objects>,
method_variant = <slot wrapper '__getitem__' of 'torch._C._TensorBase' objects>,
inplace_variant = None,
skips = (<torch.testing._internal.common_methods_invocations.DecorateInfo object at 0x7f463acbca90>,
<torch.testing._internal.common_methods_invocations.DecorateInfo object at 0x7f463acbcae0>),
decorators = (<torch.testing._internal.common_methods_invocations.DecorateInfo object at 0x7f463acbca90>,
<torch.testing._internal.common_methods_invocations.DecorateInfo object at 0x7f463acbcae0>),
sample_inputs_func = <function sample_inputs_getitem at 0x7f463acc6af0>,
reference_inputs_func = None,
error_inputs_func = None,
sample_inputs_sparse_coo_func = <function _DecoratorContextManager.__call__.<locals>.decorate_context at 0x7f463acc6b80>,
sample_inputs_sparse_csr_func = <function _DecoratorContextManager.__call__.<locals>.decorate_context at 0x7f463acc6c10>,
dtypes = {torch.int16,
torch.float64,
torch.int32,
torch.int64,
torch.complex64,
torch.float16,
torch.bfloat16,
torch.uint8,
torch.complex128,
torch.bool,
torch.float32,
torch.int8},
dtypesIfCUDA = {torch.int16,
torch.float64,
torch.int32,
torch.int64,
torch.complex64,
torch.float16,
torch.bfloat16,
torch.uint8,
torch.complex128,
torch.bool,
torch.float32,
torch.int8},
dtypesIfROCM = {torch.int16,
torch.float64,
torch.int32,
torch.int64,
torch.complex64,
torch.float16,
torch.bfloat16,
torch.uint8,
torch.complex128,
torch.bool,
torch.float32,
torch.int8},
backward_dtypes = {torch.int16,
torch.float64,
torch.int32,
torch.int64,
torch.complex64,
torch.float16,
torch.bfloat16,
torch.uint8,
torch.complex128,
torch.bool,
torch.float32,
torch.int8},
backward_dtypesIfCUDA = {torch.int16,
torch.float64,
torch.int32,
torch.int64,
torch.complex64,
torch.float16,
torch.bfloat16,
torch.uint8,
torch.complex128,
torch.bool,
torch.float32,
torch.int8},
backward_dtypesIfROCM = {torch.int16,
torch.float64,
torch.int32,
torch.int64,
torch.complex64,
torch.float16,
torch.bfloat16,
torch.uint8,
torch.complex128,
torch.bool,
torch.float32,
torch.int8},
supports_out = False,
supports_autograd = True,
supports_gradgrad = True,
supports_fwgrad_bwgrad = True,
supports_inplace_autograd = False,
supports_forward_ad = True,
gradcheck_wrapper = <function OpInfo.<lambda> at 0x7f463a7a40d0>,
check_batched_grad = True,
check_batched_gradgrad = True,
check_batched_forward_grad = True,
check_inplace_batched_forward_grad = True,
gradcheck_nondet_tol = 0.0,
gradcheck_fast_mode = None,
aten_name = '__getitem__',
decomp_aten_name = None,
aten_backward_name = None,
assert_autodiffed = False,
autodiff_nonfusible_nodes = ['aten::__getitem__'],
autodiff_fusible_nodes = [],
supports_sparse = False,
supports_scripting = False,
supports_sparse_csr = False,
test_conjugated_samples = True,
test_neg_view = True,
assert_jit_shape_analysis = False,
supports_expanded_weight = False)
```
cc @ezyang
Pull Request resolved: https://github.com/pytorch/pytorch/pull/76810
Approved by: https://github.com/ezyang