[inductor][invoke_subgraph] Support None/int as input/output of invoke_subgraph (#139373)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/139373
Approved by: https://github.com/eellison
This commit is contained in:
Animesh Jain 2024-12-26 18:19:28 -08:00 committed by PyTorch MergeBot
parent 379bbef23c
commit 969415885d
4 changed files with 50 additions and 9 deletions

View File

@ -197,6 +197,37 @@ class TestInvokeSubgraphCompile(TestCase):
res = opt_fn(q, k, v)
res.sum().backward()
def test_symint_from_fwd_to_bwd(self):
@mark_compile_region
def gn(x, y):
a = torch.sum(x, (1,), keepdim=True).view(y.shape[1], y.shape[0])
return torch.matmul(a, y)
def fn(x, y):
return gn(x, y)
opt_fn = torch.compile(fn, backend="inductor", fullgraph=True)
x = torch.randn(64, 1, requires_grad=True)
y = torch.randn(8, 8, requires_grad=True)
ref = fn(x, y)
res = opt_fn(x, y)
self.assertEqual(ref, res)
x = torch.randn(256, 1, requires_grad=True)
y = torch.randn(16, 16, requires_grad=True)
ref = fn(x, y)
res = opt_fn(x, y)
self.assertEqual(ref, res)
res.sum().backward()
x = torch.randn(16, 1, requires_grad=True)
y = torch.randn(4, 4, requires_grad=True)
ref = fn(x, y)
res = opt_fn(x, y)
self.assertEqual(ref, res)
res.sum().backward()
def test_dedupe(self):
@mark_compile_region
def gn(x, y):
@ -431,12 +462,17 @@ class GraphModule(torch.nn.Module):
return gn(x)
opt_fn = torch.compile(fn, backend="inductor", fullgraph=True)
x = torch.randn(8, 8, requires_grad=True)
# requires_grad is False deliberately to force None the joint_graph
# outputs
x = torch.randn(8, 8, requires_grad=False)
ref = mod(x)
res = opt_fn(x)
self.assertEqual(ref, res)
ref.sum().backward()
res.sum().backward()
def test_fail_with_direct_invoke_subgraph(self):
from torch._higher_order_ops import invoke_subgraph

View File

@ -2338,6 +2338,7 @@ class PythonWrapperCodegen(CodeGen):
return
self.push_codegened_graph(subgraph.graph)
self.writeline("")
self.writeline(f"{self.comment} subgraph: {subgraph.name}")
self.codegen_subgraph_prefix(subgraph, outer_inputs, outer_outputs)

View File

@ -2043,12 +2043,17 @@ class GraphLowering(torch.fx.Interpreter):
return mod
def get_output_names(self) -> List[str]:
return [
node.get_name()
for node in self.graph_outputs
if not isinstance(node, ir.NoneAsConstantBuffer)
and not isinstance(node, ir.ShapeAsConstantBuffer)
]
names = []
shape_counter = itertools.count(0)
none_counter = itertools.count(0)
for node in self.graph_outputs:
if isinstance(node, ir.NoneAsConstantBuffer):
names.append(f"{self.name}_none{next(none_counter)}")
elif isinstance(node, ir.ShapeAsConstantBuffer):
names.append(f"{self.name}_shape{next(shape_counter)}")
else:
names.append(node.get_name())
return names
def is_unspec_arg(self, name: str) -> bool:
# dynamo wraps unspec variable as 0d CPU tensor,

View File

@ -7149,7 +7149,7 @@ class InvokeSubgraph(ExternKernel):
)
def create_output(output: IRNode, ind: int):
if isinstance(output, NoneAsConstantBuffer):
if isinstance(output, (ShapeAsConstantBuffer, NoneAsConstantBuffer)):
return output
else:
return MultiOutput(
@ -7165,7 +7165,6 @@ class InvokeSubgraph(ExternKernel):
)
outputs = [create_output(output, i) for i, output in enumerate(outputs)]
invoke_subgraph.outputs = outputs
return outputs