Revert "[SymmMem] Add root argument to broadcast op (#161090)"

This reverts commit 3c0ff1b569.

Reverted https://github.com/pytorch/pytorch/pull/161090 on behalf of https://github.com/jeanschmidt due to breaks internal builds ([comment](https://github.com/pytorch/pytorch/pull/161090#issuecomment-3255574093))
This commit is contained in:
PyTorch MergeBot 2025-09-04 20:42:31 +00:00
parent 48bedd753d
commit d5b38410b5
4 changed files with 11 additions and 15 deletions

View File

@ -99,7 +99,7 @@ class NVSHMEMSymmetricMemoryTest(MultiProcContinuousTest):
tensor = torch.zeros(numel, dtype=dtype, device=self.device)
symm_mem.rendezvous(tensor, group=group_name)
torch.ops.symm_mem.nvshmem_broadcast(tensor, src_rank, group_name)
torch.ops.symm_mem.nvshmem_broadcast(tensor, group_name)
self.assertEqual(tensor, torch.arange(numel, dtype=dtype, device=self.device))
@skipIfRocm
@ -124,7 +124,7 @@ class NVSHMEMSymmetricMemoryTest(MultiProcContinuousTest):
y = torch.mm(x, w)
# y should be a symm tensor
torch.ops.symm_mem.nvshmem_broadcast(y, 0, group_name)
torch.ops.symm_mem.nvshmem_broadcast(y, group_name)
expected = torch.mm(x0, w)
self.assertEqual(y, expected)

View File

@ -497,8 +497,7 @@ TORCH_LIBRARY_FRAGMENT(symm_mem, m) {
m.def("nvshmem_put(Tensor(a!) tensor, int peer) -> ()");
m.def("nvshmem_get(Tensor(a!) tensor, int peer) -> ()");
m.def(
"nvshmem_broadcast(Tensor(a!) input, int root, str group_name) -> Tensor(a!)");
m.def("nvshmem_broadcast(Tensor(a!) input, str group_name) -> Tensor(a!)");
m.def(
"nvshmem_all_to_all(Tensor input, Tensor(a!) out, str group_name) -> Tensor(a!)");
m.def(

View File

@ -106,20 +106,19 @@ nvshmem_team_t group_to_team(
return team;
}
at::Tensor nvshmem_broadcast(at::Tensor& input, const int64_t root, const std::string& group_name) {
at::Tensor nvshmem_broadcast(at::Tensor& input, const std::string& group_name) {
auto input_hdl = c10d::symmetric_memory::rendezvous(input, group_name);
int rank = input_hdl->get_rank();
int world_size = input_hdl->get_world_size();
auto team = group_to_team(group_name, input_hdl->get_rank_to_global_rank());
void* buffer_ptr = input_hdl->get_buffer_ptrs()[rank];
int team_size = nvshmem_team_n_pes(team);
TORCH_CHECK(root < team_size, "root must be smaller than group size");
auto stream = at::cuda::getCurrentCUDAStream();
nvshmemx_broadcastmem_on_stream(team, buffer_ptr, buffer_ptr, input_hdl->get_buffer_size(), root, stream);
nvshmemx_broadcastmem_on_stream(team, buffer_ptr, buffer_ptr, input_hdl->get_buffer_size(), 0, stream);
return input;
}
void nvshmem_put(at::Tensor& tensor, const int64_t peer) {
void nvshmem_put(at::Tensor& tensor, int64_t peer) {
// TODO: support non-contiguous tensors
TORCH_CHECK(tensor.is_contiguous(),
"put op currently supports contiguous tensors only");
@ -128,14 +127,13 @@ void nvshmem_put(at::Tensor& tensor, const int64_t peer) {
auto rank = hdl->get_rank();
void* buffer_ptr = hdl->get_buffer_ptrs()[rank];
auto buffer_size = tensor.numel() * tensor.element_size();
TORCH_CHECK(peer < hdl->get_world_size(), "peer must be smaller than world size");
c10::cuda::CUDAGuard guard(tensor.device());
auto stream = at::cuda::getCurrentCUDAStream();
nvshmemx_putmem_on_stream(buffer_ptr, tensor.data_ptr(), buffer_size, peer, stream);
}
void nvshmem_get(at::Tensor& tensor, const int64_t peer) {
void nvshmem_get(at::Tensor& tensor, int64_t peer) {
// TODO: support non-contiguous tensors
TORCH_CHECK(tensor.is_contiguous(),
"get op currently supports contiguous tensors only");
@ -144,7 +142,6 @@ void nvshmem_get(at::Tensor& tensor, const int64_t peer) {
auto rank = hdl->get_rank();
void* buffer_ptr = hdl->get_buffer_ptrs()[rank];
auto buffer_size = tensor.numel() * tensor.element_size();
TORCH_CHECK(peer < hdl->get_world_size(), "peer must be smaller than world size");
c10::cuda::CUDAGuard guard(tensor.device());
auto stream = at::cuda::getCurrentCUDAStream();

View File

@ -21,11 +21,11 @@ TORCH_API bool is_nvshmem_available();
// operations.
TORCH_API void nvshmemx_cumodule_init(uintptr_t module);
TORCH_API void nvshmem_put(at::Tensor& tensor, const int64_t peer);
TORCH_API void nvshmem_put(at::Tensor& tensor, int64_t peer);
TORCH_API void nvshmem_get(at::Tensor& tensor, const int64_t peer);
TORCH_API void nvshmem_get(at::Tensor& tensor, int64_t peer);
at::Tensor nvshmem_broadcast(at::Tensor& input, const int64_t root, const std::string& group_name);
at::Tensor nvshmem_broadcast(at::Tensor& input, const std::string& group_name);
at::Tensor nvshmem_all_to_all(
at::Tensor& input,