[inductor] Fix mm decomposition evaluating symints (#158998)

Fixes #154111

Resolves an issue during compilation with dynamic shapes where `torch._inductor.decomposition.mm` evaluates the SymInt expression for the input tensor due to a for loop, and thus the output tensor is not dynamically shaped. This issue is limited to (Mx1)x(1xN) small matrix multiplications, and creates an explicit error with tensor subclasses such as DTensor.

The proposed fix replaces the loop with a simple product instead. Benchmark currently running https://hud.pytorch.org/benchmark/compilers

Pull Request resolved: https://github.com/pytorch/pytorch/pull/158998
Approved by: https://github.com/jansel, https://github.com/BoyuanFeng
This commit is contained in:
Arsh Zahed 2025-07-30 16:34:12 +00:00 committed by PyTorch MergeBot
parent 90fd06be71
commit 24d07b3a67
3 changed files with 108 additions and 23 deletions

View File

@ -1,32 +1,32 @@
add_loop_eager,compile_time_instruction_count,3070000000,0.10
add_loop_eager,compile_time_instruction_count,3070000000,0.1
add_loop_eager_dynamic,compile_time_instruction_count,4432000000,0.10
add_loop_eager_dynamic,compile_time_instruction_count,4432000000,0.1
add_loop_inductor,compile_time_instruction_count,30280000000,0.10
add_loop_inductor,compile_time_instruction_count,30280000000,0.1
add_loop_inductor_dynamic_gpu,compile_time_instruction_count,39910000000,0.10
add_loop_inductor_dynamic_gpu,compile_time_instruction_count,39910000000,0.1
add_loop_inductor_gpu,compile_time_instruction_count,26800000000,0.10
add_loop_inductor_gpu,compile_time_instruction_count,26800000000,0.1
basic_modules_ListOfLinears_eager,compile_time_instruction_count,969100000,0.10
basic_modules_ListOfLinears_eager,compile_time_instruction_count,969100000,0.1
basic_modules_ListOfLinears_inductor,compile_time_instruction_count,18030000000,0.10
basic_modules_ListOfLinears_inductor,compile_time_instruction_count,15240000000,0.1
basic_modules_ListOfLinears_inductor_gpu_force_shape_pad,compile_time_instruction_count,17020000000,0.10
basic_modules_ListOfLinears_inductor_gpu_force_shape_pad,compile_time_instruction_count,17020000000,0.1
@ -34,56 +34,56 @@ basic_modules_ListOfLinears_inductor_gpu,compile_time_instruction_count,11090000
update_hint_regression,compile_time_instruction_count,1719000000,0.10
update_hint_regression,compile_time_instruction_count,1719000000,0.1
sum_floordiv_regression,compile_time_instruction_count,966100000,0.10
sum_floordiv_regression,compile_time_instruction_count,966100000,0.1
symint_sum,compile_time_instruction_count,3237000000,0.10
symint_sum,compile_time_instruction_count,3237000000,0.1
symint_sum_loop,compile_time_instruction_count,4299000000,0.10
symint_sum_loop,compile_time_instruction_count,4299000000,0.1
aotdispatcher_inference_nosubclass_cpu,compile_time_instruction_count,2151000000,0.10
aotdispatcher_inference_nosubclass_cpu,compile_time_instruction_count,2151000000,0.1
aotdispatcher_inference_subclass_cpu,compile_time_instruction_count,6124000000,0.10
aotdispatcher_inference_subclass_cpu,compile_time_instruction_count,6124000000,0.1
aotdispatcher_partitioner_cpu,compile_time_instruction_count,9005000000,0.10
aotdispatcher_partitioner_cpu,compile_time_instruction_count,9005000000,0.1
aotdispatcher_partitioner_cpu2,compile_time_instruction_count,1989000000,0.10
aotdispatcher_partitioner_cpu2,compile_time_instruction_count,1989000000,0.1
aotdispatcher_training_nosubclass_cpu,compile_time_instruction_count,3959000000,0.10
aotdispatcher_training_nosubclass_cpu,compile_time_instruction_count,3959000000,0.1
aotdispatcher_training_subclass_cpu,compile_time_instruction_count,10650000000,0.10
aotdispatcher_training_subclass_cpu,compile_time_instruction_count,10650000000,0.1
mm_loop_inductor_gpu,compile_time_instruction_count,4461000000,0.10
mm_loop_inductor_gpu,compile_time_instruction_count,4461000000,0.1
mm_loop_inductor_dynamic_gpu,compile_time_instruction_count,8417000000,0.10
mm_loop_inductor_dynamic_gpu,compile_time_instruction_count,8417000000,0.1
basic_NestedModule_eager,compile_time_instruction_count,8348000000,0.10
basic_NestedModule_eager,compile_time_instruction_count,8348000000,0.1
basic_InlineMod_eager,compile_time_instruction_count,7464000000,0.10
basic_InlineMod_eager,compile_time_instruction_count,7464000000,0.1

1 add_loop_eager compile_time_instruction_count 3070000000 0.10 0.1
2 add_loop_eager_dynamic compile_time_instruction_count 4432000000 0.10 0.1
3 add_loop_inductor compile_time_instruction_count 30280000000 0.10 0.1
4 add_loop_inductor_dynamic_gpu compile_time_instruction_count 39910000000 0.10 0.1
5 add_loop_inductor_gpu compile_time_instruction_count 26800000000 0.10 0.1
6 basic_modules_ListOfLinears_eager compile_time_instruction_count 969100000 0.10 0.1
7 basic_modules_ListOfLinears_inductor compile_time_instruction_count 18030000000 15240000000 0.10 0.1
8 basic_modules_ListOfLinears_inductor_gpu_force_shape_pad compile_time_instruction_count 17020000000 0.10 0.1
9 basic_modules_ListOfLinears_inductor_gpu compile_time_instruction_count 11090000000 0.2 0.2
10 update_hint_regression compile_time_instruction_count 1719000000 0.10 0.1
11 sum_floordiv_regression compile_time_instruction_count 966100000 0.10 0.1
12 symint_sum compile_time_instruction_count 3237000000 0.10 0.1
13 symint_sum_loop compile_time_instruction_count 4299000000 0.10 0.1
14 aotdispatcher_inference_nosubclass_cpu compile_time_instruction_count 2151000000 0.10 0.1
15 aotdispatcher_inference_subclass_cpu compile_time_instruction_count 6124000000 0.10 0.1
16 aotdispatcher_partitioner_cpu compile_time_instruction_count 9005000000 0.10 0.1
17 aotdispatcher_partitioner_cpu2 compile_time_instruction_count 1989000000 0.10 0.1
18 aotdispatcher_training_nosubclass_cpu compile_time_instruction_count 3959000000 0.10 0.1
19 aotdispatcher_training_subclass_cpu compile_time_instruction_count 10650000000 0.10 0.1
20 mm_loop_inductor_gpu compile_time_instruction_count 4461000000 0.10 0.1
21 mm_loop_inductor_dynamic_gpu compile_time_instruction_count 8417000000 0.10 0.1
22 basic_NestedModule_eager compile_time_instruction_count 8348000000 0.10 0.1
23 basic_InlineMod_eager compile_time_instruction_count 7464000000 0.10 0.1
24
25
26
27
28
29
30
31
32
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89

View File

@ -6,6 +6,13 @@ from typing import Union
import torch
from torch._inductor import config
from torch._inductor.decomposition import mm
from torch._subclasses.fake_tensor import FakeTensorMode
from torch.fx.experimental.symbolic_shapes import (
DimDynamic,
ShapeEnv,
StatelessSymbolicContext,
)
from torch.testing._internal.common_cuda import SM80OrLater
from torch.testing._internal.common_device_type import instantiate_device_type_tests
from torch.testing._internal.common_nn import NNTestCase
@ -78,6 +85,19 @@ def torch_baddbmm(add, b, c, alpha, beta):
return torch.baddbmm(add, b, c, alpha=alpha, beta=beta)
def create_fake_tensor_with_dynamic_size(x, fake_mode):
with fake_mode:
dynamic_sizes = [DimDynamic.DYNAMIC for _ in range(x.dim())]
dynamic_strides = [DimDynamic.INFER_STRIDE for _ in range(x.dim())]
return fake_mode.from_tensor(
x,
symbolic_context=StatelessSymbolicContext(
dynamic_sizes=dynamic_sizes,
dynamic_strides=dynamic_strides,
),
)
# The shapes we test on
ts_list = [
(1, 32, 32, 1),
@ -187,6 +207,71 @@ class TestDecomp(NNTestCase):
init_tensor([[[1], [2], [3], [4]]] * bs, dtype=dtype, device=device),
)
@parametrize("dtype", [torch.float, torch.bfloat16])
def test_dynamic_shape_mm(self, device, dtype):
# Test that the mm decomp does not evaluate expressions for dynamic shapes
shape_env = ShapeEnv()
fake_mode = FakeTensorMode(shape_env=shape_env)
# Only test decomp for cpu to match fake tensors from dynamo
if device != "cpu":
return
for t_size in ts_list:
((a1_0, a1_1, a2_0, a2_1)) = t_size
# Create the fake tensors
t1 = create_fake_tensor_with_dynamic_size(
rand_math_tensor((a1_0, a1_1), dtype=dtype, device=device),
fake_mode,
)
t2 = create_fake_tensor_with_dynamic_size(
rand_math_tensor((a2_0, a2_1), dtype=dtype, device=device),
fake_mode,
)
# Save the expression types to check if any symints are evaluated
og_t1_expr_types = [
type(d.node.expr) if type(d) is torch.SymInt else int for d in t1.size()
]
og_t2_expr_types = [
type(d.node.expr) if type(d) is torch.SymInt else int for d in t2.size()
]
r = mm(t1, t2)
# Make sure all symints are not evaluated
new_t1_expr_types = [
type(d.node.expr) if type(d) is torch.SymInt else int for d in t1.size()
]
new_t2_expr_types = [
type(d.node.expr) if type(d) is torch.SymInt else int for d in t2.size()
]
self.assertTrue(
all(
og_t1_expr_types[i] == new_t1_expr_types[i]
for i in range(len(og_t1_expr_types))
)
)
self.assertTrue(
all(
og_t2_expr_types[i] == new_t2_expr_types[i]
for i in range(len(og_t2_expr_types))
)
)
if r is not NotImplemented:
# Check that the output is well formed
self.assertEqual(t1.size(0), r.size(0))
self.assertEqual(t2.size(1), r.size(1))
r_expr_types = [
type(d.node.expr) if type(d) is torch.SymInt else int
for d in r.size()
]
self.assertTrue(r_expr_types[0] == og_t1_expr_types[0])
self.assertTrue(r_expr_types[1] == og_t2_expr_types[1])
device_types = ("cpu", GPU_TYPE)
instantiate_device_type_tests(TestDecomp, globals(), only_for=device_types)

View File

@ -367,7 +367,7 @@ def mm(
and guard_or_false((torch.numel(self) + torch.numel(input2)) <= 32)
):
counters["inductor"]["decompose_mm"] += 1
return torch.cat([self[i, :] * input2 for i in range(self.size(0))])
return self * input2
if statically_known_true(self.size(0) == 1) and statically_known_true(
input2.size(-1) == 1
):