mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Revert "[cond] inlining into one of the branches when pred is a python constant (#128709)"
This reverts commitfe3e6878c4. Reverted https://github.com/pytorch/pytorch/pull/128709 on behalf of https://github.com/ydwu4 due to causing error on truck due to a land racing:fe3e6878c4([comment](https://github.com/pytorch/pytorch/pull/128709#issuecomment-2221104043))
This commit is contained in:
parent
b4b7477d3f
commit
0beeac35fa
|
|
@ -1912,8 +1912,11 @@ def forward(self, l_x_):
|
|||
):
|
||||
# True branch and false branch return tensors of different shape
|
||||
torch._dynamo.export(mod)(torch.randn(3, 2))
|
||||
|
||||
# We specialize into one of the branches since predicate is a python boolean.
|
||||
with self.assertRaisesRegex(
|
||||
torch._dynamo.exc.UncapturedHigherOrderOpError,
|
||||
"Cond doesn't work unless it is captured completely with torch.compile",
|
||||
):
|
||||
# True branch and false branch return tensors of different shape
|
||||
test_x = torch.randn(3, 2)
|
||||
mod(test_x)
|
||||
|
||||
|
|
|
|||
|
|
@ -1408,7 +1408,7 @@ def forward(self, child, const_unused):
|
|||
def false_fn(x):
|
||||
return (x - 1).sum()
|
||||
|
||||
return control_flow.cond(x.sum() > 4, true_fn, false_fn, [x])
|
||||
return control_flow.cond(x.shape[0] > 4, true_fn, false_fn, [x])
|
||||
|
||||
mod_for_compile = torch.compile(Foo(), backend=cnt, dynamic=True)
|
||||
mod_for_eager = Foo()
|
||||
|
|
@ -6147,16 +6147,12 @@ class ActivationCheckpointingTests(torch._dynamo.test_case.TestCase):
|
|||
return cond_op(pred=pred, true_fn=true_fn, false_fn=false_fn, operands=[x])
|
||||
|
||||
cnt = CompileCounter()
|
||||
opt_test = torch.compile(test, backend=cnt, fullgraph=True)
|
||||
opt_test = torch.compile(test, backend=cnt)
|
||||
inp = torch.ones(3, 3)
|
||||
true_pred = torch.Tensor([True])
|
||||
false_pred = torch.Tensor([False])
|
||||
self.assertTrue(torch.allclose(test(true_pred, inp), opt_test(true_pred, inp)))
|
||||
self.assertEqual(cnt.frame_count, 1)
|
||||
self.assertTrue(
|
||||
torch.allclose(test(false_pred, inp), opt_test(false_pred, inp))
|
||||
)
|
||||
self.assertTrue(torch.allclose(test(True, inp), opt_test(True, inp)))
|
||||
self.assertEqual(cnt.frame_count, 1)
|
||||
self.assertTrue(torch.allclose(test(False, inp), opt_test(False, inp)))
|
||||
self.assertEqual(cnt.frame_count, 2)
|
||||
|
||||
def test_cond_with_invalid_kwargs(self):
|
||||
from torch._higher_order_ops.cond import cond_op
|
||||
|
|
|
|||
|
|
@ -801,7 +801,7 @@ def forward(self, p_linear_weight, p_linear_bias, x):
|
|||
return x.sin()
|
||||
|
||||
def forward(self, x):
|
||||
return cond(x.sum() <= 2, self.subm.forward, self.bar, [x])
|
||||
return cond(x.shape[0] <= 2, self.subm.forward, self.bar, [x])
|
||||
|
||||
example_inputs = (torch.randn(1, 3, 3, 3),)
|
||||
m = CondBranchClassMethod()
|
||||
|
|
@ -3603,7 +3603,7 @@ def forward(self, x):
|
|||
|
||||
# https://github.com/pytorch/pytorch/issues/129939
|
||||
@testing.expectedFailureNonStrict
|
||||
def test_export_cond_symbool_pred(self):
|
||||
def test_export_cond(self):
|
||||
class A(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
|
@ -3626,20 +3626,10 @@ def forward(self, x):
|
|||
|
||||
return cond(x.shape[0] > 4, true_fn, false_fn, [x])
|
||||
|
||||
dim0 = torch.export.Dim("dim0", min=3)
|
||||
inp = torch.ones(6, 4)
|
||||
ep = export(Foo(), (inp,), dynamic_shapes={"x": {0: dim0}})
|
||||
self.assertExpectedInline(
|
||||
ep.graph_module.code.strip(),
|
||||
"""\
|
||||
def forward(self, b_a_buffer, x):
|
||||
sym_size_int_1 = torch.ops.aten.sym_size.int(x, 0)
|
||||
gt = sym_size_int_1 > 4; sym_size_int_1 = None
|
||||
true_graph_0 = self.true_graph_0
|
||||
false_graph_0 = self.false_graph_0
|
||||
cond = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, [x, b_a_buffer]); gt = true_graph_0 = false_graph_0 = x = b_a_buffer = None
|
||||
getitem = cond[0]; cond = None
|
||||
return (getitem,)""",
|
||||
ep = export(
|
||||
Foo(),
|
||||
(inp,),
|
||||
)
|
||||
self.assertTrue(
|
||||
torch.allclose(ep.module()(torch.ones(6, 4)), Foo()(torch.ones(6, 4)))
|
||||
|
|
@ -5017,7 +5007,7 @@ graph():
|
|||
def false_fn(x):
|
||||
return self.linear(x).sin()
|
||||
|
||||
return torch.cond(x.sum() > 4, true_fn, false_fn, [x])
|
||||
return torch.cond(x.shape[0] > 4, true_fn, false_fn, [x])
|
||||
|
||||
class CondExport(torch.nn.Module):
|
||||
def __init__(self):
|
||||
|
|
@ -5034,12 +5024,10 @@ graph():
|
|||
"""\
|
||||
def forward(self, p_bar_linear_weight, p_bar_linear_bias, x):
|
||||
cos = torch.ops.aten.cos.default(x)
|
||||
sum_1 = torch.ops.aten.sum.default(x)
|
||||
gt = torch.ops.aten.gt.Scalar(sum_1, 4); sum_1 = None
|
||||
true_graph_0 = self.true_graph_0
|
||||
false_graph_0 = self.false_graph_0
|
||||
cond = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, [p_bar_linear_bias, p_bar_linear_weight, x]); gt = true_graph_0 = false_graph_0 = p_bar_linear_bias = p_bar_linear_weight = x = None
|
||||
getitem = cond[0]; cond = None
|
||||
conditional = torch.ops.higher_order.cond(False, true_graph_0, false_graph_0, [p_bar_linear_bias, p_bar_linear_weight, x]); true_graph_0 = false_graph_0 = p_bar_linear_bias = p_bar_linear_weight = x = None
|
||||
getitem = conditional[0]; conditional = None
|
||||
add = torch.ops.aten.add.Tensor(cos, getitem); cos = getitem = None
|
||||
return (add,)""",
|
||||
)
|
||||
|
|
@ -5134,8 +5122,8 @@ def forward(self, p_bar_linear_weight, p_bar_linear_bias, x):
|
|||
def forward(self, b_pred, b_t, x, y):
|
||||
true_graph_0 = self.true_graph_0
|
||||
false_graph_0 = self.false_graph_0
|
||||
cond = torch.ops.higher_order.cond(b_pred, true_graph_0, false_graph_0, [b_t, x, y]); b_pred = true_graph_0 = false_graph_0 = b_t = x = y = None
|
||||
getitem = cond[0]; cond = None
|
||||
conditional = torch.ops.higher_order.cond(b_pred, true_graph_0, false_graph_0, [b_t, x, y]); b_pred = true_graph_0 = false_graph_0 = b_t = x = y = None
|
||||
getitem = conditional[0]; conditional = None
|
||||
return (getitem,)""",
|
||||
) # noqa: B950
|
||||
|
||||
|
|
|
|||
|
|
@ -68,7 +68,7 @@ class TestVerifier(TestCase):
|
|||
def false_fn(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
||||
return x - y
|
||||
|
||||
return control_flow.cond(x.sum() > 2, true_fn, false_fn, [x, y])
|
||||
return control_flow.cond(x.shape[0] > 2, true_fn, false_fn, [x, y])
|
||||
|
||||
f = Foo()
|
||||
|
||||
|
|
@ -87,7 +87,7 @@ class TestVerifier(TestCase):
|
|||
def false_fn(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
||||
return x - y
|
||||
|
||||
return control_flow.cond(x.sum() > 2, true_fn, false_fn, [x, y])
|
||||
return control_flow.cond(x.shape[0] > 2, true_fn, false_fn, [x, y])
|
||||
|
||||
f = Foo()
|
||||
|
||||
|
|
|
|||
|
|
@ -4237,7 +4237,7 @@ def forward(self, arg0_1, arg1_1):
|
|||
return x.cos()
|
||||
|
||||
return torch.cond(
|
||||
y.cos().sum() > 5, true_true_fn, true_false_fn, [y.cos()]
|
||||
y.cos().shape[0] > 5, true_true_fn, true_false_fn, [y.cos()]
|
||||
)
|
||||
|
||||
def false_fn(x):
|
||||
|
|
@ -4245,7 +4245,7 @@ def forward(self, arg0_1, arg1_1):
|
|||
z.add_(6)
|
||||
return z.sin()
|
||||
|
||||
a = torch.cond(x.sum() > 4, true_fn, false_fn, [x])
|
||||
a = torch.cond(x.shape[0] > 4, true_fn, false_fn, [x])
|
||||
return (a + 3, a + 4)
|
||||
|
||||
inp = torch.randn(2, 2)
|
||||
|
|
@ -4254,12 +4254,10 @@ def forward(self, arg0_1, arg1_1):
|
|||
str(gm.code).strip(),
|
||||
"""\
|
||||
def forward(self, arg0_1):
|
||||
sum_1 = torch.ops.aten.sum.default(arg0_1)
|
||||
gt = torch.ops.aten.gt.Scalar(sum_1, 4); sum_1 = None
|
||||
true_graph_0 = self.true_graph_0
|
||||
false_graph_0 = self.false_graph_0
|
||||
cond = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, [arg0_1]); gt = true_graph_0 = false_graph_0 = arg0_1 = None
|
||||
getitem = cond[0]; cond = None
|
||||
conditional = torch.ops.higher_order.cond(False, true_graph_0, false_graph_0, [arg0_1]); true_graph_0 = false_graph_0 = arg0_1 = None
|
||||
getitem = conditional[0]; conditional = None
|
||||
add = torch.ops.aten.add.Tensor(getitem, 3)
|
||||
add_1 = torch.ops.aten.add.Tensor(getitem, 4); getitem = None
|
||||
return (add, add_1)""", # noqa: B950
|
||||
|
|
@ -4272,13 +4270,11 @@ def forward(self, arg0_1):
|
|||
sin = torch.ops.aten.sin.default(arg0_1); arg0_1 = None
|
||||
add = torch.ops.aten.add.Tensor(sin, 5); sin = None
|
||||
cos = torch.ops.aten.cos.default(add)
|
||||
sum_1 = torch.ops.aten.sum.default(cos); cos = None
|
||||
gt = torch.ops.aten.gt.Scalar(sum_1, 5); sum_1 = None
|
||||
cos_1 = torch.ops.aten.cos.default(add); add = None
|
||||
true_graph_0 = self.true_graph_0
|
||||
false_graph_0 = self.false_graph_0
|
||||
cond = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, [cos_1]); gt = true_graph_0 = false_graph_0 = cos_1 = None
|
||||
getitem = cond[0]; cond = None
|
||||
conditional = torch.ops.higher_order.cond(False, true_graph_0, false_graph_0, [cos_1]); true_graph_0 = false_graph_0 = cos_1 = None
|
||||
getitem = conditional[0]; conditional = None
|
||||
return (getitem,)""", # noqa: B950
|
||||
)
|
||||
|
||||
|
|
@ -4321,7 +4317,7 @@ def forward(self, arg0_1):
|
|||
+ control_flow.map(f, z, r).sum()
|
||||
)
|
||||
|
||||
a = torch.cond(x.sum() > 4, true_fn, false_fn, [x, y])
|
||||
a = torch.cond(x.shape[0] > 4, true_fn, false_fn, [x, y])
|
||||
return (a + 3, a + 4)
|
||||
|
||||
inps = [torch.randn(2, 2), torch.ones(2)]
|
||||
|
|
@ -4330,12 +4326,10 @@ def forward(self, arg0_1):
|
|||
str(gm.code).strip(),
|
||||
"""\
|
||||
def forward(self, arg0_1, arg1_1):
|
||||
sum_1 = torch.ops.aten.sum.default(arg0_1)
|
||||
gt = torch.ops.aten.gt.Scalar(sum_1, 4); sum_1 = None
|
||||
true_graph_0 = self.true_graph_0
|
||||
false_graph_0 = self.false_graph_0
|
||||
cond = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, [arg0_1, arg1_1]); gt = true_graph_0 = false_graph_0 = arg0_1 = arg1_1 = None
|
||||
getitem = cond[0]; cond = None
|
||||
conditional = torch.ops.higher_order.cond(False, true_graph_0, false_graph_0, [arg0_1, arg1_1]); true_graph_0 = false_graph_0 = arg0_1 = arg1_1 = None
|
||||
getitem = conditional[0]; conditional = None
|
||||
add = torch.ops.aten.add.Tensor(getitem, 3)
|
||||
add_1 = torch.ops.aten.add.Tensor(getitem, 4); getitem = None
|
||||
return (add, add_1)""", # noqa: B950
|
||||
|
|
@ -4440,7 +4434,7 @@ def forward(self, arg0_1, arg1_1):
|
|||
z.add_(6)
|
||||
return z.sin()
|
||||
|
||||
a = torch.cond(x.sum() > 4, true_fn, false_fn, [x])
|
||||
a = torch.cond(x.shape[0] > 4, true_fn, false_fn, [x])
|
||||
return (a + 3, a + 4)
|
||||
|
||||
inp = torch.randn(2, 2)
|
||||
|
|
@ -4449,12 +4443,10 @@ def forward(self, arg0_1, arg1_1):
|
|||
str(gm.code).strip(),
|
||||
"""\
|
||||
def forward(self, arg0_1):
|
||||
sum_1 = torch.ops.aten.sum.default(arg0_1)
|
||||
gt = torch.ops.aten.gt.Scalar(sum_1, 4); sum_1 = None
|
||||
true_graph_0 = self.true_graph_0
|
||||
false_graph_0 = self.false_graph_0
|
||||
cond = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, [arg0_1]); gt = true_graph_0 = false_graph_0 = arg0_1 = None
|
||||
getitem = cond[0]; cond = None
|
||||
conditional = torch.ops.higher_order.cond(False, true_graph_0, false_graph_0, [arg0_1]); true_graph_0 = false_graph_0 = arg0_1 = None
|
||||
getitem = conditional[0]; conditional = None
|
||||
add = torch.ops.aten.add.Tensor(getitem, 3)
|
||||
add_1 = torch.ops.aten.add.Tensor(getitem, 4); getitem = None
|
||||
return (add, add_1)""", # noqa: B950
|
||||
|
|
@ -4875,7 +4867,7 @@ def forward(self, arg0_1, arg1_1, arg2_1):
|
|||
y.add_(6)
|
||||
return x.sin()
|
||||
|
||||
a = torch.cond(x.sum() > 4, true_fn, false_fn, [x])
|
||||
a = torch.cond(x.shape[0] > 4, true_fn, false_fn, [x])
|
||||
return (a + 3, a + 4)
|
||||
|
||||
inp = torch.randn(3, 4)
|
||||
|
|
@ -4884,12 +4876,10 @@ def forward(self, arg0_1, arg1_1, arg2_1):
|
|||
gm.code.strip(),
|
||||
"""\
|
||||
def forward(self, arg0_1):
|
||||
sum_1 = torch.ops.aten.sum.default(arg0_1)
|
||||
gt = torch.ops.aten.gt.Scalar(sum_1, 4); sum_1 = None
|
||||
true_graph_0 = self.true_graph_0
|
||||
false_graph_0 = self.false_graph_0
|
||||
cond = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, [arg0_1]); gt = true_graph_0 = false_graph_0 = arg0_1 = None
|
||||
getitem = cond[0]; cond = None
|
||||
conditional = torch.ops.higher_order.cond(False, true_graph_0, false_graph_0, [arg0_1]); true_graph_0 = false_graph_0 = arg0_1 = None
|
||||
getitem = conditional[0]; conditional = None
|
||||
add = torch.ops.aten.add.Tensor(getitem, 3)
|
||||
add_1 = torch.ops.aten.add.Tensor(getitem, 4); getitem = None
|
||||
return (add, add_1)""", # noqa: B950
|
||||
|
|
|
|||
|
|
@ -877,7 +877,7 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1):
|
|||
f(x, torch.tensor(True), torch.tensor(True)),
|
||||
)
|
||||
|
||||
def test_cond_functionalized(self):
|
||||
def test_cond_functionalized_hah(self):
|
||||
def true_fn(x):
|
||||
y = x.sin()
|
||||
y.add_(4)
|
||||
|
|
@ -894,9 +894,7 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1):
|
|||
functional_f = torch.func.functionalize(f)
|
||||
self.assertEqual(functional_f(*example_inputs), f(*example_inputs))
|
||||
|
||||
graph_module = make_fx(torch.func.functionalize(f), tracing_mode="symbolic")(
|
||||
*example_inputs
|
||||
)
|
||||
graph_module = make_fx(torch.func.functionalize(f))(*example_inputs)
|
||||
self.assertEqual(graph_module(*example_inputs), f(*example_inputs))
|
||||
|
||||
all_ops_in_true_branch = []
|
||||
|
|
@ -906,6 +904,9 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1):
|
|||
|
||||
self.assertFalse(any(op._schema.is_mutable for op in all_ops_in_true_branch))
|
||||
|
||||
graph_module = make_fx(torch.func.functionalize(f), tracing_mode="symbolic")(
|
||||
*example_inputs
|
||||
)
|
||||
self.assertEqual(graph_module(*example_inputs), f(*example_inputs))
|
||||
|
||||
def test_cond_accepts_torch_function_as_inputs(self):
|
||||
|
|
@ -924,8 +925,8 @@ def forward(self, a_1, b_1):
|
|||
gt = torch.ops.aten.gt.Scalar(sum_1, 0); sum_1 = None
|
||||
true_graph_0 = self.true_graph_0
|
||||
false_graph_0 = self.false_graph_0
|
||||
cond = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, [a_1, b_1]); gt = true_graph_0 = false_graph_0 = a_1 = b_1 = None
|
||||
getitem = cond[0]; cond = None
|
||||
conditional = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, [a_1, b_1]); gt = true_graph_0 = false_graph_0 = a_1 = b_1 = None
|
||||
getitem = conditional[0]; conditional = None
|
||||
return getitem""", # noqa: B950
|
||||
)
|
||||
self.assertExpectedInline(
|
||||
|
|
@ -972,9 +973,9 @@ def forward(self, arg0_1, arg1_1):
|
|||
z = torch.add(y, y)
|
||||
return z
|
||||
|
||||
symbolic_traced_graph = self._check_tracing(
|
||||
f, (torch.ones(4), torch.Tensor([True]))
|
||||
)["symbolic"]
|
||||
symbolic_traced_graph = self._check_tracing(f, (torch.ones(4), True))[
|
||||
"symbolic"
|
||||
]
|
||||
graph_shape_env = symbolic_traced_graph.shape_env
|
||||
|
||||
def _node_shape_env_iter(gm):
|
||||
|
|
@ -1020,14 +1021,15 @@ def forward(self, arg0_1, arg1_1):
|
|||
functional_f = torch.func.functionalize(f)
|
||||
self.assertEqual(functional_f(*example_inputs), f(*example_inputs))
|
||||
|
||||
graph_module = make_fx(torch.func.functionalize(f), tracing_mode="symbolic")(
|
||||
*example_inputs
|
||||
)
|
||||
graph_module = make_fx(torch.func.functionalize(f))(*example_inputs)
|
||||
self.assertEqual(graph_module(*example_inputs), f(*example_inputs))
|
||||
|
||||
gm_true_true_branch = graph_module.true_graph_0.true_graph_0
|
||||
|
||||
self.assertEqual(graph_module(*example_inputs), f(*example_inputs))
|
||||
graph_module1 = make_fx(torch.func.functionalize(f), tracing_mode="symbolic")(
|
||||
*example_inputs
|
||||
)
|
||||
self.assertEqual(graph_module1(*example_inputs), f(*example_inputs))
|
||||
|
||||
all_ops = []
|
||||
for node in gm_true_true_branch.graph.nodes:
|
||||
|
|
@ -1055,7 +1057,8 @@ def forward(self, arg0_1, arg1_1):
|
|||
self.assertEqual(graph_module(*example_inputs), f(*example_inputs))
|
||||
|
||||
# https://github.com/pytorch/pytorch/issues/126988
|
||||
def test_cond_functionalized_input_mutation_on_true_brancte(self):
|
||||
@xfailIfTorchDynamo
|
||||
def test_cond_functionalized_input_mutation_on_true_branch(self):
|
||||
def true_fn(x):
|
||||
view_x = x.view(x.shape)
|
||||
view_x.add_(1)
|
||||
|
|
@ -1069,33 +1072,19 @@ def forward(self, arg0_1, arg1_1):
|
|||
return cond(pred, true_fn, false_fn, [x])
|
||||
|
||||
example_inputs = (torch.ones(4, 5),)
|
||||
# torch.cond inlines into one of the branches because the predicate
|
||||
# is a constant.
|
||||
gm = make_fx(torch.func.functionalize(f))(*example_inputs)
|
||||
self.assertExpectedInline(
|
||||
gm.code.strip(),
|
||||
"""\
|
||||
def forward(self, x_1):
|
||||
view = torch.ops.aten.view.default(x_1, [4, 5])
|
||||
add = torch.ops.aten.add.Tensor(view, 1); view = None
|
||||
view_1 = torch.ops.aten.view.default(add, [4, 5]); add = None
|
||||
view_2 = torch.ops.aten.view.default(view_1, [4, 5])
|
||||
sin = torch.ops.aten.sin.default(view_2); view_2 = None
|
||||
sum_1 = torch.ops.aten.sum.default(sin); sin = None
|
||||
copy_ = torch.ops.aten.copy_.default(x_1, view_1); x_1 = view_1 = None
|
||||
return sum_1""",
|
||||
)
|
||||
|
||||
# torch.cond triggers the check of the branches because the predicate
|
||||
# is a SymBool.
|
||||
functional_f = torch.func.functionalize(f)
|
||||
with self.assertRaisesRegex(
|
||||
UnsupportedAliasMutationException, "One of torch.cond branch"
|
||||
):
|
||||
make_fx(torch.func.functionalize(f), tracing_mode="symbolic")(
|
||||
*example_inputs
|
||||
)
|
||||
functional_f(*example_inputs)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
UnsupportedAliasMutationException, "One of torch.cond branch"
|
||||
):
|
||||
make_fx(torch.func.functionalize(f))(*example_inputs)
|
||||
|
||||
# https://github.com/pytorch/pytorch/issues/126988
|
||||
@xfailIfTorchDynamo
|
||||
def test_cond_functionalized_input_mutation_on_false_branch(self):
|
||||
def true_fn(x):
|
||||
return x.sin().sum()
|
||||
|
|
@ -1110,33 +1099,19 @@ def forward(self, x_1):
|
|||
return cond(pred, true_fn, false_fn, [x])
|
||||
|
||||
example_inputs = (torch.ones(5, 5),)
|
||||
gm = make_fx(torch.func.functionalize(f))(*example_inputs)
|
||||
# torch.cond inlines into one of the branches because the predicate
|
||||
# is a constant.
|
||||
self.assertExpectedInline(
|
||||
gm.code.strip(),
|
||||
"""\
|
||||
def forward(self, x_1):
|
||||
view = torch.ops.aten.view.default(x_1, [5, 5])
|
||||
add = torch.ops.aten.add.Tensor(view, 1); view = None
|
||||
view_1 = torch.ops.aten.view.default(add, [5, 5]); add = None
|
||||
view_2 = torch.ops.aten.view.default(view_1, [5, 5])
|
||||
cos = torch.ops.aten.cos.default(view_2); view_2 = None
|
||||
sum_1 = torch.ops.aten.sum.default(cos); cos = None
|
||||
copy_ = torch.ops.aten.copy_.default(x_1, view_1); x_1 = view_1 = None
|
||||
return sum_1""",
|
||||
)
|
||||
|
||||
# torch.cond triggers the check of the branches because the predicate
|
||||
# is a SymBool.
|
||||
functional_f = torch.func.functionalize(f)
|
||||
with self.assertRaisesRegex(
|
||||
UnsupportedAliasMutationException, "One of torch.cond branch"
|
||||
):
|
||||
make_fx(torch.func.functionalize(f), tracing_mode="symbolic")(
|
||||
*example_inputs
|
||||
)
|
||||
functional_f(*example_inputs)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
UnsupportedAliasMutationException, "One of torch.cond branch"
|
||||
):
|
||||
make_fx(torch.func.functionalize(f))(*example_inputs)
|
||||
|
||||
# https://github.com/pytorch/pytorch/issues/126988
|
||||
@xfailIfTorchDynamo
|
||||
def test_cond_functionalized_output_alias_input(self):
|
||||
def true_fn(x):
|
||||
return x
|
||||
|
|
@ -1150,27 +1125,22 @@ def forward(self, x_1):
|
|||
return cond(pred, true_fn, false_fn, [x])
|
||||
|
||||
example_inputs = (torch.ones(5, 5),)
|
||||
gm = make_fx(torch.func.functionalize(f))(*example_inputs)
|
||||
# torch.cond inlines into one of the branches because the predicate
|
||||
# is a constant.
|
||||
self.assertExpectedInline(
|
||||
gm.code.strip(),
|
||||
"""\
|
||||
def forward(self, x_1):
|
||||
view = torch.ops.aten.view.default(x_1, [5, 5]); x_1 = None
|
||||
return view""",
|
||||
)
|
||||
functional_f = torch.func.functionalize(f)
|
||||
|
||||
# torch.cond triggers the check of the branches because the predicate
|
||||
# is a SymBool.
|
||||
with self.assertRaisesRegex(
|
||||
UnsupportedAliasMutationException, "One of torch.cond branch"
|
||||
UnsupportedAliasMutationException,
|
||||
"One of torch.cond branch might be aliasing",
|
||||
):
|
||||
make_fx(torch.func.functionalize(f), tracing_mode="symbolic")(
|
||||
*example_inputs
|
||||
)
|
||||
functional_f(*example_inputs)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
UnsupportedAliasMutationException,
|
||||
"One of torch.cond branch might be aliasing",
|
||||
):
|
||||
make_fx(torch.func.functionalize(f))(*example_inputs)
|
||||
|
||||
# https://github.com/pytorch/pytorch/issues/126988
|
||||
@xfailIfTorchDynamo
|
||||
def test_cond_functionalized_nested_input_mutation(self):
|
||||
def true_true_fn(x):
|
||||
x.add_(4)
|
||||
|
|
@ -1191,14 +1161,19 @@ def forward(self, x_1):
|
|||
return cond(pred, true_fn, false_fn, [x])
|
||||
|
||||
example_inputs = (torch.ones(4, 5),)
|
||||
functional_f = torch.func.functionalize(f)
|
||||
with self.assertRaisesRegex(
|
||||
UnsupportedAliasMutationException, "One of torch.cond branch"
|
||||
):
|
||||
make_fx(torch.func.functionalize(f), tracing_mode="symbolic")(
|
||||
*example_inputs
|
||||
)
|
||||
functional_f(*example_inputs)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
UnsupportedAliasMutationException, "One of torch.cond branch"
|
||||
):
|
||||
make_fx(torch.func.functionalize(f))(*example_inputs)
|
||||
|
||||
# https://github.com/pytorch/pytorch/issues/126988
|
||||
@xfailIfTorchDynamo
|
||||
def test_cond_functionalized_nested_input_mutation_with_aot_func(self):
|
||||
def true_true_fn(x):
|
||||
x.add_(4)
|
||||
|
|
@ -1222,12 +1197,15 @@ def forward(self, x_1):
|
|||
try:
|
||||
example_input_func = to_fun_old(example_input)
|
||||
torch._enable_functionalization(reapply_views=False)
|
||||
with self.assertRaisesRegex(
|
||||
UnsupportedAliasMutationException, "One of torch.cond branch"
|
||||
):
|
||||
f(example_input_func)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
UnsupportedAliasMutationException, "One of torch.cond branch"
|
||||
):
|
||||
make_fx(f, tracing_mode="symbolic")(example_input_func)
|
||||
make_fx(f)(example_input_func)
|
||||
finally:
|
||||
torch._disable_functionalization()
|
||||
|
||||
|
|
@ -1245,7 +1223,7 @@ def forward(self, x_1):
|
|||
with self.assertRaisesRegex(
|
||||
UnsupportedAliasMutationException, "One of torch.cond branch"
|
||||
):
|
||||
make_fx(f_wrapper(f), tracing_mode="symbolic")(example_input_func)
|
||||
make_fx(f_wrapper(f))(example_input_func)
|
||||
|
||||
# https://github.com/pytorch/pytorch/issues/126988
|
||||
@xfailIfTorchDynamo
|
||||
|
|
@ -1258,7 +1236,7 @@ def forward(self, x_1):
|
|||
return view_x
|
||||
|
||||
def f(x):
|
||||
pred = x.sum() > 0
|
||||
pred = x.shape[0] == 4
|
||||
return cond(pred, true_fn, false_fn, [x])
|
||||
|
||||
example_input = torch.ones(5, 5)
|
||||
|
|
@ -1300,7 +1278,7 @@ def forward(self, x_1):
|
|||
UnsupportedAliasMutationException,
|
||||
"One of torch.cond branch might be aliasing",
|
||||
):
|
||||
make_fx(f_wrapper(f), tracing_mode="symbolic")(example_input)
|
||||
make_fx(f_wrapper(f))(example_input)
|
||||
|
||||
def test_cond_functionalized_aot_func_check_functional(self):
|
||||
def true_fn(x):
|
||||
|
|
@ -1338,7 +1316,7 @@ def forward(self, x_1):
|
|||
|
||||
return wrapper
|
||||
|
||||
result_gm = make_fx(f_wrapper(f), tracing_mode="symbolic")(example_input)
|
||||
result_gm = make_fx(f_wrapper(f))(example_input)
|
||||
for node in result_gm.true_graph_0.graph.nodes:
|
||||
if node.op == "call_function":
|
||||
self.assertTrue(not node.target._schema.is_mutable)
|
||||
|
|
@ -1404,12 +1382,12 @@ def forward(self, x_1):
|
|||
def forward(self, x_1, pred_1, pred2_1):
|
||||
true_graph_0 = self.true_graph_0
|
||||
false_graph_0 = self.false_graph_0
|
||||
cond = torch.ops.higher_order.cond(pred_1, true_graph_0, false_graph_0, [x_1]); pred_1 = true_graph_0 = false_graph_0 = None
|
||||
getitem = cond[0]; cond = None
|
||||
conditional = torch.ops.higher_order.cond(pred_1, true_graph_0, false_graph_0, [x_1]); pred_1 = true_graph_0 = false_graph_0 = None
|
||||
getitem = conditional[0]; conditional = None
|
||||
true_graph_1 = self.true_graph_1
|
||||
false_graph_1 = self.false_graph_1
|
||||
cond_1 = torch.ops.higher_order.cond(pred2_1, true_graph_1, false_graph_1, [x_1]); pred2_1 = true_graph_1 = false_graph_1 = x_1 = None
|
||||
getitem_1 = cond_1[0]; cond_1 = None
|
||||
conditional_1 = torch.ops.higher_order.cond(pred2_1, true_graph_1, false_graph_1, [x_1]); pred2_1 = true_graph_1 = false_graph_1 = x_1 = None
|
||||
getitem_1 = conditional_1[0]; conditional_1 = None
|
||||
add = torch.ops.aten.add.Tensor(getitem, getitem_1); getitem = getitem_1 = None
|
||||
return add""", # noqa: B950
|
||||
)
|
||||
|
|
@ -1577,12 +1555,12 @@ def forward(self, arg0_1):
|
|||
def forward(self, x_1, pred_1, pred2_1):
|
||||
true_graph_0 = self.true_graph_0
|
||||
false_graph_0 = self.false_graph_0
|
||||
cond = torch.ops.higher_order.cond(pred_1, true_graph_0, false_graph_0, [x_1]); pred_1 = true_graph_0 = false_graph_0 = None
|
||||
getitem = cond[0]; cond = None
|
||||
conditional = torch.ops.higher_order.cond(pred_1, true_graph_0, false_graph_0, [x_1]); pred_1 = true_graph_0 = false_graph_0 = None
|
||||
getitem = conditional[0]; conditional = None
|
||||
true_graph_1 = self.true_graph_1
|
||||
false_graph_1 = self.false_graph_1
|
||||
cond_1 = torch.ops.higher_order.cond(pred2_1, true_graph_1, false_graph_1, [x_1]); pred2_1 = true_graph_1 = false_graph_1 = x_1 = None
|
||||
getitem_1 = cond_1[0]; cond_1 = None
|
||||
conditional_1 = torch.ops.higher_order.cond(pred2_1, true_graph_1, false_graph_1, [x_1]); pred2_1 = true_graph_1 = false_graph_1 = x_1 = None
|
||||
getitem_1 = conditional_1[0]; conditional_1 = None
|
||||
add = torch.ops.aten.add.Tensor(getitem, getitem_1); getitem = getitem_1 = None
|
||||
return add""", # noqa: B950
|
||||
)
|
||||
|
|
@ -1913,27 +1891,6 @@ def forward(self, arg0_1):
|
|||
):
|
||||
functional_f(*example_inputs)
|
||||
|
||||
def test_cond_autograd_succeed_when_pred_is_constant(self):
|
||||
def true_fn(x):
|
||||
return x.cos()
|
||||
|
||||
def false_fn(x):
|
||||
return x.sin()
|
||||
|
||||
def f(x, y):
|
||||
return control_flow.cond(x.shape[0] > 4, true_fn, false_fn, [y])
|
||||
|
||||
example_inputs = (
|
||||
torch.ones(3, 2, 4, requires_grad=True),
|
||||
torch.ones(4, requires_grad=True),
|
||||
)
|
||||
# Due to x.shape[0] can be statically evaluated to be False, we can evaluate
|
||||
# the backward.
|
||||
f(*example_inputs).sum().backward()
|
||||
|
||||
# Ensure no error is thrown when not running backward
|
||||
f(*example_inputs)
|
||||
|
||||
def test_cond_autograd_fail(self):
|
||||
def true_fn(x):
|
||||
return x.cos()
|
||||
|
|
@ -1942,7 +1899,7 @@ def forward(self, arg0_1):
|
|||
return x.sin()
|
||||
|
||||
def f(x, y):
|
||||
return control_flow.cond(x.sum() > 4, true_fn, false_fn, [y])
|
||||
return control_flow.cond(x.shape[0] > 4, true_fn, false_fn, [y])
|
||||
|
||||
example_inputs = (
|
||||
torch.ones(3, 2, 4, requires_grad=True),
|
||||
|
|
@ -2072,8 +2029,8 @@ def forward(self, x_1):
|
|||
eq = sym_size_int == 4; sym_size_int = None
|
||||
true_graph_0 = self.true_graph_0
|
||||
false_graph_0 = self.false_graph_0
|
||||
cond = torch.ops.higher_order.cond(eq, true_graph_0, false_graph_0, [x_1]); eq = true_graph_0 = false_graph_0 = x_1 = None
|
||||
getitem = cond[0]; cond = None
|
||||
conditional = torch.ops.higher_order.cond(eq, true_graph_0, false_graph_0, [x_1]); eq = true_graph_0 = false_graph_0 = x_1 = None
|
||||
getitem = conditional[0]; conditional = None
|
||||
return getitem""", # noqa: B950
|
||||
)
|
||||
|
||||
|
|
@ -2145,20 +2102,18 @@ def forward(self, x_1):
|
|||
# expected branches takes [x, a, b] as input
|
||||
inp = torch.randn(2, 3)
|
||||
|
||||
gm = make_fx(foo, tracing_mode="symbolic", _allow_non_fake_inputs=True)(inp)
|
||||
gm = make_fx(foo)(inp)
|
||||
|
||||
self.assertExpectedInline(
|
||||
gm.code.strip(),
|
||||
"""\
|
||||
def forward(self, x_1):
|
||||
sym_size_int = torch.ops.aten.sym_size.int(x_1, 0)
|
||||
eq = sym_size_int == 4; sym_size_int = None
|
||||
true_graph_0 = self.true_graph_0
|
||||
false_graph_0 = self.false_graph_0
|
||||
_tensor_constant0 = self._tensor_constant0
|
||||
_tensor_constant1 = self._tensor_constant1
|
||||
cond = torch.ops.higher_order.cond(eq, true_graph_0, false_graph_0, [x_1, _tensor_constant0, _tensor_constant1]); eq = true_graph_0 = false_graph_0 = x_1 = _tensor_constant0 = _tensor_constant1 = None
|
||||
getitem = cond[0]; cond = None
|
||||
conditional = torch.ops.higher_order.cond(False, true_graph_0, false_graph_0, [x_1, _tensor_constant0, _tensor_constant1]); true_graph_0 = false_graph_0 = x_1 = _tensor_constant0 = _tensor_constant1 = None
|
||||
getitem = conditional[0]; conditional = None
|
||||
return getitem""", # noqa: B950
|
||||
)
|
||||
self.assertExpectedInline(
|
||||
|
|
@ -2308,8 +2263,8 @@ def forward(self, pred_1, x_1):
|
|||
def forward(self, arg0_1, arg1_1):
|
||||
true_graph_0 = self.true_graph_0
|
||||
false_graph_0 = self.false_graph_0
|
||||
cond = torch.ops.higher_order.cond(arg1_1, true_graph_0, false_graph_0, [arg0_1]); arg1_1 = true_graph_0 = false_graph_0 = arg0_1 = None
|
||||
getitem = cond[0]; cond = None
|
||||
conditional = torch.ops.higher_order.cond(arg1_1, true_graph_0, false_graph_0, [arg0_1]); arg1_1 = true_graph_0 = false_graph_0 = arg0_1 = None
|
||||
getitem = conditional[0]; conditional = None
|
||||
return [getitem]""", # noqa: B950
|
||||
)
|
||||
|
||||
|
|
@ -2350,7 +2305,7 @@ def forward(self, arg0_1, arg1_1):
|
|||
counters.clear()
|
||||
|
||||
def foo(x, true_fn, false_fn):
|
||||
return cond(x.sum() < 0, true_fn, false_fn, (x,))
|
||||
return cond(x.shape[0] == 4, true_fn, false_fn, (x,))
|
||||
|
||||
inp = torch.ones(3, 4)
|
||||
exp_out = inp.sin()
|
||||
|
|
@ -2392,8 +2347,8 @@ def forward(self, x_1):
|
|||
eq = sym_size_int == 4; sym_size_int = None
|
||||
true_graph_0 = self.true_graph_0
|
||||
false_graph_0 = self.false_graph_0
|
||||
cond = torch.ops.higher_order.cond(eq, true_graph_0, false_graph_0, [x_1]); eq = true_graph_0 = false_graph_0 = x_1 = None
|
||||
getitem = cond[0]; cond = None
|
||||
conditional = torch.ops.higher_order.cond(eq, true_graph_0, false_graph_0, [x_1]); eq = true_graph_0 = false_graph_0 = x_1 = None
|
||||
getitem = conditional[0]; conditional = None
|
||||
return getitem""", # noqa: B950
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -632,18 +632,6 @@ class CondHigherOrderVariable(TorchHigherOrderOperatorVariable):
|
|||
f"Expected 4 arguments but got {len(args)}.\n"
|
||||
f"Usage: cond(pred, true_fn, false_fn, operands)",
|
||||
)
|
||||
|
||||
# Specialize into one of the branches since pred is constant
|
||||
if type(args[0]) is ConstantVariable:
|
||||
log.warning(
|
||||
"Pred is a Python constant. When used with torch.cond, it executes only one of the branches."
|
||||
" If you want torch.cond to perserve two branches, please make the predicate a boolean tensor or a SymBool."
|
||||
)
|
||||
if args[0].as_python_constant():
|
||||
return args[1].call_function(tx, args[3].unpack_var_sequence(tx), {})
|
||||
else:
|
||||
return args[2].call_function(tx, args[3].unpack_var_sequence(tx), {})
|
||||
|
||||
# predicate
|
||||
if type(args[0]) not in (ConstantVariable, TensorVariable, SymNodeVariable):
|
||||
unimplemented(
|
||||
|
|
|
|||
|
|
@ -1,8 +1,6 @@
|
|||
# mypy: allow-untyped-defs
|
||||
import contextlib
|
||||
|
||||
import logging
|
||||
|
||||
import torch
|
||||
import torch._subclasses.functional_tensor
|
||||
import torch.utils._pytree as pytree
|
||||
|
|
@ -34,8 +32,6 @@ from torch.fx.experimental.proxy_tensor import (
|
|||
from torch.fx.passes.shape_prop import _extract_tensor_metadata
|
||||
from torch.utils._python_dispatch import _get_current_dispatch_mode
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@exposed_in("torch")
|
||||
def cond(pred, true_fn, false_fn, operands):
|
||||
|
|
@ -107,19 +103,10 @@ def cond(pred, true_fn, false_fn, operands):
|
|||
- The **output** of branches must be a **single Tensor**. Pytree of tensors will be supported in the future.
|
||||
|
||||
"""
|
||||
|
||||
if torch.compiler.is_dynamo_compiling():
|
||||
return cond_op(pred, true_fn, false_fn, operands)
|
||||
|
||||
if isinstance(pred, (bool, int, float)):
|
||||
log.warning(
|
||||
"Pred is a Python constant. When used with torch.cond, it executes only one of the branches."
|
||||
" If you want torch.cond to perserve two branches, please make the predicate a boolean tensor or a SymBool."
|
||||
)
|
||||
if pred:
|
||||
return true_fn(*operands)
|
||||
else:
|
||||
return false_fn(*operands)
|
||||
|
||||
def _validate_input(pred, true_fn, false_fn, operands):
|
||||
if not isinstance(pred, (bool, torch.Tensor, torch.SymBool)):
|
||||
raise RuntimeError(f"Expected pred to be bool or tensor, but got {pred}.")
|
||||
|
|
@ -213,7 +200,7 @@ def trace_cond(proxy_mode, func_overload, pred, true_fn, false_fn, operands):
|
|||
proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, args)
|
||||
|
||||
out_proxy = proxy_mode.tracer.create_proxy(
|
||||
"call_function", func_overload, proxy_args, {}
|
||||
"call_function", func_overload, proxy_args, {}, name="conditional"
|
||||
)
|
||||
|
||||
# At this point, we're *guaranteed* that whether an output came from the
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user