diff --git a/test/higher_order_ops/test_invoke_subgraph.py b/test/higher_order_ops/test_invoke_subgraph.py index cf6de8200e3..8add10a398b 100644 --- a/test/higher_order_ops/test_invoke_subgraph.py +++ b/test/higher_order_ops/test_invoke_subgraph.py @@ -197,6 +197,37 @@ class TestInvokeSubgraphCompile(TestCase): res = opt_fn(q, k, v) res.sum().backward() + def test_symint_from_fwd_to_bwd(self): + @mark_compile_region + def gn(x, y): + a = torch.sum(x, (1,), keepdim=True).view(y.shape[1], y.shape[0]) + return torch.matmul(a, y) + + def fn(x, y): + return gn(x, y) + + opt_fn = torch.compile(fn, backend="inductor", fullgraph=True) + + x = torch.randn(64, 1, requires_grad=True) + y = torch.randn(8, 8, requires_grad=True) + ref = fn(x, y) + res = opt_fn(x, y) + self.assertEqual(ref, res) + + x = torch.randn(256, 1, requires_grad=True) + y = torch.randn(16, 16, requires_grad=True) + ref = fn(x, y) + res = opt_fn(x, y) + self.assertEqual(ref, res) + res.sum().backward() + + x = torch.randn(16, 1, requires_grad=True) + y = torch.randn(4, 4, requires_grad=True) + ref = fn(x, y) + res = opt_fn(x, y) + self.assertEqual(ref, res) + res.sum().backward() + def test_dedupe(self): @mark_compile_region def gn(x, y): @@ -431,12 +462,17 @@ class GraphModule(torch.nn.Module): return gn(x) opt_fn = torch.compile(fn, backend="inductor", fullgraph=True) - x = torch.randn(8, 8, requires_grad=True) + # requires_grad is False deliberately to force None the joint_graph + # outputs + x = torch.randn(8, 8, requires_grad=False) ref = mod(x) res = opt_fn(x) self.assertEqual(ref, res) + ref.sum().backward() + res.sum().backward() + def test_fail_with_direct_invoke_subgraph(self): from torch._higher_order_ops import invoke_subgraph diff --git a/torch/_inductor/codegen/wrapper.py b/torch/_inductor/codegen/wrapper.py index 4ab06805430..679b0e2ffe7 100644 --- a/torch/_inductor/codegen/wrapper.py +++ b/torch/_inductor/codegen/wrapper.py @@ -2338,6 +2338,7 @@ class PythonWrapperCodegen(CodeGen): return self.push_codegened_graph(subgraph.graph) + self.writeline("") self.writeline(f"{self.comment} subgraph: {subgraph.name}") self.codegen_subgraph_prefix(subgraph, outer_inputs, outer_outputs) diff --git a/torch/_inductor/graph.py b/torch/_inductor/graph.py index 7f8f6826191..a16f69de758 100644 --- a/torch/_inductor/graph.py +++ b/torch/_inductor/graph.py @@ -2043,12 +2043,17 @@ class GraphLowering(torch.fx.Interpreter): return mod def get_output_names(self) -> List[str]: - return [ - node.get_name() - for node in self.graph_outputs - if not isinstance(node, ir.NoneAsConstantBuffer) - and not isinstance(node, ir.ShapeAsConstantBuffer) - ] + names = [] + shape_counter = itertools.count(0) + none_counter = itertools.count(0) + for node in self.graph_outputs: + if isinstance(node, ir.NoneAsConstantBuffer): + names.append(f"{self.name}_none{next(none_counter)}") + elif isinstance(node, ir.ShapeAsConstantBuffer): + names.append(f"{self.name}_shape{next(shape_counter)}") + else: + names.append(node.get_name()) + return names def is_unspec_arg(self, name: str) -> bool: # dynamo wraps unspec variable as 0d CPU tensor, diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index 38210f11178..7f1686e54a7 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -7149,7 +7149,7 @@ class InvokeSubgraph(ExternKernel): ) def create_output(output: IRNode, ind: int): - if isinstance(output, NoneAsConstantBuffer): + if isinstance(output, (ShapeAsConstantBuffer, NoneAsConstantBuffer)): return output else: return MultiOutput( @@ -7165,7 +7165,6 @@ class InvokeSubgraph(ExternKernel): ) outputs = [create_output(output, i) for i, output in enumerate(outputs)] - invoke_subgraph.outputs = outputs return outputs