From 14d4a77495dc80da9cd7c5c29b21aaf4613ddc5b Mon Sep 17 00:00:00 2001 From: eellison Date: Wed, 29 Oct 2025 13:52:30 -0700 Subject: [PATCH] disable current modes instead of no dispatch in estimation (#166571) otherwise, the custom estimation's TorchDispatchModes will be disabled. Pull Request resolved: https://github.com/pytorch/pytorch/pull/166571 Approved by: https://github.com/SherlockNoMad, https://github.com/bdhirsh --- .../test_aten_comm_compute_reordering.py | 56 +++++++++++++++++++ .../_inductor/fx_passes/overlap_scheduling.py | 4 +- 2 files changed, 58 insertions(+), 2 deletions(-) diff --git a/test/distributed/test_aten_comm_compute_reordering.py b/test/distributed/test_aten_comm_compute_reordering.py index 4894c6853cd..3cdbe8b84a3 100644 --- a/test/distributed/test_aten_comm_compute_reordering.py +++ b/test/distributed/test_aten_comm_compute_reordering.py @@ -887,6 +887,62 @@ class TestComputeCommReorderingBucketing(TestComputeCommReorderingMultiProc): correct = func(a, b, c, d, ranks=ranks) self.assertTrue(same(test_out, correct)) + @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") + @torch._inductor.config.patch(get_bucket_patches()) + def test_custom_estimation_with_fake_tensor_mode(self): + """Test that custom estimation can use FakeTensorMode for analysis.""" + from torch._subclasses.fake_tensor import FakeTensorMode + + estimation_calls = 0 + + def estimate_with_fake_mode(fx_node, compute_multiplier=1.0): + with FakeTensorMode(): + nonlocal estimation_calls + estimation_calls += 1 + assert isinstance(torch.rand([20]), torch._subclasses.FakeTensor) + + return 1.0 + + patches = get_bucket_patches() + patches["aten_distributed_optimizations.custom_runtime_estimation"] = ( + estimate_with_fake_mode + ) + + def func(a, b, *, ranks): + # Two independent all_gathers that should be bucketed + ag1 = _functional_collectives.all_gather_tensor(a, 0, ranks) + ag2 = _functional_collectives.all_gather_tensor(b, 0, ranks) + + # Matmul that can hide the collectives + mm1 = torch.matmul(a, a) + + return ag1.sum() + ag2.sum() + mm1.sum() + + with _dynamo_dist_per_rank_init( + self.rank, + self.world_size, + self.backend(device_type), + fake_pg=not at_least_x_gpu(2), + ): + inputs_a = torch.ones(4, 4, dtype=torch.float, device=device_type) + inputs_b = torch.ones(4, 4, dtype=torch.float, device=device_type) * 2 + ranks = list(range(self.world_size)) + + func_c = functools.partial(func, ranks=ranks) + with torch._inductor.config.patch(patches): + compiled = torch.compile(func_c) + out, aten_graph_str = run_and_get_aten_graph( + compiled, inputs_a, inputs_b + ) + + # Verify the custom estimation was called + self.assertTrue( + estimation_calls > 0, "Custom estimation should have been called" + ) + + correct = func(inputs_a, inputs_b, ranks=ranks) + self.assertTrue(same(out, correct)) + if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/torch/_inductor/fx_passes/overlap_scheduling.py b/torch/_inductor/fx_passes/overlap_scheduling.py index 3575b2b49ef..de131891bb4 100644 --- a/torch/_inductor/fx_passes/overlap_scheduling.py +++ b/torch/_inductor/fx_passes/overlap_scheduling.py @@ -18,8 +18,8 @@ from torch._inductor.fx_passes.memory_estimator import ( MemoryTracker, ) from torch.fx.operator_schemas import normalize_function -from torch.utils._mode_utils import no_dispatch from torch.utils._ordered_set import OrderedSet +from torch.utils._python_dispatch import _disable_current_modes log = logging.getLogger(__name__) @@ -136,7 +136,7 @@ def benchmark_node_with_cache_key( key += f"T: {shape, stride, t.dtype} " return rand_strided(shape, stride, device=t.device, dtype=t.dtype) # type: ignore[arg-type] - with no_dispatch(): + with _disable_current_modes(): args, kwargs = torch.utils._pytree.tree_map_only( torch.Tensor, lambda t: to_real(t),