Revert "[pipelining] [BE] Move pipeline_order validation to schedules.py (#129369)"

This reverts commit ec789a3c9d.

Reverted https://github.com/pytorch/pytorch/pull/129369 on behalf of https://github.com/clee2000 due to broke test/distributed/pipelining/test_schedule.py::ScheduleTest::test_non_symmetric_stage_ids_ScheduleClass0 on distributed cuda https://github.com/pytorch/pytorch/actions/runs/9766039400/job/26959115773 ec789a3c9d.  You can see the error on the PR, but Dr. CI classified it wrong ([comment](https://github.com/pytorch/pytorch/pull/129369#issuecomment-2204568418))
This commit is contained in:
PyTorch MergeBot 2024-07-02 22:30:53 +00:00
parent b6f781e433
commit b5fdbc1a9f
4 changed files with 231 additions and 244 deletions

View File

@ -6,6 +6,7 @@ import os
import sys import sys
import tempfile import tempfile
import unittest import unittest
from typing import Dict, List, Optional, Tuple
from model_registry import ModelWithKwargs, MultiMLP from model_registry import ModelWithKwargs, MultiMLP
from schedule_registry import ScheduleUnbalanced, ScheduleVShaped, ScheduleWithW from schedule_registry import ScheduleUnbalanced, ScheduleVShaped, ScheduleWithW
@ -21,10 +22,7 @@ from torch.distributed.pipelining import (
ScheduleInterleaved1F1B, ScheduleInterleaved1F1B,
ScheduleLoopedBFS, ScheduleLoopedBFS,
) )
from torch.distributed.pipelining.schedules import ( from torch.distributed.pipelining.schedules import _Action, _ComputationType
_format_pipeline_order,
_validate_pipeline_order,
)
from torch.distributed.pipelining.stage import _PipelineStageBase from torch.distributed.pipelining.stage import _PipelineStageBase
from torch.testing._internal.common_cuda import TEST_MULTIGPU from torch.testing._internal.common_cuda import TEST_MULTIGPU
from torch.testing._internal.common_distributed import ( from torch.testing._internal.common_distributed import (
@ -610,7 +608,153 @@ class ScheduleTest(MultiProcContinousTest):
instantiate_parametrized_tests(ScheduleTest) instantiate_parametrized_tests(ScheduleTest)
def format_pipeline_order(pipeline_order: Dict[int, List[Optional[_Action]]]):
import itertools
# Calculate the maximum number of steps across all ranks
num_steps = max(len(actions) for actions in pipeline_order.values())
step_labels = [
"Step " + str(i).zfill(len(str(num_steps - 1))) for i in range(num_steps)
]
# Sorting the dictionary by keys and retrieving values in that order
rank_actions = [
pipeline_order.get(key, [""] * num_steps) for key in sorted(pipeline_order)
]
# Transpose the list of lists (rows to columns)
transposed_actions = list(itertools.zip_longest(*rank_actions, fillvalue=""))
# Generate column labels for ranks
num_ranks = len(pipeline_order)
rank_labels = ["Rank " + str(i) for i in range(num_ranks)]
# Calculate the maximum length of each column, considering labels
max_lengths = [
max(len(str(item)) if item is not None else 0 for item in col)
for col in zip(step_labels, *transposed_actions)
]
# Format the header row with rank labels
header_row = " " * (len(step_labels[0]) + 2) + " ".join(
f"{label:<{max_lengths[i]}}" for i, label in enumerate(rank_labels)
)
# Format each row with its corresponding label
formatted_rows = [
f"{label}: "
+ " ".join(f"{str(item):<{max_lengths[i]}}" for i, item in enumerate(row))
for label, row in zip(step_labels, transposed_actions)
]
# Join the rows into a single string
formatted_table = (
"=========== ALL_RANK_ACTIONS ===========\n"
+ header_row
+ "\n"
+ "\n".join(formatted_rows)
+ "\n"
)
return formatted_table
class TestSchedulePlan(unittest.TestCase): class TestSchedulePlan(unittest.TestCase):
def _validate_pipeline_order(
self,
pipeline_order: Dict[int, List[Optional[_Action]]],
num_microbatches: int,
num_stages: int,
):
"""
pipeline_order[rank] = [(computation_type, microbatch_index, stage_index), ...]
Validating that the pipeline order follows the rules:
1. Forward action for a microbatch must be before the Backward action for that microbatch
2. Recv for a microbatch must be before the send for that microbatch
3. Microbatch index is handled in sequential order for each stage
4. A later stage cannot operate on a microbatch before any of the previous stages have operated on it
5. Same microbatch cannot be handled in the same time step across ranks
"""
# microbatch_index: (current computation type, current stage)
error_msg = []
microbatch_process_info: Dict[int, Tuple(_ComputationType, int)] = {}
max_timestep = max(len(rank_list) for rank_list in pipeline_order.values())
for timestep in range(max_timestep):
error_msg = []
current_timestep_actions = []
for rank in range(len(pipeline_order)):
action = (
pipeline_order[rank][timestep]
if timestep < len(pipeline_order[rank])
else None
)
if action is not None:
current_timestep_actions.append(action)
# TODO: enable this
# if len(current_timestep_actions) == 0:
# error_msg.append(
# "All actions were None, there is an unnecessary gap in the schedule"
# )
# Ensure that no microbatch is operated on twice in current_timestep_actions
unique_microbatch_indices = {
action[1] for action in current_timestep_actions
}
if len(unique_microbatch_indices) != len(current_timestep_actions):
error_msg.append(
"Duplicate microbatch index found in current_timestep_actions"
)
# Add additional checks for other rules here...
for action in current_timestep_actions:
computation_type, mb_index, stage_index = action
if mb_index >= num_microbatches:
error_msg.append(f"Microbatch index {mb_index} out of range")
# first microbatch
if mb_index not in microbatch_process_info:
if computation_type != _ComputationType.FORWARD or stage_index != 0:
error_msg.append(f"Incorrect start for microbatch {mb_index}")
microbatch_process_info[mb_index] = (computation_type, stage_index)
else:
# if the microbatch is included, check that the current stage is right after prev
prev_computation, prev_stage = microbatch_process_info[mb_index]
if prev_computation == _ComputationType.FORWARD:
if prev_stage == num_stages - 1:
expected_stage = num_stages - 1
expected_computation = _ComputationType.BACKWARD
else:
expected_stage = prev_stage + 1
expected_computation = _ComputationType.FORWARD
elif prev_computation == _ComputationType.BACKWARD:
if prev_stage == 0:
error_msg.append(
f"[{mb_index=}] already finished backward computation"
)
expected_stage = None
expected_computation = None
else:
expected_stage = prev_stage - 1
expected_computation = _ComputationType.BACKWARD
else:
raise ValueError(
f"Computation type {prev_computation} not supported"
)
if expected_computation is not None:
if expected_computation != computation_type:
error_msg.append(
f"[{mb_index=}] {expected_computation=} VS. actual {computation_type=}"
)
if expected_stage != stage_index:
error_msg.append(
f"[{mb_index=}] {expected_stage=} VS. actual {stage_index=}"
)
microbatch_process_info[mb_index] = (
expected_computation,
expected_stage,
)
if len(error_msg) != 0:
self.fail(f"Error at timestep {timestep}: " + ",".join(error_msg))
@parametrize( @parametrize(
"ScheduleClass", "ScheduleClass",
[ScheduleFlexibleInterleaved1F1B, ScheduleInterleaved1F1B, ScheduleLoopedBFS], [ScheduleFlexibleInterleaved1F1B, ScheduleInterleaved1F1B, ScheduleLoopedBFS],
@ -669,11 +813,8 @@ class TestSchedulePlan(unittest.TestCase):
] ]
schedule = ScheduleClass(stages, num_microbatches) schedule = ScheduleClass(stages, num_microbatches)
formatted_pipeline_order = _format_pipeline_order( # print(format_pipeline_order(schedule.pipeline_order))
schedule.pipeline_order self._validate_pipeline_order(
)
# print(formatted_pipeline_order)
_validate_pipeline_order(
schedule.pipeline_order, num_microbatches, num_stages schedule.pipeline_order, num_microbatches, num_stages
) )

View File

@ -7,6 +7,7 @@ from typing import List, Tuple, Union
import torch import torch
from torch import fx from torch import fx
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@ -2,7 +2,6 @@
# Copyright (c) Meta Platforms, Inc. and affiliates # Copyright (c) Meta Platforms, Inc. and affiliates
import csv import csv
import itertools
import logging import logging
import re import re
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
@ -93,144 +92,6 @@ class _Action(NamedTuple):
) )
def _format_pipeline_order(pipeline_order: Dict[int, List[Optional[_Action]]]) -> str:
"""
Formats the pipeline order in a timestep (row) x rank (column) grid of actions
and returns the formatted string
"""
# Calculate the maximum number of steps across all ranks
num_steps = max(len(actions) for actions in pipeline_order.values())
step_labels = [
"Step " + str(i).zfill(len(str(num_steps - 1))) for i in range(num_steps)
]
# Sorting the dictionary by keys and retrieving values in that order
rank_actions = [
pipeline_order.get(key, [""] * num_steps) for key in sorted(pipeline_order)
]
# Transpose the list of lists (rows to columns)
transposed_actions = list(itertools.zip_longest(*rank_actions, fillvalue=""))
# Generate column labels for ranks
num_ranks = len(pipeline_order)
rank_labels = ["Rank " + str(i) for i in range(num_ranks)]
# Calculate the maximum length of each column, considering labels
max_lengths = [
max(len(str(item)) if item is not None else 0 for item in col)
for col in zip(step_labels, *transposed_actions)
]
# Format the header row with rank labels
header_row = " " * (len(step_labels[0]) + 2) + " ".join(
f"{label:<{max_lengths[i]}}" for i, label in enumerate(rank_labels)
)
# Format each row with its corresponding label
formatted_rows = [
f"{label}: "
+ " ".join(f"{str(item):<{max_lengths[i]}}" for i, item in enumerate(row))
for label, row in zip(step_labels, transposed_actions)
]
# Join the rows into a single string
formatted_table = header_row + "\n" + "\n".join(formatted_rows) + "\n"
return formatted_table
def _validate_pipeline_order(
pipeline_order: Dict[int, List[Optional[_Action]]],
num_microbatches: int,
num_stages: int,
):
"""
pipeline_order[rank] = [(computation_type, microbatch_index, stage_index), ...]
Validating that the pipeline order follows the rules:
1. Forward action for a microbatch must be before the Backward action for that microbatch
2. Recv for a microbatch must be before the send for that microbatch
3. Microbatch index is handled in sequential order for each stage
4. A later stage cannot operate on a microbatch before any of the previous stages have operated on it
5. Same microbatch cannot be handled in the same time step across ranks
"""
# microbatch_index: (current computation type, current stage)
microbatch_process_info: Dict[int, Tuple[_ComputationType, int]] = {}
max_timestep = max(len(rank_list) for rank_list in pipeline_order.values())
for timestep in range(max_timestep):
error_msg: List[str] = []
current_timestep_actions = []
for rank in range(len(pipeline_order)):
action = (
pipeline_order[rank][timestep]
if timestep < len(pipeline_order[rank])
else None
)
if action is not None:
current_timestep_actions.append(action)
# TODO: enable this
# if len(current_timestep_actions) == 0:
# error_msg.append(
# "All actions were None, there is an unnecessary gap in the schedule"
# )
# Ensure that no microbatch is operated on twice in current_timestep_actions
unique_microbatch_indices = {action[1] for action in current_timestep_actions}
if len(unique_microbatch_indices) != len(current_timestep_actions):
error_msg.append(
"Duplicate microbatch index found in current_timestep_actions"
)
for action in current_timestep_actions:
computation_type, mb_index, stage_index = action
if mb_index >= num_microbatches:
error_msg.append(f"Microbatch index {mb_index} out of range")
# first microbatch
if mb_index not in microbatch_process_info:
if computation_type != _ComputationType.FORWARD or stage_index != 0:
error_msg.append(f"Incorrect start for microbatch {mb_index}")
microbatch_process_info[mb_index] = (computation_type, stage_index)
else:
# if the microbatch is included, check that the current stage is right after prev
prev_computation, prev_stage = microbatch_process_info[mb_index]
if prev_computation == _ComputationType.FORWARD:
if prev_stage == num_stages - 1:
expected_stage = num_stages - 1
expected_computation = _ComputationType.BACKWARD
else:
expected_stage = prev_stage + 1
expected_computation = _ComputationType.FORWARD
elif prev_computation == _ComputationType.BACKWARD:
if prev_stage == 0:
error_msg.append(
f"[{mb_index=}] already finished backward computation"
)
break
else:
expected_stage = prev_stage - 1
expected_computation = _ComputationType.BACKWARD
else:
raise ValueError(
f"Computation type {prev_computation} not supported"
)
if expected_computation is not None:
if expected_computation != computation_type:
error_msg.append(
f"[{mb_index=}] {expected_computation=} VS. actual {computation_type=}"
)
if expected_stage != stage_index:
error_msg.append(
f"[{mb_index=}] {expected_stage=} VS. actual {stage_index=}"
)
microbatch_process_info[mb_index] = (
expected_computation,
expected_stage,
)
if len(error_msg) != 0:
raise RuntimeError(f"Error at timestep {timestep}: " + ",".join(error_msg))
class _PipelineSchedule(ABC): class _PipelineSchedule(ABC):
def __init__( def __init__(
self, self,
@ -943,106 +804,90 @@ class PipelineScheduleMulti(_PipelineSchedule):
all_next_ranks.add(self.stage_index_to_group_rank[stage_index + 1]) all_next_ranks.add(self.stage_index_to_group_rank[stage_index + 1])
for time_step, action in enumerate(self.pipeline_order[self.rank]): for time_step, action in enumerate(self.pipeline_order[self.rank]):
try: ops: List[dist.P2POp] = []
ops: List[dist.P2POp] = [] if action is not None:
if action is not None: computation_type, mb_index, stage_index = action
computation_type, mb_index, stage_index = action if computation_type == _ComputationType.FORWARD:
# perform forward computation
stage = stage_index_to_stage[stage_index]
output = stage.forward_one_chunk(
mb_index, arg_mbs[mb_index], kwarg_mbs[mb_index]
)
self._maybe_compute_loss(stage, output, target_mbs, mb_index)
ops.extend(stage.get_fwd_send_ops(mb_index))
elif computation_type == _ComputationType.BACKWARD:
# perform backward computation
stage = stage_index_to_stage[stage_index]
loss = self._maybe_get_loss(stage, mb_index)
stage.backward_one_chunk(
mb_index, loss=loss, full_backward=self.use_full_backward
)
ops.extend(stage.get_bwd_send_ops(mb_index))
elif computation_type == _ComputationType.WEIGHT:
# perform weight update
if self.use_full_backward:
raise ValueError(
f"We detected a weight update in the pipeline schedule, but \
{self.use_full_backward=}"
)
stage = stage_index_to_stage[stage_index]
stage.backward_weight_one_chunk(mb_index)
else:
raise ValueError(f"Unknown computation type {computation_type}")
# Look at the neighboring ranks for this current timestep and determine whether
# this current rank needs to do any recv communication
for prev_rank in all_prev_ranks:
prev_rank_ops = self.pipeline_order[prev_rank]
prev_rank_action = None
if time_step < len(prev_rank_ops):
prev_rank_action = prev_rank_ops[time_step]
if prev_rank_action is not None:
computation_type, mb_index, stage_index = prev_rank_action
# Only handle sends for the forward from a previous rank
if computation_type == _ComputationType.FORWARD: if computation_type == _ComputationType.FORWARD:
# perform forward computation # If not the last stage, then receive fwd activations
stage = stage_index_to_stage[stage_index] if stage_index + 1 in stage_index_to_stage:
output = stage.forward_one_chunk( # TODO: We are assuming that stage will always receive from stage-1
mb_index, arg_mbs[mb_index], kwarg_mbs[mb_index] # however that is not necessarily true of get_fwd_recv_ops
) stage = stage_index_to_stage[stage_index + 1]
self._maybe_compute_loss(stage, output, target_mbs, mb_index) ops.extend(stage.get_fwd_recv_ops(mb_index))
ops.extend(stage.get_fwd_send_ops(mb_index)) elif (
elif computation_type == _ComputationType.BACKWARD: computation_type == _ComputationType.BACKWARD
# perform backward computation or computation_type == _ComputationType.WEIGHT
stage = stage_index_to_stage[stage_index] ):
loss = self._maybe_get_loss(stage, mb_index) # Previous rank doing backward or weight update has no influence for the current rank forward recv
stage.backward_one_chunk( pass
mb_index, loss=loss, full_backward=self.use_full_backward
)
ops.extend(stage.get_bwd_send_ops(mb_index))
elif computation_type == _ComputationType.WEIGHT:
# perform weight update
if self.use_full_backward:
raise ValueError(
f"We detected a weight update in the pipeline schedule, but \
{self.use_full_backward=}"
)
stage = stage_index_to_stage[stage_index]
stage.backward_weight_one_chunk(mb_index)
else: else:
raise ValueError(f"Unknown computation type {computation_type}") raise ValueError(f"Unknown computation type {computation_type}")
# Look at the neighboring ranks for this current timestep and determine whether for next_rank in all_next_ranks:
# this current rank needs to do any recv communication next_rank_ops = self.pipeline_order[next_rank]
for prev_rank in all_prev_ranks: next_rank_action = None
prev_rank_ops = self.pipeline_order[prev_rank] if time_step < len(next_rank_ops):
prev_rank_action = None next_rank_action = next_rank_ops[time_step]
if time_step < len(prev_rank_ops): if next_rank_action is not None:
prev_rank_action = prev_rank_ops[time_step] computation_type, mb_index, stage_index = next_rank_action
if prev_rank_action is not None: # Only handle receives for the backwards from a next rank
computation_type, mb_index, stage_index = prev_rank_action if (
# Only handle sends for the forward from a previous rank computation_type == _ComputationType.FORWARD
if computation_type == _ComputationType.FORWARD: or computation_type == _ComputationType.WEIGHT
# If not the last stage, then receive fwd activations ):
if stage_index + 1 in stage_index_to_stage: # Next rank doing forward or weight update has no influence for the current rank backward recv
# TODO: We are assuming that stage will always receive from stage-1 pass
# however that is not necessarily true of get_fwd_recv_ops elif computation_type == _ComputationType.BACKWARD:
stage = stage_index_to_stage[stage_index + 1] # If not the first stage, then receive bwd gradients
ops.extend(stage.get_fwd_recv_ops(mb_index)) if stage_index - 1 in stage_index_to_stage:
elif ( # TODO: We are assuming that stage will always receive from stage+1
computation_type == _ComputationType.BACKWARD # however that is not necessarily true of get_bwd_recv_ops
or computation_type == _ComputationType.WEIGHT stage = stage_index_to_stage[stage_index - 1]
): ops.extend(stage.get_bwd_recv_ops(mb_index))
# Previous rank doing backward or weight update has no influence for the current rank forward recv else:
pass raise ValueError(f"Unknown computation type {computation_type}")
else:
raise ValueError(
f"Unknown computation type {computation_type}"
)
for next_rank in all_next_ranks: # do the communication
next_rank_ops = self.pipeline_order[next_rank] if ops:
next_rank_action = None _batch_p2p(ops).wait()
if time_step < len(next_rank_ops):
next_rank_action = next_rank_ops[time_step]
if next_rank_action is not None:
computation_type, mb_index, stage_index = next_rank_action
# Only handle receives for the backwards from a next rank
if (
computation_type == _ComputationType.FORWARD
or computation_type == _ComputationType.WEIGHT
):
# Next rank doing forward or weight update has no influence for the current rank backward recv
pass
elif computation_type == _ComputationType.BACKWARD:
# If not the first stage, then receive bwd gradients
if stage_index - 1 in stage_index_to_stage:
# TODO: We are assuming that stage will always receive from stage+1
# however that is not necessarily true of get_bwd_recv_ops
stage = stage_index_to_stage[stage_index - 1]
ops.extend(stage.get_bwd_recv_ops(mb_index))
else:
raise ValueError(
f"Unknown computation type {computation_type}"
)
# do the communication
if ops:
_batch_p2p(ops).wait()
except Exception as e:
logger.error(
"[Rank %s] pipeline schedule %s caught the following exception \
at time_step %s when running action %s",
self.rank,
self.__class__.__name__,
time_step,
action,
)
logger.error("%s", _format_pipeline_order(self.pipeline_order))
raise e
# Return losses if there is a container passed in # Return losses if there is a container passed in
self._update_losses(self._stages, losses) self._update_losses(self._stages, losses)

View File

@ -385,7 +385,7 @@ class _PipelineStageBase(ABC):
else: else:
if not (grad is None and grad_recv_stage is None): if not (grad is None and grad_recv_stage is None):
raise RuntimeError( raise RuntimeError(
f"[{self.stage_index}] for chunk {bwd_chunk_id} has gradients {grad} " f"[{self.stage_index}] for chunk {bwd_chunk_id - 1} has gradients {grad} "
f"and is expecting to send gradients to stage {grad_recv_stage}" f"and is expecting to send gradients to stage {grad_recv_stage}"
) )
return ops return ops