mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[dynamo][exceptions] Use exception subclass whenever possible (#134610)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134610 Approved by: https://github.com/drisspg, https://github.com/jansel
This commit is contained in:
parent
bf7db4e4f9
commit
880e3d18a4
|
|
@ -267,6 +267,29 @@ class ExceptionTests(torch._dynamo.test_case.TestCase):
|
|||
x = torch.ones(4)
|
||||
self.assertEqual(mod(x), opt_mod(x))
|
||||
|
||||
def test_attribute_error_from_getattr(self):
|
||||
class Mock:
|
||||
def __init__(self):
|
||||
self.a = 5
|
||||
|
||||
def __getattr__(self, name):
|
||||
if name != "a":
|
||||
raise AttributeError("missing")
|
||||
return self.__dict__["a"]
|
||||
|
||||
mock = Mock()
|
||||
|
||||
def fn(x):
|
||||
if hasattr(mock, "b"):
|
||||
return torch.cos(x)
|
||||
return torch.sin(x)
|
||||
|
||||
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
|
||||
x = torch.randn(4)
|
||||
ref = fn(x)
|
||||
res = opt_fn(x)
|
||||
self.assertEqual(ref, res)
|
||||
|
||||
def test_stop_iteration(self):
|
||||
def zip_longest(*iterables, fillvalue=None):
|
||||
# Get the iterators for each iterable
|
||||
|
|
|
|||
|
|
@ -1379,9 +1379,10 @@ class InstructionTranslatorBase(
|
|||
|
||||
# 2) when user raises exception instance
|
||||
if isinstance(val, variables.ExceptionVariable):
|
||||
if val.exc_type is StopIteration:
|
||||
# StopIteration is used to find the end of iteration while tracing __next__
|
||||
raise exc.ObservedUserStopIteration(f"raised exception {val}")
|
||||
if observed_exception_type := exc.observed_exception_map.get(
|
||||
val.exc_type
|
||||
):
|
||||
raise observed_exception_type(f"raised exception {val}")
|
||||
raise exc.ObservedException(f"raised exception {val}")
|
||||
unimplemented(f"raise {exc}")
|
||||
else:
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user