mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Testing if the minor change breaks other test cases.
For the added test case, TorchDynamo causes graph break on `torch.ops.foo.custom` but then again starts running on the recursively invoked frame - `foo_cpu` on L48 in testfile. This raises assertion like this
~~~
Traceback (most recent call last):
File "/scratch/anijain/work/pytorch/test/dynamo/test_decorators.py", line 65, in test_disallow_in_graph_for_custom_op
res = opt_fn(x)
File "/scratch/anijain/work/pytorch/torch/_dynamo/eval_frame.py", line 252, in _fn
return fn(*args, **kwargs)
File "/scratch/anijain/work/pytorch/test/dynamo/test_decorators.py", line 56, in fn
b = torch.ops.foo.custom(a)
File "/scratch/anijain/work/pytorch/torch/_ops.py", line 646, in __call__
return self._op(*args, **kwargs or {})
File "/scratch/anijain/work/pytorch/torch/_dynamo/eval_frame.py", line 401, in catch_errors
return callback(frame, cache_size, hooks, frame_state)
File "/scratch/anijain/work/pytorch/torch/_dynamo/convert_frame.py", line 495, in _convert_frame
result = inner_convert(frame, cache_size, hooks, frame_state)
File "/scratch/anijain/work/pytorch/torch/_dynamo/convert_frame.py", line 122, in _fn
return fn(*args, **kwargs)
File "/scratch/anijain/work/pytorch/torch/_dynamo/convert_frame.py", line 331, in _convert_frame_assert
return _compile(
File "/scratch/anijain/work/pytorch/torch/_dynamo/utils.py", line 169, in time_wrapper
r = func(*args, **kwargs)
File "/scratch/anijain/work/pytorch/torch/_dynamo/convert_frame.py", line 401, in _compile
out_code = transform_code_object(code, transform)
File "/scratch/anijain/work/pytorch/torch/_dynamo/bytecode_transformation.py", line 1000, in transform_code_object
transformations(instructions, code_options)
File "/scratch/anijain/work/pytorch/torch/_dynamo/convert_frame.py", line 371, in transform
tracer = InstructionTranslator(
File "/scratch/anijain/work/pytorch/torch/_dynamo/symbolic_convert.py", line 1890, in __init__
self.symbolic_locals = collections.OrderedDict(
File "/scratch/anijain/work/pytorch/torch/_dynamo/symbolic_convert.py", line 1893, in <genexpr>
VariableBuilder(
File "/scratch/anijain/work/pytorch/torch/_dynamo/variables/builder.py", line 165, in __call__
return self._wrap(value).clone(**self.options())
File "/scratch/anijain/work/pytorch/torch/_dynamo/variables/builder.py", line 290, in _wrap
return type_dispatch(self, value)
File "/scratch/anijain/work/pytorch/torch/_dynamo/variables/builder.py", line 776, in wrap_tensor
tensor_variable = wrap_fx_proxy(
File "/scratch/anijain/work/pytorch/torch/_dynamo/variables/builder.py", line 923, in wrap_fx_proxy
return wrap_fx_proxy_cls(
File "/scratch/anijain/work/pytorch/torch/_dynamo/variables/builder.py", line 983, in wrap_fx_proxy_cls
example_value = wrap_to_fake_tensor_and_record(
File "/scratch/anijain/work/pytorch/torch/_dynamo/variables/builder.py", line 1213, in wrap_to_fake_tensor_and_record
fake_e = wrap_fake_exception(
File "/scratch/anijain/work/pytorch/torch/_dynamo/utils.py", line 835, in wrap_fake_exception
return fn()
File "/scratch/anijain/work/pytorch/torch/_dynamo/variables/builder.py", line 1214, in <lambda>
lambda: tx.fake_mode.from_tensor(
File "/scratch/anijain/work/pytorch/torch/_subclasses/fake_tensor.py", line 1434, in from_tensor
return self.fake_tensor_converter(
File "/scratch/anijain/work/pytorch/torch/_subclasses/fake_tensor.py", line 329, in __call__
return self.from_real_tensor(
File "/scratch/anijain/work/pytorch/torch/_subclasses/fake_tensor.py", line 283, in from_real_tensor
out = self.meta_converter(
File "/scratch/anijain/work/pytorch/torch/_subclasses/meta_utils.py", line 531, in __call__
r = self.meta_tensor(
File "/scratch/anijain/work/pytorch/torch/_subclasses/meta_utils.py", line 184, in meta_tensor
assert not torch._C._dispatch_tls_local_exclude_set().has(
AssertionError:
~~~
It seems `_dynamo.disable` is the right option for custom ops added by `torch.library`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99600
Approved by: https://github.com/jansel
282 lines
8.6 KiB
Python
282 lines
8.6 KiB
Python
import builtins
|
|
import collections
|
|
import copy
|
|
import functools
|
|
import inspect
|
|
import itertools
|
|
import math
|
|
import operator
|
|
import types
|
|
import warnings
|
|
from typing import Dict, Optional, Set
|
|
|
|
import torch
|
|
from torch.fx._symbolic_trace import is_fx_tracing
|
|
|
|
from . import config
|
|
from .external_utils import is_compiling
|
|
from .utils import HAS_NUMPY, is_safe_constant, np
|
|
|
|
"""
|
|
A note on allowed functions:
|
|
|
|
Dynamo consults this file to determine if a particular function/module
|
|
is allowed to appear as a node in its fx output.
|
|
|
|
If a function is disallowed, it may either be traced-through, or skipped.
|
|
|
|
Trace-through means dynamo will continue to trace the interior code for
|
|
the function/module rather than stopping at its boundary and recording it
|
|
as a node in the fx graph. Whether tracing through or allowing, the functionality
|
|
of the function/module is part of the dynamo graph. Caveat: if tracing through,
|
|
any interior operation could trigger its own graph-break.
|
|
|
|
Skips are determined by (torch/_dynamo/skipfiles.py) - see "a note on
|
|
skipfiles" there.
|
|
"""
|
|
|
|
|
|
def make_function_id_set(lazy_initializer):
|
|
"""
|
|
Track a set of `id()`s of objects which are either allowed or not
|
|
allowed to go into the generated FX graph. Use to test for torch.*,
|
|
numpy.*, builtins.*, etc.
|
|
|
|
Support user modification to permit customization of what can be
|
|
added to the graph and what will cause a graph break.
|
|
"""
|
|
|
|
class FunctionIdSet:
|
|
function_ids: Optional[Set[int]] = None
|
|
function_names: Optional[Dict[int, str]] = None
|
|
|
|
def __call__(self):
|
|
if self.function_ids is None:
|
|
value = lazy_initializer()
|
|
if isinstance(value, dict):
|
|
self.function_ids = set(value.keys())
|
|
self.function_names = value
|
|
else:
|
|
assert isinstance(value, set)
|
|
self.function_ids = value
|
|
return self.function_ids
|
|
|
|
def get_name(self, idx: int, default: str):
|
|
self() # lazy init
|
|
return self.function_names.get(idx, default)
|
|
|
|
def add(self, idx: int):
|
|
self() # lazy init
|
|
self.function_ids.add(idx)
|
|
|
|
def remove(self, idx: int):
|
|
if idx in self():
|
|
self.function_ids.remove(idx)
|
|
|
|
def __contains__(self, idx: int):
|
|
return idx in self()
|
|
|
|
return FunctionIdSet()
|
|
|
|
|
|
@make_function_id_set
|
|
def _disallowed_function_ids():
|
|
remove = [
|
|
True,
|
|
False,
|
|
None,
|
|
collections.OrderedDict,
|
|
copy.copy,
|
|
copy.deepcopy,
|
|
inspect.signature,
|
|
math.__package__,
|
|
torch.__builtins__,
|
|
torch.autocast_decrement_nesting,
|
|
torch.autocast_increment_nesting,
|
|
torch.autograd.grad,
|
|
torch.clear_autocast_cache,
|
|
torch.cuda.current_device,
|
|
torch.distributions.constraints.is_dependent,
|
|
torch.distributions.normal.Normal,
|
|
torch.inference_mode,
|
|
torch.set_anomaly_enabled,
|
|
torch.set_autocast_cache_enabled,
|
|
torch.set_autocast_cpu_dtype,
|
|
torch.set_autocast_cpu_enabled,
|
|
torch.set_autocast_enabled,
|
|
torch.set_autocast_gpu_dtype,
|
|
warnings.warn,
|
|
torch._C._dynamo.eval_frame.unsupported,
|
|
]
|
|
# extract all dtypes from torch
|
|
dtypes = [
|
|
obj for obj in torch.__dict__.values() if isinstance(obj, type(torch.float32))
|
|
]
|
|
remove += dtypes
|
|
storage = [
|
|
obj
|
|
for obj in torch.__dict__.values()
|
|
if isinstance(obj, type(torch.FloatStorage))
|
|
]
|
|
remove += storage
|
|
return {id(x) for x in remove}
|
|
|
|
|
|
@make_function_id_set
|
|
def _allowed_function_ids():
|
|
"""
|
|
Walk torch.* and get the ids of all the stuff in it
|
|
"""
|
|
warnings.filterwarnings("ignore", category=UserWarning, module="torch.distributed")
|
|
torch_object_ids = dict()
|
|
|
|
def _is_allowed_module_prefix(obj):
|
|
allowed_modules = ("torch", "math")
|
|
# torch.nn.modules.rnn is disallowed because these modules internally
|
|
# flatten their parameters. This flattening process will call
|
|
# Tensor.set_ with a Storage, and Storages cannot be traced with
|
|
# AOTAutograd; so we need to graph-break. To ensure this, we inline
|
|
# these functions, rather than keep them opaque-ly in the graph.
|
|
disallowed_modules = (
|
|
"torch.optim.",
|
|
"torch.nn.modules.rnn.",
|
|
"torch._dynamo.",
|
|
"torch._C._dynamo.",
|
|
"torch._inductor.",
|
|
"torch._C.inductor.",
|
|
"torch.fx.",
|
|
"torch.distributed.fsdp.",
|
|
)
|
|
allowed_modules_dot = tuple([x + "." for x in allowed_modules])
|
|
module = inspect.getmodule(obj)
|
|
if module is None:
|
|
return False
|
|
|
|
mod_name = module.__name__
|
|
|
|
if any(mod_name.startswith(m) for m in disallowed_modules):
|
|
return False
|
|
|
|
return mod_name in allowed_modules or mod_name.startswith(allowed_modules_dot)
|
|
|
|
def _find_torch_objects(module):
|
|
if any(
|
|
module.__name__.startswith(mod_name)
|
|
for mod_name in config.allowed_functions_module_string_ignorelist
|
|
):
|
|
return
|
|
torch_object_ids[id(module)] = module.__name__
|
|
for name, obj in list(module.__dict__.items()):
|
|
if id(obj) not in torch_object_ids:
|
|
if isinstance(obj, types.ModuleType):
|
|
if obj.__name__.startswith("torch.") and _is_allowed_module_prefix(
|
|
obj
|
|
):
|
|
torch_object_ids[id(obj)] = f"{module.__name__}.{name}"
|
|
_find_torch_objects(obj)
|
|
elif _is_allowed_module_prefix(obj):
|
|
torch_object_ids[id(obj)] = f"{module.__name__}.{name}"
|
|
elif inspect.getmodule(obj) is None and not is_safe_constant(obj):
|
|
torch_object_ids[id(obj)] = f"{module.__name__}.{name}"
|
|
|
|
_find_torch_objects(torch)
|
|
_find_torch_objects(math)
|
|
|
|
# torch.Tensor.{fn}
|
|
for name in dir(torch.Tensor):
|
|
method = getattr(torch.Tensor, name)
|
|
if isinstance(method, types.MethodDescriptorType):
|
|
torch_object_ids[id(method)] = f"torch.Tensor.{name}"
|
|
|
|
for idx in _disallowed_function_ids():
|
|
if idx in torch_object_ids:
|
|
del torch_object_ids[idx]
|
|
|
|
for extra in (is_fx_tracing, is_compiling):
|
|
torch_object_ids[id(extra)] = f"{extra.__module__}.{extra.__name__}"
|
|
|
|
return torch_object_ids
|
|
|
|
|
|
@make_function_id_set
|
|
def _builtin_function_ids():
|
|
rv = {
|
|
id(v): f"builtins.{k}"
|
|
for k, v in builtins.__dict__.items()
|
|
if not k.startswith("_") and callable(v)
|
|
}
|
|
rv.update(
|
|
{
|
|
id(v): f"operator.{k}"
|
|
for k, v in operator.__dict__.items()
|
|
if not k.startswith("_") and callable(v)
|
|
}
|
|
)
|
|
rv.update(
|
|
{id(v): f"functools.{v.__name__}" for v in (itertools.chain, itertools.islice)}
|
|
)
|
|
rv[id(functools.reduce)] = "functools.reduce"
|
|
return rv
|
|
|
|
|
|
@make_function_id_set
|
|
def _numpy_function_ids():
|
|
rv = dict()
|
|
if HAS_NUMPY:
|
|
for mod in (np, np.random):
|
|
rv.update(
|
|
{
|
|
id(v): f"{mod.__name__}.{k}"
|
|
for k, v in mod.__dict__.items()
|
|
if callable(v)
|
|
and (getattr(v, "__module__", None) or mod.__name__) == mod.__name__
|
|
}
|
|
)
|
|
return rv
|
|
|
|
|
|
@make_function_id_set
|
|
def _builtin_constant_ids():
|
|
"""
|
|
Collects constant builtins by eliminating callable items.
|
|
"""
|
|
rv = {
|
|
id(v): f"builtins.{k}"
|
|
for k, v in builtins.__dict__.items()
|
|
if not k.startswith("_") and not callable(v)
|
|
}
|
|
return rv
|
|
|
|
|
|
def is_allowed(obj):
|
|
"""Is this safe to trace like torch.add ?"""
|
|
# torch.ops is populated lazily so we don't necessarily have them in
|
|
# _allowed_function_ids. Figure it out by testing the type instead
|
|
# in those cases
|
|
if id(obj) in _disallowed_function_ids:
|
|
return False
|
|
return id(obj) in _allowed_function_ids or isinstance(
|
|
obj,
|
|
(torch._ops.OpOverloadPacket, torch._ops.OpOverload, torch._ops._OpNamespace),
|
|
)
|
|
|
|
|
|
def torch_get_name(obj, default):
|
|
"""Convert a torch.* function to a string"""
|
|
return _allowed_function_ids.get_name(id(obj), default)
|
|
|
|
|
|
def is_builtin_callable(obj):
|
|
return id(obj) in _builtin_function_ids
|
|
|
|
|
|
def is_builtin_constant(obj):
|
|
return id(obj) in _builtin_constant_ids
|
|
|
|
|
|
def is_numpy(obj):
|
|
if HAS_NUMPY:
|
|
return isinstance(obj, np.ndarray) or id(obj) in _numpy_function_ids
|
|
else:
|
|
return False
|