pytorch/tools/codegen/api/dispatcher.py
Sebastian Messmer 6e2eee2b9d Add faithful C++ API (#44087)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/44087

Each op taking a TensorOptions argument now has an additional overload in the C++ frontend where it takes scattered ScalarType, Layout, Device, bool instead of one TensorOptions argument.

If it is a c10-full op, then the scattered version calls into the dispatcher and the gathered version is a proxy calling into the scattered version.
If it is a non-c10-full op, then the gathered version calls into the dispatcher and the scattered version is a proxy calling into the gathered version.

This should minimize the amount of gathering and scattering needed.

This PR is also a prerequisite to remove the re-gathering of arguments that is currently happening in VariableKernel. Currently, VariableKernels gather arguments into a TensorOptions object
to call into the C++ API. In a PR stacked on top of this, VariableKernel will just directly call into the scattered C++ API introduced here and avoid the gathering step.
ghstack-source-id: 113355689

Test Plan:
waitforsandcastle

vs master: https://www.internalfb.com/intern/fblearner/details/216169815/

vs previous diff: https://www.internalfb.com/intern/fblearner/details/216169957/

Reviewed By: ezyang

Differential Revision: D23492188

fbshipit-source-id: 3e84c467545ad9371e98e09075a311bd18411c5a
2020-10-02 04:08:53 -07:00

148 lines
7.0 KiB
Python

from tools.codegen.model import *
from tools.codegen.api.types import CppArgument, DispatcherExpr, TensorOptionsArguments, \
DispatcherArgument, ThisArgument, LegacyDispatcherArgument
from tools.codegen.api import cpp
import tools.codegen.api.legacy_dispatcher as legacy_dispatcher
import tools.codegen.local as local
from enum import Enum
import itertools
from typing import Sequence, Optional
# This file describes the translation of JIT schema to the dispatcher
# API, the *unboxed* calling convention by which invocations through
# the dispatcher are made. Historically, the dispatcher API matched
# the C++ API, but with the establishment of the boxed API, we've
# made changes to the dispatcher API to so that the unboxed API
# better aligns with the boxed API. The dispatcher API hooks heavily
# into our template based boxing/unboxing machinery, so changes
# to this convention will usually need template updates too.
#
# Prominent characteristics of the dispatcher API:
#
# - 'use_c10_dispatcher: full' controls whether or not we actually
# use the modern calling convention or not. When use_c10_dispatcher
# is not enabled, we don't use the template machinery.
#
# - dtype, layout, device and pin_memory are represented as separate
# arguments.
#
def argumenttype_type(t: Type, *, mutable: bool) -> str:
if local.use_c10_dispatcher() is UseC10Dispatcher.full:
# This is a faux amis. If it makes sense in the future to add
# more special cases here, or invert things so cpp.argument_type
# calls this, or just completely inline the function, please do
# it.
return cpp.argumenttype_type(t, mutable=mutable)
else:
# This is real sharing. If you're modifying this path, ask
# yourself why you are changing the legacy dispatcher protocol
# here and not in legacy_dispatcher.
return legacy_dispatcher.argumenttype_type(t, mutable=mutable)
def argument_type(a: Argument) -> str:
return argumenttype_type(a.type, mutable=a.is_write)
def returns_type(rs: Sequence[Return]) -> str:
# At present, there is no difference. But there could be!
return cpp.returns_type(rs)
def argument(a: Argument) -> DispatcherArgument:
if local.use_c10_dispatcher() is UseC10Dispatcher.full:
return DispatcherArgument(
type=argument_type(a),
name=a.name,
argument=a,
)
else:
la = legacy_dispatcher.argument(a)
return DispatcherArgument(
type=la.type,
name=la.name,
argument=la.argument,
)
def name(func: FunctionSchema) -> str:
return cpp.name(func)
def arguments(func: FunctionSchema) -> Sequence[DispatcherArgument]:
if local.use_c10_dispatcher() is UseC10Dispatcher.full:
return list(map(argument, itertools.chain(func.out_arguments, func.arguments, func.kwarg_only_arguments)))
else:
return [
DispatcherArgument(type=la.type, name=la.name, argument=la.argument)
for la in legacy_dispatcher.arguments(func)
]
# TODO GATHER is only needed for non-c10-full ops, remove later.
ProcessTensoroptions = Enum('ProcessTensoroptions', ('GATHER', 'SCATTER', 'PASS_THROUGH'))
# Given a set of CppArguments in scope, return a sequence of dispatcher
# expressions that translate the cpp API into dispatcher API
def cppargument_exprs(a: CppArgument,
*,
tensor_options: Optional[CppArgument],
process_tensoroptions: ProcessTensoroptions = ProcessTensoroptions.PASS_THROUGH
) -> Sequence[DispatcherExpr]:
if isinstance(a.argument, TensorOptionsArguments):
if process_tensoroptions == ProcessTensoroptions.SCATTER:
ta = a.argument
return [
DispatcherExpr(type=argument_type(ta.dtype), expr=f'optTypeMetaToScalarType({a.name}.dtype_opt())'),
DispatcherExpr(type=argument_type(ta.layout), expr=f'{a.name}.layout_opt()'),
DispatcherExpr(type=argument_type(ta.device), expr=f'{a.name}.device_opt()'),
DispatcherExpr(type=argument_type(ta.pin_memory), expr=f'{a.name}.pinned_memory_opt()'), # weird discrep
]
elif process_tensoroptions == ProcessTensoroptions.GATHER:
return [
DispatcherExpr(
type='const TensorOptions &',
expr="TensorOptions().dtype(dtype).layout(layout).device(device).pinned_memory(pin_memory)")]
else:
assert process_tensoroptions == ProcessTensoroptions.PASS_THROUGH
return [DispatcherExpr(type='const TensorOptions &', expr=a.name)]
elif isinstance(a.argument, ThisArgument):
return [DispatcherExpr(type=argument_type(a.argument.argument), expr=a.name)]
elif isinstance(a.argument, Argument):
if a.name == 'memory_format' and tensor_options is not None and local.use_c10_dispatcher() is UseC10Dispatcher.full:
return [DispatcherExpr(
type=argument_type(a.argument),
expr=f'c10::impl::check_tensor_options_and_extract_memory_format({tensor_options.name}, {a.name})')
]
else:
return [DispatcherExpr(type=argument_type(a.argument), expr=a.name)]
else:
assert_never(a.argument)
def cpparguments_exprs(args: Sequence[CppArgument], process_tensoroptions: ProcessTensoroptions) -> Sequence[DispatcherExpr]:
tensor_options = next((a for a in args if isinstance(a.argument, TensorOptionsArguments)), None)
return [r for a in args for r in cppargument_exprs(a,
tensor_options=tensor_options,
process_tensoroptions=process_tensoroptions)]
# I don't think this is entirely sound, but it should be reasonably
# close
def legacydispatcherarguments_exprs(args: Sequence[LegacyDispatcherArgument]) -> Sequence[DispatcherExpr]:
if local.use_c10_dispatcher() is UseC10Dispatcher.full:
process_tensoroptions = ProcessTensoroptions.SCATTER
else:
process_tensoroptions = ProcessTensoroptions.PASS_THROUGH
return cpparguments_exprs([CppArgument(type=a.type,
name=a.name,
default=None,
argument=a.argument) for a in args],
process_tensoroptions=process_tensoroptions)
def exprs(args: Sequence[DispatcherArgument]) -> Sequence[DispatcherExpr]:
if local.use_c10_dispatcher() is UseC10Dispatcher.full:
process_tensoroptions = ProcessTensoroptions.SCATTER
else:
process_tensoroptions = ProcessTensoroptions.PASS_THROUGH
return cpparguments_exprs([CppArgument(type=a.type,
name=a.name,
default=None,
argument=a.argument) for a in args],
process_tensoroptions=process_tensoroptions)