[dtensor][debug] added c10d alltoall_ and alltoall_base_ to CommDebugMode (#127360)

**Summary**
Added c10d alltoall_ and alltoall_base tracing to CommDebugMode and edited test case in test_comm_mode to include added features.

**Test Plan**
pytest test/distributed/_tensor/debug/test_comm_mode.py

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127360
Approved by: https://github.com/wz337, https://github.com/XilunWu, https://github.com/yifuwang
ghstack dependencies: #127358
This commit is contained in:
Anshul Sinha 2024-06-04 10:18:47 -07:00 committed by PyTorch MergeBot
parent 01e6d1cae4
commit e76b28c765
2 changed files with 14 additions and 0 deletions

View File

@ -205,6 +205,18 @@ class TestCommMode(TestCase):
self.checksAssert(comm_mode, c10d_ops.reduce_scatter_tensor_coalesced_, 1, 1)
# tests c10d alltoall_
with comm_mode:
dist.all_to_all([inp], [inp])
self.checksAssert(comm_mode, c10d_ops.alltoall_, 1, 1)
# tests c10d alltoall_base_
with comm_mode:
dist.all_to_all_single(inp, inp)
self.checksAssert(comm_mode, c10d_ops.alltoall_base_, 1, 1)
if __name__ == "__main__":
run_tests()

View File

@ -32,6 +32,8 @@ c10d_collective_ops = {
c10d_ops.allgather_into_tensor_coalesced_,
c10d_ops.allreduce_,
c10d_ops.allreduce_coalesced_,
c10d_ops.alltoall_,
c10d_ops.alltoall_base_,
c10d_ops.broadcast_,
c10d_ops.gather_,
c10d_ops.scatter_,