mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
pyi codegen update - remove Declarations.yaml (#48754)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/48754 The goal of this PR is to kill Declarations.yaml in the pyi codegen, in favor of native_functions + the existing python object model. **High-level design** Since the python signatures used by the `python_arg_parser` are “supposed” to resemble the corresponding pyi type hint signatures, I re-used the existing python object model that Jiakai defined in `tools/codegen/api/python.py`. This means that the pyi codegen now reads `native_functions.yaml`, parses it into a bunch of `PythonSignatureGroup` objects, and emits corresponding method + function variants of type-hint signatures for each one, respectively into `__init__.pyi` and `_VariableFunctions.pyi`. What makes this uglier is that pyi and the python arg parser have a number of differences in how they’re emitted. I expressed that through a `pyi` flag on the `PythonSignature` dataclass, that tells it whether or not to print itself as a pyi vs. arg_parser signature. One thing worth noting is how pyi generates signatures differently for native / deprecated op signatures. For native ops: - The pyi codegen fuses functional and out variants of each op into a single signature with an optional `out` argument. Ops without an `out` variant just get an ordinary functional signature. - Some ops that fit certain criteria also get a second “varargs” signature - basically ops with a single positional argument of type List[int]. For deprecated signatures: - Functional and out variants are not fused - they each get their own signature entry - There are no varargs signatures This is currently implemented through the `signature_str()` and `signature_str_vararg()` methods on the `PythonSignature`/`PythonSignatureDeprecated` classes. `signature_str()` knows how to print itself with/without out arguments, differently for native/deprecated ops. `signature_str_vararg()` optionally returns a vararg variant of the signature if one exists. **Calling out the gap between python_arg_parser vs. pyi** The two formats are notably different, so I don’t think we can expect to unify them completely. That said, I encountered a number of differences in the pyi codegen that looked wrong- I tried to call them out in the PR, to be removed later. Just as an example, looking at the `svd` signature in the python_arg_parser vs. the pyi type hint: python_arg_parser ``` Static PythonArgParser parser({ “svd(Tensor input, bool some=True, bool compute_uv=True, *, TensorList[3] out=None”, }, /*traceable=*/true); ``` Pyi ``` def svd(input: Tensor, some: _bool=True, compute_uv: _bool=True, *, out: Optional[Tensor]=None) -> namedtuple_U_S_V: … ``` The two have obvious syntactic differences that we probably don’t plan on changing: the python_arg_parser doesn’t include `def` or return types, and it includes the type hint before the variable name. But the type of `out` in pyi is probably wrong, since `svd` has multiple output params. I tried to clearly call out any instances of the pyi codegen diverging in a way that looks buggy, so we can clean it up in a later PR (see the comments for details). Another particularly ugly “bug” that I kept in to maintain byte-for-byte compatibility is the fact that the pyi codegen groups operator overloads together. It turns out that the only reason it does this (as far as I can tell) is because is tacks on an out argument to signatures that don’t have one, if ANY overloads of that op have an out variant. E.g. consider the pyi type hints generated for `nanmedian` in `_VF.pyi`: ``` overload def nanmedian(input: Tensor, *, out: Optional[Tensor]=None) -> Tensor: ... overload def nanmedian(input: Tensor, dim: _int, keepdim: _bool=False, *, out: Optional[Tensor]=None) -> namedtuple_values_indices: ... overload def nanmedian(input: Tensor, dim: Union[str, ellipsis, None], keepdim: _bool=False, *, out: Optional[Tensor]=None) -> namedtuple_values_indices: ... ``` And the corresponding native_functions.yaml entries: ``` - func: nanmedian(Tensor self) -> Tensor - func: nanmedian.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices) - func: nanmedian.dim_values(Tensor self, int dim, bool keepdim=False, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices) - func: nanmedian.names_dim(Tensor self, Dimname dim, bool keepdim=False) -> (Tensor values, Tensor indices) - func: nanmedian.names_dim_values(Tensor self, Dimname dim, bool keepdim=False, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) ``` Signature 2 corresponds to entries 2 and 3 in native_functions, and Signature 3 corresponds to entries 4 and 5. But signature 1 has an optional out argument, even though entry 1 in native_functions.yaml has no out variant. I’d like to delete that logic in a later PR- that will also have the added benefit no longer requiring to group overloads together in the pyi codegen. We can just operate independently on each PythonSignatureGroup. **More detailed accounting of the changes** Per file: gen_python_functions.py - `load_signatures()` can now skip deprecated signatures. Needed because pyi only includes deprecated functions, and skips their method variants (maybe we should add them in…?) - Moved `namedtuple_fieldnames` into python.cpp - `group_overloads()` can now opt to not sort the overloads (needed for byte-for-byte compact, pyi doesn’t sort for some reason) Python.py: - Gave `PythonSignature`and `PythonSignatureDeprecated` a `pyi` flag that tells it whether or not to print itself in pyi vs. python_arg_parser format - Added a `PythonReturns` dataclass , which is now a member of PythonSignature. It is only used by pyi. I found this useful because python returns need to know how to deal with named tuple returns properly. I also moved `namedtuple_fieldnames` into this file from gen_python_functions gen_pyi.py - Merged `get_py_torch_functions` and `get_py_variable_methods` into a single function, since they’re very similar - Lifted out all of the pyi type hint type-mapping mess and dropped it into python.py. This required updating the mapping to deal with NativeFunction objects instead of the outputs of Declarations.yaml (this was most of the logic in `type_to_python`, `arg_to_type_hint`, and `generate_type_hints`). `generate_type_hints` is now a small orchestration function that gathers the different signatures for each PythonSignatureGroup. - NamedTuples are now generated by calling `PythonReturn.named_tuple()` (in `generate_named_tuples()`), rather than appending to a global list A lot of hardcoded pyi signatures still live in `gen_pyi.py`. I didn’t look to closely into whether or not any of that can be removed as part of this PR. Test Plan: Imported from OSS Reviewed By: ljk53 Differential Revision: D25343802 Pulled By: bdhirsh fbshipit-source-id: f73e99e1afef934ff41e4aca3dabf34273459a52
This commit is contained in:
parent
f2c3efd51f
commit
ba6511b304
|
|
@ -38,6 +38,8 @@ mkdir -p "$OUT"/pyi/torch/_C
|
|||
mkdir -p "$OUT"/pyi/torch/nn
|
||||
python -m tools.pyi.gen_pyi \
|
||||
--declarations-path "$OUT"/torch/share/ATen/Declarations.yaml \
|
||||
--native-functions-path aten/src/ATen/native/native_functions.yaml \
|
||||
--deprecated-functions-path tools/autograd/deprecated.yaml \
|
||||
--out "$OUT"/pyi
|
||||
|
||||
# autograd codegen (called by torch codegen but can run independently)
|
||||
|
|
|
|||
|
|
@ -35,6 +35,7 @@ files = tools/codegen/gen.py,
|
|||
tools/autograd/gen_trace_type.py,
|
||||
tools/autograd/gen_variable_factories.py,
|
||||
tools/autograd/load_derivatives.py,
|
||||
tools/pyi/gen_pyi.py,
|
||||
torch/utils/benchmark/utils/common.py,
|
||||
torch/utils/benchmark/utils/timer.py,
|
||||
torch/utils/benchmark/utils/valgrind_wrapper/*.py,
|
||||
|
|
|
|||
|
|
@ -193,25 +193,28 @@ def load_signatures(
|
|||
deprecated_yaml_path: str,
|
||||
*,
|
||||
method: bool,
|
||||
skip_deprecated: bool = False,
|
||||
pyi: bool = False,
|
||||
) -> Sequence[PythonSignatureNativeFunctionPair]:
|
||||
native_functions = list(filter(should_generate_py_binding, parse_native_yaml(native_yaml_path)))
|
||||
|
||||
@with_native_function
|
||||
def gen_signature_pairs(f: NativeFunction) -> PythonSignatureNativeFunctionPair:
|
||||
return PythonSignatureNativeFunctionPair(
|
||||
signature=signature(f, method=method),
|
||||
signature=signature(f, method=method, pyi=pyi),
|
||||
function=f,
|
||||
)
|
||||
|
||||
pairs = list(map(gen_signature_pairs, native_functions))
|
||||
deprecated = load_deprecated_signatures(pairs, deprecated_yaml_path, method=method)
|
||||
return pairs + deprecated
|
||||
deprecated = load_deprecated_signatures(pairs, deprecated_yaml_path, method=method, pyi=pyi)
|
||||
return pairs if skip_deprecated else pairs + deprecated
|
||||
|
||||
def load_deprecated_signatures(
|
||||
pairs: Sequence[PythonSignatureNativeFunctionPair],
|
||||
deprecated_yaml_path: str,
|
||||
*,
|
||||
method: bool,
|
||||
pyi: bool,
|
||||
) -> List[PythonSignatureNativeFunctionPair]:
|
||||
# The deprecated.yaml doesn't have complete type information, we need
|
||||
# find and leverage the original ATen signature (to which it delegates
|
||||
|
|
@ -225,6 +228,10 @@ def load_deprecated_signatures(
|
|||
opname = str(f.func.name.name.base)
|
||||
if f.func.is_out_fn():
|
||||
opname += '_out'
|
||||
# TODO: remove HACK
|
||||
# I think we want to differentiate inplace functions here.. but we currently don't for the arg parser
|
||||
if f.func.name.name.inplace and pyi:
|
||||
opname += '_'
|
||||
args = CppSignatureGroup.from_schema(f.func, method=False).signature.arguments()
|
||||
# Simply ignore TensorOptionsArguments as it does not exist in deprecated.yaml.
|
||||
types = ', '.join(argument_type_str(a.argument.type)
|
||||
|
|
@ -308,6 +315,7 @@ def load_deprecated_signatures(
|
|||
method=python_sig.method,
|
||||
deprecated_args_names=tuple(args),
|
||||
deprecated_args_exprs=tuple(call_args),
|
||||
returns=python_sig.returns,
|
||||
),
|
||||
function=pair.function,
|
||||
))
|
||||
|
|
@ -320,31 +328,10 @@ def load_deprecated_signatures(
|
|||
#
|
||||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
|
||||
|
||||
# TODO: remove the copy of this method in 'tools/pyi/gen_pyi.py'.
|
||||
@with_native_function
|
||||
def namedtuple_fieldnames(f: NativeFunction) -> List[str]:
|
||||
returns = f.func.returns
|
||||
if len(returns) <= 1 or all(map(lambda r: r.name is None, returns)):
|
||||
return []
|
||||
else:
|
||||
if any(map(lambda r: r.name is None, returns)):
|
||||
# When building on Windows, `PyStructSequence_UnnamedField` could not be
|
||||
# resolved by the linker for some reason, which cause error in building:
|
||||
#
|
||||
# python_nn_functions.cpp.obj : error LNK2001: unresolved external symbol
|
||||
# PyStructSequence_UnnamedField
|
||||
#
|
||||
# Thus, at this point in time, we do not support unnamed
|
||||
# fields in namedtuple; you must either name all fields,
|
||||
# or none of them.
|
||||
raise ValueError("Unnamed field is not supported by codegen")
|
||||
|
||||
return list(map(lambda r: str(r.name), returns))
|
||||
|
||||
@with_native_function
|
||||
def gen_namedtuple_typename_key(f: NativeFunction) -> str:
|
||||
name = cpp.name(f.func)
|
||||
fieldnames = namedtuple_fieldnames(f)
|
||||
fieldnames = namedtuple_fieldnames(f.func.returns)
|
||||
return '_'.join([name] + fieldnames)
|
||||
|
||||
def emit_namedtuple_typedefs(
|
||||
|
|
@ -360,7 +347,7 @@ def emit_namedtuple_typedefs(
|
|||
typedefs: List[str] = [] # typedef declarations and init code
|
||||
|
||||
for overload in overloads:
|
||||
fieldnames = namedtuple_fieldnames(overload.function)
|
||||
fieldnames = namedtuple_fieldnames(overload.function.func.returns)
|
||||
if not fieldnames:
|
||||
continue
|
||||
|
||||
|
|
@ -651,7 +638,9 @@ def method_def(
|
|||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
|
||||
|
||||
def group_overloads(
|
||||
overloads: Sequence[PythonSignatureNativeFunctionPair]
|
||||
overloads: Sequence[PythonSignatureNativeFunctionPair],
|
||||
*,
|
||||
sort: bool = True,
|
||||
) -> Sequence[PythonSignatureGroup]:
|
||||
bases: Dict[str, PythonSignatureNativeFunctionPair] = {}
|
||||
outplaces: Dict[str, PythonSignatureNativeFunctionPair] = {}
|
||||
|
|
@ -700,7 +689,9 @@ def group_overloads(
|
|||
outplace=outplace.function if outplace is not None else None,
|
||||
))
|
||||
|
||||
return sort_overloads(grouped)
|
||||
# TODO: unconditionally sort
|
||||
# maintaining byte-for-byte compatibility for pyi codegen for now
|
||||
return grouped if not sort else sort_overloads(grouped)
|
||||
|
||||
# This function declares a partial order on declarations, and sorts them according
|
||||
# to its linear extension. This is necessary, because there's some ambiguity in the
|
||||
|
|
|
|||
|
|
@ -1,8 +1,8 @@
|
|||
import re
|
||||
import os
|
||||
import yaml
|
||||
from collections import defaultdict
|
||||
from .nested_dict import nested_dict
|
||||
from typing import Dict, List
|
||||
|
||||
|
||||
__all__ = [
|
||||
|
|
@ -52,7 +52,7 @@ def uninplace_api_name(api_name):
|
|||
return api_name
|
||||
|
||||
|
||||
def write(dirname, name, template, env):
|
||||
def write(dirname: str, name: str, template: CodeTemplate, env: Dict[str, List[str]]) -> None:
|
||||
env['generated_comment'] = GENERATED_COMMENT.substitute(filename=template.filename)
|
||||
path = os.path.join(dirname, name)
|
||||
# See Note [Unchanging results for ninja]
|
||||
|
|
@ -69,12 +69,6 @@ def write(dirname, name, template, env):
|
|||
else:
|
||||
print("Skipped writing {}".format(path))
|
||||
|
||||
def is_tensor_method(declaration):
|
||||
return 'Tensor' in declaration['method_of']
|
||||
|
||||
def is_torch_function(declaration):
|
||||
return 'namespace' in declaration['method_of']
|
||||
|
||||
def is_out_variant(decl):
|
||||
return decl['name'].endswith('_out')
|
||||
|
||||
|
|
@ -92,12 +86,6 @@ def load_op_list_and_strip_overload(op_list, op_list_path):
|
|||
# strip out the overload part
|
||||
return {opname.split('.', 1)[0] for opname in op_list}
|
||||
|
||||
def group_declarations_by_op_name(declarations):
|
||||
groups = defaultdict(list)
|
||||
for d in declarations:
|
||||
groups[op_name(d)].append(d)
|
||||
return groups
|
||||
|
||||
def is_output(arg):
|
||||
return arg.get('output', False)
|
||||
|
||||
|
|
|
|||
|
|
@ -173,6 +173,53 @@ from tools.codegen.model import *
|
|||
# }
|
||||
#
|
||||
|
||||
# TODO: stick this more firmly in the data model somewhere?
|
||||
def namedtuple_fieldnames(returns: Tuple[Return, ...]) -> List[str]:
|
||||
if len(returns) <= 1 or all(map(lambda r: r.name is None, returns)):
|
||||
return []
|
||||
else:
|
||||
if any(map(lambda r: r.name is None, returns)):
|
||||
# When building on Windows, `PyStructSequence_UnnamedField` could not be
|
||||
# resolved by the linker for some reason, which cause error in building:
|
||||
#
|
||||
# python_nn_functions.cpp.obj : error LNK2001: unresolved external symbol
|
||||
# PyStructSequence_UnnamedField
|
||||
#
|
||||
# Thus, at this point in time, we do not support unnamed
|
||||
# fields in namedtuple; you must either name all fields,
|
||||
# or none of them.
|
||||
raise ValueError("Unnamed field is not supported by codegen")
|
||||
|
||||
return list(map(lambda r: str(r.name), returns))
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PythonReturns:
|
||||
returns: Tuple[Return, ...]
|
||||
|
||||
def named_tuple_pyi(self) -> Optional[Tuple[str, str]]:
|
||||
python_returns = [argument_type_str_pyi(r.type) for r in self.returns]
|
||||
field_names = namedtuple_fieldnames(self.returns)
|
||||
if field_names:
|
||||
namedtuple_name = '_'.join(['namedtuple'] + field_names)
|
||||
tuple_args = [f'("{name}", {typ})' for name, typ in zip(field_names, python_returns)]
|
||||
namedtuple_def = f'NamedTuple("{namedtuple_name}", [{", ".join(tuple_args)}])'
|
||||
return namedtuple_name, namedtuple_def
|
||||
return None
|
||||
|
||||
def returns_str_pyi(self) -> str:
|
||||
named_tuple = self.named_tuple_pyi()
|
||||
if named_tuple is not None:
|
||||
namedtuple_name, _ = named_tuple
|
||||
return namedtuple_name
|
||||
|
||||
python_returns = [argument_type_str_pyi(r.type) for r in self.returns]
|
||||
if len(python_returns) > 1:
|
||||
return 'Tuple[' + ', '.join(python_returns) + ']'
|
||||
if len(python_returns) == 1:
|
||||
return python_returns[0]
|
||||
return 'None'
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PythonArgument:
|
||||
name: str
|
||||
|
|
@ -189,26 +236,56 @@ class PythonArgument:
|
|||
|
||||
# Compute argument formal for python argument parsing.
|
||||
# Needs to be consistent with torch/csrc/utils/python_arg_parser.h.
|
||||
def argument_str(self, *, method: bool = False) -> str:
|
||||
type_str = argument_type_str(self.type)
|
||||
def argument_str(self, *, method: bool = False, pyi: bool = False, deprecated: bool = False) -> str:
|
||||
type_str = argument_type_str_pyi(self.type, pyi_out_arg=pyi and isinstance(self, PythonOutArgument)) \
|
||||
if pyi else argument_type_str(self.type)
|
||||
|
||||
name = self.name
|
||||
# s/self/input/ outside method bindings
|
||||
# [old codegen] TODO: remove this? doesn't rename in codegen, it's just
|
||||
# for the parse string
|
||||
name = self.name
|
||||
if name == 'self' and type_str == 'Tensor' and not method:
|
||||
if name == 'self' and type_str == 'Tensor' and not method and not deprecated:
|
||||
name = 'input'
|
||||
|
||||
if pyi:
|
||||
if name == 'from': # from is a Python keyword...
|
||||
name += '_'
|
||||
# pyi merges the _out and functional variants into the same signature, with an optional out arg
|
||||
if name == 'out' and type_str == 'Tensor' and not deprecated:
|
||||
type_str = 'Optional[' + type_str + ']'
|
||||
|
||||
# TODO: remove diff. pyi deprecated signatures don't get defaults for their out arg
|
||||
treat_as_no_default = pyi and deprecated and isinstance(self, PythonOutArgument) and self.default == 'None'
|
||||
|
||||
# add default
|
||||
if self.default is not None:
|
||||
default = {
|
||||
'nullptr': 'None',
|
||||
'c10::nullopt': 'None',
|
||||
'{}': 'None',
|
||||
}.get(self.default, self.default)
|
||||
return f'{type_str} {name}={default}'
|
||||
if self.default is not None and not treat_as_no_default:
|
||||
if pyi:
|
||||
if isinstance(self.type, ListType) and self.type.elem == BaseType(BaseTy.int) and \
|
||||
self.default.startswith('{') and self.default.endswith('}'):
|
||||
default = '(' + self.default[1:-1] + ')'
|
||||
else:
|
||||
default = {
|
||||
'nullptr': 'None',
|
||||
'c10::nullopt': 'None',
|
||||
'{}': 'None',
|
||||
'MemoryFormat::Contiguous': 'contiguous_format',
|
||||
'QScheme::PER_TENSOR_AFFINE': 'per_tensor_affine',
|
||||
}.get(self.default, self.default)
|
||||
# TODO: remove requires_grad special case (byte-for-byte compat)
|
||||
return f'{name}:{type_str}={default}' if name == 'requires_grad' else f'{name}: {type_str}={default}'
|
||||
else:
|
||||
default = {
|
||||
'nullptr': 'None',
|
||||
'c10::nullopt': 'None',
|
||||
'{}': 'None',
|
||||
}.get(self.default, self.default)
|
||||
return f'{type_str} {name}={default}'
|
||||
else:
|
||||
return f'{type_str} {name}'
|
||||
if pyi:
|
||||
# TODO: remove requires_grad special case (byte-for-byte compat)
|
||||
return f'{name}:{type_str}' if name == 'requires_grad' else f'{name}: {type_str}'
|
||||
else:
|
||||
return f'{type_str} {name}'
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PythonOutArgument(PythonArgument):
|
||||
|
|
@ -238,6 +315,7 @@ class PythonOutArgument(PythonArgument):
|
|||
raise RuntimeError(f'Unsupported output type: {outputs}')
|
||||
return PythonOutArgument(
|
||||
name='out',
|
||||
# TODO: shouldn't this be OptionalType[ListType[...]], since it defaults to None?
|
||||
type=ListType(BaseType(BaseTy.Tensor), size),
|
||||
default='None',
|
||||
default_init=None,
|
||||
|
|
@ -260,6 +338,9 @@ class PythonSignature:
|
|||
|
||||
output_args: Optional[PythonOutArgument]
|
||||
|
||||
# Return types, which are only used by pyi
|
||||
returns: PythonReturns
|
||||
|
||||
# These are scattered kwargs arguments belonging to TensorOptions.
|
||||
# When binding to C++, they are packed into a TensorOptions object 'options'.
|
||||
# It's possible that the C++ signature doesn't take TensorOptions object (e.g.
|
||||
|
|
@ -276,13 +357,23 @@ class PythonSignature:
|
|||
return False
|
||||
|
||||
def arguments(
|
||||
self, *, skip_outputs: bool = False, skip_tensor_options: bool = False
|
||||
self, *, skip_outputs: bool = False, skip_tensor_options: bool = False, hacky_add_output: bool = False
|
||||
) -> Tuple[Union[PythonArgument, PythonOutArgument], ...]:
|
||||
result: List[Union[PythonArgument, PythonOutArgument]] = []
|
||||
result.extend(self.input_args)
|
||||
result.extend(self.input_kwargs)
|
||||
if self.output_args is not None and not skip_outputs:
|
||||
result.append(self.output_args)
|
||||
# TODO: remove HACK
|
||||
# in the existing pyi codegen, we tack on an optional out argument to every operator overload
|
||||
# if there exists at least one overload with an out variant. This seems wrong.
|
||||
elif hacky_add_output:
|
||||
result.extend([PythonOutArgument(
|
||||
name='out',
|
||||
type=OptionalType(BaseType(BaseTy.Tensor)),
|
||||
default='None',
|
||||
default_init=None,
|
||||
outputs=())])
|
||||
if not skip_tensor_options:
|
||||
result.extend(self.tensor_options_args)
|
||||
return tuple(result)
|
||||
|
|
@ -301,18 +392,57 @@ class PythonSignature:
|
|||
# for error parsing.
|
||||
#
|
||||
# For a translation to mypy-valid type signatures, see
|
||||
# tools/gen_pyi.py. If you change any logic here, please
|
||||
# signature_str_pyi. If you change any logic here, please
|
||||
# check that file too.
|
||||
def signature_str(self, *, skip_outputs: bool = False) -> str:
|
||||
schema_formals: List[str] = \
|
||||
list(map(lambda a: a.argument_str(method=self.method),
|
||||
self.arguments(skip_outputs=skip_outputs)))
|
||||
args = self.arguments(skip_outputs=skip_outputs)
|
||||
schema_formals: List[str] = list(map(lambda a: a.argument_str(method=self.method), args))
|
||||
positional_argc = len(self.input_args)
|
||||
if len(schema_formals) > positional_argc:
|
||||
schema_formals.insert(positional_argc, '*')
|
||||
|
||||
return f'{self.name}({", ".join(schema_formals)})'
|
||||
|
||||
def signature_str_pyi(self, *, skip_outputs: bool = False, hacky_add_output: bool = False) -> str:
|
||||
args = self.arguments(skip_outputs=skip_outputs, hacky_add_output=hacky_add_output)
|
||||
schema_formals: List[str] = list(map(lambda a: a.argument_str(method=self.method, pyi=True), args))
|
||||
positional_argc = len(self.input_args)
|
||||
if len(schema_formals) > positional_argc:
|
||||
schema_formals.insert(positional_argc, '*')
|
||||
|
||||
# only pyi signatures include returns
|
||||
returns_str = self.returns.returns_str_pyi()
|
||||
# pyi also includes self (with no typing/defaults) for methods
|
||||
if self.method:
|
||||
schema_formals.insert(0, "self")
|
||||
return f'def {self.name}({", ".join(schema_formals)}) -> {returns_str}: ...'
|
||||
|
||||
def signature_str_pyi_vararg(self, *, skip_outputs: bool = False, hacky_add_output: bool = False) -> Optional[str]:
|
||||
# only pyi uses vararg signatures
|
||||
args = self.arguments(skip_outputs=skip_outputs, hacky_add_output=hacky_add_output)
|
||||
schema_formals: List[str] = list(map(lambda a: a.argument_str(method=self.method, pyi=True), args))
|
||||
# vararg only applies to pyi signatures. vararg variants are not generated for all signatures
|
||||
num_args = self.arguments_count()
|
||||
num_positionalargs = len(self.input_args)
|
||||
|
||||
have_vararg_version = False
|
||||
if num_args > 0:
|
||||
vararg_type = args[0].type
|
||||
if isinstance(vararg_type, ListType) and str(vararg_type.elem) == 'int' and num_positionalargs == 1:
|
||||
have_vararg_version = True
|
||||
|
||||
if not have_vararg_version:
|
||||
return None
|
||||
# Below are the major changes in vararg vs. regular pyi signatures
|
||||
# vararg signatures also omit the asterix
|
||||
schema_formals[0] = '*' + args[0].name + ': _int'
|
||||
|
||||
returns_str = self.returns.returns_str_pyi()
|
||||
# pyi also includes self (with no typing/defaults) for methods
|
||||
if self.method:
|
||||
schema_formals.insert(0, "self")
|
||||
return f'def {self.name}({", ".join(schema_formals)}) -> {returns_str}: ...'
|
||||
|
||||
# The deprecated python signature involves some special logic, so create a
|
||||
# dedicated data model to store these extra properties.
|
||||
@dataclass(frozen=True)
|
||||
|
|
@ -340,6 +470,20 @@ class PythonSignatureDeprecated(PythonSignature):
|
|||
def signature_str(self, *, skip_outputs: bool = False) -> str:
|
||||
return PythonSignature.signature_str(self, skip_outputs=skip_outputs) + '|deprecated'
|
||||
|
||||
def signature_str_pyi(self, *, skip_outputs: bool = False, hacky_add_output: bool = False) -> str:
|
||||
args = self.arguments(skip_outputs=skip_outputs, hacky_add_output=hacky_add_output)
|
||||
schema_formals: List[str] = list(map(lambda a: a.argument_str(method=self.method, pyi=True, deprecated=True), args))
|
||||
positional_argc = len(self.input_args)
|
||||
if len(schema_formals) > positional_argc:
|
||||
schema_formals.insert(positional_argc, '*')
|
||||
|
||||
returns_str = self.returns.returns_str_pyi()
|
||||
return f'def {self.name}({", ".join(schema_formals)}) -> {returns_str}: ...'
|
||||
|
||||
def signature_str_pyi_vararg(self, *, skip_outputs: bool = False, hacky_add_output: bool = False) -> Optional[str]:
|
||||
# the codegen doesn't include vararg variants for deprecated signatures
|
||||
return None
|
||||
|
||||
# This struct is used to hold the PythonSignature and its corresponding
|
||||
# NativeFunction BEFORE grouping base and out-variant functions.
|
||||
# Why not store NativeFunction in PythonSignature or construct PythonSignature
|
||||
|
|
@ -520,12 +664,75 @@ def argument(a: Argument) -> PythonArgument:
|
|||
default_init=None,
|
||||
)
|
||||
|
||||
def signature(f: NativeFunction, *, method: bool = False) -> PythonSignature:
|
||||
def argument_type_str_pyi(t: Type, *, pyi_out_arg: bool = False) -> str:
|
||||
add_optional = False
|
||||
if isinstance(t, OptionalType):
|
||||
t = t.elem
|
||||
add_optional = True
|
||||
|
||||
if isinstance(t, BaseType):
|
||||
if t.name == BaseTy.int:
|
||||
ret = '_int'
|
||||
elif t.name == BaseTy.float:
|
||||
ret = '_float'
|
||||
elif t.name == BaseTy.str:
|
||||
ret = 'str'
|
||||
elif t.name == BaseTy.Scalar:
|
||||
ret = 'Number'
|
||||
elif t.name == BaseTy.ScalarType:
|
||||
ret = '_dtype'
|
||||
elif t.name == BaseTy.bool:
|
||||
ret = '_bool'
|
||||
elif t.name == BaseTy.QScheme:
|
||||
ret = '_qscheme'
|
||||
elif t.name == BaseTy.Layout:
|
||||
ret = '_layout'
|
||||
elif t.name == BaseTy.Device:
|
||||
ret = 'Union[_device, str, None]'
|
||||
elif t.name == BaseTy.MemoryFormat:
|
||||
ret = 'memory_format'
|
||||
elif t.name == BaseTy.Dimname:
|
||||
ret = 'Union[str, ellipsis, None]'
|
||||
elif t.name in [BaseTy.Tensor, BaseTy.Generator,
|
||||
BaseTy.Storage, BaseTy.Stream, BaseTy.str]:
|
||||
# These python schema type names line up with their function schema names
|
||||
ret = t.name.name
|
||||
|
||||
elif isinstance(t, ListType):
|
||||
if pyi_out_arg and t.is_tensor_like():
|
||||
# TODO remove HACK
|
||||
# pyi blindly treats all tensor-like out args as having type Tensor
|
||||
return 'Tensor'
|
||||
if str(t.elem) == 'int':
|
||||
ret = 'Union[_int, _size]' if t.size is not None else '_size'
|
||||
elif t.is_tensor_like():
|
||||
# TODO: this doesn't seem right...
|
||||
# Tensor?[] currently translates to Optional[Union[Tuple[Tensor, ...], List[Tensor]]]
|
||||
# It should probably translate to Union[Tuple[Optional[Tensor], ...], List[Optional[Tensor]]]
|
||||
if isinstance(t.elem, OptionalType):
|
||||
add_optional = True
|
||||
ret = 'Union[Tensor, Tuple[Tensor, ...], List[Tensor]]' if t.size is not None else \
|
||||
'Union[Tuple[Tensor, ...], List[Tensor]]'
|
||||
elif str(t.elem) == 'float':
|
||||
ret = 'Sequence[float]'
|
||||
else:
|
||||
elem = argument_type_str_pyi(t.elem)
|
||||
ret = f'Sequence[{elem}]'
|
||||
|
||||
if add_optional:
|
||||
ret = 'Optional[' + ret + ']'
|
||||
return ret
|
||||
|
||||
raise RuntimeError(f'unrecognized type {repr(t)}')
|
||||
|
||||
# Generates a PythonSignature that can be used for either .pyi or PythonArgParser codegen
|
||||
def signature(f: NativeFunction, *, method: bool = False, pyi: bool = False) -> PythonSignature:
|
||||
# Use cpp api to gather TensorOptions fields from kwargs.
|
||||
# Skip ThisArgument if this is method signature.
|
||||
# Skip SelfArgument if this is method.
|
||||
# Skip TensorOptionsArguments in C++ signature. Python side TensorOptions
|
||||
# arguments are created based on different rules - see below.
|
||||
args = tuple(a for a in cpp.group_arguments(f.func, method=method) if isinstance(a, Argument))
|
||||
cpp_args = cpp.group_arguments(f.func, method=method)
|
||||
args = tuple(a for a in cpp_args if isinstance(a, Argument))
|
||||
|
||||
input_arg_set = set(a.name for a in f.func.arguments.positional)
|
||||
kwarg_only_set = set(a.name for a in f.func.arguments.kwarg_only)
|
||||
|
|
@ -561,13 +768,15 @@ def signature(f: NativeFunction, *, method: bool = False) -> PythonSignature:
|
|||
tensor_options_args.append(PythonArgument(
|
||||
name='dtype',
|
||||
type=BaseType(BaseTy.ScalarType),
|
||||
default=_dtype_default_type_hack(name),
|
||||
default=_dtype_default_type_hack(name, pyi=pyi),
|
||||
default_init='self.scalar_type()' if is_like_or_new_function else None,
|
||||
))
|
||||
# TODO: probably a bug, kill this diff?
|
||||
# pyi signatures have a slightly different type/default for layout
|
||||
tensor_options_args.append(PythonArgument(
|
||||
name='layout',
|
||||
type=OptionalType(BaseType(BaseTy.Layout)),
|
||||
default='torch.strided',
|
||||
type=BaseType(BaseTy.Layout) if pyi else OptionalType(BaseType(BaseTy.Layout)),
|
||||
default='strided' if pyi else 'torch.strided',
|
||||
default_init='layout_from_backend(self.options().backend())' if is_like_or_new_function else None,
|
||||
))
|
||||
tensor_options_args.append(PythonArgument(
|
||||
|
|
@ -576,12 +785,15 @@ def signature(f: NativeFunction, *, method: bool = False) -> PythonSignature:
|
|||
default='None',
|
||||
default_init='self.device()' if is_like_or_new_function else None,
|
||||
))
|
||||
tensor_options_args.append(PythonArgument(
|
||||
name='pin_memory',
|
||||
type=BaseType(BaseTy.bool),
|
||||
default='False',
|
||||
default_init=None,
|
||||
))
|
||||
# TODO: probably a bug, kill this diff?
|
||||
# pyi signatures don't include pin memory
|
||||
if not pyi:
|
||||
tensor_options_args.append(PythonArgument(
|
||||
name='pin_memory',
|
||||
type=BaseType(BaseTy.bool),
|
||||
default='False',
|
||||
default_init=None,
|
||||
))
|
||||
tensor_options_args.append(PythonArgument(
|
||||
name='requires_grad',
|
||||
type=BaseType(BaseTy.bool),
|
||||
|
|
@ -589,18 +801,21 @@ def signature(f: NativeFunction, *, method: bool = False) -> PythonSignature:
|
|||
default_init=None,
|
||||
))
|
||||
|
||||
returns = PythonReturns(returns=f.func.returns)
|
||||
|
||||
return PythonSignature(
|
||||
name=str(f.func.name.name),
|
||||
input_args=input_args,
|
||||
input_kwargs=input_kwargs,
|
||||
output_args=PythonOutArgument.from_outputs(outputs),
|
||||
tensor_options_args=tuple(tensor_options_args),
|
||||
returns=returns,
|
||||
method=method,
|
||||
)
|
||||
|
||||
# TODO blowtorch
|
||||
def _dtype_default_type_hack(name: str) -> str:
|
||||
if name.startswith('randperm') or name == 'tril_indices' or name == 'triu_indices':
|
||||
def _dtype_default_type_hack(name: str, *, pyi: bool) -> str:
|
||||
if not pyi and (name.startswith('randperm') or name == 'tril_indices' or name == 'triu_indices'):
|
||||
return 'torch.int64'
|
||||
else:
|
||||
return 'None'
|
||||
|
|
|
|||
|
|
@ -3,13 +3,14 @@ import os
|
|||
import collections
|
||||
from pprint import pformat
|
||||
|
||||
import yaml
|
||||
import re
|
||||
import argparse
|
||||
|
||||
from ..autograd.utils import YamlLoader, CodeTemplate, write, group_declarations_by_op_name, is_tensor_method, is_torch_function
|
||||
from ..autograd.gen_python_functions import SKIP_PYTHON_BINDINGS, SKIP_PYTHON_BINDINGS_SIGNATURES
|
||||
from ..autograd.gen_autograd import load_aten_declarations
|
||||
from tools.codegen.model import *
|
||||
from tools.codegen.api.python import *
|
||||
from typing import Sequence, List, Mapping, Dict
|
||||
|
||||
from ..autograd.utils import CodeTemplate, write
|
||||
from ..autograd.gen_python_functions import should_generate_py_binding, load_signatures, group_overloads
|
||||
|
||||
"""
|
||||
This module implements generation of type stubs for PyTorch,
|
||||
|
|
@ -28,60 +29,48 @@ Here's our general strategy:
|
|||
(the latter case should be pretty rare).
|
||||
|
||||
- We go through automatically bound functions based on the
|
||||
type information recorded in Declarations.yaml and
|
||||
type information recorded in native_functions.yaml and
|
||||
generate type hints for them (generate_type_hints)
|
||||
|
||||
There are a number of type hints which we've special-cased;
|
||||
read gen_pyi for the gory details.
|
||||
"""
|
||||
|
||||
# TODO: remove after migrating entire codegen to the new data model.
|
||||
def should_generate_python_binding(declaration):
|
||||
name = declaration['name']
|
||||
for pattern in SKIP_PYTHON_BINDINGS:
|
||||
if re.match('^' + pattern + '$', name):
|
||||
return False
|
||||
# TODO: consider waiting to group by base name until we actually need to
|
||||
# (after computing type hint signatures, when adding @overload directives)
|
||||
def group_by_base_name(python_funcs: Sequence[PythonSignatureNativeFunctionPair]) -> Mapping[str, List[PythonSignatureGroup]]:
|
||||
groups = group_overloads(python_funcs, sort=False)
|
||||
d = collections.defaultdict(list)
|
||||
for g in groups:
|
||||
name = g.signature.name
|
||||
d[name].append(g)
|
||||
return d
|
||||
|
||||
simple_types = [arg['simple_type'] for arg in declaration['arguments']]
|
||||
signature = '{}({})'.format(name, ', '.join(simple_types))
|
||||
for pattern in SKIP_PYTHON_BINDINGS_SIGNATURES:
|
||||
if pattern == signature:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def get_py_variable_methods(declarations):
|
||||
def get_py_torch_functions(
|
||||
python_funcs: Sequence[PythonSignatureNativeFunctionPair],
|
||||
method: bool = False,
|
||||
) -> Mapping[str, Sequence[PythonSignatureGroup]]:
|
||||
"""
|
||||
Get declarations (grouped by name) which should be generated
|
||||
as methods on Tensor.
|
||||
as either functions in the "torch" module or methods on Tensor.
|
||||
"""
|
||||
def should_bind(declaration):
|
||||
return (should_generate_python_binding(declaration) and
|
||||
not declaration.get('python_module') and
|
||||
is_tensor_method(declaration))
|
||||
def should_bind_function(python_func: PythonSignatureNativeFunctionPair) -> bool:
|
||||
return (should_generate_py_binding(python_func.function) and
|
||||
not python_func.function.python_module and
|
||||
Variant.function in python_func.function.variants)
|
||||
|
||||
return group_declarations_by_op_name([d for d in declarations if should_bind(d)])
|
||||
def should_bind_method(python_func: PythonSignatureNativeFunctionPair) -> bool:
|
||||
return (should_generate_py_binding(python_func.function) and
|
||||
not python_func.function.python_module and
|
||||
Variant.method in python_func.function.variants)
|
||||
|
||||
|
||||
def get_py_torch_functions(declarations):
|
||||
"""
|
||||
Get declarations (grouped by name) which should be generated
|
||||
as functions in the "torch" module.
|
||||
"""
|
||||
def should_bind(declaration):
|
||||
return (should_generate_python_binding(declaration) and
|
||||
not declaration.get('python_module') and
|
||||
is_torch_function(declaration))
|
||||
|
||||
return group_declarations_by_op_name([d for d in declarations if should_bind(d)])
|
||||
should_bind = should_bind_method if method else should_bind_function
|
||||
return group_by_base_name([f for f in python_funcs if should_bind(f)])
|
||||
|
||||
|
||||
# TODO: Consider defining some aliases for our Union[...] types, to make
|
||||
# the stubs to read on the human eye.
|
||||
|
||||
needed_modules = set()
|
||||
|
||||
DEVICE_PARAM = "device: Union[_device, str, None]=None"
|
||||
FACTORY_PARAMS = f"dtype: Optional[_dtype]=None, {DEVICE_PARAM}, requires_grad: _bool=False"
|
||||
|
||||
|
|
@ -144,90 +133,6 @@ blocklist = [
|
|||
]
|
||||
|
||||
|
||||
def type_to_python(typename, size=None):
|
||||
"""type_to_python(typename: str, size: str) -> str
|
||||
|
||||
Transforms a Declarations.yaml type name into a Python type specification
|
||||
as used for type hints.
|
||||
"""
|
||||
typename = typename.replace(' ', '') # normalize spaces, e.g., 'Generator *'
|
||||
|
||||
# Disambiguate explicitly sized int/tensor lists from implicitly
|
||||
# sized ones. These permit non-list inputs too. (IntArrayRef[] and
|
||||
# TensorList[] are not real types; this is just for convenience.)
|
||||
if typename in {'IntArrayRef', 'TensorList'} and size is not None:
|
||||
typename += '[]'
|
||||
|
||||
typename = {
|
||||
'Device': 'Device',
|
||||
'Generator': 'Generator',
|
||||
'IntegerTensor': 'Tensor',
|
||||
'Scalar': 'Number',
|
||||
'ScalarType': '_dtype',
|
||||
'Storage': 'Storage',
|
||||
'BoolTensor': 'Tensor',
|
||||
'IndexTensor': 'Tensor',
|
||||
'Tensor': 'Tensor',
|
||||
'MemoryFormat': 'memory_format',
|
||||
'IntArrayRef': '_size',
|
||||
'IntArrayRef[]': 'Union[_int, _size]',
|
||||
'TensorList': 'Union[Tuple[Tensor, ...], List[Tensor]]',
|
||||
'TensorList[]': 'Union[Tensor, Tuple[Tensor, ...], List[Tensor]]',
|
||||
'bool': '_bool',
|
||||
'double': '_float',
|
||||
'int64_t': '_int',
|
||||
'accreal': 'Number',
|
||||
'real': 'Number',
|
||||
'void*': '_int', # data_ptr
|
||||
'void': 'None',
|
||||
'std::string': 'str',
|
||||
'Dimname': 'Union[str, ellipsis, None]',
|
||||
'DimnameList': 'Sequence[Union[str, ellipsis, None]]',
|
||||
'QScheme': '_qscheme',
|
||||
'ArrayRef<double>' : 'Sequence[float]',
|
||||
'Stream': 'Stream',
|
||||
}[typename]
|
||||
|
||||
return typename
|
||||
|
||||
|
||||
def arg_to_type_hint(arg):
|
||||
"""arg_to_type_hint(arg) -> str
|
||||
|
||||
This takes one argument in a Declarations and returns a string
|
||||
representing this argument in a type hint signature.
|
||||
"""
|
||||
name = arg['name']
|
||||
if name == 'from': # from is a Python keyword...
|
||||
name += '_'
|
||||
typename = type_to_python(arg['dynamic_type'], arg.get('size'))
|
||||
if arg.get('is_nullable'):
|
||||
typename = 'Optional[' + typename + ']'
|
||||
if 'default' in arg:
|
||||
default = arg['default']
|
||||
if default == 'nullptr':
|
||||
default = None
|
||||
elif default == 'c10::nullopt':
|
||||
default = None
|
||||
elif isinstance(default, str) and default.startswith('{') and default.endswith('}'):
|
||||
if arg['dynamic_type'] == 'Tensor' and default == '{}':
|
||||
default = None
|
||||
elif arg['dynamic_type'] == 'Generator' and default == '{}':
|
||||
default = None
|
||||
elif arg['dynamic_type'] == 'IntArrayRef':
|
||||
default = '(' + default[1:-1] + ')'
|
||||
else:
|
||||
raise Exception("Unexpected default constructor argument of type {}".format(arg['dynamic_type']))
|
||||
elif default == 'MemoryFormat::Contiguous':
|
||||
default = 'contiguous_format'
|
||||
elif default == 'QScheme::PER_TENSOR_AFFINE':
|
||||
default = 'per_tensor_affine'
|
||||
default = '={}'.format(default)
|
||||
else:
|
||||
default = ''
|
||||
return name + ': ' + typename + default
|
||||
|
||||
|
||||
binary_ops = ('add', 'sub', 'mul', 'div', 'pow', 'lshift', 'rshift', 'mod', 'truediv',
|
||||
'matmul', 'floordiv',
|
||||
'radd', 'rsub', 'rmul', 'rtruediv', 'rfloordiv', 'rpow', # reverse arithmetic
|
||||
|
|
@ -241,7 +146,7 @@ to_py_type_ops = ('bool', 'float', 'complex', 'long', 'index', 'int', 'nonzero')
|
|||
all_ops = binary_ops + comparison_ops + unary_ops + to_py_type_ops
|
||||
|
||||
|
||||
def sig_for_ops(opname):
|
||||
def sig_for_ops(opname: str) -> List[str]:
|
||||
"""sig_for_ops(opname : str) -> List[str]
|
||||
|
||||
Returns signatures for operator special functions (__add__ etc.)"""
|
||||
|
|
@ -271,146 +176,66 @@ def sig_for_ops(opname):
|
|||
else:
|
||||
raise Exception("unknown op", opname)
|
||||
|
||||
|
||||
# Copied from 'gen_python_functions.py'
|
||||
# TODO: consolidate after migrating to the new codegen model in 'tools/codegen'.
|
||||
def namedtuple_fieldnames(declaration):
|
||||
returns = declaration['returns']
|
||||
if len(returns) <= 1 or all(['field_name' not in x for x in returns]):
|
||||
return []
|
||||
else:
|
||||
def get_field_name(x):
|
||||
# See Note [field_name versus name]
|
||||
if 'field_name' not in x:
|
||||
# When building on Windows, `PyStructSequence_UnnamedField` could not be
|
||||
# resolved by the linker for some reason, which cause error in building:
|
||||
#
|
||||
# python_nn_functions.cpp.obj : error LNK2001: unresolved external symbol
|
||||
# PyStructSequence_UnnamedField
|
||||
#
|
||||
# Thus, at this point in time, we do not support unnamed
|
||||
# fields in namedtuple; you must either name all fields,
|
||||
# or none of them.
|
||||
raise ValueError("Unnamed field is not supported by codegen")
|
||||
def generate_named_tuples(funcs: Sequence[PythonSignatureGroup]) -> Dict[str, str]:
|
||||
namedtuples: Dict[str, str] = {}
|
||||
for sig_group in funcs:
|
||||
named_tuple = sig_group.signature.returns.named_tuple_pyi()
|
||||
if named_tuple is not None:
|
||||
tuple_name, tuple_def = named_tuple
|
||||
if tuple_name in namedtuples:
|
||||
assert namedtuples[tuple_name] == tuple_def
|
||||
else:
|
||||
return x['field_name']
|
||||
return [get_field_name(x) for x in returns]
|
||||
namedtuples[tuple_name] = tuple_def
|
||||
return namedtuples
|
||||
|
||||
|
||||
def generate_type_hints(fname, decls, namedtuples, is_tensor=False):
|
||||
"""generate_type_hints(fname, decls, is_tensor=False)
|
||||
def generate_type_hints(funcs: Sequence[PythonSignatureGroup], is_tensor: bool = False) -> List[str]:
|
||||
"""generate_type_hints(funcs, is_tensor=False)
|
||||
|
||||
Generates type hints for the declarations pertaining to the function
|
||||
:attr:`fname`. attr:`decls` are the declarations from the parsed
|
||||
Declarations.yaml.
|
||||
:attr:`namedtuples` is a dictionary for accumulating NamedTuple definitions.
|
||||
:attr:`funcs` are the func from the parsed native_functions.yaml.
|
||||
The :attr:`is_tensor` flag indicates whether we are parsing
|
||||
members of the Tensor class (true) or functions in the
|
||||
`torch` namespace (default, false).
|
||||
|
||||
This function currently encodes quite a bit about the semantics of
|
||||
the translation C++ -> Python.
|
||||
"""
|
||||
if fname in blocklist:
|
||||
return []
|
||||
|
||||
type_hints = []
|
||||
dnames = ([d['name'] for d in decls])
|
||||
has_out = fname + '_out' in dnames
|
||||
any_out = any([g for g in funcs if g.outplace is not None])
|
||||
|
||||
if has_out:
|
||||
decls = [d for d in decls if d['name'] != fname + '_out']
|
||||
for sig_group in funcs:
|
||||
# Some deprecated ops that are on the blocklist are still included in pyi
|
||||
if sig_group.signature.name in blocklist and not sig_group.signature.deprecated:
|
||||
continue
|
||||
|
||||
for decl in decls:
|
||||
render_kw_only_separator = True # whether we add a '*' if we see a keyword only argument
|
||||
python_args = []
|
||||
|
||||
has_tensor_options = 'TensorOptions' in (a['dynamic_type'] for a in decl['arguments'])
|
||||
|
||||
for a in decl['arguments']:
|
||||
if a['dynamic_type'] != 'TensorOptions':
|
||||
if a.get('kwarg_only', False) and render_kw_only_separator:
|
||||
python_args.append('*')
|
||||
render_kw_only_separator = False
|
||||
try:
|
||||
python_args.append(arg_to_type_hint(a))
|
||||
except Exception:
|
||||
print("Error while processing function {}".format(fname))
|
||||
raise
|
||||
|
||||
if 'self: Tensor' in python_args:
|
||||
self_index = python_args.index('self: Tensor')
|
||||
python_args.remove('self: Tensor')
|
||||
if is_tensor:
|
||||
python_args = ['self'] + python_args
|
||||
else:
|
||||
python_args.insert(self_index, 'input: Tensor')
|
||||
else:
|
||||
if is_tensor:
|
||||
raise Exception("method without self is unexpected")
|
||||
|
||||
if has_out:
|
||||
if render_kw_only_separator:
|
||||
python_args.append('*')
|
||||
render_kw_only_separator = False
|
||||
python_args.append('out: Optional[Tensor]=None')
|
||||
|
||||
if has_tensor_options:
|
||||
if render_kw_only_separator:
|
||||
python_args.append('*')
|
||||
render_kw_only_separator = False
|
||||
python_args += ["dtype: _dtype=None",
|
||||
"layout: _layout=strided",
|
||||
"device: Union[_device, str, None]=None",
|
||||
"requires_grad:_bool=False"]
|
||||
|
||||
python_args_s = ', '.join(python_args)
|
||||
python_returns = [type_to_python(r['dynamic_type']) for r in decl['returns']]
|
||||
field_names = namedtuple_fieldnames(decl)
|
||||
|
||||
if field_names:
|
||||
namedtuple_name = '_'.join(['namedtuple'] + field_names)
|
||||
tuple_args = ['("{}", {})'.format(name, typ) for name, typ in zip(field_names, python_returns)]
|
||||
namedtuple_def = 'NamedTuple("{}", [{}])'.format(namedtuple_name, ', '.join(tuple_args))
|
||||
if namedtuple_name in namedtuples:
|
||||
assert namedtuples[namedtuple_name] == namedtuple_def
|
||||
else:
|
||||
namedtuples[namedtuple_name] = namedtuple_def
|
||||
python_returns_s = namedtuple_name
|
||||
elif len(python_returns) > 1:
|
||||
python_returns_s = 'Tuple[' + ', '.join(python_returns) + ']'
|
||||
elif len(python_returns) == 1:
|
||||
python_returns_s = python_returns[0]
|
||||
else:
|
||||
python_returns_s = 'None'
|
||||
|
||||
type_hint = "def {}({}) -> {}: ...".format(fname, python_args_s, python_returns_s)
|
||||
numargs = len(decl['arguments'])
|
||||
vararg_pos = int(is_tensor)
|
||||
have_vararg_version = (numargs > vararg_pos and
|
||||
decl['arguments'][vararg_pos]['dynamic_type'] in {'IntArrayRef'} and
|
||||
(numargs == vararg_pos + 1 or python_args[vararg_pos + 1] == '*') and
|
||||
(not is_tensor or decl['arguments'][0]['name'] == 'self'))
|
||||
# deprecated signatures have separate entries for their functional and out variants
|
||||
# (as opposed to the native ops, which fuse the two into a single signature).
|
||||
# generate the functional variant here, if an out variant exists.
|
||||
if sig_group.signature.deprecated and sig_group.outplace is not None:
|
||||
type_hint = sig_group.signature.signature_str_pyi(skip_outputs=True)
|
||||
type_hints.append(type_hint)
|
||||
|
||||
# TODO: remove HACK
|
||||
# the pyi codegen currently adds an optional out param in cases where the current op does NOT have an out variant,
|
||||
# but an overload of the op DOES have an out variant.
|
||||
# TODO: After that, we should consider killing this method entirely and operating per PythonSignatureGroup
|
||||
# rather than grouping their overloads together
|
||||
# (since there isn't much else semantically meaningful about grouping overloads)
|
||||
# this hack also doesn't apply to deprecated ops
|
||||
hacky_add_output = any_out and sig_group.outplace is None and not sig_group.signature.deprecated
|
||||
# PythonSignatureGroups that have both a functional + out variant get a single signature, with an optional out argument
|
||||
# Generates the out variant if one exists. Otherwise, generate the functional variant
|
||||
type_hint = sig_group.signature.signature_str_pyi(
|
||||
skip_outputs=sig_group.outplace is None, hacky_add_output=hacky_add_output)
|
||||
type_hints.append(type_hint)
|
||||
|
||||
if have_vararg_version:
|
||||
# Two things come into play here: PyTorch has the "magic" that if the first and only positional argument
|
||||
# is an IntArrayRef, it will be used as a vararg variant.
|
||||
# The following outputs the vararg variant, the "pass a list variant" is output above.
|
||||
# The other thing is that in Python, the varargs are annotated with the element type, not the list type.
|
||||
typelist = decl['arguments'][vararg_pos]['dynamic_type']
|
||||
vararg_type = '_int'
|
||||
# replace first argument and eliminate '*' if present
|
||||
python_args = ((['self'] if is_tensor else []) + ['*' + decl['arguments'][vararg_pos]['name'] +
|
||||
': ' + vararg_type] + python_args[vararg_pos + 2:])
|
||||
python_args_s = ', '.join(python_args)
|
||||
type_hint = "def {}({}) -> {}: ...".format(fname, python_args_s, python_returns_s)
|
||||
type_hints.append(type_hint)
|
||||
# Some operators also additionally have a vararg variant of their signature
|
||||
type_hint_vararg = sig_group.signature.signature_str_pyi_vararg(
|
||||
skip_outputs=sig_group.outplace is None, hacky_add_output=hacky_add_output)
|
||||
if type_hint_vararg:
|
||||
type_hints.append(type_hint_vararg)
|
||||
|
||||
return type_hints
|
||||
|
||||
def gen_nn_functional(out):
|
||||
def gen_nn_functional(out: str) -> None:
|
||||
# Functions imported into `torch.nn.functional` from `torch`, perhaps being filtered
|
||||
# through an `_add_docstr` call
|
||||
imports = [
|
||||
|
|
@ -475,10 +300,10 @@ def gen_nn_functional(out):
|
|||
stubs = CodeTemplate.from_file(os.path.join('torch', '_C', '_nn.pyi.in'))
|
||||
write(out, 'torch/_C/_nn.pyi', stubs, env)
|
||||
|
||||
def gen_nn_pyi(out):
|
||||
def gen_nn_pyi(out: str) -> None:
|
||||
gen_nn_functional(out)
|
||||
|
||||
def gen_pyi(declarations_path, out):
|
||||
def gen_pyi(native_yaml_path: str, deprecated_yaml_path: str, out: str) -> None:
|
||||
"""gen_pyi()
|
||||
|
||||
This function generates a pyi file for torch.
|
||||
|
|
@ -491,16 +316,13 @@ def gen_pyi(declarations_path, out):
|
|||
# checking. If you are update this, consider if your change
|
||||
# also needs to update the other file.
|
||||
|
||||
# Load information from YAML
|
||||
declarations = load_aten_declarations(declarations_path)
|
||||
|
||||
# Dictionary for NamedTuple definitions
|
||||
namedtuples = {}
|
||||
namedtuples: Dict[str, str] = {}
|
||||
|
||||
# Generate type signatures for top-level functions
|
||||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
unsorted_function_hints = collections.defaultdict(list)
|
||||
unsorted_function_hints: Dict[str, List[str]] = collections.defaultdict(list)
|
||||
unsorted_function_hints.update({
|
||||
'set_flush_denormal': ['def set_flush_denormal(mode: _bool) -> _bool: ...'],
|
||||
'get_default_dtype': ['def get_default_dtype() -> _dtype: ...'],
|
||||
|
|
@ -560,21 +382,13 @@ def gen_pyi(declarations_path, out):
|
|||
' other: Union[Tensor, Number],'
|
||||
' *, alpha: Optional[Number]=1, out: Optional[Tensor]=None) -> Tensor: ...'.format(binop))
|
||||
|
||||
function_declarations = get_py_torch_functions(declarations)
|
||||
for name in sorted(function_declarations.keys()):
|
||||
unsorted_function_hints[name] += generate_type_hints(name, function_declarations[name], namedtuples)
|
||||
|
||||
# Generate type signatures for deprecated functions
|
||||
|
||||
# TODO: Maybe we shouldn't generate type hints for deprecated
|
||||
# functions :) However, examples like those addcdiv rely on these.
|
||||
with open('tools/autograd/deprecated.yaml', 'r') as f:
|
||||
deprecated = yaml.load(f, Loader=YamlLoader)
|
||||
for d in deprecated:
|
||||
name, sig = re.match(r"^([^\(]+)\(([^\)]*)", d['name']).groups()
|
||||
sig = ['*' if p.strip() == '*' else p.split() for p in sig.split(',')]
|
||||
sig = ['*' if p == '*' else (p[1] + ': ' + type_to_python(p[0])) for p in sig]
|
||||
unsorted_function_hints[name].append("def {}({}) -> Tensor: ...".format(name, ', '.join(sig)))
|
||||
function_signatures = load_signatures(native_yaml_path, deprecated_yaml_path, method=False, pyi=True)
|
||||
sig_groups = get_py_torch_functions(function_signatures)
|
||||
for name in sorted(sig_groups.keys()):
|
||||
unsorted_function_hints[name] += generate_type_hints(sig_groups[name])
|
||||
# deprecated signatures are not used when computing named tuples
|
||||
native_groups = [g for g in sig_groups[name] if not g.signature.deprecated]
|
||||
namedtuples.update(generate_named_tuples(native_groups))
|
||||
|
||||
function_hints = []
|
||||
for name, hints in sorted(unsorted_function_hints.items()):
|
||||
|
|
@ -585,26 +399,26 @@ def gen_pyi(declarations_path, out):
|
|||
# Generate type signatures for Tensor methods
|
||||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
unsorted_tensor_method_hints = collections.defaultdict(list)
|
||||
unsorted_tensor_method_hints: Dict[str, List[str]] = collections.defaultdict(list)
|
||||
unsorted_tensor_method_hints.update({
|
||||
'size': ['def size(self) -> Size: ...',
|
||||
'def size(self, _int) -> _int: ...'],
|
||||
'stride': ['def stride(self) -> Tuple[_int]: ...',
|
||||
'def stride(self, _int) -> _int: ...'],
|
||||
'new_ones': ['def new_ones(self, size: {}, {}) -> Tensor: ...'.
|
||||
format(type_to_python('IntArrayRef'), FACTORY_PARAMS)],
|
||||
'new_ones': ['def new_ones(self, size: _size, {}) -> Tensor: ...'.
|
||||
format(FACTORY_PARAMS)],
|
||||
'new_tensor': ["def new_tensor(self, data: Any, {}) -> Tensor: ...".format(FACTORY_PARAMS)],
|
||||
# new and __init__ have the same signatures differ only in return type
|
||||
# Adapted from legacy_tensor_ctor and legacy_tensor_new
|
||||
'new': ['def new(self, *args: Any, {}) ->Tensor: ...'.format(DEVICE_PARAM),
|
||||
'def new(self, storage: Storage) -> Tensor: ...',
|
||||
'def new(self, other: Tensor) -> Tensor: ...',
|
||||
'def new(self, size: {}, *, {}) -> Tensor: ...'.format(type_to_python('IntArrayRef'), DEVICE_PARAM),
|
||||
'def new(self, size: _size, *, {}) -> Tensor: ...'.format(DEVICE_PARAM),
|
||||
],
|
||||
'__init__': ['def __init__(self, *args: Any, {}) -> None: ...'.format(DEVICE_PARAM),
|
||||
'def __init__(self, storage: Storage) -> None: ...',
|
||||
'def __init__(self, other: Tensor) -> None: ...',
|
||||
'def __init__(self, size: {}, *, {}) -> None: ...'.format(type_to_python('IntArrayRef'), DEVICE_PARAM),
|
||||
'def __init__(self, size: _size, *, {}) -> None: ...'.format(DEVICE_PARAM),
|
||||
],
|
||||
'as_subclass': ["def as_subclass(self, cls: Tensor) -> Tensor: ..."],
|
||||
# clamp has no default values in the Declarations
|
||||
|
|
@ -679,10 +493,14 @@ def gen_pyi(declarations_path, out):
|
|||
for name in simple_conversions:
|
||||
unsorted_tensor_method_hints[name].append('def {}(self) -> Tensor: ...'.format(name))
|
||||
|
||||
tensor_method_declarations = get_py_variable_methods(declarations)
|
||||
for name in sorted(tensor_method_declarations.keys()):
|
||||
unsorted_tensor_method_hints[name] += \
|
||||
generate_type_hints(name, tensor_method_declarations[name], namedtuples, is_tensor=True)
|
||||
# pyi tensor methods don't currently include deprecated signatures for some reason
|
||||
# TODO: we should probably add them in
|
||||
tensor_method_signatures = load_signatures(native_yaml_path, deprecated_yaml_path, method=True, skip_deprecated=True, pyi=True)
|
||||
tensor_method_sig_groups = get_py_torch_functions(tensor_method_signatures, method=True)
|
||||
|
||||
for name in sorted(tensor_method_sig_groups.keys()):
|
||||
unsorted_tensor_method_hints[name] += generate_type_hints(tensor_method_sig_groups[name], is_tensor=True)
|
||||
namedtuples.update(generate_named_tuples(tensor_method_sig_groups[name]))
|
||||
|
||||
for op in all_ops:
|
||||
name = '__{}__'.format(op)
|
||||
|
|
@ -764,17 +582,20 @@ def gen_pyi(declarations_path, out):
|
|||
gen_nn_pyi(out)
|
||||
|
||||
|
||||
def main():
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Generate type stubs for PyTorch')
|
||||
parser.add_argument('--declarations-path', metavar='DECL',
|
||||
default='torch/share/ATen/Declarations.yaml',
|
||||
help='path to Declarations.yaml')
|
||||
parser.add_argument('--native-functions-path', metavar='NATIVE',
|
||||
default='aten/src/ATen/native/native_functions.yaml',
|
||||
help='path to native_functions.yaml')
|
||||
parser.add_argument('--deprecated-functions-path', metavar='DEPRECATED',
|
||||
default='tools/autograd/deprecated.yaml',
|
||||
help='path to deprecated.yaml')
|
||||
parser.add_argument('--out', metavar='OUT',
|
||||
default='.',
|
||||
help='path to output directory')
|
||||
args = parser.parse_args()
|
||||
gen_pyi(args.declarations_path, args.out)
|
||||
gen_pyi(args.native_functions_path, args.deprecated_functions_path, args.out)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
|||
|
|
@ -234,9 +234,9 @@ add_custom_command(
|
|||
"${TORCH_SRC_DIR}/nn/functional.pyi"
|
||||
COMMAND
|
||||
"${PYTHON_EXECUTABLE}" -mtools.pyi.gen_pyi
|
||||
--declarations-path "${CMAKE_BINARY_DIR}/aten/src/ATen/Declarations.yaml"
|
||||
--native-functions-path "aten/src/ATen/native/native_functions.yaml"
|
||||
--deprecated-functions-path "tools/autograd/deprecated.yaml"
|
||||
DEPENDS
|
||||
"${CMAKE_BINARY_DIR}/aten/src/ATen/Declarations.yaml"
|
||||
"${TORCH_SRC_DIR}/_C/__init__.pyi.in"
|
||||
"${TORCH_SRC_DIR}/_C/_VariableFunctions.pyi.in"
|
||||
"${TORCH_SRC_DIR}/nn/functional.pyi.in"
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user