# Owner(s): ["module: dynamo"] # flake8: noqa: B950 import torch import torch._dynamo import torch._dynamo.test_case import torch._dynamo.testing from torch._dynamo.testing import ( CompileCounter, CompileCounterWithBackend, EagerAndRecordGraphs, normalize_gm, ) class TestInputAttrTracking(torch._dynamo.test_case.TestCase): def test_tensor_property_on_tensor(self): def fn(x): return x * x.y x_ = torch.randn([2, 2]) y_ = torch.randn([2, 2]) x_.y = y_ eager_result = fn(x_) graph = None def grab_graph_backend(gm, inps): nonlocal graph graph = gm return gm fn = torch.compile(fn, backend=grab_graph_backend, fullgraph=True) compile_result = fn(x_) self.assertEqual(eager_result, compile_result) placeholder_cnt = 0 for node in graph.graph.nodes: if node.op == "placeholder": placeholder_cnt += 1 # We want to be very sure that this lifts y to inputs! self.assertEqual(placeholder_cnt, 2) def test_tensor_property_assigned_on_tensor(self): def fn(x, y): x.y = y return x * x.y x_ = torch.randn([2, 2]) y_ = torch.randn([2, 2]) eager_result = fn(x_, y_) graph = None def grab_graph_backend(gm, inps): nonlocal graph graph = gm return gm fn = torch.compile(fn, backend=grab_graph_backend, fullgraph=True) compile_result = fn(x_, y_) self.assertEqual(eager_result, compile_result) placeholder_cnt = 0 for node in graph.graph.nodes: if node.op == "placeholder": placeholder_cnt += 1 # y is already an input self.assertEqual(placeholder_cnt, 2) def test_const_property_on_tensor(self): def fn(x): return x * x.y x_ = torch.randn([2, 2]) y_ = 4 x_.y = y_ eager_result = fn(x_) graph = None def grab_graph_backend(gm, inps): nonlocal graph graph = gm return gm fn = torch.compile(fn, backend=grab_graph_backend, fullgraph=True) compile_result = fn(x_) self.assertEqual(eager_result, compile_result) placeholder_cnt = 0 for node in graph.graph.nodes: if node.op == "placeholder": placeholder_cnt += 1 # We want to be very sure that this does not lifts y to inputs, as its a const self.assertEqual(placeholder_cnt, 1) def test_const_property_assigned_on_tensor(self): def fn(x, y): x.y = y return x * x.y x_ = torch.randn([2, 2]) y_ = 4 eager_result = fn(x_, y_) fn = torch.compile(fn, backend="eager", fullgraph=True) compile_result = fn(x_, y_) self.assertEqual(eager_result, compile_result) def test_guards_correctly_property_assigned_on_tensor_type_change(self): def fn(x, y): x.y = y return x * x.y x_ = torch.randn([2, 2]) fn = torch.compile(fn, backend="eager", fullgraph=True) compile_result_const = fn(x_, 4) self.assertEqual(compile_result_const, x_ * 4) y = torch.randn([2, 2]) compile_result_tensor = fn(x_, y) self.assertEqual(compile_result_tensor, x_ * y) def test_guards_correctly_property_assigned_on_tensor_type_change_inductor(self): def fn(x, y): x.y = y return x * x.y x_ = torch.randn([2, 2]) fn = torch.compile(fn, backend="inductor", fullgraph=True) compile_result_const = fn(x_, 4) self.assertEqual(compile_result_const, x_ * 4) y = torch.randn([2, 2]) compile_result_tensor = fn(x_, y) self.assertEqual(compile_result_tensor, x_ * y) def test_complex_attr_access_without_graph_breaks(self): def fn(x, y, z): for t in x: t.y = y t.z = y * z new_y = 1 new_z = 1 for t in x: new_y = t.y * new_y new_z = t.z * new_z return new_y, new_z x_0 = torch.randn([2, 2]) x_1 = torch.randn([2, 2]) x_2 = torch.randn([2, 2]) x = [x_0, x_1, x_2] y = torch.randn([2, 2]) z = 5 eager_result = fn(x, y, z) counter = CompileCounter() fn = torch.compile(fn, backend=counter, fullgraph=True) compile_result = fn(x, y, z) self.assertEqual(compile_result, eager_result) self.assertEqual(counter.frame_count, 1) self.assertEqual(counter.op_count, 9) # Graph for reference # ------------- ------ ----------------------- ------------------------------------ -------- # placeholder l_y_ L_y_ () {} # call_function mul (l_y_, 5) {} # call_function mul_1 (l_y_, 5) {} # call_function mul_2 (l_y_, 5) {} # call_function mul_3 (l_y_, 1) {} # call_function mul_4 (mul, 1) {} # call_function mul_5 (l_y_, mul_3) {} # call_function mul_6 (mul_1, mul_4) {} # call_function mul_7 (l_y_, mul_5) {} # call_function mul_8 (mul_2, mul_6) {} # output output output ((mul_7, mul_8, mul, mul_1, mul_2),) {} def test_complex_attr_access_with_graph_breaks(self): def fn(x, y, z): for t in x: t.y = y t.z = y * z print("Break!") new_y = 1 new_z = 1 for t in x: new_y = t.y * new_y new_z = t.z * new_z return new_y, new_z x_0 = torch.randn([2, 2]) x_1 = torch.randn([2, 2]) x_2 = torch.randn([2, 2]) x = [x_0, x_1, x_2] y = torch.randn([2, 2]) z = 5 eager_result = fn(x, y, z) counter = CompileCounter() fn = torch.compile(fn, backend=counter, fullgraph=False) compile_result = fn(x, y, z) self.assertEqual(compile_result, eager_result) self.assertEqual(counter.frame_count, 2) self.assertEqual(counter.op_count, 9) # Graph for reference # ------------- ------ ----------------------- ---------------------- -------- # placeholder l_y_ L_y_ () {} # call_function mul (l_y_, 5) {} # call_function mul_1 (l_y_, 5) {} # call_function mul_2 (l_y_, 5) {} # output output output ((mul, mul_1, mul_2),) {} # [GRAPH BREAK!] # ------------- ------- ----------------------- ----------------- -------- # placeholder l_x_0_y L_x_0_y () {} # placeholder l_x_0_z L_x_0_z () {} # placeholder l_x_1_y L_x_1_y () {} # placeholder l_x_1_z L_x_1_z () {} # placeholder l_x_2_y L_x_2_y () {} # placeholder l_x_2_z L_x_2_z () {} # call_function mul (l_x_0_y, 1) {} # call_function mul_1 (l_x_0_z, 1) {} # call_function mul_2 (l_x_1_y, mul) {} # call_function mul_3 (l_x_1_z, mul_1) {} # call_function mul_4 (l_x_2_y, mul_2) {} # call_function mul_5 (l_x_2_z, mul_3) {} # output output output ((mul_4, mul_5),) {} def test_complex_attr_access_with_inline_reconstruct(self): def inline_test_fn(x, y, z): print("f") return x.a + y.a + z.a def fn(x, y, z): x.a = 1 y.a = 2 z.a = 3 mult = inline_test_fn(x, y, z) y = y * mult x = x * mult return x, y x = torch.randn([2, 2]) y = torch.randn([2, 2]) z = torch.randn([2, 2]) eager_result = fn(x, y, z) counter = CompileCounter() fn = torch.compile(fn, backend=counter, fullgraph=False) compile_result = fn(x, y, z) self.assertEqual(compile_result, eager_result) self.assertEqual(counter.frame_count, 1) self.assertEqual(counter.op_count, 2) # Graph for reference # __compiled_fn_2 .0 opcode name target args kwargs # ------------- ------ ----------------------- --------------- -------- # placeholder l_x_ L_x_ () {} # placeholder l_y_ L_y_ () {} # call_function mul (l_y_, 6) {} # call_function mul_1 (l_x_, 6) {} # output output output ((mul_1, mul),) {} def test_set_data_on_input_tensor(self): def fn(x, y): x.data = y.data if x.size() == y.size(): return x * y else: return y * y x = torch.randn([5, 5]) y = torch.randn([2, 2]) eager_result = fn(x, y) eager_and_record = EagerAndRecordGraphs() counter = CompileCounterWithBackend(eager_and_record) fn = torch.compile(fn, backend=counter, fullgraph=True) compile_result = fn(x, y) graph = eager_and_record.graphs[0] actual = normalize_gm(graph.print_readable(False)) self.assertEqual(compile_result, eager_result) self.assertEqual(counter.frame_count, 1) self.assertEqual(counter.op_count, 6) self.assertExpectedInline( actual, """\ class GraphModule(torch.nn.Module): def forward(self, L_y_: "f32[2, 2]", L_x_: "f32[2, 2]"): l_y_ = L_y_ l_x_ = L_x_ _get_data_attr: "f32[2, 2]" = torch._C._autograd._get_data_attr(l_y_) _set_grad_enabled = torch._C._set_grad_enabled(False); _set_grad_enabled = None set_: "f32[2, 2]" = torch_Tensor_set_(l_x_, _get_data_attr); _get_data_attr = None _set_grad_enabled_1 = torch._C._set_grad_enabled(True); _set_grad_enabled_1 = None _lower_version_count_by_1 = torch__dynamo_variables_builtin__lower_version_count_by_1(set_); set_ = _lower_version_count_by_1 = None mul: "f32[2, 2]" = l_x_ * l_y_; l_x_ = l_y_ = None return (mul,) """, ) # Note - this does not actually get captured in the graph yet. # The plan of record is to introduce a set_data op, entirely subsume the operation into a call_function # in the fx graph, and let aot_autograd handle it. def test_set_data_on_scoped_tensor(self): def fn(x): z = torch.zeros([4, 4]) z.data = x.data if x.size() == z.size(): return z * x else: return x x = torch.randn([5, 5]) eager_result = fn(x) counter = CompileCounter() fn = torch.compile(fn, backend=counter, fullgraph=False) compile_result = fn(x) self.assertEqual(compile_result, eager_result) self.assertEqual(counter.frame_count, 2) self.assertEqual(counter.op_count, 3) def test_set_data_on_user_defined_class_input_tensor(self): class MyUserDefinedClass: def __init__(self, x, y): self.x = x self.y = y def do_some_setattr_stuff(self): self.z = x * y self.a = x + x return self.z * self.a x = torch.randn([5, 5]) y = torch.randn([5, 5]) mudc_1 = MyUserDefinedClass(x, y) eager_result = mudc_1.do_some_setattr_stuff() counter = CompileCounter() mudc_2 = MyUserDefinedClass(x, y) do_some_setattr_stuff = torch.compile( mudc_2.do_some_setattr_stuff, backend=counter, fullgraph=True ) compile_result = do_some_setattr_stuff() self.assertEqual(compile_result, eager_result) self.assertEqual(counter.frame_count, 1) self.assertEqual(counter.op_count, 3) # Graph for reference # __compiled_fn_0 .0 opcode name target args kwargs # ------------- ------ ----------------------- -------------------- -------- # placeholder l_x_ L_x_ () {} # placeholder l_y_ L_y_ () {} # call_function mul (l_x_, l_y_) {} # call_function add (l_x_, l_x_) {} # call_function mul_1 (mul, add) {} # output output output ((mul_1, mul, add),) {} if __name__ == "__main__": from torch._dynamo.test_case import run_tests run_tests()