mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[dynamo][user_defined][stable-diffusion] Raise ObservedAttributeError on UserDefinedObject var_getattr (#132806)
Fixes https://github.com/pytorch/pytorch/issues/132551 Pull Request resolved: https://github.com/pytorch/pytorch/pull/132806 Approved by: https://github.com/williamwen42
This commit is contained in:
parent
40ce0a53bb
commit
25df063f04
|
|
@ -309,6 +309,27 @@ class ExceptionTests(torch._dynamo.test_case.TestCase):
|
|||
res = opt_fn(x, d)
|
||||
self.assertEqual(ref, res)
|
||||
|
||||
def test_atrribute_error(self):
|
||||
class Mock:
|
||||
def __init__(self):
|
||||
self.a = 1
|
||||
|
||||
mock = Mock()
|
||||
|
||||
def fn(x):
|
||||
try:
|
||||
c = 2
|
||||
mock.b
|
||||
except AttributeError:
|
||||
c = 3
|
||||
return torch.sin(x) * c
|
||||
|
||||
opt_fn = torch.compile(fn, backend="eager")
|
||||
x = torch.randn(4)
|
||||
ref = fn(x)
|
||||
res = opt_fn(x)
|
||||
self.assertEqual(ref, res)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._dynamo.test_case import run_tests
|
||||
|
|
|
|||
|
|
@ -204,9 +204,15 @@ class ObservedKeyError(ObservedException):
|
|||
pass
|
||||
|
||||
|
||||
class ObservedAttributeError(ObservedException):
|
||||
# An AttributeError exception to be raised from inside Dynamo tracing. This can happen on user defined object __getattr__
|
||||
pass
|
||||
|
||||
|
||||
observed_exception_map = {
|
||||
StopIteration: ObservedUserStopIteration,
|
||||
KeyError: ObservedKeyError,
|
||||
AttributeError: ObservedAttributeError,
|
||||
}
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -471,13 +471,14 @@ def generic_jump(truth_fn: typing.Callable[[object], bool], push: bool):
|
|||
elif isinstance(value, UserDefinedObjectVariable):
|
||||
try:
|
||||
x = value.var_getattr(self, "__bool__") # type: ignore[arg-type]
|
||||
except exc.ObservedException:
|
||||
except exc.ObservedAttributeError:
|
||||
exc.handle_observed_exception(self)
|
||||
# if __bool__ is missing, trying __len__ to infer a truth value.
|
||||
x = value.var_getattr(self, "__len__") # type: ignore[arg-type]
|
||||
else:
|
||||
if isinstance(x, GetAttrVariable):
|
||||
# if __bool__ is missing, trying __len__ to infer a truth value.
|
||||
try:
|
||||
x = value.var_getattr(self, "__len__") # type: ignore[arg-type]
|
||||
except exc.ObservedAttributeError:
|
||||
exc.handle_observed_exception(self)
|
||||
x = None
|
||||
|
||||
# __bool__ or __len__ is function
|
||||
if isinstance(x, UserMethodVariable):
|
||||
|
|
|
|||
|
|
@ -20,7 +20,7 @@ from torch._guards import TracingContext
|
|||
from .. import polyfill, variables
|
||||
from ..bytecode_transformation import create_call_function
|
||||
from ..create_parameter_op import do_not_convert_to_tracable_parameter
|
||||
from ..exc import ObservedException, unimplemented
|
||||
from ..exc import ObservedException, raise_observed_exception, unimplemented
|
||||
from ..guards import GuardBuilder, install_guard
|
||||
from ..source import (
|
||||
AttrSource,
|
||||
|
|
@ -1034,7 +1034,10 @@ class UserDefinedObjectVariable(UserDefinedVariable):
|
|||
source = UnspecializedParamBufferSource(self.source, name)
|
||||
source = self._wrap_source(source)
|
||||
|
||||
if subobj is not NO_SUCH_SUBOBJ and not is_wrapper_or_member_descriptor(subobj):
|
||||
if subobj is not NO_SUCH_SUBOBJ:
|
||||
if is_wrapper_or_member_descriptor(subobj):
|
||||
options = {"source": source}
|
||||
return variables.GetAttrVariable(self, name, **options)
|
||||
if source:
|
||||
return variables.LazyVariableTracker.create(subobj, source)
|
||||
else:
|
||||
|
|
@ -1042,8 +1045,8 @@ class UserDefinedObjectVariable(UserDefinedVariable):
|
|||
|
||||
return SourcelessBuilder.create(tx, subobj)
|
||||
|
||||
options = {"source": source}
|
||||
return variables.GetAttrVariable(self, name, **options)
|
||||
# Earlier we were returning GetAttrVariable but its incorrect. In absence of attr, Python raises AttributeError.
|
||||
raise_observed_exception(AttributeError, tx, self)
|
||||
|
||||
def call_hasattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracker":
|
||||
if tx.output.side_effects.is_attribute_mutation(self):
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user