diff --git a/torchgen/gen.py b/torchgen/gen.py index 4eb22bd9180..3934beb1daa 100644 --- a/torchgen/gen.py +++ b/torchgen/gen.py @@ -301,75 +301,116 @@ def static_dispatch_extra_headers( return [f'#include ' for dispatch_key in static_dispatch_keys(backends)] -def generate_static_dispatch( + +# Translates arguments of a native function from DispatcherSignature form to CppSignature form with support for +# supporting usecases even when there is a memory_format argument along with tensor_option arguments. +# This usecase is not covered by tools.codegen.api.translate() yet as its application is limited to static dispatch +def translate_args_dispatcher_to_cpp( f: NativeFunction, - sig: DispatcherSignature, *, - method: bool, - backend_index: Optional[BackendIndex] ) -> str: - if backend_index is None or f.manual_kernel_registration: - return "" - target_sig = CppSignatureGroup.from_native_function(f, method=False, fallback_binding=False).signature - name = target_sig.name() - dp_sig_args = [] + # Adds SpecialArgName.possibly_redundant_memory_format NamedCType for memory_format bindings + def add_spl_memory_format_binding( + input_bindings: List[Binding] + ) -> List[Binding]: + output_bindings: List[Binding] = [] + for binding in input_bindings: + if binding.name == 'memory_format': + spl_mem_format_binding = Binding( + nctype=NamedCType(SpecialArgName.possibly_redundant_memory_format, binding.nctype.type), + name=binding.name, + default=binding.default, + argument=binding.argument, + ) + output_bindings.append(spl_mem_format_binding) + else: + output_bindings.append(binding) + return output_bindings - # TranslateAPI doesn't support translation of operators from DispatcherAPI->CppAPI when - # there is a memory_format argument after TensorOption arguments. For operators with such arguments, - # amend the dispatcher signature's memory_format argument to have the same nctype as the CPP signature - if len(target_sig.arguments()) > 0 and \ - target_sig.arguments()[-1].nctype.name == SpecialArgName.possibly_redundant_memory_format: - last_disp_arg = sig.arguments()[-1] - dp_sig_args = sig.arguments()[:-1] - mem_format_arg = Binding( - nctype=NamedCType(SpecialArgName.possibly_redundant_memory_format, last_disp_arg.nctype.type), - name=last_disp_arg.name, - default=last_disp_arg.default, - argument=last_disp_arg.argument, - ) - dp_sig_args.append(mem_format_arg) - else: - dp_sig_args = sig.arguments() + disp_sig = DispatcherSignature.from_schema(f.func) + cpp_sig = CppSignatureGroup.from_native_function(f, method=False, fallback_binding=False).signature + disp_bindings = disp_sig.arguments() + # When last argument of CPP signature has SpecialArgName.possibly_redundant_memory_format NCType, + # get memory_format bindings of dispatcher signature to have the same NCType as well + for arg in cpp_sig.arguments(): + if arg.nctype.name == SpecialArgName.possibly_redundant_memory_format: + disp_bindings = add_spl_memory_format_binding(disp_sig.arguments()) + break + exprs = translate(disp_bindings, cpp_sig.arguments()) + return ', '.join(a.expr for a in exprs) - exprs = translate(dp_sig_args, target_sig.arguments(), method=method) - exprs_str = ', '.join(a.expr for a in exprs) - if f.structured_delegate is not None: - # TODO: for ops with structured_delegate it should check the dispatch table of - # the out variant instead. For now, these structured ops all have CPU/CUDA kernels - # so we always dispatch to the `backend`, but this could be wrong when we - # migrate math/default_backend ops to use structured delegate. - if backend_index.has_kernel(f) or backend_index.dispatch_key in STRUCTURED_DISPATCH_KEYS: - return f'return at::{backend_index.dispatch_key.lower()}::{name}({exprs_str});' - else: - return f'TORCH_CHECK(false, "Static dispatch does not support {name} for {backend_index.dispatch_key}.");' - if backend_index.has_kernel(f): - return f'return at::{backend_index.dispatch_key.lower()}::{name}({exprs_str});' - elif f.has_composite_explicit_autograd_kernel: - return f'return at::{DispatchKey.CompositeExplicitAutograd.lower()}::{name}({exprs_str});' +def generate_static_dispatch_backend_call( + f: NativeFunction, + backend_index: BackendIndex, +) -> str: + name = DispatcherSignature.from_schema(f.func).name() + exprs = translate_args_dispatcher_to_cpp(f) + return f'return at::{backend_index.dispatch_key.lower()}::{name}({exprs});' + + +def generate_static_dispatch_fallback_call( + f: NativeFunction, + backend_indices: List[BackendIndex], +) -> str: + name = DispatcherSignature.from_schema(f.func).name() + exprs = translate_args_dispatcher_to_cpp(f) + if f.has_composite_explicit_autograd_kernel: + return f'return at::{DispatchKey.CompositeExplicitAutograd.lower()}::{name}({exprs});' elif f.has_composite_implicit_autograd_kernel: - return f'return at::{DispatchKey.CompositeImplicitAutograd.lower()}::{name}({exprs_str});' - return f'TORCH_CHECK(false, "Static dispatch does not support {name} for {backend_index.dispatch_key}.");' + return f'return at::{DispatchKey.CompositeImplicitAutograd.lower()}::{name}({exprs});' + else: + return f"""TORCH_CHECK(false, "Static dispatch does not support {name} for\ +{', '.join([str(index.dispatch_key)for index in backend_indices])} ");""" def static_dispatch( f: NativeFunction, - sig: DispatcherSignature, - *, - method: bool, backend_indices: List[BackendIndex], ) -> str: if len(backend_indices) == 0 or f.manual_kernel_registration: return "" - keys = [b for b in backend_indices if b.has_kernel(f) or f.structured_delegate is not None] + + keys = [b for b in backend_indices if b.has_kernel(f) or (f.structured_delegate is not None + and b.dispatch_key in STRUCTURED_DISPATCH_KEYS)] if len(keys) == 1: - return generate_static_dispatch(f, sig, method=method, backend_index=keys[0]) + return generate_static_dispatch_backend_call(f, keys[0]) elif len(keys) == 0: - return generate_static_dispatch(f, sig, method=method, backend_index=backend_indices[0]) - else: - return f"""TORCH_CHECK(false, "Static dispatch does not support {f.func.name.unambiguous_name()} for\ -{', '.join([str(index.dispatch_key)for index in backend_indices])} as they have with multiple \ -kernels {', '.join([str(k.get_kernel(f)) for k in keys])} ");""" + return generate_static_dispatch_fallback_call(f, backend_indices) + + sig = DispatcherSignature.from_schema(f.func) + native_tensor_args = [ + a.name for a in sig.arguments() + if isinstance(a.argument, SelfArgument) or isinstance(a.argument, Argument) and a.argument.type.is_tensor_like() + ] + tensor_args = ', '.join(native_tensor_args) + tensor_opts = f.func.arguments.tensor_options + + stmts = [] + subexprs: List[str] = [] + if tensor_opts is not None: + subexprs.append('DispatchKeySet(c10::computeDispatchKey(dtype, layout, device))') + if tensor_args != "": + subexprs.append(f'c10::detail::multi_dispatch_key_set({tensor_args})') + stmts.append(f"""DispatchKeySet _dk_set = {' | '.join(subexprs)};""") + stmts.append('DispatchKey _dk = c10::highestPriorityBackendTypeId(_dk_set);') + + dispatch_code = [] + for index in keys: + dispatch_code.append(f"""case DispatchKey::{index.dispatch_key}:""") + dispatch_code.append(f"""\t{generate_static_dispatch_backend_call(f, index)};""") + + fallback = generate_static_dispatch_fallback_call(f, backend_indices) + connector = '\n\t\t' + + return f""" + {connector.join(stmts)} + switch (_dk) {{ + {connector.join(dispatch_code)} + default: + {fallback} + }} + """ # Generates RegisterSchema.cpp. Depending on the selector, either @@ -473,8 +514,7 @@ static C10_NOINLINE c10::TypedOperatorHandle<{name}::schema> create_{name}_typed if not is_redispatching_fn and len(self.static_dispatch_backend_indices) > 0: # call() should go through static dispatch - fn_body = static_dispatch(f, sig, method=False, - backend_indices=self.static_dispatch_backend_indices) + fn_body = static_dispatch(f, backend_indices=self.static_dispatch_backend_indices) defns += f""" // aten::{f.func} {sig.defn(name=method_name, is_redispatching_fn=is_redispatching_fn)} {{