[Dynamo][15/N] Merge allow_in_graph/inline/skip trace rules check into trace_rule.lookup (#118971)

Finally we have this PR to merge allow_in_graph/inline/skip trace rules into ```trace_rules.lookup_inner```, where we can define and lookup trace rules at both function level and file level. Going forward, this is the central place that we define and consulte Dynamo trace rule for any function.
* ```trace_rules.looup``` is the API can return allow_in_graph, inline or skip.
* ```skipfiles.check``` is the API can return inline or skip, since we have multiple places that only do inline/skip check.
  *  I'll move ```skipfiles.check``` to ```trace_rules.check``` as one of the follow-ups.
* Both functions consulte ```trace_rules.lookup_inner``` to get the tracing rule.

To avoid a single big PR, I left a few items as the follow-ups:
* Remove ```skipfiles.py``` and merge the code into ```trace_rules.py```.
* We do double check in ```symbolic_convert.check_inlineable```, will refactor and simplify it. We should only do inline/skip check before generating ```SkipFilesVariable``` and ```UserFunctionVariable```.
* Rename ```SkipFilesVariable``` as ```SkipFunctionVariable```, since we only handle functions.
* The inline/skip reasons are not logged for some cases, since the new lookup framework doesn't always return inline/skip reasons. I'll refactor loggings to record the inline/skip reason in next step.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/118971
Approved by: https://github.com/jansel
This commit is contained in:
Yanbo Liang 2024-02-07 05:15:36 +00:00 committed by PyTorch MergeBot
parent 284b0b5f44
commit 0f478d9d61
12 changed files with 147 additions and 175 deletions

View File

@ -12,18 +12,15 @@ import torch
import torch._dynamo.config as config
import torch._dynamo.test_case
import torch._functorch.deprecated as deprecated_func
from torch._dynamo.skipfiles import (
FUNC_INLINELIST,
LEGACY_MOD_INLINELIST,
MOD_INLINELIST,
)
from torch._dynamo.skipfiles import LEGACY_MOD_INLINELIST, MOD_INLINELIST
from torch._dynamo.trace_rules import (
load_object,
manual_torch_name_rule_map,
torch_c_binding_in_graph_functions,
torch_non_c_binding_in_graph_functions,
)
from torch._dynamo.utils import hashable, is_safe_constant, istype
from torch._dynamo.variables import TorchInGraphFunctionVariable
from torch._dynamo.variables import TorchInGraphFunctionVariable, UserFunctionVariable
try:
from .utils import create_dummy_module_and_function
@ -282,19 +279,6 @@ def gen_allowed_objs_and_ids(record=False, c_binding_only=True) -> AllowedObject
)
def gen_get_func_inlinelist(dummy_func_inlinelist):
def get_func_inlinelist():
inlinelist = set()
for f in dummy_func_inlinelist:
module_name, fn_name = f.rsplit(".", 1)
m = importlib.import_module(module_name)
fn = getattr(m, fn_name)
inlinelist.add(fn.__code__)
return inlinelist
return get_func_inlinelist
class TraceRuleTests(torch._dynamo.test_case.TestCase):
def _check_set_equality(self, generated, used, rule_map, ignored_set):
x = generated - used
@ -321,13 +305,6 @@ class TraceRuleTests(torch._dynamo.test_case.TestCase):
isinstance(importlib.import_module(m), types.ModuleType),
f"{m} from skipfiles.MOD_INLINELIST/LEGACY_MOD_INLINELIST is not a python module, please check and correct it.",
)
for f in FUNC_INLINELIST:
module_name, fn_name = f.rsplit(".", 1)
m = importlib.import_module(module_name)
self.assertTrue(
isinstance(getattr(m, fn_name), types.FunctionType),
f"{f} from skipfiles.FUNC_INLINELIST is not a python function, please check and correct it.",
)
def test_torch_name_rule_map_updated(self):
# Generate the allowed objects based on heuristic defined in `allowed_functions.py`,
@ -363,15 +340,23 @@ class TraceRuleTests(torch._dynamo.test_case.TestCase):
)
)
def test_func_inlinelist_torch_function(self):
def test_force_inline_torch_function(self):
# `torch._dynamo.utils.istype` is skipped by default
def fn(x):
if istype(x, torch.Tensor):
return x + 1
else:
return x - 1
func_inlinelist = torch._dynamo.skipfiles.FUNC_INLINELIST.copy()
func_inlinelist.add("torch._dynamo.utils.istype")
_manual_torch_name_rule_map = manual_torch_name_rule_map.copy()
# Force inline `torch._dynamo.utils.istype` by setting trace rule.
_manual_torch_name_rule_map["torch._dynamo.utils.istype"] = UserFunctionVariable
_torch_name_rule_map = [
_manual_torch_name_rule_map,
torch_c_binding_in_graph_functions,
torch_non_c_binding_in_graph_functions,
]
self.assertTrue(
"torch._dynamo" not in torch._dynamo.skipfiles.LEGACY_MOD_INLINELIST
@ -379,8 +364,11 @@ class TraceRuleTests(torch._dynamo.test_case.TestCase):
self.assertTrue("torch._dynamo" not in torch._dynamo.skipfiles.MOD_INLINELIST)
with unittest.mock.patch(
"torch._dynamo.skipfiles.get_func_inlinelist",
gen_get_func_inlinelist(func_inlinelist),
"torch._dynamo.trace_rules.torch_name_rule_map",
_torch_name_rule_map,
), unittest.mock.patch(
"torch._dynamo.trace_rules.get_torch_obj_rule_map",
torch._dynamo.trace_rules.get_torch_obj_rule_map.__wrapped__, # bypass functools.lru_cache
):
x = torch.rand(3)
opt_fn = torch.compile(backend="eager", fullgraph=True)(fn)
@ -388,23 +376,32 @@ class TraceRuleTests(torch._dynamo.test_case.TestCase):
res = opt_fn(x)
self.assertEqual(ref, res)
def test_func_inlinelist_third_party_function(self):
def test_force_inline_custom_function(self):
mod, func = create_dummy_module_and_function()
def fn(x):
return func(x)
func_inlinelist = torch._dynamo.skipfiles.FUNC_INLINELIST.copy()
func_inlinelist.add(f"{mod.__name__}.{func.__name__}")
_manual_torch_name_rule_map = manual_torch_name_rule_map.copy()
# Force inline `mod.func` by setting trace rule.
_manual_torch_name_rule_map[
f"{mod.__name__}.{func.__name__}"
] = UserFunctionVariable
_torch_name_rule_map = [
_manual_torch_name_rule_map,
torch_c_binding_in_graph_functions,
torch_non_c_binding_in_graph_functions,
]
with unittest.mock.patch(
"torch._dynamo.skipfiles.get_func_inlinelist",
gen_get_func_inlinelist(func_inlinelist),
"torch._dynamo.trace_rules.torch_name_rule_map",
_torch_name_rule_map,
), unittest.mock.patch(
"torch._dynamo.skipfiles.SKIP_DIRS",
torch._dynamo.skipfiles.SKIP_DIRS.copy(),
"torch._dynamo.trace_rules.get_torch_obj_rule_map",
torch._dynamo.trace_rules.get_torch_obj_rule_map.__wrapped__,
):
# First adding the module to SKIP_DIRS so that it will be skipped.
# First adding the module to SKIP_DIRS so that it will be skipped by default.
torch._dynamo.skipfiles.add(mod.__name__)
x = torch.rand(3)
opt_fn = torch.compile(backend="eager", fullgraph=True)(fn)

View File

@ -93,7 +93,7 @@ def allow_in_graph(fn):
if isinstance(fn, (list, tuple)):
return [allow_in_graph(x) for x in fn]
assert callable(fn), "allow_in_graph expects a callable"
if trace_rules.lookup(fn) != variables.TorchInGraphFunctionVariable:
if trace_rules.lookup_callable(fn) != variables.TorchInGraphFunctionVariable:
trace_rules._disallowed_callable_ids.remove(id(fn))
trace_rules._allowed_callable_ids.add(id(fn))
return fn
@ -106,8 +106,9 @@ def _disallow_in_graph_helper(throw_if_not_allowed):
assert callable(fn), "disallow_in_graph expects a callable"
if (
throw_if_not_allowed
and trace_rules.lookup_callable(fn)
!= variables.TorchInGraphFunctionVariable
and trace_rules.lookup(fn) != variables.TorchInGraphFunctionVariable
and fn not in trace_rules._allowed_callable_ids
):
raise IncorrectUsage(
"disallow_in_graph is expected to be used on an already allowed callable (like torch.* ops). "

View File

@ -37,8 +37,10 @@ import torch.utils._content_store
from ..utils import _config_module
from .utils import getfile
from .variables.functions import (
from .variables import (
FunctorchVmapHigherOrderVariable,
NestedUserFunctionVariable,
SkipFilesVariable,
UserFunctionVariable,
UserMethodVariable,
)
@ -160,17 +162,6 @@ def _module_dir(m: types.ModuleType):
return file and _strip_init_py(file)
# TODO: Add a decoractor for easily adding functions to FUNC_INLINELIST
# after resolving all circular import issues.
FUNC_INLINELIST = {
"torch._constrain_as_size",
"torch._constrain_as_value",
"torch._tensor._convert",
"torch.backends.mha.get_fastpath_enabled",
"torch.jit._unwrap_optional",
}
# These are legacy workarounds, don't add new modules to this list.
# Please use the MOD_INLINELIST instead to force inline functions under particular modules.
LEGACY_MOD_INLINELIST = {
@ -240,18 +231,6 @@ if torch.distributed.is_available():
MOD_INLINELIST.add("torch.distributed._functional_collectives")
# TODO: support adding bound method into this list
@functools.lru_cache(None)
def get_func_inlinelist():
inlinelist = set()
for f in FUNC_INLINELIST:
module_name, fn_name = f.rsplit(".", 1)
m = importlib.import_module(module_name)
fn = getattr(m, fn_name)
inlinelist.add(fn.__code__)
return inlinelist
@functools.lru_cache(None)
def get_legacy_mod_inlinelist():
inlinelist = set()
@ -401,20 +380,22 @@ def check_verbose(obj, is_inlined_call=False):
)
else:
fi = FunctionInfo(obj, None, getfile(obj), None)
# Go through function based skip/inline rules.
if fi.code in get_func_inlinelist():
# Consulte the central trace rules defined in torch._dynamo.trace_rules.
rule = torch._dynamo.trace_rules.lookup_inner(
fi.py_obj, fi.name, fi.filename, is_inlined_call
)
if rule in [UserFunctionVariable, FunctorchVmapHigherOrderVariable]:
return SkipResult(
False,
"inlined according skipfiles.FUNC_INLINELIST",
"inlined according trace_rules.lookup",
)
else:
assert rule == SkipFilesVariable, rule
return SkipResult(
True,
"skipped according trace_rules.lookup",
)
if is_inlined_call:
if fi.name == "patched_init":
return SkipResult(True, "patched init cannot be inlined.")
elif fi.name == "__torch_function__":
return SkipResult(False, "allow inlining __torch_function__")
# Go through file based skip/inline rules.
return check_file(fi.filename, is_inlined_call)
def check(obj, is_inlined_call=False):

View File

@ -102,6 +102,7 @@ from .variables.misc import (
InlinedClosureVariable,
NullVariable,
PythonModuleVariable,
SkipFilesVariable,
UnknownVariable,
)
from .variables.nn_module import NNModuleVariable
@ -2290,6 +2291,8 @@ class InliningInstructionTranslator(InstructionTranslatorBase):
def inline_call_(
parent, func: VariableTracker, args: List[VariableTracker], kwargs
):
if isinstance(func, SkipFilesVariable):
unimplemented("inline with functions in skip files")
assert isinstance(
func,
(UserFunctionVariable, NestedUserFunctionVariable),

View File

@ -17,9 +17,10 @@ except ModuleNotFoundError:
import torch
from .utils import hashable, NP_SUPPORTED_MODULES, unwrap_if_wrapper
from .utils import getfile, hashable, NP_SUPPORTED_MODULES, unwrap_if_wrapper
from .variables import (
BuiltinVariable,
FunctorchVmapHigherOrderVariable,
SkipFilesVariable,
TorchInGraphFunctionVariable,
@ -151,6 +152,11 @@ manual_torch_name_rule_map = {
"torch._functorch.vmap.unwrap_batched": UserFunctionVariable,
"torch._functorch.vmap.vmap_impl": FunctorchVmapHigherOrderVariable,
"torch._functorch.vmap.wrap_batched": UserFunctionVariable,
"torch._constrain_as_size": UserFunctionVariable,
"torch._constrain_as_value": UserFunctionVariable,
"torch._tensor._convert": UserFunctionVariable,
"torch.jit._unwrap_optional": UserFunctionVariable,
"torch.backends.mha.get_fastpath_enabled": UserFunctionVariable,
}
@ -2062,8 +2068,6 @@ torch_non_c_binding_in_graph_functions = dict.fromkeys(
"torch._check_with",
"torch._check",
"torch._compile._disable_dynamo",
"torch._constrain_as_size",
"torch._constrain_as_value",
"torch._functorch.apis.chunk_vmap",
"torch._functorch.autograd_function.custom_function_call_functionalize",
"torch._functorch.autograd_function.custom_function_call_grad",
@ -2765,8 +2769,7 @@ def load_object(name):
else:
assert len(x) == 1, f"Invalid obj name {name}"
val = _load_obj_from_str(x[0])
if hasattr(val, "__wrapped__"):
val = val.__wrapped__
val = unwrap_if_wrapper(val)
except (AttributeError, ImportError):
val = None
return val
@ -2969,6 +2972,23 @@ def is_numpy(obj) -> bool:
return isinstance(obj, (np.ndarray, np.generic)) or id(obj) in _numpy_function_ids
"""
Main entry point for looking up the trace rule (the Dynamo variable) for a given callable object.
"""
def lookup_callable(obj):
if not hashable(obj):
return None
# Custom allow/disallow in graph takes precedence over the general lookup.
if is_callable_disallowed(obj):
return SkipFilesVariable
if is_callable_allowed(obj):
return TorchInGraphFunctionVariable
if is_builtin_callable(obj):
return BuiltinVariable
"""
Main entry point for looking up the trace rule (the Dynamo variable) for a given function object.
E.g, the lookup result of `torch.sin` is `TorchInGraphFunctionVariable`.
@ -2976,16 +2996,35 @@ E.g, the lookup result of `torch.sin` is `TorchInGraphFunctionVariable`.
def lookup(obj):
# Unwrap if it's a functools.lru_cache wrapper
obj = unwrap_if_wrapper(obj)
return lookup_inner(obj)
def lookup_inner(obj, name=None, filename=None, is_direct_call=True):
# Step 1: lookup obj's tracing rule in `torch_name_rule_map`.
# The rules defined in `torch_name_rule_map` mainly includes two parts:
# - Manually defined rules for any functions.
# - The list of torch in graph functions.
if not hashable(obj):
return None
# Custom allow/disallow in graph takes precedence over the `torch_name_rule_map`.
if callable(obj) and is_callable_disallowed(obj):
if obj is not None:
if is_aten_op_or_tensor_method(obj):
return TorchInGraphFunctionVariable
rule = get_torch_obj_rule_map().get(obj, None)
if rule is not None:
return rule
# Step 2: lookup obj's tracing rule by function name.
if is_direct_call:
if name == "patched_init":
return SkipFilesVariable
elif name == "__torch_function__":
return UserFunctionVariable
# Step 3: lookup obj's tracing rule by filename.
if filename is None:
filename = getfile(obj)
if torch._dynamo.skipfiles.check_file(filename, is_direct_call).skipped:
return SkipFilesVariable
if callable(obj) and is_callable_allowed(obj):
return TorchInGraphFunctionVariable
if is_aten_op_or_tensor_method(obj):
return TorchInGraphFunctionVariable
rule = get_torch_obj_rule_map().get(obj, None)
return rule
else:
return UserFunctionVariable

View File

@ -524,6 +524,7 @@ def is_function_or_wrapper(value):
is_function(value)
or isinstance(value, functools._lru_cache_wrapper)
and is_function(inspect.getattr_static(value, "__wrapped__"))
or isinstance(value, (torch._ops.OpOverloadPacket, torch._ops.OpOverload))
)
@ -535,14 +536,20 @@ def is_function(value):
types.BuiltinFunctionType,
types.MethodDescriptorType,
types.WrapperDescriptorType,
torch.jit.ScriptFunction,
),
)
def unwrap_if_wrapper(value):
if isinstance(value, functools._lru_cache_wrapper):
value = inspect.getattr_static(value, "__wrapped__")
return value
def unwrap_if_wrapper(fn):
if isinstance(fn, functools._lru_cache_wrapper):
fn = inspect.getattr_static(fn, "__wrapped__")
# unpack @torch._dynamo.optimize()(fn) wrapped function
fn = inspect.getattr_static(fn, "_torchdynamo_inline", fn)
# unpack torch.jit.script_if_tracing
if inspect.getattr_static(fn, "__script_if_tracing_wrapper", False):
fn = inspect.getattr_static(fn, "__original_fn", fn)
return fn
def is_numpy_ndarray(value):

View File

@ -39,7 +39,7 @@ from torch.fx.immutable_collections import immutable_list
from torch.nested._internal.nested_tensor import NestedTensor
from torch.utils._python_dispatch import is_traceable_wrapper_subclass
from torch.utils.weak import TensorWeakRef
from .. import config, mutation_guard, replay_record, skipfiles, trace_rules
from .. import config, mutation_guard, replay_record, trace_rules
from ..device_interface import get_registered_device_interfaces
from ..exc import InternalTorchDynamoError, unimplemented
@ -59,7 +59,7 @@ from ..source import (
Source,
TupleIteratorGetItemSource,
)
from ..trace_rules import is_builtin_callable, is_callable_allowed, is_numpy
from ..trace_rules import is_callable_allowed, is_numpy
from ..utils import (
build_checkpoint_variable,
clone_input,
@ -82,7 +82,6 @@ from ..utils import (
)
from .base import MutableLocal, typestr, VariableTracker
from .builtin import BuiltinVariable
from .constant import ConstantVariable, EnumVariable
from .ctx_manager import (
AutocastModeVariable,
@ -138,7 +137,6 @@ from .misc import (
NumpyVariable,
PythonModuleVariable,
SavedTensorBox,
SkipFilesVariable,
TypingVariable,
)
from .nn_module import FSDPManagedNNModuleVariable, UnspecializedNNModuleVariable
@ -474,9 +472,6 @@ class VariableBuilder:
elif isinstance(value, enum.Enum):
self.install_guards(GuardBuilder.ID_MATCH)
return EnumVariable(value=value, source=self.source)
elif is_builtin_callable(value):
self.install_guards(GuardBuilder.BUILTIN_MATCH)
return BuiltinVariable(value, source=self.source)
elif is_utils_checkpoint(value):
return build_checkpoint_variable(source=self.source)
elif isinstance(value, functools.partial):
@ -571,6 +566,12 @@ class VariableBuilder:
),
"apply",
)
elif callable(value) and trace_rules.lookup_callable(value) is not None:
if is_callable_allowed(value):
self.tx.output.has_user_defined_allowed_in_graph = True
return trace_rules.lookup_callable(value).create_with_source(
value, source=self.source
)
elif np and isinstance(value, np.number):
return self.wrap_unspecialized_primitive(value)
elif DataClassVariable.is_matching_object(value):
@ -708,11 +709,8 @@ class VariableBuilder:
elif TorchCtxManagerClassVariable.is_matching_cls(value):
self.install_guards(GuardBuilder.FUNCTION_MATCH)
return TorchCtxManagerClassVariable(value, source=self.source)
elif (is_function_or_wrapper(value) or callable(value)) and trace_rules.lookup(
value
) is not None:
if is_callable_allowed(value):
self.tx.output.has_user_defined_allowed_in_graph = True
elif is_function_or_wrapper(value):
value = unwrap_if_wrapper(value)
return trace_rules.lookup(value).create_with_source(
value, source=self.source
)
@ -725,25 +723,6 @@ class VariableBuilder:
value,
source=self.source,
)
elif (
is_function_or_wrapper(value)
and skipfiles.check(value, is_inlined_call=True)
and not inspect.getattr_static(value, "_torchdynamo_inline", False)
and not inspect.getattr_static(value, "__script_if_tracing_wrapper", False)
):
value = unwrap_if_wrapper(value)
self.install_guards(GuardBuilder.FUNCTION_MATCH)
return SkipFilesVariable(
value,
skipfiles.check_verbose(value, is_inlined_call=True).reason,
source=self.source,
)
elif istype(value, (types.FunctionType, torch.jit.ScriptFunction)):
self.install_guards(GuardBuilder.CLOSURE_MATCH)
return UserFunctionVariable(
value,
source=self.source,
)
elif isinstance(value, types.MethodType) and isinstance(
value.__self__, (torch.nn.Module, torch.utils._pytree.TreeSpec)
):
@ -1871,14 +1850,11 @@ class SourcelessBuilder:
return UserDefinedObjectVariable(value)
if ConstantVariable.is_literal(value):
return SourcelessBuilder.wrap_constant_literal(value)
elif is_builtin_callable(value):
return BuiltinVariable(value)
elif (is_function_or_wrapper(value) or callable(value)) and trace_rules.lookup(
value
) is not None:
value = unwrap_if_wrapper(value)
elif callable(value) and trace_rules.lookup_callable(value) is not None:
if is_callable_allowed(value):
self.tx.output.has_user_defined_allowed_in_graph = True
return trace_rules.lookup_callable(value)(value)
elif is_function_or_wrapper(value):
return trace_rules.lookup(value)(value)
elif isinstance(value, types.FunctionType):
return UserFunctionVariable(value)

View File

@ -99,6 +99,11 @@ def _polyfill_call_impl(name):
class BuiltinVariable(VariableTracker):
_SENTINEL = object()
@classmethod
def create_with_source(cls, value, source):
install_guard(source.make_guard(GuardBuilder.BUILTIN_MATCH))
return BuiltinVariable(value, source=source)
@staticmethod
@functools.lru_cache(None)
def _constant_fold_functions():

View File

@ -11,6 +11,7 @@ import torch
from .. import variables
from ..bytecode_transformation import create_call_function, create_rot_n
from ..exc import unimplemented, Unsupported
from ..guards import GuardBuilder, install_guard
from ..source import AttrSource, ConstantSource, DefaultsSource, GetItemSource
from ..utils import get_first_attr, identity, istype, make_cell
from .base import typestr, VariableTracker
@ -110,6 +111,7 @@ class UserFunctionVariable(BaseUserFunctionVariable):
@classmethod
def create_with_source(cls, value, source):
install_guard(source.make_guard(GuardBuilder.CLOSURE_MATCH))
return cls(
value,
source=source,

View File

@ -293,7 +293,7 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
self.value,
source=self.source,
).call_function(tx, args, kwargs)
elif self.value is torch.overrides.get_default_nowrap_functions:
elif self.value is torch.overrides.get_default_nowrap_functions.__wrapped__:
# [Note: __torch_function__] we return empty here because we restrict
# the set of functions that we trace __torch_function__ on to
# functions outside of the actual set. Implementing this properly will require implementing

View File

@ -735,7 +735,7 @@ class UserDefinedObjectVariable(UserDefinedVariable):
).call_function(tx, [self], {})
elif isinstance(subobj, staticmethod):
func = subobj.__get__(self.value)
if trace_rules.lookup(func) is not None:
if source is not None and trace_rules.lookup(func) is not None:
return trace_rules.lookup(func).create_with_source(func, source=source)
else:
return variables.UserFunctionVariable(func, source=source)
@ -768,7 +768,7 @@ class UserDefinedObjectVariable(UserDefinedVariable):
elif inspect.isfunction(dynamic_subobj):
if is_utils_checkpoint(func):
return build_checkpoint_variable(source=source)
elif trace_rules.lookup(func) is not None:
elif source is not None and trace_rules.lookup(func) is not None:
return trace_rules.lookup(func).create_with_source(
func, source=source
)

View File

@ -420,13 +420,7 @@ dynamo_expected_failures = {
"TestAsArrayCPU.test_copy_list_cpu_complex128", # test_tensor_creation_ops
"TestAsArrayCPU.test_copy_list_cpu_int16", # test_tensor_creation_ops
"TestTensorCreationCPU.test_cartesian_prod_cpu", # test_tensor_creation_ops
"TestSubclass.test_parametrization_non_wrapper_tensor_leave_parametrized_True", # test_subclass
"TestSubclass.test_module_optimization_non_wrapper_tensor", # test_subclass
"TestSubclass.test_serialization_non_wrapper_tensor_as_param_True", # test_subclass
"TestSubclass.test_parametrization_non_wrapper_tensor_leave_parametrized_False", # test_subclass
"TestSubclass.test_type_propagation_non_wrapper_tensor_as_param_False", # test_subclass
"TestSubclass.test_parametrization_base_tensor_leave_parametrized_True", # test_subclass
"TestSubclass.test_type_propagation_non_wrapper_tensor_as_param_True", # test_subclass
"TestSubclass.test_parametrization_base_tensor_leave_parametrized_False", # test_subclass
"TestStatelessFunctionalAPI.test_reparametrize_module_fail_reset_to_original_torch_func", # test_stateless
"TestStatelessFunctionalAPI.test_reparametrized_module_change_parametrization_original_stateless", # test_stateless
@ -560,29 +554,16 @@ dynamo_expected_failures = {
"PackedSequenceTest.test_pack_sequence", # nn/test_packed_sequence
"PackedSequenceTest.test_total_length", # nn/test_packed_sequence
"TestModuleHookNN.test_hook_inplace", # nn/test_module_hooks
"TestLazyModules.test_lazy_batchnorm2d_state", # nn/test_lazy_modules
"TestLazyModules.test_lazy_conv3d", # nn/test_lazy_modules
"TestLazyModules.test_lazy_conv_transposed1d", # nn/test_lazy_modules
"TestLazyModules.test_lazy_conv2d", # nn/test_lazy_modules
"TestLazyModules.test_lazy_instancenorm3d_state", # nn/test_lazy_modules
"TestLazyModules.test_lazy_batchnorm3d_state", # nn/test_lazy_modules
"TestLazyModules.test_lazy_conv_transpose1d_pickle", # nn/test_lazy_modules
"TestLazyModules.test_lazy_instancenorm2d", # nn/test_lazy_modules
"TestLazyModules.test_lazy_instancenorm2d_state", # nn/test_lazy_modules
"TestLazyModules.test_lazy_conv3d_pickle", # nn/test_lazy_modules
"TestLazyModules.test_lazy_batchnorm2d", # nn/test_lazy_modules
"TestLazyModules.test_lazy_instancenorm1d", # nn/test_lazy_modules
"TestLazyModules.test_lazy_batchnorm1d", # nn/test_lazy_modules
"TestLazyModules.test_lazy_instancenorm1d_state", # nn/test_lazy_modules
"TestLazyModules.test_lazy_conv_transpose3d_pickle", # nn/test_lazy_modules
"TestLazyModules.test_lazy_instancenorm3d", # nn/test_lazy_modules
"TestLazyModules.test_lazy_batchnorm3d", # nn/test_lazy_modules
"TestLazyModules.test_lazy_conv2d_pickle", # nn/test_lazy_modules
"TestLazyModules.test_lazy_conv1d_pickle", # nn/test_lazy_modules
"TestLazyModules.test_lazy_conv1d", # nn/test_lazy_modules
"TestLazyModules.test_linear", # nn/test_lazy_modules
"TestLazyModules.test_lazy_module_buffer", # nn/test_lazy_modules
"TestLazyModules.test_lazy_batchnorm1d_state", # nn/test_lazy_modules
"TestLazyModules.test_lazy_batchnorm_with_dict_input", # nn/test_lazy_modules
"TestLazyModules.test_lazy_conv_transpose2d", # nn/test_lazy_modules
"TestLazyModules.test_lazy_conv_transpose2d_pickle", # nn/test_lazy_modules
@ -1167,10 +1148,8 @@ dynamo_expected_failures = {
"AutogradFunctionTests.test_print_in_bwd", # dynamo/test_autograd_function
"AutogradFunctionTests.test_graph_break_if_lifted_free_variable", # dynamo/test_autograd_function
"AotAutogradFallbackTests.test_aot_sequence_nr", # dynamo/test_aot_autograd
"TestTorchFunctionOverride.test_tensor_subclass_propagation", # test_overrides
"TestNamedTuple.test_max", # test_overrides
"TestTorchFunctionMode.test_mode_notimplemented_loop", # test_overrides
"TestTorchFunctionMode.test_disable_enable_subclass", # test_overrides
"TestTorchFunctionOverride.test_mean_semantics", # test_overrides
"TestGradCheckOverride.test_gradcheck", # test_overrides
"TestTorchFunctionOverride.test_Tensor___cuda_array_interface_____get__", # test_overrides
@ -1580,7 +1559,6 @@ dynamo_expected_failures = {
"TestAutograd.test_increment_version", # test_autograd
"TestAutograd.test_record_function_callbacks", # test_autograd
"TestAutograd.test_save_on_cpu_and_checkpoint", # test_autograd
"TestAutogradForwardMode.test_make_dual_torch_dispatch", # test_autograd
"TestAutogradDeviceTypeCPU.test_sparse_ctor_getter_backward_cpu_complex128", # test_autograd
"TestNestedCheckpoint.test_nested_checkpoint_non_tensor_inputs_and_outputs_early_stop_False", # test_autograd
"TestAutograd.test_gradcheck_nondeterministic", # test_autograd
@ -1825,7 +1803,6 @@ dynamo_expected_failures = {
"TestFakeQuantizeOps.test_learnable_forward_per_tensor_cuda", # test_quantization
"TestQuantizedTensor.test_repeat", # test_quantization
"TestStaticQuantizedModule.test_linear_leaky_relu", # test_quantization
"TestBitsCPU.test_subclass_cpu", # test_quantization
"TestFakeQuantizeOps.test_learnable_backward_per_channel_cpu", # test_quantization
"TestFXNumericSuiteCoreAPIs.test_add_shadow_loggers_fun_ptq", # test_quantization
"TestQuantizeFx.test_static_lstm", # test_quantization
@ -2064,28 +2041,11 @@ dynamo_expected_failures = {
"TestGenerateOpcheckTests.test_opcheck_bad_op", # test_custom_ops
"TestCustomOp.test_legacy_define", # test_custom_ops
"TestPythonRegistration.test_alias_analysis", # test_python_dispatch
"TestPythonDispatch.test_torch_dispatch_mode_subclass_priority", # test_python_dispatch
"TestPythonDispatch.test_strides_slow_path", # test_python_dispatch
"TestPythonDispatch.test_invalid_ret", # test_python_dispatch
"TestPythonDispatch.test_dim_slowpath", # test_python_dispatch
"TestWrapperSubclassAliasingCPU.test_wrapper_subclass_aliasing_conv2d_cpu", # test_python_dispatch
"TestPythonDispatch.test_fancy_strides", # test_python_dispatch
"TestPythonDispatch.test_layout_slow_path", # test_python_dispatch
"TestPythonDispatch.test_dispatch_super_dont_autograd", # test_python_dispatch
"TestPythonDispatch.test_sizes_slow_path", # test_python_dispatch
"TestPythonRegistration.test_finalizer", # test_python_dispatch
"TestPythonDispatch.test_dispatch_super_call_list_arg", # test_python_dispatch
"TestPythonDispatch.test_is_contiguous_slow_path", # test_python_dispatch
"TestPythonRegistration.test_override_cpu_sum", # test_python_dispatch
"TestPythonDispatch.test_mode_with_make_subclass", # test_python_dispatch
"TestPythonDispatch.test_multiple_ops_subclass", # test_python_dispatch
"TestPythonDispatch.test_subclass_autograd_device_check", # test_python_dispatch
"TestPythonDispatch.test_data_ptr_respects_numel_slow_path", # test_python_dispatch
"TestPythonDispatch.test_make_subclass_with_modes", # test_python_dispatch
"TestPythonDispatch.test_dispatch_super_call", # test_python_dispatch
"TestPythonDispatch.test_subclass_priority", # test_python_dispatch
"TestPythonDispatch.test_exception_handling", # test_python_dispatch
"TestPythonDispatch.test_list_ret", # test_python_dispatch
"LoggingTests.test_trace_source_nested", # dynamo/test_logging
"LoggingTests.test_guards_recompiles", # dynamo/test_logging
"LoggingTests.test_inductor_info", # dynamo/test_logging
@ -6499,6 +6459,7 @@ dynamo_skips = {
"TestScalarOpsMisc.test_scalar_integer_operation_divbyzero_dtype_Q_operation1",
"TestArgmax.test_combinations_data61", # torch_np/test_ndarray_methods.py
"TestArgmax.test_combinations_data58", # torch_np/test_ndarray_methods.py
"TestPythonDispatch.test_list_ret", # test_python_dispatch.py
"TestCustomOpTestingCPU.test_opcheck_fails_basic_cpu", # test_custom_ops.py
"TestVmapAPI.test_functools_partial", # functorch/test_vmap.py
"TestSaveLoadForOpVersion.test_versioned_div_tensor_out", # test_jit.py