mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
The reraise is not supported and so this just gunks up our actual exception handling. You can trigger this by hitting an exception inside of an NN module that has hooks on it. You end up graph breaking on the reraise here, and losing the inner stack trace from the actual exception that was raised. This might be kind of controversial. An alternate strategy is to support reraises in Dynamo or something but IDK this doesn't feel like the right place to apply force. Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/133239 Approved by: https://github.com/anijain2305
410 lines
11 KiB
Python
410 lines
11 KiB
Python
# Owner(s): ["module: dynamo"]
|
|
|
|
import torch
|
|
import torch._dynamo.config
|
|
import torch._dynamo.test_case
|
|
import torch._functorch.config
|
|
import torch.nn
|
|
import torch.utils.checkpoint
|
|
|
|
|
|
class ExceptionTests(torch._dynamo.test_case.TestCase):
|
|
def test_exception(self):
|
|
def fn(x):
|
|
x = torch.cos(x)
|
|
try:
|
|
x = torch.sin(x)
|
|
raise NotImplementedError
|
|
except Exception:
|
|
x = torch.sigmoid(x)
|
|
|
|
return x
|
|
|
|
x = torch.randn(4)
|
|
ref = fn(x)
|
|
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
|
|
res = opt_fn(x)
|
|
self.assertEqual(ref, res)
|
|
|
|
def test_exception2(self):
|
|
def fn(x):
|
|
x = torch.cos(x)
|
|
try:
|
|
x = torch.sin(x)
|
|
raise NotImplementedError
|
|
except (NotImplementedError, AttributeError) as e:
|
|
x = torch.sigmoid(x)
|
|
|
|
return x
|
|
|
|
x = torch.randn(4)
|
|
ref = fn(x)
|
|
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
|
|
res = opt_fn(x)
|
|
self.assertEqual(ref, res)
|
|
|
|
def test_exception3(self):
|
|
def fn(x):
|
|
x = torch.cos(x)
|
|
try:
|
|
x = torch.sin(x)
|
|
raise NotImplementedError("Not implemented")
|
|
except AssertionError:
|
|
x = torch.sigmoid(x)
|
|
except NotImplementedError:
|
|
x = torch.cos(x)
|
|
finally:
|
|
x = torch.cos(x)
|
|
|
|
return x
|
|
|
|
x = torch.randn(4)
|
|
ref = fn(x)
|
|
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
|
|
res = opt_fn(x)
|
|
self.assertEqual(ref, res)
|
|
|
|
def test_exception4(self):
|
|
def fn(x):
|
|
for i in range(10):
|
|
if i == 5:
|
|
return x
|
|
try:
|
|
x = torch.sin(x)
|
|
raise NotImplementedError
|
|
except Exception:
|
|
x = torch.sigmoid(x)
|
|
|
|
return x
|
|
|
|
x = torch.randn(4)
|
|
ref = fn(x)
|
|
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
|
|
res = opt_fn(x)
|
|
self.assertEqual(ref, res)
|
|
|
|
def test_exception_with_another_exception(self):
|
|
def fn(x):
|
|
x = torch.cos(x)
|
|
try:
|
|
x = torch.sin(x)
|
|
raise NotImplementedError("Not implemented")
|
|
except NotImplementedError as e:
|
|
x = torch.sigmoid(x)
|
|
try:
|
|
x = torch.cos(x)
|
|
raise AssertionError
|
|
except AssertionError:
|
|
x = torch.cos(x)
|
|
|
|
x = torch.randn(4)
|
|
ref = fn(x)
|
|
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
|
|
res = opt_fn(x)
|
|
self.assertEqual(ref, res)
|
|
|
|
def test_exception_else(self):
|
|
def gn(x):
|
|
return torch.cos(x)
|
|
|
|
def fn(x):
|
|
x = torch.cos(x)
|
|
try:
|
|
x = torch.sin(x)
|
|
x = gn(x)
|
|
except Exception:
|
|
x = torch.sigmoid(x)
|
|
else:
|
|
x = torch.cos(x)
|
|
|
|
return x
|
|
|
|
x = torch.randn(4)
|
|
ref = fn(x)
|
|
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
|
|
res = opt_fn(x)
|
|
self.assertEqual(ref, res)
|
|
|
|
# TODO(anijain2305) - does not work with fullgraph=True
|
|
def test_exception_with_another_exception2(self):
|
|
def gn(x):
|
|
try:
|
|
x = torch.cos(x)
|
|
raise NotImplementedError("Not implemented")
|
|
except NotImplementedError as e:
|
|
x = torch.sigmoid(x)
|
|
raise
|
|
|
|
def fn(x):
|
|
try:
|
|
x = torch.cos(x)
|
|
gn(x)
|
|
except Exception:
|
|
pass
|
|
return x
|
|
|
|
x = torch.randn(4)
|
|
ref = fn(x)
|
|
# Cant use fullgraph=True because RERAISE is not supported
|
|
opt_fn = torch.compile(fn, backend="eager")
|
|
res = opt_fn(x)
|
|
|
|
# TODO(anijain2305) - does not work with fullgraph=True
|
|
def test_exception_with_ctx_manager(self):
|
|
def fn(x):
|
|
x = torch.cos(x)
|
|
try:
|
|
with torch.no_grad():
|
|
x = torch.sin(x)
|
|
raise NotImplementedError("Not implemented")
|
|
except NotImplementedError as e:
|
|
x = torch.sigmoid(x)
|
|
return x
|
|
|
|
x = torch.randn(4)
|
|
ref = fn(x)
|
|
# Cant use fullgraph=True because WITH_EXCEPT_START is not supported
|
|
opt_fn = torch.compile(fn, backend="eager")
|
|
res = opt_fn(x)
|
|
self.assertEqual(ref, res)
|
|
|
|
def test_exception_raised_from_child(self):
|
|
def gn():
|
|
raise NotImplementedError("foo")
|
|
|
|
def fn(x):
|
|
x = torch.cos(x)
|
|
try:
|
|
x = torch.sin(x)
|
|
gn()
|
|
x = torch.sin(x)
|
|
except Exception:
|
|
x = torch.sigmoid(x)
|
|
|
|
return x
|
|
|
|
x = torch.randn(4)
|
|
ref = fn(x)
|
|
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
|
|
res = opt_fn(x)
|
|
self.assertEqual(ref, res)
|
|
|
|
def test_dynamo_undo_kw_names(self):
|
|
def g(x, k=None):
|
|
if k:
|
|
raise TypeError("error")
|
|
return x.sin()
|
|
|
|
def fn(x):
|
|
d = {"a": x}
|
|
try:
|
|
g(x, k=True)
|
|
except Exception:
|
|
y = 0
|
|
for _, b in d.items(): # noqa: PERF102
|
|
y += b.sum()
|
|
return y
|
|
|
|
x = torch.randn(2, 3)
|
|
expected = fn(x)
|
|
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
|
|
got = opt_fn(x)
|
|
self.assertEqual(expected, got)
|
|
|
|
def test_nn_module_getattr(self):
|
|
class A:
|
|
def __init__(self) -> None:
|
|
self._b = 20
|
|
|
|
def __getattr__(self, name):
|
|
fixed_name = "_" + name
|
|
if fixed_name in self.__dict__:
|
|
return self.__dict__[fixed_name]
|
|
raise AttributeError(f"{name} absent")
|
|
|
|
class B(A):
|
|
def __init__(self) -> None:
|
|
self.a = 10
|
|
|
|
def __getattr__(self, name):
|
|
try:
|
|
return super().__getattr__(name)
|
|
except AttributeError:
|
|
return 30
|
|
|
|
obj = B()
|
|
|
|
def fn(x):
|
|
return x * obj.a * obj.b * obj.c
|
|
|
|
x = torch.ones(4)
|
|
ref = fn(x)
|
|
print(ref)
|
|
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
|
|
res = opt_fn(x)
|
|
self.assertEqual(ref, res)
|
|
|
|
@torch._dynamo.config.patch(inline_inbuilt_nn_modules=True)
|
|
def test_custom_getattr_on_module_exception(self):
|
|
class Foo(torch.nn.Module):
|
|
def __init__(self, a=3):
|
|
super().__init__()
|
|
self.register_parameter("a", torch.nn.Parameter(torch.ones(4) * 2))
|
|
|
|
def __getattr__(self, name):
|
|
try:
|
|
return super().__getattr__(name) # defer to nn.Module's logic
|
|
except AttributeError:
|
|
if name == "a_copy":
|
|
return self.a
|
|
raise
|
|
|
|
def forward(self, x):
|
|
return x * self.a * self.a_copy
|
|
|
|
mod = Foo()
|
|
opt_mod = torch.compile(mod, backend="eager", fullgraph=True)
|
|
|
|
x = torch.ones(4)
|
|
self.assertEqual(mod(x), opt_mod(x))
|
|
|
|
def test_attribute_error_from_getattr(self):
|
|
class Mock:
|
|
def __init__(self):
|
|
self.a = 5
|
|
|
|
def __getattr__(self, name):
|
|
if name != "a":
|
|
raise AttributeError("missing")
|
|
return self.__dict__["a"]
|
|
|
|
mock = Mock()
|
|
|
|
def fn(x):
|
|
if hasattr(mock, "b"):
|
|
return torch.cos(x)
|
|
return torch.sin(x)
|
|
|
|
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
|
|
x = torch.randn(4)
|
|
ref = fn(x)
|
|
res = opt_fn(x)
|
|
self.assertEqual(ref, res)
|
|
|
|
def test_stop_iteration(self):
|
|
def zip_longest(*iterables, fillvalue=None):
|
|
# Get the iterators for each iterable
|
|
iterators = [iter(it) for it in iterables]
|
|
|
|
result = []
|
|
while True:
|
|
for it in iterators:
|
|
try:
|
|
value = next(it)
|
|
except StopIteration:
|
|
result.append(fillvalue)
|
|
return result
|
|
result.append(value)
|
|
|
|
def fn(x, y):
|
|
torch.cos(torch.randn(4))
|
|
return tuple(zip_longest(x, y))
|
|
|
|
x = [1, 2, 3, 4]
|
|
y = [10, 11, 12]
|
|
|
|
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
|
|
ref = fn(x, y)
|
|
res = opt_fn(x, y)
|
|
self.assertEqual(ref, res)
|
|
|
|
def test_nn_reraise(self):
|
|
class M(torch.nn.Module):
|
|
def forward(self, x):
|
|
raise ValueError("woof")
|
|
return x + 2
|
|
|
|
m = M()
|
|
m.register_forward_pre_hook(lambda m, go: None)
|
|
|
|
torch._dynamo.utils.clear_compilation_metrics()
|
|
opt_call = torch.compile(lambda x: m(x), backend="eager")
|
|
self.assertRaises(ValueError, lambda: opt_call(torch.randn(3)))
|
|
metrics = torch._dynamo.utils.get_compilation_metrics()
|
|
self.assertEqual(metrics[0].fail_reason, "Observed exception")
|
|
|
|
def test_key_error(self):
|
|
def fn(x, d):
|
|
try:
|
|
a = d["b"]
|
|
except KeyError:
|
|
a = 2
|
|
return x * a
|
|
|
|
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
|
|
x = torch.randn(4)
|
|
d = {"a": 1}
|
|
ref = fn(x, d)
|
|
res = opt_fn(x, d)
|
|
self.assertEqual(ref, res)
|
|
|
|
def test_atrribute_error(self):
|
|
class Mock:
|
|
def __init__(self):
|
|
self.a = 1
|
|
|
|
mock = Mock()
|
|
|
|
def fn(x):
|
|
try:
|
|
c = 2
|
|
mock.b
|
|
except AttributeError:
|
|
c = 3
|
|
return torch.sin(x) * c
|
|
|
|
opt_fn = torch.compile(fn, backend="eager")
|
|
x = torch.randn(4)
|
|
ref = fn(x)
|
|
res = opt_fn(x)
|
|
self.assertEqual(ref, res)
|
|
|
|
def test_raise_from_None(self):
|
|
# Inspired from os.environ
|
|
class MyMapping:
|
|
def __init__(self, d):
|
|
self._d = d
|
|
|
|
def __getitem__(self, key):
|
|
try:
|
|
value = self._d[key]
|
|
except KeyError:
|
|
raise KeyError(key) from None
|
|
return value
|
|
|
|
d = MyMapping({"a": 10, "b": 20})
|
|
|
|
def mapping_get(obj, key, value=None):
|
|
try:
|
|
return obj.__getitem__(key)
|
|
except KeyError:
|
|
return value
|
|
|
|
def fn(x, d, key):
|
|
x = torch.sin(x + 1)
|
|
return x, mapping_get(d, key)
|
|
|
|
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
|
|
|
|
x = torch.rand(2, 3)
|
|
ref = fn(x, d, "m")
|
|
res = opt_fn(x, d, "m")
|
|
self.assertEqual(ref[0], res[0])
|
|
self.assertEqual(ref[1], res[1])
|
|
|
|
|
|
if __name__ == "__main__":
|
|
from torch._dynamo.test_case import run_tests
|
|
|
|
run_tests()
|