diff --git a/test/dynamo/test_graph_deduplication.py b/test/dynamo/test_graph_deduplication.py index 99ed2f5a8dd..2ff363a5f5c 100644 --- a/test/dynamo/test_graph_deduplication.py +++ b/test/dynamo/test_graph_deduplication.py @@ -112,7 +112,9 @@ class GraphModule(torch.nn.Module): class ___forward_subgraph_0_post_graph(torch.nn.Module): def forward(self, primals_0: "f32[10, 10]", primals_1: "f32[10, 20]"): add: "f32[10, 10]" = torch.ops.aten.add.Tensor(primals_0, 1); primals_0 = None + add_1: "f32[10, 20]" = torch.ops.aten.add.Tensor(primals_1, 2); primals_1 = None + sum_1: "f32[]" = torch.ops.aten.sum.default(add); add = None sum_2: "f32[]" = torch.ops.aten.sum.default(add_1); add_1 = None add_2: "f32[]" = torch.ops.aten.add.Tensor(sum_1, sum_2); sum_1 = sum_2 = None @@ -205,7 +207,9 @@ class GraphModule(torch.nn.Module): class ___forward_subgraph_0_post_graph(torch.nn.Module): def forward(self, primals_0: "f32[10, 10]"): mul: "f32[10, 10]" = torch.ops.aten.mul.Tensor(primals_0, 7); primals_0 = None + add: "f32[10, 10]" = torch.ops.aten.add.Tensor(mul, 1); mul = None + add_1: "f32[10, 10]" = torch.ops.aten.add.Tensor(add, 2); add = None return (add_1,) """, @@ -349,7 +353,9 @@ class GraphModule(torch.nn.Module): class ___forward_subgraph_0_post_graph(torch.nn.Module): def forward(self, primals_0: "f32[10, 10]", primals_1: "f32[10, 20]"): add: "f32[10, 10]" = torch.ops.aten.add.Tensor(primals_0, 1); primals_0 = None + add_1: "f32[10, 20]" = torch.ops.aten.add.Tensor(primals_1, 2); primals_1 = None + sum_1: "f32[]" = torch.ops.aten.sum.default(add); add = None sum_2: "f32[]" = torch.ops.aten.sum.default(add_1); add_1 = None add_2: "f32[]" = torch.ops.aten.add.Tensor(sum_1, sum_2); sum_1 = sum_2 = None @@ -358,7 +364,9 @@ class GraphModule(torch.nn.Module): class ___forward_subgraph_1_post_graph(torch.nn.Module): def forward(self, primals_0: "f32[10, 10]", primals_1: "f32[10, 20]"): add: "f32[10, 10]" = torch.ops.aten.add.Tensor(primals_0, 2) + add_1: "f32[10, 20]" = torch.ops.aten.add.Tensor(primals_1, 3) + cos: "f32[10, 20]" = torch.ops.aten.cos.default(add_1); add_1 = None sum_1: "f32[]" = torch.ops.aten.sum.default(cos); cos = None mul: "f32[10, 10]" = torch.ops.aten.mul.Tensor(add, sum_1); add = None @@ -416,6 +424,7 @@ class GraphModule(torch.nn.Module): class ___forward_subgraph_0_post_graph(torch.nn.Module): def forward(self, primals_0: "f32[10, 10]", primals_1: "f32[]"): add: "f32[10, 10]" = torch.ops.aten.add.Tensor(primals_0, 1); primals_0 = None + sum_1: "f32[]" = torch.ops.aten.sum.default(add); add = None add_1: "f32[]" = torch.ops.aten.add.Tensor(sum_1, primals_1); sum_1 = primals_1 = None return (add_1,) @@ -564,7 +573,9 @@ class (torch.nn.Module): class repeated_subgraph0(torch.nn.Module): def forward(self, arg0_1: "f32[10, 10]", arg1_1: "f32[10, 20]"): mul: "f32[10, 10]" = torch.ops.aten.mul.Tensor(arg0_1, 2); arg0_1 = None + mul_1: "f32[10, 20]" = torch.ops.aten.mul.Tensor(arg1_1, 2); arg1_1 = None + sum_1: "f32[]" = torch.ops.aten.sum.default(mul); mul = None sum_2: "f32[]" = torch.ops.aten.sum.default(mul_1); mul_1 = None add: "f32[]" = torch.ops.aten.add.Tensor(sum_1, sum_2); sum_1 = sum_2 = None diff --git a/test/higher_order_ops/test_invoke_subgraph.py b/test/higher_order_ops/test_invoke_subgraph.py index db585afaafd..69394e0e642 100644 --- a/test/higher_order_ops/test_invoke_subgraph.py +++ b/test/higher_order_ops/test_invoke_subgraph.py @@ -921,6 +921,7 @@ class GraphModule(torch.nn.Module): def forward(self, tangents_0: "f32[8, 8]", tangents_1: "f32[8, 8]"): mul_2: "f32[8, 8]" = torch.ops.aten.mul.Tensor(tangents_1, 3) mul_3: "f32[8, 8]" = torch.ops.aten.mul.Tensor(tangents_1, 2); tangents_1 = None + add: "f32[8, 8]" = torch.ops.aten.add.Tensor(mul_2, mul_3); mul_2 = mul_3 = None return (add,) """, @@ -1030,7 +1031,9 @@ class GraphModule(torch.nn.Module): class ___forward_invoke_subgraph_0_post_graph(torch.nn.Module): def forward(self, primals_0: "f32[8, 8]", primals_1: "f32[8, 8]"): mm: "f32[8, 8]" = torch.ops.aten.mm.default(primals_0, primals_1) + sin: "f32[8, 8]" = torch.ops.aten.sin.default(mm) + t: "f32[8, 8]" = torch.ops.aten.t.default(primals_0); primals_0 = None t_1: "f32[8, 8]" = torch.ops.aten.t.default(primals_1); primals_1 = None return (sin, mm, t, t_1) @@ -1055,6 +1058,7 @@ class GraphModule(torch.nn.Module): def forward(self, mm: "f32[8, 8]", t: "f32[8, 8]", t_1: "f32[8, 8]", tangents_0: "f32[8, 8]"): cos: "f32[8, 8]" = torch.ops.aten.cos.default(mm); mm = None mul: "f32[8, 8]" = torch.ops.aten.mul.Tensor(tangents_0, cos); tangents_0 = cos = None + mm_1: "f32[8, 8]" = torch.ops.aten.mm.default(t, mul); t = None mm_2: "f32[8, 8]" = torch.ops.aten.mm.default(mul, t_1); mul = t_1 = None return (mm_2, mm_1) diff --git a/torch/_higher_order_ops/invoke_subgraph.py b/torch/_higher_order_ops/invoke_subgraph.py index 833b04e78e4..c899370b8d5 100644 --- a/torch/_higher_order_ops/invoke_subgraph.py +++ b/torch/_higher_order_ops/invoke_subgraph.py @@ -298,7 +298,7 @@ def get_output_metadata(subgraph, operands): def trace_joint_graph_as_bwd( - fn, num_primals, joint_operands, include_key_set, exclude_key_set + subgraph, num_primals, joint_operands, include_key_set, exclude_key_set ): """ Naively trace out a joint graph. This simplifies the reconstruction of joint @@ -308,6 +308,17 @@ def trace_joint_graph_as_bwd( dummy_aot_config = get_dummy_aot_autograd_config() + if isinstance(subgraph, torch.fx.GraphModule): + + def graph_with_interpreter(*args): + # Running graph with interpreter is needed for propagating the stack_trace + with torch.fx.traceback.preserve_node_meta(): + return torch.fx.Interpreter(subgraph).run(*args) + + fn = graph_with_interpreter + else: + fn = subgraph + # This joint_fn is inserted as the backward graph as is. This simplifies the # min-cut partitioner work later on. # Input signature - (*primals, *tangents) diff --git a/torch/_higher_order_ops/utils.py b/torch/_higher_order_ops/utils.py index ca9884687f3..27f4e739eb4 100644 --- a/torch/_higher_order_ops/utils.py +++ b/torch/_higher_order_ops/utils.py @@ -821,4 +821,10 @@ class FunctionalizeCtxWrapper: return f"FunctionalizeCtxWrapper on subgraph {self.subgraph})" def __call__(self, *args, **kwargs): + if isinstance(self.subgraph, torch.fx.GraphModule): + # Running graph with interpreter is needed for propagating the stack_trace + with fx_traceback.preserve_node_meta(): + return self.ctx.functionalize(torch.fx.Interpreter(self.subgraph).run)( + *args, **kwargs + ) return self.ctx.functionalize(self.subgraph)(*args, **kwargs)