pytorch/torch/distributed
codingwithsurya 4da98351b9 [SymmMem] Add NVSHMEM PUT with Signal support to Triton (#156211)
Adds NVSHMEM PUT with Signal operation support for Triton kernels:

- Added`putmem_signal_block` core.extern wrapper for nvshmemx_putmem_signal_block
- Added kernel for 2-rank PUT operation with atomic SET signaling (`test_triton_put_signal_set`)
- Added kernel for 2-rank PUT operation with atomic ADD signaling (`test_triton_put_signal_add`)

**Tests:**
`$ TORCH_SYMMMEM=NVSHMEM python test/distributed/test_nvshmem.py`

`TORCH_SYMMMEM=NVSHMEM python test/distributed/test_nvshmem.py -k test_triton_put_signal_set`
`TORCH_SYMMMEM=NVSHMEM python test/distributed/test_nvshmem.py -k test_triton_put_signal_add`

```python
@skipIfRocm
@requires_triton()
def test_triton_put_signal_set(self) -> None:
    @triton.jit
    def put_signal_kernel(dst_ptr, src_ptr, numel: tl.constexpr, sig_ptr,
                         signal_val: tl.constexpr, sig_op: tl.constexpr, peer: tl.constexpr):
        nvshmem.putmem_signal_block(dst_ptr, src_ptr, numel, sig_ptr, signal_val, sig_op, peer)

    # ... setup code ...

    val = 11
    inp = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(val)
    out = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(-1)  # destination buffer

    # Signal flag buffer - starts at 0, will be set to 1 upon completion
    flag = symm_mem.empty(1, dtype=torch.int64, device=self.device).fill_(0)

    peer = 1 - rank
    NVSHMEM_SIGNAL_SET = 0  # atomic set operation
    SIGNAL_VAL = 1  # completion signal value

    if rank == 0:
        # Rank 0 atomically: (1) puts data to rank 1, (2) sets rank 1's flag to 1
        put_signal_kernel[(1, 1, 1)](dst_ptr, src_ptr, numel=numel, sig_ptr=sig_ptr,
                                    signal_val=SIGNAL_VAL, sig_op=NVSHMEM_SIGNAL_SET,
                                    peer=peer, extern_libs=nvshmem_lib)

   dist.barrier()
   # Rank 1 can check flag to know data transfer completed!
   print(f"[Rank {rank}] inp buffer: {inp}")
   print(f"[Rank {rank}] out buffer: {out}")
   print(f"[Rank {rank}] flag buffer: {flag}")
```

```
[Rank 0] inp buffer: tensor([11, 11, 11, 11, 11, 11, 11, 11], device='cuda:0', dtype=torch.int8)
[Rank 0] out buffer: tensor([-1, -1, -1, -1, -1, -1, -1, -1], device='cuda:0', dtype=torch.int8)
[Rank 0] got data from peer 1
[Rank 0] flag buffer: tensor([0], device='cuda:0')
[Rank 1] inp buffer: tensor([11, 11, 11, 11, 11, 11, 11, 11], device='cuda:1', dtype=torch.int8)
[Rank 1] out buffer: tensor([11, 11, 11, 11, 11, 11, 11, 11], device='cuda:1', dtype=torch.int8)
[Rank 1] got data from peer 0
[Rank 1] flag buffer: tensor([1], device='cuda:1')

----------------------------------------------------------------------
Ran 2 tests in 17.046s

OK
```

Working as expected! Data is received, and flag set to 1 for completion signal!

```python
@skipIfRocm
@requires_triton()
def test_triton_put_signal_add(self) -> None:
   @triton.jit
   def put_signal_kernel(dst_ptr, src_ptr, numel: tl.constexpr, sig_ptr,
                        signal_val: tl.constexpr, sig_op: tl.constexpr, peer: tl.constexpr):
       nvshmem.putmem_signal_block(dst_ptr, src_ptr, numel, sig_ptr, signal_val, sig_op, peer)

   # ... setup code ...

   # Signal buffer (uint64 flag)
   flag = symm_mem.empty(1, dtype=torch.int64, device=self.device).fill_(0)

   peer = 1 - rank
   NVSHMEM_SIGNAL_ADD = 5  # atomic add operation
   SIGNAL_VAL = 16  # Signal value to add

   if rank == 0:
       # Rank 0 puts into Rank 1 and adds to signal
       put_signal_kernel[(1, 1, 1)](dst_ptr, src_ptr, numel=numel, sig_ptr=sig_ptr,
                                   signal_val=SIGNAL_VAL, sig_op=NVSHMEM_SIGNAL_ADD,
                                   peer=peer, extern_libs=nvshmem_lib)

   dist.barrier()
   print(f"[Rank {rank}] inp buffer: {inp}")
   print(f"[Rank {rank}] out buffer: {out}")
   print(f"[Rank {rank}] flag buffer: {flag}")

```

```
[Rank 0] inp buffer: tensor([11, 11, 11, 11, 11, 11, 11, 11], device='cuda:0', dtype=torch.int8)
[Rank 0] out buffer: tensor([-1, -1, -1, -1, -1, -1, -1, -1], device='cuda:0', dtype=torch.int8)
[Rank 0] got data from peer 1
[Rank 0] flag buffer: tensor([0], device='cuda:0')
[Rank 1] inp buffer: tensor([11, 11, 11, 11, 11, 11, 11, 11], device='cuda:1', dtype=torch.int8)
[Rank 1] out buffer: tensor([11, 11, 11, 11, 11, 11, 11, 11], device='cuda:1', dtype=torch.int8)
[Rank 1] got data from peer 0
[Rank 1] flag buffer: tensor([16], device='cuda:1')

----------------------------------------------------------------------
Ran 1 test in 17.145s

OK
```

The flag transition from [0] → [16] confirms both data delivery and atomic signal completion in a single operation!

Pull Request resolved: https://github.com/pytorch/pytorch/pull/156211
Approved by: https://github.com/kwen2501, https://github.com/mandroid6
2025-06-19 10:24:30 +00:00
..
_composable
_shard mypy 1.16.0 (#155821) 2025-06-14 18:18:43 +00:00
_sharded_tensor
_sharding_spec
_symmetric_memory [SymmMem] Add NVSHMEM PUT with Signal support to Triton (#156211) 2025-06-19 10:24:30 +00:00
_tensor
_tools Support XPU in memory tracker (#150703) 2025-06-12 21:33:52 +00:00
algorithms Revert "[BE]: Enable RUFF TRY400 rule - log.exception (#153473)" 2025-05-16 08:29:26 +00:00
autograd
benchmarks
checkpoint [dcp] add new checkpoint staging to preserve storage sharing and support mutable state_dicts (#155192) 2025-06-19 02:04:21 +00:00
elastic Typo fixes for "overridden" in comments and function names (#155944) 2025-06-14 03:37:38 +00:00
examples Support XPU in memory tracker (#150703) 2025-06-12 21:33:52 +00:00
fsdp Allow forcing FSDP2 to always use SUM reductions (#155915) 2025-06-18 18:57:47 +00:00
launcher [2/n]passing event log handler to record function calls (#155457) 2025-06-12 19:35:08 +00:00
nn [BE]: Update ruff to 0.11.8 (#153249) 2025-05-12 18:30:52 +00:00
optim Fix #155018 (convert distributed rst to markdown) (#155528) 2025-06-16 20:46:09 +00:00
pipelining Add get_pipeline_order() for Gpipe and 1F1B (#155935) 2025-06-17 23:39:17 +00:00
rpc Make torch importable if compiled without TensorPipe (#154382) 2025-05-27 18:13:38 +00:00
tensor mypy 1.16.0 (#155821) 2025-06-14 18:18:43 +00:00
__init__.py c10d/Store: add nonblocking mode to queue_pop (#151485) 2025-04-18 02:14:50 +00:00
_checkpointable.py [BE]: Backport runtime_checkable perf improvements/behavior from 3.12 (#155130) 2025-06-06 13:28:05 +00:00
_composable_state.py
_functional_collectives_impl.py
_functional_collectives.py mypy 1.16.0 (#155821) 2025-06-14 18:18:43 +00:00
_serialization.py
_state_dict_utils.py [dcp] add new checkpoint staging to preserve storage sharing and support mutable state_dicts (#155192) 2025-06-19 02:04:21 +00:00
argparse_util.py
c10d_logger.py
collective_utils.py
constants.py
CONTRIBUTING.md
device_mesh.py mypy 1.16.0 (#155821) 2025-06-14 18:18:43 +00:00
distributed_c10d.py Skip updating the default device distributed backend if already registered (#155320) 2025-06-12 21:17:06 +00:00
launch.py
logging_handlers.py
remote_device.py
rendezvous.py Fix tcp init when using port 0 (#154156) 2025-05-23 21:41:58 +00:00
run.py [1/n]adding torch.distributed.run option to provide destination for event logging (#154644) (#155268) 2025-06-09 10:43:52 +00:00
utils.py Refactor to use torch.accelerator.device_index instead of torch.cuda.device for generic device context manager (#148880) 2025-04-25 09:45:25 +00:00