mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
Inspect schedule IR comms (#162996)
Small change to util to allow us to see comms (e.g. `SEND`, `RECV`, etc.) in the schedule IR Pull Request resolved: https://github.com/pytorch/pytorch/pull/162996 Approved by: https://github.com/fegin
This commit is contained in:
parent
f638854e1d
commit
9de22bc5da
|
|
@ -17,6 +17,7 @@ from torch.distributed.pipelining.schedules import (
|
|||
_Action,
|
||||
_ComputationType,
|
||||
_PipelineSchedule,
|
||||
_PipelineScheduleRuntime,
|
||||
get_schedule_class,
|
||||
PipelineScheduleMulti,
|
||||
PipelineScheduleSingle,
|
||||
|
|
@ -36,6 +37,7 @@ def get_schedule_ops(
|
|||
num_microbatches: int,
|
||||
num_stages_per_rank: Optional[int] = None,
|
||||
add_spacing: bool = False,
|
||||
with_comms: bool = False,
|
||||
) -> list[list[Optional[_Action]]]:
|
||||
"""
|
||||
Get all actions for a given schedule, pp_degree, and num_microbatches. The actions are returned in a list of lists
|
||||
|
|
@ -43,6 +45,8 @@ def get_schedule_ops(
|
|||
|
||||
The schedule can be specified as a string which is passed into get_schedule_class() or a _PipelineSchedule instance.
|
||||
"""
|
||||
if add_spacing and with_comms:
|
||||
raise ValueError("Cannot add spacing and view comms at the same time")
|
||||
|
||||
if isinstance(schedule, str):
|
||||
schedule_class = get_schedule_class(schedule)
|
||||
|
|
@ -78,11 +82,18 @@ def get_schedule_ops(
|
|||
|
||||
# Instantiate the schedule class
|
||||
schedule_instance = schedule_class(stages, num_microbatches)
|
||||
assert schedule_instance.pipeline_order is not None
|
||||
|
||||
# Convert to List[List[_Action]]
|
||||
all_actions = []
|
||||
for rank in range(pp_degree):
|
||||
all_actions.append(schedule_instance.pipeline_order[rank])
|
||||
all_actions: list[list[Optional[_Action]]] = []
|
||||
if with_comms:
|
||||
runtime = _PipelineScheduleRuntime(stages, num_microbatches)
|
||||
runtime._prepare_schedule_with_comms(schedule_instance.pipeline_order)
|
||||
for rank in range(pp_degree):
|
||||
all_actions.append(list(runtime.pipeline_order_with_comms[rank]))
|
||||
else:
|
||||
for rank in range(pp_degree):
|
||||
all_actions.append(schedule_instance.pipeline_order[rank])
|
||||
|
||||
# Add spacing
|
||||
if add_spacing:
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user