[dtensor] add op support for torch._grouped_mm (#151072)

This PR would make TP work with Grouped MM in MoE implementations like https://github.com/pytorch/torchtitan/pull/1084

Pull Request resolved: https://github.com/pytorch/pytorch/pull/151072
Approved by: https://github.com/wanchaol, https://github.com/wwwjn
This commit is contained in:
Tianyu Liu 2025-04-11 15:47:13 -07:00 committed by PyTorch MergeBot
parent 0c59a031c8
commit 7dd2ed1197
3 changed files with 261 additions and 3 deletions

View File

@ -17,9 +17,9 @@ from torch.distributed.tensor import (
Shard,
)
from torch.distributed.tensor.debug import CommDebugMode
from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FP8
from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FP8, SM90OrLater
from torch.testing._internal.common_device_type import E4M3_MAX_POS, e4m3_type
from torch.testing._internal.common_utils import run_tests
from torch.testing._internal.common_utils import run_tests, TEST_WITH_ROCM
from torch.testing._internal.distributed._tensor.common_dtensor import (
DTensorTestBase,
skip_unless_torch_gpu,
@ -27,6 +27,9 @@ from torch.testing._internal.distributed._tensor.common_dtensor import (
)
funcol = torch.ops.c10d_functional
def scale_for_fp8(
t: torch.Tensor, scale_shape: tuple[int]
) -> tuple[torch.Tensor, torch.Tensor]:
@ -501,6 +504,63 @@ class DistMatrixOpsTest(DTensorTestBase):
dist_result_full = dist_result.full_tensor()
self.assertEqual(local_result, dist_result_full)
@unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support CUTLASS")
@unittest.skipIf(not SM90OrLater, "Grouped gemm supported on SM90")
@with_comms
@skip_unless_torch_gpu
def test_grouped_mm(self):
# TODO: torch._grouped_mm can take inputs of dimension (2D, 3D) x (2D, 3D)
# Here we only test the 2D x 3D Tensor Parallel use case in an MoE layer.
# More tests need to be added.
device_mesh = init_device_mesh(self.device_type, (self.world_size,))
comm_mode = CommDebugMode()
dtype = torch.bfloat16
inp = torch.rand(
64, 16, device=self.device_type, dtype=dtype, requires_grad=True
)
w1 = torch.rand(
2, 16, 32, device=self.device_type, dtype=dtype, requires_grad=True
)
w2 = torch.rand(
2, 32, 16, device=self.device_type, dtype=dtype, requires_grad=True
)
offs = torch.tensor([16, 64], device=self.device_type, dtype=torch.int32)
h = torch._grouped_mm(inp, w1, offs=offs)
out = torch._grouped_mm(h, w2, offs=offs)
dist_inp = distribute_tensor(inp, device_mesh, [Replicate()])
# colwise sharded
dist_w1 = distribute_tensor(w1, device_mesh, [Shard(2)])
# rowwise sharded
dist_w2 = distribute_tensor(w2, device_mesh, [Shard(1)])
dist_offs = distribute_tensor(offs, device_mesh, [Replicate()])
with comm_mode:
dist_h = torch._grouped_mm(dist_inp, dist_w1, offs=dist_offs)
dist_out = torch._grouped_mm(dist_h, dist_w2, offs=dist_offs)
self.assertEqual(comm_mode.get_total_counts(), 0)
self.assertTrue(dist_out.placements[0].is_partial())
self.assertEqual(dist_out.full_tensor(), out)
out_grad = torch.ones_like(out)
out.backward(out_grad)
dist_out = dist_out.redistribute(device_mesh, [Shard(0)])
dist_out_grad = distribute_tensor(out_grad, device_mesh, [Shard(0)])
with comm_mode:
dist_out.backward(dist_out_grad)
self.assertEqual(comm_mode.get_total_counts(), 1)
self.assertEqual(
comm_mode.get_comm_counts()[funcol.all_gather_into_tensor],
1,
)
self.assertEqual(dist_inp.grad.full_tensor(), inp.grad)
self.assertEqual(dist_w1.grad.full_tensor(), w1.grad)
self.assertEqual(dist_w2.grad.full_tensor(), w2.grad)
if __name__ == "__main__":
run_tests()

View File

@ -7241,6 +7241,59 @@ def sigmoid(self: Tensor) -> Tensor:
return torch.empty_like(self, dtype=result_dtype)
@register_meta(aten._grouped_mm)
@out_wrapper()
def grouped_mm(
mat1: Tensor,
mat2: Tensor,
offs: Optional[Tensor] = None,
bias: Optional[Tensor] = None,
out_dtype: Optional[torch.dtype] = None,
) -> Tensor:
torch._check(mat1.dim() == 2 or mat1.dim() == 3, lambda: "mat1 must be 2d or 3d")
torch._check(mat2.dim() == 2 or mat2.dim() == 3, lambda: "mat2 must be 2d or 3d")
torch._check(
(offs is not None) == (mat1.dim() == 2 or mat2.dim() == 2),
lambda: "Have to provide offsets if there is a 2d matrix, or no offset if both matrices are 3d",
)
if offs is not None:
torch._check(offs.dim() == 1, lambda: "offsets must be 1d")
out_dtype = out_dtype or mat1.dtype
torch._check(bias is None, lambda: "bias not supported yet")
def _compute_grouped_gemm_output_size(mat1, mat2, offs):
mat1_is_2d = mat1.dim() == 2
mat2_is_2d = mat2.dim() == 2
if mat1_is_2d:
if mat2_is_2d:
return offs.size(0), mat1.size(0), mat2.size(1)
else:
torch._check(
offs.size(0) == mat2.size(0), "matrix batch sizes have to match"
)
return mat1.size(0), mat2.size(-1)
else:
if mat2_is_2d:
torch._check(
offs.size(0) == mat1.size(0), "matrix batch sizes have to match"
)
return mat1.size(1), mat2.size(1)
else:
# regular bmm
torch._check(
mat1.size(0) == mat2.size(0), "batched dimension has to match"
)
return mat1.size(0), mat1.size(1), mat2.size(-1)
out_size = _compute_grouped_gemm_output_size(mat1, mat2, offs)
out = mat1.new_empty(out_size, dtype=out_dtype)
return out
@register_meta(aten._softmax)
@out_wrapper()
def softmax(x: Tensor, dim: int, half_to_float: bool) -> Tensor:

View File

@ -24,7 +24,12 @@ from torch.distributed.tensor._ops.utils import (
prod,
register_op_strategy,
)
from torch.distributed.tensor.placement_types import Placement, Replicate, Shard
from torch.distributed.tensor.placement_types import (
Partial,
Placement,
Replicate,
Shard,
)
aten = torch.ops.aten
@ -887,3 +892,143 @@ def scaled_scaled_dot_product_cudnn_attention_backward_strategy(
return expand_to_full_mesh_op_strategy(
mesh, op_schema, single_mesh_dim_strategies, input_index=3
)
@register_op_strategy(aten._grouped_mm.default)
def grouped_mm_strategy(op_schema: OpSchema) -> OpStrategy:
mesh = op_schema.get_mesh_from_args()
mat1_strategy = op_schema.args_schema[0]
assert isinstance(mat1_strategy, OpStrategy)
mat2_strategy = op_schema.args_schema[1]
assert isinstance(mat2_strategy, OpStrategy)
if len(op_schema.args_schema) > 3:
bias_strategy = op_schema.args_schema[3]
assert bias_strategy is None, "grouped_mm doesn't support bias yet"
single_mesh_dim_strategies = []
offs_placement = None
if len(op_schema.args_schema) > 2 and op_schema.args_schema[2] is not None:
offs_placement = Replicate() # offs should always be replicated
all_replicate: PlacementList = [
Replicate(),
Replicate(), # mat1
Replicate(), # mat2
offs_placement, # offs
None, # bias
]
partial_replicate: PlacementList = [
Partial(),
Partial(), # mat1
Replicate(), # mat2
offs_placement, # offs
None, # bias
]
replicate_partial: PlacementList = [
Partial(),
Replicate(), # mat1
Partial(), # mat2
offs_placement, # offs
None, # bias
]
single_mesh_dim_strategies = [all_replicate, partial_replicate, replicate_partial]
if mat1_strategy.ndim == 2 and mat2_strategy.ndim == 3:
# rowwise_replicate for 2dx3d not supported
replicate_colwise_2x3: PlacementList = [
Shard(1),
Replicate(), # mat1
Shard(2), # mat2
offs_placement, # offs
None, # bias
]
colwise_rowwise_2x3: PlacementList = [
Partial(),
Shard(1), # mat1
Shard(1), # mat2
offs_placement, # offs
None, # bias
]
single_mesh_dim_strategies.extend([replicate_colwise_2x3, colwise_rowwise_2x3])
if mat1_strategy.ndim == 3 and mat2_strategy.ndim == 2:
# replicate_colwise for 3dx2d not supported
colwise_rowwise_3x2: PlacementList = [
Partial(),
Shard(2), # mat1
Shard(0), # mat2
offs_placement, # offs
None, # bias
]
rowwise_replicate_3x2: PlacementList = [
Shard(0),
Shard(1), # mat1
Replicate(), # mat2
offs_placement, # offs
None, # bias
]
single_mesh_dim_strategies.extend([colwise_rowwise_3x2, rowwise_replicate_3x2])
if mat1_strategy.ndim == 2 and mat2_strategy.ndim == 2:
# colwise_rowwise for 2dx2d not supported
replicate_colwise_2x2: PlacementList = [
Shard(2),
Replicate(), # mat1
Shard(1), # mat2
offs_placement, # offs
None, # bias
]
rowwise_replicate_2x2: PlacementList = [
Shard(1),
Shard(0), # mat1
Replicate(), # mat2
offs_placement, # offs
None, # bias
]
single_mesh_dim_strategies.extend(
[replicate_colwise_2x2, rowwise_replicate_2x2]
)
if mat1_strategy.ndim == 3 and mat2_strategy.ndim == 3:
replicate_colwise_3x3: PlacementList = [
Shard(2),
Replicate(), # mat1
Shard(2), # mat2
offs_placement, # offs
None, # bias
]
rowwise_replicate_3x3: PlacementList = [
Shard(1),
Shard(1), # mat1
Replicate(), # mat2
offs_placement, # offs
None, # bias
]
colwise_rowwise_3x3: PlacementList = [
Partial(),
Shard(2), # mat1
Shard(1), # mat2
offs_placement, # offs
None, # bias
]
batch_dim_sharding: PlacementList = [
Shard(0),
Shard(0), # mat1
Shard(0), # mat2
offs_placement, # offs
None, # bias
]
single_mesh_dim_strategies.extend(
[
replicate_colwise_3x3,
rowwise_replicate_3x3,
colwise_rowwise_3x3,
batch_dim_sharding,
]
)
return expand_to_full_mesh_op_strategy(
mesh, op_schema, single_mesh_dim_strategies, input_index=1
)