mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Some changes to validation code and visualizer to support a new computation type that will be used in DualPipeV (see https://github.com/pytorch/pytorch/pull/159591)
The IR looks like:
```
[0F0, 0F1, 0F2, 0F3, 0F4, 0F5, 0F6, 7F0, 7I0, 7W0, 7F1, 7I1, 7W1, 7F2, 7I2, 7W2, 7F3, (0F7;7B3)OVERLAP_F_B, (7F4;0B0)OVERLAP_F_B, (0F8;7B4)OVERLAP_F_B, (7F5;0B1)OVERLAP_F_B, (0F9;7B5)OVERLAP_F_B, (7F6;0B2)OVERLAP_F_B, 7B6, (7F7;0B3)OVERLAP_F_B, 7B7, (7F8;0B4)OVERLAP_F_B, 7B8, (7F9;0B5)OVERLAP_F_B, 7B9, 0I6, 0W6, 0I7, 0W7, 0I8, 0W8, 0I9, 0W9]
[1F0, 1F1, 1F2, 1F3, 1F4, 6F0, 1F5, 6F1, 6I0, 6W0, 6F2, 6I1, 6W1, 6F3, (1F6;6B2)OVERLAP_F_B, (6F4;1B0)OVERLAP_F_B, (1F7;6B3)OVERLAP_F_B, (6F5;1B1)OVERLAP_F_B, (1F8;6B4)OVERLAP_F_B, (6F6;1B2)OVERLAP_F_B, (1F9;6B5)OVERLAP_F_B, (6F7;1B3)OVERLAP_F_B, 6B6, (6F8;1B4)OVERLAP_F_B, 6B7, (6F9;1B5)OVERLAP_F_B, 6B8, 1B6, 6I9, 1I7, 6W9, 1I8, 1W7, 1I9, 1W8, 1W9]
[2F0, 2F1, 2F2, 5F0, 2F3, 5F1, 2F4, 5F2, 5I0, 5W0, 5F3, (2F5;5B1)OVERLAP_F_B, (5F4;2B0)OVERLAP_F_B, (2F6;5B2)OVERLAP_F_B, (5F5;2B1)OVERLAP_F_B, (2F7;5B3)OVERLAP_F_B, (5F6;2B2)OVERLAP_F_B, (2F8;5B4)OVERLAP_F_B, (5F7;2B3)OVERLAP_F_B, (2F9;5B5)OVERLAP_F_B, (5F8;2B4)OVERLAP_F_B, 5B6, (5F9;2B5)OVERLAP_F_B, 5B7, 2B6, 5B8, 2I7, 5I9, 2I8, 2W7, 2I9, 5W9, 2W8, 2W9]
[3F0, 4F0, 3F1, 4F1, 3F2, 4F2, 3F3, 4F3, 3F4, 4B0, (4F4;3B0)OVERLAP_F_B, (3F5;4B1)OVERLAP_F_B, (4F5;3B1)OVERLAP_F_B, (3F6;4B2)OVERLAP_F_B, (4F6;3B2)OVERLAP_F_B, (3F7;4B3)OVERLAP_F_B, (4F7;3B3)OVERLAP_F_B, (3F8;4B4)OVERLAP_F_B, (4F8;3B4)OVERLAP_F_B, (3F9;4B5)OVERLAP_F_B, (4F9;3B5)OVERLAP_F_B, 4B6, 3B6, 4B7, 3B7, 4I8, 3I8, 4I9, 3I9, 4W8, 3W8, 4W9, 3W9]
```
In this PR, the schedule execution will just treat the OVERLAP_F_B as two separate operations of F and B (so there is no actual overlap). The next step is to allow users to create a custom function to plug in what this operation does.
814629043a/torch/distributed/pipelining/schedules.py (L1205-L1216)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/158978
Approved by: https://github.com/wconstab
203 lines
7.3 KiB
Python
203 lines
7.3 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates
|
|
|
|
"""
|
|
This visualizer requires matplotlib to be installed.
|
|
|
|
Example usage:
|
|
|
|
ops = get_schedule_ops("InterleavedZeroBubble", 4, 8)
|
|
visualize_schedule(ops, "test.png")
|
|
"""
|
|
|
|
from typing import Optional, Union
|
|
from unittest import mock
|
|
|
|
from torch.distributed.pipelining.schedules import (
|
|
_Action,
|
|
_ComputationType,
|
|
_PipelineSchedule,
|
|
get_schedule_class,
|
|
PipelineScheduleMulti,
|
|
PipelineScheduleSingle,
|
|
)
|
|
from torch.distributed.pipelining.stage import PipelineStage
|
|
|
|
|
|
def get_schedule_ops(
|
|
schedule: Union[str, type[_PipelineSchedule]],
|
|
pp_degree: int,
|
|
num_microbatches: int,
|
|
num_stages_per_rank: Optional[int] = None,
|
|
) -> list[list[Optional[_Action]]]:
|
|
"""
|
|
Get all actions for a given schedule, pp_degree, and num_microbatches. The actions are returned in a list of lists
|
|
where each inner list represents a rank and each element in the inner list represents an action.
|
|
|
|
The schedule can be specified as a string which is passed into get_schedule_class() or a _PipelineSchedule instance.
|
|
"""
|
|
|
|
if isinstance(schedule, str):
|
|
schedule_class = get_schedule_class(schedule)
|
|
elif issubclass(schedule, _PipelineSchedule):
|
|
schedule_class = schedule
|
|
else:
|
|
raise ValueError(f"Invalid schedule: {schedule}")
|
|
|
|
# Create a mock of the PipelineStage class
|
|
mock_pipeline_stage = mock.create_autospec(PipelineStage, instance=True)
|
|
# Set the return values for group_rank and group_size methods
|
|
mock_pipeline_stage.group_rank = 0
|
|
mock_pipeline_stage.group_size = pp_degree
|
|
mock_pipeline_stage.submod = None
|
|
|
|
# Check num_stages_per_rank is valid
|
|
if issubclass(schedule_class, PipelineScheduleSingle):
|
|
if num_stages_per_rank is None:
|
|
num_stages_per_rank = 1
|
|
assert num_stages_per_rank == 1
|
|
stages = mock_pipeline_stage
|
|
stages.num_stages = num_stages_per_rank * pp_degree
|
|
elif issubclass(schedule_class, PipelineScheduleMulti):
|
|
if num_stages_per_rank is None:
|
|
num_stages_per_rank = 2
|
|
assert num_stages_per_rank >= 2
|
|
stages = [mock_pipeline_stage for _ in range(num_stages_per_rank)]
|
|
for stage in stages:
|
|
stage.num_stages = num_stages_per_rank * pp_degree
|
|
|
|
else:
|
|
raise ValueError(f"Invalid schedule: {schedule_class}")
|
|
|
|
# Instantiate the schedule class
|
|
schedule_instance = schedule_class(stages, num_microbatches)
|
|
|
|
# Convert to List[List[_Action]]
|
|
all_actions = []
|
|
for rank in range(pp_degree):
|
|
all_actions.append(schedule_instance.pipeline_order[rank])
|
|
|
|
# Return the pipeline order
|
|
return all_actions
|
|
|
|
|
|
class _ComputationTypeColor:
|
|
def __init__(
|
|
self,
|
|
color: str,
|
|
text: str = "",
|
|
width: int = 1,
|
|
):
|
|
self.color = color
|
|
self.width = width
|
|
self.text = text
|
|
|
|
|
|
# Update the mapping to use _ComputationTypeColor instances
|
|
action_type_to_color_mapping = {
|
|
_ComputationType.FORWARD: _ComputationTypeColor("blue", "Forward"),
|
|
_ComputationType.BACKWARD_INPUT: _ComputationTypeColor("teal", "Backward Input"),
|
|
_ComputationType.BACKWARD_WEIGHT: _ComputationTypeColor("green", "Backward Weight"),
|
|
_ComputationType.FULL_BACKWARD: _ComputationTypeColor("orange", "Full Backward", 2),
|
|
_ComputationType.OVERLAP_F_B: _ComputationTypeColor("purple", "Overlap F+B", 3),
|
|
}
|
|
|
|
|
|
def visualize_schedule(
|
|
schedule: list[list[Optional[_Action]]], filename: Optional[str] = None
|
|
) -> None:
|
|
"""
|
|
Visualize the schedule using matplotlib.
|
|
The schedule is a list of lists where each inner list represents a rank and each element in the inner list represents an action.
|
|
The actions are represented as rectangles with different colors based on their computation type.
|
|
The filename is optional and if provided, the plot will be saved to that file.
|
|
"""
|
|
|
|
import matplotlib.pyplot as plt
|
|
from matplotlib.patches import Rectangle
|
|
|
|
plt.rcParams["font.family"] = (
|
|
"DejaVu Sans" # or any other font available on your system
|
|
)
|
|
num_ranks = len(schedule)
|
|
max_actions = max(len(rank) for rank in schedule)
|
|
|
|
# Increase the figure size to provide more space for the legend
|
|
fig, ax = plt.subplots(figsize=(max_actions + 2, num_ranks + 2))
|
|
max_draw_position = -1
|
|
# Calculate dynamic font size based on figure size
|
|
font_size = min(max_actions, num_ranks) + 4
|
|
used_computation = set()
|
|
for rank_idx, actions in enumerate(schedule):
|
|
draw_position = 0 # Initialize drawing position for each rank
|
|
for action in actions:
|
|
if action is not None:
|
|
comp_type_color = action_type_to_color_mapping.get(
|
|
action.computation_type, _ComputationTypeColor("black")
|
|
)
|
|
used_computation.add(action.computation_type)
|
|
color = comp_type_color.color
|
|
width = comp_type_color.width
|
|
|
|
# Check if action has sub_actions to determine styling
|
|
if action.sub_actions is not None:
|
|
linewidth = 2 # Thicker border for compound actions
|
|
text_weight = "normal" # Bold text for compound actions
|
|
else:
|
|
linewidth = 1 # Default linewidth for regular actions
|
|
text_weight = "normal" # Default text weight
|
|
|
|
# Draw the rectangle to represent the action duration
|
|
rect = Rectangle(
|
|
(draw_position, num_ranks - rank_idx - 1),
|
|
width,
|
|
1,
|
|
facecolor=color,
|
|
edgecolor="black",
|
|
linewidth=linewidth,
|
|
)
|
|
ax.add_patch(rect)
|
|
|
|
# Draw the text centered within the rectangle
|
|
ax.text(
|
|
draw_position + width / 2,
|
|
num_ranks - rank_idx - 1 + 0.5,
|
|
str(action),
|
|
ha="center",
|
|
va="center",
|
|
fontsize=font_size,
|
|
color="white",
|
|
weight=text_weight,
|
|
)
|
|
|
|
draw_position += width
|
|
else:
|
|
draw_position += 1 # Move to the next
|
|
max_draw_position = max(max_draw_position, draw_position)
|
|
ax.set_xlim(-0.5, max_draw_position + 1)
|
|
ax.set_ylim(-0.5, num_ranks + 0.5) # Add extra space at the top
|
|
# Set y-ticks to be in the middle of each rank's row
|
|
ax.set_yticks([num_ranks - rank_idx - 0.5 for rank_idx in range(num_ranks)])
|
|
ax.set_yticklabels([f"Rank {i}" for i in range(num_ranks)], fontsize=font_size)
|
|
ax.set_xticklabels([])
|
|
|
|
# Remove grid lines and ticks
|
|
ax.grid(False)
|
|
# Add legend with larger font size
|
|
legend_elements = [
|
|
Rectangle(
|
|
(0, 0),
|
|
1,
|
|
1,
|
|
facecolor=action_type_to_color_mapping[comp_type].color,
|
|
edgecolor="black",
|
|
label=action_type_to_color_mapping[comp_type].text,
|
|
)
|
|
for comp_type in used_computation
|
|
]
|
|
ax.legend(handles=legend_elements, loc="upper right", fontsize=font_size)
|
|
# Save to file if filename is provided, otherwise display the plot
|
|
if filename:
|
|
plt.savefig(filename, bbox_inches="tight")
|
|
else:
|
|
plt.show()
|