pytorch/test/distributed/test_nvshmem.py
Ke Wen 9d922b55ef [Distributed][CI] Rework continuous TestCase (#153653)
1. Reworked `MultiProcContinousTest` to spawn processes during `setUpClass` instead of `main` (so that we can support multiple TestClass'es in one file).

2. The child processes are now an infinite loop, monitoring test IDs passed from main process via a task queue. Reciprocally, the child processes inform the main process completion of a test via a completion queue.

3. Added a test template.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153653
Approved by: https://github.com/d4l3k, https://github.com/fegin, https://github.com/fduwjj
2025-05-25 03:49:29 +00:00

134 lines
4.3 KiB
Python

# Owner(s): ["oncall: distributed"]
# To run:
# TORCH_SYMMMEM=NVSHMEM python test/distributed/test_nvshmem.py
# OR
# TORCH_SYMMMEM=NVSHMEM torchrun --nproc-per-node 4 test/distributed/test_nvshmem.py
import os
import sys
import torch
import torch.distributed as dist
import torch.distributed._symmetric_memory as symm_mem
from torch.testing._internal.common_distributed import MultiProcContinousTest
from torch.testing._internal.common_utils import (
run_tests,
skip_but_pass_in_sandcastle_if,
skipIfRocm,
)
symm_mem_backend = os.getenv("TORCH_SYMMMEM")
if symm_mem_backend != "NVSHMEM":
print(
"test_nvshmem requires setting `TORCH_SYMMMEM=NVSHMEM`, skipping tests",
file=sys.stderr,
)
sys.exit(0)
# Decorator
def requires_nvshmem():
return skip_but_pass_in_sandcastle_if(
symm_mem_backend != "NVSHMEM",
"test_nvshmem requires setting `TORCH_SYMMMEM=NVSHMEM`",
)
# So that tests are written in device-agnostic way
device_type = "cuda"
device_module = torch.get_device_module(device_type)
@requires_nvshmem()
class NVSHMEMSymmetricMemoryTest(MultiProcContinousTest):
def _init_device(self) -> None:
# TODO: relieve this (seems to hang if without)
device_module.set_device(self.device)
# NOTE: required for nvshmem allocation
torch.empty(1, device=self.device)
@property
def device(self) -> torch.device:
return torch.device(device_type, self.rank)
@skipIfRocm
def test_nvshmem_all_to_all(self) -> None:
self._init_device()
group_name = dist.group.WORLD.group_name
symm_mem.enable_symm_mem_for_group(group_name)
dtype = torch.float
numel_per_peer = 10
numel = self.world_size * numel_per_peer
inp = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(self.rank)
out = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(-1)
symm_mem.rendezvous(inp, group=group_name)
symm_mem.rendezvous(out, group=group_name)
torch.ops.symm_mem.nvshmem_all_to_all(inp, out, group_name)
expected = torch.cat(
[
torch.empty(numel_per_peer, dtype=dtype, device=self.device).fill_(i)
for i in range(self.world_size)
]
)
torch.testing.assert_close(out, expected)
@skipIfRocm
def test_nvshmem_all_to_all_vdev(self) -> None:
self._init_device()
group_name = dist.group.WORLD.group_name
symm_mem.enable_symm_mem_for_group(group_name)
dtype = torch.float
# Number of elements for a peer is random between [0, k)
k = 10
inp_splits = torch.randint(k, (self.world_size,), device=self.device)
inp_numel = inp_splits.sum().item()
# Exchange input splits to get output splits
out_splits = torch.zeros_like(inp_splits)
dist.all_to_all_single(out_splits, inp_splits)
out_numel = out_splits.sum().item()
# Align up to make it bigger
align = 16
out_numel_max = (out_numel + align - 1) // align * align
inp = symm_mem.empty(inp_numel, dtype=dtype, device=self.device).fill_(
self.rank
)
out = symm_mem.empty(out_numel_max, dtype=dtype, device=self.device).fill_(-1)
in_out_splits = symm_mem.empty(
(3, self.world_size), dtype=torch.int64, device=self.device
)
# Row 0 is input splits
in_out_splits[0].copy_(inp_splits)
torch.ops.symm_mem.nvshmem_all_to_all_vdev(inp, out, in_out_splits, group_name)
# Check input splits (row 0) -- should not change
torch.testing.assert_close(in_out_splits[0], inp_splits)
# Check output splits (row 1)
torch.testing.assert_close(in_out_splits[1], out_splits)
# Check output offsets (row 2)
out_offsets = torch.cumsum(out_splits, dim=0) # inclusive scan
# output offsets from `nvshmem_all_to_all_vdev` is exclusive scan
self.assertEqual(in_out_splits[2][0], 0)
torch.testing.assert_close(in_out_splits[2][1:], out_offsets[:-1])
# Check data
expected = torch.empty(out_numel, dtype=dtype, device=self.device)
dist.all_to_all_single(expected, inp, out_splits.tolist(), inp_splits.tolist())
torch.testing.assert_close(out[:out_numel], expected)
if __name__ == "__main__":
run_tests()