mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Improve torch.cond useability: Return UserError with actionable error messages (#98909)
It's part of the effort to improve PT2 Export UX. This PR is to improve the usability of `torch.cond()` by separating user errors from the dynamo internal errors. By definition, user error means the usage of `torch.cond()` violates the restrictions of this API therefore needs users to take action and fix the error.
In this notebook N3363227 we discovered a bunch of limitations of using `torch.cond(pred, true_fn, false_fn, operands)`. In summary, the limitations can be categorized as:
- predicate restriction (`pred`)
- operands restriction (`operands`)
- branch restriction (`true_fn` & `false_fn`)
The error message will be more accurate about where the (user) error is from and more actionable for users to fix it.
For example, `operands` must be a list of tensors and the signature of `true_fn` and `false_fn` must match with the `operands`.
If the operands contains non-tensor types, user will see error message like:
```
torch._dynamo.exc.UserError: Expected a list of tensors but got ["<class 'torch.Tensor'>", "<class 'float'>"]
from user code:
File "~/pytorch/test/dynamo/test_export.py", line 2504, in f_non_tensor_operands
return cond(True, lambda x, a: x.sin(), lambda x, a: x.cos(), [x, a])
```
If the signature of the branch function doesn't match with `operands`, user will see error message like:
```
torch._dynamo.exc.UserError: too many positional arguments.
func = 'false_fn' ~/pytorch/test/dynamo/test_export.py:2514, args = [<class 'torch.Tensor'>, <class 'torch.Tensor'>], kwargs = {}
```
Or if the tensor returned from user defined branches has different metadata, e.g. shapes, dtypes, etc., user will see error message like:
```
TypeError: Expected each tensor to have same metadata but got:
cond_true_0 returns TensorMetadata(shape=torch.Size([2, 1]), dtype=torch.int64, requires_grad=False, stride=(1, 1), memory_format=torch.contiguous_format, is_quantized=False, qparams={})
cond_false_0 returns TensorMetadata(shape=torch.Size([1]), dtype=torch.float32, requires_grad=False, stride=(1,), memory_format=torch.contiguous_format, is_quantized=False, qparams={})
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/98909
Approved by: https://github.com/jansel
This commit is contained in:
parent
e47e8c9d98
commit
aa4ed332c3
|
|
@ -22,6 +22,7 @@ from torch.utils._python_dispatch import (
|
||||||
_pop_mode_temporarily,
|
_pop_mode_temporarily,
|
||||||
)
|
)
|
||||||
from torch.utils._pytree import tree_flatten
|
from torch.utils._pytree import tree_flatten
|
||||||
|
from torch._dynamo.exc import CondOpArgsMismatchError
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|
@ -56,12 +57,22 @@ def trace_cond(proxy_mode, func_overload, pred, true_fn, false_fn, operands):
|
||||||
|
|
||||||
flat_true_outs, _ = pytree.tree_flatten(true_outs)
|
flat_true_outs, _ = pytree.tree_flatten(true_outs)
|
||||||
flat_false_outs, _ = pytree.tree_flatten(false_outs)
|
flat_false_outs, _ = pytree.tree_flatten(false_outs)
|
||||||
assert(len(flat_true_outs) == len(flat_false_outs))
|
if len(flat_true_outs) != len(flat_false_outs):
|
||||||
|
raise CondOpArgsMismatchError(
|
||||||
|
f"Expected to return same number of outputs but got:"
|
||||||
|
f"\n {true_fn.__name__} returns {len(flat_true_outs)} item(s)"
|
||||||
|
f"\n {false_fn.__name__} returns {len(flat_false_outs)} item(s)"
|
||||||
|
)
|
||||||
|
|
||||||
for i in range(0, len(flat_true_outs)):
|
for i in range(0, len(flat_true_outs)):
|
||||||
true_out = flat_true_outs[i]
|
true_out = flat_true_outs[i]
|
||||||
false_out = flat_false_outs[i]
|
false_out = flat_false_outs[i]
|
||||||
assert true_out.meta['tensor_meta'] == false_out.meta['tensor_meta']
|
if true_out.meta['tensor_meta'] != false_out.meta['tensor_meta']:
|
||||||
|
raise CondOpArgsMismatchError(
|
||||||
|
f"Expected each tensor to have same metadata but got:"
|
||||||
|
f"\n {true_fn.__name__} returns {true_out.meta['tensor_meta']}"
|
||||||
|
f"\n {false_fn.__name__} returns {false_out.meta['tensor_meta']}"
|
||||||
|
)
|
||||||
|
|
||||||
# There are probably better ways - I know that create_arg has some self incrementing name
|
# There are probably better ways - I know that create_arg has some self incrementing name
|
||||||
# magic to it, but since we explicitly have to get the name for register_module,
|
# magic to it, but since we explicitly have to get the name for register_module,
|
||||||
|
|
|
||||||
|
|
@ -2602,22 +2602,29 @@ class ExportTests(torch._dynamo.test_case.TestCase):
|
||||||
def false_fn(x):
|
def false_fn(x):
|
||||||
return x.sin()
|
return x.sin()
|
||||||
|
|
||||||
def f_pred_traced_as_constant_var(x):
|
|
||||||
return cond(x.dim() > 2, true_fn, false_fn, [x])
|
|
||||||
|
|
||||||
def f_pred_traced_as_symnode_var(x):
|
def f_pred_traced_as_symnode_var(x):
|
||||||
return cond(x.shape[0] > 10, true_fn, false_fn, [x])
|
return cond(x.shape[0] > 2, true_fn, false_fn, [x])
|
||||||
|
|
||||||
def f_pred_traced_as_tensor_var(x):
|
def f_pred_traced_as_tensor_var(x):
|
||||||
return cond(x.all(), true_fn, false_fn, [x])
|
return cond(x.all(), true_fn, false_fn, [x])
|
||||||
|
|
||||||
example_inputs = (torch.rand(5),)
|
def f_pred_complex_expression_traced_as_symnode_var(x):
|
||||||
|
return cond(
|
||||||
|
x.dim() > 1 and x.shape[1] > 5 and x.shape[1] <= 10,
|
||||||
|
true_fn,
|
||||||
|
false_fn,
|
||||||
|
[x],
|
||||||
|
)
|
||||||
|
|
||||||
|
example_inputs = (torch.rand(5, 8),)
|
||||||
for f in [
|
for f in [
|
||||||
f_pred_traced_as_constant_var,
|
|
||||||
f_pred_traced_as_symnode_var,
|
f_pred_traced_as_symnode_var,
|
||||||
f_pred_traced_as_tensor_var,
|
f_pred_traced_as_tensor_var,
|
||||||
|
f_pred_complex_expression_traced_as_symnode_var,
|
||||||
]:
|
]:
|
||||||
gm, _ = torch._dynamo.export(f, *example_inputs)
|
gm, _ = torch._dynamo.export(
|
||||||
|
f, *example_inputs, aten_graph=True, tracing_mode="symbolic"
|
||||||
|
)
|
||||||
self.assertEqual(gm(*example_inputs), f(*example_inputs))
|
self.assertEqual(gm(*example_inputs), f(*example_inputs))
|
||||||
|
|
||||||
def test_mixed_real_and_fake_inputs(self):
|
def test_mixed_real_and_fake_inputs(self):
|
||||||
|
|
@ -2669,6 +2676,175 @@ class ExportTests(torch._dynamo.test_case.TestCase):
|
||||||
self.assertEqual(gm(*true_inp), f(*true_inp))
|
self.assertEqual(gm(*true_inp), f(*true_inp))
|
||||||
self.assertEqual(gm(*false_inp), f(*false_inp))
|
self.assertEqual(gm(*false_inp), f(*false_inp))
|
||||||
|
|
||||||
|
def test_cond_raise_user_error_on_missing_args(self):
|
||||||
|
def true_fn(x):
|
||||||
|
return x.cos()
|
||||||
|
|
||||||
|
def false_fn(x):
|
||||||
|
return x.sin()
|
||||||
|
|
||||||
|
def f(x):
|
||||||
|
return cond(x.shape[0] > 10, true_fn, false_fn)
|
||||||
|
|
||||||
|
example_inputs = (torch.rand(5),)
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
torch._dynamo.exc.UserError,
|
||||||
|
"Expected 4 arguments",
|
||||||
|
):
|
||||||
|
torch._dynamo.export(
|
||||||
|
f, *example_inputs, aten_graph=True, tracing_mode="symbolic"
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_cond_raise_user_error_on_unsupported_pred(self):
|
||||||
|
def f_unsupported_pred(x):
|
||||||
|
pred = torch.nn.Module()
|
||||||
|
return cond(pred, lambda x: x.sin(), lambda x: x.cos(), [x])
|
||||||
|
|
||||||
|
example_inputs = (torch.rand(5),)
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
torch._dynamo.exc.UserError,
|
||||||
|
"Expected pred to be bool/int or a tensor",
|
||||||
|
):
|
||||||
|
torch._dynamo.export(
|
||||||
|
f_unsupported_pred,
|
||||||
|
*example_inputs,
|
||||||
|
aten_graph=True,
|
||||||
|
tracing_mode="symbolic",
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_cond_raise_user_error_on_non_list_operands(self):
|
||||||
|
def f_non_list_operands(x):
|
||||||
|
return cond(torch.tensor(True), lambda x: x.sin(), lambda x: x.cos(), x)
|
||||||
|
|
||||||
|
example_inputs = (torch.rand(5),)
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
torch._dynamo.exc.UserError,
|
||||||
|
"Expected a list but got",
|
||||||
|
):
|
||||||
|
torch._dynamo.export(
|
||||||
|
f_non_list_operands,
|
||||||
|
*example_inputs,
|
||||||
|
aten_graph=True,
|
||||||
|
tracing_mode="symbolic",
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_cond_raise_user_error_on_non_tensor_operands(self):
|
||||||
|
def f_non_tensor_operands(x):
|
||||||
|
a: float = 3.14
|
||||||
|
return cond(
|
||||||
|
torch.tensor(1234), lambda x, a: x.sin(), lambda x, a: x.cos(), [x, a]
|
||||||
|
)
|
||||||
|
|
||||||
|
example_inputs = (torch.rand(5),)
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
torch._dynamo.exc.UserError,
|
||||||
|
"Expected a list of tensors",
|
||||||
|
):
|
||||||
|
torch._dynamo.export(
|
||||||
|
f_non_tensor_operands,
|
||||||
|
*example_inputs,
|
||||||
|
aten_graph=True,
|
||||||
|
tracing_mode="symbolic",
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_cond_raise_user_error_on_branch_args_mismatch(self):
|
||||||
|
def true_fn(x, y):
|
||||||
|
return x.sin()
|
||||||
|
|
||||||
|
def false_fn(x):
|
||||||
|
return x.cos()
|
||||||
|
|
||||||
|
def f_branch_args_mismatch(x, y):
|
||||||
|
return cond(torch.tensor([[[[100]]]]), true_fn, false_fn, [x, y])
|
||||||
|
|
||||||
|
example_inputs = (torch.rand(5), torch.rand(2))
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
torch._dynamo.exc.UserError,
|
||||||
|
"too many positional arguments",
|
||||||
|
):
|
||||||
|
torch._dynamo.export(
|
||||||
|
f_branch_args_mismatch,
|
||||||
|
*example_inputs,
|
||||||
|
aten_graph=True,
|
||||||
|
tracing_mode="symbolic",
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_cond_raise_user_error_on_branch_return_non_tensor(self):
|
||||||
|
def f_branch_return_non_tensor(x):
|
||||||
|
return cond(x.shape[0] <= 5, lambda x: 3.14, lambda x: 3.14, [x])
|
||||||
|
|
||||||
|
example_inputs = (torch.rand(5),)
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
torch._dynamo.exc.UserError,
|
||||||
|
"Expected branch out type to be a single tensor",
|
||||||
|
):
|
||||||
|
torch._dynamo.export(
|
||||||
|
f_branch_return_non_tensor,
|
||||||
|
*example_inputs,
|
||||||
|
aten_graph=True,
|
||||||
|
tracing_mode="symbolic",
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_cond_raise_user_error_on_branch_return_multiple_tenors(self):
|
||||||
|
def f_branch_return_multiple_tensors(x, y):
|
||||||
|
return cond(y, lambda x: (x, x), lambda x: (x, x), [x])
|
||||||
|
|
||||||
|
example_inputs = (torch.randn(4), torch.randn(2))
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
torch._dynamo.exc.UserError,
|
||||||
|
"Expected branch out type to be a single tensor",
|
||||||
|
):
|
||||||
|
torch._dynamo.export(
|
||||||
|
f_branch_return_multiple_tensors,
|
||||||
|
*example_inputs,
|
||||||
|
aten_graph=True,
|
||||||
|
tracing_mode="symbolic",
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_cond_raise_user_error_on_mismatch_return_length(self):
|
||||||
|
def true_fn(x):
|
||||||
|
return x
|
||||||
|
|
||||||
|
def false_fn(x):
|
||||||
|
return (x, x)
|
||||||
|
|
||||||
|
def f_mismatch_return_length(x):
|
||||||
|
return cond(torch.tensor(100), true_fn, false_fn, [x])
|
||||||
|
|
||||||
|
example_inputs = (torch.rand(5),)
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
torch._dynamo.exc.UserError,
|
||||||
|
"Expected branch out type to be a single tensor",
|
||||||
|
):
|
||||||
|
torch._dynamo.export(
|
||||||
|
f_mismatch_return_length,
|
||||||
|
*example_inputs,
|
||||||
|
aten_graph=True,
|
||||||
|
tracing_mode="symbolic",
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_cond_raise_user_error_on_mismatch_return_tensor_meta(self):
|
||||||
|
def true_fn(x):
|
||||||
|
return torch.tensor([[3], [2]])
|
||||||
|
|
||||||
|
def false_fn(x):
|
||||||
|
return torch.tensor([3.14])
|
||||||
|
|
||||||
|
def f_return_tensor_mismatch(x):
|
||||||
|
return cond(x.shape[0] < 3, true_fn, false_fn, [x])
|
||||||
|
|
||||||
|
example_inputs = (torch.rand(5),)
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
torch._dynamo.exc.UserError,
|
||||||
|
"Expected each tensor to have same metadata but got",
|
||||||
|
):
|
||||||
|
torch._dynamo.export(
|
||||||
|
f_return_tensor_mismatch,
|
||||||
|
*example_inputs,
|
||||||
|
aten_graph=True,
|
||||||
|
tracing_mode="symbolic",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
common_utils.instantiate_parametrized_tests(ExportTests)
|
common_utils.instantiate_parametrized_tests(ExportTests)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -6,8 +6,8 @@ from functorch.experimental import control_flow
|
||||||
from functorch.experimental.control_flow import cond
|
from functorch.experimental.control_flow import cond
|
||||||
from functorch.experimental.control_flow import UnsupportedAliasMutationException
|
from functorch.experimental.control_flow import UnsupportedAliasMutationException
|
||||||
from torch.fx.experimental.proxy_tensor import make_fx
|
from torch.fx.experimental.proxy_tensor import make_fx
|
||||||
|
|
||||||
from torch.testing._internal.common_utils import run_tests, TestCase
|
from torch.testing._internal.common_utils import run_tests, TestCase
|
||||||
|
from torch._dynamo.exc import CondOpArgsMismatchError
|
||||||
|
|
||||||
class TestControlFlow(TestCase):
|
class TestControlFlow(TestCase):
|
||||||
def test_cond_no_trace(self):
|
def test_cond_no_trace(self):
|
||||||
|
|
@ -378,7 +378,7 @@ class TestControlFlowTraced(TestCase):
|
||||||
out = "".join(out.split())
|
out = "".join(out.split())
|
||||||
self.assertEqual(code, out)
|
self.assertEqual(code, out)
|
||||||
|
|
||||||
def test_assert_on_mismatch_type_size(self):
|
def test_raise_error_on_mismatch_type_size(self):
|
||||||
def true_fn(x):
|
def true_fn(x):
|
||||||
return x.sin()
|
return x.sin()
|
||||||
|
|
||||||
|
|
@ -389,11 +389,13 @@ class TestControlFlowTraced(TestCase):
|
||||||
return cond(y, true_fn, false_fn, [x])
|
return cond(y, true_fn, false_fn, [x])
|
||||||
|
|
||||||
x = torch.randn(4)
|
x = torch.randn(4)
|
||||||
with self.assertRaises(AssertionError):
|
with self.assertRaisesRegex(
|
||||||
|
CondOpArgsMismatchError,
|
||||||
|
"Expected to return same number of outputs but got",
|
||||||
|
):
|
||||||
make_fx(f)(x, torch.tensor(False))
|
make_fx(f)(x, torch.tensor(False))
|
||||||
|
|
||||||
|
def test_raise_error_on_mismatch_tensor_size(self):
|
||||||
def test_assert_on_mismatch_tensor_size(self):
|
|
||||||
def true_fn(x):
|
def true_fn(x):
|
||||||
return x.sin()
|
return x.sin()
|
||||||
|
|
||||||
|
|
@ -404,7 +406,10 @@ class TestControlFlowTraced(TestCase):
|
||||||
return cond(y, true_fn, false_fn, [x])
|
return cond(y, true_fn, false_fn, [x])
|
||||||
|
|
||||||
x = torch.randn(4)
|
x = torch.randn(4)
|
||||||
with self.assertRaises(AssertionError):
|
with self.assertRaisesRegex(
|
||||||
|
CondOpArgsMismatchError,
|
||||||
|
"Expected each tensor to have same metadata but got",
|
||||||
|
):
|
||||||
make_fx(f)(x, torch.tensor(False))
|
make_fx(f)(x, torch.tensor(False))
|
||||||
|
|
||||||
def test_cond_traced_not_nested_fake_tensor(self):
|
def test_cond_traced_not_nested_fake_tensor(self):
|
||||||
|
|
@ -540,7 +545,7 @@ class TestControlFlowTraced(TestCase):
|
||||||
out = "".join(out.split())
|
out = "".join(out.split())
|
||||||
self.assertEqual(code, out)
|
self.assertEqual(code, out)
|
||||||
|
|
||||||
def test_assert_on_mismatch_type_size_fake_tensor(self):
|
def test_raise_error_on_mismatch_type_size_fake_tensor(self):
|
||||||
def true_fn(x):
|
def true_fn(x):
|
||||||
return x.sin()
|
return x.sin()
|
||||||
|
|
||||||
|
|
@ -551,11 +556,14 @@ class TestControlFlowTraced(TestCase):
|
||||||
return cond(y, true_fn, false_fn, [x])
|
return cond(y, true_fn, false_fn, [x])
|
||||||
|
|
||||||
x = torch.randn(4)
|
x = torch.randn(4)
|
||||||
with self.assertRaises(AssertionError):
|
with self.assertRaisesRegex(
|
||||||
|
CondOpArgsMismatchError,
|
||||||
|
"Expected to return same number of outputs but got",
|
||||||
|
):
|
||||||
make_fx(f, tracing_mode="fake")(x, torch.tensor(False))
|
make_fx(f, tracing_mode="fake")(x, torch.tensor(False))
|
||||||
|
|
||||||
|
|
||||||
def test_assert_on_mismatch_tensor_size_fake_tensor(self):
|
def test_raise_error_on_mismatch_tensor_size_fake_tensor(self):
|
||||||
def true_fn(x):
|
def true_fn(x):
|
||||||
return x.sin()
|
return x.sin()
|
||||||
|
|
||||||
|
|
@ -566,7 +574,10 @@ class TestControlFlowTraced(TestCase):
|
||||||
return cond(y, true_fn, false_fn, [x])
|
return cond(y, true_fn, false_fn, [x])
|
||||||
|
|
||||||
x = torch.randn(4)
|
x = torch.randn(4)
|
||||||
with self.assertRaises(AssertionError):
|
with self.assertRaisesRegex(
|
||||||
|
CondOpArgsMismatchError,
|
||||||
|
"Expected each tensor to have same metadata but got",
|
||||||
|
):
|
||||||
make_fx(f, tracing_mode="fake")(x, torch.tensor(False))
|
make_fx(f, tracing_mode="fake")(x, torch.tensor(False))
|
||||||
|
|
||||||
def check_map_graph(self, gm, key):
|
def check_map_graph(self, gm, key):
|
||||||
|
|
|
||||||
|
|
@ -45,7 +45,7 @@ else:
|
||||||
globals()[name] = getattr(torch._C._dynamo.eval_frame, name)
|
globals()[name] = getattr(torch._C._dynamo.eval_frame, name)
|
||||||
|
|
||||||
from . import config, convert_frame, skipfiles, utils
|
from . import config, convert_frame, skipfiles, utils
|
||||||
from .exc import ResetRequired
|
from .exc import CondOpArgsMismatchError, ResetRequired, UserError, UserErrorType
|
||||||
from .mutation_guard import install_generation_tagging_init
|
from .mutation_guard import install_generation_tagging_init
|
||||||
from .types import DynamoCallback
|
from .types import DynamoCallback
|
||||||
from .utils import compile_times
|
from .utils import compile_times
|
||||||
|
|
@ -843,12 +843,16 @@ def export(
|
||||||
return torch.fx.Interpreter(graph).run(*args)
|
return torch.fx.Interpreter(graph).run(*args)
|
||||||
|
|
||||||
with enable_python_dispatcher(), fake_mode:
|
with enable_python_dispatcher(), fake_mode:
|
||||||
graph = make_fx(
|
try:
|
||||||
graph_with_interpreter,
|
graph = make_fx(
|
||||||
decomposition_table=decomposition_table,
|
graph_with_interpreter,
|
||||||
tracing_mode="real",
|
decomposition_table=decomposition_table,
|
||||||
_allow_non_fake_inputs=True,
|
tracing_mode="real",
|
||||||
)(*example_fake_inputs)
|
_allow_non_fake_inputs=True,
|
||||||
|
)(*example_fake_inputs)
|
||||||
|
except CondOpArgsMismatchError as e:
|
||||||
|
# Wrap the internal error to the user-facing error
|
||||||
|
raise UserError(UserErrorType.DYNAMIC_CONTROL_FLOW, str(e))
|
||||||
|
|
||||||
new_graph = ChangeInputOutputSignature(
|
new_graph = ChangeInputOutputSignature(
|
||||||
graph,
|
graph,
|
||||||
|
|
|
||||||
|
|
@ -78,6 +78,20 @@ class RecompileError(TorchDynamoException):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class ArgsMismatchError(Unsupported):
|
||||||
|
def __init__(self, msg):
|
||||||
|
super().__init__(msg)
|
||||||
|
|
||||||
|
|
||||||
|
class CondOpArgsMismatchError(ArgsMismatchError):
|
||||||
|
"""
|
||||||
|
Internal error from cond() due to arguments mismatch.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, msg):
|
||||||
|
super().__init__(msg)
|
||||||
|
|
||||||
|
|
||||||
class UserErrorType(Enum):
|
class UserErrorType(Enum):
|
||||||
DYNAMIC_CONTROL_FLOW = auto()
|
DYNAMIC_CONTROL_FLOW = auto()
|
||||||
ANTI_PATTERN = auto()
|
ANTI_PATTERN = auto()
|
||||||
|
|
@ -96,6 +110,7 @@ class UserError(Unsupported):
|
||||||
"""
|
"""
|
||||||
super().__init__(msg)
|
super().__init__(msg)
|
||||||
self.error_type = error_type
|
self.error_type = error_type
|
||||||
|
self.message = msg
|
||||||
|
|
||||||
|
|
||||||
class IncorrectUsage(Exception):
|
class IncorrectUsage(Exception):
|
||||||
|
|
|
||||||
|
|
@ -43,7 +43,7 @@ from .bytecode_transformation import (
|
||||||
unique_id,
|
unique_id,
|
||||||
)
|
)
|
||||||
from .codegen import PyCodegen
|
from .codegen import PyCodegen
|
||||||
from .exc import BackendCompilerFailed, unimplemented, Unsupported
|
from .exc import ArgsMismatchError, BackendCompilerFailed, unimplemented, Unsupported
|
||||||
from .guards import GuardBuilder
|
from .guards import GuardBuilder
|
||||||
from .output_graph import GraphCompileReason, OutputGraph, OutputGraphState
|
from .output_graph import GraphCompileReason, OutputGraph, OutputGraphState
|
||||||
from .replay_record import DummyModule, ExecutionRecorder
|
from .replay_record import DummyModule, ExecutionRecorder
|
||||||
|
|
@ -2116,15 +2116,15 @@ class InliningInstructionTranslator(InstructionTranslatorBase):
|
||||||
try:
|
try:
|
||||||
sub_locals, closure_cells = func.bind_args(parent, args, kwargs)
|
sub_locals, closure_cells = func.bind_args(parent, args, kwargs)
|
||||||
except TypeError as e:
|
except TypeError as e:
|
||||||
log.warning(
|
# Wrap the general TypeError during bind_args() to the internal ArgsMismatchError with detailed info
|
||||||
"%s %s %s %s %s",
|
raise ArgsMismatchError(
|
||||||
func.get_filename(),
|
"{reason}.\n func = {func}, args = {args}, kwargs = {kwargs}".format(
|
||||||
func.get_function(),
|
reason=str(e),
|
||||||
args,
|
func=f"'{func.get_name()}' {func.get_filename()}:{func.get_code().co_firstlineno}",
|
||||||
kwargs,
|
args=[arg.python_type() for arg in args],
|
||||||
e,
|
kwargs=kwargs,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
unimplemented("arg mismatch inlining")
|
|
||||||
|
|
||||||
for v in itertools.chain(sub_locals.values(), closure_cells.values()):
|
for v in itertools.chain(sub_locals.values(), closure_cells.values()):
|
||||||
if not isinstance(v, VariableTracker):
|
if not isinstance(v, VariableTracker):
|
||||||
|
|
|
||||||
|
|
@ -435,7 +435,6 @@ class NestedUserFunctionVariable(BaseUserFunctionVariable):
|
||||||
)
|
)
|
||||||
if self.kwdefaults:
|
if self.kwdefaults:
|
||||||
func.__kwdefaults__ = self.kwdefaults.items
|
func.__kwdefaults__ = self.kwdefaults.items
|
||||||
|
|
||||||
bound = inspect.signature(func).bind(*args, **kwargs)
|
bound = inspect.signature(func).bind(*args, **kwargs)
|
||||||
bound.apply_defaults()
|
bound.apply_defaults()
|
||||||
result = dict(bound.arguments.items())
|
result = dict(bound.arguments.items())
|
||||||
|
|
|
||||||
|
|
@ -15,7 +15,7 @@ from torch._guards import GuardsCheckpointState
|
||||||
|
|
||||||
from .. import config, variables
|
from .. import config, variables
|
||||||
from ..allowed_functions import torch_get_name
|
from ..allowed_functions import torch_get_name
|
||||||
from ..exc import unimplemented
|
from ..exc import ArgsMismatchError, unimplemented, UserError, UserErrorType
|
||||||
from ..source import GeneratorStateSource, GetItemSource, NNModuleSource
|
from ..source import GeneratorStateSource, GetItemSource, NNModuleSource
|
||||||
from ..utils import (
|
from ..utils import (
|
||||||
check_constant_args,
|
check_constant_args,
|
||||||
|
|
@ -876,7 +876,12 @@ class TorchHigherOrderOperator(VariableTracker):
|
||||||
# Register output to graph
|
# Register output to graph
|
||||||
# Modeled off of compile_and_call_fx_graph
|
# Modeled off of compile_and_call_fx_graph
|
||||||
# TODO: support non single Tensor output
|
# TODO: support non single Tensor output
|
||||||
assert isinstance(output, TensorVariable)
|
if not isinstance(output, TensorVariable):
|
||||||
|
raise ArgsMismatchError(
|
||||||
|
"Expected branch out type to be a single tensor but got {}".format(
|
||||||
|
str(output.python_type())
|
||||||
|
),
|
||||||
|
)
|
||||||
tx.output.guards.update(output.guards)
|
tx.output.guards.update(output.guards)
|
||||||
tx.output.create_node(
|
tx.output.create_node(
|
||||||
"output", "output", (tx.output.create_arg((output.as_proxy(),))), {}
|
"output", "output", (tx.output.create_arg((output.as_proxy(),))), {}
|
||||||
|
|
@ -904,14 +909,45 @@ class TorchHigherOrderOperator(VariableTracker):
|
||||||
if self.value.__name__ == "cond":
|
if self.value.__name__ == "cond":
|
||||||
# TODO(voz): Support fake tensor dispatch for recursive
|
# TODO(voz): Support fake tensor dispatch for recursive
|
||||||
# ops - see torch/dispatch/_dispatcher.py
|
# ops - see torch/dispatch/_dispatcher.py
|
||||||
assert len(args) == 4
|
if len(args) != 4:
|
||||||
assert type(args[0]) in (
|
raise UserError(
|
||||||
TensorVariable,
|
UserErrorType.DYNAMIC_CONTROL_FLOW,
|
||||||
SymNodeVariable,
|
f"Expected 4 arguments but got {len(args)}.\n"
|
||||||
ConstantVariable,
|
f"Usage: cond(pred, true_fn, false_fn, operands)",
|
||||||
), str(
|
)
|
||||||
type(args[0])
|
# predicate
|
||||||
) # predicate
|
if type(args[0]) not in (ConstantVariable, TensorVariable, SymNodeVariable):
|
||||||
|
raise UserError(
|
||||||
|
UserErrorType.DYNAMIC_CONTROL_FLOW,
|
||||||
|
f"Expected pred to be bool/int or a tensor with single "
|
||||||
|
f"item but got {str(type(args[0]))} "
|
||||||
|
f"with original python type {str(args[0].python_type())}.",
|
||||||
|
)
|
||||||
|
|
||||||
|
# operands
|
||||||
|
if type(args[3]) is not ListVariable:
|
||||||
|
raise UserError(
|
||||||
|
UserErrorType.DYNAMIC_CONTROL_FLOW,
|
||||||
|
f"Expected a list but got {args[3].python_type()}",
|
||||||
|
)
|
||||||
|
operands = args[3].unpack_var_sequence(tx)
|
||||||
|
if not all(
|
||||||
|
isinstance(operand, (TensorVariable, torch.Tensor))
|
||||||
|
for operand in operands
|
||||||
|
):
|
||||||
|
raise UserError(
|
||||||
|
UserErrorType.DYNAMIC_CONTROL_FLOW,
|
||||||
|
"Expected a list of tensors but got {actual_args}".format(
|
||||||
|
actual_args=[
|
||||||
|
str(operand.python_type())
|
||||||
|
if isinstance(operand, VariableTracker)
|
||||||
|
else str(type(operand))
|
||||||
|
for operand in operands
|
||||||
|
],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# branches
|
||||||
assert isinstance(
|
assert isinstance(
|
||||||
args[1], (UserFunctionVariable, NestedUserFunctionVariable)
|
args[1], (UserFunctionVariable, NestedUserFunctionVariable)
|
||||||
), str(
|
), str(
|
||||||
|
|
@ -922,7 +958,6 @@ class TorchHigherOrderOperator(VariableTracker):
|
||||||
), str(
|
), str(
|
||||||
type(args[2])
|
type(args[2])
|
||||||
) # false_fn
|
) # false_fn
|
||||||
assert type(args[3]) is ListVariable, str(type(args[3])) # args
|
|
||||||
|
|
||||||
# Our strategy for tracing the true/false branches of cond
|
# Our strategy for tracing the true/false branches of cond
|
||||||
# are to checkpoint our graphstate, run the true branch,
|
# are to checkpoint our graphstate, run the true branch,
|
||||||
|
|
@ -939,14 +974,15 @@ class TorchHigherOrderOperator(VariableTracker):
|
||||||
|
|
||||||
graph_checkpoint, checkpoint = tx.output.graph, tx.copy_graphstate()
|
graph_checkpoint, checkpoint = tx.output.graph, tx.copy_graphstate()
|
||||||
|
|
||||||
sub_args = args[3].unpack_var_sequence(tx)
|
|
||||||
|
|
||||||
def speculate_branch(branch):
|
def speculate_branch(branch):
|
||||||
# NB: 0 is predicate
|
try:
|
||||||
ix = 1 if branch else 2
|
# NB: 0 is predicate
|
||||||
return speculate_subgraph(
|
ix = 1 if branch else 2
|
||||||
args[ix], sub_args, graph_checkpoint, checkpoint
|
return speculate_subgraph(
|
||||||
)
|
args[ix], operands, graph_checkpoint, checkpoint
|
||||||
|
)
|
||||||
|
except ArgsMismatchError as e:
|
||||||
|
raise UserError(UserErrorType.DYNAMIC_CONTROL_FLOW, str(e))
|
||||||
|
|
||||||
(
|
(
|
||||||
true_r,
|
true_r,
|
||||||
|
|
@ -988,7 +1024,7 @@ class TorchHigherOrderOperator(VariableTracker):
|
||||||
args[0].as_proxy(),
|
args[0].as_proxy(),
|
||||||
true_node,
|
true_node,
|
||||||
false_node,
|
false_node,
|
||||||
[a.as_proxy() for a in sub_args],
|
[a.as_proxy() for a in operands],
|
||||||
)
|
)
|
||||||
# TODO: assert that the true/false return values are
|
# TODO: assert that the true/false return values are
|
||||||
# consistent
|
# consistent
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user