Commit Graph

20 Commits

Author SHA1 Message Date
Wanchao Liang
ff061baa94 [comm_mode] adding some initial c10d ops to CommDebugMode (#125475)
looks like we can make it work :)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/125475
Approved by: https://github.com/awgu
2024-05-04 04:20:46 +00:00
Wanchao Liang
00df0d3e94 [dtensor] implement shard dim change with alltoall (#124872)
as titled, we implement a dedicated communication op to allow efficient
sharding dimension change using alltoall, to replace our previous
allgather + local chunk

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124872
Approved by: https://github.com/XilunWu, https://github.com/yifuwang
ghstack dependencies: #124871
2024-04-30 18:30:34 +00:00
PyTorch MergeBot
f1d1e3246f Revert "[dtensor] implement shard dim change with alltoall (#124872)"
This reverts commit 6b79469d24.

Reverted https://github.com/pytorch/pytorch/pull/124872 on behalf of https://github.com/clee2000 due to broke distributed/tensor/parallel/test_tp_examples.py::DistTensorParallelExampleTest::test_transformer_training_is_seq_parallel_True https://github.com/pytorch/pytorch/actions/runs/8882762411/job/24389191482 f7f018a0ed.  Bad TD ([comment](https://github.com/pytorch/pytorch/pull/124872#issuecomment-2083599445))
2024-04-29 20:26:16 +00:00
Wanchao Liang
6b79469d24 [dtensor] implement shard dim change with alltoall (#124872)
as titled, we implement a dedicated communication op to allow efficient
sharding dimension change using alltoall, to replace our previous
allgather + local chunk

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124872
Approved by: https://github.com/XilunWu, https://github.com/yifuwang
ghstack dependencies: #124871
2024-04-29 17:22:30 +00:00
Tristan Rice
ddd0ed1b43 distributed: templated ring attention (#124215)
This adds a templated version of the ring attention forwards function as well as tests it with memory efficient attention. This doesn't add support for memory efficient attention in DTensor. That will be added in a follow up PR.

This templating is also a POC of how to support other attention ops such as Jagged/nested tensor and as well how to implement striped attention in a scalable way.

Misc changes:

* Fixes all_to_all_single autograd implementation with CUDA + adds NCCL test
* Adds compile support to the ring attention implementations (required some tweaks to process groups)

Test plan:

```
pytest test/distributed/_tensor/test_attention.py
pytest test/distributed/test_functional_api.py
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124215
Approved by: https://github.com/wanchaol
2024-04-19 00:57:08 +00:00
Xilun Wu
605c0a28aa [dtensor][debug] force visualize_sharding not to print for empty tensors (#121217)
**Summary**
Current `visualize_sharding` code cannot print for empty DTensor objects which leads to an exception. This PR skips the print logic if the DTensor passed in has 0 element.
<img width="2165" alt="Pasted Graphic" src="https://github.com/pytorch/pytorch/assets/12968408/fa40b5e7-dad7-4d3a-a485-6a18067320ff">

Pull Request resolved: https://github.com/pytorch/pytorch/pull/121217
Approved by: https://github.com/wanchaol
ghstack dependencies: #121385, #121382
2024-03-11 09:22:49 +00:00
Xilun Wu
3a5ab17bdc [dtensor][debug] visualize_sharding skip if the current rank is not in mesh (#121382)
**Summary**
We should skip the `visualize_sharding()` function on those ranks that are not a part of the DTensor's mesh. If not, exception will be thrown in current visualize logic.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/121382
Approved by: https://github.com/wanchaol
ghstack dependencies: #121385
2024-03-11 09:22:49 +00:00
Xilun Wu
b383123e37 [dtensor][debug] visualize_sharding only compute offset on the first rank in mesh (#121385)
**Summary**
avoid computing on ranks where we do not plan to visualize the DTensor.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/121385
Approved by: https://github.com/wanchaol
2024-03-11 09:22:31 +00:00
Xilun Wu
df2ad1fecc [dtensor][debug] have visualize_sharding correctly print for sub-mesh DTensor (#121216)
**Summary**
In `visualize_sharding` we chose to only print on rank 0 (global rank) which means calling `visualize_sharind` will never print anything when the dtensor object's mesh doesn't include rank 0 (i.e. a sub-mesh). This PR has `visualize_sharding` always print on rank whose mesh coordinate is (0, 0, ..., 0) instead of whose global rank is 0.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/121216
Approved by: https://github.com/wanchaol
ghstack dependencies: #121179, #120260
2024-03-07 04:50:15 +00:00
Xilun Wu
9cc0f23e5c [dtensor][debug] allow visualize_sharding to print header (#121179)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/121179
Approved by: https://github.com/wanchaol
2024-03-07 04:50:06 +00:00
Brian Hirsh
9e0631cc8a get CommsDebugMode to work with DTensor (#118769)
Tested with Wanchao's repro:
```
from typing import Tuple, List, Dict, cast
import torch
import torch.nn as nn
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed._tensor import distribute_tensor, DTensor, Shard, Placement, Replicate

mesh = init_device_mesh(device_type="cuda", mesh_shape=(2,))
x = torch.randn(4, 8, requires_grad=True)
y = torch.randn(4, 32, requires_grad=True)
x_dtensor = DTensor.from_local(x, mesh, [Shard(0)], run_check=False)
y_dtensor = DTensor.from_local(y, mesh, [Shard(0)], run_check=False)
from torch.distributed._tensor.debug import CommDebugMode
comm_mode = CommDebugMode()
with comm_mode:
    z = torch.mm(x_dtensor, y_dtensor)
print(comm_mode.get_comm_counts())
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/118769
Approved by: https://github.com/wanchaol
2024-02-29 01:11:05 +00:00
Yifu Wang
5a3e19578f Make tests using CommDebugMode work for both legacy and native funcol (#120070)
We have many tests that use CommDebugMode to verify the occurrence of collectives. These tests do so by querying comm_counts with legacy funcol ops as key. For the purpose of native funcol migration, we need these tests to work for both legacy and native funcol. To avoid the need to modify all tests to accommodate the two implementations, we make CommDebugMode translate native funcol ops into legacy funcol ops until the migration finishes.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/120070
Approved by: https://github.com/wconstab, https://github.com/wanchaol
ghstack dependencies: #120042, #120043
2024-02-22 20:24:15 +00:00
Edward Z. Yang
46712b019d Enable local_partial_types (#118467)
When using dmypy, this setting is enabled and cannot be turned off. Force it for regular mypy too.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/118467
Approved by: https://github.com/Skylion007
ghstack dependencies: #118414, #118418, #118432
2024-01-28 13:38:22 +00:00
Kumar Ashutosh
405a0040cf Adds tool to visualize sharding (#114307)
This pull request adds a tool to visualize sharding. It uses the device_mesh and placement details to construct a visualization of the split of a torch dtensor.

Things to fix:

- [x] This implementation only uses the first element of the placement tuple, when can there be more than one elements?
- [x] The calculation of the split is happening here but maybe it is already done somewhere internally in Shard class and can we directly call that here?

Fixes #108746

Pull Request resolved: https://github.com/pytorch/pytorch/pull/114307
Approved by: https://github.com/wanchaol
2023-12-12 06:18:03 +00:00
Wanchao Liang
624f202522 [dtensor] add CommDebugMode for debugging (#113592)
This PR adds a CommDebugMode debugging tool to record the number of
distributed collectives, utilizing TorchDispatchMode, the idea borrows
from the FlopCounterMode and we can expand this later to make it more
feature complete like the FlopCounterMode

This is useful for debugging with DTensor and testing, in general this
fits for any complex distributed algorithms where it's non-trival to
understand the algorithm, we can use this tool to understand what
happened under the hood., we can later cover c10d collectives directly

Not sure if it would be a good general distributed debug tool yet,
so adding to the dtensor package first

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113592
Approved by: https://github.com/wconstab
2023-11-27 02:40:28 +00:00
Wanchao Liang
6ed20af10e [dtensor] refactor op dispatch and fix is_same_size/equal (#112927)
torch.equal/is_same_size currently skips sharding prop and directly do
local tensor compute, this is wrong. for these two ops:

- torch.equal: should not skip sharding prop, need to have two DTensor
have the SAME sharding before compare local shard values
- torch.is_same_size: need to completely skip both sharding prop and
local compute

This PR refactors the existing op_dispatch to make it a class instance
so that we can do custom op handling, then fixes both torch.equal and
torch.is_same_size

Pull Request resolved: https://github.com/pytorch/pytorch/pull/112927
Approved by: https://github.com/fduwjj, https://github.com/XilunWu
2023-11-13 22:46:31 +00:00
NVS Abhilash
44c0521e8c fix: docstring error in torch/distributed module (#113241)
Fixes: #113193

`pydocstyle <all_files_in_issue> --count`

- Before: 345
- After: 130

For deprecated methods, I have added a `noqa` to ignore them. I was not able to find the file `torch/distributed/tensor/parallel/multihead_attention_tp.py`, so I've ignored it for this PR.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113241
Approved by: https://github.com/kit1980
2023-11-09 19:10:20 +00:00
Wanchao Liang
979e706f8e [dtensor] update some comments (#107608)
This update some comments from the follow up of https://github.com/pytorch/pytorch/pull/107305
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107608
Approved by: https://github.com/fduwjj
ghstack dependencies: #107606
2023-08-22 23:08:13 +00:00
Wanchao Liang
d8f2ef10a6 [dtensor][1/n] refactor op dispatch logic to reduce overhead (#107305)
This PR is the first change of a series of refactors to the op dispatch logic to:
1. remove the redundant logic in the op dispatch, simplify the error
checking
2. reduce the number of tree_map/tree_flatten/unflatten needed to reduce
the overhead coming from those operations
3. remove the CachedShardingPropagator by using lru_cache from functools
directly, this makes it not only helps TP, but general DTensor
operations could be faster!
4. change the view ops behavior by inplace changing the op_schema, which
is dangerous for sharding prop caching, model the view op as one type
of resharding too
5. enrich output sharding to include whether the op needs redistribute
so that we don't need explicit op schema comparison to know it.

This should help with further reducing the CPU overhead, benchmark
results:
before (without this change), aten.addmm latency: 0.476ms
![Screenshot 2023-08-16 at 10 46 26 AM](https://github.com/pytorch/pytorch/assets/9443650/7692e6c1-1936-4c7f-bf9c-6c8c9b8f6c76)

after (with this change), aten.addmm latency: 0.341ms
![Screenshot 2023-08-16 at 11 05 49 AM](https://github.com/pytorch/pytorch/assets/9443650/15a53f0b-7a95-444e-ab2f-3ee0ad2fa47f)

overall one layer of mlp time reduced from 13.535 -> 9.665ms

Apart from overhead reduction, this PR simplifies the op dispatching logic and the resharding logic (more refactor needed to make things more clean, which will be done in later PRs)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/107305
Approved by: https://github.com/fduwjj
2023-08-18 18:30:46 +00:00
Wanchao Liang
123be4b694 [dtensor] add debug tool to track op coverage (#100124)
This PR adds a debug tool to track the op coverage needed in DTensor.

Note that we specifically target ops after decomp table in inductor
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100124
Approved by: https://github.com/XilunWu
2023-05-02 01:45:55 +00:00