mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
This is the second PR according https://github.com/pytorch/pytorch/pull/113009#issuecomment-1804417925 Pull Request resolved: https://github.com/pytorch/pytorch/pull/114196 Approved by: https://github.com/jansel
555 lines
18 KiB
Python
555 lines
18 KiB
Python
import builtins
|
|
import collections
|
|
import copy
|
|
import dataclasses
|
|
import functools
|
|
import inspect
|
|
import itertools
|
|
import math
|
|
import operator
|
|
import sys
|
|
import types
|
|
import warnings
|
|
|
|
from collections import defaultdict
|
|
from typing import Any, Callable, cast, Dict, List, Optional, Set, Union
|
|
|
|
np: Optional[types.ModuleType] = None
|
|
try:
|
|
import numpy as np
|
|
except ModuleNotFoundError:
|
|
pass
|
|
|
|
|
|
import torch
|
|
import torch._functorch.deprecated as deprecated_func
|
|
from torch.fx._symbolic_trace import is_fx_tracing
|
|
|
|
from . import config
|
|
from .external_utils import is_compiling
|
|
from .utils import hashable, is_safe_constant, NP_SUPPORTED_MODULES
|
|
|
|
"""
|
|
A note on allowed functions:
|
|
|
|
Dynamo consults this file to determine if a particular function/module
|
|
is allowed to appear as a node in its fx output.
|
|
|
|
If a function is disallowed, it may either be traced-through, or skipped.
|
|
|
|
Trace-through means dynamo will continue to trace the interior code for
|
|
the function/module rather than stopping at its boundary and recording it
|
|
as a node in the fx graph. Whether tracing through or allowing, the functionality
|
|
of the function/module is part of the dynamo graph. Caveat: if tracing through,
|
|
any interior operation could trigger its own graph-break.
|
|
|
|
Skips are determined by (torch/_dynamo/skipfiles.py) - see "a note on
|
|
skipfiles" there.
|
|
"""
|
|
|
|
|
|
class FunctionIdSet:
|
|
"""
|
|
Track a set of `id()`s of objects which are either allowed or not
|
|
allowed to go into the generated FX graph. Use to test for torch.*,
|
|
numpy.*, builtins.*, etc.
|
|
|
|
Support user modification to permit customization of what can be
|
|
added to the graph and what will cause a graph break.
|
|
"""
|
|
|
|
function_ids: Optional[Set[int]] = None
|
|
function_names: Optional[Dict[int, str]] = None
|
|
|
|
def __init__(self, lazy_initializer: Callable[[], Union[Dict[int, str], Set[int]]]):
|
|
self.lazy_initializer = lazy_initializer
|
|
|
|
def __call__(self):
|
|
if self.function_ids is None:
|
|
value = self.lazy_initializer()
|
|
if isinstance(value, dict):
|
|
self.function_ids = set(value.keys())
|
|
self.function_names = value
|
|
else:
|
|
assert isinstance(value, set)
|
|
self.function_ids = value
|
|
return self.function_ids
|
|
|
|
def get_name(self, idx: int, default: str):
|
|
self() # lazy init
|
|
assert self.function_names is not None
|
|
return self.function_names.get(idx, default)
|
|
|
|
def add(self, idx: int):
|
|
function_ids = self() # lazy init
|
|
function_ids.add(idx)
|
|
|
|
def remove(self, idx: int):
|
|
function_ids = self()
|
|
if idx in function_ids:
|
|
function_ids.remove(idx)
|
|
|
|
def __contains__(self, idx: int):
|
|
return idx in self()
|
|
|
|
|
|
@FunctionIdSet
|
|
def _disallowed_function_ids() -> Set[int]:
|
|
remove: List[Any] = [
|
|
True,
|
|
False,
|
|
None,
|
|
collections.OrderedDict,
|
|
copy.copy,
|
|
copy.deepcopy,
|
|
inspect.signature,
|
|
math.__package__,
|
|
torch.__builtins__,
|
|
torch.autocast_decrement_nesting,
|
|
torch.autocast_increment_nesting,
|
|
torch.autograd.grad,
|
|
torch.clear_autocast_cache,
|
|
torch.cuda.current_device,
|
|
torch.cuda.set_device,
|
|
torch.distributions.constraints.is_dependent,
|
|
torch.distributions.normal.Normal,
|
|
torch.inference_mode,
|
|
torch.jit.isinstance,
|
|
torch.set_anomaly_enabled,
|
|
torch.set_autocast_cache_enabled,
|
|
torch.set_autocast_cpu_dtype,
|
|
torch.set_autocast_cpu_enabled,
|
|
torch.set_autocast_enabled,
|
|
torch.set_autocast_gpu_dtype,
|
|
warnings.warn,
|
|
torch._C._dynamo.eval_frame.unsupported,
|
|
torch.Tensor.__init__,
|
|
torch.resize_as_,
|
|
torch._tensor._convert,
|
|
]
|
|
|
|
# extract all dtypes from torch
|
|
dtypes = [
|
|
obj for obj in torch.__dict__.values() if isinstance(obj, type(torch.float32))
|
|
]
|
|
remove += dtypes
|
|
storage = [
|
|
obj
|
|
for obj in torch.__dict__.values()
|
|
if isinstance(obj, type(torch.FloatStorage))
|
|
]
|
|
remove += storage
|
|
|
|
# Distributed APIs don't work well with torch.compile.
|
|
if torch.distributed.is_available():
|
|
remove.extend(
|
|
torch.distributed.distributed_c10d.dynamo_unsupported_distributed_c10d_ops
|
|
)
|
|
|
|
return {id(x) for x in remove}
|
|
|
|
|
|
# Helper function to dump the torch name rule map generated based on
|
|
# the heuristic defined in gen_allowed_objs_and_ids.
|
|
def dump_allowed_torch_name_rule_map() -> None:
|
|
m = gen_allowed_objs_and_ids(record=True, c_binding_only=False).name_rule_map
|
|
for k, v in m.items():
|
|
print(f'"{k}": {v.__name__},')
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class AllowedObjects:
|
|
"""
|
|
Track the objects, object id - name pairs, and name - dynamo wrapping rule pairs
|
|
from the heuristic defined in `gen_allowed_objs_and_ids`.
|
|
TODO: Remove the overalp/duplication between these fields
|
|
after allowed_functions refactor is done.
|
|
"""
|
|
|
|
object_ids: Dict[int, str]
|
|
ctx_mamager_classes: Set[Any]
|
|
c_binding_in_graph_functions: Set[Any]
|
|
non_c_binding_in_graph_functions: Set[Any]
|
|
name_rule_map: Dict[str, Any]
|
|
|
|
|
|
def gen_allowed_objs_and_ids(record=False, c_binding_only=True) -> AllowedObjects:
|
|
"""
|
|
Walk torch.* and get the ids of all the stuff in it
|
|
"""
|
|
from .variables import TorchCtxManagerClassVariable, TorchInGraphFunctionVariable
|
|
|
|
warnings.filterwarnings("ignore", category=UserWarning, module="torch.distributed")
|
|
torch_object_ids = dict()
|
|
ctx_mamager_classes = set()
|
|
c_binding_in_graph_functions = set()
|
|
non_c_binding_in_graph_functions = set()
|
|
torch_name_rule_map = dict()
|
|
|
|
# Add obj to ctx_mamager_classes set if it's a torch context manager class.
|
|
# This is used to generate the ctx manager class list based on heuristic.
|
|
def heuristic_record_if_ctx_manager(obj, module, name):
|
|
if (
|
|
issubclass(type(obj), type)
|
|
and hasattr(obj, "__enter__")
|
|
and hasattr(obj, "__exit__")
|
|
):
|
|
torch_name_rule_map[
|
|
f"{module.__name__}.{name}"
|
|
] = TorchCtxManagerClassVariable
|
|
ctx_mamager_classes.add(obj)
|
|
|
|
# In some platforms, these functions were loaded as classes instead of functions.
|
|
# To mitigate these weired cases, we need this special check.
|
|
def is_special_functions(obj):
|
|
return hashable(obj) and obj in {
|
|
torch._C._cuda_isCurrentStreamCapturing,
|
|
torch._C._graph_pool_handle,
|
|
}
|
|
|
|
# Add obj to c_binding_in_graph_functions set or non_c_binding_in_graph_functions set
|
|
# if it's a torch function or method.
|
|
# This is used to generate the in graph function list based on heuristic.
|
|
def heuristic_record_if_in_graph_function(obj, module, name):
|
|
try:
|
|
if hasattr(obj, "__wrapped__"):
|
|
obj = obj.__wrapped__
|
|
except Exception:
|
|
pass
|
|
if isinstance(
|
|
obj,
|
|
(
|
|
types.FunctionType,
|
|
types.MethodType,
|
|
types.BuiltinFunctionType,
|
|
types.MethodDescriptorType,
|
|
types.WrapperDescriptorType,
|
|
),
|
|
) or is_special_functions(obj):
|
|
torch_name_rule_map[
|
|
f"{module.__name__}.{name}"
|
|
] = TorchInGraphFunctionVariable
|
|
if c_binding_only:
|
|
if not hasattr(obj, "__code__"):
|
|
c_binding_in_graph_functions.add(obj)
|
|
else:
|
|
if hasattr(obj, "__code__"):
|
|
non_c_binding_in_graph_functions.add(obj)
|
|
else:
|
|
c_binding_in_graph_functions.add(obj)
|
|
|
|
def _is_allowed_module_prefix(obj):
|
|
allowed_modules = ("torch", "math")
|
|
# torch.nn.modules.rnn is disallowed because these modules internally
|
|
# flatten their parameters. This flattening process will call
|
|
# Tensor.set_ with a Storage, and Storages cannot be traced with
|
|
# AOTAutograd; so we need to graph-break. To ensure this, we inline
|
|
# these functions, rather than keep them opaque-ly in the graph.
|
|
disallowed_modules = [
|
|
"torch.optim.",
|
|
"torch.nn.modules.rnn.",
|
|
"torch._dynamo.",
|
|
"torch._C._dynamo.",
|
|
"torch._inductor.",
|
|
"torch._C.inductor.",
|
|
"torch.fx.",
|
|
"torch._C._autograd",
|
|
"torch._C._cudart",
|
|
"torch._C._distributed_autograd",
|
|
"torch._C._distributed_c10d",
|
|
"torch._C._distributed_rpc",
|
|
"torch._C._functorch",
|
|
"torch._C._monitor",
|
|
"torch._C._nvtx",
|
|
"torch._C._lazy",
|
|
"torch._C._profiler",
|
|
"torch.__config__",
|
|
"torch._custom_op",
|
|
"torch._decomp",
|
|
"torch._dispatch",
|
|
"torch._export",
|
|
"torch._functorch.make_functional",
|
|
"torch._functorch.compile_utils",
|
|
"torch._functorch.partitioners",
|
|
"torch._functorch.aot_autograd",
|
|
"torch._functorch.compilers",
|
|
"torch._functorch.fx_minifier",
|
|
"torch.autograd.profiler_util",
|
|
"torch.autograd.profiler",
|
|
"torch._jit_internal",
|
|
"torch._library",
|
|
"torch._lobpcg",
|
|
"torch._logging",
|
|
"torch._meta_registrations",
|
|
"torch._namedtensor_internals",
|
|
"torch._numpy",
|
|
"torch._sources",
|
|
"torch._subclasses",
|
|
"torch._tensor",
|
|
"torch._tensor_str",
|
|
"torch._utils",
|
|
"torch._utils_internal",
|
|
"torch._vmap_internals",
|
|
"torch.compiler",
|
|
"torch.distributed",
|
|
"torch.export",
|
|
"torch.hub",
|
|
"torch.jit",
|
|
"torch.library",
|
|
"torch.masked.maskedtensor",
|
|
"torch.nn.init",
|
|
"torch.nn.modules.module",
|
|
"torch.nn.parallel",
|
|
"torch.nn.utils",
|
|
"torch.multiprocessing",
|
|
"torch.onnx",
|
|
"torch.overrides",
|
|
"torch.package",
|
|
"torch.profiler",
|
|
"torch.serialization",
|
|
"torch.storage",
|
|
"torch.utils",
|
|
]
|
|
if config.trace_distributed:
|
|
disallowed_modules.append("torch.distributed.")
|
|
|
|
allowed_modules_dot = tuple([x + "." for x in allowed_modules])
|
|
module = inspect.getmodule(obj)
|
|
if module is None:
|
|
return False
|
|
|
|
mod_name = module.__name__
|
|
|
|
if any(mod_name.startswith(m) for m in disallowed_modules):
|
|
return False
|
|
|
|
return mod_name in allowed_modules or mod_name.startswith(allowed_modules_dot)
|
|
|
|
def _find_torch_objects(module):
|
|
if any(
|
|
module.__name__.startswith(mod_name)
|
|
for mod_name in config.allowed_functions_module_string_ignorelist
|
|
):
|
|
return
|
|
torch_object_ids[id(module)] = module.__name__
|
|
for name, obj in list(module.__dict__.items()):
|
|
if id(obj) not in torch_object_ids:
|
|
# Dynamo allows all builtins into the graph and does not attempt
|
|
# to introspect into them. We don't want to allow instances of
|
|
# HigherOrderOperator into the graph all the time (Dynamo needs
|
|
# to introspect the body functions of these HigherOrderOperator
|
|
# first, decide they are safe, and then allow them into the graph).
|
|
# So we exclude HigherOrderOperator from being a builtin.
|
|
import torch._ops
|
|
|
|
if isinstance(obj, torch._ops.HigherOrderOperator):
|
|
continue
|
|
|
|
# We want to trace through `grad` and `vmap`
|
|
if obj in (
|
|
torch.func.grad,
|
|
deprecated_func.grad,
|
|
torch.func.vmap,
|
|
deprecated_func.vmap,
|
|
torch.nn.functional.triplet_margin_with_distance_loss,
|
|
torch.cond,
|
|
):
|
|
continue
|
|
|
|
if isinstance(obj, types.ModuleType):
|
|
if obj.__name__.startswith("torch.") and _is_allowed_module_prefix(
|
|
obj
|
|
):
|
|
torch_object_ids[id(obj)] = f"{module.__name__}.{name}"
|
|
_find_torch_objects(obj)
|
|
elif _is_allowed_module_prefix(obj):
|
|
if record:
|
|
heuristic_record_if_ctx_manager(obj, module, name)
|
|
heuristic_record_if_in_graph_function(obj, module, name)
|
|
torch_object_ids[id(obj)] = f"{module.__name__}.{name}"
|
|
elif inspect.getmodule(obj) is None and not is_safe_constant(obj):
|
|
if record:
|
|
heuristic_record_if_ctx_manager(obj, module, name)
|
|
heuristic_record_if_in_graph_function(obj, module, name)
|
|
torch_object_ids[id(obj)] = f"{module.__name__}.{name}"
|
|
|
|
_find_torch_objects(torch)
|
|
_find_torch_objects(math)
|
|
|
|
if config.trace_distributed:
|
|
from torch.distributed import _functional_collectives_impl as fci
|
|
|
|
for f in [
|
|
fci._all_gather_into_tensor,
|
|
fci._all_reduce,
|
|
fci._reduce_scatter_tensor,
|
|
fci._all_reduce_coalesced,
|
|
fci._all_gather_into_tensor_coalesced,
|
|
fci._reduce_scatter_tensor_coalesced,
|
|
]:
|
|
torch_object_ids[id(f)] = repr(f)
|
|
|
|
# torch.Tensor.{fn}
|
|
for name in dir(torch.Tensor):
|
|
method = getattr(torch.Tensor, name)
|
|
if isinstance(
|
|
method, (types.MethodDescriptorType, types.WrapperDescriptorType)
|
|
):
|
|
torch_object_ids[id(method)] = f"torch.Tensor.{name}"
|
|
|
|
for idx in _disallowed_function_ids():
|
|
if idx in torch_object_ids:
|
|
del torch_object_ids[idx]
|
|
|
|
for extra in (is_fx_tracing, is_compiling):
|
|
torch_object_ids[id(extra)] = f"{extra.__module__}.{extra.__name__}"
|
|
|
|
return AllowedObjects(
|
|
torch_object_ids,
|
|
ctx_mamager_classes,
|
|
c_binding_in_graph_functions,
|
|
non_c_binding_in_graph_functions,
|
|
torch_name_rule_map,
|
|
)
|
|
|
|
|
|
@FunctionIdSet
|
|
def _allowed_function_ids() -> Dict[int, str]:
|
|
return gen_allowed_objs_and_ids().object_ids
|
|
|
|
|
|
@FunctionIdSet
|
|
def _allowed_user_defined_function_ids() -> Dict[int, str]:
|
|
rv: Dict[int, str] = {}
|
|
return rv
|
|
|
|
|
|
@FunctionIdSet
|
|
def _builtin_function_ids() -> Dict[int, str]:
|
|
rv = {
|
|
id(v): f"builtins.{k}"
|
|
for k, v in builtins.__dict__.items()
|
|
if not k.startswith("_") and callable(v)
|
|
}
|
|
rv.update(
|
|
{
|
|
id(v): f"operator.{k}"
|
|
for k, v in operator.__dict__.items()
|
|
if not k.startswith("_") and callable(v)
|
|
}
|
|
)
|
|
rv.update(
|
|
{id(v): f"functools.{v.__name__}" for v in (itertools.chain, itertools.islice)}
|
|
)
|
|
rv.update(
|
|
{
|
|
id(cast): "typing.cast",
|
|
id(functools.reduce): "functools.reduce",
|
|
id(copy.deepcopy): "copy.deepcopy",
|
|
}
|
|
)
|
|
return rv
|
|
|
|
|
|
@FunctionIdSet
|
|
def _numpy_function_ids() -> Dict[int, str]:
|
|
rv = dict()
|
|
for mod in NP_SUPPORTED_MODULES:
|
|
rv.update(
|
|
{
|
|
id(v): f"{mod.__name__}.{k}"
|
|
for k, v in mod.__dict__.items()
|
|
if callable(v)
|
|
and (getattr(v, "__module__", None) or mod.__name__) == mod.__name__
|
|
}
|
|
)
|
|
return rv
|
|
|
|
|
|
@FunctionIdSet
|
|
def _builtin_constant_ids() -> Dict[int, str]:
|
|
"""
|
|
Collects constant builtins by eliminating callable items.
|
|
"""
|
|
rv = {
|
|
id(v): f"builtins.{k}"
|
|
for k, v in builtins.__dict__.items()
|
|
if not k.startswith("_") and not callable(v)
|
|
}
|
|
return rv
|
|
|
|
|
|
_lazy_module_init: Dict[str, List[Callable[[], None]]] = defaultdict(list)
|
|
|
|
|
|
def add_module_init_func(name: str, init_func: Callable[[], None]) -> None:
|
|
"""Register a module without eagerly importing it"""
|
|
# If the module is already imported, eagerly run init
|
|
assert "." not in name, f"Expected a root module name, but got {name}"
|
|
if name in sys.modules:
|
|
init_func()
|
|
|
|
# Module is not yet imported, delay processing until needed
|
|
assert name not in _lazy_module_init
|
|
_lazy_module_init[name].append(init_func)
|
|
|
|
|
|
def _maybe_init_lazy_module(obj: object) -> None:
|
|
module = getattr(obj, "__module__", None)
|
|
if module is None:
|
|
return
|
|
|
|
base_module = module.split(".")[0]
|
|
init_funcs = _lazy_module_init.pop(base_module, None)
|
|
if init_funcs is not None:
|
|
for fn in init_funcs:
|
|
fn()
|
|
|
|
|
|
def is_allowed(obj) -> bool:
|
|
"""Is this safe to trace like torch.add ?"""
|
|
_maybe_init_lazy_module(obj)
|
|
|
|
if id(obj) in _disallowed_function_ids:
|
|
return False
|
|
|
|
if id(obj) in _allowed_function_ids:
|
|
return True
|
|
|
|
# torch.ops is populated lazily so we don't necessarily have them in
|
|
# _allowed_function_ids. Figure it out by testing the type instead
|
|
# in those cases
|
|
return isinstance(
|
|
obj,
|
|
(torch._ops.OpOverloadPacket, torch._ops.OpOverload, torch._ops._OpNamespace),
|
|
)
|
|
|
|
|
|
def is_user_defined_allowed(obj) -> bool:
|
|
_maybe_init_lazy_module(obj)
|
|
return id(obj) in _allowed_user_defined_function_ids
|
|
|
|
|
|
def is_forbidden(obj) -> bool:
|
|
_maybe_init_lazy_module(obj)
|
|
return getattr(obj, "_dynamo_forbidden", False)
|
|
|
|
|
|
def torch_get_name(obj, default) -> str:
|
|
"""Convert a torch.* function to a string"""
|
|
return _allowed_function_ids.get_name(id(obj), default)
|
|
|
|
|
|
def is_builtin_callable(obj) -> bool:
|
|
return id(obj) in _builtin_function_ids
|
|
|
|
|
|
def is_builtin_constant(obj) -> bool:
|
|
return id(obj) in _builtin_constant_ids
|
|
|
|
|
|
def is_numpy(obj) -> bool:
|
|
if np is None:
|
|
return False
|
|
return isinstance(obj, (np.ndarray, np.generic)) or id(obj) in _numpy_function_ids
|