import builtins import collections import copy import functools import inspect import itertools import math import operator import types import warnings from typing import cast, Dict, Optional, Set try: import numpy as np except ModuleNotFoundError: np = None import torch import torch._functorch.deprecated as deprecated_func from torch.fx._symbolic_trace import is_fx_tracing from . import config from .external_utils import is_compiling from .utils import is_safe_constant, NP_SUPPORTED_MODULES """ A note on allowed functions: Dynamo consults this file to determine if a particular function/module is allowed to appear as a node in its fx output. If a function is disallowed, it may either be traced-through, or skipped. Trace-through means dynamo will continue to trace the interior code for the function/module rather than stopping at its boundary and recording it as a node in the fx graph. Whether tracing through or allowing, the functionality of the function/module is part of the dynamo graph. Caveat: if tracing through, any interior operation could trigger its own graph-break. Skips are determined by (torch/_dynamo/skipfiles.py) - see "a note on skipfiles" there. """ def make_function_id_set(lazy_initializer): """ Track a set of `id()`s of objects which are either allowed or not allowed to go into the generated FX graph. Use to test for torch.*, numpy.*, builtins.*, etc. Support user modification to permit customization of what can be added to the graph and what will cause a graph break. """ class FunctionIdSet: function_ids: Optional[Set[int]] = None function_names: Optional[Dict[int, str]] = None def __call__(self): if self.function_ids is None: value = lazy_initializer() if isinstance(value, dict): self.function_ids = set(value.keys()) self.function_names = value else: assert isinstance(value, set) self.function_ids = value return self.function_ids def get_name(self, idx: int, default: str): self() # lazy init return self.function_names.get(idx, default) def add(self, idx: int): self() # lazy init self.function_ids.add(idx) def remove(self, idx: int): if idx in self(): self.function_ids.remove(idx) def __contains__(self, idx: int): return idx in self() return FunctionIdSet() @make_function_id_set def _disallowed_function_ids(): remove = [ True, False, None, collections.OrderedDict, copy.copy, copy.deepcopy, inspect.signature, math.__package__, torch.__builtins__, torch.autocast_decrement_nesting, torch.autocast_increment_nesting, torch.autograd.grad, torch.clear_autocast_cache, torch.cuda.current_device, torch.cuda.set_device, torch.distributions.constraints.is_dependent, torch.distributions.normal.Normal, torch.inference_mode, torch.jit.isinstance, torch.set_anomaly_enabled, torch.set_autocast_cache_enabled, torch.set_autocast_cpu_dtype, torch.set_autocast_cpu_enabled, torch.set_autocast_enabled, torch.set_autocast_gpu_dtype, warnings.warn, torch._C._dynamo.eval_frame.unsupported, torch.Tensor.__init__, ] # extract all dtypes from torch dtypes = [ obj for obj in torch.__dict__.values() if isinstance(obj, type(torch.float32)) ] remove += dtypes storage = [ obj for obj in torch.__dict__.values() if isinstance(obj, type(torch.FloatStorage)) ] remove += storage # Distributed APIs don't work well with torch.compile. if torch.distributed.is_available(): remove.extend( torch.distributed.distributed_c10d.dynamo_unsupported_distributed_c10d_ops ) return {id(x) for x in remove} @make_function_id_set def _allowed_function_ids(): """ Walk torch.* and get the ids of all the stuff in it """ warnings.filterwarnings("ignore", category=UserWarning, module="torch.distributed") torch_object_ids = dict() def _is_allowed_module_prefix(obj): allowed_modules = ("torch", "math") # torch.nn.modules.rnn is disallowed because these modules internally # flatten their parameters. This flattening process will call # Tensor.set_ with a Storage, and Storages cannot be traced with # AOTAutograd; so we need to graph-break. To ensure this, we inline # these functions, rather than keep them opaque-ly in the graph. disallowed_modules = ( "torch.optim.", "torch.utils._foreach_utils", # omit the period so we match all the functions in this module "torch.utils._pytree", "torch.nn.modules.rnn.", "torch._dynamo.", "torch._C._dynamo.", "torch._inductor.", "torch._C.inductor.", "torch.fx.", "torch.distributed.fsdp.", "torch.distributed._tensor.", # Inline through the ActivationWrapper in # torch.distributed.algorithms._checkpoint.checkpoint_wrapper. This # nn module calls torch.utils.checkpoint internally. If Dynamo does # not trace this, AOT Autograd will try to trace this and can cause # issues observed in # https://github.com/pytorch/pytorch/issues/108269 "torch.distributed.algorithms.", ) allowed_modules_dot = tuple([x + "." for x in allowed_modules]) module = inspect.getmodule(obj) if module is None: return False mod_name = module.__name__ if any(mod_name.startswith(m) for m in disallowed_modules): return False return mod_name in allowed_modules or mod_name.startswith(allowed_modules_dot) def _find_torch_objects(module): if any( module.__name__.startswith(mod_name) for mod_name in config.allowed_functions_module_string_ignorelist ): return torch_object_ids[id(module)] = module.__name__ for name, obj in list(module.__dict__.items()): if id(obj) not in torch_object_ids: # Dynamo allows all builtins into the graph and does not attempt # to introspect into them. We don't want to allow instances of # HigherOrderOperator into the graph all the time (Dynamo needs # to introspect the body functions of these HigherOrderOperator # first, decide they are safe, and then allow them into the graph). # So we exclude HigherOrderOperator from being a builtin. import torch._ops if isinstance(obj, torch._ops.HigherOrderOperator): continue # We want to trace through `grad` and `vmap` if obj in ( torch.func.grad, deprecated_func.grad, torch.func.vmap, deprecated_func.vmap, torch.nn.functional.triplet_margin_with_distance_loss, torch.cond, ): continue if isinstance(obj, types.ModuleType): if obj.__name__.startswith("torch.") and _is_allowed_module_prefix( obj ): torch_object_ids[id(obj)] = f"{module.__name__}.{name}" _find_torch_objects(obj) elif _is_allowed_module_prefix(obj): torch_object_ids[id(obj)] = f"{module.__name__}.{name}" elif inspect.getmodule(obj) is None and not is_safe_constant(obj): torch_object_ids[id(obj)] = f"{module.__name__}.{name}" _find_torch_objects(torch) _find_torch_objects(math) # torch.Tensor.{fn} for name in dir(torch.Tensor): method = getattr(torch.Tensor, name) if isinstance( method, (types.MethodDescriptorType, types.WrapperDescriptorType) ): torch_object_ids[id(method)] = f"torch.Tensor.{name}" for idx in _disallowed_function_ids(): if idx in torch_object_ids: del torch_object_ids[idx] for extra in (is_fx_tracing, is_compiling): torch_object_ids[id(extra)] = f"{extra.__module__}.{extra.__name__}" return torch_object_ids @make_function_id_set def _allowed_user_defined_function_ids(): rv = {} return rv @make_function_id_set def _builtin_function_ids(): rv = { id(v): f"builtins.{k}" for k, v in builtins.__dict__.items() if not k.startswith("_") and callable(v) } rv.update( { id(v): f"operator.{k}" for k, v in operator.__dict__.items() if not k.startswith("_") and callable(v) } ) rv.update( {id(v): f"functools.{v.__name__}" for v in (itertools.chain, itertools.islice)} ) rv.update({id(cast): "typing.cast"}) rv[id(functools.reduce)] = "functools.reduce" return rv @make_function_id_set def _numpy_function_ids(): rv = dict() for mod in NP_SUPPORTED_MODULES: rv.update( { id(v): f"{mod.__name__}.{k}" for k, v in mod.__dict__.items() if callable(v) and (getattr(v, "__module__", None) or mod.__name__) == mod.__name__ } ) return rv @make_function_id_set def _builtin_constant_ids(): """ Collects constant builtins by eliminating callable items. """ rv = { id(v): f"builtins.{k}" for k, v in builtins.__dict__.items() if not k.startswith("_") and not callable(v) } return rv def is_allowed(obj): """Is this safe to trace like torch.add ?""" # torch.ops is populated lazily so we don't necessarily have them in # _allowed_function_ids. Figure it out by testing the type instead # in those cases if id(obj) in _disallowed_function_ids: return False return id(obj) in _allowed_function_ids or isinstance( obj, (torch._ops.OpOverloadPacket, torch._ops.OpOverload, torch._ops._OpNamespace), ) def is_user_defined_allowed(obj): return id(obj) in _allowed_user_defined_function_ids def torch_get_name(obj, default): """Convert a torch.* function to a string""" return _allowed_function_ids.get_name(id(obj), default) def is_builtin_callable(obj): return id(obj) in _builtin_function_ids def is_builtin_constant(obj): return id(obj) in _builtin_constant_ids def is_numpy(obj): if np is None: return False return isinstance(obj, np.ndarray) or id(obj) in _numpy_function_ids