mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
We use `stage_index_to_group_rank` in the stage to determine what send/recv ops and in the schedule for IR generation. However, we don't need to expose this as an argument in our schedule class, so this stack of PRs is to remove it. This PR creates a `stage_index_to_group_rank` utility function and removes the arg for the ZBVschedule. In a following PR I will add code to infer the `stage_index_to_group_rank` for the CSV schedule path and we will be able to remove this argument from our classes entirely. Related comment from @wconstab https://github.com/pytorch/torchtitan/issues/774#issuecomment-2619793741 Pull Request resolved: https://github.com/pytorch/pytorch/pull/146193 Approved by: https://github.com/wconstab
133 lines
3.7 KiB
Python
133 lines
3.7 KiB
Python
# mypy: allow-untyped-defs
|
|
# Copyright (c) Meta Platforms, Inc. and affiliates
|
|
import logging
|
|
from dataclasses import dataclass
|
|
from typing import Dict, Union
|
|
|
|
import torch
|
|
from torch import fx
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def flatten_args_detach(args):
|
|
"""
|
|
Flatten the args into a list form and detach the tensors from computational graph.
|
|
"""
|
|
flat_detached_args = []
|
|
|
|
def extract_tensor_args(a):
|
|
nonlocal flat_detached_args
|
|
if isinstance(a, torch.Tensor):
|
|
val = a.detach().requires_grad_(a.requires_grad)
|
|
flat_detached_args.append(val)
|
|
return val
|
|
else:
|
|
flat_detached_args.append(a)
|
|
return a
|
|
|
|
new_args = fx.node.map_aggregate(
|
|
args,
|
|
extract_tensor_args,
|
|
)
|
|
|
|
return new_args, flat_detached_args
|
|
|
|
|
|
def flatten_args(args):
|
|
"""
|
|
Flatten the args into a list form.
|
|
"""
|
|
flat_args = []
|
|
|
|
def extract_tensor_args(a):
|
|
nonlocal flat_args
|
|
flat_args.append(a)
|
|
return a
|
|
|
|
fx.node.map_aggregate(
|
|
args,
|
|
extract_tensor_args,
|
|
)
|
|
|
|
return flat_args
|
|
|
|
|
|
class PipeliningShapeError(RuntimeError):
|
|
"""Shape mismatch between configured and runtime values."""
|
|
|
|
|
|
def validate_tensor_metadata(desc, expected, given):
|
|
if not expected.shape == given.shape:
|
|
raise PipeliningShapeError(
|
|
f"{desc} has a shape mismatch: expected {expected.shape} actual {given.shape}"
|
|
)
|
|
if not expected.dtype == given.dtype:
|
|
raise PipeliningShapeError(
|
|
f"{desc} has a dtype mismatch: expected {expected.dtype} actual {given.dtype}"
|
|
)
|
|
if not expected.stride() == given.stride():
|
|
raise PipeliningShapeError(
|
|
f"{desc} has a stride mismatch: expected {expected.stride()} actual {given.stride()}"
|
|
)
|
|
|
|
|
|
def validate_tensors_metadata(
|
|
desc,
|
|
expected_tensors: Union[list[torch.Tensor], tuple[torch.Tensor, ...]],
|
|
actual_tensors: Union[list[torch.Tensor], tuple[torch.Tensor, ...]],
|
|
):
|
|
if len(expected_tensors) != len(actual_tensors):
|
|
raise PipeliningShapeError(
|
|
f"{desc}: Number of values ({len(actual_tensors)}) does not match expected number ({len(expected_tensors)})"
|
|
)
|
|
for i in range(len(expected_tensors)):
|
|
validate_tensor_metadata(
|
|
f"{desc}: value {i}", expected_tensors[i], actual_tensors[i]
|
|
)
|
|
|
|
|
|
def generate_stage_to_rank_mapping(
|
|
pp_size: int, num_stages: int, style: str = "loop"
|
|
) -> Dict[int, int]:
|
|
"""
|
|
Compute the stage id to rank mapping for either a looped or V-style schedule.
|
|
|
|
Most commonly num_stages == pp_size * 2, but this function can be used to
|
|
compute the mapping for any number of stages per rank.
|
|
"""
|
|
mapping = {}
|
|
if num_stages % pp_size != 0:
|
|
raise ValueError(
|
|
f"num_stages {num_stages} must be evenly divisible by pp_size {pp_size}"
|
|
)
|
|
if style == "loop":
|
|
for stage_index in range(num_stages):
|
|
mapping[stage_index] = stage_index % pp_size
|
|
elif style == "v":
|
|
rank_index = 0
|
|
for stage_index in range(num_stages):
|
|
mapping[stage_index] = rank_index
|
|
# dont change rank if we are on the border (to keep v shape)
|
|
if (stage_index + 1) % pp_size == 0:
|
|
continue
|
|
if (stage_index // pp_size) % 2 == 0:
|
|
rank_index += 1
|
|
else:
|
|
rank_index -= 1
|
|
else:
|
|
raise ValueError(f"Style {style} is not supported.")
|
|
return mapping
|
|
|
|
|
|
@dataclass
|
|
class PipeInfo:
|
|
"""
|
|
Captures information for a pipeline (`Pipe` object).
|
|
"""
|
|
|
|
graph: fx.Graph
|
|
num_stages: int
|
|
has_loss_and_backward: bool
|