mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Fix typing errors in torch.distributed.nn.* directory. (#47533)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/47533 Test Plan: Imported from OSS Reviewed By: walterddr Differential Revision: D24952500 Pulled By: xuzhao9 fbshipit-source-id: 8e66784fd8f9f111b6329e0bb48d6cd61c690a4a
This commit is contained in:
parent
915050ed66
commit
7f66fa62ca
3
mypy.ini
3
mypy.ini
|
|
@ -62,6 +62,9 @@ ignore_errors = False
|
|||
[mypy-torch.distributed.distributed_c10d.*]
|
||||
ignore_errors = False
|
||||
|
||||
[mypy-torch.distributed.nn.*]
|
||||
ignore_errors = False
|
||||
|
||||
[mypy-torch.testing._internal.hypothesis_utils.*]
|
||||
ignore_errors = True
|
||||
|
||||
|
|
|
|||
|
|
@ -20,6 +20,7 @@ from torch.distributed.nn.jit import instantiator
|
|||
from torch.distributed.rpc.utils import _parse_remote_device
|
||||
from torch.nn.parameter import Parameter
|
||||
from torch.utils.hooks import RemovableHandle
|
||||
from torch.nn import Module
|
||||
|
||||
|
||||
_grad_t = Union[Tuple[Tensor, ...], Tensor]
|
||||
|
|
@ -52,7 +53,7 @@ def _create_module(module_cls, args, kwargs, device="cpu", module_interface_cls=
|
|||
|
||||
|
||||
def _param_rrefs(module_rref, recurse):
|
||||
ret = []
|
||||
ret: List[rpc.RRef[Parameter]] = []
|
||||
for param in module_rref.local_value().parameters(recurse):
|
||||
ret.append(rpc.RRef(param))
|
||||
return ret
|
||||
|
|
@ -216,45 +217,45 @@ class _RemoteModule(nn.Module):
|
|||
def register_parameter(self, name: str, param: Optional[Parameter]) -> None:
|
||||
_raise_not_supported(self.register_parameter.__name__)
|
||||
|
||||
def add_module(self, name: str, module: Optional["Module"]) -> None:
|
||||
def add_module(self, name: str, module: Optional[Module]) -> None:
|
||||
_raise_not_supported(self.add_module.__name__)
|
||||
|
||||
def apply(self: T, fn: Callable[["Module"], None]) -> T:
|
||||
def apply(self: T, fn: Callable[[Module], None]) -> T: # type: ignore[return]
|
||||
_raise_not_supported(self.apply.__name__)
|
||||
|
||||
def cuda(self: T, device: Optional[Union[int, device]] = None) -> T:
|
||||
def cuda(self: T, device: Optional[Union[int, device]] = None) -> T: # type: ignore[return]
|
||||
_raise_not_supported(self.cuda.__name__)
|
||||
|
||||
def cpu(self: T) -> T:
|
||||
def cpu(self: T) -> T: # type: ignore[return]
|
||||
_raise_not_supported(self.cpu.__name__)
|
||||
|
||||
def type(self: T, dst_type: Union[dtype, str]) -> T:
|
||||
def type(self: T, dst_type: Union[dtype, str]) -> T: # type: ignore[return]
|
||||
_raise_not_supported(self.type.__name__)
|
||||
|
||||
def float(self: T) -> T:
|
||||
def float(self: T) -> T: # type: ignore[return]
|
||||
_raise_not_supported(self.float.__name__)
|
||||
|
||||
def double(self: T) -> T:
|
||||
def double(self: T) -> T: # type: ignore[return]
|
||||
_raise_not_supported(self.double.__name__)
|
||||
|
||||
def half(self: T) -> T:
|
||||
def half(self: T) -> T: # type: ignore[return]
|
||||
_raise_not_supported(self.half.__name__)
|
||||
|
||||
def bfloat16(self: T) -> T:
|
||||
def bfloat16(self: T) -> T: # type: ignore[return]
|
||||
_raise_not_supported(self.bfloat16.__name__)
|
||||
|
||||
def to(self, *args, **kwargs):
|
||||
def to(self, *args, **kwargs) -> T: # type: ignore[return]
|
||||
_raise_not_supported(self.to.__name__)
|
||||
|
||||
def register_backward_hook(
|
||||
self, hook: Callable[["Module", _grad_t, _grad_t], Union[None, Tensor]]
|
||||
def register_backward_hook( # type: ignore[return]
|
||||
self, hook: Callable[[Module, _grad_t, _grad_t], Union[None, Tensor]]
|
||||
) -> RemovableHandle:
|
||||
_raise_not_supported(self.register_backward_hook.__name__)
|
||||
|
||||
def register_forward_pre_hook(self, hook: Callable[..., None]) -> RemovableHandle:
|
||||
def register_forward_pre_hook(self, hook: Callable[..., None]) -> RemovableHandle: # type: ignore[return]
|
||||
_raise_not_supported(self.register_forward_pre_hook.__name__)
|
||||
|
||||
def register_forward_hook(self, hook: Callable[..., None]) -> RemovableHandle:
|
||||
def register_forward_hook(self, hook: Callable[..., None]) -> RemovableHandle: # type: ignore[return]
|
||||
_raise_not_supported(self.register_forward_hook.__name__)
|
||||
|
||||
def state_dict(self, destination=None, prefix="", keep_vars=False):
|
||||
|
|
@ -272,47 +273,47 @@ class _RemoteModule(nn.Module):
|
|||
"Method ``parameters`` not supported for RemoteModule. Please use ``remote_parameters`` instead."
|
||||
)
|
||||
|
||||
def named_parameters(
|
||||
def named_parameters( # type: ignore[return]
|
||||
self, prefix: str = "", recurse: bool = True
|
||||
) -> Iterator[Tuple[str, Tensor]]:
|
||||
_raise_not_supported(self.named_parameters.__name__)
|
||||
|
||||
def buffers(self, recurse: bool = True) -> Iterator[Tensor]:
|
||||
def buffers(self, recurse: bool = True) -> Iterator[Tensor]: # type: ignore[return]
|
||||
_raise_not_supported(self.buffers.__name__)
|
||||
|
||||
def named_buffers(
|
||||
def named_buffers( # type: ignore[return]
|
||||
self, prefix: str = "", recurse: bool = True
|
||||
) -> Iterator[Tuple[str, Tensor]]:
|
||||
_raise_not_supported(self.named_buffers.__name__)
|
||||
|
||||
def children(self) -> Iterator["Module"]:
|
||||
def children(self) -> Iterator[Module]: # type: ignore[return]
|
||||
_raise_not_supported(self.children.__name__)
|
||||
|
||||
def named_children(self) -> Iterator[Tuple[str, "Module"]]:
|
||||
def named_children(self) -> Iterator[Tuple[str, Module]]: # type: ignore[return]
|
||||
_raise_not_supported(self.named_children.__name__)
|
||||
|
||||
def modules(self) -> Iterator["Module"]:
|
||||
def modules(self) -> Iterator[Module]: # type: ignore[return]
|
||||
_raise_not_supported(self.modules.__name__)
|
||||
|
||||
def named_modules(self, memo: Optional[Set["Module"]] = None, prefix: str = ""):
|
||||
def named_modules(self, memo: Optional[Set[Module]] = None, prefix: str = ""):
|
||||
_raise_not_supported(self.named_modules.__name__)
|
||||
|
||||
def train(self: T, mode: bool = True) -> T:
|
||||
def train(self: T, mode: bool = True) -> T: # type: ignore[return]
|
||||
_raise_not_supported(self.train.__name__)
|
||||
|
||||
def eval(self: T) -> T:
|
||||
def eval(self: T) -> T: # type: ignore[return]
|
||||
_raise_not_supported(self.eval.__name__)
|
||||
|
||||
def requires_grad_(self: T, requires_grad: bool = True) -> T:
|
||||
def requires_grad_(self: T, requires_grad: bool = True) -> T: # type: ignore[return]
|
||||
_raise_not_supported(self.requires_grad_.__name__)
|
||||
|
||||
def zero_grad(self) -> None:
|
||||
def zero_grad(self, set_to_none: bool = False) -> None:
|
||||
_raise_not_supported(self.zero_grad.__name__)
|
||||
|
||||
def share_memory(self: T) -> T:
|
||||
def share_memory(self: T) -> T: # type: ignore[return]
|
||||
_raise_not_supported(self.share_memory.__name__)
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
def extra_repr(self) -> str: # type: ignore[return]
|
||||
_raise_not_supported(self.extra_repr.__name__)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ import sys
|
|||
import tempfile
|
||||
|
||||
import torch
|
||||
from typing import Optional
|
||||
from torch.distributed.nn.jit.templates.remote_module_template import (
|
||||
REMOTE_MODULE_TEMPLATE,
|
||||
)
|
||||
|
|
@ -37,11 +38,12 @@ def get_arg_return_types_from_interface(module_interface):
|
|||
|
||||
arg_str_list = []
|
||||
arg_type_str_list = []
|
||||
assert method_schema is not None
|
||||
for argument in method_schema.arguments:
|
||||
arg_str_list.append(argument.name)
|
||||
|
||||
if argument.has_default_value():
|
||||
default_value_str = " = {}".format(argument.default)
|
||||
default_value_str = " = {}".format(argument.default_value)
|
||||
else:
|
||||
default_value_str = ""
|
||||
arg_type_str = "{name}: {type}{default_value}".format(
|
||||
|
|
@ -63,6 +65,7 @@ def get_arg_return_types_from_interface(module_interface):
|
|||
|
||||
|
||||
def _write(out_path, text):
|
||||
old_text: Optional[str]
|
||||
try:
|
||||
with open(out_path, "r") as f:
|
||||
old_text = f.read()
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user