import inspect import logging import traceback from collections import namedtuple from collections.abc import Callable from dataclasses import dataclass from typing import Any, Optional, TYPE_CHECKING, Union import sympy import torch import torch.fx import torch.utils._pytree as pytree from torch._dynamo.convert_frame import CaptureOutput, fullgraph_capture, get_traced_fn from torch._dynamo.eval_frame import argument_names, check_user_input_output from torch._dynamo.exc import UserErrorType from torch._dynamo.utils import dynamo_timed, get_metrics_context from torch._export.utils import _compiling_state_context from torch.export.dynamic_shapes import _RelaxedConstraint, Constraint from torch.fx import Node from torch.fx.experimental.proxy_tensor import make_fx from torch.fx.experimental.symbolic_shapes import ( ConstraintViolationError, DimDynamic, ShapeEnv, StatelessSymbolicContext, ) from torch.fx.graph import _ExportCodeGen, _PyTreeCodeGen, _PyTreeInfo from torch.utils._pytree import TreeSpec if TYPE_CHECKING: from torch._subclasses.fake_tensor import FakeTensorMode log = logging.getLogger(__name__) def post_process_error_msg( constraint_violation_error: ConstraintViolationError, func: Callable[..., Any], args: Any, kwargs: Any, ): """ Because we trace a different callable, the sources are all messed up. Manually patch them so the error message looks correct. """ from torch.export._unlift import _get_input_paths, _replace_sources orig_sig = inspect.signature(func) flat_input_paths = _get_input_paths((args, kwargs), orig_sig) constraint_violation_error.args = ( _replace_sources(constraint_violation_error.args[0], flat_input_paths), ) return constraint_violation_error EXPORT_ROOT_REPLACEMENTS = [ ("__export_root_", "_"), ("_export_root.", ""), ("._export_root", ""), ] def clean_export_root_string(text: str) -> str: """Generic utility to clean export_root patterns from strings.""" result = text for pattern, replacement in EXPORT_ROOT_REPLACEMENTS: result = result.replace(pattern, replacement) return result def clean_nn_module_stack_and_source_fn( graph_module: torch.fx.GraphModule, is_inline_builtin=False ) -> torch.fx.GraphModule: """ Clean up nn_module_stack metadata by removing export_root references. Removes the _export_root module references from nn_module_stack metadata in graph nodes, which are artifacts from the export process. Fixes two patterns: 1. Keys: Removes "__export_root_" and "__modules['_export_root']_" prefixes - Normal case: "L__self____export_root_child" -> "L__self__child" - inline_builtin case: Uses numeric ID strings like "140468831433840" 2. Values: Removes "._export_root" and "._modules['_export_root']" from child names e.g., "L['self']._export_root.child" -> "L['self'].child" e.g., "L['self']._modules['_export_root'].child" -> "L['self'].child" Also removes the root export entry "L__self____export_root" entirely. Args: graph_module: The GraphModule to clean up is_inline_builtin: If True, keys are numeric ID strings and self references (L['self']) are filtered out Returns: The cleaned GraphModule (modified in-place) """ def _process_nn_module_stack(nn_module_stack): if "L__self____export_root" in nn_module_stack: del nn_module_stack["L__self____export_root"] # Clean up remaining entries cleaned_stack = {} for key, (child_name, child_class) in nn_module_stack.items(): # Clean key by removing export_root patterns clean_key = clean_export_root_string(key) # Clean child_name by removing export_root patterns clean_name = clean_export_root_string(child_name) # Skip self reference for inline builtin case if is_inline_builtin and clean_name == "L['self']": continue cleaned_stack[clean_key] = (clean_name, child_class) return cleaned_stack def _process_source_fn(source_fn_stack): cleaned_stack = [] for item in source_fn_stack: if isinstance(item, tuple) and len(item) == 2: name, cls = item if isinstance(name, str): clean_name = clean_export_root_string(name) cleaned_stack.append((clean_name, cls)) else: cleaned_stack.append(item) else: cleaned_stack.append(item) return cleaned_stack for node in graph_module.graph.nodes: if "nn_module_stack" in node.meta: node.meta["nn_module_stack"] = _process_nn_module_stack( node.meta["nn_module_stack"].copy() ) if "source_fn_stack" in node.meta: node.meta["source_fn_stack"] = _process_source_fn( node.meta["source_fn_stack"].copy() ) if "dynamo_flat_name_to_original_fqn" in graph_module.meta: # Clean up flat name to original fqn mapping clean_name_to_original_fqn = {} for flat_name, original_fqn in graph_module.meta[ "dynamo_flat_name_to_original_fqn" ].items(): clean_name_to_original_fqn[clean_export_root_string(flat_name)] = ( clean_export_root_string(original_fqn) ) graph_module.meta["dynamo_flat_name_to_original_fqn"] = ( clean_name_to_original_fqn ) return graph_module def clean_export_root(graph_module: torch.fx.GraphModule) -> None: """Remove export_root artifacts from FX graph in-place""" # Unlike getattr node, call_module can be invoked multiple times # In those cases, we should fix all invocations of call_module clean_named_module_map: dict[str, str] = {} # Update get_attr nodes in-place for node in graph_module.graph.nodes: if node.op == "get_attr": old_target = node.target new_target = clean_export_root_string(old_target) if new_target != old_target: node.target = new_target assert hasattr(graph_module, old_target) # Move the parameter to the new name param = torch.fx.graph_module._get_attr(graph_module, old_target) torch.fx.graph_module._assign_attr(param, graph_module, new_target) torch.fx.graph_module._del_attr(graph_module, old_target) # Dynamo will only have one nested level if node.op == "call_module": old_target = node.target assert isinstance(old_target, str) new_target = clean_export_root_string(old_target) assert isinstance(new_target, str) new_name = clean_export_root_string(node.name) if new_target == old_target: continue # if this module has already been cleaned before, just lookup from map. if old_target in clean_named_module_map: node.target = clean_named_module_map[old_target] node.name = new_name continue target = graph_module.get_submodule(old_target) graph_module.delete_submodule(old_target) graph_module.add_submodule(new_target, target) node.target = new_target node.name = new_name clean_named_module_map[old_target] = new_target class ModuleToTrace(torch.nn.Module): def __init__(self, foo: Any, in_spec: Any) -> None: super().__init__() self._export_root = foo self.in_spec = in_spec def forward(self, *flat_args: Any) -> "ExportTracerOutput": args, kwargs = pytree.tree_unflatten(flat_args, self.in_spec) res = self._export_root(*args, **kwargs) out_flat, out_spec = pytree.tree_flatten(res) return ExportTracerOutput(out_flat, out_spec) ExportTracerOutput = namedtuple("ExportTracerOutput", ["flat_args", "out_spec"]) # mypy: disable-error-code="no-untyped-def,var-annotated,assignment,index,operator" class DynamoGraphTransformer(torch.fx.Transformer): """Graph transformer for dynamo export that flattens inputs/outputs without complex matching.""" def __init__( self, module: torch.fx.GraphModule, flat_inputs: list[Any], flat_args_dynamic_dims: list[set[int]], graph_input_order: dict[int, int], graph_output_map: dict[int, tuple[str, Any]], fake_mode: Optional[Any] = None, ) -> None: super().__init__(module) assert len(flat_args_dynamic_dims) == len(flat_inputs) self.flat_inputs = flat_inputs self.flat_args_dynamic_dims = flat_args_dynamic_dims self.graph_input_order = graph_input_order self.graph_output_map = graph_output_map self.fake_mode = fake_mode # Get original placeholders and output self.placeholders = [n for n in module.graph.nodes if n.op == "placeholder"] self.output_node = next(n for n in module.graph.nodes if n.op == "output") # Create new flattened input placeholders self.new_input_nodes: dict[int, torch.fx.Node] = {} self._create_flattened_inputs() # Iterator for replacing old placeholders self.old_to_new_mapping = {} self._create_placeholder_mapping() def _create_flattened_inputs(self) -> None: """Create new placeholder nodes for flattened inputs with proper fake tensors.""" for i in range(len(self.flat_inputs)): placeholder = super().placeholder(f"arg_{i}", (), {}) # Check if this user input (index i) maps to a graph placeholder if i in self.graph_input_order: # graph_input_order[i] gives us which graph placeholder this user input corresponds to graph_placeholder_idx = self.graph_input_order[i] if graph_placeholder_idx < len(self.placeholders): orig_placeholder = self.placeholders[graph_placeholder_idx] # Copy other metadata but not "val" yet for key, value in orig_placeholder.meta.items(): if key != "val": placeholder.node.meta[key] = value # Always ensure we have proper "val" metadata from fake tensor if self.fake_mode is not None and isinstance( self.flat_inputs[i], torch.Tensor ): placeholder.node.meta["val"] = self.fake_mode.from_tensor( self.flat_inputs[i], symbolic_context=StatelessSymbolicContext( dynamic_sizes=[ ( DimDynamic.DYNAMIC if d in self.flat_args_dynamic_dims[i] else DimDynamic.STATIC ) for d in range(len(self.flat_inputs[i].shape)) ], constraint_sizes=[None] * len(self.flat_inputs[i].shape), ), ) elif hasattr(self.flat_inputs[i], "val"): # _IntWrapper case placeholder.node.meta["val"] = self.flat_inputs[i].val else: placeholder.node.meta["val"] = self.flat_inputs[i] # pyrefly: ignore [unsupported-operation] self.new_input_nodes[i] = placeholder def _create_placeholder_mapping(self) -> None: """Create mapping from old placeholders to new ones.""" # graph_input_order maps: user_input_index -> graph_placeholder_index # We need to create: old_graph_placeholder -> new_user_input_placeholder for user_input_idx, graph_placeholder_idx in self.graph_input_order.items(): if graph_placeholder_idx < len(self.placeholders): old_placeholder = self.placeholders[graph_placeholder_idx] new_placeholder = self.new_input_nodes[user_input_idx] self.old_to_new_mapping[old_placeholder] = new_placeholder def placeholder(self, target, args, kwargs) -> Any: """Replace old placeholders with new flattened ones.""" # Return the corresponding new placeholder if self.current_node in self.old_to_new_mapping: new_arg = self.old_to_new_mapping[self.current_node] # Copy over additional metadata from current node, but don't overwrite "val" for key in ["tensor_dict", "example_value", "unbacked_bindings"]: if key in self.current_node.meta: new_arg.node.meta[key] = self.current_node.meta[key] # Only copy "val" if we don't already have a good one if "val" in self.current_node.meta and "val" not in new_arg.node.meta: new_arg.node.meta["val"] = self.current_node.meta["val"] return new_arg else: # Shouldn't happen if mapping is correct, but fallback return super().placeholder(target, args, kwargs) def output(self, target, args, kwargs) -> Any: """Transform output according to graph_output_map.""" original_outputs = args[0] # Build new output list based on graph_output_map new_outputs = [] for i in sorted(self.graph_output_map.keys()): output_type, val = self.graph_output_map[i] if output_type == "graph_out": new_outputs.append(original_outputs[val]) elif output_type == "input": input_idx = val.index new_outputs.append(self.new_input_nodes[input_idx]) elif output_type == "constant": new_outputs.append(val) return super().output(target, (tuple(new_outputs),), {}) def run_node(self, node: Node) -> Any: """Run node transformation and preserve metadata.""" self.current_node = node result = super().run_node(node) # Copy important metadata if hasattr(result, "node") and result.node is not node: for key in ["val", "example_value", "unbacked_bindings"]: if key in node.meta: result.node.meta[key] = node.meta[key] # Preserve node names (except output) if node.op != "output" and hasattr(node, "name"): result.node._rename(node.name) return result def transform(self) -> torch.fx.GraphModule: """Perform the graph transformation and copy module metadata.""" result_gm = super().transform() # Copy module metadata like the original implementation if hasattr(self.module, "meta"): # pyrefly: ignore [unsupported-operation] if "dynamo_flat_name_to_original_fqn" in self.module.meta: # pyrefly: ignore [index-error] result_gm.meta["dynamo_flat_name_to_original_fqn"] = self.module.meta[ # pyrefly: ignore [index-error] "dynamo_flat_name_to_original_fqn" ] # pyrefly: ignore [unsupported-operation] if "dynamo_compile_id" in self.module.meta: # pyrefly: ignore [index-error] result_gm.meta["dynamo_compile_id"] = self.module.meta[ # pyrefly: ignore [index-error] "dynamo_compile_id" ] return result_gm def _suggest_or_raise_constraint_violation( module_to_trace: torch.nn.Module, orig_callable: Callable, # type: ignore[type-arg] fake_mode: Optional["FakeTensorMode"], graph_capture_output: CaptureOutput, args: Any, kwargs: Any, dynamic_shapes: Optional[Union[dict[str, Any], tuple[Any], list[Any]]], ): constraint_violation_error = None try: # Check if we have any constraint violations fn, _ = get_traced_fn(module_to_trace) graph_capture_output.graph_capture_output.build_guards(fn.__code__) except ConstraintViolationError as e: constraint_violation_error = e if ( (shape_env := getattr(fake_mode, "shape_env", None)) is not None and (dim_constraints := shape_env.dim_constraints) is not None and not isinstance( module_to_trace.forward, torch._ops.OpOverloadPacket | torch._ops.OpOverload, ) ): dim_constraints.solve() forced_specializations = dim_constraints.forced_specializations() msg = dim_constraints.prettify_results( inspect.signature(orig_callable), # type: ignore[attr-defined] dynamic_shapes, constraint_violation_error, forced_specializations, ) if constraint_violation_error: constraint_violation_error.args = ( constraint_violation_error.args[0] + msg, ) else: if forced_specializations: constraint_violation_error = ConstraintViolationError(msg) else: log.info( "Summary of dimension constraints:%s", msg, ) # Error if we have any constraints on static values for k in shape_env.var_to_range.keys(): if isinstance(k, sympy.Integer): constraint_violation_error = ConstraintViolationError( f"{''.join(traceback.format_list(shape_env.var_to_stack[k]))}\n" "It appears that you're trying to set a constraint on a " f"value which we evaluated to have a static value of {k}. " 'Set TORCH_LOGS="+export" for more information.' ) if constraint_violation_error: constraint_violation_error = post_process_error_msg( constraint_violation_error, orig_callable, args, kwargs ) raise constraint_violation_error @dataclass(frozen=True) class PyTreeifyOutput: graph_module: torch.fx.GraphModule in_spec: TreeSpec in_shuffle_graph: torch.fx.GraphModule num_flat_args: int out_spec: TreeSpec out_shuffle_graph: torch.fx.GraphModule root: Optional[torch.nn.Module] = None def pytreeify( out: CaptureOutput, mod: Any, args: tuple[Any, ...], kwargs: dict[str, Any] ) -> PyTreeifyOutput: """ Given a dynamo capture output, return a callable graph module that contain the following information: 1. input/output pytree spec 2. input/output shuffle functions Input shuffle functions are the converters taking pytree falttened inputs and reorder them to the calling convention of dynamo raw graph module. Output shuffle functions are the converters taking the outputs of the dynamo raw graph module and convert them to the pytree format. This function will replay any side effects that happened during the bytecode, so it is important to check against side effects before calling this function. """ assert out.backend_input is not None backend_input = out.backend_input backend = out.backend_input.graph_module root = None if isinstance(mod, torch.nn.Module): args = (mod,) + args root = mod elif inspect.ismethod(mod): args = (mod.__self__,) + args root = mod.__self__ flat_real_args, in_spec = pytree.tree_flatten((args, kwargs)) class Yield(Exception): pass class InShuffle(torch.nn.Module): def __init__(self): super().__init__() self.mod = mod self.num_inputs = len(flat_real_args) self.gm_inputs = None def forward(self, *flat_proxy_args): args, kwargs = pytree.tree_unflatten( [flat_proxy_args[i] for i in range(self.num_inputs)], in_spec ) def backend_dummy(*example_inputs): self.gm_inputs = example_inputs raise Yield backend_input.graph_module = backend_dummy # type: ignore[assignment] try: out.forward_callable()(*args, **kwargs) except Yield: assert self.gm_inputs is not None return self.gm_inputs finally: backend_input.graph_module = backend raise RuntimeError fake_mode = torch._dynamo.utils.detect_fake_mode(flat_real_args) if fake_mode and fake_mode.shape_env is None: fake_mode.shape_env = ShapeEnv() in_shuffle_graph = make_fx( InShuffle(), tracing_mode="symbolic", proxy_module_inputs=True )(*flat_real_args) output_node = next(iter(reversed(backend_input.graph_module.graph.nodes))) class OutShuffle(torch.nn.Module): def __init__(self): super().__init__() self.num_inputs = len(flat_real_args) self.num_outputs = len(output_node.args[0]) self.out_spec: Optional[TreeSpec] = None def forward(self, *flat_proxy_args): args, kwargs = pytree.tree_unflatten( [flat_proxy_args[i] for i in range(self.num_inputs)], in_spec ) def backend_dummy(*example_inputs): return [ flat_proxy_args[self.num_inputs + i] for i in range(self.num_outputs) ] backend_input.graph_module = backend_dummy # type: ignore[assignment] try: results = out.forward_callable()(*args, **kwargs) finally: backend_input.graph_module = backend ret, self.out_spec = pytree.tree_flatten(results) return ret out_shuffle = OutShuffle() flat_out_shuffle_args = [ *flat_real_args, *pytree.tree_map_only( torch.fx.Node, lambda x: fake_mode.from_tensor(x.meta["example_value"]) if fake_mode else x.meta["example_value"], output_node.args[0], ), ] fake_mode = torch._dynamo.utils.detect_fake_mode(flat_out_shuffle_args) if fake_mode and fake_mode.shape_env is None: fake_mode.shape_env = ShapeEnv() out_shuffle_graph = make_fx( out_shuffle, tracing_mode="symbolic", proxy_module_inputs=True )(*flat_out_shuffle_args) assert out_shuffle.out_spec is not None return PyTreeifyOutput( backend_input.graph_module, in_spec, in_shuffle_graph, len(flat_real_args), out_shuffle.out_spec, out_shuffle_graph, root=root, # type: ignore[arg-type] ) def normalize_graph_module(gm): for node in gm.graph.nodes: if node.op == "placeholder": node.meta["val"] = node.meta["example_value"] def dynamo_graph_capture_for_export( mod: Callable[..., Any], constraints: Optional[list[Constraint]] = None, ) -> Callable[..., Any]: def inner(*args: Any, **kwargs: Any) -> Any: assert not torch._dynamo.config.install_free_tensors with ( get_metrics_context(), dynamo_timed("fullgraph_capture"), ): out = fullgraph_capture( mod, args, kwargs, constraints=constraints, ) # TODO filter out side effects. pyt = pytreeify(out, mod, args, kwargs) graph_module = pyt.graph_module tree_leaf_names = [ graph_module.graph._graph_namespace.create_name(f"_tree_leaf_{i}", None) for i in range(pyt.num_flat_args) ] graph_module.graph._codegen = _ExportCodeGen( _PyTreeInfo( # TODO we should be able to use the names from dynamo graph directly. argument_names(inspect.signature(mod), args, kwargs), pyt.in_spec, pyt.out_spec, ), pyt.in_shuffle_graph, pyt.out_shuffle_graph, tree_leaf_names, pyt.root, ) # type: ignore[attr-defined] normalize_graph_module(graph_module) if pyt.root is not None: graph_module._parameters = pyt.root._parameters.copy() graph_module._buffers = pyt.root._buffers.copy() assert all(not hasattr(graph_module, m) for m in pyt.root._modules) graph_module._modules.update(pyt.root._modules) graph_module._non_persistent_buffers_set = ( pyt.root._non_persistent_buffers_set.copy() ) graph_module._in_spec = pyt.in_spec graph_module._out_spec = pyt.out_spec assert not hasattr(graph_module, "_in_shuffle_graph") assert not hasattr(graph_module, "_out_shuffle_graph") graph_module._in_shuffle_graph = pyt.in_shuffle_graph graph_module._out_shuffle_graph = pyt.out_shuffle_graph delattr(graph_module, "_param_name_to_source") graph_module.recompile() graph_module.meta["module_call_specs"] = ( out.graph_capture_output.output_graph.export_metadata.module_call_spec ) assert out.backend_input is not None graph_module.meta["fake_mode"] = out.backend_input.fake_mode # type: ignore[attr-defined] return graph_module return inner def _dynamo_graph_capture_for_export( mod: Callable[..., Any], *, constraints: Optional[list[Constraint]] = None, dynamic_shapes: Optional[Union[dict[str, Any], tuple[Any], list[Any]]] = None, ) -> Callable[..., torch.fx.GraphModule]: """ Improved dynamo graph capture using transformer approach with proper fake tensor handling. This function creates a capture instance that handles: 1. PyTree flattening/unflattening with proper input ordering 2. Dynamo graph capture with export-specific context 3. FX graph transformation for export compatibility 4. Proper fake tensor metadata preservation 5. Dynamic dimension constraint handling Notable improvements over manual approach: - Uses FX Transformer for cleaner graph manipulation - Properly handles fake tensor metadata and dynamic dimensions - Preserves all necessary metadata for export - More robust error handling and edge case management TODO: 1. Are we actually gonna run the bytecode? 2. Need to attach guards """ _dynamic_shapes = dynamic_shapes _constraints = constraints def inner(*args: Any, **kwargs: Any) -> torch.fx.GraphModule: # This sets the is_exporting flag when building guards. with _compiling_state_context(): flat_inputs, in_spec = pytree.tree_flatten((args, kwargs)) check_user_input_output(flat_inputs, UserErrorType.INVALID_INPUT) module_to_trace = ModuleToTrace(mod, in_spec) orig_callable = mod.forward if isinstance(mod, torch.nn.Module) else mod constraints: Optional[list[Constraint]] = _constraints dynamic_shapes: Optional[Union[dict[str, Any], tuple[Any], list[Any]]] = ( _dynamic_shapes ) from . import reset # type: ignore[attr-defined] reset() dynamo_config_ctx = torch._dynamo.config.patch( specialize_int=True, specialize_float=True, assume_static_by_default=True, automatic_dynamic_shapes=False, capture_dynamic_output_shape_ops=True, capture_scalar_outputs=True, constant_fold_autograd_profiler_enabled=True, log_graph_in_out_metadata=True, # install_free_tensors ensures that params and buffers are still # added as graph attributes, and makes Dynamo emits graphs that # follow export pytree-able input requirements In future, if we # fully rely on bytecode for the runtime, we can turn this flag # off. install_free_tensors=torch._dynamo.config.install_free_tensors_for_export, ) with ( get_metrics_context(), dynamo_timed("fullgraph_capture"), dynamo_config_ctx, ): out = fullgraph_capture( module_to_trace, tuple(flat_inputs), constraints=_constraints, _is_export_deprecated_do_not_use=True, ) assert out.graph_capture_output.output_graph is not None example_inputs: list[Any] = [] if out.backend_input is not None: graph = out.backend_input.graph_module fake_mode = out.backend_input.fake_mode example_inputs = out.backend_input.example_inputs else: graph = torch.fx.GraphModule(torch.nn.Module(), torch.fx.Graph()) graph.graph.output(None) graph.recompile() fake_mode = None _suggest_or_raise_constraint_violation( module_to_trace, orig_callable, fake_mode, out, args, kwargs, dynamic_shapes, ) # Extract export metadata from the new location export_metadata = out.graph_capture_output.output_graph.export_metadata graph_inputs = export_metadata.graph_input_idx_to_local_source graph_output_map = export_metadata.output_return_type out_spec = export_metadata.out_spec module_call_spec = export_metadata.module_call_spec # Compute dynamic dimensions for each input based on constraints flat_args_dynamic_dims = [ { c.dim for c in (constraints or ()) if ( c.t_id == id(x) and not isinstance(c, _RelaxedConstraint) and c.constraint_range.vr.lower != c.constraint_range.vr.upper ) } for x in flat_inputs ] # Create input order mapping from dynamo's internal order to user order graph_input_order: dict[int, int] = {} for inp in graph_inputs: source = graph_inputs[inp] assert isinstance(source, torch._dynamo.source.GetItemSource) graph_input_order[source.index] = len(graph_input_order) for real_idx, graph_idx in graph_input_order.items(): flat_inputs[real_idx] = example_inputs[graph_idx] # Use FX transformer to rebuild the graph cleanly transformed_graph = DynamoGraphTransformer( graph, flat_inputs, flat_args_dynamic_dims, graph_input_order, graph_output_map, fake_mode, ).transform() # Set up PyTree codegen for proper input/output handling transformed_graph.graph._codegen = _PyTreeCodeGen( _PyTreeInfo( argument_names(inspect.signature(orig_callable), args, kwargs), # type: ignore[attr-defined, arg-type] in_spec, out_spec, ) ) transformed_graph.recompile() clean_nn_module_stack_and_source_fn( transformed_graph, torch._dynamo.config.inline_inbuilt_nn_modules ) clean_export_root(transformed_graph) transformed_graph.meta["module_call_specs"] = module_call_spec transformed_graph.meta["fake_mode"] = fake_mode return transformed_graph return inner