This lowers the `reduce_dtype` retrieval to the `handle` instead of the `state` in preparation for `fully_shard`, and this adds a guard to avoid a no-op `to()` call.
Note that this change pretty much gets overridden in following PRs.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90615
Approved by: https://github.com/rohan-varma
Use register_state_dict_pre_hook in FSDP to simplify state_dict implementations & remove hacks. This removes `def state_dict` entirely and paves the path for composable API as well.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90436
Approved by: https://github.com/fegin
This saves a data structure `_stream_to_name: Dict[torch.cuda.Stream, str]` that maps each FSDP stream to its name. This can help in debugging by checking `_stream_to_name[torch.cuda.current_stream()]` to see if it is `"default"` or `"unshard"` in the post-backward hook for example.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90611
Approved by: https://github.com/rohan-varma
Continuation after https://github.com/pytorch/pytorch/pull/90163.
Here is a script I used to find all the non-existing arguments in the docstrings (the script can give false positives in presence of *args/**kwargs or decorators):
_Edit:_
I've realized that the indentation is wrong for the last `break` in the script, so the script only gives output for a function if the first docstring argument is wrong. I'll create a separate PR if I find more issues with corrected script.
``` python
import ast
import os
import docstring_parser
for root, dirs, files in os.walk('.'):
for name in files:
if root.startswith("./.git/") or root.startswith("./third_party/"):
continue
if name.endswith(".py"):
full_name = os.path.join(root, name)
with open(full_name, "r") as source:
tree = ast.parse(source.read())
for node in ast.walk(tree):
if isinstance(node, ast.FunctionDef):
all_node_args = node.args.args
if node.args.vararg is not None:
all_node_args.append(node.args.vararg)
if node.args.kwarg is not None:
all_node_args.append(node.args.kwarg)
if node.args.posonlyargs is not None:
all_node_args.extend(node.args.posonlyargs)
if node.args.kwonlyargs is not None:
all_node_args.extend(node.args.kwonlyargs)
args = [a.arg for a in all_node_args]
docstring = docstring_parser.parse(ast.get_docstring(node))
doc_args = [a.arg_name for a in docstring.params]
clean_doc_args = []
for a in doc_args:
clean_a = ""
for c in a.split()[0]:
if c.isalnum() or c == '_':
clean_a += c
if clean_a:
clean_doc_args.append(clean_a)
doc_args = clean_doc_args
for a in doc_args:
if a not in args:
print(full_name, node.lineno, args, doc_args)
break
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90505
Approved by: https://github.com/malfet, https://github.com/ZainRizvi
Adds 2 new hybrid sharding strategy to FSDP:
1. HYBRID_SHARD: applies zero-3 style sharding within a node, and data parallel across
2. HYBRID_SHARD_ZERO2: applies zero-2 style sharding within a node, and data parallel across
These are useful for medium sized models and aim to decrease communication volume, tests and benchmarks will be run to understand which workloads are optimal under which sharding strategy.
Hybrid sharding in general works by sharding the model using a process group within a single node, and creating intra-node process groups for replication / data parallelism. The user either needs to pass in a tuple of these process groups, or None, and we generate the process groups appropriately.
** Acknowledgements **
- @awgu 's excellent prototype: 5ad3a16d48
- @liangluofb For ideation, feedback, and initial implementation and experimentation
Pull Request resolved: https://github.com/pytorch/pytorch/pull/89915
Approved by: https://github.com/awgu
- This PR introduces a new concept, the _communication module_ (denoted `comm_module`), that represents the module responsible for the unshard/reshard pair for a `FlatParamHandle`. This is well-defined because the current design assumes that each `FlatParamHandle` only has _one_ unshard/reshard pair for either the forward or backward pass.
- For the wrapper code path, the `comm_module` is exactly the module already being passed to the `FlatParamHandle` constructor.
- For the composable code path, the `comm_module` is not necessarily the module already being passed to the `FlatParamHandle`. This is because the module already being passed is always the local FSDP root module to give complete FQNs, instead of local FQNs. Distinguishing the communication module from the local FSDP root module can provide more flexibility for non-recursive wrapping designs in the future.
- This PR adds a unit test `test_unshard_reshard_order` that explicitly checks that `_unshard` and `_reshard` are called in the exactly the same order across the two code paths.
- This PR does not fix `test_checkpoint_fsdp_submodules_use_reentrant`. However, the error message changes, so this PR accommodates that.
- The error is now the same as if we used the equivalent wrapper FSDP:
```
test_model.u1 = FSDP(test_model.u1, use_orig_params=True)
test_model.u2 = FSDP(test_model.u2, use_orig_params=True)
```
- The error is also the same as if we used wrapper FSDP with `use_orig_params=False`, so it is not unique to `use_orig_params=True`.
---
**`comm_module` Example**
```
model = Model(
seq1: nn.Sequential(
nn.Linear
nn.ReLU
nn.Linear
nn.ReLU
)
seq2: nn.Sequential(
nn.Linear
nn.ReLU
nn.Linear
nn.ReLU
)
)
policy = ModuleWrapPolicy({nn.Sequential})
fully_shard(model, policy=policy)
FullyShardedDataParallel(model, auto_wrap_policy=policy)
```
- This policy constructs two `FlatParamHandle`s, one for `seq1` and one for `seq2`.
- `FullyShardedDataParallel` will pass `seq1` and `seq2` as the `module` argument to the two `FlatParamHandle`s, respectively.
- `fully_shard()` will pass `model` as the `module` argument to every `FlatParamHandle`.
- `FullyShardedDataParallel` will pass `seq1` and `seq2` as the `comm_module` argument to the two `FlatParamHandle`s, respectively.
- `fully_shard()` will pass `seq1` and `seq2` as the `comm_module` argument to the two `FlatParamHandle`s, respectively.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90387
Approved by: https://github.com/mrshenli
This is the last PR for integrating 2D into core distributed.
This PR does the following:
1. Add optimizer.py: this adds ability to load a state_dict in conjunction with FSDP sharded optimzer state.
2. Update default_planner.py to support 2D checkpoint.
3. Add test_fsdp_optim_state.py as a unit test for No. 1.
4. Fix bug in torch/testing/_internal/distributed/checkpoint_utils.py
5. Rename the filename for the APIs that should be private. Will organize and cleanup further in following PRs. #90328
Docstring and integration test will be added in the following PRs.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90212
Approved by: https://github.com/wanchaol
**Motivation:**
Add a helper to map from the FQN to the corresponding flat_param. The helper will directly get flat_param from fsdp_state and flat_handler as flat_param is not registered to the module if `use_orig_params` is True.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/89899
Approved by: https://github.com/awgu
Observed by @aazzolini, some op might have Optional[Tensor] returns
where it return None (i.e. native_layer_norm_backward), it's a mismatch
between C++ aten op signature and python None, but we need to handle it
in the python side
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90241
Approved by: https://github.com/aazzolini
I need to rebase later after Shen's PRs land.
The idea is to only register the pre/post-forward hook on the _root modules_ among the modules that consume a `FlatParameter`. (Yes, the term _root module_ is heavily overloaded. We may want to clarify that at some point. Here, _root_ is being used in the graph sense, meaning parent-less, and the scope is only among the modules consuming a `FlatParameter`.)
This avoids unnecessary pre/post-forward hooks running, which would lead to errors because the unshard is not truly idempotent.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90201
Approved by: https://github.com/mrshenli, https://github.com/rohan-varma
This PR get rids of torchgen FunctionSchema parsing and parse
it manually, it should resolve torchgen package issue and also
provide some perf wins when running DTensor eagerly
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90106
Approved by: https://github.com/awgu
In pytorch, the optim state_dict will always use number to index optimizer state_dict for parameters.
Now composability workstream need a FQN based way to index optimizer state_dict for parameters..
For example, SGD optimizer might have something in its `state_dict` like:
```
{'state':
{0:
{'momentum_buffer': tensor(...)},
{1:
{'momentum_buffer': tensor(...)},
...
}
'param_groups':
[{'lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'maximize': False, 'foreach': None, 'differentiable': False, 'params': [0, 1, 2, 3, 4, 5, 6, 7]}]
}
```
And in NamedOptimizer we want the `state_dict` can be:
```
{'state':
{'net1.0.weight':
{'momentum_buffer': tensor(...)},
{'net1.0.bias':
{'momentum_buffer': tensor(...)},
...
}
'param_groups':
[{'lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'maximize': False, 'foreach': None, 'differentiable': False, 'params': ['net1.0.weight', 'net1.0.bias', 'net2.0.weight', 'net2.0.bias', 'net3.weight', 'net3.bias', 'net4.1.weight', 'net4.1.bias']}]
}
```
We also want to support load_state_dict to enable optim `state_dict` override for NameOptimizer.
For the next couple PR/diffs, we also need to:
1. To make `NamedOptimizer` working with FSDP (like registering a hook for model wrapped with FSDP) and other PTD/PT components.
2. Make `NamedOptimizer` works well with apply_optim_in_backward
3. Upstream also `CombinedOptimizer`.
Differential Revision: [D41432088](https://our.internmc.facebook.com/intern/diff/D41432088/)
**NOTE FOR REVIEWERS**: This PR has internal Meta-specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D41432088/)!
Pull Request resolved: https://github.com/pytorch/pytorch/pull/89480
Approved by: https://github.com/rohan-varma
Summary:
This diff is reverting D41682843
D41682843 has been identified to be causing the following test or build failures:
Tests affected:
- https://www.internalfb.com/intern/test/281475048939643/
Here's the Multisect link:
https://www.internalfb.com/intern/testinfra/multisect/1444954
Here are the tasks that are relevant to this breakage:
T93770103: 5 tests started failing for oncall assistant_multimodal in the last 2 weeks
We're generating a revert to back out the changes in this diff, please note the backout may land if someone accepts it.
Test Plan: NA
Reviewed By: zyan0, atuljangra, YazhiGao
Differential Revision: D41710749
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90132
Approved by: https://github.com/awgu
For PyTorch FSDP, the only way that gradients are in low precision is if `keep_low_precision_grads=True` or if the user turns on AMP. This PR adds tests for the former and improves the documentation for `clip_grad_norm_()`, especially around these non-full-precision cases.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90028
Approved by: https://github.com/rohan-varma
For any `flat_param.data = flat_param.to(...)` or `flat_param.grad.data = flat_param.grad.to(...)`, we must also refresh sharded parameter/gradient views, respectively, if the storage changes.
For `keep_low_precision_grads=True` and a sharded strategy, we cast the gradient back to the low precision using `.data` to bypass the PyTorch check that a parameter and its gradient have the same dtype. For `use_orig_params=True` before this PR, the gradient would incorrectly still be in full precision, not low precision, since we did not refresh views (this can actually be considered a memory leak since we have two copies of the gradient now, one in low precision and one in full precision). This PR refreshes the views.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90027
Approved by: https://github.com/mrshenli