pytorch/torch/distributed/launcher/api.py
Philip Meier d5988c5eca remove unused type: ignore directives (#60006)
Summary:
During development it is common practice to put `type: ignore` comments on lines that are correct, but `mypy` doesn't recognize this. This often stems from the fact, that the used `mypy` version wasn't able to handle the used pattern.

With every new release `mypy` gets better at handling complex code. In addition to fix all the previously accepted but now failing patterns, we should also revisit all `type: ignore` comments to see if they are still needed or not. Fortunately, we don't need to do it manually: by adding `warn_unused_ignores = True` to the configuration, `mypy` will error out in case it encounters an `type: ignore` that is no longer needed.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/60006

Reviewed By: jbschlosser, malfet

Differential Revision: D29133237

Pulled By: albanD

fbshipit-source-id: 41e82edc5cd5affa7ccedad044b59b94dad4425a
2021-06-18 07:23:31 -07:00

261 lines
9.6 KiB
Python

#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import sys
import uuid
from dataclasses import dataclass, field
from typing import Any, Callable, Dict, List, Optional, Union, cast, Tuple
import torch.distributed.elastic.rendezvous.registry as rdzv_registry
from torch.distributed.elastic import events, metrics
from torch.distributed.elastic.agent.server.api import WorkerSpec, WorkerState
from torch.distributed.elastic.agent.server.local_elastic_agent import LocalElasticAgent
from torch.distributed.elastic.multiprocessing import Std
from torch.distributed.elastic.multiprocessing.errors import ChildFailedError, record
from torch.distributed.elastic.rendezvous import RendezvousParameters
from torch.distributed.elastic.rendezvous.utils import parse_rendezvous_endpoint
from torch.distributed.elastic.utils.logging import get_logger
logger = get_logger()
@dataclass
class LaunchConfig:
"""
min_nodes: Minimum amount of nodes that the user function will
be launched on. Elastic agent ensures that the user
function start only when the min_nodes amount enters
the rendezvous.
max_nodes: Maximum amount of nodes that the user function
will be launched on.
nproc_per_node: On each node the elastic agent will launch
this amount of workers that will execute user
defined function.
rdzv_backend: rdzv_backend to use in the rendezvous (zeus-adapter, etcd).
rdzv_endpoint: The endpoint of the rdzv sync. storage.
rdzv_id: The unique run id of the job (if not passed a unique one will be
deduced from run environment - flow workflow id in flow - or auto generated).
role: User defined role of the worker (defaults to "trainer").
max_restarts: The maximum amount of restarts that elastic agent will conduct
on workers before failure.
monitor_interval: The interval in seconds that is used by the elastic_agent
as a period of monitoring workers.
start_method: The method is used by the elastic agent to start the
workers (spawn, fork, forkserver).
log_dir: base log directory where log files are written. If not set,
one is created in a tmp dir but NOT removed on exit.
redirects: configuration to redirect stdout/stderr to log files.
Pass a single ``Std`` enum to redirect all workers,
or a mapping keyed by local_rank to selectively redirect.
tee: configuration to "tee" stdout/stderr to console + log file.
metrics_cfg: configuration to initialize metrics.
"""
min_nodes: int
max_nodes: int
nproc_per_node: int
run_id: str = ""
role: str = "default_role"
rdzv_endpoint: str = ""
rdzv_backend: str = "etcd"
rdzv_configs: Dict[str, Any] = field(default_factory=dict)
rdzv_timeout: int = 900
max_restarts: int = 3
monitor_interval: float = 30
start_method: str = "spawn"
log_dir: Optional[str] = None
redirects: Union[Std, Dict[int, Std]] = Std.NONE
tee: Union[Std, Dict[int, Std]] = Std.NONE
metrics_cfg: Dict[str, str] = field(default_factory=dict)
def __post_init__(self):
self.rdzv_configs["timeout"] = self.rdzv_timeout
class elastic_launch:
"""
Launches an torchelastic agent on the container that invoked the entrypoint.
1. Pass the ``entrypoint`` arguments as non ``kwargs`` (e.g. no named parameters)/
``entrypoint`` can be a function or a command.
2. The return value is a map of each worker's output mapped
by their respective global rank.
Usage
::
def worker_fn(foo):
# ...
def main():
# entrypoint is a function.
outputs = elastic_launch(LaunchConfig, worker_fn)(foo)
# return rank 0's output
return outputs[0]
# entrypoint is a command and ``script.py`` is the python module.
ouptuts = elestic_launch(LaunchConfig, "script.py")(args)
ouptuts = elestic_launch(LaunchConfig, "python")("script.py")
"""
def __init__(
self,
config: LaunchConfig,
entrypoint: Union[Callable, str, None],
):
self._config = config
self._entrypoint = entrypoint
def __call__(self, *args):
return launch_agent(self._config, self._entrypoint, list(args))
def _construct_event(config: LaunchConfig) -> events.Event:
metadata = {
"rdzv_backend": config.rdzv_backend,
"run_id": config.run_id,
"role": config.role,
}
return events.Event(
name="torch.distributed.elastic.launch_agent",
source=events.EventSource.AGENT,
metadata=cast(Dict[str, events.EventMetadataValue], metadata),
)
def _get_entrypoint_name(
entrypoint: Union[Callable, str, None], args: List[Any]
) -> str:
"""Retrive entrypoint name with the rule:
1. If entrypoint is a function, use ``entrypont.__qualname__``.
2. If entrypoint is a string, check its value:
2.1 if entrypoint equals to ``sys.executable`` (like "python"), use the first element from ``args``
which does not start with hifen letter (for example, "-u" will be skipped).
2.2 otherwise, use ``entrypoint`` value.
3. Otherwise, return empty string.
"""
if isinstance(entrypoint, Callable): # type: ignore[arg-type]
return entrypoint.__name__ # type: ignore[union-attr]
elif isinstance(entrypoint, str):
if entrypoint == sys.executable:
return next((arg for arg in args if arg[0] != "-"), "")
else:
return entrypoint
else:
return ""
def _get_addr_and_port(
rdzv_parameters: RendezvousParameters,
) -> Tuple[Optional[str], Optional[int]]:
if rdzv_parameters.backend != "static":
return (None, None)
endpoint = rdzv_parameters.endpoint
endpoint = endpoint.strip()
if not endpoint:
raise ValueError(
"Endpoint is missing in endpoint. Try to add --master_addr and --master_port"
)
master_addr, master_port = parse_rendezvous_endpoint(endpoint, default_port=-1)
if master_port == -1:
raise ValueError(
f"port is missing in endpoint: {endpoint}. Try to specify --master_port"
)
return (master_addr, master_port)
# pyre-fixme[56]: Pyre was not able to infer the type of the decorator
# torch.distributed.elastic.multiprocessing.errors.record.
@record
def launch_agent(
config: LaunchConfig,
entrypoint: Union[Callable, str, None],
args: List[Any],
) -> Dict[int, Any]:
if not config.run_id:
run_id = str(uuid.uuid4().int)
logger.warning(f"config has no run_id, generate a new one: {run_id}")
config.run_id = run_id
entrypoint_name = _get_entrypoint_name(entrypoint, args)
logger.info(
f"Starting elastic_operator with launch configs:\n"
f" entrypoint : {entrypoint_name}\n"
f" min_nodes : {config.min_nodes}\n"
f" max_nodes : {config.max_nodes}\n"
f" nproc_per_node : {config.nproc_per_node}\n"
f" run_id : {config.run_id}\n"
f" rdzv_backend : {config.rdzv_backend}\n"
f" rdzv_endpoint : {config.rdzv_endpoint}\n"
f" rdzv_configs : {config.rdzv_configs}\n"
f" max_restarts : {config.max_restarts}\n"
f" monitor_interval : {config.monitor_interval}\n"
f" log_dir : {config.log_dir}\n"
f" metrics_cfg : {config.metrics_cfg}\n"
)
rdzv_parameters = RendezvousParameters(
backend=config.rdzv_backend,
endpoint=config.rdzv_endpoint,
run_id=config.run_id,
min_nodes=config.min_nodes,
max_nodes=config.max_nodes,
**config.rdzv_configs,
)
agent = None
rdzv_handler = rdzv_registry.get_rendezvous_handler(rdzv_parameters)
master_addr, master_port = _get_addr_and_port(rdzv_parameters)
try:
spec = WorkerSpec(
role=config.role,
local_world_size=config.nproc_per_node,
entrypoint=entrypoint,
args=tuple(args),
rdzv_handler=rdzv_handler,
max_restarts=config.max_restarts,
monitor_interval=config.monitor_interval,
redirects=config.redirects,
tee=config.tee,
master_addr=master_addr,
master_port=master_port,
)
cfg = metrics.MetricsConfig(config.metrics_cfg) if config.metrics_cfg else None
metrics.initialize_metrics(cfg)
agent = LocalElasticAgent(
spec=spec, start_method=config.start_method, log_dir=config.log_dir
)
result = agent.run()
events.record(agent.get_agent_status_event(WorkerState.SUCCEEDED))
if result.is_failed():
# ChildFailedError is treated specially by @record
# if the error files for the failed children exist
# @record will copy the first error (root cause)
# to the error file of the launcher process.
raise ChildFailedError(
name=entrypoint_name,
failures=result.failures,
)
else:
return result.return_values
except ChildFailedError:
raise
except Exception:
if agent:
events.record(agent.get_agent_status_event(WorkerState.FAILED))
else:
events.record(_construct_event(config))
raise
finally:
rdzv_handler.shutdown()