From d5b38410b5b6cf75c7a7389972777a6497926ee7 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Thu, 4 Sep 2025 20:42:31 +0000 Subject: [PATCH] Revert "[SymmMem] Add root argument to broadcast op (#161090)" This reverts commit 3c0ff1b569c45cfa6935ad8031a9d4cf1551aa3f. Reverted https://github.com/pytorch/pytorch/pull/161090 on behalf of https://github.com/jeanschmidt due to breaks internal builds ([comment](https://github.com/pytorch/pytorch/pull/161090#issuecomment-3255574093)) --- test/distributed/test_nvshmem.py | 4 ++-- .../distributed/c10d/symm_mem/SymmetricMemory.cpp | 3 +-- .../distributed/c10d/symm_mem/nvshmem_extension.cu | 13 +++++-------- .../distributed/c10d/symm_mem/nvshmem_extension.cuh | 6 +++--- 4 files changed, 11 insertions(+), 15 deletions(-) diff --git a/test/distributed/test_nvshmem.py b/test/distributed/test_nvshmem.py index 7046a8bf735..16fed916d91 100644 --- a/test/distributed/test_nvshmem.py +++ b/test/distributed/test_nvshmem.py @@ -99,7 +99,7 @@ class NVSHMEMSymmetricMemoryTest(MultiProcContinuousTest): tensor = torch.zeros(numel, dtype=dtype, device=self.device) symm_mem.rendezvous(tensor, group=group_name) - torch.ops.symm_mem.nvshmem_broadcast(tensor, src_rank, group_name) + torch.ops.symm_mem.nvshmem_broadcast(tensor, group_name) self.assertEqual(tensor, torch.arange(numel, dtype=dtype, device=self.device)) @skipIfRocm @@ -124,7 +124,7 @@ class NVSHMEMSymmetricMemoryTest(MultiProcContinuousTest): y = torch.mm(x, w) # y should be a symm tensor - torch.ops.symm_mem.nvshmem_broadcast(y, 0, group_name) + torch.ops.symm_mem.nvshmem_broadcast(y, group_name) expected = torch.mm(x0, w) self.assertEqual(y, expected) diff --git a/torch/csrc/distributed/c10d/symm_mem/SymmetricMemory.cpp b/torch/csrc/distributed/c10d/symm_mem/SymmetricMemory.cpp index 949e6d7c9fb..c3ed9dcd0d0 100644 --- a/torch/csrc/distributed/c10d/symm_mem/SymmetricMemory.cpp +++ b/torch/csrc/distributed/c10d/symm_mem/SymmetricMemory.cpp @@ -497,8 +497,7 @@ TORCH_LIBRARY_FRAGMENT(symm_mem, m) { m.def("nvshmem_put(Tensor(a!) tensor, int peer) -> ()"); m.def("nvshmem_get(Tensor(a!) tensor, int peer) -> ()"); - m.def( - "nvshmem_broadcast(Tensor(a!) input, int root, str group_name) -> Tensor(a!)"); + m.def("nvshmem_broadcast(Tensor(a!) input, str group_name) -> Tensor(a!)"); m.def( "nvshmem_all_to_all(Tensor input, Tensor(a!) out, str group_name) -> Tensor(a!)"); m.def( diff --git a/torch/csrc/distributed/c10d/symm_mem/nvshmem_extension.cu b/torch/csrc/distributed/c10d/symm_mem/nvshmem_extension.cu index d422c4859b6..bb6b5414eaf 100644 --- a/torch/csrc/distributed/c10d/symm_mem/nvshmem_extension.cu +++ b/torch/csrc/distributed/c10d/symm_mem/nvshmem_extension.cu @@ -106,20 +106,19 @@ nvshmem_team_t group_to_team( return team; } -at::Tensor nvshmem_broadcast(at::Tensor& input, const int64_t root, const std::string& group_name) { +at::Tensor nvshmem_broadcast(at::Tensor& input, const std::string& group_name) { auto input_hdl = c10d::symmetric_memory::rendezvous(input, group_name); int rank = input_hdl->get_rank(); + int world_size = input_hdl->get_world_size(); auto team = group_to_team(group_name, input_hdl->get_rank_to_global_rank()); void* buffer_ptr = input_hdl->get_buffer_ptrs()[rank]; - int team_size = nvshmem_team_n_pes(team); - TORCH_CHECK(root < team_size, "root must be smaller than group size"); auto stream = at::cuda::getCurrentCUDAStream(); - nvshmemx_broadcastmem_on_stream(team, buffer_ptr, buffer_ptr, input_hdl->get_buffer_size(), root, stream); + nvshmemx_broadcastmem_on_stream(team, buffer_ptr, buffer_ptr, input_hdl->get_buffer_size(), 0, stream); return input; } -void nvshmem_put(at::Tensor& tensor, const int64_t peer) { +void nvshmem_put(at::Tensor& tensor, int64_t peer) { // TODO: support non-contiguous tensors TORCH_CHECK(tensor.is_contiguous(), "put op currently supports contiguous tensors only"); @@ -128,14 +127,13 @@ void nvshmem_put(at::Tensor& tensor, const int64_t peer) { auto rank = hdl->get_rank(); void* buffer_ptr = hdl->get_buffer_ptrs()[rank]; auto buffer_size = tensor.numel() * tensor.element_size(); - TORCH_CHECK(peer < hdl->get_world_size(), "peer must be smaller than world size"); c10::cuda::CUDAGuard guard(tensor.device()); auto stream = at::cuda::getCurrentCUDAStream(); nvshmemx_putmem_on_stream(buffer_ptr, tensor.data_ptr(), buffer_size, peer, stream); } -void nvshmem_get(at::Tensor& tensor, const int64_t peer) { +void nvshmem_get(at::Tensor& tensor, int64_t peer) { // TODO: support non-contiguous tensors TORCH_CHECK(tensor.is_contiguous(), "get op currently supports contiguous tensors only"); @@ -144,7 +142,6 @@ void nvshmem_get(at::Tensor& tensor, const int64_t peer) { auto rank = hdl->get_rank(); void* buffer_ptr = hdl->get_buffer_ptrs()[rank]; auto buffer_size = tensor.numel() * tensor.element_size(); - TORCH_CHECK(peer < hdl->get_world_size(), "peer must be smaller than world size"); c10::cuda::CUDAGuard guard(tensor.device()); auto stream = at::cuda::getCurrentCUDAStream(); diff --git a/torch/csrc/distributed/c10d/symm_mem/nvshmem_extension.cuh b/torch/csrc/distributed/c10d/symm_mem/nvshmem_extension.cuh index fc37bd931fa..f364e2ebfa3 100644 --- a/torch/csrc/distributed/c10d/symm_mem/nvshmem_extension.cuh +++ b/torch/csrc/distributed/c10d/symm_mem/nvshmem_extension.cuh @@ -21,11 +21,11 @@ TORCH_API bool is_nvshmem_available(); // operations. TORCH_API void nvshmemx_cumodule_init(uintptr_t module); -TORCH_API void nvshmem_put(at::Tensor& tensor, const int64_t peer); +TORCH_API void nvshmem_put(at::Tensor& tensor, int64_t peer); -TORCH_API void nvshmem_get(at::Tensor& tensor, const int64_t peer); +TORCH_API void nvshmem_get(at::Tensor& tensor, int64_t peer); -at::Tensor nvshmem_broadcast(at::Tensor& input, const int64_t root, const std::string& group_name); +at::Tensor nvshmem_broadcast(at::Tensor& input, const std::string& group_name); at::Tensor nvshmem_all_to_all( at::Tensor& input,