[SymmMem] Add NVSHMEM GET support to Triton (#155890)

Adds NVSHMEM GET operation support for Triton kernels:

- Add `getmem_block` core.extern wrapper for nvshmemx_getmem_block
- Add basic `test_triton_get` for 2-rank GET operation
- Add `test_triton_get_ring` for ring topology GET across arbitrary ranks

**Tests:**
`$ TORCH_SYMMMEM=NVSHMEM python test/distributed/test_nvshmem.py`

`TORCH_SYMMMEM=NVSHMEM python test/distributed/test_nvshmem.py -k test_triton_get`

```python
@skipIfRocm
@requires_triton()
def test_triton_get(self) -> None:
   @triton.jit
   def get_kernel(dst_ptr, src_ptr, numel: tl.constexpr, peer: tl.constexpr):
       nvshmem.getmem_block(dst_ptr, src_ptr, numel, peer)

   # ... setup code ...

   val = 7
   inp = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(
       val if rank == 0 else -1
   )
   out = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(-1)

   peer = 1 - rank
   if rank == 1:
       # Rank 1 gets data from rank 0
       get_kernel[(1, 1, 1)](dst_ptr, src_ptr, numel=numel, peer=peer, extern_libs=nvshmem_lib)

   dist.barrier()
   print(f"[Rank {rank}] inp buffer: {inp}")
   print(f"[Rank {rank}] out buffer: {out}")
   print(f"[Rank {rank}] got data from peer {peer}")
```

```

[Rank 0] inp buffer: tensor([7, 7, 7, 7, 7, 7, 7, 7], device='cuda:0', dtype=torch.int8)
[Rank 1] inp buffer: tensor([-1, -1, -1, -1, -1, -1, -1, -1], device='cuda:1', dtype=torch.int8)
...
[Rank 1] out buffer: tensor([7, 7, 7, 7, 7, 7, 7, 7], device='cuda:1', dtype=torch.int8)
...
[Rank 1] got data from peer 0

----------------------------------------------------------------------
Ran 2 tests in 17.046s

OK
```

```python
@skipIfRocm
@requires_triton()
def test_triton_get_ring(self) -> None:
   @triton.jit
   def get_kernel(dst_ptr, src_ptr, numel: tl.constexpr, peer: tl.constexpr):
       nvshmem.getmem_block(dst_ptr, src_ptr, numel, peer)

   # ... setup code ...

   # Ring topology: each rank gets data from the rank to its left
   peer = (rank - 1) % world_size

   # All ranks execute the get operation
   get_kernel[(1, 1, 1)](dst_ptr, src_ptr, numel=numel, peer=peer, extern_libs=nvshmem_lib)

   dist.barrier()
   print(f"[Rank {rank}] inp buffer: {inp}")
   print(f"[Rank {rank}] out buffer: {out}")
   print(f"[Rank {rank}] got data from peer {peer}")

```

```
Output (8 GPUs):

[Rank 0] inp buffer: tensor([0, 0, 0, 0, 0, 0, 0, 0], device='cuda:0', dtype=torch.int8)
[Rank 2] inp buffer: tensor([2, 2, 2, 2, 2, 2, 2, 2], device='cuda:2', dtype=torch.int8)
[Rank 5] inp buffer: tensor([5, 5, 5, 5, 5, 5, 5, 5], device='cuda:5', dtype=torch.int8)
[Rank 6] inp buffer: tensor([6, 6, 6, 6, 6, 6, 6, 6], device='cuda:6', dtype=torch.int8)
[Rank 3] inp buffer: tensor([3, 3, 3, 3, 3, 3, 3, 3], device='cuda:3', dtype=torch.int8)
[Rank 1] inp buffer: tensor([1, 1, 1, 1, 1, 1, 1, 1], device='cuda:1', dtype=torch.int8)
[Rank 2] out buffer: tensor([1, 1, 1, 1, 1, 1, 1, 1], device='cuda:2', dtype=torch.int8)
[Rank 5] out buffer: tensor([4, 4, 4, 4, 4, 4, 4, 4], device='cuda:5', dtype=torch.int8)
[Rank 0] out buffer: tensor([7, 7, 7, 7, 7, 7, 7, 7], device='cuda:0', dtype=torch.int8)
[Rank 2] got data from peer 1
[Rank 5] got data from peer 4
[Rank 0] got data from peer 7
[Rank 7] inp buffer: tensor([7, 7, 7, 7, 7, 7, 7, 7], device='cuda:7', dtype=torch.int8)
[Rank 6] out buffer: tensor([5, 5, 5, 5, 5, 5, 5, 5], device='cuda:6', dtype=torch.int8)
[Rank 3] out buffer: tensor([2, 2, 2, 2, 2, 2, 2, 2], device='cuda:3', dtype=torch.int8)
[Rank 6] got data from peer 5
[Rank 3] got data from peer 2
[Rank 1] out buffer: tensor([0, 0, 0, 0, 0, 0, 0, 0], device='cuda:1', dtype=torch.int8)
[Rank 1] got data from peer 0
[Rank 4] inp buffer: tensor([4, 4, 4, 4, 4, 4, 4, 4], device='cuda:4', dtype=torch.int8)
[Rank 7] out buffer: tensor([6, 6, 6, 6, 6, 6, 6, 6], device='cuda:7', dtype=torch.int8)
[Rank 7] got data from peer 6
[Rank 4] out buffer: tensor([3, 3, 3, 3, 3, 3, 3, 3], device='cuda:4', dtype=torch.int8)
[Rank 4] got data from peer 3

----------------------------------------------------------------------
Ran 1 test in 41.045s

OK
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/155890
Approved by: https://github.com/kwen2501, https://github.com/mandroid6
This commit is contained in:
codingwithsurya 2025-06-16 23:18:11 +00:00 committed by PyTorch MergeBot
parent bb1f3d1a55
commit 4781b0ee60
2 changed files with 119 additions and 0 deletions

View File

@ -315,6 +315,107 @@ class NVSHMEMSymmetricMemoryTest(MultiProcContinousTest):
out, val * torch.ones(numel, dtype=dtype, device=self.device)
)
@skipIfRocm
@requires_triton()
def test_triton_get(self) -> None:
# A Triton kernel that calls nvshmem device side API for GET
@triton.jit
def get_kernel(
dst_ptr,
src_ptr,
numel: tl.constexpr,
peer: tl.constexpr,
):
nvshmem.getmem_block(dst_ptr, src_ptr, numel, peer)
torch.manual_seed(42 + self.rank)
self._init_device()
nvshmem_lib = nvshmem.enable_triton()
group_name = dist.group.WORLD.group_name
symm_mem.enable_symm_mem_for_group(group_name)
rank = self.rank
msg_size_bytes = 8
dtype = torch.int8
numel = msg_size_bytes // dtype.itemsize
val = 7
inp = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(
val if rank == 0 else -1
)
out = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(-1)
inp_hdl = symm_mem.rendezvous(inp, group=group_name)
out_hdl = symm_mem.rendezvous(out, group=group_name)
dist.barrier()
peer = 1 - rank
if rank == 1:
# Rank 1 gets data from rank 0
dst_ptr = out_hdl.buffer_ptrs[rank]
src_ptr = inp_hdl.buffer_ptrs[rank]
get_kernel[(1, 1, 1)](
dst_ptr,
src_ptr,
numel=numel,
peer=peer,
extern_libs=nvshmem_lib,
)
if rank == 1:
torch.testing.assert_close(
out, val * torch.ones(numel, dtype=dtype, device=self.device)
)
@skipIfRocm
@requires_triton()
def test_triton_get_ring(self) -> None:
# A Triton kernel that calls nvshmem device side API for GET
# with ring topology
@triton.jit
def get_kernel(
dst_ptr,
src_ptr,
numel: tl.constexpr,
peer: tl.constexpr,
):
nvshmem.getmem_block(dst_ptr, src_ptr, numel, peer)
torch.manual_seed(42 + self.rank)
self._init_device()
nvshmem_lib = nvshmem.enable_triton()
group_name = dist.group.WORLD.group_name
symm_mem.enable_symm_mem_for_group(group_name)
rank = self.rank
world_size = dist.get_world_size()
msg_size_bytes = 8
dtype = torch.int8
numel = msg_size_bytes // dtype.itemsize
# Each rank fills its input buffer with its own rank value
inp = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(rank)
out = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(-1)
inp_hdl = symm_mem.rendezvous(inp, group=group_name)
out_hdl = symm_mem.rendezvous(out, group=group_name)
dist.barrier()
# Ring topology: each rank gets data from the rank to its left
# rank 0 gets from rank (world_size-1), rank 1 gets from rank 0, etc.
peer = (rank - 1) % world_size
# All ranks execute the get operation
dst_ptr = out_hdl.buffer_ptrs[rank]
src_ptr = inp_hdl.buffer_ptrs[rank]
get_kernel[(1, 1, 1)](
dst_ptr,
src_ptr,
numel=numel,
peer=peer,
extern_libs=nvshmem_lib,
)
expected_value = peer
torch.testing.assert_close(
out, expected_value * torch.ones(numel, dtype=dtype, device=self.device)
)
if __name__ == "__main__":
run_tests()

View File

@ -71,3 +71,21 @@ if has_triton():
is_pure=False,
_builder=_builder,
)
@core.extern
def getmem_block(dst, src, nelems, pe, _builder=None): # type: ignore[no-untyped-def]
return core.extern_elementwise(
"",
"",
[dst, src, nelems, pe],
{
(
core.dtype("int64"),
core.dtype("int64"),
core.dtype("int64"),
core.dtype("int64"),
): ("nvshmemx_getmem_block", core.dtype("int32"))
},
is_pure=False,
_builder=_builder,
)