mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
add typing annotations for a few torch.utils.* modules (#43806)
Summary: Fixes https://github.com/pytorch/pytorch/issues/43431. Depends on [gh-43862](https://github.com/pytorch/pytorch/pull/43862) (EDIT: now merged) Modules: - torch.utils.mkldnn - torch.utils.mobile_optimizer - torch.utils.bundled_inputs Pull Request resolved: https://github.com/pytorch/pytorch/pull/43806 Reviewed By: gmagogsfm Differential Revision: D23635151 Pulled By: SplitInfinity fbshipit-source-id: a85b75a7927dde6cc55bcb361f8ff601ffb0b2a1
This commit is contained in:
parent
7d78a6fcdd
commit
cdf5e2ae86
9
mypy.ini
9
mypy.ini
|
|
@ -195,15 +195,6 @@ ignore_errors = True
|
|||
[mypy-torch._overrides]
|
||||
ignore_errors = True
|
||||
|
||||
[mypy-torch.utils.bundled_inputs]
|
||||
ignore_errors = True
|
||||
|
||||
[mypy-torch.utils.mkldnn]
|
||||
ignore_errors = True
|
||||
|
||||
[mypy-torch.utils.mobile_optimizer]
|
||||
ignore_errors = True
|
||||
|
||||
[mypy-torch.utils.tensorboard._caffe2_graph]
|
||||
ignore_errors = True
|
||||
|
||||
|
|
|
|||
|
|
@ -561,7 +561,7 @@ def gen_pyi(declarations_path, out):
|
|||
],
|
||||
'get_device': ['def get_device(self) -> _int: ...'],
|
||||
'contiguous': ['def contiguous(self) -> Tensor: ...'],
|
||||
'is_contiguous': ['def is_contiguous(self) -> _bool: ...'],
|
||||
'is_contiguous': ['def is_contiguous(self, memory_format=torch.contiguous_format) -> _bool: ...'],
|
||||
'is_cuda': ['is_cuda: _bool'],
|
||||
'is_leaf': ['is_leaf: _bool'],
|
||||
'is_sparse': ['is_sparse: _bool'],
|
||||
|
|
|
|||
|
|
@ -5,7 +5,8 @@ from torch import Tensor
|
|||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import (Any, BinaryIO, Callable, ContextManager, Dict, Iterator, List, NamedTuple,
|
||||
Optional, overload, Sequence, Tuple, TypeVar, Type, Union)
|
||||
Optional, overload, Sequence, Tuple, TypeVar, Type, Union, Generic,
|
||||
Set, AnyStr)
|
||||
from torch._six import inf
|
||||
|
||||
from torch.types import _int, _float, _bool, _dtype, _device, _qscheme, _size, _layout, Device, Number, Storage
|
||||
|
|
@ -131,6 +132,16 @@ class Future(object):
|
|||
def then(self, callback: Callable) -> Future: ...
|
||||
def set_result(self, result: Any) -> None: ...
|
||||
|
||||
# Defined in torch/csrc/jit/passes/xnnpack_rewrite.h
|
||||
class MobileOptimizerType:
|
||||
...
|
||||
|
||||
CONV_BN_FUSION: MobileOptimizerType
|
||||
INSERT_FOLD_PREPACK_OPS: MobileOptimizerType
|
||||
REMOVE_DROPOUT: MobileOptimizerType
|
||||
FUSE_ADD_RELU: MobileOptimizerType
|
||||
HOIST_CONV_PACKED_PARAMS: MobileOptimizerType
|
||||
|
||||
def fork(*args: Any, **kwargs: Any) -> Future: ...
|
||||
def wait(fut: Future) -> Any: ...
|
||||
def _collect_all(futures: List[Future]) -> Future: ...
|
||||
|
|
@ -142,7 +153,9 @@ def _jit_init() -> _bool: ...
|
|||
def _jit_flatten(arg: Any) -> Tuple[List[Tensor], IODescriptor]: ...
|
||||
def _jit_unflatten(vars: List[Tensor], desc: IODescriptor) -> Any: ...
|
||||
def _jit_get_operation(op_name: str) -> Callable: ...
|
||||
def _jit_pass_optimize_for_mobile(module: 'torch.jit.ScriptModule') -> 'torch.jit.ScriptModule': ...
|
||||
def _jit_pass_optimize_for_mobile(module: 'torch.jit.ScriptModule',
|
||||
optimization_blocklist: Set[MobileOptimizerType],
|
||||
preserved_methods: List[AnyStr]) -> 'torch.jit.ScriptModule': ...
|
||||
def _jit_pass_inline(Graph) -> None: ...
|
||||
def _jit_get_schemas_for_operator(name :str) -> List[FunctionSchema]: ...
|
||||
def _jit_can_fuse_on_cpu() -> _bool: ...
|
||||
|
|
@ -293,11 +306,6 @@ class ParameterDict:
|
|||
class BufferDict:
|
||||
def __init__(self, mod: ScriptModule) -> None: ...
|
||||
|
||||
# Defined in torch/csrc/jit/python/python_ir.cpp
|
||||
class TensorType:
|
||||
@classmethod
|
||||
def get(cls) -> TensorType: ...
|
||||
|
||||
# Defined in torch/csrc/Module.cpp
|
||||
def _initExtension(shm_manager_path: str) -> None: ... # THPModule_initExtension
|
||||
def _autograd_init() -> _bool: ... # THPAutograd_initExtension
|
||||
|
|
@ -541,6 +549,8 @@ Stack = List[IValue]
|
|||
class JitType:
|
||||
...
|
||||
|
||||
R = TypeVar('R', bound=JitType)
|
||||
|
||||
class AnyType(JitType):
|
||||
@staticmethod
|
||||
def get() -> AnyType: ...
|
||||
|
|
@ -598,7 +608,7 @@ class InterfaceType(JitType):
|
|||
def getMethod(self, name: str) -> Optional[FunctionSchema]: ...
|
||||
def getMethodNames(self) -> List[str]: ...
|
||||
|
||||
class OptionalType(JitType):
|
||||
class OptionalType(JitType, Generic[R]):
|
||||
def __init__(self, a: JitType) -> None: ...
|
||||
def getElementType(self) -> JitType: ...
|
||||
|
||||
|
|
@ -621,6 +631,10 @@ class EnumType(JitType):
|
|||
) -> None:
|
||||
...
|
||||
|
||||
class TensorType(JitType):
|
||||
@classmethod
|
||||
def get(cls) -> TensorType: ...
|
||||
|
||||
# Defined in torch/csrc/jit/python/python_tree_views.cpp
|
||||
class SourceRange:
|
||||
...
|
||||
|
|
|
|||
|
|
@ -1,5 +1,13 @@
|
|||
from typing import Callable
|
||||
from torch import Tensor
|
||||
from typing import Callable, Optional, List
|
||||
|
||||
# Defined in tools/autograd/templates/python_nn_functions.cpp
|
||||
|
||||
${dispatched_hints}
|
||||
|
||||
# Defined in aten/src/ATen/native/mkldnn/Linear.cpp
|
||||
def mkldnn_linear(input: Tensor, weight: Tensor, bias: Optional[Tensor]) -> Tensor: ...
|
||||
|
||||
# Defined at aten/src/ATen/native/mkldnn/MKLDNNConversions.cpp
|
||||
def mkldnn_reorder_conv2d_weight(self: Tensor, padding: List, stride: List, dilatation: List, groups: int) -> Tensor: ...
|
||||
def mkldnn_reorder_conv3d_weight(self: Tensor, padding: List, stride: List, dilatation: List, groups: int) -> Tensor: ...
|
||||
|
|
@ -147,7 +147,7 @@ def load(f, map_location=None, _extra_files=None):
|
|||
os.remove("scriptmodule.pt")
|
||||
"""
|
||||
if isinstance(f, string_classes):
|
||||
if not os.path.exists(f):
|
||||
if not os.path.exists(f): # type: ignore
|
||||
raise ValueError("The provided filename {} does not exist".format(f)) # type: ignore
|
||||
if os.path.isdir(f):
|
||||
raise ValueError("The provided filename {} is a directory".format(f)) # type: ignore
|
||||
|
|
|
|||
|
|
@ -1,7 +1,8 @@
|
|||
#!/usr/bin/env python3
|
||||
from typing import Any, TypeVar, Optional, Tuple, List, NamedTuple
|
||||
from typing import Any, TypeVar, Optional, Tuple, List, NamedTuple, Union
|
||||
import textwrap
|
||||
import torch
|
||||
from torch._C import TupleType, OptionalType, ListType
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
|
@ -53,8 +54,8 @@ def augment_model_with_bundled_inputs(
|
|||
raise Exception("Only ScriptModule is supported.")
|
||||
|
||||
forward_arg_types = [arg.type for arg in model.forward.schema.arguments[1:]]
|
||||
deflated_inputs_type = torch._C.ListType(torch._C.TupleType(forward_arg_types))
|
||||
inflated_inputs_type = torch._C.OptionalType(deflated_inputs_type)
|
||||
deflated_inputs_type: ListType = ListType(TupleType(forward_arg_types))
|
||||
inflated_inputs_type: OptionalType[ListType] = OptionalType(deflated_inputs_type)
|
||||
model._c._register_attribute("_bundled_inputs_deflated", deflated_inputs_type, [])
|
||||
model._c._register_attribute("_bundled_inputs_inflated", inflated_inputs_type, None)
|
||||
|
||||
|
|
@ -117,7 +118,7 @@ def augment_model_with_bundled_inputs(
|
|||
"""))
|
||||
|
||||
|
||||
def _inflate_expr(arg: T, ref: str) -> Tuple[T, str]:
|
||||
def _inflate_expr(arg: T, ref: str) -> Tuple[Union[T, torch.Tensor], str]:
|
||||
# Allow custom inflation expressions any object.
|
||||
# For example, calling custom image-decoding ops.
|
||||
# Or just use "{}" as the format string to ignore size limits.
|
||||
|
|
|
|||
|
|
@ -830,7 +830,7 @@ class SummaryWriter(object):
|
|||
metadata, label_img, fs, subdir, global_step, tag)
|
||||
self._projector_config.embeddings.extend([embedding_info])
|
||||
|
||||
from google.protobuf import text_format
|
||||
from google.protobuf import text_format # type: ignore
|
||||
config_pbtxt = text_format.MessageToString(self._projector_config)
|
||||
write_pbtxt(self._get_file_writer().get_logdir(), config_pbtxt)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user