# Owner(s): ["module: dynamo"] import unittest import torch import torch._dynamo.test_case from torch._dynamo.testing import CompileCounter, EagerAndRecordGraphs, normalize_gm from torch.testing._internal.common_cuda import TEST_CUDA class PythonDispatcherTests(torch._dynamo.test_case.TestCase): def test_dispatch_key1(self): @torch.compile(backend="aot_eager", fullgraph=True) def fn(x): x = x + 1 return torch._C._dispatch_keys(x) x = torch.randn(2, 3) self.assertTrue(fn(x).raw_repr() == torch._C._dispatch_keys(x + 1).raw_repr()) def test_dispatch_key2(self): from torch.testing._internal.two_tensor import TwoTensor @torch.compile(backend="aot_eager", fullgraph=True) def fn(x): x = x.sin() return torch._C._dispatch_keys(x) x = torch.randn(3) y = torch.randn(3) z = TwoTensor(x, y) self.assertTrue(fn(z).raw_repr() == torch._C._dispatch_keys(z.sin()).raw_repr()) def test_dispatch_key3(self): @torch.compile(backend="aot_eager", fullgraph=True) def fn(x): key_set = torch._C._dispatch_tls_local_include_set() return torch.sin(x + 1), key_set x = torch.randn(2, 3) self.assertEqual(fn(x)[0], torch.sin(x + 1)) self.assertTrue( fn(x)[1].raw_repr() == torch._C._dispatch_tls_local_include_set().raw_repr() ) def test_dispatch_key4(self): eager = EagerAndRecordGraphs() @torch.compile(backend=eager, fullgraph=True) def fn(x): key_set = torch._C._dispatch_tls_local_include_set() key_set = key_set | torch._C._dispatch_keys(x) key_set = key_set - torch._C._dispatch_tls_local_exclude_set() if key_set.highestPriorityTypeId() == torch.DispatchKey.PythonDispatcher: return torch.sin(x + 1) else: return torch.sin(x - 1) x = torch.randn(2, 3) self.assertEqual(fn(x), torch.sin(x - 1)) graph = eager.graphs[0] actual = normalize_gm(graph.print_readable(False)) self.assertExpectedInline( actual, """\ class GraphModule(torch.nn.Module): def forward(self, L_x_: "f32[2, 3]"): l_x_ = L_x_ sub: "f32[2, 3]" = l_x_ - 1; l_x_ = None sin: "f32[2, 3]" = torch.sin(sub); sub = None return (sin,) """, # NOQA: B950 ) @unittest.skipIf(not TEST_CUDA, "requires cuda") def test_dispatch_key_set_guard(self): counter = CompileCounter() @torch.compile(backend=counter, fullgraph=True) def fn(x, dks): if dks.has("CPU"): return torch.sin(x + 1) else: return torch.sin(x - 1) x1 = torch.randn(2, 3) dks1 = torch._C._dispatch_keys(x1) self.assertEqual(fn(x1, dks1), torch.sin(x1 + 1)) self.assertEqual(counter.frame_count, 1) x2 = torch.randn(2, 3) dks2 = torch._C._dispatch_keys(x2) self.assertEqual(fn(x2, dks2), torch.sin(x2 + 1)) # No recompile since the dispatch key set is the same though the tensor is different. self.assertEqual(counter.frame_count, 1) x3 = torch.randn(2, 3, device="cuda") dks3 = torch._C._dispatch_keys(x3) self.assertEqual(fn(x3, dks3), torch.sin(x3 - 1)) # Re-compile since the dispatch key set is different. self.assertEqual(counter.frame_count, 2) def test_functorch_interpreter(self): counter = CompileCounter() def square_and_add(x, y): interpreter = ( torch._functorch.pyfunctorch.retrieve_current_functorch_interpreter() ) level = interpreter.level() if interpreter.key() == torch._C._functorch.TransformType.Vmap: return (x**2 + y) * level else: return x**2 * level @torch.compile(backend=counter, fullgraph=True) def fn(x, y): return torch.vmap(square_and_add)(x, y) x = torch.tensor([1, 2, 3, 4]) y = torch.tensor([10, 20, 30, 40]) self.assertEqual(fn(x, y), torch.tensor([11, 24, 39, 56])) self.assertEqual(counter.frame_count, 1) x = torch.tensor([1, 2, 3, 1]) y = torch.tensor([10, 20, 30, 10]) self.assertEqual(fn(x, y), torch.tensor([11, 24, 39, 11])) # No recompile self.assertEqual(counter.frame_count, 1) if __name__ == "__main__": from torch._dynamo.test_case import run_tests run_tests()