diff --git a/torch/_dynamo/variables/nn_module.py b/torch/_dynamo/variables/nn_module.py index 3bd845d64a3..9c4fb05df95 100644 --- a/torch/_dynamo/variables/nn_module.py +++ b/torch/_dynamo/variables/nn_module.py @@ -11,7 +11,7 @@ import torch.nn from .. import trace_rules, variables from ..exc import ( - ObservedException, + raise_observed_exception, unimplemented, UnspecializeRestartAnalysis, Unsupported, @@ -1133,7 +1133,7 @@ class UnspecializedNNModuleVariable(UserDefinedObjectVariable): if out is None: out = self.getattr_helper(tx, "_buffers", name_vt) if out is None: - raise ObservedException(f"object has no attribute {name}") + raise_observed_exception(AttributeError, tx, self) return out diff --git a/torch/_dynamo/variables/user_defined.py b/torch/_dynamo/variables/user_defined.py index 3a1cef7f397..25a1663cc93 100644 --- a/torch/_dynamo/variables/user_defined.py +++ b/torch/_dynamo/variables/user_defined.py @@ -20,7 +20,12 @@ from torch._guards import TracingContext from .. import polyfill, variables from ..bytecode_transformation import create_call_function from ..create_parameter_op import do_not_convert_to_tracable_parameter -from ..exc import ObservedException, raise_observed_exception, unimplemented +from ..exc import ( + handle_observed_exception, + ObservedAttributeError, + raise_observed_exception, + unimplemented, +) from ..guards import GuardBuilder, install_guard from ..source import ( AttrSource, @@ -926,7 +931,10 @@ class UserDefinedObjectVariable(UserDefinedVariable): self._check_for_getattribute() if tx.output.side_effects.has_pending_mutation_of_attr(self, name): - return tx.output.side_effects.load_attr(self, name) + result = tx.output.side_effects.load_attr(self, name, deleted_ok=True) + if isinstance(result, variables.DeletedVariable): + raise_observed_exception(AttributeError, tx, self) + return result if name == "__dict__": options = {"source": source} @@ -1097,47 +1105,22 @@ class UserDefinedObjectVariable(UserDefinedVariable): raise_observed_exception(AttributeError, tx, self) def call_hasattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracker": - if tx.output.side_effects.is_attribute_mutation(self): - try: - result = tx.output.side_effects.load_attr(self, name, deleted_ok=True) - return variables.ConstantVariable.create( - not isinstance(result, variables.DeletedVariable) - ) - except KeyError: - pass + if self._check_for_getattribute(): + unimplemented("hasattr with custom __getattribute__") + if self.source: install_guard( AttrSource(self.source, name).make_guard(GuardBuilder.HASATTR) ) - if self._check_for_getattribute(): - unimplemented("hasattr with custom __getattribute__") try: - self._getattr_static(name) - return variables.ConstantVariable.create(True) - except AttributeError: - # Now check in __getattr__ function - getattr_fn = self._check_for_getattr() - if isinstance(getattr_fn, types.FunctionType): - # Dynamo is going to trace the __getattr__ function with - # args=name. Set the source accordingly. - new_source = None - if self.source: - new_source = AttrSource(self.source, "__getattr__") - try: - result = variables.UserMethodVariable( - getattr_fn, self, source=new_source - ).call_function(tx, [variables.ConstantVariable.create(name)], {}) - - return variables.ConstantVariable.create( - not isinstance(result, variables.DeletedVariable) - ) - except ObservedException: - return variables.ConstantVariable.create(False) - elif getattr_fn is None: - return variables.ConstantVariable.create(False) - else: - unimplemented("UserDefined with non-function __getattr__") + var_vt = self.var_getattr(tx, name) + return variables.ConstantVariable.create( + not isinstance(var_vt, variables.DeletedVariable) + ) + except ObservedAttributeError: + handle_observed_exception(tx) + return variables.ConstantVariable.create(False) def odict_getitem(self, tx: "InstructionTranslator", key): from .builder import VariableBuilder