mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/45990 In #45890 we introduced the concept of a CppSignature, which bundled up all of the information necessary to declare a C++ signature for the cpp API. This PR introduces analogous concepts for dispatcher and native: DispatcherSignature and NativeSignature. The three interfaces are not particularly well coupled right now, but they do have some duck typing coincidences: - defn() which renders the C++ definition "bool f(int x)" - decl() which renders the C++ declaration "bool f(int x = 2)" - type() which renders the C++ function type "bool(int)" Maybe at some point we'll introduce a Protocol, or a supertype. Many other methods (like arguments()) have varying types. These signatures also have some helper methods that forward back to real implementations in the api modules. Something to think about is whether or not we should attempt to reduce boilerplate here or not; I'm not too sure about it yet. The net effect is we get to reduce the number of variables we have to explicitly write out in the codegen, since now these are all bundled together into a signature. Something extra special happens in BackendSelect, where we now dynamically select between dispatcher_sig and native_sig as "how" the backend select is implemented. A little bit of extra cleanup: - Some places where we previously advertised Sequence, we now advertise a more informative Tuple. - defn() may take an optional positional parameter overriding the entire name, or a kwarg-only prefix parameter to just add a prefix to the name. Signed-off-by: Edward Z. Yang <ezyang@fb.com> Test Plan: Imported from OSS Reviewed By: smessmer Differential Revision: D24223100 Pulled By: ezyang fbshipit-source-id: f985eced08af4a60ba9641d125d0f260f8cda9eb
79 lines
2.8 KiB
Python
79 lines
2.8 KiB
Python
from tools.codegen.model import *
|
|
|
|
from tools.codegen.api.types import TensorOptionsArguments, NativeArgument, ThisArgument
|
|
import tools.codegen.api.cpp as cpp
|
|
|
|
from typing import Union, Sequence, Tuple
|
|
|
|
# This file describes the translation of JIT schema to the native functions API.
|
|
# This looks a lot like the C++ API (which makes historical sense, because the
|
|
# idea was you wrote native functions to implement functions in the C++ API),
|
|
# but over time we have evolved the C++ API without actually changing our
|
|
# native:: kernels. The intention is to make native API and dispatcher API
|
|
# line up as closely as possible, since this results in the least overhead
|
|
# (no translation is needed from dispatcher API to native API).
|
|
#
|
|
# When a function is not use_c10_dispatcher: full, the dispatcher API actually
|
|
# coincides with the native:: API (e.g., we do as dumb as pass through as
|
|
# possible).
|
|
|
|
def name(func: FunctionSchema) -> str:
|
|
name = str(func.name.name)
|
|
# TODO: delete this!
|
|
if func.is_out_fn():
|
|
name += '_out'
|
|
if func.name.overload_name:
|
|
name += f'_{func.name.overload_name}'
|
|
return name
|
|
|
|
def argumenttype_type(t: Type, *, mutable: bool) -> str:
|
|
if str(t) == 'Tensor?':
|
|
if mutable:
|
|
return 'Tensor &'
|
|
else:
|
|
return 'const Tensor &'
|
|
elif str(t) == 'Tensor?[]':
|
|
return 'TensorList'
|
|
return cpp.argumenttype_type(t, mutable=mutable)
|
|
|
|
def returns_type(rs: Sequence[Return]) -> str:
|
|
return cpp.returns_type(rs)
|
|
|
|
def argument_type(a: Argument) -> str:
|
|
return argumenttype_type(a.type, mutable=a.is_write)
|
|
|
|
def argument(a: Union[Argument, ThisArgument, TensorOptionsArguments]) -> NativeArgument:
|
|
if isinstance(a, Argument):
|
|
return NativeArgument(
|
|
type=argument_type(a),
|
|
name=a.name,
|
|
default=cpp.default_expr(a.default, a.type) if a.default is not None else None,
|
|
argument=a,
|
|
)
|
|
elif isinstance(a, ThisArgument):
|
|
# Erase ThisArgument from the distinction
|
|
return NativeArgument(
|
|
type=argument_type(a.argument),
|
|
name=a.argument.name,
|
|
default=None,
|
|
argument=a.argument,
|
|
)
|
|
elif isinstance(a, TensorOptionsArguments):
|
|
# TODO: expunge this logic entirely
|
|
default = None
|
|
if all(x.default == "None" for x in a.all()):
|
|
default = '{}'
|
|
elif a.dtype.default == "long":
|
|
default = 'at::kLong' # TODO: this is wrong
|
|
return NativeArgument(
|
|
type='const TensorOptions &',
|
|
name='options',
|
|
default=default,
|
|
argument=a,
|
|
)
|
|
else:
|
|
assert_never(a)
|
|
|
|
def arguments(func: FunctionSchema) -> Tuple[NativeArgument, ...]:
|
|
return tuple(map(argument, cpp.group_arguments(func, method=False)))
|