mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[dtensor] add op support for torch._grouped_mm (#151072)
This PR would make TP work with Grouped MM in MoE implementations like https://github.com/pytorch/torchtitan/pull/1084 Pull Request resolved: https://github.com/pytorch/pytorch/pull/151072 Approved by: https://github.com/wanchaol, https://github.com/wwwjn
This commit is contained in:
parent
0c59a031c8
commit
7dd2ed1197
|
|
@ -17,9 +17,9 @@ from torch.distributed.tensor import (
|
|||
Shard,
|
||||
)
|
||||
from torch.distributed.tensor.debug import CommDebugMode
|
||||
from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FP8
|
||||
from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FP8, SM90OrLater
|
||||
from torch.testing._internal.common_device_type import E4M3_MAX_POS, e4m3_type
|
||||
from torch.testing._internal.common_utils import run_tests
|
||||
from torch.testing._internal.common_utils import run_tests, TEST_WITH_ROCM
|
||||
from torch.testing._internal.distributed._tensor.common_dtensor import (
|
||||
DTensorTestBase,
|
||||
skip_unless_torch_gpu,
|
||||
|
|
@ -27,6 +27,9 @@ from torch.testing._internal.distributed._tensor.common_dtensor import (
|
|||
)
|
||||
|
||||
|
||||
funcol = torch.ops.c10d_functional
|
||||
|
||||
|
||||
def scale_for_fp8(
|
||||
t: torch.Tensor, scale_shape: tuple[int]
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
|
|
@ -501,6 +504,63 @@ class DistMatrixOpsTest(DTensorTestBase):
|
|||
dist_result_full = dist_result.full_tensor()
|
||||
self.assertEqual(local_result, dist_result_full)
|
||||
|
||||
@unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support CUTLASS")
|
||||
@unittest.skipIf(not SM90OrLater, "Grouped gemm supported on SM90")
|
||||
@with_comms
|
||||
@skip_unless_torch_gpu
|
||||
def test_grouped_mm(self):
|
||||
# TODO: torch._grouped_mm can take inputs of dimension (2D, 3D) x (2D, 3D)
|
||||
# Here we only test the 2D x 3D Tensor Parallel use case in an MoE layer.
|
||||
# More tests need to be added.
|
||||
device_mesh = init_device_mesh(self.device_type, (self.world_size,))
|
||||
comm_mode = CommDebugMode()
|
||||
dtype = torch.bfloat16
|
||||
|
||||
inp = torch.rand(
|
||||
64, 16, device=self.device_type, dtype=dtype, requires_grad=True
|
||||
)
|
||||
w1 = torch.rand(
|
||||
2, 16, 32, device=self.device_type, dtype=dtype, requires_grad=True
|
||||
)
|
||||
w2 = torch.rand(
|
||||
2, 32, 16, device=self.device_type, dtype=dtype, requires_grad=True
|
||||
)
|
||||
offs = torch.tensor([16, 64], device=self.device_type, dtype=torch.int32)
|
||||
|
||||
h = torch._grouped_mm(inp, w1, offs=offs)
|
||||
out = torch._grouped_mm(h, w2, offs=offs)
|
||||
|
||||
dist_inp = distribute_tensor(inp, device_mesh, [Replicate()])
|
||||
# colwise sharded
|
||||
dist_w1 = distribute_tensor(w1, device_mesh, [Shard(2)])
|
||||
# rowwise sharded
|
||||
dist_w2 = distribute_tensor(w2, device_mesh, [Shard(1)])
|
||||
dist_offs = distribute_tensor(offs, device_mesh, [Replicate()])
|
||||
|
||||
with comm_mode:
|
||||
dist_h = torch._grouped_mm(dist_inp, dist_w1, offs=dist_offs)
|
||||
dist_out = torch._grouped_mm(dist_h, dist_w2, offs=dist_offs)
|
||||
self.assertEqual(comm_mode.get_total_counts(), 0)
|
||||
self.assertTrue(dist_out.placements[0].is_partial())
|
||||
self.assertEqual(dist_out.full_tensor(), out)
|
||||
|
||||
out_grad = torch.ones_like(out)
|
||||
out.backward(out_grad)
|
||||
|
||||
dist_out = dist_out.redistribute(device_mesh, [Shard(0)])
|
||||
dist_out_grad = distribute_tensor(out_grad, device_mesh, [Shard(0)])
|
||||
|
||||
with comm_mode:
|
||||
dist_out.backward(dist_out_grad)
|
||||
self.assertEqual(comm_mode.get_total_counts(), 1)
|
||||
self.assertEqual(
|
||||
comm_mode.get_comm_counts()[funcol.all_gather_into_tensor],
|
||||
1,
|
||||
)
|
||||
self.assertEqual(dist_inp.grad.full_tensor(), inp.grad)
|
||||
self.assertEqual(dist_w1.grad.full_tensor(), w1.grad)
|
||||
self.assertEqual(dist_w2.grad.full_tensor(), w2.grad)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
|
|
|||
|
|
@ -7241,6 +7241,59 @@ def sigmoid(self: Tensor) -> Tensor:
|
|||
return torch.empty_like(self, dtype=result_dtype)
|
||||
|
||||
|
||||
@register_meta(aten._grouped_mm)
|
||||
@out_wrapper()
|
||||
def grouped_mm(
|
||||
mat1: Tensor,
|
||||
mat2: Tensor,
|
||||
offs: Optional[Tensor] = None,
|
||||
bias: Optional[Tensor] = None,
|
||||
out_dtype: Optional[torch.dtype] = None,
|
||||
) -> Tensor:
|
||||
torch._check(mat1.dim() == 2 or mat1.dim() == 3, lambda: "mat1 must be 2d or 3d")
|
||||
torch._check(mat2.dim() == 2 or mat2.dim() == 3, lambda: "mat2 must be 2d or 3d")
|
||||
torch._check(
|
||||
(offs is not None) == (mat1.dim() == 2 or mat2.dim() == 2),
|
||||
lambda: "Have to provide offsets if there is a 2d matrix, or no offset if both matrices are 3d",
|
||||
)
|
||||
|
||||
if offs is not None:
|
||||
torch._check(offs.dim() == 1, lambda: "offsets must be 1d")
|
||||
|
||||
out_dtype = out_dtype or mat1.dtype
|
||||
torch._check(bias is None, lambda: "bias not supported yet")
|
||||
|
||||
def _compute_grouped_gemm_output_size(mat1, mat2, offs):
|
||||
mat1_is_2d = mat1.dim() == 2
|
||||
mat2_is_2d = mat2.dim() == 2
|
||||
|
||||
if mat1_is_2d:
|
||||
if mat2_is_2d:
|
||||
return offs.size(0), mat1.size(0), mat2.size(1)
|
||||
else:
|
||||
torch._check(
|
||||
offs.size(0) == mat2.size(0), "matrix batch sizes have to match"
|
||||
)
|
||||
return mat1.size(0), mat2.size(-1)
|
||||
else:
|
||||
if mat2_is_2d:
|
||||
torch._check(
|
||||
offs.size(0) == mat1.size(0), "matrix batch sizes have to match"
|
||||
)
|
||||
return mat1.size(1), mat2.size(1)
|
||||
else:
|
||||
# regular bmm
|
||||
torch._check(
|
||||
mat1.size(0) == mat2.size(0), "batched dimension has to match"
|
||||
)
|
||||
return mat1.size(0), mat1.size(1), mat2.size(-1)
|
||||
|
||||
out_size = _compute_grouped_gemm_output_size(mat1, mat2, offs)
|
||||
out = mat1.new_empty(out_size, dtype=out_dtype)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
@register_meta(aten._softmax)
|
||||
@out_wrapper()
|
||||
def softmax(x: Tensor, dim: int, half_to_float: bool) -> Tensor:
|
||||
|
|
|
|||
|
|
@ -24,7 +24,12 @@ from torch.distributed.tensor._ops.utils import (
|
|||
prod,
|
||||
register_op_strategy,
|
||||
)
|
||||
from torch.distributed.tensor.placement_types import Placement, Replicate, Shard
|
||||
from torch.distributed.tensor.placement_types import (
|
||||
Partial,
|
||||
Placement,
|
||||
Replicate,
|
||||
Shard,
|
||||
)
|
||||
|
||||
|
||||
aten = torch.ops.aten
|
||||
|
|
@ -887,3 +892,143 @@ def scaled_scaled_dot_product_cudnn_attention_backward_strategy(
|
|||
return expand_to_full_mesh_op_strategy(
|
||||
mesh, op_schema, single_mesh_dim_strategies, input_index=3
|
||||
)
|
||||
|
||||
|
||||
@register_op_strategy(aten._grouped_mm.default)
|
||||
def grouped_mm_strategy(op_schema: OpSchema) -> OpStrategy:
|
||||
mesh = op_schema.get_mesh_from_args()
|
||||
|
||||
mat1_strategy = op_schema.args_schema[0]
|
||||
assert isinstance(mat1_strategy, OpStrategy)
|
||||
mat2_strategy = op_schema.args_schema[1]
|
||||
assert isinstance(mat2_strategy, OpStrategy)
|
||||
if len(op_schema.args_schema) > 3:
|
||||
bias_strategy = op_schema.args_schema[3]
|
||||
assert bias_strategy is None, "grouped_mm doesn't support bias yet"
|
||||
|
||||
single_mesh_dim_strategies = []
|
||||
|
||||
offs_placement = None
|
||||
if len(op_schema.args_schema) > 2 and op_schema.args_schema[2] is not None:
|
||||
offs_placement = Replicate() # offs should always be replicated
|
||||
|
||||
all_replicate: PlacementList = [
|
||||
Replicate(),
|
||||
Replicate(), # mat1
|
||||
Replicate(), # mat2
|
||||
offs_placement, # offs
|
||||
None, # bias
|
||||
]
|
||||
partial_replicate: PlacementList = [
|
||||
Partial(),
|
||||
Partial(), # mat1
|
||||
Replicate(), # mat2
|
||||
offs_placement, # offs
|
||||
None, # bias
|
||||
]
|
||||
replicate_partial: PlacementList = [
|
||||
Partial(),
|
||||
Replicate(), # mat1
|
||||
Partial(), # mat2
|
||||
offs_placement, # offs
|
||||
None, # bias
|
||||
]
|
||||
single_mesh_dim_strategies = [all_replicate, partial_replicate, replicate_partial]
|
||||
|
||||
if mat1_strategy.ndim == 2 and mat2_strategy.ndim == 3:
|
||||
# rowwise_replicate for 2dx3d not supported
|
||||
replicate_colwise_2x3: PlacementList = [
|
||||
Shard(1),
|
||||
Replicate(), # mat1
|
||||
Shard(2), # mat2
|
||||
offs_placement, # offs
|
||||
None, # bias
|
||||
]
|
||||
colwise_rowwise_2x3: PlacementList = [
|
||||
Partial(),
|
||||
Shard(1), # mat1
|
||||
Shard(1), # mat2
|
||||
offs_placement, # offs
|
||||
None, # bias
|
||||
]
|
||||
single_mesh_dim_strategies.extend([replicate_colwise_2x3, colwise_rowwise_2x3])
|
||||
|
||||
if mat1_strategy.ndim == 3 and mat2_strategy.ndim == 2:
|
||||
# replicate_colwise for 3dx2d not supported
|
||||
colwise_rowwise_3x2: PlacementList = [
|
||||
Partial(),
|
||||
Shard(2), # mat1
|
||||
Shard(0), # mat2
|
||||
offs_placement, # offs
|
||||
None, # bias
|
||||
]
|
||||
rowwise_replicate_3x2: PlacementList = [
|
||||
Shard(0),
|
||||
Shard(1), # mat1
|
||||
Replicate(), # mat2
|
||||
offs_placement, # offs
|
||||
None, # bias
|
||||
]
|
||||
single_mesh_dim_strategies.extend([colwise_rowwise_3x2, rowwise_replicate_3x2])
|
||||
|
||||
if mat1_strategy.ndim == 2 and mat2_strategy.ndim == 2:
|
||||
# colwise_rowwise for 2dx2d not supported
|
||||
replicate_colwise_2x2: PlacementList = [
|
||||
Shard(2),
|
||||
Replicate(), # mat1
|
||||
Shard(1), # mat2
|
||||
offs_placement, # offs
|
||||
None, # bias
|
||||
]
|
||||
rowwise_replicate_2x2: PlacementList = [
|
||||
Shard(1),
|
||||
Shard(0), # mat1
|
||||
Replicate(), # mat2
|
||||
offs_placement, # offs
|
||||
None, # bias
|
||||
]
|
||||
single_mesh_dim_strategies.extend(
|
||||
[replicate_colwise_2x2, rowwise_replicate_2x2]
|
||||
)
|
||||
|
||||
if mat1_strategy.ndim == 3 and mat2_strategy.ndim == 3:
|
||||
replicate_colwise_3x3: PlacementList = [
|
||||
Shard(2),
|
||||
Replicate(), # mat1
|
||||
Shard(2), # mat2
|
||||
offs_placement, # offs
|
||||
None, # bias
|
||||
]
|
||||
rowwise_replicate_3x3: PlacementList = [
|
||||
Shard(1),
|
||||
Shard(1), # mat1
|
||||
Replicate(), # mat2
|
||||
offs_placement, # offs
|
||||
None, # bias
|
||||
]
|
||||
colwise_rowwise_3x3: PlacementList = [
|
||||
Partial(),
|
||||
Shard(2), # mat1
|
||||
Shard(1), # mat2
|
||||
offs_placement, # offs
|
||||
None, # bias
|
||||
]
|
||||
batch_dim_sharding: PlacementList = [
|
||||
Shard(0),
|
||||
Shard(0), # mat1
|
||||
Shard(0), # mat2
|
||||
offs_placement, # offs
|
||||
None, # bias
|
||||
]
|
||||
single_mesh_dim_strategies.extend(
|
||||
[
|
||||
replicate_colwise_3x3,
|
||||
rowwise_replicate_3x3,
|
||||
colwise_rowwise_3x3,
|
||||
batch_dim_sharding,
|
||||
]
|
||||
)
|
||||
|
||||
return expand_to_full_mesh_op_strategy(
|
||||
mesh, op_schema, single_mesh_dim_strategies, input_index=1
|
||||
)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user