Use presence of _symint in kernel name to generate symint sig or not (#84579)

Something people found confusing was that whether or not a native::
signature would get SymInt or not in its type was based on the dispatch
key.  This changes it so that SymInt or not in type is based on whether
or not you have _symint in the name of the kernel or not.  This means
that even when we make operators support SymInt, you no longer have to
go and update all the preexisting definitions; instead, you now
selectively write _symint to opt individual kernels into SymInt support.

I then go and update a bunch of kernels that don't have proper SymInt
support to make use of this convention.  There is some hacking around
for view generation code.

I also add support for external backends to specify 'symint' operators, for which we generate SymInt signatures instead of regular signatures.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>

Differential Revision: [D39310060](https://our.internmc.facebook.com/intern/diff/D39310060)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84579
Approved by: https://github.com/wconstab
This commit is contained in:
Eli Uriegas 2022-09-09 11:29:07 -07:00 committed by PyTorch MergeBot
parent 18a31cc044
commit 93aef3a010
22 changed files with 173 additions and 130 deletions

View File

@ -1 +1 @@
2ba7616e9070bd14ea34a5ef5459bac571198926 f00dd2f35ecf6455d97237d63c70c9c8ec190940

View File

@ -12,7 +12,7 @@
namespace at { namespace at {
namespace native { namespace native {
Tensor empty_meta( Tensor empty_meta_symint(
SymIntArrayRef size, SymIntArrayRef size,
c10::optional<ScalarType> dtype_opt, c10::optional<ScalarType> dtype_opt,
c10::optional<Layout> layout_opt, c10::optional<Layout> layout_opt,

View File

@ -214,12 +214,9 @@ Tensor empty_strided_cpu(IntArrayRef size, IntArrayRef stride, c10::optional<Sca
return at::detail::empty_strided_cpu(size, stride, dtype_opt, layout_opt, device_opt, pin_memory_opt); return at::detail::empty_strided_cpu(size, stride, dtype_opt, layout_opt, device_opt, pin_memory_opt);
} }
Tensor& empty_out(SymIntArrayRef sym_size, Tensor& empty_out(IntArrayRef size,
c10::optional<c10::MemoryFormat> optional_memory_format, c10::optional<c10::MemoryFormat> optional_memory_format,
Tensor& result) { Tensor& result) {
// TODO: support empty_out properly (I was forced to change this immediately
// with empty so that empty/empty.out had the same type signature)
auto size = c10::asIntArrayRefSlow(sym_size);
// Preferably, this argument would not be accepted by _out, but the code // Preferably, this argument would not be accepted by _out, but the code
// generator requires the out and non-out overloads to match exactly // generator requires the out and non-out overloads to match exactly
TORCH_CHECK( TORCH_CHECK(
@ -386,7 +383,7 @@ Tensor empty_like_quantized(
} }
} }
Tensor new_empty( Tensor new_empty_symint(
const Tensor& self, const Tensor& self,
SymIntArrayRef size, SymIntArrayRef size,
c10::optional<ScalarType> dtype_opt, c10::optional<ScalarType> dtype_opt,
@ -1077,7 +1074,7 @@ Tensor triu_indices_cpu(
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ zeros ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ zeros ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Tensor zeros(SymIntArrayRef size, Tensor zeros_symint(SymIntArrayRef size,
c10::optional<ScalarType> dtype, c10::optional<ScalarType> dtype,
c10::optional<Layout> layout, c10::optional<Layout> layout,
c10::optional<Device> device, c10::optional<Device> device,
@ -1107,8 +1104,7 @@ Tensor& zeros_sparse_out(IntArrayRef size, Tensor& result) {
return result; return result;
} }
Tensor& zeros_out(SymIntArrayRef sym_size, Tensor& result) { Tensor& zeros_out(IntArrayRef size, Tensor& result) {
auto size = c10::asIntArrayRefSlow(sym_size);
if (result.is_sparse()) { if (result.is_sparse()) {
// TODO: I think this branch should be dead, but we don't have an easy // TODO: I think this branch should be dead, but we don't have an easy
// way to cover all sparse kernels with zeros_sparse_out, so retain this // way to cover all sparse kernels with zeros_sparse_out, so retain this

View File

@ -844,9 +844,7 @@ Tensor diag_embed(const Tensor& self, int64_t offset, int64_t dim1_, int64_t dim
return result; return result;
} }
Tensor expand(const Tensor& self, c10::SymIntArrayRef sym_size, bool /*unused*/) { Tensor expand(const Tensor& self, c10::IntArrayRef size, bool /*unused*/) {
// TODO: properly support SymInt expand
auto size = asIntArrayRefSlow(sym_size);
TORCH_CHECK(size.size() >= (size_t)self.dim(), TORCH_CHECK(size.size() >= (size_t)self.dim(),
"expand(", self.toString(), "{", self.sizes(), "}, size=", size, "expand(", self.toString(), "{", self.sizes(), "}, size=", size,
"): the number of sizes provided (", size.size(), ") ", "): the number of sizes provided (", size.size(), ") ",
@ -925,9 +923,8 @@ const Tensor &as_strided_(const Tensor& self, IntArrayRef size, IntArrayRef stri
return self; return self;
} }
Tensor narrow_copy_dense(const Tensor& self, int64_t dim, SymInt start, SymInt length) { Tensor narrow_copy_dense(const Tensor& self, int64_t dim, int64_t start, int64_t length) {
// TODO: properly support SymInt narrow_copy return self.narrow(dim, start, length).clone(at::MemoryFormat::Contiguous);
return self.narrow(dim, start.expect_int(), length.expect_int()).clone(at::MemoryFormat::Contiguous);
} }
Tensor narrow_copy_dense_cpu(const Tensor& self, int64_t dim, int64_t start, int64_t length){ Tensor narrow_copy_dense_cpu(const Tensor& self, int64_t dim, int64_t start, int64_t length){
@ -3204,12 +3201,6 @@ Tensor adjoint(const Tensor &self) {
return _adjoint(self, /*transpose=*/false, "adjoint()"); return _adjoint(self, /*transpose=*/false, "adjoint()");
} }
Tensor view_meta(const Tensor& self,
at::SymIntArrayRef size) {
// TODO: Properly support SymInt view
return view_impl(self, c10::asIntArrayRefSlow(size));
}
Tensor view(const Tensor& self, Tensor view(const Tensor& self,
at::IntArrayRef size) { at::IntArrayRef size) {
return view_impl(self, size); return view_impl(self, size);
@ -3592,7 +3583,7 @@ at::Tensor& expand_copy_SymInt_out(const at::Tensor & self, c10::SymIntArrayRef
} }
at::Tensor& expand_copy_out(const at::Tensor & self, at::SymIntArrayRef size, bool implicit, at::Tensor & out) { at::Tensor& expand_copy_out_symint(const at::Tensor & self, at::SymIntArrayRef size, bool implicit, at::Tensor & out) {
auto tmp = self.expand_symint(size, implicit); auto tmp = self.expand_symint(size, implicit);
out.copy_(tmp); out.copy_(tmp);
return out; return out;
@ -3748,7 +3739,7 @@ void unbind_copy_int_out(const at::Tensor & self, int64_t dim, at::TensorList o
} }
at::Tensor& view_copy_out(const at::Tensor & self, at::SymIntArrayRef size, at::Tensor & out) { at::Tensor& view_copy_out_symint(const at::Tensor & self, at::SymIntArrayRef size, at::Tensor & out) {
auto tmp = self.view_symint(size); auto tmp = self.view_symint(size);
out.copy_(tmp); out.copy_(tmp);
return out; return out;

View File

@ -2054,7 +2054,7 @@
CPU: empty_cpu CPU: empty_cpu
CUDA: empty_cuda CUDA: empty_cuda
MPS: empty_mps MPS: empty_mps
Meta: empty_meta Meta: empty_meta_symint
MkldnnCPU: empty_mkldnn MkldnnCPU: empty_mkldnn
SparseCPU, SparseCUDA, SparseMeta: empty_sparse SparseCPU, SparseCUDA, SparseMeta: empty_sparse
SparseCsrCPU, SparseCsrCUDA: empty_sparse_compressed SparseCsrCPU, SparseCsrCUDA: empty_sparse_compressed
@ -2065,7 +2065,7 @@
- func: new_empty(Tensor self, SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor - func: new_empty(Tensor self, SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
variants: method variants: method
dispatch: dispatch:
CompositeExplicitAutograd: new_empty CompositeExplicitAutograd: new_empty_symint
autogen: new_empty.out autogen: new_empty.out
- func: new_empty_strided(Tensor self, int[] size, int[] stride, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor - func: new_empty_strided(Tensor self, int[] size, int[] stride, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
@ -5548,7 +5548,7 @@
- func: zeros(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor - func: zeros(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
dispatch: dispatch:
CompositeExplicitAutograd: zeros CompositeExplicitAutograd: zeros_symint
- func: zeros.out(SymInt[] size, *, Tensor(a!) out) -> Tensor(a!) - func: zeros.out(SymInt[] size, *, Tensor(a!) out) -> Tensor(a!)
dispatch: dispatch:
@ -6889,8 +6889,7 @@
device_check: NoCheck device_check: NoCheck
device_guard: False device_guard: False
dispatch: dispatch:
Meta: view_meta ZeroTensor, Meta, CPU, CUDA, QuantizedCPU, QuantizedCUDA, MPS: view
ZeroTensor, CPU, CUDA, QuantizedCPU, QuantizedCUDA, MPS: view
MkldnnCPU: mkldnn_view MkldnnCPU: mkldnn_view
NestedTensorCPU, NestedTensorCUDA: view_nested NestedTensorCPU, NestedTensorCUDA: view_nested
@ -12938,7 +12937,7 @@
- func: expand_copy.out(Tensor self, SymInt[] size, *, bool implicit=False, Tensor(a!) out) -> Tensor(a!) - func: expand_copy.out(Tensor self, SymInt[] size, *, bool implicit=False, Tensor(a!) out) -> Tensor(a!)
variants: function variants: function
dispatch: dispatch:
CompositeExplicitAutograd: expand_copy_out CompositeExplicitAutograd: expand_copy_out_symint
- func: permute_copy.out(Tensor self, int[] dims, *, Tensor(a!) out) -> Tensor(a!) - func: permute_copy.out(Tensor self, int[] dims, *, Tensor(a!) out) -> Tensor(a!)
@ -13058,7 +13057,7 @@
- func: view_copy.out(Tensor self, SymInt[] size, *, Tensor(a!) out) -> Tensor(a!) - func: view_copy.out(Tensor self, SymInt[] size, *, Tensor(a!) out) -> Tensor(a!)
variants: function variants: function
dispatch: dispatch:
CompositeExplicitAutograd: view_copy_out CompositeExplicitAutograd: view_copy_out_symint
- func: view_copy.dtype_out(Tensor self, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!) - func: view_copy.dtype_out(Tensor self, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!)

View File

@ -199,6 +199,13 @@ supported:
- _trilinear - _trilinear
- linalg_pinv.atol_rtol_tensor - linalg_pinv.atol_rtol_tensor
- logsumexp.out - logsumexp.out
symint:
- empty.memory_format
- expand
- expand_copy
- narrow_copy
- view
- view_copy
autograd: autograd:
- max_pool3d - max_pool3d
- native_group_norm - native_group_norm

View File

@ -119,7 +119,7 @@ TEST(MathKernelTest, NarrowCopy) {
for (const auto dim : c10::irange(3)) { for (const auto dim : c10::irange(3)) {
const int64_t start = 1, length = 4; const int64_t start = 1, length = 4;
auto y_ref = x.narrow(dim, start, length); auto y_ref = x.narrow(dim, start, length);
auto y_test = at::native::narrow_copy_dense(x, dim, c10::SymInt(start), c10::SymInt(length)); auto y_test = at::native::narrow_copy_dense(x, dim, start, length);
ASSERT_ALLCLOSE_TOLERANCES(y_ref, y_test, 0, 0); ASSERT_ALLCLOSE_TOLERANCES(y_ref, y_test, 0, 0);
} }
} }

View File

@ -5,6 +5,7 @@
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import distutils.command.clean import distutils.command.clean
import sys
import shutil import shutil
import glob import glob
import os import os
@ -129,21 +130,25 @@ class BuildExtension_(BuildExtension):
if __name__ == '__main__': if __name__ == '__main__':
print("Building wheel {}-{}".format(package_name, version)) print("Building wheel {}-{}".format(package_name, version))
write_version_file() write_version_file()
setup( try:
# Metadata setup(
name=package_name, # Metadata
version=version, name=package_name,
author='PyTorch Core Team', version=version,
url="https://github.com/pytorch/functorch", author='PyTorch Core Team',
description='JAX-like composable function transforms for PyTorch', url="https://github.com/pytorch/functorch",
license='BSD', description='JAX-like composable function transforms for PyTorch',
license='BSD',
# Package info # Package info
packages=find_packages(), packages=find_packages(),
install_requires=requirements, install_requires=requirements,
extras_require=extras, extras_require=extras,
ext_modules=get_extensions(), ext_modules=get_extensions(),
cmdclass={ cmdclass={
"build_ext": BuildExtension_.with_options(no_python_abi_suffix=True), "build_ext": BuildExtension_.with_options(no_python_abi_suffix=True),
'clean': clean, 'clean': clean,
}) })
except Exception as e:
print(e, file=sys.stderr)
sys.exit(1)

View File

@ -584,7 +584,8 @@ def gen_inplace_or_view_type(
[fn for fn in fns_with_infos if use_derived(fn)], [fn for fn in fns_with_infos if use_derived(fn)],
key_fn=lambda fn: fn.func.root_name, key_fn=lambda fn: fn.func.root_name,
base_env={ base_env={
"generated_comment": f"@generated from {template_path}/ADInplaceOrViewType.cpp", "generated_comment": "@"
+ f"generated from {template_path}/ADInplaceOrViewType.cpp",
}, },
env_callable=gen_inplace_or_view_type_env, env_callable=gen_inplace_or_view_type_env,
num_shards=2, num_shards=2,

View File

@ -535,7 +535,7 @@ def gen_trace_type(
[fn for fn in native_functions if cpp.name(fn.func) not in MANUAL_TRACER], [fn for fn in native_functions if cpp.name(fn.func) not in MANUAL_TRACER],
key_fn=lambda fn: fn.root_name, key_fn=lambda fn: fn.root_name,
base_env={ base_env={
"generated_comment": f"@generated from {template_path}/TraceType.cpp", "generated_comment": "@" + f"generated from {template_path}/TraceType.cpp",
}, },
env_callable=gen_trace_type_func, env_callable=gen_trace_type_func,
num_shards=5, num_shards=5,

View File

@ -314,7 +314,6 @@ class TestGenNativeFunctionDeclaration(unittest.TestCase):
dispatch_key=k, dispatch_key=k,
use_out_as_primary=True, use_out_as_primary=True,
external=False, external=False,
symint=False,
device_guard=False, device_guard=False,
index=backend_indices[k], index=backend_indices[k],
) )

View File

@ -238,7 +238,7 @@ invalid_key: invalid_val"""
output_error = self.get_errors_from_gen_backend_stubs(yaml_str) output_error = self.get_errors_from_gen_backend_stubs(yaml_str)
self.assertExpectedInline( self.assertExpectedInline(
output_error, output_error,
""" contains unexpected keys: invalid_key. Only the following keys are supported: backend, class_name, cpp_namespace, extra_headers, supported, autograd, full_codegen, non_native, ir_gen""", # noqa: B950 """ contains unexpected keys: invalid_key. Only the following keys are supported: backend, class_name, cpp_namespace, extra_headers, supported, autograd, full_codegen, non_native, ir_gen, symint""", # noqa: B950
) )
# if use_out_as_primary is provided, it must be a bool # if use_out_as_primary is provided, it must be a bool

View File

@ -269,7 +269,7 @@ at::Tensor LazyNativeFunctions::_to_copy(
} }
}; };
at::Tensor LazyNativeFunctions::empty( at::Tensor LazyNativeFunctions::empty_symint(
at::SymIntArrayRef sym_size, at::SymIntArrayRef sym_size,
c10::optional<at::ScalarType> dtype, c10::optional<at::ScalarType> dtype,
c10::optional<at::Layout> layout, c10::optional<at::Layout> layout,
@ -307,7 +307,7 @@ at::Tensor LazyNativeFunctions::empty_strided(
c10::optional<at::Device> device, c10::optional<at::Device> device,
c10::optional<bool> pin_memory) { c10::optional<bool> pin_memory) {
TORCH_LAZY_FN_COUNTER("lazy::"); TORCH_LAZY_FN_COUNTER("lazy::");
at::Tensor t = empty( at::Tensor t = empty_symint(
c10::SymIntArrayRef::fromIntArrayRef(size), c10::SymIntArrayRef::fromIntArrayRef(size),
dtype, dtype,
layout, layout,
@ -409,7 +409,7 @@ at::Tensor LazyNativeFunctions::_unsafe_view(
const at::Tensor& self, const at::Tensor& self,
at::IntArrayRef size) { at::IntArrayRef size) {
TORCH_LAZY_FN_COUNTER("lazy::"); TORCH_LAZY_FN_COUNTER("lazy::");
return LazyNativeFunctions::view_copy( return LazyNativeFunctions::view_copy_symint(
self, c10::SymIntArrayRef::fromIntArrayRef(size)); self, c10::SymIntArrayRef::fromIntArrayRef(size));
} }
@ -449,7 +449,7 @@ at::Tensor LazyNativeFunctions::new_empty_strided(
self, size, stride, dtype, layout, device, pin_memory); self, size, stride, dtype, layout, device, pin_memory);
} }
at::Tensor LazyNativeFunctions::narrow_copy( at::Tensor LazyNativeFunctions::narrow_copy_symint(
const at::Tensor& self, const at::Tensor& self,
int64_t dim, int64_t dim,
c10::SymInt start, c10::SymInt start,

View File

@ -35,7 +35,12 @@ def name(func: FunctionSchema) -> str:
def argumenttype_type( def argumenttype_type(
t: Type, *, mutable: bool, binds: ArgName, remove_non_owning_ref_types: bool = False t: Type,
*,
mutable: bool,
binds: ArgName,
remove_non_owning_ref_types: bool = False,
symint: bool = True,
) -> NamedCType: ) -> NamedCType:
# This is a faux amis. If it makes sense in the future to add # This is a faux amis. If it makes sense in the future to add
# more special cases here, or invert things so cpp.argument_type # more special cases here, or invert things so cpp.argument_type
@ -45,25 +50,30 @@ def argumenttype_type(
t, t,
mutable=mutable, mutable=mutable,
binds=binds, binds=binds,
symint=True, symint=symint,
remove_non_owning_ref_types=remove_non_owning_ref_types, remove_non_owning_ref_types=remove_non_owning_ref_types,
) )
def argument_type( def argument_type(
a: Argument, *, binds: ArgName, remove_non_owning_ref_types: bool = False a: Argument,
*,
binds: ArgName,
remove_non_owning_ref_types: bool = False,
symint: bool = True,
) -> NamedCType: ) -> NamedCType:
return argumenttype_type( return argumenttype_type(
a.type, a.type,
mutable=a.is_write, mutable=a.is_write,
binds=binds, binds=binds,
remove_non_owning_ref_types=remove_non_owning_ref_types, remove_non_owning_ref_types=remove_non_owning_ref_types,
symint=symint,
) )
def returns_type(rs: Sequence[Return]) -> CType: def returns_type(rs: Sequence[Return], *, symint: bool = True) -> CType:
# At present, there is no difference. But there could be! # At present, there is no difference. But there could be!
return cpp.returns_type(rs, symint=True) return cpp.returns_type(rs, symint=symint)
def jit_arguments(func: FunctionSchema) -> List[Argument]: def jit_arguments(func: FunctionSchema) -> List[Argument]:
@ -89,15 +99,20 @@ def jit_arguments(func: FunctionSchema) -> List[Argument]:
) )
def argument(a: Argument, *, remove_non_owning_ref_types: bool = False) -> Binding: def argument(
a: Argument, *, remove_non_owning_ref_types: bool = False, symint: bool = True
) -> Binding:
return Binding( return Binding(
nctype=argument_type( nctype=argument_type(
a, binds=a.name, remove_non_owning_ref_types=remove_non_owning_ref_types a,
binds=a.name,
remove_non_owning_ref_types=remove_non_owning_ref_types,
symint=symint,
), ),
name=a.name, name=a.name,
argument=a, argument=a,
) )
def arguments(func: FunctionSchema) -> List[Binding]: def arguments(func: FunctionSchema, *, symint: bool = True) -> List[Binding]:
return [argument(a) for a in jit_arguments(func)] return [argument(a, symint=symint) for a in jit_arguments(func)]

View File

@ -577,8 +577,10 @@ class DispatcherSignature:
# and need to avoid naming collisions. # and need to avoid naming collisions.
prefix: str = "" prefix: str = ""
symint: bool = True
def arguments(self) -> List[Binding]: def arguments(self) -> List[Binding]:
return dispatcher.arguments(self.func) return dispatcher.arguments(self.func, symint=self.symint)
def name(self) -> str: def name(self) -> str:
return self.prefix + dispatcher.name(self.func) return self.prefix + dispatcher.name(self.func)
@ -604,7 +606,7 @@ class DispatcherSignature:
return [Expr(a.name, a.nctype) for a in self.arguments()] return [Expr(a.name, a.nctype) for a in self.arguments()]
def returns_type(self) -> CType: def returns_type(self) -> CType:
return dispatcher.returns_type(self.func.returns) return dispatcher.returns_type(self.func.returns, symint=self.symint)
def ptr_type(self) -> str: def ptr_type(self) -> str:
dispatcher_args_types_str = ", ".join(a.type for a in self.arguments()) dispatcher_args_types_str = ", ".join(a.type for a in self.arguments())
@ -616,8 +618,10 @@ class DispatcherSignature:
return f"{self.returns_type().cpp_type()} ({dispatcher_args_types_str})" return f"{self.returns_type().cpp_type()} ({dispatcher_args_types_str})"
@staticmethod @staticmethod
def from_schema(func: FunctionSchema, *, prefix: str = "") -> "DispatcherSignature": def from_schema(
return DispatcherSignature(func, prefix) func: FunctionSchema, *, prefix: str = "", symint: bool = True
) -> "DispatcherSignature":
return DispatcherSignature(func, prefix, symint)
@dataclass(frozen=True) @dataclass(frozen=True)
@ -778,15 +782,16 @@ def kernel_signature(
# so we'd like to keep the differences as small as possible. # so we'd like to keep the differences as small as possible.
# With external backends, we'd like to enforce that they write their kernels with schemas # With external backends, we'd like to enforce that they write their kernels with schemas
# that match the Dispatcher API directly, if they can. # that match the Dispatcher API directly, if they can.
meta = backend_index.get_kernel(f)
symint = meta is not None and meta.supports_symint()
if symint:
assert (
f.func.has_symint()
), f"attempted to define symint kernel for {backend_index.dispatch_key} without SymInt in schema"
if backend_index.external: if backend_index.external:
# Dispatcher signature faithfully does SymInt, which is good for XLA, return DispatcherSignature.from_schema(f.func, prefix=prefix, symint=symint)
# not so good for more conventional backends but we don't have any of
# those. If we do, that's time to add a new Signature that is a cross
# between DispatcherSignature and NativeSignature
assert backend_index.symint
return DispatcherSignature.from_schema(f.func, prefix=prefix)
else: else:
return NativeSignature(f.func, prefix=prefix, symint=backend_index.symint) return NativeSignature(f.func, prefix=prefix, symint=symint)
# Functions only, no types # Functions only, no types

View File

@ -751,8 +751,11 @@ resize_out(out, sizes, strides, options);
) )
# Signature of the wrapper function we'll register to the dispatcher # Signature of the wrapper function we'll register to the dispatcher
kern = self.backend_index.get_kernel(f)
sig = NativeSignature( sig = NativeSignature(
f.func, prefix="wrapper_", symint=self.backend_index.symint f.func,
prefix="wrapper_",
symint=kern is not None and kern.supports_symint(),
) )
if self.target is Target.NAMESPACED_DECLARATION: if self.target is Target.NAMESPACED_DECLARATION:

View File

@ -20,6 +20,7 @@ from torchgen.api import cpp
from torchgen.api.translate import translate from torchgen.api.translate import translate
from torchgen.api.types import ( from torchgen.api.types import (
Binding, Binding,
CppSignature,
CppSignatureGroup, CppSignatureGroup,
DispatcherSignature, DispatcherSignature,
NamedCType, NamedCType,
@ -161,7 +162,6 @@ def parse_native_yaml_struct(
device_guard=False, device_guard=False,
# I'm actually not sure about this; undefined could be hit on # I'm actually not sure about this; undefined could be hit on
# empty TensorList, hypothetically that could have sizes in it # empty TensorList, hypothetically that could have sizes in it
symint=False,
index={}, index={},
) )
) )
@ -176,16 +176,6 @@ def parse_native_yaml_struct(
# Only cuda-like devices in tree require device guards # Only cuda-like devices in tree require device guards
device_guard=is_cuda_dispatch_key(k), device_guard=is_cuda_dispatch_key(k),
index=v, index=v,
# Which dispatch keys natively support symint
# Note: DispatchKey.CompositeExplicitAutograd has to match out
# composites; I think there's some factoring problem here
symint=k
in [
DispatchKey.Meta,
DispatchKey.CompositeImplicitAutograd,
DispatchKey.CompositeExplicitAutograd,
DispatchKey.CompositeExplicitAutogradNonFunctional,
],
) )
return ParsedYaml(rs, indices) return ParsedYaml(rs, indices)
@ -363,6 +353,8 @@ def static_dispatch_extra_headers(backends: List[BackendIndex]) -> List[str]:
# supporting usecases even when there is a memory_format argument along with tensor_option arguments. # supporting usecases even when there is a memory_format argument along with tensor_option arguments.
# This usecase is not covered by tools.codegen.api.translate() yet as its application is limited to static dispatch # This usecase is not covered by tools.codegen.api.translate() yet as its application is limited to static dispatch
def translate_args_dispatcher_to_cpp( def translate_args_dispatcher_to_cpp(
sig: DispatcherSignature,
cpp_sig: CppSignature,
f: NativeFunction, f: NativeFunction,
) -> str: ) -> str:
@ -385,10 +377,7 @@ def translate_args_dispatcher_to_cpp(
output_bindings.append(binding) output_bindings.append(binding)
return output_bindings return output_bindings
disp_sig = DispatcherSignature.from_schema(f.func) disp_sig = sig
cpp_sig = CppSignatureGroup.from_native_function(
f, method=False, fallback_binding=False
).signature
disp_bindings = disp_sig.arguments() disp_bindings = disp_sig.arguments()
# When last argument of CPP signature has SpecialArgName.possibly_redundant_memory_format NCType, # When last argument of CPP signature has SpecialArgName.possibly_redundant_memory_format NCType,
# get memory_format bindings of dispatcher signature to have the same NCType as well # get memory_format bindings of dispatcher signature to have the same NCType as well
@ -401,11 +390,20 @@ def translate_args_dispatcher_to_cpp(
def generate_static_dispatch_backend_call( def generate_static_dispatch_backend_call(
sig: DispatcherSignature,
f: NativeFunction, f: NativeFunction,
backend_index: BackendIndex, backend_index: BackendIndex,
) -> str: ) -> str:
name = DispatcherSignature.from_schema(f.func).name() cpp_sigs = CppSignatureGroup.from_native_function(
exprs = translate_args_dispatcher_to_cpp(f) f, method=False, fallback_binding=False
)
if sig.symint and f.func.has_symint():
cpp_sig = cpp_sigs.symint_signature
else:
cpp_sig = cpp_sigs.signature
assert cpp_sig is not None
name = cpp_sig.name()
exprs = translate_args_dispatcher_to_cpp(sig, cpp_sig, f)
backend_metadata = backend_index.get_kernel(f) backend_metadata = backend_index.get_kernel(f)
kernel_ns = ( kernel_ns = (
backend_metadata.cpp_namespace backend_metadata.cpp_namespace
@ -417,11 +415,20 @@ def generate_static_dispatch_backend_call(
def generate_static_dispatch_fallback_call( def generate_static_dispatch_fallback_call(
sig: DispatcherSignature,
f: NativeFunction, f: NativeFunction,
backend_indices: List[BackendIndex], backend_indices: List[BackendIndex],
) -> str: ) -> str:
name = DispatcherSignature.from_schema(f.func).name() cpp_sigs = CppSignatureGroup.from_native_function(
exprs = translate_args_dispatcher_to_cpp(f) f, method=False, fallback_binding=False
)
if sig.symint and f.func.has_symint():
cpp_sig = cpp_sigs.symint_signature
else:
cpp_sig = cpp_sigs.signature
assert cpp_sig is not None
name = cpp_sig.name()
exprs = translate_args_dispatcher_to_cpp(sig, cpp_sig, f)
ns = DEFAULT_KERNEL_NAMESPACE.replace("::native", "") ns = DEFAULT_KERNEL_NAMESPACE.replace("::native", "")
if f.has_composite_explicit_autograd_kernel: if f.has_composite_explicit_autograd_kernel:
return f"return {ns}::{DispatchKey.CompositeExplicitAutograd.lower()}::{name}({exprs});" return f"return {ns}::{DispatchKey.CompositeExplicitAutograd.lower()}::{name}({exprs});"
@ -437,6 +444,7 @@ def generate_static_dispatch_fallback_call(
def static_dispatch( def static_dispatch(
sig: DispatcherSignature,
f: NativeFunction, f: NativeFunction,
backend_indices: List[BackendIndex], backend_indices: List[BackendIndex],
) -> str: ) -> str:
@ -453,11 +461,10 @@ def static_dispatch(
) )
] ]
if len(keys) == 1: if len(keys) == 1:
return generate_static_dispatch_backend_call(f, keys[0]) return generate_static_dispatch_backend_call(sig, f, keys[0])
elif len(keys) == 0: elif len(keys) == 0:
return generate_static_dispatch_fallback_call(f, backend_indices) return generate_static_dispatch_fallback_call(sig, f, backend_indices)
sig = DispatcherSignature.from_schema(f.func)
native_tensor_args = [ native_tensor_args = [
a.name a.name
for a in sig.arguments() for a in sig.arguments()
@ -483,10 +490,10 @@ def static_dispatch(
for index in keys: for index in keys:
dispatch_code.append(f"""case DispatchKey::{index.dispatch_key}:""") dispatch_code.append(f"""case DispatchKey::{index.dispatch_key}:""")
dispatch_code.append( dispatch_code.append(
f"""\t{generate_static_dispatch_backend_call(f, index)};""" f"""\t{generate_static_dispatch_backend_call(sig, f, index)};"""
) )
fallback = generate_static_dispatch_fallback_call(f, backend_indices) fallback = generate_static_dispatch_fallback_call(sig, f, backend_indices)
connector = "\n\t\t" connector = "\n\t\t"
return f""" return f"""
@ -528,8 +535,6 @@ class ComputeOperators:
def __call__(self, f: NativeFunction) -> str: def __call__(self, f: NativeFunction) -> str:
sig = DispatcherSignature.from_schema(f.func) sig = DispatcherSignature.from_schema(f.func)
name = f.func.name.unambiguous_name() name = f.func.name.unambiguous_name()
call_method_name = "call"
redispatch_method_name = "redispatch"
if self.target is Target.DECLARATION: if self.target is Target.DECLARATION:
# Note [The ATen Operators API] # Note [The ATen Operators API]
@ -563,8 +568,8 @@ struct TORCH_API {name} {{
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::{f.func.name.name}") STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::{f.func.name.name}")
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "{f.func.name.overload_name}") STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "{f.func.name.overload_name}")
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, {cpp_string(str(f.func))}) STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, {cpp_string(str(f.func))})
static {sig.defn(name=call_method_name, is_redispatching_fn=False)}; static {sig.defn(name="call", is_redispatching_fn=False)};
static {sig.defn(name=redispatch_method_name, is_redispatching_fn=True)}; static {sig.defn(name="redispatch", is_redispatching_fn=True)};
}};""" }};"""
elif self.target is Target.DEFINITION: elif self.target is Target.DEFINITION:
@ -585,12 +590,13 @@ static C10_NOINLINE c10::TypedOperatorHandle<{name}::schema> create_{name}_typed
dispatcher_exprs_str = ", ".join( dispatcher_exprs_str = ", ".join(
["dispatchKeySet"] + [a.name for a in sig.arguments()] ["dispatchKeySet"] + [a.name for a in sig.arguments()]
) )
dispatcher_call = "redispatch" method_base = "redispatch"
method_name = f"{name}::{redispatch_method_name}"
else: else:
method_name = f"{name}::{call_method_name}"
dispatcher_exprs_str = ", ".join([a.name for a in sig.arguments()]) dispatcher_exprs_str = ", ".join([a.name for a in sig.arguments()])
dispatcher_call = "call" method_base = "call"
dispatcher_call = method_base
method_name = f"{name}::{method_base}"
fn_body = f""" fn_body = f"""
static auto op = create_{name}_typed_handle(); static auto op = create_{name}_typed_handle();
@ -602,7 +608,7 @@ static C10_NOINLINE c10::TypedOperatorHandle<{name}::schema> create_{name}_typed
): ):
# call() should go through static dispatch # call() should go through static dispatch
fn_body = static_dispatch( fn_body = static_dispatch(
f, backend_indices=self.static_dispatch_backend_indices sig, f, backend_indices=self.static_dispatch_backend_indices
) )
defns += f""" defns += f"""
// aten::{f.func} // aten::{f.func}

View File

@ -3,7 +3,7 @@ import os
import pathlib import pathlib
import re import re
from collections import Counter, defaultdict, namedtuple from collections import Counter, defaultdict, namedtuple
from typing import Dict, List, Optional, Sequence, Union from typing import Dict, List, Optional, Sequence, Set, Union
import yaml import yaml
@ -68,6 +68,7 @@ def parse_backend_yaml(
"full_codegen", "full_codegen",
"non_native", "non_native",
"ir_gen", "ir_gen",
"symint",
] ]
backend = yaml_values.pop("backend", None) backend = yaml_values.pop("backend", None)
@ -96,6 +97,14 @@ def parse_backend_yaml(
supported, list supported, list
), f'expected "supported" to be a list, but got: {supported} (of type {type(supported)})' ), f'expected "supported" to be a list, but got: {supported} (of type {type(supported)})'
symint = yaml_values.pop("symint", [])
if symint is None:
symint = [] # Allow an empty list of symint ops
assert isinstance(
symint, list
), f'expected "symint" to be a list, but got: {supported} (of type {type(supported)})'
symint_set = set(symint)
supported_autograd = yaml_values.pop("autograd", []) supported_autograd = yaml_values.pop("autograd", [])
assert isinstance( assert isinstance(
supported_autograd, list supported_autograd, list
@ -118,6 +127,7 @@ Only the following keys are supported: {", ".join(valid_keys)}'
def create_backend_index( def create_backend_index(
backend_ops: List[str], backend_ops: List[str],
symint_ops: Set[str],
dispatch_key: DispatchKey, dispatch_key: DispatchKey,
*, *,
use_out_as_primary: bool, use_out_as_primary: bool,
@ -131,6 +141,8 @@ Only the following keys are supported: {", ".join(valid_keys)}'
), f"Found an invalid operator name: {op_name}" ), f"Found an invalid operator name: {op_name}"
# See Note [External Backends Follow Dispatcher API] # See Note [External Backends Follow Dispatcher API]
kernel_name = dispatcher.name(native_functions_map[op_name].func) kernel_name = dispatcher.name(native_functions_map[op_name].func)
if op in symint_ops:
kernel_name += "_symint"
# TODO: allow structured external backends later. # TODO: allow structured external backends later.
m = BackendMetadata( m = BackendMetadata(
kernel=kernel_name, structured=False, cpp_namespace=cpp_namespace kernel=kernel_name, structured=False, cpp_namespace=cpp_namespace
@ -140,7 +152,6 @@ Only the following keys are supported: {", ".join(valid_keys)}'
dispatch_key=dispatch_key, dispatch_key=dispatch_key,
use_out_as_primary=use_out_as_primary, use_out_as_primary=use_out_as_primary,
external=True, external=True,
symint=True, # TODO: make this configurable
device_guard=use_device_guard, device_guard=use_device_guard,
index=metadata, index=metadata,
) )
@ -154,6 +165,7 @@ Only the following keys are supported: {", ".join(valid_keys)}'
backend_idx = create_backend_index( backend_idx = create_backend_index(
supported, supported,
symint_set,
backend_key, backend_key,
use_out_as_primary=use_out_as_primary, use_out_as_primary=use_out_as_primary,
use_device_guard=use_device_guard, use_device_guard=use_device_guard,
@ -171,6 +183,7 @@ the behavior of autograd for some operators on your backend. However "Autograd{b
autograd_idx = create_backend_index( autograd_idx = create_backend_index(
supported_autograd, supported_autograd,
symint_set,
autograd_key, autograd_key,
use_out_as_primary=use_out_as_primary, use_out_as_primary=use_out_as_primary,
use_device_guard=use_device_guard, use_device_guard=use_device_guard,

View File

@ -84,21 +84,21 @@ def gen_composite_view_copy_kernel(g: NativeFunctionsViewGroup) -> Optional[str]
# clone() calls in their graph (which is normally needed by reshape). # clone() calls in their graph (which is normally needed by reshape).
if str(g.view_copy.func.name) == "view_copy": if str(g.view_copy.func.name) == "view_copy":
return """\ return """\
at::Tensor view_copy(const at::Tensor & self, at::SymIntArrayRef size) { at::Tensor view_copy(const at::Tensor & self, at::IntArrayRef size) {
// TODO: don't cast to int array ref DimVector shape = infer_size_dv(size, self.numel());
auto int_size = c10::asIntArrayRefSlow(size);
DimVector shape = infer_size_dv(int_size, self.numel());
if (!at::detail::computeStride(self.sizes(), self.strides(), shape).has_value()) { if (!at::detail::computeStride(self.sizes(), self.strides(), shape).has_value()) {
return self.reshape(int_size); return self.reshape(size);
} else { } else {
auto output = at::_ops::view::call(self, size); auto output = at::_ops::view::call(self, c10::SymIntArrayRef::fromIntArrayRef(size));
return output.clone(); return output.clone();
} }
} }
""" """
# view_copy is a native signature, since we're generating an at::native:: kernel # view_copy is a native signature, since we're generating an at::native:: kernel
# Functionalization always operates on symints though # Functionalization always operates on symints though
view_copy_sig = NativeSignature(g.view_copy.func, symint=True) view_copy_sig = NativeSignature(
g.view_copy.func, symint=False
) # TODO: flag day this True
# view is a dispatcher signature, since we're calling into the at::_ops API # view is a dispatcher signature, since we're calling into the at::_ops API
view_sig = DispatcherSignature(g.view.func) view_sig = DispatcherSignature(g.view.func)
@ -641,7 +641,7 @@ def gen_functionalization_registration(
metadata = composite_implicit_autograd_index.get_kernel(f) metadata = composite_implicit_autograd_index.get_kernel(f)
assert metadata is not None assert metadata is not None
native_api_name = metadata.kernel native_api_name = metadata.kernel
sig = DispatcherSignature.from_schema(f.func) sig = NativeSignature(f.func, symint=metadata.supports_symint())
# Note [Composite view ops in the functionalization pass] # Note [Composite view ops in the functionalization pass]
# We don't need to worry about implemententing functionalization kernels for views with # We don't need to worry about implemententing functionalization kernels for views with
# CompositeImplicitAutograd kernels, because we can just decompose them into their base operators. # CompositeImplicitAutograd kernels, because we can just decompose them into their base operators.

View File

@ -79,7 +79,7 @@ def gen_unwraps(
def gen_case_where_all_bdims_are_none( def gen_case_where_all_bdims_are_none(
schema: FunctionSchema, cur_level_var: str outer_sig: DispatcherSignature, schema: FunctionSchema, cur_level_var: str
) -> str: ) -> str:
conditions = [] conditions = []
flat_args = schema.arguments.flat_all flat_args = schema.arguments.flat_all
@ -90,7 +90,7 @@ def gen_case_where_all_bdims_are_none(
sig = DispatcherSignature.from_schema(schema) sig = DispatcherSignature.from_schema(schema)
translated_args = ", ".join( translated_args = ", ".join(
e.expr for e in translate(sig.arguments(), sig.arguments()) e.expr for e in translate(outer_sig.arguments(), sig.arguments())
) )
return f"""\ return f"""\
if ({' && '.join(conditions)}) {{ if ({' && '.join(conditions)}) {{
@ -160,7 +160,7 @@ def gen_vmap_inplace_plumbing(native_function: NativeFunction) -> Optional[str]:
cur_level_var = "cur_level" cur_level_var = "cur_level"
unwraps, unwrapped_arg_list = gen_unwraps(schema.arguments.flat_all, cur_level_var) unwraps, unwrapped_arg_list = gen_unwraps(schema.arguments.flat_all, cur_level_var)
bdims_all_none_case = gen_case_where_all_bdims_are_none(schema, cur_level_var) bdims_all_none_case = gen_case_where_all_bdims_are_none(sig, schema, cur_level_var)
return f"""\ return f"""\
template <typename batch_rule_t, batch_rule_t batch_rule> template <typename batch_rule_t, batch_rule_t batch_rule>
@ -182,7 +182,7 @@ def gen_vmap_plumbing_no_returns(native_function: NativeFunction) -> str:
cur_level_var = "cur_level" cur_level_var = "cur_level"
unwraps, unwrapped_arg_list = gen_unwraps(schema.arguments.flat_all, cur_level_var) unwraps, unwrapped_arg_list = gen_unwraps(schema.arguments.flat_all, cur_level_var)
bdims_all_none_case = gen_case_where_all_bdims_are_none(schema, cur_level_var) bdims_all_none_case = gen_case_where_all_bdims_are_none(sig, schema, cur_level_var)
return f"""\ return f"""\
template <typename batch_rule_t, batch_rule_t batch_rule> template <typename batch_rule_t, batch_rule_t batch_rule>
@ -224,7 +224,7 @@ def gen_vmap_plumbing(native_function: NativeFunction) -> Optional[str]:
cur_level_var = "cur_level" cur_level_var = "cur_level"
unwraps, unwrapped_arg_list = gen_unwraps(schema.arguments.flat_all, cur_level_var) unwraps, unwrapped_arg_list = gen_unwraps(schema.arguments.flat_all, cur_level_var)
bdims_all_none_case = gen_case_where_all_bdims_are_none(schema, cur_level_var) bdims_all_none_case = gen_case_where_all_bdims_are_none(sig, schema, cur_level_var)
wrapped_returns = gen_returns(returns, cur_level_var, results_var) wrapped_returns = gen_returns(returns, cur_level_var, results_var)
return f"""\ return f"""\

View File

@ -1098,6 +1098,9 @@ class BackendMetadata:
# The namespace for kernels, default value: DEFAULT_KERNEL_NAMESPACE # The namespace for kernels, default value: DEFAULT_KERNEL_NAMESPACE
cpp_namespace: str cpp_namespace: str
def supports_symint(self) -> bool:
return "_symint" in self.kernel
@dataclass(frozen=True) @dataclass(frozen=True)
class UfuncInnerLoop: class UfuncInnerLoop:
@ -1141,8 +1144,6 @@ class BackendIndex:
external: bool external: bool
# Other backend-specific information that is on a per-operator basis # Other backend-specific information that is on a per-operator basis
index: Dict["OperatorName", BackendMetadata] index: Dict["OperatorName", BackendMetadata]
# Whether or not this backend handles symbolic ints or not
symint: bool
@staticmethod @staticmethod
def grow_index( def grow_index(

View File

@ -304,6 +304,8 @@ def generate_function(
if func.kind() == SchemaKind.out if func.kind() == SchemaKind.out
else cpp.name(func) else cpp.name(func)
) )
if f.func.has_symint():
kernel_name += "_symint"
backend_metadata = { backend_metadata = {
DispatchKey.CompositeExplicitAutograd: { DispatchKey.CompositeExplicitAutograd: {
func.name: BackendMetadata( func.name: BackendMetadata(
@ -555,7 +557,7 @@ def gen_composite_functional_kernel(g: NativeFunctionsGroup) -> Optional[str]:
clone_mutable_inputs_str = "\n".join(clone_mutable_inputs) clone_mutable_inputs_str = "\n".join(clone_mutable_inputs)
return f""" return f"""
{sig.defn()} {{ {sig.defn(name=sig.name() + ("_symint" if g.out.func.has_symint() else ""))} {{
{clone_mutable_inputs_str} {clone_mutable_inputs_str}
{maybe_assign}at::_ops::{target_f.func.name.unambiguous_name()}::call({exprs}); {maybe_assign}at::_ops::{target_f.func.name.unambiguous_name()}::call({exprs});
{ret_str} {ret_str}
@ -615,7 +617,7 @@ def gen_composite_out_kernel(g: NativeFunctionsGroup) -> Optional[str]:
# Kernel name needs to follow the naming convention defined in `generate_function()` # Kernel name needs to follow the naming convention defined in `generate_function()`
return f""" return f"""
{sig.defn(name=g.out.func.name.unambiguous_name())} {{ {sig.defn(name=g.out.func.name.unambiguous_name() + ("_symint" if g.out.func.has_symint() else ""))} {{
auto {out_name} = at::_ops::{g.functional.func.name.unambiguous_name()}::call({exprs}); auto {out_name} = at::_ops::{g.functional.func.name.unambiguous_name()}::call({exprs});
{copy_outs_str} {copy_outs_str}
{return_str(g.out.func.returns, rets)} {return_str(g.out.func.returns, rets)}