mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/55466
Improve the implementation and the unit test coverage of `RendezvousHandlerRegistry`.
### Note
See the original diff (D27442325 (df299dbd7d)) that had to be reverted due to an unexpected Python version incompatibility between the internal and external PyTorch CI tests.
Test Plan: Run the existing and newly-introduced unit tests.
Reviewed By: tierex
Differential Revision: D27623215
fbshipit-source-id: 51538d0f154f64e04f685a95d40d805b478c93f9
281 lines
9.0 KiB
Python
281 lines
9.0 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the BSD-style license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
import abc
|
|
from typing import Any, Callable, Dict, Optional, Tuple
|
|
|
|
from torch.distributed import Store
|
|
|
|
|
|
class RendezvousError(Exception):
|
|
"""Represents the base type for rendezvous errors."""
|
|
|
|
|
|
class RendezvousClosedError(RendezvousError):
|
|
"""Raised when a rendezvous is closed."""
|
|
|
|
|
|
class RendezvousTimeoutError(RendezvousError):
|
|
"""Raised when a rendezvous did not complete on time."""
|
|
|
|
|
|
class RendezvousConnectionError(RendezvousError):
|
|
"""Raised when the connection to a rendezvous backend has failed."""
|
|
|
|
|
|
class RendezvousStateError(RendezvousError):
|
|
"""Raised when the state of a rendezvous is corrupt."""
|
|
|
|
|
|
class RendezvousHandler(abc.ABC):
|
|
"""
|
|
Main rendezvous interface.
|
|
|
|
.. note:: torchelastic users normally **do not** need to implement their
|
|
own ``RendezvousHandler``. An implementation based on
|
|
`etcd <https://etcd.io/>`__ is already provided, and is recommended
|
|
for most users, provided they can deploy it in their environment.
|
|
|
|
.. warning:: torchelastic is currently considered experimental,
|
|
so the APIs may change!
|
|
"""
|
|
|
|
@abc.abstractmethod
|
|
def get_backend(self) -> str:
|
|
"""
|
|
Return the string representation of the rendezvous handler.
|
|
"""
|
|
pass
|
|
|
|
@abc.abstractmethod
|
|
def next_rendezvous(
|
|
self,
|
|
) -> Tuple[Store, int, int]:
|
|
"""
|
|
Main entry-point into the rendezvous barrier.
|
|
Blocks until the rendezvous is complete (and the current
|
|
process is included in the formed worker group), or a timeout occurs, or
|
|
rendezvous was marked closed.
|
|
|
|
Returns: a tuple of (``c10d Store``, ``rank``, ``world size``)
|
|
|
|
Raises:
|
|
RendezvousClosedError - if rendezvous for the current job is closed.
|
|
RendezvousTimeoutError - on timeout
|
|
"""
|
|
pass
|
|
|
|
@abc.abstractmethod
|
|
def is_closed(self) -> bool:
|
|
"""
|
|
Checks whether rendezvous for current job has been closed,
|
|
which means all future attempts to re-rendezvous (within same job) will
|
|
fail.
|
|
|
|
.. note:: ``is_closed`` and ``set_closed`` have semantics of eventual
|
|
propagation, and should not be used for synchronization.
|
|
The intention here is that if at least one worker decides
|
|
the job is finished, it will close the rendezvous, and
|
|
other workers will soon observe this and stop
|
|
training/rendezvous-ing as well.
|
|
"""
|
|
pass
|
|
|
|
@abc.abstractmethod
|
|
def set_closed(self):
|
|
"""
|
|
Used to mark the rendezvous (for current job) as closed.
|
|
"""
|
|
pass
|
|
|
|
@abc.abstractmethod
|
|
def num_nodes_waiting(self) -> int:
|
|
"""
|
|
Returns number of workers who *arrived late* at
|
|
the rendezvous barrier, hence weren't included in the current worker
|
|
group.
|
|
|
|
Callers should periodically call this method to check whether
|
|
new members are waiting to join the job and if so admit them by
|
|
calling ``next_rendezvous()`` (re-rendezvous).
|
|
"""
|
|
pass
|
|
|
|
@abc.abstractmethod
|
|
def get_run_id(self) -> str:
|
|
"""
|
|
Returns the run_id of this rendezvous handler. The run_id is a user-defined
|
|
id that uniquely identifies an instance of a distributed application.
|
|
It typically maps to a job id and is used to allow workers to join the
|
|
correct distributed application.
|
|
"""
|
|
pass
|
|
|
|
def shutdown(self) -> bool:
|
|
"""
|
|
Closes all resources that were open for rendezvous run.
|
|
|
|
Usage:
|
|
|
|
::
|
|
|
|
def main():
|
|
rdzv_handler = ...
|
|
try:
|
|
rank, world_size, store = rdzv_handler.next_rendezvous()
|
|
finally:
|
|
rdzv_handler.shutdown()
|
|
"""
|
|
pass
|
|
|
|
|
|
class RendezvousParameters:
|
|
"""Holds the parameters to construct a `RendezvousHandler`.
|
|
|
|
Args:
|
|
backend:
|
|
The name of the backend to use to handle the rendezvous.
|
|
endpoint:
|
|
The endpoint of the rendezvous, usually in form <hostname>[:<port>].
|
|
run_id:
|
|
The id of the rendezvous.
|
|
min_nodes:
|
|
The minimum number of nodes to admit to the rendezvous.
|
|
max_nodes:
|
|
The maximum number of nodes to admit to the rendezvous.
|
|
**kwargs:
|
|
Additional parameters for the specified backend.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
backend: str,
|
|
endpoint: str,
|
|
run_id: str,
|
|
min_nodes: int,
|
|
max_nodes: int,
|
|
**kwargs,
|
|
):
|
|
if not backend:
|
|
raise ValueError("The rendezvous backend name must be a non-empty string.")
|
|
|
|
if min_nodes < 1:
|
|
raise ValueError(
|
|
f"The minimum number of rendezvous nodes ({min_nodes}) must be greater than zero."
|
|
)
|
|
if max_nodes < min_nodes:
|
|
raise ValueError(
|
|
f"The maximum number of rendezvous nodes ({max_nodes}) must be greater than or "
|
|
f"equal to the minimum number of rendezvous nodes ({min_nodes})."
|
|
)
|
|
|
|
self.backend = backend
|
|
self.endpoint = endpoint
|
|
self.run_id = run_id
|
|
self.min_nodes = min_nodes
|
|
self.max_nodes = max_nodes
|
|
self.config = kwargs
|
|
|
|
def get(self, key: str, default: Any = None) -> Any:
|
|
"""Returns the value for `key` if `key` exists, else `default`."""
|
|
return self.config.get(key, default)
|
|
|
|
def get_as_bool(self, key: str, default: Optional[bool] = None) -> Optional[bool]:
|
|
"""Returns the value for `key` as a `bool`."""
|
|
value = self.get(key, default)
|
|
if value is None or isinstance(value, bool):
|
|
return value
|
|
if isinstance(value, int):
|
|
if value == 1:
|
|
return True
|
|
if value == 0:
|
|
return False
|
|
elif isinstance(value, str):
|
|
if value.lower() in ["1", "true", "t", "yes", "y"]:
|
|
return True
|
|
if value.lower() in ["0", "false", "f", "no", "n"]:
|
|
return False
|
|
raise ValueError(
|
|
f"The rendezvous configuration option '{key}' does not represent a valid boolean value."
|
|
)
|
|
|
|
def get_as_int(self, key: str, default: Optional[int] = None) -> Optional[int]:
|
|
"""Returns the value for `key` as an `int`."""
|
|
value = self.get(key, default)
|
|
if value is None:
|
|
return value
|
|
try:
|
|
return int(value)
|
|
except ValueError:
|
|
raise ValueError(
|
|
f"The rendezvous configuration option '{key}' does not represent a valid integer "
|
|
"value."
|
|
)
|
|
|
|
|
|
RendezvousHandlerCreator = Callable[[RendezvousParameters], RendezvousHandler]
|
|
|
|
|
|
class RendezvousHandlerRegistry:
|
|
"""Represents a registry of `RendezvousHandler` backends."""
|
|
|
|
_registry: Dict[str, RendezvousHandlerCreator]
|
|
|
|
def __init__(self) -> None:
|
|
self._registry = {}
|
|
|
|
def register(self, backend: str, creator: RendezvousHandlerCreator) -> None:
|
|
"""Registers a new rendezvous backend.
|
|
|
|
Args:
|
|
backend:
|
|
The name of the backend.
|
|
creater:
|
|
The callback to invoke to construct the `RendezvousHandler`.
|
|
"""
|
|
if not backend:
|
|
raise ValueError("The rendezvous backend name must be a non-empty string.")
|
|
|
|
current_creator: Optional[RendezvousHandlerCreator]
|
|
try:
|
|
current_creator = self._registry[backend]
|
|
except KeyError:
|
|
current_creator = None
|
|
|
|
if current_creator is not None and current_creator != creator:
|
|
raise ValueError(
|
|
f"The rendezvous backend '{backend}' cannot be registered with '{creator}' as it "
|
|
f"is already registered with '{current_creator}'."
|
|
)
|
|
|
|
self._registry[backend] = creator
|
|
|
|
def create_handler(self, params: RendezvousParameters) -> RendezvousHandler:
|
|
"""Creates a new `RendezvousHandler`."""
|
|
try:
|
|
creator = self._registry[params.backend]
|
|
except KeyError:
|
|
raise ValueError(
|
|
f"The rendezvous backend '{params.backend}' is not registered. Did you forget "
|
|
f"to call `{self.register.__name__}`?"
|
|
)
|
|
|
|
handler = creator(params)
|
|
|
|
# Do some sanity check.
|
|
if handler.get_backend() != params.backend:
|
|
raise RuntimeError(
|
|
f"The rendezvous backend '{handler.get_backend()}' does not match the requested "
|
|
f"backend '{params.backend}'."
|
|
)
|
|
|
|
return handler
|
|
|
|
|
|
# The default global registry instance used by launcher scripts to instantiate
|
|
# rendezvous handlers.
|
|
rendezvous_handler_registry = RendezvousHandlerRegistry()
|