pytorch/torch/_dynamo/allowed_functions.py
Animesh Jain 31eb9949e4 [dynamo] disallow_in_graph bugfix (#99600)
Testing if the minor change breaks other test cases.

For the added test case, TorchDynamo causes graph break on `torch.ops.foo.custom` but then again starts running on the recursively invoked frame - `foo_cpu` on L48 in testfile. This raises assertion like this

~~~
Traceback (most recent call last):
  File "/scratch/anijain/work/pytorch/test/dynamo/test_decorators.py", line 65, in test_disallow_in_graph_for_custom_op
    res = opt_fn(x)
  File "/scratch/anijain/work/pytorch/torch/_dynamo/eval_frame.py", line 252, in _fn
    return fn(*args, **kwargs)
  File "/scratch/anijain/work/pytorch/test/dynamo/test_decorators.py", line 56, in fn
    b = torch.ops.foo.custom(a)
  File "/scratch/anijain/work/pytorch/torch/_ops.py", line 646, in __call__
    return self._op(*args, **kwargs or {})
  File "/scratch/anijain/work/pytorch/torch/_dynamo/eval_frame.py", line 401, in catch_errors
    return callback(frame, cache_size, hooks, frame_state)
  File "/scratch/anijain/work/pytorch/torch/_dynamo/convert_frame.py", line 495, in _convert_frame
    result = inner_convert(frame, cache_size, hooks, frame_state)
  File "/scratch/anijain/work/pytorch/torch/_dynamo/convert_frame.py", line 122, in _fn
    return fn(*args, **kwargs)
  File "/scratch/anijain/work/pytorch/torch/_dynamo/convert_frame.py", line 331, in _convert_frame_assert
    return _compile(
  File "/scratch/anijain/work/pytorch/torch/_dynamo/utils.py", line 169, in time_wrapper
    r = func(*args, **kwargs)
  File "/scratch/anijain/work/pytorch/torch/_dynamo/convert_frame.py", line 401, in _compile
    out_code = transform_code_object(code, transform)
  File "/scratch/anijain/work/pytorch/torch/_dynamo/bytecode_transformation.py", line 1000, in transform_code_object
    transformations(instructions, code_options)
  File "/scratch/anijain/work/pytorch/torch/_dynamo/convert_frame.py", line 371, in transform
    tracer = InstructionTranslator(
  File "/scratch/anijain/work/pytorch/torch/_dynamo/symbolic_convert.py", line 1890, in __init__
    self.symbolic_locals = collections.OrderedDict(
  File "/scratch/anijain/work/pytorch/torch/_dynamo/symbolic_convert.py", line 1893, in <genexpr>
    VariableBuilder(
  File "/scratch/anijain/work/pytorch/torch/_dynamo/variables/builder.py", line 165, in __call__
    return self._wrap(value).clone(**self.options())
  File "/scratch/anijain/work/pytorch/torch/_dynamo/variables/builder.py", line 290, in _wrap
    return type_dispatch(self, value)
  File "/scratch/anijain/work/pytorch/torch/_dynamo/variables/builder.py", line 776, in wrap_tensor
    tensor_variable = wrap_fx_proxy(
  File "/scratch/anijain/work/pytorch/torch/_dynamo/variables/builder.py", line 923, in wrap_fx_proxy
    return wrap_fx_proxy_cls(
  File "/scratch/anijain/work/pytorch/torch/_dynamo/variables/builder.py", line 983, in wrap_fx_proxy_cls
    example_value = wrap_to_fake_tensor_and_record(
  File "/scratch/anijain/work/pytorch/torch/_dynamo/variables/builder.py", line 1213, in wrap_to_fake_tensor_and_record
    fake_e = wrap_fake_exception(
  File "/scratch/anijain/work/pytorch/torch/_dynamo/utils.py", line 835, in wrap_fake_exception
    return fn()
  File "/scratch/anijain/work/pytorch/torch/_dynamo/variables/builder.py", line 1214, in <lambda>
    lambda: tx.fake_mode.from_tensor(
  File "/scratch/anijain/work/pytorch/torch/_subclasses/fake_tensor.py", line 1434, in from_tensor
    return self.fake_tensor_converter(
  File "/scratch/anijain/work/pytorch/torch/_subclasses/fake_tensor.py", line 329, in __call__
    return self.from_real_tensor(
  File "/scratch/anijain/work/pytorch/torch/_subclasses/fake_tensor.py", line 283, in from_real_tensor
    out = self.meta_converter(
  File "/scratch/anijain/work/pytorch/torch/_subclasses/meta_utils.py", line 531, in __call__
    r = self.meta_tensor(
  File "/scratch/anijain/work/pytorch/torch/_subclasses/meta_utils.py", line 184, in meta_tensor
    assert not torch._C._dispatch_tls_local_exclude_set().has(
AssertionError:

~~~

It seems `_dynamo.disable` is the right option for custom ops added by `torch.library`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/99600
Approved by: https://github.com/jansel
2023-04-22 12:40:33 +00:00

282 lines
8.6 KiB
Python

import builtins
import collections
import copy
import functools
import inspect
import itertools
import math
import operator
import types
import warnings
from typing import Dict, Optional, Set
import torch
from torch.fx._symbolic_trace import is_fx_tracing
from . import config
from .external_utils import is_compiling
from .utils import HAS_NUMPY, is_safe_constant, np
"""
A note on allowed functions:
Dynamo consults this file to determine if a particular function/module
is allowed to appear as a node in its fx output.
If a function is disallowed, it may either be traced-through, or skipped.
Trace-through means dynamo will continue to trace the interior code for
the function/module rather than stopping at its boundary and recording it
as a node in the fx graph. Whether tracing through or allowing, the functionality
of the function/module is part of the dynamo graph. Caveat: if tracing through,
any interior operation could trigger its own graph-break.
Skips are determined by (torch/_dynamo/skipfiles.py) - see "a note on
skipfiles" there.
"""
def make_function_id_set(lazy_initializer):
"""
Track a set of `id()`s of objects which are either allowed or not
allowed to go into the generated FX graph. Use to test for torch.*,
numpy.*, builtins.*, etc.
Support user modification to permit customization of what can be
added to the graph and what will cause a graph break.
"""
class FunctionIdSet:
function_ids: Optional[Set[int]] = None
function_names: Optional[Dict[int, str]] = None
def __call__(self):
if self.function_ids is None:
value = lazy_initializer()
if isinstance(value, dict):
self.function_ids = set(value.keys())
self.function_names = value
else:
assert isinstance(value, set)
self.function_ids = value
return self.function_ids
def get_name(self, idx: int, default: str):
self() # lazy init
return self.function_names.get(idx, default)
def add(self, idx: int):
self() # lazy init
self.function_ids.add(idx)
def remove(self, idx: int):
if idx in self():
self.function_ids.remove(idx)
def __contains__(self, idx: int):
return idx in self()
return FunctionIdSet()
@make_function_id_set
def _disallowed_function_ids():
remove = [
True,
False,
None,
collections.OrderedDict,
copy.copy,
copy.deepcopy,
inspect.signature,
math.__package__,
torch.__builtins__,
torch.autocast_decrement_nesting,
torch.autocast_increment_nesting,
torch.autograd.grad,
torch.clear_autocast_cache,
torch.cuda.current_device,
torch.distributions.constraints.is_dependent,
torch.distributions.normal.Normal,
torch.inference_mode,
torch.set_anomaly_enabled,
torch.set_autocast_cache_enabled,
torch.set_autocast_cpu_dtype,
torch.set_autocast_cpu_enabled,
torch.set_autocast_enabled,
torch.set_autocast_gpu_dtype,
warnings.warn,
torch._C._dynamo.eval_frame.unsupported,
]
# extract all dtypes from torch
dtypes = [
obj for obj in torch.__dict__.values() if isinstance(obj, type(torch.float32))
]
remove += dtypes
storage = [
obj
for obj in torch.__dict__.values()
if isinstance(obj, type(torch.FloatStorage))
]
remove += storage
return {id(x) for x in remove}
@make_function_id_set
def _allowed_function_ids():
"""
Walk torch.* and get the ids of all the stuff in it
"""
warnings.filterwarnings("ignore", category=UserWarning, module="torch.distributed")
torch_object_ids = dict()
def _is_allowed_module_prefix(obj):
allowed_modules = ("torch", "math")
# torch.nn.modules.rnn is disallowed because these modules internally
# flatten their parameters. This flattening process will call
# Tensor.set_ with a Storage, and Storages cannot be traced with
# AOTAutograd; so we need to graph-break. To ensure this, we inline
# these functions, rather than keep them opaque-ly in the graph.
disallowed_modules = (
"torch.optim.",
"torch.nn.modules.rnn.",
"torch._dynamo.",
"torch._C._dynamo.",
"torch._inductor.",
"torch._C.inductor.",
"torch.fx.",
"torch.distributed.fsdp.",
)
allowed_modules_dot = tuple([x + "." for x in allowed_modules])
module = inspect.getmodule(obj)
if module is None:
return False
mod_name = module.__name__
if any(mod_name.startswith(m) for m in disallowed_modules):
return False
return mod_name in allowed_modules or mod_name.startswith(allowed_modules_dot)
def _find_torch_objects(module):
if any(
module.__name__.startswith(mod_name)
for mod_name in config.allowed_functions_module_string_ignorelist
):
return
torch_object_ids[id(module)] = module.__name__
for name, obj in list(module.__dict__.items()):
if id(obj) not in torch_object_ids:
if isinstance(obj, types.ModuleType):
if obj.__name__.startswith("torch.") and _is_allowed_module_prefix(
obj
):
torch_object_ids[id(obj)] = f"{module.__name__}.{name}"
_find_torch_objects(obj)
elif _is_allowed_module_prefix(obj):
torch_object_ids[id(obj)] = f"{module.__name__}.{name}"
elif inspect.getmodule(obj) is None and not is_safe_constant(obj):
torch_object_ids[id(obj)] = f"{module.__name__}.{name}"
_find_torch_objects(torch)
_find_torch_objects(math)
# torch.Tensor.{fn}
for name in dir(torch.Tensor):
method = getattr(torch.Tensor, name)
if isinstance(method, types.MethodDescriptorType):
torch_object_ids[id(method)] = f"torch.Tensor.{name}"
for idx in _disallowed_function_ids():
if idx in torch_object_ids:
del torch_object_ids[idx]
for extra in (is_fx_tracing, is_compiling):
torch_object_ids[id(extra)] = f"{extra.__module__}.{extra.__name__}"
return torch_object_ids
@make_function_id_set
def _builtin_function_ids():
rv = {
id(v): f"builtins.{k}"
for k, v in builtins.__dict__.items()
if not k.startswith("_") and callable(v)
}
rv.update(
{
id(v): f"operator.{k}"
for k, v in operator.__dict__.items()
if not k.startswith("_") and callable(v)
}
)
rv.update(
{id(v): f"functools.{v.__name__}" for v in (itertools.chain, itertools.islice)}
)
rv[id(functools.reduce)] = "functools.reduce"
return rv
@make_function_id_set
def _numpy_function_ids():
rv = dict()
if HAS_NUMPY:
for mod in (np, np.random):
rv.update(
{
id(v): f"{mod.__name__}.{k}"
for k, v in mod.__dict__.items()
if callable(v)
and (getattr(v, "__module__", None) or mod.__name__) == mod.__name__
}
)
return rv
@make_function_id_set
def _builtin_constant_ids():
"""
Collects constant builtins by eliminating callable items.
"""
rv = {
id(v): f"builtins.{k}"
for k, v in builtins.__dict__.items()
if not k.startswith("_") and not callable(v)
}
return rv
def is_allowed(obj):
"""Is this safe to trace like torch.add ?"""
# torch.ops is populated lazily so we don't necessarily have them in
# _allowed_function_ids. Figure it out by testing the type instead
# in those cases
if id(obj) in _disallowed_function_ids:
return False
return id(obj) in _allowed_function_ids or isinstance(
obj,
(torch._ops.OpOverloadPacket, torch._ops.OpOverload, torch._ops._OpNamespace),
)
def torch_get_name(obj, default):
"""Convert a torch.* function to a string"""
return _allowed_function_ids.get_name(id(obj), default)
def is_builtin_callable(obj):
return id(obj) in _builtin_function_ids
def is_builtin_constant(obj):
return id(obj) in _builtin_constant_ids
def is_numpy(obj):
if HAS_NUMPY:
return isinstance(obj, np.ndarray) or id(obj) in _numpy_function_ids
else:
return False