Inspect schedule IR comms (#162996)

Small change to util to allow us to see comms (e.g. `SEND`, `RECV`, etc.) in the schedule IR

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162996
Approved by: https://github.com/fegin
This commit is contained in:
Howard Huang 2025-09-16 07:11:36 -07:00 committed by PyTorch MergeBot
parent f638854e1d
commit 9de22bc5da

View File

@ -17,6 +17,7 @@ from torch.distributed.pipelining.schedules import (
_Action,
_ComputationType,
_PipelineSchedule,
_PipelineScheduleRuntime,
get_schedule_class,
PipelineScheduleMulti,
PipelineScheduleSingle,
@ -36,6 +37,7 @@ def get_schedule_ops(
num_microbatches: int,
num_stages_per_rank: Optional[int] = None,
add_spacing: bool = False,
with_comms: bool = False,
) -> list[list[Optional[_Action]]]:
"""
Get all actions for a given schedule, pp_degree, and num_microbatches. The actions are returned in a list of lists
@ -43,6 +45,8 @@ def get_schedule_ops(
The schedule can be specified as a string which is passed into get_schedule_class() or a _PipelineSchedule instance.
"""
if add_spacing and with_comms:
raise ValueError("Cannot add spacing and view comms at the same time")
if isinstance(schedule, str):
schedule_class = get_schedule_class(schedule)
@ -78,11 +82,18 @@ def get_schedule_ops(
# Instantiate the schedule class
schedule_instance = schedule_class(stages, num_microbatches)
assert schedule_instance.pipeline_order is not None
# Convert to List[List[_Action]]
all_actions = []
for rank in range(pp_degree):
all_actions.append(schedule_instance.pipeline_order[rank])
all_actions: list[list[Optional[_Action]]] = []
if with_comms:
runtime = _PipelineScheduleRuntime(stages, num_microbatches)
runtime._prepare_schedule_with_comms(schedule_instance.pipeline_order)
for rank in range(pp_degree):
all_actions.append(list(runtime.pipeline_order_with_comms[rank]))
else:
for rank in range(pp_degree):
all_actions.append(schedule_instance.pipeline_order[rank])
# Add spacing
if add_spacing: