mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
After the UT suite moved to `MultiProcContinuousTest`, `skipIfRocm` decorator started failing rather than skipping UTs because now we spawn multiple threads before the skip decorator is taken into account and the skip decorator was raising an exception to exit the process. But, the parent process treated the child process exiting as a crash rather than a skip. Additionally, in `MultiProcContinuousTest`, if one UT fails all subsequent ones are also skipped which makes sense since there's one setup for the entire suite. However, this showed up as many failing/skipped UTs in the parity. I added multiprocess version of skip decorators for ROCm, including, `skip_if_rocm_arch_multiprocess` and `skip_if_rocm_ver_lessthan_multiprocess`. These are needed as symmetric memory feature is only supported on MI300 onwards and we need to skip them for other archs and some UTs only work after ROCm7.0. Fixes #161249 Fixes #161187 Fixes #161078 Fixes #160989 Fixes #160881 Fixes #160768 Fixes #160716 Fixes #160665 Fixes #160621 Fixes #160549 Fixes #160506 Fixes #160445 Fixes #160347 Fixes #160203 Fixes #160177 Fixes #160049 Fixes #159921 Fixes #159764 Fixes #159643 Fixes #159499 Fixes #159397 Fixes #159396 Fixes #159347 Fixes #159067 Fixes #159066 Fixes #158916 Fixes #158760 Fixes #158759 Fixes #158422 Fixes #158138 Fixes #158136 Fixes #158135 Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/162811 Approved by: https://github.com/jeffdaily
1245 lines
46 KiB
Python
1245 lines
46 KiB
Python
# Owner(s): ["module: c10d"]
|
|
|
|
import itertools
|
|
import os
|
|
import random
|
|
from contextlib import nullcontext
|
|
from unittest import skip, skipIf
|
|
|
|
import torch
|
|
import torch.distributed as dist
|
|
import torch.distributed._symmetric_memory as symm_mem
|
|
from torch._C._autograd import DeviceType
|
|
from torch._C._distributed_c10d import _SymmetricMemory
|
|
from torch._inductor.utils import fresh_cache, run_and_get_triton_code
|
|
from torch.distributed._functional_collectives import all_gather_tensor
|
|
from torch.distributed._symmetric_memory import (
|
|
_fused_all_gather_matmul_fallback,
|
|
_fused_all_gather_scaled_matmul_fallback,
|
|
_fused_matmul_reduce_scatter_fallback,
|
|
_test_mode,
|
|
enable_symm_mem_for_group,
|
|
restride_A_for_fused_matmul_reduce_scatter,
|
|
restride_A_shard_for_fused_all_gather_matmul,
|
|
)
|
|
from torch.testing._internal.common_cuda import _get_torch_cuda_version, SM90OrLater
|
|
from torch.testing._internal.common_device_type import e4m3_type
|
|
from torch.testing._internal.common_distributed import (
|
|
MultiProcContinuousTest,
|
|
MultiProcessTestCase,
|
|
PLATFORM_SUPPORTS_SYMM_MEM,
|
|
requires_multicast_support,
|
|
skip_if_lt_x_gpu,
|
|
skip_if_rocm_multiprocess,
|
|
skip_if_rocm_ver_lessthan_multiprocess,
|
|
)
|
|
from torch.testing._internal.common_utils import (
|
|
instantiate_parametrized_tests,
|
|
parametrize,
|
|
requires_cuda,
|
|
requires_cuda_p2p_access,
|
|
run_tests,
|
|
TEST_WITH_ROCM,
|
|
TestCase,
|
|
)
|
|
|
|
|
|
test_contexts = [nullcontext, _test_mode]
|
|
|
|
# So that tests are written in device-agnostic way
|
|
device_type = "cuda"
|
|
device_module = torch.get_device_module(device_type)
|
|
|
|
|
|
@instantiate_parametrized_tests
|
|
@requires_cuda_p2p_access()
|
|
class SymmetricMemoryTest(MultiProcContinuousTest):
|
|
@property
|
|
def device(self) -> torch.device:
|
|
return torch.device(device_type, self.rank)
|
|
|
|
def _init_process(self):
|
|
torch.cuda.set_device(self.device)
|
|
torch.manual_seed(42 + self.rank)
|
|
|
|
def test_has_multicast_support(self) -> None:
|
|
# validate that has_multicast_support() returns "false" instead of throwing
|
|
self.assertFalse(_SymmetricMemory.has_multicast_support(DeviceType.CPU, 0))
|
|
# NOTE: DeviceType.CUDA is implicitly tested through @requires_multicast_support
|
|
|
|
@skipIf(
|
|
not PLATFORM_SUPPORTS_SYMM_MEM, "SymmMem is not supported on this ROCm arch"
|
|
)
|
|
@skip_if_lt_x_gpu(2)
|
|
def test_get_backend(self) -> None:
|
|
backend = symm_mem.get_backend(torch.device("cuda"))
|
|
self.assertIsNotNone(backend)
|
|
backend = symm_mem.get_backend("cuda")
|
|
self.assertIsNotNone(backend)
|
|
|
|
@skip_if_rocm_multiprocess
|
|
@skip_if_lt_x_gpu(2)
|
|
def test_cuda_nvlink_connectivity_detection(self) -> None:
|
|
from torch._C._distributed_c10d import _detect_dma_connectivity
|
|
|
|
connectivity = _detect_dma_connectivity(DeviceType.CUDA, "nvlink")
|
|
self.assertEqual(connectivity.device_type, DeviceType.CUDA)
|
|
self.assertEqual(connectivity.connection_type, "nvlink")
|
|
self.assertEqual(len(connectivity.matrix), torch.cuda.device_count())
|
|
for row in connectivity.matrix:
|
|
self.assertEqual(len(row), torch.cuda.device_count())
|
|
|
|
@skipIf(
|
|
not PLATFORM_SUPPORTS_SYMM_MEM, "SymmMem is not supported on this ROCm arch"
|
|
)
|
|
def test_large_alloc(self) -> None:
|
|
t = symm_mem.empty(2 * 1024**3, dtype=torch.uint8, device="cuda")
|
|
self.assertEqual(t.numel() * t.element_size(), 2 * 1024**3)
|
|
|
|
@skipIf(
|
|
not PLATFORM_SUPPORTS_SYMM_MEM, "SymmMem is not supported on this ROCm arch"
|
|
)
|
|
@skip_if_lt_x_gpu(2)
|
|
def test_get_signal_pad(self) -> None:
|
|
self._init_process()
|
|
|
|
t = symm_mem.empty(1, device="cuda")
|
|
symm_mem_hdl = symm_mem.rendezvous(t, group=dist.group.WORLD)
|
|
peer_rank = (self.rank + 1) % self.world_size
|
|
|
|
signal_pad = symm_mem_hdl.get_signal_pad(self.rank)
|
|
self.assertEqual(
|
|
signal_pad.data_ptr(), symm_mem_hdl.signal_pad_ptrs[symm_mem_hdl.rank]
|
|
)
|
|
|
|
signal_pad = symm_mem_hdl.get_signal_pad(peer_rank)
|
|
self.assertEqual(signal_pad.dtype, torch.uint32)
|
|
self.assertEqual(signal_pad.numel(), symm_mem_hdl.signal_pad_size // 4)
|
|
|
|
# Only specify sizes
|
|
signal_pad = symm_mem_hdl.get_signal_pad(peer_rank, (8, 8))
|
|
self.assertEqual(signal_pad.dtype, torch.uint32)
|
|
self.assertEqual(signal_pad.numel(), 64)
|
|
|
|
# Only specify dtype
|
|
signal_pad = symm_mem_hdl.get_signal_pad(peer_rank, dtype=torch.uint64)
|
|
self.assertEqual(signal_pad.dtype, torch.uint64)
|
|
self.assertEqual(signal_pad.numel(), symm_mem_hdl.signal_pad_size // 8)
|
|
|
|
# Specify both sizes and dtype
|
|
signal_pad = symm_mem_hdl.get_signal_pad(peer_rank, (8, 8), dtype=torch.uint64)
|
|
self.assertEqual(signal_pad.dtype, torch.uint64)
|
|
self.assertEqual(signal_pad.numel(), 64)
|
|
|
|
# Sanity check that writes to buffer doesn't corrupt signal_pad
|
|
t = symm_mem.empty(0, device="cuda")
|
|
symm_mem_hdl = symm_mem.rendezvous(t, group=dist.group.WORLD)
|
|
signal_pad = symm_mem_hdl.get_signal_pad(self.rank)
|
|
signal_pad.fill_(42)
|
|
t.fill_(0)
|
|
self.assertTrue(signal_pad.eq(42).all())
|
|
|
|
@skipIf(
|
|
not PLATFORM_SUPPORTS_SYMM_MEM, "SymmMem is not supported on this ROCm arch"
|
|
)
|
|
@requires_cuda
|
|
def test_allow_overlapping_devices(self) -> None:
|
|
os.environ["TORCH_SYMM_MEM_ALLOW_OVERLAPPING_DEVICES"] = "1"
|
|
t = symm_mem.empty(64, device="cuda:0")
|
|
symm_mem_hdl = symm_mem.rendezvous(t, group=dist.group.WORLD)
|
|
|
|
self.assertEqual(symm_mem_hdl.rank, self.rank)
|
|
self.assertEqual(symm_mem_hdl.world_size, self.world_size)
|
|
|
|
for rank in range(self.world_size):
|
|
buf = symm_mem_hdl.get_buffer(rank, (64,), torch.float32)
|
|
if rank == self.rank:
|
|
self.assertEqual(buf.data_ptr(), t.data_ptr())
|
|
else:
|
|
self.assertEqual(buf.device, t.device)
|
|
|
|
os.environ["TORCH_SYMM_MEM_ALLOW_OVERLAPPING_DEVICES"] = "0"
|
|
|
|
@skipIf(
|
|
not PLATFORM_SUPPORTS_SYMM_MEM, "SymmMem is not supported on this ROCm arch"
|
|
)
|
|
@skip_if_lt_x_gpu(2)
|
|
@parametrize("symm_mem_input", [True, False])
|
|
def test_low_contention_all_gather(self, symm_mem_input: bool) -> None:
|
|
self._init_process()
|
|
|
|
if symm_mem_input:
|
|
t = _SymmetricMemory.empty_strided_p2p(
|
|
size=(64, 64),
|
|
stride=(64, 1),
|
|
dtype=torch.float32,
|
|
device=self.device,
|
|
group_name="0",
|
|
).fill_(self.rank)
|
|
else:
|
|
t = torch.full((64, 64), self.rank, dtype=torch.float32, device=self.device)
|
|
|
|
res = torch.ops.symm_mem._low_contention_all_gather(t, "0")
|
|
res = torch.ops._c10d_functional.wait_tensor(res)
|
|
self.assertEqual(res.shape, (64 * self.world_size, 64))
|
|
|
|
chunks = res.chunk(self.world_size)
|
|
for r in range(self.world_size):
|
|
self.assertTrue(chunks[r].eq(r).all())
|
|
|
|
@skipIf(
|
|
not PLATFORM_SUPPORTS_SYMM_MEM, "SymmMem is not supported on this ROCm arch"
|
|
)
|
|
@skip_if_lt_x_gpu(2)
|
|
@parametrize("reduce_op", ["sum", "avg"])
|
|
@parametrize("symm_mem_input", [True, False])
|
|
def test_low_contention_reduce_scatter(
|
|
self, reduce_op: str, symm_mem_input: bool
|
|
) -> None:
|
|
self._init_process()
|
|
|
|
if symm_mem_input:
|
|
t = _SymmetricMemory.empty_strided_p2p(
|
|
size=(64, 64),
|
|
stride=(64, 1),
|
|
dtype=torch.float32,
|
|
device=self.device,
|
|
group_name="0",
|
|
)
|
|
else:
|
|
t = torch.empty((64, 64), dtype=torch.float32, device=self.device)
|
|
|
|
chunks = t.chunk(self.world_size)
|
|
for r in range(self.world_size):
|
|
chunks[r].fill_(r)
|
|
|
|
res = torch.ops.symm_mem._low_contention_reduce_scatter(t, reduce_op, "0")
|
|
res = torch.ops._c10d_functional.wait_tensor(res)
|
|
self.assertEqual(res.shape, (64 // self.world_size, 64))
|
|
|
|
if reduce_op == "sum":
|
|
expect = self.rank * self.world_size
|
|
elif reduce_op == "avg":
|
|
expect = self.rank
|
|
else:
|
|
raise AssertionError(f"Unexpected reduce_op: {reduce_op}")
|
|
self.assertTrue(res.eq(expect).all())
|
|
|
|
@skipIf(
|
|
not PLATFORM_SUPPORTS_SYMM_MEM, "SymmMem is not supported on this ROCm arch"
|
|
)
|
|
@skip_if_lt_x_gpu(4)
|
|
def test_subgroup(self) -> None:
|
|
self._init_process()
|
|
|
|
ranks = list(range(self.world_size))
|
|
subgroup_0 = dist.new_group(ranks[: len(ranks) // 2])
|
|
subgroup_1 = dist.new_group(ranks[len(ranks) // 2 :])
|
|
|
|
world = dist.group.WORLD
|
|
subgroup = subgroup_0 if world.rank() < world.size() // 2 else subgroup_1
|
|
|
|
t = symm_mem.empty(64, device="cuda")
|
|
symm_mem_world = symm_mem.rendezvous(t, group=world)
|
|
symm_mem_subgroup = symm_mem.rendezvous(t, group=subgroup)
|
|
|
|
self.assertEqual(symm_mem_world.world_size, world.size())
|
|
self.assertEqual(symm_mem_world.rank, world.rank())
|
|
self.assertEqual(symm_mem_subgroup.world_size, world.size() // 2)
|
|
self.assertEqual(symm_mem_subgroup.rank, world.rank() % subgroup.size())
|
|
|
|
t.fill_(world.rank())
|
|
symm_mem_world.barrier()
|
|
|
|
# Observe a peer buffer via the world group
|
|
peer_rank = (world.rank() + 1) % world.size()
|
|
buf = symm_mem_world.get_buffer(peer_rank, (64,), torch.float32)
|
|
self.assertTrue(buf.eq(peer_rank).all())
|
|
|
|
# Observe a peer buffer via the subgroup
|
|
peer_rank = (subgroup.rank() + 1) % subgroup.size()
|
|
buf = symm_mem_subgroup.get_buffer(peer_rank, (64,), torch.float32)
|
|
if world.rank() < world.size() // 2:
|
|
self.assertTrue(buf.eq(peer_rank).all())
|
|
else:
|
|
self.assertTrue(buf.eq(peer_rank + world.size() // 2).all())
|
|
|
|
|
|
# We move AsyncTP tests to a seperate test suite because 1) Async TP ops are not
|
|
# the core symmetric memory APIs, they are more like applications, 2)
|
|
# MultiProcContinuousTest will skip all the following tests if a test fails (
|
|
# we should fix this too). We still want to get the test signals for the core
|
|
# symmetric memory APIs when Async TP ops fail.
|
|
@instantiate_parametrized_tests
|
|
@requires_cuda_p2p_access()
|
|
class AsyncTPTest(MultiProcContinuousTest):
|
|
@property
|
|
def device(self) -> torch.device:
|
|
return torch.device(device_type, self.rank)
|
|
|
|
def _init_process(self):
|
|
torch.cuda.set_device(self.device)
|
|
torch.manual_seed(42 + self.rank)
|
|
torch.use_deterministic_algorithms(True)
|
|
torch.set_deterministic_debug_mode("warn")
|
|
torch.utils.deterministic.fill_uninitialized_memory = True
|
|
|
|
@skipIf(
|
|
not PLATFORM_SUPPORTS_SYMM_MEM, "SymmMem is not supported on this ROCm arch"
|
|
)
|
|
@skip_if_lt_x_gpu(2)
|
|
@parametrize("gather_dim", [0, 1])
|
|
def test_fused_all_gather_matmul(self, gather_dim: int) -> None:
|
|
self._init_process()
|
|
|
|
BATCH = 8
|
|
M = 64
|
|
N = 16
|
|
K = 32
|
|
group = dist.group.WORLD
|
|
rank = self.rank
|
|
|
|
torch.manual_seed(42 + rank)
|
|
A_shard = torch.rand(BATCH, M // self.world_size, K, device="cuda")
|
|
Bs = [torch.rand(K, N, device="cuda") for _ in range(3)]
|
|
|
|
ag_output_0, mm_outputs_0 = _fused_all_gather_matmul_fallback(
|
|
A_shard, Bs, gather_dim=gather_dim, group_name=group.group_name
|
|
)
|
|
ag_output_1, mm_outputs_1 = torch.ops.symm_mem.fused_all_gather_matmul(
|
|
A_shard, Bs, gather_dim=gather_dim, group_name=group.group_name
|
|
)
|
|
|
|
assert torch.allclose(ag_output_0, ag_output_1)
|
|
assert ag_output_0.stride() == ag_output_1.stride()
|
|
for mm_output_0, mm_output_1 in zip(mm_outputs_0, mm_outputs_1):
|
|
assert torch.allclose(mm_output_0, mm_output_1)
|
|
assert mm_output_0.stride(), mm_output_1.stride()
|
|
|
|
@skip_if_rocm_multiprocess # this requires async_input_mm support
|
|
@skipIf(
|
|
not SM90OrLater,
|
|
"_fused_all_gather_matmul_native currently only supports sm>=90",
|
|
)
|
|
@skip_if_lt_x_gpu(2)
|
|
@parametrize("symm_mem_input", [True, False])
|
|
@parametrize("is_b_row_major", [True, False])
|
|
def test_fused_all_gather_matmul_native(
|
|
self, symm_mem_input: bool, is_b_row_major: bool
|
|
) -> None:
|
|
os.environ["TORCH_SYMM_MEM_ENABLE_NATIVE_ASYNC_TP"] = "1"
|
|
self._init_process()
|
|
|
|
# See _should_use_fused_all_gather_matmul_native() for the algo
|
|
# selection criteria of _fused_all_gather_matmul_native().
|
|
M = 4096
|
|
N = 1024
|
|
K = 1024
|
|
group_name = dist.group.WORLD.group_name
|
|
|
|
torch.manual_seed(42 + self.rank)
|
|
if symm_mem_input:
|
|
A_shard = symm_mem.empty(
|
|
M // self.world_size,
|
|
K,
|
|
dtype=torch.bfloat16,
|
|
device=self.device,
|
|
).normal_()
|
|
else:
|
|
A_shard = torch.rand(
|
|
M // self.world_size, K, dtype=torch.bfloat16, device="cuda"
|
|
)
|
|
|
|
if is_b_row_major:
|
|
B = torch.rand(K, N, dtype=torch.bfloat16, device="cuda")
|
|
else:
|
|
B = torch.rand(N, K, dtype=torch.bfloat16, device="cuda").t()
|
|
|
|
ag_baseline, mm_baseline = _fused_all_gather_matmul_fallback(
|
|
A_shard, [B], gather_dim=0, group_name=group_name
|
|
)
|
|
with torch.profiler.profile(
|
|
activities=[
|
|
torch.profiler.ProfilerActivity.CUDA,
|
|
],
|
|
) as prof:
|
|
ag_target, mm_target = torch.ops.symm_mem.fused_all_gather_matmul(
|
|
A_shard, [B], gather_dim=0, group_name=group_name
|
|
)
|
|
|
|
self.assertTrue(
|
|
any("PersistentAsyncInputScheduler" in event.key for event in prof.events())
|
|
)
|
|
|
|
torch.testing.assert_close(ag_target, ag_baseline)
|
|
torch.testing.assert_close(mm_target[0], mm_baseline[0])
|
|
os.environ["TORCH_SYMM_MEM_ENABLE_NATIVE_ASYNC_TP"] = "0"
|
|
|
|
@skip_if_lt_x_gpu(2)
|
|
@requires_multicast_support()
|
|
def test_multimem_all_gather_matmul(self) -> None:
|
|
self._init_process()
|
|
|
|
# See _should_use_multimem_all_gather_matmul() for the algo
|
|
# selection criteria of _multimem_gather_matmul().
|
|
M = 1024
|
|
N = 1024
|
|
K = 1024
|
|
group_name = dist.group.WORLD.group_name
|
|
|
|
torch.manual_seed(42 + self.rank)
|
|
A_shard = torch.rand(
|
|
M // self.world_size, K, dtype=torch.bfloat16, device="cuda"
|
|
)
|
|
|
|
B = torch.rand(K, N, dtype=torch.bfloat16, device="cuda")
|
|
|
|
ag_baseline, mm_baseline = _fused_all_gather_matmul_fallback(
|
|
A_shard, [B], gather_dim=0, group_name=group_name, return_A=False
|
|
)
|
|
with torch.profiler.profile(
|
|
activities=[
|
|
torch.profiler.ProfilerActivity.CUDA,
|
|
],
|
|
) as prof:
|
|
ag_target, mm_target = torch.ops.symm_mem.fused_all_gather_matmul(
|
|
A_shard, [B], gather_dim=0, group_name=group_name, return_A=False
|
|
)
|
|
|
|
self.assertTrue(
|
|
any("multimem_all_gather_kernel" in event.key for event in prof.events())
|
|
)
|
|
|
|
torch.testing.assert_close(ag_target, ag_baseline)
|
|
torch.testing.assert_close(mm_target[0], mm_baseline[0])
|
|
|
|
@skipIf(
|
|
not PLATFORM_SUPPORTS_SYMM_MEM, "SymmMem is not supported on this ROCm arch"
|
|
)
|
|
@skip_if_lt_x_gpu(2)
|
|
@parametrize("gather_dim", [0, 1])
|
|
@parametrize(
|
|
"scale_mode", ["tensor-wise", "row-wise-replicated", "row-wise-sharded"]
|
|
)
|
|
def test_fused_all_gather_scaled_matmul(
|
|
self, gather_dim: int, scale_mode: str
|
|
) -> None:
|
|
self._init_process()
|
|
|
|
BATCH = 8
|
|
M = 64
|
|
N = 16
|
|
K = 32
|
|
group = dist.group.WORLD
|
|
rank = self.rank
|
|
|
|
if gather_dim == 0:
|
|
leading_dims = (BATCH // self.world_size, M)
|
|
elif gather_dim == 1:
|
|
leading_dims = (BATCH, M // self.world_size)
|
|
else:
|
|
raise AssertionError("Invalid scale_mode: {scale_mode}")
|
|
|
|
torch.manual_seed(42 + rank)
|
|
|
|
A_shard = torch.rand(*leading_dims, K, device="cuda").to(e4m3_type)
|
|
Bs = [torch.rand(N, K, device="cuda").to(e4m3_type).T for _ in range(3)]
|
|
|
|
if scale_mode == "tensor-wise":
|
|
A_scale = torch.tensor(0.1, device="cuda")
|
|
B_scales = [torch.tensor(0.1, device="cuda") for _ in range(3)]
|
|
out_dtypes = [None, torch.bfloat16, torch.float32]
|
|
elif scale_mode == "row-wise-sharded":
|
|
A_scale = torch.full((*leading_dims, 1), 0.1, device="cuda")
|
|
B_scales = [torch.full((1, N), 0.1, device="cuda") for _ in range(3)]
|
|
out_dtypes = [torch.bfloat16] * 3
|
|
elif scale_mode == "row-wise-replicated":
|
|
A_scale = torch.full((BATCH, M, 1), 0.1, device="cuda")
|
|
B_scales = [torch.full((1, N), 0.1, device="cuda") for _ in range(3)]
|
|
out_dtypes = [torch.bfloat16] * 3
|
|
else:
|
|
raise AssertionError(f"Invalid scale_mode: {scale_mode}")
|
|
|
|
ag_output_0, mm_outputs_0 = _fused_all_gather_scaled_matmul_fallback(
|
|
A_shard,
|
|
Bs,
|
|
A_scale,
|
|
B_scales,
|
|
gather_dim=gather_dim,
|
|
group_name=group.group_name,
|
|
biases=[None] * len(Bs),
|
|
result_scales=[None] * len(Bs),
|
|
out_dtypes=out_dtypes,
|
|
use_fast_accum=[None] * len(Bs),
|
|
)
|
|
ag_output_1, mm_outputs_1 = torch.ops.symm_mem.fused_all_gather_scaled_matmul(
|
|
A_shard,
|
|
Bs,
|
|
A_scale,
|
|
B_scales,
|
|
gather_dim=gather_dim,
|
|
group_name=group.group_name,
|
|
biases=[None] * len(Bs),
|
|
result_scales=[None] * len(Bs),
|
|
out_dtypes=out_dtypes,
|
|
use_fast_accum=[None] * len(Bs),
|
|
)
|
|
|
|
self.assertTrue(
|
|
torch.allclose(
|
|
ag_output_0.to(torch.float32),
|
|
ag_output_1.to(torch.float32),
|
|
)
|
|
)
|
|
self.assertEqual(ag_output_0.stride(), ag_output_1.stride())
|
|
for mm_output_0, mm_output_1 in zip(mm_outputs_0, mm_outputs_1):
|
|
self.assertTrue(
|
|
torch.allclose(
|
|
mm_output_0.to(torch.float32), mm_output_1.to(torch.float32)
|
|
)
|
|
)
|
|
self.assertEqual(mm_output_0.stride(), mm_output_1.stride())
|
|
self.assertEqual(mm_output_0.dtype, mm_output_1.dtype)
|
|
|
|
@skipIf(
|
|
not PLATFORM_SUPPORTS_SYMM_MEM, "SymmMem is not supported on this ROCm arch"
|
|
)
|
|
@skip_if_lt_x_gpu(2)
|
|
@parametrize("scatter_dim", [0, 1])
|
|
def test_fused_matmul_reduce_scatter(self, scatter_dim: int) -> None:
|
|
self._init_process()
|
|
|
|
BATCH = 8
|
|
M = 64
|
|
N = 16
|
|
K = 32
|
|
group = dist.group.WORLD
|
|
rank = self.rank
|
|
|
|
torch.manual_seed(42 + rank)
|
|
A = torch.rand(BATCH, M, K, device="cuda")
|
|
B = torch.rand(K, N, device="cuda")
|
|
|
|
output_0 = _fused_matmul_reduce_scatter_fallback(
|
|
A, B, "avg", scatter_dim=scatter_dim, group_name=group.group_name
|
|
)
|
|
output_1 = torch.ops.symm_mem.fused_matmul_reduce_scatter(
|
|
A, B, "avg", scatter_dim=scatter_dim, group_name=group.group_name
|
|
)
|
|
|
|
assert torch.allclose(output_0, output_1)
|
|
assert output_0.stride() == output_1.stride()
|
|
|
|
@skip_if_rocm_multiprocess # AsyncTP support changed _fused_scaled_matmul_reduce_scatter_fallback API, need more changes
|
|
@skip_if_lt_x_gpu(2)
|
|
@parametrize("scatter_dim", [0, 1])
|
|
@parametrize("rowwise", [True, False])
|
|
def test_fused_scaled_matmul_reduce_scatter(
|
|
self, scatter_dim: int, rowwise: bool
|
|
) -> None:
|
|
self._init_process()
|
|
|
|
BATCH = 8
|
|
M = 64
|
|
N = 16
|
|
K = 32
|
|
group = dist.group.WORLD
|
|
rank = self.rank
|
|
|
|
torch.manual_seed(42 + rank)
|
|
A = torch.rand(BATCH, M, K, device="cuda").to(e4m3_type)
|
|
B = torch.rand(N, K, device="cuda").to(e4m3_type).T
|
|
|
|
if rowwise:
|
|
A_scale = torch.full((BATCH, M, 1), 0.1, device="cuda")
|
|
B_scale = torch.full((1, N), 0.1, device="cuda")
|
|
else:
|
|
A_scale = torch.tensor(0.1, device="cuda")
|
|
B_scale = torch.tensor(0.1, device="cuda")
|
|
|
|
output_shape = [*A.shape[:-1], B.shape[1]]
|
|
|
|
outputs = []
|
|
for context in test_contexts:
|
|
with context():
|
|
outputs.append(
|
|
torch.ops.symm_mem.fused_scaled_matmul_reduce_scatter(
|
|
A,
|
|
B,
|
|
A_scale,
|
|
B_scale,
|
|
"avg",
|
|
scatter_dim,
|
|
scatter_dim,
|
|
group.group_name,
|
|
output_shape,
|
|
out_dtype=torch.bfloat16,
|
|
)
|
|
)
|
|
|
|
assert outputs[0].stride() == outputs[1].stride()
|
|
self.assertEqual(outputs[0], outputs[1])
|
|
|
|
@skipIf(
|
|
not PLATFORM_SUPPORTS_SYMM_MEM, "SymmMem is not supported on this ROCm arch"
|
|
)
|
|
@parametrize("dim", [0, 1, 2])
|
|
def test_optimal_layout(self, dim: int) -> None:
|
|
t = torch.rand(8, 64, 32, 16)
|
|
|
|
x = restride_A_shard_for_fused_all_gather_matmul(t, dim)
|
|
self.assertTrue(x.movedim(dim, 0).is_contiguous())
|
|
self.assertTrue(torch.allclose(x, t))
|
|
|
|
x = restride_A_for_fused_matmul_reduce_scatter(t, dim)
|
|
self.assertTrue(x.movedim(dim, 0).is_contiguous())
|
|
self.assertTrue(torch.allclose(x, t))
|
|
|
|
|
|
# [READ ME FIRST]
|
|
# The `SymmMemEmptySetDeviceTest` suite parameterizes whether user sets the
|
|
# device before calling symm_mem.emtpy. Either way should work.
|
|
# However, since `set_device` is persistent, we cannot use the
|
|
# `MultiProcContinuousTest` template because the next function will be
|
|
# "contaminated", leading to flaky tests (e.g. hang). Therefore, we use
|
|
# `MultiProcessTestCase` which spawns new processes for each test function.
|
|
# Please limit the number of tests you want to add under this test
|
|
# suite as respawning processes and `init_process_group` is expensive.
|
|
@instantiate_parametrized_tests
|
|
@requires_cuda_p2p_access()
|
|
class SymmMemEmptySetDeviceTest(MultiProcessTestCase):
|
|
def setUp(self) -> None:
|
|
super().setUp()
|
|
self._spawn_processes()
|
|
|
|
@property
|
|
def world_size(self) -> int:
|
|
return device_module.device_count()
|
|
|
|
@property
|
|
def device(self) -> torch.device:
|
|
return torch.device(device_type, self.rank)
|
|
|
|
def _init_process(self, set_device: bool):
|
|
if set_device:
|
|
torch.cuda.set_device(self.device)
|
|
store = dist.FileStore(self.file_name, self.world_size)
|
|
dist.init_process_group(
|
|
backend="nccl",
|
|
world_size=self.world_size,
|
|
rank=self.rank,
|
|
store=store,
|
|
)
|
|
torch.manual_seed(42 + self.rank)
|
|
|
|
def _get_test_alloc_args(self):
|
|
shape = (64, 64)
|
|
stride = (64, 1)
|
|
dtype = torch.float32
|
|
device = self.device
|
|
return (shape, stride, dtype, device)
|
|
|
|
def _verify_symmetric_memory(self, symm_mem_hdl):
|
|
self.assertEqual(symm_mem_hdl.world_size, self.world_size)
|
|
|
|
buf = symm_mem_hdl.get_buffer(
|
|
0, (symm_mem_hdl.buffer_size // 4,), torch.float32
|
|
)
|
|
self.assertEqual(buf.storage_offset(), 0)
|
|
self.assertEqual(buf.untyped_storage().size(), symm_mem_hdl.buffer_size)
|
|
|
|
if symm_mem_hdl.rank == 0:
|
|
symm_mem_hdl.wait_signal(src_rank=1)
|
|
self.assertTrue(buf.eq(42).all())
|
|
else:
|
|
buf.fill_(42)
|
|
symm_mem_hdl.put_signal(dst_rank=0)
|
|
|
|
symm_mem_hdl.barrier()
|
|
|
|
if symm_mem_hdl.rank == 0:
|
|
symm_mem_hdl.barrier()
|
|
self.assertTrue(buf.eq(43).all())
|
|
else:
|
|
buf.fill_(43)
|
|
symm_mem_hdl.barrier()
|
|
|
|
symm_mem_hdl.barrier()
|
|
|
|
@skipIf(
|
|
not PLATFORM_SUPPORTS_SYMM_MEM, "SymmMem is not supported on this ROCm arch"
|
|
)
|
|
@skip_if_lt_x_gpu(2)
|
|
@parametrize("set_device", [True, False])
|
|
def test_empty_strided_p2p(self, set_device: bool) -> None:
|
|
self._init_process(set_device)
|
|
group_name = dist.group.WORLD.group_name
|
|
enable_symm_mem_for_group(group_name)
|
|
|
|
alloc_args = self._get_test_alloc_args()
|
|
|
|
t = torch.empty((64, 64), device=self.device)
|
|
self.assertIsNone(_SymmetricMemory.rendezvous(t))
|
|
|
|
t = _SymmetricMemory.empty_strided_p2p(*alloc_args, group_name=group_name)
|
|
symm_mem_hdl = _SymmetricMemory.rendezvous(t)
|
|
|
|
del t
|
|
self._verify_symmetric_memory(symm_mem_hdl)
|
|
|
|
@skipIf(
|
|
not PLATFORM_SUPPORTS_SYMM_MEM, "SymmMem is not supported on this ROCm arch"
|
|
)
|
|
@skip_if_rocm_ver_lessthan_multiprocess((7, 0))
|
|
@skip_if_lt_x_gpu(2)
|
|
@parametrize("set_device", [True, False])
|
|
def test_empty_strided_p2p_persistent(self, set_device: bool) -> None:
|
|
self._init_process(set_device)
|
|
group_name = dist.group.WORLD.group_name
|
|
enable_symm_mem_for_group(group_name)
|
|
|
|
alloc_args = self._get_test_alloc_args()
|
|
|
|
alloc_id = 42 + random.randint(0, 2147483647)
|
|
t = _SymmetricMemory.empty_strided_p2p(
|
|
*alloc_args, group_name=group_name, alloc_id=alloc_id
|
|
)
|
|
data_ptr = t.data_ptr()
|
|
|
|
# Verify that persistent allocation would fail if there's an active
|
|
# allocation with the same alloc_id.
|
|
with self.assertRaises(RuntimeError):
|
|
_SymmetricMemory.empty_strided_p2p(
|
|
*alloc_args, group_name=group_name, alloc_id=alloc_id
|
|
)
|
|
|
|
# Verify that persistent allocation would succeed in lieu of activate
|
|
# allocations with the same alloc_id, and the returned tensor would
|
|
# have the same data pointer.
|
|
del t
|
|
t = _SymmetricMemory.empty_strided_p2p(
|
|
*alloc_args, group_name=group_name, alloc_id=alloc_id
|
|
)
|
|
self.assertEqual(t.data_ptr(), data_ptr)
|
|
|
|
symm_mem_hdl = _SymmetricMemory.rendezvous(t)
|
|
self._verify_symmetric_memory(symm_mem_hdl)
|
|
|
|
|
|
# This Test class is used to test the error handling of SymmetricMemory APIs.
|
|
# Since a process restart is often needed after each test, we use the
|
|
# MultiProcessTestCase instead of MultiProcContinuousTest.
|
|
@requires_cuda_p2p_access()
|
|
class SymmMemNegativeTest(MultiProcessTestCase):
|
|
def setUp(self) -> None:
|
|
super().setUp()
|
|
self._spawn_processes()
|
|
|
|
@property
|
|
def world_size(self) -> int:
|
|
return device_module.device_count()
|
|
|
|
@property
|
|
def device(self) -> torch.device:
|
|
return torch.device(device_type, self.rank)
|
|
|
|
def _init_process(self):
|
|
torch.cuda.set_device(self.device)
|
|
store = dist.FileStore(self.file_name, self.world_size)
|
|
dist.init_process_group(
|
|
backend="nccl",
|
|
world_size=self.world_size,
|
|
rank=self.rank,
|
|
store=store,
|
|
)
|
|
torch.manual_seed(42 + self.rank)
|
|
|
|
# These timeout tests are skipped on ROCm because timeout calls trap(), which
|
|
# is handled differently inside hip runtime. It collects gpu coredump and causes
|
|
# the linux kernel to create a core dump of the host application. The functionality
|
|
# is there, meaning timeout is happening correctly. However, there isn't a nice way
|
|
# to test it as the current executing thread will coredump and exit.
|
|
@skip_if_rocm_multiprocess
|
|
@skip_if_lt_x_gpu(2)
|
|
def test_barrier_timeout(self) -> None:
|
|
self._init_process()
|
|
|
|
t = symm_mem.empty(1, device="cuda")
|
|
symm_mem_hdl = symm_mem.rendezvous(t, group=dist.group.WORLD)
|
|
|
|
if self.rank == 0:
|
|
with self.assertRaises(RuntimeError):
|
|
symm_mem_hdl.barrier(timeout_ms=1000)
|
|
torch.cuda.synchronize()
|
|
else:
|
|
torch.cuda.synchronize()
|
|
|
|
# The device-side timeout triggers a __trap() that causes all
|
|
# subsequent host/device interactions to result in an "unspecified
|
|
# launch failure." Using os._exit(0) to abort the test, as it's
|
|
# impossible to terminate the process in this state.
|
|
os._exit(0)
|
|
|
|
# These timeout tests are skipped on ROCm because timeout calls trap(), which
|
|
# is handled differently inside hip runtime. It collects gpu coredump and causes
|
|
# the linux kernel to create a core dump of the host application. The functionality
|
|
# is there, meaning timeout is happening correctly. However, there isn't a nice way
|
|
# to test it as the current executing thread will coredump and exit.
|
|
@skip_if_rocm_multiprocess
|
|
@skip_if_lt_x_gpu(2)
|
|
def test_put_signal_timeout(self) -> None:
|
|
self._init_process()
|
|
|
|
t = symm_mem.empty(1, device="cuda")
|
|
symm_mem_hdl = symm_mem.rendezvous(t, group=dist.group.WORLD)
|
|
|
|
if self.rank == 0:
|
|
with self.assertRaises(RuntimeError):
|
|
# First, put a signal into rank 1's signal pad. Since rank 1
|
|
# doesn't wait on this signal, the subsequent put will timeout.
|
|
symm_mem_hdl.put_signal(dst_rank=1)
|
|
symm_mem_hdl.put_signal(dst_rank=1, timeout_ms=1000)
|
|
torch.cuda.synchronize()
|
|
else:
|
|
torch.cuda.synchronize()
|
|
|
|
# The device-side timeout triggers a __trap() that causes all
|
|
# subsequent host/device interactions to result in an "unspecified
|
|
# launch failure." Using os._exit(0) to abort the test, as it's
|
|
# impossible to terminate the process in this state.
|
|
os._exit(0)
|
|
|
|
# These timeout tests are skipped on ROCm because timeout calls trap(), which
|
|
# is handled differently inside hip runtime. It collects gpu coredump and causes
|
|
# the linux kernel to create a core dump of the host application. The functionality
|
|
# is there, meaning timeout is happening correctly. However, there isn't a nice way
|
|
# to test it as the current executing thread will coredump and exit.
|
|
@skip_if_rocm_multiprocess
|
|
@skip_if_lt_x_gpu(2)
|
|
def test_wait_signal_timeout(self) -> None:
|
|
self._init_process()
|
|
|
|
t = symm_mem.empty(1, device="cuda")
|
|
symm_mem_hdl = symm_mem.rendezvous(t, group=dist.group.WORLD)
|
|
|
|
if self.rank == 0:
|
|
with self.assertRaises(RuntimeError):
|
|
symm_mem_hdl.wait_signal(src_rank=1, timeout_ms=1000)
|
|
torch.cuda.synchronize()
|
|
else:
|
|
torch.cuda.synchronize()
|
|
|
|
# The device-side timeout triggers a __trap() that causes all
|
|
# subsequent host/device interactions to result in an "unspecified
|
|
# launch failure." Using os._exit(0) to abort the test, as it's
|
|
# impossible to terminate the process in this state.
|
|
os._exit(0)
|
|
|
|
|
|
@instantiate_parametrized_tests
|
|
@requires_cuda_p2p_access()
|
|
class SymmMemCollectiveTest(MultiProcContinuousTest):
|
|
@property
|
|
def device(self) -> torch.device:
|
|
return torch.device(device_type, self.rank)
|
|
|
|
def _init_process(self):
|
|
torch.cuda.set_device(self.device)
|
|
torch.manual_seed(42 + self.rank)
|
|
|
|
@skip_if_lt_x_gpu(4)
|
|
@requires_multicast_support()
|
|
@parametrize("dtype", [torch.float, torch.bfloat16])
|
|
@parametrize("align_bytes", [4, 8, 16])
|
|
@parametrize("size_bytes", [4, 8192, 8196])
|
|
def test_multimem_all_reduce(
|
|
self, dtype: torch.dtype, size_bytes: int, align_bytes: int
|
|
) -> None:
|
|
self._init_process()
|
|
group_name = dist.group.WORLD.group_name
|
|
|
|
t = symm_mem.empty((16384), dtype=dtype, device=self.device)
|
|
symm_mem.rendezvous(t, group=dist.group.WORLD)
|
|
|
|
self.assertTrue(t.data_ptr() % 16 == 0)
|
|
self.assertTrue(align_bytes % t.element_size() == 0)
|
|
self.assertTrue(size_bytes % t.element_size() == 0)
|
|
|
|
shift = align_bytes // t.element_size()
|
|
numel = size_bytes // t.element_size()
|
|
res = t[shift : shift + numel]
|
|
res.normal_()
|
|
inp = res.clone()
|
|
|
|
torch.ops.symm_mem.multimem_all_reduce_(res, "sum", group_name)
|
|
|
|
# Head and tail should not be written
|
|
self.assertTrue(t[:shift].eq(0).all().item())
|
|
self.assertTrue(t[shift + numel :].eq(0).all().item())
|
|
self._verify_all_reduce_result(inp, res)
|
|
|
|
@skip_if_lt_x_gpu(4)
|
|
@requires_multicast_support()
|
|
@parametrize("dtype", [torch.float, torch.bfloat16])
|
|
@parametrize("align_bytes", [4, 8, 16])
|
|
@parametrize("size_bytes", [4, 8192, 8196])
|
|
def test_multimem_one_shot_all_reduce(
|
|
self, dtype: torch.dtype, size_bytes: int, align_bytes: int
|
|
) -> None:
|
|
self._init_process()
|
|
group_name = dist.group.WORLD.group_name
|
|
|
|
inp = symm_mem.empty(
|
|
size_bytes // dtype.itemsize, dtype=dtype, device=self.device
|
|
).normal_()
|
|
symm_mem.rendezvous(inp, group=group_name)
|
|
|
|
res = torch.ops.symm_mem.multimem_one_shot_all_reduce(inp, "sum", group_name)
|
|
|
|
gathered_inps = all_gather_tensor(inp, 0, "0").view(self.world_size, -1)
|
|
# Only verify that the results are close to the sum of inputs across
|
|
# ranks (see Note [multimem_one_shot_all_reduce]).
|
|
torch.testing.assert_close(
|
|
gathered_inps.sum(dim=0), res, rtol=1e-03, atol=1e-05
|
|
)
|
|
|
|
@skipIf(
|
|
not PLATFORM_SUPPORTS_SYMM_MEM, "SymmMem is not supported on this ROCm arch"
|
|
)
|
|
@skip_if_lt_x_gpu(4)
|
|
def test_one_shot_all_reduce(self) -> None:
|
|
self._init_process()
|
|
group_name = dist.group.WORLD.group_name
|
|
|
|
for dtype, size_bytes, align_bytes, copy, offset in itertools.product(
|
|
[torch.float, torch.bfloat16],
|
|
[4, 8192, 8196],
|
|
[
|
|
8
|
|
], # TODO: add back [4, 8, 16], currently OOM when looping over all combinations
|
|
[True, False],
|
|
[0, 16],
|
|
):
|
|
inp = symm_mem.empty(
|
|
size_bytes // dtype.itemsize + offset, dtype=dtype, device=self.device
|
|
)
|
|
symm_mem.rendezvous(inp, group=group_name)
|
|
if not copy:
|
|
inp.normal_()
|
|
res = torch.ops.symm_mem.one_shot_all_reduce(
|
|
inp[offset:], "sum", group_name
|
|
)
|
|
if copy:
|
|
local_inp = torch.randn_like(inp[offset:])
|
|
res = torch.ops.symm_mem.one_shot_all_reduce_copy(
|
|
inp[offset:], local_inp, "sum", group_name
|
|
)
|
|
self._verify_all_reduce_result(local_inp if copy else inp[offset:], res)
|
|
|
|
@skipIf(
|
|
not PLATFORM_SUPPORTS_SYMM_MEM, "SymmMem is not supported on this ROCm arch"
|
|
)
|
|
@skip_if_lt_x_gpu(4)
|
|
def test_two_shot_all_reduce(self) -> None:
|
|
self._init_process()
|
|
group_name = dist.group.WORLD.group_name
|
|
|
|
for dtype, size_bytes, align_bytes, inplace in itertools.product(
|
|
[torch.float, torch.bfloat16],
|
|
[4, 8192, 8196],
|
|
[
|
|
8
|
|
], # TODO: add back [4, 8, 16], currently OOM when looping over all combinations
|
|
[True, False],
|
|
):
|
|
t = symm_mem.empty(16384, dtype=dtype, device=self.device).fill_(0)
|
|
symm_mem.rendezvous(t, group=group_name)
|
|
|
|
self.assertTrue(t.data_ptr() % 16 == 0)
|
|
self.assertTrue(align_bytes % t.element_size() == 0)
|
|
self.assertTrue(size_bytes % t.element_size() == 0)
|
|
|
|
shift = align_bytes // t.element_size()
|
|
numel = size_bytes // t.element_size()
|
|
res = t[shift : shift + numel]
|
|
res.normal_()
|
|
inp = res.clone()
|
|
if not inplace:
|
|
out = torch.empty_like(inp)
|
|
torch.ops.symm_mem.two_shot_all_reduce_out(res, "sum", group_name, out)
|
|
else:
|
|
torch.ops.symm_mem.two_shot_all_reduce_(res, "sum", group_name)
|
|
|
|
# Head and tail should not be written
|
|
self.assertTrue(t[:shift].eq(0).all().item())
|
|
self.assertTrue(t[shift + numel :].eq(0).all().item())
|
|
self._verify_all_reduce_result(inp, res if inplace else out)
|
|
|
|
def _verify_all_reduce_result(self, inp, res):
|
|
gathered_res = all_gather_tensor(res, 0, "0").view(self.world_size, -1)
|
|
# Verify that the results across ranks are identical
|
|
self.assertEqual(
|
|
(gathered_res == gathered_res[0, :]).all(dim=0).sum(), inp.numel()
|
|
)
|
|
|
|
# Verify that the result are close to the sum of inputs across ranks
|
|
gathered_inps = all_gather_tensor(inp, 0, "0").view(self.world_size, -1)
|
|
torch.testing.assert_close(
|
|
gathered_inps.sum(dim=0), res, rtol=1e-01, atol=1e-01
|
|
)
|
|
|
|
@skipIf(
|
|
not PLATFORM_SUPPORTS_SYMM_MEM, "SymmMem is not supported on this ROCm arch"
|
|
)
|
|
@skip_if_lt_x_gpu(4)
|
|
def test_reduce_scatter(self) -> None:
|
|
self._init_process()
|
|
group_name = dist.group.WORLD.group_name
|
|
|
|
for dtype, size_bytes, align_bytes, split_last_dim in itertools.product(
|
|
[torch.float, torch.bfloat16],
|
|
[128, 8192, 36 * 1024 * 16],
|
|
[
|
|
8
|
|
], # TODO: add back [4, 8, 16], currently OOM when looping over all combinations
|
|
[True, False],
|
|
):
|
|
t = symm_mem.empty(36 * 1024 * 16, dtype=dtype, device=self.device).fill_(0)
|
|
symm_mem.rendezvous(t, group=group_name)
|
|
|
|
self.assertTrue(t.data_ptr() % 16 == 0)
|
|
self.assertTrue(align_bytes % t.element_size() == 0)
|
|
self.assertTrue(size_bytes % t.element_size() == 0)
|
|
|
|
shift = align_bytes // t.element_size()
|
|
numel = size_bytes // t.element_size()
|
|
res = t[shift : shift + numel].normal_()
|
|
if split_last_dim:
|
|
res = res.view(-1, 128 // t.element_size())
|
|
inp = res.clone()
|
|
out_size = list(inp.shape)
|
|
out_size[-1] = inp.shape[-1] // self.world_size
|
|
out = torch.empty(out_size, dtype=dtype, device=self.device)
|
|
torch.ops.symm_mem.reduce_scatter_out(res, group_name, split_last_dim, out)
|
|
|
|
# Head and tail should not be written
|
|
self.assertTrue(t[:shift].eq(0).all().item())
|
|
self.assertTrue(t[shift + numel :].eq(0).all().item())
|
|
self._verify_reduce_scatter_result(inp, out)
|
|
|
|
@skipIf(
|
|
not PLATFORM_SUPPORTS_SYMM_MEM, "SymmMem is not supported on this ROCm arch"
|
|
)
|
|
@skip_if_lt_x_gpu(4)
|
|
def test_reduce_scatter_corner_cases(self) -> None:
|
|
self._init_process()
|
|
dtype = torch.bfloat16
|
|
group_name = dist.group.WORLD.group_name
|
|
t = symm_mem.empty(16384, dtype=dtype, device=self.device).fill_(0)
|
|
symm_mem.rendezvous(t, group=group_name)
|
|
res = t[:0]
|
|
out_size = res.shape[0] // self.world_size
|
|
out = torch.empty(out_size, dtype=dtype, device=self.device)
|
|
torch.ops.symm_mem.reduce_scatter_out(res, group_name, False, out)
|
|
res = t[:48]
|
|
out_size = res.shape[0] // self.world_size
|
|
out = torch.empty(out_size, dtype=dtype, device=self.device)
|
|
with self.assertRaisesRegex(RuntimeError, "divisible"):
|
|
torch.ops.symm_mem.reduce_scatter_out(res, group_name, False, out)
|
|
res = t[: 2 * 48].view(2, 48)
|
|
out = torch.empty(2, 48 // self.world_size, dtype=dtype, device=self.device)
|
|
with self.assertRaisesRegex(RuntimeError, "divisible"):
|
|
torch.ops.symm_mem.reduce_scatter_out(res, group_name, True, out)
|
|
|
|
def _verify_reduce_scatter_result(self, inp, res):
|
|
gathered_res = all_gather_tensor(res, 0, "0").view(self.world_size, *res.shape)
|
|
gathered_inps = all_gather_tensor(inp, 0, "0").view(self.world_size, *inp.shape)
|
|
sum_inps = gathered_inps.sum(0)
|
|
slice_width = sum_inps.shape[-1] // self.world_size
|
|
for i in range(self.world_size):
|
|
torch.testing.assert_close(
|
|
gathered_res[i],
|
|
sum_inps[..., i * slice_width : (i + 1) * slice_width],
|
|
rtol=1e-01,
|
|
atol=1.1e-01,
|
|
)
|
|
|
|
@skip_if_lt_x_gpu(4)
|
|
@requires_multicast_support()
|
|
@parametrize("align_bytes", [4, 8, 16])
|
|
def test_multimem_all_gather(self, align_bytes: int) -> None:
|
|
self._init_process()
|
|
group_name = dist.group.WORLD.group_name
|
|
|
|
input_numel = 32
|
|
shift = align_bytes // 4
|
|
input = torch.zeros(shift + input_numel, device=self.device)[shift:].fill_(
|
|
self.rank
|
|
)
|
|
|
|
out = symm_mem.empty(
|
|
shift + input_numel * self.world_size, device=self.device
|
|
).zero_()[shift:]
|
|
symm_mem.rendezvous(out, group=group_name)
|
|
|
|
torch.ops.symm_mem.multimem_all_gather_out(input, group_name, out)
|
|
ref = torch.ops._c10d_functional.all_gather_into_tensor(
|
|
input, self.world_size, group_name
|
|
)
|
|
ref = torch.ops._c10d_functional.wait_tensor(ref)
|
|
|
|
self.assertTrue(out.eq(ref).all())
|
|
|
|
|
|
@instantiate_parametrized_tests
|
|
@requires_cuda_p2p_access()
|
|
class LoweringTest(MultiProcContinuousTest):
|
|
def _init_process(self) -> None:
|
|
torch.cuda.set_device(self.device)
|
|
enable_symm_mem_for_group(dist.group.WORLD.group_name)
|
|
torch.manual_seed(42 + self.rank)
|
|
torch._inductor.config._collective.auto_select = True
|
|
|
|
@property
|
|
def device(self) -> torch.device:
|
|
return torch.device(device_type, self.rank)
|
|
|
|
@skip("Fails with 'one_shot_all_reduce' not found in AOT graph, TODO: fix")
|
|
@skip_if_rocm_multiprocess # requires registered-buffer support
|
|
@skip_if_lt_x_gpu(2)
|
|
@fresh_cache()
|
|
def test_lowering_one_shot_all_reduce(self):
|
|
self._init_process()
|
|
arg = torch.rand(4, 4, device=self.device)
|
|
|
|
def func_0(x):
|
|
x = x + 1
|
|
x = torch.ops._c10d_functional.all_reduce(x, "sum", "0")
|
|
return torch.ops._c10d_functional.wait_tensor(x)
|
|
|
|
compiled_0 = torch.compile(func_0, fullgraph=True)
|
|
code_0 = run_and_get_triton_code(compiled_0, arg)
|
|
|
|
self.assertIn("one_shot_all_reduce", code_0)
|
|
self.assertNotIn("return (buf0", code_0)
|
|
|
|
# All-reduce on a slice view
|
|
def func_1(x):
|
|
x = x + 1
|
|
x = x[2:]
|
|
x = torch.ops._c10d_functional.all_reduce(x, "sum", "0")
|
|
return torch.ops._c10d_functional.wait_tensor(x)
|
|
|
|
compiled_1 = torch.compile(func_1, fullgraph=True)
|
|
code_1 = run_and_get_triton_code(compiled_1, arg)
|
|
|
|
self.assertIn("one_shot_all_reduce", code_1)
|
|
self.assertNotIn("return (buf0", code_1)
|
|
|
|
# All-reduce on input
|
|
def func_2(x):
|
|
x = torch.ops._c10d_functional.all_reduce(x, "sum", "0")
|
|
return torch.ops._c10d_functional.wait_tensor(x)
|
|
|
|
compiled_2 = torch.compile(func_2, fullgraph=True)
|
|
code_2 = run_and_get_triton_code(compiled_2, arg)
|
|
|
|
self.assertNotIn("one_shot_all_reduce", code_2)
|
|
|
|
# All-reduce on matmul output
|
|
def func_3(x):
|
|
x = x @ x
|
|
x = torch.ops._c10d_functional.all_reduce(x, "sum", "0")
|
|
return torch.ops._c10d_functional.wait_tensor(x)
|
|
|
|
compiled_3 = torch.compile(func_3, fullgraph=True)
|
|
code_3 = run_and_get_triton_code(compiled_3, arg)
|
|
|
|
self.assertIn("one_shot_all_reduce", code_3)
|
|
self.assertNotIn("return (buf0", code_3)
|
|
|
|
|
|
class SymmMemSingleProcTest(TestCase):
|
|
@requires_cuda
|
|
@skipIf(
|
|
not TEST_WITH_ROCM and _get_torch_cuda_version() < (12, 0),
|
|
"stream_write_value32 currently only supports cuda version>=12.0",
|
|
)
|
|
@skipIf(
|
|
not PLATFORM_SUPPORTS_SYMM_MEM, "SymmMem is not supported on this ROCm arch"
|
|
)
|
|
def test_stream_write_value32(self):
|
|
tensor = torch.zeros(4, dtype=torch.uint32, device="cuda")
|
|
expect = torch.tril(torch.ones(4, 4, device="cuda")).to(torch.uint32)
|
|
|
|
for i in range(4):
|
|
_SymmetricMemory.stream_write_value32(tensor, i, 1)
|
|
torch.testing.assert_close(tensor, expect[i])
|
|
|
|
with self.assertRaises(RuntimeError):
|
|
_SymmetricMemory.stream_write_value32(tensor, offset=-1, val=1)
|
|
|
|
with self.assertRaises(RuntimeError):
|
|
_SymmetricMemory.stream_write_value32(tensor, offset=0, val=4294967296)
|
|
|
|
@skipIf(
|
|
not PLATFORM_SUPPORTS_SYMM_MEM, "SymmMem is not supported on this ROCm arch"
|
|
)
|
|
@requires_cuda
|
|
def test_memset32(self):
|
|
t = _SymmetricMemory.empty_strided_p2p(
|
|
(64,),
|
|
(1,),
|
|
dtype=torch.uint32,
|
|
device=torch.device("cuda:0"),
|
|
group_name="0",
|
|
).fill_(0)
|
|
|
|
_SymmetricMemory.memset32(t, offset=32, val=1, count=16)
|
|
self.assertTrue(t[:32].eq(0).all())
|
|
self.assertTrue(t[32:48].eq(1).all())
|
|
self.assertTrue(t[48:].eq(0).all())
|
|
|
|
with self.assertRaisesRegex(
|
|
RuntimeError, "input must be a flat, contiguous uint32 tensor"
|
|
):
|
|
_SymmetricMemory.memset32(t.view(8, 8), offset=0, val=1, count=1)
|
|
|
|
with self.assertRaisesRegex(
|
|
RuntimeError, "input must be a flat, contiguous uint32 tensor"
|
|
):
|
|
_SymmetricMemory.memset32(t.view(torch.float32), offset=0, val=1, count=1)
|
|
|
|
with self.assertRaisesRegex(
|
|
RuntimeError, "offset must be greater than or equal to 0"
|
|
):
|
|
_SymmetricMemory.memset32(t, offset=-1, val=1, count=1)
|
|
|
|
with self.assertRaisesRegex(
|
|
RuntimeError, r"val must be in the range of.*\(uint32_t\)"
|
|
):
|
|
_SymmetricMemory.memset32(t, offset=0, val=4294967296, count=1)
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "count must be a positive integer"):
|
|
_SymmetricMemory.memset32(t, offset=0, val=1, count=-1)
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "count must be a positive integer"):
|
|
_SymmetricMemory.memset32(t, offset=0, val=1, count=0)
|
|
|
|
with self.assertRaisesRegex(
|
|
RuntimeError, r"offset \+ count.*exceeded the numel of the input"
|
|
):
|
|
_SymmetricMemory.memset32(t, offset=64, val=1, count=1)
|
|
|
|
with self.assertRaisesRegex(
|
|
RuntimeError, r"offset \+ count.*exceeded the numel of the input"
|
|
):
|
|
_SymmetricMemory.memset32(t, offset=0, val=1, count=65)
|
|
|
|
_SymmetricMemory.memset32(t, offset=0, val=1, count=64)
|
|
_SymmetricMemory.memset32(t, offset=63, val=1, count=1)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
run_tests()
|