# Owner(s): ["module: dynamo"] import contextlib import torch import torch.fx from torch._dynamo.test_case import TestCase from torch._dynamo.testing import extract_graph_and_tracker from torch.utils._pytree import tree_map class GraphRegionTrackerTests(TestCase): def setUp(self): self.exit_stack = contextlib.ExitStack() self.exit_stack.enter_context( torch._dynamo.config.patch("track_nodes_for_deduplication", True) ) super().setUp() def tearDown(self): self.exit_stack.close() super().tearDown() def get_result(self, fn, *args, **kwargs): graph, region_tracker = extract_graph_and_tracker(fn, *args, **kwargs) region_groups = region_tracker.get_identical_regions(graph) region_groups = tree_map(lambda n: n.name, region_groups) return str(region_groups) def get_mutation_tracking(self, fn, *args, **kwargs): _, region_tracker = extract_graph_and_tracker(fn, *args, **kwargs) return str(region_tracker.node_to_mutated_arg_positions) def test_get_regions_single_region_group(self): def inner_fn(x, y): x0 = x + 1 y0 = y + 2 z = x0.sum() + y0.sum() return z def fn(x, y): o0 = inner_fn(x, y) o1 = torch.sin(y) o2 = inner_fn(x, o1) o3 = inner_fn(x, y) o4 = o3 * o3 return o2 * o4 + o0 self.assertExpectedInline( self.get_result( fn, torch.rand(10, 10), torch.ones(10, 20), ), """[[['x0', 'y0', 'sum_1', 'sum_2', 'z'], ['x0_1', 'y0_1', 'sum_3', 'sum_4', 'z_1'],\ ['x0_2', 'y0_2', 'sum_5', 'sum_6', 'z_2']]]""", ) def test_get_regions_multiple_region_groups(self): def inner_fn(x, y): x1 = x + 1 y1 = y + 2 z = x1.sum() + y1.sum() return z def inner_fn2(a, b): a += 2 b += 3 c = a * b.cos().sum() return c def fn(x, y): x0 = torch.cos(x) y0 = torch.sin(y) o1 = inner_fn2(x0, y0) o0 = inner_fn(x, y) o1 = torch.sin(o0) o2 = inner_fn(x, y0) o2 = inner_fn2(x0, y0) o3 = inner_fn(x, y) return o1 * o2 + o3 self.assertExpectedInline( self.get_result( fn, torch.rand(10, 10), torch.ones(10, 20), ), """[[['x1', 'y1', 'sum_2', 'sum_3', 'z'], ['x1_1', 'y1_1', 'sum_4', 'sum_5', 'z_1'],\ ['x1_2', 'y1_2', 'sum_7', 'sum_8', 'z_2']], [['a', 'b', 'cos_1', 'sum_1', 'c'], ['a_1', 'b_1', 'cos_2', 'sum_6', 'c_1']]]""", ) def test_no_single_node_regions(self): def inner_fn(x): return x + 1 def fn(x): o0 = inner_fn(x) o1 = inner_fn(x) o2 = inner_fn(x) return o0 + o1 + o2 self.assertExpectedInline(self.get_result(fn, torch.ones(10, 10)), """[]""") def test_mismatched_arg_shapes(self): def inner_fn(x, y): x1 = x + 1 y1 = y + 2 z = x1.sum() + y1.sum() return z def inner_fn2(a, b): a += 2 b += 3 c = a * b.cos().sum() return c def fn(x, y): x0 = torch.cos(x) y0 = torch.sin(y) o1 = inner_fn2(x0, y0) o0 = inner_fn(x, o1) o1 = torch.sin(o0) o2 = inner_fn(x, y0) o2 = inner_fn2(o2, y0) o3 = inner_fn(x, y) return o1 * o2 + o3 self.assertExpectedInline( self.get_result( fn, torch.rand(10, 10), torch.ones(10, 20), ), """[[['y1_1', 'sum_5'], ['y1_2', 'sum_8']], [['x1', 'sum_2', 'z'], ['x1_1', 'sum_4', 'z_1'], \ ['x1_2', 'sum_7', 'z_2']], [['b', 'cos_1', 'sum_1'], ['b_1', 'cos_2', 'sum_6']]]""", ) def test_mismatched_dtypes(self): def inner_fn(x, y): x1 = x * 1 y1 = y + 1 return x1 + y1.sum() def fn(x, y): x0 = torch.sin(x) y0 = torch.cos(y) o0 = inner_fn(x0, y0) o2 = inner_fn(x0, y0) o4 = inner_fn(x0, y0) o5 = inner_fn(x0, y0) o1 = inner_fn(x0.to(torch.bfloat16), y0.to(torch.bfloat16)) o3 = o1 + o2 return o3 * o0 + o4 + o5 self.assertExpectedInline( self.get_result( fn, torch.rand(10, 10), torch.ones(10, 20), ), """[[['x1', 'y1', 'sum_1', 'o0'], ['x1_1', 'y1_1', 'sum_2', 'o2'], \ ['x1_2', 'y1_2', 'sum_3', 'o4'], ['x1_3', 'y1_3', 'sum_4', 'o5']]]""", ) def test_nested_args(self): def inner_fn(xs, ys): out = torch._foreach_add(xs, ys) return out[0] + out[1].sum() def fn(x, y, z): x0 = torch.sin(x) y0 = torch.cos(y) z0 = torch.sin(z) o0 = inner_fn([x0, z0], [x0, y0]) o2 = inner_fn([x0, z0], [x0, y0]) o4 = inner_fn([x0, z0], [x0, y0]) o5 = inner_fn([x0, z0], [x0, y0]) o1 = inner_fn( [x0.to(torch.bfloat16), z0.to(torch.bfloat16)], [x0.to(torch.bfloat16), y0.to(torch.bfloat16)], ) o3 = o1 + o2 return o3 * o0 + o4 + o5 self.assertExpectedInline( self.get_result( fn, torch.rand(10, 10), torch.rand(10, 20), torch.ones(10, 20), ), """[[['_foreach_add', 'getitem', 'getitem_1', 'sum_1', 'o0'], ['_foreach_add_1', \ 'getitem_2', 'getitem_3', 'sum_2', 'o2'], ['_foreach_add_2', 'getitem_4', 'getitem_5', 'sum_3', \ 'o4'], ['_foreach_add_3', 'getitem_6', 'getitem_7', 'sum_4', 'o5']]]""", ) def test_mismatched_global_state(self): def inner_fn(x, y): x1 = x * 1 y1 = y + 1 return x1 + y1.sum() def fn(x, y, c): x0 = torch.sin(x) y0 = torch.cos(y) o4 = inner_fn(x0, y0) o5 = inner_fn(x0, y0) if isinstance(c, tuple): c[0]() o0 = inner_fn(x0, y0) o2 = inner_fn(x0, y0) c[1]() else: with c(): o0 = inner_fn(x0, y0) o2 = inner_fn(x0, y0) return o0 + o2 + o4 + o5 def create_toggle_fns(property): old_value = getattr(torch.backends.cuda.matmul, property) def toggle_property(): setattr(torch.backends.cuda.matmul, property, not old_value) def reset_property(): setattr(torch.backends.cuda.matmul, property, old_value) return toggle_property, reset_property old_dtype = torch.get_default_dtype() def set_default_dtype_bfloat16(): torch.set_default_dtype(torch.bfloat16) def reset_default_dtype(): torch.set_default_dtype(old_dtype) for ctx in [ lambda: torch.set_grad_enabled(False), torch.autograd.grad_mode.inference_mode, lambda: torch.autograd.graph.disable_saved_tensors_hooks( "This is not supported" ), # lambda: torch.set_num_threads(2), : Unsupported (set_default_dtype_bfloat16, reset_default_dtype), ( lambda: torch.use_deterministic_algorithms(True), lambda: torch.use_deterministic_algorithms(False), ), # (lambda: torch.use_deterministic_algorithms(True, warn_only=True), # lambda: torch.use_deterministic_algorithms(False)), : Unsupported create_toggle_fns("allow_bf16_reduced_precision_reduction"), create_toggle_fns("allow_fp16_reduced_precision_reduction"), create_toggle_fns("allow_tf32"), ]: self.assertExpectedInline( self.get_result(fn, torch.rand(10, 10), torch.ones(10, 20), ctx), """[[['x1_2', 'y1_2', 'sum_3', 'o0'], ['x1_3', 'y1_3', 'sum_4', 'o2']], \ [['x1', 'y1', 'sum_1', 'o4'], ['x1_1', 'y1_1', 'sum_2', 'o5']]]""", ) def test_mutation_tracking_simple(self): def fn(x, y, z): x0 = torch.sin(x) y0 = torch.cos(y) z.sin_() y0.add_(z) return x0.sum() + y0.sum() self.assertExpectedInline( self.get_mutation_tracking( fn, torch.rand(10, 10), torch.rand(10, 20), torch.ones(10, 20), ), """{sin_: OrderedSet([0]), add_: OrderedSet([0])}""", ) def test_mutation_tracking_setitem(self): def fn(x): y = x + 1 y[0] = 3 return y self.assertExpectedInline( self.get_mutation_tracking(fn, torch.rand(10, 10)), """{setitem: OrderedSet([0])}""", ) def test_mutation_tracking_allow_in_graph(self): @torch._dynamo.allow_in_graph def fn_mut(x, y): x.add_(y) return x.sum() + y.sum() def fn(x, y): z = x + y o0 = fn_mut(z, y) z.sin_() return x + o0 self.assertExpectedInline( self.get_mutation_tracking( fn, torch.rand(20, 10), torch.rand(20, 10), ), """{o0: OrderedSet([0]), sin_: OrderedSet([0])}""", ) def test_non_tensor_arg_hashing(self): def inner(x, w, t): y = x + x return torch.conv2d(y, w, None, *t) def fn(x, y): o1 = inner(x, y, ((1, 1), (0, 0), (1, 1), 1)) o2 = inner(x, y, ((1, 1), (0, 0), (1, 1), 1)) o3 = inner(x, y, ((1, 1), (0, 0), (1, 1), 1)) o4 = inner(x, y, ((2, 2), (0, 0), (1, 1), 1)) return o1.sum() + o2.sum() + o3.sum() + o4.sum() self.assertExpectedInline( self.get_result( fn, torch.rand(32, 256, 56, 56), torch.nn.Parameter(torch.rand(512, 256, 1, 1)), ), """[[['y', 'o1'], ['y_1', 'o2'], ['y_2', 'o3']]]""", ) def test_region_sorting(self): from torch._dynamo.graph_region_tracker import _sort_with_ref_region index_to_rank = {0: 0, 2: 1, 1: 2} regions = [[0, 1, 2], [1, 2, 0]] _sort_with_ref_region(index_to_rank, regions) self.assertExpectedInline(regions, """[[0, 2, 1], [1, 0, 2]]""") def test_no_duplicate_tracking(self): def inner_fn(x, y): x0 = x + 1 y0 = y + 2 z = x0.sum() + y0.sum() return z def fn(x, y): o0 = inner_fn(x, y) o1 = torch.sin(y) o2 = inner_fn(x, o1) o3 = inner_fn(x, y) o4 = o3 * o3 return o2 * o4 + o0 graph, tracker = extract_graph_and_tracker( fn, torch.rand(10, 10), torch.ones(10, 20) ) self.assertExpectedInline( tracker.node_to_duplicates, """{l_x_: [l_x_], x0: [x0, x0_1, x0_2], l_y_: [l_y_], y0: [y0, y0_1, y0_2], sum_1: \ [sum_1, sum_3, sum_5], sum_2: [sum_2, sum_4, sum_6], z: [z, z_1, z_2], o1: [o1], x0_1: [x0, x0_1, x0_2], y0_1: [y0, y0_1, y0_2], \ sum_3: [sum_1, sum_3, sum_5], sum_4: [sum_2, sum_4, sum_6], \ z_1: [z, z_1, z_2], x0_2: [x0, x0_1, x0_2], y0_2: [y0, y0_1, y0_2], sum_5: [sum_1, sum_3, sum_5], sum_6: [sum_2, sum_4, sum_6], \ z_2: [z, z_1, z_2], o4: [o4], mul_1: [mul_1], add_9: [add_9]}""", ) key = next(iter(tracker.node_to_duplicates.keys())) tracker.track_node(None, key) # this will fail if the node is added again if __name__ == "__main__": from torch._dynamo.test_case import run_tests run_tests()