Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/73529
Add ReplicatedTensor, a ReplicatedTensor is a type of tensor that have the same value on all ranks across the world_size.
ReplicatedTensor is a :class:`~torch.Tensor` subclass, and it could be used together with ShardedTensor/Tensor together to express different types of computation. The inter-op rules defined as (using torch.add as an example op):
ReplicatedTensor + ReplicatedTensor = ReplicatedTensor
ReplicatedTensor + torch.Tensor = torch.Tensor
ReplicatedTensor + ShardedTensor = ShardedTensor
We also added a `validate()` API to help user validate if a replicated tensor on certain process_group is truly replicated or not.
TODO: next PR gonna add ShardedTensor/PartialTensor logic to handle ReplicatedTensor.
ghstack-source-id: 152064781
Test Plan: test_replicated_tensor
Reviewed By: pritamdamania87, fduwjj
Differential Revision: D34529374
fbshipit-source-id: 16ccb300e9f9c47ac29a17eb6d46d029ab7d60b8
(cherry picked from commit 44f4e11e795a1bf330a8108bda256950ca769525)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/72130
1. Refactor ShardingSpec, decouple PlacementSpec and ShardingSpec, as they are essentially two separate concept
2. Introduce customizable ShardingSpec, with the help of two APIs, we can allow user to inherit and define their own customized sharding spec:
* ShardingSpec.build_metadata, which takes a tensor shape and define how to shard a tensor like this shape across ranks, return a ShardedTensorMetadata that describes the layout.
* ShardingSpec.shard: define how to shard a tensor into ShardedTensor
3. Refactor `ShardedTensor.__init__` and `shard_parameter` to take the new ShardingSpec, enable these two API to support both ChunkShardingSpec and EnumerableShardingSpec
ghstack-source-id: 149788833
Test Plan: wait for ci
Reviewed By: fduwjj
Differential Revision: D33923403
fbshipit-source-id: 3236beec8543da651dfd89c32b6968745c59301e
(cherry picked from commit 5994b33a7a6ad96b1fad2e121c6bdd83a877346e)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/72733
To improve the perf cost due to communication in the process of init the sharded tensor. There are two changes in this PR/diff:
1. We create a new API named `_init_from_local_tensor` so that if we have only one local tensor, we can initiate a sharded tensor directly from it. (GH issue: https://github.com/pytorch/pytorch/issues/72092)
2. We create a new API to infer the sharding spec from global meta data, so we don't have to manually set the sharding spec when it's not `EnumerableShardingSpec`. (GH issue: https://github.com/pytorch/pytorch/issues/67244)
ghstack-source-id: 149229259
Test Plan: CI
Reviewed By: wanchaol
Differential Revision: D34132739
fbshipit-source-id: 3a60135761bcc19d6020b6c45cb2979869645ce6
(cherry picked from commit af569325e2)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/70079
We defined a new concept named `PartialTensor`, which is an abstraction to represent Tensors that need aggregation across multiple devices and multiple processes.
We also defined a API `reshard_output` to reshard a `PartialTensor` to `Tensor` or reshard a `ShardedTensor` to `ShardedTensor/Tensor`. This is done via class `ModuleResharder` which acts like a wrapper of original modules plus the a reshard in the final step.
The `reshard` logic is defined in each class (`ShardedTensor` and `PartialTensor`).
ghstack-source-id: 148273050
Test Plan: Unit test is in the next PR.
Reviewed By: pritamdamania87
Differential Revision: D33121037
fbshipit-source-id: 5f56617ea526b857c5b73df6e069697d428ec359
(cherry picked from commit 58b1457cbc)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/72141
We have many sharding components currently:
torch.distributed._sharded_tensor, torch.distributed._sharding_spec,
torch.distributed._sharded_optimizer and more coming.
As a result, organizing all of this under the `torch.distributed._shard`
package. For BC reasons, I'm still keeping the old packages and have them just
reference the new package.
ghstack-source-id: 148150861
ghstack-source-id: 148150861
Test Plan: waitforbuildbot
Reviewed By: fduwjj
Differential Revision: D33904585
fbshipit-source-id: 057e847eb7521b536a3ee4e0f94871aacc752062
(cherry picked from commit 29a70dd7af)