pyi codegen update - remove Declarations.yaml (#48754)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/48754

The goal of this PR is to kill Declarations.yaml in the pyi codegen, in favor of native_functions + the existing python object model.

**High-level design**

Since the python signatures used by the `python_arg_parser` are “supposed” to resemble the corresponding pyi type hint signatures, I re-used the existing python object model that Jiakai defined in `tools/codegen/api/python.py`. This means that the pyi codegen now reads `native_functions.yaml`, parses it into a bunch of `PythonSignatureGroup` objects, and emits corresponding method + function variants of type-hint signatures for each one, respectively into `__init__.pyi` and `_VariableFunctions.pyi`.

What makes this uglier is that pyi and the python arg parser have a number of differences in how they’re emitted. I expressed that through a `pyi` flag on the `PythonSignature` dataclass, that tells it whether or not to print itself as a pyi vs. arg_parser signature.

One thing worth noting is how pyi generates signatures differently for native / deprecated op signatures.

For native ops:
- The pyi codegen fuses functional and out variants of each op into a single signature with an optional `out` argument. Ops without an `out` variant just get an ordinary functional signature.
- Some ops that fit certain criteria also get a second “varargs” signature - basically ops with a single positional argument of type List[int].

For deprecated signatures:
- Functional and out variants are not fused - they each get their own signature entry
- There are no varargs signatures

This is currently implemented through the `signature_str()` and `signature_str_vararg()` methods on the `PythonSignature`/`PythonSignatureDeprecated` classes.  `signature_str()` knows how to print itself with/without out arguments, differently for native/deprecated ops. `signature_str_vararg()` optionally returns a vararg variant of the signature if one exists.

**Calling out the gap between python_arg_parser vs. pyi**

The two formats are notably different, so I don’t think we can expect to unify them completely. That said, I encountered a number of differences in the pyi codegen that looked wrong- I tried to call them out in the PR, to be removed later. Just as an example, looking at the `svd` signature in the python_arg_parser vs. the pyi type hint:

python_arg_parser
```
Static PythonArgParser parser({
  “svd(Tensor input, bool some=True, bool compute_uv=True, *, TensorList[3] out=None”,
}, /*traceable=*/true);
```

Pyi
```
def svd(input: Tensor, some: _bool=True, compute_uv: _bool=True, *, out: Optional[Tensor]=None) -> namedtuple_U_S_V: …
```

The two have obvious syntactic differences that we probably don’t plan on changing: the python_arg_parser doesn’t include `def` or return types, and it includes the type hint before the variable name. But the type of `out` in pyi is probably wrong, since `svd` has multiple output params. I tried to clearly call out any instances of the pyi codegen diverging in a way that looks buggy, so we can clean it up in a later PR (see the comments for details).

Another particularly ugly “bug” that I kept in to maintain byte-for-byte compatibility is the fact that the pyi codegen groups operator overloads together. It turns out that the only reason it does this (as far as I can tell) is because is tacks on an out argument to signatures that don’t have one, if ANY overloads of that op have an out variant.

E.g. consider the pyi type hints generated for `nanmedian` in `_VF.pyi`:
```
overload
def nanmedian(input: Tensor, *, out: Optional[Tensor]=None) -> Tensor: ...
overload
def nanmedian(input: Tensor, dim: _int, keepdim: _bool=False, *, out: Optional[Tensor]=None) -> namedtuple_values_indices: ...
overload
def nanmedian(input: Tensor, dim: Union[str, ellipsis, None], keepdim: _bool=False, *, out: Optional[Tensor]=None) -> namedtuple_values_indices: ...
```

And the corresponding native_functions.yaml entries:
```
- func: nanmedian(Tensor self) -> Tensor
- func: nanmedian.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices)
- func: nanmedian.dim_values(Tensor self, int dim, bool keepdim=False, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices)
- func: nanmedian.names_dim(Tensor self, Dimname dim, bool keepdim=False) -> (Tensor values, Tensor indices)
- func: nanmedian.names_dim_values(Tensor self, Dimname dim, bool keepdim=False, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!)
```

Signature 2 corresponds to entries 2 and 3 in native_functions, and Signature 3 corresponds to entries 4 and 5. But signature 1 has an optional out argument, even though entry 1 in native_functions.yaml has no out variant.

I’d like to delete that logic in a later PR- that will also have the added benefit no longer requiring to group overloads together in the pyi codegen. We can just operate independently on each PythonSignatureGroup.

**More detailed accounting of the changes**

Per file:

gen_python_functions.py
- `load_signatures()` can now skip deprecated signatures. Needed because pyi only includes deprecated functions, and skips their method variants (maybe we should add them in…?)
- Moved `namedtuple_fieldnames` into python.cpp
- `group_overloads()` can now opt to not sort the overloads (needed for byte-for-byte compact, pyi doesn’t sort for some reason)

Python.py:
- Gave `PythonSignature`and `PythonSignatureDeprecated` a `pyi` flag that tells it whether or not to print itself in pyi vs. python_arg_parser format
- Added a `PythonReturns` dataclass , which is now a member of PythonSignature. It is only used by pyi. I found this useful because python returns need to know how to deal with named tuple returns properly. I also moved `namedtuple_fieldnames` into this file from gen_python_functions

gen_pyi.py
- Merged `get_py_torch_functions` and `get_py_variable_methods` into a single function, since they’re very similar
- Lifted out all of the pyi type hint type-mapping mess and dropped it into python.py. This required updating the mapping to deal with NativeFunction objects instead of the outputs of Declarations.yaml (this was most of the logic in `type_to_python`, `arg_to_type_hint`, and `generate_type_hints`).  `generate_type_hints` is now a small orchestration function that gathers the different signatures for each PythonSignatureGroup.
- NamedTuples are now generated by calling `PythonReturn.named_tuple()` (in `generate_named_tuples()`), rather than appending to a global list

A lot of hardcoded pyi signatures still live in `gen_pyi.py`. I didn’t look to closely into whether or not any of that can be removed as part of this PR.

Test Plan: Imported from OSS

Reviewed By: ljk53

Differential Revision: D25343802

Pulled By: bdhirsh

fbshipit-source-id: f73e99e1afef934ff41e4aca3dabf34273459a52
This commit is contained in:
Brian Hirsh 2020-12-07 10:37:38 -08:00 committed by Facebook GitHub Bot
parent f2c3efd51f
commit ba6511b304
7 changed files with 378 additions and 360 deletions

View File

@ -38,6 +38,8 @@ mkdir -p "$OUT"/pyi/torch/_C
mkdir -p "$OUT"/pyi/torch/nn
python -m tools.pyi.gen_pyi \
--declarations-path "$OUT"/torch/share/ATen/Declarations.yaml \
--native-functions-path aten/src/ATen/native/native_functions.yaml \
--deprecated-functions-path tools/autograd/deprecated.yaml \
--out "$OUT"/pyi
# autograd codegen (called by torch codegen but can run independently)

View File

@ -35,6 +35,7 @@ files = tools/codegen/gen.py,
tools/autograd/gen_trace_type.py,
tools/autograd/gen_variable_factories.py,
tools/autograd/load_derivatives.py,
tools/pyi/gen_pyi.py,
torch/utils/benchmark/utils/common.py,
torch/utils/benchmark/utils/timer.py,
torch/utils/benchmark/utils/valgrind_wrapper/*.py,

View File

@ -193,25 +193,28 @@ def load_signatures(
deprecated_yaml_path: str,
*,
method: bool,
skip_deprecated: bool = False,
pyi: bool = False,
) -> Sequence[PythonSignatureNativeFunctionPair]:
native_functions = list(filter(should_generate_py_binding, parse_native_yaml(native_yaml_path)))
@with_native_function
def gen_signature_pairs(f: NativeFunction) -> PythonSignatureNativeFunctionPair:
return PythonSignatureNativeFunctionPair(
signature=signature(f, method=method),
signature=signature(f, method=method, pyi=pyi),
function=f,
)
pairs = list(map(gen_signature_pairs, native_functions))
deprecated = load_deprecated_signatures(pairs, deprecated_yaml_path, method=method)
return pairs + deprecated
deprecated = load_deprecated_signatures(pairs, deprecated_yaml_path, method=method, pyi=pyi)
return pairs if skip_deprecated else pairs + deprecated
def load_deprecated_signatures(
pairs: Sequence[PythonSignatureNativeFunctionPair],
deprecated_yaml_path: str,
*,
method: bool,
pyi: bool,
) -> List[PythonSignatureNativeFunctionPair]:
# The deprecated.yaml doesn't have complete type information, we need
# find and leverage the original ATen signature (to which it delegates
@ -225,6 +228,10 @@ def load_deprecated_signatures(
opname = str(f.func.name.name.base)
if f.func.is_out_fn():
opname += '_out'
# TODO: remove HACK
# I think we want to differentiate inplace functions here.. but we currently don't for the arg parser
if f.func.name.name.inplace and pyi:
opname += '_'
args = CppSignatureGroup.from_schema(f.func, method=False).signature.arguments()
# Simply ignore TensorOptionsArguments as it does not exist in deprecated.yaml.
types = ', '.join(argument_type_str(a.argument.type)
@ -308,6 +315,7 @@ def load_deprecated_signatures(
method=python_sig.method,
deprecated_args_names=tuple(args),
deprecated_args_exprs=tuple(call_args),
returns=python_sig.returns,
),
function=pair.function,
))
@ -320,31 +328,10 @@ def load_deprecated_signatures(
#
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
# TODO: remove the copy of this method in 'tools/pyi/gen_pyi.py'.
@with_native_function
def namedtuple_fieldnames(f: NativeFunction) -> List[str]:
returns = f.func.returns
if len(returns) <= 1 or all(map(lambda r: r.name is None, returns)):
return []
else:
if any(map(lambda r: r.name is None, returns)):
# When building on Windows, `PyStructSequence_UnnamedField` could not be
# resolved by the linker for some reason, which cause error in building:
#
# python_nn_functions.cpp.obj : error LNK2001: unresolved external symbol
# PyStructSequence_UnnamedField
#
# Thus, at this point in time, we do not support unnamed
# fields in namedtuple; you must either name all fields,
# or none of them.
raise ValueError("Unnamed field is not supported by codegen")
return list(map(lambda r: str(r.name), returns))
@with_native_function
def gen_namedtuple_typename_key(f: NativeFunction) -> str:
name = cpp.name(f.func)
fieldnames = namedtuple_fieldnames(f)
fieldnames = namedtuple_fieldnames(f.func.returns)
return '_'.join([name] + fieldnames)
def emit_namedtuple_typedefs(
@ -360,7 +347,7 @@ def emit_namedtuple_typedefs(
typedefs: List[str] = [] # typedef declarations and init code
for overload in overloads:
fieldnames = namedtuple_fieldnames(overload.function)
fieldnames = namedtuple_fieldnames(overload.function.func.returns)
if not fieldnames:
continue
@ -651,7 +638,9 @@ def method_def(
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
def group_overloads(
overloads: Sequence[PythonSignatureNativeFunctionPair]
overloads: Sequence[PythonSignatureNativeFunctionPair],
*,
sort: bool = True,
) -> Sequence[PythonSignatureGroup]:
bases: Dict[str, PythonSignatureNativeFunctionPair] = {}
outplaces: Dict[str, PythonSignatureNativeFunctionPair] = {}
@ -700,7 +689,9 @@ def group_overloads(
outplace=outplace.function if outplace is not None else None,
))
return sort_overloads(grouped)
# TODO: unconditionally sort
# maintaining byte-for-byte compatibility for pyi codegen for now
return grouped if not sort else sort_overloads(grouped)
# This function declares a partial order on declarations, and sorts them according
# to its linear extension. This is necessary, because there's some ambiguity in the

View File

@ -1,8 +1,8 @@
import re
import os
import yaml
from collections import defaultdict
from .nested_dict import nested_dict
from typing import Dict, List
__all__ = [
@ -52,7 +52,7 @@ def uninplace_api_name(api_name):
return api_name
def write(dirname, name, template, env):
def write(dirname: str, name: str, template: CodeTemplate, env: Dict[str, List[str]]) -> None:
env['generated_comment'] = GENERATED_COMMENT.substitute(filename=template.filename)
path = os.path.join(dirname, name)
# See Note [Unchanging results for ninja]
@ -69,12 +69,6 @@ def write(dirname, name, template, env):
else:
print("Skipped writing {}".format(path))
def is_tensor_method(declaration):
return 'Tensor' in declaration['method_of']
def is_torch_function(declaration):
return 'namespace' in declaration['method_of']
def is_out_variant(decl):
return decl['name'].endswith('_out')
@ -92,12 +86,6 @@ def load_op_list_and_strip_overload(op_list, op_list_path):
# strip out the overload part
return {opname.split('.', 1)[0] for opname in op_list}
def group_declarations_by_op_name(declarations):
groups = defaultdict(list)
for d in declarations:
groups[op_name(d)].append(d)
return groups
def is_output(arg):
return arg.get('output', False)

View File

@ -173,6 +173,53 @@ from tools.codegen.model import *
# }
#
# TODO: stick this more firmly in the data model somewhere?
def namedtuple_fieldnames(returns: Tuple[Return, ...]) -> List[str]:
if len(returns) <= 1 or all(map(lambda r: r.name is None, returns)):
return []
else:
if any(map(lambda r: r.name is None, returns)):
# When building on Windows, `PyStructSequence_UnnamedField` could not be
# resolved by the linker for some reason, which cause error in building:
#
# python_nn_functions.cpp.obj : error LNK2001: unresolved external symbol
# PyStructSequence_UnnamedField
#
# Thus, at this point in time, we do not support unnamed
# fields in namedtuple; you must either name all fields,
# or none of them.
raise ValueError("Unnamed field is not supported by codegen")
return list(map(lambda r: str(r.name), returns))
@dataclass(frozen=True)
class PythonReturns:
returns: Tuple[Return, ...]
def named_tuple_pyi(self) -> Optional[Tuple[str, str]]:
python_returns = [argument_type_str_pyi(r.type) for r in self.returns]
field_names = namedtuple_fieldnames(self.returns)
if field_names:
namedtuple_name = '_'.join(['namedtuple'] + field_names)
tuple_args = [f'("{name}", {typ})' for name, typ in zip(field_names, python_returns)]
namedtuple_def = f'NamedTuple("{namedtuple_name}", [{", ".join(tuple_args)}])'
return namedtuple_name, namedtuple_def
return None
def returns_str_pyi(self) -> str:
named_tuple = self.named_tuple_pyi()
if named_tuple is not None:
namedtuple_name, _ = named_tuple
return namedtuple_name
python_returns = [argument_type_str_pyi(r.type) for r in self.returns]
if len(python_returns) > 1:
return 'Tuple[' + ', '.join(python_returns) + ']'
if len(python_returns) == 1:
return python_returns[0]
return 'None'
@dataclass(frozen=True)
class PythonArgument:
name: str
@ -189,26 +236,56 @@ class PythonArgument:
# Compute argument formal for python argument parsing.
# Needs to be consistent with torch/csrc/utils/python_arg_parser.h.
def argument_str(self, *, method: bool = False) -> str:
type_str = argument_type_str(self.type)
def argument_str(self, *, method: bool = False, pyi: bool = False, deprecated: bool = False) -> str:
type_str = argument_type_str_pyi(self.type, pyi_out_arg=pyi and isinstance(self, PythonOutArgument)) \
if pyi else argument_type_str(self.type)
name = self.name
# s/self/input/ outside method bindings
# [old codegen] TODO: remove this? doesn't rename in codegen, it's just
# for the parse string
name = self.name
if name == 'self' and type_str == 'Tensor' and not method:
if name == 'self' and type_str == 'Tensor' and not method and not deprecated:
name = 'input'
if pyi:
if name == 'from': # from is a Python keyword...
name += '_'
# pyi merges the _out and functional variants into the same signature, with an optional out arg
if name == 'out' and type_str == 'Tensor' and not deprecated:
type_str = 'Optional[' + type_str + ']'
# TODO: remove diff. pyi deprecated signatures don't get defaults for their out arg
treat_as_no_default = pyi and deprecated and isinstance(self, PythonOutArgument) and self.default == 'None'
# add default
if self.default is not None:
default = {
'nullptr': 'None',
'c10::nullopt': 'None',
'{}': 'None',
}.get(self.default, self.default)
return f'{type_str} {name}={default}'
if self.default is not None and not treat_as_no_default:
if pyi:
if isinstance(self.type, ListType) and self.type.elem == BaseType(BaseTy.int) and \
self.default.startswith('{') and self.default.endswith('}'):
default = '(' + self.default[1:-1] + ')'
else:
default = {
'nullptr': 'None',
'c10::nullopt': 'None',
'{}': 'None',
'MemoryFormat::Contiguous': 'contiguous_format',
'QScheme::PER_TENSOR_AFFINE': 'per_tensor_affine',
}.get(self.default, self.default)
# TODO: remove requires_grad special case (byte-for-byte compat)
return f'{name}:{type_str}={default}' if name == 'requires_grad' else f'{name}: {type_str}={default}'
else:
default = {
'nullptr': 'None',
'c10::nullopt': 'None',
'{}': 'None',
}.get(self.default, self.default)
return f'{type_str} {name}={default}'
else:
return f'{type_str} {name}'
if pyi:
# TODO: remove requires_grad special case (byte-for-byte compat)
return f'{name}:{type_str}' if name == 'requires_grad' else f'{name}: {type_str}'
else:
return f'{type_str} {name}'
@dataclass(frozen=True)
class PythonOutArgument(PythonArgument):
@ -238,6 +315,7 @@ class PythonOutArgument(PythonArgument):
raise RuntimeError(f'Unsupported output type: {outputs}')
return PythonOutArgument(
name='out',
# TODO: shouldn't this be OptionalType[ListType[...]], since it defaults to None?
type=ListType(BaseType(BaseTy.Tensor), size),
default='None',
default_init=None,
@ -260,6 +338,9 @@ class PythonSignature:
output_args: Optional[PythonOutArgument]
# Return types, which are only used by pyi
returns: PythonReturns
# These are scattered kwargs arguments belonging to TensorOptions.
# When binding to C++, they are packed into a TensorOptions object 'options'.
# It's possible that the C++ signature doesn't take TensorOptions object (e.g.
@ -276,13 +357,23 @@ class PythonSignature:
return False
def arguments(
self, *, skip_outputs: bool = False, skip_tensor_options: bool = False
self, *, skip_outputs: bool = False, skip_tensor_options: bool = False, hacky_add_output: bool = False
) -> Tuple[Union[PythonArgument, PythonOutArgument], ...]:
result: List[Union[PythonArgument, PythonOutArgument]] = []
result.extend(self.input_args)
result.extend(self.input_kwargs)
if self.output_args is not None and not skip_outputs:
result.append(self.output_args)
# TODO: remove HACK
# in the existing pyi codegen, we tack on an optional out argument to every operator overload
# if there exists at least one overload with an out variant. This seems wrong.
elif hacky_add_output:
result.extend([PythonOutArgument(
name='out',
type=OptionalType(BaseType(BaseTy.Tensor)),
default='None',
default_init=None,
outputs=())])
if not skip_tensor_options:
result.extend(self.tensor_options_args)
return tuple(result)
@ -301,18 +392,57 @@ class PythonSignature:
# for error parsing.
#
# For a translation to mypy-valid type signatures, see
# tools/gen_pyi.py. If you change any logic here, please
# signature_str_pyi. If you change any logic here, please
# check that file too.
def signature_str(self, *, skip_outputs: bool = False) -> str:
schema_formals: List[str] = \
list(map(lambda a: a.argument_str(method=self.method),
self.arguments(skip_outputs=skip_outputs)))
args = self.arguments(skip_outputs=skip_outputs)
schema_formals: List[str] = list(map(lambda a: a.argument_str(method=self.method), args))
positional_argc = len(self.input_args)
if len(schema_formals) > positional_argc:
schema_formals.insert(positional_argc, '*')
return f'{self.name}({", ".join(schema_formals)})'
def signature_str_pyi(self, *, skip_outputs: bool = False, hacky_add_output: bool = False) -> str:
args = self.arguments(skip_outputs=skip_outputs, hacky_add_output=hacky_add_output)
schema_formals: List[str] = list(map(lambda a: a.argument_str(method=self.method, pyi=True), args))
positional_argc = len(self.input_args)
if len(schema_formals) > positional_argc:
schema_formals.insert(positional_argc, '*')
# only pyi signatures include returns
returns_str = self.returns.returns_str_pyi()
# pyi also includes self (with no typing/defaults) for methods
if self.method:
schema_formals.insert(0, "self")
return f'def {self.name}({", ".join(schema_formals)}) -> {returns_str}: ...'
def signature_str_pyi_vararg(self, *, skip_outputs: bool = False, hacky_add_output: bool = False) -> Optional[str]:
# only pyi uses vararg signatures
args = self.arguments(skip_outputs=skip_outputs, hacky_add_output=hacky_add_output)
schema_formals: List[str] = list(map(lambda a: a.argument_str(method=self.method, pyi=True), args))
# vararg only applies to pyi signatures. vararg variants are not generated for all signatures
num_args = self.arguments_count()
num_positionalargs = len(self.input_args)
have_vararg_version = False
if num_args > 0:
vararg_type = args[0].type
if isinstance(vararg_type, ListType) and str(vararg_type.elem) == 'int' and num_positionalargs == 1:
have_vararg_version = True
if not have_vararg_version:
return None
# Below are the major changes in vararg vs. regular pyi signatures
# vararg signatures also omit the asterix
schema_formals[0] = '*' + args[0].name + ': _int'
returns_str = self.returns.returns_str_pyi()
# pyi also includes self (with no typing/defaults) for methods
if self.method:
schema_formals.insert(0, "self")
return f'def {self.name}({", ".join(schema_formals)}) -> {returns_str}: ...'
# The deprecated python signature involves some special logic, so create a
# dedicated data model to store these extra properties.
@dataclass(frozen=True)
@ -340,6 +470,20 @@ class PythonSignatureDeprecated(PythonSignature):
def signature_str(self, *, skip_outputs: bool = False) -> str:
return PythonSignature.signature_str(self, skip_outputs=skip_outputs) + '|deprecated'
def signature_str_pyi(self, *, skip_outputs: bool = False, hacky_add_output: bool = False) -> str:
args = self.arguments(skip_outputs=skip_outputs, hacky_add_output=hacky_add_output)
schema_formals: List[str] = list(map(lambda a: a.argument_str(method=self.method, pyi=True, deprecated=True), args))
positional_argc = len(self.input_args)
if len(schema_formals) > positional_argc:
schema_formals.insert(positional_argc, '*')
returns_str = self.returns.returns_str_pyi()
return f'def {self.name}({", ".join(schema_formals)}) -> {returns_str}: ...'
def signature_str_pyi_vararg(self, *, skip_outputs: bool = False, hacky_add_output: bool = False) -> Optional[str]:
# the codegen doesn't include vararg variants for deprecated signatures
return None
# This struct is used to hold the PythonSignature and its corresponding
# NativeFunction BEFORE grouping base and out-variant functions.
# Why not store NativeFunction in PythonSignature or construct PythonSignature
@ -520,12 +664,75 @@ def argument(a: Argument) -> PythonArgument:
default_init=None,
)
def signature(f: NativeFunction, *, method: bool = False) -> PythonSignature:
def argument_type_str_pyi(t: Type, *, pyi_out_arg: bool = False) -> str:
add_optional = False
if isinstance(t, OptionalType):
t = t.elem
add_optional = True
if isinstance(t, BaseType):
if t.name == BaseTy.int:
ret = '_int'
elif t.name == BaseTy.float:
ret = '_float'
elif t.name == BaseTy.str:
ret = 'str'
elif t.name == BaseTy.Scalar:
ret = 'Number'
elif t.name == BaseTy.ScalarType:
ret = '_dtype'
elif t.name == BaseTy.bool:
ret = '_bool'
elif t.name == BaseTy.QScheme:
ret = '_qscheme'
elif t.name == BaseTy.Layout:
ret = '_layout'
elif t.name == BaseTy.Device:
ret = 'Union[_device, str, None]'
elif t.name == BaseTy.MemoryFormat:
ret = 'memory_format'
elif t.name == BaseTy.Dimname:
ret = 'Union[str, ellipsis, None]'
elif t.name in [BaseTy.Tensor, BaseTy.Generator,
BaseTy.Storage, BaseTy.Stream, BaseTy.str]:
# These python schema type names line up with their function schema names
ret = t.name.name
elif isinstance(t, ListType):
if pyi_out_arg and t.is_tensor_like():
# TODO remove HACK
# pyi blindly treats all tensor-like out args as having type Tensor
return 'Tensor'
if str(t.elem) == 'int':
ret = 'Union[_int, _size]' if t.size is not None else '_size'
elif t.is_tensor_like():
# TODO: this doesn't seem right...
# Tensor?[] currently translates to Optional[Union[Tuple[Tensor, ...], List[Tensor]]]
# It should probably translate to Union[Tuple[Optional[Tensor], ...], List[Optional[Tensor]]]
if isinstance(t.elem, OptionalType):
add_optional = True
ret = 'Union[Tensor, Tuple[Tensor, ...], List[Tensor]]' if t.size is not None else \
'Union[Tuple[Tensor, ...], List[Tensor]]'
elif str(t.elem) == 'float':
ret = 'Sequence[float]'
else:
elem = argument_type_str_pyi(t.elem)
ret = f'Sequence[{elem}]'
if add_optional:
ret = 'Optional[' + ret + ']'
return ret
raise RuntimeError(f'unrecognized type {repr(t)}')
# Generates a PythonSignature that can be used for either .pyi or PythonArgParser codegen
def signature(f: NativeFunction, *, method: bool = False, pyi: bool = False) -> PythonSignature:
# Use cpp api to gather TensorOptions fields from kwargs.
# Skip ThisArgument if this is method signature.
# Skip SelfArgument if this is method.
# Skip TensorOptionsArguments in C++ signature. Python side TensorOptions
# arguments are created based on different rules - see below.
args = tuple(a for a in cpp.group_arguments(f.func, method=method) if isinstance(a, Argument))
cpp_args = cpp.group_arguments(f.func, method=method)
args = tuple(a for a in cpp_args if isinstance(a, Argument))
input_arg_set = set(a.name for a in f.func.arguments.positional)
kwarg_only_set = set(a.name for a in f.func.arguments.kwarg_only)
@ -561,13 +768,15 @@ def signature(f: NativeFunction, *, method: bool = False) -> PythonSignature:
tensor_options_args.append(PythonArgument(
name='dtype',
type=BaseType(BaseTy.ScalarType),
default=_dtype_default_type_hack(name),
default=_dtype_default_type_hack(name, pyi=pyi),
default_init='self.scalar_type()' if is_like_or_new_function else None,
))
# TODO: probably a bug, kill this diff?
# pyi signatures have a slightly different type/default for layout
tensor_options_args.append(PythonArgument(
name='layout',
type=OptionalType(BaseType(BaseTy.Layout)),
default='torch.strided',
type=BaseType(BaseTy.Layout) if pyi else OptionalType(BaseType(BaseTy.Layout)),
default='strided' if pyi else 'torch.strided',
default_init='layout_from_backend(self.options().backend())' if is_like_or_new_function else None,
))
tensor_options_args.append(PythonArgument(
@ -576,12 +785,15 @@ def signature(f: NativeFunction, *, method: bool = False) -> PythonSignature:
default='None',
default_init='self.device()' if is_like_or_new_function else None,
))
tensor_options_args.append(PythonArgument(
name='pin_memory',
type=BaseType(BaseTy.bool),
default='False',
default_init=None,
))
# TODO: probably a bug, kill this diff?
# pyi signatures don't include pin memory
if not pyi:
tensor_options_args.append(PythonArgument(
name='pin_memory',
type=BaseType(BaseTy.bool),
default='False',
default_init=None,
))
tensor_options_args.append(PythonArgument(
name='requires_grad',
type=BaseType(BaseTy.bool),
@ -589,18 +801,21 @@ def signature(f: NativeFunction, *, method: bool = False) -> PythonSignature:
default_init=None,
))
returns = PythonReturns(returns=f.func.returns)
return PythonSignature(
name=str(f.func.name.name),
input_args=input_args,
input_kwargs=input_kwargs,
output_args=PythonOutArgument.from_outputs(outputs),
tensor_options_args=tuple(tensor_options_args),
returns=returns,
method=method,
)
# TODO blowtorch
def _dtype_default_type_hack(name: str) -> str:
if name.startswith('randperm') or name == 'tril_indices' or name == 'triu_indices':
def _dtype_default_type_hack(name: str, *, pyi: bool) -> str:
if not pyi and (name.startswith('randperm') or name == 'tril_indices' or name == 'triu_indices'):
return 'torch.int64'
else:
return 'None'

View File

@ -3,13 +3,14 @@ import os
import collections
from pprint import pformat
import yaml
import re
import argparse
from ..autograd.utils import YamlLoader, CodeTemplate, write, group_declarations_by_op_name, is_tensor_method, is_torch_function
from ..autograd.gen_python_functions import SKIP_PYTHON_BINDINGS, SKIP_PYTHON_BINDINGS_SIGNATURES
from ..autograd.gen_autograd import load_aten_declarations
from tools.codegen.model import *
from tools.codegen.api.python import *
from typing import Sequence, List, Mapping, Dict
from ..autograd.utils import CodeTemplate, write
from ..autograd.gen_python_functions import should_generate_py_binding, load_signatures, group_overloads
"""
This module implements generation of type stubs for PyTorch,
@ -28,60 +29,48 @@ Here's our general strategy:
(the latter case should be pretty rare).
- We go through automatically bound functions based on the
type information recorded in Declarations.yaml and
type information recorded in native_functions.yaml and
generate type hints for them (generate_type_hints)
There are a number of type hints which we've special-cased;
read gen_pyi for the gory details.
"""
# TODO: remove after migrating entire codegen to the new data model.
def should_generate_python_binding(declaration):
name = declaration['name']
for pattern in SKIP_PYTHON_BINDINGS:
if re.match('^' + pattern + '$', name):
return False
# TODO: consider waiting to group by base name until we actually need to
# (after computing type hint signatures, when adding @overload directives)
def group_by_base_name(python_funcs: Sequence[PythonSignatureNativeFunctionPair]) -> Mapping[str, List[PythonSignatureGroup]]:
groups = group_overloads(python_funcs, sort=False)
d = collections.defaultdict(list)
for g in groups:
name = g.signature.name
d[name].append(g)
return d
simple_types = [arg['simple_type'] for arg in declaration['arguments']]
signature = '{}({})'.format(name, ', '.join(simple_types))
for pattern in SKIP_PYTHON_BINDINGS_SIGNATURES:
if pattern == signature:
return False
return True
def get_py_variable_methods(declarations):
def get_py_torch_functions(
python_funcs: Sequence[PythonSignatureNativeFunctionPair],
method: bool = False,
) -> Mapping[str, Sequence[PythonSignatureGroup]]:
"""
Get declarations (grouped by name) which should be generated
as methods on Tensor.
as either functions in the "torch" module or methods on Tensor.
"""
def should_bind(declaration):
return (should_generate_python_binding(declaration) and
not declaration.get('python_module') and
is_tensor_method(declaration))
def should_bind_function(python_func: PythonSignatureNativeFunctionPair) -> bool:
return (should_generate_py_binding(python_func.function) and
not python_func.function.python_module and
Variant.function in python_func.function.variants)
return group_declarations_by_op_name([d for d in declarations if should_bind(d)])
def should_bind_method(python_func: PythonSignatureNativeFunctionPair) -> bool:
return (should_generate_py_binding(python_func.function) and
not python_func.function.python_module and
Variant.method in python_func.function.variants)
def get_py_torch_functions(declarations):
"""
Get declarations (grouped by name) which should be generated
as functions in the "torch" module.
"""
def should_bind(declaration):
return (should_generate_python_binding(declaration) and
not declaration.get('python_module') and
is_torch_function(declaration))
return group_declarations_by_op_name([d for d in declarations if should_bind(d)])
should_bind = should_bind_method if method else should_bind_function
return group_by_base_name([f for f in python_funcs if should_bind(f)])
# TODO: Consider defining some aliases for our Union[...] types, to make
# the stubs to read on the human eye.
needed_modules = set()
DEVICE_PARAM = "device: Union[_device, str, None]=None"
FACTORY_PARAMS = f"dtype: Optional[_dtype]=None, {DEVICE_PARAM}, requires_grad: _bool=False"
@ -144,90 +133,6 @@ blocklist = [
]
def type_to_python(typename, size=None):
"""type_to_python(typename: str, size: str) -> str
Transforms a Declarations.yaml type name into a Python type specification
as used for type hints.
"""
typename = typename.replace(' ', '') # normalize spaces, e.g., 'Generator *'
# Disambiguate explicitly sized int/tensor lists from implicitly
# sized ones. These permit non-list inputs too. (IntArrayRef[] and
# TensorList[] are not real types; this is just for convenience.)
if typename in {'IntArrayRef', 'TensorList'} and size is not None:
typename += '[]'
typename = {
'Device': 'Device',
'Generator': 'Generator',
'IntegerTensor': 'Tensor',
'Scalar': 'Number',
'ScalarType': '_dtype',
'Storage': 'Storage',
'BoolTensor': 'Tensor',
'IndexTensor': 'Tensor',
'Tensor': 'Tensor',
'MemoryFormat': 'memory_format',
'IntArrayRef': '_size',
'IntArrayRef[]': 'Union[_int, _size]',
'TensorList': 'Union[Tuple[Tensor, ...], List[Tensor]]',
'TensorList[]': 'Union[Tensor, Tuple[Tensor, ...], List[Tensor]]',
'bool': '_bool',
'double': '_float',
'int64_t': '_int',
'accreal': 'Number',
'real': 'Number',
'void*': '_int', # data_ptr
'void': 'None',
'std::string': 'str',
'Dimname': 'Union[str, ellipsis, None]',
'DimnameList': 'Sequence[Union[str, ellipsis, None]]',
'QScheme': '_qscheme',
'ArrayRef<double>' : 'Sequence[float]',
'Stream': 'Stream',
}[typename]
return typename
def arg_to_type_hint(arg):
"""arg_to_type_hint(arg) -> str
This takes one argument in a Declarations and returns a string
representing this argument in a type hint signature.
"""
name = arg['name']
if name == 'from': # from is a Python keyword...
name += '_'
typename = type_to_python(arg['dynamic_type'], arg.get('size'))
if arg.get('is_nullable'):
typename = 'Optional[' + typename + ']'
if 'default' in arg:
default = arg['default']
if default == 'nullptr':
default = None
elif default == 'c10::nullopt':
default = None
elif isinstance(default, str) and default.startswith('{') and default.endswith('}'):
if arg['dynamic_type'] == 'Tensor' and default == '{}':
default = None
elif arg['dynamic_type'] == 'Generator' and default == '{}':
default = None
elif arg['dynamic_type'] == 'IntArrayRef':
default = '(' + default[1:-1] + ')'
else:
raise Exception("Unexpected default constructor argument of type {}".format(arg['dynamic_type']))
elif default == 'MemoryFormat::Contiguous':
default = 'contiguous_format'
elif default == 'QScheme::PER_TENSOR_AFFINE':
default = 'per_tensor_affine'
default = '={}'.format(default)
else:
default = ''
return name + ': ' + typename + default
binary_ops = ('add', 'sub', 'mul', 'div', 'pow', 'lshift', 'rshift', 'mod', 'truediv',
'matmul', 'floordiv',
'radd', 'rsub', 'rmul', 'rtruediv', 'rfloordiv', 'rpow', # reverse arithmetic
@ -241,7 +146,7 @@ to_py_type_ops = ('bool', 'float', 'complex', 'long', 'index', 'int', 'nonzero')
all_ops = binary_ops + comparison_ops + unary_ops + to_py_type_ops
def sig_for_ops(opname):
def sig_for_ops(opname: str) -> List[str]:
"""sig_for_ops(opname : str) -> List[str]
Returns signatures for operator special functions (__add__ etc.)"""
@ -271,146 +176,66 @@ def sig_for_ops(opname):
else:
raise Exception("unknown op", opname)
# Copied from 'gen_python_functions.py'
# TODO: consolidate after migrating to the new codegen model in 'tools/codegen'.
def namedtuple_fieldnames(declaration):
returns = declaration['returns']
if len(returns) <= 1 or all(['field_name' not in x for x in returns]):
return []
else:
def get_field_name(x):
# See Note [field_name versus name]
if 'field_name' not in x:
# When building on Windows, `PyStructSequence_UnnamedField` could not be
# resolved by the linker for some reason, which cause error in building:
#
# python_nn_functions.cpp.obj : error LNK2001: unresolved external symbol
# PyStructSequence_UnnamedField
#
# Thus, at this point in time, we do not support unnamed
# fields in namedtuple; you must either name all fields,
# or none of them.
raise ValueError("Unnamed field is not supported by codegen")
def generate_named_tuples(funcs: Sequence[PythonSignatureGroup]) -> Dict[str, str]:
namedtuples: Dict[str, str] = {}
for sig_group in funcs:
named_tuple = sig_group.signature.returns.named_tuple_pyi()
if named_tuple is not None:
tuple_name, tuple_def = named_tuple
if tuple_name in namedtuples:
assert namedtuples[tuple_name] == tuple_def
else:
return x['field_name']
return [get_field_name(x) for x in returns]
namedtuples[tuple_name] = tuple_def
return namedtuples
def generate_type_hints(fname, decls, namedtuples, is_tensor=False):
"""generate_type_hints(fname, decls, is_tensor=False)
def generate_type_hints(funcs: Sequence[PythonSignatureGroup], is_tensor: bool = False) -> List[str]:
"""generate_type_hints(funcs, is_tensor=False)
Generates type hints for the declarations pertaining to the function
:attr:`fname`. attr:`decls` are the declarations from the parsed
Declarations.yaml.
:attr:`namedtuples` is a dictionary for accumulating NamedTuple definitions.
:attr:`funcs` are the func from the parsed native_functions.yaml.
The :attr:`is_tensor` flag indicates whether we are parsing
members of the Tensor class (true) or functions in the
`torch` namespace (default, false).
This function currently encodes quite a bit about the semantics of
the translation C++ -> Python.
"""
if fname in blocklist:
return []
type_hints = []
dnames = ([d['name'] for d in decls])
has_out = fname + '_out' in dnames
any_out = any([g for g in funcs if g.outplace is not None])
if has_out:
decls = [d for d in decls if d['name'] != fname + '_out']
for sig_group in funcs:
# Some deprecated ops that are on the blocklist are still included in pyi
if sig_group.signature.name in blocklist and not sig_group.signature.deprecated:
continue
for decl in decls:
render_kw_only_separator = True # whether we add a '*' if we see a keyword only argument
python_args = []
has_tensor_options = 'TensorOptions' in (a['dynamic_type'] for a in decl['arguments'])
for a in decl['arguments']:
if a['dynamic_type'] != 'TensorOptions':
if a.get('kwarg_only', False) and render_kw_only_separator:
python_args.append('*')
render_kw_only_separator = False
try:
python_args.append(arg_to_type_hint(a))
except Exception:
print("Error while processing function {}".format(fname))
raise
if 'self: Tensor' in python_args:
self_index = python_args.index('self: Tensor')
python_args.remove('self: Tensor')
if is_tensor:
python_args = ['self'] + python_args
else:
python_args.insert(self_index, 'input: Tensor')
else:
if is_tensor:
raise Exception("method without self is unexpected")
if has_out:
if render_kw_only_separator:
python_args.append('*')
render_kw_only_separator = False
python_args.append('out: Optional[Tensor]=None')
if has_tensor_options:
if render_kw_only_separator:
python_args.append('*')
render_kw_only_separator = False
python_args += ["dtype: _dtype=None",
"layout: _layout=strided",
"device: Union[_device, str, None]=None",
"requires_grad:_bool=False"]
python_args_s = ', '.join(python_args)
python_returns = [type_to_python(r['dynamic_type']) for r in decl['returns']]
field_names = namedtuple_fieldnames(decl)
if field_names:
namedtuple_name = '_'.join(['namedtuple'] + field_names)
tuple_args = ['("{}", {})'.format(name, typ) for name, typ in zip(field_names, python_returns)]
namedtuple_def = 'NamedTuple("{}", [{}])'.format(namedtuple_name, ', '.join(tuple_args))
if namedtuple_name in namedtuples:
assert namedtuples[namedtuple_name] == namedtuple_def
else:
namedtuples[namedtuple_name] = namedtuple_def
python_returns_s = namedtuple_name
elif len(python_returns) > 1:
python_returns_s = 'Tuple[' + ', '.join(python_returns) + ']'
elif len(python_returns) == 1:
python_returns_s = python_returns[0]
else:
python_returns_s = 'None'
type_hint = "def {}({}) -> {}: ...".format(fname, python_args_s, python_returns_s)
numargs = len(decl['arguments'])
vararg_pos = int(is_tensor)
have_vararg_version = (numargs > vararg_pos and
decl['arguments'][vararg_pos]['dynamic_type'] in {'IntArrayRef'} and
(numargs == vararg_pos + 1 or python_args[vararg_pos + 1] == '*') and
(not is_tensor or decl['arguments'][0]['name'] == 'self'))
# deprecated signatures have separate entries for their functional and out variants
# (as opposed to the native ops, which fuse the two into a single signature).
# generate the functional variant here, if an out variant exists.
if sig_group.signature.deprecated and sig_group.outplace is not None:
type_hint = sig_group.signature.signature_str_pyi(skip_outputs=True)
type_hints.append(type_hint)
# TODO: remove HACK
# the pyi codegen currently adds an optional out param in cases where the current op does NOT have an out variant,
# but an overload of the op DOES have an out variant.
# TODO: After that, we should consider killing this method entirely and operating per PythonSignatureGroup
# rather than grouping their overloads together
# (since there isn't much else semantically meaningful about grouping overloads)
# this hack also doesn't apply to deprecated ops
hacky_add_output = any_out and sig_group.outplace is None and not sig_group.signature.deprecated
# PythonSignatureGroups that have both a functional + out variant get a single signature, with an optional out argument
# Generates the out variant if one exists. Otherwise, generate the functional variant
type_hint = sig_group.signature.signature_str_pyi(
skip_outputs=sig_group.outplace is None, hacky_add_output=hacky_add_output)
type_hints.append(type_hint)
if have_vararg_version:
# Two things come into play here: PyTorch has the "magic" that if the first and only positional argument
# is an IntArrayRef, it will be used as a vararg variant.
# The following outputs the vararg variant, the "pass a list variant" is output above.
# The other thing is that in Python, the varargs are annotated with the element type, not the list type.
typelist = decl['arguments'][vararg_pos]['dynamic_type']
vararg_type = '_int'
# replace first argument and eliminate '*' if present
python_args = ((['self'] if is_tensor else []) + ['*' + decl['arguments'][vararg_pos]['name'] +
': ' + vararg_type] + python_args[vararg_pos + 2:])
python_args_s = ', '.join(python_args)
type_hint = "def {}({}) -> {}: ...".format(fname, python_args_s, python_returns_s)
type_hints.append(type_hint)
# Some operators also additionally have a vararg variant of their signature
type_hint_vararg = sig_group.signature.signature_str_pyi_vararg(
skip_outputs=sig_group.outplace is None, hacky_add_output=hacky_add_output)
if type_hint_vararg:
type_hints.append(type_hint_vararg)
return type_hints
def gen_nn_functional(out):
def gen_nn_functional(out: str) -> None:
# Functions imported into `torch.nn.functional` from `torch`, perhaps being filtered
# through an `_add_docstr` call
imports = [
@ -475,10 +300,10 @@ def gen_nn_functional(out):
stubs = CodeTemplate.from_file(os.path.join('torch', '_C', '_nn.pyi.in'))
write(out, 'torch/_C/_nn.pyi', stubs, env)
def gen_nn_pyi(out):
def gen_nn_pyi(out: str) -> None:
gen_nn_functional(out)
def gen_pyi(declarations_path, out):
def gen_pyi(native_yaml_path: str, deprecated_yaml_path: str, out: str) -> None:
"""gen_pyi()
This function generates a pyi file for torch.
@ -491,16 +316,13 @@ def gen_pyi(declarations_path, out):
# checking. If you are update this, consider if your change
# also needs to update the other file.
# Load information from YAML
declarations = load_aten_declarations(declarations_path)
# Dictionary for NamedTuple definitions
namedtuples = {}
namedtuples: Dict[str, str] = {}
# Generate type signatures for top-level functions
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
unsorted_function_hints = collections.defaultdict(list)
unsorted_function_hints: Dict[str, List[str]] = collections.defaultdict(list)
unsorted_function_hints.update({
'set_flush_denormal': ['def set_flush_denormal(mode: _bool) -> _bool: ...'],
'get_default_dtype': ['def get_default_dtype() -> _dtype: ...'],
@ -560,21 +382,13 @@ def gen_pyi(declarations_path, out):
' other: Union[Tensor, Number],'
' *, alpha: Optional[Number]=1, out: Optional[Tensor]=None) -> Tensor: ...'.format(binop))
function_declarations = get_py_torch_functions(declarations)
for name in sorted(function_declarations.keys()):
unsorted_function_hints[name] += generate_type_hints(name, function_declarations[name], namedtuples)
# Generate type signatures for deprecated functions
# TODO: Maybe we shouldn't generate type hints for deprecated
# functions :) However, examples like those addcdiv rely on these.
with open('tools/autograd/deprecated.yaml', 'r') as f:
deprecated = yaml.load(f, Loader=YamlLoader)
for d in deprecated:
name, sig = re.match(r"^([^\(]+)\(([^\)]*)", d['name']).groups()
sig = ['*' if p.strip() == '*' else p.split() for p in sig.split(',')]
sig = ['*' if p == '*' else (p[1] + ': ' + type_to_python(p[0])) for p in sig]
unsorted_function_hints[name].append("def {}({}) -> Tensor: ...".format(name, ', '.join(sig)))
function_signatures = load_signatures(native_yaml_path, deprecated_yaml_path, method=False, pyi=True)
sig_groups = get_py_torch_functions(function_signatures)
for name in sorted(sig_groups.keys()):
unsorted_function_hints[name] += generate_type_hints(sig_groups[name])
# deprecated signatures are not used when computing named tuples
native_groups = [g for g in sig_groups[name] if not g.signature.deprecated]
namedtuples.update(generate_named_tuples(native_groups))
function_hints = []
for name, hints in sorted(unsorted_function_hints.items()):
@ -585,26 +399,26 @@ def gen_pyi(declarations_path, out):
# Generate type signatures for Tensor methods
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
unsorted_tensor_method_hints = collections.defaultdict(list)
unsorted_tensor_method_hints: Dict[str, List[str]] = collections.defaultdict(list)
unsorted_tensor_method_hints.update({
'size': ['def size(self) -> Size: ...',
'def size(self, _int) -> _int: ...'],
'stride': ['def stride(self) -> Tuple[_int]: ...',
'def stride(self, _int) -> _int: ...'],
'new_ones': ['def new_ones(self, size: {}, {}) -> Tensor: ...'.
format(type_to_python('IntArrayRef'), FACTORY_PARAMS)],
'new_ones': ['def new_ones(self, size: _size, {}) -> Tensor: ...'.
format(FACTORY_PARAMS)],
'new_tensor': ["def new_tensor(self, data: Any, {}) -> Tensor: ...".format(FACTORY_PARAMS)],
# new and __init__ have the same signatures differ only in return type
# Adapted from legacy_tensor_ctor and legacy_tensor_new
'new': ['def new(self, *args: Any, {}) ->Tensor: ...'.format(DEVICE_PARAM),
'def new(self, storage: Storage) -> Tensor: ...',
'def new(self, other: Tensor) -> Tensor: ...',
'def new(self, size: {}, *, {}) -> Tensor: ...'.format(type_to_python('IntArrayRef'), DEVICE_PARAM),
'def new(self, size: _size, *, {}) -> Tensor: ...'.format(DEVICE_PARAM),
],
'__init__': ['def __init__(self, *args: Any, {}) -> None: ...'.format(DEVICE_PARAM),
'def __init__(self, storage: Storage) -> None: ...',
'def __init__(self, other: Tensor) -> None: ...',
'def __init__(self, size: {}, *, {}) -> None: ...'.format(type_to_python('IntArrayRef'), DEVICE_PARAM),
'def __init__(self, size: _size, *, {}) -> None: ...'.format(DEVICE_PARAM),
],
'as_subclass': ["def as_subclass(self, cls: Tensor) -> Tensor: ..."],
# clamp has no default values in the Declarations
@ -679,10 +493,14 @@ def gen_pyi(declarations_path, out):
for name in simple_conversions:
unsorted_tensor_method_hints[name].append('def {}(self) -> Tensor: ...'.format(name))
tensor_method_declarations = get_py_variable_methods(declarations)
for name in sorted(tensor_method_declarations.keys()):
unsorted_tensor_method_hints[name] += \
generate_type_hints(name, tensor_method_declarations[name], namedtuples, is_tensor=True)
# pyi tensor methods don't currently include deprecated signatures for some reason
# TODO: we should probably add them in
tensor_method_signatures = load_signatures(native_yaml_path, deprecated_yaml_path, method=True, skip_deprecated=True, pyi=True)
tensor_method_sig_groups = get_py_torch_functions(tensor_method_signatures, method=True)
for name in sorted(tensor_method_sig_groups.keys()):
unsorted_tensor_method_hints[name] += generate_type_hints(tensor_method_sig_groups[name], is_tensor=True)
namedtuples.update(generate_named_tuples(tensor_method_sig_groups[name]))
for op in all_ops:
name = '__{}__'.format(op)
@ -764,17 +582,20 @@ def gen_pyi(declarations_path, out):
gen_nn_pyi(out)
def main():
def main() -> None:
parser = argparse.ArgumentParser(
description='Generate type stubs for PyTorch')
parser.add_argument('--declarations-path', metavar='DECL',
default='torch/share/ATen/Declarations.yaml',
help='path to Declarations.yaml')
parser.add_argument('--native-functions-path', metavar='NATIVE',
default='aten/src/ATen/native/native_functions.yaml',
help='path to native_functions.yaml')
parser.add_argument('--deprecated-functions-path', metavar='DEPRECATED',
default='tools/autograd/deprecated.yaml',
help='path to deprecated.yaml')
parser.add_argument('--out', metavar='OUT',
default='.',
help='path to output directory')
args = parser.parse_args()
gen_pyi(args.declarations_path, args.out)
gen_pyi(args.native_functions_path, args.deprecated_functions_path, args.out)
if __name__ == '__main__':

View File

@ -234,9 +234,9 @@ add_custom_command(
"${TORCH_SRC_DIR}/nn/functional.pyi"
COMMAND
"${PYTHON_EXECUTABLE}" -mtools.pyi.gen_pyi
--declarations-path "${CMAKE_BINARY_DIR}/aten/src/ATen/Declarations.yaml"
--native-functions-path "aten/src/ATen/native/native_functions.yaml"
--deprecated-functions-path "tools/autograd/deprecated.yaml"
DEPENDS
"${CMAKE_BINARY_DIR}/aten/src/ATen/Declarations.yaml"
"${TORCH_SRC_DIR}/_C/__init__.pyi.in"
"${TORCH_SRC_DIR}/_C/_VariableFunctions.pyi.in"
"${TORCH_SRC_DIR}/nn/functional.pyi.in"