Add minimal skeleton for _C type stubs, delete torch.autograd stub (#38080)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/38080

Originally, my plan was to just delete the torch.autograd stub, but
this triggered a bunch of downstream errors relating to non-existent
to _C modules, and so instead of ignoring those files, I decided to
add a minimal _C type stubs, where it was easy (cases which were
codegened I ignored).

Signed-off-by: Edward Z. Yang <ezyang@fb.com>

Test Plan: Imported from OSS

Differential Revision: D21487841

Pulled By: ezyang

fbshipit-source-id: cfcc467ff1c146d242cb9ff33a46ba26b33b8213
This commit is contained in:
Edward Yang 2020-05-08 22:31:40 -07:00 committed by Facebook GitHub Bot
parent 464e5a6c07
commit 7e9af67ca1
10 changed files with 208 additions and 71 deletions

View File

@ -29,9 +29,6 @@ python_version = 3.6
# Extension modules without stubs.
#
[mypy-torch._C]
ignore_missing_imports = True
[mypy-torch._C._jit_tree_views]
ignore_missing_imports = True
@ -193,6 +190,9 @@ ignore_errors = True
[mypy-torch.utils.bundled_inputs]
ignore_errors = True
[mypy-torch.utils.mkldnn]
ignore_errors = True
[mypy-torch.utils.tensorboard.*]
ignore_errors = True
@ -253,6 +253,9 @@ ignore_errors = True
[mypy-torch.utils.hipify.hipify_python]
ignore_errors = True
[mypy-torch.autograd]
ignore_errors = True
[mypy-torch.autograd._functions.tensor]
ignore_errors = True

86
torch/_C/__init__.pyi Normal file
View File

@ -0,0 +1,86 @@
import torch
from typing import Optional, TypeVar, Callable, Any
from . import _nn as _nn
from . import _onnx as _onnx
T = TypeVar('T')
# Defined in torch/csrc/autograd/python_function.cpp
class _FunctionBase(object):
# TODO
...
# Defined in torch/csrc/autograd/python_legacy_variable.cpp
class _LegacyVariableBase(object):
def __init__(
self,
data: Optional['torch.Tensor']=...,
requires_grad: Optional[bool]=...,
volatile: Optional[bool]=...,
_grad_fn: Optional[_FunctionBase]=...
) -> None: ...
# Defined in torch/csrc/jit/python/init.cpp
def _jit_get_operation(op_name: str) -> Callable: ...
def _jit_pass_optimize_for_mobile(module: 'torch.jit.ScriptModule') -> 'torch.jit.ScriptModule': ...
# Defined in torch/csrc/Module.cpp
def _show_config() -> str: ...
def _parallel_info() -> str: ...
def _add_docstr(obj: T, doc_obj: str) -> T: ...
def _from_dlpack(data: Any) -> 'torch.Tensor': ...
def _to_dlpack(data: 'torch.Tensor') -> Any: ...
def _set_backcompat_broadcast_warn(arg: bool) -> None: ...
def _get_backcompat_broadcast_warn() -> bool: ...
def _set_backcompat_keepdim_warn(arg: bool) -> None: ...
def _get_backcompat_keepdim_warn() -> bool: ...
def _is_xnnpack_enabled() -> bool: ...
def _get_mkldnn_enabled() -> bool: ...
def _set_mkldnn_enabled(arg: bool) -> None: ...
has_openmp: bool
has_mkldnn: bool
has_mkl: bool
# Defined in tools/autograd/templates/python_torch_functions.cpp
# TODO: This is technically wrong
class _VariableFunctions(object):
# TODO
...
# Defined in torch/csrc/jit/python/script_init.cpp
class FileCheck(object):
# TODO
...
# Defined in torch/csrc/Generator.cpp
class Generator(object):
device: 'torch.device'
def get_state(self) -> 'torch.Tensor': ...
def set_state(self, _new_state: 'torch.Tensor') -> Generator: ...
def manual_seed(self, seed: int) -> Generator: ...
def seed(self) -> int: ...
def initial_seed(self) -> int: ...
# Defined in torch/csrc/utils/init.cpp
class BenchmarkConfig(object):
num_calling_threads: int
num_worker_threads: int
num_warmup_iters: int
num_iters: int
profiler_output_path: str
class BenchmarkExecutionStats(object):
latency_avg_ms: float
num_iters: int
class ThroughputBenchmark(object):
def __init__(self, module: Any) -> None: ...
def add_input(self, *args: Any, **kwargs: Any) -> None: ...
def run_once(self, *args: Any, **kwargs: Any) -> Any: ...
def benchmark(self, config: BenchmarkConfig) -> BenchmarkExecutionStats: ...
# Defined in torch/csrc/autograd/python_variable.cpp
# This is gonna need to be code'genned.
class _TensorBase(object):
...

3
torch/_C/_nn.pyi Normal file
View File

@ -0,0 +1,3 @@
# Defined in tools/autograd/templates/python_nn_functions.cpp
class _nn(object):
...

36
torch/_C/_onnx.pyi Normal file
View File

@ -0,0 +1,36 @@
# Defined in torch/csrc/onnx/init.cpp
from enum import Enum
PYTORCH_ONNX_CAFFE2_BUNDLE: bool
IR_VERSION: int
PRODUCER_VERSION: str
class TensorProtoDataType(Enum):
UNDEFINED = ...
FLOAT = ...
UINT8 = ...
INT8 = ...
UINT16 = ...
INT16 = ...
INT32 = ...
INT64 = ...
STRING = ...
BOOL = ...
FLOAT16 = ...
DOUBLE = ...
UINT32 = ...
UINT64 = ...
COMPLEX64 = ...
COMPLEX128 = ...
class OperatorExportTypes(Enum):
ONNX = ...
ONNX_ATEN = ...
ONNX_ATEN_FALLBACK = ...
RAW = ...
class TrainingMode(Enum):
EVAL = ...
PRESERVE = ...
TRAINING = ...

View File

@ -6,6 +6,8 @@ for which gradients should be computed with the ``requires_grad=True`` keyword.
"""
import torch
import warnings
from typing import Any, Callable, Union, Tuple, Sequence, Optional
from torch.types import _TensorOrTensors
from .variable import Variable
from .function import Function, NestedIOFunction
@ -50,7 +52,13 @@ def _make_grads(outputs, grads):
return tuple(new_grads)
def backward(tensors, grad_tensors=None, retain_graph=None, create_graph=False, grad_variables=None):
def backward(
tensors: _TensorOrTensors,
grad_tensors: Optional[_TensorOrTensors] = None,
retain_graph: Optional[bool] = None,
create_graph: bool = False,
grad_variables: Optional[_TensorOrTensors] = None,
) -> None:
r"""Computes the sum of gradients of given tensors w.r.t. graph leaves.
The graph is differentiated using the chain rule. If any of ``tensors``
@ -115,8 +123,15 @@ def backward(tensors, grad_tensors=None, retain_graph=None, create_graph=False,
allow_unreachable=True) # allow_unreachable flag
def grad(outputs, inputs, grad_outputs=None, retain_graph=None, create_graph=False,
only_inputs=True, allow_unused=False):
def grad(
outputs: _TensorOrTensors,
inputs: _TensorOrTensors,
grad_outputs: Optional[_TensorOrTensors] = None,
retain_graph: Optional[bool] = None,
create_graph: bool = False,
only_inputs: bool = True,
allow_unused: bool = False
) -> Tuple[torch.Tensor, ...]:
r"""Computes and returns the sum of gradients of outputs w.r.t. the inputs.
``grad_outputs`` should be a sequence of length matching ``output``

View File

@ -1,46 +0,0 @@
from typing import Any, Callable, Union, Tuple, Sequence, Optional
from .. import Tensor
from .grad_mode import no_grad as no_grad, enable_grad as enable_grad, \
set_grad_enabled as set_grad_enabled
from . import profiler
# The Variable API has been deprecated.
# Variable(tensor) and Variable(tensor, requires_grad) still work, but they return Tensors instead of Variables.
def Variable(tensor: Tensor, requires_grad: bool=...) -> Tensor: ...
class Function:
@staticmethod
def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any: ...
@staticmethod
def backward(ctx: Any, *grad_outputs: Any) -> Any: ...
class NestedIOFunction(Function):
# The 'type: ignore' statements are needed here because these functions are declared as '@staticmethod' in the
# superclass (Function) but are instance methods here, which mypy reports as incomptabile.
def backward(self, *gradients: Any) -> Any: ... # type: ignore
def forward(self, *args: Any) -> tuple: ... # type: ignore
def save_for_backward(self, *args: Any) -> None:...
def mark_dirty(self, *args: Any, **kwargs: Any) -> None:...
def mark_non_differentiable(self, *args: Any, **kwargs: Any) -> None: ...
def forward_extended(self, *input: Any) -> None:...
def backward_extended(self, *grad_output: Any) -> None: ...
# 'func' accepts a vararg of tensors, which isn't expressable in the type system at the moment.
# If https://mypy.readthedocs.io/en/latest/additional_features.html?highlight=callable#extended-callable-types is accepted,
# the '...' first argument of Callable can be replaced with VarArg(Tensor).
# For now, we permit any input.
def gradcheck(func: Callable[..., Union[Tensor, Tuple[Tensor, ...]]], inputs: Union[Tensor, Tuple[Tensor, ...]], eps: float=..., atol: float=..., rtol: float=..., raise_exception: bool=..., check_sparse_nnz: bool=...) -> bool: ...
def gradgradcheck(func: Callable[..., Union[Tensor, Tuple[Tensor, ...]]], inputs: Union[Tensor, Tuple[Tensor, ...]], eps: float=..., atol: float=..., rtol: float=..., gen_non_contig_grad_outputs: bool=..., raise_exception: bool=...) -> bool: ...
class detect_anomaly:
def __enter__(self) -> None: ...
def __exit__(self, *args: Any) -> bool: ...
class set_detect_anomaly:
def __init__(self, mode: bool) -> None: ...
def __enter__(self) -> None:...
def __exit__(self, *args: Any) -> bool: ...
_TensorOrTensors = Union[Tensor, Sequence[Tensor]]
def backward(tensors: _TensorOrTensors, grad_tensors: Optional[_TensorOrTensors]=..., retain_graph: Optional[bool]=..., create_graph: bool=...) -> None: ...
def grad(outputs: _TensorOrTensors, inputs: _TensorOrTensors, grad_outputs: Optional[_TensorOrTensors]=..., retain_graph: Optional[bool]=..., create_graph: bool=..., only_inputs: bool=..., allow_unused: bool=...) -> Tuple[Tensor, ...]: ...

View File

@ -1,6 +1,8 @@
import torch
import warnings
from typing import Any
class detect_anomaly(object):
r"""Context-manager that enable anomaly detection for the autograd engine.
@ -65,16 +67,16 @@ class detect_anomaly(object):
"""
def __init__(self):
def __init__(self) -> None:
self.prev = torch.is_anomaly_enabled()
warnings.warn('Anomaly Detection has been enabled. '
'This mode will increase the runtime '
'and should only be enabled for debugging.')
def __enter__(self):
def __enter__(self) -> None:
torch.set_anomaly_enabled(True)
def __exit__(self, *args):
def __exit__(self, *args: Any) -> bool:
torch.set_anomaly_enabled(self.prev)
return False
@ -94,13 +96,13 @@ class set_detect_anomaly(object):
"""
def __init__(self, mode):
def __init__(self, mode: bool) -> None:
self.prev = torch.is_anomaly_enabled()
torch.set_anomaly_enabled(mode)
def __enter__(self):
def __enter__(self) -> None:
pass
def __exit__(self, *args):
def __exit__(self, *args: Any) -> bool:
torch.set_anomaly_enabled(self.prev)
return False

View File

@ -5,6 +5,7 @@ from torch._six import with_metaclass
import functools
import warnings
from collections import OrderedDict
from typing import Any
class _ContextMethodMixin(object):
@ -150,7 +151,7 @@ class Function(with_metaclass(FunctionMeta, _C._FunctionBase, _ContextMethodMixi
is_traceable = False
@staticmethod
def forward(ctx, *args, **kwargs):
def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any:
r"""Performs the operation.
This function is to be overridden by all subclasses.
@ -165,7 +166,7 @@ class Function(with_metaclass(FunctionMeta, _C._FunctionBase, _ContextMethodMixi
" autograd.Function.")
@staticmethod
def backward(ctx, *grad_outputs):
def backward(ctx: Any, *grad_outputs: Any) -> Any:
r"""Defines a formula for differentiating the operation.
This function is to be overridden by all subclasses.
@ -347,6 +348,8 @@ _map_tensor_data = _nested_map(lambda x: isinstance(x, torch.Tensor), lambda o:
class NestedIOFunction(Function):
# The 'type: ignore' statements are needed here because these functions are declared as '@staticmethod' in the
# superclass (Function) but are instance methods here, which mypy reports as incompatible.
def _do_forward(self, *input):
self._nested_input = input
@ -364,21 +367,21 @@ class NestedIOFunction(Function):
del self._to_save_nested
return result
def backward(self, *gradients):
def backward(self, *gradients: Any) -> Any:
nested_gradients = _unflatten(gradients, self._nested_output)
result = self.backward_extended(*nested_gradients)
return tuple(_iter_None_tensors(result))
__call__ = _do_forward
def forward(self, *args):
def forward(self, *args: Any) -> Any:
nested_tensors = _map_tensor_data(self._nested_input)
result = self.forward_extended(*nested_tensors)
del self._nested_input
self._nested_output = result
return tuple(_iter_tensors(result))
def save_for_backward(self, *args):
def save_for_backward(self, *args: Any) -> None:
self.to_save = tuple(_iter_tensors(args))
self._to_save_nested = args
@ -387,14 +390,14 @@ class NestedIOFunction(Function):
flat_tensors = super(NestedIOFunction, self).saved_tensors
return _unflatten(flat_tensors, self._to_save_nested)
def mark_dirty(self, *args, **kwargs):
def mark_dirty(self, *args: Any, **kwargs: Any) -> None:
self.dirty_tensors = tuple(_iter_tensors((args, kwargs)))
def mark_non_differentiable(self, *args, **kwargs):
def mark_non_differentiable(self, *args: Any, **kwargs: Any) -> None:
self.non_differentiable = tuple(_iter_tensors((args, kwargs)))
def forward_extended(self, *input):
def forward_extended(self, *input: Any) -> None:
raise NotImplementedError
def backward_extended(self, *grad_output):
def backward_extended(self, *grad_output: Any) -> None:
raise NotImplementedError

View File

@ -1,8 +1,10 @@
import torch
from torch.types import _TensorOrTensors
from torch._six import container_abcs, istuple
import torch.testing
from itertools import product
import warnings
from typing import Callable, Union, Optional
def zero_gradients(x):
if isinstance(x, torch.Tensor):
@ -189,7 +191,25 @@ def _differentiable_outputs(x):
return tuple(o for o in _as_tuple(x) if o.requires_grad)
def gradcheck(func, inputs, eps=1e-6, atol=1e-5, rtol=1e-3, raise_exception=True, check_sparse_nnz=False, nondet_tol=0.0):
# Note [VarArg of Tensors]
# ~~~~~~~~~~~~~~~~~~~~~~~~
# 'func' accepts a vararg of tensors, which isn't expressable in the type system at the moment.
# If https://mypy.readthedocs.io/en/latest/additional_features.html?highlight=callable#extended-callable-types is accepted,
# the '...' first argument of Callable can be replaced with VarArg(Tensor).
# For now, we permit any input.
# the '...' first argument of Callable can be replaced with VarArg(Tensor).
# For now, we permit any input.
def gradcheck(
func: Callable[..., Union[_TensorOrTensors]], # See Note [VarArg of Tensors]
inputs: _TensorOrTensors,
eps: float = 1e-6,
atol: float = 1e-5,
rtol: float = 1e-3,
raise_exception: bool = True,
check_sparse_nnz: bool = False,
nondet_tol: float = 0.0
) -> bool:
r"""Check gradients computed via small finite differences against analytical
gradients w.r.t. tensors in :attr:`inputs` that are of floating point or complex type
and with ``requires_grad=True``.
@ -331,9 +351,17 @@ def gradcheck(func, inputs, eps=1e-6, atol=1e-5, rtol=1e-3, raise_exception=True
return True
def gradgradcheck(func, inputs, grad_outputs=None, eps=1e-6, atol=1e-5, rtol=1e-3,
gen_non_contig_grad_outputs=False, raise_exception=True,
nondet_tol=0.0):
def gradgradcheck(
func: Callable[..., _TensorOrTensors], # See Note [VarArg of Tensors]
inputs: _TensorOrTensors,
grad_outputs: Optional[_TensorOrTensors] = None,
eps: float = 1e-6,
atol: float = 1e-5,
rtol: float = 1e-3,
gen_non_contig_grad_outputs: bool = False,
raise_exception: bool = True,
nondet_tol: float = 0.0
) -> bool:
r"""Check gradients of gradients computed via small finite differences
against analytical gradients w.r.t. tensors in :attr:`inputs` and
:attr:`grad_outputs` that are of floating point or complex type and with

7
torch/types.py Normal file
View File

@ -0,0 +1,7 @@
import torch
from typing import Union, Sequence
# Convenience aliases for common composite types that we need
# to talk about in PyTorch
_TensorOrTensors = Union[torch.Tensor, Sequence[torch.Tensor]]