pytorch/test/inductor/test_auto_functionalize.py
Laith Sakka c8ab9b06a2 Redesign custom op functionlaization for better re-inplace (#134409)
- The new implementation (auto_functionalized_v2) is enabled by default but can be disable
 using an inductor flag.
- In export mode the old implementation is used.

**Motiviation**
Previous functionalization fails to re-inplace arguments when they are view over other tensors.
see issue https://github.com/pytorch/pytorch/issues/131192
The new functionalization is easier to re-inplace for views.

**A) Functionalizations pass**
consider a program:

```

func(t)
    x = t[0]
    y = t[1]
    foo(x, y) # custom operator with x, y mutable
    return (x, y, t)
```

- To functionalize `foo` we generate a function that operates on the base tensors of the inputs;  (x.base() and y.base())
and record how to regenerates the views out of the base for argument x by recording ```ViewInfo=(x.base(), x.size(), x.stride, x,storage_offset())```

- Due to some limitations on the torch.export arguments format, we have to generate alot of arguments, but this is something we can simplify in the future, for the example above we get the following function.

   ```
   auto_functionalized = torch.ops.higher_order.auto_functionalized(torch.ops.mylib.foo.default,
     _x_base_index = 0, _x_size = (), _x_stride = (), _x_storage_offset = 0 ,
     _y_base_index = 0,_y_size = (), _y_stride = (), _y_storage_offset = 1   ,
     _all_bases = [arg0_1])
   ```
 -  In the code above:
        - _all_bases[t]: refers to a unique set of bases for all foo arguments.
        - for each argument x we have _x_base_index, _x_size, _x_stride, _x_storage_offset that can be used to (1)  regenerate x from _all_bases[_x_base_index] or a copy of a the base.

-  the output of auto_functionalized is foo output , followed by x tensors one for each base in  _all_bases, that is a copy of the base tensor after observing the mutations of the all the arguments that are views of that base.

-  for each use of a base in _all_bases or a view of it , that are after the call to foo, replace it with a view of the new output

 for the function above after functionalization we get :
 ```
    def forward(self, arg0_1: "f32[2][1]cpu"):
        auto_functionalized = torch.ops.higher_order.auto_functionalized(torch.ops.mylib.foo.default, _x_base_index = 0, _x_size = (), _x_stride = (), _x_storage_offset = 0, _y_base_index = 0, _y_size = (), _y_stride = (), _y_storage_offset = 1, _all_bases = [arg0_1])
        getitem_1: "f32[2][1]cpu" = auto_functionalized[1];  auto_functionalized = None
        copy_: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg0_1, getitem_1);  arg0_1 = copy_ = None

        # No stacktrace found for following nodes
        select_2: "f32[][]cpu" = torch.ops.aten.select.int(getitem_1, 0, 0)
        select_3: "f32[][]cpu" = torch.ops.aten.select.int(getitem_1, 0, 1);  getitem_1 = None
        return (select_2, select_3)
```

**B) Semantics of  auto_functionalize**
The new semantics of auto_functionalize is as the following:
1. For each base in all_bases, copy the base and create all_bases copies. (if a base is inplaced we do not need to copy it)
2. For each arg, regenerate the arg from the copy of its base using the view information above.
3. return the original foo output followed by the new bases.

**C) Re-inplace pass**
since auto_functionalize not copy the bases, what we actually inplace is the bases.
 (run just like before but on the beses instead of args).

1. For each base b in _all_bases check if there is any use of base (or its aliases/views) after auto_functionalize (before its overwritten with a copy) if there is not any, then inplace it (avoid copying it in step 1 above).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134409
Approved by: https://github.com/zou3519
2024-09-04 17:08:58 +00:00

1007 lines
44 KiB
Python

# Owner(s): ["module: functionalization"]
import numpy as np
import torch
import torch._dynamo.testing
import torch._inductor.config as inductor_config
import torch._inductor.test_case
import torch.onnx.operators
import torch.utils._pytree as pytree
import torch.utils.cpp_extension
from torch import Tensor
from torch.testing._internal.logging_utils import logs_to_string
class AutoFunctionalizeTests(torch._inductor.test_case.TestCase):
def test_auto_functionalize_can_with_default(self):
with torch.library._scoped_library("mylib", "FRAGMENT") as lib:
torch.library.define(
"mylib::foo",
"(Tensor a, int b, Tensor(d!)? c=None, Tensor? d=None, int e=-1) -> ()",
tags=torch.Tag.pt2_compliant_tag,
lib=lib,
)
@torch.library.impl("mylib::foo", "cpu", lib=lib)
def foo_impl(a, b, c=None, d=None, e=-1):
a + b
return
def f(a, mode):
return torch.ops.mylib.foo(
a,
0,
)
a = torch.tensor([10, 10, 10], dtype=torch.int64)
torch.compile(f)(a, 0)
def test_auto_functionalize_can_with_none_return(self):
with torch.library._scoped_library("mylib", "FRAGMENT") as lib:
lib.define("foo(Tensor x, Tensor(a!) out) -> None")
def foo_impl(x, out):
out.copy_(x)
lib.impl("foo", foo_impl, "CompositeExplicitAutograd")
x = torch.randn(3)
out = torch.zeros(3)
@torch.compile
def f(x, out):
torch.ops.mylib.foo(x, out)
f(x, out)
def test_auto_functionalize_self_as_mutate_arg(self):
with torch.library._scoped_library("mylib", "FRAGMENT") as lib:
lib.define("foo(Tensor(a!) self) -> None")
def foo_impl(self: torch.Tensor) -> None:
self.sin_()
x = torch.randn(3)
lib.impl("foo", foo_impl, "CompositeExplicitAutograd")
@torch.compile(backend="inductor", fullgraph=True)
def f(x):
torch.ops.mylib.foo(x)
f(x)
def test_auto_functionalize_tensorlist(self):
with torch.library._scoped_library("mylib", "FRAGMENT") as lib:
torch.library.define(
"mylib::foo",
"(Tensor all_gather_output, SymInt[] all_gather_input_split_sizes, int dim, Tensor(a!)[] out) -> ()",
tags=torch.Tag.pt2_compliant_tag,
lib=lib,
)
@torch.library.impl("mylib::foo", "cpu", lib=lib)
@torch._dynamo.disable
def foo_impl(all_gather_output, all_gather_input_split_sizes, dim, out):
for o in out:
o.copy_(all_gather_output)
def f(all_gather_output, all_gather_input_split_sizes, dim, out):
torch.ops.mylib.foo(
all_gather_output, all_gather_input_split_sizes, dim, out
)
a = torch.ones(4)
b = [2, 3]
c = 0
d = [torch.empty(4) for _ in range(2)]
orig_args = (a, b, c, d)
compiled_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args)
torch.compile(f, backend="inductor", fullgraph=True)(*compiled_args)
eager_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args)
f(*eager_args)
self.assertEqual(compiled_args, eager_args)
def test_can_auto_functionalize(self):
from torch._higher_order_ops.auto_functionalize import can_auto_functionalize
expected_true = [
"(Tensor(a!) x) -> ()",
"(Tensor(a!) x, Tensor y, Tensor(b!) z, SymInt w, Tensor(c!)? n) -> ()",
"(Tensor(a!) x, Tensor[] y, Tensor(b!) z, SymInt w, Tensor(c!)? n) -> ()",
"(Tensor(a!) x, Tensor y, Tensor(b!)[] z, SymInt w) -> ()",
"(Tensor(a!) x, Tensor y, Tensor(b!) z, SymInt w, Tensor(c!)? n) -> Tensor",
"(Tensor(a!) x, Tensor y, Tensor(b!) z, SymInt w, Tensor(c!)? n) -> (Tensor, Tensor)",
]
expected_false = [
"(Tensor x) -> ()",
"(Tensor(a) x) -> Tensor(a)",
"(Tensor(a!) x) -> Tensor(a!)",
"(Tensor(a!) x, Tensor y, Tensor(b!) z, SymInt w, Tensor(c!)? n) -> Tensor(a)",
"(Tensor(a!) x, Tensor y, Tensor(b!) z, SymInt w, Tensor(c!)? n) -> (Tensor, Tensor(a))",
"(Tensor(a) x, Tensor y, Tensor(b!) z, SymInt w, Tensor(c!)? n) -> (Tensor, Tensor(a))",
"(Tensor(a!) x, Tensor y, Tensor(b!) z, SymInt w, Tensor(c!)? n) -> (Tensor, Tensor[])",
]
for schema in expected_true:
with torch.library._scoped_library("mylib", "FRAGMENT") as lib:
torch.library.define("mylib::a", schema, lib=lib)
self.assertTrue(
can_auto_functionalize(torch.ops.mylib.a.default), msg=schema
)
self.assertFalse(can_auto_functionalize(torch.ops.mylib.a))
for schema in expected_false:
with torch.library._scoped_library("mylib", "FRAGMENT") as lib:
torch.library.define("mylib::a", schema, lib=lib)
self.assertFalse(
can_auto_functionalize(torch.ops.mylib.a.default), msg=schema
)
self.assertFalse(can_auto_functionalize(torch.ops.mylib.a))
@torch._inductor.config.patch(enable_auto_functionalized_v2=False)
def test_auto_functionalize_old(self):
with torch.library._scoped_library("mylib", "FRAGMENT") as lib:
torch.library.define(
"mylib::foo",
"(Tensor(a!) x, Tensor[] y, Tensor(b!) z, SymInt w, Tensor n) -> ()",
tags=torch.Tag.pt2_compliant_tag,
lib=lib,
)
@torch.library.impl("mylib::foo", "cpu", lib=lib)
@torch._dynamo.disable
def foo_impl(x, y, z, w, n):
x.add_(y[0] + w)
z.add_(y[1] + n)
def f(x, y, z, n):
torch.ops.mylib.foo(x, y, z, 2, n)
x = torch.randn(3)
y = (torch.randn(3), torch.randn(3))
z = torch.randn(3)
n = torch.randn(3)
orig_args = (x, y, z, n)
compiled_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args)
log_stream, ctx = logs_to_string(
"torch._inductor.compile_fx", "post_grad_graphs"
)
with ctx():
torch.compile(f, backend="inductor", fullgraph=True)(*compiled_args)
post_grad_graphs = "\n".join(
log_stream.getvalue().strip().split("\n")[3:]
).strip()
# Check the graph under static shapes
if torch._dynamo.config.assume_static_by_default:
self.assertExpectedInline(
post_grad_graphs,
"""\
def forward(self, arg0_1: "f32[3][1]cpu", arg1_1: "f32[3][1]cpu", arg2_1: "f32[3][1]cpu", arg3_1: \
"f32[3][1]cpu", arg4_1: "f32[3][1]cpu"):
# No stacktrace found for following nodes
foo_default = torch.ops.mylib.foo.default(arg4_1, [arg2_1, arg3_1], arg1_1, 2, arg0_1); arg4_1 = arg2_1 = \
arg3_1 = arg1_1 = arg0_1 = foo_default = None
return ()""",
)
eager_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args)
f(*eager_args)
self.assertEqual(compiled_args, eager_args)
@torch._inductor.config.patch(enable_auto_functionalized_v2=False)
def test_auto_functionalize_with_returns_old(self):
with torch.library._scoped_library("mylib", "FRAGMENT") as lib:
torch.library.define(
"mylib::foo",
"(Tensor(a!) x, Tensor[] y, Tensor(b!) z, SymInt w, Tensor n) -> (Tensor, Tensor)",
tags=torch.Tag.pt2_compliant_tag,
lib=lib,
)
@torch.library.impl("mylib::foo", "cpu", lib=lib)
@torch._dynamo.disable
def foo_impl(x, y, z, w, n):
x.add_(y[0] + w)
z.add_(y[1] + n)
return y[0] + w, y[1] + n
@torch.library.impl_abstract("mylib::foo", lib=lib)
def foo_abstract(x, y, z, w, n):
return y[0] + w, y[1] + n
def f(x, y, z, n):
return torch.ops.mylib.foo(x, y, z, 2, n)
x = torch.randn(3)
y = (torch.randn(3), torch.randn(3))
z = torch.randn(3)
n = torch.randn(3)
orig_args = (x, y, z, n)
compiled_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args)
log_stream, ctx = logs_to_string(
"torch._inductor.compile_fx", "post_grad_graphs"
)
with ctx():
compiled_out = torch.compile(f, backend="inductor", fullgraph=True)(
*compiled_args
)
if torch._dynamo.config.assume_static_by_default:
post_grad_graphs = "\n".join(
log_stream.getvalue().strip().split("\n")[3:]
).strip()
self.assertExpectedInline(
post_grad_graphs,
"""\
def forward(self, arg0_1: "f32[3][1]cpu", arg1_1: "f32[3][1]cpu", arg2_1: "f32[3][1]cpu", arg3_1: "f32[3][1]cpu", arg4_1: "f32[3][1]cpu"):
foo_default = torch.ops.mylib.foo.default(arg4_1, [arg2_1, arg3_1], arg1_1, 2, arg0_1); arg4_1 = arg2_1 = arg3_1 = arg1_1 = arg0_1 = None
getitem_4: "f32[3][1]cpu" = foo_default[0]
getitem_5: "f32[3][1]cpu" = foo_default[1]; foo_default = None
return (getitem_4, getitem_5)""", # noqa: B950
ignore_comments=True,
)
eager_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args)
eager_out = f(*eager_args)
self.assertEqual(compiled_args, eager_args)
self.assertEqual(compiled_out, eager_out)
def test_auto_functionalize_on_view(self):
for value in [True, False]:
with torch.library._scoped_library(
"mylib", "FRAGMENT"
) as lib, inductor_config.patch({"enable_auto_functionalized_v2": value}):
torch.library.define(
"mylib::foo",
"(Tensor(a!) x) -> ()",
tags=torch.Tag.pt2_compliant_tag,
lib=lib,
)
@torch.library.impl("mylib::foo", "cpu", lib=lib)
@torch._dynamo.disable
def foo_impl(x):
x_np = x.detach().numpy() # view
np.sin(x_np, out=x_np)
return
x = torch.randn(3)
expected = x.sin()
torch.ops.mylib.foo(x)
assert torch.allclose(x, expected)
@torch.compile(backend="aot_eager_decomp_partition", fullgraph=True)
def f(x):
x = x.clone()
y = x[:]
torch.ops.mylib.foo(y)
return x
y = f(x)
self.assertEqual(y, x.sin())
@torch._inductor.config.patch(enable_auto_functionalized_v2=False)
def test_auto_functionalize_optional_old(self):
with torch.library._scoped_library("mylib", "FRAGMENT") as lib:
torch.library.define(
"mylib::foo",
"(Tensor(a!)? x, Tensor[] y, Tensor(b!)? z, SymInt w, Tensor n) -> ()",
tags=torch.Tag.pt2_compliant_tag,
lib=lib,
)
@torch.library.impl("mylib::foo", "cpu", lib=lib)
@torch._dynamo.disable
def foo_impl(x, y, z, w, n):
if x is not None:
x.add_(y[0] + w)
if z is not None:
z.add_(y[1] + n)
def f(x, y, z, n):
torch.ops.mylib.foo(x, y, z, 2, n)
x = None
y = (torch.randn(3), torch.randn(3))
z = torch.randn(3)
n = torch.randn(3)
orig_args = (x, y, z, n)
compiled_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args)
log_stream, ctx = logs_to_string(
"torch._inductor.compile_fx", "post_grad_graphs"
)
with ctx():
torch.compile(f, backend="inductor", fullgraph=True)(*compiled_args)
if torch._dynamo.config.assume_static_by_default:
post_grad_graphs = "\n".join(
log_stream.getvalue().strip().split("\n")[3:]
).strip()
self.assertExpectedInline(
post_grad_graphs,
"""\
def forward(self, arg0_1: "f32[3][1]cpu", arg1_1: "f32[3][1]cpu", arg2_1: "f32[3][1]cpu", arg3_1: "f32[3][1]cpu"):
# No stacktrace found for following nodes
foo_default = torch.ops.mylib.foo.default(None, [arg2_1, arg3_1], arg1_1, 2, arg0_1); \
arg2_1 = arg3_1 = arg1_1 = arg0_1 = foo_default = None
return ()""",
)
eager_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args)
f(*eager_args)
self.assertEqual(compiled_args, eager_args)
@torch._dynamo.config.patch(
capture_scalar_outputs=True, capture_dynamic_output_shape_ops=True
)
def test_unbacked_auto_functionalize_op(self):
@torch.library.custom_op(
"mylib::mk_image", mutates_args=("decoder",), device_types=["cpu"]
)
def mk_image(decoder: Tensor) -> Tensor:
return torch.randn(2, 3, 4, 5)
@torch.library.register_fake("mylib::mk_image")
def _(decoder: Tensor) -> Tensor:
image_size = [torch.library.get_ctx().new_dynamic_size() for _ in range(4)]
return torch.empty(image_size)
@torch.compile(fullgraph=True)
def f(x):
return torch.ops.mylib.mk_image.default(x)
x = torch.zeros(100, dtype=torch.int64)
f(x)
@torch._inductor.config.patch(enable_auto_functionalized_v2=True)
def test_auto_functionalize_v2(self, _dynamic=False):
with torch.library._scoped_library("mylib", "FRAGMENT") as lib:
torch.library.define(
"mylib::foo",
"(Tensor(a!) x, Tensor[] y, Tensor(b!) z, SymInt w, Tensor n) -> ()",
tags=torch.Tag.pt2_compliant_tag,
lib=lib,
)
@torch.library.impl("mylib::foo", "cpu", lib=lib)
@torch._dynamo.disable
def foo_impl(x, y, z, w, n):
x.add_(y[0] + w)
z.add_(y[1] + n)
def f(x, y, z, n):
torch.ops.mylib.foo(x, y, z, 2, n)
x = torch.randn(3)
y = (torch.randn(3), torch.randn(3))
z = torch.randn(3)
n = torch.randn(3)
orig_args = (x, y, z, n)
compiled_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args)
log_stream, ctx = logs_to_string(
"torch._inductor.compile_fx", "post_grad_graphs"
)
with ctx():
torch.compile(f, backend="inductor", dynamic=_dynamic, fullgraph=True)(
*compiled_args
)
post_grad_graphs = "\n".join(
log_stream.getvalue().strip().split("\n")[3:]
).strip()
if torch._dynamo.config.assume_static_by_default:
if _dynamic:
self.assertExpectedInline(
post_grad_graphs,
"""\
def forward(self, arg0_1: "Sym(s0)", arg1_1: "f32[s0][1]cpu", arg2_1: "f32[s0][1]cpu", arg3_1: "f32[s0][1]cpu", arg4_1: "f32[s0][1]cpu", arg5_1: "f32[s0][1]cpu"):
foo_default = torch.ops.mylib.foo.default(arg5_1, [arg3_1, arg4_1], arg2_1, 2, arg1_1); arg3_1 = arg4_1 = arg1_1 = foo_default = None
copy_: "f32[s0][1]cpu" = torch.ops.aten.copy_.default(arg2_1, arg2_1); arg2_1 = copy_ = None
copy__1: "f32[s0][1]cpu" = torch.ops.aten.copy_.default(arg5_1, arg5_1); arg5_1 = copy__1 = None
return ()""", # noqa: B950
ignore_comments=True,
ignore_empty_lines=True,
)
else:
self.assertExpectedInline(
post_grad_graphs,
"""\
def forward(self, arg0_1: "f32[3][1]cpu", arg1_1: "f32[3][1]cpu", arg2_1: "f32[3][1]cpu", arg3_1: "f32[3][1]cpu", arg4_1: "f32[3][1]cpu"):
foo_default = torch.ops.mylib.foo.default(arg4_1, [arg2_1, arg3_1], arg1_1, 2, arg0_1); arg2_1 = arg3_1 = arg0_1 = foo_default = None
copy_: "f32[3][1]cpu" = torch.ops.aten.copy_.default(arg1_1, arg1_1); arg1_1 = copy_ = None
copy__1: "f32[3][1]cpu" = torch.ops.aten.copy_.default(arg4_1, arg4_1); arg4_1 = copy__1 = None
return ()""", # noqa: B950
ignore_comments=True,
ignore_empty_lines=True,
)
eager_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args)
f(*eager_args)
self.assertEqual(compiled_args, eager_args)
def run_aot_eager(self, f, orig_args, _dynamic=False):
aot_eager_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args)
log_stream, ctx = logs_to_string(
"torch._functorch._aot_autograd.dispatch_and_compile_graph", "aot_graphs"
)
result = None
with ctx():
result = torch.compile(
f, backend="aot_eager", fullgraph=True, dynamic=_dynamic
)(*aot_eager_args)
graph = "\n".join(log_stream.getvalue().strip().split("\n")[4:]).strip()
return [aot_eager_args, result, graph]
def run_inductor(self, f, orig_args, _dynamic=False):
compiled_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args)
log_stream, ctx = logs_to_string(
"torch._inductor.compile_fx", "post_grad_graphs"
)
result = None
with ctx():
result = torch.compile(
f, backend="inductor", fullgraph=True, dynamic=_dynamic
)(*compiled_args)
graph = "\n".join(log_stream.getvalue().strip().split("\n")[3:]).strip()
return [compiled_args, result, graph]
@torch._inductor.config.patch(enable_auto_functionalized_v2=True)
def test_auto_functionalize_with_returns_v2(self):
with torch.library._scoped_library("mylib", "FRAGMENT") as lib:
torch.library.define(
"mylib::foo",
"(Tensor(a!) x, Tensor[] y, Tensor(b!) z, SymInt w, Tensor n) -> (Tensor, Tensor)",
tags=torch.Tag.pt2_compliant_tag,
lib=lib,
)
@torch.library.impl("mylib::foo", "cpu", lib=lib)
@torch._dynamo.disable
def foo_impl(x, y, z, w, n):
x.add_(y[0] + w)
z.add_(y[1] + n)
return y[0] + w, y[1] + n
@torch.library.impl_abstract("mylib::foo", lib=lib)
def foo_abstract(x, y, z, w, n):
return y[0] + w, y[1] + n
def f(x, y, z, n):
return torch.ops.mylib.foo(x, y, z, 2, n)
x = torch.randn(3)
y = (torch.randn(3), torch.randn(3))
z = torch.randn(3)
n = torch.randn(3)
orig_args = (x, y, z, n)
compiled_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args)
log_stream, ctx = logs_to_string(
"torch._inductor.compile_fx", "post_grad_graphs"
)
with ctx():
compiled_out = torch.compile(f, backend="inductor", fullgraph=True)(
*compiled_args
)
if torch._dynamo.config.assume_static_by_default:
post_grad_graphs = "\n".join(
log_stream.getvalue().strip().split("\n")[3:]
).strip()
self.assertExpectedInline(
post_grad_graphs,
"""\
def forward(self, arg0_1: "f32[3][1]cpu", arg1_1: "f32[3][1]cpu", arg2_1: "f32[3][1]cpu", arg3_1: "f32[3][1]cpu", arg4_1: "f32[3][1]cpu"):
foo_default = torch.ops.mylib.foo.default(arg4_1, [arg2_1, arg3_1], arg1_1, 2, arg0_1); arg2_1 = arg3_1 = arg0_1 = None
getitem_4: "f32[3][1]cpu" = foo_default[0]
getitem_5: "f32[3][1]cpu" = foo_default[1]; foo_default = None
copy_: "f32[3][1]cpu" = torch.ops.aten.copy_.default(arg1_1, arg1_1); arg1_1 = copy_ = None
copy__1: "f32[3][1]cpu" = torch.ops.aten.copy_.default(arg4_1, arg4_1); arg4_1 = copy__1 = None
return (getitem_4, getitem_5)""", # noqa: B950
ignore_comments=True,
ignore_empty_lines=True,
)
eager_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args)
eager_out = f(*eager_args)
self.assertEqual(compiled_args, eager_args)
self.assertEqual(compiled_out, eager_out)
# foo takes two inputs that are not views.
@torch._inductor.config.patch(enable_auto_functionalized_v2=True)
def test_auto_functionalize_extra1(self, _dynamic=False):
with torch.library._scoped_library("mylib", "FRAGMENT") as lib:
torch.library.define(
"mylib::foo",
"(Tensor(a!) x, Tensor(b!) y) -> ()",
tags=torch.Tag.pt2_compliant_tag,
lib=lib,
)
@torch.library.impl("mylib::foo", "cpu", lib=lib)
@torch._dynamo.disable
def foo_impl(x, y):
x.sin_()
y.sin_()
def f(x, y):
torch.ops.mylib.foo(x, y)
return x + y
orig_args = (torch.randn(2), torch.randn(2))
[aot_eager_args, result1, graph_aot] = self.run_aot_eager(
f, orig_args, _dynamic
)
[inductor_args, result2, graph_inductor] = self.run_inductor(
f, orig_args, _dynamic
)
eager_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args)
result3 = f(*eager_args)
self.assertEqual(inductor_args, eager_args)
self.assertEqual(inductor_args, aot_eager_args)
self.assertEqual(result3, result1)
self.assertEqual(result3, result2)
if torch._dynamo.config.assume_static_by_default:
if _dynamic:
self.assertExpectedInline(
graph_aot,
"""\
def forward(self, arg0_1: "Sym(s0)", arg1_1: "f32[s0][1]cpu", arg2_1: "f32[s0][1]cpu"):
auto_functionalized_v2 = torch.ops.higher_order.auto_functionalized_v2(torch.ops.mylib.foo.default, _x_base_index = 0, _y_base_index = 1, _all_bases = [arg2_1, arg1_1])
getitem_1: "f32[s0][1]cpu" = auto_functionalized_v2[1]
getitem_2: "f32[s0][1]cpu" = auto_functionalized_v2[2]; auto_functionalized_v2 = None
add: "f32[s0][1]cpu" = torch.ops.aten.add.Tensor(getitem_1, getitem_2)
copy_: "f32[s0][1]cpu" = torch.ops.aten.copy_.default(arg1_1, getitem_2); arg1_1 = getitem_2 = copy_ = None
copy__1: "f32[s0][1]cpu" = torch.ops.aten.copy_.default(arg2_1, getitem_1); arg2_1 = getitem_1 = copy__1 = None
return (add,)""", # noqa: B950
ignore_comments=True,
ignore_empty_lines=True,
)
else:
self.assertExpectedInline(
graph_aot,
"""\
def forward(self, arg0_1: "f32[2][1]cpu", arg1_1: "f32[2][1]cpu"):
auto_functionalized_v2 = torch.ops.higher_order.auto_functionalized_v2(torch.ops.mylib.foo.default, _x_base_index = 0, _y_base_index = 1, _all_bases = [arg1_1, arg0_1])
getitem_1: "f32[2][1]cpu" = auto_functionalized_v2[1]
getitem_2: "f32[2][1]cpu" = auto_functionalized_v2[2]; auto_functionalized_v2 = None
add: "f32[2][1]cpu" = torch.ops.aten.add.Tensor(getitem_1, getitem_2)
copy_: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg0_1, getitem_2); arg0_1 = getitem_2 = copy_ = None
copy__1: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg1_1, getitem_1); arg1_1 = getitem_1 = copy__1 = None
return (add,)""", # noqa: B950
ignore_comments=True,
ignore_empty_lines=True,
)
if torch._dynamo.config.assume_static_by_default:
if _dynamic:
self.assertExpectedInline(
graph_inductor,
"""\
def forward(self, arg0_1: "Sym(s0)", arg1_1: "f32[s0][1]cpu", arg2_1: "f32[s0][1]cpu"):
foo_default = torch.ops.mylib.foo.default(arg2_1, arg1_1); foo_default = None
add: "f32[s0][1]cpu" = torch.ops.aten.add.Tensor(arg2_1, arg1_1)
copy_: "f32[s0][1]cpu" = torch.ops.aten.copy_.default(arg1_1, arg1_1); arg1_1 = copy_ = None
copy__1: "f32[s0][1]cpu" = torch.ops.aten.copy_.default(arg2_1, arg2_1); arg2_1 = copy__1 = None
return (add,)""",
ignore_comments=True,
ignore_empty_lines=True,
)
else:
self.assertExpectedInline(
graph_inductor,
"""\
def forward(self, arg0_1: "f32[2][1]cpu", arg1_1: "f32[2][1]cpu"):
foo_default = torch.ops.mylib.foo.default(arg1_1, arg0_1); foo_default = None
add: "f32[2][1]cpu" = torch.ops.aten.add.Tensor(arg1_1, arg0_1)
copy_: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg0_1, arg0_1); arg0_1 = copy_ = None
copy__1: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg1_1, arg1_1); arg1_1 = copy__1 = None
return (add,)""",
ignore_comments=True,
ignore_empty_lines=True,
)
# foo takes two views on the same input, function does not have return.
@torch._inductor.config.patch(enable_auto_functionalized_v2=True)
def test_auto_functionalize_extra2(self, _dynamic=False):
with torch.library._scoped_library("mylib", "FRAGMENT") as lib:
torch.library.define(
"mylib::foo",
"(Tensor(a!) x, Tensor(b!) y) -> ()",
tags=torch.Tag.pt2_compliant_tag,
lib=lib,
)
@torch.library.impl("mylib::foo", "cpu", lib=lib)
@torch._dynamo.disable
def foo_impl(x, y):
x.sin_()
y.sin_()
def f(x):
a = x[0]
b = x[1]
torch.ops.mylib.foo(a, b)
return
orig_args = [torch.randn(2)]
[aot_eager_args, result1, graph_aot] = self.run_aot_eager(
f, orig_args, _dynamic
)
[inductor_args, result2, graph_inductor] = self.run_inductor(
f, orig_args, _dynamic
)
eager_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args)
result3 = f(*eager_args)
self.assertEqual(inductor_args, eager_args)
self.assertEqual(inductor_args, aot_eager_args)
self.assertEqual(result3, result1)
self.assertEqual(result3, result2)
if torch._dynamo.config.assume_static_by_default:
if _dynamic:
self.assertExpectedInline(
graph_aot,
"""\
def forward(self, arg0_1: "Sym(s0)", arg1_1: "f32[s0][1]cpu"):
auto_functionalized_v2 = torch.ops.higher_order.auto_functionalized_v2(torch.ops.mylib.foo.default, _x_base_index = 0, _x_size = (), _x_stride = (), _x_storage_offset = 0, _y_base_index = 0, _y_size = (), _y_stride = (), _y_storage_offset = 1, _all_bases = [arg1_1])
getitem_1: "f32[s0][1]cpu" = auto_functionalized_v2[1]; auto_functionalized_v2 = None
copy_: "f32[s0][1]cpu" = torch.ops.aten.copy_.default(arg1_1, getitem_1); arg1_1 = getitem_1 = copy_ = None
return ()""", # noqa: B950
ignore_comments=True,
ignore_empty_lines=True,
)
else:
self.assertExpectedInline(
graph_aot,
"""\
def forward(self, arg0_1: "f32[2][1]cpu"):
auto_functionalized_v2 = torch.ops.higher_order.auto_functionalized_v2(torch.ops.mylib.foo.default, _x_base_index = 0, _x_size = (), _x_stride = (), _x_storage_offset = 0, _y_base_index = 0, _y_size = (), _y_stride = (), _y_storage_offset = 1, _all_bases = [arg0_1])
getitem_1: "f32[2][1]cpu" = auto_functionalized_v2[1]; auto_functionalized_v2 = None
copy_: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg0_1, getitem_1); arg0_1 = getitem_1 = copy_ = None
return ()""", # noqa: B950
ignore_comments=True,
ignore_empty_lines=True,
)
# 2. Run with inductor backend
if torch._dynamo.config.assume_static_by_default:
if _dynamic:
self.assertExpectedInline(
graph_inductor,
"""\
def forward(self, arg0_1: "Sym(s0)", arg1_1: "f32[s0][1]cpu"):
as_strided_default: "f32[][]cpu" = torch.ops.aten.as_strided.default(arg1_1, [], [], 0)
as_strided_default_1: "f32[][]cpu" = torch.ops.aten.as_strided.default(arg1_1, [], [], 1)
foo_default = torch.ops.mylib.foo.default(as_strided_default, as_strided_default_1); as_strided_default = as_strided_default_1 = foo_default = None
copy_: "f32[s0][1]cpu" = torch.ops.aten.copy_.default(arg1_1, arg1_1); arg1_1 = copy_ = None
return ()""", # noqa: B950
ignore_comments=True,
ignore_empty_lines=True,
)
else:
self.assertExpectedInline(
graph_inductor,
"""\
def forward(self, arg0_1: "f32[2][1]cpu"):
as_strided_default: "f32[][]cpu" = torch.ops.aten.as_strided.default(arg0_1, [], [], 0)
as_strided_default_1: "f32[][]cpu" = torch.ops.aten.as_strided.default(arg0_1, [], [], 1)
foo_default = torch.ops.mylib.foo.default(as_strided_default, as_strided_default_1); as_strided_default = as_strided_default_1 = foo_default = None
copy_: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg0_1, arg0_1); arg0_1 = copy_ = None
return ()""", # noqa: B950
ignore_comments=True,
ignore_empty_lines=True,
)
# foo takes two views on the same input, function returns both views and the input
@torch._inductor.config.patch(enable_auto_functionalized_v2=True)
def test_auto_functionalize_extra3(self):
with torch.library._scoped_library("mylib", "FRAGMENT") as lib:
torch.library.define(
"mylib::foo",
"(Tensor(a!) x, Tensor(b!) y) -> ()",
tags=torch.Tag.pt2_compliant_tag,
lib=lib,
)
@torch.library.impl("mylib::foo", "cpu", lib=lib)
@torch._dynamo.disable
def foo_impl(x, y):
x.sin_()
y.sin_()
def f(x):
a = x[0]
b = x[1]
torch.ops.mylib.foo(a, b)
return (a, b, x)
orig_args = [torch.randn(2)]
[aot_eager_args, result1, graph_aot] = self.run_aot_eager(f, orig_args)
[inductor_args, result2, graph_inductor] = self.run_inductor(f, orig_args)
eager_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args)
result3 = f(*eager_args)
self.assertEqual(inductor_args, eager_args)
self.assertEqual(inductor_args, aot_eager_args)
self.assertEqual(result3, result1)
self.assertEqual(result3, result2)
if torch._dynamo.config.assume_static_by_default:
self.assertExpectedInline(
graph_aot,
"""\
def forward(self, arg0_1: "f32[2][1]cpu"):
auto_functionalized_v2 = torch.ops.higher_order.auto_functionalized_v2(torch.ops.mylib.foo.default, _x_base_index = 0, _x_size = (), _x_stride = (), _x_storage_offset = 0, _y_base_index = 0, _y_size = (), _y_stride = (), _y_storage_offset = 1, _all_bases = [arg0_1])
getitem_1: "f32[2][1]cpu" = auto_functionalized_v2[1]; auto_functionalized_v2 = None
copy_: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg0_1, getitem_1); arg0_1 = copy_ = None
select_2: "f32[][]cpu" = torch.ops.aten.select.int(getitem_1, 0, 0)
select_3: "f32[][]cpu" = torch.ops.aten.select.int(getitem_1, 0, 1); getitem_1 = None
return (select_2, select_3)""", # noqa: B950
ignore_comments=True,
ignore_empty_lines=True,
)
# 2. Run with inductor backend
if torch._dynamo.config.assume_static_by_default:
self.assertExpectedInline(
graph_inductor,
"""\
def forward(self, arg0_1: "f32[2][1]cpu"):
as_strided_default: "f32[][]cpu" = torch.ops.aten.as_strided.default(arg0_1, [], [], 0)
as_strided_default_1: "f32[][]cpu" = torch.ops.aten.as_strided.default(arg0_1, [], [], 1)
foo_default = torch.ops.mylib.foo.default(as_strided_default, as_strided_default_1); as_strided_default = as_strided_default_1 = foo_default = None
copy_: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg0_1, arg0_1); copy_ = None
select_2: "f32[][]cpu" = torch.ops.aten.select.int(arg0_1, 0, 0)
select_3: "f32[][]cpu" = torch.ops.aten.select.int(arg0_1, 0, 1); arg0_1 = None
return (select_2, select_3)""", # noqa: B950
ignore_comments=True,
ignore_empty_lines=True,
)
# foo takes a mutable list with views in addition to other args.
@torch._inductor.config.patch(enable_auto_functionalized_v2=True)
def test_auto_functionalize_extra4(self):
with torch.library._scoped_library("mylib", "FRAGMENT") as lib:
torch.library.define(
"mylib::foo",
"(Tensor(a!) x, Tensor(b!)[] y) -> ()",
tags=torch.Tag.pt2_compliant_tag,
lib=lib,
)
@torch.library.impl("mylib::foo", "cpu", lib=lib)
@torch._dynamo.disable
def foo_impl(x, y):
x.sin_()
y[0].sin_()
def f(x, y, z):
a = x[0]
b = z[0]
torch.ops.mylib.foo(a, [b, y])
orig_args = [torch.randn(2), torch.randn(2), torch.randn(2)]
[aot_eager_args, result1, graph_aot] = self.run_aot_eager(f, orig_args)
[inductor_args, result2, graph_inductor] = self.run_inductor(f, orig_args)
eager_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args)
result3 = f(*eager_args)
self.assertEqual(inductor_args[2], eager_args[2])
self.assertEqual(inductor_args, aot_eager_args)
self.assertEqual(result3, result1)
self.assertEqual(result3, result2)
if torch._dynamo.config.assume_static_by_default:
self.assertExpectedInline(
graph_aot,
"""\
def forward(self, arg0_1: "f32[2][1]cpu", arg1_1: "f32[2][1]cpu", arg2_1: "f32[2][1]cpu"):
auto_functionalized_v2 = torch.ops.higher_order.auto_functionalized_v2(torch.ops.mylib.foo.default, _x_base_index = 0, _x_size = (), _x_stride = (), _x_storage_offset = 0, _y_length = 2, _y_0_base_index = 1, _y_0_size = (), _y_0_stride = (), _y_0_storage_offset = 0, _y_1_base_index = 2, _all_bases = [arg0_1, arg1_1, arg2_1])
getitem_1: "f32[2][1]cpu" = auto_functionalized_v2[1]
getitem_2: "f32[2][1]cpu" = auto_functionalized_v2[2]
getitem_3: "f32[2][1]cpu" = auto_functionalized_v2[3]; auto_functionalized_v2 = None
copy_: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg0_1, getitem_1); arg0_1 = getitem_1 = copy_ = None
copy__1: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg1_1, getitem_2); arg1_1 = getitem_2 = copy__1 = None
copy__2: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg2_1, getitem_3); arg2_1 = getitem_3 = copy__2 = None
return ()""", # noqa: B950
ignore_comments=True,
ignore_empty_lines=True,
)
# 2. Run with inductor backend
if torch._dynamo.config.assume_static_by_default:
self.assertExpectedInline(
graph_inductor,
"""\
def forward(self, arg0_1: "f32[2][1]cpu", arg1_1: "f32[2][1]cpu", arg2_1: "f32[2][1]cpu"):
as_strided_default: "f32[][]cpu" = torch.ops.aten.as_strided.default(arg0_1, [], [], 0)
as_strided_default_1: "f32[][]cpu" = torch.ops.aten.as_strided.default(arg1_1, [], [], 0)
foo_default = torch.ops.mylib.foo.default(as_strided_default, [as_strided_default_1, arg2_1]); as_strided_default = as_strided_default_1 = foo_default = None
copy_: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg0_1, arg0_1); arg0_1 = copy_ = None
copy__1: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg1_1, arg1_1); arg1_1 = copy__1 = None
copy__2: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg2_1, arg2_1); arg2_1 = copy__2 = None
return ()""", # noqa: B950
ignore_comments=True,
ignore_empty_lines=True,
)
@torch._inductor.config.patch(enable_auto_functionalized_v2=True)
def test_auto_functionalize_optional_v2(self):
with torch.library._scoped_library("mylib", "FRAGMENT") as lib:
torch.library.define(
"mylib::foo",
"(Tensor(a!)? x, Tensor[] y, Tensor(b!)? z, SymInt w, Tensor n) -> ()",
tags=torch.Tag.pt2_compliant_tag,
lib=lib,
)
@torch.library.impl("mylib::foo", "cpu", lib=lib)
@torch._dynamo.disable
def foo_impl(x, y, z, w, n):
if x is not None:
x.add_(y[0] + w)
if z is not None:
z.add_(y[1] + n)
def f(x, y, z, n):
torch.ops.mylib.foo(x, y, z, 2, n)
x = None
y = (torch.randn(3), torch.randn(3))
z = torch.randn(3)
n = torch.randn(3)
orig_args = (x, y, z, n)
compiled_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args)
log_stream, ctx = logs_to_string(
"torch._inductor.compile_fx", "post_grad_graphs"
)
with ctx():
torch.compile(f, backend="inductor", fullgraph=True)(*compiled_args)
if torch._dynamo.config.assume_static_by_default:
post_grad_graphs = "\n".join(
log_stream.getvalue().strip().split("\n")[3:]
).strip()
self.assertExpectedInline(
post_grad_graphs,
"""\
def forward(self, arg0_1: "f32[3][1]cpu", arg1_1: "f32[3][1]cpu", arg2_1: "f32[3][1]cpu", arg3_1: "f32[3][1]cpu"):
foo_default = torch.ops.mylib.foo.default(None, [arg2_1, arg3_1], arg1_1, 2, arg0_1); arg2_1 = arg3_1 = arg0_1 = foo_default = None
copy_: "f32[3][1]cpu" = torch.ops.aten.copy_.default(arg1_1, arg1_1); arg1_1 = copy_ = None
return ()""", # noqa: B950
ignore_comments=True,
ignore_empty_lines=True,
)
eager_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args)
f(*eager_args)
self.assertEqual(compiled_args, eager_args)
@torch._inductor.config.patch(enable_auto_functionalized_v2=False)
def test_inference_mode1_v2(self):
with torch.inference_mode():
self.test_auto_functionalize_extra1()
# In inference mode we do not support inplacing views yet.
@torch._inductor.config.patch(enable_auto_functionalized_v2=True)
def test_inference_mode2_v2(self):
with torch.inference_mode(), torch.library._scoped_library(
"mylib", "FRAGMENT"
) as lib:
torch.library.define(
"mylib::foo",
"(Tensor(a!) x, Tensor(b!) y) -> ()",
tags=torch.Tag.pt2_compliant_tag,
lib=lib,
)
@torch.library.impl("mylib::foo", "cpu", lib=lib)
@torch._dynamo.disable
def foo_impl(x, y):
x.sin_()
y.sin_()
def f(x):
a = x[0]
b = x[1]
torch.ops.mylib.foo(a, b)
return
orig_args = [torch.randn(2)]
[aot_eager_args, result1, graph_aot] = self.run_aot_eager(f, orig_args)
[inductor_args, result2, graph_inductor] = self.run_inductor(f, orig_args)
eager_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args)
result3 = f(*eager_args)
self.assertEqual(inductor_args, eager_args)
self.assertEqual(inductor_args, aot_eager_args)
self.assertEqual(result3, result1)
self.assertEqual(result3, result2)
if torch._dynamo.config.assume_static_by_default:
self.assertExpectedInline(
graph_aot,
"""\
def forward(self, arg0_1: "f32[2][1]cpu"):
select: "f32[][]cpu" = torch.ops.aten.select.int(arg0_1, 0, 0)
select_1: "f32[][]cpu" = torch.ops.aten.select.int(arg0_1, 0, 1)
auto_functionalized_v2 = torch.ops.higher_order.auto_functionalized_v2(torch.ops.mylib.foo.default, _x_base_index = 0, _y_base_index = 1, _all_bases = [select, select_1]); select = select_1 = None
getitem_1: "f32[][]cpu" = auto_functionalized_v2[1]
getitem_2: "f32[][]cpu" = auto_functionalized_v2[2]; auto_functionalized_v2 = None
select_scatter: "f32[2][1]cpu" = torch.ops.aten.select_scatter.default(arg0_1, getitem_1, 0, 0); getitem_1 = None
select_scatter_1: "f32[2][1]cpu" = torch.ops.aten.select_scatter.default(select_scatter, getitem_2, 0, 1); select_scatter = getitem_2 = None
copy_: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg0_1, select_scatter_1); arg0_1 = select_scatter_1 = copy_ = None
return ()""", # noqa: B950
ignore_comments=True,
ignore_empty_lines=True,
)
# 2. Run with inductor backend
if torch._dynamo.config.assume_static_by_default:
self.assertExpectedInline(
graph_inductor,
"""\
def forward(self, arg0_1: "f32[2][1]cpu"):
select: "f32[][]cpu" = torch.ops.aten.select.int(arg0_1, 0, 0)
select_1: "f32[][]cpu" = torch.ops.aten.select.int(arg0_1, 0, 1)
clone_default: "f32[][]cpu" = torch.ops.aten.clone.default(select); select = None
clone_default_1: "f32[][]cpu" = torch.ops.aten.clone.default(select_1); select_1 = None
foo_default = torch.ops.mylib.foo.default(clone_default, clone_default_1); foo_default = None
select_scatter_default: "f32[2][1]cpu" = torch.ops.aten.select_scatter.default(arg0_1, clone_default, 0, 0); clone_default = None
select_scatter_default_1: "f32[2][1]cpu" = torch.ops.aten.select_scatter.default(select_scatter_default, clone_default_1, 0, 1); select_scatter_default = clone_default_1 = None
copy_: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg0_1, select_scatter_default_1); arg0_1 = select_scatter_default_1 = copy_ = None
return ()""", # noqa: B950
ignore_comments=True,
ignore_empty_lines=True,
)
@torch._inductor.config.patch(enable_auto_functionalized_v2=True)
def test_dynamic_v2(self):
self.test_auto_functionalize_v2(_dynamic=True)
@torch._inductor.config.patch(enable_auto_functionalized_v2=True)
def test_dynamic2_v2(self):
self.test_auto_functionalize_extra1(_dynamic=True)
@torch._inductor.config.patch(enable_auto_functionalized_v2=True)
def test_dynamic3_v2(self):
self.test_auto_functionalize_extra2(_dynamic=True)
if __name__ == "__main__":
from torch._inductor.test_case import run_tests
run_tests()