Revert "[pipelining] [BE] Move pipeline_order validation to schedules.py (#129369)"

This reverts commit ec789a3c9d.

Reverted https://github.com/pytorch/pytorch/pull/129369 on behalf of https://github.com/clee2000 due to broke test/distributed/pipelining/test_schedule.py::ScheduleTest::test_non_symmetric_stage_ids_ScheduleClass0 on distributed cuda https://github.com/pytorch/pytorch/actions/runs/9766039400/job/26959115773 ec789a3c9d.  You can see the error on the PR, but Dr. CI classified it wrong ([comment](https://github.com/pytorch/pytorch/pull/129369#issuecomment-2204568418))
This commit is contained in:
PyTorch MergeBot 2024-07-02 22:30:53 +00:00
parent b6f781e433
commit b5fdbc1a9f
4 changed files with 231 additions and 244 deletions

View File

@ -6,6 +6,7 @@ import os
import sys
import tempfile
import unittest
from typing import Dict, List, Optional, Tuple
from model_registry import ModelWithKwargs, MultiMLP
from schedule_registry import ScheduleUnbalanced, ScheduleVShaped, ScheduleWithW
@ -21,10 +22,7 @@ from torch.distributed.pipelining import (
ScheduleInterleaved1F1B,
ScheduleLoopedBFS,
)
from torch.distributed.pipelining.schedules import (
_format_pipeline_order,
_validate_pipeline_order,
)
from torch.distributed.pipelining.schedules import _Action, _ComputationType
from torch.distributed.pipelining.stage import _PipelineStageBase
from torch.testing._internal.common_cuda import TEST_MULTIGPU
from torch.testing._internal.common_distributed import (
@ -610,7 +608,153 @@ class ScheduleTest(MultiProcContinousTest):
instantiate_parametrized_tests(ScheduleTest)
def format_pipeline_order(pipeline_order: Dict[int, List[Optional[_Action]]]):
import itertools
# Calculate the maximum number of steps across all ranks
num_steps = max(len(actions) for actions in pipeline_order.values())
step_labels = [
"Step " + str(i).zfill(len(str(num_steps - 1))) for i in range(num_steps)
]
# Sorting the dictionary by keys and retrieving values in that order
rank_actions = [
pipeline_order.get(key, [""] * num_steps) for key in sorted(pipeline_order)
]
# Transpose the list of lists (rows to columns)
transposed_actions = list(itertools.zip_longest(*rank_actions, fillvalue=""))
# Generate column labels for ranks
num_ranks = len(pipeline_order)
rank_labels = ["Rank " + str(i) for i in range(num_ranks)]
# Calculate the maximum length of each column, considering labels
max_lengths = [
max(len(str(item)) if item is not None else 0 for item in col)
for col in zip(step_labels, *transposed_actions)
]
# Format the header row with rank labels
header_row = " " * (len(step_labels[0]) + 2) + " ".join(
f"{label:<{max_lengths[i]}}" for i, label in enumerate(rank_labels)
)
# Format each row with its corresponding label
formatted_rows = [
f"{label}: "
+ " ".join(f"{str(item):<{max_lengths[i]}}" for i, item in enumerate(row))
for label, row in zip(step_labels, transposed_actions)
]
# Join the rows into a single string
formatted_table = (
"=========== ALL_RANK_ACTIONS ===========\n"
+ header_row
+ "\n"
+ "\n".join(formatted_rows)
+ "\n"
)
return formatted_table
class TestSchedulePlan(unittest.TestCase):
def _validate_pipeline_order(
self,
pipeline_order: Dict[int, List[Optional[_Action]]],
num_microbatches: int,
num_stages: int,
):
"""
pipeline_order[rank] = [(computation_type, microbatch_index, stage_index), ...]
Validating that the pipeline order follows the rules:
1. Forward action for a microbatch must be before the Backward action for that microbatch
2. Recv for a microbatch must be before the send for that microbatch
3. Microbatch index is handled in sequential order for each stage
4. A later stage cannot operate on a microbatch before any of the previous stages have operated on it
5. Same microbatch cannot be handled in the same time step across ranks
"""
# microbatch_index: (current computation type, current stage)
error_msg = []
microbatch_process_info: Dict[int, Tuple(_ComputationType, int)] = {}
max_timestep = max(len(rank_list) for rank_list in pipeline_order.values())
for timestep in range(max_timestep):
error_msg = []
current_timestep_actions = []
for rank in range(len(pipeline_order)):
action = (
pipeline_order[rank][timestep]
if timestep < len(pipeline_order[rank])
else None
)
if action is not None:
current_timestep_actions.append(action)
# TODO: enable this
# if len(current_timestep_actions) == 0:
# error_msg.append(
# "All actions were None, there is an unnecessary gap in the schedule"
# )
# Ensure that no microbatch is operated on twice in current_timestep_actions
unique_microbatch_indices = {
action[1] for action in current_timestep_actions
}
if len(unique_microbatch_indices) != len(current_timestep_actions):
error_msg.append(
"Duplicate microbatch index found in current_timestep_actions"
)
# Add additional checks for other rules here...
for action in current_timestep_actions:
computation_type, mb_index, stage_index = action
if mb_index >= num_microbatches:
error_msg.append(f"Microbatch index {mb_index} out of range")
# first microbatch
if mb_index not in microbatch_process_info:
if computation_type != _ComputationType.FORWARD or stage_index != 0:
error_msg.append(f"Incorrect start for microbatch {mb_index}")
microbatch_process_info[mb_index] = (computation_type, stage_index)
else:
# if the microbatch is included, check that the current stage is right after prev
prev_computation, prev_stage = microbatch_process_info[mb_index]
if prev_computation == _ComputationType.FORWARD:
if prev_stage == num_stages - 1:
expected_stage = num_stages - 1
expected_computation = _ComputationType.BACKWARD
else:
expected_stage = prev_stage + 1
expected_computation = _ComputationType.FORWARD
elif prev_computation == _ComputationType.BACKWARD:
if prev_stage == 0:
error_msg.append(
f"[{mb_index=}] already finished backward computation"
)
expected_stage = None
expected_computation = None
else:
expected_stage = prev_stage - 1
expected_computation = _ComputationType.BACKWARD
else:
raise ValueError(
f"Computation type {prev_computation} not supported"
)
if expected_computation is not None:
if expected_computation != computation_type:
error_msg.append(
f"[{mb_index=}] {expected_computation=} VS. actual {computation_type=}"
)
if expected_stage != stage_index:
error_msg.append(
f"[{mb_index=}] {expected_stage=} VS. actual {stage_index=}"
)
microbatch_process_info[mb_index] = (
expected_computation,
expected_stage,
)
if len(error_msg) != 0:
self.fail(f"Error at timestep {timestep}: " + ",".join(error_msg))
@parametrize(
"ScheduleClass",
[ScheduleFlexibleInterleaved1F1B, ScheduleInterleaved1F1B, ScheduleLoopedBFS],
@ -669,11 +813,8 @@ class TestSchedulePlan(unittest.TestCase):
]
schedule = ScheduleClass(stages, num_microbatches)
formatted_pipeline_order = _format_pipeline_order(
schedule.pipeline_order
)
# print(formatted_pipeline_order)
_validate_pipeline_order(
# print(format_pipeline_order(schedule.pipeline_order))
self._validate_pipeline_order(
schedule.pipeline_order, num_microbatches, num_stages
)

View File

@ -7,6 +7,7 @@ from typing import List, Tuple, Union
import torch
from torch import fx
logger = logging.getLogger(__name__)

View File

@ -2,7 +2,6 @@
# Copyright (c) Meta Platforms, Inc. and affiliates
import csv
import itertools
import logging
import re
from abc import ABC, abstractmethod
@ -93,144 +92,6 @@ class _Action(NamedTuple):
)
def _format_pipeline_order(pipeline_order: Dict[int, List[Optional[_Action]]]) -> str:
"""
Formats the pipeline order in a timestep (row) x rank (column) grid of actions
and returns the formatted string
"""
# Calculate the maximum number of steps across all ranks
num_steps = max(len(actions) for actions in pipeline_order.values())
step_labels = [
"Step " + str(i).zfill(len(str(num_steps - 1))) for i in range(num_steps)
]
# Sorting the dictionary by keys and retrieving values in that order
rank_actions = [
pipeline_order.get(key, [""] * num_steps) for key in sorted(pipeline_order)
]
# Transpose the list of lists (rows to columns)
transposed_actions = list(itertools.zip_longest(*rank_actions, fillvalue=""))
# Generate column labels for ranks
num_ranks = len(pipeline_order)
rank_labels = ["Rank " + str(i) for i in range(num_ranks)]
# Calculate the maximum length of each column, considering labels
max_lengths = [
max(len(str(item)) if item is not None else 0 for item in col)
for col in zip(step_labels, *transposed_actions)
]
# Format the header row with rank labels
header_row = " " * (len(step_labels[0]) + 2) + " ".join(
f"{label:<{max_lengths[i]}}" for i, label in enumerate(rank_labels)
)
# Format each row with its corresponding label
formatted_rows = [
f"{label}: "
+ " ".join(f"{str(item):<{max_lengths[i]}}" for i, item in enumerate(row))
for label, row in zip(step_labels, transposed_actions)
]
# Join the rows into a single string
formatted_table = header_row + "\n" + "\n".join(formatted_rows) + "\n"
return formatted_table
def _validate_pipeline_order(
pipeline_order: Dict[int, List[Optional[_Action]]],
num_microbatches: int,
num_stages: int,
):
"""
pipeline_order[rank] = [(computation_type, microbatch_index, stage_index), ...]
Validating that the pipeline order follows the rules:
1. Forward action for a microbatch must be before the Backward action for that microbatch
2. Recv for a microbatch must be before the send for that microbatch
3. Microbatch index is handled in sequential order for each stage
4. A later stage cannot operate on a microbatch before any of the previous stages have operated on it
5. Same microbatch cannot be handled in the same time step across ranks
"""
# microbatch_index: (current computation type, current stage)
microbatch_process_info: Dict[int, Tuple[_ComputationType, int]] = {}
max_timestep = max(len(rank_list) for rank_list in pipeline_order.values())
for timestep in range(max_timestep):
error_msg: List[str] = []
current_timestep_actions = []
for rank in range(len(pipeline_order)):
action = (
pipeline_order[rank][timestep]
if timestep < len(pipeline_order[rank])
else None
)
if action is not None:
current_timestep_actions.append(action)
# TODO: enable this
# if len(current_timestep_actions) == 0:
# error_msg.append(
# "All actions were None, there is an unnecessary gap in the schedule"
# )
# Ensure that no microbatch is operated on twice in current_timestep_actions
unique_microbatch_indices = {action[1] for action in current_timestep_actions}
if len(unique_microbatch_indices) != len(current_timestep_actions):
error_msg.append(
"Duplicate microbatch index found in current_timestep_actions"
)
for action in current_timestep_actions:
computation_type, mb_index, stage_index = action
if mb_index >= num_microbatches:
error_msg.append(f"Microbatch index {mb_index} out of range")
# first microbatch
if mb_index not in microbatch_process_info:
if computation_type != _ComputationType.FORWARD or stage_index != 0:
error_msg.append(f"Incorrect start for microbatch {mb_index}")
microbatch_process_info[mb_index] = (computation_type, stage_index)
else:
# if the microbatch is included, check that the current stage is right after prev
prev_computation, prev_stage = microbatch_process_info[mb_index]
if prev_computation == _ComputationType.FORWARD:
if prev_stage == num_stages - 1:
expected_stage = num_stages - 1
expected_computation = _ComputationType.BACKWARD
else:
expected_stage = prev_stage + 1
expected_computation = _ComputationType.FORWARD
elif prev_computation == _ComputationType.BACKWARD:
if prev_stage == 0:
error_msg.append(
f"[{mb_index=}] already finished backward computation"
)
break
else:
expected_stage = prev_stage - 1
expected_computation = _ComputationType.BACKWARD
else:
raise ValueError(
f"Computation type {prev_computation} not supported"
)
if expected_computation is not None:
if expected_computation != computation_type:
error_msg.append(
f"[{mb_index=}] {expected_computation=} VS. actual {computation_type=}"
)
if expected_stage != stage_index:
error_msg.append(
f"[{mb_index=}] {expected_stage=} VS. actual {stage_index=}"
)
microbatch_process_info[mb_index] = (
expected_computation,
expected_stage,
)
if len(error_msg) != 0:
raise RuntimeError(f"Error at timestep {timestep}: " + ",".join(error_msg))
class _PipelineSchedule(ABC):
def __init__(
self,
@ -943,7 +804,6 @@ class PipelineScheduleMulti(_PipelineSchedule):
all_next_ranks.add(self.stage_index_to_group_rank[stage_index + 1])
for time_step, action in enumerate(self.pipeline_order[self.rank]):
try:
ops: List[dist.P2POp] = []
if action is not None:
computation_type, mb_index, stage_index = action
@ -999,9 +859,7 @@ class PipelineScheduleMulti(_PipelineSchedule):
# Previous rank doing backward or weight update has no influence for the current rank forward recv
pass
else:
raise ValueError(
f"Unknown computation type {computation_type}"
)
raise ValueError(f"Unknown computation type {computation_type}")
for next_rank in all_next_ranks:
next_rank_ops = self.pipeline_order[next_rank]
@ -1025,24 +883,11 @@ class PipelineScheduleMulti(_PipelineSchedule):
stage = stage_index_to_stage[stage_index - 1]
ops.extend(stage.get_bwd_recv_ops(mb_index))
else:
raise ValueError(
f"Unknown computation type {computation_type}"
)
raise ValueError(f"Unknown computation type {computation_type}")
# do the communication
if ops:
_batch_p2p(ops).wait()
except Exception as e:
logger.error(
"[Rank %s] pipeline schedule %s caught the following exception \
at time_step %s when running action %s",
self.rank,
self.__class__.__name__,
time_step,
action,
)
logger.error("%s", _format_pipeline_order(self.pipeline_order))
raise e
# Return losses if there is a container passed in
self._update_losses(self._stages, losses)

View File

@ -385,7 +385,7 @@ class _PipelineStageBase(ABC):
else:
if not (grad is None and grad_recv_stage is None):
raise RuntimeError(
f"[{self.stage_index}] for chunk {bwd_chunk_id} has gradients {grad} "
f"[{self.stage_index}] for chunk {bwd_chunk_id - 1} has gradients {grad} "
f"and is expecting to send gradients to stage {grad_recv_stage}"
)
return ops