[Dynamo] Add UserError type (#97705)

To get started the dynamo error message improvement effort, we discussed about adding new user error type which covers cases where the user used something that TorchDynamo doesn't support and there is clear actions they can take.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/97705
Approved by: https://github.com/anijain2305, https://github.com/yanboliang
This commit is contained in:
Tugsbayasgalan Manlaibaatar 2023-03-31 13:59:18 -07:00 committed by PyTorch MergeBot
parent ee9a9b7add
commit 7f9533e224
4 changed files with 39 additions and 16 deletions

View File

@ -2281,6 +2281,21 @@ class ExportTests(torch._dynamo.test_case.TestCase):
constraints = [dynamic_dim(y, 0)]
torch._dynamo.export(my_dyn_fn, y, constraints=constraints)
@config.patch(dynamic_shapes=True, capture_dynamic_output_shape_ops=True)
def test_export_dynamic_control_flow_error(self):
def f(x):
if x.nonzero() > 3:
return x.cos()
return x.sin()
with self.assertRaisesRegex(
torch._dynamo.exc.UserError,
"Dynamic control flow is not supported at the moment",
):
gm, _ = torch._dynamo.export(
f, torch.randn(5, 6), aten_graph=True, tracing_mode="symbolic"
)
common_utils.instantiate_parametrized_tests(ExportTests)

View File

@ -2324,7 +2324,6 @@ class ReproTests(torch._dynamo.test_case.TestCase):
self.assertNoUnraisable(f)
@torch._dynamo.config.patch("rewrite_assert_with_torch_assert", True)
def test_rewrite_assert_with_msg(self):
def f(x):
b = x.sin()
@ -2345,7 +2344,6 @@ class ReproTests(torch._dynamo.test_case.TestCase):
with self.assertRaisesRegex(AssertionError, ""):
exported, _ = torch._dynamo.export(f, torch.Tensor([4, 4, 5]))
@torch._dynamo.config.patch("rewrite_assert_with_torch_assert", True)
def test_not_rewrite_assert_for_other_errors(self):
def f(x):
b = x.sin()
@ -2358,7 +2356,6 @@ class ReproTests(torch._dynamo.test_case.TestCase):
with self.assertRaisesRegex(ValueError, "input sum needs to be 3"):
opt_fn(*args)
@torch._dynamo.config.patch("rewrite_assert_with_torch_assert", True)
def test_rewrite_assert_without_msg(self):
def f(x):
b = x.sin()
@ -2372,7 +2369,6 @@ class ReproTests(torch._dynamo.test_case.TestCase):
with self.assertRaisesRegex(AssertionError, ""):
exported, _ = torch._dynamo.export(f, torch.Tensor([4, 4, 5]))
@torch._dynamo.config.patch("rewrite_assert_with_torch_assert", True)
def test_rewrite_assert_with_non_string_msg(self):
def f(x):
b = x.sin()
@ -2391,7 +2387,6 @@ class ReproTests(torch._dynamo.test_case.TestCase):
1,
)
@torch._dynamo.config.patch("rewrite_assert_with_torch_assert", True)
def test_rewrite_assert_noop(self):
def f(x):
b = x.sin()
@ -2430,16 +2425,6 @@ class ReproTests(torch._dynamo.test_case.TestCase):
self.assertTrue(same(f(x2, y), opt_f(x2, y)))
self.assertEqual(cnt.frame_count, 2)
@torch._dynamo.config.patch("rewrite_assert_with_torch_assert", False)
def test_not_rewrite_assert(self):
def f(x):
b = x.sin()
assert x[0] == 3
return x.cos() + b
with self.assertRaisesRegex(torch._dynamo.exc.Unsupported, "generic_jump"):
torch._dynamo.export(f, torch.Tensor([3, 4, 5]))
@torch._dynamo.config.patch(dynamic_shapes=True)
def test_batchnorm_e2e(self):
class Repro(torch.nn.Module):

View File

@ -1,5 +1,6 @@
import os
import textwrap
from enum import auto, Enum
from traceback import extract_stack, format_exc, format_list, FrameSummary
from typing import cast, List
@ -77,6 +78,23 @@ class RecompileError(TorchDynamoException):
pass
class UserErrorType(Enum):
DYNAMIC_CONTROL_FLOW = auto()
class UserError(Unsupported):
def __init__(self, error_type: UserErrorType, msg):
"""
Type of errors that would be valid in Eager, but not supported in TorchDynamo.
The error message should tell user about next actions.
error_type: Type of user error
msg: Actionable error message
"""
super().__init__(msg)
self.error_type = error_type
def unimplemented(msg: str):
assert msg != os.environ.get("BREAK", False)
raise Unsupported(msg)

View File

@ -329,7 +329,12 @@ def generic_jump(truth_fn: typing.Callable[[object], bool], push: bool):
push and self.push(value)
self.jump(inst)
else:
unimplemented(f"generic_jump {typestr(value)}")
# TODO link the torch.cond doc later
raise exc.UserError(
exc.UserErrorType.DYNAMIC_CONTROL_FLOW,
"Dynamic control flow is not supported at the moment. Please use "
"functorch.experimental.control_flow.cond to explicitly capture the control flow",
)
return inner