pytorch/test/inductor/test_perf.py
Aaron Orenstein 8c356ce3da Fix lint errors in fbcode (#135614)
Summary: Fixed a bunch of fbcode imports that happened to work but confused autodeps.  After this autodeps still suggests "improvements" to TARGETS (which breaks our builds) but at least it can find all the imports.

Test Plan:
```
fbpython fbcode/tools/build/buck/linters/lint_autoformat.py --linter=autodeps --default-exec-timeout=1800 -- fbcode/caffe2/TARGETS fbcode/caffe2/test/TARGETS
```
Before:
```
ERROR while processing caffe2/test/TARGETS: Cannot find an owner for "test_export" (from caffe2/test/export/testing.py:229) when processing rule "test_export". Please make sure it's listed in the srcs parameter of another rule. See https://fbur$
ERROR while processing caffe2/test/TARGETS: Cannot find an owner for "testing" (from caffe2/test/export/test_export.py:87) when processing rule "test_export". Please make sure it's listed in the srcs parameter of another rule. See https://fburl$
ERROR while processing caffe2/test/TARGETS: Cannot find an owner for "test_export" (from caffe2/test/export/test_serdes.py:9) when processing rule "test_export". Please make sure it's listed in the srcs parameter of another rule. See https://fb$
ERROR while processing caffe2/test/TARGETS: Cannot find an owner for "testing" (from caffe2/test/export/test_serdes.py:10) when processing rule "test_export". Please make sure it's listed in the srcs parameter of another rule. See https://fburl$
ERROR while processing caffe2/test/TARGETS: Cannot find an owner for "testing" (from caffe2/test/export/test_retraceability.py:7) when processing rule "test_export". Please make sure it's listed in the srcs parameter of another rule. See https:$
ERROR while processing caffe2/test/TARGETS: Cannot find an owner for "test_export" (from caffe2/test/export/test_retraceability.py:6) when processing rule "test_export". Please make sure it's listed in the srcs parameter of another rule. See ht$
ERROR while processing caffe2/test/TARGETS: Cannot find an owner for "testing" (from caffe2/test/export/test_export_nonstrict.py:7) when processing rule "test_export". Please make sure it's listed in the srcs parameter of another rule. See http$
ERROR while processing caffe2/test/TARGETS: Cannot find an owner for "test_export" (from caffe2/test/export/test_export_nonstrict.py:6) when processing rule "test_export". Please make sure it's listed in the srcs parameter of another rule. See $
ERROR while processing caffe2/test/TARGETS: Cannot find an owner for "test_export" (from caffe2/test/export/test_export_training_ir_to_run_decomp.py:8) when processing rule "test_export". Please make sure it's listed in the srcs parameter of an$
ERROR while processing caffe2/test/TARGETS: Cannot find an owner for "testing" (from caffe2/test/export/test_export_training_ir_to_run_decomp.py:10) when processing rule "test_export". Please make sure it's listed in the srcs parameter of anoth$
ERROR while processing caffe2/test/TARGETS: Found "//python/typeshed_internal:typeshed_internal_library" owner for "cv2" but it is protected by visibility rules: [] (from caffe2/test/test_bundled_images.py:7) when processing rule "test_bundled_$
ERROR while processing caffe2/test/TARGETS: Cannot find an owner for "caffe2.test.profiler_test_cpp_thread_lib" (from caffe2/test/profiler/test_cpp_thread.py:29) when processing rule "profiler_test_cpp_thread". Please make sure it's listed in t$
ERROR while processing caffe2/test/TARGETS: Cannot find an owner for "torch._utils_internal.get_file_path_2" (from caffe2/test/test_custom_ops.py:23) when processing rule "custom_ops". Please make sure it's listed in the srcs parameter of anoth$
ERROR while processing caffe2/test/TARGETS: Cannot find an owner for "torch._utils_internal.get_file_path_2" (from caffe2/test/test_public_bindings.py:13) when processing rule "public_bindings". Please make sure it's listed in the srcs paramete$
ERROR while processing caffe2/test/TARGETS: Cannot find an owner for "torch._C._profiler.symbolize_tracebacks" (from caffe2/test/test_cuda.py:3348) when processing rule "test_cuda". Please make sure it's listed in the srcs parameter of another $
ERROR while processing caffe2/test/TARGETS: Cannot find an owner for "torch._C._profiler.gather_traceback" (from caffe2/test/test_cuda.py:3348) when processing rule "test_cuda". Please make sure it's listed in the srcs parameter of another rule$
ERROR while processing caffe2/test/TARGETS: Cannot find an owner for include <torch/csrc/autograd/profiler_kineto.h> (from caffe2/test/profiler/test_cpp_thread.cpp:2) when processing profiler_test_cpp_thread_lib.  Some things to try:
```

Differential Revision: D62049222

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135614
Approved by: https://github.com/oulgen, https://github.com/laithsakka
2024-09-13 02:04:34 +00:00

1185 lines
36 KiB
Python

# Owner(s): ["module: inductor"]
import contextlib
import re
from unittest.mock import patch
import functorch
import torch
import torch._inductor.config as config
import torch.autograd
from torch._inductor import metrics
from torch._inductor.compile_fx import compile_fx, compile_fx_inner
from torch._inductor.test_case import TestCase as InductorTestCase
from torch._inductor.utils import run_and_get_code
########################
# Explanation of Tests #
########################
# These tests are all testing *memory accesses* of TorchInductor.
# They are intended to be deterministic performance tests.
# The expect tests are all measuring the number of memory bytes read/written by
# the code that Inductor has generated
#
# If the test is failing because the number became smaller, feel free to lower it.
# On the other hand, if the test is failing because the number became larger,
# that means that your change is leading to *more* memory accesses on this test.
#
# That may still be aceeptable, but be aware that you are likely lowering
# performance for that setting.
#
# Defines all the kernels for tests
from torch.testing._internal.triton_utils import HAS_CUDA, requires_cuda
if HAS_CUDA:
import triton # @manual
import triton.language as tl # @manual
from torch.testing._internal.triton_utils import add_kernel
aten = torch.ops.aten
def compile_but_use_eager(gm, example_inputs):
def inner_compile(gm, *args, **kwargs):
compile_fx_inner(gm, *args, **kwargs)
return gm
return compile_fx(gm, example_inputs, inner_compile=inner_compile)
def count_numel(f, *args):
"""
Assumes all inputs are fp32
"""
metrics.reset()
torch.compile(f, backend=compile_but_use_eager)(*args)
print(metrics.nodes_num_elem)
return str(metrics.num_bytes_accessed // 4)
def count_numel_train(f, *args):
"""
Assumes all inputs are fp32
"""
metrics.reset()
f = torch.compile(f, backend=compile_but_use_eager)
out = f(*args)
res = 0
for o in out:
res += o.mean()
res.backward()
print(metrics.nodes_num_elem)
return str(metrics.num_bytes_accessed // 4)
DEVICE = "cuda"
def T(*size, dtype=torch.float32, device=DEVICE, grad=False):
return torch.randn(size, dtype=dtype, device=device, requires_grad=grad)
def TI(*size, mx=10, dtype=torch.int32, device=DEVICE):
return torch.randint(0, mx, size, dtype=dtype, device=device)
class TestCase(InductorTestCase):
device = DEVICE
class NumBytesMetricTests(TestCase):
"""
Primarily used for sanity testing that the num_bytes_accessed metrics is correct.
"""
def test_pointwise(self):
def f(x):
return x.cos()
inp = (T(10),)
self.assertExpectedInline(count_numel(f, *inp), """20""")
def f(x, y):
return x + y
inp = (T(10), T(10))
self.assertExpectedInline(count_numel(f, *inp), """30""")
def f(x, y):
return x + y
inp = (T(10, 10), T(10))
self.assertExpectedInline(count_numel(f, *inp), """210""")
def f(x):
return x + x
inp = (T(10),)
self.assertExpectedInline(count_numel(f, *inp), """20""")
def f(x):
return x + x.t()
inp = (T(10, 10),)
self.assertExpectedInline(count_numel(f, *inp), """200""")
def f(a, b, c):
return a.cos(), b.sin() + c.sin()
inp = (T(10), T(10), T(10))
self.assertExpectedInline(count_numel(f, *inp), """50""")
def test_reduction(self):
def f(x):
return x.sum(dim=1)
inp = (T(10, 10),)
self.assertExpectedInline(count_numel(f, *inp), """110""")
def f(x):
return x.sum(dim=0)
inp = (T(10, 10),)
self.assertExpectedInline(count_numel(f, *inp), """110""")
def test_extern(self):
def f(x):
return torch.mm(x, x)
inp = (T(10, 10),)
self.assertExpectedInline(count_numel(f, *inp), """200""")
def f(a, b):
return torch.mm(a, b)
inp = (T(10, 10), T(10, 10))
self.assertExpectedInline(count_numel(f, *inp), """300""")
def f(x):
x = x.cos()
x = torch.mm(x, x)
x = x.cos()
return x
inp = (T(10, 10),)
self.assertExpectedInline(count_numel(f, *inp), """600""")
def f(x):
a = x.cos()
b = x.sin()
x = torch.mm(a, b)
return x
inp = (T(10, 10),)
self.assertExpectedInline(count_numel(f, *inp), """600""")
def test_cat(self):
def f(a, b):
return torch.cat([a.sin(), b.sin()])
inp = (T(10), T(10))
self.assertExpectedInline(count_numel(f, *inp), """40""")
def f(a, b):
return torch.cat([a, b])
inp = (T(10), T(10))
self.assertExpectedInline(count_numel(f, *inp), """40""")
def f(a, b):
return torch.cat([a.cos(), b])
inp = (T(10), T(10))
self.assertExpectedInline(count_numel(f, *inp), """40""")
def f(a):
return torch.cat([a.cos(), a.sin()])
inp = (T(10),)
self.assertExpectedInline(count_numel(f, *inp), """30""")
def f(a, b):
return torch.cat([torch.mm(a, a), b.sin()])
inp = (T(10, 10), T(10, 10))
self.assertExpectedInline(count_numel(f, *inp), """400""")
def f(a, b, c):
return torch.cat((a + 1, b + 2, c + 3)) + 10
inp = (T(10, 10), T(10, 10), T(10, 10))
self.assertExpectedInline(count_numel(f, *inp), """600""")
def f(a, b, c, d, e):
return torch.cat((a + 1, b + 2, c + 3, d + 4, e + 5)) + 10
inp = [T(10, 10) for _ in range(5)]
self.assertExpectedInline(count_numel(f, *inp), """1000""")
def f(a, b):
return torch.cat([a.sum(dim=0), b.sum(dim=0)]) + 10
inp = [T(10, 10, 10), T(10, 10, 10)]
self.assertExpectedInline(count_numel(f, *inp), """2600""")
def test_cat_pointwise(self):
def f(a, b):
return torch.cat([torch.softmax(a, dim=-1), torch.softmax(b, dim=-1)])
inp = (T(10, 10), T(10, 10))
self.assertExpectedInline(count_numel(f, *inp), """400""")
def f(a, b):
return torch.cat([torch.softmax(a, dim=-1), torch.softmax(b, dim=-1)]).cos()
inp = (T(10, 10), T(10, 10))
self.assertExpectedInline(count_numel(f, *inp), """680""")
# Should turn into pointwise even if only some of inputs are pointwise.
def f(a, b):
out = torch.cat([a.cos(), torch.mm(b, b)])
return out.cos()
inp = (T(10, 10), T(10, 10))
self.assertExpectedInline(count_numel(f, *inp), """600""")
# Should not turn into pointwise if all inputs are not pointwise
def f(a, b):
out = torch.cat([torch.mm(a, a), torch.mm(b, b)])
return out.cos()
inp = (T(10, 10), T(10, 10))
self.assertExpectedInline(count_numel(f, *inp), """800""")
def f(a, b):
out = torch.cat([a, b])
return out.cos()
inp = (T(10, 10), T(10, 10))
self.assertExpectedInline(count_numel(f, *inp), """400""")
def f(a, b):
b = b.cos()
return torch.cat([a, b])
inp = (T(10, 10), T(10, 10))
self.assertExpectedInline(count_numel(f, *inp), """400""")
def f(a, b):
a = a @ a
return torch.constant_pad_nd(torch.cat([a, b]), [2, 2], 0.5)
inp = (T(10, 10), T(10, 10))
self.assertExpectedInline(count_numel(f, *inp), """680""")
@patch.object(config, "split_cat_fx_passes", False)
@patch.object(
config,
"pre_grad_fusion_options",
{
"batch_linear": {},
"batch_linear_lhs": {},
"batch_layernorm": {},
"batch_tanh": {},
"batch_relu": {},
"batch_sigmoid": {},
},
)
@patch.object(config, "post_grad_fusion_options", {})
def test_cat_pointwise_many_complex_inputs(self):
def f(*inputs):
input = [torch.nn.functional.gelu(val) for val in inputs]
return torch.cat(input) + 10
inp = (T(10, 10) for _ in range(16))
self.assertExpectedInline(count_numel(f, *inp), """6400""")
@patch.object(config, "split_cat_fx_passes", False)
@patch.object(
config,
"pre_grad_fusion_options",
{
"batch_linear": {},
"batch_linear_lhs": {},
"batch_layernorm": {},
"batch_tanh": {},
"batch_relu": {},
"batch_sigmoid": {},
},
)
@patch.object(config, "post_grad_fusion_options", {})
def test_cat_pointwise_many_simple_inputs(self):
def f(*inputs):
input = [torch.nn.functional.relu(val) for val in inputs]
return torch.cat(input) + 10
inp = (T(10, 10) for _ in range(16))
self.assertExpectedInline(count_numel(f, *inp), """9600""")
@patch.object(config, "max_pointwise_cat_inputs", 0)
def test_cat_pointwise_config_option(self):
def f(a, b):
return torch.cat([a + 1, b + 2]) + 3
inp = (T(10, 10), T(10, 10))
self.assertExpectedInline(count_numel(f, *inp), """400""")
def test_index(self):
def f(a, b):
return a[b]
inp = (T(10), TI(10, mx=10))
self.assertExpectedInline(count_numel(f, *inp), """30""")
class FusionTests(TestCase):
"""
Tests that things can be fused into a single kernel
"""
def test_horizontal_reduction_pointwise(self):
def f(a):
b = a.sum(dim=1)
c = a.cos()
return b, c
inp = (T(10, 10),)
self.assertExpectedInline(count_numel(f, *inp), """210""")
def test_horizontal_reduction_reduction(self):
def f(a):
b = a.sum(dim=1)
c = a.amax(dim=1)
return b, c
inp = (T(10, 10),)
self.assertExpectedInline(count_numel(f, *inp), """120""")
def test_horizontal_reduction_pointwise2(self):
def f(a, b):
c = a.sum(dim=1)
b = b.cos()
return b + c
inp = (T(10, 10), T(10))
self.assertExpectedInline(count_numel(f, *inp), """120""")
def test_horizontal_reduction_outer_pointwise(self):
def f(a, b):
c = a.sum(dim=0)
b = b.cos()
return b + c
inp = (T(10, 10), T(10))
self.assertExpectedInline(count_numel(f, *inp), """120""")
def test_horizontal_sum_pw_broadcast(self):
def f(a, b):
a = a.sum(dim=1, keepdim=True)
b = b.cos()
return a * b
inp = (T(10, 10), T(10))
self.assertExpectedInline(count_numel(f, *inp), """210""")
def test_vertical_sum_pw(self):
def f(a):
a = a.cos()
a = a.sum(dim=1)
return a.cos()
inp = (T(10, 10),)
self.assertExpectedInline(count_numel(f, *inp), """110""")
def test_norm_chain(self):
def f(a):
b = a.sum(dim=1, keepdim=True)
a = a * b
b = a.sum(dim=1, keepdim=True)
a = a * b
b = a.sum(dim=1, keepdim=True)
a = a * b
return a
inp = (T(10, 10),)
self.assertExpectedInline(count_numel(f, *inp), """200""")
def test_softmax_inner(self):
def f(a):
return torch.softmax(a, dim=1)
inp = (T(10, 10),)
self.assertExpectedInline(count_numel(f, *inp), """200""")
def test_layer_norm(self):
# TODO: Suboptimal! We shouldn't need to save normalization stats.
mod = torch.nn.LayerNorm(10, device=self.device)
def f(x):
return mod(x)
inp = (T(10, 10),)
with torch.no_grad():
self.assertExpectedInline(count_numel(f, *inp), """220""")
def test_double_softmax(self):
def f(x):
x = torch.softmax(x, dim=1)
x = torch.softmax(x, dim=1)
return x
inp = (T(10, 10),)
self.assertExpectedInline(count_numel(f, *inp), """200""")
def test_softmax_backward(self):
def f(grad_out, out):
return aten._softmax_backward_data(grad_out, out, 1, torch.float32)
inp = (T(10, 10), T(10, 10))
self.assertExpectedInline(count_numel(f, *inp), """300""")
def test_neighbor(self):
def f(a, b):
return ((a - b) ** 2).sum(dim=-1).amax(dim=1)
inp = (T(10, 1, 4), T(1, 10, 4))
self.assertExpectedInline(count_numel(f, *inp), """90""")
def test_factory_reduction(self):
def f():
a = torch.ones(10, device=self.device)
b = torch.ones(10, 10, device=self.device)
return (a + b).sum(dim=-1)
inp = ()
self.assertExpectedInline(count_numel(f, *inp), """10""")
def test_index_pointwise(self):
def f(a, b):
return a[b].cos()
inp = (T(10, 10), TI(20, mx=10))
self.assertExpectedInline(count_numel(f, *inp), """320""")
def test_index_reduction(self):
def f(a, b):
return a[b].cos().sum(dim=1)
inp = (T(10, 10), TI(20, mx=10))
self.assertExpectedInline(count_numel(f, *inp), """140""")
def test_mutation_fusion(self):
def f(a, b, c):
a0 = a.add(c)
b0 = b.add(a0)
b.copy_(b0)
a.copy_(a0)
inp = (T(10, 10), T(10, 10), T(10, 10))
self.assertExpectedInline(count_numel(f, *inp), """500""")
def test_reduction_pointwise_multi_level_reduction(self):
hidden_size = 4096
layer_norm = torch.nn.LayerNorm(hidden_size).cuda().float()
@torch.inference_mode()
def f(x, scale, amax_keep_dim):
x = layer_norm(x.to(dtype=torch.float))
amax = torch.amax(torch.abs(x), keepdim=amax_keep_dim)
x_scaled = x * scale
y = torch.nn.functional.sigmoid(x_scaled)
return (y, amax)
inp = (T(4, 2048, hidden_size, dtype=torch.float), T(1, dtype=torch.float))
# 2 kernels:
# kernel 1: (input = X, scale, LN scale, LN bias, output = LN_pointwise(X), first-level amax (split-reduction))
# kernel 2: (input = first-level amax, output = final amax)
# scale (1) + X (4*2048*hidden_size) * 2 + LN scale (hidden_size) + LN bias (hidden_size) + amax (4 * 2048 * 2 + 1)
expected_numel = (
1 + hidden_size * 2 + 4 * 2048 * hidden_size * 2 + 4 * 2048 * 2 + 1
)
self.assertExpectedInline(count_numel(f, *inp, True), str(expected_numel))
self.assertExpectedInline(count_numel(f, *inp, False), str(expected_numel))
def test_pointwise_multi_level_reduction(self):
# TODO: this can be optimized by having the first pointwise kernel leveraging block sizes
# of the first-level reduction kernel.
hidden_size = 4096
def f(x, scale, amax_keep_dim):
x = x * 1.1
amax = torch.amax(torch.abs(x), keepdim=amax_keep_dim)
x_scaled = x * scale
y = torch.nn.functional.sigmoid(x_scaled)
return (y, amax)
inp = (T(4, 2048, hidden_size, dtype=torch.float), T(1, dtype=torch.float))
compiled_f = torch.compile(f)
compiled_f(*inp, True)
# 3 kernels:
# kernel 1: (input = X, scale, output = pointwise(X))
# kernel 2: (input = X, output = first-level amax)
# kernel 3: (input = first-level amax, output = final amax)
# scale (1) + X (4*2048*hidden_size) * 3 + amax (num_splits * 2 + 1)
# num_splits depends on SM architectures.
expected_numel = 1 + 4 * 2048 * hidden_size * 3 + 1
actual_numel_amax_keep_dim = count_numel(f, *inp, True)
actual_numel_amax_no_keep_dim = count_numel(f, *inp, False)
self.assertEqual(actual_numel_amax_keep_dim, actual_numel_amax_no_keep_dim)
self.assertGreaterAlmostEqual(actual_numel_amax_keep_dim, str(expected_numel))
class SchedulerFusionTests(TestCase):
"""
Testing the fusion group creation heuristic (i.e. cases where we can't fuse
everything into a single kernel)
Disables inductor rematerialization for easier reasoning of tests.
"""
@classmethod
def setUpClass(cls):
super().setUpClass()
cls._stack = contextlib.ExitStack()
cls._stack.enter_context(patch.object(config, "realize_opcount_threshold", 0))
@classmethod
def tearDownClass(cls):
cls._stack.close()
super().tearDownClass()
@patch.object(config, "pattern_matcher", False)
def test_fusion_choice1(self):
# Doesn't matter where we break fusion group here
def f(a):
c = a.cos()
d = torch.mm(c, c)
e = c.cos()
return d + e
inp = (T(10, 10),)
self.assertExpectedInline(count_numel(f, *inp), """700""")
@patch.object(config, "pattern_matcher", False)
def test_fusion_choice2(self):
# We should materialize e (it's smaller!)
# [c, e]: 210, [f]: 210, [d]: 200
def f(a):
c = a.cos()
d = torch.mm(c, c)
e = c.sum(dim=1)
f = d + e
return f
inp = (T(10, 10),)
self.assertExpectedInline(count_numel(f, *inp), """620""")
@patch.object(config, "pattern_matcher", False)
def test_fusion_choice3(self):
# We should materialize e.
# [c, e]: 300, [f]: 300, [d]: 200
def f(a):
c = a.cos()
d = torch.mm(c, c)
e = c + a
f = d + e
return f, e
inp = (T(10, 10),)
self.assertExpectedInline(count_numel(f, *inp), """800""")
@patch.object(config, "pattern_matcher", False)
def test_fusion_choice4_cpu(self):
# Fuse nodes with same number of elements and compatible orginal var ranges
# [buf0: {d0: 60, d1: 11}, buf1: {d0: 660}] -> buf0_buf1
def f(x, w):
o1 = x * w
output = o1 + 1.0
return output
inp = (T(2, 3, 10, 11, device="cpu"), T(11, device="cpu"))
self.assertExpectedInline(count_numel(f, *inp), """1331""")
# [buf0_buf1: {d0: 60, d1: 11}, buf2: {d0: 660}] -> buf0_buf1_buf2
def f(x, w1, w2):
o1 = x * w1
o2 = x * w2
output = o1 + o2
return output
inp = (T(2, 3, 10, 11, device="cpu"), T(11, device="cpu"), T(11, device="cpu"))
self.assertExpectedInline(count_numel(f, *inp), """1342""")
class TilingTests(TestCase):
def test_tiling_simple(self):
def f(a, b):
return a + b.t()
inp = (T(10, 10), T(10, 10))
self.assertExpectedInline(count_numel(f, *inp), """300""")
def f(a, b):
return a.t() + b
inp = (T(10, 10), T(10, 10))
self.assertExpectedInline(count_numel(f, *inp), """300""")
def test_tiling_three(self):
def f(a, b, c):
return a + b.permute(1, 2, 0) + c.permute(2, 0, 1)
inp = (T(10, 10, 10), T(10, 10, 10), T(10, 10, 10))
self.assertExpectedInline(count_numel(f, *inp), """4000""")
class MinCutPartitioningTests(TestCase):
def test_partitioning_full_remat(self):
def f(x):
return x.cos().cos().cos()
inp = (T(10, grad=True),)
self.assertExpectedInline(count_numel_train(f, *inp), """50""")
def test_partitioning_partial_remat(self):
def f(a, b, c, d):
x = a + b + c + d
return x.cos().cos()
inp = (T(10, grad=True), T(10, grad=True), T(10, grad=True), T(10, grad=True))
self.assertExpectedInline(count_numel_train(f, *inp), """90""")
def test_partitioning_dtype(self):
def f(x):
return (x < 0) * x
inp = (T(100, grad=True),)
self.assertExpectedInline(count_numel_train(f, *inp), """450""")
@patch.object(functorch.compile.config, "max_dist_from_bw", 1000)
def test_partitioning_unremat_bw(self):
def f(x):
return torch.mm(x, x.new_ones(x.shape)).tanh().tanh()
inp = (T(10, 10, grad=True),)
self.assertExpectedInline(count_numel_train(f, *inp), """1300""")
@patch.object(config, "pattern_matcher", False)
def test_partitioning_unremat_bw2(self):
def f(a):
a = torch.mm(a, a)
a = a + 1
b = a + 2
c = torch.mm(a, b)
return c
inp = (T(10, 10, grad=True),)
self.assertExpectedInline(count_numel_train(f, *inp), """2600""")
def test_partitioning_keops(self):
def f(a, b):
return (a * b).cos().sum(dim=1)
inp = (T(20, 1, grad=True), T(1, 20, grad=True))
self.assertExpectedInline(count_numel_train(f, *inp), """220""")
def test_partitioning_cat(self):
def f(a, b):
a = torch.tanh(a)
return torch.cat([a, b])
inp = (T(10, grad=True), T(10, grad=True))
self.assertExpectedInline(count_numel_train(f, *inp), """70""")
def test_partitioning_with_view(self):
class Foo(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
y = x.sin()
x = x.cos()
x = x.view(10, 10)
ctx.save_for_backward(x, y)
x = x.cos()
return x
@staticmethod
def backward(ctx, gradOut):
x, y = ctx.saved_tensors
return torch.mm(gradOut, x).view(100) * y
def f(a):
return Foo.apply(a)
inp = (T(100, grad=True),)
# We do not want to recompute the x.cos().view() chain, as it's
# materialized in backwards
self.assertExpectedInline(count_numel_train(f, *inp), """900""")
@patch.object(config, "pattern_matcher", False)
def test_partitioning_long_chain_add(self):
def f(x):
orig = x
for _ in range(2):
x = x * x
x = torch.mm(x, x)
x = x * 2
x = orig + x
orig = x
return x
inp = (T(10, 10, grad=True),)
self.assertExpectedInline(count_numel_train(f, *inp), """3900""")
def unfusible(x):
# For the purpose of noop tests, we want inductor to fall back to
# eager mode, so, below we must use a aten operator that does not
# have decomposition nor lowering:
return aten._lazy_clone(x)
class NoopTests(TestCase):
def test_noop_clones(self):
def f(a):
b = a.clone()
b = unfusible(b)
return b
inp = T(10)
self.assertExpectedInline(count_numel(f, inp), """20""")
def f(a):
b = a.clone()
c = unfusible(b)
return b, c
self.assertExpectedInline(count_numel(f, inp), """40""")
def test_noop_slice_scatter(self):
def f(a):
b = aten.slice_scatter(a, a)
c = unfusible(b)
return c
inp = T(10)
self.assertExpectedInline(count_numel(f, inp), """20""")
def test_noop_dtype_conversion(self):
def f(a):
b = torch.ops.prims.convert_element_type(a, torch.float32)
c = unfusible(b)
return c
inp = T(10)
self.assertExpectedInline(count_numel(f, inp), """20""")
def test_noop_device_conversion(self):
def f(a):
b = torch.ops.prims.device_put(a, "cuda")
c = unfusible(b)
return c
inp = T(10)
self.assertExpectedInline(count_numel(f, inp), """20""")
def test_noop_int_ops(self):
def f1(a):
b = torch.ceil(a)
c = unfusible(b)
return c
def f2(a):
d = torch.floor(a)
e = unfusible(d)
return e
def f3(a):
f = torch.round(a)
g = unfusible(f)
return g
def f4(a):
f = torch.pow(a, 1)
g = unfusible(f)
return g
inp = TI(10)
self.assertExpectedInline(count_numel(f1, inp), """20""")
self.assertExpectedInline(count_numel(f2, inp), """20""")
self.assertExpectedInline(count_numel(f3, inp), """20""")
self.assertExpectedInline(count_numel(f4, inp), """20""")
def test_noop_cat(self):
def f1(a):
b = torch.cat([a])
return unfusible(b)
inp = T(10)
self.assertExpectedInline(count_numel(f1, inp), """20""")
def f2(a):
b = torch.cat([a])
c = torch.cat([b])
return c
self.assertExpectedInline(count_numel(f2, inp), """20""")
class InplacingTests(TestCase):
def test_inplace_scatter(self):
def f(a, b):
a = a.cos()
a[b] = 1
return a
inp = (T(10), TI(2, mx=5))
self.assertExpectedInline(count_numel(f, *inp), """26""")
def f(a, b):
out = aten.index_put(a, (b,), torch.tensor(1.0))
return a.copy_(out)
inp = (T(10), TI(2, mx=5))
self.assertExpectedInline(count_numel(f, *inp), """6""")
def f(a, b):
out = aten._unsafe_index_put(a, (b,), torch.tensor(1.0))
return a.copy_(out)
inp = (T(10), TI(2, mx=5))
self.assertExpectedInline(count_numel(f, *inp), """6""")
def test_inplace_scatter_noop_view(self):
def f(a, b):
a[:, b] = 1
return a
inp = (T(10, 10), TI(2, mx=5))
self.assertExpectedInline(count_numel(f, *inp), """42""")
@requires_cuda
def test_inplace_triton_kernel_training(self):
@triton.jit
def sin_kernel(
in_ptr0,
out_ptr,
n_elements,
BLOCK_SIZE: "tl.constexpr",
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(in_ptr0 + offsets, mask=mask)
output = tl.sin(x)
tl.store(out_ptr + offsets, output, mask=mask)
def sin_triton(x, out):
n_elements = x.numel()
sin_kernel[(n_elements,)](x, out, n_elements, BLOCK_SIZE=4)
factory_op = torch.empty_like
class MySin(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
out = factory_op(x)
sin_triton(x, out)
ctx.save_for_backward(out)
return out
@staticmethod
def backward(ctx, grad):
(saved,) = ctx.saved_tensors
out = factory_op(grad)
sin_triton(saved, out)
return out
def f(x):
return MySin.apply(x)
x = T(3, grad=True)
self.assertExpectedInline(count_numel_train(f, x), """9""")
@requires_cuda
def test_inplace_custom_op_training_two_mutated_inputs(self):
@torch.library.custom_op(
"_reinplacing::sin_cos", mutates_args={"out_sin", "out_cos"}
)
def sin_cos(
x: torch.Tensor, out_sin: torch.Tensor, out_cos: torch.Tensor
) -> None:
out_sin.copy_(x.sin())
out_cos.copy_(x.cos())
def f(x):
out0 = torch.empty_like(x)
out1 = torch.empty_like(x)
sin_cos(x, out0, out1)
return x.clone(), out0, out1
x = T(3, grad=True)
self.assertExpectedInline(count_numel(f, x), """21""")
@requires_cuda
def test_inplace_custom_op_training(self):
@torch.library.custom_op("_reinplacing::sin", mutates_args={"result"})
def sin(x: torch.Tensor, result: torch.Tensor) -> None:
result.copy_(x.sin())
factory_op = torch.empty_like
class MySin(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
out = factory_op(x)
sin(x, out)
ctx.save_for_backward(out)
return out
@staticmethod
def backward(ctx, grad):
(saved,) = ctx.saved_tensors
out = factory_op(grad)
sin(saved, out)
return out
def f(x):
return MySin.apply(x)
x = T(3, grad=True)
self.assertExpectedInline(count_numel_train(f, x), """9""")
@requires_cuda
def test_inplace_custom_op(self):
with torch.library._scoped_library("mylib", "FRAGMENT") as m:
m.define("foo(Tensor x, Tensor(a!) out) -> ()")
def foo(x: torch.Tensor, out: torch.Tensor) -> None:
out.copy_(x.sin())
m.impl("foo", foo, "CompositeExplicitAutograd")
def f(x, out):
torch.ops.mylib.foo(x, out)
torch.ops.mylib.foo(out, out)
torch.ops.mylib.foo(out, out)
return out
x = T(3)
out = T(3)
compiled_out, (code,) = run_and_get_code(
torch.compile(f, fullgraph=True), x, out
)
self.assertEqual(compiled_out, x.sin().sin().sin())
# Check that we are allocating the minimum number of intermediate buffers
matches = re.findall(r"empty_strided_\w+\(", code)
self.assertEqual(len(matches), 0)
self.assertExpectedInline(count_numel(f, x, out), """21""")
@requires_cuda
def test_inplace_custom_op_intermediate(self):
with torch.library._scoped_library("mylib", "FRAGMENT") as m:
m.define("foo(Tensor x, Tensor(a!) out) -> ()")
def foo(x: torch.Tensor, out: torch.Tensor) -> None:
out.copy_(x.sin())
m.impl("foo", foo, "CompositeExplicitAutograd")
def f(x, out):
out = torch.empty_like(x)
torch.ops.mylib.foo(x, out)
torch.ops.mylib.foo(out, out)
torch.ops.mylib.foo(out, out)
return out
x = T(3)
out = T(3)
compiled_out, (code,) = run_and_get_code(
torch.compile(f, fullgraph=True), x, out
)
self.assertEqual(compiled_out, x.sin().sin().sin())
# Check that we are allocating the minimum number of intermediate buffers
matches = re.findall(r"empty_strided_\w+\(", code)
self.assertEqual(len(matches), 1)
self.assertExpectedInline(count_numel(f, x, out), """21""")
@requires_cuda
def test_inplace_custom_op_two_mutated_inputs(self):
with torch.library._scoped_library("mylib", "FRAGMENT") as m:
m.define("foo(Tensor q, Tensor(a!) k_cache, Tensor(b!) v_cache) -> Tensor")
def foo(q, k_cache, v_cache):
k_cache.add_(1)
v_cache.add_(1)
return q + 1
m.impl("foo", foo, "CompositeExplicitAutograd")
q = T(3)
k_cache = T(3)
v_cache = torch.rand_like(k_cache)
def f():
x = 0
for _ in range(2):
x = x + torch.ops.mylib.foo(q, k_cache, v_cache)
return x
compiled_out, (code,) = run_and_get_code(
torch.compile(f, fullgraph=True),
)
# Check that we are allocating the minimum number of intermediate buffers
matches = re.findall(r"empty_strided_\w+\(", code)
self.assertEqual(len(matches), 1)
self.assertExpectedInline(count_numel(f), """39""")
@requires_cuda
def test_inplace_triton_kernel_v1(self):
def f(x: torch.Tensor, y: torch.Tensor):
output = torch.zeros_like(x)
n_elements = output.numel()
grid = (n_elements,)
add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=16)
return output
inp = (T(10), T(10))
self.assertExpectedInline(count_numel(f, *inp), """50""")
@requires_cuda
def test_inplace_triton_kernel_v2(self):
def f(x: torch.Tensor, y: torch.Tensor):
output = torch.zeros_like(x)
n_elements = output.numel()
grid = (n_elements,)
tmp = torch.add(x, 1)
add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=16)
return output, tmp
inp = (T(10), T(10))
self.assertExpectedInline(count_numel(f, *inp), """70""")
@requires_cuda
def test_inplace_triton_kernel_v3(self):
def f(x: torch.Tensor, y: torch.Tensor):
output = torch.zeros_like(x)
n_elements = output.numel()
grid = (n_elements,)
add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=16)
x.add_(1)
return output
inp = (T(10), T(10))
self.assertExpectedInline(count_numel(f, *inp), """80""")
@requires_cuda
def test_inplace_triton_kernel_v4(self):
def f(x: torch.Tensor, y: torch.Tensor):
x_view = x.view(-1)
output = torch.zeros_like(x)
n_elements = output.numel()
grid = (n_elements,)
add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=16)
output2 = x_view.mul(2)
return output, output2
inp = (T(10), T(10))
self.assertExpectedInline(count_numel(f, *inp), """70""")
@requires_cuda
def test_inplace_triton_kernel_v5(self):
def f(x: torch.Tensor, y: torch.Tensor):
x_view = x.view(-1)
output = torch.zeros_like(x)
n_elements = output.numel()
grid = (n_elements,)
add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=16)
x_view.mul_(2)
return output
inp = (T(10), T(10))
self.assertExpectedInline(count_numel(f, *inp), """80""")
@requires_cuda
def test_inplace_triton_kernel_v6(self):
def f(x: torch.Tensor, y: torch.Tensor):
output = torch.zeros_like(x)
n_elements = output.numel()
grid = (n_elements,)
add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=16)
return output
t = T(10)
inp = (t, t.view(-1))
self.assertExpectedInline(count_numel(f, *inp), """50""")
def test_inplace_randperm_scatter(self):
def scaled_index_add(x, y, scale_y):
index = torch.randperm(x.shape[0], device=x.device)[: y.shape[0]]
out = x.index_add_(dim=0, source=y * scale_y, index=index)
return out
inp = (T(10, 10), T(5, 10), T(10))
self.assertExpectedInline(count_numel(scaled_index_add, *inp), """250""")
# Test cases where we don't do the right thing yet.
class WouldBeNiceIfItWorked:
def test_horizontal(self):
def f(a):
b = a.sum(dim=0)
c = a.cos()
return b, c
inp = (T(10, 10),)
self.assertExpectedInline(count_numel(f, *inp), """210""")
# TODO: We aren't fusing outer dim softmaxes
def test_softmax_outer(self):
def f(a):
return torch.softmax(a, dim=0)
inp = (T(10, 10),)
self.assertExpectedInline(count_numel(f, *inp), """200""")
# TODO: The greedy fusion strategy results in suboptimal grouping
@patch.object(config, "realize_opcount_threshold", 0)
def test_fusion_choice4(self):
def f(a, b, b2):
c = a + b
d = torch.mm(c, c)
e = c + b + b2
f = d + e + b2
return f, e
inp = (T(10, 10), T(10, 10, dtype=torch.float16), T(10, 10))
self.assertExpectedInline(count_numel(f, *inp), """1000""")
# TODO: We materialize the intermediate if we don't unroll the reduction
def test_neighbor(self):
def f(a, b):
return ((a - b) ** 2).sum(dim=-1).amax(dim=1)
inp = (T(10, 1, 8), T(1, 10, 8))
self.assertExpectedInline(count_numel(f, *inp), """170""")
if __name__ == "__main__":
from torch._inductor.test_case import run_tests
if HAS_CUDA:
run_tests(needs="filelock")