mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[cond] support output sizes mismatch in front end (#147130)
This PR finishes https://github.com/pytorch/pytorch/pull/137615 by addressing the TODOs and comments left there. Pull Request resolved: https://github.com/pytorch/pytorch/pull/147130 Approved by: https://github.com/zou3519
This commit is contained in:
parent
de80b6f0d3
commit
824474cb35
|
|
@ -1897,14 +1897,18 @@ def forward(self, x, y):
|
|||
def forward(self, x):
|
||||
arg0, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
|
||||
l_x_ = arg0
|
||||
size = l_x_.size()
|
||||
getitem = size[0]; size = None
|
||||
le = getitem <= 2; getitem = None
|
||||
sym_size_int = torch.ops.aten.sym_size.int(l_x_, 0)
|
||||
le = sym_size_int <= 2; sym_size_int = None
|
||||
cond_true_0 = self.cond_true_0
|
||||
cond_false_0 = self.cond_false_0
|
||||
cond = torch.ops.higher_order.cond(le, cond_true_0, cond_false_0, [l_x_]); le = cond_true_0 = cond_false_0 = l_x_ = None
|
||||
getitem_3 = cond[0]
|
||||
sym_size_int_1 = torch.ops.aten.sym_size.int(getitem_3, 0); getitem_3 = None
|
||||
sym_constrain_range_for_size_default = torch.ops.aten.sym_constrain_range_for_size.default(sym_size_int_1); sym_constrain_range_for_size_default = None
|
||||
ge = sym_size_int_1 >= 2; sym_size_int_1 = None
|
||||
_assert_scalar_default = torch.ops.aten._assert_scalar.default(ge, "Runtime assertion failed for expression u0 >= 2 on node 'ge'"); ge = _assert_scalar_default = None
|
||||
getitem_2 = cond[0]; cond = None
|
||||
return pytree.tree_unflatten([getitem_2], self._out_spec)""",
|
||||
return pytree.tree_unflatten([getitem_2], self._out_spec)""", # noqa: B950
|
||||
)
|
||||
self.assertExpectedInline(
|
||||
out_graph.cond_true_0.code.strip(),
|
||||
|
|
@ -1922,12 +1926,8 @@ def forward(self, l_x_):
|
|||
getitem = l_x__1[slice(None, 2, None)]; l_x__1 = None
|
||||
return (getitem,)""",
|
||||
)
|
||||
with self.assertRaisesRegex(
|
||||
torch._dynamo.exc.UncapturedHigherOrderOpError,
|
||||
"Expected true_fn_output and false_fn_output to have same metadata but found",
|
||||
):
|
||||
# True branch and false branch return tensors of different shape
|
||||
torch._dynamo.export(mod)(torch.randn(3, 2))
|
||||
# We could successfully export branches that return different sizes
|
||||
torch._dynamo.export(mod)(torch.randn(3, 2))
|
||||
|
||||
# We specialize into one of the branches since predicate is a python boolean.
|
||||
test_x = torch.randn(3, 2)
|
||||
|
|
@ -3334,8 +3334,8 @@ def forward(self, x):
|
|||
|
||||
example_inputs = (torch.rand(5),)
|
||||
with self.assertRaisesRegex(
|
||||
torch._dynamo.exc.UncapturedHigherOrderOpError,
|
||||
"Expected true_fn_output and false_fn_output to have same number of outputs but got",
|
||||
torch._dynamo.exc.TorchRuntimeError,
|
||||
"Unmatched output spec from torch.cond branches",
|
||||
):
|
||||
torch._dynamo.export(
|
||||
f_mismatch_return_length,
|
||||
|
|
@ -3354,8 +3354,8 @@ def forward(self, x):
|
|||
|
||||
example_inputs = (torch.rand(5),)
|
||||
with self.assertRaisesRegex(
|
||||
torch._dynamo.exc.UncapturedHigherOrderOpError,
|
||||
"Expected true_fn_output and false_fn_output to have same metadata but found",
|
||||
torch._dynamo.exc.TorchRuntimeError,
|
||||
"When merging two branches' output in torch.cond",
|
||||
):
|
||||
torch._dynamo.export(f_return_tensor_mismatch, aten_graph=True)(
|
||||
*example_inputs,
|
||||
|
|
|
|||
|
|
@ -6961,11 +6961,9 @@ class ActivationCheckpointingTests(torch._dynamo.test_case.TestCase):
|
|||
return torch.cond(x.sum() > 0, true_fn, false_fn)
|
||||
|
||||
x = torch.randn(2, 3)
|
||||
with self.assertRaises(torch._dynamo.exc.UncapturedHigherOrderOpError):
|
||||
output_mismatch_test(x)
|
||||
output_mismatch_test(x)
|
||||
|
||||
with self.assertRaises(torch._dynamo.exc.UncapturedHigherOrderOpError):
|
||||
torch.compile(output_mismatch_test)(x)
|
||||
torch.compile(output_mismatch_test, backend="eager")(x)
|
||||
|
||||
def test_non_aliasing_util(self):
|
||||
from torch._dynamo.variables.higher_order_ops import _assert_tensors_nonaliasing
|
||||
|
|
|
|||
|
|
@ -968,7 +968,7 @@ def forward(self, pred_1):
|
|||
result = cond(pred, true_fn, false_fn, ({"t": [a, {"b": b}, (c,)]},))
|
||||
return result
|
||||
|
||||
gm = make_fx(f, tracing_mode="symbolic", _allow_non_fake_inputs=True)(pred)
|
||||
gm = make_fx(f, tracing_mode="real", _allow_non_fake_inputs=True)(pred)
|
||||
self.assertExpectedInline(
|
||||
gm.code.strip(),
|
||||
"""\
|
||||
|
|
@ -5379,8 +5379,8 @@ def forward(self, arg0_1):
|
|||
|
||||
x = torch.randn(4)
|
||||
with self.assertRaisesRegex(
|
||||
torch._dynamo.exc.UncapturedHigherOrderOpError,
|
||||
"Expected true_fn_output and false_fn_output to have same number of outputs but got",
|
||||
torch._dynamo.exc.TorchRuntimeError,
|
||||
"Unmatched output spec from torch.cond branches",
|
||||
):
|
||||
make_fx(f)(x, torch.tensor(False))
|
||||
|
||||
|
|
@ -5396,8 +5396,8 @@ def forward(self, arg0_1):
|
|||
|
||||
x = torch.randn(4)
|
||||
with self.assertRaisesRegex(
|
||||
torch._dynamo.exc.UncapturedHigherOrderOpError,
|
||||
"Expected true_fn_output and false_fn_output to have same metadata but found",
|
||||
torch._dynamo.exc.TorchRuntimeError,
|
||||
"When merging two branches' output in torch.cond",
|
||||
):
|
||||
make_fx(f)(x, torch.tensor(False))
|
||||
|
||||
|
|
@ -5552,8 +5552,8 @@ def forward(self, arg0_1):
|
|||
|
||||
x = torch.randn(4)
|
||||
with self.assertRaisesRegex(
|
||||
torch._dynamo.exc.UncapturedHigherOrderOpError,
|
||||
"Expected true_fn_output and false_fn_output to have same number of outputs but got",
|
||||
torch._dynamo.exc.TorchRuntimeError,
|
||||
"Unmatched output spec from torch.cond branches",
|
||||
):
|
||||
make_fx(f, tracing_mode="fake")(x, torch.tensor(False))
|
||||
|
||||
|
|
@ -5569,8 +5569,8 @@ def forward(self, arg0_1):
|
|||
|
||||
x = torch.randn(4)
|
||||
with self.assertRaisesRegex(
|
||||
torch._dynamo.exc.UncapturedHigherOrderOpError,
|
||||
"Expected true_fn_output and false_fn_output to have same metadata but found",
|
||||
torch._dynamo.exc.TorchRuntimeError,
|
||||
"When merging two branches' output in torch.cond",
|
||||
):
|
||||
make_fx(f, tracing_mode="fake")(x, torch.tensor(False))
|
||||
|
||||
|
|
@ -7547,6 +7547,215 @@ class GraphModule(torch.nn.Module):
|
|||
dynamic_shapes = {"x": {0: Dim("d")}, "y": {0: Dim("d1")}, "z": {0: Dim("d")}}
|
||||
_ = self._check_export_ret_graph_str(model, args, dynamic_shapes)
|
||||
|
||||
@skipIfTorchDynamo(
|
||||
"Skip because _merge_tensors is not intended for dynamo to compile"
|
||||
)
|
||||
def test_merge_tensors(self):
|
||||
from torch._higher_order_ops.cond import _merge_tensors
|
||||
from torch._subclasses.fake_tensor import FakeTensorMode
|
||||
from torch.fx.experimental.symbolic_shapes import ShapeEnv
|
||||
|
||||
# The shapes and strides are from raondomly generated pairs of tensors then swapaxes
|
||||
valid_test_cases = [
|
||||
# [(size1, stride1), (size2, stride2), (expected_stride, expected_size)]
|
||||
[((3,), (1,)), ((4,), (1,)), ("(u0,)", "(1,)")],
|
||||
[((1, 3), (3, 1)), ((3, 2), (2, 1)), ("(u0, u1)", "(u1, 1)")],
|
||||
[((2, 1), (1, 1)), ((7, 3), (3, 1)), ("(u0, u1)", "(u1, 1)")],
|
||||
[((5, 5), (1, 5)), ((4, 5), (1, 4)), ("(u0, 5)", "(1, u0)")],
|
||||
[
|
||||
((7, 3, 1), (1, 7, 1)),
|
||||
((4, 3, 3), (3, 12, 1)),
|
||||
("(u0, 3, u1)", "(u1, u0*u1, 1)"),
|
||||
],
|
||||
[
|
||||
((5, 7, 4), (7, 1, 35)),
|
||||
((7, 4, 4), (4, 1, 28)),
|
||||
("(u0, u1, 4)", "(u1, 1, u0*u1)"),
|
||||
],
|
||||
[
|
||||
((1, 6, 3, 2), (36, 1, 6, 18)),
|
||||
((4, 2, 2, 6), (24, 1, 2, 4)),
|
||||
("(u0, u1, u2, u3)", "(u1*u2*u3, 1, u1, u1*u2)"),
|
||||
],
|
||||
[
|
||||
((6, 1, 6, 3), (18, 1, 1, 6)),
|
||||
((2, 1, 3, 4), (12, 1, 1, 3)),
|
||||
("(u0, 1, u1, u2)", "(u1*u2, 1, 1, u1)"),
|
||||
],
|
||||
[
|
||||
((3, 1, 2, 4, 1), (8, 8, 4, 1, 1)),
|
||||
((2, 4, 1, 4, 1), (16, 4, 4, 1, 1)),
|
||||
("(u0, u1, u2, 4, 1)", "(4*u1*u2, 4*u2, 4, 1, 1)"),
|
||||
],
|
||||
]
|
||||
|
||||
def _inner(case):
|
||||
fake_mode = FakeTensorMode(shape_env=ShapeEnv())
|
||||
|
||||
(size1, stride1), (size2, stride2), (merged_size, merged_stride) = case
|
||||
with fake_mode:
|
||||
t1 = torch.empty_strided(size1, stride1)
|
||||
t2 = torch.empty_strided(size2, stride2)
|
||||
out = _merge_tensors(t1, t2, fake_mode)
|
||||
self.assertEqual(str(tuple(out.size())), merged_size)
|
||||
self.assertEqual(str(tuple(out.stride())), merged_stride)
|
||||
|
||||
for case in valid_test_cases:
|
||||
_inner(case)
|
||||
|
||||
# The shapes and strides are from raondomly generated pairs of tensors then swapaxes
|
||||
invalid_test_cases = [
|
||||
# [(size1, stride1), (size2, stride2)]
|
||||
[((1,), (1,)), ((1,), (0,))],
|
||||
[
|
||||
((1, 3), (1, 1)),
|
||||
((5, 6), (6, 1)),
|
||||
], # t1 is not contiguous, t2 is contiguous
|
||||
[
|
||||
((2, 1), (1, 1)),
|
||||
((7, 3), (1, 3)),
|
||||
], # t1 is contiguous, t2 is not contiguous
|
||||
[
|
||||
((5, 4), (4, 1)),
|
||||
((5, 5), (1, 5)),
|
||||
], # t1 is contiguous, t2 is not contiguous
|
||||
[((7, 3, 1), (1, 7, 1)), ((4, 3, 3), (9, 1, 3))], # layout is different
|
||||
[((5, 7, 4), (7, 1, 35)), ((7, 4, 4), (4, 28, 1))], # layout is different
|
||||
[
|
||||
((1, 6, 3, 2), (36, 1, 6, 18)),
|
||||
((4, 1, 1, 6), (1, 4, 4, 4)),
|
||||
], # layout is different
|
||||
[
|
||||
((6, 1, 6, 3), (18, 1, 1, 6)),
|
||||
((1, 1, 1, 1), (1, 1, 1, 1)),
|
||||
], # layout is different
|
||||
[
|
||||
((6, 1, 1, 6, 3), (3, 18, 18, 18, 1)),
|
||||
((5, 1, 2, 1, 1), (2, 10, 1, 10, 1)),
|
||||
], # layout is different
|
||||
]
|
||||
for case in invalid_test_cases:
|
||||
with self.assertRaisesRegex(Exception, r"."):
|
||||
_inner(case)
|
||||
|
||||
@parametrize("dynamic", [True, False])
|
||||
@parametrize("backend", ["eager", "aot_eager"])
|
||||
def test_cond_mismatched_branch_output(self, dynamic, backend):
|
||||
from torch._dynamo.testing import EagerAndRecordGraphs
|
||||
|
||||
class M(torch.nn.Module):
|
||||
def forward(self, x, y, z):
|
||||
a = y.shape[0]
|
||||
b = z.shape[0]
|
||||
|
||||
def true_fn(x):
|
||||
# clone the outputs so branches have the same storage_offset
|
||||
return (x + a)[2:].clone()
|
||||
|
||||
def false_fn(x):
|
||||
# clone the outputs so branches have the same storage_offset
|
||||
return (x + b * z)[:2].clone()
|
||||
|
||||
ret = torch.cond(x.sum() > 0, true_fn, false_fn, (x,))
|
||||
return y.sum() - ret
|
||||
|
||||
m = M()
|
||||
x, y, z = torch.randn(5, 4), torch.randn(5, 4), torch.randn(5, 4)
|
||||
out = m(x, y, z)
|
||||
if not (backend == "eager" and dynamic and not TEST_WITH_CROSSREF):
|
||||
compiled_out = torch.compile(
|
||||
m, backend=backend, dynamic=dynamic, fullgraph=True
|
||||
)(x, y, z)
|
||||
self.assertEqual(compiled_out, out)
|
||||
else:
|
||||
bk = EagerAndRecordGraphs()
|
||||
compiled_out = torch.compile(
|
||||
m, backend=bk, dynamic=dynamic, fullgraph=True
|
||||
)(x, y, z)
|
||||
self.assertEqual(compiled_out, out)
|
||||
self.assertExpectedInline(
|
||||
normalize_gm(bk.graphs[0].print_readable(print_output=False)),
|
||||
"""\
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, s0: "Sym(s0)", s1: "Sym(s1)", L_y_: "f32[s0, s1]", L_z_: "f32[s0, s1]", L_x_: "f32[s0, s1]"):
|
||||
l_y_ = L_y_
|
||||
l_z_ = L_z_
|
||||
l_x_ = L_x_
|
||||
|
||||
sum_1: "f32[]" = l_x_.sum()
|
||||
gt: "b8[]" = sum_1 > 0; sum_1 = None
|
||||
|
||||
cond_true_0 = self.cond_true_0
|
||||
cond_false_0 = self.cond_false_0
|
||||
cond = torch.ops.higher_order.cond(gt, cond_true_0, cond_false_0, [l_x_, s1, s0, s0, l_z_]); gt = cond_true_0 = cond_false_0 = l_x_ = s1 = s0 = l_z_ = None
|
||||
|
||||
getitem_5: "f32[u0, s1]" = cond[0]
|
||||
sym_size_int: "Sym(u0)" = torch.ops.aten.sym_size.int(getitem_5, 0); getitem_5 = None
|
||||
_check_is_size = torch._check_is_size(sym_size_int); _check_is_size = None
|
||||
|
||||
ge: "Sym(u0 >= 0)" = sym_size_int >= 0; sym_size_int = None
|
||||
_assert_scalar_default = torch.ops.aten._assert_scalar.default(ge, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge = _assert_scalar_default = None
|
||||
ret: "f32[u0, s1]" = cond[0]; cond = None
|
||||
|
||||
sum_2: "f32[]" = l_y_.sum(); l_y_ = None
|
||||
sub: "f32[u0, s1]" = sum_2 - ret; sum_2 = ret = None
|
||||
return (sub,)
|
||||
|
||||
class cond_true_0(torch.nn.Module):
|
||||
def forward(self, l_x_, s1, s0_true_branch, getitem_2_false_branch, l_z__false_branch):
|
||||
l_x__1 = l_x_
|
||||
s1_1 = s1
|
||||
|
||||
add: "f32[s0, s1]" = l_x__1 + s0_true_branch; l_x__1 = s0_true_branch = None
|
||||
getitem: "f32[s0 - 2, s1]" = add[slice(2, None, None)]; add = None
|
||||
clone: "f32[s0 - 2, s1]" = getitem.clone(); getitem = None
|
||||
return (clone,)
|
||||
|
||||
class cond_false_0(torch.nn.Module):
|
||||
def forward(self, l_x_, s1, s0_true_branch, getitem_2_false_branch, l_z__false_branch):
|
||||
l_x__1 = l_x_
|
||||
s1_1 = s1
|
||||
|
||||
mul: "f32[s0, s1]" = getitem_2_false_branch * l_z__false_branch; getitem_2_false_branch = l_z__false_branch = None
|
||||
add: "f32[s0, s1]" = l_x__1 + mul; l_x__1 = mul = None
|
||||
getitem: "f32[2, s1]" = add[slice(None, 2, None)]; add = None
|
||||
clone: "f32[2, s1]" = getitem.clone(); getitem = None
|
||||
return (clone,)
|
||||
""", # noqa: B950
|
||||
)
|
||||
|
||||
@parametrize("dynamic", [True, False])
|
||||
@parametrize("backend", ["eager", "aot_eager"])
|
||||
def test_cond_mismatched_branch_strided_output(self, dynamic, backend):
|
||||
class M(torch.nn.Module):
|
||||
def forward(self, x, y):
|
||||
def true_fn(x, y):
|
||||
return (
|
||||
(x.swapaxes(-1, 0) + 1)
|
||||
.unsqueeze(1)
|
||||
.expand(-1, 5, -1, -1, -1, -1, -1),
|
||||
torch.empty_strided((3, 3), (0, 1)),
|
||||
)
|
||||
|
||||
def false_fn(x, y):
|
||||
return (
|
||||
(y.swapaxes(-1, 0) + 1)
|
||||
.unsqueeze(1)
|
||||
.expand(-1, 4, -1, -1, -1, -1, -1),
|
||||
torch.empty_strided((4, 5), (0, 1)),
|
||||
)
|
||||
|
||||
ret = torch.cond(x.sum() > 0, true_fn, false_fn, (x, y))
|
||||
return y.sum() + ret[0]
|
||||
|
||||
m = M()
|
||||
x, y = torch.randn(1, 6, 1, 5, 4, 3), torch.randn(1, 4, 5, 1, 3, 8)
|
||||
out = m(x, y)
|
||||
compiled_out = torch.compile(
|
||||
m, backend=backend, dynamic=dynamic, fullgraph=True
|
||||
)(x, y)
|
||||
self.assertEqual(compiled_out, out)
|
||||
|
||||
|
||||
_hop_schema_test_schema_types = [
|
||||
"bool",
|
||||
|
|
|
|||
|
|
@ -126,12 +126,12 @@ class CondModels:
|
|||
def true_fn(x, y):
|
||||
z1 = x + y
|
||||
z2 = x - y
|
||||
return z1[2:], z2[:, 4:]
|
||||
return z1[2:], z2[:, 4:].contiguous()
|
||||
|
||||
def false_fn(x, y):
|
||||
z1 = x - y
|
||||
z2 = x + y
|
||||
return z1[2:], z2[:, 4:]
|
||||
return z1[2:], z2[:, 4:].contiguous()
|
||||
|
||||
return torch.cond(p, true_fn, false_fn, [a[:-1], b[:-1]])
|
||||
|
||||
|
|
|
|||
|
|
@ -935,13 +935,6 @@ class CondHigherOrderVariable(TorchHigherOrderOperatorVariable):
|
|||
if not same_treespec.as_python_constant():
|
||||
unimplemented("Expected branches to return the same pytree structure.")
|
||||
|
||||
check_meta_consistency_vt(
|
||||
true_r.unpack_var_sequence(tx),
|
||||
false_r.unpack_var_sequence(tx),
|
||||
"true_fn_output",
|
||||
"false_fn_output",
|
||||
)
|
||||
|
||||
(
|
||||
true_graph,
|
||||
false_graph,
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@
|
|||
import contextlib
|
||||
import logging
|
||||
import warnings
|
||||
from typing import Any, Callable, Union
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch._subclasses.functional_tensor
|
||||
|
|
@ -30,7 +30,7 @@ from torch._higher_order_ops.utils import (
|
|||
validate_subgraph_args_types,
|
||||
)
|
||||
from torch._ops import HigherOrderOperator
|
||||
from torch._subclasses.fake_tensor import FakeTensorMode
|
||||
from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode
|
||||
from torch._subclasses.functional_tensor import disable_functional_mode
|
||||
from torch.fx.experimental.proxy_tensor import (
|
||||
_temp_remove_metadata_torch_function_mode,
|
||||
|
|
@ -39,7 +39,6 @@ from torch.fx.experimental.proxy_tensor import (
|
|||
ProxyTorchDispatchMode,
|
||||
track_tensor_tree,
|
||||
)
|
||||
from torch.fx.passes.shape_prop import _extract_tensor_metadata
|
||||
from torch.utils._python_dispatch import _get_current_dispatch_mode
|
||||
|
||||
from .utils import _from_fun, _maybe_fake_prop_ignore_unbacked, create_fw_bw_graph
|
||||
|
|
@ -269,55 +268,6 @@ def trace_cond(proxy_mode, func_overload, pred, true_fn, false_fn, operands):
|
|||
f"\n false branch returns {len(flat_false_outs)} item(s)"
|
||||
)
|
||||
|
||||
for i in range(0, len(flat_true_outs)):
|
||||
true_out = flat_true_outs[i]
|
||||
false_out = flat_false_outs[i]
|
||||
|
||||
# Note that we need skip the check for requires_grad because we're after
|
||||
# after autograd key during tracing, so the rquires_grad attribute of the tensors
|
||||
# are no longer. See Note [invariants for node meta 'val']
|
||||
def _same_meta_except_requires_grad(true_out, false_out):
|
||||
if true_out is None and false_out is None:
|
||||
return True
|
||||
elif true_out is None or false_out is None:
|
||||
# Consider the following case:
|
||||
# def true_fn(x, y):
|
||||
# return x * y
|
||||
#
|
||||
# def false_fn(x, y):
|
||||
# return x.sin()
|
||||
#
|
||||
# We'll get the following graphs for backward:
|
||||
# def backward_true_fn(x, y, grad_out):
|
||||
# return grad_out * y, grad_out * x
|
||||
#
|
||||
# def backward_false_fn(x, y, grad_out):
|
||||
# retrun grad_out, None
|
||||
#
|
||||
# This suggests that when we make_fx into the backward graph,
|
||||
# the output graph would produce outputs with metadata, this is undesirable.
|
||||
#
|
||||
# Ideally, we should provide an optional type to indicate that one of the branches might
|
||||
# return None. But we'll just let it pass for now and let downstream/runtime handle.
|
||||
#
|
||||
# Note that this corner case should **only** happen when user want to trace backward graph because
|
||||
# if it's foward, dynamo will error.
|
||||
return True
|
||||
true_meta = true_out.meta.get("tensor_meta", None)
|
||||
false_meta = false_out.meta.get("tensor_meta", None)
|
||||
return (
|
||||
true_meta.shape == false_meta.shape
|
||||
and true_meta.dtype == false_meta.dtype
|
||||
and true_meta.stride == false_meta.stride
|
||||
)
|
||||
|
||||
if not _same_meta_except_requires_grad(true_out, false_out):
|
||||
raise torch._dynamo.exc.CondOpArgsMismatchError(
|
||||
f"Expected each tensor to have same metadata but got:"
|
||||
f"\n {true_fn.__name__} returns {true_out.meta['tensor_meta']}"
|
||||
f"\n {false_fn.__name__} returns {false_out.meta['tensor_meta']}"
|
||||
)
|
||||
|
||||
i, true_name = unique_graph_id(proxy_mode, prefix="true_graph")
|
||||
|
||||
false_name = f"false_graph_{i}"
|
||||
|
|
@ -429,30 +379,248 @@ def cond_fake_tensor_mode(mode, pred, true_fn, false_fn, operands):
|
|||
ignore_fresh_unbacked = mode.shape_env.ignore_fresh_unbacked_symbols()
|
||||
|
||||
with mode, ignore_fresh_unbacked:
|
||||
true_outs = true_fn(*operands)
|
||||
flat_true_outs = pytree.tree_leaves(true_outs)
|
||||
flat_false_outs = pytree.tree_leaves(false_fn(*operands))
|
||||
if len(flat_true_outs) != len(flat_false_outs):
|
||||
raise RuntimeError("Unmatched number of outputs from cond() branches.")
|
||||
flat_true_outs, true_out_spec = pytree.tree_flatten(true_fn(*operands))
|
||||
flat_false_outs, false_out_spec = pytree.tree_flatten(false_fn(*operands))
|
||||
if true_out_spec != false_out_spec:
|
||||
raise RuntimeError(
|
||||
"Unmatched output spec from torch.cond branches: "
|
||||
f"true branch tree_spec {true_out_spec} vs false branch tree_spec {false_out_spec}."
|
||||
)
|
||||
|
||||
merged_outs = []
|
||||
for true_out, false_out in zip(flat_true_outs, flat_false_outs):
|
||||
if true_out is None or false_out is None:
|
||||
if true_out is None and false_out is None:
|
||||
merged_outs.append(_merge_tensors(true_out, false_out, mode))
|
||||
return pytree.tree_unflatten(merged_outs, true_out_spec)
|
||||
|
||||
|
||||
def check_tensor_meta_match(
|
||||
t1: torch.Tensor, t2: torch.Tensor, attr_names: tuple[str, ...], msg_prefix: str
|
||||
) -> None:
|
||||
def _get_attr_maybe_call(t: torch.Tensor, attr_name: str) -> Any:
|
||||
attr = getattr(t, attr_name)
|
||||
if callable(attr):
|
||||
return attr()
|
||||
return attr
|
||||
|
||||
for attr_name in attr_names:
|
||||
lattr = _get_attr_maybe_call(t1, attr_name)
|
||||
rattr = _get_attr_maybe_call(t2, attr_name)
|
||||
torch._check(
|
||||
lattr == rattr,
|
||||
lambda: f"{msg_prefix} expected same {attr_name} but got {lattr} and {rattr}.",
|
||||
)
|
||||
|
||||
|
||||
def _merge_tensors(
|
||||
a: Optional[torch.Tensor], b: Optional[torch.Tensor], mode: FakeTensorMode
|
||||
):
|
||||
from torch.fx.experimental.symbolic_shapes import SymIntEqByExpr
|
||||
|
||||
if a is None or b is None:
|
||||
assert a is None and b is None, (a, b)
|
||||
return None
|
||||
|
||||
assert type(a) is FakeTensor and type(b) is FakeTensor, (a, type(a), b, type(b))
|
||||
|
||||
# Note: we don't check size, stride because
|
||||
# they'll be merged with unbacked symints if they differ.
|
||||
_meta_to_check = {
|
||||
"dtype",
|
||||
"device",
|
||||
"layout",
|
||||
"dim",
|
||||
"is_quantized",
|
||||
"is_conj",
|
||||
"is_sparse",
|
||||
"storage_offset",
|
||||
}
|
||||
check_tensor_meta_match(
|
||||
a,
|
||||
b,
|
||||
tuple(_meta_to_check),
|
||||
msg_prefix="When merging two branches' output in torch.cond, ",
|
||||
)
|
||||
# NYI
|
||||
assert not a.is_quantized and not b.is_quantized
|
||||
assert not a.is_sparse and not b.is_sparse
|
||||
assert not a.is_conj() and not b.is_conj()
|
||||
|
||||
"""
|
||||
Step 1: create unbacked symints for sizes that are different
|
||||
along the same axis. For example:
|
||||
a.size is [s0, 4, s0, 5, 4, 5]
|
||||
b.size is [s1, 4, s2, 8, 4, 7]
|
||||
merged_size will be [u0, 4, u1, u2, 4, u3], where
|
||||
u0 has range [min(s0, s1), max(s0, s1)]
|
||||
u1 has range [min(s0, s2), max(s0, s2)]
|
||||
u2 has range [5, 8]
|
||||
u3 has range [5, 7]
|
||||
"""
|
||||
merged_size: list[Union[int, torch.SymInt]] = []
|
||||
for s0, s1 in zip(a.size(), b.size()):
|
||||
if SymIntEqByExpr(s0) == SymIntEqByExpr(s1):
|
||||
merged_size.append(s0)
|
||||
else:
|
||||
|
||||
def min_max(s0, s1):
|
||||
def _bound(s0, lower_bound: bool):
|
||||
if isinstance(s0, int):
|
||||
return s0
|
||||
r = mode.shape_env.var_to_range.get( # type: ignore[union-attr]
|
||||
s0.node.expr,
|
||||
torch.utils._sympy.value_ranges.ValueRanges.unknown(),
|
||||
)
|
||||
return r.lower if lower_bound else r.upper
|
||||
|
||||
return min(_bound(s0, True), _bound(s1, True)), max(
|
||||
_bound(s0, False), _bound(s1, False)
|
||||
)
|
||||
|
||||
assert mode.shape_env is not None
|
||||
new_size = mode.shape_env.create_unbacked_symint()
|
||||
mode.shape_env.constrain_symbol_range(new_size.node.expr, *min_max(s0, s1))
|
||||
merged_size.append(new_size)
|
||||
|
||||
"""
|
||||
This follows the logic in symbolic_shapes._compute_symbolic_stride
|
||||
Step 2: Since tensor stride is an accumulative muliplication of the sizes, which is a permutated
|
||||
(due to view ops) non-decending sequence.
|
||||
|
||||
Case 1: No size is 1. In this case, strides have unique values.
|
||||
For example, suppose we have a tenosr with:
|
||||
size [3, 4, 3, 5, 4, 5],
|
||||
stride (1200, 300, 1, 12, 3, 60),
|
||||
merged_size [u0, u1, u2, u3, u4, u5].
|
||||
|
||||
We visit the strides in ascending order: 1, 3, 12, 60, 300, 1200. In each step, we check whether
|
||||
the current stride is bounded or not and bound next stride by setting.
|
||||
stride_expr[next_stride] = current_stride_expr * current_size_expr
|
||||
1st round:
|
||||
current_stride is 1, current_size is 3, so next_stride is 1 * 3 = 3,
|
||||
current_stride_expr is set to 1, current_size_expr is u2, so stride_expr[3] is therefore 1 * u2 = u2
|
||||
2nd round:
|
||||
current_stride is 3, current_size is 4, so next_stride is 3 * 4 = 12,
|
||||
current_stride_expr is stride_expr[3] i.e. u2, current_size_expr is u4, so stride_expr[12] = u2 * u4
|
||||
...
|
||||
|
||||
Case 2: At least one dimension has size 1, which can produce duplicates in strides.
|
||||
In this case, theorectically, we cannot uniquely determine the expr of strides because
|
||||
the accessing stride_expr with same key in different order causes the final stride expression
|
||||
to be different.
|
||||
|
||||
Suppose we have:
|
||||
size: (3, 1)
|
||||
stride: (1, 1)
|
||||
merged_size: (u0, u1)
|
||||
|
||||
The stride expr could either be (u1, 1) or (1, u0) depending on whether we start with u1 or u0.
|
||||
For this reason, we try to break tie by sorting via decending index so we always get (u1, 1).
|
||||
|
||||
Note that backend might optimize the strides anyway so this is usually not a problem as long
|
||||
as two branches matches. See relevant discussions in https://github.com/pytorch/pytorch/issues/142024.
|
||||
|
||||
Case 3: Dim has 0 stride. 0 stride doesn't participate in the accumulative multiplication of
|
||||
sizes. So they're always treated as constant even if their corresponding size is turned into unbacked symint.
|
||||
|
||||
Suppose we have:
|
||||
size: (3, 3)
|
||||
stride: (0, 1)
|
||||
merged_size: (u0, u1)
|
||||
|
||||
The merged stride would be (0, 1)
|
||||
"""
|
||||
|
||||
def _bound_stride(
|
||||
a_ex_size: torch.Size,
|
||||
b_ex_size: torch.Size,
|
||||
a_ex_stride: tuple[int, ...],
|
||||
b_ex_stride: tuple[int, ...],
|
||||
merged_size: list[Union[int, torch.SymInt]],
|
||||
) -> list[Union[int, torch.SymInt]]:
|
||||
from torch._inductor.ir import get_stride_order
|
||||
|
||||
a_sorted_stride_idx = get_stride_order(a_ex_stride, mode.shape_env)
|
||||
b_sorted_stride_idx = get_stride_order(b_ex_stride, mode.shape_env)
|
||||
|
||||
a_stride_li: list[Optional[tuple[Union[int, torch.SymInt], int]]] = [
|
||||
None
|
||||
] * len(a_ex_stride)
|
||||
b_stride_li: list[Optional[tuple[Union[int, torch.SymInt], int]]] = [
|
||||
None
|
||||
] * len(b_ex_stride)
|
||||
for i, idx in enumerate(a_sorted_stride_idx):
|
||||
a_stride_li[idx] = (a_ex_stride[i], -i)
|
||||
for i, idx in enumerate(b_sorted_stride_idx):
|
||||
b_stride_li[idx] = (b_ex_stride[i], -i)
|
||||
|
||||
for a_pair, b_pair in zip(a_stride_li, b_stride_li):
|
||||
assert a_pair is not None and b_pair is not None
|
||||
_, a_idx = a_pair
|
||||
_, b_idx = b_pair
|
||||
|
||||
if a_idx != b_idx:
|
||||
raise RuntimeError(
|
||||
f"The sorted order of strides of the two branches' output doesn't match."
|
||||
f"this indicates the contiguousness of the two branches are different. "
|
||||
f"True branch has stride {a_ex_stride} but false branch has stride {b_ex_stride}."
|
||||
f"Consider using contiguous() to make the two branches have the same contiguousness."
|
||||
)
|
||||
|
||||
def _maybe_expr(s: Union[int, torch.SymInt]):
|
||||
if isinstance(s, int):
|
||||
return s
|
||||
return s.node.expr
|
||||
|
||||
a_stride_expr: dict[Any, Union[int, torch.SymInt]] = {}
|
||||
b_stride_expr: dict[Any, Union[int, torch.SymInt]] = {}
|
||||
merged_strides: list[Union[int, torch.SymInt]] = [None] * len(a_ex_stride) # type: ignore[list-item]
|
||||
for a_pair, b_pair in zip(a_stride_li, b_stride_li):
|
||||
assert a_pair is not None and b_pair is not None
|
||||
a_val, neg_i = a_pair
|
||||
b_val, _ = b_pair
|
||||
|
||||
i = -neg_i
|
||||
if a_val == 0:
|
||||
assert b_val == 0, (a_val, b_val)
|
||||
merged_strides[i] = 0
|
||||
continue
|
||||
raise torch._dynamo.exc.CondOpArgsMismatchError(
|
||||
f"Expected both branches to return None:"
|
||||
f"\n {true_fn.__name__} returns {true_out}"
|
||||
f"\n {false_fn.__name__} returns {false_out}"
|
||||
)
|
||||
true_meta = _extract_tensor_metadata(true_out)
|
||||
false_meta = _extract_tensor_metadata(false_out)
|
||||
if true_meta != false_meta:
|
||||
raise torch._dynamo.exc.CondOpArgsMismatchError(
|
||||
f"Expected each tensor to have same metadata but got:"
|
||||
f"\n {true_fn.__name__} returns {true_meta}"
|
||||
f"\n {false_fn.__name__} returns {false_meta}"
|
||||
)
|
||||
return true_outs
|
||||
|
||||
if _maybe_expr(a_val) in a_stride_expr:
|
||||
a_expr = a_stride_expr[_maybe_expr(a_val)]
|
||||
assert (
|
||||
b_stride_expr[_maybe_expr(b_val)] == a_expr
|
||||
), f"a_stride_expr:{a_stride_expr}, b_stride_expr:{b_stride_expr}"
|
||||
merged_strides[i] = a_expr
|
||||
else:
|
||||
if a_val == 1:
|
||||
assert b_val == 1
|
||||
a_stride_expr[_maybe_expr(a_val)] = 1
|
||||
b_stride_expr[_maybe_expr(b_val)] = 1
|
||||
merged_strides[i] = 1
|
||||
else:
|
||||
# If we cannot find the expr of a_val in a_stride_expr, it means
|
||||
# the strides is not a simple accumulative multiplication of sizes.
|
||||
# In this case, we cannot determine the expr of strides from the new
|
||||
# shapes so we error out and hint users to call contiguous().
|
||||
raise RuntimeError(
|
||||
f"It seems one of cond's output stride is not a simple accumulative multiplication of sizes. "
|
||||
f"This could be because cond returns a slice of a tensor, which is not dense in memory. "
|
||||
f"True branch has size {a_ex_size}, stride {a_ex_stride} and false branch has size {b_ex_size} "
|
||||
f"stride {b_ex_stride}. Hint: can call t.contiguous(). "
|
||||
)
|
||||
nxt_merged_stride_expr = merged_strides[i] * merged_size[i]
|
||||
a_stride_expr[_maybe_expr(a_val * a_ex_size[i])] = nxt_merged_stride_expr
|
||||
b_stride_expr[_maybe_expr(b_val * b_ex_size[i])] = nxt_merged_stride_expr
|
||||
return merged_strides
|
||||
|
||||
merged_stride: list[Union[int, torch.SymInt]] = _bound_stride(
|
||||
a.size(), b.size(), a.stride(), b.stride(), merged_size
|
||||
)
|
||||
|
||||
with mode:
|
||||
return torch.empty_strided(
|
||||
merged_size, merged_stride, dtype=a.dtype, device=a.device
|
||||
)
|
||||
|
||||
|
||||
@cond_op.py_functionalize_impl
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user