mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[Pytorch][4/4 Static dispatch] Support multiple backends with multiple kernels (#76059)
Summary: - Supports multiple backends with multiple kernels in static dispatch - Refactor static dispatch generators Pull Request resolved: https://github.com/pytorch/pytorch/pull/76059 ghstack-source-id: 154735166 Test Plan: ``` (pytorch) ~/fbsource └─ $ buck build --config pt.enable_lightweight_dispatch=1 --config pt.static_dispatch_backend="CPU;QuantizedCPU;CompositeExplicitAutograd" //xplat/caffe2/fb/lite_predictor:lite_predictor_flatbuffer ``` Reviewed By: bdhirsh Differential Revision: D35727473 fbshipit-source-id: 986ba3390c6e585fcf8477b6d069720ee1fbc90b (cherry picked from commit 6473990c208a78879985e4cdfb50960f5727ad5e)
This commit is contained in:
parent
1df2d6a959
commit
f954c0a774
148
torchgen/gen.py
148
torchgen/gen.py
|
|
@ -301,75 +301,116 @@ def static_dispatch_extra_headers(
|
|||
return [f'#include <ATen/{dispatch_key}Functions.h>'
|
||||
for dispatch_key in static_dispatch_keys(backends)]
|
||||
|
||||
def generate_static_dispatch(
|
||||
|
||||
# Translates arguments of a native function from DispatcherSignature form to CppSignature form with support for
|
||||
# supporting usecases even when there is a memory_format argument along with tensor_option arguments.
|
||||
# This usecase is not covered by tools.codegen.api.translate() yet as its application is limited to static dispatch
|
||||
def translate_args_dispatcher_to_cpp(
|
||||
f: NativeFunction,
|
||||
sig: DispatcherSignature, *,
|
||||
method: bool,
|
||||
backend_index: Optional[BackendIndex]
|
||||
) -> str:
|
||||
if backend_index is None or f.manual_kernel_registration:
|
||||
return ""
|
||||
target_sig = CppSignatureGroup.from_native_function(f, method=False, fallback_binding=False).signature
|
||||
name = target_sig.name()
|
||||
|
||||
dp_sig_args = []
|
||||
# Adds SpecialArgName.possibly_redundant_memory_format NamedCType for memory_format bindings
|
||||
def add_spl_memory_format_binding(
|
||||
input_bindings: List[Binding]
|
||||
) -> List[Binding]:
|
||||
output_bindings: List[Binding] = []
|
||||
for binding in input_bindings:
|
||||
if binding.name == 'memory_format':
|
||||
spl_mem_format_binding = Binding(
|
||||
nctype=NamedCType(SpecialArgName.possibly_redundant_memory_format, binding.nctype.type),
|
||||
name=binding.name,
|
||||
default=binding.default,
|
||||
argument=binding.argument,
|
||||
)
|
||||
output_bindings.append(spl_mem_format_binding)
|
||||
else:
|
||||
output_bindings.append(binding)
|
||||
return output_bindings
|
||||
|
||||
# TranslateAPI doesn't support translation of operators from DispatcherAPI->CppAPI when
|
||||
# there is a memory_format argument after TensorOption arguments. For operators with such arguments,
|
||||
# amend the dispatcher signature's memory_format argument to have the same nctype as the CPP signature
|
||||
if len(target_sig.arguments()) > 0 and \
|
||||
target_sig.arguments()[-1].nctype.name == SpecialArgName.possibly_redundant_memory_format:
|
||||
last_disp_arg = sig.arguments()[-1]
|
||||
dp_sig_args = sig.arguments()[:-1]
|
||||
mem_format_arg = Binding(
|
||||
nctype=NamedCType(SpecialArgName.possibly_redundant_memory_format, last_disp_arg.nctype.type),
|
||||
name=last_disp_arg.name,
|
||||
default=last_disp_arg.default,
|
||||
argument=last_disp_arg.argument,
|
||||
)
|
||||
dp_sig_args.append(mem_format_arg)
|
||||
else:
|
||||
dp_sig_args = sig.arguments()
|
||||
disp_sig = DispatcherSignature.from_schema(f.func)
|
||||
cpp_sig = CppSignatureGroup.from_native_function(f, method=False, fallback_binding=False).signature
|
||||
disp_bindings = disp_sig.arguments()
|
||||
# When last argument of CPP signature has SpecialArgName.possibly_redundant_memory_format NCType,
|
||||
# get memory_format bindings of dispatcher signature to have the same NCType as well
|
||||
for arg in cpp_sig.arguments():
|
||||
if arg.nctype.name == SpecialArgName.possibly_redundant_memory_format:
|
||||
disp_bindings = add_spl_memory_format_binding(disp_sig.arguments())
|
||||
break
|
||||
exprs = translate(disp_bindings, cpp_sig.arguments())
|
||||
return ', '.join(a.expr for a in exprs)
|
||||
|
||||
exprs = translate(dp_sig_args, target_sig.arguments(), method=method)
|
||||
exprs_str = ', '.join(a.expr for a in exprs)
|
||||
if f.structured_delegate is not None:
|
||||
# TODO: for ops with structured_delegate it should check the dispatch table of
|
||||
# the out variant instead. For now, these structured ops all have CPU/CUDA kernels
|
||||
# so we always dispatch to the `backend`, but this could be wrong when we
|
||||
# migrate math/default_backend ops to use structured delegate.
|
||||
if backend_index.has_kernel(f) or backend_index.dispatch_key in STRUCTURED_DISPATCH_KEYS:
|
||||
return f'return at::{backend_index.dispatch_key.lower()}::{name}({exprs_str});'
|
||||
else:
|
||||
return f'TORCH_CHECK(false, "Static dispatch does not support {name} for {backend_index.dispatch_key}.");'
|
||||
|
||||
if backend_index.has_kernel(f):
|
||||
return f'return at::{backend_index.dispatch_key.lower()}::{name}({exprs_str});'
|
||||
elif f.has_composite_explicit_autograd_kernel:
|
||||
return f'return at::{DispatchKey.CompositeExplicitAutograd.lower()}::{name}({exprs_str});'
|
||||
def generate_static_dispatch_backend_call(
|
||||
f: NativeFunction,
|
||||
backend_index: BackendIndex,
|
||||
) -> str:
|
||||
name = DispatcherSignature.from_schema(f.func).name()
|
||||
exprs = translate_args_dispatcher_to_cpp(f)
|
||||
return f'return at::{backend_index.dispatch_key.lower()}::{name}({exprs});'
|
||||
|
||||
|
||||
def generate_static_dispatch_fallback_call(
|
||||
f: NativeFunction,
|
||||
backend_indices: List[BackendIndex],
|
||||
) -> str:
|
||||
name = DispatcherSignature.from_schema(f.func).name()
|
||||
exprs = translate_args_dispatcher_to_cpp(f)
|
||||
if f.has_composite_explicit_autograd_kernel:
|
||||
return f'return at::{DispatchKey.CompositeExplicitAutograd.lower()}::{name}({exprs});'
|
||||
elif f.has_composite_implicit_autograd_kernel:
|
||||
return f'return at::{DispatchKey.CompositeImplicitAutograd.lower()}::{name}({exprs_str});'
|
||||
return f'TORCH_CHECK(false, "Static dispatch does not support {name} for {backend_index.dispatch_key}.");'
|
||||
return f'return at::{DispatchKey.CompositeImplicitAutograd.lower()}::{name}({exprs});'
|
||||
else:
|
||||
return f"""TORCH_CHECK(false, "Static dispatch does not support {name} for\
|
||||
{', '.join([str(index.dispatch_key)for index in backend_indices])} ");"""
|
||||
|
||||
|
||||
def static_dispatch(
|
||||
f: NativeFunction,
|
||||
sig: DispatcherSignature,
|
||||
*,
|
||||
method: bool,
|
||||
backend_indices: List[BackendIndex],
|
||||
) -> str:
|
||||
if len(backend_indices) == 0 or f.manual_kernel_registration:
|
||||
return ""
|
||||
keys = [b for b in backend_indices if b.has_kernel(f) or f.structured_delegate is not None]
|
||||
|
||||
keys = [b for b in backend_indices if b.has_kernel(f) or (f.structured_delegate is not None
|
||||
and b.dispatch_key in STRUCTURED_DISPATCH_KEYS)]
|
||||
if len(keys) == 1:
|
||||
return generate_static_dispatch(f, sig, method=method, backend_index=keys[0])
|
||||
return generate_static_dispatch_backend_call(f, keys[0])
|
||||
elif len(keys) == 0:
|
||||
return generate_static_dispatch(f, sig, method=method, backend_index=backend_indices[0])
|
||||
else:
|
||||
return f"""TORCH_CHECK(false, "Static dispatch does not support {f.func.name.unambiguous_name()} for\
|
||||
{', '.join([str(index.dispatch_key)for index in backend_indices])} as they have with multiple \
|
||||
kernels {', '.join([str(k.get_kernel(f)) for k in keys])} ");"""
|
||||
return generate_static_dispatch_fallback_call(f, backend_indices)
|
||||
|
||||
sig = DispatcherSignature.from_schema(f.func)
|
||||
native_tensor_args = [
|
||||
a.name for a in sig.arguments()
|
||||
if isinstance(a.argument, SelfArgument) or isinstance(a.argument, Argument) and a.argument.type.is_tensor_like()
|
||||
]
|
||||
tensor_args = ', '.join(native_tensor_args)
|
||||
tensor_opts = f.func.arguments.tensor_options
|
||||
|
||||
stmts = []
|
||||
subexprs: List[str] = []
|
||||
if tensor_opts is not None:
|
||||
subexprs.append('DispatchKeySet(c10::computeDispatchKey(dtype, layout, device))')
|
||||
if tensor_args != "":
|
||||
subexprs.append(f'c10::detail::multi_dispatch_key_set({tensor_args})')
|
||||
stmts.append(f"""DispatchKeySet _dk_set = {' | '.join(subexprs)};""")
|
||||
stmts.append('DispatchKey _dk = c10::highestPriorityBackendTypeId(_dk_set);')
|
||||
|
||||
dispatch_code = []
|
||||
for index in keys:
|
||||
dispatch_code.append(f"""case DispatchKey::{index.dispatch_key}:""")
|
||||
dispatch_code.append(f"""\t{generate_static_dispatch_backend_call(f, index)};""")
|
||||
|
||||
fallback = generate_static_dispatch_fallback_call(f, backend_indices)
|
||||
connector = '\n\t\t'
|
||||
|
||||
return f"""
|
||||
{connector.join(stmts)}
|
||||
switch (_dk) {{
|
||||
{connector.join(dispatch_code)}
|
||||
default:
|
||||
{fallback}
|
||||
}}
|
||||
"""
|
||||
|
||||
|
||||
# Generates RegisterSchema.cpp. Depending on the selector, either
|
||||
|
|
@ -473,8 +514,7 @@ static C10_NOINLINE c10::TypedOperatorHandle<{name}::schema> create_{name}_typed
|
|||
|
||||
if not is_redispatching_fn and len(self.static_dispatch_backend_indices) > 0:
|
||||
# call() should go through static dispatch
|
||||
fn_body = static_dispatch(f, sig, method=False,
|
||||
backend_indices=self.static_dispatch_backend_indices)
|
||||
fn_body = static_dispatch(f, backend_indices=self.static_dispatch_backend_indices)
|
||||
defns += f"""
|
||||
// aten::{f.func}
|
||||
{sig.defn(name=method_name, is_redispatching_fn=is_redispatching_fn)} {{
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user