Add type annotations to torch._C._distributed_rpc module. (#46624)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/46624

Test Plan: Imported from OSS

Reviewed By: glaringlee

Differential Revision: D24761656

Pulled By: xuzhao9

fbshipit-source-id: b55aee5dd2b97f573a50e5bbfddde7d984943fec
This commit is contained in:
Xu Zhao 2020-11-06 00:47:23 -08:00 committed by Facebook GitHub Bot
parent 73a3e70b24
commit eaa993a2e0
13 changed files with 321 additions and 51 deletions

View File

@ -56,6 +56,12 @@ ignore_errors = True
[mypy-torch.distributed.*]
ignore_errors = True
[mypy-torch.distributed.rpc.*]
ignore_errors = False
[mypy-torch.distributed.rpc._testing.*]
ignore_errors = True
[mypy-torch.testing._internal.hypothesis_utils.*]
ignore_errors = True

View File

@ -4,10 +4,10 @@ from enum import Enum
# Defined in tools/autograd/init.cpp
class ProfilerState(Enum):
Disable = 0
CPU = 1
CUDA = 2
NVTX = 3
Disable = ...
CPU = ...
CUDA = ...
NVTX = ...
class ProfilerConfig:

View File

@ -31,14 +31,14 @@ class Reducer:
...
class ReduceOp(Enum):
SUM = 0
PRODUCT = 1
MIN = 2
MAX = 3
BAND = 4
BOR = 5
BXOR = 6
UNUSED = 7
SUM = ...
PRODUCT = ...
MIN = ...
MAX = ...
BAND = ...
BOR = ...
BXOR = ...
UNUSED = ...
class BroadcastOptions:
rootRank: int

View File

@ -0,0 +1,194 @@
from typing import Tuple, Dict, Optional, List, Any, overload
from datetime import timedelta
import enum
import torch
from . import Future
from ._autograd import ProfilerConfig, ProfilerState, ProfilerEvent
from ._distributed_c10d import ProcessGroup, Store
# This module is defined in torch/csrc/distributed/rpc/init.cpp
_DEFAULT_NUM_SEND_RECV_THREADS: int
_DEFAULT_INIT_METHOD: str
_DEFAULT_NUM_WORKER_THREADS: int
_UNSET_RPC_TIMEOUT: float
_DEFAULT_RPC_TIMEOUT_SEC: float
class RpcBackendOptions:
rpc_timeout: float
init_method: str
def __init__(
self,
rpc_timeout: float = _DEFAULT_RPC_TIMEOUT_SEC,
init_method: str = _DEFAULT_INIT_METHOD,
): ...
class WorkerInfo:
def __init__(self, name: str, worker_id: int): ...
@property
def name(self) -> str: ...
@property
def id(self) -> int: ...
def __eq__(self, other: object) -> bool: ...
def __repr__(self) -> str: ...
class RpcAgent:
def join(self): ...
def sync(self): ...
def shutdown(self): ...
@overload
def get_worker_info(self) -> WorkerInfo: ...
@overload
def get_worker_info(self, workerName: str) -> WorkerInfo: ...
def get_worker_infos(self) -> List[WorkerInfo]: ...
def get_debug_info(self) -> Dict[str, str]: ...
def get_metrics(self) -> Dict[str, str]: ...
class PyRRef:
def __init__(self, value: Any, type_hint: Any = None): ...
def is_owner(self) -> bool: ...
def confirmed_by_owner(self) -> bool: ...
def owner(self) -> WorkerInfo: ...
def owner_name(self) -> str: ...
def to_here(self, timeout: float = _UNSET_RPC_TIMEOUT) -> Any: ...
def local_value(self) -> Any: ...
def rpc_sync(self) -> Any: ...
def rpc_async(self) -> Any: ...
def remote(self) -> Any: ...
def _serialize(self) -> Tuple: ...
@staticmethod
def _deserialize(tp: Tuple) -> 'PyRRef': ...
def _get_type(self) -> Any: ...
def _get_future(self) -> Future: ...
def _get_profiling_future(self) -> Future: ...
def _set_profiling_future(self, profilingFuture: Future): ...
def __repr__(self) -> str: ...
...
class ProcessGroupRpcBackendOptions(RpcBackendOptions):
num_send_recv_threads: int
def __init__(
self,
num_send_recv_threads: int,
rpc_timeout: float,
init_method: str
): ...
class ProcessGroupAgent(RpcAgent):
def __init__(
self,
worker_name: str,
pg: ProcessGroup,
numSendRecvThreads: int,
rpcTimeout: timedelta
): ...
@overload
def get_worker_info(self) -> WorkerInfo: ...
@overload
def get_worker_info(self, workerName: str) -> WorkerInfo: ...
@overload
def get_worker_info(self, id: int) -> WorkerInfo: ...
def get_worker_infos(self) -> List[WorkerInfo]: ...
def join(self): ...
def shutdown(self): ...
def sync(self): ...
class _TensorPipeRpcBackendOptionsBase(RpcBackendOptions):
num_worker_threads: int
device_maps: Dict[str, Dict[int, int]]
def __init__(
self,
num_worker_threads: int,
_transports: Optional[List],
_channels: Optional[List],
rpc_timeout: float = _DEFAULT_RPC_TIMEOUT_SEC,
init_method: str = _DEFAULT_INIT_METHOD,
device_maps: Dict[str, Dict[int, int]] = dict()): ...
def set_device_map(self, to: str, device_map: Dict[str, Dict[int, int]]): ...
class TensorPipeAgent(RpcAgent):
def __init__(
self,
store: Store,
name: str,
worker_id: int,
world_size: int,
pg: ProcessGroup,
opts: _TensorPipeRpcBackendOptionsBase,
): ...
def join(self): ...
def shutdown(self): ...
@overload
def get_worker_info(self) -> WorkerInfo: ...
@overload
def get_worker_info(self, workerName: str) -> WorkerInfo: ...
@overload
def get_worker_info(self, id: int) -> WorkerInfo: ...
def get_worker_infos(self) -> List[WorkerInfo]: ...
def _set_reverse_device_maps(self, reverseDeviceMaps: Dict[str, Dict[int, int]]): ...
def _is_current_rpc_agent_set() -> bool: ...
def _get_current_rpc_agent()-> RpcAgent: ...
def _set_and_start_rpc_agent(agent: RpcAgent): ...
def _reset_current_rpc_agent(): ...
def _delete_all_user_and_unforked_owner_rrefs(timeout: timedelta = ...): ...
def _destroy_rref_context(ignoreRRefLeak: bool): ...
def _rref_context_get_debug_info() -> Dict[str, str]: ...
def _cleanup_python_rpc_handler(): ...
def _invoke_rpc_builtin(
dst: WorkerInfo,
opName: str,
rpcTimeoutSeconds: float,
*args: Any,
**kwargs: Any
): ...
def _invoke_rpc_python_udf(
dst: WorkerInfo,
pickledPythonUDF: str,
tensors: List[torch.Tensor],
rpcTimeoutSeconds: float,
isAsyncExecution: bool
): ...
def _invoke_rpc_torchscript(
dstWorkerName: str,
qualifiedNameStr: str,
argsTuple: Tuple,
kwargsDict: Dict,
rpcTimeoutSeconds: float,
isAsyncExecution: bool,
): ...
def _invoke_remote_builtin(
dst: WorkerInfo,
opName: str,
rpcTimeoutSeconds: float,
*args: Any,
**kwargs: Any
): ...
def _invoke_remote_python_udf(
dst: WorkerInfo,
pickledPythonUDF: str,
tensors: List[torch.Tensor],
rpcTimeoutSeconds: float,
isAsyncExecution: bool,
): ...
def _invoke_remote_torchscript(
dstWorkerName: WorkerInfo,
qualifiedNameStr: str,
rpcTimeoutSeconds: float,
isAsyncExecution: bool,
*args: Any,
**kwargs: Any
): ...
def get_rpc_timeout() -> float: ...
def enable_gil_profiling(flag: bool): ...
def _set_rpc_timeout(rpcTimeoutSeconds: float): ...
class RemoteProfilerManager:
@staticmethod
def set_current_profiling_key(key: str): ...
def _enable_server_process_global_profiler(new_config: ProfilerConfig): ...
def _disable_server_process_global_profiler() -> List[List[List[ProfilerEvent]]]: ...
def _set_profiler_node_id(default_node_id: int): ...
def _enable_jit_rref_pickle(): ...
def _disable_jit_rref_pickle(): ...

View File

@ -38,7 +38,15 @@ PyObject* rpc_init(PyObject* _unused, PyObject* noargs) {
throw python_error();
}
auto module = py::handle(rpc_module).cast<py::module>();
auto torch_C_module = THPObjectPtr(PyImport_ImportModule("torch._C"));
if (!torch_C_module) {
throw python_error();
}
auto torch_C_m = py::handle(torch_C_module).cast<py::module>();
auto m = torch_C_m.def_submodule("_distributed_rpc", "distributed rpc bindings");
auto module = py::handle(m).cast<py::module>();
auto rpcBackendOptions =
shared_ptr_class_<RpcBackendOptions>(
@ -114,6 +122,20 @@ PyObject* rpc_init(PyObject* _unused, PyObject* noargs) {
"join", &RpcAgent::join, py::call_guard<py::gil_scoped_release>())
.def(
"sync", &RpcAgent::sync, py::call_guard<py::gil_scoped_release>())
.def(
"shutdown",
&RpcAgent::shutdown,
py::call_guard<py::gil_scoped_release>())
.def(
"get_worker_info",
(const WorkerInfo& (RpcAgent::*)(void)const) &
RpcAgent::getWorkerInfo,
py::call_guard<py::gil_scoped_release>())
.def(
"get_worker_info",
(const WorkerInfo& (RpcAgent::*)(const std::string&)const) &
RpcAgent::getWorkerInfo,
py::call_guard<py::gil_scoped_release>())
.def(
"get_worker_infos",
&RpcAgent::getWorkerInfos,

View File

@ -1,6 +1,7 @@
import logging
import threading
from typing import Generator, Tuple
import torch
import torch.distributed as dist
@ -20,12 +21,46 @@ if is_available() and not torch._C._rpc_init():
if is_available():
from . import api, backend_registry, functions, _set_profiler_node_id
from . import (
from . import api, backend_registry, functions
from torch._C._distributed_rpc import (
_disable_jit_rref_pickle,
_enable_jit_rref_pickle,
_disable_server_process_global_profiler,
_enable_server_process_global_profiler,
_set_and_start_rpc_agent,
_reset_current_rpc_agent,
_delete_all_user_and_unforked_owner_rrefs,
_destroy_rref_context,
_set_profiler_node_id,
_is_current_rpc_agent_set,
_rref_context_get_debug_info,
_cleanup_python_rpc_handler,
_invoke_rpc_builtin,
_invoke_rpc_python_udf,
_invoke_rpc_torchscript,
_invoke_remote_builtin,
_invoke_remote_python_udf,
_invoke_remote_torchscript,
_set_rpc_timeout,
_get_current_rpc_agent,
get_rpc_timeout,
enable_gil_profiling,
RpcBackendOptions,
_TensorPipeRpcBackendOptionsBase,
ProcessGroupRpcBackendOptions,
RpcAgent,
PyRRef,
ProcessGroupAgent,
TensorPipeAgent,
RemoteProfilerManager,
WorkerInfo,
_DEFAULT_INIT_METHOD,
_DEFAULT_NUM_SEND_RECV_THREADS,
_DEFAULT_NUM_WORKER_THREADS,
_UNSET_RPC_TIMEOUT,
_DEFAULT_RPC_TIMEOUT_SEC,
) # noqa: F401
from torch._C._distributed_c10d import Store
from .api import * # noqa: F401
from .options import TensorPipeRpcBackendOptions # noqa: F401
from .backend_registry import BackendType
@ -36,6 +71,7 @@ if is_available():
import numbers
rendezvous_iterator: Generator[Tuple[Store, int, int], None, None]
def init_rpc(
name,
@ -104,18 +140,19 @@ if is_available():
raise TypeError(
f"Could not infer backend for options {rpc_backend_options}"
)
if backend != BackendType.TENSORPIPE:
# Ignore type error because mypy doesn't handle dynamically generated type objects (#4865)
if backend != BackendType.TENSORPIPE: # type: ignore[attr-defined]
logger.warning(
f"RPC was initialized with no explicit backend but with options "
f"RPC was initialized with no explicit backend but with options " # type: ignore[attr-defined]
f"corresponding to {backend}, hence that backend will be used "
f"instead of the default {BackendType.TENSORPIPE}. To silence this "
f"warning pass `backend={backend}` explicitly."
)
if backend is None:
backend = BackendType.TENSORPIPE
backend = BackendType.TENSORPIPE # type: ignore[attr-defined]
if backend == BackendType.PROCESS_GROUP:
if backend == BackendType.PROCESS_GROUP: # type: ignore[attr-defined]
logger.warning(
"RPC was initialized with the PROCESS_GROUP backend which is "
"deprecated and slated to be removed and superseded by the TENSORPIPE "
@ -176,7 +213,7 @@ if is_available():
def _init_rpc_backend(
backend=backend_registry.BackendType.TENSORPIPE,
backend=BackendType.TENSORPIPE, # type: ignore[attr-defined]
store=None,
name=None,
rank=-1,
@ -204,7 +241,6 @@ if is_available():
@api._require_initialized
def _get_debug_info():
from . import _rref_context_get_debug_info
info = _rref_context_get_debug_info()
info.update(api._get_current_rpc_agent().get_debug_info())
info.update(dist_autograd._get_debug_info())

View File

@ -4,11 +4,11 @@ import functools
import inspect
import logging
import threading
from typing import Generic, TypeVar
from typing import Generic, TypeVar, Set, Any
import torch
from . import (
from torch._C._distributed_rpc import (
PyRRef,
RemoteProfilerManager,
WorkerInfo,
@ -99,10 +99,10 @@ class AllGatherStates(object):
# States used by `def _all_gather()`.
# `_ALL_WORKER_NAMES` is initialized on initiaizing RPC layer.
_ALL_WORKER_NAMES = None
_ALL_WORKER_NAMES: Set[Any] = set()
_all_gather_dict_lock = threading.RLock()
_all_gather_sequence_id = 0
_all_gather_sequence_id_to_states = collections.defaultdict(AllGatherStates)
_all_gather_sequence_id_to_states: collections.defaultdict = collections.defaultdict(AllGatherStates)
def _init_rpc_states(agent):
@ -379,16 +379,18 @@ GenericWithOneTypeVar = Generic[T]
try:
# Combine the implementation class and the type class.
class RRef(PyRRef, GenericWithOneTypeVar):
class RRef(PyRRef, Generic[T]):
pass
except TypeError as exc:
# TypeError: metaclass conflict: the metaclass of a derived class
# must be a (non-strict) subclass of the metaclasses of all its bases
class RRefMeta(PyRRef.__class__, GenericWithOneTypeVar.__class__):
# Mypy doesn't understand __class__ (mypy bug #4177)
class RRefMeta(PyRRef.__class__, GenericWithOneTypeVar.__class__): # type: ignore
pass
# Combine the implementation class and the type class.
class RRef(PyRRef, GenericWithOneTypeVar, metaclass=RRefMeta):
# Types for classes expecting a certain generic parameter (mypy bug #7791)
class RRef(PyRRef, GenericWithOneTypeVar, metaclass=RRefMeta): # type: ignore
pass
@ -564,7 +566,8 @@ def remote(to, func, args=None, kwargs=None, timeout=UNSET_RPC_TIMEOUT):
dst_worker_info.name,
)
RemoteProfilerManager.set_current_profiling_key(rpc_profiling_key)
ctx_manager = torch.autograd.profiler.record_function(rpc_profiling_key)
# Mypy doesn't support re-def of a variable not in the same block (#1174)
ctx_manager = torch.autograd.profiler.record_function(rpc_profiling_key) # type: ignore[assignment]
with ctx_manager as rf:
args = args if args else ()
@ -639,7 +642,8 @@ def _invoke_rpc(to, func, rpc_type, args=None, kwargs=None, rpc_timeout=UNSET_RP
dst_worker_info.name,
)
RemoteProfilerManager.set_current_profiling_key(rpc_profiling_key)
ctx_manager = torch.autograd.profiler.record_function(rpc_profiling_key)
# Mypy doesn't support re-def of a variable not in the same block (#1174)
ctx_manager = torch.autograd.profiler.record_function(rpc_profiling_key) # type: ignore[assignment]
with ctx_manager as rf:
args = args if args else ()

View File

@ -28,8 +28,10 @@ _backend_type_doc = """
"""
# Create an enum type, `BackendType`, with empty members.
BackendType = enum.Enum(value="BackendType", names={})
BackendType.__repr__ = _backend_type_repr
# Can't handle Function Enum API (mypy bug #9079)
BackendType = enum.Enum(value="BackendType", names=dict()) # type: ignore[misc]
# Unable to assign a function a method (mypy bug #2427)
BackendType.__repr__ = _backend_type_repr # type: ignore[assignment]
BackendType.__doc__ = _backend_type_doc
def backend_registered(backend_name):
@ -73,8 +75,10 @@ def register_backend(
},
**existing_enum_dict
)
BackendType = enum.Enum(value="BackendType", names=extended_enum_dict)
BackendType.__repr__ = _backend_type_repr
# Can't handle Function Enum API (mypy bug #9079)
BackendType = enum.Enum(value="BackendType", names=extended_enum_dict) # type: ignore[misc]
# Unable to assign a function a method (mypy bug #2427)
BackendType.__repr__ = _backend_type_repr # type: ignore[assignment]
BackendType.__doc__ = _backend_type_doc
return BackendType[backend_name]

View File

@ -1,6 +1,6 @@
from datetime import timedelta
from . import (
from torch._C._distributed_rpc import (
_DEFAULT_INIT_METHOD,
_DEFAULT_NUM_SEND_RECV_THREADS,
_DEFAULT_NUM_WORKER_THREADS,
@ -10,16 +10,16 @@ from . import (
# For any RpcAgent.
DEFAULT_RPC_TIMEOUT_SEC = _DEFAULT_RPC_TIMEOUT_SEC
DEFAULT_INIT_METHOD = _DEFAULT_INIT_METHOD
DEFAULT_SHUTDOWN_TIMEOUT = 5.0
DEFAULT_RPC_TIMEOUT_SEC: float = _DEFAULT_RPC_TIMEOUT_SEC
DEFAULT_INIT_METHOD: str = _DEFAULT_INIT_METHOD
DEFAULT_SHUTDOWN_TIMEOUT: float = 5.0
# For ProcessGroupAgent.
DEFAULT_NUM_SEND_RECV_THREADS = _DEFAULT_NUM_SEND_RECV_THREADS
DEFAULT_NUM_SEND_RECV_THREADS: int = _DEFAULT_NUM_SEND_RECV_THREADS
# For TensorPipeAgent.
DEFAULT_NUM_WORKER_THREADS = _DEFAULT_NUM_WORKER_THREADS
DEFAULT_NUM_WORKER_THREADS: int = _DEFAULT_NUM_WORKER_THREADS
# Ensure that we don't time out when there are long periods of time without
# any operations against the underlying ProcessGroup.
DEFAULT_PROCESS_GROUP_TIMEOUT = timedelta(milliseconds=2 ** 31 - 1)
DEFAULT_PROCESS_GROUP_TIMEOUT: timedelta = timedelta(milliseconds=2 ** 31 - 1)
# Value indicating that timeout is not set for RPC call, and the default should be used.
UNSET_RPC_TIMEOUT = _UNSET_RPC_TIMEOUT
UNSET_RPC_TIMEOUT: float = _UNSET_RPC_TIMEOUT

View File

@ -160,5 +160,6 @@ def async_execution(fn):
@functools.wraps(fn)
def wrapper(*args, **kwargs):
return fn(*args, **kwargs)
wrapper._wrapped_async_rpc_function = fn
# Can't declare and use attributes of function objects (mypy#2087)
wrapper._wrapped_async_rpc_function = fn # type: ignore[attr-defined]
return wrapper

View File

@ -10,7 +10,7 @@ from enum import Enum
import torch
import torch.distributed as dist
from . import _get_current_rpc_agent
from torch._C._distributed_rpc import _get_current_rpc_agent
# Thread local tensor tables to store tensors while pickling torch.Tensor
@ -37,7 +37,8 @@ class _InternalRPCPickler:
"""
def __init__(self):
self._dispatch_table = copyreg.dispatch_table.copy()
# Ignore type error because dispatch_table is defined in third-party package
self._dispatch_table = copyreg.dispatch_table.copy() # type: ignore[attr-defined]
self._dispatch_table[torch.Tensor] = self._tensor_reducer
@classmethod
@ -80,9 +81,11 @@ class _InternalRPCPickler:
#
# The return value of a `rpc.remote(..)` call is type of `rpc.PyRRef`.
# The deserialized RRef object on an RPC receiver side is type of `rpc.PyRRef`.
p.dispatch_table[dist.rpc.PyRRef] = self._py_rref_reducer
# Ignore type error because dispatch_table is defined in third-party package
p.dispatch_table[dist.rpc.PyRRef] = self._py_rref_reducer # type: ignore[index]
# An RRef created locally by RRef Python constructor is type of `rpc.RRef`.
p.dispatch_table[dist.rpc.RRef] = self._rref_reducer
# Ignore type error because dispatch_table is defined in third-party package
p.dispatch_table[dist.rpc.RRef] = self._rref_reducer # type: ignore[index]
# save _thread_local_tensor_tables.send_tables if it is in nested call
global _thread_local_tensor_tables
@ -224,8 +227,8 @@ def _start_record_function(exec_type, func_name, current_worker_name, dest_worke
profile_key = "rpc_{}#{}({} -> {})".format(
exec_type.value, str(func_name), current_worker_name, dest_worker_name
)
rf = torch.autograd._RecordFunction()
torch.autograd._run_before_callbacks(rf, profile_key)
rf = torch.autograd._RecordFunction() # type: ignore[attr-defined]
torch.autograd._run_before_callbacks(rf, profile_key) # type: ignore[attr-defined]
return rf

View File

@ -1,4 +1,4 @@
from . import _TensorPipeRpcBackendOptionsBase
from torch._C._distributed_rpc import _TensorPipeRpcBackendOptionsBase
from . import constants as rpc_contants
import torch

View File

@ -2601,7 +2601,7 @@ class RpcTest(RpcAgentTestFixture):
@dist_init
def test_disable_gil_profiling(self):
# test that rpc.enable_gil_profilig(false) will result in
# test that rpc.enable_gil_profiling(false) will result in
# GIL wait time not being recorded.
# GIL profiling should be disabled by default.