Commit Graph

1094 Commits

Author SHA1 Message Date
Rohan Varma
a0b3814433 Clean prefixes when searching for params / buffers to ignore (#78278)
Co-authored with: @awgu

When `state_dict` has a prefix attached to it, the current logic for ignoring parameters and buffers does not work since it doesn't account for this prefix. To fix this, we make the following changes:

- clean the key if it starts with prefix. Note that all keys may not start with prefix, i.e. if the current module's state_dict_post_hook is running and previous module `state_dict` has already been computed and previous module is on the same level of hierarchy as the current module.
- This prefixing makes it so that it is not current to override child module's ignored params and buffers with the root FSDP instance's (this wouldn't work if child FSDP instances had ignored modules, and root didn't, for example). We fix this by having each parent know about the ignored modules of their children, and computing fully qualified names for ignored params and buffers.
- This means that each for a particular FSDP instance, that instance knows about the names of itself and its children (in fully qualified form) that it needs to ignore. It wouldn't know about parent ignored params and buffers, but it doesn't need to store this data.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78278
Approved by: https://github.com/awgu
2022-05-26 02:43:03 +00:00
fduwjj
141238a889 [PT-D] Enable nan_to_num op for sharded tensor
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78223

Approved by: https://github.com/pritamdamania87
2022-05-25 18:03:42 +00:00
Andrew Gu
8412f209f0 [FSDP] Remove unneeded padding logic for optim state dict
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78208

Approved by: https://github.com/rohan-varma
2022-05-25 17:22:03 +00:00
Wanchao Liang
8eb62bd7ba [shard] make ShardedTensor a torch.Tensor subclass
This is the reland of PR https://github.com/pytorch/pytorch/pull/74695, which was reverted due to some internal failures.

It also removes the ShardedTensorInterface change, we will delay that
change later if we found there's a need to do that.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/78027

Approved by: https://github.com/pritamdamania87, https://github.com/fduwjj
2022-05-24 01:20:45 +00:00
pritam
37eb31599c [reland] Add sharding tests to multigpu-test.sh and fix custom operator decorator (#77987)
1. Enabled multigpu tests.
2. Fixed failing multigpu tests.
3. Fixed custom operator decorator to be first preference in operator dispatch.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77987
Approved by: https://github.com/fduwjj, https://github.com/wanchaol, https://github.com/janeyx99
2022-05-21 22:33:58 +00:00
PyTorch MergeBot
0f74b44f1a Revert "Add sharding tests to multigpu-test.sh and fix custom operator decorator (#77825)"
This reverts commit 8d4c8df33a.

Reverted https://github.com/pytorch/pytorch/pull/77825 on behalf of https://github.com/janeyx99 due to as it will break multigpu test reporting
2022-05-20 17:59:03 +00:00
pritam
8d4c8df33a Add sharding tests to multigpu-test.sh and fix custom operator decorator (#77825)
1. Enabled multigpu tests.
2. Fixed failing multigpu tests.
3. Fixed custom operator decorator to be first preference in operator dispatch.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77825
Approved by: https://github.com/wanchaol, https://github.com/fduwjj
2022-05-20 16:53:27 +00:00
Andrew Gu
e69d13b8b3 [FSDP][Easy] Update state_dict() docstring
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77853

Approved by: https://github.com/rohan-varma
2022-05-19 23:59:03 +00:00
Andrew Gu
d9b3feb27d [FSDP][Easy] Reword device placement warning
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77850

Approved by: https://github.com/rohan-varma
2022-05-19 23:57:40 +00:00
Andrew Gu
36bf8007f7 [FSDP][Easy] Fix state_dict_type() docstring example
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77848

Approved by: https://github.com/rohan-varma
2022-05-19 23:53:15 +00:00
Andrew Gu
96e674a0c9 [FSDP][Easy] Doc fixes
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77847

Approved by: https://github.com/rohan-varma
2022-05-19 23:53:15 +00:00
pritam
327d313705 Refactor operator dispatch framework across different Tensors.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77707

Refactor to clean up the following pieces:

1) Consolidate decorators to use a common way to look up operator tables.
2) Move a bunch of utilities to `op_registry_utils` and `common_op_utils` and
reuse them across ShardedTensor, ReplicatedTensor and PartialTensor.

Differential Revision: [D36465639](https://our.internmc.facebook.com/intern/diff/D36465639/)

Approved by: https://github.com/wanchaol, https://github.com/fduwjj
2022-05-19 19:27:07 +00:00
Rohan Varma
eb0ff991f7 [FSDP] Dont move if on CPU (#77720)
After offline discussion, decided that by default moving CPU module to GPU is a bit too risky due to possible OOM during init issue.

Theoretically, we should not OOM because it is required for module that is being wrapped by FSDP to fit into GPU, i.e. during forward. But possibly can be temporary GPU tensors etc allocated during __init___ that break this assumption, it is better for now to allow users a way to init on CPU if needed.

We still warn to use `device_id` for faster init if model is on CPU.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77720
Approved by: https://github.com/zhaojuanmao
2022-05-19 14:47:50 +00:00
Rohan Varma
4a57321a93 [FSDP] Use post load_state_dict hooks (#76912)
Rehash of https://github.com/pytorch/pytorch/pull/75426 now that a revised version of load_state_dict_post_hook has landed.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/76912
Approved by: https://github.com/awgu
2022-05-19 00:35:34 +00:00
Rodrigo Kumpera
c9570e4b88 [checkpoint] Synchronize error handling across all ranks (#77091)
Introduce error handling across all ranks when loading and saving checkpoints.

This makes it a lot simpler for users to handle failures and, as a positive side-effect, coordination of when it successfully finished.

This change requires 3 collectives when saving and 1 when loading.
All those collectives carry a small payload so they will be latency bound and write time should dominate it.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/77091
Approved by: https://github.com/pritamdamania87, https://github.com/wanchaol
2022-05-18 21:24:09 +00:00
Rohan Varma
4c34343216 [FSDP] Warning for shared params, small doc fixes (#77726)
- Add warning about limited shared param suppport
- Some small doc fixes after combing through the docs ; we should do a more thorough doc lookthrough.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77726
Approved by: https://github.com/zhaojuanmao
2022-05-18 14:59:36 +00:00
Andrew Gu
93b20b0232 [FSDP][Easy] Remove extraneous print
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77705

Approved by: https://github.com/zhaojuanmao
2022-05-18 13:16:06 +00:00
lezcano
ff7b6d6b5f Update linalg.*norm
This PR does a number of things:
- Move linalg.vector_norm to structured kernels and simplify the logic
- Fixes a number of prexisting issues with the dtype kwarg of these ops
- Heavily simplifies and corrects the logic of `linalg.matrix_norm` and `linalg.norm` to be consistent with the docs
  - Before the `_out` versions of these functions were incorrect
  - Their implementation is now as efficient as expected, as it avoids reimplementing these operations whenever possible.
- Deprecates `torch.frobenius_norm` and `torch.nuclear_norm`, as they were exposed in the API and they are apparently being used in mobile (??!!) even though they were not documented and their implementation was slow.
  - I'd love to get rid of these functions already, but I guess we have to go through their deprecation.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/76547

Approved by: https://github.com/mruberry
2022-05-18 11:46:50 +00:00
fduwjj
3b2375291a [PT-D][Sharding] Fix view op and matrix ops unit test
To fix a corner case when the sharding dim is negative number we need to handle it correctly.

Also disable RPC for matrix ops which are not necessary and they fail on AWS pytest.

Differential Revision: [D36464797](https://our.internmc.facebook.com/intern/diff/D36464797/)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/77706

Approved by: https://github.com/pritamdamania87, https://github.com/wanchaol
2022-05-18 03:10:37 +00:00
pritam
068d35a648 Make PartialTensot a torch.Tensor subclass
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77626

Differential Revision: [D36435152](https://our.internmc.facebook.com/intern/diff/D36435152/)

Approved by: https://github.com/wanchaol
2022-05-17 21:44:14 +00:00
Rohan Varma
6f954d7bbb FSDP parameter sync
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77492

Approved by: https://github.com/zhaojuanmao
2022-05-17 19:58:49 +00:00
Rohan Varma
8ae0b275f5 Fix device_id
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77491

Approved by: https://github.com/zhaojuanmao
2022-05-17 19:58:49 +00:00
pritam
c83f8ee46a Fix partial_tensor ops.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77580

Replace process_group with _process_group.

Differential Revision: [D36419464](https://our.internmc.facebook.com/intern/diff/D36419464/)

Approved by: https://github.com/wanchaol, https://github.com/fduwjj
2022-05-17 08:21:38 +00:00
Rohan Varma
f9f8127414 CheckpointWrapper state_dict fix (#77224)
- Uses state dict / load state dict hooks to ensure that modules wrapped with `CheckpointWrapper` can be loaded into non-checkpointed wrapped module.

This is because a training run can use activation checkpointing, then we can recover `state_dict`, and a future run may not want to wrap modules with activation checkpointing or decide to change activation checkpoint wrapping structure. To support this, we add hooks to remove / add the relevant prefix as needed.

Tests are added to ensure we can load into CheckpointWrapper module as well as local module from CheckpointWrapper-wrapped module. state_dict with FSDP is also verified.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77224
Approved by: https://github.com/zhaojuanmao
2022-05-17 03:39:31 +00:00
Rodrigo Kumpera
668599a673 Rewrite ShardedTensor.gather to use dist.gather instead of gather_object (#77272)
gather_object is problematic when used with Tensors as they can unpickle on the wrong
device and lead to deadlocks or spurious failures.

This change introduces a RPC workaround for EFA when initing TensorPipe until
they properly address it.

Fixes #73935

Pull Request resolved: https://github.com/pytorch/pytorch/pull/77272
Approved by: https://github.com/pritamdamania87
2022-05-17 02:14:40 +00:00
Wanchao Liang
25fa964d96 [shard] add clone/detach and set requires_grad for ShardedTensor
This PR adding clone/detach and set requires_grad to ShardedTensor

Pull Request resolved: https://github.com/pytorch/pytorch/pull/77367

Approved by: https://github.com/pritamdamania87
2022-05-16 21:42:27 +00:00
wayi1
5ab8afe487 [Model Averaging] Support disabling post-local gradient sync (#76723)
I find that sometimes disabling intra-subgroup gradient allreduce can still give a satisfying accuracy for some cases, so better to make such gradient averaging configurable. This does not take into account the saving in the communication of allreducing gradients.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/76723
Approved by: https://github.com/rohan-varma
2022-05-16 18:09:09 +00:00
Rob Zinkov
2a496e2f80 Adding maximize to Adamax (#77409)
Added the maximize flag #68052 to Adamax optimizer and updates the respective tests.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77409
Approved by: https://github.com/albanD
2022-05-16 17:34:44 +00:00
Sisil Mehta
9d3ffed327 [FSDP] Sharded Grad Scaler (#76918)
Summary: Adding in a shard aware grad scaler for FSDP+MixedPrecision support

Test Plan: Tests added

Differential Revision: D35988676

Pull Request resolved: https://github.com/pytorch/pytorch/pull/76918
Approved by: https://github.com/rohan-varma
2022-05-16 15:53:21 +00:00
Chien-Chin Huang
58c9d521a1 [FSDP] Implement sharded_state_dict and load_sharded_state_dict
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77356

Implement ShardedTensor compatible sharded_state_dict() and load_sharded_state_dict().

Algorithm overview:
sharded_state_dict():
  1. Call summon_full_parameters().
  2. For each unflattened, non-sharded parameter.
      2.1 Call chunk() to get the local shard of the parameter.
      2.2 Create a ShardedTensor.
  3. Replace the tensor in the state_dict with the newly created ShardedTensor.

load_sharded_state_dict():
   1. For each unflattened, sharded parameter (ShardedTensor) in the given state_dict:
       1.1 Pop out from the state_dict.
       1.2 Do allgather to reconstruct the unflattened, non-sharded parameter.
   2. Create a FlatParameter with the unflattened, non-sharded parameters.
   3. Shard the newly created FlatParameter.
   4. Insert the new FlatParameter into the state_dict.

Differential Revision: [D36284983](https://our.internmc.facebook.com/intern/diff/D36284983/)

**NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D36284983/)!

Approved by: https://github.com/zhaojuanmao
2022-05-15 22:48:56 +00:00
fduwjj
a2cb94d21a [PT-D][Sharding] Enable more ops needed in the transformer model training
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77214

From the code base of MetaSeq Model, we have found that loads of ops are not supported by sharded tensor. In https://github.com/pytorch/pytorch/pull/75374, we have enabled most of ops already and this PR/diff aims at enabling the rest of them.

Fix some unit test errors.

Differential Revision: [D36302780](https://our.internmc.facebook.com/intern/diff/D36302780/)

Approved by: https://github.com/wanchaol, https://github.com/pritamdamania87
2022-05-15 22:43:47 +00:00
Rohan Varma
a275491c6f [Reland] load_state_dict post hook (#77392)
Reland of https://github.com/pytorch/pytorch/pull/76823 with fixes to call `__setstate__` for softmax/softmin/logsoftmax as per discussion with @albanD and @jbschlosser. Original description:

Implements `register_load_state_dict_post_hook` API as discussed in https://github.com/pytorch/pytorch/issues/75287.

Unittests cover:
- Ensuring hooks are called with the correct module
- Hook is called with `IncompatibleKeys` field
- If hook modifies this, load_state_dict returns the modified result

Pull Request resolved: https://github.com/pytorch/pytorch/pull/77392
Approved by: https://github.com/jbschlosser
2022-05-14 06:06:23 +00:00
Rohan Varma
aaf5c32992 device_id
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77321

Approved by: https://github.com/awgu
2022-05-13 22:14:49 +00:00
PyTorch MergeBot
d92b0a51aa Revert "Load state dict post hook"
This reverts commit 56bed0dcfe.

Reverted https://github.com/pytorch/pytorch/pull/76823 on behalf of https://github.com/rohan-varma
2022-05-12 21:00:49 +00:00
pritam
bdbb7fe37a Use _process_group in ReplicatedTensor and ShardedTensor.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77209

As per comment in
https://github.com/pytorch/pytorch/pull/77191#discussion_r869641344, making
everthing consistent amongst these tensors.

Differential Revision: [D36300798](https://our.internmc.facebook.com/intern/diff/D36300798/)

Approved by: https://github.com/wanchaol, https://github.com/fduwjj
2022-05-12 16:06:15 +00:00
Andrew Gu
e912d24303 [FSDP] Do not check fwd order in eval mode
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77195

Approved by: https://github.com/zhaojuanmao
2022-05-11 22:55:01 +00:00
Wanchao Liang
0303647083 [shard] Add deepcopy for ShardedTensor
This PR adding deep copy support a ShardedTensor. Sometimes user might want to deep copy a ShardedTensor param for comparision purpose.

Differential Revision: [D36084109](https://our.internmc.facebook.com/intern/diff/D36084109/)

**NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D36084109/)!

Pull Request resolved: https://github.com/pytorch/pytorch/pull/76758

Approved by: https://github.com/pritamdamania87, https://github.com/fduwjj
2022-05-11 22:02:20 +00:00
Rohan Varma
bbb1f106c7 Separate input moving to utils file
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77187

Test fix

Pull Request resolved: https://github.com/pytorch/pytorch/pull/77235

Lint fix

Approved by: https://github.com/awgu
2022-05-11 21:55:38 +00:00
pritam
b91a14900e General fixes for ShardedTensor op framework.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77191

1) Add more basic validation to all ops.
2) Ensure `register_on_local_shards` uses appropriate sharding_spec.

Differential Revision: [D36292159](https://our.internmc.facebook.com/intern/diff/D36292159/)

Approved by: https://github.com/wanchaol, https://github.com/fduwjj
2022-05-11 17:21:44 +00:00
Chien-Chin Huang
4f4ebc5491 [FSDP] Fix local_state_dict and state_dict_type bugs
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77101

1. state_dict_type does not correctly clean up self._state_dict.type.
2. local_state_dict has wrong assumptions that self.module.flat_param must present.

Differential Revision: [D36260234](https://our.internmc.facebook.com/intern/diff/D36260234/)

Approved by: https://github.com/rohan-varma, https://github.com/zhaojuanmao
2022-05-11 16:41:09 +00:00
Rohan Varma
9493900876 [Reland] Mixed precision batchnorm fix (#77234)
Reland of https://github.com/pytorch/pytorch/pull/77089, which was reverted due to land race.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77234
Approved by: https://github.com/zhaojuanmao
2022-05-11 15:03:34 +00:00
PyTorch MergeBot
091f8915ae Revert "Mixed Precision batchnorm fix (#77089)"
This reverts commit bf61b79503.

Reverted https://github.com/pytorch/pytorch/pull/77089 on behalf of https://github.com/suo
2022-05-11 03:00:33 +00:00
Rohan Varma
bf61b79503 Mixed Precision batchnorm fix (#77089)
Rehash of https://github.com/pytorch/pytorch/pull/76642 which could not be updated due to GHF out of sync issue.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77089
Approved by: https://github.com/awgu
2022-05-11 02:22:01 +00:00
Wanchao Liang
81528d4b21 [shard] add more tensor creation ops (#77185)
Summary: This PR fix some existing tensor constructors (i.e. full), add sharded_tensor.randn, and add more tensor-like creation ops (i.e. full_like, rand_like, etc.)

Test Plan:
test_create_sharded_tensor_with_rand
test_create_sharded_tensor_like

Differential Revision: D36274148

Pull Request resolved: https://github.com/pytorch/pytorch/pull/77185
Approved by: https://github.com/pritamdamania87, https://github.com/fduwjj
2022-05-11 00:00:35 +00:00
Rohan Varma
e31b6213ac Profiling range for FSDP.forward (#76899)
Same as https://github.com/pytorch/pytorch/pull/76749 which had issues in updating via ghstack.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/76899
Approved by: https://github.com/awgu
2022-05-10 17:52:02 +00:00
Andrew Gu
9903f1ae4a [FSDP] Do not clone buffers; offload buffers to CPU if needed
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77000

Approved by: https://github.com/rohan-varma
2022-05-10 13:21:30 +00:00
Yanli Zhao
3621462ebb [FSDP] Change default auto wrap policy name (#76858)
current default_auto_wrap_policy is not really a recommended default auto wrap policy, change the name to avoid confusion, updated the doc as well

Pull Request resolved: https://github.com/pytorch/pytorch/pull/76858
Approved by: https://github.com/rohan-varma
2022-05-10 13:20:26 +00:00
Andrew Gu
75f316f14e [FSDP] Move param/buffer name comp. to ctor for ignored_modules
Pull Request resolved: https://github.com/pytorch/pytorch/pull/76994

Approved by: https://github.com/rohan-varma
2022-05-10 13:16:57 +00:00
Jeroen Van Goey
a238bab17c Typo fix in generated module name (#76880)
`f"{_FILE_PREFIX}non_sriptable"` -> `f"{_FILE_PREFIX}non_scriptable"`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/76880
Approved by: https://github.com/mrshenli
2022-05-09 00:51:58 +00:00
pritam
9e52b50e34 Additional ops for ShardedTensor, ReplicatedTensor and PartialTensor.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/76477

Adding the following ops:

1) softmax for ShardedTensor
2) getitem and unsqueeze for ReplicatedTensor
3) transpose and cat for PartialTensor

Differential Revision: [D35979510](https://our.internmc.facebook.com/intern/diff/D35979510/)

Approved by: https://github.com/fduwjj, https://github.com/wanchaol
2022-05-06 16:28:04 +00:00