add typing annotations for a few torch.utils.* modules (#43806)

Summary:
Fixes https://github.com/pytorch/pytorch/issues/43431. Depends on [gh-43862](https://github.com/pytorch/pytorch/pull/43862) (EDIT: now merged)

Modules:
- torch.utils.mkldnn
- torch.utils.mobile_optimizer
- torch.utils.bundled_inputs

Pull Request resolved: https://github.com/pytorch/pytorch/pull/43806

Reviewed By: gmagogsfm

Differential Revision: D23635151

Pulled By: SplitInfinity

fbshipit-source-id: a85b75a7927dde6cc55bcb361f8ff601ffb0b2a1
This commit is contained in:
Guilherme Leobas 2020-09-11 10:13:43 -07:00 committed by Facebook GitHub Bot
parent 7d78a6fcdd
commit cdf5e2ae86
7 changed files with 39 additions and 25 deletions

View File

@ -195,15 +195,6 @@ ignore_errors = True
[mypy-torch._overrides]
ignore_errors = True
[mypy-torch.utils.bundled_inputs]
ignore_errors = True
[mypy-torch.utils.mkldnn]
ignore_errors = True
[mypy-torch.utils.mobile_optimizer]
ignore_errors = True
[mypy-torch.utils.tensorboard._caffe2_graph]
ignore_errors = True

View File

@ -561,7 +561,7 @@ def gen_pyi(declarations_path, out):
],
'get_device': ['def get_device(self) -> _int: ...'],
'contiguous': ['def contiguous(self) -> Tensor: ...'],
'is_contiguous': ['def is_contiguous(self) -> _bool: ...'],
'is_contiguous': ['def is_contiguous(self, memory_format=torch.contiguous_format) -> _bool: ...'],
'is_cuda': ['is_cuda: _bool'],
'is_leaf': ['is_leaf: _bool'],
'is_sparse': ['is_sparse: _bool'],

View File

@ -5,7 +5,8 @@ from torch import Tensor
from enum import Enum
from pathlib import Path
from typing import (Any, BinaryIO, Callable, ContextManager, Dict, Iterator, List, NamedTuple,
Optional, overload, Sequence, Tuple, TypeVar, Type, Union)
Optional, overload, Sequence, Tuple, TypeVar, Type, Union, Generic,
Set, AnyStr)
from torch._six import inf
from torch.types import _int, _float, _bool, _dtype, _device, _qscheme, _size, _layout, Device, Number, Storage
@ -131,6 +132,16 @@ class Future(object):
def then(self, callback: Callable) -> Future: ...
def set_result(self, result: Any) -> None: ...
# Defined in torch/csrc/jit/passes/xnnpack_rewrite.h
class MobileOptimizerType:
...
CONV_BN_FUSION: MobileOptimizerType
INSERT_FOLD_PREPACK_OPS: MobileOptimizerType
REMOVE_DROPOUT: MobileOptimizerType
FUSE_ADD_RELU: MobileOptimizerType
HOIST_CONV_PACKED_PARAMS: MobileOptimizerType
def fork(*args: Any, **kwargs: Any) -> Future: ...
def wait(fut: Future) -> Any: ...
def _collect_all(futures: List[Future]) -> Future: ...
@ -142,7 +153,9 @@ def _jit_init() -> _bool: ...
def _jit_flatten(arg: Any) -> Tuple[List[Tensor], IODescriptor]: ...
def _jit_unflatten(vars: List[Tensor], desc: IODescriptor) -> Any: ...
def _jit_get_operation(op_name: str) -> Callable: ...
def _jit_pass_optimize_for_mobile(module: 'torch.jit.ScriptModule') -> 'torch.jit.ScriptModule': ...
def _jit_pass_optimize_for_mobile(module: 'torch.jit.ScriptModule',
optimization_blocklist: Set[MobileOptimizerType],
preserved_methods: List[AnyStr]) -> 'torch.jit.ScriptModule': ...
def _jit_pass_inline(Graph) -> None: ...
def _jit_get_schemas_for_operator(name :str) -> List[FunctionSchema]: ...
def _jit_can_fuse_on_cpu() -> _bool: ...
@ -293,11 +306,6 @@ class ParameterDict:
class BufferDict:
def __init__(self, mod: ScriptModule) -> None: ...
# Defined in torch/csrc/jit/python/python_ir.cpp
class TensorType:
@classmethod
def get(cls) -> TensorType: ...
# Defined in torch/csrc/Module.cpp
def _initExtension(shm_manager_path: str) -> None: ... # THPModule_initExtension
def _autograd_init() -> _bool: ... # THPAutograd_initExtension
@ -541,6 +549,8 @@ Stack = List[IValue]
class JitType:
...
R = TypeVar('R', bound=JitType)
class AnyType(JitType):
@staticmethod
def get() -> AnyType: ...
@ -598,7 +608,7 @@ class InterfaceType(JitType):
def getMethod(self, name: str) -> Optional[FunctionSchema]: ...
def getMethodNames(self) -> List[str]: ...
class OptionalType(JitType):
class OptionalType(JitType, Generic[R]):
def __init__(self, a: JitType) -> None: ...
def getElementType(self) -> JitType: ...
@ -621,6 +631,10 @@ class EnumType(JitType):
) -> None:
...
class TensorType(JitType):
@classmethod
def get(cls) -> TensorType: ...
# Defined in torch/csrc/jit/python/python_tree_views.cpp
class SourceRange:
...

View File

@ -1,5 +1,13 @@
from typing import Callable
from torch import Tensor
from typing import Callable, Optional, List
# Defined in tools/autograd/templates/python_nn_functions.cpp
${dispatched_hints}
# Defined in aten/src/ATen/native/mkldnn/Linear.cpp
def mkldnn_linear(input: Tensor, weight: Tensor, bias: Optional[Tensor]) -> Tensor: ...
# Defined at aten/src/ATen/native/mkldnn/MKLDNNConversions.cpp
def mkldnn_reorder_conv2d_weight(self: Tensor, padding: List, stride: List, dilatation: List, groups: int) -> Tensor: ...
def mkldnn_reorder_conv3d_weight(self: Tensor, padding: List, stride: List, dilatation: List, groups: int) -> Tensor: ...

View File

@ -147,7 +147,7 @@ def load(f, map_location=None, _extra_files=None):
os.remove("scriptmodule.pt")
"""
if isinstance(f, string_classes):
if not os.path.exists(f):
if not os.path.exists(f): # type: ignore
raise ValueError("The provided filename {} does not exist".format(f)) # type: ignore
if os.path.isdir(f):
raise ValueError("The provided filename {} is a directory".format(f)) # type: ignore

View File

@ -1,7 +1,8 @@
#!/usr/bin/env python3
from typing import Any, TypeVar, Optional, Tuple, List, NamedTuple
from typing import Any, TypeVar, Optional, Tuple, List, NamedTuple, Union
import textwrap
import torch
from torch._C import TupleType, OptionalType, ListType
T = TypeVar("T")
@ -53,8 +54,8 @@ def augment_model_with_bundled_inputs(
raise Exception("Only ScriptModule is supported.")
forward_arg_types = [arg.type for arg in model.forward.schema.arguments[1:]]
deflated_inputs_type = torch._C.ListType(torch._C.TupleType(forward_arg_types))
inflated_inputs_type = torch._C.OptionalType(deflated_inputs_type)
deflated_inputs_type: ListType = ListType(TupleType(forward_arg_types))
inflated_inputs_type: OptionalType[ListType] = OptionalType(deflated_inputs_type)
model._c._register_attribute("_bundled_inputs_deflated", deflated_inputs_type, [])
model._c._register_attribute("_bundled_inputs_inflated", inflated_inputs_type, None)
@ -117,7 +118,7 @@ def augment_model_with_bundled_inputs(
"""))
def _inflate_expr(arg: T, ref: str) -> Tuple[T, str]:
def _inflate_expr(arg: T, ref: str) -> Tuple[Union[T, torch.Tensor], str]:
# Allow custom inflation expressions any object.
# For example, calling custom image-decoding ops.
# Or just use "{}" as the format string to ignore size limits.

View File

@ -830,7 +830,7 @@ class SummaryWriter(object):
metadata, label_img, fs, subdir, global_step, tag)
self._projector_config.embeddings.extend([embedding_info])
from google.protobuf import text_format
from google.protobuf import text_format # type: ignore
config_pbtxt = text_format.MessageToString(self._projector_config)
write_pbtxt(self._get_file_writer().get_logdir(), config_pbtxt)