pytorch/torch/csrc/distributed
Ke Wen 9e9484d022 [SymmMem] Enable NVSHMEM for Triton (#155506)
(This is an **Experimental** feature)
Allow Triton kernels to invoke NVSHMEM device functions.

### Example Triton program
Key parts:
- Call `nvshmem.enable_triton()` to initialize;
- Call `nvshmem.putmem_block` in Triton kernel;
- Add `extern_libs` kwarg at kernel invocation.

```
import torch.distributed._symmetric_memory._nvshmem_triton as nvshmem

@triton.jit
def put_kernel(
    dst_ptr,
    src_ptr,
    numel: tl.constexpr,
    peer: tl.constexpr,
    BLOCK_SIZE: tl.constexpr,
):
    nvshmem.putmem_block(dst_ptr, src_ptr, numel, peer)

if __name__ == "__main__":
    # Enable NVSHMEM for Triton
    nvshmem_lib = nvshmem.enable_triton()

    # Use torch Symmetric Memory to allocate Symmetric tensors
    ...

    peer = 1 - rank
    if rank == 0:
        kernel = put_kernel[(1, 1, 1)](
            dst_ptr,
            src_ptr,
            numel=numel,
            peer=peer,
            BLOCK_SIZE=BLOCK_SIZE,
            extern_libs=nvshmem_lib,
        )

    dist.barrier()
    if rank == 1:
        print(f"Rank {rank}: received {out=}")
```

### Test output:
```
$ TORCH_SYMMMEM=NVSHMEM python test/distributed/test_nvshmem.py -k test_triton_put
Rank 0: writing value 5 to Peer 1
Rank 1: received out=tensor([5, 5, 5, 5, 5, 5, 5, 5], device='cuda:1', dtype=torch.int8)
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/155506
Approved by: https://github.com/ngimel, https://github.com/fegin, https://github.com/fduwjj
2025-06-12 00:22:49 +00:00
..
autograd Enable more readability-redundant checks (#143963) 2024-12-30 14:49:33 +00:00
c10d [SymmMem] Enable NVSHMEM for Triton (#155506) 2025-06-12 00:22:49 +00:00
rpc [BE][Ez] Update deprecated pybind11 functions (#154798) 2025-06-01 06:17:50 +00:00