mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/45990 In #45890 we introduced the concept of a CppSignature, which bundled up all of the information necessary to declare a C++ signature for the cpp API. This PR introduces analogous concepts for dispatcher and native: DispatcherSignature and NativeSignature. The three interfaces are not particularly well coupled right now, but they do have some duck typing coincidences: - defn() which renders the C++ definition "bool f(int x)" - decl() which renders the C++ declaration "bool f(int x = 2)" - type() which renders the C++ function type "bool(int)" Maybe at some point we'll introduce a Protocol, or a supertype. Many other methods (like arguments()) have varying types. These signatures also have some helper methods that forward back to real implementations in the api modules. Something to think about is whether or not we should attempt to reduce boilerplate here or not; I'm not too sure about it yet. The net effect is we get to reduce the number of variables we have to explicitly write out in the codegen, since now these are all bundled together into a signature. Something extra special happens in BackendSelect, where we now dynamically select between dispatcher_sig and native_sig as "how" the backend select is implemented. A little bit of extra cleanup: - Some places where we previously advertised Sequence, we now advertise a more informative Tuple. - defn() may take an optional positional parameter overriding the entire name, or a kwarg-only prefix parameter to just add a prefix to the name. Signed-off-by: Edward Z. Yang <ezyang@fb.com> Test Plan: Imported from OSS Reviewed By: smessmer Differential Revision: D24223100 Pulled By: ezyang fbshipit-source-id: f985eced08af4a60ba9641d125d0f260f8cda9eb
160 lines
6.7 KiB
Python
160 lines
6.7 KiB
Python
from tools.codegen.model import *
|
|
|
|
from tools.codegen.api.types import *
|
|
import tools.codegen.api.cpp as cpp
|
|
import tools.codegen.api.native as native
|
|
import tools.codegen.local as local
|
|
|
|
import itertools
|
|
from typing import Sequence, Optional, Tuple
|
|
|
|
# This file describes the translation of JIT schema to the dispatcher
|
|
# API, the *unboxed* calling convention by which invocations through
|
|
# the dispatcher are made. Historically, the dispatcher API matched
|
|
# the C++ API, but with the establishment of the boxed API, we've
|
|
# made changes to the dispatcher API to so that the unboxed API
|
|
# better aligns with the boxed API. The dispatcher API hooks heavily
|
|
# into our template based boxing/unboxing machinery, so changes
|
|
# to this convention will usually need template updates too.
|
|
#
|
|
# Prominent characteristics of the dispatcher API:
|
|
#
|
|
# - 'use_c10_dispatcher: full' controls whether or not we actually
|
|
# use the modern calling convention or not. When use_c10_dispatcher
|
|
# is not enabled, we don't use the template machinery.
|
|
#
|
|
# - dtype, layout, device and pin_memory are represented as separate
|
|
# arguments.
|
|
#
|
|
|
|
def argumenttype_type(t: Type, *, mutable: bool) -> str:
|
|
if local.use_c10_dispatcher().dispatcher_uses_new_style():
|
|
# This is a faux amis. If it makes sense in the future to add
|
|
# more special cases here, or invert things so cpp.argument_type
|
|
# calls this, or just completely inline the function, please do
|
|
# it.
|
|
return cpp.argumenttype_type(t, mutable=mutable)
|
|
else:
|
|
# This is real sharing. If you're modifying this path, ask
|
|
# yourself why you are changing the native functions protocol
|
|
# here and not in native.
|
|
return native.argumenttype_type(t, mutable=mutable)
|
|
|
|
def argument_type(a: Argument) -> str:
|
|
return argumenttype_type(a.type, mutable=a.is_write)
|
|
|
|
def returns_type(rs: Sequence[Return]) -> str:
|
|
# At present, there is no difference. But there could be!
|
|
return cpp.returns_type(rs)
|
|
|
|
def argument(a: Argument) -> DispatcherArgument:
|
|
if local.use_c10_dispatcher().dispatcher_uses_new_style():
|
|
return DispatcherArgument(
|
|
type=argument_type(a),
|
|
name=a.name,
|
|
argument=a,
|
|
)
|
|
else:
|
|
la = native.argument(a)
|
|
return DispatcherArgument(
|
|
type=la.type,
|
|
name=la.name,
|
|
argument=la.argument,
|
|
)
|
|
|
|
def name(func: FunctionSchema) -> str:
|
|
return cpp.name(func)
|
|
|
|
def arguments(func: FunctionSchema) -> Tuple[DispatcherArgument, ...]:
|
|
if local.use_c10_dispatcher().dispatcher_uses_new_style():
|
|
return tuple(map(argument, itertools.chain(func.out_arguments, func.arguments, func.kwarg_only_arguments)))
|
|
else:
|
|
return tuple(
|
|
DispatcherArgument(type=la.type, name=la.name, argument=la.argument)
|
|
for la in native.arguments(func)
|
|
)
|
|
|
|
# Given a set of CppArguments in scope, return a sequence of dispatcher
|
|
# expressions that translate the cpp API into dispatcher API
|
|
#
|
|
# WARNING: This is unsound if you pass it CppArgument when you were
|
|
# supposed to pass it CppTensorOptionsArguments, it will directly
|
|
# translate device to device, which will give you the wrong signature
|
|
# for dispatcher. If Argument "knew" that it was part of a
|
|
# TensorOptions that would help us dynamically test for this case
|
|
def cppargument_exprs(
|
|
a: CppArgumentPack,
|
|
*, tensor_options: Optional[CppArgument]
|
|
) -> Sequence[DispatcherExpr]:
|
|
if isinstance(a, CppSingleArgumentPack):
|
|
if isinstance(a.this.argument, TensorOptionsArguments):
|
|
if local.use_c10_dispatcher().dispatcher_uses_new_style():
|
|
# Scatter
|
|
ta = a.this.argument
|
|
name = a.this.name
|
|
return [
|
|
DispatcherExpr(type=argument_type(ta.dtype), expr=f'optTypeMetaToScalarType({name}.dtype_opt())'),
|
|
DispatcherExpr(type=argument_type(ta.layout), expr=f'{name}.layout_opt()'),
|
|
DispatcherExpr(type=argument_type(ta.device), expr=f'{name}.device_opt()'),
|
|
DispatcherExpr(type=argument_type(ta.pin_memory), expr=f'{name}.pinned_memory_opt()'), # weird discrep
|
|
]
|
|
else:
|
|
# No-op
|
|
return [DispatcherExpr(type='const TensorOptions &', expr=a.this.name)]
|
|
elif isinstance(a.this.argument, Argument):
|
|
if a.this.name == 'memory_format' and \
|
|
tensor_options is not None and \
|
|
local.use_c10_dispatcher().dispatcher_uses_new_style():
|
|
return [DispatcherExpr(
|
|
type=argument_type(a.this.argument),
|
|
expr=f'c10::impl::check_tensor_options_and_extract_memory_format({tensor_options.name}, {a.this.name})')
|
|
]
|
|
else:
|
|
return [DispatcherExpr(type=argument_type(a.this.argument), expr=a.this.name)]
|
|
else:
|
|
assert_never(a.this.argument)
|
|
elif isinstance(a, CppTensorOptionsArgumentPack):
|
|
if local.use_c10_dispatcher().dispatcher_uses_new_style():
|
|
# No-op
|
|
return [
|
|
expr
|
|
for sub_a in a.explicit_arguments() # NB: don't really care about explicitness here
|
|
for expr in cppargument_exprs(CppSingleArgumentPack(sub_a), tensor_options=tensor_options)
|
|
]
|
|
else:
|
|
# Gather
|
|
return [DispatcherExpr(
|
|
type='const TensorOptions &',
|
|
expr=f'TensorOptions().dtype({a.dtype.name}).layout({a.layout.name})'
|
|
f'.device({a.device.name}).pinned_memory({a.pin_memory.name})',
|
|
)]
|
|
elif isinstance(a, CppThisArgumentPack):
|
|
return [DispatcherExpr(
|
|
type=a.type,
|
|
expr='const_cast<Tensor&>(*this)',
|
|
)]
|
|
else:
|
|
assert_never(a)
|
|
|
|
def cpparguments_exprs(args: Sequence[CppArgumentPack]) -> Sequence[DispatcherExpr]:
|
|
tensor_options = next(
|
|
(a.this for a in args if isinstance(a, CppSingleArgumentPack) and
|
|
isinstance(a.this.argument, TensorOptionsArguments)),
|
|
None
|
|
)
|
|
return [r for a in args for r in cppargument_exprs(a, tensor_options=tensor_options)]
|
|
|
|
# I don't think this is entirely sound, but it should be reasonably
|
|
# close
|
|
def nativearguments_exprs(args: Sequence[NativeArgument]) -> Sequence[DispatcherExpr]:
|
|
return cpparguments_exprs([
|
|
CppSingleArgumentPack(CppArgument(type=a.type, name=a.name, default=None, argument=a.argument))
|
|
for a in args
|
|
])
|
|
|
|
def exprs(args: Sequence[DispatcherArgument]) -> Sequence[DispatcherExpr]:
|
|
return cpparguments_exprs([
|
|
CppSingleArgumentPack(CppArgument(type=a.type, name=a.name, default=None, argument=a.argument))
|
|
for a in args
|
|
])
|