[dynamo][user_defined][stable-diffusion] Raise ObservedAttributeError on UserDefinedObject var_getattr (#132806)

Fixes https://github.com/pytorch/pytorch/issues/132551

Pull Request resolved: https://github.com/pytorch/pytorch/pull/132806
Approved by: https://github.com/williamwen42
This commit is contained in:
Animesh Jain 2024-08-06 22:05:38 -07:00 committed by PyTorch MergeBot
parent 40ce0a53bb
commit 25df063f04
5 changed files with 40 additions and 9 deletions

View File

@ -309,6 +309,27 @@ class ExceptionTests(torch._dynamo.test_case.TestCase):
res = opt_fn(x, d)
self.assertEqual(ref, res)
def test_atrribute_error(self):
class Mock:
def __init__(self):
self.a = 1
mock = Mock()
def fn(x):
try:
c = 2
mock.b
except AttributeError:
c = 3
return torch.sin(x) * c
opt_fn = torch.compile(fn, backend="eager")
x = torch.randn(4)
ref = fn(x)
res = opt_fn(x)
self.assertEqual(ref, res)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests

View File

@ -204,9 +204,15 @@ class ObservedKeyError(ObservedException):
pass
class ObservedAttributeError(ObservedException):
# An AttributeError exception to be raised from inside Dynamo tracing. This can happen on user defined object __getattr__
pass
observed_exception_map = {
StopIteration: ObservedUserStopIteration,
KeyError: ObservedKeyError,
AttributeError: ObservedAttributeError,
}

View File

@ -471,13 +471,14 @@ def generic_jump(truth_fn: typing.Callable[[object], bool], push: bool):
elif isinstance(value, UserDefinedObjectVariable):
try:
x = value.var_getattr(self, "__bool__") # type: ignore[arg-type]
except exc.ObservedException:
except exc.ObservedAttributeError:
exc.handle_observed_exception(self)
# if __bool__ is missing, trying __len__ to infer a truth value.
x = value.var_getattr(self, "__len__") # type: ignore[arg-type]
else:
if isinstance(x, GetAttrVariable):
# if __bool__ is missing, trying __len__ to infer a truth value.
try:
x = value.var_getattr(self, "__len__") # type: ignore[arg-type]
except exc.ObservedAttributeError:
exc.handle_observed_exception(self)
x = None
# __bool__ or __len__ is function
if isinstance(x, UserMethodVariable):

View File

@ -20,7 +20,7 @@ from torch._guards import TracingContext
from .. import polyfill, variables
from ..bytecode_transformation import create_call_function
from ..create_parameter_op import do_not_convert_to_tracable_parameter
from ..exc import ObservedException, unimplemented
from ..exc import ObservedException, raise_observed_exception, unimplemented
from ..guards import GuardBuilder, install_guard
from ..source import (
AttrSource,
@ -1034,7 +1034,10 @@ class UserDefinedObjectVariable(UserDefinedVariable):
source = UnspecializedParamBufferSource(self.source, name)
source = self._wrap_source(source)
if subobj is not NO_SUCH_SUBOBJ and not is_wrapper_or_member_descriptor(subobj):
if subobj is not NO_SUCH_SUBOBJ:
if is_wrapper_or_member_descriptor(subobj):
options = {"source": source}
return variables.GetAttrVariable(self, name, **options)
if source:
return variables.LazyVariableTracker.create(subobj, source)
else:
@ -1042,8 +1045,8 @@ class UserDefinedObjectVariable(UserDefinedVariable):
return SourcelessBuilder.create(tx, subobj)
options = {"source": source}
return variables.GetAttrVariable(self, name, **options)
# Earlier we were returning GetAttrVariable but its incorrect. In absence of attr, Python raises AttributeError.
raise_observed_exception(AttributeError, tx, self)
def call_hasattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracker":
if tx.output.side_effects.is_attribute_mutation(self):