mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[Dynamo][15/N] Merge allow_in_graph/inline/skip trace rules check into trace_rule.lookup (#118971)
Finally we have this PR to merge allow_in_graph/inline/skip trace rules into ```trace_rules.lookup_inner```, where we can define and lookup trace rules at both function level and file level. Going forward, this is the central place that we define and consulte Dynamo trace rule for any function. * ```trace_rules.looup``` is the API can return allow_in_graph, inline or skip. * ```skipfiles.check``` is the API can return inline or skip, since we have multiple places that only do inline/skip check. * I'll move ```skipfiles.check``` to ```trace_rules.check``` as one of the follow-ups. * Both functions consulte ```trace_rules.lookup_inner``` to get the tracing rule. To avoid a single big PR, I left a few items as the follow-ups: * Remove ```skipfiles.py``` and merge the code into ```trace_rules.py```. * We do double check in ```symbolic_convert.check_inlineable```, will refactor and simplify it. We should only do inline/skip check before generating ```SkipFilesVariable``` and ```UserFunctionVariable```. * Rename ```SkipFilesVariable``` as ```SkipFunctionVariable```, since we only handle functions. * The inline/skip reasons are not logged for some cases, since the new lookup framework doesn't always return inline/skip reasons. I'll refactor loggings to record the inline/skip reason in next step. Pull Request resolved: https://github.com/pytorch/pytorch/pull/118971 Approved by: https://github.com/jansel
This commit is contained in:
parent
284b0b5f44
commit
0f478d9d61
|
|
@ -12,18 +12,15 @@ import torch
|
|||
import torch._dynamo.config as config
|
||||
import torch._dynamo.test_case
|
||||
import torch._functorch.deprecated as deprecated_func
|
||||
from torch._dynamo.skipfiles import (
|
||||
FUNC_INLINELIST,
|
||||
LEGACY_MOD_INLINELIST,
|
||||
MOD_INLINELIST,
|
||||
)
|
||||
from torch._dynamo.skipfiles import LEGACY_MOD_INLINELIST, MOD_INLINELIST
|
||||
from torch._dynamo.trace_rules import (
|
||||
load_object,
|
||||
manual_torch_name_rule_map,
|
||||
torch_c_binding_in_graph_functions,
|
||||
torch_non_c_binding_in_graph_functions,
|
||||
)
|
||||
from torch._dynamo.utils import hashable, is_safe_constant, istype
|
||||
from torch._dynamo.variables import TorchInGraphFunctionVariable
|
||||
from torch._dynamo.variables import TorchInGraphFunctionVariable, UserFunctionVariable
|
||||
|
||||
try:
|
||||
from .utils import create_dummy_module_and_function
|
||||
|
|
@ -282,19 +279,6 @@ def gen_allowed_objs_and_ids(record=False, c_binding_only=True) -> AllowedObject
|
|||
)
|
||||
|
||||
|
||||
def gen_get_func_inlinelist(dummy_func_inlinelist):
|
||||
def get_func_inlinelist():
|
||||
inlinelist = set()
|
||||
for f in dummy_func_inlinelist:
|
||||
module_name, fn_name = f.rsplit(".", 1)
|
||||
m = importlib.import_module(module_name)
|
||||
fn = getattr(m, fn_name)
|
||||
inlinelist.add(fn.__code__)
|
||||
return inlinelist
|
||||
|
||||
return get_func_inlinelist
|
||||
|
||||
|
||||
class TraceRuleTests(torch._dynamo.test_case.TestCase):
|
||||
def _check_set_equality(self, generated, used, rule_map, ignored_set):
|
||||
x = generated - used
|
||||
|
|
@ -321,13 +305,6 @@ class TraceRuleTests(torch._dynamo.test_case.TestCase):
|
|||
isinstance(importlib.import_module(m), types.ModuleType),
|
||||
f"{m} from skipfiles.MOD_INLINELIST/LEGACY_MOD_INLINELIST is not a python module, please check and correct it.",
|
||||
)
|
||||
for f in FUNC_INLINELIST:
|
||||
module_name, fn_name = f.rsplit(".", 1)
|
||||
m = importlib.import_module(module_name)
|
||||
self.assertTrue(
|
||||
isinstance(getattr(m, fn_name), types.FunctionType),
|
||||
f"{f} from skipfiles.FUNC_INLINELIST is not a python function, please check and correct it.",
|
||||
)
|
||||
|
||||
def test_torch_name_rule_map_updated(self):
|
||||
# Generate the allowed objects based on heuristic defined in `allowed_functions.py`,
|
||||
|
|
@ -363,15 +340,23 @@ class TraceRuleTests(torch._dynamo.test_case.TestCase):
|
|||
)
|
||||
)
|
||||
|
||||
def test_func_inlinelist_torch_function(self):
|
||||
def test_force_inline_torch_function(self):
|
||||
# `torch._dynamo.utils.istype` is skipped by default
|
||||
def fn(x):
|
||||
if istype(x, torch.Tensor):
|
||||
return x + 1
|
||||
else:
|
||||
return x - 1
|
||||
|
||||
func_inlinelist = torch._dynamo.skipfiles.FUNC_INLINELIST.copy()
|
||||
func_inlinelist.add("torch._dynamo.utils.istype")
|
||||
_manual_torch_name_rule_map = manual_torch_name_rule_map.copy()
|
||||
# Force inline `torch._dynamo.utils.istype` by setting trace rule.
|
||||
_manual_torch_name_rule_map["torch._dynamo.utils.istype"] = UserFunctionVariable
|
||||
|
||||
_torch_name_rule_map = [
|
||||
_manual_torch_name_rule_map,
|
||||
torch_c_binding_in_graph_functions,
|
||||
torch_non_c_binding_in_graph_functions,
|
||||
]
|
||||
|
||||
self.assertTrue(
|
||||
"torch._dynamo" not in torch._dynamo.skipfiles.LEGACY_MOD_INLINELIST
|
||||
|
|
@ -379,8 +364,11 @@ class TraceRuleTests(torch._dynamo.test_case.TestCase):
|
|||
self.assertTrue("torch._dynamo" not in torch._dynamo.skipfiles.MOD_INLINELIST)
|
||||
|
||||
with unittest.mock.patch(
|
||||
"torch._dynamo.skipfiles.get_func_inlinelist",
|
||||
gen_get_func_inlinelist(func_inlinelist),
|
||||
"torch._dynamo.trace_rules.torch_name_rule_map",
|
||||
_torch_name_rule_map,
|
||||
), unittest.mock.patch(
|
||||
"torch._dynamo.trace_rules.get_torch_obj_rule_map",
|
||||
torch._dynamo.trace_rules.get_torch_obj_rule_map.__wrapped__, # bypass functools.lru_cache
|
||||
):
|
||||
x = torch.rand(3)
|
||||
opt_fn = torch.compile(backend="eager", fullgraph=True)(fn)
|
||||
|
|
@ -388,23 +376,32 @@ class TraceRuleTests(torch._dynamo.test_case.TestCase):
|
|||
res = opt_fn(x)
|
||||
self.assertEqual(ref, res)
|
||||
|
||||
def test_func_inlinelist_third_party_function(self):
|
||||
def test_force_inline_custom_function(self):
|
||||
mod, func = create_dummy_module_and_function()
|
||||
|
||||
def fn(x):
|
||||
return func(x)
|
||||
|
||||
func_inlinelist = torch._dynamo.skipfiles.FUNC_INLINELIST.copy()
|
||||
func_inlinelist.add(f"{mod.__name__}.{func.__name__}")
|
||||
_manual_torch_name_rule_map = manual_torch_name_rule_map.copy()
|
||||
# Force inline `mod.func` by setting trace rule.
|
||||
_manual_torch_name_rule_map[
|
||||
f"{mod.__name__}.{func.__name__}"
|
||||
] = UserFunctionVariable
|
||||
|
||||
_torch_name_rule_map = [
|
||||
_manual_torch_name_rule_map,
|
||||
torch_c_binding_in_graph_functions,
|
||||
torch_non_c_binding_in_graph_functions,
|
||||
]
|
||||
|
||||
with unittest.mock.patch(
|
||||
"torch._dynamo.skipfiles.get_func_inlinelist",
|
||||
gen_get_func_inlinelist(func_inlinelist),
|
||||
"torch._dynamo.trace_rules.torch_name_rule_map",
|
||||
_torch_name_rule_map,
|
||||
), unittest.mock.patch(
|
||||
"torch._dynamo.skipfiles.SKIP_DIRS",
|
||||
torch._dynamo.skipfiles.SKIP_DIRS.copy(),
|
||||
"torch._dynamo.trace_rules.get_torch_obj_rule_map",
|
||||
torch._dynamo.trace_rules.get_torch_obj_rule_map.__wrapped__,
|
||||
):
|
||||
# First adding the module to SKIP_DIRS so that it will be skipped.
|
||||
# First adding the module to SKIP_DIRS so that it will be skipped by default.
|
||||
torch._dynamo.skipfiles.add(mod.__name__)
|
||||
x = torch.rand(3)
|
||||
opt_fn = torch.compile(backend="eager", fullgraph=True)(fn)
|
||||
|
|
|
|||
|
|
@ -93,7 +93,7 @@ def allow_in_graph(fn):
|
|||
if isinstance(fn, (list, tuple)):
|
||||
return [allow_in_graph(x) for x in fn]
|
||||
assert callable(fn), "allow_in_graph expects a callable"
|
||||
if trace_rules.lookup(fn) != variables.TorchInGraphFunctionVariable:
|
||||
if trace_rules.lookup_callable(fn) != variables.TorchInGraphFunctionVariable:
|
||||
trace_rules._disallowed_callable_ids.remove(id(fn))
|
||||
trace_rules._allowed_callable_ids.add(id(fn))
|
||||
return fn
|
||||
|
|
@ -106,8 +106,9 @@ def _disallow_in_graph_helper(throw_if_not_allowed):
|
|||
assert callable(fn), "disallow_in_graph expects a callable"
|
||||
if (
|
||||
throw_if_not_allowed
|
||||
and trace_rules.lookup_callable(fn)
|
||||
!= variables.TorchInGraphFunctionVariable
|
||||
and trace_rules.lookup(fn) != variables.TorchInGraphFunctionVariable
|
||||
and fn not in trace_rules._allowed_callable_ids
|
||||
):
|
||||
raise IncorrectUsage(
|
||||
"disallow_in_graph is expected to be used on an already allowed callable (like torch.* ops). "
|
||||
|
|
|
|||
|
|
@ -37,8 +37,10 @@ import torch.utils._content_store
|
|||
from ..utils import _config_module
|
||||
from .utils import getfile
|
||||
|
||||
from .variables.functions import (
|
||||
from .variables import (
|
||||
FunctorchVmapHigherOrderVariable,
|
||||
NestedUserFunctionVariable,
|
||||
SkipFilesVariable,
|
||||
UserFunctionVariable,
|
||||
UserMethodVariable,
|
||||
)
|
||||
|
|
@ -160,17 +162,6 @@ def _module_dir(m: types.ModuleType):
|
|||
return file and _strip_init_py(file)
|
||||
|
||||
|
||||
# TODO: Add a decoractor for easily adding functions to FUNC_INLINELIST
|
||||
# after resolving all circular import issues.
|
||||
FUNC_INLINELIST = {
|
||||
"torch._constrain_as_size",
|
||||
"torch._constrain_as_value",
|
||||
"torch._tensor._convert",
|
||||
"torch.backends.mha.get_fastpath_enabled",
|
||||
"torch.jit._unwrap_optional",
|
||||
}
|
||||
|
||||
|
||||
# These are legacy workarounds, don't add new modules to this list.
|
||||
# Please use the MOD_INLINELIST instead to force inline functions under particular modules.
|
||||
LEGACY_MOD_INLINELIST = {
|
||||
|
|
@ -240,18 +231,6 @@ if torch.distributed.is_available():
|
|||
MOD_INLINELIST.add("torch.distributed._functional_collectives")
|
||||
|
||||
|
||||
# TODO: support adding bound method into this list
|
||||
@functools.lru_cache(None)
|
||||
def get_func_inlinelist():
|
||||
inlinelist = set()
|
||||
for f in FUNC_INLINELIST:
|
||||
module_name, fn_name = f.rsplit(".", 1)
|
||||
m = importlib.import_module(module_name)
|
||||
fn = getattr(m, fn_name)
|
||||
inlinelist.add(fn.__code__)
|
||||
return inlinelist
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
def get_legacy_mod_inlinelist():
|
||||
inlinelist = set()
|
||||
|
|
@ -401,20 +380,22 @@ def check_verbose(obj, is_inlined_call=False):
|
|||
)
|
||||
else:
|
||||
fi = FunctionInfo(obj, None, getfile(obj), None)
|
||||
# Go through function based skip/inline rules.
|
||||
if fi.code in get_func_inlinelist():
|
||||
|
||||
# Consulte the central trace rules defined in torch._dynamo.trace_rules.
|
||||
rule = torch._dynamo.trace_rules.lookup_inner(
|
||||
fi.py_obj, fi.name, fi.filename, is_inlined_call
|
||||
)
|
||||
if rule in [UserFunctionVariable, FunctorchVmapHigherOrderVariable]:
|
||||
return SkipResult(
|
||||
False,
|
||||
"inlined according skipfiles.FUNC_INLINELIST",
|
||||
"inlined according trace_rules.lookup",
|
||||
)
|
||||
else:
|
||||
assert rule == SkipFilesVariable, rule
|
||||
return SkipResult(
|
||||
True,
|
||||
"skipped according trace_rules.lookup",
|
||||
)
|
||||
if is_inlined_call:
|
||||
if fi.name == "patched_init":
|
||||
return SkipResult(True, "patched init cannot be inlined.")
|
||||
elif fi.name == "__torch_function__":
|
||||
return SkipResult(False, "allow inlining __torch_function__")
|
||||
|
||||
# Go through file based skip/inline rules.
|
||||
return check_file(fi.filename, is_inlined_call)
|
||||
|
||||
|
||||
def check(obj, is_inlined_call=False):
|
||||
|
|
|
|||
|
|
@ -102,6 +102,7 @@ from .variables.misc import (
|
|||
InlinedClosureVariable,
|
||||
NullVariable,
|
||||
PythonModuleVariable,
|
||||
SkipFilesVariable,
|
||||
UnknownVariable,
|
||||
)
|
||||
from .variables.nn_module import NNModuleVariable
|
||||
|
|
@ -2290,6 +2291,8 @@ class InliningInstructionTranslator(InstructionTranslatorBase):
|
|||
def inline_call_(
|
||||
parent, func: VariableTracker, args: List[VariableTracker], kwargs
|
||||
):
|
||||
if isinstance(func, SkipFilesVariable):
|
||||
unimplemented("inline with functions in skip files")
|
||||
assert isinstance(
|
||||
func,
|
||||
(UserFunctionVariable, NestedUserFunctionVariable),
|
||||
|
|
|
|||
|
|
@ -17,9 +17,10 @@ except ModuleNotFoundError:
|
|||
|
||||
import torch
|
||||
|
||||
from .utils import hashable, NP_SUPPORTED_MODULES, unwrap_if_wrapper
|
||||
from .utils import getfile, hashable, NP_SUPPORTED_MODULES, unwrap_if_wrapper
|
||||
|
||||
from .variables import (
|
||||
BuiltinVariable,
|
||||
FunctorchVmapHigherOrderVariable,
|
||||
SkipFilesVariable,
|
||||
TorchInGraphFunctionVariable,
|
||||
|
|
@ -151,6 +152,11 @@ manual_torch_name_rule_map = {
|
|||
"torch._functorch.vmap.unwrap_batched": UserFunctionVariable,
|
||||
"torch._functorch.vmap.vmap_impl": FunctorchVmapHigherOrderVariable,
|
||||
"torch._functorch.vmap.wrap_batched": UserFunctionVariable,
|
||||
"torch._constrain_as_size": UserFunctionVariable,
|
||||
"torch._constrain_as_value": UserFunctionVariable,
|
||||
"torch._tensor._convert": UserFunctionVariable,
|
||||
"torch.jit._unwrap_optional": UserFunctionVariable,
|
||||
"torch.backends.mha.get_fastpath_enabled": UserFunctionVariable,
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -2062,8 +2068,6 @@ torch_non_c_binding_in_graph_functions = dict.fromkeys(
|
|||
"torch._check_with",
|
||||
"torch._check",
|
||||
"torch._compile._disable_dynamo",
|
||||
"torch._constrain_as_size",
|
||||
"torch._constrain_as_value",
|
||||
"torch._functorch.apis.chunk_vmap",
|
||||
"torch._functorch.autograd_function.custom_function_call_functionalize",
|
||||
"torch._functorch.autograd_function.custom_function_call_grad",
|
||||
|
|
@ -2765,8 +2769,7 @@ def load_object(name):
|
|||
else:
|
||||
assert len(x) == 1, f"Invalid obj name {name}"
|
||||
val = _load_obj_from_str(x[0])
|
||||
if hasattr(val, "__wrapped__"):
|
||||
val = val.__wrapped__
|
||||
val = unwrap_if_wrapper(val)
|
||||
except (AttributeError, ImportError):
|
||||
val = None
|
||||
return val
|
||||
|
|
@ -2969,6 +2972,23 @@ def is_numpy(obj) -> bool:
|
|||
return isinstance(obj, (np.ndarray, np.generic)) or id(obj) in _numpy_function_ids
|
||||
|
||||
|
||||
"""
|
||||
Main entry point for looking up the trace rule (the Dynamo variable) for a given callable object.
|
||||
"""
|
||||
|
||||
|
||||
def lookup_callable(obj):
|
||||
if not hashable(obj):
|
||||
return None
|
||||
# Custom allow/disallow in graph takes precedence over the general lookup.
|
||||
if is_callable_disallowed(obj):
|
||||
return SkipFilesVariable
|
||||
if is_callable_allowed(obj):
|
||||
return TorchInGraphFunctionVariable
|
||||
if is_builtin_callable(obj):
|
||||
return BuiltinVariable
|
||||
|
||||
|
||||
"""
|
||||
Main entry point for looking up the trace rule (the Dynamo variable) for a given function object.
|
||||
E.g, the lookup result of `torch.sin` is `TorchInGraphFunctionVariable`.
|
||||
|
|
@ -2976,16 +2996,35 @@ E.g, the lookup result of `torch.sin` is `TorchInGraphFunctionVariable`.
|
|||
|
||||
|
||||
def lookup(obj):
|
||||
# Unwrap if it's a functools.lru_cache wrapper
|
||||
obj = unwrap_if_wrapper(obj)
|
||||
return lookup_inner(obj)
|
||||
|
||||
|
||||
def lookup_inner(obj, name=None, filename=None, is_direct_call=True):
|
||||
# Step 1: lookup obj's tracing rule in `torch_name_rule_map`.
|
||||
# The rules defined in `torch_name_rule_map` mainly includes two parts:
|
||||
# - Manually defined rules for any functions.
|
||||
# - The list of torch in graph functions.
|
||||
if not hashable(obj):
|
||||
return None
|
||||
# Custom allow/disallow in graph takes precedence over the `torch_name_rule_map`.
|
||||
if callable(obj) and is_callable_disallowed(obj):
|
||||
if obj is not None:
|
||||
if is_aten_op_or_tensor_method(obj):
|
||||
return TorchInGraphFunctionVariable
|
||||
rule = get_torch_obj_rule_map().get(obj, None)
|
||||
if rule is not None:
|
||||
return rule
|
||||
|
||||
# Step 2: lookup obj's tracing rule by function name.
|
||||
if is_direct_call:
|
||||
if name == "patched_init":
|
||||
return SkipFilesVariable
|
||||
elif name == "__torch_function__":
|
||||
return UserFunctionVariable
|
||||
|
||||
# Step 3: lookup obj's tracing rule by filename.
|
||||
if filename is None:
|
||||
filename = getfile(obj)
|
||||
|
||||
if torch._dynamo.skipfiles.check_file(filename, is_direct_call).skipped:
|
||||
return SkipFilesVariable
|
||||
if callable(obj) and is_callable_allowed(obj):
|
||||
return TorchInGraphFunctionVariable
|
||||
if is_aten_op_or_tensor_method(obj):
|
||||
return TorchInGraphFunctionVariable
|
||||
rule = get_torch_obj_rule_map().get(obj, None)
|
||||
return rule
|
||||
else:
|
||||
return UserFunctionVariable
|
||||
|
|
|
|||
|
|
@ -524,6 +524,7 @@ def is_function_or_wrapper(value):
|
|||
is_function(value)
|
||||
or isinstance(value, functools._lru_cache_wrapper)
|
||||
and is_function(inspect.getattr_static(value, "__wrapped__"))
|
||||
or isinstance(value, (torch._ops.OpOverloadPacket, torch._ops.OpOverload))
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -535,14 +536,20 @@ def is_function(value):
|
|||
types.BuiltinFunctionType,
|
||||
types.MethodDescriptorType,
|
||||
types.WrapperDescriptorType,
|
||||
torch.jit.ScriptFunction,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def unwrap_if_wrapper(value):
|
||||
if isinstance(value, functools._lru_cache_wrapper):
|
||||
value = inspect.getattr_static(value, "__wrapped__")
|
||||
return value
|
||||
def unwrap_if_wrapper(fn):
|
||||
if isinstance(fn, functools._lru_cache_wrapper):
|
||||
fn = inspect.getattr_static(fn, "__wrapped__")
|
||||
# unpack @torch._dynamo.optimize()(fn) wrapped function
|
||||
fn = inspect.getattr_static(fn, "_torchdynamo_inline", fn)
|
||||
# unpack torch.jit.script_if_tracing
|
||||
if inspect.getattr_static(fn, "__script_if_tracing_wrapper", False):
|
||||
fn = inspect.getattr_static(fn, "__original_fn", fn)
|
||||
return fn
|
||||
|
||||
|
||||
def is_numpy_ndarray(value):
|
||||
|
|
|
|||
|
|
@ -39,7 +39,7 @@ from torch.fx.immutable_collections import immutable_list
|
|||
from torch.nested._internal.nested_tensor import NestedTensor
|
||||
from torch.utils._python_dispatch import is_traceable_wrapper_subclass
|
||||
from torch.utils.weak import TensorWeakRef
|
||||
from .. import config, mutation_guard, replay_record, skipfiles, trace_rules
|
||||
from .. import config, mutation_guard, replay_record, trace_rules
|
||||
|
||||
from ..device_interface import get_registered_device_interfaces
|
||||
from ..exc import InternalTorchDynamoError, unimplemented
|
||||
|
|
@ -59,7 +59,7 @@ from ..source import (
|
|||
Source,
|
||||
TupleIteratorGetItemSource,
|
||||
)
|
||||
from ..trace_rules import is_builtin_callable, is_callable_allowed, is_numpy
|
||||
from ..trace_rules import is_callable_allowed, is_numpy
|
||||
from ..utils import (
|
||||
build_checkpoint_variable,
|
||||
clone_input,
|
||||
|
|
@ -82,7 +82,6 @@ from ..utils import (
|
|||
)
|
||||
|
||||
from .base import MutableLocal, typestr, VariableTracker
|
||||
from .builtin import BuiltinVariable
|
||||
from .constant import ConstantVariable, EnumVariable
|
||||
from .ctx_manager import (
|
||||
AutocastModeVariable,
|
||||
|
|
@ -138,7 +137,6 @@ from .misc import (
|
|||
NumpyVariable,
|
||||
PythonModuleVariable,
|
||||
SavedTensorBox,
|
||||
SkipFilesVariable,
|
||||
TypingVariable,
|
||||
)
|
||||
from .nn_module import FSDPManagedNNModuleVariable, UnspecializedNNModuleVariable
|
||||
|
|
@ -474,9 +472,6 @@ class VariableBuilder:
|
|||
elif isinstance(value, enum.Enum):
|
||||
self.install_guards(GuardBuilder.ID_MATCH)
|
||||
return EnumVariable(value=value, source=self.source)
|
||||
elif is_builtin_callable(value):
|
||||
self.install_guards(GuardBuilder.BUILTIN_MATCH)
|
||||
return BuiltinVariable(value, source=self.source)
|
||||
elif is_utils_checkpoint(value):
|
||||
return build_checkpoint_variable(source=self.source)
|
||||
elif isinstance(value, functools.partial):
|
||||
|
|
@ -571,6 +566,12 @@ class VariableBuilder:
|
|||
),
|
||||
"apply",
|
||||
)
|
||||
elif callable(value) and trace_rules.lookup_callable(value) is not None:
|
||||
if is_callable_allowed(value):
|
||||
self.tx.output.has_user_defined_allowed_in_graph = True
|
||||
return trace_rules.lookup_callable(value).create_with_source(
|
||||
value, source=self.source
|
||||
)
|
||||
elif np and isinstance(value, np.number):
|
||||
return self.wrap_unspecialized_primitive(value)
|
||||
elif DataClassVariable.is_matching_object(value):
|
||||
|
|
@ -708,11 +709,8 @@ class VariableBuilder:
|
|||
elif TorchCtxManagerClassVariable.is_matching_cls(value):
|
||||
self.install_guards(GuardBuilder.FUNCTION_MATCH)
|
||||
return TorchCtxManagerClassVariable(value, source=self.source)
|
||||
elif (is_function_or_wrapper(value) or callable(value)) and trace_rules.lookup(
|
||||
value
|
||||
) is not None:
|
||||
if is_callable_allowed(value):
|
||||
self.tx.output.has_user_defined_allowed_in_graph = True
|
||||
elif is_function_or_wrapper(value):
|
||||
value = unwrap_if_wrapper(value)
|
||||
return trace_rules.lookup(value).create_with_source(
|
||||
value, source=self.source
|
||||
)
|
||||
|
|
@ -725,25 +723,6 @@ class VariableBuilder:
|
|||
value,
|
||||
source=self.source,
|
||||
)
|
||||
elif (
|
||||
is_function_or_wrapper(value)
|
||||
and skipfiles.check(value, is_inlined_call=True)
|
||||
and not inspect.getattr_static(value, "_torchdynamo_inline", False)
|
||||
and not inspect.getattr_static(value, "__script_if_tracing_wrapper", False)
|
||||
):
|
||||
value = unwrap_if_wrapper(value)
|
||||
self.install_guards(GuardBuilder.FUNCTION_MATCH)
|
||||
return SkipFilesVariable(
|
||||
value,
|
||||
skipfiles.check_verbose(value, is_inlined_call=True).reason,
|
||||
source=self.source,
|
||||
)
|
||||
elif istype(value, (types.FunctionType, torch.jit.ScriptFunction)):
|
||||
self.install_guards(GuardBuilder.CLOSURE_MATCH)
|
||||
return UserFunctionVariable(
|
||||
value,
|
||||
source=self.source,
|
||||
)
|
||||
elif isinstance(value, types.MethodType) and isinstance(
|
||||
value.__self__, (torch.nn.Module, torch.utils._pytree.TreeSpec)
|
||||
):
|
||||
|
|
@ -1871,14 +1850,11 @@ class SourcelessBuilder:
|
|||
return UserDefinedObjectVariable(value)
|
||||
if ConstantVariable.is_literal(value):
|
||||
return SourcelessBuilder.wrap_constant_literal(value)
|
||||
elif is_builtin_callable(value):
|
||||
return BuiltinVariable(value)
|
||||
elif (is_function_or_wrapper(value) or callable(value)) and trace_rules.lookup(
|
||||
value
|
||||
) is not None:
|
||||
value = unwrap_if_wrapper(value)
|
||||
elif callable(value) and trace_rules.lookup_callable(value) is not None:
|
||||
if is_callable_allowed(value):
|
||||
self.tx.output.has_user_defined_allowed_in_graph = True
|
||||
return trace_rules.lookup_callable(value)(value)
|
||||
elif is_function_or_wrapper(value):
|
||||
return trace_rules.lookup(value)(value)
|
||||
elif isinstance(value, types.FunctionType):
|
||||
return UserFunctionVariable(value)
|
||||
|
|
|
|||
|
|
@ -99,6 +99,11 @@ def _polyfill_call_impl(name):
|
|||
class BuiltinVariable(VariableTracker):
|
||||
_SENTINEL = object()
|
||||
|
||||
@classmethod
|
||||
def create_with_source(cls, value, source):
|
||||
install_guard(source.make_guard(GuardBuilder.BUILTIN_MATCH))
|
||||
return BuiltinVariable(value, source=source)
|
||||
|
||||
@staticmethod
|
||||
@functools.lru_cache(None)
|
||||
def _constant_fold_functions():
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ import torch
|
|||
from .. import variables
|
||||
from ..bytecode_transformation import create_call_function, create_rot_n
|
||||
from ..exc import unimplemented, Unsupported
|
||||
from ..guards import GuardBuilder, install_guard
|
||||
from ..source import AttrSource, ConstantSource, DefaultsSource, GetItemSource
|
||||
from ..utils import get_first_attr, identity, istype, make_cell
|
||||
from .base import typestr, VariableTracker
|
||||
|
|
@ -110,6 +111,7 @@ class UserFunctionVariable(BaseUserFunctionVariable):
|
|||
|
||||
@classmethod
|
||||
def create_with_source(cls, value, source):
|
||||
install_guard(source.make_guard(GuardBuilder.CLOSURE_MATCH))
|
||||
return cls(
|
||||
value,
|
||||
source=source,
|
||||
|
|
|
|||
|
|
@ -293,7 +293,7 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
|
|||
self.value,
|
||||
source=self.source,
|
||||
).call_function(tx, args, kwargs)
|
||||
elif self.value is torch.overrides.get_default_nowrap_functions:
|
||||
elif self.value is torch.overrides.get_default_nowrap_functions.__wrapped__:
|
||||
# [Note: __torch_function__] we return empty here because we restrict
|
||||
# the set of functions that we trace __torch_function__ on to
|
||||
# functions outside of the actual set. Implementing this properly will require implementing
|
||||
|
|
|
|||
|
|
@ -735,7 +735,7 @@ class UserDefinedObjectVariable(UserDefinedVariable):
|
|||
).call_function(tx, [self], {})
|
||||
elif isinstance(subobj, staticmethod):
|
||||
func = subobj.__get__(self.value)
|
||||
if trace_rules.lookup(func) is not None:
|
||||
if source is not None and trace_rules.lookup(func) is not None:
|
||||
return trace_rules.lookup(func).create_with_source(func, source=source)
|
||||
else:
|
||||
return variables.UserFunctionVariable(func, source=source)
|
||||
|
|
@ -768,7 +768,7 @@ class UserDefinedObjectVariable(UserDefinedVariable):
|
|||
elif inspect.isfunction(dynamic_subobj):
|
||||
if is_utils_checkpoint(func):
|
||||
return build_checkpoint_variable(source=source)
|
||||
elif trace_rules.lookup(func) is not None:
|
||||
elif source is not None and trace_rules.lookup(func) is not None:
|
||||
return trace_rules.lookup(func).create_with_source(
|
||||
func, source=source
|
||||
)
|
||||
|
|
|
|||
|
|
@ -420,13 +420,7 @@ dynamo_expected_failures = {
|
|||
"TestAsArrayCPU.test_copy_list_cpu_complex128", # test_tensor_creation_ops
|
||||
"TestAsArrayCPU.test_copy_list_cpu_int16", # test_tensor_creation_ops
|
||||
"TestTensorCreationCPU.test_cartesian_prod_cpu", # test_tensor_creation_ops
|
||||
"TestSubclass.test_parametrization_non_wrapper_tensor_leave_parametrized_True", # test_subclass
|
||||
"TestSubclass.test_module_optimization_non_wrapper_tensor", # test_subclass
|
||||
"TestSubclass.test_serialization_non_wrapper_tensor_as_param_True", # test_subclass
|
||||
"TestSubclass.test_parametrization_non_wrapper_tensor_leave_parametrized_False", # test_subclass
|
||||
"TestSubclass.test_type_propagation_non_wrapper_tensor_as_param_False", # test_subclass
|
||||
"TestSubclass.test_parametrization_base_tensor_leave_parametrized_True", # test_subclass
|
||||
"TestSubclass.test_type_propagation_non_wrapper_tensor_as_param_True", # test_subclass
|
||||
"TestSubclass.test_parametrization_base_tensor_leave_parametrized_False", # test_subclass
|
||||
"TestStatelessFunctionalAPI.test_reparametrize_module_fail_reset_to_original_torch_func", # test_stateless
|
||||
"TestStatelessFunctionalAPI.test_reparametrized_module_change_parametrization_original_stateless", # test_stateless
|
||||
|
|
@ -560,29 +554,16 @@ dynamo_expected_failures = {
|
|||
"PackedSequenceTest.test_pack_sequence", # nn/test_packed_sequence
|
||||
"PackedSequenceTest.test_total_length", # nn/test_packed_sequence
|
||||
"TestModuleHookNN.test_hook_inplace", # nn/test_module_hooks
|
||||
"TestLazyModules.test_lazy_batchnorm2d_state", # nn/test_lazy_modules
|
||||
"TestLazyModules.test_lazy_conv3d", # nn/test_lazy_modules
|
||||
"TestLazyModules.test_lazy_conv_transposed1d", # nn/test_lazy_modules
|
||||
"TestLazyModules.test_lazy_conv2d", # nn/test_lazy_modules
|
||||
"TestLazyModules.test_lazy_instancenorm3d_state", # nn/test_lazy_modules
|
||||
"TestLazyModules.test_lazy_batchnorm3d_state", # nn/test_lazy_modules
|
||||
"TestLazyModules.test_lazy_conv_transpose1d_pickle", # nn/test_lazy_modules
|
||||
"TestLazyModules.test_lazy_instancenorm2d", # nn/test_lazy_modules
|
||||
"TestLazyModules.test_lazy_instancenorm2d_state", # nn/test_lazy_modules
|
||||
"TestLazyModules.test_lazy_conv3d_pickle", # nn/test_lazy_modules
|
||||
"TestLazyModules.test_lazy_batchnorm2d", # nn/test_lazy_modules
|
||||
"TestLazyModules.test_lazy_instancenorm1d", # nn/test_lazy_modules
|
||||
"TestLazyModules.test_lazy_batchnorm1d", # nn/test_lazy_modules
|
||||
"TestLazyModules.test_lazy_instancenorm1d_state", # nn/test_lazy_modules
|
||||
"TestLazyModules.test_lazy_conv_transpose3d_pickle", # nn/test_lazy_modules
|
||||
"TestLazyModules.test_lazy_instancenorm3d", # nn/test_lazy_modules
|
||||
"TestLazyModules.test_lazy_batchnorm3d", # nn/test_lazy_modules
|
||||
"TestLazyModules.test_lazy_conv2d_pickle", # nn/test_lazy_modules
|
||||
"TestLazyModules.test_lazy_conv1d_pickle", # nn/test_lazy_modules
|
||||
"TestLazyModules.test_lazy_conv1d", # nn/test_lazy_modules
|
||||
"TestLazyModules.test_linear", # nn/test_lazy_modules
|
||||
"TestLazyModules.test_lazy_module_buffer", # nn/test_lazy_modules
|
||||
"TestLazyModules.test_lazy_batchnorm1d_state", # nn/test_lazy_modules
|
||||
"TestLazyModules.test_lazy_batchnorm_with_dict_input", # nn/test_lazy_modules
|
||||
"TestLazyModules.test_lazy_conv_transpose2d", # nn/test_lazy_modules
|
||||
"TestLazyModules.test_lazy_conv_transpose2d_pickle", # nn/test_lazy_modules
|
||||
|
|
@ -1167,10 +1148,8 @@ dynamo_expected_failures = {
|
|||
"AutogradFunctionTests.test_print_in_bwd", # dynamo/test_autograd_function
|
||||
"AutogradFunctionTests.test_graph_break_if_lifted_free_variable", # dynamo/test_autograd_function
|
||||
"AotAutogradFallbackTests.test_aot_sequence_nr", # dynamo/test_aot_autograd
|
||||
"TestTorchFunctionOverride.test_tensor_subclass_propagation", # test_overrides
|
||||
"TestNamedTuple.test_max", # test_overrides
|
||||
"TestTorchFunctionMode.test_mode_notimplemented_loop", # test_overrides
|
||||
"TestTorchFunctionMode.test_disable_enable_subclass", # test_overrides
|
||||
"TestTorchFunctionOverride.test_mean_semantics", # test_overrides
|
||||
"TestGradCheckOverride.test_gradcheck", # test_overrides
|
||||
"TestTorchFunctionOverride.test_Tensor___cuda_array_interface_____get__", # test_overrides
|
||||
|
|
@ -1580,7 +1559,6 @@ dynamo_expected_failures = {
|
|||
"TestAutograd.test_increment_version", # test_autograd
|
||||
"TestAutograd.test_record_function_callbacks", # test_autograd
|
||||
"TestAutograd.test_save_on_cpu_and_checkpoint", # test_autograd
|
||||
"TestAutogradForwardMode.test_make_dual_torch_dispatch", # test_autograd
|
||||
"TestAutogradDeviceTypeCPU.test_sparse_ctor_getter_backward_cpu_complex128", # test_autograd
|
||||
"TestNestedCheckpoint.test_nested_checkpoint_non_tensor_inputs_and_outputs_early_stop_False", # test_autograd
|
||||
"TestAutograd.test_gradcheck_nondeterministic", # test_autograd
|
||||
|
|
@ -1825,7 +1803,6 @@ dynamo_expected_failures = {
|
|||
"TestFakeQuantizeOps.test_learnable_forward_per_tensor_cuda", # test_quantization
|
||||
"TestQuantizedTensor.test_repeat", # test_quantization
|
||||
"TestStaticQuantizedModule.test_linear_leaky_relu", # test_quantization
|
||||
"TestBitsCPU.test_subclass_cpu", # test_quantization
|
||||
"TestFakeQuantizeOps.test_learnable_backward_per_channel_cpu", # test_quantization
|
||||
"TestFXNumericSuiteCoreAPIs.test_add_shadow_loggers_fun_ptq", # test_quantization
|
||||
"TestQuantizeFx.test_static_lstm", # test_quantization
|
||||
|
|
@ -2064,28 +2041,11 @@ dynamo_expected_failures = {
|
|||
"TestGenerateOpcheckTests.test_opcheck_bad_op", # test_custom_ops
|
||||
"TestCustomOp.test_legacy_define", # test_custom_ops
|
||||
"TestPythonRegistration.test_alias_analysis", # test_python_dispatch
|
||||
"TestPythonDispatch.test_torch_dispatch_mode_subclass_priority", # test_python_dispatch
|
||||
"TestPythonDispatch.test_strides_slow_path", # test_python_dispatch
|
||||
"TestPythonDispatch.test_invalid_ret", # test_python_dispatch
|
||||
"TestPythonDispatch.test_dim_slowpath", # test_python_dispatch
|
||||
"TestWrapperSubclassAliasingCPU.test_wrapper_subclass_aliasing_conv2d_cpu", # test_python_dispatch
|
||||
"TestPythonDispatch.test_fancy_strides", # test_python_dispatch
|
||||
"TestPythonDispatch.test_layout_slow_path", # test_python_dispatch
|
||||
"TestPythonDispatch.test_dispatch_super_dont_autograd", # test_python_dispatch
|
||||
"TestPythonDispatch.test_sizes_slow_path", # test_python_dispatch
|
||||
"TestPythonRegistration.test_finalizer", # test_python_dispatch
|
||||
"TestPythonDispatch.test_dispatch_super_call_list_arg", # test_python_dispatch
|
||||
"TestPythonDispatch.test_is_contiguous_slow_path", # test_python_dispatch
|
||||
"TestPythonRegistration.test_override_cpu_sum", # test_python_dispatch
|
||||
"TestPythonDispatch.test_mode_with_make_subclass", # test_python_dispatch
|
||||
"TestPythonDispatch.test_multiple_ops_subclass", # test_python_dispatch
|
||||
"TestPythonDispatch.test_subclass_autograd_device_check", # test_python_dispatch
|
||||
"TestPythonDispatch.test_data_ptr_respects_numel_slow_path", # test_python_dispatch
|
||||
"TestPythonDispatch.test_make_subclass_with_modes", # test_python_dispatch
|
||||
"TestPythonDispatch.test_dispatch_super_call", # test_python_dispatch
|
||||
"TestPythonDispatch.test_subclass_priority", # test_python_dispatch
|
||||
"TestPythonDispatch.test_exception_handling", # test_python_dispatch
|
||||
"TestPythonDispatch.test_list_ret", # test_python_dispatch
|
||||
"LoggingTests.test_trace_source_nested", # dynamo/test_logging
|
||||
"LoggingTests.test_guards_recompiles", # dynamo/test_logging
|
||||
"LoggingTests.test_inductor_info", # dynamo/test_logging
|
||||
|
|
@ -6499,6 +6459,7 @@ dynamo_skips = {
|
|||
"TestScalarOpsMisc.test_scalar_integer_operation_divbyzero_dtype_Q_operation1",
|
||||
"TestArgmax.test_combinations_data61", # torch_np/test_ndarray_methods.py
|
||||
"TestArgmax.test_combinations_data58", # torch_np/test_ndarray_methods.py
|
||||
"TestPythonDispatch.test_list_ret", # test_python_dispatch.py
|
||||
"TestCustomOpTestingCPU.test_opcheck_fails_basic_cpu", # test_custom_ops.py
|
||||
"TestVmapAPI.test_functools_partial", # functorch/test_vmap.py
|
||||
"TestSaveLoadForOpVersion.test_versioned_div_tensor_out", # test_jit.py
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user