pytorch/caffe2
Ke Wen 9e9484d022 [SymmMem] Enable NVSHMEM for Triton (#155506)
(This is an **Experimental** feature)
Allow Triton kernels to invoke NVSHMEM device functions.

### Example Triton program
Key parts:
- Call `nvshmem.enable_triton()` to initialize;
- Call `nvshmem.putmem_block` in Triton kernel;
- Add `extern_libs` kwarg at kernel invocation.

```
import torch.distributed._symmetric_memory._nvshmem_triton as nvshmem

@triton.jit
def put_kernel(
    dst_ptr,
    src_ptr,
    numel: tl.constexpr,
    peer: tl.constexpr,
    BLOCK_SIZE: tl.constexpr,
):
    nvshmem.putmem_block(dst_ptr, src_ptr, numel, peer)

if __name__ == "__main__":
    # Enable NVSHMEM for Triton
    nvshmem_lib = nvshmem.enable_triton()

    # Use torch Symmetric Memory to allocate Symmetric tensors
    ...

    peer = 1 - rank
    if rank == 0:
        kernel = put_kernel[(1, 1, 1)](
            dst_ptr,
            src_ptr,
            numel=numel,
            peer=peer,
            BLOCK_SIZE=BLOCK_SIZE,
            extern_libs=nvshmem_lib,
        )

    dist.barrier()
    if rank == 1:
        print(f"Rank {rank}: received {out=}")
```

### Test output:
```
$ TORCH_SYMMMEM=NVSHMEM python test/distributed/test_nvshmem.py -k test_triton_put
Rank 0: writing value 5 to Peer 1
Rank 1: received out=tensor([5, 5, 5, 5, 5, 5, 5, 5], device='cuda:1', dtype=torch.int8)
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/155506
Approved by: https://github.com/ngimel, https://github.com/fegin, https://github.com/fduwjj
2025-06-12 00:22:49 +00:00
..
core Record the XPU and XCCL build settings in the compiled binary (#147161) 2025-05-20 09:21:39 +00:00
perfkernels Optimize SVE embedding performance (#150176) 2025-04-07 18:01:54 +00:00
serialize Update PyTorchStreamReader API to take cpu allocator override (#150439) 2025-04-18 01:53:14 +00:00
utils [3/N] Use internal linkage in C++ files (#151297) 2025-05-05 17:48:39 +00:00
.clang-format [BE][clang-format] make macro PyObject_HEAD_INIT(type) and PyVarObject_HEAD_INIT(type, size) have its own line (#136949) 2024-10-02 18:39:22 +00:00
CMakeLists.txt [SymmMem] Enable NVSHMEM for Triton (#155506) 2025-06-12 00:22:49 +00:00
unexported_symbols.lds
version_script.lds