diff --git a/test/inductor/test_max_autotune.py b/test/inductor/test_max_autotune.py index dcbac368157..0d6750add17 100644 --- a/test/inductor/test_max_autotune.py +++ b/test/inductor/test_max_autotune.py @@ -28,9 +28,8 @@ from torch._inductor.autotune_process import ( TuningProcessPool, ) from torch._inductor.graph import GraphLowering -from torch._inductor.ir import Buffer, ChoiceCaller, FixedLayout, InputBuffer +from torch._inductor.ir import Buffer, ChoiceCaller, FixedLayout from torch._inductor.kernel.mm_plus_mm import aten_mm_plus_mm -from torch._inductor.kernel_inputs import MMKernelInputs from torch._inductor.select_algorithm import ( add_feedback_saver, AlgorithmSelectorCache, @@ -76,7 +75,7 @@ from torch.testing._internal.inductor_utils import ( ) -torch.backends.cuda.matmul.allow_tf32 = True +torch.set_float32_matmul_precision("high") if HAS_CUDA_AND_TRITON: torch.cuda.memory._set_allocator_settings("expandable_segments:False") @@ -2077,39 +2076,6 @@ class TestMaxAutotuneRemoteCache(TestCase): global_stats.report() self.assertEqual(global_stats.autotune_remote, Stats(2, 3, 2)) - def test_get_mm_configs_float32_precision_ieee(self): - """Test that configs returned from choices.get_mm_configs use float32_precision == ieee.""" - from torch._inductor.choices import InductorChoices - from torch._inductor.graph import GraphLowering - from torch._inductor.ir import FlexibleLayout - from torch.fx.experimental.proxy_tensor import make_fx - - # Create a simple graph to get proper context - gm = make_fx(lambda: torch.zeros(2, 3))() - graph = GraphLowering(gm) - - with V.set_graph_handler(graph): - device = torch.device(f"{GPU_TYPE}:0") - mat1 = InputBuffer( - name="mat1", - layout=FixedLayout(device, torch.float32, [64, 128], [128, 1]), - ) - mat2 = InputBuffer( - name="mat2", - layout=FixedLayout(device, torch.float32, [128, 64], [64, 1]), - ) - kernel_inputs = MMKernelInputs([mat1, mat2]) - output_layout = FlexibleLayout(device, torch.float32, [64, 64]) - - choices = InductorChoices() - configs = list( - choices.get_mm_configs(kernel_inputs, output_layout, "mm", "mm") - ) - - for cfg in configs: - self.assertIn("ALLOW_TF32", cfg) - self.assertEqual(cfg["ALLOW_TF32"], True) - class _TestTritonTemplateCaller(TritonTemplateCaller): def __init__(self, bmreq: _TestBenchmarkRequest):