Commit Graph

10 Commits

Author SHA1 Message Date
fduwjj
25a2845d78 [TP] Enable embedding sharding in TP API (#111177)
We see use cases where embedding sharding is also needed in TP API so we enabled it in the API since DTensor already support colwise embedding sharding.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/111177
Approved by: https://github.com/wanchaol
ghstack dependencies: #111160, #111166, #111176
2023-10-15 11:49:56 +00:00
fduwjj
8085e08a84 [TP] Add prepareInput and output for input/output DTensor layout annotation in the parent module in TP API (#111166)
In some use cases, we found that users might want to annote the input/output DTensor layout for the parent module rather than the submodule whose parameters are to be distributed so that we want to have these two class for users to annote input/output DTensor layouts so that we register pre-FWD/FWD hook for the TP-lized module.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/111166
Approved by: https://github.com/wanchaol
ghstack dependencies: #111160
2023-10-14 15:37:52 +00:00
fduwjj
3a8b10e2da [TP] Refactor Parallel Style to make it more usable (#111160)
One thing we find it challenging for users is that we don't want to expose the concept of prepare_input and prepare_out to users since there are so many func names for users to select from which is quite confusing. On the other hand, the colwise and rowwise parallel always need input(out) and output(in) to be certain layout so we can somehow simplify the logic here and make it more usable.

So we added three public attributes to the parallelStyle here and the code logic is like:

```python
class ParallelStyle(ABC):
    """
    The parallel style user wants the module or submodule to be parallelized.
    We can add more in future, but this seems sufficient for immediate needs. Users can extend this class to build their own parallel style with customized input/output preparations.
  """
    input_layouts: Union[placement, Tuple[placement]]
    output_layouts: Union[placement, Tuple[placement]]
    use_local: bool

class RowwiseParallel(ParallelStyle):
    """
    Partitioning the row of a module. We assume the input to be a sharded DTensor and output to be a replicate Tensor.
    """
    def __init__(self):
        super().__init__(input_layouts=Shard(-1), output_layouts=Replicate(), use_local=True)

Class ColwiseParallel(ParallelStyle):
    """
    Partitioning the column of a module. We assume the input to be a Replicated DTensor and output to be a sharded DTensor.
    """
    def __init__(self):
        super().__init__(input_layouts=Replicate(), output_layouts=Shard(-1), use_local=True)

# For the case of Sequence parallel, users just set different input_shard, Shard(0) or Shard(1) instead of Replicate()

Class PrepareModuleInput(ParallelStyle):
    """
    Only used to specify the input distribute spec for a module.
    """
    def __init__(self):
        super().__init__(input_layouts=Shard(0), output_layouts=Replicate(), use_local=False)

Class PrepareModuleOutput(ParallelStyle):
    """
    Only used to specify the output distribute spec for a module.
    """
    def __init__(self):
        super().__init__(input_layouts=Replicate(), output_layouts=Shard(0), use_local=True)

parallelize_plan = {
    "embedding": ColwiseParallel(output_shard=Replicate()),
    "attn": PrepareModuleInput(),
    "attn.w1": ColwiseParallel(),
    "attn.w2": ColwiseParallel(),
    "attn.w3": ColwiseParallel(),
    "attn.wo": RowwiseParallel(),
}

parallelize_module(
    module=block, # this can be a submodule or module
    device_mesh=mesh['tp'],
    parallelize_plan=parallelize_plan,
)
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/111160
Approved by: https://github.com/wanchaol
2023-10-14 15:26:36 +00:00
fduwjj
3828cd4b79 [TP][EZ] Update doc for TP parallel style (#107819)
We need to update the doc for PairwiseParallel and SequenceParallel so that users don't get wrong impressions that these working for ``nn.Transformer``.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/107819
Approved by: https://github.com/awgu, https://github.com/wanchaol
2023-08-24 00:13:52 +00:00
fduwjj
953aa6d90e [TP] Enable more generic attn in Tensor Parallelism (#100508)
To make TP more generic for Attention module, we come up with this new col/rowwise parallel style.

Basically, the idea behind is that:
We only do DTensor op for Col/Rowwise sharded part. For the rest of ATen ops, we will leave it to Tensor ops.

And we set this behavior as default for Colwise and Rowwise parallel style. If people want to customize it, they can always pass in different prepare_input or prepare_output

Pull Request resolved: https://github.com/pytorch/pytorch/pull/100508
Approved by: https://github.com/wanchaol
2023-05-07 18:15:49 +00:00
fduwjj
89b1e67d0a [Tensor Parallel] Add a new Colwise Parallel style when Pairwise cannot directly used (#100137)
Some use cases, users cannot directly `PairwiseParallelStyle` and they might need to specify colwise and rowwise separately.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100137
Approved by: https://github.com/wz337
2023-04-28 03:27:51 +00:00
fduwjj
b209d8fa0d [PT-D][Sequence Parallelism] Enable DTensor based Naive sequence parallelism (#94369)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/94369
Approved by: https://github.com/wanchaol
2023-02-16 21:21:00 +00:00
fduwjj
41e3189222 [PT-D][Tensor parallelism] Add documentations for TP (#94421)
This is far from completed and we will definitely polish it down the road.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/94421
Approved by: https://github.com/wz337
2023-02-09 02:31:06 +00:00
Wanchao Liang
9b5e6b029f [tp] umft distributed.tensor.parallel (#89969)
cmd: `ufmt format torch/distributed/tensor`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/89969
Approved by: https://github.com/fduwjj
2022-12-01 20:58:16 +00:00
Wanchao Liang
4451eb24e6 Move tensor_parallel out to distributed.tensor folder (#89878)
This PR moves tensor parallel from torch.distributed._tensor.parallel
to torch.distributed.tensor.parallel, to prepare for beta release
Pull Request resolved: https://github.com/pytorch/pytorch/pull/89878
Approved by: https://github.com/fduwjj
2022-11-30 22:13:10 +00:00