mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Fix typing errors in torch.distributed.nn.* directory. (#47533)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/47533 Test Plan: Imported from OSS Reviewed By: walterddr Differential Revision: D24952500 Pulled By: xuzhao9 fbshipit-source-id: 8e66784fd8f9f111b6329e0bb48d6cd61c690a4a
This commit is contained in:
parent
915050ed66
commit
7f66fa62ca
3
mypy.ini
3
mypy.ini
|
|
@ -62,6 +62,9 @@ ignore_errors = False
|
||||||
[mypy-torch.distributed.distributed_c10d.*]
|
[mypy-torch.distributed.distributed_c10d.*]
|
||||||
ignore_errors = False
|
ignore_errors = False
|
||||||
|
|
||||||
|
[mypy-torch.distributed.nn.*]
|
||||||
|
ignore_errors = False
|
||||||
|
|
||||||
[mypy-torch.testing._internal.hypothesis_utils.*]
|
[mypy-torch.testing._internal.hypothesis_utils.*]
|
||||||
ignore_errors = True
|
ignore_errors = True
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -20,6 +20,7 @@ from torch.distributed.nn.jit import instantiator
|
||||||
from torch.distributed.rpc.utils import _parse_remote_device
|
from torch.distributed.rpc.utils import _parse_remote_device
|
||||||
from torch.nn.parameter import Parameter
|
from torch.nn.parameter import Parameter
|
||||||
from torch.utils.hooks import RemovableHandle
|
from torch.utils.hooks import RemovableHandle
|
||||||
|
from torch.nn import Module
|
||||||
|
|
||||||
|
|
||||||
_grad_t = Union[Tuple[Tensor, ...], Tensor]
|
_grad_t = Union[Tuple[Tensor, ...], Tensor]
|
||||||
|
|
@ -52,7 +53,7 @@ def _create_module(module_cls, args, kwargs, device="cpu", module_interface_cls=
|
||||||
|
|
||||||
|
|
||||||
def _param_rrefs(module_rref, recurse):
|
def _param_rrefs(module_rref, recurse):
|
||||||
ret = []
|
ret: List[rpc.RRef[Parameter]] = []
|
||||||
for param in module_rref.local_value().parameters(recurse):
|
for param in module_rref.local_value().parameters(recurse):
|
||||||
ret.append(rpc.RRef(param))
|
ret.append(rpc.RRef(param))
|
||||||
return ret
|
return ret
|
||||||
|
|
@ -216,45 +217,45 @@ class _RemoteModule(nn.Module):
|
||||||
def register_parameter(self, name: str, param: Optional[Parameter]) -> None:
|
def register_parameter(self, name: str, param: Optional[Parameter]) -> None:
|
||||||
_raise_not_supported(self.register_parameter.__name__)
|
_raise_not_supported(self.register_parameter.__name__)
|
||||||
|
|
||||||
def add_module(self, name: str, module: Optional["Module"]) -> None:
|
def add_module(self, name: str, module: Optional[Module]) -> None:
|
||||||
_raise_not_supported(self.add_module.__name__)
|
_raise_not_supported(self.add_module.__name__)
|
||||||
|
|
||||||
def apply(self: T, fn: Callable[["Module"], None]) -> T:
|
def apply(self: T, fn: Callable[[Module], None]) -> T: # type: ignore[return]
|
||||||
_raise_not_supported(self.apply.__name__)
|
_raise_not_supported(self.apply.__name__)
|
||||||
|
|
||||||
def cuda(self: T, device: Optional[Union[int, device]] = None) -> T:
|
def cuda(self: T, device: Optional[Union[int, device]] = None) -> T: # type: ignore[return]
|
||||||
_raise_not_supported(self.cuda.__name__)
|
_raise_not_supported(self.cuda.__name__)
|
||||||
|
|
||||||
def cpu(self: T) -> T:
|
def cpu(self: T) -> T: # type: ignore[return]
|
||||||
_raise_not_supported(self.cpu.__name__)
|
_raise_not_supported(self.cpu.__name__)
|
||||||
|
|
||||||
def type(self: T, dst_type: Union[dtype, str]) -> T:
|
def type(self: T, dst_type: Union[dtype, str]) -> T: # type: ignore[return]
|
||||||
_raise_not_supported(self.type.__name__)
|
_raise_not_supported(self.type.__name__)
|
||||||
|
|
||||||
def float(self: T) -> T:
|
def float(self: T) -> T: # type: ignore[return]
|
||||||
_raise_not_supported(self.float.__name__)
|
_raise_not_supported(self.float.__name__)
|
||||||
|
|
||||||
def double(self: T) -> T:
|
def double(self: T) -> T: # type: ignore[return]
|
||||||
_raise_not_supported(self.double.__name__)
|
_raise_not_supported(self.double.__name__)
|
||||||
|
|
||||||
def half(self: T) -> T:
|
def half(self: T) -> T: # type: ignore[return]
|
||||||
_raise_not_supported(self.half.__name__)
|
_raise_not_supported(self.half.__name__)
|
||||||
|
|
||||||
def bfloat16(self: T) -> T:
|
def bfloat16(self: T) -> T: # type: ignore[return]
|
||||||
_raise_not_supported(self.bfloat16.__name__)
|
_raise_not_supported(self.bfloat16.__name__)
|
||||||
|
|
||||||
def to(self, *args, **kwargs):
|
def to(self, *args, **kwargs) -> T: # type: ignore[return]
|
||||||
_raise_not_supported(self.to.__name__)
|
_raise_not_supported(self.to.__name__)
|
||||||
|
|
||||||
def register_backward_hook(
|
def register_backward_hook( # type: ignore[return]
|
||||||
self, hook: Callable[["Module", _grad_t, _grad_t], Union[None, Tensor]]
|
self, hook: Callable[[Module, _grad_t, _grad_t], Union[None, Tensor]]
|
||||||
) -> RemovableHandle:
|
) -> RemovableHandle:
|
||||||
_raise_not_supported(self.register_backward_hook.__name__)
|
_raise_not_supported(self.register_backward_hook.__name__)
|
||||||
|
|
||||||
def register_forward_pre_hook(self, hook: Callable[..., None]) -> RemovableHandle:
|
def register_forward_pre_hook(self, hook: Callable[..., None]) -> RemovableHandle: # type: ignore[return]
|
||||||
_raise_not_supported(self.register_forward_pre_hook.__name__)
|
_raise_not_supported(self.register_forward_pre_hook.__name__)
|
||||||
|
|
||||||
def register_forward_hook(self, hook: Callable[..., None]) -> RemovableHandle:
|
def register_forward_hook(self, hook: Callable[..., None]) -> RemovableHandle: # type: ignore[return]
|
||||||
_raise_not_supported(self.register_forward_hook.__name__)
|
_raise_not_supported(self.register_forward_hook.__name__)
|
||||||
|
|
||||||
def state_dict(self, destination=None, prefix="", keep_vars=False):
|
def state_dict(self, destination=None, prefix="", keep_vars=False):
|
||||||
|
|
@ -272,47 +273,47 @@ class _RemoteModule(nn.Module):
|
||||||
"Method ``parameters`` not supported for RemoteModule. Please use ``remote_parameters`` instead."
|
"Method ``parameters`` not supported for RemoteModule. Please use ``remote_parameters`` instead."
|
||||||
)
|
)
|
||||||
|
|
||||||
def named_parameters(
|
def named_parameters( # type: ignore[return]
|
||||||
self, prefix: str = "", recurse: bool = True
|
self, prefix: str = "", recurse: bool = True
|
||||||
) -> Iterator[Tuple[str, Tensor]]:
|
) -> Iterator[Tuple[str, Tensor]]:
|
||||||
_raise_not_supported(self.named_parameters.__name__)
|
_raise_not_supported(self.named_parameters.__name__)
|
||||||
|
|
||||||
def buffers(self, recurse: bool = True) -> Iterator[Tensor]:
|
def buffers(self, recurse: bool = True) -> Iterator[Tensor]: # type: ignore[return]
|
||||||
_raise_not_supported(self.buffers.__name__)
|
_raise_not_supported(self.buffers.__name__)
|
||||||
|
|
||||||
def named_buffers(
|
def named_buffers( # type: ignore[return]
|
||||||
self, prefix: str = "", recurse: bool = True
|
self, prefix: str = "", recurse: bool = True
|
||||||
) -> Iterator[Tuple[str, Tensor]]:
|
) -> Iterator[Tuple[str, Tensor]]:
|
||||||
_raise_not_supported(self.named_buffers.__name__)
|
_raise_not_supported(self.named_buffers.__name__)
|
||||||
|
|
||||||
def children(self) -> Iterator["Module"]:
|
def children(self) -> Iterator[Module]: # type: ignore[return]
|
||||||
_raise_not_supported(self.children.__name__)
|
_raise_not_supported(self.children.__name__)
|
||||||
|
|
||||||
def named_children(self) -> Iterator[Tuple[str, "Module"]]:
|
def named_children(self) -> Iterator[Tuple[str, Module]]: # type: ignore[return]
|
||||||
_raise_not_supported(self.named_children.__name__)
|
_raise_not_supported(self.named_children.__name__)
|
||||||
|
|
||||||
def modules(self) -> Iterator["Module"]:
|
def modules(self) -> Iterator[Module]: # type: ignore[return]
|
||||||
_raise_not_supported(self.modules.__name__)
|
_raise_not_supported(self.modules.__name__)
|
||||||
|
|
||||||
def named_modules(self, memo: Optional[Set["Module"]] = None, prefix: str = ""):
|
def named_modules(self, memo: Optional[Set[Module]] = None, prefix: str = ""):
|
||||||
_raise_not_supported(self.named_modules.__name__)
|
_raise_not_supported(self.named_modules.__name__)
|
||||||
|
|
||||||
def train(self: T, mode: bool = True) -> T:
|
def train(self: T, mode: bool = True) -> T: # type: ignore[return]
|
||||||
_raise_not_supported(self.train.__name__)
|
_raise_not_supported(self.train.__name__)
|
||||||
|
|
||||||
def eval(self: T) -> T:
|
def eval(self: T) -> T: # type: ignore[return]
|
||||||
_raise_not_supported(self.eval.__name__)
|
_raise_not_supported(self.eval.__name__)
|
||||||
|
|
||||||
def requires_grad_(self: T, requires_grad: bool = True) -> T:
|
def requires_grad_(self: T, requires_grad: bool = True) -> T: # type: ignore[return]
|
||||||
_raise_not_supported(self.requires_grad_.__name__)
|
_raise_not_supported(self.requires_grad_.__name__)
|
||||||
|
|
||||||
def zero_grad(self) -> None:
|
def zero_grad(self, set_to_none: bool = False) -> None:
|
||||||
_raise_not_supported(self.zero_grad.__name__)
|
_raise_not_supported(self.zero_grad.__name__)
|
||||||
|
|
||||||
def share_memory(self: T) -> T:
|
def share_memory(self: T) -> T: # type: ignore[return]
|
||||||
_raise_not_supported(self.share_memory.__name__)
|
_raise_not_supported(self.share_memory.__name__)
|
||||||
|
|
||||||
def extra_repr(self) -> str:
|
def extra_repr(self) -> str: # type: ignore[return]
|
||||||
_raise_not_supported(self.extra_repr.__name__)
|
_raise_not_supported(self.extra_repr.__name__)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -6,6 +6,7 @@ import sys
|
||||||
import tempfile
|
import tempfile
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from typing import Optional
|
||||||
from torch.distributed.nn.jit.templates.remote_module_template import (
|
from torch.distributed.nn.jit.templates.remote_module_template import (
|
||||||
REMOTE_MODULE_TEMPLATE,
|
REMOTE_MODULE_TEMPLATE,
|
||||||
)
|
)
|
||||||
|
|
@ -37,11 +38,12 @@ def get_arg_return_types_from_interface(module_interface):
|
||||||
|
|
||||||
arg_str_list = []
|
arg_str_list = []
|
||||||
arg_type_str_list = []
|
arg_type_str_list = []
|
||||||
|
assert method_schema is not None
|
||||||
for argument in method_schema.arguments:
|
for argument in method_schema.arguments:
|
||||||
arg_str_list.append(argument.name)
|
arg_str_list.append(argument.name)
|
||||||
|
|
||||||
if argument.has_default_value():
|
if argument.has_default_value():
|
||||||
default_value_str = " = {}".format(argument.default)
|
default_value_str = " = {}".format(argument.default_value)
|
||||||
else:
|
else:
|
||||||
default_value_str = ""
|
default_value_str = ""
|
||||||
arg_type_str = "{name}: {type}{default_value}".format(
|
arg_type_str = "{name}: {type}{default_value}".format(
|
||||||
|
|
@ -63,6 +65,7 @@ def get_arg_return_types_from_interface(module_interface):
|
||||||
|
|
||||||
|
|
||||||
def _write(out_path, text):
|
def _write(out_path, text):
|
||||||
|
old_text: Optional[str]
|
||||||
try:
|
try:
|
||||||
with open(out_path, "r") as f:
|
with open(out_path, "r") as f:
|
||||||
old_text = f.read()
|
old_text = f.read()
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user