mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/49735 This is the final wave of autograd codegen data model migration. After this PR: - autograd codegen no longer depends on Declarations.yaml; - autograd codegen sources are fully type annotated and pass mypy-strict check; To avoid potential merge conflicts with other pending PRs, some structural changes are intentionally avoided, e.g. didn't move inner methods out, didn't change all inner methods to avoid reading outer function's variables, and etc. Confirmed byte-for-byte compatible with the old codegen: ``` Run it before and after this PR: .jenkins/pytorch/codegen-test.sh <baseline_output_dir> .jenkins/pytorch/codegen-test.sh <test_output_dir> Then run diff to compare the generated files: diff -Naur <baseline_output_dir> <test_output_dir> ``` Confirmed clean mypy-strict run: ``` mypy --config mypy-strict.ini ``` Test Plan: Imported from OSS Reviewed By: ezyang, bhosmer Differential Revision: D25678879 Pulled By: ljk53 fbshipit-source-id: ba6e2eb6b9fb744208f7f79a922d933fcc3bde9f
310 lines
12 KiB
Python
310 lines
12 KiB
Python
from tools.codegen.model import *
|
|
from tools.codegen.api.types import *
|
|
import tools.codegen.local as local
|
|
from typing import Optional, Sequence, Union, List, Set
|
|
|
|
# This file describes the translation of JIT schema to the public C++
|
|
# API, which is what people use when they call functions like at::add.
|
|
#
|
|
# Prominent characteristics of the C++ API:
|
|
#
|
|
# - dtype, layout, device and pin_memory are collected into
|
|
# a single C++ type TensorOptions (the native functions API
|
|
# also has this, but tensor options is really most relevant
|
|
# for the C++ API; it makes calling kwarg factory functions
|
|
# pleasant)
|
|
#
|
|
# - for 'use_c10_dispatcher: full' functions, optional tensors are
|
|
# represented explicitly using c10::optional
|
|
#
|
|
# - defaulting lives here (in fact, the dispatcher is completely
|
|
# oblivious of defaults!)
|
|
#
|
|
# BTW: policy on name collisions: we try not to have types with
|
|
# collisions, but functions are fair game to collide
|
|
|
|
def name(func: FunctionSchema, *, faithful_name_for_out_overloads: bool = False) -> str:
|
|
name = str(func.name.name)
|
|
if func.is_out_fn():
|
|
if faithful_name_for_out_overloads:
|
|
name += '_outf'
|
|
else:
|
|
name += '_out'
|
|
|
|
return name
|
|
|
|
# Translation of "value types" in JIT schema to C++ API type. Value
|
|
# types look the same no matter if they are argument types or return
|
|
# types. Returns None if the type in question is not a value type.
|
|
def valuetype_type(t: Type, *, binds: ArgName) -> Optional[CType]:
|
|
if isinstance(t, BaseType):
|
|
if t.name == BaseTy.Tensor:
|
|
return None
|
|
elif t.name == BaseTy.int:
|
|
return BaseCType('int64_t', binds)
|
|
elif t.name == BaseTy.float:
|
|
return BaseCType('double', binds)
|
|
elif t.name == BaseTy.str:
|
|
return BaseCType('std::string', binds)
|
|
elif t.name in [BaseTy.bool, BaseTy.QScheme, BaseTy.Scalar,
|
|
BaseTy.ScalarType, BaseTy.Generator, BaseTy.Storage,
|
|
BaseTy.Layout, BaseTy.Device, BaseTy.MemoryFormat,
|
|
BaseTy.Dimname, BaseTy.Stream, BaseTy.ConstQuantizerPtr]:
|
|
# These C++ names line up with their schema names
|
|
return BaseCType(t.name.name, binds)
|
|
else:
|
|
raise AssertionError(f"unsupported type: {t}")
|
|
elif isinstance(t, OptionalType):
|
|
elem = valuetype_type(t.elem, binds=binds)
|
|
if elem is None:
|
|
return None
|
|
return OptionalCType(elem)
|
|
elif isinstance(t, ListType):
|
|
if str(t.elem) == 'bool':
|
|
assert t.size is not None
|
|
return BaseCType(f"std::array<bool,{t.size}>", binds)
|
|
else:
|
|
return None
|
|
else:
|
|
raise AssertionError(f"unrecognized type {repr(t)}")
|
|
|
|
# Translation of types occuring in JIT arguments to a C++ argument type.
|
|
def argumenttype_type(t: Type, *, mutable: bool, binds: ArgName) -> CType:
|
|
# If it's a value type, do the value type translation
|
|
r = valuetype_type(t, binds=binds)
|
|
if r is not None:
|
|
return r
|
|
|
|
if isinstance(t, BaseType):
|
|
if t.name == BaseTy.Tensor:
|
|
if mutable:
|
|
return MutRefCType(BaseCType('Tensor', binds))
|
|
else:
|
|
return ConstRefCType(BaseCType('Tensor', binds))
|
|
else:
|
|
raise AssertionError(f"base type should have been value type {t}")
|
|
elif isinstance(t, OptionalType):
|
|
if str(t.elem) == 'Tensor':
|
|
if mutable:
|
|
return MutRefCType(BaseCType('Tensor', binds)) # TODO: fix this discrepancy
|
|
else:
|
|
if local.use_c10_dispatcher().dispatcher_uses_new_style():
|
|
return ConstRefCType(OptionalCType(BaseCType('Tensor', binds)))
|
|
else:
|
|
return ConstRefCType(BaseCType('Tensor', binds))
|
|
elem = argumenttype_type(t.elem, mutable=mutable, binds=binds)
|
|
return OptionalCType(elem)
|
|
elif isinstance(t, ListType):
|
|
# TODO: remove these special cases, ArrayRef fallthrough works fine
|
|
# NB: CType throws away ArrayRef structure because it is not currently
|
|
# relevant in translation. When it becomes relevant, need to add back
|
|
if str(t.elem) == 'int':
|
|
return BaseCType("IntArrayRef", binds)
|
|
elif str(t.elem) == 'Tensor':
|
|
return BaseCType("TensorList", binds)
|
|
elif str(t.elem) == 'Dimname':
|
|
return BaseCType("DimnameList", binds)
|
|
elif str(t.elem) == 'Tensor?':
|
|
if local.use_c10_dispatcher().dispatcher_uses_new_style():
|
|
return ConstRefCType(BaseCType("c10::List<c10::optional<Tensor>>", binds))
|
|
else:
|
|
return BaseCType("TensorList", binds)
|
|
elem = argumenttype_type(t.elem, mutable=mutable, binds=binds)
|
|
# TODO: explicitly qualify namespace here
|
|
return BaseCType(f"ArrayRef<{elem.cpp_type()}>", binds)
|
|
else:
|
|
raise AssertionError(f"unrecognized type {repr(t)}")
|
|
|
|
# Translate a JIT argument into its C++ type
|
|
def argument_type(a: Argument, *, binds: ArgName) -> CType:
|
|
return argumenttype_type(a.type, mutable=a.is_write, binds=binds)
|
|
|
|
# Translation of a (non-multi) return type from JIT to C++
|
|
# NB: if need translations on return types, make this return CType too. Need to
|
|
# take care; ArgName is misnomer now, and inputs are permitted to conflict with outputs
|
|
# so need to make sure you don't have trouble
|
|
def returntype_type(t: Type, *, mutable: bool) -> str:
|
|
# placeholder is ignored
|
|
r = valuetype_type(t, binds="__placeholder__")
|
|
if r is not None:
|
|
return r.cpp_type()
|
|
|
|
if isinstance(t, BaseType):
|
|
if t.name == BaseTy.Tensor:
|
|
if mutable:
|
|
return 'Tensor &'
|
|
else:
|
|
return 'Tensor'
|
|
elif isinstance(t, ListType):
|
|
elem = returntype_type(t.elem, mutable=mutable)
|
|
assert t.size is None, f"fixed size list returns not supported: {t}"
|
|
return f"std::vector<{elem}>"
|
|
|
|
raise AssertionError(f"unrecognized return type {t}")
|
|
|
|
# Translation of a single return to its C++ type
|
|
def return_type(r: Return) -> str:
|
|
return returntype_type(r.type, mutable=r.is_write)
|
|
|
|
# Translation of a full (possibly multi) return from JIT to its C++ type
|
|
def returns_type(rs: Sequence[Return]) -> str:
|
|
if len(rs) == 0:
|
|
return 'void'
|
|
elif len(rs) == 1:
|
|
return return_type(rs[0])
|
|
else:
|
|
args = ','.join(map(return_type, rs))
|
|
return f'std::tuple<{args}>'
|
|
|
|
def return_names(f: NativeFunction) -> Sequence[str]:
|
|
returns: List[str] = []
|
|
for i, r in enumerate(f.func.returns):
|
|
# If we have an inplace function, the return argument is
|
|
# implicitly named self.
|
|
# TODO: Consider incorporating this into the data model
|
|
if f.func.name.name.inplace:
|
|
assert i == 0, "illegal inplace function with multiple returns"
|
|
name = 'self'
|
|
# If we are out function, the name is the name of the
|
|
# corresponding output function (r.name will get recorded
|
|
# in field_name later.)
|
|
elif f.func.is_out_fn():
|
|
name = f.func.arguments.out[i].name
|
|
# If the return argument is explicitly named...
|
|
elif r.name:
|
|
name_conflict = any(r.name == a.name for a in f.func.schema_order_arguments())
|
|
if name_conflict and not f.func.is_out_fn():
|
|
name = f'{r.name}_return'
|
|
else:
|
|
name = r.name
|
|
# If there is no explicit name, we just name the output result,
|
|
# unless it's a multi-return, in which case it's result0,
|
|
# result1, etc (zero-indexed)
|
|
else:
|
|
name = 'result' if len(f.func.returns) == 1 else f'result{i}'
|
|
returns.append(name)
|
|
return returns
|
|
|
|
JIT_TO_CPP_DEFAULT = {
|
|
'False': 'false',
|
|
'True': 'true',
|
|
'None': 'c10::nullopt', # UGH this one is type directed
|
|
'Mean': 'at::Reduction::Mean',
|
|
'[]': '{}',
|
|
'contiguous_format': 'MemoryFormat::Contiguous',
|
|
'long': 'at::kLong',
|
|
}
|
|
|
|
# Convert a JIT default into C++ expression representing the default
|
|
def default_expr(d: str, t: Type) -> str:
|
|
if d == 'None' and str(t) == 'Tensor?':
|
|
return '{}'
|
|
if isinstance(t, BaseType) and t.name is BaseTy.str:
|
|
# Schema allows single quotes but C++ needs double
|
|
if len(d) >= 2 and d[0] == "'" and d[-1] == "'":
|
|
s = ''
|
|
i = 1
|
|
while i + 1 < len(d):
|
|
if d[i] != '\\':
|
|
if d[i] == '"':
|
|
s += '\\"'
|
|
else:
|
|
s += d[i]
|
|
i += 1
|
|
else:
|
|
if d[i + 1] == "'":
|
|
s += "'"
|
|
else:
|
|
s += d[i:i + 2]
|
|
i += 2
|
|
|
|
return f'"{s}"'
|
|
|
|
if isinstance(t, OptionalType):
|
|
if d == 'None':
|
|
return 'c10::nullopt'
|
|
|
|
return default_expr(d, t.elem)
|
|
|
|
if isinstance(t, ListType):
|
|
if (d.startswith('[') and d.endswith(']')):
|
|
return '{' + d[1:-1] + '}'
|
|
elif t.size is None:
|
|
# NOTE: Sized lists can have scalar defaults
|
|
raise ValueError(f"Expected a list default '[...]' but found: '{d}'")
|
|
|
|
return JIT_TO_CPP_DEFAULT.get(d, d)
|
|
|
|
# Convert an argument into its C++ API form
|
|
|
|
def argument(
|
|
a: Union[Argument, TensorOptionsArguments, SelfArgument],
|
|
*, cpp_no_default_args: Set[str], method: bool, faithful: bool,
|
|
has_tensor_options: bool
|
|
) -> List[Binding]:
|
|
def sub_argument(a: Union[Argument, TensorOptionsArguments, SelfArgument]) -> List[Binding]:
|
|
return argument(
|
|
a, cpp_no_default_args=cpp_no_default_args, method=method, faithful=faithful,
|
|
has_tensor_options=has_tensor_options)
|
|
|
|
if isinstance(a, Argument):
|
|
binds: ArgName
|
|
if a.name == "memory_format" and has_tensor_options:
|
|
binds = SpecialArgName.possibly_redundant_memory_format
|
|
else:
|
|
binds = a.name
|
|
default: Optional[str] = None
|
|
if a.name not in cpp_no_default_args and a.default is not None:
|
|
default = default_expr(a.default, a.type)
|
|
return [Binding(
|
|
ctype=argument_type(a, binds=binds),
|
|
name=a.name,
|
|
default=default,
|
|
argument=a,
|
|
)]
|
|
elif isinstance(a, TensorOptionsArguments):
|
|
if faithful:
|
|
return sub_argument(a.dtype) + sub_argument(a.layout) + \
|
|
sub_argument(a.device) + sub_argument(a.pin_memory)
|
|
else:
|
|
default = None
|
|
# Enforced by NativeFunction.__post_init__
|
|
assert 'options' not in cpp_no_default_args
|
|
if all(x.default == "None" for x in a.all()):
|
|
default = '{}'
|
|
elif a.dtype.default == "long":
|
|
default = 'at::kLong' # TODO: this is wrong
|
|
return [Binding(
|
|
ctype=ConstRefCType(BaseCType('TensorOptions', 'options')),
|
|
name='options',
|
|
default=default,
|
|
argument=a,
|
|
)]
|
|
elif isinstance(a, SelfArgument):
|
|
if method:
|
|
# Caller is responsible for installing implicit this in context!
|
|
return []
|
|
else:
|
|
return sub_argument(a.argument)
|
|
else:
|
|
assert_never(a)
|
|
|
|
def arguments(
|
|
arguments: Arguments,
|
|
*, faithful: bool, method: bool, cpp_no_default_args: Set[str]
|
|
) -> List[Binding]:
|
|
args: List[Union[Argument, TensorOptionsArguments, SelfArgument]] = []
|
|
if faithful:
|
|
args.extend(arguments.non_out)
|
|
args.extend(arguments.out)
|
|
else:
|
|
args.extend(arguments.out)
|
|
args.extend(arguments.non_out)
|
|
return [
|
|
r.no_default() if faithful else r for a in args
|
|
for r in argument(
|
|
a, faithful=faithful, method=method,
|
|
has_tensor_options=arguments.tensor_options is not None,
|
|
cpp_no_default_args=cpp_no_default_args)
|
|
]
|