mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[SymmMem] Remove redundant dist.barrier in Triton NVSHMEM tests & add device‐side signal_op support (#156684)
## Summary This PR removes unnecessary `dist.barrier` calls up in our Triton NVSHMEM test suite and adds signal_op support, which is a lightweight device-side signaling mechanism. Added test for this in our `wait_until` kernel and corresponding `core.extern` wrapper. **Why did we drop the `dist.barrier()` calls?** We dropped the host‐side dist.barrier() in all Triton NVSHMEM tests (except the raw put/get cases) because every other test already uses NVSHMEM collectives or device‐side sync primitives (fence/quiet/signal/wait), making the extra barrier redundant. This keeps synchronization entirely on the GPU and leverages NVSHMEM’s native ordering guarantees for clearer, more efficient tests. **`test_triton_wait_until` update** - **Rank 1**: after `put_kernel` writes the data, launches `signal_op_kernel` to atomically set Rank 0's flag via `nvshmemx_signal_op` - **Rank 0**: drops its old `dist.barrier()` and simply calls `wait_until_kernel` to spin-wait on the device flag, then asserts data correctness - Changes made per [this comment](https://github.com/pytorch/pytorch/pull/156472#discussion_r2159734046) ## Testing ```bash TORCH_SYMMMEM=NVSHMEM python test/distributed/test_nvshmem.py ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/156684 Approved by: https://github.com/kwen2501, https://github.com/mandroid6
This commit is contained in:
parent
43a09189c6
commit
b6e625e34f
|
|
@ -611,6 +611,16 @@ class NVSHMEMSymmetricMemoryTest(MultiProcContinousTest):
|
|||
):
|
||||
nvshmem.putmem_block(dst_ptr, src_ptr, numel, peer)
|
||||
|
||||
# A Triton kernel that calls nvshmem device side API for SIGNAL_OP
|
||||
@triton.jit
|
||||
def signal_op_kernel(
|
||||
sig_addr,
|
||||
signal: tl.constexpr,
|
||||
sig_op: tl.constexpr,
|
||||
peer: tl.constexpr,
|
||||
):
|
||||
nvshmem.signal_op(sig_addr, signal, sig_op, peer)
|
||||
|
||||
# A Triton kernel that calls nvshmem device side API for WAIT_UNTIL
|
||||
@triton.jit
|
||||
def wait_until_kernel(
|
||||
|
|
@ -637,10 +647,10 @@ class NVSHMEMSymmetricMemoryTest(MultiProcContinousTest):
|
|||
out = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(-1)
|
||||
inp_hdl = symm_mem.rendezvous(inp, group=group_name)
|
||||
out_hdl = symm_mem.rendezvous(out, group=group_name)
|
||||
dist.barrier()
|
||||
|
||||
peer = 1 - rank
|
||||
NVSHMEM_CMP_EQ = 0 # from nvshmem.h
|
||||
NVSHMEM_SIGNAL_SET = 0 # atomic set operation
|
||||
|
||||
if rank == 0:
|
||||
# Rank 0 waits for the flag to be set by Rank 1, then checks the data
|
||||
|
|
@ -666,17 +676,13 @@ class NVSHMEMSymmetricMemoryTest(MultiProcContinousTest):
|
|||
peer=peer,
|
||||
extern_libs=nvshmem_lib,
|
||||
)
|
||||
# Rank 1 sets the flag on Rank 0
|
||||
# We use a temporary tensor for the value to put.
|
||||
flag_update_val = torch.tensor(
|
||||
[flag_val], dtype=torch.int64, device=self.device
|
||||
)
|
||||
dst_ptr = out_hdl.signal_pad_ptrs[rank]
|
||||
src_ptr = flag_update_val.data_ptr()
|
||||
put_kernel[(1, 1, 1)](
|
||||
dst_ptr,
|
||||
src_ptr,
|
||||
numel=1,
|
||||
|
||||
# Rank 1 sets the flag on Rank 0 using nvshmemx_signal_op
|
||||
sig_addr = out_hdl.signal_pad_ptrs[rank]
|
||||
signal_op_kernel[(1, 1, 1)](
|
||||
sig_addr,
|
||||
signal=flag_val,
|
||||
sig_op=NVSHMEM_SIGNAL_SET,
|
||||
peer=peer,
|
||||
extern_libs=nvshmem_lib,
|
||||
)
|
||||
|
|
@ -736,8 +742,6 @@ class NVSHMEMSymmetricMemoryTest(MultiProcContinousTest):
|
|||
# Use the signal pad for synchronization, as in previous tests
|
||||
flag_dtype = torch.int64
|
||||
flag = out_hdl.get_signal_pad(rank, (1,), dtype=flag_dtype).fill_(0)
|
||||
# Ensure setup is complete on all ranks before proceeding
|
||||
dist.barrier()
|
||||
|
||||
if rank == 0:
|
||||
# Producer (rank 0): Puts data into rank 1's `out` buffer and then sets the flag
|
||||
|
|
@ -773,8 +777,6 @@ class NVSHMEMSymmetricMemoryTest(MultiProcContinousTest):
|
|||
[COMPLETION_FLAG_VAL], dtype=flag_dtype, device=self.device
|
||||
),
|
||||
)
|
||||
# Final barrier to ensure the test does not exit before assertions complete
|
||||
dist.barrier()
|
||||
|
||||
@skipIfRocm
|
||||
@requires_triton()
|
||||
|
|
@ -851,7 +853,6 @@ class NVSHMEMSymmetricMemoryTest(MultiProcContinousTest):
|
|||
[flag_val], dtype=torch.int64, device=self.device
|
||||
)
|
||||
NVSHMEM_CMP_EQ = 0 # compare equal
|
||||
dist.barrier()
|
||||
|
||||
if rank == 0:
|
||||
dst_ptr1 = out1_hdl.buffer_ptrs[rank]
|
||||
|
|
@ -892,7 +893,6 @@ class NVSHMEMSymmetricMemoryTest(MultiProcContinousTest):
|
|||
torch.testing.assert_close(
|
||||
flag, torch.tensor([flag_val], dtype=torch.int64, device=self.device)
|
||||
)
|
||||
dist.barrier()
|
||||
|
||||
@skipIfRocm
|
||||
@requires_triton()
|
||||
|
|
@ -944,7 +944,6 @@ class NVSHMEMSymmetricMemoryTest(MultiProcContinousTest):
|
|||
):
|
||||
nvshmem.wait_until(ivar_ptr, cmp_op, cmp_val)
|
||||
|
||||
dist.barrier()
|
||||
if rank == 0:
|
||||
# Rank 0 waits for flag from Rank 1
|
||||
ivar_ptr = out_hdl.signal_pad_ptrs[rank]
|
||||
|
|
|
|||
|
|
@ -154,6 +154,24 @@ if has_triton():
|
|||
_builder=_builder,
|
||||
)
|
||||
|
||||
@core.extern
|
||||
def signal_op(sig_addr, signal, sig_op, pe, _builder=None): # type: ignore[no-untyped-def]
|
||||
return core.extern_elementwise(
|
||||
"",
|
||||
"",
|
||||
[sig_addr, signal, sig_op, pe],
|
||||
{
|
||||
(
|
||||
core.dtype("int64"),
|
||||
core.dtype("int64"),
|
||||
core.dtype("int64"),
|
||||
core.dtype("int64"),
|
||||
): ("nvshmemx_signal_op", core.dtype("int32"))
|
||||
},
|
||||
is_pure=False,
|
||||
_builder=_builder,
|
||||
)
|
||||
|
||||
@core.extern
|
||||
def fence(_builder=None): # type: ignore[no-untyped-def]
|
||||
return core.extern_elementwise(
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user