[SymmMem] Remove redundant dist.barrier in Triton NVSHMEM tests & add device‐side signal_op support (#156684)

## Summary

This PR removes unnecessary `dist.barrier` calls up in our Triton NVSHMEM test suite and adds signal_op support, which is a lightweight device-side signaling mechanism. Added test for this in our `wait_until` kernel and corresponding `core.extern` wrapper.

**Why did we drop the `dist.barrier()` calls?**
We dropped the host‐side dist.barrier() in all Triton NVSHMEM tests (except the raw put/get cases) because every other test already uses NVSHMEM collectives or device‐side sync primitives (fence/quiet/signal/wait), making the extra barrier redundant. This keeps synchronization entirely on the GPU and leverages NVSHMEM’s native ordering guarantees for clearer, more efficient tests.

**`test_triton_wait_until` update**
- **Rank 1**: after `put_kernel` writes the data, launches `signal_op_kernel` to atomically set Rank 0's flag via `nvshmemx_signal_op`
- **Rank 0**: drops its old `dist.barrier()` and simply calls `wait_until_kernel` to spin-wait on the device flag, then asserts data correctness
- Changes made per [this comment](https://github.com/pytorch/pytorch/pull/156472#discussion_r2159734046)

## Testing

```bash
TORCH_SYMMMEM=NVSHMEM python test/distributed/test_nvshmem.py
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/156684
Approved by: https://github.com/kwen2501, https://github.com/mandroid6
This commit is contained in:
codingwithsurya 2025-06-26 01:14:42 -07:00 committed by PyTorch MergeBot
parent 43a09189c6
commit b6e625e34f
2 changed files with 36 additions and 19 deletions

View File

@ -611,6 +611,16 @@ class NVSHMEMSymmetricMemoryTest(MultiProcContinousTest):
):
nvshmem.putmem_block(dst_ptr, src_ptr, numel, peer)
# A Triton kernel that calls nvshmem device side API for SIGNAL_OP
@triton.jit
def signal_op_kernel(
sig_addr,
signal: tl.constexpr,
sig_op: tl.constexpr,
peer: tl.constexpr,
):
nvshmem.signal_op(sig_addr, signal, sig_op, peer)
# A Triton kernel that calls nvshmem device side API for WAIT_UNTIL
@triton.jit
def wait_until_kernel(
@ -637,10 +647,10 @@ class NVSHMEMSymmetricMemoryTest(MultiProcContinousTest):
out = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(-1)
inp_hdl = symm_mem.rendezvous(inp, group=group_name)
out_hdl = symm_mem.rendezvous(out, group=group_name)
dist.barrier()
peer = 1 - rank
NVSHMEM_CMP_EQ = 0 # from nvshmem.h
NVSHMEM_SIGNAL_SET = 0 # atomic set operation
if rank == 0:
# Rank 0 waits for the flag to be set by Rank 1, then checks the data
@ -666,17 +676,13 @@ class NVSHMEMSymmetricMemoryTest(MultiProcContinousTest):
peer=peer,
extern_libs=nvshmem_lib,
)
# Rank 1 sets the flag on Rank 0
# We use a temporary tensor for the value to put.
flag_update_val = torch.tensor(
[flag_val], dtype=torch.int64, device=self.device
)
dst_ptr = out_hdl.signal_pad_ptrs[rank]
src_ptr = flag_update_val.data_ptr()
put_kernel[(1, 1, 1)](
dst_ptr,
src_ptr,
numel=1,
# Rank 1 sets the flag on Rank 0 using nvshmemx_signal_op
sig_addr = out_hdl.signal_pad_ptrs[rank]
signal_op_kernel[(1, 1, 1)](
sig_addr,
signal=flag_val,
sig_op=NVSHMEM_SIGNAL_SET,
peer=peer,
extern_libs=nvshmem_lib,
)
@ -736,8 +742,6 @@ class NVSHMEMSymmetricMemoryTest(MultiProcContinousTest):
# Use the signal pad for synchronization, as in previous tests
flag_dtype = torch.int64
flag = out_hdl.get_signal_pad(rank, (1,), dtype=flag_dtype).fill_(0)
# Ensure setup is complete on all ranks before proceeding
dist.barrier()
if rank == 0:
# Producer (rank 0): Puts data into rank 1's `out` buffer and then sets the flag
@ -773,8 +777,6 @@ class NVSHMEMSymmetricMemoryTest(MultiProcContinousTest):
[COMPLETION_FLAG_VAL], dtype=flag_dtype, device=self.device
),
)
# Final barrier to ensure the test does not exit before assertions complete
dist.barrier()
@skipIfRocm
@requires_triton()
@ -851,7 +853,6 @@ class NVSHMEMSymmetricMemoryTest(MultiProcContinousTest):
[flag_val], dtype=torch.int64, device=self.device
)
NVSHMEM_CMP_EQ = 0 # compare equal
dist.barrier()
if rank == 0:
dst_ptr1 = out1_hdl.buffer_ptrs[rank]
@ -892,7 +893,6 @@ class NVSHMEMSymmetricMemoryTest(MultiProcContinousTest):
torch.testing.assert_close(
flag, torch.tensor([flag_val], dtype=torch.int64, device=self.device)
)
dist.barrier()
@skipIfRocm
@requires_triton()
@ -944,7 +944,6 @@ class NVSHMEMSymmetricMemoryTest(MultiProcContinousTest):
):
nvshmem.wait_until(ivar_ptr, cmp_op, cmp_val)
dist.barrier()
if rank == 0:
# Rank 0 waits for flag from Rank 1
ivar_ptr = out_hdl.signal_pad_ptrs[rank]

View File

@ -154,6 +154,24 @@ if has_triton():
_builder=_builder,
)
@core.extern
def signal_op(sig_addr, signal, sig_op, pe, _builder=None): # type: ignore[no-untyped-def]
return core.extern_elementwise(
"",
"",
[sig_addr, signal, sig_op, pe],
{
(
core.dtype("int64"),
core.dtype("int64"),
core.dtype("int64"),
core.dtype("int64"),
): ("nvshmemx_signal_op", core.dtype("int32"))
},
is_pure=False,
_builder=_builder,
)
@core.extern
def fence(_builder=None): # type: ignore[no-untyped-def]
return core.extern_elementwise(