mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
`is_tracing` returns True during dynamo tracing and False when run in Eager Pull Request resolved: https://github.com/pytorch/pytorch/pull/90329 Approved by: https://github.com/jansel
274 lines
8.3 KiB
Python
274 lines
8.3 KiB
Python
import builtins
|
|
import collections
|
|
import copy
|
|
import functools
|
|
import inspect
|
|
import itertools
|
|
import math
|
|
import operator
|
|
import types
|
|
import warnings
|
|
from typing import Dict, Optional, Set
|
|
|
|
import numpy
|
|
|
|
import torch
|
|
from torch.fx._symbolic_trace import is_fx_tracing
|
|
|
|
from . import config
|
|
from .external_utils import is_compiling
|
|
from .utils import is_safe_constant
|
|
|
|
"""
|
|
A note on allowed functions:
|
|
|
|
Dynamo consults this file to determine if a particular function/module
|
|
is allowed to appear as a node in its fx output.
|
|
|
|
If a function is disallowed, it may either be traced-through, or skipped.
|
|
|
|
Trace-through means dynamo will continue to trace the interior code for
|
|
the function/module rather than stopping at its boundary and recording it
|
|
as a node in the fx graph. Whether tracing through or allowing, the functionality
|
|
of the function/module is part of the dynamo graph. Caveat: if tracing through,
|
|
any interior operation could trigger its own graph-break.
|
|
|
|
Skips are determined by (torch/_dynamo/skipfiles.py) - see "a note on
|
|
skipfiles" there.
|
|
"""
|
|
|
|
|
|
def make_function_id_set(lazy_initializer):
|
|
"""
|
|
Track a set of `id()`s of objects which are either allowed or not
|
|
allowed to go into the generated FX graph. Use to test for torch.*,
|
|
numpy.*, builtins.*, etc.
|
|
|
|
Support user modification to permit customization of what can be
|
|
added to the graph and what will cause a graph break.
|
|
"""
|
|
|
|
class FunctionIdSet:
|
|
function_ids: Optional[Set[int]] = None
|
|
function_names: Optional[Dict[int, str]] = None
|
|
|
|
def __call__(self):
|
|
if self.function_ids is None:
|
|
value = lazy_initializer()
|
|
if isinstance(value, dict):
|
|
self.function_ids = set(value.keys())
|
|
self.function_names = value
|
|
else:
|
|
assert isinstance(value, set)
|
|
self.function_ids = value
|
|
return self.function_ids
|
|
|
|
def get_name(self, idx: int, default: str):
|
|
self() # lazy init
|
|
return self.function_names.get(idx, default)
|
|
|
|
def add(self, idx: int):
|
|
self() # lazy init
|
|
self.function_ids.add(idx)
|
|
|
|
def remove(self, idx: int):
|
|
if idx in self():
|
|
self.function_ids.remove(idx)
|
|
|
|
def __contains__(self, idx: int):
|
|
return idx in self()
|
|
|
|
return FunctionIdSet()
|
|
|
|
|
|
@make_function_id_set
|
|
def _disallowed_function_ids():
|
|
remove = [
|
|
True,
|
|
False,
|
|
None,
|
|
collections.OrderedDict,
|
|
copy.copy,
|
|
copy.deepcopy,
|
|
inspect.signature,
|
|
math.__package__,
|
|
torch.__builtins__,
|
|
torch.autocast_decrement_nesting,
|
|
torch.autocast_increment_nesting,
|
|
torch.autograd.grad,
|
|
torch.clear_autocast_cache,
|
|
torch.cuda.current_device,
|
|
torch.cuda.amp.autocast_mode.autocast,
|
|
torch.distributions.constraints.is_dependent,
|
|
torch.distributions.normal.Normal,
|
|
torch.inference_mode,
|
|
torch.set_anomaly_enabled,
|
|
torch.set_autocast_cache_enabled,
|
|
torch.set_autocast_cpu_dtype,
|
|
torch.set_autocast_cpu_enabled,
|
|
torch.set_autocast_enabled,
|
|
torch.set_autocast_gpu_dtype,
|
|
torch.autograd.profiler.profile,
|
|
warnings.warn,
|
|
torch._C._dynamo.eval_frame.unsupported,
|
|
]
|
|
# extract all dtypes from torch
|
|
dtypes = [
|
|
obj for obj in torch.__dict__.values() if isinstance(obj, type(torch.float32))
|
|
]
|
|
remove += dtypes
|
|
storage = [
|
|
obj
|
|
for obj in torch.__dict__.values()
|
|
if isinstance(obj, type(torch.FloatStorage))
|
|
]
|
|
remove += storage
|
|
return {id(x) for x in remove}
|
|
|
|
|
|
@make_function_id_set
|
|
def _allowed_function_ids():
|
|
"""
|
|
Walk torch.* and get the ids of all the stuff in it
|
|
"""
|
|
warnings.filterwarnings("ignore", category=UserWarning, module="torch.distributed")
|
|
torch_object_ids = dict()
|
|
|
|
def _is_allowed_module_prefix(obj):
|
|
allowed_modules = ("torch", "math")
|
|
# torch.nn.modules.rnn is disallowed because these modules internally
|
|
# flatten their parameters. This flattening process will call
|
|
# Tensor.set_ with a Storage, and Storages cannot be traced with
|
|
# AOTAutograd; so we need to graph-break. To ensure this, we inline
|
|
# these functions, rather than keep them opaque-ly in the graph.
|
|
disallowed_modules = (
|
|
"torch.optim.",
|
|
"torch.nn.modules.rnn.",
|
|
"torch._dynamo.",
|
|
"torch._C._dynamo.",
|
|
"torch._inductor.",
|
|
"torch._C.inductor.",
|
|
"torch.fx.",
|
|
"torch.distributed.fsdp.",
|
|
)
|
|
allowed_modules_dot = tuple([x + "." for x in allowed_modules])
|
|
module = inspect.getmodule(obj)
|
|
if module is None:
|
|
return False
|
|
|
|
mod_name = module.__name__
|
|
|
|
if any(mod_name.startswith(m) for m in disallowed_modules):
|
|
return False
|
|
|
|
return mod_name in allowed_modules or mod_name.startswith(allowed_modules_dot)
|
|
|
|
def _find_torch_objects(module):
|
|
if any(
|
|
module.__name__.startswith(mod_name)
|
|
for mod_name in config.allowed_functions_module_string_ignorelist
|
|
):
|
|
return
|
|
torch_object_ids[id(module)] = module.__name__
|
|
for name, obj in list(module.__dict__.items()):
|
|
if id(obj) not in torch_object_ids:
|
|
if isinstance(obj, types.ModuleType):
|
|
if obj.__name__.startswith("torch.") and _is_allowed_module_prefix(
|
|
obj
|
|
):
|
|
torch_object_ids[id(obj)] = f"{module.__name__}.{name}"
|
|
_find_torch_objects(obj)
|
|
elif _is_allowed_module_prefix(obj):
|
|
torch_object_ids[id(obj)] = f"{module.__name__}.{name}"
|
|
elif inspect.getmodule(obj) is None and not is_safe_constant(obj):
|
|
torch_object_ids[id(obj)] = f"{module.__name__}.{name}"
|
|
|
|
_find_torch_objects(torch)
|
|
_find_torch_objects(math)
|
|
|
|
for idx in _disallowed_function_ids():
|
|
if idx in torch_object_ids:
|
|
del torch_object_ids[idx]
|
|
|
|
for extra in (is_fx_tracing, is_compiling):
|
|
torch_object_ids[id(extra)] = f"{extra.__module__}.{extra.__name__}"
|
|
|
|
return torch_object_ids
|
|
|
|
|
|
@make_function_id_set
|
|
def _builtin_function_ids():
|
|
rv = {
|
|
id(v): f"builtins.{k}"
|
|
for k, v in builtins.__dict__.items()
|
|
if not k.startswith("_") and callable(v)
|
|
}
|
|
rv.update(
|
|
{
|
|
id(v): f"operator.{k}"
|
|
for k, v in operator.__dict__.items()
|
|
if not k.startswith("_") and callable(v)
|
|
}
|
|
)
|
|
rv.update(
|
|
{id(v): f"functools.{v.__name__}" for v in (itertools.chain, itertools.islice)}
|
|
)
|
|
rv[id(functools.reduce)] = "functools.reduce"
|
|
return rv
|
|
|
|
|
|
@make_function_id_set
|
|
def _numpy_function_ids():
|
|
rv = dict()
|
|
for mod in (numpy, numpy.random):
|
|
rv.update(
|
|
{
|
|
id(v): f"{mod.__name__}.{k}"
|
|
for k, v in mod.__dict__.items()
|
|
if callable(v)
|
|
and (getattr(v, "__module__", None) or mod.__name__) == mod.__name__
|
|
}
|
|
)
|
|
return rv
|
|
|
|
|
|
@make_function_id_set
|
|
def _builtin_constant_ids():
|
|
"""
|
|
Collects constant builtins by eliminating callable items.
|
|
"""
|
|
rv = {
|
|
id(v): f"builtins.{k}"
|
|
for k, v in builtins.__dict__.items()
|
|
if not k.startswith("_") and not callable(v)
|
|
}
|
|
return rv
|
|
|
|
|
|
def is_allowed(obj):
|
|
"""Is this safe to trace like torch.add ?"""
|
|
# torch.ops is populated lazily so we don't necessarily have them in
|
|
# _allowed_function_ids. Figure it out by testing the type instead
|
|
# in those cases
|
|
return id(obj) in _allowed_function_ids or isinstance(
|
|
obj,
|
|
(torch._ops.OpOverloadPacket, torch._ops.OpOverload, torch._ops._OpNamespace),
|
|
)
|
|
|
|
|
|
def torch_get_name(obj, default):
|
|
"""Convert a torch.* funcion to a string"""
|
|
return _allowed_function_ids.get_name(id(obj), default)
|
|
|
|
|
|
def is_builtin_callable(obj):
|
|
return id(obj) in _builtin_function_ids
|
|
|
|
|
|
def is_builtin_constant(obj):
|
|
return id(obj) in _builtin_constant_ids
|
|
|
|
|
|
def is_numpy(obj):
|
|
return isinstance(obj, numpy.ndarray) or id(obj) in _numpy_function_ids
|