mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Revert "[SymmMem] Add root argument to broadcast op (#161090)"
This reverts commit 3c0ff1b569.
Reverted https://github.com/pytorch/pytorch/pull/161090 on behalf of https://github.com/jeanschmidt due to breaks internal builds ([comment](https://github.com/pytorch/pytorch/pull/161090#issuecomment-3255574093))
This commit is contained in:
parent
48bedd753d
commit
d5b38410b5
|
|
@ -99,7 +99,7 @@ class NVSHMEMSymmetricMemoryTest(MultiProcContinuousTest):
|
|||
tensor = torch.zeros(numel, dtype=dtype, device=self.device)
|
||||
|
||||
symm_mem.rendezvous(tensor, group=group_name)
|
||||
torch.ops.symm_mem.nvshmem_broadcast(tensor, src_rank, group_name)
|
||||
torch.ops.symm_mem.nvshmem_broadcast(tensor, group_name)
|
||||
self.assertEqual(tensor, torch.arange(numel, dtype=dtype, device=self.device))
|
||||
|
||||
@skipIfRocm
|
||||
|
|
@ -124,7 +124,7 @@ class NVSHMEMSymmetricMemoryTest(MultiProcContinuousTest):
|
|||
y = torch.mm(x, w)
|
||||
|
||||
# y should be a symm tensor
|
||||
torch.ops.symm_mem.nvshmem_broadcast(y, 0, group_name)
|
||||
torch.ops.symm_mem.nvshmem_broadcast(y, group_name)
|
||||
expected = torch.mm(x0, w)
|
||||
self.assertEqual(y, expected)
|
||||
|
||||
|
|
|
|||
|
|
@ -497,8 +497,7 @@ TORCH_LIBRARY_FRAGMENT(symm_mem, m) {
|
|||
|
||||
m.def("nvshmem_put(Tensor(a!) tensor, int peer) -> ()");
|
||||
m.def("nvshmem_get(Tensor(a!) tensor, int peer) -> ()");
|
||||
m.def(
|
||||
"nvshmem_broadcast(Tensor(a!) input, int root, str group_name) -> Tensor(a!)");
|
||||
m.def("nvshmem_broadcast(Tensor(a!) input, str group_name) -> Tensor(a!)");
|
||||
m.def(
|
||||
"nvshmem_all_to_all(Tensor input, Tensor(a!) out, str group_name) -> Tensor(a!)");
|
||||
m.def(
|
||||
|
|
|
|||
|
|
@ -106,20 +106,19 @@ nvshmem_team_t group_to_team(
|
|||
return team;
|
||||
}
|
||||
|
||||
at::Tensor nvshmem_broadcast(at::Tensor& input, const int64_t root, const std::string& group_name) {
|
||||
at::Tensor nvshmem_broadcast(at::Tensor& input, const std::string& group_name) {
|
||||
auto input_hdl = c10d::symmetric_memory::rendezvous(input, group_name);
|
||||
int rank = input_hdl->get_rank();
|
||||
int world_size = input_hdl->get_world_size();
|
||||
auto team = group_to_team(group_name, input_hdl->get_rank_to_global_rank());
|
||||
void* buffer_ptr = input_hdl->get_buffer_ptrs()[rank];
|
||||
int team_size = nvshmem_team_n_pes(team);
|
||||
TORCH_CHECK(root < team_size, "root must be smaller than group size");
|
||||
|
||||
auto stream = at::cuda::getCurrentCUDAStream();
|
||||
nvshmemx_broadcastmem_on_stream(team, buffer_ptr, buffer_ptr, input_hdl->get_buffer_size(), root, stream);
|
||||
nvshmemx_broadcastmem_on_stream(team, buffer_ptr, buffer_ptr, input_hdl->get_buffer_size(), 0, stream);
|
||||
return input;
|
||||
}
|
||||
|
||||
void nvshmem_put(at::Tensor& tensor, const int64_t peer) {
|
||||
void nvshmem_put(at::Tensor& tensor, int64_t peer) {
|
||||
// TODO: support non-contiguous tensors
|
||||
TORCH_CHECK(tensor.is_contiguous(),
|
||||
"put op currently supports contiguous tensors only");
|
||||
|
|
@ -128,14 +127,13 @@ void nvshmem_put(at::Tensor& tensor, const int64_t peer) {
|
|||
auto rank = hdl->get_rank();
|
||||
void* buffer_ptr = hdl->get_buffer_ptrs()[rank];
|
||||
auto buffer_size = tensor.numel() * tensor.element_size();
|
||||
TORCH_CHECK(peer < hdl->get_world_size(), "peer must be smaller than world size");
|
||||
|
||||
c10::cuda::CUDAGuard guard(tensor.device());
|
||||
auto stream = at::cuda::getCurrentCUDAStream();
|
||||
nvshmemx_putmem_on_stream(buffer_ptr, tensor.data_ptr(), buffer_size, peer, stream);
|
||||
}
|
||||
|
||||
void nvshmem_get(at::Tensor& tensor, const int64_t peer) {
|
||||
void nvshmem_get(at::Tensor& tensor, int64_t peer) {
|
||||
// TODO: support non-contiguous tensors
|
||||
TORCH_CHECK(tensor.is_contiguous(),
|
||||
"get op currently supports contiguous tensors only");
|
||||
|
|
@ -144,7 +142,6 @@ void nvshmem_get(at::Tensor& tensor, const int64_t peer) {
|
|||
auto rank = hdl->get_rank();
|
||||
void* buffer_ptr = hdl->get_buffer_ptrs()[rank];
|
||||
auto buffer_size = tensor.numel() * tensor.element_size();
|
||||
TORCH_CHECK(peer < hdl->get_world_size(), "peer must be smaller than world size");
|
||||
|
||||
c10::cuda::CUDAGuard guard(tensor.device());
|
||||
auto stream = at::cuda::getCurrentCUDAStream();
|
||||
|
|
|
|||
|
|
@ -21,11 +21,11 @@ TORCH_API bool is_nvshmem_available();
|
|||
// operations.
|
||||
TORCH_API void nvshmemx_cumodule_init(uintptr_t module);
|
||||
|
||||
TORCH_API void nvshmem_put(at::Tensor& tensor, const int64_t peer);
|
||||
TORCH_API void nvshmem_put(at::Tensor& tensor, int64_t peer);
|
||||
|
||||
TORCH_API void nvshmem_get(at::Tensor& tensor, const int64_t peer);
|
||||
TORCH_API void nvshmem_get(at::Tensor& tensor, int64_t peer);
|
||||
|
||||
at::Tensor nvshmem_broadcast(at::Tensor& input, const int64_t root, const std::string& group_name);
|
||||
at::Tensor nvshmem_broadcast(at::Tensor& input, const std::string& group_name);
|
||||
|
||||
at::Tensor nvshmem_all_to_all(
|
||||
at::Tensor& input,
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user