Improve torch.cond useability: Return UserError with actionable error messages (#98909)

It's part of the effort to improve PT2 Export UX. This PR is to improve the usability of `torch.cond()` by separating user errors from the dynamo internal errors. By definition, user error means the usage of `torch.cond()` violates the restrictions of this API therefore needs users to take action and fix the error.

In this notebook N3363227 we discovered a bunch of limitations of using `torch.cond(pred, true_fn, false_fn, operands)`. In summary, the limitations can be categorized as:
 - predicate restriction (`pred`)
 - operands restriction (`operands`)
 - branch restriction (`true_fn` & `false_fn`)

The error message will be more accurate about where the (user) error is from and more actionable for users to fix it.

For example, `operands` must be a list of tensors and the signature of `true_fn` and `false_fn` must match with the `operands`.
If the operands contains non-tensor types, user will see error message like:
```
torch._dynamo.exc.UserError: Expected a list of tensors but got ["<class 'torch.Tensor'>", "<class 'float'>"]

from user code:
   File "~/pytorch/test/dynamo/test_export.py", line 2504, in f_non_tensor_operands
    return cond(True, lambda x, a: x.sin(), lambda x, a: x.cos(), [x, a])
```
If the signature of the branch function doesn't match with `operands`, user will see error message like:
```
torch._dynamo.exc.UserError: too many positional arguments.
  func = 'false_fn' ~/pytorch/test/dynamo/test_export.py:2514, args = [<class 'torch.Tensor'>, <class 'torch.Tensor'>], kwargs = {}
```
Or if the tensor returned from user defined branches has different metadata, e.g. shapes, dtypes, etc., user will see error message like:
```
TypeError: Expected each tensor to have same metadata but got:
  cond_true_0 returns TensorMetadata(shape=torch.Size([2, 1]), dtype=torch.int64, requires_grad=False, stride=(1, 1), memory_format=torch.contiguous_format, is_quantized=False, qparams={})
  cond_false_0 returns TensorMetadata(shape=torch.Size([1]), dtype=torch.float32, requires_grad=False, stride=(1,), memory_format=torch.contiguous_format, is_quantized=False, qparams={})
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/98909
Approved by: https://github.com/jansel
This commit is contained in:
Guang Yang 2023-04-20 04:06:00 +00:00 committed by PyTorch MergeBot
parent e47e8c9d98
commit aa4ed332c3
8 changed files with 307 additions and 55 deletions

View File

@ -22,6 +22,7 @@ from torch.utils._python_dispatch import (
_pop_mode_temporarily,
)
from torch.utils._pytree import tree_flatten
from torch._dynamo.exc import CondOpArgsMismatchError
@dataclass
@ -56,12 +57,22 @@ def trace_cond(proxy_mode, func_overload, pred, true_fn, false_fn, operands):
flat_true_outs, _ = pytree.tree_flatten(true_outs)
flat_false_outs, _ = pytree.tree_flatten(false_outs)
assert(len(flat_true_outs) == len(flat_false_outs))
if len(flat_true_outs) != len(flat_false_outs):
raise CondOpArgsMismatchError(
f"Expected to return same number of outputs but got:"
f"\n {true_fn.__name__} returns {len(flat_true_outs)} item(s)"
f"\n {false_fn.__name__} returns {len(flat_false_outs)} item(s)"
)
for i in range(0, len(flat_true_outs)):
true_out = flat_true_outs[i]
false_out = flat_false_outs[i]
assert true_out.meta['tensor_meta'] == false_out.meta['tensor_meta']
if true_out.meta['tensor_meta'] != false_out.meta['tensor_meta']:
raise CondOpArgsMismatchError(
f"Expected each tensor to have same metadata but got:"
f"\n {true_fn.__name__} returns {true_out.meta['tensor_meta']}"
f"\n {false_fn.__name__} returns {false_out.meta['tensor_meta']}"
)
# There are probably better ways - I know that create_arg has some self incrementing name
# magic to it, but since we explicitly have to get the name for register_module,

View File

@ -2602,22 +2602,29 @@ class ExportTests(torch._dynamo.test_case.TestCase):
def false_fn(x):
return x.sin()
def f_pred_traced_as_constant_var(x):
return cond(x.dim() > 2, true_fn, false_fn, [x])
def f_pred_traced_as_symnode_var(x):
return cond(x.shape[0] > 10, true_fn, false_fn, [x])
return cond(x.shape[0] > 2, true_fn, false_fn, [x])
def f_pred_traced_as_tensor_var(x):
return cond(x.all(), true_fn, false_fn, [x])
example_inputs = (torch.rand(5),)
def f_pred_complex_expression_traced_as_symnode_var(x):
return cond(
x.dim() > 1 and x.shape[1] > 5 and x.shape[1] <= 10,
true_fn,
false_fn,
[x],
)
example_inputs = (torch.rand(5, 8),)
for f in [
f_pred_traced_as_constant_var,
f_pred_traced_as_symnode_var,
f_pred_traced_as_tensor_var,
f_pred_complex_expression_traced_as_symnode_var,
]:
gm, _ = torch._dynamo.export(f, *example_inputs)
gm, _ = torch._dynamo.export(
f, *example_inputs, aten_graph=True, tracing_mode="symbolic"
)
self.assertEqual(gm(*example_inputs), f(*example_inputs))
def test_mixed_real_and_fake_inputs(self):
@ -2669,6 +2676,175 @@ class ExportTests(torch._dynamo.test_case.TestCase):
self.assertEqual(gm(*true_inp), f(*true_inp))
self.assertEqual(gm(*false_inp), f(*false_inp))
def test_cond_raise_user_error_on_missing_args(self):
def true_fn(x):
return x.cos()
def false_fn(x):
return x.sin()
def f(x):
return cond(x.shape[0] > 10, true_fn, false_fn)
example_inputs = (torch.rand(5),)
with self.assertRaisesRegex(
torch._dynamo.exc.UserError,
"Expected 4 arguments",
):
torch._dynamo.export(
f, *example_inputs, aten_graph=True, tracing_mode="symbolic"
)
def test_cond_raise_user_error_on_unsupported_pred(self):
def f_unsupported_pred(x):
pred = torch.nn.Module()
return cond(pred, lambda x: x.sin(), lambda x: x.cos(), [x])
example_inputs = (torch.rand(5),)
with self.assertRaisesRegex(
torch._dynamo.exc.UserError,
"Expected pred to be bool/int or a tensor",
):
torch._dynamo.export(
f_unsupported_pred,
*example_inputs,
aten_graph=True,
tracing_mode="symbolic",
)
def test_cond_raise_user_error_on_non_list_operands(self):
def f_non_list_operands(x):
return cond(torch.tensor(True), lambda x: x.sin(), lambda x: x.cos(), x)
example_inputs = (torch.rand(5),)
with self.assertRaisesRegex(
torch._dynamo.exc.UserError,
"Expected a list but got",
):
torch._dynamo.export(
f_non_list_operands,
*example_inputs,
aten_graph=True,
tracing_mode="symbolic",
)
def test_cond_raise_user_error_on_non_tensor_operands(self):
def f_non_tensor_operands(x):
a: float = 3.14
return cond(
torch.tensor(1234), lambda x, a: x.sin(), lambda x, a: x.cos(), [x, a]
)
example_inputs = (torch.rand(5),)
with self.assertRaisesRegex(
torch._dynamo.exc.UserError,
"Expected a list of tensors",
):
torch._dynamo.export(
f_non_tensor_operands,
*example_inputs,
aten_graph=True,
tracing_mode="symbolic",
)
def test_cond_raise_user_error_on_branch_args_mismatch(self):
def true_fn(x, y):
return x.sin()
def false_fn(x):
return x.cos()
def f_branch_args_mismatch(x, y):
return cond(torch.tensor([[[[100]]]]), true_fn, false_fn, [x, y])
example_inputs = (torch.rand(5), torch.rand(2))
with self.assertRaisesRegex(
torch._dynamo.exc.UserError,
"too many positional arguments",
):
torch._dynamo.export(
f_branch_args_mismatch,
*example_inputs,
aten_graph=True,
tracing_mode="symbolic",
)
def test_cond_raise_user_error_on_branch_return_non_tensor(self):
def f_branch_return_non_tensor(x):
return cond(x.shape[0] <= 5, lambda x: 3.14, lambda x: 3.14, [x])
example_inputs = (torch.rand(5),)
with self.assertRaisesRegex(
torch._dynamo.exc.UserError,
"Expected branch out type to be a single tensor",
):
torch._dynamo.export(
f_branch_return_non_tensor,
*example_inputs,
aten_graph=True,
tracing_mode="symbolic",
)
def test_cond_raise_user_error_on_branch_return_multiple_tenors(self):
def f_branch_return_multiple_tensors(x, y):
return cond(y, lambda x: (x, x), lambda x: (x, x), [x])
example_inputs = (torch.randn(4), torch.randn(2))
with self.assertRaisesRegex(
torch._dynamo.exc.UserError,
"Expected branch out type to be a single tensor",
):
torch._dynamo.export(
f_branch_return_multiple_tensors,
*example_inputs,
aten_graph=True,
tracing_mode="symbolic",
)
def test_cond_raise_user_error_on_mismatch_return_length(self):
def true_fn(x):
return x
def false_fn(x):
return (x, x)
def f_mismatch_return_length(x):
return cond(torch.tensor(100), true_fn, false_fn, [x])
example_inputs = (torch.rand(5),)
with self.assertRaisesRegex(
torch._dynamo.exc.UserError,
"Expected branch out type to be a single tensor",
):
torch._dynamo.export(
f_mismatch_return_length,
*example_inputs,
aten_graph=True,
tracing_mode="symbolic",
)
def test_cond_raise_user_error_on_mismatch_return_tensor_meta(self):
def true_fn(x):
return torch.tensor([[3], [2]])
def false_fn(x):
return torch.tensor([3.14])
def f_return_tensor_mismatch(x):
return cond(x.shape[0] < 3, true_fn, false_fn, [x])
example_inputs = (torch.rand(5),)
with self.assertRaisesRegex(
torch._dynamo.exc.UserError,
"Expected each tensor to have same metadata but got",
):
torch._dynamo.export(
f_return_tensor_mismatch,
*example_inputs,
aten_graph=True,
tracing_mode="symbolic",
)
common_utils.instantiate_parametrized_tests(ExportTests)

View File

@ -6,8 +6,8 @@ from functorch.experimental import control_flow
from functorch.experimental.control_flow import cond
from functorch.experimental.control_flow import UnsupportedAliasMutationException
from torch.fx.experimental.proxy_tensor import make_fx
from torch.testing._internal.common_utils import run_tests, TestCase
from torch._dynamo.exc import CondOpArgsMismatchError
class TestControlFlow(TestCase):
def test_cond_no_trace(self):
@ -378,7 +378,7 @@ class TestControlFlowTraced(TestCase):
out = "".join(out.split())
self.assertEqual(code, out)
def test_assert_on_mismatch_type_size(self):
def test_raise_error_on_mismatch_type_size(self):
def true_fn(x):
return x.sin()
@ -389,11 +389,13 @@ class TestControlFlowTraced(TestCase):
return cond(y, true_fn, false_fn, [x])
x = torch.randn(4)
with self.assertRaises(AssertionError):
with self.assertRaisesRegex(
CondOpArgsMismatchError,
"Expected to return same number of outputs but got",
):
make_fx(f)(x, torch.tensor(False))
def test_assert_on_mismatch_tensor_size(self):
def test_raise_error_on_mismatch_tensor_size(self):
def true_fn(x):
return x.sin()
@ -404,7 +406,10 @@ class TestControlFlowTraced(TestCase):
return cond(y, true_fn, false_fn, [x])
x = torch.randn(4)
with self.assertRaises(AssertionError):
with self.assertRaisesRegex(
CondOpArgsMismatchError,
"Expected each tensor to have same metadata but got",
):
make_fx(f)(x, torch.tensor(False))
def test_cond_traced_not_nested_fake_tensor(self):
@ -540,7 +545,7 @@ class TestControlFlowTraced(TestCase):
out = "".join(out.split())
self.assertEqual(code, out)
def test_assert_on_mismatch_type_size_fake_tensor(self):
def test_raise_error_on_mismatch_type_size_fake_tensor(self):
def true_fn(x):
return x.sin()
@ -551,11 +556,14 @@ class TestControlFlowTraced(TestCase):
return cond(y, true_fn, false_fn, [x])
x = torch.randn(4)
with self.assertRaises(AssertionError):
with self.assertRaisesRegex(
CondOpArgsMismatchError,
"Expected to return same number of outputs but got",
):
make_fx(f, tracing_mode="fake")(x, torch.tensor(False))
def test_assert_on_mismatch_tensor_size_fake_tensor(self):
def test_raise_error_on_mismatch_tensor_size_fake_tensor(self):
def true_fn(x):
return x.sin()
@ -566,7 +574,10 @@ class TestControlFlowTraced(TestCase):
return cond(y, true_fn, false_fn, [x])
x = torch.randn(4)
with self.assertRaises(AssertionError):
with self.assertRaisesRegex(
CondOpArgsMismatchError,
"Expected each tensor to have same metadata but got",
):
make_fx(f, tracing_mode="fake")(x, torch.tensor(False))
def check_map_graph(self, gm, key):

View File

@ -45,7 +45,7 @@ else:
globals()[name] = getattr(torch._C._dynamo.eval_frame, name)
from . import config, convert_frame, skipfiles, utils
from .exc import ResetRequired
from .exc import CondOpArgsMismatchError, ResetRequired, UserError, UserErrorType
from .mutation_guard import install_generation_tagging_init
from .types import DynamoCallback
from .utils import compile_times
@ -843,12 +843,16 @@ def export(
return torch.fx.Interpreter(graph).run(*args)
with enable_python_dispatcher(), fake_mode:
graph = make_fx(
graph_with_interpreter,
decomposition_table=decomposition_table,
tracing_mode="real",
_allow_non_fake_inputs=True,
)(*example_fake_inputs)
try:
graph = make_fx(
graph_with_interpreter,
decomposition_table=decomposition_table,
tracing_mode="real",
_allow_non_fake_inputs=True,
)(*example_fake_inputs)
except CondOpArgsMismatchError as e:
# Wrap the internal error to the user-facing error
raise UserError(UserErrorType.DYNAMIC_CONTROL_FLOW, str(e))
new_graph = ChangeInputOutputSignature(
graph,

View File

@ -78,6 +78,20 @@ class RecompileError(TorchDynamoException):
pass
class ArgsMismatchError(Unsupported):
def __init__(self, msg):
super().__init__(msg)
class CondOpArgsMismatchError(ArgsMismatchError):
"""
Internal error from cond() due to arguments mismatch.
"""
def __init__(self, msg):
super().__init__(msg)
class UserErrorType(Enum):
DYNAMIC_CONTROL_FLOW = auto()
ANTI_PATTERN = auto()
@ -96,6 +110,7 @@ class UserError(Unsupported):
"""
super().__init__(msg)
self.error_type = error_type
self.message = msg
class IncorrectUsage(Exception):

View File

@ -43,7 +43,7 @@ from .bytecode_transformation import (
unique_id,
)
from .codegen import PyCodegen
from .exc import BackendCompilerFailed, unimplemented, Unsupported
from .exc import ArgsMismatchError, BackendCompilerFailed, unimplemented, Unsupported
from .guards import GuardBuilder
from .output_graph import GraphCompileReason, OutputGraph, OutputGraphState
from .replay_record import DummyModule, ExecutionRecorder
@ -2116,15 +2116,15 @@ class InliningInstructionTranslator(InstructionTranslatorBase):
try:
sub_locals, closure_cells = func.bind_args(parent, args, kwargs)
except TypeError as e:
log.warning(
"%s %s %s %s %s",
func.get_filename(),
func.get_function(),
args,
kwargs,
e,
# Wrap the general TypeError during bind_args() to the internal ArgsMismatchError with detailed info
raise ArgsMismatchError(
"{reason}.\n func = {func}, args = {args}, kwargs = {kwargs}".format(
reason=str(e),
func=f"'{func.get_name()}' {func.get_filename()}:{func.get_code().co_firstlineno}",
args=[arg.python_type() for arg in args],
kwargs=kwargs,
),
)
unimplemented("arg mismatch inlining")
for v in itertools.chain(sub_locals.values(), closure_cells.values()):
if not isinstance(v, VariableTracker):

View File

@ -435,7 +435,6 @@ class NestedUserFunctionVariable(BaseUserFunctionVariable):
)
if self.kwdefaults:
func.__kwdefaults__ = self.kwdefaults.items
bound = inspect.signature(func).bind(*args, **kwargs)
bound.apply_defaults()
result = dict(bound.arguments.items())

View File

@ -15,7 +15,7 @@ from torch._guards import GuardsCheckpointState
from .. import config, variables
from ..allowed_functions import torch_get_name
from ..exc import unimplemented
from ..exc import ArgsMismatchError, unimplemented, UserError, UserErrorType
from ..source import GeneratorStateSource, GetItemSource, NNModuleSource
from ..utils import (
check_constant_args,
@ -876,7 +876,12 @@ class TorchHigherOrderOperator(VariableTracker):
# Register output to graph
# Modeled off of compile_and_call_fx_graph
# TODO: support non single Tensor output
assert isinstance(output, TensorVariable)
if not isinstance(output, TensorVariable):
raise ArgsMismatchError(
"Expected branch out type to be a single tensor but got {}".format(
str(output.python_type())
),
)
tx.output.guards.update(output.guards)
tx.output.create_node(
"output", "output", (tx.output.create_arg((output.as_proxy(),))), {}
@ -904,14 +909,45 @@ class TorchHigherOrderOperator(VariableTracker):
if self.value.__name__ == "cond":
# TODO(voz): Support fake tensor dispatch for recursive
# ops - see torch/dispatch/_dispatcher.py
assert len(args) == 4
assert type(args[0]) in (
TensorVariable,
SymNodeVariable,
ConstantVariable,
), str(
type(args[0])
) # predicate
if len(args) != 4:
raise UserError(
UserErrorType.DYNAMIC_CONTROL_FLOW,
f"Expected 4 arguments but got {len(args)}.\n"
f"Usage: cond(pred, true_fn, false_fn, operands)",
)
# predicate
if type(args[0]) not in (ConstantVariable, TensorVariable, SymNodeVariable):
raise UserError(
UserErrorType.DYNAMIC_CONTROL_FLOW,
f"Expected pred to be bool/int or a tensor with single "
f"item but got {str(type(args[0]))} "
f"with original python type {str(args[0].python_type())}.",
)
# operands
if type(args[3]) is not ListVariable:
raise UserError(
UserErrorType.DYNAMIC_CONTROL_FLOW,
f"Expected a list but got {args[3].python_type()}",
)
operands = args[3].unpack_var_sequence(tx)
if not all(
isinstance(operand, (TensorVariable, torch.Tensor))
for operand in operands
):
raise UserError(
UserErrorType.DYNAMIC_CONTROL_FLOW,
"Expected a list of tensors but got {actual_args}".format(
actual_args=[
str(operand.python_type())
if isinstance(operand, VariableTracker)
else str(type(operand))
for operand in operands
],
),
)
# branches
assert isinstance(
args[1], (UserFunctionVariable, NestedUserFunctionVariable)
), str(
@ -922,7 +958,6 @@ class TorchHigherOrderOperator(VariableTracker):
), str(
type(args[2])
) # false_fn
assert type(args[3]) is ListVariable, str(type(args[3])) # args
# Our strategy for tracing the true/false branches of cond
# are to checkpoint our graphstate, run the true branch,
@ -939,14 +974,15 @@ class TorchHigherOrderOperator(VariableTracker):
graph_checkpoint, checkpoint = tx.output.graph, tx.copy_graphstate()
sub_args = args[3].unpack_var_sequence(tx)
def speculate_branch(branch):
# NB: 0 is predicate
ix = 1 if branch else 2
return speculate_subgraph(
args[ix], sub_args, graph_checkpoint, checkpoint
)
try:
# NB: 0 is predicate
ix = 1 if branch else 2
return speculate_subgraph(
args[ix], operands, graph_checkpoint, checkpoint
)
except ArgsMismatchError as e:
raise UserError(UserErrorType.DYNAMIC_CONTROL_FLOW, str(e))
(
true_r,
@ -988,7 +1024,7 @@ class TorchHigherOrderOperator(VariableTracker):
args[0].as_proxy(),
true_node,
false_node,
[a.as_proxy() for a in sub_args],
[a.as_proxy() for a in operands],
)
# TODO: assert that the true/false return values are
# consistent