mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Enable local tensor mode DTensor tests for the optimizers, op strategy, matrix ops, math ops, init ops, experimental ops, embedding ops, dynamic, convolution ops, main api. Pull Request resolved: https://github.com/pytorch/pytorch/pull/166105 Approved by: https://github.com/ezyang
624 lines
24 KiB
Python
624 lines
24 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates
|
|
# Owner(s): ["oncall: distributed"]
|
|
|
|
import itertools
|
|
import unittest
|
|
from typing import cast, Optional
|
|
|
|
import torch
|
|
import torch.nn.functional as F
|
|
from torch.distributed import init_device_mesh
|
|
from torch.distributed.tensor import (
|
|
distribute_tensor,
|
|
DTensor,
|
|
Partial,
|
|
Placement,
|
|
Replicate,
|
|
Shard,
|
|
)
|
|
from torch.distributed.tensor.debug import CommDebugMode
|
|
from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FP8, SM90OrLater
|
|
from torch.testing._internal.common_device_type import E4M3_MAX_POS, e4m3_type
|
|
from torch.testing._internal.common_utils import (
|
|
instantiate_parametrized_tests,
|
|
parametrize,
|
|
run_tests,
|
|
TEST_WITH_ROCM,
|
|
)
|
|
from torch.testing._internal.distributed._tensor.common_dtensor import (
|
|
create_local_tensor_test_class,
|
|
DTensorTestBase,
|
|
skip_unless_torch_gpu,
|
|
with_comms,
|
|
)
|
|
|
|
|
|
funcol = torch.ops.c10d_functional
|
|
|
|
|
|
def scale_for_fp8(
|
|
t: torch.Tensor, scale_shape: tuple[int]
|
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
if all(d == 1 for d in scale_shape):
|
|
t = t.unsqueeze(0).unsqueeze(-2)
|
|
else:
|
|
t = t.unflatten(0, (scale_shape[0], -1)).unflatten(-1, (scale_shape[1], -1))
|
|
|
|
scale = t.abs().amax(dim=[1, -1]).float() / E4M3_MAX_POS
|
|
t_fp8 = (t / scale[:, None, :, None]).to(e4m3_type)
|
|
|
|
return t_fp8.flatten(end_dim=1).flatten(start_dim=-2), scale.view(scale_shape)
|
|
|
|
|
|
class DistMatrixOpsTest(DTensorTestBase):
|
|
@with_comms
|
|
def test_addmm(self):
|
|
device_mesh = self.build_device_mesh()
|
|
shard_spec = [Shard(0)]
|
|
replica_spec = [Replicate()]
|
|
|
|
tensor_to_shard = torch.randn(12, 8)
|
|
mat1 = distribute_tensor(tensor_to_shard, device_mesh, shard_spec)
|
|
tensor_to_replicate = torch.randn(8, 4)
|
|
mat2 = distribute_tensor(tensor_to_replicate, device_mesh, replica_spec)
|
|
input_tensor = torch.randn(4)
|
|
input = distribute_tensor(input_tensor, device_mesh, replica_spec)
|
|
|
|
dist_res = torch.addmm(input, mat1, mat2)
|
|
local_res = torch.addmm(input_tensor, tensor_to_shard, tensor_to_replicate)
|
|
self.assertEqual(dist_res.full_tensor(), local_res)
|
|
|
|
@with_comms
|
|
def test_addmm_empty_operand(self):
|
|
device_mesh = self.build_device_mesh()
|
|
shard_spec = [Shard(0)]
|
|
replica_spec = [Replicate()]
|
|
|
|
tensor_to_shard = torch.randn(12, 0)
|
|
mat1 = distribute_tensor(tensor_to_shard, device_mesh, shard_spec)
|
|
tensor_to_replicate = torch.randn(0, 4)
|
|
mat2 = distribute_tensor(tensor_to_replicate, device_mesh, replica_spec)
|
|
input_tensor = torch.randn(4)
|
|
inp = distribute_tensor(input_tensor, device_mesh, replica_spec)
|
|
|
|
dist_res = torch.addmm(inp, mat1, mat2)
|
|
local_res = torch.addmm(input_tensor, tensor_to_shard, tensor_to_replicate)
|
|
self.assertEqual(dist_res.full_tensor(), local_res)
|
|
|
|
@with_comms
|
|
def test_addmm_auto_redistribute(self):
|
|
device_mesh = self.build_device_mesh()
|
|
shard0_spec = [Shard(0)]
|
|
shard1_spec = [Shard(1)]
|
|
replica_spec = [Replicate()]
|
|
|
|
tensor_to_shard1 = torch.randn(12, 8, requires_grad=True)
|
|
mat1 = distribute_tensor(tensor_to_shard1, device_mesh, shard1_spec)
|
|
tensor_to_shard0 = torch.randn(8, 4, requires_grad=True)
|
|
mat2 = distribute_tensor(tensor_to_shard0, device_mesh, shard0_spec)
|
|
input_tensor = torch.randn(4, requires_grad=True)
|
|
input = distribute_tensor(input_tensor, device_mesh, replica_spec)
|
|
|
|
local_res = torch.addmm(input_tensor, tensor_to_shard1, tensor_to_shard0)
|
|
dist_res = torch.addmm(input, mat1, mat2)
|
|
|
|
# test if addmm output is a partial
|
|
self.assertIsInstance(dist_res, DTensor)
|
|
self.assertIsInstance(dist_res.placements[0], Partial)
|
|
|
|
# test if result is the same as tensor
|
|
dist_local_res = dist_res.full_tensor()
|
|
self.assertEqual(local_res, dist_local_res)
|
|
|
|
# backward checks
|
|
dist_local_res.sum().backward()
|
|
local_res.sum().backward()
|
|
self.assertIsNotNone(mat2.grad)
|
|
self.assertEqual(mat2.grad.full_tensor(), tensor_to_shard0.grad)
|
|
|
|
@with_comms
|
|
def test_mm(self):
|
|
device_mesh = self.build_device_mesh()
|
|
shard0_spec = Shard(0)
|
|
shard1_spec = Shard(1)
|
|
replica_spec = Replicate()
|
|
|
|
t1 = torch.randn(12, 8, requires_grad=True)
|
|
t2 = torch.randn(8, 16, requires_grad=True)
|
|
local_res = torch.mm(t1, t2)
|
|
|
|
def test_placement_comb(
|
|
placements1: list[Placement], placements2: list[Placement]
|
|
) -> None:
|
|
dt1 = distribute_tensor(t1, device_mesh, placements1)
|
|
dt2 = distribute_tensor(t2, device_mesh, placements2)
|
|
dist_res: DTensor = cast(DTensor, torch.mm(dt1, dt2)).redistribute(
|
|
device_mesh, [replica_spec]
|
|
)
|
|
self.assertEqual(dist_res.to_local(), local_res)
|
|
# backward
|
|
grad_dist_res = torch.ones_like(dist_res)
|
|
dist_res.backward(grad_dist_res)
|
|
self.assertIsNotNone(dt1.grad)
|
|
|
|
placement_specs = [shard0_spec, shard1_spec, replica_spec]
|
|
shard_specs_comb = list(itertools.product(placement_specs, placement_specs))
|
|
for spec in shard_specs_comb:
|
|
test_placement_comb([spec[0]], [spec[1]])
|
|
|
|
@with_comms
|
|
@skip_unless_torch_gpu
|
|
@unittest.skipIf(
|
|
not PLATFORM_SUPPORTS_FP8,
|
|
"FP8 is only supported on H100+, SM 8.9 and MI300+ devices",
|
|
)
|
|
def test_scaled_mm(self):
|
|
device_mesh = self.build_device_mesh()
|
|
shrd0 = Shard(0)
|
|
shrd1 = Shard(1)
|
|
repl = Replicate()
|
|
part = Partial()
|
|
|
|
ws = self.world_size
|
|
# _scaled_mm requires all dimensions to be multiples of 16. Since we'll
|
|
# shard along n and k, we need to ensure this stays true on each rank.
|
|
m, n, k = 16, 32 * ws, 16 * ws
|
|
|
|
t1 = torch.randn(m, k, device=self.device_type, dtype=torch.bfloat16)
|
|
t2 = torch.randn(n, k, device=self.device_type, dtype=torch.bfloat16)
|
|
|
|
for (
|
|
output_spec,
|
|
t1_spec,
|
|
t2_spec,
|
|
scale1_shape,
|
|
scale2_shape,
|
|
scale1_spec,
|
|
scale2_spec,
|
|
) in [
|
|
# Tensor-wise scaling
|
|
# Replicated, zero-dim scale
|
|
(repl, repl, repl, (), (), repl, repl),
|
|
# Column-parallel, two-dim scale
|
|
(shrd1, repl, shrd0, (1, 1), (1, 1), repl, repl),
|
|
# Row-parallel, one-dim scale
|
|
(part, shrd1, shrd1, (1,), (1,), repl, repl),
|
|
# Row-wise scaling
|
|
# Replicated
|
|
(repl, repl, repl, (m, 1), (n, 1), repl, repl),
|
|
# Column-parallel
|
|
(shrd1, repl, shrd0, (m, 1), (n, 1), repl, shrd0),
|
|
# Row-parallel (which actually ends up doing sub-row-wise scaling)
|
|
(part, shrd1, shrd1, (m, ws), (n, ws), shrd1, shrd1),
|
|
]:
|
|
full_ref_res = t1 @ t2.t()
|
|
|
|
t1_fp8, scale1 = scale_for_fp8(t1, scale1_shape)
|
|
t2_fp8, scale2 = scale_for_fp8(t2, scale2_shape)
|
|
|
|
dist_t1_fp8 = distribute_tensor(t1_fp8, device_mesh, [t1_spec])
|
|
dist_t2_fp8 = distribute_tensor(t2_fp8, device_mesh, [t2_spec])
|
|
dist_scale1 = distribute_tensor(scale1, device_mesh, [scale1_spec])
|
|
dist_scale2 = distribute_tensor(scale2, device_mesh, [scale2_spec])
|
|
|
|
with CommDebugMode() as comm_mode:
|
|
dist_res = cast(
|
|
DTensor,
|
|
torch._scaled_mm(
|
|
dist_t1_fp8,
|
|
dist_t2_fp8.t(),
|
|
scale_a=dist_scale1,
|
|
scale_b=dist_scale2.t(),
|
|
out_dtype=torch.bfloat16,
|
|
),
|
|
)
|
|
|
|
self.assertEqual(dist_res.placements[0], output_spec)
|
|
|
|
full_dist_res = dist_res.full_tensor()
|
|
# Fp8 matmuls are quite inaccurate, we need high tolerances
|
|
self.assertEqual(full_dist_res, full_ref_res, atol=1.5, rtol=7e-2)
|
|
|
|
self.assertEqual(comm_mode.get_total_counts(), 0)
|
|
|
|
@with_comms
|
|
def test_matmul(self):
|
|
device_mesh = self.build_device_mesh()
|
|
dim = 128
|
|
x = torch.randn(8, dim)
|
|
A = torch.randn(dim, dim)
|
|
y = torch.matmul(x, A)
|
|
|
|
# Prepare DTensors
|
|
dx = distribute_tensor(x, device_mesh, [Replicate()])
|
|
dA = distribute_tensor(A, device_mesh, [Shard(0)])
|
|
|
|
# Use `inference_mode` to test DTensor's capability of decomposing
|
|
# `matmul` op
|
|
with torch.inference_mode():
|
|
dy = torch.matmul(dx, dA)
|
|
|
|
self.assertEqual(y, dy.full_tensor())
|
|
|
|
@with_comms
|
|
def test_t(self):
|
|
device_mesh = self.build_device_mesh()
|
|
shard_spec = [Shard(0)]
|
|
|
|
tensor_to_transpose = torch.randn(12, 8, requires_grad=True)
|
|
mat = distribute_tensor(tensor_to_transpose, device_mesh, shard_spec)
|
|
tranposed_mat = mat.t()
|
|
self.assertEqual(tranposed_mat.size(), torch.Size([8, 12]))
|
|
self.assertEqual(tranposed_mat.placements, [Shard(1)])
|
|
tranposed_mat2 = tranposed_mat.t()
|
|
self.assertEqual(tranposed_mat2.size(), torch.Size([12, 8]))
|
|
self.assertEqual(tranposed_mat2.placements, shard_spec)
|
|
|
|
@with_comms
|
|
def test_t_partial(self):
|
|
device_mesh = self.build_device_mesh()
|
|
|
|
a = torch.randn(12, 8)
|
|
b = torch.randn(8, 4)
|
|
c = torch.mm(a, b).t()
|
|
|
|
da = distribute_tensor(a, device_mesh, [Shard(1)])
|
|
db = distribute_tensor(b, device_mesh, [Shard(0)])
|
|
|
|
# mm(da, db) should return a Partial tensor.
|
|
# transposing it should keep it Partial
|
|
dc = torch.mm(da, db).t()
|
|
|
|
self.assertTrue(isinstance(dc.placements[0], Partial))
|
|
|
|
# check that the local and distributed op results match
|
|
self.assertEqual(
|
|
c,
|
|
dc.redistribute(device_mesh, [Replicate()]).to_local(),
|
|
)
|
|
|
|
# baddbmm introduces nan occasionally on CPU: https://github.com/pytorch/pytorch/issues/80588
|
|
@with_comms
|
|
@skip_unless_torch_gpu
|
|
def test_baddbmm(self):
|
|
device_mesh = self.build_device_mesh()
|
|
tensor = torch.rand(4, 4, 8, device=self.device_type, requires_grad=True)
|
|
batch_1 = torch.rand(4, 4, 8, device=self.device_type, requires_grad=True)
|
|
batch_2 = torch.rand(4, 8, 8, device=self.device_type, requires_grad=True)
|
|
|
|
def test_placement_comb(
|
|
tensor_placements: list[Placement],
|
|
batch_1_placements: list[Placement],
|
|
batch_2_placements: list[Placement],
|
|
beta: int,
|
|
alpha: int,
|
|
batch_1_grad: Optional[torch.Tensor],
|
|
) -> None:
|
|
tensor_dt = distribute_tensor(tensor, device_mesh, tensor_placements)
|
|
batch_1_dt = distribute_tensor(batch_1, device_mesh, batch_1_placements)
|
|
batch_2_dt = distribute_tensor(batch_2, device_mesh, batch_2_placements)
|
|
dist_res = cast(
|
|
DTensor,
|
|
torch.baddbmm(
|
|
tensor_dt, batch_1_dt, batch_2_dt, beta=beta, alpha=alpha
|
|
),
|
|
).redistribute(device_mesh, [Replicate()])
|
|
dist_local_res = dist_res.to_local()
|
|
assert not torch.isnan(local_result).any()
|
|
assert not torch.isnan(dist_local_res).any()
|
|
self.assertEqual(dist_local_res.detach(), local_result.detach())
|
|
|
|
# TODO: add test backward
|
|
# grad_dist_res = torch.ones_like(dist_res)
|
|
# dist_res.backward(grad_dist_res)
|
|
# self.assertIsNotNone(batch_1_dt.grad)
|
|
# batch_1_grad_local = batch_1_dt.grad.redistribute(
|
|
# device_mesh, [Replicate()]
|
|
# ).to_local()
|
|
# self.assertEqual(batch_1_grad_local, batch_1_grad)
|
|
|
|
shard0_spec = Shard(0)
|
|
shard1_spec = Shard(1)
|
|
shard2_spec = Shard(2)
|
|
replica_spec = Replicate()
|
|
shard_specs = [shard0_spec, shard1_spec, shard2_spec, replica_spec]
|
|
shard_specs_comb = list(
|
|
itertools.product(shard_specs, shard_specs, shard_specs)
|
|
)
|
|
# If beta is 0, input tensor will be ignored
|
|
numeric_params_comb = [
|
|
(0.0, 0.5), # zero-beta
|
|
(0.8, 0.5), # non-zero-beta
|
|
]
|
|
|
|
for beta, alpha in numeric_params_comb:
|
|
local_result = torch.baddbmm(
|
|
tensor, batch_1, batch_2, beta=beta, alpha=alpha
|
|
)
|
|
grad_local_res = torch.ones_like(local_result)
|
|
local_result.backward(grad_local_res)
|
|
# test all combos
|
|
for spec in shard_specs_comb:
|
|
test_placement_comb(
|
|
[spec[0]], [spec[1]], [spec[2]], beta, alpha, batch_1.grad
|
|
)
|
|
|
|
@with_comms
|
|
def test_bmm(self):
|
|
device_mesh = self.build_device_mesh()
|
|
mat1 = torch.rand(4, 8, 4, device=self.device_type, requires_grad=True)
|
|
mat2 = torch.rand(4, 4, 8, device=self.device_type, requires_grad=True)
|
|
local_result = torch.bmm(mat1, mat2)
|
|
grad_local_res = torch.ones_like(local_result)
|
|
local_result.backward(grad_local_res)
|
|
|
|
def test_placement_comb(
|
|
placements1: list[Placement],
|
|
placements2: list[Placement],
|
|
) -> None:
|
|
mat1_dt = distribute_tensor(mat1, device_mesh, placements1)
|
|
mat2_dt = distribute_tensor(mat2, device_mesh, placements2)
|
|
dist_res = cast(DTensor, torch.bmm(mat1_dt, mat2_dt)).redistribute(
|
|
device_mesh, [Replicate()]
|
|
)
|
|
dist_local_res = dist_res.to_local()
|
|
self.assertEqual(dist_local_res, local_result)
|
|
|
|
# test backward
|
|
# TODO: figure out (replicate, shard1) fail on backward
|
|
# it generates a different grad shape
|
|
grad_dist_res = torch.ones_like(dist_res)
|
|
dist_res.backward(grad_dist_res)
|
|
self.assertIsNotNone(mat1_dt.grad)
|
|
mat1_dt_grad = cast(DTensor, mat1_dt.grad)
|
|
mat1_grad_local = mat1_dt_grad.redistribute(
|
|
device_mesh, [Replicate()]
|
|
).to_local()
|
|
self.assertEqual(mat1_grad_local, mat1.grad)
|
|
|
|
shard0_spec = Shard(0)
|
|
shard1_spec = Shard(1)
|
|
shard2_spec = Shard(2)
|
|
replica_spec = Replicate()
|
|
placement_specs = [shard0_spec, shard1_spec, shard2_spec, replica_spec]
|
|
shard_specs_comb = list(itertools.product(placement_specs, placement_specs))
|
|
|
|
# tests that currently pass
|
|
for spec in shard_specs_comb:
|
|
test_placement_comb([spec[0]], [spec[1]])
|
|
|
|
@with_comms
|
|
@skip_unless_torch_gpu
|
|
def test_scaled_dot_product_attention(self):
|
|
device_mesh = self.build_device_mesh()
|
|
comm_mode = CommDebugMode()
|
|
# bsz, n_heads, slen, head_dim
|
|
query = torch.rand(
|
|
(4, 8, 8, 8),
|
|
device=self.device_type,
|
|
dtype=torch.bfloat16,
|
|
requires_grad=True,
|
|
)
|
|
key = torch.rand(
|
|
(4, 8, 8, 8),
|
|
device=self.device_type,
|
|
dtype=torch.bfloat16,
|
|
requires_grad=True,
|
|
)
|
|
value = torch.rand(
|
|
(4, 8, 8, 8),
|
|
device=self.device_type,
|
|
dtype=torch.bfloat16,
|
|
requires_grad=True,
|
|
)
|
|
|
|
from torch.nn.attention import sdpa_kernel, SDPBackend
|
|
|
|
available_backends = []
|
|
dropout_p = 0.0
|
|
# TODO: Add test cases where is_causal=False and an attention mask is provided.
|
|
# Gaps include missing op support for aten.masked_fill_.Scalar.
|
|
is_causal = True
|
|
enable_gqa = False
|
|
params = torch.backends.cuda.SDPAParams(
|
|
query, key, value, None, dropout_p, is_causal, enable_gqa
|
|
)
|
|
if torch.backends.cuda.can_use_flash_attention(params, debug=False):
|
|
available_backends.append(SDPBackend.FLASH_ATTENTION)
|
|
if torch.backends.cuda.can_use_efficient_attention(params, debug=False):
|
|
available_backends.append(SDPBackend.EFFICIENT_ATTENTION)
|
|
|
|
placement_specs = [(Replicate(),), (Shard(0),), (Shard(1),)]
|
|
for backend, input_placements in itertools.product(
|
|
available_backends, placement_specs
|
|
):
|
|
dist_query = distribute_tensor(query, device_mesh, input_placements)
|
|
dist_key = distribute_tensor(key, device_mesh, input_placements)
|
|
dist_value = distribute_tensor(value, device_mesh, input_placements)
|
|
with sdpa_kernel(backends=[backend]):
|
|
out = F.scaled_dot_product_attention(
|
|
query, key, value, dropout_p=dropout_p, is_causal=is_causal
|
|
)
|
|
with comm_mode:
|
|
dist_out = F.scaled_dot_product_attention(
|
|
dist_query,
|
|
dist_key,
|
|
dist_value,
|
|
dropout_p=dropout_p,
|
|
is_causal=is_causal,
|
|
)
|
|
self.assertEqual(comm_mode.get_total_counts(), 0)
|
|
self.assertEqual(dist_out.placements, input_placements)
|
|
self.assertEqual(dist_out.full_tensor(), out)
|
|
|
|
out.sum().backward()
|
|
with comm_mode:
|
|
dist_out.sum().backward()
|
|
self.assertEqual(comm_mode.get_total_counts(), 0)
|
|
self.assertEqual(dist_query.grad.placements, input_placements)
|
|
self.assertEqual(dist_query.grad.full_tensor(), query.grad)
|
|
self.assertEqual(dist_key.grad.placements, input_placements)
|
|
self.assertEqual(dist_key.grad.full_tensor(), key.grad)
|
|
self.assertEqual(dist_value.grad.placements, input_placements)
|
|
self.assertEqual(dist_value.grad.full_tensor(), value.grad)
|
|
query.grad.zero_()
|
|
key.grad.zero_()
|
|
value.grad.zero_()
|
|
|
|
@skip_unless_torch_gpu
|
|
@with_comms()
|
|
def test_dtensor_mm(self):
|
|
"""
|
|
Test mm with DTensor with 2D mesh.
|
|
We need to add the test here since we only test 1D mesh in test_dtensor_ops.py.
|
|
Also, we added tests for the corner case where one of the 2D dimension is 1.
|
|
|
|
# TODO: we need to test more DTensor ops with 2D mesh, especially when 1 of the
|
|
mesh dimension of the 2D mesh is 1.
|
|
"""
|
|
mesh_0 = init_device_mesh(self.device_type, (self.world_size // 2, 2))
|
|
mesh_1 = init_device_mesh(self.device_type, (self.world_size, 1))
|
|
mesh_2 = init_device_mesh(self.device_type, (1, self.world_size))
|
|
|
|
for mesh in [mesh_0, mesh_1, mesh_2]:
|
|
lhs = torch.randn(256, 128)
|
|
rhs = torch.randn(128, 256)
|
|
mm_result = lhs @ rhs
|
|
|
|
lhs_dtensor = distribute_tensor(lhs, mesh, [Shard(dim=0), Replicate()])
|
|
rhs_dtensor = distribute_tensor(rhs, mesh, [Replicate(), Shard(dim=1)])
|
|
dtensor_result = lhs_dtensor @ rhs_dtensor
|
|
self.assertEqual(
|
|
dtensor_result.full_tensor(), mm_result, atol=1.5e-5, rtol=1e-6
|
|
)
|
|
|
|
@with_comms
|
|
@skip_unless_torch_gpu
|
|
def test_tensordot_shampoo(self):
|
|
"""
|
|
Create a simple test for Shampoo's use case.
|
|
"""
|
|
device_mesh = self.build_device_mesh()
|
|
|
|
local_a = torch.randn(4, 4)
|
|
local_b = torch.randn(4, 15)
|
|
dims = ([0], [0])
|
|
local_result = torch.tensordot(local_a, local_b, dims=(dims))
|
|
|
|
placements = [Replicate(), Shard(0), Shard(1)]
|
|
placements_tuples = itertools.product(placements, repeat=2)
|
|
|
|
for placement1, placement2 in placements_tuples:
|
|
dist_a = distribute_tensor(local_a, device_mesh, [placement1])
|
|
dist_b = distribute_tensor(local_b, device_mesh, [placement2])
|
|
dist_result = torch.tensordot(dist_a, dist_b, dims=dims)
|
|
dist_result_full = dist_result.full_tensor()
|
|
self.assertEqual(local_result, dist_result_full)
|
|
|
|
@unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support CUTLASS")
|
|
@unittest.skipIf(not SM90OrLater, "Grouped gemm supported on SM90")
|
|
@with_comms
|
|
@skip_unless_torch_gpu
|
|
@parametrize(
|
|
"kwargs",
|
|
[
|
|
{
|
|
# 2D x 3D case from MoE layer
|
|
"inp_shape": (64, 16),
|
|
"w1_shape": (2, 16, 32),
|
|
"w2_shape": (2, 32, 16),
|
|
"inp_placements": [Replicate()],
|
|
"w1_placements": [Shard(2)],
|
|
"w2_placements": [Shard(1)],
|
|
"expected_comm_counts_fwd": 0,
|
|
"expected_comm_counts_bwd": 1,
|
|
"expected_out_placements": [Partial()],
|
|
},
|
|
{
|
|
# Case that would have invalid strides on inp * mat1 when sharded
|
|
"inp_shape": (64, 16),
|
|
"w1_shape": (2, 16, 16),
|
|
"w2_shape": (2, 16, 16),
|
|
"inp_placements": [Replicate()],
|
|
"w1_placements": [Shard(2)],
|
|
"w2_placements": [Shard(1)],
|
|
"expected_comm_counts_fwd": 2,
|
|
"expected_comm_counts_bwd": 4,
|
|
"expected_out_placements": [Replicate()],
|
|
},
|
|
],
|
|
)
|
|
def test_grouped_mm(self, kwargs):
|
|
# TODO: torch._grouped_mm can take inputs of dimension (2D, 3D) x (2D, 3D)
|
|
# More tests need to be added.
|
|
device_mesh = self.build_device_mesh()
|
|
comm_mode = CommDebugMode()
|
|
dtype = torch.bfloat16
|
|
inp = torch.rand(
|
|
*kwargs["inp_shape"],
|
|
device=self.device_type,
|
|
dtype=dtype,
|
|
requires_grad=True,
|
|
)
|
|
w1 = torch.rand(
|
|
*kwargs["w1_shape"],
|
|
device=self.device_type,
|
|
dtype=dtype,
|
|
requires_grad=True,
|
|
)
|
|
w2 = torch.rand(
|
|
*kwargs["w2_shape"],
|
|
device=self.device_type,
|
|
dtype=dtype,
|
|
requires_grad=True,
|
|
)
|
|
offs = torch.tensor([16, 64], device=self.device_type, dtype=torch.int32)
|
|
|
|
h = torch._grouped_mm(inp, w1, offs=offs)
|
|
out = torch._grouped_mm(h, w2, offs=offs)
|
|
|
|
dist_inp = distribute_tensor(inp, device_mesh, kwargs["inp_placements"])
|
|
# colwise sharded
|
|
dist_w1 = distribute_tensor(w1, device_mesh, kwargs["w1_placements"])
|
|
# rowwise sharded
|
|
dist_w2 = distribute_tensor(w2, device_mesh, kwargs["w2_placements"])
|
|
dist_offs = distribute_tensor(offs, device_mesh, [Replicate()])
|
|
|
|
with comm_mode:
|
|
dist_h = torch._grouped_mm(dist_inp, dist_w1, offs=dist_offs)
|
|
dist_out = torch._grouped_mm(dist_h, dist_w2, offs=dist_offs)
|
|
self.assertEqual(
|
|
comm_mode.get_total_counts(), kwargs["expected_comm_counts_fwd"]
|
|
)
|
|
self.assertEqual(dist_out.placements, kwargs["expected_out_placements"])
|
|
self.assertEqual(dist_out.full_tensor(), out)
|
|
|
|
out_grad = torch.ones_like(out)
|
|
out.backward(out_grad)
|
|
|
|
dist_out = dist_out.redistribute(device_mesh, [Shard(0)])
|
|
dist_out_grad = distribute_tensor(out_grad, device_mesh, [Shard(0)])
|
|
|
|
with comm_mode:
|
|
dist_out.backward(dist_out_grad)
|
|
self.assertEqual(
|
|
comm_mode.get_total_counts(), kwargs["expected_comm_counts_bwd"]
|
|
)
|
|
self.assertEqual(
|
|
comm_mode.get_comm_counts()[funcol.all_gather_into_tensor],
|
|
kwargs["expected_comm_counts_bwd"],
|
|
)
|
|
self.assertEqual(dist_inp.grad.full_tensor(), inp.grad)
|
|
self.assertEqual(dist_w1.grad.full_tensor(), w1.grad)
|
|
self.assertEqual(dist_w2.grad.full_tensor(), w2.grad)
|
|
|
|
|
|
instantiate_parametrized_tests(DistMatrixOpsTest)
|
|
|
|
DistMatrixOpsTestWithLocalTensor = create_local_tensor_test_class(
|
|
DistMatrixOpsTest,
|
|
)
|
|
|
|
if __name__ == "__main__":
|
|
run_tests()
|