mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Add minimal skeleton for _C type stubs, delete torch.autograd stub (#38080)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/38080 Originally, my plan was to just delete the torch.autograd stub, but this triggered a bunch of downstream errors relating to non-existent to _C modules, and so instead of ignoring those files, I decided to add a minimal _C type stubs, where it was easy (cases which were codegened I ignored). Signed-off-by: Edward Z. Yang <ezyang@fb.com> Test Plan: Imported from OSS Differential Revision: D21487841 Pulled By: ezyang fbshipit-source-id: cfcc467ff1c146d242cb9ff33a46ba26b33b8213
This commit is contained in:
parent
464e5a6c07
commit
7e9af67ca1
9
mypy.ini
9
mypy.ini
|
|
@ -29,9 +29,6 @@ python_version = 3.6
|
|||
# Extension modules without stubs.
|
||||
#
|
||||
|
||||
[mypy-torch._C]
|
||||
ignore_missing_imports = True
|
||||
|
||||
[mypy-torch._C._jit_tree_views]
|
||||
ignore_missing_imports = True
|
||||
|
||||
|
|
@ -193,6 +190,9 @@ ignore_errors = True
|
|||
[mypy-torch.utils.bundled_inputs]
|
||||
ignore_errors = True
|
||||
|
||||
[mypy-torch.utils.mkldnn]
|
||||
ignore_errors = True
|
||||
|
||||
[mypy-torch.utils.tensorboard.*]
|
||||
ignore_errors = True
|
||||
|
||||
|
|
@ -253,6 +253,9 @@ ignore_errors = True
|
|||
[mypy-torch.utils.hipify.hipify_python]
|
||||
ignore_errors = True
|
||||
|
||||
[mypy-torch.autograd]
|
||||
ignore_errors = True
|
||||
|
||||
[mypy-torch.autograd._functions.tensor]
|
||||
ignore_errors = True
|
||||
|
||||
|
|
|
|||
86
torch/_C/__init__.pyi
Normal file
86
torch/_C/__init__.pyi
Normal file
|
|
@ -0,0 +1,86 @@
|
|||
import torch
|
||||
from typing import Optional, TypeVar, Callable, Any
|
||||
|
||||
from . import _nn as _nn
|
||||
from . import _onnx as _onnx
|
||||
|
||||
T = TypeVar('T')
|
||||
|
||||
# Defined in torch/csrc/autograd/python_function.cpp
|
||||
class _FunctionBase(object):
|
||||
# TODO
|
||||
...
|
||||
|
||||
# Defined in torch/csrc/autograd/python_legacy_variable.cpp
|
||||
class _LegacyVariableBase(object):
|
||||
def __init__(
|
||||
self,
|
||||
data: Optional['torch.Tensor']=...,
|
||||
requires_grad: Optional[bool]=...,
|
||||
volatile: Optional[bool]=...,
|
||||
_grad_fn: Optional[_FunctionBase]=...
|
||||
) -> None: ...
|
||||
|
||||
# Defined in torch/csrc/jit/python/init.cpp
|
||||
def _jit_get_operation(op_name: str) -> Callable: ...
|
||||
def _jit_pass_optimize_for_mobile(module: 'torch.jit.ScriptModule') -> 'torch.jit.ScriptModule': ...
|
||||
|
||||
# Defined in torch/csrc/Module.cpp
|
||||
def _show_config() -> str: ...
|
||||
def _parallel_info() -> str: ...
|
||||
def _add_docstr(obj: T, doc_obj: str) -> T: ...
|
||||
def _from_dlpack(data: Any) -> 'torch.Tensor': ...
|
||||
def _to_dlpack(data: 'torch.Tensor') -> Any: ...
|
||||
def _set_backcompat_broadcast_warn(arg: bool) -> None: ...
|
||||
def _get_backcompat_broadcast_warn() -> bool: ...
|
||||
def _set_backcompat_keepdim_warn(arg: bool) -> None: ...
|
||||
def _get_backcompat_keepdim_warn() -> bool: ...
|
||||
def _is_xnnpack_enabled() -> bool: ...
|
||||
def _get_mkldnn_enabled() -> bool: ...
|
||||
def _set_mkldnn_enabled(arg: bool) -> None: ...
|
||||
has_openmp: bool
|
||||
has_mkldnn: bool
|
||||
has_mkl: bool
|
||||
|
||||
# Defined in tools/autograd/templates/python_torch_functions.cpp
|
||||
# TODO: This is technically wrong
|
||||
class _VariableFunctions(object):
|
||||
# TODO
|
||||
...
|
||||
|
||||
# Defined in torch/csrc/jit/python/script_init.cpp
|
||||
class FileCheck(object):
|
||||
# TODO
|
||||
...
|
||||
|
||||
# Defined in torch/csrc/Generator.cpp
|
||||
class Generator(object):
|
||||
device: 'torch.device'
|
||||
def get_state(self) -> 'torch.Tensor': ...
|
||||
def set_state(self, _new_state: 'torch.Tensor') -> Generator: ...
|
||||
def manual_seed(self, seed: int) -> Generator: ...
|
||||
def seed(self) -> int: ...
|
||||
def initial_seed(self) -> int: ...
|
||||
|
||||
# Defined in torch/csrc/utils/init.cpp
|
||||
class BenchmarkConfig(object):
|
||||
num_calling_threads: int
|
||||
num_worker_threads: int
|
||||
num_warmup_iters: int
|
||||
num_iters: int
|
||||
profiler_output_path: str
|
||||
|
||||
class BenchmarkExecutionStats(object):
|
||||
latency_avg_ms: float
|
||||
num_iters: int
|
||||
|
||||
class ThroughputBenchmark(object):
|
||||
def __init__(self, module: Any) -> None: ...
|
||||
def add_input(self, *args: Any, **kwargs: Any) -> None: ...
|
||||
def run_once(self, *args: Any, **kwargs: Any) -> Any: ...
|
||||
def benchmark(self, config: BenchmarkConfig) -> BenchmarkExecutionStats: ...
|
||||
|
||||
# Defined in torch/csrc/autograd/python_variable.cpp
|
||||
# This is gonna need to be code'genned.
|
||||
class _TensorBase(object):
|
||||
...
|
||||
3
torch/_C/_nn.pyi
Normal file
3
torch/_C/_nn.pyi
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
# Defined in tools/autograd/templates/python_nn_functions.cpp
|
||||
class _nn(object):
|
||||
...
|
||||
36
torch/_C/_onnx.pyi
Normal file
36
torch/_C/_onnx.pyi
Normal file
|
|
@ -0,0 +1,36 @@
|
|||
# Defined in torch/csrc/onnx/init.cpp
|
||||
|
||||
from enum import Enum
|
||||
|
||||
PYTORCH_ONNX_CAFFE2_BUNDLE: bool
|
||||
IR_VERSION: int
|
||||
PRODUCER_VERSION: str
|
||||
|
||||
class TensorProtoDataType(Enum):
|
||||
UNDEFINED = ...
|
||||
FLOAT = ...
|
||||
UINT8 = ...
|
||||
INT8 = ...
|
||||
UINT16 = ...
|
||||
INT16 = ...
|
||||
INT32 = ...
|
||||
INT64 = ...
|
||||
STRING = ...
|
||||
BOOL = ...
|
||||
FLOAT16 = ...
|
||||
DOUBLE = ...
|
||||
UINT32 = ...
|
||||
UINT64 = ...
|
||||
COMPLEX64 = ...
|
||||
COMPLEX128 = ...
|
||||
|
||||
class OperatorExportTypes(Enum):
|
||||
ONNX = ...
|
||||
ONNX_ATEN = ...
|
||||
ONNX_ATEN_FALLBACK = ...
|
||||
RAW = ...
|
||||
|
||||
class TrainingMode(Enum):
|
||||
EVAL = ...
|
||||
PRESERVE = ...
|
||||
TRAINING = ...
|
||||
|
|
@ -6,6 +6,8 @@ for which gradients should be computed with the ``requires_grad=True`` keyword.
|
|||
"""
|
||||
import torch
|
||||
import warnings
|
||||
from typing import Any, Callable, Union, Tuple, Sequence, Optional
|
||||
from torch.types import _TensorOrTensors
|
||||
|
||||
from .variable import Variable
|
||||
from .function import Function, NestedIOFunction
|
||||
|
|
@ -50,7 +52,13 @@ def _make_grads(outputs, grads):
|
|||
return tuple(new_grads)
|
||||
|
||||
|
||||
def backward(tensors, grad_tensors=None, retain_graph=None, create_graph=False, grad_variables=None):
|
||||
def backward(
|
||||
tensors: _TensorOrTensors,
|
||||
grad_tensors: Optional[_TensorOrTensors] = None,
|
||||
retain_graph: Optional[bool] = None,
|
||||
create_graph: bool = False,
|
||||
grad_variables: Optional[_TensorOrTensors] = None,
|
||||
) -> None:
|
||||
r"""Computes the sum of gradients of given tensors w.r.t. graph leaves.
|
||||
|
||||
The graph is differentiated using the chain rule. If any of ``tensors``
|
||||
|
|
@ -115,8 +123,15 @@ def backward(tensors, grad_tensors=None, retain_graph=None, create_graph=False,
|
|||
allow_unreachable=True) # allow_unreachable flag
|
||||
|
||||
|
||||
def grad(outputs, inputs, grad_outputs=None, retain_graph=None, create_graph=False,
|
||||
only_inputs=True, allow_unused=False):
|
||||
def grad(
|
||||
outputs: _TensorOrTensors,
|
||||
inputs: _TensorOrTensors,
|
||||
grad_outputs: Optional[_TensorOrTensors] = None,
|
||||
retain_graph: Optional[bool] = None,
|
||||
create_graph: bool = False,
|
||||
only_inputs: bool = True,
|
||||
allow_unused: bool = False
|
||||
) -> Tuple[torch.Tensor, ...]:
|
||||
r"""Computes and returns the sum of gradients of outputs w.r.t. the inputs.
|
||||
|
||||
``grad_outputs`` should be a sequence of length matching ``output``
|
||||
|
|
|
|||
|
|
@ -1,46 +0,0 @@
|
|||
from typing import Any, Callable, Union, Tuple, Sequence, Optional
|
||||
from .. import Tensor
|
||||
from .grad_mode import no_grad as no_grad, enable_grad as enable_grad, \
|
||||
set_grad_enabled as set_grad_enabled
|
||||
from . import profiler
|
||||
|
||||
# The Variable API has been deprecated.
|
||||
# Variable(tensor) and Variable(tensor, requires_grad) still work, but they return Tensors instead of Variables.
|
||||
def Variable(tensor: Tensor, requires_grad: bool=...) -> Tensor: ...
|
||||
|
||||
class Function:
|
||||
@staticmethod
|
||||
def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any: ...
|
||||
@staticmethod
|
||||
def backward(ctx: Any, *grad_outputs: Any) -> Any: ...
|
||||
|
||||
class NestedIOFunction(Function):
|
||||
# The 'type: ignore' statements are needed here because these functions are declared as '@staticmethod' in the
|
||||
# superclass (Function) but are instance methods here, which mypy reports as incomptabile.
|
||||
def backward(self, *gradients: Any) -> Any: ... # type: ignore
|
||||
def forward(self, *args: Any) -> tuple: ... # type: ignore
|
||||
def save_for_backward(self, *args: Any) -> None:...
|
||||
def mark_dirty(self, *args: Any, **kwargs: Any) -> None:...
|
||||
def mark_non_differentiable(self, *args: Any, **kwargs: Any) -> None: ...
|
||||
def forward_extended(self, *input: Any) -> None:...
|
||||
def backward_extended(self, *grad_output: Any) -> None: ...
|
||||
|
||||
# 'func' accepts a vararg of tensors, which isn't expressable in the type system at the moment.
|
||||
# If https://mypy.readthedocs.io/en/latest/additional_features.html?highlight=callable#extended-callable-types is accepted,
|
||||
# the '...' first argument of Callable can be replaced with VarArg(Tensor).
|
||||
# For now, we permit any input.
|
||||
def gradcheck(func: Callable[..., Union[Tensor, Tuple[Tensor, ...]]], inputs: Union[Tensor, Tuple[Tensor, ...]], eps: float=..., atol: float=..., rtol: float=..., raise_exception: bool=..., check_sparse_nnz: bool=...) -> bool: ...
|
||||
def gradgradcheck(func: Callable[..., Union[Tensor, Tuple[Tensor, ...]]], inputs: Union[Tensor, Tuple[Tensor, ...]], eps: float=..., atol: float=..., rtol: float=..., gen_non_contig_grad_outputs: bool=..., raise_exception: bool=...) -> bool: ...
|
||||
|
||||
class detect_anomaly:
|
||||
def __enter__(self) -> None: ...
|
||||
def __exit__(self, *args: Any) -> bool: ...
|
||||
|
||||
class set_detect_anomaly:
|
||||
def __init__(self, mode: bool) -> None: ...
|
||||
def __enter__(self) -> None:...
|
||||
def __exit__(self, *args: Any) -> bool: ...
|
||||
|
||||
_TensorOrTensors = Union[Tensor, Sequence[Tensor]]
|
||||
def backward(tensors: _TensorOrTensors, grad_tensors: Optional[_TensorOrTensors]=..., retain_graph: Optional[bool]=..., create_graph: bool=...) -> None: ...
|
||||
def grad(outputs: _TensorOrTensors, inputs: _TensorOrTensors, grad_outputs: Optional[_TensorOrTensors]=..., retain_graph: Optional[bool]=..., create_graph: bool=..., only_inputs: bool=..., allow_unused: bool=...) -> Tuple[Tensor, ...]: ...
|
||||
|
|
@ -1,6 +1,8 @@
|
|||
import torch
|
||||
import warnings
|
||||
|
||||
from typing import Any
|
||||
|
||||
class detect_anomaly(object):
|
||||
r"""Context-manager that enable anomaly detection for the autograd engine.
|
||||
|
||||
|
|
@ -65,16 +67,16 @@ class detect_anomaly(object):
|
|||
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
self.prev = torch.is_anomaly_enabled()
|
||||
warnings.warn('Anomaly Detection has been enabled. '
|
||||
'This mode will increase the runtime '
|
||||
'and should only be enabled for debugging.')
|
||||
|
||||
def __enter__(self):
|
||||
def __enter__(self) -> None:
|
||||
torch.set_anomaly_enabled(True)
|
||||
|
||||
def __exit__(self, *args):
|
||||
def __exit__(self, *args: Any) -> bool:
|
||||
torch.set_anomaly_enabled(self.prev)
|
||||
return False
|
||||
|
||||
|
|
@ -94,13 +96,13 @@ class set_detect_anomaly(object):
|
|||
|
||||
"""
|
||||
|
||||
def __init__(self, mode):
|
||||
def __init__(self, mode: bool) -> None:
|
||||
self.prev = torch.is_anomaly_enabled()
|
||||
torch.set_anomaly_enabled(mode)
|
||||
|
||||
def __enter__(self):
|
||||
def __enter__(self) -> None:
|
||||
pass
|
||||
|
||||
def __exit__(self, *args):
|
||||
def __exit__(self, *args: Any) -> bool:
|
||||
torch.set_anomaly_enabled(self.prev)
|
||||
return False
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ from torch._six import with_metaclass
|
|||
import functools
|
||||
import warnings
|
||||
from collections import OrderedDict
|
||||
from typing import Any
|
||||
|
||||
|
||||
class _ContextMethodMixin(object):
|
||||
|
|
@ -150,7 +151,7 @@ class Function(with_metaclass(FunctionMeta, _C._FunctionBase, _ContextMethodMixi
|
|||
is_traceable = False
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, *args, **kwargs):
|
||||
def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any:
|
||||
r"""Performs the operation.
|
||||
|
||||
This function is to be overridden by all subclasses.
|
||||
|
|
@ -165,7 +166,7 @@ class Function(with_metaclass(FunctionMeta, _C._FunctionBase, _ContextMethodMixi
|
|||
" autograd.Function.")
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, *grad_outputs):
|
||||
def backward(ctx: Any, *grad_outputs: Any) -> Any:
|
||||
r"""Defines a formula for differentiating the operation.
|
||||
|
||||
This function is to be overridden by all subclasses.
|
||||
|
|
@ -347,6 +348,8 @@ _map_tensor_data = _nested_map(lambda x: isinstance(x, torch.Tensor), lambda o:
|
|||
|
||||
|
||||
class NestedIOFunction(Function):
|
||||
# The 'type: ignore' statements are needed here because these functions are declared as '@staticmethod' in the
|
||||
# superclass (Function) but are instance methods here, which mypy reports as incompatible.
|
||||
|
||||
def _do_forward(self, *input):
|
||||
self._nested_input = input
|
||||
|
|
@ -364,21 +367,21 @@ class NestedIOFunction(Function):
|
|||
del self._to_save_nested
|
||||
return result
|
||||
|
||||
def backward(self, *gradients):
|
||||
def backward(self, *gradients: Any) -> Any:
|
||||
nested_gradients = _unflatten(gradients, self._nested_output)
|
||||
result = self.backward_extended(*nested_gradients)
|
||||
return tuple(_iter_None_tensors(result))
|
||||
|
||||
__call__ = _do_forward
|
||||
|
||||
def forward(self, *args):
|
||||
def forward(self, *args: Any) -> Any:
|
||||
nested_tensors = _map_tensor_data(self._nested_input)
|
||||
result = self.forward_extended(*nested_tensors)
|
||||
del self._nested_input
|
||||
self._nested_output = result
|
||||
return tuple(_iter_tensors(result))
|
||||
|
||||
def save_for_backward(self, *args):
|
||||
def save_for_backward(self, *args: Any) -> None:
|
||||
self.to_save = tuple(_iter_tensors(args))
|
||||
self._to_save_nested = args
|
||||
|
||||
|
|
@ -387,14 +390,14 @@ class NestedIOFunction(Function):
|
|||
flat_tensors = super(NestedIOFunction, self).saved_tensors
|
||||
return _unflatten(flat_tensors, self._to_save_nested)
|
||||
|
||||
def mark_dirty(self, *args, **kwargs):
|
||||
def mark_dirty(self, *args: Any, **kwargs: Any) -> None:
|
||||
self.dirty_tensors = tuple(_iter_tensors((args, kwargs)))
|
||||
|
||||
def mark_non_differentiable(self, *args, **kwargs):
|
||||
def mark_non_differentiable(self, *args: Any, **kwargs: Any) -> None:
|
||||
self.non_differentiable = tuple(_iter_tensors((args, kwargs)))
|
||||
|
||||
def forward_extended(self, *input):
|
||||
def forward_extended(self, *input: Any) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def backward_extended(self, *grad_output):
|
||||
def backward_extended(self, *grad_output: Any) -> None:
|
||||
raise NotImplementedError
|
||||
|
|
|
|||
|
|
@ -1,8 +1,10 @@
|
|||
import torch
|
||||
from torch.types import _TensorOrTensors
|
||||
from torch._six import container_abcs, istuple
|
||||
import torch.testing
|
||||
from itertools import product
|
||||
import warnings
|
||||
from typing import Callable, Union, Optional
|
||||
|
||||
def zero_gradients(x):
|
||||
if isinstance(x, torch.Tensor):
|
||||
|
|
@ -189,7 +191,25 @@ def _differentiable_outputs(x):
|
|||
return tuple(o for o in _as_tuple(x) if o.requires_grad)
|
||||
|
||||
|
||||
def gradcheck(func, inputs, eps=1e-6, atol=1e-5, rtol=1e-3, raise_exception=True, check_sparse_nnz=False, nondet_tol=0.0):
|
||||
# Note [VarArg of Tensors]
|
||||
# ~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
# 'func' accepts a vararg of tensors, which isn't expressable in the type system at the moment.
|
||||
# If https://mypy.readthedocs.io/en/latest/additional_features.html?highlight=callable#extended-callable-types is accepted,
|
||||
# the '...' first argument of Callable can be replaced with VarArg(Tensor).
|
||||
# For now, we permit any input.
|
||||
# the '...' first argument of Callable can be replaced with VarArg(Tensor).
|
||||
# For now, we permit any input.
|
||||
|
||||
def gradcheck(
|
||||
func: Callable[..., Union[_TensorOrTensors]], # See Note [VarArg of Tensors]
|
||||
inputs: _TensorOrTensors,
|
||||
eps: float = 1e-6,
|
||||
atol: float = 1e-5,
|
||||
rtol: float = 1e-3,
|
||||
raise_exception: bool = True,
|
||||
check_sparse_nnz: bool = False,
|
||||
nondet_tol: float = 0.0
|
||||
) -> bool:
|
||||
r"""Check gradients computed via small finite differences against analytical
|
||||
gradients w.r.t. tensors in :attr:`inputs` that are of floating point or complex type
|
||||
and with ``requires_grad=True``.
|
||||
|
|
@ -331,9 +351,17 @@ def gradcheck(func, inputs, eps=1e-6, atol=1e-5, rtol=1e-3, raise_exception=True
|
|||
return True
|
||||
|
||||
|
||||
def gradgradcheck(func, inputs, grad_outputs=None, eps=1e-6, atol=1e-5, rtol=1e-3,
|
||||
gen_non_contig_grad_outputs=False, raise_exception=True,
|
||||
nondet_tol=0.0):
|
||||
def gradgradcheck(
|
||||
func: Callable[..., _TensorOrTensors], # See Note [VarArg of Tensors]
|
||||
inputs: _TensorOrTensors,
|
||||
grad_outputs: Optional[_TensorOrTensors] = None,
|
||||
eps: float = 1e-6,
|
||||
atol: float = 1e-5,
|
||||
rtol: float = 1e-3,
|
||||
gen_non_contig_grad_outputs: bool = False,
|
||||
raise_exception: bool = True,
|
||||
nondet_tol: float = 0.0
|
||||
) -> bool:
|
||||
r"""Check gradients of gradients computed via small finite differences
|
||||
against analytical gradients w.r.t. tensors in :attr:`inputs` and
|
||||
:attr:`grad_outputs` that are of floating point or complex type and with
|
||||
|
|
|
|||
7
torch/types.py
Normal file
7
torch/types.py
Normal file
|
|
@ -0,0 +1,7 @@
|
|||
import torch
|
||||
from typing import Union, Sequence
|
||||
|
||||
# Convenience aliases for common composite types that we need
|
||||
# to talk about in PyTorch
|
||||
|
||||
_TensorOrTensors = Union[torch.Tensor, Sequence[torch.Tensor]]
|
||||
Loading…
Reference in New Issue
Block a user