diff --git a/torch/distributed/pipelining/_schedule_visualizer.py b/torch/distributed/pipelining/_schedule_visualizer.py index 81be2b17834..1230adc35bd 100644 --- a/torch/distributed/pipelining/_schedule_visualizer.py +++ b/torch/distributed/pipelining/_schedule_visualizer.py @@ -17,6 +17,7 @@ from torch.distributed.pipelining.schedules import ( _Action, _ComputationType, _PipelineSchedule, + _PipelineScheduleRuntime, get_schedule_class, PipelineScheduleMulti, PipelineScheduleSingle, @@ -36,6 +37,7 @@ def get_schedule_ops( num_microbatches: int, num_stages_per_rank: Optional[int] = None, add_spacing: bool = False, + with_comms: bool = False, ) -> list[list[Optional[_Action]]]: """ Get all actions for a given schedule, pp_degree, and num_microbatches. The actions are returned in a list of lists @@ -43,6 +45,8 @@ def get_schedule_ops( The schedule can be specified as a string which is passed into get_schedule_class() or a _PipelineSchedule instance. """ + if add_spacing and with_comms: + raise ValueError("Cannot add spacing and view comms at the same time") if isinstance(schedule, str): schedule_class = get_schedule_class(schedule) @@ -78,11 +82,18 @@ def get_schedule_ops( # Instantiate the schedule class schedule_instance = schedule_class(stages, num_microbatches) + assert schedule_instance.pipeline_order is not None # Convert to List[List[_Action]] - all_actions = [] - for rank in range(pp_degree): - all_actions.append(schedule_instance.pipeline_order[rank]) + all_actions: list[list[Optional[_Action]]] = [] + if with_comms: + runtime = _PipelineScheduleRuntime(stages, num_microbatches) + runtime._prepare_schedule_with_comms(schedule_instance.pipeline_order) + for rank in range(pp_degree): + all_actions.append(list(runtime.pipeline_order_with_comms[rank])) + else: + for rank in range(pp_degree): + all_actions.append(schedule_instance.pipeline_order[rank]) # Add spacing if add_spacing: