mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Revert "[Autograd] Cond Higher-Order Operation (#126911)"
This reverts commitf7058b735e. Reverted https://github.com/pytorch/pytorch/pull/126911 on behalf of https://github.com/clee2000 due to broke lint and functorch/test_aotdispatchf7058b735eProbably a landrace since both the test and lint passed on PR ([comment](https://github.com/pytorch/pytorch/pull/126911#issuecomment-2237703182))
This commit is contained in:
parent
686b7f046a
commit
fb3674b1f4
|
|
@ -257,751 +257,6 @@ class TestControlFlow(TestCase):
|
|||
result = cond(pred, true_fn, false_fn, [x])
|
||||
self.assertEqual(result, torch.cos(x))
|
||||
|
||||
def test_cond_autograd_simple(self):
|
||||
def true_fn(x):
|
||||
return x.sin()
|
||||
|
||||
def false_fn(x):
|
||||
return x.cos()
|
||||
|
||||
for pred, fn in zip(
|
||||
[torch.tensor(False), torch.tensor(True)], [false_fn, true_fn]
|
||||
):
|
||||
x = torch.randn(4, requires_grad=True)
|
||||
result = cond(pred, true_fn, false_fn, (x,))
|
||||
self.assertEqual(result, fn(x))
|
||||
|
||||
grad_out = torch.ones_like(result)
|
||||
grads = torch.autograd.grad(result, (x,), grad_out)
|
||||
expected_grads = torch.autograd.grad(fn(x), (x,), grad_out)
|
||||
self.assertEqual(expected_grads, grads)
|
||||
|
||||
def f(pred, x):
|
||||
result = cond(pred, true_fn, false_fn, (x,))
|
||||
grad_out = torch.ones_like(result)
|
||||
return torch.autograd.grad(result, (x,), grad_out)
|
||||
|
||||
gm = make_fx(f, tracing_mode="symbolic")(pred, x)
|
||||
|
||||
self.assertExpectedInline(
|
||||
gm.code.strip(),
|
||||
"""\
|
||||
def forward(self, pred_1, x_1):
|
||||
true_graph_0 = self.true_graph_0
|
||||
false_graph_0 = self.false_graph_0
|
||||
cond = torch.ops.higher_order.cond(pred_1, true_graph_0, false_graph_0, (x_1,)); true_graph_0 = false_graph_0 = None
|
||||
getitem = cond[0]; cond = None
|
||||
ones_like = torch.ops.aten.ones_like.default(getitem, pin_memory = False); getitem = None
|
||||
true_graph_1 = self.true_graph_1
|
||||
false_graph_1 = self.false_graph_1
|
||||
cond_1 = torch.ops.higher_order.cond(pred_1, true_graph_1, false_graph_1, (ones_like, x_1)); pred_1 = true_graph_1 = false_graph_1 = ones_like = x_1 = None
|
||||
getitem_1 = cond_1[0]; cond_1 = None
|
||||
return (getitem_1,)""", # noqa: B950
|
||||
)
|
||||
|
||||
def test_cond_autograd_complex(self):
|
||||
def true_fn(x):
|
||||
return torch.abs((x**2).sin())
|
||||
|
||||
def false_fn(x):
|
||||
return (x + 42).cos()
|
||||
|
||||
for pred, fn in zip(
|
||||
[torch.tensor(False), torch.tensor(True)], [false_fn, true_fn]
|
||||
):
|
||||
x = torch.randn(4, requires_grad=True)
|
||||
result = cond(pred, true_fn, false_fn, (x,))
|
||||
self.assertEqual(result, fn(x))
|
||||
|
||||
grad_out = torch.ones_like(result)
|
||||
grads = torch.autograd.grad(result, (x,), grad_out)
|
||||
expected_grads = torch.autograd.grad(fn(x), (x,), grad_out)
|
||||
self.assertEqual(expected_grads, grads)
|
||||
|
||||
def f(pred, x):
|
||||
result = cond(pred, true_fn, false_fn, (x,))
|
||||
grad_out = torch.ones_like(result)
|
||||
return torch.autograd.grad(result, (x,), grad_out)
|
||||
|
||||
gm = make_fx(f, tracing_mode="symbolic")(pred, x)
|
||||
self.assertExpectedInline(
|
||||
gm.code.strip(),
|
||||
"""\
|
||||
def forward(self, pred_1, x_1):
|
||||
true_graph_0 = self.true_graph_0
|
||||
false_graph_0 = self.false_graph_0
|
||||
cond = torch.ops.higher_order.cond(pred_1, true_graph_0, false_graph_0, (x_1,)); true_graph_0 = false_graph_0 = None
|
||||
getitem = cond[0]; cond = None
|
||||
ones_like = torch.ops.aten.ones_like.default(getitem, pin_memory = False); getitem = None
|
||||
true_graph_1 = self.true_graph_1
|
||||
false_graph_1 = self.false_graph_1
|
||||
cond_1 = torch.ops.higher_order.cond(pred_1, true_graph_1, false_graph_1, (ones_like, x_1)); pred_1 = true_graph_1 = false_graph_1 = ones_like = x_1 = None
|
||||
getitem_1 = cond_1[0]; cond_1 = None
|
||||
return (getitem_1,)""", # noqa: B950
|
||||
)
|
||||
|
||||
@skipIfTorchDynamo("Skip due to graph break when run with dynamo")
|
||||
def test_cond_autograd_nested(self):
|
||||
class Nested(torch.nn.Module):
|
||||
def forward(self, p0, p1, p2, a, b, c):
|
||||
def true_fn(x0, y0, z0):
|
||||
def true_true_fn(x1, y1, z1):
|
||||
return (x1 - y1 * z1) * 3.14
|
||||
|
||||
def true_false_fn(x1, y1, z1):
|
||||
def true_false_true_fn(x2, y2, z2):
|
||||
return (x2 * y2 * z2) / 2.71
|
||||
|
||||
def true_false_false_fn(x2, y2, z2):
|
||||
return (x2 + y2 + z2) * 1.23
|
||||
|
||||
return torch.cond(
|
||||
p2, true_false_true_fn, true_false_false_fn, [x1, y1, z1]
|
||||
)
|
||||
|
||||
return torch.cond(p1, true_true_fn, true_false_fn, [x0, y0, z0])
|
||||
|
||||
def false_fn(x0, y0, z0):
|
||||
def false_true_fn(x1, y1, z1):
|
||||
def false_true_true_fn(x2, y2, z2):
|
||||
return (x2 - y2 - z2) + 1.23
|
||||
|
||||
def false_true_false_fn(x2, y2, z2):
|
||||
return (x2 / y2 / z2) - 3.14
|
||||
|
||||
return torch.cond(
|
||||
p2, false_true_true_fn, false_true_false_fn, [x1, y1, z1]
|
||||
)
|
||||
|
||||
def false_false_fn(x1, y1, z1):
|
||||
return (x1 - y1 * z1) / 2.71
|
||||
|
||||
return torch.cond(p1, false_true_fn, false_false_fn, [x0, y0, z0])
|
||||
|
||||
return torch.cond(p0, true_fn, false_fn, [a, b, c])
|
||||
|
||||
nn_module = Nested()
|
||||
|
||||
def true_fn(x):
|
||||
return nn_module(
|
||||
torch.tensor(False), torch.tensor(True), torch.tensor(False), x, x, x
|
||||
)
|
||||
|
||||
def false_fn(x):
|
||||
return nn_module(
|
||||
torch.tensor(True), torch.tensor(False), torch.tensor(True), x, x, x
|
||||
)
|
||||
|
||||
x = torch.randn(4, requires_grad=True)
|
||||
|
||||
for pred, fn in zip(
|
||||
[torch.tensor(False), torch.tensor(True)], [false_fn, true_fn]
|
||||
):
|
||||
result = cond(pred, true_fn, false_fn, (x,))
|
||||
self.assertEqual(result, fn(x))
|
||||
|
||||
grad_out = torch.ones_like(result)
|
||||
grads = torch.autograd.grad(result, (x,), grad_out)
|
||||
expected_grads = torch.autograd.grad(fn(x), (x,), grad_out)
|
||||
self.assertEqual(expected_grads, grads)
|
||||
|
||||
@skipIfTorchDynamo("Skip due to graph break when run with dynamo")
|
||||
def test_cond_autograd_mixed_require_grad(self):
|
||||
def true_fn(x, y, z):
|
||||
return x * y * z
|
||||
|
||||
def false_fn(x, y, z):
|
||||
return x + y + z
|
||||
|
||||
x = torch.randn(4, requires_grad=True)
|
||||
y = torch.randn(4, requires_grad=False)
|
||||
|
||||
for pred, fn in zip(
|
||||
[torch.tensor(False), torch.tensor(True)], [false_fn, true_fn]
|
||||
):
|
||||
result = cond(pred, true_fn, false_fn, (x, y, x))
|
||||
self.assertEqual(result, fn(x, y, x))
|
||||
|
||||
grad_out = torch.ones_like(result)
|
||||
grads = torch.autograd.grad(result, (x,), grad_out)
|
||||
expected_grads = torch.autograd.grad(fn(x, y, x), (x,), grad_out)
|
||||
self.assertEqual(expected_grads, grads)
|
||||
|
||||
def f(pred, x, y, z):
|
||||
result = cond(pred, true_fn, false_fn, (x, y, z))
|
||||
grad_out = torch.ones_like(result)
|
||||
return torch.autograd.grad(result, (x,), grad_out)
|
||||
|
||||
gm = make_fx(f, tracing_mode="symbolic")(pred, x, y, x)
|
||||
self.assertExpectedInline(
|
||||
gm.code.strip(),
|
||||
"""\
|
||||
def forward(self, pred_1, x_1, y_1, z_1):
|
||||
true_graph_0 = self.true_graph_0
|
||||
false_graph_0 = self.false_graph_0
|
||||
cond = torch.ops.higher_order.cond(pred_1, true_graph_0, false_graph_0, (z_1, y_1)); true_graph_0 = false_graph_0 = None
|
||||
getitem = cond[0]; cond = None
|
||||
ones_like = torch.ops.aten.ones_like.default(getitem, pin_memory = False); getitem = None
|
||||
true_graph_1 = self.true_graph_1
|
||||
false_graph_1 = self.false_graph_1
|
||||
cond_1 = torch.ops.higher_order.cond(pred_1, true_graph_1, false_graph_1, (ones_like, z_1, y_1)); pred_1 = true_graph_1 = false_graph_1 = ones_like = z_1 = y_1 = None
|
||||
getitem_1 = cond_1[0]
|
||||
getitem_2 = cond_1[1]; cond_1 = None
|
||||
return (getitem_1,)""", # noqa: B950
|
||||
)
|
||||
|
||||
@skipIfTorchDynamo("Skip due to graph break when run with dynamo")
|
||||
def test_cond_autograd_grad_through_cond(self):
|
||||
nn_module = torch.nn.Linear(4, 4)
|
||||
|
||||
def true_fn(x):
|
||||
return nn_module(x)
|
||||
|
||||
def false_fn(X):
|
||||
return x * nn_module(x)
|
||||
|
||||
x = torch.randn(4, requires_grad=True)
|
||||
|
||||
for pred, fn in zip(
|
||||
[torch.tensor(False), torch.tensor(True)], [false_fn, true_fn]
|
||||
):
|
||||
result = cond(pred, true_fn, false_fn, (x,))
|
||||
self.assertEqual(result, fn(x))
|
||||
|
||||
grad_out = torch.ones_like(result)
|
||||
grads = torch.autograd.grad(result, (nn_module.weight,), grad_out)
|
||||
expected_grads = torch.autograd.grad(
|
||||
fn(
|
||||
x,
|
||||
),
|
||||
(nn_module.weight,),
|
||||
grad_out,
|
||||
)
|
||||
self.assertEqual(expected_grads, grads)
|
||||
|
||||
def f(pred, x):
|
||||
result = cond(pred, true_fn, false_fn, (x,))
|
||||
grad_out = torch.ones_like(result)
|
||||
return torch.autograd.grad(result, (nn_module.weight,), grad_out)
|
||||
|
||||
# need to set _allow_non_fake_inputs = True because model parameters don't
|
||||
# get fakified.
|
||||
gm = make_fx(f, tracing_mode="symbolic", _allow_non_fake_inputs=True)(pred, x)
|
||||
self.assertExpectedInline(
|
||||
gm.code.strip(),
|
||||
"""\
|
||||
def forward(self, pred_1, x_1):
|
||||
true_graph_0 = self.true_graph_0
|
||||
false_graph_0 = self.false_graph_0
|
||||
_param_constant0 = self._param_constant0
|
||||
_param_constant1 = self._param_constant1
|
||||
_tensor_constant0 = self._tensor_constant0
|
||||
cond = torch.ops.higher_order.cond(pred_1, true_graph_0, false_graph_0, (_param_constant0, _param_constant1, x_1, _tensor_constant0)); true_graph_0 = false_graph_0 = _param_constant0 = _param_constant1 = _tensor_constant0 = None
|
||||
getitem = cond[0]; cond = None
|
||||
ones_like = torch.ops.aten.ones_like.default(getitem, pin_memory = False); getitem = None
|
||||
true_graph_1 = self.true_graph_1
|
||||
false_graph_1 = self.false_graph_1
|
||||
_param_constant0_1 = self._param_constant0
|
||||
_param_constant1_1 = self._param_constant1
|
||||
_tensor_constant0_1 = self._tensor_constant0
|
||||
cond_1 = torch.ops.higher_order.cond(pred_1, true_graph_1, false_graph_1, (ones_like, _param_constant0_1, _param_constant1_1, x_1, _tensor_constant0_1)); pred_1 = true_graph_1 = false_graph_1 = ones_like = _param_constant0_1 = _param_constant1_1 = x_1 = _tensor_constant0_1 = None
|
||||
getitem_1 = cond_1[0]
|
||||
getitem_2 = cond_1[1]
|
||||
getitem_3 = cond_1[2]
|
||||
getitem_4 = cond_1[3]; cond_1 = None
|
||||
return (getitem_2,)""", # noqa: B950
|
||||
)
|
||||
|
||||
def test_cond_in_forloop(self):
|
||||
def for_loop_fake(x):
|
||||
for i in range(3):
|
||||
x = x * x + 1
|
||||
return x
|
||||
|
||||
def for_loop_test(x):
|
||||
for i in range(3):
|
||||
pred = i < 3
|
||||
|
||||
def true_fn(x):
|
||||
return x * x + 1
|
||||
|
||||
def false_fn(x):
|
||||
return x
|
||||
|
||||
x = cond(pred, true_fn, false_fn, (x,))
|
||||
|
||||
return x
|
||||
|
||||
x = torch.ones(4, requires_grad=True)
|
||||
x_new = for_loop_test(x)
|
||||
x_exp = for_loop_fake(x)
|
||||
|
||||
self.assertEqual(x_new, x_exp)
|
||||
|
||||
grad_out = torch.ones_like(x_new)
|
||||
grads = torch.autograd.grad(x_new, (x,), grad_out)
|
||||
expected_grads = torch.autograd.grad(x_exp, (x,), grad_out)
|
||||
self.assertEqual(expected_grads, grads)
|
||||
|
||||
def f(x):
|
||||
x_new = for_loop_test(x)
|
||||
grad_out = torch.ones_like(x_new)
|
||||
return torch.autograd.grad(x_new, (x,), grad_out)
|
||||
|
||||
gm = make_fx(f, tracing_mode="symbolic")(x)
|
||||
self.assertExpectedInline(
|
||||
gm.code.strip(),
|
||||
"""\
|
||||
def forward(self, x_1):
|
||||
mul = torch.ops.aten.mul.Tensor(x_1, x_1)
|
||||
add = torch.ops.aten.add.Tensor(mul, 1); mul = None
|
||||
mul_1 = torch.ops.aten.mul.Tensor(add, add)
|
||||
add_1 = torch.ops.aten.add.Tensor(mul_1, 1); mul_1 = None
|
||||
mul_2 = torch.ops.aten.mul.Tensor(add_1, add_1)
|
||||
add_2 = torch.ops.aten.add.Tensor(mul_2, 1); mul_2 = None
|
||||
ones_like = torch.ops.aten.ones_like.default(add_2, pin_memory = False); add_2 = None
|
||||
mul_3 = torch.ops.aten.mul.Tensor(ones_like, add_1)
|
||||
mul_4 = torch.ops.aten.mul.Tensor(ones_like, add_1); ones_like = add_1 = None
|
||||
add_3 = torch.ops.aten.add.Tensor(mul_4, mul_3); mul_4 = mul_3 = None
|
||||
mul_5 = torch.ops.aten.mul.Tensor(add_3, add)
|
||||
mul_6 = torch.ops.aten.mul.Tensor(add_3, add); add_3 = add = None
|
||||
add_4 = torch.ops.aten.add.Tensor(mul_6, mul_5); mul_6 = mul_5 = None
|
||||
mul_7 = torch.ops.aten.mul.Tensor(add_4, x_1)
|
||||
mul_8 = torch.ops.aten.mul.Tensor(add_4, x_1); add_4 = x_1 = None
|
||||
add_5 = torch.ops.aten.add.Tensor(mul_8, mul_7); mul_8 = mul_7 = None
|
||||
return (add_5,)""", # noqa: B950
|
||||
)
|
||||
|
||||
@skipIfTorchDynamo("Skip due to graph break when run with dynamo")
|
||||
def test_cond_autograd_pytree_not_all_inputs_used(self):
|
||||
def true_fn(x):
|
||||
return x["t"][0] + x["t"][1]["b"]
|
||||
|
||||
def false_fn(x):
|
||||
return x["t"][0] * (x["t"][2][0] / x["t"][1]["b"])
|
||||
|
||||
a = torch.randn(4, requires_grad=True)
|
||||
b = torch.randn(4, requires_grad=True)
|
||||
c = torch.randn(4, requires_grad=True)
|
||||
|
||||
for pred, fn in zip(
|
||||
[torch.tensor(False), torch.tensor(True)], [false_fn, true_fn]
|
||||
):
|
||||
result = cond(pred, true_fn, false_fn, ({"t": [a, {"b": b}, (c,)]},))
|
||||
self.assertEqual(result, fn({"t": [a, {"b": b}, (c,)]}))
|
||||
|
||||
grad_out = torch.ones_like(result)
|
||||
if pred:
|
||||
with self.assertRaisesRegex(Exception, r"."):
|
||||
grads = torch.autograd.grad(result, (a, b, c), grad_out)
|
||||
expected_grads = torch.autograd.grad(
|
||||
fn({"t": [a, {"b": b}, (c,)]}), (a, b, c), grad_out
|
||||
)
|
||||
self.assertEqual(expected_grads, grads)
|
||||
|
||||
def f(pred, a, b, c):
|
||||
result = cond(pred, true_fn, false_fn, ({"t": [a, {"b": b}, (c,)]},))
|
||||
grad_out = torch.ones_like(result)
|
||||
return torch.autograd.grad(result, (a, b), grad_out)
|
||||
|
||||
gm = make_fx(f, tracing_mode="symbolic", _allow_non_fake_inputs=True)(
|
||||
pred, a, b, c
|
||||
)
|
||||
self.assertExpectedInline(
|
||||
gm.code.strip(),
|
||||
"""\
|
||||
def forward(self, pred_1, a_1, b_1, c_1):
|
||||
true_graph_0 = self.true_graph_0
|
||||
false_graph_0 = self.false_graph_0
|
||||
cond = torch.ops.higher_order.cond(pred_1, true_graph_0, false_graph_0, (a_1, b_1, c_1)); true_graph_0 = false_graph_0 = None
|
||||
getitem = cond[0]; cond = None
|
||||
ones_like = torch.ops.aten.ones_like.default(getitem, pin_memory = False); getitem = None
|
||||
true_graph_1 = self.true_graph_1
|
||||
false_graph_1 = self.false_graph_1
|
||||
cond_1 = torch.ops.higher_order.cond(pred_1, true_graph_1, false_graph_1, (ones_like, a_1, b_1, c_1)); pred_1 = true_graph_1 = false_graph_1 = ones_like = a_1 = b_1 = c_1 = None
|
||||
getitem_1 = cond_1[0]
|
||||
getitem_2 = cond_1[1]
|
||||
getitem_3 = cond_1[2]; cond_1 = None
|
||||
return (getitem_1, getitem_2)""", # noqa: B950
|
||||
)
|
||||
# Forward
|
||||
self.assertExpectedInline(
|
||||
gm.true_graph_0.code.strip(),
|
||||
"""\
|
||||
def forward(self, arg0_1, arg1_1, arg2_1):
|
||||
add = torch.ops.aten.add.Tensor(arg0_1, arg1_1); arg0_1 = arg1_1 = None
|
||||
return (add,)""",
|
||||
)
|
||||
# Backward
|
||||
self.assertExpectedInline(
|
||||
gm.true_graph_1.code.strip(),
|
||||
"""\
|
||||
def forward(self, arg0_1, arg1_1, arg2_1, arg3_1):
|
||||
add = torch.ops.aten.add.Tensor(arg1_1, arg2_1); arg1_1 = arg2_1 = None
|
||||
clone = torch.ops.aten.clone.default(arg0_1)
|
||||
clone_1 = torch.ops.aten.clone.default(arg0_1); arg0_1 = None
|
||||
return [clone, clone_1, None]""",
|
||||
)
|
||||
|
||||
def test_cond_autograd_pytree_input(self):
|
||||
def true_fn(x):
|
||||
return x["t"][0] + x["t"][1]["b"] * x["t"][2][0]
|
||||
|
||||
def false_fn(x):
|
||||
return x["t"][0] * (x["t"][2][0] / x["t"][1]["b"])
|
||||
|
||||
a = torch.randn(4, requires_grad=True)
|
||||
b = torch.randn(4, requires_grad=True)
|
||||
c = torch.randn(4, requires_grad=True)
|
||||
|
||||
for pred, fn in zip(
|
||||
[torch.tensor(False), torch.tensor(True)], [false_fn, true_fn]
|
||||
):
|
||||
result = cond(pred, true_fn, false_fn, ({"t": [a, {"b": b}, (c,)]},))
|
||||
self.assertEqual(result, fn({"t": [a, {"b": b}, (c,)]}))
|
||||
|
||||
grad_out = torch.ones_like(result)
|
||||
grads = torch.autograd.grad(result, (a, b), grad_out)
|
||||
expected_grads = torch.autograd.grad(
|
||||
fn({"t": [a, {"b": b}, (c,)]}), (a, b), grad_out
|
||||
)
|
||||
self.assertEqual(expected_grads, grads)
|
||||
|
||||
def f(pred):
|
||||
result = cond(pred, true_fn, false_fn, ({"t": [a, {"b": b}, (c,)]},))
|
||||
grad_out = torch.ones_like(result)
|
||||
return torch.autograd.grad(result, (a, b), grad_out)
|
||||
|
||||
# need to set _allow_non_fake_inputs = True because model parameters don't
|
||||
# get fakified.
|
||||
gm = make_fx(f, tracing_mode="symbolic", _allow_non_fake_inputs=True)(pred)
|
||||
self.assertExpectedInline(
|
||||
gm.code.strip(),
|
||||
"""\
|
||||
def forward(self, pred_1):
|
||||
true_graph_0 = self.true_graph_0
|
||||
false_graph_0 = self.false_graph_0
|
||||
_tensor_constant0 = self._tensor_constant0
|
||||
_tensor_constant1 = self._tensor_constant1
|
||||
_tensor_constant2 = self._tensor_constant2
|
||||
cond = torch.ops.higher_order.cond(pred_1, true_graph_0, false_graph_0, (_tensor_constant0, _tensor_constant1, _tensor_constant2)); true_graph_0 = false_graph_0 = _tensor_constant0 = _tensor_constant1 = _tensor_constant2 = None
|
||||
getitem = cond[0]; cond = None
|
||||
ones_like = torch.ops.aten.ones_like.default(getitem, pin_memory = False); getitem = None
|
||||
true_graph_1 = self.true_graph_1
|
||||
false_graph_1 = self.false_graph_1
|
||||
_tensor_constant0_1 = self._tensor_constant0
|
||||
_tensor_constant1_1 = self._tensor_constant1
|
||||
_tensor_constant2_1 = self._tensor_constant2
|
||||
cond_1 = torch.ops.higher_order.cond(pred_1, true_graph_1, false_graph_1, (ones_like, _tensor_constant0_1, _tensor_constant1_1, _tensor_constant2_1)); pred_1 = true_graph_1 = false_graph_1 = ones_like = _tensor_constant0_1 = _tensor_constant1_1 = _tensor_constant2_1 = None
|
||||
getitem_1 = cond_1[0]
|
||||
getitem_2 = cond_1[1]
|
||||
getitem_3 = cond_1[2]; cond_1 = None
|
||||
return (getitem_1, getitem_2)""", # noqa: B950
|
||||
)
|
||||
|
||||
def test_cond_autograd_different_pytree_output(self):
|
||||
def true_fn(x):
|
||||
return x["t"][0], {"r": x["t"][2][0] / x["t"][1]["b"]}, [x["t"][2][0]]
|
||||
|
||||
def false_fn(x):
|
||||
return {"res": [x["t"][0] * x["t"][1]["b"], x["t"][2][0]]}
|
||||
|
||||
a = torch.randn(4, requires_grad=True)
|
||||
b = torch.randn(4, requires_grad=True)
|
||||
c = torch.randn(4, requires_grad=True)
|
||||
|
||||
for pred, fn in zip(
|
||||
[torch.tensor(False), torch.tensor(True)], [false_fn, true_fn]
|
||||
):
|
||||
with self.assertRaisesRegex(
|
||||
torch._dynamo.exc.UncapturedHigherOrderOpError,
|
||||
"Cond doesn't work unless it is captured completely with torch.compile",
|
||||
):
|
||||
cond(pred, true_fn, false_fn, ({"t": [a, {"b": b}, (c,)]},))
|
||||
|
||||
@skipIfTorchDynamo("Skip due to graph break when run with dynamo")
|
||||
def test_cond_autograd_same_pytree_output(self):
|
||||
def true_fn(x):
|
||||
return {"res": [x["t"][0], (x["t"][2][0],)]}
|
||||
|
||||
def false_fn(x):
|
||||
return {"res": [x["t"][1]["b"], (x["t"][2][0],)]}
|
||||
|
||||
a = torch.randn(4, requires_grad=True)
|
||||
b = torch.randn(4, requires_grad=True)
|
||||
c = torch.randn(4, requires_grad=True)
|
||||
|
||||
for pred, fn in zip(
|
||||
[torch.tensor(False), torch.tensor(True)], [false_fn, true_fn]
|
||||
):
|
||||
result = cond(pred, true_fn, false_fn, ({"t": [a, {"b": b}, (c,)]},))
|
||||
result_exp = fn({"t": [a, {"b": b}, (c,)]})
|
||||
self.assertEqual(result, result_exp)
|
||||
|
||||
result_flat, _ = pytree.tree_flatten(result)
|
||||
result_exp_flat, _ = pytree.tree_flatten(result_exp)
|
||||
|
||||
grad_out = [torch.ones_like(g) for g in result_flat]
|
||||
expected_grads = torch.autograd.grad(result_exp_flat, (c,), grad_out)
|
||||
grads = torch.autograd.grad(result_flat, (c,), grad_out)
|
||||
self.assertEqual(expected_grads, grads)
|
||||
|
||||
def f(pred):
|
||||
result = cond(pred, true_fn, false_fn, ({"t": [a, {"b": b}, (c,)]},))
|
||||
return result
|
||||
|
||||
gm = make_fx(f, tracing_mode="symbolic", _allow_non_fake_inputs=True)(pred)
|
||||
self.assertExpectedInline(
|
||||
gm.code.strip(),
|
||||
"""\
|
||||
def forward(self, pred_1):
|
||||
true_graph_0 = self.true_graph_0
|
||||
false_graph_0 = self.false_graph_0
|
||||
_tensor_constant0 = self._tensor_constant0
|
||||
_tensor_constant1 = self._tensor_constant1
|
||||
_tensor_constant2 = self._tensor_constant2
|
||||
cond = torch.ops.higher_order.cond(pred_1, true_graph_0, false_graph_0, (_tensor_constant0, _tensor_constant1, _tensor_constant2)); pred_1 = true_graph_0 = false_graph_0 = _tensor_constant0 = _tensor_constant1 = _tensor_constant2 = None
|
||||
getitem = cond[0]
|
||||
getitem_1 = cond[1]; cond = None
|
||||
view = torch.ops.aten.view.default(getitem, [4]); getitem = None
|
||||
view_1 = torch.ops.aten.view.default(getitem_1, [4]); getitem_1 = None
|
||||
return {'res': [view, (view_1,)]}""", # noqa: B950
|
||||
)
|
||||
|
||||
@skipIfTorchDynamo("Skip due to graph break when run with dynamo")
|
||||
def test_cond_autograd_torch_nn_module(self):
|
||||
nn_module_true = torch.nn.Linear(4, 4)
|
||||
|
||||
def true_fn(x):
|
||||
return nn_module_true(torch.abs((x**2).sin()))
|
||||
|
||||
nn_module_false = torch.nn.GRUCell(4, 4)
|
||||
|
||||
def false_fn(x):
|
||||
return nn_module_false((x + 42).cos())
|
||||
|
||||
for pred, fn in zip(
|
||||
[torch.tensor(False), torch.tensor(True)], [false_fn, true_fn]
|
||||
):
|
||||
x = torch.randn(4, requires_grad=True)
|
||||
result = cond(pred, true_fn, false_fn, (x,))
|
||||
self.assertEqual(result, fn(x))
|
||||
|
||||
grad_out = torch.ones_like(result)
|
||||
grads = torch.autograd.grad(result, (x,), grad_out)
|
||||
expected_grads = torch.autograd.grad(fn(x), (x,), grad_out)
|
||||
self.assertEqual(expected_grads, grads)
|
||||
|
||||
def f(pred, x):
|
||||
result = cond(pred, true_fn, false_fn, (x,))
|
||||
grad_out = torch.ones_like(result)
|
||||
return torch.autograd.grad(result, (x,), grad_out)
|
||||
|
||||
gm = make_fx(f)(pred, x)
|
||||
self.assertExpectedInline(
|
||||
gm.code.strip(),
|
||||
"""\
|
||||
def forward(self, pred_1, x_1):
|
||||
true_graph_0 = self.true_graph_0
|
||||
false_graph_0 = self.false_graph_0
|
||||
_param_constant0 = self._param_constant0
|
||||
_param_constant1 = self._param_constant1
|
||||
_param_constant2 = self._param_constant2
|
||||
_param_constant3 = self._param_constant3
|
||||
_param_constant4 = self._param_constant4
|
||||
_param_constant5 = self._param_constant5
|
||||
cond = torch.ops.higher_order.cond(pred_1, true_graph_0, false_graph_0, (x_1, _param_constant0, _param_constant1, _param_constant2, _param_constant3, _param_constant4, _param_constant5)); true_graph_0 = false_graph_0 = _param_constant0 = _param_constant1 = _param_constant2 = _param_constant3 = _param_constant4 = _param_constant5 = None
|
||||
getitem = cond[0]; cond = None
|
||||
ones_like = torch.ops.aten.ones_like.default(getitem, pin_memory = False); getitem = None
|
||||
true_graph_1 = self.true_graph_1
|
||||
false_graph_1 = self.false_graph_1
|
||||
_param_constant0_1 = self._param_constant0
|
||||
_param_constant1_1 = self._param_constant1
|
||||
_param_constant2_1 = self._param_constant2
|
||||
_param_constant3_1 = self._param_constant3
|
||||
_param_constant4_1 = self._param_constant4
|
||||
_param_constant5_1 = self._param_constant5
|
||||
cond_1 = torch.ops.higher_order.cond(pred_1, true_graph_1, false_graph_1, (ones_like, x_1, _param_constant0_1, _param_constant1_1, _param_constant2_1, _param_constant3_1, _param_constant4_1, _param_constant5_1)); pred_1 = true_graph_1 = false_graph_1 = ones_like = x_1 = _param_constant0_1 = _param_constant1_1 = _param_constant2_1 = _param_constant3_1 = _param_constant4_1 = _param_constant5_1 = None
|
||||
getitem_1 = cond_1[0]
|
||||
getitem_2 = cond_1[1]
|
||||
getitem_3 = cond_1[2]
|
||||
getitem_4 = cond_1[3]
|
||||
getitem_5 = cond_1[4]
|
||||
getitem_6 = cond_1[5]
|
||||
getitem_7 = cond_1[6]; cond_1 = None
|
||||
return (getitem_1,)""", # noqa: B950
|
||||
)
|
||||
|
||||
def test_cond_autograd_user_nn_module(self):
|
||||
class User_nn_module(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, input):
|
||||
return input * input
|
||||
|
||||
nn_module_true = User_nn_module()
|
||||
|
||||
def true_fn(x):
|
||||
return nn_module_true(torch.abs((x**2).sin()))
|
||||
|
||||
nn_module_false = torch.nn.ReLU(inplace=False)
|
||||
|
||||
def false_fn(x):
|
||||
return nn_module_false((x + 42).cos())
|
||||
|
||||
for pred, fn in zip(
|
||||
[torch.tensor(False), torch.tensor(True)], [false_fn, true_fn]
|
||||
):
|
||||
x = torch.randn(4, requires_grad=True)
|
||||
result = cond(pred, true_fn, false_fn, (x,))
|
||||
self.assertEqual(result, fn(x))
|
||||
|
||||
grad_out = torch.ones_like(result)
|
||||
grads = torch.autograd.grad(result, (x,), grad_out)
|
||||
expected_grads = torch.autograd.grad(fn(x), (x,), grad_out)
|
||||
self.assertEqual(expected_grads, grads)
|
||||
|
||||
def f(pred, x):
|
||||
result = cond(pred, true_fn, false_fn, (x,))
|
||||
grad_out = torch.ones_like(result)
|
||||
return torch.autograd.grad(result, (x,), grad_out)
|
||||
|
||||
gm = make_fx(f)(pred, x)
|
||||
self.assertExpectedInline(
|
||||
gm.code.strip(),
|
||||
"""\
|
||||
def forward(self, pred_1, x_1):
|
||||
true_graph_0 = self.true_graph_0
|
||||
false_graph_0 = self.false_graph_0
|
||||
cond = torch.ops.higher_order.cond(pred_1, true_graph_0, false_graph_0, (x_1,)); true_graph_0 = false_graph_0 = None
|
||||
getitem = cond[0]; cond = None
|
||||
ones_like = torch.ops.aten.ones_like.default(getitem, pin_memory = False); getitem = None
|
||||
true_graph_1 = self.true_graph_1
|
||||
false_graph_1 = self.false_graph_1
|
||||
cond_1 = torch.ops.higher_order.cond(pred_1, true_graph_1, false_graph_1, (ones_like, x_1)); pred_1 = true_graph_1 = false_graph_1 = ones_like = x_1 = None
|
||||
getitem_1 = cond_1[0]; cond_1 = None
|
||||
return (getitem_1,)""", # noqa: B950
|
||||
)
|
||||
|
||||
def test_cond_autograd_inner_fn(self):
|
||||
def true_fn(x):
|
||||
return torch.abs((x**2).sin())
|
||||
|
||||
def false_fn(x):
|
||||
def inner_fn(x):
|
||||
return x**2
|
||||
|
||||
return torch.abs(inner_fn(x).sin())
|
||||
|
||||
x = torch.randn(4, requires_grad=True)
|
||||
pred = torch.tensor(False)
|
||||
fn = false_fn
|
||||
result_false = cond(pred, true_fn, false_fn, (x,))
|
||||
self.assertEqual(result_false, fn(x))
|
||||
|
||||
grad_out = torch.ones_like(result_false)
|
||||
grads_false = torch.autograd.grad(result_false, (x,), grad_out)
|
||||
expected_grads = torch.autograd.grad(fn(x), (x,), grad_out)
|
||||
self.assertEqual(expected_grads, grads_false)
|
||||
|
||||
pred = torch.tensor(True)
|
||||
fn = true_fn
|
||||
result_true = cond(pred, true_fn, false_fn, (x,))
|
||||
self.assertEqual(result_true, fn(x))
|
||||
self.assertEqual(result_false, result_true)
|
||||
|
||||
grad_out = torch.ones_like(result_true)
|
||||
grads_true = torch.autograd.grad(result_true, (x,), grad_out)
|
||||
expected_grads = torch.autograd.grad(fn(x), (x,), grad_out)
|
||||
self.assertEqual(expected_grads, grads_true)
|
||||
self.assertEqual(grads_false, grads_true)
|
||||
|
||||
def f(pred, x):
|
||||
result = cond(pred, true_fn, false_fn, (x,))
|
||||
grad_out = torch.ones_like(result)
|
||||
return torch.autograd.grad(result, (x,), grad_out)
|
||||
|
||||
gm = make_fx(f)(pred, x)
|
||||
self.assertExpectedInline(
|
||||
gm.code.strip(),
|
||||
"""\
|
||||
def forward(self, pred_1, x_1):
|
||||
true_graph_0 = self.true_graph_0
|
||||
false_graph_0 = self.false_graph_0
|
||||
cond = torch.ops.higher_order.cond(pred_1, true_graph_0, false_graph_0, (x_1,)); true_graph_0 = false_graph_0 = None
|
||||
getitem = cond[0]; cond = None
|
||||
ones_like = torch.ops.aten.ones_like.default(getitem, pin_memory = False); getitem = None
|
||||
true_graph_1 = self.true_graph_1
|
||||
false_graph_1 = self.false_graph_1
|
||||
cond_1 = torch.ops.higher_order.cond(pred_1, true_graph_1, false_graph_1, (ones_like, x_1)); pred_1 = true_graph_1 = false_graph_1 = ones_like = x_1 = None
|
||||
getitem_1 = cond_1[0]; cond_1 = None
|
||||
return (getitem_1,)""", # noqa: B950
|
||||
)
|
||||
|
||||
def test_cond_autograd_inner_tensor(self):
|
||||
def true_fn(x):
|
||||
return torch.abs((x**2).sin())
|
||||
|
||||
def false_fn(x):
|
||||
y = torch.ones(4, requires_grad=False) * 42
|
||||
return (x * y).cos()
|
||||
|
||||
for pred, fn in zip(
|
||||
[torch.tensor(False), torch.tensor(True)], [false_fn, true_fn]
|
||||
):
|
||||
x = torch.randn(4, requires_grad=True)
|
||||
result = cond(pred, true_fn, false_fn, (x,))
|
||||
self.assertEqual(result, fn(x))
|
||||
|
||||
grad_out = torch.ones_like(result)
|
||||
grads = torch.autograd.grad(result, (x,), grad_out)
|
||||
expected_grads = torch.autograd.grad(fn(x), (x,), grad_out)
|
||||
self.assertEqual(expected_grads, grads)
|
||||
|
||||
def f(pred, x):
|
||||
result = cond(pred, true_fn, false_fn, (x,))
|
||||
grad_out = torch.ones_like(result)
|
||||
return torch.autograd.grad(result, (x,), grad_out)
|
||||
|
||||
gm = make_fx(f, tracing_mode="symbolic")(pred, x)
|
||||
self.assertExpectedInline(
|
||||
gm.code.strip(),
|
||||
"""\
|
||||
def forward(self, pred_1, x_1):
|
||||
true_graph_0 = self.true_graph_0
|
||||
false_graph_0 = self.false_graph_0
|
||||
cond = torch.ops.higher_order.cond(pred_1, true_graph_0, false_graph_0, (x_1,)); true_graph_0 = false_graph_0 = None
|
||||
getitem = cond[0]; cond = None
|
||||
ones_like = torch.ops.aten.ones_like.default(getitem, pin_memory = False); getitem = None
|
||||
true_graph_1 = self.true_graph_1
|
||||
false_graph_1 = self.false_graph_1
|
||||
cond_1 = torch.ops.higher_order.cond(pred_1, true_graph_1, false_graph_1, (ones_like, x_1)); pred_1 = true_graph_1 = false_graph_1 = ones_like = x_1 = None
|
||||
getitem_1 = cond_1[0]; cond_1 = None
|
||||
return (getitem_1,)""", # noqa: B950
|
||||
)
|
||||
|
||||
@unittest.skipIf(not torch.cuda.is_available(), "Test requires CUDA.")
|
||||
def test_cond_autograd_gpu(self):
|
||||
def true_fn(x):
|
||||
return x.sin()
|
||||
|
||||
def false_fn(x):
|
||||
return x.cos()
|
||||
|
||||
for pred, fn in zip(
|
||||
[torch.tensor(False, device="cuda"), torch.tensor(True, device="cuda")],
|
||||
[false_fn, true_fn],
|
||||
):
|
||||
x = torch.randn(4, requires_grad=True, device="cuda")
|
||||
result = cond(pred, true_fn, false_fn, (x,))
|
||||
self.assertEqual(result, fn(x))
|
||||
|
||||
grad_out = torch.ones_like(result)
|
||||
grads = torch.autograd.grad(result, (x,), grad_out)
|
||||
expected_grads = torch.autograd.grad(fn(x), (x,), grad_out)
|
||||
self.assertEqual(expected_grads, grads)
|
||||
|
||||
@unittest.skipIf(not torch.cuda.is_available(), "Test requires CUDA.")
|
||||
def test_map_gpu(self):
|
||||
def f(x, y):
|
||||
|
|
@ -1221,74 +476,6 @@ class TestControlFlowTraced(TestCase):
|
|||
graph = make_fx(f, tracing_mode="symbolic")(x, torch.tensor(False))
|
||||
self.assertEqual(graph(x, torch.tensor(True)), f(x, torch.tensor(True)))
|
||||
|
||||
@skipIfTorchDynamo("Graph is not captured by backend if test with dynamo")
|
||||
def test_cond_simple_with_linear_compile_check_graph(self):
|
||||
from torch._dynamo.testing import EagerAndRecordGraphs
|
||||
|
||||
def true_fn(x):
|
||||
return x.sin()
|
||||
|
||||
def false_fn(x):
|
||||
return x.cos()
|
||||
|
||||
x = torch.randn(4, requires_grad=True)
|
||||
|
||||
def f(pred, x):
|
||||
result = cond(pred, true_fn, false_fn, (x,))
|
||||
grad_out = torch.ones_like(result)
|
||||
return torch.autograd.grad(result, (x,), grad_out)
|
||||
|
||||
backend = EagerAndRecordGraphs()
|
||||
torch.compile(f, backend=backend)(torch.tensor(False), x)
|
||||
self.assertEqual(len(backend.graphs), 2)
|
||||
gm = backend.graphs[0]
|
||||
|
||||
self.assertExpectedInline(
|
||||
gm.code.strip(),
|
||||
"""\
|
||||
def forward(self, L_pred_ : torch.Tensor, L_x_ : torch.Tensor):
|
||||
l_pred_ = L_pred_
|
||||
l_x_ = L_x_
|
||||
cond_true_0 = self.cond_true_0
|
||||
cond_false_0 = self.cond_false_0
|
||||
cond = torch.ops.higher_order.cond(l_pred_, cond_true_0, cond_false_0, [l_x_]); l_pred_ = cond_true_0 = cond_false_0 = l_x_ = None
|
||||
result = cond[0]; cond = None
|
||||
grad_out = torch.ones_like(result)
|
||||
return (result, grad_out)""", # noqa: B950
|
||||
)
|
||||
|
||||
self.assertExpectedInline(
|
||||
gm.cond_true_0.code.strip(),
|
||||
"""\
|
||||
def forward(self, l_x_):
|
||||
l_x__1 = l_x_
|
||||
sin = l_x__1.sin(); l_x__1 = None
|
||||
return (sin,)""", # noqa: B950
|
||||
)
|
||||
self.assertExpectedInline(
|
||||
gm.cond_false_0.code.strip(),
|
||||
"""\
|
||||
def forward(self, l_x_):
|
||||
l_x__1 = l_x_
|
||||
cos = l_x__1.cos(); l_x__1 = None
|
||||
return (cos,)""", # noqa: B950
|
||||
)
|
||||
|
||||
backward_gm = backend.graphs[1]
|
||||
self.assertExpectedInline(
|
||||
backward_gm.code.strip(),
|
||||
"""\
|
||||
def forward(self, L_ctx_saved_tensors_0_ : torch.Tensor, L_ctx_pred : torch.Tensor, L_flat_grads_0_ : torch.Tensor):
|
||||
l_ctx_saved_tensors_0_ = L_ctx_saved_tensors_0_
|
||||
l_ctx_pred = L_ctx_pred
|
||||
l_flat_grads_0_ = L_flat_grads_0_
|
||||
cond_true_0 = self.cond_true_0
|
||||
cond_false_0 = self.cond_false_0
|
||||
cond = torch.ops.higher_order.cond(l_ctx_pred, cond_true_0, cond_false_0, [l_ctx_saved_tensors_0_, l_flat_grads_0_]); l_ctx_pred = cond_true_0 = cond_false_0 = l_ctx_saved_tensors_0_ = l_flat_grads_0_ = None
|
||||
getitem = cond[0]; cond = None
|
||||
return (getitem,)""", # noqa: B950
|
||||
)
|
||||
|
||||
def test_while_loop_nested_traced(self):
|
||||
fn, inp = WHILE_LOOP_TESTS["nested"]
|
||||
graphs = self._check_tracing(fn, inp)
|
||||
|
|
@ -2724,7 +1911,7 @@ def forward(self, arg0_1):
|
|||
):
|
||||
functional_f(*example_inputs)
|
||||
|
||||
def test_cond_autograd_backward(self):
|
||||
def test_cond_autograd_succeed_when_pred_is_constant(self):
|
||||
def true_fn(x):
|
||||
return x.cos()
|
||||
|
||||
|
|
@ -2738,14 +1925,32 @@ def forward(self, arg0_1):
|
|||
torch.ones(3, 2, 4, requires_grad=True),
|
||||
torch.ones(4, requires_grad=True),
|
||||
)
|
||||
# Due to x.shape[0] can be statically evaluated to be False, we can evaluate
|
||||
# the backward.
|
||||
f(*example_inputs).sum().backward()
|
||||
|
||||
# Ensure no error is thrown when not running backward
|
||||
res = f(*example_inputs)
|
||||
f(*example_inputs)
|
||||
|
||||
def test_cond_autograd_fail(self):
|
||||
def true_fn(x):
|
||||
return x.cos()
|
||||
|
||||
def false_fn(x):
|
||||
return x.sin()
|
||||
|
||||
def f(x, y):
|
||||
return control_flow.cond(x.sum() > 4, true_fn, false_fn, [y])
|
||||
|
||||
example_inputs = (
|
||||
torch.ones(3, 2, 4, requires_grad=True),
|
||||
torch.ones(4, requires_grad=True),
|
||||
)
|
||||
with self.assertRaisesRegex(RuntimeError, "Autograd not implemented for cond"):
|
||||
f(*example_inputs).sum().backward()
|
||||
|
||||
# Ensure no error is thrown when not running backward
|
||||
res_compiled = torch.compile(f)(*example_inputs)
|
||||
self.assertEqual(res, res_compiled)
|
||||
f(*example_inputs)
|
||||
|
||||
# https://github.com/pytorch/pytorch/issues/126988
|
||||
@xfailIfTorchDynamo
|
||||
|
|
|
|||
|
|
@ -13,29 +13,26 @@ from torch._C._functorch import (
|
|||
is_batchedtensor,
|
||||
maybe_get_bdim,
|
||||
)
|
||||
from torch._dispatch.python import suspend_functionalization
|
||||
from torch._functorch.utils import exposed_in
|
||||
from torch._guards import detect_fake_mode
|
||||
from torch._higher_order_ops.utils import (
|
||||
_has_potential_branch_input_alias,
|
||||
_has_potential_branch_input_mutation,
|
||||
_set_compilation_env,
|
||||
autograd_not_implemented,
|
||||
reenter_make_fx,
|
||||
unique_graph_id,
|
||||
UnsupportedAliasMutationException,
|
||||
)
|
||||
from torch._ops import HigherOrderOperator
|
||||
from torch._subclasses.fake_tensor import FakeTensorMode
|
||||
from torch._subclasses.functional_tensor import disable_functional_mode
|
||||
from torch.fx.experimental.proxy_tensor import (
|
||||
_temp_remove_pre_dispatch_torch_function_mode,
|
||||
disable_proxy_modes_tracing,
|
||||
ProxyTorchDispatchMode,
|
||||
track_tensor_tree,
|
||||
)
|
||||
from torch.fx.passes.shape_prop import _extract_tensor_metadata
|
||||
from torch.utils._python_dispatch import _get_current_dispatch_mode
|
||||
from .utils import _from_fun, create_fw_bw_graph
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -105,6 +102,8 @@ def cond(pred, true_fn, false_fn, operands):
|
|||
.. warning::
|
||||
Temporal Limitations:
|
||||
|
||||
- `cond` only supports **inference** right now. Autograd will be supported in the future.
|
||||
|
||||
- The **output** of branches must be a **single Tensor**. Pytree of tensors will be supported in the future.
|
||||
|
||||
"""
|
||||
|
|
@ -159,48 +158,6 @@ We're going to define a `cond_op` operation.
|
|||
In order to do this, we need implementations for each of the dispatch keys.
|
||||
"""
|
||||
cond_op = HigherOrderOperator("cond")
|
||||
cond_op.__module__ = "torch.ops.higher_order"
|
||||
|
||||
|
||||
def create_fw_bw_graph_branches(true_fn, false_fn, *operands):
|
||||
# See Note [HOP create fw_bw graph] in create_fw_bw_graph in utils.py
|
||||
|
||||
with suspend_functionalization(), disable_functional_mode():
|
||||
with disable_proxy_modes_tracing():
|
||||
fw_inputs = pytree.tree_map(_from_fun, operands)
|
||||
|
||||
fw_outputs_true = pytree.tree_map(_from_fun, true_fn(*fw_inputs))
|
||||
if any(
|
||||
not isinstance(out, torch.Tensor)
|
||||
for out in fw_outputs_true
|
||||
if out is not None
|
||||
):
|
||||
raise RuntimeError(
|
||||
"Expect outputs of true_fn to only contains tensors or None. "
|
||||
f"Got types {[type(out) for out in fw_outputs_true]}."
|
||||
)
|
||||
fw_outputs_false = pytree.tree_map(_from_fun, false_fn(*fw_inputs))
|
||||
if any(
|
||||
not isinstance(out, torch.Tensor)
|
||||
for out in fw_outputs_false
|
||||
if out is not None
|
||||
):
|
||||
raise RuntimeError(
|
||||
"Expect outputs of false_fn to only contains tensors or None. "
|
||||
f"Got types {[type(out) for out in fw_outputs_false]}."
|
||||
)
|
||||
|
||||
# TODO: There is a major issue that the create_fw_bw in the higher_order_op is invoked twice:
|
||||
# Once in the forward path (as it should) and once in the backward path, where it shouldn't be called
|
||||
# If we can get rid of the second invokation, it would simplify this function
|
||||
fw_true_graph, joint_true_graph = create_fw_bw_graph(
|
||||
true_fn, False, fw_inputs, fw_outputs_true
|
||||
)
|
||||
fw_false_graph, joint_false_graph = create_fw_bw_graph(
|
||||
false_fn, False, fw_inputs, fw_outputs_false
|
||||
)
|
||||
|
||||
return fw_true_graph, fw_false_graph, joint_true_graph, joint_false_graph
|
||||
|
||||
|
||||
def trace_cond(proxy_mode, func_overload, pred, true_fn, false_fn, operands):
|
||||
|
|
@ -229,53 +186,14 @@ def trace_cond(proxy_mode, func_overload, pred, true_fn, false_fn, operands):
|
|||
if len(flat_true_outs) != len(flat_false_outs):
|
||||
raise torch._dynamo.exc.CondOpArgsMismatchError(
|
||||
f"Expected to return same number of outputs but got:"
|
||||
f"\n true branch returns {len(flat_true_outs)} item(s)"
|
||||
f"\n false branch returns {len(flat_false_outs)} item(s)"
|
||||
f"\n {true_fn.__name__} returns {len(flat_true_outs)} item(s)"
|
||||
f"\n {false_fn.__name__} returns {len(flat_false_outs)} item(s)"
|
||||
)
|
||||
|
||||
for i in range(0, len(flat_true_outs)):
|
||||
true_out = flat_true_outs[i]
|
||||
false_out = flat_false_outs[i]
|
||||
|
||||
# Note that we need skip the check for requires_grad because we're after
|
||||
# after autograd key during tracing, so the rquires_grad attribute of the tensors
|
||||
# are no longer. See Note [invariants for node meta 'val']
|
||||
def _same_meta_except_requires_grad(true_out, false_out):
|
||||
if true_out is None and false_out is None:
|
||||
return True
|
||||
elif true_out is None or false_out is None:
|
||||
# Consider the following case:
|
||||
# def true_fn(x, y):
|
||||
# return x * y
|
||||
#
|
||||
# def false_fn(x, y):
|
||||
# return x.sin()
|
||||
#
|
||||
# We'll get the following graphs for backward:
|
||||
# def backward_true_fn(x, y, grad_out):
|
||||
# return grad_out * y, grad_out * x
|
||||
#
|
||||
# def backward_false_fn(x, y, grad_out):
|
||||
# retrun grad_out, None
|
||||
#
|
||||
# This suggests that when we make_fx into the backward graph,
|
||||
# the output graph would produce outputs with metadata, this is undesirable.
|
||||
#
|
||||
# Ideally, we should provide an optional type to indicate that one of the branches might
|
||||
# return None. But we'll just let it pass for now and let downstream/runtime handle.
|
||||
#
|
||||
# Note that this corner case should **only** happen when user want to trace backward graph because
|
||||
# if it's foward, dynamo will error.
|
||||
return True
|
||||
true_meta = true_out.meta.get("tensor_meta", None)
|
||||
false_meta = false_out.meta.get("tensor_meta", None)
|
||||
return (
|
||||
true_meta.shape == false_meta.shape
|
||||
and true_meta.dtype == false_meta.dtype
|
||||
and true_meta.stride == false_meta.stride
|
||||
)
|
||||
|
||||
if not _same_meta_except_requires_grad(true_out, false_out):
|
||||
if true_out.meta["tensor_meta"] != false_out.meta["tensor_meta"]:
|
||||
raise torch._dynamo.exc.CondOpArgsMismatchError(
|
||||
f"Expected each tensor to have same metadata but got:"
|
||||
f"\n {true_fn.__name__} returns {true_out.meta['tensor_meta']}"
|
||||
|
|
@ -335,65 +253,9 @@ def cond_op_dense(pred, true_fn, false_fn, operands):
|
|||
return false_fn(*operands)
|
||||
|
||||
|
||||
class CondAutogradOp(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(
|
||||
ctx,
|
||||
pred,
|
||||
fw_true_graph,
|
||||
fw_false_graph,
|
||||
joint_true_graph,
|
||||
joint_false_graph,
|
||||
*operands,
|
||||
):
|
||||
ctx._pred = pred
|
||||
ctx._joint_true_graph = joint_true_graph
|
||||
ctx._joint_false_graph = joint_false_graph
|
||||
ctx.save_for_backward(*operands)
|
||||
|
||||
with torch._C._AutoDispatchBelowAutograd():
|
||||
return cond_op(pred, fw_true_graph, fw_false_graph, operands)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, *flat_grads):
|
||||
operands = ctx.saved_tensors
|
||||
|
||||
grads = cond_op(
|
||||
ctx._pred,
|
||||
ctx._joint_true_graph,
|
||||
ctx._joint_false_graph,
|
||||
flat_grads + operands,
|
||||
)
|
||||
return None, None, None, None, None, *grads
|
||||
|
||||
|
||||
@cond_op.py_impl(DispatchKey.Autograd)
|
||||
def cond_autograd(pred, true_fn, false_fn, operands):
|
||||
# A shortcut for the case where all inputs don't require gradient,
|
||||
# we skip tracing the forward and backward graph.
|
||||
if all(
|
||||
not t.requires_grad
|
||||
for t in pytree.tree_flatten((pred, operands))[0]
|
||||
if isinstance(t, torch.Tensor)
|
||||
):
|
||||
with torch._C._AutoDispatchBelowAutograd():
|
||||
return cond_op(pred, true_fn, false_fn, operands)
|
||||
|
||||
(
|
||||
fw_true_graph,
|
||||
fw_false_graph,
|
||||
joint_true_graph,
|
||||
joint_false_graph,
|
||||
) = create_fw_bw_graph_branches(true_fn, false_fn, *operands)
|
||||
flat_out = CondAutogradOp.apply(
|
||||
pred,
|
||||
fw_true_graph,
|
||||
fw_false_graph,
|
||||
joint_true_graph,
|
||||
joint_false_graph,
|
||||
*operands,
|
||||
)
|
||||
return flat_out
|
||||
cond_op.py_impl(DispatchKey.Autograd)(
|
||||
autograd_not_implemented(cond_op, deferred_error=True)
|
||||
)
|
||||
|
||||
|
||||
@cond_op.py_impl(ProxyTorchDispatchMode)
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ import torch
|
|||
import torch.utils._pytree as pytree
|
||||
from torch._C import DispatchKey
|
||||
from torch._dispatch.python import suspend_functionalization
|
||||
from torch._functorch.aot_autograd import AOTConfig, create_joint
|
||||
from torch._functorch.aot_autograd import AOTConfig, create_joint, from_fun
|
||||
from torch._higher_order_ops.utils import (
|
||||
_has_potential_branch_input_alias,
|
||||
_has_potential_branch_input_mutation,
|
||||
|
|
@ -12,20 +12,17 @@ from torch._higher_order_ops.utils import (
|
|||
)
|
||||
from torch._ops import HigherOrderOperator
|
||||
from torch._subclasses.fake_tensor import FakeTensorMode
|
||||
from torch._subclasses.functional_tensor import disable_functional_mode
|
||||
from torch._subclasses.functional_tensor import (
|
||||
disable_functional_mode,
|
||||
FunctionalTensor,
|
||||
)
|
||||
from torch.fx.experimental.proxy_tensor import (
|
||||
disable_proxy_modes_tracing,
|
||||
make_fx,
|
||||
ProxyTorchDispatchMode,
|
||||
track_tensor_tree,
|
||||
)
|
||||
from .utils import (
|
||||
_from_fun,
|
||||
_stack_pytree,
|
||||
_unstack_pytree,
|
||||
clone_outputs_aliasing_inputs,
|
||||
prepare_fw_with_masks,
|
||||
)
|
||||
from torch.multiprocessing.reductions import StorageWeakRef
|
||||
|
||||
|
||||
# TODO: We add this to prevent dymamo from tracing into map_wrapper,
|
||||
|
|
@ -53,10 +50,49 @@ def create_fw_bw_graph(f, num_mapped_args, *args):
|
|||
mapped_xs = args[:num_mapped_args]
|
||||
pos_args = args[num_mapped_args:]
|
||||
|
||||
# See Note [HOP create fw_bw graph] in create_fw_bw_graph in utils.py
|
||||
# Note:[HOP create fw_bw graph] We create "clean" environments for make_fx by suspending all dispatch keys
|
||||
# between Autograd and Python key. Currently, we only suspend functionalization but more can be
|
||||
# added when required. Will encounter two problems if we don't suspend functionalization:
|
||||
#
|
||||
# 1. make_fx fails to capture operations on input: the inputs are wrapped as _to_functional_tensor_wrapper,
|
||||
# but they will be unwrapped before entering ProxyTorchDispatchMode as part of the dispatching.
|
||||
# However, it's the outside wrapper that tracer creates proxies for. This casuses tracer fail to
|
||||
# fetch the proxy for the inputs and fail to capture any operations on them.
|
||||
#
|
||||
# 2. make_fx fails to capture output: the outputs after ProxyTorchDispatchMode are further
|
||||
# wrapped as FunctionalTensorWrapper in Functionalize key after return. However, the tracer
|
||||
# only associates the inner tensor with proxy in ProxyTorchDispatchMode. Therefore,
|
||||
# when creating the output node, it fails to associate the wrapped tensor with its proxy.
|
||||
# Instead, it will create _tensor_constant as output.
|
||||
|
||||
with suspend_functionalization(), disable_functional_mode():
|
||||
with disable_proxy_modes_tracing():
|
||||
|
||||
def _from_fun(t):
|
||||
if isinstance(t, torch.Tensor):
|
||||
if t.dtype != torch.bool:
|
||||
return torch.empty_strided(
|
||||
t.size(),
|
||||
t.stride(),
|
||||
dtype=t.dtype,
|
||||
requires_grad=t.requires_grad,
|
||||
)
|
||||
else:
|
||||
# clone of a functional tensor produces a functional tensor
|
||||
# but we want to avoid it so we clone a non-functional version
|
||||
maybe_unfunc_t = t
|
||||
if isinstance(t, FunctionalTensor):
|
||||
torch._sync(t)
|
||||
maybe_unfunc_t = from_fun(t)
|
||||
elif torch._is_functional_tensor(t):
|
||||
# need to handle both types of functionalization here:
|
||||
# these are the tensors that came from the user,
|
||||
# which could be either FunctionalTensorWrapper or FunctionalTensor
|
||||
torch._sync(t)
|
||||
maybe_unfunc_t = torch._from_functional_tensor(t)
|
||||
return maybe_unfunc_t.clone()
|
||||
return t
|
||||
|
||||
unwrapped_mapped_xs = pytree.tree_map(_from_fun, mapped_xs)
|
||||
example_xs = _unstack_pytree(unwrapped_mapped_xs)[0]
|
||||
|
||||
|
|
@ -87,7 +123,16 @@ def create_fw_bw_graph(f, num_mapped_args, *args):
|
|||
mapped_input = joint_mapped_args[:num_mapped_args]
|
||||
mapped_grads = joint_mapped_args[num_mapped_args:]
|
||||
|
||||
joint = create_joint(prepare_fw_with_masks(f), aot_config=dummy_aot_config)
|
||||
def fw_with_masks(*args):
|
||||
fw_out = f(*args)
|
||||
return fw_out, [
|
||||
True
|
||||
if isinstance(ret, torch.Tensor) and ret.requires_grad
|
||||
else False
|
||||
for ret in fw_out
|
||||
]
|
||||
|
||||
joint = create_joint(fw_with_masks, aot_config=dummy_aot_config)
|
||||
_, grads = joint(
|
||||
list(mapped_input) + list(args),
|
||||
[
|
||||
|
|
@ -99,7 +144,19 @@ def create_fw_bw_graph(f, num_mapped_args, *args):
|
|||
|
||||
# In order to keep map functional for backward graph,
|
||||
# we clone outputs that are aliasing inputs
|
||||
maybe_clone = clone_outputs_aliasing_inputs(example_args)
|
||||
input_storage = {
|
||||
StorageWeakRef(arg._typed_storage())
|
||||
for arg in example_args
|
||||
if isinstance(arg, torch.Tensor)
|
||||
}
|
||||
|
||||
def maybe_clone(t):
|
||||
if (
|
||||
isinstance(t, torch.Tensor)
|
||||
and StorageWeakRef(t._typed_storage()) in input_storage
|
||||
):
|
||||
return t.clone()
|
||||
return t
|
||||
|
||||
return pytree.tree_map(maybe_clone, grads)
|
||||
|
||||
|
|
@ -199,6 +256,46 @@ def trace_map(proxy_mode, func_overload, f, xs, pos_args):
|
|||
)
|
||||
|
||||
|
||||
def _unstack_pytree(xs):
|
||||
flat_xs, inspec = pytree.tree_flatten(xs)
|
||||
if not all(isinstance(xs, torch.Tensor) for xs in flat_xs):
|
||||
raise RuntimeError(f"Leaves of xs must be Tensor {flat_xs}")
|
||||
|
||||
if not all(xs.shape[0] == flat_xs[0].shape[0] for xs in flat_xs):
|
||||
raise RuntimeError(
|
||||
f"Leaves of xs must have same leading dimension size {[xs.shape for xs in flat_xs]}"
|
||||
)
|
||||
|
||||
a = zip(*flat_xs)
|
||||
|
||||
pytrees = []
|
||||
for tuple in a:
|
||||
pytrees.append(pytree.tree_unflatten(tuple, inspec))
|
||||
return pytrees
|
||||
|
||||
|
||||
def _stack_pytree(pytrees):
|
||||
flat_out = []
|
||||
out_spec = None
|
||||
for pt in pytrees:
|
||||
flat_pt, out_spec = pytree.tree_flatten(pt)
|
||||
flat_out.append(flat_pt)
|
||||
assert out_spec is not None
|
||||
b = zip(*flat_out)
|
||||
stacked_out = []
|
||||
for leaves in b:
|
||||
if all(isinstance(leaf, torch.Tensor) for leaf in leaves):
|
||||
stacked_out.append(torch.stack(leaves))
|
||||
elif all(leaf is None for leaf in leaves):
|
||||
# Backward graph can return None output when forward inputs doesn't require grad.
|
||||
# When we eagerly execute backward graph, we need to call _stack_pytree on its output,
|
||||
# therefore we need to deal with None output.
|
||||
stacked_out.append(None) # type: ignore[arg-type]
|
||||
else:
|
||||
raise RuntimeError(f"Cannot stack {leaves}.")
|
||||
return pytree.tree_unflatten(stacked_out, out_spec)
|
||||
|
||||
|
||||
@map_impl.py_impl(DispatchKey.CompositeExplicitAutograd)
|
||||
def map_dense(f, xs, pos_args):
|
||||
pytrees = []
|
||||
|
|
|
|||
|
|
@ -93,15 +93,6 @@ def reenter_make_fx(fn):
|
|||
return wrapped
|
||||
|
||||
|
||||
def _maybe_reenter_make_fx(fn):
|
||||
from torch.fx.experimental.proxy_tensor import _CURRENT_MAKE_FX_TRACER
|
||||
|
||||
if _CURRENT_MAKE_FX_TRACER is not None:
|
||||
return reenter_make_fx(fn)
|
||||
else:
|
||||
return make_fx(fn)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _set_compilation_env():
|
||||
_old_is_tracing = torch.fx._symbolic_trace._is_fx_tracing_flag
|
||||
|
|
@ -219,169 +210,3 @@ def unique_graph_id(proxy_mode, prefix):
|
|||
else:
|
||||
next_name = candidate
|
||||
return i, next_name
|
||||
|
||||
|
||||
def _from_fun(t):
|
||||
from torch._functorch.aot_autograd import from_fun
|
||||
from torch._subclasses.functional_tensor import FunctionalTensor
|
||||
|
||||
if isinstance(t, torch.Tensor):
|
||||
if t.dtype != torch.bool:
|
||||
return torch.empty_strided(
|
||||
t.size(),
|
||||
t.stride(),
|
||||
dtype=t.dtype,
|
||||
requires_grad=t.requires_grad,
|
||||
)
|
||||
else:
|
||||
# clone of a functional tensor produces a functional tensor
|
||||
# but we want to avoid it so we clone a non-functional version
|
||||
maybe_unfunc_t = t
|
||||
if isinstance(t, FunctionalTensor):
|
||||
torch._sync(t)
|
||||
maybe_unfunc_t = from_fun(t)
|
||||
elif torch._is_functional_tensor(t):
|
||||
# need to handle both types of functionalization here:
|
||||
# these are the tensors that came from the user,
|
||||
# which could be either FunctionalTensorWrapper or FunctionalTensor
|
||||
torch._sync(t)
|
||||
maybe_unfunc_t = torch._from_functional_tensor(t)
|
||||
return maybe_unfunc_t.clone()
|
||||
return t
|
||||
|
||||
|
||||
def clone_outputs_aliasing_inputs(args):
|
||||
input_storage = {
|
||||
StorageWeakRef(arg._typed_storage())
|
||||
for arg in args
|
||||
if isinstance(arg, torch.Tensor)
|
||||
}
|
||||
|
||||
def maybe_clone(t):
|
||||
if (
|
||||
isinstance(t, torch.Tensor)
|
||||
and StorageWeakRef(t._typed_storage()) in input_storage
|
||||
):
|
||||
return t.clone()
|
||||
return t
|
||||
|
||||
return maybe_clone
|
||||
|
||||
|
||||
def prepare_fw_with_masks(fn):
|
||||
def fw_with_masks(*args):
|
||||
fw_out = fn(*args)
|
||||
return fw_out, [
|
||||
True if isinstance(ret, torch.Tensor) and ret.requires_grad else False
|
||||
for ret in fw_out
|
||||
]
|
||||
|
||||
return fw_with_masks
|
||||
|
||||
|
||||
# TODO: The parameter use_output_and_grad_bw is required because some operations
|
||||
# that utilize this function, such as the while_loop, may require (grad, fwd_outputs)
|
||||
def create_fw_bw_graph(fn, use_output_and_grad_bw, fw_inputs, fw_outputs):
|
||||
from torch._functorch.aot_autograd import AOTConfig, create_joint
|
||||
|
||||
# Note:[HOP create fw_bw graph] We create "clean" environments for make_fx by suspending all dispatch keys
|
||||
# between Autograd and Python key. Currently, we only suspend functionalization but more can be
|
||||
# added when required. Will encounter two problems if we don't suspend functionalization:
|
||||
#
|
||||
# 1. make_fx fails to capture operations on input: the inputs are wrapped as _to_functional_tensor_wrapper,
|
||||
# but they will be unwrapped before entering ProxyTorchDispatchMode as part of the dispatching.
|
||||
# However, it's the outside wrapper that tracer creates proxies for. This casuses tracer fail to
|
||||
# fetch the proxy for the inputs and fail to capture any operations on them.
|
||||
#
|
||||
# 2. make_fx fails to capture output: the outputs after ProxyTorchDispatchMode are further
|
||||
# wrapped as FunctionalTensorWrapper in Functionalize key after return. However, the tracer
|
||||
# only associates the inner tensor with proxy in ProxyTorchDispatchMode. Therefore,
|
||||
# when creating the output node, it fails to associate the wrapped tensor with its proxy.
|
||||
# Instead, it will create _tensor_constant as output.
|
||||
|
||||
dummy_aot_config = AOTConfig(
|
||||
fw_compiler=None, # type: ignore[arg-type]
|
||||
bw_compiler=None, # type: ignore[arg-type]
|
||||
partition_fn=None, # type: ignore[arg-type]
|
||||
decompositions={},
|
||||
num_params_buffers=0,
|
||||
aot_id=0,
|
||||
keep_inference_input_mutations=False,
|
||||
)
|
||||
|
||||
example_grad = [_from_fun(out) for out in fw_outputs]
|
||||
num_grads = len(example_grad)
|
||||
fw_graph = _maybe_reenter_make_fx(fn)(*fw_inputs)
|
||||
|
||||
def joint_fn(*joint_operands_grads):
|
||||
if use_output_and_grad_bw:
|
||||
grads = joint_operands_grads[0]
|
||||
inputs = joint_operands_grads[1][-1:]
|
||||
else:
|
||||
grads = joint_operands_grads[:num_grads]
|
||||
inputs = joint_operands_grads[num_grads:]
|
||||
|
||||
joint = create_joint(prepare_fw_with_masks(fn), aot_config=dummy_aot_config)
|
||||
_, grads = joint(
|
||||
list(inputs),
|
||||
[grad for grad in grads if grad is not None and grad.requires_grad],
|
||||
)
|
||||
|
||||
# In order to keep map functional for backward graph,
|
||||
# we clone outputs that are aliasing inputs
|
||||
maybe_clone = clone_outputs_aliasing_inputs(joint_operands_grads)
|
||||
|
||||
return pytree.tree_map(maybe_clone, grads)
|
||||
|
||||
if use_output_and_grad_bw:
|
||||
example_xs_out = list(fw_inputs) + list(fw_outputs)
|
||||
joint_graph = _maybe_reenter_make_fx(joint_fn)(
|
||||
(list(example_grad), list(example_xs_out))
|
||||
)
|
||||
else:
|
||||
example_xs_out = list(fw_inputs)
|
||||
joint_graph = _maybe_reenter_make_fx(joint_fn)(
|
||||
*(list(example_grad) + list(example_xs_out))
|
||||
)
|
||||
|
||||
return fw_graph, joint_graph
|
||||
|
||||
|
||||
def _unstack_pytree(xs):
|
||||
flat_xs, inspec = pytree.tree_flatten(xs)
|
||||
if not all(isinstance(xs, torch.Tensor) for xs in flat_xs):
|
||||
raise RuntimeError(f"Leaves of xs must be Tensor {flat_xs}")
|
||||
|
||||
if not all(xs.shape[0] == flat_xs[0].shape[0] for xs in flat_xs):
|
||||
raise RuntimeError(
|
||||
f"Leaves of xs must have same leading dimension size {[xs.shape for xs in flat_xs]}"
|
||||
)
|
||||
|
||||
a = zip(*flat_xs)
|
||||
|
||||
pytrees = []
|
||||
for tuple in a:
|
||||
pytrees.append(pytree.tree_unflatten(tuple, inspec))
|
||||
return pytrees
|
||||
|
||||
|
||||
def _stack_pytree(pytrees):
|
||||
flat_out = []
|
||||
out_spec = None
|
||||
for pt in pytrees:
|
||||
flat_pt, out_spec = pytree.tree_flatten(pt)
|
||||
flat_out.append(flat_pt)
|
||||
assert out_spec is not None
|
||||
b = zip(*flat_out)
|
||||
stacked_out = []
|
||||
for leaves in b:
|
||||
if all(isinstance(leaf, torch.Tensor) for leaf in leaves):
|
||||
stacked_out.append(torch.stack(leaves))
|
||||
elif all(leaf is None for leaf in leaves):
|
||||
# Backward graph can return None output when forward inputs doesn't require grad.
|
||||
# When we eagerly execute backward graph, we need to call _stack_pytree on its output,
|
||||
# therefore we need to deal with None output.
|
||||
stacked_out.append(None) # type: ignore[arg-type]
|
||||
else:
|
||||
raise RuntimeError(f"Cannot stack {leaves}.")
|
||||
return pytree.tree_unflatten(stacked_out, out_spec)
|
||||
|
|
|
|||
|
|
@ -315,7 +315,6 @@ def extract_val(val: _ExtractValType) -> _ExtractValType:
|
|||
|
||||
typing_extensions.assert_never(val)
|
||||
|
||||
# Note [invariants for node meta 'val']
|
||||
# What invariants do we have for the 'val' set on the FX node? It has accurate
|
||||
# metadata... but only for metadata that exists "below" all other subsystems
|
||||
# (most notably autograd, but also vmap, functorch transforms, etc). This means
|
||||
|
|
|
|||
|
|
@ -91,13 +91,13 @@ def foo_impl_abstract(x, z):
|
|||
|
||||
def sample_inputs_cond(opinfo, device, dtype, requires_grad, **kwargs):
|
||||
make_arg = functools.partial(
|
||||
make_tensor, device=device, dtype=dtype, requires_grad=requires_grad
|
||||
make_tensor, device=device, dtype=dtype, requires_grad=False
|
||||
)
|
||||
yield SampleInput(make_arg(2, 2, 2, low=0.1, high=2))
|
||||
|
||||
|
||||
def simple_cond(x):
|
||||
return torch.cond(x.sum() > 2, lambda x: x.cos(), lambda x: x.sin(), [x])
|
||||
return torch.cond(x.shape[0] > 2, lambda x: x.cos(), lambda x: x.sin(), [x])
|
||||
|
||||
|
||||
def sample_inputs_auto_functionalize(opinfo, device, dtype, requires_grad, **kwargs):
|
||||
|
|
@ -196,23 +196,7 @@ hop_db = [
|
|||
check_batched_gradgrad=False,
|
||||
check_batched_forward_grad=False,
|
||||
check_inplace_batched_forward_grad=False,
|
||||
supports_autograd=True,
|
||||
# "torch.compile with aot_autograd does not currently support double backward."
|
||||
supports_gradgrad=False,
|
||||
skips=(
|
||||
DecorateInfo(
|
||||
unittest.expectedFailure,
|
||||
"TestEagerFusionOpInfo",
|
||||
"test_aot_autograd_exhaustive",
|
||||
active_if=IS_WINDOWS,
|
||||
),
|
||||
DecorateInfo(
|
||||
unittest.expectedFailure,
|
||||
"TestEagerFusionOpInfo",
|
||||
"test_aot_autograd_symbolic_exhaustive",
|
||||
active_if=IS_WINDOWS,
|
||||
),
|
||||
),
|
||||
supports_autograd=False,
|
||||
),
|
||||
OpInfo(
|
||||
name="while_loop",
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user