pytorch/torch/_inductor/compile_fx.py
Edward Z. Yang 17d7be68ee Delete functorch use_fake_tensor and debug_fake_cross_ref (#99314)
Using fake tensor with AOTAutograd is now mandatory, simplifying our
logic.  Unfortunately, this means debug_fake_cross_ref must go,
but I don't think anyone has used it recently.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/99314
Approved by: https://github.com/eellison, https://github.com/zou3519
2023-04-18 02:09:54 +00:00

848 lines
27 KiB
Python

import dataclasses
import functools
import itertools
import logging
import sys
import warnings
from copy import deepcopy
from typing import Any, Callable, Dict, List, Optional
import functorch
from functorch.compile import min_cut_rematerialization_partition
import torch._dynamo.config as dynamo_config
import torch.fx
import torch.utils._pytree as pytree
from torch._dynamo import logging as dynamo_logging, utils as dynamo_utils
from torch._dynamo.utils import detect_fake_mode
from torch._functorch.aot_autograd import make_boxed_func
from torch._ops import OpOverload
from torch._subclasses.fake_tensor import FakeTensor
from torch.fx.passes.fake_tensor_prop import FakeTensorProp
from torch.utils._mode_utils import no_dispatch
from .._dynamo.backends.common import aot_autograd
from ..fx.graph import _PyTreeCodeGen
from . import config, metrics, overrides
from .debug import DebugContext
from .decomposition import select_decomp_table
from .fx_passes.joint_graph import joint_graph_passes
from .fx_passes.post_grad import post_grad_passes
from .fx_passes.pre_grad import pre_grad_passes
from .graph import GraphLowering
from .utils import (
developer_warning,
get_dtype_size,
has_incompatible_cudagraph_ops,
is_cpu_device,
)
from .virtualized import V
log = logging.getLogger(__name__)
ALIGNMENT = 16
@dataclasses.dataclass
class BoxedBool:
value: bool
def __bool__(self):
return self.value
@staticmethod
def disable(obj):
if isinstance(obj, BoxedBool):
obj.value = False
return obj
return False
@dataclasses.dataclass
class BoxedDeviceIndex:
value: Optional[int]
def set(self, device_idx):
assert device_idx is None or isinstance(device_idx, int)
self.value = device_idx
# copy_ fails when trying to write to tensors with memory overlap,
# for expanded dimensions (a dimension which used to have size 1 -> ?)
# we can select one element from that dimension and write to it
# to achieve writing to all values of that dimension of the input tensor
def get_expanded_dims(t):
return [i for i in range(t.ndim) if t.stride(i) == 0 and t.size(i) != 1]
def index_expanded_dims(t, expanded_dims):
for expanded_dim in expanded_dims:
t = torch.ops.aten.slice(t, expanded_dim, 0, 1)
return t
def complex_memory_overlap(t):
# if torch._debug_has_internal_overlap thinks this tensor potentially has
# memory overlap internally, let's dig deeper to find out whether it's true.
t = index_expanded_dims(t, get_expanded_dims(t))
if torch._debug_has_internal_overlap(t) != 0:
strides = t.stride()
sizes = t.shape
indices = list(range(len(strides)))
indices = [x for _, x in sorted(zip(strides, indices))]
for i in range(len(strides)):
prev_stride = 1 if i == 0 else strides[indices[i - 1]]
prev_size = 1 if i == 0 else sizes[indices[i - 1]]
if strides[indices[i]] < prev_stride * prev_size:
return True
return False
@functools.lru_cache(None)
def _step_logger():
return dynamo_logging.get_step_logger(log)
@functools.lru_cache(None)
def _warn_tf32_disabled():
if (
torch.cuda.is_available()
and not torch.backends.cuda.matmul.allow_tf32
and torch.cuda.get_device_capability() >= (8, 0)
):
warnings.warn(
"TensorFloat32 tensor cores for float32 matrix multiplication available but not enabled. "
"Consider setting `torch.set_float32_matmul_precision('high')` for better performance."
)
def is_tf32_warning_applicable(gm: torch.fx.GraphModule):
aten = torch.ops.aten
tf32_ops = {
aten.mm.default,
aten.addmm.default,
aten.bmm.default,
aten.baddbmm.default,
}
for node in gm.graph.nodes:
if (
node.op == "call_function"
and node.target in tf32_ops
and isinstance(node.meta.get("val", None), torch.Tensor)
and node.meta["val"].dtype == torch.float32
and node.meta["val"].device.type == "cuda"
):
return True
return False
@DebugContext.wrap
def count_bytes_inner(gm, example_inputs, num_fixed=0, **kwargs):
shape_env = _shape_env_from_inputs(example_inputs)
graph = GraphLowering(gm, shape_env=shape_env, num_static_inputs=num_fixed)
with V.set_graph_handler(graph):
graph.run(*example_inputs)
num_bytes, nodes_num_elem = graph.count_bytes()
metrics.num_bytes_accessed += num_bytes
metrics.nodes_num_elem += nodes_num_elem
return make_boxed_func(gm.forward)
@DebugContext.wrap
@torch.utils._python_dispatch._disable_current_modes()
def compile_fx_inner(
gm: torch.fx.GraphModule,
example_inputs: List[torch.Tensor],
cudagraphs=None,
num_fixed=0,
is_backward=False,
graph_id=None,
cpp_wrapper=False,
aot_mode=False,
is_inference=False,
boxed_forward_device_index=None,
):
if is_tf32_warning_applicable(gm):
_warn_tf32_disabled()
if dynamo_utils.count_calls(gm.graph) == 0:
return make_boxed_func(gm.forward)
# lift the maximum depth of the Python interpreter stack
# to adapt large/deep models
sys.setrecursionlimit(max(sys.getrecursionlimit(), 2000))
_step_logger()(
logging.INFO,
"torchinductor compiling "
f"{'BACKWARDS' if is_backward else 'FORWARDS'} "
f"graph {graph_id}",
)
V.debug.fx_graph(gm, example_inputs)
if cudagraphs is None:
cudagraphs = config.triton.cudagraphs
shape_env = _shape_env_from_inputs(example_inputs)
fake_mode = detect_fake_mode(example_inputs)
if not fake_mode:
fake_mode = torch._subclasses.FakeTensorMode(allow_non_fake_inputs=True)
FakeTensorProp(gm, mode=fake_mode).propagate(*example_inputs)
else:
FakeTensorProp(gm, mode=fake_mode).propagate_dont_convert_inputs(
*example_inputs
)
# pattern matcher passes might not preserve striding information
# on node.meta["val"]. if in the future we rely on these being
# correct we will need to fix.
with V.set_fake_mode(fake_mode):
post_grad_passes(gm)
V.debug.fx_graph_transformed(gm, example_inputs)
with V.set_fake_mode(fake_mode):
graph = GraphLowering(
gm,
shape_env=shape_env,
num_static_inputs=num_fixed,
graph_id=graph_id,
cpp_wrapper=cpp_wrapper,
aot_mode=aot_mode,
)
with V.set_graph_handler(graph):
graph.run(*example_inputs)
compiled_fn = graph.compile_to_fn()
if aot_mode:
return compiled_fn
if cudagraphs:
# output args are tuple of first argument
output = list(gm.graph.nodes)[-1]
assert len(output.args) == 1
stack_traces = [
(arg.stack_trace if isinstance(arg, torch.fx.node.Node) else None)
for arg in output.args[0]
]
complex_memory_overlap_inputs = any(
complex_memory_overlap(t)
for t in example_inputs
if isinstance(t, torch.Tensor)
)
if (
set(graph.device_types) == {"cuda"}
and not graph.mutated_inputs
and not has_incompatible_cudagraph_ops(gm)
and not complex_memory_overlap_inputs
and all(isinstance(t, torch.Tensor) for t in example_inputs)
and (len(graph.device_idxs) == 1 or not config.triton.cudagraph_trees)
):
if (
boxed_forward_device_index is not None
and not is_inference
and not is_backward
):
boxed_forward_device_index.set(next(iter(graph.device_idxs)))
compiled_fn = cudagraphify(
compiled_fn,
example_inputs,
static_input_idxs=range(num_fixed),
device_index=next(iter(graph.device_idxs)),
stack_traces=stack_traces,
is_backward=is_backward,
is_inference=is_inference,
)
else:
BoxedBool.disable(cudagraphs)
# See [Backward Generation Handling]
# if cudagraph'd the forward and set the device, we need to let the cudagraph manager
# know we are we running the backward even if we will not run it in cudagraphs
if is_backward and config.triton.cudagraph_trees:
assert boxed_forward_device_index.value is not None
compiled_fn_inner = compiled_fn
manager = torch._inductor.cudagraph_trees.get_manager(
boxed_forward_device_index.value, create_if_none_exists=False
)
# should already exist from forward
assert manager is not None
def compiled_fn(new_inputs):
manager.set_to_running_backward()
return compiled_fn_inner(new_inputs)
if len(set(graph.device_types)) > 1:
developer_warning("skipping cudagraphs due to multiple devices")
elif set(graph.device_types) == {"cuda"}:
if graph.mutated_inputs:
developer_warning("skipping cudagraphs due to input mutation")
elif complex_memory_overlap_inputs:
developer_warning(
"skipping cudagraphs due to complex input striding"
)
elif len(graph.device_idxs) > 1 and config.triton.cudagraph_trees:
developer_warning(
"skipping cudagraphs due to multiple device indexes"
)
result = align_inputs(compiled_fn, example_inputs, range(num_fixed))
_step_logger()(
logging.INFO,
"torchinductor done compiling "
f"{'BACKWARDS' if is_backward else 'FORWARDS'} "
f"graph {graph_id}",
)
# aot autograd needs to know to pass in inputs as a list
result._boxed_call = True
return result
def clone_preserve_strides(x):
needed_size = (
sum((shape - 1) * stride for shape, stride in zip(x.size(), x.stride())) + 1
)
buffer = torch.as_strided(x, (needed_size,), (1,)).clone()
return torch.as_strided(buffer, x.size(), x.stride())
def align_inputs(model, inputs, static_input_idxs=()):
def is_aligned(storage_offset, dtype):
return (storage_offset * get_dtype_size(dtype)) % ALIGNMENT == 0
check_inputs = [
i
for i in range(len(inputs))
if isinstance(inputs[i], torch.Tensor)
and (
i not in static_input_idxs
or not is_aligned(inputs[i].storage_offset(), inputs[i].dtype)
)
and inputs[i].device.type == "cuda"
]
if len(check_inputs) == 0:
return model
def run(new_inputs):
for i in check_inputs:
if new_inputs[i].data_ptr() % ALIGNMENT:
new_inputs[i] = clone_preserve_strides(new_inputs[i])
return model(new_inputs)
return run
@dynamo_utils.dynamo_timed
def cudagraphify(
model,
inputs,
static_input_idxs=(),
*,
device_index: int,
stack_traces: List[Optional[str]],
is_backward: bool,
is_inference: bool,
):
from torch._inductor.cudagraph_trees import (
cudagraphify_impl as new_cudagraphify_impl,
)
if config.triton.cudagraph_trees:
cudagraphify_fn = functools.partial(
new_cudagraphify_impl,
device_index=device_index,
stack_traces=stack_traces,
is_backward=is_backward,
is_inference=is_inference,
)
else:
cudagraphify_fn = cudagraphify_impl
# if using fake tensors, defer cudagraphs until we get real inputs at runtime
if not any(isinstance(inp, FakeTensor) for inp in inputs):
return cudagraphify_fn(model, inputs, static_input_idxs)
compiled_fn = None
def run(new_inputs):
nonlocal compiled_fn
if compiled_fn is None:
with dynamo_utils.preserve_rng_state():
compiled_fn = cudagraphify_fn(model, new_inputs, static_input_idxs)
return compiled_fn(new_inputs)
return run
def remove_unaligned_input_idxs(inputs, static_input_idxs):
"""
We require all inputs to be aligned, so introduce a copy for any
that aren't.
"""
aligned_static_input_idxs = {
idx for idx in static_input_idxs if (inputs[idx].data_ptr() % ALIGNMENT) == 0
}
if len(aligned_static_input_idxs) != len(static_input_idxs):
return aligned_static_input_idxs
return static_input_idxs
def static_input(x):
"""
Copy and input while preserving strides
"""
# TODO(jansel): figure out why this version doesn't work:
# return torch.empty_strided(x.size(), x.stride(), dtype=x.dtype, device=x.device)
needed_size = (
sum((shape - 1) * stride for shape, stride in zip(x.size(), x.stride())) + 1
)
buffer = torch.empty(needed_size, dtype=x.dtype, device=x.device)
return torch.as_strided(buffer, x.size(), x.stride())
def index_expanded_dims_and_copy_(dst, src, expanded_dims):
"Index into expanded dimensions of both dst and src then copy_"
dst = index_expanded_dims(dst, expanded_dims)
src = index_expanded_dims(src, expanded_dims)
dst.copy_(src)
def cudagraphify_impl(model, inputs, static_input_idxs=()):
"""
Assumes inputs[static_input_idxs[i]] are always the same memory address
"""
static_input_idxs = remove_unaligned_input_idxs(inputs, static_input_idxs)
assert isinstance(inputs, (list, tuple))
inps_expanded_dims = [
get_expanded_dims(x) if idx not in static_input_idxs else []
for idx, x in enumerate(inputs)
]
# allocate static tensor inputs
static_inputs = [
static_input(x) if idx not in static_input_idxs else x.detach()
for idx, x in enumerate(inputs)
]
# copy over input values for fresh allocations
for idx, (x, expanded_dims) in enumerate(zip(inputs, inps_expanded_dims)):
if idx not in static_input_idxs:
index_expanded_dims_and_copy_(static_inputs[idx], x, expanded_dims)
# warmup
torch.cuda.synchronize()
stream = torch.cuda.Stream()
stream.wait_stream(torch.cuda.current_stream())
# copy static_inputs because it will be cleared in model
with torch.cuda.stream(stream):
model(list(static_inputs))
stream.synchronize()
torch.cuda.current_stream().wait_stream(stream)
torch.cuda.synchronize()
# record
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph, stream=stream):
static_outputs = model(list(static_inputs))
if not isinstance(static_outputs, (list, tuple)):
static_outputs = (static_outputs,)
if config.size_asserts:
def run(new_inputs):
assert len(static_inputs) == len(new_inputs)
for idx, (dst, src, expanded_dims) in enumerate(
zip(static_inputs, new_inputs, inps_expanded_dims)
):
if idx in static_input_idxs:
assert dst.data_ptr() == src.data_ptr()
else:
# TODO - could make one single op of multiple slices
# and avoid dispatch.
# Could also pre-index the `dst` tensors
index_expanded_dims_and_copy_(dst, src, expanded_dims)
new_inputs.clear()
graph.replay()
return static_outputs
else:
copy_indices = [
idx for idx in range(len(static_inputs)) if idx not in static_input_idxs
]
def run(new_inputs):
for idx in copy_indices:
expanded_dims = inps_expanded_dims[idx]
index_expanded_dims_and_copy_(
static_inputs[idx], new_inputs[idx], expanded_dims
)
new_inputs.clear()
graph.replay()
return static_outputs
return run
def count_tangents(fx_g: torch.fx.GraphModule):
"""
Infers which inputs are static for a backwards graph
"""
def is_not_gradout(x):
return "tangents" not in x.name
arg_count = 0
static_arg_idxs = []
for n in fx_g.graph.nodes:
if n.op == "placeholder":
if is_not_gradout(n):
static_arg_idxs.append(arg_count)
arg_count += 1
assert static_arg_idxs == list(range(len(static_arg_idxs)))
return len(static_arg_idxs)
def compile_fx_with_cpp_wrapper(
module: torch.fx.GraphModule,
example_inputs: List[torch.Tensor],
inner_compile,
decompositions: Optional[Dict[OpOverload, Callable]] = None,
):
"""
Compile into cpp wrapper:
For CPU, this is currently done in one pass.
For GPU, this is done in two passes: JIT-compile the model with python wrapper code
and run it to generate autotuned kernel binaries in the first pass; and then generate
cpp wrapper code and compile it to a dynamic library in the second pass.
"""
from torch.ao.quantization.fx.utils import assert_and_get_unique_device
# Turns off cpp_wrapper before calling back into compile_fx
config_patches = {"cpp_wrapper": False}
device = assert_and_get_unique_device(module)
if is_cpu_device(example_inputs):
assert device is None or device.type == "cpu"
with config.patch(config_patches):
return compile_fx(
module,
example_inputs,
inner_compile=functools.partial(inner_compile, cpp_wrapper=True),
decompositions=decompositions,
)
else:
assert device is None or device.type == "cuda"
config_patches.update(
{
"triton.cudagraphs": False,
"triton.store_cubin": True,
}
)
with config.patch(config_patches):
# first pass
module_copy = deepcopy(module)
fake_mode = detect_fake_mode(example_inputs)
inputs_copy = example_inputs if fake_mode else deepcopy(example_inputs)
compiled = compile_fx(
module_copy,
inputs_copy,
inner_compile=functools.partial(inner_compile, cpp_wrapper=False),
decompositions=decompositions,
)
if fake_mode:
with no_dispatch():
def to_real_tensor(e):
if isinstance(e, FakeTensor):
out = torch.zeros_like(e, device=e.fake_device)
return out
return e
inputs_real = [to_real_tensor(t) for t in example_inputs]
else:
inputs_real = inputs_copy
compiled(*inputs_real)
del module_copy, inputs_real
# second pass
return compile_fx(
module,
example_inputs,
inner_compile=functools.partial(inner_compile, cpp_wrapper=True),
decompositions=decompositions,
)
def compile_fx_aot(
model_: torch.fx.GraphModule,
example_inputs_: List[torch.Tensor],
inner_compile=compile_fx_inner,
config_patches: Optional[Dict[str, Any]] = None,
decompositions: Optional[Dict[OpOverload, Callable]] = None,
):
return compile_fx(
model_,
example_inputs_,
inner_compile=functools.partial(inner_compile, aot_mode=True),
config_patches=config_patches,
decompositions=decompositions,
)
_graph_counter = itertools.count(0)
def compile_fx(
model_: torch.fx.GraphModule,
example_inputs_: List[torch.Tensor],
inner_compile=compile_fx_inner,
config_patches: Optional[Dict[str, Any]] = None,
decompositions: Optional[Dict[OpOverload, Callable]] = None,
):
"""Main entrypoint to a compile given FX graph"""
if config_patches:
with config.patch(config_patches):
return compile_fx(
model_,
example_inputs_,
# need extra layer of patching as backwards is compiled out of scope
inner_compile=config.patch(config_patches)(inner_compile),
decompositions=decompositions,
)
if config.cpp_wrapper:
return compile_fx_with_cpp_wrapper(
model_,
example_inputs_,
inner_compile=inner_compile,
decompositions=decompositions,
)
recursive_compile_fx = functools.partial(
compile_fx,
inner_compile=inner_compile,
decompositions=decompositions,
)
if not graph_returns_tuple(model_):
return make_graph_return_tuple(
model_,
example_inputs_,
recursive_compile_fx,
)
if isinstance(model_, torch.fx.GraphModule):
if isinstance(model_.graph._codegen, _PyTreeCodeGen):
# this graph is the result of dynamo.export()
return handle_dynamo_export_graph(
model_,
example_inputs_,
recursive_compile_fx,
)
# Since handle_dynamo_export_graph will trigger compile_fx again,
# Move these passes after handle_dynamo_export_graph to avoid repeated calls.
model_ = pre_grad_passes(model_, example_inputs_)
if any(isinstance(x, (list, tuple, dict)) for x in example_inputs_):
return flatten_graph_inputs(
model_,
example_inputs_,
recursive_compile_fx,
)
assert not config._raise_error_for_testing
functorch.compile.config.use_functionalize = True
num_example_inputs = len(example_inputs_)
cudagraphs = BoxedBool(
config.triton.cudagraphs and not dynamo_config.dynamic_shapes
)
forward_device = BoxedDeviceIndex(None)
graph_id = next(_graph_counter)
@dynamo_utils.dynamo_timed
def fw_compiler_base(model: torch.fx.GraphModule, example_inputs, is_inference):
if is_inference:
# partition_fn won't be called
joint_graph_passes(model)
fixed = len(example_inputs) - num_example_inputs
return inner_compile(
model,
example_inputs,
num_fixed=fixed,
cudagraphs=cudagraphs,
graph_id=graph_id,
is_inference=is_inference,
boxed_forward_device_index=forward_device,
)
fw_compiler = functools.partial(fw_compiler_base, is_inference=False)
inference_compiler = functools.partial(fw_compiler_base, is_inference=True)
def partition_fn(graph, joint_inputs, **kwargs):
joint_graph_passes(graph)
return min_cut_rematerialization_partition(
graph, joint_inputs, **kwargs, compiler="inductor"
)
# Save and restore dynamic shapes setting for backwards, as it is
# sometimes done as a context manager which won't be set when we
# hit backwards compile
dynamic_shapes = dynamo_config.dynamic_shapes
@dynamo_utils.dynamo_timed
def bw_compiler(model: torch.fx.GraphModule, example_inputs):
with dynamo_config.patch(dynamic_shapes=dynamic_shapes):
fixed = count_tangents(model)
return inner_compile(
model,
example_inputs,
num_fixed=fixed,
cudagraphs=cudagraphs,
is_backward=True,
graph_id=graph_id,
boxed_forward_device_index=forward_device,
)
with overrides.patch_functions():
if decompositions is None:
decompositions = select_decomp_table()
# TODO: can add logging before/after the call to create_aot_dispatcher_function
# in torch._functorch/aot_autograd.py::aot_module_simplified::aot_function_simplified::new_func
# once torchdynamo is merged into pytorch
return aot_autograd(
fw_compiler=fw_compiler,
bw_compiler=bw_compiler,
inference_compiler=inference_compiler,
decompositions=decompositions,
partition_fn=partition_fn,
keep_inference_input_mutations=True,
)(model_, example_inputs_)
def _shape_env_from_inputs(inputs):
shape_env = None
fake_mode = detect_fake_mode(inputs)
# TODO(voz): It would be nice to enable this assert, but there are lots of tests that
# pass in real inputs for now.
# if len(inputs) > 0:
# assert fake_mode is not None, breakpoint()
if fake_mode is not None:
return fake_mode.shape_env
# When there are no tensor inputs, get shape_env from the first SymInt.
for input in inputs:
if isinstance(input, torch.SymInt):
return input.node.shape_env
# TODO(voz): Should we always have one anyway?
return None
def output_node(gm: torch.fx.GraphModule):
"""Get the output node from an FX graph"""
last_node = next(iter(reversed(gm.graph.nodes)))
assert last_node.op == "output"
return last_node
def graph_returns_tuple(gm: torch.fx.GraphModule):
"""True if a FX graph returns a tuple"""
if not isinstance(gm, torch.fx.GraphModule):
return True # can't check this, assume true
(rv,) = output_node(gm).args
if isinstance(rv, (list, tuple)):
return True
if (
isinstance(rv, torch.fx.node.Node)
and hasattr(rv.target, "_schema")
and len(rv.target._schema.returns) > 1
and all(str(ret.type) == "Tensor" for ret in rv.target._schema.returns)
):
# for graphs whose result is one node with multiple outputs
return True
return False
def make_graph_return_tuple(gm: torch.fx.GraphModule, inputs, compile_gm):
"""
Mutate gm so it returns a tuple. This is only needed for graphs
not created by torchdynamo that return non-tuples.
"""
node = output_node(gm)
(rv,) = node.args
rv, spec = pytree.tree_flatten(rv)
with gm.graph.inserting_before(node):
gm.graph.output(rv)
gm.graph.erase_node(node)
assert graph_returns_tuple(gm)
compiled_fn = compile_gm(gm, inputs)
@functools.wraps(compiled_fn)
def wrapper(*args, **kwargs):
return pytree.tree_unflatten(compiled_fn(*args, **kwargs), spec)
return wrapper
def flatten_graph_inputs(gm: torch.fx.GraphModule, inputs, compile_gm):
"""
Mutate inputs so that they are flat and wrap gm such that it
accepts those inputs. This is only needed for graphs not created
by torchdynamo that take bumpy inputs.
"""
inputs, spec = pytree.tree_flatten(inputs)
class GmWrapper(torch.nn.Module):
def __init__(self):
super().__init__()
self.gm = gm
def forward(self, *args):
return self.gm(*pytree.tree_unflatten(args, spec))
compiled_fn = compile_gm(GmWrapper(), inputs)
@functools.wraps(compiled_fn)
def wrapper(*args):
# note this doesn't check the spec, assuming it is the same
return compiled_fn(*pytree.tree_flatten(args)[0])
return wrapper
def handle_dynamo_export_graph(gm, inputs, compile_gm):
"""
`torch._dynamo.export` embeds pytrees in the FX graph codegen object,
convert that to a normal FX graph so inductor can compile it.
"""
codegen = gm.graph._codegen
gm.graph._codegen = torch.fx.graph.CodeGen()
gm.recompile()
compiled_fn = compile_gm(gm, codegen.process_inputs(*inputs))
@functools.wraps(compiled_fn)
def wrapper(*args):
return codegen.process_outputs(compiled_fn(*codegen.process_inputs(*args)))
return wrapper