Commit Graph

1521 Commits

Author SHA1 Message Date
Rodrigo Kumpera
14dd5db2f5 [fsdp] Fix test for 2d parallel integration to trigger the load hooks. (#86272)
nit: replaced empty array bool test with explicit test for its length.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/86272
Approved by: https://github.com/awgu
2022-10-13 20:28:44 +00:00
Kshiteej K
54ee95c8ec [nn] module: full_backward_pre_hook (#86700)
Fixes https://github.com/pytorch/pytorch/issues/42824

* [x] Test
* [x] Doc
Pull Request resolved: https://github.com/pytorch/pytorch/pull/86700
Approved by: https://github.com/soulitzer
2022-10-13 17:36:39 +00:00
Colin Taylor
25811663af [FSDP] restricts meta model check to non ignored modules in FSDP (#86766)
Summary: as title

Test Plan: see test plan D40287799

Differential Revision: D40287890

Pull Request resolved: https://github.com/pytorch/pytorch/pull/86766
Approved by: https://github.com/awgu
2022-10-13 16:48:24 +00:00
Jerry Zhang
c12f829cce [nn] Add remove_duplicate flag to named_buffers (#674) (#85903)
Summary:
X-link: https://github.com/pytorch/torchrec/pull/674

Pull Request resolved: https://github.com/pytorch/pytorch/pull/84984

this is to allow named_buffers to return the same buffer objects with different names multiple times, needed by internal use cases
ghstack-source-id: 168589597

Test Plan:
python test/test_nn.py -k test_buffers_and_named_buffers

Imported from OSS

Reviewed By: albanD

Differential Revision: D39493161

Pull Request resolved: https://github.com/pytorch/pytorch/pull/85903
Approved by: https://github.com/albanD
2022-10-11 18:49:09 +00:00
eqy
352d926482 [CUBLAS][CUDA GRAPHS] (re-re-re-re-open of #83461) Explicitly set the workspace for cuBLAS handles (#86645)
re-opening (again) in hopes of working around failed/stuck CLA check

CC @ptrblck @ngimel @huydhn
Pull Request resolved: https://github.com/pytorch/pytorch/pull/86645
Approved by: https://github.com/zdevito
2022-10-11 16:03:49 +00:00
Andrew Gu
6ab07febce [FSDP][Easy] Rename _prefixed_param_names -> _fqns for consistency (#86653)
This renames `_prefixed_param_names` to `_fqns` to help converge on the terminology.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/86653
Approved by: https://github.com/rohan-varma
2022-10-11 12:49:45 +00:00
Louis Feng
55479fe80e Enable capturing of comm collective parameters (#98) (#85368)
Summary:
X-link: https://github.com/facebookresearch/torch_ucc/pull/98

Add tensor input, output, and other metadata for PyTorch comms.

Test Plan: P517138779

Reviewed By: Pavani-Panakanti

Differential Revision: D38357077

Pull Request resolved: https://github.com/pytorch/pytorch/pull/85368
Approved by: https://github.com/H-Huang
2022-10-11 04:38:26 +00:00
Andrew Gu
ce7751188a [DDP] Add PackedSequence support when device_ids is specified (#86614)
Before this PR, if a user runs DDP with `device_ids` specified and with a `PackedSequence` input, then the execution will error with something like:
```
raise ValueError(
  ValueError: batch_sizes should always be on CPU. Instances of PackedSequence should never be created manually. They should be instantiated by
 functions like pack_sequence and pack_padded_sequences in nn.utils.rnn. https://pytorch.org/docs/stable/nn.html...
```
This is because the DDP forward calls `_to_kwargs()`, which calls `_recursive_to()`, which moves the inputs to GPU. However, `_is_namedtuple(packed_sequence)` returns `True`, leading to the branch `return [type(obj)(*args) for args in zip(*map(to_map, obj))]`, which tries to construct a `PackedSequence` directly via `type(obj)(*args)`, leading to the error.

Repro for `_is_namedtuple(packed_sequence)` returning `True`:
```
import random

import torch
import torch.nn.utils.rnn as rnn_utils
from torch.nn.parallel.scatter_gather import _is_namedtuple

def _ordered_sequence(tensor_type):
    seqs = [tensor_type(random.randint(1, 256))
            for _ in range(32)]
    seqs = [s.random_(-128, 128) for s in seqs]
    ordered = sorted(seqs, key=len, reverse=True)
    return ordered

def _padded_sequence(tensor_type):
    ordered = _ordered_sequence(tensor_type)
    lengths = [len(i) for i in ordered]
    padded_tensor = rnn_utils.pad_sequence(ordered)
    return padded_tensor, lengths

padded, lengths = _padded_sequence(torch.Tensor)
packed = rnn_utils.pack_padded_sequence(
    padded, lengths, enforce_sorted=False)
print(type(packed), packed.data.device)
print(_is_namedtuple(packed))
```

Test Plan:
```
python test/distributed/test_c10d_nccl.py -k test_ddp_packed_sequence
```
Without the fix, the added unit test fails with the expected error.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/86614
Approved by: https://github.com/rohan-varma
2022-10-10 21:50:59 +00:00
anjali411
e2a4dfa468 Add correct __all__ for torch.distributed and torch.cuda submodules (#85702)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85702
Approved by: https://github.com/ezyang, https://github.com/albanD, https://github.com/rohan-varma
2022-10-10 19:15:24 +00:00
Andrew Gu
5102f0cffc [FSDP][1/N] Retire FlattenParamsWrapper (#86117)
This deprecates `FlattenParamsWrapper`'s usage for "unflattening" the original parameters. After this PR, FPW only serves to register and de-register its `FlatParameter` for the parent `FullyShardedDataParallel` instance.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/86117
Approved by: https://github.com/zhaojuanmao
2022-10-10 11:38:44 +00:00
Andrew Gu
5844f00bbf [FSDP] Add low_prec prefix to param and reduce dtype varnames (#86512)
This PR renames `param_dtype` and `reduce_dtype` in `HandleConfig` to `low_prec_param_dtype` and `low_prec_reduce_dtype` to emphasize that they are meant to be of the low precision (if not `None`).

(In my mind, mixed precision refers to the paradigm of using both full and low precision together during training. "Reduced" and "low precision" mean the same thing, but I prefer the term "low precision" in the code since it is shorter. A particular dtype can be a low precision dtype or a full precision dtype.)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/86512
Approved by: https://github.com/zhaojuanmao
2022-10-10 09:33:33 +00:00
Andrew Gu
cc5de7f1ac [FSDP] Remove utils.py (moved to _utils.py) (#86528)
I messed up my git with an earlier PR, where I did not actually remove `utils.py` when moving it to `_utils.py`. This removes `utils.py`, which is now outdated and unused.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/86528
Approved by: https://github.com/H-Huang
2022-10-10 09:31:01 +00:00
Andrew Gu
af9c6bc851 [FSDP] Add keep_low_precision_grads support when CPU offloading (#86495)
When CPU offloading, FSDP uses `_cpu_grad`, not `_saved_grad_shard`. This adds support for `keep_low_precision_grads` for that case.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/86495
Approved by: https://github.com/rohan-varma
2022-10-08 03:26:40 +00:00
Andrew Gu
a95889ba7c [FSDP] Add initial summon_full_params(with_grads=True) (#85738)
This adds `summon_full_params(with_grads=True)` for `use_orig_params=True` and `offload_to_cpu=False`. Filling in the `use_orig_params=False` case requires some already-planned refactoring, and the `offload_to_cpu=True` case needs some additional work as well.

Adding this is helpful for debugging `use_orig_params=True` to make sure gradients are being updated correctly.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85738
Approved by: https://github.com/rohan-varma
2022-10-07 21:03:18 +00:00
Andrew Gu
be682befbc [FSDP] Add use_orig_params (#84911)
**Overview**
This PR adds the option to use the original parameters via `use_orig_params=True` in the FSDP constructor.
- This exposes the original parameters rather than the `FlatParameter`s from `named_parameters()`, which means that the optimizer runs on the original parameters. Hence, users may assign original parameters from the same `FlatParameter` to different parameter groups.
- This enables decoupling the original parameter variables from their storage without changing the variables themselves, which is critical for our upcoming execution-order-based non-recursive wrapping policy.

For more detailed design explanation, refer to the Quip shared internally.

**Follow-Ups**
See 85831 (removing link to avoid spamming the issue whenever I update this PR).

`test_fsdp_use_orig_params.py` adds ~4 min 46 seconds to the TTS on the AWS cluster.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84911
Approved by: https://github.com/rohan-varma
2022-10-07 18:07:17 +00:00
PyTorch MergeBot
0e639ff45c Revert "Cleanup PT-D imports (#85781)"
This reverts commit 9a170b24f6.

Reverted https://github.com/pytorch/pytorch/pull/85781 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally
2022-10-07 14:55:44 +00:00
Dennis van der Staay
9a170b24f6 Cleanup PT-D imports (#85781)
Summary:
The flow logic around torch.dist imports results in large number of pyre errors (100's); would be preferable to just raise on importing as opposed to silently fail.

Con: Some percentage (MacOS?) of users may have notebooks that imports PT-D, although would think small, since any attempt to call parts of the library would just fail...

TODO: assuming ok, will remove the 10's-100's of unused pyre ignores no longer required.

Test Plan: existing unit tests

Differential Revision: D39842273

Pull Request resolved: https://github.com/pytorch/pytorch/pull/85781
Approved by: https://github.com/mrshenli
2022-10-07 00:29:32 +00:00
Rohan Varma
f0977c4658 [FSDP] Doc to explain running submodules (#86343)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/86343
Approved by: https://github.com/awgu
2022-10-06 23:10:23 +00:00
Rohan Varma
3db8ddcac1 [FSDP] Fix clip_grad_norm for CPU offload (#86337)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/86337
Approved by: https://github.com/awgu
2022-10-06 23:10:23 +00:00
Rohan Varma
adfd8f3823 [FSDP] assert to runtime error (#86336)
Prefer raising an error over `assert` which should mostly to indicate a developer bug, but user can cause this error path.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/86336
Approved by: https://github.com/awgu
2022-10-06 23:10:21 +00:00
Bin Chen
8c6d352bcf Log a new "timer expired" event to Scuba in file_based_local_timer (#85861)
Summary: The "kill worker process" event was logged to Scuba only when the worker process was really reaped. We want to add a new event "timer expired", no matter the worker process will be reaped or not. This will help collect data before we enable the JustKnob to kill the worker process on timeout.

Test Plan:
### Unit Test
```
buck test mode/dev-nosan //caffe2/test/distributed/elastic/agent/server/test:local_agent_test
```
```
Test Session: https://www.internalfb.com/intern/testinfra/testrun/7318349508929624
RE: reSessionID-ea464c43-54e7-44f2-942b-14ea8aa98c74  Up: 10.5 KiB  Down: 1.1 MiB
Jobs completed: 100. Time elapsed: 3206.9s. Cache hits: 91%. Commands: 11 (cached: 10, remote: 1, local: 0)
Tests finished: Pass 55. Fail 0. Fatal 0. Skip 0. 0 builds failed
```
--------
```
buck test mode/dev-nosan //caffe2/test/distributed/elastic/agent/server/test/fb:local_agent_fb_internal_test
```
```
Test Session: https://www.internalfb.com/intern/testinfra/testrun/6473924579130483
RE: reSessionID-231a47b7-a43d-4c0f-9f73-64713ffcbbd3  Up: 5.7 MiB  Down: 1.9 GiB
Jobs completed: 182156. Time elapsed: 282.4s. Cache hits: 99%. Commands: 72112 (cached: 72107, remote: 1, local: 4)
Tests finished: Pass 2. Fail 0. Fatal 0. Skip 0. 0 builds failed
```

Differential Revision: D39903376

Pull Request resolved: https://github.com/pytorch/pytorch/pull/85861
Approved by: https://github.com/d4l3k
2022-10-05 18:23:53 +00:00
Andrew Gu
67eb2d5952 [FSDP] Dequeue one instead of flush (#86165)
For the rate limiter, I initially implemented the approach of only dequeueing a single event, but there was concern about blocking the CPU _every_ iteration. The landed approach instead blocks every `_max_num_inflight_all_gathers` iterations and flushes the entire queue.

However, upon further analysis, the approach of dequeueing a single event should be more performant with the same memory usage -- as the name suggests, both have `_max_num_inflight_all_gathers` concurrently inflight all-gathers. The cost of blocking the CPU thread is not important compared to the duration the CPU thread is actually blocked. This PR's approach reduces the latter quantity.

**Fast Communication; Slow Computation**
<img width="1235" alt="Screen Shot 2022-10-04 at 4 15 13 PM" src="https://user-images.githubusercontent.com/31054793/193917536-f1491803-9578-45ea-ba6e-e735c1bf7784.png">

**Slow Communication; Fast Computation**
<img width="718" alt="Screen Shot 2022-10-04 at 4 34 15 PM" src="https://user-images.githubusercontent.com/31054793/193921508-f2a4fd22-2b03-4a8e-b6ca-634c584c70e2.png">

**T5-11B**
2 nodes / 16 40 GB A100s with EFA and batch size 6:
- [Old] 5.81 s / batch; 24 and 20 CUDA malloc retries on local rank 0s; 35.234 GB peak active; 38.806 GB peak reserved
- [New] 5.10 s / batch; 25 and 29 CUDA malloc retries on local rank 0s; 35.234 GB peak active; 38.868 GB peak reserved

4 nodes / 32 40 GB A100s with EFA and batch size 7:
- [Old] 5.21 s / batch; 0, 0, 0, 0 CUDA malloc retries on local rank 0s; 33.695 GB peak active; 38.494 GB peak reserved
- [New] 4.93 s / batch; 1, 0, 0, 0 CUDA malloc retries on local rank 0s; 33.678 GB peak active; 38.792 GB peak reserved

The new version changes the fragmentation in the allocator. It is possible that by blocking the CPU thread more in the old approach, the initial blocks used to serve the all-gather stream allocations are different compared to the new approach. Even though the number of CUDA malloc retries increases slightly, the net result is a speedup with the new approach.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/86165
Approved by: https://github.com/zhaojuanmao
2022-10-05 11:28:12 +00:00
Andrew Gu
248796987e [FSDP] Expose internal prefetch limits (#86198)
This PR refactors the prefetching implementation to enable a module to prefetch more than one all-gather.
- The motivation is for backward prefetching, but forward prefetching is included in the change as well.
- The prefetching limit is a _limit_. In some edge cases (e.g. dynamic graph or first/last module), the limit may not be reached.
- The prefetching limit is kept as internal in this PR -- it is set as local variables `backward_prefetch_limit` and `forward_prefetch_limit` in the `FullyShardedDataParallel` constructor and passed to the `_ExecOrderData()` constructor.
- This PR additionally includes some clean up for forward prefetching but does not change any semantics assuming static graph.

If we increase the `backward_prefetch_limit` to `2`, then a typical pattern may be that the first module in the pre-backward prefetches 2, but every next module only prefetches 1 since its first target was already prefetched by the previous. If we did not do this behavior, then with more modules, the prefetching would run further and further ahead.

**`_handles_prefetched`**
- This is used to avoid multiple modules prefetching the same handles keys.
- `_handles_prefetched[handles_key]` is set to `True` when the prefetch for `handles_key` happens from the CPU thread (`_prefetch_handles()`).
- `_handles_prefetched[handles_key]` is set to `False` when any handle in `handles_key` is resharded (`_reshard()`).
- `_handles_prefetched` is cleared at the end of the backward (`_wait_for_post_backward()`).

**`_needs_pre_backward_unshard`**
- This is used to determine if a handles key should be backward prefetched at all.
- `_needs_pre_backward_unshard[handles_key]` is set to `False` in the post-forward (`_register_pre_backward_hooks()`).
- `_needs_pre_backward_unshard[handles_key]` is set to `True` in the post-forward if the forward outputs include tensors that require gradient (`_register_pre_backward_hook()`).
- `_needs_pre_backward_unshard[handles_key]` is set to `False` in the pre-backward hook, after unsharding (`_pre_backward_hook()`).

**`_needs_pre_forward_unshard`**
- This is used to determine if a handles key should be forward prefetched at all.
- `_needs_pre_forward_unshard[handles_key]` is set to `True` in the root's pre-forward (`_fsdp_root_pre_forward()`).
- `_needs_pre_forward_unshard[handles_key]` is set to `False` in the pre-forward unshard (`_pre_forward_unshard()`).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/86198
Approved by: https://github.com/zhaojuanmao
2022-10-04 22:37:22 +00:00
Chien-Chin Huang
2067b768fc [FSDP] Delay moving tensor to CPU until necessary for optim_state_dict() (#85761)
Optimizer state_dict currently move tensors to CPU() immediately after allgather(). However, for sharded optimizer state_dict, this moving is duplicated. We should wait until all the sharding are done. This PR may slightly reduce the performance of full optimizer state_dict as it has to allocate more memory than w/o this PR. But the benchmark shows the memory allocation is pretty light.

Differential Revision: [D39855912](https://our.internmc.facebook.com/intern/diff/D39855912/)

**NOTE FOR REVIEWERS**: This PR has internal Meta-specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D39855912/)!
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85761
Approved by: https://github.com/rohan-varma
2022-10-03 17:23:23 +00:00
Jesus Magana
c670bad72f Update dist.scatter() documentation (#86069)
Update documentation for dist. scatter

Fixes #84566

Pull Request resolved: https://github.com/pytorch/pytorch/pull/86069
Approved by: https://github.com/rohan-varma, https://github.com/H-Huang
2022-10-03 17:22:08 +00:00
Rohan Varma
2b5625a726 Update hierarchical_model_averager.py (#85648)
Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/85648
Approved by: https://github.com/wayi1, https://github.com/H-Huang
2022-10-03 06:15:20 +00:00
Ke Wen
05d1128106 [c10d] Start deprecating *_multigpu APIs (#85961)
### Deprecation reasons:
- For most users training is on one GPU per process so these APIs are rarely used
- They added one more API dimension
- They can be expressed in a composed manner
- They are not abstracted – specific to GPU
- They caused backend APIs and implementations to have nested `std::vector<std::vector<Tensor>>`, which is hard to read or maintain

Pull Request resolved: https://github.com/pytorch/pytorch/pull/85961
Approved by: https://github.com/XilunWu, https://github.com/H-Huang
2022-10-01 00:59:39 +00:00
Ke Wen
463283e016 [c10d] Start deprecating *_coalesced APIs (#85959)
- We consider that general users need not to use the `*_coalesced` APIs unless there is an extreme concern about performance.

- We are investigating using a context manager named `coalescing_manager` which wrap around multiple individual collectives to compose the coalescing hint, rather than giving each collective a *_coalesced variant.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85959
Approved by: https://github.com/XilunWu, https://github.com/H-Huang
2022-10-01 00:55:27 +00:00
Chien-Chin Huang
be29ca9716 [FSDP] Ignore buffers that are non-persistent. (#85740)
A buffer can be registered as non-persistent. A non-persistent buffer won't be in the state_dict.

Differential Revision: [D39858689](https://our.internmc.facebook.com/intern/diff/D39858689/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85740
Approved by: https://github.com/awgu, https://github.com/rohan-varma
2022-10-01 00:28:17 +00:00
PyTorch MergeBot
71eb04403c Revert "[CUBLAS][CUDA GRAPHS] (re-re-open of #83461) Explicitly set the workspace for cuBLAS handles (#85447)"
This reverts commit b04b2fa9aa.

Reverted https://github.com/pytorch/pytorch/pull/85447 on behalf of https://github.com/seemethere due to Caused a CUDA memory leak, detected by our performance benchmark suite
2022-09-30 20:53:41 +00:00
Rohan Varma
3a13c8493a [1.13] Mention optim_input future BC breakage (#85963)
We should remove this arg when release after 1.13 rolls around, enhance warning to indicate it will be gone. We can do this as FSDP is still beta and can be BC breaking until we stabilize the API.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85963
Approved by: https://github.com/awgu
2022-09-30 16:28:17 +00:00
Ke Wen
ade1c19612 Add reduce_scatter_tensor in place of _reduce_scatter_base (#85867)
This is a twin PR similar to the one for `all_gather_into_tensor` (#85686).
The philosophy for renaming `_reduce_scatter_base` instead of merging it is described in #85686.

Cc @rohan-varma @H-Huang @crcrpar @ptrblck @mrshenli

Pull Request resolved: https://github.com/pytorch/pytorch/pull/85867
Approved by: https://github.com/crcrpar, https://github.com/H-Huang
2022-09-30 05:48:16 +00:00
Saliya Ekanayake
941d7a31f6 Pass group ranks and options to third party distributed backends (#73164)
Fixes #73163

PyTorch's [_new_process_group_helper()](9f541aa3ac/torch/distributed/distributed_c10d.py (L633)) does not pass group's participating ranks to the backend.

This PR adds the above capability. Also, refactors some variables for better clarity.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/73164
Approved by: https://github.com/kumpera
2022-09-29 17:28:58 +00:00
Masaki Kozuki
5f26df0345 resubmit: "resubmit: [mta] APEX style Fused Adam (#81705) (#85507)" (#85739)
Embarrassingly move the pow implementations around [ATen/native/cuda/PowKernel.cu#L21-L66](849b08f14b/aten/src/ATen/native/cuda/PowKernel.cu (L21-L66)) to a new header file and let FusedAdam use them to tame MSVC, hopefully.

cc @ngimel @ptrblck
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85739
Approved by: https://github.com/ngimel
2022-09-29 16:58:59 +00:00
PyTorch MergeBot
6fae62b35f Revert "C10D extension to enable per-thread PG (#84153)"
This reverts commit 5cbffbbac9.

Reverted https://github.com/pytorch/pytorch/pull/84153 on behalf of https://github.com/kumpera due to broke internal stuff
2022-09-29 13:51:05 +00:00
Andrew Gu
ff71f45788 [FSDP] Add FSDPExtensions for TP support (#85039)
This adds `FSDPExtensions` to enable TP + FSDP composability. To be agnostic to both `ShardedTensor` and `DistributedTensor`, the design relies on customizable hooks.

Some notes:
- I preferred the `_ext` prefix (short for "extension") over `_param_extension` simply because it is shorter. It should not matter much because it is purely internal facing.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85039
Approved by: https://github.com/kumpera, https://github.com/fegin
2022-09-28 18:34:17 +00:00
Eddie Yan
b04b2fa9aa [CUBLAS][CUDA GRAPHS] (re-re-open of #83461) Explicitly set the workspace for cuBLAS handles (#85447)
Now includes @dagitses 's optimizations and fixes for teardown

CC @ngimel @ptrblck
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85447
Approved by: https://github.com/malfet
2022-09-28 16:04:58 +00:00
Chien-Chin Huang
1c1f3a99dc [FSDP] Handle the state_dict on CPU cases (#85640)
state_dict may not be on GPUs. We need to move it to the compute_device in order to gather the ShardedTensor.

Differential Revision: [D39681730](https://our.internmc.facebook.com/intern/diff/D39681730/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85640
Approved by: https://github.com/rohan-varma, https://github.com/awgu
2022-09-28 01:04:42 +00:00
Ke Wen
775a22c7c6 Add all_gather_into_tensor in place of _all_gather_base (#85686)
### Description
- This PR renames `_all_gather_base` to `all_gather_into_tensor` so that it is clearer in meaning.
- The `all_gather_into_tensor` API differs from the `all_gather` API in the output it accepts -- a single, large tensor instead of a list of tensors.
- This PR also adds deprecation warning to `_all_gather_base`.

### Issue
`_all_gather_base` was implemented in https://github.com/pytorch/pytorch/pull/33924 to avoid unnecessary flattening. There was previous effort (#82639) to merge `_all_gather_base` with the existing `all_gather` API by detecting the parameter type passed in for the output.

There are, however, two "blockers" that make the merge difficult:
(i) The merge leads to backward compatibility break. We would need to change the parameter name `tensor_list` in `all_gather` to a general name `output` that can cover both tensor and tensor list.
(ii) Recently, the `all_gather` API has added uneven tensor support, utilizing the tensor boundaries implied by the list. We are, however, not sure to add such support to the `_all_gather_base` function, because that would require users to pass in additional tensor boundary information.

In view of the above, we decided to productize `_all_gather_base` as a separate function, but with a clearer name.

### Testing
Added tests:
- `test_all_gather_into_cat_tensor_cuda` -- output form as with `torch.cat`. For example:
```
        >>> tensor_in
        tensor([1, 2], device='cuda:0') # Rank 0
        tensor([3, 4], device='cuda:1') # Rank 1
        >>> tensor_out
        tensor([1, 2, 3, 4], device='cuda:0') # Rank 0
        tensor([1, 2, 3, 4], device='cuda:1') # Rank 1
```
- `test_all_gather_into_stack_tensor_cuda` -- output form as with `torch.stack`. For example:
```
        >>> tensor_out2
        tensor([[1, 2],
                [3, 4]], device='cuda:0') # Rank 0
        tensor([[1, 2],
                [3, 4]], device='cuda:1') # Rank 1
```
The output form is determined by the shape of the output tensor passed by the user, no flag used.

Cc @rohan-varma @mrshenli @crcrpar @ptrblck @H-Huang
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85686
Approved by: https://github.com/rohan-varma, https://github.com/crcrpar
2022-09-27 22:50:22 +00:00
Rodrigo Kumpera
5cbffbbac9 C10D extension to enable per-thread PG (#84153)
Move a bunch of globals to instance methods and replace all use to them.

We move all PG related globals under World and use a singleton instance under _world.

This creates an undocumented extension point to inject full control of how how c10d
state behaves.

One simple hack is to change _world to an implementation that uses a threadlocal
and enable per-thread PGs.

It almost get DDP working and the PG is missing an implementation of all_reduce.

This enables notebook usage of PTD, which is a big deal for learning it:
https://gist.github.com/kumpera/32cb051fa26b8cad8bdf671f968dcd68

Pull Request resolved: https://github.com/pytorch/pytorch/pull/84153
Approved by: https://github.com/rohan-varma
2022-09-27 21:42:31 +00:00
PyTorch MergeBot
7167996346 Revert "resubmit: [mta] APEX style Fused Adam (#81705) (#85507)"
This reverts commit 4615d1bcfa.

Reverted https://github.com/pytorch/pytorch/pull/85507 on behalf of https://github.com/atalman due to Break internal windows builds
2022-09-27 16:59:35 +00:00
Bin Chen
0f561f0bd2 Log Watchdog events to scuba (#85391)
Summary: This diff logs some events of FileTimerServer to a scuba table. The events include "server started", "server stopped", "set timer", "clear timer" and "kill worker process".

Test Plan:
### Unit Test
```
buck test mode/dev-nosan //caffe2/test/distributed/elastic/agent/server/test:local_agent_test
```
```
Test Session: https://www.internalfb.com/intern/testinfra/testrun/1407375146936031
RE: reSessionID-2224cf79-6a28-4762-ab7c-9875adb244dc 3.4 KiB▲,  0.0 B▼
Jobs completed: 57. Time elapsed: 3084.4s.
Tests finished: Pass 55. Fail 0. Fatal 0. Skip 0. 0 builds failed
```

Differential Revision: D39665560

Pull Request resolved: https://github.com/pytorch/pytorch/pull/85391
Approved by: https://github.com/d4l3k
2022-09-26 16:05:17 +00:00
Rohan Varma
a8074a1a0b [Checkpoint] rename apply_ac_wrapper (#85449)
Per title

Differential Revision: [D39714855](https://our.internmc.facebook.com/intern/diff/D39714855/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85449
Approved by: https://github.com/awgu
2022-09-23 21:15:08 +00:00
Rohan Varma
cc64f64670 [Docs] Minor fix to apply_ac doc (#85448)
Per title

Created from CodeHub with https://fburl.com/edit-in-codehub

Differential Revision: [D39714530](https://our.internmc.facebook.com/intern/diff/D39714530/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85448
Approved by: https://github.com/awgu
2022-09-23 21:15:08 +00:00
Masaki Kozuki
4615d1bcfa resubmit: [mta] APEX style Fused Adam (#81705) (#85507)
This PR implements an APEX style FusedAdam in PyTorch. This is different from the APEX one in that this is compatible with `torch.cuda.amp.GradScaler` by setting `_step_supports_amp_scaling` to `True` and unscales gradients inside its CUDA kernel.

related: https://github.com/pytorch/pytorch/issues/68041, https://github.com/pytorch/pytorch/issues/71274, https://github.com/pytorch/pytorch/issues/80167 possibly related to https://github.com/pytorch/pytorch/issues/80595#issuecomment-1178519436

Pull Request resolved: https://github.com/pytorch/pytorch/pull/81705
Approved by: https://github.com/ngimel

cc @ptrblck @ngimel
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85507
Approved by: https://github.com/ngimel
2022-09-23 18:56:00 +00:00
Andrew Gu
56c0c0af5b [ShardedTensor] Add is_floating_point (#85483)
This adds `is_floating_point()` support to `ShardedTensor`. This is needed for `ShardedTensor` + FSDP.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85483
Approved by: https://github.com/wanchaol
2022-09-23 04:48:03 +00:00
Andrew Gu
c8f78d417b [ShardedTensor] Add is_meta (#85482)
This adds `is_meta` support to `ShardedTensor`. This is needed for `ShardedTensor` + FSDP.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85482
Approved by: https://github.com/wanchaol
2022-09-23 04:48:03 +00:00
Andrew Gu
05d0eb2aee [FSDP] Make _ran_pre_backward_hook check more robust (#85481)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85481
Approved by: https://github.com/rohan-varma
2022-09-23 04:48:01 +00:00
Andrew Gu
cf0de77c2c [Easy][FSDP] Simplify assert to p_assert (#85479)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85479
Approved by: https://github.com/rohan-varma
2022-09-23 02:27:09 +00:00
Rohan Varma
5f6735ea97 [FSDP] Address comments on previous PR (#85490)
Address follow ups on https://github.com/pytorch/pytorch/pull/85223/

Differential Revision: [D39740878](https://our.internmc.facebook.com/intern/diff/D39740878/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85490
Approved by: https://github.com/awgu
2022-09-23 00:26:49 +00:00
PyTorch MergeBot
e505360eb8 Revert "[mta] APEX style Fused Adam (#81705)"
This reverts commit 7a6c4d0c50.

Reverted https://github.com/pytorch/pytorch/pull/81705 on behalf of https://github.com/dagitses due to broke internal builds, details to come
2022-09-22 19:37:29 +00:00
Edward Z. Yang
61b4e8a7bf More SymFloat support (#85411)
- Support storing SymFloat in IValue
- Add SymFloat to JIT type system (erases to float)
- Printing support for SymFloat
- add/sub/mul/truediv operator support for SymFloat
- Support truediv on integers, it returns a SymFloat
- Support parsing SymFloat from Python object

Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85411
Approved by: https://github.com/albanD
2022-09-22 08:07:22 +00:00
anjali411
85073b8ddc Add __all__ to fx, fistributed and cuda submodules (#85080)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85080
Approved by: https://github.com/albanD
2022-09-21 18:04:58 +00:00
PyTorch MergeBot
0ac6311356 Revert "[CUBLAS][CUDA GRAPHS] (re-open of #83461) Explicitly set the workspace for cuBLAS handles (#85292)"
This reverts commit 4012e623e8.

Reverted https://github.com/pytorch/pytorch/pull/85292 on behalf of https://github.com/dagitses due to broke an internal test during shutdown. Re-submit with #85399 in stack
2022-09-21 17:57:49 +00:00
Andrew Gu
125e9256f4 [FSDP] Add back forward_prefetch (#85177)
- This implements explicit forward prefetching following the static 1st iteration's pre-forward order when `forward_prefetch=True` in the FSDP constructor.
- This has the same unit test coverage as the original `forward_prefetch`.
- I checked via print statements that the prefetches are happening, but since I cannot get a good CPU bound workload, it is hard to tell via traces that the prefetch is working.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85177
Approved by: https://github.com/zhaojuanmao
2022-09-21 14:40:37 +00:00
Andrew Gu
977f8fce3c [FSDP] Simplify backward prefetch implementation (#85176)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85176
Approved by: https://github.com/zhaojuanmao
2022-09-21 14:40:37 +00:00
Masaki Kozuki
7a6c4d0c50 [mta] APEX style Fused Adam (#81705)
This PR implements an APEX style FusedAdam in PyTorch.
This is different from the APEX one in that this is compatible with `torch.cuda.amp.GradScaler` by setting `_step_supports_amp_scaling` to `True` and unscales gradients inside its CUDA kernel.

related: https://github.com/pytorch/pytorch/issues/68041, https://github.com/pytorch/pytorch/issues/71274, https://github.com/pytorch/pytorch/issues/80167
possibly related to https://github.com/pytorch/pytorch/issues/80595#issuecomment-1178519436

cc @ptrblck @ngimel
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81705
Approved by: https://github.com/ngimel
2022-09-20 17:18:33 +00:00
eqy
4012e623e8 [CUBLAS][CUDA GRAPHS] (re-open of #83461) Explicitly set the workspace for cuBLAS handles (#85292)
re-open of #83461 with fix for 10.2 build

CC @ngimel @malfet
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85292
Approved by: https://github.com/malfet
2022-09-20 16:31:54 +00:00
anjali411
cf2f552cd8 Add __all__ to torch.{fx, distributed, backends} submodules (#85079)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85079
Approved by: https://github.com/rohan-varma
2022-09-20 12:51:08 +00:00
Rohan Varma
7df0878b99 [FSDP] Option to keep grads in lower prec (#85223)
Reland of https://github.com/pytorch/pytorch/pull/85134, fix is to use fp16 instead of bf16 which is not supported on all platforms.

Differential Revision: [D39565189](https://our.internmc.facebook.com/intern/diff/D39565189/)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/85223
Approved by: https://github.com/awgu
2022-09-18 22:15:28 +00:00
PyTorch MergeBot
14b3bdc025 Revert "[FSDP] Option to keep grads in lower prec (#85134)"
This reverts commit 607eccb13c.

Reverted https://github.com/pytorch/pytorch/pull/85134 on behalf of https://github.com/ZainRizvi due to broke trunk, failing the tests test_grads_reduced_precision (main.TestFSDPMixedPrecisionUnsharded)
2022-09-16 22:33:06 +00:00
Andrew Gu
c6c3346d5a [FSDP] Short-term fix to remove optim_input (#84201)
This is a short-term quick fix to accommodate using the existing optimizer state APIs without passing `optim_input`. It preserves the existing `optim_input` code path but if `optim_input` is `None` while `optim` is not, then the APIs will use the new code path that relies on `self.param_groups` to get the information previously provided by `optim_input`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/84201
Approved by: https://github.com/rohan-varma
2022-09-16 21:24:15 +00:00
Rohan Varma
607eccb13c [FSDP] Option to keep grads in lower prec (#85134)
Differential Revision: [D39565189](https://our.internmc.facebook.com/intern/diff/D39565189)

Rehash of a similar PR from a month ago that got stale. Adds a config to FSDP MP so that gradients can be kept in lower precision, to support optimizers such as AnyPrecisionOptimizer which would like to keep grads in bf16.

To do this, for sharded cases, we cannot simply omit the cast back to the full precision param dtype, otherwise when setting `p.grad = p._saved_grad_shard` in finalize_params, autograd will throw an error indicating that the grad dtype should match the param dtype when it is being set.

As a workaround, we re-cast after setting this. Although, this means that for cases that use gradient accumulation, p._saved_grad_shard will be of the reduced dtype because it is set to p.grad in `_prep_grad_for_backward`. As a result, add a check + recast here as well.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85134
Approved by: https://github.com/awgu
2022-09-16 20:20:04 +00:00
Andrew Gu
5652ab22f6 [FSDP] Add _set_flattened(); _is_flattened() (#85038)
For both exposing the original parameters and for TP integration, we cannot only rely on `isinstance(param, FlatParameter)` to ignore already-flattened parameters in `.named_parameters()`. As a simple workaround, we can mark original parameters or `ShardedTensor`s with an attribute `_fsdp_flattened` (saved as a string variable `FSDP_FLATTENED`) to indicate that the parameter/tensor has already been flattened. This issue only arises for recursive/nested FSDP wrapping.

This PR also changes `isinstance(param, FlatParameter)` checks to `type(param) is FlatParameter` because all tensor subclasses that have `_is_param == True` will return `True` for `isinstance(param, <any subclass with _is_param == True>)`. This means that a `ShardedTensor` parameter will return `True` for `isinstance(st, FlatParameter)`, which is not what we want.
5271494ef2/torch/nn/parameter.py (L8-L10)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85038
Approved by: https://github.com/rohan-varma
2022-09-16 03:45:29 +00:00
Rodrigo Kumpera
7dcc723d35 [c10d] Ensure collectives are called with the same dtype for all tensor params. (#84664)
While passing tensors with different dtypes don't crash, they don't produce sensible results.

We see data tearing instead of casting.

It's not clear we want to support transparent casting so, for now, we fail when such input is presented.

Fixes #84525

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/84664
Approved by: https://github.com/rohan-varma
2022-09-15 22:32:51 +00:00
Andrew Gu
25ecc4889d [FSDP] Fix memory regression! (#85087)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85087
Approved by: https://github.com/zhaojuanmao
2022-09-15 19:21:07 +00:00
Salahuddin
6bd7d0f856 doc string fixed in torch.distributed.reduce_scatter (#84983)
Fixes #84865

Previous `torch.distributed.reduce_scatter`:

```
def reduce_scatter(output, input_list, op=ReduceOp.SUM, group=None, async_op=False):
    """
    Reduces, then scatters a list of tensors to all processes in a group.

    Args:
        output (Tensor): Output tensor.
        input_list (list[Tensor]): List of tensors to reduce and scatter.
        group (ProcessGroup, optional): The process group to work on. If None,
            the default process group will be used.
        async_op (bool, optional): Whether this op should be an async op.
```

Fixed:

```
def reduce_scatter(output, input_list, op=ReduceOp.SUM, group=None, async_op=False):
    """
    Reduces, then scatters a list of tensors to all processes in a group.

    Args:
        output (Tensor): Output tensor.
        input_list (list[Tensor]): List of tensors to reduce and scatter.
        op (optional): One of the values from
            ``torch.distributed.ReduceOp``
            enum.  Specifies an operation used for element-wise reductions
        group (ProcessGroup, optional): The process group to work on. If None,
            the default process group will be used.
        async_op (bool, optional): Whether this op should be an async op.
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/84983
Approved by: https://github.com/H-Huang
2022-09-15 18:17:10 +00:00
Andrew Gu
62af1c9eed [Easy][FSDP] Change assert -> p_assert (#85052)
This changes a few `assert`s to `p_assert()`s because they can run in the backward (some are in the forward, but AC can make them run in the backward).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85052
Approved by: https://github.com/zhaojuanmao
2022-09-15 02:05:34 +00:00
Andrew Gu
cdd625ba70 [Easy][FSDP] Remove outdated comment (#85051)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85051
Approved by: https://github.com/zhaojuanmao
2022-09-15 01:50:10 +00:00
Andrew Gu
cc62ad79c7 [FSDP] Fix pin_memory() for CPU offloading (#85048)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85048
Approved by: https://github.com/zhaojuanmao
2022-09-15 01:50:10 +00:00
Rohan Varma
8cb7826889 [CheckpointWrapper] Reentrant kwarg support (#84908)
A temporary patch to support keyword args when reentrant checkpoint wrapper is used. This is need to unblock some crucial workloads, the ideal fix would be checking this directly into torch.utils.checkpoint.

Differential Revision: [D39453453](https://our.internmc.facebook.com/intern/diff/D39453453/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84908
Approved by: https://github.com/awgu
2022-09-15 00:30:23 +00:00
Rohan Varma
55ca6901a7 [CheckpointWrapper] Decouple CPU offload (#84907)
This fixes the activation offload for checkpoint wrapper, which was previously broken. It was broken because it was tightly coupled with activation checkpoint, i.e. we did:

```
with save_on_cpu:
    checkpoint(module_forward())
```

which would not offload any activation tensors to CPU, as those activations would already be not saved by autograd due to the checkpoint implementation taking priority.

Now, if `offload_to_cpu` is specified, we only do `save_on_cpu` and no checkpoint, so all intermediate tensors are offloaded to CPU instead of checkpointed.

These wrappers can be composed, i.e. if we have

`(Linear, Linear) -> (Linear, Linear) -> (Linear, Linear)`

we can do

`Offload( checkpoint(Linear, Linear) -> checkpoint(Linear, Linear) -> checkpoint(Linear, Linear))`

and inner tensors would be checkpointed while outers will be offloaded.

Differential Revision: [D39448882](https://our.internmc.facebook.com/intern/diff/D39448882/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84907
Approved by: https://github.com/awgu
2022-09-15 00:30:23 +00:00
Shen Li
0f30059227 Remove eager mode support form CommTensor (#84978)
We don't need eager mode support (automatic wait on read) for now.
Removing that to simply the code. We can always add this back if
necessary in the future.

Note that, we still need the eager mode code in `__torch_dispatch__`,
as `make_fx` will also run the ops in eager mode to get the output.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84978
Approved by: https://github.com/wanchaol
2022-09-14 17:23:23 +00:00
Shen Li
8cbbd3a25f Avoid nested CommTensor wrapping (#84963)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84963
Approved by: https://github.com/wanchaol
2022-09-14 01:22:45 +00:00
Rodrigo Kumpera
38192f63cd Add __all__ for a few distributed modules plus a little typing (reland) (#84872)
This handles distributed_c10d, which is massive and ddp_comm_hooks.

This relands #84119 with the required fixes.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84872
Approved by: https://github.com/rohan-varma
2022-09-13 21:57:49 +00:00
Andrew Gu
33352336b4 [FSDP] Add rate limiter (#83917)
**Overview**
This PR adds a `bool` argument `limit_all_gathers` to the FSDP constructor, defaulted to `False`.
- Setting `limit_all_gathers=True` limits the max number of inflight all-gathers to 2 (an empirically chosen constant), preventing a fast CPU thread from over-allocating blocks to the all-gather stream.
- When experiencing a high number of CUDA malloc retries, the limiter can help reduce the number and hence lead to QPS improvement.

**Exploration**
I experimented with both a count-based limiter and size-based limiter (where the size is based on the inflight all-gather size in bytes).
- The size-based limiter did not provide any advantage, only confusing the developer and user alike on what threshold to set.
- For the count-based approach, I decided not to expose the max number of inflight all-gathers to the user since values other than 2 do not show improvements and exposing the knob may confuse users.

**T5-11B**
T5-11B evidences the performance gain from enabling the limiter and that a limit of 2 is a reasonable choice. This is run on an AWS cluster with 8 A100s per node and EFA. For both 2 and 4 nodes, we scale the batch size maximally before hitting OOM, which is a common practice.

<p float="left">
  <img src="https://user-images.githubusercontent.com/31054793/188936036-04427da9-f492-4e50-9b35-ff64665d9815.png" width="400" />
  <img src="https://user-images.githubusercontent.com/31054793/188936045-f44e659f-1e18-4ea7-8c78-0fce4ff8fb48.png" width="400" />
</p>

For 2 nodes, the limit of 2 yields 3.01x QPS improvement, and for 4 nodes, the limit of 2 yields 2.87x QPS improvement.

We need more data points, but the limiter may simplify the batch size scaling workflow. Normally, a practitioner may scale until hitting OOM and back off until there are few CUDA malloc retries. However, now the practitioner may be able to scale until hitting OOM and simply turn on the limiter to reduce the number of retries instead of backing off.

Differential Revision: [D39331201](https://our.internmc.facebook.com/intern/diff/D39331201)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83917
Approved by: https://github.com/zhaojuanmao
2022-09-13 17:15:41 +00:00
Andrew Gu
39676a977f [FSDP][Easy] Save unpadded/padded unsharded sizes as attributes (#84366)
Differential Revision: [D39331199](https://our.internmc.facebook.com/intern/diff/D39331199)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84366
Approved by: https://github.com/rohan-varma
2022-09-13 17:09:20 +00:00
Andrew Gu
afcc7c7f5c [FSDP] Generalize prefetching; lower unshard/reshard to handle (#83665)
### Additional Constructor Changes
- `self.sharding_strategy`
    - If the world size is 1, I clamp the sharding strategy to `NO_SHARD`, regardless of the passed-in sharding strategy, since the behavior is fully equivalent. This absolves the need for `p._is_sharded or self.world_size == 1` checks in the core code. Once we fully shift the paradigm to using handles, this should result in a clear net positive. However, for now, we still have some places where we interface directly with the `FlatParameter`, in which case we have some temporary hacky code.
- `HandleConfig`
    - As a part of the new design abstraction, much logic is lowered to the `FlatParamHandle`. This requires the handle be aware of mixed precision, CPU offloading, sharding strategy, and the process group (for world size > 1). To be less error-prone, I re-defined the `dataclass`s and `enum`s for the handle. These can be removed and coalesced with the existing ones.
    - The drawback is that the `FlattenParamsWrapper` constructor now takes in the `HandleConfig` to forward it to the `FlatParamHandle` constructor. I tolerate this since we plan to retire the FPW. For now, the handle's process group attributes are set later when we call `handle.shard()`.
    - We will dive into this logic lowering later. For now, the idea is we need to pass some extra info to the handle, which must go through the FPW.
- `FullyShardedDataParallel._shard_parameters()` -> `FlatParamHandle.shard()`
- [Important] Generalizing attributes to remove the 1 `FullyShardedDataParallel` : 1 `FlatParameter` assumption
    - **Before:** `_fsdp_graph_order`, `_pre_backward_hook_full_params_prefetched`, `_forward_full_params_prefetched`, `reshard_after_forward` are with respect to 1 `FullyShardedDataParallel`
    - **After:** (1) We use `FlatParamHandle` in place of `FullyShardedDataParallel`. (2) The atomic unit for forward and pre-backward is a _group_ of handles involved in the same module's forward/pre-backward. This is represented as `Tuple[FlatParamHandle, ...]`. For now, this is **always a singleton tuple**, but this shift enables a module having multiple FSDP parameters (which we have use cases for).
- `_reset_lazy_init()` attributes
    - The prefetched flags are merged into `self._handles_prefetched`, which is directly defined in the constructor. `reshard_after_forward` is retired since it can be fully determined by other attributes (`_is_root` and `sharding_strategy`).

## FSDP Runtime: Unshard

The first step is to read the existing `_rebuild_full_params()`. A few notable observations:
- It returns `Tuple[Tensor, bool]`. The first element is the _padded unsharded flattened parameter_, and the second element is whether we can free it upon exiting `summon_full_params()`. This return value is **only used in `summon_full_params()`**.
- If parameter mixed precision is enabled and the `FlatParameter` is already unsharded, then the low precision shard (`_mp_shard`) is still re-allocated on GPU. (It is freed at the end of the method.)
- If CPU offloading is enabled and the `FlatParameter` is already unsharded, then there is a no-op `p.data = p.data.to(self.compute_device, non_blocking=True)`.
- Inside `summon_full_params()`, `mixed_precision_cast_ran` is always `False`. Therefore, the return value for the `not p._is_sharded and mixed_precision_cast_ran` branch is unused.
-`summon_full_params()` can only be called (before forward or after backward) or (between forward and backward). Given this, I cannot think of a case where we call `summon_full_params()`, the `FlatParameter` is already unsharded, but `reshard_after_forward` is `True`. The `FlatParameter` should be sharded (before forward or after backward), and the `FlatParameter` may only be unsharded (between forward and backward) if `reshard_after_forward` is `False`.
- If parameter mixed precision is enabled and the sharding strategy is a sharded one, then inside `summon_full_params()`, the `FlatParameter` is unsharded in full precision. This involves allocating a new padded unsharded flattened parameter on GPU in full precision since `_full_param_padded` is in the low precision.

Some comments:
- Ideally, we reduce the complexity of the core code path: i.e. unshard for forward and pre-backward. If the return value is only used for `summon_full_params()`, we should consider if we can compartmentalize that logic.
- The branching is complex, and some return values are never used, where this fact is not immediately obvious. We should see if we can reduce the branch complexity.

Disclaimer: The difference in attribute semantics between `NO_SHARD` and the sharded strategies makes it challenging to unify the cases. This PR does not attempt to address that since it requires more design thought. However, it does attempt to reduce the complexity for the sharded strategies.

### Unshard: Core Code Path
Let us trace through the new logical unshard.
1. `FullyShardedDataParallel._unshard(self, handles: List[FlatParamHandle], prepare_gradient: bool)`
    - This iterates over the handles and calls `handle.pre_unshard()`, `handle.unshard()`, and `handle.post_unshard(prepare_gradient)` in the all-gather stream.
2. `FlatParamHandle.needs_unshard(self)`
    - We take an aside to look at this key subroutine.
    - For `NO_SHARD`, this returns `False`.
    - For sharded strategies, this checks if the padded unsharded flattened parameter is allocated. The padded unsharded flattened parameter is the base tensor for the unpadded unsharded flattened parameter, which is a view into the padded one. Thus, the padded one's allocation fully determines if the `FlatParameter` is unsharded.
    - For sharded strategies, to accommodate the parameter mixed precision + `summon_full_params()` case, we introduce `_full_prec_full_param_padded`, which is the padded unsharded flattened parameter in full precision. The helper `_get_padded_unsharded_flat_param()` takes care of this casing and returns the padded unsharded flattened parameter. Instead of allocating a new tensor each time, we manually manage `_full_prec_full_param_padded`'s storage just like for `_full_param_padded`.
3. `FlatParamHandle.pre_unshard(self)`
    - For sharded strategies, the postcondition is that the handle's `FlatParameter` points to the tensor to all-gather. This should be on the communication device and in the desired precision. The allocation and usage of the low precision shard for parameter mixed precision and the CPU -> GPU copy for CPU offloading both classify naturally in the pre-unshard.
    - For sharded strategies, if the `FlatParameter` does not need to be unsharded, `pre_unshard()` is a no-op. This avoids unnecessarily allocating and freeing the low precision shard.
    - For `NO_SHARD`, we simply preserve the existing semantics.
4. `FlatParamHandle.unshard(self)`
    - If the handle was resharded without freeing the padded unsharded flattened parameter (e.g. `summon_full_params()` between forward and backward when `reshard_after_forward=False`), then the `FlatParameter` points to the sharded flattened parameter. We need to switch to using the unsharded parameter. This is a design choice. Alternatively, we may not switch to using the sharded flattened parameter in `reshard()` if we do not free the padded unsharded flattened parameter. However, the postcondition that the `FlatParameter` points to the sharded flattened parameter after `reshard()` is helpful logically, so I prefer this approach.
    - Otherwise, this allocates the padded unsharded flattened parameter, all-gathers, and switches to using the unpadded unsharded flattened parameter.
    - In the future, we may add an option to `unshard()` that additionally all-gathers the gradient.
5. `FlatParamHandle.post_unshard(self, prepare_gradient: bool)`
    - For sharded strategies, if using parameter mixed precision, this frees the low precision shard. More generally, this should free any sharded allocations made in `pre_unshard()` since the all-gather has been launched. If using CPU offloading, the GPU copy of the local shard goes out of scope after `unshard()` and is able to be garbage collected. **We should understand if there is any performance difference between manually freeing versus deferring to garbage collection since our usage is inconsistent.** For now, I preserve the existing semantics here.
    - `prepare_gradient` is meant to be set to `True` for the pre-backward unshard and `False` for the forward unshard. This runs the equivalent logic of `_prep_grads_for_backward()`.
    - This post-unshard logic (notably the gradient preparation) now runs in the all-gather stream, which is fine because we always have the current stream wait for the all-gather stream immediately after `FullyShardedDataParallel._unshard()`. IIUC, we do not need to call `_mp_shard.record_stream(current_stream)` (where `current_stream` is the default stream) because `_mp_shard` is allocated and freed in the same (all-gather) stream.
    - A postcondition is that the `FlatParameter` is on the compute device. It should also have the unpadded unsharded size (though I do not have a check for this at the moment).

### Unshard: `summon_full_params()`
Now that we see how the logical unshard has been reorganized for the core code path, let us dive into `summon_full_params()`.

The two constraints are:
1. If using parameter mixed precision, we should unshard in full precision.
2. We must determine if we should free the padded unsharded flattened parameter upon exiting.

The first constraint is addressed as described before in the core unshard code path, so it remains to explore the second constraint.

I propose a simple rule: **We free iff we actually unshard the `FlatParameter` in `summon_full_params()`** (i.e. it was not already unsharded). We perform a case analysis:

**Parameter mixed precision enabled:**
* `NO_SHARD`: `flat_param.data` points to `flat_param._local_shard`, which is the full precision unsharded flattened parameter. This is **not safe to free**.
* `FULL_SHARD` / `SHARD_GRAD_OP`: We force full precision and all-gather to `_full_prec_full_param_padded`. We do not support `nested summon_full_params()`, so `_full_prec_full_param_padded` must be unallocated. We unshard, and it is **safe to free**.

**Parameter mixed precision disabled:**
* `NO_SHARD`: This is the same as with mixed precision enabled. This is **not safe to free**.
* `FULL_SHARD` / `SHARD_GRAD_OP`: We all-gather to `_full_param_padded`. It may already be unsharded.
    * Already unsharded: The unshard is a no-op. This is **not safe to free**.
        * For `FULL_SHARD`, this can happen for the root FSDP instance after `forward()` but before backward.
        * For `SHARD_GRAD_OP`, this can happen for all FSDP instances after `forward()` but before backward.
    * Needs unshard: We unshard. This is **safe to free**.

Therefore, we see that it is not safe to free when using `NO_SHARD` and when using a sharded strategy but the `FlatParameter` is already unsharded. This is precisely the proposed rule.

There were two notable edge cases that the existing code did not address.
1. The existing code tests if the `FlatParameter` is already unsharded by checking the allocation status of `_full_param_padded`. When using parameter mixed precision, this is the incorrect tensor to check. If `_full_param_padded` is allocated (e.g. when `reshard_after_forward=False` and calling `summon_full_params()` between forward and backward), the already-unsharded check is a false positive, and `summon_full_params()` does not correctly force full precision. https://github.com/pytorch/pytorch/issues/83068
    - This PR's `needs_unshard()` check correctly routes to the appropriate padded unsharded flattened parameter depending on the calling context (i.e. if it needs to force full precision or not).
2. The existing code does not free the GPU copy of the padded unsharded flattened parameter when calling `summon_full_params(offload_to_cpu=True)`. It unshards the `FlatParameter`, moves the padded unsharded flattened parameter to CPU, and sets the `FlatParameter` data to be the appropriate unpadded view into the padded unsharded flattened parameter on CPU. However, `_full_param_padded` still points to the all-gathered padded unsharded flattened parameter on GPU, which is kept in memory. https://github.com/pytorch/pytorch/issues/83076
    - This PR frees the GPU copy and reallocates it upon exiting `summon_full_params()`. This is essential for avoiding peak GPU memory usage from increasing as we recurse through the module tree. There may be some cases where we can avoid reallocation altogether, but that can be addressed in a follow-up PR.
    - This PR offloads the *unpadded* unsharded flattened parameter to CPU directly instead of the *padded* one. As far as I can tell, there is no need to include the padding since unflattening the original parameters does not require the padding.
    - The relevant code is in the context manager `FlatParamHandle.to_cpu()`.

### Unshard: Mixed-Precision Stream

This PR removes the mixed precision stream usage. As is, I do not think there is any extra overlap being achieved by the stream usage.

The low precision shard is allocated and copied to in the mixed precision stream ([code](1f99bdfcc4/torch/distributed/fsdp/fully_sharded_data_parallel.py (L1401-L1412))), and the current stream (in this case the all-gather stream) waits for the mixed precision stream ([code](1f99bdfcc4/torch/distributed/fsdp/fully_sharded_data_parallel.py (L1414))). However, we immediately schedule an all-gather that communicates that exact low precision shard ([code](1f99bdfcc4/torch/distributed/fsdp/fully_sharded_data_parallel.py (L3338))) with no other meaningful computation between. If we remove the mixed precision stream, the low precision shard is allocated and copied to in the all-gather stream (including the non-blocking CPU -> GPU copy if using CPU offloading).

Under this PR's design, we may consider a "pre-unshard" stream for all logical pre-unshard data transfers if we want to overlap in the future. IIUC, the overlap opportunity exists if there are multiple `FlatParameter`s per module, and we only have the all-gather stream wait for the data transfer corresponding to the local shard it communicates, not the others.

If we agree on removing the mixed-precision stream for now, I will remember to delete it from `_init_streams()`.

## FSDP Runtime: Reshard

Like with unshard, the first step is the look at the existing `_free_full_params()` and `_use_param_local_shard()`. A few notable observations:
- For only `NO_SHARD`, `_free_full_params()` includes a call to `_free_mp_shard()`.
- For `summon_full_params()`, there is a separate `_free_full_params_and_use_local_shard()` that duplicates the main logic of `_free_full_params()` and calls `_use_param_local_shard()`.
- In `forward()`, if `reshard_after_forward=True`, we call `_free_full_params()` and then `_free_mp_shard()`. Hence, for `NO_SHARD`, the `_free_mp_shard()` is a no-op.
- In the post-backward hook, we typically call `_free_full_params()` and `_free_mp_shard()`. The `_free_mp_shard()` is a no-op for `NO_SHARD` and if `reshard_after_forward=True`.

Some comments:
- The code certainly works, but some of the no-ops are subtle. When possible, we should make it clear when calls are no-ops or not. It is good that the existing code documents that `_free_mp_shard()` is a no-op in the post-backward hook when `reshard_after_forward=True`. However, there are still some non-obvious no-ops (around `NO_SHARD`).
- We should see if we can avoid the duplicate `_free_full_params_and_use_local_shard()`.

Let us trace through the logical reshard:
1. `FullyShardedDataParallel._reshard(self, handles: List[FlatParamHandle], free_unsharded_flat_params: List[bool])`
    - The two args should have the same length since they are to be zipped.
    - The goal of having `free_unsharded_flat_params` is that the caller should be explicit about whether the (padded) unsharded flattened parameter should be freed. The low precision shard is always meant to be freed (as early as possible), so there is no corresponding `List[bool]`.
2. `FlatParamHandle.reshard(self, free_unsharded_flat_param: bool)`
    - This frees the (padded) unsharded flattened parameter if `free_unsharded_flat_param` and switches to using the sharded flattened parameter.
    - Echoing back to forcing full precision in `summon_full_params()`, `_free_unsharded_flat_param()` frees the correct tensor by using `_get_padded_unsharded_flat_parameter()`.
3. `FlatParamHandle.post_reshard(self)`
    - I am not fully content with the existence of this method, but this seems to be an unavoidable consequence of `NO_SHARD`. Perhaps, this may be useful in the future for other reasons though.
    - Right now, this method is only meaningful for `NO_SHARD` + parameter mixed precision + outside `summon_full_params()`. `_mp_shard` is not freed in the post-unshard since it is also the low precision _unsharded_ flattened parameter, so we must delay the free until the the post-reshard.

Below the `FlatParamHandle.reshard()` and `post_reshard()` layer, there should not be any no-ops.

One final comment I will mention is that I like the `pre_unshard()`, `unshard()`, `post_unshard()`, and `reshard()`, `post_reshard()` organization because it makes it clear what the boundaries are and their temporal relationship. Through that, we can set pre- and post-conditions. Furthermore, we can eventually convert logic to hooks that may be registered on the `FlatParamHandle` (for `pre_unshard()`, `post_unshard()`, and `post_reshard()`). This may improve the customizability of FSDP.

 ## FSDP Runtime: `forward()`

- This PR reorganizes `forward()` in preparation for non-recursive wrapping, which uses pre-forward and post-forward hooks that expect the signature `hook(module, input)`. For FSDP, the `module` and `input` arguments are not used.
- This PR creates a new method `_fsdp_root_pre_forward()` to handle the logic only the root FSDP should run.

## FSDP Prefetching

Finally, we dive into the prefetching changes. Some highlights:
1. This PR unifies the execution order validation and prefetching implementations.
    - Both involve the execution order and can be unified to share some boilerplate.
2. Execution order validation only runs when the distributed debug level is `INFO`.
    - We have yet to have one success case where we actually catch an unintended source of dynamism. The warning is also too verbose. Hence, we are gating it by the `INFO` level.
3. This PR moves prefetching to be with respect to groups of handles (as mentioned in the constructor comment).
    - This is essential for supporting prefetching with non-recursive wrapping.
4. This PR does not include "bubbles", i.e. modules with no handles, in the recorded execution order(s). This deviates from the existing implementation.
    - This makes prefetching possibly more aggressive (when there are such bubbles), but it should not have significant performance implications either way.
5. This PR changes backward prefetching to reset the post-forward order each iteration (as intended).
6. This PR changes forward prefetching to use the first iteration's pre-forward order instead of the first iteration's post-forward order. (We can discuss whether we want this in this PR or not. Otherwise, I can keep it as using the post-forward order to preserve the existing semantics.) This PR also removes the `all_gather_stream.wait_stream(current_stream)` before forward prefetching because it does not help with high GPU reserved memory. We can add that back if desired.

### Appendix
#### Reverse Post-Forward Order Is Not Always the Pre-Backward Order
The existing PT-D FSDP pre-backward prefetching uses the reverse post-forward order.
<details>
  <summary>Model Code</summary>

  ```
  class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.block1 = nn.Sequential(
            nn.Conv2d(3, 4, kernel_size=3),
            nn.BatchNorm2d(4),
            nn.ReLU(inplace=True),
        )
        self.block2 = nn.Sequential(
            nn.Conv2d(4, 4, kernel_size=3),
            nn.BatchNorm2d(4),
            nn.ReLU(inplace=False),
        )
        self.block3 = nn.Linear(12, 8)
        self.head = nn.Sequential(
            nn.AdaptiveAvgPool2d(output_size=(1, 1)),
            nn.Flatten(),
            nn.Linear(4, 10),
        )

    def forward(self, x):
        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)
        return self.head(x)

  model = Model().cuda()
  fsdp_kwargs = {}
  model.block1[1] = FSDP(model.block1[1], **fsdp_kwargs)  # BN2d
  model.block2[1] = FSDP(model.block2[1], **fsdp_kwargs)  # BN2d
  model.block1 = FSDP(model.block1, **fsdp_kwargs)
  model.block2 = FSDP(model.block2, **fsdp_kwargs)
  model.block3 = FSDP(model.block3, **fsdp_kwargs)
  model = FSDP(model, **fsdp_kwargs)
  ```
</details>

<details>
  <summary>Execution Orders </summary>

  ```
  Pre-backward hook for ('head.2.weight', 'head.2.bias') 140339520587136 (model)
  Pre-backward hook for ('weight', 'bias') 140339461194656 (block3)
  Pre-backward hook for ('0.weight', '0.bias') 140339520589776 (block2)
  Pre-backward hook for ('weight', 'bias') 140339520587664 (block2 BN)
  Pre-backward hook for ('weight', 'bias') 140339520586656 (block1 BN)
  Pre-backward hook for ('0.weight', '0.bias') 140339520588768 (block1)

  Pre-forward order:
  ('head.2.weight', 'head.2.bias') 140339520587136 (model)
  ('0.weight', '0.bias') 140339520588768 (block1)
  ('weight', 'bias') 140339520586656 (block1 BN)
  ('0.weight', '0.bias') 140339520589776 (block2)
  ('weight', 'bias') 140339520587664 (block2 BN)
  ('weight', 'bias') 140339461194656 (block3)

  Reverse post-forward order:
  ('head.2.weight', 'head.2.bias') 140339520587136 (model)
  ('weight', 'bias') 140339461194656 (block3)
  ('0.weight', '0.bias') 140339520589776 (block2)
  ('weight', 'bias') 140339520587664 (block2 BN)
  ('0.weight', '0.bias') 140339520588768 (block1)
  ('weight', 'bias') 140339520586656 (block1 BN)
  ```
</details>

Differential Revision: [D39293429](https://our.internmc.facebook.com/intern/diff/D39293429)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83665
Approved by: https://github.com/zhaojuanmao
2022-09-13 17:05:10 +00:00
Andrew Gu
a2acead002 [FSDP][Easy] Minor cleanup (#84761)
This PR simply pulls out some minor changes from the next (monolithic) PR.

Differential Revision: [D39392147](https://our.internmc.facebook.com/intern/diff/D39392147)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84761
Approved by: https://github.com/zhaojuanmao
2022-09-13 17:01:52 +00:00
Andrew Gu
9d5b3e4da8 [FSDP] Remove forward_prefetch (#84600)
We are removing the `forward_prefetch` option. By the nature of async GPU kernel execution, launching the CPU kernel for the next layer's all-gather early does not actually improve performance. Moreover, the existing `forward_prefetch` uses the post-forward order instead of the pre-forward order, which leads to mis-targeted prefetched all-gathers.

Differential Revision: [D39454217](https://our.internmc.facebook.com/intern/diff/D39454217)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84600
Approved by: https://github.com/zhaojuanmao
2022-09-13 02:45:07 +00:00
Andrew Gu
c304a1206b [FSDP][Easy] Remove unused functions (#84598)
This removes some leftover functions from the constructor refactor.

Differential Revision: [D39293430](https://our.internmc.facebook.com/intern/diff/D39293430)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84598
Approved by: https://github.com/zhaojuanmao
2022-09-12 20:04:23 +00:00
PyTorch MergeBot
219ff26172 Revert "Add __all__ for a few distributed modules plus a little typing (#84119)"
This reverts commit 6f21680563.

Reverted https://github.com/pytorch/pytorch/pull/84119 on behalf of https://github.com/izaitsevfb due to breaking internal builds, see D39386448
2022-09-09 20:01:07 +00:00
Chien-Chin Huang
28c830ac07 [FSDP] Optimizer states may be on CPU, copy them to GPU before gathering (#84708)
**Background**:
Optimizer states may not always on GPUs. Some examples include, 1.) CPU offload is enable, 2.) after lightning trainer fit() is called.

**What Does This PR Do?**
If states are not on GPUs, move them to GPUs before gathering the global states.

Differential Revision: [D39332300](https://our.internmc.facebook.com/intern/diff/D39332300/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84708
Approved by: https://github.com/awgu
2022-09-09 17:06:10 +00:00
Chien-Chin Huang
1840f24df7 [FSDP] Ensure that all ranks use the same order to iterate through optimizer states (#84654)
**Background:**
Optimizer states are of the type `Dict[int, Dict[str, torch.Tensor]]` and the order of `dict.items()`  is the creation order of keys. Without checkpoint (state_dict/load_state_dict), the creation order of keys depends on the implementation of the optimizer (e.g., Adam seems to creates `exp_avg` then `exp_avg_sq`). However, when loading states from a checkpoint, since the optimizer state are lazily initialized, the order depends on the user code (reading state_dict from IO). See the following example:

```
optimizer_state_dict = USER_CODE_TO_READ_STATE_FROM_IO()
optimizer.load_state_dict(optimizer_state_dict)
```
The key order of `optimizer_state_dict` depends on `USER_CODE_TO_READ_STATE_FROM_IO` and there is no guarantee that the order is the same across ranks.

**What Can Go Wrong?**
After the first checkpoint load, the key order of optimizer may not be the same on different ranks. When users try to save another checkpoint, user will call `_unflatten_optim_state()` to save the optimizer states. Inside `_unflatten_optim_state()`, `dict.itmes()` will be called to iterate all the local optimizer state and `all_gather()` will be used to gather the local states. Since the order may be different across ranks, the gathered states are not correct.

We have seen some models get NaN loss after the second checkpoint load because of this issue.

**What This PR Does?**
This PR implements a `sorted_items()` to return sorted `(key, value)` pairs. We can do this because the key is either an integer or a string.

Differential Revision: [D39315184](https://our.internmc.facebook.com/intern/diff/D39315184/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84654
Approved by: https://github.com/awgu
2022-09-09 07:19:01 +00:00
Shen Li
2211949513 Moving CommTensor from tests to private _spmd folder (#84719)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84719
Approved by: https://github.com/wanchaol
2022-09-09 06:25:42 +00:00
Rodrigo Kumpera
6f21680563 Add __all__ for a few distributed modules plus a little typing (#84119)
This handles distributed_c10d, which is massive and ddp_comm_hooks.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/84119
Approved by: https://github.com/rohan-varma
2022-09-08 23:28:31 +00:00
PyTorch MergeBot
a6e6276c8b Revert "Moving CommTensor from tests to private _spmd folder (#84655)"
This reverts commit 07dad15583.

Reverted https://github.com/pytorch/pytorch/pull/84655 on behalf of https://github.com/kit1980 due to Several test failures on trunk 07dad15583, PR also had failures
2022-09-08 19:28:38 +00:00
Shen Li
07dad15583 Moving CommTensor from tests to private _spmd folder (#84655)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84655
Approved by: https://github.com/wanchaol
2022-09-08 17:25:38 +00:00
Rodrigo Kumpera
e96fb5d58c [c10d] Fix docstring of scatter_object_list (#84596)
The docstring for scatter_object_list mentions is doesn't work with NCCL, but this was fixed in #79034

Pull Request resolved: https://github.com/pytorch/pytorch/pull/84596
Approved by: https://github.com/H-Huang
2022-09-07 14:49:45 +00:00
Bin Chen
06ebe2d5bc Add watchdog to TorchElastic agent and trainers (#84081)
Summary:
D38604238 (3b11b80fc3) introduced a named pipe based watchdog timer.

This diff uses the named pipe based watchdog timer in TorchElastic agent and training worker processes (in the StuckJobDetector class) to allow the TorchElastic agent to detect the stuck of a training process, and kill the process to create a core dump.

Test Plan:
```
buck test mode/dev-nosan //caffe2/test/distributed/elastic/agent/server/test:local_agent_test
```
```
RemoteExecution session id: reSessionID-0bfcacef-24d1-42bc-a1d3-f3058fc42b2f-tpx
Started reporting to test run: https://www.internalfb.com/intern/testinfra/testrun/7318349503394739
    ✓ ListingSuccess: caffe2/test/distributed/elastic/agent/server/test:local_agent_test : 55 tests discovered (22.699)
    ✓ Pass: caffe2/test/distributed/elastic/agent/server/test:local_agent_test - test_barrier_failed_etcd_v2 (local_elastic_agent_test.LocalElasticAgentTest) (47.140)
    ✓ Pass: caffe2/test/distributed/elastic/agent/server/test:local_agent_test - test_run_distributed_sum_homogeneous_etcd_v2 (local_elastic_agent_test.LocalElasticAgentTest) (49.198)
    ✓ Pass: caffe2/test/distributed/elastic/agent/server/test:local_agent_test - test_run_happy_function_c10d (local_elastic_agent_test.LocalElasticAgentTest) (46.387)
    ✓ Pass: caffe2/test/distributed/elastic/agent/server/test:local_agent_test - test_run_happy_function_etcd_v2 (local_elastic_agent_test.LocalElasticAgentTest) (46.094)
    ✓ Pass: caffe2/test/distributed/elastic/agent/server/test:local_agent_test - test_run_bipolar_function_etcd (local_elastic_agent_test.LocalElasticAgentTest) (106.342)
    ✓ Pass: caffe2/test/distributed/elastic/agent/server/test:local_agent_test - test_correct_rank_assignment_homogeneous_etcd_v2 (local_elastic_agent_test.LocalElasticAgentTest) (64.888)
    ✓ Pass: caffe2/test/distributed/elastic/agent/server/test:local_agent_test - test_correct_rank_assignment_homogeneous_etcd (local_elastic_agent_test.LocalElasticAgentTest) (69.158)
    ✓ Pass: caffe2/test/distributed/elastic/agent/server/test:local_agent_test - test_run_agent_local_watchdog_setup_enabled_etcd (local_elastic_agent_test.LocalElasticAgentTest) (46.965)
    ✓ Pass: caffe2/test/distributed/elastic/agent/server/test:local_agent_test - test_double_agent_elastic_etcd_v2 (local_elastic_agent_test.LocalElasticAgentTest) (79.626)
    ✓ Pass: caffe2/test/distributed/elastic/agent/server/test:local_agent_test - test_run_function_with_return_value_etcd_v2 (local_elastic_agent_test.LocalElasticAgentTest) (46.113)
    ✓ Pass: caffe2/test/distributed/elastic/agent/server/test:local_agent_test - test_run_sad_function_etcd (local_elastic_agent_test.LocalElasticAgentTest) (46.487)
    ✓ Pass: caffe2/test/distributed/elastic/agent/server/test:local_agent_test - test_shutdown_called_etcd_v2 (local_elastic_agent_test.LocalElasticAgentTest) (24.358)
    ✓ Pass: caffe2/test/distributed/elastic/agent/server/test:local_agent_test - test_torch_rpc_c10d (local_elastic_agent_test.LocalElasticAgentTest) (48.216)
    ✓ Pass: caffe2/test/distributed/elastic/agent/server/test:local_agent_test - test_run_distributed_sum_homogeneous_c10d (local_elastic_agent_test.LocalElasticAgentTest) (48.433)
    ✓ Pass: caffe2/test/distributed/elastic/agent/server/test:local_agent_test - test_torch_rpc_etcd_v2 (local_elastic_agent_test.LocalElasticAgentTest) (47.029)
    ✓ Pass: caffe2/test/distributed/elastic/agent/server/test:local_agent_test - test_simple_dist_sum_etcd_v2 (local_elastic_agent_test.LocalElasticAgentTest) (44.357)
    ✓ Pass: caffe2/test/distributed/elastic/agent/server/test:local_agent_test - test_check_master_addr_port_override_etcd_v2 (local_elastic_agent_test.LocalElasticAgentTest) (45.176)
    ✓ Pass: caffe2/test/distributed/elastic/agent/server/test:local_agent_test - test_run_check_nccl_async_error_handling_env_default_c10d (local_elastic_agent_test.LocalElasticAgentTest) (45.980)
    ✓ Pass: caffe2/test/distributed/elastic/agent/server/test:local_agent_test - test_simple_dist_sum_c10d (local_elastic_agent_test.LocalElasticAgentTest) (47.151)
    ✓ Pass: caffe2/test/distributed/elastic/agent/server/test:local_agent_test - test_simple_dist_sum_etcd (local_elastic_agent_test.LocalElasticAgentTest) (44.614)
    ✓ Pass: caffe2/test/distributed/elastic/agent/server/test:local_agent_test - test_correct_rank_assignment_heterogeneous_etcd (local_elastic_agent_test.LocalElasticAgentTest) (69.099)
    ✓ Pass: caffe2/test/distributed/elastic/agent/server/test:local_agent_test - test_run_agent_local_watchdog_setup_enabled_c10d (local_elastic_agent_test.LocalElasticAgentTest) (45.367)
    ✓ Pass: caffe2/test/distributed/elastic/agent/server/test:local_agent_test - test_shutdown_called_etcd (local_elastic_agent_test.LocalElasticAgentTest) (22.804)
    ✓ Pass: caffe2/test/distributed/elastic/agent/server/test:local_agent_test - test_double_agent_elastic_c10d (local_elastic_agent_test.LocalElasticAgentTest) (77.560)
    ✓ Pass: caffe2/test/distributed/elastic/agent/server/test:local_agent_test - test_dummy_compute_etcd_v2 (local_elastic_agent_test.LocalElasticAgentTest) (46.050)
    ✓ Pass: caffe2/test/distributed/elastic/agent/server/test:local_agent_test - test_run_distributed_sum_heterogeneous_c10d (local_elastic_agent_test.LocalElasticAgentTest) (48.088)
    ✓ Pass: caffe2/test/distributed/elastic/agent/server/test:local_agent_test - test_double_agent_elastic_etcd (local_elastic_agent_test.LocalElasticAgentTest) (77.286)
    ✓ Pass: caffe2/test/distributed/elastic/agent/server/test:local_agent_test - test_double_agent_fault_tolerance_etcd_v2 (local_elastic_agent_test.LocalElasticAgentTest) (50.670)
    ✓ Pass: caffe2/test/distributed/elastic/agent/server/test:local_agent_test - test_check_master_addr_port_override_etcd (local_elastic_agent_test.LocalElasticAgentTest) (45.631)
    ✓ Pass: caffe2/test/distributed/elastic/agent/server/test:local_agent_test - test_run_distributed_sum_heterogeneous_etcd_v2 (local_elastic_agent_test.LocalElasticAgentTest) (50.867)
    ✓ Pass: caffe2/test/distributed/elastic/agent/server/test:local_agent_test - test_double_agent_fault_tolerance_etcd (local_elastic_agent_test.LocalElasticAgentTest) (51.095)
    ✓ Pass: caffe2/test/distributed/elastic/agent/server/test:local_agent_test - test_run_happy_function_etcd (local_elastic_agent_test.LocalElasticAgentTest) (45.000)
    ✓ Pass: caffe2/test/distributed/elastic/agent/server/test:local_agent_test - test_run_sad_function_etcd_v2 (local_elastic_agent_test.LocalElasticAgentTest) (45.197)
    ✓ Pass: caffe2/test/distributed/elastic/agent/server/test:local_agent_test - test_run_distributed_sum_homogeneous_etcd (local_elastic_agent_test.LocalElasticAgentTest) (46.873)
    ✓ Pass: caffe2/test/distributed/elastic/agent/server/test:local_agent_test - test_shutdown_called_c10d (local_elastic_agent_test.LocalElasticAgentTest) (23.160)
    ✓ Pass: caffe2/test/distributed/elastic/agent/server/test:local_agent_test - test_barrier_failed_etcd (local_elastic_agent_test.LocalElasticAgentTest) (43.632)
    ✓ Pass: caffe2/test/distributed/elastic/agent/server/test:local_agent_test - test_torch_rpc_etcd (local_elastic_agent_test.LocalElasticAgentTest) (44.536)
    ✓ Pass: caffe2/test/distributed/elastic/agent/server/test:local_agent_test - test_run_bipolar_function_c10d (local_elastic_agent_test.LocalElasticAgentTest) (89.859)
    ✓ Pass: caffe2/test/distributed/elastic/agent/server/test:local_agent_test - test_workers_drift_fail_etcd (local_elastic_agent_test.LocalElasticAgentTest) (48.277)
    ✓ Pass: caffe2/test/distributed/elastic/agent/server/test:local_agent_test - test_run_check_nccl_async_error_handling_env_c10d (local_elastic_agent_test.LocalElasticAgentTest) (43.930)
    ✓ Pass: caffe2/test/distributed/elastic/agent/server/test:local_agent_test - test_run_bipolar_function_etcd_v2 (local_elastic_agent_test.LocalElasticAgentTest) (87.677)
    ✓ Pass: caffe2/test/distributed/elastic/agent/server/test:local_agent_test - test_workers_drift_success_etcd_v2 (local_elastic_agent_test.LocalElasticAgentTest) (48.965)
    ✓ Pass: caffe2/test/distributed/elastic/agent/server/test:local_agent_test - test_workers_drift_fail_etcd_v2 (local_elastic_agent_test.LocalElasticAgentTest) (50.143)
    ✓ Pass: caffe2/test/distributed/elastic/agent/server/test:local_agent_test - test_workers_drift_success_etcd (local_elastic_agent_test.LocalElasticAgentTest) (46.781)
    ✓ Pass: caffe2/test/distributed/elastic/agent/server/test:local_agent_test - test_run_function_with_return_value_etcd (local_elastic_agent_test.LocalElasticAgentTest) (45.152)
    ✓ Pass: caffe2/test/distributed/elastic/agent/server/test:local_agent_test - test_barrier_failed_c10d (local_elastic_agent_test.LocalElasticAgentTest) (44.832)
    ✓ Pass: caffe2/test/distributed/elastic/agent/server/test:local_agent_test - test_run_function_with_return_value_c10d (local_elastic_agent_test.LocalElasticAgentTest) (45.281)
    ✓ Pass: caffe2/test/distributed/elastic/agent/server/test:local_agent_test - test_correct_rank_assignment_heterogeneous_etcd_v2 (local_elastic_agent_test.LocalElasticAgentTest) (74.968)
    ✓ Pass: caffe2/test/distributed/elastic/agent/server/test:local_agent_test - test_run_agent_local_watchdog_setup_disabled_c10d (local_elastic_agent_test.LocalElasticAgentTest) (46.141)
    ✓ Pass: caffe2/test/distributed/elastic/agent/server/test:local_agent_test - test_dummy_compute_c10d (local_elastic_agent_test.LocalElasticAgentTest) (44.960)
    ✓ Pass: caffe2/test/distributed/elastic/agent/server/test:local_agent_test - test_dummy_compute_etcd (local_elastic_agent_test.LocalElasticAgentTest) (45.292)
    ✓ Pass: caffe2/test/distributed/elastic/agent/server/test:local_agent_test - test_run_agent_local_watchdog_setup_disabled_etcd (local_elastic_agent_test.LocalElasticAgentTest) (44.611)
    ✓ Pass: caffe2/test/distributed/elastic/agent/server/test:local_agent_test - test_run_check_env_function_etcd (local_elastic_agent_test.LocalElasticAgentTest) (44.939)
    ✓ Pass: caffe2/test/distributed/elastic/agent/server/test:local_agent_test - test_run_distributed_sum_heterogeneous_etcd (local_elastic_agent_test.LocalElasticAgentTest) (47.609)
    ✓ Pass: caffe2/test/distributed/elastic/agent/server/test:local_agent_test - test_run_sad_function_c10d (local_elastic_agent_test.LocalElasticAgentTest) (45.628)
Summary
  Pass: 55
  ListingSuccess: 1
Finished test run: https://www.internalfb.com/intern/testinfra/testrun/7318349503394739
```
-----------
```
buck test caffe2/torch/fb/trainer/stuck_detection/tests:stuck_job_detector_test
```
```
RemoteExecution session id: reSessionID-607a0028-4095-4dfc-b657-55f0807fe621-tpx
Started reporting to test run: https://www.internalfb.com/intern/testinfra/testrun/8162774432794818
    ✓ ListingSuccess: caffe2/torch/fb/trainer/stuck_detection/tests:stuck_job_detector_test : 11 tests discovered (39.037)
    ✓ Pass: caffe2/torch/fb/trainer/stuck_detection/tests:stuck_job_detector_test - test_thrift_api_called (caffe2.torch.fb.trainer.stuck_detection.tests.collect_quickstack_test.CollectQuickstackTrace) (0.655)
    ✓ Pass: caffe2/torch/fb/trainer/stuck_detection/tests:stuck_job_detector_test - test_setup_local_watchdog (caffe2.torch.fb.trainer.stuck_detection.tests.stuck_job_detector_test.StuckJobDetectorTest) (36.510)
    ✓ Pass: caffe2/torch/fb/trainer/stuck_detection/tests:stuck_job_detector_test - test_dont_print_when_job_normal (caffe2.torch.fb.trainer.stuck_detection.tests.stuck_job_detector_test.StuckJobDetectorTest) (36.727)
    ✓ Pass: caffe2/torch/fb/trainer/stuck_detection/tests:stuck_job_detector_test - test_send_watchdog_request_on_batch_callbacks_no_server (caffe2.torch.fb.trainer.stuck_detection.tests.stuck_job_detector_test.StuckJobDetectorTest) (37.060)
    ✓ Pass: caffe2/torch/fb/trainer/stuck_detection/tests:stuck_job_detector_test - test_quickstack_stuck_job (caffe2.torch.fb.trainer.stuck_detection.tests.stuck_job_detector_test.StuckJobDetectorTest) (37.242)
    ✓ Pass: caffe2/torch/fb/trainer/stuck_detection/tests:stuck_job_detector_test - test_setup_local_watchdog_disabled (caffe2.torch.fb.trainer.stuck_detection.tests.stuck_job_detector_test.StuckJobDetectorTest) (37.243)
    ✓ Pass: caffe2/torch/fb/trainer/stuck_detection/tests:stuck_job_detector_test - test_print_stack_trace_when_job_stuck (caffe2.torch.fb.trainer.stuck_detection.tests.stuck_job_detector_test.StuckJobDetectorTest) (37.590)
    ✓ Pass: caffe2/torch/fb/trainer/stuck_detection/tests:stuck_job_detector_test - test_print_when_stuck (caffe2.torch.fb.trainer.stuck_detection.tests.stuck_job_detector_test.StuckJobDetectorTest) (37.590)
    ✓ Pass: caffe2/torch/fb/trainer/stuck_detection/tests:stuck_job_detector_test - test_setup_local_watchdog_no_file (caffe2.torch.fb.trainer.stuck_detection.tests.stuck_job_detector_test.StuckJobDetectorTest) (37.589)
    ✓ Pass: caffe2/torch/fb/trainer/stuck_detection/tests:stuck_job_detector_test - test_signposts_stack_trace_when_job_stuck (caffe2.torch.fb.trainer.stuck_detection.tests.stuck_job_detector_test.StuckJobDetectorTest) (38.132)
    ✓ Pass: caffe2/torch/fb/trainer/stuck_detection/tests:stuck_job_detector_test - test_send_watchdog_request_on_batch_callbacks (caffe2.torch.fb.trainer.stuck_detection.tests.stuck_job_detector_test.StuckJobDetectorTest) (38.133)
Summary
  Pass: 11
  ListingSuccess: 1
Finished test run: https://www.internalfb.com/intern/testinfra/testrun/8162774432794818
```

Differential Revision: D38930476

Pull Request resolved: https://github.com/pytorch/pytorch/pull/84081
Approved by: https://github.com/d4l3k
2022-09-07 00:17:20 +00:00
Masaki Kozuki
ab6c57217a Add NCCL PreMul Sum to c10d redce ops (#84243)
This is based on #81272 but this conforms to TorchScript Compiler

## TODO
- [ ] Update abaf8112e6/torch/csrc/distributed/c10d/ProcessGroupUCC.cpp (L64-L73) to use `ReduceOp::RedOpType`. In my first try with `USE_SYSTEM_UCC=1`, this change wasn't necessary (I think) because of `ReduceOp::RedOpType` operator. That being said, I want to make it more explicit.

cc @ptrblck @kwen2501 @aazzolini
cc @zasdfgbnm for visibility to the TODO above
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84243
Approved by: https://github.com/kwen2501
2022-09-02 21:57:45 +00:00
Andrew Gu
88802719b6 [FSDP][Easy] Move utils to _utils.py (#84212)
I pulled this out into a separate PR. This just moves some utility functions to `_utils.py`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84212
Approved by: https://github.com/rohan-varma
2022-09-01 19:27:51 +00:00
Rodrigo Kumpera
7a348a1d4a Fix internal breakage caused by #82134 (#84363)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84363
Approved by: https://github.com/rohan-varma, https://github.com/mehtanirav
2022-09-01 17:54:10 +00:00
Chien-Chin Huang
305c6a6c35 [FSDP] Fix the FQN not found issue for load sharded_state_dict when using activation checkpoint (#84253)
The current sharded_state_dict load will fail if activation checkpoint is also enabled. This PR fixes the issue.

Differential Revision: [D39125431](https://our.internmc.facebook.com/intern/diff/D39125431/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84253
Approved by: https://github.com/awgu
2022-08-31 23:05:46 +00:00
Andrew Gu
84ceebebf9 [FSDP] ufmt flat_param.py, flatten_params_wrapper.py (#83664)
I think we can move FSDP code to start using ufmt (https://ufmt.omnilib.dev/en/stable/) to unify formatting across developers. ufmt is the recommended formatter for PyTorch's Python code. If we have consensus, I can ufmt all of the FSDP code in follow-ups.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83664
Approved by: https://github.com/rohan-varma
2022-08-31 19:18:48 +00:00
Andrew Gu
762890d11e [FSDP] Retire self.device_id; clean up ctor (#83663)
### Overview
This PR retires `self.device_id` by coalescing it with `self.compute_device` and more generally cleans up the FSDP constructor.

### Existing FSDP Constructor Semantics (In Order)
1. Compute the ignored parameters/modules from `ignored_modules` and the buffer names (to avoid cloning in `state_dict()`)
2. Recursively auto wrap if needed
5. Define process group attributes
6. Determine `device_id`
7. Materialize the wrapped module if using meta device or `torchdistX` deferred initialization
8. Move the module if needed (based on `self.device_id`)
9. Determine `compute_device`
10. Define `training_state`, gradient divide factors, FSDP feature-related attributes (`cpu_offload`, `forward_prefetch`, `backward_prefetch`, `sharding_strategy`, `mixed_precision`), `_orig_buffer_dtypes`
11. Determine the parameters to flatten
12. Sync module states if `sync_module_states`
13. Initialize the `FlattenParamsWrapper` with the parameters to flatten and the wrapped module, which constructs the `FlatParameter`
14. Shard the `FlatParameter` (in-place)
15. Define `_is_root`, shared attributes (`_streams`, `_fsdp_graph_order`), prefetching attributes (`_my_fsdp_idx_in_graph`, `_pre_backward_hook_full_params_prefetched`, `_forward_full_params_prefetched`), `reshard_after_forward` -- all of this is done in `_reset_lazy_init()`
16. Define `_require_backward_grad_sync` to configure `no_sync()`
17. Define state dict attributes (`_state_dict_type`, `_state_dict_config`) and register state dict hooks
18. Define backward pass flags (`_pre_backward_hook_has_run`, `_need_rebuild_full_params`)
19. Move `FlatParameter`s to CPU if `cpu_offload.offload_params`
20. Define `_exec_order_data` for execution order validation
21. Define communication hook attributes (`communication_hook`, `communication_hook_state`, `_hook_registered`)

### Notable Changes
- `self.mixed_precision`
    - **Before:** `self.mixed_precision` itself could be `None`. Equivalently, `self.mixed_precision` could be `MixedPrecision(None, None, None)`. Both would disable mixed precision completely.
    - **After:** `self.mixed_precision` itself is never `None`. We only have `MixedPrecision(None, None, None)` (default construction of the `dataclass`) to disable mixed precision. This catches the issue that for `test_summon_full_params.py`, we were passing `MixedPrecision(None, None, None)` when we wanted to actually enable mixed precision.
- `cpu_offload.offload_params=True` + `device_id`
    - **Before:** For nested FSDP and `device_id` specified, `FlatParameter`s already offloaded to CPU are moved back to GPU and not re-offloaded to CPU.
    - **After:** The nested `FlatParameter`s are re-offloaded to CPU. This is a temporary hack. The ideal solution removes the `module = module.to(<GPU device>)` in the first place and only moves the relevant parameters. Because the `module.to()` implementation has some complexity, I did not want to remove that call in this PR.
- `device_id` and `compute_device`
    -  **Before:** `self.device_id` is either `None` or equal to `self.compute_device`. `self.device_id` is not used after the FSDP constructor.
    - **After:** `self.device_id` is removed and instead coalesced with `self.compute_device`. The only semantic change is that `test_module_device_mismatches_device_id()` errors earlier (but importantly, still errors).
    - This PR also uses a helper method `_get_orig_params()`, which is more robust and may avoid issues like https://github.com/pytorch/pytorch/issues/82891 without having to gate higher-level logic.
- `_reset_lazy_init()` attributes
    - **Before:** Some attributes were being _defined_ in `_reset_lazy_init()` (which may not be obvious to all devs).
    - **After:** For this PR, we define these attributes in the constructor but leave `_reset_lazy_init()` as is. In the follow-ups, this gets further refactored.
- Otherwise, I simply moved some logic into their own methods and reorganized the attribute definitions to be grouped logically.

### Follow-Ups
1. What should the specification be for `device_id` + `ignored_modules`?
2. Investigate removing the `module = module.to(<GPU device>)` in favor of moving per parameter.
3. Should we call `_reset_lazy_init()` in `register_comm_hook()`?

Pull Request resolved: https://github.com/pytorch/pytorch/pull/83663
Approved by: https://github.com/zhaojuanmao, https://github.com/rohan-varma
2022-08-31 18:24:37 +00:00
Andrew Gu
b8ee810144 [Easy][FSDP] Update StateDictType doc (#84200)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84200
Approved by: https://github.com/rohan-varma
2022-08-30 18:31:46 +00:00
Andrew Gu
7f58db7424 [Easy][FSDP] ufmt _optim_utils.py (#84199)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84199
Approved by: https://github.com/rohan-varma
2022-08-30 18:31:33 +00:00
Rodrigo Kumpera
65dc5dd3f3 [c10d] Introduce dist.get_local_rank, dist.get_global_rank and dist.get_global_ranks (#82134)
Those functions enable membership introspection into a ProcessGroup. A common scenario
that needs this is library code that consumes a PG but doesn't create it, which means
it likely doesn't know the global ranks used to create it.

Translating from local to global is necessary when using c10d collectives like broadcast
so if your library code adopts the convention of using local rank 0, it needs
to the following:

```python
import torch.distributed as dist

my_pg: dist.ProcessGroup = ...

def my_library_bcast(tensor)
    dist.broadcast(tensor, src=dist.get_global_rank(my_pg, local_rank=0), my_pg)

```

This implements some of the helpers needed to implement the `clone` API from: https://github.com/pytorch/pytorch/issues/81291
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82134
Approved by: https://github.com/rohan-varma
2022-08-30 17:45:00 +00:00
Andrew Gu
f0efc1c2d1 [Easy][FSDP] Fix sharded optim state dict doc formatting (#84198)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84198
Approved by: https://github.com/rohan-varma
2022-08-30 15:52:10 +00:00
Rohan Varma
8acc92eb00 [FSDP] Print exec order only in debug mode (#83868)
Since exec order warning can result in very long module name print out, gating this only to be printing in debug mode. Oftentimes such as in multiModal training, there is not a lot we can do about this warning since some modules go unused in certain iterations.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83868
Approved by: https://github.com/awgu
2022-08-29 17:10:25 +00:00
Rohan Varma
1a53e35b9d Enforce explicit ProcessGroup passed into DefaultState (#84105)
Would prefer to enforce that users pass in explicit PG into these state objects when using comm hooks with FSDP, so that it is clear and easy debugable over which processes communication is taking place.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/84105
Approved by: https://github.com/mrshenli, https://github.com/zhaojuanmao
2022-08-29 14:52:58 +00:00
Rodrigo Kumpera
f66be71d77 [checkpoint] Adopt Planner interface across the board. (#83781)
Change StorageReader and StorageWriter to follow the new SavePlanner / LoadPlanner design.

Add optional planner param to load_state_dict and save_state_dict and implement the new protocol.

This includes a small rework of FileSystem layer to support single file per rank and making fsync optional to match torch.save behavior.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/83781
Approved by: https://github.com/wanchaol, https://github.com/fduwjj
2022-08-29 14:38:32 +00:00
PyTorch MergeBot
5cf4542f86 Revert "Enforce explicit ProcessGroup passed into DefaultState (#84105)"
This reverts commit adc9a1e2fb.

Reverted https://github.com/pytorch/pytorch/pull/84105 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally
2022-08-28 14:30:18 +00:00
Rohan Varma
adc9a1e2fb Enforce explicit ProcessGroup passed into DefaultState (#84105)
Would prefer to enforce that users pass in explicit PG into these state objects when using comm hooks with FSDP, so that it is clear and easy debugable over which processes communication is taking place.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/84105
Approved by: https://github.com/mrshenli, https://github.com/zhaojuanmao
2022-08-27 03:12:20 +00:00
Rohan Varma
d2f37401b8 Silence namedtuple warning in dist (#84072)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84072
Approved by: https://github.com/awgu
2022-08-26 00:24:28 +00:00
Rohan Varma
b35e7c5da7 Fix FSDP not all outputs used in loss (#83195)
There are a couple issues / assumptions within FSDP today that this PR attempts to fix:

- In wait_for_post_backward, we assume that if a param required grad, its post backward was called, but this is not true, i.e. if its output did not participate in grad computation, it would not have called post backward. To fix this we simply removed those assertions.
- There is a deeper issue where in `_finalize_params`, we could end up assigning a grad of the sharded shape to an unsharded parameter gradient field, which would raise a shape error. This can happen for example if a parameter's usage transitions from used --> unused. In this case, when the parameter was used, it would have had a gradient, then user could have possibly called `zero_grad()` and p.grad would not be `None`. This in `_prep_grad_for_backward`, we would assign a `_saved_grad_shard` to this gradient field which would be the sharded shape. In `_finalize_param`, our parameter would be unsharded (since post_backward was not called), but we'd try to assign, raising the shape issue. This issue is fixed by checking `_post_backward_called`. If this is False, we simply skip the assignment because there is no new gradient to update.
- A final issue as mentioned above is that if post_backward is not called, we never reshard the full param. This is fixed by checking if we haven't resharded (basically if post_backward_called == False), and if so, performing a reshard.

A few things to note:
- This logic may have to be revisited when non-recursive wrapping lands as there are multiple FlatParams per FSDP unit
- This logic may not work when post_backward_hook fires but p.grad is None, i.e. the short-circuiting here: f534b2c627/torch/distributed/fsdp/fully_sharded_data_parallel.py (L2884). As a quick fix, we could just move `_post_backward_called` flag change to after this, or just perform a reshard before returning early. I am not sure how to repro a case where p.grad == None but we call the post-backward hook, https://github.com/pytorch/pytorch/issues/83197 might be a possibility, but I think it is fine to not support this yet.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/83195
Approved by: https://github.com/awgu
2022-08-26 00:24:28 +00:00
PyTorch MergeBot
1f61c39ac4 Revert "Support NCCL Premul Sum (#81272)"
This reverts commit 432c508e71.

Reverted https://github.com/pytorch/pytorch/pull/81272 on behalf of https://github.com/weiwangmeta due to breaking internal builds
2022-08-25 05:01:37 +00:00
Bin Chen
3b11b80fc3 Named pipe based watchdog timer (#83695)
Summary:
This diff implements a named pipe based watchdog timer (`FileTimerClient` and `FileTimerServer`). This is similar to the existing `LocalTimerClient` and `LocalTimerServer` (https://fburl.com/code/j4b9pyya).

The motivation is from the need of handling various timeout issues. The training process occasionally get stuck. We need a proper watchdog to monitor the liveness of the training processes. This timer allows the TorchElastic agent (as the watchdog) to monitor the progress of the training processes that it spawned. If a timeout occurred, he TorchElastic agent can take some action to kill the stuck process and creating a core dump for it.

`LocalTimerClient` and `LocalTimerServer` require  a `multiprocessing.Queue()` to work. So they can only be used between `multiprocessing` parent and child processes.

`FileTimerClient` and `FileTimerServer` does not have such limitation.

Test Plan:
### Unit Test
```
buck test mode/opt caffe2/test/distributed/elastic/timer:file_based_timer_test
```
```
RemoteExecution session id: reSessionID-06d70a77-043c-4d9d-b0f2-94c24460740a-tpx
Started reporting to test run: https://www.internalfb.com/intern/testinfra/testrun/844425186732666
    ✓ ListingSuccess: caffe2/test/distributed/elastic/timer:file_based_timer_test : 12 tests discovered (2.177)
    ✓ Pass: caffe2/test/distributed/elastic/timer:file_based_timer_test - test_happy_path (file_based_local_timer_test.FileTimerTest) (2.463)
    ✓ Pass: caffe2/test/distributed/elastic/timer:file_based_timer_test - test_expired_timers (file_based_local_timer_test.FileTimerServerTest) (1.889)
    ✓ Pass: caffe2/test/distributed/elastic/timer:file_based_timer_test - test_send_request_release (file_based_local_timer_test.FileTimerServerTest) (1.700)
    ✓ Pass: caffe2/test/distributed/elastic/timer:file_based_timer_test - test_valid_timers (file_based_local_timer_test.FileTimerServerTest) (1.873)
    ✓ Pass: caffe2/test/distributed/elastic/timer:file_based_timer_test - test_watchdog_call_count (file_based_local_timer_test.FileTimerServerTest) (1.715)
    ✓ Pass: caffe2/test/distributed/elastic/timer:file_based_timer_test - test_watchdog_empty_queue (file_based_local_timer_test.FileTimerServerTest) (1.609)
    ✓ Pass: caffe2/test/distributed/elastic/timer:file_based_timer_test - test_exception_propagation (file_based_local_timer_test.FileTimerTest) (1.633)
    ✓ Pass: caffe2/test/distributed/elastic/timer:file_based_timer_test - test_multiple_clients_interaction (file_based_local_timer_test.FileTimerTest) (2.189)
    ✓ Pass: caffe2/test/distributed/elastic/timer:file_based_timer_test - test_get_timer_recursive (file_based_local_timer_test.FileTimerTest) (2.295)
    ✓ Pass: caffe2/test/distributed/elastic/timer:file_based_timer_test - test_no_client (file_based_local_timer_test.FileTimerTest) (1.753)
    ✓ Pass: caffe2/test/distributed/elastic/timer:file_based_timer_test - test_timer (file_based_local_timer_test.FileTimerTest) (2.151)
    ✓ Pass: caffe2/test/distributed/elastic/timer:file_based_timer_test - test_client_interaction (file_based_local_timer_test.FileTimerTest) (1.895)
Summary
  Pass: 12
  ListingSuccess: 1
Finished test run: https://www.internalfb.com/intern/testinfra/testrun/844425186732666
```

Differential Revision: D38604238

Pull Request resolved: https://github.com/pytorch/pytorch/pull/83695
Approved by: https://github.com/d4l3k
2022-08-24 22:16:12 +00:00
Masaki Kozuki
432c508e71 Support NCCL Premul Sum (#81272)
This PR adds the support for https://docs.nvidia.com/deeplearning/nccl/archives/nccl_21212/user-guide/docs/api/ops.html?highlight=premul#c.ncclRedOpCreatePreMulSum.

The major changes include
- convert enum ReduceOp to struct
- add premul sum specific paths to init.cpp and Ops.cpp.

note:
- For pip wheels / conda binaries to support this, ~~I think https://github.com/pytorch/pytorch/pull/79132 would be needed~~ https://github.com/pytorch/pytorch/pull/82775 landed

The commit titled "add nccl premul" whose current hash is cb99ad6744 was authored by @mcarilli and @ptrblck.

cc @ptrblck
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81272
Approved by: https://github.com/kwen2501
2022-08-24 04:53:25 +00:00
Sergii Dymchenko
591222f5d9 Fix use-dict-literal lint (#83718)
Fix use-dict-literal pylint suggestions by changing `dict()` to `{}`. This PR should do the change for every Python file except test/jit/test_list_dict.py, where I think the intent is to test the constructor.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83718
Approved by: https://github.com/albanD
2022-08-24 00:26:46 +00:00
Rohan Varma
b29a074882 [BE] Revert distributed change in https://github.com/pytorch/pytorch/pull/68779 (#83181)
https://github.com/pytorch/pytorch/issues/82641 points out a regression in how inputs / outputs are processed by DDP, blocking their HF use case. It was narrowed down to https://github.com/pytorch/pytorch/pull/68779 and reverting the distributed change there fixes the issue.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/83181
Approved by: https://github.com/kumpera
2022-08-23 02:38:23 +00:00
Rohan Varma
4e90526a4f [FSDP] Remove unneeded checks (#83150)
@awgu pointed out these checks aren't really doing anything, as they just make sure we're setting training state in certain ways throughout FSDP and is sort of arbitrary. So, removing them to avoid confusion.

We still keep the checking around `_post_backward_called` because this is needed in `finalize_params` for now.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/83150
Approved by: https://github.com/awgu
2022-08-23 02:38:23 +00:00
joncrall
b136f3f310 More doctest refinements. (#83317)
Follow up to #82797

Now that the doctests themselves are in a better state, we should be able to enable xdoctest on the CI so they stay that way.

@ezyang @vadimkantorov
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83317
Approved by: https://github.com/ezyang
2022-08-22 20:07:26 +00:00
Taylor Robie
1fa9a377d0 [Profiler] Start moving python bindings out of autograd (#82584)
A lot of profiler code still lives in autograd for historic reasons. However as we formalize and clean up profiler internals it makes sense to pull more and more into the profiler folders/namespace. For now I'm just moving some of the core config data structures and those related to `torch::profiler::impl::Result` to keep the scope manageable.

Differential Revision: [D37961462](https://our.internmc.facebook.com/intern/diff/D37961462/)

**NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D37961462/)!
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82584
Approved by: https://github.com/albanD, https://github.com/Gamrix
2022-08-19 17:15:18 +00:00
Rodrigo Kumpera
d11d3dd036 [dist.cp] Introduce LoadPlanner and SavePlanner extensibility API. (#83419)
The planners come with default implementations in default_planner.py.

The default planners expose their core functionality as separate functions
to make it easy for other checkpoint implementations to use this functionality.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/83419
Approved by: https://github.com/wanchaol
2022-08-18 19:40:15 +00:00
Olga Andreeva
f204afc2bb Added communication hook for sharded cases (#83254)
Fixes https://github.com/pytorch/pytorch/issues/79114

An implementation of a FSDP communication hook interface for a sharded strategies:

- Added `reduce_scatter_hook` to default hooks. Note the difference of `reduce_scatter` from `all_reduce`, it requires 2 tensors:`input_gradient` and `output` variables and stores result in `output`, which is further used as a summed gradient shard.
- Adjusted FSDP logic to return `reduce_scatter_hook` as a default communication hook for sharded strategies, `DefaultState` is the same for sharded and non-sharded strategies.
- Adjusted low-precision hooks to work with both `all_reduce` and `reduce_scatter` depending on whether `output` tensor is provided or not.

Test plan:

Added all existing sharded strategies as an input parameters to existing tests.
For`test_default_communication_hook_behaviour` double checked how a linear layer is sharded across workers. This test creates a simple net ``1 X N``, where ``N`` - is the number of workers. For sharded cases, ``N`` parameters are sharded across ``N`` workers. This test checks that after backward, each worker has a proper value in it's chunk of the gradient, or the whole gradient on every worker is equal to an expected value.

Checked that low-precision tests work for sharded cases.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83254
Approved by: https://github.com/rohan-varma, https://github.com/awgu
2022-08-18 18:41:14 +00:00
Chien-Chin Huang
3e1fc85b23 [FSDP] Implement sharded_optim_state_dict and flatten_sharded_optim_state_dict. (#77628)
As title

Differential Revision: [D36436496](https://our.internmc.facebook.com/intern/diff/D36436496/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77628
Approved by: https://github.com/awgu
2022-08-18 16:38:58 +00:00
Chien-Chin Huang
244690205f [FSDP] Use _init_from_local_tensor to create ShardedTensor to avoid communication overhead (#82911)
FSDP originally uses `_init_from_local_shards_and_global_metadata()` to create a ShardedTensor for sharded_state_dict(). We have seen some non-trivial overhead if the number of tensors is large. Using `_init_from_local_shards_and_global_metadata ` can significantly reduce the overhead. For a model with ~250 tensors in the state_dict trained with 16 GPUs, the original `sharded_state_dict` takes ~1.7 seconds and this PR reduces the overhead to ~0.6 seconds.

Differential Revision: [D38452170](https://our.internmc.facebook.com/intern/diff/D38452170/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82911
Approved by: https://github.com/awgu
2022-08-17 16:40:20 +00:00
fduwjj
d3a176a156 [PT-D][BE][TP perf 1/N] Get rid of unnecessary collectives in Embedding/EmbeddingBag and use autograd-enabled collectives (#81853)
These two ops (Embedding and EmbeddingBag for ShardedTensor) especially for row-wise sharding is very inefficient and hard to fit in the concept of future design. So this PR is trying to:
1. Remove all unnecessary collective communications. Only one gather and one reduce(or reduce scatter) is needed.
2. Use auto-grad enabled collectives so that we can use these ops in real model training.
3. Some minor code cleaning
4. Treat input differently when it's replicated tensor. (Will add more for this for the next few PRs).

Differential Revision: [D37965687](https://our.internmc.facebook.com/intern/diff/D37965687/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81853
Approved by: https://github.com/wanchaol
2022-08-17 04:32:41 +00:00
Rob Zinkov
ff75562cff Adding maximize to rprop (#81864)
Added the maximize flag #68052 to rprop optimizer and updates the respective tests.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81864
Approved by: https://github.com/albanD
2022-08-16 08:19:46 +00:00
Rohan Varma
794ae64174 [FSDP] Pass kwargs to load_state_dict (#83309)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83309
Approved by: https://github.com/awgu
2022-08-16 00:34:58 +00:00
joncrall
4618371da5 Integrate xdoctest - Rebased (#82797)
This is a new version of #15648 based on the latest master branch.

Unlike the previous PR where I fixed a lot of the doctests in addition to integrating xdoctest, I'm going to reduce the scope here. I'm simply going to integrate xdoctest, and then I'm going to mark all of the failing tests as "SKIP". This will let xdoctest run on the dashboards, provide some value, and still let the dashboards pass. I'll leave fixing the doctests themselves to another PR.

In my initial commit, I do the bare minimum to get something running with failing dashboards. The few tests that I marked as skip are causing segfaults. Running xdoctest results in 293 failed, 201 passed tests. The next commits will be to disable those tests. (unfortunately I don't have a tool that will insert the `#xdoctest: +SKIP` directive over every failing test, so I'm going to do this mostly manually.)

Fixes https://github.com/pytorch/pytorch/issues/71105

@ezyang
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82797
Approved by: https://github.com/ezyang
2022-08-12 02:08:01 +00:00
Rohan Varma
9690fbf9a8 FSDP namedtuple support (#83055)
- NamedTuple support is blocking MultiModal adoption. TODO: add test
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83055
Approved by: https://github.com/awgu
2022-08-10 16:44:37 +00:00
Xiang Gao
cda210e23b UCC PG build in CI (#81583)
- Modifies the current cmake build definitions to use `find_package` to find UCX and UCC installed in the system
- Install UCX and UCC in CUDA dockers
- Build PyTorch with `USE_UCC=1` in pipelines
- Currently, we are not running unit tests with the UCC PG. Those tests will be added in future PRs.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81583
Approved by: https://github.com/vtlam, https://github.com/malfet
2022-08-10 00:23:47 +00:00
Rohan Varma
5b2c03823d Generalize CheckpointWrapper (#83035)
Allow checkpoint_wrapper to take in the checkpoint functional impl. This decouples it from torch.utils.checkpoint and allows other checkpoint implementations to be used.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83035
Approved by: https://github.com/awgu
2022-08-09 23:35:39 +00:00
Sergii Dymchenko
a0b3854548 Change seperate -> separate (#83056)
One instance was caught by Meta-internal "exact-word-misspell" linter in D38505529.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83056
Approved by: https://github.com/huydhn, https://github.com/seemethere
2022-08-09 23:11:34 +00:00
Aashaka Shah
24a084eda6 [c10d] Fix async error in batch_isend_irecv (#82450)
Summary:
`batch_isend_irecv` previously required the use of `torch.cuda.synchronize` to avoid data race conditions. This was because the ncclStreams were recorderd in the returned ncclWork object _before_ a ncclGroupEnd by the `_batch_p2p_manager` was issued. Thus, the `req.wait()` was effectively waiting on nothing, leading to the later operators working on incorrect intermediate data.

This fix:
- keeps track of ncclStreams to wait on, and records them in the work objects after the batch manager issues a ncclGroupEnd
- renames the `_batch_p2p_manager` to `_coalescing_manager` for generality
- removes the explicit check for NCCL backend inside `_batch_p2p_manager` in distributed_c10.py and moves the manager start/end to ProcessGroup.hpp, in order to transparently work with all process groups

Test Plan: Modified the unittest for `batch_isend_irecv` to check that received tensors are the same as expected tensors. Verified that the test fails before the change, and passes after the change.

Differential Revision: D38100789

Pull Request resolved: https://github.com/pytorch/pytorch/pull/82450
Approved by: https://github.com/kwen2501
2022-08-08 17:50:22 +00:00
Rohan Varma
ab3c039910 Fix FSDP device_id when CPU offloading (#82892)
See https://github.com/pytorch/pytorch/issues/82891 for full context.

When we init FSDP with device_id + CPU offload, we could potentially hit a crash when an outer FSDP unit does not manage any params. What was happening is that it would end up getting a flat param of a child FSDP module, check the device of this, see it is CPU, and throw an error.

The fix is to avoid this check if we hit a flat param. Also fixes up the documentation of the function.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82892
Approved by: https://github.com/awgu
2022-08-07 19:06:32 +00:00
edward-io
e7ff9d44ad [fsdp] add ability to iterate through dataclasses in fsdp.utils (#82638)
### Description

previously FSDP was failing on a torchmultimodal model because `_apply_to_tensors` couldn't iterate over dataclasses.

### Issue

None

### Testing

unit test

Pull Request resolved: https://github.com/pytorch/pytorch/pull/82638
Approved by: https://github.com/rohan-varma
2022-08-05 18:34:31 +00:00
Wanchao Liang
cda8635a5e [_shard] only check shard metadata for copy_ (#82655)
copy_ does not restrict on tensor properites, it does not check things like requires_grad or dtype, so only check if the shard metadata are the same

Differential Revision: [D38359176](https://our.internmc.facebook.com/intern/diff/D38359176/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82655
Approved by: https://github.com/fduwjj
2022-08-04 06:12:19 +00:00
Chien-Chin Huang
b750c10fbe [FSDP] Move the sharded_state_dict logic to the post hook to avoid OOM (#82613)
The original implementation put the call of `_summon_full_params()` in `state_dict()`. However, because `state_dict()` is recursive, `_summon_full_params()` will also behave like the recursive version even if recursive is set to False. This PR put the logic in the post hook to solve the OOM issue.

Differential Revision: [D38329396](https://our.internmc.facebook.com/intern/diff/D38329396/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82613
Approved by: https://github.com/rohan-varma
2022-08-03 17:16:13 +00:00
Rodrigo Kumpera
f4ee37453c [dist.checkpoint] Change metadata format and improve error reporting (#82078)
This PR implements the following changes.

Move to new checkpoint metadata format with split between logical and storage data.
This is a step in the direction of supporting extensible checkpointing as it moves us away from the hardcoded storage model enforced by the FileSystem storage layer.

Change CheckpointException to include exception traceback. Exception tracebacks are not serializable so we need to take care of that otherwise we provide horribly bad errors to users.

Finally, remove `validate_state_dict` as it lost its usefulness. Loading is becoming more and more flexible to the point that the only reasonable way to verify if it's possible to load a given configuration is to actually try it.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/82078
Approved by: https://github.com/wanchaol, https://github.com/fduwjj
2022-08-03 17:00:12 +00:00
Chien-Chin Huang
c05c3952cd [FSDP] Implement _param_fqns() to return all parameter FQNs for the FSDP module (#82595)
_param_fqns() returns the (fqn, param_name, module_name) tuple for all the parameters wrapped by the current FSDP module.

Differential Revision: [D38325223](https://our.internmc.facebook.com/intern/diff/D38325223/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82595
Approved by: https://github.com/rohan-varma
2022-08-02 17:31:56 +00:00
Wanchao Liang
48a34acf13 [_shard] add copy_ to shardedtensor (#82508)
as titled

Differential Revision: [D38290442](https://our.internmc.facebook.com/intern/diff/D38290442)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82508
Approved by: https://github.com/fduwjj
2022-08-01 23:52:19 +00:00
ProGamerGov
71d50f4f89 Change docstring type callable to Callable for consistency (#82487)
### Description

Across PyTorch's docstrings, both `callable` and `Callable` for variable types. The Callable should be capitalized as we are referring to the `Callable` type, and not the Python `callable()` function.

### Testing

There shouldn't be any testing required.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/82487
Approved by: https://github.com/albanD
2022-08-01 17:26:09 +00:00
ProGamerGov
8def154e00 Fix multiple docstring type mistakes (#82474)
### Description

* Docstrings using `(tuple of ints)` shows up as `(tuple of python:ints)`, so I fixed them by making the `int` no longer plural. Example: https://pytorch.org/docs/stable/generated/torch.permute.html#torch.permute
* A docstring type in JIT had one of its types incorrectly highlighted as code. Example: https://pytorch.org/docs/stable/generated/torch.jit.script.html#torch.jit.script
* I found some docstring type usages of `string` that had not yet been converted to `str` after #82410
* Some docstrings incorrectly listed their defaults inside the docstring types.
* I also found a docstring that was missing its type

### Testing
No testing should be required.

---

In the developer guidelines, there should probably be standards listed for the docstring types.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82474
Approved by: https://github.com/albanD
2022-07-29 17:45:37 +00:00
Andrew Gu
4630b9f44e [Easy][FSDP] Remove variable shadowing (#82386)
I unintentionally had the sub-tensors returned by `split()` bound to `tensor`, which shadows the original full `tensor`. This is a bad practice and is error-prone.

Test Plan: CI
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82386
Approved by: https://github.com/rohan-varma
2022-07-28 23:15:47 +00:00
ProGamerGov
357b7d589c Fix docstring inconsistencies: string -> str, boolean -> bool (#82410)
### Description

Throughout the PyTorch docs and codebase, the `string` type in docstrings is referred to by two separate names. This leads to inconsistent docs, like you can see here: https://pytorch.org/docs/stable/generated/torch.nn.Conv3d.html#torch.nn.Conv3d

This PR fixes this issue by ensuring that all mentions of the string type in docstrings, are using the same format that Sphinx generates hyperlinks for.

### Testing
No testing should be required for this change

Pull Request resolved: https://github.com/pytorch/pytorch/pull/82410
Approved by: https://github.com/jbschlosser
2022-07-28 21:29:57 +00:00
Rodrigo Kumpera
69eecdbc9c Introduce MetadataIndex and helper to use it. (#81909)
MetadataIndex simplifies indexing into state dict and Metadata.

This includes a find_state_dict_object helper that searcher into a state dict.

This PR doesn't include search over Metadata at it requires changes that will land
in a subsequent PR.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/81909
Approved by: https://github.com/wanchaol
2022-07-28 12:20:58 +00:00
Wanchao Liang
28c43190b8 [_shard] Add ShardedTensorBase (#82291)
This PR added ShardedTensorBase, which is the base class of
ShardedTensor, and only contains local shards, ShardedTensorMetadata,
and does not have any communication backend attached (i.e ProcessGroup)

Differential Revision: [D38190272](https://our.internmc.facebook.com/intern/diff/D38190272)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82291
Approved by: https://github.com/fduwjj
2022-07-28 00:17:52 +00:00
Rodrigo Kumpera
d2078fac11 [dist.checkpoint] Cleanup usage of collectives and introduce narrow helper (#81828)
Introduce _DistWrapper class that wraps a process group and provides functional
variants of collectives. It works without c10d enabled and is exception
robust.

Introduce tensor_narrow_n that handle narrowing over multiple dimentions.

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/81828
Approved by: https://github.com/wanchaol
2022-07-27 12:59:58 +00:00
Wanchao Liang
7ff121e75a [reland] make ShardedTensor be a Tensor and nn.Parameter (#82089)
This is the reland PR of https://github.com/pytorch/pytorch/pull/79825,
which was reverted due to multi-gpu ci failures. Fixes those failures
and reland it again.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82089
Approved by: https://github.com/fduwjj
2022-07-25 19:06:01 +00:00
Andrew Gu
57a566234f [FSDP] Refactor casting of grad to full param dtype (#81574)
I noticed a comment was repeated, so I wanted to refactor the duplication. Refer to the comment for the explanation.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81574
Approved by: https://github.com/rohan-varma
2022-07-25 18:05:52 +00:00
Andrew Gu
10a47c533d [FSDP] Update ShardingStrategy and _free_full_params() docs (#80894)
1. I messed up the comment for the post-backward `_free_full_params()` in https://github.com/pytorch/pytorch/pull/75901.
This removes the comment, which is not necessary, and instead adds an explanation in the `SHARD_GRAD_OP` comment itself.
2. This updates the overall `ShardingStrategy` documentation after the observation that `SHARD_GRAD_OP` did not specify that parameters are still sharded outside of computation.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/80894
Approved by: https://github.com/rohan-varma, https://github.com/zhaojuanmao
2022-07-22 23:22:17 +00:00
PyTorch MergeBot
f51cf774c6 Revert "[_shard] make ShardedTensor be a Tensor and nn.Parameter (#79825)"
This reverts commit 9c32439a77.

Reverted https://github.com/pytorch/pytorch/pull/79825 on behalf of https://github.com/janeyx99 due to Sorry, reverting for breaking multigpu tests 9c32439a77
2022-07-22 20:39:44 +00:00
Andrew Gu
9c94e10bba [FSDP] Move _post_backward_called to _init_param_attributes (#81243)
This moves the initialization of `_post_backward_called` on each `FlatParameter` from the body of `_lazy_init()` to inside `_init_param_attributes()` (which is called in `_lazy_init()`).

Differential Revision: [D37931854](https://our.internmc.facebook.com/intern/diff/D37931854)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81243
Approved by: https://github.com/rohan-varma
2022-07-22 19:27:07 +00:00
Andrew Gu
777ff539f2 [FSDP] Clean up _lazy_init() (#80185)
This PR cleans up `_lazy_init()`. The explanations are left as PR comments.

Differential Revision: [D37726059](https://our.internmc.facebook.com/intern/diff/D37726059)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/80185
Approved by: https://github.com/rohan-varma, https://github.com/zhaojuanmao
2022-07-22 19:24:22 +00:00
Andrew Gu
3059b13791 [FSDP] Remove self.numel_padded_per_param (unused) (#80002)
The list `self.numel_padded_per_param` is constructed but never used. This PR removes it.

Differential Revision: [D37726057](https://our.internmc.facebook.com/intern/diff/D37726057)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/80002
Approved by: https://github.com/rohan-varma
2022-07-22 19:24:21 +00:00
Andrew Gu
790b122901 [FSDP] Move tensor sharding logic to FlatParamHandle (#80000)
This moves the tensor sharding logic from `FullyShardedDataParallel` to `FlatParamHandle`. In particular, `_get_shard()` and its related subroutines are moved to `FlatParamHandle` as static methods.

The motivation is to start refactoring to move the broader FSDP sharding logic in `_shard_parameters()` to `FlatParamHandle` (as a part of the multiple parameter group and possibly future pluggable sharding efforts). In other words, in follow-ups, I hope to move
cd08954463/torch/distributed/fsdp/fully_sharded_data_parallel.py (L1444-L1447)
to be part of `FlatParamHandle`.

Differential Revision: [D37726060](https://our.internmc.facebook.com/intern/diff/D37726060)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/80000
Approved by: https://github.com/fegin
2022-07-22 19:21:51 +00:00
Andrew Gu
b069120d9c [FSDP] Deduplicate _orig_size and _unsharded_size (#79984)
This removes the `_orig_size` attribute that is initialized in `fully_sharded_data_parallel.py` since it represents the same quantity as `_unsharded_size` in `flat_param.py`. Since the quantity is not sharding dependent, we keep its initialization in `FlatParameter.init_metadata()` instead of in `FullyShardedDataParallel._shard_parameters()`.

Differential Revision: [D37726062](https://our.internmc.facebook.com/intern/diff/D37726062)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/79984
Approved by: https://github.com/rohan-varma
2022-07-22 19:20:47 +00:00
Andrew Gu
be656f55b1 [FSDP] Introduce FlatParamHandle (#79652)
**Overview**
This PR introduces `FlatParamHandle` to enable non-recursive FSDP wrapping. The class absorbs the unflattening/flattening logic from `FlattenParamsWrapper` but does not require wrapping a particular `nn.Module`.

## Discussion
### Introducing `FlatParamHandle`
There is flexibility in the design space for how to allocate attributes and methods to `FlatParameter` versus a wrapping class like `FlatParamHandle` or `FlattenParamsWrapper`. Several points in the design space provide the same functionality, so deciding on an allocation is arguably stylistic, though then preference should be given to cleaner designs.

The forefront consideration is that a `FlatParameter`'s metadata should be initialized once, while its data may be reloaded via checkpointing. This motivates decoupling the metadata initialization from the `FlatParameter` constructor, which should instead only handle the parameter data. Thus, we have both a `FlatParamHandle` managing a `FlatParameter` and the `FlatParameter` itself.
```
class FlatParamHandle:
    def __init__(self, module: nn.Module, params: Sequence[nn.Parameter]):
        # Calls `_init_flat_param()`
    def _init_flat_param(self, module: nn.Module, params: Sequence[nn.Parameter]):
        # Calls `flatten_params()` and initializes metadata
    @staticmethod
    def flatten_params(params: Sequence[torch.Tensor], requires_grad: bool) -> FlatParameter:
        # Also may be used for checkpoint reloading
class FlatParameter(nn.Parameter):
    # Constructor is not overridden
```
Under this separation with `FlatParameter` as solely as a data container, we keep methods manipulating `FlatParameter` on the `FlatParamHandle`. Because `FlatParameter`'s constructor is not overridden, we should be able to replace it with another tensor type e.g. `ShardedTensor` with minimal changes.

### Compatibility with `FlattenParamsWrapper`
To ensure backward compatibility, `FlattenParamsWrapper` now holds a `FlatParamHandle`. Existing logic from `FlattenParamsWrapper` simply routes to the handle now.

A `FullyShardedDataParallel` instance holds references to all of its handles.
- For the recursive-wrapping paradigm, there is at most one handle, which is from its `FlattenParamsWrapper` if it manages parameters.
- For the non-recursive wrapping paradigm, there may be multiple handles, all owned by the single (root) `FullyShardedDataParallel` instance.

## For Reviewers
### `FlatParameter` Construction
In the existing implementation, a `FlatParameter`'s metadata was partially initialized in its constructor (e.g. `_param_numels`, `_param_shapes`) and partially initialized by the owning `FlattenParamsWrapper` (e.g. `_param_infos`, `_shared_param_infos`). The latter part was needed due to requiring module information. With this PR, the metadata initialization is consolidated in `FlatParamHandle`.
- During model construction, a `FlatParameter` should be initialized via the handle constructor`FlatParamHandle(params, module)`.
- During sharded checkpoint loading, a `FlatParameter` should be initialized via the static method `FlatParamHandle.flatten_params(new_params)`.
    - The checkpointing implementation is responsible for checking that `new_params` used to construct the `FlatParameter` data to load is consistent with the existing `FlatParameter`'s metadata.

These are the only two cases for `FlatParameter` construction right now, so there is no real functionality regression by not recomputing some of the metadata in the `FlatParameter` constructor. The `nn.Module.state_dict()` is implemented using in-place `copy_()`, so the new loaded `FlatParameter`'s metadata *should* match the existing `FlatParameter`'s metadata for correctness anyway. (I.e. we do not support a usage where we reload a `FlatParameter` with differing metadata into an existing `FlatParameter`.)

### BC Breaking
- `ShardMetadata` -> `FlatParamShardMetadata` to avoid name conflict with `ShardedTensor`
    - `metadata()` -> removed (unused)
- `FlatParameter` attributes
    - `_param_numels` -> `_numels`
    - `_param_shapes` -> `_shapes`
    - `_param_names` -> `_prefixed_param_names`
    - `full_numel` -> `_unsharded_size.numel()`
    - `_param_indice_in_shard` -> `_shard_indices`
    - `_sharded_param_offsets` -> `_shard_param_offsets`
    - `num_padded` -> `_shard_numel_padded`
    - `param_offsets` -> not saved; directly constructed in `_get_flat_param_offsets()` and used once
- `FlattenParamsWrapper` `param_list` argument -> `params` for consistency with `FlatParameter`

## Follow-Ups

- The current `FlatParameter`'s `data` represents either the sharded unflattened parameter, unsharded unflattened parameter, or reduced-precision sharded unflattened parameter, depending dynamically on the runtime context. When its `data` represents one quantity, the other quantities are still saved as attributes on the `FlatParameter` (e.g. `_local_shard`, `_full_param_padded`, `_mp_shard`). `FullyShardedDataParallel` directly manipulates the `data`.
We should investigate the tradeoffs of having those attributes on the `FlatParameter` versus moving them to the `FlatParamHandle`. The motivation for the latter is to define a clean interface for `FullyShardedDataParallel` to manage parameter data in preparation for generalizing to multiple parameter groups, to managing non-`FlatParameter`s, and to supporting non-CUDA devices. (More explicitly, `FullyShardedDataParallel`'s parameter *variables* would be set to different `Tensor` variables, none of which own another, instead of `FullyShardedDataParallel`'s parameter variables' *data* being set to different `Tensor` variables, all owned by the `FlatParameter`, and the data management would be folded into handle, hidden from `FullyShardedDataParallel`.)
- We should investigate if we can coalesce the remaining logic in `FlattenParamsWrapper` into `FullyShardedDataParallel` and remove `FlattenParamsWrapper`.
- We may want to move the mixed precision logic to the handle instead of the `FullyShardedDataParallel` instance to enable per-`FlatParameter` mixed precision instead of per-`FullyShardedDataParallel`. Otherwise, the non-recursive wrapping path is bound to all-or-nothing mixed precision.

Differential Revision: [D37250558](https://our.internmc.facebook.com/intern/diff/D37250558)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/79652
Approved by: https://github.com/zhaojuanmao, https://github.com/fegin, https://github.com/rohan-varma
2022-07-22 19:16:50 +00:00
Wanchao Liang
9c32439a77 [_shard] make ShardedTensor be a Tensor and nn.Parameter (#79825)
Differential Revision: [D37707371](https://our.internmc.facebook.com/intern/diff/D37707371)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/79825
Approved by: https://github.com/kumpera
2022-07-22 16:50:12 +00:00
Bill Darrow
38988a8d14 [rpc/distributed] eliminate code duplication in distributed/rendezvou… (#81577)
This change eliminates duplication in redundant code paths inside of distributed/rendezvous.py.  Minor additional test coverage is added.

Fixes #74440

Pull Request resolved: https://github.com/pytorch/pytorch/pull/81577
Approved by: https://github.com/H-Huang
2022-07-22 16:21:00 +00:00
Olga Andreeva
a60907ec11 Adding fsdp fp16 and bf16 hooks (#81711)
Recently, `register_comm_hook` was introduced to `FSDP`, which at the moment supports only `NO_SHARD` strategy and has a default `all_reduce` hook implemented. This PR adds two lower precision hooks to an existing default hook.

I've also made slight adjustments to existing implementation of an `all_reduce` hook including:

`AllReduceState` ->` DefaultState `, motivation: `AllReduceState` is not specific to all_reduce. Gradients' pre- and post-division factors are also useful for other hooks, that require pre- and post-division, e.g. `fp16_hook` and `bf16_hook`.
I've put all 3 hooks into `default_hooks.py`
Additionally, `FSDP` supports `MixedPrecision` and, theoretically, it is possible to specify MixedPrecision for gradients and attach a lower precision hook to the model. To avoid double-casting, I've added a couple of checks to `fully_sharded_data_parallel`, i.e. casting to precision and back is performed by a lower precision hook only. I think, as a next step, it would be nice to ensure that user can't have both lower precision hook and MixedPrecision(reduce_dtype=<precision>) specified, but I am happy to discuss this and adjust current implementation.

As a test, I create two models: one with a lower precision hook and one with a `MixedPrecision(reduce_dtype=<precision>)` specified, perform one forward/backward and optimizer step and compare gradients.

PS. first version of this PR was reverted, because added unittests didn't include NCCL version checks for `bf16_hook` (thus failed on trunk). In this version, I've added appropriate checks for tests.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/81711
Approved by: https://github.com/rohan-varma
2022-07-19 23:54:51 +00:00
Andrew Gu
87cdb52cc4 [FSDP] Stricten _update_p_data() in _summon_full_params() (#81573)
385ae8721e/torch/distributed/fsdp/fully_sharded_data_parallel.py (L2530-L2542)

The `finally` block below should undo what is done above -- namely, pointing the flattened parameter's data to the CPU copy of the unsharded flattened parameter.

385ae8721e/torch/distributed/fsdp/fully_sharded_data_parallel.py (L2558-L2575)
(This code snipped shows after the change in the PR.)

This PR makes the conditional in the `finally` match the conditional before (adding the `and (not rank0_only or my_rank == 0)` part). Otherwise, for nonzero ranks when `rank0_only == True`, their flattened parameters' `.data` is unnecessarily updated to itself.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81573
Approved by: https://github.com/rohan-varma
2022-07-19 19:39:45 +00:00
Atul Jangra
ba54165392 Make sure that exit code is propagated from Child to parent process (#81408)
Summary: Refractor error_handler.py

Test Plan:
In the previous diff, I added a unit test which showcases the failed case. With this diff, we can see that the override works as expected.

Also added few additional tests for test coverage

Reviewed By: wilson100hong

Differential Revision: D37677402

Pull Request resolved: https://github.com/pytorch/pytorch/pull/81408
Approved by: https://github.com/d4l3k
2022-07-19 18:47:54 +00:00
PyTorch MergeBot
a8f4011e90 Revert "Adding fsdp fp16 and bf16 hooks (#80557)"
This reverts commit f7d6828467.

Reverted https://github.com/pytorch/pytorch/pull/80557 on behalf of https://github.com/aovladi due to broke distributed tests on trunk
2022-07-19 03:11:19 +00:00
Olga Andreeva
f7d6828467 Adding fsdp fp16 and bf16 hooks (#80557)
Recently, `register_comm_hook` was introduced to `FSDP`, which at the moment supports only `NO_SHARD` strategy and has a default `all_reduce` hook implemented. This PR adds two lower precision hooks to an existing default hook.

I've also made slight adjustments to existing implementation of an `all_reduce` hook including:

- `AllReduceState` ->  `DefaultState` , motivation: `AllReduceState` is not specific to `all_reduce`. Gradients' pre- and post-division factors are also useful for other hooks, that require pre- and post-division, e.g. fp16_hook and bf16_hook.
- I've put all 3 hooks into `default_hooks.py`

Additionally, `FSDP` supports `MixedPrecision` and, theoretically, it is possible to specify `MixedPrecision` for gradients and attach a lower precision hook to the model. To avoid double-casting, I've added a couple of checks to `fully_sharded_data_parallel`, i.e. casting to precision and back is performed by a lower precision hook only. I think, as a next step, it would be nice to ensure that user can't have both lower precision hook and `MixedPrecision(reduce_dtype=<precision>)` specified, but I am happy to discuss this and adjust current implementation.

As a test, I create two models: one with a lower precision hook and one with a `MixedPrecision(reduce_dtype=<precision>)` specified, perform one forward/backward and optimizer step and compare gradients.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/80557
Approved by: https://github.com/rohan-varma
2022-07-18 22:40:56 +00:00
Jerome
547e499731 Enable Zero1's ddp_with_overlap for hpu backend (#80438)
Enable zero with ddp overlap feature along with a simple interface to insert functional optimizer to the map

Signed-off-by: Jerome <janand@habana.ai>

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/80438
Approved by: https://github.com/rohan-varma, https://github.com/awgu
2022-07-18 15:05:27 +00:00
Chien-Chin Huang
3ea1b9be94 [FSDP] Construct FQN in _full_post_state_dict_hook (#81253)
Differential Revision: [D37079444](https://our.internmc.facebook.com/intern/diff/D37079444/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81253
Approved by: https://github.com/rohan-varma
2022-07-15 20:49:32 +00:00
Sergii Dymchenko
d61ae1a773 Remove unused variables from state_dict_loader (#81513)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81513
Approved by: https://github.com/mrshenli
2022-07-15 15:31:34 +00:00
Sergii Dymchenko
fe34bf1201 Remove unused storage_size (#81514)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81514
Approved by: https://github.com/mrshenli
2022-07-15 15:30:52 +00:00
Sergii Dymchenko
d083b44818 Remove unused rank from _AllGatherBase backward (#81515)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81515
Approved by: https://github.com/mrshenli
2022-07-15 15:30:07 +00:00
linjianma
dd73c97ea2 [FSDP] Remove the dependency of `_symbolic_trace in wrap` (#81443)
Same as #81339, this is used to fix internal tests where ``torch.fx`` is not available in the module and ``wrap`` of FSDP is imported. With this there should be no import errors for FSDP when ``torch.fx`` is not available.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81443
Approved by: https://github.com/awgu, https://github.com/zhaojuanmao, https://github.com/Neilblaze
2022-07-14 15:59:27 +00:00
linjianma
ed8a830da8 [FSDP] import `_symbolic_trace only when torch.fx` is enabled (#81339)
This is used to fix internal tests where ``torch.fx`` is not available in the module.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81339
Approved by: https://github.com/rohan-varma, https://github.com/zhaojuanmao
2022-07-13 03:08:38 +00:00
Terry Lam
54bdaf76d6 [PFC] Native UCC process group for Pytorch (#79918)
Summary:
This diff integrates UCC process group as a native component of Pytorch Distributed core. It is based on the existing torch-ucc (https://github.com/facebookresearch/torch_ucc) as the wrapper for UCC collective communication library.
The environment and cmake variables are named in mirroring to the existing process groups such as NCCL and Gloo. Specifically,
- USE_UCC: enables UCC PG. This defaults to OFF, so there is no breakage of existing builds that do not have UCX/UCC external libraries.
- USE_SYSTEM_UCC: uses external UCX and UCC shared libraries that are set accordingly with UCX_HOME and UCC_HOME.

Currently, this diff only supports USE_SYSTEM_UCC=ON, i.e., requiring users to specify external libraries for UCX and UCC. In subsequent diffs, we will add UCX and UCC repos as third-party dependencies in pytorch/third-party.

Test Plan:
Passed Torch-UCC tests that invoke UCC process group. For example:

$ sh test/start_test.sh test/torch_allreduce_test.py --backend gloo --use-cuda
...
Test allreduce: succeeded

Differential Revision: D36973688

Pull Request resolved: https://github.com/pytorch/pytorch/pull/79918
Approved by: https://github.com/kwen2501, https://github.com/kingchc
2022-07-12 14:45:44 +00:00
anjali411
93912b1a73 Add __all__ to torch.distributed submodules (#80523)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/80523
Approved by: https://github.com/rohan-varma
2022-07-11 06:54:24 +00:00
Linjian Ma
f3008be900 [FSDP] Getting the parameter execution information using torch.fx (#80294)
This support allows one to get the parameter execution order at the FSDP module construction time rather than at runtime for some models, under which case the preparation step can be removed. This will be used in the non-recursive wrapping policy later on.

Note that this support is based on the assumption that the tracer provided by the user will be able to successfully trace the forward pass.

### Advantage of using `torch.fx` to get the parameter order rather than using backward hook:
When using backward hook, the parameter execution order will be the reversed ordering of the parameter gradient ready order. One problem is that we are not able to get the number of times each parameter is used inside the forward function. For example, consider the following forward function,
```python
def forward(self, x):
    z = self.relu(self.layer0(x))
    z = self.relu(self.layer2(z))
    z = self.relu(self.layer1(z))
    z = self.relu(self.layer0(x))
    return z
```
Based on the parameter gradient ready order, the current parameter execution order for the example is `[layer0.weight, layer2.weight, layer1.weight]`. However, we don't get the information that layer0 is called twice.
Using `torch.fx`, we can get a more detailed parameter execution order: [layer0.weight, layer2.weight, layer1.weight, layer0.weight]. This allows us to implement more scheduling algorithms that could be useful in multiple regimes. For example, since we know that `layer0` will be called twice, we can delay the resharding of `layer0.weight` to the end, which would costs more memory but faster.

### Example of API usage
The execution information is recorded via calling `tracer.trace` in the `_patch_tracer` context manager:
```python
tracer = torch.fx.Tracer() # or an instance of Tracer's children class
execution_info = _init_execution_info(model)
with _patch_tracer(
    tracer=tracer, root_module=model, execution_info=execution_info
):
    tracer.trace(model, concrete_args=...)
```
The execution information will be recorded in `execution_info.module_forward_order` and `execution_info.module_to_execution_infos`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/80294
Approved by: https://github.com/mrshenli, https://github.com/zhaojuanmao
2022-07-09 16:55:41 +00:00
anjali411
4bf076e964 Add __all__ to torch.distributed, futures, fx, nn, package, benchmark submodules (#80520)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/80520
Approved by: https://github.com/rohan-varma
2022-07-08 14:31:24 +00:00
Howard Huang
81ca2ff353 Prevent automatic cuda init in init_rpc (#80180)
Fixes #80141

Only initialize cuda if there are devices specified in `init_rpc`

Differential Revision: [D37458309](https://our.internmc.facebook.com/intern/diff/D37458309)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/80180
Approved by: https://github.com/rohan-varma
2022-07-08 14:18:02 +00:00
PyTorch MergeBot
0b8a5ca01b Revert "Adding maximize to rprop (#80335)"
This reverts commit 495aa9bc3a.

Reverted https://github.com/pytorch/pytorch/pull/80335 on behalf of https://github.com/albanD due to Broke rocm and windows test
2022-07-08 13:34:02 +00:00
Rob Zinkov
495aa9bc3a Adding maximize to rprop (#80335)
Added the maximize flag #68052 to rprop optimizer and updates the respective tests.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/80335
Approved by: https://github.com/albanD
2022-07-08 08:04:38 +00:00
Rob Zinkov
a1fd5b4273 Adding maximize to RMSprop (#80326)
Added the maximize flag #68052 to RMSprop optimizer and updates the respective tests.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/80326
Approved by: https://github.com/albanD
2022-07-08 08:04:26 +00:00
Andrew Gu
ab09f34622 [FSDP] Fix full_optim_state_dict() hang (#80712)
Fixes https://github.com/pytorch/pytorch/issues/80581.

Context:
1f08c1d3d6/torch/distributed/fsdp/_optim_utils.py (L152-L163)

To-Do:
I do not understand why inserting this `torch.cuda.synchronize()` prevents the `.cpu()` call from hanging and why in particular, this `torch.cuda.synchronize()` must be called on **all ranks**. If it is only called on the saving ranks (i.e. rank 0), then the hang persists.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/80712
Approved by: https://github.com/rohan-varma
2022-07-07 15:23:06 +00:00
anjali411
120987ffeb Fix macos public bindings failures (#80970)
We are seeing unrelated public bindings test failures on macos tests being triggered on random PRs. Here's an attempt to fix some of those.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/80970
Approved by: https://github.com/rohan-varma
2022-07-07 14:10:00 +00:00
Rohan Varma
0c5fdfd95f Revert "Revert "[FSDP Optim State] Remove checkpoint prefix (#80480)"" (#80936)
This reverts commit fe361dede4.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/80936
Approved by: https://github.com/awgu
2022-07-06 22:21:07 +00:00
PyTorch MergeBot
fe361dede4 Revert "[FSDP Optim State] Remove checkpoint prefix (#80480)"
This reverts commit 04c50fec1c.

Reverted https://github.com/pytorch/pytorch/pull/80480 on behalf of https://github.com/suo due to Broke master 04c50fec1c, the test failures were not unrelated
2022-07-06 02:43:27 +00:00
Rohan Varma
04c50fec1c [FSDP Optim State] Remove checkpoint prefix (#80480)
Remove `_checkpoint_wrapped_module` prefixes when creating keys for optimizer state_dict.

Having these does not actually create an issue for optim_state_dict save / load, but we'd like to strip these keys out for downstream code that consumes these APIs typically expecting checkpointing prefixes to not exist (as checkpointing should be a transparent operation which should not change module / parameter names).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/80480
Approved by: https://github.com/awgu, https://github.com/fegin
2022-07-06 01:17:58 +00:00
Chien-Chin Huang
e0eeb06ec6 Consolidate the naming of named_parameter and state_dict for CheckpointWrapper (#80089)
named_parameter() should return the same parameter names as state_dict() but the current CheckpointWrapper does not enforce this naming rule. This PR resolves this issue.

Differential Revision: [D37344200](https://our.internmc.facebook.com/intern/diff/D37344200/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/80089
Approved by: https://github.com/rohan-varma
2022-07-05 22:11:59 +00:00
wayi1
f76bb88205 fix docstring of PostLocalSGDOptimizer (#80855)
As title.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/80855
Approved by: https://github.com/awgu, https://github.com/rohan-varma
2022-07-05 14:58:35 +00:00
Charlie Yan
ffae7308c9 Enable test: distributed/algorithms/quantization/test_quantization (#80097)
fixes  https://github.com/pytorch/pytorch/issues/69017
Pull Request resolved: https://github.com/pytorch/pytorch/pull/80097
Approved by: https://github.com/wanchaol
2022-07-01 01:32:33 +00:00
PyTorch MergeBot
f667aaed1d Revert "Added serialization to postlocal_SGD. (#80435)"
This reverts commit dfdf4e79df.

Reverted https://github.com/pytorch/pytorch/pull/80435 on behalf of https://github.com/suo due to broke distributed tests on trunk, see: dfdf4e79df
2022-06-30 01:34:10 +00:00
Olga Andreeva
dfdf4e79df Added serialization to postlocal_SGD. (#80435)
Fixes #75666

Current PR adds the functionality for `PostLocalSGD` communication hook and tests that communication hook can be properly saved and restored. Similar to https://github.com/pytorch/pytorch/pull/79334, where serialization was added to `PowerSGD`.

``__getstate__``

 Returns:
```
        ``Dict[str, Any]`` which will be pickled and saved.
        ``process_group`` and ``subgroup`` are not serializable and excluded from
        a returned state.
```
``__setstate__``
```
          Takes provided ``state`` and retrieves ``PostLocalSGDState``.
          ``process_group`` and ``subgroup`` are set to default process_group and subgroup respectively.
           Default subgroup is equivalent to the subgroup on each node.
```

Small adjustment to `PowerSGD`'s warning message.

Refactored unittest, i.e. separated parity and log checks.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/80435
Approved by: https://github.com/awgu
2022-06-29 23:59:46 +00:00
PyTorch MergeBot
58532256e9 Revert "Add __all__ for torch.distributed and fx modules (#80460)"
This reverts commit 5d40c3d5c8.

Reverted https://github.com/pytorch/pytorch/pull/80460 on behalf of https://github.com/malfet due to Broke MacOS testing, see https://github.com/pytorch/pytorch/runs/7105579664?check_suite_focus=true
2022-06-29 16:20:55 +00:00
zilinzhu
3d9cef8c98 Clone tensor to write in ShardedTensor checkpoint (#79400)
The `torch.save` api will save the origin tensor of a view, which will results in saving a much larger checkpoint when parameters are fused, e.g. in torchrec.

Relates to #79016

Pull Request resolved: https://github.com/pytorch/pytorch/pull/79400
Approved by: https://github.com/kumpera
2022-06-29 03:47:24 +00:00
anjali411
5d40c3d5c8 Add __all__ for torch.distributed and fx modules (#80460)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/80460
Approved by: https://github.com/albanD, https://github.com/rohan-varma
2022-06-29 02:53:56 +00:00
Rohan Varma
5fc2d45a3a Remove unneeded TODO (#80453)
This TODO is no longer needed, as we use `_register_fused_optim` to register the overlapped optimizer in DDP.  Also, remove comment about API being experimental, as this API is no longer going to be used by end user.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/80453
Approved by: https://github.com/awgu
2022-06-29 01:19:48 +00:00
Olga Andreeva
a48f3059b7 Corrected comments in fsdp (#80456)
Currently,  pre- and post-division steps in `FullyShardedDataParallel._post_backward_hook` state the following:
>  Average grad by world_size for consistency with PyTorch DDP.

This is not matching what is actually going on, i.e. pre-divide factor may be equal to `world_size` and may not.
For example, for `world_size = 3 `, `predivide_factor=2`

This PR clarifies pre- and post-division in the code

Pull Request resolved: https://github.com/pytorch/pytorch/pull/80456
Approved by: https://github.com/rohan-varma
2022-06-28 18:46:05 +00:00
PyTorch MergeBot
14a7cf79c1 Add __all__ to torch.distributed and tensorboard submodules (#80444)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/80444
Approved by: https://github.com/rohan-varma
2022-06-28 16:33:22 +00:00
Olga Andreeva
5fc209ed11 FSDP communication hook interface for NO_SHARD strategy (#79833)
Fixes #79114

An implementation of a FSDP communication hook interface for a NO_SHARD strategy:
- `FullyShardedDataParallel.register_comm_hook(self, state: object, hook: callable)` checks current sharding strategy. If it is other that NO_SHARD, raises a runtime error. Otherwise, sets and shares a specified hook and its state with all submodules
- When FSDP is ready to communicate a gradient, checks if there is a registered hook, and calls it instead of all_reduce. Additionally, gradient pre and post devision are not performed if a hook is registered.

To test the interface, I've implemented a communication hook, that calls for `all_reduce`.

A  unittest:
- checks that is a sharding strategy is anything but NO_SHARD, a runtime error is raised
- checks that for a NO_SHARD case, model with registered all_reduce hook and without a hook work the same.
- checks for 2 types of FSDP models: with the wrapped first layer and without. (to make sure submodules have a hook registered)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/79833
Approved by: https://github.com/rohan-varma, https://github.com/awgu
2022-06-28 08:03:11 +00:00
Rohan Varma
e88dadd5eb Fix FSDP when not all outputs get gradient in backward (#80245)
In some use cases, FSDP runs into an issue where a training state assert in `_wait_for_post_backward` erroneously fires. Digging into the root cause, this is because `_post_backward_hook` which sets the module's training state to backward_post is never actually called, since no param in that module had gradient computed for it. Similar to DDP, this can happen when not all module outputs are used in loss computation, or module did not participate in forward at all.

Fix this by tracking a variable `_post_backward_called` to track whether the hook is actually called or not.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/80245
Approved by: https://github.com/awgu
2022-06-28 05:54:13 +00:00
anjali411
3bcc19b29a Add __all__ to various submodules in torch.fx, distributions, distributed, package (#80367)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/80367
Approved by: https://github.com/albanD
2022-06-27 21:27:30 +00:00
PyTorch MergeBot
9db3c517de Add __all__ for torch.nn.modules, torch.distributed.elastic, torch.nn.utils submodules (#80240)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/80240
Approved by: https://github.com/rohan-varma
2022-06-27 17:11:12 +00:00
Andrew Gu
0b0e65516d [FSDP] Fix param name prefixes for ignored modules (#79955)
For ignored modules' parameters, we should also clean their parameter names since they will have the FSDP-specific prefixes.

This change only affects the prefixed parameter name keys in `full_optim_state_dict()` (i.e. optim state dict saving). Not having this change does not actually violate the correctness of the optim state dict save-load flow because it only requires that the keys are unique and internally consistent.

Either way, this PR explicitly adds the specification now that the parameter keys in the optim state dict should match the keys of full model state dict.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/79955
Approved by: https://github.com/rohan-varma
2022-06-21 22:10:33 +00:00
Rohan Varma
2ede28724d [CheckpointWrapper] Replace generic mod prefix (#79830)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/79830
Approved by: https://github.com/awgu, https://github.com/zhaojuanmao
2022-06-21 16:01:59 +00:00
nariaki3551
6d6e77eb6b Fix some links in torch/distributed/CONTRIBUTING.md (#79855)
Fix some invalid links in torch/distributed/CONTRIBUTING.md

Pull Request resolved: https://github.com/pytorch/pytorch/pull/79855
Approved by: https://github.com/H-Huang
2022-06-21 00:48:30 +00:00
Linjian Ma
56c98a2b7b [FSDP] First implementation of ParamExecOrderWrapPolicy (non-recursive wrap policy) (#79238)
This is the first PR for a wrapping policy that wraps parameters and performs the communication scheduling based on the parameter execution order in the forward pass (also called non-recursive wrapping policy).

This PR includes:
- The basic API for using this policy,
- A helper function to get the parameter execution order in the first forward and backward pass.

Other parts will be implemented in future PRs.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/79238
Approved by: https://github.com/zhaojuanmao, https://github.com/awgu
2022-06-21 00:20:01 +00:00
Linjian Ma
5ca9253fa8 [FSDP] Fix a small bug of pre_backward_hook params prefetch (#78851)
Fix a potential small bug in FSDP pre_backward_hook params prefetch.

In `_pre_backward_hook`, `self._need_prefetch_full_params(self.training_state)` is used to decide whether the params of the next backward pass needs to be pre-fetched, and currently it is also used to check whether we want to perform synchronization in the current backward pass before `_rebuild_full_params`.

For some edge cases, using this to check whether to perform synchronization is not current. One example is when `self._my_fsdp_idx_in_graph = 0`, which means this is the last backward pass. In this way, we have `self._need_prefetch_full_params(self.training_state)=False` since there is no backward pass after it, and currently synchronization will not be done before `_rebuild_full_params`.
But the params of this layer is prefetched at the previous layer, thus a synchronization needs to be done.

To fix this, we just needs to check whether to do the synchronization using another flag rather than `self._need_prefetch_full_params(self.training_state)`, and that is what this PR does.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/78851
Approved by: https://github.com/zhaojuanmao
2022-06-18 21:46:12 +00:00
Wanchao Liang
bef2fecbbc [shard] make state_dict hook be consistent
Pull Request resolved: https://github.com/pytorch/pytorch/pull/79650

for root module we shouldn't accidentally add a "." for state_dict keys, it
should be empty instead to match the module.state_dict behavior

Differential Revision: [D37191203](https://our.internmc.facebook.com/intern/diff/D37191203/)

**NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D37191203/)!

Approved by: https://github.com/pritamdamania87, https://github.com/fduwjj
2022-06-17 22:08:06 +00:00
pritam
500fb24715 Ensure tensors are contiguous in functional all_gather.
We called `tensor.contiguous()` in the forward pass, however this was
after the `out_tensor_list` was built which results in the `out_tensor_list`
containing non-contiguous tensors resulting in errors.

Fixing this by moving the contiguous call above.

Differential Revision: [D37222870](https://our.internmc.facebook.com/intern/diff/D37222870/)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/79747

Approved by: https://github.com/fduwjj, https://github.com/wanchaol
2022-06-17 01:27:11 +00:00
Olga Andreeva
8a6d83079c Functionality/pickling for commhooks (#79334)
This PR addresses issue address #75666.
Stateful communication hook now can be saved and reloaded to resume training.

Current PR adds the functionality for PowerSGD communication hook and tests that communication hook can be properly saved and restored.

PowerSGD implementation uses ``__slots__``, as a result introduced __getstate__ and __setstate__ methods are implemented to work with `__slots__` and not` __dict__`.

`__getstate__ `

	 Returns:
           A dictionary that represents a ``PowerSGDState`` which will be pickled and saved.
          ``process_group`` is non-serializable and excluded from a returned state.

`__setstate__`

	Takes a provided ``state`` and retrieves ``PowerSGDState``.
        ``process_group`` is set to default with a proper warning issued to a user.

Unit test

A hook-independent `_test_hook_pickling` is added with this PR, as well as `test_ddp_hook_pickling_powerSGD`, which tests `powerSGD`’s ability to be saved and reloaded.

Currently, the test creates a ddp model with a provided hook, trains it for 10 epochs and saves model’s state and hook’s state.
During reloading, unit test makes sure that a warning was logged (only one warning and the proper one). It then proceeds to check that reloaded hook and original hook are the same. Finally, it checks that a hook’s state was properly initialized:
	- it compares slot values (all, but 2: `process_group` and `rng`) for original and reloaded state
	- it checks that process group was set to a default group
	- it checks that a random state was restored properly with np.testing.assert_array_equal, because `rng` is an instance of `np.random.RandomState`, represented by a tuple. One of entries is of `ndarray dtype[uint32]` type and `np.testing.assert_array_equal` is used for assertion.

Future To-Do:
	- Implement similar __getstate__ and __setstate__ for other stateful communication hooks
	- Add appropriate tests

Pull Request resolved: https://github.com/pytorch/pytorch/pull/79334
Approved by: https://github.com/rohan-varma, https://github.com/awgu
2022-06-16 23:15:34 +00:00
Rodrigo Kumpera
270c518be0 [checkpoint] Implement interop between Tensor and Sharded Tensor (#78120)
This allows loading a Tensor from a checkpoint with a SharedTensor in the same FQN.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78120
Approved by: https://github.com/pritamdamania87
2022-06-16 15:31:09 +00:00
Andrew Gu
18fcd4826f [FSDP] Fix exec order validation for diff ignored modules across ranks
Pull Request resolved: https://github.com/pytorch/pytorch/pull/79533

Approved by: https://github.com/rohan-varma
2022-06-16 02:00:53 +00:00
Linjian Ma
70446c25d7 [FSDP] Add forward prefetching option in FSDP API (#78841)
Fixes #78608

Pull Request resolved: https://github.com/pytorch/pytorch/pull/78841
Approved by: https://github.com/zhaojuanmao
2022-06-15 20:59:08 +00:00
Nikita Shulga
09df27fe45 Revert "Revert "[distributed] Handle object collectives and NCCL. (#79034)""
This reverts commit 279634f384.
2022-06-15 10:04:37 -07:00
PyTorch MergeBot
279634f384 Revert "[distributed] Handle object collectives and NCCL. (#79034)"
This reverts commit 4ebb326b75.

Reverted https://github.com/pytorch/pytorch/pull/79034 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally
2022-06-15 16:16:21 +00:00
fduwjj
f4edbaa62f [PT-D] Use process group of the partial tensor so sub pg comm will be enabled during reshard
Pull Request resolved: https://github.com/pytorch/pytorch/pull/79357

Approved by: https://github.com/wanchaol
2022-06-14 17:44:51 +00:00
Rohan Varma
543919cfc8 Forward attributes to wrapped module
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78854

Approved by: https://github.com/albanD
2022-06-14 01:13:33 +00:00
Rohan Varma
44fe851feb [WIP] Fix non-reentrant hooks based checkpointing
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78752

Approved by: https://github.com/albanD
2022-06-14 01:13:33 +00:00
Rodrigo Kumpera
4ebb326b75 [distributed] Handle object collectives and NCCL. (#79034)
This fixes all object collectives under NCCL and adds some automated tests for them.

This PR *does not* fix sending tensors using object collectives.

It simplifies device handling by computing the appropriate one earlier and then ensuring all tensor ops happen on it.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/79034
Approved by: https://github.com/rohan-varma
2022-06-13 19:23:39 +00:00
Michael Carilli
ba27ee9e8f [CUDA graphs] Allows Adam and AdamW to be capture-safe (#77862)
Near term fix for https://github.com/pytorch/pytorch/issues/76368.

Q. Why does the user need to request `capturable=True` in the optimizer constructor? Why can't capture safety be completely automatic?
A. We need to set up capture-safe (device-side) state variables before capture. If we don't, and step() internally detects capture is underway, it's too late: the best we could do is create a device state variable and copy the current CPU value into it, which is not something we want baked into the graph.

Q. Ok, why not just do the capture-safe approach with device-side state variables all the time?
A. It incurs several more kernel launches per parameter, which could really add up and regress cpu overhead for ungraphed step()s. If the optimizer won't be captured, we should allow step() to stick with its current cpu-side state handling.

Q. But cuda RNG is a stateful thing that maintains its state on the cpu outside of capture and replay, and we capture it automatically. Why can't we do the same thing here?
A. The graph object can handle RNG generator increments because its capture_begin, capture_end, and replay() methods can see and access generator object. But the graph object has no explicit knowledge of or access to optimizer steps in its capture scope. We could let the user tell the graph object what optimizers will be stepped in its scope, ie something like
```python
graph.will_use_optimizer(opt)
graph.capture_begin()
...
```
but that seems clunkier than an optimizer constructor arg.

I'm open to other ideas, but right now I think constructor arg is necessary and the least bad approach.

Long term, https://github.com/pytorch/pytorch/issues/71274 is a better fix.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77862
Approved by: https://github.com/ezyang
2022-06-13 01:56:47 +00:00
pritam
a81be44410 Fix shard_module to appropriately deal with sub process groups.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/79264

`shard_module` API didn't work correctly with a sub-pg since
`dist.scatter` actually takes the global rank as input for `src`.

Fixing this by passing in the appropriate rank to `dist.scatter`

Differential Revision: [D37062766](https://our.internmc.facebook.com/intern/diff/D37062766/)

Approved by: https://github.com/fduwjj, https://github.com/wanchaol
2022-06-12 03:50:45 +00:00
Rohan Varma
ec86070922 Checkpoint util
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78704

Approved by: https://github.com/zhaojuanmao
2022-06-10 18:37:36 +00:00
pritam
b9e3d722c4 Use appropriate dtype for sharded linear implementation.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/79255

We use several collective operations in our sharded linear
implementation and for many collectives, we do not set the `dtype` of the
output tensor appropriately. As a result, using a datatype like torch.float16
(which is not the default torch.float32) results in errors.

Fixing this across the board and adding appropriate tests.

Differential Revision: [D37059752](https://our.internmc.facebook.com/intern/diff/D37059752/)

Approved by: https://github.com/fduwjj, https://github.com/wanchaol
2022-06-10 07:32:15 +00:00
Olga Andreeva
b1ae519df9 Added functionality for post_local SGD (#78988)
Fixes #74556

Added functionality to save and restore step counter for model averager.
Added a unittest.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78988
Approved by: https://github.com/rohan-varma, https://github.com/awgu
2022-06-09 17:47:04 +00:00
linjianma
0990a1c627 [FSDP] Profiling range for FSDP.backward (#78479)
Add profiling for pre and post backward hooks, partially fixes #67714

Pull Request resolved: https://github.com/pytorch/pytorch/pull/78479
Approved by: https://github.com/rohan-varma
2022-06-04 03:17:22 +00:00
pritam
c6ca4a4038 Fuse matmul in row-wise sharded linear to have a single matmul.
Performing a single large matmul is more efficient than having to
perform multiple matmuls in a loop.

Similar improvement to https://github.com/pytorch/pytorch/pull/78449

Differential Revision: [D36828505](https://our.internmc.facebook.com/intern/diff/D36828505/)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/78672

Approved by: https://github.com/fduwjj, https://github.com/wanchaol
2022-06-04 01:23:55 +00:00
Wanchao Liang
2fce7483a5 Back out "[shard] make ShardedTensor a torch.Tensor subclass" (#78796)
Summary:
Original commit changeset: f3ce270bad56

Original Phabricator Diff: D36569064

Test Plan: wait for sandcastle and doing additional checks

Reviewed By: guangyuwang

Differential Revision: D36890625

Pull Request resolved: https://github.com/pytorch/pytorch/pull/78796
Approved by: https://github.com/pbelevich
2022-06-03 19:39:26 +00:00
Andrew Gu
4615738a3d [FSDP] Allow different optim_input orders across ranks
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78599

Approved by: https://github.com/rohan-varma
2022-06-03 11:47:24 +00:00
Andrew Gu
d4d8aaf7cb [FSDP][Docs] Fix typo in full_optim_state_dict()
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78784

Approved by: https://github.com/rohan-varma
2022-06-03 11:41:21 +00:00
linjianma
c29df68f95 [FSDP] Return original module when fsdp wrapped model call .module (#78671)
Fixes #78607

Pull Request resolved: https://github.com/pytorch/pytorch/pull/78671
Approved by: https://github.com/awgu, https://github.com/rohan-varma
2022-06-03 04:38:19 +00:00
Howard Huang
24b7142d7a Update distributed/CONTRIBUTING.md to remove ProcessGroupAgent references and add test instructions
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78625

Approved by: https://github.com/mrshenli, https://github.com/albanD
2022-06-01 21:31:12 +00:00
pritam
5aa2ed1922 Remove call to .contiguous() for local_shard_t.
The call to contiguous was probably left over from a previous
implementation and is no longer needed.

Had to adjust atol for one of the tests to accomodate for this.

Differential Revision: [D36797942](https://our.internmc.facebook.com/intern/diff/D36797942/)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/78598

Approved by: https://github.com/kumpera
2022-06-01 18:50:10 +00:00
pritam
44aa4ad894 Use _all_gather_base and fuse matmul for sharded linear.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78477

Use `_all_gather_base` instead of all_gather for col-wise sharding
since `_all_gather_base` returns a single fused tensor that can be used to
perform a single matmul instead of looping through and performing multiple
matmuls.

This improves performance for col-wise sharding.

Differential Revision: [D36754385](https://our.internmc.facebook.com/intern/diff/D36754385/)

Approved by: https://github.com/aazzolini, https://github.com/wanchaol
2022-06-01 17:17:34 +00:00
pritam
effd270986 Fuse row-wise sharded linear matmul to increase perf.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78449

Instead of looping through and performing a matmul separately, we can
just perform a single matmul to ensure we launch a single cuda kernel for this
operation.

Differential Revision: [D36743354](https://our.internmc.facebook.com/intern/diff/D36743354/)

Approved by: https://github.com/aazzolini, https://github.com/wanchaol
2022-06-01 17:13:48 +00:00
Rohan Varma
e387fb4df7 [FSDP][BE][Docs] Improve auto wrap policy doc (#78400)
Closes https://github.com/pytorch/pytorch/issues/78399

- Add expected type of callable
- Clarify what the policy function should return, and how its used (i.e. what's done depends on `recurse` flag).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/78400
Approved by: https://github.com/awgu
2022-05-31 15:15:01 +00:00
Rohan Varma
a0b3814433 Clean prefixes when searching for params / buffers to ignore (#78278)
Co-authored with: @awgu

When `state_dict` has a prefix attached to it, the current logic for ignoring parameters and buffers does not work since it doesn't account for this prefix. To fix this, we make the following changes:

- clean the key if it starts with prefix. Note that all keys may not start with prefix, i.e. if the current module's state_dict_post_hook is running and previous module `state_dict` has already been computed and previous module is on the same level of hierarchy as the current module.
- This prefixing makes it so that it is not current to override child module's ignored params and buffers with the root FSDP instance's (this wouldn't work if child FSDP instances had ignored modules, and root didn't, for example). We fix this by having each parent know about the ignored modules of their children, and computing fully qualified names for ignored params and buffers.
- This means that each for a particular FSDP instance, that instance knows about the names of itself and its children (in fully qualified form) that it needs to ignore. It wouldn't know about parent ignored params and buffers, but it doesn't need to store this data.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78278
Approved by: https://github.com/awgu
2022-05-26 02:43:03 +00:00
fduwjj
141238a889 [PT-D] Enable nan_to_num op for sharded tensor
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78223

Approved by: https://github.com/pritamdamania87
2022-05-25 18:03:42 +00:00
Andrew Gu
8412f209f0 [FSDP] Remove unneeded padding logic for optim state dict
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78208

Approved by: https://github.com/rohan-varma
2022-05-25 17:22:03 +00:00
Wanchao Liang
8eb62bd7ba [shard] make ShardedTensor a torch.Tensor subclass
This is the reland of PR https://github.com/pytorch/pytorch/pull/74695, which was reverted due to some internal failures.

It also removes the ShardedTensorInterface change, we will delay that
change later if we found there's a need to do that.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/78027

Approved by: https://github.com/pritamdamania87, https://github.com/fduwjj
2022-05-24 01:20:45 +00:00
pritam
37eb31599c [reland] Add sharding tests to multigpu-test.sh and fix custom operator decorator (#77987)
1. Enabled multigpu tests.
2. Fixed failing multigpu tests.
3. Fixed custom operator decorator to be first preference in operator dispatch.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77987
Approved by: https://github.com/fduwjj, https://github.com/wanchaol, https://github.com/janeyx99
2022-05-21 22:33:58 +00:00
PyTorch MergeBot
0f74b44f1a Revert "Add sharding tests to multigpu-test.sh and fix custom operator decorator (#77825)"
This reverts commit 8d4c8df33a.

Reverted https://github.com/pytorch/pytorch/pull/77825 on behalf of https://github.com/janeyx99 due to as it will break multigpu test reporting
2022-05-20 17:59:03 +00:00
pritam
8d4c8df33a Add sharding tests to multigpu-test.sh and fix custom operator decorator (#77825)
1. Enabled multigpu tests.
2. Fixed failing multigpu tests.
3. Fixed custom operator decorator to be first preference in operator dispatch.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77825
Approved by: https://github.com/wanchaol, https://github.com/fduwjj
2022-05-20 16:53:27 +00:00
Andrew Gu
e69d13b8b3 [FSDP][Easy] Update state_dict() docstring
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77853

Approved by: https://github.com/rohan-varma
2022-05-19 23:59:03 +00:00
Andrew Gu
d9b3feb27d [FSDP][Easy] Reword device placement warning
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77850

Approved by: https://github.com/rohan-varma
2022-05-19 23:57:40 +00:00
Andrew Gu
36bf8007f7 [FSDP][Easy] Fix state_dict_type() docstring example
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77848

Approved by: https://github.com/rohan-varma
2022-05-19 23:53:15 +00:00
Andrew Gu
96e674a0c9 [FSDP][Easy] Doc fixes
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77847

Approved by: https://github.com/rohan-varma
2022-05-19 23:53:15 +00:00
pritam
327d313705 Refactor operator dispatch framework across different Tensors.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77707

Refactor to clean up the following pieces:

1) Consolidate decorators to use a common way to look up operator tables.
2) Move a bunch of utilities to `op_registry_utils` and `common_op_utils` and
reuse them across ShardedTensor, ReplicatedTensor and PartialTensor.

Differential Revision: [D36465639](https://our.internmc.facebook.com/intern/diff/D36465639/)

Approved by: https://github.com/wanchaol, https://github.com/fduwjj
2022-05-19 19:27:07 +00:00
Rohan Varma
eb0ff991f7 [FSDP] Dont move if on CPU (#77720)
After offline discussion, decided that by default moving CPU module to GPU is a bit too risky due to possible OOM during init issue.

Theoretically, we should not OOM because it is required for module that is being wrapped by FSDP to fit into GPU, i.e. during forward. But possibly can be temporary GPU tensors etc allocated during __init___ that break this assumption, it is better for now to allow users a way to init on CPU if needed.

We still warn to use `device_id` for faster init if model is on CPU.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77720
Approved by: https://github.com/zhaojuanmao
2022-05-19 14:47:50 +00:00
Rohan Varma
4a57321a93 [FSDP] Use post load_state_dict hooks (#76912)
Rehash of https://github.com/pytorch/pytorch/pull/75426 now that a revised version of load_state_dict_post_hook has landed.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/76912
Approved by: https://github.com/awgu
2022-05-19 00:35:34 +00:00
Rodrigo Kumpera
c9570e4b88 [checkpoint] Synchronize error handling across all ranks (#77091)
Introduce error handling across all ranks when loading and saving checkpoints.

This makes it a lot simpler for users to handle failures and, as a positive side-effect, coordination of when it successfully finished.

This change requires 3 collectives when saving and 1 when loading.
All those collectives carry a small payload so they will be latency bound and write time should dominate it.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/77091
Approved by: https://github.com/pritamdamania87, https://github.com/wanchaol
2022-05-18 21:24:09 +00:00
Rohan Varma
4c34343216 [FSDP] Warning for shared params, small doc fixes (#77726)
- Add warning about limited shared param suppport
- Some small doc fixes after combing through the docs ; we should do a more thorough doc lookthrough.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77726
Approved by: https://github.com/zhaojuanmao
2022-05-18 14:59:36 +00:00
Andrew Gu
93b20b0232 [FSDP][Easy] Remove extraneous print
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77705

Approved by: https://github.com/zhaojuanmao
2022-05-18 13:16:06 +00:00
lezcano
ff7b6d6b5f Update linalg.*norm
This PR does a number of things:
- Move linalg.vector_norm to structured kernels and simplify the logic
- Fixes a number of prexisting issues with the dtype kwarg of these ops
- Heavily simplifies and corrects the logic of `linalg.matrix_norm` and `linalg.norm` to be consistent with the docs
  - Before the `_out` versions of these functions were incorrect
  - Their implementation is now as efficient as expected, as it avoids reimplementing these operations whenever possible.
- Deprecates `torch.frobenius_norm` and `torch.nuclear_norm`, as they were exposed in the API and they are apparently being used in mobile (??!!) even though they were not documented and their implementation was slow.
  - I'd love to get rid of these functions already, but I guess we have to go through their deprecation.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/76547

Approved by: https://github.com/mruberry
2022-05-18 11:46:50 +00:00
fduwjj
3b2375291a [PT-D][Sharding] Fix view op and matrix ops unit test
To fix a corner case when the sharding dim is negative number we need to handle it correctly.

Also disable RPC for matrix ops which are not necessary and they fail on AWS pytest.

Differential Revision: [D36464797](https://our.internmc.facebook.com/intern/diff/D36464797/)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/77706

Approved by: https://github.com/pritamdamania87, https://github.com/wanchaol
2022-05-18 03:10:37 +00:00
pritam
068d35a648 Make PartialTensot a torch.Tensor subclass
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77626

Differential Revision: [D36435152](https://our.internmc.facebook.com/intern/diff/D36435152/)

Approved by: https://github.com/wanchaol
2022-05-17 21:44:14 +00:00
Rohan Varma
6f954d7bbb FSDP parameter sync
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77492

Approved by: https://github.com/zhaojuanmao
2022-05-17 19:58:49 +00:00
Rohan Varma
8ae0b275f5 Fix device_id
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77491

Approved by: https://github.com/zhaojuanmao
2022-05-17 19:58:49 +00:00
pritam
c83f8ee46a Fix partial_tensor ops.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77580

Replace process_group with _process_group.

Differential Revision: [D36419464](https://our.internmc.facebook.com/intern/diff/D36419464/)

Approved by: https://github.com/wanchaol, https://github.com/fduwjj
2022-05-17 08:21:38 +00:00