Commit Graph

316 Commits

Author SHA1 Message Date
Shen Li
7ec1cb8553 [FSDP] Fix _pre_forward type annotation (#90621)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90621
Approved by: https://github.com/awgu, https://github.com/Skylion007
2022-12-11 06:39:38 +00:00
Shen Li
80542add73 [FSDP] Allow MixedPrecision to skip inputs (#90620)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90620
Approved by: https://github.com/rohan-varma, https://github.com/awgu
2022-12-11 06:39:38 +00:00
Andrew Gu
31351c61dd [FSDP] Tighten post-bwd cast to reduce_dtype (#90615)
This lowers the `reduce_dtype` retrieval to the `handle` instead of the `state` in preparation for `fully_shard`, and this adds a guard to avoid a no-op `to()` call.

Note that this change pretty much gets overridden in following PRs.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90615
Approved by: https://github.com/rohan-varma
2022-12-11 06:39:34 +00:00
Rohan Varma
c7d2fb7f86 Adopt state_dict_pre_hook in FSDP (#90436)
Use register_state_dict_pre_hook in FSDP to simplify state_dict implementations & remove hacks. This removes `def state_dict` entirely and paves the path for composable API as well.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/90436
Approved by: https://github.com/fegin
2022-12-11 03:54:26 +00:00
Andrew Gu
e7efeb5282 [FSDP] Save _stream_to_name for debugging (#90611)
This saves a data structure `_stream_to_name: Dict[torch.cuda.Stream, str]` that maps each FSDP stream to its name. This can help in debugging by checking `_stream_to_name[torch.cuda.current_stream()]` to see if it is `"default"` or `"unshard"` in the post-backward hook for example.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90611
Approved by: https://github.com/rohan-varma
2022-12-11 03:46:18 +00:00
Andrew Gu
9eccfedca2 [Reland][FSDP] Another fix for DTensor, use_orig_params=True (#90562)
This is a reland of https://github.com/pytorch/pytorch/pull/89845 with nothing changed. This should avoid the internal breakage now that `DTensor` does not import `torchgen` (https://github.com/pytorch/pytorch/pull/90106).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90562
Approved by: https://github.com/fduwjj
2022-12-10 22:50:30 +00:00
Shen Li
082450609c [FSDP] Allow nested FSDP wrapper to use different mixed precision (#90523)
The main change is to move `args` and `kwargs` dtype convertion
from `_root_pre_forward` to `_pre_forward`, so that every
FSDP has a chance to apply its own precision.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90523
Approved by: https://github.com/awgu, https://github.com/rohan-varma
2022-12-09 20:06:05 +00:00
Andrew Gu
2cf703214b [Composable API][Easy] Fix some follow-ups (#90471)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90471
Approved by: https://github.com/mrshenli
2022-12-09 00:26:38 +00:00
Rohan Varma
43660051d8 [Ez] Omit HSDP Z2 from doc (#90503)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90503
Approved by: https://github.com/awgu
2022-12-08 23:05:49 +00:00
Rohan Varma
793a999ce0 Hybrid Sharded Data Parallel (#89915)
Adds 2 new hybrid sharding strategy to FSDP:
1. HYBRID_SHARD: applies zero-3 style sharding within a node, and data parallel across
2. HYBRID_SHARD_ZERO2: applies zero-2 style sharding within a node, and data parallel across

These are useful for medium sized models and aim to decrease communication volume, tests and benchmarks will be run to understand which workloads are optimal under which sharding strategy.

Hybrid sharding in general works by sharding the model using a process group within a single node, and creating intra-node process groups for replication / data parallelism. The user either needs to pass in a tuple of these process groups, or None, and we generate the process groups appropriately.

** Acknowledgements **
- @awgu 's excellent prototype: 5ad3a16d48
- @liangluofb For ideation, feedback, and initial implementation and experimentation
Pull Request resolved: https://github.com/pytorch/pytorch/pull/89915
Approved by: https://github.com/awgu
2022-12-08 16:18:03 +00:00
Andrew Gu
21a0e809c2 [Composable API] Match fully_shard() comm. schedule with wrapper FSDP (#90387)
- This PR introduces a new concept, the _communication module_ (denoted `comm_module`), that represents the module responsible for the unshard/reshard pair for a `FlatParamHandle`. This is well-defined because the current design assumes that each `FlatParamHandle` only has _one_ unshard/reshard pair for either the forward or backward pass.
    - For the wrapper code path, the `comm_module` is exactly the module already being passed to the `FlatParamHandle` constructor.
    - For the composable code path, the `comm_module` is not necessarily the module already being passed to the `FlatParamHandle`. This is because the module already being passed is always the local FSDP root module to give complete FQNs, instead of local FQNs. Distinguishing the communication module from the local FSDP root module can provide more flexibility for non-recursive wrapping designs in the future.
- This PR adds a unit test `test_unshard_reshard_order` that explicitly checks that `_unshard` and `_reshard` are called in the exactly the same order across the two code paths.
- This PR does not fix `test_checkpoint_fsdp_submodules_use_reentrant`. However, the error message changes, so this PR accommodates that.
    - The error is now the same as if we used the equivalent wrapper FSDP:
    ```
    test_model.u1 = FSDP(test_model.u1, use_orig_params=True)
    test_model.u2 = FSDP(test_model.u2, use_orig_params=True)
    ```
    - The error is also the same as if we used wrapper FSDP with `use_orig_params=False`, so it is not unique to `use_orig_params=True`.

---

**`comm_module` Example**

```
model = Model(
    seq1: nn.Sequential(
        nn.Linear
        nn.ReLU
        nn.Linear
        nn.ReLU
    )
    seq2: nn.Sequential(
        nn.Linear
        nn.ReLU
        nn.Linear
        nn.ReLU
    )
)
policy = ModuleWrapPolicy({nn.Sequential})
fully_shard(model, policy=policy)
FullyShardedDataParallel(model, auto_wrap_policy=policy)
```
- This policy constructs two `FlatParamHandle`s, one for `seq1` and one for `seq2`.
- `FullyShardedDataParallel` will pass `seq1` and `seq2` as the `module` argument to the two `FlatParamHandle`s, respectively.
- `fully_shard()` will pass `model` as the `module` argument to every `FlatParamHandle`.
- `FullyShardedDataParallel` will pass `seq1` and `seq2` as the `comm_module` argument to the two `FlatParamHandle`s, respectively.
- `fully_shard()` will pass `seq1` and `seq2` as the `comm_module` argument to the two `FlatParamHandle`s, respectively.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/90387
Approved by: https://github.com/mrshenli
2022-12-08 15:55:20 +00:00
Chien-Chin Huang
44779d9bc6 [FSDP][optim_state_dict][2/N] Add _get_fqn_to_fsdp_param_info to map from original FQN to flat_param (#89899)
**Motivation:**
Add a helper to map from the FQN to the corresponding flat_param. The helper will directly get flat_param from fsdp_state and flat_handler as flat_param is not registered to the module if `use_orig_params` is True.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/89899
Approved by: https://github.com/awgu
2022-12-07 19:40:47 +00:00
Ram Rachum
351d73b97f Fix exception causes all over the codebase (#90271)
This is the continuation to #90134 and hopefully the final PR in this series.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/90271
Approved by: https://github.com/kit1980
2022-12-07 04:29:00 +00:00
Andrew Gu
45b40be078 [FSDP()] Fix fully_shard fwd hook registration (#90201)
I need to rebase later after Shen's PRs land.

The idea is to only register the pre/post-forward hook on the _root modules_ among the modules that consume a `FlatParameter`. (Yes, the term _root module_ is heavily overloaded. We may want to clarify that at some point. Here, _root_ is being used in the graph sense, meaning parent-less, and the scope is only among the modules consuming a `FlatParameter`.)

This avoids unnecessary pre/post-forward hooks running, which would lead to errors because the unshard is not truly idempotent.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90201
Approved by: https://github.com/mrshenli, https://github.com/rohan-varma
2022-12-06 06:09:03 +00:00
Chien-Chin Huang
72fdfad4ad [FSDP][optim_state_dict][1/N] Restructure _optim_state_dict to prepare the support of use_orig_param (#89898)
**Motivation:**
Restructure some APIs in _optim_state_dict.py to allow better future extension, mostly for supporting use_orig_params. NO logic change in this PR.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/89898
Approved by: https://github.com/awgu
2022-12-05 21:01:48 +00:00
Shen Li
7a08261a9c Fix fully_shard error when policy is not provided (#90151)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90151
Approved by: https://github.com/awgu
2022-12-05 15:21:47 +00:00
Zheng Yan
c00d395f05 Revert D41682843: Multisect successfully blamed D41682843 for test or build failures (#90132)
Summary:
This diff is reverting D41682843
D41682843 has been identified to be causing the following test or build failures:
Tests affected:
- https://www.internalfb.com/intern/test/281475048939643/

Here's the Multisect link:
https://www.internalfb.com/intern/testinfra/multisect/1444954
Here are the tasks that are relevant to this breakage:
T93770103: 5 tests started failing for oncall assistant_multimodal in the last 2 weeks
We're generating a revert to back out the changes in this diff, please note the backout may land if someone accepts it.

Test Plan: NA

Reviewed By: zyan0, atuljangra, YazhiGao

Differential Revision: D41710749

Pull Request resolved: https://github.com/pytorch/pytorch/pull/90132
Approved by: https://github.com/awgu
2022-12-04 05:35:17 +00:00
Andrew Gu
e47af44eb8 [FSDP][Easy] Remove unused methods (#89229)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/89229
Approved by: https://github.com/mrshenli
2022-12-03 17:55:27 +00:00
Andrew Gu
1ee189ce8e [FSDP] Issue warning when clamping to NO_SHARD (#90060)
Fixes https://github.com/pytorch/pytorch/issues/90050. I hope that this was not meant as an onboarding task :/
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90060
Approved by: https://github.com/zhaojuanmao
2022-12-03 15:58:25 +00:00
Andrew Gu
eb56b08f96 [FSDP] Fix clip_grad_norm_() for low prec grads (#90028)
For PyTorch FSDP, the only way that gradients are in low precision is if `keep_low_precision_grads=True` or if the user turns on AMP. This PR adds tests for the former and improves the documentation for `clip_grad_norm_()`, especially around these non-full-precision cases.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90028
Approved by: https://github.com/rohan-varma
2022-12-02 21:10:45 +00:00
Andrew Gu
688b767265 [FSDP] Fix keep_low_precision_grads=True for use_orig_params=True (#90027)
For any `flat_param.data = flat_param.to(...)` or `flat_param.grad.data = flat_param.grad.to(...)`, we must also refresh sharded parameter/gradient views, respectively, if the storage changes.

For `keep_low_precision_grads=True` and a sharded strategy, we cast the gradient back to the low precision using `.data` to bypass the PyTorch check that a parameter and its gradient have the same dtype. For `use_orig_params=True` before this PR, the gradient would incorrectly still be in full precision, not low precision, since we did not refresh views (this can actually be considered a memory leak since we have two copies of the gradient now, one in low precision and one in full precision). This PR refreshes the views.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90027
Approved by: https://github.com/mrshenli
2022-12-02 21:10:45 +00:00
Zheng Yan
6efedfd774 Revert D41609017: Multisect successfully blamed D41609017 for test or build failures (#90034)
Summary:
This diff is reverting D41609017
D41609017 has been identified to be causing the following test or build failures:
Tests affected:
- https://www.internalfb.com/intern/test/281475052567659/
- https://www.internalfb.com/intern/test/562950029295825/

Here's the Multisect link:
https://www.internalfb.com/intern/testinfra/multisect/1440332
Here are the tasks that are relevant to this breakage:
T93368156: 5 tests started failing for oncall admarket_predictor_pushmaster in the last 2 weeks
We're generating a revert to back out the changes in this diff, please note the backout may land if someone accepts it.

Test Plan: NA

Reviewed By: zyan0

Differential Revision: D41656946

Pull Request resolved: https://github.com/pytorch/pytorch/pull/90034
Approved by: https://github.com/awgu
2022-12-02 01:31:50 +00:00
Andrew Gu
808cb2e86d [FSDP][Dynamo] Define annotation attributes as globals (#89913)
This was separated out from the previous PR to decouple. Since not all builds include `torch.distributed`, we should define the globals in the dynamo file and import to distributed instead of vice versa. Unlike the version from the previous PR, this PR prefixes the globals with `_` to future proof against `_dynamo/` eventually becoming public.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/89913
Approved by: https://github.com/wconstab
2022-12-01 13:25:54 +00:00
Shen Li
dfbc4e5473 [Easy][FSDP] Fix pyre error (#89930)
This PR attemps to fix the following pyre error:

```
Incompatible parameter type [6]: In call
`dist.fsdp.fully_sharded_data_parallel.FullyShardedDataParallel.__init__`,
for 7th parameter `auto_wrap_policy` expected
`Optional[typing.Callable[..., typing.Any]]` but got
`Optional[_FSDPPolicy]`.
```

Besides, this also removes the type inconsistency in code and docstring.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/89930
Approved by: https://github.com/awgu
2022-11-30 21:33:00 +00:00
Andrew Gu
6fb8423904 [FSDP] Slightly refactor fx symbolic tracer (#89917)
I made a pass over Linjian's `_symbolic_trace.py` and tidied it up a bit. Aside from simple stylistic changes, this PR makes the following changes:
- Save `visited_params: Set[nn.Parameter]` to avoid linear overhead to check a parameter already being visited when appending to the parameter execution order list (`param_forward_order`)
- Move the tracer patching logic to a class `_ExecOrderTracer` to have a reference to `self.exec_info` without having a fragmented 2-step initialization (like the old `_init_execution_info(root_module)` plus `_patch_tracer(tracer, root_module, execution_info)`)
- Define `_ParamUsageInfo` to formalize the `Tuple[nn.Module, List[str, nn.Parameter]]` elements being mapped to in the execution info `dict`, and clarify the documentation regarding what this represents
- Change the unit test to use `TestCase`, not `FSDPTest`, to avoid initializing a process group

Pull Request resolved: https://github.com/pytorch/pytorch/pull/89917
Approved by: https://github.com/zhaojuanmao, https://github.com/fegin
2022-11-30 20:31:55 +00:00
Andrew Gu
89769d84eb [FSDP][BE] Move dynamo annotation to separate file (#89890)
This PR makes two minor changes: It (1) moves the recently-added module annotation logic for dynamo support to a separate file `torch/distributed/fsdp/_dynamo_utils.py` and ~~(2) saves the annotated attribute names to global variables `FSDP_MANAGED_MODULE` and `FSDP_USE_ORIG_PARAMS`~~.
Update: Since the distributed package may not be included in some builds, it is not safe to import from `torch.distributed...` to a file in `_dynamo/`. I will not include change (2) in this PR. The alternative is to define those globals (privately) in the dynamo file and import from there in the FSDP file.
- The first change is mainly a personal choice, where I wanted to avoid the dynamo explanation from dominating the FSDP constructor space-wise. I added the `(see function for details)` to the inline comment to forward interested readers.
- The second change follows the custom we have taken in the past for such attributes (e.g. `FSDP_FLATTENED`). My understanding (in the past as well as currently) is that this is a good practice.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/89890
Approved by: https://github.com/wconstab
2022-11-30 20:29:41 +00:00
Andrew Gu
5949d5fed5 [FSDP][Easy] Remove internal default arg (#89227)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/89227
Approved by: https://github.com/mrshenli
2022-11-30 13:34:05 +00:00
Andrew Gu
ce17bb95fc [FSDP] Include module classes in ModuleWrapPolicy.__repr__ (#89058)
Before:
```
<torch.distributed.fsdp.wrap.ModuleWrapPolicy object at 0x7fd4280f0fd0>
```
After:
```
<torch.distributed.fsdp.wrap.ModuleWrapPolicy object at 0x7fd4280f0fd0>({<class 'transformers.models.t5.modeling_t5.T5Block'>})
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/89058
Approved by: https://github.com/mrshenli
2022-11-30 02:27:02 +00:00
Andrew Gu
c8aaad040e [FSDP] Limit all gather after pre-unshard (#89057)
To reuse memory when allocating the unsharded `FlatParameter` in the unshard stream, we only need to block the CPU thread on the preceding free event (i.e. `event.synchronize()`) before allocating the unsharded memory, which happens in `handle.unshard()`. Notably, this can be done after the pre-unshard logic, which at most performs _sharded_ allocations (low precision shard or H2D sharded `FlatParameter` copy) in its own pre-unshard stream. This enables the pre-unshard to overlap with any pending ops.

With this change, I believe that we should use `limit_all_gathers=True` all the time to stay true to FSDP's proposed memory semantics.

If a user wants to set `limit_all_gathers=False`, that would mean that he/she wants to overlap ops that are issued after the unshard logic's all-gather with ops that are pending at the time when FSDP _would_ block the CPU thread via `event.synchronize()`.
- If the user is willing to not reuse memory for that all-gather, then the user may as well have applied `NO_SHARD` and optionally ZeRO-1 (if this niche is important, then maybe we should consider hardening ZeRO-1). This is because now the unsharded memory for the all-gather additionally contributes to peak memory since it cannot reuse memory.
- If the user wanted to reuse memory for that all-gather, then we needed to block the CPU thread. There is no way around that given the caching allocator semantics.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/89057
Approved by: https://github.com/mrshenli
2022-11-30 02:27:02 +00:00
Andrew Gu
be80b72add [FSDP] Remove unneeded stream sync from clip_grad_norm_() (#89308)
We do not need to have the pre-unshard and unshard streams wait for the computation stream because we are not using the pre-unshard or unshard streams in `clip_grad_norm_()`.

The other change is simply avoiding a loop to get `grads`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/89308
Approved by: https://github.com/mrshenli
2022-11-30 02:14:09 +00:00
Andrew Gu
d1760d7a42 [FSDP][Easy] Remove outdated TODO (#89217)
**Overview**
This PR removes an outdated TODO:
```
# TODO (awgu): When exposing the original parameters, we need to also
# use this attribute to prevent re-synchronizing parameters.
```

**Justification**
We only pass `managed_params` to `_sync_module_params_and_buffers()`, where `managed_params` is defined as
```
managed_params = list(_get_orig_params(root_module, state._ignored_params))
```
This `_get_orig_params()` call excludes parameters already flattened by FSDP. Thus, `_sync_module_params_and_buffers()` will not re-sync already-synchronized parameters. Each parameter appears in `managed_params` for some FSDP instance exactly once and hence is only synchronized once.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/89217
Approved by: https://github.com/mrshenli
2022-11-30 01:42:16 +00:00
Andrew Gu
c599cf24ad [FSDP] Another fix for DTensor, use_orig_params=True (#89845)
The issue for `test_2d_parallel.py` is that `DTensor` does not support the idiom `param.data = view` where `view` is a `DTensor`. To work around this, we do not preserve the parameter variable `param` and instead create a new parameter variable altogether via `nn.Parameter(view)`. Preserving the parameter variable when unsharded was not a strict requirement -- it just made sense to do that if we are already doing that when _sharded_, where it _is_ a strict requirement to support the optimizer step. The sharded case is not an issue for 2D because sharded implies local tensor, not `DTensor`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/89845
Approved by: https://github.com/zhaojuanmao
2022-11-29 22:29:41 +00:00
Andrew Gu
6e2da426f0 [FSDP] Relax post-backward assert (#89791)
This assert was accidentally made stricter when transitioning from per-FSDP-instance training state to per-handle training state. This PR relaxes it again, which should restore compatibility for some reentrant AC plus FSDP cases.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/89791
Approved by: https://github.com/zhaojuanmao
2022-11-29 17:25:56 +00:00
Will Constable
705ad36cc5 Dynamo asserts FSDP wrapped modules use_orig_param (#89523)
- This is a strict requirement given the way dynamo+FSDP is implemented,
  but isn't convenient to assert.
- By plumbing use_orig_param field on all wrapped modules, we can
  do this assertion inside dynamo

Pull Request resolved: https://github.com/pytorch/pytorch/pull/89523
Approved by: https://github.com/awgu
2022-11-29 05:27:23 +00:00
Will Constable
77df2ca9b6 Special-case fsdp wrapped modules to be Unspecialized (#89330)
### Summary
Making dynamo treat the nn.Modules inside FSDP wrappers as 'Unspecialized'
results in dynamo-produced graphs where nn.module parameters are inputs
to the graph rather than attributes of the outer graphmodule.

This helps in FSDP since it forces dynamo to pick the latest copy
of the parameters off the user's nn.Module (which FSDP mutates every pre_forward),
solving the ordering issue in backward.

### Details
Imagine this toy model
```
class MyModule(torch.nn.Module):
    def __init__(self, a, b):
        super(MyModule, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(a, b),
            nn.ReLU(),
        )
    def forward(self, x):
        return self.net(x)

class ToyModel(nn.Module):
    def __init__(self):
        super(ToyModel, self).__init__()
        self.net = nn.Sequential(
            *[MyModule(10, 10000)]
            + [MyModule(10000, 1000)]
            + [MyModule(1000, 5)]
        )

    def forward(self, x):
        return self.net(x)
```
Where FSDP is recursively wrapped around each `MyModule`, then dynamo-compiled, with dynamo already configured to skip/break in FSDP code.  You'd expect to get 3 compiled AOT functions, corresponding to the contents of `MyModule`, and then see FSDP's communication ops happen inbetween them (eagerly).  This almost happens (everything works out fine in forward), but in backward there is an ordering issue.

FSDP creates a flat buffer for all the parameters that are bucketed together, and then creates views into this buffer to replace the original parameters.  On each iteration of forward, it creates a new view after 'filling' the flatbuffer with data from an all-gather operation, to 'unshard' the parameters from remote devices.  Dynamo traces the first such view and stores it in a compiled graphmodule.

During  tracing, we see (1) view created for first MyModule, (2) compile first MyModule, (3) ... for the rest of layers

Then during runtime,  we see (A)  view created for first MyModule (and orphaned), (B) execute first compiled MyModule, using old view, ...

This is a problem, because we want backward hooks to run right after each compiled-backward, but autograd executes those hooks in an order mirroring their execution order during forward.  Since we are forever using the views created during steps (1, 3, ..  N), which all happen before the steps (A, B, ...),  this means that all the hooks will happen after all the compiled backwards.  An illustration of the problem - a torchviz graph showing the 2 possible orderings of autograd, and a profile showing the view-backwards ops happening after all the compiled backwards, and before all the backward hooks.

<img width="2069" alt="image" src="https://user-images.githubusercontent.com/4984825/202828002-32dbbd15-8fc3-4281-93e9-227ab5e32683.png">
<img width="2069" alt="image" src="https://user-images.githubusercontent.com/4984825/202828632-33e40729-9a7f-4e68-9ce1-571e3a8dd2dd.png">

A solution is to make dynamo not specialize on these nn modules.  It is worth pointing out that this nn.module specialization is de-facto failing, as we are modifying .parameters and this bypasses dynamo's __setattr__ monkeypatch, which should have automatically kicked us out to Unspecialized and forced a recompile.

After unspecializing, the new views (created during steps A,  C, ...) are actually _used_ at runtime by the module, making their creation order interleaved, making autograd execute their backwards interleaved.

The new torchviz graph (this time with names added for the view tensors):
<img width="2043" alt="image" src="https://user-images.githubusercontent.com/4984825/202828480-d30005ba-0d20-45d8-b647-30b7ff5e91d3.png">

And a new profile showing the interleaving of compiled backwards and hooks, allowing overlapping of reduce-scatter.
<img width="2293" alt="image" src="https://user-images.githubusercontent.com/4984825/202828533-bb20a041-19b8-499c-b3cf-02808933df47.png">

@jansel @davidberard98 @aazzolini @mrshenli @awgu @ezyang @soumith @voznesenskym @anijain2305

Pull Request resolved: https://github.com/pytorch/pytorch/pull/89330
Approved by: https://github.com/davidberard98
2022-11-29 01:24:03 +00:00
Andrew Gu
943acd4d27 [FSDP] Fix nn.Parameter usage for 2D and use_orig_params=True (#89782)
This ensures that all elements of `FlatParameter._params` and `FlatParameter._shared_params` are `nn.Parameter`s (as expected). This was violated by the local tensor of a `DTensor` when using 2D parallelism. To fix the breakage, we simply wrap with `nn.Parameter` if needed.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/89782
Approved by: https://github.com/fduwjj
2022-11-28 23:56:38 +00:00
Andrew Gu
090fc62b24 [FSDP()] Register root pre-forward hook (#89572)
- This PR registers the FSDP root pre-forward hook as a module forward pre-hook following the recently added support for kwargs for those hooks.
- This PR also passes `prepend=True` for the normal (not root) pre-forward hook. This is not strictly required for this PR, but I believe it is needed for composability with activation checkpointing. (We want to run FSDP logic on the outside and AC logic on the inside, just like how we recommend `FSDP(AC(module))` for the wrapper versions.)

Fun fact: I originally chose the `[FSDP()]` prefix in the PR titles when we still referred to composable FSDP as functional-like FSDP, in which case `FSDP()` approximated "functional FSDP". I am preserving this usage to make searching for PRs relating to composable FSDP easier.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/89572
Approved by: https://github.com/mrshenli
2022-11-28 16:56:32 +00:00
Andrew Gu
ee1d375bf9 [FSDP] Add fast path for NO_SHARD clip_grad_norm_() (#89137)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/89137
Approved by: https://github.com/rohan-varma
2022-11-16 22:08:50 +00:00
Andrew Gu
4f2639e56a [FSDP] Fix FSDP.clip_grad_norm_() for NO_SHARD (#88955)
This PR fixes `FSDP.clip_grad_norm_()` for `NO_SHARD`, which previously "double-counted" each gradient `world_size`-many times.

This does not address any discrepancies between `FULL_SHARD` and DDP. (Note that the unit tests do show parity between `FULL_SHARD` and DDP when using `FSDP.clip_grad_norm_()` and `nn.utils.clip_grad_norm_()` respectively on one iteration.)

The added unit test code path tests mixing nested FSDP instances with both `FULL_SHARD` and `NO_SHARD` to ensure that the `local_sharded_norm` and `local_nonsharded_norm` computations are interoperating correctly. I want to test non-FSDP root instance in the future, but this is BC breaking since we need to make `clip_grad_norm_()` a static method, which would require a different method call syntax (`FSDP.clip_grad_norm_(root_module, ...)` vs. `root_module.clip_grad_norm_(...)`).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88955
Approved by: https://github.com/zhaojuanmao
2022-11-13 02:38:38 +00:00
Andrew Gu
d01bf1d1f1 [FSDP] Introduce ModuleWrapPolicy for simplicity (#88450)
**BC Breaking Change**
This renames `unwrapped_params` to `nonwrapped_numel`. I prefer `nonwrapped` over `unwrapped` because "unwrap"  suggests that some wrapping has been undone. I prefer `numel` over `params` because that is unit of measurement; I think we should keep "params" to refer to `nn.Parameter`s themselves.

This only breaks anything that passes `unwrapped_params` as a keyword argument, but I did not see anything that did that (except the one internal benchmark file but that does not actually depend on our `pytorch` code).

In a follow-up, I want to rename `min_num_params` to `min_nonwrapped_numel` in `size_based_auto_wrap_policy`, which is also BC breaking. Again, this is to differentiate between "params" being `nn.Parameter`s and "numel" being the unit for `param.numel()`.

**Overview**
This PR introduces `ModuleWrapPolicy` as a lightweight layer over the existing `transformer_auto_wrap_policy`. The most common auto wrapping paradigm is:
```
module_classes: Set[Type[nn.Module]] = ...
auto_wrap_policy = functools.partial(
    transformer_auto_wrap_policy,
    transformer_layer_cls=module_classes,
)
fsdp_model = FSDP(model, auto_wrap_policy=auto_wrap_policy, ...)
```
Now, users can instead write:
```
auto_wrap_policy = ModuleWrapPolicy(module_classes)
fsdp_model = FSDP(model, auto_wrap_policy=auto_wrap_policy, ...)
```
This hides the unused arguments expected from the callable (`recurse` and `unwrapped_params`/`nonwrapped_numel`).

`ModuleWrapPolicy` inherits from an abstract base class `FSDPPolicy` that expects a `policy` property. This decouples the construct of such `FSDPPolicy` classes and their actual `policy`, which must abide by the `_recursive_wrap` interface. Any existing auto wrap policy can be rewritten as a class that inherits from `FSDPPolicy`, so this approach is fully backward compatible from a functionality perspective.

I call this base class `FSDPPolicy` to generalize over the cases where we may not want to actually perform any nested wrapping. In reality, the policy is meant for constructing `FlatParameter`s, which just happened to be induced by a nested wrapping before. Given this, I am changing the constructor argument in `fully_shard()` to simply `policy` instead of `auto_wrap_policy`.

This PR migrates usages of `transformer_auto_wrap_policy` within our unit test suite to `ModuleWrapPolicy` as much as possible.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88450
Approved by: https://github.com/zhaojuanmao
2022-11-12 04:14:32 +00:00
Chien-Chin Huang
ae4074669e [FSDP][state_dict][6/N] Remove most FSDP module dependency from _optim_utils (#88638)
**What**
This PR removes most `FullyShardedDataParallel` dependencies from `optim_utils`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/88638
Approved by: https://github.com/awgu
2022-11-12 03:16:37 +00:00
Will Constable
a3f3ec8fac [FSDP+dynamo]: forward treats parameter-views as params (#88781)
Dynamo+AotAutograd needs a way to wrap all tensors (whether
inputs or params/buffers) in FakeTensor wrappers, and
FSDP's mangling of parameters hides them from this wrapping.

This PR unblocks running hf_bert and hf_T5 with FSDP under dynamo, whether using recursive wrapping around transformer layers or only applying FSDP around the whole model.  Perf/memory validation and possibly optimization is the next step.
`python benchmarks/dynamo/distributed.py --torchbench_model hf_Bert --fsdp --dynamo aot_eager`
`python benchmarks/dynamo/distributed.py --torchbench_model hf_Bert --fsdp --dynamo aot_eager --fsdp_wrap`
`python benchmarks/dynamo/distributed.py --torchbench_model hf_T5 --fsdp --dynamo aot_eager`
`python benchmarks/dynamo/distributed.py --torchbench_model hf_T5 --fsdp --dynamo aot_eager --fsdp_wrap`

The problem:
Dynamo (Actually aot_autograd) trips up with FSDP becuase it must
wrap all input tensors in FakeTensor wrappers, and it only knows
to wrap graph inputs or named_(parameters, buffers).  FSDP's
pre_forward hook sets views (which are not nn.param) into the flatparam
as attrs on the module with the same name as the original param, but
they will not show up in named_parameters.

- in use_orig_params mode, FSDP still de-registers
  params during pre-forward hook, then re-registers them
  post-forward
- during forward (between the hooks), the params are setattr'd
  on the module as regular view tensors, not nn.Parameters
- note: use_orig_params is the recommended way to use FSDP,
  and use_orig_params=False is being deprecated.  So i only consider
  use_orig_params=True for this enablement

The solution:
- adding them to named_buffers is not possible because it interferes
  with how FSDP's `_apply` works
- since they are not actual nn.parameters, register_parameter will
  complain about registering them
- simply seting `module._parameters[name] = view` seems to be a viable
  workaround, despite being hacky, and FSDP code does modify _parameters
  directly already.

Note: Manual checkpointing still isn't working with FSDP+dynamo,
so that will have to be addressed in a follow up.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/88781
Approved by: https://github.com/ezyang, https://github.com/awgu
2022-11-12 01:17:23 +00:00
Chien-Chin Huang
7aa144ac54 [FSDP][state_dict][5/N] Remove the FSDP module dependency from _state_dict_utils (#88637)
**What**
This PR completely removes the `FullyShardedDataParallel` dependency from `_state_dict_utils` -- `_state_dict_utils` now depends only on `_FSDPState` and all the utils modules.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/88637
Approved by: https://github.com/awgu
2022-11-11 21:22:13 +00:00
Chien-Chin Huang
3d1c5c89ed [FSDP][state_dict][4/N] Move the core logic of summon full parameters to _unshard_params_utils.py (#88636)
**What**
`_summon_full_parameters` is required for state_dict. To enable composable FSDP state_dict, `_summon_full_params` must be accessible without FullyShardedDataParall. This PR move the core logic of `_summon_full_params` to `_unshard_params_utils`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/88636
Approved by: https://github.com/awgu
2022-11-11 18:30:57 +00:00
Chien-Chin Huang
3c7f96665e [FSDP][state_dict][3/N] Change how state_dict utils access attributes in _FSDPState (#88635)
**What This PR Does**
_state_dict_utils currently accesses the FSDP states through module. To enable composable FSDP state_dict, these accesses need to go through _FSDPState. module is still required for most APIs as state_dict has to access per-module information.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/88635
Approved by: https://github.com/awgu
2022-11-11 15:20:36 +00:00
Chien-Chin Huang
324ac93a43 [FSDP][state_dict][2/N] Move state_dict related enums/dataclasses/states to state_dict_utils.py, api.py and init_state_dict() (#88481)
**Motivation**:
Several Enums, Dataclasses and states defined in fully_sharded_data_paralle.py should be moved to a place where the composable FSDP can access. This PR does the move.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/88481
Approved by: https://github.com/rohan-varma, https://github.com/awgu
2022-11-11 12:28:37 +00:00
Chien-Chin Huang
c4fc5d372f [FSDP][state_dict][1/N] Moving state_dict logic to pre_state_dict_hook (#87900)
This is one step toward the ultimate goal: remove the overwritten state_dict in FSDP. All the logic should be either in `pre_state_dict_hook` or `post_state_dict_hook`.

Since current `nn.Module` does not support `pre_state_dict_hook`, this PR mimic `pre_state_dict_hook` by calling the pre hook inside post the hook, effectively ditching all the work done by `nn.Module.state_dict`. Once `pre_state_dict_hook` is supported by `nn.Module`, these pre hook calls can be moved out from the post hooks and be registered to `nn.Module.pre_state_dict_hook`.

The major issue of this temporary solution is that `post_state_dict_hook` is called from the leaf node to the root node. This makes the `module._lazy_init()` invalid as FSDP assumes `_lazy_init()` to be called from the root. As a result, `FSDP.state_dict` currently contains only one logic -- calling `module._lazy_init()`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/87900
Approved by: https://github.com/rohan-varma
2022-11-11 03:41:40 +00:00
Andrew Gu
6bf2776ac1 [FSDP][Perf] Do not call pad in no-padding case (#88769)
- Calling `F.pad()` issues a pad kernel from the CPU even if there is no padding needed, which can incur some non-negligible overhead. This PR removes that unnecessary call for the no-padding case.
- This PR also does not zero the newly-allocated sharded gradient tensor before the reduce-scatter if `use_orig_params=True` because there is no need. The reduce-scatter will fill the tensor anyway, and we do not care about the values in the padding. For `use_orig_params=False`, the padding is exposed to the user, so we preserve the existing semantics of zeroing it. I left a to-do to follow-up since we may optimize that.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88769
Approved by: https://github.com/zhaojuanmao
2022-11-10 18:18:55 +00:00
Kurt Mohler
ee28b865ee Deprecate TypedStorage, its derived classes, and all of their public methods (#85303)
Part of #85302

Pull Request resolved: https://github.com/pytorch/pytorch/pull/85303
Approved by: https://github.com/ezyang
2022-11-08 18:11:01 +00:00
Andrew Gu
a84d68cdfd [FSDP][Docs] Reword sharding_strategy docs and other minor doc changes (#88431)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88431
Approved by: https://github.com/mrshenli
2022-11-03 23:32:41 +00:00