Eventually, we should just have one unified way to check for parity between a `DTensor`-sharded model and a replicated model. This PR is a small refactor to work toward that. One current gap to use this `check_sharded_parity` function for 2D is that FSDP's `(Shard(0), Shard(0))` layout differs from that of the `DTensor` APIs since FSDP shards on dim-0 after TP shards on dim-0.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/121357
Approved by: https://github.com/weifengpy
ghstack dependencies: #121360
Since we are already checking if the RNG tracker is initialized, there is no real performance difference between erroring vs. just initializing a default RNG tracker (which we choose to be the `OffsetBasedRNGTracker`).
```
pytest test/distributed/_composable/fsdp/test_fully_shard_init.py -k test_meta
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/121328
Approved by: https://github.com/wanchaol
ghstack dependencies: #120351
This PR adds initial support for meta-device initialization for pre-training without loading from a state dict. The idea is to allow `fully_shard(module)` to return and still have sharded parameters on meta device. Then, the user is free to initialize them as they please, e.g. using `to_empty()`.
We override `_apply` to achieve the following:
- Reshard the parameters to ensure that sharded parameters are registered (for correctness) -- we will always need this
- Pad new local tensors and use the padded local tensors (to handle uneven sharding) -- we will remove this once `DTensor` pads its local tensor
We use the `swap_tensors` path in `_apply`. For now, this requires setting `torch.__future__.set_swap_module_params_on_conversion(True)`; however, in the future, this may be enabled by default for wrapper subclasses and will not need any explicit API call. If requiring this call is too intrusive in the short term, we can also call it in `_apply` or when importing `fully_shard`.
```
# Pre-training flow (no checkpoint)
global_mesh = init_device_mesh(..., mesh_dim_names=("dp", "tp"))
dp_mesh, tp_mesh = global_mesh["dp"], global_mesh["tp"]
with torch.device("meta"):
model = ...
parallelize_module(model, tp_mesh, ...)
fully_shard(model, mesh=dp_mesh, ...)
for param in model.parameters():
assert param.device.type == "meta"
model.to_empty(device="cuda")
random.manual_seed(42, global_mesh)
for module in model.modules():
if hasattr(module, "reset_parameters"):
module.reset_parameters()
```
This PR includes some minor changes to allow the user to similarly cast the module to a different dtype after construction time but before forward.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/120351
Approved by: https://github.com/wanchaol
This PR uses `ncclAvg` op (via `ReduceOp.AVG`) if doing fp32 reduce-scatter. This allows the division by world size to happen in the reduce-scatter kernel itself, which seems to save extra memory read/write for dividing. This yields ~1.5% speedup on the Llama-7B workload (and makes per-parameter FSDP faster than flat-parameter FSDP 😅 ).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/120919
Approved by: https://github.com/yifuwang, https://github.com/wanchaol
ghstack dependencies: #120238, #120910
This PR adds support for `clip_grad_norm_(foreach=True)` by implementing `aten._foreach_norm.Scalar` and `aten._foreach_mul_.Tensor`. `foreach=True` is required to get competitive performance with `DTensor`.
`foreach=True` reduces CPU overhead for Llama-7B from 388 ms to 63 ms. Existing flat-parameter FSDP's `clip_grad_norm_` takes 3 ms on CPU 😢 .
Pull Request resolved: https://github.com/pytorch/pytorch/pull/120910
Approved by: https://github.com/wanchaol, https://github.com/janeyx99
ghstack dependencies: #120238
This PR adds `DTensor` support for `aten.linalg_vector_norm.default` and `aten.stack.default` so that we can run `clip_grad_norm_` (with `foreach=False`).
To implement `linalg_vector_norm`, we introduce a `_NormPartial` placement since the reduction op for norm is the norm itself.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/120238
Approved by: https://github.com/wanchaol
goal: all unit tests for eager. we want to test torch.compile by default
this PR adds ``@test_compiled_fsdp(compile_compute_on_module=None/TransformerBlock)`` to unit tests. now it's compiling compute-only as follows.
```
module.compile() # include user registered hooks if any
fully_shard(module)
```
torch.compile does not work following component yet
* compiling AC
* compiling reshard_after_forward=2
* delayed_all_gather, delayed_reduce_scatter
Pull Request resolved: https://github.com/pytorch/pytorch/pull/119933
Approved by: https://github.com/awgu, https://github.com/jansel
This PR adds a way to do gradient accumulation without collectives (i.e. reduce-scatter for FSDP and reduce-scatter/all-reduce for HSDP, though HSDP is not yet implemented). Since the `no_sync()` context manager has received some feedback, we simply define a method on the module to set whether the module requires gradient synchronization or not, where this method can recurse or not.
```
# Before with `no_sync()`:
with fsdp_model.no_sync() if not is_last_microbatch else contextlib.nullcontext():
# Forward/backward
# After with a setter:
fsdp_model.set_requires_gradient_sync(not is_last_microbatch)
# Forward/backward
```
Having the method be able to recurse or not also gives some flexibility. For example, some large modules can still reduce-scatter, while some smaller modules can avoid it to save communication bandwidth:
```
fsdp_modules_to_reduce_scatter: Set[nn.Module] = ...
for module in fsdp_model.modules():
if isinstance(module, FSDP) and module not in fsdp_modules_to_reduce_scatter:
module.set_requires_gradient_sync(not is_last_microbatch)
# Forward/backward
```
(Separately, we may expose a helper for `return [module for model.modules() if isinstance(module, FSDP)]`.)
---
To show the spirit of this API choice, I also included `set_requires_all_reduce` that would give us the ability to only reduce-scatter but not all-reduce for HSDP (originally from the MiCS paper). If we want to flexibly support heterogeneous sharding where FSDP is applied to some modules and HSDP to others in the same model, then having a module-level method that has the option to not recurse makes sense to me.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118298
Approved by: https://github.com/wconstab, https://github.com/wanchaol
ghstack dependencies: #119550, #118136, #118223, #118755, #119825
This PR adds mixed precision configured via `MixedPrecisionPolicy`.
- By default (`cast_forward_inputs=True`), each FSDP module will cast forward floating-point input tensors to `param_dtype` if specified. If the user wants to own the cast, then the user can disable it by passing `False`.
- Symmetrically, by default (`output_dtype=None`) each FSDP module will not cast the forward output. If the user wants to customize the output dtype, then the user can pass a `torch.dtype`.
- `param_dtype` configures the unsharded parameters' dtype for forward/backward computation and hence the all-gather dtype.
- `reduce_dtype` configures the gradient reduction dtype. If `reduce_dtype=None` and `param_dtype is not None`, then `reduce_dtype` inherits from `param_dtype` for simplicity.
We test against a manually implemented reference implementation instead of comparing against existing FSDP since the comparison is more direct to what we want to test.
---
**Overhead benchmarks to inform design**
The dilemma is as follows:
- The common path for FSDP is bf16 parameter mixed precision, where we cast sharded parameters from fp32 to bf16 before all-gathering them.
- The baseline implementation is to `torch._foreach_copy_` the sharded parameters to the flat `all_gather_input`, which gets passed to `dist.all_gather_into_tensor`.
- The baseline incurs 1 extra fp32 read and 1 extra bf16 write per parameter because `_foreach_copy` takes the slow path, calling `copy_` in a loop, and `copy_` calls `dst.copy_(src.to(bf16))` where `dst` is bf16 and `src` is fp32.
- These `copy_` calls stay in C++ and do not require calling `at::as_strided`.
- The issue with this baseline implementation is that it requires knowing that all parameters in the group will be cast from fp32 to bf16 to do this `_foreach_copy_` from fp32 sources to a bf16 destination.
- We want per-parameter FSDP to support mixed dtype all-gathers, which would involve different parameters providing different dtype all-gather inputs and viewing them as uint8 for a combined flat all-gather input, where this viewing-as-uint8 step is only needed in the mixed dtype case.
- However, this incurs more CPU overhead, so we want to investigate this in more detail.
We consider 150 `nn.Parameter`s with shapes taken from an internal model (where the shapes only affect the copy bandwidth, not the CPU overhead). We focus on world size 128 first. We consider two experiments: (1) run the copy-in with no head start, allowing CPU boundedness affect GPU time, and (2) run the copy-in with a CPU head start, removing CPU overhead from affecting GPU time.
No head start:
- Baseline `torch._foreach_copy_`: 0.525 ms CPU; 0.528 ms GPU
- `.to(bf16)` before `torch._foreach_copy_`: 0.828 ms CPU; 0.836 ms GPU
- `.to(bf16).view(uint8)` before `torch._foreach_copy_`: 0.933 ms CPU; 0.937 ms GPU
Head start (removing CPU boundedness from GPU times):
- Baseline `torch._foreach_copy_`: 0.393 ms GPU
- `.to(bf16)` before `torch._foreach_copy_`: 0.403 ms GPU
- `.to(bf16).view(uint8)` before `torch._foreach_copy_`: 0.403 ms GPU
Some other interesting notes:
- Constructing a set of all all-gather input dtypes: ~0.015 ms -- this would be the overhead cost of checking whether we need to view as uint8 (i.e. whether we have mixed dtype); alternatively, we could always view as uint8 (but that loses the mixed precision policy info from the profiler trace)
- Changing from `[t.to(bf16).view(uint8) for t in ts]` to two list comprehensions like `[t.to(bf16) for t in ts]; [t.view(uint8) for t in ts]` actually reduces CPU overhead 🤔 (by ~0.04 ms)
We see that the main difference is just CPU overhead. The GPU times are almost the same. (Actually, sweeping over 8, 16, 32, 64 world size, we do see difference in GPU time inversely proportional to world size, as expected since smaller world sizes copy more data. However, even at world size 8, the difference is only 0.407 ms vs. 0.445 ms GPU time.) Note though that the CPU overhead differences are exacerbated when the PyTorch profiler is turned on, and how much so seems to depend on the CPU capability.
Seeing these numbers, I am inclined to prefer to just incur the CPU overhead, especially given that if we want to support the mixed dtype case for fp8 all-gather, we will need to incur this anyway. If the CPU overhead becomes a problem on a real workload, then we will need to figure out options then, one being using `torch.compile` possibly.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118223
Approved by: https://github.com/weifengpy, https://github.com/wanchaol
ghstack dependencies: #119550, #118136
This PR adds tests for autograd (mainly backward hooks), memory, overlap, and frozen parameters.
- Autograd: unused forward output, unused forward module, non-tensor activations (common in internal models)
- Memory: expected GPU memory usage after init, forward, backward, and optimizer step
- Overlap: communication/computation overlap in forward and backward
- Frozen: expected reduce-scatter size, training parity
This PR adds some initial 2D (FSDP + TP) training and model state dict tests. The only change required for model sharded state dict is to make sure parameters are sharded before save and load.
This PR adds tests that `fully_shard` can use `torch.utils.checkpoint`, `_composable.checkpoint`, and `CheckpointWrapper` on a transformer.
(I squashed all of these into one PR now to save CI cost.)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118136
Approved by: https://github.com/weifengpy, https://github.com/wanchaol
ghstack dependencies: #119550
This PR adds explicit backward prefetching to overlap communication and computation in backward (namely, needed for `reshard_after_forward=True` or `reshard_after_forward: int`). We do this by recording the post-forward order and using its reverse to approximate the backward order.
This works for the typical 1 forward / 1 backward training. However, for more complex schedules, this can run into some gaps:
- We need to know the _true end of backward_.
- At the true of end of backward, we can clear our recorded post-forward order and pre-backward hook state, and we should wait on gradient reductions.
- There is no easy way to know whether the current backward marks the true end of backward. Therefore, we introduce an API for the user to set this: `fsdp_module.set_is_last_backward(bool)`. For example, for pipeline parallelism's DFS cooldown backward, we can call `fsdp_module.set_is_last_backward(is_last_microbatch)`.
- When the user runs backward through only part of the model, our reverse-post-forward-order heuristic risks _mistargeted prefetches_ for unused modules, which would mean the module's parameters are all-gathered and not freed until the end of backward.
- To error on the side of less memory usage (but no overlap), this PR introduces logic to check whether a module will need its unshard in the current backward (by recording the module's `forward` outputs' `grad_fn`s and querying the autograd engine).
- Note that there may be _no_ overlap in backward for some parts due to no prefetching.
- Note further that when running multiple backwards, if the user does not use `set_is_last_backward`, we may not be able to provide a meaningful error message, as the pre-backward hook could be erroneously cleared on the 1st backward.
- In the future, we may expose more APIs from the autograd engine (similar to `_current_graph_task_execution_order`) to make the prefetching exact. (Currently, `_current_graph_task_execution_order` requires the `with torch.autograd.set_multithreading_enabled(False)`, which is too hard of a constraint as we cannot easily modify users' training loops. We can replace the multi-threading check with a device check. Moreover, in the partial backward case in this PR's unit test, I still hit an [internal assertion](b816760a2f/torch/csrc/autograd/engine.cpp (L476)), so some follow-up is required.)
<details>
<summary> Old Discussion </summary>
For discussion:
- The PR includes a counter `expected_backward_unshard_count` to mitigate mistargeted prefetches in backward. However, it can be seen as a necessary but not sufficient solution.
- If a module's outputs do not require gradient, then we certainly do not need to unshard the module in backward.
- However, if a module's outputs do require gradient, then we still may not need to unshard the module for _this_ backward (e.g. if the module did not contribute to `loss` for the current `loss.backward()`).
- This counter will only address the first case but not the second. If we want to address the second, then we may need more info from the autograd engine.
- For now, I did not include any unit test to cover these behaviors, as I do not have a good example yet.
</details>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118118
Approved by: https://github.com/weifengpy, https://github.com/wanchaol
ghstack dependencies: #118017
**Summary**
The reducer of `DistributedDataParallel` is implemented with C++ and it is not easy to trace the allreduce launched in the reducer. This PR modifies `DistributedDataParallel` to launch one allreduce per gradient when `compiled_autograd` is enabled. The changes allow us to use `compiled_autograd` to trace the allreduce and later be optimized (fused) in the Inductor.
**Key Logic**
1. If `ddp_python_hook` is True, we assume `compiled_autograd` is used. `DistributedDataParallel` registers `compiled_accum_grad_hook` for all parameters.
2. In the first forward() call, if `DistributedDataParallel` is not compiled, all `compiled_accum_grad_hook` are deregistered. If `DistributedDataParallel` is compiled, all `compiled_accum_grad_hook` will be compiled by `compiled_autograd`.
3. `compiled_accum_grad_hook` launches an allreduce to reduce the gradient of the parameter.
**Bucketing**
The compiled backward is slow because there is no bucketing for the allreduces. We rely on Inductor to bucket the allreduces.
The bucketing is done in a separate PR.
Differential Revision: [D49428482](https://our.internmc.facebook.com/intern/diff/D49428482/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/110662
Approved by: https://github.com/wconstab
This PR adds the `reshard_after_forward: Union[bool, int]` arg and a `reshard()` method. The `reshard_after_forward` argument trades off communication and memory.
- `reshard_after_forward=True`: reshard parameters after forward; unshard (all-gather) in backward
- `reshard_after_forward=False`: no reshard of parameters after forward; no unshard (all-gather) in backward
- `reshard_after_forward: int`: reshard parameters to a smaller world size; unshard (all-gather) over small world size in backward
In comparison with DeepSpeed and existing FSDP:
- `reshard_after_forward=True` == `FULL_SHARD` == ZeRO-3
- `reshard_after_forward=False` == `SHARD_GRAD_OP` == ZeRO-2
- `reshard_after_forward=8` == ZeRO++
ZeRO-1 is `reshard_after_after_forward=False` without gradient reduction (implemented in a later PR). If we need gradient reduction on an iteration, then ZeRO-2 supersedes ZeRO-1.
We prefer a simple state transition between `SHARDED` / `SHARDED_POST_FORWARD` and `UNSHARDED`, where the state directly defines what tensors are registered to the module. In particular, we _do not_ have a state where the sharded parameters are registered but the unsharded parameters are still in GPU memory. This greatly simplifies our state transitions, but it means that parameters may be non-intuitively registered to the module (e.g. if only the root does not reshard after forward, then the root will be the only without sharded parameters registered). To address this, we introduce a simple `reshard()` method that can force-reshard the parameters. This makes sense to me because the typical case does not care about the registered parameters after forward (in fact, for existing FSDP with `use_orig_params=False`, the unsharded parameters are still registered and are dangling tensors without storage.)
I plan to expose a complementary `unshard(async_op: bool = True)` method in the future.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118017
Approved by: https://github.com/weifengpy, https://github.com/wanchaol
This PR adds the pre- and post-backward logic:
- **Pre-backward hook:** `FSDPState` and `FSDPParamGroup` define this, and `FSDPState` is responsible for registering since its pre-backward should run even if the `FSDPState` does not manage any parameters (in case it is the root).
- **Post-backward hook:** Only `FSDParamGroup` defines this since the post-backward hook reshards parameters and reduce-scatters gradients (functionality only needed with managed parameters). The `FSDPParamGroup` is responsible for registering this.
- **Post-backward final callback:** `FSDPState` defines this, and each `FSDPParamGroup` defines a `finalize_backward()` to call in the final callback.
### Pre-Backward
The pre-backward hook is registered on the module outputs (that require gradient), and it should run when the first such output has its gradient computed. The hook may run multiple times per backward, once per module forward. Specifically, there will be one `(pre-backward, post-backward)` interval for each of the module's `forward()` calls. This is contrast with the existing FSDP semantics, which only defines a single `(pre-backward, post-backward)` interval that is equivalent to the union of this FSDP's `(pre-backward, post-backward)` intervals. This avoids spiking memory from having multiple modules not resharding and avoids some autograd edge cases.
We implement the pre-backward hook by having a flag that is set upon the 1st calls to disable subsequent calls. This flag could be maintained by FSDP, but for a cleaner design, we augment `register_multi_grad_hook` with a `mode="any"` option and use that instead.
### Post-Backward
The post-backward hook is equivalent to a module full backward hook (`nn.Module.register_full_backward_hook`) except it adds pytree logic to work with data structures other than just flat `Tensor` args passed to `nn.Module.forward`. If we were to use `register_full_backward_hook`, then the hook could fire early (before all gradients for the module have been computed). Most internal models use custom data structures as `forward` inputs, and they find that unifying under pytree is an acceptable solution.
Unlike existing FSDP, we are able to reshard the parameters in the post-backward hook _before_ 'concatenating' the autograd-computed gradients, achieving a lower peak memory usage. (Existing FSDP has `SplitWithSizesBackward` that calls a `CatArrayBatched`, and here we have the reduce-scatter copy-in.)
### Final Callback
The final callback runs as a queued callback to the autograd engine, meaning that it runs at the end of backward.
In the future, if we do not want to wait for the reduce-scatter (or similar for CPU offloading), we can augment the final callback. The code is written such that each reduce-scatter can be waited on separately (via CUDA event).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118004
Approved by: https://github.com/weifengpy, https://github.com/wanchaol
ghstack dependencies: #117950, #117955, #117973, #117975
This PR adds the FSDP reduce-scatter (the copy-in/reduce-scatter collective/view-out).
- We use gradient pre- and post-divide factors like existing FSDP (mainly for fp16 reduction).
- We use a separate CUDA stream for the reduce-scatter to conveniently handle additional kernels surrounding the collective as a separate 'thread of execution' (e.g. pre/post-divide and later the D2H gradient offload).
- ~~The implementation in this PR is more complicated to _try_ to reduce CPU overhead by using `torch.split` instead of a Python for-loop. The challenge comes from the fact that the autograd-computed unsharded gradients do not have padding. We prefer to not do an intermediate padding step and instead directly copying to the big reduce-scatter input.~~ For simplicity, I changed the implementation to include intermediate padding steps, as it can still achieve ~250 GB/s, and it avoids any `O(NP)` tensor materialization for world size `N` and `P` `nn.Parameter`s.
<details>
<summary> Recall: Copy-in/All-Gather/Copy-Out Example </summary>
Suppose we have 2 parameters with shapes `(3, 3)` (denoted with `A`s) and `(2, 2)` (denoted with `B`s) and 2 ranks, where `P` represents padding and `E` represents empty:
```
Given:
(3, 3): AAAAAAAAA
(2, 2): BBBB
Sharded parameters/all-gather inputs:
Rank 0: AAAAAA, BB
Rank 1: AAAPPP, BB
Each rank allocate group's all-gather output:
EEEEEEEEEEEEEEEE
Each rank copy-in:
Rank 0: AAAAAABBEEEEEEEE
Rank 1: EEEEEEEEAAAPPPBB
Each rank all-gather:
Rank 0: AAAAAABBAAAPPPBB
Rank 1: AAAAAABBAAAPPPBB
Each rank copy-out:
Rank 0: AAAAAAAAAPPP, BBBB
Rank 1: AAAAAAAAAPPP, BBBB
```
</details>
<details>
<summary> Copy-in/Reduce-Scatter/View-Out Example </summary>
Suppose we have 2 gradients with shapes `(3, 3)` (denoted with `a`s when not-yet-reduced and `A`s after reduced) and `(2, 2)` (denoted with `b`s and `B`s similarly) and 2 ranks, where `E` represents empty:
```
Given from autograd:
(3, 3): aaaaaaaaa
(2, 2): bbbb
Unsharded gradients/reduce-scatter inputs (no padding!):
Rank 0: aaaaaaaaa, bbbb
Rank 1: aaaaaaaaa, bbbb
Each rank allocate group's reduce-scatter input:
EEEEEEEEEEEEEEEE
Each rank copy-in:
Rank 0: aaaaaabbaaaEEEbb
Rank 1: aaaaaabbaaaEEEbb
Each rank :
Rank 0: AAAAAABBAAAEEEBB
Rank 1: AAAAAABBAAAEEEBB
Each rank view-out:
Rank 0: AAAAAA BB
Rank 1: AAA, BB
```
</details>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/117975
Approved by: https://github.com/weifengpy, https://github.com/yifuwang
ghstack dependencies: #117950, #117955, #117973
This PR adds the all-gather and free logic required for forward.
- We define the logical all-gather as two ops: (1) unshard and (2) wait for unshard. This abstraction allows capturing both implicit forward prefetching (using multiple streams and `async_op=False`) and explicit forward prefetching (using `async_op=True`).
- Symmetrically, we define the reshard op to free the unsharded parameters.
Some other notes:
- The `FSDPParamGroup` and its `FSDPParam`s transition their sharded states together. This invariant allows us to reason about the parameters by group rather than individually with respect to whether they are sharded or unsharded.
---
### How Does the Overlap Work for All-Gather?
For context, the all-gather consists of three steps: (1) copy-in, (2) all-gather collective, and (3) copy-out.
<details>
<summary> Example </summary>
Suppose we have 2 parameters with shapes `(3, 3)` (denoted with `A`s) and `(2, 2)` (denoted with `B`s) and 2 ranks, where `P` represents padding and `E` represents empty:
```
Given:
(3, 3): AAAAAAAAA
(2, 2): BBBB
Sharded parameters/all-gather inputs:
Rank 0: AAAAAA, BB
Rank 1: AAAPPP, BB
Each rank allocate group's all-gather output:
EEEEEEEEEEEEEEEE
Each rank copy-in:
Rank 0: AAAAAABBEEEEEEEE
Rank 1: EEEEEEEEAAAPPPBB
Each rank all-gather:
Rank 0: AAAAAABBAAAPPPBB
Rank 1: AAAAAABBAAAPPPBB
Each rank copy-out:
Rank 0: AAAAAAAAAPPP, BBBB
Rank 1: AAAAAAAAAPPP, BBBB
```
</details>
`dist.all_gather_into_tensor()` always has the PG's NCCL stream wait for the current stream before running the collective. `async_op=True` means that the function waits on the work, having the current stream wait for the NCCL stream before returning. `async_op=False` means it returns the `Work` object, which the user can wait on later.
#### Implicit Prefetching
Implicit prefetching achieves communication/computation overlap without changing the CPU issue order:
- We use separate streams for copy-in and for issuing the `dist.all_gather_into_tensor()`. The copy-in stream allows us to overlap the copy-in with all-gather/reduce-scatter in backward, and the all-gather stream allows us to overlap the all-gather with forward compute (issued before it).
- Because `dist.all_gather_into_tensor()` always has the PG's NCCL stream wait for the current stream, we need this "dummy" all-gather stream to prevent the all-gather from waiting on the forward compute with which it should overlap.
- Without the separate copy-in stream, we cannot overlap all-gather copy-in with all-gather in forward.
- We copy-out in the default stream after having the default stream wait for the all-gather. This means that the autograd leaves are allocated in the default stream and autograd will not call `recordStream`.
Implicit prefetching does not require knowing the execution order ahead of time. However, when overlapping the next all-gather with the current compute, there may be a gap from the CPU thread issuing the current compute. If the CPU thread can run ahead, then this is not an issue.
#### Explicit Prefetching
Explicit prefetching achieves communication/computation by changing the CPU issue order, namely by reordering the all-gather to be before the compute with which it should overlap.
- Because we reorder, we do not need any separate streams, and we can use `async_op=False` for overlap.
- We can expose this explicit prefetching as a module-level `unshard()` op (e.g. `module.unshard(async_op: bool)`, and we can use it as a primitive for implementing the explicit forward prefetching in existing FSDP.
Explicit prefetching requires knowing the execution order.
---
Disclaimer: The testing is relatively lighter in this PR. I did not want to spend too much time writing new forward-only tests. The stream usage will be exercised thoroughly once we have backward too.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/117973
Approved by: https://github.com/weifengpy, https://github.com/yifuwang
ghstack dependencies: #117950, #117955
This PR adds the FSDP all-gather (the copy-in/all-gather collective and the copy-out) and the unsharded parameter concept to `FSDPParam`. This is to prepare for being able to run the forward pass.
- We implement all-gather as two functions: `foreach_all_gather` (copy-in/all-gather collective) and `foreach_all_gather_copy_out` (copy-out).
- In the future, there will be two paths: `async_op=True` in the default stream for explicit prefetching and `async_op=False` in separate streams for implicit prefetching.
- In the future, we will use `torch.split_with_sizes_copy` in the copy-out when it has the CUDA fast path.
- We have the functions operate on `List[FSDPParam]` instead of passing the `torch.Tensor` and metadata mainly so that the `all_gather_input` can be computed under the `all_gather_copy_in_stream`. Since the two functions are specific to FSDP, I did not see motivation for avoiding this at the cost of entering/exiting the `all_gather_copy_in_stream` context twice (which incurs some CPU overhead).
- The `init_all_gather_output()` and `init_unsharded_parameter()` functions may seem unintuitive. The reason we initialize them once and write to them in-place thereafter is for autograd. See the note `[Note: FSDP and autograd]` in the code.
- We expand our 'FSDP tensors' definition to include the all-gather input and all-gather output in addition to the sharded and unsharded parameters. This distinction might seem unnecessary or pedantic, but it enables a language for describing pre- and post-all-gather transformations.
- We use the `_unsafe_preserve_version_counters` context when copying out because otherwise autograd will complain of a version mismatch in backward due to writing to the leaf tensors. (An alternative would be to use `.data`, but we are avoiding that 😄 .)
---
<details>
<summary> Copy-in/All-Gather/Copy-Out Example </summary>
Suppose we have 2 parameters with shapes `(3, 3)` (denoted with `A`s) and `(2, 2)` (denoted with `B`s) and 2 ranks, where `P` represents padding and `E` represents empty:
```
Given:
(3, 3): AAAAAAAAA
(2, 2): BBBB
Sharded parameters/all-gather inputs:
Rank 0: AAAAAA, BB
Rank 1: AAAPPP, BB
Each rank allocate group's all-gather output:
EEEEEEEEEEEEEEEE
Each rank copy-in:
Rank 0: AAAAAABBEEEEEEEE
Rank 1: EEEEEEEEAAAPPPBB
Each rank all-gather:
Rank 0: AAAAAABBAAAPPPBB
Rank 1: AAAAAABBAAAPPPBB
Each rank copy-out:
Rank 0: AAAAAAAAAPPP, BBBB
Rank 1: AAAAAAAAAPPP, BBBB
```
</details>
---
For context, we use the copy-in/all-gather/copy-out strategy instead of NCCL group coalescing for two reasons:
1. One large NCCL all-gather is still noticeably faster than several NCCL all-gathers using group coalescing of the same total bytes (even after NCCL 2.18.3). We prefer to tradeoff extra device-to-device copies (using GPU high-bandwidth memory) to save communication time, which does not improve as fast from hardware generation to generation.
2. Copying out of the all-gather buffer tensor simplifies multi-stream memory handling because there is a constant number of such all-gather tensors alive at once. (The copy-out is done in the default/compute stream.) If we directly used the all-gather tensor memory for computation, then the number of such alive tensors is linear in the module depth and hence dependent on the particular model.
---
Disclaimer: This PR has some extraneous code, but I did not want to simplify too much since that code will be added back soon anyway (e.g. for overlapping, mixed precision, and ZeRO++). Hopefully it does not hinder code review too much.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/117950
Approved by: https://github.com/weifengpy, https://github.com/wanchaol
This PR adds the initial `_lazy_init`. Lazy initialization marks the point when the FSDP structure is finalized and is typically the beginning of the first forward. This would be after any meta-device initialization.
- Lazy initialization is distinct from construction time because when processing `fully_shard(module)`, there is no way to know whether a parent of `module` will have `fully_shard` applied as well. This is a consequence of `fully_shard` having to be applied bottom-up.
- At lazy initialization, we now have the concept of a _root_. The root FSDP module is the one whose `forward` runs first and ends last (and hence similarly for its backward). Having a single root simplifies handling logic that should only run "once per forward/backward/iteration". We may consider relaxing this in the future, but it will add more complexity to the design.
- Once we have a root, we can define _fully-qualified names_ (FQNs) for both parameters and modules. To aid debugging, we store `_param_fqn` and `_module_fqn` on `FSDPParam` and `FSDPParamGroup`, respectively. Note that we can have a unique `_module_fqn` for `FSDPParamGroup` since we currently assume a 1:1 relationship.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/117881
Approved by: https://github.com/weifengpy, https://github.com/wanchaol
ghstack dependencies: #118525, #117814, #117867, #117877
This PR adds logic to shard the managed parameters on dim-0. This is like `distribute_tensor()` with two differences:
1. `distribute_tensor()` today cannot accept a `DTensor` and reshard it to the parent mesh (https://github.com/pytorch/pytorch/issues/116101).
2. `DTensor` does not pad its local shard on any `Shard` dimensions (https://github.com/pytorch/pytorch/issues/113045).
As such, the `FSDPParam._init_sharded_param()` derives the global `DTensor` metadata itself and pads the local tensor on dim-0. The padding helps make the all-gather copy-in more efficient since the all-gather buffer will require padding.
---
Some details:
- We free the original parameter manually after constructing the sharded parameter. This lowers the peak memory during construction time slightly (since not _all_ parameters in the group must be sharded before the original parameters are freed) and is not strictly necessary.
- We bypass `nn.Module.__setattr__` because the checks are slow and unnecessary. The drawback is that we would ignore a user-defined override of `__setattr__`; however, since we have never encountered this in practice, I am okay with this. Notably, user calls to `setattr` would still use the override; FSDP only uses `setattr` as a mechanism for switching between sharded and unsharded parameters.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/117877
Approved by: https://github.com/wanchaol
ghstack dependencies: #118525, #117814, #117867
This PR adds the initial `FSDPParamGroup` and `FSDPParam` classes, and it focuses on the `ParamModuleInfo` data structure.
- `ParamModuleInfo` has the info needed to `setattr` a managed parameter, where it must account for shared parameters and shared modules.
```
# Shared parameter
lin1.weight = lin2.weight
# Shared module
mlp.lin1 = mlp.lin2
```
- In order for FSDP to find shared modules' parameters, we must use `remove_duplicate=False`. See https://github.com/pytorch/pytorch/pull/99448/ for the original context. Finding shared modules' parameters is not necessary for the `setattr` logic, but in case we need it in the future (like for existing FSDP's state dict), we include that info for now.
With this PR, we see the general system architecture:
- 1 `module` : 1 `fully_shard`
- 1 `fully_shard` : 1 `FSDPParamGroup`
- 1 `FSDPParamGroup` : k `FSDPParam`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/117867
Approved by: https://github.com/weifengpy, https://github.com/wanchaol
ghstack dependencies: #118525, #117814
Squashed to include https://github.com/pytorch/pytorch/pull/117861, https://github.com/pytorch/pytorch/pull/117852
---
This PR adds `_get_managed_modules()` to determine which modules a `fully_shard(module)` call manages. The rule is defined as:
> `fully_shard(module)` manages all modules in `module.modules()` except those already managed by a nested `fully_shard()` or a nested non-composable API (e.g. `replicate()` or TorchRec).
Practically, this can be implemented as a graph search from `module` that does not proceed into any module with `fully_shard` or a non-composable API applied. Because the non-composable APIs follow the same rule, this rule is correct inductively.
---
This PR adds `_get_managed_states(managed_modules)` to return the managed parameters and buffers given the managed modules.
- Without an extra mechanism to ignore specific parameters or buffers, the rule currently is simply to get the directly managed state (i.e. parameters/buffers) from each managed module while de-duplicating shared ones.
- However, we prefer this translation from managed modules to managed states to accommodate ignoring specific states in the future (which has appeared in various open-source use cases).
---
This PR adds the `mesh` argument to `fully_shard` and some helper data structures specific to FSDP/HSDP that pre-compute useful info like rank/world size for each mesh dim.
- The `mesh` defines the FSDP/HSDP algorithm. 1D mesh means FSDP, and 2D mesh means HSDP, where we assume sharding on the last dimension.
- We can revisit the HSDP sharding-dim assumption if needed in the future.
- The default (if `mesh is None`) is that `fully_shard` calls `init_device_mesh` following the global process group.
- The helper data structures are the various `*MeshInfo`s. I included up to the `HSDPMeshInfo` even though it will not be immediately used to show the spirit of it. We want to tag both the shard and replicate dims.
- The `mesh_info` variable in `fully_shard` is not used for now. It will be passed downstream in the future.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/117814
Approved by: https://github.com/wanchaol, https://github.com/wconstab
ghstack dependencies: #118525
This PR introduces the initial `fully_shard` frontend without any distributed logic that will be built into per-parameter-sharding FSDP.
- We design `fully_shard` to be a _module-level_ API (taking in an `nn.Module`), e.g. as opposed to a tensor-level one.
- We define a `FSDP` class and use a dynamic class swap, setting `module.__class__` to a newly created class that subclasses `FSDP` and `type(module)`, to allow FSDP to override and add methods on the module.
- We name this class as `FSDP<type(module)>`, e.g. `FSDPLinear` for `Linear`.
- We disable the `deepcopy` because the state object inserted on the module will not be trivially `deepcopy`-able.
- Calling `fully_shard(module)` inserts a state object on `module` but not any of its children. This state object will be used for any FSDP-specific state.
- We raise an error on `ModuleList` or `ModuleDict` since they do not implement `forward()`, and FSDP will rely on `forward()` to insert logic (https://github.com/pytorch/pytorch/issues/113794).
- In the future, we will deprecate the existing `fully_shard` that calls into the same backend logic as `FullyShardedDataParallel` as there is no adoption for that and we prefer to reuse that name.
**Reland details:** I removed `test/distributed/_composable/fsdp/_test_fully_shard_common.py` and moved its contents to the existing `torch/testing/_internal/common_fsdp.py`, which is already a target for internal tests.
Differential Revision: [D53187509](https://our.internmc.facebook.com/intern/diff/D53187509)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118525
Approved by: https://github.com/wanchaol
This PR introduces the initial `fully_shard` frontend without any distributed logic that will be built into per-parameter-sharding FSDP.
- We design `fully_shard` to be a _module-level_ API (taking in an `nn.Module`), e.g. as opposed to a tensor-level one.
- We define a `FSDP` class and use a dynamic class swap, setting `module.__class__` to a newly created class that subclasses `FSDP` and `type(module)`, to allow FSDP to override and add methods on the module.
- We name this class as `FSDP<type(module)>`, e.g. `FSDPLinear` for `Linear`.
- We disable the `deepcopy` because the state object inserted on the module will not be trivially `deepcopy`-able.
- Calling `fully_shard(module)` inserts a state object on `module` but not any of its children. This state object will be used for any FSDP-specific state.
- We raise an error on `ModuleList` or `ModuleDict` since they do not implement `forward()`, and FSDP will rely on `forward()` to insert logic (https://github.com/pytorch/pytorch/issues/113794).
- In the future, we will deprecate the existing `fully_shard` that calls into the same backend logic as `FullyShardedDataParallel` as there is no adoption for that and we prefer to reuse that name.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/117776
Approved by: https://github.com/wconstab, https://github.com/weifengpy, https://github.com/wanchaol
ghstack dependencies: #117994, #118186, #117984
I prefer to not modify the module if it does not have any of our APIs applied. The side effect of inserting a registry on the module when calling a getter is non-intuitive to me.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113654
Approved by: https://github.com/fegin
This PR adds a new `CustomPolicy` that acts like the existing `lambda_auto_wrap_policy` except it (1) leverages the new auto wrapping infrastructure and (2) allows overriding FSDP kwargs for particular instances. (1) gives it access to the validation checks (like for frozen parameters), and (2) makes it as expressive as manual wrapping. This should allow us to effectively deprecate manual wrapping if desired.
The API is as follows:
```
def lambda_fn(module: nn.Module) -> Union[bool, Dict[str, Any]]:
...
policy = CustomPolicy(lambda_fn)
```
The `lambda_fn` can return:
- `False` or `{}` to indicate no wrapping
- `True` to indicate wrapping while inheriting the root's FSDP kwargs
- Non-empty `dict` to indicate wrapping while overriding the specified FSDP kwargs and inheriting the rest from the root
---
After this PR, the follow-up work items for auto wrapping are:
1. Add shared parameter validation
2. (Longer-term / exploratory) Add a policy that provides a reasonable auto wrapping with "minimal" user input
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104986
Approved by: https://github.com/ezyang
ghstack dependencies: #104427, #104967, #104999, #104969
This does some code organization improvement.
- It renames `_FSDPPolicy` to `_Policy` to show that it is not only for FSDP but for any module-level API.
- It formalizes the contract that such a policy should return something like `target_module_to_kwargs: Dict[nn.Module, Dict[str, Any]]` that maps each module to wrap to its kwargs. It does so by requiring a `_run_policy` abstract method (this time private since users do not need to care about it). Then, our auto wrapping can just call `_run_policy()` to generate the dict and do any validation or post-processing.
This PR is technically BC-breaking because it removes the public `ModuleWrapPolicy.policy`. However, I do not think anyone was using that anyway, so this is a pretty safe breakage.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104969
Approved by: https://github.com/rohan-varma
ghstack dependencies: #104427, #104967, #104999
Purely out of preference, this PR renames the streams to `_unshard_stream` instead of `_streams_unshard` etc. since the former reads more naturally. The PR also removes some duplicated comments and adds back a unit test that streams are shared.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104966
Approved by: https://github.com/rohan-varma
This moves `fully_shard` to use `_auto_wrap()` just like `FullyShardedDataParallel`. This means that `fully_shard` goes through the `_init_param_handle_from_module()` path (i.e. 1 `fully_shard` per "wrap"), removing the need for `_init_param_handles_from_module()` (which was 1 `fully_shard` for all "wraps" of a given policy). `_auto_wrap()` simply calls `fully_shard` on target submodules.
This includes several important fixes:
- We should register the pre/post-forward hooks on the module regardless of it has managed parameters.
- We can permit `_module_handles` to return `[]` in the composable path (for when the module has no managed parameters).
- We should unify the paths for `_get_buffers_and_dtypes_for_computation()` (previously, composable path was buggy in some cases).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104408
Approved by: https://github.com/rohan-varma