mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Revert "[pipelining] [BE] Move pipeline_order validation to schedules.py (#129369)"
This reverts commitec789a3c9d. Reverted https://github.com/pytorch/pytorch/pull/129369 on behalf of https://github.com/clee2000 due to broke test/distributed/pipelining/test_schedule.py::ScheduleTest::test_non_symmetric_stage_ids_ScheduleClass0 on distributed cuda https://github.com/pytorch/pytorch/actions/runs/9766039400/job/26959115773ec789a3c9d. You can see the error on the PR, but Dr. CI classified it wrong ([comment](https://github.com/pytorch/pytorch/pull/129369#issuecomment-2204568418))
This commit is contained in:
parent
b6f781e433
commit
b5fdbc1a9f
|
|
@ -6,6 +6,7 @@ import os
|
||||||
import sys
|
import sys
|
||||||
import tempfile
|
import tempfile
|
||||||
import unittest
|
import unittest
|
||||||
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
from model_registry import ModelWithKwargs, MultiMLP
|
from model_registry import ModelWithKwargs, MultiMLP
|
||||||
from schedule_registry import ScheduleUnbalanced, ScheduleVShaped, ScheduleWithW
|
from schedule_registry import ScheduleUnbalanced, ScheduleVShaped, ScheduleWithW
|
||||||
|
|
@ -21,10 +22,7 @@ from torch.distributed.pipelining import (
|
||||||
ScheduleInterleaved1F1B,
|
ScheduleInterleaved1F1B,
|
||||||
ScheduleLoopedBFS,
|
ScheduleLoopedBFS,
|
||||||
)
|
)
|
||||||
from torch.distributed.pipelining.schedules import (
|
from torch.distributed.pipelining.schedules import _Action, _ComputationType
|
||||||
_format_pipeline_order,
|
|
||||||
_validate_pipeline_order,
|
|
||||||
)
|
|
||||||
from torch.distributed.pipelining.stage import _PipelineStageBase
|
from torch.distributed.pipelining.stage import _PipelineStageBase
|
||||||
from torch.testing._internal.common_cuda import TEST_MULTIGPU
|
from torch.testing._internal.common_cuda import TEST_MULTIGPU
|
||||||
from torch.testing._internal.common_distributed import (
|
from torch.testing._internal.common_distributed import (
|
||||||
|
|
@ -610,7 +608,153 @@ class ScheduleTest(MultiProcContinousTest):
|
||||||
instantiate_parametrized_tests(ScheduleTest)
|
instantiate_parametrized_tests(ScheduleTest)
|
||||||
|
|
||||||
|
|
||||||
|
def format_pipeline_order(pipeline_order: Dict[int, List[Optional[_Action]]]):
|
||||||
|
import itertools
|
||||||
|
|
||||||
|
# Calculate the maximum number of steps across all ranks
|
||||||
|
num_steps = max(len(actions) for actions in pipeline_order.values())
|
||||||
|
step_labels = [
|
||||||
|
"Step " + str(i).zfill(len(str(num_steps - 1))) for i in range(num_steps)
|
||||||
|
]
|
||||||
|
# Sorting the dictionary by keys and retrieving values in that order
|
||||||
|
rank_actions = [
|
||||||
|
pipeline_order.get(key, [""] * num_steps) for key in sorted(pipeline_order)
|
||||||
|
]
|
||||||
|
# Transpose the list of lists (rows to columns)
|
||||||
|
transposed_actions = list(itertools.zip_longest(*rank_actions, fillvalue=""))
|
||||||
|
# Generate column labels for ranks
|
||||||
|
num_ranks = len(pipeline_order)
|
||||||
|
rank_labels = ["Rank " + str(i) for i in range(num_ranks)]
|
||||||
|
# Calculate the maximum length of each column, considering labels
|
||||||
|
max_lengths = [
|
||||||
|
max(len(str(item)) if item is not None else 0 for item in col)
|
||||||
|
for col in zip(step_labels, *transposed_actions)
|
||||||
|
]
|
||||||
|
# Format the header row with rank labels
|
||||||
|
header_row = " " * (len(step_labels[0]) + 2) + " ".join(
|
||||||
|
f"{label:<{max_lengths[i]}}" for i, label in enumerate(rank_labels)
|
||||||
|
)
|
||||||
|
# Format each row with its corresponding label
|
||||||
|
formatted_rows = [
|
||||||
|
f"{label}: "
|
||||||
|
+ " ".join(f"{str(item):<{max_lengths[i]}}" for i, item in enumerate(row))
|
||||||
|
for label, row in zip(step_labels, transposed_actions)
|
||||||
|
]
|
||||||
|
# Join the rows into a single string
|
||||||
|
formatted_table = (
|
||||||
|
"=========== ALL_RANK_ACTIONS ===========\n"
|
||||||
|
+ header_row
|
||||||
|
+ "\n"
|
||||||
|
+ "\n".join(formatted_rows)
|
||||||
|
+ "\n"
|
||||||
|
)
|
||||||
|
return formatted_table
|
||||||
|
|
||||||
|
|
||||||
class TestSchedulePlan(unittest.TestCase):
|
class TestSchedulePlan(unittest.TestCase):
|
||||||
|
def _validate_pipeline_order(
|
||||||
|
self,
|
||||||
|
pipeline_order: Dict[int, List[Optional[_Action]]],
|
||||||
|
num_microbatches: int,
|
||||||
|
num_stages: int,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
pipeline_order[rank] = [(computation_type, microbatch_index, stage_index), ...]
|
||||||
|
|
||||||
|
Validating that the pipeline order follows the rules:
|
||||||
|
1. Forward action for a microbatch must be before the Backward action for that microbatch
|
||||||
|
2. Recv for a microbatch must be before the send for that microbatch
|
||||||
|
3. Microbatch index is handled in sequential order for each stage
|
||||||
|
4. A later stage cannot operate on a microbatch before any of the previous stages have operated on it
|
||||||
|
5. Same microbatch cannot be handled in the same time step across ranks
|
||||||
|
"""
|
||||||
|
# microbatch_index: (current computation type, current stage)
|
||||||
|
error_msg = []
|
||||||
|
microbatch_process_info: Dict[int, Tuple(_ComputationType, int)] = {}
|
||||||
|
max_timestep = max(len(rank_list) for rank_list in pipeline_order.values())
|
||||||
|
for timestep in range(max_timestep):
|
||||||
|
error_msg = []
|
||||||
|
current_timestep_actions = []
|
||||||
|
for rank in range(len(pipeline_order)):
|
||||||
|
action = (
|
||||||
|
pipeline_order[rank][timestep]
|
||||||
|
if timestep < len(pipeline_order[rank])
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
if action is not None:
|
||||||
|
current_timestep_actions.append(action)
|
||||||
|
|
||||||
|
# TODO: enable this
|
||||||
|
# if len(current_timestep_actions) == 0:
|
||||||
|
# error_msg.append(
|
||||||
|
# "All actions were None, there is an unnecessary gap in the schedule"
|
||||||
|
# )
|
||||||
|
|
||||||
|
# Ensure that no microbatch is operated on twice in current_timestep_actions
|
||||||
|
unique_microbatch_indices = {
|
||||||
|
action[1] for action in current_timestep_actions
|
||||||
|
}
|
||||||
|
if len(unique_microbatch_indices) != len(current_timestep_actions):
|
||||||
|
error_msg.append(
|
||||||
|
"Duplicate microbatch index found in current_timestep_actions"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add additional checks for other rules here...
|
||||||
|
for action in current_timestep_actions:
|
||||||
|
computation_type, mb_index, stage_index = action
|
||||||
|
|
||||||
|
if mb_index >= num_microbatches:
|
||||||
|
error_msg.append(f"Microbatch index {mb_index} out of range")
|
||||||
|
|
||||||
|
# first microbatch
|
||||||
|
if mb_index not in microbatch_process_info:
|
||||||
|
if computation_type != _ComputationType.FORWARD or stage_index != 0:
|
||||||
|
error_msg.append(f"Incorrect start for microbatch {mb_index}")
|
||||||
|
microbatch_process_info[mb_index] = (computation_type, stage_index)
|
||||||
|
else:
|
||||||
|
# if the microbatch is included, check that the current stage is right after prev
|
||||||
|
prev_computation, prev_stage = microbatch_process_info[mb_index]
|
||||||
|
if prev_computation == _ComputationType.FORWARD:
|
||||||
|
if prev_stage == num_stages - 1:
|
||||||
|
expected_stage = num_stages - 1
|
||||||
|
expected_computation = _ComputationType.BACKWARD
|
||||||
|
else:
|
||||||
|
expected_stage = prev_stage + 1
|
||||||
|
expected_computation = _ComputationType.FORWARD
|
||||||
|
elif prev_computation == _ComputationType.BACKWARD:
|
||||||
|
if prev_stage == 0:
|
||||||
|
error_msg.append(
|
||||||
|
f"[{mb_index=}] already finished backward computation"
|
||||||
|
)
|
||||||
|
expected_stage = None
|
||||||
|
expected_computation = None
|
||||||
|
else:
|
||||||
|
expected_stage = prev_stage - 1
|
||||||
|
expected_computation = _ComputationType.BACKWARD
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Computation type {prev_computation} not supported"
|
||||||
|
)
|
||||||
|
|
||||||
|
if expected_computation is not None:
|
||||||
|
if expected_computation != computation_type:
|
||||||
|
error_msg.append(
|
||||||
|
f"[{mb_index=}] {expected_computation=} VS. actual {computation_type=}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if expected_stage != stage_index:
|
||||||
|
error_msg.append(
|
||||||
|
f"[{mb_index=}] {expected_stage=} VS. actual {stage_index=}"
|
||||||
|
)
|
||||||
|
|
||||||
|
microbatch_process_info[mb_index] = (
|
||||||
|
expected_computation,
|
||||||
|
expected_stage,
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(error_msg) != 0:
|
||||||
|
self.fail(f"Error at timestep {timestep}: " + ",".join(error_msg))
|
||||||
|
|
||||||
@parametrize(
|
@parametrize(
|
||||||
"ScheduleClass",
|
"ScheduleClass",
|
||||||
[ScheduleFlexibleInterleaved1F1B, ScheduleInterleaved1F1B, ScheduleLoopedBFS],
|
[ScheduleFlexibleInterleaved1F1B, ScheduleInterleaved1F1B, ScheduleLoopedBFS],
|
||||||
|
|
@ -669,11 +813,8 @@ class TestSchedulePlan(unittest.TestCase):
|
||||||
]
|
]
|
||||||
|
|
||||||
schedule = ScheduleClass(stages, num_microbatches)
|
schedule = ScheduleClass(stages, num_microbatches)
|
||||||
formatted_pipeline_order = _format_pipeline_order(
|
# print(format_pipeline_order(schedule.pipeline_order))
|
||||||
schedule.pipeline_order
|
self._validate_pipeline_order(
|
||||||
)
|
|
||||||
# print(formatted_pipeline_order)
|
|
||||||
_validate_pipeline_order(
|
|
||||||
schedule.pipeline_order, num_microbatches, num_stages
|
schedule.pipeline_order, num_microbatches, num_stages
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -7,6 +7,7 @@ from typing import List, Tuple, Union
|
||||||
import torch
|
import torch
|
||||||
from torch import fx
|
from torch import fx
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,6 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates
|
# Copyright (c) Meta Platforms, Inc. and affiliates
|
||||||
|
|
||||||
import csv
|
import csv
|
||||||
import itertools
|
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
|
@ -93,144 +92,6 @@ class _Action(NamedTuple):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _format_pipeline_order(pipeline_order: Dict[int, List[Optional[_Action]]]) -> str:
|
|
||||||
"""
|
|
||||||
Formats the pipeline order in a timestep (row) x rank (column) grid of actions
|
|
||||||
and returns the formatted string
|
|
||||||
"""
|
|
||||||
|
|
||||||
# Calculate the maximum number of steps across all ranks
|
|
||||||
num_steps = max(len(actions) for actions in pipeline_order.values())
|
|
||||||
step_labels = [
|
|
||||||
"Step " + str(i).zfill(len(str(num_steps - 1))) for i in range(num_steps)
|
|
||||||
]
|
|
||||||
# Sorting the dictionary by keys and retrieving values in that order
|
|
||||||
rank_actions = [
|
|
||||||
pipeline_order.get(key, [""] * num_steps) for key in sorted(pipeline_order)
|
|
||||||
]
|
|
||||||
# Transpose the list of lists (rows to columns)
|
|
||||||
transposed_actions = list(itertools.zip_longest(*rank_actions, fillvalue=""))
|
|
||||||
# Generate column labels for ranks
|
|
||||||
num_ranks = len(pipeline_order)
|
|
||||||
rank_labels = ["Rank " + str(i) for i in range(num_ranks)]
|
|
||||||
# Calculate the maximum length of each column, considering labels
|
|
||||||
max_lengths = [
|
|
||||||
max(len(str(item)) if item is not None else 0 for item in col)
|
|
||||||
for col in zip(step_labels, *transposed_actions)
|
|
||||||
]
|
|
||||||
# Format the header row with rank labels
|
|
||||||
header_row = " " * (len(step_labels[0]) + 2) + " ".join(
|
|
||||||
f"{label:<{max_lengths[i]}}" for i, label in enumerate(rank_labels)
|
|
||||||
)
|
|
||||||
# Format each row with its corresponding label
|
|
||||||
formatted_rows = [
|
|
||||||
f"{label}: "
|
|
||||||
+ " ".join(f"{str(item):<{max_lengths[i]}}" for i, item in enumerate(row))
|
|
||||||
for label, row in zip(step_labels, transposed_actions)
|
|
||||||
]
|
|
||||||
# Join the rows into a single string
|
|
||||||
formatted_table = header_row + "\n" + "\n".join(formatted_rows) + "\n"
|
|
||||||
return formatted_table
|
|
||||||
|
|
||||||
|
|
||||||
def _validate_pipeline_order(
|
|
||||||
pipeline_order: Dict[int, List[Optional[_Action]]],
|
|
||||||
num_microbatches: int,
|
|
||||||
num_stages: int,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
pipeline_order[rank] = [(computation_type, microbatch_index, stage_index), ...]
|
|
||||||
Validating that the pipeline order follows the rules:
|
|
||||||
1. Forward action for a microbatch must be before the Backward action for that microbatch
|
|
||||||
2. Recv for a microbatch must be before the send for that microbatch
|
|
||||||
3. Microbatch index is handled in sequential order for each stage
|
|
||||||
4. A later stage cannot operate on a microbatch before any of the previous stages have operated on it
|
|
||||||
5. Same microbatch cannot be handled in the same time step across ranks
|
|
||||||
"""
|
|
||||||
# microbatch_index: (current computation type, current stage)
|
|
||||||
microbatch_process_info: Dict[int, Tuple[_ComputationType, int]] = {}
|
|
||||||
max_timestep = max(len(rank_list) for rank_list in pipeline_order.values())
|
|
||||||
for timestep in range(max_timestep):
|
|
||||||
error_msg: List[str] = []
|
|
||||||
current_timestep_actions = []
|
|
||||||
for rank in range(len(pipeline_order)):
|
|
||||||
action = (
|
|
||||||
pipeline_order[rank][timestep]
|
|
||||||
if timestep < len(pipeline_order[rank])
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
if action is not None:
|
|
||||||
current_timestep_actions.append(action)
|
|
||||||
|
|
||||||
# TODO: enable this
|
|
||||||
# if len(current_timestep_actions) == 0:
|
|
||||||
# error_msg.append(
|
|
||||||
# "All actions were None, there is an unnecessary gap in the schedule"
|
|
||||||
# )
|
|
||||||
|
|
||||||
# Ensure that no microbatch is operated on twice in current_timestep_actions
|
|
||||||
unique_microbatch_indices = {action[1] for action in current_timestep_actions}
|
|
||||||
if len(unique_microbatch_indices) != len(current_timestep_actions):
|
|
||||||
error_msg.append(
|
|
||||||
"Duplicate microbatch index found in current_timestep_actions"
|
|
||||||
)
|
|
||||||
|
|
||||||
for action in current_timestep_actions:
|
|
||||||
computation_type, mb_index, stage_index = action
|
|
||||||
|
|
||||||
if mb_index >= num_microbatches:
|
|
||||||
error_msg.append(f"Microbatch index {mb_index} out of range")
|
|
||||||
|
|
||||||
# first microbatch
|
|
||||||
if mb_index not in microbatch_process_info:
|
|
||||||
if computation_type != _ComputationType.FORWARD or stage_index != 0:
|
|
||||||
error_msg.append(f"Incorrect start for microbatch {mb_index}")
|
|
||||||
microbatch_process_info[mb_index] = (computation_type, stage_index)
|
|
||||||
else:
|
|
||||||
# if the microbatch is included, check that the current stage is right after prev
|
|
||||||
prev_computation, prev_stage = microbatch_process_info[mb_index]
|
|
||||||
|
|
||||||
if prev_computation == _ComputationType.FORWARD:
|
|
||||||
if prev_stage == num_stages - 1:
|
|
||||||
expected_stage = num_stages - 1
|
|
||||||
expected_computation = _ComputationType.BACKWARD
|
|
||||||
else:
|
|
||||||
expected_stage = prev_stage + 1
|
|
||||||
expected_computation = _ComputationType.FORWARD
|
|
||||||
elif prev_computation == _ComputationType.BACKWARD:
|
|
||||||
if prev_stage == 0:
|
|
||||||
error_msg.append(
|
|
||||||
f"[{mb_index=}] already finished backward computation"
|
|
||||||
)
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
expected_stage = prev_stage - 1
|
|
||||||
expected_computation = _ComputationType.BACKWARD
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
f"Computation type {prev_computation} not supported"
|
|
||||||
)
|
|
||||||
|
|
||||||
if expected_computation is not None:
|
|
||||||
if expected_computation != computation_type:
|
|
||||||
error_msg.append(
|
|
||||||
f"[{mb_index=}] {expected_computation=} VS. actual {computation_type=}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if expected_stage != stage_index:
|
|
||||||
error_msg.append(
|
|
||||||
f"[{mb_index=}] {expected_stage=} VS. actual {stage_index=}"
|
|
||||||
)
|
|
||||||
|
|
||||||
microbatch_process_info[mb_index] = (
|
|
||||||
expected_computation,
|
|
||||||
expected_stage,
|
|
||||||
)
|
|
||||||
|
|
||||||
if len(error_msg) != 0:
|
|
||||||
raise RuntimeError(f"Error at timestep {timestep}: " + ",".join(error_msg))
|
|
||||||
|
|
||||||
|
|
||||||
class _PipelineSchedule(ABC):
|
class _PipelineSchedule(ABC):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|
@ -943,106 +804,90 @@ class PipelineScheduleMulti(_PipelineSchedule):
|
||||||
all_next_ranks.add(self.stage_index_to_group_rank[stage_index + 1])
|
all_next_ranks.add(self.stage_index_to_group_rank[stage_index + 1])
|
||||||
|
|
||||||
for time_step, action in enumerate(self.pipeline_order[self.rank]):
|
for time_step, action in enumerate(self.pipeline_order[self.rank]):
|
||||||
try:
|
ops: List[dist.P2POp] = []
|
||||||
ops: List[dist.P2POp] = []
|
if action is not None:
|
||||||
if action is not None:
|
computation_type, mb_index, stage_index = action
|
||||||
computation_type, mb_index, stage_index = action
|
if computation_type == _ComputationType.FORWARD:
|
||||||
|
# perform forward computation
|
||||||
|
stage = stage_index_to_stage[stage_index]
|
||||||
|
output = stage.forward_one_chunk(
|
||||||
|
mb_index, arg_mbs[mb_index], kwarg_mbs[mb_index]
|
||||||
|
)
|
||||||
|
self._maybe_compute_loss(stage, output, target_mbs, mb_index)
|
||||||
|
ops.extend(stage.get_fwd_send_ops(mb_index))
|
||||||
|
elif computation_type == _ComputationType.BACKWARD:
|
||||||
|
# perform backward computation
|
||||||
|
stage = stage_index_to_stage[stage_index]
|
||||||
|
loss = self._maybe_get_loss(stage, mb_index)
|
||||||
|
stage.backward_one_chunk(
|
||||||
|
mb_index, loss=loss, full_backward=self.use_full_backward
|
||||||
|
)
|
||||||
|
ops.extend(stage.get_bwd_send_ops(mb_index))
|
||||||
|
elif computation_type == _ComputationType.WEIGHT:
|
||||||
|
# perform weight update
|
||||||
|
if self.use_full_backward:
|
||||||
|
raise ValueError(
|
||||||
|
f"We detected a weight update in the pipeline schedule, but \
|
||||||
|
{self.use_full_backward=}"
|
||||||
|
)
|
||||||
|
stage = stage_index_to_stage[stage_index]
|
||||||
|
stage.backward_weight_one_chunk(mb_index)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown computation type {computation_type}")
|
||||||
|
|
||||||
|
# Look at the neighboring ranks for this current timestep and determine whether
|
||||||
|
# this current rank needs to do any recv communication
|
||||||
|
for prev_rank in all_prev_ranks:
|
||||||
|
prev_rank_ops = self.pipeline_order[prev_rank]
|
||||||
|
prev_rank_action = None
|
||||||
|
if time_step < len(prev_rank_ops):
|
||||||
|
prev_rank_action = prev_rank_ops[time_step]
|
||||||
|
if prev_rank_action is not None:
|
||||||
|
computation_type, mb_index, stage_index = prev_rank_action
|
||||||
|
# Only handle sends for the forward from a previous rank
|
||||||
if computation_type == _ComputationType.FORWARD:
|
if computation_type == _ComputationType.FORWARD:
|
||||||
# perform forward computation
|
# If not the last stage, then receive fwd activations
|
||||||
stage = stage_index_to_stage[stage_index]
|
if stage_index + 1 in stage_index_to_stage:
|
||||||
output = stage.forward_one_chunk(
|
# TODO: We are assuming that stage will always receive from stage-1
|
||||||
mb_index, arg_mbs[mb_index], kwarg_mbs[mb_index]
|
# however that is not necessarily true of get_fwd_recv_ops
|
||||||
)
|
stage = stage_index_to_stage[stage_index + 1]
|
||||||
self._maybe_compute_loss(stage, output, target_mbs, mb_index)
|
ops.extend(stage.get_fwd_recv_ops(mb_index))
|
||||||
ops.extend(stage.get_fwd_send_ops(mb_index))
|
elif (
|
||||||
elif computation_type == _ComputationType.BACKWARD:
|
computation_type == _ComputationType.BACKWARD
|
||||||
# perform backward computation
|
or computation_type == _ComputationType.WEIGHT
|
||||||
stage = stage_index_to_stage[stage_index]
|
):
|
||||||
loss = self._maybe_get_loss(stage, mb_index)
|
# Previous rank doing backward or weight update has no influence for the current rank forward recv
|
||||||
stage.backward_one_chunk(
|
pass
|
||||||
mb_index, loss=loss, full_backward=self.use_full_backward
|
|
||||||
)
|
|
||||||
ops.extend(stage.get_bwd_send_ops(mb_index))
|
|
||||||
elif computation_type == _ComputationType.WEIGHT:
|
|
||||||
# perform weight update
|
|
||||||
if self.use_full_backward:
|
|
||||||
raise ValueError(
|
|
||||||
f"We detected a weight update in the pipeline schedule, but \
|
|
||||||
{self.use_full_backward=}"
|
|
||||||
)
|
|
||||||
stage = stage_index_to_stage[stage_index]
|
|
||||||
stage.backward_weight_one_chunk(mb_index)
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown computation type {computation_type}")
|
raise ValueError(f"Unknown computation type {computation_type}")
|
||||||
|
|
||||||
# Look at the neighboring ranks for this current timestep and determine whether
|
for next_rank in all_next_ranks:
|
||||||
# this current rank needs to do any recv communication
|
next_rank_ops = self.pipeline_order[next_rank]
|
||||||
for prev_rank in all_prev_ranks:
|
next_rank_action = None
|
||||||
prev_rank_ops = self.pipeline_order[prev_rank]
|
if time_step < len(next_rank_ops):
|
||||||
prev_rank_action = None
|
next_rank_action = next_rank_ops[time_step]
|
||||||
if time_step < len(prev_rank_ops):
|
if next_rank_action is not None:
|
||||||
prev_rank_action = prev_rank_ops[time_step]
|
computation_type, mb_index, stage_index = next_rank_action
|
||||||
if prev_rank_action is not None:
|
# Only handle receives for the backwards from a next rank
|
||||||
computation_type, mb_index, stage_index = prev_rank_action
|
if (
|
||||||
# Only handle sends for the forward from a previous rank
|
computation_type == _ComputationType.FORWARD
|
||||||
if computation_type == _ComputationType.FORWARD:
|
or computation_type == _ComputationType.WEIGHT
|
||||||
# If not the last stage, then receive fwd activations
|
):
|
||||||
if stage_index + 1 in stage_index_to_stage:
|
# Next rank doing forward or weight update has no influence for the current rank backward recv
|
||||||
# TODO: We are assuming that stage will always receive from stage-1
|
pass
|
||||||
# however that is not necessarily true of get_fwd_recv_ops
|
elif computation_type == _ComputationType.BACKWARD:
|
||||||
stage = stage_index_to_stage[stage_index + 1]
|
# If not the first stage, then receive bwd gradients
|
||||||
ops.extend(stage.get_fwd_recv_ops(mb_index))
|
if stage_index - 1 in stage_index_to_stage:
|
||||||
elif (
|
# TODO: We are assuming that stage will always receive from stage+1
|
||||||
computation_type == _ComputationType.BACKWARD
|
# however that is not necessarily true of get_bwd_recv_ops
|
||||||
or computation_type == _ComputationType.WEIGHT
|
stage = stage_index_to_stage[stage_index - 1]
|
||||||
):
|
ops.extend(stage.get_bwd_recv_ops(mb_index))
|
||||||
# Previous rank doing backward or weight update has no influence for the current rank forward recv
|
else:
|
||||||
pass
|
raise ValueError(f"Unknown computation type {computation_type}")
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
f"Unknown computation type {computation_type}"
|
|
||||||
)
|
|
||||||
|
|
||||||
for next_rank in all_next_ranks:
|
# do the communication
|
||||||
next_rank_ops = self.pipeline_order[next_rank]
|
if ops:
|
||||||
next_rank_action = None
|
_batch_p2p(ops).wait()
|
||||||
if time_step < len(next_rank_ops):
|
|
||||||
next_rank_action = next_rank_ops[time_step]
|
|
||||||
if next_rank_action is not None:
|
|
||||||
computation_type, mb_index, stage_index = next_rank_action
|
|
||||||
# Only handle receives for the backwards from a next rank
|
|
||||||
if (
|
|
||||||
computation_type == _ComputationType.FORWARD
|
|
||||||
or computation_type == _ComputationType.WEIGHT
|
|
||||||
):
|
|
||||||
# Next rank doing forward or weight update has no influence for the current rank backward recv
|
|
||||||
pass
|
|
||||||
elif computation_type == _ComputationType.BACKWARD:
|
|
||||||
# If not the first stage, then receive bwd gradients
|
|
||||||
if stage_index - 1 in stage_index_to_stage:
|
|
||||||
# TODO: We are assuming that stage will always receive from stage+1
|
|
||||||
# however that is not necessarily true of get_bwd_recv_ops
|
|
||||||
stage = stage_index_to_stage[stage_index - 1]
|
|
||||||
ops.extend(stage.get_bwd_recv_ops(mb_index))
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
f"Unknown computation type {computation_type}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# do the communication
|
|
||||||
if ops:
|
|
||||||
_batch_p2p(ops).wait()
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(
|
|
||||||
"[Rank %s] pipeline schedule %s caught the following exception \
|
|
||||||
at time_step %s when running action %s",
|
|
||||||
self.rank,
|
|
||||||
self.__class__.__name__,
|
|
||||||
time_step,
|
|
||||||
action,
|
|
||||||
)
|
|
||||||
logger.error("%s", _format_pipeline_order(self.pipeline_order))
|
|
||||||
raise e
|
|
||||||
# Return losses if there is a container passed in
|
# Return losses if there is a container passed in
|
||||||
self._update_losses(self._stages, losses)
|
self._update_losses(self._stages, losses)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -385,7 +385,7 @@ class _PipelineStageBase(ABC):
|
||||||
else:
|
else:
|
||||||
if not (grad is None and grad_recv_stage is None):
|
if not (grad is None and grad_recv_stage is None):
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"[{self.stage_index}] for chunk {bwd_chunk_id} has gradients {grad} "
|
f"[{self.stage_index}] for chunk {bwd_chunk_id - 1} has gradients {grad} "
|
||||||
f"and is expecting to send gradients to stage {grad_recv_stage}"
|
f"and is expecting to send gradients to stage {grad_recv_stage}"
|
||||||
)
|
)
|
||||||
return ops
|
return ops
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user