Extract some HOP utils to be importable (#159705)

Useful helper function for stage 1 export -> manual partitioner -> stage 2 compile users

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159705
Approved by: https://github.com/zou3519
ghstack dependencies: #159134
This commit is contained in:
Simon Fan 2025-08-04 13:46:03 -07:00 committed by PyTorch MergeBot
parent 49abc0e3f8
commit 22bedc429f
2 changed files with 70 additions and 72 deletions

View File

@ -516,6 +516,48 @@ class InvokeSubgraphHopGraphs:
new_num_saved_nodes: Optional[int] = None
def prepare_for_partitioner(mod, num_primals, num_fw_outputs):
# min-cut partitioner requires the placeholders to have primals and
# tangents string in the node.name. The signature of the joint graph is
# (*primals, *tangents)
# We also have to update the output signature which is right now
# (*grads, *fw_outs) and we have to change to (*fw_outs, *grads) for the
# partitioner to work.
new_graph = torch.fx.Graph()
env = {}
primals_counter = itertools.count(0)
tangents_counter = itertools.count(0)
for idx, node in enumerate(mod.graph.nodes):
if node.op == "placeholder":
if idx < num_primals:
env[node] = new_graph.placeholder(f"primals_{next(primals_counter)}")
else:
env[node] = new_graph.placeholder(f"tangents_{next(tangents_counter)}")
env[node].meta = copy.copy(node.meta)
elif node.op == "output":
# Reverse the (*grads, *fw_outs) to (*fw_outs, *grads)
# The reason for having the reversed signature in the first
# place is to simplify step 3.
old_outputs = node.args[0]
new_outputs = (
*old_outputs[-num_fw_outputs:],
*old_outputs[:-num_fw_outputs],
)
new_outputs = [env[n] if n else None for n in new_outputs]
new_graph.output(tuple(new_outputs))
else:
env[node] = new_graph.node_copy(node, lambda n: env[n])
env[node].meta = copy.copy(node.meta)
new_graph.lint()
out = torch.fx.GraphModule(mod, new_graph)
return out
def run_joint_graph_passes_on_hops(
joint_gm: torch.fx.GraphModule,
joint_inputs: Any,
@ -553,51 +595,6 @@ def run_joint_graph_passes_on_hops(
def num_inputs(mod):
return len(mod.graph.find_nodes(op="placeholder"))
def prepare_for_partitioner(mod, num_primals, num_fw_outputs):
# min-cut partitioner requires the placeholders to have primals and
# tangents string in the node.name. The signature of the joint graph is
# (*primals, *tangents)
# We also have to update the output signature which is right now
# (*grads, *fw_outs) and we have to change to (*fw_outs, *grads) for the
# partitioner to work.
new_graph = torch.fx.Graph()
env = {}
primals_counter = itertools.count(0)
tangents_counter = itertools.count(0)
for idx, node in enumerate(mod.graph.nodes):
if node.op == "placeholder":
if idx < num_primals:
env[node] = new_graph.placeholder(
f"primals_{next(primals_counter)}"
)
else:
env[node] = new_graph.placeholder(
f"tangents_{next(tangents_counter)}"
)
env[node].meta = copy.copy(node.meta)
elif node.op == "output":
# Reverse the (*grads, *fw_outs) to (*fw_outs, *grads)
# The reason for having the reversed signature in the first
# place is to simplify step 3.
old_outputs = node.args[0]
new_outputs = (
*old_outputs[-num_fw_outputs:],
*old_outputs[:-num_fw_outputs],
)
new_outputs = [env[n] if n else None for n in new_outputs]
new_graph.output(tuple(new_outputs))
else:
env[node] = new_graph.node_copy(node, lambda n: env[n])
env[node].meta = copy.copy(node.meta)
new_graph.lint()
out = torch.fx.GraphModule(mod, new_graph)
return out
new_hop_graphs: dict[str, InvokeSubgraphHopGraphs] = defaultdict(
lambda: InvokeSubgraphHopGraphs()
)

View File

@ -2052,6 +2052,34 @@ def get_cuda_device_context(gm: torch.fx.GraphModule) -> AbstractContextManager[
)
def partition_fn(
gm: GraphModule,
joint_inputs: Sequence[object],
**kwargs: object,
) -> tuple[GraphModule, GraphModule]:
cuda_context = get_cuda_device_context(gm)
with cuda_context:
# We can skip the invoke_subgraph because the
# entire_partition_fn is called recursively for invoke_subgraph
# in partitioning.
_recursive_joint_graph_passes(gm, skip_invoke_subgraph=True)
static_lifetime_input_indices: Optional[list[int]] = kwargs.pop( # type: ignore[assignment]
"static_lifetime_input_indices", None
)
with dynamo_utils.dynamo_timed(
"min_cut_rematerialization_partition", log_pt2_compile_event=True
):
return min_cut_rematerialization_partition(
gm,
joint_inputs,
compiler="inductor",
static_lifetime_input_indices=static_lifetime_input_indices,
**kwargs,
)
def compile_fx(
model_: GraphModule,
example_inputs_: Sequence[InputType],
@ -2370,33 +2398,6 @@ def compile_fx(
OutputCode, inference_compiler
)
def partition_fn(
gm: GraphModule,
joint_inputs: Sequence[object],
**kwargs: object,
) -> tuple[GraphModule, GraphModule]:
cuda_context = get_cuda_device_context(gm)
with cuda_context:
# We can skip the invoke_subgraph because the
# entire_partition_fn is called recursively for invoke_subgraph
# in partitioning.
_recursive_joint_graph_passes(gm, skip_invoke_subgraph=True)
static_lifetime_input_indices: Optional[list[int]] = kwargs.pop( # type: ignore[assignment]
"static_lifetime_input_indices", None
)
with dynamo_utils.dynamo_timed(
"min_cut_rematerialization_partition", log_pt2_compile_event=True
):
return min_cut_rematerialization_partition(
gm,
joint_inputs,
compiler="inductor",
static_lifetime_input_indices=static_lifetime_input_indices,
**kwargs,
)
@compile_time_strobelight_meta(phase_name="backward")
def bw_compiler(
gm: GraphModule, example_inputs: Sequence[InputType]