Improve torch.cond useability: Return UserError with actionable error messages (#98909)

It's part of the effort to improve PT2 Export UX. This PR is to improve the usability of `torch.cond()` by separating user errors from the dynamo internal errors. By definition, user error means the usage of `torch.cond()` violates the restrictions of this API therefore needs users to take action and fix the error.

In this notebook N3363227 we discovered a bunch of limitations of using `torch.cond(pred, true_fn, false_fn, operands)`. In summary, the limitations can be categorized as:
 - predicate restriction (`pred`)
 - operands restriction (`operands`)
 - branch restriction (`true_fn` & `false_fn`)

The error message will be more accurate about where the (user) error is from and more actionable for users to fix it.

For example, `operands` must be a list of tensors and the signature of `true_fn` and `false_fn` must match with the `operands`.
If the operands contains non-tensor types, user will see error message like:
```
torch._dynamo.exc.UserError: Expected a list of tensors but got ["<class 'torch.Tensor'>", "<class 'float'>"]

from user code:
   File "~/pytorch/test/dynamo/test_export.py", line 2504, in f_non_tensor_operands
    return cond(True, lambda x, a: x.sin(), lambda x, a: x.cos(), [x, a])
```
If the signature of the branch function doesn't match with `operands`, user will see error message like:
```
torch._dynamo.exc.UserError: too many positional arguments.
  func = 'false_fn' ~/pytorch/test/dynamo/test_export.py:2514, args = [<class 'torch.Tensor'>, <class 'torch.Tensor'>], kwargs = {}
```
Or if the tensor returned from user defined branches has different metadata, e.g. shapes, dtypes, etc., user will see error message like:
```
TypeError: Expected each tensor to have same metadata but got:
  cond_true_0 returns TensorMetadata(shape=torch.Size([2, 1]), dtype=torch.int64, requires_grad=False, stride=(1, 1), memory_format=torch.contiguous_format, is_quantized=False, qparams={})
  cond_false_0 returns TensorMetadata(shape=torch.Size([1]), dtype=torch.float32, requires_grad=False, stride=(1,), memory_format=torch.contiguous_format, is_quantized=False, qparams={})
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/98909
Approved by: https://github.com/jansel
This commit is contained in:
Guang Yang 2023-04-20 04:06:00 +00:00 committed by PyTorch MergeBot
parent e47e8c9d98
commit aa4ed332c3
8 changed files with 307 additions and 55 deletions

View File

@ -22,6 +22,7 @@ from torch.utils._python_dispatch import (
_pop_mode_temporarily, _pop_mode_temporarily,
) )
from torch.utils._pytree import tree_flatten from torch.utils._pytree import tree_flatten
from torch._dynamo.exc import CondOpArgsMismatchError
@dataclass @dataclass
@ -56,12 +57,22 @@ def trace_cond(proxy_mode, func_overload, pred, true_fn, false_fn, operands):
flat_true_outs, _ = pytree.tree_flatten(true_outs) flat_true_outs, _ = pytree.tree_flatten(true_outs)
flat_false_outs, _ = pytree.tree_flatten(false_outs) flat_false_outs, _ = pytree.tree_flatten(false_outs)
assert(len(flat_true_outs) == len(flat_false_outs)) if len(flat_true_outs) != len(flat_false_outs):
raise CondOpArgsMismatchError(
f"Expected to return same number of outputs but got:"
f"\n {true_fn.__name__} returns {len(flat_true_outs)} item(s)"
f"\n {false_fn.__name__} returns {len(flat_false_outs)} item(s)"
)
for i in range(0, len(flat_true_outs)): for i in range(0, len(flat_true_outs)):
true_out = flat_true_outs[i] true_out = flat_true_outs[i]
false_out = flat_false_outs[i] false_out = flat_false_outs[i]
assert true_out.meta['tensor_meta'] == false_out.meta['tensor_meta'] if true_out.meta['tensor_meta'] != false_out.meta['tensor_meta']:
raise CondOpArgsMismatchError(
f"Expected each tensor to have same metadata but got:"
f"\n {true_fn.__name__} returns {true_out.meta['tensor_meta']}"
f"\n {false_fn.__name__} returns {false_out.meta['tensor_meta']}"
)
# There are probably better ways - I know that create_arg has some self incrementing name # There are probably better ways - I know that create_arg has some self incrementing name
# magic to it, but since we explicitly have to get the name for register_module, # magic to it, but since we explicitly have to get the name for register_module,

View File

@ -2602,22 +2602,29 @@ class ExportTests(torch._dynamo.test_case.TestCase):
def false_fn(x): def false_fn(x):
return x.sin() return x.sin()
def f_pred_traced_as_constant_var(x):
return cond(x.dim() > 2, true_fn, false_fn, [x])
def f_pred_traced_as_symnode_var(x): def f_pred_traced_as_symnode_var(x):
return cond(x.shape[0] > 10, true_fn, false_fn, [x]) return cond(x.shape[0] > 2, true_fn, false_fn, [x])
def f_pred_traced_as_tensor_var(x): def f_pred_traced_as_tensor_var(x):
return cond(x.all(), true_fn, false_fn, [x]) return cond(x.all(), true_fn, false_fn, [x])
example_inputs = (torch.rand(5),) def f_pred_complex_expression_traced_as_symnode_var(x):
return cond(
x.dim() > 1 and x.shape[1] > 5 and x.shape[1] <= 10,
true_fn,
false_fn,
[x],
)
example_inputs = (torch.rand(5, 8),)
for f in [ for f in [
f_pred_traced_as_constant_var,
f_pred_traced_as_symnode_var, f_pred_traced_as_symnode_var,
f_pred_traced_as_tensor_var, f_pred_traced_as_tensor_var,
f_pred_complex_expression_traced_as_symnode_var,
]: ]:
gm, _ = torch._dynamo.export(f, *example_inputs) gm, _ = torch._dynamo.export(
f, *example_inputs, aten_graph=True, tracing_mode="symbolic"
)
self.assertEqual(gm(*example_inputs), f(*example_inputs)) self.assertEqual(gm(*example_inputs), f(*example_inputs))
def test_mixed_real_and_fake_inputs(self): def test_mixed_real_and_fake_inputs(self):
@ -2669,6 +2676,175 @@ class ExportTests(torch._dynamo.test_case.TestCase):
self.assertEqual(gm(*true_inp), f(*true_inp)) self.assertEqual(gm(*true_inp), f(*true_inp))
self.assertEqual(gm(*false_inp), f(*false_inp)) self.assertEqual(gm(*false_inp), f(*false_inp))
def test_cond_raise_user_error_on_missing_args(self):
def true_fn(x):
return x.cos()
def false_fn(x):
return x.sin()
def f(x):
return cond(x.shape[0] > 10, true_fn, false_fn)
example_inputs = (torch.rand(5),)
with self.assertRaisesRegex(
torch._dynamo.exc.UserError,
"Expected 4 arguments",
):
torch._dynamo.export(
f, *example_inputs, aten_graph=True, tracing_mode="symbolic"
)
def test_cond_raise_user_error_on_unsupported_pred(self):
def f_unsupported_pred(x):
pred = torch.nn.Module()
return cond(pred, lambda x: x.sin(), lambda x: x.cos(), [x])
example_inputs = (torch.rand(5),)
with self.assertRaisesRegex(
torch._dynamo.exc.UserError,
"Expected pred to be bool/int or a tensor",
):
torch._dynamo.export(
f_unsupported_pred,
*example_inputs,
aten_graph=True,
tracing_mode="symbolic",
)
def test_cond_raise_user_error_on_non_list_operands(self):
def f_non_list_operands(x):
return cond(torch.tensor(True), lambda x: x.sin(), lambda x: x.cos(), x)
example_inputs = (torch.rand(5),)
with self.assertRaisesRegex(
torch._dynamo.exc.UserError,
"Expected a list but got",
):
torch._dynamo.export(
f_non_list_operands,
*example_inputs,
aten_graph=True,
tracing_mode="symbolic",
)
def test_cond_raise_user_error_on_non_tensor_operands(self):
def f_non_tensor_operands(x):
a: float = 3.14
return cond(
torch.tensor(1234), lambda x, a: x.sin(), lambda x, a: x.cos(), [x, a]
)
example_inputs = (torch.rand(5),)
with self.assertRaisesRegex(
torch._dynamo.exc.UserError,
"Expected a list of tensors",
):
torch._dynamo.export(
f_non_tensor_operands,
*example_inputs,
aten_graph=True,
tracing_mode="symbolic",
)
def test_cond_raise_user_error_on_branch_args_mismatch(self):
def true_fn(x, y):
return x.sin()
def false_fn(x):
return x.cos()
def f_branch_args_mismatch(x, y):
return cond(torch.tensor([[[[100]]]]), true_fn, false_fn, [x, y])
example_inputs = (torch.rand(5), torch.rand(2))
with self.assertRaisesRegex(
torch._dynamo.exc.UserError,
"too many positional arguments",
):
torch._dynamo.export(
f_branch_args_mismatch,
*example_inputs,
aten_graph=True,
tracing_mode="symbolic",
)
def test_cond_raise_user_error_on_branch_return_non_tensor(self):
def f_branch_return_non_tensor(x):
return cond(x.shape[0] <= 5, lambda x: 3.14, lambda x: 3.14, [x])
example_inputs = (torch.rand(5),)
with self.assertRaisesRegex(
torch._dynamo.exc.UserError,
"Expected branch out type to be a single tensor",
):
torch._dynamo.export(
f_branch_return_non_tensor,
*example_inputs,
aten_graph=True,
tracing_mode="symbolic",
)
def test_cond_raise_user_error_on_branch_return_multiple_tenors(self):
def f_branch_return_multiple_tensors(x, y):
return cond(y, lambda x: (x, x), lambda x: (x, x), [x])
example_inputs = (torch.randn(4), torch.randn(2))
with self.assertRaisesRegex(
torch._dynamo.exc.UserError,
"Expected branch out type to be a single tensor",
):
torch._dynamo.export(
f_branch_return_multiple_tensors,
*example_inputs,
aten_graph=True,
tracing_mode="symbolic",
)
def test_cond_raise_user_error_on_mismatch_return_length(self):
def true_fn(x):
return x
def false_fn(x):
return (x, x)
def f_mismatch_return_length(x):
return cond(torch.tensor(100), true_fn, false_fn, [x])
example_inputs = (torch.rand(5),)
with self.assertRaisesRegex(
torch._dynamo.exc.UserError,
"Expected branch out type to be a single tensor",
):
torch._dynamo.export(
f_mismatch_return_length,
*example_inputs,
aten_graph=True,
tracing_mode="symbolic",
)
def test_cond_raise_user_error_on_mismatch_return_tensor_meta(self):
def true_fn(x):
return torch.tensor([[3], [2]])
def false_fn(x):
return torch.tensor([3.14])
def f_return_tensor_mismatch(x):
return cond(x.shape[0] < 3, true_fn, false_fn, [x])
example_inputs = (torch.rand(5),)
with self.assertRaisesRegex(
torch._dynamo.exc.UserError,
"Expected each tensor to have same metadata but got",
):
torch._dynamo.export(
f_return_tensor_mismatch,
*example_inputs,
aten_graph=True,
tracing_mode="symbolic",
)
common_utils.instantiate_parametrized_tests(ExportTests) common_utils.instantiate_parametrized_tests(ExportTests)

View File

@ -6,8 +6,8 @@ from functorch.experimental import control_flow
from functorch.experimental.control_flow import cond from functorch.experimental.control_flow import cond
from functorch.experimental.control_flow import UnsupportedAliasMutationException from functorch.experimental.control_flow import UnsupportedAliasMutationException
from torch.fx.experimental.proxy_tensor import make_fx from torch.fx.experimental.proxy_tensor import make_fx
from torch.testing._internal.common_utils import run_tests, TestCase from torch.testing._internal.common_utils import run_tests, TestCase
from torch._dynamo.exc import CondOpArgsMismatchError
class TestControlFlow(TestCase): class TestControlFlow(TestCase):
def test_cond_no_trace(self): def test_cond_no_trace(self):
@ -378,7 +378,7 @@ class TestControlFlowTraced(TestCase):
out = "".join(out.split()) out = "".join(out.split())
self.assertEqual(code, out) self.assertEqual(code, out)
def test_assert_on_mismatch_type_size(self): def test_raise_error_on_mismatch_type_size(self):
def true_fn(x): def true_fn(x):
return x.sin() return x.sin()
@ -389,11 +389,13 @@ class TestControlFlowTraced(TestCase):
return cond(y, true_fn, false_fn, [x]) return cond(y, true_fn, false_fn, [x])
x = torch.randn(4) x = torch.randn(4)
with self.assertRaises(AssertionError): with self.assertRaisesRegex(
CondOpArgsMismatchError,
"Expected to return same number of outputs but got",
):
make_fx(f)(x, torch.tensor(False)) make_fx(f)(x, torch.tensor(False))
def test_raise_error_on_mismatch_tensor_size(self):
def test_assert_on_mismatch_tensor_size(self):
def true_fn(x): def true_fn(x):
return x.sin() return x.sin()
@ -404,7 +406,10 @@ class TestControlFlowTraced(TestCase):
return cond(y, true_fn, false_fn, [x]) return cond(y, true_fn, false_fn, [x])
x = torch.randn(4) x = torch.randn(4)
with self.assertRaises(AssertionError): with self.assertRaisesRegex(
CondOpArgsMismatchError,
"Expected each tensor to have same metadata but got",
):
make_fx(f)(x, torch.tensor(False)) make_fx(f)(x, torch.tensor(False))
def test_cond_traced_not_nested_fake_tensor(self): def test_cond_traced_not_nested_fake_tensor(self):
@ -540,7 +545,7 @@ class TestControlFlowTraced(TestCase):
out = "".join(out.split()) out = "".join(out.split())
self.assertEqual(code, out) self.assertEqual(code, out)
def test_assert_on_mismatch_type_size_fake_tensor(self): def test_raise_error_on_mismatch_type_size_fake_tensor(self):
def true_fn(x): def true_fn(x):
return x.sin() return x.sin()
@ -551,11 +556,14 @@ class TestControlFlowTraced(TestCase):
return cond(y, true_fn, false_fn, [x]) return cond(y, true_fn, false_fn, [x])
x = torch.randn(4) x = torch.randn(4)
with self.assertRaises(AssertionError): with self.assertRaisesRegex(
CondOpArgsMismatchError,
"Expected to return same number of outputs but got",
):
make_fx(f, tracing_mode="fake")(x, torch.tensor(False)) make_fx(f, tracing_mode="fake")(x, torch.tensor(False))
def test_assert_on_mismatch_tensor_size_fake_tensor(self): def test_raise_error_on_mismatch_tensor_size_fake_tensor(self):
def true_fn(x): def true_fn(x):
return x.sin() return x.sin()
@ -566,7 +574,10 @@ class TestControlFlowTraced(TestCase):
return cond(y, true_fn, false_fn, [x]) return cond(y, true_fn, false_fn, [x])
x = torch.randn(4) x = torch.randn(4)
with self.assertRaises(AssertionError): with self.assertRaisesRegex(
CondOpArgsMismatchError,
"Expected each tensor to have same metadata but got",
):
make_fx(f, tracing_mode="fake")(x, torch.tensor(False)) make_fx(f, tracing_mode="fake")(x, torch.tensor(False))
def check_map_graph(self, gm, key): def check_map_graph(self, gm, key):

View File

@ -45,7 +45,7 @@ else:
globals()[name] = getattr(torch._C._dynamo.eval_frame, name) globals()[name] = getattr(torch._C._dynamo.eval_frame, name)
from . import config, convert_frame, skipfiles, utils from . import config, convert_frame, skipfiles, utils
from .exc import ResetRequired from .exc import CondOpArgsMismatchError, ResetRequired, UserError, UserErrorType
from .mutation_guard import install_generation_tagging_init from .mutation_guard import install_generation_tagging_init
from .types import DynamoCallback from .types import DynamoCallback
from .utils import compile_times from .utils import compile_times
@ -843,12 +843,16 @@ def export(
return torch.fx.Interpreter(graph).run(*args) return torch.fx.Interpreter(graph).run(*args)
with enable_python_dispatcher(), fake_mode: with enable_python_dispatcher(), fake_mode:
graph = make_fx( try:
graph_with_interpreter, graph = make_fx(
decomposition_table=decomposition_table, graph_with_interpreter,
tracing_mode="real", decomposition_table=decomposition_table,
_allow_non_fake_inputs=True, tracing_mode="real",
)(*example_fake_inputs) _allow_non_fake_inputs=True,
)(*example_fake_inputs)
except CondOpArgsMismatchError as e:
# Wrap the internal error to the user-facing error
raise UserError(UserErrorType.DYNAMIC_CONTROL_FLOW, str(e))
new_graph = ChangeInputOutputSignature( new_graph = ChangeInputOutputSignature(
graph, graph,

View File

@ -78,6 +78,20 @@ class RecompileError(TorchDynamoException):
pass pass
class ArgsMismatchError(Unsupported):
def __init__(self, msg):
super().__init__(msg)
class CondOpArgsMismatchError(ArgsMismatchError):
"""
Internal error from cond() due to arguments mismatch.
"""
def __init__(self, msg):
super().__init__(msg)
class UserErrorType(Enum): class UserErrorType(Enum):
DYNAMIC_CONTROL_FLOW = auto() DYNAMIC_CONTROL_FLOW = auto()
ANTI_PATTERN = auto() ANTI_PATTERN = auto()
@ -96,6 +110,7 @@ class UserError(Unsupported):
""" """
super().__init__(msg) super().__init__(msg)
self.error_type = error_type self.error_type = error_type
self.message = msg
class IncorrectUsage(Exception): class IncorrectUsage(Exception):

View File

@ -43,7 +43,7 @@ from .bytecode_transformation import (
unique_id, unique_id,
) )
from .codegen import PyCodegen from .codegen import PyCodegen
from .exc import BackendCompilerFailed, unimplemented, Unsupported from .exc import ArgsMismatchError, BackendCompilerFailed, unimplemented, Unsupported
from .guards import GuardBuilder from .guards import GuardBuilder
from .output_graph import GraphCompileReason, OutputGraph, OutputGraphState from .output_graph import GraphCompileReason, OutputGraph, OutputGraphState
from .replay_record import DummyModule, ExecutionRecorder from .replay_record import DummyModule, ExecutionRecorder
@ -2116,15 +2116,15 @@ class InliningInstructionTranslator(InstructionTranslatorBase):
try: try:
sub_locals, closure_cells = func.bind_args(parent, args, kwargs) sub_locals, closure_cells = func.bind_args(parent, args, kwargs)
except TypeError as e: except TypeError as e:
log.warning( # Wrap the general TypeError during bind_args() to the internal ArgsMismatchError with detailed info
"%s %s %s %s %s", raise ArgsMismatchError(
func.get_filename(), "{reason}.\n func = {func}, args = {args}, kwargs = {kwargs}".format(
func.get_function(), reason=str(e),
args, func=f"'{func.get_name()}' {func.get_filename()}:{func.get_code().co_firstlineno}",
kwargs, args=[arg.python_type() for arg in args],
e, kwargs=kwargs,
),
) )
unimplemented("arg mismatch inlining")
for v in itertools.chain(sub_locals.values(), closure_cells.values()): for v in itertools.chain(sub_locals.values(), closure_cells.values()):
if not isinstance(v, VariableTracker): if not isinstance(v, VariableTracker):

View File

@ -435,7 +435,6 @@ class NestedUserFunctionVariable(BaseUserFunctionVariable):
) )
if self.kwdefaults: if self.kwdefaults:
func.__kwdefaults__ = self.kwdefaults.items func.__kwdefaults__ = self.kwdefaults.items
bound = inspect.signature(func).bind(*args, **kwargs) bound = inspect.signature(func).bind(*args, **kwargs)
bound.apply_defaults() bound.apply_defaults()
result = dict(bound.arguments.items()) result = dict(bound.arguments.items())

View File

@ -15,7 +15,7 @@ from torch._guards import GuardsCheckpointState
from .. import config, variables from .. import config, variables
from ..allowed_functions import torch_get_name from ..allowed_functions import torch_get_name
from ..exc import unimplemented from ..exc import ArgsMismatchError, unimplemented, UserError, UserErrorType
from ..source import GeneratorStateSource, GetItemSource, NNModuleSource from ..source import GeneratorStateSource, GetItemSource, NNModuleSource
from ..utils import ( from ..utils import (
check_constant_args, check_constant_args,
@ -876,7 +876,12 @@ class TorchHigherOrderOperator(VariableTracker):
# Register output to graph # Register output to graph
# Modeled off of compile_and_call_fx_graph # Modeled off of compile_and_call_fx_graph
# TODO: support non single Tensor output # TODO: support non single Tensor output
assert isinstance(output, TensorVariable) if not isinstance(output, TensorVariable):
raise ArgsMismatchError(
"Expected branch out type to be a single tensor but got {}".format(
str(output.python_type())
),
)
tx.output.guards.update(output.guards) tx.output.guards.update(output.guards)
tx.output.create_node( tx.output.create_node(
"output", "output", (tx.output.create_arg((output.as_proxy(),))), {} "output", "output", (tx.output.create_arg((output.as_proxy(),))), {}
@ -904,14 +909,45 @@ class TorchHigherOrderOperator(VariableTracker):
if self.value.__name__ == "cond": if self.value.__name__ == "cond":
# TODO(voz): Support fake tensor dispatch for recursive # TODO(voz): Support fake tensor dispatch for recursive
# ops - see torch/dispatch/_dispatcher.py # ops - see torch/dispatch/_dispatcher.py
assert len(args) == 4 if len(args) != 4:
assert type(args[0]) in ( raise UserError(
TensorVariable, UserErrorType.DYNAMIC_CONTROL_FLOW,
SymNodeVariable, f"Expected 4 arguments but got {len(args)}.\n"
ConstantVariable, f"Usage: cond(pred, true_fn, false_fn, operands)",
), str( )
type(args[0]) # predicate
) # predicate if type(args[0]) not in (ConstantVariable, TensorVariable, SymNodeVariable):
raise UserError(
UserErrorType.DYNAMIC_CONTROL_FLOW,
f"Expected pred to be bool/int or a tensor with single "
f"item but got {str(type(args[0]))} "
f"with original python type {str(args[0].python_type())}.",
)
# operands
if type(args[3]) is not ListVariable:
raise UserError(
UserErrorType.DYNAMIC_CONTROL_FLOW,
f"Expected a list but got {args[3].python_type()}",
)
operands = args[3].unpack_var_sequence(tx)
if not all(
isinstance(operand, (TensorVariable, torch.Tensor))
for operand in operands
):
raise UserError(
UserErrorType.DYNAMIC_CONTROL_FLOW,
"Expected a list of tensors but got {actual_args}".format(
actual_args=[
str(operand.python_type())
if isinstance(operand, VariableTracker)
else str(type(operand))
for operand in operands
],
),
)
# branches
assert isinstance( assert isinstance(
args[1], (UserFunctionVariable, NestedUserFunctionVariable) args[1], (UserFunctionVariable, NestedUserFunctionVariable)
), str( ), str(
@ -922,7 +958,6 @@ class TorchHigherOrderOperator(VariableTracker):
), str( ), str(
type(args[2]) type(args[2])
) # false_fn ) # false_fn
assert type(args[3]) is ListVariable, str(type(args[3])) # args
# Our strategy for tracing the true/false branches of cond # Our strategy for tracing the true/false branches of cond
# are to checkpoint our graphstate, run the true branch, # are to checkpoint our graphstate, run the true branch,
@ -939,14 +974,15 @@ class TorchHigherOrderOperator(VariableTracker):
graph_checkpoint, checkpoint = tx.output.graph, tx.copy_graphstate() graph_checkpoint, checkpoint = tx.output.graph, tx.copy_graphstate()
sub_args = args[3].unpack_var_sequence(tx)
def speculate_branch(branch): def speculate_branch(branch):
# NB: 0 is predicate try:
ix = 1 if branch else 2 # NB: 0 is predicate
return speculate_subgraph( ix = 1 if branch else 2
args[ix], sub_args, graph_checkpoint, checkpoint return speculate_subgraph(
) args[ix], operands, graph_checkpoint, checkpoint
)
except ArgsMismatchError as e:
raise UserError(UserErrorType.DYNAMIC_CONTROL_FLOW, str(e))
( (
true_r, true_r,
@ -988,7 +1024,7 @@ class TorchHigherOrderOperator(VariableTracker):
args[0].as_proxy(), args[0].as_proxy(),
true_node, true_node,
false_node, false_node,
[a.as_proxy() for a in sub_args], [a.as_proxy() for a in operands],
) )
# TODO: assert that the true/false return values are # TODO: assert that the true/false return values are
# consistent # consistent