mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
related commits: - #139706 - #140238 - #140247 - #140253 - #140663 - #140688 Pull Request resolved: https://github.com/pytorch/pytorch/pull/140688 Approved by: https://github.com/williamwen42
404 lines
14 KiB
Python
404 lines
14 KiB
Python
# Owner(s): ["module: dynamo"]
|
|
# flake8: noqa
|
|
import torch
|
|
import torch._dynamo
|
|
import torch._dynamo.test_case
|
|
import torch._dynamo.testing
|
|
from torch._dynamo.testing import (
|
|
CompileCounter,
|
|
CompileCounterWithBackend,
|
|
EagerAndRecordGraphs,
|
|
normalize_gm,
|
|
)
|
|
|
|
|
|
class TestInputAttrTracking(torch._dynamo.test_case.TestCase):
|
|
def test_tensor_property_on_tensor(self):
|
|
def fn(x):
|
|
return x * x.y
|
|
|
|
x_ = torch.randn([2, 2])
|
|
y_ = torch.randn([2, 2])
|
|
x_.y = y_
|
|
|
|
eager_result = fn(x_)
|
|
|
|
graph = None
|
|
|
|
def grab_graph_backend(gm, inps):
|
|
nonlocal graph
|
|
graph = gm
|
|
return gm
|
|
|
|
fn = torch.compile(fn, backend=grab_graph_backend, fullgraph=True)
|
|
compile_result = fn(x_)
|
|
self.assertEqual(eager_result, compile_result)
|
|
|
|
placeholder_cnt = 0
|
|
for node in graph.graph.nodes:
|
|
if node.op == "placeholder":
|
|
placeholder_cnt += 1
|
|
|
|
# We want to be very sure that this lifts y to inputs!
|
|
self.assertEqual(placeholder_cnt, 2)
|
|
|
|
def test_tensor_property_assigned_on_tensor(self):
|
|
def fn(x, y):
|
|
x.y = y
|
|
return x * x.y
|
|
|
|
x_ = torch.randn([2, 2])
|
|
y_ = torch.randn([2, 2])
|
|
|
|
eager_result = fn(x_, y_)
|
|
|
|
graph = None
|
|
|
|
def grab_graph_backend(gm, inps):
|
|
nonlocal graph
|
|
graph = gm
|
|
return gm
|
|
|
|
fn = torch.compile(fn, backend=grab_graph_backend, fullgraph=True)
|
|
compile_result = fn(x_, y_)
|
|
self.assertEqual(eager_result, compile_result)
|
|
|
|
placeholder_cnt = 0
|
|
for node in graph.graph.nodes:
|
|
if node.op == "placeholder":
|
|
placeholder_cnt += 1
|
|
|
|
# y is already an input
|
|
self.assertEqual(placeholder_cnt, 2)
|
|
|
|
def test_const_property_on_tensor(self):
|
|
def fn(x):
|
|
return x * x.y
|
|
|
|
x_ = torch.randn([2, 2])
|
|
y_ = 4
|
|
x_.y = y_
|
|
|
|
eager_result = fn(x_)
|
|
|
|
graph = None
|
|
|
|
def grab_graph_backend(gm, inps):
|
|
nonlocal graph
|
|
graph = gm
|
|
return gm
|
|
|
|
fn = torch.compile(fn, backend=grab_graph_backend, fullgraph=True)
|
|
compile_result = fn(x_)
|
|
self.assertEqual(eager_result, compile_result)
|
|
|
|
placeholder_cnt = 0
|
|
for node in graph.graph.nodes:
|
|
if node.op == "placeholder":
|
|
placeholder_cnt += 1
|
|
|
|
# We want to be very sure that this does not lifts y to inputs, as its a const
|
|
self.assertEqual(placeholder_cnt, 1)
|
|
|
|
def test_const_property_assigned_on_tensor(self):
|
|
def fn(x, y):
|
|
x.y = y
|
|
return x * x.y
|
|
|
|
x_ = torch.randn([2, 2])
|
|
y_ = 4
|
|
|
|
eager_result = fn(x_, y_)
|
|
|
|
fn = torch.compile(fn, backend="eager", fullgraph=True)
|
|
compile_result = fn(x_, y_)
|
|
self.assertEqual(eager_result, compile_result)
|
|
|
|
def test_guards_correctly_property_assigned_on_tensor_type_change(self):
|
|
def fn(x, y):
|
|
x.y = y
|
|
return x * x.y
|
|
|
|
x_ = torch.randn([2, 2])
|
|
|
|
fn = torch.compile(fn, backend="eager", fullgraph=True)
|
|
compile_result_const = fn(x_, 4)
|
|
self.assertEqual(compile_result_const, x_ * 4)
|
|
|
|
y = torch.randn([2, 2])
|
|
compile_result_tensor = fn(x_, y)
|
|
self.assertEqual(compile_result_tensor, x_ * y)
|
|
|
|
def test_guards_correctly_property_assigned_on_tensor_type_change_inductor(self):
|
|
def fn(x, y):
|
|
x.y = y
|
|
return x * x.y
|
|
|
|
x_ = torch.randn([2, 2])
|
|
|
|
fn = torch.compile(fn, backend="inductor", fullgraph=True)
|
|
compile_result_const = fn(x_, 4)
|
|
self.assertEqual(compile_result_const, x_ * 4)
|
|
|
|
y = torch.randn([2, 2])
|
|
compile_result_tensor = fn(x_, y)
|
|
self.assertEqual(compile_result_tensor, x_ * y)
|
|
|
|
def test_complex_attr_access_without_graph_breaks(self):
|
|
def fn(x, y, z):
|
|
for t in x:
|
|
t.y = y
|
|
t.z = y * z
|
|
|
|
new_y = 1
|
|
new_z = 1
|
|
for t in x:
|
|
new_y = t.y * new_y
|
|
new_z = t.z * new_z
|
|
|
|
return new_y, new_z
|
|
|
|
x_0 = torch.randn([2, 2])
|
|
x_1 = torch.randn([2, 2])
|
|
x_2 = torch.randn([2, 2])
|
|
x = [x_0, x_1, x_2]
|
|
|
|
y = torch.randn([2, 2])
|
|
z = 5
|
|
|
|
eager_result = fn(x, y, z)
|
|
|
|
counter = CompileCounter()
|
|
fn = torch.compile(fn, backend=counter, fullgraph=True)
|
|
|
|
compile_result = fn(x, y, z)
|
|
self.assertEqual(compile_result, eager_result)
|
|
self.assertEqual(counter.frame_count, 1)
|
|
self.assertEqual(counter.op_count, 9)
|
|
# Graph for reference
|
|
# ------------- ------ ----------------------- ------------------------------------ --------
|
|
# placeholder l_y_ L_y_ () {}
|
|
# call_function mul <built-in function mul> (l_y_, 5) {}
|
|
# call_function mul_1 <built-in function mul> (l_y_, 5) {}
|
|
# call_function mul_2 <built-in function mul> (l_y_, 5) {}
|
|
# call_function mul_3 <built-in function mul> (l_y_, 1) {}
|
|
# call_function mul_4 <built-in function mul> (mul, 1) {}
|
|
# call_function mul_5 <built-in function mul> (l_y_, mul_3) {}
|
|
# call_function mul_6 <built-in function mul> (mul_1, mul_4) {}
|
|
# call_function mul_7 <built-in function mul> (l_y_, mul_5) {}
|
|
# call_function mul_8 <built-in function mul> (mul_2, mul_6) {}
|
|
# output output output ((mul_7, mul_8, mul, mul_1, mul_2),) {}
|
|
|
|
def test_complex_attr_access_with_graph_breaks(self):
|
|
def fn(x, y, z):
|
|
for t in x:
|
|
t.y = y
|
|
t.z = y * z
|
|
|
|
print("Break!")
|
|
|
|
new_y = 1
|
|
new_z = 1
|
|
for t in x:
|
|
new_y = t.y * new_y
|
|
new_z = t.z * new_z
|
|
|
|
return new_y, new_z
|
|
|
|
x_0 = torch.randn([2, 2])
|
|
x_1 = torch.randn([2, 2])
|
|
x_2 = torch.randn([2, 2])
|
|
x = [x_0, x_1, x_2]
|
|
|
|
y = torch.randn([2, 2])
|
|
z = 5
|
|
|
|
eager_result = fn(x, y, z)
|
|
|
|
counter = CompileCounter()
|
|
fn = torch.compile(fn, backend=counter, fullgraph=False)
|
|
|
|
compile_result = fn(x, y, z)
|
|
self.assertEqual(compile_result, eager_result)
|
|
self.assertEqual(counter.frame_count, 2)
|
|
self.assertEqual(counter.op_count, 9)
|
|
# Graph for reference
|
|
# ------------- ------ ----------------------- ---------------------- --------
|
|
# placeholder l_y_ L_y_ () {}
|
|
# call_function mul <built-in function mul> (l_y_, 5) {}
|
|
# call_function mul_1 <built-in function mul> (l_y_, 5) {}
|
|
# call_function mul_2 <built-in function mul> (l_y_, 5) {}
|
|
# output output output ((mul, mul_1, mul_2),) {}
|
|
# [GRAPH BREAK!]
|
|
# ------------- ------- ----------------------- ----------------- --------
|
|
# placeholder l_x_0_y L_x_0_y () {}
|
|
# placeholder l_x_0_z L_x_0_z () {}
|
|
# placeholder l_x_1_y L_x_1_y () {}
|
|
# placeholder l_x_1_z L_x_1_z () {}
|
|
# placeholder l_x_2_y L_x_2_y () {}
|
|
# placeholder l_x_2_z L_x_2_z () {}
|
|
# call_function mul <built-in function mul> (l_x_0_y, 1) {}
|
|
# call_function mul_1 <built-in function mul> (l_x_0_z, 1) {}
|
|
# call_function mul_2 <built-in function mul> (l_x_1_y, mul) {}
|
|
# call_function mul_3 <built-in function mul> (l_x_1_z, mul_1) {}
|
|
# call_function mul_4 <built-in function mul> (l_x_2_y, mul_2) {}
|
|
# call_function mul_5 <built-in function mul> (l_x_2_z, mul_3) {}
|
|
# output output output ((mul_4, mul_5),) {}
|
|
|
|
def test_complex_attr_access_with_inline_reconstruct(self):
|
|
def inline_test_fn(x, y, z):
|
|
print("f")
|
|
return x.a + y.a + z.a
|
|
|
|
def fn(x, y, z):
|
|
x.a = 1
|
|
y.a = 2
|
|
z.a = 3
|
|
|
|
mult = inline_test_fn(x, y, z)
|
|
y = y * mult
|
|
x = x * mult
|
|
return x, y
|
|
|
|
x = torch.randn([2, 2])
|
|
y = torch.randn([2, 2])
|
|
z = torch.randn([2, 2])
|
|
|
|
eager_result = fn(x, y, z)
|
|
|
|
counter = CompileCounter()
|
|
|
|
fn = torch.compile(fn, backend=counter, fullgraph=False)
|
|
|
|
compile_result = fn(x, y, z)
|
|
self.assertEqual(compile_result, eager_result)
|
|
self.assertEqual(counter.frame_count, 1)
|
|
self.assertEqual(counter.op_count, 2)
|
|
# Graph for reference
|
|
# __compiled_fn_2 <eval_with_key>.0 opcode name target args kwargs
|
|
# ------------- ------ ----------------------- --------------- --------
|
|
# placeholder l_x_ L_x_ () {}
|
|
# placeholder l_y_ L_y_ () {}
|
|
# call_function mul <built-in function mul> (l_y_, 6) {}
|
|
# call_function mul_1 <built-in function mul> (l_x_, 6) {}
|
|
# output output output ((mul_1, mul),) {}
|
|
|
|
def test_set_data_on_input_tensor(self):
|
|
def fn(x, y):
|
|
x.data = y.data
|
|
if x.size() == y.size():
|
|
return x * y
|
|
else:
|
|
return y * y
|
|
|
|
x = torch.randn([5, 5])
|
|
y = torch.randn([2, 2])
|
|
|
|
eager_result = fn(x, y)
|
|
|
|
eager_and_record = EagerAndRecordGraphs()
|
|
|
|
counter = CompileCounterWithBackend(eager_and_record)
|
|
|
|
fn = torch.compile(fn, backend=counter, fullgraph=True)
|
|
|
|
compile_result = fn(x, y)
|
|
|
|
graph = eager_and_record.graphs[0]
|
|
actual = normalize_gm(graph.print_readable(False))
|
|
|
|
self.assertEqual(compile_result, eager_result)
|
|
self.assertEqual(counter.frame_count, 1)
|
|
self.assertEqual(counter.op_count, 6)
|
|
self.assertExpectedInline(
|
|
actual,
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(self, L_y_: "f32[2, 2]", L_x_: "f32[2, 2]"):
|
|
l_y_ = L_y_
|
|
l_x_ = L_x_
|
|
|
|
_get_data_attr: "f32[2, 2]" = torch._C._autograd._get_data_attr(l_y_)
|
|
|
|
_set_grad_enabled = torch._C._set_grad_enabled(False); _set_grad_enabled = None
|
|
|
|
set_: "f32[2, 2]" = torch_Tensor_set_(l_x_, _get_data_attr); _get_data_attr = None
|
|
|
|
_set_grad_enabled_1 = torch._C._set_grad_enabled(True); _set_grad_enabled_1 = None
|
|
|
|
_lower_version_count_by_1 = torch__dynamo_variables_builtin__lower_version_count_by_1(set_); set_ = _lower_version_count_by_1 = None
|
|
|
|
mul: "f32[2, 2]" = l_x_ * l_y_; l_x_ = l_y_ = None
|
|
return (mul,)
|
|
""",
|
|
)
|
|
|
|
# Note - this does not actually get captured in the graph yet.
|
|
# The plan of record is to introduce a set_data op, entirely subsume the operation into a call_function
|
|
# in the fx graph, and let aot_autograd handle it.
|
|
def test_set_data_on_scoped_tensor(self):
|
|
def fn(x):
|
|
z = torch.zeros([4, 4])
|
|
z.data = x.data
|
|
if x.size() == z.size():
|
|
return z * x
|
|
else:
|
|
return x
|
|
|
|
x = torch.randn([5, 5])
|
|
|
|
eager_result = fn(x)
|
|
|
|
counter = CompileCounter()
|
|
|
|
fn = torch.compile(fn, backend=counter, fullgraph=False)
|
|
|
|
compile_result = fn(x)
|
|
self.assertEqual(compile_result, eager_result)
|
|
self.assertEqual(counter.frame_count, 2)
|
|
self.assertEqual(counter.op_count, 3)
|
|
|
|
def test_set_data_on_user_defined_class_input_tensor(self):
|
|
class MyUserDefinedClass:
|
|
def __init__(self, x, y):
|
|
self.x = x
|
|
self.y = y
|
|
|
|
def do_some_setattr_stuff(self):
|
|
self.z = x * y
|
|
self.a = x + x
|
|
return self.z * self.a
|
|
|
|
x = torch.randn([5, 5])
|
|
y = torch.randn([5, 5])
|
|
mudc_1 = MyUserDefinedClass(x, y)
|
|
|
|
eager_result = mudc_1.do_some_setattr_stuff()
|
|
|
|
counter = CompileCounter()
|
|
|
|
mudc_2 = MyUserDefinedClass(x, y)
|
|
do_some_setattr_stuff = torch.compile(
|
|
mudc_2.do_some_setattr_stuff, backend=counter, fullgraph=True
|
|
)
|
|
|
|
compile_result = do_some_setattr_stuff()
|
|
self.assertEqual(compile_result, eager_result)
|
|
self.assertEqual(counter.frame_count, 1)
|
|
self.assertEqual(counter.op_count, 3)
|
|
# Graph for reference
|
|
# __compiled_fn_0 <eval_with_key>.0 opcode name target args kwargs
|
|
# ------------- ------ ----------------------- -------------------- --------
|
|
# placeholder l_x_ L_x_ () {}
|
|
# placeholder l_y_ L_y_ () {}
|
|
# call_function mul <built-in function mul> (l_x_, l_y_) {}
|
|
# call_function add <built-in function add> (l_x_, l_x_) {}
|
|
# call_function mul_1 <built-in function mul> (mul, add) {}
|
|
# output output output ((mul_1, mul, add),) {}
|
|
|
|
|
|
if __name__ == "__main__":
|
|
from torch._dynamo.test_case import run_tests
|
|
|
|
run_tests()
|