Revert "[Autograd] Cond Higher-Order Operation (#126911)"

This reverts commit f7058b735e.

Reverted https://github.com/pytorch/pytorch/pull/126911 on behalf of https://github.com/clee2000 due to broke lint and functorch/test_aotdispatch f7058b735e Probably a landrace since both the test and lint passed on PR ([comment](https://github.com/pytorch/pytorch/pull/126911#issuecomment-2237703182))
This commit is contained in:
PyTorch MergeBot 2024-07-18 22:06:40 +00:00
parent 686b7f046a
commit fb3674b1f4
6 changed files with 143 additions and 1171 deletions

View File

@ -257,751 +257,6 @@ class TestControlFlow(TestCase):
result = cond(pred, true_fn, false_fn, [x])
self.assertEqual(result, torch.cos(x))
def test_cond_autograd_simple(self):
def true_fn(x):
return x.sin()
def false_fn(x):
return x.cos()
for pred, fn in zip(
[torch.tensor(False), torch.tensor(True)], [false_fn, true_fn]
):
x = torch.randn(4, requires_grad=True)
result = cond(pred, true_fn, false_fn, (x,))
self.assertEqual(result, fn(x))
grad_out = torch.ones_like(result)
grads = torch.autograd.grad(result, (x,), grad_out)
expected_grads = torch.autograd.grad(fn(x), (x,), grad_out)
self.assertEqual(expected_grads, grads)
def f(pred, x):
result = cond(pred, true_fn, false_fn, (x,))
grad_out = torch.ones_like(result)
return torch.autograd.grad(result, (x,), grad_out)
gm = make_fx(f, tracing_mode="symbolic")(pred, x)
self.assertExpectedInline(
gm.code.strip(),
"""\
def forward(self, pred_1, x_1):
true_graph_0 = self.true_graph_0
false_graph_0 = self.false_graph_0
cond = torch.ops.higher_order.cond(pred_1, true_graph_0, false_graph_0, (x_1,)); true_graph_0 = false_graph_0 = None
getitem = cond[0]; cond = None
ones_like = torch.ops.aten.ones_like.default(getitem, pin_memory = False); getitem = None
true_graph_1 = self.true_graph_1
false_graph_1 = self.false_graph_1
cond_1 = torch.ops.higher_order.cond(pred_1, true_graph_1, false_graph_1, (ones_like, x_1)); pred_1 = true_graph_1 = false_graph_1 = ones_like = x_1 = None
getitem_1 = cond_1[0]; cond_1 = None
return (getitem_1,)""", # noqa: B950
)
def test_cond_autograd_complex(self):
def true_fn(x):
return torch.abs((x**2).sin())
def false_fn(x):
return (x + 42).cos()
for pred, fn in zip(
[torch.tensor(False), torch.tensor(True)], [false_fn, true_fn]
):
x = torch.randn(4, requires_grad=True)
result = cond(pred, true_fn, false_fn, (x,))
self.assertEqual(result, fn(x))
grad_out = torch.ones_like(result)
grads = torch.autograd.grad(result, (x,), grad_out)
expected_grads = torch.autograd.grad(fn(x), (x,), grad_out)
self.assertEqual(expected_grads, grads)
def f(pred, x):
result = cond(pred, true_fn, false_fn, (x,))
grad_out = torch.ones_like(result)
return torch.autograd.grad(result, (x,), grad_out)
gm = make_fx(f, tracing_mode="symbolic")(pred, x)
self.assertExpectedInline(
gm.code.strip(),
"""\
def forward(self, pred_1, x_1):
true_graph_0 = self.true_graph_0
false_graph_0 = self.false_graph_0
cond = torch.ops.higher_order.cond(pred_1, true_graph_0, false_graph_0, (x_1,)); true_graph_0 = false_graph_0 = None
getitem = cond[0]; cond = None
ones_like = torch.ops.aten.ones_like.default(getitem, pin_memory = False); getitem = None
true_graph_1 = self.true_graph_1
false_graph_1 = self.false_graph_1
cond_1 = torch.ops.higher_order.cond(pred_1, true_graph_1, false_graph_1, (ones_like, x_1)); pred_1 = true_graph_1 = false_graph_1 = ones_like = x_1 = None
getitem_1 = cond_1[0]; cond_1 = None
return (getitem_1,)""", # noqa: B950
)
@skipIfTorchDynamo("Skip due to graph break when run with dynamo")
def test_cond_autograd_nested(self):
class Nested(torch.nn.Module):
def forward(self, p0, p1, p2, a, b, c):
def true_fn(x0, y0, z0):
def true_true_fn(x1, y1, z1):
return (x1 - y1 * z1) * 3.14
def true_false_fn(x1, y1, z1):
def true_false_true_fn(x2, y2, z2):
return (x2 * y2 * z2) / 2.71
def true_false_false_fn(x2, y2, z2):
return (x2 + y2 + z2) * 1.23
return torch.cond(
p2, true_false_true_fn, true_false_false_fn, [x1, y1, z1]
)
return torch.cond(p1, true_true_fn, true_false_fn, [x0, y0, z0])
def false_fn(x0, y0, z0):
def false_true_fn(x1, y1, z1):
def false_true_true_fn(x2, y2, z2):
return (x2 - y2 - z2) + 1.23
def false_true_false_fn(x2, y2, z2):
return (x2 / y2 / z2) - 3.14
return torch.cond(
p2, false_true_true_fn, false_true_false_fn, [x1, y1, z1]
)
def false_false_fn(x1, y1, z1):
return (x1 - y1 * z1) / 2.71
return torch.cond(p1, false_true_fn, false_false_fn, [x0, y0, z0])
return torch.cond(p0, true_fn, false_fn, [a, b, c])
nn_module = Nested()
def true_fn(x):
return nn_module(
torch.tensor(False), torch.tensor(True), torch.tensor(False), x, x, x
)
def false_fn(x):
return nn_module(
torch.tensor(True), torch.tensor(False), torch.tensor(True), x, x, x
)
x = torch.randn(4, requires_grad=True)
for pred, fn in zip(
[torch.tensor(False), torch.tensor(True)], [false_fn, true_fn]
):
result = cond(pred, true_fn, false_fn, (x,))
self.assertEqual(result, fn(x))
grad_out = torch.ones_like(result)
grads = torch.autograd.grad(result, (x,), grad_out)
expected_grads = torch.autograd.grad(fn(x), (x,), grad_out)
self.assertEqual(expected_grads, grads)
@skipIfTorchDynamo("Skip due to graph break when run with dynamo")
def test_cond_autograd_mixed_require_grad(self):
def true_fn(x, y, z):
return x * y * z
def false_fn(x, y, z):
return x + y + z
x = torch.randn(4, requires_grad=True)
y = torch.randn(4, requires_grad=False)
for pred, fn in zip(
[torch.tensor(False), torch.tensor(True)], [false_fn, true_fn]
):
result = cond(pred, true_fn, false_fn, (x, y, x))
self.assertEqual(result, fn(x, y, x))
grad_out = torch.ones_like(result)
grads = torch.autograd.grad(result, (x,), grad_out)
expected_grads = torch.autograd.grad(fn(x, y, x), (x,), grad_out)
self.assertEqual(expected_grads, grads)
def f(pred, x, y, z):
result = cond(pred, true_fn, false_fn, (x, y, z))
grad_out = torch.ones_like(result)
return torch.autograd.grad(result, (x,), grad_out)
gm = make_fx(f, tracing_mode="symbolic")(pred, x, y, x)
self.assertExpectedInline(
gm.code.strip(),
"""\
def forward(self, pred_1, x_1, y_1, z_1):
true_graph_0 = self.true_graph_0
false_graph_0 = self.false_graph_0
cond = torch.ops.higher_order.cond(pred_1, true_graph_0, false_graph_0, (z_1, y_1)); true_graph_0 = false_graph_0 = None
getitem = cond[0]; cond = None
ones_like = torch.ops.aten.ones_like.default(getitem, pin_memory = False); getitem = None
true_graph_1 = self.true_graph_1
false_graph_1 = self.false_graph_1
cond_1 = torch.ops.higher_order.cond(pred_1, true_graph_1, false_graph_1, (ones_like, z_1, y_1)); pred_1 = true_graph_1 = false_graph_1 = ones_like = z_1 = y_1 = None
getitem_1 = cond_1[0]
getitem_2 = cond_1[1]; cond_1 = None
return (getitem_1,)""", # noqa: B950
)
@skipIfTorchDynamo("Skip due to graph break when run with dynamo")
def test_cond_autograd_grad_through_cond(self):
nn_module = torch.nn.Linear(4, 4)
def true_fn(x):
return nn_module(x)
def false_fn(X):
return x * nn_module(x)
x = torch.randn(4, requires_grad=True)
for pred, fn in zip(
[torch.tensor(False), torch.tensor(True)], [false_fn, true_fn]
):
result = cond(pred, true_fn, false_fn, (x,))
self.assertEqual(result, fn(x))
grad_out = torch.ones_like(result)
grads = torch.autograd.grad(result, (nn_module.weight,), grad_out)
expected_grads = torch.autograd.grad(
fn(
x,
),
(nn_module.weight,),
grad_out,
)
self.assertEqual(expected_grads, grads)
def f(pred, x):
result = cond(pred, true_fn, false_fn, (x,))
grad_out = torch.ones_like(result)
return torch.autograd.grad(result, (nn_module.weight,), grad_out)
# need to set _allow_non_fake_inputs = True because model parameters don't
# get fakified.
gm = make_fx(f, tracing_mode="symbolic", _allow_non_fake_inputs=True)(pred, x)
self.assertExpectedInline(
gm.code.strip(),
"""\
def forward(self, pred_1, x_1):
true_graph_0 = self.true_graph_0
false_graph_0 = self.false_graph_0
_param_constant0 = self._param_constant0
_param_constant1 = self._param_constant1
_tensor_constant0 = self._tensor_constant0
cond = torch.ops.higher_order.cond(pred_1, true_graph_0, false_graph_0, (_param_constant0, _param_constant1, x_1, _tensor_constant0)); true_graph_0 = false_graph_0 = _param_constant0 = _param_constant1 = _tensor_constant0 = None
getitem = cond[0]; cond = None
ones_like = torch.ops.aten.ones_like.default(getitem, pin_memory = False); getitem = None
true_graph_1 = self.true_graph_1
false_graph_1 = self.false_graph_1
_param_constant0_1 = self._param_constant0
_param_constant1_1 = self._param_constant1
_tensor_constant0_1 = self._tensor_constant0
cond_1 = torch.ops.higher_order.cond(pred_1, true_graph_1, false_graph_1, (ones_like, _param_constant0_1, _param_constant1_1, x_1, _tensor_constant0_1)); pred_1 = true_graph_1 = false_graph_1 = ones_like = _param_constant0_1 = _param_constant1_1 = x_1 = _tensor_constant0_1 = None
getitem_1 = cond_1[0]
getitem_2 = cond_1[1]
getitem_3 = cond_1[2]
getitem_4 = cond_1[3]; cond_1 = None
return (getitem_2,)""", # noqa: B950
)
def test_cond_in_forloop(self):
def for_loop_fake(x):
for i in range(3):
x = x * x + 1
return x
def for_loop_test(x):
for i in range(3):
pred = i < 3
def true_fn(x):
return x * x + 1
def false_fn(x):
return x
x = cond(pred, true_fn, false_fn, (x,))
return x
x = torch.ones(4, requires_grad=True)
x_new = for_loop_test(x)
x_exp = for_loop_fake(x)
self.assertEqual(x_new, x_exp)
grad_out = torch.ones_like(x_new)
grads = torch.autograd.grad(x_new, (x,), grad_out)
expected_grads = torch.autograd.grad(x_exp, (x,), grad_out)
self.assertEqual(expected_grads, grads)
def f(x):
x_new = for_loop_test(x)
grad_out = torch.ones_like(x_new)
return torch.autograd.grad(x_new, (x,), grad_out)
gm = make_fx(f, tracing_mode="symbolic")(x)
self.assertExpectedInline(
gm.code.strip(),
"""\
def forward(self, x_1):
mul = torch.ops.aten.mul.Tensor(x_1, x_1)
add = torch.ops.aten.add.Tensor(mul, 1); mul = None
mul_1 = torch.ops.aten.mul.Tensor(add, add)
add_1 = torch.ops.aten.add.Tensor(mul_1, 1); mul_1 = None
mul_2 = torch.ops.aten.mul.Tensor(add_1, add_1)
add_2 = torch.ops.aten.add.Tensor(mul_2, 1); mul_2 = None
ones_like = torch.ops.aten.ones_like.default(add_2, pin_memory = False); add_2 = None
mul_3 = torch.ops.aten.mul.Tensor(ones_like, add_1)
mul_4 = torch.ops.aten.mul.Tensor(ones_like, add_1); ones_like = add_1 = None
add_3 = torch.ops.aten.add.Tensor(mul_4, mul_3); mul_4 = mul_3 = None
mul_5 = torch.ops.aten.mul.Tensor(add_3, add)
mul_6 = torch.ops.aten.mul.Tensor(add_3, add); add_3 = add = None
add_4 = torch.ops.aten.add.Tensor(mul_6, mul_5); mul_6 = mul_5 = None
mul_7 = torch.ops.aten.mul.Tensor(add_4, x_1)
mul_8 = torch.ops.aten.mul.Tensor(add_4, x_1); add_4 = x_1 = None
add_5 = torch.ops.aten.add.Tensor(mul_8, mul_7); mul_8 = mul_7 = None
return (add_5,)""", # noqa: B950
)
@skipIfTorchDynamo("Skip due to graph break when run with dynamo")
def test_cond_autograd_pytree_not_all_inputs_used(self):
def true_fn(x):
return x["t"][0] + x["t"][1]["b"]
def false_fn(x):
return x["t"][0] * (x["t"][2][0] / x["t"][1]["b"])
a = torch.randn(4, requires_grad=True)
b = torch.randn(4, requires_grad=True)
c = torch.randn(4, requires_grad=True)
for pred, fn in zip(
[torch.tensor(False), torch.tensor(True)], [false_fn, true_fn]
):
result = cond(pred, true_fn, false_fn, ({"t": [a, {"b": b}, (c,)]},))
self.assertEqual(result, fn({"t": [a, {"b": b}, (c,)]}))
grad_out = torch.ones_like(result)
if pred:
with self.assertRaisesRegex(Exception, r"."):
grads = torch.autograd.grad(result, (a, b, c), grad_out)
expected_grads = torch.autograd.grad(
fn({"t": [a, {"b": b}, (c,)]}), (a, b, c), grad_out
)
self.assertEqual(expected_grads, grads)
def f(pred, a, b, c):
result = cond(pred, true_fn, false_fn, ({"t": [a, {"b": b}, (c,)]},))
grad_out = torch.ones_like(result)
return torch.autograd.grad(result, (a, b), grad_out)
gm = make_fx(f, tracing_mode="symbolic", _allow_non_fake_inputs=True)(
pred, a, b, c
)
self.assertExpectedInline(
gm.code.strip(),
"""\
def forward(self, pred_1, a_1, b_1, c_1):
true_graph_0 = self.true_graph_0
false_graph_0 = self.false_graph_0
cond = torch.ops.higher_order.cond(pred_1, true_graph_0, false_graph_0, (a_1, b_1, c_1)); true_graph_0 = false_graph_0 = None
getitem = cond[0]; cond = None
ones_like = torch.ops.aten.ones_like.default(getitem, pin_memory = False); getitem = None
true_graph_1 = self.true_graph_1
false_graph_1 = self.false_graph_1
cond_1 = torch.ops.higher_order.cond(pred_1, true_graph_1, false_graph_1, (ones_like, a_1, b_1, c_1)); pred_1 = true_graph_1 = false_graph_1 = ones_like = a_1 = b_1 = c_1 = None
getitem_1 = cond_1[0]
getitem_2 = cond_1[1]
getitem_3 = cond_1[2]; cond_1 = None
return (getitem_1, getitem_2)""", # noqa: B950
)
# Forward
self.assertExpectedInline(
gm.true_graph_0.code.strip(),
"""\
def forward(self, arg0_1, arg1_1, arg2_1):
add = torch.ops.aten.add.Tensor(arg0_1, arg1_1); arg0_1 = arg1_1 = None
return (add,)""",
)
# Backward
self.assertExpectedInline(
gm.true_graph_1.code.strip(),
"""\
def forward(self, arg0_1, arg1_1, arg2_1, arg3_1):
add = torch.ops.aten.add.Tensor(arg1_1, arg2_1); arg1_1 = arg2_1 = None
clone = torch.ops.aten.clone.default(arg0_1)
clone_1 = torch.ops.aten.clone.default(arg0_1); arg0_1 = None
return [clone, clone_1, None]""",
)
def test_cond_autograd_pytree_input(self):
def true_fn(x):
return x["t"][0] + x["t"][1]["b"] * x["t"][2][0]
def false_fn(x):
return x["t"][0] * (x["t"][2][0] / x["t"][1]["b"])
a = torch.randn(4, requires_grad=True)
b = torch.randn(4, requires_grad=True)
c = torch.randn(4, requires_grad=True)
for pred, fn in zip(
[torch.tensor(False), torch.tensor(True)], [false_fn, true_fn]
):
result = cond(pred, true_fn, false_fn, ({"t": [a, {"b": b}, (c,)]},))
self.assertEqual(result, fn({"t": [a, {"b": b}, (c,)]}))
grad_out = torch.ones_like(result)
grads = torch.autograd.grad(result, (a, b), grad_out)
expected_grads = torch.autograd.grad(
fn({"t": [a, {"b": b}, (c,)]}), (a, b), grad_out
)
self.assertEqual(expected_grads, grads)
def f(pred):
result = cond(pred, true_fn, false_fn, ({"t": [a, {"b": b}, (c,)]},))
grad_out = torch.ones_like(result)
return torch.autograd.grad(result, (a, b), grad_out)
# need to set _allow_non_fake_inputs = True because model parameters don't
# get fakified.
gm = make_fx(f, tracing_mode="symbolic", _allow_non_fake_inputs=True)(pred)
self.assertExpectedInline(
gm.code.strip(),
"""\
def forward(self, pred_1):
true_graph_0 = self.true_graph_0
false_graph_0 = self.false_graph_0
_tensor_constant0 = self._tensor_constant0
_tensor_constant1 = self._tensor_constant1
_tensor_constant2 = self._tensor_constant2
cond = torch.ops.higher_order.cond(pred_1, true_graph_0, false_graph_0, (_tensor_constant0, _tensor_constant1, _tensor_constant2)); true_graph_0 = false_graph_0 = _tensor_constant0 = _tensor_constant1 = _tensor_constant2 = None
getitem = cond[0]; cond = None
ones_like = torch.ops.aten.ones_like.default(getitem, pin_memory = False); getitem = None
true_graph_1 = self.true_graph_1
false_graph_1 = self.false_graph_1
_tensor_constant0_1 = self._tensor_constant0
_tensor_constant1_1 = self._tensor_constant1
_tensor_constant2_1 = self._tensor_constant2
cond_1 = torch.ops.higher_order.cond(pred_1, true_graph_1, false_graph_1, (ones_like, _tensor_constant0_1, _tensor_constant1_1, _tensor_constant2_1)); pred_1 = true_graph_1 = false_graph_1 = ones_like = _tensor_constant0_1 = _tensor_constant1_1 = _tensor_constant2_1 = None
getitem_1 = cond_1[0]
getitem_2 = cond_1[1]
getitem_3 = cond_1[2]; cond_1 = None
return (getitem_1, getitem_2)""", # noqa: B950
)
def test_cond_autograd_different_pytree_output(self):
def true_fn(x):
return x["t"][0], {"r": x["t"][2][0] / x["t"][1]["b"]}, [x["t"][2][0]]
def false_fn(x):
return {"res": [x["t"][0] * x["t"][1]["b"], x["t"][2][0]]}
a = torch.randn(4, requires_grad=True)
b = torch.randn(4, requires_grad=True)
c = torch.randn(4, requires_grad=True)
for pred, fn in zip(
[torch.tensor(False), torch.tensor(True)], [false_fn, true_fn]
):
with self.assertRaisesRegex(
torch._dynamo.exc.UncapturedHigherOrderOpError,
"Cond doesn't work unless it is captured completely with torch.compile",
):
cond(pred, true_fn, false_fn, ({"t": [a, {"b": b}, (c,)]},))
@skipIfTorchDynamo("Skip due to graph break when run with dynamo")
def test_cond_autograd_same_pytree_output(self):
def true_fn(x):
return {"res": [x["t"][0], (x["t"][2][0],)]}
def false_fn(x):
return {"res": [x["t"][1]["b"], (x["t"][2][0],)]}
a = torch.randn(4, requires_grad=True)
b = torch.randn(4, requires_grad=True)
c = torch.randn(4, requires_grad=True)
for pred, fn in zip(
[torch.tensor(False), torch.tensor(True)], [false_fn, true_fn]
):
result = cond(pred, true_fn, false_fn, ({"t": [a, {"b": b}, (c,)]},))
result_exp = fn({"t": [a, {"b": b}, (c,)]})
self.assertEqual(result, result_exp)
result_flat, _ = pytree.tree_flatten(result)
result_exp_flat, _ = pytree.tree_flatten(result_exp)
grad_out = [torch.ones_like(g) for g in result_flat]
expected_grads = torch.autograd.grad(result_exp_flat, (c,), grad_out)
grads = torch.autograd.grad(result_flat, (c,), grad_out)
self.assertEqual(expected_grads, grads)
def f(pred):
result = cond(pred, true_fn, false_fn, ({"t": [a, {"b": b}, (c,)]},))
return result
gm = make_fx(f, tracing_mode="symbolic", _allow_non_fake_inputs=True)(pred)
self.assertExpectedInline(
gm.code.strip(),
"""\
def forward(self, pred_1):
true_graph_0 = self.true_graph_0
false_graph_0 = self.false_graph_0
_tensor_constant0 = self._tensor_constant0
_tensor_constant1 = self._tensor_constant1
_tensor_constant2 = self._tensor_constant2
cond = torch.ops.higher_order.cond(pred_1, true_graph_0, false_graph_0, (_tensor_constant0, _tensor_constant1, _tensor_constant2)); pred_1 = true_graph_0 = false_graph_0 = _tensor_constant0 = _tensor_constant1 = _tensor_constant2 = None
getitem = cond[0]
getitem_1 = cond[1]; cond = None
view = torch.ops.aten.view.default(getitem, [4]); getitem = None
view_1 = torch.ops.aten.view.default(getitem_1, [4]); getitem_1 = None
return {'res': [view, (view_1,)]}""", # noqa: B950
)
@skipIfTorchDynamo("Skip due to graph break when run with dynamo")
def test_cond_autograd_torch_nn_module(self):
nn_module_true = torch.nn.Linear(4, 4)
def true_fn(x):
return nn_module_true(torch.abs((x**2).sin()))
nn_module_false = torch.nn.GRUCell(4, 4)
def false_fn(x):
return nn_module_false((x + 42).cos())
for pred, fn in zip(
[torch.tensor(False), torch.tensor(True)], [false_fn, true_fn]
):
x = torch.randn(4, requires_grad=True)
result = cond(pred, true_fn, false_fn, (x,))
self.assertEqual(result, fn(x))
grad_out = torch.ones_like(result)
grads = torch.autograd.grad(result, (x,), grad_out)
expected_grads = torch.autograd.grad(fn(x), (x,), grad_out)
self.assertEqual(expected_grads, grads)
def f(pred, x):
result = cond(pred, true_fn, false_fn, (x,))
grad_out = torch.ones_like(result)
return torch.autograd.grad(result, (x,), grad_out)
gm = make_fx(f)(pred, x)
self.assertExpectedInline(
gm.code.strip(),
"""\
def forward(self, pred_1, x_1):
true_graph_0 = self.true_graph_0
false_graph_0 = self.false_graph_0
_param_constant0 = self._param_constant0
_param_constant1 = self._param_constant1
_param_constant2 = self._param_constant2
_param_constant3 = self._param_constant3
_param_constant4 = self._param_constant4
_param_constant5 = self._param_constant5
cond = torch.ops.higher_order.cond(pred_1, true_graph_0, false_graph_0, (x_1, _param_constant0, _param_constant1, _param_constant2, _param_constant3, _param_constant4, _param_constant5)); true_graph_0 = false_graph_0 = _param_constant0 = _param_constant1 = _param_constant2 = _param_constant3 = _param_constant4 = _param_constant5 = None
getitem = cond[0]; cond = None
ones_like = torch.ops.aten.ones_like.default(getitem, pin_memory = False); getitem = None
true_graph_1 = self.true_graph_1
false_graph_1 = self.false_graph_1
_param_constant0_1 = self._param_constant0
_param_constant1_1 = self._param_constant1
_param_constant2_1 = self._param_constant2
_param_constant3_1 = self._param_constant3
_param_constant4_1 = self._param_constant4
_param_constant5_1 = self._param_constant5
cond_1 = torch.ops.higher_order.cond(pred_1, true_graph_1, false_graph_1, (ones_like, x_1, _param_constant0_1, _param_constant1_1, _param_constant2_1, _param_constant3_1, _param_constant4_1, _param_constant5_1)); pred_1 = true_graph_1 = false_graph_1 = ones_like = x_1 = _param_constant0_1 = _param_constant1_1 = _param_constant2_1 = _param_constant3_1 = _param_constant4_1 = _param_constant5_1 = None
getitem_1 = cond_1[0]
getitem_2 = cond_1[1]
getitem_3 = cond_1[2]
getitem_4 = cond_1[3]
getitem_5 = cond_1[4]
getitem_6 = cond_1[5]
getitem_7 = cond_1[6]; cond_1 = None
return (getitem_1,)""", # noqa: B950
)
def test_cond_autograd_user_nn_module(self):
class User_nn_module(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, input):
return input * input
nn_module_true = User_nn_module()
def true_fn(x):
return nn_module_true(torch.abs((x**2).sin()))
nn_module_false = torch.nn.ReLU(inplace=False)
def false_fn(x):
return nn_module_false((x + 42).cos())
for pred, fn in zip(
[torch.tensor(False), torch.tensor(True)], [false_fn, true_fn]
):
x = torch.randn(4, requires_grad=True)
result = cond(pred, true_fn, false_fn, (x,))
self.assertEqual(result, fn(x))
grad_out = torch.ones_like(result)
grads = torch.autograd.grad(result, (x,), grad_out)
expected_grads = torch.autograd.grad(fn(x), (x,), grad_out)
self.assertEqual(expected_grads, grads)
def f(pred, x):
result = cond(pred, true_fn, false_fn, (x,))
grad_out = torch.ones_like(result)
return torch.autograd.grad(result, (x,), grad_out)
gm = make_fx(f)(pred, x)
self.assertExpectedInline(
gm.code.strip(),
"""\
def forward(self, pred_1, x_1):
true_graph_0 = self.true_graph_0
false_graph_0 = self.false_graph_0
cond = torch.ops.higher_order.cond(pred_1, true_graph_0, false_graph_0, (x_1,)); true_graph_0 = false_graph_0 = None
getitem = cond[0]; cond = None
ones_like = torch.ops.aten.ones_like.default(getitem, pin_memory = False); getitem = None
true_graph_1 = self.true_graph_1
false_graph_1 = self.false_graph_1
cond_1 = torch.ops.higher_order.cond(pred_1, true_graph_1, false_graph_1, (ones_like, x_1)); pred_1 = true_graph_1 = false_graph_1 = ones_like = x_1 = None
getitem_1 = cond_1[0]; cond_1 = None
return (getitem_1,)""", # noqa: B950
)
def test_cond_autograd_inner_fn(self):
def true_fn(x):
return torch.abs((x**2).sin())
def false_fn(x):
def inner_fn(x):
return x**2
return torch.abs(inner_fn(x).sin())
x = torch.randn(4, requires_grad=True)
pred = torch.tensor(False)
fn = false_fn
result_false = cond(pred, true_fn, false_fn, (x,))
self.assertEqual(result_false, fn(x))
grad_out = torch.ones_like(result_false)
grads_false = torch.autograd.grad(result_false, (x,), grad_out)
expected_grads = torch.autograd.grad(fn(x), (x,), grad_out)
self.assertEqual(expected_grads, grads_false)
pred = torch.tensor(True)
fn = true_fn
result_true = cond(pred, true_fn, false_fn, (x,))
self.assertEqual(result_true, fn(x))
self.assertEqual(result_false, result_true)
grad_out = torch.ones_like(result_true)
grads_true = torch.autograd.grad(result_true, (x,), grad_out)
expected_grads = torch.autograd.grad(fn(x), (x,), grad_out)
self.assertEqual(expected_grads, grads_true)
self.assertEqual(grads_false, grads_true)
def f(pred, x):
result = cond(pred, true_fn, false_fn, (x,))
grad_out = torch.ones_like(result)
return torch.autograd.grad(result, (x,), grad_out)
gm = make_fx(f)(pred, x)
self.assertExpectedInline(
gm.code.strip(),
"""\
def forward(self, pred_1, x_1):
true_graph_0 = self.true_graph_0
false_graph_0 = self.false_graph_0
cond = torch.ops.higher_order.cond(pred_1, true_graph_0, false_graph_0, (x_1,)); true_graph_0 = false_graph_0 = None
getitem = cond[0]; cond = None
ones_like = torch.ops.aten.ones_like.default(getitem, pin_memory = False); getitem = None
true_graph_1 = self.true_graph_1
false_graph_1 = self.false_graph_1
cond_1 = torch.ops.higher_order.cond(pred_1, true_graph_1, false_graph_1, (ones_like, x_1)); pred_1 = true_graph_1 = false_graph_1 = ones_like = x_1 = None
getitem_1 = cond_1[0]; cond_1 = None
return (getitem_1,)""", # noqa: B950
)
def test_cond_autograd_inner_tensor(self):
def true_fn(x):
return torch.abs((x**2).sin())
def false_fn(x):
y = torch.ones(4, requires_grad=False) * 42
return (x * y).cos()
for pred, fn in zip(
[torch.tensor(False), torch.tensor(True)], [false_fn, true_fn]
):
x = torch.randn(4, requires_grad=True)
result = cond(pred, true_fn, false_fn, (x,))
self.assertEqual(result, fn(x))
grad_out = torch.ones_like(result)
grads = torch.autograd.grad(result, (x,), grad_out)
expected_grads = torch.autograd.grad(fn(x), (x,), grad_out)
self.assertEqual(expected_grads, grads)
def f(pred, x):
result = cond(pred, true_fn, false_fn, (x,))
grad_out = torch.ones_like(result)
return torch.autograd.grad(result, (x,), grad_out)
gm = make_fx(f, tracing_mode="symbolic")(pred, x)
self.assertExpectedInline(
gm.code.strip(),
"""\
def forward(self, pred_1, x_1):
true_graph_0 = self.true_graph_0
false_graph_0 = self.false_graph_0
cond = torch.ops.higher_order.cond(pred_1, true_graph_0, false_graph_0, (x_1,)); true_graph_0 = false_graph_0 = None
getitem = cond[0]; cond = None
ones_like = torch.ops.aten.ones_like.default(getitem, pin_memory = False); getitem = None
true_graph_1 = self.true_graph_1
false_graph_1 = self.false_graph_1
cond_1 = torch.ops.higher_order.cond(pred_1, true_graph_1, false_graph_1, (ones_like, x_1)); pred_1 = true_graph_1 = false_graph_1 = ones_like = x_1 = None
getitem_1 = cond_1[0]; cond_1 = None
return (getitem_1,)""", # noqa: B950
)
@unittest.skipIf(not torch.cuda.is_available(), "Test requires CUDA.")
def test_cond_autograd_gpu(self):
def true_fn(x):
return x.sin()
def false_fn(x):
return x.cos()
for pred, fn in zip(
[torch.tensor(False, device="cuda"), torch.tensor(True, device="cuda")],
[false_fn, true_fn],
):
x = torch.randn(4, requires_grad=True, device="cuda")
result = cond(pred, true_fn, false_fn, (x,))
self.assertEqual(result, fn(x))
grad_out = torch.ones_like(result)
grads = torch.autograd.grad(result, (x,), grad_out)
expected_grads = torch.autograd.grad(fn(x), (x,), grad_out)
self.assertEqual(expected_grads, grads)
@unittest.skipIf(not torch.cuda.is_available(), "Test requires CUDA.")
def test_map_gpu(self):
def f(x, y):
@ -1221,74 +476,6 @@ class TestControlFlowTraced(TestCase):
graph = make_fx(f, tracing_mode="symbolic")(x, torch.tensor(False))
self.assertEqual(graph(x, torch.tensor(True)), f(x, torch.tensor(True)))
@skipIfTorchDynamo("Graph is not captured by backend if test with dynamo")
def test_cond_simple_with_linear_compile_check_graph(self):
from torch._dynamo.testing import EagerAndRecordGraphs
def true_fn(x):
return x.sin()
def false_fn(x):
return x.cos()
x = torch.randn(4, requires_grad=True)
def f(pred, x):
result = cond(pred, true_fn, false_fn, (x,))
grad_out = torch.ones_like(result)
return torch.autograd.grad(result, (x,), grad_out)
backend = EagerAndRecordGraphs()
torch.compile(f, backend=backend)(torch.tensor(False), x)
self.assertEqual(len(backend.graphs), 2)
gm = backend.graphs[0]
self.assertExpectedInline(
gm.code.strip(),
"""\
def forward(self, L_pred_ : torch.Tensor, L_x_ : torch.Tensor):
l_pred_ = L_pred_
l_x_ = L_x_
cond_true_0 = self.cond_true_0
cond_false_0 = self.cond_false_0
cond = torch.ops.higher_order.cond(l_pred_, cond_true_0, cond_false_0, [l_x_]); l_pred_ = cond_true_0 = cond_false_0 = l_x_ = None
result = cond[0]; cond = None
grad_out = torch.ones_like(result)
return (result, grad_out)""", # noqa: B950
)
self.assertExpectedInline(
gm.cond_true_0.code.strip(),
"""\
def forward(self, l_x_):
l_x__1 = l_x_
sin = l_x__1.sin(); l_x__1 = None
return (sin,)""", # noqa: B950
)
self.assertExpectedInline(
gm.cond_false_0.code.strip(),
"""\
def forward(self, l_x_):
l_x__1 = l_x_
cos = l_x__1.cos(); l_x__1 = None
return (cos,)""", # noqa: B950
)
backward_gm = backend.graphs[1]
self.assertExpectedInline(
backward_gm.code.strip(),
"""\
def forward(self, L_ctx_saved_tensors_0_ : torch.Tensor, L_ctx_pred : torch.Tensor, L_flat_grads_0_ : torch.Tensor):
l_ctx_saved_tensors_0_ = L_ctx_saved_tensors_0_
l_ctx_pred = L_ctx_pred
l_flat_grads_0_ = L_flat_grads_0_
cond_true_0 = self.cond_true_0
cond_false_0 = self.cond_false_0
cond = torch.ops.higher_order.cond(l_ctx_pred, cond_true_0, cond_false_0, [l_ctx_saved_tensors_0_, l_flat_grads_0_]); l_ctx_pred = cond_true_0 = cond_false_0 = l_ctx_saved_tensors_0_ = l_flat_grads_0_ = None
getitem = cond[0]; cond = None
return (getitem,)""", # noqa: B950
)
def test_while_loop_nested_traced(self):
fn, inp = WHILE_LOOP_TESTS["nested"]
graphs = self._check_tracing(fn, inp)
@ -2724,7 +1911,7 @@ def forward(self, arg0_1):
):
functional_f(*example_inputs)
def test_cond_autograd_backward(self):
def test_cond_autograd_succeed_when_pred_is_constant(self):
def true_fn(x):
return x.cos()
@ -2738,14 +1925,32 @@ def forward(self, arg0_1):
torch.ones(3, 2, 4, requires_grad=True),
torch.ones(4, requires_grad=True),
)
# Due to x.shape[0] can be statically evaluated to be False, we can evaluate
# the backward.
f(*example_inputs).sum().backward()
# Ensure no error is thrown when not running backward
res = f(*example_inputs)
f(*example_inputs)
def test_cond_autograd_fail(self):
def true_fn(x):
return x.cos()
def false_fn(x):
return x.sin()
def f(x, y):
return control_flow.cond(x.sum() > 4, true_fn, false_fn, [y])
example_inputs = (
torch.ones(3, 2, 4, requires_grad=True),
torch.ones(4, requires_grad=True),
)
with self.assertRaisesRegex(RuntimeError, "Autograd not implemented for cond"):
f(*example_inputs).sum().backward()
# Ensure no error is thrown when not running backward
res_compiled = torch.compile(f)(*example_inputs)
self.assertEqual(res, res_compiled)
f(*example_inputs)
# https://github.com/pytorch/pytorch/issues/126988
@xfailIfTorchDynamo

View File

@ -13,29 +13,26 @@ from torch._C._functorch import (
is_batchedtensor,
maybe_get_bdim,
)
from torch._dispatch.python import suspend_functionalization
from torch._functorch.utils import exposed_in
from torch._guards import detect_fake_mode
from torch._higher_order_ops.utils import (
_has_potential_branch_input_alias,
_has_potential_branch_input_mutation,
_set_compilation_env,
autograd_not_implemented,
reenter_make_fx,
unique_graph_id,
UnsupportedAliasMutationException,
)
from torch._ops import HigherOrderOperator
from torch._subclasses.fake_tensor import FakeTensorMode
from torch._subclasses.functional_tensor import disable_functional_mode
from torch.fx.experimental.proxy_tensor import (
_temp_remove_pre_dispatch_torch_function_mode,
disable_proxy_modes_tracing,
ProxyTorchDispatchMode,
track_tensor_tree,
)
from torch.fx.passes.shape_prop import _extract_tensor_metadata
from torch.utils._python_dispatch import _get_current_dispatch_mode
from .utils import _from_fun, create_fw_bw_graph
log = logging.getLogger(__name__)
@ -105,6 +102,8 @@ def cond(pred, true_fn, false_fn, operands):
.. warning::
Temporal Limitations:
- `cond` only supports **inference** right now. Autograd will be supported in the future.
- The **output** of branches must be a **single Tensor**. Pytree of tensors will be supported in the future.
"""
@ -159,48 +158,6 @@ We're going to define a `cond_op` operation.
In order to do this, we need implementations for each of the dispatch keys.
"""
cond_op = HigherOrderOperator("cond")
cond_op.__module__ = "torch.ops.higher_order"
def create_fw_bw_graph_branches(true_fn, false_fn, *operands):
# See Note [HOP create fw_bw graph] in create_fw_bw_graph in utils.py
with suspend_functionalization(), disable_functional_mode():
with disable_proxy_modes_tracing():
fw_inputs = pytree.tree_map(_from_fun, operands)
fw_outputs_true = pytree.tree_map(_from_fun, true_fn(*fw_inputs))
if any(
not isinstance(out, torch.Tensor)
for out in fw_outputs_true
if out is not None
):
raise RuntimeError(
"Expect outputs of true_fn to only contains tensors or None. "
f"Got types {[type(out) for out in fw_outputs_true]}."
)
fw_outputs_false = pytree.tree_map(_from_fun, false_fn(*fw_inputs))
if any(
not isinstance(out, torch.Tensor)
for out in fw_outputs_false
if out is not None
):
raise RuntimeError(
"Expect outputs of false_fn to only contains tensors or None. "
f"Got types {[type(out) for out in fw_outputs_false]}."
)
# TODO: There is a major issue that the create_fw_bw in the higher_order_op is invoked twice:
# Once in the forward path (as it should) and once in the backward path, where it shouldn't be called
# If we can get rid of the second invokation, it would simplify this function
fw_true_graph, joint_true_graph = create_fw_bw_graph(
true_fn, False, fw_inputs, fw_outputs_true
)
fw_false_graph, joint_false_graph = create_fw_bw_graph(
false_fn, False, fw_inputs, fw_outputs_false
)
return fw_true_graph, fw_false_graph, joint_true_graph, joint_false_graph
def trace_cond(proxy_mode, func_overload, pred, true_fn, false_fn, operands):
@ -229,53 +186,14 @@ def trace_cond(proxy_mode, func_overload, pred, true_fn, false_fn, operands):
if len(flat_true_outs) != len(flat_false_outs):
raise torch._dynamo.exc.CondOpArgsMismatchError(
f"Expected to return same number of outputs but got:"
f"\n true branch returns {len(flat_true_outs)} item(s)"
f"\n false branch returns {len(flat_false_outs)} item(s)"
f"\n {true_fn.__name__} returns {len(flat_true_outs)} item(s)"
f"\n {false_fn.__name__} returns {len(flat_false_outs)} item(s)"
)
for i in range(0, len(flat_true_outs)):
true_out = flat_true_outs[i]
false_out = flat_false_outs[i]
# Note that we need skip the check for requires_grad because we're after
# after autograd key during tracing, so the rquires_grad attribute of the tensors
# are no longer. See Note [invariants for node meta 'val']
def _same_meta_except_requires_grad(true_out, false_out):
if true_out is None and false_out is None:
return True
elif true_out is None or false_out is None:
# Consider the following case:
# def true_fn(x, y):
# return x * y
#
# def false_fn(x, y):
# return x.sin()
#
# We'll get the following graphs for backward:
# def backward_true_fn(x, y, grad_out):
# return grad_out * y, grad_out * x
#
# def backward_false_fn(x, y, grad_out):
# retrun grad_out, None
#
# This suggests that when we make_fx into the backward graph,
# the output graph would produce outputs with metadata, this is undesirable.
#
# Ideally, we should provide an optional type to indicate that one of the branches might
# return None. But we'll just let it pass for now and let downstream/runtime handle.
#
# Note that this corner case should **only** happen when user want to trace backward graph because
# if it's foward, dynamo will error.
return True
true_meta = true_out.meta.get("tensor_meta", None)
false_meta = false_out.meta.get("tensor_meta", None)
return (
true_meta.shape == false_meta.shape
and true_meta.dtype == false_meta.dtype
and true_meta.stride == false_meta.stride
)
if not _same_meta_except_requires_grad(true_out, false_out):
if true_out.meta["tensor_meta"] != false_out.meta["tensor_meta"]:
raise torch._dynamo.exc.CondOpArgsMismatchError(
f"Expected each tensor to have same metadata but got:"
f"\n {true_fn.__name__} returns {true_out.meta['tensor_meta']}"
@ -335,65 +253,9 @@ def cond_op_dense(pred, true_fn, false_fn, operands):
return false_fn(*operands)
class CondAutogradOp(torch.autograd.Function):
@staticmethod
def forward(
ctx,
pred,
fw_true_graph,
fw_false_graph,
joint_true_graph,
joint_false_graph,
*operands,
):
ctx._pred = pred
ctx._joint_true_graph = joint_true_graph
ctx._joint_false_graph = joint_false_graph
ctx.save_for_backward(*operands)
with torch._C._AutoDispatchBelowAutograd():
return cond_op(pred, fw_true_graph, fw_false_graph, operands)
@staticmethod
def backward(ctx, *flat_grads):
operands = ctx.saved_tensors
grads = cond_op(
ctx._pred,
ctx._joint_true_graph,
ctx._joint_false_graph,
flat_grads + operands,
)
return None, None, None, None, None, *grads
@cond_op.py_impl(DispatchKey.Autograd)
def cond_autograd(pred, true_fn, false_fn, operands):
# A shortcut for the case where all inputs don't require gradient,
# we skip tracing the forward and backward graph.
if all(
not t.requires_grad
for t in pytree.tree_flatten((pred, operands))[0]
if isinstance(t, torch.Tensor)
):
with torch._C._AutoDispatchBelowAutograd():
return cond_op(pred, true_fn, false_fn, operands)
(
fw_true_graph,
fw_false_graph,
joint_true_graph,
joint_false_graph,
) = create_fw_bw_graph_branches(true_fn, false_fn, *operands)
flat_out = CondAutogradOp.apply(
pred,
fw_true_graph,
fw_false_graph,
joint_true_graph,
joint_false_graph,
*operands,
)
return flat_out
cond_op.py_impl(DispatchKey.Autograd)(
autograd_not_implemented(cond_op, deferred_error=True)
)
@cond_op.py_impl(ProxyTorchDispatchMode)

View File

@ -3,7 +3,7 @@ import torch
import torch.utils._pytree as pytree
from torch._C import DispatchKey
from torch._dispatch.python import suspend_functionalization
from torch._functorch.aot_autograd import AOTConfig, create_joint
from torch._functorch.aot_autograd import AOTConfig, create_joint, from_fun
from torch._higher_order_ops.utils import (
_has_potential_branch_input_alias,
_has_potential_branch_input_mutation,
@ -12,20 +12,17 @@ from torch._higher_order_ops.utils import (
)
from torch._ops import HigherOrderOperator
from torch._subclasses.fake_tensor import FakeTensorMode
from torch._subclasses.functional_tensor import disable_functional_mode
from torch._subclasses.functional_tensor import (
disable_functional_mode,
FunctionalTensor,
)
from torch.fx.experimental.proxy_tensor import (
disable_proxy_modes_tracing,
make_fx,
ProxyTorchDispatchMode,
track_tensor_tree,
)
from .utils import (
_from_fun,
_stack_pytree,
_unstack_pytree,
clone_outputs_aliasing_inputs,
prepare_fw_with_masks,
)
from torch.multiprocessing.reductions import StorageWeakRef
# TODO: We add this to prevent dymamo from tracing into map_wrapper,
@ -53,10 +50,49 @@ def create_fw_bw_graph(f, num_mapped_args, *args):
mapped_xs = args[:num_mapped_args]
pos_args = args[num_mapped_args:]
# See Note [HOP create fw_bw graph] in create_fw_bw_graph in utils.py
# Note:[HOP create fw_bw graph] We create "clean" environments for make_fx by suspending all dispatch keys
# between Autograd and Python key. Currently, we only suspend functionalization but more can be
# added when required. Will encounter two problems if we don't suspend functionalization:
#
# 1. make_fx fails to capture operations on input: the inputs are wrapped as _to_functional_tensor_wrapper,
# but they will be unwrapped before entering ProxyTorchDispatchMode as part of the dispatching.
# However, it's the outside wrapper that tracer creates proxies for. This casuses tracer fail to
# fetch the proxy for the inputs and fail to capture any operations on them.
#
# 2. make_fx fails to capture output: the outputs after ProxyTorchDispatchMode are further
# wrapped as FunctionalTensorWrapper in Functionalize key after return. However, the tracer
# only associates the inner tensor with proxy in ProxyTorchDispatchMode. Therefore,
# when creating the output node, it fails to associate the wrapped tensor with its proxy.
# Instead, it will create _tensor_constant as output.
with suspend_functionalization(), disable_functional_mode():
with disable_proxy_modes_tracing():
def _from_fun(t):
if isinstance(t, torch.Tensor):
if t.dtype != torch.bool:
return torch.empty_strided(
t.size(),
t.stride(),
dtype=t.dtype,
requires_grad=t.requires_grad,
)
else:
# clone of a functional tensor produces a functional tensor
# but we want to avoid it so we clone a non-functional version
maybe_unfunc_t = t
if isinstance(t, FunctionalTensor):
torch._sync(t)
maybe_unfunc_t = from_fun(t)
elif torch._is_functional_tensor(t):
# need to handle both types of functionalization here:
# these are the tensors that came from the user,
# which could be either FunctionalTensorWrapper or FunctionalTensor
torch._sync(t)
maybe_unfunc_t = torch._from_functional_tensor(t)
return maybe_unfunc_t.clone()
return t
unwrapped_mapped_xs = pytree.tree_map(_from_fun, mapped_xs)
example_xs = _unstack_pytree(unwrapped_mapped_xs)[0]
@ -87,7 +123,16 @@ def create_fw_bw_graph(f, num_mapped_args, *args):
mapped_input = joint_mapped_args[:num_mapped_args]
mapped_grads = joint_mapped_args[num_mapped_args:]
joint = create_joint(prepare_fw_with_masks(f), aot_config=dummy_aot_config)
def fw_with_masks(*args):
fw_out = f(*args)
return fw_out, [
True
if isinstance(ret, torch.Tensor) and ret.requires_grad
else False
for ret in fw_out
]
joint = create_joint(fw_with_masks, aot_config=dummy_aot_config)
_, grads = joint(
list(mapped_input) + list(args),
[
@ -99,7 +144,19 @@ def create_fw_bw_graph(f, num_mapped_args, *args):
# In order to keep map functional for backward graph,
# we clone outputs that are aliasing inputs
maybe_clone = clone_outputs_aliasing_inputs(example_args)
input_storage = {
StorageWeakRef(arg._typed_storage())
for arg in example_args
if isinstance(arg, torch.Tensor)
}
def maybe_clone(t):
if (
isinstance(t, torch.Tensor)
and StorageWeakRef(t._typed_storage()) in input_storage
):
return t.clone()
return t
return pytree.tree_map(maybe_clone, grads)
@ -199,6 +256,46 @@ def trace_map(proxy_mode, func_overload, f, xs, pos_args):
)
def _unstack_pytree(xs):
flat_xs, inspec = pytree.tree_flatten(xs)
if not all(isinstance(xs, torch.Tensor) for xs in flat_xs):
raise RuntimeError(f"Leaves of xs must be Tensor {flat_xs}")
if not all(xs.shape[0] == flat_xs[0].shape[0] for xs in flat_xs):
raise RuntimeError(
f"Leaves of xs must have same leading dimension size {[xs.shape for xs in flat_xs]}"
)
a = zip(*flat_xs)
pytrees = []
for tuple in a:
pytrees.append(pytree.tree_unflatten(tuple, inspec))
return pytrees
def _stack_pytree(pytrees):
flat_out = []
out_spec = None
for pt in pytrees:
flat_pt, out_spec = pytree.tree_flatten(pt)
flat_out.append(flat_pt)
assert out_spec is not None
b = zip(*flat_out)
stacked_out = []
for leaves in b:
if all(isinstance(leaf, torch.Tensor) for leaf in leaves):
stacked_out.append(torch.stack(leaves))
elif all(leaf is None for leaf in leaves):
# Backward graph can return None output when forward inputs doesn't require grad.
# When we eagerly execute backward graph, we need to call _stack_pytree on its output,
# therefore we need to deal with None output.
stacked_out.append(None) # type: ignore[arg-type]
else:
raise RuntimeError(f"Cannot stack {leaves}.")
return pytree.tree_unflatten(stacked_out, out_spec)
@map_impl.py_impl(DispatchKey.CompositeExplicitAutograd)
def map_dense(f, xs, pos_args):
pytrees = []

View File

@ -93,15 +93,6 @@ def reenter_make_fx(fn):
return wrapped
def _maybe_reenter_make_fx(fn):
from torch.fx.experimental.proxy_tensor import _CURRENT_MAKE_FX_TRACER
if _CURRENT_MAKE_FX_TRACER is not None:
return reenter_make_fx(fn)
else:
return make_fx(fn)
@contextmanager
def _set_compilation_env():
_old_is_tracing = torch.fx._symbolic_trace._is_fx_tracing_flag
@ -219,169 +210,3 @@ def unique_graph_id(proxy_mode, prefix):
else:
next_name = candidate
return i, next_name
def _from_fun(t):
from torch._functorch.aot_autograd import from_fun
from torch._subclasses.functional_tensor import FunctionalTensor
if isinstance(t, torch.Tensor):
if t.dtype != torch.bool:
return torch.empty_strided(
t.size(),
t.stride(),
dtype=t.dtype,
requires_grad=t.requires_grad,
)
else:
# clone of a functional tensor produces a functional tensor
# but we want to avoid it so we clone a non-functional version
maybe_unfunc_t = t
if isinstance(t, FunctionalTensor):
torch._sync(t)
maybe_unfunc_t = from_fun(t)
elif torch._is_functional_tensor(t):
# need to handle both types of functionalization here:
# these are the tensors that came from the user,
# which could be either FunctionalTensorWrapper or FunctionalTensor
torch._sync(t)
maybe_unfunc_t = torch._from_functional_tensor(t)
return maybe_unfunc_t.clone()
return t
def clone_outputs_aliasing_inputs(args):
input_storage = {
StorageWeakRef(arg._typed_storage())
for arg in args
if isinstance(arg, torch.Tensor)
}
def maybe_clone(t):
if (
isinstance(t, torch.Tensor)
and StorageWeakRef(t._typed_storage()) in input_storage
):
return t.clone()
return t
return maybe_clone
def prepare_fw_with_masks(fn):
def fw_with_masks(*args):
fw_out = fn(*args)
return fw_out, [
True if isinstance(ret, torch.Tensor) and ret.requires_grad else False
for ret in fw_out
]
return fw_with_masks
# TODO: The parameter use_output_and_grad_bw is required because some operations
# that utilize this function, such as the while_loop, may require (grad, fwd_outputs)
def create_fw_bw_graph(fn, use_output_and_grad_bw, fw_inputs, fw_outputs):
from torch._functorch.aot_autograd import AOTConfig, create_joint
# Note:[HOP create fw_bw graph] We create "clean" environments for make_fx by suspending all dispatch keys
# between Autograd and Python key. Currently, we only suspend functionalization but more can be
# added when required. Will encounter two problems if we don't suspend functionalization:
#
# 1. make_fx fails to capture operations on input: the inputs are wrapped as _to_functional_tensor_wrapper,
# but they will be unwrapped before entering ProxyTorchDispatchMode as part of the dispatching.
# However, it's the outside wrapper that tracer creates proxies for. This casuses tracer fail to
# fetch the proxy for the inputs and fail to capture any operations on them.
#
# 2. make_fx fails to capture output: the outputs after ProxyTorchDispatchMode are further
# wrapped as FunctionalTensorWrapper in Functionalize key after return. However, the tracer
# only associates the inner tensor with proxy in ProxyTorchDispatchMode. Therefore,
# when creating the output node, it fails to associate the wrapped tensor with its proxy.
# Instead, it will create _tensor_constant as output.
dummy_aot_config = AOTConfig(
fw_compiler=None, # type: ignore[arg-type]
bw_compiler=None, # type: ignore[arg-type]
partition_fn=None, # type: ignore[arg-type]
decompositions={},
num_params_buffers=0,
aot_id=0,
keep_inference_input_mutations=False,
)
example_grad = [_from_fun(out) for out in fw_outputs]
num_grads = len(example_grad)
fw_graph = _maybe_reenter_make_fx(fn)(*fw_inputs)
def joint_fn(*joint_operands_grads):
if use_output_and_grad_bw:
grads = joint_operands_grads[0]
inputs = joint_operands_grads[1][-1:]
else:
grads = joint_operands_grads[:num_grads]
inputs = joint_operands_grads[num_grads:]
joint = create_joint(prepare_fw_with_masks(fn), aot_config=dummy_aot_config)
_, grads = joint(
list(inputs),
[grad for grad in grads if grad is not None and grad.requires_grad],
)
# In order to keep map functional for backward graph,
# we clone outputs that are aliasing inputs
maybe_clone = clone_outputs_aliasing_inputs(joint_operands_grads)
return pytree.tree_map(maybe_clone, grads)
if use_output_and_grad_bw:
example_xs_out = list(fw_inputs) + list(fw_outputs)
joint_graph = _maybe_reenter_make_fx(joint_fn)(
(list(example_grad), list(example_xs_out))
)
else:
example_xs_out = list(fw_inputs)
joint_graph = _maybe_reenter_make_fx(joint_fn)(
*(list(example_grad) + list(example_xs_out))
)
return fw_graph, joint_graph
def _unstack_pytree(xs):
flat_xs, inspec = pytree.tree_flatten(xs)
if not all(isinstance(xs, torch.Tensor) for xs in flat_xs):
raise RuntimeError(f"Leaves of xs must be Tensor {flat_xs}")
if not all(xs.shape[0] == flat_xs[0].shape[0] for xs in flat_xs):
raise RuntimeError(
f"Leaves of xs must have same leading dimension size {[xs.shape for xs in flat_xs]}"
)
a = zip(*flat_xs)
pytrees = []
for tuple in a:
pytrees.append(pytree.tree_unflatten(tuple, inspec))
return pytrees
def _stack_pytree(pytrees):
flat_out = []
out_spec = None
for pt in pytrees:
flat_pt, out_spec = pytree.tree_flatten(pt)
flat_out.append(flat_pt)
assert out_spec is not None
b = zip(*flat_out)
stacked_out = []
for leaves in b:
if all(isinstance(leaf, torch.Tensor) for leaf in leaves):
stacked_out.append(torch.stack(leaves))
elif all(leaf is None for leaf in leaves):
# Backward graph can return None output when forward inputs doesn't require grad.
# When we eagerly execute backward graph, we need to call _stack_pytree on its output,
# therefore we need to deal with None output.
stacked_out.append(None) # type: ignore[arg-type]
else:
raise RuntimeError(f"Cannot stack {leaves}.")
return pytree.tree_unflatten(stacked_out, out_spec)

View File

@ -315,7 +315,6 @@ def extract_val(val: _ExtractValType) -> _ExtractValType:
typing_extensions.assert_never(val)
# Note [invariants for node meta 'val']
# What invariants do we have for the 'val' set on the FX node? It has accurate
# metadata... but only for metadata that exists "below" all other subsystems
# (most notably autograd, but also vmap, functorch transforms, etc). This means

View File

@ -91,13 +91,13 @@ def foo_impl_abstract(x, z):
def sample_inputs_cond(opinfo, device, dtype, requires_grad, **kwargs):
make_arg = functools.partial(
make_tensor, device=device, dtype=dtype, requires_grad=requires_grad
make_tensor, device=device, dtype=dtype, requires_grad=False
)
yield SampleInput(make_arg(2, 2, 2, low=0.1, high=2))
def simple_cond(x):
return torch.cond(x.sum() > 2, lambda x: x.cos(), lambda x: x.sin(), [x])
return torch.cond(x.shape[0] > 2, lambda x: x.cos(), lambda x: x.sin(), [x])
def sample_inputs_auto_functionalize(opinfo, device, dtype, requires_grad, **kwargs):
@ -196,23 +196,7 @@ hop_db = [
check_batched_gradgrad=False,
check_batched_forward_grad=False,
check_inplace_batched_forward_grad=False,
supports_autograd=True,
# "torch.compile with aot_autograd does not currently support double backward."
supports_gradgrad=False,
skips=(
DecorateInfo(
unittest.expectedFailure,
"TestEagerFusionOpInfo",
"test_aot_autograd_exhaustive",
active_if=IS_WINDOWS,
),
DecorateInfo(
unittest.expectedFailure,
"TestEagerFusionOpInfo",
"test_aot_autograd_symbolic_exhaustive",
active_if=IS_WINDOWS,
),
),
supports_autograd=False,
),
OpInfo(
name="while_loop",