This PR allows schedules loaded via CSV to automatically set their `stage_index_to_group_rank ` and removes the `stage_index_to_group_rank ` argument from the `PipelineScheduleMulti` constructor
Pull Request resolved: https://github.com/pytorch/pytorch/pull/146217
Approved by: https://github.com/wconstab
ghstack dependencies: #146193
We use `stage_index_to_group_rank` in the stage to determine what send/recv ops and in the schedule for IR generation. However, we don't need to expose this as an argument in our schedule class, so this stack of PRs is to remove it.
This PR creates a `stage_index_to_group_rank` utility function and removes the arg for the ZBVschedule. In a following PR I will add code to infer the `stage_index_to_group_rank` for the CSV schedule path and we will be able to remove this argument from our classes entirely.
Related comment from @wconstab https://github.com/pytorch/torchtitan/issues/774#issuecomment-2619793741
Pull Request resolved: https://github.com/pytorch/pytorch/pull/146193
Approved by: https://github.com/wconstab
Adds a grad-scaling method `perform_pp_grad_scaling()` which divides grads by num_microbatches.
Enables grad scaling by default, unless disabled due to using a loss function that sums instead of averaging losses.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144352
Approved by: https://github.com/H-Huang
**Overview**
This PR moves `torch/distributed/_composable/fsdp` to `torch/distributed/fsdp/_fully_shard` and makes public APIs available from `torch.distributed.fsdp`, e.g.:
```
from torch.distributed.fsdp import fully_shard
```
This is targeting 2.6 release. I rewrote some of the documentation with (hopefully) improved phrasing.
**Changes for Reland**
- Preserved the public objects from `torch/distributed/_composable/fsdp/fully_shard.py` so that the import path still works internally
- Added a unit test that we can do `from torch.distributed._composable.fsdp.fully_shard import FSDPModule`
Differential Revision: [D66890387](https://our.internmc.facebook.com/intern/diff/D66890387)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/141868
Approved by: https://github.com/kwen2501, https://github.com/wconstab, https://github.com/weifengpy, https://github.com/fegin, https://github.com/XilunWu
Co-authored-by: Svetlana Karslioglu <svekars@meta.com>
**Overview**
This PR moves `torch/distributed/_composable/fsdp` to `torch/distributed/fsdp/_fully_shard` and makes public APIs available from `torch.distributed.fsdp`, e.g.:
```
from torch.distributed.fsdp import fully_shard
```
This is targeting 2.6 release. I rewrote some of the documentation with (hopefully) improved phrasing.
**Follow-Ups**
- [x] Add some explanation in the docs about FSDP1 vs. FSDP2
- [ ] Move unit tests from `test/distributed/_composable/fsdp` to `test/distributed/fsdp/fully_shard/`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/141868
Approved by: https://github.com/kwen2501, https://github.com/wconstab, https://github.com/weifengpy
Co-authored-by: Svetlana Karslioglu <svekars@meta.com>
* Automatically applies ruff rule 401. Turns loops into equivalent list comprehensions which are faster and do not leak the scope of the loop variables.
* list comprehensions not only often have better typing, but are 50+% faster than for loops on overhead. They also preserve length information etc and are better for the interpreter to optimize.
* Manually went back and made mypy happy after the change.
* Also fixed style lints in files covered by flake8 but not by pyfmt
Pull Request resolved: https://github.com/pytorch/pytorch/pull/140980
Approved by: https://github.com/justinchuby, https://github.com/malfet
Since any stage can run a mixture of full backwards and split backwards,
it is important to count the sum of (full_backwards + backward_weight)
when comparing to num microbatches to determine last backward.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/139415
Approved by: https://github.com/H-Huang
Also, update tests to use I (BACKWARD_INPUT) vs B (FULL_BACKWARD)
consistently.
Previously, schedules would issue a 'B' operation and leave it ambiguous
whether that operation should be BACKWARD_INPUT or FULL_BACKWARD,
depending on a separate flag (use_full_backward) passed to the schedule
class, which would determine which behavior was taken at runtime.
Now, use_full_backward is removed and the schedule class is required to
produce unambiguous IR. The logic for 'use_full_backward' is removed
from the runtime.
_validate_pipeline_order is replaced with _simulate_comms_compute. Both
offer similar functionality, to validate the corrrectness of a schedule
IR. 'validate' operates on compute-only IR, while simulate operates on
compute + comm IR. To convert from using validate to simulate, you have
to first insert comm actions via '_add_send_recv'.
'simulate' was inefficiently written before this PR and needed to be
optimized to run quickly for extra large schedules with >32 ranks and
microbatches per rank used in some unit tests.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/138886
Approved by: https://github.com/H-Huang
Used in both simulator and add_send_recv pass, the ready_to_schedule
logic works by looking at all the previously scheduled ops on a rank to
see if any of them 'unblocks' the current op to be scheduled. For example,
to schedule a FORWARD op, a previous RECV_F op is needed, unless this is
stage 0 or there is a previous stage on the same rank that ran FORWARD
already.
The old implementation iteratively compared the candidate op to the
previous ops. The new implementation uses set lookups to reduce
complexity. It also maintains the set of previous ops as ops are
scheduled rather than constructing a set on demand.
I did not save benchmark results, but this results in a 10-100x speedup
which is most noticeable for unit tests with artificially huge schedule
IR, the largest of which took longer than 20m before (I never let it
finish) but now takes less than 14s. Most schedules take less than
10ms.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/138924
Approved by: https://github.com/H-Huang
ghstack dependencies: #138928, #131762
### Separate dI / dW:
PipelineScheduleRuntime now supports execution of merged FULL_BACKWARD
or separate dI / dW operations.
Separating the B and W may add execution overhead or may be suboptimal
in cases where BW are 'fused', but it is worthwhile when separating B, W
lets the schedule be more efficient by filling in bubbles. In some
cases, the schedule will still issue B followed by W at certain points,
so in these cases just merge them back into BW ops and execute them as
full backwards rather than executing a B followed by a W.
### V-schedules:
V-schedules have a special case where the last rank has 2 adjacent
stages.
E.g. if rank3 had stage 3 and stage 4, then we should implement direct
transfer of stage3 outputs to stage4 inputs without a
send/recv.
In the schedling logic, we also must allow scheduling the
stage 4 forward after running stage 3 forward, without expecting a stage
4 RECV_F
In the runtime, we pass activations between adjacent stages without
using SEND/RECV ops since the stages are on the same rank/process. We
add new APIs to PipelineStage abstraction for passing the activations
both during forward and backward. Currently the implementation directly
modifies the 'recv buffers' the stage is managing, so the
forward/backwrad execution logic does not need to know the difference.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/131762
Approved by: https://github.com/H-Huang
ghstack dependencies: #138928
The special case was added during experimentation with batched send/recv
ops. The ops needed to be jointly scheduled or the simulator would
think that each op was unschedulable since each contained a recv that
depended on the other's send. The workaround I added was to let the
scheduler 'peek' one op ahead for unblocking, which let batched ops be
scheduled but also changed the behavior or non-batched ops. It let RECV
ops be simulated one step earlier than the unblocking SEND ops, which
shortened the simulated duration of schedules.
Removing this workaround simplifies the simulator but more importantly
lends to optimizing the runtime of the simulator by making it much
easier to avoid copying or extending lists of previous ops on each
iteration. It also restores the output of the simulator for non-batched
ops to a more natural output where RECV must happen at the same time or
later than matching SEND, rather than possibly a step earlier.
For example, for this test:
`python test/distributed/pipelining/test_schedule.py -k test_send_recv_test_info0`
Before:
```
Step 0: 0F0 1RECV_F0
Step 1: 0SEND_F0
Step 2: 0F1 1RECV_F1
Step 3: 0SEND_F1 1F0
Step 4: 0RECV_B0 1B0
Step 5: 0B0 1SEND_B0
Step 6: 1F1
Step 7: 0RECV_B1 1B1
Step 8: 0B1 1SEND_B1
```
After:
```
Rank 0 Rank 1
Step 00: 0F0
Step 01: 0SEND_F0 1RECV_F0
Step 02: 0F1
Step 03: 0SEND_F1 1RECV_F1
Step 04: 1F0
Step 05: 1B0
Step 06: 0RECV_B0 1SEND_B0
Step 07: 0B0 1F1
Step 08: 1B1
Step 09: 0RECV_B1 1SEND_B1
Step 10: 0B1
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/138928
Approved by: https://github.com/H-Huang
Schedule simulator is useful for detecting hangs in schedules and
validating that they won't hang. It also inserts bubbles (None actions)
at any timestep where a rank can not enqueue its next action due to
unmet dependencies, which can serve as a rough metric for schedule
efficiency. The output can be visualized. The simulator expects a full
comm + compute schedule as input.
Chrometrace dump is a basic visualization utility. It currently just
renders one 'process' per rank, and lets users visualize the schedule in
a UI instead of as text.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/138134
Approved by: https://github.com/H-Huang
NOTE: this PR removes `ScheduleFlexibleInterleaved1F1B`, let me know if theres any concerns.
`ScheduleFlexibleInterleaved1F1B` is a superset of `Interleaved1F1B` and uses most of the same implementation, but relaxes the condition that `n_microbatches % pp_size == 0`. This is refactors the implementation into `Interleaved1F1B` and then removes it since it is confusing to have both schedules with similar names. This also refactors the zero bubble logic to belong in the `ZeroBubble` schedule class.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/137783
Approved by: https://github.com/wconstab
Performs shape inference at runtime using user-provided real tensors.
- avoids the need for users to precompute shapes which is difficult and error prone
- lets us remove args from the PipelineStage ctor (in a later PR)
- deprecates existing inference helper in PipelineStage constructor for several reasons: its problematic to have to reason about the stage submod being on the right device for shape inference
The current state as of this PR:
- Users should not pass any input or output shapes into PipelineStage ctor, and shape inference will run automatically
- To override shape inference, they can continue to pass input/output args as previously
Currently, does not add a barrier after shape-inference, which essentially pipelines shape inference with the subsequent schedule action for that stage. If this complicates debugging, we could add in a barrier (it comes at a cost, but only during the first step).
Testing:
- Removed input args from all PP test cases, thus exposing them all to shape-inference.
- Verified visually (nvidia-smi) that torchtitan PP 3D test runs shape inference fine without creating extra cuda contexts.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/136912
Approved by: https://github.com/kwen2501, https://github.com/H-Huang
Fix two more leaks of the same variety as #136507 (see that PR desc and attached gdoc for debug details).
This time, also add a test-time check that helped to discover new leaks and ensure we won't accidently regress.
Adds `check_tensor_leak` util which internally asserts no tensors are being kept alive by other objects involved in py ref cycles.
Uses objgraph for a nice debug utility when a leak is found.
Credit to @H-Huang for pointing out objdump and helping debug the 'param_group["intermediates"]` leak.
I manually confirmed that all 3 of the leaks identified/fixed so far are caught by the unit test and checker.
Sample output, if I re-introduce a leak by commenting out `del param_group["intermediates"]` in _backward.py,
and run `python test/distributed/pipelining/test_schedule_multiproc.py -k test_schedule_with_native_zero_bubble`:
```
warnings.warn(
/data/users/whc/pytorch/torch/testing/_internal/common_utils.py:5341: UserWarning: 34 tensors were found in the garbage. Did you introduce a reference cycle?
warnings.warn(
/data/users/whc/pytorch/torch/testing/_internal/common_utils.py:5347: UserWarning: Dumping first 1 objgraphs of leaked tensors rendered to png
Graph written to /tmp/objgraph-ztz642h3.dot (19 nodes)
Graph viewer (xdot) not found, generating a png instead
Image generated as /tmp/objgraph-ztz642h3.png
```
rendering of ` /tmp/objgraph-ztz642h3.png`:
<img width="1671" alt="image" src="https://github.com/user-attachments/assets/9098ff29-224c-4533-935b-83c210ac2e22">
Pull Request resolved: https://github.com/pytorch/pytorch/pull/136584
Approved by: https://github.com/kwen2501, https://github.com/H-Huang
ghstack dependencies: #136507
Co-authored-by: Howard Huang <howardhuang@fb.com>
Avoid allocating memory or dry-running the submodule during stage init.
Save user-provided input/output metadata during stage init, to allow
lazily initializing the buffers before the first step call.
Later, we plan to build on top of this to add lazy shape inference
(#130856) so that no input/output shapes are required at stage init.
For now, we require input/output tensors for stage init, but these
should be on meta device and stage should not allocate any real memory.
Note: this needs more thorough testing and review, but it worked on the
torchtitan 3d test.
TODO:
- delete 'device' arg from PipelineStage ctor? (move it to inferred from
args tensors passed to first step call? separate PR.
- delete 'output_args' from PipelineStage ctor? we don't actually need
it, but we use it to do shape validation, which is why I didn't remove
it in this PR. Proposal: leave it until we add lazy shape inference?
Fixes#136225, #136226
Pull Request resolved: https://github.com/pytorch/pytorch/pull/136243
Approved by: https://github.com/H-Huang, https://github.com/kwen2501
Zero bubble can be expressed through `ScheduleFlexibleInterleaved1F1B` by setting `enable_zero_bubble=True`. But instead of having to include this flag in schedule initialization we should create a separate ZeroBubbleSchedule and also transition `Interleaved1F1B` to derive from `ScheduleFlexibleInterleaved1F1B`. Then we dont need to expose `ScheduleFlexibleInterleaved1F1B` since the naming is not obvious
Pull Request resolved: https://github.com/pytorch/pytorch/pull/133467
Approved by: https://github.com/wconstab
ghstack dependencies: #132691
Creates a new runtime that shifts complexity from runtime to
ahead-of-time.
The existing runtime (PipelineScheduleMulti) accepts a
compute-only schedule (forward, backward, weight) actions only are
specified, and it infers the communication operations at runtime.
Compared to that runtime, PipelineScheduleRuntime has less logic that
happens at runtime and relies on lowering passes to transform the
compute-only schedule to add communications.
Advantages include
- easier to verify the correctness by dumping a compute+comm schedule
- posible to manually edit the compute+comm schedule if the lowering
heuristics are insufficient
Functionality included inside the PipelineScheduleRuntime is limited to
- accepting a compute-only schedule and lowering it to add comms
- executing the compute or comm operations specified by the given
schedule
- handling work.wait() automatically by calling it just before the
matching compute operation (for RECV ops) or at the end of step (for
SEND ops)
Follow ups for later PRs
- Some refactoring should be done to replace PipelineScheduleMulti with
this runtime
- Optimizer execution is not considered (e.g. for zero-bubble cases)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/130488
Approved by: https://github.com/H-Huang
Creates a new runtime that shifts complexity from runtime to
ahead-of-time.
The existing runtime (PipelineScheduleMulti) accepts a
compute-only schedule (forward, backward, weight) actions only are
specified, and it infers the communication operations at runtime.
Compared to that runtime, PipelineScheduleRuntime has less logic that
happens at runtime and relies on lowering passes to transform the
compute-only schedule to add communications.
Advantages include
- easier to verify the correctness by dumping a compute+comm schedule
- posible to manually edit the compute+comm schedule if the lowering
heuristics are insufficient
Functionality included inside the PipelineScheduleRuntime is limited to
- accepting a compute-only schedule and lowering it to add comms
- executing the compute or comm operations specified by the given
schedule
- handling work.wait() automatically by calling it just before the
matching compute operation (for RECV ops) or at the end of step (for
SEND ops)
Follow ups for later PRs
- Some refactoring should be done to replace PipelineScheduleMulti with
this runtime
- Optimizer execution is not considered (e.g. for zero-bubble cases)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/130488
Approved by: https://github.com/H-Huang
Inserts send/recv ops where needed in a compute-only pipeline schedule.
Any F or B action will require a recv op for its input and a send op
for its output, except for at the ends of the pipeline.
To avoid hangs caused by mixed-up orderings of sends/recvs across ranks,
we pick one compute action at a time and insert both its send op (on
that rank's schedule), and the matching recv op for the recipient stage
(on the schedule for the rank for that stage).
TODO
Currently ignores a couple of edge cases
- ignores batching (which is an optimization)
- ignores cases where a stage sends to anotehr stage on the same rank,
and should skip the send/recv and directly access memory
Pull Request resolved: https://github.com/pytorch/pytorch/pull/130378
Approved by: https://github.com/H-Huang
ghstack dependencies: #129810
Adds fsdp unshard/reshard ops to a compute-only schedule.
Operates on one pp-rank's schedule at a time, since there is no
cross-pp-rank coordination needed for FSDP. (Unshard/Reshard is across
DP ranks within a PP group).
Uses a heuristic based on examining the next N stages to run compute
operations on this rank, evicting (resharding) and fetching (unsharding)
ahead of time to give unshard operations a chance to overlap with
compute and PP comms.
- this heuristic has not been validated and may not be optimal
Makes the assumption that it's fine to add the UNSHARD/RESHARD actions
to the schedule regardless of if FSDP will actually be used.
- this way, users do not have to tell us at PP schedule creation time if
they plan to use FSDP or DDP
- it is trivial to implement UNSHARD/RESHARD as no-ops inside the
runtime, if FSDP is not detected on the stage module
TODO
- also add FSDP's reduce-scatter? or is it sufficient to leave this
handled by PipelineStage at 'last backward' time
- validate 'next N stages' heuristic and expose an API if needed
- add an e2e test
Co-authored-by: Howard Huang <howardhuang@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129810
Approved by: https://github.com/kwen2501, https://github.com/H-Huang
This PR fixes a bug in `test_correct_module_names` introduced in #130497. It also addresses post-fix test failures in:
* `torch/ao/quantization/__init__.py` - set the correct `__module__` for several public API helpers
* `torch/library.py` - add `register_vmap` to `__all__`
* `torch/nn/attention/flex_attention.py` - make `round_up_to_multiple` private by prepending an underscore
* `torch/storage.py` - introduce `__all__` to avoid `Self` being re-exported as a public API
* `torch/distributed/pipelining/schedules.py` - add `ZeroBubbleAlgorithm` to `__all__`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/131386
Approved by: https://github.com/albanD