mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Extract some HOP utils to be importable (#159705)
Useful helper function for stage 1 export -> manual partitioner -> stage 2 compile users Pull Request resolved: https://github.com/pytorch/pytorch/pull/159705 Approved by: https://github.com/zou3519 ghstack dependencies: #159134
This commit is contained in:
parent
49abc0e3f8
commit
22bedc429f
|
|
@ -516,6 +516,48 @@ class InvokeSubgraphHopGraphs:
|
|||
new_num_saved_nodes: Optional[int] = None
|
||||
|
||||
|
||||
def prepare_for_partitioner(mod, num_primals, num_fw_outputs):
|
||||
# min-cut partitioner requires the placeholders to have primals and
|
||||
# tangents string in the node.name. The signature of the joint graph is
|
||||
# (*primals, *tangents)
|
||||
|
||||
# We also have to update the output signature which is right now
|
||||
# (*grads, *fw_outs) and we have to change to (*fw_outs, *grads) for the
|
||||
# partitioner to work.
|
||||
new_graph = torch.fx.Graph()
|
||||
env = {}
|
||||
|
||||
primals_counter = itertools.count(0)
|
||||
tangents_counter = itertools.count(0)
|
||||
|
||||
for idx, node in enumerate(mod.graph.nodes):
|
||||
if node.op == "placeholder":
|
||||
if idx < num_primals:
|
||||
env[node] = new_graph.placeholder(f"primals_{next(primals_counter)}")
|
||||
else:
|
||||
env[node] = new_graph.placeholder(f"tangents_{next(tangents_counter)}")
|
||||
env[node].meta = copy.copy(node.meta)
|
||||
elif node.op == "output":
|
||||
# Reverse the (*grads, *fw_outs) to (*fw_outs, *grads)
|
||||
# The reason for having the reversed signature in the first
|
||||
# place is to simplify step 3.
|
||||
old_outputs = node.args[0]
|
||||
new_outputs = (
|
||||
*old_outputs[-num_fw_outputs:],
|
||||
*old_outputs[:-num_fw_outputs],
|
||||
)
|
||||
new_outputs = [env[n] if n else None for n in new_outputs]
|
||||
new_graph.output(tuple(new_outputs))
|
||||
else:
|
||||
env[node] = new_graph.node_copy(node, lambda n: env[n])
|
||||
env[node].meta = copy.copy(node.meta)
|
||||
|
||||
new_graph.lint()
|
||||
|
||||
out = torch.fx.GraphModule(mod, new_graph)
|
||||
return out
|
||||
|
||||
|
||||
def run_joint_graph_passes_on_hops(
|
||||
joint_gm: torch.fx.GraphModule,
|
||||
joint_inputs: Any,
|
||||
|
|
@ -553,51 +595,6 @@ def run_joint_graph_passes_on_hops(
|
|||
def num_inputs(mod):
|
||||
return len(mod.graph.find_nodes(op="placeholder"))
|
||||
|
||||
def prepare_for_partitioner(mod, num_primals, num_fw_outputs):
|
||||
# min-cut partitioner requires the placeholders to have primals and
|
||||
# tangents string in the node.name. The signature of the joint graph is
|
||||
# (*primals, *tangents)
|
||||
|
||||
# We also have to update the output signature which is right now
|
||||
# (*grads, *fw_outs) and we have to change to (*fw_outs, *grads) for the
|
||||
# partitioner to work.
|
||||
new_graph = torch.fx.Graph()
|
||||
env = {}
|
||||
|
||||
primals_counter = itertools.count(0)
|
||||
tangents_counter = itertools.count(0)
|
||||
|
||||
for idx, node in enumerate(mod.graph.nodes):
|
||||
if node.op == "placeholder":
|
||||
if idx < num_primals:
|
||||
env[node] = new_graph.placeholder(
|
||||
f"primals_{next(primals_counter)}"
|
||||
)
|
||||
else:
|
||||
env[node] = new_graph.placeholder(
|
||||
f"tangents_{next(tangents_counter)}"
|
||||
)
|
||||
env[node].meta = copy.copy(node.meta)
|
||||
elif node.op == "output":
|
||||
# Reverse the (*grads, *fw_outs) to (*fw_outs, *grads)
|
||||
# The reason for having the reversed signature in the first
|
||||
# place is to simplify step 3.
|
||||
old_outputs = node.args[0]
|
||||
new_outputs = (
|
||||
*old_outputs[-num_fw_outputs:],
|
||||
*old_outputs[:-num_fw_outputs],
|
||||
)
|
||||
new_outputs = [env[n] if n else None for n in new_outputs]
|
||||
new_graph.output(tuple(new_outputs))
|
||||
else:
|
||||
env[node] = new_graph.node_copy(node, lambda n: env[n])
|
||||
env[node].meta = copy.copy(node.meta)
|
||||
|
||||
new_graph.lint()
|
||||
|
||||
out = torch.fx.GraphModule(mod, new_graph)
|
||||
return out
|
||||
|
||||
new_hop_graphs: dict[str, InvokeSubgraphHopGraphs] = defaultdict(
|
||||
lambda: InvokeSubgraphHopGraphs()
|
||||
)
|
||||
|
|
|
|||
|
|
@ -2052,6 +2052,34 @@ def get_cuda_device_context(gm: torch.fx.GraphModule) -> AbstractContextManager[
|
|||
)
|
||||
|
||||
|
||||
def partition_fn(
|
||||
gm: GraphModule,
|
||||
joint_inputs: Sequence[object],
|
||||
**kwargs: object,
|
||||
) -> tuple[GraphModule, GraphModule]:
|
||||
cuda_context = get_cuda_device_context(gm)
|
||||
with cuda_context:
|
||||
# We can skip the invoke_subgraph because the
|
||||
# entire_partition_fn is called recursively for invoke_subgraph
|
||||
# in partitioning.
|
||||
_recursive_joint_graph_passes(gm, skip_invoke_subgraph=True)
|
||||
|
||||
static_lifetime_input_indices: Optional[list[int]] = kwargs.pop( # type: ignore[assignment]
|
||||
"static_lifetime_input_indices", None
|
||||
)
|
||||
|
||||
with dynamo_utils.dynamo_timed(
|
||||
"min_cut_rematerialization_partition", log_pt2_compile_event=True
|
||||
):
|
||||
return min_cut_rematerialization_partition(
|
||||
gm,
|
||||
joint_inputs,
|
||||
compiler="inductor",
|
||||
static_lifetime_input_indices=static_lifetime_input_indices,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
def compile_fx(
|
||||
model_: GraphModule,
|
||||
example_inputs_: Sequence[InputType],
|
||||
|
|
@ -2370,33 +2398,6 @@ def compile_fx(
|
|||
OutputCode, inference_compiler
|
||||
)
|
||||
|
||||
def partition_fn(
|
||||
gm: GraphModule,
|
||||
joint_inputs: Sequence[object],
|
||||
**kwargs: object,
|
||||
) -> tuple[GraphModule, GraphModule]:
|
||||
cuda_context = get_cuda_device_context(gm)
|
||||
with cuda_context:
|
||||
# We can skip the invoke_subgraph because the
|
||||
# entire_partition_fn is called recursively for invoke_subgraph
|
||||
# in partitioning.
|
||||
_recursive_joint_graph_passes(gm, skip_invoke_subgraph=True)
|
||||
|
||||
static_lifetime_input_indices: Optional[list[int]] = kwargs.pop( # type: ignore[assignment]
|
||||
"static_lifetime_input_indices", None
|
||||
)
|
||||
|
||||
with dynamo_utils.dynamo_timed(
|
||||
"min_cut_rematerialization_partition", log_pt2_compile_event=True
|
||||
):
|
||||
return min_cut_rematerialization_partition(
|
||||
gm,
|
||||
joint_inputs,
|
||||
compiler="inductor",
|
||||
static_lifetime_input_indices=static_lifetime_input_indices,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@compile_time_strobelight_meta(phase_name="backward")
|
||||
def bw_compiler(
|
||||
gm: GraphModule, example_inputs: Sequence[InputType]
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user