[Pytorch][4/4 Static dispatch] Support multiple backends with multiple kernels (#76059)

Summary:
- Supports multiple backends with multiple kernels in static dispatch
- Refactor static dispatch generators

Pull Request resolved: https://github.com/pytorch/pytorch/pull/76059
ghstack-source-id: 154735166

Test Plan:
```
(pytorch)  ~/fbsource
└─ $ buck build --config pt.enable_lightweight_dispatch=1 --config pt.static_dispatch_backend="CPU;QuantizedCPU;CompositeExplicitAutograd" //xplat/caffe2/fb/lite_predictor:lite_predictor_flatbuffer
```

Reviewed By: bdhirsh

Differential Revision: D35727473

fbshipit-source-id: 986ba3390c6e585fcf8477b6d069720ee1fbc90b
(cherry picked from commit 6473990c208a78879985e4cdfb50960f5727ad5e)
This commit is contained in:
Priya Ramani 2022-04-25 14:06:13 -07:00 committed by PyTorch MergeBot
parent 1df2d6a959
commit f954c0a774

View File

@ -301,75 +301,116 @@ def static_dispatch_extra_headers(
return [f'#include <ATen/{dispatch_key}Functions.h>'
for dispatch_key in static_dispatch_keys(backends)]
def generate_static_dispatch(
# Translates arguments of a native function from DispatcherSignature form to CppSignature form with support for
# supporting usecases even when there is a memory_format argument along with tensor_option arguments.
# This usecase is not covered by tools.codegen.api.translate() yet as its application is limited to static dispatch
def translate_args_dispatcher_to_cpp(
f: NativeFunction,
sig: DispatcherSignature, *,
method: bool,
backend_index: Optional[BackendIndex]
) -> str:
if backend_index is None or f.manual_kernel_registration:
return ""
target_sig = CppSignatureGroup.from_native_function(f, method=False, fallback_binding=False).signature
name = target_sig.name()
dp_sig_args = []
# Adds SpecialArgName.possibly_redundant_memory_format NamedCType for memory_format bindings
def add_spl_memory_format_binding(
input_bindings: List[Binding]
) -> List[Binding]:
output_bindings: List[Binding] = []
for binding in input_bindings:
if binding.name == 'memory_format':
spl_mem_format_binding = Binding(
nctype=NamedCType(SpecialArgName.possibly_redundant_memory_format, binding.nctype.type),
name=binding.name,
default=binding.default,
argument=binding.argument,
)
output_bindings.append(spl_mem_format_binding)
else:
output_bindings.append(binding)
return output_bindings
# TranslateAPI doesn't support translation of operators from DispatcherAPI->CppAPI when
# there is a memory_format argument after TensorOption arguments. For operators with such arguments,
# amend the dispatcher signature's memory_format argument to have the same nctype as the CPP signature
if len(target_sig.arguments()) > 0 and \
target_sig.arguments()[-1].nctype.name == SpecialArgName.possibly_redundant_memory_format:
last_disp_arg = sig.arguments()[-1]
dp_sig_args = sig.arguments()[:-1]
mem_format_arg = Binding(
nctype=NamedCType(SpecialArgName.possibly_redundant_memory_format, last_disp_arg.nctype.type),
name=last_disp_arg.name,
default=last_disp_arg.default,
argument=last_disp_arg.argument,
)
dp_sig_args.append(mem_format_arg)
else:
dp_sig_args = sig.arguments()
disp_sig = DispatcherSignature.from_schema(f.func)
cpp_sig = CppSignatureGroup.from_native_function(f, method=False, fallback_binding=False).signature
disp_bindings = disp_sig.arguments()
# When last argument of CPP signature has SpecialArgName.possibly_redundant_memory_format NCType,
# get memory_format bindings of dispatcher signature to have the same NCType as well
for arg in cpp_sig.arguments():
if arg.nctype.name == SpecialArgName.possibly_redundant_memory_format:
disp_bindings = add_spl_memory_format_binding(disp_sig.arguments())
break
exprs = translate(disp_bindings, cpp_sig.arguments())
return ', '.join(a.expr for a in exprs)
exprs = translate(dp_sig_args, target_sig.arguments(), method=method)
exprs_str = ', '.join(a.expr for a in exprs)
if f.structured_delegate is not None:
# TODO: for ops with structured_delegate it should check the dispatch table of
# the out variant instead. For now, these structured ops all have CPU/CUDA kernels
# so we always dispatch to the `backend`, but this could be wrong when we
# migrate math/default_backend ops to use structured delegate.
if backend_index.has_kernel(f) or backend_index.dispatch_key in STRUCTURED_DISPATCH_KEYS:
return f'return at::{backend_index.dispatch_key.lower()}::{name}({exprs_str});'
else:
return f'TORCH_CHECK(false, "Static dispatch does not support {name} for {backend_index.dispatch_key}.");'
if backend_index.has_kernel(f):
return f'return at::{backend_index.dispatch_key.lower()}::{name}({exprs_str});'
elif f.has_composite_explicit_autograd_kernel:
return f'return at::{DispatchKey.CompositeExplicitAutograd.lower()}::{name}({exprs_str});'
def generate_static_dispatch_backend_call(
f: NativeFunction,
backend_index: BackendIndex,
) -> str:
name = DispatcherSignature.from_schema(f.func).name()
exprs = translate_args_dispatcher_to_cpp(f)
return f'return at::{backend_index.dispatch_key.lower()}::{name}({exprs});'
def generate_static_dispatch_fallback_call(
f: NativeFunction,
backend_indices: List[BackendIndex],
) -> str:
name = DispatcherSignature.from_schema(f.func).name()
exprs = translate_args_dispatcher_to_cpp(f)
if f.has_composite_explicit_autograd_kernel:
return f'return at::{DispatchKey.CompositeExplicitAutograd.lower()}::{name}({exprs});'
elif f.has_composite_implicit_autograd_kernel:
return f'return at::{DispatchKey.CompositeImplicitAutograd.lower()}::{name}({exprs_str});'
return f'TORCH_CHECK(false, "Static dispatch does not support {name} for {backend_index.dispatch_key}.");'
return f'return at::{DispatchKey.CompositeImplicitAutograd.lower()}::{name}({exprs});'
else:
return f"""TORCH_CHECK(false, "Static dispatch does not support {name} for\
{', '.join([str(index.dispatch_key)for index in backend_indices])} ");"""
def static_dispatch(
f: NativeFunction,
sig: DispatcherSignature,
*,
method: bool,
backend_indices: List[BackendIndex],
) -> str:
if len(backend_indices) == 0 or f.manual_kernel_registration:
return ""
keys = [b for b in backend_indices if b.has_kernel(f) or f.structured_delegate is not None]
keys = [b for b in backend_indices if b.has_kernel(f) or (f.structured_delegate is not None
and b.dispatch_key in STRUCTURED_DISPATCH_KEYS)]
if len(keys) == 1:
return generate_static_dispatch(f, sig, method=method, backend_index=keys[0])
return generate_static_dispatch_backend_call(f, keys[0])
elif len(keys) == 0:
return generate_static_dispatch(f, sig, method=method, backend_index=backend_indices[0])
else:
return f"""TORCH_CHECK(false, "Static dispatch does not support {f.func.name.unambiguous_name()} for\
{', '.join([str(index.dispatch_key)for index in backend_indices])} as they have with multiple \
kernels {', '.join([str(k.get_kernel(f)) for k in keys])} ");"""
return generate_static_dispatch_fallback_call(f, backend_indices)
sig = DispatcherSignature.from_schema(f.func)
native_tensor_args = [
a.name for a in sig.arguments()
if isinstance(a.argument, SelfArgument) or isinstance(a.argument, Argument) and a.argument.type.is_tensor_like()
]
tensor_args = ', '.join(native_tensor_args)
tensor_opts = f.func.arguments.tensor_options
stmts = []
subexprs: List[str] = []
if tensor_opts is not None:
subexprs.append('DispatchKeySet(c10::computeDispatchKey(dtype, layout, device))')
if tensor_args != "":
subexprs.append(f'c10::detail::multi_dispatch_key_set({tensor_args})')
stmts.append(f"""DispatchKeySet _dk_set = {' | '.join(subexprs)};""")
stmts.append('DispatchKey _dk = c10::highestPriorityBackendTypeId(_dk_set);')
dispatch_code = []
for index in keys:
dispatch_code.append(f"""case DispatchKey::{index.dispatch_key}:""")
dispatch_code.append(f"""\t{generate_static_dispatch_backend_call(f, index)};""")
fallback = generate_static_dispatch_fallback_call(f, backend_indices)
connector = '\n\t\t'
return f"""
{connector.join(stmts)}
switch (_dk) {{
{connector.join(dispatch_code)}
default:
{fallback}
}}
"""
# Generates RegisterSchema.cpp. Depending on the selector, either
@ -473,8 +514,7 @@ static C10_NOINLINE c10::TypedOperatorHandle<{name}::schema> create_{name}_typed
if not is_redispatching_fn and len(self.static_dispatch_backend_indices) > 0:
# call() should go through static dispatch
fn_body = static_dispatch(f, sig, method=False,
backend_indices=self.static_dispatch_backend_indices)
fn_body = static_dispatch(f, backend_indices=self.static_dispatch_backend_indices)
defns += f"""
// aten::{f.func}
{sig.defn(name=method_name, is_redispatching_fn=is_redispatching_fn)} {{