From 01e6d1cae46ff4af8d55e04237a05e430cfb3136 Mon Sep 17 00:00:00 2001 From: Anshul Sinha <50644008+sinhaanshul@users.noreply.github.com> Date: Tue, 4 Jun 2024 10:18:46 -0700 Subject: [PATCH] [dtensor][debug] added c10d reduce_scatter_ and reduce_scatter_tensor_coalesced tracing_ to CommDebugMode (#127358) **Summary** Added c10d reduce_scatter_ and reduce_scatter_tensor_coalesced tracing to CommDebugMode and edited test case in test_comm_mode to include added features. **Test Plan** pytest test/distributed/_tensor/debug/test_comm_mode.py Pull Request resolved: https://github.com/pytorch/pytorch/pull/127358 Approved by: https://github.com/wz337, https://github.com/XilunWu, https://github.com/yifuwang --- test/distributed/_tensor/debug/test_comm_mode.py | 12 ++++++++++++ torch/distributed/_tensor/debug/comm_mode.py | 2 ++ 2 files changed, 14 insertions(+) diff --git a/test/distributed/_tensor/debug/test_comm_mode.py b/test/distributed/_tensor/debug/test_comm_mode.py index 6cb94c86002..dc088f38988 100644 --- a/test/distributed/_tensor/debug/test_comm_mode.py +++ b/test/distributed/_tensor/debug/test_comm_mode.py @@ -193,6 +193,18 @@ class TestCommMode(TestCase): self.checksAssert(comm_mode, c10d_ops.allreduce_coalesced_, 1, 1) + # tests c10d reduce_scatter_ + with comm_mode: + dist.reduce_scatter(all_gather_out, [inp]) + + self.checksAssert(comm_mode, c10d_ops.reduce_scatter_, 1, 1) + + # tests c10d reduce_scatter_tensor_coalesced + with comm_mode as A, dist._coalescing_manager() as B: + dist.reduce_scatter_tensor(all_gather_out, inp) + + self.checksAssert(comm_mode, c10d_ops.reduce_scatter_tensor_coalesced_, 1, 1) + if __name__ == "__main__": run_tests() diff --git a/torch/distributed/_tensor/debug/comm_mode.py b/torch/distributed/_tensor/debug/comm_mode.py index d566da546d2..82f7e98c07c 100644 --- a/torch/distributed/_tensor/debug/comm_mode.py +++ b/torch/distributed/_tensor/debug/comm_mode.py @@ -36,6 +36,8 @@ c10d_collective_ops = { c10d_ops.gather_, c10d_ops.scatter_, c10d_ops.reduce_, + c10d_ops.reduce_scatter_, + c10d_ops.reduce_scatter_tensor_coalesced_, }