mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Add type annotations to torch._C._distributed_rpc module. (#46624)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/46624 Test Plan: Imported from OSS Reviewed By: glaringlee Differential Revision: D24761656 Pulled By: xuzhao9 fbshipit-source-id: b55aee5dd2b97f573a50e5bbfddde7d984943fec
This commit is contained in:
parent
73a3e70b24
commit
eaa993a2e0
6
mypy.ini
6
mypy.ini
|
|
@ -56,6 +56,12 @@ ignore_errors = True
|
|||
[mypy-torch.distributed.*]
|
||||
ignore_errors = True
|
||||
|
||||
[mypy-torch.distributed.rpc.*]
|
||||
ignore_errors = False
|
||||
|
||||
[mypy-torch.distributed.rpc._testing.*]
|
||||
ignore_errors = True
|
||||
|
||||
[mypy-torch.testing._internal.hypothesis_utils.*]
|
||||
ignore_errors = True
|
||||
|
||||
|
|
|
|||
|
|
@ -4,10 +4,10 @@ from enum import Enum
|
|||
# Defined in tools/autograd/init.cpp
|
||||
|
||||
class ProfilerState(Enum):
|
||||
Disable = 0
|
||||
CPU = 1
|
||||
CUDA = 2
|
||||
NVTX = 3
|
||||
Disable = ...
|
||||
CPU = ...
|
||||
CUDA = ...
|
||||
NVTX = ...
|
||||
|
||||
|
||||
class ProfilerConfig:
|
||||
|
|
|
|||
|
|
@ -31,14 +31,14 @@ class Reducer:
|
|||
...
|
||||
|
||||
class ReduceOp(Enum):
|
||||
SUM = 0
|
||||
PRODUCT = 1
|
||||
MIN = 2
|
||||
MAX = 3
|
||||
BAND = 4
|
||||
BOR = 5
|
||||
BXOR = 6
|
||||
UNUSED = 7
|
||||
SUM = ...
|
||||
PRODUCT = ...
|
||||
MIN = ...
|
||||
MAX = ...
|
||||
BAND = ...
|
||||
BOR = ...
|
||||
BXOR = ...
|
||||
UNUSED = ...
|
||||
|
||||
class BroadcastOptions:
|
||||
rootRank: int
|
||||
|
|
|
|||
194
torch/_C/_distributed_rpc.pyi
Normal file
194
torch/_C/_distributed_rpc.pyi
Normal file
|
|
@ -0,0 +1,194 @@
|
|||
from typing import Tuple, Dict, Optional, List, Any, overload
|
||||
from datetime import timedelta
|
||||
import enum
|
||||
import torch
|
||||
from . import Future
|
||||
from ._autograd import ProfilerConfig, ProfilerState, ProfilerEvent
|
||||
from ._distributed_c10d import ProcessGroup, Store
|
||||
|
||||
# This module is defined in torch/csrc/distributed/rpc/init.cpp
|
||||
|
||||
_DEFAULT_NUM_SEND_RECV_THREADS: int
|
||||
_DEFAULT_INIT_METHOD: str
|
||||
_DEFAULT_NUM_WORKER_THREADS: int
|
||||
_UNSET_RPC_TIMEOUT: float
|
||||
_DEFAULT_RPC_TIMEOUT_SEC: float
|
||||
|
||||
class RpcBackendOptions:
|
||||
rpc_timeout: float
|
||||
init_method: str
|
||||
def __init__(
|
||||
self,
|
||||
rpc_timeout: float = _DEFAULT_RPC_TIMEOUT_SEC,
|
||||
init_method: str = _DEFAULT_INIT_METHOD,
|
||||
): ...
|
||||
|
||||
class WorkerInfo:
|
||||
def __init__(self, name: str, worker_id: int): ...
|
||||
@property
|
||||
def name(self) -> str: ...
|
||||
@property
|
||||
def id(self) -> int: ...
|
||||
def __eq__(self, other: object) -> bool: ...
|
||||
def __repr__(self) -> str: ...
|
||||
|
||||
class RpcAgent:
|
||||
def join(self): ...
|
||||
def sync(self): ...
|
||||
def shutdown(self): ...
|
||||
@overload
|
||||
def get_worker_info(self) -> WorkerInfo: ...
|
||||
@overload
|
||||
def get_worker_info(self, workerName: str) -> WorkerInfo: ...
|
||||
def get_worker_infos(self) -> List[WorkerInfo]: ...
|
||||
def get_debug_info(self) -> Dict[str, str]: ...
|
||||
def get_metrics(self) -> Dict[str, str]: ...
|
||||
|
||||
class PyRRef:
|
||||
def __init__(self, value: Any, type_hint: Any = None): ...
|
||||
def is_owner(self) -> bool: ...
|
||||
def confirmed_by_owner(self) -> bool: ...
|
||||
def owner(self) -> WorkerInfo: ...
|
||||
def owner_name(self) -> str: ...
|
||||
def to_here(self, timeout: float = _UNSET_RPC_TIMEOUT) -> Any: ...
|
||||
def local_value(self) -> Any: ...
|
||||
def rpc_sync(self) -> Any: ...
|
||||
def rpc_async(self) -> Any: ...
|
||||
def remote(self) -> Any: ...
|
||||
def _serialize(self) -> Tuple: ...
|
||||
@staticmethod
|
||||
def _deserialize(tp: Tuple) -> 'PyRRef': ...
|
||||
def _get_type(self) -> Any: ...
|
||||
def _get_future(self) -> Future: ...
|
||||
def _get_profiling_future(self) -> Future: ...
|
||||
def _set_profiling_future(self, profilingFuture: Future): ...
|
||||
def __repr__(self) -> str: ...
|
||||
...
|
||||
|
||||
class ProcessGroupRpcBackendOptions(RpcBackendOptions):
|
||||
num_send_recv_threads: int
|
||||
def __init__(
|
||||
self,
|
||||
num_send_recv_threads: int,
|
||||
rpc_timeout: float,
|
||||
init_method: str
|
||||
): ...
|
||||
|
||||
class ProcessGroupAgent(RpcAgent):
|
||||
def __init__(
|
||||
self,
|
||||
worker_name: str,
|
||||
pg: ProcessGroup,
|
||||
numSendRecvThreads: int,
|
||||
rpcTimeout: timedelta
|
||||
): ...
|
||||
@overload
|
||||
def get_worker_info(self) -> WorkerInfo: ...
|
||||
@overload
|
||||
def get_worker_info(self, workerName: str) -> WorkerInfo: ...
|
||||
@overload
|
||||
def get_worker_info(self, id: int) -> WorkerInfo: ...
|
||||
def get_worker_infos(self) -> List[WorkerInfo]: ...
|
||||
def join(self): ...
|
||||
def shutdown(self): ...
|
||||
def sync(self): ...
|
||||
|
||||
class _TensorPipeRpcBackendOptionsBase(RpcBackendOptions):
|
||||
num_worker_threads: int
|
||||
device_maps: Dict[str, Dict[int, int]]
|
||||
def __init__(
|
||||
self,
|
||||
num_worker_threads: int,
|
||||
_transports: Optional[List],
|
||||
_channels: Optional[List],
|
||||
rpc_timeout: float = _DEFAULT_RPC_TIMEOUT_SEC,
|
||||
init_method: str = _DEFAULT_INIT_METHOD,
|
||||
device_maps: Dict[str, Dict[int, int]] = dict()): ...
|
||||
def set_device_map(self, to: str, device_map: Dict[str, Dict[int, int]]): ...
|
||||
|
||||
class TensorPipeAgent(RpcAgent):
|
||||
def __init__(
|
||||
self,
|
||||
store: Store,
|
||||
name: str,
|
||||
worker_id: int,
|
||||
world_size: int,
|
||||
pg: ProcessGroup,
|
||||
opts: _TensorPipeRpcBackendOptionsBase,
|
||||
): ...
|
||||
def join(self): ...
|
||||
def shutdown(self): ...
|
||||
@overload
|
||||
def get_worker_info(self) -> WorkerInfo: ...
|
||||
@overload
|
||||
def get_worker_info(self, workerName: str) -> WorkerInfo: ...
|
||||
@overload
|
||||
def get_worker_info(self, id: int) -> WorkerInfo: ...
|
||||
def get_worker_infos(self) -> List[WorkerInfo]: ...
|
||||
def _set_reverse_device_maps(self, reverseDeviceMaps: Dict[str, Dict[int, int]]): ...
|
||||
|
||||
def _is_current_rpc_agent_set() -> bool: ...
|
||||
def _get_current_rpc_agent()-> RpcAgent: ...
|
||||
def _set_and_start_rpc_agent(agent: RpcAgent): ...
|
||||
def _reset_current_rpc_agent(): ...
|
||||
def _delete_all_user_and_unforked_owner_rrefs(timeout: timedelta = ...): ...
|
||||
def _destroy_rref_context(ignoreRRefLeak: bool): ...
|
||||
def _rref_context_get_debug_info() -> Dict[str, str]: ...
|
||||
def _cleanup_python_rpc_handler(): ...
|
||||
def _invoke_rpc_builtin(
|
||||
dst: WorkerInfo,
|
||||
opName: str,
|
||||
rpcTimeoutSeconds: float,
|
||||
*args: Any,
|
||||
**kwargs: Any
|
||||
): ...
|
||||
def _invoke_rpc_python_udf(
|
||||
dst: WorkerInfo,
|
||||
pickledPythonUDF: str,
|
||||
tensors: List[torch.Tensor],
|
||||
rpcTimeoutSeconds: float,
|
||||
isAsyncExecution: bool
|
||||
): ...
|
||||
def _invoke_rpc_torchscript(
|
||||
dstWorkerName: str,
|
||||
qualifiedNameStr: str,
|
||||
argsTuple: Tuple,
|
||||
kwargsDict: Dict,
|
||||
rpcTimeoutSeconds: float,
|
||||
isAsyncExecution: bool,
|
||||
): ...
|
||||
def _invoke_remote_builtin(
|
||||
dst: WorkerInfo,
|
||||
opName: str,
|
||||
rpcTimeoutSeconds: float,
|
||||
*args: Any,
|
||||
**kwargs: Any
|
||||
): ...
|
||||
def _invoke_remote_python_udf(
|
||||
dst: WorkerInfo,
|
||||
pickledPythonUDF: str,
|
||||
tensors: List[torch.Tensor],
|
||||
rpcTimeoutSeconds: float,
|
||||
isAsyncExecution: bool,
|
||||
): ...
|
||||
def _invoke_remote_torchscript(
|
||||
dstWorkerName: WorkerInfo,
|
||||
qualifiedNameStr: str,
|
||||
rpcTimeoutSeconds: float,
|
||||
isAsyncExecution: bool,
|
||||
*args: Any,
|
||||
**kwargs: Any
|
||||
): ...
|
||||
def get_rpc_timeout() -> float: ...
|
||||
def enable_gil_profiling(flag: bool): ...
|
||||
def _set_rpc_timeout(rpcTimeoutSeconds: float): ...
|
||||
|
||||
class RemoteProfilerManager:
|
||||
@staticmethod
|
||||
def set_current_profiling_key(key: str): ...
|
||||
|
||||
def _enable_server_process_global_profiler(new_config: ProfilerConfig): ...
|
||||
def _disable_server_process_global_profiler() -> List[List[List[ProfilerEvent]]]: ...
|
||||
def _set_profiler_node_id(default_node_id: int): ...
|
||||
def _enable_jit_rref_pickle(): ...
|
||||
def _disable_jit_rref_pickle(): ...
|
||||
|
|
@ -38,7 +38,15 @@ PyObject* rpc_init(PyObject* _unused, PyObject* noargs) {
|
|||
throw python_error();
|
||||
}
|
||||
|
||||
auto module = py::handle(rpc_module).cast<py::module>();
|
||||
auto torch_C_module = THPObjectPtr(PyImport_ImportModule("torch._C"));
|
||||
if (!torch_C_module) {
|
||||
throw python_error();
|
||||
}
|
||||
|
||||
auto torch_C_m = py::handle(torch_C_module).cast<py::module>();
|
||||
auto m = torch_C_m.def_submodule("_distributed_rpc", "distributed rpc bindings");
|
||||
|
||||
auto module = py::handle(m).cast<py::module>();
|
||||
|
||||
auto rpcBackendOptions =
|
||||
shared_ptr_class_<RpcBackendOptions>(
|
||||
|
|
@ -114,6 +122,20 @@ PyObject* rpc_init(PyObject* _unused, PyObject* noargs) {
|
|||
"join", &RpcAgent::join, py::call_guard<py::gil_scoped_release>())
|
||||
.def(
|
||||
"sync", &RpcAgent::sync, py::call_guard<py::gil_scoped_release>())
|
||||
.def(
|
||||
"shutdown",
|
||||
&RpcAgent::shutdown,
|
||||
py::call_guard<py::gil_scoped_release>())
|
||||
.def(
|
||||
"get_worker_info",
|
||||
(const WorkerInfo& (RpcAgent::*)(void)const) &
|
||||
RpcAgent::getWorkerInfo,
|
||||
py::call_guard<py::gil_scoped_release>())
|
||||
.def(
|
||||
"get_worker_info",
|
||||
(const WorkerInfo& (RpcAgent::*)(const std::string&)const) &
|
||||
RpcAgent::getWorkerInfo,
|
||||
py::call_guard<py::gil_scoped_release>())
|
||||
.def(
|
||||
"get_worker_infos",
|
||||
&RpcAgent::getWorkerInfos,
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
import logging
|
||||
import threading
|
||||
|
||||
from typing import Generator, Tuple
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
|
|
@ -20,12 +21,46 @@ if is_available() and not torch._C._rpc_init():
|
|||
|
||||
|
||||
if is_available():
|
||||
from . import api, backend_registry, functions, _set_profiler_node_id
|
||||
from . import (
|
||||
from . import api, backend_registry, functions
|
||||
from torch._C._distributed_rpc import (
|
||||
_disable_jit_rref_pickle,
|
||||
_enable_jit_rref_pickle,
|
||||
_disable_server_process_global_profiler,
|
||||
_enable_server_process_global_profiler,
|
||||
_set_and_start_rpc_agent,
|
||||
_reset_current_rpc_agent,
|
||||
_delete_all_user_and_unforked_owner_rrefs,
|
||||
_destroy_rref_context,
|
||||
_set_profiler_node_id,
|
||||
_is_current_rpc_agent_set,
|
||||
_rref_context_get_debug_info,
|
||||
_cleanup_python_rpc_handler,
|
||||
_invoke_rpc_builtin,
|
||||
_invoke_rpc_python_udf,
|
||||
_invoke_rpc_torchscript,
|
||||
_invoke_remote_builtin,
|
||||
_invoke_remote_python_udf,
|
||||
_invoke_remote_torchscript,
|
||||
_set_rpc_timeout,
|
||||
_get_current_rpc_agent,
|
||||
get_rpc_timeout,
|
||||
enable_gil_profiling,
|
||||
RpcBackendOptions,
|
||||
_TensorPipeRpcBackendOptionsBase,
|
||||
ProcessGroupRpcBackendOptions,
|
||||
RpcAgent,
|
||||
PyRRef,
|
||||
ProcessGroupAgent,
|
||||
TensorPipeAgent,
|
||||
RemoteProfilerManager,
|
||||
WorkerInfo,
|
||||
_DEFAULT_INIT_METHOD,
|
||||
_DEFAULT_NUM_SEND_RECV_THREADS,
|
||||
_DEFAULT_NUM_WORKER_THREADS,
|
||||
_UNSET_RPC_TIMEOUT,
|
||||
_DEFAULT_RPC_TIMEOUT_SEC,
|
||||
) # noqa: F401
|
||||
from torch._C._distributed_c10d import Store
|
||||
from .api import * # noqa: F401
|
||||
from .options import TensorPipeRpcBackendOptions # noqa: F401
|
||||
from .backend_registry import BackendType
|
||||
|
|
@ -36,6 +71,7 @@ if is_available():
|
|||
|
||||
import numbers
|
||||
|
||||
rendezvous_iterator: Generator[Tuple[Store, int, int], None, None]
|
||||
|
||||
def init_rpc(
|
||||
name,
|
||||
|
|
@ -104,18 +140,19 @@ if is_available():
|
|||
raise TypeError(
|
||||
f"Could not infer backend for options {rpc_backend_options}"
|
||||
)
|
||||
if backend != BackendType.TENSORPIPE:
|
||||
# Ignore type error because mypy doesn't handle dynamically generated type objects (#4865)
|
||||
if backend != BackendType.TENSORPIPE: # type: ignore[attr-defined]
|
||||
logger.warning(
|
||||
f"RPC was initialized with no explicit backend but with options "
|
||||
f"RPC was initialized with no explicit backend but with options " # type: ignore[attr-defined]
|
||||
f"corresponding to {backend}, hence that backend will be used "
|
||||
f"instead of the default {BackendType.TENSORPIPE}. To silence this "
|
||||
f"warning pass `backend={backend}` explicitly."
|
||||
)
|
||||
|
||||
if backend is None:
|
||||
backend = BackendType.TENSORPIPE
|
||||
backend = BackendType.TENSORPIPE # type: ignore[attr-defined]
|
||||
|
||||
if backend == BackendType.PROCESS_GROUP:
|
||||
if backend == BackendType.PROCESS_GROUP: # type: ignore[attr-defined]
|
||||
logger.warning(
|
||||
"RPC was initialized with the PROCESS_GROUP backend which is "
|
||||
"deprecated and slated to be removed and superseded by the TENSORPIPE "
|
||||
|
|
@ -176,7 +213,7 @@ if is_available():
|
|||
|
||||
|
||||
def _init_rpc_backend(
|
||||
backend=backend_registry.BackendType.TENSORPIPE,
|
||||
backend=BackendType.TENSORPIPE, # type: ignore[attr-defined]
|
||||
store=None,
|
||||
name=None,
|
||||
rank=-1,
|
||||
|
|
@ -204,7 +241,6 @@ if is_available():
|
|||
|
||||
@api._require_initialized
|
||||
def _get_debug_info():
|
||||
from . import _rref_context_get_debug_info
|
||||
info = _rref_context_get_debug_info()
|
||||
info.update(api._get_current_rpc_agent().get_debug_info())
|
||||
info.update(dist_autograd._get_debug_info())
|
||||
|
|
|
|||
|
|
@ -4,11 +4,11 @@ import functools
|
|||
import inspect
|
||||
import logging
|
||||
import threading
|
||||
from typing import Generic, TypeVar
|
||||
from typing import Generic, TypeVar, Set, Any
|
||||
|
||||
import torch
|
||||
|
||||
from . import (
|
||||
from torch._C._distributed_rpc import (
|
||||
PyRRef,
|
||||
RemoteProfilerManager,
|
||||
WorkerInfo,
|
||||
|
|
@ -99,10 +99,10 @@ class AllGatherStates(object):
|
|||
|
||||
# States used by `def _all_gather()`.
|
||||
# `_ALL_WORKER_NAMES` is initialized on initiaizing RPC layer.
|
||||
_ALL_WORKER_NAMES = None
|
||||
_ALL_WORKER_NAMES: Set[Any] = set()
|
||||
_all_gather_dict_lock = threading.RLock()
|
||||
_all_gather_sequence_id = 0
|
||||
_all_gather_sequence_id_to_states = collections.defaultdict(AllGatherStates)
|
||||
_all_gather_sequence_id_to_states: collections.defaultdict = collections.defaultdict(AllGatherStates)
|
||||
|
||||
|
||||
def _init_rpc_states(agent):
|
||||
|
|
@ -379,16 +379,18 @@ GenericWithOneTypeVar = Generic[T]
|
|||
|
||||
try:
|
||||
# Combine the implementation class and the type class.
|
||||
class RRef(PyRRef, GenericWithOneTypeVar):
|
||||
class RRef(PyRRef, Generic[T]):
|
||||
pass
|
||||
except TypeError as exc:
|
||||
# TypeError: metaclass conflict: the metaclass of a derived class
|
||||
# must be a (non-strict) subclass of the metaclasses of all its bases
|
||||
class RRefMeta(PyRRef.__class__, GenericWithOneTypeVar.__class__):
|
||||
# Mypy doesn't understand __class__ (mypy bug #4177)
|
||||
class RRefMeta(PyRRef.__class__, GenericWithOneTypeVar.__class__): # type: ignore
|
||||
pass
|
||||
|
||||
# Combine the implementation class and the type class.
|
||||
class RRef(PyRRef, GenericWithOneTypeVar, metaclass=RRefMeta):
|
||||
# Types for classes expecting a certain generic parameter (mypy bug #7791)
|
||||
class RRef(PyRRef, GenericWithOneTypeVar, metaclass=RRefMeta): # type: ignore
|
||||
pass
|
||||
|
||||
|
||||
|
|
@ -564,7 +566,8 @@ def remote(to, func, args=None, kwargs=None, timeout=UNSET_RPC_TIMEOUT):
|
|||
dst_worker_info.name,
|
||||
)
|
||||
RemoteProfilerManager.set_current_profiling_key(rpc_profiling_key)
|
||||
ctx_manager = torch.autograd.profiler.record_function(rpc_profiling_key)
|
||||
# Mypy doesn't support re-def of a variable not in the same block (#1174)
|
||||
ctx_manager = torch.autograd.profiler.record_function(rpc_profiling_key) # type: ignore[assignment]
|
||||
|
||||
with ctx_manager as rf:
|
||||
args = args if args else ()
|
||||
|
|
@ -639,7 +642,8 @@ def _invoke_rpc(to, func, rpc_type, args=None, kwargs=None, rpc_timeout=UNSET_RP
|
|||
dst_worker_info.name,
|
||||
)
|
||||
RemoteProfilerManager.set_current_profiling_key(rpc_profiling_key)
|
||||
ctx_manager = torch.autograd.profiler.record_function(rpc_profiling_key)
|
||||
# Mypy doesn't support re-def of a variable not in the same block (#1174)
|
||||
ctx_manager = torch.autograd.profiler.record_function(rpc_profiling_key) # type: ignore[assignment]
|
||||
|
||||
with ctx_manager as rf:
|
||||
args = args if args else ()
|
||||
|
|
|
|||
|
|
@ -28,8 +28,10 @@ _backend_type_doc = """
|
|||
"""
|
||||
|
||||
# Create an enum type, `BackendType`, with empty members.
|
||||
BackendType = enum.Enum(value="BackendType", names={})
|
||||
BackendType.__repr__ = _backend_type_repr
|
||||
# Can't handle Function Enum API (mypy bug #9079)
|
||||
BackendType = enum.Enum(value="BackendType", names=dict()) # type: ignore[misc]
|
||||
# Unable to assign a function a method (mypy bug #2427)
|
||||
BackendType.__repr__ = _backend_type_repr # type: ignore[assignment]
|
||||
BackendType.__doc__ = _backend_type_doc
|
||||
|
||||
def backend_registered(backend_name):
|
||||
|
|
@ -73,8 +75,10 @@ def register_backend(
|
|||
},
|
||||
**existing_enum_dict
|
||||
)
|
||||
BackendType = enum.Enum(value="BackendType", names=extended_enum_dict)
|
||||
BackendType.__repr__ = _backend_type_repr
|
||||
# Can't handle Function Enum API (mypy bug #9079)
|
||||
BackendType = enum.Enum(value="BackendType", names=extended_enum_dict) # type: ignore[misc]
|
||||
# Unable to assign a function a method (mypy bug #2427)
|
||||
BackendType.__repr__ = _backend_type_repr # type: ignore[assignment]
|
||||
BackendType.__doc__ = _backend_type_doc
|
||||
return BackendType[backend_name]
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
from datetime import timedelta
|
||||
|
||||
from . import (
|
||||
from torch._C._distributed_rpc import (
|
||||
_DEFAULT_INIT_METHOD,
|
||||
_DEFAULT_NUM_SEND_RECV_THREADS,
|
||||
_DEFAULT_NUM_WORKER_THREADS,
|
||||
|
|
@ -10,16 +10,16 @@ from . import (
|
|||
|
||||
|
||||
# For any RpcAgent.
|
||||
DEFAULT_RPC_TIMEOUT_SEC = _DEFAULT_RPC_TIMEOUT_SEC
|
||||
DEFAULT_INIT_METHOD = _DEFAULT_INIT_METHOD
|
||||
DEFAULT_SHUTDOWN_TIMEOUT = 5.0
|
||||
DEFAULT_RPC_TIMEOUT_SEC: float = _DEFAULT_RPC_TIMEOUT_SEC
|
||||
DEFAULT_INIT_METHOD: str = _DEFAULT_INIT_METHOD
|
||||
DEFAULT_SHUTDOWN_TIMEOUT: float = 5.0
|
||||
|
||||
# For ProcessGroupAgent.
|
||||
DEFAULT_NUM_SEND_RECV_THREADS = _DEFAULT_NUM_SEND_RECV_THREADS
|
||||
DEFAULT_NUM_SEND_RECV_THREADS: int = _DEFAULT_NUM_SEND_RECV_THREADS
|
||||
# For TensorPipeAgent.
|
||||
DEFAULT_NUM_WORKER_THREADS = _DEFAULT_NUM_WORKER_THREADS
|
||||
DEFAULT_NUM_WORKER_THREADS: int = _DEFAULT_NUM_WORKER_THREADS
|
||||
# Ensure that we don't time out when there are long periods of time without
|
||||
# any operations against the underlying ProcessGroup.
|
||||
DEFAULT_PROCESS_GROUP_TIMEOUT = timedelta(milliseconds=2 ** 31 - 1)
|
||||
DEFAULT_PROCESS_GROUP_TIMEOUT: timedelta = timedelta(milliseconds=2 ** 31 - 1)
|
||||
# Value indicating that timeout is not set for RPC call, and the default should be used.
|
||||
UNSET_RPC_TIMEOUT = _UNSET_RPC_TIMEOUT
|
||||
UNSET_RPC_TIMEOUT: float = _UNSET_RPC_TIMEOUT
|
||||
|
|
|
|||
|
|
@ -160,5 +160,6 @@ def async_execution(fn):
|
|||
@functools.wraps(fn)
|
||||
def wrapper(*args, **kwargs):
|
||||
return fn(*args, **kwargs)
|
||||
wrapper._wrapped_async_rpc_function = fn
|
||||
# Can't declare and use attributes of function objects (mypy#2087)
|
||||
wrapper._wrapped_async_rpc_function = fn # type: ignore[attr-defined]
|
||||
return wrapper
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ from enum import Enum
|
|||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from . import _get_current_rpc_agent
|
||||
from torch._C._distributed_rpc import _get_current_rpc_agent
|
||||
|
||||
|
||||
# Thread local tensor tables to store tensors while pickling torch.Tensor
|
||||
|
|
@ -37,7 +37,8 @@ class _InternalRPCPickler:
|
|||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._dispatch_table = copyreg.dispatch_table.copy()
|
||||
# Ignore type error because dispatch_table is defined in third-party package
|
||||
self._dispatch_table = copyreg.dispatch_table.copy() # type: ignore[attr-defined]
|
||||
self._dispatch_table[torch.Tensor] = self._tensor_reducer
|
||||
|
||||
@classmethod
|
||||
|
|
@ -80,9 +81,11 @@ class _InternalRPCPickler:
|
|||
#
|
||||
# The return value of a `rpc.remote(..)` call is type of `rpc.PyRRef`.
|
||||
# The deserialized RRef object on an RPC receiver side is type of `rpc.PyRRef`.
|
||||
p.dispatch_table[dist.rpc.PyRRef] = self._py_rref_reducer
|
||||
# Ignore type error because dispatch_table is defined in third-party package
|
||||
p.dispatch_table[dist.rpc.PyRRef] = self._py_rref_reducer # type: ignore[index]
|
||||
# An RRef created locally by RRef Python constructor is type of `rpc.RRef`.
|
||||
p.dispatch_table[dist.rpc.RRef] = self._rref_reducer
|
||||
# Ignore type error because dispatch_table is defined in third-party package
|
||||
p.dispatch_table[dist.rpc.RRef] = self._rref_reducer # type: ignore[index]
|
||||
|
||||
# save _thread_local_tensor_tables.send_tables if it is in nested call
|
||||
global _thread_local_tensor_tables
|
||||
|
|
@ -224,8 +227,8 @@ def _start_record_function(exec_type, func_name, current_worker_name, dest_worke
|
|||
profile_key = "rpc_{}#{}({} -> {})".format(
|
||||
exec_type.value, str(func_name), current_worker_name, dest_worker_name
|
||||
)
|
||||
rf = torch.autograd._RecordFunction()
|
||||
torch.autograd._run_before_callbacks(rf, profile_key)
|
||||
rf = torch.autograd._RecordFunction() # type: ignore[attr-defined]
|
||||
torch.autograd._run_before_callbacks(rf, profile_key) # type: ignore[attr-defined]
|
||||
return rf
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from . import _TensorPipeRpcBackendOptionsBase
|
||||
from torch._C._distributed_rpc import _TensorPipeRpcBackendOptionsBase
|
||||
from . import constants as rpc_contants
|
||||
|
||||
import torch
|
||||
|
|
|
|||
|
|
@ -2601,7 +2601,7 @@ class RpcTest(RpcAgentTestFixture):
|
|||
|
||||
@dist_init
|
||||
def test_disable_gil_profiling(self):
|
||||
# test that rpc.enable_gil_profilig(false) will result in
|
||||
# test that rpc.enable_gil_profiling(false) will result in
|
||||
# GIL wait time not being recorded.
|
||||
|
||||
# GIL profiling should be disabled by default.
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user