There was an issue reported internally that with `sync_module_states=True`, if the model had buffers on CPU, even with `device_id` specified, FSDP would try to broadcast CPU buffers, leading to an error like:
```
RuntimeError: No backend type associated with device type cpu
```
After some investigation, I determined that we should _not_ fix this by moving the buffers to GPU just for the broadcast and then back to CPU. Instead, we should fix our `device_id` logic.
The issue is that we always used the _parameters_ as the proxy to tell whether we should move module states to the device specified by `device_id`. However, a module (often the root) may not have any parameters but have some buffers! In that case, the buffers are left on CPU even if `device_id` is specified. This PR fixes this by considering both parameters and buffers for movement to `device_id`.
Note that this PR preserves the logic that `ignored_modules` / `ignored_parameters` are not considered for this movement, meaning that ignored parameters are moved to `device_id`.
Note also that I had to move the unit test back from using MTPG to the normal PG since otherwise, I could not repro the original error. (It seems like MTPG does not complain if we try to use `dist._broadcast_coalesced()` with CPU tensors.)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103504
Approved by: https://github.com/rohan-varma
Enables MTPG for some FSDP tests in this file. Tests that need the
backward pass and warning logging are left as follow up work.
Backward pass issue: It seems that there is a hang with all_gather. Will sync with @kumpera on this.
Warning issue: We have a couple tests that regex check on warnings, but in the
multithreaded scenario these warnings are somehow not logged.
Differential Revision: [D43209769](https://our.internmc.facebook.com/intern/diff/D43209769/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102043
Approved by: https://github.com/awgu
Add 'ignored_states' that accepts either a list of ignored_parameters or a list of nn modules for FSDP model wrapper and fully_shard composable APIs, it is recommended to use 'ignored_states' over 'ignored_modules' moving forward
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102056
Approved by: https://github.com/awgu
FSDP creates communication groups for intra-node communication through dist.new_subgroups. Previously, dist.new_subgroups only supported creation based on the number of CUDA devices. However, issue #99706 removed the avaliable-check for CUDA devices, allowing for custom backend create group based on num of custom devices per node.
This PR allows FSDP to explicitly pass device num within the node when creating communication groups for intra-node communication, instead of defaulting to the number of CUDA devices.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100622
Approved by: https://github.com/awgu
Custom backend implementation based on privateuse1 with semantics identical to CUDA (CUDA is so popular), named for example 'my_device', and registered as the same module name torch.my_device.
This PR aims to satisfy the constraints of such a backend, which can be directly integrated into the current FSDP implementation.
The main issues addressed are:
#### 1. Device decision for FSDP wrapping of Modules without Parameters
Users typically organize FSDP code as follows:
```python
m = Module().to('my_device:0')
fsdp_m = FSDP(m)
```
or like this:
```python
m = Module()
fsdp_m = FSDP(m, device_id=torch.device('my_device', 0))
```
If the model has Parameters, everything works fine because FSDP will prioritize the device where the Parameters are located. However, for Modules without Parameters, the to() call has no side effects, and FSDP will assume the current CUDA device, which prevents the use of devices other than the current CUDA device for Modules without Parameters. Therefore, when FSDP is called with a device_id argument, this configuration takes top priority.
#### 2. Abstraction of a cuda-like device
Now, in addition to compute_device, _FSDPState includes a device_handler member. In fact, this device_handler is now just a reference to either torch.cuda or torch.my_device. From now on, code that works based on _FSDPState should use state.device_handler to operate streams create, wait or sync, just like using torch.cuda previously.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99024
Approved by: https://github.com/awgu
The default option of `named_parameters` and `named_modules` is to remove the duplicated parameters and modules. However, in FSDP, we need to know what parameters are shared. As a result, setting `remove_duplicate` to False is required in FSDP. Without setting `remove_duplicate` to False, FSDP won't be able to discover shared weights in some cases (e.g., the shared weights are in the same module or there are shared modules).
The previous PR is reverted due to some modules overwriting the signature of `named_parameters()`. This new PR adds a workaround for the case.
Differential Revision: [D45065973](https://our.internmc.facebook.com/intern/diff/D45065973/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99448
Approved by: https://github.com/zhaojuanmao
Summary:
This diff is reverting D44897935
D44897935: [FSDP] Include duplicate parameters and modules when calling named_parameters and named_modules (#98912) by fegin has been identified to be causing the following test or build failures:
Tests affected:
- [caffe2/torch/fb/module_factory/sync_sgd/tests:test_pyper_data_parallel_wrapper - caffe2.torch.fb.module_factory.sync_sgd.tests.test_pyper_data_parallel_wrapper.PyPerDataParallelWrapperTest: test_fsdp_submodules_pyper](https://www.internalfb.com/intern/test/562950025957458/)
Here's the Multisect link:
https://www.internalfb.com/multisect/1893714
Here are the tasks that are relevant to this breakage:
We're generating a revert to back out the changes in this diff, please note the backout may land if someone accepts it.
If you believe this diff has been generated in error you may Commandeer and Abandon it.
Test Plan: NA
Reviewed By: fegin
Differential Revision: D45027286
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99353
Approved by: https://github.com/izaitsevfb, https://github.com/fegin
The default option of `named_parameters` and `named_modules` is to remove the duplicated parameters and modules. However, in FSDP, we need to know what parameters are shared. As a result, setting `remove_duplicate` to False is required in FSDP. Without setting `remove_duplicate` to False, FSDP won't be able to discover shared weights in some cases (e.g., the shared weights are in the same module or there are shared modules).
Differential Revision: [D44897935](https://our.internmc.facebook.com/intern/diff/D44897935/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/98912
Approved by: https://github.com/awgu
Previously, `_need_to_materialize_module` would return false because:
* `managed_params =_get_orig_params(module, ignored_params)` returns a generator
* `is_meta_module = any(param.is_meta for param in managed_params)` exhausts the generator in its check
* `any(fake.is_fake(param) for param in managed_params)` would try to iterate over the empty generator and get an empty sequence, thus returning `False`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/97488
Approved by: https://github.com/ngimel, https://github.com/awgu
**Overview**
This refactors module materialization (i.e. meta device or `torchdistX` deferred initialization) to compute the parameter and buffer names as needed instead of pre-computing them. These are needed to reacquire references to the states (e.g. `module.get_parameter(param_name)`) after materialization since the materialization may create new variables.
This refactor simplifies `_get_fully_sharded_module_to_states()` (the core function for "pseudo auto wrapping") to better enable lowest common ancestor (LCA) module computation for shared parameters, for which tracking parameter and buffer names may complicate the already non-obvious implementation.
**Discussion**
The tradeoff is a worst case quadratic traversal over modules if materializing all of them. However, since (1) the number of modules is relatively small, (2) the computation per module in the quadratic traversal is negligible, (3) this runs only once per training session, and (4) module materialization targets truly large models, I think this tradeoff is tolerable.
**For Reviewers**
- `_init_param_handle_from_module()` initializes _one_ `FlatParamHandle` from a fully sharded module and represents the module wrapper code path. For this code path, there is no need to reacquire references to the parameters/buffers for now since the managed parameters are only computed after materialization. This works because the managed parameters have a simple definition: any parameter in the local root module's tree excluding those already marked as flattened by FSDP. Similarly, FSDP marks buffers to indicate that they have already been processed (synced if `sync_module_states`).
- `_init_param_handles_from_module()` initializes _all_ `FlatParamHandle`s from a fully sharded module and represents the composable code path. For this code path, we must reacquire references to parameters/buffers because each logical wrapping is specified as a list of parameters/buffers to group together by those variables and because materialization may create new variables.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/94196
Approved by: https://github.com/rohan-varma
Applies the remaining flake8-comprehension fixes and checks. This changes replace all remaining unnecessary generator expressions with list/dict/set comprehensions which are more succinct, performant, and better supported by our torch.jit compiler. It also removes useless generators such as 'set(a for a in b)`, resolving it into just the set call.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/94676
Approved by: https://github.com/ezyang
We pass `fully_sharded_module`, not `root_module`, after recent refactoring to unify composable and wrapper FSDP for now. This PR removes the comment explaining why before we passed in `root_module`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/92739
Approved by: https://github.com/mrshenli
This PR adds FSDP and composable API files to `.lintrunner.toml` so that (1) lintrunner enforces that those files are formatted and (2) `lintrunner f` formats those files for you.
There are two requirements here (see https://github.com/pytorch/pytorch/wiki/lintrunner for details):
1. Install lintrunner:
```
pip install lintrunner
lintrunner init
```
2. `lintrunner f` before you finalize your PR, which would now be enforced by CI after this PR.
The code changes in this PR outside of `.lintrunner.toml` are the result of `lintrunner f`.
---
I only plan to land this PR if all of the composable API developers agree that this is something that makes sense and is not too intrusive to the workflow.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90873
Approved by: https://github.com/yhcharles, https://github.com/mrshenli, https://github.com/rohan-varma
This PR supports nesting `replicate` in `fully_shard`.
- The PR achieves this by treating `replicate`-annotated modules are ignored modules. This means that all submodules in the `replicate`-annotated module's subtree are ignored, including nested `fully_shard`-annotated modules, which is the desired behavior.
---
This PR reworks some tree traversal.
One end goal is for `state._handles` to follow the same order for both the wrapper and composable paths. This implies that `_get_fsdp_handles()` returns the same value for both paths.
- The helper function `_get_fully_sharded_module_to_states()` now follows a left-to-right DFS from each fully sharded module instead of a BFS. The left-to-right DFS follows `.modules()` order.
- The composable auto "wrap" initialization function `_init_param_handles_from_module()` follows the reverse left-to-right DFS order. As noted in the code comments, this initialization order is a valid reverse topological sort, but it differs from the wrapper path. This is the _only_ difference with respect to initialization order through the entire process.
```
mod: Module(
submod1: Submodule()
submod2: Submodule(
subsubmod: Subsubmodule(),
),
)
```
For left-to-right DFS, the order is `mod`, `submod1`, `submod2`, `subsubmod`. (For context, right-to-left DFS would be `mod`, `submod2`, `subsubmod`, `submod1`. In other words, the left-to-right vs. right-to-left corresponds to `.children()` vs. `reversed(.children())` respectively.) Then, reverse left-to-right DFS is `subsubmod`, `submod2`, `submod1`, `mod`, which is a valid initialization order. However, the wrapper auto wrap initialization order would be `submod1`, `subsubmod`, `submod2`, `mod` since it directly follows a left-to-right DFS and initializes as a part of the recursive DFS logic.
- At the end of `_init_param_handles_from_module()`, we reverse the newly populated `state._handles`, so this is the reverse reverse left-to-right DFS order, which is equivalent to the left-to-right DFS order. Thus, `state._handles` has the same order for both paths.
Another goal is for `_get_fsdp_states()` to not traverse into any submodule that is annotated with an API that is not compatible with `fully_shard` (e.g. `replicate`). To achieve this while preserving that `_get_fsdp_states()` follows `.modules()` order, we again use a left-to-right DFS.
The reason the DFSs may look strange is because I implemented them non-recursively, which requires a stack.
- `test_get_fully_sharded_module_to_states()` in `test_utils.py` checks the traversal order of `_get_fully_sharded_module_to_states()`.
- `test_policy()` in `test_fully_shard.py` checks the traversal order returned by `_get_fsdp_handles()`.
---
Due to a circular dependency issue, we must move the graph/tree traversal helpers to their own file `_traversal_utils.py`, and any usages must import the entire file like `import torch.distributed.fsdp._traversal_utils as traversal_utils` instead of `from torch.distributed.fsdp._traversal_utils import ...`.
The cycle comes from the fact that the traversals require `_composable()`, which requires `_get_registry()` from `composable/contract.py`, which when imported, imports `composable/fully_shard.py`, which requires the traversals.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91044
Approved by: https://github.com/mrshenli
Motivations for this change:
1. TorchRec returns inconsistent results on `m.named_parameters()`
and `m.m1.named_parameters()` if m1 is a `ShardedModule`. Basically,
`ShardedModule` appears in `m.named_modules()`, but its parameters
are not in `m.named_parameters()`. As a result, when we identify
`ShardedModule` and pass them as `ignored_modules` to FSDP, FSDP
complains about key error in `_get_ignored_params`.
2. If users are manually wrapping submodules with FSDP, it could be
easier for them to keep a global set of ignored parameters, instead
of create a new collection for every FSDP invocation.
Given the above two reasons, we allow FSDP to have ignored modules
out of the wrapped root module.
Differential Revision: [D42132394](https://our.internmc.facebook.com/intern/diff/D42132394)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91079
Approved by: https://github.com/awgu
This PR removes the "communication module" (comm. module / `comm_module`) concept from the FSDP code base since it causes disproportionate confusion compared to its benefit for now.
Instead, we introduce the term "fully sharded module" as the single concept to unify the wrapper and non-wrapper code paths. The definition is presented in a note at the top of `flat_param.py`. I reproduce it here:
---
We define the **"fully sharded module"** to be the original `nn.Module` that owns a `FlatParamHandle`. It is the *single* module logically responsible for the *single* unshard/reshard pair for the handle's `FlatParameter` for a given forward or backward pass. The fully sharded module should be passed to the `FlatParamHandle` constructor.
For the wrapper code path:
- The `FullyShardedDataParallel` module wrapping the fully sharded module runs the unshard/reshard on behalf of the fully sharded module by overriding `nn.Module.forward`.
- The fully sharded module is exactly the module passed to the `FullyShardedDataParallel` constructor's `module` argument and is saved in `_fsdp_wrapped_module`.
For the non-wrapper code path:
- Hooks registered on the fully sharded module run the unshard/reshard.
- The fully sharded module may either be the direct argument to `fully_shard` or a submodule chosen by the provided wrapping policy.
---
After this PR, `handle.flat_param._fqns`, `_param_infos`, and `_shared_param_infos` all prefix names from the same module, namely the fully sharded module. This should make state dict less confusing.
---
As an example, consider:
```
mod: Module(
sub1: Submodule(
subsub1: Subsubmodule(),
subsub2: Subsubmodule(),
),
sub2: Submodule(
subsub1: Subsubmodule(),
subsub2: Subsubmodule(),
),
)
```
For wrapper FSDP manual wrap:
```
mod.sub1 = FSDP(mod.sub1)
mod.sub2 = FSDP(mod.sub2)
mod = FSDP(mod)
```
For wrapper FSDP auto wrap:
```
mod = FSDP(mod, auto_wrap_policy=ModuleWrapPolicy({Submodule}))
```
(WIP) For non-wrapper FSDP manual wrap:
```
fully_shard(mod.sub1)
fully_shard(mod.sub2)
fully_shard(mod)
```
For non-wrapper FSDP auto wrap:
```
fully_shard(mod, policy=ModuleWrapPolicy({Submodule}))
```
The fully sharded module **in all cases** are `mod`, `mod.sub1`, `mod.sub2`, and notably, `subsub1` and `subsub2`s are not fully sharded modules.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90933
Approved by: https://github.com/rohan-varma
This PR
- Removes `_module_to_handles` since it is no longer used. We instead use `_comm_module_to_handles`.
- Removes `HandleConfig` and stores its fields directly as attributes on `FlatParamHandle`.
- Uses the term `fqn`/`fqns` uniformly in `flat_param.py` instead of `prefixed_param_name` / `prefixed_param_names`.
- Clarifies some documentation.
I am including all of these BE items in the same PR to save CI.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90840
Approved by: https://github.com/rohan-varma
This PR reworks the internal handling of parameter and gradient reduction mixed precision, cleans up the post-backward hook logic, and adds some minor changes to the communication hooks.
**Overview**
This PR addresses everything in https://github.com/pytorch/pytorch/issues/90657 except renaming `keep_low_precision_grads` to `keep_grads_in_reduce_dtype` since that is BC breaking. I recommend reading the issue before preceding.
For `MixedPrecision(param_dtype, reduce_dtype, ...)`, the exact rule for parameter and gradient reduction mixed precision that we are following is:
> If `param_dtype is not None` and `reduce_dtype is None`, then we infer `reduce_dtype = param_dtype`. Otherwise, we take `param_dtype` and `reduce_dtype` as is.
This PR enforces that, at the `FlatParamHandle` level, `handle._config.fwd_bwd_param_dtype` and `handle._config.reduce_dtype` are never `None`. The way to check if mixed precision is enabled is to compare against the original parameter dtype, which is now stored in `handle._orig_param_dtype`. It is no longer to check against `None`.
This avoids ambiguous cases such as when the user passes `MixedPrecision(param_dtype=torch.float32)`. In that case, our existing implementation mistakenly thinks that parameter mixed precision is enabled and either relies on no-ops silently or errors (such as one case reported by MosaicML).
**Additional Details**
- We remove `FullyShardedDataParallel._mixed_precision_enabled_for_params`, `FullyShardedDataParallel._mixed_precision_enabled_for_reduce`, and `FullyShardedDataParallel._mixed_precision_keep_low_precision_grads` since they are not used.
- The unit test `test_meta_device_with_mixed_precision()` exercises a tricky edge case with meta device initialization, `apply()` (calling into `summon_full_params()`), and `param_dtype=torch.float32` for a nested wrapping case, where each nested instance has parameters.
- We include some minor fixes/improvements to the communication hook implementation.
**Follow-Ups**
- We should get rid of `HandleConfig` and store its fields as attributes on `FlatParamHandle` directly.
- Rename `keep_low_precision_grads` to `keep_grads_in_reduce_dtype`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90660
Approved by: https://github.com/zhaojuanmao
This saves a data structure `_stream_to_name: Dict[torch.cuda.Stream, str]` that maps each FSDP stream to its name. This can help in debugging by checking `_stream_to_name[torch.cuda.current_stream()]` to see if it is `"default"` or `"unshard"` in the post-backward hook for example.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90611
Approved by: https://github.com/rohan-varma
Adds 2 new hybrid sharding strategy to FSDP:
1. HYBRID_SHARD: applies zero-3 style sharding within a node, and data parallel across
2. HYBRID_SHARD_ZERO2: applies zero-2 style sharding within a node, and data parallel across
These are useful for medium sized models and aim to decrease communication volume, tests and benchmarks will be run to understand which workloads are optimal under which sharding strategy.
Hybrid sharding in general works by sharding the model using a process group within a single node, and creating intra-node process groups for replication / data parallelism. The user either needs to pass in a tuple of these process groups, or None, and we generate the process groups appropriately.
** Acknowledgements **
- @awgu 's excellent prototype: 5ad3a16d48
- @liangluofb For ideation, feedback, and initial implementation and experimentation
Pull Request resolved: https://github.com/pytorch/pytorch/pull/89915
Approved by: https://github.com/awgu
- This PR introduces a new concept, the _communication module_ (denoted `comm_module`), that represents the module responsible for the unshard/reshard pair for a `FlatParamHandle`. This is well-defined because the current design assumes that each `FlatParamHandle` only has _one_ unshard/reshard pair for either the forward or backward pass.
- For the wrapper code path, the `comm_module` is exactly the module already being passed to the `FlatParamHandle` constructor.
- For the composable code path, the `comm_module` is not necessarily the module already being passed to the `FlatParamHandle`. This is because the module already being passed is always the local FSDP root module to give complete FQNs, instead of local FQNs. Distinguishing the communication module from the local FSDP root module can provide more flexibility for non-recursive wrapping designs in the future.
- This PR adds a unit test `test_unshard_reshard_order` that explicitly checks that `_unshard` and `_reshard` are called in the exactly the same order across the two code paths.
- This PR does not fix `test_checkpoint_fsdp_submodules_use_reentrant`. However, the error message changes, so this PR accommodates that.
- The error is now the same as if we used the equivalent wrapper FSDP:
```
test_model.u1 = FSDP(test_model.u1, use_orig_params=True)
test_model.u2 = FSDP(test_model.u2, use_orig_params=True)
```
- The error is also the same as if we used wrapper FSDP with `use_orig_params=False`, so it is not unique to `use_orig_params=True`.
---
**`comm_module` Example**
```
model = Model(
seq1: nn.Sequential(
nn.Linear
nn.ReLU
nn.Linear
nn.ReLU
)
seq2: nn.Sequential(
nn.Linear
nn.ReLU
nn.Linear
nn.ReLU
)
)
policy = ModuleWrapPolicy({nn.Sequential})
fully_shard(model, policy=policy)
FullyShardedDataParallel(model, auto_wrap_policy=policy)
```
- This policy constructs two `FlatParamHandle`s, one for `seq1` and one for `seq2`.
- `FullyShardedDataParallel` will pass `seq1` and `seq2` as the `module` argument to the two `FlatParamHandle`s, respectively.
- `fully_shard()` will pass `model` as the `module` argument to every `FlatParamHandle`.
- `FullyShardedDataParallel` will pass `seq1` and `seq2` as the `comm_module` argument to the two `FlatParamHandle`s, respectively.
- `fully_shard()` will pass `seq1` and `seq2` as the `comm_module` argument to the two `FlatParamHandle`s, respectively.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90387
Approved by: https://github.com/mrshenli
I need to rebase later after Shen's PRs land.
The idea is to only register the pre/post-forward hook on the _root modules_ among the modules that consume a `FlatParameter`. (Yes, the term _root module_ is heavily overloaded. We may want to clarify that at some point. Here, _root_ is being used in the graph sense, meaning parent-less, and the scope is only among the modules consuming a `FlatParameter`.)
This avoids unnecessary pre/post-forward hooks running, which would lead to errors because the unshard is not truly idempotent.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90201
Approved by: https://github.com/mrshenli, https://github.com/rohan-varma
**Overview**
This PR removes an outdated TODO:
```
# TODO (awgu): When exposing the original parameters, we need to also
# use this attribute to prevent re-synchronizing parameters.
```
**Justification**
We only pass `managed_params` to `_sync_module_params_and_buffers()`, where `managed_params` is defined as
```
managed_params = list(_get_orig_params(root_module, state._ignored_params))
```
This `_get_orig_params()` call excludes parameters already flattened by FSDP. Thus, `_sync_module_params_and_buffers()` will not re-sync already-synchronized parameters. Each parameter appears in `managed_params` for some FSDP instance exactly once and hence is only synchronized once.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/89217
Approved by: https://github.com/mrshenli
**BC Breaking Change**
This renames `unwrapped_params` to `nonwrapped_numel`. I prefer `nonwrapped` over `unwrapped` because "unwrap" suggests that some wrapping has been undone. I prefer `numel` over `params` because that is unit of measurement; I think we should keep "params" to refer to `nn.Parameter`s themselves.
This only breaks anything that passes `unwrapped_params` as a keyword argument, but I did not see anything that did that (except the one internal benchmark file but that does not actually depend on our `pytorch` code).
In a follow-up, I want to rename `min_num_params` to `min_nonwrapped_numel` in `size_based_auto_wrap_policy`, which is also BC breaking. Again, this is to differentiate between "params" being `nn.Parameter`s and "numel" being the unit for `param.numel()`.
**Overview**
This PR introduces `ModuleWrapPolicy` as a lightweight layer over the existing `transformer_auto_wrap_policy`. The most common auto wrapping paradigm is:
```
module_classes: Set[Type[nn.Module]] = ...
auto_wrap_policy = functools.partial(
transformer_auto_wrap_policy,
transformer_layer_cls=module_classes,
)
fsdp_model = FSDP(model, auto_wrap_policy=auto_wrap_policy, ...)
```
Now, users can instead write:
```
auto_wrap_policy = ModuleWrapPolicy(module_classes)
fsdp_model = FSDP(model, auto_wrap_policy=auto_wrap_policy, ...)
```
This hides the unused arguments expected from the callable (`recurse` and `unwrapped_params`/`nonwrapped_numel`).
`ModuleWrapPolicy` inherits from an abstract base class `FSDPPolicy` that expects a `policy` property. This decouples the construct of such `FSDPPolicy` classes and their actual `policy`, which must abide by the `_recursive_wrap` interface. Any existing auto wrap policy can be rewritten as a class that inherits from `FSDPPolicy`, so this approach is fully backward compatible from a functionality perspective.
I call this base class `FSDPPolicy` to generalize over the cases where we may not want to actually perform any nested wrapping. In reality, the policy is meant for constructing `FlatParameter`s, which just happened to be induced by a nested wrapping before. Given this, I am changing the constructor argument in `fully_shard()` to simply `policy` instead of `auto_wrap_policy`.
This PR migrates usages of `transformer_auto_wrap_policy` within our unit test suite to `ModuleWrapPolicy` as much as possible.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88450
Approved by: https://github.com/zhaojuanmao
**What**
This PR completely removes the `FullyShardedDataParallel` dependency from `_state_dict_utils` -- `_state_dict_utils` now depends only on `_FSDPState` and all the utils modules.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88637
Approved by: https://github.com/awgu
This PR refactors and fixes `_cast_buffers()`.
**Before**
Buffers were not correctly cast back to their original dtypes for submodules when using buffer mixed precision.
- `_cast_buffers(recurse=False)` incorrectly casts all buffers, including those in submodules. This is because of this outer loop over `self.modules()`:
c40033be16/torch/distributed/fsdp/fully_sharded_data_parallel.py (L700)
- There was a unit test that checked that buffers were cast as expected (`test_mixed_precision_e2e_full_shard()`). The unit test _coincidentally_ passed because all modules shared the same buffer name `"buffer"`. In `_cast_buffers()`, the `dict` mapping buffer name to original dtype is populated lazily (during `_lazy_init()`). However, the keys are unprefixed:
c40033be16/torch/distributed/fsdp/fully_sharded_data_parallel.py (L712-L717)
- Thus, even though (1) `_cast_buffers(recurse=False)` was only called on the root and (2) `self._buffer_name_to_orig_dtype` had unprefixed names as keys, the unit test still passed because (1) `_cast_buffers()` still looped over all buffers despite `recurse=False` and (2) all submodules' buffers were named `"buffer"` and had the same original and low-precision dtypes and hence were cast correctly.
If we change each submodule to have its own distinct buffer name, then the unit test fails. This PR makes such a change to showcase the progression granted by this PR.
**After**
This PR separates `_cast_buffers()` into three methods: `_get_buffers_and_dtypes_for_computation()`, `_get_buffers_and_dtypes_for_checkpoint()`, and `_cast_buffers_to_dtype_and_device()`. This is to separate the different use cases (casting for computation and casting for checkpointing) and the corresponding code paths. Plus, the signature for `_cast_buffers_to_dtype_and_device()` makes it clear exactly what buffers are being cast and to what dtype.
Both `_get_...()` functions assume that they are called on the root only for now. This coincides with the construction of `_buffer_name_to_orig_dtype` in the FSDP constructor, which loops over all submodules. (This means that for non-root modules, their `_buffer_name_to_orig_dtype` is populated but not used.) The `dict`'s keys are clean since the buffer cast to original dtype happens in a `summon_full_params()` context, which cleans the names.
**Follow-Ups**
- We can try to move `_get_buffers_and_dtypes_for_checkpoint()` into `_state_dict_utils.py` in a follow-up.
- We may want to move to per-module buffer casting (i.e. do not have the root module cast for all submodules).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87935
Approved by: https://github.com/mrshenli
This PR introduces the composable FSDP API (with constructor semantics only) along with some further constructor refactoring. A notable contribution here is `_get_submodule_to_states()`, which performs auto wrapping without actually wrapping.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87924
Approved by: https://github.com/mrshenli
This PR makes a second pass over the constructor. The logic has been grouped into `_init_<...>` functions based on intent (e.g. `_init_prefetching_state()` or `_init_runtime_state()`). This makes the initialization code for composable FSDP much cleaner than having to re-write the same sequences of lower-level helper calls.
This PR also moves `_ExecOrderData` into its own file `_exec_order_utils.py`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87923
Approved by: https://github.com/mrshenli