import contextlib import functools import inspect import logging import os import sys import textwrap import threading import traceback import types import warnings from enum import Enum from importlib import import_module from typing import Optional, Tuple, TYPE_CHECKING, Union from unittest.mock import patch import torch import torch.utils._pytree as pytree from torch.fx.experimental.proxy_tensor import make_fx from torch.fx.graph import _PyTreeCodeGen, _PyTreeInfo from torch.nn.parallel.distributed import DistributedDataParallel from .hooks import Hooks if TYPE_CHECKING: from torch._C._dynamo.eval_frame import ( # noqa: F401 reset_code, set_eval_frame, set_guard_error_hook, set_guard_fail_hook, skip_code, unsupported, ) else: for name in dir(torch._C._dynamo.eval_frame): if name.startswith("__"): continue globals()[name] = getattr(torch._C._dynamo.eval_frame, name) from . import config, convert_frame, skipfiles, utils from .exc import ResetRequired from .mutation_guard import install_generation_tagging_init from .output_graph import CompilerFn from .types import DynamoCallback from .utils import compile_times log = logging.getLogger(__name__) from torch.fx.experimental import proxy_tensor always_optimize_code_objects = utils.ExactWeakKeyDictionary() null_context = contextlib.nullcontext # See https://github.com/python/typing/pull/240 class Unset(Enum): token = 0 unset = Unset.token compile_lock = threading.RLock() most_recent_backend: Optional[CompilerFn] = None class OptimizedModule(torch.nn.Module): """ Wraps the original nn.Module object and later patches its forward method to optimized self.forward method. """ def __init__(self, mod, dynamo_ctx): super().__init__() # Installs the params/buffer self._orig_mod = mod self.dynamo_ctx = dynamo_ctx def __getattr__(self, name): if name == "_orig_mod": return self._modules["_orig_mod"] return getattr(self._orig_mod, name) def forward(self, *args, **kwargs): return self.dynamo_ctx(self._orig_mod.forward)(*args, **kwargs) def remove_from_cache(f): """ Make sure f.__code__ is not cached to force a recompile """ if isinstance(f, types.CodeType): reset_code(f) elif hasattr(f, "__code__"): reset_code(f.__code__) elif hasattr(getattr(f, "forward", None), "__code__"): reset_code(f.forward.__code__) else: from . import reset reset() log.warning("could not determine __code__ for %s", f) def nothing(): pass def innermost_fn(fn): """ In case of nesting of _TorchDynamoContext calls, find the innermost function. TorchDynamo caches on fn.__code__ object, so its necessary to find the innermost function to pass on the optimize, run, disable etc. """ unaltered_fn = fn while hasattr(unaltered_fn, "_torchdynamo_orig_callable"): unaltered_fn = unaltered_fn._torchdynamo_orig_callable assert callable(unaltered_fn) return unaltered_fn @contextlib.contextmanager def enable_dynamic(enable: bool = True): if not enable: yield return with patch("torch._dynamo.config.dynamic_shapes", True), patch( "torch._functorch.config.use_dynamic_shapes", True ): yield class _TorchDynamoContext: def __init__( self, callback: DynamoCallback, on_enter=nothing, backend_ctx_ctor=null_context, patch_fn=nothing, first_ctx=False, *, dynamic=False, ): super().__init__() assert callable(callback) or callback is False or callback is None self.callback: DynamoCallback = callback self.prior: Union[Unset, DynamoCallback] = unset self.on_enter = on_enter self.extra_ctx_ctor = backend_ctx_ctor self.first_ctx = first_ctx self.dynamic = dynamic patch_fn() def __enter__(self): if config.raise_on_ctx_manager_usage: raise RuntimeError( "torch._dynamo.optimize(...) is used with a context manager. " "Please refer to https://github.com/pytorch/torchdynamo#usage-example " "to use torch._dynamo.optimize(...) as an annotation/decorator. " ) self.on_enter() self.prior = set_eval_frame(self.callback) self.backend_ctx = self.extra_ctx_ctor() self.backend_ctx.__enter__() self.dynamic_ctx = enable_dynamic(self.dynamic) self.dynamic_ctx.__enter__() def __exit__(self, exc_type, exc_val, exc_tb): assert self.prior is not unset set_eval_frame(self.prior) self.prior = unset # TODO: This is totally not the right way to chain contexts manually self.dynamic_ctx.__exit__(exc_type, exc_val, exc_tb) self.backend_ctx.__exit__(exc_type, exc_val, exc_tb) def __call__(self, fn): fn = innermost_fn(fn) # Optimize the forward method of torch.nn.Module object if isinstance(fn, torch.nn.Module): mod = fn new_mod = OptimizedModule(mod, self) # Save the function pointer to find the original callable while nesting # of decorators. new_mod._torchdynamo_orig_callable = mod.forward return new_mod assert callable(fn) callback = self.callback on_enter = self.on_enter backend_ctx_ctor = self.extra_ctx_ctor @functools.wraps(fn) def _fn(*args, **kwargs): if ( not isinstance(self, DisableContext) and torch.fx._symbolic_trace.is_fx_tracing() ): if config.error_on_nested_fx_trace: raise RuntimeError( "Detected that you are using FX to symbolically trace " "a dynamo-optimized function. This is not supported at the moment." ) else: return fn(*args, **kwargs) on_enter() prior = set_eval_frame(callback) backend_ctx = backend_ctx_ctor() backend_ctx.__enter__() dynamic_ctx = enable_dynamic(self.dynamic) dynamic_ctx.__enter__() try: return fn(*args, **kwargs) finally: set_eval_frame(prior) dynamic_ctx.__exit__(None, None, None) backend_ctx.__exit__(None, None, None) # hooks to properly handle inlining if isinstance(self, DisableContext): _fn._torchdynamo_disable = True # type: ignore[attr-defined] else: _fn._torchdynamo_inline = fn # type: ignore[attr-defined] # Save the function pointer to find the original callable while nesting # of decorators. _fn._torchdynamo_orig_callable = fn # type: ignore[attr-defined] # If the function is called using torch._dynamo.optimize decorator, we # should prevent any type of skipping. if callback not in (None, False): if not hasattr(fn, "__code__"): raise RuntimeError( textwrap.dedent( """ torch._dynamo.optimize is called on a non function object. If this is a callable class, please wrap the relevant code into a function and optimize the wrapper function. >> class CallableClass: >> def __init__(self): >> super().__init__() >> self.relu = torch.nn.ReLU() >> >> def __call__(self, x): >> return self.relu(torch.sin(x)) >> >> def print_hello(self): >> print("Hello world") >> >> mod = CallableClass() If you want to optimize the __call__ function and other code, wrap that up in a function >> def wrapper_fn(x): >> y = mod(x) >> return y.sum() and then optimize the wrapper_fn >> opt_wrapper_fn = torch._dynamo.optimize(wrapper_fn) """ ) ) always_optimize_code_objects[fn.__code__] = True return _fn class OptimizeContext(_TorchDynamoContext): def __init__(self, callback, backend_ctx_ctor, first_ctx=False, *, dynamic=False): def on_enter(): global most_recent_backend if ( most_recent_backend is not None and most_recent_backend is not compiler_fn ): raise ResetRequired() most_recent_backend = compiler_fn install_generation_tagging_init() compiler_fn = innermost_fn(callback) super().__init__( callback=callback, on_enter=on_enter, backend_ctx_ctor=backend_ctx_ctor, patch_fn=TorchPatcher.patch, first_ctx=first_ctx, dynamic=dynamic, ) class RunOnlyContext(_TorchDynamoContext): def __init__(self): super().__init__(callback=False) class DisableContext(_TorchDynamoContext): def __init__(self): super().__init__(callback=None) def catch_errors_wrapper(callback, hooks: Hooks): @functools.wraps(callback) def catch_errors(frame, cache_size): if ( frame.f_lasti >= 0 or skipfiles.check(frame.f_code.co_filename) or config.disable ): log.debug(f"skipping {frame.f_code.co_name} {frame.f_code.co_filename}") return None if frame.f_code.co_filename == "" and frame.f_code.co_name == "__new__": # nametuple constructor return None if config.optimize_ddp: ddp_module = DistributedDataParallel._get_active_ddp_module() if ddp_module: with compile_lock: from .optimizations.distributed import DDPOptimizer ddp_optimizer = DDPOptimizer( bucket_bytes_cap=ddp_module.bucket_bytes_cap, backend_compile_fn=callback._torchdynamo_orig_callable, ) hijacked_callback = convert_frame.convert_frame( ddp_optimizer.compile_fn, hooks=hooks, ) return hijacked_callback(frame, cache_size, hooks) with compile_lock: return callback(frame, cache_size, hooks) catch_errors._torchdynamo_orig_callable = callback # type: ignore[attr-defined] return catch_errors def _optimize_catch_errors( compile_fn, hooks: Hooks, backend_ctx_ctor=null_context, dynamic=False ): return OptimizeContext( catch_errors_wrapper(compile_fn, hooks), backend_ctx_ctor=backend_ctx_ctor, first_ctx=True, dynamic=dynamic, ) def get_compiler_fn(compiler_fn): from .debug_utils import wrap_backend_debug if isinstance(compiler_fn, torch._TorchCompileInductorWrapper): compiler_str = "inductor" elif isinstance(compiler_fn, str): compiler_str = compiler_fn else: compiler_str = None compiler_fn = lookup_backend(compiler_fn) return wrap_backend_debug(compiler_fn, compiler_str) @functools.lru_cache(1) def lookup_backend(compiler_fn): """Expand backend strings to functions""" if compiler_fn == "inductor": if torch.cuda.is_available(): if ( torch.backends.cuda.matmul.allow_tf32 is False and torch.cuda.get_device_capability() >= (8, 0) ): warnings.warn( "TensorFloat32 tensor cores for float32 matrix multiplication available but not enabled." "Consider setting `torch.set_float32_matmul_precision('high')`" ) compiler_fn = import_module(f"{config.inductor_import}.compile_fx").compile_fx elif isinstance(compiler_fn, str): from .optimizations import BACKENDS compiler_fn = BACKENDS[compiler_fn] return compiler_fn class _NullDecorator(contextlib.nullcontext): # type: ignore[type-arg] def __call__(self, fn): assert callable(fn) return fn def optimize( backend="inductor", *, nopython=False, guard_export_fn=None, guard_fail_fn=None, disable=False, dynamic=False, ): """ The main entrypoint of TorchDynamo. Do graph capture and call backend() to optimize extracted graphs. Args: backend: One of the two things: - Either, a function/callable taking a torch.fx.GraphModule and example_inputs and returning a python callable that runs the graph faster. One can also provide additional context for the backend, like torch.jit.fuser("fuser2"), by setting the backend_ctx_ctor attribute. See AOTAutogradMemoryEfficientFusionWithContext for the usage. - Or, a string backend name in `torch._dynamo.list_backends()` nopython: If True, graph breaks will be errors and there will be a single whole-program graph. disable: If True, turn this decorator into a no-op dynamic: If True, turn on dynamic shapes support Example Usage:: @torch._dynamo.optimize() def toy_example(a, b): ... """ # Note: The hooks object could be global instead of passed around, *however* that would make # for a confusing API usage and plumbing story wherein we nest multiple .optimize calls. # There is some prior art around this, w/r/t nesting backend calls are enforced to be the same # compiler, however, this feels onerous for callback and hooks, and it feels better to give our users an # easier to understand UX at the cost of a little more plumbing on our end. hooks = Hooks(guard_export_fn=guard_export_fn, guard_fail_fn=guard_fail_fn) torch._C._log_api_usage_once("torch._dynamo.optimize") if disable or os.environ.get("TORCHDYNAMO_DISABLE", "") == "1": return _NullDecorator() if sys.platform == "win32": warnings.warn( "Windows is not currently supported, " + f"{config.dynamo_import}.optimize() will do nothing" ) return _NullDecorator() if sys.version_info >= (3, 11): warnings.warn( "Python 3.11+ not yet supported, " f"{config.dynamo_import}.optimize() will do nothing" ) return _NullDecorator() backend = get_compiler_fn(backend) # Find if backend has any extra context manager backend_ctx_ctor = getattr(backend, "backend_ctx_ctor", null_context) if nopython: return optimize_assert( backend, dynamic=dynamic, hooks=hooks, ) return _optimize_catch_errors( convert_frame.convert_frame(backend, hooks=hooks), hooks, backend_ctx_ctor, dynamic=dynamic, ) # TODO(voz): Consider making "explain" output alongside a run / part of a run @patch("torch._dynamo.symbolic_convert.explain", True) def explain(f, *args, **kwargs): # TODO(voz): Do we want a decorator for this? from . import reset reset() out_guards = [] graphs = [] ops_per_graph = [] op_count = 0 break_reasons = [] def dynamo_graph_accumulating_compiler(gm: torch.fx.GraphModule, example_inputs): nonlocal graphs nonlocal op_count nonlocal ops_per_graph graphs.append(gm) ops = [] for node in gm.graph.nodes: if node.op == "call_function": ops.append(node.target) op_count += len(ops) ops_per_graph.append(ops) if gm.compile_subgraph_reason is not None: break_reasons.append(gm.compile_subgraph_reason) return gm.forward def guard_export_print(guards): nonlocal out_guards out_guards.append(guards) with patch(f"{__name__}.most_recent_backend", None): opt_f = optimize( dynamo_graph_accumulating_compiler, nopython=False, guard_export_fn=guard_export_print, )(f) # TODO(voz): We may have instances of `f` that mutate inputs, we should track sideffects and reject. opt_f(*args, **kwargs) graph_count = len(graphs) # For the explanation summary, dedupe reasons by the innermost stack frame and dedupe by it. deduped_reasons = {} for reason in break_reasons: innermost_frame = reason.user_stack[-1] # __repr__ uniquely identifies a FrameSummary so we can use it for deduping deduped_reasons[repr(innermost_frame)] = reason formatted_list = "" for idx, break_reason in enumerate(deduped_reasons.values()): formatted_stack = "".join(traceback.format_list(break_reason.user_stack)) msg = f"{break_reason.reason}\n{formatted_stack}" formatted_list += f"{idx + 1}. {msg} \n" explanation = f"Dynamo produced {graph_count} graphs " explanation += f"with {graph_count - 1} graph break and {op_count} ops" explanation_verbose = explanation explanation_verbose += f"\n Break reasons: \n\n{formatted_list}" explanation_verbose += compile_times() # TODO(voz): Do we want a decorator for this? reset() return ( explanation, out_guards, graphs, ops_per_graph, break_reasons, explanation_verbose, ) def export( f, *args, aten_graph=False, decomposition_table=None, tracing_mode="real", **kwargs ): torch._C._log_api_usage_once("torch._dynamo.export") if decomposition_table is not None or tracing_mode != "real": assert ( aten_graph ), "Specifying a decomposition_table table or tracing mode is illegal without setting aten_graph=True" f = innermost_fn(f) graph = None out_guards = None graph_captured_input = None graph_captured_result: Optional[Tuple[torch.Tensor, ...]] = None def produce_matching(source_args, candidate_args): matched_elements_positions = [] dict_of_source_args = dict() for i in range(0, len(source_args)): element_id = id(source_args[i]) dict_of_source_args[element_id] = i for i in range(0, len(candidate_args)): arg = candidate_args[i] # 1-element tensor arg can be unspec int/float if isinstance(arg, torch.Tensor) and torch.numel(arg) == 1: if id(arg) in dict_of_source_args: matched_elements_positions.append(dict_of_source_args[id(arg)]) elif id(arg.item()) in dict_of_source_args: matched_elements_positions.append( dict_of_source_args[id(arg.item())] ) else: raise AssertionError( "Dynamo input/output is not consistent with traced input/output" ) else: assert ( id(arg) in dict_of_source_args ), "Dynamo input and output is a strict subset of traced input/output" matched_elements_positions.append(dict_of_source_args[id(arg)]) return matched_elements_positions def guard_export_print(guards): nonlocal out_guards assert out_guards is None, "whole graph export entails exactly one guard export" out_guards = guards def dynamo_normalization_capturing_compiler( gm: torch.fx.GraphModule, example_inputs ): nonlocal graph assert graph is None, "whole graph export entails exactly one graph" graph = gm def result_capturing_wrapper(*graph_inputs): nonlocal graph_captured_result nonlocal graph_captured_input graph_captured_input = graph_inputs assert graph is not None graph_captured_result = graph(*graph_inputs) return graph_captured_result return result_capturing_wrapper # TODO(voz): Handle kwargs properly? flat_args, in_spec = pytree.tree_flatten(args) remove_from_cache(f) with patch(f"{__name__}.most_recent_backend", None): opt_f = optimize_assert( dynamo_normalization_capturing_compiler, hooks=Hooks(guard_export_fn=guard_export_print, guard_fail_fn=None), export=True, )(f) # TODO(voz): We may have instances of `f` that mutate inputs, we should track sideffects and reject. result_traced = opt_f(*args, **kwargs) remove_from_cache(f) assert graph is not None, "whole graph export entails exactly one call" assert out_guards is not None, "whole graph export entails exactly one guard export" matched_input_elements_positions = produce_matching(flat_args, graph_captured_input) flat_results_traced, out_spec_traced = pytree.tree_flatten(result_traced) assert graph_captured_result is not None flat_both = list(graph_captured_result) + flat_args matched_output_elements_positions = produce_matching(flat_both, flat_results_traced) class ChangeInputOutputSignature(torch.fx.interpreter.Transformer): def __init__( self, m, ): super().__init__(m) arg_len = len(flat_args) self.new_args = [ super(ChangeInputOutputSignature, self).placeholder(f"arg{i}", (), {}) for i in range(0, arg_len) ] self.old_args_gen = ( self.new_args[i] for i in matched_input_elements_positions ) def placeholder(self, target, args, kwargs): arg = next(self.old_args_gen) if "val" in self.current_node.meta: arg.node.meta["val"] = self.current_node.meta["val"] return arg def output(self, target, args, kwargs): dynamo_result_flat = args[0] lookup = [*dynamo_result_flat, *self.new_args] new_result_flat = [lookup[i] for i in matched_output_elements_positions] return super().output(target, (new_result_flat,), {}) def run_node(self, n): self.current_node = n return super().run_node(n) if aten_graph: # Running graph with interpreter is needed for propagating the stack_trace def graph_with_interpreter(*args): with torch.fx.traceback.preserve_node_meta(): return torch.fx.Interpreter(graph).run(*args) graph = make_fx( graph_with_interpreter, decomposition_table=decomposition_table, tracing_mode=tracing_mode, _allow_non_fake_inputs=True, )(*graph_captured_input) new_graph = ChangeInputOutputSignature( graph, ).transform() # Make dynamo graph to have same input/output spec as user code new_graph.graph._codegen = _PyTreeCodeGen( _PyTreeInfo( [f"orig_arg_{i}" for i in range(len(args))], in_spec, out_spec_traced, ) ) new_graph.recompile() return (new_graph, out_guards) def assume_constant_result(fn): fn._dynamo_marked_constant = True return fn def optimize_assert(backend, *, hooks=Hooks(None, None), export=False, dynamic=False): """ The same as `torch._dynamo.optimize(backend, nopython=True)` """ backend = get_compiler_fn(backend) # Find if backend has any extra context manager backend_ctx_ctor = getattr(backend, "backend_ctx_ctor", null_context) return _optimize_catch_errors( convert_frame.convert_frame_assert(backend, export=export), hooks, backend_ctx_ctor, dynamic=dynamic, ) def run(fn=None): """Don't do any dynamic compiles, just use prior optimizations""" if fn is not None: fn = innermost_fn(fn) assert callable(fn) return RunOnlyContext()(fn) return RunOnlyContext() def disable(fn=None): """Decorator and context manager to disable TorchDynamo""" if fn is not None: fn = innermost_fn(fn) assert callable(fn) return DisableContext()(fn) return DisableContext() def skip(fn=None): """ Skip frames associated with the function code, but still process recursively invoked frames """ if fn is None: return skip fn = innermost_fn(fn) assert callable(fn) skip_code(fn.__code__) fn._torchdynamo_disable = True return fn class TorchPatcher: @staticmethod @functools.lru_cache(None) def patch(): # Disable TorchDynamo on some torch.* compilers generated frames torch.jit.trace = disable(torch.jit.trace) torch.jit.trace_module = disable(torch.jit.trace_module) torch.jit._get_trace_graph = disable(torch.jit._get_trace_graph) # symbolic_trace creates new frames. We disable Dynamo on such frames torch.fx._symbolic_trace.Tracer.trace = disable( torch.fx._symbolic_trace.Tracer.trace ) torch.onnx.export_to_pretty_string = disable(torch.onnx.export_to_pretty_string) torch.distributions.Distribution.set_default_validate_args(False) proxy_tensor.dispatch_trace = disable(proxy_tensor.dispatch_trace) optimizers = [ opt for opt in torch.optim.__dict__.values() if inspect.isclass(opt) and issubclass(opt, torch.optim.Optimizer) ] # disable dynamo for the wrapper that helps give dynamo hints about entering DDP if hasattr(DistributedDataParallel, "_inside_ddp_forward"): DistributedDataParallel._inside_ddp_forward = skip( DistributedDataParallel._inside_ddp_forward ) from ..optim import adagrad, adam, adamax, adamw, asgd, nadam, sgd for opt_mod in adagrad, adam, adamax, adamw, asgd, nadam, sgd: multi_tensor_fn_name = f"_multi_tensor_{opt_mod.__name__.split('.')[-1]}" if hasattr(opt_mod, multi_tensor_fn_name): setattr( opt_mod, multi_tensor_fn_name, disable(getattr(opt_mod, multi_tensor_fn_name)), ) excluded_opts = {torch.optim.SparseAdam, torch.optim.RAdam, torch.optim.LBFGS} for opt in optimizers: if opt in excluded_opts: opt.step = disable(opt.step) opt._cuda_graph_capture_health_check = disable( opt._cuda_graph_capture_health_check ) opt.zero_grad = disable(opt.zero_grad) if hasattr(opt, "_init_group"): opt._init_group = disable(opt._init_group) # disable any currently set hooks # Note: we only want to disable the profiling hook # which is the *last* hook applied, we want to keep the no_grad hook hooked = getattr(opt.step, "hooked", False) if hooked: unwrapped_step = getattr(opt.step, "__wrapped__", None) if unwrapped_step: opt.step = unwrapped_step # disable future hooking opt.step.hooked = True @staticmethod def suppress_torch_distributed_warnings(fn): def inner_fn(*args, **kwargs): warnings.filterwarnings( "ignore", category=UserWarning, module="torch.distributed" ) return fn(*args, **kwargs) return inner_fn