mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[inductor][invoke_subgraph] Support None/int as input/output of invoke_subgraph (#139373)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/139373 Approved by: https://github.com/eellison
This commit is contained in:
parent
379bbef23c
commit
969415885d
|
|
@ -197,6 +197,37 @@ class TestInvokeSubgraphCompile(TestCase):
|
|||
res = opt_fn(q, k, v)
|
||||
res.sum().backward()
|
||||
|
||||
def test_symint_from_fwd_to_bwd(self):
|
||||
@mark_compile_region
|
||||
def gn(x, y):
|
||||
a = torch.sum(x, (1,), keepdim=True).view(y.shape[1], y.shape[0])
|
||||
return torch.matmul(a, y)
|
||||
|
||||
def fn(x, y):
|
||||
return gn(x, y)
|
||||
|
||||
opt_fn = torch.compile(fn, backend="inductor", fullgraph=True)
|
||||
|
||||
x = torch.randn(64, 1, requires_grad=True)
|
||||
y = torch.randn(8, 8, requires_grad=True)
|
||||
ref = fn(x, y)
|
||||
res = opt_fn(x, y)
|
||||
self.assertEqual(ref, res)
|
||||
|
||||
x = torch.randn(256, 1, requires_grad=True)
|
||||
y = torch.randn(16, 16, requires_grad=True)
|
||||
ref = fn(x, y)
|
||||
res = opt_fn(x, y)
|
||||
self.assertEqual(ref, res)
|
||||
res.sum().backward()
|
||||
|
||||
x = torch.randn(16, 1, requires_grad=True)
|
||||
y = torch.randn(4, 4, requires_grad=True)
|
||||
ref = fn(x, y)
|
||||
res = opt_fn(x, y)
|
||||
self.assertEqual(ref, res)
|
||||
res.sum().backward()
|
||||
|
||||
def test_dedupe(self):
|
||||
@mark_compile_region
|
||||
def gn(x, y):
|
||||
|
|
@ -431,12 +462,17 @@ class GraphModule(torch.nn.Module):
|
|||
return gn(x)
|
||||
|
||||
opt_fn = torch.compile(fn, backend="inductor", fullgraph=True)
|
||||
x = torch.randn(8, 8, requires_grad=True)
|
||||
# requires_grad is False deliberately to force None the joint_graph
|
||||
# outputs
|
||||
x = torch.randn(8, 8, requires_grad=False)
|
||||
|
||||
ref = mod(x)
|
||||
res = opt_fn(x)
|
||||
self.assertEqual(ref, res)
|
||||
|
||||
ref.sum().backward()
|
||||
res.sum().backward()
|
||||
|
||||
def test_fail_with_direct_invoke_subgraph(self):
|
||||
from torch._higher_order_ops import invoke_subgraph
|
||||
|
||||
|
|
|
|||
|
|
@ -2338,6 +2338,7 @@ class PythonWrapperCodegen(CodeGen):
|
|||
return
|
||||
|
||||
self.push_codegened_graph(subgraph.graph)
|
||||
self.writeline("")
|
||||
self.writeline(f"{self.comment} subgraph: {subgraph.name}")
|
||||
self.codegen_subgraph_prefix(subgraph, outer_inputs, outer_outputs)
|
||||
|
||||
|
|
|
|||
|
|
@ -2043,12 +2043,17 @@ class GraphLowering(torch.fx.Interpreter):
|
|||
return mod
|
||||
|
||||
def get_output_names(self) -> List[str]:
|
||||
return [
|
||||
node.get_name()
|
||||
for node in self.graph_outputs
|
||||
if not isinstance(node, ir.NoneAsConstantBuffer)
|
||||
and not isinstance(node, ir.ShapeAsConstantBuffer)
|
||||
]
|
||||
names = []
|
||||
shape_counter = itertools.count(0)
|
||||
none_counter = itertools.count(0)
|
||||
for node in self.graph_outputs:
|
||||
if isinstance(node, ir.NoneAsConstantBuffer):
|
||||
names.append(f"{self.name}_none{next(none_counter)}")
|
||||
elif isinstance(node, ir.ShapeAsConstantBuffer):
|
||||
names.append(f"{self.name}_shape{next(shape_counter)}")
|
||||
else:
|
||||
names.append(node.get_name())
|
||||
return names
|
||||
|
||||
def is_unspec_arg(self, name: str) -> bool:
|
||||
# dynamo wraps unspec variable as 0d CPU tensor,
|
||||
|
|
|
|||
|
|
@ -7149,7 +7149,7 @@ class InvokeSubgraph(ExternKernel):
|
|||
)
|
||||
|
||||
def create_output(output: IRNode, ind: int):
|
||||
if isinstance(output, NoneAsConstantBuffer):
|
||||
if isinstance(output, (ShapeAsConstantBuffer, NoneAsConstantBuffer)):
|
||||
return output
|
||||
else:
|
||||
return MultiOutput(
|
||||
|
|
@ -7165,7 +7165,6 @@ class InvokeSubgraph(ExternKernel):
|
|||
)
|
||||
|
||||
outputs = [create_output(output, i) for i, output in enumerate(outputs)]
|
||||
|
||||
invoke_subgraph.outputs = outputs
|
||||
return outputs
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user