A recent PR deprecated `torch.testing.assert_allclose` in favor of `torch.testing.assert_close` and left a `TODO`. This PR follows up to confirm that we do intend to have `check_dtype=False`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90251
Approved by: https://github.com/rohan-varma
`FSDP.clip_grad_norm_()` is tested separately in `test_fsdp_clip_grad_norm.py`. This PR removes the dead non-run code from `common_fsdp.py` and `test_fsdp_core.py` related to `clip_grad_norm_()`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90250
Approved by: https://github.com/rohan-varma
**BC Breaking Change**
This renames `unwrapped_params` to `nonwrapped_numel`. I prefer `nonwrapped` over `unwrapped` because "unwrap" suggests that some wrapping has been undone. I prefer `numel` over `params` because that is unit of measurement; I think we should keep "params" to refer to `nn.Parameter`s themselves.
This only breaks anything that passes `unwrapped_params` as a keyword argument, but I did not see anything that did that (except the one internal benchmark file but that does not actually depend on our `pytorch` code).
In a follow-up, I want to rename `min_num_params` to `min_nonwrapped_numel` in `size_based_auto_wrap_policy`, which is also BC breaking. Again, this is to differentiate between "params" being `nn.Parameter`s and "numel" being the unit for `param.numel()`.
**Overview**
This PR introduces `ModuleWrapPolicy` as a lightweight layer over the existing `transformer_auto_wrap_policy`. The most common auto wrapping paradigm is:
```
module_classes: Set[Type[nn.Module]] = ...
auto_wrap_policy = functools.partial(
transformer_auto_wrap_policy,
transformer_layer_cls=module_classes,
)
fsdp_model = FSDP(model, auto_wrap_policy=auto_wrap_policy, ...)
```
Now, users can instead write:
```
auto_wrap_policy = ModuleWrapPolicy(module_classes)
fsdp_model = FSDP(model, auto_wrap_policy=auto_wrap_policy, ...)
```
This hides the unused arguments expected from the callable (`recurse` and `unwrapped_params`/`nonwrapped_numel`).
`ModuleWrapPolicy` inherits from an abstract base class `FSDPPolicy` that expects a `policy` property. This decouples the construct of such `FSDPPolicy` classes and their actual `policy`, which must abide by the `_recursive_wrap` interface. Any existing auto wrap policy can be rewritten as a class that inherits from `FSDPPolicy`, so this approach is fully backward compatible from a functionality perspective.
I call this base class `FSDPPolicy` to generalize over the cases where we may not want to actually perform any nested wrapping. In reality, the policy is meant for constructing `FlatParameter`s, which just happened to be induced by a nested wrapping before. Given this, I am changing the constructor argument in `fully_shard()` to simply `policy` instead of `auto_wrap_policy`.
This PR migrates usages of `transformer_auto_wrap_policy` within our unit test suite to `ModuleWrapPolicy` as much as possible.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88450
Approved by: https://github.com/zhaojuanmao
**`_init_param_attributes()` -> `init_flat_param_attributes()`**
We move `_init_param_attributes()` to `FlatParamHandle.init_flat_param_attributes()` (as already marked as to-do during previous refactoring).
**`_reset_lazy_init()`**
We no longer delete `_local_shard` from each `FlatParameter` in `_reset_lazy_init()`.
**Analysis**
Thus, the two semantic differences are that we remove the initial `if hasattr(p, "_local_shard")` early return in `_init_param_attributes()` and the `delattr(p, "_local_shard")` in `_reset_lazy_init()`.
This is safe because
- If we never call `_reset_lazy_init()`, then `init_flat_param_attributes()` is only called once. There is no opportunity for an early return.
- If we call `_reset_lazy_init()`, then `init_flat_param_attributes()` will be called again in the next `_lazy_init()`. However, since we removed the early return, all of the attributes initialized in `init_flat_param_attributes()` simply get re-initialized and override any existing attributes.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87938
Approved by: https://github.com/mrshenli
This PR actually has meaningful changes. We stratify `TrainingState` into two levels: one is per FSDP instance and one is per `FlatParamHandle`/`FlatParameter`.
- At the FSDP instance level, we only care about `IDLE`, FSDP computation (i.e. `FORWARD_BACKWARD`), or `SUMMON_FULL_PARAMS`. These dynamically modify behavior (e.g. `summon_full_params()` forces full precision).
- At the `FlatParamHandle` level, we care about the training state for invariants and debugging. Hence, we keep `IDLE`, `FORWARD`, `BACKWARD_PRE`, `BACKWARD_POST`, and `SUMMON_FULL_PARAMS`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87916
Approved by: https://github.com/mrshenli
This PR reworks FSDP's `clip_grad_norm_()` and its unit tests. The unit tests in `test_fsdp_core.py` still need to be revisited and will be done in follow-up work.
Some details in arbitrary order:
- This renames `_calc_grad_norm()` to `_get_grad_norm()`. This is to simplify our verb usage in method names. Otherwise, we may diverge to different verbs like "compute", "calculate", "get", "find" etc. I am open to discussion here.
- Because we call `torch.linalg.vector_norm()` as the underlying norm calculation subroutine, which can take infinity as input for the norm type, there is no reason to have a separate conditional branch for the infinity norm.
- This removes a host-device synchronization point from `clip_grad_norm_()` by using the same trick from `torch.nn.utils.clip_grad_norm_()`. This may improve throughput for workloads like metaseq, which computes gradient norms regularly.
- This returns the total norm from `clip_grad_norm_()` as mentioned in the docstring. Before nothing was returned.
- This rewrites the unit tests, which were slightly problematic. Much of the logic to verify gradient norms were computed correctly were exactly the same as the logic used to compute them in FSDP (i.e. `^p`, sum via all-reduce, `^(1/p)`). This defeats the purpose of unit testing. There were some other oddities like `input = torch.rand(14, 2, device=self.rank); in_data = torch.tensor(input[self.rank], device=self.rank)`, where we materialize a full `(14, 2)` shape but only ever use the first two rows (assuming world size 2).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87479
Approved by: https://github.com/rohan-varma
**Overview**
This PR adds the option to use the original parameters via `use_orig_params=True` in the FSDP constructor.
- This exposes the original parameters rather than the `FlatParameter`s from `named_parameters()`, which means that the optimizer runs on the original parameters. Hence, users may assign original parameters from the same `FlatParameter` to different parameter groups.
- This enables decoupling the original parameter variables from their storage without changing the variables themselves, which is critical for our upcoming execution-order-based non-recursive wrapping policy.
For more detailed design explanation, refer to the Quip shared internally.
**Follow-Ups**
See 85831 (removing link to avoid spamming the issue whenever I update this PR).
`test_fsdp_use_orig_params.py` adds ~4 min 46 seconds to the TTS on the AWS cluster.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84911
Approved by: https://github.com/rohan-varma
- This implements explicit forward prefetching following the static 1st iteration's pre-forward order when `forward_prefetch=True` in the FSDP constructor.
- This has the same unit test coverage as the original `forward_prefetch`.
- I checked via print statements that the prefetches are happening, but since I cannot get a good CPU bound workload, it is hard to tell via traces that the prefetch is working.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85177
Approved by: https://github.com/zhaojuanmao
### Additional Constructor Changes
- `self.sharding_strategy`
- If the world size is 1, I clamp the sharding strategy to `NO_SHARD`, regardless of the passed-in sharding strategy, since the behavior is fully equivalent. This absolves the need for `p._is_sharded or self.world_size == 1` checks in the core code. Once we fully shift the paradigm to using handles, this should result in a clear net positive. However, for now, we still have some places where we interface directly with the `FlatParameter`, in which case we have some temporary hacky code.
- `HandleConfig`
- As a part of the new design abstraction, much logic is lowered to the `FlatParamHandle`. This requires the handle be aware of mixed precision, CPU offloading, sharding strategy, and the process group (for world size > 1). To be less error-prone, I re-defined the `dataclass`s and `enum`s for the handle. These can be removed and coalesced with the existing ones.
- The drawback is that the `FlattenParamsWrapper` constructor now takes in the `HandleConfig` to forward it to the `FlatParamHandle` constructor. I tolerate this since we plan to retire the FPW. For now, the handle's process group attributes are set later when we call `handle.shard()`.
- We will dive into this logic lowering later. For now, the idea is we need to pass some extra info to the handle, which must go through the FPW.
- `FullyShardedDataParallel._shard_parameters()` -> `FlatParamHandle.shard()`
- [Important] Generalizing attributes to remove the 1 `FullyShardedDataParallel` : 1 `FlatParameter` assumption
- **Before:** `_fsdp_graph_order`, `_pre_backward_hook_full_params_prefetched`, `_forward_full_params_prefetched`, `reshard_after_forward` are with respect to 1 `FullyShardedDataParallel`
- **After:** (1) We use `FlatParamHandle` in place of `FullyShardedDataParallel`. (2) The atomic unit for forward and pre-backward is a _group_ of handles involved in the same module's forward/pre-backward. This is represented as `Tuple[FlatParamHandle, ...]`. For now, this is **always a singleton tuple**, but this shift enables a module having multiple FSDP parameters (which we have use cases for).
- `_reset_lazy_init()` attributes
- The prefetched flags are merged into `self._handles_prefetched`, which is directly defined in the constructor. `reshard_after_forward` is retired since it can be fully determined by other attributes (`_is_root` and `sharding_strategy`).
## FSDP Runtime: Unshard
The first step is to read the existing `_rebuild_full_params()`. A few notable observations:
- It returns `Tuple[Tensor, bool]`. The first element is the _padded unsharded flattened parameter_, and the second element is whether we can free it upon exiting `summon_full_params()`. This return value is **only used in `summon_full_params()`**.
- If parameter mixed precision is enabled and the `FlatParameter` is already unsharded, then the low precision shard (`_mp_shard`) is still re-allocated on GPU. (It is freed at the end of the method.)
- If CPU offloading is enabled and the `FlatParameter` is already unsharded, then there is a no-op `p.data = p.data.to(self.compute_device, non_blocking=True)`.
- Inside `summon_full_params()`, `mixed_precision_cast_ran` is always `False`. Therefore, the return value for the `not p._is_sharded and mixed_precision_cast_ran` branch is unused.
-`summon_full_params()` can only be called (before forward or after backward) or (between forward and backward). Given this, I cannot think of a case where we call `summon_full_params()`, the `FlatParameter` is already unsharded, but `reshard_after_forward` is `True`. The `FlatParameter` should be sharded (before forward or after backward), and the `FlatParameter` may only be unsharded (between forward and backward) if `reshard_after_forward` is `False`.
- If parameter mixed precision is enabled and the sharding strategy is a sharded one, then inside `summon_full_params()`, the `FlatParameter` is unsharded in full precision. This involves allocating a new padded unsharded flattened parameter on GPU in full precision since `_full_param_padded` is in the low precision.
Some comments:
- Ideally, we reduce the complexity of the core code path: i.e. unshard for forward and pre-backward. If the return value is only used for `summon_full_params()`, we should consider if we can compartmentalize that logic.
- The branching is complex, and some return values are never used, where this fact is not immediately obvious. We should see if we can reduce the branch complexity.
Disclaimer: The difference in attribute semantics between `NO_SHARD` and the sharded strategies makes it challenging to unify the cases. This PR does not attempt to address that since it requires more design thought. However, it does attempt to reduce the complexity for the sharded strategies.
### Unshard: Core Code Path
Let us trace through the new logical unshard.
1. `FullyShardedDataParallel._unshard(self, handles: List[FlatParamHandle], prepare_gradient: bool)`
- This iterates over the handles and calls `handle.pre_unshard()`, `handle.unshard()`, and `handle.post_unshard(prepare_gradient)` in the all-gather stream.
2. `FlatParamHandle.needs_unshard(self)`
- We take an aside to look at this key subroutine.
- For `NO_SHARD`, this returns `False`.
- For sharded strategies, this checks if the padded unsharded flattened parameter is allocated. The padded unsharded flattened parameter is the base tensor for the unpadded unsharded flattened parameter, which is a view into the padded one. Thus, the padded one's allocation fully determines if the `FlatParameter` is unsharded.
- For sharded strategies, to accommodate the parameter mixed precision + `summon_full_params()` case, we introduce `_full_prec_full_param_padded`, which is the padded unsharded flattened parameter in full precision. The helper `_get_padded_unsharded_flat_param()` takes care of this casing and returns the padded unsharded flattened parameter. Instead of allocating a new tensor each time, we manually manage `_full_prec_full_param_padded`'s storage just like for `_full_param_padded`.
3. `FlatParamHandle.pre_unshard(self)`
- For sharded strategies, the postcondition is that the handle's `FlatParameter` points to the tensor to all-gather. This should be on the communication device and in the desired precision. The allocation and usage of the low precision shard for parameter mixed precision and the CPU -> GPU copy for CPU offloading both classify naturally in the pre-unshard.
- For sharded strategies, if the `FlatParameter` does not need to be unsharded, `pre_unshard()` is a no-op. This avoids unnecessarily allocating and freeing the low precision shard.
- For `NO_SHARD`, we simply preserve the existing semantics.
4. `FlatParamHandle.unshard(self)`
- If the handle was resharded without freeing the padded unsharded flattened parameter (e.g. `summon_full_params()` between forward and backward when `reshard_after_forward=False`), then the `FlatParameter` points to the sharded flattened parameter. We need to switch to using the unsharded parameter. This is a design choice. Alternatively, we may not switch to using the sharded flattened parameter in `reshard()` if we do not free the padded unsharded flattened parameter. However, the postcondition that the `FlatParameter` points to the sharded flattened parameter after `reshard()` is helpful logically, so I prefer this approach.
- Otherwise, this allocates the padded unsharded flattened parameter, all-gathers, and switches to using the unpadded unsharded flattened parameter.
- In the future, we may add an option to `unshard()` that additionally all-gathers the gradient.
5. `FlatParamHandle.post_unshard(self, prepare_gradient: bool)`
- For sharded strategies, if using parameter mixed precision, this frees the low precision shard. More generally, this should free any sharded allocations made in `pre_unshard()` since the all-gather has been launched. If using CPU offloading, the GPU copy of the local shard goes out of scope after `unshard()` and is able to be garbage collected. **We should understand if there is any performance difference between manually freeing versus deferring to garbage collection since our usage is inconsistent.** For now, I preserve the existing semantics here.
- `prepare_gradient` is meant to be set to `True` for the pre-backward unshard and `False` for the forward unshard. This runs the equivalent logic of `_prep_grads_for_backward()`.
- This post-unshard logic (notably the gradient preparation) now runs in the all-gather stream, which is fine because we always have the current stream wait for the all-gather stream immediately after `FullyShardedDataParallel._unshard()`. IIUC, we do not need to call `_mp_shard.record_stream(current_stream)` (where `current_stream` is the default stream) because `_mp_shard` is allocated and freed in the same (all-gather) stream.
- A postcondition is that the `FlatParameter` is on the compute device. It should also have the unpadded unsharded size (though I do not have a check for this at the moment).
### Unshard: `summon_full_params()`
Now that we see how the logical unshard has been reorganized for the core code path, let us dive into `summon_full_params()`.
The two constraints are:
1. If using parameter mixed precision, we should unshard in full precision.
2. We must determine if we should free the padded unsharded flattened parameter upon exiting.
The first constraint is addressed as described before in the core unshard code path, so it remains to explore the second constraint.
I propose a simple rule: **We free iff we actually unshard the `FlatParameter` in `summon_full_params()`** (i.e. it was not already unsharded). We perform a case analysis:
**Parameter mixed precision enabled:**
* `NO_SHARD`: `flat_param.data` points to `flat_param._local_shard`, which is the full precision unsharded flattened parameter. This is **not safe to free**.
* `FULL_SHARD` / `SHARD_GRAD_OP`: We force full precision and all-gather to `_full_prec_full_param_padded`. We do not support `nested summon_full_params()`, so `_full_prec_full_param_padded` must be unallocated. We unshard, and it is **safe to free**.
**Parameter mixed precision disabled:**
* `NO_SHARD`: This is the same as with mixed precision enabled. This is **not safe to free**.
* `FULL_SHARD` / `SHARD_GRAD_OP`: We all-gather to `_full_param_padded`. It may already be unsharded.
* Already unsharded: The unshard is a no-op. This is **not safe to free**.
* For `FULL_SHARD`, this can happen for the root FSDP instance after `forward()` but before backward.
* For `SHARD_GRAD_OP`, this can happen for all FSDP instances after `forward()` but before backward.
* Needs unshard: We unshard. This is **safe to free**.
Therefore, we see that it is not safe to free when using `NO_SHARD` and when using a sharded strategy but the `FlatParameter` is already unsharded. This is precisely the proposed rule.
There were two notable edge cases that the existing code did not address.
1. The existing code tests if the `FlatParameter` is already unsharded by checking the allocation status of `_full_param_padded`. When using parameter mixed precision, this is the incorrect tensor to check. If `_full_param_padded` is allocated (e.g. when `reshard_after_forward=False` and calling `summon_full_params()` between forward and backward), the already-unsharded check is a false positive, and `summon_full_params()` does not correctly force full precision. https://github.com/pytorch/pytorch/issues/83068
- This PR's `needs_unshard()` check correctly routes to the appropriate padded unsharded flattened parameter depending on the calling context (i.e. if it needs to force full precision or not).
2. The existing code does not free the GPU copy of the padded unsharded flattened parameter when calling `summon_full_params(offload_to_cpu=True)`. It unshards the `FlatParameter`, moves the padded unsharded flattened parameter to CPU, and sets the `FlatParameter` data to be the appropriate unpadded view into the padded unsharded flattened parameter on CPU. However, `_full_param_padded` still points to the all-gathered padded unsharded flattened parameter on GPU, which is kept in memory. https://github.com/pytorch/pytorch/issues/83076
- This PR frees the GPU copy and reallocates it upon exiting `summon_full_params()`. This is essential for avoiding peak GPU memory usage from increasing as we recurse through the module tree. There may be some cases where we can avoid reallocation altogether, but that can be addressed in a follow-up PR.
- This PR offloads the *unpadded* unsharded flattened parameter to CPU directly instead of the *padded* one. As far as I can tell, there is no need to include the padding since unflattening the original parameters does not require the padding.
- The relevant code is in the context manager `FlatParamHandle.to_cpu()`.
### Unshard: Mixed-Precision Stream
This PR removes the mixed precision stream usage. As is, I do not think there is any extra overlap being achieved by the stream usage.
The low precision shard is allocated and copied to in the mixed precision stream ([code](1f99bdfcc4/torch/distributed/fsdp/fully_sharded_data_parallel.py (L1401-L1412))), and the current stream (in this case the all-gather stream) waits for the mixed precision stream ([code](1f99bdfcc4/torch/distributed/fsdp/fully_sharded_data_parallel.py (L1414))). However, we immediately schedule an all-gather that communicates that exact low precision shard ([code](1f99bdfcc4/torch/distributed/fsdp/fully_sharded_data_parallel.py (L3338))) with no other meaningful computation between. If we remove the mixed precision stream, the low precision shard is allocated and copied to in the all-gather stream (including the non-blocking CPU -> GPU copy if using CPU offloading).
Under this PR's design, we may consider a "pre-unshard" stream for all logical pre-unshard data transfers if we want to overlap in the future. IIUC, the overlap opportunity exists if there are multiple `FlatParameter`s per module, and we only have the all-gather stream wait for the data transfer corresponding to the local shard it communicates, not the others.
If we agree on removing the mixed-precision stream for now, I will remember to delete it from `_init_streams()`.
## FSDP Runtime: Reshard
Like with unshard, the first step is the look at the existing `_free_full_params()` and `_use_param_local_shard()`. A few notable observations:
- For only `NO_SHARD`, `_free_full_params()` includes a call to `_free_mp_shard()`.
- For `summon_full_params()`, there is a separate `_free_full_params_and_use_local_shard()` that duplicates the main logic of `_free_full_params()` and calls `_use_param_local_shard()`.
- In `forward()`, if `reshard_after_forward=True`, we call `_free_full_params()` and then `_free_mp_shard()`. Hence, for `NO_SHARD`, the `_free_mp_shard()` is a no-op.
- In the post-backward hook, we typically call `_free_full_params()` and `_free_mp_shard()`. The `_free_mp_shard()` is a no-op for `NO_SHARD` and if `reshard_after_forward=True`.
Some comments:
- The code certainly works, but some of the no-ops are subtle. When possible, we should make it clear when calls are no-ops or not. It is good that the existing code documents that `_free_mp_shard()` is a no-op in the post-backward hook when `reshard_after_forward=True`. However, there are still some non-obvious no-ops (around `NO_SHARD`).
- We should see if we can avoid the duplicate `_free_full_params_and_use_local_shard()`.
Let us trace through the logical reshard:
1. `FullyShardedDataParallel._reshard(self, handles: List[FlatParamHandle], free_unsharded_flat_params: List[bool])`
- The two args should have the same length since they are to be zipped.
- The goal of having `free_unsharded_flat_params` is that the caller should be explicit about whether the (padded) unsharded flattened parameter should be freed. The low precision shard is always meant to be freed (as early as possible), so there is no corresponding `List[bool]`.
2. `FlatParamHandle.reshard(self, free_unsharded_flat_param: bool)`
- This frees the (padded) unsharded flattened parameter if `free_unsharded_flat_param` and switches to using the sharded flattened parameter.
- Echoing back to forcing full precision in `summon_full_params()`, `_free_unsharded_flat_param()` frees the correct tensor by using `_get_padded_unsharded_flat_parameter()`.
3. `FlatParamHandle.post_reshard(self)`
- I am not fully content with the existence of this method, but this seems to be an unavoidable consequence of `NO_SHARD`. Perhaps, this may be useful in the future for other reasons though.
- Right now, this method is only meaningful for `NO_SHARD` + parameter mixed precision + outside `summon_full_params()`. `_mp_shard` is not freed in the post-unshard since it is also the low precision _unsharded_ flattened parameter, so we must delay the free until the the post-reshard.
Below the `FlatParamHandle.reshard()` and `post_reshard()` layer, there should not be any no-ops.
One final comment I will mention is that I like the `pre_unshard()`, `unshard()`, `post_unshard()`, and `reshard()`, `post_reshard()` organization because it makes it clear what the boundaries are and their temporal relationship. Through that, we can set pre- and post-conditions. Furthermore, we can eventually convert logic to hooks that may be registered on the `FlatParamHandle` (for `pre_unshard()`, `post_unshard()`, and `post_reshard()`). This may improve the customizability of FSDP.
## FSDP Runtime: `forward()`
- This PR reorganizes `forward()` in preparation for non-recursive wrapping, which uses pre-forward and post-forward hooks that expect the signature `hook(module, input)`. For FSDP, the `module` and `input` arguments are not used.
- This PR creates a new method `_fsdp_root_pre_forward()` to handle the logic only the root FSDP should run.
## FSDP Prefetching
Finally, we dive into the prefetching changes. Some highlights:
1. This PR unifies the execution order validation and prefetching implementations.
- Both involve the execution order and can be unified to share some boilerplate.
2. Execution order validation only runs when the distributed debug level is `INFO`.
- We have yet to have one success case where we actually catch an unintended source of dynamism. The warning is also too verbose. Hence, we are gating it by the `INFO` level.
3. This PR moves prefetching to be with respect to groups of handles (as mentioned in the constructor comment).
- This is essential for supporting prefetching with non-recursive wrapping.
4. This PR does not include "bubbles", i.e. modules with no handles, in the recorded execution order(s). This deviates from the existing implementation.
- This makes prefetching possibly more aggressive (when there are such bubbles), but it should not have significant performance implications either way.
5. This PR changes backward prefetching to reset the post-forward order each iteration (as intended).
6. This PR changes forward prefetching to use the first iteration's pre-forward order instead of the first iteration's post-forward order. (We can discuss whether we want this in this PR or not. Otherwise, I can keep it as using the post-forward order to preserve the existing semantics.) This PR also removes the `all_gather_stream.wait_stream(current_stream)` before forward prefetching because it does not help with high GPU reserved memory. We can add that back if desired.
### Appendix
#### Reverse Post-Forward Order Is Not Always the Pre-Backward Order
The existing PT-D FSDP pre-backward prefetching uses the reverse post-forward order.
<details>
<summary>Model Code</summary>
```
class Model(nn.Module):
def __init__(self):
super().__init__()
self.block1 = nn.Sequential(
nn.Conv2d(3, 4, kernel_size=3),
nn.BatchNorm2d(4),
nn.ReLU(inplace=True),
)
self.block2 = nn.Sequential(
nn.Conv2d(4, 4, kernel_size=3),
nn.BatchNorm2d(4),
nn.ReLU(inplace=False),
)
self.block3 = nn.Linear(12, 8)
self.head = nn.Sequential(
nn.AdaptiveAvgPool2d(output_size=(1, 1)),
nn.Flatten(),
nn.Linear(4, 10),
)
def forward(self, x):
x = self.block1(x)
x = self.block2(x)
x = self.block3(x)
return self.head(x)
model = Model().cuda()
fsdp_kwargs = {}
model.block1[1] = FSDP(model.block1[1], **fsdp_kwargs) # BN2d
model.block2[1] = FSDP(model.block2[1], **fsdp_kwargs) # BN2d
model.block1 = FSDP(model.block1, **fsdp_kwargs)
model.block2 = FSDP(model.block2, **fsdp_kwargs)
model.block3 = FSDP(model.block3, **fsdp_kwargs)
model = FSDP(model, **fsdp_kwargs)
```
</details>
<details>
<summary>Execution Orders </summary>
```
Pre-backward hook for ('head.2.weight', 'head.2.bias') 140339520587136 (model)
Pre-backward hook for ('weight', 'bias') 140339461194656 (block3)
Pre-backward hook for ('0.weight', '0.bias') 140339520589776 (block2)
Pre-backward hook for ('weight', 'bias') 140339520587664 (block2 BN)
Pre-backward hook for ('weight', 'bias') 140339520586656 (block1 BN)
Pre-backward hook for ('0.weight', '0.bias') 140339520588768 (block1)
Pre-forward order:
('head.2.weight', 'head.2.bias') 140339520587136 (model)
('0.weight', '0.bias') 140339520588768 (block1)
('weight', 'bias') 140339520586656 (block1 BN)
('0.weight', '0.bias') 140339520589776 (block2)
('weight', 'bias') 140339520587664 (block2 BN)
('weight', 'bias') 140339461194656 (block3)
Reverse post-forward order:
('head.2.weight', 'head.2.bias') 140339520587136 (model)
('weight', 'bias') 140339461194656 (block3)
('0.weight', '0.bias') 140339520589776 (block2)
('weight', 'bias') 140339520587664 (block2 BN)
('0.weight', '0.bias') 140339520588768 (block1)
('weight', 'bias') 140339520586656 (block1 BN)
```
</details>
Differential Revision: [D39293429](https://our.internmc.facebook.com/intern/diff/D39293429)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83665
Approved by: https://github.com/zhaojuanmao
We are removing the `forward_prefetch` option. By the nature of async GPU kernel execution, launching the CPU kernel for the next layer's all-gather early does not actually improve performance. Moreover, the existing `forward_prefetch` uses the post-forward order instead of the pre-forward order, which leads to mis-targeted prefetched all-gathers.
Differential Revision: [D39454217](https://our.internmc.facebook.com/intern/diff/D39454217)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84600
Approved by: https://github.com/zhaojuanmao
This moves the forward and backward prefetching to be subtests.
**On the AI AWS cluster, this reduces the `test_fsdp_core.py` TTS from ~2200 seconds (36 minutes) to 480 seconds (8 minutes).**
This introduces `run_subtests()` in `common_fsdp.py` and `_get_subtest_config()` in `test_fsdp_core.py`. Feel free to give suggestions for a cleaner way to do this.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/80908
Approved by: https://github.com/rohan-varma
**Overview**
Please refer to https://github.com/pytorch/pytorch/issues/80867 first.
This addresses:
> Goal 3: Refactor model construction to enable simpler testing for the non-recursive wrapping path.
The idea is that we have an abstract class `FSDPTestModel` that defines the interface expected from the parity check and training boilerplate. This PR refactors the models in `common_fsdp.py` used in `test_fsdp_core.py` to implement this interface. Further unification under this interface is coming in follow-up PRs.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/80873
Approved by: https://github.com/rohan-varma
This PR does a number of things:
- Move linalg.vector_norm to structured kernels and simplify the logic
- Fixes a number of prexisting issues with the dtype kwarg of these ops
- Heavily simplifies and corrects the logic of `linalg.matrix_norm` and `linalg.norm` to be consistent with the docs
- Before the `_out` versions of these functions were incorrect
- Their implementation is now as efficient as expected, as it avoids reimplementing these operations whenever possible.
- Deprecates `torch.frobenius_norm` and `torch.nuclear_norm`, as they were exposed in the API and they are apparently being used in mobile (??!!) even though they were not documented and their implementation was slow.
- I'd love to get rid of these functions already, but I guess we have to go through their deprecation.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/76547
Approved by: https://github.com/mruberry
Sometimes we randomly see unrelated FSDP CI failures such as https://github.com/pytorch/pytorch/runs/6298275361?check_suite_focus=true which are unrelated to the diff at hand. Suspicion is that because some other tests set `BACKEND` which is a generic env var for distributed tests, if those tests are run in same CI container before, they won't get unset and we'll use gloo for FSDP backend.
But gloo is not currently supported, and this was mostly added for easy testing during early FSDP development, so remove this entirely.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/76878
Approved by: https://github.com/awgu
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/74860
These pre/post hooks must be registered even if the FlatParamsWrapper does not flatten any parameters; any submodule inside FlatParamsWrapper should be pre/post processed by the hooks.
ghstack-source-id: 152594052
Test Plan: CI
Reviewed By: rohan-varma
Differential Revision: D35194483
fbshipit-source-id: c25d7846f317c7ce78d77d335d041fed8db8f3a1
(cherry picked from commit db2cc311714e579362f5201922be715a626d48df)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/74452
Useful clarifications while reviewing the diff:
How fairscale implements MP for buffers:
- Accepts buffer_dtype argument in ctor that is the dtype for computation for buffers. By default this is the compute_dtype.
- During _lazy_init, for root module, _cast_buffers is called which casts buffers to buffer_dtype.
- During state_dict, buffers are cast to torch.float32, then checkpoint is taken. They are restored back to buffer_dtype after that.
How PT FSDP implements MP for buffers in this diff:
- Rather than buffer_dtype in ctor, we accept MixedPrecision.buffer_dtype which is the compute type for buffers.
- During lazy_init, similar to fairsacle we cast the buffers to the type given by the MP config. In the case of no mixed precision the default behavior is maintained.
- In _cast_buffers, we remember the original buffer dtype into a member variable. We then may cast them to a new dtype if given by the user.
- During state_dict, we use the above remembered type (stored as self._orig_buffer_dtype) and restore this type to the buffers prior to taking checkpoint. After state_dict, we restore it back to the casted type as buffers remain in this mixed precision type even after forward/backwards passes (so this is done for consistency).
- The improvement here is that we remember and restore the correct dtype of buffer the model originally had. However we assume all buffers are of the same dtype, which can be relaxed depending on use cases.
Why rebuild_full_params checks for summon_full_params training state:
- summon_full_params needs to return the full module parameters in the original precision for checkpoint to work as expected (users don't want to checkpoint the fp16 params generally). Thus, _rebuild_full_params will do this check. This is exactly the same reasoning as "force_full_precision" arg in fairscale.
- Concretely, if we're in summon_full_params, we 1) Don't cast shards to param_dtype, 2) all_gather with a full precision input rather than _full_param_paded.
Test coverage:
[ ] Test1
ghstack-source-id: 152654758
Test Plan: CI
Reviewed By: zhaojuanmao
Differential Revision: D35000703
fbshipit-source-id: 4bd7937ff36bdb3afd60eda981afc9d8731b823a
(cherry picked from commit 6ed6721aaf18f323656686200465fc78cef1d0dd)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/74215
### Overview of API
This PR introduces full optimizer state dict checkpointing.
- This allows users to save the optimizer state for a `torch.nn.Module` (not necessarily a `FullyShardedDataParallel` instance) that contains `FullyShardedDataParallel` instances and later load that optimizer state.
- This supports loading to a module with a different world size, but the `FSDP` wrapping scheme must be the same.
To **save** the optimizer state, run the following (on all ranks):
```
model: torch.nn.Module = ...
optim = torch.optim.Adam(model.parameters(), ...)
# Train for some steps...
full_osd = FSDP.full_optim_state_dict(model, optim) # returns non-empty dict only on rank 0
if rank == 0:
torch.save(full_osd, ...)
```
To **load** the optimizer state, run the following (on all ranks):
```
new_model: torch.nn.Module = ... # may use different world size
full_osd = torch.load(...)
sharded_osd = FSDP.shard_full_optim_state_dict(full_osd, new_model)
optim = torch.optim.Adam(new_model.parameters(), ...)
optim.load_state_dict(sharded_osd)
```
To support **multiple parameter groups**, we require using an additional argument `optim_input`, which is the first argument that the user passes into the optimizer constructor.
```
optim_input = ...
optim = torch.optim.Adam(optim_input, ...)
FSDP.full_optim_state_dict(model, optim, optim_input) # one more argument
...
new_optim_input = ...
new_optim = torch.optim.Adam(new_optim_input, ...)
FSDP.shard_full_optim_state_dict(full_osd, new_model, new_optim_input) # one more argument
```
One caveat is that the user should be careful of generators, which are exhausted after their first use. The `optim_input` passed into the `FSDP` APIs should be refreshed version of the generator if using generators.
### Test Plan
**`full_optim_state_dict()`**
- [x] `full_optim_state_dict()` for a non-`FSDP` root model matches that of an equivalent local model, up to parameter IDs being rearranged, when optimizer input is `model.parameters()`.
- [x] `full_optim_state_dict()` for a non-`FSDP` root model matches that of an equivalent local model, up to parameter IDs being rearranged, when optimizer input is multiple parameter groups (changing parameter order).
**`shard_full_optim_state_dict()`**
- [x] `shard_full_optim_state_dict()` for a non-`FSDP` root model matches the local `optim.state_dict()` of the same model with halved world size, when optimizer input is `model.parameters()`.
- [x] `shard_full_optim_state_dict()` for a non-`FSDP` root model matches the local `optim.state_dict()` of the same model with halved world size, when optimizer input is multiple parameter groups (changing parameter order).
- [x] `shard_full_optim_state_dict()` raises a `ValueError` when changing the `FSDP` wrapping scheme.
On the AWS cluster, the TTS contribution for these tests is ~45 seconds.
### Developer Notes
**Relaxing the Problem**
For optimizer state checkpointing, we have relaxed the problem to **not support changing the `FSDP` wrapping scheme** between save and load time. It is unclear how to solve without this relaxation. This was the least restrictive way to relax the problem since it does not affect most expected use cases. Rather, the expected change between save and load time is the **world size**, which this implementation **does support**.
Even with the relaxation, the `optim_input` argument is necessary to determine the `flat_param_id_to_param` mapping, which is important to know which parameter IDs in the flattened space correspond to `FlatParameter`s that hence need to be unflattened.
**Differences with Local Equivalent**
Suppose `full_osd = full_optim_state_dict()` and `local_osd = state_dict()` for a purely local equivalent. The difference between `full_osd` and `local_osd` is that the parameter IDs of unflattened parameters comprising a single flattened parameter are always consecutive in `full_osd`, while they may be non-consecutive in `local_osd`. Suppose in the following that each layer has 1 parameter `param`:
```
FSDP(model)
layer1
FSDP(layer2)
layer3
```
`layer1.param` and `layer3.param` are flattened and attributed to `model`. `layer2.param` is flattened and attributed to itself.
- In `local_osd`, the parameter IDs would be `0: layer1.param`, `1: layer2.param`, and `2: layer3.param`.
- In `full_osd`, the parameter IDs would be `0: layer1.param`, `1: layer3.param`, and `2: layer2.param`. (Parameter IDs of unflattened parameters sharing a flattened parameter are consecutive.)
The idea is that as long as `full_optim_state_dict()` and `shard_full_optim_state_dict()` are internally consistent, then there is no need to match the local equivalent (assuming no change in `FSDP` wrapping).
### Follow-Ups
**API**
- If needed, we can follow-up this PR by adding an argument `key_by_name: bool = False` to both methods that may be set to `True` to key parameters by `str` names instead of `int` parameter IDs. We still need to investigate if keying by name enables changing the `FSDP` wrapping scheme.
**Refactoring**
- In this optimizer state checkpointing, all optimizer state is saved to CPU on rank 0 (set as `OPTIM_TARGET_RANK`). We should unify and refactor these assumptions with model state checkpointing.
**Testing**
- The code path for unused parameters is not tested. The testing and any needed implementation fixes can be done in a follow-up.
- The code path for non-tensor states (e.g. `Adam` `"step"` as `float` instead of as zero-dimension `FloatTensor`) is not tested. However, it is identical to that of zero-dimension tensor states, so I have some confidence. If needed, I can add tests for it in a follow-up.
- Would I have to write my own optimizer? I do not want to introduce dependencies on third party libraries like Nvidia `apex`.
- We may want to add end-to-end checkpointing tests that include both model state dict and optimizer state dict.
Test Plan: Imported from OSS
Reviewed By: zhaojuanmao
Differential Revision: D35045121
Pulled By: awgu
fbshipit-source-id: 33c650dc960acbd7613d4f444a852b9f76ca4a9b
(cherry picked from commit 2bbc2e344296dc455cf686f3a9b097989504be81)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/73819
adding a new sharding_strategy config in FSDP API to support different data parallel algorithm. also add support for zero2 algorithm, which will only shard optimizer states and grads
ghstack-source-id: 151454460
Test Plan: unit tests
Reviewed By: rohan-varma
Differential Revision: D34662583
fbshipit-source-id: 14c6e0c0054692ecd76512c025d60deb4964ec5f
(cherry picked from commit 51382e882447b4756c4ee6d94ce0939a25955b00)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/73535
**Overview**
- This adds FSDP gradient accumulation without `no_sync()`, which comparatively has more network bandwidth demand but less GPU memory requirement per worker.
- This fixes a bug in the `no_sync()` testing, where the CPU offloading and backward prefetch arguments were not propagating to the `FullyShardedDataParallel` constructor.
- This adds `p_assert()` (taken from Fairscale), which prints the assert error message before raising the `AssertionError`. It is meant to be used when running in the autograd backward context since otherwise the error message is swallowed, giving a unhelpful error like:
```
<built-in method run_backward of torch._C._EngineBase object at 0x7f1fd518dc80> returned NULL without setting an error
```
NOTE: Gradient accumulation without `no_sync()` is not currently compatible with CPU offloading.
**Test Plan**
I augmented the tests to test gradient accumulation interleaving iterations accumulating with and without `no_sync()`.
After this diff:
- QPS (ResNet): f328439897
- QPS (RoBERTa): f328440141
- Accuracy: f328442119
Before this diff (trunk):
- QPS (ResNet): f328432756
- QPS (RoBERTa): f328436766
- Accuracy: f328437896
Test Plan: Imported from OSS
Reviewed By: zhaojuanmao
Differential Revision: D34533546
Pulled By: awgu
fbshipit-source-id: 821d762dfad5f2b1e59adcb8e5cb7c277399040c
(cherry picked from commit 746a5ea2720dcf87c376229b405a318396fe5769)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/73116
Users may need summon_full_params() to get the original parameters.
ghstack-source-id: 150134237
Test Plan: CI
Reviewed By: rohan-varma
Differential Revision: D34353034
fbshipit-source-id: ac69cc032da177903cd9969094f3f82dc6a61636
(cherry picked from commit 55d34fdee3778110a165a13ae987d0339e8d33c7)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/73366
Adds state_dict() save/reload in parity with DDP test to ensure
checkpointing doesn't cause issue with accuracy/model params.
ghstack-source-id: 150114251
Test Plan: CI
Reviewed By: fegin
Differential Revision: D34434358
fbshipit-source-id: fb0787486b383cfcbec7cc1325a486c8d9b1e2ea
(cherry picked from commit e3bcc7733cb5a497a640007044b1138dfee3a532)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/73324
Implements `state_dict` and `load_state_dict` APIs for FSDP, with the following limitations:
1. Does not support `state_dict_device` (i.e. specifying which device params should be on) which fairscale does currently support
2. Does not yet support offload of state_dict onto CPU
3. Loads state_dict on all ranks currently. In the future we could add support for loading this on only rank 0, to avoid redundancy across ranks as usually only one rank is responsible for saving/loading the model. Along with (2) this would enable larger models to have state_dict called.
As discussed in FSDP checkpoint API proposal, `state_dict` will basically be a `full_state_dict` where full parameters are returned on all ranks. As a result this implies that the model must actually be able to fit on a single GPU.
ghstack-source-id: 150012240
Test Plan: ci
Reviewed By: zhaojuanmao
Differential Revision: D34433514
fbshipit-source-id: 3eb1d679b2236264f9f423e761d1720f9aaec73a
(cherry picked from commit a451d5a08ebfa14a229a25fea35b9ca59fe91a59)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/73314
Needs to synchronize all_gather stream. Added test fails without this
fix
ghstack-source-id: 149800363
Test Plan: CI
Reviewed By: zhaojuanmao
Differential Revision: D34430602
fbshipit-source-id: 4ce07e2d098a4f07ac640285db1d0ff64fd42232
(cherry picked from commit 24c756e7bba69017b9358bf824589b2aeb366b5e)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/72084
make fsdp folder to be public
ghstack-source-id: 148173447
Test Plan: unit tests
Reviewed By: mrshenli
Differential Revision: D33903417
fbshipit-source-id: 7852a2adc4af09af48a5ffa52ebf210489f834d5
(cherry picked from commit bd06513cfe)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/70235
address comments in https://github.com/pytorch/pytorch/pull/69282:
Have fixed a few corner cases for prefetching full parameters in post backward hook.
After benchmarking, prefetching full parameters in the pre-backward hook has the best performance and stable but at cost of increased memory; prefetching full parameters in the post-backward hook did not see expected performance, also failed in a few corner cases (fixed) although there is no memory increase. The main issue is that post backward hook fire order is not consistent with opposite of forward computation order, so incorrectly prefetched all gather could delay the really needed all gather in the single NCCL stream and cause some layer's computation delay.
So putting these two algorithms as two configurable experimental algorithms for now
prefetch full parameters at pre-backward hook:
It is observed from past traces that all gather ops are not triggered until current layer's backward pass starts to compute, also for some models previous layers' reduce scatter is scheduled before next layer's all gather ops, since all gather and reduce scatter are in the same nccl stream, this case could result in backward pass has no communication and computation overlap.
To explicitly make next layers' all gather scheduled while previous layers' backward computation is running, we can prefetch next layers' all gather full params. This can help 1) both all gather and reduce scatter are overlapped with computation deterministically 2) only prefetch one layer's all gather full parameters, to avoid increasing too much memories.
The implementation borrowed the idea from facebookresearch/fairscale#865, where forward graph order is recorded in the forward pass.
In the backward pass, this PR prefetches all gather full parameter in current layer's pre-backward hook, instead of prefetching in current layer's post backward hook in facebookresearch/fairscale#865. Also make sure all gather streams are synced properly.
Experiments showed 10% memory increase and 20% latency speed up for 1GB roberta model in a slow network environment.
Test Plan: unit tests
Reviewed By: rohan-varma
Differential Revision: D33252795
fbshipit-source-id: 4e2f47225ba223e7429b0dcaa89df3634bb70050
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/69955
Implements a checkpoint_wrapper function, which wraps nn.Module with checkpointing so user won't have to call checkpoint() everytime they want to checkpoint the module.
Currently only support for reentrant-based checkpointing is added and only tested with FSDP to unblock a use case.
Future work is to add support for new checkpointing API, add more tests, upstream to torch.utils.checkpoint.
ghstack-source-id: 145811242
Test Plan: CI
Reviewed By: mrshenli
Differential Revision: D33107276
fbshipit-source-id: c4a1c68d71d65713a929994940a8750f73fbdbdb
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/68308
export CPUOffload in _fsdp package, as cpu_offload config in FSDP API needs to import this class
ghstack-source-id: 143560608
Test Plan: unit tests
Reviewed By: rohan-varma
Differential Revision: D32408719
fbshipit-source-id: ee5c40ec91a423fbd58872fbdeb5f2dda8a3d89e
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/67135
Add ability to use env var backend for quicker testing (and gloo2 in
the future)
ghstack-source-id: 142274304
Test Plan: CI
Reviewed By: mrshenli
Differential Revision: D31878285
fbshipit-source-id: 80ae7107cd631a1a15ebc23262b27d8192cfe4b6
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/67249
Implements CPU offload for model parameters in FSDP.
- CPU offload class with only offload_params attribute is created
- If this is specified in FSDP ctor, model parameters are moved back to CPU after sharding in __init__
- In forward pass, during lazy init, p._local_shard gets set to p.data so it is on CPU. We pin_memory here.
- In forward pass, in _rebuild_full_params, we move p.data back to self.compute_device if necessary. Note that we don't use the device of p._full_param_padded because we don't always have this attr, but when we do its always the same as compute_device.
- The same logic as above applies to the beginning of backwards pass.
- At end of fwd and end of bwd, `_use_param_local_shard` takes care to ensure the parameters are offloaded to CPU again, by pointing it to p._local_shard, which is always on CPU.
Regarding tests:
- We tests 3 different types of init: 1) CUDA the model before wrapping with FSDP, 2) CUDA the model after wrapping with FSDP, 3) never CUDA the model.
- Case 1 is always supported. Case 2 is not supported with CPU offload and throws an error during fwd pass. Case 3 is only supported with CPU offload at the moment.
- Verifies all params are offloaded to CPU after init.
- Verifies all params are offloaded to CPU after forward and backward.
- Note that there is an issue with verifying exact parity when CPU offloading, but it appears to be related to transfering model back and forth cpu/CUDA. More details in https://github.com/pytorch/pytorch/pull/66961
ghstack-source-id: 141851903
Test Plan: CI
Reviewed By: mrshenli
Differential Revision: D31911085
fbshipit-source-id: 3ddf73c070b55ce383e62251868d609004fc30e7