mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
There are 4 parts (they are hard to further break into smaller ones cause they're highly coupled) in this PR:
1. **Whenever we call create_graph_input, we try to bind the symbols in the graph input.**
We've enforced the invariant that all create_graph_inputs calls must provide an example value, we could intercept at the create_graph_input calls (This PR only handles free symbols in tensors).
2. **We cache the bound_symbols** to avoid lift the same symbol repeated.
3. For lifted symbols, we re-used **lifted_freevars** i.e. the mapping between symbol proxy in parent graph to the lifted phs in current subgraph, which we handle lifted tensors. In this way, all hops that supports lifted tensors should be able to handle lifted_symints automatically (at least in dynamo part).
4. For **unbacked symbols** created during tracing, we need to also bound these symbols to its proxy. This is to support the tests cases where we want to lift unbacked symbols as input. We need the proxy of the unbacked symbol in parent graph in order to properly create the args to the hop.
5. We change all the tests after free symbols are lifted in subgraphs. And also supports the lifted symbols in existing higher order ops.
**The interaction of nested tracers:**
The previous design for lifting tensor closures is that: suppose we're in nested tracers, whenever we see a new proxy that's not created by create tracer, we recursively look for the proxy in parent tracer until we find the tracer that creates this proxy (either a placeholder or some intermediate results). More detail is in Note [Nested SubgraphTracer and free_variable handling].
Given the above design, the plan for lifting the free symbols is: whenever we lift a free tensor to be the inputs of current subgraph, we'll look at the symbols in it and bind the symbols at the same time.
For example, suppose we have the following function:
```python
def f(x: [s1, s2]):
def true_f():
def true_f_inner():
return x.sin()
```
what will happen in time order:
1. we create a subtracer 1 and start to speculate the outer cond's true_f
2. we create a another subtracer 2 and start to speculate the inner cond's true_f_inner.
3. dynamo realize the tensor input x by calling wrap_tensor in top-level to create graph input x (tracer 0), we bind the symbol s1, s2 after ph for x is created. So the graph now looks like:
```python
def gm(s1, s2, x):
```
4. when seeing TensorVariable.call_method of x, tracer2 wants to create a call_function(sin, proxy_of_x), but it finds that proxy_of_x is not created by current tracer. So it recursively look up its parent tracer1 and find parent tracer1 also doesn't track this proxy_of_x then it finds the root tracer0, who is the creator of it and tracks it as a ph. Then tracer 1 create_graph_input to lift the closure to its input ph1 and add (proxy_of_x: ph1) k-v in **lifted_freevars** of tracer 1.
Now the graph looks like:
```python
def gm(s1, s2, x):
def true_gm(x):
```
5. Since there are free symbols inside this new tensor input, tracer 1 also binds the symbols (maybe_bind_symbol), which calls create_graph_input for s1 and s2. Now the graph looks like
```python
def gm(s1, s2, x):
def true_gm(s1, s2, x):
```
6. then it goes back to tracer 2, and call create_graph_input for x and get ph2, tracer 2's **lifted_freevars** records (ph1, ph2). and tracer 2 also binds the symbols in this new tensor input. Now the graph looks like:
```python
def gm(s1, s2, x):
def true_gm(s1, s2, x):
def true_gm_inner(s1, s2, x):
```
7. Finally the sin call_function node is created by tracer 2.
**This PR also handles the following cases:**
- What if we lift two tensors share the same symbol? e.g. x1 [s1, s2], x2 [s2, s3]? Each subtracer maintains bound_symbols as a cache that maps a symbol.expr to its proxy in current tracer. So when we see x1, we'll track s1 and s2 as inputs and bound s1 to ph1, s2 to ph2. So when we try to bind symbols of x2, s2 will already be tracked so no graph input is created.
- what if a subgraph close over a symint? e.g.
```python
def f(x):
def true_f():
c = x.size(0)
def true_fn_inner():
return c
```
When we speculate true_fn_inner, we find proxy_of_c is not tracked by tracer 2, so it recursively looks up its parent. At this point, x and its symbols have been lifted as input of true_f (as a result of lifting x during tracing true_f in tracer 1. Specifically the graph looks like:
```python
def gm(s1, s2, x):
def true_gm(s1, s2, x):
def true_gm_inner():
```
So tracer 2 is able to find that s1 have been tracked as ph in tracer 1 so it returns back to gm and call create_graph_input on s1. The graph now looks like:
```python
def gm(s1, s2, x):
def true_gm(s1, s2, x):
def true_gm_inner(s1):
return s1
```
- What if subgraph close over an unbacked symint? e.g.
```python
def f(x):
def true_f():
c = x.item()
def true_f_inner():
return c
```
When x.item() is called, proxy_of_c and its symnode variable is created for tracer 1, and we also call track_unbacked_symbols to record this relationship. So when tracer 2 finds proxy_of_c is not created by current tracer, it recursivelly looks up its parent tracer and finds that that expression u0 has been tracked as a result of track_unbacked_symbol in tracer 1. So it will stop the recursion and create_graph_input u0 in tracer 2. Graph looks like:
```python
def f(x):
def true_f(s1, s2, x):
c = x.item()
def true_gm_inner(u0):
return u0
cond(pred, true_gm_inner, false_gm_inner, (c,))
```
- what if subgraph close over a tensor with unbacked symint shape?
```python
def f(x):
def true_f():
c = x.item()
r = torch.randn((c,))
def true_f_inner():
return r + 1
```
This is the same as the case of closing over tensors with backed shapes. where we first lift r, then bind u0 in it, which recursively bind_symint of u0 in its parent and found u0 is tracked in parent tracer as a result of .item() call.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/138363
Approved by: https://github.com/zou3519
6948 lines
259 KiB
Python
6948 lines
259 KiB
Python
# Owner(s): ["module: dynamo"]
|
|
import enum
|
|
import functools
|
|
import pprint
|
|
import re
|
|
import unittest
|
|
import warnings
|
|
|
|
import functorch.experimental.control_flow as control_flow
|
|
import torch
|
|
import torch._dynamo.config as config
|
|
import torch._dynamo.test_case
|
|
import torch._functorch.config
|
|
import torch.nn as nn
|
|
import torch.utils._pytree as pytree
|
|
import torch.utils.checkpoint
|
|
from torch._dynamo.backends.common import aot_autograd
|
|
from torch._dynamo.testing import (
|
|
check_dynamic_shape_capture,
|
|
CompileCounter,
|
|
CompileCounterWithBackend,
|
|
EagerAndRecordGraphs,
|
|
empty_line_normalizer,
|
|
normalize_gm,
|
|
)
|
|
from torch._dynamo.utils import counters, ifdynstaticdefault
|
|
from torch._higher_order_ops.hints_wrap import hints_wrapper
|
|
from torch._higher_order_ops.wrap import wrap
|
|
from torch.testing._internal.common_utils import (
|
|
munge_exc,
|
|
TEST_WITH_TORCHDYNAMO,
|
|
xfailIfTorchDynamo,
|
|
)
|
|
from torch.testing._internal.inductor_utils import HAS_CUDA
|
|
from torch.testing._internal.logging_utils import LoggingTestCase, make_logging_test
|
|
|
|
|
|
requires_cuda = unittest.skipUnless(HAS_CUDA, "requires cuda")
|
|
|
|
|
|
def count_ops(gm, args, freq, op):
|
|
actual = [node.target for node in gm.graph.nodes].count(op)
|
|
assert actual == freq, f"expected={freq}, actual={actual}"
|
|
return gm
|
|
|
|
|
|
class Obj:
|
|
pass
|
|
|
|
|
|
class MyModule(nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.existing = torch.nn.Parameter(torch.ones([]))
|
|
|
|
def forward(self, x):
|
|
return self.existing * x
|
|
|
|
|
|
global_obj = Obj()
|
|
global_module = MyModule()
|
|
global_var = torch.randn(3)
|
|
global_num = 3.14
|
|
global_list = []
|
|
|
|
|
|
def find_first_node(gm, func):
|
|
for node in gm.graph.nodes:
|
|
if node.target is func:
|
|
return node
|
|
return None
|
|
|
|
|
|
def op_count(gm):
|
|
result = 0
|
|
for node in gm.graph.nodes:
|
|
if "call" in node.op:
|
|
result += 1
|
|
return result
|
|
|
|
|
|
# Checks that a dict matches a dict with "regex keys". That is,
|
|
# the keys are regex expressions.
|
|
def assert_dict_matches_regex(self, dct, dct_with_regex_keys):
|
|
regex_keys = dct_with_regex_keys.keys()
|
|
regex_key_to_actual_key = {}
|
|
for regex_key in regex_keys:
|
|
for key in dct:
|
|
if re.match(regex_key, key):
|
|
if regex_key in regex_key_to_actual_key:
|
|
raise AssertionError(
|
|
f"Single key regex mapped to multiple keys. Please improve your "
|
|
f"regex. Got: regex='{regex_key}' "
|
|
f"keys='{regex_key_to_actual_key[regex_key]}',"
|
|
f"'{key}'"
|
|
)
|
|
regex_key_to_actual_key[regex_key] = key
|
|
new_dct = {}
|
|
for regex_key in regex_keys:
|
|
if regex_key not in regex_key_to_actual_key:
|
|
raise AssertionError(
|
|
f"Got regex '{regex_key}' but could not match any key in dict with "
|
|
f"keys {dct.keys()}"
|
|
)
|
|
new_dct[regex_key_to_actual_key[regex_key]] = dct_with_regex_keys[regex_key]
|
|
self.assertEqual(dct, new_dct)
|
|
|
|
|
|
def default_args_generator(seed_value):
|
|
flat_args, args_spec = pytree.tree_flatten(seed_value)
|
|
for i in range(3):
|
|
new_flat_arg = []
|
|
for val in flat_args:
|
|
if isinstance(val, torch.Tensor):
|
|
new_val = val + 0.1 * i
|
|
elif isinstance(val, int):
|
|
new_val = val + 1 * i
|
|
elif isinstance(val, float):
|
|
new_val = val + 0.1 * i
|
|
elif isinstance(val, enum.Enum):
|
|
new_val = val
|
|
else:
|
|
raise AssertionError("unexpected arg type")
|
|
|
|
new_flat_arg.append(new_val)
|
|
new_args = pytree.tree_unflatten(new_flat_arg, args_spec)
|
|
yield new_args
|
|
|
|
|
|
class HigherOrderOpTests(torch._dynamo.test_case.TestCase):
|
|
def _assert_wrap_fallback(self, func, args, setup=lambda: None):
|
|
counters.clear()
|
|
backend = EagerAndRecordGraphs()
|
|
cnt = CompileCounterWithBackend(backend)
|
|
|
|
setup()
|
|
expected = func(*args)
|
|
setup()
|
|
result = torch.compile(func, backend=cnt, fullgraph=False)(*args)
|
|
num_graph_breaks = len(counters["graph_break"].keys())
|
|
self.assertGreater(num_graph_breaks, 0)
|
|
|
|
for gm in backend.graphs:
|
|
for node in gm.graph.nodes:
|
|
self.assertFalse(node.target is wrap)
|
|
|
|
self.assertEqual(result, expected)
|
|
|
|
def _test_wrap_simple(
|
|
self,
|
|
func,
|
|
args_generator,
|
|
expected_num_wrap_args,
|
|
expected_opcount=2,
|
|
return_graph=False,
|
|
):
|
|
# Given a `func` that has a single call to `wrap`,
|
|
# we check that:
|
|
# - there are no graph breaks
|
|
# - eager vs torch.compile has the same result (correctness)
|
|
# - other compilation metrics, e.g, # of ops in the dynamo captured graph,
|
|
# the wrap has the expected number of args, etc
|
|
#
|
|
# we have one or multiple runs through with each of the args from args_generator,
|
|
# and we will check:
|
|
# - correctness and no graph breaks for every run
|
|
# - other compilation metrics only for the first run, since automatic_dynamic_shapes
|
|
# may compile another dynamic version graph for the later runs
|
|
graph = None
|
|
for i, args in enumerate(args_generator):
|
|
backend = EagerAndRecordGraphs()
|
|
cnt = CompileCounterWithBackend(backend)
|
|
expected = func(*args)
|
|
result = torch.compile(func, fullgraph=True, backend=cnt)(*args)
|
|
# check correctness and no graph breaks
|
|
self.assertEqual(result, expected)
|
|
self.assertEqual(cnt.frame_count, 1)
|
|
self.assertEqual(len(backend.graphs), 1)
|
|
# check other compilation metrics
|
|
if i == 0:
|
|
self.assertEqual(cnt.op_count, expected_opcount)
|
|
graph = backend.graphs[0]
|
|
wrap_node = find_first_node(graph, wrap)
|
|
self.assertEqual(len(wrap_node.args), expected_num_wrap_args)
|
|
# We always return/check the graph from the first run if return_graph = True
|
|
if return_graph:
|
|
return normalize_gm(graph.print_readable(print_output=False))
|
|
|
|
def test_error_message_sane(self):
|
|
foo = []
|
|
|
|
def inner(x):
|
|
foo.append(x)
|
|
return x.clone()
|
|
|
|
@torch.compile(backend="eager", fullgraph=True)
|
|
def f(x):
|
|
return wrap(inner, x)
|
|
|
|
x = torch.randn(3)
|
|
with self.assertRaisesRegex(
|
|
torch._dynamo.exc.Unsupported,
|
|
r"HigherOrderOperator: Mutating a variable not in the current scope \(SideEffects\)",
|
|
):
|
|
f(x)
|
|
|
|
def test_no_freevars(self):
|
|
def f(x):
|
|
return wrap(lambda x: torch.sin(x), x)
|
|
|
|
x = torch.randn(3)
|
|
arg_count = ifdynstaticdefault(2, 3)
|
|
self._test_wrap_simple(f, default_args_generator((x,)), arg_count)
|
|
|
|
def test_enum_arg(self):
|
|
class SomeEnum(enum.Enum):
|
|
A = 0
|
|
B = 1
|
|
|
|
def g(x, val):
|
|
if val == SomeEnum.A:
|
|
return torch.sin(x)
|
|
return torch.cos(x)
|
|
|
|
def f(x, val):
|
|
return wrap(g, x, val)
|
|
|
|
x = torch.randn(3)
|
|
arg_count = ifdynstaticdefault(2, 3)
|
|
self._test_wrap_simple(f, default_args_generator((x, SomeEnum.A)), arg_count)
|
|
|
|
def test_return_captured_var(self):
|
|
freevar = torch.randn(3)
|
|
|
|
def test(x):
|
|
return freevar
|
|
|
|
def fn(x):
|
|
return wrap(test, x)
|
|
|
|
x = torch.randn(3)
|
|
|
|
# Since, `x` is unused, we don't lift it to
|
|
# be the input.
|
|
|
|
# when testing with dynamic shape, symbols are lifted as input
|
|
arg_count = ifdynstaticdefault(2, 3)
|
|
self._test_wrap_simple(fn, default_args_generator((x,)), arg_count)
|
|
|
|
def test_return_captured_vars(self):
|
|
freevar1 = torch.randn(3)
|
|
freevar2 = torch.randn(3)
|
|
|
|
def test(x):
|
|
return freevar1, freevar2, freevar1
|
|
|
|
def fn(x):
|
|
return wrap(test, x)
|
|
|
|
x = torch.randn(3)
|
|
|
|
# Since, `x` is unused, we don't lift it to
|
|
# be the input.
|
|
# when testing with dynamic shape, a symbol is lifted as input
|
|
arg_count = ifdynstaticdefault(3, 4)
|
|
self._test_wrap_simple(fn, default_args_generator((x,)), arg_count, 4)
|
|
|
|
def test_return_captured_var_used_multiple_times(self):
|
|
freevar = torch.randn(3)
|
|
|
|
def test(x):
|
|
y = x + freevar
|
|
return y, freevar
|
|
|
|
def fn(x):
|
|
return wrap(test, x)
|
|
|
|
x = torch.randn(3)
|
|
# when testing with dynamic shape, a symbol is lifted as input
|
|
arg_count = ifdynstaticdefault(3, 4)
|
|
self._test_wrap_simple(fn, default_args_generator((x,)), arg_count, 3)
|
|
|
|
def test_capture_untracked_global(self):
|
|
def f(x):
|
|
return wrap(lambda x: x + global_var, x)
|
|
|
|
x = torch.randn(3)
|
|
# when testing with dynamic shape, a symbol is lifted as input
|
|
arg_count = ifdynstaticdefault(3, 4)
|
|
self._test_wrap_simple(f, default_args_generator((x,)), arg_count)
|
|
|
|
def test_symint_input(self):
|
|
def f(x):
|
|
i = x.size(0)
|
|
return wrap(lambda x, i: x.view(i), x, i)
|
|
|
|
x = torch.randn(3, 1)
|
|
self._test_wrap_simple(
|
|
f,
|
|
default_args_generator((x,)),
|
|
ifdynstaticdefault(2, 3),
|
|
expected_opcount=2,
|
|
)
|
|
|
|
def test_wrap_pytree_args_nested(self):
|
|
def f(x, y, z):
|
|
def fn(d):
|
|
return d["x"].sin() + d["y"][0].cos() - d["y"][1][2].sin()
|
|
|
|
return wrap(fn, d)
|
|
|
|
x = torch.tensor(1.5)
|
|
y = torch.tensor(2.0)
|
|
z = torch.tensor(3.0)
|
|
d = {"x": x, "y": (y, [x, y, z])}
|
|
|
|
def my_args_generator(t):
|
|
yield t
|
|
yield t[0] + 0.1, t[1], t[2]
|
|
yield t[0], t[1] + 0.1, t[2]
|
|
|
|
actual_graph = self._test_wrap_simple(
|
|
f,
|
|
my_args_generator((x, y, z)),
|
|
4,
|
|
return_graph=True,
|
|
)
|
|
self.assertExpectedInline(
|
|
actual_graph,
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(self, L_d_x_: "f32[]", L_d_y_0_: "f32[]", L_d_y_1_2_: "f32[]"):
|
|
l_d_x_ = L_d_x_
|
|
l_d_y_0_ = L_d_y_0_
|
|
l_d_y_1_2_ = L_d_y_1_2_
|
|
|
|
wrap_body_0 = self.wrap_body_0
|
|
wrap = torch.ops.higher_order.wrap(wrap_body_0, l_d_x_, l_d_y_0_, l_d_y_1_2_); wrap_body_0 = l_d_x_ = l_d_y_0_ = l_d_y_1_2_ = None
|
|
getitem: "f32[]" = wrap[0]; wrap = None
|
|
return (getitem,)
|
|
|
|
class wrap_body_0(torch.nn.Module):
|
|
def forward(self, l_d_x_: "f32[]", l_d_y_0_: "f32[]", l_d_y_1_2_: "f32[]"):
|
|
sin: "f32[]" = l_d_x_.sin(); l_d_x_ = None
|
|
cos: "f32[]" = l_d_y_0_.cos(); l_d_y_0_ = None
|
|
add: "f32[]" = sin + cos; sin = cos = None
|
|
sin_1: "f32[]" = l_d_y_1_2_.sin(); l_d_y_1_2_ = None
|
|
sub: "f32[]" = add - sin_1; add = sin_1 = None
|
|
return (sub,)
|
|
""", # NOQA: B950
|
|
)
|
|
|
|
def test_wrap_pytree_args_with_symint_constant(self):
|
|
def f(x, y):
|
|
i = x.size(0)
|
|
return wrap(lambda t: t[0].view(t[2]) + t[1], (x, y, i))
|
|
|
|
x = torch.randn(3, 1)
|
|
y = 0.5
|
|
actual_graph = self._test_wrap_simple(
|
|
f,
|
|
default_args_generator((x, y)),
|
|
ifdynstaticdefault(2, 3),
|
|
expected_opcount=2,
|
|
return_graph=True,
|
|
)
|
|
if torch._dynamo.config.assume_static_by_default:
|
|
self.assertExpectedInline(
|
|
actual_graph,
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(self, L_x_: "f32[3, 1]"):
|
|
l_x_ = L_x_
|
|
|
|
wrap_body_0 = self.wrap_body_0
|
|
wrap = torch.ops.higher_order.wrap(wrap_body_0, l_x_); wrap_body_0 = l_x_ = None
|
|
getitem: "f32[3]" = wrap[0]; wrap = None
|
|
return (getitem,)
|
|
|
|
class wrap_body_0(torch.nn.Module):
|
|
def forward(self, l_x_: "f32[3, 1]"):
|
|
view: "f32[3]" = l_x_.view(3); l_x_ = None
|
|
add: "f32[3]" = view + 0.5; view = None
|
|
return (add,)
|
|
""",
|
|
)
|
|
else:
|
|
self.assertExpectedInline(
|
|
actual_graph,
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(self, s0: "Sym(s0)", L_x_: "f32[s0, 1]"):
|
|
l_x_ = L_x_
|
|
|
|
wrap_body_0 = self.wrap_body_0
|
|
wrap = torch.ops.higher_order.wrap(wrap_body_0, s0, l_x_); wrap_body_0 = s0 = l_x_ = None
|
|
getitem: "f32[s0]" = wrap[0]; wrap = None
|
|
return (getitem,)
|
|
|
|
class wrap_body_0(torch.nn.Module):
|
|
def forward(self, s0: "Sym(s0)", l_x_: "f32[s0, 1]"):
|
|
view: "f32[s0]" = l_x_.view(s0); l_x_ = s0 = None
|
|
add: "f32[s0]" = view + 0.5; view = None
|
|
return (add,)
|
|
""",
|
|
)
|
|
|
|
def test_wrap_pytree_kwargs(self):
|
|
def f(x, y, z):
|
|
def fn(*, x, y, z):
|
|
z1, z2 = z
|
|
return (x * 2) + y + z1
|
|
|
|
return wrap(fn, x=x, y=y, z=z)
|
|
|
|
x = torch.randn(3)
|
|
y = torch.randn(3, 3)
|
|
|
|
def my_args_generator(t):
|
|
yield t
|
|
x1 = t[0] + 0.1
|
|
y1 = t[1] + 0.1
|
|
yield (x1, y1, (x1, y1))
|
|
x2 = t[0] + 0.2
|
|
y2 = t[0] + 0.2
|
|
yield (x2, y2, (x2, y2))
|
|
|
|
arg_count = ifdynstaticdefault(3, 4)
|
|
self._test_wrap_simple(f, my_args_generator((x, y, (x, y))), arg_count)
|
|
|
|
def test_wrap_pytree_args_not_const_symint_tensor(self):
|
|
class MyClass:
|
|
def __init__(self, x):
|
|
self.val = x
|
|
|
|
def f(x, y):
|
|
return wrap(lambda z: z[0].sin() * z[1].val.cos(), (x, y))
|
|
|
|
x = torch.tensor(1.2)
|
|
y = MyClass(torch.tensor(3.4))
|
|
self._test_wrap_simple(f, [(x, y)], 3)
|
|
|
|
def test_capture_constants(self):
|
|
x = torch.randn(3, 3)
|
|
y = 4.0
|
|
|
|
def fn(x, y, z):
|
|
if z:
|
|
return x + y
|
|
return x * y
|
|
|
|
def f(x, y, z):
|
|
return wrap(fn, x, y, z)
|
|
|
|
args = (x, 4.0, None)
|
|
opt_f = torch.compile(f, fullgraph=True, backend=CompileCounter())
|
|
expected = f(*args)
|
|
result = opt_f(*args)
|
|
self.assertEqual(result, expected)
|
|
|
|
# Ensure that we recompile here
|
|
args = (x, 5.0, None)
|
|
expected = f(*args)
|
|
result = opt_f(*args)
|
|
self.assertEqual(result, expected)
|
|
|
|
def test_capture_untracked_global_nested(self):
|
|
backend = EagerAndRecordGraphs()
|
|
cnt = CompileCounterWithBackend(backend)
|
|
|
|
@torch.compile(backend=cnt, fullgraph=True)
|
|
def f(x):
|
|
return wrap(lambda x: wrap(lambda x: x + global_var, x), x)
|
|
|
|
x = torch.randn(3)
|
|
result = f(x)
|
|
|
|
self.assertEqual(result, x + global_var)
|
|
self.assertEqual(cnt.frame_count, 1)
|
|
self.assertEqual(cnt.op_count, 2)
|
|
|
|
self.assertEqual(len(backend.graphs), 1)
|
|
wrap_node = find_first_node(backend.graphs[0], wrap)
|
|
self.assertTrue(len(wrap_node.args), 3)
|
|
|
|
body_function = getattr(backend.graphs[0], wrap_node.args[0].name)
|
|
self.assertEqual(op_count(body_function), 2)
|
|
inner_wrap_node = find_first_node(body_function, wrap)
|
|
self.assertTrue(len(inner_wrap_node.args), 3)
|
|
|
|
def test_capture_untracked_nonlocal(self):
|
|
x = torch.randn(3, 3)
|
|
y = torch.randn(3, 3)
|
|
|
|
def f(x, y):
|
|
def g(x):
|
|
return wrap(lambda x: x + y, x)
|
|
|
|
# when testing with dynamic shape, a symbol is lifted as input
|
|
arg_count = ifdynstaticdefault(3, 4)
|
|
self._test_wrap_simple(g, default_args_generator((x,)), arg_count)
|
|
return g(x)
|
|
|
|
f(x, y)
|
|
|
|
def test_capture_tracked(self):
|
|
x = torch.randn(3, 3)
|
|
y = torch.randn(3, 3)
|
|
|
|
def f(x, y):
|
|
return wrap(lambda x: x + y, x)
|
|
|
|
# when testing with dynamic shape, a symbol is lifted as input
|
|
arg_count = ifdynstaticdefault(3, 4)
|
|
self._test_wrap_simple(f, default_args_generator((x, y)), arg_count)
|
|
|
|
def test_capture_tracked_nested(self):
|
|
x = torch.randn(3, 3)
|
|
y = torch.randn(3, 3)
|
|
|
|
def f(x, y):
|
|
return wrap(lambda x: wrap(lambda x: x + y, x), x)
|
|
|
|
# when testing with dynamic shape, a symbol is lifted as input
|
|
arg_count = ifdynstaticdefault(3, 4)
|
|
self._test_wrap_simple(f, default_args_generator((x, y)), arg_count)
|
|
|
|
def test_inlined_functions(self):
|
|
def g(x, y):
|
|
return x + y
|
|
|
|
def f(x, y):
|
|
return wrap(lambda x: g(x, y), x)
|
|
|
|
x = torch.randn(3, 3)
|
|
y = torch.randn(3, 3)
|
|
# when testing with dynamic shape, a symbol is lifted as input
|
|
arg_count = ifdynstaticdefault(3, 4)
|
|
self._test_wrap_simple(f, default_args_generator((x, y)), arg_count)
|
|
|
|
def test_same_freevar_twice(self):
|
|
free = torch.randn(3)
|
|
|
|
def g(x):
|
|
y = free.sin()
|
|
z = free.cos()
|
|
return y, z
|
|
|
|
def f(x):
|
|
return wrap(g, x)
|
|
|
|
x = torch.randn(3)
|
|
|
|
# Since, `x` is unused, we don't lift it to
|
|
# be the input.
|
|
# when testing with dynamic shape, a symbol is lifted as input
|
|
arg_count = ifdynstaticdefault(2, 3)
|
|
self._test_wrap_simple(f, default_args_generator((x,)), arg_count, 3)
|
|
|
|
@torch._dynamo.config.patch(
|
|
capture_scalar_outputs=True,
|
|
)
|
|
def test_unbacked_symbol_closure(self):
|
|
def f(x):
|
|
c = x.sum().item()
|
|
|
|
def g(x):
|
|
def k(x):
|
|
return x + c
|
|
|
|
return wrap(k, x)
|
|
|
|
return wrap(g, x)
|
|
|
|
x = torch.randn(3)
|
|
arg_count = ifdynstaticdefault(3, 4)
|
|
out_graph = self._test_wrap_simple(
|
|
f, default_args_generator((x,)), arg_count, 4, return_graph=True
|
|
)
|
|
|
|
if check_dynamic_shape_capture():
|
|
self.assertExpectedInline(
|
|
out_graph,
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(self, s0: "Sym(s0)", L_x_: "f32[s0]"):
|
|
l_x_ = L_x_
|
|
|
|
sum_1: "f32[]" = l_x_.sum()
|
|
item: "Sym(zuf0)" = sum_1.item(); sum_1 = None
|
|
|
|
wrap_body_1 = self.wrap_body_1
|
|
wrap = torch.ops.higher_order.wrap(wrap_body_1, s0, l_x_, item); wrap_body_1 = s0 = l_x_ = item = None
|
|
getitem: "f32[s0]" = wrap[0]; wrap = None
|
|
return (getitem,)
|
|
|
|
class wrap_body_1(torch.nn.Module):
|
|
def forward(self, s0: "Sym(s0)", l_x_: "f32[s0]", item: "Sym(zuf0)"):
|
|
wrap_body_0 = self.wrap_body_0
|
|
wrap = torch.ops.higher_order.wrap(wrap_body_0, s0, l_x_, item); wrap_body_0 = s0 = l_x_ = item = None
|
|
getitem: "f32[s0]" = wrap[0]; wrap = None
|
|
return (getitem,)
|
|
|
|
class wrap_body_0(torch.nn.Module):
|
|
def forward(self, s0: "Sym(s0)", l_x_: "f32[s0]", item: "Sym(zuf0)"):
|
|
add: "f32[s0]" = l_x_ + item; l_x_ = item = None
|
|
return (add,)
|
|
""",
|
|
)
|
|
else:
|
|
self.assertExpectedInline(
|
|
out_graph,
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(self, L_x_: "f32[3]"):
|
|
l_x_ = L_x_
|
|
|
|
sum_1: "f32[]" = l_x_.sum()
|
|
item: "Sym(zuf0)" = sum_1.item(); sum_1 = None
|
|
|
|
wrap_body_1 = self.wrap_body_1
|
|
wrap = torch.ops.higher_order.wrap(wrap_body_1, l_x_, item); wrap_body_1 = l_x_ = item = None
|
|
getitem: "f32[3]" = wrap[0]; wrap = None
|
|
return (getitem,)
|
|
|
|
class wrap_body_1(torch.nn.Module):
|
|
def forward(self, l_x_: "f32[3]", item: "Sym(zuf0)"):
|
|
wrap_body_0 = self.wrap_body_0
|
|
wrap = torch.ops.higher_order.wrap(wrap_body_0, l_x_, item); wrap_body_0 = l_x_ = item = None
|
|
getitem: "f32[3]" = wrap[0]; wrap = None
|
|
return (getitem,)
|
|
|
|
class wrap_body_0(torch.nn.Module):
|
|
def forward(self, l_x_: "f32[3]", item: "Sym(zuf0)"):
|
|
add: "f32[3]" = l_x_ + item; l_x_ = item = None
|
|
return (add,)
|
|
""",
|
|
)
|
|
|
|
@torch._dynamo.config.patch(
|
|
capture_dynamic_output_shape_ops=True,
|
|
)
|
|
def test_tensor_with_unbacked_shape_closure(self):
|
|
def f(x):
|
|
c = x.nonzero()
|
|
|
|
def g(x):
|
|
def k(x):
|
|
return x.sin(), c.sin()
|
|
|
|
return wrap(k, x)
|
|
|
|
return wrap(g, x)
|
|
|
|
x = torch.randn(3)
|
|
arg_count = ifdynstaticdefault(4, 5)
|
|
# when compiled with dynamic, we don't have upper bound runtime assertions for u0
|
|
expected_op_count = ifdynstaticdefault(10, 8)
|
|
out_graph = self._test_wrap_simple(
|
|
f,
|
|
default_args_generator((x,)),
|
|
arg_count,
|
|
expected_op_count,
|
|
return_graph=True,
|
|
)
|
|
|
|
if check_dynamic_shape_capture():
|
|
self.assertExpectedInline(
|
|
out_graph,
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(self, s0: "Sym(s0)", L_x_: "f32[s0]"):
|
|
l_x_ = L_x_
|
|
|
|
c: "i64[u0, 1]" = l_x_.nonzero()
|
|
|
|
sym_size_int_1: "Sym(u0)" = torch.ops.aten.sym_size.int(c, 0)
|
|
_check_is_size = torch._check_is_size(sym_size_int_1); _check_is_size = None
|
|
|
|
ge: "Sym(u0 >= 0)" = sym_size_int_1 >= 0
|
|
_assert_scalar_default = torch.ops.aten._assert_scalar.default(ge, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge = _assert_scalar_default = None
|
|
|
|
wrap_body_1 = self.wrap_body_1
|
|
wrap = torch.ops.higher_order.wrap(wrap_body_1, s0, l_x_, sym_size_int_1, c); wrap_body_1 = s0 = l_x_ = sym_size_int_1 = c = None
|
|
getitem: "f32[s0]" = wrap[0]
|
|
getitem_1: "f32[u0, 1]" = wrap[1]; wrap = None
|
|
return (getitem, getitem_1)
|
|
|
|
class wrap_body_1(torch.nn.Module):
|
|
def forward(self, s0: "Sym(s0)", l_x_: "f32[s0]", u0: "Sym(u0)", c: "i64[u0, 1]"):
|
|
wrap_body_0 = self.wrap_body_0
|
|
wrap = torch.ops.higher_order.wrap(wrap_body_0, s0, l_x_, u0, c); wrap_body_0 = s0 = l_x_ = u0 = c = None
|
|
child: "f32[s0]" = wrap[0]
|
|
child_1: "f32[u0, 1]" = wrap[1]; wrap = None
|
|
return (child, child_1)
|
|
|
|
class wrap_body_0(torch.nn.Module):
|
|
def forward(self, s0: "Sym(s0)", l_x_: "f32[s0]", u0: "Sym(u0)", c: "i64[u0, 1]"):
|
|
child: "f32[s0]" = l_x_.sin(); l_x_ = None
|
|
child_1: "f32[u0, 1]" = c.sin(); c = None
|
|
return (child, child_1)
|
|
""",
|
|
)
|
|
else:
|
|
self.assertExpectedInline(
|
|
out_graph,
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(self, L_x_: "f32[3]"):
|
|
l_x_ = L_x_
|
|
|
|
c: "i64[u0, 1]" = l_x_.nonzero()
|
|
|
|
sym_size_int_1: "Sym(u0)" = torch.ops.aten.sym_size.int(c, 0)
|
|
_check_is_size = torch._check_is_size(sym_size_int_1); _check_is_size = None
|
|
|
|
ge: "Sym(u0 >= 0)" = sym_size_int_1 >= 0
|
|
_assert_scalar_default = torch.ops.aten._assert_scalar.default(ge, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge = _assert_scalar_default = None
|
|
le: "Sym(u0 <= 3)" = sym_size_int_1 <= 3
|
|
_assert_scalar_default_1 = torch.ops.aten._assert_scalar.default(le, "Runtime assertion failed for expression u0 <= 3 on node 'le'"); le = _assert_scalar_default_1 = None
|
|
|
|
wrap_body_1 = self.wrap_body_1
|
|
wrap = torch.ops.higher_order.wrap(wrap_body_1, l_x_, sym_size_int_1, c); wrap_body_1 = l_x_ = sym_size_int_1 = c = None
|
|
getitem: "f32[3]" = wrap[0]
|
|
getitem_1: "f32[u0, 1]" = wrap[1]; wrap = None
|
|
return (getitem, getitem_1)
|
|
|
|
class wrap_body_1(torch.nn.Module):
|
|
def forward(self, l_x_: "f32[3]", u0: "Sym(u0)", c: "i64[u0, 1]"):
|
|
wrap_body_0 = self.wrap_body_0
|
|
wrap = torch.ops.higher_order.wrap(wrap_body_0, l_x_, u0, c); wrap_body_0 = l_x_ = u0 = c = None
|
|
child: "f32[3]" = wrap[0]
|
|
child_1: "f32[u0, 1]" = wrap[1]; wrap = None
|
|
return (child, child_1)
|
|
|
|
class wrap_body_0(torch.nn.Module):
|
|
def forward(self, l_x_: "f32[3]", u0: "Sym(u0)", c: "i64[u0, 1]"):
|
|
child: "f32[3]" = l_x_.sin(); l_x_ = None
|
|
child_1: "f32[u0, 1]" = c.sin(); c = None
|
|
return (child, child_1)
|
|
""",
|
|
)
|
|
|
|
@torch._dynamo.config.patch(
|
|
capture_dynamic_output_shape_ops=True,
|
|
)
|
|
def test_tensor_to_list_closure(self):
|
|
def f(x):
|
|
li = x.tolist()
|
|
|
|
def g(x):
|
|
def k(x):
|
|
return li[0] + x
|
|
|
|
return wrap(k, x)
|
|
|
|
return wrap(g, x)
|
|
|
|
x = torch.tensor([1, 2, 3], dtype=torch.int16)
|
|
arg_count = ifdynstaticdefault(3, 3)
|
|
out_graph = self._test_wrap_simple(f, ((x,),), arg_count, 4, return_graph=True)
|
|
|
|
# tolist will specialize on input shapes, so dynamic and static tests
|
|
# have the same graph
|
|
self.assertExpectedInline(
|
|
out_graph,
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(self, L_x_: "i16[3]"):
|
|
l_x_ = L_x_
|
|
|
|
getitem = l_x_[0]
|
|
item: "Sym(u0)" = getitem.item(); getitem = None
|
|
|
|
wrap_body_1 = self.wrap_body_1
|
|
wrap = torch.ops.higher_order.wrap(wrap_body_1, item, l_x_); wrap_body_1 = item = l_x_ = None
|
|
getitem_3: "i16[3]" = wrap[0]; wrap = None
|
|
return (getitem_3,)
|
|
|
|
class wrap_body_1(torch.nn.Module):
|
|
def forward(self, item: "Sym(u0)", l_x_: "i16[3]"):
|
|
wrap_body_0 = self.wrap_body_0
|
|
wrap = torch.ops.higher_order.wrap(wrap_body_0, item, l_x_); wrap_body_0 = item = l_x_ = None
|
|
getitem: "i16[3]" = wrap[0]; wrap = None
|
|
return (getitem,)
|
|
|
|
class wrap_body_0(torch.nn.Module):
|
|
def forward(self, item: "Sym(u0)", l_x_: "i16[3]"):
|
|
add: "i16[3]" = item + l_x_; item = l_x_ = None
|
|
return (add,)
|
|
""",
|
|
)
|
|
|
|
@torch._dynamo.config.patch(
|
|
capture_dynamic_output_shape_ops=True,
|
|
)
|
|
def test_tensor_and_unbacked_symbol_closure(self):
|
|
def f(x):
|
|
c = x.nonzero()
|
|
sz = c.size(0)
|
|
|
|
def g(x):
|
|
def k(x):
|
|
return x.sin() + sz, c.sin()
|
|
|
|
return wrap(k, x)
|
|
|
|
return wrap(g, x)
|
|
|
|
x = torch.randn(3)
|
|
arg_count = ifdynstaticdefault(4, 5)
|
|
# when compiled with dynamic, we don't have upper bound runtime assertions for u0
|
|
expected_op_count = ifdynstaticdefault(10, 8)
|
|
out_graph = self._test_wrap_simple(
|
|
f,
|
|
default_args_generator((x,)),
|
|
arg_count,
|
|
expected_op_count,
|
|
return_graph=True,
|
|
)
|
|
|
|
# Note that u0 is accessed from sz and the shape of c
|
|
# We cached via the symbol u0 and de-duplicate them.
|
|
if not check_dynamic_shape_capture():
|
|
self.assertExpectedInline(
|
|
out_graph,
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(self, L_x_: "f32[3]"):
|
|
l_x_ = L_x_
|
|
|
|
c: "i64[u0, 1]" = l_x_.nonzero()
|
|
|
|
sym_size_int: "Sym(u0)" = torch.ops.aten.sym_size.int(c, 0)
|
|
_check_is_size = torch._check_is_size(sym_size_int); _check_is_size = None
|
|
|
|
ge: "Sym(u0 >= 0)" = sym_size_int >= 0
|
|
_assert_scalar_default = torch.ops.aten._assert_scalar.default(ge, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge = _assert_scalar_default = None
|
|
le: "Sym(u0 <= 3)" = sym_size_int <= 3
|
|
_assert_scalar_default_1 = torch.ops.aten._assert_scalar.default(le, "Runtime assertion failed for expression u0 <= 3 on node 'le'"); le = _assert_scalar_default_1 = None
|
|
|
|
wrap_body_1 = self.wrap_body_1
|
|
wrap = torch.ops.higher_order.wrap(wrap_body_1, l_x_, sym_size_int, c); wrap_body_1 = l_x_ = sym_size_int = c = None
|
|
getitem: "f32[3]" = wrap[0]
|
|
getitem_1: "f32[u0, 1]" = wrap[1]; wrap = None
|
|
return (getitem, getitem_1)
|
|
|
|
class wrap_body_1(torch.nn.Module):
|
|
def forward(self, l_x_: "f32[3]", size: "Sym(u0)", c: "i64[u0, 1]"):
|
|
wrap_body_0 = self.wrap_body_0
|
|
wrap = torch.ops.higher_order.wrap(wrap_body_0, l_x_, size, c); wrap_body_0 = l_x_ = size = c = None
|
|
child: "f32[3]" = wrap[0]
|
|
child_1: "f32[u0, 1]" = wrap[1]; wrap = None
|
|
return (child, child_1)
|
|
|
|
class wrap_body_0(torch.nn.Module):
|
|
def forward(self, l_x_: "f32[3]", size: "Sym(u0)", c: "i64[u0, 1]"):
|
|
sin: "f32[3]" = l_x_.sin(); l_x_ = None
|
|
child: "f32[3]" = sin + size; sin = size = None
|
|
child_1: "f32[u0, 1]" = c.sin(); c = None
|
|
return (child, child_1)
|
|
""",
|
|
)
|
|
|
|
@torch._dynamo.config.patch(
|
|
capture_dynamic_output_shape_ops=True,
|
|
)
|
|
def test_concat_unbacked_shape_tensor(self):
|
|
def f(x, y):
|
|
c = x.nonzero()
|
|
d = y.nonzero()
|
|
cat = torch.cat((c, d))
|
|
|
|
def g(x):
|
|
def k(x):
|
|
return cat.sum() + x
|
|
|
|
return wrap(k, x)
|
|
|
|
return wrap(g, x)
|
|
|
|
x = torch.randn(3)
|
|
y = torch.randn(3)
|
|
arg_count = ifdynstaticdefault(5, 6)
|
|
# when compiled with dynamic, we don't have upper bound runtime assertions for u0 and u1
|
|
expected_op_count = ifdynstaticdefault(17, 13)
|
|
out_graph = self._test_wrap_simple(
|
|
f,
|
|
default_args_generator((x, y)),
|
|
arg_count,
|
|
expected_op_count,
|
|
return_graph=True,
|
|
)
|
|
|
|
if not check_dynamic_shape_capture():
|
|
self.assertExpectedInline(
|
|
out_graph,
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(self, L_x_: "f32[3]", L_y_: "f32[3]"):
|
|
l_x_ = L_x_
|
|
l_y_ = L_y_
|
|
|
|
c: "i64[u0, 1]" = l_x_.nonzero()
|
|
|
|
sym_size_int_2: "Sym(u0)" = torch.ops.aten.sym_size.int(c, 0)
|
|
_check_is_size = torch._check_is_size(sym_size_int_2); _check_is_size = None
|
|
|
|
ge: "Sym(u0 >= 0)" = sym_size_int_2 >= 0
|
|
_assert_scalar_default = torch.ops.aten._assert_scalar.default(ge, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge = _assert_scalar_default = None
|
|
le: "Sym(u0 <= 3)" = sym_size_int_2 <= 3
|
|
_assert_scalar_default_1 = torch.ops.aten._assert_scalar.default(le, "Runtime assertion failed for expression u0 <= 3 on node 'le'"); le = _assert_scalar_default_1 = None
|
|
|
|
d: "i64[u1, 1]" = l_y_.nonzero(); l_y_ = None
|
|
|
|
sym_size_int_3: "Sym(u1)" = torch.ops.aten.sym_size.int(d, 0)
|
|
_check_is_size_1 = torch._check_is_size(sym_size_int_3); _check_is_size_1 = None
|
|
|
|
ge_1: "Sym(u1 >= 0)" = sym_size_int_3 >= 0
|
|
_assert_scalar_default_2 = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u1 >= 0 on node 'ge_1'"); ge_1 = _assert_scalar_default_2 = None
|
|
le_1: "Sym(u1 <= 3)" = sym_size_int_3 <= 3
|
|
_assert_scalar_default_3 = torch.ops.aten._assert_scalar.default(le_1, "Runtime assertion failed for expression u1 <= 3 on node 'le_1'"); le_1 = _assert_scalar_default_3 = None
|
|
|
|
cat: "i64[u0 + u1, 1]" = torch.cat((c, d)); c = d = None
|
|
|
|
wrap_body_1 = self.wrap_body_1
|
|
wrap = torch.ops.higher_order.wrap(wrap_body_1, sym_size_int_2, sym_size_int_3, cat, l_x_); wrap_body_1 = sym_size_int_2 = sym_size_int_3 = cat = l_x_ = None
|
|
getitem: "f32[3]" = wrap[0]; wrap = None
|
|
return (getitem,)
|
|
|
|
class wrap_body_1(torch.nn.Module):
|
|
def forward(self, u0: "Sym(u0)", u1: "Sym(u1)", cat: "i64[u0 + u1, 1]", l_x_: "f32[3]"):
|
|
wrap_body_0 = self.wrap_body_0
|
|
wrap = torch.ops.higher_order.wrap(wrap_body_0, u0, u1, cat, l_x_); wrap_body_0 = u0 = u1 = cat = l_x_ = None
|
|
getitem: "f32[3]" = wrap[0]; wrap = None
|
|
return (getitem,)
|
|
|
|
class wrap_body_0(torch.nn.Module):
|
|
def forward(self, u0: "Sym(u0)", u1: "Sym(u1)", cat: "i64[u0 + u1, 1]", l_x_: "f32[3]"):
|
|
sum_1: "i64[]" = cat.sum(); cat = None
|
|
add: "f32[3]" = sum_1 + l_x_; sum_1 = l_x_ = None
|
|
return (add,)
|
|
""",
|
|
)
|
|
|
|
@torch._dynamo.config.patch(
|
|
assume_static_by_default=False,
|
|
dynamic_shapes=True,
|
|
)
|
|
def test_lift_tensors_with_shared_symbols(self):
|
|
def f(x, y):
|
|
def g(x):
|
|
def k(x):
|
|
return x @ y
|
|
|
|
return wrap(k, x)
|
|
|
|
return wrap(g, x)
|
|
|
|
x = torch.randn(2, 3)
|
|
y = torch.randn(3, 4)
|
|
|
|
out_graph = self._test_wrap_simple(
|
|
f,
|
|
default_args_generator((x, y)),
|
|
6,
|
|
2,
|
|
return_graph=True,
|
|
)
|
|
self.assertExpectedInline(
|
|
out_graph,
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(self, s0: "Sym(s0)", s1: "Sym(s1)", L_x_: "f32[s0, s1]", s2: "Sym(s2)", L_y_: "f32[s1, s2]"):
|
|
l_x_ = L_x_
|
|
l_y_ = L_y_
|
|
|
|
wrap_body_1 = self.wrap_body_1
|
|
wrap = torch.ops.higher_order.wrap(wrap_body_1, s0, s1, l_x_, s2, l_y_); wrap_body_1 = s0 = s1 = l_x_ = s2 = l_y_ = None
|
|
getitem: "f32[s0, s2]" = wrap[0]; wrap = None
|
|
return (getitem,)
|
|
|
|
class wrap_body_1(torch.nn.Module):
|
|
def forward(self, s0: "Sym(s0)", s1: "Sym(s1)", l_x_: "f32[s0, s1]", s2: "Sym(s2)", l_y_: "f32[s1, s2]"):
|
|
wrap_body_0 = self.wrap_body_0
|
|
wrap = torch.ops.higher_order.wrap(wrap_body_0, s0, s1, l_x_, s2, l_y_); wrap_body_0 = s0 = s1 = l_x_ = s2 = l_y_ = None
|
|
getitem: "f32[s0, s2]" = wrap[0]; wrap = None
|
|
return (getitem,)
|
|
|
|
class wrap_body_0(torch.nn.Module):
|
|
def forward(self, s0: "Sym(s0)", s1: "Sym(s1)", l_x_: "f32[s0, s1]", s2: "Sym(s2)", l_y_: "f32[s1, s2]"):
|
|
matmul: "f32[s0, s2]" = l_x_ @ l_y_; l_x_ = l_y_ = None
|
|
return (matmul,)
|
|
""",
|
|
)
|
|
|
|
@torch._dynamo.config.patch(
|
|
assume_static_by_default=False,
|
|
dynamic_shapes=True,
|
|
capture_dynamic_output_shape_ops=True,
|
|
)
|
|
def test_lift_tensors_with_compound_expressions(self):
|
|
def f(x, y):
|
|
x = x.view(-1, 2)
|
|
c = y.nonzero()
|
|
d = torch.concat((x, c))
|
|
|
|
def g(x):
|
|
def k(x):
|
|
return d.sum() + x
|
|
|
|
return wrap(k, x)
|
|
|
|
return wrap(g, x)
|
|
|
|
x = torch.randn(2, 3)
|
|
y = torch.randn(3, 4)
|
|
|
|
f(x, y)
|
|
|
|
if not check_dynamic_shape_capture():
|
|
out_graph = self._test_wrap_simple(
|
|
f,
|
|
default_args_generator((x, y)),
|
|
6,
|
|
9,
|
|
return_graph=True,
|
|
)
|
|
self.assertExpectedInline(
|
|
out_graph,
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(self, s0: "Sym(s0)", s1: "Sym(s1)", L_x_: "f32[s0, s1]", s2: "Sym(s2)", L_y_: "f32[s1, s2]"):
|
|
l_x_ = L_x_
|
|
l_y_ = L_y_
|
|
|
|
x: "f32[((s0*s1)//2), ((s0*s1)//(((s0*s1)//2)))]" = l_x_.view(-1, 2); l_x_ = None
|
|
|
|
c: "i64[u0, 2]" = l_y_.nonzero(); l_y_ = None
|
|
|
|
sym_size_int_1: "Sym(u0)" = torch.ops.aten.sym_size.int(c, 0)
|
|
_check_is_size = torch._check_is_size(sym_size_int_1); _check_is_size = None
|
|
|
|
ge: "Sym(u0 >= 0)" = sym_size_int_1 >= 0
|
|
_assert_scalar_default = torch.ops.aten._assert_scalar.default(ge, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge = _assert_scalar_default = None
|
|
|
|
d: "f32[u0 + ((s0*s1)//2), ((s0*s1)//(((s0*s1)//2)))]" = torch.concat((x, c)); c = None
|
|
|
|
wrap_body_1 = self.wrap_body_1
|
|
wrap = torch.ops.higher_order.wrap(wrap_body_1, sym_size_int_1, s1, s0, d, x); wrap_body_1 = sym_size_int_1 = s1 = s0 = d = x = None
|
|
getitem: "f32[((s0*s1)//2), ((s0*s1)//(((s0*s1)//2)))]" = wrap[0]; wrap = None
|
|
return (getitem,)
|
|
|
|
class wrap_body_1(torch.nn.Module):
|
|
def forward(self, u0: "Sym(u0)", s1: "Sym(s1)", s0: "Sym(s0)", d: "f32[u0 + ((s0*s1)//2), ((s0*s1)//(((s0*s1)//2)))]", x: "f32[((s0*s1)//2), ((s0*s1)//(((s0*s1)//2)))]"):
|
|
wrap_body_0 = self.wrap_body_0
|
|
wrap = torch.ops.higher_order.wrap(wrap_body_0, u0, s1, s0, d, x); wrap_body_0 = u0 = s1 = s0 = d = x = None
|
|
getitem: "f32[((s0*s1)//2), ((s0*s1)//(((s0*s1)//2)))]" = wrap[0]; wrap = None
|
|
return (getitem,)
|
|
|
|
class wrap_body_0(torch.nn.Module):
|
|
def forward(self, u0: "Sym(u0)", s1: "Sym(s1)", s0: "Sym(s0)", d: "f32[u0 + ((s0*s1)//2), ((s0*s1)//(((s0*s1)//2)))]", x: "f32[((s0*s1)//2), ((s0*s1)//(((s0*s1)//2)))]"):
|
|
sum_1: "f32[]" = d.sum(); d = None
|
|
add: "f32[((s0*s1)//2), ((s0*s1)//(((s0*s1)//2)))]" = sum_1 + x; sum_1 = x = None
|
|
return (add,)
|
|
""",
|
|
)
|
|
|
|
def test_register_subclass(self):
|
|
from torch._higher_order_ops.cond import cond_op
|
|
from torch.testing._internal.two_tensor import TwoTensor
|
|
|
|
a = torch.tensor([1.0, 0.0, 1.0])
|
|
b = torch.randn(3)
|
|
t = TwoTensor(a, b)
|
|
with self.assertRaisesRegex(
|
|
NotImplementedError,
|
|
"no rule registered for HOP cond and subclass .*TwoTensor'>",
|
|
):
|
|
res = cond_op(a.sum() > 0, torch.sin, torch.cos, (t,))
|
|
|
|
called = 0
|
|
|
|
# Using cond.py_impl
|
|
@cond_op.py_impl(TwoTensor)
|
|
def _(pred, true_fn, false_fn, operands):
|
|
nonlocal called
|
|
called += 1
|
|
assert len(operands) == 1
|
|
a = cond_op(pred, true_fn, false_fn, (operands[0].a,))
|
|
b = cond_op(pred, true_fn, false_fn, (operands[0].b,))
|
|
return TwoTensor(a, b)
|
|
|
|
res = cond_op(a.sum() > 0, torch.sin, torch.cos, (t,))
|
|
self.assertEqual(res.a, torch.sin(a))
|
|
self.assertEqual(res.b, torch.sin(b))
|
|
self.assertEqual(called, 1)
|
|
|
|
def test_register_mode(self):
|
|
from torch._higher_order_ops.cond import cond_op
|
|
|
|
torch_dispatch_called = 0
|
|
|
|
class MyMode(torch.utils._python_dispatch.TorchDispatchMode):
|
|
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
|
|
nonlocal torch_dispatch_called
|
|
torch_dispatch_called += 1
|
|
return func(*args, **kwargs)
|
|
|
|
a = torch.tensor([1.0, 0.1, 1.0])
|
|
pred = a.sum() > 0
|
|
with self.assertRaisesRegex(
|
|
NotImplementedError,
|
|
"no rule registered for HOP cond and mode .*MyMode",
|
|
):
|
|
with MyMode():
|
|
res = cond_op(pred, torch.sin, torch.cos, (a,))
|
|
|
|
py_impl_called = 0
|
|
|
|
# Using cond.py_impl
|
|
@cond_op.py_impl(MyMode)
|
|
def _(mode, pred, true_fn, false_fn, operands):
|
|
nonlocal py_impl_called
|
|
py_impl_called += 1
|
|
return cond_op(pred, true_fn, false_fn, operands)
|
|
|
|
a = torch.tensor([1.0, 0.1, 1.0])
|
|
pred = a.sum() > 0
|
|
with MyMode():
|
|
res = cond_op(pred, torch.sin, torch.cos, (a,))
|
|
self.assertEqual(res, a.sin())
|
|
|
|
def test_capture_value_created_in_subgraph(self):
|
|
backend = EagerAndRecordGraphs()
|
|
cnt = CompileCounterWithBackend(backend)
|
|
|
|
x = torch.randn(3, 3)
|
|
y = torch.randn(3, 3)
|
|
|
|
def inner(x, y):
|
|
z = x + y
|
|
return wrap(lambda x: wrap(lambda x: x + z, x), x)
|
|
|
|
@torch.compile(backend=cnt, fullgraph=True)
|
|
def f(x, y):
|
|
return wrap(inner, x, y)
|
|
|
|
result = f(x, y)
|
|
|
|
self.assertEqual(result, x + y + x)
|
|
self.assertEqual(cnt.frame_count, 1)
|
|
self.assertEqual(cnt.op_count, 2)
|
|
self.assertEqual(len(backend.graphs), 1)
|
|
|
|
# No changes to args of outer wrap
|
|
gm = backend.graphs[0]
|
|
wrap_node = find_first_node(gm, wrap)
|
|
self.assertTrue(len(wrap_node.args), 3)
|
|
|
|
# z was lifted to arg of inner wrap
|
|
body_function = getattr(gm, wrap_node.args[0].name)
|
|
# addition + wrap + getitem
|
|
self.assertEqual(op_count(body_function), 3)
|
|
inner_wrap_node = find_first_node(body_function, wrap)
|
|
self.assertTrue(len(inner_wrap_node.args), 3)
|
|
|
|
# Innermost body function: z was also lifted to arg
|
|
body_function = getattr(body_function, inner_wrap_node.args[0].name)
|
|
self.assertEqual(op_count(body_function), 2)
|
|
inner_wrap_node = find_first_node(body_function, wrap)
|
|
self.assertTrue(len(inner_wrap_node.args), 3)
|
|
|
|
def test_side_effect_set_new_attr_global_obj(self):
|
|
def setup():
|
|
global global_obj
|
|
global_obj = Obj()
|
|
|
|
def f(x):
|
|
def h(x):
|
|
def g(x):
|
|
global_obj.foo = x + 1
|
|
return x.clone()
|
|
|
|
y = wrap(g, x)
|
|
return y + global_obj.foo
|
|
|
|
return h(x)
|
|
|
|
x = torch.zeros([])
|
|
self._assert_wrap_fallback(f, (x,), setup=setup)
|
|
|
|
def test_side_effect_set_existing_attr_global_obj(self):
|
|
def setup():
|
|
global global_obj
|
|
global_obj = Obj()
|
|
global_obj.foo = nn.Parameter(torch.tensor(4.0))
|
|
|
|
def f(x):
|
|
def h(x):
|
|
def g(x):
|
|
global_obj.foo = x + 1
|
|
return x.clone()
|
|
|
|
y = wrap(g, x)
|
|
return y + global_obj.foo
|
|
|
|
return h(x)
|
|
|
|
x = torch.zeros([])
|
|
self._assert_wrap_fallback(f, (x,), setup=setup)
|
|
|
|
def test_side_effect_del_existing_attr_global_obj(self):
|
|
def setup():
|
|
global global_obj
|
|
global_obj = Obj()
|
|
global_obj.foo = torch.tensor(4.0)
|
|
|
|
def f(x):
|
|
def h(x):
|
|
def g(x):
|
|
del global_obj.foo
|
|
return x.clone()
|
|
|
|
y = wrap(g, x)
|
|
return y
|
|
|
|
return h(x)
|
|
|
|
x = torch.zeros([])
|
|
self._assert_wrap_fallback(f, (x,), setup=setup)
|
|
|
|
def test_side_effect_set_new_attr_global_module(self):
|
|
def setup():
|
|
global global_module
|
|
global_module = MyModule()
|
|
|
|
def h(x):
|
|
def g(x):
|
|
global_module.foo = nn.Parameter(x + 1)
|
|
return x.clone()
|
|
|
|
y = wrap(g, x)
|
|
return y + global_module.foo
|
|
|
|
x = torch.zeros([])
|
|
self._assert_wrap_fallback(h, (x,), setup=setup)
|
|
|
|
def test_side_effect_set_existing_attr_global_module(self):
|
|
def setup():
|
|
global global_module
|
|
global_module = MyModule()
|
|
|
|
def h(x):
|
|
def g(x):
|
|
global_module.existing = nn.Parameter(torch.tensor(4.0))
|
|
return global_module(x)
|
|
|
|
y = wrap(g, x)
|
|
return y
|
|
|
|
x = torch.zeros([])
|
|
self._assert_wrap_fallback(h, (x,), setup=setup)
|
|
|
|
def test_side_effect_del_existing_attr_global_module(self):
|
|
def setup():
|
|
global global_module
|
|
global_module = MyModule()
|
|
|
|
def h(x):
|
|
def g(x):
|
|
del global_module.existing
|
|
return x.clone()
|
|
|
|
y = wrap(g, x)
|
|
return y
|
|
|
|
x = torch.zeros([])
|
|
self._assert_wrap_fallback(h, (x,), setup=setup)
|
|
|
|
def test_side_effect_mutate_global_num(self):
|
|
def setup():
|
|
global global_num
|
|
global_num = 3.14
|
|
|
|
def f(x):
|
|
def g(x):
|
|
global global_num
|
|
global_num = global_num + 1
|
|
return x + global_num
|
|
|
|
y = wrap(g, x)
|
|
return y + global_num
|
|
|
|
x = torch.zeros([])
|
|
self._assert_wrap_fallback(f, (x,), setup=setup)
|
|
|
|
def test_side_effect_mutate_global_num_builtin(self):
|
|
def setup():
|
|
global global_num
|
|
global_num = 3.14
|
|
|
|
def f(x):
|
|
def g(x):
|
|
global global_num
|
|
global_num += 1
|
|
return x + global_num
|
|
|
|
y = wrap(g, x)
|
|
return y + global_num
|
|
|
|
x = torch.zeros([])
|
|
self._assert_wrap_fallback(f, (x,), setup=setup)
|
|
|
|
def test_side_effect_mutate_global_tensor(self):
|
|
def setup():
|
|
global global_var
|
|
global_var = torch.ones(3)
|
|
|
|
def f(x):
|
|
def g(x):
|
|
global global_var
|
|
global_var = global_var + 1
|
|
return x + global_var
|
|
|
|
y = wrap(g, x)
|
|
return y + global_var
|
|
|
|
x = torch.zeros([])
|
|
self._assert_wrap_fallback(f, (x,), setup=setup)
|
|
|
|
def test_side_effect_mutate_global_tensor_builtin(self):
|
|
def setup():
|
|
global global_var
|
|
global_var = torch.ones(3)
|
|
|
|
def f(x):
|
|
def g(x):
|
|
global global_var
|
|
global_var += 1
|
|
return x + global_var
|
|
|
|
y = wrap(g, x)
|
|
return y + global_var
|
|
|
|
x = torch.zeros([])
|
|
self._assert_wrap_fallback(f, (x,), setup=setup)
|
|
|
|
def test_side_effect_mutate_global_list(self):
|
|
def setup():
|
|
global global_list
|
|
global_list = []
|
|
|
|
def f(x):
|
|
def g(x):
|
|
val = x + 1
|
|
global_list.append(val)
|
|
return global_list[-1]
|
|
|
|
y = wrap(g, x)
|
|
z = y + global_list[-1]
|
|
return z
|
|
|
|
x = torch.zeros([])
|
|
self._assert_wrap_fallback(f, (x,), setup=setup)
|
|
|
|
def test_side_effect_mutate_nonlocal_num(self):
|
|
def f(x):
|
|
def h(x):
|
|
val = 1
|
|
|
|
def g(x):
|
|
nonlocal val
|
|
val = val + 1
|
|
return x + val
|
|
|
|
y = wrap(g, x)
|
|
z = y + val
|
|
return z
|
|
|
|
return h(x)
|
|
|
|
x = torch.zeros([])
|
|
self._assert_wrap_fallback(f, (x,))
|
|
|
|
def test_side_effect_set_new_attr_nonlocal_obj(self):
|
|
def f(x):
|
|
def h(x):
|
|
obj = Obj()
|
|
|
|
def g(x):
|
|
obj.val = x.dim()
|
|
return x.clone()
|
|
|
|
y = wrap(g, x)
|
|
z = y + obj.val
|
|
return z
|
|
|
|
return h(x)
|
|
|
|
x = torch.zeros([])
|
|
self._assert_wrap_fallback(f, (x,))
|
|
|
|
def test_side_effect_set_existing_attr_nonlocal_obj(self):
|
|
def f(x):
|
|
def h(x):
|
|
obj = Obj()
|
|
obj.val = 3
|
|
|
|
def g(x):
|
|
obj.val = x.dim()
|
|
return x.clone()
|
|
|
|
y = wrap(g, x)
|
|
z = y + obj.val
|
|
return z
|
|
|
|
return h(x)
|
|
|
|
x = torch.zeros([])
|
|
self._assert_wrap_fallback(f, (x,))
|
|
|
|
def test_side_effect_del_existing_attr_nonlocal_obj(self):
|
|
def f(x):
|
|
def h(x):
|
|
obj = Obj()
|
|
obj.val = 3
|
|
|
|
def g(x):
|
|
del obj.val
|
|
return x.clone()
|
|
|
|
y = wrap(g, x)
|
|
return y
|
|
|
|
return h(x)
|
|
|
|
x = torch.zeros([])
|
|
self._assert_wrap_fallback(f, (x,))
|
|
|
|
def test_side_effect_set_new_attr_nonlocal_module(self):
|
|
def h(x):
|
|
obj = MyModule()
|
|
|
|
def g(x):
|
|
obj.val = x.dim()
|
|
return x.clone()
|
|
|
|
y = wrap(g, x)
|
|
z = y + obj.val
|
|
return z
|
|
|
|
x = torch.zeros([])
|
|
self._assert_wrap_fallback(h, (x,))
|
|
|
|
def test_side_effect_set_existing_attr_nonlocal_module(self):
|
|
def h(x):
|
|
obj = MyModule()
|
|
|
|
def g(x):
|
|
obj.existing = nn.Parameter(torch.tensor(3.14))
|
|
return obj(x)
|
|
|
|
y = wrap(g, x)
|
|
return y
|
|
|
|
x = torch.zeros([])
|
|
self._assert_wrap_fallback(h, (x,))
|
|
|
|
def test_side_effect_del_existing_attr_nonlocal_module(self):
|
|
def h(x):
|
|
obj = MyModule()
|
|
|
|
def g(x):
|
|
del obj.existing
|
|
return x.clone()
|
|
|
|
y = wrap(g, x)
|
|
return y
|
|
|
|
x = torch.zeros([])
|
|
self._assert_wrap_fallback(h, (x,))
|
|
|
|
def test_side_effect_mutate_nonlocal_tensor(self):
|
|
def f(x):
|
|
def h(x):
|
|
val = torch.tensor(1.0)
|
|
|
|
def g(x):
|
|
nonlocal val
|
|
val = val + 1
|
|
return x + val
|
|
|
|
y = wrap(g, x)
|
|
z = y + val
|
|
return z
|
|
|
|
return h(x)
|
|
|
|
x = torch.zeros([])
|
|
self._assert_wrap_fallback(f, (x,))
|
|
|
|
def test_side_effect_mutate_nonlocal_num_builtin(self):
|
|
def f(x):
|
|
def h(x):
|
|
val = 1
|
|
|
|
def g(x):
|
|
nonlocal val
|
|
val += 1
|
|
return x + val
|
|
|
|
y = wrap(g, x)
|
|
z = y + val
|
|
return z
|
|
|
|
return h(x)
|
|
|
|
x = torch.zeros([])
|
|
self._assert_wrap_fallback(f, (x,))
|
|
|
|
def test_side_effect_mutate_nonlocal_tensor_builtin(self):
|
|
def f(x):
|
|
def h(x):
|
|
val = torch.tensor(1.0)
|
|
|
|
def g(x):
|
|
nonlocal val
|
|
val += 1
|
|
return x + val
|
|
|
|
y = wrap(g, x)
|
|
z = y + val
|
|
return z
|
|
|
|
return h(x)
|
|
|
|
x = torch.zeros([])
|
|
self._assert_wrap_fallback(f, (x,))
|
|
|
|
def test_side_effect_nonlocal_list_append_graph_break(self):
|
|
def g(x):
|
|
y = []
|
|
|
|
def f(k):
|
|
m = k + 1
|
|
y.append(m)
|
|
return k
|
|
|
|
wrap(f, x)
|
|
return y[0]
|
|
|
|
x = torch.randn(3, 3)
|
|
self._assert_wrap_fallback(g, (x,))
|
|
|
|
def test_side_effect_nested_nonlocal_list_append_graph_break(self):
|
|
def g(x):
|
|
def h(x):
|
|
y = []
|
|
|
|
def f(k):
|
|
m = k + 1
|
|
y.append(m)
|
|
return k
|
|
|
|
wrap(f, x)
|
|
return y[0]
|
|
|
|
return h(x)
|
|
|
|
x = torch.randn(3, 3)
|
|
self._assert_wrap_fallback(g, (x,))
|
|
|
|
def test_side_effect_local_list_append_no_graph_break(self):
|
|
def g(x):
|
|
def f(k):
|
|
y = []
|
|
y.append(k + 1)
|
|
return y[0]
|
|
|
|
return wrap(f, x)
|
|
|
|
x = torch.randn(3, 3)
|
|
arg_count = ifdynstaticdefault(2, 3)
|
|
self._test_wrap_simple(g, default_args_generator((x,)), arg_count)
|
|
|
|
def test_wrap_kwarg(self):
|
|
def f(x, y):
|
|
return wrap(lambda x, y: x + y, x, y=y)
|
|
|
|
x = torch.randn(3)
|
|
y = torch.randn(3, 3)
|
|
arg_count = ifdynstaticdefault(3, 4)
|
|
self._test_wrap_simple(f, default_args_generator((x, y)), arg_count)
|
|
|
|
def test_wrap_kwarg_int(self):
|
|
def f(x, y):
|
|
return wrap(lambda x, y: x + y, x, y=y)
|
|
|
|
x = torch.randn(3)
|
|
y = 8
|
|
|
|
arg_count = (
|
|
ifdynstaticdefault(2, 3) + 1
|
|
if check_dynamic_shape_capture()
|
|
else ifdynstaticdefault(2, 3)
|
|
)
|
|
self._test_wrap_simple(f, default_args_generator((x, y)), arg_count)
|
|
|
|
def test_wrap_all_kwarg(self):
|
|
def f(y, x):
|
|
return wrap(lambda x, y: (x * 2) + y, x=x, y=y)
|
|
|
|
x = torch.randn(3)
|
|
y = torch.randn(3, 3)
|
|
|
|
arg_count = ifdynstaticdefault(3, 4)
|
|
self._test_wrap_simple(f, default_args_generator((x, y)), arg_count)
|
|
|
|
def test_wrap_kwarg_only(self):
|
|
def f(x, y):
|
|
def fn(*, x, y):
|
|
return (x * 2) + y
|
|
|
|
return wrap(fn, x=x, y=y)
|
|
|
|
x = torch.randn(3)
|
|
y = torch.randn(3, 3)
|
|
|
|
arg_count = ifdynstaticdefault(3, 4)
|
|
self._test_wrap_simple(f, default_args_generator((x, y)), arg_count)
|
|
|
|
def test_wrap_kwarg_default(self):
|
|
def f(x, y):
|
|
def fn(*, x, y, z=8):
|
|
return (x * 2) + y + z
|
|
|
|
return wrap(fn, x=x, y=y)
|
|
|
|
x = torch.randn(3)
|
|
y = torch.randn(3, 3)
|
|
|
|
arg_count = ifdynstaticdefault(3, 4)
|
|
self._test_wrap_simple(f, default_args_generator((x, y)), arg_count)
|
|
|
|
def test_wrap_kwarg_default_if_branch(self):
|
|
def f(x, y):
|
|
def fn(*, x, y, z=None):
|
|
if z is None:
|
|
return (x * 2) + y
|
|
else:
|
|
return 2 * x
|
|
|
|
return wrap(fn, x=x, y=y)
|
|
|
|
x = torch.randn(3)
|
|
y = torch.randn(3, 3)
|
|
|
|
arg_count = ifdynstaticdefault(3, 4)
|
|
self._test_wrap_simple(f, default_args_generator((x, y)), arg_count)
|
|
|
|
def test_wrap_kwarg_recompile(self):
|
|
def f(x, y, z=None):
|
|
def fn(*, x, y, z=None):
|
|
if z is None:
|
|
return (x * 2) + y
|
|
else:
|
|
return 2 * x
|
|
|
|
return wrap(fn, x=x, y=y, z=z)
|
|
|
|
x = torch.randn(3)
|
|
y = torch.randn(3, 3)
|
|
|
|
counters.clear()
|
|
opt = torch.compile(f, backend="eager", fullgraph=True)
|
|
opt(x, y)
|
|
self.assertEqual(counters["stats"]["calls_captured"], 2)
|
|
|
|
# verify that we `don't` recompile
|
|
opt(x, y)
|
|
self.assertEqual(counters["stats"]["calls_captured"], 2)
|
|
|
|
output = opt(x, y, 8)
|
|
self.assertEqual(counters["stats"]["calls_captured"], 4)
|
|
self.assertEqual(output, 2 * x)
|
|
|
|
def test_wrap_kwarg_default_else_branch(self):
|
|
def f(x, y, z):
|
|
def fn(*, x, y, z=None):
|
|
if z is None:
|
|
return (x * 2) + y
|
|
else:
|
|
return 2 * x
|
|
|
|
return wrap(fn, x=x, y=y, z=z)
|
|
|
|
x = torch.randn(3)
|
|
y = torch.randn(3, 3)
|
|
|
|
arg_count = ifdynstaticdefault(2, 3)
|
|
self._test_wrap_simple(f, default_args_generator((x, y, 8)), arg_count)
|
|
|
|
def test_map_subgraph_name_is_valid(self):
|
|
backend = EagerAndRecordGraphs()
|
|
cnt = CompileCounterWithBackend(backend)
|
|
|
|
xs = torch.randn(2, 3, 3)
|
|
y = torch.randn(3)
|
|
|
|
def map_f(xs, y):
|
|
def inner(x, y):
|
|
def inner2(x, y):
|
|
return x + y
|
|
|
|
return control_flow.map(inner2, x, y)
|
|
|
|
return control_flow.map(inner, xs, y)
|
|
|
|
graphs = self._check_map_graph_and_extract(map_f, (xs, y))
|
|
if graphs:
|
|
graph, body_graph = graphs
|
|
self.assertExpectedInline(
|
|
graph,
|
|
"""\
|
|
def forward(self, L_xs_ : torch.Tensor, L_y_ : torch.Tensor):
|
|
l_xs_ = L_xs_
|
|
l_y_ = L_y_
|
|
map_body_1 = self.map_body_1
|
|
map_impl = torch.ops.higher_order.map_impl(map_body_1, [l_xs_], [l_y_]); map_body_1 = l_xs_ = l_y_ = None
|
|
getitem_1 = map_impl[0]; map_impl = None
|
|
return (getitem_1,)""",
|
|
)
|
|
self.assertExpectedInline(
|
|
body_graph,
|
|
"""\
|
|
def forward(self, child : torch.Tensor, l_y_ : torch.Tensor):
|
|
child_1 = child[0]; child_1 = None
|
|
map_body_0 = self.map_body_0
|
|
map_impl = torch.ops.higher_order.map_impl(map_body_0, [child], [l_y_]); map_body_0 = child = l_y_ = None
|
|
getitem_1 = map_impl[0]; map_impl = None
|
|
return (getitem_1,)""",
|
|
)
|
|
|
|
def test_map_multi_return(self):
|
|
cnt = CompileCounter()
|
|
|
|
def f(x):
|
|
return control_flow.map(lambda x: (x.sin(), x.sin()), x)
|
|
|
|
x = torch.randn(3)
|
|
graphs = self._check_map_graph_and_extract(f, (x,))
|
|
if graphs:
|
|
graph, body_graph = graphs
|
|
self.assertExpectedInline(
|
|
graph,
|
|
"""\
|
|
def forward(self, L_x_ : torch.Tensor):
|
|
l_x_ = L_x_
|
|
map_body_0 = self.map_body_0
|
|
map_impl = torch.ops.higher_order.map_impl(map_body_0, [l_x_], []); map_body_0 = l_x_ = None
|
|
getitem_1 = map_impl[0]
|
|
getitem_2 = map_impl[1]; map_impl = None
|
|
return (getitem_1, getitem_2)""",
|
|
)
|
|
self.assertExpectedInline(
|
|
body_graph,
|
|
"""\
|
|
def forward(self, child : torch.Tensor):
|
|
child_1 = child.sin()
|
|
child_2 = child.sin(); child = None
|
|
return (child_1, child_2)""",
|
|
)
|
|
|
|
def test_map_pytree_return(self):
|
|
cnt = CompileCounter()
|
|
|
|
def _construct_pytree(a):
|
|
return (a, [[[a]]], a, (a, (a,), a), {"a": a})
|
|
|
|
def f(x):
|
|
def inner_f(xs):
|
|
return _construct_pytree(xs)
|
|
|
|
return control_flow.map(inner_f, x)
|
|
|
|
x = torch.randn(3)
|
|
graphs = self._check_map_graph_and_extract(f, (x,))
|
|
if graphs:
|
|
graph, body_graph = graphs
|
|
self.assertExpectedInline(
|
|
graph,
|
|
"""\
|
|
def forward(self, L_x_ : torch.Tensor):
|
|
l_x_ = L_x_
|
|
map_body_0 = self.map_body_0
|
|
map_impl = torch.ops.higher_order.map_impl(map_body_0, [l_x_], []); map_body_0 = l_x_ = None
|
|
getitem_1 = map_impl[0]
|
|
getitem_2 = map_impl[1]
|
|
getitem_3 = map_impl[2]
|
|
getitem_4 = map_impl[3]
|
|
getitem_5 = map_impl[4]
|
|
getitem_6 = map_impl[5]
|
|
getitem_7 = map_impl[6]; map_impl = None
|
|
return (getitem_1, getitem_2, getitem_3, getitem_4, getitem_5, getitem_6, getitem_7)""",
|
|
)
|
|
self.assertExpectedInline(
|
|
body_graph,
|
|
"""\
|
|
def forward(self, child : torch.Tensor):
|
|
return (child, child, child, child, child, child, child)""",
|
|
)
|
|
|
|
def test_map_kwargs(self):
|
|
cnt = CompileCounter()
|
|
|
|
@torch.compile(backend=cnt)
|
|
def f(x):
|
|
return control_flow.map(lambda x: x.sin(), x=x)
|
|
|
|
x = torch.randn(3)
|
|
self.assertRaises(TypeError, lambda: f(x))
|
|
self.assertEqual(cnt.frame_count, 0)
|
|
|
|
def test_map_symint_input(self):
|
|
backend = EagerAndRecordGraphs()
|
|
cnt = CompileCounterWithBackend(backend)
|
|
|
|
def fn(x, y):
|
|
def inner(x, y):
|
|
return torch.sin(x + y)
|
|
|
|
return control_flow.map(inner, x, y.size(0))
|
|
|
|
x = torch.randn(3, 1)
|
|
y = torch.randn(3, 1)
|
|
graphs = self._check_map_graph_and_extract(fn, (x, y))
|
|
if graphs:
|
|
graph, body_graph = graphs
|
|
self.assertExpectedInline(
|
|
graph,
|
|
"""\
|
|
def forward(self, L_x_ : torch.Tensor):
|
|
l_x_ = L_x_
|
|
map_body_0 = self.map_body_0
|
|
map_impl = torch.ops.higher_order.map_impl(map_body_0, [l_x_], [3]); map_body_0 = l_x_ = None
|
|
getitem_1 = map_impl[0]; map_impl = None
|
|
return (getitem_1,)""",
|
|
)
|
|
self.assertExpectedInline(
|
|
body_graph,
|
|
"""\
|
|
def forward(self, child : torch.Tensor, const_unused : int):
|
|
add = child + 3; child = None
|
|
sin = torch.sin(add); add = None
|
|
return (sin,)""",
|
|
)
|
|
|
|
def test_map_lowers_to_graph(self):
|
|
backend = EagerAndRecordGraphs()
|
|
cnt = CompileCounterWithBackend(backend)
|
|
|
|
def fn(x, y):
|
|
def inner(x, y):
|
|
return torch.sin(x + y)
|
|
|
|
return control_flow.map(inner, x, y.size(0))
|
|
|
|
x = torch.randn(3, 1)
|
|
y = torch.randn(3, 1)
|
|
graphs = self._check_map_graph_and_extract(fn, (x, y))
|
|
if graphs:
|
|
graph, body_graph = graphs
|
|
self.assertExpectedInline(
|
|
graph,
|
|
"""\
|
|
def forward(self, L_x_ : torch.Tensor):
|
|
l_x_ = L_x_
|
|
map_body_0 = self.map_body_0
|
|
map_impl = torch.ops.higher_order.map_impl(map_body_0, [l_x_], [3]); map_body_0 = l_x_ = None
|
|
getitem_1 = map_impl[0]; map_impl = None
|
|
return (getitem_1,)""",
|
|
)
|
|
self.assertExpectedInline(
|
|
body_graph,
|
|
"""\
|
|
def forward(self, child : torch.Tensor, const_unused : int):
|
|
add = child + 3; child = None
|
|
sin = torch.sin(add); add = None
|
|
return (sin,)""",
|
|
)
|
|
|
|
def test_map_example_value_metadata_consistent_with_eager(self):
|
|
from torch._higher_order_ops.map import map_dense
|
|
|
|
backend = EagerAndRecordGraphs()
|
|
|
|
def inner(x):
|
|
return x.sin(), x.cos().T, x.sin().view(-1)
|
|
|
|
rand_44 = torch.randn(4, 4)
|
|
inps = [
|
|
torch.randn(3),
|
|
torch.randn(3, 4),
|
|
torch.randn(3, 4, 5, requires_grad=True),
|
|
torch.randn(3, 4, 5, requires_grad=True).permute((2, 0, 1)),
|
|
torch.randn(3, 4, 5, requires_grad=True).detach(),
|
|
torch.randn(3, 4, 5, requires_grad=True).narrow(1, 1, 2),
|
|
rand_44.T,
|
|
rand_44[::2],
|
|
rand_44[::2, ::2],
|
|
rand_44[1::3, 1::3],
|
|
rand_44[1::3, 1::2].T,
|
|
rand_44.unsqueeze(1),
|
|
rand_44.squeeze(0),
|
|
rand_44.reshape(2, 8),
|
|
]
|
|
for x in inps:
|
|
compiled_ret = torch.compile(
|
|
control_flow.map, backend=backend, fullgraph=True
|
|
)(inner, x)
|
|
eager_sin, eager_transpose, eager_view = map_dense(inner, (x,), ())
|
|
|
|
map_node = next(
|
|
node
|
|
for node in backend.graphs[0].graph.nodes
|
|
if node.op == "call_function" and "map" in node.name
|
|
)
|
|
|
|
fake_sin, fake_transpose, fake_view = map_node.meta["example_value"]
|
|
|
|
def _check_size_stride_contiguous(x, y):
|
|
self.assertEqual(y.size(), x.size())
|
|
self.assertEqual(y.stride(), x.stride())
|
|
self.assertEqual(y.requires_grad, x.requires_grad)
|
|
self.assertEqual(x.is_contiguous(), True)
|
|
self.assertEqual(y.is_contiguous(), True)
|
|
|
|
_check_size_stride_contiguous(eager_sin, fake_sin)
|
|
_check_size_stride_contiguous(eager_transpose, fake_transpose)
|
|
_check_size_stride_contiguous(eager_view, fake_view)
|
|
|
|
torch._dynamo.reset()
|
|
backend.graphs.clear()
|
|
|
|
def test_cond_subgraph_name_is_valid(self):
|
|
backend = EagerAndRecordGraphs()
|
|
cnt = CompileCounterWithBackend(backend)
|
|
|
|
pred = torch.tensor(True)
|
|
pred2 = torch.tensor(False)
|
|
xs = torch.randn(2, 3, 3)
|
|
y = torch.randn(3, 3)
|
|
|
|
@torch.compile(backend=cnt, fullgraph=True)
|
|
def cond_f(pred, pred2, x, y):
|
|
def true_fn(pred2, x, y):
|
|
return x + y
|
|
|
|
def false_fn(pred2, x, y):
|
|
def true_fn2(x, y):
|
|
return x.sin() - y.cos()
|
|
|
|
def false_fn2(x, y):
|
|
return x.cos() - y.sin()
|
|
|
|
return control_flow.cond(pred2, true_fn2, false_fn2, [x, y])
|
|
|
|
return control_flow.cond(pred, true_fn, false_fn, [pred2, x, y])
|
|
|
|
result = cond_f(pred, pred2, xs, y)
|
|
self.assertEqual(result, xs + y)
|
|
|
|
cond_gm = backend.graphs[0]
|
|
name_set = set()
|
|
name_set.update(name for name, _ in cond_gm.named_modules())
|
|
self.assertEqual(
|
|
name_set,
|
|
{
|
|
"",
|
|
"cond_true_1",
|
|
"cond_false_1",
|
|
"cond_false_1.cond_false_0",
|
|
"cond_false_1.cond_true_0",
|
|
},
|
|
)
|
|
|
|
@torch._dynamo.config.patch(
|
|
assume_static_by_default=True,
|
|
dynamic_shapes=True,
|
|
)
|
|
def test_cond_graph_break_in_one_branch(self):
|
|
backend = EagerAndRecordGraphs()
|
|
cnt = CompileCounterWithBackend(backend)
|
|
|
|
class Foo(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.buffer = torch.nn.Buffer(torch.ones(6, 4))
|
|
|
|
def forward(self, x):
|
|
def true_fn(x):
|
|
self.buffer += 1
|
|
return self.buffer.sum() + x.sum()
|
|
|
|
def false_fn(x):
|
|
return (x - 1).sum()
|
|
|
|
return control_flow.cond(x.sum() > 4, true_fn, false_fn, [x])
|
|
|
|
mod_for_compile = torch.compile(Foo(), backend=cnt, dynamic=True)
|
|
mod_for_eager = Foo()
|
|
|
|
with self.assertRaisesRegex(
|
|
torch._dynamo.exc.UncapturedHigherOrderOpError,
|
|
r"Cond doesn't work unless it is captured completely with torch.compile",
|
|
):
|
|
mod_for_eager(torch.ones(6, 4))
|
|
|
|
with self.assertRaisesRegex(
|
|
torch._dynamo.exc.UncapturedHigherOrderOpError,
|
|
r"Cond doesn't work unless it is captured completely with torch.compile",
|
|
):
|
|
mod_for_compile(torch.ones(3, 4))
|
|
|
|
def test_cond_free_variable_in_both_branches(self):
|
|
backend = EagerAndRecordGraphs()
|
|
cnt = CompileCounterWithBackend(backend)
|
|
|
|
z = torch.ones(4, 4)
|
|
|
|
class Foo(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.buffer = torch.nn.Buffer(torch.ones(6, 4))
|
|
|
|
def forward(self, x, y):
|
|
def true_fn(x):
|
|
return x.sum() + self.buffer.sum() + z.sum()
|
|
|
|
def false_fn(x):
|
|
return x.sum() - z.sum() - self.buffer.sum()
|
|
|
|
return control_flow.cond(y, true_fn, false_fn, [x])
|
|
|
|
mod_for_compile = torch.compile(
|
|
Foo(), backend=cnt, dynamic=True, fullgraph=True
|
|
)
|
|
mod_for_eager = Foo()
|
|
|
|
self.assertEqual(
|
|
mod_for_compile(torch.tensor(True), torch.tensor(5)),
|
|
mod_for_eager(torch.tensor(True), torch.tensor(5)),
|
|
)
|
|
|
|
for node in backend.graphs[0].graph.nodes:
|
|
if (
|
|
node.op == "call_function"
|
|
and node.target == torch.ops.higher_order.cond
|
|
):
|
|
_, _, _, operands = node.args
|
|
# Since we compile wit dynamic, each branch takes 4 inputs (buffer, x, z, s1)
|
|
self.assertEqual(len(operands), 4)
|
|
if node.op == "get_attr":
|
|
if str(node.target) in ("cond_true_0, cond_false_0"):
|
|
num_placeholders = len(
|
|
[
|
|
node
|
|
for node in getattr(
|
|
backend.graphs[0], str(node.target)
|
|
).graph.nodes
|
|
if node.op == "placeholder"
|
|
]
|
|
)
|
|
self.assertEqual(num_placeholders, 4)
|
|
|
|
def _check_cond_graph_and_extract(self, fn, args):
|
|
backend = EagerAndRecordGraphs()
|
|
cnt = CompileCounterWithBackend(backend)
|
|
out = torch.compile(fn, backend=cnt, fullgraph=True)(*args)
|
|
self.assertEqual(out, fn(*args))
|
|
self.assertEqual(cnt.frame_count, 1)
|
|
self.assertEqual(len(backend.graphs), 1)
|
|
|
|
# Dynamic shapes produce a slightly different graph.
|
|
if check_dynamic_shape_capture():
|
|
return
|
|
|
|
gm = backend.graphs[0]
|
|
graph = gm.code.strip()
|
|
true_graph = gm.cond_true_0.code.strip()
|
|
false_graph = gm.cond_false_0.code.strip()
|
|
return (graph, true_graph, false_graph)
|
|
|
|
def _check_map_graph_and_extract(self, fn, args):
|
|
backend = EagerAndRecordGraphs()
|
|
cnt = CompileCounterWithBackend(backend)
|
|
out = torch.compile(fn, backend=cnt, fullgraph=True)(*args)
|
|
self.assertEqual(out, fn(*args))
|
|
self.assertEqual(cnt.frame_count, 1)
|
|
self.assertEqual(len(backend.graphs), 1)
|
|
|
|
# Dynamic shapes produce a slightly different graph.
|
|
if check_dynamic_shape_capture():
|
|
return
|
|
|
|
gm = backend.graphs[0]
|
|
graph = gm.code.strip()
|
|
subgraphs = []
|
|
for module_name in gm._modules.keys():
|
|
subgraphs.append(getattr(gm, module_name).code.strip())
|
|
return (graph, *subgraphs)
|
|
|
|
def test_cond_branches_no_arguments(self):
|
|
def fn(x):
|
|
def true_fn():
|
|
return torch.sin(x)
|
|
|
|
def false_fn():
|
|
return torch.cos(x)
|
|
|
|
return control_flow.cond(x.sum() > 0, true_fn, false_fn, ())
|
|
|
|
graphs = self._check_cond_graph_and_extract(fn, (torch.randn(4, 5),))
|
|
if graphs is not None:
|
|
graph, true_graph, false_graph = graphs
|
|
self.assertExpectedInline(
|
|
graph,
|
|
"""\
|
|
def forward(self, L_x_ : torch.Tensor):
|
|
l_x_ = L_x_
|
|
sum_1 = l_x_.sum()
|
|
gt = sum_1 > 0; sum_1 = None
|
|
cond_true_0 = self.cond_true_0
|
|
cond_false_0 = self.cond_false_0
|
|
cond = torch.ops.higher_order.cond(gt, cond_true_0, cond_false_0, [l_x_]); gt = cond_true_0 = cond_false_0 = l_x_ = None
|
|
getitem = cond[0]; cond = None
|
|
return (getitem,)""",
|
|
)
|
|
self.assertExpectedInline(
|
|
true_graph,
|
|
"""\
|
|
def forward(self, l_x_):
|
|
l_x__1 = l_x_
|
|
sin = torch.sin(l_x__1); l_x__1 = None
|
|
return (sin,)""",
|
|
)
|
|
self.assertExpectedInline(
|
|
false_graph,
|
|
"""\
|
|
def forward(self, l_x_):
|
|
l_x__1 = l_x_
|
|
cos = torch.cos(l_x__1); l_x__1 = None
|
|
return (cos,)""",
|
|
)
|
|
|
|
def test_cond_branches_no_arguments_no_closure(self):
|
|
def fn(x):
|
|
def true_fn():
|
|
return torch.ones(3, 4)
|
|
|
|
def false_fn():
|
|
return torch.ones(3, 4).sin()
|
|
|
|
return control_flow.cond(x.sum() > 0, true_fn, false_fn, ())
|
|
|
|
self._check_cond_graph_and_extract(fn, (torch.randn(4, 5),))
|
|
graphs = self._check_cond_graph_and_extract(fn, (torch.randn(4, 5),))
|
|
if graphs is not None:
|
|
graph, true_graph, false_graph = graphs
|
|
self.assertExpectedInline(
|
|
graph,
|
|
"""\
|
|
def forward(self, L_x_ : torch.Tensor):
|
|
l_x_ = L_x_
|
|
sum_1 = l_x_.sum(); l_x_ = None
|
|
gt = sum_1 > 0; sum_1 = None
|
|
cond_true_0 = self.cond_true_0
|
|
cond_false_0 = self.cond_false_0
|
|
cond = torch.ops.higher_order.cond(gt, cond_true_0, cond_false_0, []); gt = cond_true_0 = cond_false_0 = None
|
|
getitem = cond[0]; cond = None
|
|
return (getitem,)""",
|
|
)
|
|
self.assertExpectedInline(
|
|
true_graph,
|
|
"""\
|
|
def forward(self):
|
|
ones = torch.ones(3, 4)
|
|
return (ones,)""",
|
|
)
|
|
self.assertExpectedInline(
|
|
false_graph,
|
|
"""\
|
|
def forward(self):
|
|
ones = torch.ones(3, 4)
|
|
sin = ones.sin(); ones = None
|
|
return (sin,)""",
|
|
)
|
|
|
|
def test_cond_side_effect_in_one_branches(self):
|
|
backend = EagerAndRecordGraphs()
|
|
cnt = CompileCounterWithBackend(backend)
|
|
|
|
z = [torch.ones(4, 4)]
|
|
|
|
class Foo(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def forward(self, y, x):
|
|
def true_fn(x):
|
|
z.append(x)
|
|
z.append(x)
|
|
z.pop()
|
|
return x.sum() + z[-1].sum()
|
|
|
|
def false_fn(x):
|
|
return x.sum() - z[0].sum()
|
|
|
|
return control_flow.cond(y, true_fn, false_fn, [x])
|
|
|
|
mod_for_eager = Foo()
|
|
mod_for_compile = torch.compile(
|
|
Foo(), backend=cnt, dynamic=True, fullgraph=False
|
|
)
|
|
with self.assertRaisesRegex(
|
|
torch._dynamo.exc.UncapturedHigherOrderOpError,
|
|
r"Cond doesn't work unless it is captured completely with torch.compile",
|
|
):
|
|
mod_for_eager(torch.tensor(True), torch.tensor(5))
|
|
|
|
with self.assertRaisesRegex(
|
|
torch._dynamo.exc.UncapturedHigherOrderOpError,
|
|
r"Cond doesn't work unless it is captured completely with torch.compile",
|
|
):
|
|
mod_for_compile(torch.tensor(True), torch.tensor(5))
|
|
|
|
def test_cond_with_constant_pred(self):
|
|
def test(pred, x):
|
|
def true_fn(x):
|
|
return x
|
|
|
|
def false_fn(x):
|
|
return -x
|
|
|
|
return control_flow.cond(pred, true_fn, false_fn, [x])
|
|
|
|
opt_test = torch.compile(test, backend="eager")
|
|
inp = torch.ones(3, 3)
|
|
self.assertTrue(torch.allclose(test(True, inp), opt_test(True, inp)))
|
|
self.assertTrue(torch.allclose(test(False, inp), opt_test(False, inp)))
|
|
|
|
def test_map_graph_break(self):
|
|
backend = EagerAndRecordGraphs()
|
|
cnt = CompileCounterWithBackend(backend)
|
|
|
|
class Module(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.w = torch.nn.Buffer(torch.ones(6, 4))
|
|
|
|
def forward(self, xs):
|
|
def body(x):
|
|
self.w += 1
|
|
return x
|
|
|
|
return control_flow.map(body, xs)
|
|
|
|
mod = Module()
|
|
|
|
mod_for_compile = torch.compile(mod, backend=cnt, dynamic=True, fullgraph=False)
|
|
mod_for_eager = Module()
|
|
|
|
res = mod_for_compile(torch.Tensor([[6, 4, 5], [3, 4, 5], [6, 6, 6]]))
|
|
# There is graph break right when we enter body of map
|
|
self.assertEqual(len(backend.graphs), 0)
|
|
self.assertEqual(
|
|
res, mod_for_eager(torch.Tensor([[6, 4, 5], [3, 4, 5], [6, 6, 6]]))
|
|
)
|
|
|
|
def test_map_side_effect(self):
|
|
backend = EagerAndRecordGraphs()
|
|
cnt = CompileCounterWithBackend(backend)
|
|
|
|
z = [torch.ones(6, 4)]
|
|
|
|
class Module(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.w = torch.nn.Buffer(torch.ones(6, 4))
|
|
|
|
def forward(self, xs):
|
|
def body(x):
|
|
z.append(x)
|
|
z.append(x)
|
|
z.pop()
|
|
return x + z[-1].sum()
|
|
|
|
return control_flow.map(body, xs)
|
|
|
|
mod = Module()
|
|
|
|
mod_for_compile = torch.compile(mod, backend=cnt, dynamic=True, fullgraph=False)
|
|
mod_for_eager = Module()
|
|
|
|
res = mod_for_compile(torch.Tensor([[6, 4, 5], [3, 4, 5], [6, 6, 6]]))
|
|
res = mod_for_compile(torch.Tensor([[6, 4, 5], [3, 4, 5], [6, 6, 6]]))
|
|
|
|
eager = mod_for_eager(torch.Tensor([[6, 4, 5], [3, 4, 5], [6, 6, 6]]))
|
|
eager = mod_for_eager(torch.Tensor([[6, 4, 5], [3, 4, 5], [6, 6, 6]]))
|
|
|
|
self.assertEqual(len(backend.graphs), 0)
|
|
self.assertEqual(res, eager)
|
|
|
|
def test_wrap_subgraph_name_is_valid(self):
|
|
backend = EagerAndRecordGraphs()
|
|
cnt = CompileCounterWithBackend(backend)
|
|
|
|
x = torch.randn(3, 3)
|
|
y = torch.randn(3, 3)
|
|
|
|
def inner(x, y):
|
|
z = x + y
|
|
return wrap(lambda x: wrap(lambda x: x + z, x), x)
|
|
|
|
@torch.compile(backend=cnt, fullgraph=True)
|
|
def f(x, y):
|
|
return wrap(inner, x, y)
|
|
|
|
result = f(x, y)
|
|
|
|
self.assertEqual(result, x + y + x)
|
|
wrap_gm = backend.graphs[0]
|
|
names = set()
|
|
names.update(mod_name for mod_name, _ in wrap_gm.named_modules())
|
|
self.assertEqual(
|
|
names,
|
|
{
|
|
"",
|
|
"wrap_body_2",
|
|
"wrap_body_2.wrap_body_1",
|
|
"wrap_body_2.wrap_body_1.wrap_body_0",
|
|
},
|
|
)
|
|
|
|
def test_wrap_allow_local_assign_in_body_fn(self):
|
|
def f(arg1, arg2):
|
|
def inner_f(arg1, arg2):
|
|
a = arg1
|
|
b = arg2
|
|
ret = []
|
|
for x in a:
|
|
ret.append(x + 1)
|
|
for x in b:
|
|
ret.append(x + 1)
|
|
return ret
|
|
|
|
return wrap(inner_f, arg1, arg2)
|
|
|
|
x = torch.ones(3)
|
|
|
|
def my_args_generator():
|
|
yield [x], [x.sin()]
|
|
yield (x,), (x.sin(),)
|
|
|
|
arg_count = ifdynstaticdefault(3, 4)
|
|
actual_graph = self._test_wrap_simple(
|
|
f,
|
|
my_args_generator(),
|
|
arg_count,
|
|
3,
|
|
return_graph=True,
|
|
)
|
|
|
|
# Dynamic shapes produce a slightly different graph.
|
|
if check_dynamic_shape_capture():
|
|
return
|
|
|
|
self.assertExpectedInline(
|
|
actual_graph,
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(self, L_arg1_0_: "f32[3]", L_arg2_0_: "f32[3]"):
|
|
l_arg1_0_ = L_arg1_0_
|
|
l_arg2_0_ = L_arg2_0_
|
|
|
|
wrap_body_0 = self.wrap_body_0
|
|
wrap = torch.ops.higher_order.wrap(wrap_body_0, l_arg1_0_, l_arg2_0_); wrap_body_0 = l_arg1_0_ = l_arg2_0_ = None
|
|
getitem: "f32[3]" = wrap[0]
|
|
getitem_1: "f32[3]" = wrap[1]; wrap = None
|
|
return (getitem, getitem_1)
|
|
|
|
class wrap_body_0(torch.nn.Module):
|
|
def forward(self, l_arg1_0_: "f32[3]", l_arg2_0_: "f32[3]"):
|
|
child: "f32[3]" = l_arg1_0_ + 1; l_arg1_0_ = None
|
|
|
|
child_1: "f32[3]" = l_arg2_0_ + 1; l_arg2_0_ = None
|
|
return (child, child_1)
|
|
""",
|
|
)
|
|
|
|
def test_capture_global_num(self):
|
|
def f(x):
|
|
return wrap(lambda x: x + global_num, x)
|
|
|
|
x = torch.zeros([])
|
|
# Numbers don't get lifted, so args is still 2.
|
|
self._test_wrap_simple(f, default_args_generator((x,)), 2)
|
|
|
|
def test_capture_global_num_adds_guard(self):
|
|
@torch.compile(backend="eager", fullgraph=True)
|
|
def f(x):
|
|
return wrap(lambda x: x + global_num, x)
|
|
|
|
global global_num
|
|
x = torch.zeros([])
|
|
result = f(x)
|
|
self.assertEqual(result, x + global_num)
|
|
|
|
global_num = torch.randn([]).item()
|
|
result = f(x)
|
|
self.assertEqual(result, x + global_num)
|
|
|
|
def test_capture_input_num(self):
|
|
def f(x, y):
|
|
return wrap(lambda x: x + y, x)
|
|
|
|
x = torch.zeros([])
|
|
y = 3.14
|
|
# Numbers don't get lifted, so args is still 2.
|
|
self._test_wrap_simple(f, default_args_generator((x, y)), 2)
|
|
|
|
def test_side_effect_in_body(self):
|
|
counters.clear()
|
|
backend = EagerAndRecordGraphs()
|
|
|
|
x = torch.randn([])
|
|
y = torch.randn([])
|
|
|
|
def inner(x):
|
|
nonlocal y
|
|
y = x
|
|
return x.clone()
|
|
|
|
@torch.compile(backend=backend)
|
|
def f(x):
|
|
return wrap(inner, x)
|
|
|
|
f(x)
|
|
self.assertEqual(y, x)
|
|
assert_dict_matches_regex(
|
|
self,
|
|
dict(counters["graph_break"]),
|
|
{
|
|
r".*HigherOrderOperator: Mutating a variable not in the current scope \(SideEffects\)": 1
|
|
},
|
|
)
|
|
|
|
def test_fallback_on_graph_break_simple(self):
|
|
# In the future, there should be a per-HigherOrderOperator switch
|
|
# on whether or not to fallback or raise a loud error.
|
|
# For now we just fallback by default.
|
|
cnt = CompileCounter()
|
|
x = torch.randn([])
|
|
|
|
def inner(x):
|
|
y = x.sin()
|
|
torch._dynamo.graph_break()
|
|
z = y.sin()
|
|
return z
|
|
|
|
@torch.compile(backend=cnt)
|
|
def f(x):
|
|
return wrap(inner, x)
|
|
|
|
result = f(x)
|
|
self.assertEqual(result, inner(x))
|
|
self.assertEqual(cnt.frame_count, 0)
|
|
|
|
def test_fallback_on_graph_break_complicated(self):
|
|
cnt = CompileCounter()
|
|
x = torch.randn([])
|
|
|
|
def inner(x):
|
|
y = x.sin()
|
|
y = y * global_var
|
|
torch._dynamo.graph_break()
|
|
z = y.sin()
|
|
return z
|
|
|
|
@torch.compile(backend=cnt)
|
|
def f(x):
|
|
x = x.clone()
|
|
result = wrap(inner, x)
|
|
return result.clone()
|
|
|
|
result = f(x)
|
|
self.assertEqual(result, inner(x))
|
|
self.assertEqual(cnt.frame_count, 2)
|
|
|
|
def test_modules(self):
|
|
counters.clear()
|
|
backend = EagerAndRecordGraphs()
|
|
cnt = CompileCounterWithBackend(backend)
|
|
mod = torch.nn.Linear(3, 3)
|
|
x = torch.randn(3, 3)
|
|
|
|
@torch.compile(backend=cnt, fullgraph=True)
|
|
def f(x):
|
|
return wrap(lambda x: mod(x), x)
|
|
|
|
result = f(x)
|
|
|
|
self.assertEqual(result, mod(x))
|
|
self.assertEqual(cnt.frame_count, 1)
|
|
|
|
self.assertEqual(len(backend.graphs), 1)
|
|
wrap_node = find_first_node(backend.graphs[0], wrap)
|
|
# 3 args - 1 for input, and other 2 for the weight and bias
|
|
self.assertTrue(len(wrap_node.args), 3)
|
|
|
|
# Check that the linear bias and weight are getattr in the outer graph
|
|
if not torch._dynamo.config.inline_inbuilt_nn_modules:
|
|
self.assertTrue(len(dict(backend.graphs[0].named_parameters())) == 2)
|
|
|
|
# Check that the inner function has one op and its a linear op
|
|
body_function = getattr(backend.graphs[0], wrap_node.args[0].name)
|
|
self.assertEqual(op_count(body_function), 1)
|
|
linear_node = find_first_node(body_function, torch._C._nn.linear)
|
|
self.assertTrue(linear_node is not None)
|
|
|
|
# Check that the innermost graph does not have any params
|
|
self.assertTrue(len(dict(body_function.named_parameters())) == 0)
|
|
self.assertTrue(len(dict(body_function.named_children())) == 0)
|
|
|
|
def test_flat_list_output(self):
|
|
def f(x):
|
|
return wrap(lambda x: [torch.sin(x), torch.cos(x)], x)
|
|
|
|
x = torch.randn(3)
|
|
arg_count = ifdynstaticdefault(2, 3)
|
|
self._test_wrap_simple(
|
|
f, default_args_generator((x,)), arg_count, expected_opcount=3
|
|
)
|
|
|
|
def test_fallback_on_python_primitives_output(self):
|
|
counters.clear()
|
|
cnt = CompileCounter()
|
|
|
|
@torch.compile(backend=cnt)
|
|
def f(x):
|
|
return wrap(lambda x: [1, torch.sin(x), 2.0], x)
|
|
|
|
x = torch.randn(3)
|
|
result = f(x)
|
|
self.assertEqual(result, [1, torch.sin(x), 2.0])
|
|
self.assertEqual(cnt.frame_count, 0)
|
|
assert_dict_matches_regex(
|
|
self,
|
|
dict(counters["graph_break"]),
|
|
{".*HigherOrderOperator body's output must consist of tensors only": 1},
|
|
)
|
|
|
|
def test_nested_tuple_output(self):
|
|
def f(x):
|
|
((a, b),) = wrap(lambda x: ((x.sin(), x.cos()),), x)
|
|
return a + b
|
|
|
|
x = torch.randn(2, 3)
|
|
|
|
counters.clear()
|
|
arg_count = ifdynstaticdefault(2, 4)
|
|
graph = self._test_wrap_simple(
|
|
f, default_args_generator((x,)), arg_count, 4, return_graph=True
|
|
)
|
|
self.assertEqual(len(counters["graph_break"]), 0)
|
|
|
|
if check_dynamic_shape_capture():
|
|
return
|
|
|
|
self.assertExpectedInline(
|
|
graph,
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(self, L_x_: "f32[2, 3]"):
|
|
l_x_ = L_x_
|
|
|
|
wrap_body_0 = self.wrap_body_0
|
|
wrap = torch.ops.higher_order.wrap(wrap_body_0, l_x_); wrap_body_0 = l_x_ = None
|
|
a: "f32[2, 3]" = wrap[0]
|
|
b: "f32[2, 3]" = wrap[1]; wrap = None
|
|
|
|
add: "f32[2, 3]" = a + b; a = b = None
|
|
return (add,)
|
|
|
|
class wrap_body_0(torch.nn.Module):
|
|
def forward(self, l_x_: "f32[2, 3]"):
|
|
child: "f32[2, 3]" = l_x_.sin()
|
|
child_1: "f32[2, 3]" = l_x_.cos(); l_x_ = None
|
|
return (child, child_1)
|
|
""",
|
|
)
|
|
|
|
def test_output_with_dict(self):
|
|
def f(x):
|
|
return wrap(lambda x: [{"a": -x}], x)
|
|
|
|
x = torch.randn(3)
|
|
|
|
counters.clear()
|
|
|
|
arg_count = ifdynstaticdefault(2, 3)
|
|
graph = self._test_wrap_simple(
|
|
f, default_args_generator((x,)), arg_count, 2, return_graph=True
|
|
)
|
|
self.assertEqual(len(counters["graph_break"]), 0)
|
|
|
|
if check_dynamic_shape_capture():
|
|
return
|
|
|
|
self.assertExpectedInline(
|
|
graph,
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(self, L_x_: "f32[3]"):
|
|
l_x_ = L_x_
|
|
|
|
wrap_body_0 = self.wrap_body_0
|
|
wrap = torch.ops.higher_order.wrap(wrap_body_0, l_x_); wrap_body_0 = l_x_ = None
|
|
getitem: "f32[3]" = wrap[0]; wrap = None
|
|
return (getitem,)
|
|
|
|
class wrap_body_0(torch.nn.Module):
|
|
def forward(self, l_x_: "f32[3]"):
|
|
child: "f32[3]" = -l_x_; l_x_ = None
|
|
return (child,)
|
|
""",
|
|
)
|
|
|
|
def test_access_module_attr(self):
|
|
counters.clear()
|
|
backend = EagerAndRecordGraphs()
|
|
cnt = CompileCounterWithBackend(backend)
|
|
mod = torch.nn.Linear(3, 3)
|
|
x = torch.randn(3, 3)
|
|
|
|
@torch.compile(backend=cnt, fullgraph=True)
|
|
def f(x):
|
|
y = mod(x)
|
|
return wrap(lambda y: y - mod.bias, y)
|
|
|
|
result = f(x)
|
|
self.assertEqual(result, mod(x) - mod.bias)
|
|
self.assertEqual(cnt.frame_count, 1)
|
|
|
|
self.assertEqual(len(backend.graphs), 1)
|
|
wrap_node = find_first_node(backend.graphs[0], wrap)
|
|
self.assertTrue(len(wrap_node.args), 3)
|
|
|
|
# Check that the linear bias and weight are getattr in the outer graph
|
|
if not torch._dynamo.config.inline_inbuilt_nn_modules:
|
|
self.assertTrue(len(dict(backend.graphs[0].named_parameters())) == 2)
|
|
|
|
# Check that the inner function has one op and its a linear op
|
|
body_function = getattr(backend.graphs[0], wrap_node.args[0].name)
|
|
self.assertEqual(op_count(body_function), 1)
|
|
|
|
# Check that the innermost graph does not have any params
|
|
self.assertTrue(len(dict(body_function.named_parameters())) == 0)
|
|
self.assertTrue(len(dict(body_function.named_children())) == 0)
|
|
|
|
def test_make_closure(self):
|
|
def f(x, y):
|
|
def g(x):
|
|
return x + y
|
|
|
|
return g(x)
|
|
|
|
def h(x, y):
|
|
return wrap(f, x, y)
|
|
|
|
x = torch.randn(3, 3)
|
|
y = torch.randn(3, 3)
|
|
arg_count = ifdynstaticdefault(3, 4)
|
|
self._test_wrap_simple(h, default_args_generator((x, y)), arg_count)
|
|
|
|
def test_internal_nonlocal(self):
|
|
def f(x, y):
|
|
w = 1
|
|
|
|
def g(x):
|
|
nonlocal w
|
|
w = x
|
|
return x
|
|
|
|
def h(x):
|
|
nonlocal w
|
|
w = w + 1
|
|
return x
|
|
|
|
g(x)
|
|
h(x)
|
|
return w + y
|
|
|
|
def h(x, y):
|
|
return wrap(f, x, y)
|
|
|
|
x = torch.randn(3, 3)
|
|
y = torch.randn(3, 3)
|
|
arg_count = ifdynstaticdefault(3, 4)
|
|
self._test_wrap_simple(h, default_args_generator((x, y)), arg_count)
|
|
|
|
def test_capture_numpy_number(self):
|
|
import numpy as np
|
|
|
|
y = np.float32(1.0)
|
|
|
|
def f(x):
|
|
return wrap(lambda x: x + y, x)
|
|
|
|
x = torch.randn(3)
|
|
# np.number are lifted to graph inputs
|
|
arg_count = ifdynstaticdefault(3, 4)
|
|
self._test_wrap_simple(f, default_args_generator((x,)), arg_count)
|
|
|
|
def test_freevars_as_inputs_to_wrap(self):
|
|
y = torch.randn(3)
|
|
|
|
def f(x):
|
|
return wrap(lambda x, y: x + y, x, y)
|
|
|
|
x = torch.randn(3)
|
|
arg_count = ifdynstaticdefault(3, 4)
|
|
self._test_wrap_simple(f, default_args_generator((x,)), arg_count)
|
|
|
|
def test_lift_tensor_constant(self):
|
|
def f(x):
|
|
y = torch.tensor(1.0)
|
|
return wrap(lambda x: x + y, x)
|
|
|
|
x = torch.randn(3)
|
|
arg_count = ifdynstaticdefault(3, 4)
|
|
self._test_wrap_simple(
|
|
f, default_args_generator((x,)), arg_count, expected_opcount=3
|
|
)
|
|
|
|
def test_nested_wrap(self):
|
|
class MockModule(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(10, 10)
|
|
|
|
def forward(self, x):
|
|
return self.linear(x)
|
|
|
|
mod = MockModule()
|
|
|
|
# Two levels of wrap ops
|
|
def gn(x):
|
|
return torch.cos(x) + wrap(mod, x)
|
|
|
|
def fn(x):
|
|
return wrap(gn, x)
|
|
|
|
arg_count = ifdynstaticdefault(4, 5)
|
|
self._test_wrap_simple(
|
|
fn, default_args_generator((torch.randn(10, 10),)), arg_count
|
|
)
|
|
|
|
def test_fn_with_kwargs_in_torch_ops(self):
|
|
def fn(x):
|
|
return wrap(lambda z: torch.cos(input=z), x)
|
|
|
|
x = torch.randn(3)
|
|
arg_count = ifdynstaticdefault(2, 3)
|
|
self._test_wrap_simple(fn, default_args_generator((x,)), arg_count)
|
|
|
|
def test_hooks(self):
|
|
class ToyModel(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.net = torch.nn.Linear(10, 10)
|
|
|
|
def forward(self, x):
|
|
return self.net(x)
|
|
|
|
model = ToyModel()
|
|
forward_handles = {}
|
|
activations = {}
|
|
|
|
def save_activations(mod, inp, out):
|
|
activations[name] = inp
|
|
|
|
for name, module in model.named_children():
|
|
forward_handles[name] = module.register_forward_hook(save_activations)
|
|
|
|
@torch.compile(backend="eager")
|
|
def fn(x):
|
|
return wrap(lambda x: model(x), x)
|
|
|
|
for i in range(2):
|
|
# second iteration is key, hooks would have fired during aot trace
|
|
# on first iter
|
|
activations.clear()
|
|
x = torch.randn((10, 10))
|
|
pred = fn(x)
|
|
loss = pred.sum()
|
|
loss.backward()
|
|
|
|
self.assertTrue(activations.keys() == forward_handles.keys())
|
|
|
|
def _get_source_fn_stack(self, gm, node_names):
|
|
ret = {}
|
|
for mod in gm.modules():
|
|
for node in mod.graph.nodes:
|
|
if node.name in node_names:
|
|
actual_stack = [
|
|
name for name, _ in node.meta.get("source_fn_stack", [])
|
|
]
|
|
ret[node.name] = actual_stack
|
|
return ret
|
|
|
|
def test_wrap_source_fn_stack(self):
|
|
class MockModule(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(4, 4)
|
|
|
|
def forward(self, x):
|
|
return self.linear(x)
|
|
|
|
mod = MockModule()
|
|
|
|
def gn(x):
|
|
return torch.cos(x) + wrap(mod, x)
|
|
|
|
def fn(x):
|
|
return wrap(gn, x)
|
|
|
|
backend = EagerAndRecordGraphs()
|
|
inp = torch.randn((4, 4))
|
|
torch.compile(fn, backend=backend, fullgraph=True)(inp)
|
|
|
|
gm = backend.graphs[0]
|
|
actual_stack = self._get_source_fn_stack(gm, {"cos", "add", "linear"})
|
|
self.assertExpectedInline(
|
|
pprint.pformat(actual_stack),
|
|
"""\
|
|
{'add': ['wrap', 'add'],
|
|
'cos': ['wrap', 'cos'],
|
|
'linear': ['wrap', 'wrap', 'linear']}""",
|
|
)
|
|
|
|
def test_cond_source_fn_stack(self):
|
|
backend = EagerAndRecordGraphs()
|
|
|
|
@torch.compile(backend=backend, fullgraph=True)
|
|
def cond_f(pred, pred2, x, y):
|
|
def true_fn(pred2, x, y):
|
|
return x + y
|
|
|
|
def false_fn(pred2, x, y):
|
|
def true_fn2(x, y):
|
|
return x.sin() - y.cos()
|
|
|
|
def false_fn2(x, y):
|
|
return x.cos() - y.sin()
|
|
|
|
return control_flow.cond(pred2, true_fn2, false_fn2, [x, y])
|
|
|
|
return control_flow.cond(pred, true_fn, false_fn, [pred2, x, y])
|
|
|
|
pred = torch.tensor(True)
|
|
pred2 = torch.tensor(False)
|
|
xs = torch.randn(2, 3, 3)
|
|
y = torch.randn(3, 3)
|
|
cond_f(pred, pred2, xs, y)
|
|
|
|
gm = backend.graphs[0]
|
|
actual_stack = self._get_source_fn_stack(gm, {"cos", "add", "sin", "sub"})
|
|
self.assertExpectedInline(
|
|
pprint.pformat(actual_stack),
|
|
"""\
|
|
{'add': ['cond', 'add'],
|
|
'cos': ['cond', 'cond', 'cos'],
|
|
'sin': ['cond', 'cond', 'sin'],
|
|
'sub': ['cond', 'cond', 'sub']}""",
|
|
)
|
|
|
|
def test_map_source_fn_stack(self):
|
|
backend = EagerAndRecordGraphs()
|
|
|
|
xs = torch.randn(2, 3, 3)
|
|
y = torch.randn(3)
|
|
|
|
@torch.compile(backend=backend, fullgraph=True)
|
|
def map_f(xs, y):
|
|
def inner(x, y):
|
|
def inner2(x, y):
|
|
return x + y
|
|
|
|
return control_flow.map(inner2, x, y) * y.cos()
|
|
|
|
return control_flow.map(inner, xs, y).sin()
|
|
|
|
result = map_f(xs, y)
|
|
|
|
gm = backend.graphs[0]
|
|
actual_stack = self._get_source_fn_stack(gm, {"cos", "add", "sin"})
|
|
self.assertExpectedInline(
|
|
pprint.pformat(actual_stack),
|
|
"""{'add': ['map', 'map', 'add'], 'cos': ['map', 'cos'], 'sin': ['sin']}""",
|
|
)
|
|
|
|
def test_grad_source_fn_stack(self):
|
|
backend = EagerAndRecordGraphs()
|
|
|
|
def fn(x):
|
|
return x.sin().sum()
|
|
|
|
@torch.compile(backend=backend, fullgraph=False)
|
|
def wrapper_fn(x):
|
|
return torch.func.grad(torch.func.grad(fn))(x)
|
|
|
|
x = torch.randn(())
|
|
|
|
wrapper_fn(x)
|
|
gm = backend.graphs[0]
|
|
actual_stack = self._get_source_fn_stack(gm, {"sum_1", "sin"})
|
|
self.assertExpectedInline(
|
|
pprint.pformat(actual_stack),
|
|
"""{'sin': ['sin']}""",
|
|
)
|
|
|
|
def test_vmap_multiply_scalar(self):
|
|
@torch.compile(backend="inductor", fullgraph=True)
|
|
def g(x):
|
|
return torch.vmap(torch.mul, in_dims=(0, None))(x, 3.14)
|
|
|
|
x = torch.randn(3)
|
|
y = g(x)
|
|
self.assertEqual(y, x * 3.14)
|
|
|
|
@torch.compile(backend="inductor", fullgraph=True)
|
|
def f(x):
|
|
return torch.vmap(torch.mul, in_dims=(0, None))(x, 314)
|
|
|
|
x = torch.randn(3)
|
|
y = f(x)
|
|
self.assertEqual(y, x * 314)
|
|
|
|
def test_vmap_source_fn_stack(self):
|
|
backend = EagerAndRecordGraphs()
|
|
|
|
def inner_fn(x):
|
|
return torch.func.vmap(lambda x: x.sum(0) + x.sum(1))(x)
|
|
|
|
@torch.compile(backend=backend, fullgraph=True)
|
|
def fn(x):
|
|
return torch.func.vmap(lambda x: inner_fn(x.cos()))(x)
|
|
|
|
x = torch.randn(3, 3, 3, 3)
|
|
fn(x)
|
|
gm = backend.graphs[0]
|
|
actual_stack = self._get_source_fn_stack(
|
|
gm, {"sum_1", "sum_2", "batched_output"}
|
|
)
|
|
self.assertExpectedInline(
|
|
pprint.pformat(actual_stack),
|
|
"""{'sum_1': ['sum_1'], 'sum_2': ['sum_2']}""",
|
|
)
|
|
|
|
# https://github.com/pytorch/pytorch/issues/137061
|
|
def test_dynamic_shapes_over_vmap_batch_size(self):
|
|
def gn(a, b, c, d):
|
|
return a + b + c + d
|
|
|
|
def fn(func, a, b, c, d):
|
|
a = torch.arange(a)
|
|
b = torch.arange(b)
|
|
c = torch.arange(c)
|
|
d = torch.arange(d)
|
|
func = torch.vmap(func, in_dims=(0, None, None, None))
|
|
func = torch.vmap(func, in_dims=(None, 0, None, None))
|
|
func = torch.vmap(func, in_dims=(None, None, 0, None))
|
|
func = torch.vmap(func, in_dims=(None, None, None, 0))
|
|
return func(a, b, c, d)
|
|
|
|
cnt = CompileCounterWithBackend("eager")
|
|
# We generate corresponding dynamic shapes test case at
|
|
# `test/dynamo/test_dynamic_shapes.py` automatically.
|
|
compiled_fn = torch.compile(fn, backend=cnt)
|
|
a, b, c, d = 2, 4, 8, 8
|
|
self.assertEqual(fn(gn, a, b, c, d), compiled_fn(gn, a, b, c, d))
|
|
self.assertEqual(cnt.frame_count, 1)
|
|
|
|
a, b, c, d = 4, 8, 16, 16
|
|
self.assertEqual(fn(gn, a, b, c, d), compiled_fn(gn, a, b, c, d))
|
|
# Ensure no recompile if dynamic shapes enabled.
|
|
self.assertEqual(cnt.frame_count, ifdynstaticdefault(2, 1))
|
|
graph = cnt.graphs[0]
|
|
|
|
# Check dynamic shapes generates correct graph.
|
|
if check_dynamic_shape_capture():
|
|
self.assertExpectedInline(
|
|
graph.code.strip(),
|
|
"""\
|
|
def forward(self, L_a_ : torch.SymInt, L_b_ : torch.SymInt, L_c_ : torch.SymInt, L_d_ : torch.SymInt):
|
|
l_a_ = L_a_
|
|
l_b_ = L_b_
|
|
l_c_ = L_c_
|
|
l_d_ = L_d_
|
|
a = torch.arange(l_a_)
|
|
b = torch.arange(l_b_)
|
|
c = torch.arange(l_c_)
|
|
d = torch.arange(l_d_)
|
|
lazy_load_decompositions = torch._functorch.vmap.lazy_load_decompositions(); lazy_load_decompositions = None
|
|
_vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(l_d_, 'error'); _vmap_increment_nesting = None
|
|
child = torch._C._functorch._add_batch_dim(d, 0, 1); d = None
|
|
lazy_load_decompositions_1 = torch._functorch.vmap.lazy_load_decompositions(); lazy_load_decompositions_1 = None
|
|
_vmap_increment_nesting_1 = torch._C._functorch._vmap_increment_nesting(l_c_, 'error'); _vmap_increment_nesting_1 = None
|
|
child_1 = torch._C._functorch._add_batch_dim(c, 0, 2); c = None
|
|
lazy_load_decompositions_2 = torch._functorch.vmap.lazy_load_decompositions(); lazy_load_decompositions_2 = None
|
|
_vmap_increment_nesting_2 = torch._C._functorch._vmap_increment_nesting(l_b_, 'error'); _vmap_increment_nesting_2 = None
|
|
child_2 = torch._C._functorch._add_batch_dim(b, 0, 3); b = None
|
|
lazy_load_decompositions_3 = torch._functorch.vmap.lazy_load_decompositions(); lazy_load_decompositions_3 = None
|
|
_vmap_increment_nesting_3 = torch._C._functorch._vmap_increment_nesting(l_a_, 'error'); _vmap_increment_nesting_3 = None
|
|
_add_batch_dim_3 = torch._C._functorch._add_batch_dim(a, 0, 4); a = None
|
|
add = _add_batch_dim_3 + child_2; _add_batch_dim_3 = child_2 = None
|
|
add_1 = add + child_1; add = child_1 = None
|
|
batched_outputs = add_1 + child; add_1 = child = None
|
|
batched_outputs_1 = torch._C._functorch._remove_batch_dim(batched_outputs, 4, l_a_, 0); batched_outputs = l_a_ = None
|
|
_vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting(); _vmap_decrement_nesting = None
|
|
batched_outputs_2 = torch._C._functorch._remove_batch_dim(batched_outputs_1, 3, l_b_, 0); batched_outputs_1 = l_b_ = None
|
|
_vmap_decrement_nesting_1 = torch._C._functorch._vmap_decrement_nesting(); _vmap_decrement_nesting_1 = None
|
|
batched_outputs_3 = torch._C._functorch._remove_batch_dim(batched_outputs_2, 2, l_c_, 0); batched_outputs_2 = l_c_ = None
|
|
_vmap_decrement_nesting_2 = torch._C._functorch._vmap_decrement_nesting(); _vmap_decrement_nesting_2 = None
|
|
_remove_batch_dim_3 = torch._C._functorch._remove_batch_dim(batched_outputs_3, 1, l_d_, 0); batched_outputs_3 = l_d_ = None
|
|
_vmap_decrement_nesting_3 = torch._C._functorch._vmap_decrement_nesting(); _vmap_decrement_nesting_3 = None
|
|
return (_remove_batch_dim_3,)""", # noqa: B950
|
|
)
|
|
|
|
def test_cond_pytree_operands(self):
|
|
def _construct_pytree():
|
|
a = torch.randn(3, 3)
|
|
b = torch.randn(3, 3)
|
|
c = torch.randn(3, 3)
|
|
d = torch.randn(3, 3)
|
|
e = torch.randn(3, 3)
|
|
f = torch.randn(3, 3)
|
|
g = torch.randn(3, 3)
|
|
return (a, [[[b]]], c, (d, (e,), f), {"g": g})
|
|
|
|
pred = torch.tensor(True)
|
|
inp = _construct_pytree()
|
|
|
|
def _reduce_sum(flattened):
|
|
init = 0
|
|
for val in flattened:
|
|
init += val
|
|
return init
|
|
|
|
def _reduce_max(flattened):
|
|
init = flattened[0]
|
|
for val in flattened:
|
|
init = max(val, init)
|
|
return init
|
|
|
|
def true_fn(pytree_in):
|
|
flattened, spec = pytree.tree_flatten(pytree_in)
|
|
return _reduce_sum(flattened)
|
|
|
|
def false_fn(pytree_in):
|
|
flattened, spec = pytree.tree_flatten(pytree_in)
|
|
return _reduce_max(flattened)
|
|
|
|
def fn(pred, pytree_in):
|
|
return torch.cond(pred, true_fn, false_fn, [pytree_in])
|
|
|
|
backend = EagerAndRecordGraphs()
|
|
cnt = CompileCounterWithBackend(backend)
|
|
compiled_res = torch.compile(fn, backend=backend)(pred, inp)
|
|
eager_res = fn(pred, inp)
|
|
self.assertEqual(compiled_res, eager_res)
|
|
graph = backend.graphs[0]
|
|
|
|
# Dynamic shapes produce a slightly different graph.
|
|
if check_dynamic_shape_capture():
|
|
return
|
|
|
|
self.assertExpectedInline(
|
|
graph.code.strip(),
|
|
"""\
|
|
def forward(self, L_pred_ : torch.Tensor, L_pytree_in_0_ : torch.Tensor, L_pytree_in_1_0_0_0_ : torch.Tensor, L_pytree_in_2_ : torch.Tensor, L_pytree_in_3_0_ : torch.Tensor, L_pytree_in_3_1_0_ : torch.Tensor, L_pytree_in_3_2_ : torch.Tensor, L_pytree_in_4_g_ : torch.Tensor):
|
|
l_pred_ = L_pred_
|
|
l_pytree_in_0_ = L_pytree_in_0_
|
|
l_pytree_in_1_0_0_0_ = L_pytree_in_1_0_0_0_
|
|
l_pytree_in_2_ = L_pytree_in_2_
|
|
l_pytree_in_3_0_ = L_pytree_in_3_0_
|
|
l_pytree_in_3_1_0_ = L_pytree_in_3_1_0_
|
|
l_pytree_in_3_2_ = L_pytree_in_3_2_
|
|
l_pytree_in_4_g_ = L_pytree_in_4_g_
|
|
cond_true_0 = self.cond_true_0
|
|
cond_false_0 = self.cond_false_0
|
|
cond = torch.ops.higher_order.cond(l_pred_, cond_true_0, cond_false_0, [l_pytree_in_0_, l_pytree_in_1_0_0_0_, l_pytree_in_2_, l_pytree_in_3_0_, l_pytree_in_3_1_0_, l_pytree_in_3_2_, l_pytree_in_4_g_]); l_pred_ = cond_true_0 = cond_false_0 = l_pytree_in_0_ = l_pytree_in_1_0_0_0_ = l_pytree_in_2_ = l_pytree_in_3_0_ = l_pytree_in_3_1_0_ = l_pytree_in_3_2_ = l_pytree_in_4_g_ = None
|
|
getitem = cond[0]; cond = None
|
|
return (getitem,)""", # noqa: B950
|
|
)
|
|
|
|
def test_cond_pytree_operands_with_non_tensor_leaves(self):
|
|
def fn(pred, pytree_in):
|
|
return torch.cond(
|
|
pred, lambda x: x[0] + 1, lambda x: x[0] * 2, (pytree_in,)
|
|
)
|
|
|
|
pred = torch.tensor(True)
|
|
for pytree_in in [(1,), ("string",), (1.0,)]:
|
|
with self.assertRaisesRegex(
|
|
RuntimeError,
|
|
r"Expect operands to be a tuple of possibly nested dict/list/tuple",
|
|
):
|
|
fn(pred, pytree_in)
|
|
|
|
for pytree_in in [(1,), ("string",), (1.0,)]:
|
|
with self.assertRaisesRegex(
|
|
torch._dynamo.exc.UncapturedHigherOrderOpError,
|
|
r"Cond doesn't work unless it is captured completely with torch.compile",
|
|
):
|
|
torch.compile(fn, backend="eager")(pred, pytree_in)
|
|
|
|
def test_cond_with_empty_operands(self):
|
|
@torch.compile(fullgraph=True)
|
|
def fn(x, y, z):
|
|
def true_fn():
|
|
return y + 2
|
|
|
|
def false_fn():
|
|
return z + 1
|
|
|
|
return torch.cond(x, true_fn, false_fn)
|
|
|
|
zeros = torch.zeros(1)
|
|
ones = torch.ones(1)
|
|
self.assertEqual(fn(zeros, ones, ones), torch.tensor([2.0]))
|
|
self.assertEqual(fn(ones, ones, ones), torch.tensor([3.0]))
|
|
|
|
def test_hints_wrapper(self):
|
|
def ref_fn(x, y):
|
|
x = x + y
|
|
x = torch.relu(x)
|
|
x = x + y
|
|
return torch.abs(x)
|
|
|
|
def fn_with_hints(x, y):
|
|
x = x + y
|
|
|
|
def inner_body_fn(x, y):
|
|
x = torch.relu(x)
|
|
x = x + y
|
|
return x
|
|
|
|
def outer_body_fn(x, y):
|
|
x = hints_wrapper(inner_body_fn, (x, y), {}, hints={"inner_body": True})
|
|
x = torch.abs(x)
|
|
return x
|
|
|
|
res = hints_wrapper(outer_body_fn, (x, y), {}, hints={"outer_body": True})
|
|
return res
|
|
|
|
backend = EagerAndRecordGraphs()
|
|
cnt = CompileCounterWithBackend(backend)
|
|
|
|
x = torch.randn(2, 4)
|
|
y = torch.ones(4)
|
|
|
|
eager_res = fn_with_hints(x, y)
|
|
compiled_res = torch.compile(fn_with_hints, backend=cnt)(x, y)
|
|
ref_res = ref_fn(x, y)
|
|
self.assertEqual(eager_res, ref_res)
|
|
self.assertEqual(compiled_res, ref_res)
|
|
self.assertEqual(len(cnt.graphs), 1)
|
|
|
|
# Dynamic shapes produce a slightly different graph.
|
|
if check_dynamic_shape_capture():
|
|
return
|
|
|
|
graph = backend.graphs[0]
|
|
self.assertExpectedInline(
|
|
normalize_gm(graph.print_readable(print_output=False)),
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(self, L_x_: "f32[2, 4]", L_y_: "f32[4]"):
|
|
l_x_ = L_x_
|
|
l_y_ = L_y_
|
|
|
|
x: "f32[2, 4]" = l_x_ + l_y_; l_x_ = None
|
|
|
|
hints_wrapper_body_1 = self.hints_wrapper_body_1
|
|
hints_wrapper = torch.ops.higher_order.hints_wrapper(hints_wrapper_body_1, (x, l_y_), {}, hints = {'outer_body': True}); hints_wrapper_body_1 = x = l_y_ = None
|
|
res: "f32[2, 4]" = hints_wrapper[0]; hints_wrapper = None
|
|
return (res,)
|
|
|
|
class hints_wrapper_body_1(torch.nn.Module):
|
|
def forward(self, x: "f32[2, 4]", l_y_: "f32[4]"):
|
|
hints_wrapper_body_0 = self.hints_wrapper_body_0
|
|
hints_wrapper = torch.ops.higher_order.hints_wrapper(hints_wrapper_body_0, (x, l_y_), {}, hints = {'inner_body': True}); hints_wrapper_body_0 = x = l_y_ = None
|
|
x_1: "f32[2, 4]" = hints_wrapper[0]; hints_wrapper = None
|
|
|
|
x_2: "f32[2, 4]" = torch.abs(x_1); x_1 = None
|
|
return (x_2,)
|
|
|
|
class hints_wrapper_body_0(torch.nn.Module):
|
|
def forward(self, x: "f32[2, 4]", l_y_: "f32[4]"):
|
|
x_1: "f32[2, 4]" = torch.relu(x); x = None
|
|
|
|
x_2: "f32[2, 4]" = x_1 + l_y_; x_1 = l_y_ = None
|
|
return (x_2,)
|
|
""",
|
|
)
|
|
|
|
def test_hints_wrapper_no_hints(self):
|
|
def fn_with_hints(x, y):
|
|
def outer_body_fn(x, y):
|
|
x = torch.add(x, y)
|
|
return x
|
|
|
|
res = hints_wrapper(outer_body_fn, (x, y), {})
|
|
return res
|
|
|
|
backend = EagerAndRecordGraphs()
|
|
cnt = CompileCounterWithBackend(backend)
|
|
|
|
x = torch.randn(2, 4)
|
|
y = torch.ones(4)
|
|
|
|
msg = "hints_wrapper - key hints not provided"
|
|
with self.assertRaisesRegex(RuntimeError, msg):
|
|
compiled_res = torch.compile(fn_with_hints, backend=cnt)(x, y)
|
|
|
|
def test_hints_wrapper_incorrect_type(self):
|
|
def fn_with_hints(x, y):
|
|
def outer_body_fn(x, y):
|
|
x = torch.add(x, y)
|
|
return x
|
|
|
|
res = hints_wrapper(outer_body_fn, (x, y), {}, hints={"test": (True,)})
|
|
return res
|
|
|
|
backend = EagerAndRecordGraphs()
|
|
cnt = CompileCounterWithBackend(backend)
|
|
|
|
x = torch.randn(2, 4)
|
|
y = torch.ones(4)
|
|
|
|
msg = r"hints must be a dict containing int, float, bool or str value,"
|
|
with self.assertRaisesRegex(RuntimeError, msg):
|
|
compiled_res = torch.compile(fn_with_hints, backend=cnt)(x, y)
|
|
|
|
def test_hints_wrapper_pytree_inputs(self):
|
|
def fn_with_hints(x, y):
|
|
def outer_body_fn(x):
|
|
res = torch.add(x[0], x[1]["test"])
|
|
return res
|
|
|
|
res = hints_wrapper(
|
|
outer_body_fn, ((x, {"test": y}),), {}, hints={"test": True}
|
|
)
|
|
return res
|
|
|
|
backend = EagerAndRecordGraphs()
|
|
cnt = CompileCounterWithBackend(backend)
|
|
|
|
x = torch.randn(2, 4)
|
|
y = torch.ones(4)
|
|
|
|
msg = r"args must be a tuple of tensors, ints, floats, or bools,"
|
|
with self.assertRaisesRegex(RuntimeError, msg):
|
|
fn_with_hints(x, y)
|
|
|
|
|
|
class HigherOrderOpVmapGuardTests(LoggingTestCase):
|
|
@make_logging_test(recompiles=True)
|
|
def test_vmap_grad_guard_ok(self, records):
|
|
vmap = torch.vmap
|
|
grad = torch.func.grad
|
|
|
|
def g(x):
|
|
return vmap(grad(torch.sin))(x)
|
|
|
|
@torch.compile(backend="eager")
|
|
def fn(x):
|
|
return vmap(g)(x)
|
|
|
|
x = torch.randn(4, 5)
|
|
y = fn(x)
|
|
# sanity check
|
|
self.assertEqual(len(records), 0)
|
|
self.assertEqual(x.cos(), y)
|
|
|
|
# Calling the same function again won't have any effect on guards
|
|
fn(x)
|
|
self.assertEqual(len(records), 0)
|
|
|
|
@xfailIfTorchDynamo
|
|
@make_logging_test(recompiles=True)
|
|
def test_grad_guard_fail(self, records):
|
|
grad = torch.func.grad
|
|
|
|
@torch.compile(backend="eager")
|
|
def fn(x):
|
|
return grad(torch.sin)(x.sum())
|
|
|
|
x = torch.randn([])
|
|
fn(x)
|
|
self.assertEqual(len(records), 0)
|
|
|
|
# calling again should not invalidate the graph
|
|
fn(x)
|
|
self.assertEqual(len(records), 0)
|
|
|
|
# call grad should retrigger compilation
|
|
x = torch.randn(3)
|
|
grad(fn)(x)
|
|
self.assertGreater(len(records), 0)
|
|
record = self.getRecord(records, "pyfunctorch")
|
|
self.assertIn(
|
|
"""torch._functorch.pyfunctorch.compare_functorch_state([])""",
|
|
munge_exc(record.getMessage()),
|
|
)
|
|
|
|
@make_logging_test(recompiles=True)
|
|
def test_dual_level_guard(self, records):
|
|
fwAD = torch.autograd.forward_ad
|
|
|
|
@torch.compile(backend="eager", fullgraph=True)
|
|
def fn(foo, tangent):
|
|
with fwAD.dual_level():
|
|
dual = fwAD.make_dual(foo, tangent[1:])
|
|
return dual
|
|
|
|
foo = torch.rand(2)
|
|
tangent = torch.rand(3)
|
|
fn(foo, tangent)
|
|
self.assertEqual(len(records), 0)
|
|
|
|
# calling again should not invalidate the graph
|
|
fn(foo, tangent)
|
|
self.assertEqual(len(records), 0)
|
|
|
|
# assertRaises is only here because Nested forward mode AD is not supported
|
|
with self.assertRaises(torch._dynamo.exc.InternalTorchDynamoError):
|
|
with fwAD.dual_level():
|
|
fn(foo, tangent)
|
|
self.assertGreater(len(records), 0)
|
|
record = self.getRecord(records, "forward_ad")
|
|
self.assertIn(
|
|
"""torch.autograd.forward_ad._current_level == -1""",
|
|
munge_exc(record.getMessage()),
|
|
)
|
|
|
|
@xfailIfTorchDynamo
|
|
@make_logging_test(recompiles=True)
|
|
def test_jvp_guard_fail(self, records):
|
|
jvp = torch.func.jvp
|
|
vmap = torch.func.vmap
|
|
|
|
@torch.compile(backend="eager")
|
|
def fn(x):
|
|
return jvp(torch.sin, (x,), (x,))
|
|
|
|
x = torch.randn(3, 4)
|
|
fn(x)
|
|
self.assertEqual(len(records), 0)
|
|
|
|
# calling again should not invalidate the graph
|
|
fn(x)
|
|
self.assertEqual(len(records), 0)
|
|
|
|
# call jvp should retrigger compilation
|
|
x = torch.randn(3, 4, 5)
|
|
jvp(vmap(fn), (x,), (x,))
|
|
|
|
self.assertGreater(len(records), 0)
|
|
if self.hasRecord(records, "pyfunctorch"):
|
|
record = self.getRecord(records, "pyfunctorch")
|
|
self.assertIn(
|
|
"""torch._functorch.pyfunctorch.compare_functorch_state([])""",
|
|
munge_exc(record.getMessage()),
|
|
)
|
|
elif self.hasRecord(records, "forward_ad"):
|
|
record = self.getRecord(records, "forward_ad")
|
|
self.assertIn(
|
|
"""torch.autograd.forward_ad._current_level == -1""",
|
|
munge_exc(record.getMessage()),
|
|
)
|
|
|
|
@make_logging_test(recompiles=True)
|
|
def test_vmap_guard_ok(self, records):
|
|
@torch.compile(backend="eager")
|
|
def fn(x):
|
|
return torch.vmap(lambda x: x.sin())(x)
|
|
|
|
x = torch.randn(3, 3, 4, 5)
|
|
y = fn(x)
|
|
# sanity check
|
|
self.assertEqual(len(records), 0)
|
|
self.assertEqual(x.sin(), y)
|
|
|
|
# Calling the same function again won't have any effect on guards
|
|
z = fn(x)
|
|
self.assertEqual(len(records), 0)
|
|
self.assertEqual(x.sin(), z)
|
|
|
|
# calling with a different object will also not affect guards
|
|
w = fn(z)
|
|
self.assertEqual(len(records), 0)
|
|
self.assertEqual(z.sin(), w)
|
|
|
|
@xfailIfTorchDynamo
|
|
@make_logging_test(recompiles=True)
|
|
def test_vmap_guard_fail_different_state(self, records):
|
|
@torch.compile(backend="eager")
|
|
def fn(x):
|
|
return torch.vmap(lambda x: x.sin())(x)
|
|
|
|
x = torch.zeros(3, 4)
|
|
y = torch.vmap(fn, randomness="same")(x)
|
|
self.assertEqual(x.sin(), y)
|
|
self.assertEqual(len(records), 0)
|
|
|
|
# call vmap(vmap(fn))(x) should retrigger compilation
|
|
y = torch.vmap(fn, randomness="different")(x)
|
|
self.assertEqual(x.sin(), y)
|
|
self.assertGreater(len(records), 0)
|
|
record = self.getRecord(records, "pyfunctorch")
|
|
self.assertIn(
|
|
"""torch._functorch.pyfunctorch.compare_functorch_state([('Vmap', 1, 'same')])""",
|
|
record.getMessage(),
|
|
)
|
|
|
|
@xfailIfTorchDynamo
|
|
@make_logging_test(recompiles=True)
|
|
def test_vmap_guard_fail(self, records):
|
|
@torch.compile(backend="eager")
|
|
def fn(x):
|
|
return torch.vmap(lambda x: x.sin())(x)
|
|
|
|
x = torch.zeros(3, 3, 4, 5)
|
|
y = torch.vmap(fn)(x)
|
|
self.assertEqual(x.sin(), y)
|
|
self.assertEqual(len(records), 0)
|
|
|
|
# call vmap(vmap(fn))(x) should retrigger compilation as
|
|
# _functorch.current_level() is not the same
|
|
x = torch.zeros(3, 3, 3, 4, 5)
|
|
y = torch.vmap(torch.vmap(fn))(x)
|
|
self.assertEqual(x.sin(), y)
|
|
self.assertGreater(len(records), 0)
|
|
record = self.getRecord(records, "pyfunctorch")
|
|
self.assertIn(
|
|
"""torch._functorch.pyfunctorch.compare_functorch_state([('Vmap', 1, 'error')])""",
|
|
record.getMessage(),
|
|
)
|
|
|
|
@xfailIfTorchDynamo
|
|
@make_logging_test(recompiles=True)
|
|
def test_vmap_grad_vmap_guard_fail(self, records):
|
|
vmap = torch.vmap
|
|
grad = torch.func.grad
|
|
|
|
def g(x):
|
|
y = vmap(torch.sin, randomness="same")(x)
|
|
return y.sum(0)
|
|
|
|
@torch.compile(backend="eager")
|
|
def fn(x):
|
|
return grad(g)(x)
|
|
|
|
x = torch.randn(3, 3)
|
|
y = vmap(fn, randomness="error")(x)
|
|
self.assertEqual(x.cos(), y)
|
|
|
|
# previous FX graph should be invalidated
|
|
x = torch.randn(3, 3, 4)
|
|
y = vmap(vmap(fn, randomness="different"))(x)
|
|
self.assertGreater(len(records), 0)
|
|
record = self.getRecord(records, "pyfunctorch")
|
|
self.assertIn(
|
|
"""torch._functorch.pyfunctorch.compare_functorch_state([('Vmap', 1, 'error')])""",
|
|
munge_exc(record.getMessage()),
|
|
)
|
|
|
|
@xfailIfTorchDynamo
|
|
@make_logging_test(recompiles=True)
|
|
def test_vmap_recompile_different_states(self, records):
|
|
@torch.compile(backend="eager")
|
|
def fn(x):
|
|
return torch.vmap(lambda x: x.sin())(x)
|
|
|
|
x = torch.zeros(3, 3, 4, 5)
|
|
y = torch.vmap(fn, randomness="same")(x)
|
|
self.assertEqual(len(records), 0) # sanity check
|
|
|
|
y = torch.vmap(fn, randomness="different")(x)
|
|
self.assertGreater(len(records), 0)
|
|
record = self.getRecord(records, "pyfunctorch")
|
|
self.assertIn(
|
|
"""torch._functorch.pyfunctorch.compare_functorch_state([('Vmap', 1, 'same')])""",
|
|
munge_exc(record.getMessage()),
|
|
)
|
|
|
|
@make_logging_test(guards=True)
|
|
def test_emit_functorch_guard_if_active(self, records):
|
|
@torch.compile(backend="eager")
|
|
def fn(x):
|
|
return torch.sin(x)
|
|
|
|
x = torch.randn(3, 4)
|
|
_ = fn(x)
|
|
self.assertFalse(self.hasRecord(records, "pyfunctorch")) # sanity check
|
|
|
|
_ = torch.vmap(fn)(x)
|
|
self.assertTrue(self.hasRecord(records, "pyfunctorch"))
|
|
record = self.getRecord(records, "pyfunctorch")
|
|
self.assertIn(
|
|
"""torch._functorch.pyfunctorch.compare_functorch_state([('Vmap', 1, 'error')])""",
|
|
munge_exc(record.getMessage()),
|
|
)
|
|
|
|
@make_logging_test(recompiles=True)
|
|
def test_linearize_recompiles(self, records):
|
|
@torch.compile(backend="eager")
|
|
def fn(x):
|
|
out, jvp_fn = torch.func.linearize(torch.sin, x)
|
|
return out, jvp_fn(x)
|
|
|
|
x = torch.randn(2, 3)
|
|
fn(x)
|
|
self.assertEqual(len(records), 0)
|
|
|
|
z = torch.randn(2, 3)
|
|
fn(z)
|
|
self.assertEqual(len(records), 0)
|
|
|
|
y = torch.randn(3, 4)
|
|
fn(y)
|
|
self.assertGreater(len(records), 0)
|
|
|
|
|
|
class FuncTorchHigherOrderOpTests(torch._dynamo.test_case.TestCase):
|
|
def tearDown(self):
|
|
# Ensure that in the case of a test failure, the next test won't fail
|
|
# because of a previous call to _vmap_increment_nesting that wasn't undone
|
|
# i.e. test_vmap_free_tensor fails when PYTORCH_TEST_WITH_DYNAMO=1
|
|
# and the call to increment nesting is not undone
|
|
if not TEST_WITH_TORCHDYNAMO:
|
|
return
|
|
|
|
warn = False
|
|
while ci := torch._C._functorch.peek_interpreter_stack():
|
|
if ci.key() == torch._C._functorch.TransformType.Vmap:
|
|
warn = True
|
|
torch._C._functorch._vmap_decrement_nesting()
|
|
else:
|
|
break
|
|
|
|
if warn:
|
|
msg = (
|
|
"Interpreter stack is not empty. Test should have called "
|
|
"'torch._C._functorch._vmap_decrement_nesting()'"
|
|
)
|
|
warnings.warn(msg)
|
|
|
|
def _compile_check(self, fn, inputs, fullgraph=True, graph_idx=0):
|
|
backend = EagerAndRecordGraphs()
|
|
actual = fn(*inputs)
|
|
expected = torch.compile(fn, backend=backend, fullgraph=fullgraph)(*inputs)
|
|
|
|
self.assertEqual(actual, expected)
|
|
|
|
wrapped_gm = backend.graphs[graph_idx]
|
|
return wrapped_gm
|
|
|
|
def test_hessian(self):
|
|
counters.clear()
|
|
|
|
def wrapper_fn(x):
|
|
return torch.func.hessian(torch.sin)(x)
|
|
|
|
x = torch.randn(4, 3)
|
|
wrapped_gm = self._compile_check(wrapper_fn, (x,))
|
|
# Dynamic shapes produce a slightly different graph.
|
|
if check_dynamic_shape_capture():
|
|
return
|
|
|
|
actual = normalize_gm(wrapped_gm.print_readable(print_output=False))
|
|
self.assertExpectedInline(
|
|
actual,
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(self, L_x_: "f32[4, 3]"):
|
|
l_x_ = L_x_
|
|
|
|
tensor: "i64[1]" = torch.tensor((12,))
|
|
cumsum: "i64[1]" = tensor.cumsum(dim = 0); tensor = None
|
|
getitem: "i64[0]" = cumsum[slice(None, -1, None)]; cumsum = None
|
|
neg: "i64[0]" = getitem.neg(); getitem = None
|
|
unbind = neg.unbind(); neg = unbind = None
|
|
|
|
chunk: "f32[12, 12]" = l_x_.new_zeros(12, 12)
|
|
|
|
diagonal: "f32[12]" = chunk.diagonal(0)
|
|
fill_: "f32[12]" = diagonal.fill_(1); diagonal = fill_ = None
|
|
|
|
child: "f32[12, 4, 3]" = chunk.view(12, 4, 3); chunk = None
|
|
|
|
lazy_load_decompositions = torch._functorch.vmap.lazy_load_decompositions(); lazy_load_decompositions = None
|
|
|
|
_vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(12, 'error'); _vmap_increment_nesting = None
|
|
|
|
child_1: "f32[4, 3]" = torch._C._functorch._add_batch_dim(child, 0, 1); child = None
|
|
|
|
_jvp_treespec_compare = torch._functorch.eager_transforms._jvp_treespec_compare((l_x_,), (child_1,)); _jvp_treespec_compare = None
|
|
|
|
_jvp_increment_nesting = torch._C._functorch._jvp_increment_nesting(); _jvp_increment_nesting = None
|
|
_set_fwd_grad_enabled = torch._C._set_fwd_grad_enabled(True); _set_fwd_grad_enabled = None
|
|
_enter_dual_level = torch._C._enter_dual_level(); _enter_dual_level = None
|
|
|
|
_maybe_load_decompositions = torch.autograd.forward_ad._maybe_load_decompositions(); _maybe_load_decompositions = None
|
|
|
|
child_2: "f32[4, 3]" = torch._make_dual(l_x_, child_1, level = 0); child_1 = None
|
|
|
|
_wrap_for_grad: "f32[4, 3]" = torch._C._functorch._wrap_for_grad(l_x_, 2); l_x_ = _wrap_for_grad = None
|
|
|
|
_saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func.{grad, vjp, jacrev, hessian} don't yet support saved tensor hooks. Please open an issue with your use case."); _saved_tensors_hooks_disable = None
|
|
_grad_increment_nesting = torch._C._functorch._grad_increment_nesting(); _grad_increment_nesting = None
|
|
|
|
diff_primals: "f32[4, 3]" = torch._C._functorch._wrap_for_grad(child_2, 3); child_2 = None
|
|
|
|
set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True); set_inplace_requires_grad_allowed = None
|
|
|
|
_set_tensor_requires_grad: "f32[4, 3]" = torch._functorch.eager_transforms._set_tensor_requires_grad(diff_primals); _set_tensor_requires_grad = None
|
|
|
|
set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False); set_inplace_requires_grad_allowed_1 = None
|
|
|
|
o: "f32[4, 3]" = torch.sin(diff_primals)
|
|
|
|
results: "f32[4, 3]" = torch._C._functorch._unwrap_for_grad(o, 3)
|
|
|
|
_grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting(); _grad_decrement_nesting = None
|
|
_saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable(); _saved_tensors_hooks_enable = None
|
|
|
|
tensor_1: "i64[1]" = torch.tensor((12,))
|
|
cumsum_1: "i64[1]" = tensor_1.cumsum(dim = 0); tensor_1 = None
|
|
getitem_1: "i64[0]" = cumsum_1[slice(None, -1, None)]; cumsum_1 = None
|
|
neg_1: "i64[0]" = getitem_1.neg(); getitem_1 = None
|
|
unbind_1 = neg_1.unbind(); neg_1 = unbind_1 = None
|
|
|
|
chunk_1: "f32[12, 12]" = results.new_zeros(12, 12); results = None
|
|
|
|
diagonal_1: "f32[12]" = chunk_1.diagonal(0)
|
|
fill__1: "f32[12]" = diagonal_1.fill_(1); diagonal_1 = fill__1 = None
|
|
|
|
basis: "f32[12, 4, 3]" = chunk_1.view(12, 4, 3); chunk_1 = None
|
|
|
|
lazy_load_decompositions_1 = torch._functorch.vmap.lazy_load_decompositions(); lazy_load_decompositions_1 = None
|
|
|
|
_vmap_increment_nesting_1 = torch._C._functorch._vmap_increment_nesting(12, 'error'); _vmap_increment_nesting_1 = None
|
|
|
|
_add_batch_dim_1: "f32[4, 3]" = torch._C._functorch._add_batch_dim(basis, 0, 3); basis = None
|
|
|
|
_vjp_treespec_compare = torch._functorch.eager_transforms._vjp_treespec_compare(o, _add_batch_dim_1); _vjp_treespec_compare = None
|
|
|
|
_autograd_grad = torch._functorch.eager_transforms._autograd_grad([o], [diff_primals], [_add_batch_dim_1], retain_graph = True, create_graph = True); o = diff_primals = _add_batch_dim_1 = None
|
|
batched_outputs: "f32[4, 3]" = _autograd_grad[0]; _autograd_grad = None
|
|
|
|
chunked_result: "f32[12, 4, 3]" = torch._C._functorch._remove_batch_dim(batched_outputs, 3, 12, 0); batched_outputs = None
|
|
|
|
_vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting(); _vmap_decrement_nesting = None
|
|
|
|
split = chunked_result.split((12,), dim = 0); chunked_result = None
|
|
split_1: "f32[12, 4, 3]" = split[0]; split = None
|
|
|
|
output_input: "f32[4, 3, 4, 3]" = split_1.view((4, 3, 4, 3)); split_1 = None
|
|
|
|
_unpack_dual = torch._unpack_dual(output_input, level = 0); output_input = None
|
|
primal: "f32[4, 3, 4, 3]" = _unpack_dual[0]
|
|
dual: "f32[4, 3, 4, 3]" = _unpack_dual[1]; _unpack_dual = None
|
|
|
|
primals_out_unflatten: "f32[4, 3, 4, 3]" = torch._C._functorch._unwrap_for_grad(primal, 2); primal = primals_out_unflatten = None
|
|
|
|
tangents_out_unflatten: "f32[4, 3, 4, 3]" = torch._C._functorch._unwrap_for_grad(dual, 2); dual = None
|
|
|
|
_exit_dual_level = torch._C._exit_dual_level(0); _exit_dual_level = None
|
|
_set_fwd_grad_enabled_1 = torch._C._set_fwd_grad_enabled(True); _set_fwd_grad_enabled_1 = None
|
|
_jvp_decrement_nesting = torch._C._functorch._jvp_decrement_nesting(); _jvp_decrement_nesting = None
|
|
|
|
results_1: "f32[12, 4, 3, 4, 3]" = torch._C._functorch._remove_batch_dim(tangents_out_unflatten, 1, 12, 0); tangents_out_unflatten = None
|
|
|
|
_vmap_decrement_nesting_1 = torch._C._functorch._vmap_decrement_nesting(); _vmap_decrement_nesting_1 = None
|
|
|
|
movedim: "f32[4, 3, 4, 3, 12]" = results_1.movedim(0, -1); results_1 = None
|
|
split_2 = movedim.split((12,), dim = -1); movedim = None
|
|
jac_out_in: "f32[4, 3, 4, 3, 12]" = split_2[0]; split_2 = None
|
|
|
|
unflatten: "f32[4, 3, 4, 3, 4, 3]" = jac_out_in.unflatten(-1, (4, 3)); jac_out_in = None
|
|
return (unflatten,)
|
|
""",
|
|
)
|
|
|
|
def test_hessian_argnums(self):
|
|
counters.clear()
|
|
|
|
def fn(x, y):
|
|
return x.sin()
|
|
|
|
def wrapper_fn(x, y):
|
|
return torch.func.hessian(fn, argnums=(1,))(x, y)
|
|
|
|
x = torch.randn(4, 3)
|
|
y = torch.randn(3, 4)
|
|
wrapped_gm = self._compile_check(wrapper_fn, (x, y))
|
|
# Dynamic shapes produce a slightly different graph.
|
|
if check_dynamic_shape_capture():
|
|
return
|
|
|
|
actual = normalize_gm(wrapped_gm.print_readable(print_output=False))
|
|
self.assertExpectedInline(
|
|
"\n".join(actual.split("\n")[:-2]),
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(self, L_x_: "f32[4, 3]", L_y_: "f32[3, 4]"):
|
|
l_x_ = L_x_
|
|
l_y_ = L_y_
|
|
|
|
tensor: "i64[1]" = torch.tensor((12,))
|
|
cumsum: "i64[1]" = tensor.cumsum(dim = 0); tensor = None
|
|
getitem: "i64[0]" = cumsum[slice(None, -1, None)]; cumsum = None
|
|
neg: "i64[0]" = getitem.neg(); getitem = None
|
|
unbind = neg.unbind(); neg = unbind = None
|
|
|
|
chunk: "f32[12, 12]" = l_y_.new_zeros(12, 12)
|
|
|
|
diagonal: "f32[12]" = chunk.diagonal(0)
|
|
fill_: "f32[12]" = diagonal.fill_(1); diagonal = fill_ = None
|
|
|
|
child: "f32[12, 3, 4]" = chunk.view(12, 3, 4); chunk = None
|
|
|
|
lazy_load_decompositions = torch._functorch.vmap.lazy_load_decompositions(); lazy_load_decompositions = None
|
|
|
|
_vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(12, 'error'); _vmap_increment_nesting = None
|
|
|
|
child_1: "f32[3, 4]" = torch._C._functorch._add_batch_dim(child, 0, 1); child = None
|
|
|
|
_jvp_treespec_compare = torch._functorch.eager_transforms._jvp_treespec_compare((l_y_,), (child_1,)); _jvp_treespec_compare = None
|
|
|
|
_jvp_increment_nesting = torch._C._functorch._jvp_increment_nesting(); _jvp_increment_nesting = None
|
|
_set_fwd_grad_enabled = torch._C._set_fwd_grad_enabled(True); _set_fwd_grad_enabled = None
|
|
_enter_dual_level = torch._C._enter_dual_level(); _enter_dual_level = None
|
|
|
|
_maybe_load_decompositions = torch.autograd.forward_ad._maybe_load_decompositions(); _maybe_load_decompositions = None
|
|
|
|
child_3: "f32[3, 4]" = torch._make_dual(l_y_, child_1, level = 0); child_1 = None
|
|
|
|
child_2: "f32[4, 3]" = torch._C._functorch._wrap_for_grad(l_x_, 2); l_x_ = None
|
|
_wrap_for_grad_1: "f32[3, 4]" = torch._C._functorch._wrap_for_grad(l_y_, 2); l_y_ = _wrap_for_grad_1 = None
|
|
|
|
_saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func.{grad, vjp, jacrev, hessian} don't yet support saved tensor hooks. Please open an issue with your use case."); _saved_tensors_hooks_disable = None
|
|
_grad_increment_nesting = torch._C._functorch._grad_increment_nesting(); _grad_increment_nesting = None
|
|
|
|
_wrap_for_grad_2: "f32[4, 3]" = torch._C._functorch._wrap_for_grad(child_2, 3); child_2 = None
|
|
child_4: "f32[3, 4]" = torch._C._functorch._wrap_for_grad(child_3, 3); child_3 = None
|
|
|
|
set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True); set_inplace_requires_grad_allowed = None
|
|
|
|
_set_tensor_requires_grad: "f32[3, 4]" = torch._functorch.eager_transforms._set_tensor_requires_grad(child_4); _set_tensor_requires_grad = None
|
|
|
|
set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False); set_inplace_requires_grad_allowed_1 = None
|
|
|
|
o: "f32[4, 3]" = _wrap_for_grad_2.sin(); _wrap_for_grad_2 = None
|
|
|
|
results: "f32[4, 3]" = torch._C._functorch._unwrap_for_grad(o, 3)
|
|
|
|
_grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting(); _grad_decrement_nesting = None
|
|
_saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable(); _saved_tensors_hooks_enable = None
|
|
|
|
tensor_1: "i64[1]" = torch.tensor((12,))
|
|
cumsum_1: "i64[1]" = tensor_1.cumsum(dim = 0); tensor_1 = None
|
|
getitem_1: "i64[0]" = cumsum_1[slice(None, -1, None)]; cumsum_1 = None
|
|
neg_1: "i64[0]" = getitem_1.neg(); getitem_1 = None
|
|
unbind_1 = neg_1.unbind(); neg_1 = unbind_1 = None
|
|
|
|
chunk_1: "f32[12, 12]" = results.new_zeros(12, 12); results = None
|
|
|
|
diagonal_1: "f32[12]" = chunk_1.diagonal(0)
|
|
fill__1: "f32[12]" = diagonal_1.fill_(1); diagonal_1 = fill__1 = None
|
|
|
|
basis: "f32[12, 4, 3]" = chunk_1.view(12, 4, 3); chunk_1 = None
|
|
|
|
lazy_load_decompositions_1 = torch._functorch.vmap.lazy_load_decompositions(); lazy_load_decompositions_1 = None
|
|
|
|
_vmap_increment_nesting_1 = torch._C._functorch._vmap_increment_nesting(12, 'error'); _vmap_increment_nesting_1 = None
|
|
|
|
_add_batch_dim_1: "f32[4, 3]" = torch._C._functorch._add_batch_dim(basis, 0, 3); basis = None
|
|
|
|
_vjp_treespec_compare = torch._functorch.eager_transforms._vjp_treespec_compare(o, _add_batch_dim_1); _vjp_treespec_compare = None
|
|
|
|
_autograd_grad = torch._functorch.eager_transforms._autograd_grad([o], [child_4], [_add_batch_dim_1], retain_graph = True, create_graph = True); o = child_4 = _add_batch_dim_1 = None
|
|
child_5: "f32[3, 4]" = _autograd_grad[0]; _autograd_grad = None
|
|
|
|
child_6: "f32[12, 3, 4]" = torch._C._functorch._remove_batch_dim(child_5, 3, 12, 0); child_5 = None
|
|
|
|
_vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting(); _vmap_decrement_nesting = None
|
|
|
|
split = child_6.split((12,), dim = 0); child_6 = None
|
|
split_1: "f32[12, 3, 4]" = split[0]; split = None
|
|
|
|
child_7: "f32[4, 3, 3, 4]" = split_1.view((4, 3, 3, 4)); split_1 = None
|
|
|
|
_unpack_dual = torch._unpack_dual(child_7, level = 0); child_7 = None
|
|
primal: "f32[4, 3, 3, 4]" = _unpack_dual[0]; _unpack_dual = None
|
|
|
|
tangent: "f32[4, 3, 3, 4]" = torch.zeros_like(primal)
|
|
|
|
child_8: "f32[4, 3, 3, 4]" = torch._C._functorch._unwrap_for_grad(primal, 2); primal = child_8 = None
|
|
|
|
child_9: "f32[4, 3, 3, 4]" = torch._C._functorch._unwrap_for_grad(tangent, 2); tangent = None
|
|
|
|
_exit_dual_level = torch._C._exit_dual_level(0); _exit_dual_level = None
|
|
_set_fwd_grad_enabled_1 = torch._C._set_fwd_grad_enabled(True); _set_fwd_grad_enabled_1 = None
|
|
_jvp_decrement_nesting = torch._C._functorch._jvp_decrement_nesting(); _jvp_decrement_nesting = None
|
|
|
|
child_10: "f32[12, 4, 3, 3, 4]" = torch._C._functorch._remove_batch_dim(child_9, 1, 12, 0); child_9 = None
|
|
|
|
_vmap_decrement_nesting_1 = torch._C._functorch._vmap_decrement_nesting(); _vmap_decrement_nesting_1 = None
|
|
|
|
movedim: "f32[4, 3, 3, 4, 12]" = child_10.movedim(0, -1); child_10 = None
|
|
split_2 = movedim.split((12,), dim = -1); movedim = None
|
|
jac_out_in: "f32[4, 3, 3, 4, 12]" = split_2[0]; split_2 = None
|
|
|
|
unflatten: "f32[4, 3, 3, 4, 3, 4]" = jac_out_in.unflatten(-1, (3, 4)); jac_out_in = None""",
|
|
)
|
|
|
|
self.assertExpectedInline(
|
|
actual.split("\n")[-2],
|
|
""" return (unflatten,)""",
|
|
)
|
|
|
|
def test_jacrev(self):
|
|
counters.clear()
|
|
|
|
def wrapper_fn(x):
|
|
return torch.func.jacrev(torch.sin)(x)
|
|
|
|
x = torch.randn(4, 3)
|
|
wrapped_gm = self._compile_check(wrapper_fn, (x,))
|
|
# Dynamic shapes produce a slightly different graph.
|
|
if check_dynamic_shape_capture():
|
|
return
|
|
|
|
actual = normalize_gm(wrapped_gm.print_readable(print_output=False))
|
|
self.assertExpectedInline(
|
|
actual,
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(self, L_x_: "f32[4, 3]"):
|
|
l_x_ = L_x_
|
|
|
|
_saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func.{grad, vjp, jacrev, hessian} don't yet support saved tensor hooks. Please open an issue with your use case."); _saved_tensors_hooks_disable = None
|
|
_grad_increment_nesting = torch._C._functorch._grad_increment_nesting(); _grad_increment_nesting = None
|
|
|
|
diff_primals: "f32[4, 3]" = torch._C._functorch._wrap_for_grad(l_x_, 1); l_x_ = None
|
|
|
|
set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True); set_inplace_requires_grad_allowed = None
|
|
|
|
_set_tensor_requires_grad: "f32[4, 3]" = torch._functorch.eager_transforms._set_tensor_requires_grad(diff_primals); _set_tensor_requires_grad = None
|
|
|
|
set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False); set_inplace_requires_grad_allowed_1 = None
|
|
|
|
o: "f32[4, 3]" = torch.sin(diff_primals)
|
|
|
|
results: "f32[4, 3]" = torch._C._functorch._unwrap_for_grad(o, 1)
|
|
|
|
_grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting(); _grad_decrement_nesting = None
|
|
_saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable(); _saved_tensors_hooks_enable = None
|
|
|
|
tensor: "i64[1]" = torch.tensor((12,))
|
|
cumsum: "i64[1]" = tensor.cumsum(dim = 0); tensor = None
|
|
getitem: "i64[0]" = cumsum[slice(None, -1, None)]; cumsum = None
|
|
neg: "i64[0]" = getitem.neg(); getitem = None
|
|
unbind = neg.unbind(); neg = unbind = None
|
|
|
|
chunk: "f32[12, 12]" = results.new_zeros(12, 12); results = None
|
|
|
|
diagonal: "f32[12]" = chunk.diagonal(0)
|
|
fill_: "f32[12]" = diagonal.fill_(1); diagonal = fill_ = None
|
|
|
|
basis: "f32[12, 4, 3]" = chunk.view(12, 4, 3); chunk = None
|
|
|
|
lazy_load_decompositions = torch._functorch.vmap.lazy_load_decompositions(); lazy_load_decompositions = None
|
|
|
|
_vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(12, 'error'); _vmap_increment_nesting = None
|
|
|
|
_add_batch_dim: "f32[4, 3]" = torch._C._functorch._add_batch_dim(basis, 0, 1); basis = None
|
|
|
|
_vjp_treespec_compare = torch._functorch.eager_transforms._vjp_treespec_compare(o, _add_batch_dim); _vjp_treespec_compare = None
|
|
|
|
_autograd_grad = torch._functorch.eager_transforms._autograd_grad([o], [diff_primals], [_add_batch_dim], retain_graph = True, create_graph = True); o = diff_primals = _add_batch_dim = None
|
|
batched_outputs: "f32[4, 3]" = _autograd_grad[0]; _autograd_grad = None
|
|
|
|
chunked_result: "f32[12, 4, 3]" = torch._C._functorch._remove_batch_dim(batched_outputs, 1, 12, 0); batched_outputs = None
|
|
|
|
_vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting(); _vmap_decrement_nesting = None
|
|
|
|
split = chunked_result.split((12,), dim = 0); chunked_result = None
|
|
split_1: "f32[12, 4, 3]" = split[0]; split = None
|
|
|
|
output_input: "f32[4, 3, 4, 3]" = split_1.view((4, 3, 4, 3)); split_1 = None
|
|
return (output_input,)
|
|
""",
|
|
)
|
|
|
|
def test_jacrev_two_tensors_argnums(self):
|
|
counters.clear()
|
|
|
|
def fn(x, y):
|
|
return y.sin()
|
|
|
|
def wrapper_fn(x, y):
|
|
return torch.func.jacrev(fn, argnums=1)(x, y)
|
|
|
|
x = torch.randn(4, 3)
|
|
y = torch.randn(3, 4)
|
|
wrapped_gm = self._compile_check(wrapper_fn, (x, y))
|
|
# Dynamic shapes produce a slightly different graph.
|
|
if check_dynamic_shape_capture():
|
|
return
|
|
|
|
actual = normalize_gm(wrapped_gm.print_readable(print_output=False))
|
|
self.assertExpectedInline(
|
|
actual,
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(self, L_x_: "f32[4, 3]", L_y_: "f32[3, 4]"):
|
|
l_x_ = L_x_
|
|
l_y_ = L_y_
|
|
|
|
_saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func.{grad, vjp, jacrev, hessian} don't yet support saved tensor hooks. Please open an issue with your use case."); _saved_tensors_hooks_disable = None
|
|
_grad_increment_nesting = torch._C._functorch._grad_increment_nesting(); _grad_increment_nesting = None
|
|
|
|
_wrap_for_grad: "f32[4, 3]" = torch._C._functorch._wrap_for_grad(l_x_, 1); l_x_ = _wrap_for_grad = None
|
|
diff_primals: "f32[3, 4]" = torch._C._functorch._wrap_for_grad(l_y_, 1); l_y_ = None
|
|
|
|
set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True); set_inplace_requires_grad_allowed = None
|
|
|
|
_set_tensor_requires_grad: "f32[3, 4]" = torch._functorch.eager_transforms._set_tensor_requires_grad(diff_primals); _set_tensor_requires_grad = None
|
|
|
|
set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False); set_inplace_requires_grad_allowed_1 = None
|
|
|
|
o: "f32[3, 4]" = diff_primals.sin()
|
|
|
|
results: "f32[3, 4]" = torch._C._functorch._unwrap_for_grad(o, 1)
|
|
|
|
_grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting(); _grad_decrement_nesting = None
|
|
_saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable(); _saved_tensors_hooks_enable = None
|
|
|
|
tensor: "i64[1]" = torch.tensor((12,))
|
|
cumsum: "i64[1]" = tensor.cumsum(dim = 0); tensor = None
|
|
getitem: "i64[0]" = cumsum[slice(None, -1, None)]; cumsum = None
|
|
neg: "i64[0]" = getitem.neg(); getitem = None
|
|
unbind = neg.unbind(); neg = unbind = None
|
|
|
|
chunk: "f32[12, 12]" = results.new_zeros(12, 12); results = None
|
|
|
|
diagonal: "f32[12]" = chunk.diagonal(0)
|
|
fill_: "f32[12]" = diagonal.fill_(1); diagonal = fill_ = None
|
|
|
|
basis: "f32[12, 3, 4]" = chunk.view(12, 3, 4); chunk = None
|
|
|
|
lazy_load_decompositions = torch._functorch.vmap.lazy_load_decompositions(); lazy_load_decompositions = None
|
|
|
|
_vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(12, 'error'); _vmap_increment_nesting = None
|
|
|
|
_add_batch_dim: "f32[3, 4]" = torch._C._functorch._add_batch_dim(basis, 0, 1); basis = None
|
|
|
|
_vjp_treespec_compare = torch._functorch.eager_transforms._vjp_treespec_compare(o, _add_batch_dim); _vjp_treespec_compare = None
|
|
|
|
_autograd_grad = torch._functorch.eager_transforms._autograd_grad([o], [diff_primals], [_add_batch_dim], retain_graph = True, create_graph = True); o = diff_primals = _add_batch_dim = None
|
|
batched_outputs: "f32[3, 4]" = _autograd_grad[0]; _autograd_grad = None
|
|
|
|
chunked_result: "f32[12, 3, 4]" = torch._C._functorch._remove_batch_dim(batched_outputs, 1, 12, 0); batched_outputs = None
|
|
|
|
_vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting(); _vmap_decrement_nesting = None
|
|
|
|
split = chunked_result.split((12,), dim = 0); chunked_result = None
|
|
split_1: "f32[12, 3, 4]" = split[0]; split = None
|
|
|
|
output_input: "f32[3, 4, 3, 4]" = split_1.view((3, 4, 3, 4)); split_1 = None
|
|
return (output_input,)
|
|
""",
|
|
)
|
|
|
|
def test_jacrev_has_aux(self):
|
|
counters.clear()
|
|
|
|
def fn(x, y):
|
|
return y.sin(), x
|
|
|
|
def wrapper_fn(x, y):
|
|
return torch.func.jacrev(fn, argnums=1, has_aux=True)(x, y)
|
|
|
|
x = torch.randn(4, 3)
|
|
y = torch.randn(3, 4)
|
|
wrapped_gm = self._compile_check(wrapper_fn, (x, y))
|
|
# Dynamic shapes produce a slightly different graph.
|
|
if check_dynamic_shape_capture():
|
|
return
|
|
|
|
actual = normalize_gm(wrapped_gm.print_readable(print_output=False))
|
|
self.assertExpectedInline(
|
|
actual,
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(self, L_x_: "f32[4, 3]", L_y_: "f32[3, 4]"):
|
|
l_x_ = L_x_
|
|
l_y_ = L_y_
|
|
|
|
_saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func.{grad, vjp, jacrev, hessian} don't yet support saved tensor hooks. Please open an issue with your use case."); _saved_tensors_hooks_disable = None
|
|
_grad_increment_nesting = torch._C._functorch._grad_increment_nesting(); _grad_increment_nesting = None
|
|
|
|
aux: "f32[4, 3]" = torch._C._functorch._wrap_for_grad(l_x_, 1); l_x_ = None
|
|
diff_primals: "f32[3, 4]" = torch._C._functorch._wrap_for_grad(l_y_, 1); l_y_ = None
|
|
|
|
set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True); set_inplace_requires_grad_allowed = None
|
|
|
|
_set_tensor_requires_grad: "f32[3, 4]" = torch._functorch.eager_transforms._set_tensor_requires_grad(diff_primals); _set_tensor_requires_grad = None
|
|
|
|
set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False); set_inplace_requires_grad_allowed_1 = None
|
|
|
|
o: "f32[3, 4]" = diff_primals.sin()
|
|
|
|
aux_1: "f32[4, 3]" = torch._C._functorch._unwrap_for_grad(aux, 1); aux = None
|
|
|
|
results: "f32[3, 4]" = torch._C._functorch._unwrap_for_grad(o, 1)
|
|
|
|
_grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting(); _grad_decrement_nesting = None
|
|
_saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable(); _saved_tensors_hooks_enable = None
|
|
|
|
tensor: "i64[1]" = torch.tensor((12,))
|
|
cumsum: "i64[1]" = tensor.cumsum(dim = 0); tensor = None
|
|
getitem: "i64[0]" = cumsum[slice(None, -1, None)]; cumsum = None
|
|
neg: "i64[0]" = getitem.neg(); getitem = None
|
|
unbind = neg.unbind(); neg = unbind = None
|
|
|
|
chunk: "f32[12, 12]" = results.new_zeros(12, 12); results = None
|
|
|
|
diagonal: "f32[12]" = chunk.diagonal(0)
|
|
fill_: "f32[12]" = diagonal.fill_(1); diagonal = fill_ = None
|
|
|
|
basis: "f32[12, 3, 4]" = chunk.view(12, 3, 4); chunk = None
|
|
|
|
lazy_load_decompositions = torch._functorch.vmap.lazy_load_decompositions(); lazy_load_decompositions = None
|
|
|
|
_vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(12, 'error'); _vmap_increment_nesting = None
|
|
|
|
_add_batch_dim: "f32[3, 4]" = torch._C._functorch._add_batch_dim(basis, 0, 1); basis = None
|
|
|
|
_vjp_treespec_compare = torch._functorch.eager_transforms._vjp_treespec_compare(o, _add_batch_dim); _vjp_treespec_compare = None
|
|
|
|
_autograd_grad = torch._functorch.eager_transforms._autograd_grad([o], [diff_primals], [_add_batch_dim], retain_graph = True, create_graph = True); o = diff_primals = _add_batch_dim = None
|
|
batched_outputs: "f32[3, 4]" = _autograd_grad[0]; _autograd_grad = None
|
|
|
|
chunked_result: "f32[12, 3, 4]" = torch._C._functorch._remove_batch_dim(batched_outputs, 1, 12, 0); batched_outputs = None
|
|
|
|
_vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting(); _vmap_decrement_nesting = None
|
|
|
|
split = chunked_result.split((12,), dim = 0); chunked_result = None
|
|
split_1: "f32[12, 3, 4]" = split[0]; split = None
|
|
|
|
output_input: "f32[3, 4, 3, 4]" = split_1.view((3, 4, 3, 4)); split_1 = None
|
|
return (output_input, aux_1)
|
|
""",
|
|
)
|
|
|
|
def test_vjp(self):
|
|
counters.clear()
|
|
|
|
def fn(x):
|
|
return x.sin().sum()
|
|
|
|
def wrapper_fn(x, v):
|
|
(out, vjpfunc) = torch.func.vjp(fn, x)
|
|
return out
|
|
|
|
x = torch.randn([5])
|
|
v = torch.randn(5)
|
|
wrapped_gm = self._compile_check(wrapper_fn, (x, v))
|
|
|
|
# Dynamic shapes produce a slightly different graph.
|
|
if check_dynamic_shape_capture():
|
|
return
|
|
|
|
actual = normalize_gm(wrapped_gm.print_readable(print_output=False))
|
|
self.assertExpectedInline(
|
|
actual,
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(self, L_x_: "f32[5]"):
|
|
l_x_ = L_x_
|
|
|
|
_saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func.{grad, vjp, jacrev, hessian} don't yet support saved tensor hooks. Please open an issue with your use case."); _saved_tensors_hooks_disable = None
|
|
_grad_increment_nesting = torch._C._functorch._grad_increment_nesting(); _grad_increment_nesting = None
|
|
|
|
child: "f32[5]" = torch._C._functorch._wrap_for_grad(l_x_, 1); l_x_ = None
|
|
|
|
set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True); set_inplace_requires_grad_allowed = None
|
|
|
|
child_1: "f32[5]" = torch._functorch.eager_transforms._set_tensor_requires_grad(child); child_1 = None
|
|
|
|
set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False); set_inplace_requires_grad_allowed_1 = None
|
|
|
|
sin: "f32[5]" = child.sin(); child = None
|
|
o: "f32[]" = sin.sum(); sin = None
|
|
|
|
results: "f32[]" = torch._C._functorch._unwrap_for_grad(o, 1); o = None
|
|
|
|
_grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting(); _grad_decrement_nesting = None
|
|
_saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable(); _saved_tensors_hooks_enable = None
|
|
return (results,)
|
|
""",
|
|
)
|
|
|
|
def test_vjp_multiple_outputs(self):
|
|
counters.clear()
|
|
|
|
def wrapper_fn(x, v):
|
|
fn = lambda x: (x.sin(), x.cos()) # noqa: E731
|
|
(out, vjpfunc) = torch.func.vjp(fn, x)
|
|
vjps = vjpfunc((v, v))
|
|
return out, vjps
|
|
|
|
x = torch.randn([5])
|
|
v = torch.randn(5)
|
|
wrapped_gm = self._compile_check(wrapper_fn, (x, v))
|
|
|
|
# Dynamic shapes produce a slightly different graph.
|
|
if check_dynamic_shape_capture():
|
|
return
|
|
|
|
actual = normalize_gm(wrapped_gm.print_readable(print_output=False))
|
|
self.assertExpectedInline(
|
|
actual,
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(self, L_x_: "f32[5]", L_v_: "f32[5]"):
|
|
l_x_ = L_x_
|
|
l_v_ = L_v_
|
|
|
|
_saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func.{grad, vjp, jacrev, hessian} don't yet support saved tensor hooks. Please open an issue with your use case."); _saved_tensors_hooks_disable = None
|
|
_grad_increment_nesting = torch._C._functorch._grad_increment_nesting(); _grad_increment_nesting = None
|
|
|
|
child: "f32[5]" = torch._C._functorch._wrap_for_grad(l_x_, 1); l_x_ = None
|
|
|
|
set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True); set_inplace_requires_grad_allowed = None
|
|
|
|
child_3: "f32[5]" = torch._functorch.eager_transforms._set_tensor_requires_grad(child)
|
|
|
|
set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False); set_inplace_requires_grad_allowed_1 = None
|
|
|
|
child_1: "f32[5]" = child.sin()
|
|
child_2: "f32[5]" = child.cos(); child = None
|
|
|
|
_unwrap_for_grad: "f32[5]" = torch._C._functorch._unwrap_for_grad(child_1, 1)
|
|
_unwrap_for_grad_1: "f32[5]" = torch._C._functorch._unwrap_for_grad(child_2, 1)
|
|
|
|
_grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting(); _grad_decrement_nesting = None
|
|
_saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable(); _saved_tensors_hooks_enable = None
|
|
|
|
_vjp_treespec_compare = torch._functorch.eager_transforms._vjp_treespec_compare((child_1, child_2), (l_v_, l_v_)); _vjp_treespec_compare = None
|
|
|
|
_autograd_grad = torch._functorch.eager_transforms._autograd_grad([child_1, child_2], [child_3], [l_v_, l_v_], retain_graph = True, create_graph = True); child_1 = child_2 = child_3 = l_v_ = None
|
|
getitem: "f32[5]" = _autograd_grad[0]; _autograd_grad = None
|
|
return (_unwrap_for_grad, _unwrap_for_grad_1, getitem)
|
|
""",
|
|
)
|
|
|
|
def test_vjp_multiple_outputs_python_struct(self):
|
|
counters.clear()
|
|
|
|
def wrapper_fn(x, v):
|
|
fn = lambda x: {"first": x.sin(), "second": x.cos()} # noqa: E731
|
|
(out, vjpfunc) = torch.func.vjp(fn, x)
|
|
vjps = vjpfunc({"first": v, "second": v.sin()})
|
|
return out, vjps
|
|
|
|
x = torch.randn([5])
|
|
v = torch.randn(5)
|
|
wrapped_gm = self._compile_check(wrapper_fn, (x, v))
|
|
|
|
# Dynamic shapes produce a slightly different graph.
|
|
if check_dynamic_shape_capture():
|
|
return
|
|
|
|
actual = normalize_gm(wrapped_gm.print_readable(print_output=False))
|
|
self.assertExpectedInline(
|
|
actual,
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(self, L_x_: "f32[5]", L_v_: "f32[5]"):
|
|
l_x_ = L_x_
|
|
l_v_ = L_v_
|
|
|
|
_saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func.{grad, vjp, jacrev, hessian} don't yet support saved tensor hooks. Please open an issue with your use case."); _saved_tensors_hooks_disable = None
|
|
_grad_increment_nesting = torch._C._functorch._grad_increment_nesting(); _grad_increment_nesting = None
|
|
|
|
child: "f32[5]" = torch._C._functorch._wrap_for_grad(l_x_, 1); l_x_ = None
|
|
|
|
set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True); set_inplace_requires_grad_allowed = None
|
|
|
|
child_3: "f32[5]" = torch._functorch.eager_transforms._set_tensor_requires_grad(child)
|
|
|
|
set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False); set_inplace_requires_grad_allowed_1 = None
|
|
|
|
child_1: "f32[5]" = child.sin()
|
|
child_2: "f32[5]" = child.cos(); child = None
|
|
|
|
_unwrap_for_grad: "f32[5]" = torch._C._functorch._unwrap_for_grad(child_1, 1)
|
|
_unwrap_for_grad_1: "f32[5]" = torch._C._functorch._unwrap_for_grad(child_2, 1)
|
|
|
|
_grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting(); _grad_decrement_nesting = None
|
|
_saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable(); _saved_tensors_hooks_enable = None
|
|
|
|
child_4: "f32[5]" = l_v_.sin()
|
|
|
|
_vjp_treespec_compare = torch._functorch.eager_transforms._vjp_treespec_compare({'first': child_1, 'second': child_2}, {'first': l_v_, 'second': child_4}); _vjp_treespec_compare = None
|
|
|
|
_autograd_grad = torch._functorch.eager_transforms._autograd_grad([child_1, child_2], [child_3], [l_v_, child_4], retain_graph = True, create_graph = True); child_1 = child_2 = child_3 = l_v_ = child_4 = None
|
|
getitem: "f32[5]" = _autograd_grad[0]; _autograd_grad = None
|
|
return (_unwrap_for_grad, _unwrap_for_grad_1, getitem)
|
|
""",
|
|
)
|
|
|
|
def test_vjp_has_aux(self):
|
|
counters.clear()
|
|
|
|
def fn(x):
|
|
return x.sin().sum(), x
|
|
|
|
def wrapper_fn(x, v):
|
|
(out, vjpfunc, _) = torch.func.vjp(fn, x, has_aux=True)
|
|
return out
|
|
|
|
x = torch.randn([5])
|
|
v = torch.randn(5)
|
|
wrapped_gm = self._compile_check(wrapper_fn, (x, v))
|
|
|
|
# Dynamic shapes produce a slightly different graph.
|
|
if check_dynamic_shape_capture():
|
|
return
|
|
|
|
actual = normalize_gm(wrapped_gm.print_readable(print_output=False))
|
|
self.assertExpectedInline(
|
|
actual,
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(self, L_x_: "f32[5]"):
|
|
l_x_ = L_x_
|
|
|
|
_saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func.{grad, vjp, jacrev, hessian} don't yet support saved tensor hooks. Please open an issue with your use case."); _saved_tensors_hooks_disable = None
|
|
_grad_increment_nesting = torch._C._functorch._grad_increment_nesting(); _grad_increment_nesting = None
|
|
|
|
child: "f32[5]" = torch._C._functorch._wrap_for_grad(l_x_, 1); l_x_ = None
|
|
|
|
set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True); set_inplace_requires_grad_allowed = None
|
|
|
|
child_1: "f32[5]" = torch._functorch.eager_transforms._set_tensor_requires_grad(child); child_1 = None
|
|
|
|
set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False); set_inplace_requires_grad_allowed_1 = None
|
|
|
|
sin: "f32[5]" = child.sin()
|
|
o: "f32[]" = sin.sum(); sin = None
|
|
|
|
aux: "f32[5]" = torch._C._functorch._unwrap_for_grad(child, 1); child = aux = None
|
|
|
|
results: "f32[]" = torch._C._functorch._unwrap_for_grad(o, 1); o = None
|
|
|
|
_grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting(); _grad_decrement_nesting = None
|
|
_saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable(); _saved_tensors_hooks_enable = None
|
|
return (results,)
|
|
""",
|
|
)
|
|
|
|
@config.patch(inline_inbuilt_nn_modules=True)
|
|
def test_functional_call(self):
|
|
def wrapper_fn(model, params, inputs, targets):
|
|
prediction = torch.func.functional_call(model, params, (inputs,))
|
|
return torch.nn.functional.mse_loss(prediction, targets)
|
|
|
|
model = torch.nn.Linear(3, 3)
|
|
params = dict(model.named_parameters())
|
|
inputs = torch.randn(64, 3)
|
|
targets = torch.randn(64, 3)
|
|
|
|
wrapped_gm = self._compile_check(wrapper_fn, (model, params, inputs, targets))
|
|
# Dynamic shapes produce a slightly different graph.
|
|
if check_dynamic_shape_capture():
|
|
return
|
|
|
|
actual = normalize_gm(wrapped_gm.print_readable(print_output=False))
|
|
if torch._dynamo.config.inline_inbuilt_nn_modules:
|
|
self.assertExpectedInline(
|
|
actual,
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(self, L_model_parameters_weight_: "f32[3, 3]", L_model_parameters_bias_: "f32[3]", L_inputs_: "f32[64, 3]", L_targets_: "f32[64, 3]"):
|
|
l_model_parameters_weight_ = L_model_parameters_weight_
|
|
l_model_parameters_bias_ = L_model_parameters_bias_
|
|
l_inputs_ = L_inputs_
|
|
l_targets_ = L_targets_
|
|
|
|
prediction: "f32[64, 3]" = torch._C._nn.linear(l_inputs_, l_model_parameters_weight_, l_model_parameters_bias_); l_inputs_ = l_model_parameters_weight_ = l_model_parameters_bias_ = None
|
|
|
|
mse_loss: "f32[]" = torch.nn.functional.mse_loss(prediction, l_targets_); prediction = l_targets_ = None
|
|
return (mse_loss,)
|
|
""",
|
|
)
|
|
else:
|
|
self.assertExpectedInline(
|
|
actual,
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(self, L_inputs_: "f32[64, 3]", L_targets_: "f32[64, 3]"):
|
|
l_inputs_ = L_inputs_
|
|
l_targets_ = L_targets_
|
|
|
|
prediction: "f32[64, 3]" = self.model(l_inputs_); l_inputs_ = None
|
|
|
|
mse_loss: "f32[]" = torch.nn.functional.mse_loss(prediction, l_targets_); prediction = l_targets_ = None
|
|
return (mse_loss,)
|
|
""",
|
|
)
|
|
|
|
@config.patch(inline_inbuilt_nn_modules=True)
|
|
def test_functional_call_sequential_params_and_buffers(self):
|
|
# copied from test/test_stateless.py
|
|
class MockModule(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.l1 = torch.nn.Linear(1, 1)
|
|
self.register_buffer("buffer", torch.ones(1))
|
|
self.foo = 0.0
|
|
|
|
def forward(self, x):
|
|
return self.l1(x) + self.buffer
|
|
|
|
def wrapper_fn(model, params, buffers, inputs):
|
|
# two separate dictionaries
|
|
return torch.func.functional_call(model, (params, buffers), inputs)
|
|
|
|
model = MockModule()
|
|
params = dict(model.named_parameters())
|
|
buffers = dict(model.named_buffers())
|
|
inputs = torch.tensor([[1.5]])
|
|
|
|
wrapped_gm = self._compile_check(
|
|
wrapper_fn, (model, params, buffers, inputs), fullgraph=False
|
|
)
|
|
# Dynamic shapes produce a slightly different graph.
|
|
if check_dynamic_shape_capture():
|
|
return
|
|
|
|
actual = normalize_gm(wrapped_gm.print_readable(print_output=False))
|
|
if torch._dynamo.config.inline_inbuilt_nn_modules:
|
|
expected = """\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(self, L_params_l1_weight_: "f32[1, 1]", L_params_l1_bias_: "f32[1]", L_buffers_buffer_: "f32[1]", L_inputs_: "f32[1, 1]"):
|
|
l_params_l1_weight_ = L_params_l1_weight_
|
|
l_params_l1_bias_ = L_params_l1_bias_
|
|
l_buffers_buffer_ = L_buffers_buffer_
|
|
l_inputs_ = L_inputs_
|
|
|
|
linear: "f32[1, 1]" = torch._C._nn.linear(l_inputs_, l_params_l1_weight_, l_params_l1_bias_); l_inputs_ = l_params_l1_weight_ = l_params_l1_bias_ = None
|
|
|
|
add: "f32[1, 1]" = linear + l_buffers_buffer_; linear = l_buffers_buffer_ = None
|
|
return (add,)
|
|
"""
|
|
# We found Windows/Linux have some empty line difference, empty_line_normalizer will help fix it.
|
|
self.assertExpectedInline(
|
|
empty_line_normalizer(actual),
|
|
empty_line_normalizer(normalize_gm(expected)),
|
|
)
|
|
else:
|
|
self.assertExpectedInline(
|
|
actual,
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(self, L_x_: "f32[1, 1]"):
|
|
l_x_ = L_x_
|
|
|
|
l__self___l1: "f32[1, 1]" = self.L__self___l1(l_x_); l_x_ = None
|
|
l__self___buffer: "f32[1]" = self.L__self___buffer
|
|
add: "f32[1, 1]" = l__self___l1 + l__self___buffer; l__self___l1 = l__self___buffer = None
|
|
return (add,)
|
|
""",
|
|
)
|
|
|
|
@config.patch(inline_inbuilt_nn_modules=False)
|
|
def test_functional_call_disable_inline_nn_module(self):
|
|
counters.clear()
|
|
|
|
def wrapper_fn(model, params, inputs, targets):
|
|
prediction = torch.func.functional_call(model, params, (inputs,))
|
|
return torch.nn.functional.mse_loss(prediction, targets)
|
|
|
|
model = torch.nn.Linear(3, 3)
|
|
params = dict(model.named_parameters())
|
|
inputs = torch.randn(64, 3)
|
|
targets = torch.randn(64, 3)
|
|
|
|
actual = wrapper_fn(model, params, inputs, targets)
|
|
expected = torch.compile(wrapper_fn, backend="aot_eager", fullgraph=False)(
|
|
model, params, inputs, targets
|
|
)
|
|
self.assertEqual(len(counters["graph_break"]), 1)
|
|
self.assertEqual(
|
|
{
|
|
"torch.func.functional_call capture is disabled, it can be "
|
|
"turned on by setting `torch._dynamo.config.inline_inbuilt_nn_modules=True`": 1,
|
|
},
|
|
dict(counters["graph_break"]),
|
|
)
|
|
self.assertEqual(actual, expected)
|
|
|
|
def test_grad(self):
|
|
counters.clear()
|
|
|
|
def fn(x):
|
|
return x.sin().sum()
|
|
|
|
def wrapper_fn(x):
|
|
return torch.func.grad(fn)(x)
|
|
|
|
x = torch.randn(3, 3, 3)
|
|
wrapped_gm = self._compile_check(wrapper_fn, (x,))
|
|
|
|
# Dynamic shapes produce a slightly different graph.
|
|
if check_dynamic_shape_capture():
|
|
return
|
|
|
|
actual = normalize_gm(wrapped_gm.print_readable(print_output=False))
|
|
self.assertExpectedInline(
|
|
actual,
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(self, L_x_: "f32[3, 3, 3]"):
|
|
l_x_ = L_x_
|
|
|
|
_saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func.{grad, vjp, jacrev, hessian} don't yet support saved tensor hooks. Please open an issue with your use case."); _saved_tensors_hooks_disable = None
|
|
_grad_increment_nesting = torch._C._functorch._grad_increment_nesting(); _grad_increment_nesting = None
|
|
|
|
diff_args: "f32[3, 3, 3]" = torch._C._functorch._wrap_for_grad(l_x_, 1); l_x_ = None
|
|
|
|
set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True); set_inplace_requires_grad_allowed = None
|
|
|
|
_set_tensor_requires_grad: "f32[3, 3, 3]" = torch._functorch.eager_transforms._set_tensor_requires_grad(diff_args); _set_tensor_requires_grad = None
|
|
|
|
set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False); set_inplace_requires_grad_allowed_1 = None
|
|
|
|
sin: "f32[3, 3, 3]" = diff_args.sin()
|
|
output: "f32[]" = sin.sum(); sin = None
|
|
|
|
_autograd_grad = torch._functorch.eager_transforms._autograd_grad((output,), [diff_args], create_graph = True); diff_args = None
|
|
grad_input: "f32[3, 3, 3]" = _autograd_grad[0]; _autograd_grad = None
|
|
|
|
grad_input_1: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(grad_input, 1); grad_input = None
|
|
|
|
output_1: "f32[]" = torch._C._functorch._unwrap_for_grad(output, 1); output = output_1 = None
|
|
|
|
_grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting(); _grad_decrement_nesting = None
|
|
_saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable(); _saved_tensors_hooks_enable = None
|
|
return (grad_input_1,)
|
|
""",
|
|
)
|
|
|
|
def test_grad_freevar_tensor(self):
|
|
counters.clear()
|
|
y = torch.randn(3, 3)
|
|
|
|
def fn(x):
|
|
return (x.sin() + y).sum()
|
|
|
|
def wrapper_fn(x):
|
|
return torch.func.grad(fn)(x)
|
|
|
|
x = torch.randn(3, 3, 3)
|
|
expected = wrapper_fn(x)
|
|
actual = torch.compile(wrapper_fn, backend="aot_eager", fullgraph=True)(x)
|
|
self.assertEqual(actual, expected)
|
|
|
|
def test_grad_freevar_python_scalar(self):
|
|
counters.clear()
|
|
y = 3
|
|
|
|
def fn(x):
|
|
return (x.sin() + y).sum()
|
|
|
|
def wrapper_fn(x):
|
|
return torch.func.grad(fn)(x)
|
|
|
|
x = torch.randn(3, 3, 3)
|
|
wrapped_gm = self._compile_check(wrapper_fn, (x,))
|
|
|
|
# Dynamic shapes produce a slightly different graph.
|
|
if check_dynamic_shape_capture():
|
|
return
|
|
|
|
actual = normalize_gm(wrapped_gm.print_readable(print_output=False))
|
|
self.assertExpectedInline(
|
|
actual,
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(self, L_x_: "f32[3, 3, 3]"):
|
|
l_x_ = L_x_
|
|
|
|
_saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func.{grad, vjp, jacrev, hessian} don't yet support saved tensor hooks. Please open an issue with your use case."); _saved_tensors_hooks_disable = None
|
|
_grad_increment_nesting = torch._C._functorch._grad_increment_nesting(); _grad_increment_nesting = None
|
|
|
|
diff_args: "f32[3, 3, 3]" = torch._C._functorch._wrap_for_grad(l_x_, 1); l_x_ = None
|
|
|
|
set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True); set_inplace_requires_grad_allowed = None
|
|
|
|
_set_tensor_requires_grad: "f32[3, 3, 3]" = torch._functorch.eager_transforms._set_tensor_requires_grad(diff_args); _set_tensor_requires_grad = None
|
|
|
|
set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False); set_inplace_requires_grad_allowed_1 = None
|
|
|
|
sin: "f32[3, 3, 3]" = diff_args.sin()
|
|
add: "f32[3, 3, 3]" = sin + 3; sin = None
|
|
output: "f32[]" = add.sum(); add = None
|
|
|
|
_autograd_grad = torch._functorch.eager_transforms._autograd_grad((output,), [diff_args], create_graph = True); diff_args = None
|
|
grad_input: "f32[3, 3, 3]" = _autograd_grad[0]; _autograd_grad = None
|
|
|
|
grad_input_1: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(grad_input, 1); grad_input = None
|
|
|
|
output_1: "f32[]" = torch._C._functorch._unwrap_for_grad(output, 1); output = output_1 = None
|
|
|
|
_grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting(); _grad_decrement_nesting = None
|
|
_saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable(); _saved_tensors_hooks_enable = None
|
|
return (grad_input_1,)
|
|
""",
|
|
)
|
|
|
|
def test_grad_capture_tensor(self):
|
|
counters.clear()
|
|
|
|
def wrapper_fn(x):
|
|
y = torch.randn(3)
|
|
|
|
def fn(x):
|
|
return (x.sin() + y).sum()
|
|
|
|
return torch.func.grad(fn)(x)
|
|
|
|
x = torch.randn(3, 3, 3)
|
|
|
|
wrapped_gm = self._compile_check(wrapper_fn, (x,))
|
|
|
|
# Dynamic shapes produce a slightly different graph.
|
|
if check_dynamic_shape_capture():
|
|
return
|
|
|
|
actual = normalize_gm(wrapped_gm.print_readable(print_output=False))
|
|
self.assertExpectedInline(
|
|
actual,
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(self, L_x_: "f32[3, 3, 3]"):
|
|
l_x_ = L_x_
|
|
|
|
y: "f32[3]" = torch.randn(3)
|
|
|
|
_saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func.{grad, vjp, jacrev, hessian} don't yet support saved tensor hooks. Please open an issue with your use case."); _saved_tensors_hooks_disable = None
|
|
_grad_increment_nesting = torch._C._functorch._grad_increment_nesting(); _grad_increment_nesting = None
|
|
|
|
diff_args: "f32[3, 3, 3]" = torch._C._functorch._wrap_for_grad(l_x_, 1); l_x_ = None
|
|
|
|
set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True); set_inplace_requires_grad_allowed = None
|
|
|
|
_set_tensor_requires_grad: "f32[3, 3, 3]" = torch._functorch.eager_transforms._set_tensor_requires_grad(diff_args); _set_tensor_requires_grad = None
|
|
|
|
set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False); set_inplace_requires_grad_allowed_1 = None
|
|
|
|
sin: "f32[3, 3, 3]" = diff_args.sin()
|
|
add: "f32[3, 3, 3]" = sin + y; sin = y = None
|
|
output: "f32[]" = add.sum(); add = None
|
|
|
|
_autograd_grad = torch._functorch.eager_transforms._autograd_grad((output,), [diff_args], create_graph = True); diff_args = None
|
|
grad_input: "f32[3, 3, 3]" = _autograd_grad[0]; _autograd_grad = None
|
|
|
|
grad_input_1: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(grad_input, 1); grad_input = None
|
|
|
|
output_1: "f32[]" = torch._C._functorch._unwrap_for_grad(output, 1); output = output_1 = None
|
|
|
|
_grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting(); _grad_decrement_nesting = None
|
|
_saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable(); _saved_tensors_hooks_enable = None
|
|
return (grad_input_1,)
|
|
""",
|
|
)
|
|
|
|
def test_grad_closure_scalar(self):
|
|
counters.clear()
|
|
|
|
def wrapper_fn(x):
|
|
y = 3.14
|
|
|
|
def fn(x):
|
|
return (x.sin() + y).sum()
|
|
|
|
return torch.func.grad(fn)(x)
|
|
|
|
x = torch.randn(3, 3, 3)
|
|
|
|
# Graph break because dynamo is unable to get source `fn` and
|
|
# functools.wraps in `grad` leads to graph-break
|
|
wrapped_gm = self._compile_check(wrapper_fn, (x,), fullgraph=False)
|
|
|
|
# Dynamic shapes produce a slightly different graph.
|
|
if check_dynamic_shape_capture():
|
|
return
|
|
|
|
actual = normalize_gm(wrapped_gm.print_readable(print_output=False))
|
|
self.assertExpectedInline(
|
|
actual,
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(self, L_x_: "f32[3, 3, 3]"):
|
|
l_x_ = L_x_
|
|
|
|
_saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func.{grad, vjp, jacrev, hessian} don't yet support saved tensor hooks. Please open an issue with your use case."); _saved_tensors_hooks_disable = None
|
|
_grad_increment_nesting = torch._C._functorch._grad_increment_nesting(); _grad_increment_nesting = None
|
|
|
|
diff_args: "f32[3, 3, 3]" = torch._C._functorch._wrap_for_grad(l_x_, 1); l_x_ = None
|
|
|
|
set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True); set_inplace_requires_grad_allowed = None
|
|
|
|
_set_tensor_requires_grad: "f32[3, 3, 3]" = torch._functorch.eager_transforms._set_tensor_requires_grad(diff_args); _set_tensor_requires_grad = None
|
|
|
|
set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False); set_inplace_requires_grad_allowed_1 = None
|
|
|
|
sin: "f32[3, 3, 3]" = diff_args.sin()
|
|
add: "f32[3, 3, 3]" = sin + 3.14; sin = None
|
|
output: "f32[]" = add.sum(); add = None
|
|
|
|
_autograd_grad = torch._functorch.eager_transforms._autograd_grad((output,), [diff_args], create_graph = True); diff_args = None
|
|
grad_input: "f32[3, 3, 3]" = _autograd_grad[0]; _autograd_grad = None
|
|
|
|
grad_input_1: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(grad_input, 1); grad_input = None
|
|
|
|
output_1: "f32[]" = torch._C._functorch._unwrap_for_grad(output, 1); output = output_1 = None
|
|
|
|
_grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting(); _grad_decrement_nesting = None
|
|
_saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable(); _saved_tensors_hooks_enable = None
|
|
return (grad_input_1,)
|
|
""",
|
|
)
|
|
|
|
def test_grad_has_aux(self):
|
|
counters.clear()
|
|
|
|
y = 3.14
|
|
|
|
def fn(x):
|
|
return ((x.sin() + y).sum(), x.cos())
|
|
|
|
def wrapper_fn(x):
|
|
return torch.func.grad(fn, has_aux=True)(x)
|
|
|
|
x = torch.randn(3, 3, 3)
|
|
wrapped_gm = self._compile_check(wrapper_fn, (x,))
|
|
|
|
# Dynamic shapes produce a slightly different graph.
|
|
if check_dynamic_shape_capture():
|
|
return
|
|
|
|
actual = normalize_gm(wrapped_gm.print_readable(print_output=False))
|
|
self.assertExpectedInline(
|
|
actual,
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(self, L_x_: "f32[3, 3, 3]"):
|
|
l_x_ = L_x_
|
|
|
|
_saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func.{grad, vjp, jacrev, hessian} don't yet support saved tensor hooks. Please open an issue with your use case."); _saved_tensors_hooks_disable = None
|
|
_grad_increment_nesting = torch._C._functorch._grad_increment_nesting(); _grad_increment_nesting = None
|
|
|
|
diff_args: "f32[3, 3, 3]" = torch._C._functorch._wrap_for_grad(l_x_, 1); l_x_ = None
|
|
|
|
set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True); set_inplace_requires_grad_allowed = None
|
|
|
|
_set_tensor_requires_grad: "f32[3, 3, 3]" = torch._functorch.eager_transforms._set_tensor_requires_grad(diff_args); _set_tensor_requires_grad = None
|
|
|
|
set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False); set_inplace_requires_grad_allowed_1 = None
|
|
|
|
sin: "f32[3, 3, 3]" = diff_args.sin()
|
|
add: "f32[3, 3, 3]" = sin + 3.14; sin = None
|
|
output: "f32[]" = add.sum(); add = None
|
|
aux: "f32[3, 3, 3]" = diff_args.cos()
|
|
|
|
_autograd_grad = torch._functorch.eager_transforms._autograd_grad((output,), [diff_args], create_graph = True); diff_args = None
|
|
grad_input: "f32[3, 3, 3]" = _autograd_grad[0]; _autograd_grad = None
|
|
|
|
grad_input_1: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(grad_input, 1); grad_input = None
|
|
|
|
output_1: "f32[]" = torch._C._functorch._unwrap_for_grad(output, 1); output = output_1 = None
|
|
|
|
aux_1: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(aux, 1); aux = None
|
|
|
|
_grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting(); _grad_decrement_nesting = None
|
|
_saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable(); _saved_tensors_hooks_enable = None
|
|
return (grad_input_1, aux_1)
|
|
""",
|
|
)
|
|
|
|
def test_grad_two_tensor_has_aux(self):
|
|
counters.clear()
|
|
|
|
def fn(x, y):
|
|
return ((x.sin() + y).sum(), x.cos())
|
|
|
|
def wrapper_fn(x, y):
|
|
return torch.func.grad(fn, has_aux=True)(x, y)
|
|
|
|
y = torch.randn(3, 3, 3)
|
|
x = torch.randn(3, 3, 3)
|
|
wrapped_gm = self._compile_check(wrapper_fn, (x, y))
|
|
|
|
# Dynamic shapes produce a slightly different graph.
|
|
if check_dynamic_shape_capture():
|
|
return
|
|
|
|
actual = normalize_gm(wrapped_gm.print_readable(print_output=False))
|
|
self.assertExpectedInline(
|
|
actual,
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(self, L_x_: "f32[3, 3, 3]", L_y_: "f32[3, 3, 3]"):
|
|
l_x_ = L_x_
|
|
l_y_ = L_y_
|
|
|
|
_saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func.{grad, vjp, jacrev, hessian} don't yet support saved tensor hooks. Please open an issue with your use case."); _saved_tensors_hooks_disable = None
|
|
_grad_increment_nesting = torch._C._functorch._grad_increment_nesting(); _grad_increment_nesting = None
|
|
|
|
diff_args: "f32[3, 3, 3]" = torch._C._functorch._wrap_for_grad(l_x_, 1); l_x_ = None
|
|
_wrap_for_grad_1: "f32[3, 3, 3]" = torch._C._functorch._wrap_for_grad(l_y_, 1); l_y_ = None
|
|
|
|
set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True); set_inplace_requires_grad_allowed = None
|
|
|
|
_set_tensor_requires_grad: "f32[3, 3, 3]" = torch._functorch.eager_transforms._set_tensor_requires_grad(diff_args); _set_tensor_requires_grad = None
|
|
|
|
set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False); set_inplace_requires_grad_allowed_1 = None
|
|
|
|
sin: "f32[3, 3, 3]" = diff_args.sin()
|
|
add: "f32[3, 3, 3]" = sin + _wrap_for_grad_1; sin = _wrap_for_grad_1 = None
|
|
output: "f32[]" = add.sum(); add = None
|
|
aux: "f32[3, 3, 3]" = diff_args.cos()
|
|
|
|
_autograd_grad = torch._functorch.eager_transforms._autograd_grad((output,), [diff_args], create_graph = True); diff_args = None
|
|
grad_input: "f32[3, 3, 3]" = _autograd_grad[0]; _autograd_grad = None
|
|
|
|
grad_input_1: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(grad_input, 1); grad_input = None
|
|
|
|
output_1: "f32[]" = torch._C._functorch._unwrap_for_grad(output, 1); output = output_1 = None
|
|
|
|
aux_1: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(aux, 1); aux = None
|
|
|
|
_grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting(); _grad_decrement_nesting = None
|
|
_saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable(); _saved_tensors_hooks_enable = None
|
|
return (grad_input_1, aux_1)
|
|
""",
|
|
)
|
|
|
|
def test_grad_two_tensor_all_grad_has_aux(self):
|
|
counters.clear()
|
|
|
|
nums = (0, 1)
|
|
|
|
def fn(x, y):
|
|
return ((x.sin() + y).sum(), x.cos())
|
|
|
|
def wrapper_fn_const_var(x, y):
|
|
return torch.func.grad(fn, argnums=(0, 1), has_aux=True)(x, y)
|
|
|
|
def wrapper_fn_tuple_var(x, y):
|
|
return torch.func.grad(fn, argnums=nums, has_aux=True)(x, y)
|
|
|
|
y = torch.randn(3, 3, 3)
|
|
x = torch.randn(3, 3, 3)
|
|
wrapped_gm_const_var = self._compile_check(wrapper_fn_const_var, (x, y))
|
|
wrapped_gm_tuple_var = self._compile_check(wrapper_fn_tuple_var, (x, y))
|
|
|
|
# Dynamic shapes produce a slightly different graph.
|
|
if check_dynamic_shape_capture():
|
|
return
|
|
|
|
actual_const_var = normalize_gm(
|
|
wrapped_gm_const_var.print_readable(print_output=False)
|
|
)
|
|
actual_tuple_var = normalize_gm(
|
|
wrapped_gm_tuple_var.print_readable(print_output=False)
|
|
)
|
|
self.assertExpectedInline(
|
|
actual_const_var,
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(self, L_x_: "f32[3, 3, 3]", L_y_: "f32[3, 3, 3]"):
|
|
l_x_ = L_x_
|
|
l_y_ = L_y_
|
|
|
|
_saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func.{grad, vjp, jacrev, hessian} don't yet support saved tensor hooks. Please open an issue with your use case."); _saved_tensors_hooks_disable = None
|
|
_grad_increment_nesting = torch._C._functorch._grad_increment_nesting(); _grad_increment_nesting = None
|
|
|
|
child: "f32[3, 3, 3]" = torch._C._functorch._wrap_for_grad(l_x_, 1); l_x_ = None
|
|
child_1: "f32[3, 3, 3]" = torch._C._functorch._wrap_for_grad(l_y_, 1); l_y_ = None
|
|
|
|
set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True); set_inplace_requires_grad_allowed = None
|
|
|
|
_set_tensor_requires_grad: "f32[3, 3, 3]" = torch._functorch.eager_transforms._set_tensor_requires_grad(child); _set_tensor_requires_grad = None
|
|
|
|
set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False); set_inplace_requires_grad_allowed_1 = None
|
|
set_inplace_requires_grad_allowed_2 = torch._C._functorch.set_inplace_requires_grad_allowed(True); set_inplace_requires_grad_allowed_2 = None
|
|
|
|
_set_tensor_requires_grad_1: "f32[3, 3, 3]" = torch._functorch.eager_transforms._set_tensor_requires_grad(child_1); _set_tensor_requires_grad_1 = None
|
|
|
|
set_inplace_requires_grad_allowed_3 = torch._C._functorch.set_inplace_requires_grad_allowed(False); set_inplace_requires_grad_allowed_3 = None
|
|
|
|
sin: "f32[3, 3, 3]" = child.sin()
|
|
add: "f32[3, 3, 3]" = sin + child_1; sin = None
|
|
output: "f32[]" = add.sum(); add = None
|
|
aux: "f32[3, 3, 3]" = child.cos()
|
|
|
|
_autograd_grad = torch._functorch.eager_transforms._autograd_grad((output,), [child, child_1], create_graph = True); child = child_1 = None
|
|
child_2: "f32[3, 3, 3]" = _autograd_grad[0]
|
|
child_3: "f32[3, 3, 3]" = _autograd_grad[1]; _autograd_grad = None
|
|
|
|
_unwrap_for_grad: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(child_2, 1); child_2 = None
|
|
_unwrap_for_grad_1: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(child_3, 1); child_3 = None
|
|
|
|
output_1: "f32[]" = torch._C._functorch._unwrap_for_grad(output, 1); output = output_1 = None
|
|
|
|
aux_1: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(aux, 1); aux = None
|
|
|
|
_grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting(); _grad_decrement_nesting = None
|
|
_saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable(); _saved_tensors_hooks_enable = None
|
|
return (_unwrap_for_grad, _unwrap_for_grad_1, aux_1)
|
|
""",
|
|
)
|
|
self.assertExpectedInline(
|
|
actual_tuple_var,
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(self, L_x_: "f32[3, 3, 3]", L_y_: "f32[3, 3, 3]"):
|
|
l_x_ = L_x_
|
|
l_y_ = L_y_
|
|
|
|
_saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func.{grad, vjp, jacrev, hessian} don't yet support saved tensor hooks. Please open an issue with your use case."); _saved_tensors_hooks_disable = None
|
|
_grad_increment_nesting = torch._C._functorch._grad_increment_nesting(); _grad_increment_nesting = None
|
|
|
|
child: "f32[3, 3, 3]" = torch._C._functorch._wrap_for_grad(l_x_, 1); l_x_ = None
|
|
child_1: "f32[3, 3, 3]" = torch._C._functorch._wrap_for_grad(l_y_, 1); l_y_ = None
|
|
|
|
set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True); set_inplace_requires_grad_allowed = None
|
|
|
|
_set_tensor_requires_grad: "f32[3, 3, 3]" = torch._functorch.eager_transforms._set_tensor_requires_grad(child); _set_tensor_requires_grad = None
|
|
|
|
set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False); set_inplace_requires_grad_allowed_1 = None
|
|
set_inplace_requires_grad_allowed_2 = torch._C._functorch.set_inplace_requires_grad_allowed(True); set_inplace_requires_grad_allowed_2 = None
|
|
|
|
_set_tensor_requires_grad_1: "f32[3, 3, 3]" = torch._functorch.eager_transforms._set_tensor_requires_grad(child_1); _set_tensor_requires_grad_1 = None
|
|
|
|
set_inplace_requires_grad_allowed_3 = torch._C._functorch.set_inplace_requires_grad_allowed(False); set_inplace_requires_grad_allowed_3 = None
|
|
|
|
sin: "f32[3, 3, 3]" = child.sin()
|
|
add: "f32[3, 3, 3]" = sin + child_1; sin = None
|
|
output: "f32[]" = add.sum(); add = None
|
|
aux: "f32[3, 3, 3]" = child.cos()
|
|
|
|
_autograd_grad = torch._functorch.eager_transforms._autograd_grad((output,), [child, child_1], create_graph = True); child = child_1 = None
|
|
child_2: "f32[3, 3, 3]" = _autograd_grad[0]
|
|
child_3: "f32[3, 3, 3]" = _autograd_grad[1]; _autograd_grad = None
|
|
|
|
_unwrap_for_grad: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(child_2, 1); child_2 = None
|
|
_unwrap_for_grad_1: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(child_3, 1); child_3 = None
|
|
|
|
output_1: "f32[]" = torch._C._functorch._unwrap_for_grad(output, 1); output = output_1 = None
|
|
|
|
aux_1: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(aux, 1); aux = None
|
|
|
|
_grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting(); _grad_decrement_nesting = None
|
|
_saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable(); _saved_tensors_hooks_enable = None
|
|
return (_unwrap_for_grad, _unwrap_for_grad_1, aux_1)
|
|
""",
|
|
)
|
|
|
|
def test_grad_over_grad(self):
|
|
counters.clear()
|
|
|
|
def fn(x):
|
|
return x.sin().sum()
|
|
|
|
def wrapper_fn(x):
|
|
return torch.func.grad(torch.func.grad(fn))(x)
|
|
|
|
x = torch.randn(())
|
|
wrapped_gm = self._compile_check(wrapper_fn, (x,), fullgraph=False)
|
|
|
|
if check_dynamic_shape_capture():
|
|
return
|
|
|
|
actual = normalize_gm(wrapped_gm.print_readable(print_output=False))
|
|
self.assertExpectedInline(
|
|
actual,
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(self, L_x_: "f32[]"):
|
|
l_x_ = L_x_
|
|
|
|
_saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func.{grad, vjp, jacrev, hessian} don't yet support saved tensor hooks. Please open an issue with your use case."); _saved_tensors_hooks_disable = None
|
|
_grad_increment_nesting = torch._C._functorch._grad_increment_nesting(); _grad_increment_nesting = None
|
|
|
|
diff_args: "f32[]" = torch._C._functorch._wrap_for_grad(l_x_, 1); l_x_ = None
|
|
|
|
set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True); set_inplace_requires_grad_allowed = None
|
|
|
|
_set_tensor_requires_grad: "f32[]" = torch._functorch.eager_transforms._set_tensor_requires_grad(diff_args); _set_tensor_requires_grad = None
|
|
|
|
set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False); set_inplace_requires_grad_allowed_1 = None
|
|
_saved_tensors_hooks_disable_1 = torch._C._autograd._saved_tensors_hooks_disable("torch.func.{grad, vjp, jacrev, hessian} don't yet support saved tensor hooks. Please open an issue with your use case."); _saved_tensors_hooks_disable_1 = None
|
|
_grad_increment_nesting_1 = torch._C._functorch._grad_increment_nesting(); _grad_increment_nesting_1 = None
|
|
|
|
diff_args_1: "f32[]" = torch._C._functorch._wrap_for_grad(diff_args, 2)
|
|
|
|
set_inplace_requires_grad_allowed_2 = torch._C._functorch.set_inplace_requires_grad_allowed(True); set_inplace_requires_grad_allowed_2 = None
|
|
|
|
_set_tensor_requires_grad_1: "f32[]" = torch._functorch.eager_transforms._set_tensor_requires_grad(diff_args_1); _set_tensor_requires_grad_1 = None
|
|
|
|
set_inplace_requires_grad_allowed_3 = torch._C._functorch.set_inplace_requires_grad_allowed(False); set_inplace_requires_grad_allowed_3 = None
|
|
|
|
sin: "f32[]" = diff_args_1.sin()
|
|
output: "f32[]" = sin.sum(); sin = None
|
|
|
|
_autograd_grad = torch._functorch.eager_transforms._autograd_grad((output,), [diff_args_1], create_graph = True); diff_args_1 = None
|
|
grad_input: "f32[]" = _autograd_grad[0]; _autograd_grad = None
|
|
|
|
grad_input_1: "f32[]" = torch._C._functorch._unwrap_for_grad(grad_input, 2); grad_input = None
|
|
|
|
output_1: "f32[]" = torch._C._functorch._unwrap_for_grad(output, 2); output = output_1 = None
|
|
|
|
_grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting(); _grad_decrement_nesting = None
|
|
_saved_tensors_hooks_disable_2 = torch._C._autograd._saved_tensors_hooks_disable("torch.func.{grad, vjp, jacrev, hessian} don't yet support saved tensor hooks. Please open an issue with your use case."); _saved_tensors_hooks_disable_2 = None
|
|
|
|
_autograd_grad_1 = torch._functorch.eager_transforms._autograd_grad((grad_input_1,), [diff_args], create_graph = True); diff_args = None
|
|
grad_input_2: "f32[]" = _autograd_grad_1[0]; _autograd_grad_1 = None
|
|
|
|
grad_input_3: "f32[]" = torch._C._functorch._unwrap_for_grad(grad_input_2, 1); grad_input_2 = None
|
|
|
|
output_2: "f32[]" = torch._C._functorch._unwrap_for_grad(grad_input_1, 1); grad_input_1 = output_2 = None
|
|
|
|
_grad_decrement_nesting_1 = torch._C._functorch._grad_decrement_nesting(); _grad_decrement_nesting_1 = None
|
|
_saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable(); _saved_tensors_hooks_enable = None
|
|
return (grad_input_3,)
|
|
""",
|
|
)
|
|
|
|
def test_grad_with_graph_break(self):
|
|
counters.clear()
|
|
|
|
def fn(x):
|
|
torch._dynamo.graph_break()
|
|
return x.sin().sum()
|
|
|
|
def wrapper_fn(x):
|
|
return torch.func.grad(fn)(x)
|
|
|
|
x = torch.randn(3, 3, 3)
|
|
actual = wrapper_fn(x)
|
|
expected = torch.compile(wrapper_fn, backend="aot_eager", fullgraph=False)(x)
|
|
self.assertEqual(len(counters["graph_break"]), 1)
|
|
self.assertEqual(actual, expected)
|
|
|
|
def test_grad_with_side_effect(self):
|
|
counters.clear()
|
|
|
|
foo = [1, 2]
|
|
|
|
def fn(x):
|
|
foo.append(3)
|
|
return x.sin().sum()
|
|
|
|
def wrapper_fn(x):
|
|
return torch.func.grad(fn)(x)
|
|
|
|
x = torch.randn(3, 3, 3)
|
|
actual = wrapper_fn(x)
|
|
expected = torch.compile(wrapper_fn, backend="aot_eager", fullgraph=False)(x)
|
|
self.assertEqual(len(counters["graph_break"]), 0)
|
|
self.assertEqual(actual, expected)
|
|
|
|
def test_grad_pytree(self):
|
|
counters.clear()
|
|
|
|
def fn(x):
|
|
x1, x2 = x
|
|
return x1.sin().sum() + x2
|
|
|
|
def wrapper_fn(x):
|
|
return torch.func.grad(fn)(x)
|
|
|
|
x1 = torch.randn(3, 3, 3)
|
|
x2 = torch.randn(())
|
|
actual = wrapper_fn((x1, x2))
|
|
expected = torch.compile(wrapper_fn, backend="aot_eager", fullgraph=False)(
|
|
(x1, x2)
|
|
)
|
|
self.assertEqual(len(counters["graph_break"]), 0)
|
|
self.assertEqual(actual, expected)
|
|
|
|
def test_grad_non_tensor_input(self):
|
|
counters.clear()
|
|
|
|
def fn(x, y):
|
|
return x.sin().sum() + y
|
|
|
|
def wrapper_fn(x, y):
|
|
return torch.func.grad(fn)(x, y)
|
|
|
|
x = torch.randn(3, 3, 3)
|
|
y = 3.0
|
|
wrapped_gm = self._compile_check(wrapper_fn, (x, y))
|
|
|
|
# Dynamic shapes produce a slightly different graph.
|
|
if check_dynamic_shape_capture():
|
|
return
|
|
|
|
actual = normalize_gm(wrapped_gm.print_readable(print_output=False))
|
|
self.assertExpectedInline(
|
|
actual,
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(self, L_x_: "f32[3, 3, 3]"):
|
|
l_x_ = L_x_
|
|
|
|
_saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func.{grad, vjp, jacrev, hessian} don't yet support saved tensor hooks. Please open an issue with your use case."); _saved_tensors_hooks_disable = None
|
|
_grad_increment_nesting = torch._C._functorch._grad_increment_nesting(); _grad_increment_nesting = None
|
|
|
|
diff_args: "f32[3, 3, 3]" = torch._C._functorch._wrap_for_grad(l_x_, 1); l_x_ = None
|
|
|
|
set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True); set_inplace_requires_grad_allowed = None
|
|
|
|
_set_tensor_requires_grad: "f32[3, 3, 3]" = torch._functorch.eager_transforms._set_tensor_requires_grad(diff_args); _set_tensor_requires_grad = None
|
|
|
|
set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False); set_inplace_requires_grad_allowed_1 = None
|
|
|
|
sin: "f32[3, 3, 3]" = diff_args.sin()
|
|
sum_1: "f32[]" = sin.sum(); sin = None
|
|
output: "f32[]" = sum_1 + 3.0; sum_1 = None
|
|
|
|
_autograd_grad = torch._functorch.eager_transforms._autograd_grad((output,), [diff_args], create_graph = True); diff_args = None
|
|
grad_input: "f32[3, 3, 3]" = _autograd_grad[0]; _autograd_grad = None
|
|
|
|
grad_input_1: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(grad_input, 1); grad_input = None
|
|
|
|
output_1: "f32[]" = torch._C._functorch._unwrap_for_grad(output, 1); output = output_1 = None
|
|
|
|
_grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting(); _grad_decrement_nesting = None
|
|
_saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable(); _saved_tensors_hooks_enable = None
|
|
return (grad_input_1,)
|
|
""",
|
|
)
|
|
|
|
def test_grad_fn_with_kwargs(self):
|
|
def fn(x, y):
|
|
return (x + y).sum()
|
|
|
|
def wrapper_fn(x, y):
|
|
return torch.func.grad(fn)(x, y=y)
|
|
|
|
x = torch.randn(3, 3)
|
|
y = torch.randn(3, 3)
|
|
actual = wrapper_fn(x, y)
|
|
expected = torch.compile(wrapper_fn, backend="aot_eager", fullgraph=False)(x, y)
|
|
self.assertEqual(len(counters["graph_break"]), 0)
|
|
self.assertEqual(actual, expected)
|
|
|
|
def test_jacfwd(self):
|
|
counters.clear()
|
|
|
|
def wrapper_fn(x):
|
|
return torch.func.jacfwd(torch.sin)(x)
|
|
|
|
x = torch.randn(4, 3)
|
|
wrapped_gm = self._compile_check(wrapper_fn, (x,))
|
|
# Dynamic shapes produce a slightly different graph.
|
|
if check_dynamic_shape_capture():
|
|
return
|
|
|
|
actual = normalize_gm(wrapped_gm.print_readable(print_output=False))
|
|
self.assertExpectedInline(
|
|
actual,
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(self, L_x_: "f32[4, 3]"):
|
|
l_x_ = L_x_
|
|
|
|
tensor: "i64[1]" = torch.tensor((12,))
|
|
cumsum: "i64[1]" = tensor.cumsum(dim = 0); tensor = None
|
|
getitem: "i64[0]" = cumsum[slice(None, -1, None)]; cumsum = None
|
|
neg: "i64[0]" = getitem.neg(); getitem = None
|
|
unbind = neg.unbind(); neg = unbind = None
|
|
|
|
chunk: "f32[12, 12]" = l_x_.new_zeros(12, 12)
|
|
|
|
diagonal: "f32[12]" = chunk.diagonal(0)
|
|
fill_: "f32[12]" = diagonal.fill_(1); diagonal = fill_ = None
|
|
|
|
child: "f32[12, 4, 3]" = chunk.view(12, 4, 3); chunk = None
|
|
|
|
lazy_load_decompositions = torch._functorch.vmap.lazy_load_decompositions(); lazy_load_decompositions = None
|
|
|
|
_vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(12, 'error'); _vmap_increment_nesting = None
|
|
|
|
child_1: "f32[4, 3]" = torch._C._functorch._add_batch_dim(child, 0, 1); child = None
|
|
|
|
_jvp_treespec_compare = torch._functorch.eager_transforms._jvp_treespec_compare((l_x_,), (child_1,)); _jvp_treespec_compare = None
|
|
|
|
_jvp_increment_nesting = torch._C._functorch._jvp_increment_nesting(); _jvp_increment_nesting = None
|
|
_set_fwd_grad_enabled = torch._C._set_fwd_grad_enabled(True); _set_fwd_grad_enabled = None
|
|
_enter_dual_level = torch._C._enter_dual_level(); _enter_dual_level = None
|
|
|
|
_maybe_load_decompositions = torch.autograd.forward_ad._maybe_load_decompositions(); _maybe_load_decompositions = None
|
|
|
|
_make_dual: "f32[4, 3]" = torch._make_dual(l_x_, child_1, level = 0); child_1 = None
|
|
|
|
_wrap_for_grad: "f32[4, 3]" = torch._C._functorch._wrap_for_grad(l_x_, 2); l_x_ = _wrap_for_grad = None
|
|
|
|
result_duals: "f32[4, 3]" = torch.sin(_make_dual); _make_dual = None
|
|
|
|
_unpack_dual = torch._unpack_dual(result_duals, level = 0); result_duals = None
|
|
primal: "f32[4, 3]" = _unpack_dual[0]
|
|
dual: "f32[4, 3]" = _unpack_dual[1]; _unpack_dual = None
|
|
|
|
primals_out_unflatten: "f32[4, 3]" = torch._C._functorch._unwrap_for_grad(primal, 2); primal = primals_out_unflatten = None
|
|
|
|
tangents_out_unflatten: "f32[4, 3]" = torch._C._functorch._unwrap_for_grad(dual, 2); dual = None
|
|
|
|
_exit_dual_level = torch._C._exit_dual_level(0); _exit_dual_level = None
|
|
_set_fwd_grad_enabled_1 = torch._C._set_fwd_grad_enabled(True); _set_fwd_grad_enabled_1 = None
|
|
_jvp_decrement_nesting = torch._C._functorch._jvp_decrement_nesting(); _jvp_decrement_nesting = None
|
|
|
|
results: "f32[12, 4, 3]" = torch._C._functorch._remove_batch_dim(tangents_out_unflatten, 1, 12, 0); tangents_out_unflatten = None
|
|
|
|
_vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting(); _vmap_decrement_nesting = None
|
|
|
|
movedim: "f32[4, 3, 12]" = results.movedim(0, -1); results = None
|
|
split = movedim.split((12,), dim = -1); movedim = None
|
|
jac_out_in: "f32[4, 3, 12]" = split[0]; split = None
|
|
|
|
unflatten: "f32[4, 3, 4, 3]" = jac_out_in.unflatten(-1, (4, 3)); jac_out_in = None
|
|
return (unflatten,)
|
|
""",
|
|
)
|
|
|
|
def test_jacfwd_two_tensors_argnums(self):
|
|
counters.clear()
|
|
|
|
def fn(x, y):
|
|
return y.sin()
|
|
|
|
def wrapper_fn(x, y):
|
|
return torch.func.jacfwd(fn, argnums=1)(x, y)
|
|
|
|
x = torch.randn(4, 3)
|
|
y = torch.randn(3, 4)
|
|
wrapped_gm = self._compile_check(wrapper_fn, (x, y))
|
|
# Dynamic shapes produce a slightly different graph.
|
|
if check_dynamic_shape_capture():
|
|
return
|
|
|
|
actual = normalize_gm(wrapped_gm.print_readable(print_output=False))
|
|
self.assertExpectedInline(
|
|
actual,
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(self, L_x_: "f32[4, 3]", L_y_: "f32[3, 4]"):
|
|
l_x_ = L_x_
|
|
l_y_ = L_y_
|
|
|
|
tensor: "i64[1]" = torch.tensor((12,))
|
|
cumsum: "i64[1]" = tensor.cumsum(dim = 0); tensor = None
|
|
getitem: "i64[0]" = cumsum[slice(None, -1, None)]; cumsum = None
|
|
neg: "i64[0]" = getitem.neg(); getitem = None
|
|
unbind = neg.unbind(); neg = unbind = None
|
|
|
|
chunk: "f32[12, 12]" = l_y_.new_zeros(12, 12)
|
|
|
|
diagonal: "f32[12]" = chunk.diagonal(0)
|
|
fill_: "f32[12]" = diagonal.fill_(1); diagonal = fill_ = None
|
|
|
|
child: "f32[12, 3, 4]" = chunk.view(12, 3, 4); chunk = None
|
|
|
|
lazy_load_decompositions = torch._functorch.vmap.lazy_load_decompositions(); lazy_load_decompositions = None
|
|
|
|
_vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(12, 'error'); _vmap_increment_nesting = None
|
|
|
|
child_1: "f32[3, 4]" = torch._C._functorch._add_batch_dim(child, 0, 1); child = None
|
|
|
|
_jvp_treespec_compare = torch._functorch.eager_transforms._jvp_treespec_compare((l_y_,), (child_1,)); _jvp_treespec_compare = None
|
|
|
|
_jvp_increment_nesting = torch._C._functorch._jvp_increment_nesting(); _jvp_increment_nesting = None
|
|
_set_fwd_grad_enabled = torch._C._set_fwd_grad_enabled(True); _set_fwd_grad_enabled = None
|
|
_enter_dual_level = torch._C._enter_dual_level(); _enter_dual_level = None
|
|
|
|
_maybe_load_decompositions = torch.autograd.forward_ad._maybe_load_decompositions(); _maybe_load_decompositions = None
|
|
|
|
_make_dual: "f32[3, 4]" = torch._make_dual(l_y_, child_1, level = 0); child_1 = None
|
|
|
|
_wrap_for_grad: "f32[4, 3]" = torch._C._functorch._wrap_for_grad(l_x_, 2); l_x_ = _wrap_for_grad = None
|
|
_wrap_for_grad_1: "f32[3, 4]" = torch._C._functorch._wrap_for_grad(l_y_, 2); l_y_ = _wrap_for_grad_1 = None
|
|
|
|
result_duals: "f32[3, 4]" = _make_dual.sin(); _make_dual = None
|
|
|
|
_unpack_dual = torch._unpack_dual(result_duals, level = 0); result_duals = None
|
|
primal: "f32[3, 4]" = _unpack_dual[0]
|
|
dual: "f32[3, 4]" = _unpack_dual[1]; _unpack_dual = None
|
|
|
|
primals_out_unflatten: "f32[3, 4]" = torch._C._functorch._unwrap_for_grad(primal, 2); primal = primals_out_unflatten = None
|
|
|
|
tangents_out_unflatten: "f32[3, 4]" = torch._C._functorch._unwrap_for_grad(dual, 2); dual = None
|
|
|
|
_exit_dual_level = torch._C._exit_dual_level(0); _exit_dual_level = None
|
|
_set_fwd_grad_enabled_1 = torch._C._set_fwd_grad_enabled(True); _set_fwd_grad_enabled_1 = None
|
|
_jvp_decrement_nesting = torch._C._functorch._jvp_decrement_nesting(); _jvp_decrement_nesting = None
|
|
|
|
results: "f32[12, 3, 4]" = torch._C._functorch._remove_batch_dim(tangents_out_unflatten, 1, 12, 0); tangents_out_unflatten = None
|
|
|
|
_vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting(); _vmap_decrement_nesting = None
|
|
|
|
movedim: "f32[3, 4, 12]" = results.movedim(0, -1); results = None
|
|
split = movedim.split((12,), dim = -1); movedim = None
|
|
jac_out_in: "f32[3, 4, 12]" = split[0]; split = None
|
|
|
|
unflatten: "f32[3, 4, 3, 4]" = jac_out_in.unflatten(-1, (3, 4)); jac_out_in = None
|
|
return (unflatten,)
|
|
""",
|
|
)
|
|
|
|
def test_jacfwd_has_aux(self):
|
|
counters.clear()
|
|
|
|
def fn(x, y):
|
|
return y.sin(), x
|
|
|
|
def wrapper_fn(x, y):
|
|
return torch.func.jacfwd(fn, argnums=1, has_aux=True)(x, y)
|
|
|
|
x = torch.randn(4, 3)
|
|
y = torch.randn(3, 4)
|
|
wrapped_gm = self._compile_check(wrapper_fn, (x, y))
|
|
# Dynamic shapes produce a slightly different graph.
|
|
if check_dynamic_shape_capture():
|
|
return
|
|
|
|
actual = normalize_gm(wrapped_gm.print_readable(print_output=False))
|
|
self.assertExpectedInline(
|
|
actual,
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(self, L_x_: "f32[4, 3]", L_y_: "f32[3, 4]"):
|
|
l_x_ = L_x_
|
|
l_y_ = L_y_
|
|
|
|
tensor: "i64[1]" = torch.tensor((12,))
|
|
cumsum: "i64[1]" = tensor.cumsum(dim = 0); tensor = None
|
|
getitem: "i64[0]" = cumsum[slice(None, -1, None)]; cumsum = None
|
|
neg: "i64[0]" = getitem.neg(); getitem = None
|
|
unbind = neg.unbind(); neg = unbind = None
|
|
|
|
chunk: "f32[12, 12]" = l_y_.new_zeros(12, 12)
|
|
|
|
diagonal: "f32[12]" = chunk.diagonal(0)
|
|
fill_: "f32[12]" = diagonal.fill_(1); diagonal = fill_ = None
|
|
|
|
child: "f32[12, 3, 4]" = chunk.view(12, 3, 4); chunk = None
|
|
|
|
lazy_load_decompositions = torch._functorch.vmap.lazy_load_decompositions(); lazy_load_decompositions = None
|
|
|
|
_vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(12, 'error'); _vmap_increment_nesting = None
|
|
|
|
child_1: "f32[3, 4]" = torch._C._functorch._add_batch_dim(child, 0, 1); child = None
|
|
|
|
_jvp_treespec_compare = torch._functorch.eager_transforms._jvp_treespec_compare((l_y_,), (child_1,)); _jvp_treespec_compare = None
|
|
|
|
_jvp_increment_nesting = torch._C._functorch._jvp_increment_nesting(); _jvp_increment_nesting = None
|
|
_set_fwd_grad_enabled = torch._C._set_fwd_grad_enabled(True); _set_fwd_grad_enabled = None
|
|
_enter_dual_level = torch._C._enter_dual_level(); _enter_dual_level = None
|
|
|
|
_maybe_load_decompositions = torch.autograd.forward_ad._maybe_load_decompositions(); _maybe_load_decompositions = None
|
|
|
|
_make_dual: "f32[3, 4]" = torch._make_dual(l_y_, child_1, level = 0); child_1 = None
|
|
|
|
aux: "f32[4, 3]" = torch._C._functorch._wrap_for_grad(l_x_, 2); l_x_ = None
|
|
_wrap_for_grad_1: "f32[3, 4]" = torch._C._functorch._wrap_for_grad(l_y_, 2); l_y_ = _wrap_for_grad_1 = None
|
|
|
|
result_duals: "f32[3, 4]" = _make_dual.sin(); _make_dual = None
|
|
|
|
aux_1: "f32[4, 3]" = torch._C._functorch._unwrap_for_grad(aux, 2); aux = None
|
|
|
|
_unpack_dual = torch._unpack_dual(result_duals, level = 0); result_duals = None
|
|
primal: "f32[3, 4]" = _unpack_dual[0]
|
|
dual: "f32[3, 4]" = _unpack_dual[1]; _unpack_dual = None
|
|
|
|
primals_out_unflatten: "f32[3, 4]" = torch._C._functorch._unwrap_for_grad(primal, 2); primal = primals_out_unflatten = None
|
|
|
|
tangents_out_unflatten: "f32[3, 4]" = torch._C._functorch._unwrap_for_grad(dual, 2); dual = None
|
|
|
|
_exit_dual_level = torch._C._exit_dual_level(0); _exit_dual_level = None
|
|
_set_fwd_grad_enabled_1 = torch._C._set_fwd_grad_enabled(True); _set_fwd_grad_enabled_1 = None
|
|
_jvp_decrement_nesting = torch._C._functorch._jvp_decrement_nesting(); _jvp_decrement_nesting = None
|
|
|
|
results: "f32[12, 3, 4]" = torch._C._functorch._remove_batch_dim(tangents_out_unflatten, 1, 12, 0); tangents_out_unflatten = None
|
|
aux_2: "f32[12, 4, 3]" = torch._C._functorch._remove_batch_dim(aux_1, 1, 12, 0); aux_1 = None
|
|
|
|
_vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting(); _vmap_decrement_nesting = None
|
|
|
|
aux_3: "f32[4, 3]" = aux_2[0]; aux_2 = None
|
|
|
|
movedim: "f32[3, 4, 12]" = results.movedim(0, -1); results = None
|
|
split = movedim.split((12,), dim = -1); movedim = None
|
|
jac_out_in: "f32[3, 4, 12]" = split[0]; split = None
|
|
|
|
unflatten: "f32[3, 4, 3, 4]" = jac_out_in.unflatten(-1, (3, 4)); jac_out_in = None
|
|
return (unflatten, aux_3)
|
|
""",
|
|
)
|
|
|
|
def test_jacfwd_randomness(self):
|
|
counters.clear()
|
|
|
|
def fn(x, y):
|
|
return y.sin(), x
|
|
|
|
def wrapper_fn(x, y):
|
|
return torch.func.jacfwd(fn, randomness="same")(x, y)
|
|
|
|
x = torch.randn(4, 3)
|
|
y = torch.randn(3, 4)
|
|
wrapped_gm = self._compile_check(wrapper_fn, (x, y))
|
|
# Dynamic shapes produce a slightly different graph.
|
|
if check_dynamic_shape_capture():
|
|
return
|
|
|
|
actual = normalize_gm(wrapped_gm.print_readable(print_output=False))
|
|
self.assertExpectedInline(
|
|
actual,
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(self, L_x_: "f32[4, 3]", L_y_: "f32[3, 4]"):
|
|
l_x_ = L_x_
|
|
l_y_ = L_y_
|
|
|
|
tensor: "i64[1]" = torch.tensor((12,))
|
|
cumsum: "i64[1]" = tensor.cumsum(dim = 0); tensor = None
|
|
getitem: "i64[0]" = cumsum[slice(None, -1, None)]; cumsum = None
|
|
neg: "i64[0]" = getitem.neg(); getitem = None
|
|
unbind = neg.unbind(); neg = unbind = None
|
|
|
|
chunk: "f32[12, 12]" = l_x_.new_zeros(12, 12)
|
|
|
|
diagonal: "f32[12]" = chunk.diagonal(0)
|
|
fill_: "f32[12]" = diagonal.fill_(1); diagonal = fill_ = None
|
|
|
|
child: "f32[12, 4, 3]" = chunk.view(12, 4, 3); chunk = None
|
|
|
|
lazy_load_decompositions = torch._functorch.vmap.lazy_load_decompositions(); lazy_load_decompositions = None
|
|
|
|
_vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(12, 'same'); _vmap_increment_nesting = None
|
|
|
|
child_1: "f32[4, 3]" = torch._C._functorch._add_batch_dim(child, 0, 1); child = None
|
|
|
|
_jvp_treespec_compare = torch._functorch.eager_transforms._jvp_treespec_compare((l_x_,), (child_1,)); _jvp_treespec_compare = None
|
|
|
|
_jvp_increment_nesting = torch._C._functorch._jvp_increment_nesting(); _jvp_increment_nesting = None
|
|
_set_fwd_grad_enabled = torch._C._set_fwd_grad_enabled(True); _set_fwd_grad_enabled = None
|
|
_enter_dual_level = torch._C._enter_dual_level(); _enter_dual_level = None
|
|
|
|
_maybe_load_decompositions = torch.autograd.forward_ad._maybe_load_decompositions(); _maybe_load_decompositions = None
|
|
|
|
child_3: "f32[4, 3]" = torch._make_dual(l_x_, child_1, level = 0); child_1 = None
|
|
|
|
_wrap_for_grad: "f32[4, 3]" = torch._C._functorch._wrap_for_grad(l_x_, 2); l_x_ = _wrap_for_grad = None
|
|
_wrap_for_grad_1: "f32[3, 4]" = torch._C._functorch._wrap_for_grad(l_y_, 2); l_y_ = None
|
|
|
|
child_2: "f32[3, 4]" = _wrap_for_grad_1.sin(); _wrap_for_grad_1 = None
|
|
|
|
_unpack_dual = torch._unpack_dual(child_2, level = 0); child_2 = None
|
|
primal: "f32[3, 4]" = _unpack_dual[0]; _unpack_dual = None
|
|
|
|
tangent: "f32[3, 4]" = torch.zeros_like(primal)
|
|
|
|
_unpack_dual_1 = torch._unpack_dual(child_3, level = 0); child_3 = None
|
|
primal_1: "f32[4, 3]" = _unpack_dual_1[0]
|
|
dual: "f32[4, 3]" = _unpack_dual_1[1]; _unpack_dual_1 = None
|
|
|
|
child_4: "f32[3, 4]" = torch._C._functorch._unwrap_for_grad(primal, 2); primal = child_4 = None
|
|
child_5: "f32[4, 3]" = torch._C._functorch._unwrap_for_grad(primal_1, 2); primal_1 = child_5 = None
|
|
|
|
child_6: "f32[3, 4]" = torch._C._functorch._unwrap_for_grad(tangent, 2); tangent = None
|
|
child_7: "f32[4, 3]" = torch._C._functorch._unwrap_for_grad(dual, 2); dual = None
|
|
|
|
_exit_dual_level = torch._C._exit_dual_level(0); _exit_dual_level = None
|
|
_set_fwd_grad_enabled_1 = torch._C._set_fwd_grad_enabled(True); _set_fwd_grad_enabled_1 = None
|
|
_jvp_decrement_nesting = torch._C._functorch._jvp_decrement_nesting(); _jvp_decrement_nesting = None
|
|
|
|
child_8: "f32[12, 3, 4]" = torch._C._functorch._remove_batch_dim(child_6, 1, 12, 0); child_6 = None
|
|
child_9: "f32[12, 4, 3]" = torch._C._functorch._remove_batch_dim(child_7, 1, 12, 0); child_7 = None
|
|
|
|
_vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting(); _vmap_decrement_nesting = None
|
|
|
|
movedim: "f32[3, 4, 12]" = child_8.movedim(0, -1); child_8 = None
|
|
split = movedim.split((12,), dim = -1); movedim = None
|
|
jac_out_in: "f32[3, 4, 12]" = split[0]; split = None
|
|
|
|
unflatten: "f32[3, 4, 4, 3]" = jac_out_in.unflatten(-1, (4, 3)); jac_out_in = None
|
|
|
|
movedim_1: "f32[4, 3, 12]" = child_9.movedim(0, -1); child_9 = None
|
|
split_1 = movedim_1.split((12,), dim = -1); movedim_1 = None
|
|
jac_out_in_1: "f32[4, 3, 12]" = split_1[0]; split_1 = None
|
|
|
|
unflatten_1: "f32[4, 3, 4, 3]" = jac_out_in_1.unflatten(-1, (4, 3)); jac_out_in_1 = None
|
|
return (unflatten, unflatten_1)
|
|
""",
|
|
)
|
|
|
|
def test_jvp_simple(self):
|
|
counters.clear()
|
|
|
|
def fn(x):
|
|
return x.sin().sum()
|
|
|
|
def wrapper_fn(x, v):
|
|
return torch.func.jvp(fn, (x,), (v,))
|
|
|
|
x = torch.randn(3, 3)
|
|
v = torch.randn(3, 3)
|
|
wrapped_gm = self._compile_check(wrapper_fn, (x, v))
|
|
|
|
# Dynamic shapes produce a slightly different graph.
|
|
if check_dynamic_shape_capture():
|
|
return
|
|
|
|
actual = normalize_gm(wrapped_gm.print_readable(print_output=False))
|
|
self.assertExpectedInline(
|
|
actual,
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(self, L_x_: "f32[3, 3]", L_v_: "f32[3, 3]"):
|
|
l_x_ = L_x_
|
|
l_v_ = L_v_
|
|
|
|
_jvp_treespec_compare = torch._functorch.eager_transforms._jvp_treespec_compare((l_x_,), (l_v_,)); _jvp_treespec_compare = None
|
|
|
|
_jvp_increment_nesting = torch._C._functorch._jvp_increment_nesting(); _jvp_increment_nesting = None
|
|
_set_fwd_grad_enabled = torch._C._set_fwd_grad_enabled(True); _set_fwd_grad_enabled = None
|
|
_enter_dual_level = torch._C._enter_dual_level(); _enter_dual_level = None
|
|
|
|
_maybe_load_decompositions = torch.autograd.forward_ad._maybe_load_decompositions(); _maybe_load_decompositions = None
|
|
|
|
_make_dual: "f32[3, 3]" = torch._make_dual(l_x_, l_v_, level = 0); l_x_ = l_v_ = None
|
|
|
|
sin: "f32[3, 3]" = _make_dual.sin(); _make_dual = None
|
|
result_duals: "f32[]" = sin.sum(); sin = None
|
|
|
|
_unpack_dual = torch._unpack_dual(result_duals, level = 0); result_duals = None
|
|
primal: "f32[]" = _unpack_dual[0]
|
|
dual: "f32[]" = _unpack_dual[1]; _unpack_dual = None
|
|
|
|
primals_out_unflatten: "f32[]" = torch._C._functorch._unwrap_for_grad(primal, 1); primal = None
|
|
|
|
tangents_out_unflatten: "f32[]" = torch._C._functorch._unwrap_for_grad(dual, 1); dual = None
|
|
|
|
_exit_dual_level = torch._C._exit_dual_level(0); _exit_dual_level = None
|
|
_set_fwd_grad_enabled_1 = torch._C._set_fwd_grad_enabled(True); _set_fwd_grad_enabled_1 = None
|
|
_jvp_decrement_nesting = torch._C._functorch._jvp_decrement_nesting(); _jvp_decrement_nesting = None
|
|
return (primals_out_unflatten, tangents_out_unflatten)
|
|
""",
|
|
)
|
|
|
|
def test_jvp_has_aux(self):
|
|
counters.clear()
|
|
|
|
def fn(x):
|
|
return x.sin().sum(), x
|
|
|
|
def wrapper_fn(x, v):
|
|
return torch.func.jvp(fn, (x,), (v,), has_aux=True)
|
|
|
|
x = torch.randn(3, 3)
|
|
v = torch.randn(3, 3)
|
|
wrapped_gm = self._compile_check(wrapper_fn, (x, v))
|
|
|
|
# Dynamic shapes produce a slightly different graph.
|
|
if check_dynamic_shape_capture():
|
|
return
|
|
|
|
actual = normalize_gm(wrapped_gm.print_readable(print_output=False))
|
|
self.assertExpectedInline(
|
|
actual,
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(self, L_x_: "f32[3, 3]", L_v_: "f32[3, 3]"):
|
|
l_x_ = L_x_
|
|
l_v_ = L_v_
|
|
|
|
_jvp_treespec_compare = torch._functorch.eager_transforms._jvp_treespec_compare((l_x_,), (l_v_,)); _jvp_treespec_compare = None
|
|
|
|
_jvp_increment_nesting = torch._C._functorch._jvp_increment_nesting(); _jvp_increment_nesting = None
|
|
_set_fwd_grad_enabled = torch._C._set_fwd_grad_enabled(True); _set_fwd_grad_enabled = None
|
|
_enter_dual_level = torch._C._enter_dual_level(); _enter_dual_level = None
|
|
|
|
_maybe_load_decompositions = torch.autograd.forward_ad._maybe_load_decompositions(); _maybe_load_decompositions = None
|
|
|
|
aux: "f32[3, 3]" = torch._make_dual(l_x_, l_v_, level = 0); l_x_ = l_v_ = None
|
|
|
|
sin: "f32[3, 3]" = aux.sin()
|
|
result_duals: "f32[]" = sin.sum(); sin = None
|
|
|
|
aux_1: "f32[3, 3]" = torch._C._functorch._unwrap_for_grad(aux, 1); aux = None
|
|
|
|
_unpack_dual = torch._unpack_dual(result_duals, level = 0); result_duals = None
|
|
primal: "f32[]" = _unpack_dual[0]
|
|
dual: "f32[]" = _unpack_dual[1]; _unpack_dual = None
|
|
|
|
primals_out_unflatten: "f32[]" = torch._C._functorch._unwrap_for_grad(primal, 1); primal = None
|
|
|
|
tangents_out_unflatten: "f32[]" = torch._C._functorch._unwrap_for_grad(dual, 1); dual = None
|
|
|
|
_exit_dual_level = torch._C._exit_dual_level(0); _exit_dual_level = None
|
|
_set_fwd_grad_enabled_1 = torch._C._set_fwd_grad_enabled(True); _set_fwd_grad_enabled_1 = None
|
|
_jvp_decrement_nesting = torch._C._functorch._jvp_decrement_nesting(); _jvp_decrement_nesting = None
|
|
return (primals_out_unflatten, tangents_out_unflatten, aux_1)
|
|
""",
|
|
)
|
|
|
|
def test_jvp_two_tensors_has_aux(self):
|
|
counters.clear()
|
|
|
|
def fn(x, y):
|
|
return (x.sin().sum() + y.cos()), x
|
|
|
|
def wrapper_fn(x, y, v):
|
|
return torch.func.jvp(fn, (x, y), (v, v), has_aux=True)
|
|
|
|
x = torch.randn(3, 3)
|
|
y = torch.randn(3, 3)
|
|
v = torch.randn(3, 3)
|
|
wrapped_gm = self._compile_check(wrapper_fn, (x, y, v))
|
|
|
|
# Dynamic shapes produce a slightly different graph.
|
|
if check_dynamic_shape_capture():
|
|
return
|
|
|
|
actual = normalize_gm(wrapped_gm.print_readable(print_output=False))
|
|
self.assertExpectedInline(
|
|
actual,
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(self, L_x_: "f32[3, 3]", L_y_: "f32[3, 3]", L_v_: "f32[3, 3]"):
|
|
l_x_ = L_x_
|
|
l_y_ = L_y_
|
|
l_v_ = L_v_
|
|
|
|
_jvp_treespec_compare = torch._functorch.eager_transforms._jvp_treespec_compare((l_x_, l_y_), (l_v_, l_v_)); _jvp_treespec_compare = None
|
|
|
|
_jvp_increment_nesting = torch._C._functorch._jvp_increment_nesting(); _jvp_increment_nesting = None
|
|
_set_fwd_grad_enabled = torch._C._set_fwd_grad_enabled(True); _set_fwd_grad_enabled = None
|
|
_enter_dual_level = torch._C._enter_dual_level(); _enter_dual_level = None
|
|
|
|
_maybe_load_decompositions = torch.autograd.forward_ad._maybe_load_decompositions(); _maybe_load_decompositions = None
|
|
|
|
aux: "f32[3, 3]" = torch._make_dual(l_x_, l_v_, level = 0); l_x_ = None
|
|
|
|
_maybe_load_decompositions_1 = torch.autograd.forward_ad._maybe_load_decompositions(); _maybe_load_decompositions_1 = None
|
|
|
|
_make_dual_1: "f32[3, 3]" = torch._make_dual(l_y_, l_v_, level = 0); l_y_ = l_v_ = None
|
|
|
|
sin: "f32[3, 3]" = aux.sin()
|
|
sum_1: "f32[]" = sin.sum(); sin = None
|
|
cos: "f32[3, 3]" = _make_dual_1.cos(); _make_dual_1 = None
|
|
result_duals: "f32[3, 3]" = sum_1 + cos; sum_1 = cos = None
|
|
|
|
aux_1: "f32[3, 3]" = torch._C._functorch._unwrap_for_grad(aux, 1); aux = None
|
|
|
|
_unpack_dual = torch._unpack_dual(result_duals, level = 0); result_duals = None
|
|
primal: "f32[3, 3]" = _unpack_dual[0]
|
|
dual: "f32[3, 3]" = _unpack_dual[1]; _unpack_dual = None
|
|
|
|
primals_out_unflatten: "f32[3, 3]" = torch._C._functorch._unwrap_for_grad(primal, 1); primal = None
|
|
|
|
tangents_out_unflatten: "f32[3, 3]" = torch._C._functorch._unwrap_for_grad(dual, 1); dual = None
|
|
|
|
_exit_dual_level = torch._C._exit_dual_level(0); _exit_dual_level = None
|
|
_set_fwd_grad_enabled_1 = torch._C._set_fwd_grad_enabled(True); _set_fwd_grad_enabled_1 = None
|
|
_jvp_decrement_nesting = torch._C._functorch._jvp_decrement_nesting(); _jvp_decrement_nesting = None
|
|
return (primals_out_unflatten, tangents_out_unflatten, aux_1)
|
|
""",
|
|
)
|
|
|
|
def test_jvp_two_tensors_disable_grad(self):
|
|
counters.clear()
|
|
|
|
def fn(x):
|
|
return x.sin().sum()
|
|
|
|
def wrapper_fn(x, v):
|
|
with torch.autograd.forward_ad._set_fwd_grad_enabled(False):
|
|
return torch.func.jvp(fn, (x,), (v,))
|
|
|
|
x = torch.randn(3, 3)
|
|
v = torch.randn(3, 3)
|
|
wrapped_gm = self._compile_check(wrapper_fn, (x, v))
|
|
|
|
# Dynamic shapes produce a slightly different graph.
|
|
if check_dynamic_shape_capture():
|
|
return
|
|
|
|
actual = normalize_gm(wrapped_gm.print_readable(print_output=False))
|
|
self.assertExpectedInline(
|
|
actual,
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(self, L_x_: "f32[3, 3]", L_v_: "f32[3, 3]"):
|
|
l_x_ = L_x_
|
|
l_v_ = L_v_
|
|
|
|
_set_fwd_grad_enabled = torch._C._set_fwd_grad_enabled(False); _set_fwd_grad_enabled = None
|
|
|
|
_jvp_treespec_compare = torch._functorch.eager_transforms._jvp_treespec_compare((l_x_,), (l_v_,)); _jvp_treespec_compare = None
|
|
|
|
_jvp_increment_nesting = torch._C._functorch._jvp_increment_nesting(); _jvp_increment_nesting = None
|
|
_set_fwd_grad_enabled_1 = torch._C._set_fwd_grad_enabled(True); _set_fwd_grad_enabled_1 = None
|
|
_enter_dual_level = torch._C._enter_dual_level(); _enter_dual_level = None
|
|
|
|
_maybe_load_decompositions = torch.autograd.forward_ad._maybe_load_decompositions(); _maybe_load_decompositions = None
|
|
|
|
_make_dual: "f32[3, 3]" = torch._make_dual(l_x_, l_v_, level = 0); l_x_ = l_v_ = None
|
|
|
|
sin: "f32[3, 3]" = _make_dual.sin(); _make_dual = None
|
|
result_duals: "f32[]" = sin.sum(); sin = None
|
|
|
|
_unpack_dual = torch._unpack_dual(result_duals, level = 0); result_duals = None
|
|
primal: "f32[]" = _unpack_dual[0]
|
|
dual: "f32[]" = _unpack_dual[1]; _unpack_dual = None
|
|
|
|
primals_out_unflatten: "f32[]" = torch._C._functorch._unwrap_for_grad(primal, 1); primal = None
|
|
|
|
tangents_out_unflatten: "f32[]" = torch._C._functorch._unwrap_for_grad(dual, 1); dual = None
|
|
|
|
_exit_dual_level = torch._C._exit_dual_level(0); _exit_dual_level = None
|
|
_set_fwd_grad_enabled_2 = torch._C._set_fwd_grad_enabled(False); _set_fwd_grad_enabled_2 = None
|
|
_jvp_decrement_nesting = torch._C._functorch._jvp_decrement_nesting(); _jvp_decrement_nesting = None
|
|
_set_fwd_grad_enabled_3 = torch._C._set_fwd_grad_enabled(True); _set_fwd_grad_enabled_3 = None
|
|
return (primals_out_unflatten, tangents_out_unflatten)
|
|
""",
|
|
)
|
|
|
|
def test_jvp_two_tensors_disable_enable_disable_grad(self):
|
|
counters.clear()
|
|
|
|
def fn(x):
|
|
return x.sin().sum()
|
|
|
|
def wrapper_fn(x, v):
|
|
with torch.autograd.forward_ad._set_fwd_grad_enabled(False): # (1)
|
|
with torch.autograd.forward_ad._set_fwd_grad_enabled(True): # (2)
|
|
with torch.autograd.forward_ad._set_fwd_grad_enabled(False): # (3)
|
|
return torch.func.jvp(fn, (x,), (v,)) # (4)
|
|
|
|
# Start True
|
|
# False (1)
|
|
# True (2)
|
|
# False (3)
|
|
# True (4)
|
|
# True (undo 3)
|
|
# False (undo 2)
|
|
# True (undo 1)
|
|
|
|
x = torch.randn(3, 3)
|
|
v = torch.randn(3, 3)
|
|
wrapped_gm = self._compile_check(wrapper_fn, (x, v))
|
|
|
|
# Dynamic shapes produce a slightly different graph.
|
|
if check_dynamic_shape_capture():
|
|
return
|
|
|
|
actual = normalize_gm(wrapped_gm.print_readable(print_output=False))
|
|
self.assertExpectedInline(
|
|
actual,
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(self, L_x_: "f32[3, 3]", L_v_: "f32[3, 3]"):
|
|
l_x_ = L_x_
|
|
l_v_ = L_v_
|
|
|
|
_set_fwd_grad_enabled = torch._C._set_fwd_grad_enabled(False); _set_fwd_grad_enabled = None
|
|
_set_fwd_grad_enabled_1 = torch._C._set_fwd_grad_enabled(True); _set_fwd_grad_enabled_1 = None
|
|
_set_fwd_grad_enabled_2 = torch._C._set_fwd_grad_enabled(False); _set_fwd_grad_enabled_2 = None
|
|
|
|
_jvp_treespec_compare = torch._functorch.eager_transforms._jvp_treespec_compare((l_x_,), (l_v_,)); _jvp_treespec_compare = None
|
|
|
|
_jvp_increment_nesting = torch._C._functorch._jvp_increment_nesting(); _jvp_increment_nesting = None
|
|
_set_fwd_grad_enabled_3 = torch._C._set_fwd_grad_enabled(True); _set_fwd_grad_enabled_3 = None
|
|
_enter_dual_level = torch._C._enter_dual_level(); _enter_dual_level = None
|
|
|
|
_maybe_load_decompositions = torch.autograd.forward_ad._maybe_load_decompositions(); _maybe_load_decompositions = None
|
|
|
|
_make_dual: "f32[3, 3]" = torch._make_dual(l_x_, l_v_, level = 0); l_x_ = l_v_ = None
|
|
|
|
sin: "f32[3, 3]" = _make_dual.sin(); _make_dual = None
|
|
result_duals: "f32[]" = sin.sum(); sin = None
|
|
|
|
_unpack_dual = torch._unpack_dual(result_duals, level = 0); result_duals = None
|
|
primal: "f32[]" = _unpack_dual[0]
|
|
dual: "f32[]" = _unpack_dual[1]; _unpack_dual = None
|
|
|
|
primals_out_unflatten: "f32[]" = torch._C._functorch._unwrap_for_grad(primal, 1); primal = None
|
|
|
|
tangents_out_unflatten: "f32[]" = torch._C._functorch._unwrap_for_grad(dual, 1); dual = None
|
|
|
|
_exit_dual_level = torch._C._exit_dual_level(0); _exit_dual_level = None
|
|
_set_fwd_grad_enabled_4 = torch._C._set_fwd_grad_enabled(False); _set_fwd_grad_enabled_4 = None
|
|
_jvp_decrement_nesting = torch._C._functorch._jvp_decrement_nesting(); _jvp_decrement_nesting = None
|
|
_set_fwd_grad_enabled_5 = torch._C._set_fwd_grad_enabled(True); _set_fwd_grad_enabled_5 = None
|
|
_set_fwd_grad_enabled_6 = torch._C._set_fwd_grad_enabled(False); _set_fwd_grad_enabled_6 = None
|
|
_set_fwd_grad_enabled_7 = torch._C._set_fwd_grad_enabled(True); _set_fwd_grad_enabled_7 = None
|
|
return (primals_out_unflatten, tangents_out_unflatten)
|
|
""",
|
|
)
|
|
|
|
def test_jvp_freevar_tensor(self):
|
|
counters.clear()
|
|
y = torch.randn(3, 3)
|
|
|
|
def fn(x):
|
|
return (x.sin() + y).sum()
|
|
|
|
def wrapper_fn(x):
|
|
return torch.func.jvp(fn, (x,), (x,))
|
|
|
|
x = torch.randn(3, 3)
|
|
expected = wrapper_fn(x)
|
|
actual = torch.compile(wrapper_fn, backend="aot_eager", fullgraph=True)(x)
|
|
self.assertEqual(actual, expected)
|
|
|
|
def test_jvp_jvp(self):
|
|
counters.clear()
|
|
|
|
if check_dynamic_shape_capture():
|
|
self.skipTest("test fails with dynamic shapes")
|
|
|
|
def fn(x):
|
|
return torch.func.jvp(torch.sin, (x,), (x,))
|
|
|
|
def wrapper_fn(x):
|
|
return torch.func.jvp(fn, (x,), (x,))
|
|
|
|
x = torch.randn(3, 3, 3)
|
|
wrapped_gm = self._compile_check(wrapper_fn, (x,))
|
|
|
|
# Dynamic shapes produce a slightly different graph.
|
|
if check_dynamic_shape_capture():
|
|
return
|
|
|
|
actual = normalize_gm(wrapped_gm.print_readable(print_output=False))
|
|
self.assertExpectedInline(
|
|
actual,
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(self, L_x_: "f32[3, 3, 3]"):
|
|
l_x_ = L_x_
|
|
|
|
_jvp_treespec_compare = torch._functorch.eager_transforms._jvp_treespec_compare((l_x_,), (l_x_,)); _jvp_treespec_compare = None
|
|
|
|
_jvp_increment_nesting = torch._C._functorch._jvp_increment_nesting(); _jvp_increment_nesting = None
|
|
_set_fwd_grad_enabled = torch._C._set_fwd_grad_enabled(True); _set_fwd_grad_enabled = None
|
|
_enter_dual_level = torch._C._enter_dual_level(); _enter_dual_level = None
|
|
|
|
_maybe_load_decompositions = torch.autograd.forward_ad._maybe_load_decompositions(); _maybe_load_decompositions = None
|
|
|
|
child: "f32[3, 3, 3]" = torch._make_dual(l_x_, l_x_, level = 0); l_x_ = None
|
|
|
|
_jvp_treespec_compare_1 = torch._functorch.eager_transforms._jvp_treespec_compare((child,), (child,)); _jvp_treespec_compare_1 = None
|
|
|
|
_jvp_increment_nesting_1 = torch._C._functorch._jvp_increment_nesting(); _jvp_increment_nesting_1 = None
|
|
_set_fwd_grad_enabled_1 = torch._C._set_fwd_grad_enabled(True); _set_fwd_grad_enabled_1 = None
|
|
|
|
_maybe_load_decompositions_1 = torch.autograd.forward_ad._maybe_load_decompositions(); _maybe_load_decompositions_1 = None
|
|
|
|
_make_dual_1: "f32[3, 3, 3]" = torch._make_dual(child, child, level = 0); child = None
|
|
|
|
result_duals: "f32[3, 3, 3]" = torch.sin(_make_dual_1); _make_dual_1 = None
|
|
|
|
_unpack_dual = torch._unpack_dual(result_duals, level = 0); result_duals = None
|
|
primal: "f32[3, 3, 3]" = _unpack_dual[0]
|
|
dual: "f32[3, 3, 3]" = _unpack_dual[1]; _unpack_dual = None
|
|
|
|
primals_out_unflatten: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(primal, 2); primal = None
|
|
|
|
tangents_out_unflatten: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(dual, 2); dual = None
|
|
|
|
_set_fwd_grad_enabled_2 = torch._C._set_fwd_grad_enabled(True); _set_fwd_grad_enabled_2 = None
|
|
_jvp_decrement_nesting = torch._C._functorch._jvp_decrement_nesting(); _jvp_decrement_nesting = None
|
|
|
|
_unpack_dual_1 = torch._unpack_dual(primals_out_unflatten, level = 0); primals_out_unflatten = None
|
|
primal_1: "f32[3, 3, 3]" = _unpack_dual_1[0]
|
|
dual_1: "f32[3, 3, 3]" = _unpack_dual_1[1]; _unpack_dual_1 = None
|
|
_unpack_dual_2 = torch._unpack_dual(tangents_out_unflatten, level = 0); tangents_out_unflatten = None
|
|
primal_2: "f32[3, 3, 3]" = _unpack_dual_2[0]
|
|
dual_2: "f32[3, 3, 3]" = _unpack_dual_2[1]; _unpack_dual_2 = None
|
|
|
|
_unwrap_for_grad_2: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(primal_1, 1); primal_1 = None
|
|
_unwrap_for_grad_3: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(primal_2, 1); primal_2 = None
|
|
|
|
_unwrap_for_grad_4: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(dual_1, 1); dual_1 = None
|
|
_unwrap_for_grad_5: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(dual_2, 1); dual_2 = None
|
|
|
|
_exit_dual_level = torch._C._exit_dual_level(0); _exit_dual_level = None
|
|
_set_fwd_grad_enabled_3 = torch._C._set_fwd_grad_enabled(True); _set_fwd_grad_enabled_3 = None
|
|
_jvp_decrement_nesting_1 = torch._C._functorch._jvp_decrement_nesting(); _jvp_decrement_nesting_1 = None
|
|
return (_unwrap_for_grad_2, _unwrap_for_grad_3, _unwrap_for_grad_4, _unwrap_for_grad_5)
|
|
""",
|
|
)
|
|
|
|
def test_jvp_freevar_python_scalar(self):
|
|
counters.clear()
|
|
y = 3
|
|
|
|
def fn(x):
|
|
return (x.sin() + y).sum()
|
|
|
|
def wrapper_fn(x):
|
|
return torch.func.jvp(fn, (x,), (x,))
|
|
|
|
x = torch.randn(3, 3, 3)
|
|
expected = wrapper_fn(x)
|
|
actual = torch.compile(wrapper_fn, backend="aot_eager", fullgraph=True)(x)
|
|
self.assertEqual(actual, expected)
|
|
|
|
def test_linearize_jvp_fn(self):
|
|
counters.clear()
|
|
|
|
def wrapper_fn(x):
|
|
output, jvp_fn = torch.func.linearize(torch.sin, x)
|
|
return output, jvp_fn(x)
|
|
|
|
x = torch.randn(3, 3, 3)
|
|
wrapped_gm = self._compile_check(wrapper_fn, (x,), fullgraph=False, graph_idx=0)
|
|
|
|
# Dynamic shapes produce a slightly different graph.
|
|
if check_dynamic_shape_capture():
|
|
return
|
|
|
|
actual = normalize_gm(wrapped_gm.print_readable(print_output=False))
|
|
self.assertExpectedInline(
|
|
actual,
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(self, L_self_buffers_tensor_constant0_: "f32[3, 3, 3]"):
|
|
l_self_buffers_tensor_constant0_ = L_self_buffers_tensor_constant0_
|
|
|
|
alias_default: "f32[3, 3, 3]" = torch.ops.aten.alias.default(l_self_buffers_tensor_constant0_); l_self_buffers_tensor_constant0_ = None
|
|
|
|
sin_default: "f32[3, 3, 3]" = torch.ops.aten.sin.default(alias_default)
|
|
|
|
alias_default_1: "f32[3, 3, 3]" = torch.ops.aten.alias.default(alias_default)
|
|
|
|
cos_default: "f32[3, 3, 3]" = torch.ops.aten.cos.default(alias_default_1); alias_default_1 = None
|
|
|
|
alias_default_2: "f32[3, 3, 3]" = torch.ops.aten.alias.default(sin_default); alias_default_2 = None
|
|
return (alias_default, cos_default, sin_default)
|
|
""",
|
|
)
|
|
|
|
wrapped_gm = self._compile_check(wrapper_fn, (x,), fullgraph=False, graph_idx=1)
|
|
actual = normalize_gm(wrapped_gm.print_readable(print_output=False))
|
|
self.assertExpectedInline(
|
|
actual,
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(self, L_self_modules_FX_CONST_FOLDED_ATTRS_parameters_0_: "f32[3, 3, 3]", L_self_modules_FX_CONST_FOLDED_ATTRS_parameters_1_: "f32[3, 3, 3]", L_flat_tangents_1_: "f32[3, 3, 3]"):
|
|
l_self_modules_fx_const_folded_attrs_parameters_0_ = L_self_modules_FX_CONST_FOLDED_ATTRS_parameters_0_
|
|
l_self_modules_fx_const_folded_attrs_parameters_1_ = L_self_modules_FX_CONST_FOLDED_ATTRS_parameters_1_
|
|
l_flat_tangents_1_ = L_flat_tangents_1_
|
|
|
|
_new_zeros_with_same_feature_meta_default: "f32[3, 3, 3]" = torch.ops.aten._new_zeros_with_same_feature_meta.default(l_flat_tangents_1_, l_self_modules_fx_const_folded_attrs_parameters_0_); l_self_modules_fx_const_folded_attrs_parameters_0_ = None
|
|
|
|
copy__default: "f32[3, 3, 3]" = torch.ops.aten.copy_.default(_new_zeros_with_same_feature_meta_default, l_flat_tangents_1_); _new_zeros_with_same_feature_meta_default = l_flat_tangents_1_ = None
|
|
|
|
mul_tensor: "f32[3, 3, 3]" = torch.ops.aten.mul.Tensor(copy__default, l_self_modules_fx_const_folded_attrs_parameters_1_); copy__default = l_self_modules_fx_const_folded_attrs_parameters_1_ = None
|
|
return (mul_tensor,)
|
|
""",
|
|
)
|
|
|
|
@config.patch(error_on_recompile=True)
|
|
def test_vmap_recompile(self):
|
|
@torch.compile(backend="eager")
|
|
def fn(x):
|
|
return torch.vmap(lambda x: x.sin())(x)
|
|
|
|
x = torch.zeros(3, 3, 4, 5)
|
|
y = torch.vmap(fn)(x)
|
|
# should not recompile on second call. See Pytorch issue #118493
|
|
y = torch.vmap(fn)(x)
|
|
|
|
@xfailIfTorchDynamo
|
|
@config.patch(error_on_recompile=True)
|
|
def test_vmap_recompile_different_config(self):
|
|
@torch.compile(backend="eager")
|
|
def fn(x):
|
|
return torch.vmap(lambda x: x.sin())(x)
|
|
|
|
x = torch.zeros(3, 3, 4, 5)
|
|
y = torch.vmap(fn)(x)
|
|
with self.assertRaises(torch._dynamo.exc.RecompileError):
|
|
fn(x)
|
|
|
|
@config.patch(error_on_recompile=True)
|
|
def test_vmap_recompile_same_config(self):
|
|
@torch.compile(backend="eager")
|
|
def fn(x):
|
|
return torch.vmap(lambda x: x.sin())(x)
|
|
|
|
x = torch.zeros(3, 3, 4, 5)
|
|
torch.vmap(torch.vmap(fn, randomness="same"), randomness="same")(x)
|
|
with self.assertRaises(torch._dynamo.exc.RecompileError):
|
|
torch.vmap(torch.vmap(fn, randomness="same"), randomness="error")(x)
|
|
|
|
@config.patch(error_on_recompile=True)
|
|
def test_vmap_recompile_with_randomness(self):
|
|
@torch.compile(backend="eager")
|
|
def fn(x):
|
|
return torch.vmap(lambda x: x.sin())(x)
|
|
|
|
x = torch.zeros(3, 3, 4, 5)
|
|
torch.vmap(fn, randomness="same")(x)
|
|
with self.assertRaises(torch._dynamo.exc.RecompileError):
|
|
torch.vmap(fn, randomness="different")(x)
|
|
|
|
def test_vmap_call_torch_compile_fn(self):
|
|
def wrapped_fn(x):
|
|
return x.sin()
|
|
|
|
x = torch.randn(3, 4)
|
|
fn = torch.compile(backend="aot_eager", fullgraph=True)(wrapped_fn)
|
|
|
|
with self.assertRaisesRegex(
|
|
torch._dynamo.exc.Unsupported,
|
|
"Calling torch.func.vmap\\(compiled_fn\\) function from eager mode is not supported",
|
|
):
|
|
torch.func.vmap(fn)(x)
|
|
|
|
def test_grad_call_torch_compile_fn(self):
|
|
def wrapped_fn(x):
|
|
return x.sin().sum()
|
|
|
|
x = torch.randn(3, 4)
|
|
fn = torch.compile(backend="aot_eager", fullgraph=True)(wrapped_fn)
|
|
|
|
with self.assertRaisesRegex(
|
|
torch._dynamo.exc.Unsupported,
|
|
"Calling torch.func.grad\\(compiled_fn\\) function from eager mode is not supported",
|
|
):
|
|
torch.func.grad(fn)(x)
|
|
|
|
def test_jvp_call_torch_compile_fn(self):
|
|
def wrapped_fn(x):
|
|
return x.sin().sum()
|
|
|
|
x = torch.randn(3, 4)
|
|
fn = torch.compile(backend="aot_eager", fullgraph=True)(wrapped_fn)
|
|
|
|
with self.assertRaisesRegex(
|
|
torch._dynamo.exc.Unsupported,
|
|
"Calling torch.func.jvp\\(compiled_fn\\) function from eager mode is not supported",
|
|
):
|
|
torch.func.jvp(fn, (x,), (x,))
|
|
|
|
@config.patch(error_on_recompile=True)
|
|
def test_grad_recompile(self):
|
|
@torch.compile(backend="eager")
|
|
def fn(x):
|
|
return torch.func.grad(torch.sin)(x)
|
|
|
|
x = torch.randn([])
|
|
torch.func.grad(fn)(x)
|
|
# should not recompile on second call
|
|
torch.func.grad(fn)(x)
|
|
|
|
def test_vmap_get_wrapped(self):
|
|
counters.clear()
|
|
|
|
def g(x):
|
|
return x.sin()
|
|
|
|
@torch.compile(backend="aot_eager", fullgraph=True)
|
|
def fn():
|
|
return torch.vmap(g)
|
|
|
|
x = torch.randn(3, 4)
|
|
expected = torch.vmap(g)(x)
|
|
wrapper = fn()
|
|
got = wrapper(x)
|
|
self.assertEqual(expected, got)
|
|
|
|
def test_vmap_with_conditional_graph_break(self):
|
|
def g(x):
|
|
if len(x.shape) < 2:
|
|
torch._dynamo.graph_break()
|
|
return x.sin()
|
|
else:
|
|
return x.cos()
|
|
|
|
@torch.compile(backend="aot_eager")
|
|
def fn(x):
|
|
return torch.vmap(g)(x)
|
|
|
|
counters.clear()
|
|
x = torch.randn(2, 3)
|
|
expected = x.sin()
|
|
got = fn(x)
|
|
self.assertEqual(expected, got)
|
|
self.assertEqual(len(counters["graph_break"]), 1)
|
|
|
|
counters.clear()
|
|
y = torch.randn(2, 3, 4)
|
|
expected = y.cos()
|
|
got = fn(y)
|
|
self.assertEqual(expected, got)
|
|
self.assertEqual(len(counters["graph_break"]), 0)
|
|
|
|
def test_vmap_with_graph_break(self):
|
|
counters.clear()
|
|
|
|
def g(x):
|
|
y = x.cos()
|
|
print("hi")
|
|
return y.sin()
|
|
|
|
def fn(x):
|
|
return torch.vmap(g)(x)
|
|
|
|
x = torch.randn(3, 4)
|
|
opt = torch.compile(fn, backend="aot_eager", fullgraph=False)
|
|
expected = fn(x)
|
|
got = opt(x)
|
|
self.assertEqual(len(counters["graph_break"]), 1)
|
|
self.assertEqual(expected, got)
|
|
|
|
def test_vmap_with_graph_break_2(self):
|
|
counters.clear()
|
|
|
|
def cos(x):
|
|
print("cos")
|
|
return x.cos()
|
|
|
|
def sin(x):
|
|
print("sin")
|
|
return x.sin()
|
|
|
|
def g(x):
|
|
y = cos(x)
|
|
return sin(y)
|
|
|
|
def fn(x):
|
|
return torch.vmap(g, randomness="same")(x)
|
|
|
|
x = torch.randn(3, 4)
|
|
opt = torch.compile(fn, backend="aot_eager", fullgraph=False)
|
|
expected = fn(x)
|
|
got = opt(x)
|
|
self.assertEqual(len(counters["graph_break"]), 1)
|
|
self.assertEqual(expected, got)
|
|
|
|
def test_vmap_with_graph_break_lambda(self):
|
|
counters.clear()
|
|
|
|
def sin(x):
|
|
print("sin")
|
|
return x.sin()
|
|
|
|
def fn(x):
|
|
return torch.vmap(lambda x: sin(x))(x)
|
|
|
|
x = torch.randn(3, 4)
|
|
opt = torch.compile(fn, backend="aot_eager", fullgraph=False)
|
|
expected = fn(x)
|
|
got = opt(x)
|
|
self.assertEqual(len(counters["graph_break"]), 1)
|
|
self.assertEqual(expected, got)
|
|
|
|
def test_vmap(self):
|
|
def fn(x):
|
|
return torch.func.vmap(lambda x: x.sum(0) + x.sum(1))(x)
|
|
|
|
x = torch.randn(3, 3, 3)
|
|
wrapped_gm = self._compile_check(fn, (x,))
|
|
|
|
# Dynamic shapes produce a slightly different graph.
|
|
if check_dynamic_shape_capture():
|
|
return
|
|
|
|
actual = normalize_gm(wrapped_gm.print_readable(print_output=False))
|
|
self.assertExpectedInline(
|
|
actual,
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(self, L_x_: "f32[3, 3, 3]"):
|
|
l_x_ = L_x_
|
|
|
|
lazy_load_decompositions = torch._functorch.vmap.lazy_load_decompositions(); lazy_load_decompositions = None
|
|
|
|
_vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(3, 'error'); _vmap_increment_nesting = None
|
|
|
|
_add_batch_dim: "f32[3, 3]" = torch._C._functorch._add_batch_dim(l_x_, 0, 1); l_x_ = None
|
|
|
|
sum_1: "f32[3]" = _add_batch_dim.sum(0)
|
|
sum_2: "f32[3]" = _add_batch_dim.sum(1); _add_batch_dim = None
|
|
batched_outputs: "f32[3]" = sum_1 + sum_2; sum_1 = sum_2 = None
|
|
|
|
_remove_batch_dim: "f32[3, 3]" = torch._C._functorch._remove_batch_dim(batched_outputs, 1, 3, 0); batched_outputs = None
|
|
|
|
_vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting(); _vmap_decrement_nesting = None
|
|
return (_remove_batch_dim,)
|
|
""",
|
|
)
|
|
|
|
def test_vmap_free_const(self):
|
|
y = 3
|
|
|
|
def fn(x):
|
|
return torch.func.vmap(lambda x: x.sum(0) + x.sum(1) + y)(x)
|
|
|
|
x = torch.randn(3, 3, 3)
|
|
wrapped_gm = self._compile_check(fn, (x,))
|
|
|
|
# Dynamic shapes produce a slightly different graph.
|
|
if check_dynamic_shape_capture():
|
|
return
|
|
|
|
actual = normalize_gm(wrapped_gm.print_readable(print_output=False))
|
|
self.assertExpectedInline(
|
|
actual,
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(self, L_x_: "f32[3, 3, 3]"):
|
|
l_x_ = L_x_
|
|
|
|
lazy_load_decompositions = torch._functorch.vmap.lazy_load_decompositions(); lazy_load_decompositions = None
|
|
|
|
_vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(3, 'error'); _vmap_increment_nesting = None
|
|
|
|
_add_batch_dim: "f32[3, 3]" = torch._C._functorch._add_batch_dim(l_x_, 0, 1); l_x_ = None
|
|
|
|
sum_1: "f32[3]" = _add_batch_dim.sum(0)
|
|
sum_2: "f32[3]" = _add_batch_dim.sum(1); _add_batch_dim = None
|
|
add: "f32[3]" = sum_1 + sum_2; sum_1 = sum_2 = None
|
|
batched_outputs: "f32[3]" = add + 3; add = None
|
|
|
|
_remove_batch_dim: "f32[3, 3]" = torch._C._functorch._remove_batch_dim(batched_outputs, 1, 3, 0); batched_outputs = None
|
|
|
|
_vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting(); _vmap_decrement_nesting = None
|
|
return (_remove_batch_dim,)
|
|
""",
|
|
)
|
|
|
|
def test_vmap_free_tensor(self):
|
|
y = torch.randn(3, 3)
|
|
|
|
def fn(x):
|
|
return torch.func.vmap(lambda x: x.sum(0) + x.sum(1) + y)(x)
|
|
|
|
x = torch.randn(3, 3, 3)
|
|
wrapped_gm = self._compile_check(fn, (x,))
|
|
|
|
# Dynamic shapes produce a slightly different graph.
|
|
if check_dynamic_shape_capture():
|
|
return
|
|
|
|
actual = normalize_gm(wrapped_gm.print_readable(print_output=False))
|
|
self.assertExpectedInline(
|
|
actual,
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(self, L_x_: "f32[3, 3, 3]", L_y_: "f32[3, 3]"):
|
|
l_x_ = L_x_
|
|
l_y_ = L_y_
|
|
|
|
lazy_load_decompositions = torch._functorch.vmap.lazy_load_decompositions(); lazy_load_decompositions = None
|
|
|
|
_vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(3, 'error'); _vmap_increment_nesting = None
|
|
|
|
_add_batch_dim: "f32[3, 3]" = torch._C._functorch._add_batch_dim(l_x_, 0, 1); l_x_ = None
|
|
|
|
sum_1: "f32[3]" = _add_batch_dim.sum(0)
|
|
sum_2: "f32[3]" = _add_batch_dim.sum(1); _add_batch_dim = None
|
|
add: "f32[3]" = sum_1 + sum_2; sum_1 = sum_2 = None
|
|
batched_outputs: "f32[3, 3]" = add + l_y_; add = l_y_ = None
|
|
|
|
_remove_batch_dim: "f32[3, 3, 3]" = torch._C._functorch._remove_batch_dim(batched_outputs, 1, 3, 0); batched_outputs = None
|
|
|
|
_vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting(); _vmap_decrement_nesting = None
|
|
return (_remove_batch_dim,)
|
|
""",
|
|
)
|
|
|
|
def test_vmap_two_inputs(self):
|
|
def fn(x, y):
|
|
return torch.func.vmap(
|
|
lambda x, y: x.sum(0) + x.sum(1) + y, in_dims=(0, 1)
|
|
)(x, y)
|
|
|
|
x = torch.randn(3, 3, 3)
|
|
y = torch.randn(3, 3)
|
|
wrapped_gm = self._compile_check(fn, (x, y))
|
|
|
|
# Dynamic shapes produce a slightly different graph.
|
|
if check_dynamic_shape_capture():
|
|
return
|
|
|
|
actual = normalize_gm(wrapped_gm.print_readable(print_output=False))
|
|
self.assertExpectedInline(
|
|
actual,
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(self, L_x_: "f32[3, 3, 3]", L_y_: "f32[3, 3]"):
|
|
l_x_ = L_x_
|
|
l_y_ = L_y_
|
|
|
|
lazy_load_decompositions = torch._functorch.vmap.lazy_load_decompositions(); lazy_load_decompositions = None
|
|
|
|
_vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(3, 'error'); _vmap_increment_nesting = None
|
|
|
|
_add_batch_dim: "f32[3, 3]" = torch._C._functorch._add_batch_dim(l_x_, 0, 1); l_x_ = None
|
|
_add_batch_dim_1: "f32[3]" = torch._C._functorch._add_batch_dim(l_y_, 1, 1); l_y_ = None
|
|
|
|
sum_1: "f32[3]" = _add_batch_dim.sum(0)
|
|
sum_2: "f32[3]" = _add_batch_dim.sum(1); _add_batch_dim = None
|
|
add: "f32[3]" = sum_1 + sum_2; sum_1 = sum_2 = None
|
|
batched_outputs: "f32[3]" = add + _add_batch_dim_1; add = _add_batch_dim_1 = None
|
|
|
|
_remove_batch_dim: "f32[3, 3]" = torch._C._functorch._remove_batch_dim(batched_outputs, 1, 3, 0); batched_outputs = None
|
|
|
|
_vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting(); _vmap_decrement_nesting = None
|
|
return (_remove_batch_dim,)
|
|
""",
|
|
)
|
|
|
|
def test_vmap_two_inputs_tuple_in_dims(self):
|
|
in_dims = (0, 1)
|
|
|
|
def fn(x, y):
|
|
return torch.func.vmap(
|
|
lambda x, y: x.sum(0) + x.sum(1) + y, in_dims=in_dims
|
|
)(x, y)
|
|
|
|
x = torch.randn(3, 3, 3)
|
|
y = torch.randn(3, 3)
|
|
wrapped_gm = self._compile_check(fn, (x, y))
|
|
|
|
# Dynamic shapes produce a slightly different graph.
|
|
if check_dynamic_shape_capture():
|
|
return
|
|
|
|
actual = normalize_gm(wrapped_gm.print_readable(print_output=False))
|
|
self.assertExpectedInline(
|
|
actual,
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(self, L_x_: "f32[3, 3, 3]", L_y_: "f32[3, 3]"):
|
|
l_x_ = L_x_
|
|
l_y_ = L_y_
|
|
|
|
lazy_load_decompositions = torch._functorch.vmap.lazy_load_decompositions(); lazy_load_decompositions = None
|
|
|
|
_vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(3, 'error'); _vmap_increment_nesting = None
|
|
|
|
_add_batch_dim: "f32[3, 3]" = torch._C._functorch._add_batch_dim(l_x_, 0, 1); l_x_ = None
|
|
_add_batch_dim_1: "f32[3]" = torch._C._functorch._add_batch_dim(l_y_, 1, 1); l_y_ = None
|
|
|
|
sum_1: "f32[3]" = _add_batch_dim.sum(0)
|
|
sum_2: "f32[3]" = _add_batch_dim.sum(1); _add_batch_dim = None
|
|
add: "f32[3]" = sum_1 + sum_2; sum_1 = sum_2 = None
|
|
batched_outputs: "f32[3]" = add + _add_batch_dim_1; add = _add_batch_dim_1 = None
|
|
|
|
_remove_batch_dim: "f32[3, 3]" = torch._C._functorch._remove_batch_dim(batched_outputs, 1, 3, 0); batched_outputs = None
|
|
|
|
_vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting(); _vmap_decrement_nesting = None
|
|
return (_remove_batch_dim,)
|
|
""",
|
|
)
|
|
|
|
def test_vmap_over_vmap_two_inputs(self):
|
|
def fn(x, y):
|
|
return torch.func.vmap(torch.func.vmap(lambda x, y: x + y, in_dims=1))(x, y)
|
|
|
|
x = torch.randn(3, 3, 3)
|
|
y = torch.randn(3, 3, 3)
|
|
wrapped_gm = self._compile_check(fn, (x, y))
|
|
|
|
# Dynamic shapes produce a slightly different graph.
|
|
if check_dynamic_shape_capture():
|
|
return
|
|
|
|
actual = normalize_gm(wrapped_gm.print_readable(print_output=False))
|
|
self.assertExpectedInline(
|
|
actual,
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(self, L_x_: "f32[3, 3, 3]", L_y_: "f32[3, 3, 3]"):
|
|
l_x_ = L_x_
|
|
l_y_ = L_y_
|
|
|
|
lazy_load_decompositions = torch._functorch.vmap.lazy_load_decompositions(); lazy_load_decompositions = None
|
|
|
|
_vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(3, 'error'); _vmap_increment_nesting = None
|
|
|
|
child: "f32[3, 3]" = torch._C._functorch._add_batch_dim(l_x_, 0, 1); l_x_ = None
|
|
child_1: "f32[3, 3]" = torch._C._functorch._add_batch_dim(l_y_, 0, 1); l_y_ = None
|
|
|
|
lazy_load_decompositions_1 = torch._functorch.vmap.lazy_load_decompositions(); lazy_load_decompositions_1 = None
|
|
|
|
_vmap_increment_nesting_1 = torch._C._functorch._vmap_increment_nesting(3, 'error'); _vmap_increment_nesting_1 = None
|
|
|
|
_add_batch_dim_2: "f32[3]" = torch._C._functorch._add_batch_dim(child, 1, 2); child = None
|
|
_add_batch_dim_3: "f32[3]" = torch._C._functorch._add_batch_dim(child_1, 1, 2); child_1 = None
|
|
|
|
batched_outputs: "f32[3]" = _add_batch_dim_2 + _add_batch_dim_3; _add_batch_dim_2 = _add_batch_dim_3 = None
|
|
|
|
batched_outputs_1: "f32[3, 3]" = torch._C._functorch._remove_batch_dim(batched_outputs, 2, 3, 0); batched_outputs = None
|
|
|
|
_vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting(); _vmap_decrement_nesting = None
|
|
|
|
_remove_batch_dim_1: "f32[3, 3, 3]" = torch._C._functorch._remove_batch_dim(batched_outputs_1, 1, 3, 0); batched_outputs_1 = None
|
|
|
|
_vmap_decrement_nesting_1 = torch._C._functorch._vmap_decrement_nesting(); _vmap_decrement_nesting_1 = None
|
|
return (_remove_batch_dim_1,)
|
|
""",
|
|
)
|
|
|
|
def test_vmap_over_vmap_captured(self):
|
|
x = torch.ones(2, 3)
|
|
y = torch.ones(5, 3)
|
|
|
|
def fn(x):
|
|
return torch.func.vmap(torch.func.vmap(lambda y: x * y))(y)
|
|
|
|
wrapped_gm = self._compile_check(fn, (x,))
|
|
|
|
# Dynamic shapes produce a slightly different graph.
|
|
if check_dynamic_shape_capture():
|
|
return
|
|
|
|
actual = normalize_gm(wrapped_gm.print_readable(print_output=False))
|
|
self.assertExpectedInline(
|
|
actual,
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(self, L_y_: "f32[5, 3]", L_x_: "f32[2, 3]"):
|
|
l_y_ = L_y_
|
|
l_x_ = L_x_
|
|
|
|
lazy_load_decompositions = torch._functorch.vmap.lazy_load_decompositions(); lazy_load_decompositions = None
|
|
|
|
_vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(5, 'error'); _vmap_increment_nesting = None
|
|
|
|
child: "f32[3]" = torch._C._functorch._add_batch_dim(l_y_, 0, 1); l_y_ = None
|
|
|
|
lazy_load_decompositions_1 = torch._functorch.vmap.lazy_load_decompositions(); lazy_load_decompositions_1 = None
|
|
|
|
_vmap_increment_nesting_1 = torch._C._functorch._vmap_increment_nesting(3, 'error'); _vmap_increment_nesting_1 = None
|
|
|
|
_add_batch_dim_1: "f32[]" = torch._C._functorch._add_batch_dim(child, 0, 2); child = None
|
|
|
|
batched_outputs: "f32[2, 3]" = l_x_ * _add_batch_dim_1; l_x_ = _add_batch_dim_1 = None
|
|
|
|
batched_outputs_1: "f32[3, 2, 3]" = torch._C._functorch._remove_batch_dim(batched_outputs, 2, 3, 0); batched_outputs = None
|
|
|
|
_vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting(); _vmap_decrement_nesting = None
|
|
|
|
_remove_batch_dim_1: "f32[5, 3, 2, 3]" = torch._C._functorch._remove_batch_dim(batched_outputs_1, 1, 5, 0); batched_outputs_1 = None
|
|
|
|
_vmap_decrement_nesting_1 = torch._C._functorch._vmap_decrement_nesting(); _vmap_decrement_nesting_1 = None
|
|
return (_remove_batch_dim_1,)
|
|
""",
|
|
)
|
|
|
|
def test_vmap_multiple_outputs(self):
|
|
x = torch.ones(2, 4, 3)
|
|
|
|
def fn(x):
|
|
return torch.vmap(lambda x: (x.sum(0), x.sum(1)))(x)
|
|
|
|
wrapped_gm = self._compile_check(fn, (x,))
|
|
|
|
# Dynamic shapes produce a slightly different graph.
|
|
if check_dynamic_shape_capture():
|
|
return
|
|
|
|
actual = normalize_gm(wrapped_gm.print_readable(print_output=False))
|
|
self.assertExpectedInline(
|
|
actual,
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(self, L_x_: "f32[2, 4, 3]"):
|
|
l_x_ = L_x_
|
|
|
|
lazy_load_decompositions = torch._functorch.vmap.lazy_load_decompositions(); lazy_load_decompositions = None
|
|
|
|
_vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(2, 'error'); _vmap_increment_nesting = None
|
|
|
|
_add_batch_dim: "f32[4, 3]" = torch._C._functorch._add_batch_dim(l_x_, 0, 1); l_x_ = None
|
|
|
|
child: "f32[3]" = _add_batch_dim.sum(0)
|
|
child_1: "f32[4]" = _add_batch_dim.sum(1); _add_batch_dim = None
|
|
|
|
_remove_batch_dim: "f32[2, 3]" = torch._C._functorch._remove_batch_dim(child, 1, 2, 0); child = None
|
|
_remove_batch_dim_1: "f32[2, 4]" = torch._C._functorch._remove_batch_dim(child_1, 1, 2, 0); child_1 = None
|
|
|
|
_vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting(); _vmap_decrement_nesting = None
|
|
return (_remove_batch_dim, _remove_batch_dim_1)
|
|
""",
|
|
)
|
|
|
|
def test_vmap_multiple_outputs_diff_dims(self):
|
|
x = torch.ones(2, 4, 3)
|
|
|
|
def fn(x):
|
|
return torch.vmap(lambda x: (x.sum(0), x.sum(1)), out_dims=(1, 0))(x)
|
|
|
|
wrapped_gm = self._compile_check(fn, (x,))
|
|
|
|
# Dynamic shapes produce a slightly different graph.
|
|
if check_dynamic_shape_capture():
|
|
return
|
|
|
|
actual = normalize_gm(wrapped_gm.print_readable(print_output=False))
|
|
self.assertExpectedInline(
|
|
actual,
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(self, L_x_: "f32[2, 4, 3]"):
|
|
l_x_ = L_x_
|
|
|
|
lazy_load_decompositions = torch._functorch.vmap.lazy_load_decompositions(); lazy_load_decompositions = None
|
|
|
|
_vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(2, 'error'); _vmap_increment_nesting = None
|
|
|
|
_add_batch_dim: "f32[4, 3]" = torch._C._functorch._add_batch_dim(l_x_, 0, 1); l_x_ = None
|
|
|
|
child: "f32[3]" = _add_batch_dim.sum(0)
|
|
child_1: "f32[4]" = _add_batch_dim.sum(1); _add_batch_dim = None
|
|
|
|
_remove_batch_dim: "f32[3, 2]" = torch._C._functorch._remove_batch_dim(child, 1, 2, 1); child = None
|
|
_remove_batch_dim_1: "f32[2, 4]" = torch._C._functorch._remove_batch_dim(child_1, 1, 2, 0); child_1 = None
|
|
|
|
_vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting(); _vmap_decrement_nesting = None
|
|
return (_remove_batch_dim, _remove_batch_dim_1)
|
|
""",
|
|
)
|
|
|
|
def test_vmap_multiple_outputs_out_dims_tuple(self):
|
|
x = torch.ones(2, 4, 3)
|
|
out_dims = (1, 0)
|
|
|
|
def fn(x):
|
|
return torch.vmap(lambda x: (x.sum(0), x.sum(1)), out_dims=out_dims)(x)
|
|
|
|
wrapped_gm = self._compile_check(fn, (x,))
|
|
|
|
# Dynamic shapes produce a slightly different graph.
|
|
if check_dynamic_shape_capture():
|
|
return
|
|
|
|
actual = normalize_gm(wrapped_gm.print_readable(print_output=False))
|
|
self.assertExpectedInline(
|
|
actual,
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(self, L_x_: "f32[2, 4, 3]"):
|
|
l_x_ = L_x_
|
|
|
|
lazy_load_decompositions = torch._functorch.vmap.lazy_load_decompositions(); lazy_load_decompositions = None
|
|
|
|
_vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(2, 'error'); _vmap_increment_nesting = None
|
|
|
|
_add_batch_dim: "f32[4, 3]" = torch._C._functorch._add_batch_dim(l_x_, 0, 1); l_x_ = None
|
|
|
|
child: "f32[3]" = _add_batch_dim.sum(0)
|
|
child_1: "f32[4]" = _add_batch_dim.sum(1); _add_batch_dim = None
|
|
|
|
_remove_batch_dim: "f32[3, 2]" = torch._C._functorch._remove_batch_dim(child, 1, 2, 1); child = None
|
|
_remove_batch_dim_1: "f32[2, 4]" = torch._C._functorch._remove_batch_dim(child_1, 1, 2, 0); child_1 = None
|
|
|
|
_vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting(); _vmap_decrement_nesting = None
|
|
return (_remove_batch_dim, _remove_batch_dim_1)
|
|
""",
|
|
)
|
|
|
|
def test_vmap_kwargs(self):
|
|
counters.clear()
|
|
x = torch.ones(2, 3)
|
|
y = torch.randn(2, 3)
|
|
|
|
def fn(x, y):
|
|
return torch.func.vmap(lambda x, y: x + y)(x, y=y)
|
|
|
|
actual = fn(x, y)
|
|
expected = torch.compile(fn, backend="aot_eager", fullgraph=False)(x, y)
|
|
self.assertEqual(len(counters["graph_break"]), 0)
|
|
self.assertEqual(actual, expected)
|
|
|
|
def test_vmap_pytree_inputs(self):
|
|
counters.clear()
|
|
x = torch.ones(2, 3)
|
|
y = torch.randn(2, 3)
|
|
|
|
def vmap_fn(inps):
|
|
x = inps["x"]
|
|
y = inps["y"]
|
|
return x + y
|
|
|
|
def fn(x, y):
|
|
return torch.func.vmap(vmap_fn)({"x": x, "y": y})
|
|
|
|
actual = fn(x, y)
|
|
expected = torch.compile(fn, backend="aot_eager", fullgraph=False)(x, y)
|
|
self.assertEqual(len(counters["graph_break"]), 0)
|
|
self.assertEqual(actual, expected)
|
|
|
|
def test_vmap_side_effects(self):
|
|
counters.clear()
|
|
x = torch.ones(2, 3)
|
|
y = torch.randn(2, 3)
|
|
|
|
some_list = []
|
|
|
|
def f(x, y):
|
|
some_list.append(1)
|
|
return x + y
|
|
|
|
def wrapper_fn(x, y):
|
|
return torch.func.vmap(f)(x, y)
|
|
|
|
actual = wrapper_fn(x, y)
|
|
expected = torch.compile(wrapper_fn, backend="aot_eager", fullgraph=False)(x, y)
|
|
self.assertEqual(len(counters["graph_break"]), 0)
|
|
self.assertEqual(actual, expected)
|
|
self.assertEqual(some_list, [1, 1])
|
|
|
|
@unittest.expectedFailure
|
|
def test_vmap_side_effects_append_input(self):
|
|
counters.clear()
|
|
x = torch.ones(2, 3)
|
|
y = torch.randn(2, 3)
|
|
|
|
some_list = []
|
|
|
|
def f(x, y):
|
|
some_list.append(x)
|
|
return x + y
|
|
|
|
def wrapper_fn(x, y):
|
|
return torch.func.vmap(f)(x, y)
|
|
|
|
actual = wrapper_fn(x, y)
|
|
expected = torch.compile(wrapper_fn, backend="aot_eager", fullgraph=False)(x, y)
|
|
self.assertEqual(len(counters["graph_break"]), 0)
|
|
self.assertEqual(actual, expected)
|
|
|
|
def test_vmap_previous_illegal_op_no_graph_break(self):
|
|
counters.clear()
|
|
|
|
# calling .stride() would previously graph break
|
|
def bad_fn(x):
|
|
y = x.view((4, 3))
|
|
y.stride()
|
|
return y
|
|
|
|
def wrapper_fn(x):
|
|
return torch.func.vmap(bad_fn)(x)
|
|
|
|
x = torch.randn(2, 3, 4)
|
|
actual = wrapper_fn(x)
|
|
expected = torch.compile(wrapper_fn, backend="aot_eager", fullgraph=False)(x)
|
|
self.assertEqual(len(counters["graph_break"]), 0)
|
|
self.assertEqual(actual, expected)
|
|
|
|
def test_vmap_multiple_invocation_in_dims(self):
|
|
counters.clear()
|
|
|
|
def wrapper_fn(x, in_dims):
|
|
return torch.func.vmap(torch.sum, in_dims)(x)
|
|
|
|
x = torch.randn(3, 3, 3, 3)
|
|
cnt = CompileCounter()
|
|
opt = torch.compile(wrapper_fn, backend=cnt, fullgraph=False, dynamic=True)
|
|
expected = wrapper_fn(x, 0), wrapper_fn(x, 1), wrapper_fn(x, 2)
|
|
# Third invocation of `opt` makes `in_dims` as SymInt.
|
|
actual = opt(x, 0), opt(x, 1), opt(x, 2)
|
|
self.assertEqual(expected, actual)
|
|
self.assertEqual(cnt.frame_count, 3)
|
|
self.assertEqual(cnt.op_count, 18)
|
|
|
|
def test_vmap_multiple_invocation_out_dims(self):
|
|
counters.clear()
|
|
|
|
def wrapper_fn(x, out_dims):
|
|
return torch.func.vmap(lambda x: torch.sum(x, 0), out_dims=out_dims)(x)
|
|
|
|
x = torch.randn(3, 3, 3, 3)
|
|
cnt = CompileCounter()
|
|
opt = torch.compile(wrapper_fn, backend=cnt, fullgraph=False, dynamic=True)
|
|
expected = wrapper_fn(x, 0), wrapper_fn(x, 1), wrapper_fn(x, 2)
|
|
# Third invocation of `opt` makes `in_dims` as SymInt.
|
|
actual = opt(x, 0), opt(x, 1), opt(x, 2)
|
|
self.assertEqual(expected, actual)
|
|
self.assertEqual(cnt.frame_count, 3)
|
|
self.assertEqual(cnt.op_count, 18)
|
|
|
|
def test_vmap_new_tensor_in_body(self):
|
|
def fn(x):
|
|
return x + torch.ones(3)
|
|
|
|
def wrapper_fn(x):
|
|
return torch.func.vmap(fn)(x)
|
|
|
|
x = torch.randn(
|
|
3,
|
|
)
|
|
opt = torch.compile(wrapper_fn, backend="aot_eager", fullgraph=True)
|
|
expected = wrapper_fn(x)
|
|
actual = opt(x)
|
|
self.assertEqual(expected, actual)
|
|
|
|
def test_vmap_new_tensor_unused_in_body(self):
|
|
def fn(x):
|
|
return torch.tensor(0.5)
|
|
|
|
def wrapper_fn(x):
|
|
return torch.func.vmap(fn)(x)
|
|
|
|
x = torch.randn(3)
|
|
opt = torch.compile(wrapper_fn, backend="aot_eager", fullgraph=True)
|
|
expected = wrapper_fn(x)
|
|
actual = opt(x)
|
|
self.assertEqual(expected, actual)
|
|
|
|
def test_vmap_new_tensor_implicit_via_op(self):
|
|
def wrapper_fn(x):
|
|
return torch.func.vmap(lambda t: torch.add(t, 0.5))(x)
|
|
|
|
x = torch.randn(3)
|
|
opt = torch.compile(wrapper_fn, backend="aot_eager", fullgraph=True)
|
|
expected = wrapper_fn(x)
|
|
actual = opt(x)
|
|
self.assertEqual(expected, actual)
|
|
|
|
|
|
class ActivationCheckpointingTests(torch._dynamo.test_case.TestCase):
|
|
def _validate(self, fn, backend, *args, skip_check=False, fullgraph=True):
|
|
cloned_args = []
|
|
for arg in args:
|
|
cloned_args.append(arg.clone().detach().requires_grad_(arg.requires_grad))
|
|
|
|
torch.manual_seed(0)
|
|
expected = fn(*args)
|
|
expected.sum().backward()
|
|
|
|
opt_fn = torch.compile(fn, fullgraph=fullgraph, backend=backend)
|
|
torch.manual_seed(0)
|
|
result = opt_fn(*cloned_args)
|
|
result.sum().backward()
|
|
|
|
if not skip_check:
|
|
self.assertEqual(result, expected)
|
|
for arg, cloned_arg in zip(args, cloned_args):
|
|
self.assertEqual(arg.grad, cloned_arg.grad)
|
|
|
|
@requires_cuda
|
|
@torch._functorch.config.patch(functionalize_rng_ops=True)
|
|
def test_function(self):
|
|
def gn(x, y):
|
|
return torch.sigmoid(torch.matmul(x, y))
|
|
|
|
def fn(x, y):
|
|
return torch.utils.checkpoint.checkpoint(
|
|
gn, torch.sin(x), y, use_reentrant=True
|
|
)
|
|
|
|
x = torch.randn(4, 4, requires_grad=True)
|
|
y = torch.randn(4, 4, requires_grad=True)
|
|
|
|
fw_compiler = functools.partial(count_ops, freq=1, op=torch.ops.aten.mm.default)
|
|
bw_compiler = functools.partial(count_ops, freq=2, op=torch.ops.aten.mm.default)
|
|
backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler)
|
|
self._validate(fn, backend, x, y)
|
|
|
|
@requires_cuda
|
|
@torch._functorch.config.patch(functionalize_rng_ops=True)
|
|
def test_function_with_kwargs(self):
|
|
def gn(x, y):
|
|
return torch.sigmoid(torch.matmul(x, y))
|
|
|
|
def fn(x, y):
|
|
return torch.utils.checkpoint.checkpoint(
|
|
gn,
|
|
torch.sin(x),
|
|
y,
|
|
use_reentrant=True,
|
|
preserve_rng_state=False,
|
|
)
|
|
|
|
x = torch.randn(4, 4, requires_grad=True)
|
|
y = torch.randn(4, 4, requires_grad=True)
|
|
|
|
fw_compiler = functools.partial(count_ops, freq=1, op=torch.ops.aten.mm.default)
|
|
bw_compiler = functools.partial(count_ops, freq=2, op=torch.ops.aten.mm.default)
|
|
backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler)
|
|
self._validate(fn, backend, x, y)
|
|
|
|
@requires_cuda
|
|
@torch._functorch.config.patch(functionalize_rng_ops=True)
|
|
def test_dropout(self):
|
|
def gn(x, y):
|
|
return torch.nn.functional.dropout(torch.matmul(x, y), p=0.2)
|
|
|
|
def fn(x, y):
|
|
return torch.utils.checkpoint.checkpoint(
|
|
gn, torch.sin(x), y, use_reentrant=True
|
|
)
|
|
|
|
x = torch.randn(4, 4, device="cuda", requires_grad=True)
|
|
y = torch.randn(4, 4, device="cuda", requires_grad=True)
|
|
|
|
fw_compiler = functools.partial(
|
|
count_ops, freq=1, op=torch.ops.rngprims.philox_rand.default
|
|
)
|
|
# philox_rand is passed from fwd
|
|
bw_compiler = functools.partial(
|
|
count_ops, freq=0, op=torch.ops.rngprims.philox_rand.default
|
|
)
|
|
backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler)
|
|
self._validate(
|
|
fn, backend, x, y, skip_check=True
|
|
) # dropout decomp is known to diverge with eager
|
|
|
|
@requires_cuda
|
|
@torch._functorch.config.patch(functionalize_rng_ops=True)
|
|
def test_dropout_inductor(self):
|
|
def gn(x, y):
|
|
return torch.nn.functional.dropout(torch.matmul(x, y), p=0.2)
|
|
|
|
def fn(x, y):
|
|
return torch.utils.checkpoint.checkpoint(
|
|
gn, torch.sin(x), y, use_reentrant=True
|
|
)
|
|
|
|
x = torch.randn(4, 4, device="cuda", requires_grad=True)
|
|
y = torch.randn(4, 4, device="cuda", requires_grad=True)
|
|
|
|
backend = "inductor"
|
|
self._validate(
|
|
fn, backend, x, y, skip_check=True
|
|
) # dropout decomp is known to diverge with eager
|
|
|
|
@requires_cuda
|
|
@torch._functorch.config.patch(functionalize_rng_ops=True)
|
|
def test_fallback(self):
|
|
def gn(x, y):
|
|
torch._dynamo.graph_break()
|
|
return torch.sigmoid(torch.matmul(x, y))
|
|
|
|
def fn(x, y):
|
|
return torch.cos(
|
|
torch.utils.checkpoint.checkpoint(
|
|
gn, torch.sin(x), y, use_reentrant=True
|
|
),
|
|
)
|
|
|
|
x = torch.randn(4, 4, requires_grad=True)
|
|
y = torch.randn(4, 4, requires_grad=True)
|
|
args = (x, y)
|
|
|
|
backend = EagerAndRecordGraphs()
|
|
cnt = CompileCounterWithBackend(backend)
|
|
|
|
expected = fn(*args)
|
|
result = torch.compile(fn, backend=cnt)(*args)
|
|
|
|
self.assertEqual(result, expected)
|
|
|
|
# One graph for torch.sin on the input, and other for torch.cos.
|
|
self.assertEqual(cnt.frame_count, 2)
|
|
self.assertEqual(cnt.op_count, 2)
|
|
self.assertEqual(len(backend.graphs), 2)
|
|
|
|
@requires_cuda
|
|
@torch._functorch.config.patch(functionalize_rng_ops=True)
|
|
def test_module(self):
|
|
class MockModule(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(10, 10)
|
|
|
|
def forward(self, x):
|
|
return torch.sigmoid(self.linear(x))
|
|
|
|
mod = MockModule()
|
|
|
|
def fn(x):
|
|
return torch.utils.checkpoint.checkpoint(
|
|
mod, torch.sin(x), use_reentrant=True
|
|
)
|
|
|
|
x = torch.randn(10, 10, requires_grad=True)
|
|
|
|
fw_compiler = functools.partial(
|
|
count_ops, freq=1, op=torch.ops.aten.sigmoid.default
|
|
)
|
|
# sigmoid passed from fwd
|
|
bw_compiler = functools.partial(
|
|
count_ops, freq=0, op=torch.ops.aten.sigmoid.default
|
|
)
|
|
backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler)
|
|
self._validate(fn, backend, x)
|
|
|
|
def test_override_fallthrough_dispatch_key(self):
|
|
class _FallthroughTestOnly(torch._ops.HigherOrderOperator):
|
|
def __init__(self):
|
|
super().__init__("_fallthrough_test_only")
|
|
|
|
def __call__(self, *args, **kwargs):
|
|
return super().__call__(*args, **kwargs)
|
|
|
|
test_op = _FallthroughTestOnly()
|
|
default_keys = torch._ops._HIGHER_ORDER_OP_DEFAULT_FALLTHROUGH_DISPATCH_KEYS
|
|
self.assertTrue(
|
|
not any(test_op.non_fallthrough_keys.has(key) for key in default_keys)
|
|
)
|
|
|
|
foos = [lambda x=i: x for i, k in enumerate(default_keys)]
|
|
for foo, fallthrough_key in zip(foos, default_keys):
|
|
test_op.py_impl(fallthrough_key)(foo)
|
|
|
|
self.assertTrue(
|
|
all(test_op.non_fallthrough_keys.has(key) for key in default_keys)
|
|
)
|
|
self.assertEqual(
|
|
list(range(len(default_keys))),
|
|
[test_op.py_kernels[key]() for key in default_keys],
|
|
)
|
|
|
|
def test_cond_with_kwargs(self):
|
|
from torch._higher_order_ops.cond import cond_op
|
|
|
|
def test(pred, x):
|
|
def true_fn(x):
|
|
return x
|
|
|
|
def false_fn(x):
|
|
return -x
|
|
|
|
return cond_op(pred=pred, true_fn=true_fn, false_fn=false_fn, operands=[x])
|
|
|
|
cnt = CompileCounter()
|
|
opt_test = torch.compile(test, backend=cnt, fullgraph=True)
|
|
inp = torch.ones(3, 3)
|
|
true_pred = torch.Tensor([True])
|
|
false_pred = torch.Tensor([False])
|
|
self.assertTrue(torch.allclose(test(true_pred, inp), opt_test(true_pred, inp)))
|
|
self.assertEqual(cnt.frame_count, 1)
|
|
self.assertTrue(
|
|
torch.allclose(test(false_pred, inp), opt_test(false_pred, inp))
|
|
)
|
|
self.assertEqual(cnt.frame_count, 1)
|
|
|
|
def test_cond_with_invalid_kwargs(self):
|
|
from torch._higher_order_ops.cond import cond_op
|
|
|
|
def test(pred, mode, x):
|
|
def true_fn(x):
|
|
return x
|
|
|
|
def false_fn(x):
|
|
return -x
|
|
|
|
if mode:
|
|
return cond_op(
|
|
pred=pred,
|
|
true_fn=true_fn,
|
|
false_fn=false_fn,
|
|
operands=[x],
|
|
invalid=True,
|
|
)
|
|
else:
|
|
return cond_op(
|
|
pred,
|
|
pred=pred,
|
|
true_fn=true_fn,
|
|
false_fn=false_fn,
|
|
operands=[x],
|
|
)
|
|
|
|
cnt = CompileCounter()
|
|
opt_test = torch.compile(test, backend=cnt)
|
|
inp = torch.ones(3, 3)
|
|
with self.assertRaises(torch._dynamo.exc.UncapturedHigherOrderOpError):
|
|
opt_test(True, True, inp)
|
|
|
|
with self.assertRaises(AssertionError):
|
|
opt_test(True, False, inp)
|
|
|
|
def test_non_aliasing_util(self):
|
|
from torch._dynamo.variables.higher_order_ops import _assert_tensors_nonaliasing
|
|
|
|
a = [torch.tensor(1), {"a": torch.tensor(1)}]
|
|
b = (torch.tensor(1),)
|
|
_assert_tensors_nonaliasing(a, b)
|
|
|
|
with self.assertRaisesRegex(
|
|
AssertionError, "inputs to function body cannot alias outputs"
|
|
):
|
|
_assert_tensors_nonaliasing(a, a)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
from torch._dynamo.test_case import run_tests
|
|
|
|
run_tests()
|