This PR allows schedules loaded via CSV to automatically set their `stage_index_to_group_rank ` and removes the `stage_index_to_group_rank ` argument from the `PipelineScheduleMulti` constructor
Pull Request resolved: https://github.com/pytorch/pytorch/pull/146217
Approved by: https://github.com/wconstab
ghstack dependencies: #146193
We use `stage_index_to_group_rank` in the stage to determine what send/recv ops and in the schedule for IR generation. However, we don't need to expose this as an argument in our schedule class, so this stack of PRs is to remove it.
This PR creates a `stage_index_to_group_rank` utility function and removes the arg for the ZBVschedule. In a following PR I will add code to infer the `stage_index_to_group_rank` for the CSV schedule path and we will be able to remove this argument from our classes entirely.
Related comment from @wconstab https://github.com/pytorch/torchtitan/issues/774#issuecomment-2619793741
Pull Request resolved: https://github.com/pytorch/pytorch/pull/146193
Approved by: https://github.com/wconstab
# Changes
* small fix in stage error message
* Move `format_pipeline_order` and `_validate_pipeline_order` out of `test_schedule.py` into `schedules.py`.
* Wrap the execution runtime in a try-except which on error will log the timestep and schedule plan before re-raising the exception.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129369
Approved by: https://github.com/wconstab
ghstack dependencies: #129368
# Changes
* small fix in stage error message
* Move `format_pipeline_order` and `_validate_pipeline_order` out of `test_schedule.py` into `schedules.py`.
* Wrap the execution runtime in a try-except which on error will log the timestep and schedule plan before re-raising the exception.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129369
Approved by: https://github.com/wconstab
ghstack dependencies: #129368
`QualnameMapMixin` was intended to provide a mapping from new FQN of the piped model to the FQN of the original model. It was there because previous tracers and flattening during tracing would modify the FQNs.
Now that we use unflattener, the FQN of the stage modules are the same as the original FQNs. We don't need `QualnameMapMixin` any more.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/127018
Approved by: https://github.com/H-Huang
Address the classes of user errors stemming from (possibly)
unintentional dynamic shapes usage or mismatch of configuration time and
run time data shapes/dtypes.
The goal is to ensure a clear error is raised rather than relying on some underlying
error to bubble up when a tensor shape is not compatible, or worse,
having a silent correctness issue.
**Classes of shape/dtype errors**
* (a) error is thrown within the stage-module forward code, but may be
hard to understand/trace back to an input issue
* (b) silent correctness issue happens inside the stage-module forward,
but the correct output shape is still produced
produces the expected output shape
* (c) the stage-module produces an output that is locally correct, but not
matching the expectation of the following stage, leading to a hang or
correctness issue down the line
**How validation helps**
Input shape validation
- improves debugability of case (a)
- guards against case (b)
- only needed on first stage, since subsequent stages use pre-allocated recv
buffers that can't change shape/size even if they wanted to
Output shape validation
- guards against case (c)
Validation of first stage input and all stages' outputs inductively verifies all shapes
Shape/dtype are most critical as they literally affect the number of
bytes on the wire. Strides and other tensor properties may also (?)
matter, and the validation function can be adjusted accordingly if needed.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/126732
Approved by: https://github.com/kwen2501