[cond] support output sizes mismatch in front end (#147130)

This PR finishes https://github.com/pytorch/pytorch/pull/137615 by addressing the TODOs and comments left there.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/147130
Approved by: https://github.com/zou3519
This commit is contained in:
Yidi Wu 2025-02-24 13:34:31 -08:00 committed by PyTorch MergeBot
parent de80b6f0d3
commit 824474cb35
6 changed files with 477 additions and 109 deletions

View File

@ -1897,14 +1897,18 @@ def forward(self, x, y):
def forward(self, x): def forward(self, x):
arg0, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec) arg0, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
l_x_ = arg0 l_x_ = arg0
size = l_x_.size() sym_size_int = torch.ops.aten.sym_size.int(l_x_, 0)
getitem = size[0]; size = None le = sym_size_int <= 2; sym_size_int = None
le = getitem <= 2; getitem = None
cond_true_0 = self.cond_true_0 cond_true_0 = self.cond_true_0
cond_false_0 = self.cond_false_0 cond_false_0 = self.cond_false_0
cond = torch.ops.higher_order.cond(le, cond_true_0, cond_false_0, [l_x_]); le = cond_true_0 = cond_false_0 = l_x_ = None cond = torch.ops.higher_order.cond(le, cond_true_0, cond_false_0, [l_x_]); le = cond_true_0 = cond_false_0 = l_x_ = None
getitem_3 = cond[0]
sym_size_int_1 = torch.ops.aten.sym_size.int(getitem_3, 0); getitem_3 = None
sym_constrain_range_for_size_default = torch.ops.aten.sym_constrain_range_for_size.default(sym_size_int_1); sym_constrain_range_for_size_default = None
ge = sym_size_int_1 >= 2; sym_size_int_1 = None
_assert_scalar_default = torch.ops.aten._assert_scalar.default(ge, "Runtime assertion failed for expression u0 >= 2 on node 'ge'"); ge = _assert_scalar_default = None
getitem_2 = cond[0]; cond = None getitem_2 = cond[0]; cond = None
return pytree.tree_unflatten([getitem_2], self._out_spec)""", return pytree.tree_unflatten([getitem_2], self._out_spec)""", # noqa: B950
) )
self.assertExpectedInline( self.assertExpectedInline(
out_graph.cond_true_0.code.strip(), out_graph.cond_true_0.code.strip(),
@ -1922,12 +1926,8 @@ def forward(self, l_x_):
getitem = l_x__1[slice(None, 2, None)]; l_x__1 = None getitem = l_x__1[slice(None, 2, None)]; l_x__1 = None
return (getitem,)""", return (getitem,)""",
) )
with self.assertRaisesRegex( # We could successfully export branches that return different sizes
torch._dynamo.exc.UncapturedHigherOrderOpError, torch._dynamo.export(mod)(torch.randn(3, 2))
"Expected true_fn_output and false_fn_output to have same metadata but found",
):
# True branch and false branch return tensors of different shape
torch._dynamo.export(mod)(torch.randn(3, 2))
# We specialize into one of the branches since predicate is a python boolean. # We specialize into one of the branches since predicate is a python boolean.
test_x = torch.randn(3, 2) test_x = torch.randn(3, 2)
@ -3334,8 +3334,8 @@ def forward(self, x):
example_inputs = (torch.rand(5),) example_inputs = (torch.rand(5),)
with self.assertRaisesRegex( with self.assertRaisesRegex(
torch._dynamo.exc.UncapturedHigherOrderOpError, torch._dynamo.exc.TorchRuntimeError,
"Expected true_fn_output and false_fn_output to have same number of outputs but got", "Unmatched output spec from torch.cond branches",
): ):
torch._dynamo.export( torch._dynamo.export(
f_mismatch_return_length, f_mismatch_return_length,
@ -3354,8 +3354,8 @@ def forward(self, x):
example_inputs = (torch.rand(5),) example_inputs = (torch.rand(5),)
with self.assertRaisesRegex( with self.assertRaisesRegex(
torch._dynamo.exc.UncapturedHigherOrderOpError, torch._dynamo.exc.TorchRuntimeError,
"Expected true_fn_output and false_fn_output to have same metadata but found", "When merging two branches' output in torch.cond",
): ):
torch._dynamo.export(f_return_tensor_mismatch, aten_graph=True)( torch._dynamo.export(f_return_tensor_mismatch, aten_graph=True)(
*example_inputs, *example_inputs,

View File

@ -6961,11 +6961,9 @@ class ActivationCheckpointingTests(torch._dynamo.test_case.TestCase):
return torch.cond(x.sum() > 0, true_fn, false_fn) return torch.cond(x.sum() > 0, true_fn, false_fn)
x = torch.randn(2, 3) x = torch.randn(2, 3)
with self.assertRaises(torch._dynamo.exc.UncapturedHigherOrderOpError): output_mismatch_test(x)
output_mismatch_test(x)
with self.assertRaises(torch._dynamo.exc.UncapturedHigherOrderOpError): torch.compile(output_mismatch_test, backend="eager")(x)
torch.compile(output_mismatch_test)(x)
def test_non_aliasing_util(self): def test_non_aliasing_util(self):
from torch._dynamo.variables.higher_order_ops import _assert_tensors_nonaliasing from torch._dynamo.variables.higher_order_ops import _assert_tensors_nonaliasing

View File

@ -968,7 +968,7 @@ def forward(self, pred_1):
result = cond(pred, true_fn, false_fn, ({"t": [a, {"b": b}, (c,)]},)) result = cond(pred, true_fn, false_fn, ({"t": [a, {"b": b}, (c,)]},))
return result return result
gm = make_fx(f, tracing_mode="symbolic", _allow_non_fake_inputs=True)(pred) gm = make_fx(f, tracing_mode="real", _allow_non_fake_inputs=True)(pred)
self.assertExpectedInline( self.assertExpectedInline(
gm.code.strip(), gm.code.strip(),
"""\ """\
@ -5379,8 +5379,8 @@ def forward(self, arg0_1):
x = torch.randn(4) x = torch.randn(4)
with self.assertRaisesRegex( with self.assertRaisesRegex(
torch._dynamo.exc.UncapturedHigherOrderOpError, torch._dynamo.exc.TorchRuntimeError,
"Expected true_fn_output and false_fn_output to have same number of outputs but got", "Unmatched output spec from torch.cond branches",
): ):
make_fx(f)(x, torch.tensor(False)) make_fx(f)(x, torch.tensor(False))
@ -5396,8 +5396,8 @@ def forward(self, arg0_1):
x = torch.randn(4) x = torch.randn(4)
with self.assertRaisesRegex( with self.assertRaisesRegex(
torch._dynamo.exc.UncapturedHigherOrderOpError, torch._dynamo.exc.TorchRuntimeError,
"Expected true_fn_output and false_fn_output to have same metadata but found", "When merging two branches' output in torch.cond",
): ):
make_fx(f)(x, torch.tensor(False)) make_fx(f)(x, torch.tensor(False))
@ -5552,8 +5552,8 @@ def forward(self, arg0_1):
x = torch.randn(4) x = torch.randn(4)
with self.assertRaisesRegex( with self.assertRaisesRegex(
torch._dynamo.exc.UncapturedHigherOrderOpError, torch._dynamo.exc.TorchRuntimeError,
"Expected true_fn_output and false_fn_output to have same number of outputs but got", "Unmatched output spec from torch.cond branches",
): ):
make_fx(f, tracing_mode="fake")(x, torch.tensor(False)) make_fx(f, tracing_mode="fake")(x, torch.tensor(False))
@ -5569,8 +5569,8 @@ def forward(self, arg0_1):
x = torch.randn(4) x = torch.randn(4)
with self.assertRaisesRegex( with self.assertRaisesRegex(
torch._dynamo.exc.UncapturedHigherOrderOpError, torch._dynamo.exc.TorchRuntimeError,
"Expected true_fn_output and false_fn_output to have same metadata but found", "When merging two branches' output in torch.cond",
): ):
make_fx(f, tracing_mode="fake")(x, torch.tensor(False)) make_fx(f, tracing_mode="fake")(x, torch.tensor(False))
@ -7547,6 +7547,215 @@ class GraphModule(torch.nn.Module):
dynamic_shapes = {"x": {0: Dim("d")}, "y": {0: Dim("d1")}, "z": {0: Dim("d")}} dynamic_shapes = {"x": {0: Dim("d")}, "y": {0: Dim("d1")}, "z": {0: Dim("d")}}
_ = self._check_export_ret_graph_str(model, args, dynamic_shapes) _ = self._check_export_ret_graph_str(model, args, dynamic_shapes)
@skipIfTorchDynamo(
"Skip because _merge_tensors is not intended for dynamo to compile"
)
def test_merge_tensors(self):
from torch._higher_order_ops.cond import _merge_tensors
from torch._subclasses.fake_tensor import FakeTensorMode
from torch.fx.experimental.symbolic_shapes import ShapeEnv
# The shapes and strides are from raondomly generated pairs of tensors then swapaxes
valid_test_cases = [
# [(size1, stride1), (size2, stride2), (expected_stride, expected_size)]
[((3,), (1,)), ((4,), (1,)), ("(u0,)", "(1,)")],
[((1, 3), (3, 1)), ((3, 2), (2, 1)), ("(u0, u1)", "(u1, 1)")],
[((2, 1), (1, 1)), ((7, 3), (3, 1)), ("(u0, u1)", "(u1, 1)")],
[((5, 5), (1, 5)), ((4, 5), (1, 4)), ("(u0, 5)", "(1, u0)")],
[
((7, 3, 1), (1, 7, 1)),
((4, 3, 3), (3, 12, 1)),
("(u0, 3, u1)", "(u1, u0*u1, 1)"),
],
[
((5, 7, 4), (7, 1, 35)),
((7, 4, 4), (4, 1, 28)),
("(u0, u1, 4)", "(u1, 1, u0*u1)"),
],
[
((1, 6, 3, 2), (36, 1, 6, 18)),
((4, 2, 2, 6), (24, 1, 2, 4)),
("(u0, u1, u2, u3)", "(u1*u2*u3, 1, u1, u1*u2)"),
],
[
((6, 1, 6, 3), (18, 1, 1, 6)),
((2, 1, 3, 4), (12, 1, 1, 3)),
("(u0, 1, u1, u2)", "(u1*u2, 1, 1, u1)"),
],
[
((3, 1, 2, 4, 1), (8, 8, 4, 1, 1)),
((2, 4, 1, 4, 1), (16, 4, 4, 1, 1)),
("(u0, u1, u2, 4, 1)", "(4*u1*u2, 4*u2, 4, 1, 1)"),
],
]
def _inner(case):
fake_mode = FakeTensorMode(shape_env=ShapeEnv())
(size1, stride1), (size2, stride2), (merged_size, merged_stride) = case
with fake_mode:
t1 = torch.empty_strided(size1, stride1)
t2 = torch.empty_strided(size2, stride2)
out = _merge_tensors(t1, t2, fake_mode)
self.assertEqual(str(tuple(out.size())), merged_size)
self.assertEqual(str(tuple(out.stride())), merged_stride)
for case in valid_test_cases:
_inner(case)
# The shapes and strides are from raondomly generated pairs of tensors then swapaxes
invalid_test_cases = [
# [(size1, stride1), (size2, stride2)]
[((1,), (1,)), ((1,), (0,))],
[
((1, 3), (1, 1)),
((5, 6), (6, 1)),
], # t1 is not contiguous, t2 is contiguous
[
((2, 1), (1, 1)),
((7, 3), (1, 3)),
], # t1 is contiguous, t2 is not contiguous
[
((5, 4), (4, 1)),
((5, 5), (1, 5)),
], # t1 is contiguous, t2 is not contiguous
[((7, 3, 1), (1, 7, 1)), ((4, 3, 3), (9, 1, 3))], # layout is different
[((5, 7, 4), (7, 1, 35)), ((7, 4, 4), (4, 28, 1))], # layout is different
[
((1, 6, 3, 2), (36, 1, 6, 18)),
((4, 1, 1, 6), (1, 4, 4, 4)),
], # layout is different
[
((6, 1, 6, 3), (18, 1, 1, 6)),
((1, 1, 1, 1), (1, 1, 1, 1)),
], # layout is different
[
((6, 1, 1, 6, 3), (3, 18, 18, 18, 1)),
((5, 1, 2, 1, 1), (2, 10, 1, 10, 1)),
], # layout is different
]
for case in invalid_test_cases:
with self.assertRaisesRegex(Exception, r"."):
_inner(case)
@parametrize("dynamic", [True, False])
@parametrize("backend", ["eager", "aot_eager"])
def test_cond_mismatched_branch_output(self, dynamic, backend):
from torch._dynamo.testing import EagerAndRecordGraphs
class M(torch.nn.Module):
def forward(self, x, y, z):
a = y.shape[0]
b = z.shape[0]
def true_fn(x):
# clone the outputs so branches have the same storage_offset
return (x + a)[2:].clone()
def false_fn(x):
# clone the outputs so branches have the same storage_offset
return (x + b * z)[:2].clone()
ret = torch.cond(x.sum() > 0, true_fn, false_fn, (x,))
return y.sum() - ret
m = M()
x, y, z = torch.randn(5, 4), torch.randn(5, 4), torch.randn(5, 4)
out = m(x, y, z)
if not (backend == "eager" and dynamic and not TEST_WITH_CROSSREF):
compiled_out = torch.compile(
m, backend=backend, dynamic=dynamic, fullgraph=True
)(x, y, z)
self.assertEqual(compiled_out, out)
else:
bk = EagerAndRecordGraphs()
compiled_out = torch.compile(
m, backend=bk, dynamic=dynamic, fullgraph=True
)(x, y, z)
self.assertEqual(compiled_out, out)
self.assertExpectedInline(
normalize_gm(bk.graphs[0].print_readable(print_output=False)),
"""\
class GraphModule(torch.nn.Module):
def forward(self, s0: "Sym(s0)", s1: "Sym(s1)", L_y_: "f32[s0, s1]", L_z_: "f32[s0, s1]", L_x_: "f32[s0, s1]"):
l_y_ = L_y_
l_z_ = L_z_
l_x_ = L_x_
sum_1: "f32[]" = l_x_.sum()
gt: "b8[]" = sum_1 > 0; sum_1 = None
cond_true_0 = self.cond_true_0
cond_false_0 = self.cond_false_0
cond = torch.ops.higher_order.cond(gt, cond_true_0, cond_false_0, [l_x_, s1, s0, s0, l_z_]); gt = cond_true_0 = cond_false_0 = l_x_ = s1 = s0 = l_z_ = None
getitem_5: "f32[u0, s1]" = cond[0]
sym_size_int: "Sym(u0)" = torch.ops.aten.sym_size.int(getitem_5, 0); getitem_5 = None
_check_is_size = torch._check_is_size(sym_size_int); _check_is_size = None
ge: "Sym(u0 >= 0)" = sym_size_int >= 0; sym_size_int = None
_assert_scalar_default = torch.ops.aten._assert_scalar.default(ge, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge = _assert_scalar_default = None
ret: "f32[u0, s1]" = cond[0]; cond = None
sum_2: "f32[]" = l_y_.sum(); l_y_ = None
sub: "f32[u0, s1]" = sum_2 - ret; sum_2 = ret = None
return (sub,)
class cond_true_0(torch.nn.Module):
def forward(self, l_x_, s1, s0_true_branch, getitem_2_false_branch, l_z__false_branch):
l_x__1 = l_x_
s1_1 = s1
add: "f32[s0, s1]" = l_x__1 + s0_true_branch; l_x__1 = s0_true_branch = None
getitem: "f32[s0 - 2, s1]" = add[slice(2, None, None)]; add = None
clone: "f32[s0 - 2, s1]" = getitem.clone(); getitem = None
return (clone,)
class cond_false_0(torch.nn.Module):
def forward(self, l_x_, s1, s0_true_branch, getitem_2_false_branch, l_z__false_branch):
l_x__1 = l_x_
s1_1 = s1
mul: "f32[s0, s1]" = getitem_2_false_branch * l_z__false_branch; getitem_2_false_branch = l_z__false_branch = None
add: "f32[s0, s1]" = l_x__1 + mul; l_x__1 = mul = None
getitem: "f32[2, s1]" = add[slice(None, 2, None)]; add = None
clone: "f32[2, s1]" = getitem.clone(); getitem = None
return (clone,)
""", # noqa: B950
)
@parametrize("dynamic", [True, False])
@parametrize("backend", ["eager", "aot_eager"])
def test_cond_mismatched_branch_strided_output(self, dynamic, backend):
class M(torch.nn.Module):
def forward(self, x, y):
def true_fn(x, y):
return (
(x.swapaxes(-1, 0) + 1)
.unsqueeze(1)
.expand(-1, 5, -1, -1, -1, -1, -1),
torch.empty_strided((3, 3), (0, 1)),
)
def false_fn(x, y):
return (
(y.swapaxes(-1, 0) + 1)
.unsqueeze(1)
.expand(-1, 4, -1, -1, -1, -1, -1),
torch.empty_strided((4, 5), (0, 1)),
)
ret = torch.cond(x.sum() > 0, true_fn, false_fn, (x, y))
return y.sum() + ret[0]
m = M()
x, y = torch.randn(1, 6, 1, 5, 4, 3), torch.randn(1, 4, 5, 1, 3, 8)
out = m(x, y)
compiled_out = torch.compile(
m, backend=backend, dynamic=dynamic, fullgraph=True
)(x, y)
self.assertEqual(compiled_out, out)
_hop_schema_test_schema_types = [ _hop_schema_test_schema_types = [
"bool", "bool",

View File

@ -126,12 +126,12 @@ class CondModels:
def true_fn(x, y): def true_fn(x, y):
z1 = x + y z1 = x + y
z2 = x - y z2 = x - y
return z1[2:], z2[:, 4:] return z1[2:], z2[:, 4:].contiguous()
def false_fn(x, y): def false_fn(x, y):
z1 = x - y z1 = x - y
z2 = x + y z2 = x + y
return z1[2:], z2[:, 4:] return z1[2:], z2[:, 4:].contiguous()
return torch.cond(p, true_fn, false_fn, [a[:-1], b[:-1]]) return torch.cond(p, true_fn, false_fn, [a[:-1], b[:-1]])

View File

@ -935,13 +935,6 @@ class CondHigherOrderVariable(TorchHigherOrderOperatorVariable):
if not same_treespec.as_python_constant(): if not same_treespec.as_python_constant():
unimplemented("Expected branches to return the same pytree structure.") unimplemented("Expected branches to return the same pytree structure.")
check_meta_consistency_vt(
true_r.unpack_var_sequence(tx),
false_r.unpack_var_sequence(tx),
"true_fn_output",
"false_fn_output",
)
( (
true_graph, true_graph,
false_graph, false_graph,

View File

@ -3,7 +3,7 @@
import contextlib import contextlib
import logging import logging
import warnings import warnings
from typing import Any, Callable, Union from typing import Any, Callable, Optional, Union
import torch import torch
import torch._subclasses.functional_tensor import torch._subclasses.functional_tensor
@ -30,7 +30,7 @@ from torch._higher_order_ops.utils import (
validate_subgraph_args_types, validate_subgraph_args_types,
) )
from torch._ops import HigherOrderOperator from torch._ops import HigherOrderOperator
from torch._subclasses.fake_tensor import FakeTensorMode from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode
from torch._subclasses.functional_tensor import disable_functional_mode from torch._subclasses.functional_tensor import disable_functional_mode
from torch.fx.experimental.proxy_tensor import ( from torch.fx.experimental.proxy_tensor import (
_temp_remove_metadata_torch_function_mode, _temp_remove_metadata_torch_function_mode,
@ -39,7 +39,6 @@ from torch.fx.experimental.proxy_tensor import (
ProxyTorchDispatchMode, ProxyTorchDispatchMode,
track_tensor_tree, track_tensor_tree,
) )
from torch.fx.passes.shape_prop import _extract_tensor_metadata
from torch.utils._python_dispatch import _get_current_dispatch_mode from torch.utils._python_dispatch import _get_current_dispatch_mode
from .utils import _from_fun, _maybe_fake_prop_ignore_unbacked, create_fw_bw_graph from .utils import _from_fun, _maybe_fake_prop_ignore_unbacked, create_fw_bw_graph
@ -269,55 +268,6 @@ def trace_cond(proxy_mode, func_overload, pred, true_fn, false_fn, operands):
f"\n false branch returns {len(flat_false_outs)} item(s)" f"\n false branch returns {len(flat_false_outs)} item(s)"
) )
for i in range(0, len(flat_true_outs)):
true_out = flat_true_outs[i]
false_out = flat_false_outs[i]
# Note that we need skip the check for requires_grad because we're after
# after autograd key during tracing, so the rquires_grad attribute of the tensors
# are no longer. See Note [invariants for node meta 'val']
def _same_meta_except_requires_grad(true_out, false_out):
if true_out is None and false_out is None:
return True
elif true_out is None or false_out is None:
# Consider the following case:
# def true_fn(x, y):
# return x * y
#
# def false_fn(x, y):
# return x.sin()
#
# We'll get the following graphs for backward:
# def backward_true_fn(x, y, grad_out):
# return grad_out * y, grad_out * x
#
# def backward_false_fn(x, y, grad_out):
# retrun grad_out, None
#
# This suggests that when we make_fx into the backward graph,
# the output graph would produce outputs with metadata, this is undesirable.
#
# Ideally, we should provide an optional type to indicate that one of the branches might
# return None. But we'll just let it pass for now and let downstream/runtime handle.
#
# Note that this corner case should **only** happen when user want to trace backward graph because
# if it's foward, dynamo will error.
return True
true_meta = true_out.meta.get("tensor_meta", None)
false_meta = false_out.meta.get("tensor_meta", None)
return (
true_meta.shape == false_meta.shape
and true_meta.dtype == false_meta.dtype
and true_meta.stride == false_meta.stride
)
if not _same_meta_except_requires_grad(true_out, false_out):
raise torch._dynamo.exc.CondOpArgsMismatchError(
f"Expected each tensor to have same metadata but got:"
f"\n {true_fn.__name__} returns {true_out.meta['tensor_meta']}"
f"\n {false_fn.__name__} returns {false_out.meta['tensor_meta']}"
)
i, true_name = unique_graph_id(proxy_mode, prefix="true_graph") i, true_name = unique_graph_id(proxy_mode, prefix="true_graph")
false_name = f"false_graph_{i}" false_name = f"false_graph_{i}"
@ -429,30 +379,248 @@ def cond_fake_tensor_mode(mode, pred, true_fn, false_fn, operands):
ignore_fresh_unbacked = mode.shape_env.ignore_fresh_unbacked_symbols() ignore_fresh_unbacked = mode.shape_env.ignore_fresh_unbacked_symbols()
with mode, ignore_fresh_unbacked: with mode, ignore_fresh_unbacked:
true_outs = true_fn(*operands) flat_true_outs, true_out_spec = pytree.tree_flatten(true_fn(*operands))
flat_true_outs = pytree.tree_leaves(true_outs) flat_false_outs, false_out_spec = pytree.tree_flatten(false_fn(*operands))
flat_false_outs = pytree.tree_leaves(false_fn(*operands)) if true_out_spec != false_out_spec:
if len(flat_true_outs) != len(flat_false_outs): raise RuntimeError(
raise RuntimeError("Unmatched number of outputs from cond() branches.") "Unmatched output spec from torch.cond branches: "
f"true branch tree_spec {true_out_spec} vs false branch tree_spec {false_out_spec}."
)
merged_outs = []
for true_out, false_out in zip(flat_true_outs, flat_false_outs): for true_out, false_out in zip(flat_true_outs, flat_false_outs):
if true_out is None or false_out is None: merged_outs.append(_merge_tensors(true_out, false_out, mode))
if true_out is None and false_out is None: return pytree.tree_unflatten(merged_outs, true_out_spec)
def check_tensor_meta_match(
t1: torch.Tensor, t2: torch.Tensor, attr_names: tuple[str, ...], msg_prefix: str
) -> None:
def _get_attr_maybe_call(t: torch.Tensor, attr_name: str) -> Any:
attr = getattr(t, attr_name)
if callable(attr):
return attr()
return attr
for attr_name in attr_names:
lattr = _get_attr_maybe_call(t1, attr_name)
rattr = _get_attr_maybe_call(t2, attr_name)
torch._check(
lattr == rattr,
lambda: f"{msg_prefix} expected same {attr_name} but got {lattr} and {rattr}.",
)
def _merge_tensors(
a: Optional[torch.Tensor], b: Optional[torch.Tensor], mode: FakeTensorMode
):
from torch.fx.experimental.symbolic_shapes import SymIntEqByExpr
if a is None or b is None:
assert a is None and b is None, (a, b)
return None
assert type(a) is FakeTensor and type(b) is FakeTensor, (a, type(a), b, type(b))
# Note: we don't check size, stride because
# they'll be merged with unbacked symints if they differ.
_meta_to_check = {
"dtype",
"device",
"layout",
"dim",
"is_quantized",
"is_conj",
"is_sparse",
"storage_offset",
}
check_tensor_meta_match(
a,
b,
tuple(_meta_to_check),
msg_prefix="When merging two branches' output in torch.cond, ",
)
# NYI
assert not a.is_quantized and not b.is_quantized
assert not a.is_sparse and not b.is_sparse
assert not a.is_conj() and not b.is_conj()
"""
Step 1: create unbacked symints for sizes that are different
along the same axis. For example:
a.size is [s0, 4, s0, 5, 4, 5]
b.size is [s1, 4, s2, 8, 4, 7]
merged_size will be [u0, 4, u1, u2, 4, u3], where
u0 has range [min(s0, s1), max(s0, s1)]
u1 has range [min(s0, s2), max(s0, s2)]
u2 has range [5, 8]
u3 has range [5, 7]
"""
merged_size: list[Union[int, torch.SymInt]] = []
for s0, s1 in zip(a.size(), b.size()):
if SymIntEqByExpr(s0) == SymIntEqByExpr(s1):
merged_size.append(s0)
else:
def min_max(s0, s1):
def _bound(s0, lower_bound: bool):
if isinstance(s0, int):
return s0
r = mode.shape_env.var_to_range.get( # type: ignore[union-attr]
s0.node.expr,
torch.utils._sympy.value_ranges.ValueRanges.unknown(),
)
return r.lower if lower_bound else r.upper
return min(_bound(s0, True), _bound(s1, True)), max(
_bound(s0, False), _bound(s1, False)
)
assert mode.shape_env is not None
new_size = mode.shape_env.create_unbacked_symint()
mode.shape_env.constrain_symbol_range(new_size.node.expr, *min_max(s0, s1))
merged_size.append(new_size)
"""
This follows the logic in symbolic_shapes._compute_symbolic_stride
Step 2: Since tensor stride is an accumulative muliplication of the sizes, which is a permutated
(due to view ops) non-decending sequence.
Case 1: No size is 1. In this case, strides have unique values.
For example, suppose we have a tenosr with:
size [3, 4, 3, 5, 4, 5],
stride (1200, 300, 1, 12, 3, 60),
merged_size [u0, u1, u2, u3, u4, u5].
We visit the strides in ascending order: 1, 3, 12, 60, 300, 1200. In each step, we check whether
the current stride is bounded or not and bound next stride by setting.
stride_expr[next_stride] = current_stride_expr * current_size_expr
1st round:
current_stride is 1, current_size is 3, so next_stride is 1 * 3 = 3,
current_stride_expr is set to 1, current_size_expr is u2, so stride_expr[3] is therefore 1 * u2 = u2
2nd round:
current_stride is 3, current_size is 4, so next_stride is 3 * 4 = 12,
current_stride_expr is stride_expr[3] i.e. u2, current_size_expr is u4, so stride_expr[12] = u2 * u4
...
Case 2: At least one dimension has size 1, which can produce duplicates in strides.
In this case, theorectically, we cannot uniquely determine the expr of strides because
the accessing stride_expr with same key in different order causes the final stride expression
to be different.
Suppose we have:
size: (3, 1)
stride: (1, 1)
merged_size: (u0, u1)
The stride expr could either be (u1, 1) or (1, u0) depending on whether we start with u1 or u0.
For this reason, we try to break tie by sorting via decending index so we always get (u1, 1).
Note that backend might optimize the strides anyway so this is usually not a problem as long
as two branches matches. See relevant discussions in https://github.com/pytorch/pytorch/issues/142024.
Case 3: Dim has 0 stride. 0 stride doesn't participate in the accumulative multiplication of
sizes. So they're always treated as constant even if their corresponding size is turned into unbacked symint.
Suppose we have:
size: (3, 3)
stride: (0, 1)
merged_size: (u0, u1)
The merged stride would be (0, 1)
"""
def _bound_stride(
a_ex_size: torch.Size,
b_ex_size: torch.Size,
a_ex_stride: tuple[int, ...],
b_ex_stride: tuple[int, ...],
merged_size: list[Union[int, torch.SymInt]],
) -> list[Union[int, torch.SymInt]]:
from torch._inductor.ir import get_stride_order
a_sorted_stride_idx = get_stride_order(a_ex_stride, mode.shape_env)
b_sorted_stride_idx = get_stride_order(b_ex_stride, mode.shape_env)
a_stride_li: list[Optional[tuple[Union[int, torch.SymInt], int]]] = [
None
] * len(a_ex_stride)
b_stride_li: list[Optional[tuple[Union[int, torch.SymInt], int]]] = [
None
] * len(b_ex_stride)
for i, idx in enumerate(a_sorted_stride_idx):
a_stride_li[idx] = (a_ex_stride[i], -i)
for i, idx in enumerate(b_sorted_stride_idx):
b_stride_li[idx] = (b_ex_stride[i], -i)
for a_pair, b_pair in zip(a_stride_li, b_stride_li):
assert a_pair is not None and b_pair is not None
_, a_idx = a_pair
_, b_idx = b_pair
if a_idx != b_idx:
raise RuntimeError(
f"The sorted order of strides of the two branches' output doesn't match."
f"this indicates the contiguousness of the two branches are different. "
f"True branch has stride {a_ex_stride} but false branch has stride {b_ex_stride}."
f"Consider using contiguous() to make the two branches have the same contiguousness."
)
def _maybe_expr(s: Union[int, torch.SymInt]):
if isinstance(s, int):
return s
return s.node.expr
a_stride_expr: dict[Any, Union[int, torch.SymInt]] = {}
b_stride_expr: dict[Any, Union[int, torch.SymInt]] = {}
merged_strides: list[Union[int, torch.SymInt]] = [None] * len(a_ex_stride) # type: ignore[list-item]
for a_pair, b_pair in zip(a_stride_li, b_stride_li):
assert a_pair is not None and b_pair is not None
a_val, neg_i = a_pair
b_val, _ = b_pair
i = -neg_i
if a_val == 0:
assert b_val == 0, (a_val, b_val)
merged_strides[i] = 0
continue continue
raise torch._dynamo.exc.CondOpArgsMismatchError(
f"Expected both branches to return None:" if _maybe_expr(a_val) in a_stride_expr:
f"\n {true_fn.__name__} returns {true_out}" a_expr = a_stride_expr[_maybe_expr(a_val)]
f"\n {false_fn.__name__} returns {false_out}" assert (
) b_stride_expr[_maybe_expr(b_val)] == a_expr
true_meta = _extract_tensor_metadata(true_out) ), f"a_stride_expr:{a_stride_expr}, b_stride_expr:{b_stride_expr}"
false_meta = _extract_tensor_metadata(false_out) merged_strides[i] = a_expr
if true_meta != false_meta: else:
raise torch._dynamo.exc.CondOpArgsMismatchError( if a_val == 1:
f"Expected each tensor to have same metadata but got:" assert b_val == 1
f"\n {true_fn.__name__} returns {true_meta}" a_stride_expr[_maybe_expr(a_val)] = 1
f"\n {false_fn.__name__} returns {false_meta}" b_stride_expr[_maybe_expr(b_val)] = 1
) merged_strides[i] = 1
return true_outs else:
# If we cannot find the expr of a_val in a_stride_expr, it means
# the strides is not a simple accumulative multiplication of sizes.
# In this case, we cannot determine the expr of strides from the new
# shapes so we error out and hint users to call contiguous().
raise RuntimeError(
f"It seems one of cond's output stride is not a simple accumulative multiplication of sizes. "
f"This could be because cond returns a slice of a tensor, which is not dense in memory. "
f"True branch has size {a_ex_size}, stride {a_ex_stride} and false branch has size {b_ex_size} "
f"stride {b_ex_stride}. Hint: can call t.contiguous(). "
)
nxt_merged_stride_expr = merged_strides[i] * merged_size[i]
a_stride_expr[_maybe_expr(a_val * a_ex_size[i])] = nxt_merged_stride_expr
b_stride_expr[_maybe_expr(b_val * b_ex_size[i])] = nxt_merged_stride_expr
return merged_strides
merged_stride: list[Union[int, torch.SymInt]] = _bound_stride(
a.size(), b.size(), a.stride(), b.stride(), merged_size
)
with mode:
return torch.empty_strided(
merged_size, merged_stride, dtype=a.dtype, device=a.device
)
@cond_op.py_functionalize_impl @cond_op.py_functionalize_impl