mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Revert "[dynamo][exceptions] Use exception subclass whenever possible (#134610)"
This reverts commit 880e3d18a4.
Reverted https://github.com/pytorch/pytorch/pull/134610 on behalf of https://github.com/ZainRizvi due to Sorry, I had to revert this in order to revert another PR ([comment](https://github.com/pytorch/pytorch/pull/134610#issuecomment-2316568553))
This commit is contained in:
parent
67d7040fce
commit
f0fceed432
|
|
@ -267,29 +267,6 @@ class ExceptionTests(torch._dynamo.test_case.TestCase):
|
||||||
x = torch.ones(4)
|
x = torch.ones(4)
|
||||||
self.assertEqual(mod(x), opt_mod(x))
|
self.assertEqual(mod(x), opt_mod(x))
|
||||||
|
|
||||||
def test_attribute_error_from_getattr(self):
|
|
||||||
class Mock:
|
|
||||||
def __init__(self):
|
|
||||||
self.a = 5
|
|
||||||
|
|
||||||
def __getattr__(self, name):
|
|
||||||
if name != "a":
|
|
||||||
raise AttributeError("missing")
|
|
||||||
return self.__dict__["a"]
|
|
||||||
|
|
||||||
mock = Mock()
|
|
||||||
|
|
||||||
def fn(x):
|
|
||||||
if hasattr(mock, "b"):
|
|
||||||
return torch.cos(x)
|
|
||||||
return torch.sin(x)
|
|
||||||
|
|
||||||
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
|
|
||||||
x = torch.randn(4)
|
|
||||||
ref = fn(x)
|
|
||||||
res = opt_fn(x)
|
|
||||||
self.assertEqual(ref, res)
|
|
||||||
|
|
||||||
def test_stop_iteration(self):
|
def test_stop_iteration(self):
|
||||||
def zip_longest(*iterables, fillvalue=None):
|
def zip_longest(*iterables, fillvalue=None):
|
||||||
# Get the iterators for each iterable
|
# Get the iterators for each iterable
|
||||||
|
|
|
||||||
|
|
@ -1379,10 +1379,9 @@ class InstructionTranslatorBase(
|
||||||
|
|
||||||
# 2) when user raises exception instance
|
# 2) when user raises exception instance
|
||||||
if isinstance(val, variables.ExceptionVariable):
|
if isinstance(val, variables.ExceptionVariable):
|
||||||
if observed_exception_type := exc.observed_exception_map.get(
|
if val.exc_type is StopIteration:
|
||||||
val.exc_type
|
# StopIteration is used to find the end of iteration while tracing __next__
|
||||||
):
|
raise exc.ObservedUserStopIteration(f"raised exception {val}")
|
||||||
raise observed_exception_type(f"raised exception {val}")
|
|
||||||
raise exc.ObservedException(f"raised exception {val}")
|
raise exc.ObservedException(f"raised exception {val}")
|
||||||
unimplemented(f"raise {exc}")
|
unimplemented(f"raise {exc}")
|
||||||
else:
|
else:
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user