diff --git a/test/dynamo/test_export.py b/test/dynamo/test_export.py index 5938b25e5de..2653f08b4f8 100644 --- a/test/dynamo/test_export.py +++ b/test/dynamo/test_export.py @@ -2851,7 +2851,8 @@ def forward(self, x): with self.assertRaisesRegex(RuntimeError, "Shape must be more than 4"): gm(torch.randn(3, 4, 5)) - def test_access_class_method_from_user_class(self): + @common_utils.parametrize("type_fn", [type, lambda obj: obj.__class__]) + def test_access_class_method_from_user_class(self, type_fn): class A: @classmethod def func(cls): @@ -2859,19 +2860,11 @@ def forward(self, x): def f(x): a = A() - return x.sum() + type(a).func().sum() + return x.sum() + type_fn(a).func().sum() gm, _ = torch._dynamo.export(f, aten_graph=True)(torch.ones(6, 4)) self.assertEqual(f(torch.ones(6, 4)), gm(torch.ones(6, 4))) - def f_correct(x): - a = A() - return x.sum() + a.__class__.func().sum() - - gm, _ = torch._dynamo.export(f_correct, aten_graph=True)(torch.ones(6, 4)) - - self.assertEqual(f_correct(torch.ones(6, 4)), gm(torch.ones(6, 4))) - def test_not_functionalize(self): class Foo(torch.nn.Module): def __init__(self): diff --git a/torch/_dynamo/variables/builtin.py b/torch/_dynamo/variables/builtin.py index c369f1681e9..0cb43d4be9d 100644 --- a/torch/_dynamo/variables/builtin.py +++ b/torch/_dynamo/variables/builtin.py @@ -11,7 +11,6 @@ import torch from torch import sym_float, sym_int from .. import config, polyfill, variables -from ..allowed_functions import is_allowed from ..exc import ( AttributeMutationError, unimplemented, @@ -1177,12 +1176,10 @@ class BuiltinVariable(VariableTracker): return build_checkpoint_variable(**options) elif trace_rules.lookup(member) is not None: return trace_rules.lookup(member)(member, **options) - elif is_allowed(member): - return TorchVariable(member, **options) - elif ConstantVariable.is_literal(member): - return ConstantVariable.create(member, **options) - else: + elif source is not None: return VariableBuilder(tx, source)(member) + else: + return SourcelessBuilder()(tx, member) elif isinstance(obj, (PythonModuleVariable, DummyModule)): member = obj.value.__dict__[name] @@ -1278,24 +1275,17 @@ class BuiltinVariable(VariableTracker): try: py_type = obj.python_type() - except NotImplementedError: - py_type = None + except NotImplementedError as error: + raise UserError( + UserErrorType.INVALID_INPUT, + str(error), + case_name="unknown_python_type", + ) from None - if istype(obj, variables.TupleVariable): - return BuiltinVariable(py_type) - - if py_type is not None and obj.source: - return VariableBuilder(tx, TypeSource(obj.source))(py_type) - - if py_type is not None: + if obj.source is None: return SourcelessBuilder()(tx, py_type) - - raise UserError( - UserErrorType.ANTI_PATTERN, - f"Can't call type() on generated custom object {obj}. " - "Please use __class__ instead", - case_name="type_reflection_method", - ) + else: + return VariableBuilder(tx, TypeSource(obj.source))(py_type) def call_reversed(self, tx, obj: VariableTracker): if obj.has_unpack_var_sequence(tx):