diff --git a/test/distributed/test_nvshmem.py b/test/distributed/test_nvshmem.py index fdf71f79c89..cdd68591e7f 100644 --- a/test/distributed/test_nvshmem.py +++ b/test/distributed/test_nvshmem.py @@ -315,6 +315,107 @@ class NVSHMEMSymmetricMemoryTest(MultiProcContinousTest): out, val * torch.ones(numel, dtype=dtype, device=self.device) ) + @skipIfRocm + @requires_triton() + def test_triton_get(self) -> None: + # A Triton kernel that calls nvshmem device side API for GET + @triton.jit + def get_kernel( + dst_ptr, + src_ptr, + numel: tl.constexpr, + peer: tl.constexpr, + ): + nvshmem.getmem_block(dst_ptr, src_ptr, numel, peer) + + torch.manual_seed(42 + self.rank) + self._init_device() + + nvshmem_lib = nvshmem.enable_triton() + group_name = dist.group.WORLD.group_name + symm_mem.enable_symm_mem_for_group(group_name) + rank = self.rank + msg_size_bytes = 8 + dtype = torch.int8 + numel = msg_size_bytes // dtype.itemsize + val = 7 + inp = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_( + val if rank == 0 else -1 + ) + out = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(-1) + inp_hdl = symm_mem.rendezvous(inp, group=group_name) + out_hdl = symm_mem.rendezvous(out, group=group_name) + dist.barrier() + peer = 1 - rank + if rank == 1: + # Rank 1 gets data from rank 0 + dst_ptr = out_hdl.buffer_ptrs[rank] + src_ptr = inp_hdl.buffer_ptrs[rank] + get_kernel[(1, 1, 1)]( + dst_ptr, + src_ptr, + numel=numel, + peer=peer, + extern_libs=nvshmem_lib, + ) + if rank == 1: + torch.testing.assert_close( + out, val * torch.ones(numel, dtype=dtype, device=self.device) + ) + + @skipIfRocm + @requires_triton() + def test_triton_get_ring(self) -> None: + # A Triton kernel that calls nvshmem device side API for GET + # with ring topology + @triton.jit + def get_kernel( + dst_ptr, + src_ptr, + numel: tl.constexpr, + peer: tl.constexpr, + ): + nvshmem.getmem_block(dst_ptr, src_ptr, numel, peer) + + torch.manual_seed(42 + self.rank) + self._init_device() + + nvshmem_lib = nvshmem.enable_triton() + group_name = dist.group.WORLD.group_name + symm_mem.enable_symm_mem_for_group(group_name) + rank = self.rank + world_size = dist.get_world_size() + msg_size_bytes = 8 + dtype = torch.int8 + numel = msg_size_bytes // dtype.itemsize + + # Each rank fills its input buffer with its own rank value + inp = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(rank) + out = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(-1) + inp_hdl = symm_mem.rendezvous(inp, group=group_name) + out_hdl = symm_mem.rendezvous(out, group=group_name) + dist.barrier() + + # Ring topology: each rank gets data from the rank to its left + # rank 0 gets from rank (world_size-1), rank 1 gets from rank 0, etc. + peer = (rank - 1) % world_size + + # All ranks execute the get operation + dst_ptr = out_hdl.buffer_ptrs[rank] + src_ptr = inp_hdl.buffer_ptrs[rank] + get_kernel[(1, 1, 1)]( + dst_ptr, + src_ptr, + numel=numel, + peer=peer, + extern_libs=nvshmem_lib, + ) + + expected_value = peer + torch.testing.assert_close( + out, expected_value * torch.ones(numel, dtype=dtype, device=self.device) + ) + if __name__ == "__main__": run_tests() diff --git a/torch/distributed/_symmetric_memory/_nvshmem_triton.py b/torch/distributed/_symmetric_memory/_nvshmem_triton.py index be583175a47..fb4df142a1e 100644 --- a/torch/distributed/_symmetric_memory/_nvshmem_triton.py +++ b/torch/distributed/_symmetric_memory/_nvshmem_triton.py @@ -71,3 +71,21 @@ if has_triton(): is_pure=False, _builder=_builder, ) + + @core.extern + def getmem_block(dst, src, nelems, pe, _builder=None): # type: ignore[no-untyped-def] + return core.extern_elementwise( + "", + "", + [dst, src, nelems, pe], + { + ( + core.dtype("int64"), + core.dtype("int64"), + core.dtype("int64"), + core.dtype("int64"), + ): ("nvshmemx_getmem_block", core.dtype("int32")) + }, + is_pure=False, + _builder=_builder, + )