mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Summary: Add an ability to customize log lines and addtional template like behavior to enrich log information. Motivation: a) Log stream processing/aggregation gains additional value when it includes information about the global rank. Extension to that is that it will be easier to map ranks to hosts from log stream information (less relevant at the moment) b) Users can easily map the failure to the right rank without matching node rank offset+local rank. Implementation - BC change - keeps the logs line prefix as `[<role name><local rank>]:` - Optional env variable TORCHELASTIC_LOG_LINE_HEADER that will be used as a prefix when specified and currently exposes `role_name`, `rank` and `local_rank` variables that will be bound when agent assigns the ranks. Test Plan: CI https://fburl.com/mlhub/mzx5xspv Differential Revision: D50584590 Pull Request resolved: https://github.com/pytorch/pytorch/pull/112357 Approved by: https://github.com/kiukchung
289 lines
11 KiB
Python
289 lines
11 KiB
Python
#!/usr/bin/env python3
|
|
|
|
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the BSD-style license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
import sys
|
|
import uuid
|
|
from dataclasses import dataclass, field
|
|
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
|
|
|
import torch.distributed.elastic.rendezvous.registry as rdzv_registry
|
|
from torch.distributed.elastic import events, metrics
|
|
from torch.distributed.elastic.agent.server.api import WorkerSpec
|
|
from torch.distributed.elastic.agent.server.local_elastic_agent import LocalElasticAgent
|
|
from torch.distributed.elastic.multiprocessing import SignalException, Std
|
|
from torch.distributed.elastic.multiprocessing.errors import ChildFailedError
|
|
from torch.distributed.elastic.rendezvous import RendezvousParameters
|
|
from torch.distributed.elastic.rendezvous.utils import parse_rendezvous_endpoint
|
|
from torch.distributed.elastic.utils.logging import get_logger
|
|
|
|
__all__ = ['LaunchConfig', 'elastic_launch', 'launch_agent']
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
@dataclass
|
|
class LaunchConfig:
|
|
"""
|
|
Creates a rendezvous config.
|
|
|
|
Args:
|
|
min_nodes: Minimum amount of nodes that the user function will
|
|
be launched on. Elastic agent ensures that the user
|
|
function start only when the min_nodes amount enters
|
|
the rendezvous.
|
|
max_nodes: Maximum amount of nodes that the user function
|
|
will be launched on.
|
|
nproc_per_node: On each node the elastic agent will launch
|
|
this amount of workers that will execute user
|
|
defined function.
|
|
rdzv_backend: rdzv_backend to use in the rendezvous (zeus-adapter, etcd).
|
|
rdzv_endpoint: The endpoint of the rdzv sync. storage.
|
|
rdzv_configs: Key, value pair that specifies rendezvous specific configuration.
|
|
rdzv_timeout: Legacy argument that specifies timeout for the rendezvous. It is going
|
|
to be removed in future versions, see the note below. The default timeout is 900 seconds.
|
|
run_id: The unique run id of the job (if not passed a unique one will be
|
|
deduced from run environment - flow workflow id in flow - or auto generated).
|
|
role: User defined role of the worker (defaults to "trainer").
|
|
max_restarts: The maximum amount of restarts that elastic agent will conduct
|
|
on workers before failure.
|
|
monitor_interval: The interval in seconds that is used by the elastic_agent
|
|
as a period of monitoring workers.
|
|
start_method: The method is used by the elastic agent to start the
|
|
workers (spawn, fork, forkserver).
|
|
log_dir: base log directory where log files are written. If not set,
|
|
one is created in a tmp dir but NOT removed on exit.
|
|
redirects: configuration to redirect stdout/stderr to log files.
|
|
Pass a single ``Std`` enum to redirect all workers,
|
|
or a mapping keyed by local_rank to selectively redirect.
|
|
tee: configuration to "tee" stdout/stderr to console + log file.
|
|
metrics_cfg: configuration to initialize metrics.
|
|
local_addr: address of the local node if any. If not set, a lookup on the local
|
|
machine's FQDN will be performed.
|
|
..note:
|
|
`rdzv_timeout` is a legacy argument that will be removed in future.
|
|
Set the timeout via `rdzv_configs['timeout']`
|
|
|
|
"""
|
|
|
|
min_nodes: int
|
|
max_nodes: int
|
|
nproc_per_node: int
|
|
run_id: str = ""
|
|
role: str = "default_role"
|
|
rdzv_endpoint: str = ""
|
|
rdzv_backend: str = "etcd"
|
|
rdzv_configs: Dict[str, Any] = field(default_factory=dict)
|
|
rdzv_timeout: int = -1
|
|
max_restarts: int = 3
|
|
monitor_interval: float = 30
|
|
start_method: str = "spawn"
|
|
log_dir: Optional[str] = None
|
|
log_line_prefix_template: Optional[str] = None
|
|
redirects: Union[Std, Dict[int, Std]] = Std.NONE
|
|
tee: Union[Std, Dict[int, Std]] = Std.NONE
|
|
metrics_cfg: Dict[str, str] = field(default_factory=dict)
|
|
local_addr: Optional[str] = None
|
|
|
|
def __post_init__(self):
|
|
default_timeout = 900
|
|
if self.rdzv_timeout != -1:
|
|
self.rdzv_configs["timeout"] = self.rdzv_timeout
|
|
elif "timeout" not in self.rdzv_configs:
|
|
self.rdzv_configs["timeout"] = default_timeout
|
|
|
|
|
|
class elastic_launch:
|
|
"""
|
|
Launches an torchelastic agent on the container that invoked the entrypoint.
|
|
|
|
1. Pass the ``entrypoint`` arguments as non ``kwargs`` (e.g. no named parameters)/
|
|
``entrypoint`` can be a function or a command.
|
|
2. The return value is a map of each worker's output mapped
|
|
by their respective global rank.
|
|
|
|
Usage
|
|
|
|
::
|
|
|
|
def worker_fn(foo):
|
|
# ...
|
|
|
|
def main():
|
|
# entrypoint is a function.
|
|
outputs = elastic_launch(LaunchConfig, worker_fn)(foo)
|
|
# return rank 0's output
|
|
return outputs[0]
|
|
|
|
# entrypoint is a command and ``script.py`` is the python module.
|
|
outputs = elastic_launch(LaunchConfig, "script.py")(args)
|
|
outputs = elastic_launch(LaunchConfig, "python")("script.py")
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
config: LaunchConfig,
|
|
entrypoint: Union[Callable, str, None],
|
|
):
|
|
self._config = config
|
|
self._entrypoint = entrypoint
|
|
|
|
def __call__(self, *args):
|
|
return launch_agent(self._config, self._entrypoint, list(args))
|
|
|
|
|
|
def _get_entrypoint_name(
|
|
entrypoint: Union[Callable, str, None], args: List[Any]
|
|
) -> str:
|
|
"""Retrieve entrypoint name with the rule:
|
|
1. If entrypoint is a function, use ``entrypoint.__qualname__``.
|
|
2. If entrypoint is a string, check its value:
|
|
2.1 if entrypoint equals to ``sys.executable`` (like "python"), use the first element from ``args``
|
|
which does not start with hifen letter (for example, "-u" will be skipped).
|
|
2.2 otherwise, use ``entrypoint`` value.
|
|
3. Otherwise, return empty string.
|
|
"""
|
|
if isinstance(entrypoint, Callable): # type: ignore[arg-type]
|
|
return entrypoint.__name__ # type: ignore[union-attr]
|
|
elif isinstance(entrypoint, str):
|
|
if entrypoint == sys.executable:
|
|
return next((arg for arg in args if arg[0] != "-"), "")
|
|
else:
|
|
return entrypoint
|
|
else:
|
|
return ""
|
|
|
|
|
|
def _get_addr_and_port(
|
|
rdzv_parameters: RendezvousParameters,
|
|
) -> Tuple[Optional[str], Optional[int]]:
|
|
if rdzv_parameters.backend != "static":
|
|
return (None, None)
|
|
endpoint = rdzv_parameters.endpoint
|
|
endpoint = endpoint.strip()
|
|
if not endpoint:
|
|
raise ValueError(
|
|
"Endpoint is missing in endpoint. Try to add --master-addr and --master-port"
|
|
)
|
|
master_addr, master_port = parse_rendezvous_endpoint(endpoint, default_port=-1)
|
|
if master_port == -1:
|
|
raise ValueError(
|
|
f"port is missing in endpoint: {endpoint}. Try to specify --master-port"
|
|
)
|
|
return (master_addr, master_port)
|
|
|
|
|
|
def launch_agent(
|
|
config: LaunchConfig,
|
|
entrypoint: Union[Callable, str, None],
|
|
args: List[Any],
|
|
) -> Dict[int, Any]:
|
|
if not config.run_id:
|
|
run_id = str(uuid.uuid4().int)
|
|
logger.warning("config has no run_id, generated a random run_id: %s", run_id)
|
|
config.run_id = run_id
|
|
|
|
entrypoint_name = _get_entrypoint_name(entrypoint, args)
|
|
|
|
logger.info(
|
|
"Starting elastic_operator with launch configs:\n"
|
|
" entrypoint : %(entrypoint)s\n"
|
|
" min_nodes : %(min_nodes)s\n"
|
|
" max_nodes : %(max_nodes)s\n"
|
|
" nproc_per_node : %(nproc_per_node)s\n"
|
|
" run_id : %(run_id)s\n"
|
|
" rdzv_backend : %(rdzv_backend)s\n"
|
|
" rdzv_endpoint : %(rdzv_endpoint)s\n"
|
|
" rdzv_configs : %(rdzv_configs)s\n"
|
|
" max_restarts : %(max_restarts)s\n"
|
|
" monitor_interval : %(monitor_interval)s\n"
|
|
" log_dir : %(log_dir)s\n"
|
|
" metrics_cfg : %(metrics_cfg)s\n",
|
|
{
|
|
"entrypoint": entrypoint_name,
|
|
"min_nodes": config.min_nodes,
|
|
"max_nodes": config.max_nodes,
|
|
"nproc_per_node": config.nproc_per_node,
|
|
"run_id": config.run_id,
|
|
"rdzv_backend": config.rdzv_backend,
|
|
"rdzv_endpoint": config.rdzv_endpoint,
|
|
"rdzv_configs": config.rdzv_configs,
|
|
"max_restarts": config.max_restarts,
|
|
"monitor_interval": config.monitor_interval,
|
|
"log_dir": config.log_dir,
|
|
"metrics_cfg": config.metrics_cfg
|
|
}
|
|
)
|
|
|
|
rdzv_parameters = RendezvousParameters(
|
|
backend=config.rdzv_backend,
|
|
endpoint=config.rdzv_endpoint,
|
|
run_id=config.run_id,
|
|
min_nodes=config.min_nodes,
|
|
max_nodes=config.max_nodes,
|
|
local_addr=config.local_addr,
|
|
**config.rdzv_configs,
|
|
)
|
|
|
|
master_addr, master_port = _get_addr_and_port(rdzv_parameters)
|
|
|
|
spec = WorkerSpec(
|
|
role=config.role,
|
|
local_world_size=config.nproc_per_node,
|
|
entrypoint=entrypoint,
|
|
args=tuple(args),
|
|
rdzv_handler=rdzv_registry.get_rendezvous_handler(rdzv_parameters),
|
|
max_restarts=config.max_restarts,
|
|
monitor_interval=config.monitor_interval,
|
|
redirects=config.redirects,
|
|
tee=config.tee,
|
|
master_addr=master_addr,
|
|
master_port=master_port,
|
|
local_addr=config.local_addr,
|
|
)
|
|
|
|
agent = LocalElasticAgent(
|
|
spec=spec,
|
|
start_method=config.start_method,
|
|
log_dir=config.log_dir,
|
|
log_line_prefix_template=config.log_line_prefix_template,
|
|
)
|
|
|
|
shutdown_rdzv = True
|
|
try:
|
|
metrics.initialize_metrics(metrics.MetricsConfig(config.metrics_cfg))
|
|
|
|
result = agent.run()
|
|
# records that agent.run() has succeeded NOT that workers have succeeded
|
|
events.record(agent.get_event_succeeded())
|
|
|
|
if result.is_failed():
|
|
# ChildFailedError is treated specially by @record
|
|
# if the error files for the failed children exist
|
|
# @record will copy the first error (root cause)
|
|
# to the error file of the launcher process.
|
|
raise ChildFailedError(
|
|
name=entrypoint_name,
|
|
failures=result.failures,
|
|
)
|
|
|
|
return result.return_values
|
|
except ChildFailedError:
|
|
raise
|
|
except SignalException:
|
|
# when the agent dies with a signal do NOT shutdown the rdzv_handler
|
|
# since this closes the rendezvous on this rdzv_id permanently and
|
|
# prevents any additional scaling events
|
|
shutdown_rdzv = False
|
|
events.record(agent.get_event_failed())
|
|
raise
|
|
except Exception:
|
|
events.record(agent.get_event_failed())
|
|
raise
|
|
finally:
|
|
if shutdown_rdzv:
|
|
spec.rdzv_handler.shutdown()
|