Differential Revision: D47568928

Pull Request resolved: https://github.com/pytorch/pytorch/pull/105507
Approved by: https://github.com/awgu, https://github.com/fduwjj
This commit is contained in:
Mo Mo 2023-07-19 20:34:43 +00:00 committed by PyTorch MergeBot
parent 801fb93b0c
commit 7b56238551

View File

@ -139,7 +139,7 @@ sharded_module = distribute_module(MyModule(), mesh, partition_fn=shard_fc)
## Compiler and PyTorch DTensor
DTensor provides efficient solutions for cases like Tensor Parallelism. But when using the DTensor's replication in a data parallel fashion, it might become observably slower compared to our existing solutions like DDP/FSDP. This is mainly because mainly because DDP/FSDP have a global view of the entire model architecture, thus could optimize for data parallel specifically, i.e. collective fusion and computation overlap, etc. In contract, DistributedTensor as a Tensor-like object can only optimize within individual tensor operations.
DTensor provides efficient solutions for cases like Tensor Parallelism. But when using the DTensor's replication in a data parallel fashion, it might become observably slower compared to our existing solutions like DDP/FSDP. This is mainly because DDP/FSDP have a global view of the entire model architecture, thus could optimize for data parallel specifically, i.e. collective fusion and computation overlap, etc. In contrast, DistributedTensor as a Tensor-like object can only optimize within individual tensor operations.
To improve efficiency of DTensor-based data parallel training, we are exploring a compiler-based solution on top of DTensor, which can extract graph information from user programs to expose more performance optimization opportunities.