Revert "[invoke_subgraph] Simplify output code for subgraph output node (#152490)"

This reverts commit 5fe335810a.

Reverted https://github.com/pytorch/pytorch/pull/152490 on behalf of https://github.com/malfet due to Broke CI, see 52cbcac640/1 ([comment](https://github.com/pytorch/pytorch/pull/152384#issuecomment-2845099985))
This commit is contained in:
PyTorch MergeBot 2025-05-01 15:46:07 +00:00
parent 2fa39e60ed
commit 2f1800bc3d

View File

@ -2884,12 +2884,10 @@ class PythonWrapperCodegen(CodeGen):
def set_all_partition_names(self, num_partitions: int):
self.all_partition_names = [f"partition_{idx}" for idx in range(num_partitions)]
def codegen_subgraph_call_with_flattened_outputs(
self, subgraph, outer_inputs, outer_flattened_outputs
):
def codegen_subgraph_call(self, subgraph, outer_inputs, outer_outputs):
# Get the input and output names of the subgraph
outer_output_names = ", ".join(outer_flattened_outputs) + (
"," if len(outer_flattened_outputs) == 1 else ""
outer_output_names = ", ".join(outer_outputs) + (
"," if len(outer_outputs) == 1 else ""
)
outer_input_names = ", ".join(outer_inputs) + (
"," if len(outer_inputs) == 1 else ""
@ -2902,20 +2900,13 @@ class PythonWrapperCodegen(CodeGen):
f"({outer_output_names}) = {subgraph.graph.name}({subgraph.graph.name}_args)"
)
def codegen_subgraph_call(self, subgraph, outer_inputs, outer_buffer_name):
# Get the input and output names of the subgraph
outer_input_names = ", ".join(outer_inputs) + (
"," if len(outer_inputs) == 1 else ""
)
def codegen_subgraph(self, subgraph, outer_inputs, outer_outputs):
# Codegen subgraph by recursively calling the codegen for the subgraph.
# This lifts the subgraph as a function in the output code.
if V.graph.aot_mode:
self.codegen_subgraph_by_inlining(subgraph, outer_inputs, outer_outputs)
return
self.writeline(f"{subgraph.graph.name}_args = [{outer_input_names}]")
# Call the subgraph launcher function
self.writeline(
f"{outer_buffer_name} = {subgraph.graph.name}({subgraph.graph.name}_args)"
)
def codegen_subgraph_common(self, subgraph):
self.push_codegened_graph(subgraph.graph)
self.writeline("")
self.writeline(f"{self.comment} subgraph: {subgraph.name}")
@ -2934,40 +2925,21 @@ class PythonWrapperCodegen(CodeGen):
self.already_codegened_subgraphs.add(subgraph.graph.name)
self.define_subgraph_launcher_fn(subgraph_code.value)
def codegen_subgraph_with_flattened_outputs(
self, subgraph, outer_inputs, outer_flattened_outputs
):
self.codegen_subgraph_common(subgraph)
self.codegen_subgraph_call_with_flattened_outputs(
subgraph, outer_inputs, outer_flattened_outputs
)
def codegen_subgraph(self, subgraph, outer_inputs, outer_buffer_name):
# Codegen subgraph by recursively calling the codegen for the subgraph.
# This lifts the subgraph as a function in the output code.
self.codegen_subgraph_common(subgraph)
self.codegen_subgraph_call(subgraph, outer_inputs, outer_buffer_name)
self.codegen_subgraph_call(subgraph, outer_inputs, outer_outputs)
def codegen_invoke_subgraph(self, invoke_subgraph):
name = invoke_subgraph.get_name()
self.writeline(f"{name} = [None] * {len(invoke_subgraph.outputs)}")
outer_inputs = [buf.codegen_reference() for buf in invoke_subgraph.inputs]
if V.graph.aot_mode:
outer_outputs = [
f"{name}[{i}]" for i in range(len(invoke_subgraph.outputs))
]
self.codegen_subgraph_by_inlining(
invoke_subgraph.subgraph, outer_inputs, outer_outputs
)
else:
self.codegen_subgraph(invoke_subgraph.subgraph, outer_inputs, name)
outer_outputs = [f"{name}[{i}]" for i in range(len(invoke_subgraph.outputs))]
self.codegen_subgraph(invoke_subgraph.subgraph, outer_inputs, outer_outputs)
def codegen_conditional(self, conditional):
name = conditional.get_name()
outer_inputs = [buf.codegen_reference() for buf in conditional.operands]
outer_outputs = [f"{name}[{i}]" for i in range(len(conditional.outputs))]
predicate = conditional.predicate.codegen_reference()
if not isinstance(conditional.predicate, ir.ShapeAsConstantBuffer):
@ -2977,24 +2949,11 @@ class PythonWrapperCodegen(CodeGen):
self.writeline(f"{name} = [None] * {len(conditional.outputs)}")
self.writeline(f"if {predicate}:")
self.writeline(EnterSubgraphLine(self, conditional.true_subgraph.graph))
if V.graph.aot_mode:
outer_outputs = [f"{name}[{i}]" for i in range(len(conditional.outputs))]
self.codegen_subgraph_by_inlining(
conditional.true_subgraph, outer_inputs, outer_outputs
)
else:
self.codegen_subgraph(conditional.true_subgraph, outer_inputs, name)
self.codegen_subgraph(conditional.true_subgraph, outer_inputs, outer_outputs)
self.writeline(ExitSubgraphLine(self))
self.writeline("else:")
self.writeline(EnterSubgraphLine(self, conditional.false_subgraph.graph))
if V.graph.aot_mode:
outer_outputs = [f"{name}[{i}]" for i in range(len(conditional.outputs))]
self.codegen_subgraph_by_inlining(
conditional.false_subgraph, outer_inputs, outer_outputs
)
else:
self.codegen_subgraph(conditional.false_subgraph, outer_inputs, name)
self.codegen_subgraph(conditional.false_subgraph, outer_inputs, outer_outputs)
self.writeline(ExitSubgraphLine(self))
def codegen_while_loop(self, while_loop):
@ -3026,28 +2985,17 @@ class PythonWrapperCodegen(CodeGen):
self.writeline("while True:")
self.writeline(EnterSubgraphLine(self, while_loop.cond_subgraph.graph))
if V.graph.aot_mode:
self.codegen_subgraph_by_inlining(
while_loop.cond_subgraph, cond_outer_inputs, cond_outer_outputs
)
else:
self.codegen_subgraph_with_flattened_outputs(
while_loop.cond_subgraph, cond_outer_inputs, cond_outer_outputs
)
self.codegen_subgraph(
while_loop.cond_subgraph, cond_outer_inputs, cond_outer_outputs
)
self.writeline(
f"if not {cond_outer_outputs[0]}: break"
) # condition doesn't hold
self.writeline(ExitSubgraphLine(self))
self.writeline(EnterSubgraphLine(self, while_loop.body_subgraph.graph))
if V.graph.aot_mode:
self.codegen_subgraph_by_inlining(
while_loop.body_subgraph, body_outer_inputs, body_outer_outputs
)
else:
self.codegen_subgraph_with_flattened_outputs(
while_loop.body_subgraph, body_outer_inputs, body_outer_outputs
)
self.codegen_subgraph(
while_loop.body_subgraph, body_outer_inputs, body_outer_outputs
)
self.writeline(ExitSubgraphLine(self))
@staticmethod