mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
This is my first attempt of building new export API. The main thing it addresses is correctly getting input and output relations. Subsequent diffs willl add functionality for dynamic shapes, nn_module_stack etc. Differential Revision: [D81793205](https://our.internmc.facebook.com/intern/diff/D81793205) Pull Request resolved: https://github.com/pytorch/pytorch/pull/162167 Approved by: https://github.com/zhxchen17, https://github.com/avikchaudhuri
143 lines
5.3 KiB
Python
143 lines
5.3 KiB
Python
import builtins
|
|
import inspect
|
|
from collections import namedtuple
|
|
from typing import Any, Callable
|
|
|
|
import torch
|
|
import torch.utils._pytree as pytree
|
|
from torch._dynamo.convert_frame import FrameInfo, fullgraph_capture, get_compile_id
|
|
from torch._dynamo.eval_frame import argument_names
|
|
from torch._dynamo.utils import dynamo_timed, get_metrics_context
|
|
from torch._guards import compile_context, CompileContext
|
|
from torch.fx.graph import _PyTreeCodeGen, _PyTreeInfo
|
|
|
|
|
|
class ModuleToTrace(torch.nn.Module):
|
|
def __init__(self, foo: Any, in_spec: Any) -> None:
|
|
super().__init__()
|
|
self._export_root = foo
|
|
self.in_spec = in_spec
|
|
|
|
def forward(self, *flat_args: Any) -> "ExportTracerOutput":
|
|
args, kwargs = pytree.tree_unflatten(flat_args, self.in_spec)
|
|
res = self._export_root(*args, **kwargs)
|
|
out_flat, out_spec = pytree.tree_flatten(res)
|
|
return ExportTracerOutput(out_flat, out_spec)
|
|
|
|
|
|
ExportTracerOutput = namedtuple("ExportTracerOutput", ["flat_args", "out_spec"])
|
|
|
|
|
|
def _dynamo_graph_capture_for_export(
|
|
mod: torch.nn.Module,
|
|
) -> Callable[..., torch.fx.GraphModule]:
|
|
"""
|
|
This is lower level API that is used for export to capture dynamo level
|
|
torch IR.
|
|
|
|
Notable TODOs:
|
|
1. Are we actually gonna run the bytecode?
|
|
2. Need to attach guards
|
|
"""
|
|
|
|
def inner(*args: Any, **kwargs: Any) -> torch.fx.GraphModule:
|
|
flat_inputs, in_spec = pytree.tree_flatten((args, kwargs))
|
|
module_to_trace = ModuleToTrace(mod, in_spec)
|
|
|
|
signature = inspect.signature(module_to_trace.forward)
|
|
|
|
bound_arguments = signature.bind(*flat_inputs)
|
|
bound_arguments.apply_defaults()
|
|
|
|
f_locals = {"self": module_to_trace, **bound_arguments.arguments}
|
|
|
|
frame = FrameInfo(
|
|
module_to_trace.forward.__func__.__code__, # type: ignore[attr-defined]
|
|
module_to_trace.forward.__func__.__globals__, # type: ignore[attr-defined]
|
|
f_locals,
|
|
builtins, # type: ignore[arg-type]
|
|
closure=(), # type: ignore[arg-type]
|
|
)
|
|
|
|
dynamo_config_ctx = torch._dynamo.config.patch(
|
|
"log_graph_in_out_metadata", True
|
|
)
|
|
|
|
with (
|
|
compile_context(CompileContext(get_compile_id({}))),
|
|
get_metrics_context(),
|
|
dynamo_timed("fullgraph_capture"),
|
|
dynamo_config_ctx,
|
|
):
|
|
out = fullgraph_capture(frame, _is_export_deprecated_do_not_use=True)
|
|
|
|
assert out.dynamo_output.tracer_output.output_graph is not None
|
|
|
|
export_metadata = (
|
|
out.dynamo_output.tracer_output.output_graph.export_metadata
|
|
)
|
|
graph_inputs = export_metadata.graph_input_idx_to_local_source
|
|
output_return_type = export_metadata.output_return_type
|
|
# We need to extract out_spec here because we are not actually running the bytecode
|
|
out_spec = export_metadata.out_spec
|
|
|
|
graph = out.backend_input.graph_module
|
|
|
|
# It is not guaranteed that dynamo puts inputs in right order, so we need to
|
|
# map the actual user order to the dynamo order.
|
|
graph_input_order: dict[int, int] = {}
|
|
for inp in graph_inputs:
|
|
source = graph_inputs[inp]
|
|
assert isinstance(source, torch._dynamo.source.GetItemSource)
|
|
graph_input_order[source.index] = len(graph_input_order)
|
|
|
|
placeholders = [n for n in list(graph.graph.nodes) if n.op == "placeholder"]
|
|
output = next(n for n in list(graph.graph.nodes) if n.op == "output")
|
|
# Sometimes there can be empty inputs
|
|
anchor = placeholders[0] if len(placeholders) > 0 else output
|
|
inp_to_node = {}
|
|
|
|
with graph.graph.inserting_before(anchor):
|
|
for i in range(len(flat_inputs)):
|
|
node_new = graph.graph.placeholder(f"arg_{i}")
|
|
if i in graph_input_order:
|
|
placeholders[graph_input_order[i]]
|
|
node_new.meta = placeholders[graph_input_order[i]].meta.copy()
|
|
inp_to_node[i] = node_new
|
|
|
|
new_args = []
|
|
for i in output_return_type:
|
|
type, val = output_return_type[i]
|
|
if type == "graph_out":
|
|
new_args.append(output.args[0][val])
|
|
if type == "input":
|
|
input_idx = val.index
|
|
new_args.append(inp_to_node[input_idx])
|
|
if type == "constant":
|
|
new_args.append(val)
|
|
output.args = (tuple(new_args),)
|
|
|
|
for src_idx, i in graph_input_order.items():
|
|
old = placeholders[src_idx]
|
|
new = inp_to_node[i]
|
|
old.replace_all_uses_with(new)
|
|
graph.graph.erase_node(old)
|
|
|
|
# Dynamo uses _lazyGraphModule, so we need to force recompile
|
|
from torch.fx._lazy_graph_module import _LazyGraphModule
|
|
|
|
_LazyGraphModule.force_recompile(graph)
|
|
|
|
graph.graph._codegen = _PyTreeCodeGen(
|
|
_PyTreeInfo(
|
|
argument_names(signature, args, kwargs), # type: ignore[arg-type]
|
|
in_spec,
|
|
out_spec,
|
|
)
|
|
)
|
|
|
|
graph.recompile()
|
|
return graph
|
|
|
|
return inner
|