Revert "[cond] inlining into one of the branches when pred is a python constant (#128709)"

This reverts commit fe3e6878c4.

Reverted https://github.com/pytorch/pytorch/pull/128709 on behalf of https://github.com/ydwu4 due to causing error on truck due to a land racing: fe3e6878c4 ([comment](https://github.com/pytorch/pytorch/pull/128709#issuecomment-2221104043))
This commit is contained in:
PyTorch MergeBot 2024-07-10 17:47:18 +00:00
parent b4b7477d3f
commit 0beeac35fa
8 changed files with 121 additions and 214 deletions

View File

@ -1912,10 +1912,13 @@ def forward(self, l_x_):
):
# True branch and false branch return tensors of different shape
torch._dynamo.export(mod)(torch.randn(3, 2))
# We specialize into one of the branches since predicate is a python boolean.
test_x = torch.randn(3, 2)
mod(test_x)
with self.assertRaisesRegex(
torch._dynamo.exc.UncapturedHigherOrderOpError,
"Cond doesn't work unless it is captured completely with torch.compile",
):
# True branch and false branch return tensors of different shape
test_x = torch.randn(3, 2)
mod(test_x)
def test_export_with_map_cond(self):
from functorch.experimental.control_flow import cond, map

View File

@ -1408,7 +1408,7 @@ def forward(self, child, const_unused):
def false_fn(x):
return (x - 1).sum()
return control_flow.cond(x.sum() > 4, true_fn, false_fn, [x])
return control_flow.cond(x.shape[0] > 4, true_fn, false_fn, [x])
mod_for_compile = torch.compile(Foo(), backend=cnt, dynamic=True)
mod_for_eager = Foo()
@ -6147,16 +6147,12 @@ class ActivationCheckpointingTests(torch._dynamo.test_case.TestCase):
return cond_op(pred=pred, true_fn=true_fn, false_fn=false_fn, operands=[x])
cnt = CompileCounter()
opt_test = torch.compile(test, backend=cnt, fullgraph=True)
opt_test = torch.compile(test, backend=cnt)
inp = torch.ones(3, 3)
true_pred = torch.Tensor([True])
false_pred = torch.Tensor([False])
self.assertTrue(torch.allclose(test(true_pred, inp), opt_test(true_pred, inp)))
self.assertEqual(cnt.frame_count, 1)
self.assertTrue(
torch.allclose(test(false_pred, inp), opt_test(false_pred, inp))
)
self.assertTrue(torch.allclose(test(True, inp), opt_test(True, inp)))
self.assertEqual(cnt.frame_count, 1)
self.assertTrue(torch.allclose(test(False, inp), opt_test(False, inp)))
self.assertEqual(cnt.frame_count, 2)
def test_cond_with_invalid_kwargs(self):
from torch._higher_order_ops.cond import cond_op

View File

@ -801,7 +801,7 @@ def forward(self, p_linear_weight, p_linear_bias, x):
return x.sin()
def forward(self, x):
return cond(x.sum() <= 2, self.subm.forward, self.bar, [x])
return cond(x.shape[0] <= 2, self.subm.forward, self.bar, [x])
example_inputs = (torch.randn(1, 3, 3, 3),)
m = CondBranchClassMethod()
@ -3603,7 +3603,7 @@ def forward(self, x):
# https://github.com/pytorch/pytorch/issues/129939
@testing.expectedFailureNonStrict
def test_export_cond_symbool_pred(self):
def test_export_cond(self):
class A(torch.nn.Module):
def __init__(self):
super().__init__()
@ -3626,20 +3626,10 @@ def forward(self, x):
return cond(x.shape[0] > 4, true_fn, false_fn, [x])
dim0 = torch.export.Dim("dim0", min=3)
inp = torch.ones(6, 4)
ep = export(Foo(), (inp,), dynamic_shapes={"x": {0: dim0}})
self.assertExpectedInline(
ep.graph_module.code.strip(),
"""\
def forward(self, b_a_buffer, x):
sym_size_int_1 = torch.ops.aten.sym_size.int(x, 0)
gt = sym_size_int_1 > 4; sym_size_int_1 = None
true_graph_0 = self.true_graph_0
false_graph_0 = self.false_graph_0
cond = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, [x, b_a_buffer]); gt = true_graph_0 = false_graph_0 = x = b_a_buffer = None
getitem = cond[0]; cond = None
return (getitem,)""",
ep = export(
Foo(),
(inp,),
)
self.assertTrue(
torch.allclose(ep.module()(torch.ones(6, 4)), Foo()(torch.ones(6, 4)))
@ -5017,7 +5007,7 @@ graph():
def false_fn(x):
return self.linear(x).sin()
return torch.cond(x.sum() > 4, true_fn, false_fn, [x])
return torch.cond(x.shape[0] > 4, true_fn, false_fn, [x])
class CondExport(torch.nn.Module):
def __init__(self):
@ -5034,12 +5024,10 @@ graph():
"""\
def forward(self, p_bar_linear_weight, p_bar_linear_bias, x):
cos = torch.ops.aten.cos.default(x)
sum_1 = torch.ops.aten.sum.default(x)
gt = torch.ops.aten.gt.Scalar(sum_1, 4); sum_1 = None
true_graph_0 = self.true_graph_0
false_graph_0 = self.false_graph_0
cond = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, [p_bar_linear_bias, p_bar_linear_weight, x]); gt = true_graph_0 = false_graph_0 = p_bar_linear_bias = p_bar_linear_weight = x = None
getitem = cond[0]; cond = None
conditional = torch.ops.higher_order.cond(False, true_graph_0, false_graph_0, [p_bar_linear_bias, p_bar_linear_weight, x]); true_graph_0 = false_graph_0 = p_bar_linear_bias = p_bar_linear_weight = x = None
getitem = conditional[0]; conditional = None
add = torch.ops.aten.add.Tensor(cos, getitem); cos = getitem = None
return (add,)""",
)
@ -5134,8 +5122,8 @@ def forward(self, p_bar_linear_weight, p_bar_linear_bias, x):
def forward(self, b_pred, b_t, x, y):
true_graph_0 = self.true_graph_0
false_graph_0 = self.false_graph_0
cond = torch.ops.higher_order.cond(b_pred, true_graph_0, false_graph_0, [b_t, x, y]); b_pred = true_graph_0 = false_graph_0 = b_t = x = y = None
getitem = cond[0]; cond = None
conditional = torch.ops.higher_order.cond(b_pred, true_graph_0, false_graph_0, [b_t, x, y]); b_pred = true_graph_0 = false_graph_0 = b_t = x = y = None
getitem = conditional[0]; conditional = None
return (getitem,)""",
) # noqa: B950

View File

@ -68,7 +68,7 @@ class TestVerifier(TestCase):
def false_fn(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return x - y
return control_flow.cond(x.sum() > 2, true_fn, false_fn, [x, y])
return control_flow.cond(x.shape[0] > 2, true_fn, false_fn, [x, y])
f = Foo()
@ -87,7 +87,7 @@ class TestVerifier(TestCase):
def false_fn(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return x - y
return control_flow.cond(x.sum() > 2, true_fn, false_fn, [x, y])
return control_flow.cond(x.shape[0] > 2, true_fn, false_fn, [x, y])
f = Foo()

View File

@ -4237,7 +4237,7 @@ def forward(self, arg0_1, arg1_1):
return x.cos()
return torch.cond(
y.cos().sum() > 5, true_true_fn, true_false_fn, [y.cos()]
y.cos().shape[0] > 5, true_true_fn, true_false_fn, [y.cos()]
)
def false_fn(x):
@ -4245,7 +4245,7 @@ def forward(self, arg0_1, arg1_1):
z.add_(6)
return z.sin()
a = torch.cond(x.sum() > 4, true_fn, false_fn, [x])
a = torch.cond(x.shape[0] > 4, true_fn, false_fn, [x])
return (a + 3, a + 4)
inp = torch.randn(2, 2)
@ -4254,12 +4254,10 @@ def forward(self, arg0_1, arg1_1):
str(gm.code).strip(),
"""\
def forward(self, arg0_1):
sum_1 = torch.ops.aten.sum.default(arg0_1)
gt = torch.ops.aten.gt.Scalar(sum_1, 4); sum_1 = None
true_graph_0 = self.true_graph_0
false_graph_0 = self.false_graph_0
cond = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, [arg0_1]); gt = true_graph_0 = false_graph_0 = arg0_1 = None
getitem = cond[0]; cond = None
conditional = torch.ops.higher_order.cond(False, true_graph_0, false_graph_0, [arg0_1]); true_graph_0 = false_graph_0 = arg0_1 = None
getitem = conditional[0]; conditional = None
add = torch.ops.aten.add.Tensor(getitem, 3)
add_1 = torch.ops.aten.add.Tensor(getitem, 4); getitem = None
return (add, add_1)""", # noqa: B950
@ -4272,13 +4270,11 @@ def forward(self, arg0_1):
sin = torch.ops.aten.sin.default(arg0_1); arg0_1 = None
add = torch.ops.aten.add.Tensor(sin, 5); sin = None
cos = torch.ops.aten.cos.default(add)
sum_1 = torch.ops.aten.sum.default(cos); cos = None
gt = torch.ops.aten.gt.Scalar(sum_1, 5); sum_1 = None
cos_1 = torch.ops.aten.cos.default(add); add = None
true_graph_0 = self.true_graph_0
false_graph_0 = self.false_graph_0
cond = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, [cos_1]); gt = true_graph_0 = false_graph_0 = cos_1 = None
getitem = cond[0]; cond = None
conditional = torch.ops.higher_order.cond(False, true_graph_0, false_graph_0, [cos_1]); true_graph_0 = false_graph_0 = cos_1 = None
getitem = conditional[0]; conditional = None
return (getitem,)""", # noqa: B950
)
@ -4321,7 +4317,7 @@ def forward(self, arg0_1):
+ control_flow.map(f, z, r).sum()
)
a = torch.cond(x.sum() > 4, true_fn, false_fn, [x, y])
a = torch.cond(x.shape[0] > 4, true_fn, false_fn, [x, y])
return (a + 3, a + 4)
inps = [torch.randn(2, 2), torch.ones(2)]
@ -4330,12 +4326,10 @@ def forward(self, arg0_1):
str(gm.code).strip(),
"""\
def forward(self, arg0_1, arg1_1):
sum_1 = torch.ops.aten.sum.default(arg0_1)
gt = torch.ops.aten.gt.Scalar(sum_1, 4); sum_1 = None
true_graph_0 = self.true_graph_0
false_graph_0 = self.false_graph_0
cond = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, [arg0_1, arg1_1]); gt = true_graph_0 = false_graph_0 = arg0_1 = arg1_1 = None
getitem = cond[0]; cond = None
conditional = torch.ops.higher_order.cond(False, true_graph_0, false_graph_0, [arg0_1, arg1_1]); true_graph_0 = false_graph_0 = arg0_1 = arg1_1 = None
getitem = conditional[0]; conditional = None
add = torch.ops.aten.add.Tensor(getitem, 3)
add_1 = torch.ops.aten.add.Tensor(getitem, 4); getitem = None
return (add, add_1)""", # noqa: B950
@ -4440,7 +4434,7 @@ def forward(self, arg0_1, arg1_1):
z.add_(6)
return z.sin()
a = torch.cond(x.sum() > 4, true_fn, false_fn, [x])
a = torch.cond(x.shape[0] > 4, true_fn, false_fn, [x])
return (a + 3, a + 4)
inp = torch.randn(2, 2)
@ -4449,12 +4443,10 @@ def forward(self, arg0_1, arg1_1):
str(gm.code).strip(),
"""\
def forward(self, arg0_1):
sum_1 = torch.ops.aten.sum.default(arg0_1)
gt = torch.ops.aten.gt.Scalar(sum_1, 4); sum_1 = None
true_graph_0 = self.true_graph_0
false_graph_0 = self.false_graph_0
cond = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, [arg0_1]); gt = true_graph_0 = false_graph_0 = arg0_1 = None
getitem = cond[0]; cond = None
conditional = torch.ops.higher_order.cond(False, true_graph_0, false_graph_0, [arg0_1]); true_graph_0 = false_graph_0 = arg0_1 = None
getitem = conditional[0]; conditional = None
add = torch.ops.aten.add.Tensor(getitem, 3)
add_1 = torch.ops.aten.add.Tensor(getitem, 4); getitem = None
return (add, add_1)""", # noqa: B950
@ -4875,7 +4867,7 @@ def forward(self, arg0_1, arg1_1, arg2_1):
y.add_(6)
return x.sin()
a = torch.cond(x.sum() > 4, true_fn, false_fn, [x])
a = torch.cond(x.shape[0] > 4, true_fn, false_fn, [x])
return (a + 3, a + 4)
inp = torch.randn(3, 4)
@ -4884,12 +4876,10 @@ def forward(self, arg0_1, arg1_1, arg2_1):
gm.code.strip(),
"""\
def forward(self, arg0_1):
sum_1 = torch.ops.aten.sum.default(arg0_1)
gt = torch.ops.aten.gt.Scalar(sum_1, 4); sum_1 = None
true_graph_0 = self.true_graph_0
false_graph_0 = self.false_graph_0
cond = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, [arg0_1]); gt = true_graph_0 = false_graph_0 = arg0_1 = None
getitem = cond[0]; cond = None
conditional = torch.ops.higher_order.cond(False, true_graph_0, false_graph_0, [arg0_1]); true_graph_0 = false_graph_0 = arg0_1 = None
getitem = conditional[0]; conditional = None
add = torch.ops.aten.add.Tensor(getitem, 3)
add_1 = torch.ops.aten.add.Tensor(getitem, 4); getitem = None
return (add, add_1)""", # noqa: B950

View File

@ -877,7 +877,7 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1):
f(x, torch.tensor(True), torch.tensor(True)),
)
def test_cond_functionalized(self):
def test_cond_functionalized_hah(self):
def true_fn(x):
y = x.sin()
y.add_(4)
@ -894,9 +894,7 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1):
functional_f = torch.func.functionalize(f)
self.assertEqual(functional_f(*example_inputs), f(*example_inputs))
graph_module = make_fx(torch.func.functionalize(f), tracing_mode="symbolic")(
*example_inputs
)
graph_module = make_fx(torch.func.functionalize(f))(*example_inputs)
self.assertEqual(graph_module(*example_inputs), f(*example_inputs))
all_ops_in_true_branch = []
@ -906,6 +904,9 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1):
self.assertFalse(any(op._schema.is_mutable for op in all_ops_in_true_branch))
graph_module = make_fx(torch.func.functionalize(f), tracing_mode="symbolic")(
*example_inputs
)
self.assertEqual(graph_module(*example_inputs), f(*example_inputs))
def test_cond_accepts_torch_function_as_inputs(self):
@ -924,8 +925,8 @@ def forward(self, a_1, b_1):
gt = torch.ops.aten.gt.Scalar(sum_1, 0); sum_1 = None
true_graph_0 = self.true_graph_0
false_graph_0 = self.false_graph_0
cond = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, [a_1, b_1]); gt = true_graph_0 = false_graph_0 = a_1 = b_1 = None
getitem = cond[0]; cond = None
conditional = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, [a_1, b_1]); gt = true_graph_0 = false_graph_0 = a_1 = b_1 = None
getitem = conditional[0]; conditional = None
return getitem""", # noqa: B950
)
self.assertExpectedInline(
@ -972,9 +973,9 @@ def forward(self, arg0_1, arg1_1):
z = torch.add(y, y)
return z
symbolic_traced_graph = self._check_tracing(
f, (torch.ones(4), torch.Tensor([True]))
)["symbolic"]
symbolic_traced_graph = self._check_tracing(f, (torch.ones(4), True))[
"symbolic"
]
graph_shape_env = symbolic_traced_graph.shape_env
def _node_shape_env_iter(gm):
@ -1020,14 +1021,15 @@ def forward(self, arg0_1, arg1_1):
functional_f = torch.func.functionalize(f)
self.assertEqual(functional_f(*example_inputs), f(*example_inputs))
graph_module = make_fx(torch.func.functionalize(f), tracing_mode="symbolic")(
*example_inputs
)
graph_module = make_fx(torch.func.functionalize(f))(*example_inputs)
self.assertEqual(graph_module(*example_inputs), f(*example_inputs))
gm_true_true_branch = graph_module.true_graph_0.true_graph_0
self.assertEqual(graph_module(*example_inputs), f(*example_inputs))
graph_module1 = make_fx(torch.func.functionalize(f), tracing_mode="symbolic")(
*example_inputs
)
self.assertEqual(graph_module1(*example_inputs), f(*example_inputs))
all_ops = []
for node in gm_true_true_branch.graph.nodes:
@ -1055,7 +1057,8 @@ def forward(self, arg0_1, arg1_1):
self.assertEqual(graph_module(*example_inputs), f(*example_inputs))
# https://github.com/pytorch/pytorch/issues/126988
def test_cond_functionalized_input_mutation_on_true_brancte(self):
@xfailIfTorchDynamo
def test_cond_functionalized_input_mutation_on_true_branch(self):
def true_fn(x):
view_x = x.view(x.shape)
view_x.add_(1)
@ -1069,33 +1072,19 @@ def forward(self, arg0_1, arg1_1):
return cond(pred, true_fn, false_fn, [x])
example_inputs = (torch.ones(4, 5),)
# torch.cond inlines into one of the branches because the predicate
# is a constant.
gm = make_fx(torch.func.functionalize(f))(*example_inputs)
self.assertExpectedInline(
gm.code.strip(),
"""\
def forward(self, x_1):
view = torch.ops.aten.view.default(x_1, [4, 5])
add = torch.ops.aten.add.Tensor(view, 1); view = None
view_1 = torch.ops.aten.view.default(add, [4, 5]); add = None
view_2 = torch.ops.aten.view.default(view_1, [4, 5])
sin = torch.ops.aten.sin.default(view_2); view_2 = None
sum_1 = torch.ops.aten.sum.default(sin); sin = None
copy_ = torch.ops.aten.copy_.default(x_1, view_1); x_1 = view_1 = None
return sum_1""",
)
# torch.cond triggers the check of the branches because the predicate
# is a SymBool.
functional_f = torch.func.functionalize(f)
with self.assertRaisesRegex(
UnsupportedAliasMutationException, "One of torch.cond branch"
):
make_fx(torch.func.functionalize(f), tracing_mode="symbolic")(
*example_inputs
)
functional_f(*example_inputs)
with self.assertRaisesRegex(
UnsupportedAliasMutationException, "One of torch.cond branch"
):
make_fx(torch.func.functionalize(f))(*example_inputs)
# https://github.com/pytorch/pytorch/issues/126988
@xfailIfTorchDynamo
def test_cond_functionalized_input_mutation_on_false_branch(self):
def true_fn(x):
return x.sin().sum()
@ -1110,33 +1099,19 @@ def forward(self, x_1):
return cond(pred, true_fn, false_fn, [x])
example_inputs = (torch.ones(5, 5),)
gm = make_fx(torch.func.functionalize(f))(*example_inputs)
# torch.cond inlines into one of the branches because the predicate
# is a constant.
self.assertExpectedInline(
gm.code.strip(),
"""\
def forward(self, x_1):
view = torch.ops.aten.view.default(x_1, [5, 5])
add = torch.ops.aten.add.Tensor(view, 1); view = None
view_1 = torch.ops.aten.view.default(add, [5, 5]); add = None
view_2 = torch.ops.aten.view.default(view_1, [5, 5])
cos = torch.ops.aten.cos.default(view_2); view_2 = None
sum_1 = torch.ops.aten.sum.default(cos); cos = None
copy_ = torch.ops.aten.copy_.default(x_1, view_1); x_1 = view_1 = None
return sum_1""",
)
# torch.cond triggers the check of the branches because the predicate
# is a SymBool.
functional_f = torch.func.functionalize(f)
with self.assertRaisesRegex(
UnsupportedAliasMutationException, "One of torch.cond branch"
):
make_fx(torch.func.functionalize(f), tracing_mode="symbolic")(
*example_inputs
)
functional_f(*example_inputs)
with self.assertRaisesRegex(
UnsupportedAliasMutationException, "One of torch.cond branch"
):
make_fx(torch.func.functionalize(f))(*example_inputs)
# https://github.com/pytorch/pytorch/issues/126988
@xfailIfTorchDynamo
def test_cond_functionalized_output_alias_input(self):
def true_fn(x):
return x
@ -1150,27 +1125,22 @@ def forward(self, x_1):
return cond(pred, true_fn, false_fn, [x])
example_inputs = (torch.ones(5, 5),)
gm = make_fx(torch.func.functionalize(f))(*example_inputs)
# torch.cond inlines into one of the branches because the predicate
# is a constant.
self.assertExpectedInline(
gm.code.strip(),
"""\
def forward(self, x_1):
view = torch.ops.aten.view.default(x_1, [5, 5]); x_1 = None
return view""",
)
functional_f = torch.func.functionalize(f)
# torch.cond triggers the check of the branches because the predicate
# is a SymBool.
with self.assertRaisesRegex(
UnsupportedAliasMutationException, "One of torch.cond branch"
UnsupportedAliasMutationException,
"One of torch.cond branch might be aliasing",
):
make_fx(torch.func.functionalize(f), tracing_mode="symbolic")(
*example_inputs
)
functional_f(*example_inputs)
with self.assertRaisesRegex(
UnsupportedAliasMutationException,
"One of torch.cond branch might be aliasing",
):
make_fx(torch.func.functionalize(f))(*example_inputs)
# https://github.com/pytorch/pytorch/issues/126988
@xfailIfTorchDynamo
def test_cond_functionalized_nested_input_mutation(self):
def true_true_fn(x):
x.add_(4)
@ -1191,14 +1161,19 @@ def forward(self, x_1):
return cond(pred, true_fn, false_fn, [x])
example_inputs = (torch.ones(4, 5),)
functional_f = torch.func.functionalize(f)
with self.assertRaisesRegex(
UnsupportedAliasMutationException, "One of torch.cond branch"
):
make_fx(torch.func.functionalize(f), tracing_mode="symbolic")(
*example_inputs
)
functional_f(*example_inputs)
with self.assertRaisesRegex(
UnsupportedAliasMutationException, "One of torch.cond branch"
):
make_fx(torch.func.functionalize(f))(*example_inputs)
# https://github.com/pytorch/pytorch/issues/126988
@xfailIfTorchDynamo
def test_cond_functionalized_nested_input_mutation_with_aot_func(self):
def true_true_fn(x):
x.add_(4)
@ -1222,12 +1197,15 @@ def forward(self, x_1):
try:
example_input_func = to_fun_old(example_input)
torch._enable_functionalization(reapply_views=False)
f(example_input_func)
with self.assertRaisesRegex(
UnsupportedAliasMutationException, "One of torch.cond branch"
):
f(example_input_func)
with self.assertRaisesRegex(
UnsupportedAliasMutationException, "One of torch.cond branch"
):
make_fx(f, tracing_mode="symbolic")(example_input_func)
make_fx(f)(example_input_func)
finally:
torch._disable_functionalization()
@ -1245,7 +1223,7 @@ def forward(self, x_1):
with self.assertRaisesRegex(
UnsupportedAliasMutationException, "One of torch.cond branch"
):
make_fx(f_wrapper(f), tracing_mode="symbolic")(example_input_func)
make_fx(f_wrapper(f))(example_input_func)
# https://github.com/pytorch/pytorch/issues/126988
@xfailIfTorchDynamo
@ -1258,7 +1236,7 @@ def forward(self, x_1):
return view_x
def f(x):
pred = x.sum() > 0
pred = x.shape[0] == 4
return cond(pred, true_fn, false_fn, [x])
example_input = torch.ones(5, 5)
@ -1300,7 +1278,7 @@ def forward(self, x_1):
UnsupportedAliasMutationException,
"One of torch.cond branch might be aliasing",
):
make_fx(f_wrapper(f), tracing_mode="symbolic")(example_input)
make_fx(f_wrapper(f))(example_input)
def test_cond_functionalized_aot_func_check_functional(self):
def true_fn(x):
@ -1338,7 +1316,7 @@ def forward(self, x_1):
return wrapper
result_gm = make_fx(f_wrapper(f), tracing_mode="symbolic")(example_input)
result_gm = make_fx(f_wrapper(f))(example_input)
for node in result_gm.true_graph_0.graph.nodes:
if node.op == "call_function":
self.assertTrue(not node.target._schema.is_mutable)
@ -1404,12 +1382,12 @@ def forward(self, x_1):
def forward(self, x_1, pred_1, pred2_1):
true_graph_0 = self.true_graph_0
false_graph_0 = self.false_graph_0
cond = torch.ops.higher_order.cond(pred_1, true_graph_0, false_graph_0, [x_1]); pred_1 = true_graph_0 = false_graph_0 = None
getitem = cond[0]; cond = None
conditional = torch.ops.higher_order.cond(pred_1, true_graph_0, false_graph_0, [x_1]); pred_1 = true_graph_0 = false_graph_0 = None
getitem = conditional[0]; conditional = None
true_graph_1 = self.true_graph_1
false_graph_1 = self.false_graph_1
cond_1 = torch.ops.higher_order.cond(pred2_1, true_graph_1, false_graph_1, [x_1]); pred2_1 = true_graph_1 = false_graph_1 = x_1 = None
getitem_1 = cond_1[0]; cond_1 = None
conditional_1 = torch.ops.higher_order.cond(pred2_1, true_graph_1, false_graph_1, [x_1]); pred2_1 = true_graph_1 = false_graph_1 = x_1 = None
getitem_1 = conditional_1[0]; conditional_1 = None
add = torch.ops.aten.add.Tensor(getitem, getitem_1); getitem = getitem_1 = None
return add""", # noqa: B950
)
@ -1577,12 +1555,12 @@ def forward(self, arg0_1):
def forward(self, x_1, pred_1, pred2_1):
true_graph_0 = self.true_graph_0
false_graph_0 = self.false_graph_0
cond = torch.ops.higher_order.cond(pred_1, true_graph_0, false_graph_0, [x_1]); pred_1 = true_graph_0 = false_graph_0 = None
getitem = cond[0]; cond = None
conditional = torch.ops.higher_order.cond(pred_1, true_graph_0, false_graph_0, [x_1]); pred_1 = true_graph_0 = false_graph_0 = None
getitem = conditional[0]; conditional = None
true_graph_1 = self.true_graph_1
false_graph_1 = self.false_graph_1
cond_1 = torch.ops.higher_order.cond(pred2_1, true_graph_1, false_graph_1, [x_1]); pred2_1 = true_graph_1 = false_graph_1 = x_1 = None
getitem_1 = cond_1[0]; cond_1 = None
conditional_1 = torch.ops.higher_order.cond(pred2_1, true_graph_1, false_graph_1, [x_1]); pred2_1 = true_graph_1 = false_graph_1 = x_1 = None
getitem_1 = conditional_1[0]; conditional_1 = None
add = torch.ops.aten.add.Tensor(getitem, getitem_1); getitem = getitem_1 = None
return add""", # noqa: B950
)
@ -1913,27 +1891,6 @@ def forward(self, arg0_1):
):
functional_f(*example_inputs)
def test_cond_autograd_succeed_when_pred_is_constant(self):
def true_fn(x):
return x.cos()
def false_fn(x):
return x.sin()
def f(x, y):
return control_flow.cond(x.shape[0] > 4, true_fn, false_fn, [y])
example_inputs = (
torch.ones(3, 2, 4, requires_grad=True),
torch.ones(4, requires_grad=True),
)
# Due to x.shape[0] can be statically evaluated to be False, we can evaluate
# the backward.
f(*example_inputs).sum().backward()
# Ensure no error is thrown when not running backward
f(*example_inputs)
def test_cond_autograd_fail(self):
def true_fn(x):
return x.cos()
@ -1942,7 +1899,7 @@ def forward(self, arg0_1):
return x.sin()
def f(x, y):
return control_flow.cond(x.sum() > 4, true_fn, false_fn, [y])
return control_flow.cond(x.shape[0] > 4, true_fn, false_fn, [y])
example_inputs = (
torch.ones(3, 2, 4, requires_grad=True),
@ -2072,8 +2029,8 @@ def forward(self, x_1):
eq = sym_size_int == 4; sym_size_int = None
true_graph_0 = self.true_graph_0
false_graph_0 = self.false_graph_0
cond = torch.ops.higher_order.cond(eq, true_graph_0, false_graph_0, [x_1]); eq = true_graph_0 = false_graph_0 = x_1 = None
getitem = cond[0]; cond = None
conditional = torch.ops.higher_order.cond(eq, true_graph_0, false_graph_0, [x_1]); eq = true_graph_0 = false_graph_0 = x_1 = None
getitem = conditional[0]; conditional = None
return getitem""", # noqa: B950
)
@ -2145,20 +2102,18 @@ def forward(self, x_1):
# expected branches takes [x, a, b] as input
inp = torch.randn(2, 3)
gm = make_fx(foo, tracing_mode="symbolic", _allow_non_fake_inputs=True)(inp)
gm = make_fx(foo)(inp)
self.assertExpectedInline(
gm.code.strip(),
"""\
def forward(self, x_1):
sym_size_int = torch.ops.aten.sym_size.int(x_1, 0)
eq = sym_size_int == 4; sym_size_int = None
true_graph_0 = self.true_graph_0
false_graph_0 = self.false_graph_0
_tensor_constant0 = self._tensor_constant0
_tensor_constant1 = self._tensor_constant1
cond = torch.ops.higher_order.cond(eq, true_graph_0, false_graph_0, [x_1, _tensor_constant0, _tensor_constant1]); eq = true_graph_0 = false_graph_0 = x_1 = _tensor_constant0 = _tensor_constant1 = None
getitem = cond[0]; cond = None
conditional = torch.ops.higher_order.cond(False, true_graph_0, false_graph_0, [x_1, _tensor_constant0, _tensor_constant1]); true_graph_0 = false_graph_0 = x_1 = _tensor_constant0 = _tensor_constant1 = None
getitem = conditional[0]; conditional = None
return getitem""", # noqa: B950
)
self.assertExpectedInline(
@ -2308,8 +2263,8 @@ def forward(self, pred_1, x_1):
def forward(self, arg0_1, arg1_1):
true_graph_0 = self.true_graph_0
false_graph_0 = self.false_graph_0
cond = torch.ops.higher_order.cond(arg1_1, true_graph_0, false_graph_0, [arg0_1]); arg1_1 = true_graph_0 = false_graph_0 = arg0_1 = None
getitem = cond[0]; cond = None
conditional = torch.ops.higher_order.cond(arg1_1, true_graph_0, false_graph_0, [arg0_1]); arg1_1 = true_graph_0 = false_graph_0 = arg0_1 = None
getitem = conditional[0]; conditional = None
return [getitem]""", # noqa: B950
)
@ -2350,7 +2305,7 @@ def forward(self, arg0_1, arg1_1):
counters.clear()
def foo(x, true_fn, false_fn):
return cond(x.sum() < 0, true_fn, false_fn, (x,))
return cond(x.shape[0] == 4, true_fn, false_fn, (x,))
inp = torch.ones(3, 4)
exp_out = inp.sin()
@ -2392,8 +2347,8 @@ def forward(self, x_1):
eq = sym_size_int == 4; sym_size_int = None
true_graph_0 = self.true_graph_0
false_graph_0 = self.false_graph_0
cond = torch.ops.higher_order.cond(eq, true_graph_0, false_graph_0, [x_1]); eq = true_graph_0 = false_graph_0 = x_1 = None
getitem = cond[0]; cond = None
conditional = torch.ops.higher_order.cond(eq, true_graph_0, false_graph_0, [x_1]); eq = true_graph_0 = false_graph_0 = x_1 = None
getitem = conditional[0]; conditional = None
return getitem""", # noqa: B950
)

View File

@ -632,18 +632,6 @@ class CondHigherOrderVariable(TorchHigherOrderOperatorVariable):
f"Expected 4 arguments but got {len(args)}.\n"
f"Usage: cond(pred, true_fn, false_fn, operands)",
)
# Specialize into one of the branches since pred is constant
if type(args[0]) is ConstantVariable:
log.warning(
"Pred is a Python constant. When used with torch.cond, it executes only one of the branches."
" If you want torch.cond to perserve two branches, please make the predicate a boolean tensor or a SymBool."
)
if args[0].as_python_constant():
return args[1].call_function(tx, args[3].unpack_var_sequence(tx), {})
else:
return args[2].call_function(tx, args[3].unpack_var_sequence(tx), {})
# predicate
if type(args[0]) not in (ConstantVariable, TensorVariable, SymNodeVariable):
unimplemented(

View File

@ -1,8 +1,6 @@
# mypy: allow-untyped-defs
import contextlib
import logging
import torch
import torch._subclasses.functional_tensor
import torch.utils._pytree as pytree
@ -34,8 +32,6 @@ from torch.fx.experimental.proxy_tensor import (
from torch.fx.passes.shape_prop import _extract_tensor_metadata
from torch.utils._python_dispatch import _get_current_dispatch_mode
log = logging.getLogger(__name__)
@exposed_in("torch")
def cond(pred, true_fn, false_fn, operands):
@ -107,19 +103,10 @@ def cond(pred, true_fn, false_fn, operands):
- The **output** of branches must be a **single Tensor**. Pytree of tensors will be supported in the future.
"""
if torch.compiler.is_dynamo_compiling():
return cond_op(pred, true_fn, false_fn, operands)
if isinstance(pred, (bool, int, float)):
log.warning(
"Pred is a Python constant. When used with torch.cond, it executes only one of the branches."
" If you want torch.cond to perserve two branches, please make the predicate a boolean tensor or a SymBool."
)
if pred:
return true_fn(*operands)
else:
return false_fn(*operands)
def _validate_input(pred, true_fn, false_fn, operands):
if not isinstance(pred, (bool, torch.Tensor, torch.SymBool)):
raise RuntimeError(f"Expected pred to be bool or tensor, but got {pred}.")
@ -213,7 +200,7 @@ def trace_cond(proxy_mode, func_overload, pred, true_fn, false_fn, operands):
proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, args)
out_proxy = proxy_mode.tracer.create_proxy(
"call_function", func_overload, proxy_args, {}
"call_function", func_overload, proxy_args, {}, name="conditional"
)
# At this point, we're *guaranteed* that whether an output came from the