mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Revert "[cond] inlining into one of the branches when pred is a python constant (#128709)"
This reverts commitfe3e6878c4. Reverted https://github.com/pytorch/pytorch/pull/128709 on behalf of https://github.com/ydwu4 due to causing error on truck due to a land racing:fe3e6878c4([comment](https://github.com/pytorch/pytorch/pull/128709#issuecomment-2221104043))
This commit is contained in:
parent
b4b7477d3f
commit
0beeac35fa
|
|
@ -1912,8 +1912,11 @@ def forward(self, l_x_):
|
||||||
):
|
):
|
||||||
# True branch and false branch return tensors of different shape
|
# True branch and false branch return tensors of different shape
|
||||||
torch._dynamo.export(mod)(torch.randn(3, 2))
|
torch._dynamo.export(mod)(torch.randn(3, 2))
|
||||||
|
with self.assertRaisesRegex(
|
||||||
# We specialize into one of the branches since predicate is a python boolean.
|
torch._dynamo.exc.UncapturedHigherOrderOpError,
|
||||||
|
"Cond doesn't work unless it is captured completely with torch.compile",
|
||||||
|
):
|
||||||
|
# True branch and false branch return tensors of different shape
|
||||||
test_x = torch.randn(3, 2)
|
test_x = torch.randn(3, 2)
|
||||||
mod(test_x)
|
mod(test_x)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1408,7 +1408,7 @@ def forward(self, child, const_unused):
|
||||||
def false_fn(x):
|
def false_fn(x):
|
||||||
return (x - 1).sum()
|
return (x - 1).sum()
|
||||||
|
|
||||||
return control_flow.cond(x.sum() > 4, true_fn, false_fn, [x])
|
return control_flow.cond(x.shape[0] > 4, true_fn, false_fn, [x])
|
||||||
|
|
||||||
mod_for_compile = torch.compile(Foo(), backend=cnt, dynamic=True)
|
mod_for_compile = torch.compile(Foo(), backend=cnt, dynamic=True)
|
||||||
mod_for_eager = Foo()
|
mod_for_eager = Foo()
|
||||||
|
|
@ -6147,16 +6147,12 @@ class ActivationCheckpointingTests(torch._dynamo.test_case.TestCase):
|
||||||
return cond_op(pred=pred, true_fn=true_fn, false_fn=false_fn, operands=[x])
|
return cond_op(pred=pred, true_fn=true_fn, false_fn=false_fn, operands=[x])
|
||||||
|
|
||||||
cnt = CompileCounter()
|
cnt = CompileCounter()
|
||||||
opt_test = torch.compile(test, backend=cnt, fullgraph=True)
|
opt_test = torch.compile(test, backend=cnt)
|
||||||
inp = torch.ones(3, 3)
|
inp = torch.ones(3, 3)
|
||||||
true_pred = torch.Tensor([True])
|
self.assertTrue(torch.allclose(test(True, inp), opt_test(True, inp)))
|
||||||
false_pred = torch.Tensor([False])
|
|
||||||
self.assertTrue(torch.allclose(test(true_pred, inp), opt_test(true_pred, inp)))
|
|
||||||
self.assertEqual(cnt.frame_count, 1)
|
|
||||||
self.assertTrue(
|
|
||||||
torch.allclose(test(false_pred, inp), opt_test(false_pred, inp))
|
|
||||||
)
|
|
||||||
self.assertEqual(cnt.frame_count, 1)
|
self.assertEqual(cnt.frame_count, 1)
|
||||||
|
self.assertTrue(torch.allclose(test(False, inp), opt_test(False, inp)))
|
||||||
|
self.assertEqual(cnt.frame_count, 2)
|
||||||
|
|
||||||
def test_cond_with_invalid_kwargs(self):
|
def test_cond_with_invalid_kwargs(self):
|
||||||
from torch._higher_order_ops.cond import cond_op
|
from torch._higher_order_ops.cond import cond_op
|
||||||
|
|
|
||||||
|
|
@ -801,7 +801,7 @@ def forward(self, p_linear_weight, p_linear_bias, x):
|
||||||
return x.sin()
|
return x.sin()
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return cond(x.sum() <= 2, self.subm.forward, self.bar, [x])
|
return cond(x.shape[0] <= 2, self.subm.forward, self.bar, [x])
|
||||||
|
|
||||||
example_inputs = (torch.randn(1, 3, 3, 3),)
|
example_inputs = (torch.randn(1, 3, 3, 3),)
|
||||||
m = CondBranchClassMethod()
|
m = CondBranchClassMethod()
|
||||||
|
|
@ -3603,7 +3603,7 @@ def forward(self, x):
|
||||||
|
|
||||||
# https://github.com/pytorch/pytorch/issues/129939
|
# https://github.com/pytorch/pytorch/issues/129939
|
||||||
@testing.expectedFailureNonStrict
|
@testing.expectedFailureNonStrict
|
||||||
def test_export_cond_symbool_pred(self):
|
def test_export_cond(self):
|
||||||
class A(torch.nn.Module):
|
class A(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
@ -3626,20 +3626,10 @@ def forward(self, x):
|
||||||
|
|
||||||
return cond(x.shape[0] > 4, true_fn, false_fn, [x])
|
return cond(x.shape[0] > 4, true_fn, false_fn, [x])
|
||||||
|
|
||||||
dim0 = torch.export.Dim("dim0", min=3)
|
|
||||||
inp = torch.ones(6, 4)
|
inp = torch.ones(6, 4)
|
||||||
ep = export(Foo(), (inp,), dynamic_shapes={"x": {0: dim0}})
|
ep = export(
|
||||||
self.assertExpectedInline(
|
Foo(),
|
||||||
ep.graph_module.code.strip(),
|
(inp,),
|
||||||
"""\
|
|
||||||
def forward(self, b_a_buffer, x):
|
|
||||||
sym_size_int_1 = torch.ops.aten.sym_size.int(x, 0)
|
|
||||||
gt = sym_size_int_1 > 4; sym_size_int_1 = None
|
|
||||||
true_graph_0 = self.true_graph_0
|
|
||||||
false_graph_0 = self.false_graph_0
|
|
||||||
cond = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, [x, b_a_buffer]); gt = true_graph_0 = false_graph_0 = x = b_a_buffer = None
|
|
||||||
getitem = cond[0]; cond = None
|
|
||||||
return (getitem,)""",
|
|
||||||
)
|
)
|
||||||
self.assertTrue(
|
self.assertTrue(
|
||||||
torch.allclose(ep.module()(torch.ones(6, 4)), Foo()(torch.ones(6, 4)))
|
torch.allclose(ep.module()(torch.ones(6, 4)), Foo()(torch.ones(6, 4)))
|
||||||
|
|
@ -5017,7 +5007,7 @@ graph():
|
||||||
def false_fn(x):
|
def false_fn(x):
|
||||||
return self.linear(x).sin()
|
return self.linear(x).sin()
|
||||||
|
|
||||||
return torch.cond(x.sum() > 4, true_fn, false_fn, [x])
|
return torch.cond(x.shape[0] > 4, true_fn, false_fn, [x])
|
||||||
|
|
||||||
class CondExport(torch.nn.Module):
|
class CondExport(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
|
@ -5034,12 +5024,10 @@ graph():
|
||||||
"""\
|
"""\
|
||||||
def forward(self, p_bar_linear_weight, p_bar_linear_bias, x):
|
def forward(self, p_bar_linear_weight, p_bar_linear_bias, x):
|
||||||
cos = torch.ops.aten.cos.default(x)
|
cos = torch.ops.aten.cos.default(x)
|
||||||
sum_1 = torch.ops.aten.sum.default(x)
|
|
||||||
gt = torch.ops.aten.gt.Scalar(sum_1, 4); sum_1 = None
|
|
||||||
true_graph_0 = self.true_graph_0
|
true_graph_0 = self.true_graph_0
|
||||||
false_graph_0 = self.false_graph_0
|
false_graph_0 = self.false_graph_0
|
||||||
cond = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, [p_bar_linear_bias, p_bar_linear_weight, x]); gt = true_graph_0 = false_graph_0 = p_bar_linear_bias = p_bar_linear_weight = x = None
|
conditional = torch.ops.higher_order.cond(False, true_graph_0, false_graph_0, [p_bar_linear_bias, p_bar_linear_weight, x]); true_graph_0 = false_graph_0 = p_bar_linear_bias = p_bar_linear_weight = x = None
|
||||||
getitem = cond[0]; cond = None
|
getitem = conditional[0]; conditional = None
|
||||||
add = torch.ops.aten.add.Tensor(cos, getitem); cos = getitem = None
|
add = torch.ops.aten.add.Tensor(cos, getitem); cos = getitem = None
|
||||||
return (add,)""",
|
return (add,)""",
|
||||||
)
|
)
|
||||||
|
|
@ -5134,8 +5122,8 @@ def forward(self, p_bar_linear_weight, p_bar_linear_bias, x):
|
||||||
def forward(self, b_pred, b_t, x, y):
|
def forward(self, b_pred, b_t, x, y):
|
||||||
true_graph_0 = self.true_graph_0
|
true_graph_0 = self.true_graph_0
|
||||||
false_graph_0 = self.false_graph_0
|
false_graph_0 = self.false_graph_0
|
||||||
cond = torch.ops.higher_order.cond(b_pred, true_graph_0, false_graph_0, [b_t, x, y]); b_pred = true_graph_0 = false_graph_0 = b_t = x = y = None
|
conditional = torch.ops.higher_order.cond(b_pred, true_graph_0, false_graph_0, [b_t, x, y]); b_pred = true_graph_0 = false_graph_0 = b_t = x = y = None
|
||||||
getitem = cond[0]; cond = None
|
getitem = conditional[0]; conditional = None
|
||||||
return (getitem,)""",
|
return (getitem,)""",
|
||||||
) # noqa: B950
|
) # noqa: B950
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -68,7 +68,7 @@ class TestVerifier(TestCase):
|
||||||
def false_fn(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
def false_fn(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
||||||
return x - y
|
return x - y
|
||||||
|
|
||||||
return control_flow.cond(x.sum() > 2, true_fn, false_fn, [x, y])
|
return control_flow.cond(x.shape[0] > 2, true_fn, false_fn, [x, y])
|
||||||
|
|
||||||
f = Foo()
|
f = Foo()
|
||||||
|
|
||||||
|
|
@ -87,7 +87,7 @@ class TestVerifier(TestCase):
|
||||||
def false_fn(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
def false_fn(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
||||||
return x - y
|
return x - y
|
||||||
|
|
||||||
return control_flow.cond(x.sum() > 2, true_fn, false_fn, [x, y])
|
return control_flow.cond(x.shape[0] > 2, true_fn, false_fn, [x, y])
|
||||||
|
|
||||||
f = Foo()
|
f = Foo()
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -4237,7 +4237,7 @@ def forward(self, arg0_1, arg1_1):
|
||||||
return x.cos()
|
return x.cos()
|
||||||
|
|
||||||
return torch.cond(
|
return torch.cond(
|
||||||
y.cos().sum() > 5, true_true_fn, true_false_fn, [y.cos()]
|
y.cos().shape[0] > 5, true_true_fn, true_false_fn, [y.cos()]
|
||||||
)
|
)
|
||||||
|
|
||||||
def false_fn(x):
|
def false_fn(x):
|
||||||
|
|
@ -4245,7 +4245,7 @@ def forward(self, arg0_1, arg1_1):
|
||||||
z.add_(6)
|
z.add_(6)
|
||||||
return z.sin()
|
return z.sin()
|
||||||
|
|
||||||
a = torch.cond(x.sum() > 4, true_fn, false_fn, [x])
|
a = torch.cond(x.shape[0] > 4, true_fn, false_fn, [x])
|
||||||
return (a + 3, a + 4)
|
return (a + 3, a + 4)
|
||||||
|
|
||||||
inp = torch.randn(2, 2)
|
inp = torch.randn(2, 2)
|
||||||
|
|
@ -4254,12 +4254,10 @@ def forward(self, arg0_1, arg1_1):
|
||||||
str(gm.code).strip(),
|
str(gm.code).strip(),
|
||||||
"""\
|
"""\
|
||||||
def forward(self, arg0_1):
|
def forward(self, arg0_1):
|
||||||
sum_1 = torch.ops.aten.sum.default(arg0_1)
|
|
||||||
gt = torch.ops.aten.gt.Scalar(sum_1, 4); sum_1 = None
|
|
||||||
true_graph_0 = self.true_graph_0
|
true_graph_0 = self.true_graph_0
|
||||||
false_graph_0 = self.false_graph_0
|
false_graph_0 = self.false_graph_0
|
||||||
cond = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, [arg0_1]); gt = true_graph_0 = false_graph_0 = arg0_1 = None
|
conditional = torch.ops.higher_order.cond(False, true_graph_0, false_graph_0, [arg0_1]); true_graph_0 = false_graph_0 = arg0_1 = None
|
||||||
getitem = cond[0]; cond = None
|
getitem = conditional[0]; conditional = None
|
||||||
add = torch.ops.aten.add.Tensor(getitem, 3)
|
add = torch.ops.aten.add.Tensor(getitem, 3)
|
||||||
add_1 = torch.ops.aten.add.Tensor(getitem, 4); getitem = None
|
add_1 = torch.ops.aten.add.Tensor(getitem, 4); getitem = None
|
||||||
return (add, add_1)""", # noqa: B950
|
return (add, add_1)""", # noqa: B950
|
||||||
|
|
@ -4272,13 +4270,11 @@ def forward(self, arg0_1):
|
||||||
sin = torch.ops.aten.sin.default(arg0_1); arg0_1 = None
|
sin = torch.ops.aten.sin.default(arg0_1); arg0_1 = None
|
||||||
add = torch.ops.aten.add.Tensor(sin, 5); sin = None
|
add = torch.ops.aten.add.Tensor(sin, 5); sin = None
|
||||||
cos = torch.ops.aten.cos.default(add)
|
cos = torch.ops.aten.cos.default(add)
|
||||||
sum_1 = torch.ops.aten.sum.default(cos); cos = None
|
|
||||||
gt = torch.ops.aten.gt.Scalar(sum_1, 5); sum_1 = None
|
|
||||||
cos_1 = torch.ops.aten.cos.default(add); add = None
|
cos_1 = torch.ops.aten.cos.default(add); add = None
|
||||||
true_graph_0 = self.true_graph_0
|
true_graph_0 = self.true_graph_0
|
||||||
false_graph_0 = self.false_graph_0
|
false_graph_0 = self.false_graph_0
|
||||||
cond = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, [cos_1]); gt = true_graph_0 = false_graph_0 = cos_1 = None
|
conditional = torch.ops.higher_order.cond(False, true_graph_0, false_graph_0, [cos_1]); true_graph_0 = false_graph_0 = cos_1 = None
|
||||||
getitem = cond[0]; cond = None
|
getitem = conditional[0]; conditional = None
|
||||||
return (getitem,)""", # noqa: B950
|
return (getitem,)""", # noqa: B950
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -4321,7 +4317,7 @@ def forward(self, arg0_1):
|
||||||
+ control_flow.map(f, z, r).sum()
|
+ control_flow.map(f, z, r).sum()
|
||||||
)
|
)
|
||||||
|
|
||||||
a = torch.cond(x.sum() > 4, true_fn, false_fn, [x, y])
|
a = torch.cond(x.shape[0] > 4, true_fn, false_fn, [x, y])
|
||||||
return (a + 3, a + 4)
|
return (a + 3, a + 4)
|
||||||
|
|
||||||
inps = [torch.randn(2, 2), torch.ones(2)]
|
inps = [torch.randn(2, 2), torch.ones(2)]
|
||||||
|
|
@ -4330,12 +4326,10 @@ def forward(self, arg0_1):
|
||||||
str(gm.code).strip(),
|
str(gm.code).strip(),
|
||||||
"""\
|
"""\
|
||||||
def forward(self, arg0_1, arg1_1):
|
def forward(self, arg0_1, arg1_1):
|
||||||
sum_1 = torch.ops.aten.sum.default(arg0_1)
|
|
||||||
gt = torch.ops.aten.gt.Scalar(sum_1, 4); sum_1 = None
|
|
||||||
true_graph_0 = self.true_graph_0
|
true_graph_0 = self.true_graph_0
|
||||||
false_graph_0 = self.false_graph_0
|
false_graph_0 = self.false_graph_0
|
||||||
cond = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, [arg0_1, arg1_1]); gt = true_graph_0 = false_graph_0 = arg0_1 = arg1_1 = None
|
conditional = torch.ops.higher_order.cond(False, true_graph_0, false_graph_0, [arg0_1, arg1_1]); true_graph_0 = false_graph_0 = arg0_1 = arg1_1 = None
|
||||||
getitem = cond[0]; cond = None
|
getitem = conditional[0]; conditional = None
|
||||||
add = torch.ops.aten.add.Tensor(getitem, 3)
|
add = torch.ops.aten.add.Tensor(getitem, 3)
|
||||||
add_1 = torch.ops.aten.add.Tensor(getitem, 4); getitem = None
|
add_1 = torch.ops.aten.add.Tensor(getitem, 4); getitem = None
|
||||||
return (add, add_1)""", # noqa: B950
|
return (add, add_1)""", # noqa: B950
|
||||||
|
|
@ -4440,7 +4434,7 @@ def forward(self, arg0_1, arg1_1):
|
||||||
z.add_(6)
|
z.add_(6)
|
||||||
return z.sin()
|
return z.sin()
|
||||||
|
|
||||||
a = torch.cond(x.sum() > 4, true_fn, false_fn, [x])
|
a = torch.cond(x.shape[0] > 4, true_fn, false_fn, [x])
|
||||||
return (a + 3, a + 4)
|
return (a + 3, a + 4)
|
||||||
|
|
||||||
inp = torch.randn(2, 2)
|
inp = torch.randn(2, 2)
|
||||||
|
|
@ -4449,12 +4443,10 @@ def forward(self, arg0_1, arg1_1):
|
||||||
str(gm.code).strip(),
|
str(gm.code).strip(),
|
||||||
"""\
|
"""\
|
||||||
def forward(self, arg0_1):
|
def forward(self, arg0_1):
|
||||||
sum_1 = torch.ops.aten.sum.default(arg0_1)
|
|
||||||
gt = torch.ops.aten.gt.Scalar(sum_1, 4); sum_1 = None
|
|
||||||
true_graph_0 = self.true_graph_0
|
true_graph_0 = self.true_graph_0
|
||||||
false_graph_0 = self.false_graph_0
|
false_graph_0 = self.false_graph_0
|
||||||
cond = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, [arg0_1]); gt = true_graph_0 = false_graph_0 = arg0_1 = None
|
conditional = torch.ops.higher_order.cond(False, true_graph_0, false_graph_0, [arg0_1]); true_graph_0 = false_graph_0 = arg0_1 = None
|
||||||
getitem = cond[0]; cond = None
|
getitem = conditional[0]; conditional = None
|
||||||
add = torch.ops.aten.add.Tensor(getitem, 3)
|
add = torch.ops.aten.add.Tensor(getitem, 3)
|
||||||
add_1 = torch.ops.aten.add.Tensor(getitem, 4); getitem = None
|
add_1 = torch.ops.aten.add.Tensor(getitem, 4); getitem = None
|
||||||
return (add, add_1)""", # noqa: B950
|
return (add, add_1)""", # noqa: B950
|
||||||
|
|
@ -4875,7 +4867,7 @@ def forward(self, arg0_1, arg1_1, arg2_1):
|
||||||
y.add_(6)
|
y.add_(6)
|
||||||
return x.sin()
|
return x.sin()
|
||||||
|
|
||||||
a = torch.cond(x.sum() > 4, true_fn, false_fn, [x])
|
a = torch.cond(x.shape[0] > 4, true_fn, false_fn, [x])
|
||||||
return (a + 3, a + 4)
|
return (a + 3, a + 4)
|
||||||
|
|
||||||
inp = torch.randn(3, 4)
|
inp = torch.randn(3, 4)
|
||||||
|
|
@ -4884,12 +4876,10 @@ def forward(self, arg0_1, arg1_1, arg2_1):
|
||||||
gm.code.strip(),
|
gm.code.strip(),
|
||||||
"""\
|
"""\
|
||||||
def forward(self, arg0_1):
|
def forward(self, arg0_1):
|
||||||
sum_1 = torch.ops.aten.sum.default(arg0_1)
|
|
||||||
gt = torch.ops.aten.gt.Scalar(sum_1, 4); sum_1 = None
|
|
||||||
true_graph_0 = self.true_graph_0
|
true_graph_0 = self.true_graph_0
|
||||||
false_graph_0 = self.false_graph_0
|
false_graph_0 = self.false_graph_0
|
||||||
cond = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, [arg0_1]); gt = true_graph_0 = false_graph_0 = arg0_1 = None
|
conditional = torch.ops.higher_order.cond(False, true_graph_0, false_graph_0, [arg0_1]); true_graph_0 = false_graph_0 = arg0_1 = None
|
||||||
getitem = cond[0]; cond = None
|
getitem = conditional[0]; conditional = None
|
||||||
add = torch.ops.aten.add.Tensor(getitem, 3)
|
add = torch.ops.aten.add.Tensor(getitem, 3)
|
||||||
add_1 = torch.ops.aten.add.Tensor(getitem, 4); getitem = None
|
add_1 = torch.ops.aten.add.Tensor(getitem, 4); getitem = None
|
||||||
return (add, add_1)""", # noqa: B950
|
return (add, add_1)""", # noqa: B950
|
||||||
|
|
|
||||||
|
|
@ -877,7 +877,7 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1):
|
||||||
f(x, torch.tensor(True), torch.tensor(True)),
|
f(x, torch.tensor(True), torch.tensor(True)),
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_cond_functionalized(self):
|
def test_cond_functionalized_hah(self):
|
||||||
def true_fn(x):
|
def true_fn(x):
|
||||||
y = x.sin()
|
y = x.sin()
|
||||||
y.add_(4)
|
y.add_(4)
|
||||||
|
|
@ -894,9 +894,7 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1):
|
||||||
functional_f = torch.func.functionalize(f)
|
functional_f = torch.func.functionalize(f)
|
||||||
self.assertEqual(functional_f(*example_inputs), f(*example_inputs))
|
self.assertEqual(functional_f(*example_inputs), f(*example_inputs))
|
||||||
|
|
||||||
graph_module = make_fx(torch.func.functionalize(f), tracing_mode="symbolic")(
|
graph_module = make_fx(torch.func.functionalize(f))(*example_inputs)
|
||||||
*example_inputs
|
|
||||||
)
|
|
||||||
self.assertEqual(graph_module(*example_inputs), f(*example_inputs))
|
self.assertEqual(graph_module(*example_inputs), f(*example_inputs))
|
||||||
|
|
||||||
all_ops_in_true_branch = []
|
all_ops_in_true_branch = []
|
||||||
|
|
@ -906,6 +904,9 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1):
|
||||||
|
|
||||||
self.assertFalse(any(op._schema.is_mutable for op in all_ops_in_true_branch))
|
self.assertFalse(any(op._schema.is_mutable for op in all_ops_in_true_branch))
|
||||||
|
|
||||||
|
graph_module = make_fx(torch.func.functionalize(f), tracing_mode="symbolic")(
|
||||||
|
*example_inputs
|
||||||
|
)
|
||||||
self.assertEqual(graph_module(*example_inputs), f(*example_inputs))
|
self.assertEqual(graph_module(*example_inputs), f(*example_inputs))
|
||||||
|
|
||||||
def test_cond_accepts_torch_function_as_inputs(self):
|
def test_cond_accepts_torch_function_as_inputs(self):
|
||||||
|
|
@ -924,8 +925,8 @@ def forward(self, a_1, b_1):
|
||||||
gt = torch.ops.aten.gt.Scalar(sum_1, 0); sum_1 = None
|
gt = torch.ops.aten.gt.Scalar(sum_1, 0); sum_1 = None
|
||||||
true_graph_0 = self.true_graph_0
|
true_graph_0 = self.true_graph_0
|
||||||
false_graph_0 = self.false_graph_0
|
false_graph_0 = self.false_graph_0
|
||||||
cond = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, [a_1, b_1]); gt = true_graph_0 = false_graph_0 = a_1 = b_1 = None
|
conditional = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, [a_1, b_1]); gt = true_graph_0 = false_graph_0 = a_1 = b_1 = None
|
||||||
getitem = cond[0]; cond = None
|
getitem = conditional[0]; conditional = None
|
||||||
return getitem""", # noqa: B950
|
return getitem""", # noqa: B950
|
||||||
)
|
)
|
||||||
self.assertExpectedInline(
|
self.assertExpectedInline(
|
||||||
|
|
@ -972,9 +973,9 @@ def forward(self, arg0_1, arg1_1):
|
||||||
z = torch.add(y, y)
|
z = torch.add(y, y)
|
||||||
return z
|
return z
|
||||||
|
|
||||||
symbolic_traced_graph = self._check_tracing(
|
symbolic_traced_graph = self._check_tracing(f, (torch.ones(4), True))[
|
||||||
f, (torch.ones(4), torch.Tensor([True]))
|
"symbolic"
|
||||||
)["symbolic"]
|
]
|
||||||
graph_shape_env = symbolic_traced_graph.shape_env
|
graph_shape_env = symbolic_traced_graph.shape_env
|
||||||
|
|
||||||
def _node_shape_env_iter(gm):
|
def _node_shape_env_iter(gm):
|
||||||
|
|
@ -1020,14 +1021,15 @@ def forward(self, arg0_1, arg1_1):
|
||||||
functional_f = torch.func.functionalize(f)
|
functional_f = torch.func.functionalize(f)
|
||||||
self.assertEqual(functional_f(*example_inputs), f(*example_inputs))
|
self.assertEqual(functional_f(*example_inputs), f(*example_inputs))
|
||||||
|
|
||||||
graph_module = make_fx(torch.func.functionalize(f), tracing_mode="symbolic")(
|
graph_module = make_fx(torch.func.functionalize(f))(*example_inputs)
|
||||||
*example_inputs
|
|
||||||
)
|
|
||||||
self.assertEqual(graph_module(*example_inputs), f(*example_inputs))
|
self.assertEqual(graph_module(*example_inputs), f(*example_inputs))
|
||||||
|
|
||||||
gm_true_true_branch = graph_module.true_graph_0.true_graph_0
|
gm_true_true_branch = graph_module.true_graph_0.true_graph_0
|
||||||
|
|
||||||
self.assertEqual(graph_module(*example_inputs), f(*example_inputs))
|
graph_module1 = make_fx(torch.func.functionalize(f), tracing_mode="symbolic")(
|
||||||
|
*example_inputs
|
||||||
|
)
|
||||||
|
self.assertEqual(graph_module1(*example_inputs), f(*example_inputs))
|
||||||
|
|
||||||
all_ops = []
|
all_ops = []
|
||||||
for node in gm_true_true_branch.graph.nodes:
|
for node in gm_true_true_branch.graph.nodes:
|
||||||
|
|
@ -1055,7 +1057,8 @@ def forward(self, arg0_1, arg1_1):
|
||||||
self.assertEqual(graph_module(*example_inputs), f(*example_inputs))
|
self.assertEqual(graph_module(*example_inputs), f(*example_inputs))
|
||||||
|
|
||||||
# https://github.com/pytorch/pytorch/issues/126988
|
# https://github.com/pytorch/pytorch/issues/126988
|
||||||
def test_cond_functionalized_input_mutation_on_true_brancte(self):
|
@xfailIfTorchDynamo
|
||||||
|
def test_cond_functionalized_input_mutation_on_true_branch(self):
|
||||||
def true_fn(x):
|
def true_fn(x):
|
||||||
view_x = x.view(x.shape)
|
view_x = x.view(x.shape)
|
||||||
view_x.add_(1)
|
view_x.add_(1)
|
||||||
|
|
@ -1069,33 +1072,19 @@ def forward(self, arg0_1, arg1_1):
|
||||||
return cond(pred, true_fn, false_fn, [x])
|
return cond(pred, true_fn, false_fn, [x])
|
||||||
|
|
||||||
example_inputs = (torch.ones(4, 5),)
|
example_inputs = (torch.ones(4, 5),)
|
||||||
# torch.cond inlines into one of the branches because the predicate
|
functional_f = torch.func.functionalize(f)
|
||||||
# is a constant.
|
|
||||||
gm = make_fx(torch.func.functionalize(f))(*example_inputs)
|
|
||||||
self.assertExpectedInline(
|
|
||||||
gm.code.strip(),
|
|
||||||
"""\
|
|
||||||
def forward(self, x_1):
|
|
||||||
view = torch.ops.aten.view.default(x_1, [4, 5])
|
|
||||||
add = torch.ops.aten.add.Tensor(view, 1); view = None
|
|
||||||
view_1 = torch.ops.aten.view.default(add, [4, 5]); add = None
|
|
||||||
view_2 = torch.ops.aten.view.default(view_1, [4, 5])
|
|
||||||
sin = torch.ops.aten.sin.default(view_2); view_2 = None
|
|
||||||
sum_1 = torch.ops.aten.sum.default(sin); sin = None
|
|
||||||
copy_ = torch.ops.aten.copy_.default(x_1, view_1); x_1 = view_1 = None
|
|
||||||
return sum_1""",
|
|
||||||
)
|
|
||||||
|
|
||||||
# torch.cond triggers the check of the branches because the predicate
|
|
||||||
# is a SymBool.
|
|
||||||
with self.assertRaisesRegex(
|
with self.assertRaisesRegex(
|
||||||
UnsupportedAliasMutationException, "One of torch.cond branch"
|
UnsupportedAliasMutationException, "One of torch.cond branch"
|
||||||
):
|
):
|
||||||
make_fx(torch.func.functionalize(f), tracing_mode="symbolic")(
|
functional_f(*example_inputs)
|
||||||
*example_inputs
|
|
||||||
)
|
with self.assertRaisesRegex(
|
||||||
|
UnsupportedAliasMutationException, "One of torch.cond branch"
|
||||||
|
):
|
||||||
|
make_fx(torch.func.functionalize(f))(*example_inputs)
|
||||||
|
|
||||||
# https://github.com/pytorch/pytorch/issues/126988
|
# https://github.com/pytorch/pytorch/issues/126988
|
||||||
|
@xfailIfTorchDynamo
|
||||||
def test_cond_functionalized_input_mutation_on_false_branch(self):
|
def test_cond_functionalized_input_mutation_on_false_branch(self):
|
||||||
def true_fn(x):
|
def true_fn(x):
|
||||||
return x.sin().sum()
|
return x.sin().sum()
|
||||||
|
|
@ -1110,33 +1099,19 @@ def forward(self, x_1):
|
||||||
return cond(pred, true_fn, false_fn, [x])
|
return cond(pred, true_fn, false_fn, [x])
|
||||||
|
|
||||||
example_inputs = (torch.ones(5, 5),)
|
example_inputs = (torch.ones(5, 5),)
|
||||||
gm = make_fx(torch.func.functionalize(f))(*example_inputs)
|
functional_f = torch.func.functionalize(f)
|
||||||
# torch.cond inlines into one of the branches because the predicate
|
|
||||||
# is a constant.
|
|
||||||
self.assertExpectedInline(
|
|
||||||
gm.code.strip(),
|
|
||||||
"""\
|
|
||||||
def forward(self, x_1):
|
|
||||||
view = torch.ops.aten.view.default(x_1, [5, 5])
|
|
||||||
add = torch.ops.aten.add.Tensor(view, 1); view = None
|
|
||||||
view_1 = torch.ops.aten.view.default(add, [5, 5]); add = None
|
|
||||||
view_2 = torch.ops.aten.view.default(view_1, [5, 5])
|
|
||||||
cos = torch.ops.aten.cos.default(view_2); view_2 = None
|
|
||||||
sum_1 = torch.ops.aten.sum.default(cos); cos = None
|
|
||||||
copy_ = torch.ops.aten.copy_.default(x_1, view_1); x_1 = view_1 = None
|
|
||||||
return sum_1""",
|
|
||||||
)
|
|
||||||
|
|
||||||
# torch.cond triggers the check of the branches because the predicate
|
|
||||||
# is a SymBool.
|
|
||||||
with self.assertRaisesRegex(
|
with self.assertRaisesRegex(
|
||||||
UnsupportedAliasMutationException, "One of torch.cond branch"
|
UnsupportedAliasMutationException, "One of torch.cond branch"
|
||||||
):
|
):
|
||||||
make_fx(torch.func.functionalize(f), tracing_mode="symbolic")(
|
functional_f(*example_inputs)
|
||||||
*example_inputs
|
|
||||||
)
|
with self.assertRaisesRegex(
|
||||||
|
UnsupportedAliasMutationException, "One of torch.cond branch"
|
||||||
|
):
|
||||||
|
make_fx(torch.func.functionalize(f))(*example_inputs)
|
||||||
|
|
||||||
# https://github.com/pytorch/pytorch/issues/126988
|
# https://github.com/pytorch/pytorch/issues/126988
|
||||||
|
@xfailIfTorchDynamo
|
||||||
def test_cond_functionalized_output_alias_input(self):
|
def test_cond_functionalized_output_alias_input(self):
|
||||||
def true_fn(x):
|
def true_fn(x):
|
||||||
return x
|
return x
|
||||||
|
|
@ -1150,27 +1125,22 @@ def forward(self, x_1):
|
||||||
return cond(pred, true_fn, false_fn, [x])
|
return cond(pred, true_fn, false_fn, [x])
|
||||||
|
|
||||||
example_inputs = (torch.ones(5, 5),)
|
example_inputs = (torch.ones(5, 5),)
|
||||||
gm = make_fx(torch.func.functionalize(f))(*example_inputs)
|
functional_f = torch.func.functionalize(f)
|
||||||
# torch.cond inlines into one of the branches because the predicate
|
|
||||||
# is a constant.
|
|
||||||
self.assertExpectedInline(
|
|
||||||
gm.code.strip(),
|
|
||||||
"""\
|
|
||||||
def forward(self, x_1):
|
|
||||||
view = torch.ops.aten.view.default(x_1, [5, 5]); x_1 = None
|
|
||||||
return view""",
|
|
||||||
)
|
|
||||||
|
|
||||||
# torch.cond triggers the check of the branches because the predicate
|
|
||||||
# is a SymBool.
|
|
||||||
with self.assertRaisesRegex(
|
with self.assertRaisesRegex(
|
||||||
UnsupportedAliasMutationException, "One of torch.cond branch"
|
UnsupportedAliasMutationException,
|
||||||
|
"One of torch.cond branch might be aliasing",
|
||||||
):
|
):
|
||||||
make_fx(torch.func.functionalize(f), tracing_mode="symbolic")(
|
functional_f(*example_inputs)
|
||||||
*example_inputs
|
|
||||||
)
|
with self.assertRaisesRegex(
|
||||||
|
UnsupportedAliasMutationException,
|
||||||
|
"One of torch.cond branch might be aliasing",
|
||||||
|
):
|
||||||
|
make_fx(torch.func.functionalize(f))(*example_inputs)
|
||||||
|
|
||||||
# https://github.com/pytorch/pytorch/issues/126988
|
# https://github.com/pytorch/pytorch/issues/126988
|
||||||
|
@xfailIfTorchDynamo
|
||||||
def test_cond_functionalized_nested_input_mutation(self):
|
def test_cond_functionalized_nested_input_mutation(self):
|
||||||
def true_true_fn(x):
|
def true_true_fn(x):
|
||||||
x.add_(4)
|
x.add_(4)
|
||||||
|
|
@ -1191,14 +1161,19 @@ def forward(self, x_1):
|
||||||
return cond(pred, true_fn, false_fn, [x])
|
return cond(pred, true_fn, false_fn, [x])
|
||||||
|
|
||||||
example_inputs = (torch.ones(4, 5),)
|
example_inputs = (torch.ones(4, 5),)
|
||||||
|
functional_f = torch.func.functionalize(f)
|
||||||
with self.assertRaisesRegex(
|
with self.assertRaisesRegex(
|
||||||
UnsupportedAliasMutationException, "One of torch.cond branch"
|
UnsupportedAliasMutationException, "One of torch.cond branch"
|
||||||
):
|
):
|
||||||
make_fx(torch.func.functionalize(f), tracing_mode="symbolic")(
|
functional_f(*example_inputs)
|
||||||
*example_inputs
|
|
||||||
)
|
with self.assertRaisesRegex(
|
||||||
|
UnsupportedAliasMutationException, "One of torch.cond branch"
|
||||||
|
):
|
||||||
|
make_fx(torch.func.functionalize(f))(*example_inputs)
|
||||||
|
|
||||||
# https://github.com/pytorch/pytorch/issues/126988
|
# https://github.com/pytorch/pytorch/issues/126988
|
||||||
|
@xfailIfTorchDynamo
|
||||||
def test_cond_functionalized_nested_input_mutation_with_aot_func(self):
|
def test_cond_functionalized_nested_input_mutation_with_aot_func(self):
|
||||||
def true_true_fn(x):
|
def true_true_fn(x):
|
||||||
x.add_(4)
|
x.add_(4)
|
||||||
|
|
@ -1222,12 +1197,15 @@ def forward(self, x_1):
|
||||||
try:
|
try:
|
||||||
example_input_func = to_fun_old(example_input)
|
example_input_func = to_fun_old(example_input)
|
||||||
torch._enable_functionalization(reapply_views=False)
|
torch._enable_functionalization(reapply_views=False)
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
UnsupportedAliasMutationException, "One of torch.cond branch"
|
||||||
|
):
|
||||||
f(example_input_func)
|
f(example_input_func)
|
||||||
|
|
||||||
with self.assertRaisesRegex(
|
with self.assertRaisesRegex(
|
||||||
UnsupportedAliasMutationException, "One of torch.cond branch"
|
UnsupportedAliasMutationException, "One of torch.cond branch"
|
||||||
):
|
):
|
||||||
make_fx(f, tracing_mode="symbolic")(example_input_func)
|
make_fx(f)(example_input_func)
|
||||||
finally:
|
finally:
|
||||||
torch._disable_functionalization()
|
torch._disable_functionalization()
|
||||||
|
|
||||||
|
|
@ -1245,7 +1223,7 @@ def forward(self, x_1):
|
||||||
with self.assertRaisesRegex(
|
with self.assertRaisesRegex(
|
||||||
UnsupportedAliasMutationException, "One of torch.cond branch"
|
UnsupportedAliasMutationException, "One of torch.cond branch"
|
||||||
):
|
):
|
||||||
make_fx(f_wrapper(f), tracing_mode="symbolic")(example_input_func)
|
make_fx(f_wrapper(f))(example_input_func)
|
||||||
|
|
||||||
# https://github.com/pytorch/pytorch/issues/126988
|
# https://github.com/pytorch/pytorch/issues/126988
|
||||||
@xfailIfTorchDynamo
|
@xfailIfTorchDynamo
|
||||||
|
|
@ -1258,7 +1236,7 @@ def forward(self, x_1):
|
||||||
return view_x
|
return view_x
|
||||||
|
|
||||||
def f(x):
|
def f(x):
|
||||||
pred = x.sum() > 0
|
pred = x.shape[0] == 4
|
||||||
return cond(pred, true_fn, false_fn, [x])
|
return cond(pred, true_fn, false_fn, [x])
|
||||||
|
|
||||||
example_input = torch.ones(5, 5)
|
example_input = torch.ones(5, 5)
|
||||||
|
|
@ -1300,7 +1278,7 @@ def forward(self, x_1):
|
||||||
UnsupportedAliasMutationException,
|
UnsupportedAliasMutationException,
|
||||||
"One of torch.cond branch might be aliasing",
|
"One of torch.cond branch might be aliasing",
|
||||||
):
|
):
|
||||||
make_fx(f_wrapper(f), tracing_mode="symbolic")(example_input)
|
make_fx(f_wrapper(f))(example_input)
|
||||||
|
|
||||||
def test_cond_functionalized_aot_func_check_functional(self):
|
def test_cond_functionalized_aot_func_check_functional(self):
|
||||||
def true_fn(x):
|
def true_fn(x):
|
||||||
|
|
@ -1338,7 +1316,7 @@ def forward(self, x_1):
|
||||||
|
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
result_gm = make_fx(f_wrapper(f), tracing_mode="symbolic")(example_input)
|
result_gm = make_fx(f_wrapper(f))(example_input)
|
||||||
for node in result_gm.true_graph_0.graph.nodes:
|
for node in result_gm.true_graph_0.graph.nodes:
|
||||||
if node.op == "call_function":
|
if node.op == "call_function":
|
||||||
self.assertTrue(not node.target._schema.is_mutable)
|
self.assertTrue(not node.target._schema.is_mutable)
|
||||||
|
|
@ -1404,12 +1382,12 @@ def forward(self, x_1):
|
||||||
def forward(self, x_1, pred_1, pred2_1):
|
def forward(self, x_1, pred_1, pred2_1):
|
||||||
true_graph_0 = self.true_graph_0
|
true_graph_0 = self.true_graph_0
|
||||||
false_graph_0 = self.false_graph_0
|
false_graph_0 = self.false_graph_0
|
||||||
cond = torch.ops.higher_order.cond(pred_1, true_graph_0, false_graph_0, [x_1]); pred_1 = true_graph_0 = false_graph_0 = None
|
conditional = torch.ops.higher_order.cond(pred_1, true_graph_0, false_graph_0, [x_1]); pred_1 = true_graph_0 = false_graph_0 = None
|
||||||
getitem = cond[0]; cond = None
|
getitem = conditional[0]; conditional = None
|
||||||
true_graph_1 = self.true_graph_1
|
true_graph_1 = self.true_graph_1
|
||||||
false_graph_1 = self.false_graph_1
|
false_graph_1 = self.false_graph_1
|
||||||
cond_1 = torch.ops.higher_order.cond(pred2_1, true_graph_1, false_graph_1, [x_1]); pred2_1 = true_graph_1 = false_graph_1 = x_1 = None
|
conditional_1 = torch.ops.higher_order.cond(pred2_1, true_graph_1, false_graph_1, [x_1]); pred2_1 = true_graph_1 = false_graph_1 = x_1 = None
|
||||||
getitem_1 = cond_1[0]; cond_1 = None
|
getitem_1 = conditional_1[0]; conditional_1 = None
|
||||||
add = torch.ops.aten.add.Tensor(getitem, getitem_1); getitem = getitem_1 = None
|
add = torch.ops.aten.add.Tensor(getitem, getitem_1); getitem = getitem_1 = None
|
||||||
return add""", # noqa: B950
|
return add""", # noqa: B950
|
||||||
)
|
)
|
||||||
|
|
@ -1577,12 +1555,12 @@ def forward(self, arg0_1):
|
||||||
def forward(self, x_1, pred_1, pred2_1):
|
def forward(self, x_1, pred_1, pred2_1):
|
||||||
true_graph_0 = self.true_graph_0
|
true_graph_0 = self.true_graph_0
|
||||||
false_graph_0 = self.false_graph_0
|
false_graph_0 = self.false_graph_0
|
||||||
cond = torch.ops.higher_order.cond(pred_1, true_graph_0, false_graph_0, [x_1]); pred_1 = true_graph_0 = false_graph_0 = None
|
conditional = torch.ops.higher_order.cond(pred_1, true_graph_0, false_graph_0, [x_1]); pred_1 = true_graph_0 = false_graph_0 = None
|
||||||
getitem = cond[0]; cond = None
|
getitem = conditional[0]; conditional = None
|
||||||
true_graph_1 = self.true_graph_1
|
true_graph_1 = self.true_graph_1
|
||||||
false_graph_1 = self.false_graph_1
|
false_graph_1 = self.false_graph_1
|
||||||
cond_1 = torch.ops.higher_order.cond(pred2_1, true_graph_1, false_graph_1, [x_1]); pred2_1 = true_graph_1 = false_graph_1 = x_1 = None
|
conditional_1 = torch.ops.higher_order.cond(pred2_1, true_graph_1, false_graph_1, [x_1]); pred2_1 = true_graph_1 = false_graph_1 = x_1 = None
|
||||||
getitem_1 = cond_1[0]; cond_1 = None
|
getitem_1 = conditional_1[0]; conditional_1 = None
|
||||||
add = torch.ops.aten.add.Tensor(getitem, getitem_1); getitem = getitem_1 = None
|
add = torch.ops.aten.add.Tensor(getitem, getitem_1); getitem = getitem_1 = None
|
||||||
return add""", # noqa: B950
|
return add""", # noqa: B950
|
||||||
)
|
)
|
||||||
|
|
@ -1913,27 +1891,6 @@ def forward(self, arg0_1):
|
||||||
):
|
):
|
||||||
functional_f(*example_inputs)
|
functional_f(*example_inputs)
|
||||||
|
|
||||||
def test_cond_autograd_succeed_when_pred_is_constant(self):
|
|
||||||
def true_fn(x):
|
|
||||||
return x.cos()
|
|
||||||
|
|
||||||
def false_fn(x):
|
|
||||||
return x.sin()
|
|
||||||
|
|
||||||
def f(x, y):
|
|
||||||
return control_flow.cond(x.shape[0] > 4, true_fn, false_fn, [y])
|
|
||||||
|
|
||||||
example_inputs = (
|
|
||||||
torch.ones(3, 2, 4, requires_grad=True),
|
|
||||||
torch.ones(4, requires_grad=True),
|
|
||||||
)
|
|
||||||
# Due to x.shape[0] can be statically evaluated to be False, we can evaluate
|
|
||||||
# the backward.
|
|
||||||
f(*example_inputs).sum().backward()
|
|
||||||
|
|
||||||
# Ensure no error is thrown when not running backward
|
|
||||||
f(*example_inputs)
|
|
||||||
|
|
||||||
def test_cond_autograd_fail(self):
|
def test_cond_autograd_fail(self):
|
||||||
def true_fn(x):
|
def true_fn(x):
|
||||||
return x.cos()
|
return x.cos()
|
||||||
|
|
@ -1942,7 +1899,7 @@ def forward(self, arg0_1):
|
||||||
return x.sin()
|
return x.sin()
|
||||||
|
|
||||||
def f(x, y):
|
def f(x, y):
|
||||||
return control_flow.cond(x.sum() > 4, true_fn, false_fn, [y])
|
return control_flow.cond(x.shape[0] > 4, true_fn, false_fn, [y])
|
||||||
|
|
||||||
example_inputs = (
|
example_inputs = (
|
||||||
torch.ones(3, 2, 4, requires_grad=True),
|
torch.ones(3, 2, 4, requires_grad=True),
|
||||||
|
|
@ -2072,8 +2029,8 @@ def forward(self, x_1):
|
||||||
eq = sym_size_int == 4; sym_size_int = None
|
eq = sym_size_int == 4; sym_size_int = None
|
||||||
true_graph_0 = self.true_graph_0
|
true_graph_0 = self.true_graph_0
|
||||||
false_graph_0 = self.false_graph_0
|
false_graph_0 = self.false_graph_0
|
||||||
cond = torch.ops.higher_order.cond(eq, true_graph_0, false_graph_0, [x_1]); eq = true_graph_0 = false_graph_0 = x_1 = None
|
conditional = torch.ops.higher_order.cond(eq, true_graph_0, false_graph_0, [x_1]); eq = true_graph_0 = false_graph_0 = x_1 = None
|
||||||
getitem = cond[0]; cond = None
|
getitem = conditional[0]; conditional = None
|
||||||
return getitem""", # noqa: B950
|
return getitem""", # noqa: B950
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -2145,20 +2102,18 @@ def forward(self, x_1):
|
||||||
# expected branches takes [x, a, b] as input
|
# expected branches takes [x, a, b] as input
|
||||||
inp = torch.randn(2, 3)
|
inp = torch.randn(2, 3)
|
||||||
|
|
||||||
gm = make_fx(foo, tracing_mode="symbolic", _allow_non_fake_inputs=True)(inp)
|
gm = make_fx(foo)(inp)
|
||||||
|
|
||||||
self.assertExpectedInline(
|
self.assertExpectedInline(
|
||||||
gm.code.strip(),
|
gm.code.strip(),
|
||||||
"""\
|
"""\
|
||||||
def forward(self, x_1):
|
def forward(self, x_1):
|
||||||
sym_size_int = torch.ops.aten.sym_size.int(x_1, 0)
|
|
||||||
eq = sym_size_int == 4; sym_size_int = None
|
|
||||||
true_graph_0 = self.true_graph_0
|
true_graph_0 = self.true_graph_0
|
||||||
false_graph_0 = self.false_graph_0
|
false_graph_0 = self.false_graph_0
|
||||||
_tensor_constant0 = self._tensor_constant0
|
_tensor_constant0 = self._tensor_constant0
|
||||||
_tensor_constant1 = self._tensor_constant1
|
_tensor_constant1 = self._tensor_constant1
|
||||||
cond = torch.ops.higher_order.cond(eq, true_graph_0, false_graph_0, [x_1, _tensor_constant0, _tensor_constant1]); eq = true_graph_0 = false_graph_0 = x_1 = _tensor_constant0 = _tensor_constant1 = None
|
conditional = torch.ops.higher_order.cond(False, true_graph_0, false_graph_0, [x_1, _tensor_constant0, _tensor_constant1]); true_graph_0 = false_graph_0 = x_1 = _tensor_constant0 = _tensor_constant1 = None
|
||||||
getitem = cond[0]; cond = None
|
getitem = conditional[0]; conditional = None
|
||||||
return getitem""", # noqa: B950
|
return getitem""", # noqa: B950
|
||||||
)
|
)
|
||||||
self.assertExpectedInline(
|
self.assertExpectedInline(
|
||||||
|
|
@ -2308,8 +2263,8 @@ def forward(self, pred_1, x_1):
|
||||||
def forward(self, arg0_1, arg1_1):
|
def forward(self, arg0_1, arg1_1):
|
||||||
true_graph_0 = self.true_graph_0
|
true_graph_0 = self.true_graph_0
|
||||||
false_graph_0 = self.false_graph_0
|
false_graph_0 = self.false_graph_0
|
||||||
cond = torch.ops.higher_order.cond(arg1_1, true_graph_0, false_graph_0, [arg0_1]); arg1_1 = true_graph_0 = false_graph_0 = arg0_1 = None
|
conditional = torch.ops.higher_order.cond(arg1_1, true_graph_0, false_graph_0, [arg0_1]); arg1_1 = true_graph_0 = false_graph_0 = arg0_1 = None
|
||||||
getitem = cond[0]; cond = None
|
getitem = conditional[0]; conditional = None
|
||||||
return [getitem]""", # noqa: B950
|
return [getitem]""", # noqa: B950
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -2350,7 +2305,7 @@ def forward(self, arg0_1, arg1_1):
|
||||||
counters.clear()
|
counters.clear()
|
||||||
|
|
||||||
def foo(x, true_fn, false_fn):
|
def foo(x, true_fn, false_fn):
|
||||||
return cond(x.sum() < 0, true_fn, false_fn, (x,))
|
return cond(x.shape[0] == 4, true_fn, false_fn, (x,))
|
||||||
|
|
||||||
inp = torch.ones(3, 4)
|
inp = torch.ones(3, 4)
|
||||||
exp_out = inp.sin()
|
exp_out = inp.sin()
|
||||||
|
|
@ -2392,8 +2347,8 @@ def forward(self, x_1):
|
||||||
eq = sym_size_int == 4; sym_size_int = None
|
eq = sym_size_int == 4; sym_size_int = None
|
||||||
true_graph_0 = self.true_graph_0
|
true_graph_0 = self.true_graph_0
|
||||||
false_graph_0 = self.false_graph_0
|
false_graph_0 = self.false_graph_0
|
||||||
cond = torch.ops.higher_order.cond(eq, true_graph_0, false_graph_0, [x_1]); eq = true_graph_0 = false_graph_0 = x_1 = None
|
conditional = torch.ops.higher_order.cond(eq, true_graph_0, false_graph_0, [x_1]); eq = true_graph_0 = false_graph_0 = x_1 = None
|
||||||
getitem = cond[0]; cond = None
|
getitem = conditional[0]; conditional = None
|
||||||
return getitem""", # noqa: B950
|
return getitem""", # noqa: B950
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -632,18 +632,6 @@ class CondHigherOrderVariable(TorchHigherOrderOperatorVariable):
|
||||||
f"Expected 4 arguments but got {len(args)}.\n"
|
f"Expected 4 arguments but got {len(args)}.\n"
|
||||||
f"Usage: cond(pred, true_fn, false_fn, operands)",
|
f"Usage: cond(pred, true_fn, false_fn, operands)",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Specialize into one of the branches since pred is constant
|
|
||||||
if type(args[0]) is ConstantVariable:
|
|
||||||
log.warning(
|
|
||||||
"Pred is a Python constant. When used with torch.cond, it executes only one of the branches."
|
|
||||||
" If you want torch.cond to perserve two branches, please make the predicate a boolean tensor or a SymBool."
|
|
||||||
)
|
|
||||||
if args[0].as_python_constant():
|
|
||||||
return args[1].call_function(tx, args[3].unpack_var_sequence(tx), {})
|
|
||||||
else:
|
|
||||||
return args[2].call_function(tx, args[3].unpack_var_sequence(tx), {})
|
|
||||||
|
|
||||||
# predicate
|
# predicate
|
||||||
if type(args[0]) not in (ConstantVariable, TensorVariable, SymNodeVariable):
|
if type(args[0]) not in (ConstantVariable, TensorVariable, SymNodeVariable):
|
||||||
unimplemented(
|
unimplemented(
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,6 @@
|
||||||
# mypy: allow-untyped-defs
|
# mypy: allow-untyped-defs
|
||||||
import contextlib
|
import contextlib
|
||||||
|
|
||||||
import logging
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch._subclasses.functional_tensor
|
import torch._subclasses.functional_tensor
|
||||||
import torch.utils._pytree as pytree
|
import torch.utils._pytree as pytree
|
||||||
|
|
@ -34,8 +32,6 @@ from torch.fx.experimental.proxy_tensor import (
|
||||||
from torch.fx.passes.shape_prop import _extract_tensor_metadata
|
from torch.fx.passes.shape_prop import _extract_tensor_metadata
|
||||||
from torch.utils._python_dispatch import _get_current_dispatch_mode
|
from torch.utils._python_dispatch import _get_current_dispatch_mode
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
@exposed_in("torch")
|
@exposed_in("torch")
|
||||||
def cond(pred, true_fn, false_fn, operands):
|
def cond(pred, true_fn, false_fn, operands):
|
||||||
|
|
@ -107,19 +103,10 @@ def cond(pred, true_fn, false_fn, operands):
|
||||||
- The **output** of branches must be a **single Tensor**. Pytree of tensors will be supported in the future.
|
- The **output** of branches must be a **single Tensor**. Pytree of tensors will be supported in the future.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if torch.compiler.is_dynamo_compiling():
|
if torch.compiler.is_dynamo_compiling():
|
||||||
return cond_op(pred, true_fn, false_fn, operands)
|
return cond_op(pred, true_fn, false_fn, operands)
|
||||||
|
|
||||||
if isinstance(pred, (bool, int, float)):
|
|
||||||
log.warning(
|
|
||||||
"Pred is a Python constant. When used with torch.cond, it executes only one of the branches."
|
|
||||||
" If you want torch.cond to perserve two branches, please make the predicate a boolean tensor or a SymBool."
|
|
||||||
)
|
|
||||||
if pred:
|
|
||||||
return true_fn(*operands)
|
|
||||||
else:
|
|
||||||
return false_fn(*operands)
|
|
||||||
|
|
||||||
def _validate_input(pred, true_fn, false_fn, operands):
|
def _validate_input(pred, true_fn, false_fn, operands):
|
||||||
if not isinstance(pred, (bool, torch.Tensor, torch.SymBool)):
|
if not isinstance(pred, (bool, torch.Tensor, torch.SymBool)):
|
||||||
raise RuntimeError(f"Expected pred to be bool or tensor, but got {pred}.")
|
raise RuntimeError(f"Expected pred to be bool or tensor, but got {pred}.")
|
||||||
|
|
@ -213,7 +200,7 @@ def trace_cond(proxy_mode, func_overload, pred, true_fn, false_fn, operands):
|
||||||
proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, args)
|
proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, args)
|
||||||
|
|
||||||
out_proxy = proxy_mode.tracer.create_proxy(
|
out_proxy = proxy_mode.tracer.create_proxy(
|
||||||
"call_function", func_overload, proxy_args, {}
|
"call_function", func_overload, proxy_args, {}, name="conditional"
|
||||||
)
|
)
|
||||||
|
|
||||||
# At this point, we're *guaranteed* that whether an output came from the
|
# At this point, we're *guaranteed* that whether an output came from the
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user