Fix typing errors in torch.distributed.nn.* directory. (#47533)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/47533

Test Plan: Imported from OSS

Reviewed By: walterddr

Differential Revision: D24952500

Pulled By: xuzhao9

fbshipit-source-id: 8e66784fd8f9f111b6329e0bb48d6cd61c690a4a
This commit is contained in:
Xu Zhao 2020-11-16 23:16:02 -08:00 committed by Facebook GitHub Bot
parent 915050ed66
commit 7f66fa62ca
3 changed files with 36 additions and 29 deletions

View File

@ -62,6 +62,9 @@ ignore_errors = False
[mypy-torch.distributed.distributed_c10d.*] [mypy-torch.distributed.distributed_c10d.*]
ignore_errors = False ignore_errors = False
[mypy-torch.distributed.nn.*]
ignore_errors = False
[mypy-torch.testing._internal.hypothesis_utils.*] [mypy-torch.testing._internal.hypothesis_utils.*]
ignore_errors = True ignore_errors = True

View File

@ -20,6 +20,7 @@ from torch.distributed.nn.jit import instantiator
from torch.distributed.rpc.utils import _parse_remote_device from torch.distributed.rpc.utils import _parse_remote_device
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from torch.utils.hooks import RemovableHandle from torch.utils.hooks import RemovableHandle
from torch.nn import Module
_grad_t = Union[Tuple[Tensor, ...], Tensor] _grad_t = Union[Tuple[Tensor, ...], Tensor]
@ -52,7 +53,7 @@ def _create_module(module_cls, args, kwargs, device="cpu", module_interface_cls=
def _param_rrefs(module_rref, recurse): def _param_rrefs(module_rref, recurse):
ret = [] ret: List[rpc.RRef[Parameter]] = []
for param in module_rref.local_value().parameters(recurse): for param in module_rref.local_value().parameters(recurse):
ret.append(rpc.RRef(param)) ret.append(rpc.RRef(param))
return ret return ret
@ -216,45 +217,45 @@ class _RemoteModule(nn.Module):
def register_parameter(self, name: str, param: Optional[Parameter]) -> None: def register_parameter(self, name: str, param: Optional[Parameter]) -> None:
_raise_not_supported(self.register_parameter.__name__) _raise_not_supported(self.register_parameter.__name__)
def add_module(self, name: str, module: Optional["Module"]) -> None: def add_module(self, name: str, module: Optional[Module]) -> None:
_raise_not_supported(self.add_module.__name__) _raise_not_supported(self.add_module.__name__)
def apply(self: T, fn: Callable[["Module"], None]) -> T: def apply(self: T, fn: Callable[[Module], None]) -> T: # type: ignore[return]
_raise_not_supported(self.apply.__name__) _raise_not_supported(self.apply.__name__)
def cuda(self: T, device: Optional[Union[int, device]] = None) -> T: def cuda(self: T, device: Optional[Union[int, device]] = None) -> T: # type: ignore[return]
_raise_not_supported(self.cuda.__name__) _raise_not_supported(self.cuda.__name__)
def cpu(self: T) -> T: def cpu(self: T) -> T: # type: ignore[return]
_raise_not_supported(self.cpu.__name__) _raise_not_supported(self.cpu.__name__)
def type(self: T, dst_type: Union[dtype, str]) -> T: def type(self: T, dst_type: Union[dtype, str]) -> T: # type: ignore[return]
_raise_not_supported(self.type.__name__) _raise_not_supported(self.type.__name__)
def float(self: T) -> T: def float(self: T) -> T: # type: ignore[return]
_raise_not_supported(self.float.__name__) _raise_not_supported(self.float.__name__)
def double(self: T) -> T: def double(self: T) -> T: # type: ignore[return]
_raise_not_supported(self.double.__name__) _raise_not_supported(self.double.__name__)
def half(self: T) -> T: def half(self: T) -> T: # type: ignore[return]
_raise_not_supported(self.half.__name__) _raise_not_supported(self.half.__name__)
def bfloat16(self: T) -> T: def bfloat16(self: T) -> T: # type: ignore[return]
_raise_not_supported(self.bfloat16.__name__) _raise_not_supported(self.bfloat16.__name__)
def to(self, *args, **kwargs): def to(self, *args, **kwargs) -> T: # type: ignore[return]
_raise_not_supported(self.to.__name__) _raise_not_supported(self.to.__name__)
def register_backward_hook( def register_backward_hook( # type: ignore[return]
self, hook: Callable[["Module", _grad_t, _grad_t], Union[None, Tensor]] self, hook: Callable[[Module, _grad_t, _grad_t], Union[None, Tensor]]
) -> RemovableHandle: ) -> RemovableHandle:
_raise_not_supported(self.register_backward_hook.__name__) _raise_not_supported(self.register_backward_hook.__name__)
def register_forward_pre_hook(self, hook: Callable[..., None]) -> RemovableHandle: def register_forward_pre_hook(self, hook: Callable[..., None]) -> RemovableHandle: # type: ignore[return]
_raise_not_supported(self.register_forward_pre_hook.__name__) _raise_not_supported(self.register_forward_pre_hook.__name__)
def register_forward_hook(self, hook: Callable[..., None]) -> RemovableHandle: def register_forward_hook(self, hook: Callable[..., None]) -> RemovableHandle: # type: ignore[return]
_raise_not_supported(self.register_forward_hook.__name__) _raise_not_supported(self.register_forward_hook.__name__)
def state_dict(self, destination=None, prefix="", keep_vars=False): def state_dict(self, destination=None, prefix="", keep_vars=False):
@ -272,47 +273,47 @@ class _RemoteModule(nn.Module):
"Method ``parameters`` not supported for RemoteModule. Please use ``remote_parameters`` instead." "Method ``parameters`` not supported for RemoteModule. Please use ``remote_parameters`` instead."
) )
def named_parameters( def named_parameters( # type: ignore[return]
self, prefix: str = "", recurse: bool = True self, prefix: str = "", recurse: bool = True
) -> Iterator[Tuple[str, Tensor]]: ) -> Iterator[Tuple[str, Tensor]]:
_raise_not_supported(self.named_parameters.__name__) _raise_not_supported(self.named_parameters.__name__)
def buffers(self, recurse: bool = True) -> Iterator[Tensor]: def buffers(self, recurse: bool = True) -> Iterator[Tensor]: # type: ignore[return]
_raise_not_supported(self.buffers.__name__) _raise_not_supported(self.buffers.__name__)
def named_buffers( def named_buffers( # type: ignore[return]
self, prefix: str = "", recurse: bool = True self, prefix: str = "", recurse: bool = True
) -> Iterator[Tuple[str, Tensor]]: ) -> Iterator[Tuple[str, Tensor]]:
_raise_not_supported(self.named_buffers.__name__) _raise_not_supported(self.named_buffers.__name__)
def children(self) -> Iterator["Module"]: def children(self) -> Iterator[Module]: # type: ignore[return]
_raise_not_supported(self.children.__name__) _raise_not_supported(self.children.__name__)
def named_children(self) -> Iterator[Tuple[str, "Module"]]: def named_children(self) -> Iterator[Tuple[str, Module]]: # type: ignore[return]
_raise_not_supported(self.named_children.__name__) _raise_not_supported(self.named_children.__name__)
def modules(self) -> Iterator["Module"]: def modules(self) -> Iterator[Module]: # type: ignore[return]
_raise_not_supported(self.modules.__name__) _raise_not_supported(self.modules.__name__)
def named_modules(self, memo: Optional[Set["Module"]] = None, prefix: str = ""): def named_modules(self, memo: Optional[Set[Module]] = None, prefix: str = ""):
_raise_not_supported(self.named_modules.__name__) _raise_not_supported(self.named_modules.__name__)
def train(self: T, mode: bool = True) -> T: def train(self: T, mode: bool = True) -> T: # type: ignore[return]
_raise_not_supported(self.train.__name__) _raise_not_supported(self.train.__name__)
def eval(self: T) -> T: def eval(self: T) -> T: # type: ignore[return]
_raise_not_supported(self.eval.__name__) _raise_not_supported(self.eval.__name__)
def requires_grad_(self: T, requires_grad: bool = True) -> T: def requires_grad_(self: T, requires_grad: bool = True) -> T: # type: ignore[return]
_raise_not_supported(self.requires_grad_.__name__) _raise_not_supported(self.requires_grad_.__name__)
def zero_grad(self) -> None: def zero_grad(self, set_to_none: bool = False) -> None:
_raise_not_supported(self.zero_grad.__name__) _raise_not_supported(self.zero_grad.__name__)
def share_memory(self: T) -> T: def share_memory(self: T) -> T: # type: ignore[return]
_raise_not_supported(self.share_memory.__name__) _raise_not_supported(self.share_memory.__name__)
def extra_repr(self) -> str: def extra_repr(self) -> str: # type: ignore[return]
_raise_not_supported(self.extra_repr.__name__) _raise_not_supported(self.extra_repr.__name__)

View File

@ -6,6 +6,7 @@ import sys
import tempfile import tempfile
import torch import torch
from typing import Optional
from torch.distributed.nn.jit.templates.remote_module_template import ( from torch.distributed.nn.jit.templates.remote_module_template import (
REMOTE_MODULE_TEMPLATE, REMOTE_MODULE_TEMPLATE,
) )
@ -37,11 +38,12 @@ def get_arg_return_types_from_interface(module_interface):
arg_str_list = [] arg_str_list = []
arg_type_str_list = [] arg_type_str_list = []
assert method_schema is not None
for argument in method_schema.arguments: for argument in method_schema.arguments:
arg_str_list.append(argument.name) arg_str_list.append(argument.name)
if argument.has_default_value(): if argument.has_default_value():
default_value_str = " = {}".format(argument.default) default_value_str = " = {}".format(argument.default_value)
else: else:
default_value_str = "" default_value_str = ""
arg_type_str = "{name}: {type}{default_value}".format( arg_type_str = "{name}: {type}{default_value}".format(
@ -63,6 +65,7 @@ def get_arg_return_types_from_interface(module_interface):
def _write(out_path, text): def _write(out_path, text):
old_text: Optional[str]
try: try:
with open(out_path, "r") as f: with open(out_path, "r") as f:
old_text = f.read() old_text = f.read()