mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: ONNX team and recent transformer upgrade ran into this error and we also ran into during our export benchmarking. This diff makes it possible to trace through vmap implementation in pre-dispatch IR. Note that we don't support serializing functorch ops in pre-dispatch IR and in the future, we should desugar them to post-grad ops. The implementation strategy is: 1. We add python wrappers around vmap APIs so that we attach custom torch function handler that is only on during non-strict export. The reason is we don't want to add this to default torch_function handler because it will break BC. 2. Some dynamo changes to make sure it picks up new python wrapper APIs. The reason is when we do strict export, we need to re-materialize these APIs in pre-dispatch IR from torch IR. We can avoid this by special casing in dynamo for export to proxy different API calls but i feel that is too much chaos because you need to be able to proxy 2 different variants of same vmap API. Test Plan: CI Differential Revision: D75623875 Pull Request resolved: https://github.com/pytorch/pytorch/pull/154650 Approved by: https://github.com/ezyang, https://github.com/zou3519
9077 lines
345 KiB
Python
9077 lines
345 KiB
Python
# Owner(s): ["module: functorch"]
|
|
import contextlib
|
|
import functools
|
|
import unittest
|
|
|
|
import torch
|
|
import torch.utils._pytree as pytree
|
|
from functorch.experimental import control_flow
|
|
from functorch.experimental.control_flow import cond
|
|
from torch._dynamo.testing import EagerAndRecordGraphs, normalize_gm
|
|
from torch._higher_order_ops.associative_scan import (
|
|
_fake_associative_scan,
|
|
associative_scan,
|
|
)
|
|
from torch._higher_order_ops.map import _fake_map
|
|
from torch._higher_order_ops.scan import _fake_scan, scan
|
|
from torch._higher_order_ops.schema import HopSchemaGenerator
|
|
from torch._higher_order_ops.while_loop import while_loop
|
|
from torch._subclasses.functional_tensor import (
|
|
CppFunctionalizeAPI,
|
|
FunctionalTensor,
|
|
FunctionalTensorMode,
|
|
PythonFunctionalizeAPI,
|
|
)
|
|
from torch.fx.experimental.proxy_tensor import make_fx
|
|
from torch.testing._internal.common_cuda import SM70OrLater
|
|
from torch.testing._internal.common_quantization import skipIfNoDynamoSupport
|
|
from torch.testing._internal.common_utils import (
|
|
decorateIf,
|
|
instantiate_parametrized_tests,
|
|
IS_WINDOWS,
|
|
parametrize,
|
|
requires_cuda,
|
|
run_tests,
|
|
skipIfCrossRef,
|
|
skipIfRocm,
|
|
skipIfTorchDynamo,
|
|
TEST_WITH_CROSSREF,
|
|
TEST_WITH_TORCHDYNAMO,
|
|
TestCase,
|
|
)
|
|
|
|
|
|
# TODO: pull these helpers from AOTAutograd later
|
|
def to_fun(t):
|
|
if isinstance(t, torch.Tensor):
|
|
return FunctionalTensor.to_functional(t)
|
|
return t
|
|
|
|
|
|
def from_fun(t):
|
|
if not isinstance(t, FunctionalTensor):
|
|
# quick sanity assert
|
|
if isinstance(t, torch.Tensor):
|
|
assert not torch._is_functional_tensor(t)
|
|
return t
|
|
torch._sync(t)
|
|
return torch._from_functional_tensor(t.elem)
|
|
|
|
|
|
def to_fun_old(t):
|
|
if isinstance(t, torch.Tensor) and not torch._is_functional_tensor(t):
|
|
out = torch._to_functional_tensor(t)
|
|
torch._mirror_autograd_meta_to(t, out)
|
|
return out
|
|
return t
|
|
|
|
|
|
def from_fun_old(t):
|
|
# quick sanity assert
|
|
if isinstance(t, torch.Tensor):
|
|
assert torch._is_functional_tensor(t)
|
|
torch._sync(t)
|
|
return torch._from_functional_tensor(t)
|
|
return t
|
|
|
|
|
|
def _fake_while_loop(cond_fn, body_fn, operands):
|
|
while cond_fn(*operands):
|
|
operands = body_fn(*operands)
|
|
return operands
|
|
|
|
|
|
def compile_mode_helper(fct, compile_mode):
|
|
if compile_mode == "compile":
|
|
return torch.compile(fct, fullgraph=True, dynamic=False)
|
|
elif compile_mode == "compile_dynamic_shape":
|
|
return torch.compile(fct, fullgraph=True, dynamic=True)
|
|
elif compile_mode == "eager":
|
|
return torch.compile(fct, fullgraph=True, backend="eager")
|
|
else:
|
|
return fct
|
|
|
|
|
|
ALIAS_FN = [
|
|
lambda x: x,
|
|
lambda x: x.view(-1),
|
|
lambda x: x.reshape(-1),
|
|
lambda x: x.squeeze(0),
|
|
lambda x: x.unsqueeze(0),
|
|
lambda x: x.transpose(0, 1),
|
|
lambda x: x.flatten(),
|
|
lambda x: x.expand(1, *x.size()),
|
|
]
|
|
|
|
|
|
def get_scan_combine_fn(name, associative=True, parameters=None):
|
|
def add(x: torch.Tensor, y: torch.Tensor):
|
|
return x + y
|
|
|
|
def adds(x: torch.Tensor, y: torch.Tensor):
|
|
return x + x, y + y
|
|
|
|
def mul(x: torch.Tensor, y: torch.Tensor):
|
|
return x * y
|
|
|
|
def div(x: torch.Tensor, y: torch.Tensor):
|
|
return x / y
|
|
|
|
def s5_operator(x: torch.Tensor, y: torch.Tensor):
|
|
A_i, Bu_i = x
|
|
A_j, Bu_j = y
|
|
return A_j * A_i, A_j * Bu_i + Bu_j
|
|
|
|
def different_input_size_operator(x: torch.Tensor, y: torch.Tensor):
|
|
x_o, dA_o, dB_o, C_o, y_o = x
|
|
x_n, dA_n, dB_n, C_n, y_n = y
|
|
|
|
x_new = x_n + x_o
|
|
y_new = torch.einsum("bdn,bn->bd", x_new, C_n)
|
|
|
|
return x_new, dA_n + 0.0, dB_n + 0.0, C_n + 0.0, y_new
|
|
|
|
def tuple_fct(x, y):
|
|
return (x[0] + y[0], x[1] * y[1])
|
|
|
|
def complex_pointwise(x, y):
|
|
return {
|
|
"i": x["i"] * y["i"],
|
|
"j": (
|
|
[x["j"][0][0] * y["j"][0][0]],
|
|
[{"o": x["j"][1][0]["o"] + y["j"][1][0]["o"]}],
|
|
),
|
|
}
|
|
|
|
def non_pointwise(x: torch.Tensor, y: torch.Tensor):
|
|
W = torch.diag(torch.ones(2, device=x.device))
|
|
return x @ W + y @ W
|
|
|
|
def RNN(x: torch.Tensor, y: torch.Tensor):
|
|
c_new = y @ parameters[0] + parameters[1]
|
|
h_new = torch.tanh(c_new + x @ parameters[2] + parameters[3])
|
|
return h_new, h_new.clone()
|
|
|
|
def fct_c1_no_grad(x: torch.Tensor, y: torch.Tensor):
|
|
h_new = torch.tanh(x[0] + x[1] + y)
|
|
c2 = x[1] + y
|
|
with torch.no_grad():
|
|
c1 = x[0] + y
|
|
return (c1, c2), h_new
|
|
|
|
if name == "add":
|
|
fct = add
|
|
elif name == "adds":
|
|
fct = adds
|
|
elif name == "mul":
|
|
fct = mul
|
|
elif name == "div":
|
|
fct = div
|
|
elif name == "s5_operator":
|
|
fct = s5_operator
|
|
elif name == "different_input_size_operator":
|
|
fct = different_input_size_operator
|
|
elif name == "tuple_fct":
|
|
fct = tuple_fct
|
|
elif name == "complex_pointwise":
|
|
fct = complex_pointwise
|
|
elif name == "non_pointwise":
|
|
fct = non_pointwise
|
|
elif name == "RNN":
|
|
fct = RNN
|
|
elif name == "fct_c1_no_grad":
|
|
fct = fct_c1_no_grad
|
|
else:
|
|
raise ValueError("Combine_fn name unknown!")
|
|
|
|
if not associative:
|
|
return lambda x, y: (fct(x, y), fct(x, y))
|
|
else:
|
|
return fct
|
|
|
|
|
|
def _while_loop_tests():
|
|
def simple(x):
|
|
def cond_fn(x):
|
|
return x.sum() < 10
|
|
|
|
def body_fn(x):
|
|
return (x + 1,)
|
|
|
|
return while_loop(cond_fn, body_fn, (x,))
|
|
|
|
def simple_with_mutation(x):
|
|
def cond_fn(x):
|
|
y = x.clone().add_(1).add_(-1)
|
|
return y.sum() < 10
|
|
|
|
def body_fn(x):
|
|
y = x.clone().add_(1).add_(-1)
|
|
return (y + 1,)
|
|
|
|
return while_loop(cond_fn, body_fn, (x,))
|
|
|
|
def nested(out_iter, it, y):
|
|
def cond_fn(out_iter, it, y):
|
|
return it.sum() < 10
|
|
|
|
def body_fn(out_iter, it, y):
|
|
return (out_iter.clone(), it + y, y + 1)
|
|
|
|
def outer_cond_fn(out_iter, it, y):
|
|
return out_iter.sum() < 2
|
|
|
|
def outer_body_fn(out_iter, it, y):
|
|
out_iter, it, y = while_loop(cond_fn, body_fn, (out_iter, it, y))
|
|
return (out_iter + 1, it, y)
|
|
|
|
return while_loop(outer_cond_fn, outer_body_fn, (out_iter, it, y))
|
|
|
|
class Nested(torch.nn.Module):
|
|
def forward(self, ci, cj, a, b):
|
|
def cond_fn(i1, j1, x1, y1):
|
|
return i1 > 0
|
|
|
|
def body_fn(i1, j1, x1, y1):
|
|
def cond_fn_nested(i2, j2, x2, y2):
|
|
return j2 > 0
|
|
|
|
def body_fn_nested(i2, j2, x2, y2):
|
|
return i2.clone(), j2 - 1, x2 + 3.14, y2 - 2.71
|
|
|
|
i1, j1, x1, y1 = while_loop(
|
|
cond_fn_nested, body_fn_nested, [i1, j1, x1, y1]
|
|
)
|
|
return i1 - 1, j1.clone(), x1 * 2, y1 / 2
|
|
|
|
return while_loop(cond_fn, body_fn, (ci, cj, a, b))
|
|
|
|
class SimpleWithLinear(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(2, 2)
|
|
self.dec = torch.nn.Buffer(torch.tensor(1))
|
|
|
|
def forward(self, iter, x):
|
|
def cond_fn(it, x):
|
|
return it - self.dec > 0
|
|
|
|
def body_fn(it, x):
|
|
return it - 1, self.linear(x)
|
|
|
|
return while_loop(cond_fn, body_fn, (iter, x))
|
|
|
|
class SimpleWithPytreeCarry(torch.nn.Module):
|
|
def forward(self, it, pytree_input):
|
|
def cond_fn(it, pytree_input):
|
|
return it > 0
|
|
|
|
def body_fn(it, pytree_input):
|
|
x = pytree_input[0][0]
|
|
y = pytree_input[1]["x"]
|
|
z = pytree_input[1]["y"]
|
|
new_x = y.sin()
|
|
new_y = z.cos()
|
|
new_z = x + 1
|
|
return it - 1, ([new_x], {"x": new_y, "y": new_z})
|
|
|
|
return while_loop(cond_fn, body_fn, (it, pytree_input))
|
|
|
|
class NestedWithLinear(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.mod = SimpleWithLinear()
|
|
self.outer_linear = torch.nn.Linear(2, 2)
|
|
self.dec = torch.nn.Buffer(torch.tensor(1))
|
|
|
|
def forward(self, iter, x):
|
|
def cond_fn(it, x):
|
|
return it - self.dec > 0
|
|
|
|
def body_fn(it, x):
|
|
return it - 1, self.outer_linear(self.mod(it, x)[1])
|
|
|
|
return while_loop(cond_fn, body_fn, (iter, x))
|
|
|
|
class PytreeIntCarry(torch.nn.Module):
|
|
def forward(self, x):
|
|
a = x.shape[0]
|
|
b = x.shape[1]
|
|
|
|
def cond_fn(shapes, const_int_dict, x):
|
|
a, b = shapes
|
|
c1, c2, c3 = const_int_dict["int_carry"]
|
|
return c1 * c2 * c3 < a * b
|
|
|
|
def body_fn(shapes, const_int_dict, x):
|
|
a, b = shapes
|
|
c1, c2, c3 = const_int_dict["int_carry"]
|
|
return (
|
|
[a + 1, b + 1],
|
|
{"int_carry": (c1 + 1, c2 + 1, c3 + 1)},
|
|
x + 1,
|
|
)
|
|
|
|
carry = ([a, b], {"int_carry": (2, 2, 3)}, x.sin())
|
|
out_shapes, out_it, out_x = while_loop(cond_fn, body_fn, carry)
|
|
out_inc = pytree.tree_map(lambda x: x + 1, out_it)
|
|
out_add = pytree.tree_map(lambda x: x + out_x, out_it)
|
|
return (out_shapes, out_inc, out_add, out_x)
|
|
|
|
class IntCarry(torch.nn.Module):
|
|
def forward(self, x):
|
|
def cond_fn(it, x):
|
|
return it < x.shape[0]
|
|
|
|
def body_fn(it, x):
|
|
x_clone = x.clone()
|
|
# Need these checks to select from x
|
|
torch._check(it >= 0)
|
|
torch._check(it < x.shape[0])
|
|
x_clone.select(0, it).copy_(x_clone.select(0, it) + it)
|
|
return it + 1, x_clone
|
|
|
|
# We invoke the hop directly to avoid triggering dyanmo tracing
|
|
out_it, out_x = torch.ops.higher_order.while_loop(
|
|
cond_fn, body_fn, (0, x), tuple()
|
|
)
|
|
# We need torch._check to use it in torch.ones call
|
|
torch._check(out_it > 0)
|
|
return (
|
|
out_it + 1,
|
|
out_it + out_x,
|
|
out_it < x.shape[0],
|
|
torch.ones(out_it * 2),
|
|
)
|
|
|
|
class ConstAndSymIntOutput(torch.nn.Module):
|
|
def forward(self, t):
|
|
a = t.shape[0]
|
|
b = t.shape[1]
|
|
|
|
def cond_fn(a, b, c1, c2, c3, c0, u0, x):
|
|
return c1 * c2 * c3 < a * b
|
|
|
|
def body_fn(a, b, c1, c2, c3, c0, u0, x):
|
|
return b, c1, c2, c3, a, 0, u0 + 1, x + 1
|
|
|
|
carry = (a, b, 1, 1, 1, a + 1, t.sum().to(torch.int64).item(), t.sin())
|
|
out_it = torch.ops.higher_order.while_loop(cond_fn, body_fn, carry, tuple())
|
|
out_inc = pytree.tree_map(lambda x: x + 1, out_it)
|
|
out_add = pytree.tree_map(lambda x: x + t, out_it)
|
|
return out_inc, out_add
|
|
|
|
nested2 = Nested()
|
|
simple_with_linear = SimpleWithLinear()
|
|
simple_with_pytree_carry = SimpleWithPytreeCarry()
|
|
nested_with_linear = NestedWithLinear()
|
|
int_carry = IntCarry()
|
|
pytree_int_carry = PytreeIntCarry()
|
|
const_and_symint_output = ConstAndSymIntOutput()
|
|
|
|
x = torch.zeros(1)
|
|
y = torch.zeros(1)
|
|
z = torch.zeros(1)
|
|
return {
|
|
"simple": (simple, (x,)),
|
|
"nested": (nested, (x, y, z)),
|
|
"nested2": (
|
|
nested2,
|
|
(torch.tensor(2), torch.tensor(2), torch.ones(2, 2), torch.ones(2, 2)),
|
|
),
|
|
"simple_with_mutation": (simple_with_mutation, (x,)),
|
|
"simple_with_linear": (
|
|
simple_with_linear,
|
|
(torch.tensor(3), torch.randn(2, 2)),
|
|
),
|
|
"nested_with_linear": (
|
|
nested_with_linear,
|
|
(torch.tensor(3), torch.randn(2, 2)),
|
|
),
|
|
"simple_with_pytree_carry": (
|
|
simple_with_pytree_carry,
|
|
(
|
|
torch.tensor(3),
|
|
([torch.randn(3, 3)], {"x": torch.randn(3, 3), "y": torch.randn(3, 3)}),
|
|
),
|
|
),
|
|
"int_carry": (int_carry, (torch.randn(2, 3, requires_grad=True),)),
|
|
"pytree_int_carry": (
|
|
pytree_int_carry,
|
|
(torch.randn(2, 3, requires_grad=True),),
|
|
),
|
|
"const_and_symint_output": (
|
|
const_and_symint_output,
|
|
(torch.randn(2, 3, requires_grad=True),),
|
|
),
|
|
}
|
|
|
|
|
|
WHILE_LOOP_TESTS = _while_loop_tests()
|
|
|
|
|
|
def collect_meta_for_filtered_nodes(
|
|
gm: torch.fx.GraphModule, node_names, meta_field_name
|
|
):
|
|
ret = []
|
|
for mod in gm.modules():
|
|
for node in mod.graph.nodes:
|
|
if node.name in node_names:
|
|
for field_name in meta_field_name:
|
|
ret.append(node.meta.get(field_name))
|
|
return ret
|
|
|
|
|
|
def reduce_func(*operands):
|
|
acc = 0
|
|
for operand in operands:
|
|
acc += operand
|
|
return acc
|
|
|
|
|
|
class ReduceObj:
|
|
def __call__(self, *operands):
|
|
return reduce_func(*operands)
|
|
|
|
|
|
class ReduceMod(torch.nn.Module):
|
|
def _reduce(self, *operands):
|
|
return reduce_func(*operands)
|
|
|
|
def forward(self, *operands):
|
|
return self._reduce(*operands)
|
|
|
|
|
|
@unittest.skipIf(IS_WINDOWS, "Windows not supported for this test")
|
|
@skipIfNoDynamoSupport
|
|
class TestControlFlow(TestCase):
|
|
def setUp(self):
|
|
torch._dynamo.reset()
|
|
super().setUp()
|
|
|
|
def check_autograd(self, result, result_exp, params):
|
|
params_flatten = pytree.tree_leaves(params)
|
|
result_flatten = pytree.tree_leaves(result)
|
|
result_exp_flatten = pytree.tree_leaves(result_exp)
|
|
grad_exp_init = [torch.ones_like(el) for el in result_exp_flatten]
|
|
expected_grads = torch.autograd.grad(
|
|
result_exp_flatten, params_flatten, grad_exp_init
|
|
)
|
|
grad_init = [torch.ones_like(el) for el in result_flatten]
|
|
grads = torch.autograd.grad(result_flatten, params_flatten, grad_init)
|
|
self.assertEqual(grads, expected_grads, atol=6e-05, rtol=6e-06)
|
|
|
|
def test_cond_no_trace(self):
|
|
def true_fn(x):
|
|
return x.sin()
|
|
|
|
def false_fn(x):
|
|
return x.cos()
|
|
|
|
x = torch.randn(4)
|
|
result = cond(False, true_fn, false_fn, [x])
|
|
self.assertEqual(result, torch.cos(x))
|
|
|
|
@unittest.skipIf(not torch.cuda.is_available(), "Test requires CUDA.")
|
|
def test_cond_gpu(self):
|
|
def true_fn(x):
|
|
return x.sin()
|
|
|
|
def false_fn(x):
|
|
return x.cos()
|
|
|
|
x = torch.randn(4, device="cuda")
|
|
pred = torch.tensor(False, device="cuda")
|
|
result = cond(pred, true_fn, false_fn, [x])
|
|
self.assertEqual(result, torch.cos(x))
|
|
|
|
def test_cond_autograd_simple(self):
|
|
def true_fn(x):
|
|
return x.sin()
|
|
|
|
def false_fn(x):
|
|
return x.cos()
|
|
|
|
for pred, fn in zip(
|
|
[torch.tensor(False), torch.tensor(True)], [false_fn, true_fn]
|
|
):
|
|
x = torch.randn(4, requires_grad=True)
|
|
result = cond(pred, true_fn, false_fn, (x,))
|
|
self.assertEqual(result, fn(x))
|
|
|
|
grad_out = torch.ones_like(result)
|
|
grads = torch.autograd.grad(result, (x,), grad_out)
|
|
expected_grads = torch.autograd.grad(fn(x), (x,), grad_out)
|
|
self.assertEqual(expected_grads, grads)
|
|
|
|
def f(pred, x):
|
|
result = cond(pred, true_fn, false_fn, (x,))
|
|
grad_out = torch.ones_like(result)
|
|
return torch.autograd.grad(result, (x,), grad_out)
|
|
|
|
gm = make_fx(f)(pred, x)
|
|
|
|
self.assertExpectedInline(
|
|
gm.code.strip(),
|
|
"""\
|
|
def forward(self, pred_1, x_1):
|
|
true_graph_0 = self.true_graph_0
|
|
false_graph_0 = self.false_graph_0
|
|
cond = torch.ops.higher_order.cond(pred_1, true_graph_0, false_graph_0, (x_1,)); true_graph_0 = false_graph_0 = None
|
|
getitem = cond[0]; cond = None
|
|
ones_like = torch.ops.aten.ones_like.default(getitem, pin_memory = False); getitem = None
|
|
true_graph_1 = self.true_graph_1
|
|
false_graph_1 = self.false_graph_1
|
|
cond_1 = torch.ops.higher_order.cond(pred_1, true_graph_1, false_graph_1, (x_1, ones_like)); pred_1 = true_graph_1 = false_graph_1 = x_1 = ones_like = None
|
|
getitem_1 = cond_1[0]; cond_1 = None
|
|
return (getitem_1,)""", # noqa: B950
|
|
)
|
|
|
|
def test_cond_autograd_complex(self):
|
|
def true_fn(x):
|
|
return torch.abs((x**2).sin())
|
|
|
|
def false_fn(x):
|
|
return (x + 42).cos()
|
|
|
|
for pred, fn in zip(
|
|
[torch.tensor(False), torch.tensor(True)], [false_fn, true_fn]
|
|
):
|
|
x = torch.randn(4, requires_grad=True)
|
|
result = cond(pred, true_fn, false_fn, (x,))
|
|
self.assertEqual(result, fn(x))
|
|
|
|
grad_out = torch.ones_like(result)
|
|
grads = torch.autograd.grad(result, (x,), grad_out)
|
|
expected_grads = torch.autograd.grad(fn(x), (x,), grad_out)
|
|
self.assertEqual(expected_grads, grads)
|
|
|
|
def f(pred, x):
|
|
result = cond(pred, true_fn, false_fn, (x,))
|
|
grad_out = torch.ones_like(result)
|
|
return torch.autograd.grad(result, (x,), grad_out)
|
|
|
|
gm = make_fx(f)(pred, x)
|
|
self.assertExpectedInline(
|
|
gm.code.strip(),
|
|
"""\
|
|
def forward(self, pred_1, x_1):
|
|
true_graph_0 = self.true_graph_0
|
|
false_graph_0 = self.false_graph_0
|
|
cond = torch.ops.higher_order.cond(pred_1, true_graph_0, false_graph_0, (x_1,)); true_graph_0 = false_graph_0 = None
|
|
getitem = cond[0]; cond = None
|
|
ones_like = torch.ops.aten.ones_like.default(getitem, pin_memory = False); getitem = None
|
|
true_graph_1 = self.true_graph_1
|
|
false_graph_1 = self.false_graph_1
|
|
cond_1 = torch.ops.higher_order.cond(pred_1, true_graph_1, false_graph_1, (x_1, ones_like)); pred_1 = true_graph_1 = false_graph_1 = x_1 = ones_like = None
|
|
getitem_1 = cond_1[0]; cond_1 = None
|
|
return (getitem_1,)""", # noqa: B950
|
|
)
|
|
|
|
@skipIfTorchDynamo("Skip due to graph break when run with dynamo")
|
|
def test_cond_autograd_nested(self):
|
|
class Nested(torch.nn.Module):
|
|
def forward(self, p0, p1, p2, a, b, c):
|
|
def true_fn(x0, y0, z0):
|
|
def true_true_fn(x1, y1, z1):
|
|
return (x1 - y1 * z1) * 3.14
|
|
|
|
def true_false_fn(x1, y1, z1):
|
|
def true_false_true_fn(x2, y2, z2):
|
|
return (x2 * y2 * z2) / 2.71
|
|
|
|
def true_false_false_fn(x2, y2, z2):
|
|
return (x2 + y2 + z2) * 1.23
|
|
|
|
return torch.cond(
|
|
p2, true_false_true_fn, true_false_false_fn, [x1, y1, z1]
|
|
)
|
|
|
|
return torch.cond(p1, true_true_fn, true_false_fn, [x0, y0, z0])
|
|
|
|
def false_fn(x0, y0, z0):
|
|
def false_true_fn(x1, y1, z1):
|
|
def false_true_true_fn(x2, y2, z2):
|
|
return (x2 - y2 - z2) + 1.23
|
|
|
|
def false_true_false_fn(x2, y2, z2):
|
|
return (x2 / y2 / z2) - 3.14
|
|
|
|
return torch.cond(
|
|
p2, false_true_true_fn, false_true_false_fn, [x1, y1, z1]
|
|
)
|
|
|
|
def false_false_fn(x1, y1, z1):
|
|
return (x1 - y1 * z1) / 2.71
|
|
|
|
return torch.cond(p1, false_true_fn, false_false_fn, [x0, y0, z0])
|
|
|
|
return torch.cond(p0, true_fn, false_fn, [a, b, c])
|
|
|
|
nn_module = Nested()
|
|
|
|
def true_fn(x):
|
|
return nn_module(
|
|
torch.tensor(False), torch.tensor(True), torch.tensor(False), x, x, x
|
|
)
|
|
|
|
def false_fn(x):
|
|
return nn_module(
|
|
torch.tensor(True), torch.tensor(False), torch.tensor(True), x, x, x
|
|
)
|
|
|
|
x = torch.randn(4, requires_grad=True)
|
|
|
|
for pred, fn in zip(
|
|
[torch.tensor(False), torch.tensor(True)], [false_fn, true_fn]
|
|
):
|
|
result = cond(pred, true_fn, false_fn, (x,))
|
|
self.assertEqual(result, fn(x))
|
|
|
|
grad_out = torch.ones_like(result)
|
|
grads = torch.autograd.grad(result, (x,), grad_out)
|
|
expected_grads = torch.autograd.grad(fn(x), (x,), grad_out)
|
|
self.assertEqual(expected_grads, grads)
|
|
|
|
@skipIfTorchDynamo("Skip due to graph break when run with dynamo")
|
|
def test_cond_autograd_mixed_require_grad(self):
|
|
def true_fn(x, y, z):
|
|
return x * y * z
|
|
|
|
def false_fn(x, y, z):
|
|
return x + y + z
|
|
|
|
x = torch.randn(4, requires_grad=True)
|
|
y = torch.randn(4, requires_grad=False)
|
|
|
|
for pred, fn in zip(
|
|
[torch.tensor(False), torch.tensor(True)], [false_fn, true_fn]
|
|
):
|
|
result = cond(pred, true_fn, false_fn, (x, y, x))
|
|
self.assertEqual(result, fn(x, y, x))
|
|
|
|
grad_out = torch.ones_like(result)
|
|
grads = torch.autograd.grad(result, (x,), grad_out)
|
|
expected_grads = torch.autograd.grad(fn(x, y, x), (x,), grad_out)
|
|
self.assertEqual(expected_grads, grads)
|
|
|
|
def f(pred, x, y, z):
|
|
result = cond(pred, true_fn, false_fn, (x, y, z))
|
|
grad_out = torch.ones_like(result)
|
|
return torch.autograd.grad(result, (x,), grad_out)
|
|
|
|
gm = make_fx(f)(pred, x, y, x)
|
|
self.assertExpectedInline(
|
|
gm.code.strip(),
|
|
"""\
|
|
def forward(self, pred_1, x_1, y_1, z_1):
|
|
true_graph_0 = self.true_graph_0
|
|
false_graph_0 = self.false_graph_0
|
|
cond = torch.ops.higher_order.cond(pred_1, true_graph_0, false_graph_0, (z_1, y_1)); true_graph_0 = false_graph_0 = None
|
|
getitem = cond[0]; cond = None
|
|
ones_like = torch.ops.aten.ones_like.default(getitem, pin_memory = False); getitem = None
|
|
true_graph_1 = self.true_graph_1
|
|
false_graph_1 = self.false_graph_1
|
|
cond_1 = torch.ops.higher_order.cond(pred_1, true_graph_1, false_graph_1, (z_1, y_1, ones_like)); pred_1 = true_graph_1 = false_graph_1 = z_1 = y_1 = ones_like = None
|
|
getitem_1 = cond_1[0]
|
|
getitem_2 = cond_1[1]; cond_1 = getitem_2 = None
|
|
return (getitem_1,)""", # noqa: B950
|
|
)
|
|
|
|
@skipIfTorchDynamo("Skip due to graph break when run with dynamo")
|
|
def test_cond_autograd_grad_through_cond(self):
|
|
nn_module = torch.nn.Linear(4, 4)
|
|
|
|
def true_fn(x):
|
|
return nn_module(x)
|
|
|
|
def false_fn(X):
|
|
return x * nn_module(x)
|
|
|
|
x = torch.randn(4, requires_grad=True)
|
|
|
|
for pred, fn in zip(
|
|
[torch.tensor(False), torch.tensor(True)], [false_fn, true_fn]
|
|
):
|
|
result = cond(pred, true_fn, false_fn, (x,))
|
|
self.assertEqual(result, fn(x))
|
|
|
|
grad_out = torch.ones_like(result)
|
|
grads = torch.autograd.grad(result, (nn_module.weight,), grad_out)
|
|
expected_grads = torch.autograd.grad(
|
|
fn(
|
|
x,
|
|
),
|
|
(nn_module.weight,),
|
|
grad_out,
|
|
)
|
|
self.assertEqual(expected_grads, grads)
|
|
|
|
def f(pred, x):
|
|
result = cond(pred, true_fn, false_fn, (x,))
|
|
grad_out = torch.ones_like(result)
|
|
return torch.autograd.grad(result, (nn_module.weight,), grad_out)
|
|
|
|
# need to set _allow_non_fake_inputs = True because model parameters don't
|
|
# get fakified.
|
|
gm = make_fx(f, tracing_mode="symbolic", _allow_non_fake_inputs=True)(pred, x)
|
|
self.assertExpectedInline(
|
|
gm.code.strip(),
|
|
"""\
|
|
def forward(self, pred_1, x_1):
|
|
sym_size_int = torch.ops.aten.sym_size.int(x_1, 0)
|
|
true_graph_0 = self.true_graph_0
|
|
false_graph_0 = self.false_graph_0
|
|
_param_constant0 = self._param_constant0
|
|
_param_constant1 = self._param_constant1
|
|
_tensor_constant0 = self._tensor_constant0
|
|
cond = torch.ops.higher_order.cond(pred_1, true_graph_0, false_graph_0, (_param_constant0, _param_constant1, x_1, sym_size_int, _tensor_constant0)); true_graph_0 = false_graph_0 = _param_constant0 = _param_constant1 = _tensor_constant0 = None
|
|
getitem = cond[0]; cond = None
|
|
ones_like = torch.ops.aten.ones_like.default(getitem, pin_memory = False); getitem = None
|
|
true_graph_1 = self.true_graph_1
|
|
false_graph_1 = self.false_graph_1
|
|
_param_constant0_1 = self._param_constant0
|
|
_param_constant1_1 = self._param_constant1
|
|
_tensor_constant0_1 = self._tensor_constant0
|
|
cond_1 = torch.ops.higher_order.cond(pred_1, true_graph_1, false_graph_1, (_param_constant0_1, _param_constant1_1, x_1, sym_size_int, _tensor_constant0_1, ones_like)); pred_1 = true_graph_1 = false_graph_1 = _param_constant0_1 = _param_constant1_1 = x_1 = sym_size_int = _tensor_constant0_1 = ones_like = None
|
|
getitem_1 = cond_1[0]; getitem_1 = None
|
|
getitem_2 = cond_1[1]
|
|
getitem_3 = cond_1[2]; getitem_3 = None
|
|
getitem_4 = cond_1[3]; cond_1 = getitem_4 = None
|
|
return (getitem_2,)""", # noqa: B950
|
|
)
|
|
|
|
def test_cond_in_forloop(self):
|
|
def for_loop_fake(x):
|
|
for i in range(3):
|
|
x = x * x + 1
|
|
return x
|
|
|
|
def for_loop_test(x):
|
|
for i in range(3):
|
|
pred = i < 3
|
|
|
|
def true_fn(x):
|
|
return x * x + 1
|
|
|
|
def false_fn(x):
|
|
return x
|
|
|
|
x = cond(pred, true_fn, false_fn, (x,))
|
|
|
|
return x
|
|
|
|
x = torch.ones(4, requires_grad=True)
|
|
x_new = for_loop_test(x)
|
|
x_exp = for_loop_fake(x)
|
|
|
|
self.assertEqual(x_new, x_exp)
|
|
|
|
grad_out = torch.ones_like(x_new)
|
|
grads = torch.autograd.grad(x_new, (x,), grad_out)
|
|
expected_grads = torch.autograd.grad(x_exp, (x,), grad_out)
|
|
self.assertEqual(expected_grads, grads)
|
|
|
|
def f(x):
|
|
x_new = for_loop_test(x)
|
|
grad_out = torch.ones_like(x_new)
|
|
return torch.autograd.grad(x_new, (x,), grad_out)
|
|
|
|
gm = make_fx(f, tracing_mode="symbolic")(x)
|
|
self.assertExpectedInline(
|
|
gm.code.strip(),
|
|
"""\
|
|
def forward(self, x_1):
|
|
mul = torch.ops.aten.mul.Tensor(x_1, x_1)
|
|
add = torch.ops.aten.add.Tensor(mul, 1); mul = None
|
|
mul_1 = torch.ops.aten.mul.Tensor(add, add)
|
|
add_1 = torch.ops.aten.add.Tensor(mul_1, 1); mul_1 = None
|
|
mul_2 = torch.ops.aten.mul.Tensor(add_1, add_1)
|
|
add_2 = torch.ops.aten.add.Tensor(mul_2, 1); mul_2 = None
|
|
ones_like = torch.ops.aten.ones_like.default(add_2, pin_memory = False); add_2 = None
|
|
mul_3 = torch.ops.aten.mul.Tensor(ones_like, add_1)
|
|
mul_4 = torch.ops.aten.mul.Tensor(ones_like, add_1); ones_like = add_1 = None
|
|
add_3 = torch.ops.aten.add.Tensor(mul_4, mul_3); mul_4 = mul_3 = None
|
|
mul_5 = torch.ops.aten.mul.Tensor(add_3, add)
|
|
mul_6 = torch.ops.aten.mul.Tensor(add_3, add); add_3 = add = None
|
|
add_4 = torch.ops.aten.add.Tensor(mul_6, mul_5); mul_6 = mul_5 = None
|
|
mul_7 = torch.ops.aten.mul.Tensor(add_4, x_1)
|
|
mul_8 = torch.ops.aten.mul.Tensor(add_4, x_1); add_4 = x_1 = None
|
|
add_5 = torch.ops.aten.add.Tensor(mul_8, mul_7); mul_8 = mul_7 = None
|
|
return (add_5,)""", # noqa: B950
|
|
)
|
|
|
|
@skipIfTorchDynamo("Skip due to graph break when run with dynamo")
|
|
def test_cond_autograd_pytree_not_all_inputs_used(self):
|
|
def true_fn(x):
|
|
return x["t"][0] + x["t"][1]["b"]
|
|
|
|
def false_fn(x):
|
|
return x["t"][0] * (x["t"][2][0] / x["t"][1]["b"])
|
|
|
|
a = torch.randn(4, requires_grad=True)
|
|
b = torch.randn(4, requires_grad=True)
|
|
c = torch.randn(4, requires_grad=True)
|
|
|
|
for pred, fn in zip(
|
|
[torch.tensor(False), torch.tensor(True)], [false_fn, true_fn]
|
|
):
|
|
result = cond(pred, true_fn, false_fn, ({"t": [a, {"b": b}, (c,)]},))
|
|
self.assertEqual(result, fn({"t": [a, {"b": b}, (c,)]}))
|
|
|
|
grad_out = torch.ones_like(result)
|
|
if pred:
|
|
with self.assertRaisesRegex(Exception, r"."):
|
|
grads = torch.autograd.grad(result, (a, b, c), grad_out)
|
|
expected_grads = torch.autograd.grad(
|
|
fn({"t": [a, {"b": b}, (c,)]}), (a, b, c), grad_out
|
|
)
|
|
self.assertEqual(expected_grads, grads)
|
|
|
|
def f(pred, a, b, c):
|
|
result = cond(pred, true_fn, false_fn, ({"t": [a, {"b": b}, (c,)]},))
|
|
grad_out = torch.ones_like(result)
|
|
return torch.autograd.grad(result, (a, b), grad_out)
|
|
|
|
gm = make_fx(f, tracing_mode="symbolic", _allow_non_fake_inputs=True)(
|
|
pred, a, b, c
|
|
)
|
|
self.assertExpectedInline(
|
|
gm.code.strip(),
|
|
"""\
|
|
def forward(self, pred_1, a_1, b_1, c_1):
|
|
sym_size_int = torch.ops.aten.sym_size.int(a_1, 0)
|
|
sym_size_int_1 = torch.ops.aten.sym_size.int(b_1, 0)
|
|
sym_size_int_2 = torch.ops.aten.sym_size.int(c_1, 0)
|
|
true_graph_0 = self.true_graph_0
|
|
false_graph_0 = self.false_graph_0
|
|
cond = torch.ops.higher_order.cond(pred_1, true_graph_0, false_graph_0, (a_1, b_1, sym_size_int, sym_size_int_1, c_1, sym_size_int_2)); true_graph_0 = false_graph_0 = None
|
|
getitem = cond[0]; cond = None
|
|
ones_like = torch.ops.aten.ones_like.default(getitem, pin_memory = False); getitem = None
|
|
true_graph_1 = self.true_graph_1
|
|
false_graph_1 = self.false_graph_1
|
|
cond_1 = torch.ops.higher_order.cond(pred_1, true_graph_1, false_graph_1, (a_1, b_1, sym_size_int, sym_size_int_1, c_1, sym_size_int_2, ones_like)); pred_1 = true_graph_1 = false_graph_1 = a_1 = b_1 = sym_size_int = sym_size_int_1 = c_1 = sym_size_int_2 = ones_like = None
|
|
getitem_1 = cond_1[0]
|
|
getitem_2 = cond_1[1]
|
|
getitem_3 = cond_1[2]; cond_1 = getitem_3 = None
|
|
return (getitem_1, getitem_2)""", # noqa: B950
|
|
)
|
|
# Forward
|
|
self.assertExpectedInline(
|
|
gm.true_graph_0.code.strip(),
|
|
"""\
|
|
def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1):
|
|
add = torch.ops.aten.add.Tensor(arg0_1, arg1_1); arg0_1 = arg1_1 = None
|
|
return (add,)""",
|
|
)
|
|
# Backward
|
|
self.assertExpectedInline(
|
|
gm.true_graph_1.code.strip(),
|
|
"""\
|
|
def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1):
|
|
add = torch.ops.aten.add.Tensor(arg0_1, arg1_1); arg0_1 = arg1_1 = add = None
|
|
clone = torch.ops.aten.clone.default(arg6_1)
|
|
clone_1 = torch.ops.aten.clone.default(arg6_1); arg6_1 = None
|
|
zeros_like = torch.ops.aten.zeros_like.default(arg4_1, pin_memory = False); arg4_1 = None
|
|
return [clone, clone_1, zeros_like]""",
|
|
)
|
|
|
|
def test_cond_autograd_pytree_input(self):
|
|
def true_fn(x):
|
|
return x["t"][0] + x["t"][1]["b"] * x["t"][2][0]
|
|
|
|
def false_fn(x):
|
|
return x["t"][0] * (x["t"][2][0] / x["t"][1]["b"])
|
|
|
|
a = torch.randn(4, requires_grad=True)
|
|
b = torch.randn(4, requires_grad=True)
|
|
c = torch.randn(4, requires_grad=True)
|
|
|
|
for pred, fn in zip(
|
|
[torch.tensor(False), torch.tensor(True)], [false_fn, true_fn]
|
|
):
|
|
result = cond(pred, true_fn, false_fn, ({"t": [a, {"b": b}, (c,)]},))
|
|
self.assertEqual(result, fn({"t": [a, {"b": b}, (c,)]}))
|
|
|
|
grad_out = torch.ones_like(result)
|
|
grads = torch.autograd.grad(result, (a, b), grad_out)
|
|
expected_grads = torch.autograd.grad(
|
|
fn({"t": [a, {"b": b}, (c,)]}), (a, b), grad_out
|
|
)
|
|
self.assertEqual(expected_grads, grads)
|
|
|
|
def f(pred):
|
|
result = cond(pred, true_fn, false_fn, ({"t": [a, {"b": b}, (c,)]},))
|
|
grad_out = torch.ones_like(result)
|
|
return torch.autograd.grad(result, (a, b), grad_out)
|
|
|
|
# need to set _allow_non_fake_inputs = True because model parameters don't
|
|
# get fakified.
|
|
gm = make_fx(f, tracing_mode="symbolic", _allow_non_fake_inputs=True)(pred)
|
|
self.assertExpectedInline(
|
|
gm.code.strip(),
|
|
"""\
|
|
def forward(self, pred_1):
|
|
true_graph_0 = self.true_graph_0
|
|
false_graph_0 = self.false_graph_0
|
|
_tensor_constant0 = self._tensor_constant0
|
|
_tensor_constant1 = self._tensor_constant1
|
|
_tensor_constant2 = self._tensor_constant2
|
|
cond = torch.ops.higher_order.cond(pred_1, true_graph_0, false_graph_0, (_tensor_constant0, _tensor_constant1, _tensor_constant2)); true_graph_0 = false_graph_0 = _tensor_constant0 = _tensor_constant1 = _tensor_constant2 = None
|
|
getitem = cond[0]; cond = None
|
|
ones_like = torch.ops.aten.ones_like.default(getitem, pin_memory = False); getitem = None
|
|
true_graph_1 = self.true_graph_1
|
|
false_graph_1 = self.false_graph_1
|
|
_tensor_constant0_1 = self._tensor_constant0
|
|
_tensor_constant1_1 = self._tensor_constant1
|
|
_tensor_constant2_1 = self._tensor_constant2
|
|
cond_1 = torch.ops.higher_order.cond(pred_1, true_graph_1, false_graph_1, (_tensor_constant0_1, _tensor_constant1_1, _tensor_constant2_1, ones_like)); pred_1 = true_graph_1 = false_graph_1 = _tensor_constant0_1 = _tensor_constant1_1 = _tensor_constant2_1 = ones_like = None
|
|
getitem_1 = cond_1[0]
|
|
getitem_2 = cond_1[1]
|
|
getitem_3 = cond_1[2]; cond_1 = getitem_3 = None
|
|
return (getitem_1, getitem_2)""", # noqa: B950
|
|
)
|
|
|
|
def test_cond_autograd_different_pytree_output(self):
|
|
def true_fn(x):
|
|
return x["t"][0], {"r": x["t"][2][0] / x["t"][1]["b"]}, [x["t"][2][0]]
|
|
|
|
def false_fn(x):
|
|
return {"res": [x["t"][0] * x["t"][1]["b"], x["t"][2][0]]}
|
|
|
|
a = torch.randn(4, requires_grad=True)
|
|
b = torch.randn(4, requires_grad=True)
|
|
c = torch.randn(4, requires_grad=True)
|
|
|
|
for pred, fn in zip(
|
|
[torch.tensor(False), torch.tensor(True)], [false_fn, true_fn]
|
|
):
|
|
with self.assertRaisesRegex(
|
|
torch._dynamo.exc.UncapturedHigherOrderOpError,
|
|
"Cond doesn't work unless it is captured completely with torch.compile",
|
|
):
|
|
cond(pred, true_fn, false_fn, ({"t": [a, {"b": b}, (c,)]},))
|
|
|
|
@skipIfTorchDynamo("Skip due to graph break when run with dynamo")
|
|
def test_cond_autograd_same_pytree_output(self):
|
|
def true_fn(x):
|
|
return {"res": [x["t"][0].clone(), (x["t"][2][0].clone(),)]}
|
|
|
|
def false_fn(x):
|
|
return {"res": [x["t"][1]["b"].clone(), (x["t"][2][0].clone(),)]}
|
|
|
|
a = torch.randn(4, requires_grad=True)
|
|
b = torch.randn(4, requires_grad=True)
|
|
c = torch.randn(4, requires_grad=True)
|
|
|
|
for pred, fn in zip(
|
|
[torch.tensor(False), torch.tensor(True)], [false_fn, true_fn]
|
|
):
|
|
result = cond(pred, true_fn, false_fn, ({"t": [a, {"b": b}, (c,)]},))
|
|
result_exp = fn({"t": [a, {"b": b}, (c,)]})
|
|
self.assertEqual(result, result_exp)
|
|
|
|
result_flat, _ = pytree.tree_flatten(result)
|
|
result_exp_flat, _ = pytree.tree_flatten(result_exp)
|
|
|
|
grad_out = [torch.ones_like(g) for g in result_flat]
|
|
expected_grads = torch.autograd.grad(result_exp_flat, (c,), grad_out)
|
|
grads = torch.autograd.grad(result_flat, (c,), grad_out)
|
|
self.assertEqual(expected_grads, grads)
|
|
|
|
def f(pred):
|
|
result = cond(pred, true_fn, false_fn, ({"t": [a, {"b": b}, (c,)]},))
|
|
return result
|
|
|
|
gm = make_fx(f, tracing_mode="real", _allow_non_fake_inputs=True)(pred)
|
|
self.assertExpectedInline(
|
|
gm.code.strip(),
|
|
"""\
|
|
def forward(self, pred_1):
|
|
true_graph_0 = self.true_graph_0
|
|
false_graph_0 = self.false_graph_0
|
|
_tensor_constant0 = self._tensor_constant0
|
|
_tensor_constant1 = self._tensor_constant1
|
|
_tensor_constant2 = self._tensor_constant2
|
|
cond = torch.ops.higher_order.cond(pred_1, true_graph_0, false_graph_0, (_tensor_constant0, _tensor_constant1, _tensor_constant2)); pred_1 = true_graph_0 = false_graph_0 = _tensor_constant0 = _tensor_constant1 = _tensor_constant2 = None
|
|
getitem = cond[0]
|
|
getitem_1 = cond[1]; cond = None
|
|
return {'res': [getitem, (getitem_1,)]}""", # noqa: B950
|
|
)
|
|
|
|
@skipIfTorchDynamo("Skip due to graph break when run with dynamo")
|
|
def test_cond_autograd_torch_nn_module(self):
|
|
nn_module_true = torch.nn.Linear(4, 4)
|
|
|
|
def true_fn(x):
|
|
return nn_module_true(torch.abs((x**2).sin()))
|
|
|
|
nn_module_false = torch.nn.GRUCell(4, 4)
|
|
|
|
def false_fn(x):
|
|
return nn_module_false((x + 42).cos())
|
|
|
|
for pred, fn in zip(
|
|
[torch.tensor(False), torch.tensor(True)], [false_fn, true_fn]
|
|
):
|
|
x = torch.randn(4, requires_grad=True)
|
|
result = cond(pred, true_fn, false_fn, (x,))
|
|
self.assertEqual(result, fn(x))
|
|
|
|
grad_out = torch.ones_like(result)
|
|
grads = torch.autograd.grad(result, (x,), grad_out)
|
|
expected_grads = torch.autograd.grad(fn(x), (x,), grad_out)
|
|
self.assertEqual(expected_grads, grads)
|
|
|
|
def f(pred, x):
|
|
result = cond(pred, true_fn, false_fn, (x,))
|
|
grad_out = torch.ones_like(result)
|
|
return torch.autograd.grad(result, (x,), grad_out)
|
|
|
|
gm = make_fx(f)(pred, x)
|
|
self.assertExpectedInline(
|
|
gm.code.strip(),
|
|
"""\
|
|
def forward(self, pred_1, x_1):
|
|
true_graph_0 = self.true_graph_0
|
|
false_graph_0 = self.false_graph_0
|
|
_param_constant0 = self._param_constant0
|
|
_param_constant1 = self._param_constant1
|
|
_param_constant2 = self._param_constant2
|
|
_param_constant3 = self._param_constant3
|
|
_param_constant4 = self._param_constant4
|
|
_param_constant5 = self._param_constant5
|
|
cond = torch.ops.higher_order.cond(pred_1, true_graph_0, false_graph_0, (x_1, _param_constant0, _param_constant1, _param_constant2, _param_constant3, _param_constant4, _param_constant5)); true_graph_0 = false_graph_0 = _param_constant0 = _param_constant1 = _param_constant2 = _param_constant3 = _param_constant4 = _param_constant5 = None
|
|
getitem = cond[0]; cond = None
|
|
ones_like = torch.ops.aten.ones_like.default(getitem, pin_memory = False); getitem = None
|
|
true_graph_1 = self.true_graph_1
|
|
false_graph_1 = self.false_graph_1
|
|
_param_constant0_1 = self._param_constant0
|
|
_param_constant1_1 = self._param_constant1
|
|
_param_constant2_1 = self._param_constant2
|
|
_param_constant3_1 = self._param_constant3
|
|
_param_constant4_1 = self._param_constant4
|
|
_param_constant5_1 = self._param_constant5
|
|
cond_1 = torch.ops.higher_order.cond(pred_1, true_graph_1, false_graph_1, (x_1, _param_constant0_1, _param_constant1_1, _param_constant2_1, _param_constant3_1, _param_constant4_1, _param_constant5_1, ones_like)); pred_1 = true_graph_1 = false_graph_1 = x_1 = _param_constant0_1 = _param_constant1_1 = _param_constant2_1 = _param_constant3_1 = _param_constant4_1 = _param_constant5_1 = ones_like = None
|
|
getitem_1 = cond_1[0]
|
|
getitem_2 = cond_1[1]; getitem_2 = None
|
|
getitem_3 = cond_1[2]; getitem_3 = None
|
|
getitem_4 = cond_1[3]; getitem_4 = None
|
|
getitem_5 = cond_1[4]; getitem_5 = None
|
|
getitem_6 = cond_1[5]; getitem_6 = None
|
|
getitem_7 = cond_1[6]; cond_1 = getitem_7 = None
|
|
return (getitem_1,)""", # noqa: B950
|
|
)
|
|
|
|
def test_cond_autograd_user_nn_module(self):
|
|
class User_nn_module(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def forward(self, input):
|
|
return input * input
|
|
|
|
nn_module_true = User_nn_module()
|
|
|
|
def true_fn(x):
|
|
return nn_module_true(torch.abs((x**2).sin()))
|
|
|
|
nn_module_false = torch.nn.ReLU(inplace=False)
|
|
|
|
def false_fn(x):
|
|
return nn_module_false((x + 42).cos())
|
|
|
|
for pred, fn in zip(
|
|
[torch.tensor(False), torch.tensor(True)], [false_fn, true_fn]
|
|
):
|
|
x = torch.randn(4, requires_grad=True)
|
|
result = cond(pred, true_fn, false_fn, (x,))
|
|
self.assertEqual(result, fn(x))
|
|
|
|
grad_out = torch.ones_like(result)
|
|
grads = torch.autograd.grad(result, (x,), grad_out)
|
|
expected_grads = torch.autograd.grad(fn(x), (x,), grad_out)
|
|
self.assertEqual(expected_grads, grads)
|
|
|
|
def f(pred, x):
|
|
result = cond(pred, true_fn, false_fn, (x,))
|
|
grad_out = torch.ones_like(result)
|
|
return torch.autograd.grad(result, (x,), grad_out)
|
|
|
|
gm = make_fx(f)(pred, x)
|
|
self.assertExpectedInline(
|
|
gm.code.strip(),
|
|
"""\
|
|
def forward(self, pred_1, x_1):
|
|
true_graph_0 = self.true_graph_0
|
|
false_graph_0 = self.false_graph_0
|
|
cond = torch.ops.higher_order.cond(pred_1, true_graph_0, false_graph_0, (x_1,)); true_graph_0 = false_graph_0 = None
|
|
getitem = cond[0]; cond = None
|
|
ones_like = torch.ops.aten.ones_like.default(getitem, pin_memory = False); getitem = None
|
|
true_graph_1 = self.true_graph_1
|
|
false_graph_1 = self.false_graph_1
|
|
cond_1 = torch.ops.higher_order.cond(pred_1, true_graph_1, false_graph_1, (x_1, ones_like)); pred_1 = true_graph_1 = false_graph_1 = x_1 = ones_like = None
|
|
getitem_1 = cond_1[0]; cond_1 = None
|
|
return (getitem_1,)""", # noqa: B950
|
|
)
|
|
|
|
def test_cond_autograd_inner_fn(self):
|
|
def true_fn(x):
|
|
return torch.abs((x**2).sin())
|
|
|
|
def false_fn(x):
|
|
def inner_fn(x):
|
|
return x**2
|
|
|
|
return torch.abs(inner_fn(x).sin())
|
|
|
|
x = torch.randn(4, requires_grad=True)
|
|
pred = torch.tensor(False)
|
|
fn = false_fn
|
|
result_false = cond(pred, true_fn, false_fn, (x,))
|
|
self.assertEqual(result_false, fn(x))
|
|
|
|
grad_out = torch.ones_like(result_false)
|
|
grads_false = torch.autograd.grad(result_false, (x,), grad_out)
|
|
expected_grads = torch.autograd.grad(fn(x), (x,), grad_out)
|
|
self.assertEqual(expected_grads, grads_false)
|
|
|
|
pred = torch.tensor(True)
|
|
fn = true_fn
|
|
result_true = cond(pred, true_fn, false_fn, (x,))
|
|
self.assertEqual(result_true, fn(x))
|
|
self.assertEqual(result_false, result_true)
|
|
|
|
grad_out = torch.ones_like(result_true)
|
|
grads_true = torch.autograd.grad(result_true, (x,), grad_out)
|
|
expected_grads = torch.autograd.grad(fn(x), (x,), grad_out)
|
|
self.assertEqual(expected_grads, grads_true)
|
|
self.assertEqual(grads_false, grads_true)
|
|
|
|
def f(pred, x):
|
|
result = cond(pred, true_fn, false_fn, (x,))
|
|
grad_out = torch.ones_like(result)
|
|
return torch.autograd.grad(result, (x,), grad_out)
|
|
|
|
gm = make_fx(f)(pred, x)
|
|
self.assertExpectedInline(
|
|
gm.code.strip(),
|
|
"""\
|
|
def forward(self, pred_1, x_1):
|
|
true_graph_0 = self.true_graph_0
|
|
false_graph_0 = self.false_graph_0
|
|
cond = torch.ops.higher_order.cond(pred_1, true_graph_0, false_graph_0, (x_1,)); true_graph_0 = false_graph_0 = None
|
|
getitem = cond[0]; cond = None
|
|
ones_like = torch.ops.aten.ones_like.default(getitem, pin_memory = False); getitem = None
|
|
true_graph_1 = self.true_graph_1
|
|
false_graph_1 = self.false_graph_1
|
|
cond_1 = torch.ops.higher_order.cond(pred_1, true_graph_1, false_graph_1, (x_1, ones_like)); pred_1 = true_graph_1 = false_graph_1 = x_1 = ones_like = None
|
|
getitem_1 = cond_1[0]; cond_1 = None
|
|
return (getitem_1,)""", # noqa: B950
|
|
)
|
|
|
|
def test_cond_autograd_inner_tensor(self):
|
|
def true_fn(x):
|
|
return torch.abs((x**2).sin())
|
|
|
|
def false_fn(x):
|
|
y = torch.ones(4, requires_grad=False) * 42
|
|
return (x * y).cos()
|
|
|
|
for pred, fn in zip(
|
|
[torch.tensor(False), torch.tensor(True)], [false_fn, true_fn]
|
|
):
|
|
x = torch.randn(4, requires_grad=True)
|
|
result = cond(pred, true_fn, false_fn, (x,))
|
|
self.assertEqual(result, fn(x))
|
|
|
|
grad_out = torch.ones_like(result)
|
|
grads = torch.autograd.grad(result, (x,), grad_out)
|
|
expected_grads = torch.autograd.grad(fn(x), (x,), grad_out)
|
|
self.assertEqual(expected_grads, grads)
|
|
|
|
def f(pred, x):
|
|
result = cond(pred, true_fn, false_fn, (x,))
|
|
grad_out = torch.ones_like(result)
|
|
return torch.autograd.grad(result, (x,), grad_out)
|
|
|
|
gm = make_fx(f)(pred, x)
|
|
self.assertExpectedInline(
|
|
gm.code.strip(),
|
|
"""\
|
|
def forward(self, pred_1, x_1):
|
|
true_graph_0 = self.true_graph_0
|
|
false_graph_0 = self.false_graph_0
|
|
cond = torch.ops.higher_order.cond(pred_1, true_graph_0, false_graph_0, (x_1,)); true_graph_0 = false_graph_0 = None
|
|
getitem = cond[0]; cond = None
|
|
ones_like = torch.ops.aten.ones_like.default(getitem, pin_memory = False); getitem = None
|
|
true_graph_1 = self.true_graph_1
|
|
false_graph_1 = self.false_graph_1
|
|
cond_1 = torch.ops.higher_order.cond(pred_1, true_graph_1, false_graph_1, (x_1, ones_like)); pred_1 = true_graph_1 = false_graph_1 = x_1 = ones_like = None
|
|
getitem_1 = cond_1[0]; cond_1 = None
|
|
return (getitem_1,)""", # noqa: B950
|
|
)
|
|
|
|
@unittest.skipIf(not torch.cuda.is_available(), "Test requires CUDA.")
|
|
def test_cond_autograd_gpu(self):
|
|
def true_fn(x):
|
|
return x.sin()
|
|
|
|
def false_fn(x):
|
|
return x.cos()
|
|
|
|
for pred, fn in zip(
|
|
[torch.tensor(False, device="cuda"), torch.tensor(True, device="cuda")],
|
|
[false_fn, true_fn],
|
|
):
|
|
x = torch.randn(4, requires_grad=True, device="cuda")
|
|
result = cond(pred, true_fn, false_fn, (x,))
|
|
self.assertEqual(result, fn(x))
|
|
|
|
grad_out = torch.ones_like(result)
|
|
grads = torch.autograd.grad(result, (x,), grad_out)
|
|
expected_grads = torch.autograd.grad(fn(x), (x,), grad_out)
|
|
self.assertEqual(expected_grads, grads)
|
|
|
|
def _test_cond_autograd(self, cond_fct, pred_fn, true_fn, false_fn, operands):
|
|
from torch.fx.passes.shape_prop import _extract_tensor_metadata, TensorMetadata
|
|
|
|
# This is a helper function that extracts the metadata from the tensor and
|
|
# sets the requries_grad flag to false. This is needed as we compare the
|
|
# metadata of the operands and the gradients
|
|
def _extract_tensor_metadata_except_requires_grad(arg):
|
|
metadata = _extract_tensor_metadata(arg)
|
|
metadata = TensorMetadata(
|
|
metadata.shape,
|
|
metadata.dtype,
|
|
False,
|
|
metadata.stride,
|
|
metadata.memory_format,
|
|
metadata.is_quantized,
|
|
metadata.qparams,
|
|
)
|
|
return metadata
|
|
|
|
# Comparison of FWD path
|
|
cond_outputs = cond_fct(pred_fn(*operands), true_fn, false_fn, operands)
|
|
operands_forced_grad = [o.clone().detach() for o in operands]
|
|
for o in operands_forced_grad:
|
|
o.requires_grad = True
|
|
cond_outputs_exp = (
|
|
true_fn(*operands_forced_grad)
|
|
if pred_fn(*operands_forced_grad)
|
|
else false_fn(*operands_forced_grad)
|
|
)
|
|
self.assertEqual(cond_outputs, cond_outputs_exp)
|
|
|
|
# Comparison of BWD path
|
|
cond_inputs = [o for o in operands if o.requires_grad]
|
|
cond_inputs_exp = [o for o in operands_forced_grad if o.requires_grad]
|
|
|
|
# Check if at least some operators require grads
|
|
if len(cond_inputs) > 0:
|
|
grad_inputs = torch.autograd.grad(
|
|
cond_outputs, cond_inputs, allow_unused=True, retain_graph=True
|
|
)
|
|
grad_inputs_exp = torch.autograd.grad(
|
|
cond_outputs_exp,
|
|
cond_inputs_exp,
|
|
allow_unused=True,
|
|
materialize_grads=True,
|
|
)
|
|
|
|
grad_exp_masked = [
|
|
g for g, o in zip(grad_inputs_exp, operands) if o.requires_grad
|
|
]
|
|
self.assertEqual(grad_exp_masked, grad_inputs)
|
|
|
|
# Extraction and comparison of Metadata of operands and gradients
|
|
operands_metadata = [
|
|
_extract_tensor_metadata_except_requires_grad(o) for o in cond_inputs
|
|
]
|
|
grad_metadata = [
|
|
_extract_tensor_metadata_except_requires_grad(o) for o in grad_inputs
|
|
]
|
|
self.assertTrue(
|
|
all(op == g for op, g in zip(operands_metadata, grad_metadata))
|
|
)
|
|
|
|
return cond_outputs, cond_inputs
|
|
|
|
@skipIfTorchDynamo("don't test compile on compile")
|
|
@unittest.skipIf(not SM70OrLater, "triton")
|
|
@unittest.skipIf(not torch.cuda.is_available(), "Test requires CUDA.")
|
|
@parametrize("compile_mode", ["compile_dynamic_shape"])
|
|
@parametrize("scalar", [False])
|
|
def test_cond_autograd_zeros_unused_branch_complex_compile_fail(
|
|
self, compile_mode, scalar
|
|
):
|
|
device = torch.device("cuda")
|
|
cond_fct = compile_mode_helper(torch.cond, compile_mode)
|
|
|
|
autograd = [False, True, True, True, True]
|
|
|
|
if scalar:
|
|
# These operands work
|
|
x = torch.randn((), device=device, requires_grad=bool(autograd[0]))
|
|
w1 = torch.randn((), device=device, requires_grad=bool(autograd[1]))
|
|
b1 = torch.randn((), device=device, requires_grad=bool(autograd[2]))
|
|
w2 = torch.randn((), device=device, requires_grad=bool(autograd[3]))
|
|
b2 = torch.randn((), device=device, requires_grad=bool(autograd[4]))
|
|
else:
|
|
# These operands do not work
|
|
x = torch.randn(4, 5, device=device, requires_grad=bool(autograd[0]))
|
|
w1 = torch.randn(2, 4, device=device, requires_grad=bool(autograd[1]))
|
|
b1 = torch.randn(2, 1, device=device, requires_grad=bool(autograd[2]))
|
|
w2 = torch.randn(2, 4, device=device, requires_grad=bool(autograd[3]))
|
|
b2 = torch.randn(1, 5, device=device, requires_grad=bool(autograd[4]))
|
|
|
|
operands = [x, w1, b1, w2, b2]
|
|
|
|
def true_fn(x, w1, b1, w2, b2):
|
|
if scalar:
|
|
# This works
|
|
return ((w1 * x + b1),)
|
|
else:
|
|
# This does not work
|
|
return ((w1 @ x + b1).sum(),)
|
|
|
|
def false_fn(x, w1, b1, w2, b2):
|
|
if scalar:
|
|
# This works
|
|
return ((w2 * x + b2),)
|
|
else:
|
|
# This does not work
|
|
return ((w2 @ x + b2).sum(),)
|
|
|
|
def pred_fn(x, w1, b1, w2, b2):
|
|
return x.mean() > 0
|
|
|
|
cond_outputs, cond_inputs = self._test_cond_autograd(
|
|
cond_fct, pred_fn, true_fn, false_fn, operands
|
|
)
|
|
|
|
@unittest.skipIf(not torch.cuda.is_available(), "Test requires CUDA.")
|
|
def test_map_gpu(self):
|
|
def f(x, y):
|
|
return x + y
|
|
|
|
xs = torch.ones(3, 2, 2, device="cuda")
|
|
y = torch.ones(2, device="cuda")
|
|
res = control_flow.map(f, xs, y)
|
|
expected = _fake_map(f, xs, y)
|
|
self.assertEqual(expected, res)
|
|
|
|
@unittest.skipIf(not torch.cuda.is_available(), "Test requires CUDA.")
|
|
def test_while_loop_gpu(self):
|
|
def cond_fn(x):
|
|
return x.sum() < 10
|
|
|
|
def body_fn(x):
|
|
return (x + 1,)
|
|
|
|
x = torch.zeros(1, device="cuda")
|
|
res = while_loop(cond_fn, body_fn, (x,))
|
|
expected = _fake_while_loop(cond_fn, body_fn, (x,))
|
|
self.assertEqual(expected, res)
|
|
|
|
def test_map_illegal_inputs(self):
|
|
def f(x, y):
|
|
return x[0] + x[1] + y
|
|
|
|
with self.assertRaisesRegex(
|
|
RuntimeError,
|
|
r"Mapped xs can only consist of tensors\. Got xs \[3, tensor\(\[1\., 1\.\]\)\]\.",
|
|
):
|
|
_ = control_flow.map(f, (3, torch.ones(2)), torch.ones(2))
|
|
|
|
with self.assertRaisesRegex(
|
|
RuntimeError, r"Leading dimensions of mapped xs cannot be 0\."
|
|
):
|
|
_ = control_flow.map(
|
|
f, (torch.ones(0, 1, 2), torch.ones(0, 1, 2)), torch.ones(2)
|
|
)
|
|
|
|
with self.assertRaisesRegex(
|
|
RuntimeError,
|
|
r"Leading dimensions of mapped xs must be consistent\. "
|
|
r"Got shapes \[torch\.Size\(\[3, 4, 5\]\), torch\.Size\(\[4, 4, 5\]\)\]\.",
|
|
):
|
|
_ = control_flow.map(
|
|
f, (torch.ones(3, 4, 5), torch.ones(4, 4, 5)), torch.ones(5)
|
|
)
|
|
|
|
def test_map_illegal_outputs(self):
|
|
def f(x, y):
|
|
return x.item()
|
|
|
|
def f1(x, y):
|
|
return y.size()
|
|
|
|
def f2(x, y):
|
|
return None
|
|
|
|
x = torch.ones([3])
|
|
y = torch.ones([1, 2, 3])
|
|
with self.assertRaisesRegex(
|
|
RuntimeError, "map doesn't work unless it is captured completely"
|
|
):
|
|
control_flow.map(f, x, y)
|
|
|
|
with self.assertRaisesRegex(
|
|
# Should be
|
|
# torch._dynamo.exc.UncapturedHigherOrderOpError,
|
|
# "Expected all leaves to be of torch.Tensor type.*",
|
|
torch._dynamo.exc.UncapturedHigherOrderOpError,
|
|
"map doesn't work unless it is captured completely with torch.compile.*",
|
|
):
|
|
control_flow.map(f1, x, y)
|
|
|
|
# return None is OK
|
|
control_flow.map(f2, x, y)
|
|
|
|
def test_map_list_in_out(self):
|
|
def f(x, y):
|
|
return [[x[0][0] + y]]
|
|
|
|
xs = [[torch.ones(3, 2, 2)]]
|
|
y = torch.ones(2)
|
|
res = control_flow.map(f, xs, y)
|
|
expected = _fake_map(f, xs, y)
|
|
self.assertEqual(len(res), 1)
|
|
self.assertEqual(len(res[0]), 1)
|
|
self.assertEqual(expected, res)
|
|
|
|
def test_map_dict_in_out(self):
|
|
def f(x, y):
|
|
return {"c": x["a"]["b"] + y}
|
|
|
|
xs = {"a": {"b": torch.ones(3, 2, 2)}}
|
|
y = torch.ones(2)
|
|
res = control_flow.map(f, xs, y)
|
|
expected = _fake_map(f, xs, y)
|
|
self.assertEqual(len(res), 1)
|
|
self.assertTrue("c" in res)
|
|
self.assertEqual(expected, res)
|
|
|
|
def test_map_autograd_simple(self):
|
|
def f(x, y):
|
|
return x.sin().cos() * y.cos().sin()
|
|
|
|
xs = torch.ones(3, 2, 2, requires_grad=True)
|
|
y = torch.ones(2, requires_grad=True)
|
|
res = control_flow.map(f, xs, y)
|
|
expected_res = _fake_map(f, xs, y)
|
|
grad_out = torch.ones_like(res)
|
|
grads = torch.autograd.grad(res, (xs, y), grad_out)
|
|
expected_grads = torch.autograd.grad(expected_res, (xs, y), grad_out)
|
|
self.assertEqual(expected_res, res)
|
|
self.assertEqual(expected_grads, grads)
|
|
|
|
def test_map_autograd_simple_partial_grad(self):
|
|
def f(x, y):
|
|
return x.sin().cos() * y.cos().sin()
|
|
|
|
xs = torch.ones(3, 2, 2, requires_grad=True)
|
|
# Disable the gradient computation for y
|
|
y = torch.ones(2, requires_grad=False)
|
|
res = control_flow.map(f, xs, y)
|
|
expected_res = _fake_map(f, xs, y)
|
|
grad_out = torch.ones_like(res)
|
|
grads = torch.autograd.grad(res, (xs,), grad_out)
|
|
expected_grads = torch.autograd.grad(expected_res, (xs,), grad_out)
|
|
self.assertEqual(expected_res, res)
|
|
self.assertEqual(expected_grads, grads)
|
|
|
|
def test_map_autograd_no_grad_output(self):
|
|
def f(x, y):
|
|
return x[0].sin().cos() + y, y.cos().sin()
|
|
|
|
xs = [torch.ones(3, 2, 2, requires_grad=True), torch.ones(3, 3)]
|
|
# Disable the gradient computation for y
|
|
y = torch.ones(2, requires_grad=False)
|
|
res = control_flow.map(f, xs, y)
|
|
expected_res = _fake_map(f, xs, y)
|
|
grad_out = torch.ones_like(res[0])
|
|
grads = torch.autograd.grad(res[0], (xs[0],), grad_out)
|
|
expected_grads = torch.autograd.grad(expected_res[0], (xs[0],), grad_out)
|
|
self.assertEqual(expected_res, res)
|
|
self.assertEqual(expected_grads, grads)
|
|
|
|
def test_map_autograd_nested_list(self):
|
|
import torch.utils._pytree as pytree
|
|
|
|
def f(x, y):
|
|
a, b = x
|
|
c, d = a
|
|
return [[b.sin() * c.cos()], d.sin() * y.cos()]
|
|
|
|
def fwbw(map_op, f, x, y):
|
|
z = map_op(f, x, y)
|
|
flat_x = pytree.tree_leaves(x)
|
|
flat_z = pytree.tree_leaves(z)
|
|
grads = torch.autograd.grad(
|
|
flat_z, flat_x, [torch.ones_like(z) for z in flat_z]
|
|
)
|
|
return z, grads
|
|
|
|
x = [
|
|
[
|
|
torch.randn(3, 2, 2, requires_grad=True),
|
|
torch.randn(3, 2, 1, requires_grad=True),
|
|
],
|
|
torch.ones(3, 1, 2, requires_grad=True),
|
|
]
|
|
y = torch.ones(1, requires_grad=True)
|
|
true_outs = fwbw(control_flow.map, f, x, y)
|
|
fake_outs = fwbw(_fake_map, f, x, y)
|
|
self.assertEqual(true_outs, fake_outs)
|
|
|
|
def test_map_autograd_higher_order(self):
|
|
from torch.autograd.functional import hessian as hes, jacobian as jac
|
|
|
|
def f(x, y):
|
|
return x.sin().cos() + y
|
|
|
|
def wrapper_jac(x, y):
|
|
return control_flow.map(f, x, y)
|
|
|
|
def wrapper_jac_fake(x, y):
|
|
return _fake_map(f, x, y)
|
|
|
|
def wrapper_hes(x, y):
|
|
return control_flow.map(f, x, y).sum()
|
|
|
|
def wrapper_hes_fake(x, y):
|
|
return _fake_map(f, x, y).sum()
|
|
|
|
for g_fct, (wrap, wrap_fake) in [
|
|
(jac, [wrapper_jac, wrapper_jac_fake]),
|
|
(hes, [wrapper_hes, wrapper_hes_fake]),
|
|
]:
|
|
xs = torch.ones(3, 2, 2, requires_grad=True)
|
|
# Disable the gradient computation for y
|
|
y = torch.ones(2, requires_grad=False)
|
|
res = control_flow.map(f, xs, y)
|
|
expected_res = _fake_map(f, xs, y)
|
|
self.assertEqual(expected_res, res)
|
|
|
|
expected_grads = g_fct(wrap_fake, (xs, y))
|
|
grads = g_fct(wrap, (xs, y))
|
|
self.assertEqual(expected_res, res)
|
|
self.assertEqual(expected_grads, grads)
|
|
|
|
def test_scan_y_less_ndim_then_dim(self):
|
|
def combine_fn(carry, x):
|
|
return carry @ x, (carry @ x).sum()
|
|
|
|
init = torch.randn(4, 3)
|
|
xs = torch.randn(3, 3, 2)
|
|
dim = 2
|
|
out = scan(combine_fn, init, xs, dim=dim)
|
|
exp_out = _fake_scan(combine_fn, init, xs, dim=dim)
|
|
self.assertEqual(out, exp_out)
|
|
|
|
# TODO: provide an implementation for all compile modes and re-enable all test
|
|
@skipIfTorchDynamo("don't test compile on compile")
|
|
@requires_cuda
|
|
@parametrize("reverse", [False, True])
|
|
@parametrize("compile_mode", ["none", "eager"])
|
|
@parametrize("device", [torch.device("cpu"), torch.device("cuda")])
|
|
@parametrize("autograd", [False, True])
|
|
def test_scan_compile(self, reverse, compile_mode, device, autograd):
|
|
def add2(x: torch.Tensor, y: torch.Tensor):
|
|
return x * y, x + y
|
|
|
|
x = torch.randn(3, 10, 2, device=device, requires_grad=autograd)
|
|
|
|
scan_fct = compile_mode_helper(scan, compile_mode)
|
|
|
|
for op, op_pt, init in [
|
|
(
|
|
get_scan_combine_fn("add", False),
|
|
torch.cumsum,
|
|
torch.zeros(10, 2, device=device, requires_grad=autograd),
|
|
),
|
|
(
|
|
get_scan_combine_fn("mul", False),
|
|
torch.cumprod,
|
|
torch.ones(10, 2, device=device, requires_grad=autograd),
|
|
),
|
|
]:
|
|
result = scan_fct(op, init, x, dim=0, reverse=reverse)
|
|
result_exp = _fake_scan(op, init=init, xs=x, dim=0, reverse=reverse)
|
|
self.assertEqual(result, result_exp)
|
|
if not reverse:
|
|
result_exp_PT = op_pt(x, 0)
|
|
self.assertEqual(result[1], result_exp_PT)
|
|
|
|
if autograd:
|
|
self.check_autograd(result, result_exp, (init, x))
|
|
|
|
# Jax Examples
|
|
x = torch.arange(0, 4, device=device, dtype=torch.int64)
|
|
init = torch.zeros(1, device=device, dtype=torch.int64)
|
|
cumsum1 = scan_fct(
|
|
get_scan_combine_fn("add", False),
|
|
init,
|
|
x,
|
|
dim=0,
|
|
reverse=reverse,
|
|
)
|
|
cumsum_exp = _fake_scan(
|
|
get_scan_combine_fn("add", False),
|
|
init=init,
|
|
xs=x,
|
|
dim=0,
|
|
reverse=reverse,
|
|
)
|
|
if not reverse:
|
|
self.assertEqual(
|
|
cumsum1[1],
|
|
torch.tensor([[0.0], [1.0], [3.0], [6.0]], dtype=torch.int64),
|
|
)
|
|
self.assertEqual(cumsum1[0], torch.tensor([6.0], dtype=torch.int64))
|
|
else:
|
|
self.assertEqual(
|
|
cumsum1[1],
|
|
torch.tensor([[6.0], [6.0], [5.0], [3.0]], dtype=torch.int64),
|
|
)
|
|
self.assertEqual(cumsum1[0], torch.tensor([6.0], dtype=torch.int64))
|
|
self.assertEqual(cumsum1, cumsum_exp)
|
|
|
|
# Different carry computation as output computation
|
|
x = torch.arange(1, 5, device=device, dtype=torch.int64)
|
|
init = torch.ones(1, device=device, dtype=torch.int64)
|
|
result = scan_fct(add2, init, x, dim=0, reverse=reverse)
|
|
result_exp = _fake_scan(add2, init=init, xs=x, dim=0, reverse=reverse)
|
|
if not reverse:
|
|
self.assertEqual(
|
|
result[1],
|
|
torch.tensor([[2.0], [3.0], [5.0], [10.0]], dtype=torch.int64),
|
|
)
|
|
self.assertEqual(result[0], torch.tensor([24.0], dtype=torch.int64))
|
|
else:
|
|
self.assertEqual(
|
|
result[1],
|
|
torch.tensor([[25.0], [14.0], [7.0], [5.0]], dtype=torch.int64),
|
|
)
|
|
self.assertEqual(result[0], torch.tensor([24.0], dtype=torch.int64))
|
|
self.assertEqual(result, result_exp)
|
|
|
|
# Non associative operation
|
|
x = torch.arange(
|
|
0, 5, device=device, dtype=torch.float32, requires_grad=autograd
|
|
)
|
|
init = torch.ones(1, device=device, dtype=torch.float32, requires_grad=autograd)
|
|
result = scan_fct(
|
|
get_scan_combine_fn("div", False),
|
|
init,
|
|
x,
|
|
dim=0,
|
|
reverse=reverse,
|
|
)
|
|
result_exp = _fake_scan(
|
|
get_scan_combine_fn("div", False),
|
|
init=init,
|
|
xs=x,
|
|
dim=0,
|
|
reverse=reverse,
|
|
)
|
|
self.assertEqual(result, result_exp)
|
|
|
|
if autograd:
|
|
self.check_autograd(result, result_exp, (init, x))
|
|
|
|
# TODO: provide an implementation for all compile modes and re-enable all test
|
|
@skipIfTorchDynamo("don't test compile on compile")
|
|
@requires_cuda
|
|
@parametrize("reverse", [False, True])
|
|
@parametrize("compile_mode", ["none", "eager"])
|
|
@parametrize("device", [torch.device("cpu"), torch.device("cuda")])
|
|
@parametrize(
|
|
"dtype",
|
|
[
|
|
torch.float16,
|
|
torch.float32,
|
|
torch.int32,
|
|
torch.int64,
|
|
torch.complex64,
|
|
],
|
|
)
|
|
def test_scan_dtype(self, reverse, compile_mode, device, dtype):
|
|
scan_fct = compile_mode_helper(scan, compile_mode)
|
|
|
|
# Check all outputs and carries on the correct device and with torch.float32
|
|
x = torch.randn(3, 10, 2, device=device).to(dtype=dtype)
|
|
op, init = (
|
|
get_scan_combine_fn("adds"),
|
|
torch.zeros(10, 2, device=device, dtype=dtype),
|
|
)
|
|
result = scan_fct(op, init, x, dim=0, reverse=reverse)
|
|
result_exp = _fake_scan(op, init=init, xs=x, dim=0, reverse=reverse)
|
|
self.assertEqual(result, result_exp)
|
|
self.assertEqual(
|
|
[[r.device.type for r in res] for res in result],
|
|
[[device.type for _ in res] for res in result],
|
|
)
|
|
self.assertEqual(
|
|
[[r.dtype for r in res] for res in result],
|
|
[[dtype for _ in res] for res in result],
|
|
)
|
|
|
|
# Check all outputs and carries on the correct device and
|
|
# carry.dtype torch.float32 and output.dtype torch.float16
|
|
x = torch.randn(3, 10, 2, device=device).to(dtype=dtype)
|
|
op, init = (
|
|
get_scan_combine_fn("adds"),
|
|
torch.zeros(10, 2, device=device, dtype=torch.float32),
|
|
)
|
|
result = scan_fct(op, init, x, dim=0, reverse=reverse)
|
|
result_exp = _fake_scan(op, init=init, xs=x, dim=0, reverse=reverse)
|
|
self.assertEqual(result, result_exp)
|
|
self.assertEqual(
|
|
[[r.dtype for r in res] for res in result],
|
|
[
|
|
[torch.float32 for _ in range(len(result[0]))],
|
|
[dtype for _ in range(len(result[1]))],
|
|
],
|
|
)
|
|
|
|
# Check all outputs and carries on the correct device and
|
|
# carry.dtype torch.int64 and output.dtype torch.float32
|
|
x = torch.randn(3, 10, 2, device=device)
|
|
op, init = (
|
|
get_scan_combine_fn("adds"),
|
|
torch.zeros(10, 2, device=device, dtype=dtype),
|
|
)
|
|
result = scan_fct(op, init, x, dim=0, reverse=reverse)
|
|
result_exp = _fake_scan(op, init=init, xs=x, dim=0, reverse=reverse)
|
|
self.assertEqual(result, result_exp)
|
|
self.assertEqual(
|
|
[[r.dtype for r in res] for res in result],
|
|
[
|
|
[dtype for _ in range(len(result[0]))],
|
|
[torch.float32 for _ in range(len(result[1]))],
|
|
],
|
|
)
|
|
|
|
@unittest.skipIf(not SM70OrLater, "triton")
|
|
@requires_cuda
|
|
@parametrize("reverse", [False, True])
|
|
@parametrize("compile_mode", ["none", "eager"])
|
|
@parametrize("device", [torch.device("cpu"), torch.device("cuda")])
|
|
@parametrize("autograd", [False, True])
|
|
def test_scan_dim(self, reverse, compile_mode, device, autograd):
|
|
import random
|
|
|
|
scan_fct = compile_mode_helper(scan, compile_mode)
|
|
|
|
num_dims = [random.randint(2, 5) for _ in range(5)]
|
|
for num_dim in num_dims:
|
|
shapes = [random.randint(1, 10) for _ in range(num_dim)]
|
|
rnd_scan_dim = random.randint(0, num_dim - 1)
|
|
x = torch.randn(*shapes, device=device, requires_grad=autograd)
|
|
init_shapes = shapes[:rnd_scan_dim] + shapes[rnd_scan_dim + 1 :]
|
|
|
|
for op, op_pt, init in [
|
|
(
|
|
get_scan_combine_fn("add", False),
|
|
torch.cumsum,
|
|
torch.zeros(*init_shapes, device=device, requires_grad=autograd),
|
|
),
|
|
(
|
|
get_scan_combine_fn("mul", False),
|
|
torch.cumprod,
|
|
torch.ones(*init_shapes, device=device, requires_grad=autograd),
|
|
),
|
|
]:
|
|
result = scan_fct(op, init, x, dim=rnd_scan_dim, reverse=reverse)
|
|
result_exp = _fake_scan(
|
|
op, init=init, xs=x, dim=rnd_scan_dim, reverse=reverse
|
|
)
|
|
self.assertEqual(result, result_exp)
|
|
if not reverse:
|
|
result_exp_PT = op_pt(x, rnd_scan_dim)
|
|
res_list = list(result)
|
|
res_list[1] = res_list[1].movedim(0, rnd_scan_dim)
|
|
self.assertEqual(res_list[1], result_exp_PT)
|
|
|
|
if autograd:
|
|
self.check_autograd(result, result_exp, (init, x))
|
|
|
|
@unittest.skipIf(not SM70OrLater, "triton")
|
|
@requires_cuda
|
|
@parametrize("reverse", [False, True])
|
|
@parametrize("compile_mode", ["none", "eager"])
|
|
@parametrize("device", [torch.device("cpu"), torch.device("cuda")])
|
|
@parametrize("autograd", [False, True])
|
|
def test_scan_binary_operator(self, reverse, compile_mode, device, autograd):
|
|
state_dim = 20
|
|
timesteps = 10
|
|
scan_fct = compile_mode_helper(scan, compile_mode)
|
|
|
|
projected_inputs = torch.randn(
|
|
timesteps, state_dim, requires_grad=autograd, device=device
|
|
)
|
|
A = torch.randn(state_dim, requires_grad=autograd, device=device)
|
|
elements = (A.repeat((timesteps, 1)), projected_inputs)
|
|
init = tuple(
|
|
[
|
|
torch.ones_like(
|
|
torch._ops.ops.aten.slice(elements[0], 0, 0, 1, 1),
|
|
requires_grad=autograd,
|
|
)
|
|
]
|
|
+ [
|
|
torch.zeros_like(
|
|
torch._ops.ops.aten.slice(projected_inputs, 0, 0, 1, 1),
|
|
requires_grad=autograd,
|
|
)
|
|
]
|
|
)
|
|
|
|
result = scan_fct(
|
|
get_scan_combine_fn("s5_operator", False),
|
|
init,
|
|
elements,
|
|
dim=0,
|
|
reverse=reverse,
|
|
)
|
|
expected_result = _fake_scan(
|
|
get_scan_combine_fn("s5_operator", False),
|
|
init=init,
|
|
xs=elements,
|
|
dim=0,
|
|
reverse=reverse,
|
|
)
|
|
self.assertEqual(result, expected_result)
|
|
|
|
if autograd:
|
|
init_flatten, _ = pytree.tree_flatten(init)
|
|
elements_flatten, _ = pytree.tree_flatten(elements)
|
|
|
|
result_flatten, _ = pytree.tree_flatten(result)
|
|
result_exp_flatten, _ = pytree.tree_flatten(expected_result)
|
|
grad_out = [torch.ones_like(el) for el in result_exp_flatten]
|
|
expected_grads = torch.autograd.grad(
|
|
result_exp_flatten, (*init_flatten, *elements_flatten), grad_out
|
|
)
|
|
grads = torch.autograd.grad(
|
|
result_flatten, (*init_flatten, *elements_flatten), grad_out
|
|
)
|
|
self.assertEqual(grads, expected_grads)
|
|
|
|
@skipIfRocm(msg="Unsupported on ROCM yet")
|
|
@unittest.skipIf(not SM70OrLater, "triton")
|
|
@requires_cuda
|
|
@parametrize("reverse", [False, True])
|
|
@parametrize("compile_mode", ["none", "eager"])
|
|
@parametrize("device", [torch.device("cpu"), torch.device("cuda")])
|
|
@parametrize("autograd", [False, True])
|
|
def test_scan_tuple(self, reverse, compile_mode, device, autograd):
|
|
x = torch.randn(3, 2, 2, device=device, requires_grad=autograd)
|
|
y = torch.randn(3, 2, 2, device=device, requires_grad=autograd)
|
|
inp = (x, y)
|
|
init = tuple(torch._ops.ops.aten.slice(e, 0, 0, 1, 1) for e in inp)
|
|
|
|
scan_fct = compile_mode_helper(scan, compile_mode)
|
|
|
|
result_same = scan_fct(
|
|
get_scan_combine_fn("tuple_fct", False),
|
|
init,
|
|
inp,
|
|
dim=0,
|
|
reverse=reverse,
|
|
)
|
|
expected_result = _fake_scan(
|
|
get_scan_combine_fn("tuple_fct", False),
|
|
init=init,
|
|
xs=inp,
|
|
dim=0,
|
|
reverse=reverse,
|
|
)
|
|
self.assertEqual(result_same, expected_result)
|
|
|
|
if autograd:
|
|
self.check_autograd(result_same, expected_result, (init, inp))
|
|
|
|
def fct_different_output_tuple(x, y):
|
|
return ((x[0] + y[0], x[1] * y[1]), (x[1] * y[1]))
|
|
|
|
inp = (x, y)
|
|
init = tuple(torch._ops.ops.aten.slice(e, 0, 0, 1, 1) for e in inp)
|
|
|
|
result_diff = scan(
|
|
fct_different_output_tuple, init, inp, dim=0, reverse=reverse
|
|
)
|
|
expected_result = _fake_scan(
|
|
fct_different_output_tuple, init=init, xs=inp, dim=0, reverse=reverse
|
|
)
|
|
self.assertEqual(result_diff, expected_result)
|
|
self.assertEqual(result_diff[1], result_same[1][1])
|
|
|
|
if autograd:
|
|
self.check_autograd(result_diff, expected_result, (init, inp))
|
|
|
|
def test_scan_wrong_pytree(self):
|
|
# Init and input have same pytree
|
|
def fct_wrong_pytree(x, y):
|
|
return (
|
|
{
|
|
"i": x["i"] * y["j"][0][0],
|
|
"k": torch.tensor(0.0),
|
|
"j": (
|
|
[x["j"][1][0]["o"].clone()],
|
|
[{"o": torch.sin(x["i"])}],
|
|
),
|
|
},
|
|
{
|
|
"i": x["i"] * y["j"][0][0],
|
|
"k": torch.tensor(0.0),
|
|
"j": ([x["j"][1][0]["o"].clone()], [{"o": torch.sin(x["i"])}]),
|
|
},
|
|
)
|
|
|
|
x = torch.randn(3, 2, 2)
|
|
y = torch.randn(3, 2, 2)
|
|
z = torch.randn(3, 2, 2)
|
|
inp = {"i": x, "j": ([y], [{"o": z}])}
|
|
inp_flat, inp_spec = pytree.tree_flatten(inp)
|
|
init_flat = [torch._ops.ops.aten.slice(e, 0, 0, 1, 1) for e in inp_flat]
|
|
init = pytree.tree_unflatten(init_flat, inp_spec)
|
|
|
|
with self.assertRaisesRegex(
|
|
# Should be
|
|
# torch._dynamo.exc.UncapturedHigherOrderOpError,
|
|
# r"The tree structure of the inits and the carries are not identical.*",
|
|
torch._dynamo.exc.UncapturedHigherOrderOpError,
|
|
"Expected init and carry to have same number of outputs but got lhs.*",
|
|
):
|
|
scan(fct_wrong_pytree, init, inp, dim=0)
|
|
|
|
def test_scan_float_output(self):
|
|
# Init and input have same pytree
|
|
def fct_float_output(x, y):
|
|
return 0.0, x + y
|
|
|
|
x = torch.randn(3, 2, 2)
|
|
init = torch._ops.ops.aten.slice(x, 0, 0, 1, 1)
|
|
|
|
with self.assertRaisesRegex(
|
|
# Should be:
|
|
# torch._dynamo.exc.Unsupported,
|
|
# "HigherOrderOperator body's output must consist of tensors or ints only"
|
|
torch._dynamo.exc.UncapturedHigherOrderOpError,
|
|
"scan must be captured completely.*",
|
|
):
|
|
scan(fct_float_output, init, x, dim=0)
|
|
|
|
@requires_cuda
|
|
@parametrize("reverse", [False, True])
|
|
@parametrize("compile_mode", ["none", "eager"])
|
|
@parametrize("device", [torch.device("cpu"), torch.device("cuda")])
|
|
@parametrize("autograd", [False, True])
|
|
def test_scan_complex_pytree(self, reverse, compile_mode, device, autograd):
|
|
# Init and input have same pytree
|
|
|
|
scan_fct = compile_mode_helper(scan, compile_mode)
|
|
|
|
x = torch.randn(3, 2, 2, device=device, requires_grad=autograd)
|
|
y = torch.randn(3, 2, 2, device=device, requires_grad=autograd)
|
|
z = torch.randn(3, 2, 2, device=device, requires_grad=autograd)
|
|
inp = {"i": x, "j": ([y], [{"o": z}])}
|
|
inp_flat, inp_spec = pytree.tree_flatten(inp)
|
|
init_flat = [torch._ops.ops.aten.slice(e, 0, 0, 1, 1) for e in inp_flat]
|
|
init = pytree.tree_unflatten(init_flat, inp_spec)
|
|
|
|
result = scan_fct(
|
|
get_scan_combine_fn("complex_pointwise", False),
|
|
init,
|
|
inp,
|
|
dim=0,
|
|
reverse=reverse,
|
|
)
|
|
expected_result = _fake_scan(
|
|
get_scan_combine_fn("complex_pointwise", False),
|
|
init=init,
|
|
xs=inp,
|
|
dim=0,
|
|
reverse=reverse,
|
|
)
|
|
self.assertEqual(result, expected_result)
|
|
|
|
if autograd:
|
|
self.check_autograd(result, expected_result, (init, inp))
|
|
|
|
# TODO: Does not work because of the usage of vmap within associative_scan
|
|
# The paT206899919 rameterization is commented out for the moment and the test is marked with expected fail
|
|
# Fails with: AssertionError: scan is not an OpOverload
|
|
@skipIfRocm(msg="Unsupported on ROCM yet")
|
|
@unittest.skipIf(not SM70OrLater, "triton")
|
|
@requires_cuda
|
|
@unittest.expectedFailure
|
|
def test_scan_associative_scan(self):
|
|
combine_mode = "generic"
|
|
compile_mode_scan = "compile"
|
|
compile_mode_associative_scan = "none"
|
|
reverse = True
|
|
reverse_associative_scan = True
|
|
device = torch.device("cuda")
|
|
|
|
scan_fct = compile_mode_helper(scan, compile_mode_scan)
|
|
associative_scan_fct = compile_mode_helper(
|
|
associative_scan, compile_mode_associative_scan
|
|
)
|
|
init = torch.randn(10, 5, device=device)
|
|
inp = torch.randn(3, 10, 5, device=device)
|
|
|
|
def body(x, y):
|
|
val = associative_scan_fct(
|
|
get_scan_combine_fn("add", True),
|
|
y,
|
|
0,
|
|
reverse=reverse_associative_scan,
|
|
combine_mode=combine_mode,
|
|
)
|
|
return x + y, x + val
|
|
|
|
result = scan_fct(body, init, inp, dim=0, reverse=reverse)
|
|
expected_result = _fake_scan(
|
|
body,
|
|
init,
|
|
inp,
|
|
0,
|
|
reverse=reverse,
|
|
)
|
|
|
|
self.assertEqual(result, expected_result)
|
|
|
|
# TODO: provide an implementation for all compile modes and re-enable all test
|
|
@skipIfTorchDynamo("don't test compile on compile")
|
|
@requires_cuda
|
|
@parametrize("compile_mode", ["none", "eager"])
|
|
@parametrize("reverse", [False, True])
|
|
@parametrize("device", [torch.device("cpu"), torch.device("cuda")])
|
|
@parametrize("autograd", [False, True])
|
|
def test_scan_downstream_scan_matmul(self, compile_mode, reverse, device, autograd):
|
|
inp = torch.randn(3, 10, 2, device=device, requires_grad=autograd)
|
|
init = torch.randn(3, 2, device=device, requires_grad=autograd)
|
|
|
|
for ind in range(2):
|
|
# Chain with matmul
|
|
def chain_fct(inp):
|
|
W = torch.ones(2, 5, device=device)
|
|
o = scan(
|
|
get_scan_combine_fn("add", False),
|
|
init,
|
|
inp,
|
|
dim=1,
|
|
reverse=reverse,
|
|
)
|
|
return o[ind] @ W
|
|
|
|
fct_cmp = compile_mode_helper(chain_fct, compile_mode)
|
|
|
|
expected_result = _fake_scan(
|
|
get_scan_combine_fn("add", False),
|
|
init=init,
|
|
xs=inp,
|
|
dim=1,
|
|
reverse=reverse,
|
|
)[ind] @ torch.ones(2, 5, device=device)
|
|
result = fct_cmp(inp)
|
|
self.assertEqual(result, expected_result)
|
|
|
|
if autograd:
|
|
self.check_autograd(result, expected_result, (init, inp))
|
|
|
|
# TODO: provide an implementation for all compile modes and re-enable all test
|
|
@skipIfTorchDynamo("don't test compile on compile")
|
|
@requires_cuda
|
|
@parametrize("compile_mode", ["none", "eager"])
|
|
@parametrize("reverse", [False, True])
|
|
@parametrize("device", [torch.device("cpu"), torch.device("cuda")])
|
|
@parametrize("autograd", [False, True])
|
|
def test_scan_downstream_scan_scan_dim(
|
|
self, compile_mode, reverse, device, autograd
|
|
):
|
|
inp = torch.randn(3, 10, 2, device=device, requires_grad=autograd)
|
|
init = torch.randn(3, 2, device=device, requires_grad=autograd)
|
|
|
|
# Chain with scan on different dim
|
|
init2 = torch.randn(1, 10, 2, device=device, requires_grad=autograd)
|
|
|
|
def chain_fct_different_dim(inp):
|
|
o1 = scan(
|
|
get_scan_combine_fn("add", False),
|
|
init,
|
|
inp,
|
|
dim=1,
|
|
reverse=reverse,
|
|
)
|
|
o1 = pytree.tree_map(lambda t: t.movedim(0, 1), o1)
|
|
o2 = scan(
|
|
get_scan_combine_fn("add", False),
|
|
init2,
|
|
o1[1],
|
|
dim=0,
|
|
reverse=reverse,
|
|
)
|
|
return o2
|
|
|
|
fct_cmp = compile_mode_helper(chain_fct_different_dim, compile_mode)
|
|
|
|
xs = _fake_scan(
|
|
get_scan_combine_fn("add", False),
|
|
init=init,
|
|
xs=inp,
|
|
dim=1,
|
|
reverse=reverse,
|
|
)[1]
|
|
xs = pytree.tree_map(lambda t: t.movedim(0, 1), xs)
|
|
expected_result = _fake_scan(
|
|
get_scan_combine_fn("add", False),
|
|
init=init2,
|
|
xs=xs,
|
|
dim=0,
|
|
reverse=reverse,
|
|
)
|
|
result = fct_cmp(inp)
|
|
self.assertEqual(result, expected_result)
|
|
|
|
if autograd:
|
|
self.check_autograd(result, expected_result, (init, init2, inp))
|
|
|
|
@unittest.skipIf(not SM70OrLater, "triton")
|
|
@requires_cuda
|
|
@parametrize("reverse", [False, True])
|
|
@parametrize("compile_mode", ["none", "eager"])
|
|
@parametrize("device", [torch.device("cpu"), torch.device("cuda")])
|
|
@parametrize("autograd", [False, True])
|
|
def test_scan_non_pointwise(self, reverse, compile_mode, device, autograd):
|
|
scan_fct = compile_mode_helper(scan, compile_mode)
|
|
|
|
x = torch.randn(3, 10, 2, device=device, requires_grad=autograd)
|
|
init = torch.randn(10, 2, device=device, requires_grad=autograd)
|
|
expected_result = _fake_scan(
|
|
get_scan_combine_fn("non_pointwise", False),
|
|
init=init,
|
|
xs=x,
|
|
dim=0,
|
|
reverse=reverse,
|
|
)
|
|
|
|
result = scan_fct(
|
|
get_scan_combine_fn("non_pointwise", False),
|
|
init,
|
|
x,
|
|
dim=0,
|
|
reverse=reverse,
|
|
)
|
|
self.assertEqual(result, expected_result)
|
|
|
|
if autograd:
|
|
self.check_autograd(result, expected_result, (init, x))
|
|
|
|
@requires_cuda
|
|
@parametrize("reverse", [False, True])
|
|
@parametrize("device", [torch.device("cpu"), torch.device("cuda")])
|
|
def test_scan_compile_cnt(self, reverse, device):
|
|
dim = 1
|
|
|
|
from torch._dynamo.testing import CompileCounter
|
|
|
|
# Tests rely on automatic_dynamic = True
|
|
with torch._dynamo.config.patch(automatic_dynamic_shapes=True):
|
|
cnt = CompileCounter()
|
|
x = torch.randn(3, 2, 5, device=device)
|
|
init = torch.randn(3, 5, device=device)
|
|
# First compilation step
|
|
torch.compile(scan, backend=cnt)(
|
|
get_scan_combine_fn("add", False),
|
|
init,
|
|
x,
|
|
dim=dim,
|
|
reverse=reverse,
|
|
)
|
|
self.assertEqual(cnt.frame_count, 1)
|
|
|
|
x = torch.randn(3, 20, 5, device=device)
|
|
init = torch.randn(3, 5, device=device)
|
|
# Recompilation due to first different size
|
|
torch.compile(scan, backend=cnt)(
|
|
get_scan_combine_fn("add", False),
|
|
init,
|
|
x,
|
|
dim=dim,
|
|
reverse=reverse,
|
|
)
|
|
self.assertEqual(cnt.frame_count, 2)
|
|
|
|
x = torch.randn(3, 40, 5, device=device)
|
|
init = torch.randn(3, 5, device=device)
|
|
# No recompilation, because of dynamic shape
|
|
torch.compile(scan, backend=cnt)(
|
|
get_scan_combine_fn("add", False),
|
|
init,
|
|
x,
|
|
dim=dim,
|
|
reverse=reverse,
|
|
)
|
|
self.assertEqual(cnt.frame_count, 2)
|
|
|
|
x = torch.randn(3, 40, 5, device=device)
|
|
init = torch.randn(3, 40, device=device)
|
|
# Recompilation because of dim change
|
|
torch.compile(scan, backend=cnt)(
|
|
get_scan_combine_fn("add", False),
|
|
init,
|
|
x,
|
|
dim=2,
|
|
reverse=reverse,
|
|
)
|
|
self.assertEqual(cnt.frame_count, 3)
|
|
|
|
x = torch.randn(3, 40, 20, device=device)
|
|
init = torch.randn(3, 40, device=device)
|
|
# Recompilation due to first different size on new dim
|
|
torch.compile(scan, backend=cnt)(
|
|
get_scan_combine_fn("add", False),
|
|
init,
|
|
x,
|
|
dim=2,
|
|
reverse=reverse,
|
|
)
|
|
self.assertEqual(cnt.frame_count, 4)
|
|
|
|
x = torch.randn(3, 40, 40, device=device)
|
|
init = torch.randn(3, 40, device=device)
|
|
# No recompilation, because of dynamic shape on new dim
|
|
torch.compile(scan, backend=cnt)(
|
|
get_scan_combine_fn("add", False),
|
|
init,
|
|
x,
|
|
dim=2,
|
|
reverse=reverse,
|
|
)
|
|
self.assertEqual(cnt.frame_count, 4)
|
|
|
|
x = torch.randn(3, 60, 40, device=device)
|
|
init = torch.randn(3, 40, device=device)
|
|
# Recompilation because of dim change
|
|
torch.compile(scan, backend=cnt)(
|
|
get_scan_combine_fn("add", False),
|
|
init,
|
|
x,
|
|
dim=1,
|
|
reverse=reverse,
|
|
)
|
|
self.assertEqual(cnt.frame_count, 5)
|
|
|
|
x = torch.randn(3, 60, 40, device=device)
|
|
init = torch.randn(3, 40, device=device)
|
|
# Recompilation because of reverse change
|
|
torch.compile(scan, backend=cnt)(
|
|
get_scan_combine_fn("add", False),
|
|
init,
|
|
x,
|
|
dim=1,
|
|
reverse=not reverse,
|
|
)
|
|
self.assertEqual(cnt.frame_count, 6)
|
|
|
|
x = torch.randn(3, 60, 40, device=device)
|
|
init = torch.randn(3, 40, device=device)
|
|
# No recompilation, as nothing changed
|
|
torch.compile(scan, backend=cnt)(
|
|
get_scan_combine_fn("add", False),
|
|
init,
|
|
x,
|
|
dim=1,
|
|
reverse=not reverse,
|
|
)
|
|
self.assertEqual(cnt.frame_count, 6)
|
|
|
|
x = torch.randn(3, 120, 80, device=device)
|
|
init = torch.randn(3, 80, device=device)
|
|
# No recompilation, final test
|
|
torch.compile(scan, backend=cnt)(
|
|
get_scan_combine_fn("add", False),
|
|
init,
|
|
x,
|
|
dim=1,
|
|
reverse=reverse,
|
|
)
|
|
self.assertEqual(cnt.frame_count, 6)
|
|
|
|
@skipIfTorchDynamo("don't test compile on compile")
|
|
def test_scan_init_scanned_0(self):
|
|
# Only init and no input
|
|
x = torch.randn(3, 1, 2, device=torch.device("cpu"))
|
|
init = torch.randn(3, 2, device=torch.device("cpu"))
|
|
dim = 1
|
|
|
|
# Scan dimension is 0
|
|
init = torch._ops.ops.aten.slice(x, dim, 0, 1, 1)
|
|
inp = torch._ops.ops.aten.slice(x, dim, 1, None, 1)
|
|
with self.assertRaisesRegex(
|
|
RuntimeError,
|
|
"All xs leaves must at least have.*",
|
|
):
|
|
scan(
|
|
get_scan_combine_fn("add", False),
|
|
init,
|
|
inp,
|
|
dim=dim,
|
|
)
|
|
|
|
@skipIfTorchDynamo("don't test compile on compile")
|
|
def test_scan_init_non_tensor(self):
|
|
x = torch.randn(3, 1, 2, device=torch.device("cpu"))
|
|
dim = 1
|
|
|
|
# Init is a float and not a tensor
|
|
init = 1.0
|
|
with self.assertRaisesRegex(RuntimeError, "All init leaves must be a Tensor.*"):
|
|
scan(get_scan_combine_fn("add", False), init, x, dim=dim, reverse=False)
|
|
|
|
@skipIfTorchDynamo("don't test compile on compile")
|
|
def test_scan_init_wrong_shape(self):
|
|
scan_fct = compile_mode_helper(scan, "none")
|
|
|
|
# Only init and no input
|
|
x = torch.randn(3, 1, 2)
|
|
dim = 1
|
|
|
|
# Init wrong shape (Other dim different)
|
|
init = torch.randn(1, 2)
|
|
with self.assertRaisesRegex(
|
|
torch._dynamo.exc.UncapturedHigherOrderOpError,
|
|
"Expected init and carry to have same metadata.*",
|
|
):
|
|
scan_fct(
|
|
get_scan_combine_fn("add", False),
|
|
init,
|
|
x,
|
|
dim=dim,
|
|
)
|
|
|
|
@skipIfTorchDynamo("don't test compile on compile")
|
|
def test_scan_init_wrong_pytree_init_longer_carry(self):
|
|
def init_longer_carry(x: torch.Tensor, y: torch.Tensor):
|
|
return x[0] + 1.0, y + 1.0
|
|
|
|
scan_fct = compile_mode_helper(scan, "none")
|
|
|
|
# Only init and no input
|
|
x = torch.randn(3, 1, 2)
|
|
dim = 1
|
|
|
|
# Init wrong pytree
|
|
init = (
|
|
torch._ops.ops.aten.slice(x, dim, 0, 1, 1),
|
|
torch._ops.ops.aten.slice(x, dim, 0, 1, 1),
|
|
)
|
|
|
|
with self.assertRaisesRegex(
|
|
torch._dynamo.exc.UncapturedHigherOrderOpError,
|
|
"Expected init and carry to have same number of outputs but got lhs.*",
|
|
):
|
|
scan_fct(init_longer_carry, init, x, dim=dim)
|
|
|
|
@skipIfTorchDynamo("don't test compile on compile")
|
|
def test_scan_init_wrong_pytree_init_shorter_carry(self):
|
|
def init_shorter_carry(x: torch.Tensor, y: torch.Tensor):
|
|
return (x + 1, x + 2), x + 3
|
|
|
|
scan_fct = compile_mode_helper(scan, "none")
|
|
|
|
# Only init and no input
|
|
x = torch.randn(3, 1, 2)
|
|
dim = 1
|
|
|
|
# Init wrong pytree
|
|
init = torch._ops.ops.aten.slice(x, dim, 0, 1, 1)
|
|
|
|
with self.assertRaisesRegex(
|
|
# Should be
|
|
# torch._dynamo.exc.Unsupported,
|
|
# The tree structure of the inits and the carries are not identical!
|
|
torch._dynamo.exc.UncapturedHigherOrderOpError,
|
|
"Expected init and carry to have same number of outputs but got lhs.*",
|
|
):
|
|
scan_fct(init_shorter_carry, init, x, dim=dim)
|
|
|
|
@skipIfTorchDynamo("don't test compile on compile")
|
|
def test_scan_init_wrong_pytree_carry_shape(self):
|
|
def wrong_carry_shape(x: torch.Tensor, y: torch.Tensor):
|
|
return x[0, :], x + 3
|
|
|
|
scan_fct = compile_mode_helper(scan, "none")
|
|
|
|
# Only init and no input
|
|
x = torch.randn(3, 1, 2)
|
|
dim = 1
|
|
|
|
# Init wrong pytree
|
|
init = torch._ops.ops.aten.slice(x, dim, 0, 1, 1)
|
|
|
|
with self.assertRaisesRegex(
|
|
# Should be
|
|
# torch._dynamo.exc.Unsupported,
|
|
# "Encountered aliasing during higher order op tracing for HOP.*"
|
|
torch._dynamo.exc.UncapturedHigherOrderOpError,
|
|
"scan must be captured completely with torch.compile.*",
|
|
):
|
|
scan_fct(wrong_carry_shape, init, x, dim=dim)
|
|
|
|
@skipIfTorchDynamo("don't test compile on compile")
|
|
def test_scan_one_return(self):
|
|
def no_carry(x: torch.Tensor, y: torch.Tensor):
|
|
return x + 3
|
|
|
|
scan_fct = compile_mode_helper(scan, "none")
|
|
|
|
# Only init and no input
|
|
x = torch.randn(3, 1, 2)
|
|
dim = 1
|
|
|
|
# Init wrong pytree
|
|
init = torch._ops.ops.aten.slice(x, dim, 0, 1, 1)
|
|
|
|
with self.assertRaisesRegex(
|
|
# Should be
|
|
# torch._dynamo.exc.Unsupported,
|
|
# combine_fn needs to produce two pytrees, one for the carries and one for the outputs.
|
|
torch._dynamo.exc.UncapturedHigherOrderOpError,
|
|
"scan must be captured completely with.*",
|
|
):
|
|
scan_fct(no_carry, init, x, dim=dim)
|
|
|
|
@skipIfTorchDynamo("don't test compile on compile")
|
|
@requires_cuda
|
|
@parametrize("reverse", [False, True])
|
|
@parametrize("compile_mode", ["none", "eager"])
|
|
@parametrize("device", [torch.device("cpu"), torch.device("cuda")])
|
|
@parametrize("autograd", [False, True])
|
|
def test_scan_init(self, reverse, compile_mode, device, autograd):
|
|
scan_fct = compile_mode_helper(scan, compile_mode)
|
|
|
|
# Only init and no input
|
|
x = torch.randn(3, 1, 2, device=device, requires_grad=autograd)
|
|
dim = 1
|
|
op, op_pt = (get_scan_combine_fn("add", False), torch.cumsum)
|
|
|
|
# Only init given
|
|
init = torch._ops.ops.aten.slice(x, dim, 0, 1, 1)
|
|
result = scan_fct(op, init, [], dim=dim, reverse=reverse)
|
|
result_exp = _fake_scan(op, init=init, xs=[], dim=dim, reverse=reverse)
|
|
result_init = scan_fct(op, init, [], dim=dim, reverse=reverse)
|
|
self.assertEqual(result, result_exp)
|
|
self.assertEqual(result_init, result_exp)
|
|
self.assertEqual(result_init[0], init)
|
|
|
|
if autograd:
|
|
self.check_autograd(result, result_exp, (init,))
|
|
|
|
x = torch.randn(3, 5, 2, device=device, requires_grad=autograd)
|
|
dim = 0
|
|
|
|
op, op_pt = (get_scan_combine_fn("add", False), torch.cumsum)
|
|
inp = torch._ops.ops.aten.slice(x, dim, 1, None, 1)
|
|
|
|
# Init tensor scalar
|
|
init = torch.ones(1, device=device, requires_grad=autograd)
|
|
|
|
def add_scalar_carry(x: torch.Tensor, y: torch.Tensor):
|
|
return x + 1.0, x + y
|
|
|
|
result_init = scan_fct(add_scalar_carry, init, inp, dim=dim, reverse=reverse)
|
|
result_exp = _fake_scan(
|
|
add_scalar_carry, init=init, xs=inp, dim=dim, reverse=reverse
|
|
)
|
|
self.assertEqual(result_init, result_exp)
|
|
self.assertEqual(result_init[0], torch.tensor([3.0], device=device))
|
|
|
|
if autograd:
|
|
self.check_autograd(result_init, result_exp, (init, inp))
|
|
|
|
# Init tensor entirely different shape than inp
|
|
init = torch.randn(7, 8, device=device, requires_grad=autograd)
|
|
|
|
def add_scalar_carry2(x: torch.Tensor, y: torch.Tensor):
|
|
return x + 1.0, x[: y.shape[0], : y.shape[1]] + y
|
|
|
|
result_init = scan_fct(add_scalar_carry2, init, inp, dim=dim, reverse=reverse)
|
|
result_exp = _fake_scan(
|
|
add_scalar_carry2, init=init, xs=inp, dim=dim, reverse=reverse
|
|
)
|
|
self.assertEqual(result_init, result_exp)
|
|
|
|
# Init with two timestep on dim axis. Should work as y has always 1 on dim axis and
|
|
# hence automatic broadcasting should work
|
|
# I.e., the input shape is 2x5x2, but the carry at each iteration is 2x5x2,
|
|
# thus the output of each iteration is 2x5x2, which results in the total output
|
|
# to be 4x5x2
|
|
init = torch._ops.ops.aten.slice(x, dim, 0, 2, 1)
|
|
result_init = scan_fct(op, init, inp, dim=dim, reverse=reverse)
|
|
result_exp = _fake_scan(op, init=init, xs=inp, dim=dim, reverse=reverse)
|
|
self.assertEqual(result_init, result_exp)
|
|
self.assertEqual(result_init[0].shape, torch.Size([2, 5, 2]))
|
|
|
|
if autograd:
|
|
self.check_autograd(result_init, result_exp, (init, inp))
|
|
|
|
init = torch.tile(init, (1, 2, 1))
|
|
|
|
def add_scalar_carry_sliced_out(x: torch.Tensor, y: torch.Tensor):
|
|
return x + 1.0, x[:, :1, :] + y
|
|
|
|
result_init = scan_fct(
|
|
add_scalar_carry_sliced_out, init, inp, dim=dim, reverse=reverse
|
|
)
|
|
result_exp = _fake_scan(
|
|
add_scalar_carry_sliced_out, init=init, xs=inp, dim=dim, reverse=reverse
|
|
)
|
|
self.assertEqual(result_init, result_exp)
|
|
self.assertEqual(result_init[0].shape, torch.Size([2, 10, 2]))
|
|
self.assertEqual(result_init[1].shape, torch.Size([2, 2, 5, 2]))
|
|
|
|
if autograd:
|
|
self.check_autograd(result_init, result_exp, (init, inp))
|
|
|
|
# Correct case
|
|
op, op_pt = (get_scan_combine_fn("add", False), torch.cumsum)
|
|
x = torch.randn(3, 2, 2, device=device, requires_grad=autograd)
|
|
init = torch.zeros(3, 2, device=device, requires_grad=autograd)
|
|
dim = 2
|
|
|
|
result = scan_fct(op, init, x, dim=dim, reverse=reverse)
|
|
result_exp = _fake_scan(op, init=init, xs=x, dim=dim, reverse=reverse)
|
|
|
|
self.assertEqual(result, result_exp)
|
|
if not reverse:
|
|
result_exp_PT = op_pt(x, dim)
|
|
result = list(result)
|
|
result[1] = pytree.tree_map(lambda t: torch.movedim(t, 0, dim), result[1])
|
|
self.assertEqual(result[1], result_exp_PT)
|
|
|
|
if autograd:
|
|
self.check_autograd(result, result_exp, (init, x))
|
|
|
|
@requires_cuda
|
|
@parametrize("reverse", [False, True])
|
|
@parametrize("device", [torch.device("cpu"), torch.device("cuda")])
|
|
def test_scan_init_wrong_pytree_complex(self, reverse, device):
|
|
x = torch.randn(3, 2, 2, device=device)
|
|
y = torch.randn(3, 2, 2, device=device)
|
|
z = torch.randn(3, 2, 2, device=device)
|
|
|
|
# Wrong pytree fed to the function
|
|
init = {
|
|
"i": torch._ops.ops.aten.slice(x, 0, 0, 1, 1),
|
|
"j": (
|
|
{"a": torch._ops.ops.aten.slice(x, 0, 0, 1, 1)},
|
|
[torch._ops.ops.aten.slice(y, 0, 0, 1, 1)],
|
|
[{"o": torch._ops.ops.aten.slice(z, 0, 0, 1, 1)}],
|
|
),
|
|
}
|
|
inp = {
|
|
"i": torch._ops.ops.aten.slice(x, 0, 0, None, 1),
|
|
"j": (
|
|
[torch._ops.ops.aten.slice(y, 0, 0, None, 1)],
|
|
[{"o": torch._ops.ops.aten.slice(z, 0, 0, None, 1)}],
|
|
),
|
|
}
|
|
with self.assertRaisesRegex(
|
|
Exception,
|
|
".*",
|
|
):
|
|
scan(
|
|
get_scan_combine_fn("complex_pointwise", False),
|
|
init,
|
|
inp,
|
|
dim=0,
|
|
reverse=reverse,
|
|
)
|
|
|
|
@unittest.skipIf(not SM70OrLater, "triton")
|
|
@requires_cuda
|
|
@parametrize("reverse", [False, True])
|
|
@parametrize("compile_mode", ["none", "eager"])
|
|
@parametrize("device", [torch.device("cpu"), torch.device("cuda")])
|
|
@parametrize("autograd", [False, True])
|
|
def test_scan_init_pytree_complex(self, reverse, compile_mode, device, autograd):
|
|
def fct_pointwise_different_output(x, y):
|
|
return (
|
|
{
|
|
"i": x["i"] * y["i"],
|
|
"j": (
|
|
[x["j"][0][0] * y["j"][0][0]],
|
|
[{"o": x["j"][1][0]["o"] + y["j"][1][0]["o"]}],
|
|
),
|
|
},
|
|
(
|
|
y["i"] * 2,
|
|
{
|
|
"o": x["i"] * y["i"],
|
|
"j": (
|
|
[x["j"][0][0] * y["j"][0][0]],
|
|
[{"o": x["j"][1][0]["o"] + y["j"][1][0]["o"]}],
|
|
),
|
|
},
|
|
),
|
|
)
|
|
|
|
def fct_pointwise_different_carry(x, y):
|
|
return (
|
|
{
|
|
"i": x["i"] * y["i"],
|
|
"j": (
|
|
x["i"] * 2,
|
|
[x["j"][1][0] * y["j"][0][0]],
|
|
[{"o": x["j"][2][0]["o"] + y["j"][1][0]["o"]}],
|
|
),
|
|
},
|
|
(
|
|
y["i"] * 2,
|
|
{
|
|
"o": x["i"] * y["i"] + x["j"][0][0],
|
|
"j": (
|
|
[x["j"][1][0] * y["j"][0][0]],
|
|
[{"o": x["j"][2][0]["o"] + y["j"][1][0]["o"]}],
|
|
),
|
|
},
|
|
),
|
|
)
|
|
|
|
scan_fct = compile_mode_helper(scan, compile_mode)
|
|
|
|
x = torch.randn(3, 2, 2, device=device, requires_grad=autograd)
|
|
y = torch.randn(3, 2, 2, device=device, requires_grad=autograd)
|
|
z = torch.randn(3, 2, 2, device=device, requires_grad=autograd)
|
|
|
|
if reverse:
|
|
init_start, init_end = -1, None
|
|
inp_start, inp_end = 0, -1
|
|
else:
|
|
init_start, init_end = 0, 1
|
|
inp_start, inp_end = 1, None
|
|
|
|
# Regular case
|
|
init = {
|
|
"i": torch._ops.ops.aten.slice(x, 0, init_start, init_end, 1),
|
|
"j": (
|
|
[torch._ops.ops.aten.slice(y, 0, init_start, init_end, 1)],
|
|
[{"o": torch._ops.ops.aten.slice(z, 0, init_start, init_end, 1)}],
|
|
),
|
|
}
|
|
inp = {
|
|
"i": torch._ops.ops.aten.slice(x, 0, inp_start, inp_end, 1),
|
|
"j": (
|
|
[torch._ops.ops.aten.slice(y, 0, inp_start, inp_end, 1)],
|
|
[{"o": torch._ops.ops.aten.slice(z, 0, inp_start, inp_end, 1)}],
|
|
),
|
|
}
|
|
result = scan_fct(
|
|
get_scan_combine_fn("complex_pointwise", False),
|
|
init,
|
|
inp,
|
|
dim=0,
|
|
reverse=reverse,
|
|
)
|
|
expected_result = _fake_scan(
|
|
get_scan_combine_fn("complex_pointwise", False),
|
|
init,
|
|
inp,
|
|
dim=0,
|
|
reverse=reverse,
|
|
)
|
|
self.assertEqual(result, expected_result)
|
|
|
|
if autograd:
|
|
init_flat = pytree.tree_leaves(init)
|
|
inp_flat = pytree.tree_leaves(inp)
|
|
self.check_autograd(result, expected_result, (*init_flat, *inp_flat))
|
|
|
|
# Pytree of output is different
|
|
result = scan_fct(
|
|
fct_pointwise_different_output, init, inp, dim=0, reverse=reverse
|
|
)
|
|
expected_result = _fake_scan(
|
|
fct_pointwise_different_output, init=init, xs=inp, dim=0, reverse=reverse
|
|
)
|
|
self.assertEqual(result, expected_result)
|
|
|
|
# Pytree of carry is different
|
|
init = {
|
|
"i": torch._ops.ops.aten.slice(x, 0, init_start, init_end, 1),
|
|
"j": (
|
|
torch._ops.ops.aten.slice(x, 0, init_start, init_end, 1),
|
|
[torch._ops.ops.aten.slice(y, 0, init_start, init_end, 1)],
|
|
[{"o": torch._ops.ops.aten.slice(z, 0, init_start, init_end, 1)}],
|
|
),
|
|
}
|
|
inp = {
|
|
"i": torch._ops.ops.aten.slice(x, 0, inp_start, inp_end, 1),
|
|
"j": (
|
|
[torch._ops.ops.aten.slice(y, 0, inp_start, inp_end, 1)],
|
|
[{"o": torch._ops.ops.aten.slice(z, 0, inp_start, inp_end, 1)}],
|
|
),
|
|
}
|
|
result = scan_fct(
|
|
fct_pointwise_different_carry, init, inp, dim=0, reverse=reverse
|
|
)
|
|
expected_result = _fake_scan(
|
|
fct_pointwise_different_carry, init=init, xs=inp, dim=0, reverse=reverse
|
|
)
|
|
self.assertEqual(result, expected_result)
|
|
|
|
if autograd:
|
|
init_flat = pytree.tree_leaves(init)
|
|
inp_flat = pytree.tree_leaves(inp)
|
|
self.check_autograd(result, expected_result, (*init_flat, *inp_flat))
|
|
|
|
@skipIfTorchDynamo("don't test compile on compile")
|
|
@skipIfNoDynamoSupport
|
|
@skipIfCrossRef # Arg order changes with crossref
|
|
def test_scan_pytree_output(self):
|
|
x = torch.randn(3, 10, 2, device=torch.device("cpu"))
|
|
init = torch.randn(1, 10, 2, device=torch.device("cpu"))
|
|
|
|
def f(fct, init, xs):
|
|
return scan(fct, init, xs, dim=0, reverse=True)
|
|
|
|
def combine_fn(init, x):
|
|
a, b = (init[0] + x, init[1] - x)
|
|
return (a, b), a - b
|
|
|
|
# Check graph
|
|
backend = EagerAndRecordGraphs()
|
|
torch.compile(f, backend=backend)(combine_fn, (init, init.clone()), x)
|
|
gm = backend.graphs[0]
|
|
|
|
self.assertExpectedInline(
|
|
normalize_gm(gm.print_readable(print_output=False)),
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(self, L_init_0_: "f32[1, 10, 2]", L_init_1_: "f32[1, 10, 2]", L_xs_: "f32[3, 10, 2]"):
|
|
l_init_0_ = L_init_0_
|
|
l_init_1_ = L_init_1_
|
|
l_xs_ = L_xs_
|
|
|
|
elem: "f32[3, 10, 2]" = torch.movedim(l_xs_, 0, 0); l_xs_ = None
|
|
|
|
flip: "f32[3, 10, 2]" = torch.flip(elem, [0]); elem = None
|
|
|
|
scan_combine_fn_0 = self.scan_combine_fn_0
|
|
scan = torch.ops.higher_order.scan(scan_combine_fn_0, [l_init_0_, l_init_1_], [flip], []); scan_combine_fn_0 = l_init_0_ = l_init_1_ = flip = None
|
|
getitem: "f32[1, 10, 2]" = scan[0]
|
|
getitem_1: "f32[1, 10, 2]" = scan[1]
|
|
out: "f32[3, 1, 10, 2]" = scan[2]; scan = None
|
|
|
|
out_1: "f32[3, 1, 10, 2]" = out.flip([0]); out = None
|
|
return (getitem, getitem_1, out_1)
|
|
|
|
class scan_combine_fn_0(torch.nn.Module):
|
|
def forward(self, child: "f32[1, 10, 2]", child_1: "f32[1, 10, 2]", child_2: "f32[10, 2]"):
|
|
a: "f32[1, 10, 2]" = child + child_2; child = None
|
|
b: "f32[1, 10, 2]" = child_1 - child_2; child_1 = child_2 = None
|
|
|
|
child_3: "f32[1, 10, 2]" = a - b
|
|
return [a, b, child_3]
|
|
""", # noqa: B950
|
|
)
|
|
|
|
@skipIfTorchDynamo("Graph is not captured by backend if test with dynamo")
|
|
@unittest.skipIf(not SM70OrLater, "triton")
|
|
@requires_cuda
|
|
@parametrize("compile_mode", ["none", "eager"])
|
|
@parametrize("autograd", [False, True])
|
|
def test_scan_closure_RNN(self, compile_mode, autograd):
|
|
dim = 1
|
|
device = torch.device("cpu")
|
|
scan_fct = compile_mode_helper(scan, compile_mode)
|
|
|
|
rnn = torch.nn.RNN(
|
|
input_size=5,
|
|
hidden_size=7,
|
|
batch_first=True,
|
|
)
|
|
rnn = rnn.to(device=device)
|
|
x = torch.randn(3, 10, 5, device=device, requires_grad=autograd)
|
|
h = torch.randn(3, 7, device=device, requires_grad=autograd)
|
|
|
|
W_ih = rnn.weight_ih_l0.T.clone()
|
|
b_ih = rnn.bias_ih_l0.clone()
|
|
W_hh = rnn.weight_hh_l0.T.clone()
|
|
b_hh = rnn.bias_hh_l0.clone()
|
|
|
|
if not autograd:
|
|
W_ih = W_ih.detach()
|
|
b_ih = b_ih.detach()
|
|
W_hh = W_hh.detach()
|
|
b_hh = b_hh.detach()
|
|
|
|
expected_result = rnn(x, torch.unsqueeze(h, 0))
|
|
expected_result_out = expected_result[0]
|
|
expected_result_state = expected_result[1][0, :]
|
|
|
|
result = scan_fct(
|
|
get_scan_combine_fn("RNN", True, parameters=[W_ih, b_ih, W_hh, b_hh]),
|
|
h,
|
|
x,
|
|
dim=dim,
|
|
reverse=False,
|
|
)
|
|
result_cmp = [result[0], torch.movedim(result[1], 0, dim)]
|
|
self.assertEqual(result_cmp[0], expected_result_state)
|
|
self.assertEqual(result_cmp[1], expected_result_out)
|
|
|
|
if autograd:
|
|
result_flat = pytree.tree_leaves(result)
|
|
result_exp_flat = [expected_result_state, expected_result_out]
|
|
|
|
grad_out_expected = [torch.ones_like(r) for r in result_exp_flat]
|
|
expected_grads = torch.autograd.grad(
|
|
result_exp_flat,
|
|
(
|
|
h,
|
|
x,
|
|
rnn.weight_ih_l0,
|
|
rnn.bias_ih_l0,
|
|
rnn.weight_hh_l0,
|
|
rnn.bias_hh_l0,
|
|
),
|
|
grad_out_expected,
|
|
)
|
|
expected_add_input_grads = list(expected_grads[2:])
|
|
expected_grads = expected_grads[:2]
|
|
|
|
grad_out = [torch.ones_like(r) for r in result]
|
|
grads = torch.autograd.grad(
|
|
result_flat, (h, x, W_ih, b_ih, W_hh, b_hh), grad_out
|
|
)
|
|
add_input_grads = list(grads[2:])
|
|
add_input_grads[0] = add_input_grads[0].T
|
|
add_input_grads[2] = add_input_grads[2].T
|
|
grads = grads[:2]
|
|
self.assertEqual(grads, expected_grads)
|
|
self.assertEqual(add_input_grads, expected_add_input_grads)
|
|
|
|
@unittest.skipIf(not SM70OrLater, "triton")
|
|
@requires_cuda
|
|
@parametrize("reverse", [False, True])
|
|
@parametrize("compile_mode", ["none", "eager"])
|
|
@parametrize(
|
|
"partial_grad", ["xs", "init", "additional_inputs", "complex", "random"]
|
|
)
|
|
@parametrize("device", [torch.device("cpu"), torch.device("cuda")])
|
|
def test_scan_closure_RNN_partial_autograd(
|
|
self, reverse, compile_mode, partial_grad, device
|
|
):
|
|
dim = 1
|
|
scan_fct = compile_mode_helper(scan, compile_mode)
|
|
|
|
# The first two booleans are the xs
|
|
# The second two are the inits
|
|
# The last four are the additional_inputs
|
|
autograds = []
|
|
|
|
if partial_grad == "xs":
|
|
# xs tests
|
|
autograds.append([True, False, True, True, True, True, True, True])
|
|
autograds.append([False, False, True, True, True, True, True, True])
|
|
elif partial_grad == "init":
|
|
# init tests
|
|
autograds.append([True, True, False, True, True, True, True, True])
|
|
autograds.append([True, True, False, False, True, True, True, True])
|
|
elif partial_grad == "additional_inputs":
|
|
# additional input tests
|
|
autograds.append([True, True, True, True, False, True, False, True])
|
|
autograds.append([True, True, True, True, False, False, False, False])
|
|
elif partial_grad == "complex":
|
|
# complex cases
|
|
autograds.append([True, False, False, False, False, False, False, True])
|
|
autograds.append([False, False, True, True, False, False, False, True])
|
|
elif partial_grad == "random":
|
|
# random tests
|
|
import random
|
|
|
|
for _ in range(5):
|
|
autograds.append([bool(random.randint(0, 1)) for _ in range(8)])
|
|
|
|
for autograd in autograds:
|
|
x = torch.randn(3, 10, 5, device=device, requires_grad=autograd[0])
|
|
x1 = torch.randn(3, 10, 5, device=device, requires_grad=autograd[1])
|
|
h = torch.randn(3, 7, device=device, requires_grad=autograd[2])
|
|
h_1 = torch.randn(3, 7, device=device, requires_grad=autograd[3])
|
|
W_ih = torch.randn(5, 7, device=device, requires_grad=autograd[4])
|
|
b_ih = torch.randn(7, device=device, requires_grad=autograd[5])
|
|
W_hh = torch.randn(7, 7, device=device, requires_grad=autograd[6])
|
|
b_hh = torch.randn(7, device=device, requires_grad=autograd[7])
|
|
|
|
params = [
|
|
p
|
|
for p, a in zip([x, x1, h, h_1, W_ih, b_ih, W_hh, b_hh], autograd)
|
|
if a
|
|
]
|
|
|
|
def RNN(x: torch.Tensor, y: torch.Tensor):
|
|
c_new_0 = x[0] + 1
|
|
c_new_1 = x[1] + 1
|
|
h_new = (
|
|
torch.tanh(c_new_1 + x[0] @ W_hh + b_hh)
|
|
+ y[0] @ W_ih
|
|
+ y[1] @ W_ih
|
|
+ b_ih
|
|
+ x[1]
|
|
)
|
|
return (c_new_0, c_new_1), h_new
|
|
|
|
inits = (h, h_1)
|
|
result = scan_fct(RNN, inits, (x, x1), dim=dim, reverse=reverse)
|
|
result_exp = _fake_scan(RNN, (h, h_1), (x, x1), dim=dim, reverse=reverse)
|
|
self.assertEqual(result, result_exp)
|
|
|
|
if autograd:
|
|
result_flat = pytree.tree_leaves(result)
|
|
result_exp_flat = pytree.tree_leaves(result_exp)
|
|
exp_grad_mask = [
|
|
True if r.requires_grad else False for r in result_exp_flat
|
|
]
|
|
self.check_autograd(
|
|
[r for r, m in zip(result_flat, exp_grad_mask) if m],
|
|
[r for r, m in zip(result_exp_flat, exp_grad_mask) if m],
|
|
params,
|
|
)
|
|
|
|
@requires_cuda
|
|
@skipIfTorchDynamo("not a dynamo test")
|
|
@unittest.skipIf(not SM70OrLater, "triton")
|
|
@parametrize("layers", [1, 2, 3])
|
|
@parametrize("device", ["cpu", "cuda"])
|
|
@torch._dynamo.config.patch(capture_scalar_outputs=True)
|
|
def test_scan_multiple_layers_gradient(self, layers, device):
|
|
import torch.nn as nn
|
|
|
|
torch.manual_seed(1)
|
|
|
|
LAYERS = layers
|
|
BATCH_SIZE = 2
|
|
SEQ_LEN = 5
|
|
FEATURE_DIM = 10
|
|
DEVICE = device
|
|
|
|
class RNNLoop(nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.layers = nn.ModuleList(
|
|
[nn.Linear(FEATURE_DIM * 2, FEATURE_DIM) for _ in range(LAYERS)]
|
|
)
|
|
self.num_layers = LAYERS
|
|
|
|
def forward(self, initial, inputs_sequence):
|
|
B, T, _ = inputs_sequence.shape
|
|
hs_list = initial
|
|
all_out = []
|
|
for t in range(T):
|
|
input = inputs_sequence[:, t, :]
|
|
for li, layer in enumerate(self.layers):
|
|
input_concat = torch.cat((hs_list[li], input), dim=-1)
|
|
update = layer(input_concat)
|
|
hs_list[li] = hs_list[li] + update
|
|
input = hs_list[li]
|
|
|
|
all_out.append(input)
|
|
|
|
return torch.stack(all_out, dim=1)
|
|
|
|
class RNNScanList(nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.layers = nn.ModuleList(
|
|
[nn.Linear(FEATURE_DIM * 2, FEATURE_DIM) for _ in range(LAYERS)]
|
|
)
|
|
self.num_layers = LAYERS
|
|
|
|
def forward(self, initial, input_sequence):
|
|
def step(carry, input):
|
|
hs_list = carry[:]
|
|
for li, layer in enumerate(self.layers):
|
|
h_prev_li = hs_list[li]
|
|
input_concat = torch.cat((h_prev_li, input), dim=-1)
|
|
update = layer(input_concat)
|
|
h_curr_li = h_prev_li + update
|
|
hs_list[li] = h_curr_li
|
|
input = h_curr_li
|
|
return [t.clone() for t in hs_list], input.clone()
|
|
|
|
_, all_outputs_scan = scan(step, initial, input_sequence, dim=1)
|
|
return all_outputs_scan.transpose(0, 1)
|
|
|
|
class RNNScanTensor(nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.layers = nn.ModuleList(
|
|
[nn.Linear(FEATURE_DIM * 2, FEATURE_DIM) for _ in range(LAYERS)]
|
|
)
|
|
self.num_layers = LAYERS
|
|
|
|
def forward(self, initial, input_sequence):
|
|
def step(carry_tensor, xs_input):
|
|
input = xs_input
|
|
hs_tensor = carry_tensor
|
|
for li, layer in enumerate(self.layers):
|
|
current_h_prev_li_slice = hs_tensor[:, li, :]
|
|
input_concat = torch.cat(
|
|
(current_h_prev_li_slice, input), dim=-1
|
|
)
|
|
update = layer(input_concat)
|
|
h_curr_li = current_h_prev_li_slice + update
|
|
hs_tensor = hs_tensor.clone()
|
|
hs_tensor[:, li, :] = h_curr_li
|
|
input = h_curr_li
|
|
return hs_tensor.clone(), input.clone()
|
|
|
|
hs_stacked = torch.stack(initial, dim=1)
|
|
_, all_outputs_scan = scan(step, hs_stacked, input_sequence, dim=1)
|
|
return all_outputs_scan.transpose(0, 1)
|
|
|
|
def run_test_and_get_grads_loss(model, initial_hs, inputs):
|
|
for param in model.parameters():
|
|
if param.grad is not None:
|
|
param.grad.zero_()
|
|
|
|
current_initial_hs = [
|
|
h.detach().clone().requires_grad_(h.requires_grad) for h in initial_hs
|
|
]
|
|
current_inputs = (
|
|
inputs.detach().clone().requires_grad_(inputs.requires_grad)
|
|
)
|
|
|
|
out = model(current_initial_hs, current_inputs)
|
|
loss = out.sum()
|
|
loss.backward()
|
|
|
|
layer_grads = []
|
|
for layer in model.layers:
|
|
layer_grads.append(layer.weight.grad.clone())
|
|
|
|
return layer_grads, loss
|
|
|
|
torch.manual_seed(0)
|
|
|
|
initial_hs_template = [
|
|
torch.zeros(
|
|
BATCH_SIZE, FEATURE_DIM, requires_grad=True, dtype=torch.float32
|
|
).to(DEVICE)
|
|
for _ in range(LAYERS)
|
|
]
|
|
inputs_template = torch.randn(
|
|
BATCH_SIZE, SEQ_LEN, FEATURE_DIM, requires_grad=True, dtype=torch.float32
|
|
).to(DEVICE)
|
|
|
|
# Test 3 models: RNNScanList, RNNScanTensor, RNNLoop
|
|
models = [
|
|
("ScanList", RNNScanList),
|
|
("ScanTensor", RNNScanTensor),
|
|
("Loop", RNNLoop),
|
|
]
|
|
|
|
for model_name, model_class in models:
|
|
# Create uncompiled model
|
|
model_uc = model_class().to(DEVICE)
|
|
uncompiled_grads, uncompiled_loss = run_test_and_get_grads_loss(
|
|
model_uc, initial_hs_template, inputs_template
|
|
)
|
|
|
|
# Create compiled model with same weights
|
|
model_to_compile = model_class().to(DEVICE)
|
|
model_to_compile.load_state_dict(model_uc.state_dict())
|
|
compiled_model = torch.compile(model_to_compile)
|
|
compiled_grads, compiled_loss = run_test_and_get_grads_loss(
|
|
compiled_model, initial_hs_template, inputs_template
|
|
)
|
|
|
|
# Compare gradients for each layer
|
|
for i, (uncompiled_grad, compiled_grad) in enumerate(
|
|
zip(uncompiled_grads, compiled_grads)
|
|
):
|
|
self.assertEqual(
|
|
uncompiled_grad,
|
|
compiled_grad,
|
|
)
|
|
|
|
# Compare losses
|
|
self.assertEqual(
|
|
uncompiled_loss,
|
|
compiled_loss,
|
|
)
|
|
|
|
@unittest.skipIf(not SM70OrLater, "triton")
|
|
@requires_cuda
|
|
@parametrize("reverse", [False, True])
|
|
@parametrize("compile_mode", ["none", "eager"])
|
|
@parametrize("device", [torch.device("cpu"), torch.device("cuda")])
|
|
@parametrize("autograd", [False, True])
|
|
def test_scan_closure_combine_fn_with_no_grad_init_carries_unequal_grad(
|
|
self, reverse, compile_mode, device, autograd
|
|
):
|
|
dim = 1
|
|
scan_fct = compile_mode_helper(scan, compile_mode)
|
|
x = torch.randn(3, 10, 7, device=device, requires_grad=autograd)
|
|
h1 = torch.randn(3, 7, device=device, requires_grad=autograd)
|
|
h2 = torch.randn(3, 7, device=device, requires_grad=autograd)
|
|
|
|
result = scan_fct(
|
|
get_scan_combine_fn("fct_c1_no_grad", True),
|
|
(h1, h2),
|
|
x,
|
|
dim=dim,
|
|
reverse=reverse,
|
|
)
|
|
result_exp = _fake_scan(
|
|
get_scan_combine_fn("fct_c1_no_grad", True),
|
|
(h1, h2),
|
|
x,
|
|
dim=dim,
|
|
reverse=reverse,
|
|
)
|
|
self.assertEqual(result, result_exp)
|
|
|
|
if autograd:
|
|
# TODO: Ideally we should be able to select the results that require gradients like this
|
|
# [leaf for leaf in pytree.tree_leaves(result) if leaf.requires_grad == True]
|
|
# However, for the scan operator this does not work, as all outputs always have
|
|
# grad_fn=<ScanAutogradOpBackward>
|
|
res_req_grad_flat = pytree.tree_leaves(result)[1:]
|
|
res_exp_req_grad_flat = pytree.tree_leaves(result_exp)[1:]
|
|
self.check_autograd(res_req_grad_flat, res_exp_req_grad_flat, (x, h2))
|
|
|
|
@unittest.skipIf(not SM70OrLater, "triton")
|
|
@requires_cuda
|
|
@parametrize("reverse", [False, True])
|
|
@parametrize("compile_mode", ["none", "eager"])
|
|
@parametrize("device", [torch.device("cpu"), torch.device("cuda")])
|
|
@parametrize("autograd", [False, True])
|
|
def test_scan_closure_combine_fn_with_no_grad_init_carries_equal_grad(
|
|
self, reverse, compile_mode, device, autograd
|
|
):
|
|
dim = 1
|
|
scan_fct = compile_mode_helper(scan, compile_mode)
|
|
x = torch.randn(3, 10, 7, device=device, requires_grad=autograd)
|
|
h1 = torch.randn(3, 7, device=device, requires_grad=False)
|
|
h2 = torch.randn(3, 7, device=device, requires_grad=autograd)
|
|
|
|
result = scan_fct(
|
|
get_scan_combine_fn("fct_c1_no_grad", True),
|
|
(h1, h2),
|
|
x,
|
|
dim=dim,
|
|
reverse=reverse,
|
|
)
|
|
result_exp = _fake_scan(
|
|
get_scan_combine_fn("fct_c1_no_grad", True),
|
|
(h1, h2),
|
|
x,
|
|
dim=dim,
|
|
reverse=reverse,
|
|
)
|
|
self.assertEqual(result, result_exp)
|
|
|
|
if autograd:
|
|
# TODO: Ideally we should be able to select the results that require gradients like this
|
|
# [leaf for leaf in pytree.tree_leaves(result) if leaf.requires_grad == True]
|
|
# However, for the scan operator this does not work, as all outputs always have
|
|
# grad_fn=<ScanAutogradOpBackward>
|
|
res_req_grad_flat = pytree.tree_leaves(result)[1:]
|
|
res_exp_req_grad_flat = pytree.tree_leaves(result_exp)[1:]
|
|
self.check_autograd(res_req_grad_flat, res_exp_req_grad_flat, (x, h2))
|
|
|
|
@unittest.skipIf(not SM70OrLater, "triton")
|
|
@requires_cuda
|
|
@parametrize("reverse", [False, True])
|
|
@parametrize("compile_mode", ["none", "eager"])
|
|
@parametrize("device", [torch.device("cpu"), torch.device("cuda")])
|
|
@parametrize("autograd", [False, True])
|
|
def test_scan_closure_combine_fn_with_no_grad_for_out(
|
|
self, reverse, compile_mode, device, autograd
|
|
):
|
|
dim = 1
|
|
scan_fct = compile_mode_helper(scan, compile_mode)
|
|
x = torch.randn(3, 10, 7, device=device, requires_grad=autograd)
|
|
h1 = torch.randn(3, 7, device=device, requires_grad=autograd)
|
|
h2 = torch.randn(3, 7, device=device, requires_grad=autograd)
|
|
|
|
def fct_ys_no_grad(x: torch.Tensor, y: torch.Tensor):
|
|
c1 = x[0] + y
|
|
c2 = x[1] + y
|
|
with torch.no_grad():
|
|
h_new = torch.tanh(x[0] + x[1] + y)
|
|
return (c1, c2), h_new
|
|
|
|
result = scan_fct(fct_ys_no_grad, (h1, h2), x, dim=dim, reverse=reverse)
|
|
result_exp = _fake_scan(fct_ys_no_grad, (h1, h2), x, dim=dim, reverse=reverse)
|
|
self.assertEqual(result, result_exp)
|
|
|
|
if autograd:
|
|
self.check_autograd(result[0], result_exp[0], (x, h1, h2))
|
|
|
|
@unittest.skipIf(not SM70OrLater, "triton")
|
|
@requires_cuda
|
|
@parametrize("reverse", [False, True])
|
|
@parametrize("compile_mode", ["none", "eager"])
|
|
@parametrize("device", [torch.device("cpu"), torch.device("cuda")])
|
|
@parametrize("autograd", [False, True])
|
|
def test_scan_closure_combine_fn_with_no_grad_additional_inputs_partial(
|
|
self, reverse, compile_mode, device, autograd
|
|
):
|
|
dim = 1
|
|
scan_fct = compile_mode_helper(scan, compile_mode)
|
|
x = torch.randn(3, 10, 7, device=device, requires_grad=autograd)
|
|
h = torch.randn(3, 7, device=device, requires_grad=autograd)
|
|
W_ih = torch.randn(7, 7, device=device, requires_grad=autograd)
|
|
b_ih = torch.randn(7, device=device, requires_grad=autograd)
|
|
W_hh = torch.randn(7, 7, device=device, requires_grad=autograd)
|
|
b_hh = torch.randn(7, device=device, requires_grad=autograd)
|
|
|
|
def fct_no_grad_bhh_Whh(x: torch.Tensor, y: torch.Tensor):
|
|
c_new = y @ W_ih + b_ih + x
|
|
|
|
h_new = c_new + 1
|
|
with torch.no_grad():
|
|
h_new_no_grad = torch.tanh(x @ W_hh + b_hh)
|
|
h_new2 = h_new + h_new_no_grad
|
|
|
|
return c_new, h_new2
|
|
|
|
result = scan_fct(fct_no_grad_bhh_Whh, h, x, dim=dim, reverse=reverse)
|
|
result_exp = _fake_scan(fct_no_grad_bhh_Whh, h, x, dim=dim, reverse=reverse)
|
|
self.assertEqual(result, result_exp)
|
|
|
|
if autograd:
|
|
self.check_autograd(result[1], result_exp[1], (h, x, W_ih, b_ih))
|
|
|
|
@unittest.skipIf(not SM70OrLater, "triton")
|
|
@requires_cuda
|
|
@parametrize("reverse", [False, True])
|
|
@parametrize("compile_mode", ["none", "eager"])
|
|
@parametrize("device", [torch.device("cpu"), torch.device("cuda")])
|
|
@parametrize("autograd", [False, True])
|
|
def test_scan_closure_combine_fn_with_no_grad_additional_inputs_all(
|
|
self, reverse, compile_mode, device, autograd
|
|
):
|
|
dim = 1
|
|
scan_fct = compile_mode_helper(scan, compile_mode)
|
|
x = torch.randn(3, 10, 7, device=device, requires_grad=autograd)
|
|
h = torch.randn(3, 7, device=device, requires_grad=autograd)
|
|
W_ih = torch.randn(7, 7, device=device, requires_grad=autograd)
|
|
b_ih = torch.randn(7, device=device, requires_grad=autograd)
|
|
W_hh = torch.randn(7, 7, device=device, requires_grad=autograd)
|
|
b_hh = torch.randn(7, device=device, requires_grad=autograd)
|
|
|
|
def fct_no_grad_bih_Wih_bhh_Whh(x: torch.Tensor, y: torch.Tensor):
|
|
c_new = x + y
|
|
h_new = c_new + x
|
|
with torch.no_grad():
|
|
c_new_no_grad = y @ W_ih + b_ih
|
|
h_new_no_grad = torch.tanh(x @ W_hh + b_hh)
|
|
c_new2 = c_new + c_new_no_grad
|
|
h_new2 = h_new + h_new_no_grad
|
|
return c_new2, h_new2
|
|
|
|
result = scan_fct(fct_no_grad_bih_Wih_bhh_Whh, h, x, dim=dim, reverse=reverse)
|
|
result_exp = _fake_scan(
|
|
fct_no_grad_bih_Wih_bhh_Whh, h, x, dim=dim, reverse=reverse
|
|
)
|
|
self.assertEqual(result, result_exp)
|
|
|
|
if autograd:
|
|
self.check_autograd(result[1], result_exp[1], (h, x))
|
|
|
|
@unittest.skipIf(not SM70OrLater, "triton")
|
|
@requires_cuda
|
|
@parametrize("reverse", [False, True])
|
|
@parametrize("compile_mode", ["none", "eager"])
|
|
@parametrize("device", [torch.device("cpu"), torch.device("cuda")])
|
|
@parametrize("autograd", [False, True])
|
|
def test_scan_closure_combine_fn_carries_ys_same_grad(
|
|
self, reverse, compile_mode, device, autograd
|
|
):
|
|
dim = 1
|
|
scan_fct = compile_mode_helper(scan, compile_mode)
|
|
x = torch.randn(3, 10, 7, device=device, requires_grad=autograd)
|
|
h = torch.randn(3, 7, device=device, requires_grad=autograd)
|
|
W_ih = torch.randn(7, 7, device=device, requires_grad=autograd)
|
|
b_ih = torch.randn(7, device=device, requires_grad=autograd)
|
|
W_hh = torch.randn(7, 7, device=device, requires_grad=autograd)
|
|
b_hh = torch.randn(7, device=device, requires_grad=autograd)
|
|
|
|
def fct_no_grad_bih_Wih_bhh_Whh(x: torch.Tensor, y: torch.Tensor):
|
|
c_new = x + y
|
|
h_new = c_new + 1
|
|
with torch.no_grad():
|
|
c_new_no_grad = y @ W_ih + b_ih
|
|
h_new_no_grad = torch.tanh(x @ W_hh + b_hh)
|
|
c_new2 = c_new + c_new_no_grad
|
|
h_new2 = h_new + h_new_no_grad
|
|
return c_new2, h_new2
|
|
|
|
result = scan_fct(fct_no_grad_bih_Wih_bhh_Whh, h, x, dim=dim, reverse=reverse)
|
|
result_exp = _fake_scan(
|
|
fct_no_grad_bih_Wih_bhh_Whh, h, x, dim=dim, reverse=reverse
|
|
)
|
|
self.assertEqual(result, result_exp)
|
|
|
|
if autograd:
|
|
self.check_autograd(result[1], result_exp[1], (h, x))
|
|
|
|
@unittest.skipIf(not SM70OrLater, "triton")
|
|
@requires_cuda
|
|
@parametrize("reverse", [False, True])
|
|
@parametrize("compile_mode", ["none", "eager"])
|
|
@parametrize("device", [torch.device("cpu"), torch.device("cuda")])
|
|
@parametrize("autograd", [False, True])
|
|
def test_scan_closure_nested(self, reverse, compile_mode, device, autograd):
|
|
scan_fct = compile_mode_helper(scan, compile_mode)
|
|
|
|
# Simple non-nested case
|
|
x = torch.randn(3, 20, 5, device=device, requires_grad=autograd)
|
|
h = torch.randn(3, 7, device=device, requires_grad=autograd)
|
|
W = torch.randn(5, 7, device=device, requires_grad=autograd)
|
|
b = torch.randn(7, device=device, requires_grad=autograd)
|
|
|
|
def f1(x: torch.Tensor, y: torch.Tensor):
|
|
c_new = y @ W + b
|
|
h_new = torch.tanh(c_new + x)
|
|
return c_new, h_new
|
|
|
|
result = scan_fct(f1, h, x, dim=1, reverse=reverse)
|
|
result_exp = _fake_scan(f1, h, x, dim=1, reverse=reverse)
|
|
self.assertEqual(result, result_exp)
|
|
|
|
if autograd:
|
|
self.check_autograd(result, result_exp, (h, x, W, b))
|
|
|
|
# Nested case
|
|
def chain_fct(fct, f_1, f_2, xs, h_1, h_2):
|
|
o1 = fct(
|
|
f_1,
|
|
h_1,
|
|
xs,
|
|
dim=1,
|
|
reverse=reverse,
|
|
)
|
|
o2 = fct(
|
|
f_2,
|
|
h_2,
|
|
o1[1],
|
|
dim=0,
|
|
reverse=reverse,
|
|
)
|
|
return o2
|
|
|
|
x1 = torch.ones(3, 20, 5, device=device, requires_grad=autograd)
|
|
h1 = torch.zeros(3, 7, device=device, requires_grad=autograd)
|
|
h2 = torch.zeros(3, 3, device=device, requires_grad=autograd)
|
|
W_1 = torch.randn(5, 7, device=device, requires_grad=autograd)
|
|
b_1 = torch.randn(7, device=device, requires_grad=autograd)
|
|
W_2 = torch.randn(7, 3, device=device, requires_grad=autograd)
|
|
b_2 = torch.randn(3, device=device, requires_grad=autograd)
|
|
|
|
def f1(x: torch.Tensor, y: torch.Tensor):
|
|
c_new = y @ W_1 + b_1
|
|
h_new = torch.tanh(c_new + x)
|
|
return c_new, h_new
|
|
|
|
def f2(x: torch.Tensor, y: torch.Tensor):
|
|
c_new = y @ W_2 + b_2
|
|
h_new = torch.tanh(c_new + x)
|
|
return c_new, h_new
|
|
|
|
result1 = chain_fct(scan_fct, f1, f2, x1, h1, h2)
|
|
expected_result = chain_fct(_fake_scan, f1, f2, x1, h1, h2)
|
|
self.assertEqual(result1, expected_result)
|
|
|
|
if autograd:
|
|
self.check_autograd(result1, expected_result, (h1, h2, x1, W_1, b_1))
|
|
|
|
# Complex case
|
|
x1 = torch.randn(3, 20, 3, device=device, requires_grad=autograd)
|
|
h1 = torch.randn(3, 3, device=device, requires_grad=autograd)
|
|
h2 = torch.randn(3, 3, device=device, requires_grad=autograd)
|
|
W_1 = torch.randn(3, 3, device=device, requires_grad=autograd)
|
|
b_1 = torch.randn(3, device=device, requires_grad=autograd)
|
|
W_2 = torch.randn(3, 3, device=device, requires_grad=autograd)
|
|
b_2 = torch.randn(3, device=device, requires_grad=autograd)
|
|
|
|
def f1(x: torch.Tensor, y: torch.Tensor):
|
|
c_new = y @ W_1 + b_1
|
|
h_new = torch.tanh(c_new + x)
|
|
return c_new, h_new
|
|
|
|
def f2(x: torch.Tensor, y: torch.Tensor):
|
|
c_new = y @ W_2 + b_2 * b_1 + y @ W_1
|
|
h_new = torch.tanh(c_new + x)
|
|
return c_new, h_new
|
|
|
|
result1 = chain_fct(scan_fct, f1, f2, x1, h1, h2)
|
|
expected_result = chain_fct(_fake_scan, f1, f2, x1, h1, h2)
|
|
self.assertEqual(result1, expected_result)
|
|
|
|
if autograd:
|
|
self.check_autograd(
|
|
result1, expected_result, (h1, h2, x1, W_1, b_1, W_2, b_2)
|
|
)
|
|
|
|
@skipIfNoDynamoSupport
|
|
def test_scan_simple_graph_wrong_dtype(self):
|
|
def add_wrong_dtype(x: torch.Tensor, y: torch.Tensor):
|
|
return torch.ones_like(x + y, dtype=torch.int64), x + y
|
|
|
|
x = torch.randn(3, 10, 2, device=torch.device("cpu"))
|
|
init = torch.randn(1, 10, 2, device=torch.device("cpu"))
|
|
|
|
def f(fct, init, xs):
|
|
return scan(fct, init, xs, dim=0, reverse=True)
|
|
|
|
# Wrong dtype
|
|
with self.assertRaisesRegex(
|
|
torch._dynamo.exc.UncapturedHigherOrderOpError,
|
|
"Expected init and carry to have same metadata.*",
|
|
):
|
|
f(add_wrong_dtype, init, x)
|
|
|
|
@skipIfNoDynamoSupport
|
|
@skipIfCrossRef # Arg order changes with crossref
|
|
def test_scan_simple_graph(self):
|
|
x = torch.randn(3, 10, 2, device=torch.device("cpu"))
|
|
init = torch.randn(1, 10, 2, device=torch.device("cpu"))
|
|
|
|
def f(fct, init, xs):
|
|
return scan(fct, init, xs, dim=0, reverse=True)
|
|
|
|
# Correct case
|
|
gm = make_fx(f, tracing_mode="symbolic")(
|
|
get_scan_combine_fn("add", False), init, x
|
|
)
|
|
self.assertExpectedInline(
|
|
gm.code.strip(),
|
|
"""\
|
|
def forward(self, fct_1, init_1, xs_1):
|
|
permute = torch.ops.aten.permute.default(xs_1, [0, 1, 2])
|
|
flip = torch.ops.aten.flip.default(permute, [0]); permute = None
|
|
sym_size_int_1 = torch.ops.aten.sym_size.int(init_1, 1)
|
|
sym_size_int_2 = torch.ops.aten.sym_size.int(init_1, 2)
|
|
sym_size_int_3 = torch.ops.aten.sym_size.int(xs_1, 1)
|
|
sym_size_int_4 = torch.ops.aten.sym_size.int(xs_1, 2); xs_1 = None
|
|
scan_combine_graph_0 = self.scan_combine_graph_0
|
|
scan = torch.ops.higher_order.scan(scan_combine_graph_0, [init_1], [flip], (sym_size_int_1, sym_size_int_2, sym_size_int_3, sym_size_int_4)); scan_combine_graph_0 = init_1 = flip = sym_size_int_1 = sym_size_int_2 = sym_size_int_3 = sym_size_int_4 = None
|
|
getitem = scan[0]
|
|
getitem_1 = scan[1]; scan = None
|
|
flip_1 = torch.ops.aten.flip.default(getitem_1, [0]); getitem_1 = None
|
|
return (getitem, flip_1)""", # noqa: B950
|
|
)
|
|
|
|
# Check graph
|
|
backend = EagerAndRecordGraphs()
|
|
torch.compile(f, backend=backend)(get_scan_combine_fn("add", False), init, x)
|
|
gm = backend.graphs[0]
|
|
|
|
self.assertExpectedInline(
|
|
gm.code.strip(),
|
|
"""\
|
|
def forward(self, L_init_ : torch.Tensor, L_xs_ : torch.Tensor):
|
|
l_init_ = L_init_
|
|
l_xs_ = L_xs_
|
|
elem = torch.movedim(l_xs_, 0, 0); l_xs_ = None
|
|
flip = torch.flip(elem, [0]); elem = None
|
|
scan_combine_fn_0 = self.scan_combine_fn_0
|
|
scan = torch.ops.higher_order.scan(scan_combine_fn_0, [l_init_], [flip], []); scan_combine_fn_0 = l_init_ = flip = None
|
|
carry = scan[0]
|
|
out = scan[1]; scan = None
|
|
out_1 = out.flip([0]); out = None
|
|
return (carry, out_1)""", # noqa: B950
|
|
)
|
|
|
|
@requires_cuda
|
|
def test_scan_input_mutation(self):
|
|
device = torch.device("cuda")
|
|
|
|
def fct_input_mutation(x, y):
|
|
x.add_(1)
|
|
return x + y, x + y + 2
|
|
|
|
x = torch.randn(3, 2, 2, device=device)
|
|
init = torch.randn(2, 2, device=device)
|
|
|
|
with self.assertRaisesRegex(
|
|
# Should be
|
|
# torch._dynamo.exc.Unsupported,
|
|
# "Encountered aliasing during higher order op tracing for HOP.*"
|
|
torch._dynamo.exc.UncapturedHigherOrderOpError,
|
|
"scan must be captured completely with torch.compile.*",
|
|
):
|
|
scan(fct_input_mutation, init, x, dim=0)
|
|
|
|
@requires_cuda
|
|
def test_scan_input_carry_alias(self):
|
|
device = torch.device("cuda")
|
|
|
|
def fct_input_output_alias(x, y):
|
|
return (x[0], x[1] + y[1]), (x[1] + y[1] + 1, x[1] + y[1] + 2)
|
|
|
|
x = torch.randn(3, 2, 2, device=device)
|
|
y = torch.randn(3, 2, 2, device=device)
|
|
inp = (x, y)
|
|
init = (torch.randn(2, 2, device=device), torch.randn(2, 2, device=device))
|
|
|
|
with self.assertRaisesRegex(
|
|
# Should be
|
|
# torch._dynamo.exc.Unsupported,
|
|
# "Encountered aliasing during higher order op tracing for HOP.*"
|
|
torch._dynamo.exc.UncapturedHigherOrderOpError,
|
|
"scan must be captured completely with torch.compile.*",
|
|
):
|
|
scan(fct_input_output_alias, init, inp, dim=0)
|
|
|
|
@requires_cuda
|
|
def test_scan_input_output_alias(self):
|
|
device = torch.device("cuda")
|
|
|
|
def fct_input_output_alias(x, y):
|
|
return (x[0] + 1, x[1] + y[1]), (x[1], x[1] + y[1] + 2)
|
|
|
|
x = torch.randn(3, 2, 2, device=device)
|
|
y = torch.randn(3, 2, 2, device=device)
|
|
inp = (x, y)
|
|
init = (torch.randn(2, 2, device=device), torch.randn(2, 2, device=device))
|
|
|
|
with self.assertRaisesRegex(
|
|
# Should be
|
|
# torch._dynamo.exc.Unsupported,
|
|
# "Encountered aliasing during higher order op tracing for HOP.*"
|
|
torch._dynamo.exc.UncapturedHigherOrderOpError,
|
|
"scan must be captured completely with torch.compile.*",
|
|
):
|
|
scan(fct_input_output_alias, init, inp, dim=0)
|
|
|
|
@unittest.skipIf(not SM70OrLater, "triton")
|
|
@requires_cuda
|
|
def test_scan_carry_carry_alias(self):
|
|
device = torch.device("cuda")
|
|
|
|
def fct_carry_carry_alias(x, y):
|
|
c = x[0] + y[1]
|
|
return (c, c), (x[0] + y[1], x[0] + y[1] + 1)
|
|
|
|
x = torch.randn(3, 2, 2, device=device)
|
|
y = torch.randn(3, 2, 2, device=device)
|
|
inp = (x, y)
|
|
init = (torch.randn(2, 2, device=device), torch.randn(2, 2, device=device))
|
|
|
|
with self.assertRaisesRegex(
|
|
# Should be
|
|
# torch._dynamo.exc.Unsupported,
|
|
# "Encountered aliasing during higher order op tracing for HOP.*"
|
|
torch._dynamo.exc.UncapturedHigherOrderOpError,
|
|
"scan must be captured completely with torch.compile.*",
|
|
):
|
|
scan(fct_carry_carry_alias, init, inp, dim=0)
|
|
|
|
@unittest.skipIf(not SM70OrLater, "triton")
|
|
@requires_cuda
|
|
def test_scan_carry_output_alias(self):
|
|
device = torch.device("cuda")
|
|
|
|
def fct_carry_output_alias(x, y):
|
|
c = x[0] + y[1]
|
|
return (x[0] + y[1], c), (c, x[0] + y[1] + 1)
|
|
|
|
x = torch.randn(3, 2, 2, device=device)
|
|
y = torch.randn(3, 2, 2, device=device)
|
|
inp = (x, y)
|
|
init = (torch.randn(2, 2, device=device), torch.randn(2, 2, device=device))
|
|
|
|
with self.assertRaisesRegex(
|
|
# Should be
|
|
# torch._dynamo.exc.Unsupported,
|
|
# "Encountered aliasing during higher order op tracing for HOP.*"
|
|
torch._dynamo.exc.UncapturedHigherOrderOpError,
|
|
"scan must be captured completely with torch.compile.*",
|
|
):
|
|
scan(fct_carry_output_alias, init, inp, dim=0)
|
|
|
|
|
|
class AssociativeScanModels:
|
|
@staticmethod
|
|
def get_scan_fct(compile_mode, combine_mode):
|
|
# Compile the associative_scan according to the provided compile_mode
|
|
if compile_mode != "fake":
|
|
assoc_scan_comp = compile_mode_helper(associative_scan, compile_mode)
|
|
|
|
def scan_fct(combine_fn, xs, dim, reverse):
|
|
return assoc_scan_comp(combine_fn, xs, dim, reverse, combine_mode)
|
|
|
|
else:
|
|
scan_fct = _fake_associative_scan
|
|
return scan_fct
|
|
|
|
class CombineFn(torch.nn.Module):
|
|
def __init__(self, combine_fn, dim, reverse, combine_mode, compile_mode):
|
|
super().__init__()
|
|
|
|
self.scan_fct = AssociativeScanModels.get_scan_fct(
|
|
compile_mode, combine_mode
|
|
)
|
|
self.combine_fn = combine_fn
|
|
self.dim = dim
|
|
self.reverse = reverse
|
|
|
|
def forward(self, inputs):
|
|
results = self.scan_fct(self.combine_fn, inputs, self.dim, self.reverse)
|
|
return results
|
|
|
|
class Simple(torch.nn.Module):
|
|
def __init__(self, dim, reverse, combine_mode, compile_mode):
|
|
super().__init__()
|
|
|
|
kwargs = {
|
|
"dim": dim,
|
|
"reverse": reverse,
|
|
"combine_mode": combine_mode,
|
|
"compile_mode": compile_mode,
|
|
}
|
|
self.combine_fns = [
|
|
AssociativeScanModels.CombineFn(
|
|
get_scan_combine_fn("add", True), **kwargs
|
|
),
|
|
AssociativeScanModels.CombineFn(
|
|
get_scan_combine_fn("mul", True), **kwargs
|
|
),
|
|
]
|
|
|
|
def forward(self, inputs):
|
|
results = []
|
|
for combine_fn in self.combine_fns:
|
|
results.append(combine_fn(inputs))
|
|
return results
|
|
|
|
class ChainFn(torch.nn.Module):
|
|
def __init__(self, combine_fn, dim, reverse, combine_mode, compile_mode):
|
|
super().__init__()
|
|
|
|
chain_len = len(combine_fn)
|
|
kwargs = {
|
|
"combine_fn": combine_fn,
|
|
"dim": dim,
|
|
"reverse": reverse,
|
|
"combine_mode": combine_mode,
|
|
}
|
|
|
|
# Prepare the kwargs as a list.
|
|
self.nested_tuple = []
|
|
for ind in range(chain_len):
|
|
kwargs_el = {}
|
|
for key, val in kwargs.items():
|
|
# Check if val is a list and if it has the same length as combine_fn
|
|
# If so, then use the individual elements.
|
|
# If not, duplicate the first element.
|
|
if type(val) == list and len(val) == chain_len:
|
|
kwargs_el[key] = val[ind]
|
|
else:
|
|
kwargs_el[key] = val
|
|
|
|
scan_fct = AssociativeScanModels.get_scan_fct(
|
|
compile_mode, kwargs_el["combine_mode"]
|
|
)
|
|
combine_fn = kwargs_el["combine_fn"]
|
|
del kwargs_el["combine_fn"]
|
|
del kwargs_el["combine_mode"]
|
|
self.nested_tuple.append((combine_fn, scan_fct, kwargs_el))
|
|
|
|
def forward(self, inputs):
|
|
results = inputs
|
|
for combine_fn, scan_fct, kwargs in self.nested_tuple:
|
|
results = combine_fn(scan_fct, results, **kwargs)
|
|
return results
|
|
|
|
class NestedFn(torch.nn.Module):
|
|
def forward(self, scan_fct, inputs, **kwargs):
|
|
combine_fn = kwargs["combine_fn"]
|
|
|
|
# Remove combine_fn from kwargs
|
|
del kwargs["combine_fn"]
|
|
|
|
results = scan_fct(combine_fn, inputs, **kwargs)
|
|
|
|
return results
|
|
|
|
|
|
@unittest.skipIf(IS_WINDOWS, "Windows not supported for this test")
|
|
@skipIfNoDynamoSupport
|
|
class AssociativeScanTests(TestCase):
|
|
def setUp(self):
|
|
torch._dynamo.reset()
|
|
super().setUp()
|
|
|
|
def _run_test(self, model, model_fake, inputs):
|
|
result = model(inputs)
|
|
result_exp = model_fake(inputs)
|
|
self.assertEqual(result, result_exp)
|
|
|
|
# Return the result of the functions under test for further investigations
|
|
return result
|
|
|
|
def _prepare_fake_kwargs(self, original_kwargs):
|
|
kwargs_fake = original_kwargs.copy()
|
|
kwargs_fake["compile_mode"] = "fake"
|
|
return kwargs_fake
|
|
|
|
@unittest.skipIf(not SM70OrLater, "triton")
|
|
@requires_cuda
|
|
@parametrize("reverse", [False, True])
|
|
@parametrize("compile_mode", ["none", "eager", "compile", "compile_dynamic_shape"])
|
|
@parametrize("combine_mode", ["pointwise", "generic"])
|
|
@parametrize("device", [torch.device("cpu"), torch.device("cuda")])
|
|
# Skipping the combination of combine_mode=pointwise and device=cpu
|
|
# as the current implementation of pointwise does only support CUDA device
|
|
# Skipping the combination of combine_mode=pointwise and compile_mode=compile_dynamic_shape
|
|
# as the current implementation does not support lifted arguments
|
|
@decorateIf(
|
|
unittest.skip,
|
|
lambda params: (
|
|
params["combine_mode"] == "pointwise"
|
|
and (
|
|
params["device"] == torch.device("cpu")
|
|
or params["compile_mode"] == "compile_dynamic_shape"
|
|
or torch.version.hip
|
|
)
|
|
),
|
|
)
|
|
def test_associative_scan_compile(
|
|
self, combine_mode, reverse, compile_mode, device
|
|
):
|
|
x = torch.randn(3, 10, 2, device=device)
|
|
kwargs = {
|
|
"dim": 0,
|
|
"reverse": reverse,
|
|
"compile_mode": compile_mode,
|
|
"combine_mode": combine_mode,
|
|
}
|
|
kwargs_fake = self._prepare_fake_kwargs(kwargs)
|
|
results = self._run_test(
|
|
model=AssociativeScanModels.Simple(**kwargs),
|
|
model_fake=AssociativeScanModels.Simple(**kwargs_fake),
|
|
inputs=x,
|
|
)
|
|
|
|
if not reverse:
|
|
results_torch = []
|
|
for op_pt in [torch.cumsum, torch.cumprod]:
|
|
results_torch.append(op_pt(x, 0))
|
|
self.assertEqual(results, results_torch)
|
|
|
|
# Jax Examples
|
|
x = torch.arange(0, 4, device=device)
|
|
kwargs = {
|
|
"dim": 0,
|
|
"reverse": reverse,
|
|
"compile_mode": compile_mode,
|
|
"combine_fn": get_scan_combine_fn("add", True),
|
|
"combine_mode": combine_mode,
|
|
}
|
|
kwargs_fake = self._prepare_fake_kwargs(kwargs)
|
|
result = self._run_test(
|
|
model=AssociativeScanModels.CombineFn(**kwargs),
|
|
model_fake=AssociativeScanModels.CombineFn(**kwargs_fake),
|
|
inputs=x,
|
|
)
|
|
|
|
if not reverse:
|
|
results_torch = torch.tensor([0.0, 1.0, 3.0, 6.0], dtype=torch.int64)
|
|
else:
|
|
results_torch = torch.tensor([6.0, 6.0, 5.0, 3.0], dtype=torch.int64)
|
|
|
|
self.assertEqual(result, results_torch)
|
|
|
|
@unittest.skipIf(not SM70OrLater, "triton")
|
|
@requires_cuda
|
|
@parametrize("reverse", [False, True])
|
|
@parametrize("compile_mode", ["none", "eager", "compile", "compile_dynamic_shape"])
|
|
@parametrize("combine_mode", ["pointwise", "generic"])
|
|
@parametrize("device", [torch.device("cpu"), torch.device("cuda")])
|
|
# Skipping the combination of combine_mode=pointwise and device=cpu
|
|
# as the current implementation of pointwise does only support CUDA device
|
|
# Skipping the combination of combine_mode=pointwise and compile_mode=compile_dynamic_shape
|
|
# as the current implementation does not support lifted arguments
|
|
@decorateIf(
|
|
unittest.skip,
|
|
lambda params: (
|
|
params["combine_mode"] == "pointwise"
|
|
and (
|
|
params["device"] == torch.device("cpu")
|
|
or params["compile_mode"] == "compile_dynamic_shape"
|
|
or torch.version.hip
|
|
)
|
|
),
|
|
)
|
|
def test_associative_scan_dim(self, combine_mode, compile_mode, reverse, device):
|
|
import random
|
|
|
|
random.seed(1234)
|
|
|
|
num_dims = [random.randint(2, 5) for _ in range(4)]
|
|
for num_dim in num_dims:
|
|
# To avoid triggering automatic dynamic shape
|
|
torch._dynamo.reset()
|
|
shapes = [random.randint(1, 9) for _ in range(num_dim)]
|
|
rnd_scan_dim = random.randint(0, num_dim - 1)
|
|
x = torch.randn(*shapes, device=device)
|
|
|
|
kwargs = {
|
|
"dim": rnd_scan_dim,
|
|
"reverse": reverse,
|
|
"compile_mode": compile_mode,
|
|
"combine_mode": combine_mode,
|
|
}
|
|
kwargs_fake = self._prepare_fake_kwargs(kwargs)
|
|
results = self._run_test(
|
|
model=AssociativeScanModels.Simple(**kwargs),
|
|
model_fake=AssociativeScanModels.Simple(**kwargs_fake),
|
|
inputs=x,
|
|
)
|
|
|
|
if not reverse:
|
|
results_torch = []
|
|
for op_pt in [torch.cumsum, torch.cumprod]:
|
|
results_torch.append(op_pt(x, 0))
|
|
self.assertEqual(results, results_torch)
|
|
|
|
@unittest.skipIf(not SM70OrLater, "triton")
|
|
@requires_cuda
|
|
@unittest.expectedFailure
|
|
def test_associative_scan_dim_shape_failure(self, compile_mode, combine_mode):
|
|
num_dims = [2]
|
|
for num_dim in num_dims:
|
|
shapes = [9 for _ in range(num_dim)]
|
|
rnd_scan_dim = 0
|
|
x = torch.randn(*shapes, device=torch.device("cuda"))
|
|
|
|
kwargs = {
|
|
"dim": rnd_scan_dim,
|
|
"reverse": True,
|
|
"compile_mode": "compile",
|
|
"combine_mode": "generic",
|
|
}
|
|
kwargs_fake = self._prepare_fake_kwargs(kwargs)
|
|
self._run_test(
|
|
model=AssociativeScanModels.Simple(**kwargs),
|
|
model_fake=AssociativeScanModels.Simple(**kwargs_fake),
|
|
inputs=x,
|
|
)
|
|
|
|
@skipIfRocm(msg="Unsupported on ROCM yet")
|
|
@unittest.skipIf(not SM70OrLater, "triton")
|
|
@requires_cuda
|
|
@parametrize("compile_mode", ["none", "eager", "compile", "compile_dynamic_shape"])
|
|
@parametrize("combine_mode", ["pointwise", "generic"])
|
|
@parametrize("reverse", [False, True])
|
|
@parametrize("device", [torch.device("cpu"), torch.device("cuda")])
|
|
# Skipping the combination of combine_mode=pointwise and device=cpu
|
|
# as the current implementation of pointwise does only support CUDA device
|
|
# Skipping the combination of combine_mode=pointwise and compile_mode=compile_dynamic_shape
|
|
# as the current implementation does not support lifted arguments
|
|
@decorateIf(
|
|
unittest.skip,
|
|
lambda params: (
|
|
params["combine_mode"] == "pointwise"
|
|
and (
|
|
params["device"] == torch.device("cpu")
|
|
or params["compile_mode"] == "compile_dynamic_shape"
|
|
or torch.version.hip
|
|
)
|
|
),
|
|
)
|
|
def test_associative_scan_tuple(self, compile_mode, combine_mode, reverse, device):
|
|
x = torch.randn(3, 2, 2, device=device)
|
|
y = torch.randn(3, 2, 2, device=device)
|
|
inp = (x, y)
|
|
|
|
kwargs = {
|
|
"dim": 0,
|
|
"reverse": reverse,
|
|
"compile_mode": compile_mode,
|
|
"combine_fn": get_scan_combine_fn("tuple_fct", True),
|
|
"combine_mode": combine_mode,
|
|
}
|
|
kwargs_fake = self._prepare_fake_kwargs(kwargs)
|
|
self._run_test(
|
|
model=AssociativeScanModels.CombineFn(**kwargs),
|
|
model_fake=AssociativeScanModels.CombineFn(**kwargs_fake),
|
|
inputs=inp,
|
|
)
|
|
|
|
@unittest.skipIf(not SM70OrLater, "triton")
|
|
@requires_cuda
|
|
@parametrize("compile_mode", ["none", "eager", "compile", "compile_dynamic_shape"])
|
|
@parametrize("combine_mode", ["pointwise", "generic"])
|
|
@parametrize("reverse", [False, True])
|
|
@parametrize("device", [torch.device("cpu"), torch.device("cuda")])
|
|
def test_associative_scan_expand_in_combine_fn(
|
|
self, compile_mode, combine_mode, reverse, device
|
|
):
|
|
x = torch.randn(3, 2, 2, device=device)
|
|
|
|
def combine_fn(x, y):
|
|
return x * torch.sum(y, -1).expand(x.shape)
|
|
|
|
kwargs = {
|
|
"dim": 0,
|
|
"reverse": reverse,
|
|
"compile_mode": compile_mode,
|
|
"combine_fn": combine_fn,
|
|
"combine_mode": "generic",
|
|
}
|
|
kwargs_fake = self._prepare_fake_kwargs(kwargs)
|
|
self._run_test(
|
|
model=AssociativeScanModels.CombineFn(**kwargs),
|
|
model_fake=AssociativeScanModels.CombineFn(**kwargs_fake),
|
|
inputs=x,
|
|
)
|
|
|
|
@unittest.skipIf(not SM70OrLater, "triton")
|
|
@requires_cuda
|
|
@parametrize("compile_mode", ["none", "eager", "compile", "compile_dynamic_shape"])
|
|
@parametrize("reverse", [False, True])
|
|
@parametrize("device", [torch.device("cpu"), torch.device("cuda")])
|
|
def test_associative_scan_non_contiguous_tensor(
|
|
self, compile_mode, reverse, device
|
|
):
|
|
x = torch.arange(30, device=device).view(10, 3).t()
|
|
assert not x.is_contiguous()
|
|
|
|
kwargs = {
|
|
"dim": 0,
|
|
"reverse": reverse,
|
|
"compile_mode": compile_mode,
|
|
"combine_fn": get_scan_combine_fn("add", True),
|
|
"combine_mode": "generic",
|
|
}
|
|
kwargs_fake = self._prepare_fake_kwargs(kwargs)
|
|
self._run_test(
|
|
model=AssociativeScanModels.CombineFn(**kwargs),
|
|
model_fake=AssociativeScanModels.CombineFn(**kwargs_fake),
|
|
inputs=x,
|
|
)
|
|
|
|
@unittest.skipIf(not SM70OrLater, "triton")
|
|
@requires_cuda
|
|
@parametrize("compile_mode", ["none", "eager", "compile", "compile_dynamic_shape"])
|
|
@parametrize("combine_mode", ["pointwise", "generic"])
|
|
@parametrize("reverse", [False, True])
|
|
@parametrize("device", [torch.device("cpu"), torch.device("cuda")])
|
|
# Skipping the combination of combine_mode=pointwise and device=cpu
|
|
# as the current implementation of pointwise does only support CUDA device
|
|
# Skipping the combination of combine_mode=pointwise and compile_mode=compile_dynamic_shape
|
|
# as the current implementation does not support lifted arguments
|
|
@decorateIf(
|
|
unittest.skip,
|
|
lambda params: (
|
|
params["combine_mode"] == "pointwise"
|
|
and (
|
|
params["device"] == torch.device("cpu")
|
|
or params["compile_mode"] == "compile_dynamic_shape"
|
|
or torch.version.hip
|
|
)
|
|
),
|
|
)
|
|
def test_associative_scan_complex_pytree(
|
|
self, compile_mode, combine_mode, reverse, device
|
|
):
|
|
x = torch.randn(3, 2, 2, device=device)
|
|
y = torch.randn(3, 2, 2, device=device)
|
|
z = torch.randn(3, 2, 2, device=device)
|
|
inp = {"i": x, "j": ([y], [{"o": z}])}
|
|
|
|
kwargs = {
|
|
"dim": 0,
|
|
"reverse": reverse,
|
|
"compile_mode": compile_mode,
|
|
"combine_fn": get_scan_combine_fn("complex_pointwise", True),
|
|
"combine_mode": combine_mode,
|
|
}
|
|
kwargs_fake = self._prepare_fake_kwargs(kwargs)
|
|
self._run_test(
|
|
model=AssociativeScanModels.CombineFn(**kwargs),
|
|
model_fake=AssociativeScanModels.CombineFn(**kwargs_fake),
|
|
inputs=inp,
|
|
)
|
|
|
|
@skipIfTorchDynamo("don't test compile on compile")
|
|
@skipIfNoDynamoSupport
|
|
@skipIfCrossRef # Arg order changes with crossref
|
|
def test_associative_scan_pytree_output(self):
|
|
x = (
|
|
(
|
|
torch.randn(3, 10, 2, device=torch.device("cpu")),
|
|
(torch.randn(3, 10, 2, device=torch.device("cpu")),),
|
|
),
|
|
torch.randn(3, 10, 2, device=torch.device("cpu")),
|
|
)
|
|
|
|
def f(fct, xs):
|
|
return associative_scan(
|
|
fct, xs, dim=0, reverse=True, combine_mode="generic"
|
|
)
|
|
|
|
def combine_fn(x: torch.Tensor, y: torch.Tensor):
|
|
a, b = (x[0][0] + y[1], x[0][1][0] - y[1])
|
|
return (a, (b,)), a - b
|
|
|
|
# Check graph
|
|
backend = EagerAndRecordGraphs()
|
|
torch.compile(f, backend=backend)(combine_fn, x)
|
|
gm = backend.graphs[0]
|
|
|
|
self.assertExpectedInline(
|
|
normalize_gm(gm.print_readable(print_output=False)),
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(self, L_xs_0_0_: "f32[3, 10, 2]", L_xs_0_1_0_: "f32[3, 10, 2]", L_xs_1_: "f32[3, 10, 2]"):
|
|
l_xs_0_0_ = L_xs_0_0_
|
|
l_xs_0_1_0_ = L_xs_0_1_0_
|
|
l_xs_1_ = L_xs_1_
|
|
|
|
elem: "f32[3, 10, 2]" = torch.movedim(l_xs_0_0_, 0, 0); l_xs_0_0_ = None
|
|
elem_1: "f32[3, 10, 2]" = torch.movedim(l_xs_0_1_0_, 0, 0); l_xs_0_1_0_ = None
|
|
elem_2: "f32[3, 10, 2]" = torch.movedim(l_xs_1_, 0, 0); l_xs_1_ = None
|
|
|
|
elem_3: "f32[3, 10, 2]" = torch.flip(elem, [0]); elem = None
|
|
elem_4: "f32[3, 10, 2]" = torch.flip(elem_1, [0]); elem_1 = None
|
|
elem_5: "f32[3, 10, 2]" = torch.flip(elem_2, [0]); elem_2 = None
|
|
|
|
child: "f32[1, 10, 2]" = torch.ops.aten.slice(elem_3, 0, 0, -1, 2)
|
|
child_1: "f32[1, 10, 2]" = torch.ops.aten.slice(elem_4, 0, 0, -1, 2)
|
|
child_2: "f32[1, 10, 2]" = torch.ops.aten.slice(elem_5, 0, 0, -1, 2)
|
|
|
|
child_3: "f32[1, 10, 2]" = torch.ops.aten.slice(elem_3, 0, 1, None, 2)
|
|
child_4: "f32[1, 10, 2]" = torch.ops.aten.slice(elem_4, 0, 1, None, 2)
|
|
child_5: "f32[1, 10, 2]" = torch.ops.aten.slice(elem_5, 0, 1, None, 2)
|
|
|
|
lazy_load_decompositions = torch._functorch.predispatch.lazy_load_decompositions(); lazy_load_decompositions = None
|
|
|
|
_vmap_increment_nesting = torch._functorch.predispatch._vmap_increment_nesting(1, 'error'); _vmap_increment_nesting = None
|
|
|
|
_add_batch_dim: "f32[10, 2]" = torch._functorch.predispatch._add_batch_dim(child, 0, 1); child = None
|
|
_add_batch_dim_1: "f32[10, 2]" = torch._functorch.predispatch._add_batch_dim(child_1, 0, 1); child_1 = None
|
|
_add_batch_dim_2: "f32[10, 2]" = torch._functorch.predispatch._add_batch_dim(child_2, 0, 1); child_2 = _add_batch_dim_2 = None
|
|
_add_batch_dim_3: "f32[10, 2]" = torch._functorch.predispatch._add_batch_dim(child_3, 0, 1); child_3 = _add_batch_dim_3 = None
|
|
_add_batch_dim_4: "f32[10, 2]" = torch._functorch.predispatch._add_batch_dim(child_4, 0, 1); child_4 = _add_batch_dim_4 = None
|
|
_add_batch_dim_5: "f32[10, 2]" = torch._functorch.predispatch._add_batch_dim(child_5, 0, 1); child_5 = None
|
|
|
|
a: "f32[10, 2]" = _add_batch_dim + _add_batch_dim_5; _add_batch_dim = None
|
|
b: "f32[10, 2]" = _add_batch_dim_1 - _add_batch_dim_5; _add_batch_dim_1 = _add_batch_dim_5 = None
|
|
|
|
child_6: "f32[10, 2]" = a - b
|
|
|
|
child_7: "f32[1, 10, 2]" = torch._functorch.predispatch._remove_batch_dim(a, 1, 1, 0); a = None
|
|
child_8: "f32[1, 10, 2]" = torch._functorch.predispatch._remove_batch_dim(b, 1, 1, 0); b = None
|
|
child_9: "f32[1, 10, 2]" = torch._functorch.predispatch._remove_batch_dim(child_6, 1, 1, 0); child_6 = None
|
|
|
|
_vmap_decrement_nesting = torch._functorch.predispatch._vmap_decrement_nesting(); _vmap_decrement_nesting = None
|
|
|
|
child_10: "f32[1, 10, 2]" = torch.ops.aten.slice(elem_3, 0, 2, None, 2)
|
|
child_11: "f32[1, 10, 2]" = torch.ops.aten.slice(elem_4, 0, 2, None, 2)
|
|
child_12: "f32[1, 10, 2]" = torch.ops.aten.slice(elem_5, 0, 2, None, 2)
|
|
|
|
lazy_load_decompositions_1 = torch._functorch.predispatch.lazy_load_decompositions(); lazy_load_decompositions_1 = None
|
|
|
|
_vmap_increment_nesting_1 = torch._functorch.predispatch._vmap_increment_nesting(1, 'error'); _vmap_increment_nesting_1 = None
|
|
|
|
_add_batch_dim_6: "f32[10, 2]" = torch._functorch.predispatch._add_batch_dim(child_7, 0, 1)
|
|
_add_batch_dim_7: "f32[10, 2]" = torch._functorch.predispatch._add_batch_dim(child_8, 0, 1)
|
|
_add_batch_dim_8: "f32[10, 2]" = torch._functorch.predispatch._add_batch_dim(child_9, 0, 1); _add_batch_dim_8 = None
|
|
_add_batch_dim_9: "f32[10, 2]" = torch._functorch.predispatch._add_batch_dim(child_10, 0, 1); child_10 = _add_batch_dim_9 = None
|
|
_add_batch_dim_10: "f32[10, 2]" = torch._functorch.predispatch._add_batch_dim(child_11, 0, 1); child_11 = _add_batch_dim_10 = None
|
|
_add_batch_dim_11: "f32[10, 2]" = torch._functorch.predispatch._add_batch_dim(child_12, 0, 1); child_12 = None
|
|
|
|
a_1: "f32[10, 2]" = _add_batch_dim_6 + _add_batch_dim_11; _add_batch_dim_6 = None
|
|
b_1: "f32[10, 2]" = _add_batch_dim_7 - _add_batch_dim_11; _add_batch_dim_7 = _add_batch_dim_11 = None
|
|
|
|
child_13: "f32[10, 2]" = a_1 - b_1
|
|
|
|
child_14: "f32[1, 10, 2]" = torch._functorch.predispatch._remove_batch_dim(a_1, 1, 1, 0); a_1 = None
|
|
child_15: "f32[1, 10, 2]" = torch._functorch.predispatch._remove_batch_dim(b_1, 1, 1, 0); b_1 = None
|
|
child_16: "f32[1, 10, 2]" = torch._functorch.predispatch._remove_batch_dim(child_13, 1, 1, 0); child_13 = None
|
|
|
|
_vmap_decrement_nesting_1 = torch._functorch.predispatch._vmap_decrement_nesting(); _vmap_decrement_nesting_1 = None
|
|
|
|
slice_10: "f32[1, 10, 2]" = torch.ops.aten.slice(elem_3, 0, 0, 1); elem_3 = None
|
|
cat: "f32[2, 10, 2]" = torch.cat([slice_10, child_14], dim = 0); slice_10 = child_14 = None
|
|
slice_11: "f32[1, 10, 2]" = torch.ops.aten.slice(elem_4, 0, 0, 1); elem_4 = None
|
|
cat_1: "f32[2, 10, 2]" = torch.cat([slice_11, child_15], dim = 0); slice_11 = child_15 = None
|
|
slice_12: "f32[1, 10, 2]" = torch.ops.aten.slice(elem_5, 0, 0, 1); elem_5 = None
|
|
cat_2: "f32[2, 10, 2]" = torch.cat([slice_12, child_16], dim = 0); slice_12 = child_16 = None
|
|
|
|
b_2: "f32[2, 10, 2]" = torch._C._nn.pad(child_7, [0, 0, 0, 0, 0, 1], 'constant', None); child_7 = None
|
|
|
|
stacked: "f32[2, 2, 10, 2]" = torch.stack([cat, b_2], dim = 1); cat = b_2 = None
|
|
|
|
interleaved: "f32[4, 10, 2]" = torch.flatten(stacked, start_dim = 0, end_dim = 1); stacked = None
|
|
|
|
interleaved_1: "f32[3, 10, 2]" = torch.ops.aten.slice(interleaved, 0, 0, 3); interleaved = None
|
|
|
|
b_3: "f32[2, 10, 2]" = torch._C._nn.pad(child_8, [0, 0, 0, 0, 0, 1], 'constant', None); child_8 = None
|
|
|
|
stacked_1: "f32[2, 2, 10, 2]" = torch.stack([cat_1, b_3], dim = 1); cat_1 = b_3 = None
|
|
|
|
interleaved_2: "f32[4, 10, 2]" = torch.flatten(stacked_1, start_dim = 0, end_dim = 1); stacked_1 = None
|
|
|
|
interleaved_3: "f32[3, 10, 2]" = torch.ops.aten.slice(interleaved_2, 0, 0, 3); interleaved_2 = None
|
|
|
|
b_4: "f32[2, 10, 2]" = torch._C._nn.pad(child_9, [0, 0, 0, 0, 0, 1], 'constant', None); child_9 = None
|
|
|
|
stacked_2: "f32[2, 2, 10, 2]" = torch.stack([cat_2, b_4], dim = 1); cat_2 = b_4 = None
|
|
|
|
interleaved_4: "f32[4, 10, 2]" = torch.flatten(stacked_2, start_dim = 0, end_dim = 1); stacked_2 = None
|
|
|
|
interleaved_5: "f32[3, 10, 2]" = torch.ops.aten.slice(interleaved_4, 0, 0, 3); interleaved_4 = None
|
|
|
|
child_17: "f32[3, 10, 2]" = interleaved_1.flip([0]); interleaved_1 = None
|
|
child_18: "f32[3, 10, 2]" = interleaved_3.flip([0]); interleaved_3 = None
|
|
child_19: "f32[3, 10, 2]" = interleaved_5.flip([0]); interleaved_5 = None
|
|
|
|
movedim_3: "f32[3, 10, 2]" = torch.movedim(child_17, 0, 0); child_17 = None
|
|
movedim_4: "f32[3, 10, 2]" = torch.movedim(child_18, 0, 0); child_18 = None
|
|
movedim_5: "f32[3, 10, 2]" = torch.movedim(child_19, 0, 0); child_19 = None
|
|
return (movedim_3, movedim_4, movedim_5)
|
|
""", # noqa: B950
|
|
)
|
|
|
|
@unittest.skipIf(not SM70OrLater, "triton")
|
|
@requires_cuda
|
|
@parametrize("combine_mode", ["pointwise", "generic"])
|
|
@parametrize("compile_mode", ["none", "eager", "compile", "compile_dynamic_shape"])
|
|
@parametrize("reverse", [False, True])
|
|
@parametrize("device", [torch.device("cpu"), torch.device("cuda")])
|
|
# Skipping the combination of combine_mode=pointwise and device=cpu
|
|
# as the current implementation of pointwise does only support CUDA device
|
|
# Skipping the combination of combine_mode=pointwise and compile_mode=compile_dynamic_shape
|
|
# as the current implementation does not support lifted arguments
|
|
@decorateIf(
|
|
unittest.skip,
|
|
lambda params: (
|
|
params["combine_mode"] == "pointwise"
|
|
and (
|
|
params["device"] == torch.device("cpu")
|
|
or params["compile_mode"] == "compile_dynamic_shape"
|
|
or torch.version.hip
|
|
)
|
|
),
|
|
)
|
|
def test_associative_scan_downstream_scan_matmul(
|
|
self, combine_mode, compile_mode, reverse, device
|
|
):
|
|
def first_chain_fct(scan_fct, inp, **kwargs):
|
|
o = scan_fct(get_scan_combine_fn("add", True), inp, **kwargs)
|
|
return o
|
|
|
|
def second_chain_fct(scan_fct, inp, **kwargs):
|
|
W = torch.ones(2, 5, device=device)
|
|
return inp @ W
|
|
|
|
inp = torch.randn(3, 10, 2, device=device)
|
|
kwargs = {
|
|
"dim": 1,
|
|
"reverse": reverse,
|
|
"compile_mode": compile_mode,
|
|
"combine_fn": [first_chain_fct, second_chain_fct],
|
|
"combine_mode": combine_mode,
|
|
}
|
|
kwargs_fake = self._prepare_fake_kwargs(kwargs)
|
|
self._run_test(
|
|
model=AssociativeScanModels.ChainFn(**kwargs),
|
|
model_fake=AssociativeScanModels.ChainFn(**kwargs_fake),
|
|
inputs=inp,
|
|
)
|
|
|
|
@unittest.skipIf(not SM70OrLater, "triton")
|
|
@requires_cuda
|
|
@parametrize("combine_mode", ["pointwise", "generic"])
|
|
@parametrize("compile_mode", ["none", "eager", "compile", "compile_dynamic_shape"])
|
|
@parametrize("reverse", [False, True])
|
|
@parametrize("device", [torch.device("cpu"), torch.device("cuda")])
|
|
# Skipping the combination of combine_mode=pointwise and device=cpu
|
|
# as the current implementation of pointwise does only support CUDA device
|
|
# Skipping the combination of combine_mode=pointwise and compile_mode=compile_dynamic_shape
|
|
# as the current implementation does not support lifted arguments
|
|
@decorateIf(
|
|
unittest.skip,
|
|
lambda params: (
|
|
params["combine_mode"] == "pointwise"
|
|
and (
|
|
params["device"] == torch.device("cpu")
|
|
or params["compile_mode"] == "compile_dynamic_shape"
|
|
or torch.version.hip
|
|
)
|
|
),
|
|
)
|
|
def test_associative_scan_downstream_scan_scan(
|
|
self, combine_mode, compile_mode, reverse, device
|
|
):
|
|
def first_chain_fct(scan_fct, inp, **kwargs):
|
|
o1 = scan_fct(get_scan_combine_fn("add", True), inp, **kwargs)
|
|
return o1
|
|
|
|
def second_chain_fct(scan_fct, inp, **kwargs):
|
|
o2 = scan_fct(get_scan_combine_fn("add", True), inp, **kwargs)
|
|
return o2
|
|
|
|
inp = torch.randn(3, 10, 2, device=device)
|
|
|
|
kwargs = {
|
|
"dim": 1,
|
|
"reverse": reverse,
|
|
"compile_mode": compile_mode,
|
|
"combine_fn": [first_chain_fct, second_chain_fct],
|
|
"combine_mode": combine_mode,
|
|
}
|
|
kwargs_fake = self._prepare_fake_kwargs(kwargs)
|
|
self._run_test(
|
|
model=AssociativeScanModels.ChainFn(**kwargs),
|
|
model_fake=AssociativeScanModels.ChainFn(**kwargs_fake),
|
|
inputs=inp,
|
|
)
|
|
|
|
@unittest.skipIf(not SM70OrLater, "triton")
|
|
@requires_cuda
|
|
@parametrize("combine_mode", ["pointwise", "generic"])
|
|
@parametrize("compile_mode", ["none", "eager", "compile", "compile_dynamic_shape"])
|
|
@parametrize("reverse_first", [False, True])
|
|
@parametrize("same_direction", [False, True])
|
|
@parametrize("device", [torch.device("cpu"), torch.device("cuda")])
|
|
# Skipping the combination of combine_mode=pointwise and device=cpu
|
|
# as the current implementation of pointwise does only support CUDA device
|
|
# Skipping the combination of combine_mode=pointwise and compile_mode=compile_dynamic_shape
|
|
# as the current implementation does not support lifted arguments
|
|
@decorateIf(
|
|
unittest.skip,
|
|
lambda params: (
|
|
params["combine_mode"] == "pointwise"
|
|
and (
|
|
params["device"] == torch.device("cpu")
|
|
or params["compile_mode"] == "compile_dynamic_shape"
|
|
or torch.version.hip
|
|
)
|
|
),
|
|
)
|
|
def test_associative_scan_downstream_scan_scan_different_dim(
|
|
self, combine_mode, compile_mode, reverse_first, same_direction, device
|
|
):
|
|
reverse_second = reverse_first if same_direction else not reverse_first
|
|
|
|
def first_chain_fct(scan_fct, inp, **kwargs):
|
|
o1 = scan_fct(get_scan_combine_fn("add", True), inp, **kwargs)
|
|
return o1
|
|
|
|
def second_chain_fct(scan_fct, inp, **kwargs):
|
|
o2 = scan_fct(get_scan_combine_fn("add", True), inp, **kwargs)
|
|
return o2
|
|
|
|
inp = torch.randn(3, 10, 2, device=device)
|
|
|
|
kwargs = {
|
|
"dim": [1, 0],
|
|
"reverse": [reverse_first, reverse_second],
|
|
"compile_mode": compile_mode,
|
|
"combine_fn": [first_chain_fct, second_chain_fct],
|
|
"combine_mode": [combine_mode, combine_mode],
|
|
}
|
|
kwargs_fake = self._prepare_fake_kwargs(kwargs)
|
|
self._run_test(
|
|
model=AssociativeScanModels.ChainFn(**kwargs),
|
|
model_fake=AssociativeScanModels.ChainFn(**kwargs_fake),
|
|
inputs=inp,
|
|
)
|
|
|
|
# TODO: Does not work because of the usage of vmap within associative_scan
|
|
# TODO: Re-enable additional parameters again once this issues has been resolved
|
|
@unittest.skipIf(not SM70OrLater, "triton")
|
|
@requires_cuda
|
|
@unittest.expectedFailure
|
|
def test_associative_scan_nested(self):
|
|
combine_mode = "pointwise"
|
|
compile_mode = "eager"
|
|
reverse_first = False
|
|
same_direction = False
|
|
device = torch.device("cuda")
|
|
|
|
reverse_second = reverse_first if same_direction else not reverse_first
|
|
|
|
def first_nested_fct(x, y):
|
|
y_new = associative_scan(
|
|
second_nested_fct,
|
|
y,
|
|
0,
|
|
reverse=reverse_second,
|
|
combine_mode=combine_mode,
|
|
)
|
|
return x + y_new
|
|
|
|
def first_nested_fct_fake(x, y):
|
|
y_new = _fake_associative_scan(
|
|
second_nested_fct, y, 0, reverse=reverse_second
|
|
)
|
|
return x + y_new
|
|
|
|
def second_nested_fct(x, y):
|
|
return x * y
|
|
|
|
inp = torch.randn(3, 10, 2, device=device)
|
|
|
|
kwargs = {
|
|
"dim": 0,
|
|
"reverse": reverse_first,
|
|
"compile_mode": compile_mode,
|
|
"combine_fn": first_nested_fct,
|
|
"combine_mode": combine_mode,
|
|
}
|
|
kwargs_fake = self._prepare_fake_kwargs(kwargs)
|
|
kwargs_fake["combine_fn"] = first_nested_fct_fake
|
|
self._run_test(
|
|
model=AssociativeScanModels.NestedFn(**kwargs),
|
|
model_fake=AssociativeScanModels.NestedFn(**kwargs_fake),
|
|
inputs=inp,
|
|
)
|
|
|
|
@unittest.skipIf(not SM70OrLater, "triton")
|
|
@requires_cuda
|
|
@parametrize("compile_mode", ["none", "eager", "compile", "compile_dynamic_shape"])
|
|
@parametrize("loop_type", ["for"])
|
|
@parametrize("reverse", [False, True])
|
|
@parametrize("device", [torch.device("cpu"), torch.device("cuda")])
|
|
def test_associative_scan_loop_in_combine_fn(
|
|
self, compile_mode, loop_type, reverse, device
|
|
):
|
|
def combine_fn(x, y):
|
|
cnt = torch.zeros_like(y[0, :])
|
|
if loop_type == "while":
|
|
|
|
def cond_fn(ind, loop_val):
|
|
return (loop_val < 5)[0]
|
|
|
|
def body_fn(ind, loop_val):
|
|
return ind + 1, loop_val + torch.abs(ind)
|
|
|
|
new_ind, cnt = torch.while_loop(
|
|
cond_fn=cond_fn,
|
|
body_fn=body_fn,
|
|
carried_inputs=(
|
|
torch.zeros(1, dtype=torch.int32, device=cnt.device),
|
|
cnt,
|
|
),
|
|
)
|
|
else:
|
|
for ind in range(10):
|
|
cnt += torch.abs(y[ind])
|
|
return x * cnt
|
|
|
|
inp = torch.randn(3, 10, 1, device=device) * 2
|
|
|
|
kwargs = {
|
|
"dim": 0,
|
|
"reverse": reverse,
|
|
"compile_mode": compile_mode,
|
|
"combine_fn": combine_fn,
|
|
"combine_mode": "generic",
|
|
}
|
|
kwargs_fake = self._prepare_fake_kwargs(kwargs)
|
|
self._run_test(
|
|
model=AssociativeScanModels.CombineFn(**kwargs),
|
|
model_fake=AssociativeScanModels.CombineFn(**kwargs_fake),
|
|
inputs=inp,
|
|
)
|
|
|
|
# TODO: Does not work because of the usage of vmap within associative_scan
|
|
# TODO: Re-enable additional parameters again once this issues has been resolved
|
|
@unittest.skipIf(not SM70OrLater, "triton")
|
|
@requires_cuda
|
|
@unittest.expectedFailure
|
|
def test_associative_scan_loop_in_combine_fn_failure(self):
|
|
compile_mode = "none"
|
|
loop_type = "while"
|
|
reverse = False
|
|
device = torch.device("cuda")
|
|
|
|
def combine_fn(x, y):
|
|
_cnt = torch.zeros_like(y[0, :])
|
|
if loop_type == "while":
|
|
|
|
def cond_fn(ind, loop_val):
|
|
return (loop_val < 5)[0]
|
|
|
|
def body_fn(ind, loop_val):
|
|
return ind + 1, loop_val + torch.abs(ind)
|
|
|
|
inp = torch.randn(3, 10, 1, device=device) * 2
|
|
|
|
kwargs = {
|
|
"dim": 0,
|
|
"reverse": reverse,
|
|
"compile_mode": compile_mode,
|
|
"combine_fn": combine_fn,
|
|
"combine_mode": "generic",
|
|
}
|
|
kwargs_fake = self._prepare_fake_kwargs(kwargs)
|
|
self._run_test(
|
|
model=AssociativeScanModels.CombineFn(**kwargs),
|
|
model_fake=AssociativeScanModels.CombineFn(**kwargs_fake),
|
|
inputs=inp,
|
|
)
|
|
|
|
@unittest.skipIf(not SM70OrLater, "triton")
|
|
@requires_cuda
|
|
@parametrize("compile_mode", ["none", "eager", "compile", "compile_dynamic_shape"])
|
|
@parametrize("reverse", [False, True])
|
|
@parametrize("device", [torch.device("cpu"), torch.device("cuda")])
|
|
# Skipping the combination of compile_mode=compile_dynamic_shape
|
|
# as the current implementation does not support lifted arguments
|
|
@decorateIf(
|
|
unittest.skip,
|
|
lambda params: (
|
|
params["device"] == torch.device("cpu")
|
|
or params["compile_mode"] == "compile_dynamic_shape"
|
|
or torch.version.hip
|
|
),
|
|
)
|
|
def test_associative_scan_cond_in_combine_fn(self, compile_mode, reverse, device):
|
|
def combine_fn(x, y):
|
|
val = cond(torch.sum(y) > 0.0, lambda y: y.clone(), lambda y: 1.0 - y, (y,))
|
|
return x * val
|
|
|
|
inp = torch.randn(3, 10, 1, device=device)
|
|
|
|
kwargs = {
|
|
"dim": 0,
|
|
"reverse": reverse,
|
|
"compile_mode": compile_mode,
|
|
"combine_fn": combine_fn,
|
|
"combine_mode": "generic",
|
|
}
|
|
kwargs_fake = self._prepare_fake_kwargs(kwargs)
|
|
self._run_test(
|
|
model=AssociativeScanModels.CombineFn(**kwargs),
|
|
model_fake=AssociativeScanModels.CombineFn(**kwargs_fake),
|
|
inputs=inp,
|
|
)
|
|
|
|
# TODO: Does not work because of the usage of vmap within associative_scan
|
|
# TODO: Re-enable additional parameters again once this issues has been resolved
|
|
@unittest.skipIf(not SM70OrLater, "triton")
|
|
@requires_cuda
|
|
@unittest.expectedFailure
|
|
def test_associative_scan_map_in_combine_fn(self):
|
|
compile_mode = "none"
|
|
reverse = False
|
|
device = torch.device("cuda")
|
|
|
|
def combine_fn(x, y):
|
|
def body(x, y):
|
|
return x + y
|
|
|
|
y_init = y[0]
|
|
y_new = control_flow.map(body, y, y_init)
|
|
return x * y_new
|
|
|
|
inp = torch.randn(3, 10, 1, device=device)
|
|
|
|
kwargs = {
|
|
"dim": 0,
|
|
"reverse": reverse,
|
|
"compile_mode": compile_mode,
|
|
"combine_fn": combine_fn,
|
|
"combine_mode": "generic",
|
|
}
|
|
kwargs_fake = self._prepare_fake_kwargs(kwargs)
|
|
self._run_test(
|
|
model=AssociativeScanModels.CombineFn(**kwargs),
|
|
model_fake=AssociativeScanModels.CombineFn(**kwargs_fake),
|
|
inputs=inp,
|
|
)
|
|
|
|
@unittest.skipIf(not SM70OrLater, "triton")
|
|
@requires_cuda
|
|
@parametrize("compile_mode", ["none", "eager", "compile", "compile_dynamic_shape"])
|
|
@parametrize("reverse", [False, True])
|
|
@parametrize("device", [torch.device("cpu"), torch.device("cuda")])
|
|
def test_associative_scan_vmap_in_combine_fn(self, compile_mode, reverse, device):
|
|
def combine_fn(x, y):
|
|
def body(x):
|
|
return x**2
|
|
|
|
mapped_body = torch.vmap(body, 0, 0)
|
|
y_new = mapped_body(y)
|
|
return x + y_new
|
|
|
|
inp = torch.randn(3, 10, 2, device=device)
|
|
|
|
kwargs = {
|
|
"dim": 0,
|
|
"reverse": reverse,
|
|
"compile_mode": compile_mode,
|
|
"combine_fn": combine_fn,
|
|
"combine_mode": "generic",
|
|
}
|
|
kwargs_fake = self._prepare_fake_kwargs(kwargs)
|
|
self._run_test(
|
|
model=AssociativeScanModels.CombineFn(**kwargs),
|
|
model_fake=AssociativeScanModels.CombineFn(**kwargs_fake),
|
|
inputs=inp,
|
|
)
|
|
|
|
@unittest.skipIf(not SM70OrLater, "triton")
|
|
@requires_cuda
|
|
@parametrize("reverse", [False, True])
|
|
@parametrize("compile_mode", ["none", "eager", "compile", "compile_dynamic_shape"])
|
|
@parametrize("device", [torch.device("cpu"), torch.device("cuda")])
|
|
# Skipping the combination of associative_scan and device=cpu
|
|
# as the current implementation of pointwise does only support CUDA device
|
|
@decorateIf(
|
|
unittest.skip,
|
|
lambda params: (params["device"] == torch.device("cpu")),
|
|
)
|
|
def test_associative_scan_non_pointwise_generic(
|
|
self, reverse, compile_mode, device
|
|
):
|
|
x = torch.randn(3, 10, 2, device=device)
|
|
|
|
kwargs = {
|
|
"dim": 0,
|
|
"reverse": reverse,
|
|
"compile_mode": compile_mode,
|
|
"combine_fn": get_scan_combine_fn("non_pointwise", True),
|
|
"combine_mode": "generic",
|
|
}
|
|
kwargs_fake = self._prepare_fake_kwargs(kwargs)
|
|
self._run_test(
|
|
model=AssociativeScanModels.CombineFn(**kwargs),
|
|
model_fake=AssociativeScanModels.CombineFn(**kwargs_fake),
|
|
inputs=x,
|
|
)
|
|
|
|
@skipIfRocm(msg="Unsupported on ROCM yet")
|
|
@unittest.skipIf(not SM70OrLater, "triton")
|
|
@requires_cuda
|
|
@parametrize("compile_mode", ["none", "eager", "compile", "compile_dynamic_shape"])
|
|
@parametrize("combine_mode", ["pointwise", "generic"])
|
|
@parametrize("reverse", [False, True])
|
|
@parametrize("device", [torch.device("cpu"), torch.device("cuda")])
|
|
# Skipping the combination of combine_mode=pointwise and device=cpu
|
|
# as the current implementation of pointwise does only support CUDA device
|
|
# Skipping the combination of combine_mode=pointwise and compile_mode=compile_dynamic_shape
|
|
# as the current implementation does not support lifted arguments
|
|
@decorateIf(
|
|
unittest.skip,
|
|
lambda params: (
|
|
params["combine_mode"] == "pointwise"
|
|
and (
|
|
params["device"] == torch.device("cpu")
|
|
or params["compile_mode"] == "compile_dynamic_shape"
|
|
or torch.version.hip
|
|
)
|
|
),
|
|
)
|
|
def test_associative_scan_binary_operator(
|
|
self, compile_mode, combine_mode, reverse, device
|
|
):
|
|
state_dim = 20
|
|
timesteps = 10
|
|
projected_inputs = torch.randn(
|
|
timesteps, state_dim, requires_grad=True, device=device
|
|
)
|
|
A = torch.randn(state_dim, requires_grad=True, device=device)
|
|
elements = (A.repeat((timesteps, 1)), projected_inputs)
|
|
|
|
kwargs = {
|
|
"dim": 0,
|
|
"reverse": reverse,
|
|
"compile_mode": compile_mode,
|
|
"combine_fn": get_scan_combine_fn("s5_operator", True),
|
|
"combine_mode": combine_mode,
|
|
}
|
|
kwargs_fake = self._prepare_fake_kwargs(kwargs)
|
|
self._run_test(
|
|
model=AssociativeScanModels.CombineFn(**kwargs),
|
|
model_fake=AssociativeScanModels.CombineFn(**kwargs_fake),
|
|
inputs=elements,
|
|
)
|
|
|
|
@skipIfRocm(msg="Unsupported on ROCM yet")
|
|
@unittest.skipIf(not SM70OrLater, "triton")
|
|
@requires_cuda
|
|
@parametrize("compile_mode", ["none", "eager", "compile", "compile_dynamic_shape"])
|
|
@parametrize("reverse", [False, True])
|
|
@parametrize("device", [torch.device("cpu"), torch.device("cuda")])
|
|
def test_associative_scan_different_input_size(self, compile_mode, reverse, device):
|
|
batch = 5
|
|
hidden_dim = 3
|
|
length = 10
|
|
dstate = 7
|
|
|
|
deltaA = torch.randn(
|
|
(batch, hidden_dim, length, dstate), requires_grad=True, device=device
|
|
)
|
|
deltaB_u = torch.randn(
|
|
(batch, hidden_dim, length, dstate), requires_grad=True, device=device
|
|
)
|
|
C = torch.randn((batch, dstate, length), requires_grad=True, device=device)
|
|
x = torch.randn(
|
|
(batch, hidden_dim, length, dstate), requires_grad=True, device=device
|
|
)
|
|
y = torch.randn((batch, hidden_dim, length), requires_grad=True, device=device)
|
|
elements = (x, deltaA, deltaB_u, C, y)
|
|
|
|
kwargs = {
|
|
"dim": 2,
|
|
"reverse": reverse,
|
|
"compile_mode": compile_mode,
|
|
"combine_fn": get_scan_combine_fn("different_input_size_operator", True),
|
|
"combine_mode": "generic",
|
|
}
|
|
kwargs_fake = self._prepare_fake_kwargs(kwargs)
|
|
self._run_test(
|
|
model=AssociativeScanModels.CombineFn(**kwargs),
|
|
model_fake=AssociativeScanModels.CombineFn(**kwargs_fake),
|
|
inputs=elements,
|
|
)
|
|
|
|
@unittest.skipIf(not SM70OrLater, "triton")
|
|
@requires_cuda
|
|
def test_associative_scan_different_input_size_wrong_dim(self):
|
|
batch = 5
|
|
hidden_dim = 3
|
|
length = 10
|
|
dstate = 7
|
|
|
|
deltaA = torch.randn(
|
|
(batch, hidden_dim, length, dstate), device=torch.device("cuda")
|
|
)
|
|
deltaB_u = torch.randn(
|
|
(batch, hidden_dim, length, dstate), device=torch.device("cuda")
|
|
)
|
|
C = torch.randn((batch, dstate, length), device=torch.device("cuda"))
|
|
x = torch.randn(
|
|
(batch, hidden_dim, length, dstate), device=torch.device("cuda")
|
|
)
|
|
y = torch.randn(
|
|
(batch, hidden_dim, length, dstate), device=torch.device("cuda")
|
|
)
|
|
elements = (x, deltaA, deltaB_u, C, y)
|
|
|
|
with self.assertRaisesRegex(
|
|
ValueError,
|
|
"All xs leaves must at least have.*",
|
|
):
|
|
associative_scan(
|
|
get_scan_combine_fn("different_input_size_operator", True),
|
|
elements,
|
|
3,
|
|
combine_mode="pointwise",
|
|
)
|
|
|
|
@unittest.skipIf(not SM70OrLater, "triton")
|
|
@unittest.skipIf(not torch.cuda.is_available(), "Test requires CUDA.")
|
|
@parametrize("compile_mode", ["none", "eager", "compile", "compile_dynamic_shape"])
|
|
@parametrize("combine_mode", ["pointwise", "generic"])
|
|
@parametrize("reverse", [False, True])
|
|
@parametrize("device", [torch.device("cpu"), torch.device("cuda")])
|
|
# Skipping the combine_mode=pointwise
|
|
# as the current implementation of associative_scan lowering
|
|
# does not support lifted arguments
|
|
@decorateIf(
|
|
unittest.skip,
|
|
lambda params: (params["combine_mode"] == "pointwise"),
|
|
)
|
|
def test_associative_scan_freevars_simple(
|
|
self, compile_mode, combine_mode, reverse, device
|
|
):
|
|
H = torch.rand(2, device=device)
|
|
|
|
def fct_freevars1(x: torch.Tensor, y: torch.Tensor):
|
|
return x * H + y * 2
|
|
|
|
def fct_freevars2(x: torch.Tensor, y: torch.Tensor):
|
|
return x * H + y * H
|
|
|
|
H1 = torch.rand(1, device=device)
|
|
H2 = torch.rand(1, device=device)
|
|
|
|
def fct_freevars3(x: torch.Tensor, y: torch.Tensor):
|
|
return x * H1 + y * H2
|
|
|
|
inp = torch.randn(3, 2, 2, device=device)
|
|
|
|
for fct, param in [
|
|
(fct_freevars1, (H,)),
|
|
(fct_freevars2, (H,)),
|
|
(fct_freevars3, (H1, H2)),
|
|
]:
|
|
kwargs = {
|
|
"dim": 0,
|
|
"reverse": reverse,
|
|
"compile_mode": compile_mode,
|
|
"combine_fn": fct,
|
|
"combine_mode": combine_mode,
|
|
}
|
|
kwargs_fake = self._prepare_fake_kwargs(kwargs)
|
|
self._run_test(
|
|
model=AssociativeScanModels.CombineFn(**kwargs),
|
|
model_fake=AssociativeScanModels.CombineFn(**kwargs_fake),
|
|
inputs=inp,
|
|
)
|
|
|
|
@unittest.skipIf(not SM70OrLater, "triton")
|
|
@unittest.skipIf(not torch.cuda.is_available(), "Test requires CUDA.")
|
|
@parametrize("compile_mode", ["none", "eager", "compile", "compile_dynamic_shape"])
|
|
@parametrize("combine_mode", ["pointwise", "generic"])
|
|
@parametrize("reverse", [False, True])
|
|
@parametrize("device", [torch.device("cpu"), torch.device("cuda")])
|
|
# Skipping the combine_mode=pointwise
|
|
# as the current implementation of associative_scan lowering
|
|
# does not support lifted arguments
|
|
@decorateIf(
|
|
unittest.skip,
|
|
lambda params: (params["combine_mode"] == "pointwise"),
|
|
)
|
|
def test_associative_scan_freevars_nested(
|
|
self, compile_mode, combine_mode, reverse, device
|
|
):
|
|
H1 = torch.rand(4, 5, device=device)
|
|
H2 = torch.rand(4, 1, device=device)
|
|
|
|
def fct_nested_outside(x: torch.Tensor, y: torch.Tensor):
|
|
def inner(xi):
|
|
return xi * H2
|
|
|
|
ret = inner(y)
|
|
return x + ret * H1
|
|
|
|
def fct_nested_outside_fake(x: torch.Tensor, y: torch.Tensor):
|
|
def inner(xi):
|
|
return xi * H2
|
|
|
|
ret = inner(y)
|
|
return x + ret * H1
|
|
|
|
H1_i = torch.rand(4, 5, device=device)
|
|
|
|
# TODO: Using random tensors in the `combine_fn` triggers the vmap randomness error:
|
|
# RuntimeError: vmap: called random operation while in randomness error mode.
|
|
# Please either use the 'same' or 'different' randomness flags on vmap or perform the randomness operation out of vmap
|
|
def fct_nested_inside(x: torch.Tensor, y: torch.Tensor):
|
|
# H2_i = torch.rand(4, 1, device=device)
|
|
H2_i = torch.ones(4, 1, device=device) * 42
|
|
|
|
def inner(xi):
|
|
return xi * H2_i
|
|
|
|
ret = inner(y)
|
|
return x + ret * H1
|
|
|
|
def fct_nested_inside_fake(x: torch.Tensor, y: torch.Tensor):
|
|
# H2_i = torch.rand(4, 1, device=device)
|
|
H2_i = torch.ones(4, 1, device=device) * 42
|
|
|
|
def inner(xi):
|
|
return xi * H2_i
|
|
|
|
ret = inner(y)
|
|
return x + ret * H1
|
|
|
|
inp = torch.randn(3, 4, 5, device=device)
|
|
|
|
for fct, fct_fake, param in [
|
|
(fct_nested_outside, fct_nested_outside_fake, (H1, H2)),
|
|
(fct_nested_inside, fct_nested_inside_fake, (H1_i,)),
|
|
]:
|
|
kwargs = {
|
|
"dim": 0,
|
|
"reverse": reverse,
|
|
"compile_mode": compile_mode,
|
|
"combine_fn": fct,
|
|
"combine_mode": combine_mode,
|
|
}
|
|
kwargs_fake = self._prepare_fake_kwargs(kwargs)
|
|
kwargs_fake["combine_fn"] = fct_fake
|
|
self._run_test(
|
|
model=AssociativeScanModels.CombineFn(**kwargs),
|
|
model_fake=AssociativeScanModels.CombineFn(**kwargs_fake),
|
|
inputs=inp,
|
|
)
|
|
|
|
@unittest.skipIf(not SM70OrLater, "triton")
|
|
@unittest.skipIf(not torch.cuda.is_available(), "Test requires CUDA.")
|
|
@parametrize("compile_mode", ["none", "eager", "compile", "compile_dynamic_shape"])
|
|
@parametrize("combine_mode", ["pointwise", "generic"])
|
|
@parametrize("reverse", [False, True])
|
|
@parametrize("device", [torch.device("cpu"), torch.device("cuda")])
|
|
# Skipping the combine_mode=pointwise
|
|
# as the current implementation of associative_scan lowering
|
|
# does not support lifted arguments
|
|
@decorateIf(
|
|
unittest.skip,
|
|
lambda params: (params["combine_mode"] == "pointwise"),
|
|
)
|
|
def test_associative_scan_freevars_fct(
|
|
self, compile_mode, combine_mode, reverse, device
|
|
):
|
|
def additional_fct_no_add_inp(x, y):
|
|
return x * y
|
|
|
|
def fct_nested_outside(x: torch.Tensor, y: torch.Tensor):
|
|
ret = additional_fct_no_add_inp(y, y)
|
|
return x + ret
|
|
|
|
inp = torch.randn(3, 4, 5, device=device)
|
|
|
|
kwargs = {
|
|
"dim": 0,
|
|
"reverse": reverse,
|
|
"compile_mode": compile_mode,
|
|
"combine_fn": fct_nested_outside,
|
|
"combine_mode": combine_mode,
|
|
}
|
|
kwargs_fake = self._prepare_fake_kwargs(kwargs)
|
|
self._run_test(
|
|
model=AssociativeScanModels.CombineFn(**kwargs),
|
|
model_fake=AssociativeScanModels.CombineFn(**kwargs_fake),
|
|
inputs=inp,
|
|
)
|
|
|
|
@unittest.skipIf(not SM70OrLater, "triton")
|
|
@unittest.skipIf(not torch.cuda.is_available(), "Test requires CUDA.")
|
|
@parametrize("compile_mode", ["none", "eager", "compile", "compile_dynamic_shape"])
|
|
@parametrize("reverse", [False, True])
|
|
@parametrize("device", [torch.device("cpu"), torch.device("cuda")])
|
|
def test_associative_scan_freevars_fct_generic(self, compile_mode, reverse, device):
|
|
def additional_fct_no_add_inp(x, y):
|
|
return x * y
|
|
|
|
def fct_nested_outside(x: torch.Tensor, y: torch.Tensor):
|
|
ret = associative_scan(
|
|
additional_fct_no_add_inp, y, 1, combine_mode="generic"
|
|
)
|
|
return x + ret
|
|
|
|
def fct_nested_outside_fake(x: torch.Tensor, y: torch.Tensor):
|
|
ret = _fake_associative_scan(additional_fct_no_add_inp, y, 1)
|
|
return x + ret
|
|
|
|
inp = torch.randn(3, 4, 5, device=device)
|
|
|
|
kwargs = {
|
|
"dim": 0,
|
|
"reverse": reverse,
|
|
"compile_mode": compile_mode,
|
|
"combine_fn": fct_nested_outside,
|
|
"combine_mode": "generic",
|
|
}
|
|
kwargs_fake = self._prepare_fake_kwargs(kwargs)
|
|
kwargs_fake["combine_fn"] = fct_nested_outside_fake
|
|
self._run_test(
|
|
model=AssociativeScanModels.CombineFn(**kwargs),
|
|
model_fake=AssociativeScanModels.CombineFn(**kwargs_fake),
|
|
inputs=inp,
|
|
)
|
|
|
|
@unittest.skipIf(not SM70OrLater, "triton")
|
|
@unittest.skipIf(not torch.cuda.is_available(), "Test requires CUDA.")
|
|
@parametrize("compile_mode", ["none", "eager", "compile", "compile_dynamic_shape"])
|
|
@parametrize("combine_mode", ["pointwise", "generic"])
|
|
@parametrize("reverse", [False, True])
|
|
@parametrize("device", [torch.device("cpu"), torch.device("cuda")])
|
|
# Skipping the combine_mode=pointwise
|
|
# as the current implementation of associative_scan lowering
|
|
# does not support lifted arguments
|
|
@decorateIf(
|
|
unittest.skip,
|
|
lambda params: (params["combine_mode"] == "pointwise"),
|
|
)
|
|
def test_associative_scan_freevars_shape_check(
|
|
self, compile_mode, combine_mode, reverse, device
|
|
):
|
|
H = torch.eye(2, device=device, requires_grad=True)
|
|
|
|
def fct_freevars(x: torch.Tensor, y: torch.Tensor):
|
|
return x @ H + y
|
|
|
|
inp = torch.randn(2, 2, 3, device=device, requires_grad=True)
|
|
|
|
kwargs = {
|
|
"dim": 2,
|
|
"reverse": reverse,
|
|
"compile_mode": compile_mode,
|
|
"combine_fn": fct_freevars,
|
|
"combine_mode": combine_mode,
|
|
}
|
|
kwargs_fake = self._prepare_fake_kwargs(kwargs)
|
|
self._run_test(
|
|
model=AssociativeScanModels.CombineFn(**kwargs),
|
|
model_fake=AssociativeScanModels.CombineFn(**kwargs_fake),
|
|
inputs=inp,
|
|
)
|
|
|
|
@unittest.skipIf(not SM70OrLater, "triton")
|
|
@unittest.skipIf(not torch.cuda.is_available(), "Test requires CUDA.")
|
|
@parametrize("compile_mode", ["none", "eager", "compile", "compile_dynamic_shape"])
|
|
@parametrize("reverse", [False, True])
|
|
@parametrize("device", [torch.device("cpu"), torch.device("cuda")])
|
|
@parametrize("combine_mode", ["pointwise", "generic"])
|
|
# Skipping the combine_mode=pointwise
|
|
# as the current implementation of associative_scan lowering
|
|
# does not support lifted arguments
|
|
@decorateIf(
|
|
unittest.skip,
|
|
lambda params: (params["combine_mode"] == "pointwise"),
|
|
)
|
|
def test_associative_scan_freevars_pytree(
|
|
self, compile_mode, combine_mode, reverse, device
|
|
):
|
|
xf = torch.randn(2, 2, device=device, requires_grad=True)
|
|
yf = torch.randn(2, 2, device=device, requires_grad=True)
|
|
zf = torch.randn(2, 2, device=device, requires_grad=True)
|
|
inpf = {"i": xf, "j": ([yf], [{"o": zf}])}
|
|
|
|
def fct_pointwise(x, y):
|
|
return {
|
|
"i": (x["i"] * y["i"]) + inpf["i"],
|
|
"j": (
|
|
[(x["j"][0][0] * y["j"][0][0]) + inpf["j"][0][0]],
|
|
[
|
|
{
|
|
"o": (x["j"][1][0]["o"] + y["j"][1][0]["o"])
|
|
+ inpf["j"][1][0]["o"]
|
|
}
|
|
],
|
|
),
|
|
}
|
|
|
|
x = torch.randn(3, 2, 2, device=device, requires_grad=True)
|
|
y = torch.randn(3, 2, 2, device=device, requires_grad=True)
|
|
z = torch.randn(3, 2, 2, device=device, requires_grad=True)
|
|
inp = {"i": x, "j": ([y], [{"o": z}])}
|
|
|
|
kwargs = {
|
|
"dim": 0,
|
|
"reverse": reverse,
|
|
"compile_mode": compile_mode,
|
|
"combine_fn": fct_pointwise,
|
|
"combine_mode": combine_mode,
|
|
}
|
|
kwargs_fake = self._prepare_fake_kwargs(kwargs)
|
|
self._run_test(
|
|
model=AssociativeScanModels.CombineFn(**kwargs),
|
|
model_fake=AssociativeScanModels.CombineFn(**kwargs_fake),
|
|
inputs=inp,
|
|
)
|
|
|
|
@unittest.skipIf(not SM70OrLater, "triton")
|
|
def test_associative_scan_sparse_tensor(self):
|
|
x = torch.tensor(
|
|
[[[0.0, 0], [1.0, 2.0]], [[0.0, 0], [3.0, 4.0]], [[0.0, 0], [5.0, 6.0]]]
|
|
).to_sparse()
|
|
|
|
with self.assertRaisesRegex(
|
|
ValueError,
|
|
"xs leaves must dense Tensors.*",
|
|
):
|
|
associative_scan(
|
|
get_scan_combine_fn("add", True), x, 0, combine_mode="generic"
|
|
)
|
|
|
|
@unittest.skipIf(not SM70OrLater, "triton")
|
|
@requires_cuda
|
|
def test_associative_scan_combine_fn_wrong_meta_in_combine_fn(self):
|
|
device = torch.device("cuda")
|
|
B, N, C, H, W = 3, 3, 2, 3, 3
|
|
x = torch.randn(B, N, C, H, W, device=device)
|
|
|
|
def fct_wrong_dtype(x, y):
|
|
return (x + y).to(torch.int64)
|
|
|
|
def fct_wrong_device(x, y):
|
|
return (x + y).to(
|
|
torch.device("cpu") if device.type == "cuda" else torch.device("cuda")
|
|
)
|
|
|
|
def fct_wrong_stride(x, y):
|
|
return (x + y).to(memory_format=torch.channels_last)
|
|
|
|
for fct in [fct_wrong_dtype, fct_wrong_device, fct_wrong_stride]:
|
|
with self.assertRaisesRegex(
|
|
torch._dynamo.exc.UncapturedHigherOrderOpError,
|
|
"Expected initial_xs and combine_fn_output to have same metadata.*",
|
|
):
|
|
associative_scan(fct, x, 0)
|
|
|
|
@unittest.skipIf(not SM70OrLater, "triton")
|
|
def test_associative_scan_wrong_pytree(self):
|
|
def fct_wrong_pytree(x, y):
|
|
return {
|
|
"i": x["i"] * y["j"][0][0],
|
|
"k": torch.tensor(0.0),
|
|
"j": ([x["j"][1][0]["o"]], [{"o": torch.sin(x["i"])}]),
|
|
}
|
|
|
|
x = torch.randn(3, 2, 2)
|
|
y = torch.randn(3, 2, 2)
|
|
z = torch.randn(3, 2, 2)
|
|
inp = {"i": x, "j": ([y], [{"o": z}])}
|
|
|
|
with self.assertRaisesRegex(
|
|
AssertionError,
|
|
"Combin_fn received wrong number of arguments.*",
|
|
):
|
|
associative_scan(fct_wrong_pytree, inp, 0, combine_mode="generic")
|
|
|
|
@unittest.skipIf(not SM70OrLater, "triton")
|
|
@requires_cuda
|
|
def test_associative_scan_non_pointwise(self):
|
|
device = torch.device("cuda")
|
|
x = torch.randn(3, 10, 2, device=device)
|
|
with self.assertRaisesRegex(
|
|
# Should be:
|
|
RuntimeError,
|
|
r"For combine_mode='pointwise', the combine_fn needs to be pointwise",
|
|
):
|
|
associative_scan(
|
|
get_scan_combine_fn("non_pointwise", True),
|
|
x,
|
|
0,
|
|
combine_mode="pointwise",
|
|
)
|
|
|
|
@requires_cuda
|
|
def test_associative_scan_input_mutation(self):
|
|
device = torch.device("cuda")
|
|
|
|
def fct_input_mutation(x, y):
|
|
x.add_(1)
|
|
return x + y
|
|
|
|
x = torch.randn(3, 2, 2, device=device)
|
|
|
|
with self.assertRaisesRegex(
|
|
# Should be
|
|
# torch._dynamo.exc.Unsupported,
|
|
# "Encountered aliasing during higher order op tracing for HOP.*"
|
|
torch._dynamo.exc.UncapturedHigherOrderOpError,
|
|
"associative_scan must be captured completely with torch.compile.*",
|
|
):
|
|
associative_scan(fct_input_mutation, x, 0)
|
|
|
|
@requires_cuda
|
|
def test_associative_scan_input_output_alias(self):
|
|
device = torch.device("cuda")
|
|
|
|
def fct_input_output_alias(x, y):
|
|
return x[0], x[1] + y[1]
|
|
|
|
x = torch.randn(3, 2, 2, device=device)
|
|
y = torch.randn(3, 2, 2, device=device)
|
|
inp = (x, y)
|
|
|
|
with self.assertRaisesRegex(
|
|
# Should be
|
|
# torch._dynamo.exc.Unsupported,
|
|
# "Encountered aliasing during higher order op tracing for HOP.*"
|
|
torch._dynamo.exc.UncapturedHigherOrderOpError,
|
|
"associative_scan must be captured completely with torch.compile.*",
|
|
):
|
|
associative_scan(fct_input_output_alias, inp, 0)
|
|
|
|
@unittest.skipIf(not SM70OrLater, "triton")
|
|
@requires_cuda
|
|
def test_associative_scan_output_output_alias(self):
|
|
device = torch.device("cuda")
|
|
|
|
def fct_output_output_alias(x, y):
|
|
c = x[0] + y[1]
|
|
return c, c
|
|
|
|
x = torch.randn(3, 2, 2, device=device)
|
|
y = torch.randn(3, 2, 2, device=device)
|
|
inp = (x, y)
|
|
|
|
with self.assertRaisesRegex(
|
|
# Should be
|
|
# torch._dynamo.exc.Unsupported,
|
|
# "Encountered aliasing during higher order op tracing for HOP.*"
|
|
torch._dynamo.exc.UncapturedHigherOrderOpError,
|
|
"associative_scan must be captured completely with torch.compile.*",
|
|
):
|
|
associative_scan(fct_output_output_alias, inp, 0)
|
|
|
|
|
|
@unittest.skipIf(IS_WINDOWS, "Windows not supported for this test")
|
|
@skipIfNoDynamoSupport
|
|
class TestControlFlowTraced(TestCase):
|
|
def setUp(self):
|
|
torch._dynamo.reset()
|
|
super().setUp()
|
|
|
|
def _check_tracing(self, fn, args, allow_non_fake_inputs=False):
|
|
graphs = {}
|
|
eager_res = fn(*args)
|
|
for tracing_mode in ["symbolic", "real", "fake"]:
|
|
graph = make_fx(
|
|
fn,
|
|
tracing_mode=tracing_mode,
|
|
_allow_non_fake_inputs=allow_non_fake_inputs,
|
|
)(*args)
|
|
graphs[tracing_mode] = graph
|
|
self.assertEqual(graph(*args), eager_res)
|
|
return graphs
|
|
|
|
def _check_compile(self, fn, args, *, dynamic=False, backend="eager"):
|
|
eager_res = fn(*args)
|
|
compiled_fn = torch.compile(fn, backend=backend, dynamic=dynamic)
|
|
self.assertEqual(compiled_fn(*args), eager_res)
|
|
|
|
def _check_export(self, fn, args, *, strict=False, dynamic_shapes=None):
|
|
eg_out = fn(*args)
|
|
ep = torch.export.export(fn, args, strict=strict, dynamic_shapes=dynamic_shapes)
|
|
ep_out = ep.module()(*args)
|
|
self.assertEqual(eg_out, ep_out)
|
|
return ep
|
|
|
|
def test_cond_traced_not_nested(self):
|
|
def true_fn(x):
|
|
return x.sin()
|
|
|
|
def false_fn(x):
|
|
return x.cos()
|
|
|
|
def f(x, y):
|
|
return cond(y, true_fn, false_fn, [x])
|
|
|
|
x = torch.randn(4)
|
|
graph = make_fx(f)(x, torch.tensor(False))
|
|
result_true = graph.forward(x, torch.tensor(True))
|
|
result_false = graph.forward(x, torch.tensor(False))
|
|
self.assertFalse(torch.allclose(result_true, result_false))
|
|
self.assertEqual(result_true, torch.sin(x))
|
|
self.assertEqual(result_false, torch.cos(x))
|
|
|
|
graph = make_fx(f, tracing_mode="symbolic")(x, torch.tensor(False))
|
|
self.assertEqual(graph(x, torch.tensor(True)), f(x, torch.tensor(True)))
|
|
|
|
@skipIfTorchDynamo("Graph is not captured by backend if test with dynamo")
|
|
@skipIfCrossRef # Arg order changes with crossref
|
|
def test_cond_simple_with_linear_compile_check_graph(self):
|
|
def true_fn(x):
|
|
return x.sin()
|
|
|
|
def false_fn(x):
|
|
return x.cos()
|
|
|
|
x = torch.randn(4, requires_grad=True)
|
|
|
|
def f(pred, x):
|
|
result = cond(pred, true_fn, false_fn, (x,))
|
|
grad_out = torch.ones_like(result)
|
|
return torch.autograd.grad(result, (x,), grad_out)
|
|
|
|
backend = EagerAndRecordGraphs()
|
|
torch.compile(f, backend=backend)(torch.tensor(False), x)
|
|
self.assertEqual(len(backend.graphs), 2)
|
|
gm = backend.graphs[0]
|
|
|
|
self.assertExpectedInline(
|
|
gm.code.strip(),
|
|
"""\
|
|
def forward(self, L_pred_ : torch.Tensor, L_x_ : torch.Tensor):
|
|
l_pred_ = L_pred_
|
|
l_x_ = L_x_
|
|
cond_true_0 = self.cond_true_0
|
|
cond_false_0 = self.cond_false_0
|
|
cond = torch.ops.higher_order.cond(l_pred_, cond_true_0, cond_false_0, (l_x_,)); l_pred_ = cond_true_0 = cond_false_0 = l_x_ = None
|
|
result = cond[0]; cond = None
|
|
grad_out = torch.ones_like(result)
|
|
return (result, grad_out)""", # noqa: B950
|
|
)
|
|
self.assertExpectedInline(
|
|
normalize_gm(backend.graphs[1].print_readable(print_output=False)),
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(self, L_ctx_saved_tensors_0_: "f32[4]", L_ctx_pred: "b8[]", L_args_1_: "f32[4]"):
|
|
l_ctx_saved_tensors_0_ = L_ctx_saved_tensors_0_
|
|
l_ctx_pred = L_ctx_pred
|
|
l_args_1_ = L_args_1_
|
|
|
|
cond_true_0 = self.cond_true_0
|
|
cond_false_0 = self.cond_false_0
|
|
cond = torch.ops.higher_order.cond(l_ctx_pred, cond_true_0, cond_false_0, (l_args_1_, l_ctx_saved_tensors_0_)); l_ctx_pred = cond_true_0 = cond_false_0 = l_args_1_ = l_ctx_saved_tensors_0_ = None
|
|
getitem: "f32[4]" = cond[0]; cond = None
|
|
return (getitem,)
|
|
|
|
class cond_true_0(torch.nn.Module):
|
|
def forward(self, l_args_1_: "f32[4]", l_ctx_saved_tensors_0_: "f32[4]"):
|
|
l_args_1__1 = l_args_1_
|
|
l_ctx_saved_tensors_0__1 = l_ctx_saved_tensors_0_
|
|
|
|
sin: "f32[4]" = torch.ops.aten.sin.default(l_ctx_saved_tensors_0__1); sin = None
|
|
|
|
cos: "f32[4]" = torch.ops.aten.cos.default(l_ctx_saved_tensors_0__1); l_ctx_saved_tensors_0__1 = None
|
|
|
|
mul: "f32[4]" = torch.ops.aten.mul.Tensor(l_args_1__1, cos); l_args_1__1 = cos = None
|
|
return (mul,)
|
|
|
|
class cond_false_0(torch.nn.Module):
|
|
def forward(self, l_args_1_: "f32[4]", l_ctx_saved_tensors_0_: "f32[4]"):
|
|
l_args_1__1 = l_args_1_
|
|
l_ctx_saved_tensors_0__1 = l_ctx_saved_tensors_0_
|
|
|
|
cos: "f32[4]" = torch.ops.aten.cos.default(l_ctx_saved_tensors_0__1); cos = None
|
|
|
|
sin: "f32[4]" = torch.ops.aten.sin.default(l_ctx_saved_tensors_0__1); l_ctx_saved_tensors_0__1 = None
|
|
|
|
neg: "f32[4]" = torch.ops.aten.neg.default(sin); sin = None
|
|
|
|
mul: "f32[4]" = torch.ops.aten.mul.Tensor(l_args_1__1, neg); l_args_1__1 = neg = None
|
|
return (mul,)
|
|
""", # noqa: B950
|
|
)
|
|
|
|
def test_while_loop_op_mismatch_in_meta(self):
|
|
class Mod(torch.nn.Module):
|
|
def forward(self, c, a, b):
|
|
def cond_fn(c, a, b):
|
|
return c > 0
|
|
|
|
def body_fn(c, a, b):
|
|
return c - 1, a.nonzero(), b.nonzero()
|
|
|
|
return torch.ops.higher_order.while_loop(
|
|
cond_fn,
|
|
body_fn,
|
|
(c, a, b),
|
|
tuple(),
|
|
)
|
|
|
|
with self.assertRaisesRegex(
|
|
torch._dynamo.exc.UncapturedHigherOrderOpError,
|
|
"Expected carried_inputs and body_output to have same metadata but found",
|
|
):
|
|
make_fx(Mod(), tracing_mode="fake")(
|
|
torch.tensor(
|
|
0,
|
|
),
|
|
torch.randn(2, 3),
|
|
torch.randn(2, 3),
|
|
)
|
|
|
|
def test_while_loop_nested_traced(self):
|
|
fn, inp = WHILE_LOOP_TESTS["nested"]
|
|
graphs = self._check_tracing(fn, inp)
|
|
self.assertExpectedInline(
|
|
graphs["symbolic"].code.strip("\n"),
|
|
"""\
|
|
def forward(self, out_iter_1, it_1, y_1):
|
|
while_loop_cond_graph_0 = self.while_loop_cond_graph_0
|
|
while_loop_body_graph_0 = self.while_loop_body_graph_0
|
|
while_loop = torch.ops.higher_order.while_loop(while_loop_cond_graph_0, while_loop_body_graph_0, (out_iter_1, it_1, y_1), ()); while_loop_cond_graph_0 = while_loop_body_graph_0 = out_iter_1 = it_1 = y_1 = None
|
|
getitem = while_loop[0]
|
|
getitem_1 = while_loop[1]
|
|
getitem_2 = while_loop[2]; while_loop = None
|
|
return (getitem, getitem_1, getitem_2)
|
|
""", # noqa: B950
|
|
)
|
|
self.assertExpectedInline(
|
|
graphs["symbolic"].while_loop_cond_graph_0.code.strip("\n"),
|
|
"""\
|
|
def forward(self, arg0_1, arg1_1, arg2_1):
|
|
sum_1 = torch.ops.aten.sum.default(arg0_1); arg0_1 = None
|
|
lt = torch.ops.aten.lt.Scalar(sum_1, 2); sum_1 = None
|
|
return lt
|
|
""",
|
|
)
|
|
self.assertExpectedInline(
|
|
graphs["symbolic"].while_loop_body_graph_0.code.strip("\n"),
|
|
"""\
|
|
def forward(self, arg0_1, arg1_1, arg2_1):
|
|
while_loop_cond_graph_0 = self.while_loop_cond_graph_0
|
|
while_loop_body_graph_0 = self.while_loop_body_graph_0
|
|
while_loop = torch.ops.higher_order.while_loop(while_loop_cond_graph_0, while_loop_body_graph_0, (arg0_1, arg1_1, arg2_1), ()); while_loop_cond_graph_0 = while_loop_body_graph_0 = arg0_1 = arg1_1 = arg2_1 = None
|
|
getitem = while_loop[0]
|
|
getitem_1 = while_loop[1]
|
|
getitem_2 = while_loop[2]; while_loop = None
|
|
add = torch.ops.aten.add.Tensor(getitem, 1); getitem = None
|
|
return (add, getitem_1, getitem_2)
|
|
""", # noqa: B950
|
|
)
|
|
|
|
def test_while_loop_pytree_carry(self):
|
|
fn, inp = WHILE_LOOP_TESTS["simple_with_pytree_carry"]
|
|
backend = EagerAndRecordGraphs()
|
|
expected_res = fn(*inp)
|
|
compiled_res = torch.compile(fn, backend=backend)(*inp)
|
|
self.assertEqual(expected_res, compiled_res)
|
|
# When test with torch dynamo, the graph is not captured because
|
|
# it's traced together with the code before torch.compile
|
|
if not TEST_WITH_TORCHDYNAMO:
|
|
self.assertEqual(len(backend.graphs), 1)
|
|
self.assertExpectedInline(
|
|
backend.graphs[0].code.strip(),
|
|
"""\
|
|
def forward(self, L_it_ : torch.Tensor, L_pytree_input_0_0_ : torch.Tensor, L_pytree_input_1_x_ : torch.Tensor, L_pytree_input_1_y_ : torch.Tensor):
|
|
l_it_ = L_it_
|
|
l_pytree_input_0_0_ = L_pytree_input_0_0_
|
|
l_pytree_input_1_x_ = L_pytree_input_1_x_
|
|
l_pytree_input_1_y_ = L_pytree_input_1_y_
|
|
cond_fn_0 = self.cond_fn_0
|
|
body_fn_0 = self.body_fn_0
|
|
while_loop = torch.ops.higher_order.while_loop(cond_fn_0, body_fn_0, (l_it_, l_pytree_input_0_0_, l_pytree_input_1_x_, l_pytree_input_1_y_), ()); cond_fn_0 = body_fn_0 = l_it_ = l_pytree_input_0_0_ = l_pytree_input_1_x_ = l_pytree_input_1_y_ = None
|
|
getitem = while_loop[0]
|
|
getitem_1 = while_loop[1]
|
|
value = while_loop[2]
|
|
value_1 = while_loop[3]; while_loop = None
|
|
return (getitem, getitem_1, value, value_1)""", # noqa: B950
|
|
)
|
|
|
|
def _wrap_with_functionalize(self, fn, func_type):
|
|
mode = None
|
|
if func_type == "cpp":
|
|
fn = CppFunctionalizeAPI().functionalize(fn)
|
|
elif func_type == "python":
|
|
fn = PythonFunctionalizeAPI().functionalize(fn)
|
|
mode = FunctionalTensorMode()
|
|
elif func_type == "functorch":
|
|
fn = torch.func.functionalize(fn)
|
|
else:
|
|
assert func_type == "no"
|
|
return fn, mode
|
|
|
|
@parametrize("func_type", ["no", "cpp", "python", "functorch"])
|
|
def test_while_loop_simple_functionalize_check_graph(self, func_type):
|
|
fn, inp = WHILE_LOOP_TESTS["simple_with_mutation"]
|
|
fn, mode = self._wrap_with_functionalize(fn, func_type)
|
|
mode = mode if mode is not None else contextlib.nullcontext()
|
|
with mode:
|
|
graphs = self._check_tracing(fn, inp)
|
|
if func_type == "no":
|
|
self.assertExpectedInline(
|
|
graphs["symbolic"].code.strip("\n"),
|
|
"""\
|
|
def forward(self, x_1):
|
|
while_loop_cond_graph_0 = self.while_loop_cond_graph_0
|
|
while_loop_body_graph_0 = self.while_loop_body_graph_0
|
|
while_loop = torch.ops.higher_order.while_loop(while_loop_cond_graph_0, while_loop_body_graph_0, (x_1,), ()); while_loop_cond_graph_0 = while_loop_body_graph_0 = x_1 = None
|
|
getitem = while_loop[0]; while_loop = None
|
|
return (getitem,)
|
|
""", # noqa: B950
|
|
)
|
|
self.assertExpectedInline(
|
|
graphs["symbolic"].while_loop_cond_graph_0.code.strip("\n"),
|
|
"""\
|
|
def forward(self, arg0_1):
|
|
clone = torch.ops.aten.clone.default(arg0_1); arg0_1 = None
|
|
add_ = torch.ops.aten.add_.Tensor(clone, 1); clone = None
|
|
add__1 = torch.ops.aten.add_.Tensor(add_, -1); add_ = None
|
|
sum_1 = torch.ops.aten.sum.default(add__1); add__1 = None
|
|
lt = torch.ops.aten.lt.Scalar(sum_1, 10); sum_1 = None
|
|
return lt
|
|
""",
|
|
)
|
|
self.assertExpectedInline(
|
|
graphs["symbolic"].while_loop_body_graph_0.code.strip("\n"),
|
|
"""\
|
|
def forward(self, arg0_1):
|
|
clone = torch.ops.aten.clone.default(arg0_1); arg0_1 = None
|
|
add_ = torch.ops.aten.add_.Tensor(clone, 1); clone = None
|
|
add__1 = torch.ops.aten.add_.Tensor(add_, -1); add_ = None
|
|
add = torch.ops.aten.add.Tensor(add__1, 1); add__1 = None
|
|
return (add,)
|
|
""",
|
|
)
|
|
elif func_type == "python":
|
|
self.assertExpectedInline(
|
|
graphs["symbolic"].code.strip("\n"),
|
|
"""\
|
|
def forward(self, arg0_1):
|
|
while_loop_cond_graph_0 = self.while_loop_cond_graph_0
|
|
while_loop_body_graph_0 = self.while_loop_body_graph_0
|
|
while_loop = torch.ops.higher_order.while_loop(while_loop_cond_graph_0, while_loop_body_graph_0, (arg0_1,), ()); while_loop_cond_graph_0 = while_loop_body_graph_0 = arg0_1 = None
|
|
getitem = while_loop[0]; while_loop = None
|
|
return (getitem,)
|
|
""", # noqa: B950
|
|
)
|
|
self.assertExpectedInline(
|
|
graphs["symbolic"].while_loop_cond_graph_0.code.strip("\n"),
|
|
"""\
|
|
def forward(self, arg0_1):
|
|
clone = torch.ops.aten.clone.default(arg0_1); arg0_1 = None
|
|
add = torch.ops.aten.add.Tensor(clone, 1); clone = None
|
|
add_1 = torch.ops.aten.add.Tensor(add, -1); add = None
|
|
sum_1 = torch.ops.aten.sum.default(add_1); add_1 = None
|
|
lt = torch.ops.aten.lt.Scalar(sum_1, 10); sum_1 = None
|
|
return lt
|
|
""",
|
|
)
|
|
self.assertExpectedInline(
|
|
graphs["symbolic"].while_loop_body_graph_0.code.strip("\n"),
|
|
"""\
|
|
def forward(self, arg0_1):
|
|
clone = torch.ops.aten.clone.default(arg0_1); arg0_1 = None
|
|
add = torch.ops.aten.add.Tensor(clone, 1); clone = None
|
|
add_1 = torch.ops.aten.add.Tensor(add, -1); add = None
|
|
add_2 = torch.ops.aten.add.Tensor(add_1, 1); add_1 = None
|
|
return (add_2,)
|
|
""",
|
|
)
|
|
else:
|
|
self.assertExpectedInline(
|
|
graphs["symbolic"].code.strip("\n"),
|
|
"""\
|
|
def forward(self, x_1):
|
|
while_loop_cond_graph_0 = self.while_loop_cond_graph_0
|
|
while_loop_body_graph_0 = self.while_loop_body_graph_0
|
|
while_loop = torch.ops.higher_order.while_loop(while_loop_cond_graph_0, while_loop_body_graph_0, (x_1,), ()); while_loop_cond_graph_0 = while_loop_body_graph_0 = x_1 = None
|
|
getitem = while_loop[0]; while_loop = None
|
|
return (getitem,)
|
|
""", # noqa: B950
|
|
)
|
|
self.assertExpectedInline(
|
|
graphs["symbolic"].while_loop_cond_graph_0.code.strip("\n"),
|
|
"""\
|
|
def forward(self, arg0_1):
|
|
clone = torch.ops.aten.clone.default(arg0_1); arg0_1 = None
|
|
add = torch.ops.aten.add.Tensor(clone, 1); clone = None
|
|
add_1 = torch.ops.aten.add.Tensor(add, -1); add = None
|
|
sum_1 = torch.ops.aten.sum.default(add_1); add_1 = None
|
|
lt = torch.ops.aten.lt.Scalar(sum_1, 10); sum_1 = None
|
|
return lt
|
|
""",
|
|
)
|
|
self.assertExpectedInline(
|
|
graphs["symbolic"].while_loop_body_graph_0.code.strip("\n"),
|
|
"""\
|
|
def forward(self, arg0_1):
|
|
clone = torch.ops.aten.clone.default(arg0_1); arg0_1 = None
|
|
add = torch.ops.aten.add.Tensor(clone, 1); clone = None
|
|
add_1 = torch.ops.aten.add.Tensor(add, -1); add = None
|
|
add_2 = torch.ops.aten.add.Tensor(add_1, 1); add_1 = None
|
|
return (add_2,)
|
|
""",
|
|
)
|
|
|
|
@parametrize("func_type", ["no", "cpp", "python", "functorch"])
|
|
# - "simple_with_linear" and "nested_with_linear" doesn't work because parameters and buffers
|
|
# are not inputs so they're not wrapped by functionalization and tracing.
|
|
#
|
|
# - make_fx tracing mode "real" fails for "int_carry", "pytree_int_carry" and "const_and_symint_output"
|
|
# because tensors are real but we unspecialize the ints with unbacked symints causing
|
|
# data dependent errors.
|
|
# Since this is not the common use path, we skip them for now.
|
|
@parametrize(
|
|
"while_loop_test",
|
|
set(WHILE_LOOP_TESTS.keys())
|
|
- {
|
|
"simple_with_linear",
|
|
"nested_with_linear",
|
|
"int_carry",
|
|
"pytree_int_carry",
|
|
"const_and_symint_output",
|
|
},
|
|
)
|
|
def test_while_loop_functionalize(self, func_type, while_loop_test):
|
|
fn, inp = WHILE_LOOP_TESTS[while_loop_test]
|
|
fn, mode = self._wrap_with_functionalize(fn, func_type)
|
|
mode = mode if mode is not None else contextlib.nullcontext()
|
|
with mode:
|
|
self._check_tracing(fn, inp)
|
|
|
|
# - make_fx tracing mode "real" fails for "int_carry", "pytree_int_carry" and "const_and_symint_output"
|
|
# because tensors are real but we unspecialize the ints with unbacked symints causing
|
|
# data dependent errors.
|
|
# Since this is not the common use path, we skip them for now.
|
|
@parametrize(
|
|
"while_loop_test",
|
|
set(WHILE_LOOP_TESTS.keys())
|
|
- {"int_carry", "pytree_int_carry", "const_and_symint_output"},
|
|
)
|
|
def test_while_loop_tracing(self, while_loop_test):
|
|
fn, inp = WHILE_LOOP_TESTS[while_loop_test]
|
|
allow_non_fake_inputs = (
|
|
False
|
|
if while_loop_test not in ("simple_with_linear", "nested_with_linear")
|
|
else True
|
|
)
|
|
self._check_tracing(fn, inp, allow_non_fake_inputs)
|
|
|
|
@parametrize("backend", ["eager", "aot_eager"])
|
|
@parametrize("while_loop_test", list(WHILE_LOOP_TESTS.keys()))
|
|
def test_while_loop_compile(self, backend, while_loop_test):
|
|
fn, inp = WHILE_LOOP_TESTS[while_loop_test]
|
|
self._check_compile(fn, inp, backend=backend)
|
|
|
|
@skipIfTorchDynamo("Graph is not captured by backend if test with dynamo")
|
|
@skipIfCrossRef # Arg order changes with cross ref
|
|
def test_while_loop_simple_with_linear_compile_check_graph(self):
|
|
fn, inp = WHILE_LOOP_TESTS["simple_with_linear"]
|
|
backend = EagerAndRecordGraphs()
|
|
torch.compile(fn, backend=backend)(*inp)
|
|
self.assertEqual(len(backend.graphs), 1)
|
|
gm = backend.graphs[0]
|
|
if torch._dynamo.config.inline_inbuilt_nn_modules:
|
|
self.assertExpectedInline(
|
|
gm.code.strip(),
|
|
"""\
|
|
def forward(self, L_iter_ : torch.Tensor, L_x_ : torch.Tensor, L_self_buffers_dec_ : torch.Tensor, L_self_modules_linear_parameters_weight_ : torch.nn.parameter.Parameter, L_self_modules_linear_parameters_bias_ : torch.nn.parameter.Parameter):
|
|
l_iter_ = L_iter_
|
|
l_x_ = L_x_
|
|
l_self_buffers_dec_ = L_self_buffers_dec_
|
|
l_self_modules_linear_parameters_weight_ = L_self_modules_linear_parameters_weight_
|
|
l_self_modules_linear_parameters_bias_ = L_self_modules_linear_parameters_bias_
|
|
cond_fn_0 = self.cond_fn_0
|
|
body_fn_0 = self.body_fn_0
|
|
while_loop = torch.ops.higher_order.while_loop(cond_fn_0, body_fn_0, (l_iter_, l_x_), (l_self_buffers_dec_, l_self_modules_linear_parameters_bias_, l_self_modules_linear_parameters_weight_)); cond_fn_0 = body_fn_0 = l_iter_ = l_x_ = l_self_buffers_dec_ = l_self_modules_linear_parameters_bias_ = l_self_modules_linear_parameters_weight_ = None
|
|
getitem = while_loop[0]
|
|
getitem_1 = while_loop[1]; while_loop = None
|
|
return (getitem, getitem_1)""", # noqa: B950
|
|
)
|
|
self.assertExpectedInline(
|
|
gm.cond_fn_0.code.strip(),
|
|
"""\
|
|
def forward(self, child : torch.Tensor, child_1 : torch.Tensor, l_self_buffers_dec__cond_fn, l_self_modules_linear_parameters_bias__body_fn, l_self_modules_linear_parameters_weight__body_fn):
|
|
sub = child - l_self_buffers_dec__cond_fn; child = l_self_buffers_dec__cond_fn = None
|
|
gt = sub > 0; sub = None
|
|
return gt""", # noqa: B950
|
|
)
|
|
self.assertExpectedInline(
|
|
gm.body_fn_0.code.strip(),
|
|
"""\
|
|
def forward(self, child_2 : torch.Tensor, child_3 : torch.Tensor, l_self_buffers_dec__cond_fn, l_self_modules_linear_parameters_bias__body_fn, l_self_modules_linear_parameters_weight__body_fn):
|
|
child = child_2 - 1; child_2 = None
|
|
child_4 = torch._C._nn.linear(child_3, l_self_modules_linear_parameters_weight__body_fn, l_self_modules_linear_parameters_bias__body_fn); child_3 = l_self_modules_linear_parameters_weight__body_fn = l_self_modules_linear_parameters_bias__body_fn = None
|
|
return (child, child_4)""", # noqa: B950
|
|
)
|
|
else:
|
|
self.assertExpectedInline(
|
|
gm.code.strip(),
|
|
"""\
|
|
def forward(self, L_iter_ : torch.Tensor, L_x_ : torch.Tensor):
|
|
l_iter_ = L_iter_
|
|
l_x_ = L_x_
|
|
l__self___dec = self.L__self___dec
|
|
l__self___linear_weight = self.L__self___linear_weight
|
|
l__self___linear_bias = self.L__self___linear_bias
|
|
cond_fn_0 = self.cond_fn_0
|
|
body_fn_0 = self.body_fn_0
|
|
while_loop = torch.ops.higher_order.while_loop(cond_fn_0, body_fn_0, (l_iter_, l_x_), (l__self___dec, l__self___linear_bias, l__self___linear_weight)); cond_fn_0 = body_fn_0 = l_iter_ = l_x_ = l__self___dec = l__self___linear_bias = l__self___linear_weight = None
|
|
getitem = while_loop[0]
|
|
getitem_1 = while_loop[1]; while_loop = None
|
|
return (getitem, getitem_1)""", # noqa: B950
|
|
)
|
|
self.assertExpectedInline(
|
|
gm.cond_fn_0.code.strip(),
|
|
"""\
|
|
def forward(self, l_iter_, l_x_, l__self___dec_cond_fn, l__self___linear_bias_body_fn, l__self___linear_weight_body_fn):
|
|
sub = l_iter_ - l__self___dec_cond_fn; l_iter_ = l__self___dec_cond_fn = None
|
|
gt = sub > 0; sub = None
|
|
return gt""", # noqa: B950
|
|
)
|
|
self.assertExpectedInline(
|
|
gm.body_fn_0.code.strip(),
|
|
"""\
|
|
def forward(self, l_iter_, l_x_, l__self___dec_cond_fn, l__self___linear_bias_body_fn, l__self___linear_weight_body_fn):
|
|
child = l_iter_ - 1; l_iter_ = None
|
|
child_1 = torch._C._nn.linear(l_x_, l__self___linear_weight_body_fn, l__self___linear_bias_body_fn); l_x_ = l__self___linear_weight_body_fn = l__self___linear_bias_body_fn = None
|
|
return (child, child_1)""", # noqa: B950
|
|
)
|
|
|
|
def test_while_loop_nested2_traced(self):
|
|
fn, inp = WHILE_LOOP_TESTS["nested2"]
|
|
graphs = self._check_tracing(fn, inp)
|
|
gm = graphs["symbolic"]
|
|
outer_body = gm.while_loop_body_graph_0
|
|
inner_body = outer_body.while_loop_body_graph_0
|
|
inner_cond = outer_body.while_loop_cond_graph_0
|
|
self.assertExpectedInline(
|
|
gm.code.strip("\n"),
|
|
"""\
|
|
def forward(self, arg0_1, arg1_1, arg2_1, arg3_1):
|
|
sym_size_int = torch.ops.aten.sym_size.int(arg3_1, 1)
|
|
sym_size_int_1 = torch.ops.aten.sym_size.int(arg2_1, 1)
|
|
sym_size_int_2 = torch.ops.aten.sym_size.int(arg2_1, 0)
|
|
sym_size_int_3 = torch.ops.aten.sym_size.int(arg3_1, 0)
|
|
while_loop_cond_graph_0 = self.while_loop_cond_graph_0
|
|
while_loop_body_graph_0 = self.while_loop_body_graph_0
|
|
while_loop = torch.ops.higher_order.while_loop(while_loop_cond_graph_0, while_loop_body_graph_0, (arg0_1, arg1_1, arg2_1, arg3_1), (sym_size_int, sym_size_int_1, sym_size_int_2, sym_size_int_3)); while_loop_cond_graph_0 = while_loop_body_graph_0 = arg0_1 = arg1_1 = arg2_1 = arg3_1 = sym_size_int = sym_size_int_1 = sym_size_int_2 = sym_size_int_3 = None
|
|
getitem = while_loop[0]
|
|
getitem_1 = while_loop[1]
|
|
getitem_2 = while_loop[2]
|
|
getitem_3 = while_loop[3]; while_loop = None
|
|
return (getitem, getitem_1, getitem_2, getitem_3)
|
|
""", # noqa: B950
|
|
)
|
|
self.assertExpectedInline(
|
|
outer_body.code.strip("\n"),
|
|
"""\
|
|
def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1):
|
|
while_loop_cond_graph_0 = self.while_loop_cond_graph_0
|
|
while_loop_body_graph_0 = self.while_loop_body_graph_0
|
|
while_loop = torch.ops.higher_order.while_loop(while_loop_cond_graph_0, while_loop_body_graph_0, (arg0_1, arg1_1, arg2_1, arg3_1), (arg7_1, arg7_1, arg7_1, arg7_1)); while_loop_cond_graph_0 = while_loop_body_graph_0 = arg0_1 = arg1_1 = arg2_1 = arg3_1 = arg7_1 = None
|
|
getitem = while_loop[0]
|
|
getitem_1 = while_loop[1]
|
|
getitem_2 = while_loop[2]
|
|
getitem_3 = while_loop[3]; while_loop = None
|
|
sub = torch.ops.aten.sub.Tensor(getitem, 1); getitem = None
|
|
clone = torch.ops.aten.clone.default(getitem_1); getitem_1 = None
|
|
mul = torch.ops.aten.mul.Tensor(getitem_2, 2); getitem_2 = None
|
|
div = torch.ops.aten.div.Tensor(getitem_3, 2); getitem_3 = None
|
|
return (sub, clone, mul, div)
|
|
""", # noqa: B950
|
|
)
|
|
self.assertExpectedInline(
|
|
outer_body.code.strip("\n"),
|
|
"""\
|
|
def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1):
|
|
while_loop_cond_graph_0 = self.while_loop_cond_graph_0
|
|
while_loop_body_graph_0 = self.while_loop_body_graph_0
|
|
while_loop = torch.ops.higher_order.while_loop(while_loop_cond_graph_0, while_loop_body_graph_0, (arg0_1, arg1_1, arg2_1, arg3_1), (arg7_1, arg7_1, arg7_1, arg7_1)); while_loop_cond_graph_0 = while_loop_body_graph_0 = arg0_1 = arg1_1 = arg2_1 = arg3_1 = arg7_1 = None
|
|
getitem = while_loop[0]
|
|
getitem_1 = while_loop[1]
|
|
getitem_2 = while_loop[2]
|
|
getitem_3 = while_loop[3]; while_loop = None
|
|
sub = torch.ops.aten.sub.Tensor(getitem, 1); getitem = None
|
|
clone = torch.ops.aten.clone.default(getitem_1); getitem_1 = None
|
|
mul = torch.ops.aten.mul.Tensor(getitem_2, 2); getitem_2 = None
|
|
div = torch.ops.aten.div.Tensor(getitem_3, 2); getitem_3 = None
|
|
return (sub, clone, mul, div)
|
|
""", # noqa: B950
|
|
)
|
|
self.assertExpectedInline(
|
|
inner_body.code.strip("\n"),
|
|
"""\
|
|
def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1):
|
|
clone = torch.ops.aten.clone.default(arg0_1); arg0_1 = None
|
|
sub = torch.ops.aten.sub.Tensor(arg1_1, 1); arg1_1 = None
|
|
add = torch.ops.aten.add.Tensor(arg2_1, 3.14); arg2_1 = None
|
|
sub_1 = torch.ops.aten.sub.Tensor(arg3_1, 2.71); arg3_1 = None
|
|
return (clone, sub, add, sub_1)
|
|
""",
|
|
)
|
|
self.assertExpectedInline(
|
|
inner_cond.code.strip("\n"),
|
|
"""\
|
|
def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1):
|
|
gt = torch.ops.aten.gt.Scalar(arg1_1, 0); arg1_1 = None
|
|
return gt
|
|
""",
|
|
)
|
|
|
|
def test_cond_nested_traced(self):
|
|
def true_nested(y):
|
|
return y * y
|
|
|
|
def false_nested(y):
|
|
return y + y
|
|
|
|
def true_fn(x, pred2):
|
|
z = cond(pred2, true_nested, false_nested, [x])
|
|
return x + z
|
|
|
|
def false_fn(x, _):
|
|
return x.cos()
|
|
|
|
def f(x, pred, pred2):
|
|
return cond(pred, true_fn, false_fn, [x, pred2])
|
|
|
|
x = torch.randn(4)
|
|
graph = make_fx(f)(x, torch.tensor(False), torch.tensor(False))
|
|
|
|
result_true_true = graph.forward(
|
|
x, torch.tensor(True), torch.tensor(True)
|
|
) # True + True -> x * x
|
|
result_true_false = graph.forward(
|
|
x, torch.tensor(True), torch.tensor(False)
|
|
) # True + True -> x + x
|
|
result_false_true = graph.forward(
|
|
x, torch.tensor(False), torch.tensor(True)
|
|
) # False + either -> cos
|
|
result_false_false = graph.forward(
|
|
x, torch.tensor(False), torch.tensor(False)
|
|
) # False + either -> cos
|
|
|
|
self.assertNotEqual(result_true_true, result_true_false)
|
|
self.assertFalse(torch.allclose(result_false_true, result_true_true))
|
|
|
|
self.assertEqual(result_false_true, result_false_false)
|
|
|
|
self.assertEqual(result_true_true, (x * x) + x)
|
|
self.assertEqual(result_true_false, x + x + x)
|
|
|
|
self.assertEqual(result_false_true, torch.cos(x))
|
|
|
|
graph = make_fx(f, tracing_mode="symbolic")(
|
|
x, torch.tensor(False), torch.tensor(False)
|
|
)
|
|
self.assertEqual(
|
|
graph(x, torch.tensor(True), torch.tensor(True)),
|
|
f(x, torch.tensor(True), torch.tensor(True)),
|
|
)
|
|
|
|
def test_cond_functionalized(self):
|
|
def true_fn(x):
|
|
y = x.sin()
|
|
y.add_(4)
|
|
return x.sin().max() + y.sum()
|
|
|
|
def false_fn(x):
|
|
return x.cos().min()
|
|
|
|
def f(x):
|
|
pred = x.shape[0] == 1
|
|
return cond(pred, true_fn, false_fn, [x])
|
|
|
|
example_inputs = (torch.ones(4, 5),)
|
|
functional_f = torch.func.functionalize(f)
|
|
self.assertEqual(functional_f(*example_inputs), f(*example_inputs))
|
|
|
|
graph_module = make_fx(torch.func.functionalize(f), tracing_mode="symbolic")(
|
|
*example_inputs
|
|
)
|
|
self.assertEqual(graph_module(*example_inputs), f(*example_inputs))
|
|
|
|
all_ops_in_true_branch = []
|
|
for node in graph_module.true_graph_0.graph.nodes:
|
|
if node.op == "call_function":
|
|
all_ops_in_true_branch.append(node.target)
|
|
|
|
self.assertFalse(any(op._schema.is_mutable for op in all_ops_in_true_branch))
|
|
|
|
self.assertEqual(graph_module(*example_inputs), f(*example_inputs))
|
|
|
|
def test_cond_accepts_torch_function_as_inputs(self):
|
|
a = torch.randn(3, 4)
|
|
b = torch.randn(3, 4)
|
|
|
|
def f(a, b):
|
|
return cond(a.sum() > 0, torch.add, torch.mul, (a, b))
|
|
|
|
gm = self._check_tracing(f, (a, b))["symbolic"]
|
|
self.assertExpectedInline(
|
|
gm.code.strip(),
|
|
"""\
|
|
def forward(self, a_1, b_1):
|
|
sum_1 = torch.ops.aten.sum.default(a_1)
|
|
gt = torch.ops.aten.gt.Scalar(sum_1, 0); sum_1 = None
|
|
sym_size_int = torch.ops.aten.sym_size.int(a_1, 1)
|
|
sym_size_int_1 = torch.ops.aten.sym_size.int(b_1, 0)
|
|
sym_size_int_2 = torch.ops.aten.sym_size.int(b_1, 1)
|
|
sym_size_int_3 = torch.ops.aten.sym_size.int(a_1, 0)
|
|
true_graph_0 = self.true_graph_0
|
|
false_graph_0 = self.false_graph_0
|
|
cond = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, (a_1, b_1, sym_size_int, sym_size_int_1, sym_size_int_2, sym_size_int_3)); gt = true_graph_0 = false_graph_0 = a_1 = b_1 = sym_size_int = sym_size_int_1 = sym_size_int_2 = sym_size_int_3 = None
|
|
getitem = cond[0]; cond = None
|
|
return getitem""", # noqa: B950
|
|
)
|
|
self.assertExpectedInline(
|
|
gm.true_graph_0.code.strip(),
|
|
"""\
|
|
def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1):
|
|
add = torch.ops.aten.add.Tensor(arg0_1, arg1_1); arg0_1 = arg1_1 = None
|
|
return (add,)""",
|
|
)
|
|
self.assertExpectedInline(
|
|
gm.false_graph_0.code.strip(),
|
|
"""\
|
|
def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1):
|
|
mul = torch.ops.aten.mul.Tensor(arg0_1, arg1_1); arg0_1 = arg1_1 = None
|
|
return (mul,)""",
|
|
)
|
|
|
|
def test_cond_retrace_functionalized(self):
|
|
def true_fn(x):
|
|
return x.sin()
|
|
|
|
def false_fn(x):
|
|
return x.cos()
|
|
|
|
def f(x):
|
|
return cond(x.all(), true_fn, false_fn, (x,))
|
|
|
|
inp = torch.ones(1, 2)
|
|
gm_non_functional = make_fx(f, tracing_mode="real")(inp)
|
|
gm_functional = make_fx(
|
|
torch.func.functionalize(gm_non_functional), tracing_mode="real"
|
|
)(inp)
|
|
self.assertEqual(gm_functional(torch.zeros(1, 2)), f(torch.zeros(1, 2)))
|
|
|
|
def test_cond_subgraph_same_shape_env_as_parent(self):
|
|
def true_fn(x):
|
|
return x.sin() + 10
|
|
|
|
def false_fn(x):
|
|
return x.cos() - 20
|
|
|
|
def f(x, pred):
|
|
y = cond(pred, true_fn, false_fn, [x])
|
|
z = torch.add(y, y)
|
|
return z
|
|
|
|
symbolic_traced_graph = self._check_tracing(
|
|
f, (torch.ones(4), torch.Tensor([True]))
|
|
)["symbolic"]
|
|
graph_shape_env = symbolic_traced_graph.shape_env
|
|
|
|
def _node_shape_env_iter(gm):
|
|
for node in symbolic_traced_graph.graph.nodes:
|
|
if node.op == "call_function":
|
|
val = node.meta.get("val")
|
|
if isinstance(val, tuple):
|
|
for v in val:
|
|
yield v.fake_mode.shape_env
|
|
elif isinstance(val, torch.SymInt):
|
|
yield val.node.shape_env
|
|
else:
|
|
yield val.fake_mode.shape_env
|
|
|
|
for shape_env in _node_shape_env_iter(symbolic_traced_graph):
|
|
self.assertTrue(shape_env is graph_shape_env)
|
|
|
|
for shape_env in _node_shape_env_iter(symbolic_traced_graph.true_graph_0):
|
|
self.assertTrue(shape_env is graph_shape_env)
|
|
|
|
for shape_env in _node_shape_env_iter(symbolic_traced_graph.false_graph_0):
|
|
self.assertTrue(shape_env is graph_shape_env)
|
|
|
|
def test_cond_functionalized_nested(self):
|
|
def true_true_fn(x):
|
|
y = x.cos()
|
|
y.add_(4)
|
|
return x.sin().max() + y.sin().max()
|
|
|
|
def true_false_fn(x):
|
|
return x.cos().min()
|
|
|
|
def true_fn(x):
|
|
pred = x.shape[0] == 1
|
|
return cond(pred, true_true_fn, true_false_fn, [x])
|
|
|
|
def false_fn(x):
|
|
return x.sum()
|
|
|
|
def f(x):
|
|
pred = x.shape[0] == 1
|
|
return cond(pred, true_fn, false_fn, [x])
|
|
|
|
example_inputs = (torch.ones(4, 5),)
|
|
functional_f = torch.func.functionalize(f)
|
|
self.assertEqual(functional_f(*example_inputs), f(*example_inputs))
|
|
|
|
graph_module = make_fx(torch.func.functionalize(f), tracing_mode="symbolic")(
|
|
*example_inputs
|
|
)
|
|
self.assertEqual(graph_module(*example_inputs), f(*example_inputs))
|
|
|
|
gm_true_true_branch = graph_module.true_graph_0.true_graph_0
|
|
|
|
self.assertEqual(graph_module(*example_inputs), f(*example_inputs))
|
|
|
|
all_ops = []
|
|
for node in gm_true_true_branch.graph.nodes:
|
|
if node.op == "call_function":
|
|
all_ops.append(node.target)
|
|
|
|
self.assertFalse(any(op._schema.is_mutable for op in all_ops))
|
|
|
|
def test_cond_functionalized_data_dependent_pred(self):
|
|
def true_fn(x):
|
|
return x.sin().sum()
|
|
|
|
def false_fn(x):
|
|
return x.cos().sum()
|
|
|
|
def f(x):
|
|
pred = x.nonzero().shape[0] == 1
|
|
return cond(pred, true_fn, false_fn, [x])
|
|
|
|
example_inputs = (torch.ones(4, 5),)
|
|
functional_f = torch.func.functionalize(f)
|
|
self.assertEqual(functional_f(*example_inputs), f(*example_inputs))
|
|
|
|
graph_module = make_fx(torch.func.functionalize(f))(*example_inputs)
|
|
self.assertEqual(graph_module(*example_inputs), f(*example_inputs))
|
|
|
|
def test_cond_functionalized_input_mutation_on_true_branch(self):
|
|
def true_fn(x):
|
|
view_x = x.view(x.shape)
|
|
view_x.add_(1)
|
|
return view_x.sin().sum()
|
|
|
|
def false_fn(x):
|
|
return x.cos().sum()
|
|
|
|
def f(x):
|
|
pred = x.shape[0] == 4
|
|
return cond(pred, true_fn, false_fn, [x])
|
|
|
|
example_inputs = (torch.ones(4, 5),)
|
|
# torch.cond inlines into one of the branches because the predicate
|
|
# is a constant.
|
|
gm = make_fx(torch.func.functionalize(f))(*example_inputs)
|
|
self.assertExpectedInline(
|
|
gm.code.strip(),
|
|
"""\
|
|
def forward(self, x_1):
|
|
view = torch.ops.aten.view.default(x_1, [4, 5])
|
|
add = torch.ops.aten.add.Tensor(view, 1); view = None
|
|
view_1 = torch.ops.aten.view.default(add, [4, 5]); add = None
|
|
view_2 = torch.ops.aten.view.default(view_1, [4, 5])
|
|
sin = torch.ops.aten.sin.default(view_2); view_2 = None
|
|
sum_1 = torch.ops.aten.sum.default(sin); sin = None
|
|
copy_ = torch.ops.aten.copy_.default(x_1, view_1); x_1 = view_1 = copy_ = None
|
|
return sum_1""",
|
|
)
|
|
|
|
# torch.cond triggers the check of the branches because the predicate
|
|
# is a SymBool.
|
|
with self.assertRaisesRegex(
|
|
torch._dynamo.exc.TorchRuntimeError,
|
|
"cond_true might be modifying the input!",
|
|
):
|
|
make_fx(torch.func.functionalize(f), tracing_mode="symbolic")(
|
|
*example_inputs
|
|
)
|
|
|
|
def test_cond_functionalized_input_mutation_on_false_branch(self):
|
|
def true_fn(x):
|
|
return x.sin().sum()
|
|
|
|
def false_fn(x):
|
|
view_x = x.view(x.shape)
|
|
view_x.add_(1)
|
|
return view_x.cos().sum()
|
|
|
|
def f(x):
|
|
pred = x.shape[0] == 4
|
|
return cond(pred, true_fn, false_fn, [x])
|
|
|
|
example_inputs = (torch.ones(5, 5),)
|
|
gm = make_fx(torch.func.functionalize(f))(*example_inputs)
|
|
# torch.cond inlines into one of the branches because the predicate
|
|
# is a constant.
|
|
self.assertExpectedInline(
|
|
gm.code.strip(),
|
|
"""\
|
|
def forward(self, x_1):
|
|
view = torch.ops.aten.view.default(x_1, [5, 5])
|
|
add = torch.ops.aten.add.Tensor(view, 1); view = None
|
|
view_1 = torch.ops.aten.view.default(add, [5, 5]); add = None
|
|
view_2 = torch.ops.aten.view.default(view_1, [5, 5])
|
|
cos = torch.ops.aten.cos.default(view_2); view_2 = None
|
|
sum_1 = torch.ops.aten.sum.default(cos); cos = None
|
|
copy_ = torch.ops.aten.copy_.default(x_1, view_1); x_1 = view_1 = copy_ = None
|
|
return sum_1""",
|
|
)
|
|
|
|
# torch.cond triggers the check of the branches because the predicate
|
|
# is a SymBool.
|
|
with self.assertRaisesRegex(
|
|
torch._dynamo.exc.TorchRuntimeError,
|
|
"cond_false might be modifying the input!",
|
|
):
|
|
make_fx(torch.func.functionalize(f), tracing_mode="symbolic")(
|
|
*example_inputs
|
|
)
|
|
|
|
def test_cond_functionalized_output_alias_input(self):
|
|
def true_fn(x):
|
|
return x.clone()
|
|
|
|
def false_fn(x):
|
|
view_x = x.view(x.shape)
|
|
return view_x
|
|
|
|
def f(x):
|
|
pred = x.shape[0] == 4
|
|
return cond(pred, true_fn, false_fn, [x])
|
|
|
|
example_inputs = (torch.ones(5, 5),)
|
|
gm = make_fx(torch.func.functionalize(f))(*example_inputs)
|
|
# torch.cond inlines into one of the branches because the predicate
|
|
# is a constant.
|
|
self.assertExpectedInline(
|
|
gm.code.strip(),
|
|
"""\
|
|
def forward(self, x_1):
|
|
view = torch.ops.aten.view.default(x_1, [5, 5]); x_1 = None
|
|
return view""",
|
|
)
|
|
|
|
# torch.cond triggers the check of the branches because the predicate
|
|
# is a SymBool.
|
|
with self.assertRaisesRegex(
|
|
# Should be
|
|
# torch._dynamo.exc.Unsupported,
|
|
# "Encountered aliasing during higher order op tracing for HOP.*"
|
|
torch._dynamo.exc.UncapturedHigherOrderOpError,
|
|
"Cond doesn't work unless it is captured completely with torch.compile.*",
|
|
):
|
|
make_fx(torch.func.functionalize(f), tracing_mode="symbolic")(
|
|
*example_inputs
|
|
)
|
|
|
|
def test_cond_functionalized_nested_input_mutation(self):
|
|
def true_true_fn(x):
|
|
x.add_(4)
|
|
return x.sin().max()
|
|
|
|
def true_false_fn(x):
|
|
return x.cos().min()
|
|
|
|
def true_fn(x):
|
|
pred = x.shape[0] == 1
|
|
return cond(pred, true_true_fn, true_false_fn, [x])
|
|
|
|
def false_fn(x):
|
|
return x.sum()
|
|
|
|
def f(x):
|
|
pred = x.shape[0] == 1
|
|
return cond(pred, true_fn, false_fn, [x])
|
|
|
|
example_inputs = (torch.ones(4, 5),)
|
|
with self.assertRaisesRegex(
|
|
torch._dynamo.exc.TorchRuntimeError,
|
|
"cond_true might be modifying the input!",
|
|
):
|
|
make_fx(torch.func.functionalize(f), tracing_mode="symbolic")(
|
|
*example_inputs
|
|
)
|
|
|
|
def test_cond_functionalized_nested_input_mutation_with_aot_func(self):
|
|
def true_true_fn(x):
|
|
x.add_(4)
|
|
return x.sin().max()
|
|
|
|
def true_false_fn(x):
|
|
return x.cos().min()
|
|
|
|
def true_fn(x):
|
|
pred = x.shape[0] == 1
|
|
return cond(pred, true_true_fn, true_false_fn, [x])
|
|
|
|
def false_fn(x):
|
|
return x.sum()
|
|
|
|
def f(x):
|
|
pred = x.shape[0] == 1
|
|
return cond(pred, true_fn, false_fn, [x])
|
|
|
|
example_input = torch.ones(4, 5)
|
|
try:
|
|
example_input_func = to_fun_old(example_input)
|
|
torch._enable_functionalization(reapply_views=False)
|
|
f(example_input_func)
|
|
|
|
with self.assertRaisesRegex(
|
|
# Should be
|
|
# torch._dynamo.exc.Unsupported,
|
|
# "Encountered aliasing during higher order op tracing for HOP.*"
|
|
torch._dynamo.exc.UncapturedHigherOrderOpError,
|
|
"Cond doesn't work unless it is captured completely with torch.compile.*",
|
|
):
|
|
make_fx(f, tracing_mode="symbolic")(example_input_func)
|
|
finally:
|
|
torch._disable_functionalization()
|
|
|
|
def f_wrapper(func):
|
|
@functools.wraps(func)
|
|
def wrapper(*args, **kwargs):
|
|
torch._enable_functionalization(reapply_views=False)
|
|
try:
|
|
return func(*args, **kwargs)
|
|
finally:
|
|
torch._disable_functionalization()
|
|
|
|
return wrapper
|
|
|
|
with self.assertRaisesRegex(
|
|
# Should be
|
|
# torch._dynamo.exc.Unsupported,
|
|
# "Encountered aliasing during higher order op tracing for HOP.*"
|
|
torch._dynamo.exc.UncapturedHigherOrderOpError,
|
|
"Cond doesn't work unless it is captured completely with torch.compile.*",
|
|
):
|
|
make_fx(f_wrapper(f), tracing_mode="symbolic")(example_input_func)
|
|
|
|
def test_cond_functionalized_input_aliasing_with_aot_func(self):
|
|
def true_fn(x):
|
|
return x
|
|
|
|
def false_fn(x):
|
|
view_x = x.view(x.shape)
|
|
return view_x
|
|
|
|
def f(x):
|
|
pred = x.sum() > 0
|
|
return cond(pred, true_fn, false_fn, [x])
|
|
|
|
example_input = torch.ones(5, 5)
|
|
try:
|
|
example_input_func = to_fun_old(example_input)
|
|
torch._enable_functionalization(reapply_views=False)
|
|
with self.assertRaisesRegex(
|
|
# Should be
|
|
# torch._dynamo.exc.Unsupported,
|
|
# "Encountered aliasing during higher order op tracing for HOP.*"
|
|
torch._dynamo.exc.UncapturedHigherOrderOpError,
|
|
"Cond doesn't work unless it is captured completely with torch.compile.*",
|
|
):
|
|
f(example_input_func)
|
|
finally:
|
|
torch._disable_functionalization()
|
|
|
|
def f_wrapper(func):
|
|
@functools.wraps(func)
|
|
def wrapper(*args, **kwargs):
|
|
torch._enable_functionalization(reapply_views=False)
|
|
try:
|
|
func_args = pytree.tree_map(
|
|
lambda x: torch._to_functional_tensor(x)
|
|
if isinstance(x, torch.Tensor)
|
|
else x,
|
|
args,
|
|
)
|
|
func_kwargs = pytree.tree_map(
|
|
lambda x: torch._to_functional_tensor(x)
|
|
if isinstance(x, torch.Tensor)
|
|
else x,
|
|
kwargs,
|
|
)
|
|
return func(*func_args, **func_kwargs)
|
|
finally:
|
|
torch._disable_functionalization()
|
|
|
|
return wrapper
|
|
|
|
with self.assertRaisesRegex(
|
|
# Should be
|
|
# torch._dynamo.exc.Unsupported,
|
|
# "Encountered aliasing during higher order op tracing for HOP.*"
|
|
torch._dynamo.exc.UncapturedHigherOrderOpError,
|
|
"Cond doesn't work unless it is captured completely with torch.compile.*",
|
|
):
|
|
make_fx(f_wrapper(f), tracing_mode="symbolic")(example_input)
|
|
|
|
def test_cond_functionalized_aot_func_check_functional(self):
|
|
def true_fn(x):
|
|
return x.cos()
|
|
|
|
def false_fn(x):
|
|
y = x.sin()
|
|
y.add_(5)
|
|
return y
|
|
|
|
def f(x):
|
|
pred = x.shape[0] == 4
|
|
return cond(pred, true_fn, false_fn, [x])
|
|
|
|
example_input = torch.ones(5, 5)
|
|
|
|
def f_wrapper(func):
|
|
@functools.wraps(func)
|
|
def wrapper(*args, **kwargs):
|
|
torch._enable_functionalization(reapply_views=False)
|
|
try:
|
|
func_args = pytree.tree_map(
|
|
lambda x: to_fun_old(x) if isinstance(x, torch.Tensor) else x,
|
|
args,
|
|
)
|
|
func_kwargs = pytree.tree_map(
|
|
lambda x: to_fun_old(x) if isinstance(x, torch.Tensor) else x,
|
|
kwargs,
|
|
)
|
|
return pytree.tree_map(
|
|
from_fun_old, func(*func_args, **func_kwargs)
|
|
)
|
|
finally:
|
|
torch._disable_functionalization()
|
|
|
|
return wrapper
|
|
|
|
result_gm = make_fx(f_wrapper(f), tracing_mode="symbolic")(example_input)
|
|
for node in result_gm.true_graph_0.graph.nodes:
|
|
if node.op == "call_function":
|
|
self.assertTrue(not node.target._schema.is_mutable)
|
|
|
|
for node in result_gm.false_graph_0.graph.nodes:
|
|
if node.op == "call_function":
|
|
self.assertTrue(not node.target._schema.is_mutable)
|
|
|
|
self.assertEqual(result_gm(torch.ones(5, 5)), f(torch.ones(5, 5)))
|
|
|
|
def test_cond_nested_traced_other_inputs(self):
|
|
def true_nested(y):
|
|
return y * y
|
|
|
|
def false_nested(y):
|
|
return y + y
|
|
|
|
def true_fn(k, pred2):
|
|
z = cond(pred2, true_nested, false_nested, [k])
|
|
return torch.add(torch.tensor([0.25, 0.25]), z)
|
|
|
|
def false_fn(k, _):
|
|
return k.cos()
|
|
|
|
def f(k, pred, pred2):
|
|
return cond(pred, true_fn, false_fn, [k, pred2])
|
|
|
|
x = torch.tensor([0.5, 0.5])
|
|
graph = make_fx(f)(x, torch.tensor(False), torch.tensor(False))
|
|
|
|
a = torch.tensor([1.0, 1.0])
|
|
result_true_true = graph.forward(a, torch.tensor(True), torch.tensor(True))
|
|
self.assertEqual(result_true_true, (a * a) + torch.tensor([0.25, 0.25]))
|
|
|
|
b = torch.tensor([2.0, 2.0])
|
|
result_true_true = graph.forward(b, torch.tensor(True), torch.tensor(True))
|
|
self.assertEqual(result_true_true, (b * b) + torch.tensor([0.25, 0.25]))
|
|
|
|
def test_cond_nested_traced_multi(self):
|
|
def true_a(y):
|
|
return y * y
|
|
|
|
def false_a(y):
|
|
return y + y
|
|
|
|
def true_b(y, z):
|
|
return y + z
|
|
|
|
def false_b(y, z):
|
|
return y * z
|
|
|
|
def f(x, pred, pred2):
|
|
a_out = cond(pred, true_a, false_a, [x])
|
|
b_out = cond(pred2, true_b, false_b, [x, x])
|
|
return a_out + b_out
|
|
|
|
x = torch.randn(4)
|
|
graph = make_fx(f)(x, torch.tensor(False), torch.tensor(False))
|
|
|
|
self.assertExpectedInline(
|
|
graph.code.strip(),
|
|
"""\
|
|
def forward(self, x_1, pred_1, pred2_1):
|
|
true_graph_0 = self.true_graph_0
|
|
false_graph_0 = self.false_graph_0
|
|
cond = torch.ops.higher_order.cond(pred_1, true_graph_0, false_graph_0, (x_1,)); pred_1 = true_graph_0 = false_graph_0 = None
|
|
getitem = cond[0]; cond = None
|
|
true_graph_1 = self.true_graph_1
|
|
false_graph_1 = self.false_graph_1
|
|
cond_1 = torch.ops.higher_order.cond(pred2_1, true_graph_1, false_graph_1, (x_1,)); pred2_1 = true_graph_1 = false_graph_1 = x_1 = None
|
|
getitem_1 = cond_1[0]; cond_1 = None
|
|
add = torch.ops.aten.add.Tensor(getitem, getitem_1); getitem = getitem_1 = None
|
|
return add""", # noqa: B950
|
|
)
|
|
self.assertExpectedInline(
|
|
graph.true_graph_0.code.strip(),
|
|
"""\
|
|
def forward(self, arg0_1):
|
|
mul = torch.ops.aten.mul.Tensor(arg0_1, arg0_1); arg0_1 = None
|
|
return (mul,)""",
|
|
)
|
|
|
|
def test_raise_error_on_mismatch_type_size(self):
|
|
def true_fn(x):
|
|
return x.sin()
|
|
|
|
def false_fn(x):
|
|
return (x, x)
|
|
|
|
def f(x, y):
|
|
return cond(y, true_fn, false_fn, [x])
|
|
|
|
x = torch.randn(4)
|
|
with self.assertRaisesRegex(
|
|
# Should be
|
|
# torch._dynamo.exc.Unsupported,
|
|
# "Encountered aliasing during higher order op tracing for HOP.*"
|
|
torch._dynamo.exc.UncapturedHigherOrderOpError,
|
|
"Cond doesn't work unless it is captured completely with torch.compile.*",
|
|
):
|
|
make_fx(f)(x, torch.tensor(False))
|
|
|
|
def test_raise_error_on_mismatch_tensor_size(self):
|
|
def true_fn(x):
|
|
return x.sin()
|
|
|
|
def false_fn(x):
|
|
return torch.zeros([10, 10])
|
|
|
|
def f(x, y):
|
|
return cond(y, true_fn, false_fn, [x])
|
|
|
|
x = torch.randn(4)
|
|
with self.assertRaisesRegex(
|
|
torch._dynamo.exc.TorchRuntimeError,
|
|
"When merging two branches' output in torch.cond",
|
|
):
|
|
make_fx(f)(x, torch.tensor(False))
|
|
|
|
def test_cond_traced_not_nested_fake_tensor(self):
|
|
def true_fn(x):
|
|
return x.sin()
|
|
|
|
def false_fn(x):
|
|
return x.cos()
|
|
|
|
def f(x, y):
|
|
return cond(y, true_fn, false_fn, [x])
|
|
|
|
x = torch.randn(4)
|
|
graph = make_fx(f, tracing_mode="fake")(x, torch.tensor(False))
|
|
result_true = graph.forward(x, torch.tensor(True))
|
|
result_false = graph.forward(x, torch.tensor(False))
|
|
self.assertFalse(torch.allclose(result_true, result_false))
|
|
self.assertEqual(result_true, torch.sin(x))
|
|
self.assertEqual(result_false, torch.cos(x))
|
|
|
|
def test_cond_nested_traced_fake_tensor(self):
|
|
def true_nested(y):
|
|
return y * y
|
|
|
|
def false_nested(y):
|
|
return y + y
|
|
|
|
def true_fn(x, pred2):
|
|
z = cond(pred2, true_nested, false_nested, [x])
|
|
return x + z
|
|
|
|
def false_fn(x, _):
|
|
return x.cos()
|
|
|
|
def f(x, pred, pred2):
|
|
return cond(pred, true_fn, false_fn, [x, pred2])
|
|
|
|
x = torch.randn(4)
|
|
graph = make_fx(f, tracing_mode="fake")(
|
|
x, torch.tensor(False), torch.tensor(False)
|
|
)
|
|
|
|
result_true_true = graph.forward(
|
|
x, torch.tensor(True), torch.tensor(True)
|
|
) # True + True -> x * x
|
|
result_true_false = graph.forward(
|
|
x, torch.tensor(True), torch.tensor(False)
|
|
) # True + True -> x + x
|
|
result_false_true = graph.forward(
|
|
x, torch.tensor(False), torch.tensor(True)
|
|
) # False + either -> cos
|
|
result_false_false = graph.forward(
|
|
x, torch.tensor(False), torch.tensor(False)
|
|
) # False + either -> cos
|
|
|
|
self.assertNotEqual(result_true_true, result_true_false)
|
|
self.assertFalse(torch.allclose(result_false_true, result_true_true))
|
|
|
|
self.assertEqual(result_false_true, result_false_false)
|
|
|
|
self.assertEqual(result_true_true, (x * x) + x)
|
|
self.assertEqual(result_true_false, x + x + x)
|
|
|
|
self.assertEqual(result_false_true, torch.cos(x))
|
|
|
|
def test_cond_nested_traced_other_inputs_fake_tensor(self):
|
|
def true_nested(y):
|
|
return y * y
|
|
|
|
def false_nested(y):
|
|
return y + y
|
|
|
|
def true_fn(k, pred2):
|
|
z = cond(pred2, true_nested, false_nested, [k])
|
|
return torch.add(torch.tensor([0.25, 0.25]), z)
|
|
|
|
def false_fn(k, _):
|
|
return k.cos()
|
|
|
|
def f(k, pred, pred2):
|
|
return cond(pred, true_fn, false_fn, [k, pred2])
|
|
|
|
x = torch.tensor([0.5, 0.5])
|
|
graph = make_fx(f, tracing_mode="fake")(
|
|
x, torch.tensor(False), torch.tensor(False)
|
|
)
|
|
|
|
a = torch.tensor([1.0, 1.0])
|
|
result_true_true = graph.forward(a, torch.tensor(True), torch.tensor(True))
|
|
self.assertEqual(result_true_true, (a * a) + torch.tensor([0.25, 0.25]))
|
|
|
|
b = torch.tensor([2.0, 2.0])
|
|
result_true_true = graph.forward(b, torch.tensor(True), torch.tensor(True))
|
|
self.assertEqual(result_true_true, (b * b) + torch.tensor([0.25, 0.25]))
|
|
|
|
def test_cond_nested_traced_multi_fake_tensor(self):
|
|
def true_a(y):
|
|
return y * y
|
|
|
|
def false_a(y):
|
|
return y + y
|
|
|
|
def true_b(y, z):
|
|
return y + z
|
|
|
|
def false_b(y, z):
|
|
return y * z
|
|
|
|
def f(x, pred, pred2):
|
|
a_out = cond(pred, true_a, false_a, [x])
|
|
b_out = cond(pred2, true_b, false_b, [x, x])
|
|
return a_out + b_out
|
|
|
|
x = torch.randn(4)
|
|
graph = make_fx(f, tracing_mode="fake")(
|
|
x, torch.tensor(False), torch.tensor(False)
|
|
)
|
|
|
|
self.assertExpectedInline(
|
|
graph.code.strip(),
|
|
"""\
|
|
def forward(self, x_1, pred_1, pred2_1):
|
|
true_graph_0 = self.true_graph_0
|
|
false_graph_0 = self.false_graph_0
|
|
cond = torch.ops.higher_order.cond(pred_1, true_graph_0, false_graph_0, (x_1,)); pred_1 = true_graph_0 = false_graph_0 = None
|
|
getitem = cond[0]; cond = None
|
|
true_graph_1 = self.true_graph_1
|
|
false_graph_1 = self.false_graph_1
|
|
cond_1 = torch.ops.higher_order.cond(pred2_1, true_graph_1, false_graph_1, (x_1,)); pred2_1 = true_graph_1 = false_graph_1 = x_1 = None
|
|
getitem_1 = cond_1[0]; cond_1 = None
|
|
add = torch.ops.aten.add.Tensor(getitem, getitem_1); getitem = getitem_1 = None
|
|
return add""", # noqa: B950
|
|
)
|
|
self.assertExpectedInline(
|
|
graph.true_graph_0.code.strip(),
|
|
"""\
|
|
def forward(self, arg0_1):
|
|
mul = torch.ops.aten.mul.Tensor(arg0_1, arg0_1); arg0_1 = None
|
|
return (mul,)""",
|
|
)
|
|
|
|
def test_raise_error_on_mismatch_type_size_fake_tensor(self):
|
|
def true_fn(x):
|
|
return x.sin()
|
|
|
|
def false_fn(x):
|
|
return (x, x)
|
|
|
|
def f(x, y):
|
|
return cond(y, true_fn, false_fn, [x])
|
|
|
|
x = torch.randn(4)
|
|
with self.assertRaisesRegex(
|
|
# Should be
|
|
# torch._dynamo.exc.Unsupported,
|
|
# "Encountered aliasing during higher order op tracing for HOP.*"
|
|
torch._dynamo.exc.UncapturedHigherOrderOpError,
|
|
"Cond doesn't work unless it is captured completely with torch.compile.*",
|
|
):
|
|
make_fx(f, tracing_mode="fake")(x, torch.tensor(False))
|
|
|
|
def test_raise_error_on_mismatch_tensor_size_fake_tensor(self):
|
|
def true_fn(x):
|
|
return x.sin()
|
|
|
|
def false_fn(x):
|
|
return torch.zeros([10, 10])
|
|
|
|
def f(x, y):
|
|
return cond(y, true_fn, false_fn, [x])
|
|
|
|
x = torch.randn(4)
|
|
with self.assertRaisesRegex(
|
|
torch._dynamo.exc.TorchRuntimeError,
|
|
"When merging two branches' output in torch.cond",
|
|
):
|
|
make_fx(f, tracing_mode="fake")(x, torch.tensor(False))
|
|
|
|
def check_map_count(self, gm, op_count):
|
|
i = 0
|
|
for m in gm.modules():
|
|
for node in m.graph.nodes:
|
|
if (
|
|
node.op == "call_function"
|
|
and node.target == torch.ops.higher_order.map_impl
|
|
):
|
|
i += 1
|
|
self.assertEqual(i, op_count)
|
|
|
|
def test_tracing_map_real(self):
|
|
def f(x, y):
|
|
return x + y
|
|
|
|
def g(xs, y):
|
|
return control_flow.map(f, xs, y)
|
|
|
|
gm = make_fx(g, tracing_mode="real")(torch.ones(3, 2, 2), torch.ones(2))
|
|
x = torch.randn(3, 2, 2)
|
|
y = torch.randn(2)
|
|
res = gm(x, y)
|
|
self.assertEqual(res, g(x, y))
|
|
self.check_map_count(gm, 1)
|
|
|
|
def test_tracing_map_symbolic_simple(self):
|
|
def f(x, y):
|
|
return x + y
|
|
|
|
def g(xs, y):
|
|
return control_flow.map(f, xs, y)
|
|
|
|
gm = make_fx(g, tracing_mode="symbolic")(torch.ones(3, 2, 4), torch.ones(4))
|
|
x = torch.randn(3, 2, 2)
|
|
y = torch.randn(2)
|
|
res = gm(x, y)
|
|
self.assertEqual(res, g(x, y))
|
|
self.check_map_count(gm, 1)
|
|
|
|
def test_tracing_map_symbolic_list(self):
|
|
def f(x, y):
|
|
return [x[0][0] + y, x[1] * y]
|
|
|
|
def g(xs, y, z):
|
|
out = control_flow.map(f, xs, y)
|
|
return out[0] + z, out[1] * z
|
|
|
|
example_x = [[torch.ones(3, 4, 5)], torch.ones(3, 4, 5)]
|
|
gm = make_fx(g, tracing_mode="symbolic")(
|
|
example_x, torch.ones(5), torch.ones(5)
|
|
)
|
|
x = [[torch.randn(4, 5, 6)], torch.ones(4, 5, 6)]
|
|
y = torch.randn(6)
|
|
z = torch.ones(6)
|
|
res = gm(x, y, z)
|
|
self.assertEqual(res, g(x, y, z))
|
|
self.check_map_count(gm, 1)
|
|
|
|
def test_tracing_map_symbolic_dict(self):
|
|
def f(x, y):
|
|
return {"d": x["b"]["a"] + y, "e": x["c"] * y}
|
|
|
|
def g(xs, y, z):
|
|
out = control_flow.map(f, xs, y)
|
|
return {"f": out["d"] + z, "g": out["e"] * z}
|
|
|
|
example_x = {"b": {"a": torch.ones(3, 4, 5)}, "c": torch.ones(3, 4, 5)}
|
|
gm = make_fx(g, tracing_mode="symbolic")(
|
|
example_x, torch.ones(5), torch.ones(5)
|
|
)
|
|
x = {"b": {"a": torch.randn(4, 5, 6)}, "c": torch.ones(4, 5, 6)}
|
|
y = torch.randn(6)
|
|
z = torch.ones(6)
|
|
res = gm(x, y, z)
|
|
self.assertEqual(res, g(x, y, z))
|
|
self.check_map_count(gm, 1)
|
|
|
|
def test_tracing_map_autograd_symbolic_simple(self):
|
|
def f(x, y):
|
|
return x + y
|
|
|
|
def g(xs, y):
|
|
out = control_flow.map(f, xs, y)
|
|
return torch.autograd.grad(out, (xs, y), torch.ones_like(out))
|
|
|
|
gm = make_fx(g, tracing_mode="symbolic")(
|
|
torch.ones(3, 4, 5, requires_grad=True), torch.ones(5, requires_grad=True)
|
|
)
|
|
x = torch.randn(4, 5, 6, requires_grad=True)
|
|
y = torch.randn(6, requires_grad=True)
|
|
res = gm(x, y)
|
|
self.assertEqual(res, g(x, y))
|
|
self.check_map_count(gm, 2)
|
|
|
|
def test_tracing_map_autograd_symbolic_list(self):
|
|
import torch.utils._pytree as pytree
|
|
|
|
def f(x, y):
|
|
return [x[0].cos() + y.sin(), x[1].sin() * y.cos()]
|
|
|
|
def g(xs, y):
|
|
out = control_flow.map(f, xs, y)
|
|
flat_out = pytree.tree_leaves(out)
|
|
flat_inp = pytree.tree_leaves((xs, y))
|
|
requires_grad_inp = [inp for inp in flat_inp if inp.requires_grad]
|
|
return torch.autograd.grad(
|
|
flat_out, requires_grad_inp, [torch.ones_like(out) for out in flat_out]
|
|
)
|
|
|
|
gm = make_fx(g, tracing_mode="symbolic")(
|
|
[torch.ones(3, 4, 5), torch.ones(3, 4, 5, requires_grad=True)],
|
|
torch.ones(5, requires_grad=True),
|
|
)
|
|
x = [torch.randn(4, 5, 6), torch.ones(4, 5, 6, requires_grad=True)]
|
|
y = torch.randn(6, requires_grad=True)
|
|
res = gm(x, y)
|
|
self.assertEqual(res, g(x, y))
|
|
self.check_map_count(gm, 2)
|
|
|
|
def test_tracing_map_autograd_symbolic_dict(self):
|
|
def f(x, y):
|
|
return [x["a"] + y, x["b"] * y]
|
|
|
|
def g(xs, y):
|
|
out = control_flow.map(f, xs, y)
|
|
flat_out = pytree.tree_leaves(out)
|
|
flat_inp = pytree.tree_leaves((xs, y))
|
|
requires_grad_inp = [inp for inp in flat_inp if inp.requires_grad]
|
|
return torch.autograd.grad(
|
|
flat_out, requires_grad_inp, [torch.ones_like(out) for out in flat_out]
|
|
)
|
|
|
|
traced_x = {
|
|
"a": torch.ones(3, 4, 5, requires_grad=True),
|
|
"b": torch.ones(3, 4, 5, requires_grad=True),
|
|
}
|
|
gm = make_fx(g, tracing_mode="symbolic")(
|
|
traced_x, torch.ones(5, requires_grad=True)
|
|
)
|
|
x = {
|
|
"a": torch.randn(4, 5, 6, requires_grad=True),
|
|
"b": torch.ones(4, 5, 6, requires_grad=True),
|
|
}
|
|
y = torch.randn(6, requires_grad=True)
|
|
res = gm(x, y)
|
|
self.assertEqual(res, g(x, y))
|
|
self.check_map_count(gm, 2)
|
|
|
|
def test_tracing_map_autograd_aot_functionalized(self):
|
|
def inner(x, y):
|
|
z = x - 1
|
|
z.add_(1)
|
|
return z * y
|
|
|
|
def f(xs, y):
|
|
res = control_flow.map(inner, xs, y)
|
|
grads = torch.autograd.grad(res, (xs, y), torch.ones_like(res))
|
|
return grads
|
|
|
|
def f_wrapper(func):
|
|
@functools.wraps(func)
|
|
def wrapper(*args, **kwargs):
|
|
torch._enable_functionalization(reapply_views=False)
|
|
try:
|
|
return pytree.tree_map(from_fun_old, func(*args, **kwargs))
|
|
finally:
|
|
torch._disable_functionalization()
|
|
|
|
return wrapper
|
|
|
|
example_inputs = (
|
|
torch.ones(3, 2, 4, requires_grad=True),
|
|
torch.ones(2, 4, requires_grad=True),
|
|
)
|
|
gm = make_fx(f, tracing_mode="symbolic")(*example_inputs)
|
|
fgm = make_fx(f_wrapper(f), tracing_mode="symbolic")(*example_inputs)
|
|
xs = torch.ones(3, 4, 5, requires_grad=True)
|
|
y = torch.ones(4, 5, requires_grad=True)
|
|
|
|
self.assertEqual(gm(xs, y), f(xs, y))
|
|
|
|
def count_mutable(gm):
|
|
c = 0
|
|
for node in gm.graph.nodes:
|
|
if node.op == "call_function":
|
|
if node.target == torch.ops.higher_order.map_impl:
|
|
c += count_mutable(getattr(gm, str(node.args[0])))
|
|
elif schema := getattr(node.target, "_schema", None):
|
|
c += int(schema.is_mutable)
|
|
return c
|
|
|
|
self.assertEqual(count_mutable(fgm), 0)
|
|
# One for forward, one for recomputation logic in backward
|
|
self.assertEqual(count_mutable(gm), 2)
|
|
|
|
def test_map_functionalized(self):
|
|
def map_fn(x, y):
|
|
z = x + y
|
|
z.add_(4)
|
|
return z
|
|
|
|
def f(xs, y):
|
|
return control_flow.map(map_fn, xs, y)
|
|
|
|
example_inputs = (torch.ones(3, 2, 4), torch.ones(4))
|
|
functional_f = torch.func.functionalize(f)
|
|
self.assertEqual(functional_f(*example_inputs), f(*example_inputs))
|
|
|
|
gm = make_fx(torch.func.functionalize(f))(*example_inputs)
|
|
self.assertEqual(gm(*example_inputs), f(*example_inputs))
|
|
|
|
gm = make_fx(torch.func.functionalize(f), tracing_mode="symbolic")(
|
|
*example_inputs
|
|
)
|
|
self.assertEqual(gm(*example_inputs), f(*example_inputs))
|
|
|
|
for node in gm.body_graph_0.graph.nodes:
|
|
if node.op == "call_function":
|
|
self.assertTrue(not node.target._schema.is_mutable)
|
|
self.check_map_count(gm, 1)
|
|
|
|
def test_map_functionalized_aot_func(self):
|
|
def map_fn(x, y):
|
|
z = x + y
|
|
z.add_(4)
|
|
return z
|
|
|
|
def f(xs, y):
|
|
return control_flow.map(map_fn, xs, y)
|
|
|
|
def f_wrapper(func):
|
|
@functools.wraps(func)
|
|
def wrapper(*args, **kwargs):
|
|
torch._enable_functionalization(reapply_views=False)
|
|
try:
|
|
return pytree.tree_map(from_fun_old, func(*args, **kwargs))
|
|
finally:
|
|
torch._disable_functionalization()
|
|
|
|
return wrapper
|
|
|
|
example_inputs = (torch.ones(3, 2, 4), torch.ones(4))
|
|
|
|
gm = make_fx(f_wrapper(f))(*example_inputs)
|
|
|
|
for node in gm.body_graph_0.graph.nodes:
|
|
if node.op == "call_function":
|
|
self.assertTrue(not node.target._schema.is_mutable)
|
|
|
|
self.assertEqual(gm(*example_inputs), f(*example_inputs))
|
|
|
|
def test_map_functionalized_arg_mutation(self):
|
|
def map_fn(x, y):
|
|
y.add_(4)
|
|
return x + y
|
|
|
|
def f(xs, y):
|
|
return control_flow.map(map_fn, xs, y)
|
|
|
|
example_inputs = (torch.ones(3, 2, 4), torch.ones(4))
|
|
functional_f = torch.func.functionalize(f)
|
|
with self.assertRaisesRegex(
|
|
torch._dynamo.exc.TorchRuntimeError,
|
|
"map might be modifying the input!",
|
|
):
|
|
functional_f(*example_inputs)
|
|
|
|
def test_map_functionalized_elem_mutation(self):
|
|
def map_fn(x, y):
|
|
x.add_(4)
|
|
return x + y
|
|
|
|
def f(xs, y):
|
|
return control_flow.map(map_fn, xs, y)
|
|
|
|
example_inputs = (torch.ones(3, 2, 4), torch.ones(4))
|
|
functional_f = torch.func.functionalize(f)
|
|
with self.assertRaisesRegex(
|
|
torch._dynamo.exc.TorchRuntimeError, "map might be modifying the input!"
|
|
):
|
|
functional_f(*example_inputs)
|
|
|
|
def test_cond_autograd_backward(self):
|
|
def true_fn(x):
|
|
return x.cos()
|
|
|
|
def false_fn(x):
|
|
return x.sin()
|
|
|
|
def f(x, y):
|
|
return control_flow.cond(x.shape[0] > 4, true_fn, false_fn, [y])
|
|
|
|
example_inputs = (
|
|
torch.ones(3, 2, 4, requires_grad=True),
|
|
torch.ones(4, requires_grad=True),
|
|
)
|
|
f(*example_inputs).sum().backward()
|
|
|
|
# Ensure no error is thrown when not running backward
|
|
res = f(*example_inputs)
|
|
|
|
# Ensure no error is thrown when not running backward
|
|
res_compiled = torch.compile(f)(*example_inputs)
|
|
self.assertEqual(res, res_compiled)
|
|
|
|
def test_map_functionalized_elem_alias(self):
|
|
def map_fn(x):
|
|
x.view(x.shape)
|
|
return x
|
|
|
|
def f(xs):
|
|
return control_flow.map(map_fn, xs)
|
|
|
|
example_inputs = (torch.ones(3, 2, 4),)
|
|
functional_f = torch.func.functionalize(f)
|
|
with self.assertRaisesRegex(
|
|
# Should be
|
|
# torch._dynamo.exc.Unsupported,
|
|
# "Encountered aliasing during higher order op tracing.*"
|
|
torch._dynamo.exc.UncapturedHigherOrderOpError,
|
|
"map doesn't work unless it is captured completely with torch.compile.*",
|
|
):
|
|
functional_f(*example_inputs)
|
|
|
|
def test_nested_map_cond_real(self):
|
|
def true_fn(x, y):
|
|
return x * y
|
|
|
|
def false_fn(x, y):
|
|
return x + y
|
|
|
|
def f(x, pred, y):
|
|
return cond(pred, true_fn, false_fn, [x, y])
|
|
|
|
def g(pred, xs, y):
|
|
return control_flow.map(f, xs, pred, y)
|
|
|
|
gm = make_fx(g, tracing_mode="real")(
|
|
torch.tensor(True), torch.ones(3, 2, 4), torch.ones(4)
|
|
)
|
|
pred = torch.tensor(False)
|
|
x = torch.randn(3, 2, 4)
|
|
y = torch.randn(4)
|
|
res = gm(pred, x, y)
|
|
self.assertEqual(res, g(pred, x, y))
|
|
self.check_map_count(gm, 1)
|
|
|
|
def test_nested_map_cond_symbolic(self):
|
|
def true_fn(x, y):
|
|
return x * y
|
|
|
|
def false_fn(x, y):
|
|
return x + y
|
|
|
|
def f(x, pred, y):
|
|
return cond(pred, true_fn, false_fn, [x, y])
|
|
|
|
def g(pred, xs, y):
|
|
return control_flow.map(f, xs, pred, y)
|
|
|
|
gm = make_fx(g, tracing_mode="symbolic")(
|
|
torch.tensor(True), torch.ones(3, 2, 4), torch.ones(4)
|
|
)
|
|
pred = torch.tensor(False)
|
|
x = torch.randn(3, 2, 2)
|
|
y = torch.randn(2)
|
|
res = gm(pred, x, y)
|
|
self.assertEqual(res, g(pred, x, y))
|
|
self.check_map_count(gm, 1)
|
|
|
|
def test_nested_cond_map_cond_symbolic(self):
|
|
def true_fn(x, y):
|
|
return x * y
|
|
|
|
def false_fn(x, y):
|
|
return x + y
|
|
|
|
def f(x, pred, y):
|
|
return cond(pred, true_fn, false_fn, [x, y])
|
|
|
|
def g(pred, xs, y):
|
|
return control_flow.map(f, xs, pred, y)
|
|
|
|
def main_true_fn(pred, xs, y):
|
|
return g(pred, xs, y) * 2
|
|
|
|
def main_false_fn(pred, xs, y):
|
|
return g(pred, xs, y) + 1
|
|
|
|
def main(p, pred, xs, y):
|
|
return cond(p, main_true_fn, main_false_fn, [pred, xs, y])
|
|
|
|
gm = make_fx(main, tracing_mode="symbolic")(
|
|
torch.tensor(True), torch.tensor(True), torch.ones(3, 2, 4), torch.ones(4)
|
|
)
|
|
p = torch.tensor(False)
|
|
pred = torch.tensor(False)
|
|
xs = torch.randn(3, 2, 2)
|
|
y = torch.randn(2)
|
|
res = gm(p, pred, xs, y)
|
|
self.assertEqual(res, main(p, pred, xs, y))
|
|
self.check_map_count(gm, 2)
|
|
|
|
def test_cond_with_sym_pred(self):
|
|
def true_fn(x):
|
|
return x + x
|
|
|
|
def false_fn(x):
|
|
return x * x
|
|
|
|
def foo(x):
|
|
return cond(x.shape[0] == 4, true_fn, false_fn, [x])
|
|
|
|
gm = make_fx(foo, tracing_mode="symbolic")(torch.ones(3, 2, 1))
|
|
# The symbols in make_fx's shape_env should not be specialized.
|
|
self.assertEqual(len(gm.shape_env.guards), 0)
|
|
|
|
self.assertExpectedInline(
|
|
gm.code.strip(),
|
|
"""\
|
|
def forward(self, x_1):
|
|
sym_size_int = torch.ops.aten.sym_size.int(x_1, 0)
|
|
eq = sym_size_int == 4
|
|
sym_size_int_1 = torch.ops.aten.sym_size.int(x_1, 1)
|
|
true_graph_0 = self.true_graph_0
|
|
false_graph_0 = self.false_graph_0
|
|
cond = torch.ops.higher_order.cond(eq, true_graph_0, false_graph_0, (x_1, sym_size_int_1, sym_size_int)); eq = true_graph_0 = false_graph_0 = x_1 = sym_size_int_1 = sym_size_int = None
|
|
getitem = cond[0]; cond = None
|
|
return getitem""", # noqa: B950
|
|
)
|
|
|
|
# We expect the traced graph module to work even if input size changes.
|
|
x = torch.ones(4, 3, 2)
|
|
self.assertEqual(gm(x), true_fn(x))
|
|
self.assertEqual(foo(x), true_fn(x))
|
|
|
|
def test_cond_with_unbacked_sym_pred(self):
|
|
def foo(x):
|
|
def true_fn(x):
|
|
return x + x
|
|
|
|
def false_fn(x):
|
|
return x * x
|
|
|
|
az = x.nonzero()
|
|
return cond(az.shape[0] > 3, true_fn, false_fn, (x,))
|
|
|
|
gm = make_fx(foo, tracing_mode="symbolic")(torch.randn(7))
|
|
self.assertExpectedInline(
|
|
gm.code.strip(),
|
|
"""\
|
|
def forward(self, x_1):
|
|
nonzero = torch.ops.aten.nonzero.default(x_1)
|
|
sym_size_int = torch.ops.aten.sym_size.int(nonzero, 0); nonzero = None
|
|
gt = sym_size_int > 3; sym_size_int = None
|
|
sym_size_int_1 = torch.ops.aten.sym_size.int(x_1, 0)
|
|
true_graph_0 = self.true_graph_0
|
|
false_graph_0 = self.false_graph_0
|
|
cond = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, (x_1, sym_size_int_1)); gt = true_graph_0 = false_graph_0 = x_1 = sym_size_int_1 = None
|
|
getitem = cond[0]; cond = None
|
|
return getitem""", # noqa: B950
|
|
)
|
|
|
|
def _check_closure_correctly_lifted(self, f, *, args, exp_res, exp_arg_num):
|
|
assert isinstance(args, (tuple, list))
|
|
self.assertEqual(f(*args), exp_res)
|
|
gm = make_fx(f)(*args)
|
|
self.assertEqual(gm(*args), exp_res)
|
|
|
|
def cnt_placeholder(gm):
|
|
return len([node for node in gm.graph.nodes if node.op == "placeholder"])
|
|
|
|
placeholder_cnts = [cnt_placeholder(mod) for mod in gm.children()]
|
|
self.assertTrue(all(cnt == exp_arg_num for cnt in placeholder_cnts))
|
|
|
|
def _check_closure_correctly_lifted_with_mutation(
|
|
self, f, closures_to_be_mutated, *, args, exp_arg_num
|
|
):
|
|
exp_res = f(*args)
|
|
self._check_closure_correctly_lifted(
|
|
f, args=args, exp_res=exp_res, exp_arg_num=exp_arg_num
|
|
)
|
|
|
|
for closure in closures_to_be_mutated:
|
|
closure.add(-1)
|
|
new_exp_res = f(*args)
|
|
|
|
self._check_closure_correctly_lifted(
|
|
f, args=args, exp_res=new_exp_res, exp_arg_num=exp_arg_num
|
|
)
|
|
|
|
def test_cond_with_tensor_closure(self):
|
|
a = torch.ones(2, 3)
|
|
b = torch.ones(2, 3) + 1
|
|
|
|
def true_fn(x):
|
|
return x + a
|
|
|
|
def false_fn(x):
|
|
return x + b
|
|
|
|
def foo(x):
|
|
return cond(x.shape[0] == 4, true_fn, false_fn, [x])
|
|
|
|
# expected branches takes [x, a, b] as input
|
|
inp = torch.randn(2, 3)
|
|
self._check_closure_correctly_lifted_with_mutation(
|
|
foo, (a, b), args=(inp,), exp_arg_num=3
|
|
)
|
|
|
|
def test_cond_with_tensor_closure_graph_module(self):
|
|
a = torch.ones(2, 3)
|
|
b = torch.ones(2, 3) + 1
|
|
|
|
def true_fn(x):
|
|
return x + a
|
|
|
|
def false_fn(x):
|
|
return x + b
|
|
|
|
def foo(x):
|
|
return cond(x.shape[0] == 4, true_fn, false_fn, [x])
|
|
|
|
# expected branches takes [x, a, b] as input
|
|
inp = torch.randn(2, 3)
|
|
|
|
gm = make_fx(foo, tracing_mode="symbolic", _allow_non_fake_inputs=True)(inp)
|
|
|
|
self.assertExpectedInline(
|
|
gm.code.strip(),
|
|
"""\
|
|
def forward(self, x_1):
|
|
sym_size_int = torch.ops.aten.sym_size.int(x_1, 0)
|
|
eq = sym_size_int == 4
|
|
sym_size_int_1 = torch.ops.aten.sym_size.int(x_1, 1)
|
|
true_graph_0 = self.true_graph_0
|
|
false_graph_0 = self.false_graph_0
|
|
_tensor_constant0 = self._tensor_constant0
|
|
_tensor_constant1 = self._tensor_constant1
|
|
cond = torch.ops.higher_order.cond(eq, true_graph_0, false_graph_0, (x_1, _tensor_constant0, sym_size_int_1, sym_size_int, _tensor_constant1)); eq = true_graph_0 = false_graph_0 = x_1 = _tensor_constant0 = sym_size_int_1 = sym_size_int = _tensor_constant1 = None
|
|
getitem = cond[0]; cond = None
|
|
return getitem""", # noqa: B950
|
|
)
|
|
self.assertExpectedInline(
|
|
gm.true_graph_0.code.strip(),
|
|
"""\
|
|
def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1):
|
|
add = torch.ops.aten.add.Tensor(arg0_1, arg1_1); arg0_1 = arg1_1 = None
|
|
return (add,)""",
|
|
)
|
|
|
|
def test_cond_with_module_param_closure(self):
|
|
class Mod(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.register_parameter(
|
|
"param", torch.nn.Parameter(torch.ones(2, 3), requires_grad=False)
|
|
)
|
|
self.buffer = torch.nn.Buffer(torch.ones(2, 3) + 1)
|
|
|
|
my_mode = Mod()
|
|
|
|
def true_fn(x):
|
|
return x + my_mode.param
|
|
|
|
def false_fn(x):
|
|
return x + my_mode.buffer
|
|
|
|
def foo(x):
|
|
return cond(x.shape[0] == 4, true_fn, false_fn, [x])
|
|
|
|
inp = torch.ones(2, 3)
|
|
# expected both branches takes (x, param, buffer)
|
|
self._check_closure_correctly_lifted_with_mutation(
|
|
foo, (my_mode.param, my_mode.buffer), args=(inp,), exp_arg_num=3
|
|
)
|
|
|
|
def test_cond_with_module_python_scalar_closure(self):
|
|
def foo(x):
|
|
a = torch.ones(1, 1)
|
|
b = 1
|
|
|
|
def true_fn(x):
|
|
return x + a
|
|
|
|
def false_fn(x):
|
|
return x + b
|
|
|
|
return cond(x.shape[0] == 4, true_fn, false_fn, [x])
|
|
|
|
inp = torch.ones(2, 3)
|
|
res = inp + 1
|
|
# python scalar b is not lifted as input, so both branches take (x, a)
|
|
self._check_closure_correctly_lifted(
|
|
foo, args=(inp,), exp_res=res, exp_arg_num=2
|
|
)
|
|
|
|
def test_cond_nested_with_closure(self):
|
|
a = torch.ones(1, 1)
|
|
b = torch.ones(1, 1) + 1
|
|
|
|
def inner_true_fn(x):
|
|
return x + a
|
|
|
|
def inner_false_fn(x):
|
|
return x + b
|
|
|
|
def foo(x):
|
|
def true_fn(x):
|
|
return cond(x.shape[0] == 2, inner_true_fn, inner_false_fn, [x])
|
|
|
|
def false_fn(x):
|
|
return cond(x.shape[0] > 4, inner_true_fn, inner_false_fn, [x])
|
|
|
|
return cond(x.shape[0] == 4, true_fn, false_fn, [x])
|
|
|
|
inp = torch.ones(2, 3)
|
|
# For top-level cond, it take 3 arguments (x, a, b). Dynamo should
|
|
# realize that the nonlocal variables are same for the true and false
|
|
# branches, so it should de-dupe them.
|
|
# For second-level conds, it takes (x, a, b)
|
|
self._check_closure_correctly_lifted_with_mutation(
|
|
foo, (a, b), args=(inp,), exp_arg_num=3
|
|
)
|
|
|
|
def test_cond_nested_with_closure_graph_module(self):
|
|
a = torch.ones(1, 1)
|
|
b = torch.ones(1, 1) + 1
|
|
|
|
def inner_true_fn(x):
|
|
return x + a
|
|
|
|
def inner_false_fn(x):
|
|
return x + b
|
|
|
|
def foo(x):
|
|
def true_fn(x):
|
|
return cond(x.shape[0] == 2, inner_true_fn, inner_false_fn, [x])
|
|
|
|
def false_fn(x):
|
|
return cond(x.shape[0] > 4, inner_true_fn, inner_false_fn, [x])
|
|
|
|
return cond(x.shape[0] == 4, true_fn, false_fn, [x])
|
|
|
|
def test_map_unfunc_boolean_tensor_for_nested_map_cond(self):
|
|
def map_fn(pred, x):
|
|
def fn(x, pred):
|
|
return control_flow.cond(pred, lambda x: x * 2, lambda x: x / 2, (x,))
|
|
|
|
return control_flow.map(fn, x, pred)
|
|
|
|
def f_wrapper(func):
|
|
@functools.wraps(func)
|
|
def wrapper(*args, **kwargs):
|
|
torch._enable_functionalization(reapply_views=False)
|
|
try:
|
|
func_args = pytree.tree_map(
|
|
lambda x: to_fun_old(x) if isinstance(x, torch.Tensor) else x,
|
|
args,
|
|
)
|
|
func_kwargs = pytree.tree_map(
|
|
lambda x: to_fun_old(x) if isinstance(x, torch.Tensor) else x,
|
|
kwargs,
|
|
)
|
|
return pytree.tree_map(
|
|
from_fun_old, func(*func_args, **func_kwargs)
|
|
)
|
|
finally:
|
|
torch._disable_functionalization()
|
|
|
|
return wrapper
|
|
|
|
gm = make_fx(f_wrapper(map_fn))(
|
|
torch.tensor(True), torch.ones([2, 3], requires_grad=False)
|
|
)
|
|
self.assertExpectedInline(
|
|
gm.code.strip(),
|
|
"""\
|
|
def forward(self, pred_1, x_1):
|
|
unbind = torch.ops.aten.unbind.int(x_1)
|
|
getitem = unbind[0]; getitem = None
|
|
getitem_1 = unbind[1]; unbind = getitem_1 = None
|
|
body_graph_0 = self.body_graph_0
|
|
map_impl = torch.ops.higher_order.map_impl(body_graph_0, [x_1], [pred_1]); body_graph_0 = x_1 = pred_1 = None
|
|
getitem_2 = map_impl[0]; map_impl = None
|
|
return getitem_2""",
|
|
)
|
|
self.assertExpectedInline(
|
|
gm.body_graph_0.code.strip(),
|
|
"""\
|
|
def forward(self, arg0_1, arg1_1):
|
|
true_graph_0 = self.true_graph_0
|
|
false_graph_0 = self.false_graph_0
|
|
cond = torch.ops.higher_order.cond(arg1_1, true_graph_0, false_graph_0, (arg0_1,)); arg1_1 = true_graph_0 = false_graph_0 = arg0_1 = None
|
|
getitem = cond[0]; cond = None
|
|
return (getitem,)""", # noqa: B950
|
|
)
|
|
|
|
@skipIfCrossRef # Arg order changes with crossref
|
|
def test_cond_make_fx_preserve_stack_trace_for_nodes_in_subgraph(self):
|
|
def true_fn(x):
|
|
return x + x.cos()
|
|
|
|
def false_fn(x):
|
|
return x * x.sin()
|
|
|
|
def foo(x):
|
|
return cond(x.shape[0] == 4, true_fn, false_fn, (x,))
|
|
|
|
inp = torch.randn([4, 3])
|
|
gm, _ = torch._dynamo.export(foo)(inp)
|
|
|
|
def run_with_interpreter(*args):
|
|
with torch.fx.traceback.preserve_node_meta():
|
|
return torch.fx.Interpreter(gm).run(*args)
|
|
|
|
new_gm = make_fx(run_with_interpreter)(inp)
|
|
|
|
checked_ops = {"add", "mul", "sin", "cos"}
|
|
checked_meta = ["source_fn_stack", "stack_trace"]
|
|
all_source_fns = collect_meta_for_filtered_nodes(gm, checked_ops, checked_meta)
|
|
new_source_fns = collect_meta_for_filtered_nodes(
|
|
new_gm, checked_ops, checked_meta
|
|
)
|
|
self.assertEqual(all_source_fns, new_source_fns)
|
|
|
|
@unittest.skipIf(
|
|
TEST_WITH_TORCHDYNAMO,
|
|
"triggers cache limit for foo and changes unique_graphs count.",
|
|
)
|
|
def test_cond_no_dynamo_cache_limit(self):
|
|
torch._dynamo.reset()
|
|
counters = torch._dynamo.utils.counters
|
|
counters.clear()
|
|
|
|
def foo(x, true_fn, false_fn):
|
|
return cond(x.sum() < 0, true_fn, false_fn, (x,))
|
|
|
|
inp = torch.ones(3, 4)
|
|
exp_out = inp.sin()
|
|
iter_n = torch._dynamo.config.recompile_limit + 1
|
|
|
|
# Need functions that cause recompilations
|
|
def get_dummy_fns(str):
|
|
def dummy_cos(x):
|
|
return x.cos() + len(str) - len(str)
|
|
|
|
def dummy_sin(x):
|
|
return x.sin() + len(str) - len(str)
|
|
|
|
return dummy_cos, dummy_sin
|
|
|
|
for i in range(iter_n):
|
|
# we fail guards each iter because `str(i)` is different
|
|
self.assertEqual(foo(inp, *get_dummy_fns(str(i))), exp_out)
|
|
|
|
# each iteration captures a cond and a getitem from the tuple output
|
|
self.assertEqual(counters["stats"]["calls_captured"], iter_n * 2)
|
|
self.assertEqual(counters["stats"]["unique_graphs"], iter_n)
|
|
|
|
def test_cond_with_consecutive_make_fx_symbolic(self):
|
|
def true_fn(x):
|
|
return x - x.cos()
|
|
|
|
def false_fn(x):
|
|
return x + x.sin()
|
|
|
|
def foo(x):
|
|
return cond(x.shape[0] == 4, true_fn, false_fn, [x])
|
|
|
|
inps = (torch.ones(3, 4), torch.ones(3, 5), torch.ones(5, 4), torch.ones(5, 3))
|
|
for inp in inps:
|
|
gm = make_fx(foo, tracing_mode="symbolic")(torch.ones(3, 4))
|
|
self.assertExpectedInline(
|
|
gm.code.strip(),
|
|
"""\
|
|
def forward(self, x_1):
|
|
sym_size_int = torch.ops.aten.sym_size.int(x_1, 0)
|
|
eq = sym_size_int == 4
|
|
sym_size_int_1 = torch.ops.aten.sym_size.int(x_1, 1)
|
|
true_graph_0 = self.true_graph_0
|
|
false_graph_0 = self.false_graph_0
|
|
cond = torch.ops.higher_order.cond(eq, true_graph_0, false_graph_0, (x_1, sym_size_int_1, sym_size_int)); eq = true_graph_0 = false_graph_0 = x_1 = sym_size_int_1 = sym_size_int = None
|
|
getitem = cond[0]; cond = None
|
|
return getitem""", # noqa: B950
|
|
)
|
|
|
|
self.assertExpectedInline(
|
|
gm.true_graph_0.code.strip(),
|
|
"""\
|
|
def forward(self, arg0_1, arg1_1, arg2_1):
|
|
cos = torch.ops.aten.cos.default(arg0_1)
|
|
sub = torch.ops.aten.sub.Tensor(arg0_1, cos); arg0_1 = cos = None
|
|
return (sub,)""",
|
|
)
|
|
|
|
self.assertExpectedInline(
|
|
gm.false_graph_0.code.strip(),
|
|
"""\
|
|
def forward(self, arg0_1, arg1_1, arg2_1):
|
|
sin = torch.ops.aten.sin.default(arg0_1)
|
|
add = torch.ops.aten.add.Tensor(arg0_1, sin); arg0_1 = sin = None
|
|
return (add,)""",
|
|
)
|
|
|
|
def _create_test_fns_for_cond(
|
|
self, pred, inner_most_fn, operands, closure_list, nested_level
|
|
):
|
|
if nested_level == 0:
|
|
if len(closure_list) > 0:
|
|
|
|
def true_fn(*operands):
|
|
return inner_most_fn(*operands) + inner_most_fn(*closure_list)
|
|
|
|
def false_fn(*operands):
|
|
return inner_most_fn(*operands) - inner_most_fn(*closure_list)
|
|
|
|
else:
|
|
|
|
def true_fn(*operands):
|
|
return inner_most_fn(*operands)
|
|
|
|
def false_fn(*operands):
|
|
return inner_most_fn(*operands)
|
|
|
|
def fn(*operands):
|
|
if len(operands) == 0 and len(closure_list) == 0:
|
|
return torch.zeros(1)
|
|
return cond(pred, true_fn, false_fn, operands)
|
|
|
|
return operands, fn
|
|
else:
|
|
args, inner_fn = self._create_test_fns_for_cond(
|
|
pred <= 0, inner_most_fn, operands, closure_list, nested_level - 1
|
|
)
|
|
|
|
def true_fn(*operands):
|
|
return inner_most_fn(*operands) + inner_fn(*args)
|
|
|
|
def false_fn(*operands):
|
|
return inner_most_fn(*operands) - inner_fn(*args)
|
|
|
|
def fn(*operands):
|
|
if len(operands) == 0 and len(closure_list) == 0:
|
|
return torch.ones(1)
|
|
return cond(pred, true_fn, false_fn, operands)
|
|
|
|
return operands, fn
|
|
|
|
def _init_predicate(self, pred_type):
|
|
if pred_type == "bool":
|
|
return True
|
|
elif pred_type == "intTensor":
|
|
return torch.tensor(1)
|
|
elif pred_type == "floatTensor":
|
|
return torch.tensor(1.0)
|
|
elif pred_type == "boolTensor":
|
|
return torch.tensor(False)
|
|
else:
|
|
raise NotImplementedError
|
|
|
|
def _init_fn(self, inner_fn_type):
|
|
if inner_fn_type == "function":
|
|
return reduce_func
|
|
elif inner_fn_type == "module":
|
|
return ReduceMod()
|
|
elif inner_fn_type == "object":
|
|
return ReduceObj()
|
|
else:
|
|
raise NotImplementedError
|
|
|
|
@parametrize("predType", ["bool", "intTensor", "floatTensor", "boolTensor"])
|
|
@parametrize("innerFnType", ["function", "module", "object"])
|
|
@parametrize("nOperands", [0, 1])
|
|
@parametrize("nClosure", [0, 1])
|
|
@parametrize("nesting", [0, 2])
|
|
def test_cond_tracing_with_valid_inputs(
|
|
self, predType, innerFnType, nOperands, nClosure, nesting
|
|
):
|
|
pred = self._init_predicate(predType)
|
|
inner_fn = self._init_fn(innerFnType)
|
|
operands = [torch.ones(2, 3) + i for i in range(nOperands)]
|
|
closure = [torch.ones(2, 3) - i for i in range(nClosure)]
|
|
args, fn = self._create_test_fns_for_cond(
|
|
pred, inner_fn, operands, closure, nesting
|
|
)
|
|
eager_res = fn(*args)
|
|
for tracing_mode in ["symbolic", "fake", "real"]:
|
|
# set _allow_non_fake_inputs = True to allow fake prop through closures
|
|
with self.subTest(tracing_mode=tracing_mode):
|
|
gm = make_fx(
|
|
fn, tracing_mode=tracing_mode, _allow_non_fake_inputs=True
|
|
)(*args)
|
|
self.assertEqual(gm(*args), eager_res)
|
|
|
|
@parametrize("predType", ["boolTensor"])
|
|
@parametrize("innerFnType", ["function", "module", "object"])
|
|
@parametrize("nOperands", [1, 2])
|
|
@parametrize("nClosure", [0, 1])
|
|
@parametrize("nesting", [0])
|
|
def test_cond_vmap(self, predType, innerFnType, nOperands, nClosure, nesting):
|
|
pred = self._init_predicate(predType)
|
|
inner_fn = self._init_fn(innerFnType)
|
|
operands = [torch.ones(2, 3) + i for i in range(nOperands)]
|
|
closure = [torch.ones(2, 3) - i for i in range(nClosure)]
|
|
args, fn = self._create_test_fns_for_cond(
|
|
pred, inner_fn, operands, closure, nesting
|
|
)
|
|
eager_res = fn(*args)
|
|
out = torch.vmap(fn)(*args)
|
|
if nClosure == 0:
|
|
self.assertEqual(eager_res, out)
|
|
else:
|
|
self.assertEqual(eager_res, out[0])
|
|
self.assertEqual(eager_res, out[1])
|
|
|
|
def test_cond_vmap_simple(self):
|
|
def fn(x):
|
|
return torch.cond(
|
|
pred=torch.tensor([True]),
|
|
true_fn=lambda x: x + 100,
|
|
false_fn=lambda x: x.clone(),
|
|
operands=(x,),
|
|
)
|
|
|
|
a = torch.arange(15).reshape((3, 5))
|
|
res = torch.vmap(fn, in_dims=(0,))(a)
|
|
self.assertEqual(res.shape, (3, 5))
|
|
self.assertEqual(res, a + 100)
|
|
|
|
def test_cond_vmap_multiple_inputs(self):
|
|
def fn(x, y):
|
|
return torch.cond(
|
|
pred=x.sum() < y.sum(),
|
|
true_fn=lambda x, y: x + 100,
|
|
false_fn=lambda x, y: y.clone(),
|
|
operands=(x, y),
|
|
)
|
|
|
|
a = torch.arange(15).reshape(3, 5)
|
|
b = torch.ones_like(a) + 3
|
|
res = torch.vmap(fn, in_dims=(0, 0))(a, b)
|
|
expected = torch.tensor(
|
|
[[100, 101, 102, 103, 104], [4, 4, 4, 4, 4], [4, 4, 4, 4, 4]]
|
|
)
|
|
self.assertEqual(res.shape, (3, 5))
|
|
self.assertEqual(expected, res)
|
|
|
|
def test_cond_vmap_single_input_with_closure(self):
|
|
a = torch.ones((3, 5)) + 3
|
|
c = torch.arange(5)
|
|
|
|
def fn(x):
|
|
return torch.cond(
|
|
pred=torch.tensor([True]),
|
|
true_fn=lambda x: x + c,
|
|
false_fn=lambda x: x - c,
|
|
operands=(x,),
|
|
)
|
|
|
|
res = torch.vmap(fn, in_dims=(0,))(
|
|
a,
|
|
)
|
|
with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True):
|
|
res = torch.vmap(fn, in_dims=(0,))(
|
|
a,
|
|
)
|
|
self.assertEqual(a + c, res)
|
|
|
|
def test_cond_vmap_multiple_args_with_closure(self):
|
|
a = torch.ones((3, 5), dtype=torch.int64) + 3
|
|
b = torch.arange(15).reshape(3, 5)
|
|
c = torch.arange(5)
|
|
|
|
def fn(x, y):
|
|
return torch.cond(
|
|
pred=torch.tensor([False]),
|
|
true_fn=lambda x, y: x + c,
|
|
false_fn=lambda x, y: y - c,
|
|
operands=(x, y),
|
|
)
|
|
|
|
res = torch.vmap(fn)(a, b)
|
|
self.assertEqual(b - c, res)
|
|
|
|
@parametrize("nClosure", [0, 1])
|
|
def test_cond_vmap_multiple_outputs(self, nClosure):
|
|
if nClosure:
|
|
c = torch.ones(5, dtype=torch.int64) + 5
|
|
|
|
def fn(x):
|
|
return torch.cond(
|
|
pred=torch.tensor([True]),
|
|
true_fn=lambda x: (x + c, x - c),
|
|
false_fn=lambda x: (x.clone(), x.clone()),
|
|
operands=(x,),
|
|
)
|
|
|
|
else:
|
|
|
|
def fn(x):
|
|
return torch.cond(
|
|
pred=torch.tensor([True]),
|
|
true_fn=lambda x: (x + 1, x - 1),
|
|
false_fn=lambda x: (x.clone(), x.clone()),
|
|
operands=(x,),
|
|
)
|
|
|
|
a = torch.arange(15).reshape(3, 5)
|
|
res = torch.vmap(fn)(
|
|
a,
|
|
)
|
|
self.assertEqual(len(res), 2)
|
|
if nClosure:
|
|
self.assertEqual(res, (a + c, a - c))
|
|
else:
|
|
self.assertEqual(res, (a + 1, a - 1))
|
|
|
|
@parametrize("boolcond", [True, False])
|
|
def test_vmap_vmap(self, boolcond):
|
|
def fn(x):
|
|
return torch.cond(
|
|
pred=torch.tensor([True]) if not boolcond else True,
|
|
true_fn=lambda x: x + 1,
|
|
false_fn=lambda x: x - 1,
|
|
operands=(x,),
|
|
)
|
|
|
|
def wrapper(x):
|
|
return torch.vmap(fn)(x)
|
|
|
|
a = torch.ones((3, 4, 5))
|
|
res = torch.vmap(wrapper)(a)
|
|
self.assertEqual(res, a + 1)
|
|
|
|
def test_cond_trace_set__and_mutate_input(self):
|
|
def f(a, tmp):
|
|
a_view = a.view(-1)
|
|
with torch.no_grad():
|
|
a.set_(tmp)
|
|
a_view.mul_(2)
|
|
return a + tmp
|
|
|
|
inp = torch.ones(3, 3, requires_grad=True)
|
|
tmp = torch.ones(3, 3, requires_grad=True)
|
|
# graph break: torch._dynamo.exc.Unsupported: call_function DelayGraphBreakVariable() [TensorVariable()] {}
|
|
# due to set_
|
|
with self.assertRaisesRegex(
|
|
torch._dynamo.exc.UncapturedHigherOrderOpError,
|
|
"Cond doesn't work unless it is captured completely with torch.compile",
|
|
):
|
|
torch.cond(inp.sum() > 0, f, f, (inp, tmp))
|
|
|
|
@skipIfCrossRef # Arg order changes with crossref
|
|
def test_cond_trace_set__and_mutate_intermediate(self):
|
|
def f(a, tmp):
|
|
a = a.clone()
|
|
a_view = a.view(-1)
|
|
tmp = tmp.clone()
|
|
with torch.no_grad():
|
|
a.set_(tmp)
|
|
a_view.mul_(2)
|
|
return a + tmp
|
|
|
|
inp = torch.ones(3, 3, requires_grad=True)
|
|
tmp = torch.ones(3, 3, requires_grad=True)
|
|
|
|
class Mod(torch.nn.Module):
|
|
def forward(self, inp: torch.Tensor, tmp: torch.Tensor) -> torch.Tensor:
|
|
return torch.cond(inp.sum() > 0, f, f, (inp, tmp))
|
|
|
|
with self.assertRaisesRegex(
|
|
RuntimeError, "cannot mutate tensors with frozen storage"
|
|
):
|
|
out = torch.compile(Mod(), backend="aot_eager")(inp, tmp)
|
|
|
|
with self.assertRaisesRegex(
|
|
RuntimeError, "cannot mutate tensors with frozen storage"
|
|
):
|
|
out = torch.compile(Mod(), backend="inductor")(inp, tmp)
|
|
|
|
backend = EagerAndRecordGraphs()
|
|
out = torch.compile(Mod(), backend=backend)(inp, tmp)
|
|
self.assertExpectedInline(
|
|
backend.graphs[0].cond_true_0.code.strip("\n"),
|
|
"""\
|
|
def forward(self, l_inp_, l_tmp_):
|
|
l_inp__1 = l_inp_
|
|
l_tmp__1 = l_tmp_
|
|
a = l_inp__1.clone(); l_inp__1 = None
|
|
a_view = a.view(-1)
|
|
tmp = l_tmp__1.clone(); l_tmp__1 = None
|
|
_set_grad_enabled = torch._C._set_grad_enabled(False); _set_grad_enabled = None
|
|
set_ = a.set_(tmp); set_ = None
|
|
mul_ = a_view.mul_(2); a_view = mul_ = None
|
|
_set_grad_enabled_1 = torch._C._set_grad_enabled(True); _set_grad_enabled_1 = None
|
|
add = a + tmp; a = tmp = None
|
|
return (add,)
|
|
""",
|
|
)
|
|
self.assertEqual(out, f(inp, tmp))
|
|
|
|
@skipIfCrossRef # Args get renamed to r in crossref mode
|
|
@parametrize("requires_grad", [True, False])
|
|
def test_cond_symint_operands(self, requires_grad):
|
|
backend = EagerAndRecordGraphs()
|
|
|
|
class Mod(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.num = 3
|
|
|
|
def forward(self, a, b):
|
|
return torch.cond(
|
|
pred=torch.tensor([True]),
|
|
true_fn=lambda a, b: a + b + self.num,
|
|
false_fn=lambda a, b: a - b - self.num,
|
|
operands=(a, b),
|
|
)
|
|
|
|
a = torch.ones(3, 3, requires_grad=requires_grad)
|
|
b = torch.ones(3, 3, requires_grad=requires_grad)
|
|
out = torch.compile(Mod(), backend=backend, dynamic=True)(a, b)
|
|
self.assertEqual(out, Mod()(a, b))
|
|
self.assertEqual(len(backend.graphs), 1)
|
|
self.assertExpectedInline(
|
|
backend.graphs[0].code.strip(),
|
|
"""\
|
|
def forward(self, s97 : torch.SymInt, L_a_ : torch.Tensor, L_b_ : torch.Tensor):
|
|
l_a_ = L_a_
|
|
l_b_ = L_b_
|
|
tensor = torch.tensor([True])
|
|
cond_true_0 = self.cond_true_0
|
|
cond_false_0 = self.cond_false_0
|
|
cond = torch.ops.higher_order.cond(tensor, cond_true_0, cond_false_0, (l_a_, l_b_, s97)); tensor = cond_true_0 = cond_false_0 = l_a_ = l_b_ = s97 = None
|
|
getitem = cond[0]; cond = None
|
|
return (getitem,)""", # noqa: B950
|
|
)
|
|
|
|
def test_two_hops_not_sharing_code_obj(self):
|
|
pred, args = torch.tensor(True), (torch.ones(3, 3),)
|
|
|
|
def fn1(x):
|
|
return x + 1
|
|
|
|
def fn2(x):
|
|
return x - 1
|
|
|
|
from torch._dynamo.testing import CompileCounter
|
|
|
|
# Tests rely on automatic_dynamic = True
|
|
with torch._dynamo.config.patch(automatic_dynamic_shapes=True):
|
|
cnt = CompileCounter()
|
|
torch.compile(torch.cond, backend=cnt)(pred, fn1, fn2, args)
|
|
self.assertEqual(cnt.frame_count, 1)
|
|
|
|
args = (torch.randn(3, 3),)
|
|
# No recompilation
|
|
torch.compile(torch.cond, backend=cnt)(pred, fn1, fn2, args)
|
|
self.assertEqual(cnt.frame_count, 1)
|
|
|
|
def cond_fn(x):
|
|
return x.sum() > 0
|
|
|
|
args = (torch.randn(4, 4),)
|
|
torch.compile(torch.while_loop, backend=cnt)(cond_fn, fn2, args)
|
|
# recompilation
|
|
self.assertEqual(cnt.frame_count, 2)
|
|
|
|
args = (torch.randn(4, 4),)
|
|
torch.compile(torch.while_loop, backend=cnt)(cond_fn, fn2, args)
|
|
self.assertEqual(cnt.frame_count, 2)
|
|
|
|
# With recompilation due to automatic dynamic
|
|
# This also proves that while_loop doesn't share code obj with cond
|
|
torch.compile(torch.cond, backend=cnt)(pred, fn1, fn2, (torch.randn(4, 4),))
|
|
self.assertEqual(cnt.frame_count, 3)
|
|
|
|
def test_hop_raises_if_not_overriding_call(self):
|
|
class WrongHop(torch._ops.HigherOrderOperator):
|
|
pass
|
|
|
|
with self.assertRaisesRegex(TypeError, "WrongHop"):
|
|
WrongHop("wrong_hop")
|
|
|
|
def test_scan_functionalized(self):
|
|
def f(init, xs):
|
|
return scan(get_scan_combine_fn("add", False), init, xs, dim=1)
|
|
|
|
example_inputs = torch.ones(5, 7, 4)
|
|
example_init = torch.ones(5, 4)
|
|
functional_f = torch.func.functionalize(f)
|
|
self.assertEqual(
|
|
functional_f(example_init, example_inputs), f(example_init, example_inputs)
|
|
)
|
|
|
|
def test_scan_functionalized_elem_mutation(self):
|
|
def add1(x, y):
|
|
x.add_(4)
|
|
return x + y, x + y
|
|
|
|
def f(init, xs):
|
|
return scan(add1, init, xs, dim=1)
|
|
|
|
example_inputs = torch.ones(5, 7, 4)
|
|
example_init = torch.ones(5, 4)
|
|
functional_f = torch.func.functionalize(f)
|
|
with self.assertRaisesRegex(
|
|
# TODO: Fix this so that the HOPs show similar errors for functionalization
|
|
# This is the Exception with PYTORCH_TEST_WITH_DYNAMO=0
|
|
# RuntimeError,
|
|
# "torch.scan might be modifying the input!",
|
|
# This is the Exception with PYTORCH_TEST_WITH_DYNAMO=1
|
|
# torch._dynamo.exc.TorchDynamoException,
|
|
# "Unexpected exception when running generated GraphModule.*"
|
|
Exception,
|
|
".*",
|
|
):
|
|
functional_f(example_init, example_inputs)
|
|
|
|
def add2(x, y):
|
|
y.add_(4)
|
|
return x + y, x + y
|
|
|
|
def f(init, xs):
|
|
return scan(add2, init, xs, dim=1)
|
|
|
|
functional_f = torch.func.functionalize(f)
|
|
with self.assertRaisesRegex(
|
|
# TODO: Fix this so that the HOPs show similar errors for functionalization
|
|
# Should be
|
|
# This is the Exception with PYTORCH_TEST_WITH_DYNAMO=0
|
|
# RuntimeError,
|
|
# "torch.scan might be modifying the input!",
|
|
# This is the Exception with PYTORCH_TEST_WITH_DYNAMO=1
|
|
# torch._dynamo.exc.TorchDynamoException,
|
|
# "Unexpected exception when running generated GraphModule.*"
|
|
Exception,
|
|
".*",
|
|
):
|
|
functional_f(example_init, example_inputs)
|
|
|
|
def test_scan_functionalized_elem_alias(self):
|
|
def add(x, y):
|
|
return x, x
|
|
|
|
def f(init, xs):
|
|
return scan(add, init, xs, dim=1)
|
|
|
|
example_inputs = torch.ones(5, 7, 4)
|
|
example_init = torch.ones(5, 4)
|
|
functional_f = torch.func.functionalize(f)
|
|
with self.assertRaisesRegex(
|
|
# TODO: Fix this so that the HOPs show similar errors for functionalization
|
|
# Should be
|
|
# This is the Exception with PYTORCH_TEST_WITH_DYNAMO=0
|
|
# torch._dynamo.exc.Unsupported,
|
|
# "Encountered aliasing during higher order op tracing for HOP.*"
|
|
# This is the Exception with PYTORCH_TEST_WITH_DYNAMO=1
|
|
# torch._dynamo.exc.UncapturedHigherOrderOpError,
|
|
# "scan must be captured completely with torch.compile.*",
|
|
Exception,
|
|
".*",
|
|
):
|
|
functional_f(example_init, example_inputs)
|
|
|
|
@skipIfTorchDynamo("Graph is not captured by backend if test with dynamo")
|
|
def test_scan_pytree_closure(self):
|
|
param_buffer = ({"param": torch.randn(3, 3)}, (torch.randn(3),))
|
|
|
|
def add(carry, x):
|
|
ret = (carry @ param_buffer[0]["param"]) @ x + param_buffer[1][0]
|
|
return ret, ret.sum()
|
|
|
|
def f(init, xs):
|
|
return scan(add, init, xs)
|
|
|
|
init = torch.randn(4, 3)
|
|
xs = torch.randn(3, 3, 3)
|
|
|
|
backend = EagerAndRecordGraphs()
|
|
eager_out = f(init, xs)
|
|
compiled_out = torch.compile(f, backend=backend)(init, xs)
|
|
exp_out = _fake_scan(add, init, xs)
|
|
|
|
self.assertEqual(len(backend.graphs), 1)
|
|
if TEST_WITH_CROSSREF:
|
|
self.assertExpectedInline(
|
|
backend.graphs[0].code.strip(),
|
|
"""\
|
|
def forward(self, L_init_ : torch.Tensor, L_xs_ : torch.Tensor, L_add_closure_0_cell_contents_0_param_ : torch.Tensor, L_add_closure_0_cell_contents_1_0_ : torch.Tensor):
|
|
l_init_ = L_init_
|
|
l_xs_ = L_xs_
|
|
l_add_closure_0_cell_contents_0_param_ = L_add_closure_0_cell_contents_0_param_
|
|
l_add_closure_0_cell_contents_1_0_ = L_add_closure_0_cell_contents_1_0_
|
|
r = torch.movedim(l_xs_, 0, 0); l_xs_ = None
|
|
scan_combine_fn_0 = self.scan_combine_fn_0
|
|
scan = torch.ops.higher_order.scan(scan_combine_fn_0, [l_init_], [r], [l_add_closure_0_cell_contents_0_param_, l_add_closure_0_cell_contents_1_0_]); scan_combine_fn_0 = l_init_ = r = l_add_closure_0_cell_contents_0_param_ = l_add_closure_0_cell_contents_1_0_ = None
|
|
carry = scan[0]
|
|
out = scan[1]; scan = None
|
|
return (carry, out)""", # noqa: B950
|
|
)
|
|
else:
|
|
self.assertExpectedInline(
|
|
backend.graphs[0].code.strip(),
|
|
"""\
|
|
def forward(self, L_init_ : torch.Tensor, L_xs_ : torch.Tensor, L_add_closure_0_cell_contents_0_param_ : torch.Tensor, L_add_closure_0_cell_contents_1_0_ : torch.Tensor):
|
|
l_init_ = L_init_
|
|
l_xs_ = L_xs_
|
|
l_add_closure_0_cell_contents_0_param_ = L_add_closure_0_cell_contents_0_param_
|
|
l_add_closure_0_cell_contents_1_0_ = L_add_closure_0_cell_contents_1_0_
|
|
movedim = torch.movedim(l_xs_, 0, 0); l_xs_ = None
|
|
scan_combine_fn_0 = self.scan_combine_fn_0
|
|
scan = torch.ops.higher_order.scan(scan_combine_fn_0, [l_init_], [movedim], [l_add_closure_0_cell_contents_0_param_, l_add_closure_0_cell_contents_1_0_]); scan_combine_fn_0 = l_init_ = movedim = l_add_closure_0_cell_contents_0_param_ = l_add_closure_0_cell_contents_1_0_ = None
|
|
carry = scan[0]
|
|
out = scan[1]; scan = None
|
|
return (carry, out)""", # noqa: B950
|
|
)
|
|
self.assertEqual(eager_out, exp_out)
|
|
self.assertEqual(compiled_out, exp_out)
|
|
|
|
@skipIfTorchDynamo("Skip because we're testing export")
|
|
@parametrize("strict", [True, False])
|
|
@parametrize("dynamic", [True, False])
|
|
def test_while_loop_op_int_carry_export(self, strict, dynamic):
|
|
m, args = WHILE_LOOP_TESTS["int_carry"]
|
|
dynamic_shapes = {"x": {0: torch.export.Dim("dim_x")}} if dynamic else None
|
|
ep = self._check_export(m, args, strict=strict, dynamic_shapes=dynamic_shapes)
|
|
if not strict and dynamic:
|
|
self.assertExpectedInline(
|
|
normalize_gm(ep.module().print_readable(print_output=False)),
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(self, x):
|
|
x: "f32[s77, 3]";
|
|
|
|
x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
|
|
sym_size_int_1: "Sym(s77)" = torch.ops.aten.sym_size.int(x, 0)
|
|
|
|
while_loop_cond_graph_0 = self.while_loop_cond_graph_0
|
|
while_loop_body_graph_0 = self.while_loop_body_graph_0
|
|
while_loop = torch.ops.higher_order.while_loop(while_loop_cond_graph_0, while_loop_body_graph_0, (0, x), ()); while_loop_cond_graph_0 = while_loop_body_graph_0 = x = None
|
|
|
|
getitem_2: "Sym(u1)" = while_loop[0]
|
|
|
|
ge: "Sym(u1 >= 1)" = getitem_2 >= 1
|
|
_assert_scalar_default = torch.ops.aten._assert_scalar.default(ge, "Runtime assertion failed for expression u1 >= 1 on node 'ge'"); ge = _assert_scalar_default = None
|
|
|
|
gt_1: "Sym(u1 > 0)" = getitem_2 > 0
|
|
_assert_scalar_default_1 = torch.ops.aten._assert_scalar.default(gt_1, "Runtime assertion failed for expression 0 < u1 on node 'gt_1'"); gt_1 = _assert_scalar_default_1 = None
|
|
|
|
getitem_1: "f32[s77, 3]" = while_loop[1]; while_loop = None
|
|
|
|
add: "Sym(u1 + 1)" = getitem_2 + 1
|
|
|
|
add_1: "f32[s77, 3]" = torch.ops.aten.add.Tensor(getitem_1, getitem_2); getitem_1 = None
|
|
|
|
lt: "Sym(u1 < s77)" = getitem_2 < sym_size_int_1; sym_size_int_1 = None
|
|
|
|
mul: "Sym(2*u1)" = getitem_2 * 2; getitem_2 = None
|
|
ones: "f32[2*u1]" = torch.ops.aten.ones.default([mul], device = device(type='cpu'), pin_memory = False); mul = None
|
|
return pytree.tree_unflatten((add, add_1, lt, ones), self._out_spec)
|
|
|
|
class while_loop_cond_graph_0(torch.nn.Module):
|
|
def forward(self, it_1: "Sym(u0)", x_1: "f32[s77, 3]"):
|
|
sym_size_int_1: "Sym(s77)" = torch.ops.aten.sym_size.int(x_1, 0); x_1 = None
|
|
|
|
lt: "Sym(u0 < s77)" = it_1 < sym_size_int_1; it_1 = sym_size_int_1 = None
|
|
return lt
|
|
|
|
class while_loop_body_graph_0(torch.nn.Module):
|
|
def forward(self, it_1: "Sym(u0)", x_1: "f32[s77, 3]"):
|
|
clone: "f32[s77, 3]" = torch.ops.aten.clone.default(x_1); x_1 = None
|
|
select: "f32[3]" = torch.ops.aten.select.int(clone, 0, it_1)
|
|
select_1: "f32[3]" = torch.ops.aten.select.int(clone, 0, it_1)
|
|
add: "f32[3]" = torch.ops.aten.add.Tensor(select_1, it_1); select_1 = None
|
|
copy_: "f32[3]" = torch.ops.aten.copy_.default(select, add); select = add = copy_ = None
|
|
add_1: "Sym(u0 + 1)" = it_1 + 1; it_1 = None
|
|
return (add_1, clone)
|
|
""", # noqa: B950
|
|
)
|
|
|
|
@skipIfTorchDynamo("Graph is not captured correctly when test with dynamo")
|
|
@parametrize("dynamic", [True, False])
|
|
@parametrize("backend", ["eager", "aot_eager"])
|
|
def test_while_loop_op_int_carry_compile(self, dynamic, backend):
|
|
m, args = WHILE_LOOP_TESTS["int_carry"]
|
|
if backend == "eager":
|
|
backend = EagerAndRecordGraphs()
|
|
self._check_compile(m, args, dynamic=dynamic, backend=backend)
|
|
if (
|
|
isinstance(backend, EagerAndRecordGraphs)
|
|
and dynamic
|
|
and not TEST_WITH_CROSSREF
|
|
):
|
|
self.assertEqual(len(backend.graphs), 1)
|
|
self.assertExpectedInline(
|
|
normalize_gm(backend.graphs[0].print_readable(print_output=False)),
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(self, s77: "Sym(s77)", s27: "Sym(s27)", L_x_: "f32[s77, s27]"):
|
|
l_x_ = L_x_
|
|
|
|
cond_fn_0 = self.cond_fn_0
|
|
body_fn_0 = self.body_fn_0
|
|
while_loop = torch.ops.higher_order.while_loop(cond_fn_0, body_fn_0, (0, l_x_), (s27, s77)); cond_fn_0 = body_fn_0 = l_x_ = s27 = None
|
|
|
|
getitem_4: "Sym(u2)" = while_loop[0]
|
|
|
|
ge: "Sym(u2 >= 1)" = getitem_4 >= 1
|
|
_assert_scalar_default = torch.ops.aten._assert_scalar.default(ge, "Runtime assertion failed for expression u2 >= 1 on node 'ge'"); ge = _assert_scalar_default = None
|
|
|
|
gt_1: "Sym(u2 > 0)" = getitem_4 > 0
|
|
_assert_scalar_default_1 = torch.ops.aten._assert_scalar.default(gt_1, "Runtime assertion failed for expression 0 < u2 on node 'gt_1'"); gt_1 = _assert_scalar_default_1 = None
|
|
|
|
out_x: "f32[s77, s27]" = while_loop[1]; while_loop = None
|
|
|
|
gt: "Sym(u2 > 0)" = getitem_4 > 0
|
|
_check = torch._check(gt); gt = _check = None
|
|
|
|
add: "Sym(u2 + 1)" = getitem_4 + 1
|
|
|
|
add_1: "f32[s77, s27]" = getitem_4 + out_x; out_x = None
|
|
|
|
lt: "Sym(u2 < s77)" = getitem_4 < s77; s77 = None
|
|
|
|
mul: "Sym(2*u2)" = getitem_4 * 2; getitem_4 = None
|
|
ones: "f32[2*u2]" = torch.ones(mul); mul = None
|
|
return (add, add_1, lt, ones)
|
|
|
|
class cond_fn_0(torch.nn.Module):
|
|
def forward(self, unbacked_symint: "Sym(u0)", child: "f32[s77, s27]", s27: "Sym(s27)", s77: "Sym(s77)"):
|
|
s27_1 = s27
|
|
s77_1 = s77
|
|
|
|
size = child.size(); child = None
|
|
getitem: "Sym(s77)" = size[0]
|
|
getitem_1: "Sym(s27)" = size[1]; size = getitem_1 = None
|
|
lt: "Sym(u0 < s77)" = unbacked_symint < getitem; unbacked_symint = getitem = None
|
|
return lt
|
|
|
|
class body_fn_0(torch.nn.Module):
|
|
def forward(self, unbacked_symint_0: "Sym(u1)", child_1: "f32[s77, s27]", s27: "Sym(s27)", s77: "Sym(s77)"):
|
|
s27_1 = s27
|
|
s77_1 = s77
|
|
|
|
x_clone: "f32[s77, s27]" = child_1.clone()
|
|
|
|
ge: "Sym(u1 >= 0)" = unbacked_symint_0 >= 0
|
|
_check = torch._check(ge); ge = _check = None
|
|
|
|
size = child_1.size(); child_1 = None
|
|
getitem: "Sym(s77)" = size[0]
|
|
getitem_1: "Sym(s27)" = size[1]; size = getitem_1 = None
|
|
lt: "Sym(u1 < s77)" = unbacked_symint_0 < getitem; getitem = None
|
|
_check_1 = torch._check(lt); lt = _check_1 = None
|
|
|
|
select: "f32[s27]" = x_clone.select(0, unbacked_symint_0)
|
|
select_1: "f32[s27]" = x_clone.select(0, unbacked_symint_0)
|
|
add: "f32[s27]" = select_1 + unbacked_symint_0; select_1 = None
|
|
copy_: "f32[s27]" = select.copy_(add); select = add = copy_ = None
|
|
|
|
add_1: "Sym(u1 + 1)" = unbacked_symint_0 + 1; unbacked_symint_0 = None
|
|
return (add_1, x_clone)
|
|
""", # noqa: B950
|
|
)
|
|
|
|
@skipIfTorchDynamo("Skip because we're testing export")
|
|
@parametrize("strict", [True, False])
|
|
@parametrize("dynamic", [True, False])
|
|
@torch._dynamo.config.patch(capture_scalar_outputs=True)
|
|
def test_while_loop_op_constant_and_symint_output_export(self, strict, dynamic):
|
|
m, args = WHILE_LOOP_TESTS["const_and_symint_output"]
|
|
dynamic_shapes = {"t": {0: torch.export.Dim("dim_t")}} if dynamic else None
|
|
ep = self._check_export(m, args, strict=strict, dynamic_shapes=dynamic_shapes)
|
|
# strict or dynamic gives a slightly different graph
|
|
if not strict and not dynamic:
|
|
self.assertExpectedInline(
|
|
normalize_gm(ep.module().print_readable(print_output=False)),
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(self, t):
|
|
t: "f32[2, 3]";
|
|
|
|
t, = fx_pytree.tree_flatten_spec(([t], {}), self._in_spec)
|
|
sum_1: "f32[]" = torch.ops.aten.sum.default(t)
|
|
_assert_tensor_metadata_default = torch.ops.aten._assert_tensor_metadata.default(sum_1, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default = None
|
|
to: "i64[]" = torch.ops.aten.to.dtype(sum_1, torch.int64); sum_1 = None
|
|
item: "Sym(u0)" = torch.ops.aten.item.default(to); to = None
|
|
sin: "f32[2, 3]" = torch.ops.aten.sin.default(t)
|
|
|
|
while_loop_cond_graph_0 = self.while_loop_cond_graph_0
|
|
while_loop_body_graph_0 = self.while_loop_body_graph_0
|
|
while_loop = torch.ops.higher_order.while_loop(while_loop_cond_graph_0, while_loop_body_graph_0, (2, 3, 1, 1, 1, 3, item, sin), ()); while_loop_cond_graph_0 = while_loop_body_graph_0 = item = sin = None
|
|
|
|
getitem_8: "Sym(u8)" = while_loop[0]
|
|
getitem_9: "Sym(u9)" = while_loop[1]
|
|
getitem_10: "Sym(u10)" = while_loop[2]
|
|
getitem_11: "Sym(u11)" = while_loop[3]
|
|
getitem_12: "Sym(u12)" = while_loop[4]
|
|
getitem_13: "Sym(u13)" = while_loop[5]
|
|
getitem_14: "Sym(u14)" = while_loop[6]
|
|
|
|
getitem_7: "f32[2, 3]" = while_loop[7]; while_loop = None
|
|
|
|
add: "Sym(u8 + 1)" = getitem_8 + 1
|
|
add_1: "Sym(u9 + 1)" = getitem_9 + 1
|
|
add_2: "Sym(u10 + 1)" = getitem_10 + 1
|
|
add_3: "Sym(u11 + 1)" = getitem_11 + 1
|
|
add_4: "Sym(u12 + 1)" = getitem_12 + 1
|
|
add_5: "Sym(u13 + 1)" = getitem_13 + 1
|
|
add_6: "Sym(u14 + 1)" = getitem_14 + 1
|
|
add_7: "f32[2, 3]" = torch.ops.aten.add.Tensor(getitem_7, 1)
|
|
|
|
add_8: "f32[2, 3]" = torch.ops.aten.add.Tensor(t, getitem_8); getitem_8 = None
|
|
add_9: "f32[2, 3]" = torch.ops.aten.add.Tensor(t, getitem_9); getitem_9 = None
|
|
add_10: "f32[2, 3]" = torch.ops.aten.add.Tensor(t, getitem_10); getitem_10 = None
|
|
add_11: "f32[2, 3]" = torch.ops.aten.add.Tensor(t, getitem_11); getitem_11 = None
|
|
add_12: "f32[2, 3]" = torch.ops.aten.add.Tensor(t, getitem_12); getitem_12 = None
|
|
add_13: "f32[2, 3]" = torch.ops.aten.add.Tensor(t, getitem_13); getitem_13 = None
|
|
add_14: "f32[2, 3]" = torch.ops.aten.add.Tensor(t, getitem_14); getitem_14 = None
|
|
add_15: "f32[2, 3]" = torch.ops.aten.add.Tensor(getitem_7, t); getitem_7 = t = None
|
|
return pytree.tree_unflatten((add, add_1, add_2, add_3, add_4, add_5, add_6, add_7, add_8, add_9, add_10, add_11, add_12, add_13, add_14, add_15), self._out_spec)
|
|
|
|
class while_loop_cond_graph_0(torch.nn.Module):
|
|
def forward(self, a_1: "Sym(u1)", b_1: "Sym(u2)", c1_1: "Sym(u3)", c2_1: "Sym(u4)", c3_1: "Sym(u5)", c0_1: "Sym(u6)", u0_1: "Sym(u7)", x_1: "f32[2, 3]"):
|
|
mul: "Sym(u3*u4)" = c1_1 * c2_1; c1_1 = c2_1 = None
|
|
mul_1: "Sym(u3*u4*u5)" = mul * c3_1; mul = c3_1 = None
|
|
mul_2: "Sym(u1*u2)" = a_1 * b_1; a_1 = b_1 = None
|
|
lt: "Sym(u3*u4*u5 < u1*u2)" = mul_1 < mul_2; mul_1 = mul_2 = None
|
|
return lt
|
|
|
|
class while_loop_body_graph_0(torch.nn.Module):
|
|
def forward(self, a_1: "Sym(u1)", b_1: "Sym(u2)", c1_1: "Sym(u3)", c2_1: "Sym(u4)", c3_1: "Sym(u5)", c0_1: "Sym(u6)", u0_1: "Sym(u7)", x_1: "f32[2, 3]"):
|
|
add: "Sym(u7 + 1)" = u0_1 + 1; u0_1 = None
|
|
add_1: "f32[2, 3]" = torch.ops.aten.add.Tensor(x_1, 1); x_1 = None
|
|
return (b_1, c1_1, c2_1, c3_1, a_1, 0, add, add_1)
|
|
""", # noqa: B950
|
|
)
|
|
|
|
@skipIfTorchDynamo("Graph is not captured correctly when test with dynamo")
|
|
@parametrize("dynamic", [True, False])
|
|
@parametrize("backend", ["eager", "aot_eager"])
|
|
@torch._dynamo.config.patch(capture_scalar_outputs=True)
|
|
def test_while_loop_op_constant_and_symint_output_compile(self, dynamic, backend):
|
|
m, args = WHILE_LOOP_TESTS["const_and_symint_output"]
|
|
if backend == "eager":
|
|
backend = EagerAndRecordGraphs()
|
|
self._check_compile(m, args, dynamic=dynamic, backend=backend)
|
|
if (
|
|
isinstance(backend, EagerAndRecordGraphs)
|
|
# cross ref or dynamic gives a slightly different graph
|
|
and not dynamic
|
|
and not TEST_WITH_CROSSREF
|
|
):
|
|
self.assertEqual(len(backend.graphs), 1)
|
|
self.assertExpectedInline(
|
|
normalize_gm(backend.graphs[0].print_readable(print_output=False)),
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(self, L_t_: "f32[2, 3]"):
|
|
l_t_ = L_t_
|
|
|
|
sum_1: "f32[]" = l_t_.sum()
|
|
to: "i64[]" = sum_1.to(torch.int64); sum_1 = None
|
|
item: "Sym(u0)" = to.item(); to = None
|
|
sin: "f32[2, 3]" = l_t_.sin()
|
|
|
|
cond_fn_0 = self.cond_fn_0
|
|
body_fn_0 = self.body_fn_0
|
|
while_loop = torch.ops.higher_order.while_loop(cond_fn_0, body_fn_0, (2, 3, 1, 1, 1, 3, item, sin), ()); cond_fn_0 = body_fn_0 = item = sin = None
|
|
|
|
getitem_8: "Sym(u15)" = while_loop[0]
|
|
getitem_9: "Sym(u16)" = while_loop[1]
|
|
getitem_10: "Sym(u17)" = while_loop[2]
|
|
getitem_11: "Sym(u18)" = while_loop[3]
|
|
getitem_12: "Sym(u19)" = while_loop[4]
|
|
getitem_13: "Sym(u20)" = while_loop[5]
|
|
getitem_14: "Sym(u21)" = while_loop[6]
|
|
|
|
child: "f32[2, 3]" = while_loop[7]; while_loop = None
|
|
|
|
add: "Sym(u15 + 1)" = getitem_8 + 1
|
|
add_1: "Sym(u16 + 1)" = getitem_9 + 1
|
|
add_2: "Sym(u17 + 1)" = getitem_10 + 1
|
|
add_3: "Sym(u18 + 1)" = getitem_11 + 1
|
|
add_4: "Sym(u19 + 1)" = getitem_12 + 1
|
|
add_5: "Sym(u20 + 1)" = getitem_13 + 1
|
|
add_6: "Sym(u21 + 1)" = getitem_14 + 1
|
|
add_7: "f32[2, 3]" = child + 1
|
|
|
|
add_8: "f32[2, 3]" = getitem_8 + l_t_; getitem_8 = None
|
|
add_9: "f32[2, 3]" = getitem_9 + l_t_; getitem_9 = None
|
|
add_10: "f32[2, 3]" = getitem_10 + l_t_; getitem_10 = None
|
|
add_11: "f32[2, 3]" = getitem_11 + l_t_; getitem_11 = None
|
|
add_12: "f32[2, 3]" = getitem_12 + l_t_; getitem_12 = None
|
|
add_13: "f32[2, 3]" = getitem_13 + l_t_; getitem_13 = None
|
|
add_14: "f32[2, 3]" = getitem_14 + l_t_; getitem_14 = None
|
|
add_15: "f32[2, 3]" = child + l_t_; child = l_t_ = None
|
|
return (add, add_1, add_2, add_3, add_4, add_5, add_6, add_7, add_8, add_9, add_10, add_11, add_12, add_13, add_14, add_15)
|
|
|
|
class cond_fn_0(torch.nn.Module):
|
|
def forward(self, unbacked_symint: "Sym(u1)", unbacked_symint_0: "Sym(u2)", unbacked_symint_1: "Sym(u3)", unbacked_symint_2: "Sym(u4)", unbacked_symint_3: "Sym(u5)", unbacked_symint_4: "Sym(u6)", unbacked_symint_5: "Sym(u7)", child: "f32[2, 3]"):
|
|
mul: "Sym(u3*u4)" = unbacked_symint_1 * unbacked_symint_2; unbacked_symint_1 = unbacked_symint_2 = None
|
|
mul_1: "Sym(u3*u4*u5)" = mul * unbacked_symint_3; mul = unbacked_symint_3 = None
|
|
mul_2: "Sym(u1*u2)" = unbacked_symint * unbacked_symint_0; unbacked_symint = unbacked_symint_0 = None
|
|
lt: "Sym(u3*u4*u5 < u1*u2)" = mul_1 < mul_2; mul_1 = mul_2 = None
|
|
return lt
|
|
|
|
class body_fn_0(torch.nn.Module):
|
|
def forward(self, unbacked_symint_6: "Sym(u8)", unbacked_symint_7: "Sym(u9)", unbacked_symint_8: "Sym(u10)", unbacked_symint_9: "Sym(u11)", unbacked_symint_10: "Sym(u12)", unbacked_symint_11: "Sym(u13)", unbacked_symint_12: "Sym(u14)", child_1: "f32[2, 3]"):
|
|
add: "Sym(u14 + 1)" = unbacked_symint_12 + 1; unbacked_symint_12 = None
|
|
child: "f32[2, 3]" = child_1 + 1; child_1 = None
|
|
return (unbacked_symint_7, unbacked_symint_8, unbacked_symint_9, unbacked_symint_10, unbacked_symint_6, 0, add, child)
|
|
""", # noqa: B950
|
|
)
|
|
|
|
@skipIfTorchDynamo("Skip because we're testing export")
|
|
@parametrize("strict", [True, False])
|
|
@parametrize("dynamic", [True, False])
|
|
def test_while_loop_op_pytree_int_carry_export(self, strict, dynamic):
|
|
m, args = WHILE_LOOP_TESTS["pytree_int_carry"]
|
|
dynamic_shapes = {"x": {0: torch.export.Dim("dim_x")}} if dynamic else None
|
|
ep = self._check_export(m, args, strict=strict, dynamic_shapes=dynamic_shapes)
|
|
if strict and dynamic:
|
|
self.assertExpectedInline(
|
|
normalize_gm(ep.module().print_readable(print_output=False)),
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(self, x):
|
|
x: "f32[s77, 3]";
|
|
|
|
x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
|
|
sym_size_int_1: "Sym(s77)" = torch.ops.aten.sym_size.int(x, 0)
|
|
|
|
sin: "f32[s77, 3]" = torch.ops.aten.sin.default(x); x = None
|
|
|
|
while_loop_cond_graph_0 = self.while_loop_cond_graph_0
|
|
while_loop_body_graph_0 = self.while_loop_body_graph_0
|
|
while_loop = torch.ops.higher_order.while_loop(while_loop_cond_graph_0, while_loop_body_graph_0, (sym_size_int_1, 3, 2, 2, 3, sin), ()); while_loop_cond_graph_0 = while_loop_body_graph_0 = sym_size_int_1 = sin = None
|
|
|
|
getitem_6: "Sym(u10)" = while_loop[0]
|
|
getitem_7: "Sym(u11)" = while_loop[1]
|
|
getitem_8: "Sym(u12)" = while_loop[2]
|
|
getitem_9: "Sym(u13)" = while_loop[3]
|
|
getitem_10: "Sym(u14)" = while_loop[4]
|
|
|
|
getitem_5: "f32[s77, 3]" = while_loop[5]; while_loop = None
|
|
|
|
add: "Sym(u12 + 1)" = getitem_8 + 1
|
|
add_1: "Sym(u13 + 1)" = getitem_9 + 1
|
|
add_2: "Sym(u14 + 1)" = getitem_10 + 1
|
|
|
|
add_3: "f32[s77, 3]" = torch.ops.aten.add.Tensor(getitem_5, getitem_8); getitem_8 = None
|
|
add_4: "f32[s77, 3]" = torch.ops.aten.add.Tensor(getitem_5, getitem_9); getitem_9 = None
|
|
add_5: "f32[s77, 3]" = torch.ops.aten.add.Tensor(getitem_5, getitem_10); getitem_10 = None
|
|
return pytree.tree_unflatten((getitem_6, getitem_7, add, add_1, add_2, add_3, add_4, add_5, getitem_5), self._out_spec)
|
|
|
|
class while_loop_cond_graph_0(torch.nn.Module):
|
|
def forward(self, arg0_1: "Sym(u20)", arg1_1: "Sym(u21)", arg2_1: "Sym(u22)", arg3_1: "Sym(u23)", arg4_1: "Sym(u24)", arg5_1: "f32[s77, 3]"):
|
|
mul: "Sym(u22*u23)" = arg2_1 * arg3_1; arg2_1 = arg3_1 = None
|
|
mul_1: "Sym(u22*u23*u24)" = mul * arg4_1; mul = arg4_1 = None
|
|
mul_2: "Sym(u20*u21)" = arg0_1 * arg1_1; arg0_1 = arg1_1 = None
|
|
lt: "Sym(u22*u23*u24 < u20*u21)" = mul_1 < mul_2; mul_1 = mul_2 = None
|
|
return lt
|
|
|
|
class while_loop_body_graph_0(torch.nn.Module):
|
|
def forward(self, arg0_1: "Sym(u20)", arg1_1: "Sym(u21)", arg2_1: "Sym(u22)", arg3_1: "Sym(u23)", arg4_1: "Sym(u24)", arg5_1: "f32[s77, 3]"):
|
|
add: "Sym(u20 + 1)" = arg0_1 + 1; arg0_1 = None
|
|
add_1: "Sym(u21 + 1)" = arg1_1 + 1; arg1_1 = None
|
|
|
|
add_2: "Sym(u22 + 1)" = arg2_1 + 1; arg2_1 = None
|
|
add_3: "Sym(u23 + 1)" = arg3_1 + 1; arg3_1 = None
|
|
add_4: "Sym(u24 + 1)" = arg4_1 + 1; arg4_1 = None
|
|
|
|
add_5: "f32[s77, 3]" = torch.ops.aten.add.Tensor(arg5_1, 1); arg5_1 = None
|
|
return (add, add_1, add_2, add_3, add_4, add_5)
|
|
""", # noqa: B950
|
|
)
|
|
|
|
@skipIfTorchDynamo("Graph is not captured correctly when test with dynamo")
|
|
@parametrize("dynamic", [True, False])
|
|
@parametrize("backend", ["eager", "aot_eager"])
|
|
def test_while_loop_op_pytree_int_carry_compile(self, dynamic, backend):
|
|
m, args = WHILE_LOOP_TESTS["pytree_int_carry"]
|
|
if backend == "eager":
|
|
backend = EagerAndRecordGraphs()
|
|
self._check_compile(m, args, dynamic=dynamic, backend=backend)
|
|
if (
|
|
isinstance(backend, EagerAndRecordGraphs)
|
|
and dynamic
|
|
and not TEST_WITH_CROSSREF
|
|
):
|
|
self.assertEqual(len(backend.graphs), 1)
|
|
self.assertExpectedInline(
|
|
normalize_gm(backend.graphs[0].print_readable(print_output=False)),
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(self, s77: "Sym(s77)", s27: "Sym(s27)", L_x_: "f32[s77, s27]"):
|
|
l_x_ = L_x_
|
|
|
|
child: "f32[s77, s27]" = l_x_.sin(); l_x_ = None
|
|
|
|
cond_fn_0 = self.cond_fn_0
|
|
body_fn_0 = self.body_fn_0
|
|
while_loop = torch.ops.higher_order.while_loop(cond_fn_0, body_fn_0, (s77, s27, 2, 2, 3, child), (s27, s77)); cond_fn_0 = body_fn_0 = s77 = s27 = child = None
|
|
|
|
getitem_10: "Sym(u10)" = while_loop[0]
|
|
getitem_11: "Sym(u11)" = while_loop[1]
|
|
getitem_12: "Sym(u12)" = while_loop[2]
|
|
getitem_13: "Sym(u13)" = while_loop[3]
|
|
getitem_14: "Sym(u14)" = while_loop[4]
|
|
|
|
out_x: "f32[s77, s27]" = while_loop[5]; while_loop = None
|
|
|
|
add: "Sym(u12 + 1)" = getitem_12 + 1
|
|
add_1: "Sym(u13 + 1)" = getitem_13 + 1
|
|
add_2: "Sym(u14 + 1)" = getitem_14 + 1
|
|
|
|
add_3: "f32[s77, s27]" = getitem_12 + out_x; getitem_12 = None
|
|
add_4: "f32[s77, s27]" = getitem_13 + out_x; getitem_13 = None
|
|
add_5: "f32[s77, s27]" = getitem_14 + out_x; getitem_14 = None
|
|
return (getitem_10, getitem_11, add, add_1, add_2, add_3, add_4, add_5, out_x)
|
|
|
|
class cond_fn_0(torch.nn.Module):
|
|
def forward(self, unbacked_symint: "Sym(u0)", unbacked_symint_0: "Sym(u1)", unbacked_symint_1: "Sym(u2)", unbacked_symint_2: "Sym(u3)", unbacked_symint_3: "Sym(u4)", child_1: "f32[s77, s27]", s27: "Sym(s27)", s77: "Sym(s77)"):
|
|
s27_1 = s27
|
|
s77_1 = s77
|
|
|
|
mul: "Sym(u2*u3)" = unbacked_symint_1 * unbacked_symint_2; unbacked_symint_1 = unbacked_symint_2 = None
|
|
mul_1: "Sym(u2*u3*u4)" = mul * unbacked_symint_3; mul = unbacked_symint_3 = None
|
|
mul_2: "Sym(u0*u1)" = unbacked_symint * unbacked_symint_0; unbacked_symint = unbacked_symint_0 = None
|
|
lt: "Sym(u2*u3*u4 < u0*u1)" = mul_1 < mul_2; mul_1 = mul_2 = None
|
|
return lt
|
|
|
|
class body_fn_0(torch.nn.Module):
|
|
def forward(self, unbacked_symint_4: "Sym(u5)", unbacked_symint_5: "Sym(u6)", unbacked_symint_6: "Sym(u7)", unbacked_symint_7: "Sym(u8)", unbacked_symint_8: "Sym(u9)", child_2: "f32[s77, s27]", s27: "Sym(s27)", s77: "Sym(s77)"):
|
|
s27_1 = s27
|
|
s77_1 = s77
|
|
|
|
add: "Sym(u5 + 1)" = unbacked_symint_4 + 1; unbacked_symint_4 = None
|
|
add_1: "Sym(u6 + 1)" = unbacked_symint_5 + 1; unbacked_symint_5 = None
|
|
|
|
add_2: "Sym(u7 + 1)" = unbacked_symint_6 + 1; unbacked_symint_6 = None
|
|
add_3: "Sym(u8 + 1)" = unbacked_symint_7 + 1; unbacked_symint_7 = None
|
|
add_4: "Sym(u9 + 1)" = unbacked_symint_8 + 1; unbacked_symint_8 = None
|
|
|
|
child: "f32[s77, s27]" = child_2 + 1; child_2 = None
|
|
return (add, add_1, add_2, add_3, add_4, child)
|
|
""", # noqa: B950
|
|
)
|
|
|
|
def test_input_output_alias(self):
|
|
def fn(f, *args):
|
|
return torch.cond(args[0].sum() > 0, f, f, args)
|
|
|
|
x = torch.randn(2, 2)
|
|
for f in ALIAS_FN:
|
|
with self.assertRaisesRegex(
|
|
# Should be
|
|
# torch._dynamo.exc.Unsupported,
|
|
# "Encountered aliasing during higher order op tracing for HOP.*"
|
|
torch._dynamo.exc.UncapturedHigherOrderOpError,
|
|
"Cond doesn't work unless it is captured completely with torch.compile.*",
|
|
):
|
|
torch.compile(fn)(f, x)
|
|
|
|
def test_input_input_alias(self):
|
|
def fn(view_f, arg):
|
|
def f(arg1, arg2):
|
|
return arg1.cos(), arg2.sin()
|
|
|
|
return torch.cond(arg.sum() > 0, f, f, (arg, view_f(arg)))
|
|
|
|
x = torch.randn(2, 2)
|
|
# ALIAS_FN[0] is an identical function, cond optimizes the duplication
|
|
# as a result of auto lifting.
|
|
for view_f in ALIAS_FN[1:]:
|
|
with self.assertRaisesRegex(
|
|
# Should be
|
|
# torch._dynamo.exc.Unsupported,
|
|
# "Encountered aliasing during higher order op tracing for HOP.*"
|
|
torch._dynamo.exc.UncapturedHigherOrderOpError,
|
|
"Cond doesn't work unless it is captured completely with torch.compile.*",
|
|
):
|
|
torch.compile(fn)(view_f, x)
|
|
|
|
@parametrize("inference_mode", [True, False])
|
|
def test_input_mutation(self, inference_mode):
|
|
def fn(view_f, *args):
|
|
def mutate_f(x):
|
|
v = view_f(x)
|
|
v.add_(1)
|
|
return v.sin()
|
|
|
|
return torch.cond(args[0].sum() > 0, mutate_f, mutate_f, args)
|
|
|
|
x = torch.randn(2, 2)
|
|
for f in ALIAS_FN:
|
|
with self.assertRaisesRegex(
|
|
# Should be
|
|
# torch._dynamo.exc.Unsupported,
|
|
# "Encountered aliasing during higher order op tracing for HOP.*"
|
|
torch._dynamo.exc.UncapturedHigherOrderOpError,
|
|
"Cond doesn't work unless it is captured completely with torch.compile.*",
|
|
):
|
|
torch.compile(fn)(f, x)
|
|
|
|
with self.assertRaisesRegex(
|
|
# Should be
|
|
# torch._dynamo.exc.Unsupported,
|
|
# "Encountered aliasing during higher order op tracing for HOP.*"
|
|
torch._dynamo.exc.UncapturedHigherOrderOpError,
|
|
"Cond doesn't work unless it is captured completely with torch.compile.*",
|
|
):
|
|
with torch.inference_mode(inference_mode):
|
|
torch.compile(fn)(f, x)
|
|
|
|
@skipIfTorchDynamo("Graph is not captured correctly when test with dynamo")
|
|
def test_while_loop_unbacked_bindings(self):
|
|
m, args = WHILE_LOOP_TESTS["pytree_int_carry"]
|
|
backend = EagerAndRecordGraphs()
|
|
self._check_compile(m, args, dynamic=True, backend=backend)
|
|
self.assertEqual(len(backend.graphs), 1)
|
|
while_loop_nodes = backend.graphs[0].graph.find_nodes(
|
|
op="call_function", target=torch.ops.higher_order.while_loop
|
|
)
|
|
self.assertEqual(len(while_loop_nodes), 1)
|
|
self.assertEqual(len(while_loop_nodes[0].meta.get("unbacked_bindings")), 5)
|
|
|
|
# Return the .module() graph str result of non-strict export
|
|
def _check_export_ret_graph_str(self, fn, args, dynamic_shapes=None) -> str:
|
|
strict_ep = torch.export.export(
|
|
fn, args, dynamic_shapes=dynamic_shapes, strict=True
|
|
)
|
|
non_strict_ep = torch.export.export(
|
|
fn, args, dynamic_shapes=dynamic_shapes, strict=False
|
|
)
|
|
eager_res = fn(*args)
|
|
self.assertEqual(strict_ep.module()(*args), eager_res)
|
|
self.assertEqual(non_strict_ep.module()(*args), eager_res)
|
|
return normalize_gm(non_strict_ep.module().print_readable(print_output=False))
|
|
|
|
@skipIfTorchDynamo("Skip because dynamo cannot trace torch.export.")
|
|
def test_cond_eager_run_with_item(self):
|
|
class M(torch.nn.Module):
|
|
def forward(self, a, b1, b2, c):
|
|
def true_fn(x):
|
|
return x * b1.item()
|
|
|
|
def false_fn(x):
|
|
return x * b2.item()
|
|
|
|
r = torch.cond(a, true_fn, false_fn, (c,))
|
|
return r * 2
|
|
|
|
x = torch.randn(10, requires_grad=True)
|
|
args = (
|
|
torch.tensor(True),
|
|
torch.tensor([3]),
|
|
torch.tensor([4]),
|
|
x,
|
|
)
|
|
model = M()
|
|
torch.export.export(model, args, strict=True)
|
|
graph_str = self._check_export_ret_graph_str(model, args, None)
|
|
self.assertExpectedInline(
|
|
graph_str,
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(self, a, b1, b2, c):
|
|
a: "b8[]"; b1: "i64[1]"; b2: "i64[1]"; c: "f32[10]";
|
|
|
|
a, b1, b2, c, = fx_pytree.tree_flatten_spec(([a, b1, b2, c], {}), self._in_spec)
|
|
true_graph_0 = self.true_graph_0
|
|
false_graph_0 = self.false_graph_0
|
|
cond = torch.ops.higher_order.cond(a, true_graph_0, false_graph_0, (c, b1, b2)); a = true_graph_0 = false_graph_0 = c = b1 = b2 = None
|
|
getitem: "f32[10]" = cond[0]; cond = None
|
|
|
|
mul: "f32[10]" = torch.ops.aten.mul.Tensor(getitem, 2); getitem = None
|
|
return pytree.tree_unflatten((mul,), self._out_spec)
|
|
|
|
class true_graph_0(torch.nn.Module):
|
|
def forward(self, c: "f32[10]", b1: "i64[1]", b2: "i64[1]"):
|
|
item: "Sym(u0)" = torch.ops.aten.item.default(b1); b1 = None
|
|
|
|
mul: "f32[10]" = torch.ops.aten.mul.Tensor(c, item); c = item = None
|
|
return (mul,)
|
|
|
|
class false_graph_0(torch.nn.Module):
|
|
def forward(self, c: "f32[10]", b1: "i64[1]", b2: "i64[1]"):
|
|
item: "Sym(u1)" = torch.ops.aten.item.default(b2); b2 = None
|
|
|
|
mul: "f32[10]" = torch.ops.aten.mul.Tensor(c, item); c = item = None
|
|
return (mul,)
|
|
""", # noqa: B950
|
|
)
|
|
|
|
def test_cond_merge_graph_preserves_ph_meta(self):
|
|
class M(torch.nn.Module):
|
|
def forward(self, x, y, z):
|
|
a = y.shape[0]
|
|
b = z.shape[0]
|
|
|
|
def true_fn(x):
|
|
return x + a
|
|
|
|
def false_fn(x):
|
|
return x + b * z
|
|
|
|
return torch.cond(x.sum() > 5, true_fn, false_fn, (x,))
|
|
|
|
backend = EagerAndRecordGraphs()
|
|
_ = torch.compile(M(), backend=backend)(
|
|
torch.randn(3, 4), torch.randn(3, 4), torch.randn(3, 4)
|
|
)
|
|
self.assertEqual(len(backend.graphs), 1)
|
|
gm = backend.graphs[0]
|
|
subgraph_attr = gm.graph.find_nodes(op="get_attr")[0]
|
|
subgm = getattr(gm, subgraph_attr.target)
|
|
for ph in subgm.graph.find_nodes(op="placeholder"):
|
|
self.assertTrue("example_value" in ph.meta)
|
|
|
|
@skipIfTorchDynamo("Skip because dynamo cannot trace torch.export.")
|
|
def test_cond_symint_closure(self):
|
|
from torch.export import Dim
|
|
|
|
class M(torch.nn.Module):
|
|
def forward(self, x, y, z):
|
|
a = y.shape[0]
|
|
b = z.shape[0]
|
|
|
|
def true_fn(x):
|
|
return x + a
|
|
|
|
def false_fn(x):
|
|
return x + b * z
|
|
|
|
# When exporting with non-strict: a and b are symints,
|
|
# so torch.compile need to wrap and trace symint inputs.
|
|
return torch.cond(x.shape[0] > 5, true_fn, false_fn, (x,))
|
|
|
|
args = (torch.ones(3, 3), torch.ones(5), torch.ones(3, 3))
|
|
model = M()
|
|
dynamic_shapes = {"x": {0: Dim("d")}, "y": {0: Dim("d1")}, "z": {0: Dim("d")}}
|
|
non_strict_graph_str = self._check_export_ret_graph_str(
|
|
model, args, dynamic_shapes
|
|
)
|
|
self.assertExpectedInline(
|
|
non_strict_graph_str,
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(self, x, y, z):
|
|
x: "f32[s68, 3]"; y: "f32[s17]"; z: "f32[s68, 3]";
|
|
|
|
x, y, z, = fx_pytree.tree_flatten_spec(([x, y, z], {}), self._in_spec)
|
|
sym_size_int_4: "Sym(s17)" = torch.ops.aten.sym_size.int(y, 0); y = None
|
|
sym_size_int_5: "Sym(s68)" = torch.ops.aten.sym_size.int(z, 0)
|
|
|
|
gt: "Sym(s68 > 5)" = sym_size_int_5 > 5
|
|
|
|
true_graph_0 = self.true_graph_0
|
|
false_graph_0 = self.false_graph_0
|
|
cond = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, (x, sym_size_int_4, sym_size_int_5, z)); gt = true_graph_0 = false_graph_0 = x = sym_size_int_4 = sym_size_int_5 = z = None
|
|
getitem: "f32[s68, 3]" = cond[0]; cond = None
|
|
return pytree.tree_unflatten((getitem,), self._out_spec)
|
|
|
|
class true_graph_0(torch.nn.Module):
|
|
def forward(self, x: "f32[s68, 3]", sym_size_int_4: "Sym(s17)", sym_size_int_5: "Sym(s68)", z: "f32[s68, 3]"):
|
|
add: "f32[s68, 3]" = torch.ops.aten.add.Tensor(x, sym_size_int_4); x = sym_size_int_4 = None
|
|
return (add,)
|
|
|
|
class false_graph_0(torch.nn.Module):
|
|
def forward(self, x: "f32[s68, 3]", sym_size_int_4: "Sym(s17)", sym_size_int_5: "Sym(s68)", z: "f32[s68, 3]"):
|
|
mul: "f32[s68, 3]" = torch.ops.aten.mul.Tensor(z, sym_size_int_5); z = sym_size_int_5 = None
|
|
|
|
add: "f32[s68, 3]" = torch.ops.aten.add.Tensor(x, mul); x = mul = None
|
|
return (add,)
|
|
""", # noqa: B950
|
|
)
|
|
|
|
# unbacked symint inputs are created during non-strict export,
|
|
# which causes a graph break
|
|
@unittest.expectedFailure
|
|
def test_cond_unbacked_symint_closure(self):
|
|
from torch.export import Dim
|
|
|
|
class M(torch.nn.Module):
|
|
def forward(self, x, y, z):
|
|
a = y.shape[0]
|
|
b = z.shape[0]
|
|
# c is an unbacked symint in non-strict export
|
|
c = y.sum().item()
|
|
|
|
def true_fn(x):
|
|
return x + a + c
|
|
|
|
def false_fn(x):
|
|
return x + b * z * c
|
|
|
|
# When exporting with non-strict: a and b are symints,
|
|
# so torch.compile need to wrap and trace symint inputs.
|
|
return torch.cond(x.shape[0] > 5, true_fn, false_fn, (x,))
|
|
|
|
args = (torch.ones(3, 3), torch.ones(5, dtype=torch.int32), torch.ones(3, 3))
|
|
model = M()
|
|
dynamic_shapes = {"x": {0: Dim("d")}, "y": {0: Dim("d1")}, "z": {0: Dim("d")}}
|
|
_ = self._check_export_ret_graph_str(model, args, dynamic_shapes)
|
|
|
|
@skipIfTorchDynamo(
|
|
"Skip because _merge_output is not intended for dynamo to compile"
|
|
)
|
|
def test_merge_output(self):
|
|
from torch._higher_order_ops.cond import _merge_output
|
|
from torch._subclasses.fake_tensor import FakeTensorMode
|
|
from torch.fx.experimental.symbolic_shapes import ShapeEnv
|
|
|
|
# The shapes and strides are from raondomly generated pairs of tensors then swapaxes
|
|
valid_test_cases = [
|
|
# [(size1, stride1), (size2, stride2), (expected_stride, expected_size)]
|
|
[((3,), (1,)), ((4,), (1,)), ("(u0,)", "(1,)")],
|
|
[((1, 3), (3, 1)), ((3, 2), (2, 1)), ("(u0, u1)", "(u1, 1)")],
|
|
[((2, 1), (1, 1)), ((7, 3), (3, 1)), ("(u0, u1)", "(u1, 1)")],
|
|
[((5, 5), (1, 5)), ((4, 5), (1, 4)), ("(u0, 5)", "(1, u0)")],
|
|
[
|
|
((7, 3, 1), (1, 7, 1)),
|
|
((4, 3, 3), (3, 12, 1)),
|
|
("(u0, 3, u1)", "(u1, u0*u1, 1)"),
|
|
],
|
|
[
|
|
((5, 7, 4), (7, 1, 35)),
|
|
((7, 4, 4), (4, 1, 28)),
|
|
("(u0, u1, 4)", "(u1, 1, u0*u1)"),
|
|
],
|
|
[
|
|
((1, 6, 3, 2), (36, 1, 6, 18)),
|
|
((4, 2, 2, 6), (24, 1, 2, 4)),
|
|
("(u0, u1, u2, u3)", "(u1*u2*u3, 1, u1, u1*u2)"),
|
|
],
|
|
[
|
|
((6, 1, 6, 3), (18, 1, 1, 6)),
|
|
((2, 1, 3, 4), (12, 1, 1, 3)),
|
|
("(u0, 1, u1, u2)", "(u1*u2, 1, 1, u1)"),
|
|
],
|
|
[
|
|
((3, 1, 2, 4, 1), (8, 8, 4, 1, 1)),
|
|
((2, 4, 1, 4, 1), (16, 4, 4, 1, 1)),
|
|
("(u0, u1, u2, 4, 1)", "(4*u1*u2, 4*u2, 4, 1, 1)"),
|
|
],
|
|
]
|
|
|
|
def _inner(case):
|
|
fake_mode = FakeTensorMode(shape_env=ShapeEnv())
|
|
|
|
(size1, stride1), (size2, stride2), (merged_size, merged_stride) = case
|
|
with fake_mode:
|
|
t1 = torch.empty_strided(size1, stride1)
|
|
t2 = torch.empty_strided(size2, stride2)
|
|
out = _merge_output(t1, t2, fake_mode)
|
|
self.assertEqual(str(tuple(out.size())), merged_size)
|
|
self.assertEqual(str(tuple(out.stride())), merged_stride)
|
|
|
|
for case in valid_test_cases:
|
|
_inner(case)
|
|
|
|
# The shapes and strides are from raondomly generated pairs of tensors then swapaxes
|
|
invalid_test_cases = [
|
|
# [(size1, stride1), (size2, stride2)]
|
|
[((1,), (1,)), ((1,), (0,))],
|
|
[
|
|
((1, 3), (1, 1)),
|
|
((5, 6), (6, 1)),
|
|
], # t1 is not contiguous, t2 is contiguous
|
|
[
|
|
((2, 1), (1, 1)),
|
|
((7, 3), (1, 3)),
|
|
], # t1 is contiguous, t2 is not contiguous
|
|
[
|
|
((5, 4), (4, 1)),
|
|
((5, 5), (1, 5)),
|
|
], # t1 is contiguous, t2 is not contiguous
|
|
[((7, 3, 1), (1, 7, 1)), ((4, 3, 3), (9, 1, 3))], # layout is different
|
|
[((5, 7, 4), (7, 1, 35)), ((7, 4, 4), (4, 28, 1))], # layout is different
|
|
[
|
|
((1, 6, 3, 2), (36, 1, 6, 18)),
|
|
((4, 1, 1, 6), (1, 4, 4, 4)),
|
|
], # layout is different
|
|
[
|
|
((6, 1, 6, 3), (18, 1, 1, 6)),
|
|
((1, 1, 1, 1), (1, 1, 1, 1)),
|
|
], # layout is different
|
|
[
|
|
((6, 1, 1, 6, 3), (3, 18, 18, 18, 1)),
|
|
((5, 1, 2, 1, 1), (2, 10, 1, 10, 1)),
|
|
], # layout is different
|
|
]
|
|
for case in invalid_test_cases:
|
|
with self.assertRaisesRegex(Exception, r"."):
|
|
_inner(case)
|
|
|
|
@parametrize("dynamic", [True, False])
|
|
@parametrize("backend", ["eager", "aot_eager"])
|
|
def test_cond_mismatched_branch_output(self, dynamic, backend):
|
|
class M(torch.nn.Module):
|
|
def forward(self, x, y, z):
|
|
a = y.shape[0]
|
|
b = z.shape[0]
|
|
|
|
def true_fn(x):
|
|
# clone the outputs so branches have the same storage_offset
|
|
return (x + a)[2:].clone()
|
|
|
|
def false_fn(x):
|
|
# clone the outputs so branches have the same storage_offset
|
|
return (x + b * z)[:2].clone()
|
|
|
|
ret = torch.cond(x.sum() > 0, true_fn, false_fn, (x,))
|
|
return y.sum() - ret
|
|
|
|
m = M()
|
|
x, y, z = torch.randn(5, 4), torch.randn(5, 4), torch.randn(5, 4)
|
|
out = m(x, y, z)
|
|
if not (backend == "eager" and dynamic and not TEST_WITH_CROSSREF):
|
|
compiled_out = torch.compile(
|
|
m, backend=backend, dynamic=dynamic, fullgraph=True
|
|
)(x, y, z)
|
|
self.assertEqual(compiled_out, out)
|
|
else:
|
|
bk = EagerAndRecordGraphs()
|
|
compiled_out = torch.compile(
|
|
m, backend=bk, dynamic=dynamic, fullgraph=True
|
|
)(x, y, z)
|
|
self.assertEqual(compiled_out, out)
|
|
self.assertExpectedInline(
|
|
normalize_gm(bk.graphs[0].print_readable(print_output=False)),
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(self, s17: "Sym(s17)", s94: "Sym(s94)", L_y_: "f32[s17, s94]", L_z_: "f32[s17, s94]", L_x_: "f32[s17, s94]"):
|
|
l_y_ = L_y_
|
|
l_z_ = L_z_
|
|
l_x_ = L_x_
|
|
|
|
sum_1: "f32[]" = l_x_.sum()
|
|
gt: "b8[]" = sum_1 > 0; sum_1 = None
|
|
|
|
cond_true_0 = self.cond_true_0
|
|
cond_false_0 = self.cond_false_0
|
|
cond = torch.ops.higher_order.cond(gt, cond_true_0, cond_false_0, (l_x_, s94, s17, s17, l_z_)); gt = cond_true_0 = cond_false_0 = l_x_ = s94 = s17 = l_z_ = None
|
|
|
|
getitem_5: "f32[u0, s94]" = cond[0]
|
|
sym_size_int: "Sym(u0)" = torch.ops.aten.sym_size.int(getitem_5, 0); getitem_5 = None
|
|
_check_is_size = torch._check_is_size(sym_size_int); _check_is_size = None
|
|
|
|
ge: "Sym(u0 >= 0)" = sym_size_int >= 0; sym_size_int = None
|
|
_assert_scalar_default = torch.ops.aten._assert_scalar.default(ge, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge = _assert_scalar_default = None
|
|
ret: "f32[u0, s94]" = cond[0]; cond = None
|
|
|
|
sum_2: "f32[]" = l_y_.sum(); l_y_ = None
|
|
sub: "f32[u0, s94]" = sum_2 - ret; sum_2 = ret = None
|
|
return (sub,)
|
|
|
|
class cond_true_0(torch.nn.Module):
|
|
def forward(self, l_x_: "f32[s17, s94]", s94: "Sym(s94)", s17_true_branch: "Sym(s17)", getitem_2_false_branch: "Sym(s17)", l_z__false_branch: "f32[s17, s94]"):
|
|
l_x__1 = l_x_
|
|
s94_1 = s94
|
|
|
|
add: "f32[s17, s94]" = l_x__1 + s17_true_branch; l_x__1 = s17_true_branch = None
|
|
getitem: "f32[s17 - 2, s94]" = add[slice(2, None, None)]; add = None
|
|
clone: "f32[s17 - 2, s94]" = getitem.clone(); getitem = None
|
|
return (clone,)
|
|
|
|
class cond_false_0(torch.nn.Module):
|
|
def forward(self, l_x_: "f32[s17, s94]", s94: "Sym(s94)", s17_true_branch: "Sym(s17)", getitem_2_false_branch: "Sym(s17)", l_z__false_branch: "f32[s17, s94]"):
|
|
l_x__1 = l_x_
|
|
s94_1 = s94
|
|
|
|
mul: "f32[s17, s94]" = getitem_2_false_branch * l_z__false_branch; getitem_2_false_branch = l_z__false_branch = None
|
|
add: "f32[s17, s94]" = l_x__1 + mul; l_x__1 = mul = None
|
|
getitem: "f32[2, s94]" = add[slice(None, 2, None)]; add = None
|
|
clone: "f32[2, s94]" = getitem.clone(); getitem = None
|
|
return (clone,)
|
|
""", # noqa: B950
|
|
)
|
|
|
|
@parametrize("dynamic", [True, False])
|
|
@parametrize("backend", ["eager", "aot_eager"])
|
|
def test_cond_mismatched_branch_strided_output(self, dynamic, backend):
|
|
class M(torch.nn.Module):
|
|
def forward(self, x, y):
|
|
def true_fn(x, y):
|
|
return (
|
|
(x.swapaxes(-1, 0) + 1)
|
|
.unsqueeze(1)
|
|
.expand(-1, 5, -1, -1, -1, -1, -1),
|
|
torch.empty_strided((3, 3), (0, 1)),
|
|
)
|
|
|
|
def false_fn(x, y):
|
|
return (
|
|
(y.swapaxes(-1, 0) + 1)
|
|
.unsqueeze(1)
|
|
.expand(-1, 4, -1, -1, -1, -1, -1),
|
|
torch.empty_strided((4, 5), (0, 1)),
|
|
)
|
|
|
|
ret = torch.cond(x.sum() > 0, true_fn, false_fn, (x, y))
|
|
return y.sum() + ret[0]
|
|
|
|
m = M()
|
|
x, y = torch.randn(1, 6, 1, 5, 4, 3), torch.randn(1, 4, 5, 1, 3, 8)
|
|
out = m(x, y)
|
|
compiled_out = torch.compile(
|
|
m, backend=backend, dynamic=dynamic, fullgraph=True
|
|
)(x, y)
|
|
self.assertEqual(compiled_out, out)
|
|
|
|
|
|
_hop_schema_test_schema_types = [
|
|
"bool",
|
|
"int",
|
|
"float",
|
|
"str",
|
|
"Tensor",
|
|
"SymInt",
|
|
"SymBool",
|
|
"GraphModule",
|
|
"ScriptObj",
|
|
]
|
|
|
|
|
|
@skipIfTorchDynamo("We don't expect users to torch.compile hop schema generation.")
|
|
@unittest.skipIf(IS_WINDOWS, "Windows not supported for this test")
|
|
class TestHopSchema(TestCase):
|
|
def _get_example_val(self, ty: str):
|
|
from torch.fx.experimental.sym_node import SymNode
|
|
from torch.fx.experimental.symbolic_shapes import ShapeEnv
|
|
|
|
def create_symtype(cls, pytype, shape_env, val):
|
|
from torch._dynamo.source import ConstantSource
|
|
|
|
symbol = shape_env.create_symbol(
|
|
val,
|
|
source=ConstantSource(
|
|
f"__testing_hop_schema{len(shape_env.var_to_val)}"
|
|
),
|
|
)
|
|
return cls(SymNode(symbol, shape_env, pytype, hint=val))
|
|
|
|
if ty == "bool":
|
|
return True
|
|
elif ty == "int":
|
|
return 1
|
|
elif ty == "float":
|
|
return 1.0
|
|
elif ty == "str":
|
|
return "foo"
|
|
elif ty == "Tensor":
|
|
return torch.tensor(1)
|
|
elif ty == "SymInt":
|
|
shape_env = ShapeEnv()
|
|
return create_symtype(torch.SymInt, int, shape_env, 1)
|
|
elif ty == "SymBool":
|
|
shape_env = ShapeEnv()
|
|
return create_symtype(torch.SymBool, bool, shape_env, True)
|
|
elif ty == "GraphModule":
|
|
|
|
def f(x):
|
|
return x.sin()
|
|
|
|
return make_fx(f)(torch.ones(1))
|
|
elif ty == "ScriptObj":
|
|
from torch.testing._internal.torchbind_impls import (
|
|
init_torchbind_implementations,
|
|
)
|
|
|
|
init_torchbind_implementations()
|
|
foo = torch.classes._TorchScriptTesting._Foo(3, 4)
|
|
return foo
|
|
else:
|
|
raise NotImplementedError(ty)
|
|
|
|
@parametrize("schema_type", _hop_schema_test_schema_types)
|
|
def test_type_gen(self, schema_type):
|
|
from torchgen.gen_schema_utils import TypeGen
|
|
|
|
example_val = self._get_example_val(schema_type)
|
|
ty = TypeGen.from_example(example_val)
|
|
# Test the generated type can be parsed
|
|
self.assertEqual(ty.parse(str(ty)), ty)
|
|
|
|
@parametrize("schema_type", _hop_schema_test_schema_types)
|
|
def test_list_gen(self, schema_type):
|
|
from torchgen.gen_schema_utils import TypeGen
|
|
|
|
example_val = self._get_example_val(schema_type)
|
|
li1 = [example_val]
|
|
ty1 = TypeGen.from_example(li1)
|
|
ty2 = TypeGen.from_example(li1)
|
|
self.assertEqual(ty1.parse(str(ty1)), ty1)
|
|
self.assertEqual(ty2.parse(str(ty2)), ty2)
|
|
|
|
def test_function_schema_gen(self):
|
|
from torchgen.gen_schema_utils import FunctionSchemaGen
|
|
|
|
inps = [
|
|
(schema_type + "_v", self._get_example_val(schema_type))
|
|
for schema_type in _hop_schema_test_schema_types
|
|
]
|
|
schema1 = FunctionSchemaGen.from_example("test_op1", inps, torch.ones(1))
|
|
schema2 = FunctionSchemaGen.from_example(
|
|
"test_op2",
|
|
inps,
|
|
[
|
|
torch.ones(1),
|
|
],
|
|
)
|
|
schema3 = FunctionSchemaGen.from_example(
|
|
"test_op3", inps, [torch.ones(1), torch.ones(1)]
|
|
)
|
|
self.assertExpectedInline(
|
|
str(schema1),
|
|
"""test_op1(bool bool_v, int int_v, float float_v, str str_v, Tensor Tensor_v, SymInt SymInt_v, SymBool SymBool_v, GraphModule GraphModule_v, __torch__.torch.classes._Foo ScriptObj_v) -> Tensor""", # noqa: B950
|
|
)
|
|
self.assertExpectedInline(
|
|
str(schema2),
|
|
"""test_op2(bool bool_v, int int_v, float float_v, str str_v, Tensor Tensor_v, SymInt SymInt_v, SymBool SymBool_v, GraphModule GraphModule_v, __torch__.torch.classes._Foo ScriptObj_v) -> Tensor""", # noqa: B950
|
|
)
|
|
self.assertExpectedInline(
|
|
str(schema3),
|
|
"""test_op3(bool bool_v, int int_v, float float_v, str str_v, Tensor Tensor_v, SymInt SymInt_v, SymBool SymBool_v, GraphModule GraphModule_v, __torch__.torch.classes._Foo ScriptObj_v) -> (Tensor, Tensor)""", # noqa: B950,
|
|
)
|
|
self.assertEqual(schema1.parse(str(schema1)), schema1)
|
|
self.assertEqual(schema2.parse(str(schema2)), schema2)
|
|
self.assertEqual(schema3.parse(str(schema3)), schema3)
|
|
|
|
def test_while_loop_schema_gen(self):
|
|
fn, inp = WHILE_LOOP_TESTS["simple_with_linear"]
|
|
graph = make_fx(fn)(*inp).graph
|
|
while_loop_node = next(
|
|
node
|
|
for node in graph.nodes
|
|
if node.op == "call_function"
|
|
and node.target is torch.ops.higher_order.while_loop
|
|
)
|
|
schema = torch._library.utils.hop_schema_from_fx_node(while_loop_node)
|
|
self.assertExpectedInline(
|
|
str(schema),
|
|
"""while_loop(GraphModule cond_fn, GraphModule body_fn, Tensor[2] carried_inputs, Tensor[3] additional_inputs) -> Tensor[2]""", # noqa: B950
|
|
)
|
|
self.assertEqual(schema.parse(str(schema)), schema)
|
|
|
|
def test_schema_tree_spec(self):
|
|
schema_gen = HopSchemaGenerator(torch.ops.higher_order.cond)
|
|
args = (torch.randn(3, 4), torch.randn(2, 3))
|
|
with self.assertRaisesRegex(
|
|
RuntimeError, "Please only add flattened inputs to the hop schema"
|
|
):
|
|
schema_gen.add_arg("tuple_args", args)
|
|
|
|
for i, arg in enumerate(args):
|
|
schema_gen.add_arg(f"tuple_args{i}", arg)
|
|
schema_gen.add_schema_tree_spec(pytree.tree_flatten(args)[1])
|
|
flat_schema = schema_gen.gen_schema()
|
|
self.assertExpectedInline(
|
|
str(flat_schema), """cond(Tensor tuple_args0, Tensor tuple_args1) -> ()"""
|
|
)
|
|
|
|
def test_cond_gen_schema_tensor_inputs(self):
|
|
schema = torch.ops.higher_order.cond.gen_schema(
|
|
torch.tensor(True),
|
|
lambda x: x.sin(),
|
|
lambda x: x.cos(),
|
|
(torch.randn(3, 4),),
|
|
)
|
|
self.assertExpectedInline(
|
|
str(schema),
|
|
"""cond(Tensor pred, Any true_fn, Any false_fn, Tensor operand0) -> ((Tensor))""",
|
|
)
|
|
|
|
def test_cond_gen_schema_symbool_inputs(self):
|
|
from torch._subclasses.fake_tensor import FakeTensorMode
|
|
from torch.fx.experimental.symbolic_shapes import ShapeEnv
|
|
|
|
fake_mode = FakeTensorMode(shape_env=ShapeEnv())
|
|
with fake_mode, fake_mode.shape_env.ignore_fresh_unbacked_symbols():
|
|
sym_bool = torch.randn(3, 4).nonzero().size(0) == 0
|
|
|
|
schema = torch.ops.higher_order.cond.gen_schema(
|
|
sym_bool,
|
|
lambda x: x.sin(),
|
|
lambda x: x.cos(),
|
|
(torch.randn(3, 4),),
|
|
)
|
|
self.assertExpectedInline(
|
|
str(schema),
|
|
"""cond(SymBool pred, Any true_fn, Any false_fn, Tensor operand0) -> ((Tensor))""",
|
|
)
|
|
|
|
def test_while_loop_gen_schema_tensor_inputs(self):
|
|
def cond_fn(x, y):
|
|
return x.sum() < 10
|
|
|
|
def body_fn(x, y):
|
|
return x + 1, y.sin()
|
|
|
|
schema = torch.ops.higher_order.while_loop.gen_schema(
|
|
cond_fn,
|
|
body_fn,
|
|
(torch.randn(3, 4), torch.randn(2, 3)),
|
|
(),
|
|
)
|
|
self.assertExpectedInline(
|
|
str(schema),
|
|
"""while_loop(Any cond_fn, Any body_fn, Tensor carried_input0, Tensor carried_input1) -> (Tensor, Tensor)""",
|
|
)
|
|
|
|
def test_while_loop_gen_schema_with_additional_inputs(self):
|
|
def cond_fn(x, y, z):
|
|
return x.sum() < z
|
|
|
|
def body_fn(x, y, z):
|
|
return x + 1, y.sin()
|
|
|
|
schema = torch.ops.higher_order.while_loop.gen_schema(
|
|
cond_fn,
|
|
body_fn,
|
|
(torch.randn(3, 4), torch.randn(2, 3)),
|
|
(torch.tensor(10),),
|
|
)
|
|
self.assertExpectedInline(
|
|
str(schema),
|
|
"""while_loop(Any cond_fn, Any body_fn, Tensor carried_input0, Tensor carried_input1, Tensor additional_input0) -> (Tensor, Tensor)""", # noqa: B950
|
|
)
|
|
|
|
def test_scan_gen_schema_tensor_inputs(self):
|
|
def combine_fn(carry, x):
|
|
return carry + x, carry * x
|
|
|
|
schema = torch.ops.higher_order.scan.gen_schema(
|
|
combine_fn,
|
|
(torch.randn(3, 4),),
|
|
(torch.randn(5, 3, 4),),
|
|
(),
|
|
)
|
|
self.assertExpectedInline(
|
|
str(schema),
|
|
"""scan(Any combine_fn, Tensor init0, Tensor xs0) -> (Tensor, Tensor)""",
|
|
)
|
|
|
|
def test_scan_gen_schema_with_additional_inputs(self):
|
|
def combine_fn(carry, x, scale):
|
|
return carry + x * scale, carry * x
|
|
|
|
schema = torch.ops.higher_order.scan.gen_schema(
|
|
combine_fn,
|
|
(torch.randn(3, 4),),
|
|
(torch.randn(5, 3, 4),),
|
|
(torch.tensor(2.0),),
|
|
)
|
|
self.assertExpectedInline(
|
|
str(schema),
|
|
"""scan(Any combine_fn, Tensor init0, Tensor xs0, Tensor additional_input0) -> (Tensor, Tensor)""", # noqa: B950
|
|
)
|
|
|
|
def test_scan_gen_schema_multiple_inputs(self):
|
|
def combine_fn(carry1, carry2, x1, x2):
|
|
return carry1 + x1, carry2 * x2, carry1 - x1, carry2 + x2
|
|
|
|
schema = torch.ops.higher_order.scan.gen_schema(
|
|
combine_fn,
|
|
(torch.randn(3, 4), torch.randn(2, 3)),
|
|
(torch.randn(5, 3, 4), torch.randn(5, 2, 3)),
|
|
(),
|
|
)
|
|
self.assertExpectedInline(
|
|
str(schema),
|
|
"""scan(Any combine_fn, Tensor init0, Tensor init1, Tensor xs0, Tensor xs1) -> (Tensor, Tensor, Tensor, Tensor)""", # noqa: B950
|
|
)
|
|
|
|
def test_associative_scan_gen_schema_tensor_inputs(self):
|
|
def combine_fn(x, y):
|
|
return x + y
|
|
|
|
schema = torch.ops.higher_order.associative_scan.gen_schema(
|
|
combine_fn,
|
|
(torch.randn(5, 3, 4),),
|
|
(),
|
|
)
|
|
self.assertExpectedInline(
|
|
str(schema),
|
|
"""associative_scan(Any combine_fn, Tensor xs0) -> ((Tensor))""",
|
|
)
|
|
|
|
def test_associative_scan_gen_schema_with_additional_inputs(self):
|
|
def combine_fn(x, y, scale):
|
|
return x * y * scale
|
|
|
|
schema = torch.ops.higher_order.associative_scan.gen_schema(
|
|
combine_fn,
|
|
(torch.randn(5, 3, 4),),
|
|
(torch.tensor(2.0),),
|
|
)
|
|
self.assertExpectedInline(
|
|
str(schema),
|
|
"""associative_scan(Any combine_fn, Tensor xs0, Tensor additional_input0) -> ((Tensor))""",
|
|
)
|
|
|
|
def test_associative_scan_gen_schema_multiple_inputs(self):
|
|
def combine_fn(x1, x2, y1, y2):
|
|
return x1 + y1, x2 * y2
|
|
|
|
schema = torch.ops.higher_order.associative_scan.gen_schema(
|
|
combine_fn,
|
|
(torch.randn(5, 3, 4), torch.randn(5, 2, 3)),
|
|
(),
|
|
)
|
|
self.assertExpectedInline(
|
|
str(schema),
|
|
"""associative_scan(Any combine_fn, Tensor xs0, Tensor xs1) -> (Tensor, Tensor)""",
|
|
)
|
|
|
|
def test_while_loop_gen_schema_with_int_carries(self):
|
|
def cond_fn(x, y, z, c):
|
|
return x < y
|
|
|
|
def body_fn(x, y, z, c):
|
|
return x + 1, y - 1, z.sin(), c + x
|
|
|
|
schema = torch.ops.higher_order.while_loop.gen_schema(
|
|
cond_fn,
|
|
body_fn,
|
|
(2, 10, torch.randn(2, 3)),
|
|
(torch.tensor(10),),
|
|
)
|
|
self.assertExpectedInline(
|
|
str(schema),
|
|
"""while_loop(Any cond_fn, Any body_fn, int carried_input0, int carried_input1, Tensor carried_input2, Tensor additional_input0) -> (int, int, Tensor, Tensor)""", # noqa: B950
|
|
)
|
|
|
|
def test_while_loop_gen_schema_with_input_mutation(self):
|
|
def cond_fn(x, y, z, c):
|
|
return x < y
|
|
|
|
def body_fn(x, y, z, c):
|
|
x.add_(1)
|
|
y.sub_(1)
|
|
z.sin_()
|
|
c.add_(x)
|
|
return x, y, z
|
|
|
|
c = torch.randn(3, 3)
|
|
|
|
schema = torch.ops.higher_order.while_loop.gen_schema(
|
|
cond_fn,
|
|
body_fn,
|
|
(torch.randn(3, 3), torch.randn(3, 3), torch.randn(3, 3)),
|
|
(c,),
|
|
)
|
|
self.assertExpectedInline(
|
|
str(schema),
|
|
"""while_loop(Any cond_fn, Any body_fn, Tensor(a2!) carried_input0, Tensor(a3!) carried_input1, Tensor(a4!) carried_input2, Tensor(a5!) additional_input0) -> (Tensor, Tensor, Tensor)""", # noqa: B950
|
|
)
|
|
|
|
|
|
instantiate_parametrized_tests(TestHopSchema)
|
|
instantiate_parametrized_tests(TestControlFlowTraced)
|
|
|
|
instantiate_parametrized_tests(TestControlFlow)
|
|
instantiate_parametrized_tests(AssociativeScanTests)
|
|
|
|
if __name__ == "__main__":
|
|
run_tests()
|