diff --git a/torch/_dynamo/codegen.py b/torch/_dynamo/codegen.py index 4d4d494191b..f64ef6e5231 100644 --- a/torch/_dynamo/codegen.py +++ b/torch/_dynamo/codegen.py @@ -42,7 +42,6 @@ from .variables.base import ValueMutationExisting, VariableTracker from .variables.functions import ( ContextlibContextManagerLocalGeneratorObjectVariable, LocalGeneratorObjectVariable, - UserMethodVariable, ) from .variables.nn_module import NNModuleVariable from .variables.tensor import ( @@ -251,10 +250,7 @@ class PyCodegen: value.source is not None and allow_cache and not ( - value.is_realized() - and isinstance( - value, (LocalGeneratorObjectVariable, UserMethodVariable) - ) + value.is_realized() and isinstance(value, LocalGeneratorObjectVariable) ) ): # There's a corner case for export: for instance, if the computation diff --git a/torch/_dynamo/variables/functions.py b/torch/_dynamo/variables/functions.py index e628a955bc9..0da182c022b 100644 --- a/torch/_dynamo/variables/functions.py +++ b/torch/_dynamo/variables/functions.py @@ -1122,26 +1122,13 @@ class UserMethodVariable(UserFunctionVariable): return super().inspect_parameter_names()[1:] def var_getattr(self, tx: "InstructionTranslator", name: str): - if name == "__func__": - # self.source points to the source of the function object and not - # the method object - return VariableTracker.build(tx, self.fn, self.source) + source = self.source and AttrSource(self.source, name) if name == "__self__": return self.obj + if name == "__func__": + return VariableTracker.build(tx, self.fn, source) return super().var_getattr(tx, name) - def reconstruct(self, codegen): - if not self.obj.source or not self.source: - raise NotImplementedError - - def get_bound_method(): - codegen(self.source) - codegen.extend_output(codegen.create_load_attrs("__get__")) - - codegen.add_push_null(get_bound_method) - codegen(self.obj.source) - codegen.extend_output(create_call_function(1, False)) - class WrappedUserMethodVariable(UserMethodVariable): def __init__(self, wrapped, context, **kwargs) -> None: diff --git a/torch/_dynamo/variables/user_defined.py b/torch/_dynamo/variables/user_defined.py index 1b6d9ffacf1..7cb21ab3728 100644 --- a/torch/_dynamo/variables/user_defined.py +++ b/torch/_dynamo/variables/user_defined.py @@ -1380,9 +1380,7 @@ class UserDefinedObjectVariable(UserDefinedVariable): self.value.__class__, name, NO_SUCH_SUBOBJ ) is_accessible_from_type_mro = ( - subobj_from_class is subobj - and self.cls_source is not None - and self.source is not None + subobj_from_class is subobj and self.cls_source is not None ) if isinstance(subobj, property): @@ -1414,11 +1412,6 @@ class UserDefinedObjectVariable(UserDefinedVariable): func = subobj.__get__(self.value) return VariableTracker.build(tx, func, source) elif isinstance(subobj, classmethod): - if is_accessible_from_type_mro: - # Accessing from __dict__ does not resolve the descriptor, it - # returns a classmethod object, so access the __func__ - # attribute to get to the actual function. - source = AttrSource(self.get_source_by_walking_mro(name), "__func__") return variables.UserMethodVariable( subobj.__func__, self.var_getattr(tx, "__class__"), source=source ) @@ -1468,9 +1461,6 @@ class UserDefinedObjectVariable(UserDefinedVariable): isinstance(subobj, types.MethodType) and isinstance(self.value, torch.nn.Module) ): - if is_accessible_from_type_mro: - source = self.get_source_by_walking_mro(name) - # Since we get subobj via self._getattr_static, which may not trigger dynamic lookup. # Static lookup can't tell us it's a method or function correctly, # so we trigger dynamic lookup here to get the correct type.