[dynamo][user-defined] Simplify call_hasattr (#133935)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/133935
Approved by: https://github.com/williamwen42, https://github.com/jansel
ghstack dependencies: #133745, #133747, #133746, #133799, #133800
This commit is contained in:
Animesh Jain 2024-08-19 18:46:27 -07:00 committed by PyTorch MergeBot
parent 8d93fe510e
commit 33f1ee036e
2 changed files with 22 additions and 39 deletions

View File

@ -11,7 +11,7 @@ import torch.nn
from .. import trace_rules, variables
from ..exc import (
ObservedException,
raise_observed_exception,
unimplemented,
UnspecializeRestartAnalysis,
Unsupported,
@ -1133,7 +1133,7 @@ class UnspecializedNNModuleVariable(UserDefinedObjectVariable):
if out is None:
out = self.getattr_helper(tx, "_buffers", name_vt)
if out is None:
raise ObservedException(f"object has no attribute {name}")
raise_observed_exception(AttributeError, tx, self)
return out

View File

@ -20,7 +20,12 @@ from torch._guards import TracingContext
from .. import polyfill, variables
from ..bytecode_transformation import create_call_function
from ..create_parameter_op import do_not_convert_to_tracable_parameter
from ..exc import ObservedException, raise_observed_exception, unimplemented
from ..exc import (
handle_observed_exception,
ObservedAttributeError,
raise_observed_exception,
unimplemented,
)
from ..guards import GuardBuilder, install_guard
from ..source import (
AttrSource,
@ -926,7 +931,10 @@ class UserDefinedObjectVariable(UserDefinedVariable):
self._check_for_getattribute()
if tx.output.side_effects.has_pending_mutation_of_attr(self, name):
return tx.output.side_effects.load_attr(self, name)
result = tx.output.side_effects.load_attr(self, name, deleted_ok=True)
if isinstance(result, variables.DeletedVariable):
raise_observed_exception(AttributeError, tx, self)
return result
if name == "__dict__":
options = {"source": source}
@ -1097,47 +1105,22 @@ class UserDefinedObjectVariable(UserDefinedVariable):
raise_observed_exception(AttributeError, tx, self)
def call_hasattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracker":
if tx.output.side_effects.is_attribute_mutation(self):
try:
result = tx.output.side_effects.load_attr(self, name, deleted_ok=True)
return variables.ConstantVariable.create(
not isinstance(result, variables.DeletedVariable)
)
except KeyError:
pass
if self._check_for_getattribute():
unimplemented("hasattr with custom __getattribute__")
if self.source:
install_guard(
AttrSource(self.source, name).make_guard(GuardBuilder.HASATTR)
)
if self._check_for_getattribute():
unimplemented("hasattr with custom __getattribute__")
try:
self._getattr_static(name)
return variables.ConstantVariable.create(True)
except AttributeError:
# Now check in __getattr__ function
getattr_fn = self._check_for_getattr()
if isinstance(getattr_fn, types.FunctionType):
# Dynamo is going to trace the __getattr__ function with
# args=name. Set the source accordingly.
new_source = None
if self.source:
new_source = AttrSource(self.source, "__getattr__")
try:
result = variables.UserMethodVariable(
getattr_fn, self, source=new_source
).call_function(tx, [variables.ConstantVariable.create(name)], {})
var_vt = self.var_getattr(tx, name)
return variables.ConstantVariable.create(
not isinstance(result, variables.DeletedVariable)
not isinstance(var_vt, variables.DeletedVariable)
)
except ObservedException:
except ObservedAttributeError:
handle_observed_exception(tx)
return variables.ConstantVariable.create(False)
elif getattr_fn is None:
return variables.ConstantVariable.create(False)
else:
unimplemented("UserDefined with non-function __getattr__")
def odict_getitem(self, tx: "InstructionTranslator", key):
from .builder import VariableBuilder