Commit Graph

5 Commits

Author SHA1 Message Date
Wanchao Liang
0524b2829a [shard] Add ReplicatedTensor (#73529)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/73529

Add ReplicatedTensor, a ReplicatedTensor is a type of tensor that have the same value on all ranks across the world_size.

ReplicatedTensor is a :class:`~torch.Tensor` subclass, and it could be used together with ShardedTensor/Tensor together to express different types of computation. The inter-op rules defined as (using torch.add as an example op):
    ReplicatedTensor + ReplicatedTensor = ReplicatedTensor
    ReplicatedTensor + torch.Tensor = torch.Tensor
    ReplicatedTensor + ShardedTensor = ShardedTensor

We also added a `validate()` API to help user validate if a replicated tensor on certain process_group is truly replicated or not.

TODO: next PR gonna add ShardedTensor/PartialTensor logic to handle ReplicatedTensor.
ghstack-source-id: 152064781

Test Plan: test_replicated_tensor

Reviewed By: pritamdamania87, fduwjj

Differential Revision: D34529374

fbshipit-source-id: 16ccb300e9f9c47ac29a17eb6d46d029ab7d60b8
(cherry picked from commit 44f4e11e795a1bf330a8108bda256950ca769525)
2022-03-24 12:41:17 +00:00
Wanchao Liang
d6c5295ec8 [shard] Extensible ShardingSpec (#72130)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/72130

1. Refactor ShardingSpec, decouple PlacementSpec and ShardingSpec, as they are essentially two separate concept
2. Introduce customizable ShardingSpec, with the help of two APIs, we can allow user to inherit and define their own customized sharding spec:
  * ShardingSpec.build_metadata, which takes a tensor shape and define how to shard a tensor like this shape across ranks, return a ShardedTensorMetadata that describes the layout.
  * ShardingSpec.shard: define how to shard a tensor into ShardedTensor
3. Refactor `ShardedTensor.__init__` and `shard_parameter` to take the new ShardingSpec, enable these two API to support both ChunkShardingSpec and EnumerableShardingSpec

ghstack-source-id: 149788833

Test Plan: wait for ci

Reviewed By: fduwjj

Differential Revision: D33923403

fbshipit-source-id: 3236beec8543da651dfd89c32b6968745c59301e
(cherry picked from commit 5994b33a7a6ad96b1fad2e121c6bdd83a877346e)
2022-02-24 04:30:48 +00:00
Junjie Wang (PyTorch)
b02c514764 [PT-D][Sharded Tensor] new init api for local tensor and sharding spec auto inference (#72733)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/72733

To improve the perf cost due to communication in the process of init the sharded tensor. There are two changes in this PR/diff:

1. We create a new API named `_init_from_local_tensor` so that if we have only one local tensor, we can initiate a sharded tensor directly from it. (GH issue: https://github.com/pytorch/pytorch/issues/72092)
2. We create a new API to infer the sharding spec from global meta data, so we don't have to manually set the sharding spec when it's not `EnumerableShardingSpec`. (GH issue: https://github.com/pytorch/pytorch/issues/67244)
ghstack-source-id: 149229259

Test Plan: CI

Reviewed By: wanchaol

Differential Revision: D34132739

fbshipit-source-id: 3a60135761bcc19d6020b6c45cb2979869645ce6
(cherry picked from commit af569325e2)
2022-02-16 17:42:39 +00:00
Junjie Wang (PyTorch)
19d0de8a57 [PT-D][RFC] Resharding related API implement for ShardedTensor and Partial Tensor (#70079)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/70079

We defined a new concept named `PartialTensor`, which is an abstraction to represent Tensors that need aggregation across multiple devices and multiple processes.

We also defined a API `reshard_output` to reshard a `PartialTensor` to `Tensor` or reshard a `ShardedTensor` to `ShardedTensor/Tensor`. This is done via class `ModuleResharder` which acts like a wrapper of original modules plus the a reshard in the final step.

The `reshard` logic is defined in each class (`ShardedTensor` and `PartialTensor`).
ghstack-source-id: 148273050

Test Plan: Unit test is in the next PR.

Reviewed By: pritamdamania87

Differential Revision: D33121037

fbshipit-source-id: 5f56617ea526b857c5b73df6e069697d428ec359
(cherry picked from commit 58b1457cbc)
2022-02-03 05:26:02 +00:00
Pritam Damania
64670e414e [reland] Create torch.distributed._shard package. (#72141)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/72141

We have many sharding components currently:
torch.distributed._sharded_tensor, torch.distributed._sharding_spec,
torch.distributed._sharded_optimizer and more coming.

As a result, organizing all of this under the `torch.distributed._shard`
package. For BC reasons, I'm still keeping the old packages and have them just
reference the new package.
ghstack-source-id: 148150861
ghstack-source-id: 148150861

Test Plan: waitforbuildbot

Reviewed By: fduwjj

Differential Revision: D33904585

fbshipit-source-id: 057e847eb7521b536a3ee4e0f94871aacc752062
(cherry picked from commit 29a70dd7af)
2022-02-02 06:58:20 +00:00