Commit Graph

8 Commits

Author SHA1 Message Date
Rohan Varma
d2f37401b8 Silence namedtuple warning in dist (#84072)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84072
Approved by: https://github.com/awgu
2022-08-26 00:24:28 +00:00
Rohan Varma
b29a074882 [BE] Revert distributed change in https://github.com/pytorch/pytorch/pull/68779 (#83181)
https://github.com/pytorch/pytorch/issues/82641 points out a regression in how inputs / outputs are processed by DDP, blocking their HF use case. It was narrowed down to https://github.com/pytorch/pytorch/pull/68779 and reverting the distributed change there fixes the issue.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/83181
Approved by: https://github.com/kumpera
2022-08-23 02:38:23 +00:00
Rohan Varma
6f954d7bbb FSDP parameter sync
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77492

Approved by: https://github.com/zhaojuanmao
2022-05-17 19:58:49 +00:00
Rohan Varma
f9f8127414 CheckpointWrapper state_dict fix (#77224)
- Uses state dict / load state dict hooks to ensure that modules wrapped with `CheckpointWrapper` can be loaded into non-checkpointed wrapped module.

This is because a training run can use activation checkpointing, then we can recover `state_dict`, and a future run may not want to wrap modules with activation checkpointing or decide to change activation checkpoint wrapping structure. To support this, we add hooks to remove / add the relevant prefix as needed.

Tests are added to ensure we can load into CheckpointWrapper module as well as local module from CheckpointWrapper-wrapped module. state_dict with FSDP is also verified.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77224
Approved by: https://github.com/zhaojuanmao
2022-05-17 03:39:31 +00:00
Rohan Varma
bbb1f106c7 Separate input moving to utils file
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77187

Test fix

Pull Request resolved: https://github.com/pytorch/pytorch/pull/77235

Lint fix

Approved by: https://github.com/awgu
2022-05-11 21:55:38 +00:00
Rohan Varma
ffb0946504 Generalize param verification and broadcast
New PR for https://github.com/pytorch/pytorch/pull/75970 to be compatible with GHF.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/76374
Approved by: https://github.com/awgu
2022-04-26 22:25:53 +00:00
Pritam Damania
b8e6144e0a Add a _RemoteDevice structure for ShardedTensor/ShardingSpec. (#62927)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/62927

As part of the ShardedTensor work, we realized we do need some sort of
_RemoteDevice structure that deals with our format of "workername/device" so
that users don't have to worry about parsing this string directly.

Right now this structure is just the bare minimum and is mostly a container for
describing a remote device. It is currently only used in ShardedTensor,
ShardingSpec and RemoteModule.

Once we actually have a consolidated remote device proposal, this class can be
extended appropriately if needed.
ghstack-source-id: 135534086

Test Plan:
1) unit tests
2) waitforbuildbot

Reviewed By: SciPioneer

Differential Revision: D30170689

fbshipit-source-id: 1ac2e81c7a597dc40bf3fbf2c1168c382c66649f
2021-08-11 11:27:32 -07:00
Pritam Damania
0d6fa1adc5 Introduce ChunkShardingSpec as a model sharding specification. (#55728)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/55728

Full design: https://github.com/pytorch/pytorch/issues/55207

This PR introduces ChunkShardingSpec (SingleShardingSpec in the design). Used
the name ChunkShardingSpec since it is very similar to `torch.chunk` in terms
of how a Tensor is split up and feels more clear compared to SingleShardingSpec.
ghstack-source-id: 129603318

Test Plan: waitforbuildbot

Reviewed By: SciPioneer

Differential Revision: D27694108

fbshipit-source-id: c8764abe6a4d5fc56d023fda29b74b5af2a73b49
2021-05-23 16:04:57 -07:00