mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[dynamo][user-defined] Simplify call_hasattr (#133935)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/133935 Approved by: https://github.com/williamwen42, https://github.com/jansel ghstack dependencies: #133745, #133747, #133746, #133799, #133800
This commit is contained in:
parent
8d93fe510e
commit
33f1ee036e
|
|
@ -11,7 +11,7 @@ import torch.nn
|
|||
|
||||
from .. import trace_rules, variables
|
||||
from ..exc import (
|
||||
ObservedException,
|
||||
raise_observed_exception,
|
||||
unimplemented,
|
||||
UnspecializeRestartAnalysis,
|
||||
Unsupported,
|
||||
|
|
@ -1133,7 +1133,7 @@ class UnspecializedNNModuleVariable(UserDefinedObjectVariable):
|
|||
if out is None:
|
||||
out = self.getattr_helper(tx, "_buffers", name_vt)
|
||||
if out is None:
|
||||
raise ObservedException(f"object has no attribute {name}")
|
||||
raise_observed_exception(AttributeError, tx, self)
|
||||
return out
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -20,7 +20,12 @@ from torch._guards import TracingContext
|
|||
from .. import polyfill, variables
|
||||
from ..bytecode_transformation import create_call_function
|
||||
from ..create_parameter_op import do_not_convert_to_tracable_parameter
|
||||
from ..exc import ObservedException, raise_observed_exception, unimplemented
|
||||
from ..exc import (
|
||||
handle_observed_exception,
|
||||
ObservedAttributeError,
|
||||
raise_observed_exception,
|
||||
unimplemented,
|
||||
)
|
||||
from ..guards import GuardBuilder, install_guard
|
||||
from ..source import (
|
||||
AttrSource,
|
||||
|
|
@ -926,7 +931,10 @@ class UserDefinedObjectVariable(UserDefinedVariable):
|
|||
self._check_for_getattribute()
|
||||
|
||||
if tx.output.side_effects.has_pending_mutation_of_attr(self, name):
|
||||
return tx.output.side_effects.load_attr(self, name)
|
||||
result = tx.output.side_effects.load_attr(self, name, deleted_ok=True)
|
||||
if isinstance(result, variables.DeletedVariable):
|
||||
raise_observed_exception(AttributeError, tx, self)
|
||||
return result
|
||||
|
||||
if name == "__dict__":
|
||||
options = {"source": source}
|
||||
|
|
@ -1097,47 +1105,22 @@ class UserDefinedObjectVariable(UserDefinedVariable):
|
|||
raise_observed_exception(AttributeError, tx, self)
|
||||
|
||||
def call_hasattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracker":
|
||||
if tx.output.side_effects.is_attribute_mutation(self):
|
||||
try:
|
||||
result = tx.output.side_effects.load_attr(self, name, deleted_ok=True)
|
||||
return variables.ConstantVariable.create(
|
||||
not isinstance(result, variables.DeletedVariable)
|
||||
)
|
||||
except KeyError:
|
||||
pass
|
||||
if self._check_for_getattribute():
|
||||
unimplemented("hasattr with custom __getattribute__")
|
||||
|
||||
if self.source:
|
||||
install_guard(
|
||||
AttrSource(self.source, name).make_guard(GuardBuilder.HASATTR)
|
||||
)
|
||||
if self._check_for_getattribute():
|
||||
unimplemented("hasattr with custom __getattribute__")
|
||||
|
||||
try:
|
||||
self._getattr_static(name)
|
||||
return variables.ConstantVariable.create(True)
|
||||
except AttributeError:
|
||||
# Now check in __getattr__ function
|
||||
getattr_fn = self._check_for_getattr()
|
||||
if isinstance(getattr_fn, types.FunctionType):
|
||||
# Dynamo is going to trace the __getattr__ function with
|
||||
# args=name. Set the source accordingly.
|
||||
new_source = None
|
||||
if self.source:
|
||||
new_source = AttrSource(self.source, "__getattr__")
|
||||
try:
|
||||
result = variables.UserMethodVariable(
|
||||
getattr_fn, self, source=new_source
|
||||
).call_function(tx, [variables.ConstantVariable.create(name)], {})
|
||||
|
||||
var_vt = self.var_getattr(tx, name)
|
||||
return variables.ConstantVariable.create(
|
||||
not isinstance(result, variables.DeletedVariable)
|
||||
not isinstance(var_vt, variables.DeletedVariable)
|
||||
)
|
||||
except ObservedException:
|
||||
except ObservedAttributeError:
|
||||
handle_observed_exception(tx)
|
||||
return variables.ConstantVariable.create(False)
|
||||
elif getattr_fn is None:
|
||||
return variables.ConstantVariable.create(False)
|
||||
else:
|
||||
unimplemented("UserDefined with non-function __getattr__")
|
||||
|
||||
def odict_getitem(self, tx: "InstructionTranslator", key):
|
||||
from .builder import VariableBuilder
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user