mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Improve torch.cond useability: Return UserError with actionable error messages (#98909)
It's part of the effort to improve PT2 Export UX. This PR is to improve the usability of `torch.cond()` by separating user errors from the dynamo internal errors. By definition, user error means the usage of `torch.cond()` violates the restrictions of this API therefore needs users to take action and fix the error.
In this notebook N3363227 we discovered a bunch of limitations of using `torch.cond(pred, true_fn, false_fn, operands)`. In summary, the limitations can be categorized as:
- predicate restriction (`pred`)
- operands restriction (`operands`)
- branch restriction (`true_fn` & `false_fn`)
The error message will be more accurate about where the (user) error is from and more actionable for users to fix it.
For example, `operands` must be a list of tensors and the signature of `true_fn` and `false_fn` must match with the `operands`.
If the operands contains non-tensor types, user will see error message like:
```
torch._dynamo.exc.UserError: Expected a list of tensors but got ["<class 'torch.Tensor'>", "<class 'float'>"]
from user code:
File "~/pytorch/test/dynamo/test_export.py", line 2504, in f_non_tensor_operands
return cond(True, lambda x, a: x.sin(), lambda x, a: x.cos(), [x, a])
```
If the signature of the branch function doesn't match with `operands`, user will see error message like:
```
torch._dynamo.exc.UserError: too many positional arguments.
func = 'false_fn' ~/pytorch/test/dynamo/test_export.py:2514, args = [<class 'torch.Tensor'>, <class 'torch.Tensor'>], kwargs = {}
```
Or if the tensor returned from user defined branches has different metadata, e.g. shapes, dtypes, etc., user will see error message like:
```
TypeError: Expected each tensor to have same metadata but got:
cond_true_0 returns TensorMetadata(shape=torch.Size([2, 1]), dtype=torch.int64, requires_grad=False, stride=(1, 1), memory_format=torch.contiguous_format, is_quantized=False, qparams={})
cond_false_0 returns TensorMetadata(shape=torch.Size([1]), dtype=torch.float32, requires_grad=False, stride=(1,), memory_format=torch.contiguous_format, is_quantized=False, qparams={})
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/98909
Approved by: https://github.com/jansel
This commit is contained in:
parent
e47e8c9d98
commit
aa4ed332c3
|
|
@ -22,6 +22,7 @@ from torch.utils._python_dispatch import (
|
|||
_pop_mode_temporarily,
|
||||
)
|
||||
from torch.utils._pytree import tree_flatten
|
||||
from torch._dynamo.exc import CondOpArgsMismatchError
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
@ -56,12 +57,22 @@ def trace_cond(proxy_mode, func_overload, pred, true_fn, false_fn, operands):
|
|||
|
||||
flat_true_outs, _ = pytree.tree_flatten(true_outs)
|
||||
flat_false_outs, _ = pytree.tree_flatten(false_outs)
|
||||
assert(len(flat_true_outs) == len(flat_false_outs))
|
||||
if len(flat_true_outs) != len(flat_false_outs):
|
||||
raise CondOpArgsMismatchError(
|
||||
f"Expected to return same number of outputs but got:"
|
||||
f"\n {true_fn.__name__} returns {len(flat_true_outs)} item(s)"
|
||||
f"\n {false_fn.__name__} returns {len(flat_false_outs)} item(s)"
|
||||
)
|
||||
|
||||
for i in range(0, len(flat_true_outs)):
|
||||
true_out = flat_true_outs[i]
|
||||
false_out = flat_false_outs[i]
|
||||
assert true_out.meta['tensor_meta'] == false_out.meta['tensor_meta']
|
||||
if true_out.meta['tensor_meta'] != false_out.meta['tensor_meta']:
|
||||
raise CondOpArgsMismatchError(
|
||||
f"Expected each tensor to have same metadata but got:"
|
||||
f"\n {true_fn.__name__} returns {true_out.meta['tensor_meta']}"
|
||||
f"\n {false_fn.__name__} returns {false_out.meta['tensor_meta']}"
|
||||
)
|
||||
|
||||
# There are probably better ways - I know that create_arg has some self incrementing name
|
||||
# magic to it, but since we explicitly have to get the name for register_module,
|
||||
|
|
|
|||
|
|
@ -2602,22 +2602,29 @@ class ExportTests(torch._dynamo.test_case.TestCase):
|
|||
def false_fn(x):
|
||||
return x.sin()
|
||||
|
||||
def f_pred_traced_as_constant_var(x):
|
||||
return cond(x.dim() > 2, true_fn, false_fn, [x])
|
||||
|
||||
def f_pred_traced_as_symnode_var(x):
|
||||
return cond(x.shape[0] > 10, true_fn, false_fn, [x])
|
||||
return cond(x.shape[0] > 2, true_fn, false_fn, [x])
|
||||
|
||||
def f_pred_traced_as_tensor_var(x):
|
||||
return cond(x.all(), true_fn, false_fn, [x])
|
||||
|
||||
example_inputs = (torch.rand(5),)
|
||||
def f_pred_complex_expression_traced_as_symnode_var(x):
|
||||
return cond(
|
||||
x.dim() > 1 and x.shape[1] > 5 and x.shape[1] <= 10,
|
||||
true_fn,
|
||||
false_fn,
|
||||
[x],
|
||||
)
|
||||
|
||||
example_inputs = (torch.rand(5, 8),)
|
||||
for f in [
|
||||
f_pred_traced_as_constant_var,
|
||||
f_pred_traced_as_symnode_var,
|
||||
f_pred_traced_as_tensor_var,
|
||||
f_pred_complex_expression_traced_as_symnode_var,
|
||||
]:
|
||||
gm, _ = torch._dynamo.export(f, *example_inputs)
|
||||
gm, _ = torch._dynamo.export(
|
||||
f, *example_inputs, aten_graph=True, tracing_mode="symbolic"
|
||||
)
|
||||
self.assertEqual(gm(*example_inputs), f(*example_inputs))
|
||||
|
||||
def test_mixed_real_and_fake_inputs(self):
|
||||
|
|
@ -2669,6 +2676,175 @@ class ExportTests(torch._dynamo.test_case.TestCase):
|
|||
self.assertEqual(gm(*true_inp), f(*true_inp))
|
||||
self.assertEqual(gm(*false_inp), f(*false_inp))
|
||||
|
||||
def test_cond_raise_user_error_on_missing_args(self):
|
||||
def true_fn(x):
|
||||
return x.cos()
|
||||
|
||||
def false_fn(x):
|
||||
return x.sin()
|
||||
|
||||
def f(x):
|
||||
return cond(x.shape[0] > 10, true_fn, false_fn)
|
||||
|
||||
example_inputs = (torch.rand(5),)
|
||||
with self.assertRaisesRegex(
|
||||
torch._dynamo.exc.UserError,
|
||||
"Expected 4 arguments",
|
||||
):
|
||||
torch._dynamo.export(
|
||||
f, *example_inputs, aten_graph=True, tracing_mode="symbolic"
|
||||
)
|
||||
|
||||
def test_cond_raise_user_error_on_unsupported_pred(self):
|
||||
def f_unsupported_pred(x):
|
||||
pred = torch.nn.Module()
|
||||
return cond(pred, lambda x: x.sin(), lambda x: x.cos(), [x])
|
||||
|
||||
example_inputs = (torch.rand(5),)
|
||||
with self.assertRaisesRegex(
|
||||
torch._dynamo.exc.UserError,
|
||||
"Expected pred to be bool/int or a tensor",
|
||||
):
|
||||
torch._dynamo.export(
|
||||
f_unsupported_pred,
|
||||
*example_inputs,
|
||||
aten_graph=True,
|
||||
tracing_mode="symbolic",
|
||||
)
|
||||
|
||||
def test_cond_raise_user_error_on_non_list_operands(self):
|
||||
def f_non_list_operands(x):
|
||||
return cond(torch.tensor(True), lambda x: x.sin(), lambda x: x.cos(), x)
|
||||
|
||||
example_inputs = (torch.rand(5),)
|
||||
with self.assertRaisesRegex(
|
||||
torch._dynamo.exc.UserError,
|
||||
"Expected a list but got",
|
||||
):
|
||||
torch._dynamo.export(
|
||||
f_non_list_operands,
|
||||
*example_inputs,
|
||||
aten_graph=True,
|
||||
tracing_mode="symbolic",
|
||||
)
|
||||
|
||||
def test_cond_raise_user_error_on_non_tensor_operands(self):
|
||||
def f_non_tensor_operands(x):
|
||||
a: float = 3.14
|
||||
return cond(
|
||||
torch.tensor(1234), lambda x, a: x.sin(), lambda x, a: x.cos(), [x, a]
|
||||
)
|
||||
|
||||
example_inputs = (torch.rand(5),)
|
||||
with self.assertRaisesRegex(
|
||||
torch._dynamo.exc.UserError,
|
||||
"Expected a list of tensors",
|
||||
):
|
||||
torch._dynamo.export(
|
||||
f_non_tensor_operands,
|
||||
*example_inputs,
|
||||
aten_graph=True,
|
||||
tracing_mode="symbolic",
|
||||
)
|
||||
|
||||
def test_cond_raise_user_error_on_branch_args_mismatch(self):
|
||||
def true_fn(x, y):
|
||||
return x.sin()
|
||||
|
||||
def false_fn(x):
|
||||
return x.cos()
|
||||
|
||||
def f_branch_args_mismatch(x, y):
|
||||
return cond(torch.tensor([[[[100]]]]), true_fn, false_fn, [x, y])
|
||||
|
||||
example_inputs = (torch.rand(5), torch.rand(2))
|
||||
with self.assertRaisesRegex(
|
||||
torch._dynamo.exc.UserError,
|
||||
"too many positional arguments",
|
||||
):
|
||||
torch._dynamo.export(
|
||||
f_branch_args_mismatch,
|
||||
*example_inputs,
|
||||
aten_graph=True,
|
||||
tracing_mode="symbolic",
|
||||
)
|
||||
|
||||
def test_cond_raise_user_error_on_branch_return_non_tensor(self):
|
||||
def f_branch_return_non_tensor(x):
|
||||
return cond(x.shape[0] <= 5, lambda x: 3.14, lambda x: 3.14, [x])
|
||||
|
||||
example_inputs = (torch.rand(5),)
|
||||
with self.assertRaisesRegex(
|
||||
torch._dynamo.exc.UserError,
|
||||
"Expected branch out type to be a single tensor",
|
||||
):
|
||||
torch._dynamo.export(
|
||||
f_branch_return_non_tensor,
|
||||
*example_inputs,
|
||||
aten_graph=True,
|
||||
tracing_mode="symbolic",
|
||||
)
|
||||
|
||||
def test_cond_raise_user_error_on_branch_return_multiple_tenors(self):
|
||||
def f_branch_return_multiple_tensors(x, y):
|
||||
return cond(y, lambda x: (x, x), lambda x: (x, x), [x])
|
||||
|
||||
example_inputs = (torch.randn(4), torch.randn(2))
|
||||
with self.assertRaisesRegex(
|
||||
torch._dynamo.exc.UserError,
|
||||
"Expected branch out type to be a single tensor",
|
||||
):
|
||||
torch._dynamo.export(
|
||||
f_branch_return_multiple_tensors,
|
||||
*example_inputs,
|
||||
aten_graph=True,
|
||||
tracing_mode="symbolic",
|
||||
)
|
||||
|
||||
def test_cond_raise_user_error_on_mismatch_return_length(self):
|
||||
def true_fn(x):
|
||||
return x
|
||||
|
||||
def false_fn(x):
|
||||
return (x, x)
|
||||
|
||||
def f_mismatch_return_length(x):
|
||||
return cond(torch.tensor(100), true_fn, false_fn, [x])
|
||||
|
||||
example_inputs = (torch.rand(5),)
|
||||
with self.assertRaisesRegex(
|
||||
torch._dynamo.exc.UserError,
|
||||
"Expected branch out type to be a single tensor",
|
||||
):
|
||||
torch._dynamo.export(
|
||||
f_mismatch_return_length,
|
||||
*example_inputs,
|
||||
aten_graph=True,
|
||||
tracing_mode="symbolic",
|
||||
)
|
||||
|
||||
def test_cond_raise_user_error_on_mismatch_return_tensor_meta(self):
|
||||
def true_fn(x):
|
||||
return torch.tensor([[3], [2]])
|
||||
|
||||
def false_fn(x):
|
||||
return torch.tensor([3.14])
|
||||
|
||||
def f_return_tensor_mismatch(x):
|
||||
return cond(x.shape[0] < 3, true_fn, false_fn, [x])
|
||||
|
||||
example_inputs = (torch.rand(5),)
|
||||
with self.assertRaisesRegex(
|
||||
torch._dynamo.exc.UserError,
|
||||
"Expected each tensor to have same metadata but got",
|
||||
):
|
||||
torch._dynamo.export(
|
||||
f_return_tensor_mismatch,
|
||||
*example_inputs,
|
||||
aten_graph=True,
|
||||
tracing_mode="symbolic",
|
||||
)
|
||||
|
||||
|
||||
common_utils.instantiate_parametrized_tests(ExportTests)
|
||||
|
||||
|
|
|
|||
|
|
@ -6,8 +6,8 @@ from functorch.experimental import control_flow
|
|||
from functorch.experimental.control_flow import cond
|
||||
from functorch.experimental.control_flow import UnsupportedAliasMutationException
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
|
||||
from torch.testing._internal.common_utils import run_tests, TestCase
|
||||
from torch._dynamo.exc import CondOpArgsMismatchError
|
||||
|
||||
class TestControlFlow(TestCase):
|
||||
def test_cond_no_trace(self):
|
||||
|
|
@ -378,7 +378,7 @@ class TestControlFlowTraced(TestCase):
|
|||
out = "".join(out.split())
|
||||
self.assertEqual(code, out)
|
||||
|
||||
def test_assert_on_mismatch_type_size(self):
|
||||
def test_raise_error_on_mismatch_type_size(self):
|
||||
def true_fn(x):
|
||||
return x.sin()
|
||||
|
||||
|
|
@ -389,11 +389,13 @@ class TestControlFlowTraced(TestCase):
|
|||
return cond(y, true_fn, false_fn, [x])
|
||||
|
||||
x = torch.randn(4)
|
||||
with self.assertRaises(AssertionError):
|
||||
with self.assertRaisesRegex(
|
||||
CondOpArgsMismatchError,
|
||||
"Expected to return same number of outputs but got",
|
||||
):
|
||||
make_fx(f)(x, torch.tensor(False))
|
||||
|
||||
|
||||
def test_assert_on_mismatch_tensor_size(self):
|
||||
def test_raise_error_on_mismatch_tensor_size(self):
|
||||
def true_fn(x):
|
||||
return x.sin()
|
||||
|
||||
|
|
@ -404,7 +406,10 @@ class TestControlFlowTraced(TestCase):
|
|||
return cond(y, true_fn, false_fn, [x])
|
||||
|
||||
x = torch.randn(4)
|
||||
with self.assertRaises(AssertionError):
|
||||
with self.assertRaisesRegex(
|
||||
CondOpArgsMismatchError,
|
||||
"Expected each tensor to have same metadata but got",
|
||||
):
|
||||
make_fx(f)(x, torch.tensor(False))
|
||||
|
||||
def test_cond_traced_not_nested_fake_tensor(self):
|
||||
|
|
@ -540,7 +545,7 @@ class TestControlFlowTraced(TestCase):
|
|||
out = "".join(out.split())
|
||||
self.assertEqual(code, out)
|
||||
|
||||
def test_assert_on_mismatch_type_size_fake_tensor(self):
|
||||
def test_raise_error_on_mismatch_type_size_fake_tensor(self):
|
||||
def true_fn(x):
|
||||
return x.sin()
|
||||
|
||||
|
|
@ -551,11 +556,14 @@ class TestControlFlowTraced(TestCase):
|
|||
return cond(y, true_fn, false_fn, [x])
|
||||
|
||||
x = torch.randn(4)
|
||||
with self.assertRaises(AssertionError):
|
||||
with self.assertRaisesRegex(
|
||||
CondOpArgsMismatchError,
|
||||
"Expected to return same number of outputs but got",
|
||||
):
|
||||
make_fx(f, tracing_mode="fake")(x, torch.tensor(False))
|
||||
|
||||
|
||||
def test_assert_on_mismatch_tensor_size_fake_tensor(self):
|
||||
def test_raise_error_on_mismatch_tensor_size_fake_tensor(self):
|
||||
def true_fn(x):
|
||||
return x.sin()
|
||||
|
||||
|
|
@ -566,7 +574,10 @@ class TestControlFlowTraced(TestCase):
|
|||
return cond(y, true_fn, false_fn, [x])
|
||||
|
||||
x = torch.randn(4)
|
||||
with self.assertRaises(AssertionError):
|
||||
with self.assertRaisesRegex(
|
||||
CondOpArgsMismatchError,
|
||||
"Expected each tensor to have same metadata but got",
|
||||
):
|
||||
make_fx(f, tracing_mode="fake")(x, torch.tensor(False))
|
||||
|
||||
def check_map_graph(self, gm, key):
|
||||
|
|
|
|||
|
|
@ -45,7 +45,7 @@ else:
|
|||
globals()[name] = getattr(torch._C._dynamo.eval_frame, name)
|
||||
|
||||
from . import config, convert_frame, skipfiles, utils
|
||||
from .exc import ResetRequired
|
||||
from .exc import CondOpArgsMismatchError, ResetRequired, UserError, UserErrorType
|
||||
from .mutation_guard import install_generation_tagging_init
|
||||
from .types import DynamoCallback
|
||||
from .utils import compile_times
|
||||
|
|
@ -843,12 +843,16 @@ def export(
|
|||
return torch.fx.Interpreter(graph).run(*args)
|
||||
|
||||
with enable_python_dispatcher(), fake_mode:
|
||||
graph = make_fx(
|
||||
graph_with_interpreter,
|
||||
decomposition_table=decomposition_table,
|
||||
tracing_mode="real",
|
||||
_allow_non_fake_inputs=True,
|
||||
)(*example_fake_inputs)
|
||||
try:
|
||||
graph = make_fx(
|
||||
graph_with_interpreter,
|
||||
decomposition_table=decomposition_table,
|
||||
tracing_mode="real",
|
||||
_allow_non_fake_inputs=True,
|
||||
)(*example_fake_inputs)
|
||||
except CondOpArgsMismatchError as e:
|
||||
# Wrap the internal error to the user-facing error
|
||||
raise UserError(UserErrorType.DYNAMIC_CONTROL_FLOW, str(e))
|
||||
|
||||
new_graph = ChangeInputOutputSignature(
|
||||
graph,
|
||||
|
|
|
|||
|
|
@ -78,6 +78,20 @@ class RecompileError(TorchDynamoException):
|
|||
pass
|
||||
|
||||
|
||||
class ArgsMismatchError(Unsupported):
|
||||
def __init__(self, msg):
|
||||
super().__init__(msg)
|
||||
|
||||
|
||||
class CondOpArgsMismatchError(ArgsMismatchError):
|
||||
"""
|
||||
Internal error from cond() due to arguments mismatch.
|
||||
"""
|
||||
|
||||
def __init__(self, msg):
|
||||
super().__init__(msg)
|
||||
|
||||
|
||||
class UserErrorType(Enum):
|
||||
DYNAMIC_CONTROL_FLOW = auto()
|
||||
ANTI_PATTERN = auto()
|
||||
|
|
@ -96,6 +110,7 @@ class UserError(Unsupported):
|
|||
"""
|
||||
super().__init__(msg)
|
||||
self.error_type = error_type
|
||||
self.message = msg
|
||||
|
||||
|
||||
class IncorrectUsage(Exception):
|
||||
|
|
|
|||
|
|
@ -43,7 +43,7 @@ from .bytecode_transformation import (
|
|||
unique_id,
|
||||
)
|
||||
from .codegen import PyCodegen
|
||||
from .exc import BackendCompilerFailed, unimplemented, Unsupported
|
||||
from .exc import ArgsMismatchError, BackendCompilerFailed, unimplemented, Unsupported
|
||||
from .guards import GuardBuilder
|
||||
from .output_graph import GraphCompileReason, OutputGraph, OutputGraphState
|
||||
from .replay_record import DummyModule, ExecutionRecorder
|
||||
|
|
@ -2116,15 +2116,15 @@ class InliningInstructionTranslator(InstructionTranslatorBase):
|
|||
try:
|
||||
sub_locals, closure_cells = func.bind_args(parent, args, kwargs)
|
||||
except TypeError as e:
|
||||
log.warning(
|
||||
"%s %s %s %s %s",
|
||||
func.get_filename(),
|
||||
func.get_function(),
|
||||
args,
|
||||
kwargs,
|
||||
e,
|
||||
# Wrap the general TypeError during bind_args() to the internal ArgsMismatchError with detailed info
|
||||
raise ArgsMismatchError(
|
||||
"{reason}.\n func = {func}, args = {args}, kwargs = {kwargs}".format(
|
||||
reason=str(e),
|
||||
func=f"'{func.get_name()}' {func.get_filename()}:{func.get_code().co_firstlineno}",
|
||||
args=[arg.python_type() for arg in args],
|
||||
kwargs=kwargs,
|
||||
),
|
||||
)
|
||||
unimplemented("arg mismatch inlining")
|
||||
|
||||
for v in itertools.chain(sub_locals.values(), closure_cells.values()):
|
||||
if not isinstance(v, VariableTracker):
|
||||
|
|
|
|||
|
|
@ -435,7 +435,6 @@ class NestedUserFunctionVariable(BaseUserFunctionVariable):
|
|||
)
|
||||
if self.kwdefaults:
|
||||
func.__kwdefaults__ = self.kwdefaults.items
|
||||
|
||||
bound = inspect.signature(func).bind(*args, **kwargs)
|
||||
bound.apply_defaults()
|
||||
result = dict(bound.arguments.items())
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ from torch._guards import GuardsCheckpointState
|
|||
|
||||
from .. import config, variables
|
||||
from ..allowed_functions import torch_get_name
|
||||
from ..exc import unimplemented
|
||||
from ..exc import ArgsMismatchError, unimplemented, UserError, UserErrorType
|
||||
from ..source import GeneratorStateSource, GetItemSource, NNModuleSource
|
||||
from ..utils import (
|
||||
check_constant_args,
|
||||
|
|
@ -876,7 +876,12 @@ class TorchHigherOrderOperator(VariableTracker):
|
|||
# Register output to graph
|
||||
# Modeled off of compile_and_call_fx_graph
|
||||
# TODO: support non single Tensor output
|
||||
assert isinstance(output, TensorVariable)
|
||||
if not isinstance(output, TensorVariable):
|
||||
raise ArgsMismatchError(
|
||||
"Expected branch out type to be a single tensor but got {}".format(
|
||||
str(output.python_type())
|
||||
),
|
||||
)
|
||||
tx.output.guards.update(output.guards)
|
||||
tx.output.create_node(
|
||||
"output", "output", (tx.output.create_arg((output.as_proxy(),))), {}
|
||||
|
|
@ -904,14 +909,45 @@ class TorchHigherOrderOperator(VariableTracker):
|
|||
if self.value.__name__ == "cond":
|
||||
# TODO(voz): Support fake tensor dispatch for recursive
|
||||
# ops - see torch/dispatch/_dispatcher.py
|
||||
assert len(args) == 4
|
||||
assert type(args[0]) in (
|
||||
TensorVariable,
|
||||
SymNodeVariable,
|
||||
ConstantVariable,
|
||||
), str(
|
||||
type(args[0])
|
||||
) # predicate
|
||||
if len(args) != 4:
|
||||
raise UserError(
|
||||
UserErrorType.DYNAMIC_CONTROL_FLOW,
|
||||
f"Expected 4 arguments but got {len(args)}.\n"
|
||||
f"Usage: cond(pred, true_fn, false_fn, operands)",
|
||||
)
|
||||
# predicate
|
||||
if type(args[0]) not in (ConstantVariable, TensorVariable, SymNodeVariable):
|
||||
raise UserError(
|
||||
UserErrorType.DYNAMIC_CONTROL_FLOW,
|
||||
f"Expected pred to be bool/int or a tensor with single "
|
||||
f"item but got {str(type(args[0]))} "
|
||||
f"with original python type {str(args[0].python_type())}.",
|
||||
)
|
||||
|
||||
# operands
|
||||
if type(args[3]) is not ListVariable:
|
||||
raise UserError(
|
||||
UserErrorType.DYNAMIC_CONTROL_FLOW,
|
||||
f"Expected a list but got {args[3].python_type()}",
|
||||
)
|
||||
operands = args[3].unpack_var_sequence(tx)
|
||||
if not all(
|
||||
isinstance(operand, (TensorVariable, torch.Tensor))
|
||||
for operand in operands
|
||||
):
|
||||
raise UserError(
|
||||
UserErrorType.DYNAMIC_CONTROL_FLOW,
|
||||
"Expected a list of tensors but got {actual_args}".format(
|
||||
actual_args=[
|
||||
str(operand.python_type())
|
||||
if isinstance(operand, VariableTracker)
|
||||
else str(type(operand))
|
||||
for operand in operands
|
||||
],
|
||||
),
|
||||
)
|
||||
|
||||
# branches
|
||||
assert isinstance(
|
||||
args[1], (UserFunctionVariable, NestedUserFunctionVariable)
|
||||
), str(
|
||||
|
|
@ -922,7 +958,6 @@ class TorchHigherOrderOperator(VariableTracker):
|
|||
), str(
|
||||
type(args[2])
|
||||
) # false_fn
|
||||
assert type(args[3]) is ListVariable, str(type(args[3])) # args
|
||||
|
||||
# Our strategy for tracing the true/false branches of cond
|
||||
# are to checkpoint our graphstate, run the true branch,
|
||||
|
|
@ -939,14 +974,15 @@ class TorchHigherOrderOperator(VariableTracker):
|
|||
|
||||
graph_checkpoint, checkpoint = tx.output.graph, tx.copy_graphstate()
|
||||
|
||||
sub_args = args[3].unpack_var_sequence(tx)
|
||||
|
||||
def speculate_branch(branch):
|
||||
# NB: 0 is predicate
|
||||
ix = 1 if branch else 2
|
||||
return speculate_subgraph(
|
||||
args[ix], sub_args, graph_checkpoint, checkpoint
|
||||
)
|
||||
try:
|
||||
# NB: 0 is predicate
|
||||
ix = 1 if branch else 2
|
||||
return speculate_subgraph(
|
||||
args[ix], operands, graph_checkpoint, checkpoint
|
||||
)
|
||||
except ArgsMismatchError as e:
|
||||
raise UserError(UserErrorType.DYNAMIC_CONTROL_FLOW, str(e))
|
||||
|
||||
(
|
||||
true_r,
|
||||
|
|
@ -988,7 +1024,7 @@ class TorchHigherOrderOperator(VariableTracker):
|
|||
args[0].as_proxy(),
|
||||
true_node,
|
||||
false_node,
|
||||
[a.as_proxy() for a in sub_args],
|
||||
[a.as_proxy() for a in operands],
|
||||
)
|
||||
# TODO: assert that the true/false return values are
|
||||
# consistent
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user