pytorch/test/inductor/test_inplacing_pass.py

462 lines
15 KiB
Python

# Owner(s): ["module: inductor"]
from typing import List
import torch
import torch._inductor.config as inductor_config
from functorch import make_fx
from torch import Tensor
from torch._dynamo.utils import ReinplaceCounters
from torch._higher_order_ops.auto_functionalize import (
auto_functionalized,
auto_functionalized_v2,
)
from torch._inductor.fx_passes.reinplace import reinplace_inplaceable_ops_core
from torch._inductor.test_case import run_tests, TestCase as InductorTestCase
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
IS_LINUX,
parametrize,
subtest,
)
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU
from torch.testing._internal.logging_utils import logs_to_string
aten = torch.ops.aten
const = torch.tensor(0.0)
device = GPU_TYPE
def num_reinplacing_failures():
return ReinplaceCounters.get_total_missed()
def miss_inplaced_bytes():
return ReinplaceCounters.get_total_missed_bytes()
@torch.library.custom_op("_reinplacing::sin", mutates_args={"result"})
def sin(x: torch.Tensor, result: torch.Tensor) -> None:
result.copy_(x.sin())
@torch.library.custom_op("_reinplacing::sin_cos", mutates_args={"out_sin", "out_cos"})
def sin_cos(x: torch.Tensor, out_sin: torch.Tensor, out_cos: torch.Tensor) -> None:
out_sin.copy_(x.sin())
out_cos.copy_(x.cos())
if HAS_GPU:
import triton # @manual
import triton.language as tl # @manual
@triton.jit
def sin_kernel(
in_ptr0,
out_ptr,
n_elements,
BLOCK_SIZE: "tl.constexpr",
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(in_ptr0 + offsets, mask=mask)
output = tl.sin(x)
tl.store(out_ptr + offsets, output, mask=mask)
def sin_triton(x, out):
n_elements = x.numel()
sin_kernel[(n_elements,)](x, out, n_elements, BLOCK_SIZE=4)
else:
def sin_triton(x, out):
return
@torch.library.custom_op("test_view::boo", mutates_args={"x"})
def boo(x: torch.Tensor) -> None:
x.sin_()
class TestReinplacingPassCorrectness(InductorTestCase):
def setUp(self):
ReinplaceCounters.clear()
return super().setUp()
def _test(self, f):
nf = torch.compile(f)
inp = (
torch.randn(4, device=device),
torch.ones(2, device=device, dtype=torch.int),
)
inp2 = (inp[0].clone(), inp[1].clone())
self.assertEqual(f(*inp), nf(*inp2))
self.assertEqual(inp, inp2)
def test_dont_modify_live(self):
def f(x, y):
x = x.cos()
x2 = x.index_put((y,), const)
return x2, x
self._test(f)
def test_dont_modify_view_of_live(self):
def f(x, y):
x = x.cos()
x2 = aten.alias(x)
x2 = x2.index_put((y,), const)
y = x2 + x.cos()
return y
self._test(f)
def test_dont_modify_input(self):
def f(x, y):
return x.index_put((y,), const)
self._test(f)
def test_should_modify_inner(self):
def f(x, y):
x = x.cos()
x = x.index_put((y,), const)
return x
self._test(f)
def test_should_modify_input(self):
def f(x, y):
x = x.index_put_((y,), const)
return x
self._test(f)
def test_counters_functionalize_old(self):
ReinplaceCounters.clear()
def f(x):
out = torch.empty_like(x)
_, new_out = auto_functionalized(sin._opoverload, x=x, result=out)
y = out * new_out
return new_out, y
x = torch.randn(3, device=device)
gm = make_fx(f, tracing_mode="fake")(x)
reinplace_inplaceable_ops_core(gm.graph)
# We shouldn't have been able to reinplace `out` because it was used after
# auto_functionalized. Note that this usually doesn't happen in practice;
# we're artificially creating this example to test the counter.
# IF THIS NUMBER GOES TO ZERO, PLEASE FIND ANOTHER EXAMPLE
self.assertEqual(num_reinplacing_failures(), 1)
self.assertEqual(miss_inplaced_bytes(), 12)
def test_counters_functionalize_v2(self):
ReinplaceCounters.clear()
def f(x):
out = torch.empty_like(x)
_, new_out = auto_functionalized_v2(
sin._opoverload,
x=x,
_result_base_index=0,
_result_size=(3,),
_result_stride=(1,),
_result_storage_offset=0,
_all_bases=[out],
)
y = out * new_out
return new_out, y
x = torch.randn(3, device=device)
gm = make_fx(f, tracing_mode="fake")(x)
reinplace_inplaceable_ops_core(gm.graph)
# We shouldn't have been able to reinplace `out` because it was used after
# auto_functionalized. Note that this usually doesn't happen in practice;
# we're artificially creating this example to test the counter.
# IF THIS NUMBER GOES TO ZERO, PLEASE FIND ANOTHER EXAMPLE
self.assertEqual(num_reinplacing_failures(), 1)
def get_not_inplaced_count(self, graph):
counter = 0
auto_functionalized_found = False
for node in graph.nodes:
if (node.target == torch.ops.higher_order.auto_functionalized) or (
node.target == torch.ops.higher_order.auto_functionalized_v2
):
auto_functionalized_found = True
counter += len(node.meta["only_clone_these_tensors"])
assert auto_functionalized_found
return counter
def test_view_inplaced_functionalize_v2(self):
def f(arg0_1):
torch.ops.aten.select.int(arg0_1, 0, 0)
auto_functionalized = auto_functionalized_v2(
torch.ops.test_view.boo.default,
_x_base_index=0,
_x_size=(3,),
_x_stride=(1,),
_x_storage_offset=0,
_all_bases=[arg0_1],
)
getitem_1 = auto_functionalized[1]
torch.ops.aten.copy_.default(arg0_1, getitem_1)
return ()
x1 = torch.randn(3, device=device)
gm = make_fx(f, tracing_mode="fake")(x1)
reinplace_inplaceable_ops_core(gm.graph)
self.assertEqual(self.get_not_inplaced_count(gm.graph), 0)
# introduce a view another_view that is used `after` the copy
def test_view_inplaced2_functionalize_v2(self):
def f(arg0_1):
_select = torch.ops.aten.select.int(arg0_1, 0, 0)
another_view = arg0_1[2]
auto_functionalized = auto_functionalized_v2(
torch.ops.test_view.boo.default,
_x_base_index=0,
_x_size=(3,),
_x_stride=(1,),
_x_storage_offset=0,
_all_bases=[arg0_1],
)
getitem_1 = auto_functionalized[1]
_copy = torch.ops.aten.copy_.default(arg0_1, getitem_1)
return another_view
x1 = torch.randn(3, device=device)
gm = make_fx(f, tracing_mode="fake")(x1)
reinplace_inplaceable_ops_core(gm.graph)
self.assertEqual(self.get_not_inplaced_count(gm.graph), 0)
# introduce a view another_view that is used `before` the copy
def test_views_not_inplaced_functionalize_v2(self):
def f(arg0_1):
_select = torch.ops.aten.select.int(arg0_1, 0, 0)
another_view = arg0_1[2]
auto_functionalized = auto_functionalized_v2(
torch.ops.test_view.boo.default,
_x_base_index=0,
_x_size=(3,),
_x_stride=(1,),
_x_storage_offset=0,
_all_bases=[arg0_1],
)
getitem_1 = auto_functionalized[1]
use_another_view = another_view * 10
_copy = torch.ops.aten.copy_.default(arg0_1, getitem_1)
return use_another_view
x1 = torch.randn(3, device=device)
gm = make_fx(f, tracing_mode="fake")(x1)
reinplace_inplaceable_ops_core(gm.graph)
self.assertEqual(self.get_not_inplaced_count(gm.graph), 1)
# a view over input without copy node, inplace not allowed
def test_views_not_inplaced2_functionalize_v2(self):
def f(arg0_1):
_select = torch.ops.aten.select.int(arg0_1, 0, 0)
_another_view = arg0_1[2]
auto_functionalized = auto_functionalized_v2(
torch.ops.test_view.boo.default,
_x_base_index=0,
_x_size=(3,),
_x_stride=(1,),
_x_storage_offset=0,
_all_bases=[arg0_1],
)
_getitem_1 = auto_functionalized[1]
return
x1 = torch.randn(3, device=device)
gm = make_fx(f, tracing_mode="fake")(x1)
reinplace_inplaceable_ops_core(gm.graph)
self.assertEqual(self.get_not_inplaced_count(gm.graph), 1)
# no copy nodes, view over local, with a use for another view
def test_views_not_inplaced3_functionalize_v2(self):
def f(arg0_1):
a = torch.ones(10)
another_view = a[2]
auto_functionalized = auto_functionalized_v2(
torch.ops.test_view.boo.default,
_x_base_index=0,
_x_size=(),
_x_stride=(),
_x_storage_offset=0,
_all_bases=[a],
)
_getitem_1 = auto_functionalized[1]
return another_view
x1 = torch.randn(3, device=device)
gm = make_fx(f, tracing_mode="fake")(x1)
reinplace_inplaceable_ops_core(gm.graph)
self.assertEqual(self.get_not_inplaced_count(gm.graph), 1)
def test_multi_output_intermediate(self):
for requires_grad in [False, True]:
for enable_v2 in [False, True]:
with inductor_config.patch(
{"enable_auto_functionalized_v2": enable_v2}
):
ReinplaceCounters.clear()
def f(x):
out1 = torch.empty_like(x)
out2 = torch.empty_like(x)
sin_cos(x, out1, out2)
return out1, out2, x**2
x = torch.randn(3, device=device, requires_grad=requires_grad)
res1, res2, _ = torch.compile(f)(x)
self.assertEqual(res1, x.sin())
self.assertEqual(res2, x.cos())
self.assertEqual(num_reinplacing_failures(), 0)
def test_multiple_mutations(self):
ReinplaceCounters.clear()
def f(x, out):
sin(x, out)
sin(out, out)
sin(out, out)
return out
x = torch.randn(3, device=device)
out = torch.randn(3, device=device)
result = torch.compile(f)(x, out)
self.assertEqual(result, x.sin().sin().sin())
self.assertEqual(result, out)
self.assertEqual(num_reinplacing_failures(), 0)
def test_multiple_intermediate(self):
ReinplaceCounters.clear()
def f(x):
out = torch.empty_like(x)
sin(x, out)
sin(out, out)
sin(out, out)
return out
x = torch.randn(3, device=device)
result = torch.compile(f)(x)
self.assertEqual(result, x.sin().sin().sin())
self.assertEqual(num_reinplacing_failures(), 0)
def test_lists_functionalize_v2(self):
with inductor_config.patch({"enable_auto_functionalized_v2": True}):
@torch.library.custom_op("mylib::mutate_op", mutates_args={"y"})
def mutate_op(y: List[Tensor]) -> None:
y[0].add_(2)
y[1].add_(3)
@torch.compile(fullgraph=True, dynamic=False, backend="inductor")
def f(b):
mutate_op([b[0], b[1]])
x1 = torch.tensor([0.3, 0.4], device=device)
log_stream, ctx = logs_to_string(
"torch._inductor.compile_fx", "post_grad_graphs"
)
with ctx():
torch.compile(f, backend="inductor", fullgraph=True)(x1)
post_grad_graphs = "\n".join(
log_stream.getvalue().strip().split("\n")[3:]
).strip()
# We can inplace the base y. no clones emitted.
self.assertEqual(num_reinplacing_failures(), 0)
self.assertEqual(miss_inplaced_bytes(), 0)
self.assertEqual(post_grad_graphs.count("aten.clone"), 0)
def test_lists_old_functionalize(self):
with inductor_config.patch({"enable_auto_functionalized_v2": False}):
@torch.library.custom_op("mylib::mutate_op", mutates_args={"y"})
def mutate_op(y: List[Tensor]) -> None:
y[0].add_(2)
y[1].add_(3)
@torch.compile(fullgraph=True, dynamic=False, backend="inductor")
def f(b):
mutate_op([b[0], b[1]])
x1 = torch.tensor([0.3, 0.4], device=device)
log_stream, ctx = logs_to_string(
"torch._inductor.compile_fx", "post_grad_graphs"
)
with ctx():
torch.compile(f, backend="inductor", fullgraph=True)(x1)
post_grad_graphs = "\n".join(
log_stream.getvalue().strip().split("\n")[3:]
).strip()
# Can't reinplace on views yet (1 for the "entire list" failing to reinplace)
self.assertEqual(num_reinplacing_failures(), 1)
self.assertEqual(miss_inplaced_bytes(), 8)
# Both list inputs failed to reinplace. So we should have emitted clones for them.
self.assertEqual(post_grad_graphs.count("aten.clone"), 2)
@parametrize(
"factory_op",
[
subtest(torch.ones_like, name="ones_like"),
subtest(torch.empty_like, name="empty_like"),
],
)
@parametrize(
"sin_op",
[
subtest(sin, name="sin_op"),
subtest(sin_triton, name="sin_triton"),
],
)
def test_partitioner_recomputes_factory(self, factory_op, sin_op):
class MySin(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
out = factory_op(x)
sin_op(x, out)
ctx.save_for_backward(out)
return out
@staticmethod
def backward(ctx, grad):
(saved,) = ctx.saved_tensors
out = factory_op(grad)
sin_op(saved, out)
return out
@torch.compile(backend="inductor")
def f(x):
return MySin.apply(x)
x = torch.randn(3, requires_grad=True, device=device)
f(x)
self.assertEqual(num_reinplacing_failures(), 0)
instantiate_parametrized_tests(TestReinplacingPassCorrectness)
if __name__ == "__main__":
if IS_LINUX and HAS_GPU:
run_tests(needs="filelock")