mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
(This is an **Experimental** feature)
Allow Triton kernels to invoke NVSHMEM device functions.
### Example Triton program
Key parts:
- Call `nvshmem.enable_triton()` to initialize;
- Call `nvshmem.putmem_block` in Triton kernel;
- Add `extern_libs` kwarg at kernel invocation.
```
import torch.distributed._symmetric_memory._nvshmem_triton as nvshmem
@triton.jit
def put_kernel(
dst_ptr,
src_ptr,
numel: tl.constexpr,
peer: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
nvshmem.putmem_block(dst_ptr, src_ptr, numel, peer)
if __name__ == "__main__":
# Enable NVSHMEM for Triton
nvshmem_lib = nvshmem.enable_triton()
# Use torch Symmetric Memory to allocate Symmetric tensors
...
peer = 1 - rank
if rank == 0:
kernel = put_kernel[(1, 1, 1)](
dst_ptr,
src_ptr,
numel=numel,
peer=peer,
BLOCK_SIZE=BLOCK_SIZE,
extern_libs=nvshmem_lib,
)
dist.barrier()
if rank == 1:
print(f"Rank {rank}: received {out=}")
```
### Test output:
```
$ TORCH_SYMMMEM=NVSHMEM python test/distributed/test_nvshmem.py -k test_triton_put
Rank 0: writing value 5 to Peer 1
Rank 1: received out=tensor([5, 5, 5, 5, 5, 5, 5, 5], device='cuda:1', dtype=torch.int8)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/155506
Approved by: https://github.com/ngimel, https://github.com/fegin, https://github.com/fduwjj
|
||
|---|---|---|
| .. | ||
| core | ||
| perfkernels | ||
| serialize | ||
| utils | ||
| .clang-format | ||
| CMakeLists.txt | ||
| unexported_symbols.lds | ||
| version_script.lds | ||