pytorch/tools/codegen/api/cpp.py
Edward Yang 9079aea1ac Rewrite implementation of faithful cpp signatures (#45890)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/45890

This rewrite is as per my comments at https://github.com/pytorch/pytorch/pull/44087#issuecomment-701664506
I did the rewrite by reverting #44087 and then reimplementing it on top.
You may find it easier to review by diffing against master with only #44087
reverted.

There are two main ideas.

First, we now factor cpp argument processing into two phases operating
on three representations of data:

1. `FunctionSchema` - this is the source from native_functions.yaml
2. `Union[Argument, ThisArgument, TensorOptionsArgument]` - this is
   the arguments after doing some basic semantic analysis to group
   them (for TensorOptions) or identify the this argument (if this
   is a method).  There is only ever one of these per functions.
3. `Union[CppArgument, CppThisArgument, CppTensorOptionsArgument]` -
   this is the arguments after we've elaborated them to C++.  There
   may be multiple of these per actual C++ signature.

You can think of (2) as common processing, whereas (3) bakes in specific
assumptions about whether or not you have a faithful or non-faithful
signature.

Second, we now have CppSignature and CppSignatureGroup representing
the *total* public C++ API signature.  So those dataclasses are what
know how to render definitions/declarations, and you no longer have
to manually type it out in the Functions/TensorMethods codegen.

Here is an exhaustive accounting of the changes.

tools.codegen.api.types

- CppSignature and CppSignatureGroup got moved to tools.codegen.api.types
- Add new CppThisArgument and CppTensorOptionsArguments (modeled off
  of ThisArgument and TensorOptionsArguments) so that we can retain
  high level semantic structure even after elaborating terms with C++
  API information.  Once this is done, we can refine
  CppArgument.argument to no longer contain a ThisArgument (ThisArgument
  is always translated to CppThisArgument.  Note that this doesn't
  apply to TensorOptionsArguments, as those may be expanded or not
  expanded, and so you could get a single CppArgument for 'options')
- Add no_default() functional mutator to easily remove default arguments
  from CppArgument and friends
- Add an explicit_arguments() method to CppArgument and friends to
  extract (flat) argument list that must be explicitly written in the signature.
  This is everything except (Cpp)ThisArgument, and is also convenient
  when you don't care about the extra structure of
  CppTensorOptionsArguments

tools.codegen.api.cpp

- group_arguments is back, and it doesn't send things directly to a
  CppSignatureGroup; instead, it moves us from representation (1) to (2)
  (perhaps it should live in model).  Here I changed my mind from my
  PR comment; I discovered it was not necessary to do classification at
  grouping time, and it was simpler and easier to do it later.
- argument got split into argument_not_this/argument/argument_faithful.
  argument and argument_faithful are obvious enough what they do,
  and I needed argument_not_this as a more refined version of argument
  so that I could get the types to work out on TensorOptionsArguments

tools.codegen.api.dispatcher

- Here we start seeing the payoff.  The old version of this code had a
  "scatter" mode and a "gather" mode.  We don't need that anymore:
  cppargument_exprs is 100% type-directed via the passed in cpp
  arguments.  I am able to write the functions without any reference
  to use_c10_dispatcher

tools.codegen.gen

- Instead of having exprs_str and types_str functions, I moved these to
  live directly on CppSignature, since it seemed pretty logical.
- The actual codegen for TensorMethods/Functions is greatly simplified,
  since (1) all of the heavy lifting is now happening in
  CppSignature(Group) construction, and (2) I don't need to proxy one
  way or another, the new dispatcher translation code is able to handle
  both cases no problem.  There is a little faffing about with ordering
  to reduce the old and new diff which could be removed afterwards.

Here are codegen diffs.  For use_c10_dispatcher: full:

```
+// aten::_cudnn_init_dropout_state(float dropout, bool train, int dropout_seed, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor
 Tensor _cudnn_init_dropout_state(double dropout, bool train, int64_t dropout_seed, const TensorOptions & options) {
-    return _cudnn_init_dropout_state(dropout, train, dropout_seed, optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt());
+    static auto op = c10::Dispatcher::singleton()
+        .findSchemaOrThrow("aten::_cudnn_init_dropout_state", "")
+        .typed<Tensor (double, bool, int64_t, c10::optional<ScalarType>, c10::optional<Layout>, c10::optional<Device>, c10::optional<bool>)>();
+    return op.call(dropout, train, dropout_seed, optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt());
 }
```

Otherwise:

```
+// aten::empty_meta(int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor
 Tensor empty_meta(IntArrayRef size, c10::optional<ScalarType> dtype, c10::optional<Layout> layout, c10::optional<Device> device, c10::optional<bool> pin_memory, c10::optional<MemoryFormat> memory_format) {
-    return empty_meta(size, TensorOptions().dtype(dtype).layout(layout).device(device).pinned_memory(pin_memory), memory_format);
+    static auto op = c10::Dispatcher::singleton()
+        .findSchemaOrThrow("aten::empty_meta", "")
+        .typed<Tensor (IntArrayRef, const TensorOptions &, c10::optional<MemoryFormat>)>();
+    return op.call(size, TensorOptions().dtype(dtype).layout(layout).device(device).pinned_memory(pin_memory), memory_format);
 }
```

Things that I probably did not get right:

- The Union[Argument, TensorOptionsArguments, ThisArgument] and
  the Cpp variants are starting to get a little unwieldy.  Not sure if
  this means I should add a supertype (or at the very least an
  alias); in some cases I do purposely omit one of these from the Union
- Code may not necessarily live in the most logical files.  There isn't
  very much rhyme or reason to it.
- The fields on CppSignature.  They're not very well constrained and
  it will be better if people don't use them directly.
- Disambiguation.  We should do this properly in #44087 and we don't
  need special logic for deleting defaulting for faithful signatures;
  there is a more general story here.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>

Test Plan: Imported from OSS

Reviewed By: smessmer

Differential Revision: D24144035

Pulled By: ezyang

fbshipit-source-id: a185f8bf9df8b44ca5718a7a44dac23cefd11c0a
2020-10-13 08:31:54 -07:00

289 lines
10 KiB
Python

from tools.codegen.model import *
from tools.codegen.api.types import *
import tools.codegen.local as local
from typing import Optional, Sequence, Union, Callable, List
# This file describes the translation of JIT schema to the public C++
# API, which is what people use when they call functions like at::add.
#
# Prominent characteristics of the C++ API:
#
# - dtype, layout, device and pin_memory are collected into
# a single C++ type TensorOptions (the legacy dispatcher API
# also has this, but tensor options is really most relevant
# for the C++ API; it makes calling kwarg factory functions
# pleasant)
#
# - for 'use_c10_dispatcher: full' functions, optional tensors are
# represented explicitly using c10::optional
#
# - defaulting lives here (in fact, the dispatcher is completely
# oblivious of defaults!)
#
# BTW: policy on name collisions: we try not to have types with
# collisions, but functions are fair game to collide
def name(func: FunctionSchema) -> str:
name = str(func.name.name)
if func.is_out_fn():
name += '_out'
return name
# Translation of "value types" in JIT schema to C++ API type. Value
# types look the same no matter if they are argument types are return
# types. Returns None if the type in question is not a value type.
def valuetype_type(t: Type) -> Optional[str]:
if isinstance(t, BaseType):
if t.name == BaseTy.Tensor:
return None
elif t.name == BaseTy.int:
return 'int64_t'
elif t.name == BaseTy.float:
return 'double'
elif t.name == BaseTy.str:
return 'std::string'
elif t.name in [BaseTy.bool, BaseTy.QScheme, BaseTy.Scalar,
BaseTy.ScalarType, BaseTy.Generator, BaseTy.Storage,
BaseTy.Layout, BaseTy.Device, BaseTy.MemoryFormat,
BaseTy.Dimname, BaseTy.ConstQuantizerPtr]:
# These C++ names line up with their schema names
return t.name.name
else:
raise AssertionError(f"unsupported type: {t}")
elif isinstance(t, OptionalType):
elem = valuetype_type(t.elem)
if elem is None:
return None
return f"c10::optional<{elem}>"
elif isinstance(t, ListType):
if str(t.elem) == 'bool':
assert t.size is not None
return f"std::array<bool,{t.size}>"
else:
return None
else:
raise AssertionError(f"unrecognized type {repr(t)}")
# Translation of types occuring in JIT arguments to a C++ argument type.
def argumenttype_type(t: Type, *, mutable: bool) -> str:
# If it's a value type, do the value type translation
r = valuetype_type(t)
if r is not None:
return r
if isinstance(t, BaseType):
if t.name == BaseTy.Tensor:
if mutable:
return 'Tensor &'
else:
return 'const Tensor &'
else:
raise AssertionError(f"base type should have been value type {t}")
elif isinstance(t, OptionalType):
if str(t.elem) == 'Tensor':
if mutable:
return 'Tensor &' # TODO: fix this discrepancy
else:
if local.use_c10_dispatcher().dispatcher_uses_new_style():
return 'const c10::optional<Tensor>&'
else:
return 'const Tensor &'
elem = argumenttype_type(t.elem, mutable=mutable)
return f"c10::optional<{elem}>"
elif isinstance(t, ListType):
# TODO: remove these special cases, ArrayRef fallthrough works fine
if str(t.elem) == 'int':
return "IntArrayRef"
elif str(t.elem) == 'Tensor':
return "TensorList"
elif str(t.elem) == 'Dimname':
return "DimnameList"
# TODO: do something reasonable about lists of optional tensors
elif (not local.use_c10_dispatcher().dispatcher_uses_new_style()) and str(t.elem) == 'Tensor?':
return "TensorList"
elem = argumenttype_type(t.elem, mutable=mutable)
# TODO: explicitly qualify namespace here
return f"ArrayRef<{elem}>"
else:
raise AssertionError(f"unrecognized type {repr(t)}")
# Translate a JIT argument into its C++ type
def argument_type(a: Argument) -> str:
return argumenttype_type(a.type, mutable=a.is_write)
# Translation of a (non-multi) return type from JIT to C++
def returntype_type(t: Type, *, mutable: bool) -> str:
r = valuetype_type(t)
if r is not None:
return r
if isinstance(t, BaseType):
if t.name == BaseTy.Tensor:
if mutable:
return 'Tensor &'
else:
return 'Tensor'
elif isinstance(t, ListType):
elem = returntype_type(t.elem, mutable=mutable)
assert t.size is None, f"fixed size list returns not supported: {t}"
return f"std::vector<{elem}>"
raise AssertionError(f"unrecognized return type {t}")
# Translation of a single return to its C++ type
def return_type(r: Return) -> str:
return returntype_type(r.type, mutable=r.is_write)
# Translation of a full (possibly multi) return from JIT to its C++ type
def returns_type(rs: Sequence[Return]) -> str:
if len(rs) == 0:
return 'void'
elif len(rs) == 1:
return return_type(rs[0])
else:
args = ','.join(map(return_type, rs))
return f'std::tuple<{args}>'
JIT_TO_CPP_DEFAULT = {
'False': 'false',
'True': 'true',
'None': 'c10::nullopt', # UGH this one is type directed
'Mean': 'at::Reduction::Mean',
'[]': '{}',
'contiguous_format': 'MemoryFormat::Contiguous',
'long': 'at::kLong',
}
# Convert a JIT default into C++ expression representing the default
def default_expr(d: str, t: Type) -> str:
if d == 'None' and str(t) == 'Tensor?':
return '{}'
if isinstance(t, BaseType) and t.name is BaseTy.str:
# Schema allows single quotes but C++ needs double
if len(d) >= 2 and d[0] == "'" and d[-1] == "'":
s = ''
i = 1
while i + 1 < len(d):
if d[i] != '\\':
if d[i] == '"':
s += '\\"'
else:
s += d[i]
i += 1
else:
if d[i + 1] == "'":
s += "'"
else:
s += d[i:i + 2]
i += 2
return f'"{s}"'
if isinstance(t, OptionalType):
if d == 'None':
return 'c10::nullopt'
return default_expr(d, t.elem)
if isinstance(t, ListType):
if (d.startswith('[') and d.endswith(']')):
return '{' + d[1:-1] + '}'
elif t.size is None:
# NOTE: Sized lists can have scalar defaults
raise ValueError(f"Expected a list default '[...]' but found: '{d}'")
return JIT_TO_CPP_DEFAULT.get(d, d)
# Convert an argument into its C++ API form
def argument_not_this(
a: Union[Argument, TensorOptionsArguments],
) -> CppArgument:
if isinstance(a, Argument):
return CppArgument(
type=argument_type(a),
name=a.name,
default=default_expr(a.default, a.type) if a.default is not None else None,
argument=a,
)
elif isinstance(a, TensorOptionsArguments):
default = None
if all(x.default == "None" for x in a.all()):
default = '{}'
elif a.dtype.default == "long":
default = 'at::kLong' # TODO: this is wrong
return CppArgument(
type='const TensorOptions &',
name='options',
default=default,
argument=a,
)
else:
assert_never(a)
def argument(
a: Union[Argument, TensorOptionsArguments, ThisArgument],
) -> Union[CppSingleArgumentPack, CppThisArgumentPack]:
if isinstance(a, ThisArgument):
return CppThisArgumentPack(argument=a, type=argument_type(a.argument))
else:
return CppSingleArgumentPack(argument_not_this(a))
def argument_faithful(
a: Union[Argument, TensorOptionsArguments, ThisArgument],
) -> CppArgumentPack:
if isinstance(a, TensorOptionsArguments):
return CppTensorOptionsArgumentPack(
argument=a,
dtype=argument_not_this(a.dtype),
layout=argument_not_this(a.layout),
device=argument_not_this(a.device),
pin_memory=argument_not_this(a.pin_memory),
)
else:
return argument(a)
# NB: this unconditionally groups arguments
def group_arguments(
func: FunctionSchema, *, method: bool
) -> Sequence[Union[Argument, TensorOptionsArguments, ThisArgument]]:
args: List[Union[Argument, ThisArgument, TensorOptionsArguments]] = []
args.extend(func.out_arguments)
if method:
args.extend(ThisArgument(a) if a.name == "self" else a for a in func.arguments)
else:
args.extend(func.arguments)
# group up arguments for tensor options
def pred(name: str, ty: Type) -> Callable[[Argument], bool]:
return lambda a: a.name == name and a.type in [ty, OptionalType(ty)]
predicates = [ # order matters
pred('dtype', Type.parse('ScalarType')),
pred('layout', Type.parse('Layout')),
pred('device', Type.parse('Device')),
pred('pin_memory', Type.parse('bool')),
]
i = 0
while i < len(func.kwarg_only_arguments):
# If there is enough space...
if i <= len(func.kwarg_only_arguments) - len(predicates):
# And the next len(predicates) arguments look like TensorOptions arguments
if all(p(a) for p, a in zip(predicates, func.kwarg_only_arguments[i : i + len(predicates)])):
# Group them together as one argument
args.append(TensorOptionsArguments(
dtype=func.kwarg_only_arguments[i],
layout=func.kwarg_only_arguments[i + 1],
device=func.kwarg_only_arguments[i + 2],
pin_memory=func.kwarg_only_arguments[i + 3],
))
i += len(predicates)
continue
args.append(func.kwarg_only_arguments[i])
i += 1
return args