mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
This reverts commitec789a3c9d. Reverted https://github.com/pytorch/pytorch/pull/129369 on behalf of https://github.com/clee2000 due to broke test/distributed/pipelining/test_schedule.py::ScheduleTest::test_non_symmetric_stage_ids_ScheduleClass0 on distributed cuda https://github.com/pytorch/pytorch/actions/runs/9766039400/job/26959115773ec789a3c9d. You can see the error on the PR, but Dr. CI classified it wrong ([comment](https://github.com/pytorch/pytorch/pull/129369#issuecomment-2204568418))
1242 lines
49 KiB
Python
1242 lines
49 KiB
Python
# mypy: allow-untyped-defs
|
|
# Copyright (c) Meta Platforms, Inc. and affiliates
|
|
|
|
import csv
|
|
import logging
|
|
import re
|
|
from abc import ABC, abstractmethod
|
|
from collections import defaultdict
|
|
from enum import Enum
|
|
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Set, Tuple, Union
|
|
|
|
import torch
|
|
import torch.distributed as dist
|
|
from torch.profiler import record_function
|
|
|
|
from .microbatch import merge_chunks, split_args_kwargs_into_chunks, TensorChunkSpec
|
|
from .stage import _PipelineStageBase
|
|
|
|
|
|
__all__ = [
|
|
"PipelineScheduleSingle",
|
|
"PipelineScheduleMulti",
|
|
"Schedule1F1B",
|
|
"ScheduleFlexibleInterleaved1F1B",
|
|
"ScheduleGPipe",
|
|
"ScheduleInterleaved1F1B",
|
|
"ScheduleLoopedBFS",
|
|
]
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class _ComputationType(Enum):
|
|
FORWARD = 1
|
|
BACKWARD = 2
|
|
WEIGHT = 3
|
|
|
|
def __str__(self):
|
|
str_map = {
|
|
_ComputationType.FORWARD: "F",
|
|
_ComputationType.BACKWARD: "B",
|
|
_ComputationType.WEIGHT: "W",
|
|
}
|
|
return str_map[self]
|
|
|
|
@staticmethod
|
|
def from_str(action):
|
|
if action == "F":
|
|
return _ComputationType.FORWARD
|
|
elif action == "B":
|
|
return _ComputationType.BACKWARD
|
|
elif action == "W":
|
|
return _ComputationType.WEIGHT
|
|
else:
|
|
raise RuntimeError(f"Invalid computation type {action}")
|
|
|
|
|
|
F = _ComputationType.FORWARD
|
|
B = _ComputationType.BACKWARD
|
|
W = _ComputationType.WEIGHT
|
|
|
|
|
|
_action_regex = re.compile(r"(\d+)([F,B,W])(\d+)")
|
|
|
|
|
|
class _Action(NamedTuple):
|
|
computation_type: _ComputationType
|
|
microbatch_index: int
|
|
stage_index: int
|
|
|
|
def __repr__(self):
|
|
return f"{self.stage_index}{self.computation_type}{self.microbatch_index}"
|
|
|
|
@staticmethod
|
|
def from_str(str):
|
|
"""
|
|
Reverse of __repr__
|
|
|
|
String should be formatted as [stage][action type][microbatch] e.g. `2F0`
|
|
"""
|
|
if match := _action_regex.match(str):
|
|
stage_index, computation_type, microbatch_index = match.groups()
|
|
return _Action(
|
|
_ComputationType.from_str(computation_type),
|
|
int(microbatch_index),
|
|
int(stage_index),
|
|
)
|
|
elif str == "":
|
|
return None
|
|
raise RuntimeError(
|
|
f"Invalid action string: {str}, should be formatted as [stage][action type][microbatch] e.g. 2F0"
|
|
)
|
|
|
|
|
|
class _PipelineSchedule(ABC):
|
|
def __init__(
|
|
self,
|
|
n_microbatches: int,
|
|
loss_fn: Optional[Callable[..., torch.Tensor]] = None,
|
|
args_chunk_spec: Optional[Tuple[TensorChunkSpec, ...]] = None,
|
|
kwargs_chunk_spec: Optional[Dict[str, TensorChunkSpec]] = None,
|
|
output_merge_spec: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
|
|
):
|
|
# From arguments
|
|
self._n_microbatches = n_microbatches
|
|
self._loss_fn = loss_fn
|
|
# Chunking specification for positional inputs. (default: `None`)
|
|
self._args_chunk_spec = args_chunk_spec
|
|
# Chunking specification for keyword inputs. (default: `None`)
|
|
self._kwargs_chunk_spec = kwargs_chunk_spec
|
|
self._output_merge_spec = output_merge_spec
|
|
"""
|
|
# args_chunk_spec and kwargs_chunk_spec specify how to chunk inputs.
|
|
# They are used to convert batch to microbatches in `step(x)`. See
|
|
# `TensorChunkSpec` for helper methods for creating them.
|
|
"""
|
|
|
|
# Derived
|
|
self._has_backward = self._loss_fn is not None
|
|
|
|
# Holds the losses for each microbatch.
|
|
self._internal_losses: List[torch.Tensor] = []
|
|
logger.info(f"Using {self.__class__.__name__}") # noqa: G004
|
|
|
|
def _maybe_compute_loss(self, stage, output, target_mbs, mb_index):
|
|
if stage.is_last and self._has_backward:
|
|
loss = self._compute_loss(output, target_mbs[mb_index]) # type: ignore[index]
|
|
self._internal_losses.append(loss)
|
|
|
|
def _maybe_get_loss(self, stage, mb_index):
|
|
valid_index = 0 <= mb_index < len(self._internal_losses)
|
|
if stage.is_last and self._has_backward and valid_index:
|
|
return self._internal_losses[mb_index]
|
|
elif len(self._internal_losses) != 0 and not valid_index:
|
|
raise RuntimeError(
|
|
f"Loss for microbatch {mb_index} is not available. "
|
|
f"Available losses for microbatches: {self._internal_losses}"
|
|
)
|
|
else:
|
|
return None
|
|
|
|
def _update_losses(self, stages, losses):
|
|
"""
|
|
Update the losses to those in the internal state
|
|
"""
|
|
# if stages not a list turn into a list
|
|
if not isinstance(stages, list):
|
|
stages = [stages]
|
|
contains_last_stage = any(stage.is_last for stage in stages)
|
|
|
|
# Return losses if there is a container passed in
|
|
if contains_last_stage and losses is not None:
|
|
if len(self._internal_losses) != self._n_microbatches:
|
|
raise RuntimeError(
|
|
f"Expecting {self._n_microbatches} losses but got {len(self._internal_losses)}"
|
|
)
|
|
|
|
# Clean external container first
|
|
losses.clear()
|
|
# Copy internal losses to external container
|
|
losses.extend(self._internal_losses)
|
|
|
|
self._internal_losses.clear()
|
|
|
|
@abstractmethod
|
|
def _step_microbatches(
|
|
self,
|
|
arg_mbs: Optional[List] = None,
|
|
kwarg_mbs: Optional[List] = None,
|
|
target_mbs: Optional[List] = None,
|
|
losses: Optional[List] = None,
|
|
):
|
|
"""
|
|
Run one iteration of the pipeline schedule with list of microbatches.
|
|
Will go through all the microbatches according to the schedule
|
|
implementation.
|
|
|
|
Args:
|
|
microbatches: list of microbatch args.
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
def step(self, *args, target=None, losses: Optional[List] = None, **kwargs):
|
|
"""
|
|
Run one iteration of the pipeline schedule with *whole-batch* input.
|
|
Will chunk the input into microbatches automatically, and go through the
|
|
microbatches according to the schedule implementation.
|
|
|
|
args: positional arguments to the model (as in non-pipeline case).
|
|
kwargs: keyword arguments to the model (as in non-pipeline case).
|
|
target: target for the loss function.
|
|
losses: a list to store the losses for each microbatch.
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
def _check_inputs(
|
|
self,
|
|
arg_mbs: Optional[List] = None,
|
|
kwarg_mbs: Optional[List] = None,
|
|
target_mbs: Optional[List] = None,
|
|
losses: Optional[List] = None,
|
|
):
|
|
"""
|
|
Pre-process/check inputs
|
|
"""
|
|
|
|
def check_type_and_len(mbs, name: str):
|
|
if not isinstance(mbs, list):
|
|
raise TypeError(f"{name} must be a list but got a {type(mbs)}")
|
|
if len(mbs) != self._n_microbatches:
|
|
raise ValueError(
|
|
f"Expecting {self._n_microbatches} {name} but got {len(mbs)}"
|
|
)
|
|
|
|
if arg_mbs is not None:
|
|
check_type_and_len(arg_mbs, "arg_mbs")
|
|
else:
|
|
arg_mbs = [()] * self._n_microbatches
|
|
|
|
if kwarg_mbs is not None:
|
|
check_type_and_len(kwarg_mbs, "kwarg_mbs")
|
|
else:
|
|
kwarg_mbs = [{}] * self._n_microbatches
|
|
|
|
if target_mbs is not None:
|
|
check_type_and_len(target_mbs, "target_mbs")
|
|
|
|
if losses is not None:
|
|
if not isinstance(losses, list):
|
|
raise TypeError(f"losses must be a list but got a {type(losses)}")
|
|
|
|
return arg_mbs, kwarg_mbs
|
|
|
|
def _compute_loss(self, output, target):
|
|
return self._loss_fn(output, target) # type: ignore[misc]
|
|
|
|
def _split_inputs(
|
|
self,
|
|
args: Tuple[Any, ...],
|
|
kwargs: Optional[Dict[str, Any]] = None,
|
|
):
|
|
"""
|
|
Splits a full-batch input into chunks (i.e. microbatches) and returns
|
|
the chunks
|
|
"""
|
|
if args or kwargs:
|
|
args_split, kwargs_split = split_args_kwargs_into_chunks(
|
|
args,
|
|
kwargs,
|
|
self._n_microbatches,
|
|
self._args_chunk_spec,
|
|
self._kwargs_chunk_spec,
|
|
)
|
|
return args_split, kwargs_split
|
|
else:
|
|
# Empty inputs (e.g. when called on middle stages)
|
|
# Return a list of empty tuples/dicts with matching length as chunks
|
|
return [()] * self._n_microbatches, [{}] * self._n_microbatches
|
|
|
|
def _merge_outputs(self, output_chunks: List[Any]) -> Any:
|
|
"""
|
|
Merge output chunks back to a batch state.
|
|
If output_merge_spec is None, the utility will merge output chunks by dimension 0 (batch dim).
|
|
"""
|
|
return merge_chunks(
|
|
output_chunks,
|
|
self._output_merge_spec,
|
|
)
|
|
|
|
|
|
def _batch_p2p(p2p_ops: List[dist.P2POp], desc: Optional[str] = None):
|
|
"""
|
|
Simple wrapper over batch_isend_irecv from torch.distributed, which just adds a descriptive logger on top.
|
|
"""
|
|
if len(p2p_ops) == 0:
|
|
return None
|
|
desc_str = f"{desc}, " if desc else ""
|
|
logger.debug(f"batch_p2p {desc_str}{p2p_ops}") # noqa: G004
|
|
return dist.batch_isend_irecv(p2p_ops).pop()
|
|
|
|
|
|
def _sorted_batch_p2p(
|
|
p2p_ops: List[dist.P2POp], desc: Optional[str] = None
|
|
) -> Dict[int, dist.Work]:
|
|
"""
|
|
Sorts the list of P2P ops by the peer rank, and then calls
|
|
batch_isend_irecv. Return a dictionary of works by peer rank. This function
|
|
helps us avoid hangs in case of skip connections.
|
|
"""
|
|
# Arrange p2p_ops by peer rank:
|
|
# int is the peer rank;
|
|
# List is the list of ops towards the peer
|
|
ops_by_peer: Dict[int, List[dist.P2POp]] = defaultdict(list)
|
|
work_by_peer: Dict[int, dist.Work] = {}
|
|
if len(p2p_ops) == 0:
|
|
return work_by_peer
|
|
|
|
# Classify the ops by peer rank
|
|
for op in p2p_ops:
|
|
ops_by_peer[op.peer].append(op)
|
|
|
|
# Call batch_isend_irecv per peer, in sorted order of the peers (to avoid hangs)
|
|
for peer, ops in sorted(ops_by_peer.items()):
|
|
work_by_peer[peer] = _batch_p2p(ops, desc=desc)
|
|
|
|
return work_by_peer
|
|
|
|
|
|
class PipelineScheduleSingle(_PipelineSchedule):
|
|
"""
|
|
Base class for single-stage schedules.
|
|
Implements the `step` method.
|
|
Derived classes should implement `_step_microbatches`.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
stage: _PipelineStageBase,
|
|
n_microbatches: int,
|
|
loss_fn: Optional[Callable] = None,
|
|
args_chunk_spec: Optional[Tuple[TensorChunkSpec, ...]] = None,
|
|
kwargs_chunk_spec: Optional[Dict[str, TensorChunkSpec]] = None,
|
|
output_merge_spec: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
|
|
):
|
|
# Init parent
|
|
super().__init__(
|
|
n_microbatches=n_microbatches,
|
|
loss_fn=loss_fn,
|
|
args_chunk_spec=args_chunk_spec,
|
|
kwargs_chunk_spec=kwargs_chunk_spec,
|
|
output_merge_spec=output_merge_spec,
|
|
)
|
|
# Self attributes
|
|
self._stage = stage
|
|
self._num_stages = stage.num_stages
|
|
# Set the same has_backward flag for stage object
|
|
self._stage.has_backward = self._has_backward
|
|
|
|
# TODO: later replace this with lazy shape inference during forward
|
|
# Prepare forward send/recv infrastructure for stage
|
|
stage._prepare_forward_infra(n_microbatches)
|
|
if self._has_backward:
|
|
stage._prepare_backward_infra(n_microbatches)
|
|
|
|
def step(self, *args, target=None, losses: Optional[List] = None, **kwargs):
|
|
"""
|
|
Run one iteration of the pipeline schedule with *whole-batch* input.
|
|
Will chunk the input into microbatches automatically, and go through the
|
|
microbatches according to the schedule implementation.
|
|
|
|
args: positional arguments to the model (as in non-pipeline case).
|
|
kwargs: keyword arguments to the model (as in non-pipeline case).
|
|
target: target for the loss function.
|
|
losses: a list to store the losses for each microbatch.
|
|
"""
|
|
|
|
# Clean per iteration
|
|
self._stage.clear_runtime_states()
|
|
|
|
# Split inputs into microbatches
|
|
args_split, kwargs_split = self._split_inputs(args, kwargs)
|
|
|
|
# Split target into microbatches
|
|
if target is not None:
|
|
targets_split = list(torch.tensor_split(target, self._n_microbatches))
|
|
else:
|
|
targets_split = None
|
|
|
|
# Run microbatches
|
|
self._step_microbatches(args_split, kwargs_split, targets_split, losses)
|
|
|
|
# Return merged results per original format
|
|
if self._stage.is_last:
|
|
return self._merge_outputs(self._stage.output_chunks)
|
|
else:
|
|
return None
|
|
|
|
|
|
class ScheduleGPipe(PipelineScheduleSingle):
|
|
"""
|
|
The GPipe schedule.
|
|
Will go through all the microbatches in a fill-drain manner.
|
|
"""
|
|
|
|
def _step_microbatches(
|
|
self,
|
|
arg_mbs: Optional[List] = None,
|
|
kwarg_mbs: Optional[List] = None,
|
|
target_mbs: Optional[List] = None,
|
|
losses: Optional[List] = None,
|
|
):
|
|
"""
|
|
Run one iteration of the pipeline schedule with list of microbatches.
|
|
Will go through all the microbatches according to the GPipe schedule.
|
|
|
|
Args:
|
|
microbatches: list of microbatch args.
|
|
"""
|
|
arg_mbs, kwarg_mbs = self._check_inputs(arg_mbs, kwarg_mbs, target_mbs, losses)
|
|
|
|
# Delay send waits
|
|
fwd_sends_to_wait: List[dist.Work] = []
|
|
|
|
# Run microbatches
|
|
for i in range(self._n_microbatches):
|
|
with record_function(f"Forward {i}"):
|
|
ops = self._stage.get_fwd_recv_ops(i)
|
|
works = _sorted_batch_p2p(ops, desc="fwd_recv")
|
|
for work in works.values():
|
|
work.wait()
|
|
|
|
output = self._stage.forward_one_chunk(i, arg_mbs[i], kwarg_mbs[i]) # type: ignore[index]
|
|
|
|
ops = self._stage.get_fwd_send_ops(i)
|
|
works = _sorted_batch_p2p(ops, desc="fwd_send")
|
|
fwd_sends_to_wait.extend(works.values())
|
|
|
|
logger.debug(
|
|
f"[{self._stage.stage_index}] Forwarded microbatch {i}" # noqa: G004
|
|
)
|
|
|
|
self._maybe_compute_loss(self._stage, output, target_mbs, i)
|
|
|
|
# Wait for all forward sends to finish
|
|
# This should not have performance impact because by the time the first
|
|
# backward arrives all the forward sends should have been finished.
|
|
for work in fwd_sends_to_wait:
|
|
work.wait()
|
|
|
|
# No loss function, no need to run backward
|
|
if not self._has_backward:
|
|
return
|
|
|
|
# Run backward
|
|
# Delay send waits
|
|
bwd_sends_to_wait: List[dist.Work] = []
|
|
for i in range(self._n_microbatches):
|
|
with record_function(f"Backward {i}"):
|
|
ops = self._stage.get_bwd_recv_ops(i)
|
|
works = _sorted_batch_p2p(ops, desc="bwd_recv")
|
|
for work in works.values():
|
|
work.wait()
|
|
|
|
loss = self._maybe_get_loss(self._stage, i)
|
|
self._stage.backward_one_chunk(i, loss=loss)
|
|
|
|
ops = self._stage.get_bwd_send_ops(i)
|
|
works = _sorted_batch_p2p(ops, desc="bwd_send")
|
|
bwd_sends_to_wait.extend(works.values())
|
|
|
|
logger.debug(
|
|
f"[{self._stage.stage_index}] Backwarded microbatch {i}" # noqa: G004
|
|
)
|
|
|
|
# Return losses if there is a container passed in
|
|
self._update_losses(self._stage, losses)
|
|
|
|
# Wait for all backward sends to finish
|
|
for work in bwd_sends_to_wait:
|
|
work.wait()
|
|
|
|
|
|
class Schedule1F1B(PipelineScheduleSingle):
|
|
"""
|
|
The 1F1B schedule.
|
|
Will perform one forward and one backward on the microbatches in steady state.
|
|
"""
|
|
|
|
def _step_microbatches(
|
|
self,
|
|
arg_mbs: Optional[List] = None,
|
|
kwarg_mbs: Optional[List] = None,
|
|
target_mbs: Optional[List] = None,
|
|
losses: Optional[List] = None,
|
|
):
|
|
"""
|
|
Run one iteration of the pipeline schedule with list of microbatches.
|
|
Will go through all the microbatches according to the 1F1B schedule.
|
|
|
|
Args:
|
|
microbatches: list of microbatch args.
|
|
"""
|
|
arg_mbs, kwarg_mbs = self._check_inputs(arg_mbs, kwarg_mbs, target_mbs, losses)
|
|
|
|
# Last stage has 1 warmup, second-to-last 2 warmups, ...
|
|
# first stage `num_stages` warmups
|
|
warmup_chunks = min(
|
|
self._n_microbatches,
|
|
self._num_stages - self._stage.stage_index,
|
|
)
|
|
|
|
# Chunk counters
|
|
fwd_mb_index = 0
|
|
bwd_mb_index = 0
|
|
|
|
# Warmup phase
|
|
send_work = None
|
|
fwd_sends = []
|
|
for _ in range(warmup_chunks):
|
|
# Receive activations
|
|
fwd_recvs = self._stage.get_fwd_recv_ops(fwd_mb_index)
|
|
if recv_work := _batch_p2p(fwd_recvs, desc="fwd_recv"):
|
|
recv_work.wait()
|
|
|
|
# Compute
|
|
output = self._stage.forward_one_chunk(fwd_mb_index, arg_mbs[fwd_mb_index], kwarg_mbs[fwd_mb_index]) # type: ignore[index]
|
|
|
|
# Clear previous chunk's forward sends (hopefully they have well
|
|
# finished, otherwise, we are heavily communication bound, in which
|
|
# case it doesn't create a lot of benefit to compute next chunk
|
|
# eagerly either)
|
|
if send_work:
|
|
send_work.wait()
|
|
|
|
# Send activations
|
|
fwd_sends = self._stage.get_fwd_send_ops(fwd_mb_index)
|
|
if fwd_mb_index != warmup_chunks - 1:
|
|
# Safe to fire
|
|
send_work = _batch_p2p(fwd_sends, desc="fwd_send")
|
|
# otherwise:
|
|
# The last foward send is left for fuse with first 1B in 1B1F below
|
|
|
|
# Compute loss
|
|
self._maybe_compute_loss(self._stage, output, target_mbs, fwd_mb_index)
|
|
fwd_mb_index += 1
|
|
|
|
# Now we should have send ops left over, to be fused with first 1B of 1B1F phase below.
|
|
|
|
# 1B1F phase
|
|
while True: # Don't worry, we have a break inside
|
|
# We actually do 1B first as the `1B1F` name indicates, so prepare its recv ops
|
|
bwd_recvs = self._stage.get_bwd_recv_ops(bwd_mb_index)
|
|
|
|
# Now, we need to fire the fwd_sends and bwd_recvs together
|
|
if fuse_work := _batch_p2p(fwd_sends + bwd_recvs, desc="fwd_send_bwd_recv"):
|
|
fuse_work.wait()
|
|
|
|
# Backward one chunk
|
|
loss = self._maybe_get_loss(self._stage, bwd_mb_index)
|
|
self._stage.backward_one_chunk(bwd_mb_index, loss=loss)
|
|
|
|
# Get the bwd send ops, but don't fire, to be fused with the 1F below
|
|
bwd_sends = self._stage.get_bwd_send_ops(bwd_mb_index)
|
|
bwd_mb_index += 1
|
|
|
|
if fwd_mb_index == self._n_microbatches:
|
|
# We are done with 1B1F, so break with some left-over bwd_sends
|
|
break
|
|
|
|
# We prepare 1F of the `1B1F`
|
|
fwd_recvs = self._stage.get_fwd_recv_ops(fwd_mb_index)
|
|
|
|
# Fuse it with bwd_sends above
|
|
if fuse_work := _batch_p2p(bwd_sends + fwd_recvs, desc="bwd_send_fwd_recv"):
|
|
fuse_work.wait()
|
|
|
|
# Now do the fwd
|
|
output = self._stage.forward_one_chunk(fwd_mb_index, arg_mbs[fwd_mb_index], kwarg_mbs[fwd_mb_index]) # type: ignore[index]
|
|
|
|
# Compute loss
|
|
self._maybe_compute_loss(self._stage, output, target_mbs, fwd_mb_index)
|
|
|
|
# Get the fwd send ops, but don't fire, leave it for the next iter (wrap-around)
|
|
fwd_sends = self._stage.get_fwd_send_ops(fwd_mb_index)
|
|
fwd_mb_index += 1
|
|
|
|
# Remember we still have some bwd_sends left over after the break? Now it is time to fire it
|
|
send_work = _batch_p2p(bwd_sends, desc="bwd_send")
|
|
|
|
# Cooldown
|
|
while bwd_mb_index < self._n_microbatches:
|
|
# prepare bwd recv ops
|
|
bwd_recvs = self._stage.get_bwd_recv_ops(bwd_mb_index)
|
|
if recv_work := _batch_p2p(bwd_recvs, desc="bwd_recv"):
|
|
recv_work.wait()
|
|
|
|
# Backward one chunk
|
|
loss = self._maybe_get_loss(self._stage, bwd_mb_index)
|
|
self._stage.backward_one_chunk(bwd_mb_index, loss=loss)
|
|
|
|
# Clear previous chunk's backward sends (hopefully they have well finished)
|
|
if send_work:
|
|
send_work.wait()
|
|
|
|
# Get the bwd send ops, fire it
|
|
bwd_sends = self._stage.get_bwd_send_ops(bwd_mb_index)
|
|
send_work = _batch_p2p(bwd_sends, desc="bwd_send")
|
|
bwd_mb_index += 1
|
|
|
|
# Wait for the last backward send to finish
|
|
if send_work:
|
|
send_work.wait()
|
|
|
|
# Return losses if there is a container passed in
|
|
self._update_losses(self._stage, losses)
|
|
|
|
|
|
class PipelineScheduleMulti(_PipelineSchedule):
|
|
"""
|
|
Base class for multi-stage schedules.
|
|
Implements the `step` method.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
stages: List[_PipelineStageBase],
|
|
n_microbatches: int,
|
|
loss_fn: Optional[Callable] = None,
|
|
args_chunk_spec: Optional[Tuple[TensorChunkSpec, ...]] = None,
|
|
kwargs_chunk_spec: Optional[Dict[str, TensorChunkSpec]] = None,
|
|
output_merge_spec: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
|
|
stage_index_to_group_rank: Optional[Dict[int, int]] = None,
|
|
):
|
|
if len(stages) <= 1:
|
|
raise ValueError(
|
|
f"Multi-stage schedule expects at least two stages but got {len(stages)}"
|
|
)
|
|
# Init parent
|
|
super().__init__(
|
|
n_microbatches=n_microbatches,
|
|
loss_fn=loss_fn,
|
|
args_chunk_spec=args_chunk_spec,
|
|
kwargs_chunk_spec=kwargs_chunk_spec,
|
|
output_merge_spec=output_merge_spec,
|
|
)
|
|
# Self attributes
|
|
self._stages = stages
|
|
self._num_stages = stages[0].num_stages
|
|
self.pp_group_size = stages[0].group_size
|
|
self.rank = stages[0].group_rank
|
|
# Set the pipeline stage states
|
|
if stage_index_to_group_rank is not None:
|
|
for stage in self._stages:
|
|
stage.stage_index_to_group_rank = stage_index_to_group_rank
|
|
self.stage_index_to_group_rank = stages[0].stage_index_to_group_rank
|
|
|
|
# Set the same has_backward flag for stage object
|
|
for stage in self._stages:
|
|
stage.has_backward = self._has_backward
|
|
|
|
self._should_compute_loss = (
|
|
lambda stage: stage.is_last and self._loss_fn is not None
|
|
)
|
|
|
|
# This will be set during init of derived schedules
|
|
self.pipeline_order: Dict[int, List[Optional[_Action]]] = {}
|
|
self.use_full_backward = True
|
|
|
|
# TODO: later replace this with lazy shape inference during forward
|
|
# Prepare forward send/recv infrastructure for stage
|
|
for stage in self._stages:
|
|
stage._prepare_forward_infra(n_microbatches)
|
|
if self._has_backward:
|
|
stage._prepare_backward_infra(n_microbatches)
|
|
|
|
def _dump_csv(self, filename):
|
|
"""Dump a CSV representation of the schedule into a file with the provided filename.
|
|
This API will most likely get renamed/refactored so is marked as internal for now.
|
|
"""
|
|
with open(filename, "w", newline="") as csvfile:
|
|
writer = csv.writer(csvfile)
|
|
for rank in self.pipeline_order:
|
|
writer.writerow(self.pipeline_order[rank])
|
|
|
|
def _validate_schedule(self):
|
|
# TODO(whc) this should be merged with the logic in test_schedule.py#L453-L554
|
|
def _validate_rank_actions(
|
|
actions: Dict[int, List[_Action | None]],
|
|
num_stages: int,
|
|
num_microbatches: int,
|
|
):
|
|
# We will count all the actions per stage and ensure they happen in a valid order
|
|
# (e.g. F before B before W for a given microbatch)
|
|
stage_actions: Dict[int, Dict[_ComputationType, Set]] = {
|
|
stage_id: {
|
|
F: set(),
|
|
B: set(),
|
|
W: set(),
|
|
}
|
|
for stage_id in range(num_stages)
|
|
}
|
|
for rank in actions:
|
|
for action in actions[rank]:
|
|
if action is None:
|
|
continue
|
|
assert isinstance(
|
|
action, _Action
|
|
), f"Got an invalid action: {action}, expected instance of _Action"
|
|
s_id = action.stage_index
|
|
ctype = action.computation_type
|
|
mb_id = action.microbatch_index
|
|
if ctype == F:
|
|
stage_actions[s_id][F].add(mb_id)
|
|
elif ctype == B:
|
|
assert (
|
|
mb_id in stage_actions[s_id][F]
|
|
), f"Running Backward for stage {s_id}, microbatch {mb_id} without first running Forward"
|
|
stage_actions[s_id][B].add(mb_id)
|
|
elif ctype == W:
|
|
assert (
|
|
not self.use_full_backward
|
|
), "Schedule contains 'W' actions, but is configured to use full backward"
|
|
assert (
|
|
mb_id in stage_actions[s_id][B]
|
|
), f"Running Weight for stage {s_id}, microbatch {mb_id} without first running Backward"
|
|
stage_actions[s_id][W].add(mb_id)
|
|
|
|
for s_id in stage_actions:
|
|
for ctype in (F, B, W):
|
|
stage_mb = len(stage_actions[s_id][ctype])
|
|
assert (
|
|
stage_mb == num_microbatches
|
|
), f"Got {stage_mb} {ctype} microbatches for stage {s_id}, expected {num_microbatches}"
|
|
|
|
assert (
|
|
len(self.pipeline_order) == self.pp_group_size
|
|
), f"Schedule has incorrect number of ranks - expected {self.pp_group_size}, actual {len(self.pipeline_order)}"
|
|
for rank in range(self.pp_group_size):
|
|
assert (
|
|
rank in self.pipeline_order
|
|
), f"Schedule is missing actions for rank {rank}"
|
|
_validate_rank_actions(
|
|
self.pipeline_order,
|
|
self._num_stages,
|
|
self._n_microbatches,
|
|
)
|
|
|
|
def _load_csv(self, filename):
|
|
"""Load a CSV representation of the schedule from a file with the provided filename.
|
|
This API will most likely get renamed/refactored so is marked as internal for now.
|
|
"""
|
|
with open(filename, newline="") as csvfile:
|
|
reader = csv.reader(csvfile)
|
|
for rank, row in enumerate(reader):
|
|
self.pipeline_order[rank] = [_Action.from_str(s) for s in row]
|
|
self._validate_schedule()
|
|
|
|
def step(self, *args, target=None, losses: Optional[List] = None, **kwargs):
|
|
"""
|
|
Run one iteration of the pipeline schedule with *whole-batch* input.
|
|
Will chunk the input into microbatches automatically, and go through the
|
|
microbatches according to the schedule implementation.
|
|
|
|
args: positional arguments to the model (as in non-pipeline case).
|
|
kwargs: keyword arguments to the model (as in non-pipeline case).
|
|
target: target for the loss function.
|
|
losses: a list to store the losses for each microbatch.
|
|
"""
|
|
|
|
# Clean per iteration
|
|
for stage in self._stages:
|
|
stage.clear_runtime_states()
|
|
|
|
# Split inputs into microbatches
|
|
args_split, kwargs_split = self._split_inputs(args, kwargs)
|
|
|
|
# Split target into microbatches
|
|
if target is not None:
|
|
targets_split = list(torch.tensor_split(target, self._n_microbatches))
|
|
else:
|
|
targets_split = None
|
|
|
|
# Run microbatches
|
|
self._step_microbatches(args_split, kwargs_split, targets_split, losses)
|
|
|
|
# Return merged results per original format
|
|
for stage in self._stages:
|
|
if stage.is_last:
|
|
return self._merge_outputs(stage.output_chunks)
|
|
# Does not contain the last stage
|
|
return None
|
|
|
|
def _step_microbatches(
|
|
self,
|
|
arg_mbs: Optional[List] = None,
|
|
kwarg_mbs: Optional[List] = None,
|
|
target_mbs: Optional[List] = None,
|
|
losses: Optional[List] = None,
|
|
):
|
|
"""
|
|
Operate on the microbatches for looped schedules (multiple stages on each rank).
|
|
|
|
TODO: Does not use sorted_batch_isend_irecv(). As a result, this schedule does
|
|
not support models with skip connections.
|
|
"""
|
|
arg_mbs, kwarg_mbs = self._check_inputs(arg_mbs, kwarg_mbs, target_mbs, losses)
|
|
|
|
# Based on the plan in Step 1 created in __init__:
|
|
# 2. Perform communication based on the pipeline_order
|
|
stage_index_to_stage: Dict[int, _PipelineStageBase] = {
|
|
stage.stage_index: stage for stage in self._stages
|
|
}
|
|
|
|
# determine prev_rank and next_rank based on which ranks are next to
|
|
# the stages in the pipeline_order
|
|
all_prev_ranks: Set[int] = set()
|
|
all_next_ranks: Set[int] = set()
|
|
for stage_index in stage_index_to_stage.keys():
|
|
# TODO: assumption that stages only communicate from distances of +1/-1 (no skip connections)
|
|
if stage_index > 0:
|
|
all_prev_ranks.add(self.stage_index_to_group_rank[stage_index - 1])
|
|
if stage_index < self._num_stages - 1:
|
|
all_next_ranks.add(self.stage_index_to_group_rank[stage_index + 1])
|
|
|
|
for time_step, action in enumerate(self.pipeline_order[self.rank]):
|
|
ops: List[dist.P2POp] = []
|
|
if action is not None:
|
|
computation_type, mb_index, stage_index = action
|
|
if computation_type == _ComputationType.FORWARD:
|
|
# perform forward computation
|
|
stage = stage_index_to_stage[stage_index]
|
|
output = stage.forward_one_chunk(
|
|
mb_index, arg_mbs[mb_index], kwarg_mbs[mb_index]
|
|
)
|
|
self._maybe_compute_loss(stage, output, target_mbs, mb_index)
|
|
ops.extend(stage.get_fwd_send_ops(mb_index))
|
|
elif computation_type == _ComputationType.BACKWARD:
|
|
# perform backward computation
|
|
stage = stage_index_to_stage[stage_index]
|
|
loss = self._maybe_get_loss(stage, mb_index)
|
|
stage.backward_one_chunk(
|
|
mb_index, loss=loss, full_backward=self.use_full_backward
|
|
)
|
|
ops.extend(stage.get_bwd_send_ops(mb_index))
|
|
elif computation_type == _ComputationType.WEIGHT:
|
|
# perform weight update
|
|
if self.use_full_backward:
|
|
raise ValueError(
|
|
f"We detected a weight update in the pipeline schedule, but \
|
|
{self.use_full_backward=}"
|
|
)
|
|
stage = stage_index_to_stage[stage_index]
|
|
stage.backward_weight_one_chunk(mb_index)
|
|
else:
|
|
raise ValueError(f"Unknown computation type {computation_type}")
|
|
|
|
# Look at the neighboring ranks for this current timestep and determine whether
|
|
# this current rank needs to do any recv communication
|
|
for prev_rank in all_prev_ranks:
|
|
prev_rank_ops = self.pipeline_order[prev_rank]
|
|
prev_rank_action = None
|
|
if time_step < len(prev_rank_ops):
|
|
prev_rank_action = prev_rank_ops[time_step]
|
|
if prev_rank_action is not None:
|
|
computation_type, mb_index, stage_index = prev_rank_action
|
|
# Only handle sends for the forward from a previous rank
|
|
if computation_type == _ComputationType.FORWARD:
|
|
# If not the last stage, then receive fwd activations
|
|
if stage_index + 1 in stage_index_to_stage:
|
|
# TODO: We are assuming that stage will always receive from stage-1
|
|
# however that is not necessarily true of get_fwd_recv_ops
|
|
stage = stage_index_to_stage[stage_index + 1]
|
|
ops.extend(stage.get_fwd_recv_ops(mb_index))
|
|
elif (
|
|
computation_type == _ComputationType.BACKWARD
|
|
or computation_type == _ComputationType.WEIGHT
|
|
):
|
|
# Previous rank doing backward or weight update has no influence for the current rank forward recv
|
|
pass
|
|
else:
|
|
raise ValueError(f"Unknown computation type {computation_type}")
|
|
|
|
for next_rank in all_next_ranks:
|
|
next_rank_ops = self.pipeline_order[next_rank]
|
|
next_rank_action = None
|
|
if time_step < len(next_rank_ops):
|
|
next_rank_action = next_rank_ops[time_step]
|
|
if next_rank_action is not None:
|
|
computation_type, mb_index, stage_index = next_rank_action
|
|
# Only handle receives for the backwards from a next rank
|
|
if (
|
|
computation_type == _ComputationType.FORWARD
|
|
or computation_type == _ComputationType.WEIGHT
|
|
):
|
|
# Next rank doing forward or weight update has no influence for the current rank backward recv
|
|
pass
|
|
elif computation_type == _ComputationType.BACKWARD:
|
|
# If not the first stage, then receive bwd gradients
|
|
if stage_index - 1 in stage_index_to_stage:
|
|
# TODO: We are assuming that stage will always receive from stage+1
|
|
# however that is not necessarily true of get_bwd_recv_ops
|
|
stage = stage_index_to_stage[stage_index - 1]
|
|
ops.extend(stage.get_bwd_recv_ops(mb_index))
|
|
else:
|
|
raise ValueError(f"Unknown computation type {computation_type}")
|
|
|
|
# do the communication
|
|
if ops:
|
|
_batch_p2p(ops).wait()
|
|
# Return losses if there is a container passed in
|
|
self._update_losses(self._stages, losses)
|
|
|
|
|
|
class ScheduleLoopedBFS(PipelineScheduleMulti):
|
|
"""
|
|
Breadth-First Pipeline Parallelism.
|
|
See https://arxiv.org/abs/2211.05953 for details.
|
|
Simliar to Interleaved 1F1B, Looped BFS supports multiple stages per rank.
|
|
What is different is that when microbatches are ready for multiple local
|
|
stages, Loops BFS will prioritizes the earlier stage, running all available
|
|
microbatches at once.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
stages: List[_PipelineStageBase],
|
|
n_microbatches: int,
|
|
loss_fn: Optional[Callable] = None,
|
|
output_merge_spec: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
|
|
):
|
|
super().__init__(
|
|
stages=stages,
|
|
n_microbatches=n_microbatches,
|
|
loss_fn=loss_fn,
|
|
output_merge_spec=output_merge_spec,
|
|
)
|
|
|
|
# 1. Create the pipeline_order (all ranks do this calculation)
|
|
# This will be used to keep track of the current state of the entire pipeline
|
|
# pipeline_order[rank] = [Action(computation_type, microbatch_index, stage_index), ...]
|
|
self.pipeline_order: Dict[int, List[Optional[_Action]]] = {}
|
|
# ========================================================================
|
|
for rank in range(self.pp_group_size):
|
|
rank_ops = self._calculate_single_rank_operations(rank)
|
|
self.pipeline_order[rank] = rank_ops
|
|
|
|
def _calculate_single_rank_operations(self, rank):
|
|
n_local_stages = len(self._stages)
|
|
stage_indices = range(
|
|
rank, self.pp_group_size * n_local_stages, self.pp_group_size
|
|
)
|
|
|
|
# Store the list of operations used for that rank
|
|
rank_ops: List[Optional[_Action]] = []
|
|
# Pre-padding, rank starts with no-ops based on the warmup.
|
|
for _ in range(rank):
|
|
rank_ops.append(None)
|
|
|
|
for stage_index in stage_indices:
|
|
for mb_index in range(self._n_microbatches):
|
|
rank_ops.append(
|
|
_Action(_ComputationType.FORWARD, mb_index, stage_index)
|
|
)
|
|
|
|
# wait for the first backward to trickle up
|
|
# which is 2 for every hop away
|
|
post_warmup_ops = 2 * (self.pp_group_size - 1 - rank)
|
|
rank_ops.extend([None] * post_warmup_ops)
|
|
|
|
for stage_index in reversed(stage_indices):
|
|
for mb_index in reversed(range(self._n_microbatches)):
|
|
rank_ops.append(
|
|
_Action(_ComputationType.BACKWARD, mb_index, stage_index)
|
|
)
|
|
return rank_ops
|
|
|
|
|
|
def _get_1f1b_rank_ops(
|
|
n_local_stages,
|
|
pp_group_size,
|
|
warmup_ops,
|
|
fwd_bwd_ops,
|
|
cooldown_ops,
|
|
rank,
|
|
forward_stage_index,
|
|
backward_stage_index,
|
|
):
|
|
# All stages start with handling microbatch 0
|
|
fwd_stage_mb_index: Dict[int, int] = defaultdict(int)
|
|
bwd_stage_mb_index: Dict[int, int] = defaultdict(int)
|
|
# Store the list of operations used for that rank
|
|
rank_ops: List[Optional[_Action]] = []
|
|
# Pre-padding, rank starts with no-ops based on the warmup.
|
|
for _ in range(rank):
|
|
rank_ops.append(None)
|
|
# These are used to calculate the number of slots to fill with no-ops, to account for the delay in warmup
|
|
# when we want to wait for the backward to trickle back up and start 1f1b to align all ranks.
|
|
# Formula:
|
|
# pre-padding + warmup_ops + post_warmup_ops = earliest time step of first backward
|
|
# post_warmup_ops = [earliest time step of first backward] - (warmup_ops + pre-padding)
|
|
# earliest time step of first backward = [local_stages * group_size + 2 * (group_size - 1 - rank)]
|
|
# warmup_ops = calculated above
|
|
post_warmup_ops = (
|
|
n_local_stages * pp_group_size + 2 * (pp_group_size - 1 - rank)
|
|
) - (warmup_ops + rank)
|
|
|
|
total_ops = warmup_ops + fwd_bwd_ops + cooldown_ops
|
|
|
|
for op in range(total_ops):
|
|
# Warmup phase
|
|
if op < warmup_ops:
|
|
fwd_stage_index = forward_stage_index(op)
|
|
# This will assign the current microbatch index and update it as well
|
|
fwd_stage_mb_index[fwd_stage_index] = (
|
|
mb_index := fwd_stage_mb_index[fwd_stage_index]
|
|
) + 1
|
|
rank_ops.append(
|
|
_Action(_ComputationType.FORWARD, mb_index, fwd_stage_index)
|
|
)
|
|
if op == warmup_ops - 1:
|
|
# This is the last step in the warmup phase, so we need to wait for the backward to trickle back up
|
|
rank_ops.extend([None] * post_warmup_ops)
|
|
# 1F1B Phase (forward and backward)
|
|
elif warmup_ops <= op < warmup_ops + fwd_bwd_ops:
|
|
fwd_stage_index = forward_stage_index(op)
|
|
fwd_stage_mb_index[fwd_stage_index] = (
|
|
fwd_mb_index := fwd_stage_mb_index[fwd_stage_index]
|
|
) + 1
|
|
rank_ops.append(
|
|
_Action(_ComputationType.FORWARD, fwd_mb_index, fwd_stage_index)
|
|
)
|
|
bwd_stage_index = backward_stage_index(op)
|
|
bwd_stage_mb_index[bwd_stage_index] = (
|
|
bwd_mb_index := bwd_stage_mb_index[bwd_stage_index]
|
|
) + 1
|
|
rank_ops.append(
|
|
_Action(_ComputationType.BACKWARD, bwd_mb_index, bwd_stage_index)
|
|
)
|
|
# Cooldown phase
|
|
else:
|
|
# During cooldown phase, we need steps to align with 1f1b happening in other ranks
|
|
# TODO: we don't need to always append, after all 1f1b are finished we can stop appending None
|
|
rank_ops.append(None)
|
|
bwd_stage_index = backward_stage_index(op)
|
|
bwd_stage_mb_index[bwd_stage_index] = (
|
|
bwd_mb_index := bwd_stage_mb_index[bwd_stage_index]
|
|
) + 1
|
|
rank_ops.append(
|
|
_Action(_ComputationType.BACKWARD, bwd_mb_index, bwd_stage_index)
|
|
)
|
|
return rank_ops
|
|
|
|
|
|
class ScheduleInterleaved1F1B(PipelineScheduleMulti):
|
|
"""
|
|
The Interleaved 1F1B schedule.
|
|
See https://arxiv.org/pdf/2104.04473 for details.
|
|
Will perform one forward and one backward on the microbatches in steady
|
|
state and supports multiple stages per rank. When microbatches are ready for
|
|
multiple local stages, Interleaved 1F1B prioritizes the earlier microbatch
|
|
(also called "depth first").
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
stages: List[_PipelineStageBase],
|
|
n_microbatches: int,
|
|
loss_fn: Optional[Callable] = None,
|
|
args_chunk_spec: Optional[Tuple[TensorChunkSpec, ...]] = None,
|
|
kwargs_chunk_spec: Optional[Dict[str, TensorChunkSpec]] = None,
|
|
output_merge_spec: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
|
|
):
|
|
self.pp_group_size = stages[0].group_size
|
|
# TODO: is this limitation a must?
|
|
if n_microbatches % self.pp_group_size != 0:
|
|
raise ValueError(
|
|
f"Interleaved 1F1B schedule requires the number of microbatches ({n_microbatches}) \
|
|
to be a multiple of the number of pipeline ranks ({self.pp_group_size})."
|
|
)
|
|
|
|
super().__init__(
|
|
stages=stages,
|
|
n_microbatches=n_microbatches,
|
|
loss_fn=loss_fn,
|
|
args_chunk_spec=args_chunk_spec,
|
|
kwargs_chunk_spec=kwargs_chunk_spec,
|
|
output_merge_spec=output_merge_spec,
|
|
)
|
|
|
|
self.n_local_stages = len(stages)
|
|
self.rank = stages[0].group_rank
|
|
self.group = stages[0].group
|
|
|
|
# 1. Create the pipeline_order (all ranks do this calculation)
|
|
# This will be used to keep track of the current state of the entire pipeline
|
|
# pipeline_order[rank] = [Action(computation_type, microbatch_index, stage_index), ...]
|
|
self.pipeline_order: Dict[int, List[Optional[_Action]]] = {}
|
|
|
|
for rank in range(self.pp_group_size):
|
|
rank_ops = self._calculate_single_rank_operations(rank)
|
|
self.pipeline_order[rank] = rank_ops
|
|
|
|
def _calculate_single_rank_operations(self, rank) -> List[Optional[_Action]]:
|
|
def get_rank_warmup_ops(rank):
|
|
# Warms up operations for last stage
|
|
warmups_ops_last_stage = (self.n_local_stages - 1) * self.pp_group_size
|
|
# Increment warmup operations by 2 for each hop away from the last stage
|
|
warmup_ops = warmups_ops_last_stage + 2 * ((self.pp_group_size - 1) - rank)
|
|
# We cannot have more warmup operations than there are number of microbatches, so cap it there
|
|
return min(warmup_ops, self._n_microbatches * self.n_local_stages)
|
|
|
|
warmup_ops = get_rank_warmup_ops(rank)
|
|
microbatch_ops = self.n_local_stages * self._n_microbatches
|
|
# fwd_bwd_ops should encompass the remaining forwards
|
|
fwd_bwd_ops = microbatch_ops - warmup_ops
|
|
# cooldown_ops should encompass the remaining backwards
|
|
cooldown_ops = microbatch_ops - fwd_bwd_ops
|
|
# total ops encompass both forward and backward ops
|
|
total_ops = warmup_ops + fwd_bwd_ops + cooldown_ops
|
|
# warmup_ops + fwd_bwd_ops * 2 + cooldown_ops == microbatch_ops * 2
|
|
|
|
logger.debug(
|
|
"rank %s, warmup_ops %s, 1f1b %s, cooldown_ops %s total_ops %s",
|
|
rank,
|
|
warmup_ops,
|
|
fwd_bwd_ops,
|
|
cooldown_ops,
|
|
total_ops,
|
|
)
|
|
|
|
# Calculates the stage index based on step and pp_group_size
|
|
def forward_stage_index(step):
|
|
# Get the local index from 0 to n_local_stages-1
|
|
local_index = (step // self.pp_group_size) % self.n_local_stages
|
|
return (local_index * self.pp_group_size) + rank
|
|
|
|
def backward_stage_index(step):
|
|
local_index = (
|
|
self.n_local_stages
|
|
- 1
|
|
- ((step - warmup_ops) // self.pp_group_size) % self.n_local_stages
|
|
)
|
|
return (local_index * self.pp_group_size) + rank
|
|
|
|
return _get_1f1b_rank_ops(
|
|
self.n_local_stages,
|
|
self.pp_group_size,
|
|
warmup_ops,
|
|
fwd_bwd_ops,
|
|
cooldown_ops,
|
|
rank,
|
|
forward_stage_index,
|
|
backward_stage_index,
|
|
)
|
|
|
|
|
|
class ScheduleFlexibleInterleaved1F1B(PipelineScheduleMulti):
|
|
"""
|
|
The Flexible Interleaved 1F1B schedule.
|
|
|
|
This schedule is mostly similar to the interleaved 1F1B schedule.
|
|
It differs by being relaxing the requirement of num_microbatch % pp_size == 0.
|
|
Using the flex_pp schedule, we will have num_rounds = max(1, n_microbatches // pp_group_size) and
|
|
it works as long as n_microbatches % num_rounds is 0. As a few examples, support
|
|
|
|
1. pp_group_size = 4, n_microbatches = 10. We will have num_rounds = 2 and n_microbatches % 2 is 0.
|
|
2. pp_group_size = 4, n_microbatches = 3. We will have num_rounds = 1 and n_microbatches % 1 is 0.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
stages: List[_PipelineStageBase],
|
|
n_microbatches: int,
|
|
loss_fn: Optional[Callable] = None,
|
|
args_chunk_spec: Optional[Tuple[TensorChunkSpec, ...]] = None,
|
|
kwargs_chunk_spec: Optional[Dict[str, TensorChunkSpec]] = None,
|
|
output_merge_spec: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
|
|
):
|
|
self.pp_group_size = stages[0].group_size
|
|
super().__init__(
|
|
stages=stages,
|
|
n_microbatches=n_microbatches,
|
|
loss_fn=loss_fn,
|
|
args_chunk_spec=args_chunk_spec,
|
|
kwargs_chunk_spec=kwargs_chunk_spec,
|
|
output_merge_spec=output_merge_spec,
|
|
)
|
|
self.n_local_stages = len(stages)
|
|
self.rank = stages[0].group_rank
|
|
self.number_of_rounds = max(1, n_microbatches // self.pp_group_size)
|
|
self.microbatches_per_round = n_microbatches // self.number_of_rounds
|
|
if n_microbatches % self.number_of_rounds != 0:
|
|
raise ValueError(
|
|
"Flexible Interleaved 1F1B requires the number of microbatches to be a "
|
|
f"multiple of the number of rounds ({self.number_of_rounds}), "
|
|
f"but got {n_microbatches}."
|
|
)
|
|
# 1. Create the pipeline_order (all ranks do this calculation)
|
|
# This will be used to keep track of the current state of the entire pipeline
|
|
# pipeline_order[rank] = [Action(computation_type, microbatch_index, stage_index), ...]
|
|
self.pipeline_order: Dict[int, List[Optional[_Action]]] = {}
|
|
for rank in range(self.pp_group_size):
|
|
rank_ops = self._calculate_single_rank_operations(rank)
|
|
self.pipeline_order[rank] = rank_ops
|
|
|
|
def _calculate_single_rank_operations(self, rank) -> List[Optional[_Action]]:
|
|
def get_rank_warmup_ops(rank):
|
|
# Warms up operations for last stage
|
|
warmups_ops_last_stage = (
|
|
self.n_local_stages - 1
|
|
) * self.microbatches_per_round
|
|
# Increment warmup operations by 2 for each hop away from the last stage
|
|
warmup_ops = warmups_ops_last_stage + 2 * ((self.pp_group_size - 1) - rank)
|
|
# We cannot have more warmup operations than there are number of microbatches, so cap it there
|
|
return min(warmup_ops, self._n_microbatches * self.n_local_stages)
|
|
|
|
warmup_ops = get_rank_warmup_ops(rank)
|
|
microbatch_ops = self.n_local_stages * self._n_microbatches
|
|
# fwd_bwd_ops should encompass the remaining forwards
|
|
fwd_bwd_ops = microbatch_ops - warmup_ops
|
|
# cooldown_ops should encompass the remaining backwards
|
|
cooldown_ops = microbatch_ops - fwd_bwd_ops
|
|
# total ops encompass both forward and backward ops
|
|
total_ops = warmup_ops + fwd_bwd_ops + cooldown_ops
|
|
# warmup_ops + fwd_bwd_ops * 2 + cooldown_ops == microbatch_ops * 2
|
|
logger.debug(
|
|
"rank %s, warmup_ops %s, 1f1b %s, cooldown_ops %s total_ops %s",
|
|
rank,
|
|
warmup_ops,
|
|
fwd_bwd_ops,
|
|
cooldown_ops,
|
|
total_ops,
|
|
)
|
|
|
|
# Calculates the stage index based on step and pp_group_size
|
|
|
|
def forward_stage_index(step):
|
|
# Get the local index from 0 to n_local_stages-1
|
|
local_index = (step // self.microbatches_per_round) % self.n_local_stages
|
|
return (local_index * self.pp_group_size) + rank
|
|
|
|
def backward_stage_index(step):
|
|
local_index = (
|
|
self.n_local_stages
|
|
- 1
|
|
- ((step - warmup_ops) // self.microbatches_per_round)
|
|
% self.n_local_stages
|
|
)
|
|
return (local_index * self.pp_group_size) + rank
|
|
|
|
return _get_1f1b_rank_ops(
|
|
self.n_local_stages,
|
|
self.pp_group_size,
|
|
warmup_ops,
|
|
fwd_bwd_ops,
|
|
cooldown_ops,
|
|
rank,
|
|
forward_stage_index,
|
|
backward_stage_index,
|
|
)
|