mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Revert "[pipelining] [BE] Move pipeline_order validation to schedules.py (#129369)"
This reverts commitec789a3c9d. Reverted https://github.com/pytorch/pytorch/pull/129369 on behalf of https://github.com/clee2000 due to broke test/distributed/pipelining/test_schedule.py::ScheduleTest::test_non_symmetric_stage_ids_ScheduleClass0 on distributed cuda https://github.com/pytorch/pytorch/actions/runs/9766039400/job/26959115773ec789a3c9d. You can see the error on the PR, but Dr. CI classified it wrong ([comment](https://github.com/pytorch/pytorch/pull/129369#issuecomment-2204568418))
This commit is contained in:
parent
b6f781e433
commit
b5fdbc1a9f
|
|
@ -6,6 +6,7 @@ import os
|
|||
import sys
|
||||
import tempfile
|
||||
import unittest
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
from model_registry import ModelWithKwargs, MultiMLP
|
||||
from schedule_registry import ScheduleUnbalanced, ScheduleVShaped, ScheduleWithW
|
||||
|
|
@ -21,10 +22,7 @@ from torch.distributed.pipelining import (
|
|||
ScheduleInterleaved1F1B,
|
||||
ScheduleLoopedBFS,
|
||||
)
|
||||
from torch.distributed.pipelining.schedules import (
|
||||
_format_pipeline_order,
|
||||
_validate_pipeline_order,
|
||||
)
|
||||
from torch.distributed.pipelining.schedules import _Action, _ComputationType
|
||||
from torch.distributed.pipelining.stage import _PipelineStageBase
|
||||
from torch.testing._internal.common_cuda import TEST_MULTIGPU
|
||||
from torch.testing._internal.common_distributed import (
|
||||
|
|
@ -610,7 +608,153 @@ class ScheduleTest(MultiProcContinousTest):
|
|||
instantiate_parametrized_tests(ScheduleTest)
|
||||
|
||||
|
||||
def format_pipeline_order(pipeline_order: Dict[int, List[Optional[_Action]]]):
|
||||
import itertools
|
||||
|
||||
# Calculate the maximum number of steps across all ranks
|
||||
num_steps = max(len(actions) for actions in pipeline_order.values())
|
||||
step_labels = [
|
||||
"Step " + str(i).zfill(len(str(num_steps - 1))) for i in range(num_steps)
|
||||
]
|
||||
# Sorting the dictionary by keys and retrieving values in that order
|
||||
rank_actions = [
|
||||
pipeline_order.get(key, [""] * num_steps) for key in sorted(pipeline_order)
|
||||
]
|
||||
# Transpose the list of lists (rows to columns)
|
||||
transposed_actions = list(itertools.zip_longest(*rank_actions, fillvalue=""))
|
||||
# Generate column labels for ranks
|
||||
num_ranks = len(pipeline_order)
|
||||
rank_labels = ["Rank " + str(i) for i in range(num_ranks)]
|
||||
# Calculate the maximum length of each column, considering labels
|
||||
max_lengths = [
|
||||
max(len(str(item)) if item is not None else 0 for item in col)
|
||||
for col in zip(step_labels, *transposed_actions)
|
||||
]
|
||||
# Format the header row with rank labels
|
||||
header_row = " " * (len(step_labels[0]) + 2) + " ".join(
|
||||
f"{label:<{max_lengths[i]}}" for i, label in enumerate(rank_labels)
|
||||
)
|
||||
# Format each row with its corresponding label
|
||||
formatted_rows = [
|
||||
f"{label}: "
|
||||
+ " ".join(f"{str(item):<{max_lengths[i]}}" for i, item in enumerate(row))
|
||||
for label, row in zip(step_labels, transposed_actions)
|
||||
]
|
||||
# Join the rows into a single string
|
||||
formatted_table = (
|
||||
"=========== ALL_RANK_ACTIONS ===========\n"
|
||||
+ header_row
|
||||
+ "\n"
|
||||
+ "\n".join(formatted_rows)
|
||||
+ "\n"
|
||||
)
|
||||
return formatted_table
|
||||
|
||||
|
||||
class TestSchedulePlan(unittest.TestCase):
|
||||
def _validate_pipeline_order(
|
||||
self,
|
||||
pipeline_order: Dict[int, List[Optional[_Action]]],
|
||||
num_microbatches: int,
|
||||
num_stages: int,
|
||||
):
|
||||
"""
|
||||
pipeline_order[rank] = [(computation_type, microbatch_index, stage_index), ...]
|
||||
|
||||
Validating that the pipeline order follows the rules:
|
||||
1. Forward action for a microbatch must be before the Backward action for that microbatch
|
||||
2. Recv for a microbatch must be before the send for that microbatch
|
||||
3. Microbatch index is handled in sequential order for each stage
|
||||
4. A later stage cannot operate on a microbatch before any of the previous stages have operated on it
|
||||
5. Same microbatch cannot be handled in the same time step across ranks
|
||||
"""
|
||||
# microbatch_index: (current computation type, current stage)
|
||||
error_msg = []
|
||||
microbatch_process_info: Dict[int, Tuple(_ComputationType, int)] = {}
|
||||
max_timestep = max(len(rank_list) for rank_list in pipeline_order.values())
|
||||
for timestep in range(max_timestep):
|
||||
error_msg = []
|
||||
current_timestep_actions = []
|
||||
for rank in range(len(pipeline_order)):
|
||||
action = (
|
||||
pipeline_order[rank][timestep]
|
||||
if timestep < len(pipeline_order[rank])
|
||||
else None
|
||||
)
|
||||
if action is not None:
|
||||
current_timestep_actions.append(action)
|
||||
|
||||
# TODO: enable this
|
||||
# if len(current_timestep_actions) == 0:
|
||||
# error_msg.append(
|
||||
# "All actions were None, there is an unnecessary gap in the schedule"
|
||||
# )
|
||||
|
||||
# Ensure that no microbatch is operated on twice in current_timestep_actions
|
||||
unique_microbatch_indices = {
|
||||
action[1] for action in current_timestep_actions
|
||||
}
|
||||
if len(unique_microbatch_indices) != len(current_timestep_actions):
|
||||
error_msg.append(
|
||||
"Duplicate microbatch index found in current_timestep_actions"
|
||||
)
|
||||
|
||||
# Add additional checks for other rules here...
|
||||
for action in current_timestep_actions:
|
||||
computation_type, mb_index, stage_index = action
|
||||
|
||||
if mb_index >= num_microbatches:
|
||||
error_msg.append(f"Microbatch index {mb_index} out of range")
|
||||
|
||||
# first microbatch
|
||||
if mb_index not in microbatch_process_info:
|
||||
if computation_type != _ComputationType.FORWARD or stage_index != 0:
|
||||
error_msg.append(f"Incorrect start for microbatch {mb_index}")
|
||||
microbatch_process_info[mb_index] = (computation_type, stage_index)
|
||||
else:
|
||||
# if the microbatch is included, check that the current stage is right after prev
|
||||
prev_computation, prev_stage = microbatch_process_info[mb_index]
|
||||
if prev_computation == _ComputationType.FORWARD:
|
||||
if prev_stage == num_stages - 1:
|
||||
expected_stage = num_stages - 1
|
||||
expected_computation = _ComputationType.BACKWARD
|
||||
else:
|
||||
expected_stage = prev_stage + 1
|
||||
expected_computation = _ComputationType.FORWARD
|
||||
elif prev_computation == _ComputationType.BACKWARD:
|
||||
if prev_stage == 0:
|
||||
error_msg.append(
|
||||
f"[{mb_index=}] already finished backward computation"
|
||||
)
|
||||
expected_stage = None
|
||||
expected_computation = None
|
||||
else:
|
||||
expected_stage = prev_stage - 1
|
||||
expected_computation = _ComputationType.BACKWARD
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Computation type {prev_computation} not supported"
|
||||
)
|
||||
|
||||
if expected_computation is not None:
|
||||
if expected_computation != computation_type:
|
||||
error_msg.append(
|
||||
f"[{mb_index=}] {expected_computation=} VS. actual {computation_type=}"
|
||||
)
|
||||
|
||||
if expected_stage != stage_index:
|
||||
error_msg.append(
|
||||
f"[{mb_index=}] {expected_stage=} VS. actual {stage_index=}"
|
||||
)
|
||||
|
||||
microbatch_process_info[mb_index] = (
|
||||
expected_computation,
|
||||
expected_stage,
|
||||
)
|
||||
|
||||
if len(error_msg) != 0:
|
||||
self.fail(f"Error at timestep {timestep}: " + ",".join(error_msg))
|
||||
|
||||
@parametrize(
|
||||
"ScheduleClass",
|
||||
[ScheduleFlexibleInterleaved1F1B, ScheduleInterleaved1F1B, ScheduleLoopedBFS],
|
||||
|
|
@ -669,11 +813,8 @@ class TestSchedulePlan(unittest.TestCase):
|
|||
]
|
||||
|
||||
schedule = ScheduleClass(stages, num_microbatches)
|
||||
formatted_pipeline_order = _format_pipeline_order(
|
||||
schedule.pipeline_order
|
||||
)
|
||||
# print(formatted_pipeline_order)
|
||||
_validate_pipeline_order(
|
||||
# print(format_pipeline_order(schedule.pipeline_order))
|
||||
self._validate_pipeline_order(
|
||||
schedule.pipeline_order, num_microbatches, num_stages
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ from typing import List, Tuple, Union
|
|||
import torch
|
||||
from torch import fx
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -2,7 +2,6 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates
|
||||
|
||||
import csv
|
||||
import itertools
|
||||
import logging
|
||||
import re
|
||||
from abc import ABC, abstractmethod
|
||||
|
|
@ -93,144 +92,6 @@ class _Action(NamedTuple):
|
|||
)
|
||||
|
||||
|
||||
def _format_pipeline_order(pipeline_order: Dict[int, List[Optional[_Action]]]) -> str:
|
||||
"""
|
||||
Formats the pipeline order in a timestep (row) x rank (column) grid of actions
|
||||
and returns the formatted string
|
||||
"""
|
||||
|
||||
# Calculate the maximum number of steps across all ranks
|
||||
num_steps = max(len(actions) for actions in pipeline_order.values())
|
||||
step_labels = [
|
||||
"Step " + str(i).zfill(len(str(num_steps - 1))) for i in range(num_steps)
|
||||
]
|
||||
# Sorting the dictionary by keys and retrieving values in that order
|
||||
rank_actions = [
|
||||
pipeline_order.get(key, [""] * num_steps) for key in sorted(pipeline_order)
|
||||
]
|
||||
# Transpose the list of lists (rows to columns)
|
||||
transposed_actions = list(itertools.zip_longest(*rank_actions, fillvalue=""))
|
||||
# Generate column labels for ranks
|
||||
num_ranks = len(pipeline_order)
|
||||
rank_labels = ["Rank " + str(i) for i in range(num_ranks)]
|
||||
# Calculate the maximum length of each column, considering labels
|
||||
max_lengths = [
|
||||
max(len(str(item)) if item is not None else 0 for item in col)
|
||||
for col in zip(step_labels, *transposed_actions)
|
||||
]
|
||||
# Format the header row with rank labels
|
||||
header_row = " " * (len(step_labels[0]) + 2) + " ".join(
|
||||
f"{label:<{max_lengths[i]}}" for i, label in enumerate(rank_labels)
|
||||
)
|
||||
# Format each row with its corresponding label
|
||||
formatted_rows = [
|
||||
f"{label}: "
|
||||
+ " ".join(f"{str(item):<{max_lengths[i]}}" for i, item in enumerate(row))
|
||||
for label, row in zip(step_labels, transposed_actions)
|
||||
]
|
||||
# Join the rows into a single string
|
||||
formatted_table = header_row + "\n" + "\n".join(formatted_rows) + "\n"
|
||||
return formatted_table
|
||||
|
||||
|
||||
def _validate_pipeline_order(
|
||||
pipeline_order: Dict[int, List[Optional[_Action]]],
|
||||
num_microbatches: int,
|
||||
num_stages: int,
|
||||
):
|
||||
"""
|
||||
pipeline_order[rank] = [(computation_type, microbatch_index, stage_index), ...]
|
||||
Validating that the pipeline order follows the rules:
|
||||
1. Forward action for a microbatch must be before the Backward action for that microbatch
|
||||
2. Recv for a microbatch must be before the send for that microbatch
|
||||
3. Microbatch index is handled in sequential order for each stage
|
||||
4. A later stage cannot operate on a microbatch before any of the previous stages have operated on it
|
||||
5. Same microbatch cannot be handled in the same time step across ranks
|
||||
"""
|
||||
# microbatch_index: (current computation type, current stage)
|
||||
microbatch_process_info: Dict[int, Tuple[_ComputationType, int]] = {}
|
||||
max_timestep = max(len(rank_list) for rank_list in pipeline_order.values())
|
||||
for timestep in range(max_timestep):
|
||||
error_msg: List[str] = []
|
||||
current_timestep_actions = []
|
||||
for rank in range(len(pipeline_order)):
|
||||
action = (
|
||||
pipeline_order[rank][timestep]
|
||||
if timestep < len(pipeline_order[rank])
|
||||
else None
|
||||
)
|
||||
if action is not None:
|
||||
current_timestep_actions.append(action)
|
||||
|
||||
# TODO: enable this
|
||||
# if len(current_timestep_actions) == 0:
|
||||
# error_msg.append(
|
||||
# "All actions were None, there is an unnecessary gap in the schedule"
|
||||
# )
|
||||
|
||||
# Ensure that no microbatch is operated on twice in current_timestep_actions
|
||||
unique_microbatch_indices = {action[1] for action in current_timestep_actions}
|
||||
if len(unique_microbatch_indices) != len(current_timestep_actions):
|
||||
error_msg.append(
|
||||
"Duplicate microbatch index found in current_timestep_actions"
|
||||
)
|
||||
|
||||
for action in current_timestep_actions:
|
||||
computation_type, mb_index, stage_index = action
|
||||
|
||||
if mb_index >= num_microbatches:
|
||||
error_msg.append(f"Microbatch index {mb_index} out of range")
|
||||
|
||||
# first microbatch
|
||||
if mb_index not in microbatch_process_info:
|
||||
if computation_type != _ComputationType.FORWARD or stage_index != 0:
|
||||
error_msg.append(f"Incorrect start for microbatch {mb_index}")
|
||||
microbatch_process_info[mb_index] = (computation_type, stage_index)
|
||||
else:
|
||||
# if the microbatch is included, check that the current stage is right after prev
|
||||
prev_computation, prev_stage = microbatch_process_info[mb_index]
|
||||
|
||||
if prev_computation == _ComputationType.FORWARD:
|
||||
if prev_stage == num_stages - 1:
|
||||
expected_stage = num_stages - 1
|
||||
expected_computation = _ComputationType.BACKWARD
|
||||
else:
|
||||
expected_stage = prev_stage + 1
|
||||
expected_computation = _ComputationType.FORWARD
|
||||
elif prev_computation == _ComputationType.BACKWARD:
|
||||
if prev_stage == 0:
|
||||
error_msg.append(
|
||||
f"[{mb_index=}] already finished backward computation"
|
||||
)
|
||||
break
|
||||
else:
|
||||
expected_stage = prev_stage - 1
|
||||
expected_computation = _ComputationType.BACKWARD
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Computation type {prev_computation} not supported"
|
||||
)
|
||||
|
||||
if expected_computation is not None:
|
||||
if expected_computation != computation_type:
|
||||
error_msg.append(
|
||||
f"[{mb_index=}] {expected_computation=} VS. actual {computation_type=}"
|
||||
)
|
||||
|
||||
if expected_stage != stage_index:
|
||||
error_msg.append(
|
||||
f"[{mb_index=}] {expected_stage=} VS. actual {stage_index=}"
|
||||
)
|
||||
|
||||
microbatch_process_info[mb_index] = (
|
||||
expected_computation,
|
||||
expected_stage,
|
||||
)
|
||||
|
||||
if len(error_msg) != 0:
|
||||
raise RuntimeError(f"Error at timestep {timestep}: " + ",".join(error_msg))
|
||||
|
||||
|
||||
class _PipelineSchedule(ABC):
|
||||
def __init__(
|
||||
self,
|
||||
|
|
@ -943,7 +804,6 @@ class PipelineScheduleMulti(_PipelineSchedule):
|
|||
all_next_ranks.add(self.stage_index_to_group_rank[stage_index + 1])
|
||||
|
||||
for time_step, action in enumerate(self.pipeline_order[self.rank]):
|
||||
try:
|
||||
ops: List[dist.P2POp] = []
|
||||
if action is not None:
|
||||
computation_type, mb_index, stage_index = action
|
||||
|
|
@ -999,9 +859,7 @@ class PipelineScheduleMulti(_PipelineSchedule):
|
|||
# Previous rank doing backward or weight update has no influence for the current rank forward recv
|
||||
pass
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unknown computation type {computation_type}"
|
||||
)
|
||||
raise ValueError(f"Unknown computation type {computation_type}")
|
||||
|
||||
for next_rank in all_next_ranks:
|
||||
next_rank_ops = self.pipeline_order[next_rank]
|
||||
|
|
@ -1025,24 +883,11 @@ class PipelineScheduleMulti(_PipelineSchedule):
|
|||
stage = stage_index_to_stage[stage_index - 1]
|
||||
ops.extend(stage.get_bwd_recv_ops(mb_index))
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unknown computation type {computation_type}"
|
||||
)
|
||||
raise ValueError(f"Unknown computation type {computation_type}")
|
||||
|
||||
# do the communication
|
||||
if ops:
|
||||
_batch_p2p(ops).wait()
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"[Rank %s] pipeline schedule %s caught the following exception \
|
||||
at time_step %s when running action %s",
|
||||
self.rank,
|
||||
self.__class__.__name__,
|
||||
time_step,
|
||||
action,
|
||||
)
|
||||
logger.error("%s", _format_pipeline_order(self.pipeline_order))
|
||||
raise e
|
||||
# Return losses if there is a container passed in
|
||||
self._update_losses(self._stages, losses)
|
||||
|
||||
|
|
|
|||
|
|
@ -385,7 +385,7 @@ class _PipelineStageBase(ABC):
|
|||
else:
|
||||
if not (grad is None and grad_recv_stage is None):
|
||||
raise RuntimeError(
|
||||
f"[{self.stage_index}] for chunk {bwd_chunk_id} has gradients {grad} "
|
||||
f"[{self.stage_index}] for chunk {bwd_chunk_id - 1} has gradients {grad} "
|
||||
f"and is expecting to send gradients to stage {grad_recv_stage}"
|
||||
)
|
||||
return ops
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user