diff --git a/benchmarks/dynamo/pr_time_benchmarks/expected_results.csv b/benchmarks/dynamo/pr_time_benchmarks/expected_results.csv index c0d676f8851..a0ef356d29d 100644 --- a/benchmarks/dynamo/pr_time_benchmarks/expected_results.csv +++ b/benchmarks/dynamo/pr_time_benchmarks/expected_results.csv @@ -1,32 +1,32 @@ -add_loop_eager,compile_time_instruction_count,3070000000,0.10 +add_loop_eager,compile_time_instruction_count,3070000000,0.1 -add_loop_eager_dynamic,compile_time_instruction_count,4432000000,0.10 +add_loop_eager_dynamic,compile_time_instruction_count,4432000000,0.1 -add_loop_inductor,compile_time_instruction_count,30280000000,0.10 +add_loop_inductor,compile_time_instruction_count,30280000000,0.1 -add_loop_inductor_dynamic_gpu,compile_time_instruction_count,39910000000,0.10 +add_loop_inductor_dynamic_gpu,compile_time_instruction_count,39910000000,0.1 -add_loop_inductor_gpu,compile_time_instruction_count,26800000000,0.10 +add_loop_inductor_gpu,compile_time_instruction_count,26800000000,0.1 -basic_modules_ListOfLinears_eager,compile_time_instruction_count,969100000,0.10 +basic_modules_ListOfLinears_eager,compile_time_instruction_count,969100000,0.1 -basic_modules_ListOfLinears_inductor,compile_time_instruction_count,18030000000,0.10 +basic_modules_ListOfLinears_inductor,compile_time_instruction_count,15240000000,0.1 -basic_modules_ListOfLinears_inductor_gpu_force_shape_pad,compile_time_instruction_count,17020000000,0.10 +basic_modules_ListOfLinears_inductor_gpu_force_shape_pad,compile_time_instruction_count,17020000000,0.1 @@ -34,56 +34,56 @@ basic_modules_ListOfLinears_inductor_gpu,compile_time_instruction_count,11090000 -update_hint_regression,compile_time_instruction_count,1719000000,0.10 +update_hint_regression,compile_time_instruction_count,1719000000,0.1 -sum_floordiv_regression,compile_time_instruction_count,966100000,0.10 +sum_floordiv_regression,compile_time_instruction_count,966100000,0.1 -symint_sum,compile_time_instruction_count,3237000000,0.10 +symint_sum,compile_time_instruction_count,3237000000,0.1 -symint_sum_loop,compile_time_instruction_count,4299000000,0.10 +symint_sum_loop,compile_time_instruction_count,4299000000,0.1 -aotdispatcher_inference_nosubclass_cpu,compile_time_instruction_count,2151000000,0.10 +aotdispatcher_inference_nosubclass_cpu,compile_time_instruction_count,2151000000,0.1 -aotdispatcher_inference_subclass_cpu,compile_time_instruction_count,6124000000,0.10 +aotdispatcher_inference_subclass_cpu,compile_time_instruction_count,6124000000,0.1 -aotdispatcher_partitioner_cpu,compile_time_instruction_count,9005000000,0.10 +aotdispatcher_partitioner_cpu,compile_time_instruction_count,9005000000,0.1 -aotdispatcher_partitioner_cpu2,compile_time_instruction_count,1989000000,0.10 +aotdispatcher_partitioner_cpu2,compile_time_instruction_count,1989000000,0.1 -aotdispatcher_training_nosubclass_cpu,compile_time_instruction_count,3959000000,0.10 +aotdispatcher_training_nosubclass_cpu,compile_time_instruction_count,3959000000,0.1 -aotdispatcher_training_subclass_cpu,compile_time_instruction_count,10650000000,0.10 +aotdispatcher_training_subclass_cpu,compile_time_instruction_count,10650000000,0.1 -mm_loop_inductor_gpu,compile_time_instruction_count,4461000000,0.10 +mm_loop_inductor_gpu,compile_time_instruction_count,4461000000,0.1 -mm_loop_inductor_dynamic_gpu,compile_time_instruction_count,8417000000,0.10 +mm_loop_inductor_dynamic_gpu,compile_time_instruction_count,8417000000,0.1 -basic_NestedModule_eager,compile_time_instruction_count,8348000000,0.10 +basic_NestedModule_eager,compile_time_instruction_count,8348000000,0.1 -basic_InlineMod_eager,compile_time_instruction_count,7464000000,0.10 +basic_InlineMod_eager,compile_time_instruction_count,7464000000,0.1 diff --git a/test/inductor/test_mmdecomp.py b/test/inductor/test_mmdecomp.py index 05b7afe0d2c..22a5d833245 100644 --- a/test/inductor/test_mmdecomp.py +++ b/test/inductor/test_mmdecomp.py @@ -6,6 +6,13 @@ from typing import Union import torch from torch._inductor import config +from torch._inductor.decomposition import mm +from torch._subclasses.fake_tensor import FakeTensorMode +from torch.fx.experimental.symbolic_shapes import ( + DimDynamic, + ShapeEnv, + StatelessSymbolicContext, +) from torch.testing._internal.common_cuda import SM80OrLater from torch.testing._internal.common_device_type import instantiate_device_type_tests from torch.testing._internal.common_nn import NNTestCase @@ -78,6 +85,19 @@ def torch_baddbmm(add, b, c, alpha, beta): return torch.baddbmm(add, b, c, alpha=alpha, beta=beta) +def create_fake_tensor_with_dynamic_size(x, fake_mode): + with fake_mode: + dynamic_sizes = [DimDynamic.DYNAMIC for _ in range(x.dim())] + dynamic_strides = [DimDynamic.INFER_STRIDE for _ in range(x.dim())] + return fake_mode.from_tensor( + x, + symbolic_context=StatelessSymbolicContext( + dynamic_sizes=dynamic_sizes, + dynamic_strides=dynamic_strides, + ), + ) + + # The shapes we test on ts_list = [ (1, 32, 32, 1), @@ -187,6 +207,71 @@ class TestDecomp(NNTestCase): init_tensor([[[1], [2], [3], [4]]] * bs, dtype=dtype, device=device), ) + @parametrize("dtype", [torch.float, torch.bfloat16]) + def test_dynamic_shape_mm(self, device, dtype): + # Test that the mm decomp does not evaluate expressions for dynamic shapes + + shape_env = ShapeEnv() + fake_mode = FakeTensorMode(shape_env=shape_env) + + # Only test decomp for cpu to match fake tensors from dynamo + if device != "cpu": + return + + for t_size in ts_list: + ((a1_0, a1_1, a2_0, a2_1)) = t_size + + # Create the fake tensors + t1 = create_fake_tensor_with_dynamic_size( + rand_math_tensor((a1_0, a1_1), dtype=dtype, device=device), + fake_mode, + ) + t2 = create_fake_tensor_with_dynamic_size( + rand_math_tensor((a2_0, a2_1), dtype=dtype, device=device), + fake_mode, + ) + + # Save the expression types to check if any symints are evaluated + og_t1_expr_types = [ + type(d.node.expr) if type(d) is torch.SymInt else int for d in t1.size() + ] + og_t2_expr_types = [ + type(d.node.expr) if type(d) is torch.SymInt else int for d in t2.size() + ] + + r = mm(t1, t2) + + # Make sure all symints are not evaluated + new_t1_expr_types = [ + type(d.node.expr) if type(d) is torch.SymInt else int for d in t1.size() + ] + new_t2_expr_types = [ + type(d.node.expr) if type(d) is torch.SymInt else int for d in t2.size() + ] + self.assertTrue( + all( + og_t1_expr_types[i] == new_t1_expr_types[i] + for i in range(len(og_t1_expr_types)) + ) + ) + self.assertTrue( + all( + og_t2_expr_types[i] == new_t2_expr_types[i] + for i in range(len(og_t2_expr_types)) + ) + ) + + if r is not NotImplemented: + # Check that the output is well formed + self.assertEqual(t1.size(0), r.size(0)) + self.assertEqual(t2.size(1), r.size(1)) + r_expr_types = [ + type(d.node.expr) if type(d) is torch.SymInt else int + for d in r.size() + ] + self.assertTrue(r_expr_types[0] == og_t1_expr_types[0]) + self.assertTrue(r_expr_types[1] == og_t2_expr_types[1]) + device_types = ("cpu", GPU_TYPE) instantiate_device_type_tests(TestDecomp, globals(), only_for=device_types) diff --git a/torch/_inductor/decomposition.py b/torch/_inductor/decomposition.py index a73da08c904..2622ab6b95e 100644 --- a/torch/_inductor/decomposition.py +++ b/torch/_inductor/decomposition.py @@ -367,7 +367,7 @@ def mm( and guard_or_false((torch.numel(self) + torch.numel(input2)) <= 32) ): counters["inductor"]["decompose_mm"] += 1 - return torch.cat([self[i, :] * input2 for i in range(self.size(0))]) + return self * input2 if statically_known_true(self.size(0) == 1) and statically_known_true( input2.size(-1) == 1 ):