pytorch/torch/distributed/pipelining/_utils.py
Howard Huang 4ee7d0de86 Add generate_stage_to_rank_mapping utility (#146193)
We use `stage_index_to_group_rank` in the stage to determine what send/recv ops and in the schedule for IR generation. However, we don't need to expose this as an argument in our schedule class, so this stack of PRs is to remove it.

This PR creates a `stage_index_to_group_rank` utility function and removes the arg for the ZBVschedule. In a following PR I will add code to infer the `stage_index_to_group_rank` for the CSV schedule path and we will be able to remove this argument from our classes entirely.

Related comment from @wconstab https://github.com/pytorch/torchtitan/issues/774#issuecomment-2619793741

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146193
Approved by: https://github.com/wconstab
2025-02-05 21:26:45 +00:00

133 lines
3.7 KiB
Python

# mypy: allow-untyped-defs
# Copyright (c) Meta Platforms, Inc. and affiliates
import logging
from dataclasses import dataclass
from typing import Dict, Union
import torch
from torch import fx
logger = logging.getLogger(__name__)
def flatten_args_detach(args):
"""
Flatten the args into a list form and detach the tensors from computational graph.
"""
flat_detached_args = []
def extract_tensor_args(a):
nonlocal flat_detached_args
if isinstance(a, torch.Tensor):
val = a.detach().requires_grad_(a.requires_grad)
flat_detached_args.append(val)
return val
else:
flat_detached_args.append(a)
return a
new_args = fx.node.map_aggregate(
args,
extract_tensor_args,
)
return new_args, flat_detached_args
def flatten_args(args):
"""
Flatten the args into a list form.
"""
flat_args = []
def extract_tensor_args(a):
nonlocal flat_args
flat_args.append(a)
return a
fx.node.map_aggregate(
args,
extract_tensor_args,
)
return flat_args
class PipeliningShapeError(RuntimeError):
"""Shape mismatch between configured and runtime values."""
def validate_tensor_metadata(desc, expected, given):
if not expected.shape == given.shape:
raise PipeliningShapeError(
f"{desc} has a shape mismatch: expected {expected.shape} actual {given.shape}"
)
if not expected.dtype == given.dtype:
raise PipeliningShapeError(
f"{desc} has a dtype mismatch: expected {expected.dtype} actual {given.dtype}"
)
if not expected.stride() == given.stride():
raise PipeliningShapeError(
f"{desc} has a stride mismatch: expected {expected.stride()} actual {given.stride()}"
)
def validate_tensors_metadata(
desc,
expected_tensors: Union[list[torch.Tensor], tuple[torch.Tensor, ...]],
actual_tensors: Union[list[torch.Tensor], tuple[torch.Tensor, ...]],
):
if len(expected_tensors) != len(actual_tensors):
raise PipeliningShapeError(
f"{desc}: Number of values ({len(actual_tensors)}) does not match expected number ({len(expected_tensors)})"
)
for i in range(len(expected_tensors)):
validate_tensor_metadata(
f"{desc}: value {i}", expected_tensors[i], actual_tensors[i]
)
def generate_stage_to_rank_mapping(
pp_size: int, num_stages: int, style: str = "loop"
) -> Dict[int, int]:
"""
Compute the stage id to rank mapping for either a looped or V-style schedule.
Most commonly num_stages == pp_size * 2, but this function can be used to
compute the mapping for any number of stages per rank.
"""
mapping = {}
if num_stages % pp_size != 0:
raise ValueError(
f"num_stages {num_stages} must be evenly divisible by pp_size {pp_size}"
)
if style == "loop":
for stage_index in range(num_stages):
mapping[stage_index] = stage_index % pp_size
elif style == "v":
rank_index = 0
for stage_index in range(num_stages):
mapping[stage_index] = rank_index
# dont change rank if we are on the border (to keep v shape)
if (stage_index + 1) % pp_size == 0:
continue
if (stage_index // pp_size) % 2 == 0:
rank_index += 1
else:
rank_index -= 1
else:
raise ValueError(f"Style {style} is not supported.")
return mapping
@dataclass
class PipeInfo:
"""
Captures information for a pipeline (`Pipe` object).
"""
graph: fx.Graph
num_stages: int
has_loss_and_backward: bool