mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[cond] inlining into one of the branches when pred is a python constant (#130493)
Reland https://github.com/pytorch/pytorch/pull/128709. When the input predicate is a python constant, we specialize into one of the branches and warn users that torch.cond is not preserving the dynamism. The previous behavior is that we baked in True/False in the cond operator. This can be confusing. In this PR, we change it to be specializing into one of the branches when the inputs are constants. We additionally change the naming of cond operator to default one without overriding its name. This allows better testing on de-serialized graph. Test Plan: The predicate in some existing tests is the result of a shape comparison. When no dynamic shape is involved, the predicate is a python bool. To fix them, we either change the predicate to be some data-dependent tensor or change the test to check cond is specialized as one of the branches, Pull Request resolved: https://github.com/pytorch/pytorch/pull/130493 Approved by: https://github.com/BoyuanFeng
This commit is contained in:
parent
0bf9a091ec
commit
741c1710e8
|
|
@ -1912,13 +1912,10 @@ def forward(self, l_x_):
|
|||
):
|
||||
# True branch and false branch return tensors of different shape
|
||||
torch._dynamo.export(mod)(torch.randn(3, 2))
|
||||
with self.assertRaisesRegex(
|
||||
torch._dynamo.exc.UncapturedHigherOrderOpError,
|
||||
"Cond doesn't work unless it is captured completely with torch.compile",
|
||||
):
|
||||
# True branch and false branch return tensors of different shape
|
||||
test_x = torch.randn(3, 2)
|
||||
mod(test_x)
|
||||
|
||||
# We specialize into one of the branches since predicate is a python boolean.
|
||||
test_x = torch.randn(3, 2)
|
||||
mod(test_x)
|
||||
|
||||
def test_export_with_map_cond(self):
|
||||
from functorch.experimental.control_flow import cond, map
|
||||
|
|
|
|||
|
|
@ -1406,7 +1406,7 @@ def forward(self, child, const_unused):
|
|||
def false_fn(x):
|
||||
return (x - 1).sum()
|
||||
|
||||
return control_flow.cond(x.shape[0] > 4, true_fn, false_fn, [x])
|
||||
return control_flow.cond(x.sum() > 4, true_fn, false_fn, [x])
|
||||
|
||||
mod_for_compile = torch.compile(Foo(), backend=cnt, dynamic=True)
|
||||
mod_for_eager = Foo()
|
||||
|
|
@ -6145,12 +6145,16 @@ class ActivationCheckpointingTests(torch._dynamo.test_case.TestCase):
|
|||
return cond_op(pred=pred, true_fn=true_fn, false_fn=false_fn, operands=[x])
|
||||
|
||||
cnt = CompileCounter()
|
||||
opt_test = torch.compile(test, backend=cnt)
|
||||
opt_test = torch.compile(test, backend=cnt, fullgraph=True)
|
||||
inp = torch.ones(3, 3)
|
||||
self.assertTrue(torch.allclose(test(True, inp), opt_test(True, inp)))
|
||||
true_pred = torch.Tensor([True])
|
||||
false_pred = torch.Tensor([False])
|
||||
self.assertTrue(torch.allclose(test(true_pred, inp), opt_test(true_pred, inp)))
|
||||
self.assertEqual(cnt.frame_count, 1)
|
||||
self.assertTrue(
|
||||
torch.allclose(test(false_pred, inp), opt_test(false_pred, inp))
|
||||
)
|
||||
self.assertEqual(cnt.frame_count, 1)
|
||||
self.assertTrue(torch.allclose(test(False, inp), opt_test(False, inp)))
|
||||
self.assertEqual(cnt.frame_count, 2)
|
||||
|
||||
def test_cond_with_invalid_kwargs(self):
|
||||
from torch._higher_order_ops.cond import cond_op
|
||||
|
|
|
|||
|
|
@ -804,7 +804,7 @@ def forward(self, p_linear_weight, p_linear_bias, x):
|
|||
return x.sin()
|
||||
|
||||
def forward(self, x):
|
||||
return cond(x.shape[0] <= 2, self.subm.forward, self.bar, [x])
|
||||
return cond(x.sum() <= 2, self.subm.forward, self.bar, [x])
|
||||
|
||||
example_inputs = (torch.randn(1, 3, 3, 3),)
|
||||
m = CondBranchClassMethod()
|
||||
|
|
@ -3616,7 +3616,7 @@ def forward(self, x):
|
|||
):
|
||||
torch.export.export(exported_v2.module(), (torch.randn(2, 2),))
|
||||
|
||||
def test_export_cond(self):
|
||||
def test_export_cond_symbool_pred(self):
|
||||
class A(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
|
@ -3639,10 +3639,20 @@ def forward(self, x):
|
|||
|
||||
return cond(x.shape[0] > 4, true_fn, false_fn, [x])
|
||||
|
||||
dim0 = torch.export.Dim("dim0", min=3)
|
||||
inp = torch.ones(6, 4)
|
||||
ep = export(
|
||||
Foo(),
|
||||
(inp,),
|
||||
ep = export(Foo(), (inp,), dynamic_shapes={"x": {0: dim0}})
|
||||
self.assertExpectedInline(
|
||||
ep.graph_module.code.strip(),
|
||||
"""\
|
||||
def forward(self, b_a_buffer, x):
|
||||
sym_size_int_1 = torch.ops.aten.sym_size.int(x, 0)
|
||||
gt = sym_size_int_1 > 4; sym_size_int_1 = None
|
||||
true_graph_0 = self.true_graph_0
|
||||
false_graph_0 = self.false_graph_0
|
||||
cond = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, [x, b_a_buffer]); gt = true_graph_0 = false_graph_0 = x = b_a_buffer = None
|
||||
getitem = cond[0]; cond = None
|
||||
return (getitem,)""",
|
||||
)
|
||||
self.assertTrue(
|
||||
torch.allclose(ep.module()(torch.ones(6, 4)), Foo()(torch.ones(6, 4)))
|
||||
|
|
@ -4988,7 +4998,7 @@ graph():
|
|||
def false_fn(x):
|
||||
return self.linear(x).sin()
|
||||
|
||||
return torch.cond(x.shape[0] > 4, true_fn, false_fn, [x])
|
||||
return torch.cond(x.sum() > 4, true_fn, false_fn, [x])
|
||||
|
||||
class CondExport(torch.nn.Module):
|
||||
def __init__(self):
|
||||
|
|
@ -5005,10 +5015,12 @@ graph():
|
|||
"""\
|
||||
def forward(self, p_bar_linear_weight, p_bar_linear_bias, x):
|
||||
cos = torch.ops.aten.cos.default(x)
|
||||
sum_1 = torch.ops.aten.sum.default(x)
|
||||
gt = torch.ops.aten.gt.Scalar(sum_1, 4); sum_1 = None
|
||||
true_graph_0 = self.true_graph_0
|
||||
false_graph_0 = self.false_graph_0
|
||||
conditional = torch.ops.higher_order.cond(False, true_graph_0, false_graph_0, [p_bar_linear_bias, p_bar_linear_weight, x]); true_graph_0 = false_graph_0 = p_bar_linear_bias = p_bar_linear_weight = x = None
|
||||
getitem = conditional[0]; conditional = None
|
||||
cond = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, [p_bar_linear_bias, p_bar_linear_weight, x]); gt = true_graph_0 = false_graph_0 = p_bar_linear_bias = p_bar_linear_weight = x = None
|
||||
getitem = cond[0]; cond = None
|
||||
add = torch.ops.aten.add.Tensor(cos, getitem); cos = getitem = None
|
||||
return (add,)""",
|
||||
)
|
||||
|
|
@ -5103,8 +5115,8 @@ def forward(self, p_bar_linear_weight, p_bar_linear_bias, x):
|
|||
def forward(self, b_pred, b_t, x, y):
|
||||
true_graph_0 = self.true_graph_0
|
||||
false_graph_0 = self.false_graph_0
|
||||
conditional = torch.ops.higher_order.cond(b_pred, true_graph_0, false_graph_0, [b_t, x, y]); b_pred = true_graph_0 = false_graph_0 = b_t = x = y = None
|
||||
getitem = conditional[0]; conditional = None
|
||||
cond = torch.ops.higher_order.cond(b_pred, true_graph_0, false_graph_0, [b_t, x, y]); b_pred = true_graph_0 = false_graph_0 = b_t = x = y = None
|
||||
getitem = cond[0]; cond = None
|
||||
return (getitem,)""",
|
||||
) # noqa: B950
|
||||
|
||||
|
|
|
|||
|
|
@ -68,7 +68,7 @@ class TestVerifier(TestCase):
|
|||
def false_fn(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
||||
return x - y
|
||||
|
||||
return control_flow.cond(x.shape[0] > 2, true_fn, false_fn, [x, y])
|
||||
return control_flow.cond(x.sum() > 2, true_fn, false_fn, [x, y])
|
||||
|
||||
f = Foo()
|
||||
|
||||
|
|
@ -87,7 +87,7 @@ class TestVerifier(TestCase):
|
|||
def false_fn(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
||||
return x - y
|
||||
|
||||
return control_flow.cond(x.shape[0] > 2, true_fn, false_fn, [x, y])
|
||||
return control_flow.cond(x.sum() > 2, true_fn, false_fn, [x, y])
|
||||
|
||||
f = Foo()
|
||||
|
||||
|
|
|
|||
|
|
@ -4237,7 +4237,7 @@ def forward(self, arg0_1, arg1_1):
|
|||
return x.cos()
|
||||
|
||||
return torch.cond(
|
||||
y.cos().shape[0] > 5, true_true_fn, true_false_fn, [y.cos()]
|
||||
y.cos().sum() > 5, true_true_fn, true_false_fn, [y.cos()]
|
||||
)
|
||||
|
||||
def false_fn(x):
|
||||
|
|
@ -4245,7 +4245,7 @@ def forward(self, arg0_1, arg1_1):
|
|||
z.add_(6)
|
||||
return z.sin()
|
||||
|
||||
a = torch.cond(x.shape[0] > 4, true_fn, false_fn, [x])
|
||||
a = torch.cond(x.sum() > 4, true_fn, false_fn, [x])
|
||||
return (a + 3, a + 4)
|
||||
|
||||
inp = torch.randn(2, 2)
|
||||
|
|
@ -4254,10 +4254,12 @@ def forward(self, arg0_1, arg1_1):
|
|||
str(gm.code).strip(),
|
||||
"""\
|
||||
def forward(self, arg0_1):
|
||||
sum_1 = torch.ops.aten.sum.default(arg0_1)
|
||||
gt = torch.ops.aten.gt.Scalar(sum_1, 4); sum_1 = None
|
||||
true_graph_0 = self.true_graph_0
|
||||
false_graph_0 = self.false_graph_0
|
||||
conditional = torch.ops.higher_order.cond(False, true_graph_0, false_graph_0, [arg0_1]); true_graph_0 = false_graph_0 = arg0_1 = None
|
||||
getitem = conditional[0]; conditional = None
|
||||
cond = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, [arg0_1]); gt = true_graph_0 = false_graph_0 = arg0_1 = None
|
||||
getitem = cond[0]; cond = None
|
||||
add = torch.ops.aten.add.Tensor(getitem, 3)
|
||||
add_1 = torch.ops.aten.add.Tensor(getitem, 4); getitem = None
|
||||
return (add, add_1)""", # noqa: B950
|
||||
|
|
@ -4270,11 +4272,13 @@ def forward(self, arg0_1):
|
|||
sin = torch.ops.aten.sin.default(arg0_1); arg0_1 = None
|
||||
add = torch.ops.aten.add.Tensor(sin, 5); sin = None
|
||||
cos = torch.ops.aten.cos.default(add)
|
||||
sum_1 = torch.ops.aten.sum.default(cos); cos = None
|
||||
gt = torch.ops.aten.gt.Scalar(sum_1, 5); sum_1 = None
|
||||
cos_1 = torch.ops.aten.cos.default(add); add = None
|
||||
true_graph_0 = self.true_graph_0
|
||||
false_graph_0 = self.false_graph_0
|
||||
conditional = torch.ops.higher_order.cond(False, true_graph_0, false_graph_0, [cos_1]); true_graph_0 = false_graph_0 = cos_1 = None
|
||||
getitem = conditional[0]; conditional = None
|
||||
cond = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, [cos_1]); gt = true_graph_0 = false_graph_0 = cos_1 = None
|
||||
getitem = cond[0]; cond = None
|
||||
return (getitem,)""", # noqa: B950
|
||||
)
|
||||
|
||||
|
|
@ -4317,7 +4321,7 @@ def forward(self, arg0_1):
|
|||
+ control_flow.map(f, z, r).sum()
|
||||
)
|
||||
|
||||
a = torch.cond(x.shape[0] > 4, true_fn, false_fn, [x, y])
|
||||
a = torch.cond(x.sum() > 4, true_fn, false_fn, [x, y])
|
||||
return (a + 3, a + 4)
|
||||
|
||||
inps = [torch.randn(2, 2), torch.ones(2)]
|
||||
|
|
@ -4326,10 +4330,12 @@ def forward(self, arg0_1):
|
|||
str(gm.code).strip(),
|
||||
"""\
|
||||
def forward(self, arg0_1, arg1_1):
|
||||
sum_1 = torch.ops.aten.sum.default(arg0_1)
|
||||
gt = torch.ops.aten.gt.Scalar(sum_1, 4); sum_1 = None
|
||||
true_graph_0 = self.true_graph_0
|
||||
false_graph_0 = self.false_graph_0
|
||||
conditional = torch.ops.higher_order.cond(False, true_graph_0, false_graph_0, [arg0_1, arg1_1]); true_graph_0 = false_graph_0 = arg0_1 = arg1_1 = None
|
||||
getitem = conditional[0]; conditional = None
|
||||
cond = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, [arg0_1, arg1_1]); gt = true_graph_0 = false_graph_0 = arg0_1 = arg1_1 = None
|
||||
getitem = cond[0]; cond = None
|
||||
add = torch.ops.aten.add.Tensor(getitem, 3)
|
||||
add_1 = torch.ops.aten.add.Tensor(getitem, 4); getitem = None
|
||||
return (add, add_1)""", # noqa: B950
|
||||
|
|
@ -4434,7 +4440,7 @@ def forward(self, arg0_1, arg1_1):
|
|||
z.add_(6)
|
||||
return z.sin()
|
||||
|
||||
a = torch.cond(x.shape[0] > 4, true_fn, false_fn, [x])
|
||||
a = torch.cond(x.sum() > 4, true_fn, false_fn, [x])
|
||||
return (a + 3, a + 4)
|
||||
|
||||
inp = torch.randn(2, 2)
|
||||
|
|
@ -4443,10 +4449,12 @@ def forward(self, arg0_1, arg1_1):
|
|||
str(gm.code).strip(),
|
||||
"""\
|
||||
def forward(self, arg0_1):
|
||||
sum_1 = torch.ops.aten.sum.default(arg0_1)
|
||||
gt = torch.ops.aten.gt.Scalar(sum_1, 4); sum_1 = None
|
||||
true_graph_0 = self.true_graph_0
|
||||
false_graph_0 = self.false_graph_0
|
||||
conditional = torch.ops.higher_order.cond(False, true_graph_0, false_graph_0, [arg0_1]); true_graph_0 = false_graph_0 = arg0_1 = None
|
||||
getitem = conditional[0]; conditional = None
|
||||
cond = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, [arg0_1]); gt = true_graph_0 = false_graph_0 = arg0_1 = None
|
||||
getitem = cond[0]; cond = None
|
||||
add = torch.ops.aten.add.Tensor(getitem, 3)
|
||||
add_1 = torch.ops.aten.add.Tensor(getitem, 4); getitem = None
|
||||
return (add, add_1)""", # noqa: B950
|
||||
|
|
@ -4867,7 +4875,7 @@ def forward(self, arg0_1, arg1_1, arg2_1):
|
|||
y.add_(6)
|
||||
return x.sin()
|
||||
|
||||
a = torch.cond(x.shape[0] > 4, true_fn, false_fn, [x])
|
||||
a = torch.cond(x.sum() > 4, true_fn, false_fn, [x])
|
||||
return (a + 3, a + 4)
|
||||
|
||||
inp = torch.randn(3, 4)
|
||||
|
|
@ -4876,10 +4884,12 @@ def forward(self, arg0_1, arg1_1, arg2_1):
|
|||
gm.code.strip(),
|
||||
"""\
|
||||
def forward(self, arg0_1):
|
||||
sum_1 = torch.ops.aten.sum.default(arg0_1)
|
||||
gt = torch.ops.aten.gt.Scalar(sum_1, 4); sum_1 = None
|
||||
true_graph_0 = self.true_graph_0
|
||||
false_graph_0 = self.false_graph_0
|
||||
conditional = torch.ops.higher_order.cond(False, true_graph_0, false_graph_0, [arg0_1]); true_graph_0 = false_graph_0 = arg0_1 = None
|
||||
getitem = conditional[0]; conditional = None
|
||||
cond = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, [arg0_1]); gt = true_graph_0 = false_graph_0 = arg0_1 = None
|
||||
getitem = cond[0]; cond = None
|
||||
add = torch.ops.aten.add.Tensor(getitem, 3)
|
||||
add_1 = torch.ops.aten.add.Tensor(getitem, 4); getitem = None
|
||||
return (add, add_1)""", # noqa: B950
|
||||
|
|
|
|||
|
|
@ -877,7 +877,7 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1):
|
|||
f(x, torch.tensor(True), torch.tensor(True)),
|
||||
)
|
||||
|
||||
def test_cond_functionalized_hah(self):
|
||||
def test_cond_functionalized(self):
|
||||
def true_fn(x):
|
||||
y = x.sin()
|
||||
y.add_(4)
|
||||
|
|
@ -894,7 +894,9 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1):
|
|||
functional_f = torch.func.functionalize(f)
|
||||
self.assertEqual(functional_f(*example_inputs), f(*example_inputs))
|
||||
|
||||
graph_module = make_fx(torch.func.functionalize(f))(*example_inputs)
|
||||
graph_module = make_fx(torch.func.functionalize(f), tracing_mode="symbolic")(
|
||||
*example_inputs
|
||||
)
|
||||
self.assertEqual(graph_module(*example_inputs), f(*example_inputs))
|
||||
|
||||
all_ops_in_true_branch = []
|
||||
|
|
@ -904,9 +906,6 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1):
|
|||
|
||||
self.assertFalse(any(op._schema.is_mutable for op in all_ops_in_true_branch))
|
||||
|
||||
graph_module = make_fx(torch.func.functionalize(f), tracing_mode="symbolic")(
|
||||
*example_inputs
|
||||
)
|
||||
self.assertEqual(graph_module(*example_inputs), f(*example_inputs))
|
||||
|
||||
def test_cond_accepts_torch_function_as_inputs(self):
|
||||
|
|
@ -925,8 +924,8 @@ def forward(self, a_1, b_1):
|
|||
gt = torch.ops.aten.gt.Scalar(sum_1, 0); sum_1 = None
|
||||
true_graph_0 = self.true_graph_0
|
||||
false_graph_0 = self.false_graph_0
|
||||
conditional = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, [a_1, b_1]); gt = true_graph_0 = false_graph_0 = a_1 = b_1 = None
|
||||
getitem = conditional[0]; conditional = None
|
||||
cond = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, [a_1, b_1]); gt = true_graph_0 = false_graph_0 = a_1 = b_1 = None
|
||||
getitem = cond[0]; cond = None
|
||||
return getitem""", # noqa: B950
|
||||
)
|
||||
self.assertExpectedInline(
|
||||
|
|
@ -973,9 +972,9 @@ def forward(self, arg0_1, arg1_1):
|
|||
z = torch.add(y, y)
|
||||
return z
|
||||
|
||||
symbolic_traced_graph = self._check_tracing(f, (torch.ones(4), True))[
|
||||
"symbolic"
|
||||
]
|
||||
symbolic_traced_graph = self._check_tracing(
|
||||
f, (torch.ones(4), torch.Tensor([True]))
|
||||
)["symbolic"]
|
||||
graph_shape_env = symbolic_traced_graph.shape_env
|
||||
|
||||
def _node_shape_env_iter(gm):
|
||||
|
|
@ -1021,15 +1020,14 @@ def forward(self, arg0_1, arg1_1):
|
|||
functional_f = torch.func.functionalize(f)
|
||||
self.assertEqual(functional_f(*example_inputs), f(*example_inputs))
|
||||
|
||||
graph_module = make_fx(torch.func.functionalize(f))(*example_inputs)
|
||||
graph_module = make_fx(torch.func.functionalize(f), tracing_mode="symbolic")(
|
||||
*example_inputs
|
||||
)
|
||||
self.assertEqual(graph_module(*example_inputs), f(*example_inputs))
|
||||
|
||||
gm_true_true_branch = graph_module.true_graph_0.true_graph_0
|
||||
|
||||
graph_module1 = make_fx(torch.func.functionalize(f), tracing_mode="symbolic")(
|
||||
*example_inputs
|
||||
)
|
||||
self.assertEqual(graph_module1(*example_inputs), f(*example_inputs))
|
||||
self.assertEqual(graph_module(*example_inputs), f(*example_inputs))
|
||||
|
||||
all_ops = []
|
||||
for node in gm_true_true_branch.graph.nodes:
|
||||
|
|
@ -1057,8 +1055,7 @@ def forward(self, arg0_1, arg1_1):
|
|||
self.assertEqual(graph_module(*example_inputs), f(*example_inputs))
|
||||
|
||||
# https://github.com/pytorch/pytorch/issues/126988
|
||||
@xfailIfTorchDynamo
|
||||
def test_cond_functionalized_input_mutation_on_true_branch(self):
|
||||
def test_cond_functionalized_input_mutation_on_true_brancte(self):
|
||||
def true_fn(x):
|
||||
view_x = x.view(x.shape)
|
||||
view_x.add_(1)
|
||||
|
|
@ -1072,19 +1069,33 @@ def forward(self, arg0_1, arg1_1):
|
|||
return cond(pred, true_fn, false_fn, [x])
|
||||
|
||||
example_inputs = (torch.ones(4, 5),)
|
||||
functional_f = torch.func.functionalize(f)
|
||||
with self.assertRaisesRegex(
|
||||
UnsupportedAliasMutationException, "One of torch.cond branch"
|
||||
):
|
||||
functional_f(*example_inputs)
|
||||
# torch.cond inlines into one of the branches because the predicate
|
||||
# is a constant.
|
||||
gm = make_fx(torch.func.functionalize(f))(*example_inputs)
|
||||
self.assertExpectedInline(
|
||||
gm.code.strip(),
|
||||
"""\
|
||||
def forward(self, x_1):
|
||||
view = torch.ops.aten.view.default(x_1, [4, 5])
|
||||
add = torch.ops.aten.add.Tensor(view, 1); view = None
|
||||
view_1 = torch.ops.aten.view.default(add, [4, 5]); add = None
|
||||
view_2 = torch.ops.aten.view.default(view_1, [4, 5])
|
||||
sin = torch.ops.aten.sin.default(view_2); view_2 = None
|
||||
sum_1 = torch.ops.aten.sum.default(sin); sin = None
|
||||
copy_ = torch.ops.aten.copy_.default(x_1, view_1); x_1 = view_1 = None
|
||||
return sum_1""",
|
||||
)
|
||||
|
||||
# torch.cond triggers the check of the branches because the predicate
|
||||
# is a SymBool.
|
||||
with self.assertRaisesRegex(
|
||||
UnsupportedAliasMutationException, "One of torch.cond branch"
|
||||
):
|
||||
make_fx(torch.func.functionalize(f))(*example_inputs)
|
||||
make_fx(torch.func.functionalize(f), tracing_mode="symbolic")(
|
||||
*example_inputs
|
||||
)
|
||||
|
||||
# https://github.com/pytorch/pytorch/issues/126988
|
||||
@xfailIfTorchDynamo
|
||||
def test_cond_functionalized_input_mutation_on_false_branch(self):
|
||||
def true_fn(x):
|
||||
return x.sin().sum()
|
||||
|
|
@ -1099,19 +1110,33 @@ def forward(self, arg0_1, arg1_1):
|
|||
return cond(pred, true_fn, false_fn, [x])
|
||||
|
||||
example_inputs = (torch.ones(5, 5),)
|
||||
functional_f = torch.func.functionalize(f)
|
||||
with self.assertRaisesRegex(
|
||||
UnsupportedAliasMutationException, "One of torch.cond branch"
|
||||
):
|
||||
functional_f(*example_inputs)
|
||||
gm = make_fx(torch.func.functionalize(f))(*example_inputs)
|
||||
# torch.cond inlines into one of the branches because the predicate
|
||||
# is a constant.
|
||||
self.assertExpectedInline(
|
||||
gm.code.strip(),
|
||||
"""\
|
||||
def forward(self, x_1):
|
||||
view = torch.ops.aten.view.default(x_1, [5, 5])
|
||||
add = torch.ops.aten.add.Tensor(view, 1); view = None
|
||||
view_1 = torch.ops.aten.view.default(add, [5, 5]); add = None
|
||||
view_2 = torch.ops.aten.view.default(view_1, [5, 5])
|
||||
cos = torch.ops.aten.cos.default(view_2); view_2 = None
|
||||
sum_1 = torch.ops.aten.sum.default(cos); cos = None
|
||||
copy_ = torch.ops.aten.copy_.default(x_1, view_1); x_1 = view_1 = None
|
||||
return sum_1""",
|
||||
)
|
||||
|
||||
# torch.cond triggers the check of the branches because the predicate
|
||||
# is a SymBool.
|
||||
with self.assertRaisesRegex(
|
||||
UnsupportedAliasMutationException, "One of torch.cond branch"
|
||||
):
|
||||
make_fx(torch.func.functionalize(f))(*example_inputs)
|
||||
make_fx(torch.func.functionalize(f), tracing_mode="symbolic")(
|
||||
*example_inputs
|
||||
)
|
||||
|
||||
# https://github.com/pytorch/pytorch/issues/126988
|
||||
@xfailIfTorchDynamo
|
||||
def test_cond_functionalized_output_alias_input(self):
|
||||
def true_fn(x):
|
||||
return x
|
||||
|
|
@ -1125,22 +1150,27 @@ def forward(self, arg0_1, arg1_1):
|
|||
return cond(pred, true_fn, false_fn, [x])
|
||||
|
||||
example_inputs = (torch.ones(5, 5),)
|
||||
functional_f = torch.func.functionalize(f)
|
||||
gm = make_fx(torch.func.functionalize(f))(*example_inputs)
|
||||
# torch.cond inlines into one of the branches because the predicate
|
||||
# is a constant.
|
||||
self.assertExpectedInline(
|
||||
gm.code.strip(),
|
||||
"""\
|
||||
def forward(self, x_1):
|
||||
view = torch.ops.aten.view.default(x_1, [5, 5]); x_1 = None
|
||||
return view""",
|
||||
)
|
||||
|
||||
# torch.cond triggers the check of the branches because the predicate
|
||||
# is a SymBool.
|
||||
with self.assertRaisesRegex(
|
||||
UnsupportedAliasMutationException,
|
||||
"One of torch.cond branch might be aliasing",
|
||||
UnsupportedAliasMutationException, "One of torch.cond branch"
|
||||
):
|
||||
functional_f(*example_inputs)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
UnsupportedAliasMutationException,
|
||||
"One of torch.cond branch might be aliasing",
|
||||
):
|
||||
make_fx(torch.func.functionalize(f))(*example_inputs)
|
||||
make_fx(torch.func.functionalize(f), tracing_mode="symbolic")(
|
||||
*example_inputs
|
||||
)
|
||||
|
||||
# https://github.com/pytorch/pytorch/issues/126988
|
||||
@xfailIfTorchDynamo
|
||||
def test_cond_functionalized_nested_input_mutation(self):
|
||||
def true_true_fn(x):
|
||||
x.add_(4)
|
||||
|
|
@ -1161,19 +1191,14 @@ def forward(self, arg0_1, arg1_1):
|
|||
return cond(pred, true_fn, false_fn, [x])
|
||||
|
||||
example_inputs = (torch.ones(4, 5),)
|
||||
functional_f = torch.func.functionalize(f)
|
||||
with self.assertRaisesRegex(
|
||||
UnsupportedAliasMutationException, "One of torch.cond branch"
|
||||
):
|
||||
functional_f(*example_inputs)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
UnsupportedAliasMutationException, "One of torch.cond branch"
|
||||
):
|
||||
make_fx(torch.func.functionalize(f))(*example_inputs)
|
||||
make_fx(torch.func.functionalize(f), tracing_mode="symbolic")(
|
||||
*example_inputs
|
||||
)
|
||||
|
||||
# https://github.com/pytorch/pytorch/issues/126988
|
||||
@xfailIfTorchDynamo
|
||||
def test_cond_functionalized_nested_input_mutation_with_aot_func(self):
|
||||
def true_true_fn(x):
|
||||
x.add_(4)
|
||||
|
|
@ -1197,15 +1222,12 @@ def forward(self, arg0_1, arg1_1):
|
|||
try:
|
||||
example_input_func = to_fun_old(example_input)
|
||||
torch._enable_functionalization(reapply_views=False)
|
||||
with self.assertRaisesRegex(
|
||||
UnsupportedAliasMutationException, "One of torch.cond branch"
|
||||
):
|
||||
f(example_input_func)
|
||||
f(example_input_func)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
UnsupportedAliasMutationException, "One of torch.cond branch"
|
||||
):
|
||||
make_fx(f)(example_input_func)
|
||||
make_fx(f, tracing_mode="symbolic")(example_input_func)
|
||||
finally:
|
||||
torch._disable_functionalization()
|
||||
|
||||
|
|
@ -1223,7 +1245,7 @@ def forward(self, arg0_1, arg1_1):
|
|||
with self.assertRaisesRegex(
|
||||
UnsupportedAliasMutationException, "One of torch.cond branch"
|
||||
):
|
||||
make_fx(f_wrapper(f))(example_input_func)
|
||||
make_fx(f_wrapper(f), tracing_mode="symbolic")(example_input_func)
|
||||
|
||||
# https://github.com/pytorch/pytorch/issues/126988
|
||||
@xfailIfTorchDynamo
|
||||
|
|
@ -1236,7 +1258,7 @@ def forward(self, arg0_1, arg1_1):
|
|||
return view_x
|
||||
|
||||
def f(x):
|
||||
pred = x.shape[0] == 4
|
||||
pred = x.sum() > 0
|
||||
return cond(pred, true_fn, false_fn, [x])
|
||||
|
||||
example_input = torch.ones(5, 5)
|
||||
|
|
@ -1278,7 +1300,7 @@ def forward(self, arg0_1, arg1_1):
|
|||
UnsupportedAliasMutationException,
|
||||
"One of torch.cond branch might be aliasing",
|
||||
):
|
||||
make_fx(f_wrapper(f))(example_input)
|
||||
make_fx(f_wrapper(f), tracing_mode="symbolic")(example_input)
|
||||
|
||||
def test_cond_functionalized_aot_func_check_functional(self):
|
||||
def true_fn(x):
|
||||
|
|
@ -1316,7 +1338,7 @@ def forward(self, arg0_1, arg1_1):
|
|||
|
||||
return wrapper
|
||||
|
||||
result_gm = make_fx(f_wrapper(f))(example_input)
|
||||
result_gm = make_fx(f_wrapper(f), tracing_mode="symbolic")(example_input)
|
||||
for node in result_gm.true_graph_0.graph.nodes:
|
||||
if node.op == "call_function":
|
||||
self.assertTrue(not node.target._schema.is_mutable)
|
||||
|
|
@ -1382,12 +1404,12 @@ def forward(self, arg0_1, arg1_1):
|
|||
def forward(self, x_1, pred_1, pred2_1):
|
||||
true_graph_0 = self.true_graph_0
|
||||
false_graph_0 = self.false_graph_0
|
||||
conditional = torch.ops.higher_order.cond(pred_1, true_graph_0, false_graph_0, [x_1]); pred_1 = true_graph_0 = false_graph_0 = None
|
||||
getitem = conditional[0]; conditional = None
|
||||
cond = torch.ops.higher_order.cond(pred_1, true_graph_0, false_graph_0, [x_1]); pred_1 = true_graph_0 = false_graph_0 = None
|
||||
getitem = cond[0]; cond = None
|
||||
true_graph_1 = self.true_graph_1
|
||||
false_graph_1 = self.false_graph_1
|
||||
conditional_1 = torch.ops.higher_order.cond(pred2_1, true_graph_1, false_graph_1, [x_1]); pred2_1 = true_graph_1 = false_graph_1 = x_1 = None
|
||||
getitem_1 = conditional_1[0]; conditional_1 = None
|
||||
cond_1 = torch.ops.higher_order.cond(pred2_1, true_graph_1, false_graph_1, [x_1]); pred2_1 = true_graph_1 = false_graph_1 = x_1 = None
|
||||
getitem_1 = cond_1[0]; cond_1 = None
|
||||
add = torch.ops.aten.add.Tensor(getitem, getitem_1); getitem = getitem_1 = None
|
||||
return add""", # noqa: B950
|
||||
)
|
||||
|
|
@ -1555,12 +1577,12 @@ def forward(self, arg0_1):
|
|||
def forward(self, x_1, pred_1, pred2_1):
|
||||
true_graph_0 = self.true_graph_0
|
||||
false_graph_0 = self.false_graph_0
|
||||
conditional = torch.ops.higher_order.cond(pred_1, true_graph_0, false_graph_0, [x_1]); pred_1 = true_graph_0 = false_graph_0 = None
|
||||
getitem = conditional[0]; conditional = None
|
||||
cond = torch.ops.higher_order.cond(pred_1, true_graph_0, false_graph_0, [x_1]); pred_1 = true_graph_0 = false_graph_0 = None
|
||||
getitem = cond[0]; cond = None
|
||||
true_graph_1 = self.true_graph_1
|
||||
false_graph_1 = self.false_graph_1
|
||||
conditional_1 = torch.ops.higher_order.cond(pred2_1, true_graph_1, false_graph_1, [x_1]); pred2_1 = true_graph_1 = false_graph_1 = x_1 = None
|
||||
getitem_1 = conditional_1[0]; conditional_1 = None
|
||||
cond_1 = torch.ops.higher_order.cond(pred2_1, true_graph_1, false_graph_1, [x_1]); pred2_1 = true_graph_1 = false_graph_1 = x_1 = None
|
||||
getitem_1 = cond_1[0]; cond_1 = None
|
||||
add = torch.ops.aten.add.Tensor(getitem, getitem_1); getitem = getitem_1 = None
|
||||
return add""", # noqa: B950
|
||||
)
|
||||
|
|
@ -1891,7 +1913,7 @@ def forward(self, arg0_1):
|
|||
):
|
||||
functional_f(*example_inputs)
|
||||
|
||||
def test_cond_autograd_fail(self):
|
||||
def test_cond_autograd_succeed_when_pred_is_constant(self):
|
||||
def true_fn(x):
|
||||
return x.cos()
|
||||
|
||||
|
|
@ -1901,6 +1923,27 @@ def forward(self, arg0_1):
|
|||
def f(x, y):
|
||||
return control_flow.cond(x.shape[0] > 4, true_fn, false_fn, [y])
|
||||
|
||||
example_inputs = (
|
||||
torch.ones(3, 2, 4, requires_grad=True),
|
||||
torch.ones(4, requires_grad=True),
|
||||
)
|
||||
# Due to x.shape[0] can be statically evaluated to be False, we can evaluate
|
||||
# the backward.
|
||||
f(*example_inputs).sum().backward()
|
||||
|
||||
# Ensure no error is thrown when not running backward
|
||||
f(*example_inputs)
|
||||
|
||||
def test_cond_autograd_fail(self):
|
||||
def true_fn(x):
|
||||
return x.cos()
|
||||
|
||||
def false_fn(x):
|
||||
return x.sin()
|
||||
|
||||
def f(x, y):
|
||||
return control_flow.cond(x.sum() > 4, true_fn, false_fn, [y])
|
||||
|
||||
example_inputs = (
|
||||
torch.ones(3, 2, 4, requires_grad=True),
|
||||
torch.ones(4, requires_grad=True),
|
||||
|
|
@ -2029,8 +2072,8 @@ def forward(self, x_1):
|
|||
eq = sym_size_int == 4; sym_size_int = None
|
||||
true_graph_0 = self.true_graph_0
|
||||
false_graph_0 = self.false_graph_0
|
||||
conditional = torch.ops.higher_order.cond(eq, true_graph_0, false_graph_0, [x_1]); eq = true_graph_0 = false_graph_0 = x_1 = None
|
||||
getitem = conditional[0]; conditional = None
|
||||
cond = torch.ops.higher_order.cond(eq, true_graph_0, false_graph_0, [x_1]); eq = true_graph_0 = false_graph_0 = x_1 = None
|
||||
getitem = cond[0]; cond = None
|
||||
return getitem""", # noqa: B950
|
||||
)
|
||||
|
||||
|
|
@ -2102,18 +2145,20 @@ def forward(self, x_1):
|
|||
# expected branches takes [x, a, b] as input
|
||||
inp = torch.randn(2, 3)
|
||||
|
||||
gm = make_fx(foo)(inp)
|
||||
gm = make_fx(foo, tracing_mode="symbolic", _allow_non_fake_inputs=True)(inp)
|
||||
|
||||
self.assertExpectedInline(
|
||||
gm.code.strip(),
|
||||
"""\
|
||||
def forward(self, x_1):
|
||||
sym_size_int = torch.ops.aten.sym_size.int(x_1, 0)
|
||||
eq = sym_size_int == 4; sym_size_int = None
|
||||
true_graph_0 = self.true_graph_0
|
||||
false_graph_0 = self.false_graph_0
|
||||
_tensor_constant0 = self._tensor_constant0
|
||||
_tensor_constant1 = self._tensor_constant1
|
||||
conditional = torch.ops.higher_order.cond(False, true_graph_0, false_graph_0, [x_1, _tensor_constant0, _tensor_constant1]); true_graph_0 = false_graph_0 = x_1 = _tensor_constant0 = _tensor_constant1 = None
|
||||
getitem = conditional[0]; conditional = None
|
||||
cond = torch.ops.higher_order.cond(eq, true_graph_0, false_graph_0, [x_1, _tensor_constant0, _tensor_constant1]); eq = true_graph_0 = false_graph_0 = x_1 = _tensor_constant0 = _tensor_constant1 = None
|
||||
getitem = cond[0]; cond = None
|
||||
return getitem""", # noqa: B950
|
||||
)
|
||||
self.assertExpectedInline(
|
||||
|
|
@ -2263,8 +2308,8 @@ def forward(self, pred_1, x_1):
|
|||
def forward(self, arg0_1, arg1_1):
|
||||
true_graph_0 = self.true_graph_0
|
||||
false_graph_0 = self.false_graph_0
|
||||
conditional = torch.ops.higher_order.cond(arg1_1, true_graph_0, false_graph_0, [arg0_1]); arg1_1 = true_graph_0 = false_graph_0 = arg0_1 = None
|
||||
getitem = conditional[0]; conditional = None
|
||||
cond = torch.ops.higher_order.cond(arg1_1, true_graph_0, false_graph_0, [arg0_1]); arg1_1 = true_graph_0 = false_graph_0 = arg0_1 = None
|
||||
getitem = cond[0]; cond = None
|
||||
return [getitem]""", # noqa: B950
|
||||
)
|
||||
|
||||
|
|
@ -2305,7 +2350,7 @@ def forward(self, arg0_1, arg1_1):
|
|||
counters.clear()
|
||||
|
||||
def foo(x, true_fn, false_fn):
|
||||
return cond(x.shape[0] == 4, true_fn, false_fn, (x,))
|
||||
return cond(x.sum() < 0, true_fn, false_fn, (x,))
|
||||
|
||||
inp = torch.ones(3, 4)
|
||||
exp_out = inp.sin()
|
||||
|
|
@ -2347,8 +2392,8 @@ def forward(self, x_1):
|
|||
eq = sym_size_int == 4; sym_size_int = None
|
||||
true_graph_0 = self.true_graph_0
|
||||
false_graph_0 = self.false_graph_0
|
||||
conditional = torch.ops.higher_order.cond(eq, true_graph_0, false_graph_0, [x_1]); eq = true_graph_0 = false_graph_0 = x_1 = None
|
||||
getitem = conditional[0]; conditional = None
|
||||
cond = torch.ops.higher_order.cond(eq, true_graph_0, false_graph_0, [x_1]); eq = true_graph_0 = false_graph_0 = x_1 = None
|
||||
getitem = cond[0]; cond = None
|
||||
return getitem""", # noqa: B950
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -632,6 +632,18 @@ class CondHigherOrderVariable(TorchHigherOrderOperatorVariable):
|
|||
f"Expected 4 arguments but got {len(args)}.\n"
|
||||
f"Usage: cond(pred, true_fn, false_fn, operands)",
|
||||
)
|
||||
|
||||
# Specialize into one of the branches since pred is constant
|
||||
if type(args[0]) is ConstantVariable:
|
||||
log.warning(
|
||||
"Pred is a Python constant. When used with torch.cond, it executes only one of the branches."
|
||||
" If you want torch.cond to perserve two branches, please make the predicate a boolean tensor or a SymBool."
|
||||
)
|
||||
if args[0].as_python_constant():
|
||||
return args[1].call_function(tx, args[3].unpack_var_sequence(tx), {})
|
||||
else:
|
||||
return args[2].call_function(tx, args[3].unpack_var_sequence(tx), {})
|
||||
|
||||
# predicate
|
||||
if type(args[0]) not in (ConstantVariable, TensorVariable, SymNodeVariable):
|
||||
unimplemented(
|
||||
|
|
|
|||
|
|
@ -1,6 +1,8 @@
|
|||
# mypy: allow-untyped-defs
|
||||
import contextlib
|
||||
|
||||
import logging
|
||||
|
||||
import torch
|
||||
import torch._subclasses.functional_tensor
|
||||
import torch.utils._pytree as pytree
|
||||
|
|
@ -32,6 +34,8 @@ from torch.fx.experimental.proxy_tensor import (
|
|||
from torch.fx.passes.shape_prop import _extract_tensor_metadata
|
||||
from torch.utils._python_dispatch import _get_current_dispatch_mode
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@exposed_in("torch")
|
||||
def cond(pred, true_fn, false_fn, operands):
|
||||
|
|
@ -103,10 +107,19 @@ def cond(pred, true_fn, false_fn, operands):
|
|||
- The **output** of branches must be a **single Tensor**. Pytree of tensors will be supported in the future.
|
||||
|
||||
"""
|
||||
|
||||
if torch.compiler.is_dynamo_compiling():
|
||||
return cond_op(pred, true_fn, false_fn, operands)
|
||||
|
||||
if isinstance(pred, (bool, int, float)):
|
||||
log.warning(
|
||||
"Pred is a Python constant. When used with torch.cond, it executes only one of the branches."
|
||||
" If you want torch.cond to perserve two branches, please make the predicate a boolean tensor or a SymBool."
|
||||
)
|
||||
if pred:
|
||||
return true_fn(*operands)
|
||||
else:
|
||||
return false_fn(*operands)
|
||||
|
||||
def _validate_input(pred, true_fn, false_fn, operands):
|
||||
if not isinstance(pred, (bool, torch.Tensor, torch.SymBool)):
|
||||
raise RuntimeError(f"Expected pred to be bool or tensor, but got {pred}.")
|
||||
|
|
@ -200,7 +213,7 @@ def trace_cond(proxy_mode, func_overload, pred, true_fn, false_fn, operands):
|
|||
proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, args)
|
||||
|
||||
out_proxy = proxy_mode.tracer.create_proxy(
|
||||
"call_function", func_overload, proxy_args, {}, name="conditional"
|
||||
"call_function", func_overload, proxy_args, {}
|
||||
)
|
||||
|
||||
# At this point, we're *guaranteed* that whether an output came from the
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user