mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[SymmMem] Add NVSHMEM GET support to Triton (#155890)
Adds NVSHMEM GET operation support for Triton kernels:
- Add `getmem_block` core.extern wrapper for nvshmemx_getmem_block
- Add basic `test_triton_get` for 2-rank GET operation
- Add `test_triton_get_ring` for ring topology GET across arbitrary ranks
**Tests:**
`$ TORCH_SYMMMEM=NVSHMEM python test/distributed/test_nvshmem.py`
`TORCH_SYMMMEM=NVSHMEM python test/distributed/test_nvshmem.py -k test_triton_get`
```python
@skipIfRocm
@requires_triton()
def test_triton_get(self) -> None:
@triton.jit
def get_kernel(dst_ptr, src_ptr, numel: tl.constexpr, peer: tl.constexpr):
nvshmem.getmem_block(dst_ptr, src_ptr, numel, peer)
# ... setup code ...
val = 7
inp = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(
val if rank == 0 else -1
)
out = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(-1)
peer = 1 - rank
if rank == 1:
# Rank 1 gets data from rank 0
get_kernel[(1, 1, 1)](dst_ptr, src_ptr, numel=numel, peer=peer, extern_libs=nvshmem_lib)
dist.barrier()
print(f"[Rank {rank}] inp buffer: {inp}")
print(f"[Rank {rank}] out buffer: {out}")
print(f"[Rank {rank}] got data from peer {peer}")
```
```
[Rank 0] inp buffer: tensor([7, 7, 7, 7, 7, 7, 7, 7], device='cuda:0', dtype=torch.int8)
[Rank 1] inp buffer: tensor([-1, -1, -1, -1, -1, -1, -1, -1], device='cuda:1', dtype=torch.int8)
...
[Rank 1] out buffer: tensor([7, 7, 7, 7, 7, 7, 7, 7], device='cuda:1', dtype=torch.int8)
...
[Rank 1] got data from peer 0
----------------------------------------------------------------------
Ran 2 tests in 17.046s
OK
```
```python
@skipIfRocm
@requires_triton()
def test_triton_get_ring(self) -> None:
@triton.jit
def get_kernel(dst_ptr, src_ptr, numel: tl.constexpr, peer: tl.constexpr):
nvshmem.getmem_block(dst_ptr, src_ptr, numel, peer)
# ... setup code ...
# Ring topology: each rank gets data from the rank to its left
peer = (rank - 1) % world_size
# All ranks execute the get operation
get_kernel[(1, 1, 1)](dst_ptr, src_ptr, numel=numel, peer=peer, extern_libs=nvshmem_lib)
dist.barrier()
print(f"[Rank {rank}] inp buffer: {inp}")
print(f"[Rank {rank}] out buffer: {out}")
print(f"[Rank {rank}] got data from peer {peer}")
```
```
Output (8 GPUs):
[Rank 0] inp buffer: tensor([0, 0, 0, 0, 0, 0, 0, 0], device='cuda:0', dtype=torch.int8)
[Rank 2] inp buffer: tensor([2, 2, 2, 2, 2, 2, 2, 2], device='cuda:2', dtype=torch.int8)
[Rank 5] inp buffer: tensor([5, 5, 5, 5, 5, 5, 5, 5], device='cuda:5', dtype=torch.int8)
[Rank 6] inp buffer: tensor([6, 6, 6, 6, 6, 6, 6, 6], device='cuda:6', dtype=torch.int8)
[Rank 3] inp buffer: tensor([3, 3, 3, 3, 3, 3, 3, 3], device='cuda:3', dtype=torch.int8)
[Rank 1] inp buffer: tensor([1, 1, 1, 1, 1, 1, 1, 1], device='cuda:1', dtype=torch.int8)
[Rank 2] out buffer: tensor([1, 1, 1, 1, 1, 1, 1, 1], device='cuda:2', dtype=torch.int8)
[Rank 5] out buffer: tensor([4, 4, 4, 4, 4, 4, 4, 4], device='cuda:5', dtype=torch.int8)
[Rank 0] out buffer: tensor([7, 7, 7, 7, 7, 7, 7, 7], device='cuda:0', dtype=torch.int8)
[Rank 2] got data from peer 1
[Rank 5] got data from peer 4
[Rank 0] got data from peer 7
[Rank 7] inp buffer: tensor([7, 7, 7, 7, 7, 7, 7, 7], device='cuda:7', dtype=torch.int8)
[Rank 6] out buffer: tensor([5, 5, 5, 5, 5, 5, 5, 5], device='cuda:6', dtype=torch.int8)
[Rank 3] out buffer: tensor([2, 2, 2, 2, 2, 2, 2, 2], device='cuda:3', dtype=torch.int8)
[Rank 6] got data from peer 5
[Rank 3] got data from peer 2
[Rank 1] out buffer: tensor([0, 0, 0, 0, 0, 0, 0, 0], device='cuda:1', dtype=torch.int8)
[Rank 1] got data from peer 0
[Rank 4] inp buffer: tensor([4, 4, 4, 4, 4, 4, 4, 4], device='cuda:4', dtype=torch.int8)
[Rank 7] out buffer: tensor([6, 6, 6, 6, 6, 6, 6, 6], device='cuda:7', dtype=torch.int8)
[Rank 7] got data from peer 6
[Rank 4] out buffer: tensor([3, 3, 3, 3, 3, 3, 3, 3], device='cuda:4', dtype=torch.int8)
[Rank 4] got data from peer 3
----------------------------------------------------------------------
Ran 1 test in 41.045s
OK
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/155890
Approved by: https://github.com/kwen2501, https://github.com/mandroid6
This commit is contained in:
parent
bb1f3d1a55
commit
4781b0ee60
|
|
@ -315,6 +315,107 @@ class NVSHMEMSymmetricMemoryTest(MultiProcContinousTest):
|
|||
out, val * torch.ones(numel, dtype=dtype, device=self.device)
|
||||
)
|
||||
|
||||
@skipIfRocm
|
||||
@requires_triton()
|
||||
def test_triton_get(self) -> None:
|
||||
# A Triton kernel that calls nvshmem device side API for GET
|
||||
@triton.jit
|
||||
def get_kernel(
|
||||
dst_ptr,
|
||||
src_ptr,
|
||||
numel: tl.constexpr,
|
||||
peer: tl.constexpr,
|
||||
):
|
||||
nvshmem.getmem_block(dst_ptr, src_ptr, numel, peer)
|
||||
|
||||
torch.manual_seed(42 + self.rank)
|
||||
self._init_device()
|
||||
|
||||
nvshmem_lib = nvshmem.enable_triton()
|
||||
group_name = dist.group.WORLD.group_name
|
||||
symm_mem.enable_symm_mem_for_group(group_name)
|
||||
rank = self.rank
|
||||
msg_size_bytes = 8
|
||||
dtype = torch.int8
|
||||
numel = msg_size_bytes // dtype.itemsize
|
||||
val = 7
|
||||
inp = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(
|
||||
val if rank == 0 else -1
|
||||
)
|
||||
out = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(-1)
|
||||
inp_hdl = symm_mem.rendezvous(inp, group=group_name)
|
||||
out_hdl = symm_mem.rendezvous(out, group=group_name)
|
||||
dist.barrier()
|
||||
peer = 1 - rank
|
||||
if rank == 1:
|
||||
# Rank 1 gets data from rank 0
|
||||
dst_ptr = out_hdl.buffer_ptrs[rank]
|
||||
src_ptr = inp_hdl.buffer_ptrs[rank]
|
||||
get_kernel[(1, 1, 1)](
|
||||
dst_ptr,
|
||||
src_ptr,
|
||||
numel=numel,
|
||||
peer=peer,
|
||||
extern_libs=nvshmem_lib,
|
||||
)
|
||||
if rank == 1:
|
||||
torch.testing.assert_close(
|
||||
out, val * torch.ones(numel, dtype=dtype, device=self.device)
|
||||
)
|
||||
|
||||
@skipIfRocm
|
||||
@requires_triton()
|
||||
def test_triton_get_ring(self) -> None:
|
||||
# A Triton kernel that calls nvshmem device side API for GET
|
||||
# with ring topology
|
||||
@triton.jit
|
||||
def get_kernel(
|
||||
dst_ptr,
|
||||
src_ptr,
|
||||
numel: tl.constexpr,
|
||||
peer: tl.constexpr,
|
||||
):
|
||||
nvshmem.getmem_block(dst_ptr, src_ptr, numel, peer)
|
||||
|
||||
torch.manual_seed(42 + self.rank)
|
||||
self._init_device()
|
||||
|
||||
nvshmem_lib = nvshmem.enable_triton()
|
||||
group_name = dist.group.WORLD.group_name
|
||||
symm_mem.enable_symm_mem_for_group(group_name)
|
||||
rank = self.rank
|
||||
world_size = dist.get_world_size()
|
||||
msg_size_bytes = 8
|
||||
dtype = torch.int8
|
||||
numel = msg_size_bytes // dtype.itemsize
|
||||
|
||||
# Each rank fills its input buffer with its own rank value
|
||||
inp = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(rank)
|
||||
out = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(-1)
|
||||
inp_hdl = symm_mem.rendezvous(inp, group=group_name)
|
||||
out_hdl = symm_mem.rendezvous(out, group=group_name)
|
||||
dist.barrier()
|
||||
|
||||
# Ring topology: each rank gets data from the rank to its left
|
||||
# rank 0 gets from rank (world_size-1), rank 1 gets from rank 0, etc.
|
||||
peer = (rank - 1) % world_size
|
||||
|
||||
# All ranks execute the get operation
|
||||
dst_ptr = out_hdl.buffer_ptrs[rank]
|
||||
src_ptr = inp_hdl.buffer_ptrs[rank]
|
||||
get_kernel[(1, 1, 1)](
|
||||
dst_ptr,
|
||||
src_ptr,
|
||||
numel=numel,
|
||||
peer=peer,
|
||||
extern_libs=nvshmem_lib,
|
||||
)
|
||||
|
||||
expected_value = peer
|
||||
torch.testing.assert_close(
|
||||
out, expected_value * torch.ones(numel, dtype=dtype, device=self.device)
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
|
|
|||
|
|
@ -71,3 +71,21 @@ if has_triton():
|
|||
is_pure=False,
|
||||
_builder=_builder,
|
||||
)
|
||||
|
||||
@core.extern
|
||||
def getmem_block(dst, src, nelems, pe, _builder=None): # type: ignore[no-untyped-def]
|
||||
return core.extern_elementwise(
|
||||
"",
|
||||
"",
|
||||
[dst, src, nelems, pe],
|
||||
{
|
||||
(
|
||||
core.dtype("int64"),
|
||||
core.dtype("int64"),
|
||||
core.dtype("int64"),
|
||||
core.dtype("int64"),
|
||||
): ("nvshmemx_getmem_block", core.dtype("int32"))
|
||||
},
|
||||
is_pure=False,
|
||||
_builder=_builder,
|
||||
)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user