pytorch/tools/codegen/api/native.py
Edward Yang d705083c2b Refactor dispatcher and native to use Signature structure. (#45990)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/45990

In #45890 we introduced the concept of a CppSignature, which bundled
up all of the information necessary to declare a C++ signature for
the cpp API.  This PR introduces analogous concepts for dispatcher
and native: DispatcherSignature and NativeSignature.

The three interfaces are not particularly well coupled right now,
but they do have some duck typing coincidences:

- defn() which renders the C++ definition "bool f(int x)"
- decl() which renders the C++ declaration "bool f(int x = 2)"
- type() which renders the C++ function type "bool(int)"

Maybe at some point we'll introduce a Protocol, or a supertype.
Many other methods (like arguments()) have varying types.  These
signatures also have some helper methods that forward back to real
implementations in the api modules.  Something to think about is
whether or not we should attempt to reduce boilerplate here or
not; I'm not too sure about it yet.

The net effect is we get to reduce the number of variables we
have to explicitly write out in the codegen, since now these are all
bundled together into a signature.  Something extra special happens
in BackendSelect, where we now dynamically select between dispatcher_sig
and native_sig as "how" the backend select is implemented.

A little bit of extra cleanup:
- Some places where we previously advertised Sequence, we now advertise
  a more informative Tuple.
- defn() may take an optional positional parameter overriding the entire
  name, or a kwarg-only prefix parameter to just add a prefix to the
  name.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>

Test Plan: Imported from OSS

Reviewed By: smessmer

Differential Revision: D24223100

Pulled By: ezyang

fbshipit-source-id: f985eced08af4a60ba9641d125d0f260f8cda9eb
2020-10-13 08:34:48 -07:00

79 lines
2.8 KiB
Python

from tools.codegen.model import *
from tools.codegen.api.types import TensorOptionsArguments, NativeArgument, ThisArgument
import tools.codegen.api.cpp as cpp
from typing import Union, Sequence, Tuple
# This file describes the translation of JIT schema to the native functions API.
# This looks a lot like the C++ API (which makes historical sense, because the
# idea was you wrote native functions to implement functions in the C++ API),
# but over time we have evolved the C++ API without actually changing our
# native:: kernels. The intention is to make native API and dispatcher API
# line up as closely as possible, since this results in the least overhead
# (no translation is needed from dispatcher API to native API).
#
# When a function is not use_c10_dispatcher: full, the dispatcher API actually
# coincides with the native:: API (e.g., we do as dumb as pass through as
# possible).
def name(func: FunctionSchema) -> str:
name = str(func.name.name)
# TODO: delete this!
if func.is_out_fn():
name += '_out'
if func.name.overload_name:
name += f'_{func.name.overload_name}'
return name
def argumenttype_type(t: Type, *, mutable: bool) -> str:
if str(t) == 'Tensor?':
if mutable:
return 'Tensor &'
else:
return 'const Tensor &'
elif str(t) == 'Tensor?[]':
return 'TensorList'
return cpp.argumenttype_type(t, mutable=mutable)
def returns_type(rs: Sequence[Return]) -> str:
return cpp.returns_type(rs)
def argument_type(a: Argument) -> str:
return argumenttype_type(a.type, mutable=a.is_write)
def argument(a: Union[Argument, ThisArgument, TensorOptionsArguments]) -> NativeArgument:
if isinstance(a, Argument):
return NativeArgument(
type=argument_type(a),
name=a.name,
default=cpp.default_expr(a.default, a.type) if a.default is not None else None,
argument=a,
)
elif isinstance(a, ThisArgument):
# Erase ThisArgument from the distinction
return NativeArgument(
type=argument_type(a.argument),
name=a.argument.name,
default=None,
argument=a.argument,
)
elif isinstance(a, TensorOptionsArguments):
# TODO: expunge this logic entirely
default = None
if all(x.default == "None" for x in a.all()):
default = '{}'
elif a.dtype.default == "long":
default = 'at::kLong' # TODO: this is wrong
return NativeArgument(
type='const TensorOptions &',
name='options',
default=default,
argument=a,
)
else:
assert_never(a)
def arguments(func: FunctionSchema) -> Tuple[NativeArgument, ...]:
return tuple(map(argument, cpp.group_arguments(func, method=False)))