mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[PP] Add DualPipeV schedule (#159591)
Added the DualPipeV schedule according to http://github.com/deepseek-ai/DualPipe/blob/main/dualpipe/dualpipev.py#L11 <img width="3633" height="486" alt="image" src="https://github.com/user-attachments/assets/4e843bb9-87cd-4d11-936c-7dfe8ee12f16" /> This schedule doesn't perform the actual "overlap" during execution, but provides the scaffolding and schedule definition we need to run it E2E in torchtitan. Supporting the overlapped operation will be worked on in following PRs. Tests: ```sh python test/distributed/pipelining/test_schedule_multiproc.py -k test_v_shape_schedules python test/distributed/pipelining/test_schedule.py -k test_pipeline_order_for_v_schedules ``` Also tested in TorchTitan and is running. Pull Request resolved: https://github.com/pytorch/pytorch/pull/159591 Approved by: https://github.com/wconstab
This commit is contained in:
parent
20bdabbb3c
commit
198b5fd2d4
|
|
@ -504,6 +504,10 @@ The following set of APIs transform your model into a pipeline representation.
|
|||
.. autoclass:: ScheduleZBVZeroBubble
|
||||
```
|
||||
|
||||
```{eval-rst}
|
||||
.. autoclass:: ScheduleDualPipeV
|
||||
```
|
||||
|
||||
```{eval-rst}
|
||||
.. autoclass:: PipelineScheduleSingle
|
||||
:members:
|
||||
|
|
|
|||
|
|
@ -0,0 +1,4 @@
|
|||
0F0,0F1,0F2,0F3,0F4,0F5,0F6,7F0,7I0,7W0,7F1,7I1,7W1,7F2,7I2,7W2,7F3,(0F7;7B3)OVERLAP_F_B,(7F4;0B0)OVERLAP_F_B,(0F8;7B4)OVERLAP_F_B,(7F5;0B1)OVERLAP_F_B,(0F9;7B5)OVERLAP_F_B,(7F6;0B2)OVERLAP_F_B,7B6,(7F7;0B3)OVERLAP_F_B,7B7,(7F8;0B4)OVERLAP_F_B,7B8,(7F9;0B5)OVERLAP_F_B,7B9,0I6,0W6,0I7,0W7,0I8,0W8,0I9,0W9
|
||||
1F0,1F1,1F2,1F3,1F4,6F0,1F5,6F1,6I0,6W0,6F2,6I1,6W1,6F3,(1F6;6B2)OVERLAP_F_B,(6F4;1B0)OVERLAP_F_B,(1F7;6B3)OVERLAP_F_B,(6F5;1B1)OVERLAP_F_B,(1F8;6B4)OVERLAP_F_B,(6F6;1B2)OVERLAP_F_B,(1F9;6B5)OVERLAP_F_B,(6F7;1B3)OVERLAP_F_B,6B6,(6F8;1B4)OVERLAP_F_B,6B7,(6F9;1B5)OVERLAP_F_B,6B8,1B6,6I9,1I7,6W9,1I8,1W7,1I9,1W8,1W9
|
||||
2F0,2F1,2F2,5F0,2F3,5F1,2F4,5F2,5I0,5W0,5F3,(2F5;5B1)OVERLAP_F_B,(5F4;2B0)OVERLAP_F_B,(2F6;5B2)OVERLAP_F_B,(5F5;2B1)OVERLAP_F_B,(2F7;5B3)OVERLAP_F_B,(5F6;2B2)OVERLAP_F_B,(2F8;5B4)OVERLAP_F_B,(5F7;2B3)OVERLAP_F_B,(2F9;5B5)OVERLAP_F_B,(5F8;2B4)OVERLAP_F_B,5B6,(5F9;2B5)OVERLAP_F_B,5B7,2B6,5B8,2I7,5I9,2I8,2W7,2I9,5W9,2W8,2W9
|
||||
3F0,4F0,3F1,4F1,3F2,4F2,3F3,4F3,3F4,4B0,(4F4;3B0)OVERLAP_F_B,(3F5;4B1)OVERLAP_F_B,(4F5;3B1)OVERLAP_F_B,(3F6;4B2)OVERLAP_F_B,(4F6;3B2)OVERLAP_F_B,(3F7;4B3)OVERLAP_F_B,(4F7;3B3)OVERLAP_F_B,(3F8;4B4)OVERLAP_F_B,(4F8;3B4)OVERLAP_F_B,(3F9;4B5)OVERLAP_F_B,(4F9;3B5)OVERLAP_F_B,4B6,3B6,4B7,3B7,4I8,3I8,4I9,3I9,4W8,3W8,4W9,3W9
|
||||
|
|
|
@ -10,10 +10,12 @@ from model_registry import MultiMLP
|
|||
import torch
|
||||
from torch.distributed.pipelining import (
|
||||
Schedule1F1B,
|
||||
ScheduleDualPipeV,
|
||||
ScheduleGPipe,
|
||||
ScheduleInterleaved1F1B,
|
||||
ScheduleInterleavedZeroBubble,
|
||||
ScheduleLoopedBFS,
|
||||
ScheduleZBVZeroBubble,
|
||||
)
|
||||
from torch.distributed.pipelining._utils import generate_stage_to_rank_mapping
|
||||
from torch.distributed.pipelining.schedules import (
|
||||
|
|
@ -348,10 +350,91 @@ class TestSchedulePlan(TestCase):
|
|||
num_stages=num_stages,
|
||||
)
|
||||
|
||||
@parametrize(
|
||||
"ScheduleClass",
|
||||
[ScheduleDualPipeV, ScheduleZBVZeroBubble],
|
||||
)
|
||||
def test_pipeline_order_for_v_schedules(self, ScheduleClass):
|
||||
for num_local_stages, num_microbatches, group_size in self.test_cases:
|
||||
with self.subTest(
|
||||
num_local_stages=num_local_stages,
|
||||
num_microbatches=num_microbatches,
|
||||
group_size=group_size,
|
||||
):
|
||||
num_stages = num_local_stages * group_size
|
||||
stages = [
|
||||
MockPipelineStage(group_size=group_size, num_stages=num_stages)
|
||||
for i in range(num_local_stages)
|
||||
]
|
||||
|
||||
# V schedules only support 2 stages per rank so if num_local_stages is not 2, ensure an error is thrown
|
||||
if num_local_stages != 2:
|
||||
with self.assertRaises(ValueError):
|
||||
ScheduleClass(
|
||||
stages,
|
||||
num_microbatches,
|
||||
)
|
||||
continue
|
||||
|
||||
# DualPipeV requires num_microbatches to be >= num_stages
|
||||
if ScheduleClass == ScheduleDualPipeV and num_microbatches < num_stages:
|
||||
with self.assertRaises(ValueError):
|
||||
ScheduleClass(
|
||||
stages,
|
||||
num_microbatches,
|
||||
)
|
||||
continue
|
||||
|
||||
# Create schedule and validate it
|
||||
schedule = ScheduleClass(stages, num_microbatches)
|
||||
_validate_schedule(
|
||||
schedule.pipeline_order, group_size, num_stages, num_microbatches
|
||||
)
|
||||
|
||||
|
||||
instantiate_parametrized_tests(TestSchedulePlan)
|
||||
|
||||
|
||||
class TestScheduleCsv(TestCase):
|
||||
@parametrize(
|
||||
"ScheduleClass,csv_name",
|
||||
[
|
||||
(ScheduleDualPipeV, "dualpipev_4rank_10mb"),
|
||||
],
|
||||
)
|
||||
def test_csv_compare(self, ScheduleClass, csv_name):
|
||||
"""
|
||||
Test that schedules matches the expected CSV. This is a regression test to ensure that the schedule
|
||||
is not changed unintentionally.
|
||||
"""
|
||||
num_local_stages = 2
|
||||
group_size = 4
|
||||
num_stages = num_local_stages * group_size
|
||||
stages = [
|
||||
MockPipelineStage(group_size=group_size, num_stages=num_stages)
|
||||
for _ in range(num_local_stages)
|
||||
]
|
||||
num_microbatches = 10
|
||||
schedule = ScheduleClass(stages, num_microbatches)
|
||||
comms_csv = os.path.join(ARTIFACTS_DIR, f"{csv_name}.csv")
|
||||
sch = schedule.pipeline_order
|
||||
|
||||
# Uncomment to regenerate reference output
|
||||
# schedule._dump_csv("test.csv", "compute_only")
|
||||
|
||||
sch_ref = {}
|
||||
with open(comms_csv, newline="") as ref:
|
||||
for rank, row in enumerate(csv.reader(ref)):
|
||||
sch_ref[rank] = [_Action.from_str(s) for s in row]
|
||||
|
||||
for rank in sch_ref:
|
||||
for timestep, (a, b) in enumerate(zip(sch[rank], sch_ref[rank])):
|
||||
self.assertEqual(a, b, f"Mismatch at {timestep=}, {a=}, expected {b}")
|
||||
|
||||
|
||||
instantiate_parametrized_tests(TestScheduleCsv)
|
||||
|
||||
|
||||
class TestScheduleLowering(TestCase):
|
||||
"""Tests lowering passes that convert simple compute-only (FBW) schedules into compute+comms schedules"""
|
||||
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@ from torch.distributed.pipelining import (
|
|||
pipeline,
|
||||
PipelineStage,
|
||||
Schedule1F1B,
|
||||
ScheduleDualPipeV,
|
||||
ScheduleGPipe,
|
||||
ScheduleInterleaved1F1B,
|
||||
ScheduleInterleavedZeroBubble,
|
||||
|
|
@ -106,7 +107,9 @@ class ScheduleTest(MultiProcContinousTest):
|
|||
stage_modules = [mod.get_submodule(submod_name) for submod_name in submod_names]
|
||||
stages = [
|
||||
PipelineStage(stage_module, stage_idx, n_stages, self.device)
|
||||
for stage_module, stage_idx in zip(stage_modules, stage_indices)
|
||||
for stage_module, stage_idx in zip(
|
||||
stage_modules, stage_indices, strict=True
|
||||
)
|
||||
]
|
||||
return stages, stage_modules, submod_names
|
||||
|
||||
|
|
@ -137,7 +140,13 @@ class ScheduleTest(MultiProcContinousTest):
|
|||
raise AssertionError(
|
||||
f"One gradient is None for {param_name}: {grad1} vs {grad2}"
|
||||
)
|
||||
torch.testing.assert_close(grad1, grad2, rtol=rtol, atol=atol)
|
||||
try:
|
||||
torch.testing.assert_close(grad1, grad2, rtol=rtol, atol=atol)
|
||||
except AssertionError:
|
||||
print(
|
||||
f"Numerical issues detected for {param_name}: param grad {grad1} vs ref grad {grad2}"
|
||||
)
|
||||
raise
|
||||
|
||||
if submod_names is None:
|
||||
# Single stage case - need to detect tracer vs manual pipeline
|
||||
|
|
@ -682,16 +691,69 @@ class ScheduleTest(MultiProcContinousTest):
|
|||
@requires_nccl()
|
||||
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
|
||||
@parametrize(
|
||||
"schedule_class", [ScheduleVShaped, ScheduleUnbalanced, ScheduleZBVZeroBubble]
|
||||
"schedule_class",
|
||||
[ScheduleZBVZeroBubble, ScheduleDualPipeV],
|
||||
)
|
||||
@parametrize("use_new_runtime", [False, True])
|
||||
def test_v_shape_schedules(self, schedule_class, use_new_runtime):
|
||||
# n_stages = 8
|
||||
# rank_stages = {0: [0, 7], 1: [1, 6], 2: [2, 5], 3: [3, 4]}
|
||||
n_stages = 4
|
||||
rank_stages = {0: [0, 3], 1: [1, 2]}
|
||||
mod, ref_mod, x, target, loss_fn = self._setup_models_and_data(
|
||||
n_layers=n_stages
|
||||
)
|
||||
|
||||
# Run reference
|
||||
ref_out, ref_loss = self._run_reference_model(ref_mod, x, target, loss_fn)
|
||||
|
||||
# Create multi-stage pipeline with custom stage indices
|
||||
num_microbatches = 8
|
||||
stage_indices = rank_stages[self.rank]
|
||||
stages, stage_modules, submod_names = self._create_multi_stage_pipeline(
|
||||
mod, len(stage_indices), n_stages, stage_indices
|
||||
)
|
||||
|
||||
schedule = schedule_class(
|
||||
stages, num_microbatches, loss_fn=loss_fn, scale_grads=False
|
||||
)
|
||||
|
||||
if schedule_class != ScheduleDualPipeV and use_new_runtime:
|
||||
old_schedule = schedule
|
||||
schedule = _PipelineScheduleRuntime(
|
||||
stages, num_microbatches, loss_fn=loss_fn
|
||||
)
|
||||
schedule._load_actions(old_schedule.pipeline_order)
|
||||
|
||||
# Run pipeline - special case where first and last stage are on rank 0
|
||||
out = None
|
||||
losses = []
|
||||
for _ in range(2):
|
||||
self._zero_gradients(stage_modules)
|
||||
if self.rank == 0:
|
||||
out = schedule.step(x, target=target, losses=losses)
|
||||
else:
|
||||
schedule.step()
|
||||
|
||||
# Verify results (rank 0 has both first and last stages)
|
||||
if self.rank == 0:
|
||||
torch.testing.assert_close(out, ref_out)
|
||||
pipe_loss = sum(losses)
|
||||
torch.testing.assert_close(pipe_loss, ref_loss)
|
||||
|
||||
# Check gradients using helper method
|
||||
self._check_gradients(stage_modules, ref_mod, submod_names)
|
||||
|
||||
@requires_nccl()
|
||||
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
|
||||
@parametrize(
|
||||
"schedule_class",
|
||||
[ScheduleVShaped, ScheduleUnbalanced],
|
||||
)
|
||||
@parametrize("use_new_runtime", [False, True])
|
||||
def test_non_symmetric_stage_ids(self, schedule_class, use_new_runtime):
|
||||
if schedule_class is ScheduleZBVZeroBubble:
|
||||
n_stages = 4
|
||||
rank_stages = {0: [0, 3], 1: [1, 2]}
|
||||
else:
|
||||
n_stages = schedule_class.n_stages
|
||||
rank_stages = schedule_class.rank_stages
|
||||
n_stages = schedule_class.n_stages
|
||||
rank_stages = schedule_class.rank_stages
|
||||
|
||||
mod, ref_mod, x, target, loss_fn = self._setup_models_and_data(
|
||||
n_layers=n_stages
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ from ._IR import Pipe, pipe_split, pipeline, SplitPoint
|
|||
from .schedules import (
|
||||
_ScheduleForwardOnly,
|
||||
Schedule1F1B,
|
||||
ScheduleDualPipeV,
|
||||
ScheduleGPipe,
|
||||
ScheduleInterleaved1F1B,
|
||||
ScheduleInterleavedZeroBubble,
|
||||
|
|
@ -25,4 +26,5 @@ __all__ = [
|
|||
"ScheduleLoopedBFS",
|
||||
"ScheduleInterleavedZeroBubble",
|
||||
"ScheduleZBVZeroBubble",
|
||||
"ScheduleDualPipeV",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ from torch.distributed.fsdp import FSDPModule, UnshardHandle
|
|||
from torch.nn.modules.loss import _Loss
|
||||
from torch.profiler import record_function
|
||||
|
||||
from ._utils import generate_stage_to_rank_mapping
|
||||
from ._utils import generate_rank_to_stage_mapping, generate_stage_to_rank_mapping
|
||||
from .microbatch import merge_chunks, split_args_kwargs_into_chunks, TensorChunkSpec
|
||||
from .stage import _PipelineStageBase
|
||||
|
||||
|
|
@ -33,6 +33,7 @@ __all__ = [
|
|||
"ScheduleLoopedBFS",
|
||||
"ScheduleInterleavedZeroBubble",
|
||||
"ScheduleZBVZeroBubble",
|
||||
"ScheduleDualPipeV",
|
||||
]
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -1796,17 +1797,24 @@ class _PipelineScheduleRuntime(PipelineScheduleMulti):
|
|||
else:
|
||||
raise NotImplementedError(f"{format=} is not implemented")
|
||||
|
||||
def _dump_csv(self, filename: str):
|
||||
"""Dump a CSV representation of the compute + comms schedule into a file with the provided filename."""
|
||||
# TODO should there be an option to dump the compute_only schedule from PipelineScheduleRuntime? It's possible
|
||||
# that it does not exist if it was created from a compute_comms schedule.
|
||||
assert self.pipeline_order_with_comms is not None, (
|
||||
"Must initialize compute_comms schedule before dump_csv"
|
||||
)
|
||||
with open(filename, "w", newline="") as csvfile:
|
||||
writer = csv.writer(csvfile)
|
||||
for rank in self.pipeline_order_with_comms:
|
||||
writer.writerow(self.pipeline_order_with_comms[rank])
|
||||
def _dump_csv(self, filename: str, format: str = "compute_comms"):
|
||||
"""Dump a CSV representation of the schedule into a file with the provided filename."""
|
||||
if format == "compute_only":
|
||||
assert self.pipeline_order is not None, (
|
||||
"Compute only schedule must be available"
|
||||
)
|
||||
with open(filename, "w", newline="") as csvfile:
|
||||
writer = csv.writer(csvfile)
|
||||
for rank in self.pipeline_order:
|
||||
writer.writerow(self.pipeline_order[rank])
|
||||
elif format == "compute_comms":
|
||||
assert self.pipeline_order_with_comms is not None, (
|
||||
"Must initialize compute_comms schedule before dump_csv"
|
||||
)
|
||||
with open(filename, "w", newline="") as csvfile:
|
||||
writer = csv.writer(csvfile)
|
||||
for rank in self.pipeline_order_with_comms:
|
||||
writer.writerow(self.pipeline_order_with_comms[rank])
|
||||
|
||||
def _simulate(self):
|
||||
return _simulate_comms_compute(
|
||||
|
|
@ -2750,6 +2758,227 @@ class ScheduleZBVZeroBubble(PipelineScheduleMulti):
|
|||
return rank_ops
|
||||
|
||||
|
||||
class ScheduleDualPipeV(_PipelineScheduleRuntime):
|
||||
"""
|
||||
The DualPipeV schedule. A more efficient schedule variant based on the
|
||||
DualPipe schedule introduced by DeepSeek in https://arxiv.org/pdf/2412.19437
|
||||
|
||||
Based on the open sourced code from https://github.com/deepseek-ai/DualPipe
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
stages: list[_PipelineStageBase],
|
||||
n_microbatches: int,
|
||||
loss_fn: Optional[Callable] = None,
|
||||
args_chunk_spec: Optional[tuple[TensorChunkSpec, ...]] = None,
|
||||
kwargs_chunk_spec: Optional[dict[str, TensorChunkSpec]] = None,
|
||||
output_merge_spec: Optional[Union[dict[str, Any], tuple[Any]]] = None,
|
||||
scale_grads: bool = True,
|
||||
):
|
||||
self.pp_group_size = stages[0].group_size
|
||||
super().__init__(
|
||||
stages=stages,
|
||||
n_microbatches=n_microbatches,
|
||||
loss_fn=loss_fn,
|
||||
args_chunk_spec=args_chunk_spec,
|
||||
kwargs_chunk_spec=kwargs_chunk_spec,
|
||||
output_merge_spec=output_merge_spec,
|
||||
scale_grads=scale_grads,
|
||||
)
|
||||
self.stage_index_to_group_rank = generate_stage_to_rank_mapping(
|
||||
self.pp_group_size, self._num_stages, style="v"
|
||||
)
|
||||
for stage in self._stages:
|
||||
stage.stage_index_to_group_rank = self.stage_index_to_group_rank
|
||||
|
||||
self.n_local_stages = len(stages)
|
||||
if self.n_local_stages != 2:
|
||||
raise ValueError(
|
||||
"ZBV requires exactly 2 stages per rank, but got "
|
||||
f"{self.n_local_stages}."
|
||||
)
|
||||
if n_microbatches < self._num_stages:
|
||||
raise ValueError(
|
||||
"DualPipeV requires at least as many microbatches as stages, but got "
|
||||
f"{n_microbatches} microbatches and {self._num_stages} stages."
|
||||
)
|
||||
|
||||
self.rank = stages[0].group_rank
|
||||
self.num_stages = stages[0].num_stages
|
||||
|
||||
# 1. Create the pipeline_order (all ranks do this calculation)
|
||||
# This will be used to keep track of the current state of the entire pipeline
|
||||
# pipeline_order[rank] = [Action(computation_type, microbatch_index, stage_index), ...]
|
||||
self.pipeline_order: dict[int, list[Optional[_Action]]] = {}
|
||||
for rank in range(self.pp_group_size):
|
||||
rank_ops = self._calculate_single_rank_operations(rank)
|
||||
self.pipeline_order[rank] = rank_ops
|
||||
|
||||
# Initialize the pipeline order with communication necessary to run with _PipelineScheduleRuntime
|
||||
self._load_actions(self.pipeline_order)
|
||||
|
||||
def _calculate_single_rank_operations(self, rank) -> list[Optional[_Action]]:
|
||||
actions: list[Optional[_Action]] = []
|
||||
counters: dict[
|
||||
tuple[int, _ComputationType], int
|
||||
] = {} # (stage_index, computation_type) -> mb_index
|
||||
weight_queue = [] # Queue of (stage_index, mb_index) for pending weight actions
|
||||
|
||||
num_ranks = self.pp_group_size
|
||||
num_chunks = self._n_microbatches
|
||||
|
||||
rank_to_stages = generate_rank_to_stage_mapping(
|
||||
num_ranks, num_ranks * 2, style="v"
|
||||
)
|
||||
stage0_index, stage1_index = rank_to_stages[rank]
|
||||
|
||||
def increment_backward_counts(stage_index: int):
|
||||
"""Helper method to increment BACKWARD_INPUT and BACKWARD_WEIGHT counters when FULL_BACKWARD is used."""
|
||||
input_key = (stage_index, BACKWARD_INPUT)
|
||||
weight_key = (stage_index, BACKWARD_WEIGHT)
|
||||
counters[input_key] = counters.get(input_key, 0) + 1
|
||||
counters[weight_key] = counters.get(weight_key, 0) + 1
|
||||
|
||||
def add_overlap_f_b(
|
||||
actions: list,
|
||||
forward_stage: int,
|
||||
backward_stage: int,
|
||||
):
|
||||
"""Helper method to add an overlapped forward+backward action which tracks microbatch index."""
|
||||
# Create new overlapped forward+backward action with sub_actions
|
||||
forward_key = (forward_stage, FORWARD)
|
||||
backward_key = (backward_stage, BACKWARD_INPUT)
|
||||
|
||||
forward_mb = counters.get(forward_key, 0)
|
||||
backward_mb = counters.get(backward_key, 0)
|
||||
|
||||
sub_actions = (
|
||||
_Action(forward_stage, FORWARD, forward_mb),
|
||||
_Action(backward_stage, FULL_BACKWARD, backward_mb),
|
||||
)
|
||||
actions.append(_Action(-1, OVERLAP_F_B, None, sub_actions))
|
||||
|
||||
# Update counters for sub_actions
|
||||
counters[forward_key] = forward_mb + 1
|
||||
increment_backward_counts(backward_stage)
|
||||
|
||||
def add_action(
|
||||
actions: list,
|
||||
stage_index: int,
|
||||
computation_type: _ComputationType,
|
||||
):
|
||||
# Regular single action, for FULL_BACKWARD we only use the BACKWARD_INPUT counter
|
||||
key = (
|
||||
(stage_index, computation_type)
|
||||
if computation_type != FULL_BACKWARD
|
||||
else (stage_index, BACKWARD_INPUT)
|
||||
)
|
||||
mb_index = counters.get(key, 0)
|
||||
actions.append(_Action(stage_index, computation_type, mb_index))
|
||||
|
||||
# If FULL_BACKWARD is used, just increment the separate BACKWARD_INPUT and BACKWARD_WEIGHT counters
|
||||
if computation_type == FULL_BACKWARD:
|
||||
increment_backward_counts(stage_index)
|
||||
else:
|
||||
# If BACKWARD_INPUT is updated, add corresponding weight action to queue
|
||||
if computation_type == BACKWARD_INPUT:
|
||||
# Add weight action to queue for later processing
|
||||
weight_queue.append((stage_index, mb_index))
|
||||
counters[key] = mb_index + 1
|
||||
|
||||
def add_weight_action_if_pending(actions: list):
|
||||
"""Helper method to add a weight action from the queue."""
|
||||
if not weight_queue:
|
||||
return # No pending weight actions, skip
|
||||
# Pop the oldest weight action from the queue
|
||||
actual_stage_index, weight_mb_index = weight_queue.pop(0)
|
||||
actions.append(
|
||||
_Action(
|
||||
actual_stage_index,
|
||||
BACKWARD_WEIGHT,
|
||||
weight_mb_index,
|
||||
)
|
||||
)
|
||||
# Update the counter for the actual stage that was processed
|
||||
weight_key = (actual_stage_index, BACKWARD_WEIGHT)
|
||||
counters[weight_key] = counters.get(weight_key, 0) + 1
|
||||
|
||||
# Step 1: F0
|
||||
step_1 = (num_ranks - rank - 1) * 2
|
||||
for _ in range(step_1):
|
||||
add_action(actions, stage0_index, FORWARD)
|
||||
|
||||
# Step 2: F0F1
|
||||
step_2 = rank + 1
|
||||
for _ in range(step_2):
|
||||
add_action(actions, stage0_index, FORWARD)
|
||||
add_action(actions, stage1_index, FORWARD)
|
||||
|
||||
# Step 3: I1W1F1 (Use zero bubble)
|
||||
step_3 = num_ranks - rank - 1
|
||||
for _ in range(step_3):
|
||||
add_action(actions, stage1_index, BACKWARD_INPUT)
|
||||
add_weight_action_if_pending(actions)
|
||||
add_action(actions, stage1_index, FORWARD)
|
||||
|
||||
# Step 4 (Main step): F0B1-F1B0 (combined, overlapped forward+backward)
|
||||
step_4 = num_chunks - num_ranks * 2 + rank + 1
|
||||
for i in range(step_4):
|
||||
if i == 0 and rank == num_ranks - 1:
|
||||
# NOTE: We don't overlap these two chunks to further reduce bubble size.
|
||||
add_action(actions, stage0_index, FORWARD)
|
||||
add_action(actions, stage1_index, FULL_BACKWARD)
|
||||
else:
|
||||
add_overlap_f_b(
|
||||
actions,
|
||||
forward_stage=stage0_index,
|
||||
backward_stage=stage1_index,
|
||||
)
|
||||
add_overlap_f_b(
|
||||
actions,
|
||||
forward_stage=stage1_index,
|
||||
backward_stage=stage0_index,
|
||||
)
|
||||
|
||||
# Step 5: B1-F1B0
|
||||
step_5 = num_ranks - rank - 1
|
||||
for _ in range(step_5):
|
||||
add_action(actions, stage1_index, FULL_BACKWARD)
|
||||
add_overlap_f_b(
|
||||
actions,
|
||||
forward_stage=stage1_index,
|
||||
backward_stage=stage0_index,
|
||||
)
|
||||
|
||||
# Step 6: B1B0 (The second half of the chunks use zero bubble)
|
||||
step_6 = rank + 1
|
||||
enable_zb = False
|
||||
for i in range(step_6):
|
||||
if i == step_6 // 2 and rank % 2 == 1:
|
||||
enable_zb = True
|
||||
comp_type = BACKWARD_INPUT if enable_zb else FULL_BACKWARD
|
||||
add_action(actions, stage1_index, comp_type)
|
||||
if i == step_6 // 2 and rank % 2 == 0:
|
||||
enable_zb = True
|
||||
comp_type = BACKWARD_INPUT if enable_zb else FULL_BACKWARD
|
||||
add_action(actions, stage0_index, comp_type)
|
||||
|
||||
# Step 7: W0B0
|
||||
step_7 = num_ranks - rank - 1
|
||||
for _ in range(step_7):
|
||||
add_weight_action_if_pending(actions)
|
||||
comp_type = BACKWARD_INPUT if enable_zb else FULL_BACKWARD
|
||||
add_action(actions, stage0_index, comp_type)
|
||||
|
||||
# Step 8: W0
|
||||
step_8 = rank + 1
|
||||
for _ in range(step_8):
|
||||
add_weight_action_if_pending(actions)
|
||||
|
||||
return actions
|
||||
|
||||
|
||||
def get_schedule_class(schedule_name: str):
|
||||
"""
|
||||
Maps a schedule name (case insensitive) to its corresponding class object.
|
||||
|
|
@ -2766,6 +2995,7 @@ def get_schedule_class(schedule_name: str):
|
|||
"PipelineScheduleSingle": PipelineScheduleSingle,
|
||||
"PipelineScheduleMulti": PipelineScheduleMulti,
|
||||
"ZBVZeroBubble": ScheduleZBVZeroBubble,
|
||||
"DualPipeV": ScheduleDualPipeV,
|
||||
}
|
||||
lowercase_keys = {k.lower(): k for k in schedule_map.keys()}
|
||||
lowercase_schedule_name = schedule_name.lower()
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user