mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[dynamo][exceptions] Use exception subclass whenever possible (#134610)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134610 Approved by: https://github.com/drisspg, https://github.com/jansel
This commit is contained in:
parent
bf7db4e4f9
commit
880e3d18a4
|
|
@ -267,6 +267,29 @@ class ExceptionTests(torch._dynamo.test_case.TestCase):
|
||||||
x = torch.ones(4)
|
x = torch.ones(4)
|
||||||
self.assertEqual(mod(x), opt_mod(x))
|
self.assertEqual(mod(x), opt_mod(x))
|
||||||
|
|
||||||
|
def test_attribute_error_from_getattr(self):
|
||||||
|
class Mock:
|
||||||
|
def __init__(self):
|
||||||
|
self.a = 5
|
||||||
|
|
||||||
|
def __getattr__(self, name):
|
||||||
|
if name != "a":
|
||||||
|
raise AttributeError("missing")
|
||||||
|
return self.__dict__["a"]
|
||||||
|
|
||||||
|
mock = Mock()
|
||||||
|
|
||||||
|
def fn(x):
|
||||||
|
if hasattr(mock, "b"):
|
||||||
|
return torch.cos(x)
|
||||||
|
return torch.sin(x)
|
||||||
|
|
||||||
|
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
|
||||||
|
x = torch.randn(4)
|
||||||
|
ref = fn(x)
|
||||||
|
res = opt_fn(x)
|
||||||
|
self.assertEqual(ref, res)
|
||||||
|
|
||||||
def test_stop_iteration(self):
|
def test_stop_iteration(self):
|
||||||
def zip_longest(*iterables, fillvalue=None):
|
def zip_longest(*iterables, fillvalue=None):
|
||||||
# Get the iterators for each iterable
|
# Get the iterators for each iterable
|
||||||
|
|
|
||||||
|
|
@ -1379,9 +1379,10 @@ class InstructionTranslatorBase(
|
||||||
|
|
||||||
# 2) when user raises exception instance
|
# 2) when user raises exception instance
|
||||||
if isinstance(val, variables.ExceptionVariable):
|
if isinstance(val, variables.ExceptionVariable):
|
||||||
if val.exc_type is StopIteration:
|
if observed_exception_type := exc.observed_exception_map.get(
|
||||||
# StopIteration is used to find the end of iteration while tracing __next__
|
val.exc_type
|
||||||
raise exc.ObservedUserStopIteration(f"raised exception {val}")
|
):
|
||||||
|
raise observed_exception_type(f"raised exception {val}")
|
||||||
raise exc.ObservedException(f"raised exception {val}")
|
raise exc.ObservedException(f"raised exception {val}")
|
||||||
unimplemented(f"raise {exc}")
|
unimplemented(f"raise {exc}")
|
||||||
else:
|
else:
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user