Add minimal skeleton for _C type stubs, delete torch.autograd stub (#38080)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/38080

Originally, my plan was to just delete the torch.autograd stub, but
this triggered a bunch of downstream errors relating to non-existent
to _C modules, and so instead of ignoring those files, I decided to
add a minimal _C type stubs, where it was easy (cases which were
codegened I ignored).

Signed-off-by: Edward Z. Yang <ezyang@fb.com>

Test Plan: Imported from OSS

Differential Revision: D21487841

Pulled By: ezyang

fbshipit-source-id: cfcc467ff1c146d242cb9ff33a46ba26b33b8213
This commit is contained in:
Edward Yang 2020-05-08 22:31:40 -07:00 committed by Facebook GitHub Bot
parent 464e5a6c07
commit 7e9af67ca1
10 changed files with 208 additions and 71 deletions

View File

@ -29,9 +29,6 @@ python_version = 3.6
# Extension modules without stubs. # Extension modules without stubs.
# #
[mypy-torch._C]
ignore_missing_imports = True
[mypy-torch._C._jit_tree_views] [mypy-torch._C._jit_tree_views]
ignore_missing_imports = True ignore_missing_imports = True
@ -193,6 +190,9 @@ ignore_errors = True
[mypy-torch.utils.bundled_inputs] [mypy-torch.utils.bundled_inputs]
ignore_errors = True ignore_errors = True
[mypy-torch.utils.mkldnn]
ignore_errors = True
[mypy-torch.utils.tensorboard.*] [mypy-torch.utils.tensorboard.*]
ignore_errors = True ignore_errors = True
@ -253,6 +253,9 @@ ignore_errors = True
[mypy-torch.utils.hipify.hipify_python] [mypy-torch.utils.hipify.hipify_python]
ignore_errors = True ignore_errors = True
[mypy-torch.autograd]
ignore_errors = True
[mypy-torch.autograd._functions.tensor] [mypy-torch.autograd._functions.tensor]
ignore_errors = True ignore_errors = True

86
torch/_C/__init__.pyi Normal file
View File

@ -0,0 +1,86 @@
import torch
from typing import Optional, TypeVar, Callable, Any
from . import _nn as _nn
from . import _onnx as _onnx
T = TypeVar('T')
# Defined in torch/csrc/autograd/python_function.cpp
class _FunctionBase(object):
# TODO
...
# Defined in torch/csrc/autograd/python_legacy_variable.cpp
class _LegacyVariableBase(object):
def __init__(
self,
data: Optional['torch.Tensor']=...,
requires_grad: Optional[bool]=...,
volatile: Optional[bool]=...,
_grad_fn: Optional[_FunctionBase]=...
) -> None: ...
# Defined in torch/csrc/jit/python/init.cpp
def _jit_get_operation(op_name: str) -> Callable: ...
def _jit_pass_optimize_for_mobile(module: 'torch.jit.ScriptModule') -> 'torch.jit.ScriptModule': ...
# Defined in torch/csrc/Module.cpp
def _show_config() -> str: ...
def _parallel_info() -> str: ...
def _add_docstr(obj: T, doc_obj: str) -> T: ...
def _from_dlpack(data: Any) -> 'torch.Tensor': ...
def _to_dlpack(data: 'torch.Tensor') -> Any: ...
def _set_backcompat_broadcast_warn(arg: bool) -> None: ...
def _get_backcompat_broadcast_warn() -> bool: ...
def _set_backcompat_keepdim_warn(arg: bool) -> None: ...
def _get_backcompat_keepdim_warn() -> bool: ...
def _is_xnnpack_enabled() -> bool: ...
def _get_mkldnn_enabled() -> bool: ...
def _set_mkldnn_enabled(arg: bool) -> None: ...
has_openmp: bool
has_mkldnn: bool
has_mkl: bool
# Defined in tools/autograd/templates/python_torch_functions.cpp
# TODO: This is technically wrong
class _VariableFunctions(object):
# TODO
...
# Defined in torch/csrc/jit/python/script_init.cpp
class FileCheck(object):
# TODO
...
# Defined in torch/csrc/Generator.cpp
class Generator(object):
device: 'torch.device'
def get_state(self) -> 'torch.Tensor': ...
def set_state(self, _new_state: 'torch.Tensor') -> Generator: ...
def manual_seed(self, seed: int) -> Generator: ...
def seed(self) -> int: ...
def initial_seed(self) -> int: ...
# Defined in torch/csrc/utils/init.cpp
class BenchmarkConfig(object):
num_calling_threads: int
num_worker_threads: int
num_warmup_iters: int
num_iters: int
profiler_output_path: str
class BenchmarkExecutionStats(object):
latency_avg_ms: float
num_iters: int
class ThroughputBenchmark(object):
def __init__(self, module: Any) -> None: ...
def add_input(self, *args: Any, **kwargs: Any) -> None: ...
def run_once(self, *args: Any, **kwargs: Any) -> Any: ...
def benchmark(self, config: BenchmarkConfig) -> BenchmarkExecutionStats: ...
# Defined in torch/csrc/autograd/python_variable.cpp
# This is gonna need to be code'genned.
class _TensorBase(object):
...

3
torch/_C/_nn.pyi Normal file
View File

@ -0,0 +1,3 @@
# Defined in tools/autograd/templates/python_nn_functions.cpp
class _nn(object):
...

36
torch/_C/_onnx.pyi Normal file
View File

@ -0,0 +1,36 @@
# Defined in torch/csrc/onnx/init.cpp
from enum import Enum
PYTORCH_ONNX_CAFFE2_BUNDLE: bool
IR_VERSION: int
PRODUCER_VERSION: str
class TensorProtoDataType(Enum):
UNDEFINED = ...
FLOAT = ...
UINT8 = ...
INT8 = ...
UINT16 = ...
INT16 = ...
INT32 = ...
INT64 = ...
STRING = ...
BOOL = ...
FLOAT16 = ...
DOUBLE = ...
UINT32 = ...
UINT64 = ...
COMPLEX64 = ...
COMPLEX128 = ...
class OperatorExportTypes(Enum):
ONNX = ...
ONNX_ATEN = ...
ONNX_ATEN_FALLBACK = ...
RAW = ...
class TrainingMode(Enum):
EVAL = ...
PRESERVE = ...
TRAINING = ...

View File

@ -6,6 +6,8 @@ for which gradients should be computed with the ``requires_grad=True`` keyword.
""" """
import torch import torch
import warnings import warnings
from typing import Any, Callable, Union, Tuple, Sequence, Optional
from torch.types import _TensorOrTensors
from .variable import Variable from .variable import Variable
from .function import Function, NestedIOFunction from .function import Function, NestedIOFunction
@ -50,7 +52,13 @@ def _make_grads(outputs, grads):
return tuple(new_grads) return tuple(new_grads)
def backward(tensors, grad_tensors=None, retain_graph=None, create_graph=False, grad_variables=None): def backward(
tensors: _TensorOrTensors,
grad_tensors: Optional[_TensorOrTensors] = None,
retain_graph: Optional[bool] = None,
create_graph: bool = False,
grad_variables: Optional[_TensorOrTensors] = None,
) -> None:
r"""Computes the sum of gradients of given tensors w.r.t. graph leaves. r"""Computes the sum of gradients of given tensors w.r.t. graph leaves.
The graph is differentiated using the chain rule. If any of ``tensors`` The graph is differentiated using the chain rule. If any of ``tensors``
@ -115,8 +123,15 @@ def backward(tensors, grad_tensors=None, retain_graph=None, create_graph=False,
allow_unreachable=True) # allow_unreachable flag allow_unreachable=True) # allow_unreachable flag
def grad(outputs, inputs, grad_outputs=None, retain_graph=None, create_graph=False, def grad(
only_inputs=True, allow_unused=False): outputs: _TensorOrTensors,
inputs: _TensorOrTensors,
grad_outputs: Optional[_TensorOrTensors] = None,
retain_graph: Optional[bool] = None,
create_graph: bool = False,
only_inputs: bool = True,
allow_unused: bool = False
) -> Tuple[torch.Tensor, ...]:
r"""Computes and returns the sum of gradients of outputs w.r.t. the inputs. r"""Computes and returns the sum of gradients of outputs w.r.t. the inputs.
``grad_outputs`` should be a sequence of length matching ``output`` ``grad_outputs`` should be a sequence of length matching ``output``

View File

@ -1,46 +0,0 @@
from typing import Any, Callable, Union, Tuple, Sequence, Optional
from .. import Tensor
from .grad_mode import no_grad as no_grad, enable_grad as enable_grad, \
set_grad_enabled as set_grad_enabled
from . import profiler
# The Variable API has been deprecated.
# Variable(tensor) and Variable(tensor, requires_grad) still work, but they return Tensors instead of Variables.
def Variable(tensor: Tensor, requires_grad: bool=...) -> Tensor: ...
class Function:
@staticmethod
def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any: ...
@staticmethod
def backward(ctx: Any, *grad_outputs: Any) -> Any: ...
class NestedIOFunction(Function):
# The 'type: ignore' statements are needed here because these functions are declared as '@staticmethod' in the
# superclass (Function) but are instance methods here, which mypy reports as incomptabile.
def backward(self, *gradients: Any) -> Any: ... # type: ignore
def forward(self, *args: Any) -> tuple: ... # type: ignore
def save_for_backward(self, *args: Any) -> None:...
def mark_dirty(self, *args: Any, **kwargs: Any) -> None:...
def mark_non_differentiable(self, *args: Any, **kwargs: Any) -> None: ...
def forward_extended(self, *input: Any) -> None:...
def backward_extended(self, *grad_output: Any) -> None: ...
# 'func' accepts a vararg of tensors, which isn't expressable in the type system at the moment.
# If https://mypy.readthedocs.io/en/latest/additional_features.html?highlight=callable#extended-callable-types is accepted,
# the '...' first argument of Callable can be replaced with VarArg(Tensor).
# For now, we permit any input.
def gradcheck(func: Callable[..., Union[Tensor, Tuple[Tensor, ...]]], inputs: Union[Tensor, Tuple[Tensor, ...]], eps: float=..., atol: float=..., rtol: float=..., raise_exception: bool=..., check_sparse_nnz: bool=...) -> bool: ...
def gradgradcheck(func: Callable[..., Union[Tensor, Tuple[Tensor, ...]]], inputs: Union[Tensor, Tuple[Tensor, ...]], eps: float=..., atol: float=..., rtol: float=..., gen_non_contig_grad_outputs: bool=..., raise_exception: bool=...) -> bool: ...
class detect_anomaly:
def __enter__(self) -> None: ...
def __exit__(self, *args: Any) -> bool: ...
class set_detect_anomaly:
def __init__(self, mode: bool) -> None: ...
def __enter__(self) -> None:...
def __exit__(self, *args: Any) -> bool: ...
_TensorOrTensors = Union[Tensor, Sequence[Tensor]]
def backward(tensors: _TensorOrTensors, grad_tensors: Optional[_TensorOrTensors]=..., retain_graph: Optional[bool]=..., create_graph: bool=...) -> None: ...
def grad(outputs: _TensorOrTensors, inputs: _TensorOrTensors, grad_outputs: Optional[_TensorOrTensors]=..., retain_graph: Optional[bool]=..., create_graph: bool=..., only_inputs: bool=..., allow_unused: bool=...) -> Tuple[Tensor, ...]: ...

View File

@ -1,6 +1,8 @@
import torch import torch
import warnings import warnings
from typing import Any
class detect_anomaly(object): class detect_anomaly(object):
r"""Context-manager that enable anomaly detection for the autograd engine. r"""Context-manager that enable anomaly detection for the autograd engine.
@ -65,16 +67,16 @@ class detect_anomaly(object):
""" """
def __init__(self): def __init__(self) -> None:
self.prev = torch.is_anomaly_enabled() self.prev = torch.is_anomaly_enabled()
warnings.warn('Anomaly Detection has been enabled. ' warnings.warn('Anomaly Detection has been enabled. '
'This mode will increase the runtime ' 'This mode will increase the runtime '
'and should only be enabled for debugging.') 'and should only be enabled for debugging.')
def __enter__(self): def __enter__(self) -> None:
torch.set_anomaly_enabled(True) torch.set_anomaly_enabled(True)
def __exit__(self, *args): def __exit__(self, *args: Any) -> bool:
torch.set_anomaly_enabled(self.prev) torch.set_anomaly_enabled(self.prev)
return False return False
@ -94,13 +96,13 @@ class set_detect_anomaly(object):
""" """
def __init__(self, mode): def __init__(self, mode: bool) -> None:
self.prev = torch.is_anomaly_enabled() self.prev = torch.is_anomaly_enabled()
torch.set_anomaly_enabled(mode) torch.set_anomaly_enabled(mode)
def __enter__(self): def __enter__(self) -> None:
pass pass
def __exit__(self, *args): def __exit__(self, *args: Any) -> bool:
torch.set_anomaly_enabled(self.prev) torch.set_anomaly_enabled(self.prev)
return False return False

View File

@ -5,6 +5,7 @@ from torch._six import with_metaclass
import functools import functools
import warnings import warnings
from collections import OrderedDict from collections import OrderedDict
from typing import Any
class _ContextMethodMixin(object): class _ContextMethodMixin(object):
@ -150,7 +151,7 @@ class Function(with_metaclass(FunctionMeta, _C._FunctionBase, _ContextMethodMixi
is_traceable = False is_traceable = False
@staticmethod @staticmethod
def forward(ctx, *args, **kwargs): def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any:
r"""Performs the operation. r"""Performs the operation.
This function is to be overridden by all subclasses. This function is to be overridden by all subclasses.
@ -165,7 +166,7 @@ class Function(with_metaclass(FunctionMeta, _C._FunctionBase, _ContextMethodMixi
" autograd.Function.") " autograd.Function.")
@staticmethod @staticmethod
def backward(ctx, *grad_outputs): def backward(ctx: Any, *grad_outputs: Any) -> Any:
r"""Defines a formula for differentiating the operation. r"""Defines a formula for differentiating the operation.
This function is to be overridden by all subclasses. This function is to be overridden by all subclasses.
@ -347,6 +348,8 @@ _map_tensor_data = _nested_map(lambda x: isinstance(x, torch.Tensor), lambda o:
class NestedIOFunction(Function): class NestedIOFunction(Function):
# The 'type: ignore' statements are needed here because these functions are declared as '@staticmethod' in the
# superclass (Function) but are instance methods here, which mypy reports as incompatible.
def _do_forward(self, *input): def _do_forward(self, *input):
self._nested_input = input self._nested_input = input
@ -364,21 +367,21 @@ class NestedIOFunction(Function):
del self._to_save_nested del self._to_save_nested
return result return result
def backward(self, *gradients): def backward(self, *gradients: Any) -> Any:
nested_gradients = _unflatten(gradients, self._nested_output) nested_gradients = _unflatten(gradients, self._nested_output)
result = self.backward_extended(*nested_gradients) result = self.backward_extended(*nested_gradients)
return tuple(_iter_None_tensors(result)) return tuple(_iter_None_tensors(result))
__call__ = _do_forward __call__ = _do_forward
def forward(self, *args): def forward(self, *args: Any) -> Any:
nested_tensors = _map_tensor_data(self._nested_input) nested_tensors = _map_tensor_data(self._nested_input)
result = self.forward_extended(*nested_tensors) result = self.forward_extended(*nested_tensors)
del self._nested_input del self._nested_input
self._nested_output = result self._nested_output = result
return tuple(_iter_tensors(result)) return tuple(_iter_tensors(result))
def save_for_backward(self, *args): def save_for_backward(self, *args: Any) -> None:
self.to_save = tuple(_iter_tensors(args)) self.to_save = tuple(_iter_tensors(args))
self._to_save_nested = args self._to_save_nested = args
@ -387,14 +390,14 @@ class NestedIOFunction(Function):
flat_tensors = super(NestedIOFunction, self).saved_tensors flat_tensors = super(NestedIOFunction, self).saved_tensors
return _unflatten(flat_tensors, self._to_save_nested) return _unflatten(flat_tensors, self._to_save_nested)
def mark_dirty(self, *args, **kwargs): def mark_dirty(self, *args: Any, **kwargs: Any) -> None:
self.dirty_tensors = tuple(_iter_tensors((args, kwargs))) self.dirty_tensors = tuple(_iter_tensors((args, kwargs)))
def mark_non_differentiable(self, *args, **kwargs): def mark_non_differentiable(self, *args: Any, **kwargs: Any) -> None:
self.non_differentiable = tuple(_iter_tensors((args, kwargs))) self.non_differentiable = tuple(_iter_tensors((args, kwargs)))
def forward_extended(self, *input): def forward_extended(self, *input: Any) -> None:
raise NotImplementedError raise NotImplementedError
def backward_extended(self, *grad_output): def backward_extended(self, *grad_output: Any) -> None:
raise NotImplementedError raise NotImplementedError

View File

@ -1,8 +1,10 @@
import torch import torch
from torch.types import _TensorOrTensors
from torch._six import container_abcs, istuple from torch._six import container_abcs, istuple
import torch.testing import torch.testing
from itertools import product from itertools import product
import warnings import warnings
from typing import Callable, Union, Optional
def zero_gradients(x): def zero_gradients(x):
if isinstance(x, torch.Tensor): if isinstance(x, torch.Tensor):
@ -189,7 +191,25 @@ def _differentiable_outputs(x):
return tuple(o for o in _as_tuple(x) if o.requires_grad) return tuple(o for o in _as_tuple(x) if o.requires_grad)
def gradcheck(func, inputs, eps=1e-6, atol=1e-5, rtol=1e-3, raise_exception=True, check_sparse_nnz=False, nondet_tol=0.0): # Note [VarArg of Tensors]
# ~~~~~~~~~~~~~~~~~~~~~~~~
# 'func' accepts a vararg of tensors, which isn't expressable in the type system at the moment.
# If https://mypy.readthedocs.io/en/latest/additional_features.html?highlight=callable#extended-callable-types is accepted,
# the '...' first argument of Callable can be replaced with VarArg(Tensor).
# For now, we permit any input.
# the '...' first argument of Callable can be replaced with VarArg(Tensor).
# For now, we permit any input.
def gradcheck(
func: Callable[..., Union[_TensorOrTensors]], # See Note [VarArg of Tensors]
inputs: _TensorOrTensors,
eps: float = 1e-6,
atol: float = 1e-5,
rtol: float = 1e-3,
raise_exception: bool = True,
check_sparse_nnz: bool = False,
nondet_tol: float = 0.0
) -> bool:
r"""Check gradients computed via small finite differences against analytical r"""Check gradients computed via small finite differences against analytical
gradients w.r.t. tensors in :attr:`inputs` that are of floating point or complex type gradients w.r.t. tensors in :attr:`inputs` that are of floating point or complex type
and with ``requires_grad=True``. and with ``requires_grad=True``.
@ -331,9 +351,17 @@ def gradcheck(func, inputs, eps=1e-6, atol=1e-5, rtol=1e-3, raise_exception=True
return True return True
def gradgradcheck(func, inputs, grad_outputs=None, eps=1e-6, atol=1e-5, rtol=1e-3, def gradgradcheck(
gen_non_contig_grad_outputs=False, raise_exception=True, func: Callable[..., _TensorOrTensors], # See Note [VarArg of Tensors]
nondet_tol=0.0): inputs: _TensorOrTensors,
grad_outputs: Optional[_TensorOrTensors] = None,
eps: float = 1e-6,
atol: float = 1e-5,
rtol: float = 1e-3,
gen_non_contig_grad_outputs: bool = False,
raise_exception: bool = True,
nondet_tol: float = 0.0
) -> bool:
r"""Check gradients of gradients computed via small finite differences r"""Check gradients of gradients computed via small finite differences
against analytical gradients w.r.t. tensors in :attr:`inputs` and against analytical gradients w.r.t. tensors in :attr:`inputs` and
:attr:`grad_outputs` that are of floating point or complex type and with :attr:`grad_outputs` that are of floating point or complex type and with

7
torch/types.py Normal file
View File

@ -0,0 +1,7 @@
import torch
from typing import Union, Sequence
# Convenience aliases for common composite types that we need
# to talk about in PyTorch
_TensorOrTensors = Union[torch.Tensor, Sequence[torch.Tensor]]