mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[Dynamo] Add UserError type (#97705)
To get started the dynamo error message improvement effort, we discussed about adding new user error type which covers cases where the user used something that TorchDynamo doesn't support and there is clear actions they can take. Pull Request resolved: https://github.com/pytorch/pytorch/pull/97705 Approved by: https://github.com/anijain2305, https://github.com/yanboliang
This commit is contained in:
parent
ee9a9b7add
commit
7f9533e224
|
|
@ -2281,6 +2281,21 @@ class ExportTests(torch._dynamo.test_case.TestCase):
|
|||
constraints = [dynamic_dim(y, 0)]
|
||||
torch._dynamo.export(my_dyn_fn, y, constraints=constraints)
|
||||
|
||||
@config.patch(dynamic_shapes=True, capture_dynamic_output_shape_ops=True)
|
||||
def test_export_dynamic_control_flow_error(self):
|
||||
def f(x):
|
||||
if x.nonzero() > 3:
|
||||
return x.cos()
|
||||
return x.sin()
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
torch._dynamo.exc.UserError,
|
||||
"Dynamic control flow is not supported at the moment",
|
||||
):
|
||||
gm, _ = torch._dynamo.export(
|
||||
f, torch.randn(5, 6), aten_graph=True, tracing_mode="symbolic"
|
||||
)
|
||||
|
||||
|
||||
common_utils.instantiate_parametrized_tests(ExportTests)
|
||||
|
||||
|
|
|
|||
|
|
@ -2324,7 +2324,6 @@ class ReproTests(torch._dynamo.test_case.TestCase):
|
|||
|
||||
self.assertNoUnraisable(f)
|
||||
|
||||
@torch._dynamo.config.patch("rewrite_assert_with_torch_assert", True)
|
||||
def test_rewrite_assert_with_msg(self):
|
||||
def f(x):
|
||||
b = x.sin()
|
||||
|
|
@ -2345,7 +2344,6 @@ class ReproTests(torch._dynamo.test_case.TestCase):
|
|||
with self.assertRaisesRegex(AssertionError, ""):
|
||||
exported, _ = torch._dynamo.export(f, torch.Tensor([4, 4, 5]))
|
||||
|
||||
@torch._dynamo.config.patch("rewrite_assert_with_torch_assert", True)
|
||||
def test_not_rewrite_assert_for_other_errors(self):
|
||||
def f(x):
|
||||
b = x.sin()
|
||||
|
|
@ -2358,7 +2356,6 @@ class ReproTests(torch._dynamo.test_case.TestCase):
|
|||
with self.assertRaisesRegex(ValueError, "input sum needs to be 3"):
|
||||
opt_fn(*args)
|
||||
|
||||
@torch._dynamo.config.patch("rewrite_assert_with_torch_assert", True)
|
||||
def test_rewrite_assert_without_msg(self):
|
||||
def f(x):
|
||||
b = x.sin()
|
||||
|
|
@ -2372,7 +2369,6 @@ class ReproTests(torch._dynamo.test_case.TestCase):
|
|||
with self.assertRaisesRegex(AssertionError, ""):
|
||||
exported, _ = torch._dynamo.export(f, torch.Tensor([4, 4, 5]))
|
||||
|
||||
@torch._dynamo.config.patch("rewrite_assert_with_torch_assert", True)
|
||||
def test_rewrite_assert_with_non_string_msg(self):
|
||||
def f(x):
|
||||
b = x.sin()
|
||||
|
|
@ -2391,7 +2387,6 @@ class ReproTests(torch._dynamo.test_case.TestCase):
|
|||
1,
|
||||
)
|
||||
|
||||
@torch._dynamo.config.patch("rewrite_assert_with_torch_assert", True)
|
||||
def test_rewrite_assert_noop(self):
|
||||
def f(x):
|
||||
b = x.sin()
|
||||
|
|
@ -2430,16 +2425,6 @@ class ReproTests(torch._dynamo.test_case.TestCase):
|
|||
self.assertTrue(same(f(x2, y), opt_f(x2, y)))
|
||||
self.assertEqual(cnt.frame_count, 2)
|
||||
|
||||
@torch._dynamo.config.patch("rewrite_assert_with_torch_assert", False)
|
||||
def test_not_rewrite_assert(self):
|
||||
def f(x):
|
||||
b = x.sin()
|
||||
assert x[0] == 3
|
||||
return x.cos() + b
|
||||
|
||||
with self.assertRaisesRegex(torch._dynamo.exc.Unsupported, "generic_jump"):
|
||||
torch._dynamo.export(f, torch.Tensor([3, 4, 5]))
|
||||
|
||||
@torch._dynamo.config.patch(dynamic_shapes=True)
|
||||
def test_batchnorm_e2e(self):
|
||||
class Repro(torch.nn.Module):
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
import os
|
||||
import textwrap
|
||||
from enum import auto, Enum
|
||||
from traceback import extract_stack, format_exc, format_list, FrameSummary
|
||||
from typing import cast, List
|
||||
|
||||
|
|
@ -77,6 +78,23 @@ class RecompileError(TorchDynamoException):
|
|||
pass
|
||||
|
||||
|
||||
class UserErrorType(Enum):
|
||||
DYNAMIC_CONTROL_FLOW = auto()
|
||||
|
||||
|
||||
class UserError(Unsupported):
|
||||
def __init__(self, error_type: UserErrorType, msg):
|
||||
"""
|
||||
Type of errors that would be valid in Eager, but not supported in TorchDynamo.
|
||||
The error message should tell user about next actions.
|
||||
|
||||
error_type: Type of user error
|
||||
msg: Actionable error message
|
||||
"""
|
||||
super().__init__(msg)
|
||||
self.error_type = error_type
|
||||
|
||||
|
||||
def unimplemented(msg: str):
|
||||
assert msg != os.environ.get("BREAK", False)
|
||||
raise Unsupported(msg)
|
||||
|
|
|
|||
|
|
@ -329,7 +329,12 @@ def generic_jump(truth_fn: typing.Callable[[object], bool], push: bool):
|
|||
push and self.push(value)
|
||||
self.jump(inst)
|
||||
else:
|
||||
unimplemented(f"generic_jump {typestr(value)}")
|
||||
# TODO link the torch.cond doc later
|
||||
raise exc.UserError(
|
||||
exc.UserErrorType.DYNAMIC_CONTROL_FLOW,
|
||||
"Dynamic control flow is not supported at the moment. Please use "
|
||||
"functorch.experimental.control_flow.cond to explicitly capture the control flow",
|
||||
)
|
||||
|
||||
return inner
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user