pytorch/tools/codegen/api/cpp.py
Jiakai Liu e71a13e8a3 [pytorch][codegen] migrate gen_variable_type to new data model (#49735)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/49735

This is the final wave of autograd codegen data model migration.

After this PR:
- autograd codegen no longer depends on Declarations.yaml;
- autograd codegen sources are fully type annotated and pass mypy-strict check;

To avoid potential merge conflicts with other pending PRs, some structural
changes are intentionally avoided, e.g. didn't move inner methods out, didn't
change all inner methods to avoid reading outer function's variables, and etc.

Confirmed byte-for-byte compatible with the old codegen:
```
Run it before and after this PR:
  .jenkins/pytorch/codegen-test.sh <baseline_output_dir>
  .jenkins/pytorch/codegen-test.sh <test_output_dir>

Then run diff to compare the generated files:
  diff -Naur <baseline_output_dir> <test_output_dir>
```

Confirmed clean mypy-strict run:
```
mypy --config mypy-strict.ini
```

Test Plan: Imported from OSS

Reviewed By: ezyang, bhosmer

Differential Revision: D25678879

Pulled By: ljk53

fbshipit-source-id: ba6e2eb6b9fb744208f7f79a922d933fcc3bde9f
2021-01-05 14:12:39 -08:00

310 lines
12 KiB
Python

from tools.codegen.model import *
from tools.codegen.api.types import *
import tools.codegen.local as local
from typing import Optional, Sequence, Union, List, Set
# This file describes the translation of JIT schema to the public C++
# API, which is what people use when they call functions like at::add.
#
# Prominent characteristics of the C++ API:
#
# - dtype, layout, device and pin_memory are collected into
# a single C++ type TensorOptions (the native functions API
# also has this, but tensor options is really most relevant
# for the C++ API; it makes calling kwarg factory functions
# pleasant)
#
# - for 'use_c10_dispatcher: full' functions, optional tensors are
# represented explicitly using c10::optional
#
# - defaulting lives here (in fact, the dispatcher is completely
# oblivious of defaults!)
#
# BTW: policy on name collisions: we try not to have types with
# collisions, but functions are fair game to collide
def name(func: FunctionSchema, *, faithful_name_for_out_overloads: bool = False) -> str:
name = str(func.name.name)
if func.is_out_fn():
if faithful_name_for_out_overloads:
name += '_outf'
else:
name += '_out'
return name
# Translation of "value types" in JIT schema to C++ API type. Value
# types look the same no matter if they are argument types or return
# types. Returns None if the type in question is not a value type.
def valuetype_type(t: Type, *, binds: ArgName) -> Optional[CType]:
if isinstance(t, BaseType):
if t.name == BaseTy.Tensor:
return None
elif t.name == BaseTy.int:
return BaseCType('int64_t', binds)
elif t.name == BaseTy.float:
return BaseCType('double', binds)
elif t.name == BaseTy.str:
return BaseCType('std::string', binds)
elif t.name in [BaseTy.bool, BaseTy.QScheme, BaseTy.Scalar,
BaseTy.ScalarType, BaseTy.Generator, BaseTy.Storage,
BaseTy.Layout, BaseTy.Device, BaseTy.MemoryFormat,
BaseTy.Dimname, BaseTy.Stream, BaseTy.ConstQuantizerPtr]:
# These C++ names line up with their schema names
return BaseCType(t.name.name, binds)
else:
raise AssertionError(f"unsupported type: {t}")
elif isinstance(t, OptionalType):
elem = valuetype_type(t.elem, binds=binds)
if elem is None:
return None
return OptionalCType(elem)
elif isinstance(t, ListType):
if str(t.elem) == 'bool':
assert t.size is not None
return BaseCType(f"std::array<bool,{t.size}>", binds)
else:
return None
else:
raise AssertionError(f"unrecognized type {repr(t)}")
# Translation of types occuring in JIT arguments to a C++ argument type.
def argumenttype_type(t: Type, *, mutable: bool, binds: ArgName) -> CType:
# If it's a value type, do the value type translation
r = valuetype_type(t, binds=binds)
if r is not None:
return r
if isinstance(t, BaseType):
if t.name == BaseTy.Tensor:
if mutable:
return MutRefCType(BaseCType('Tensor', binds))
else:
return ConstRefCType(BaseCType('Tensor', binds))
else:
raise AssertionError(f"base type should have been value type {t}")
elif isinstance(t, OptionalType):
if str(t.elem) == 'Tensor':
if mutable:
return MutRefCType(BaseCType('Tensor', binds)) # TODO: fix this discrepancy
else:
if local.use_c10_dispatcher().dispatcher_uses_new_style():
return ConstRefCType(OptionalCType(BaseCType('Tensor', binds)))
else:
return ConstRefCType(BaseCType('Tensor', binds))
elem = argumenttype_type(t.elem, mutable=mutable, binds=binds)
return OptionalCType(elem)
elif isinstance(t, ListType):
# TODO: remove these special cases, ArrayRef fallthrough works fine
# NB: CType throws away ArrayRef structure because it is not currently
# relevant in translation. When it becomes relevant, need to add back
if str(t.elem) == 'int':
return BaseCType("IntArrayRef", binds)
elif str(t.elem) == 'Tensor':
return BaseCType("TensorList", binds)
elif str(t.elem) == 'Dimname':
return BaseCType("DimnameList", binds)
elif str(t.elem) == 'Tensor?':
if local.use_c10_dispatcher().dispatcher_uses_new_style():
return ConstRefCType(BaseCType("c10::List<c10::optional<Tensor>>", binds))
else:
return BaseCType("TensorList", binds)
elem = argumenttype_type(t.elem, mutable=mutable, binds=binds)
# TODO: explicitly qualify namespace here
return BaseCType(f"ArrayRef<{elem.cpp_type()}>", binds)
else:
raise AssertionError(f"unrecognized type {repr(t)}")
# Translate a JIT argument into its C++ type
def argument_type(a: Argument, *, binds: ArgName) -> CType:
return argumenttype_type(a.type, mutable=a.is_write, binds=binds)
# Translation of a (non-multi) return type from JIT to C++
# NB: if need translations on return types, make this return CType too. Need to
# take care; ArgName is misnomer now, and inputs are permitted to conflict with outputs
# so need to make sure you don't have trouble
def returntype_type(t: Type, *, mutable: bool) -> str:
# placeholder is ignored
r = valuetype_type(t, binds="__placeholder__")
if r is not None:
return r.cpp_type()
if isinstance(t, BaseType):
if t.name == BaseTy.Tensor:
if mutable:
return 'Tensor &'
else:
return 'Tensor'
elif isinstance(t, ListType):
elem = returntype_type(t.elem, mutable=mutable)
assert t.size is None, f"fixed size list returns not supported: {t}"
return f"std::vector<{elem}>"
raise AssertionError(f"unrecognized return type {t}")
# Translation of a single return to its C++ type
def return_type(r: Return) -> str:
return returntype_type(r.type, mutable=r.is_write)
# Translation of a full (possibly multi) return from JIT to its C++ type
def returns_type(rs: Sequence[Return]) -> str:
if len(rs) == 0:
return 'void'
elif len(rs) == 1:
return return_type(rs[0])
else:
args = ','.join(map(return_type, rs))
return f'std::tuple<{args}>'
def return_names(f: NativeFunction) -> Sequence[str]:
returns: List[str] = []
for i, r in enumerate(f.func.returns):
# If we have an inplace function, the return argument is
# implicitly named self.
# TODO: Consider incorporating this into the data model
if f.func.name.name.inplace:
assert i == 0, "illegal inplace function with multiple returns"
name = 'self'
# If we are out function, the name is the name of the
# corresponding output function (r.name will get recorded
# in field_name later.)
elif f.func.is_out_fn():
name = f.func.arguments.out[i].name
# If the return argument is explicitly named...
elif r.name:
name_conflict = any(r.name == a.name for a in f.func.schema_order_arguments())
if name_conflict and not f.func.is_out_fn():
name = f'{r.name}_return'
else:
name = r.name
# If there is no explicit name, we just name the output result,
# unless it's a multi-return, in which case it's result0,
# result1, etc (zero-indexed)
else:
name = 'result' if len(f.func.returns) == 1 else f'result{i}'
returns.append(name)
return returns
JIT_TO_CPP_DEFAULT = {
'False': 'false',
'True': 'true',
'None': 'c10::nullopt', # UGH this one is type directed
'Mean': 'at::Reduction::Mean',
'[]': '{}',
'contiguous_format': 'MemoryFormat::Contiguous',
'long': 'at::kLong',
}
# Convert a JIT default into C++ expression representing the default
def default_expr(d: str, t: Type) -> str:
if d == 'None' and str(t) == 'Tensor?':
return '{}'
if isinstance(t, BaseType) and t.name is BaseTy.str:
# Schema allows single quotes but C++ needs double
if len(d) >= 2 and d[0] == "'" and d[-1] == "'":
s = ''
i = 1
while i + 1 < len(d):
if d[i] != '\\':
if d[i] == '"':
s += '\\"'
else:
s += d[i]
i += 1
else:
if d[i + 1] == "'":
s += "'"
else:
s += d[i:i + 2]
i += 2
return f'"{s}"'
if isinstance(t, OptionalType):
if d == 'None':
return 'c10::nullopt'
return default_expr(d, t.elem)
if isinstance(t, ListType):
if (d.startswith('[') and d.endswith(']')):
return '{' + d[1:-1] + '}'
elif t.size is None:
# NOTE: Sized lists can have scalar defaults
raise ValueError(f"Expected a list default '[...]' but found: '{d}'")
return JIT_TO_CPP_DEFAULT.get(d, d)
# Convert an argument into its C++ API form
def argument(
a: Union[Argument, TensorOptionsArguments, SelfArgument],
*, cpp_no_default_args: Set[str], method: bool, faithful: bool,
has_tensor_options: bool
) -> List[Binding]:
def sub_argument(a: Union[Argument, TensorOptionsArguments, SelfArgument]) -> List[Binding]:
return argument(
a, cpp_no_default_args=cpp_no_default_args, method=method, faithful=faithful,
has_tensor_options=has_tensor_options)
if isinstance(a, Argument):
binds: ArgName
if a.name == "memory_format" and has_tensor_options:
binds = SpecialArgName.possibly_redundant_memory_format
else:
binds = a.name
default: Optional[str] = None
if a.name not in cpp_no_default_args and a.default is not None:
default = default_expr(a.default, a.type)
return [Binding(
ctype=argument_type(a, binds=binds),
name=a.name,
default=default,
argument=a,
)]
elif isinstance(a, TensorOptionsArguments):
if faithful:
return sub_argument(a.dtype) + sub_argument(a.layout) + \
sub_argument(a.device) + sub_argument(a.pin_memory)
else:
default = None
# Enforced by NativeFunction.__post_init__
assert 'options' not in cpp_no_default_args
if all(x.default == "None" for x in a.all()):
default = '{}'
elif a.dtype.default == "long":
default = 'at::kLong' # TODO: this is wrong
return [Binding(
ctype=ConstRefCType(BaseCType('TensorOptions', 'options')),
name='options',
default=default,
argument=a,
)]
elif isinstance(a, SelfArgument):
if method:
# Caller is responsible for installing implicit this in context!
return []
else:
return sub_argument(a.argument)
else:
assert_never(a)
def arguments(
arguments: Arguments,
*, faithful: bool, method: bool, cpp_no_default_args: Set[str]
) -> List[Binding]:
args: List[Union[Argument, TensorOptionsArguments, SelfArgument]] = []
if faithful:
args.extend(arguments.non_out)
args.extend(arguments.out)
else:
args.extend(arguments.out)
args.extend(arguments.non_out)
return [
r.no_default() if faithful else r for a in args
for r in argument(
a, faithful=faithful, method=method,
has_tensor_options=arguments.tensor_options is not None,
cpp_no_default_args=cpp_no_default_args)
]