mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Revert "[invoke_subgraph] Simplify output code for subgraph output node (#152490)"
This reverts commit5fe335810a. Reverted https://github.com/pytorch/pytorch/pull/152490 on behalf of https://github.com/malfet due to Broke CI, see52cbcac640/1([comment](https://github.com/pytorch/pytorch/pull/152384#issuecomment-2845099985))
This commit is contained in:
parent
2fa39e60ed
commit
2f1800bc3d
|
|
@ -2884,12 +2884,10 @@ class PythonWrapperCodegen(CodeGen):
|
|||
def set_all_partition_names(self, num_partitions: int):
|
||||
self.all_partition_names = [f"partition_{idx}" for idx in range(num_partitions)]
|
||||
|
||||
def codegen_subgraph_call_with_flattened_outputs(
|
||||
self, subgraph, outer_inputs, outer_flattened_outputs
|
||||
):
|
||||
def codegen_subgraph_call(self, subgraph, outer_inputs, outer_outputs):
|
||||
# Get the input and output names of the subgraph
|
||||
outer_output_names = ", ".join(outer_flattened_outputs) + (
|
||||
"," if len(outer_flattened_outputs) == 1 else ""
|
||||
outer_output_names = ", ".join(outer_outputs) + (
|
||||
"," if len(outer_outputs) == 1 else ""
|
||||
)
|
||||
outer_input_names = ", ".join(outer_inputs) + (
|
||||
"," if len(outer_inputs) == 1 else ""
|
||||
|
|
@ -2902,20 +2900,13 @@ class PythonWrapperCodegen(CodeGen):
|
|||
f"({outer_output_names}) = {subgraph.graph.name}({subgraph.graph.name}_args)"
|
||||
)
|
||||
|
||||
def codegen_subgraph_call(self, subgraph, outer_inputs, outer_buffer_name):
|
||||
# Get the input and output names of the subgraph
|
||||
outer_input_names = ", ".join(outer_inputs) + (
|
||||
"," if len(outer_inputs) == 1 else ""
|
||||
)
|
||||
def codegen_subgraph(self, subgraph, outer_inputs, outer_outputs):
|
||||
# Codegen subgraph by recursively calling the codegen for the subgraph.
|
||||
# This lifts the subgraph as a function in the output code.
|
||||
if V.graph.aot_mode:
|
||||
self.codegen_subgraph_by_inlining(subgraph, outer_inputs, outer_outputs)
|
||||
return
|
||||
|
||||
self.writeline(f"{subgraph.graph.name}_args = [{outer_input_names}]")
|
||||
|
||||
# Call the subgraph launcher function
|
||||
self.writeline(
|
||||
f"{outer_buffer_name} = {subgraph.graph.name}({subgraph.graph.name}_args)"
|
||||
)
|
||||
|
||||
def codegen_subgraph_common(self, subgraph):
|
||||
self.push_codegened_graph(subgraph.graph)
|
||||
self.writeline("")
|
||||
self.writeline(f"{self.comment} subgraph: {subgraph.name}")
|
||||
|
|
@ -2934,40 +2925,21 @@ class PythonWrapperCodegen(CodeGen):
|
|||
self.already_codegened_subgraphs.add(subgraph.graph.name)
|
||||
self.define_subgraph_launcher_fn(subgraph_code.value)
|
||||
|
||||
def codegen_subgraph_with_flattened_outputs(
|
||||
self, subgraph, outer_inputs, outer_flattened_outputs
|
||||
):
|
||||
self.codegen_subgraph_common(subgraph)
|
||||
self.codegen_subgraph_call_with_flattened_outputs(
|
||||
subgraph, outer_inputs, outer_flattened_outputs
|
||||
)
|
||||
|
||||
def codegen_subgraph(self, subgraph, outer_inputs, outer_buffer_name):
|
||||
# Codegen subgraph by recursively calling the codegen for the subgraph.
|
||||
# This lifts the subgraph as a function in the output code.
|
||||
self.codegen_subgraph_common(subgraph)
|
||||
self.codegen_subgraph_call(subgraph, outer_inputs, outer_buffer_name)
|
||||
self.codegen_subgraph_call(subgraph, outer_inputs, outer_outputs)
|
||||
|
||||
def codegen_invoke_subgraph(self, invoke_subgraph):
|
||||
name = invoke_subgraph.get_name()
|
||||
|
||||
self.writeline(f"{name} = [None] * {len(invoke_subgraph.outputs)}")
|
||||
outer_inputs = [buf.codegen_reference() for buf in invoke_subgraph.inputs]
|
||||
|
||||
if V.graph.aot_mode:
|
||||
outer_outputs = [
|
||||
f"{name}[{i}]" for i in range(len(invoke_subgraph.outputs))
|
||||
]
|
||||
self.codegen_subgraph_by_inlining(
|
||||
invoke_subgraph.subgraph, outer_inputs, outer_outputs
|
||||
)
|
||||
else:
|
||||
self.codegen_subgraph(invoke_subgraph.subgraph, outer_inputs, name)
|
||||
outer_outputs = [f"{name}[{i}]" for i in range(len(invoke_subgraph.outputs))]
|
||||
self.codegen_subgraph(invoke_subgraph.subgraph, outer_inputs, outer_outputs)
|
||||
|
||||
def codegen_conditional(self, conditional):
|
||||
name = conditional.get_name()
|
||||
|
||||
outer_inputs = [buf.codegen_reference() for buf in conditional.operands]
|
||||
outer_outputs = [f"{name}[{i}]" for i in range(len(conditional.outputs))]
|
||||
|
||||
predicate = conditional.predicate.codegen_reference()
|
||||
if not isinstance(conditional.predicate, ir.ShapeAsConstantBuffer):
|
||||
|
|
@ -2977,24 +2949,11 @@ class PythonWrapperCodegen(CodeGen):
|
|||
self.writeline(f"{name} = [None] * {len(conditional.outputs)}")
|
||||
self.writeline(f"if {predicate}:")
|
||||
self.writeline(EnterSubgraphLine(self, conditional.true_subgraph.graph))
|
||||
if V.graph.aot_mode:
|
||||
outer_outputs = [f"{name}[{i}]" for i in range(len(conditional.outputs))]
|
||||
self.codegen_subgraph_by_inlining(
|
||||
conditional.true_subgraph, outer_inputs, outer_outputs
|
||||
)
|
||||
else:
|
||||
self.codegen_subgraph(conditional.true_subgraph, outer_inputs, name)
|
||||
|
||||
self.codegen_subgraph(conditional.true_subgraph, outer_inputs, outer_outputs)
|
||||
self.writeline(ExitSubgraphLine(self))
|
||||
self.writeline("else:")
|
||||
self.writeline(EnterSubgraphLine(self, conditional.false_subgraph.graph))
|
||||
if V.graph.aot_mode:
|
||||
outer_outputs = [f"{name}[{i}]" for i in range(len(conditional.outputs))]
|
||||
self.codegen_subgraph_by_inlining(
|
||||
conditional.false_subgraph, outer_inputs, outer_outputs
|
||||
)
|
||||
else:
|
||||
self.codegen_subgraph(conditional.false_subgraph, outer_inputs, name)
|
||||
self.codegen_subgraph(conditional.false_subgraph, outer_inputs, outer_outputs)
|
||||
self.writeline(ExitSubgraphLine(self))
|
||||
|
||||
def codegen_while_loop(self, while_loop):
|
||||
|
|
@ -3026,28 +2985,17 @@ class PythonWrapperCodegen(CodeGen):
|
|||
|
||||
self.writeline("while True:")
|
||||
self.writeline(EnterSubgraphLine(self, while_loop.cond_subgraph.graph))
|
||||
|
||||
if V.graph.aot_mode:
|
||||
self.codegen_subgraph_by_inlining(
|
||||
while_loop.cond_subgraph, cond_outer_inputs, cond_outer_outputs
|
||||
)
|
||||
else:
|
||||
self.codegen_subgraph_with_flattened_outputs(
|
||||
while_loop.cond_subgraph, cond_outer_inputs, cond_outer_outputs
|
||||
)
|
||||
self.codegen_subgraph(
|
||||
while_loop.cond_subgraph, cond_outer_inputs, cond_outer_outputs
|
||||
)
|
||||
self.writeline(
|
||||
f"if not {cond_outer_outputs[0]}: break"
|
||||
) # condition doesn't hold
|
||||
self.writeline(ExitSubgraphLine(self))
|
||||
self.writeline(EnterSubgraphLine(self, while_loop.body_subgraph.graph))
|
||||
if V.graph.aot_mode:
|
||||
self.codegen_subgraph_by_inlining(
|
||||
while_loop.body_subgraph, body_outer_inputs, body_outer_outputs
|
||||
)
|
||||
else:
|
||||
self.codegen_subgraph_with_flattened_outputs(
|
||||
while_loop.body_subgraph, body_outer_inputs, body_outer_outputs
|
||||
)
|
||||
self.codegen_subgraph(
|
||||
while_loop.body_subgraph, body_outer_inputs, body_outer_outputs
|
||||
)
|
||||
self.writeline(ExitSubgraphLine(self))
|
||||
|
||||
@staticmethod
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user