pytorch/test/distributed/test_symmetric_memory.py
Prachi Gupta f638854e1d [ROCm][SymmMem] re-enable UTs (#162811)
After the UT suite moved to `MultiProcContinuousTest`, `skipIfRocm` decorator started failing rather than skipping UTs because now we spawn multiple threads before the skip decorator is taken into account and the skip decorator was raising an exception to exit the process. But, the parent process treated the child process exiting as a crash rather than a skip. Additionally, in `MultiProcContinuousTest`, if one UT fails all subsequent ones are also skipped which makes sense since there's one setup for the entire suite. However, this showed up as many failing/skipped UTs in the parity.

I added multiprocess version of skip decorators for ROCm, including, `skip_if_rocm_arch_multiprocess` and
`skip_if_rocm_ver_lessthan_multiprocess`. These are needed as symmetric memory feature is only supported on MI300 onwards and we need to skip them for other archs and some UTs only work after ROCm7.0.

Fixes #161249
Fixes #161187
Fixes #161078
Fixes #160989
Fixes #160881
Fixes #160768
Fixes #160716
Fixes #160665
Fixes #160621
Fixes #160549
Fixes #160506
Fixes #160445
Fixes #160347
Fixes #160203
Fixes #160177
Fixes #160049
Fixes #159921
Fixes #159764
Fixes #159643
Fixes #159499
Fixes #159397
Fixes #159396
Fixes #159347
Fixes #159067
Fixes #159066
Fixes #158916
Fixes #158760
Fixes #158759
Fixes #158422
Fixes #158138
Fixes #158136
Fixes #158135

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162811
Approved by: https://github.com/jeffdaily
2025-09-16 15:35:39 +00:00

1245 lines
46 KiB
Python

# Owner(s): ["module: c10d"]
import itertools
import os
import random
from contextlib import nullcontext
from unittest import skip, skipIf
import torch
import torch.distributed as dist
import torch.distributed._symmetric_memory as symm_mem
from torch._C._autograd import DeviceType
from torch._C._distributed_c10d import _SymmetricMemory
from torch._inductor.utils import fresh_cache, run_and_get_triton_code
from torch.distributed._functional_collectives import all_gather_tensor
from torch.distributed._symmetric_memory import (
_fused_all_gather_matmul_fallback,
_fused_all_gather_scaled_matmul_fallback,
_fused_matmul_reduce_scatter_fallback,
_test_mode,
enable_symm_mem_for_group,
restride_A_for_fused_matmul_reduce_scatter,
restride_A_shard_for_fused_all_gather_matmul,
)
from torch.testing._internal.common_cuda import _get_torch_cuda_version, SM90OrLater
from torch.testing._internal.common_device_type import e4m3_type
from torch.testing._internal.common_distributed import (
MultiProcContinuousTest,
MultiProcessTestCase,
PLATFORM_SUPPORTS_SYMM_MEM,
requires_multicast_support,
skip_if_lt_x_gpu,
skip_if_rocm_multiprocess,
skip_if_rocm_ver_lessthan_multiprocess,
)
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
requires_cuda,
requires_cuda_p2p_access,
run_tests,
TEST_WITH_ROCM,
TestCase,
)
test_contexts = [nullcontext, _test_mode]
# So that tests are written in device-agnostic way
device_type = "cuda"
device_module = torch.get_device_module(device_type)
@instantiate_parametrized_tests
@requires_cuda_p2p_access()
class SymmetricMemoryTest(MultiProcContinuousTest):
@property
def device(self) -> torch.device:
return torch.device(device_type, self.rank)
def _init_process(self):
torch.cuda.set_device(self.device)
torch.manual_seed(42 + self.rank)
def test_has_multicast_support(self) -> None:
# validate that has_multicast_support() returns "false" instead of throwing
self.assertFalse(_SymmetricMemory.has_multicast_support(DeviceType.CPU, 0))
# NOTE: DeviceType.CUDA is implicitly tested through @requires_multicast_support
@skipIf(
not PLATFORM_SUPPORTS_SYMM_MEM, "SymmMem is not supported on this ROCm arch"
)
@skip_if_lt_x_gpu(2)
def test_get_backend(self) -> None:
backend = symm_mem.get_backend(torch.device("cuda"))
self.assertIsNotNone(backend)
backend = symm_mem.get_backend("cuda")
self.assertIsNotNone(backend)
@skip_if_rocm_multiprocess
@skip_if_lt_x_gpu(2)
def test_cuda_nvlink_connectivity_detection(self) -> None:
from torch._C._distributed_c10d import _detect_dma_connectivity
connectivity = _detect_dma_connectivity(DeviceType.CUDA, "nvlink")
self.assertEqual(connectivity.device_type, DeviceType.CUDA)
self.assertEqual(connectivity.connection_type, "nvlink")
self.assertEqual(len(connectivity.matrix), torch.cuda.device_count())
for row in connectivity.matrix:
self.assertEqual(len(row), torch.cuda.device_count())
@skipIf(
not PLATFORM_SUPPORTS_SYMM_MEM, "SymmMem is not supported on this ROCm arch"
)
def test_large_alloc(self) -> None:
t = symm_mem.empty(2 * 1024**3, dtype=torch.uint8, device="cuda")
self.assertEqual(t.numel() * t.element_size(), 2 * 1024**3)
@skipIf(
not PLATFORM_SUPPORTS_SYMM_MEM, "SymmMem is not supported on this ROCm arch"
)
@skip_if_lt_x_gpu(2)
def test_get_signal_pad(self) -> None:
self._init_process()
t = symm_mem.empty(1, device="cuda")
symm_mem_hdl = symm_mem.rendezvous(t, group=dist.group.WORLD)
peer_rank = (self.rank + 1) % self.world_size
signal_pad = symm_mem_hdl.get_signal_pad(self.rank)
self.assertEqual(
signal_pad.data_ptr(), symm_mem_hdl.signal_pad_ptrs[symm_mem_hdl.rank]
)
signal_pad = symm_mem_hdl.get_signal_pad(peer_rank)
self.assertEqual(signal_pad.dtype, torch.uint32)
self.assertEqual(signal_pad.numel(), symm_mem_hdl.signal_pad_size // 4)
# Only specify sizes
signal_pad = symm_mem_hdl.get_signal_pad(peer_rank, (8, 8))
self.assertEqual(signal_pad.dtype, torch.uint32)
self.assertEqual(signal_pad.numel(), 64)
# Only specify dtype
signal_pad = symm_mem_hdl.get_signal_pad(peer_rank, dtype=torch.uint64)
self.assertEqual(signal_pad.dtype, torch.uint64)
self.assertEqual(signal_pad.numel(), symm_mem_hdl.signal_pad_size // 8)
# Specify both sizes and dtype
signal_pad = symm_mem_hdl.get_signal_pad(peer_rank, (8, 8), dtype=torch.uint64)
self.assertEqual(signal_pad.dtype, torch.uint64)
self.assertEqual(signal_pad.numel(), 64)
# Sanity check that writes to buffer doesn't corrupt signal_pad
t = symm_mem.empty(0, device="cuda")
symm_mem_hdl = symm_mem.rendezvous(t, group=dist.group.WORLD)
signal_pad = symm_mem_hdl.get_signal_pad(self.rank)
signal_pad.fill_(42)
t.fill_(0)
self.assertTrue(signal_pad.eq(42).all())
@skipIf(
not PLATFORM_SUPPORTS_SYMM_MEM, "SymmMem is not supported on this ROCm arch"
)
@requires_cuda
def test_allow_overlapping_devices(self) -> None:
os.environ["TORCH_SYMM_MEM_ALLOW_OVERLAPPING_DEVICES"] = "1"
t = symm_mem.empty(64, device="cuda:0")
symm_mem_hdl = symm_mem.rendezvous(t, group=dist.group.WORLD)
self.assertEqual(symm_mem_hdl.rank, self.rank)
self.assertEqual(symm_mem_hdl.world_size, self.world_size)
for rank in range(self.world_size):
buf = symm_mem_hdl.get_buffer(rank, (64,), torch.float32)
if rank == self.rank:
self.assertEqual(buf.data_ptr(), t.data_ptr())
else:
self.assertEqual(buf.device, t.device)
os.environ["TORCH_SYMM_MEM_ALLOW_OVERLAPPING_DEVICES"] = "0"
@skipIf(
not PLATFORM_SUPPORTS_SYMM_MEM, "SymmMem is not supported on this ROCm arch"
)
@skip_if_lt_x_gpu(2)
@parametrize("symm_mem_input", [True, False])
def test_low_contention_all_gather(self, symm_mem_input: bool) -> None:
self._init_process()
if symm_mem_input:
t = _SymmetricMemory.empty_strided_p2p(
size=(64, 64),
stride=(64, 1),
dtype=torch.float32,
device=self.device,
group_name="0",
).fill_(self.rank)
else:
t = torch.full((64, 64), self.rank, dtype=torch.float32, device=self.device)
res = torch.ops.symm_mem._low_contention_all_gather(t, "0")
res = torch.ops._c10d_functional.wait_tensor(res)
self.assertEqual(res.shape, (64 * self.world_size, 64))
chunks = res.chunk(self.world_size)
for r in range(self.world_size):
self.assertTrue(chunks[r].eq(r).all())
@skipIf(
not PLATFORM_SUPPORTS_SYMM_MEM, "SymmMem is not supported on this ROCm arch"
)
@skip_if_lt_x_gpu(2)
@parametrize("reduce_op", ["sum", "avg"])
@parametrize("symm_mem_input", [True, False])
def test_low_contention_reduce_scatter(
self, reduce_op: str, symm_mem_input: bool
) -> None:
self._init_process()
if symm_mem_input:
t = _SymmetricMemory.empty_strided_p2p(
size=(64, 64),
stride=(64, 1),
dtype=torch.float32,
device=self.device,
group_name="0",
)
else:
t = torch.empty((64, 64), dtype=torch.float32, device=self.device)
chunks = t.chunk(self.world_size)
for r in range(self.world_size):
chunks[r].fill_(r)
res = torch.ops.symm_mem._low_contention_reduce_scatter(t, reduce_op, "0")
res = torch.ops._c10d_functional.wait_tensor(res)
self.assertEqual(res.shape, (64 // self.world_size, 64))
if reduce_op == "sum":
expect = self.rank * self.world_size
elif reduce_op == "avg":
expect = self.rank
else:
raise AssertionError(f"Unexpected reduce_op: {reduce_op}")
self.assertTrue(res.eq(expect).all())
@skipIf(
not PLATFORM_SUPPORTS_SYMM_MEM, "SymmMem is not supported on this ROCm arch"
)
@skip_if_lt_x_gpu(4)
def test_subgroup(self) -> None:
self._init_process()
ranks = list(range(self.world_size))
subgroup_0 = dist.new_group(ranks[: len(ranks) // 2])
subgroup_1 = dist.new_group(ranks[len(ranks) // 2 :])
world = dist.group.WORLD
subgroup = subgroup_0 if world.rank() < world.size() // 2 else subgroup_1
t = symm_mem.empty(64, device="cuda")
symm_mem_world = symm_mem.rendezvous(t, group=world)
symm_mem_subgroup = symm_mem.rendezvous(t, group=subgroup)
self.assertEqual(symm_mem_world.world_size, world.size())
self.assertEqual(symm_mem_world.rank, world.rank())
self.assertEqual(symm_mem_subgroup.world_size, world.size() // 2)
self.assertEqual(symm_mem_subgroup.rank, world.rank() % subgroup.size())
t.fill_(world.rank())
symm_mem_world.barrier()
# Observe a peer buffer via the world group
peer_rank = (world.rank() + 1) % world.size()
buf = symm_mem_world.get_buffer(peer_rank, (64,), torch.float32)
self.assertTrue(buf.eq(peer_rank).all())
# Observe a peer buffer via the subgroup
peer_rank = (subgroup.rank() + 1) % subgroup.size()
buf = symm_mem_subgroup.get_buffer(peer_rank, (64,), torch.float32)
if world.rank() < world.size() // 2:
self.assertTrue(buf.eq(peer_rank).all())
else:
self.assertTrue(buf.eq(peer_rank + world.size() // 2).all())
# We move AsyncTP tests to a seperate test suite because 1) Async TP ops are not
# the core symmetric memory APIs, they are more like applications, 2)
# MultiProcContinuousTest will skip all the following tests if a test fails (
# we should fix this too). We still want to get the test signals for the core
# symmetric memory APIs when Async TP ops fail.
@instantiate_parametrized_tests
@requires_cuda_p2p_access()
class AsyncTPTest(MultiProcContinuousTest):
@property
def device(self) -> torch.device:
return torch.device(device_type, self.rank)
def _init_process(self):
torch.cuda.set_device(self.device)
torch.manual_seed(42 + self.rank)
torch.use_deterministic_algorithms(True)
torch.set_deterministic_debug_mode("warn")
torch.utils.deterministic.fill_uninitialized_memory = True
@skipIf(
not PLATFORM_SUPPORTS_SYMM_MEM, "SymmMem is not supported on this ROCm arch"
)
@skip_if_lt_x_gpu(2)
@parametrize("gather_dim", [0, 1])
def test_fused_all_gather_matmul(self, gather_dim: int) -> None:
self._init_process()
BATCH = 8
M = 64
N = 16
K = 32
group = dist.group.WORLD
rank = self.rank
torch.manual_seed(42 + rank)
A_shard = torch.rand(BATCH, M // self.world_size, K, device="cuda")
Bs = [torch.rand(K, N, device="cuda") for _ in range(3)]
ag_output_0, mm_outputs_0 = _fused_all_gather_matmul_fallback(
A_shard, Bs, gather_dim=gather_dim, group_name=group.group_name
)
ag_output_1, mm_outputs_1 = torch.ops.symm_mem.fused_all_gather_matmul(
A_shard, Bs, gather_dim=gather_dim, group_name=group.group_name
)
assert torch.allclose(ag_output_0, ag_output_1)
assert ag_output_0.stride() == ag_output_1.stride()
for mm_output_0, mm_output_1 in zip(mm_outputs_0, mm_outputs_1):
assert torch.allclose(mm_output_0, mm_output_1)
assert mm_output_0.stride(), mm_output_1.stride()
@skip_if_rocm_multiprocess # this requires async_input_mm support
@skipIf(
not SM90OrLater,
"_fused_all_gather_matmul_native currently only supports sm>=90",
)
@skip_if_lt_x_gpu(2)
@parametrize("symm_mem_input", [True, False])
@parametrize("is_b_row_major", [True, False])
def test_fused_all_gather_matmul_native(
self, symm_mem_input: bool, is_b_row_major: bool
) -> None:
os.environ["TORCH_SYMM_MEM_ENABLE_NATIVE_ASYNC_TP"] = "1"
self._init_process()
# See _should_use_fused_all_gather_matmul_native() for the algo
# selection criteria of _fused_all_gather_matmul_native().
M = 4096
N = 1024
K = 1024
group_name = dist.group.WORLD.group_name
torch.manual_seed(42 + self.rank)
if symm_mem_input:
A_shard = symm_mem.empty(
M // self.world_size,
K,
dtype=torch.bfloat16,
device=self.device,
).normal_()
else:
A_shard = torch.rand(
M // self.world_size, K, dtype=torch.bfloat16, device="cuda"
)
if is_b_row_major:
B = torch.rand(K, N, dtype=torch.bfloat16, device="cuda")
else:
B = torch.rand(N, K, dtype=torch.bfloat16, device="cuda").t()
ag_baseline, mm_baseline = _fused_all_gather_matmul_fallback(
A_shard, [B], gather_dim=0, group_name=group_name
)
with torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CUDA,
],
) as prof:
ag_target, mm_target = torch.ops.symm_mem.fused_all_gather_matmul(
A_shard, [B], gather_dim=0, group_name=group_name
)
self.assertTrue(
any("PersistentAsyncInputScheduler" in event.key for event in prof.events())
)
torch.testing.assert_close(ag_target, ag_baseline)
torch.testing.assert_close(mm_target[0], mm_baseline[0])
os.environ["TORCH_SYMM_MEM_ENABLE_NATIVE_ASYNC_TP"] = "0"
@skip_if_lt_x_gpu(2)
@requires_multicast_support()
def test_multimem_all_gather_matmul(self) -> None:
self._init_process()
# See _should_use_multimem_all_gather_matmul() for the algo
# selection criteria of _multimem_gather_matmul().
M = 1024
N = 1024
K = 1024
group_name = dist.group.WORLD.group_name
torch.manual_seed(42 + self.rank)
A_shard = torch.rand(
M // self.world_size, K, dtype=torch.bfloat16, device="cuda"
)
B = torch.rand(K, N, dtype=torch.bfloat16, device="cuda")
ag_baseline, mm_baseline = _fused_all_gather_matmul_fallback(
A_shard, [B], gather_dim=0, group_name=group_name, return_A=False
)
with torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CUDA,
],
) as prof:
ag_target, mm_target = torch.ops.symm_mem.fused_all_gather_matmul(
A_shard, [B], gather_dim=0, group_name=group_name, return_A=False
)
self.assertTrue(
any("multimem_all_gather_kernel" in event.key for event in prof.events())
)
torch.testing.assert_close(ag_target, ag_baseline)
torch.testing.assert_close(mm_target[0], mm_baseline[0])
@skipIf(
not PLATFORM_SUPPORTS_SYMM_MEM, "SymmMem is not supported on this ROCm arch"
)
@skip_if_lt_x_gpu(2)
@parametrize("gather_dim", [0, 1])
@parametrize(
"scale_mode", ["tensor-wise", "row-wise-replicated", "row-wise-sharded"]
)
def test_fused_all_gather_scaled_matmul(
self, gather_dim: int, scale_mode: str
) -> None:
self._init_process()
BATCH = 8
M = 64
N = 16
K = 32
group = dist.group.WORLD
rank = self.rank
if gather_dim == 0:
leading_dims = (BATCH // self.world_size, M)
elif gather_dim == 1:
leading_dims = (BATCH, M // self.world_size)
else:
raise AssertionError("Invalid scale_mode: {scale_mode}")
torch.manual_seed(42 + rank)
A_shard = torch.rand(*leading_dims, K, device="cuda").to(e4m3_type)
Bs = [torch.rand(N, K, device="cuda").to(e4m3_type).T for _ in range(3)]
if scale_mode == "tensor-wise":
A_scale = torch.tensor(0.1, device="cuda")
B_scales = [torch.tensor(0.1, device="cuda") for _ in range(3)]
out_dtypes = [None, torch.bfloat16, torch.float32]
elif scale_mode == "row-wise-sharded":
A_scale = torch.full((*leading_dims, 1), 0.1, device="cuda")
B_scales = [torch.full((1, N), 0.1, device="cuda") for _ in range(3)]
out_dtypes = [torch.bfloat16] * 3
elif scale_mode == "row-wise-replicated":
A_scale = torch.full((BATCH, M, 1), 0.1, device="cuda")
B_scales = [torch.full((1, N), 0.1, device="cuda") for _ in range(3)]
out_dtypes = [torch.bfloat16] * 3
else:
raise AssertionError(f"Invalid scale_mode: {scale_mode}")
ag_output_0, mm_outputs_0 = _fused_all_gather_scaled_matmul_fallback(
A_shard,
Bs,
A_scale,
B_scales,
gather_dim=gather_dim,
group_name=group.group_name,
biases=[None] * len(Bs),
result_scales=[None] * len(Bs),
out_dtypes=out_dtypes,
use_fast_accum=[None] * len(Bs),
)
ag_output_1, mm_outputs_1 = torch.ops.symm_mem.fused_all_gather_scaled_matmul(
A_shard,
Bs,
A_scale,
B_scales,
gather_dim=gather_dim,
group_name=group.group_name,
biases=[None] * len(Bs),
result_scales=[None] * len(Bs),
out_dtypes=out_dtypes,
use_fast_accum=[None] * len(Bs),
)
self.assertTrue(
torch.allclose(
ag_output_0.to(torch.float32),
ag_output_1.to(torch.float32),
)
)
self.assertEqual(ag_output_0.stride(), ag_output_1.stride())
for mm_output_0, mm_output_1 in zip(mm_outputs_0, mm_outputs_1):
self.assertTrue(
torch.allclose(
mm_output_0.to(torch.float32), mm_output_1.to(torch.float32)
)
)
self.assertEqual(mm_output_0.stride(), mm_output_1.stride())
self.assertEqual(mm_output_0.dtype, mm_output_1.dtype)
@skipIf(
not PLATFORM_SUPPORTS_SYMM_MEM, "SymmMem is not supported on this ROCm arch"
)
@skip_if_lt_x_gpu(2)
@parametrize("scatter_dim", [0, 1])
def test_fused_matmul_reduce_scatter(self, scatter_dim: int) -> None:
self._init_process()
BATCH = 8
M = 64
N = 16
K = 32
group = dist.group.WORLD
rank = self.rank
torch.manual_seed(42 + rank)
A = torch.rand(BATCH, M, K, device="cuda")
B = torch.rand(K, N, device="cuda")
output_0 = _fused_matmul_reduce_scatter_fallback(
A, B, "avg", scatter_dim=scatter_dim, group_name=group.group_name
)
output_1 = torch.ops.symm_mem.fused_matmul_reduce_scatter(
A, B, "avg", scatter_dim=scatter_dim, group_name=group.group_name
)
assert torch.allclose(output_0, output_1)
assert output_0.stride() == output_1.stride()
@skip_if_rocm_multiprocess # AsyncTP support changed _fused_scaled_matmul_reduce_scatter_fallback API, need more changes
@skip_if_lt_x_gpu(2)
@parametrize("scatter_dim", [0, 1])
@parametrize("rowwise", [True, False])
def test_fused_scaled_matmul_reduce_scatter(
self, scatter_dim: int, rowwise: bool
) -> None:
self._init_process()
BATCH = 8
M = 64
N = 16
K = 32
group = dist.group.WORLD
rank = self.rank
torch.manual_seed(42 + rank)
A = torch.rand(BATCH, M, K, device="cuda").to(e4m3_type)
B = torch.rand(N, K, device="cuda").to(e4m3_type).T
if rowwise:
A_scale = torch.full((BATCH, M, 1), 0.1, device="cuda")
B_scale = torch.full((1, N), 0.1, device="cuda")
else:
A_scale = torch.tensor(0.1, device="cuda")
B_scale = torch.tensor(0.1, device="cuda")
output_shape = [*A.shape[:-1], B.shape[1]]
outputs = []
for context in test_contexts:
with context():
outputs.append(
torch.ops.symm_mem.fused_scaled_matmul_reduce_scatter(
A,
B,
A_scale,
B_scale,
"avg",
scatter_dim,
scatter_dim,
group.group_name,
output_shape,
out_dtype=torch.bfloat16,
)
)
assert outputs[0].stride() == outputs[1].stride()
self.assertEqual(outputs[0], outputs[1])
@skipIf(
not PLATFORM_SUPPORTS_SYMM_MEM, "SymmMem is not supported on this ROCm arch"
)
@parametrize("dim", [0, 1, 2])
def test_optimal_layout(self, dim: int) -> None:
t = torch.rand(8, 64, 32, 16)
x = restride_A_shard_for_fused_all_gather_matmul(t, dim)
self.assertTrue(x.movedim(dim, 0).is_contiguous())
self.assertTrue(torch.allclose(x, t))
x = restride_A_for_fused_matmul_reduce_scatter(t, dim)
self.assertTrue(x.movedim(dim, 0).is_contiguous())
self.assertTrue(torch.allclose(x, t))
# [READ ME FIRST]
# The `SymmMemEmptySetDeviceTest` suite parameterizes whether user sets the
# device before calling symm_mem.emtpy. Either way should work.
# However, since `set_device` is persistent, we cannot use the
# `MultiProcContinuousTest` template because the next function will be
# "contaminated", leading to flaky tests (e.g. hang). Therefore, we use
# `MultiProcessTestCase` which spawns new processes for each test function.
# Please limit the number of tests you want to add under this test
# suite as respawning processes and `init_process_group` is expensive.
@instantiate_parametrized_tests
@requires_cuda_p2p_access()
class SymmMemEmptySetDeviceTest(MultiProcessTestCase):
def setUp(self) -> None:
super().setUp()
self._spawn_processes()
@property
def world_size(self) -> int:
return device_module.device_count()
@property
def device(self) -> torch.device:
return torch.device(device_type, self.rank)
def _init_process(self, set_device: bool):
if set_device:
torch.cuda.set_device(self.device)
store = dist.FileStore(self.file_name, self.world_size)
dist.init_process_group(
backend="nccl",
world_size=self.world_size,
rank=self.rank,
store=store,
)
torch.manual_seed(42 + self.rank)
def _get_test_alloc_args(self):
shape = (64, 64)
stride = (64, 1)
dtype = torch.float32
device = self.device
return (shape, stride, dtype, device)
def _verify_symmetric_memory(self, symm_mem_hdl):
self.assertEqual(symm_mem_hdl.world_size, self.world_size)
buf = symm_mem_hdl.get_buffer(
0, (symm_mem_hdl.buffer_size // 4,), torch.float32
)
self.assertEqual(buf.storage_offset(), 0)
self.assertEqual(buf.untyped_storage().size(), symm_mem_hdl.buffer_size)
if symm_mem_hdl.rank == 0:
symm_mem_hdl.wait_signal(src_rank=1)
self.assertTrue(buf.eq(42).all())
else:
buf.fill_(42)
symm_mem_hdl.put_signal(dst_rank=0)
symm_mem_hdl.barrier()
if symm_mem_hdl.rank == 0:
symm_mem_hdl.barrier()
self.assertTrue(buf.eq(43).all())
else:
buf.fill_(43)
symm_mem_hdl.barrier()
symm_mem_hdl.barrier()
@skipIf(
not PLATFORM_SUPPORTS_SYMM_MEM, "SymmMem is not supported on this ROCm arch"
)
@skip_if_lt_x_gpu(2)
@parametrize("set_device", [True, False])
def test_empty_strided_p2p(self, set_device: bool) -> None:
self._init_process(set_device)
group_name = dist.group.WORLD.group_name
enable_symm_mem_for_group(group_name)
alloc_args = self._get_test_alloc_args()
t = torch.empty((64, 64), device=self.device)
self.assertIsNone(_SymmetricMemory.rendezvous(t))
t = _SymmetricMemory.empty_strided_p2p(*alloc_args, group_name=group_name)
symm_mem_hdl = _SymmetricMemory.rendezvous(t)
del t
self._verify_symmetric_memory(symm_mem_hdl)
@skipIf(
not PLATFORM_SUPPORTS_SYMM_MEM, "SymmMem is not supported on this ROCm arch"
)
@skip_if_rocm_ver_lessthan_multiprocess((7, 0))
@skip_if_lt_x_gpu(2)
@parametrize("set_device", [True, False])
def test_empty_strided_p2p_persistent(self, set_device: bool) -> None:
self._init_process(set_device)
group_name = dist.group.WORLD.group_name
enable_symm_mem_for_group(group_name)
alloc_args = self._get_test_alloc_args()
alloc_id = 42 + random.randint(0, 2147483647)
t = _SymmetricMemory.empty_strided_p2p(
*alloc_args, group_name=group_name, alloc_id=alloc_id
)
data_ptr = t.data_ptr()
# Verify that persistent allocation would fail if there's an active
# allocation with the same alloc_id.
with self.assertRaises(RuntimeError):
_SymmetricMemory.empty_strided_p2p(
*alloc_args, group_name=group_name, alloc_id=alloc_id
)
# Verify that persistent allocation would succeed in lieu of activate
# allocations with the same alloc_id, and the returned tensor would
# have the same data pointer.
del t
t = _SymmetricMemory.empty_strided_p2p(
*alloc_args, group_name=group_name, alloc_id=alloc_id
)
self.assertEqual(t.data_ptr(), data_ptr)
symm_mem_hdl = _SymmetricMemory.rendezvous(t)
self._verify_symmetric_memory(symm_mem_hdl)
# This Test class is used to test the error handling of SymmetricMemory APIs.
# Since a process restart is often needed after each test, we use the
# MultiProcessTestCase instead of MultiProcContinuousTest.
@requires_cuda_p2p_access()
class SymmMemNegativeTest(MultiProcessTestCase):
def setUp(self) -> None:
super().setUp()
self._spawn_processes()
@property
def world_size(self) -> int:
return device_module.device_count()
@property
def device(self) -> torch.device:
return torch.device(device_type, self.rank)
def _init_process(self):
torch.cuda.set_device(self.device)
store = dist.FileStore(self.file_name, self.world_size)
dist.init_process_group(
backend="nccl",
world_size=self.world_size,
rank=self.rank,
store=store,
)
torch.manual_seed(42 + self.rank)
# These timeout tests are skipped on ROCm because timeout calls trap(), which
# is handled differently inside hip runtime. It collects gpu coredump and causes
# the linux kernel to create a core dump of the host application. The functionality
# is there, meaning timeout is happening correctly. However, there isn't a nice way
# to test it as the current executing thread will coredump and exit.
@skip_if_rocm_multiprocess
@skip_if_lt_x_gpu(2)
def test_barrier_timeout(self) -> None:
self._init_process()
t = symm_mem.empty(1, device="cuda")
symm_mem_hdl = symm_mem.rendezvous(t, group=dist.group.WORLD)
if self.rank == 0:
with self.assertRaises(RuntimeError):
symm_mem_hdl.barrier(timeout_ms=1000)
torch.cuda.synchronize()
else:
torch.cuda.synchronize()
# The device-side timeout triggers a __trap() that causes all
# subsequent host/device interactions to result in an "unspecified
# launch failure." Using os._exit(0) to abort the test, as it's
# impossible to terminate the process in this state.
os._exit(0)
# These timeout tests are skipped on ROCm because timeout calls trap(), which
# is handled differently inside hip runtime. It collects gpu coredump and causes
# the linux kernel to create a core dump of the host application. The functionality
# is there, meaning timeout is happening correctly. However, there isn't a nice way
# to test it as the current executing thread will coredump and exit.
@skip_if_rocm_multiprocess
@skip_if_lt_x_gpu(2)
def test_put_signal_timeout(self) -> None:
self._init_process()
t = symm_mem.empty(1, device="cuda")
symm_mem_hdl = symm_mem.rendezvous(t, group=dist.group.WORLD)
if self.rank == 0:
with self.assertRaises(RuntimeError):
# First, put a signal into rank 1's signal pad. Since rank 1
# doesn't wait on this signal, the subsequent put will timeout.
symm_mem_hdl.put_signal(dst_rank=1)
symm_mem_hdl.put_signal(dst_rank=1, timeout_ms=1000)
torch.cuda.synchronize()
else:
torch.cuda.synchronize()
# The device-side timeout triggers a __trap() that causes all
# subsequent host/device interactions to result in an "unspecified
# launch failure." Using os._exit(0) to abort the test, as it's
# impossible to terminate the process in this state.
os._exit(0)
# These timeout tests are skipped on ROCm because timeout calls trap(), which
# is handled differently inside hip runtime. It collects gpu coredump and causes
# the linux kernel to create a core dump of the host application. The functionality
# is there, meaning timeout is happening correctly. However, there isn't a nice way
# to test it as the current executing thread will coredump and exit.
@skip_if_rocm_multiprocess
@skip_if_lt_x_gpu(2)
def test_wait_signal_timeout(self) -> None:
self._init_process()
t = symm_mem.empty(1, device="cuda")
symm_mem_hdl = symm_mem.rendezvous(t, group=dist.group.WORLD)
if self.rank == 0:
with self.assertRaises(RuntimeError):
symm_mem_hdl.wait_signal(src_rank=1, timeout_ms=1000)
torch.cuda.synchronize()
else:
torch.cuda.synchronize()
# The device-side timeout triggers a __trap() that causes all
# subsequent host/device interactions to result in an "unspecified
# launch failure." Using os._exit(0) to abort the test, as it's
# impossible to terminate the process in this state.
os._exit(0)
@instantiate_parametrized_tests
@requires_cuda_p2p_access()
class SymmMemCollectiveTest(MultiProcContinuousTest):
@property
def device(self) -> torch.device:
return torch.device(device_type, self.rank)
def _init_process(self):
torch.cuda.set_device(self.device)
torch.manual_seed(42 + self.rank)
@skip_if_lt_x_gpu(4)
@requires_multicast_support()
@parametrize("dtype", [torch.float, torch.bfloat16])
@parametrize("align_bytes", [4, 8, 16])
@parametrize("size_bytes", [4, 8192, 8196])
def test_multimem_all_reduce(
self, dtype: torch.dtype, size_bytes: int, align_bytes: int
) -> None:
self._init_process()
group_name = dist.group.WORLD.group_name
t = symm_mem.empty((16384), dtype=dtype, device=self.device)
symm_mem.rendezvous(t, group=dist.group.WORLD)
self.assertTrue(t.data_ptr() % 16 == 0)
self.assertTrue(align_bytes % t.element_size() == 0)
self.assertTrue(size_bytes % t.element_size() == 0)
shift = align_bytes // t.element_size()
numel = size_bytes // t.element_size()
res = t[shift : shift + numel]
res.normal_()
inp = res.clone()
torch.ops.symm_mem.multimem_all_reduce_(res, "sum", group_name)
# Head and tail should not be written
self.assertTrue(t[:shift].eq(0).all().item())
self.assertTrue(t[shift + numel :].eq(0).all().item())
self._verify_all_reduce_result(inp, res)
@skip_if_lt_x_gpu(4)
@requires_multicast_support()
@parametrize("dtype", [torch.float, torch.bfloat16])
@parametrize("align_bytes", [4, 8, 16])
@parametrize("size_bytes", [4, 8192, 8196])
def test_multimem_one_shot_all_reduce(
self, dtype: torch.dtype, size_bytes: int, align_bytes: int
) -> None:
self._init_process()
group_name = dist.group.WORLD.group_name
inp = symm_mem.empty(
size_bytes // dtype.itemsize, dtype=dtype, device=self.device
).normal_()
symm_mem.rendezvous(inp, group=group_name)
res = torch.ops.symm_mem.multimem_one_shot_all_reduce(inp, "sum", group_name)
gathered_inps = all_gather_tensor(inp, 0, "0").view(self.world_size, -1)
# Only verify that the results are close to the sum of inputs across
# ranks (see Note [multimem_one_shot_all_reduce]).
torch.testing.assert_close(
gathered_inps.sum(dim=0), res, rtol=1e-03, atol=1e-05
)
@skipIf(
not PLATFORM_SUPPORTS_SYMM_MEM, "SymmMem is not supported on this ROCm arch"
)
@skip_if_lt_x_gpu(4)
def test_one_shot_all_reduce(self) -> None:
self._init_process()
group_name = dist.group.WORLD.group_name
for dtype, size_bytes, align_bytes, copy, offset in itertools.product(
[torch.float, torch.bfloat16],
[4, 8192, 8196],
[
8
], # TODO: add back [4, 8, 16], currently OOM when looping over all combinations
[True, False],
[0, 16],
):
inp = symm_mem.empty(
size_bytes // dtype.itemsize + offset, dtype=dtype, device=self.device
)
symm_mem.rendezvous(inp, group=group_name)
if not copy:
inp.normal_()
res = torch.ops.symm_mem.one_shot_all_reduce(
inp[offset:], "sum", group_name
)
if copy:
local_inp = torch.randn_like(inp[offset:])
res = torch.ops.symm_mem.one_shot_all_reduce_copy(
inp[offset:], local_inp, "sum", group_name
)
self._verify_all_reduce_result(local_inp if copy else inp[offset:], res)
@skipIf(
not PLATFORM_SUPPORTS_SYMM_MEM, "SymmMem is not supported on this ROCm arch"
)
@skip_if_lt_x_gpu(4)
def test_two_shot_all_reduce(self) -> None:
self._init_process()
group_name = dist.group.WORLD.group_name
for dtype, size_bytes, align_bytes, inplace in itertools.product(
[torch.float, torch.bfloat16],
[4, 8192, 8196],
[
8
], # TODO: add back [4, 8, 16], currently OOM when looping over all combinations
[True, False],
):
t = symm_mem.empty(16384, dtype=dtype, device=self.device).fill_(0)
symm_mem.rendezvous(t, group=group_name)
self.assertTrue(t.data_ptr() % 16 == 0)
self.assertTrue(align_bytes % t.element_size() == 0)
self.assertTrue(size_bytes % t.element_size() == 0)
shift = align_bytes // t.element_size()
numel = size_bytes // t.element_size()
res = t[shift : shift + numel]
res.normal_()
inp = res.clone()
if not inplace:
out = torch.empty_like(inp)
torch.ops.symm_mem.two_shot_all_reduce_out(res, "sum", group_name, out)
else:
torch.ops.symm_mem.two_shot_all_reduce_(res, "sum", group_name)
# Head and tail should not be written
self.assertTrue(t[:shift].eq(0).all().item())
self.assertTrue(t[shift + numel :].eq(0).all().item())
self._verify_all_reduce_result(inp, res if inplace else out)
def _verify_all_reduce_result(self, inp, res):
gathered_res = all_gather_tensor(res, 0, "0").view(self.world_size, -1)
# Verify that the results across ranks are identical
self.assertEqual(
(gathered_res == gathered_res[0, :]).all(dim=0).sum(), inp.numel()
)
# Verify that the result are close to the sum of inputs across ranks
gathered_inps = all_gather_tensor(inp, 0, "0").view(self.world_size, -1)
torch.testing.assert_close(
gathered_inps.sum(dim=0), res, rtol=1e-01, atol=1e-01
)
@skipIf(
not PLATFORM_SUPPORTS_SYMM_MEM, "SymmMem is not supported on this ROCm arch"
)
@skip_if_lt_x_gpu(4)
def test_reduce_scatter(self) -> None:
self._init_process()
group_name = dist.group.WORLD.group_name
for dtype, size_bytes, align_bytes, split_last_dim in itertools.product(
[torch.float, torch.bfloat16],
[128, 8192, 36 * 1024 * 16],
[
8
], # TODO: add back [4, 8, 16], currently OOM when looping over all combinations
[True, False],
):
t = symm_mem.empty(36 * 1024 * 16, dtype=dtype, device=self.device).fill_(0)
symm_mem.rendezvous(t, group=group_name)
self.assertTrue(t.data_ptr() % 16 == 0)
self.assertTrue(align_bytes % t.element_size() == 0)
self.assertTrue(size_bytes % t.element_size() == 0)
shift = align_bytes // t.element_size()
numel = size_bytes // t.element_size()
res = t[shift : shift + numel].normal_()
if split_last_dim:
res = res.view(-1, 128 // t.element_size())
inp = res.clone()
out_size = list(inp.shape)
out_size[-1] = inp.shape[-1] // self.world_size
out = torch.empty(out_size, dtype=dtype, device=self.device)
torch.ops.symm_mem.reduce_scatter_out(res, group_name, split_last_dim, out)
# Head and tail should not be written
self.assertTrue(t[:shift].eq(0).all().item())
self.assertTrue(t[shift + numel :].eq(0).all().item())
self._verify_reduce_scatter_result(inp, out)
@skipIf(
not PLATFORM_SUPPORTS_SYMM_MEM, "SymmMem is not supported on this ROCm arch"
)
@skip_if_lt_x_gpu(4)
def test_reduce_scatter_corner_cases(self) -> None:
self._init_process()
dtype = torch.bfloat16
group_name = dist.group.WORLD.group_name
t = symm_mem.empty(16384, dtype=dtype, device=self.device).fill_(0)
symm_mem.rendezvous(t, group=group_name)
res = t[:0]
out_size = res.shape[0] // self.world_size
out = torch.empty(out_size, dtype=dtype, device=self.device)
torch.ops.symm_mem.reduce_scatter_out(res, group_name, False, out)
res = t[:48]
out_size = res.shape[0] // self.world_size
out = torch.empty(out_size, dtype=dtype, device=self.device)
with self.assertRaisesRegex(RuntimeError, "divisible"):
torch.ops.symm_mem.reduce_scatter_out(res, group_name, False, out)
res = t[: 2 * 48].view(2, 48)
out = torch.empty(2, 48 // self.world_size, dtype=dtype, device=self.device)
with self.assertRaisesRegex(RuntimeError, "divisible"):
torch.ops.symm_mem.reduce_scatter_out(res, group_name, True, out)
def _verify_reduce_scatter_result(self, inp, res):
gathered_res = all_gather_tensor(res, 0, "0").view(self.world_size, *res.shape)
gathered_inps = all_gather_tensor(inp, 0, "0").view(self.world_size, *inp.shape)
sum_inps = gathered_inps.sum(0)
slice_width = sum_inps.shape[-1] // self.world_size
for i in range(self.world_size):
torch.testing.assert_close(
gathered_res[i],
sum_inps[..., i * slice_width : (i + 1) * slice_width],
rtol=1e-01,
atol=1.1e-01,
)
@skip_if_lt_x_gpu(4)
@requires_multicast_support()
@parametrize("align_bytes", [4, 8, 16])
def test_multimem_all_gather(self, align_bytes: int) -> None:
self._init_process()
group_name = dist.group.WORLD.group_name
input_numel = 32
shift = align_bytes // 4
input = torch.zeros(shift + input_numel, device=self.device)[shift:].fill_(
self.rank
)
out = symm_mem.empty(
shift + input_numel * self.world_size, device=self.device
).zero_()[shift:]
symm_mem.rendezvous(out, group=group_name)
torch.ops.symm_mem.multimem_all_gather_out(input, group_name, out)
ref = torch.ops._c10d_functional.all_gather_into_tensor(
input, self.world_size, group_name
)
ref = torch.ops._c10d_functional.wait_tensor(ref)
self.assertTrue(out.eq(ref).all())
@instantiate_parametrized_tests
@requires_cuda_p2p_access()
class LoweringTest(MultiProcContinuousTest):
def _init_process(self) -> None:
torch.cuda.set_device(self.device)
enable_symm_mem_for_group(dist.group.WORLD.group_name)
torch.manual_seed(42 + self.rank)
torch._inductor.config._collective.auto_select = True
@property
def device(self) -> torch.device:
return torch.device(device_type, self.rank)
@skip("Fails with 'one_shot_all_reduce' not found in AOT graph, TODO: fix")
@skip_if_rocm_multiprocess # requires registered-buffer support
@skip_if_lt_x_gpu(2)
@fresh_cache()
def test_lowering_one_shot_all_reduce(self):
self._init_process()
arg = torch.rand(4, 4, device=self.device)
def func_0(x):
x = x + 1
x = torch.ops._c10d_functional.all_reduce(x, "sum", "0")
return torch.ops._c10d_functional.wait_tensor(x)
compiled_0 = torch.compile(func_0, fullgraph=True)
code_0 = run_and_get_triton_code(compiled_0, arg)
self.assertIn("one_shot_all_reduce", code_0)
self.assertNotIn("return (buf0", code_0)
# All-reduce on a slice view
def func_1(x):
x = x + 1
x = x[2:]
x = torch.ops._c10d_functional.all_reduce(x, "sum", "0")
return torch.ops._c10d_functional.wait_tensor(x)
compiled_1 = torch.compile(func_1, fullgraph=True)
code_1 = run_and_get_triton_code(compiled_1, arg)
self.assertIn("one_shot_all_reduce", code_1)
self.assertNotIn("return (buf0", code_1)
# All-reduce on input
def func_2(x):
x = torch.ops._c10d_functional.all_reduce(x, "sum", "0")
return torch.ops._c10d_functional.wait_tensor(x)
compiled_2 = torch.compile(func_2, fullgraph=True)
code_2 = run_and_get_triton_code(compiled_2, arg)
self.assertNotIn("one_shot_all_reduce", code_2)
# All-reduce on matmul output
def func_3(x):
x = x @ x
x = torch.ops._c10d_functional.all_reduce(x, "sum", "0")
return torch.ops._c10d_functional.wait_tensor(x)
compiled_3 = torch.compile(func_3, fullgraph=True)
code_3 = run_and_get_triton_code(compiled_3, arg)
self.assertIn("one_shot_all_reduce", code_3)
self.assertNotIn("return (buf0", code_3)
class SymmMemSingleProcTest(TestCase):
@requires_cuda
@skipIf(
not TEST_WITH_ROCM and _get_torch_cuda_version() < (12, 0),
"stream_write_value32 currently only supports cuda version>=12.0",
)
@skipIf(
not PLATFORM_SUPPORTS_SYMM_MEM, "SymmMem is not supported on this ROCm arch"
)
def test_stream_write_value32(self):
tensor = torch.zeros(4, dtype=torch.uint32, device="cuda")
expect = torch.tril(torch.ones(4, 4, device="cuda")).to(torch.uint32)
for i in range(4):
_SymmetricMemory.stream_write_value32(tensor, i, 1)
torch.testing.assert_close(tensor, expect[i])
with self.assertRaises(RuntimeError):
_SymmetricMemory.stream_write_value32(tensor, offset=-1, val=1)
with self.assertRaises(RuntimeError):
_SymmetricMemory.stream_write_value32(tensor, offset=0, val=4294967296)
@skipIf(
not PLATFORM_SUPPORTS_SYMM_MEM, "SymmMem is not supported on this ROCm arch"
)
@requires_cuda
def test_memset32(self):
t = _SymmetricMemory.empty_strided_p2p(
(64,),
(1,),
dtype=torch.uint32,
device=torch.device("cuda:0"),
group_name="0",
).fill_(0)
_SymmetricMemory.memset32(t, offset=32, val=1, count=16)
self.assertTrue(t[:32].eq(0).all())
self.assertTrue(t[32:48].eq(1).all())
self.assertTrue(t[48:].eq(0).all())
with self.assertRaisesRegex(
RuntimeError, "input must be a flat, contiguous uint32 tensor"
):
_SymmetricMemory.memset32(t.view(8, 8), offset=0, val=1, count=1)
with self.assertRaisesRegex(
RuntimeError, "input must be a flat, contiguous uint32 tensor"
):
_SymmetricMemory.memset32(t.view(torch.float32), offset=0, val=1, count=1)
with self.assertRaisesRegex(
RuntimeError, "offset must be greater than or equal to 0"
):
_SymmetricMemory.memset32(t, offset=-1, val=1, count=1)
with self.assertRaisesRegex(
RuntimeError, r"val must be in the range of.*\(uint32_t\)"
):
_SymmetricMemory.memset32(t, offset=0, val=4294967296, count=1)
with self.assertRaisesRegex(RuntimeError, "count must be a positive integer"):
_SymmetricMemory.memset32(t, offset=0, val=1, count=-1)
with self.assertRaisesRegex(RuntimeError, "count must be a positive integer"):
_SymmetricMemory.memset32(t, offset=0, val=1, count=0)
with self.assertRaisesRegex(
RuntimeError, r"offset \+ count.*exceeded the numel of the input"
):
_SymmetricMemory.memset32(t, offset=64, val=1, count=1)
with self.assertRaisesRegex(
RuntimeError, r"offset \+ count.*exceeded the numel of the input"
):
_SymmetricMemory.memset32(t, offset=0, val=1, count=65)
_SymmetricMemory.memset32(t, offset=0, val=1, count=64)
_SymmetricMemory.memset32(t, offset=63, val=1, count=1)
if __name__ == "__main__":
run_tests()