[dtensor][debug] added c10d reduce_scatter_ and reduce_scatter_tensor_coalesced tracing_ to CommDebugMode (#127358)

**Summary**
Added c10d reduce_scatter_ and reduce_scatter_tensor_coalesced tracing to CommDebugMode and edited test case in test_comm_mode to include added features.

**Test Plan**
pytest test/distributed/_tensor/debug/test_comm_mode.py

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127358
Approved by: https://github.com/wz337, https://github.com/XilunWu, https://github.com/yifuwang
This commit is contained in:
Anshul Sinha 2024-06-04 10:18:46 -07:00 committed by PyTorch MergeBot
parent 9a25ff77af
commit 01e6d1cae4
2 changed files with 14 additions and 0 deletions

View File

@ -193,6 +193,18 @@ class TestCommMode(TestCase):
self.checksAssert(comm_mode, c10d_ops.allreduce_coalesced_, 1, 1)
# tests c10d reduce_scatter_
with comm_mode:
dist.reduce_scatter(all_gather_out, [inp])
self.checksAssert(comm_mode, c10d_ops.reduce_scatter_, 1, 1)
# tests c10d reduce_scatter_tensor_coalesced
with comm_mode as A, dist._coalescing_manager() as B:
dist.reduce_scatter_tensor(all_gather_out, inp)
self.checksAssert(comm_mode, c10d_ops.reduce_scatter_tensor_coalesced_, 1, 1)
if __name__ == "__main__":
run_tests()

View File

@ -36,6 +36,8 @@ c10d_collective_ops = {
c10d_ops.gather_,
c10d_ops.scatter_,
c10d_ops.reduce_,
c10d_ops.reduce_scatter_,
c10d_ops.reduce_scatter_tensor_coalesced_,
}