Commit Graph

29 Commits

Author SHA1 Message Date
NVS Abhilash
44c0521e8c fix: docstring error in torch/distributed module (#113241)
Fixes: #113193

`pydocstyle <all_files_in_issue> --count`

- Before: 345
- After: 130

For deprecated methods, I have added a `noqa` to ignore them. I was not able to find the file `torch/distributed/tensor/parallel/multihead_attention_tp.py`, so I've ignored it for this PR.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113241
Approved by: https://github.com/kit1980
2023-11-09 19:10:20 +00:00
fduwjj
bfcd86955e [TP] Fix TP doc format to show examples correctly (#111346)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/111346
Approved by: https://github.com/wanchaol
ghstack dependencies: #111160, #111166, #111176, #111177
2023-10-16 06:15:10 +00:00
fduwjj
25a2845d78 [TP] Enable embedding sharding in TP API (#111177)
We see use cases where embedding sharding is also needed in TP API so we enabled it in the API since DTensor already support colwise embedding sharding.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/111177
Approved by: https://github.com/wanchaol
ghstack dependencies: #111160, #111166, #111176
2023-10-15 11:49:56 +00:00
fduwjj
8085e08a84 [TP] Add prepareInput and output for input/output DTensor layout annotation in the parent module in TP API (#111166)
In some use cases, we found that users might want to annote the input/output DTensor layout for the parent module rather than the submodule whose parameters are to be distributed so that we want to have these two class for users to annote input/output DTensor layouts so that we register pre-FWD/FWD hook for the TP-lized module.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/111166
Approved by: https://github.com/wanchaol
ghstack dependencies: #111160
2023-10-14 15:37:52 +00:00
fduwjj
3a8b10e2da [TP] Refactor Parallel Style to make it more usable (#111160)
One thing we find it challenging for users is that we don't want to expose the concept of prepare_input and prepare_out to users since there are so many func names for users to select from which is quite confusing. On the other hand, the colwise and rowwise parallel always need input(out) and output(in) to be certain layout so we can somehow simplify the logic here and make it more usable.

So we added three public attributes to the parallelStyle here and the code logic is like:

```python
class ParallelStyle(ABC):
    """
    The parallel style user wants the module or submodule to be parallelized.
    We can add more in future, but this seems sufficient for immediate needs. Users can extend this class to build their own parallel style with customized input/output preparations.
  """
    input_layouts: Union[placement, Tuple[placement]]
    output_layouts: Union[placement, Tuple[placement]]
    use_local: bool

class RowwiseParallel(ParallelStyle):
    """
    Partitioning the row of a module. We assume the input to be a sharded DTensor and output to be a replicate Tensor.
    """
    def __init__(self):
        super().__init__(input_layouts=Shard(-1), output_layouts=Replicate(), use_local=True)

Class ColwiseParallel(ParallelStyle):
    """
    Partitioning the column of a module. We assume the input to be a Replicated DTensor and output to be a sharded DTensor.
    """
    def __init__(self):
        super().__init__(input_layouts=Replicate(), output_layouts=Shard(-1), use_local=True)

# For the case of Sequence parallel, users just set different input_shard, Shard(0) or Shard(1) instead of Replicate()

Class PrepareModuleInput(ParallelStyle):
    """
    Only used to specify the input distribute spec for a module.
    """
    def __init__(self):
        super().__init__(input_layouts=Shard(0), output_layouts=Replicate(), use_local=False)

Class PrepareModuleOutput(ParallelStyle):
    """
    Only used to specify the output distribute spec for a module.
    """
    def __init__(self):
        super().__init__(input_layouts=Replicate(), output_layouts=Shard(0), use_local=True)

parallelize_plan = {
    "embedding": ColwiseParallel(output_shard=Replicate()),
    "attn": PrepareModuleInput(),
    "attn.w1": ColwiseParallel(),
    "attn.w2": ColwiseParallel(),
    "attn.w3": ColwiseParallel(),
    "attn.wo": RowwiseParallel(),
}

parallelize_module(
    module=block, # this can be a submodule or module
    device_mesh=mesh['tp'],
    parallelize_plan=parallelize_plan,
)
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/111160
Approved by: https://github.com/wanchaol
2023-10-14 15:26:36 +00:00
wz337
8e32e62f67 [TP] Validate TP mesh dim for 2D composition (#111001)
Currently, we only support intranode TP when compositin TP with other parallelism. This PR adds additional check to validate the TP mesh dim in TP initialization when parent mesh exists.

cc. @fegin, @fduwjj
Pull Request resolved: https://github.com/pytorch/pytorch/pull/111001
Approved by: https://github.com/fduwjj
2023-10-12 02:11:44 +00:00
Wanchao Liang
d8f2ef10a6 [dtensor][1/n] refactor op dispatch logic to reduce overhead (#107305)
This PR is the first change of a series of refactors to the op dispatch logic to:
1. remove the redundant logic in the op dispatch, simplify the error
checking
2. reduce the number of tree_map/tree_flatten/unflatten needed to reduce
the overhead coming from those operations
3. remove the CachedShardingPropagator by using lru_cache from functools
directly, this makes it not only helps TP, but general DTensor
operations could be faster!
4. change the view ops behavior by inplace changing the op_schema, which
is dangerous for sharding prop caching, model the view op as one type
of resharding too
5. enrich output sharding to include whether the op needs redistribute
so that we don't need explicit op schema comparison to know it.

This should help with further reducing the CPU overhead, benchmark
results:
before (without this change), aten.addmm latency: 0.476ms
![Screenshot 2023-08-16 at 10 46 26 AM](https://github.com/pytorch/pytorch/assets/9443650/7692e6c1-1936-4c7f-bf9c-6c8c9b8f6c76)

after (with this change), aten.addmm latency: 0.341ms
![Screenshot 2023-08-16 at 11 05 49 AM](https://github.com/pytorch/pytorch/assets/9443650/15a53f0b-7a95-444e-ab2f-3ee0ad2fa47f)

overall one layer of mlp time reduced from 13.535 -> 9.665ms

Apart from overhead reduction, this PR simplifies the op dispatching logic and the resharding logic (more refactor needed to make things more clean, which will be done in later PRs)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/107305
Approved by: https://github.com/fduwjj
2023-08-18 18:30:46 +00:00
alanhe151220037
1afbc985fe Make RNGStateTracker support cuda-like device (#106771)
replace  `CudaRNGStateTracker` with `RNGStateTracker` by rewriting some Cuda-binding code with `device_handle`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/106771
Approved by: https://github.com/wanchaol
2023-08-10 19:14:33 +00:00
fduwjj
487ebcac3b Clean up unsed MHA code to avoid confusion (#105956)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105956
Approved by: https://github.com/wz337, https://github.com/ezyang, https://github.com/wanchaol
2023-07-27 17:10:17 +00:00
FFFrog
9a1cdcb8a0 Format: fixing multiple string concatenation in single line (#106013)
Fixing multiple string concatenation in single line
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106013
Approved by: https://github.com/albanD
2023-07-26 18:39:18 +00:00
Xilun Wu
e799f565eb [DTensor][TP][Random] Introduce TensorParallelRNGTracker to integrate parallel RNG state with Tensor Parallel (#103910)
This PR enables the automatic use of `TensorParallelRNGTracker` in Tensor Parallel api. Some unit tests are going to be added to cover.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103910
Approved by: https://github.com/wanchaol, https://github.com/fduwjj
2023-06-30 08:06:41 +00:00
fduwjj
d4380edb9b [TP] Add API logging for TP high level API (#102209)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102209
Approved by: https://github.com/wz337, https://github.com/wanchaol
2023-05-25 03:33:00 +00:00
Wanchao Liang
a1aa32e204 [dtensor] tensor ops to use strategy based sharding prop (#100607)
This is the first series of PR that adopts operator impls to use a
strategy based approach, each op utilizes OpStrategy and PlacementStrategy
to generate their own strategy. By utilizing the strategy based
approach along with the op graph, we could enable more advanced op
implementation (decomp is possible), and turn the sharding prop to be
more like a contraint satisfication problem.

This PR alone only adds some basic tensor op strategies, and it directly
works on the op graph that was used for metadata propagation. The tensor ops
added in this PR mainly follows one of the arg strategy. The next set of
PRs would add more op strategies to other ops.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100607
Approved by: https://github.com/XilunWu
2023-05-11 02:47:20 +00:00
fduwjj
953aa6d90e [TP] Enable more generic attn in Tensor Parallelism (#100508)
To make TP more generic for Attention module, we come up with this new col/rowwise parallel style.

Basically, the idea behind is that:
We only do DTensor op for Col/Rowwise sharded part. For the rest of ATen ops, we will leave it to Tensor ops.

And we set this behavior as default for Colwise and Rowwise parallel style. If people want to customize it, they can always pass in different prepare_input or prepare_output

Pull Request resolved: https://github.com/pytorch/pytorch/pull/100508
Approved by: https://github.com/wanchaol
2023-05-07 18:15:49 +00:00
fduwjj
89b1e67d0a [Tensor Parallel] Add a new Colwise Parallel style when Pairwise cannot directly used (#100137)
Some use cases, users cannot directly `PairwiseParallelStyle` and they might need to specify colwise and rowwise separately.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100137
Approved by: https://github.com/wz337
2023-04-28 03:27:51 +00:00
Rohan Varma
be8c7c06b6 [Tensor Parallel] Simplify distribute for MHA (#100046)
This function is only called for nn.MHA or the custom MHA we use, and
if it is the former it is converted to the latter. So this check can actually
be an assert.

Differential Revision: [D45300396](https://our.internmc.facebook.com/intern/diff/D45300396/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100046
Approved by: https://github.com/wanchaol
2023-04-27 00:54:21 +00:00
Kazuaki Ishizaki
35fd5c548e Fix typos under torch/distributed directory (#95638)
This PR fixes typos in comments and messages of `.py` files under torch/distributed directory

Pull Request resolved: https://github.com/pytorch/pytorch/pull/95638
Approved by: https://github.com/usamah1, https://github.com/H-Huang, https://github.com/kit1980
2023-03-27 21:13:44 +00:00
fduwjj
b209d8fa0d [PT-D][Sequence Parallelism] Enable DTensor based Naive sequence parallelism (#94369)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/94369
Approved by: https://github.com/wanchaol
2023-02-16 21:21:00 +00:00
Wanchao Liang
cd9ca4c73f [tp] additional doc fixes (#94786)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/94786
Approved by: https://github.com/fduwjj
2023-02-15 21:25:26 +00:00
PyTorch MergeBot
28ed0bdb37 Revert "[tp] additional doc fixes (#94786)"
This reverts commit 7522ca55f1.

Reverted https://github.com/pytorch/pytorch/pull/94786 on behalf of https://github.com/huydhn due to Sorry for reverting your PR, but the doc failure looks related and they are also failing in trunk 7522ca55f1
2023-02-14 05:43:37 +00:00
Wanchao Liang
7522ca55f1 [tp] additional doc fixes (#94786)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/94786
Approved by: https://github.com/fduwjj
2023-02-14 04:52:04 +00:00
Wanchao Liang
2db12e3844 [tp] minor update to TP docs (#94748)
minor update to TP docs for beta release
Pull Request resolved: https://github.com/pytorch/pytorch/pull/94748
Approved by: https://github.com/fduwjj
2023-02-13 21:54:19 +00:00
Aaron Gokaslan
1e2d82b8e4 [BE] Merge isinstance calls together (#94419)
Simplify and speeds up isinstance calls by checking for multiple types at the same time.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/94419
Approved by: https://github.com/ezyang
2023-02-09 00:47:26 +00:00
fduwjj
3fb6e119e2 [PT-D][TP] Fix the module registration in TP API (#93412)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/93412
Approved by: https://github.com/XilunWu
2023-02-01 21:03:56 +00:00
Wanchao Liang
9a56997fe1 [dtensor][5/N] add cached propagator for TP (#90734)
This PR adds a cached propagator for TP use, it caches the sharding
prop decision for the same input sharding on an operator. This could
improve eager mode performance.

Differential Revision: [D42876249](https://our.internmc.facebook.com/intern/diff/D42876249)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90734
Approved by: https://github.com/XilunWu, https://github.com/fduwjj
2023-02-01 05:04:08 +00:00
fduwjj
913866efbf [PT-D][TP] Fix TP API for FQN path based parallelization (#93029)
We have not tested dict based parallelize_module and turns out we had mistakes here.

1. Fix the error.
2. Add unit test cases for it.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/93029
Approved by: https://github.com/wz337
2023-01-26 09:10:21 +00:00
joncrall
ad782ff7df Enable xdoctest runner in CI for real this time (#83816)
Builds on #83317 and enables running the doctests. Just need to figure out what is causing the failures.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/83816
Approved by: https://github.com/ezyang, https://github.com/malfet
2022-12-29 05:32:42 +00:00
Wanchao Liang
9b5e6b029f [tp] umft distributed.tensor.parallel (#89969)
cmd: `ufmt format torch/distributed/tensor`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/89969
Approved by: https://github.com/fduwjj
2022-12-01 20:58:16 +00:00
Wanchao Liang
4451eb24e6 Move tensor_parallel out to distributed.tensor folder (#89878)
This PR moves tensor parallel from torch.distributed._tensor.parallel
to torch.distributed.tensor.parallel, to prepare for beta release
Pull Request resolved: https://github.com/pytorch/pytorch/pull/89878
Approved by: https://github.com/fduwjj
2022-11-30 22:13:10 +00:00