mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Don't setup try-except handler when Dynamo compiling (#133239)
The reraise is not supported and so this just gunks up our actual exception handling. You can trigger this by hitting an exception inside of an NN module that has hooks on it. You end up graph breaking on the reraise here, and losing the inner stack trace from the actual exception that was raised. This might be kind of controversial. An alternate strategy is to support reraises in Dynamo or something but IDK this doesn't feel like the right place to apply force. Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/133239 Approved by: https://github.com/anijain2305
This commit is contained in:
parent
ea01aec8b1
commit
208442ea18
|
|
@ -4,6 +4,7 @@ import torch
|
|||
import torch._dynamo.config
|
||||
import torch._dynamo.test_case
|
||||
import torch._functorch.config
|
||||
import torch.nn
|
||||
import torch.utils.checkpoint
|
||||
|
||||
|
||||
|
|
@ -317,6 +318,21 @@ class ExceptionTests(torch._dynamo.test_case.TestCase):
|
|||
res = opt_fn(x, y)
|
||||
self.assertEqual(ref, res)
|
||||
|
||||
def test_nn_reraise(self):
|
||||
class M(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
raise ValueError("woof")
|
||||
return x + 2
|
||||
|
||||
m = M()
|
||||
m.register_forward_pre_hook(lambda m, go: None)
|
||||
|
||||
torch._dynamo.utils.clear_compilation_metrics()
|
||||
opt_call = torch.compile(lambda x: m(x), backend="eager")
|
||||
self.assertRaises(ValueError, lambda: opt_call(torch.randn(3)))
|
||||
metrics = torch._dynamo.utils.get_compilation_metrics()
|
||||
self.assertEqual(metrics[0].fail_reason, "Observed exception")
|
||||
|
||||
def test_key_error(self):
|
||||
def fn(x, d):
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -1746,9 +1746,11 @@ class Module:
|
|||
or _global_forward_hooks or _global_forward_pre_hooks):
|
||||
return forward_call(*args, **kwargs)
|
||||
|
||||
try:
|
||||
result = None
|
||||
called_always_called_hooks = set()
|
||||
result = None
|
||||
called_always_called_hooks = set()
|
||||
|
||||
def inner():
|
||||
nonlocal result, args, kwargs
|
||||
|
||||
full_backward_hooks, non_full_backward_hooks = [], []
|
||||
backward_pre_hooks = []
|
||||
|
|
@ -1826,6 +1828,20 @@ class Module:
|
|||
|
||||
return result
|
||||
|
||||
from torch.compiler import is_compiling
|
||||
|
||||
# This is technically not behavior equivalent when compiling, but it's
|
||||
# incredibly unlikely we will ever support throwing an exception in NN
|
||||
# module, and then catching it here, and then reraising it, and then
|
||||
# catching it again, and expecting the resulting frame to be compiled.
|
||||
# The reraise here just gunks up our exception handling for no good
|
||||
# reason. Don't try to run the always called hooks in event of
|
||||
# exception.
|
||||
if is_compiling():
|
||||
return inner()
|
||||
|
||||
try:
|
||||
return inner()
|
||||
except Exception:
|
||||
# run always called hooks if they have not already been run
|
||||
# For now only forward hooks have the always_call option but perhaps
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user