mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Apparently https://github.com/pytorch/pytorch/pull/78142 made torch.JIT allow for simple generator expressions which allows us to enable rules that replace unnecessary list comprehensions with generators in any/all. This was originally part of #99280 but I split it off into this PR so that it can be easily reverted should anything break. Pull Request resolved: https://github.com/pytorch/pytorch/pull/99890 Approved by: https://github.com/justinchuby, https://github.com/kit1980, https://github.com/malfet
778 lines
26 KiB
Python
778 lines
26 KiB
Python
# Owner(s): ["module: functorch"]
|
|
import unittest
|
|
|
|
import torch
|
|
from functorch.experimental import control_flow
|
|
from functorch.experimental.control_flow import cond
|
|
from functorch.experimental.control_flow import UnsupportedAliasMutationException
|
|
from torch.fx.experimental.proxy_tensor import make_fx
|
|
from torch.testing._internal.common_utils import run_tests, TestCase
|
|
from torch._dynamo.exc import CondOpArgsMismatchError
|
|
|
|
class TestControlFlow(TestCase):
|
|
def test_cond_no_trace(self):
|
|
def true_fn(x):
|
|
return x.sin()
|
|
|
|
def false_fn(x):
|
|
return x.cos()
|
|
|
|
x = torch.randn(4)
|
|
result = cond(False, true_fn, false_fn, [x])
|
|
self.assertEqual(result, torch.cos(x))
|
|
|
|
@unittest.skipIf(not torch.cuda.is_available(), "Test requires CUDA.")
|
|
def test_cond_gpu(self):
|
|
def true_fn(x):
|
|
return x.sin()
|
|
|
|
def false_fn(x):
|
|
return x.cos()
|
|
|
|
x = torch.randn(4, device="cuda")
|
|
pred = torch.tensor(False, device="cuda")
|
|
result = cond(False, true_fn, false_fn, [x])
|
|
self.assertEqual(result, torch.cos(x))
|
|
|
|
@unittest.skipIf(not torch.cuda.is_available(), "Test requires CUDA.")
|
|
def test_map_gpu(self):
|
|
def f(x, y):
|
|
return x + y
|
|
|
|
xs = torch.ones(3, 2, 2, device="cuda")
|
|
y = torch.ones(2, device="cuda")
|
|
res = control_flow.map(f, xs, y)
|
|
|
|
self.assertEqual(res, control_flow.map(f, torch.ones(3, 2, 2), torch.ones(2)))
|
|
|
|
|
|
class TestControlFlowTraced(TestCase):
|
|
def test_cond_traced_not_nested(self):
|
|
def true_fn(x):
|
|
return x.sin()
|
|
|
|
def false_fn(x):
|
|
return x.cos()
|
|
|
|
def f(x, y):
|
|
return cond(y, true_fn, false_fn, [x])
|
|
|
|
x = torch.randn(4)
|
|
graph = make_fx(f)(x, torch.tensor(False))
|
|
result_true = graph.forward(x, torch.tensor(True))
|
|
result_false = graph.forward(x, torch.tensor(False))
|
|
self.assertFalse(torch.allclose(result_true, result_false))
|
|
self.assertEqual(result_true, torch.sin(x))
|
|
self.assertEqual(result_false, torch.cos(x))
|
|
|
|
graph = make_fx(f, tracing_mode="symbolic")(x, torch.tensor(False))
|
|
self.assertEqual(graph(x, torch.tensor(True)), f(x, torch.tensor(True)))
|
|
|
|
def test_cond_nested_traced(self):
|
|
def true_nested(y):
|
|
return y * y
|
|
|
|
def false_nested(y):
|
|
return y + y
|
|
|
|
def true_fn(x, pred2):
|
|
z = cond(pred2, true_nested, false_nested, [x])
|
|
return x + z
|
|
|
|
def false_fn(x, _):
|
|
return x.cos()
|
|
|
|
def f(x, pred, pred2):
|
|
return cond(pred, true_fn, false_fn, [x, pred2])
|
|
|
|
x = torch.randn(4)
|
|
graph = make_fx(f)(x, torch.tensor(False), torch.tensor(False))
|
|
|
|
result_true_true = graph.forward(x, torch.tensor(True), torch.tensor(True)) # True + True -> x * x
|
|
result_true_false = graph.forward(x, torch.tensor(True), torch.tensor(False)) # True + True -> x + x
|
|
result_false_true = graph.forward(x, torch.tensor(False), torch.tensor(True)) # False + either -> cos
|
|
result_false_false = graph.forward(x, torch.tensor(False), torch.tensor(False)) # False + either -> cos
|
|
|
|
self.assertNotEqual(result_true_true, result_true_false)
|
|
self.assertFalse(torch.allclose(result_false_true, result_true_true))
|
|
|
|
self.assertEqual(result_false_true, result_false_false)
|
|
|
|
self.assertEqual(result_true_true, (x * x) + x)
|
|
self.assertEqual(result_true_false, x + x + x)
|
|
|
|
self.assertEqual(result_false_true, torch.cos(x))
|
|
|
|
graph = make_fx(f, tracing_mode="symbolic")(x, torch.tensor(False), torch.tensor(False))
|
|
self.assertEqual(graph(x, torch.tensor(True), torch.tensor(True)), f(x, torch.tensor(True), torch.tensor(True)))
|
|
|
|
def test_cond_functionalized(self):
|
|
def true_fn(x):
|
|
y = x.sin()
|
|
y.add_(4)
|
|
return x.sin().max() + y.sum()
|
|
|
|
def false_fn(x):
|
|
return x.cos().min()
|
|
|
|
def f(x):
|
|
pred = x.shape[0] == 1
|
|
return cond(pred, true_fn, false_fn, [x])
|
|
|
|
example_inputs = (torch.ones(4, 5),)
|
|
functional_f = torch.func.functionalize(f)
|
|
self.assertEqual(functional_f(*example_inputs), f(*example_inputs))
|
|
|
|
graph_module = make_fx(torch.func.functionalize(f))(*example_inputs)
|
|
self.assertEqual(graph_module(*example_inputs), f(*example_inputs))
|
|
|
|
all_ops_in_true_branch = []
|
|
for node in graph_module.true_graph_0.graph.nodes:
|
|
if node.op == "call_function":
|
|
all_ops_in_true_branch.append(node.target)
|
|
|
|
self.assertFalse(any(op._schema.is_mutable for op in all_ops_in_true_branch))
|
|
|
|
graph_module = make_fx(torch.func.functionalize(f), tracing_mode="symbolic")(*example_inputs)
|
|
self.assertEqual(graph_module(*example_inputs), f(*example_inputs))
|
|
|
|
def test_cond_retrace_functionalized(self):
|
|
def true_fn(x):
|
|
return x.sin()
|
|
|
|
def false_fn(x):
|
|
return x.cos()
|
|
|
|
def f(x):
|
|
return cond(x.all(), true_fn, false_fn, (x,))
|
|
|
|
inp = torch.ones(1, 2)
|
|
gm_non_functional = make_fx(f, tracing_mode="real")(inp)
|
|
gm_functional = make_fx(torch.func.functionalize(gm_non_functional), tracing_mode="real")(inp)
|
|
self.assertEqual(gm_functional(torch.zeros(1, 2)), f(torch.zeros(1, 2)))
|
|
|
|
def test_cond_functionalized_nested(self):
|
|
def true_true_fn(x):
|
|
y = x.cos()
|
|
y.add_(4)
|
|
return x.sin().max() + y.sin().max()
|
|
|
|
def true_false_fn(x):
|
|
return x.cos().min()
|
|
|
|
def true_fn(x):
|
|
pred = x.shape[0] == 1
|
|
return cond(pred, true_true_fn, true_false_fn, [x])
|
|
|
|
def false_fn(x):
|
|
return x.sum()
|
|
|
|
def f(x):
|
|
pred = x.shape[0] == 1
|
|
return cond(pred, true_fn, false_fn, [x])
|
|
|
|
example_inputs = (torch.ones(4, 5),)
|
|
functional_f = torch.func.functionalize(f)
|
|
self.assertEqual(functional_f(*example_inputs), f(*example_inputs))
|
|
|
|
graph_module = make_fx(torch.func.functionalize(f))(*example_inputs)
|
|
self.assertEqual(graph_module(*example_inputs), f(*example_inputs))
|
|
|
|
gm_true_true_branch = graph_module.true_graph_0.true_graph_0
|
|
|
|
graph_module1 = make_fx(torch.func.functionalize(f), tracing_mode="symbolic")(*example_inputs)
|
|
self.assertEqual(graph_module1(*example_inputs), f(*example_inputs))
|
|
|
|
all_ops = []
|
|
for node in gm_true_true_branch.graph.nodes:
|
|
if node.op == "call_function":
|
|
all_ops.append(node.target)
|
|
|
|
self.assertFalse(any(op._schema.is_mutable for op in all_ops))
|
|
|
|
def test_cond_functionalized_data_dependent_pred(self):
|
|
def true_fn(x):
|
|
return x.sin().sum()
|
|
|
|
def false_fn(x):
|
|
return x.cos().sum()
|
|
|
|
def f(x):
|
|
pred = x.nonzero().shape[0] == 1
|
|
return cond(pred, true_fn, false_fn, [x])
|
|
|
|
example_inputs = (torch.ones(4, 5),)
|
|
functional_f = torch.func.functionalize(f)
|
|
self.assertEqual(functional_f(*example_inputs), f(*example_inputs))
|
|
|
|
graph_module = make_fx(torch.func.functionalize(f))(*example_inputs)
|
|
self.assertEqual(graph_module(*example_inputs), f(*example_inputs))
|
|
|
|
def test_cond_functionalized_input_mutation_on_true_branch(self):
|
|
def true_fn(x):
|
|
view_x = x.view(x.shape)
|
|
view_x.add_(1)
|
|
return view_x.sin().sum()
|
|
|
|
def false_fn(x):
|
|
return x.cos().sum()
|
|
|
|
def f(x):
|
|
pred = x.shape[0] == 4
|
|
return cond(pred, true_fn, false_fn, [x])
|
|
|
|
example_inputs = (torch.ones(4, 5),)
|
|
functional_f = torch.func.functionalize(f)
|
|
with self.assertRaisesRegex(UnsupportedAliasMutationException, "One of torch.cond branch"):
|
|
functional_f(*example_inputs)
|
|
|
|
with self.assertRaisesRegex(UnsupportedAliasMutationException, "One of torch.cond branch"):
|
|
make_fx(torch.func.functionalize(f))(*example_inputs)
|
|
|
|
def test_cond_functionalized_input_mutation_on_false_branch(self):
|
|
def true_fn(x):
|
|
return x.sin().sum()
|
|
|
|
def false_fn(x):
|
|
view_x = x.view(x.shape)
|
|
view_x.add_(1)
|
|
return view_x.cos().sum()
|
|
|
|
def f(x):
|
|
pred = x.shape[0] == 4
|
|
return cond(pred, true_fn, false_fn, [x])
|
|
|
|
example_inputs = (torch.ones(5, 5),)
|
|
functional_f = torch.func.functionalize(f)
|
|
with self.assertRaisesRegex(UnsupportedAliasMutationException, "One of torch.cond branch"):
|
|
functional_f(*example_inputs)
|
|
|
|
with self.assertRaisesRegex(UnsupportedAliasMutationException, "One of torch.cond branch"):
|
|
make_fx(torch.func.functionalize(f))(*example_inputs)
|
|
|
|
def test_cond_functionalized_output_alias_input(self):
|
|
def true_fn(x):
|
|
return x
|
|
|
|
def false_fn(x):
|
|
view_x = x.view(x.shape)
|
|
return view_x
|
|
|
|
def f(x):
|
|
pred = x.shape[0] == 4
|
|
return cond(pred, true_fn, false_fn, [x])
|
|
|
|
example_inputs = (torch.ones(5, 5),)
|
|
functional_f = torch.func.functionalize(f)
|
|
|
|
with self.assertRaisesRegex(UnsupportedAliasMutationException, "One of torch.cond branch might be aliasing"):
|
|
functional_f(*example_inputs)
|
|
|
|
with self.assertRaisesRegex(UnsupportedAliasMutationException, "One of torch.cond branch might be aliasing"):
|
|
make_fx(torch.func.functionalize(f))(*example_inputs)
|
|
|
|
def test_cond_functionalized_nested_input_mutation(self):
|
|
def true_true_fn(x):
|
|
x.add_(4)
|
|
return x.sin().max()
|
|
|
|
def true_false_fn(x):
|
|
return x.cos().min()
|
|
|
|
def true_fn(x):
|
|
pred = x.shape[0] == 1
|
|
return cond(pred, true_true_fn, true_false_fn, [x])
|
|
|
|
def false_fn(x):
|
|
return x.sum()
|
|
|
|
def f(x):
|
|
pred = x.shape[0] == 1
|
|
return cond(pred, true_fn, false_fn, [x])
|
|
|
|
example_inputs = (torch.ones(4, 5),)
|
|
functional_f = torch.func.functionalize(f)
|
|
with self.assertRaisesRegex(UnsupportedAliasMutationException, "One of torch.cond branch"):
|
|
functional_f(*example_inputs)
|
|
|
|
with self.assertRaisesRegex(UnsupportedAliasMutationException, "One of torch.cond branch"):
|
|
make_fx(torch.func.functionalize(f))(*example_inputs)
|
|
|
|
def test_cond_nested_traced_other_inputs(self):
|
|
def true_nested(y):
|
|
return y * y
|
|
|
|
def false_nested(y):
|
|
return y + y
|
|
|
|
def true_fn(k, pred2):
|
|
z = cond(pred2, true_nested, false_nested, [k])
|
|
return torch.add(torch.tensor([.25, .25]), z)
|
|
|
|
def false_fn(k, _):
|
|
return k.cos()
|
|
|
|
def f(k, pred, pred2):
|
|
return cond(pred, true_fn, false_fn, [k, pred2])
|
|
|
|
x = torch.tensor([0.5, 0.5])
|
|
graph = make_fx(f)(x, torch.tensor(False), torch.tensor(False))
|
|
|
|
a = torch.tensor([1.0, 1.0])
|
|
result_true_true = graph.forward(a, torch.tensor(True), torch.tensor(True))
|
|
self.assertEqual(result_true_true, (a * a) + torch.tensor([0.25, 0.25]))
|
|
|
|
b = torch.tensor([2.0, 2.0])
|
|
result_true_true = graph.forward(b, torch.tensor(True), torch.tensor(True))
|
|
self.assertEqual(result_true_true, (b * b) + torch.tensor([0.25, 0.25]))
|
|
|
|
def test_cond_nested_traced_multi(self):
|
|
def true_a(y):
|
|
return y * y
|
|
|
|
def false_a(y):
|
|
return y + y
|
|
|
|
def true_b(y, z):
|
|
return y + z
|
|
|
|
def false_b(y, z):
|
|
return y * z
|
|
|
|
def f(x, pred, pred2):
|
|
a_out = cond(pred, true_a, false_a, [x])
|
|
b_out = cond(pred2, true_b, false_b, [x, x])
|
|
return a_out + b_out
|
|
|
|
x = torch.randn(4)
|
|
graph = make_fx(f)(x, torch.tensor(False), torch.tensor(False))
|
|
|
|
# Brittle, yet, delicious
|
|
out = """
|
|
def forward(self, x_1, pred_1, pred2_1):
|
|
true_graph_0 = self.true_graph_0
|
|
false_graph_0 = self.false_graph_0
|
|
conditional = torch.ops.cond(pred_1, true_graph_0, false_graph_0, [x_1]);
|
|
pred_1 = true_graph_0 = false_graph_0 = None
|
|
true_graph_1 = self.true_graph_1
|
|
false_graph_1 = self.false_graph_1
|
|
conditional_1 = torch.ops.cond(pred2_1, true_graph_1, false_graph_1, [x_1, x_1]);
|
|
pred2_1 = true_graph_1 = false_graph_1 = x_1 = None
|
|
add = torch.ops.aten.add.Tensor(conditional, conditional_1); conditional = conditional_1 = None
|
|
return add
|
|
"""
|
|
code = graph.code
|
|
# Normalization hack, cause .code makes some weird whitespace
|
|
code = "".join(code.split())
|
|
out = "".join(out.split())
|
|
self.assertEqual(code, out)
|
|
|
|
code = graph.true_graph_0.code
|
|
out = """
|
|
def forward(self, y_1):
|
|
mul = torch.ops.aten.mul.Tensor(y_1, y_1); y_1 = None
|
|
return mul
|
|
"""
|
|
# Normalization hack, cause .code makes some weird whitespace
|
|
code = "".join(code.split())
|
|
out = "".join(out.split())
|
|
self.assertEqual(code, out)
|
|
|
|
def test_raise_error_on_mismatch_type_size(self):
|
|
def true_fn(x):
|
|
return x.sin()
|
|
|
|
def false_fn(x):
|
|
return (x, x)
|
|
|
|
def f(x, y):
|
|
return cond(y, true_fn, false_fn, [x])
|
|
|
|
x = torch.randn(4)
|
|
with self.assertRaisesRegex(
|
|
CondOpArgsMismatchError,
|
|
"Expected to return same number of outputs but got",
|
|
):
|
|
make_fx(f)(x, torch.tensor(False))
|
|
|
|
def test_raise_error_on_mismatch_tensor_size(self):
|
|
def true_fn(x):
|
|
return x.sin()
|
|
|
|
def false_fn(x):
|
|
return torch.zeros([10, 10])
|
|
|
|
def f(x, y):
|
|
return cond(y, true_fn, false_fn, [x])
|
|
|
|
x = torch.randn(4)
|
|
with self.assertRaisesRegex(
|
|
CondOpArgsMismatchError,
|
|
"Expected each tensor to have same metadata but got",
|
|
):
|
|
make_fx(f)(x, torch.tensor(False))
|
|
|
|
def test_cond_traced_not_nested_fake_tensor(self):
|
|
def true_fn(x):
|
|
return x.sin()
|
|
|
|
def false_fn(x):
|
|
return x.cos()
|
|
|
|
def f(x, y):
|
|
return cond(y, true_fn, false_fn, [x])
|
|
|
|
x = torch.randn(4)
|
|
graph = make_fx(f, tracing_mode="fake")(x, torch.tensor(False))
|
|
result_true = graph.forward(x, torch.tensor(True))
|
|
result_false = graph.forward(x, torch.tensor(False))
|
|
self.assertFalse(torch.allclose(result_true, result_false))
|
|
self.assertEqual(result_true, torch.sin(x))
|
|
self.assertEqual(result_false, torch.cos(x))
|
|
|
|
def test_cond_nested_traced_fake_tensor(self):
|
|
def true_nested(y):
|
|
return y * y
|
|
|
|
def false_nested(y):
|
|
return y + y
|
|
|
|
def true_fn(x, pred2):
|
|
z = cond(pred2, true_nested, false_nested, [x])
|
|
return x + z
|
|
|
|
def false_fn(x, _):
|
|
return x.cos()
|
|
|
|
def f(x, pred, pred2):
|
|
return cond(pred, true_fn, false_fn, [x, pred2])
|
|
|
|
x = torch.randn(4)
|
|
graph = make_fx(f, tracing_mode="fake")(x, torch.tensor(False), torch.tensor(False))
|
|
|
|
result_true_true = graph.forward(x, torch.tensor(True), torch.tensor(True)) # True + True -> x * x
|
|
result_true_false = graph.forward(x, torch.tensor(True), torch.tensor(False)) # True + True -> x + x
|
|
result_false_true = graph.forward(x, torch.tensor(False), torch.tensor(True)) # False + either -> cos
|
|
result_false_false = graph.forward(x, torch.tensor(False), torch.tensor(False)) # False + either -> cos
|
|
|
|
self.assertNotEqual(result_true_true, result_true_false)
|
|
self.assertFalse(torch.allclose(result_false_true, result_true_true))
|
|
|
|
self.assertEqual(result_false_true, result_false_false)
|
|
|
|
self.assertEqual(result_true_true, (x * x) + x)
|
|
self.assertEqual(result_true_false, x + x + x)
|
|
|
|
self.assertEqual(result_false_true, torch.cos(x))
|
|
|
|
def test_cond_nested_traced_other_inputs_fake_tensor(self):
|
|
def true_nested(y):
|
|
return y * y
|
|
|
|
def false_nested(y):
|
|
return y + y
|
|
|
|
def true_fn(k, pred2):
|
|
z = cond(pred2, true_nested, false_nested, [k])
|
|
return torch.add(torch.tensor([.25, .25]), z)
|
|
|
|
def false_fn(k, _):
|
|
return k.cos()
|
|
|
|
def f(k, pred, pred2):
|
|
return cond(pred, true_fn, false_fn, [k, pred2])
|
|
|
|
x = torch.tensor([0.5, 0.5])
|
|
graph = make_fx(f, tracing_mode="fake")(x, torch.tensor(False), torch.tensor(False))
|
|
|
|
a = torch.tensor([1.0, 1.0])
|
|
result_true_true = graph.forward(a, torch.tensor(True), torch.tensor(True))
|
|
self.assertEqual(result_true_true, (a * a) + torch.tensor([0.25, 0.25]))
|
|
|
|
b = torch.tensor([2.0, 2.0])
|
|
result_true_true = graph.forward(b, torch.tensor(True), torch.tensor(True))
|
|
self.assertEqual(result_true_true, (b * b) + torch.tensor([0.25, 0.25]))
|
|
|
|
def test_cond_nested_traced_multi_fake_tensor(self):
|
|
def true_a(y):
|
|
return y * y
|
|
|
|
def false_a(y):
|
|
return y + y
|
|
|
|
def true_b(y, z):
|
|
return y + z
|
|
|
|
def false_b(y, z):
|
|
return y * z
|
|
|
|
def f(x, pred, pred2):
|
|
a_out = cond(pred, true_a, false_a, [x])
|
|
b_out = cond(pred2, true_b, false_b, [x, x])
|
|
return a_out + b_out
|
|
|
|
x = torch.randn(4)
|
|
graph = make_fx(f, tracing_mode="fake")(x, torch.tensor(False), torch.tensor(False))
|
|
|
|
# Brittle, yet, delicious
|
|
out = """
|
|
def forward(self, x_1, pred_1, pred2_1):
|
|
true_graph_0 = self.true_graph_0
|
|
false_graph_0 = self.false_graph_0
|
|
conditional = torch.ops.cond(pred_1, true_graph_0, false_graph_0, [x_1]);
|
|
pred_1 = true_graph_0 = false_graph_0 = None
|
|
true_graph_1 = self.true_graph_1
|
|
false_graph_1 = self.false_graph_1
|
|
conditional_1 = torch.ops.cond(pred2_1, true_graph_1, false_graph_1, [x_1, x_1]);
|
|
pred2_1 = true_graph_1 = false_graph_1 = x_1 = None
|
|
add = torch.ops.aten.add.Tensor(conditional, conditional_1); conditional = conditional_1 = None
|
|
return add
|
|
"""
|
|
code = graph.code
|
|
# Normalization hack, cause .code makes some weird whitespace
|
|
code = "".join(code.split())
|
|
out = "".join(out.split())
|
|
self.assertEqual(code, out)
|
|
|
|
code = graph.true_graph_0.code
|
|
out = """
|
|
def forward(self, y_1):
|
|
mul = torch.ops.aten.mul.Tensor(y_1, y_1); y_1 = None
|
|
return mul
|
|
"""
|
|
# Normalization hack, cause .code makes some weird whitespace
|
|
code = "".join(code.split())
|
|
out = "".join(out.split())
|
|
self.assertEqual(code, out)
|
|
|
|
def test_raise_error_on_mismatch_type_size_fake_tensor(self):
|
|
def true_fn(x):
|
|
return x.sin()
|
|
|
|
def false_fn(x):
|
|
return (x, x)
|
|
|
|
def f(x, y):
|
|
return cond(y, true_fn, false_fn, [x])
|
|
|
|
x = torch.randn(4)
|
|
with self.assertRaisesRegex(
|
|
CondOpArgsMismatchError,
|
|
"Expected to return same number of outputs but got",
|
|
):
|
|
make_fx(f, tracing_mode="fake")(x, torch.tensor(False))
|
|
|
|
|
|
def test_raise_error_on_mismatch_tensor_size_fake_tensor(self):
|
|
def true_fn(x):
|
|
return x.sin()
|
|
|
|
def false_fn(x):
|
|
return torch.zeros([10, 10])
|
|
|
|
def f(x, y):
|
|
return cond(y, true_fn, false_fn, [x])
|
|
|
|
x = torch.randn(4)
|
|
with self.assertRaisesRegex(
|
|
CondOpArgsMismatchError,
|
|
"Expected each tensor to have same metadata but got",
|
|
):
|
|
make_fx(f, tracing_mode="fake")(x, torch.tensor(False))
|
|
|
|
def check_map_graph(self, gm, key):
|
|
i = 0
|
|
for node in gm.graph.nodes:
|
|
if node.op == "call_function" and node.target == torch.ops.map:
|
|
i += 1
|
|
self.assertEqual(
|
|
node.meta[key].shape[0], node.args[1].meta[key].shape[0]
|
|
)
|
|
self.assertEqual(i, 1)
|
|
|
|
def test_map_real(self):
|
|
def f(x, y):
|
|
return x + y
|
|
|
|
def g(xs, y):
|
|
return control_flow.map(f, xs, y)
|
|
|
|
gm = make_fx(g, tracing_mode="real")(torch.ones(3, 2, 2), torch.ones(2))
|
|
x = torch.randn(3, 2, 2)
|
|
y = torch.randn(2)
|
|
res = gm(x, y)
|
|
self.assertEqual(res, g(x, y))
|
|
self.check_map_graph(gm, "tensor_meta")
|
|
|
|
def test_map_symbolic(self):
|
|
def f(x, y):
|
|
return x + y
|
|
|
|
def g(xs, y):
|
|
return control_flow.map(f, xs, y)
|
|
|
|
gm = make_fx(g, tracing_mode="symbolic")(torch.ones(3, 2, 4), torch.ones(4))
|
|
x = torch.randn(3, 2, 2)
|
|
y = torch.randn(2)
|
|
res = gm(x, y)
|
|
self.assertEqual(res, g(x, y))
|
|
self.check_map_graph(gm, "val")
|
|
|
|
def test_map_functionalized(self):
|
|
def map_fn(x, y):
|
|
z = x + y
|
|
z.add_(4)
|
|
return z
|
|
|
|
def f(xs, y):
|
|
return control_flow.map(map_fn, xs, y)
|
|
|
|
example_inputs = (torch.ones(3, 2, 4), torch.ones(4))
|
|
functional_f = torch.func.functionalize(f)
|
|
self.assertEqual(functional_f(*example_inputs), f(*example_inputs))
|
|
|
|
gm = make_fx(torch.func.functionalize(f))(*example_inputs)
|
|
self.assertEqual(gm(*example_inputs), f(*example_inputs))
|
|
|
|
gm = make_fx(torch.func.functionalize(f), tracing_mode="symbolic")(*example_inputs)
|
|
self.assertEqual(gm(*example_inputs), f(*example_inputs))
|
|
|
|
for node in gm.body_graph_0.graph.nodes:
|
|
if node.op == "call_function":
|
|
self.assertTrue(not node.target._schema.is_mutable)
|
|
|
|
def test_map_functionalized_arg_mutation(self):
|
|
def map_fn(x, y):
|
|
y.add_(4)
|
|
return x + y
|
|
|
|
def f(xs, y):
|
|
return control_flow.map(map_fn, xs, y)
|
|
|
|
example_inputs = (torch.ones(3, 2, 4), torch.ones(4))
|
|
functional_f = torch.func.functionalize(f)
|
|
with self.assertRaisesRegex(UnsupportedAliasMutationException, "torch.map is mutating the input!"):
|
|
functional_f(*example_inputs)
|
|
|
|
def test_map_functionalized_elem_mutation(self):
|
|
def map_fn(x, y):
|
|
x.add_(4)
|
|
return x + y
|
|
|
|
def f(xs, y):
|
|
return control_flow.map(map_fn, xs, y)
|
|
|
|
example_inputs = (torch.ones(3, 2, 4), torch.ones(4))
|
|
functional_f = torch.func.functionalize(f)
|
|
with self.assertRaisesRegex(UnsupportedAliasMutationException, "torch.map is mutating the input!"):
|
|
functional_f(*example_inputs)
|
|
|
|
def test_map_functionalized_elem_alias(self):
|
|
def map_fn(x):
|
|
x.view(x.shape)
|
|
return x
|
|
|
|
def f(xs):
|
|
return control_flow.map(map_fn, xs)
|
|
|
|
example_inputs = (torch.ones(3, 2, 4),)
|
|
functional_f = torch.func.functionalize(f)
|
|
with self.assertRaisesRegex(UnsupportedAliasMutationException, "torch.map is aliasing the input!"):
|
|
functional_f(*example_inputs)
|
|
|
|
def test_nested_map_cond_real(self):
|
|
def true_fn(x, y):
|
|
return x * y
|
|
|
|
def false_fn(x, y):
|
|
return x + y
|
|
|
|
def f(x, pred, y):
|
|
return cond(pred, true_fn, false_fn, [x, y])
|
|
|
|
def g(pred, xs, y):
|
|
return control_flow.map(f, xs, pred, y)
|
|
|
|
gm = make_fx(g, tracing_mode="real")(
|
|
torch.tensor(True), torch.ones(3, 2, 4), torch.ones(4)
|
|
)
|
|
pred = torch.tensor(False)
|
|
x = torch.randn(3, 2, 2)
|
|
y = torch.randn(2)
|
|
res = gm(pred, x, y)
|
|
self.assertEqual(res, g(pred, x, y))
|
|
self.check_map_graph(gm, "tensor_meta")
|
|
|
|
def test_nested_map_cond_symbolic(self):
|
|
def true_fn(x, y):
|
|
return x * y
|
|
|
|
def false_fn(x, y):
|
|
return x + y
|
|
|
|
def f(x, pred, y):
|
|
return cond(pred, true_fn, false_fn, [x, y])
|
|
|
|
def g(pred, xs, y):
|
|
return control_flow.map(f, xs, pred, y)
|
|
|
|
gm = make_fx(g, tracing_mode="symbolic")(
|
|
torch.tensor(True), torch.ones(3, 2, 4), torch.ones(4)
|
|
)
|
|
pred = torch.tensor(False)
|
|
x = torch.randn(3, 2, 2)
|
|
y = torch.randn(2)
|
|
res = gm(pred, x, y)
|
|
self.assertEqual(res, g(pred, x, y))
|
|
self.check_map_graph(gm, "val")
|
|
|
|
def test_nested_cond_map_cond_symbolic(self):
|
|
|
|
def true_fn(x, y):
|
|
return x * y
|
|
|
|
def false_fn(x, y):
|
|
return x + y
|
|
|
|
def f(x, pred, y):
|
|
return cond(pred, true_fn, false_fn, [x, y])
|
|
|
|
def g(pred, xs, y):
|
|
return control_flow.map(f, xs, pred, y)
|
|
|
|
def main_true_fn(pred, xs, y):
|
|
return g(pred, xs, y) * 2
|
|
|
|
def main_false_fn(pred, xs, y):
|
|
return g(pred, xs, y) + 1
|
|
|
|
def main(p, pred, xs, y):
|
|
return cond(p, main_true_fn, main_false_fn, [pred, xs, y])
|
|
|
|
gm = make_fx(main, tracing_mode="symbolic")(
|
|
torch.tensor(True), torch.tensor(True), torch.ones(3, 2, 4), torch.ones(4)
|
|
)
|
|
p = torch.tensor(False)
|
|
pred = torch.tensor(False)
|
|
xs = torch.randn(3, 2, 2)
|
|
y = torch.randn(2)
|
|
res = gm(p, pred, xs, y)
|
|
self.assertEqual(res, main(p, pred, xs, y))
|
|
|
|
def test_cond_with_sym_pred(self):
|
|
def true_fn(x):
|
|
return x + x
|
|
|
|
def false_fn(x):
|
|
return x * x
|
|
|
|
def foo(x):
|
|
return cond(x.shape[0] == 4, true_fn, false_fn, [x])
|
|
|
|
gm = make_fx(foo, tracing_mode="symbolic")(torch.ones(3, 2, 1))
|
|
x = torch.ones(4, 3, 2)
|
|
self.assertEqual(foo(x), gm(x))
|
|
|
|
if __name__ == '__main__':
|
|
run_tests()
|