mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Adds NVSHMEM PUT with Signal operation support for Triton kernels:
- Added`putmem_signal_block` core.extern wrapper for nvshmemx_putmem_signal_block
- Added kernel for 2-rank PUT operation with atomic SET signaling (`test_triton_put_signal_set`)
- Added kernel for 2-rank PUT operation with atomic ADD signaling (`test_triton_put_signal_add`)
**Tests:**
`$ TORCH_SYMMMEM=NVSHMEM python test/distributed/test_nvshmem.py`
`TORCH_SYMMMEM=NVSHMEM python test/distributed/test_nvshmem.py -k test_triton_put_signal_set`
`TORCH_SYMMMEM=NVSHMEM python test/distributed/test_nvshmem.py -k test_triton_put_signal_add`
```python
@skipIfRocm
@requires_triton()
def test_triton_put_signal_set(self) -> None:
@triton.jit
def put_signal_kernel(dst_ptr, src_ptr, numel: tl.constexpr, sig_ptr,
signal_val: tl.constexpr, sig_op: tl.constexpr, peer: tl.constexpr):
nvshmem.putmem_signal_block(dst_ptr, src_ptr, numel, sig_ptr, signal_val, sig_op, peer)
# ... setup code ...
val = 11
inp = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(val)
out = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(-1) # destination buffer
# Signal flag buffer - starts at 0, will be set to 1 upon completion
flag = symm_mem.empty(1, dtype=torch.int64, device=self.device).fill_(0)
peer = 1 - rank
NVSHMEM_SIGNAL_SET = 0 # atomic set operation
SIGNAL_VAL = 1 # completion signal value
if rank == 0:
# Rank 0 atomically: (1) puts data to rank 1, (2) sets rank 1's flag to 1
put_signal_kernel[(1, 1, 1)](dst_ptr, src_ptr, numel=numel, sig_ptr=sig_ptr,
signal_val=SIGNAL_VAL, sig_op=NVSHMEM_SIGNAL_SET,
peer=peer, extern_libs=nvshmem_lib)
dist.barrier()
# Rank 1 can check flag to know data transfer completed!
print(f"[Rank {rank}] inp buffer: {inp}")
print(f"[Rank {rank}] out buffer: {out}")
print(f"[Rank {rank}] flag buffer: {flag}")
```
```
[Rank 0] inp buffer: tensor([11, 11, 11, 11, 11, 11, 11, 11], device='cuda:0', dtype=torch.int8)
[Rank 0] out buffer: tensor([-1, -1, -1, -1, -1, -1, -1, -1], device='cuda:0', dtype=torch.int8)
[Rank 0] got data from peer 1
[Rank 0] flag buffer: tensor([0], device='cuda:0')
[Rank 1] inp buffer: tensor([11, 11, 11, 11, 11, 11, 11, 11], device='cuda:1', dtype=torch.int8)
[Rank 1] out buffer: tensor([11, 11, 11, 11, 11, 11, 11, 11], device='cuda:1', dtype=torch.int8)
[Rank 1] got data from peer 0
[Rank 1] flag buffer: tensor([1], device='cuda:1')
----------------------------------------------------------------------
Ran 2 tests in 17.046s
OK
```
Working as expected! Data is received, and flag set to 1 for completion signal!
```python
@skipIfRocm
@requires_triton()
def test_triton_put_signal_add(self) -> None:
@triton.jit
def put_signal_kernel(dst_ptr, src_ptr, numel: tl.constexpr, sig_ptr,
signal_val: tl.constexpr, sig_op: tl.constexpr, peer: tl.constexpr):
nvshmem.putmem_signal_block(dst_ptr, src_ptr, numel, sig_ptr, signal_val, sig_op, peer)
# ... setup code ...
# Signal buffer (uint64 flag)
flag = symm_mem.empty(1, dtype=torch.int64, device=self.device).fill_(0)
peer = 1 - rank
NVSHMEM_SIGNAL_ADD = 5 # atomic add operation
SIGNAL_VAL = 16 # Signal value to add
if rank == 0:
# Rank 0 puts into Rank 1 and adds to signal
put_signal_kernel[(1, 1, 1)](dst_ptr, src_ptr, numel=numel, sig_ptr=sig_ptr,
signal_val=SIGNAL_VAL, sig_op=NVSHMEM_SIGNAL_ADD,
peer=peer, extern_libs=nvshmem_lib)
dist.barrier()
print(f"[Rank {rank}] inp buffer: {inp}")
print(f"[Rank {rank}] out buffer: {out}")
print(f"[Rank {rank}] flag buffer: {flag}")
```
```
[Rank 0] inp buffer: tensor([11, 11, 11, 11, 11, 11, 11, 11], device='cuda:0', dtype=torch.int8)
[Rank 0] out buffer: tensor([-1, -1, -1, -1, -1, -1, -1, -1], device='cuda:0', dtype=torch.int8)
[Rank 0] got data from peer 1
[Rank 0] flag buffer: tensor([0], device='cuda:0')
[Rank 1] inp buffer: tensor([11, 11, 11, 11, 11, 11, 11, 11], device='cuda:1', dtype=torch.int8)
[Rank 1] out buffer: tensor([11, 11, 11, 11, 11, 11, 11, 11], device='cuda:1', dtype=torch.int8)
[Rank 1] got data from peer 0
[Rank 1] flag buffer: tensor([16], device='cuda:1')
----------------------------------------------------------------------
Ran 1 test in 17.145s
OK
```
The flag transition from [0] → [16] confirms both data delivery and atomic signal completion in a single operation!
Pull Request resolved: https://github.com/pytorch/pytorch/pull/156211
Approved by: https://github.com/kwen2501, https://github.com/mandroid6
|
||
|---|---|---|
| .. | ||
| _composable | ||
| _shard | ||
| _sharded_tensor | ||
| _sharding_spec | ||
| _symmetric_memory | ||
| _tensor | ||
| _tools | ||
| algorithms | ||
| autograd | ||
| benchmarks | ||
| checkpoint | ||
| elastic | ||
| examples | ||
| fsdp | ||
| launcher | ||
| nn | ||
| optim | ||
| pipelining | ||
| rpc | ||
| tensor | ||
| __init__.py | ||
| _checkpointable.py | ||
| _composable_state.py | ||
| _functional_collectives_impl.py | ||
| _functional_collectives.py | ||
| _serialization.py | ||
| _state_dict_utils.py | ||
| argparse_util.py | ||
| c10d_logger.py | ||
| collective_utils.py | ||
| constants.py | ||
| CONTRIBUTING.md | ||
| device_mesh.py | ||
| distributed_c10d.py | ||
| launch.py | ||
| logging_handlers.py | ||
| remote_device.py | ||
| rendezvous.py | ||
| run.py | ||
| utils.py | ||