mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Use presence of _symint in kernel name to generate symint sig or not (#84579)
Something people found confusing was that whether or not a native:: signature would get SymInt or not in its type was based on the dispatch key. This changes it so that SymInt or not in type is based on whether or not you have _symint in the name of the kernel or not. This means that even when we make operators support SymInt, you no longer have to go and update all the preexisting definitions; instead, you now selectively write _symint to opt individual kernels into SymInt support. I then go and update a bunch of kernels that don't have proper SymInt support to make use of this convention. There is some hacking around for view generation code. I also add support for external backends to specify 'symint' operators, for which we generate SymInt signatures instead of regular signatures. Signed-off-by: Edward Z. Yang <ezyang@fb.com> Differential Revision: [D39310060](https://our.internmc.facebook.com/intern/diff/D39310060) Pull Request resolved: https://github.com/pytorch/pytorch/pull/84579 Approved by: https://github.com/wconstab
This commit is contained in:
parent
18a31cc044
commit
93aef3a010
2
.github/ci_commit_pins/xla.txt
vendored
2
.github/ci_commit_pins/xla.txt
vendored
|
|
@ -1 +1 @@
|
||||||
2ba7616e9070bd14ea34a5ef5459bac571198926
|
f00dd2f35ecf6455d97237d63c70c9c8ec190940
|
||||||
|
|
|
||||||
|
|
@ -12,7 +12,7 @@
|
||||||
namespace at {
|
namespace at {
|
||||||
namespace native {
|
namespace native {
|
||||||
|
|
||||||
Tensor empty_meta(
|
Tensor empty_meta_symint(
|
||||||
SymIntArrayRef size,
|
SymIntArrayRef size,
|
||||||
c10::optional<ScalarType> dtype_opt,
|
c10::optional<ScalarType> dtype_opt,
|
||||||
c10::optional<Layout> layout_opt,
|
c10::optional<Layout> layout_opt,
|
||||||
|
|
|
||||||
|
|
@ -214,12 +214,9 @@ Tensor empty_strided_cpu(IntArrayRef size, IntArrayRef stride, c10::optional<Sca
|
||||||
return at::detail::empty_strided_cpu(size, stride, dtype_opt, layout_opt, device_opt, pin_memory_opt);
|
return at::detail::empty_strided_cpu(size, stride, dtype_opt, layout_opt, device_opt, pin_memory_opt);
|
||||||
}
|
}
|
||||||
|
|
||||||
Tensor& empty_out(SymIntArrayRef sym_size,
|
Tensor& empty_out(IntArrayRef size,
|
||||||
c10::optional<c10::MemoryFormat> optional_memory_format,
|
c10::optional<c10::MemoryFormat> optional_memory_format,
|
||||||
Tensor& result) {
|
Tensor& result) {
|
||||||
// TODO: support empty_out properly (I was forced to change this immediately
|
|
||||||
// with empty so that empty/empty.out had the same type signature)
|
|
||||||
auto size = c10::asIntArrayRefSlow(sym_size);
|
|
||||||
// Preferably, this argument would not be accepted by _out, but the code
|
// Preferably, this argument would not be accepted by _out, but the code
|
||||||
// generator requires the out and non-out overloads to match exactly
|
// generator requires the out and non-out overloads to match exactly
|
||||||
TORCH_CHECK(
|
TORCH_CHECK(
|
||||||
|
|
@ -386,7 +383,7 @@ Tensor empty_like_quantized(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Tensor new_empty(
|
Tensor new_empty_symint(
|
||||||
const Tensor& self,
|
const Tensor& self,
|
||||||
SymIntArrayRef size,
|
SymIntArrayRef size,
|
||||||
c10::optional<ScalarType> dtype_opt,
|
c10::optional<ScalarType> dtype_opt,
|
||||||
|
|
@ -1077,7 +1074,7 @@ Tensor triu_indices_cpu(
|
||||||
|
|
||||||
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ zeros ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ zeros ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
Tensor zeros(SymIntArrayRef size,
|
Tensor zeros_symint(SymIntArrayRef size,
|
||||||
c10::optional<ScalarType> dtype,
|
c10::optional<ScalarType> dtype,
|
||||||
c10::optional<Layout> layout,
|
c10::optional<Layout> layout,
|
||||||
c10::optional<Device> device,
|
c10::optional<Device> device,
|
||||||
|
|
@ -1107,8 +1104,7 @@ Tensor& zeros_sparse_out(IntArrayRef size, Tensor& result) {
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
Tensor& zeros_out(SymIntArrayRef sym_size, Tensor& result) {
|
Tensor& zeros_out(IntArrayRef size, Tensor& result) {
|
||||||
auto size = c10::asIntArrayRefSlow(sym_size);
|
|
||||||
if (result.is_sparse()) {
|
if (result.is_sparse()) {
|
||||||
// TODO: I think this branch should be dead, but we don't have an easy
|
// TODO: I think this branch should be dead, but we don't have an easy
|
||||||
// way to cover all sparse kernels with zeros_sparse_out, so retain this
|
// way to cover all sparse kernels with zeros_sparse_out, so retain this
|
||||||
|
|
|
||||||
|
|
@ -844,9 +844,7 @@ Tensor diag_embed(const Tensor& self, int64_t offset, int64_t dim1_, int64_t dim
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
Tensor expand(const Tensor& self, c10::SymIntArrayRef sym_size, bool /*unused*/) {
|
Tensor expand(const Tensor& self, c10::IntArrayRef size, bool /*unused*/) {
|
||||||
// TODO: properly support SymInt expand
|
|
||||||
auto size = asIntArrayRefSlow(sym_size);
|
|
||||||
TORCH_CHECK(size.size() >= (size_t)self.dim(),
|
TORCH_CHECK(size.size() >= (size_t)self.dim(),
|
||||||
"expand(", self.toString(), "{", self.sizes(), "}, size=", size,
|
"expand(", self.toString(), "{", self.sizes(), "}, size=", size,
|
||||||
"): the number of sizes provided (", size.size(), ") ",
|
"): the number of sizes provided (", size.size(), ") ",
|
||||||
|
|
@ -925,9 +923,8 @@ const Tensor &as_strided_(const Tensor& self, IntArrayRef size, IntArrayRef stri
|
||||||
return self;
|
return self;
|
||||||
}
|
}
|
||||||
|
|
||||||
Tensor narrow_copy_dense(const Tensor& self, int64_t dim, SymInt start, SymInt length) {
|
Tensor narrow_copy_dense(const Tensor& self, int64_t dim, int64_t start, int64_t length) {
|
||||||
// TODO: properly support SymInt narrow_copy
|
return self.narrow(dim, start, length).clone(at::MemoryFormat::Contiguous);
|
||||||
return self.narrow(dim, start.expect_int(), length.expect_int()).clone(at::MemoryFormat::Contiguous);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Tensor narrow_copy_dense_cpu(const Tensor& self, int64_t dim, int64_t start, int64_t length){
|
Tensor narrow_copy_dense_cpu(const Tensor& self, int64_t dim, int64_t start, int64_t length){
|
||||||
|
|
@ -3204,12 +3201,6 @@ Tensor adjoint(const Tensor &self) {
|
||||||
return _adjoint(self, /*transpose=*/false, "adjoint()");
|
return _adjoint(self, /*transpose=*/false, "adjoint()");
|
||||||
}
|
}
|
||||||
|
|
||||||
Tensor view_meta(const Tensor& self,
|
|
||||||
at::SymIntArrayRef size) {
|
|
||||||
// TODO: Properly support SymInt view
|
|
||||||
return view_impl(self, c10::asIntArrayRefSlow(size));
|
|
||||||
}
|
|
||||||
|
|
||||||
Tensor view(const Tensor& self,
|
Tensor view(const Tensor& self,
|
||||||
at::IntArrayRef size) {
|
at::IntArrayRef size) {
|
||||||
return view_impl(self, size);
|
return view_impl(self, size);
|
||||||
|
|
@ -3592,7 +3583,7 @@ at::Tensor& expand_copy_SymInt_out(const at::Tensor & self, c10::SymIntArrayRef
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
at::Tensor& expand_copy_out(const at::Tensor & self, at::SymIntArrayRef size, bool implicit, at::Tensor & out) {
|
at::Tensor& expand_copy_out_symint(const at::Tensor & self, at::SymIntArrayRef size, bool implicit, at::Tensor & out) {
|
||||||
auto tmp = self.expand_symint(size, implicit);
|
auto tmp = self.expand_symint(size, implicit);
|
||||||
out.copy_(tmp);
|
out.copy_(tmp);
|
||||||
return out;
|
return out;
|
||||||
|
|
@ -3748,7 +3739,7 @@ void unbind_copy_int_out(const at::Tensor & self, int64_t dim, at::TensorList o
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
at::Tensor& view_copy_out(const at::Tensor & self, at::SymIntArrayRef size, at::Tensor & out) {
|
at::Tensor& view_copy_out_symint(const at::Tensor & self, at::SymIntArrayRef size, at::Tensor & out) {
|
||||||
auto tmp = self.view_symint(size);
|
auto tmp = self.view_symint(size);
|
||||||
out.copy_(tmp);
|
out.copy_(tmp);
|
||||||
return out;
|
return out;
|
||||||
|
|
|
||||||
|
|
@ -2054,7 +2054,7 @@
|
||||||
CPU: empty_cpu
|
CPU: empty_cpu
|
||||||
CUDA: empty_cuda
|
CUDA: empty_cuda
|
||||||
MPS: empty_mps
|
MPS: empty_mps
|
||||||
Meta: empty_meta
|
Meta: empty_meta_symint
|
||||||
MkldnnCPU: empty_mkldnn
|
MkldnnCPU: empty_mkldnn
|
||||||
SparseCPU, SparseCUDA, SparseMeta: empty_sparse
|
SparseCPU, SparseCUDA, SparseMeta: empty_sparse
|
||||||
SparseCsrCPU, SparseCsrCUDA: empty_sparse_compressed
|
SparseCsrCPU, SparseCsrCUDA: empty_sparse_compressed
|
||||||
|
|
@ -2065,7 +2065,7 @@
|
||||||
- func: new_empty(Tensor self, SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
|
- func: new_empty(Tensor self, SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
|
||||||
variants: method
|
variants: method
|
||||||
dispatch:
|
dispatch:
|
||||||
CompositeExplicitAutograd: new_empty
|
CompositeExplicitAutograd: new_empty_symint
|
||||||
autogen: new_empty.out
|
autogen: new_empty.out
|
||||||
|
|
||||||
- func: new_empty_strided(Tensor self, int[] size, int[] stride, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
|
- func: new_empty_strided(Tensor self, int[] size, int[] stride, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
|
||||||
|
|
@ -5548,7 +5548,7 @@
|
||||||
|
|
||||||
- func: zeros(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
|
- func: zeros(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
|
||||||
dispatch:
|
dispatch:
|
||||||
CompositeExplicitAutograd: zeros
|
CompositeExplicitAutograd: zeros_symint
|
||||||
|
|
||||||
- func: zeros.out(SymInt[] size, *, Tensor(a!) out) -> Tensor(a!)
|
- func: zeros.out(SymInt[] size, *, Tensor(a!) out) -> Tensor(a!)
|
||||||
dispatch:
|
dispatch:
|
||||||
|
|
@ -6889,8 +6889,7 @@
|
||||||
device_check: NoCheck
|
device_check: NoCheck
|
||||||
device_guard: False
|
device_guard: False
|
||||||
dispatch:
|
dispatch:
|
||||||
Meta: view_meta
|
ZeroTensor, Meta, CPU, CUDA, QuantizedCPU, QuantizedCUDA, MPS: view
|
||||||
ZeroTensor, CPU, CUDA, QuantizedCPU, QuantizedCUDA, MPS: view
|
|
||||||
MkldnnCPU: mkldnn_view
|
MkldnnCPU: mkldnn_view
|
||||||
NestedTensorCPU, NestedTensorCUDA: view_nested
|
NestedTensorCPU, NestedTensorCUDA: view_nested
|
||||||
|
|
||||||
|
|
@ -12938,7 +12937,7 @@
|
||||||
- func: expand_copy.out(Tensor self, SymInt[] size, *, bool implicit=False, Tensor(a!) out) -> Tensor(a!)
|
- func: expand_copy.out(Tensor self, SymInt[] size, *, bool implicit=False, Tensor(a!) out) -> Tensor(a!)
|
||||||
variants: function
|
variants: function
|
||||||
dispatch:
|
dispatch:
|
||||||
CompositeExplicitAutograd: expand_copy_out
|
CompositeExplicitAutograd: expand_copy_out_symint
|
||||||
|
|
||||||
|
|
||||||
- func: permute_copy.out(Tensor self, int[] dims, *, Tensor(a!) out) -> Tensor(a!)
|
- func: permute_copy.out(Tensor self, int[] dims, *, Tensor(a!) out) -> Tensor(a!)
|
||||||
|
|
@ -13058,7 +13057,7 @@
|
||||||
- func: view_copy.out(Tensor self, SymInt[] size, *, Tensor(a!) out) -> Tensor(a!)
|
- func: view_copy.out(Tensor self, SymInt[] size, *, Tensor(a!) out) -> Tensor(a!)
|
||||||
variants: function
|
variants: function
|
||||||
dispatch:
|
dispatch:
|
||||||
CompositeExplicitAutograd: view_copy_out
|
CompositeExplicitAutograd: view_copy_out_symint
|
||||||
|
|
||||||
|
|
||||||
- func: view_copy.dtype_out(Tensor self, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!)
|
- func: view_copy.dtype_out(Tensor self, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!)
|
||||||
|
|
|
||||||
|
|
@ -199,6 +199,13 @@ supported:
|
||||||
- _trilinear
|
- _trilinear
|
||||||
- linalg_pinv.atol_rtol_tensor
|
- linalg_pinv.atol_rtol_tensor
|
||||||
- logsumexp.out
|
- logsumexp.out
|
||||||
|
symint:
|
||||||
|
- empty.memory_format
|
||||||
|
- expand
|
||||||
|
- expand_copy
|
||||||
|
- narrow_copy
|
||||||
|
- view
|
||||||
|
- view_copy
|
||||||
autograd:
|
autograd:
|
||||||
- max_pool3d
|
- max_pool3d
|
||||||
- native_group_norm
|
- native_group_norm
|
||||||
|
|
|
||||||
|
|
@ -119,7 +119,7 @@ TEST(MathKernelTest, NarrowCopy) {
|
||||||
for (const auto dim : c10::irange(3)) {
|
for (const auto dim : c10::irange(3)) {
|
||||||
const int64_t start = 1, length = 4;
|
const int64_t start = 1, length = 4;
|
||||||
auto y_ref = x.narrow(dim, start, length);
|
auto y_ref = x.narrow(dim, start, length);
|
||||||
auto y_test = at::native::narrow_copy_dense(x, dim, c10::SymInt(start), c10::SymInt(length));
|
auto y_test = at::native::narrow_copy_dense(x, dim, start, length);
|
||||||
ASSERT_ALLCLOSE_TOLERANCES(y_ref, y_test, 0, 0);
|
ASSERT_ALLCLOSE_TOLERANCES(y_ref, y_test, 0, 0);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -5,6 +5,7 @@
|
||||||
# LICENSE file in the root directory of this source tree.
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
import distutils.command.clean
|
import distutils.command.clean
|
||||||
|
import sys
|
||||||
import shutil
|
import shutil
|
||||||
import glob
|
import glob
|
||||||
import os
|
import os
|
||||||
|
|
@ -129,21 +130,25 @@ class BuildExtension_(BuildExtension):
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
print("Building wheel {}-{}".format(package_name, version))
|
print("Building wheel {}-{}".format(package_name, version))
|
||||||
write_version_file()
|
write_version_file()
|
||||||
setup(
|
try:
|
||||||
# Metadata
|
setup(
|
||||||
name=package_name,
|
# Metadata
|
||||||
version=version,
|
name=package_name,
|
||||||
author='PyTorch Core Team',
|
version=version,
|
||||||
url="https://github.com/pytorch/functorch",
|
author='PyTorch Core Team',
|
||||||
description='JAX-like composable function transforms for PyTorch',
|
url="https://github.com/pytorch/functorch",
|
||||||
license='BSD',
|
description='JAX-like composable function transforms for PyTorch',
|
||||||
|
license='BSD',
|
||||||
|
|
||||||
# Package info
|
# Package info
|
||||||
packages=find_packages(),
|
packages=find_packages(),
|
||||||
install_requires=requirements,
|
install_requires=requirements,
|
||||||
extras_require=extras,
|
extras_require=extras,
|
||||||
ext_modules=get_extensions(),
|
ext_modules=get_extensions(),
|
||||||
cmdclass={
|
cmdclass={
|
||||||
"build_ext": BuildExtension_.with_options(no_python_abi_suffix=True),
|
"build_ext": BuildExtension_.with_options(no_python_abi_suffix=True),
|
||||||
'clean': clean,
|
'clean': clean,
|
||||||
})
|
})
|
||||||
|
except Exception as e:
|
||||||
|
print(e, file=sys.stderr)
|
||||||
|
sys.exit(1)
|
||||||
|
|
|
||||||
|
|
@ -584,7 +584,8 @@ def gen_inplace_or_view_type(
|
||||||
[fn for fn in fns_with_infos if use_derived(fn)],
|
[fn for fn in fns_with_infos if use_derived(fn)],
|
||||||
key_fn=lambda fn: fn.func.root_name,
|
key_fn=lambda fn: fn.func.root_name,
|
||||||
base_env={
|
base_env={
|
||||||
"generated_comment": f"@generated from {template_path}/ADInplaceOrViewType.cpp",
|
"generated_comment": "@"
|
||||||
|
+ f"generated from {template_path}/ADInplaceOrViewType.cpp",
|
||||||
},
|
},
|
||||||
env_callable=gen_inplace_or_view_type_env,
|
env_callable=gen_inplace_or_view_type_env,
|
||||||
num_shards=2,
|
num_shards=2,
|
||||||
|
|
|
||||||
|
|
@ -535,7 +535,7 @@ def gen_trace_type(
|
||||||
[fn for fn in native_functions if cpp.name(fn.func) not in MANUAL_TRACER],
|
[fn for fn in native_functions if cpp.name(fn.func) not in MANUAL_TRACER],
|
||||||
key_fn=lambda fn: fn.root_name,
|
key_fn=lambda fn: fn.root_name,
|
||||||
base_env={
|
base_env={
|
||||||
"generated_comment": f"@generated from {template_path}/TraceType.cpp",
|
"generated_comment": "@" + f"generated from {template_path}/TraceType.cpp",
|
||||||
},
|
},
|
||||||
env_callable=gen_trace_type_func,
|
env_callable=gen_trace_type_func,
|
||||||
num_shards=5,
|
num_shards=5,
|
||||||
|
|
|
||||||
|
|
@ -314,7 +314,6 @@ class TestGenNativeFunctionDeclaration(unittest.TestCase):
|
||||||
dispatch_key=k,
|
dispatch_key=k,
|
||||||
use_out_as_primary=True,
|
use_out_as_primary=True,
|
||||||
external=False,
|
external=False,
|
||||||
symint=False,
|
|
||||||
device_guard=False,
|
device_guard=False,
|
||||||
index=backend_indices[k],
|
index=backend_indices[k],
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -238,7 +238,7 @@ invalid_key: invalid_val"""
|
||||||
output_error = self.get_errors_from_gen_backend_stubs(yaml_str)
|
output_error = self.get_errors_from_gen_backend_stubs(yaml_str)
|
||||||
self.assertExpectedInline(
|
self.assertExpectedInline(
|
||||||
output_error,
|
output_error,
|
||||||
""" contains unexpected keys: invalid_key. Only the following keys are supported: backend, class_name, cpp_namespace, extra_headers, supported, autograd, full_codegen, non_native, ir_gen""", # noqa: B950
|
""" contains unexpected keys: invalid_key. Only the following keys are supported: backend, class_name, cpp_namespace, extra_headers, supported, autograd, full_codegen, non_native, ir_gen, symint""", # noqa: B950
|
||||||
)
|
)
|
||||||
|
|
||||||
# if use_out_as_primary is provided, it must be a bool
|
# if use_out_as_primary is provided, it must be a bool
|
||||||
|
|
|
||||||
|
|
@ -269,7 +269,7 @@ at::Tensor LazyNativeFunctions::_to_copy(
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
at::Tensor LazyNativeFunctions::empty(
|
at::Tensor LazyNativeFunctions::empty_symint(
|
||||||
at::SymIntArrayRef sym_size,
|
at::SymIntArrayRef sym_size,
|
||||||
c10::optional<at::ScalarType> dtype,
|
c10::optional<at::ScalarType> dtype,
|
||||||
c10::optional<at::Layout> layout,
|
c10::optional<at::Layout> layout,
|
||||||
|
|
@ -307,7 +307,7 @@ at::Tensor LazyNativeFunctions::empty_strided(
|
||||||
c10::optional<at::Device> device,
|
c10::optional<at::Device> device,
|
||||||
c10::optional<bool> pin_memory) {
|
c10::optional<bool> pin_memory) {
|
||||||
TORCH_LAZY_FN_COUNTER("lazy::");
|
TORCH_LAZY_FN_COUNTER("lazy::");
|
||||||
at::Tensor t = empty(
|
at::Tensor t = empty_symint(
|
||||||
c10::SymIntArrayRef::fromIntArrayRef(size),
|
c10::SymIntArrayRef::fromIntArrayRef(size),
|
||||||
dtype,
|
dtype,
|
||||||
layout,
|
layout,
|
||||||
|
|
@ -409,7 +409,7 @@ at::Tensor LazyNativeFunctions::_unsafe_view(
|
||||||
const at::Tensor& self,
|
const at::Tensor& self,
|
||||||
at::IntArrayRef size) {
|
at::IntArrayRef size) {
|
||||||
TORCH_LAZY_FN_COUNTER("lazy::");
|
TORCH_LAZY_FN_COUNTER("lazy::");
|
||||||
return LazyNativeFunctions::view_copy(
|
return LazyNativeFunctions::view_copy_symint(
|
||||||
self, c10::SymIntArrayRef::fromIntArrayRef(size));
|
self, c10::SymIntArrayRef::fromIntArrayRef(size));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -449,7 +449,7 @@ at::Tensor LazyNativeFunctions::new_empty_strided(
|
||||||
self, size, stride, dtype, layout, device, pin_memory);
|
self, size, stride, dtype, layout, device, pin_memory);
|
||||||
}
|
}
|
||||||
|
|
||||||
at::Tensor LazyNativeFunctions::narrow_copy(
|
at::Tensor LazyNativeFunctions::narrow_copy_symint(
|
||||||
const at::Tensor& self,
|
const at::Tensor& self,
|
||||||
int64_t dim,
|
int64_t dim,
|
||||||
c10::SymInt start,
|
c10::SymInt start,
|
||||||
|
|
|
||||||
|
|
@ -35,7 +35,12 @@ def name(func: FunctionSchema) -> str:
|
||||||
|
|
||||||
|
|
||||||
def argumenttype_type(
|
def argumenttype_type(
|
||||||
t: Type, *, mutable: bool, binds: ArgName, remove_non_owning_ref_types: bool = False
|
t: Type,
|
||||||
|
*,
|
||||||
|
mutable: bool,
|
||||||
|
binds: ArgName,
|
||||||
|
remove_non_owning_ref_types: bool = False,
|
||||||
|
symint: bool = True,
|
||||||
) -> NamedCType:
|
) -> NamedCType:
|
||||||
# This is a faux amis. If it makes sense in the future to add
|
# This is a faux amis. If it makes sense in the future to add
|
||||||
# more special cases here, or invert things so cpp.argument_type
|
# more special cases here, or invert things so cpp.argument_type
|
||||||
|
|
@ -45,25 +50,30 @@ def argumenttype_type(
|
||||||
t,
|
t,
|
||||||
mutable=mutable,
|
mutable=mutable,
|
||||||
binds=binds,
|
binds=binds,
|
||||||
symint=True,
|
symint=symint,
|
||||||
remove_non_owning_ref_types=remove_non_owning_ref_types,
|
remove_non_owning_ref_types=remove_non_owning_ref_types,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def argument_type(
|
def argument_type(
|
||||||
a: Argument, *, binds: ArgName, remove_non_owning_ref_types: bool = False
|
a: Argument,
|
||||||
|
*,
|
||||||
|
binds: ArgName,
|
||||||
|
remove_non_owning_ref_types: bool = False,
|
||||||
|
symint: bool = True,
|
||||||
) -> NamedCType:
|
) -> NamedCType:
|
||||||
return argumenttype_type(
|
return argumenttype_type(
|
||||||
a.type,
|
a.type,
|
||||||
mutable=a.is_write,
|
mutable=a.is_write,
|
||||||
binds=binds,
|
binds=binds,
|
||||||
remove_non_owning_ref_types=remove_non_owning_ref_types,
|
remove_non_owning_ref_types=remove_non_owning_ref_types,
|
||||||
|
symint=symint,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def returns_type(rs: Sequence[Return]) -> CType:
|
def returns_type(rs: Sequence[Return], *, symint: bool = True) -> CType:
|
||||||
# At present, there is no difference. But there could be!
|
# At present, there is no difference. But there could be!
|
||||||
return cpp.returns_type(rs, symint=True)
|
return cpp.returns_type(rs, symint=symint)
|
||||||
|
|
||||||
|
|
||||||
def jit_arguments(func: FunctionSchema) -> List[Argument]:
|
def jit_arguments(func: FunctionSchema) -> List[Argument]:
|
||||||
|
|
@ -89,15 +99,20 @@ def jit_arguments(func: FunctionSchema) -> List[Argument]:
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def argument(a: Argument, *, remove_non_owning_ref_types: bool = False) -> Binding:
|
def argument(
|
||||||
|
a: Argument, *, remove_non_owning_ref_types: bool = False, symint: bool = True
|
||||||
|
) -> Binding:
|
||||||
return Binding(
|
return Binding(
|
||||||
nctype=argument_type(
|
nctype=argument_type(
|
||||||
a, binds=a.name, remove_non_owning_ref_types=remove_non_owning_ref_types
|
a,
|
||||||
|
binds=a.name,
|
||||||
|
remove_non_owning_ref_types=remove_non_owning_ref_types,
|
||||||
|
symint=symint,
|
||||||
),
|
),
|
||||||
name=a.name,
|
name=a.name,
|
||||||
argument=a,
|
argument=a,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def arguments(func: FunctionSchema) -> List[Binding]:
|
def arguments(func: FunctionSchema, *, symint: bool = True) -> List[Binding]:
|
||||||
return [argument(a) for a in jit_arguments(func)]
|
return [argument(a, symint=symint) for a in jit_arguments(func)]
|
||||||
|
|
|
||||||
|
|
@ -577,8 +577,10 @@ class DispatcherSignature:
|
||||||
# and need to avoid naming collisions.
|
# and need to avoid naming collisions.
|
||||||
prefix: str = ""
|
prefix: str = ""
|
||||||
|
|
||||||
|
symint: bool = True
|
||||||
|
|
||||||
def arguments(self) -> List[Binding]:
|
def arguments(self) -> List[Binding]:
|
||||||
return dispatcher.arguments(self.func)
|
return dispatcher.arguments(self.func, symint=self.symint)
|
||||||
|
|
||||||
def name(self) -> str:
|
def name(self) -> str:
|
||||||
return self.prefix + dispatcher.name(self.func)
|
return self.prefix + dispatcher.name(self.func)
|
||||||
|
|
@ -604,7 +606,7 @@ class DispatcherSignature:
|
||||||
return [Expr(a.name, a.nctype) for a in self.arguments()]
|
return [Expr(a.name, a.nctype) for a in self.arguments()]
|
||||||
|
|
||||||
def returns_type(self) -> CType:
|
def returns_type(self) -> CType:
|
||||||
return dispatcher.returns_type(self.func.returns)
|
return dispatcher.returns_type(self.func.returns, symint=self.symint)
|
||||||
|
|
||||||
def ptr_type(self) -> str:
|
def ptr_type(self) -> str:
|
||||||
dispatcher_args_types_str = ", ".join(a.type for a in self.arguments())
|
dispatcher_args_types_str = ", ".join(a.type for a in self.arguments())
|
||||||
|
|
@ -616,8 +618,10 @@ class DispatcherSignature:
|
||||||
return f"{self.returns_type().cpp_type()} ({dispatcher_args_types_str})"
|
return f"{self.returns_type().cpp_type()} ({dispatcher_args_types_str})"
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_schema(func: FunctionSchema, *, prefix: str = "") -> "DispatcherSignature":
|
def from_schema(
|
||||||
return DispatcherSignature(func, prefix)
|
func: FunctionSchema, *, prefix: str = "", symint: bool = True
|
||||||
|
) -> "DispatcherSignature":
|
||||||
|
return DispatcherSignature(func, prefix, symint)
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
|
|
@ -778,15 +782,16 @@ def kernel_signature(
|
||||||
# so we'd like to keep the differences as small as possible.
|
# so we'd like to keep the differences as small as possible.
|
||||||
# With external backends, we'd like to enforce that they write their kernels with schemas
|
# With external backends, we'd like to enforce that they write their kernels with schemas
|
||||||
# that match the Dispatcher API directly, if they can.
|
# that match the Dispatcher API directly, if they can.
|
||||||
|
meta = backend_index.get_kernel(f)
|
||||||
|
symint = meta is not None and meta.supports_symint()
|
||||||
|
if symint:
|
||||||
|
assert (
|
||||||
|
f.func.has_symint()
|
||||||
|
), f"attempted to define symint kernel for {backend_index.dispatch_key} without SymInt in schema"
|
||||||
if backend_index.external:
|
if backend_index.external:
|
||||||
# Dispatcher signature faithfully does SymInt, which is good for XLA,
|
return DispatcherSignature.from_schema(f.func, prefix=prefix, symint=symint)
|
||||||
# not so good for more conventional backends but we don't have any of
|
|
||||||
# those. If we do, that's time to add a new Signature that is a cross
|
|
||||||
# between DispatcherSignature and NativeSignature
|
|
||||||
assert backend_index.symint
|
|
||||||
return DispatcherSignature.from_schema(f.func, prefix=prefix)
|
|
||||||
else:
|
else:
|
||||||
return NativeSignature(f.func, prefix=prefix, symint=backend_index.symint)
|
return NativeSignature(f.func, prefix=prefix, symint=symint)
|
||||||
|
|
||||||
|
|
||||||
# Functions only, no types
|
# Functions only, no types
|
||||||
|
|
|
||||||
|
|
@ -751,8 +751,11 @@ resize_out(out, sizes, strides, options);
|
||||||
)
|
)
|
||||||
|
|
||||||
# Signature of the wrapper function we'll register to the dispatcher
|
# Signature of the wrapper function we'll register to the dispatcher
|
||||||
|
kern = self.backend_index.get_kernel(f)
|
||||||
sig = NativeSignature(
|
sig = NativeSignature(
|
||||||
f.func, prefix="wrapper_", symint=self.backend_index.symint
|
f.func,
|
||||||
|
prefix="wrapper_",
|
||||||
|
symint=kern is not None and kern.supports_symint(),
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.target is Target.NAMESPACED_DECLARATION:
|
if self.target is Target.NAMESPACED_DECLARATION:
|
||||||
|
|
|
||||||
|
|
@ -20,6 +20,7 @@ from torchgen.api import cpp
|
||||||
from torchgen.api.translate import translate
|
from torchgen.api.translate import translate
|
||||||
from torchgen.api.types import (
|
from torchgen.api.types import (
|
||||||
Binding,
|
Binding,
|
||||||
|
CppSignature,
|
||||||
CppSignatureGroup,
|
CppSignatureGroup,
|
||||||
DispatcherSignature,
|
DispatcherSignature,
|
||||||
NamedCType,
|
NamedCType,
|
||||||
|
|
@ -161,7 +162,6 @@ def parse_native_yaml_struct(
|
||||||
device_guard=False,
|
device_guard=False,
|
||||||
# I'm actually not sure about this; undefined could be hit on
|
# I'm actually not sure about this; undefined could be hit on
|
||||||
# empty TensorList, hypothetically that could have sizes in it
|
# empty TensorList, hypothetically that could have sizes in it
|
||||||
symint=False,
|
|
||||||
index={},
|
index={},
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
@ -176,16 +176,6 @@ def parse_native_yaml_struct(
|
||||||
# Only cuda-like devices in tree require device guards
|
# Only cuda-like devices in tree require device guards
|
||||||
device_guard=is_cuda_dispatch_key(k),
|
device_guard=is_cuda_dispatch_key(k),
|
||||||
index=v,
|
index=v,
|
||||||
# Which dispatch keys natively support symint
|
|
||||||
# Note: DispatchKey.CompositeExplicitAutograd has to match out
|
|
||||||
# composites; I think there's some factoring problem here
|
|
||||||
symint=k
|
|
||||||
in [
|
|
||||||
DispatchKey.Meta,
|
|
||||||
DispatchKey.CompositeImplicitAutograd,
|
|
||||||
DispatchKey.CompositeExplicitAutograd,
|
|
||||||
DispatchKey.CompositeExplicitAutogradNonFunctional,
|
|
||||||
],
|
|
||||||
)
|
)
|
||||||
return ParsedYaml(rs, indices)
|
return ParsedYaml(rs, indices)
|
||||||
|
|
||||||
|
|
@ -363,6 +353,8 @@ def static_dispatch_extra_headers(backends: List[BackendIndex]) -> List[str]:
|
||||||
# supporting usecases even when there is a memory_format argument along with tensor_option arguments.
|
# supporting usecases even when there is a memory_format argument along with tensor_option arguments.
|
||||||
# This usecase is not covered by tools.codegen.api.translate() yet as its application is limited to static dispatch
|
# This usecase is not covered by tools.codegen.api.translate() yet as its application is limited to static dispatch
|
||||||
def translate_args_dispatcher_to_cpp(
|
def translate_args_dispatcher_to_cpp(
|
||||||
|
sig: DispatcherSignature,
|
||||||
|
cpp_sig: CppSignature,
|
||||||
f: NativeFunction,
|
f: NativeFunction,
|
||||||
) -> str:
|
) -> str:
|
||||||
|
|
||||||
|
|
@ -385,10 +377,7 @@ def translate_args_dispatcher_to_cpp(
|
||||||
output_bindings.append(binding)
|
output_bindings.append(binding)
|
||||||
return output_bindings
|
return output_bindings
|
||||||
|
|
||||||
disp_sig = DispatcherSignature.from_schema(f.func)
|
disp_sig = sig
|
||||||
cpp_sig = CppSignatureGroup.from_native_function(
|
|
||||||
f, method=False, fallback_binding=False
|
|
||||||
).signature
|
|
||||||
disp_bindings = disp_sig.arguments()
|
disp_bindings = disp_sig.arguments()
|
||||||
# When last argument of CPP signature has SpecialArgName.possibly_redundant_memory_format NCType,
|
# When last argument of CPP signature has SpecialArgName.possibly_redundant_memory_format NCType,
|
||||||
# get memory_format bindings of dispatcher signature to have the same NCType as well
|
# get memory_format bindings of dispatcher signature to have the same NCType as well
|
||||||
|
|
@ -401,11 +390,20 @@ def translate_args_dispatcher_to_cpp(
|
||||||
|
|
||||||
|
|
||||||
def generate_static_dispatch_backend_call(
|
def generate_static_dispatch_backend_call(
|
||||||
|
sig: DispatcherSignature,
|
||||||
f: NativeFunction,
|
f: NativeFunction,
|
||||||
backend_index: BackendIndex,
|
backend_index: BackendIndex,
|
||||||
) -> str:
|
) -> str:
|
||||||
name = DispatcherSignature.from_schema(f.func).name()
|
cpp_sigs = CppSignatureGroup.from_native_function(
|
||||||
exprs = translate_args_dispatcher_to_cpp(f)
|
f, method=False, fallback_binding=False
|
||||||
|
)
|
||||||
|
if sig.symint and f.func.has_symint():
|
||||||
|
cpp_sig = cpp_sigs.symint_signature
|
||||||
|
else:
|
||||||
|
cpp_sig = cpp_sigs.signature
|
||||||
|
assert cpp_sig is not None
|
||||||
|
name = cpp_sig.name()
|
||||||
|
exprs = translate_args_dispatcher_to_cpp(sig, cpp_sig, f)
|
||||||
backend_metadata = backend_index.get_kernel(f)
|
backend_metadata = backend_index.get_kernel(f)
|
||||||
kernel_ns = (
|
kernel_ns = (
|
||||||
backend_metadata.cpp_namespace
|
backend_metadata.cpp_namespace
|
||||||
|
|
@ -417,11 +415,20 @@ def generate_static_dispatch_backend_call(
|
||||||
|
|
||||||
|
|
||||||
def generate_static_dispatch_fallback_call(
|
def generate_static_dispatch_fallback_call(
|
||||||
|
sig: DispatcherSignature,
|
||||||
f: NativeFunction,
|
f: NativeFunction,
|
||||||
backend_indices: List[BackendIndex],
|
backend_indices: List[BackendIndex],
|
||||||
) -> str:
|
) -> str:
|
||||||
name = DispatcherSignature.from_schema(f.func).name()
|
cpp_sigs = CppSignatureGroup.from_native_function(
|
||||||
exprs = translate_args_dispatcher_to_cpp(f)
|
f, method=False, fallback_binding=False
|
||||||
|
)
|
||||||
|
if sig.symint and f.func.has_symint():
|
||||||
|
cpp_sig = cpp_sigs.symint_signature
|
||||||
|
else:
|
||||||
|
cpp_sig = cpp_sigs.signature
|
||||||
|
assert cpp_sig is not None
|
||||||
|
name = cpp_sig.name()
|
||||||
|
exprs = translate_args_dispatcher_to_cpp(sig, cpp_sig, f)
|
||||||
ns = DEFAULT_KERNEL_NAMESPACE.replace("::native", "")
|
ns = DEFAULT_KERNEL_NAMESPACE.replace("::native", "")
|
||||||
if f.has_composite_explicit_autograd_kernel:
|
if f.has_composite_explicit_autograd_kernel:
|
||||||
return f"return {ns}::{DispatchKey.CompositeExplicitAutograd.lower()}::{name}({exprs});"
|
return f"return {ns}::{DispatchKey.CompositeExplicitAutograd.lower()}::{name}({exprs});"
|
||||||
|
|
@ -437,6 +444,7 @@ def generate_static_dispatch_fallback_call(
|
||||||
|
|
||||||
|
|
||||||
def static_dispatch(
|
def static_dispatch(
|
||||||
|
sig: DispatcherSignature,
|
||||||
f: NativeFunction,
|
f: NativeFunction,
|
||||||
backend_indices: List[BackendIndex],
|
backend_indices: List[BackendIndex],
|
||||||
) -> str:
|
) -> str:
|
||||||
|
|
@ -453,11 +461,10 @@ def static_dispatch(
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
if len(keys) == 1:
|
if len(keys) == 1:
|
||||||
return generate_static_dispatch_backend_call(f, keys[0])
|
return generate_static_dispatch_backend_call(sig, f, keys[0])
|
||||||
elif len(keys) == 0:
|
elif len(keys) == 0:
|
||||||
return generate_static_dispatch_fallback_call(f, backend_indices)
|
return generate_static_dispatch_fallback_call(sig, f, backend_indices)
|
||||||
|
|
||||||
sig = DispatcherSignature.from_schema(f.func)
|
|
||||||
native_tensor_args = [
|
native_tensor_args = [
|
||||||
a.name
|
a.name
|
||||||
for a in sig.arguments()
|
for a in sig.arguments()
|
||||||
|
|
@ -483,10 +490,10 @@ def static_dispatch(
|
||||||
for index in keys:
|
for index in keys:
|
||||||
dispatch_code.append(f"""case DispatchKey::{index.dispatch_key}:""")
|
dispatch_code.append(f"""case DispatchKey::{index.dispatch_key}:""")
|
||||||
dispatch_code.append(
|
dispatch_code.append(
|
||||||
f"""\t{generate_static_dispatch_backend_call(f, index)};"""
|
f"""\t{generate_static_dispatch_backend_call(sig, f, index)};"""
|
||||||
)
|
)
|
||||||
|
|
||||||
fallback = generate_static_dispatch_fallback_call(f, backend_indices)
|
fallback = generate_static_dispatch_fallback_call(sig, f, backend_indices)
|
||||||
connector = "\n\t\t"
|
connector = "\n\t\t"
|
||||||
|
|
||||||
return f"""
|
return f"""
|
||||||
|
|
@ -528,8 +535,6 @@ class ComputeOperators:
|
||||||
def __call__(self, f: NativeFunction) -> str:
|
def __call__(self, f: NativeFunction) -> str:
|
||||||
sig = DispatcherSignature.from_schema(f.func)
|
sig = DispatcherSignature.from_schema(f.func)
|
||||||
name = f.func.name.unambiguous_name()
|
name = f.func.name.unambiguous_name()
|
||||||
call_method_name = "call"
|
|
||||||
redispatch_method_name = "redispatch"
|
|
||||||
|
|
||||||
if self.target is Target.DECLARATION:
|
if self.target is Target.DECLARATION:
|
||||||
# Note [The ATen Operators API]
|
# Note [The ATen Operators API]
|
||||||
|
|
@ -563,8 +568,8 @@ struct TORCH_API {name} {{
|
||||||
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::{f.func.name.name}")
|
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::{f.func.name.name}")
|
||||||
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "{f.func.name.overload_name}")
|
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "{f.func.name.overload_name}")
|
||||||
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, {cpp_string(str(f.func))})
|
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, {cpp_string(str(f.func))})
|
||||||
static {sig.defn(name=call_method_name, is_redispatching_fn=False)};
|
static {sig.defn(name="call", is_redispatching_fn=False)};
|
||||||
static {sig.defn(name=redispatch_method_name, is_redispatching_fn=True)};
|
static {sig.defn(name="redispatch", is_redispatching_fn=True)};
|
||||||
}};"""
|
}};"""
|
||||||
|
|
||||||
elif self.target is Target.DEFINITION:
|
elif self.target is Target.DEFINITION:
|
||||||
|
|
@ -585,12 +590,13 @@ static C10_NOINLINE c10::TypedOperatorHandle<{name}::schema> create_{name}_typed
|
||||||
dispatcher_exprs_str = ", ".join(
|
dispatcher_exprs_str = ", ".join(
|
||||||
["dispatchKeySet"] + [a.name for a in sig.arguments()]
|
["dispatchKeySet"] + [a.name for a in sig.arguments()]
|
||||||
)
|
)
|
||||||
dispatcher_call = "redispatch"
|
method_base = "redispatch"
|
||||||
method_name = f"{name}::{redispatch_method_name}"
|
|
||||||
else:
|
else:
|
||||||
method_name = f"{name}::{call_method_name}"
|
|
||||||
dispatcher_exprs_str = ", ".join([a.name for a in sig.arguments()])
|
dispatcher_exprs_str = ", ".join([a.name for a in sig.arguments()])
|
||||||
dispatcher_call = "call"
|
method_base = "call"
|
||||||
|
|
||||||
|
dispatcher_call = method_base
|
||||||
|
method_name = f"{name}::{method_base}"
|
||||||
|
|
||||||
fn_body = f"""
|
fn_body = f"""
|
||||||
static auto op = create_{name}_typed_handle();
|
static auto op = create_{name}_typed_handle();
|
||||||
|
|
@ -602,7 +608,7 @@ static C10_NOINLINE c10::TypedOperatorHandle<{name}::schema> create_{name}_typed
|
||||||
):
|
):
|
||||||
# call() should go through static dispatch
|
# call() should go through static dispatch
|
||||||
fn_body = static_dispatch(
|
fn_body = static_dispatch(
|
||||||
f, backend_indices=self.static_dispatch_backend_indices
|
sig, f, backend_indices=self.static_dispatch_backend_indices
|
||||||
)
|
)
|
||||||
defns += f"""
|
defns += f"""
|
||||||
// aten::{f.func}
|
// aten::{f.func}
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,7 @@ import os
|
||||||
import pathlib
|
import pathlib
|
||||||
import re
|
import re
|
||||||
from collections import Counter, defaultdict, namedtuple
|
from collections import Counter, defaultdict, namedtuple
|
||||||
from typing import Dict, List, Optional, Sequence, Union
|
from typing import Dict, List, Optional, Sequence, Set, Union
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
|
|
@ -68,6 +68,7 @@ def parse_backend_yaml(
|
||||||
"full_codegen",
|
"full_codegen",
|
||||||
"non_native",
|
"non_native",
|
||||||
"ir_gen",
|
"ir_gen",
|
||||||
|
"symint",
|
||||||
]
|
]
|
||||||
|
|
||||||
backend = yaml_values.pop("backend", None)
|
backend = yaml_values.pop("backend", None)
|
||||||
|
|
@ -96,6 +97,14 @@ def parse_backend_yaml(
|
||||||
supported, list
|
supported, list
|
||||||
), f'expected "supported" to be a list, but got: {supported} (of type {type(supported)})'
|
), f'expected "supported" to be a list, but got: {supported} (of type {type(supported)})'
|
||||||
|
|
||||||
|
symint = yaml_values.pop("symint", [])
|
||||||
|
if symint is None:
|
||||||
|
symint = [] # Allow an empty list of symint ops
|
||||||
|
assert isinstance(
|
||||||
|
symint, list
|
||||||
|
), f'expected "symint" to be a list, but got: {supported} (of type {type(supported)})'
|
||||||
|
symint_set = set(symint)
|
||||||
|
|
||||||
supported_autograd = yaml_values.pop("autograd", [])
|
supported_autograd = yaml_values.pop("autograd", [])
|
||||||
assert isinstance(
|
assert isinstance(
|
||||||
supported_autograd, list
|
supported_autograd, list
|
||||||
|
|
@ -118,6 +127,7 @@ Only the following keys are supported: {", ".join(valid_keys)}'
|
||||||
|
|
||||||
def create_backend_index(
|
def create_backend_index(
|
||||||
backend_ops: List[str],
|
backend_ops: List[str],
|
||||||
|
symint_ops: Set[str],
|
||||||
dispatch_key: DispatchKey,
|
dispatch_key: DispatchKey,
|
||||||
*,
|
*,
|
||||||
use_out_as_primary: bool,
|
use_out_as_primary: bool,
|
||||||
|
|
@ -131,6 +141,8 @@ Only the following keys are supported: {", ".join(valid_keys)}'
|
||||||
), f"Found an invalid operator name: {op_name}"
|
), f"Found an invalid operator name: {op_name}"
|
||||||
# See Note [External Backends Follow Dispatcher API]
|
# See Note [External Backends Follow Dispatcher API]
|
||||||
kernel_name = dispatcher.name(native_functions_map[op_name].func)
|
kernel_name = dispatcher.name(native_functions_map[op_name].func)
|
||||||
|
if op in symint_ops:
|
||||||
|
kernel_name += "_symint"
|
||||||
# TODO: allow structured external backends later.
|
# TODO: allow structured external backends later.
|
||||||
m = BackendMetadata(
|
m = BackendMetadata(
|
||||||
kernel=kernel_name, structured=False, cpp_namespace=cpp_namespace
|
kernel=kernel_name, structured=False, cpp_namespace=cpp_namespace
|
||||||
|
|
@ -140,7 +152,6 @@ Only the following keys are supported: {", ".join(valid_keys)}'
|
||||||
dispatch_key=dispatch_key,
|
dispatch_key=dispatch_key,
|
||||||
use_out_as_primary=use_out_as_primary,
|
use_out_as_primary=use_out_as_primary,
|
||||||
external=True,
|
external=True,
|
||||||
symint=True, # TODO: make this configurable
|
|
||||||
device_guard=use_device_guard,
|
device_guard=use_device_guard,
|
||||||
index=metadata,
|
index=metadata,
|
||||||
)
|
)
|
||||||
|
|
@ -154,6 +165,7 @@ Only the following keys are supported: {", ".join(valid_keys)}'
|
||||||
|
|
||||||
backend_idx = create_backend_index(
|
backend_idx = create_backend_index(
|
||||||
supported,
|
supported,
|
||||||
|
symint_set,
|
||||||
backend_key,
|
backend_key,
|
||||||
use_out_as_primary=use_out_as_primary,
|
use_out_as_primary=use_out_as_primary,
|
||||||
use_device_guard=use_device_guard,
|
use_device_guard=use_device_guard,
|
||||||
|
|
@ -171,6 +183,7 @@ the behavior of autograd for some operators on your backend. However "Autograd{b
|
||||||
|
|
||||||
autograd_idx = create_backend_index(
|
autograd_idx = create_backend_index(
|
||||||
supported_autograd,
|
supported_autograd,
|
||||||
|
symint_set,
|
||||||
autograd_key,
|
autograd_key,
|
||||||
use_out_as_primary=use_out_as_primary,
|
use_out_as_primary=use_out_as_primary,
|
||||||
use_device_guard=use_device_guard,
|
use_device_guard=use_device_guard,
|
||||||
|
|
|
||||||
|
|
@ -84,21 +84,21 @@ def gen_composite_view_copy_kernel(g: NativeFunctionsViewGroup) -> Optional[str]
|
||||||
# clone() calls in their graph (which is normally needed by reshape).
|
# clone() calls in their graph (which is normally needed by reshape).
|
||||||
if str(g.view_copy.func.name) == "view_copy":
|
if str(g.view_copy.func.name) == "view_copy":
|
||||||
return """\
|
return """\
|
||||||
at::Tensor view_copy(const at::Tensor & self, at::SymIntArrayRef size) {
|
at::Tensor view_copy(const at::Tensor & self, at::IntArrayRef size) {
|
||||||
// TODO: don't cast to int array ref
|
DimVector shape = infer_size_dv(size, self.numel());
|
||||||
auto int_size = c10::asIntArrayRefSlow(size);
|
|
||||||
DimVector shape = infer_size_dv(int_size, self.numel());
|
|
||||||
if (!at::detail::computeStride(self.sizes(), self.strides(), shape).has_value()) {
|
if (!at::detail::computeStride(self.sizes(), self.strides(), shape).has_value()) {
|
||||||
return self.reshape(int_size);
|
return self.reshape(size);
|
||||||
} else {
|
} else {
|
||||||
auto output = at::_ops::view::call(self, size);
|
auto output = at::_ops::view::call(self, c10::SymIntArrayRef::fromIntArrayRef(size));
|
||||||
return output.clone();
|
return output.clone();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
# view_copy is a native signature, since we're generating an at::native:: kernel
|
# view_copy is a native signature, since we're generating an at::native:: kernel
|
||||||
# Functionalization always operates on symints though
|
# Functionalization always operates on symints though
|
||||||
view_copy_sig = NativeSignature(g.view_copy.func, symint=True)
|
view_copy_sig = NativeSignature(
|
||||||
|
g.view_copy.func, symint=False
|
||||||
|
) # TODO: flag day this True
|
||||||
|
|
||||||
# view is a dispatcher signature, since we're calling into the at::_ops API
|
# view is a dispatcher signature, since we're calling into the at::_ops API
|
||||||
view_sig = DispatcherSignature(g.view.func)
|
view_sig = DispatcherSignature(g.view.func)
|
||||||
|
|
@ -641,7 +641,7 @@ def gen_functionalization_registration(
|
||||||
metadata = composite_implicit_autograd_index.get_kernel(f)
|
metadata = composite_implicit_autograd_index.get_kernel(f)
|
||||||
assert metadata is not None
|
assert metadata is not None
|
||||||
native_api_name = metadata.kernel
|
native_api_name = metadata.kernel
|
||||||
sig = DispatcherSignature.from_schema(f.func)
|
sig = NativeSignature(f.func, symint=metadata.supports_symint())
|
||||||
# Note [Composite view ops in the functionalization pass]
|
# Note [Composite view ops in the functionalization pass]
|
||||||
# We don't need to worry about implemententing functionalization kernels for views with
|
# We don't need to worry about implemententing functionalization kernels for views with
|
||||||
# CompositeImplicitAutograd kernels, because we can just decompose them into their base operators.
|
# CompositeImplicitAutograd kernels, because we can just decompose them into their base operators.
|
||||||
|
|
|
||||||
|
|
@ -79,7 +79,7 @@ def gen_unwraps(
|
||||||
|
|
||||||
|
|
||||||
def gen_case_where_all_bdims_are_none(
|
def gen_case_where_all_bdims_are_none(
|
||||||
schema: FunctionSchema, cur_level_var: str
|
outer_sig: DispatcherSignature, schema: FunctionSchema, cur_level_var: str
|
||||||
) -> str:
|
) -> str:
|
||||||
conditions = []
|
conditions = []
|
||||||
flat_args = schema.arguments.flat_all
|
flat_args = schema.arguments.flat_all
|
||||||
|
|
@ -90,7 +90,7 @@ def gen_case_where_all_bdims_are_none(
|
||||||
|
|
||||||
sig = DispatcherSignature.from_schema(schema)
|
sig = DispatcherSignature.from_schema(schema)
|
||||||
translated_args = ", ".join(
|
translated_args = ", ".join(
|
||||||
e.expr for e in translate(sig.arguments(), sig.arguments())
|
e.expr for e in translate(outer_sig.arguments(), sig.arguments())
|
||||||
)
|
)
|
||||||
return f"""\
|
return f"""\
|
||||||
if ({' && '.join(conditions)}) {{
|
if ({' && '.join(conditions)}) {{
|
||||||
|
|
@ -160,7 +160,7 @@ def gen_vmap_inplace_plumbing(native_function: NativeFunction) -> Optional[str]:
|
||||||
cur_level_var = "cur_level"
|
cur_level_var = "cur_level"
|
||||||
|
|
||||||
unwraps, unwrapped_arg_list = gen_unwraps(schema.arguments.flat_all, cur_level_var)
|
unwraps, unwrapped_arg_list = gen_unwraps(schema.arguments.flat_all, cur_level_var)
|
||||||
bdims_all_none_case = gen_case_where_all_bdims_are_none(schema, cur_level_var)
|
bdims_all_none_case = gen_case_where_all_bdims_are_none(sig, schema, cur_level_var)
|
||||||
|
|
||||||
return f"""\
|
return f"""\
|
||||||
template <typename batch_rule_t, batch_rule_t batch_rule>
|
template <typename batch_rule_t, batch_rule_t batch_rule>
|
||||||
|
|
@ -182,7 +182,7 @@ def gen_vmap_plumbing_no_returns(native_function: NativeFunction) -> str:
|
||||||
cur_level_var = "cur_level"
|
cur_level_var = "cur_level"
|
||||||
|
|
||||||
unwraps, unwrapped_arg_list = gen_unwraps(schema.arguments.flat_all, cur_level_var)
|
unwraps, unwrapped_arg_list = gen_unwraps(schema.arguments.flat_all, cur_level_var)
|
||||||
bdims_all_none_case = gen_case_where_all_bdims_are_none(schema, cur_level_var)
|
bdims_all_none_case = gen_case_where_all_bdims_are_none(sig, schema, cur_level_var)
|
||||||
|
|
||||||
return f"""\
|
return f"""\
|
||||||
template <typename batch_rule_t, batch_rule_t batch_rule>
|
template <typename batch_rule_t, batch_rule_t batch_rule>
|
||||||
|
|
@ -224,7 +224,7 @@ def gen_vmap_plumbing(native_function: NativeFunction) -> Optional[str]:
|
||||||
cur_level_var = "cur_level"
|
cur_level_var = "cur_level"
|
||||||
|
|
||||||
unwraps, unwrapped_arg_list = gen_unwraps(schema.arguments.flat_all, cur_level_var)
|
unwraps, unwrapped_arg_list = gen_unwraps(schema.arguments.flat_all, cur_level_var)
|
||||||
bdims_all_none_case = gen_case_where_all_bdims_are_none(schema, cur_level_var)
|
bdims_all_none_case = gen_case_where_all_bdims_are_none(sig, schema, cur_level_var)
|
||||||
|
|
||||||
wrapped_returns = gen_returns(returns, cur_level_var, results_var)
|
wrapped_returns = gen_returns(returns, cur_level_var, results_var)
|
||||||
return f"""\
|
return f"""\
|
||||||
|
|
|
||||||
|
|
@ -1098,6 +1098,9 @@ class BackendMetadata:
|
||||||
# The namespace for kernels, default value: DEFAULT_KERNEL_NAMESPACE
|
# The namespace for kernels, default value: DEFAULT_KERNEL_NAMESPACE
|
||||||
cpp_namespace: str
|
cpp_namespace: str
|
||||||
|
|
||||||
|
def supports_symint(self) -> bool:
|
||||||
|
return "_symint" in self.kernel
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class UfuncInnerLoop:
|
class UfuncInnerLoop:
|
||||||
|
|
@ -1141,8 +1144,6 @@ class BackendIndex:
|
||||||
external: bool
|
external: bool
|
||||||
# Other backend-specific information that is on a per-operator basis
|
# Other backend-specific information that is on a per-operator basis
|
||||||
index: Dict["OperatorName", BackendMetadata]
|
index: Dict["OperatorName", BackendMetadata]
|
||||||
# Whether or not this backend handles symbolic ints or not
|
|
||||||
symint: bool
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def grow_index(
|
def grow_index(
|
||||||
|
|
|
||||||
|
|
@ -304,6 +304,8 @@ def generate_function(
|
||||||
if func.kind() == SchemaKind.out
|
if func.kind() == SchemaKind.out
|
||||||
else cpp.name(func)
|
else cpp.name(func)
|
||||||
)
|
)
|
||||||
|
if f.func.has_symint():
|
||||||
|
kernel_name += "_symint"
|
||||||
backend_metadata = {
|
backend_metadata = {
|
||||||
DispatchKey.CompositeExplicitAutograd: {
|
DispatchKey.CompositeExplicitAutograd: {
|
||||||
func.name: BackendMetadata(
|
func.name: BackendMetadata(
|
||||||
|
|
@ -555,7 +557,7 @@ def gen_composite_functional_kernel(g: NativeFunctionsGroup) -> Optional[str]:
|
||||||
|
|
||||||
clone_mutable_inputs_str = "\n".join(clone_mutable_inputs)
|
clone_mutable_inputs_str = "\n".join(clone_mutable_inputs)
|
||||||
return f"""
|
return f"""
|
||||||
{sig.defn()} {{
|
{sig.defn(name=sig.name() + ("_symint" if g.out.func.has_symint() else ""))} {{
|
||||||
{clone_mutable_inputs_str}
|
{clone_mutable_inputs_str}
|
||||||
{maybe_assign}at::_ops::{target_f.func.name.unambiguous_name()}::call({exprs});
|
{maybe_assign}at::_ops::{target_f.func.name.unambiguous_name()}::call({exprs});
|
||||||
{ret_str}
|
{ret_str}
|
||||||
|
|
@ -615,7 +617,7 @@ def gen_composite_out_kernel(g: NativeFunctionsGroup) -> Optional[str]:
|
||||||
|
|
||||||
# Kernel name needs to follow the naming convention defined in `generate_function()`
|
# Kernel name needs to follow the naming convention defined in `generate_function()`
|
||||||
return f"""
|
return f"""
|
||||||
{sig.defn(name=g.out.func.name.unambiguous_name())} {{
|
{sig.defn(name=g.out.func.name.unambiguous_name() + ("_symint" if g.out.func.has_symint() else ""))} {{
|
||||||
auto {out_name} = at::_ops::{g.functional.func.name.unambiguous_name()}::call({exprs});
|
auto {out_name} = at::_ops::{g.functional.func.name.unambiguous_name()}::call({exprs});
|
||||||
{copy_outs_str}
|
{copy_outs_str}
|
||||||
{return_str(g.out.func.returns, rets)}
|
{return_str(g.out.func.returns, rets)}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user