mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[inductor] Fix mm decomposition evaluating symints (#158998)
Fixes #154111 Resolves an issue during compilation with dynamic shapes where `torch._inductor.decomposition.mm` evaluates the SymInt expression for the input tensor due to a for loop, and thus the output tensor is not dynamically shaped. This issue is limited to (Mx1)x(1xN) small matrix multiplications, and creates an explicit error with tensor subclasses such as DTensor. The proposed fix replaces the loop with a simple product instead. Benchmark currently running https://hud.pytorch.org/benchmark/compilers Pull Request resolved: https://github.com/pytorch/pytorch/pull/158998 Approved by: https://github.com/jansel, https://github.com/BoyuanFeng
This commit is contained in:
parent
90fd06be71
commit
24d07b3a67
|
|
@ -1,32 +1,32 @@
|
||||||
add_loop_eager,compile_time_instruction_count,3070000000,0.10
|
add_loop_eager,compile_time_instruction_count,3070000000,0.1
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
add_loop_eager_dynamic,compile_time_instruction_count,4432000000,0.10
|
add_loop_eager_dynamic,compile_time_instruction_count,4432000000,0.1
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
add_loop_inductor,compile_time_instruction_count,30280000000,0.10
|
add_loop_inductor,compile_time_instruction_count,30280000000,0.1
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
add_loop_inductor_dynamic_gpu,compile_time_instruction_count,39910000000,0.10
|
add_loop_inductor_dynamic_gpu,compile_time_instruction_count,39910000000,0.1
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
add_loop_inductor_gpu,compile_time_instruction_count,26800000000,0.10
|
add_loop_inductor_gpu,compile_time_instruction_count,26800000000,0.1
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
basic_modules_ListOfLinears_eager,compile_time_instruction_count,969100000,0.10
|
basic_modules_ListOfLinears_eager,compile_time_instruction_count,969100000,0.1
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
basic_modules_ListOfLinears_inductor,compile_time_instruction_count,18030000000,0.10
|
basic_modules_ListOfLinears_inductor,compile_time_instruction_count,15240000000,0.1
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
basic_modules_ListOfLinears_inductor_gpu_force_shape_pad,compile_time_instruction_count,17020000000,0.10
|
basic_modules_ListOfLinears_inductor_gpu_force_shape_pad,compile_time_instruction_count,17020000000,0.1
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -34,56 +34,56 @@ basic_modules_ListOfLinears_inductor_gpu,compile_time_instruction_count,11090000
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
update_hint_regression,compile_time_instruction_count,1719000000,0.10
|
update_hint_regression,compile_time_instruction_count,1719000000,0.1
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
sum_floordiv_regression,compile_time_instruction_count,966100000,0.10
|
sum_floordiv_regression,compile_time_instruction_count,966100000,0.1
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
symint_sum,compile_time_instruction_count,3237000000,0.10
|
symint_sum,compile_time_instruction_count,3237000000,0.1
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
symint_sum_loop,compile_time_instruction_count,4299000000,0.10
|
symint_sum_loop,compile_time_instruction_count,4299000000,0.1
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
aotdispatcher_inference_nosubclass_cpu,compile_time_instruction_count,2151000000,0.10
|
aotdispatcher_inference_nosubclass_cpu,compile_time_instruction_count,2151000000,0.1
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
aotdispatcher_inference_subclass_cpu,compile_time_instruction_count,6124000000,0.10
|
aotdispatcher_inference_subclass_cpu,compile_time_instruction_count,6124000000,0.1
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
aotdispatcher_partitioner_cpu,compile_time_instruction_count,9005000000,0.10
|
aotdispatcher_partitioner_cpu,compile_time_instruction_count,9005000000,0.1
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
aotdispatcher_partitioner_cpu2,compile_time_instruction_count,1989000000,0.10
|
aotdispatcher_partitioner_cpu2,compile_time_instruction_count,1989000000,0.1
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
aotdispatcher_training_nosubclass_cpu,compile_time_instruction_count,3959000000,0.10
|
aotdispatcher_training_nosubclass_cpu,compile_time_instruction_count,3959000000,0.1
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
aotdispatcher_training_subclass_cpu,compile_time_instruction_count,10650000000,0.10
|
aotdispatcher_training_subclass_cpu,compile_time_instruction_count,10650000000,0.1
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
mm_loop_inductor_gpu,compile_time_instruction_count,4461000000,0.10
|
mm_loop_inductor_gpu,compile_time_instruction_count,4461000000,0.1
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
mm_loop_inductor_dynamic_gpu,compile_time_instruction_count,8417000000,0.10
|
mm_loop_inductor_dynamic_gpu,compile_time_instruction_count,8417000000,0.1
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
basic_NestedModule_eager,compile_time_instruction_count,8348000000,0.10
|
basic_NestedModule_eager,compile_time_instruction_count,8348000000,0.1
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
basic_InlineMod_eager,compile_time_instruction_count,7464000000,0.10
|
basic_InlineMod_eager,compile_time_instruction_count,7464000000,0.1
|
||||||
|
|
|
||||||
|
|
|
@ -6,6 +6,13 @@ from typing import Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch._inductor import config
|
from torch._inductor import config
|
||||||
|
from torch._inductor.decomposition import mm
|
||||||
|
from torch._subclasses.fake_tensor import FakeTensorMode
|
||||||
|
from torch.fx.experimental.symbolic_shapes import (
|
||||||
|
DimDynamic,
|
||||||
|
ShapeEnv,
|
||||||
|
StatelessSymbolicContext,
|
||||||
|
)
|
||||||
from torch.testing._internal.common_cuda import SM80OrLater
|
from torch.testing._internal.common_cuda import SM80OrLater
|
||||||
from torch.testing._internal.common_device_type import instantiate_device_type_tests
|
from torch.testing._internal.common_device_type import instantiate_device_type_tests
|
||||||
from torch.testing._internal.common_nn import NNTestCase
|
from torch.testing._internal.common_nn import NNTestCase
|
||||||
|
|
@ -78,6 +85,19 @@ def torch_baddbmm(add, b, c, alpha, beta):
|
||||||
return torch.baddbmm(add, b, c, alpha=alpha, beta=beta)
|
return torch.baddbmm(add, b, c, alpha=alpha, beta=beta)
|
||||||
|
|
||||||
|
|
||||||
|
def create_fake_tensor_with_dynamic_size(x, fake_mode):
|
||||||
|
with fake_mode:
|
||||||
|
dynamic_sizes = [DimDynamic.DYNAMIC for _ in range(x.dim())]
|
||||||
|
dynamic_strides = [DimDynamic.INFER_STRIDE for _ in range(x.dim())]
|
||||||
|
return fake_mode.from_tensor(
|
||||||
|
x,
|
||||||
|
symbolic_context=StatelessSymbolicContext(
|
||||||
|
dynamic_sizes=dynamic_sizes,
|
||||||
|
dynamic_strides=dynamic_strides,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# The shapes we test on
|
# The shapes we test on
|
||||||
ts_list = [
|
ts_list = [
|
||||||
(1, 32, 32, 1),
|
(1, 32, 32, 1),
|
||||||
|
|
@ -187,6 +207,71 @@ class TestDecomp(NNTestCase):
|
||||||
init_tensor([[[1], [2], [3], [4]]] * bs, dtype=dtype, device=device),
|
init_tensor([[[1], [2], [3], [4]]] * bs, dtype=dtype, device=device),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@parametrize("dtype", [torch.float, torch.bfloat16])
|
||||||
|
def test_dynamic_shape_mm(self, device, dtype):
|
||||||
|
# Test that the mm decomp does not evaluate expressions for dynamic shapes
|
||||||
|
|
||||||
|
shape_env = ShapeEnv()
|
||||||
|
fake_mode = FakeTensorMode(shape_env=shape_env)
|
||||||
|
|
||||||
|
# Only test decomp for cpu to match fake tensors from dynamo
|
||||||
|
if device != "cpu":
|
||||||
|
return
|
||||||
|
|
||||||
|
for t_size in ts_list:
|
||||||
|
((a1_0, a1_1, a2_0, a2_1)) = t_size
|
||||||
|
|
||||||
|
# Create the fake tensors
|
||||||
|
t1 = create_fake_tensor_with_dynamic_size(
|
||||||
|
rand_math_tensor((a1_0, a1_1), dtype=dtype, device=device),
|
||||||
|
fake_mode,
|
||||||
|
)
|
||||||
|
t2 = create_fake_tensor_with_dynamic_size(
|
||||||
|
rand_math_tensor((a2_0, a2_1), dtype=dtype, device=device),
|
||||||
|
fake_mode,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Save the expression types to check if any symints are evaluated
|
||||||
|
og_t1_expr_types = [
|
||||||
|
type(d.node.expr) if type(d) is torch.SymInt else int for d in t1.size()
|
||||||
|
]
|
||||||
|
og_t2_expr_types = [
|
||||||
|
type(d.node.expr) if type(d) is torch.SymInt else int for d in t2.size()
|
||||||
|
]
|
||||||
|
|
||||||
|
r = mm(t1, t2)
|
||||||
|
|
||||||
|
# Make sure all symints are not evaluated
|
||||||
|
new_t1_expr_types = [
|
||||||
|
type(d.node.expr) if type(d) is torch.SymInt else int for d in t1.size()
|
||||||
|
]
|
||||||
|
new_t2_expr_types = [
|
||||||
|
type(d.node.expr) if type(d) is torch.SymInt else int for d in t2.size()
|
||||||
|
]
|
||||||
|
self.assertTrue(
|
||||||
|
all(
|
||||||
|
og_t1_expr_types[i] == new_t1_expr_types[i]
|
||||||
|
for i in range(len(og_t1_expr_types))
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.assertTrue(
|
||||||
|
all(
|
||||||
|
og_t2_expr_types[i] == new_t2_expr_types[i]
|
||||||
|
for i in range(len(og_t2_expr_types))
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if r is not NotImplemented:
|
||||||
|
# Check that the output is well formed
|
||||||
|
self.assertEqual(t1.size(0), r.size(0))
|
||||||
|
self.assertEqual(t2.size(1), r.size(1))
|
||||||
|
r_expr_types = [
|
||||||
|
type(d.node.expr) if type(d) is torch.SymInt else int
|
||||||
|
for d in r.size()
|
||||||
|
]
|
||||||
|
self.assertTrue(r_expr_types[0] == og_t1_expr_types[0])
|
||||||
|
self.assertTrue(r_expr_types[1] == og_t2_expr_types[1])
|
||||||
|
|
||||||
|
|
||||||
device_types = ("cpu", GPU_TYPE)
|
device_types = ("cpu", GPU_TYPE)
|
||||||
instantiate_device_type_tests(TestDecomp, globals(), only_for=device_types)
|
instantiate_device_type_tests(TestDecomp, globals(), only_for=device_types)
|
||||||
|
|
|
||||||
|
|
@ -367,7 +367,7 @@ def mm(
|
||||||
and guard_or_false((torch.numel(self) + torch.numel(input2)) <= 32)
|
and guard_or_false((torch.numel(self) + torch.numel(input2)) <= 32)
|
||||||
):
|
):
|
||||||
counters["inductor"]["decompose_mm"] += 1
|
counters["inductor"]["decompose_mm"] += 1
|
||||||
return torch.cat([self[i, :] * input2 for i in range(self.size(0))])
|
return self * input2
|
||||||
if statically_known_true(self.size(0) == 1) and statically_known_true(
|
if statically_known_true(self.size(0) == 1) and statically_known_true(
|
||||||
input2.size(-1) == 1
|
input2.size(-1) == 1
|
||||||
):
|
):
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user