Don't setup try-except handler when Dynamo compiling (#133239)

The reraise is not supported and so this just gunks up our actual exception handling. You can trigger this by hitting an exception inside of an NN module that has hooks on it. You end up graph breaking on the reraise here, and losing the inner stack trace from the actual exception that was raised.

This might be kind of controversial.  An alternate strategy is to support reraises in Dynamo or something but IDK this doesn't feel like the right place to apply force.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/133239
Approved by: https://github.com/anijain2305
This commit is contained in:
Edward Z. Yang 2024-09-01 12:36:16 -07:00 committed by PyTorch MergeBot
parent ea01aec8b1
commit 208442ea18
2 changed files with 35 additions and 3 deletions

View File

@ -4,6 +4,7 @@ import torch
import torch._dynamo.config
import torch._dynamo.test_case
import torch._functorch.config
import torch.nn
import torch.utils.checkpoint
@ -317,6 +318,21 @@ class ExceptionTests(torch._dynamo.test_case.TestCase):
res = opt_fn(x, y)
self.assertEqual(ref, res)
def test_nn_reraise(self):
class M(torch.nn.Module):
def forward(self, x):
raise ValueError("woof")
return x + 2
m = M()
m.register_forward_pre_hook(lambda m, go: None)
torch._dynamo.utils.clear_compilation_metrics()
opt_call = torch.compile(lambda x: m(x), backend="eager")
self.assertRaises(ValueError, lambda: opt_call(torch.randn(3)))
metrics = torch._dynamo.utils.get_compilation_metrics()
self.assertEqual(metrics[0].fail_reason, "Observed exception")
def test_key_error(self):
def fn(x, d):
try:

View File

@ -1746,9 +1746,11 @@ class Module:
or _global_forward_hooks or _global_forward_pre_hooks):
return forward_call(*args, **kwargs)
try:
result = None
called_always_called_hooks = set()
result = None
called_always_called_hooks = set()
def inner():
nonlocal result, args, kwargs
full_backward_hooks, non_full_backward_hooks = [], []
backward_pre_hooks = []
@ -1826,6 +1828,20 @@ class Module:
return result
from torch.compiler import is_compiling
# This is technically not behavior equivalent when compiling, but it's
# incredibly unlikely we will ever support throwing an exception in NN
# module, and then catching it here, and then reraising it, and then
# catching it again, and expecting the resulting frame to be compiled.
# The reraise here just gunks up our exception handling for no good
# reason. Don't try to run the always called hooks in event of
# exception.
if is_compiling():
return inner()
try:
return inner()
except Exception:
# run always called hooks if they have not already been run
# For now only forward hooks have the always_call option but perhaps