[Pipelining] Update schedules to use I, B actions. (#138886)

Also, update tests to use I (BACKWARD_INPUT) vs B (FULL_BACKWARD)
consistently.

Previously, schedules would issue a 'B' operation and leave it ambiguous
whether that operation should be BACKWARD_INPUT or FULL_BACKWARD,
depending on a separate flag (use_full_backward) passed to the schedule
class, which would determine which behavior was taken at runtime.

Now, use_full_backward is removed and the schedule class is required to
produce unambiguous IR.  The logic for 'use_full_backward' is removed
from the runtime.

_validate_pipeline_order is replaced  with _simulate_comms_compute. Both
offer similar functionality, to validate the corrrectness of a schedule
IR.  'validate' operates on compute-only IR, while simulate operates on
compute + comm IR.  To convert from using validate to simulate, you have
to first insert comm actions via '_add_send_recv'.

'simulate' was inefficiently written before this PR and needed to be
optimized to run quickly for extra large schedules with >32 ranks and
microbatches per rank used in some unit tests.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/138886
Approved by: https://github.com/H-Huang
This commit is contained in:
Will Constable 2024-10-31 17:46:20 -07:00 committed by PyTorch MergeBot
parent 094d288f40
commit 84416618a6
4 changed files with 77 additions and 210 deletions

View File

@ -140,8 +140,6 @@ class ScheduleWithW(PipelineScheduleMulti):
self.use_full_backward = False
# Go through two microbatches
# TODO(whc) unify the semantics of the IR for old runtime with new runtime.
# make 'I' a supported action in old runtime
self.pipeline_order = {
0: [
_Action(0, F, 0),
@ -149,12 +147,12 @@ class ScheduleWithW(PipelineScheduleMulti):
_Action(2, F, 0),
_Action(2, F, 1),
None,
_Action(2, B, 0),
_Action(2, I, 0),
_Action(2, W, 0),
_Action(0, B, 0),
_Action(2, B, 1),
_Action(0, I, 0),
_Action(2, I, 1),
_Action(0, W, 0),
_Action(0, B, 1),
_Action(0, I, 1),
_Action(2, W, 1),
_Action(0, W, 1),
],
@ -163,12 +161,12 @@ class ScheduleWithW(PipelineScheduleMulti):
_Action(1, F, 0),
_Action(1, F, 1),
_Action(3, F, 0),
_Action(3, B, 0),
_Action(3, I, 0),
_Action(3, F, 1),
_Action(1, B, 0),
_Action(3, B, 1),
_Action(1, I, 0),
_Action(3, I, 1),
_Action(3, W, 0),
_Action(1, B, 1),
_Action(1, I, 1),
_Action(1, W, 0),
_Action(3, W, 1),
_Action(1, W, 1),

View File

@ -25,7 +25,6 @@ from torch.distributed.pipelining.schedules import (
_PipelineSchedule,
_PipelineScheduleRuntime,
_simulate_comms_compute,
_validate_pipeline_order,
B,
F,
get_schedule_class,
@ -267,9 +266,19 @@ class TestSchedulePlan(TestCase):
formatted_pipeline_order = _format_pipeline_order(
schedule.pipeline_order
)
# print(formatted_pipeline_order)
_validate_pipeline_order(
schedule.pipeline_order, num_microbatches, num_stages
def stage_to_rank(stage):
return stage % group_size
comms_sch = _add_send_recv(
schedule.pipeline_order,
stage_to_rank=stage_to_rank,
num_stages=num_stages,
)
_simulate_comms_compute(
comms_sch,
stage_to_rank=stage_to_rank,
num_stages=num_stages,
)
@parametrize(
@ -299,11 +308,20 @@ class TestSchedulePlan(TestCase):
schedule.pipeline_order
)
# print(formatted_pipeline_order)
_validate_pipeline_order(
def stage_to_rank(stage):
return stage % group_size
comms_sch = _add_send_recv(
schedule.pipeline_order,
num_microbatches,
num_stages,
enable_zero_bubble=(ScheduleClass is ScheduleInterleavedZeroBubble),
stage_to_rank=stage_to_rank,
num_stages=num_stages,
)
# print(_format_pipeline_order(comms_sch))
_simulate_comms_compute(
comms_sch,
stage_to_rank=stage_to_rank,
num_stages=num_stages,
)
@ -678,8 +696,6 @@ class TestScheduleLowering(TestCase):
num_microbatches,
loss_fn=loss_fn,
stage_index_to_group_rank=[0, 0],
# TODO should we test both T/F?
use_full_backward=True,
)
schedule._load_actions(
{
@ -792,7 +808,6 @@ class TestScheduleLowering(TestCase):
num_microbatches,
loss_fn=loss_fn,
stage_index_to_group_rank=[0],
use_full_backward=False,
)
schedule._load_actions(
{

View File

@ -364,7 +364,7 @@ class ScheduleTest(MultiProcContinousTest):
"ScheduleClass",
[ScheduleInterleaved1F1B, ScheduleLoopedBFS, ScheduleInterleavedZeroBubble],
)
@parametrize("use_new_runtime", [False])
@parametrize("use_new_runtime", [False, True])
def test_grad_with_manual_interleaved(self, ScheduleClass, use_new_runtime):
stages_per_rank = 2
n_stages = stages_per_rank * self.world_size
@ -422,7 +422,6 @@ class ScheduleTest(MultiProcContinousTest):
num_microbatches,
loss_fn=loss_fn,
stage_index_to_group_rank=old_schedule.stage_index_to_group_rank,
use_full_backward=old_schedule.use_full_backward,
)
tmp_schedule._load_actions(old_schedule.pipeline_order)
# test that csv round-trip works for compute_comms schedule
@ -431,7 +430,6 @@ class ScheduleTest(MultiProcContinousTest):
num_microbatches,
loss_fn=loss_fn,
stage_index_to_group_rank=old_schedule.stage_index_to_group_rank,
use_full_backward=old_schedule.use_full_backward,
)
with tempfile.NamedTemporaryFile() as f:
tmp_schedule._dump_csv(f.name)
@ -442,7 +440,6 @@ class ScheduleTest(MultiProcContinousTest):
num_microbatches,
loss_fn=loss_fn,
stage_index_to_group_rank=old_schedule.stage_index_to_group_rank,
use_full_backward=old_schedule.use_full_backward,
)
one_more_schedule._load_actions(
schedule.pipeline_order_with_comms, format="compute_comms"

View File

@ -209,150 +209,6 @@ def _format_pipeline_order(pipeline_order: Dict[int, List[Optional[_Action]]]) -
return formatted_table
def _validate_pipeline_order(
pipeline_order: Dict[int, List[Optional[_Action]]],
num_microbatches: int,
num_stages: int,
enable_zero_bubble: bool = False,
):
"""
pipeline_order[rank] = [(computation_type, microbatch_index, stage_index), ...]
Validating that the pipeline order follows the rules:
1. Forward action for a microbatch must be before the Backward action for that microbatch
2. Recv for a microbatch must be before the send for that microbatch
3. Microbatch index is handled in sequential order for each stage
4. A later stage cannot operate on a microbatch before any of the previous stages have operated on it
5. Same microbatch cannot be handled in the same time step across ranks
"""
# microbatch_index: (current computation type, current stage)
microbatch_process_info: Dict[int, Tuple[_ComputationType, int]] = {}
max_timestep = max(len(rank_list) for rank_list in pipeline_order.values())
for timestep in range(max_timestep):
error_msg: List[str] = []
current_timestep_actions = []
for rank in range(len(pipeline_order)):
action = (
pipeline_order[rank][timestep]
if timestep < len(pipeline_order[rank])
else None
)
if action is not None:
computation_type = action.computation_type
if computation_type != _ComputationType.BACKWARD_WEIGHT:
current_timestep_actions.append(action)
# TODO: enable this
# if len(current_timestep_actions) == 0:
# error_msg.append(
# "All actions were None, there is an unnecessary gap in the schedule"
# )
# Ensure that no microbatch is operated on twice in current_timestep_actions
unique_microbatch_indices = {
action.microbatch_index for action in current_timestep_actions
}
if len(unique_microbatch_indices) != len(current_timestep_actions):
error_msg.append(
"Duplicate microbatch index found in current_timestep_actions"
)
for action in current_timestep_actions:
stage_index = action.stage_index
computation_type = action.computation_type
mb_index = action.microbatch_index
assert (
mb_index is not None
), "All currently supported action types require valid microbatch_index"
if mb_index >= num_microbatches:
error_msg.append(f"Microbatch index {mb_index} out of range")
# first microbatch
if mb_index not in microbatch_process_info:
if computation_type != _ComputationType.FORWARD or stage_index != 0:
error_msg.append(f"Incorrect start for microbatch {mb_index}")
microbatch_process_info[mb_index] = (computation_type, stage_index)
else:
# if the microbatch is included, check that the current stage is right after prev
prev_computation, prev_stage = microbatch_process_info[mb_index]
if prev_computation == _ComputationType.FORWARD:
if prev_stage == num_stages - 1:
expected_stage = num_stages - 1
expected_computation = _ComputationType.FULL_BACKWARD
else:
expected_stage = prev_stage + 1
expected_computation = _ComputationType.FORWARD
elif prev_computation == _ComputationType.FULL_BACKWARD:
if prev_stage == 0:
error_msg.append(
f"[{mb_index=}] already finished backward computation"
)
break
else:
expected_stage = prev_stage - 1
expected_computation = _ComputationType.FULL_BACKWARD
else:
raise ValueError(
f"Computation type {prev_computation} not supported"
)
if expected_computation is not None:
if expected_computation != computation_type:
error_msg.append(
f"[{mb_index=}] {expected_computation=} VS. actual {computation_type=}"
)
if expected_stage != stage_index:
error_msg.append(
f"[{mb_index=}] {expected_stage=} VS. actual {stage_index=}"
)
microbatch_process_info[mb_index] = (
expected_computation,
expected_stage,
)
if not enable_zero_bubble:
if len(error_msg) != 0:
raise RuntimeError(
f"Error at timestep {timestep}: " + ",".join(error_msg)
)
return
for rank in range(len(pipeline_order)):
backward_steps: Set[Tuple[int, int]] = set()
weight_steps: Set[Tuple[int, int]] = set()
for action in pipeline_order[rank]:
if action is None:
continue
stage_index = action.stage_index
computation_type = action.computation_type
mb_index = action.microbatch_index
if computation_type == _ComputationType.FULL_BACKWARD:
if mb_index is not None:
backward_steps.add((mb_index, stage_index))
elif computation_type == _ComputationType.BACKWARD_WEIGHT:
if (mb_index, stage_index) not in backward_steps:
error_msg.append(
f"{mb_index=}, {stage_index=} Weight happened before bwd"
)
if (mb_index, stage_index) in weight_steps:
error_msg.append(
f"{mb_index=}, {stage_index=} Duplicated weight step"
)
if mb_index is not None:
weight_steps.add((mb_index, stage_index))
if len(backward_steps) != len(weight_steps):
error_msg.append("Length weight steps != Length bwd steps")
if len(error_msg) != 0:
raise RuntimeError(f"Error at timestep {timestep}: " + ",".join(error_msg))
class _PipelineSchedule(ABC):
def __init__(
self,
@ -1136,7 +992,7 @@ class PipelineScheduleMulti(_PipelineSchedule):
kwargs_chunk_spec: Optional[Dict[str, TensorChunkSpec]] = None,
output_merge_spec: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
stage_index_to_group_rank: Optional[Dict[int, int]] = None,
use_full_backward: bool = True,
use_full_backward: Optional[bool] = None,
):
# Init parent
super().__init__(
@ -1168,7 +1024,12 @@ class PipelineScheduleMulti(_PipelineSchedule):
# This will be set during init of derived schedules
self.pipeline_order: Dict[int, List[Optional[_Action]]] = {}
self.use_full_backward = use_full_backward
if use_full_backward is not None:
logger.warning(
"Deprecation warning: 'use_full_backward' is no longer supported. "
"Simply stop passing it, and everything should still work fine."
)
def _initialize_stages(self, args: Tuple[Any, ...], kwargs):
# may be 'none' value (if this stage sends its output shapes to the next stage via P2P)
@ -1203,7 +1064,7 @@ class PipelineScheduleMulti(_PipelineSchedule):
num_microbatches: int,
):
# We will count all the actions per stage and ensure they happen in a valid order
# (e.g. F before B before W for a given microbatch)
# (e.g. F before (B, I) before W for a given microbatch)
stage_actions: Dict[int, Dict[_ComputationType, Set]] = {
stage_id: {
F: set(),
@ -1227,15 +1088,18 @@ class PipelineScheduleMulti(_PipelineSchedule):
elif ctype == B:
assert (
mb_id in stage_actions[s_id][F]
), f"Running Backward for stage {s_id}, microbatch {mb_id} without first running Forward"
), f"Running Full Backward for stage {s_id}, microbatch {mb_id} without first running Forward"
stage_actions[s_id][B].add(mb_id)
elif ctype == I:
assert (
mb_id in stage_actions[s_id][F]
), f"Running Backward Input for stage {s_id}, microbatch {mb_id} without first running Forward"
# TODO(whc) do we need to track I separately from B or should we just merge them for simplicity
stage_actions[s_id][B].add(mb_id)
elif ctype == W:
assert (
not self.use_full_backward
), "Schedule contains 'W' actions, but is configured to use full backward"
assert (
mb_id in stage_actions[s_id][B]
), f"Running Weight for stage {s_id}, microbatch {mb_id} without first running Backward"
), f"Running Backward Weight for stage {s_id}, microbatch {mb_id} without first running Backward"
stage_actions[s_id][W].add(mb_id)
for s_id in stage_actions:
@ -1358,25 +1222,28 @@ class PipelineScheduleMulti(_PipelineSchedule):
)
self._maybe_compute_loss(stage, output, target_mbs, mb_index)
ops.extend(stage.get_fwd_send_ops(mb_index))
# TODO(whc) for now i'm going with the hopefully backward-compatible position that legacy IR with
# 'B' maps to ambiguous backward that is either full or d_Input based on 'use_full_backward' flag.
# Later, we should deprecate this flag, and rely on the IR to determine the type of backward
elif computation_type == _ComputationType.FULL_BACKWARD:
# perform backward computation
stage = stage_index_to_stage[stage_index]
loss = self._maybe_get_loss(stage, mb_index)
stage.backward_one_chunk(
mb_index, loss=loss, full_backward=self.use_full_backward
mb_index,
loss=loss,
full_backward=True,
)
ops.extend(stage.get_bwd_send_ops(mb_index))
elif computation_type == _ComputationType.BACKWARD_INPUT:
# perform backward computation
stage = stage_index_to_stage[stage_index]
loss = self._maybe_get_loss(stage, mb_index)
stage.backward_one_chunk(
mb_index,
loss=loss,
full_backward=False,
)
ops.extend(stage.get_bwd_send_ops(mb_index))
elif computation_type == _ComputationType.BACKWARD_WEIGHT:
# perform weight update
if self.use_full_backward:
raise ValueError(
f"We detected a weight update in the pipeline schedule, but \
{self.use_full_backward=}"
)
stage = stage_index_to_stage[stage_index]
stage.backward_weight_one_chunk(mb_index)
else:
@ -1404,11 +1271,12 @@ class PipelineScheduleMulti(_PipelineSchedule):
# however that is not necessarily true of get_fwd_recv_ops
stage = stage_index_to_stage[stage_index + 1]
ops.extend(stage.get_fwd_recv_ops(mb_index))
elif (
computation_type == _ComputationType.FULL_BACKWARD
or computation_type == _ComputationType.BACKWARD_WEIGHT
elif computation_type in (
FULL_BACKWARD,
BACKWARD_INPUT,
BACKWARD_WEIGHT,
):
# Previous rank doing backward or weight update has no influence for the current rank forward recv
# Previous rank doing backward has no influence for the current rank forward recv
pass
else:
raise ValueError(
@ -1427,13 +1295,10 @@ class PipelineScheduleMulti(_PipelineSchedule):
mb_index is not None
), "All currently supported action types require valid microbatch_index"
# Only handle receives for the backwards from a next rank
if (
computation_type == _ComputationType.FORWARD
or computation_type == _ComputationType.BACKWARD_WEIGHT
):
if computation_type in (FORWARD, BACKWARD_WEIGHT):
# Next rank doing forward or weight update has no influence for the current rank backward recv
pass
elif computation_type == _ComputationType.FULL_BACKWARD:
elif computation_type in (BACKWARD_INPUT, FULL_BACKWARD):
# If not the first stage, then receive bwd gradients
if stage_index - 1 in stage_index_to_stage:
# TODO: We are assuming that stage will always receive from stage+1
@ -1719,12 +1584,6 @@ class _PipelineScheduleRuntime(PipelineScheduleMulti):
if stage_uses_fsdp:
_assert_unsharded(stage_idx)
if self.use_full_backward:
raise ValueError(
f"We detected a weight update in the pipeline schedule, but \
{self.use_full_backward=}"
)
if not stage.is_last:
assert (
stage_idx,
@ -1749,11 +1608,6 @@ class _PipelineScheduleRuntime(PipelineScheduleMulti):
if stage_uses_fsdp:
_assert_unsharded(stage_idx)
if self.use_full_backward:
raise ValueError(
f"We detected a weight update in the pipeline schedule, but \
{self.use_full_backward=}"
)
stage.backward_weight_one_chunk(mb_index)
else:
raise ValueError(f"{action=} is unknown or unsupported")
@ -1883,6 +1737,10 @@ def _get_1f1b_rank_ops(
backward_op_ids = []
weight_op_count = 0
FULL_BACKWARD_OR_BACKWARD_INPUT = (
BACKWARD_INPUT if enable_zero_bubble else FULL_BACKWARD
)
for op in range(total_ops):
# Warmup phase
if op < warmup_ops:
@ -1911,7 +1769,7 @@ def _get_1f1b_rank_ops(
bwd_mb_index := bwd_stage_mb_index[bwd_stage_index]
) + 1
rank_ops.append(
_Action(bwd_stage_index, _ComputationType.FULL_BACKWARD, bwd_mb_index)
_Action(bwd_stage_index, FULL_BACKWARD_OR_BACKWARD_INPUT, bwd_mb_index)
)
backward_op_ids.append(op)
@ -1942,7 +1800,7 @@ def _get_1f1b_rank_ops(
bwd_mb_index := bwd_stage_mb_index[bwd_stage_index]
) + 1
rank_ops.append(
_Action(bwd_stage_index, _ComputationType.FULL_BACKWARD, bwd_mb_index)
_Action(bwd_stage_index, FULL_BACKWARD_OR_BACKWARD_INPUT, bwd_mb_index)
)
backward_op_ids.append(op)
@ -2119,7 +1977,6 @@ class ScheduleInterleavedZeroBubble(PipelineScheduleMulti):
args_chunk_spec=args_chunk_spec,
kwargs_chunk_spec=kwargs_chunk_spec,
output_merge_spec=output_merge_spec,
use_full_backward=False,
)
self.n_local_stages = len(stages)
self.rank = stages[0].group_rank