diff --git a/.github/ci_commit_pins/xla.txt b/.github/ci_commit_pins/xla.txt index 3cc7e2854a1..0effbf7984c 100644 --- a/.github/ci_commit_pins/xla.txt +++ b/.github/ci_commit_pins/xla.txt @@ -1 +1 @@ -2ba7616e9070bd14ea34a5ef5459bac571198926 +f00dd2f35ecf6455d97237d63c70c9c8ec190940 diff --git a/aten/src/ATen/native/MetaTensor.cpp b/aten/src/ATen/native/MetaTensor.cpp index a58b18c786e..e0fe9df9387 100644 --- a/aten/src/ATen/native/MetaTensor.cpp +++ b/aten/src/ATen/native/MetaTensor.cpp @@ -12,7 +12,7 @@ namespace at { namespace native { -Tensor empty_meta( +Tensor empty_meta_symint( SymIntArrayRef size, c10::optional dtype_opt, c10::optional layout_opt, diff --git a/aten/src/ATen/native/TensorFactories.cpp b/aten/src/ATen/native/TensorFactories.cpp index 6ccbbbac03a..0a0befbe18d 100644 --- a/aten/src/ATen/native/TensorFactories.cpp +++ b/aten/src/ATen/native/TensorFactories.cpp @@ -214,12 +214,9 @@ Tensor empty_strided_cpu(IntArrayRef size, IntArrayRef stride, c10::optional optional_memory_format, Tensor& result) { - // TODO: support empty_out properly (I was forced to change this immediately - // with empty so that empty/empty.out had the same type signature) - auto size = c10::asIntArrayRefSlow(sym_size); // Preferably, this argument would not be accepted by _out, but the code // generator requires the out and non-out overloads to match exactly TORCH_CHECK( @@ -386,7 +383,7 @@ Tensor empty_like_quantized( } } -Tensor new_empty( +Tensor new_empty_symint( const Tensor& self, SymIntArrayRef size, c10::optional dtype_opt, @@ -1077,7 +1074,7 @@ Tensor triu_indices_cpu( // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ zeros ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -Tensor zeros(SymIntArrayRef size, +Tensor zeros_symint(SymIntArrayRef size, c10::optional dtype, c10::optional layout, c10::optional device, @@ -1107,8 +1104,7 @@ Tensor& zeros_sparse_out(IntArrayRef size, Tensor& result) { return result; } -Tensor& zeros_out(SymIntArrayRef sym_size, Tensor& result) { - auto size = c10::asIntArrayRefSlow(sym_size); +Tensor& zeros_out(IntArrayRef size, Tensor& result) { if (result.is_sparse()) { // TODO: I think this branch should be dead, but we don't have an easy // way to cover all sparse kernels with zeros_sparse_out, so retain this diff --git a/aten/src/ATen/native/TensorShape.cpp b/aten/src/ATen/native/TensorShape.cpp index 64af552c7a8..cdadf44df30 100644 --- a/aten/src/ATen/native/TensorShape.cpp +++ b/aten/src/ATen/native/TensorShape.cpp @@ -844,9 +844,7 @@ Tensor diag_embed(const Tensor& self, int64_t offset, int64_t dim1_, int64_t dim return result; } -Tensor expand(const Tensor& self, c10::SymIntArrayRef sym_size, bool /*unused*/) { - // TODO: properly support SymInt expand - auto size = asIntArrayRefSlow(sym_size); +Tensor expand(const Tensor& self, c10::IntArrayRef size, bool /*unused*/) { TORCH_CHECK(size.size() >= (size_t)self.dim(), "expand(", self.toString(), "{", self.sizes(), "}, size=", size, "): the number of sizes provided (", size.size(), ") ", @@ -925,9 +923,8 @@ const Tensor &as_strided_(const Tensor& self, IntArrayRef size, IntArrayRef stri return self; } -Tensor narrow_copy_dense(const Tensor& self, int64_t dim, SymInt start, SymInt length) { - // TODO: properly support SymInt narrow_copy - return self.narrow(dim, start.expect_int(), length.expect_int()).clone(at::MemoryFormat::Contiguous); +Tensor narrow_copy_dense(const Tensor& self, int64_t dim, int64_t start, int64_t length) { + return self.narrow(dim, start, length).clone(at::MemoryFormat::Contiguous); } Tensor narrow_copy_dense_cpu(const Tensor& self, int64_t dim, int64_t start, int64_t length){ @@ -3204,12 +3201,6 @@ Tensor adjoint(const Tensor &self) { return _adjoint(self, /*transpose=*/false, "adjoint()"); } -Tensor view_meta(const Tensor& self, - at::SymIntArrayRef size) { - // TODO: Properly support SymInt view - return view_impl(self, c10::asIntArrayRefSlow(size)); -} - Tensor view(const Tensor& self, at::IntArrayRef size) { return view_impl(self, size); @@ -3592,7 +3583,7 @@ at::Tensor& expand_copy_SymInt_out(const at::Tensor & self, c10::SymIntArrayRef } -at::Tensor& expand_copy_out(const at::Tensor & self, at::SymIntArrayRef size, bool implicit, at::Tensor & out) { +at::Tensor& expand_copy_out_symint(const at::Tensor & self, at::SymIntArrayRef size, bool implicit, at::Tensor & out) { auto tmp = self.expand_symint(size, implicit); out.copy_(tmp); return out; @@ -3748,7 +3739,7 @@ void unbind_copy_int_out(const at::Tensor & self, int64_t dim, at::TensorList o } -at::Tensor& view_copy_out(const at::Tensor & self, at::SymIntArrayRef size, at::Tensor & out) { +at::Tensor& view_copy_out_symint(const at::Tensor & self, at::SymIntArrayRef size, at::Tensor & out) { auto tmp = self.view_symint(size); out.copy_(tmp); return out; diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index dfb3bddc523..73634685a43 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -2054,7 +2054,7 @@ CPU: empty_cpu CUDA: empty_cuda MPS: empty_mps - Meta: empty_meta + Meta: empty_meta_symint MkldnnCPU: empty_mkldnn SparseCPU, SparseCUDA, SparseMeta: empty_sparse SparseCsrCPU, SparseCsrCUDA: empty_sparse_compressed @@ -2065,7 +2065,7 @@ - func: new_empty(Tensor self, SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor variants: method dispatch: - CompositeExplicitAutograd: new_empty + CompositeExplicitAutograd: new_empty_symint autogen: new_empty.out - func: new_empty_strided(Tensor self, int[] size, int[] stride, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor @@ -5548,7 +5548,7 @@ - func: zeros(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor dispatch: - CompositeExplicitAutograd: zeros + CompositeExplicitAutograd: zeros_symint - func: zeros.out(SymInt[] size, *, Tensor(a!) out) -> Tensor(a!) dispatch: @@ -6889,8 +6889,7 @@ device_check: NoCheck device_guard: False dispatch: - Meta: view_meta - ZeroTensor, CPU, CUDA, QuantizedCPU, QuantizedCUDA, MPS: view + ZeroTensor, Meta, CPU, CUDA, QuantizedCPU, QuantizedCUDA, MPS: view MkldnnCPU: mkldnn_view NestedTensorCPU, NestedTensorCUDA: view_nested @@ -12938,7 +12937,7 @@ - func: expand_copy.out(Tensor self, SymInt[] size, *, bool implicit=False, Tensor(a!) out) -> Tensor(a!) variants: function dispatch: - CompositeExplicitAutograd: expand_copy_out + CompositeExplicitAutograd: expand_copy_out_symint - func: permute_copy.out(Tensor self, int[] dims, *, Tensor(a!) out) -> Tensor(a!) @@ -13058,7 +13057,7 @@ - func: view_copy.out(Tensor self, SymInt[] size, *, Tensor(a!) out) -> Tensor(a!) variants: function dispatch: - CompositeExplicitAutograd: view_copy_out + CompositeExplicitAutograd: view_copy_out_symint - func: view_copy.dtype_out(Tensor self, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!) diff --git a/aten/src/ATen/native/ts_native_functions.yaml b/aten/src/ATen/native/ts_native_functions.yaml index b110aa75c83..b948a898f21 100644 --- a/aten/src/ATen/native/ts_native_functions.yaml +++ b/aten/src/ATen/native/ts_native_functions.yaml @@ -199,6 +199,13 @@ supported: - _trilinear - linalg_pinv.atol_rtol_tensor - logsumexp.out +symint: + - empty.memory_format + - expand + - expand_copy + - narrow_copy + - view + - view_copy autograd: - max_pool3d - native_group_norm diff --git a/aten/src/ATen/test/math_kernel_test.cpp b/aten/src/ATen/test/math_kernel_test.cpp index 29c33889909..15ce0af4001 100644 --- a/aten/src/ATen/test/math_kernel_test.cpp +++ b/aten/src/ATen/test/math_kernel_test.cpp @@ -119,7 +119,7 @@ TEST(MathKernelTest, NarrowCopy) { for (const auto dim : c10::irange(3)) { const int64_t start = 1, length = 4; auto y_ref = x.narrow(dim, start, length); - auto y_test = at::native::narrow_copy_dense(x, dim, c10::SymInt(start), c10::SymInt(length)); + auto y_test = at::native::narrow_copy_dense(x, dim, start, length); ASSERT_ALLCLOSE_TOLERANCES(y_ref, y_test, 0, 0); } } diff --git a/functorch/setup.py b/functorch/setup.py index e200cbd1fcc..3f56c078bca 100644 --- a/functorch/setup.py +++ b/functorch/setup.py @@ -5,6 +5,7 @@ # LICENSE file in the root directory of this source tree. import distutils.command.clean +import sys import shutil import glob import os @@ -129,21 +130,25 @@ class BuildExtension_(BuildExtension): if __name__ == '__main__': print("Building wheel {}-{}".format(package_name, version)) write_version_file() - setup( - # Metadata - name=package_name, - version=version, - author='PyTorch Core Team', - url="https://github.com/pytorch/functorch", - description='JAX-like composable function transforms for PyTorch', - license='BSD', + try: + setup( + # Metadata + name=package_name, + version=version, + author='PyTorch Core Team', + url="https://github.com/pytorch/functorch", + description='JAX-like composable function transforms for PyTorch', + license='BSD', - # Package info - packages=find_packages(), - install_requires=requirements, - extras_require=extras, - ext_modules=get_extensions(), - cmdclass={ - "build_ext": BuildExtension_.with_options(no_python_abi_suffix=True), - 'clean': clean, - }) + # Package info + packages=find_packages(), + install_requires=requirements, + extras_require=extras, + ext_modules=get_extensions(), + cmdclass={ + "build_ext": BuildExtension_.with_options(no_python_abi_suffix=True), + 'clean': clean, + }) + except Exception as e: + print(e, file=sys.stderr) + sys.exit(1) diff --git a/tools/autograd/gen_inplace_or_view_type.py b/tools/autograd/gen_inplace_or_view_type.py index acfd2ac796c..3a8c24f38ea 100644 --- a/tools/autograd/gen_inplace_or_view_type.py +++ b/tools/autograd/gen_inplace_or_view_type.py @@ -584,7 +584,8 @@ def gen_inplace_or_view_type( [fn for fn in fns_with_infos if use_derived(fn)], key_fn=lambda fn: fn.func.root_name, base_env={ - "generated_comment": f"@generated from {template_path}/ADInplaceOrViewType.cpp", + "generated_comment": "@" + + f"generated from {template_path}/ADInplaceOrViewType.cpp", }, env_callable=gen_inplace_or_view_type_env, num_shards=2, diff --git a/tools/autograd/gen_trace_type.py b/tools/autograd/gen_trace_type.py index 21739bb8051..6dc0fcab575 100644 --- a/tools/autograd/gen_trace_type.py +++ b/tools/autograd/gen_trace_type.py @@ -535,7 +535,7 @@ def gen_trace_type( [fn for fn in native_functions if cpp.name(fn.func) not in MANUAL_TRACER], key_fn=lambda fn: fn.root_name, base_env={ - "generated_comment": f"@generated from {template_path}/TraceType.cpp", + "generated_comment": "@" + f"generated from {template_path}/TraceType.cpp", }, env_callable=gen_trace_type_func, num_shards=5, diff --git a/tools/test/test_codegen.py b/tools/test/test_codegen.py index 781dde46fe7..cbce8b3bc5e 100644 --- a/tools/test/test_codegen.py +++ b/tools/test/test_codegen.py @@ -314,7 +314,6 @@ class TestGenNativeFunctionDeclaration(unittest.TestCase): dispatch_key=k, use_out_as_primary=True, external=False, - symint=False, device_guard=False, index=backend_indices[k], ) diff --git a/tools/test/test_gen_backend_stubs.py b/tools/test/test_gen_backend_stubs.py index 9091cca6ddd..377db49ccaf 100644 --- a/tools/test/test_gen_backend_stubs.py +++ b/tools/test/test_gen_backend_stubs.py @@ -238,7 +238,7 @@ invalid_key: invalid_val""" output_error = self.get_errors_from_gen_backend_stubs(yaml_str) self.assertExpectedInline( output_error, - """ contains unexpected keys: invalid_key. Only the following keys are supported: backend, class_name, cpp_namespace, extra_headers, supported, autograd, full_codegen, non_native, ir_gen""", # noqa: B950 + """ contains unexpected keys: invalid_key. Only the following keys are supported: backend, class_name, cpp_namespace, extra_headers, supported, autograd, full_codegen, non_native, ir_gen, symint""", # noqa: B950 ) # if use_out_as_primary is provided, it must be a bool diff --git a/torch/csrc/lazy/ts_backend/ts_native_functions.cpp b/torch/csrc/lazy/ts_backend/ts_native_functions.cpp index 5787ebc62a4..1f29fb6e040 100644 --- a/torch/csrc/lazy/ts_backend/ts_native_functions.cpp +++ b/torch/csrc/lazy/ts_backend/ts_native_functions.cpp @@ -269,7 +269,7 @@ at::Tensor LazyNativeFunctions::_to_copy( } }; -at::Tensor LazyNativeFunctions::empty( +at::Tensor LazyNativeFunctions::empty_symint( at::SymIntArrayRef sym_size, c10::optional dtype, c10::optional layout, @@ -307,7 +307,7 @@ at::Tensor LazyNativeFunctions::empty_strided( c10::optional device, c10::optional pin_memory) { TORCH_LAZY_FN_COUNTER("lazy::"); - at::Tensor t = empty( + at::Tensor t = empty_symint( c10::SymIntArrayRef::fromIntArrayRef(size), dtype, layout, @@ -409,7 +409,7 @@ at::Tensor LazyNativeFunctions::_unsafe_view( const at::Tensor& self, at::IntArrayRef size) { TORCH_LAZY_FN_COUNTER("lazy::"); - return LazyNativeFunctions::view_copy( + return LazyNativeFunctions::view_copy_symint( self, c10::SymIntArrayRef::fromIntArrayRef(size)); } @@ -449,7 +449,7 @@ at::Tensor LazyNativeFunctions::new_empty_strided( self, size, stride, dtype, layout, device, pin_memory); } -at::Tensor LazyNativeFunctions::narrow_copy( +at::Tensor LazyNativeFunctions::narrow_copy_symint( const at::Tensor& self, int64_t dim, c10::SymInt start, diff --git a/torchgen/api/dispatcher.py b/torchgen/api/dispatcher.py index aaab73ef737..58816959f7c 100644 --- a/torchgen/api/dispatcher.py +++ b/torchgen/api/dispatcher.py @@ -35,7 +35,12 @@ def name(func: FunctionSchema) -> str: def argumenttype_type( - t: Type, *, mutable: bool, binds: ArgName, remove_non_owning_ref_types: bool = False + t: Type, + *, + mutable: bool, + binds: ArgName, + remove_non_owning_ref_types: bool = False, + symint: bool = True, ) -> NamedCType: # This is a faux amis. If it makes sense in the future to add # more special cases here, or invert things so cpp.argument_type @@ -45,25 +50,30 @@ def argumenttype_type( t, mutable=mutable, binds=binds, - symint=True, + symint=symint, remove_non_owning_ref_types=remove_non_owning_ref_types, ) def argument_type( - a: Argument, *, binds: ArgName, remove_non_owning_ref_types: bool = False + a: Argument, + *, + binds: ArgName, + remove_non_owning_ref_types: bool = False, + symint: bool = True, ) -> NamedCType: return argumenttype_type( a.type, mutable=a.is_write, binds=binds, remove_non_owning_ref_types=remove_non_owning_ref_types, + symint=symint, ) -def returns_type(rs: Sequence[Return]) -> CType: +def returns_type(rs: Sequence[Return], *, symint: bool = True) -> CType: # At present, there is no difference. But there could be! - return cpp.returns_type(rs, symint=True) + return cpp.returns_type(rs, symint=symint) def jit_arguments(func: FunctionSchema) -> List[Argument]: @@ -89,15 +99,20 @@ def jit_arguments(func: FunctionSchema) -> List[Argument]: ) -def argument(a: Argument, *, remove_non_owning_ref_types: bool = False) -> Binding: +def argument( + a: Argument, *, remove_non_owning_ref_types: bool = False, symint: bool = True +) -> Binding: return Binding( nctype=argument_type( - a, binds=a.name, remove_non_owning_ref_types=remove_non_owning_ref_types + a, + binds=a.name, + remove_non_owning_ref_types=remove_non_owning_ref_types, + symint=symint, ), name=a.name, argument=a, ) -def arguments(func: FunctionSchema) -> List[Binding]: - return [argument(a) for a in jit_arguments(func)] +def arguments(func: FunctionSchema, *, symint: bool = True) -> List[Binding]: + return [argument(a, symint=symint) for a in jit_arguments(func)] diff --git a/torchgen/api/types.py b/torchgen/api/types.py index 9eacacf2fd9..e8741c0e8f6 100644 --- a/torchgen/api/types.py +++ b/torchgen/api/types.py @@ -577,8 +577,10 @@ class DispatcherSignature: # and need to avoid naming collisions. prefix: str = "" + symint: bool = True + def arguments(self) -> List[Binding]: - return dispatcher.arguments(self.func) + return dispatcher.arguments(self.func, symint=self.symint) def name(self) -> str: return self.prefix + dispatcher.name(self.func) @@ -604,7 +606,7 @@ class DispatcherSignature: return [Expr(a.name, a.nctype) for a in self.arguments()] def returns_type(self) -> CType: - return dispatcher.returns_type(self.func.returns) + return dispatcher.returns_type(self.func.returns, symint=self.symint) def ptr_type(self) -> str: dispatcher_args_types_str = ", ".join(a.type for a in self.arguments()) @@ -616,8 +618,10 @@ class DispatcherSignature: return f"{self.returns_type().cpp_type()} ({dispatcher_args_types_str})" @staticmethod - def from_schema(func: FunctionSchema, *, prefix: str = "") -> "DispatcherSignature": - return DispatcherSignature(func, prefix) + def from_schema( + func: FunctionSchema, *, prefix: str = "", symint: bool = True + ) -> "DispatcherSignature": + return DispatcherSignature(func, prefix, symint) @dataclass(frozen=True) @@ -778,15 +782,16 @@ def kernel_signature( # so we'd like to keep the differences as small as possible. # With external backends, we'd like to enforce that they write their kernels with schemas # that match the Dispatcher API directly, if they can. + meta = backend_index.get_kernel(f) + symint = meta is not None and meta.supports_symint() + if symint: + assert ( + f.func.has_symint() + ), f"attempted to define symint kernel for {backend_index.dispatch_key} without SymInt in schema" if backend_index.external: - # Dispatcher signature faithfully does SymInt, which is good for XLA, - # not so good for more conventional backends but we don't have any of - # those. If we do, that's time to add a new Signature that is a cross - # between DispatcherSignature and NativeSignature - assert backend_index.symint - return DispatcherSignature.from_schema(f.func, prefix=prefix) + return DispatcherSignature.from_schema(f.func, prefix=prefix, symint=symint) else: - return NativeSignature(f.func, prefix=prefix, symint=backend_index.symint) + return NativeSignature(f.func, prefix=prefix, symint=symint) # Functions only, no types diff --git a/torchgen/dest/register_dispatch_key.py b/torchgen/dest/register_dispatch_key.py index 2d9a78a912a..88032bb8719 100644 --- a/torchgen/dest/register_dispatch_key.py +++ b/torchgen/dest/register_dispatch_key.py @@ -751,8 +751,11 @@ resize_out(out, sizes, strides, options); ) # Signature of the wrapper function we'll register to the dispatcher + kern = self.backend_index.get_kernel(f) sig = NativeSignature( - f.func, prefix="wrapper_", symint=self.backend_index.symint + f.func, + prefix="wrapper_", + symint=kern is not None and kern.supports_symint(), ) if self.target is Target.NAMESPACED_DECLARATION: diff --git a/torchgen/gen.py b/torchgen/gen.py index 4fa73a02150..0d535275660 100644 --- a/torchgen/gen.py +++ b/torchgen/gen.py @@ -20,6 +20,7 @@ from torchgen.api import cpp from torchgen.api.translate import translate from torchgen.api.types import ( Binding, + CppSignature, CppSignatureGroup, DispatcherSignature, NamedCType, @@ -161,7 +162,6 @@ def parse_native_yaml_struct( device_guard=False, # I'm actually not sure about this; undefined could be hit on # empty TensorList, hypothetically that could have sizes in it - symint=False, index={}, ) ) @@ -176,16 +176,6 @@ def parse_native_yaml_struct( # Only cuda-like devices in tree require device guards device_guard=is_cuda_dispatch_key(k), index=v, - # Which dispatch keys natively support symint - # Note: DispatchKey.CompositeExplicitAutograd has to match out - # composites; I think there's some factoring problem here - symint=k - in [ - DispatchKey.Meta, - DispatchKey.CompositeImplicitAutograd, - DispatchKey.CompositeExplicitAutograd, - DispatchKey.CompositeExplicitAutogradNonFunctional, - ], ) return ParsedYaml(rs, indices) @@ -363,6 +353,8 @@ def static_dispatch_extra_headers(backends: List[BackendIndex]) -> List[str]: # supporting usecases even when there is a memory_format argument along with tensor_option arguments. # This usecase is not covered by tools.codegen.api.translate() yet as its application is limited to static dispatch def translate_args_dispatcher_to_cpp( + sig: DispatcherSignature, + cpp_sig: CppSignature, f: NativeFunction, ) -> str: @@ -385,10 +377,7 @@ def translate_args_dispatcher_to_cpp( output_bindings.append(binding) return output_bindings - disp_sig = DispatcherSignature.from_schema(f.func) - cpp_sig = CppSignatureGroup.from_native_function( - f, method=False, fallback_binding=False - ).signature + disp_sig = sig disp_bindings = disp_sig.arguments() # When last argument of CPP signature has SpecialArgName.possibly_redundant_memory_format NCType, # get memory_format bindings of dispatcher signature to have the same NCType as well @@ -401,11 +390,20 @@ def translate_args_dispatcher_to_cpp( def generate_static_dispatch_backend_call( + sig: DispatcherSignature, f: NativeFunction, backend_index: BackendIndex, ) -> str: - name = DispatcherSignature.from_schema(f.func).name() - exprs = translate_args_dispatcher_to_cpp(f) + cpp_sigs = CppSignatureGroup.from_native_function( + f, method=False, fallback_binding=False + ) + if sig.symint and f.func.has_symint(): + cpp_sig = cpp_sigs.symint_signature + else: + cpp_sig = cpp_sigs.signature + assert cpp_sig is not None + name = cpp_sig.name() + exprs = translate_args_dispatcher_to_cpp(sig, cpp_sig, f) backend_metadata = backend_index.get_kernel(f) kernel_ns = ( backend_metadata.cpp_namespace @@ -417,11 +415,20 @@ def generate_static_dispatch_backend_call( def generate_static_dispatch_fallback_call( + sig: DispatcherSignature, f: NativeFunction, backend_indices: List[BackendIndex], ) -> str: - name = DispatcherSignature.from_schema(f.func).name() - exprs = translate_args_dispatcher_to_cpp(f) + cpp_sigs = CppSignatureGroup.from_native_function( + f, method=False, fallback_binding=False + ) + if sig.symint and f.func.has_symint(): + cpp_sig = cpp_sigs.symint_signature + else: + cpp_sig = cpp_sigs.signature + assert cpp_sig is not None + name = cpp_sig.name() + exprs = translate_args_dispatcher_to_cpp(sig, cpp_sig, f) ns = DEFAULT_KERNEL_NAMESPACE.replace("::native", "") if f.has_composite_explicit_autograd_kernel: return f"return {ns}::{DispatchKey.CompositeExplicitAutograd.lower()}::{name}({exprs});" @@ -437,6 +444,7 @@ def generate_static_dispatch_fallback_call( def static_dispatch( + sig: DispatcherSignature, f: NativeFunction, backend_indices: List[BackendIndex], ) -> str: @@ -453,11 +461,10 @@ def static_dispatch( ) ] if len(keys) == 1: - return generate_static_dispatch_backend_call(f, keys[0]) + return generate_static_dispatch_backend_call(sig, f, keys[0]) elif len(keys) == 0: - return generate_static_dispatch_fallback_call(f, backend_indices) + return generate_static_dispatch_fallback_call(sig, f, backend_indices) - sig = DispatcherSignature.from_schema(f.func) native_tensor_args = [ a.name for a in sig.arguments() @@ -483,10 +490,10 @@ def static_dispatch( for index in keys: dispatch_code.append(f"""case DispatchKey::{index.dispatch_key}:""") dispatch_code.append( - f"""\t{generate_static_dispatch_backend_call(f, index)};""" + f"""\t{generate_static_dispatch_backend_call(sig, f, index)};""" ) - fallback = generate_static_dispatch_fallback_call(f, backend_indices) + fallback = generate_static_dispatch_fallback_call(sig, f, backend_indices) connector = "\n\t\t" return f""" @@ -528,8 +535,6 @@ class ComputeOperators: def __call__(self, f: NativeFunction) -> str: sig = DispatcherSignature.from_schema(f.func) name = f.func.name.unambiguous_name() - call_method_name = "call" - redispatch_method_name = "redispatch" if self.target is Target.DECLARATION: # Note [The ATen Operators API] @@ -563,8 +568,8 @@ struct TORCH_API {name} {{ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::{f.func.name.name}") STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "{f.func.name.overload_name}") STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, {cpp_string(str(f.func))}) - static {sig.defn(name=call_method_name, is_redispatching_fn=False)}; - static {sig.defn(name=redispatch_method_name, is_redispatching_fn=True)}; + static {sig.defn(name="call", is_redispatching_fn=False)}; + static {sig.defn(name="redispatch", is_redispatching_fn=True)}; }};""" elif self.target is Target.DEFINITION: @@ -585,12 +590,13 @@ static C10_NOINLINE c10::TypedOperatorHandle<{name}::schema> create_{name}_typed dispatcher_exprs_str = ", ".join( ["dispatchKeySet"] + [a.name for a in sig.arguments()] ) - dispatcher_call = "redispatch" - method_name = f"{name}::{redispatch_method_name}" + method_base = "redispatch" else: - method_name = f"{name}::{call_method_name}" dispatcher_exprs_str = ", ".join([a.name for a in sig.arguments()]) - dispatcher_call = "call" + method_base = "call" + + dispatcher_call = method_base + method_name = f"{name}::{method_base}" fn_body = f""" static auto op = create_{name}_typed_handle(); @@ -602,7 +608,7 @@ static C10_NOINLINE c10::TypedOperatorHandle<{name}::schema> create_{name}_typed ): # call() should go through static dispatch fn_body = static_dispatch( - f, backend_indices=self.static_dispatch_backend_indices + sig, f, backend_indices=self.static_dispatch_backend_indices ) defns += f""" // aten::{f.func} diff --git a/torchgen/gen_backend_stubs.py b/torchgen/gen_backend_stubs.py index a8108a51411..447d9a72156 100644 --- a/torchgen/gen_backend_stubs.py +++ b/torchgen/gen_backend_stubs.py @@ -3,7 +3,7 @@ import os import pathlib import re from collections import Counter, defaultdict, namedtuple -from typing import Dict, List, Optional, Sequence, Union +from typing import Dict, List, Optional, Sequence, Set, Union import yaml @@ -68,6 +68,7 @@ def parse_backend_yaml( "full_codegen", "non_native", "ir_gen", + "symint", ] backend = yaml_values.pop("backend", None) @@ -96,6 +97,14 @@ def parse_backend_yaml( supported, list ), f'expected "supported" to be a list, but got: {supported} (of type {type(supported)})' + symint = yaml_values.pop("symint", []) + if symint is None: + symint = [] # Allow an empty list of symint ops + assert isinstance( + symint, list + ), f'expected "symint" to be a list, but got: {supported} (of type {type(supported)})' + symint_set = set(symint) + supported_autograd = yaml_values.pop("autograd", []) assert isinstance( supported_autograd, list @@ -118,6 +127,7 @@ Only the following keys are supported: {", ".join(valid_keys)}' def create_backend_index( backend_ops: List[str], + symint_ops: Set[str], dispatch_key: DispatchKey, *, use_out_as_primary: bool, @@ -131,6 +141,8 @@ Only the following keys are supported: {", ".join(valid_keys)}' ), f"Found an invalid operator name: {op_name}" # See Note [External Backends Follow Dispatcher API] kernel_name = dispatcher.name(native_functions_map[op_name].func) + if op in symint_ops: + kernel_name += "_symint" # TODO: allow structured external backends later. m = BackendMetadata( kernel=kernel_name, structured=False, cpp_namespace=cpp_namespace @@ -140,7 +152,6 @@ Only the following keys are supported: {", ".join(valid_keys)}' dispatch_key=dispatch_key, use_out_as_primary=use_out_as_primary, external=True, - symint=True, # TODO: make this configurable device_guard=use_device_guard, index=metadata, ) @@ -154,6 +165,7 @@ Only the following keys are supported: {", ".join(valid_keys)}' backend_idx = create_backend_index( supported, + symint_set, backend_key, use_out_as_primary=use_out_as_primary, use_device_guard=use_device_guard, @@ -171,6 +183,7 @@ the behavior of autograd for some operators on your backend. However "Autograd{b autograd_idx = create_backend_index( supported_autograd, + symint_set, autograd_key, use_out_as_primary=use_out_as_primary, use_device_guard=use_device_guard, diff --git a/torchgen/gen_functionalization_type.py b/torchgen/gen_functionalization_type.py index 107d5737c3f..f4fe8bbd625 100644 --- a/torchgen/gen_functionalization_type.py +++ b/torchgen/gen_functionalization_type.py @@ -84,21 +84,21 @@ def gen_composite_view_copy_kernel(g: NativeFunctionsViewGroup) -> Optional[str] # clone() calls in their graph (which is normally needed by reshape). if str(g.view_copy.func.name) == "view_copy": return """\ -at::Tensor view_copy(const at::Tensor & self, at::SymIntArrayRef size) { - // TODO: don't cast to int array ref - auto int_size = c10::asIntArrayRefSlow(size); - DimVector shape = infer_size_dv(int_size, self.numel()); +at::Tensor view_copy(const at::Tensor & self, at::IntArrayRef size) { + DimVector shape = infer_size_dv(size, self.numel()); if (!at::detail::computeStride(self.sizes(), self.strides(), shape).has_value()) { - return self.reshape(int_size); + return self.reshape(size); } else { - auto output = at::_ops::view::call(self, size); + auto output = at::_ops::view::call(self, c10::SymIntArrayRef::fromIntArrayRef(size)); return output.clone(); } } """ # view_copy is a native signature, since we're generating an at::native:: kernel # Functionalization always operates on symints though - view_copy_sig = NativeSignature(g.view_copy.func, symint=True) + view_copy_sig = NativeSignature( + g.view_copy.func, symint=False + ) # TODO: flag day this True # view is a dispatcher signature, since we're calling into the at::_ops API view_sig = DispatcherSignature(g.view.func) @@ -641,7 +641,7 @@ def gen_functionalization_registration( metadata = composite_implicit_autograd_index.get_kernel(f) assert metadata is not None native_api_name = metadata.kernel - sig = DispatcherSignature.from_schema(f.func) + sig = NativeSignature(f.func, symint=metadata.supports_symint()) # Note [Composite view ops in the functionalization pass] # We don't need to worry about implemententing functionalization kernels for views with # CompositeImplicitAutograd kernels, because we can just decompose them into their base operators. diff --git a/torchgen/gen_vmap_plumbing.py b/torchgen/gen_vmap_plumbing.py index fe614326f7a..263a4842ad4 100644 --- a/torchgen/gen_vmap_plumbing.py +++ b/torchgen/gen_vmap_plumbing.py @@ -79,7 +79,7 @@ def gen_unwraps( def gen_case_where_all_bdims_are_none( - schema: FunctionSchema, cur_level_var: str + outer_sig: DispatcherSignature, schema: FunctionSchema, cur_level_var: str ) -> str: conditions = [] flat_args = schema.arguments.flat_all @@ -90,7 +90,7 @@ def gen_case_where_all_bdims_are_none( sig = DispatcherSignature.from_schema(schema) translated_args = ", ".join( - e.expr for e in translate(sig.arguments(), sig.arguments()) + e.expr for e in translate(outer_sig.arguments(), sig.arguments()) ) return f"""\ if ({' && '.join(conditions)}) {{ @@ -160,7 +160,7 @@ def gen_vmap_inplace_plumbing(native_function: NativeFunction) -> Optional[str]: cur_level_var = "cur_level" unwraps, unwrapped_arg_list = gen_unwraps(schema.arguments.flat_all, cur_level_var) - bdims_all_none_case = gen_case_where_all_bdims_are_none(schema, cur_level_var) + bdims_all_none_case = gen_case_where_all_bdims_are_none(sig, schema, cur_level_var) return f"""\ template @@ -182,7 +182,7 @@ def gen_vmap_plumbing_no_returns(native_function: NativeFunction) -> str: cur_level_var = "cur_level" unwraps, unwrapped_arg_list = gen_unwraps(schema.arguments.flat_all, cur_level_var) - bdims_all_none_case = gen_case_where_all_bdims_are_none(schema, cur_level_var) + bdims_all_none_case = gen_case_where_all_bdims_are_none(sig, schema, cur_level_var) return f"""\ template @@ -224,7 +224,7 @@ def gen_vmap_plumbing(native_function: NativeFunction) -> Optional[str]: cur_level_var = "cur_level" unwraps, unwrapped_arg_list = gen_unwraps(schema.arguments.flat_all, cur_level_var) - bdims_all_none_case = gen_case_where_all_bdims_are_none(schema, cur_level_var) + bdims_all_none_case = gen_case_where_all_bdims_are_none(sig, schema, cur_level_var) wrapped_returns = gen_returns(returns, cur_level_var, results_var) return f"""\ diff --git a/torchgen/model.py b/torchgen/model.py index 81fc05760af..707b7c48c6e 100644 --- a/torchgen/model.py +++ b/torchgen/model.py @@ -1098,6 +1098,9 @@ class BackendMetadata: # The namespace for kernels, default value: DEFAULT_KERNEL_NAMESPACE cpp_namespace: str + def supports_symint(self) -> bool: + return "_symint" in self.kernel + @dataclass(frozen=True) class UfuncInnerLoop: @@ -1141,8 +1144,6 @@ class BackendIndex: external: bool # Other backend-specific information that is on a per-operator basis index: Dict["OperatorName", BackendMetadata] - # Whether or not this backend handles symbolic ints or not - symint: bool @staticmethod def grow_index( diff --git a/torchgen/native_function_generation.py b/torchgen/native_function_generation.py index bf8503ed640..62981ccd82b 100644 --- a/torchgen/native_function_generation.py +++ b/torchgen/native_function_generation.py @@ -304,6 +304,8 @@ def generate_function( if func.kind() == SchemaKind.out else cpp.name(func) ) + if f.func.has_symint(): + kernel_name += "_symint" backend_metadata = { DispatchKey.CompositeExplicitAutograd: { func.name: BackendMetadata( @@ -555,7 +557,7 @@ def gen_composite_functional_kernel(g: NativeFunctionsGroup) -> Optional[str]: clone_mutable_inputs_str = "\n".join(clone_mutable_inputs) return f""" -{sig.defn()} {{ +{sig.defn(name=sig.name() + ("_symint" if g.out.func.has_symint() else ""))} {{ {clone_mutable_inputs_str} {maybe_assign}at::_ops::{target_f.func.name.unambiguous_name()}::call({exprs}); {ret_str} @@ -615,7 +617,7 @@ def gen_composite_out_kernel(g: NativeFunctionsGroup) -> Optional[str]: # Kernel name needs to follow the naming convention defined in `generate_function()` return f""" -{sig.defn(name=g.out.func.name.unambiguous_name())} {{ +{sig.defn(name=g.out.func.name.unambiguous_name() + ("_symint" if g.out.func.has_symint() else ""))} {{ auto {out_name} = at::_ops::{g.functional.func.name.unambiguous_name()}::call({exprs}); {copy_outs_str} {return_str(g.out.func.returns, rets)}