From 7f66fa62ca31cc1b5178c532cb152288e55a7c3c Mon Sep 17 00:00:00 2001 From: Xu Zhao Date: Mon, 16 Nov 2020 23:16:02 -0800 Subject: [PATCH] Fix typing errors in torch.distributed.nn.* directory. (#47533) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/47533 Test Plan: Imported from OSS Reviewed By: walterddr Differential Revision: D24952500 Pulled By: xuzhao9 fbshipit-source-id: 8e66784fd8f9f111b6329e0bb48d6cd61c690a4a --- mypy.ini | 3 ++ torch/distributed/nn/api/remote_module.py | 57 ++++++++++++----------- torch/distributed/nn/jit/instantiator.py | 5 +- 3 files changed, 36 insertions(+), 29 deletions(-) diff --git a/mypy.ini b/mypy.ini index 27efe12bb0f..cbbd502fd4e 100644 --- a/mypy.ini +++ b/mypy.ini @@ -62,6 +62,9 @@ ignore_errors = False [mypy-torch.distributed.distributed_c10d.*] ignore_errors = False +[mypy-torch.distributed.nn.*] +ignore_errors = False + [mypy-torch.testing._internal.hypothesis_utils.*] ignore_errors = True diff --git a/torch/distributed/nn/api/remote_module.py b/torch/distributed/nn/api/remote_module.py index e0aebf87ae8..3ee7e9b2a4b 100644 --- a/torch/distributed/nn/api/remote_module.py +++ b/torch/distributed/nn/api/remote_module.py @@ -20,6 +20,7 @@ from torch.distributed.nn.jit import instantiator from torch.distributed.rpc.utils import _parse_remote_device from torch.nn.parameter import Parameter from torch.utils.hooks import RemovableHandle +from torch.nn import Module _grad_t = Union[Tuple[Tensor, ...], Tensor] @@ -52,7 +53,7 @@ def _create_module(module_cls, args, kwargs, device="cpu", module_interface_cls= def _param_rrefs(module_rref, recurse): - ret = [] + ret: List[rpc.RRef[Parameter]] = [] for param in module_rref.local_value().parameters(recurse): ret.append(rpc.RRef(param)) return ret @@ -216,45 +217,45 @@ class _RemoteModule(nn.Module): def register_parameter(self, name: str, param: Optional[Parameter]) -> None: _raise_not_supported(self.register_parameter.__name__) - def add_module(self, name: str, module: Optional["Module"]) -> None: + def add_module(self, name: str, module: Optional[Module]) -> None: _raise_not_supported(self.add_module.__name__) - def apply(self: T, fn: Callable[["Module"], None]) -> T: + def apply(self: T, fn: Callable[[Module], None]) -> T: # type: ignore[return] _raise_not_supported(self.apply.__name__) - def cuda(self: T, device: Optional[Union[int, device]] = None) -> T: + def cuda(self: T, device: Optional[Union[int, device]] = None) -> T: # type: ignore[return] _raise_not_supported(self.cuda.__name__) - def cpu(self: T) -> T: + def cpu(self: T) -> T: # type: ignore[return] _raise_not_supported(self.cpu.__name__) - def type(self: T, dst_type: Union[dtype, str]) -> T: + def type(self: T, dst_type: Union[dtype, str]) -> T: # type: ignore[return] _raise_not_supported(self.type.__name__) - def float(self: T) -> T: + def float(self: T) -> T: # type: ignore[return] _raise_not_supported(self.float.__name__) - def double(self: T) -> T: + def double(self: T) -> T: # type: ignore[return] _raise_not_supported(self.double.__name__) - def half(self: T) -> T: + def half(self: T) -> T: # type: ignore[return] _raise_not_supported(self.half.__name__) - def bfloat16(self: T) -> T: + def bfloat16(self: T) -> T: # type: ignore[return] _raise_not_supported(self.bfloat16.__name__) - def to(self, *args, **kwargs): + def to(self, *args, **kwargs) -> T: # type: ignore[return] _raise_not_supported(self.to.__name__) - def register_backward_hook( - self, hook: Callable[["Module", _grad_t, _grad_t], Union[None, Tensor]] + def register_backward_hook( # type: ignore[return] + self, hook: Callable[[Module, _grad_t, _grad_t], Union[None, Tensor]] ) -> RemovableHandle: _raise_not_supported(self.register_backward_hook.__name__) - def register_forward_pre_hook(self, hook: Callable[..., None]) -> RemovableHandle: + def register_forward_pre_hook(self, hook: Callable[..., None]) -> RemovableHandle: # type: ignore[return] _raise_not_supported(self.register_forward_pre_hook.__name__) - def register_forward_hook(self, hook: Callable[..., None]) -> RemovableHandle: + def register_forward_hook(self, hook: Callable[..., None]) -> RemovableHandle: # type: ignore[return] _raise_not_supported(self.register_forward_hook.__name__) def state_dict(self, destination=None, prefix="", keep_vars=False): @@ -272,47 +273,47 @@ class _RemoteModule(nn.Module): "Method ``parameters`` not supported for RemoteModule. Please use ``remote_parameters`` instead." ) - def named_parameters( + def named_parameters( # type: ignore[return] self, prefix: str = "", recurse: bool = True ) -> Iterator[Tuple[str, Tensor]]: _raise_not_supported(self.named_parameters.__name__) - def buffers(self, recurse: bool = True) -> Iterator[Tensor]: + def buffers(self, recurse: bool = True) -> Iterator[Tensor]: # type: ignore[return] _raise_not_supported(self.buffers.__name__) - def named_buffers( + def named_buffers( # type: ignore[return] self, prefix: str = "", recurse: bool = True ) -> Iterator[Tuple[str, Tensor]]: _raise_not_supported(self.named_buffers.__name__) - def children(self) -> Iterator["Module"]: + def children(self) -> Iterator[Module]: # type: ignore[return] _raise_not_supported(self.children.__name__) - def named_children(self) -> Iterator[Tuple[str, "Module"]]: + def named_children(self) -> Iterator[Tuple[str, Module]]: # type: ignore[return] _raise_not_supported(self.named_children.__name__) - def modules(self) -> Iterator["Module"]: + def modules(self) -> Iterator[Module]: # type: ignore[return] _raise_not_supported(self.modules.__name__) - def named_modules(self, memo: Optional[Set["Module"]] = None, prefix: str = ""): + def named_modules(self, memo: Optional[Set[Module]] = None, prefix: str = ""): _raise_not_supported(self.named_modules.__name__) - def train(self: T, mode: bool = True) -> T: + def train(self: T, mode: bool = True) -> T: # type: ignore[return] _raise_not_supported(self.train.__name__) - def eval(self: T) -> T: + def eval(self: T) -> T: # type: ignore[return] _raise_not_supported(self.eval.__name__) - def requires_grad_(self: T, requires_grad: bool = True) -> T: + def requires_grad_(self: T, requires_grad: bool = True) -> T: # type: ignore[return] _raise_not_supported(self.requires_grad_.__name__) - def zero_grad(self) -> None: + def zero_grad(self, set_to_none: bool = False) -> None: _raise_not_supported(self.zero_grad.__name__) - def share_memory(self: T) -> T: + def share_memory(self: T) -> T: # type: ignore[return] _raise_not_supported(self.share_memory.__name__) - def extra_repr(self) -> str: + def extra_repr(self) -> str: # type: ignore[return] _raise_not_supported(self.extra_repr.__name__) diff --git a/torch/distributed/nn/jit/instantiator.py b/torch/distributed/nn/jit/instantiator.py index 346984c90da..950343f0933 100644 --- a/torch/distributed/nn/jit/instantiator.py +++ b/torch/distributed/nn/jit/instantiator.py @@ -6,6 +6,7 @@ import sys import tempfile import torch +from typing import Optional from torch.distributed.nn.jit.templates.remote_module_template import ( REMOTE_MODULE_TEMPLATE, ) @@ -37,11 +38,12 @@ def get_arg_return_types_from_interface(module_interface): arg_str_list = [] arg_type_str_list = [] + assert method_schema is not None for argument in method_schema.arguments: arg_str_list.append(argument.name) if argument.has_default_value(): - default_value_str = " = {}".format(argument.default) + default_value_str = " = {}".format(argument.default_value) else: default_value_str = "" arg_type_str = "{name}: {type}{default_value}".format( @@ -63,6 +65,7 @@ def get_arg_return_types_from_interface(module_interface): def _write(out_path, text): + old_text: Optional[str] try: with open(out_path, "r") as f: old_text = f.read()