mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
1. Reworked `MultiProcContinousTest` to spawn processes during `setUpClass` instead of `main` (so that we can support multiple TestClass'es in one file). 2. The child processes are now an infinite loop, monitoring test IDs passed from main process via a task queue. Reciprocally, the child processes inform the main process completion of a test via a completion queue. 3. Added a test template. Pull Request resolved: https://github.com/pytorch/pytorch/pull/153653 Approved by: https://github.com/d4l3k, https://github.com/fegin, https://github.com/fduwjj
134 lines
4.3 KiB
Python
134 lines
4.3 KiB
Python
# Owner(s): ["oncall: distributed"]
|
|
|
|
# To run:
|
|
# TORCH_SYMMMEM=NVSHMEM python test/distributed/test_nvshmem.py
|
|
# OR
|
|
# TORCH_SYMMMEM=NVSHMEM torchrun --nproc-per-node 4 test/distributed/test_nvshmem.py
|
|
|
|
import os
|
|
import sys
|
|
|
|
import torch
|
|
import torch.distributed as dist
|
|
import torch.distributed._symmetric_memory as symm_mem
|
|
from torch.testing._internal.common_distributed import MultiProcContinousTest
|
|
from torch.testing._internal.common_utils import (
|
|
run_tests,
|
|
skip_but_pass_in_sandcastle_if,
|
|
skipIfRocm,
|
|
)
|
|
|
|
|
|
symm_mem_backend = os.getenv("TORCH_SYMMMEM")
|
|
|
|
if symm_mem_backend != "NVSHMEM":
|
|
print(
|
|
"test_nvshmem requires setting `TORCH_SYMMMEM=NVSHMEM`, skipping tests",
|
|
file=sys.stderr,
|
|
)
|
|
sys.exit(0)
|
|
|
|
|
|
# Decorator
|
|
def requires_nvshmem():
|
|
return skip_but_pass_in_sandcastle_if(
|
|
symm_mem_backend != "NVSHMEM",
|
|
"test_nvshmem requires setting `TORCH_SYMMMEM=NVSHMEM`",
|
|
)
|
|
|
|
|
|
# So that tests are written in device-agnostic way
|
|
device_type = "cuda"
|
|
device_module = torch.get_device_module(device_type)
|
|
|
|
|
|
@requires_nvshmem()
|
|
class NVSHMEMSymmetricMemoryTest(MultiProcContinousTest):
|
|
def _init_device(self) -> None:
|
|
# TODO: relieve this (seems to hang if without)
|
|
device_module.set_device(self.device)
|
|
# NOTE: required for nvshmem allocation
|
|
torch.empty(1, device=self.device)
|
|
|
|
@property
|
|
def device(self) -> torch.device:
|
|
return torch.device(device_type, self.rank)
|
|
|
|
@skipIfRocm
|
|
def test_nvshmem_all_to_all(self) -> None:
|
|
self._init_device()
|
|
|
|
group_name = dist.group.WORLD.group_name
|
|
symm_mem.enable_symm_mem_for_group(group_name)
|
|
|
|
dtype = torch.float
|
|
numel_per_peer = 10
|
|
numel = self.world_size * numel_per_peer
|
|
inp = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(self.rank)
|
|
out = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(-1)
|
|
|
|
symm_mem.rendezvous(inp, group=group_name)
|
|
symm_mem.rendezvous(out, group=group_name)
|
|
torch.ops.symm_mem.nvshmem_all_to_all(inp, out, group_name)
|
|
|
|
expected = torch.cat(
|
|
[
|
|
torch.empty(numel_per_peer, dtype=dtype, device=self.device).fill_(i)
|
|
for i in range(self.world_size)
|
|
]
|
|
)
|
|
torch.testing.assert_close(out, expected)
|
|
|
|
@skipIfRocm
|
|
def test_nvshmem_all_to_all_vdev(self) -> None:
|
|
self._init_device()
|
|
|
|
group_name = dist.group.WORLD.group_name
|
|
symm_mem.enable_symm_mem_for_group(group_name)
|
|
|
|
dtype = torch.float
|
|
# Number of elements for a peer is random between [0, k)
|
|
k = 10
|
|
inp_splits = torch.randint(k, (self.world_size,), device=self.device)
|
|
inp_numel = inp_splits.sum().item()
|
|
# Exchange input splits to get output splits
|
|
out_splits = torch.zeros_like(inp_splits)
|
|
dist.all_to_all_single(out_splits, inp_splits)
|
|
out_numel = out_splits.sum().item()
|
|
# Align up to make it bigger
|
|
align = 16
|
|
out_numel_max = (out_numel + align - 1) // align * align
|
|
|
|
inp = symm_mem.empty(inp_numel, dtype=dtype, device=self.device).fill_(
|
|
self.rank
|
|
)
|
|
out = symm_mem.empty(out_numel_max, dtype=dtype, device=self.device).fill_(-1)
|
|
in_out_splits = symm_mem.empty(
|
|
(3, self.world_size), dtype=torch.int64, device=self.device
|
|
)
|
|
# Row 0 is input splits
|
|
in_out_splits[0].copy_(inp_splits)
|
|
|
|
torch.ops.symm_mem.nvshmem_all_to_all_vdev(inp, out, in_out_splits, group_name)
|
|
|
|
# Check input splits (row 0) -- should not change
|
|
torch.testing.assert_close(in_out_splits[0], inp_splits)
|
|
|
|
# Check output splits (row 1)
|
|
torch.testing.assert_close(in_out_splits[1], out_splits)
|
|
|
|
# Check output offsets (row 2)
|
|
out_offsets = torch.cumsum(out_splits, dim=0) # inclusive scan
|
|
# output offsets from `nvshmem_all_to_all_vdev` is exclusive scan
|
|
self.assertEqual(in_out_splits[2][0], 0)
|
|
torch.testing.assert_close(in_out_splits[2][1:], out_offsets[:-1])
|
|
|
|
# Check data
|
|
expected = torch.empty(out_numel, dtype=dtype, device=self.device)
|
|
dist.all_to_all_single(expected, inp, out_splits.tolist(), inp_splits.tolist())
|
|
torch.testing.assert_close(out[:out_numel], expected)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
run_tests()
|