[PP] Add DualPipeV schedule (#159591)

Added the DualPipeV schedule according to http://github.com/deepseek-ai/DualPipe/blob/main/dualpipe/dualpipev.py#L11

<img width="3633" height="486" alt="image" src="https://github.com/user-attachments/assets/4e843bb9-87cd-4d11-936c-7dfe8ee12f16" />

This schedule doesn't perform the actual "overlap" during execution, but provides the scaffolding and schedule definition we need to run it E2E in torchtitan. Supporting the overlapped operation will be worked on in following PRs.

Tests:
```sh
python test/distributed/pipelining/test_schedule_multiproc.py -k test_v_shape_schedules
python test/distributed/pipelining/test_schedule.py -k test_pipeline_order_for_v_schedules
```

Also tested in TorchTitan and is running.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159591
Approved by: https://github.com/wconstab
This commit is contained in:
Howard Huang 2025-08-13 17:56:41 -07:00 committed by PyTorch MergeBot
parent 20bdabbb3c
commit 198b5fd2d4
6 changed files with 406 additions and 21 deletions

View File

@ -504,6 +504,10 @@ The following set of APIs transform your model into a pipeline representation.
.. autoclass:: ScheduleZBVZeroBubble
```
```{eval-rst}
.. autoclass:: ScheduleDualPipeV
```
```{eval-rst}
.. autoclass:: PipelineScheduleSingle
:members:

View File

@ -0,0 +1,4 @@
0F0,0F1,0F2,0F3,0F4,0F5,0F6,7F0,7I0,7W0,7F1,7I1,7W1,7F2,7I2,7W2,7F3,(0F7;7B3)OVERLAP_F_B,(7F4;0B0)OVERLAP_F_B,(0F8;7B4)OVERLAP_F_B,(7F5;0B1)OVERLAP_F_B,(0F9;7B5)OVERLAP_F_B,(7F6;0B2)OVERLAP_F_B,7B6,(7F7;0B3)OVERLAP_F_B,7B7,(7F8;0B4)OVERLAP_F_B,7B8,(7F9;0B5)OVERLAP_F_B,7B9,0I6,0W6,0I7,0W7,0I8,0W8,0I9,0W9
1F0,1F1,1F2,1F3,1F4,6F0,1F5,6F1,6I0,6W0,6F2,6I1,6W1,6F3,(1F6;6B2)OVERLAP_F_B,(6F4;1B0)OVERLAP_F_B,(1F7;6B3)OVERLAP_F_B,(6F5;1B1)OVERLAP_F_B,(1F8;6B4)OVERLAP_F_B,(6F6;1B2)OVERLAP_F_B,(1F9;6B5)OVERLAP_F_B,(6F7;1B3)OVERLAP_F_B,6B6,(6F8;1B4)OVERLAP_F_B,6B7,(6F9;1B5)OVERLAP_F_B,6B8,1B6,6I9,1I7,6W9,1I8,1W7,1I9,1W8,1W9
2F0,2F1,2F2,5F0,2F3,5F1,2F4,5F2,5I0,5W0,5F3,(2F5;5B1)OVERLAP_F_B,(5F4;2B0)OVERLAP_F_B,(2F6;5B2)OVERLAP_F_B,(5F5;2B1)OVERLAP_F_B,(2F7;5B3)OVERLAP_F_B,(5F6;2B2)OVERLAP_F_B,(2F8;5B4)OVERLAP_F_B,(5F7;2B3)OVERLAP_F_B,(2F9;5B5)OVERLAP_F_B,(5F8;2B4)OVERLAP_F_B,5B6,(5F9;2B5)OVERLAP_F_B,5B7,2B6,5B8,2I7,5I9,2I8,2W7,2I9,5W9,2W8,2W9
3F0,4F0,3F1,4F1,3F2,4F2,3F3,4F3,3F4,4B0,(4F4;3B0)OVERLAP_F_B,(3F5;4B1)OVERLAP_F_B,(4F5;3B1)OVERLAP_F_B,(3F6;4B2)OVERLAP_F_B,(4F6;3B2)OVERLAP_F_B,(3F7;4B3)OVERLAP_F_B,(4F7;3B3)OVERLAP_F_B,(3F8;4B4)OVERLAP_F_B,(4F8;3B4)OVERLAP_F_B,(3F9;4B5)OVERLAP_F_B,(4F9;3B5)OVERLAP_F_B,4B6,3B6,4B7,3B7,4I8,3I8,4I9,3I9,4W8,3W8,4W9,3W9
1 0F0,0F1,0F2,0F3,0F4,0F5,0F6,7F0,7I0,7W0,7F1,7I1,7W1,7F2,7I2,7W2,7F3,(0F7;7B3)OVERLAP_F_B,(7F4;0B0)OVERLAP_F_B,(0F8;7B4)OVERLAP_F_B,(7F5;0B1)OVERLAP_F_B,(0F9;7B5)OVERLAP_F_B,(7F6;0B2)OVERLAP_F_B,7B6,(7F7;0B3)OVERLAP_F_B,7B7,(7F8;0B4)OVERLAP_F_B,7B8,(7F9;0B5)OVERLAP_F_B,7B9,0I6,0W6,0I7,0W7,0I8,0W8,0I9,0W9
2 1F0,1F1,1F2,1F3,1F4,6F0,1F5,6F1,6I0,6W0,6F2,6I1,6W1,6F3,(1F6;6B2)OVERLAP_F_B,(6F4;1B0)OVERLAP_F_B,(1F7;6B3)OVERLAP_F_B,(6F5;1B1)OVERLAP_F_B,(1F8;6B4)OVERLAP_F_B,(6F6;1B2)OVERLAP_F_B,(1F9;6B5)OVERLAP_F_B,(6F7;1B3)OVERLAP_F_B,6B6,(6F8;1B4)OVERLAP_F_B,6B7,(6F9;1B5)OVERLAP_F_B,6B8,1B6,6I9,1I7,6W9,1I8,1W7,1I9,1W8,1W9
3 2F0,2F1,2F2,5F0,2F3,5F1,2F4,5F2,5I0,5W0,5F3,(2F5;5B1)OVERLAP_F_B,(5F4;2B0)OVERLAP_F_B,(2F6;5B2)OVERLAP_F_B,(5F5;2B1)OVERLAP_F_B,(2F7;5B3)OVERLAP_F_B,(5F6;2B2)OVERLAP_F_B,(2F8;5B4)OVERLAP_F_B,(5F7;2B3)OVERLAP_F_B,(2F9;5B5)OVERLAP_F_B,(5F8;2B4)OVERLAP_F_B,5B6,(5F9;2B5)OVERLAP_F_B,5B7,2B6,5B8,2I7,5I9,2I8,2W7,2I9,5W9,2W8,2W9
4 3F0,4F0,3F1,4F1,3F2,4F2,3F3,4F3,3F4,4B0,(4F4;3B0)OVERLAP_F_B,(3F5;4B1)OVERLAP_F_B,(4F5;3B1)OVERLAP_F_B,(3F6;4B2)OVERLAP_F_B,(4F6;3B2)OVERLAP_F_B,(3F7;4B3)OVERLAP_F_B,(4F7;3B3)OVERLAP_F_B,(3F8;4B4)OVERLAP_F_B,(4F8;3B4)OVERLAP_F_B,(3F9;4B5)OVERLAP_F_B,(4F9;3B5)OVERLAP_F_B,4B6,3B6,4B7,3B7,4I8,3I8,4I9,3I9,4W8,3W8,4W9,3W9

View File

@ -10,10 +10,12 @@ from model_registry import MultiMLP
import torch
from torch.distributed.pipelining import (
Schedule1F1B,
ScheduleDualPipeV,
ScheduleGPipe,
ScheduleInterleaved1F1B,
ScheduleInterleavedZeroBubble,
ScheduleLoopedBFS,
ScheduleZBVZeroBubble,
)
from torch.distributed.pipelining._utils import generate_stage_to_rank_mapping
from torch.distributed.pipelining.schedules import (
@ -348,10 +350,91 @@ class TestSchedulePlan(TestCase):
num_stages=num_stages,
)
@parametrize(
"ScheduleClass",
[ScheduleDualPipeV, ScheduleZBVZeroBubble],
)
def test_pipeline_order_for_v_schedules(self, ScheduleClass):
for num_local_stages, num_microbatches, group_size in self.test_cases:
with self.subTest(
num_local_stages=num_local_stages,
num_microbatches=num_microbatches,
group_size=group_size,
):
num_stages = num_local_stages * group_size
stages = [
MockPipelineStage(group_size=group_size, num_stages=num_stages)
for i in range(num_local_stages)
]
# V schedules only support 2 stages per rank so if num_local_stages is not 2, ensure an error is thrown
if num_local_stages != 2:
with self.assertRaises(ValueError):
ScheduleClass(
stages,
num_microbatches,
)
continue
# DualPipeV requires num_microbatches to be >= num_stages
if ScheduleClass == ScheduleDualPipeV and num_microbatches < num_stages:
with self.assertRaises(ValueError):
ScheduleClass(
stages,
num_microbatches,
)
continue
# Create schedule and validate it
schedule = ScheduleClass(stages, num_microbatches)
_validate_schedule(
schedule.pipeline_order, group_size, num_stages, num_microbatches
)
instantiate_parametrized_tests(TestSchedulePlan)
class TestScheduleCsv(TestCase):
@parametrize(
"ScheduleClass,csv_name",
[
(ScheduleDualPipeV, "dualpipev_4rank_10mb"),
],
)
def test_csv_compare(self, ScheduleClass, csv_name):
"""
Test that schedules matches the expected CSV. This is a regression test to ensure that the schedule
is not changed unintentionally.
"""
num_local_stages = 2
group_size = 4
num_stages = num_local_stages * group_size
stages = [
MockPipelineStage(group_size=group_size, num_stages=num_stages)
for _ in range(num_local_stages)
]
num_microbatches = 10
schedule = ScheduleClass(stages, num_microbatches)
comms_csv = os.path.join(ARTIFACTS_DIR, f"{csv_name}.csv")
sch = schedule.pipeline_order
# Uncomment to regenerate reference output
# schedule._dump_csv("test.csv", "compute_only")
sch_ref = {}
with open(comms_csv, newline="") as ref:
for rank, row in enumerate(csv.reader(ref)):
sch_ref[rank] = [_Action.from_str(s) for s in row]
for rank in sch_ref:
for timestep, (a, b) in enumerate(zip(sch[rank], sch_ref[rank])):
self.assertEqual(a, b, f"Mismatch at {timestep=}, {a=}, expected {b}")
instantiate_parametrized_tests(TestScheduleCsv)
class TestScheduleLowering(TestCase):
"""Tests lowering passes that convert simple compute-only (FBW) schedules into compute+comms schedules"""

View File

@ -19,6 +19,7 @@ from torch.distributed.pipelining import (
pipeline,
PipelineStage,
Schedule1F1B,
ScheduleDualPipeV,
ScheduleGPipe,
ScheduleInterleaved1F1B,
ScheduleInterleavedZeroBubble,
@ -106,7 +107,9 @@ class ScheduleTest(MultiProcContinousTest):
stage_modules = [mod.get_submodule(submod_name) for submod_name in submod_names]
stages = [
PipelineStage(stage_module, stage_idx, n_stages, self.device)
for stage_module, stage_idx in zip(stage_modules, stage_indices)
for stage_module, stage_idx in zip(
stage_modules, stage_indices, strict=True
)
]
return stages, stage_modules, submod_names
@ -137,7 +140,13 @@ class ScheduleTest(MultiProcContinousTest):
raise AssertionError(
f"One gradient is None for {param_name}: {grad1} vs {grad2}"
)
torch.testing.assert_close(grad1, grad2, rtol=rtol, atol=atol)
try:
torch.testing.assert_close(grad1, grad2, rtol=rtol, atol=atol)
except AssertionError:
print(
f"Numerical issues detected for {param_name}: param grad {grad1} vs ref grad {grad2}"
)
raise
if submod_names is None:
# Single stage case - need to detect tracer vs manual pipeline
@ -682,16 +691,69 @@ class ScheduleTest(MultiProcContinousTest):
@requires_nccl()
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
@parametrize(
"schedule_class", [ScheduleVShaped, ScheduleUnbalanced, ScheduleZBVZeroBubble]
"schedule_class",
[ScheduleZBVZeroBubble, ScheduleDualPipeV],
)
@parametrize("use_new_runtime", [False, True])
def test_v_shape_schedules(self, schedule_class, use_new_runtime):
# n_stages = 8
# rank_stages = {0: [0, 7], 1: [1, 6], 2: [2, 5], 3: [3, 4]}
n_stages = 4
rank_stages = {0: [0, 3], 1: [1, 2]}
mod, ref_mod, x, target, loss_fn = self._setup_models_and_data(
n_layers=n_stages
)
# Run reference
ref_out, ref_loss = self._run_reference_model(ref_mod, x, target, loss_fn)
# Create multi-stage pipeline with custom stage indices
num_microbatches = 8
stage_indices = rank_stages[self.rank]
stages, stage_modules, submod_names = self._create_multi_stage_pipeline(
mod, len(stage_indices), n_stages, stage_indices
)
schedule = schedule_class(
stages, num_microbatches, loss_fn=loss_fn, scale_grads=False
)
if schedule_class != ScheduleDualPipeV and use_new_runtime:
old_schedule = schedule
schedule = _PipelineScheduleRuntime(
stages, num_microbatches, loss_fn=loss_fn
)
schedule._load_actions(old_schedule.pipeline_order)
# Run pipeline - special case where first and last stage are on rank 0
out = None
losses = []
for _ in range(2):
self._zero_gradients(stage_modules)
if self.rank == 0:
out = schedule.step(x, target=target, losses=losses)
else:
schedule.step()
# Verify results (rank 0 has both first and last stages)
if self.rank == 0:
torch.testing.assert_close(out, ref_out)
pipe_loss = sum(losses)
torch.testing.assert_close(pipe_loss, ref_loss)
# Check gradients using helper method
self._check_gradients(stage_modules, ref_mod, submod_names)
@requires_nccl()
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
@parametrize(
"schedule_class",
[ScheduleVShaped, ScheduleUnbalanced],
)
@parametrize("use_new_runtime", [False, True])
def test_non_symmetric_stage_ids(self, schedule_class, use_new_runtime):
if schedule_class is ScheduleZBVZeroBubble:
n_stages = 4
rank_stages = {0: [0, 3], 1: [1, 2]}
else:
n_stages = schedule_class.n_stages
rank_stages = schedule_class.rank_stages
n_stages = schedule_class.n_stages
rank_stages = schedule_class.rank_stages
mod, ref_mod, x, target, loss_fn = self._setup_models_and_data(
n_layers=n_stages

View File

@ -3,6 +3,7 @@ from ._IR import Pipe, pipe_split, pipeline, SplitPoint
from .schedules import (
_ScheduleForwardOnly,
Schedule1F1B,
ScheduleDualPipeV,
ScheduleGPipe,
ScheduleInterleaved1F1B,
ScheduleInterleavedZeroBubble,
@ -25,4 +26,5 @@ __all__ = [
"ScheduleLoopedBFS",
"ScheduleInterleavedZeroBubble",
"ScheduleZBVZeroBubble",
"ScheduleDualPipeV",
]

View File

@ -18,7 +18,7 @@ from torch.distributed.fsdp import FSDPModule, UnshardHandle
from torch.nn.modules.loss import _Loss
from torch.profiler import record_function
from ._utils import generate_stage_to_rank_mapping
from ._utils import generate_rank_to_stage_mapping, generate_stage_to_rank_mapping
from .microbatch import merge_chunks, split_args_kwargs_into_chunks, TensorChunkSpec
from .stage import _PipelineStageBase
@ -33,6 +33,7 @@ __all__ = [
"ScheduleLoopedBFS",
"ScheduleInterleavedZeroBubble",
"ScheduleZBVZeroBubble",
"ScheduleDualPipeV",
]
logger = logging.getLogger(__name__)
@ -1796,17 +1797,24 @@ class _PipelineScheduleRuntime(PipelineScheduleMulti):
else:
raise NotImplementedError(f"{format=} is not implemented")
def _dump_csv(self, filename: str):
"""Dump a CSV representation of the compute + comms schedule into a file with the provided filename."""
# TODO should there be an option to dump the compute_only schedule from PipelineScheduleRuntime? It's possible
# that it does not exist if it was created from a compute_comms schedule.
assert self.pipeline_order_with_comms is not None, (
"Must initialize compute_comms schedule before dump_csv"
)
with open(filename, "w", newline="") as csvfile:
writer = csv.writer(csvfile)
for rank in self.pipeline_order_with_comms:
writer.writerow(self.pipeline_order_with_comms[rank])
def _dump_csv(self, filename: str, format: str = "compute_comms"):
"""Dump a CSV representation of the schedule into a file with the provided filename."""
if format == "compute_only":
assert self.pipeline_order is not None, (
"Compute only schedule must be available"
)
with open(filename, "w", newline="") as csvfile:
writer = csv.writer(csvfile)
for rank in self.pipeline_order:
writer.writerow(self.pipeline_order[rank])
elif format == "compute_comms":
assert self.pipeline_order_with_comms is not None, (
"Must initialize compute_comms schedule before dump_csv"
)
with open(filename, "w", newline="") as csvfile:
writer = csv.writer(csvfile)
for rank in self.pipeline_order_with_comms:
writer.writerow(self.pipeline_order_with_comms[rank])
def _simulate(self):
return _simulate_comms_compute(
@ -2750,6 +2758,227 @@ class ScheduleZBVZeroBubble(PipelineScheduleMulti):
return rank_ops
class ScheduleDualPipeV(_PipelineScheduleRuntime):
"""
The DualPipeV schedule. A more efficient schedule variant based on the
DualPipe schedule introduced by DeepSeek in https://arxiv.org/pdf/2412.19437
Based on the open sourced code from https://github.com/deepseek-ai/DualPipe
"""
def __init__(
self,
stages: list[_PipelineStageBase],
n_microbatches: int,
loss_fn: Optional[Callable] = None,
args_chunk_spec: Optional[tuple[TensorChunkSpec, ...]] = None,
kwargs_chunk_spec: Optional[dict[str, TensorChunkSpec]] = None,
output_merge_spec: Optional[Union[dict[str, Any], tuple[Any]]] = None,
scale_grads: bool = True,
):
self.pp_group_size = stages[0].group_size
super().__init__(
stages=stages,
n_microbatches=n_microbatches,
loss_fn=loss_fn,
args_chunk_spec=args_chunk_spec,
kwargs_chunk_spec=kwargs_chunk_spec,
output_merge_spec=output_merge_spec,
scale_grads=scale_grads,
)
self.stage_index_to_group_rank = generate_stage_to_rank_mapping(
self.pp_group_size, self._num_stages, style="v"
)
for stage in self._stages:
stage.stage_index_to_group_rank = self.stage_index_to_group_rank
self.n_local_stages = len(stages)
if self.n_local_stages != 2:
raise ValueError(
"ZBV requires exactly 2 stages per rank, but got "
f"{self.n_local_stages}."
)
if n_microbatches < self._num_stages:
raise ValueError(
"DualPipeV requires at least as many microbatches as stages, but got "
f"{n_microbatches} microbatches and {self._num_stages} stages."
)
self.rank = stages[0].group_rank
self.num_stages = stages[0].num_stages
# 1. Create the pipeline_order (all ranks do this calculation)
# This will be used to keep track of the current state of the entire pipeline
# pipeline_order[rank] = [Action(computation_type, microbatch_index, stage_index), ...]
self.pipeline_order: dict[int, list[Optional[_Action]]] = {}
for rank in range(self.pp_group_size):
rank_ops = self._calculate_single_rank_operations(rank)
self.pipeline_order[rank] = rank_ops
# Initialize the pipeline order with communication necessary to run with _PipelineScheduleRuntime
self._load_actions(self.pipeline_order)
def _calculate_single_rank_operations(self, rank) -> list[Optional[_Action]]:
actions: list[Optional[_Action]] = []
counters: dict[
tuple[int, _ComputationType], int
] = {} # (stage_index, computation_type) -> mb_index
weight_queue = [] # Queue of (stage_index, mb_index) for pending weight actions
num_ranks = self.pp_group_size
num_chunks = self._n_microbatches
rank_to_stages = generate_rank_to_stage_mapping(
num_ranks, num_ranks * 2, style="v"
)
stage0_index, stage1_index = rank_to_stages[rank]
def increment_backward_counts(stage_index: int):
"""Helper method to increment BACKWARD_INPUT and BACKWARD_WEIGHT counters when FULL_BACKWARD is used."""
input_key = (stage_index, BACKWARD_INPUT)
weight_key = (stage_index, BACKWARD_WEIGHT)
counters[input_key] = counters.get(input_key, 0) + 1
counters[weight_key] = counters.get(weight_key, 0) + 1
def add_overlap_f_b(
actions: list,
forward_stage: int,
backward_stage: int,
):
"""Helper method to add an overlapped forward+backward action which tracks microbatch index."""
# Create new overlapped forward+backward action with sub_actions
forward_key = (forward_stage, FORWARD)
backward_key = (backward_stage, BACKWARD_INPUT)
forward_mb = counters.get(forward_key, 0)
backward_mb = counters.get(backward_key, 0)
sub_actions = (
_Action(forward_stage, FORWARD, forward_mb),
_Action(backward_stage, FULL_BACKWARD, backward_mb),
)
actions.append(_Action(-1, OVERLAP_F_B, None, sub_actions))
# Update counters for sub_actions
counters[forward_key] = forward_mb + 1
increment_backward_counts(backward_stage)
def add_action(
actions: list,
stage_index: int,
computation_type: _ComputationType,
):
# Regular single action, for FULL_BACKWARD we only use the BACKWARD_INPUT counter
key = (
(stage_index, computation_type)
if computation_type != FULL_BACKWARD
else (stage_index, BACKWARD_INPUT)
)
mb_index = counters.get(key, 0)
actions.append(_Action(stage_index, computation_type, mb_index))
# If FULL_BACKWARD is used, just increment the separate BACKWARD_INPUT and BACKWARD_WEIGHT counters
if computation_type == FULL_BACKWARD:
increment_backward_counts(stage_index)
else:
# If BACKWARD_INPUT is updated, add corresponding weight action to queue
if computation_type == BACKWARD_INPUT:
# Add weight action to queue for later processing
weight_queue.append((stage_index, mb_index))
counters[key] = mb_index + 1
def add_weight_action_if_pending(actions: list):
"""Helper method to add a weight action from the queue."""
if not weight_queue:
return # No pending weight actions, skip
# Pop the oldest weight action from the queue
actual_stage_index, weight_mb_index = weight_queue.pop(0)
actions.append(
_Action(
actual_stage_index,
BACKWARD_WEIGHT,
weight_mb_index,
)
)
# Update the counter for the actual stage that was processed
weight_key = (actual_stage_index, BACKWARD_WEIGHT)
counters[weight_key] = counters.get(weight_key, 0) + 1
# Step 1: F0
step_1 = (num_ranks - rank - 1) * 2
for _ in range(step_1):
add_action(actions, stage0_index, FORWARD)
# Step 2: F0F1
step_2 = rank + 1
for _ in range(step_2):
add_action(actions, stage0_index, FORWARD)
add_action(actions, stage1_index, FORWARD)
# Step 3: I1W1F1 (Use zero bubble)
step_3 = num_ranks - rank - 1
for _ in range(step_3):
add_action(actions, stage1_index, BACKWARD_INPUT)
add_weight_action_if_pending(actions)
add_action(actions, stage1_index, FORWARD)
# Step 4 (Main step): F0B1-F1B0 (combined, overlapped forward+backward)
step_4 = num_chunks - num_ranks * 2 + rank + 1
for i in range(step_4):
if i == 0 and rank == num_ranks - 1:
# NOTE: We don't overlap these two chunks to further reduce bubble size.
add_action(actions, stage0_index, FORWARD)
add_action(actions, stage1_index, FULL_BACKWARD)
else:
add_overlap_f_b(
actions,
forward_stage=stage0_index,
backward_stage=stage1_index,
)
add_overlap_f_b(
actions,
forward_stage=stage1_index,
backward_stage=stage0_index,
)
# Step 5: B1-F1B0
step_5 = num_ranks - rank - 1
for _ in range(step_5):
add_action(actions, stage1_index, FULL_BACKWARD)
add_overlap_f_b(
actions,
forward_stage=stage1_index,
backward_stage=stage0_index,
)
# Step 6: B1B0 (The second half of the chunks use zero bubble)
step_6 = rank + 1
enable_zb = False
for i in range(step_6):
if i == step_6 // 2 and rank % 2 == 1:
enable_zb = True
comp_type = BACKWARD_INPUT if enable_zb else FULL_BACKWARD
add_action(actions, stage1_index, comp_type)
if i == step_6 // 2 and rank % 2 == 0:
enable_zb = True
comp_type = BACKWARD_INPUT if enable_zb else FULL_BACKWARD
add_action(actions, stage0_index, comp_type)
# Step 7: W0B0
step_7 = num_ranks - rank - 1
for _ in range(step_7):
add_weight_action_if_pending(actions)
comp_type = BACKWARD_INPUT if enable_zb else FULL_BACKWARD
add_action(actions, stage0_index, comp_type)
# Step 8: W0
step_8 = rank + 1
for _ in range(step_8):
add_weight_action_if_pending(actions)
return actions
def get_schedule_class(schedule_name: str):
"""
Maps a schedule name (case insensitive) to its corresponding class object.
@ -2766,6 +2995,7 @@ def get_schedule_class(schedule_name: str):
"PipelineScheduleSingle": PipelineScheduleSingle,
"PipelineScheduleMulti": PipelineScheduleMulti,
"ZBVZeroBubble": ScheduleZBVZeroBubble,
"DualPipeV": ScheduleDualPipeV,
}
lowercase_keys = {k.lower(): k for k in schedule_map.keys()}
lowercase_schedule_name = schedule_name.lower()