[cond] inlining into one of the branches when pred is a python constant (#130493)

Reland https://github.com/pytorch/pytorch/pull/128709.

When the input predicate is a python constant, we specialize into one of the branches and warn users that torch.cond is not preserving the dynamism. The previous behavior is that we baked in True/False in the cond operator. This can be confusing. In this PR, we change it to be specializing into one of the branches when the inputs are constants.

We additionally change the naming of cond operator to default one without overriding its name. This allows better testing on de-serialized graph.

Test Plan:
The predicate in some existing tests is the result of a shape comparison. When no dynamic shape is involved, the predicate is a python bool. To fix them, we either change the predicate to be some data-dependent tensor or change the test to check cond is specialized as one of the branches,

Pull Request resolved: https://github.com/pytorch/pytorch/pull/130493
Approved by: https://github.com/BoyuanFeng
This commit is contained in:
Yidi Wu 2024-07-11 13:21:54 -07:00 committed by PyTorch MergeBot
parent 0bf9a091ec
commit 741c1710e8
8 changed files with 212 additions and 119 deletions

View File

@ -1912,13 +1912,10 @@ def forward(self, l_x_):
):
# True branch and false branch return tensors of different shape
torch._dynamo.export(mod)(torch.randn(3, 2))
with self.assertRaisesRegex(
torch._dynamo.exc.UncapturedHigherOrderOpError,
"Cond doesn't work unless it is captured completely with torch.compile",
):
# True branch and false branch return tensors of different shape
test_x = torch.randn(3, 2)
mod(test_x)
# We specialize into one of the branches since predicate is a python boolean.
test_x = torch.randn(3, 2)
mod(test_x)
def test_export_with_map_cond(self):
from functorch.experimental.control_flow import cond, map

View File

@ -1406,7 +1406,7 @@ def forward(self, child, const_unused):
def false_fn(x):
return (x - 1).sum()
return control_flow.cond(x.shape[0] > 4, true_fn, false_fn, [x])
return control_flow.cond(x.sum() > 4, true_fn, false_fn, [x])
mod_for_compile = torch.compile(Foo(), backend=cnt, dynamic=True)
mod_for_eager = Foo()
@ -6145,12 +6145,16 @@ class ActivationCheckpointingTests(torch._dynamo.test_case.TestCase):
return cond_op(pred=pred, true_fn=true_fn, false_fn=false_fn, operands=[x])
cnt = CompileCounter()
opt_test = torch.compile(test, backend=cnt)
opt_test = torch.compile(test, backend=cnt, fullgraph=True)
inp = torch.ones(3, 3)
self.assertTrue(torch.allclose(test(True, inp), opt_test(True, inp)))
true_pred = torch.Tensor([True])
false_pred = torch.Tensor([False])
self.assertTrue(torch.allclose(test(true_pred, inp), opt_test(true_pred, inp)))
self.assertEqual(cnt.frame_count, 1)
self.assertTrue(
torch.allclose(test(false_pred, inp), opt_test(false_pred, inp))
)
self.assertEqual(cnt.frame_count, 1)
self.assertTrue(torch.allclose(test(False, inp), opt_test(False, inp)))
self.assertEqual(cnt.frame_count, 2)
def test_cond_with_invalid_kwargs(self):
from torch._higher_order_ops.cond import cond_op

View File

@ -804,7 +804,7 @@ def forward(self, p_linear_weight, p_linear_bias, x):
return x.sin()
def forward(self, x):
return cond(x.shape[0] <= 2, self.subm.forward, self.bar, [x])
return cond(x.sum() <= 2, self.subm.forward, self.bar, [x])
example_inputs = (torch.randn(1, 3, 3, 3),)
m = CondBranchClassMethod()
@ -3616,7 +3616,7 @@ def forward(self, x):
):
torch.export.export(exported_v2.module(), (torch.randn(2, 2),))
def test_export_cond(self):
def test_export_cond_symbool_pred(self):
class A(torch.nn.Module):
def __init__(self):
super().__init__()
@ -3639,10 +3639,20 @@ def forward(self, x):
return cond(x.shape[0] > 4, true_fn, false_fn, [x])
dim0 = torch.export.Dim("dim0", min=3)
inp = torch.ones(6, 4)
ep = export(
Foo(),
(inp,),
ep = export(Foo(), (inp,), dynamic_shapes={"x": {0: dim0}})
self.assertExpectedInline(
ep.graph_module.code.strip(),
"""\
def forward(self, b_a_buffer, x):
sym_size_int_1 = torch.ops.aten.sym_size.int(x, 0)
gt = sym_size_int_1 > 4; sym_size_int_1 = None
true_graph_0 = self.true_graph_0
false_graph_0 = self.false_graph_0
cond = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, [x, b_a_buffer]); gt = true_graph_0 = false_graph_0 = x = b_a_buffer = None
getitem = cond[0]; cond = None
return (getitem,)""",
)
self.assertTrue(
torch.allclose(ep.module()(torch.ones(6, 4)), Foo()(torch.ones(6, 4)))
@ -4988,7 +4998,7 @@ graph():
def false_fn(x):
return self.linear(x).sin()
return torch.cond(x.shape[0] > 4, true_fn, false_fn, [x])
return torch.cond(x.sum() > 4, true_fn, false_fn, [x])
class CondExport(torch.nn.Module):
def __init__(self):
@ -5005,10 +5015,12 @@ graph():
"""\
def forward(self, p_bar_linear_weight, p_bar_linear_bias, x):
cos = torch.ops.aten.cos.default(x)
sum_1 = torch.ops.aten.sum.default(x)
gt = torch.ops.aten.gt.Scalar(sum_1, 4); sum_1 = None
true_graph_0 = self.true_graph_0
false_graph_0 = self.false_graph_0
conditional = torch.ops.higher_order.cond(False, true_graph_0, false_graph_0, [p_bar_linear_bias, p_bar_linear_weight, x]); true_graph_0 = false_graph_0 = p_bar_linear_bias = p_bar_linear_weight = x = None
getitem = conditional[0]; conditional = None
cond = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, [p_bar_linear_bias, p_bar_linear_weight, x]); gt = true_graph_0 = false_graph_0 = p_bar_linear_bias = p_bar_linear_weight = x = None
getitem = cond[0]; cond = None
add = torch.ops.aten.add.Tensor(cos, getitem); cos = getitem = None
return (add,)""",
)
@ -5103,8 +5115,8 @@ def forward(self, p_bar_linear_weight, p_bar_linear_bias, x):
def forward(self, b_pred, b_t, x, y):
true_graph_0 = self.true_graph_0
false_graph_0 = self.false_graph_0
conditional = torch.ops.higher_order.cond(b_pred, true_graph_0, false_graph_0, [b_t, x, y]); b_pred = true_graph_0 = false_graph_0 = b_t = x = y = None
getitem = conditional[0]; conditional = None
cond = torch.ops.higher_order.cond(b_pred, true_graph_0, false_graph_0, [b_t, x, y]); b_pred = true_graph_0 = false_graph_0 = b_t = x = y = None
getitem = cond[0]; cond = None
return (getitem,)""",
) # noqa: B950

View File

@ -68,7 +68,7 @@ class TestVerifier(TestCase):
def false_fn(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return x - y
return control_flow.cond(x.shape[0] > 2, true_fn, false_fn, [x, y])
return control_flow.cond(x.sum() > 2, true_fn, false_fn, [x, y])
f = Foo()
@ -87,7 +87,7 @@ class TestVerifier(TestCase):
def false_fn(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return x - y
return control_flow.cond(x.shape[0] > 2, true_fn, false_fn, [x, y])
return control_flow.cond(x.sum() > 2, true_fn, false_fn, [x, y])
f = Foo()

View File

@ -4237,7 +4237,7 @@ def forward(self, arg0_1, arg1_1):
return x.cos()
return torch.cond(
y.cos().shape[0] > 5, true_true_fn, true_false_fn, [y.cos()]
y.cos().sum() > 5, true_true_fn, true_false_fn, [y.cos()]
)
def false_fn(x):
@ -4245,7 +4245,7 @@ def forward(self, arg0_1, arg1_1):
z.add_(6)
return z.sin()
a = torch.cond(x.shape[0] > 4, true_fn, false_fn, [x])
a = torch.cond(x.sum() > 4, true_fn, false_fn, [x])
return (a + 3, a + 4)
inp = torch.randn(2, 2)
@ -4254,10 +4254,12 @@ def forward(self, arg0_1, arg1_1):
str(gm.code).strip(),
"""\
def forward(self, arg0_1):
sum_1 = torch.ops.aten.sum.default(arg0_1)
gt = torch.ops.aten.gt.Scalar(sum_1, 4); sum_1 = None
true_graph_0 = self.true_graph_0
false_graph_0 = self.false_graph_0
conditional = torch.ops.higher_order.cond(False, true_graph_0, false_graph_0, [arg0_1]); true_graph_0 = false_graph_0 = arg0_1 = None
getitem = conditional[0]; conditional = None
cond = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, [arg0_1]); gt = true_graph_0 = false_graph_0 = arg0_1 = None
getitem = cond[0]; cond = None
add = torch.ops.aten.add.Tensor(getitem, 3)
add_1 = torch.ops.aten.add.Tensor(getitem, 4); getitem = None
return (add, add_1)""", # noqa: B950
@ -4270,11 +4272,13 @@ def forward(self, arg0_1):
sin = torch.ops.aten.sin.default(arg0_1); arg0_1 = None
add = torch.ops.aten.add.Tensor(sin, 5); sin = None
cos = torch.ops.aten.cos.default(add)
sum_1 = torch.ops.aten.sum.default(cos); cos = None
gt = torch.ops.aten.gt.Scalar(sum_1, 5); sum_1 = None
cos_1 = torch.ops.aten.cos.default(add); add = None
true_graph_0 = self.true_graph_0
false_graph_0 = self.false_graph_0
conditional = torch.ops.higher_order.cond(False, true_graph_0, false_graph_0, [cos_1]); true_graph_0 = false_graph_0 = cos_1 = None
getitem = conditional[0]; conditional = None
cond = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, [cos_1]); gt = true_graph_0 = false_graph_0 = cos_1 = None
getitem = cond[0]; cond = None
return (getitem,)""", # noqa: B950
)
@ -4317,7 +4321,7 @@ def forward(self, arg0_1):
+ control_flow.map(f, z, r).sum()
)
a = torch.cond(x.shape[0] > 4, true_fn, false_fn, [x, y])
a = torch.cond(x.sum() > 4, true_fn, false_fn, [x, y])
return (a + 3, a + 4)
inps = [torch.randn(2, 2), torch.ones(2)]
@ -4326,10 +4330,12 @@ def forward(self, arg0_1):
str(gm.code).strip(),
"""\
def forward(self, arg0_1, arg1_1):
sum_1 = torch.ops.aten.sum.default(arg0_1)
gt = torch.ops.aten.gt.Scalar(sum_1, 4); sum_1 = None
true_graph_0 = self.true_graph_0
false_graph_0 = self.false_graph_0
conditional = torch.ops.higher_order.cond(False, true_graph_0, false_graph_0, [arg0_1, arg1_1]); true_graph_0 = false_graph_0 = arg0_1 = arg1_1 = None
getitem = conditional[0]; conditional = None
cond = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, [arg0_1, arg1_1]); gt = true_graph_0 = false_graph_0 = arg0_1 = arg1_1 = None
getitem = cond[0]; cond = None
add = torch.ops.aten.add.Tensor(getitem, 3)
add_1 = torch.ops.aten.add.Tensor(getitem, 4); getitem = None
return (add, add_1)""", # noqa: B950
@ -4434,7 +4440,7 @@ def forward(self, arg0_1, arg1_1):
z.add_(6)
return z.sin()
a = torch.cond(x.shape[0] > 4, true_fn, false_fn, [x])
a = torch.cond(x.sum() > 4, true_fn, false_fn, [x])
return (a + 3, a + 4)
inp = torch.randn(2, 2)
@ -4443,10 +4449,12 @@ def forward(self, arg0_1, arg1_1):
str(gm.code).strip(),
"""\
def forward(self, arg0_1):
sum_1 = torch.ops.aten.sum.default(arg0_1)
gt = torch.ops.aten.gt.Scalar(sum_1, 4); sum_1 = None
true_graph_0 = self.true_graph_0
false_graph_0 = self.false_graph_0
conditional = torch.ops.higher_order.cond(False, true_graph_0, false_graph_0, [arg0_1]); true_graph_0 = false_graph_0 = arg0_1 = None
getitem = conditional[0]; conditional = None
cond = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, [arg0_1]); gt = true_graph_0 = false_graph_0 = arg0_1 = None
getitem = cond[0]; cond = None
add = torch.ops.aten.add.Tensor(getitem, 3)
add_1 = torch.ops.aten.add.Tensor(getitem, 4); getitem = None
return (add, add_1)""", # noqa: B950
@ -4867,7 +4875,7 @@ def forward(self, arg0_1, arg1_1, arg2_1):
y.add_(6)
return x.sin()
a = torch.cond(x.shape[0] > 4, true_fn, false_fn, [x])
a = torch.cond(x.sum() > 4, true_fn, false_fn, [x])
return (a + 3, a + 4)
inp = torch.randn(3, 4)
@ -4876,10 +4884,12 @@ def forward(self, arg0_1, arg1_1, arg2_1):
gm.code.strip(),
"""\
def forward(self, arg0_1):
sum_1 = torch.ops.aten.sum.default(arg0_1)
gt = torch.ops.aten.gt.Scalar(sum_1, 4); sum_1 = None
true_graph_0 = self.true_graph_0
false_graph_0 = self.false_graph_0
conditional = torch.ops.higher_order.cond(False, true_graph_0, false_graph_0, [arg0_1]); true_graph_0 = false_graph_0 = arg0_1 = None
getitem = conditional[0]; conditional = None
cond = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, [arg0_1]); gt = true_graph_0 = false_graph_0 = arg0_1 = None
getitem = cond[0]; cond = None
add = torch.ops.aten.add.Tensor(getitem, 3)
add_1 = torch.ops.aten.add.Tensor(getitem, 4); getitem = None
return (add, add_1)""", # noqa: B950

View File

@ -877,7 +877,7 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1):
f(x, torch.tensor(True), torch.tensor(True)),
)
def test_cond_functionalized_hah(self):
def test_cond_functionalized(self):
def true_fn(x):
y = x.sin()
y.add_(4)
@ -894,7 +894,9 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1):
functional_f = torch.func.functionalize(f)
self.assertEqual(functional_f(*example_inputs), f(*example_inputs))
graph_module = make_fx(torch.func.functionalize(f))(*example_inputs)
graph_module = make_fx(torch.func.functionalize(f), tracing_mode="symbolic")(
*example_inputs
)
self.assertEqual(graph_module(*example_inputs), f(*example_inputs))
all_ops_in_true_branch = []
@ -904,9 +906,6 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1):
self.assertFalse(any(op._schema.is_mutable for op in all_ops_in_true_branch))
graph_module = make_fx(torch.func.functionalize(f), tracing_mode="symbolic")(
*example_inputs
)
self.assertEqual(graph_module(*example_inputs), f(*example_inputs))
def test_cond_accepts_torch_function_as_inputs(self):
@ -925,8 +924,8 @@ def forward(self, a_1, b_1):
gt = torch.ops.aten.gt.Scalar(sum_1, 0); sum_1 = None
true_graph_0 = self.true_graph_0
false_graph_0 = self.false_graph_0
conditional = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, [a_1, b_1]); gt = true_graph_0 = false_graph_0 = a_1 = b_1 = None
getitem = conditional[0]; conditional = None
cond = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, [a_1, b_1]); gt = true_graph_0 = false_graph_0 = a_1 = b_1 = None
getitem = cond[0]; cond = None
return getitem""", # noqa: B950
)
self.assertExpectedInline(
@ -973,9 +972,9 @@ def forward(self, arg0_1, arg1_1):
z = torch.add(y, y)
return z
symbolic_traced_graph = self._check_tracing(f, (torch.ones(4), True))[
"symbolic"
]
symbolic_traced_graph = self._check_tracing(
f, (torch.ones(4), torch.Tensor([True]))
)["symbolic"]
graph_shape_env = symbolic_traced_graph.shape_env
def _node_shape_env_iter(gm):
@ -1021,15 +1020,14 @@ def forward(self, arg0_1, arg1_1):
functional_f = torch.func.functionalize(f)
self.assertEqual(functional_f(*example_inputs), f(*example_inputs))
graph_module = make_fx(torch.func.functionalize(f))(*example_inputs)
graph_module = make_fx(torch.func.functionalize(f), tracing_mode="symbolic")(
*example_inputs
)
self.assertEqual(graph_module(*example_inputs), f(*example_inputs))
gm_true_true_branch = graph_module.true_graph_0.true_graph_0
graph_module1 = make_fx(torch.func.functionalize(f), tracing_mode="symbolic")(
*example_inputs
)
self.assertEqual(graph_module1(*example_inputs), f(*example_inputs))
self.assertEqual(graph_module(*example_inputs), f(*example_inputs))
all_ops = []
for node in gm_true_true_branch.graph.nodes:
@ -1057,8 +1055,7 @@ def forward(self, arg0_1, arg1_1):
self.assertEqual(graph_module(*example_inputs), f(*example_inputs))
# https://github.com/pytorch/pytorch/issues/126988
@xfailIfTorchDynamo
def test_cond_functionalized_input_mutation_on_true_branch(self):
def test_cond_functionalized_input_mutation_on_true_brancte(self):
def true_fn(x):
view_x = x.view(x.shape)
view_x.add_(1)
@ -1072,19 +1069,33 @@ def forward(self, arg0_1, arg1_1):
return cond(pred, true_fn, false_fn, [x])
example_inputs = (torch.ones(4, 5),)
functional_f = torch.func.functionalize(f)
with self.assertRaisesRegex(
UnsupportedAliasMutationException, "One of torch.cond branch"
):
functional_f(*example_inputs)
# torch.cond inlines into one of the branches because the predicate
# is a constant.
gm = make_fx(torch.func.functionalize(f))(*example_inputs)
self.assertExpectedInline(
gm.code.strip(),
"""\
def forward(self, x_1):
view = torch.ops.aten.view.default(x_1, [4, 5])
add = torch.ops.aten.add.Tensor(view, 1); view = None
view_1 = torch.ops.aten.view.default(add, [4, 5]); add = None
view_2 = torch.ops.aten.view.default(view_1, [4, 5])
sin = torch.ops.aten.sin.default(view_2); view_2 = None
sum_1 = torch.ops.aten.sum.default(sin); sin = None
copy_ = torch.ops.aten.copy_.default(x_1, view_1); x_1 = view_1 = None
return sum_1""",
)
# torch.cond triggers the check of the branches because the predicate
# is a SymBool.
with self.assertRaisesRegex(
UnsupportedAliasMutationException, "One of torch.cond branch"
):
make_fx(torch.func.functionalize(f))(*example_inputs)
make_fx(torch.func.functionalize(f), tracing_mode="symbolic")(
*example_inputs
)
# https://github.com/pytorch/pytorch/issues/126988
@xfailIfTorchDynamo
def test_cond_functionalized_input_mutation_on_false_branch(self):
def true_fn(x):
return x.sin().sum()
@ -1099,19 +1110,33 @@ def forward(self, arg0_1, arg1_1):
return cond(pred, true_fn, false_fn, [x])
example_inputs = (torch.ones(5, 5),)
functional_f = torch.func.functionalize(f)
with self.assertRaisesRegex(
UnsupportedAliasMutationException, "One of torch.cond branch"
):
functional_f(*example_inputs)
gm = make_fx(torch.func.functionalize(f))(*example_inputs)
# torch.cond inlines into one of the branches because the predicate
# is a constant.
self.assertExpectedInline(
gm.code.strip(),
"""\
def forward(self, x_1):
view = torch.ops.aten.view.default(x_1, [5, 5])
add = torch.ops.aten.add.Tensor(view, 1); view = None
view_1 = torch.ops.aten.view.default(add, [5, 5]); add = None
view_2 = torch.ops.aten.view.default(view_1, [5, 5])
cos = torch.ops.aten.cos.default(view_2); view_2 = None
sum_1 = torch.ops.aten.sum.default(cos); cos = None
copy_ = torch.ops.aten.copy_.default(x_1, view_1); x_1 = view_1 = None
return sum_1""",
)
# torch.cond triggers the check of the branches because the predicate
# is a SymBool.
with self.assertRaisesRegex(
UnsupportedAliasMutationException, "One of torch.cond branch"
):
make_fx(torch.func.functionalize(f))(*example_inputs)
make_fx(torch.func.functionalize(f), tracing_mode="symbolic")(
*example_inputs
)
# https://github.com/pytorch/pytorch/issues/126988
@xfailIfTorchDynamo
def test_cond_functionalized_output_alias_input(self):
def true_fn(x):
return x
@ -1125,22 +1150,27 @@ def forward(self, arg0_1, arg1_1):
return cond(pred, true_fn, false_fn, [x])
example_inputs = (torch.ones(5, 5),)
functional_f = torch.func.functionalize(f)
gm = make_fx(torch.func.functionalize(f))(*example_inputs)
# torch.cond inlines into one of the branches because the predicate
# is a constant.
self.assertExpectedInline(
gm.code.strip(),
"""\
def forward(self, x_1):
view = torch.ops.aten.view.default(x_1, [5, 5]); x_1 = None
return view""",
)
# torch.cond triggers the check of the branches because the predicate
# is a SymBool.
with self.assertRaisesRegex(
UnsupportedAliasMutationException,
"One of torch.cond branch might be aliasing",
UnsupportedAliasMutationException, "One of torch.cond branch"
):
functional_f(*example_inputs)
with self.assertRaisesRegex(
UnsupportedAliasMutationException,
"One of torch.cond branch might be aliasing",
):
make_fx(torch.func.functionalize(f))(*example_inputs)
make_fx(torch.func.functionalize(f), tracing_mode="symbolic")(
*example_inputs
)
# https://github.com/pytorch/pytorch/issues/126988
@xfailIfTorchDynamo
def test_cond_functionalized_nested_input_mutation(self):
def true_true_fn(x):
x.add_(4)
@ -1161,19 +1191,14 @@ def forward(self, arg0_1, arg1_1):
return cond(pred, true_fn, false_fn, [x])
example_inputs = (torch.ones(4, 5),)
functional_f = torch.func.functionalize(f)
with self.assertRaisesRegex(
UnsupportedAliasMutationException, "One of torch.cond branch"
):
functional_f(*example_inputs)
with self.assertRaisesRegex(
UnsupportedAliasMutationException, "One of torch.cond branch"
):
make_fx(torch.func.functionalize(f))(*example_inputs)
make_fx(torch.func.functionalize(f), tracing_mode="symbolic")(
*example_inputs
)
# https://github.com/pytorch/pytorch/issues/126988
@xfailIfTorchDynamo
def test_cond_functionalized_nested_input_mutation_with_aot_func(self):
def true_true_fn(x):
x.add_(4)
@ -1197,15 +1222,12 @@ def forward(self, arg0_1, arg1_1):
try:
example_input_func = to_fun_old(example_input)
torch._enable_functionalization(reapply_views=False)
with self.assertRaisesRegex(
UnsupportedAliasMutationException, "One of torch.cond branch"
):
f(example_input_func)
f(example_input_func)
with self.assertRaisesRegex(
UnsupportedAliasMutationException, "One of torch.cond branch"
):
make_fx(f)(example_input_func)
make_fx(f, tracing_mode="symbolic")(example_input_func)
finally:
torch._disable_functionalization()
@ -1223,7 +1245,7 @@ def forward(self, arg0_1, arg1_1):
with self.assertRaisesRegex(
UnsupportedAliasMutationException, "One of torch.cond branch"
):
make_fx(f_wrapper(f))(example_input_func)
make_fx(f_wrapper(f), tracing_mode="symbolic")(example_input_func)
# https://github.com/pytorch/pytorch/issues/126988
@xfailIfTorchDynamo
@ -1236,7 +1258,7 @@ def forward(self, arg0_1, arg1_1):
return view_x
def f(x):
pred = x.shape[0] == 4
pred = x.sum() > 0
return cond(pred, true_fn, false_fn, [x])
example_input = torch.ones(5, 5)
@ -1278,7 +1300,7 @@ def forward(self, arg0_1, arg1_1):
UnsupportedAliasMutationException,
"One of torch.cond branch might be aliasing",
):
make_fx(f_wrapper(f))(example_input)
make_fx(f_wrapper(f), tracing_mode="symbolic")(example_input)
def test_cond_functionalized_aot_func_check_functional(self):
def true_fn(x):
@ -1316,7 +1338,7 @@ def forward(self, arg0_1, arg1_1):
return wrapper
result_gm = make_fx(f_wrapper(f))(example_input)
result_gm = make_fx(f_wrapper(f), tracing_mode="symbolic")(example_input)
for node in result_gm.true_graph_0.graph.nodes:
if node.op == "call_function":
self.assertTrue(not node.target._schema.is_mutable)
@ -1382,12 +1404,12 @@ def forward(self, arg0_1, arg1_1):
def forward(self, x_1, pred_1, pred2_1):
true_graph_0 = self.true_graph_0
false_graph_0 = self.false_graph_0
conditional = torch.ops.higher_order.cond(pred_1, true_graph_0, false_graph_0, [x_1]); pred_1 = true_graph_0 = false_graph_0 = None
getitem = conditional[0]; conditional = None
cond = torch.ops.higher_order.cond(pred_1, true_graph_0, false_graph_0, [x_1]); pred_1 = true_graph_0 = false_graph_0 = None
getitem = cond[0]; cond = None
true_graph_1 = self.true_graph_1
false_graph_1 = self.false_graph_1
conditional_1 = torch.ops.higher_order.cond(pred2_1, true_graph_1, false_graph_1, [x_1]); pred2_1 = true_graph_1 = false_graph_1 = x_1 = None
getitem_1 = conditional_1[0]; conditional_1 = None
cond_1 = torch.ops.higher_order.cond(pred2_1, true_graph_1, false_graph_1, [x_1]); pred2_1 = true_graph_1 = false_graph_1 = x_1 = None
getitem_1 = cond_1[0]; cond_1 = None
add = torch.ops.aten.add.Tensor(getitem, getitem_1); getitem = getitem_1 = None
return add""", # noqa: B950
)
@ -1555,12 +1577,12 @@ def forward(self, arg0_1):
def forward(self, x_1, pred_1, pred2_1):
true_graph_0 = self.true_graph_0
false_graph_0 = self.false_graph_0
conditional = torch.ops.higher_order.cond(pred_1, true_graph_0, false_graph_0, [x_1]); pred_1 = true_graph_0 = false_graph_0 = None
getitem = conditional[0]; conditional = None
cond = torch.ops.higher_order.cond(pred_1, true_graph_0, false_graph_0, [x_1]); pred_1 = true_graph_0 = false_graph_0 = None
getitem = cond[0]; cond = None
true_graph_1 = self.true_graph_1
false_graph_1 = self.false_graph_1
conditional_1 = torch.ops.higher_order.cond(pred2_1, true_graph_1, false_graph_1, [x_1]); pred2_1 = true_graph_1 = false_graph_1 = x_1 = None
getitem_1 = conditional_1[0]; conditional_1 = None
cond_1 = torch.ops.higher_order.cond(pred2_1, true_graph_1, false_graph_1, [x_1]); pred2_1 = true_graph_1 = false_graph_1 = x_1 = None
getitem_1 = cond_1[0]; cond_1 = None
add = torch.ops.aten.add.Tensor(getitem, getitem_1); getitem = getitem_1 = None
return add""", # noqa: B950
)
@ -1891,7 +1913,7 @@ def forward(self, arg0_1):
):
functional_f(*example_inputs)
def test_cond_autograd_fail(self):
def test_cond_autograd_succeed_when_pred_is_constant(self):
def true_fn(x):
return x.cos()
@ -1901,6 +1923,27 @@ def forward(self, arg0_1):
def f(x, y):
return control_flow.cond(x.shape[0] > 4, true_fn, false_fn, [y])
example_inputs = (
torch.ones(3, 2, 4, requires_grad=True),
torch.ones(4, requires_grad=True),
)
# Due to x.shape[0] can be statically evaluated to be False, we can evaluate
# the backward.
f(*example_inputs).sum().backward()
# Ensure no error is thrown when not running backward
f(*example_inputs)
def test_cond_autograd_fail(self):
def true_fn(x):
return x.cos()
def false_fn(x):
return x.sin()
def f(x, y):
return control_flow.cond(x.sum() > 4, true_fn, false_fn, [y])
example_inputs = (
torch.ones(3, 2, 4, requires_grad=True),
torch.ones(4, requires_grad=True),
@ -2029,8 +2072,8 @@ def forward(self, x_1):
eq = sym_size_int == 4; sym_size_int = None
true_graph_0 = self.true_graph_0
false_graph_0 = self.false_graph_0
conditional = torch.ops.higher_order.cond(eq, true_graph_0, false_graph_0, [x_1]); eq = true_graph_0 = false_graph_0 = x_1 = None
getitem = conditional[0]; conditional = None
cond = torch.ops.higher_order.cond(eq, true_graph_0, false_graph_0, [x_1]); eq = true_graph_0 = false_graph_0 = x_1 = None
getitem = cond[0]; cond = None
return getitem""", # noqa: B950
)
@ -2102,18 +2145,20 @@ def forward(self, x_1):
# expected branches takes [x, a, b] as input
inp = torch.randn(2, 3)
gm = make_fx(foo)(inp)
gm = make_fx(foo, tracing_mode="symbolic", _allow_non_fake_inputs=True)(inp)
self.assertExpectedInline(
gm.code.strip(),
"""\
def forward(self, x_1):
sym_size_int = torch.ops.aten.sym_size.int(x_1, 0)
eq = sym_size_int == 4; sym_size_int = None
true_graph_0 = self.true_graph_0
false_graph_0 = self.false_graph_0
_tensor_constant0 = self._tensor_constant0
_tensor_constant1 = self._tensor_constant1
conditional = torch.ops.higher_order.cond(False, true_graph_0, false_graph_0, [x_1, _tensor_constant0, _tensor_constant1]); true_graph_0 = false_graph_0 = x_1 = _tensor_constant0 = _tensor_constant1 = None
getitem = conditional[0]; conditional = None
cond = torch.ops.higher_order.cond(eq, true_graph_0, false_graph_0, [x_1, _tensor_constant0, _tensor_constant1]); eq = true_graph_0 = false_graph_0 = x_1 = _tensor_constant0 = _tensor_constant1 = None
getitem = cond[0]; cond = None
return getitem""", # noqa: B950
)
self.assertExpectedInline(
@ -2263,8 +2308,8 @@ def forward(self, pred_1, x_1):
def forward(self, arg0_1, arg1_1):
true_graph_0 = self.true_graph_0
false_graph_0 = self.false_graph_0
conditional = torch.ops.higher_order.cond(arg1_1, true_graph_0, false_graph_0, [arg0_1]); arg1_1 = true_graph_0 = false_graph_0 = arg0_1 = None
getitem = conditional[0]; conditional = None
cond = torch.ops.higher_order.cond(arg1_1, true_graph_0, false_graph_0, [arg0_1]); arg1_1 = true_graph_0 = false_graph_0 = arg0_1 = None
getitem = cond[0]; cond = None
return [getitem]""", # noqa: B950
)
@ -2305,7 +2350,7 @@ def forward(self, arg0_1, arg1_1):
counters.clear()
def foo(x, true_fn, false_fn):
return cond(x.shape[0] == 4, true_fn, false_fn, (x,))
return cond(x.sum() < 0, true_fn, false_fn, (x,))
inp = torch.ones(3, 4)
exp_out = inp.sin()
@ -2347,8 +2392,8 @@ def forward(self, x_1):
eq = sym_size_int == 4; sym_size_int = None
true_graph_0 = self.true_graph_0
false_graph_0 = self.false_graph_0
conditional = torch.ops.higher_order.cond(eq, true_graph_0, false_graph_0, [x_1]); eq = true_graph_0 = false_graph_0 = x_1 = None
getitem = conditional[0]; conditional = None
cond = torch.ops.higher_order.cond(eq, true_graph_0, false_graph_0, [x_1]); eq = true_graph_0 = false_graph_0 = x_1 = None
getitem = cond[0]; cond = None
return getitem""", # noqa: B950
)

View File

@ -632,6 +632,18 @@ class CondHigherOrderVariable(TorchHigherOrderOperatorVariable):
f"Expected 4 arguments but got {len(args)}.\n"
f"Usage: cond(pred, true_fn, false_fn, operands)",
)
# Specialize into one of the branches since pred is constant
if type(args[0]) is ConstantVariable:
log.warning(
"Pred is a Python constant. When used with torch.cond, it executes only one of the branches."
" If you want torch.cond to perserve two branches, please make the predicate a boolean tensor or a SymBool."
)
if args[0].as_python_constant():
return args[1].call_function(tx, args[3].unpack_var_sequence(tx), {})
else:
return args[2].call_function(tx, args[3].unpack_var_sequence(tx), {})
# predicate
if type(args[0]) not in (ConstantVariable, TensorVariable, SymNodeVariable):
unimplemented(

View File

@ -1,6 +1,8 @@
# mypy: allow-untyped-defs
import contextlib
import logging
import torch
import torch._subclasses.functional_tensor
import torch.utils._pytree as pytree
@ -32,6 +34,8 @@ from torch.fx.experimental.proxy_tensor import (
from torch.fx.passes.shape_prop import _extract_tensor_metadata
from torch.utils._python_dispatch import _get_current_dispatch_mode
log = logging.getLogger(__name__)
@exposed_in("torch")
def cond(pred, true_fn, false_fn, operands):
@ -103,10 +107,19 @@ def cond(pred, true_fn, false_fn, operands):
- The **output** of branches must be a **single Tensor**. Pytree of tensors will be supported in the future.
"""
if torch.compiler.is_dynamo_compiling():
return cond_op(pred, true_fn, false_fn, operands)
if isinstance(pred, (bool, int, float)):
log.warning(
"Pred is a Python constant. When used with torch.cond, it executes only one of the branches."
" If you want torch.cond to perserve two branches, please make the predicate a boolean tensor or a SymBool."
)
if pred:
return true_fn(*operands)
else:
return false_fn(*operands)
def _validate_input(pred, true_fn, false_fn, operands):
if not isinstance(pred, (bool, torch.Tensor, torch.SymBool)):
raise RuntimeError(f"Expected pred to be bool or tensor, but got {pred}.")
@ -200,7 +213,7 @@ def trace_cond(proxy_mode, func_overload, pred, true_fn, false_fn, operands):
proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, args)
out_proxy = proxy_mode.tracer.create_proxy(
"call_function", func_overload, proxy_args, {}, name="conditional"
"call_function", func_overload, proxy_args, {}
)
# At this point, we're *guaranteed* that whether an output came from the