mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Pull Request resolved: https://github.com/pytorch/pytorch/pull/150782 Approved by: https://github.com/bdhirsh ghstack dependencies: #150666
668 lines
27 KiB
Python
668 lines
27 KiB
Python
# Owner(s): ["module: dynamo"]
|
|
# flake8: noqa: B950
|
|
import torch
|
|
import torch.fx
|
|
from torch._dynamo.graph_deduplication import _flatten_args_kwargs
|
|
from torch._dynamo.graph_utils import _detect_cycles
|
|
from torch._dynamo.test_case import TestCase
|
|
from torch._dynamo.testing import AotEagerAndRecordGraphs, normalize_gm
|
|
|
|
|
|
def extract_graph(fn, *args, **kwargs):
|
|
backend = AotEagerAndRecordGraphs()
|
|
result = torch.compile(backend=backend)(fn)(*args, **kwargs)
|
|
return result, backend.graphs, backend.fw_graphs
|
|
|
|
|
|
def graph_str(gm):
|
|
return normalize_gm(gm.print_readable(print_output=False))
|
|
|
|
|
|
class GraphDededuplicationTests(TestCase):
|
|
def run_and_return_graphs(self, fn, *args, **kwargs):
|
|
with torch._dynamo.config.patch("use_graph_deduplication", True):
|
|
return extract_graph(fn, *args, **kwargs)
|
|
|
|
def test_single_subgraph(self):
|
|
def inner_fn(x, y):
|
|
x0 = x + 1
|
|
y0 = y + 2
|
|
z = x0.sum() + y0.sum()
|
|
return z
|
|
|
|
def fn(x, y):
|
|
_o0 = inner_fn(x, y)
|
|
o1 = torch.sin(y)
|
|
o2 = inner_fn(x, o1)
|
|
o3 = inner_fn(x, y)
|
|
o4 = o3 * o3
|
|
return o2 * o4
|
|
|
|
x = torch.rand(10, 10, requires_grad=True)
|
|
y = torch.rand(10, 20, requires_grad=True)
|
|
x_clone = x.clone().requires_grad_(True)
|
|
y_clone = y.clone().requires_grad_(True)
|
|
|
|
ref_result = fn(x, y)
|
|
result, graphs, fw_graphs = self.run_and_return_graphs(fn, x_clone, y_clone)
|
|
|
|
torch.allclose(ref_result, result)
|
|
ref_result.sum().backward()
|
|
result.sum().backward()
|
|
|
|
self.assertEqual(len(graphs), 1)
|
|
self.assertEqual(len(fw_graphs), 1)
|
|
self.assertExpectedInline(
|
|
graph_str(graphs[0]),
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(self, L_x_: "f32[10, 10]", L_y_: "f32[10, 20]"):
|
|
subgraph_0 = self.subgraph_0
|
|
l_x_ = L_x_
|
|
l_y_ = L_y_
|
|
|
|
o1: "f32[10, 20]" = torch.sin(l_y_)
|
|
|
|
invoke_subgraph = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', (l_x_, l_y_)); invoke_subgraph = None
|
|
invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', (l_x_, o1)); o1 = None
|
|
|
|
getitem_1: "f32[]" = invoke_subgraph_1[0]; invoke_subgraph_1 = None
|
|
|
|
invoke_subgraph_2 = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', (l_x_, l_y_)); subgraph_0 = l_x_ = l_y_ = None
|
|
|
|
getitem_2: "f32[]" = invoke_subgraph_2[0]; invoke_subgraph_2 = None
|
|
|
|
o4: "f32[]" = getitem_2 * getitem_2; getitem_2 = None
|
|
|
|
mul_1: "f32[]" = getitem_1 * o4; getitem_1 = o4 = None
|
|
return (mul_1,)
|
|
|
|
class subgraph_0(torch.nn.Module):
|
|
def forward(self, subgraph_input_l_x_, subgraph_input_l_y_):
|
|
x0: "f32[10, 10]" = subgraph_input_l_x_ + 1; subgraph_input_l_x_ = None
|
|
|
|
y0: "f32[10, 20]" = subgraph_input_l_y_ + 2; subgraph_input_l_y_ = None
|
|
|
|
sum_1: "f32[]" = x0.sum(); x0 = None
|
|
sum_2: "f32[]" = y0.sum(); y0 = None
|
|
z: "f32[]" = sum_1 + sum_2; sum_1 = sum_2 = None
|
|
return (z,)
|
|
""",
|
|
)
|
|
|
|
self.assertExpectedInline(
|
|
graph_str(fw_graphs[0]),
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(self, primals_1: "f32[10, 10]", primals_2: "f32[10, 20]"):
|
|
sin: "f32[10, 20]" = torch.ops.aten.sin.default(primals_2)
|
|
|
|
___forward_subgraph_0_post_graph = self.___forward_subgraph_0_post_graph
|
|
invoke_subgraph_5 = torch.ops.higher_order.invoke_subgraph(___forward_subgraph_0_post_graph, '___forward_subgraph_0_post_graph', (primals_1, sin)); ___forward_subgraph_0_post_graph = sin = None
|
|
getitem_1: "f32[]" = invoke_subgraph_5[0]; invoke_subgraph_5 = None
|
|
___forward_subgraph_0_post_graph_1 = self.___forward_subgraph_0_post_graph
|
|
invoke_subgraph_6 = torch.ops.higher_order.invoke_subgraph(___forward_subgraph_0_post_graph_1, '___forward_subgraph_0_post_graph', (primals_1, primals_2)); ___forward_subgraph_0_post_graph_1 = primals_1 = None
|
|
getitem_2: "f32[]" = invoke_subgraph_6[0]; invoke_subgraph_6 = None
|
|
|
|
mul: "f32[]" = torch.ops.aten.mul.Tensor(getitem_2, getitem_2)
|
|
|
|
mul_1: "f32[]" = torch.ops.aten.mul.Tensor(getitem_1, mul); mul = None
|
|
return (mul_1, primals_2, getitem_1, getitem_2)
|
|
|
|
class ___forward_subgraph_0_post_graph(torch.nn.Module):
|
|
def forward(self, primals_0: "f32[10, 10]", primals_1: "f32[10, 20]"):
|
|
add: "f32[10, 10]" = torch.ops.aten.add.Tensor(primals_0, 1); primals_0 = None
|
|
|
|
add_1: "f32[10, 20]" = torch.ops.aten.add.Tensor(primals_1, 2); primals_1 = None
|
|
|
|
sum_1: "f32[]" = torch.ops.aten.sum.default(add); add = None
|
|
sum_2: "f32[]" = torch.ops.aten.sum.default(add_1); add_1 = None
|
|
add_2: "f32[]" = torch.ops.aten.add.Tensor(sum_1, sum_2); sum_1 = sum_2 = None
|
|
return (add_2,)
|
|
""",
|
|
)
|
|
|
|
def test_single_subgraph2(self):
|
|
def fn(x):
|
|
x0 = x + 2
|
|
o = inner_fn(x0)
|
|
o = torch.cos(o)
|
|
o = inner_fn(o)
|
|
return torch.sin(o)
|
|
|
|
def inner_fn(x):
|
|
o = x * 7
|
|
o += 1
|
|
o += 2
|
|
return o
|
|
|
|
x = torch.rand(10, 10, requires_grad=True)
|
|
x_clone = x.clone().requires_grad_(True)
|
|
|
|
ref_result = fn(x)
|
|
result, graphs, fw_graphs = self.run_and_return_graphs(fn, x_clone)
|
|
|
|
torch.allclose(ref_result, result)
|
|
ref_result.sum().backward()
|
|
result.sum().backward()
|
|
self.assertEqual(len(graphs), 1)
|
|
self.assertEqual(len(fw_graphs), 1)
|
|
self.assertExpectedInline(
|
|
graph_str(graphs[0]),
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(self, L_x_: "f32[10, 10]"):
|
|
subgraph_0 = self.subgraph_0
|
|
l_x_ = L_x_
|
|
|
|
x0: "f32[10, 10]" = l_x_ + 2; l_x_ = None
|
|
|
|
invoke_subgraph = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', (x0,)); x0 = None
|
|
|
|
getitem: "f32[10, 10]" = invoke_subgraph[0]; invoke_subgraph = None
|
|
|
|
o_3: "f32[10, 10]" = torch.cos(getitem); getitem = None
|
|
|
|
invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', (o_3,)); subgraph_0 = o_3 = None
|
|
|
|
getitem_1: "f32[10, 10]" = invoke_subgraph_1[0]; invoke_subgraph_1 = None
|
|
|
|
sin: "f32[10, 10]" = torch.sin(getitem_1); getitem_1 = None
|
|
return (sin,)
|
|
|
|
class subgraph_0(torch.nn.Module):
|
|
def forward(self, subgraph_input_x0):
|
|
o: "f32[10, 10]" = subgraph_input_x0 * 7; subgraph_input_x0 = None
|
|
|
|
o += 1; o_1: "f32[10, 10]" = o; o = None
|
|
|
|
o_1 += 2; o_2: "f32[10, 10]" = o_1; o_1 = None
|
|
return (o_2,)
|
|
""",
|
|
)
|
|
self.assertExpectedInline(
|
|
graph_str(fw_graphs[0]),
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(self, primals_1: "f32[10, 10]"):
|
|
add: "f32[10, 10]" = torch.ops.aten.add.Tensor(primals_1, 2); primals_1 = None
|
|
|
|
___forward_subgraph_0_post_graph = self.___forward_subgraph_0_post_graph
|
|
invoke_subgraph_4 = torch.ops.higher_order.invoke_subgraph(___forward_subgraph_0_post_graph, '___forward_subgraph_0_post_graph', (add,)); ___forward_subgraph_0_post_graph = add = None
|
|
getitem: "f32[10, 10]" = invoke_subgraph_4[0]; invoke_subgraph_4 = None
|
|
|
|
cos: "f32[10, 10]" = torch.ops.aten.cos.default(getitem)
|
|
|
|
___forward_subgraph_0_post_graph_1 = self.___forward_subgraph_0_post_graph
|
|
invoke_subgraph_5 = torch.ops.higher_order.invoke_subgraph(___forward_subgraph_0_post_graph_1, '___forward_subgraph_0_post_graph', (cos,)); ___forward_subgraph_0_post_graph_1 = cos = None
|
|
getitem_1: "f32[10, 10]" = invoke_subgraph_5[0]; invoke_subgraph_5 = None
|
|
|
|
sin: "f32[10, 10]" = torch.ops.aten.sin.default(getitem_1)
|
|
cos_1: "f32[10, 10]" = torch.ops.aten.cos.default(getitem_1); getitem_1 = None
|
|
|
|
sin_1: "f32[10, 10]" = torch.ops.aten.sin.default(getitem); getitem = None
|
|
neg: "f32[10, 10]" = torch.ops.aten.neg.default(sin_1); sin_1 = None
|
|
return (sin, cos_1, neg)
|
|
|
|
class ___forward_subgraph_0_post_graph(torch.nn.Module):
|
|
def forward(self, primals_0: "f32[10, 10]"):
|
|
mul: "f32[10, 10]" = torch.ops.aten.mul.Tensor(primals_0, 7); primals_0 = None
|
|
|
|
add: "f32[10, 10]" = torch.ops.aten.add.Tensor(mul, 1); mul = None
|
|
|
|
add_1: "f32[10, 10]" = torch.ops.aten.add.Tensor(add, 2); add = None
|
|
return (add_1,)
|
|
""",
|
|
)
|
|
|
|
def test_multiple_subgraphs(self):
|
|
def inner_fn(x, y):
|
|
x1 = x + 1
|
|
y1 = y + 2
|
|
z = x1.sum() + y1.sum()
|
|
return z
|
|
|
|
def inner_fn2(a, b):
|
|
a0 = a + 2
|
|
b0 = b + 3
|
|
c = a0 * b0.cos().sum()
|
|
return c
|
|
|
|
def fn(x, y):
|
|
x0 = torch.cos(x)
|
|
y0 = torch.sin(y)
|
|
o1 = inner_fn2(x0, y0)
|
|
o0 = inner_fn(x, y)
|
|
o1 = torch.sin(o0)
|
|
o2 = inner_fn(x, y0)
|
|
o3 = inner_fn2(x0, y0)
|
|
o4 = inner_fn(x, y)
|
|
return o1 * o2 * o3 + o4
|
|
|
|
x = torch.rand(10, 10, requires_grad=True)
|
|
y = torch.rand(10, 20, requires_grad=True)
|
|
x_clone = x.clone().requires_grad_(True)
|
|
y_clone = y.clone().requires_grad_(True)
|
|
|
|
ref_result = fn(x, y)
|
|
result, graphs, fw_graphs = self.run_and_return_graphs(fn, x_clone, y_clone)
|
|
|
|
torch.allclose(ref_result, result)
|
|
ref_result.sum().backward()
|
|
result.sum().backward()
|
|
self.assertEqual(len(graphs), 1)
|
|
self.assertEqual(len(fw_graphs), 1)
|
|
|
|
self.assertExpectedInline(
|
|
graph_str(graphs[0]),
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(self, L_x_: "f32[10, 10]", L_y_: "f32[10, 20]"):
|
|
subgraph_1 = self.subgraph_1
|
|
subgraph_0 = self.subgraph_0
|
|
l_x_ = L_x_
|
|
l_y_ = L_y_
|
|
|
|
x0: "f32[10, 10]" = torch.cos(l_x_)
|
|
|
|
y0: "f32[10, 20]" = torch.sin(l_y_)
|
|
|
|
invoke_subgraph = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', (l_x_, l_y_))
|
|
|
|
getitem: "f32[]" = invoke_subgraph[0]; invoke_subgraph = None
|
|
|
|
o1: "f32[]" = torch.sin(getitem); getitem = None
|
|
|
|
invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', (l_x_, y0))
|
|
|
|
getitem_1: "f32[]" = invoke_subgraph_1[0]; invoke_subgraph_1 = None
|
|
|
|
mul_2: "f32[]" = o1 * getitem_1; o1 = getitem_1 = None
|
|
|
|
invoke_subgraph_2 = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', (l_x_, l_y_)); subgraph_0 = l_x_ = l_y_ = None
|
|
|
|
getitem_2: "f32[]" = invoke_subgraph_2[0]; invoke_subgraph_2 = None
|
|
|
|
invoke_subgraph_3 = torch.ops.higher_order.invoke_subgraph(subgraph_1, 'subgraph_1', (x0, y0)); invoke_subgraph_3 = None
|
|
invoke_subgraph_4 = torch.ops.higher_order.invoke_subgraph(subgraph_1, 'subgraph_1', (x0, y0)); subgraph_1 = x0 = y0 = None
|
|
|
|
getitem_4: "f32[10, 10]" = invoke_subgraph_4[0]; invoke_subgraph_4 = None
|
|
|
|
mul_3: "f32[10, 10]" = mul_2 * getitem_4; mul_2 = getitem_4 = None
|
|
add_13: "f32[10, 10]" = mul_3 + getitem_2; mul_3 = getitem_2 = None
|
|
return (add_13,)
|
|
|
|
class subgraph_1(torch.nn.Module):
|
|
def forward(self, subgraph_input_x0, subgraph_input_y0):
|
|
a0: "f32[10, 10]" = subgraph_input_x0 + 2; subgraph_input_x0 = None
|
|
|
|
b0: "f32[10, 20]" = subgraph_input_y0 + 3; subgraph_input_y0 = None
|
|
|
|
cos_1: "f32[10, 20]" = b0.cos(); b0 = None
|
|
sum_1: "f32[]" = cos_1.sum(); cos_1 = None
|
|
c: "f32[10, 10]" = a0 * sum_1; a0 = sum_1 = None
|
|
return (c,)
|
|
|
|
class subgraph_0(torch.nn.Module):
|
|
def forward(self, subgraph_input_l_x_, subgraph_input_l_y_):
|
|
x1: "f32[10, 10]" = subgraph_input_l_x_ + 1; subgraph_input_l_x_ = None
|
|
|
|
y1: "f32[10, 20]" = subgraph_input_l_y_ + 2; subgraph_input_l_y_ = None
|
|
|
|
sum_2: "f32[]" = x1.sum(); x1 = None
|
|
sum_3: "f32[]" = y1.sum(); y1 = None
|
|
z: "f32[]" = sum_2 + sum_3; sum_2 = sum_3 = None
|
|
return (z,)
|
|
""",
|
|
)
|
|
self.assertExpectedInline(
|
|
graph_str(fw_graphs[0]),
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(self, primals_1: "f32[10, 10]", primals_2: "f32[10, 20]"):
|
|
cos: "f32[10, 10]" = torch.ops.aten.cos.default(primals_1)
|
|
|
|
sin: "f32[10, 20]" = torch.ops.aten.sin.default(primals_2)
|
|
|
|
___forward_subgraph_0_post_graph = self.___forward_subgraph_0_post_graph
|
|
invoke_subgraph_9 = torch.ops.higher_order.invoke_subgraph(___forward_subgraph_0_post_graph, '___forward_subgraph_0_post_graph', (primals_1, primals_2)); ___forward_subgraph_0_post_graph = None
|
|
getitem: "f32[]" = invoke_subgraph_9[0]; invoke_subgraph_9 = None
|
|
|
|
sin_1: "f32[]" = torch.ops.aten.sin.default(getitem)
|
|
|
|
___forward_subgraph_0_post_graph_1 = self.___forward_subgraph_0_post_graph
|
|
invoke_subgraph_10 = torch.ops.higher_order.invoke_subgraph(___forward_subgraph_0_post_graph_1, '___forward_subgraph_0_post_graph', (primals_1, sin)); ___forward_subgraph_0_post_graph_1 = None
|
|
getitem_1: "f32[]" = invoke_subgraph_10[0]; invoke_subgraph_10 = None
|
|
|
|
mul: "f32[]" = torch.ops.aten.mul.Tensor(sin_1, getitem_1); sin_1 = None
|
|
|
|
___forward_subgraph_0_post_graph_2 = self.___forward_subgraph_0_post_graph
|
|
invoke_subgraph_11 = torch.ops.higher_order.invoke_subgraph(___forward_subgraph_0_post_graph_2, '___forward_subgraph_0_post_graph', (primals_1, primals_2)); ___forward_subgraph_0_post_graph_2 = None
|
|
getitem_2: "f32[]" = invoke_subgraph_11[0]; invoke_subgraph_11 = None
|
|
___forward_subgraph_1_post_graph = self.___forward_subgraph_1_post_graph
|
|
invoke_subgraph_12 = torch.ops.higher_order.invoke_subgraph(___forward_subgraph_1_post_graph, '___forward_subgraph_1_post_graph', (cos, sin)); ___forward_subgraph_1_post_graph = cos = sin = None
|
|
getitem_19: "f32[]" = invoke_subgraph_12[3]
|
|
getitem_18: "f32[10, 20]" = invoke_subgraph_12[2]
|
|
getitem_17: "f32[10, 10]" = invoke_subgraph_12[1]
|
|
getitem_4: "f32[10, 10]" = invoke_subgraph_12[0]; invoke_subgraph_12 = None
|
|
|
|
mul_1: "f32[10, 10]" = torch.ops.aten.mul.Tensor(mul, getitem_4); mul = None
|
|
add: "f32[10, 10]" = torch.ops.aten.add.Tensor(mul_1, getitem_2); mul_1 = getitem_2 = None
|
|
return (add, primals_1, primals_2, getitem, getitem_1, getitem_19, getitem_18, getitem_17, getitem_4)
|
|
|
|
class ___forward_subgraph_0_post_graph(torch.nn.Module):
|
|
def forward(self, primals_0: "f32[10, 10]", primals_1: "f32[10, 20]"):
|
|
add: "f32[10, 10]" = torch.ops.aten.add.Tensor(primals_0, 1); primals_0 = None
|
|
|
|
add_1: "f32[10, 20]" = torch.ops.aten.add.Tensor(primals_1, 2); primals_1 = None
|
|
|
|
sum_1: "f32[]" = torch.ops.aten.sum.default(add); add = None
|
|
sum_2: "f32[]" = torch.ops.aten.sum.default(add_1); add_1 = None
|
|
add_2: "f32[]" = torch.ops.aten.add.Tensor(sum_1, sum_2); sum_1 = sum_2 = None
|
|
return (add_2,)
|
|
|
|
class ___forward_subgraph_1_post_graph(torch.nn.Module):
|
|
def forward(self, primals_0: "f32[10, 10]", primals_1: "f32[10, 20]"):
|
|
add: "f32[10, 10]" = torch.ops.aten.add.Tensor(primals_0, 2)
|
|
|
|
add_1: "f32[10, 20]" = torch.ops.aten.add.Tensor(primals_1, 3)
|
|
|
|
cos: "f32[10, 20]" = torch.ops.aten.cos.default(add_1); add_1 = None
|
|
sum_1: "f32[]" = torch.ops.aten.sum.default(cos); cos = None
|
|
mul: "f32[10, 10]" = torch.ops.aten.mul.Tensor(add, sum_1); add = None
|
|
return (mul, primals_0, primals_1, sum_1)
|
|
""",
|
|
)
|
|
|
|
def test_dependent_subgraphs(self):
|
|
def inner_fn(x, y):
|
|
x0 = x + 1
|
|
y0 = y + 2
|
|
z = x0.sum() + y0.sum()
|
|
return z
|
|
|
|
def fn(x, y):
|
|
o0 = inner_fn(x, y)
|
|
o1 = inner_fn(x, o0)
|
|
return o1
|
|
|
|
x = torch.rand(10, 10, requires_grad=True)
|
|
y = torch.rand(10, 20, requires_grad=True)
|
|
x_clone = x.clone().requires_grad_(True)
|
|
y_clone = y.clone().requires_grad_(True)
|
|
|
|
ref_result = fn(x, y)
|
|
result, graphs, fw_graphs = self.run_and_return_graphs(fn, x_clone, y_clone)
|
|
|
|
torch.allclose(ref_result, result)
|
|
ref_result.sum().backward()
|
|
result.sum().backward()
|
|
self.assertEqual(len(graphs), 1)
|
|
self.assertEqual(len(fw_graphs), 1)
|
|
self.assertExpectedInline(
|
|
graph_str(fw_graphs[0]),
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(self, primals_1: "f32[10, 10]", primals_2: "f32[10, 20]"):
|
|
add: "f32[10, 20]" = torch.ops.aten.add.Tensor(primals_2, 2); primals_2 = None
|
|
|
|
sum_1: "f32[]" = torch.ops.aten.sum.default(add); add = None
|
|
|
|
___forward_subgraph_0_post_graph = self.___forward_subgraph_0_post_graph
|
|
invoke_subgraph_4 = torch.ops.higher_order.invoke_subgraph(___forward_subgraph_0_post_graph, '___forward_subgraph_0_post_graph', (primals_1, sum_1)); ___forward_subgraph_0_post_graph = sum_1 = None
|
|
getitem: "f32[]" = invoke_subgraph_4[0]; invoke_subgraph_4 = None
|
|
|
|
add_1: "f32[]" = torch.ops.aten.add.Tensor(getitem, 2); getitem = None
|
|
|
|
sum_2: "f32[]" = torch.ops.aten.sum.default(add_1); add_1 = None
|
|
|
|
___forward_subgraph_0_post_graph_1 = self.___forward_subgraph_0_post_graph
|
|
invoke_subgraph_5 = torch.ops.higher_order.invoke_subgraph(___forward_subgraph_0_post_graph_1, '___forward_subgraph_0_post_graph', (primals_1, sum_2)); ___forward_subgraph_0_post_graph_1 = primals_1 = sum_2 = None
|
|
getitem_1: "f32[]" = invoke_subgraph_5[0]; invoke_subgraph_5 = None
|
|
return (getitem_1,)
|
|
|
|
class ___forward_subgraph_0_post_graph(torch.nn.Module):
|
|
def forward(self, primals_0: "f32[10, 10]", primals_1: "f32[]"):
|
|
add: "f32[10, 10]" = torch.ops.aten.add.Tensor(primals_0, 1); primals_0 = None
|
|
|
|
sum_1: "f32[]" = torch.ops.aten.sum.default(add); add = None
|
|
add_1: "f32[]" = torch.ops.aten.add.Tensor(sum_1, primals_1); sum_1 = primals_1 = None
|
|
return (add_1,)
|
|
""",
|
|
)
|
|
|
|
def test_input_mutation(self):
|
|
def inner_fn(x, y):
|
|
x0 = x + 1
|
|
y0 = y + 2
|
|
z = x0.sum() + y0.sum()
|
|
return z
|
|
|
|
def inner_fn2(x, y):
|
|
x0 = x + 1
|
|
y0 = y + 1
|
|
x.add_(x0)
|
|
y.add_(y0)
|
|
return x.sum() + y.sum()
|
|
|
|
def fn(x, y):
|
|
x0 = torch.sin(x)
|
|
_y0 = torch.cos(y)
|
|
# o0 = inner_fn(x0, y0)
|
|
# o1 = inner_fn(x0, o0)
|
|
o2 = inner_fn2(x0, y)
|
|
o3 = inner_fn2(x0.clone(), y.clone())
|
|
return o2 + o3
|
|
|
|
x = torch.rand(10, 10, requires_grad=False)
|
|
y = torch.rand(10, 20, requires_grad=False)
|
|
x_clone = x.clone()
|
|
y_clone = y.clone()
|
|
|
|
ref_result = fn(x, y)
|
|
result, graphs, fw_graphs = self.run_and_return_graphs(fn, x_clone, y_clone)
|
|
|
|
torch.allclose(ref_result, result)
|
|
self.assertEqual(len(graphs), 1)
|
|
self.assertEqual(len(fw_graphs), 1)
|
|
self.assertExpectedInline(
|
|
graph_str(fw_graphs[0]),
|
|
"""\
|
|
class <lambda>(torch.nn.Module):
|
|
def forward(self, arg0_1: "f32[10, 10]", arg1_1: "f32[10, 20]"):
|
|
sin: "f32[10, 10]" = torch.ops.aten.sin.default(arg0_1); arg0_1 = None
|
|
|
|
add: "f32[10, 10]" = torch.ops.aten.add.Tensor(sin, 1)
|
|
|
|
add_1: "f32[10, 20]" = torch.ops.aten.add.Tensor(arg1_1, 1)
|
|
|
|
add_2: "f32[10, 10]" = torch.ops.aten.add.Tensor(sin, add); sin = add = None
|
|
|
|
add_3: "f32[10, 20]" = torch.ops.aten.add.Tensor(arg1_1, add_1); add_1 = None
|
|
|
|
clone: "f32[10, 10]" = torch.ops.aten.clone.default(add_2)
|
|
clone_1: "f32[10, 20]" = torch.ops.aten.clone.default(add_3)
|
|
|
|
add_4: "f32[10, 10]" = torch.ops.aten.add.Tensor(clone, 1)
|
|
|
|
add_5: "f32[10, 20]" = torch.ops.aten.add.Tensor(clone_1, 1)
|
|
|
|
add_6: "f32[10, 10]" = torch.ops.aten.add.Tensor(clone, add_4); clone = add_4 = None
|
|
|
|
add_7: "f32[10, 20]" = torch.ops.aten.add.Tensor(clone_1, add_5); clone_1 = add_5 = None
|
|
|
|
repeated_subgraph0 = self.repeated_subgraph0
|
|
invoke_subgraph = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0, 'subgraph_0', (add_2, add_3)); repeated_subgraph0 = add_2 = None
|
|
getitem: "f32[]" = invoke_subgraph[0]; invoke_subgraph = None
|
|
repeated_subgraph0_1 = self.repeated_subgraph0
|
|
invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0_1, 'subgraph_0', (add_6, add_7)); repeated_subgraph0_1 = add_6 = add_7 = None
|
|
getitem_1: "f32[]" = invoke_subgraph_1[0]; invoke_subgraph_1 = None
|
|
|
|
add_8: "f32[]" = torch.ops.aten.add.Tensor(getitem, getitem_1); getitem = getitem_1 = None
|
|
|
|
copy_: "f32[10, 20]" = torch.ops.aten.copy_.default(arg1_1, add_3); arg1_1 = add_3 = copy_ = None
|
|
return (add_8,)
|
|
|
|
class repeated_subgraph0(torch.nn.Module):
|
|
def forward(self, arg0_1: "f32[10, 10]", arg1_1: "f32[10, 20]"):
|
|
sum_1: "f32[]" = torch.ops.aten.sum.default(arg0_1); arg0_1 = None
|
|
sum_2: "f32[]" = torch.ops.aten.sum.default(arg1_1); arg1_1 = None
|
|
add: "f32[]" = torch.ops.aten.add.Tensor(sum_1, sum_2); sum_1 = sum_2 = None
|
|
return (add,)
|
|
""",
|
|
)
|
|
|
|
def test_input_aliasing(self):
|
|
def inner_fn(x, y):
|
|
x0 = x.view(x.size())
|
|
return x0.view(x.size())
|
|
|
|
def inner_fn2(x, y):
|
|
x = x * 2
|
|
y = y * 2
|
|
return x.sum() + y.sum()
|
|
|
|
def fn(x, y):
|
|
o0 = inner_fn(x, y)
|
|
o1 = inner_fn(x, y)
|
|
o2 = inner_fn2(x, y)
|
|
o3 = inner_fn2(x, y)
|
|
return o0 + o1 + o2.sum() + o3.sum()
|
|
|
|
x = torch.rand(10, 10, requires_grad=False)
|
|
y = torch.rand(10, 20, requires_grad=False)
|
|
x_clone = x.clone()
|
|
y_clone = y.clone()
|
|
|
|
ref_result = fn(x, y)
|
|
result, graphs, fw_graphs = self.run_and_return_graphs(fn, x_clone, y_clone)
|
|
|
|
torch.allclose(ref_result, result)
|
|
self.assertEqual(len(graphs), 1)
|
|
self.assertEqual(len(fw_graphs), 1)
|
|
self.assertExpectedInline(
|
|
graph_str(fw_graphs[0]),
|
|
"""\
|
|
class <lambda>(torch.nn.Module):
|
|
def forward(self, arg0_1: "f32[10, 10]", arg1_1: "f32[10, 20]"):
|
|
view: "f32[10, 10]" = torch.ops.aten.view.default(arg0_1, [10, 10])
|
|
|
|
view_1: "f32[10, 10]" = torch.ops.aten.view.default(view, [10, 10]); view = None
|
|
|
|
view_2: "f32[10, 10]" = torch.ops.aten.view.default(arg0_1, [10, 10])
|
|
|
|
view_3: "f32[10, 10]" = torch.ops.aten.view.default(view_2, [10, 10]); view_2 = None
|
|
|
|
add: "f32[10, 10]" = torch.ops.aten.add.Tensor(view_1, view_3); view_1 = view_3 = None
|
|
|
|
repeated_subgraph0 = self.repeated_subgraph0
|
|
invoke_subgraph = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0, 'subgraph_0', (arg0_1, arg1_1)); repeated_subgraph0 = None
|
|
getitem: "f32[]" = invoke_subgraph[0]; invoke_subgraph = None
|
|
|
|
sum_1: "f32[]" = torch.ops.aten.sum.default(getitem); getitem = None
|
|
add_1: "f32[10, 10]" = torch.ops.aten.add.Tensor(add, sum_1); add = sum_1 = None
|
|
|
|
repeated_subgraph0_1 = self.repeated_subgraph0
|
|
invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0_1, 'subgraph_0', (arg0_1, arg1_1)); repeated_subgraph0_1 = arg0_1 = arg1_1 = None
|
|
getitem_1: "f32[]" = invoke_subgraph_1[0]; invoke_subgraph_1 = None
|
|
|
|
sum_2: "f32[]" = torch.ops.aten.sum.default(getitem_1); getitem_1 = None
|
|
add_2: "f32[10, 10]" = torch.ops.aten.add.Tensor(add_1, sum_2); add_1 = sum_2 = None
|
|
return (add_2,)
|
|
|
|
class repeated_subgraph0(torch.nn.Module):
|
|
def forward(self, arg0_1: "f32[10, 10]", arg1_1: "f32[10, 20]"):
|
|
mul: "f32[10, 10]" = torch.ops.aten.mul.Tensor(arg0_1, 2); arg0_1 = None
|
|
|
|
mul_1: "f32[10, 20]" = torch.ops.aten.mul.Tensor(arg1_1, 2); arg1_1 = None
|
|
|
|
sum_1: "f32[]" = torch.ops.aten.sum.default(mul); mul = None
|
|
sum_2: "f32[]" = torch.ops.aten.sum.default(mul_1); mul_1 = None
|
|
add: "f32[]" = torch.ops.aten.add.Tensor(sum_1, sum_2); sum_1 = sum_2 = None
|
|
return (add,)
|
|
""",
|
|
)
|
|
|
|
def test_flatten_with_slices(self):
|
|
tree = [{"x": 3}, ["x", slice(1, 2, 3), 1], [4, 5, 6, [slice(3, 4, 5)]]]
|
|
out = _flatten_args_kwargs(tree)
|
|
self.assertExpectedInline(
|
|
str(out), """[3, 'x', 1, 2, 3, 1, 4, 5, 6, 3, 4, 5]"""
|
|
)
|
|
|
|
def test_cycle_detection_no_cycle(self):
|
|
def fn(x, y):
|
|
x0 = x + 1
|
|
y0 = y + 2
|
|
z = x0.sum() + y0.sum()
|
|
return z
|
|
|
|
x = torch.rand(10, 10, requires_grad=False)
|
|
y = torch.rand(10, 20, requires_grad=False)
|
|
|
|
_, _, fw_graphs = self.run_and_return_graphs(fn, x, y)
|
|
mod = fw_graphs[0]
|
|
self.assertExpectedInline(_detect_cycles(mod.graph), """no cycle detected""")
|
|
|
|
def test_cycle_detection_simple(self):
|
|
def fn(x, y):
|
|
x0 = x + 1
|
|
y0 = y + 2
|
|
z = x0.sum() + y0.sum()
|
|
return z
|
|
|
|
x = torch.rand(10, 10, requires_grad=False)
|
|
y = torch.rand(10, 20, requires_grad=False)
|
|
|
|
_, _, fw_graphs = self.run_and_return_graphs(fn, x, y)
|
|
mod = fw_graphs[0]
|
|
add_node = next(n for n in mod.graph.nodes if n.name == "add")
|
|
add_2 = next(n for n in mod.graph.nodes if n.name == "add_2")
|
|
args = add_node.args
|
|
add_node.args = (args[0], add_2)
|
|
self.assertExpectedInline(
|
|
_detect_cycles(mod.graph),
|
|
"""cycle detected in path: deque([arg0_1, add, sum_1, add_2, add])""",
|
|
)
|
|
|
|
def test_cycle_detection_complex(self):
|
|
def inner_fn(x, y):
|
|
x0 = x.view(x.size())
|
|
return x0.view(x.size())
|
|
|
|
def inner_fn2(x, y):
|
|
x = x * 2
|
|
y = y * 2
|
|
return x.sum() + y.sum()
|
|
|
|
def fn(x, y):
|
|
o0 = inner_fn(x, y)
|
|
o1 = inner_fn(x, y)
|
|
o2 = inner_fn2(x, y)
|
|
o3 = inner_fn2(x, y)
|
|
return o0 + o1 + o2.sum() + o3.sum()
|
|
|
|
x = torch.rand(10, 10, requires_grad=False)
|
|
y = torch.rand(10, 20, requires_grad=False)
|
|
x_clone = x.clone()
|
|
y_clone = y.clone()
|
|
|
|
_, _, fw_graphs = self.run_and_return_graphs(fn, x_clone, y_clone)
|
|
mod = fw_graphs[0]
|
|
invoke_subgraph_node = next(
|
|
n for n in mod.graph.nodes if n.name == "invoke_subgraph"
|
|
)
|
|
add_2 = next(n for n in mod.graph.nodes if n.name == "add_2")
|
|
args = invoke_subgraph_node.args
|
|
invoke_subgraph_node.args = (add_2, args[1])
|
|
self.assertExpectedInline(
|
|
_detect_cycles(mod.graph),
|
|
"""cycle detected in path: deque([arg0_1, invoke_subgraph_1, getitem_1, sum_2, add_2, invoke_subgraph, getitem, sum_1, add_1, add_2])""",
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
from torch._dynamo.test_case import run_tests
|
|
|
|
run_tests()
|