pytorch/torch/distributed/pipelining/_schedule_visualizer.py
Howard Huang 5e8b95605f [PP] Support OVERLAP_F_B computation type (#158978)
Some changes to validation code and visualizer to support a new computation type that will be used in DualPipeV (see https://github.com/pytorch/pytorch/pull/159591)

The IR looks like:

```
[0F0, 0F1, 0F2, 0F3, 0F4, 0F5, 0F6, 7F0, 7I0, 7W0, 7F1, 7I1, 7W1, 7F2, 7I2, 7W2, 7F3, (0F7;7B3)OVERLAP_F_B, (7F4;0B0)OVERLAP_F_B, (0F8;7B4)OVERLAP_F_B, (7F5;0B1)OVERLAP_F_B, (0F9;7B5)OVERLAP_F_B, (7F6;0B2)OVERLAP_F_B, 7B6, (7F7;0B3)OVERLAP_F_B, 7B7, (7F8;0B4)OVERLAP_F_B, 7B8, (7F9;0B5)OVERLAP_F_B, 7B9, 0I6, 0W6, 0I7, 0W7, 0I8, 0W8, 0I9, 0W9]
[1F0, 1F1, 1F2, 1F3, 1F4, 6F0, 1F5, 6F1, 6I0, 6W0, 6F2, 6I1, 6W1, 6F3, (1F6;6B2)OVERLAP_F_B, (6F4;1B0)OVERLAP_F_B, (1F7;6B3)OVERLAP_F_B, (6F5;1B1)OVERLAP_F_B, (1F8;6B4)OVERLAP_F_B, (6F6;1B2)OVERLAP_F_B, (1F9;6B5)OVERLAP_F_B, (6F7;1B3)OVERLAP_F_B, 6B6, (6F8;1B4)OVERLAP_F_B, 6B7, (6F9;1B5)OVERLAP_F_B, 6B8, 1B6, 6I9, 1I7, 6W9, 1I8, 1W7, 1I9, 1W8, 1W9]
[2F0, 2F1, 2F2, 5F0, 2F3, 5F1, 2F4, 5F2, 5I0, 5W0, 5F3, (2F5;5B1)OVERLAP_F_B, (5F4;2B0)OVERLAP_F_B, (2F6;5B2)OVERLAP_F_B, (5F5;2B1)OVERLAP_F_B, (2F7;5B3)OVERLAP_F_B, (5F6;2B2)OVERLAP_F_B, (2F8;5B4)OVERLAP_F_B, (5F7;2B3)OVERLAP_F_B, (2F9;5B5)OVERLAP_F_B, (5F8;2B4)OVERLAP_F_B, 5B6, (5F9;2B5)OVERLAP_F_B, 5B7, 2B6, 5B8, 2I7, 5I9, 2I8, 2W7, 2I9, 5W9, 2W8, 2W9]
[3F0, 4F0, 3F1, 4F1, 3F2, 4F2, 3F3, 4F3, 3F4, 4B0, (4F4;3B0)OVERLAP_F_B, (3F5;4B1)OVERLAP_F_B, (4F5;3B1)OVERLAP_F_B, (3F6;4B2)OVERLAP_F_B, (4F6;3B2)OVERLAP_F_B, (3F7;4B3)OVERLAP_F_B, (4F7;3B3)OVERLAP_F_B, (3F8;4B4)OVERLAP_F_B, (4F8;3B4)OVERLAP_F_B, (3F9;4B5)OVERLAP_F_B, (4F9;3B5)OVERLAP_F_B, 4B6, 3B6, 4B7, 3B7, 4I8, 3I8, 4I9, 3I9, 4W8, 3W8, 4W9, 3W9]
```

In this PR, the schedule execution will just treat the OVERLAP_F_B as two separate operations of F and B (so there is no actual overlap). The next step is to allow users to create a custom function to plug in what this operation does.

814629043a/torch/distributed/pipelining/schedules.py (L1205-L1216)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/158978
Approved by: https://github.com/wconstab
2025-08-01 20:22:30 +00:00

203 lines
7.3 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates
"""
This visualizer requires matplotlib to be installed.
Example usage:
ops = get_schedule_ops("InterleavedZeroBubble", 4, 8)
visualize_schedule(ops, "test.png")
"""
from typing import Optional, Union
from unittest import mock
from torch.distributed.pipelining.schedules import (
_Action,
_ComputationType,
_PipelineSchedule,
get_schedule_class,
PipelineScheduleMulti,
PipelineScheduleSingle,
)
from torch.distributed.pipelining.stage import PipelineStage
def get_schedule_ops(
schedule: Union[str, type[_PipelineSchedule]],
pp_degree: int,
num_microbatches: int,
num_stages_per_rank: Optional[int] = None,
) -> list[list[Optional[_Action]]]:
"""
Get all actions for a given schedule, pp_degree, and num_microbatches. The actions are returned in a list of lists
where each inner list represents a rank and each element in the inner list represents an action.
The schedule can be specified as a string which is passed into get_schedule_class() or a _PipelineSchedule instance.
"""
if isinstance(schedule, str):
schedule_class = get_schedule_class(schedule)
elif issubclass(schedule, _PipelineSchedule):
schedule_class = schedule
else:
raise ValueError(f"Invalid schedule: {schedule}")
# Create a mock of the PipelineStage class
mock_pipeline_stage = mock.create_autospec(PipelineStage, instance=True)
# Set the return values for group_rank and group_size methods
mock_pipeline_stage.group_rank = 0
mock_pipeline_stage.group_size = pp_degree
mock_pipeline_stage.submod = None
# Check num_stages_per_rank is valid
if issubclass(schedule_class, PipelineScheduleSingle):
if num_stages_per_rank is None:
num_stages_per_rank = 1
assert num_stages_per_rank == 1
stages = mock_pipeline_stage
stages.num_stages = num_stages_per_rank * pp_degree
elif issubclass(schedule_class, PipelineScheduleMulti):
if num_stages_per_rank is None:
num_stages_per_rank = 2
assert num_stages_per_rank >= 2
stages = [mock_pipeline_stage for _ in range(num_stages_per_rank)]
for stage in stages:
stage.num_stages = num_stages_per_rank * pp_degree
else:
raise ValueError(f"Invalid schedule: {schedule_class}")
# Instantiate the schedule class
schedule_instance = schedule_class(stages, num_microbatches)
# Convert to List[List[_Action]]
all_actions = []
for rank in range(pp_degree):
all_actions.append(schedule_instance.pipeline_order[rank])
# Return the pipeline order
return all_actions
class _ComputationTypeColor:
def __init__(
self,
color: str,
text: str = "",
width: int = 1,
):
self.color = color
self.width = width
self.text = text
# Update the mapping to use _ComputationTypeColor instances
action_type_to_color_mapping = {
_ComputationType.FORWARD: _ComputationTypeColor("blue", "Forward"),
_ComputationType.BACKWARD_INPUT: _ComputationTypeColor("teal", "Backward Input"),
_ComputationType.BACKWARD_WEIGHT: _ComputationTypeColor("green", "Backward Weight"),
_ComputationType.FULL_BACKWARD: _ComputationTypeColor("orange", "Full Backward", 2),
_ComputationType.OVERLAP_F_B: _ComputationTypeColor("purple", "Overlap F+B", 3),
}
def visualize_schedule(
schedule: list[list[Optional[_Action]]], filename: Optional[str] = None
) -> None:
"""
Visualize the schedule using matplotlib.
The schedule is a list of lists where each inner list represents a rank and each element in the inner list represents an action.
The actions are represented as rectangles with different colors based on their computation type.
The filename is optional and if provided, the plot will be saved to that file.
"""
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
plt.rcParams["font.family"] = (
"DejaVu Sans" # or any other font available on your system
)
num_ranks = len(schedule)
max_actions = max(len(rank) for rank in schedule)
# Increase the figure size to provide more space for the legend
fig, ax = plt.subplots(figsize=(max_actions + 2, num_ranks + 2))
max_draw_position = -1
# Calculate dynamic font size based on figure size
font_size = min(max_actions, num_ranks) + 4
used_computation = set()
for rank_idx, actions in enumerate(schedule):
draw_position = 0 # Initialize drawing position for each rank
for action in actions:
if action is not None:
comp_type_color = action_type_to_color_mapping.get(
action.computation_type, _ComputationTypeColor("black")
)
used_computation.add(action.computation_type)
color = comp_type_color.color
width = comp_type_color.width
# Check if action has sub_actions to determine styling
if action.sub_actions is not None:
linewidth = 2 # Thicker border for compound actions
text_weight = "normal" # Bold text for compound actions
else:
linewidth = 1 # Default linewidth for regular actions
text_weight = "normal" # Default text weight
# Draw the rectangle to represent the action duration
rect = Rectangle(
(draw_position, num_ranks - rank_idx - 1),
width,
1,
facecolor=color,
edgecolor="black",
linewidth=linewidth,
)
ax.add_patch(rect)
# Draw the text centered within the rectangle
ax.text(
draw_position + width / 2,
num_ranks - rank_idx - 1 + 0.5,
str(action),
ha="center",
va="center",
fontsize=font_size,
color="white",
weight=text_weight,
)
draw_position += width
else:
draw_position += 1 # Move to the next
max_draw_position = max(max_draw_position, draw_position)
ax.set_xlim(-0.5, max_draw_position + 1)
ax.set_ylim(-0.5, num_ranks + 0.5) # Add extra space at the top
# Set y-ticks to be in the middle of each rank's row
ax.set_yticks([num_ranks - rank_idx - 0.5 for rank_idx in range(num_ranks)])
ax.set_yticklabels([f"Rank {i}" for i in range(num_ranks)], fontsize=font_size)
ax.set_xticklabels([])
# Remove grid lines and ticks
ax.grid(False)
# Add legend with larger font size
legend_elements = [
Rectangle(
(0, 0),
1,
1,
facecolor=action_type_to_color_mapping[comp_type].color,
edgecolor="black",
label=action_type_to_color_mapping[comp_type].text,
)
for comp_type in used_computation
]
ax.legend(handles=legend_elements, loc="upper right", fontsize=font_size)
# Save to file if filename is provided, otherwise display the plot
if filename:
plt.savefig(filename, bbox_inches="tight")
else:
plt.show()