mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Pull Request resolved: https://github.com/pytorch/pytorch/pull/97781 Approved by: https://github.com/jansel
1245 lines
44 KiB
Python
1245 lines
44 KiB
Python
# Owner(s): ["module: inductor"]
|
|
import contextlib
|
|
import itertools
|
|
import sys
|
|
import unittest
|
|
from typing import Callable
|
|
from unittest.mock import patch
|
|
|
|
import numpy as np
|
|
import sympy
|
|
import torch
|
|
import torch._dynamo
|
|
from torch._C import FileCheck
|
|
from torch._dynamo.utils import same
|
|
from torch._inductor import codecache, config, metrics
|
|
from torch._inductor.codegen.cpp import (
|
|
CppOverrides,
|
|
CppVecKernelChecker,
|
|
CppVecOverrides,
|
|
)
|
|
from torch._inductor.compile_fx import compile_fx_inner, complex_memory_overlap
|
|
from torch._inductor.graph import GraphLowering
|
|
from torch._inductor.ir import InterpreterShim
|
|
from torch._inductor.utils import timed
|
|
from torch._inductor.virtualized import V
|
|
from torch.fx.experimental.proxy_tensor import make_fx
|
|
from torch.nn import functional as F
|
|
from torch.testing._internal.common_utils import IS_MACOS
|
|
from torch.utils._python_dispatch import TorchDispatchMode
|
|
|
|
try:
|
|
try:
|
|
from . import test_torchinductor
|
|
except ImportError:
|
|
import test_torchinductor
|
|
except unittest.SkipTest:
|
|
if __name__ == "__main__":
|
|
sys.exit(0)
|
|
raise
|
|
|
|
|
|
vec_dtypes = test_torchinductor.vec_dtypes
|
|
slow = test_torchinductor.slow
|
|
run_and_get_cpp_code = test_torchinductor.run_and_get_cpp_code
|
|
TestCase = test_torchinductor.TestCase
|
|
aten = torch.ops.aten
|
|
|
|
|
|
class CPUReproTests(TestCase):
|
|
def test_conv_stride_constraints(self):
|
|
for fmt in [torch.channels_last, torch.contiguous_format]:
|
|
# TorchDispatch doesn't work in our cuda invocation for some reason
|
|
m = torch.nn.Conv2d(5, 6, [3, 3])
|
|
|
|
def fn(inp, weight):
|
|
return (
|
|
F.conv2d(
|
|
inp, weight, None, m.stride, m.padding, m.dilation, m.groups
|
|
),
|
|
)
|
|
|
|
inp = torch.randn([2, 5, 16, 16])
|
|
inps = [inp, m.weight.to(memory_format=fmt)]
|
|
fn_fx = make_fx(fn)(*inps)
|
|
fn_compiled = compile_fx_inner(fn_fx, inps)
|
|
test_self = self
|
|
conv_seen = False
|
|
|
|
class RecordFunctions(TorchDispatchMode):
|
|
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
|
|
kwargs = kwargs if kwargs else {}
|
|
if func == torch.ops.aten.convolution.default:
|
|
test_self.assertTrue(args[0].is_contiguous(memory_format=fmt))
|
|
test_self.assertTrue(args[1].is_contiguous(memory_format=fmt))
|
|
nonlocal conv_seen
|
|
conv_seen = True
|
|
|
|
return func(*args, **kwargs)
|
|
|
|
with RecordFunctions():
|
|
out = fn_compiled(inps)
|
|
|
|
self.assertTrue(conv_seen)
|
|
|
|
def test_inplace_squeeze_needed(self):
|
|
mod = torch.nn.Sequential(
|
|
torch.nn.Linear(10, 10),
|
|
torch.nn.LayerNorm(10),
|
|
torch.nn.ReLU(),
|
|
).eval()
|
|
|
|
@torch._dynamo.optimize("inductor")
|
|
def fn(x):
|
|
return mod(x)
|
|
|
|
v = torch.randn(10)
|
|
result = fn(v)
|
|
# TODO: OMP parallel reduction order is not deterministic.
|
|
# Hence, the accurarcy might vary up and down. For short term,
|
|
# we increase the tolerance and will fix it later by using
|
|
# aten parallel.
|
|
assert same(result, mod(v), tol=5e-1)
|
|
|
|
def test_cat_mul(self):
|
|
# https://github.com/pytorch/pytorch/issues/93365
|
|
def fn(p0, p1):
|
|
y1 = torch.cat([p0, p1], dim=0)
|
|
y2 = torch.mul(y1, y1)
|
|
return y1, y2
|
|
|
|
p0 = torch.randn(3, 4)
|
|
p1 = torch.randn(3, 4)
|
|
opt_fn = torch._dynamo.optimize("inductor")(fn)
|
|
opt_fn(p0, p1)
|
|
real_out = fn(p0, p1)
|
|
compiled_out = opt_fn(p0, p1)
|
|
assert same(real_out, compiled_out)
|
|
|
|
def test_reduce_with_masked(self):
|
|
# https://github.com/pytorch/pytorch/issues/96484
|
|
def fn(a, b):
|
|
a = torch.nn.functional.pad(a, (0, -1))
|
|
c = a + b
|
|
return c.min(0).values
|
|
|
|
a = torch.randn([2])
|
|
b = torch.randn([2])
|
|
opt_fn = torch._dynamo.optimize("inductor")(fn)
|
|
opt_fn(a, b)
|
|
real_out = fn(a, b)
|
|
compiled_out = opt_fn(a, b)
|
|
assert same(real_out, compiled_out)
|
|
|
|
@unittest.skipIf(
|
|
not codecache.valid_vec_isa_list(), "Does not support vectorization"
|
|
)
|
|
@patch("torch.cuda.is_available", lambda: False)
|
|
def test_sigmoid_with_reduction(self):
|
|
def fn(x):
|
|
x = torch.ops.aten.sigmoid.default(x)
|
|
return torch.ops.aten.mean.dim(x, [-1, -2], True)
|
|
|
|
x = torch.randn((1, 8, 8, 8))
|
|
with config.patch({"cpp.simdlen": None}):
|
|
torch._dynamo.reset()
|
|
metrics.reset()
|
|
opt_fn = torch._dynamo.optimize("inductor")(fn)
|
|
opt_fn(x)
|
|
|
|
real_out = fn(x)
|
|
compiled_out = opt_fn(x)
|
|
assert same(real_out, compiled_out, equal_nan=True)
|
|
|
|
def test_inplace_add_alpha(self):
|
|
def fn(x, y):
|
|
aten.add_.Tensor(x, y, alpha=0.55)
|
|
return (x,)
|
|
|
|
x1 = torch.zeros(10)
|
|
x2 = torch.zeros(10)
|
|
x3 = torch.zeros(10)
|
|
y = torch.randn(10)
|
|
fn_fx = make_fx(fn)(x1, y)
|
|
fn_compiled = compile_fx_inner(fn_fx, [x1, y])
|
|
fn(x2, y)
|
|
fn_compiled([x3, y])
|
|
assert same(x2, x3)
|
|
|
|
def test_no_op_squeeze(self):
|
|
@torch._dynamo.optimize("inductor")
|
|
def forward(arg0_1):
|
|
return torch.ops.aten.squeeze.dim(arg0_1, 1)
|
|
|
|
x = torch.randn((10, 20))
|
|
assert same(x, forward(x))
|
|
|
|
def test_parallel_num_threads(self):
|
|
@torch._dynamo.optimize("inductor")
|
|
def fn(x1, x2):
|
|
return x1 + x2
|
|
|
|
@contextlib.contextmanager
|
|
def set_num_threads(num_threads):
|
|
orig_num_threads = torch.get_num_threads()
|
|
torch.set_num_threads(num_threads)
|
|
yield
|
|
torch.set_num_threads(orig_num_threads)
|
|
|
|
x1 = torch.randn((10, 20))
|
|
x2 = torch.randn((10, 20))
|
|
with set_num_threads(1):
|
|
assert same(x1 + x2, fn(x1, x2))
|
|
with set_num_threads(4):
|
|
assert same(x1 + x2, fn(x1, x2))
|
|
|
|
@patch("torch.cuda.is_available", lambda: False)
|
|
def test_timed_cpu_only(self):
|
|
timed(lambda: torch.randn(10), ())
|
|
|
|
def test_complex_memory_overlap(self):
|
|
dense = torch.zeros(64, 32)
|
|
self.assertFalse(complex_memory_overlap(dense))
|
|
self.assertFalse(complex_memory_overlap(dense.t()))
|
|
|
|
strided = dense.split(4, dim=1)
|
|
self.assertFalse(complex_memory_overlap(strided[0]))
|
|
self.assertFalse(complex_memory_overlap(strided[0].t()))
|
|
|
|
unsqueezed = dense.unsqueeze(1)
|
|
self.assertFalse(complex_memory_overlap(unsqueezed))
|
|
self.assertFalse(complex_memory_overlap(unsqueezed.permute(1, 2, 0)))
|
|
|
|
expanded = unsqueezed.expand(-1, 2, -1)
|
|
self.assertTrue(complex_memory_overlap(expanded))
|
|
self.assertTrue(complex_memory_overlap(expanded.permute(1, 2, 0)))
|
|
|
|
gathered = dense.index_select(0, torch.IntTensor([1, 0, 1]))
|
|
self.assertFalse(complex_memory_overlap(gathered))
|
|
self.assertFalse(complex_memory_overlap(gathered.t()))
|
|
|
|
@unittest.skipIf(
|
|
not codecache.valid_vec_isa_list(), "Does not support vectorization"
|
|
)
|
|
@torch._dynamo.config.patch(dynamic_shapes=True)
|
|
def test_vec_dynamic_shapes(self):
|
|
def fn(x):
|
|
return torch.softmax(x, -1)
|
|
|
|
value = torch.randn((2, 10))
|
|
with config.patch({"cpp.simdlen": None}):
|
|
torch._dynamo.reset()
|
|
metrics.reset()
|
|
opt_fn = torch._dynamo.optimize("inductor")(fn)
|
|
opt_fn(value)
|
|
|
|
real_out = fn(value)
|
|
compiled_out = opt_fn(value)
|
|
assert same(real_out, compiled_out, equal_nan=True)
|
|
|
|
@unittest.skipIf(
|
|
not codecache.valid_vec_isa_list(), "Does not support vectorization"
|
|
)
|
|
@patch("torch.cuda.is_available", lambda: False)
|
|
def test_auto_simd(self):
|
|
vec_avx512 = codecache.supported_vec_isa_list[0]
|
|
vec_avx2 = codecache.supported_vec_isa_list[1]
|
|
self.assertTrue(vec_avx512.bit_width() == 512)
|
|
self.assertTrue(vec_avx2.bit_width() == 256)
|
|
self.assertTrue(vec_avx512.nelements() == 16)
|
|
self.assertTrue(vec_avx2.nelements() == 8)
|
|
self.assertTrue(vec_avx512.nelements(torch.bfloat16) == 32)
|
|
self.assertTrue(vec_avx2.nelements(torch.bfloat16) == 16)
|
|
|
|
with config.patch({"cpp.simdlen": None}):
|
|
isa = codecache.pick_vec_isa()
|
|
if vec_avx512 in codecache.valid_vec_isa_list():
|
|
self.assertTrue(isa == vec_avx512)
|
|
else:
|
|
self.assertTrue(isa == vec_avx2)
|
|
|
|
with config.patch({"cpp.simdlen": 0}):
|
|
isa = codecache.pick_vec_isa()
|
|
self.assertFalse(isa)
|
|
|
|
with config.patch({"cpp.simdlen": 1}):
|
|
isa = codecache.pick_vec_isa()
|
|
self.assertFalse(isa)
|
|
|
|
with config.patch({"cpp.simdlen": 257}):
|
|
isa = codecache.pick_vec_isa()
|
|
self.assertFalse(isa)
|
|
|
|
with config.patch({"cpp.simdlen": 513}):
|
|
isa_list = codecache.valid_vec_isa_list()
|
|
if vec_avx512 in isa_list:
|
|
self.assertFalse(isa)
|
|
|
|
with config.patch({"cpp.simdlen": 512}):
|
|
isa_list = codecache.valid_vec_isa_list()
|
|
if vec_avx512 in isa_list:
|
|
isa = codecache.pick_vec_isa()
|
|
self.assertTrue(isa == vec_avx512)
|
|
|
|
with config.patch({"cpp.simdlen": 256}):
|
|
isa_list = codecache.valid_vec_isa_list()
|
|
if vec_avx2 in isa_list:
|
|
isa = codecache.pick_vec_isa()
|
|
self.assertTrue(isa == vec_avx2)
|
|
|
|
@unittest.skipIf(
|
|
not codecache.valid_vec_isa_list(), "Does not support vectorization"
|
|
)
|
|
@patch("torch.cuda.is_available", lambda: False)
|
|
def test_masked_fill_softmax(self):
|
|
def fn(value, mask):
|
|
mask = mask.to(torch.bool)
|
|
x = torch.masked_fill(value, mask, -33.0)
|
|
return torch.softmax(x, -1)
|
|
|
|
for dtype in vec_dtypes:
|
|
value = torch.randn((2, 17), dtype=dtype)
|
|
mask = torch.randint(0, 1, size=(2, 17), dtype=torch.uint8)
|
|
with config.patch({"cpp.simdlen": None}):
|
|
for cpp_wrapper_flag in [True, False]:
|
|
with config.patch({"cpp_wrapper": cpp_wrapper_flag}):
|
|
torch._dynamo.reset()
|
|
metrics.reset()
|
|
opt_fn = torch._dynamo.optimize("inductor")(fn)
|
|
opt_fn(value, mask)
|
|
|
|
real_out = fn(value, mask)
|
|
compiled_out = opt_fn(value, mask)
|
|
assert same(real_out, compiled_out, equal_nan=True)
|
|
assert metrics.generated_cpp_vec_kernel_count >= 1
|
|
|
|
def test_load_same_bool_tensor_twice(self):
|
|
@torch._dynamo.optimize("inductor")
|
|
def fn(a, b):
|
|
x = torch.masked_fill(a, b, -33.0)
|
|
y = torch.masked_fill(a, b, -33.0)
|
|
return x, y
|
|
|
|
value = torch.randn((2, 17))
|
|
mask = torch.randint(0, 1, size=(2, 17), dtype=torch.uint8).to(torch.bool)
|
|
fn(value, mask)
|
|
|
|
def test_cpu_vec_cosim(self):
|
|
cpp_vec_op_list = []
|
|
cpp_op_list = []
|
|
|
|
for k, v in CppVecOverrides.__dict__.items():
|
|
if isinstance(v, staticmethod):
|
|
cpp_vec_op_list.append(k)
|
|
for k, v in CppOverrides.__dict__.items():
|
|
if isinstance(v, staticmethod):
|
|
cpp_op_list.append(k)
|
|
|
|
diff = [
|
|
"index_expr",
|
|
"signbit",
|
|
"isinf",
|
|
"mod",
|
|
"masked",
|
|
"randn",
|
|
"isnan",
|
|
"rand",
|
|
]
|
|
union = {*cpp_vec_op_list, *diff}
|
|
self.assertTrue(set(cpp_op_list).issubset(union))
|
|
|
|
def test_atomic_add_bf16(self):
|
|
def fn(test_args):
|
|
res = torch.gather(**test_args)
|
|
return res
|
|
|
|
input_tensor_for_ref = torch.tensor(
|
|
[[3.0, -5.0]], dtype=torch.bfloat16, requires_grad=True
|
|
)
|
|
input_tensor_for_opt = torch.tensor(
|
|
[[3.0, -5.0]], dtype=torch.bfloat16, requires_grad=True
|
|
)
|
|
|
|
test_args_for_ref = {
|
|
"input": input_tensor_for_ref,
|
|
"dim": 1,
|
|
"index": torch.tensor([[1]]),
|
|
}
|
|
test_args_for_opt = {
|
|
"input": input_tensor_for_opt,
|
|
"dim": 1,
|
|
"index": torch.tensor([[1]]),
|
|
}
|
|
|
|
opt_fn = torch.compile(fn)
|
|
|
|
ref_fwd = fn(test_args_for_ref)
|
|
res_fwd = opt_fn(test_args_for_opt)
|
|
self.assertEqual(res_fwd, ref_fwd)
|
|
|
|
torch.manual_seed(1)
|
|
bwd_tensor_for_ref = torch.randn(ref_fwd.shape, dtype=torch.bfloat16)
|
|
torch.manual_seed(1)
|
|
bwd_tensor_for_opt = torch.randn(res_fwd.shape, dtype=torch.bfloat16)
|
|
self.assertEqual(bwd_tensor_for_ref, bwd_tensor_for_opt)
|
|
|
|
ref_fwd.backward(bwd_tensor_for_ref)
|
|
res_fwd.backward(bwd_tensor_for_opt)
|
|
|
|
ref_grad = test_args_for_ref["input"].grad
|
|
res_grad = test_args_for_opt["input"].grad
|
|
self.assertEqual(ref_grad, res_grad)
|
|
|
|
@unittest.skipIf(
|
|
not codecache.valid_vec_isa_list(), "Does not support vectorization"
|
|
)
|
|
@patch("torch.cuda.is_available", lambda: False)
|
|
def test_new_vec_op_cpu_only(self):
|
|
def fn(x):
|
|
return (torch.log1p(torch.expm1(torch.erf(x))),)
|
|
|
|
for dtype in vec_dtypes:
|
|
torch.manual_seed(0)
|
|
x = torch.randn((2, 9), dtype=dtype)
|
|
x[0, 0] = torch.nan
|
|
x[1, -1] = torch.nan
|
|
|
|
tol = 1e-2 if dtype == torch.bfloat16 else 1e-4
|
|
|
|
with config.patch({"cpp.simdlen": None}):
|
|
for cpp_wrapper_flag in [True, False]:
|
|
with config.patch({"cpp_wrapper": cpp_wrapper_flag}):
|
|
torch._dynamo.reset()
|
|
metrics.reset()
|
|
traced = make_fx(fn)(x)
|
|
compiled = compile_fx_inner(traced, [x])
|
|
assert same(fn(x)[0], compiled([x])[0], equal_nan=True, tol=tol)
|
|
assert metrics.generated_cpp_vec_kernel_count == 1
|
|
|
|
@unittest.skipIf(
|
|
not codecache.valid_vec_isa_list(), "Does not support vectorization"
|
|
)
|
|
@patch("torch.cuda.is_available", lambda: False)
|
|
def test_vec_cpu_only_for_all_available_isa(self):
|
|
def fn(x):
|
|
return (torch.sin(torch.cos(torch.erf(x))),)
|
|
|
|
x = torch.randn((2, 9))
|
|
x[0, 0] = torch.nan
|
|
x[1, -1] = torch.nan
|
|
|
|
bit_widths = [isa._bit_width for isa in codecache.valid_vec_isa_list()] + [None]
|
|
for item in bit_widths:
|
|
with config.patch({"cpp.simdlen": item}):
|
|
torch._dynamo.reset()
|
|
metrics.reset()
|
|
traced = make_fx(fn)(x)
|
|
compiled = compile_fx_inner(traced, [x])
|
|
assert same(fn(x)[0], compiled([x])[0], equal_nan=True)
|
|
assert metrics.generated_cpp_vec_kernel_count == 1
|
|
|
|
@slow()
|
|
@unittest.skipIf(
|
|
not codecache.valid_vec_isa_list(), "Does not support vectorization"
|
|
)
|
|
@patch("torch.cuda.is_available", lambda: False)
|
|
def test__adaptive_avg_pool2d(self):
|
|
def wrap_fn(oh, ow):
|
|
def fn(x):
|
|
return torch._adaptive_avg_pool2d(x, (oh, ow))
|
|
|
|
return fn
|
|
|
|
bit_widths = [isa._bit_width for isa in codecache.valid_vec_isa_list()]
|
|
ih = [16, 65]
|
|
iw = ih
|
|
oh = ih
|
|
ow = ih
|
|
for _ih, _iw, _oh, _ow, _simd_len, dtype in itertools.product(
|
|
ih, iw, oh, ow, bit_widths, vec_dtypes
|
|
):
|
|
x = torch.randn(2, 3, _ih, _iw, dtype=dtype).to(
|
|
memory_format=torch.channels_last
|
|
)
|
|
_fn = wrap_fn(_oh, _ow)
|
|
with config.patch({"cpp.simdlen": _simd_len}):
|
|
torch._dynamo.reset()
|
|
metrics.reset()
|
|
compiled = torch.compile(_fn)
|
|
compiled(x)
|
|
assert same(_fn(x), compiled(x), equal_nan=True)
|
|
assert metrics.generated_cpp_vec_kernel_count == 1
|
|
|
|
@unittest.skipIf(
|
|
not codecache.valid_vec_isa_list(), "Does not support vectorization"
|
|
)
|
|
@patch("torch.cuda.is_available", lambda: False)
|
|
def test_vec_logical_and_or(self):
|
|
def wrap_fn(op: Callable):
|
|
def fn(x: torch.Tensor, y: torch.Tensor):
|
|
return torch.where(op(x, y), 1.0, 0.0)
|
|
|
|
return fn
|
|
|
|
for dtype in vec_dtypes:
|
|
x = torch.randn(64, dtype=dtype)
|
|
y = torch.randn(64, dtype=dtype)
|
|
logical_fns = [torch.logical_and, torch.logical_or]
|
|
for logical_fn in logical_fns:
|
|
_fn = wrap_fn(logical_fn)
|
|
torch._dynamo.reset()
|
|
metrics.reset()
|
|
compiled = torch.compile(_fn)
|
|
|
|
compiled(x, y)
|
|
assert same(_fn(x, y), compiled(x, y), equal_nan=True)
|
|
assert metrics.generated_cpp_vec_kernel_count == 1
|
|
|
|
@unittest.skipIf(
|
|
not codecache.valid_vec_isa_list(), "Does not support vectorization"
|
|
)
|
|
@patch("torch.cuda.is_available", lambda: False)
|
|
def test_vec_compare_op_cpu_only(self):
|
|
def fn(x):
|
|
y1 = torch.eq(x, 1.0)
|
|
x = torch.where(y1, x, -x)
|
|
y2 = torch.ne(x, 0.0)
|
|
x = torch.where(y2, x, -x)
|
|
y3 = torch.lt(x, 5.0)
|
|
x = torch.where(y3, x, x - 1.0)
|
|
y4 = torch.gt(x, -2.0)
|
|
x = torch.where(y4, x, x + 1.0)
|
|
y5 = torch.le(x, 8.0)
|
|
x = torch.where(y5, x, x - 1.0)
|
|
y6 = torch.ge(x, -3.0)
|
|
x = torch.where(y6, x, x + 1.0)
|
|
y7 = x == 1.0
|
|
x = torch.where(y7, x, -x)
|
|
y8 = x != 0.0
|
|
x = torch.where(y8, x, -x)
|
|
y9 = x < 5.0
|
|
x = torch.where(y9, x, x - 1.0)
|
|
y10 = x > -2.0
|
|
x = torch.where(y10, x, x + 1.0)
|
|
y11 = x <= 8.0
|
|
x = torch.where(y11, x, x - 1.0)
|
|
y12 = x >= -3.0
|
|
x = torch.where(y12, x, x + 1.0)
|
|
return (x,)
|
|
|
|
for dtype in vec_dtypes:
|
|
x = torch.randn((2, 9), dtype=dtype)
|
|
|
|
with config.patch({"cpp.simdlen": None}):
|
|
torch._dynamo.reset()
|
|
metrics.reset()
|
|
traced = make_fx(fn)(x)
|
|
compiled = compile_fx_inner(traced, [x])
|
|
assert same(fn(x)[0], compiled([x])[0], equal_nan=True)
|
|
assert metrics.generated_cpp_vec_kernel_count == 1
|
|
assert (
|
|
metrics.generated_kernel_count
|
|
- metrics.generated_cpp_vec_kernel_count
|
|
) == 0
|
|
|
|
def test_skip_cpp_codegen(self):
|
|
with config.patch({"disable_cpp_codegen": True}):
|
|
inps = torch.ones([20]), torch.rand([20])
|
|
|
|
def f(x, y):
|
|
return x + y + torch.tensor(1)
|
|
|
|
f_opt = torch.compile()(f)
|
|
|
|
code = run_and_get_cpp_code(f_opt, inps[0], inps[1])
|
|
FileCheck().check_not("void kernel").run(code)
|
|
|
|
self.assertEqual(
|
|
f(*inps),
|
|
f_opt(*inps),
|
|
)
|
|
|
|
# constant needs to be propagated on fallback
|
|
def f(x):
|
|
return x[torch.tensor(1) :] * 2
|
|
|
|
f_opt = torch.compile()(f)
|
|
code = run_and_get_cpp_code(f_opt, inps[0])
|
|
FileCheck().check_not("void kernel").run(code)
|
|
self.assertEqual(f_opt(inps[0]), f(inps[0]))
|
|
|
|
class Model(torch.nn.Module):
|
|
def __init__(
|
|
self,
|
|
):
|
|
super().__init__()
|
|
|
|
def forward(self, v1: torch.Tensor):
|
|
vx = v1.min(dim=1).values
|
|
v2 = torch.randn_like(vx)
|
|
return v2
|
|
|
|
model = Model()
|
|
x = torch.rand(10, 3, 0)
|
|
model_f = torch.compile()(model)
|
|
|
|
self.assertEqual(model(x), model_f(x))
|
|
|
|
def test_redundant_to_node_elimination_bf16(self):
|
|
def fn(x, y):
|
|
res = x + y
|
|
res = torch.mean(res)
|
|
return (res,)
|
|
|
|
x = torch.randn((2, 9), dtype=torch.bfloat16)
|
|
y = torch.randn((2, 9), dtype=torch.bfloat16)
|
|
|
|
for torch_compile_debug in [True, False]:
|
|
with config.patch(
|
|
{"trace.enabled": torch_compile_debug, "cpp.simdlen": None}
|
|
):
|
|
torch._dynamo.reset()
|
|
metrics.reset()
|
|
traced = make_fx(fn)(x, y)
|
|
compiled = compile_fx_inner(traced, [x, y])
|
|
assert same(fn(x, y)[0], compiled([x, y])[0], equal_nan=True, tol=1e-2)
|
|
if codecache.valid_vec_isa_list():
|
|
assert metrics.generated_cpp_vec_kernel_count == 1
|
|
|
|
def test_do_not_insert_to_dtype_for_memory_copy_only_kernel(self):
|
|
def fn(x):
|
|
res = x.clone()
|
|
return (res,)
|
|
|
|
x = torch.randn((100, 100), dtype=torch.bfloat16)
|
|
|
|
torch._dynamo.reset()
|
|
metrics.reset()
|
|
traced = make_fx(fn)(x)
|
|
compiled = compile_fx_inner(traced, [x])
|
|
assert same(fn(x)[0], compiled([x])[0])
|
|
assert metrics.cpp_to_dtype_count == 0
|
|
if codecache.valid_vec_isa_list():
|
|
assert metrics.generated_cpp_vec_kernel_count == 1
|
|
|
|
def test_insert_to_dtype_count(self):
|
|
def fn(x):
|
|
res = x.relu()
|
|
return (res,)
|
|
|
|
x = torch.randn((100, 100), dtype=torch.bfloat16)
|
|
|
|
torch._dynamo.reset()
|
|
metrics.reset()
|
|
traced = make_fx(fn)(x)
|
|
compiled = compile_fx_inner(traced, [x])
|
|
assert same(fn(x)[0], compiled([x])[0])
|
|
assert metrics.cpp_to_dtype_count == 2
|
|
if codecache.valid_vec_isa_list():
|
|
assert metrics.generated_cpp_vec_kernel_count == 1
|
|
|
|
@unittest.skipIf(
|
|
not codecache.valid_vec_isa_list(), "Does not support vectorization"
|
|
)
|
|
@patch("torch.cuda.is_available", lambda: False)
|
|
def test_cpp_vec_constant_checker(self):
|
|
_graph: torch.fx.Graph = torch.fx.Graph()
|
|
a: torch.fx.Node = _graph.create_node("placeholder", "ops")
|
|
iv: torch.fx.Node = _graph.create_node("placeholder", "iv")
|
|
fv: torch.fx.Node = _graph.create_node("placeholder", "fv")
|
|
b: torch.fx.Node = _graph.create_node(
|
|
"call_method",
|
|
"constant",
|
|
args=(
|
|
a,
|
|
iv,
|
|
torch.int64,
|
|
),
|
|
)
|
|
c: torch.fx.Node = _graph.create_node(
|
|
"call_method",
|
|
"constant",
|
|
args=(
|
|
a,
|
|
fv,
|
|
torch.double,
|
|
),
|
|
)
|
|
d: torch.fx.Node = _graph.create_node(
|
|
"call_method",
|
|
"ge",
|
|
args=(
|
|
a,
|
|
b,
|
|
b,
|
|
),
|
|
)
|
|
_graph.output((d, c))
|
|
|
|
def get_index():
|
|
return ""
|
|
|
|
submodules = {"get_index": get_index}
|
|
|
|
graph_lowering = GraphLowering(
|
|
torch.fx.GraphModule(submodules, _graph),
|
|
shape_env=None,
|
|
num_static_inputs=0,
|
|
)
|
|
with patch.object(graph_lowering, "wrapper_code", ""), V.set_graph_handler(
|
|
graph_lowering
|
|
):
|
|
# The moset inner loop variable is used in the index_expr
|
|
tiling_factor = codecache.pick_vec_isa().nelements(dtype=torch.float)
|
|
with CppVecKernelChecker(
|
|
args=None, num_threads=1, tiling_factor=tiling_factor
|
|
) as vec_checker:
|
|
i32_iinfo = np.iinfo(np.int32)
|
|
f32_iinfo = np.finfo(np.float32)
|
|
InterpreterShim(_graph, submodules).run(
|
|
V.get_ops_handler(), i32_iinfo.max, f32_iinfo.max
|
|
)
|
|
self.assertTrue(vec_checker.simd_vec)
|
|
|
|
vec_checker.simd_vec = True
|
|
InterpreterShim(_graph, submodules).run(
|
|
V.get_ops_handler(), i32_iinfo.min, f32_iinfo.min
|
|
)
|
|
self.assertTrue(vec_checker.simd_vec)
|
|
|
|
vec_checker.simd_vec = True
|
|
InterpreterShim(_graph, submodules).run(
|
|
V.get_ops_handler(), i32_iinfo.min, np.inf
|
|
)
|
|
self.assertTrue(vec_checker.simd_vec)
|
|
|
|
vec_checker.simd_vec = True
|
|
InterpreterShim(_graph, submodules).run(
|
|
V.get_ops_handler(), i32_iinfo.min, -np.inf
|
|
)
|
|
self.assertTrue(vec_checker.simd_vec)
|
|
|
|
vec_checker.simd_vec = True
|
|
InterpreterShim(_graph, submodules).run(
|
|
V.get_ops_handler(), i32_iinfo.min - 1, f32_iinfo.min
|
|
)
|
|
self.assertFalse(vec_checker.simd_vec)
|
|
|
|
vec_checker.simd_vec = True
|
|
InterpreterShim(_graph, submodules).run(
|
|
V.get_ops_handler(), i32_iinfo.max + 1, f32_iinfo.max
|
|
)
|
|
self.assertFalse(vec_checker.simd_vec)
|
|
|
|
vec_checker.simd_vec = True
|
|
InterpreterShim(_graph, submodules).run(
|
|
V.get_ops_handler(), i32_iinfo.min, f32_iinfo.min * (1 + 1e-5)
|
|
)
|
|
self.assertFalse(vec_checker.simd_vec)
|
|
|
|
vec_checker.simd_vec = True
|
|
InterpreterShim(_graph, submodules).run(
|
|
V.get_ops_handler(), i32_iinfo.max, f32_iinfo.max * (1 + 1e-5)
|
|
)
|
|
self.assertFalse(vec_checker.simd_vec)
|
|
|
|
@unittest.skipIf(
|
|
not codecache.valid_vec_isa_list(), "Does not support vectorization"
|
|
)
|
|
@patch("torch.cuda.is_available", lambda: False)
|
|
def test_cpp_vec_index_expr_checker(self):
|
|
_graph: torch.fx.Graph = torch.fx.Graph()
|
|
a: torch.fx.Node = _graph.create_node("placeholder", "ops")
|
|
b: torch.fx.Node = _graph.create_node("call_module", "get_index", args=())
|
|
c: torch.fx.Node = _graph.create_node(
|
|
"call_method",
|
|
"index_expr",
|
|
args=(
|
|
a,
|
|
b,
|
|
torch.int64,
|
|
),
|
|
)
|
|
d: torch.fx.Node = _graph.create_node(
|
|
"call_method",
|
|
"ge",
|
|
args=(
|
|
a,
|
|
c,
|
|
c,
|
|
),
|
|
)
|
|
_graph.output(d)
|
|
|
|
def get_index():
|
|
return ""
|
|
|
|
submodules = {"get_index": get_index}
|
|
graph_lowering = GraphLowering(
|
|
torch.fx.GraphModule(submodules, _graph),
|
|
shape_env=None,
|
|
num_static_inputs=0,
|
|
)
|
|
with patch.object(graph_lowering, "wrapper_code", ""), V.set_graph_handler(
|
|
graph_lowering
|
|
):
|
|
itervars = [sympy.Symbol("i"), sympy.Symbol("j"), sympy.Symbol("k")]
|
|
|
|
tiling_factor = codecache.pick_vec_isa().nelements(dtype=torch.float)
|
|
# The moset inner loop variable is used in the index_expr
|
|
with CppVecKernelChecker(
|
|
args=None, num_threads=1, tiling_factor=tiling_factor
|
|
) as vec_checker:
|
|
|
|
def get_index():
|
|
return -itervars[0] ** 2 + 2 * itervars[0] + itervars[1]
|
|
|
|
ranges = [0, 100, 200]
|
|
vec_checker.itervars = itervars[:2]
|
|
vec_checker.ranges = ranges[:2]
|
|
submodules = {"get_index": get_index}
|
|
InterpreterShim(_graph, submodules).run(V.get_ops_handler())
|
|
self.assertFalse(vec_checker.simd_vec)
|
|
|
|
# Most inner loop variable irrevalant
|
|
with CppVecKernelChecker(
|
|
args=None, num_threads=1, tiling_factor=tiling_factor
|
|
) as vec_checker:
|
|
|
|
def get_index():
|
|
return -itervars[0] ** 2 + 2 * itervars[0] + itervars[1]
|
|
|
|
ranges = [0, 100, 200]
|
|
vec_checker.itervars = itervars
|
|
vec_checker.ranges = ranges
|
|
submodules = {"get_index": get_index}
|
|
InterpreterShim(_graph, submodules).run(V.get_ops_handler())
|
|
self.assertTrue(vec_checker.simd_vec)
|
|
|
|
i32_iinfo = np.iinfo(np.int32)
|
|
_max_value = i32_iinfo.max + 1
|
|
ranges = [_max_value, _max_value, _max_value]
|
|
# Most inner loop variable irrevalant but max value is greater than
|
|
# the max value of INT32
|
|
with CppVecKernelChecker(
|
|
args=None, num_threads=1, tiling_factor=tiling_factor
|
|
) as vec_checker:
|
|
|
|
def get_index():
|
|
return itervars[0]
|
|
|
|
submodules = {"get_index": get_index}
|
|
vec_checker.itervars = itervars
|
|
vec_checker.ranges = ranges
|
|
InterpreterShim(_graph, submodules).run(V.get_ops_handler())
|
|
self.assertFalse(vec_checker.simd_vec)
|
|
|
|
# Most inner loop variable irrevalant but min value is greater than
|
|
# the min value of INT32
|
|
with CppVecKernelChecker(
|
|
args=None, num_threads=1, tiling_factor=tiling_factor
|
|
) as vec_checker:
|
|
|
|
def get_index():
|
|
return -itervars[0] - 2
|
|
|
|
submodules = {"get_index": get_index}
|
|
vec_checker.itervars = itervars
|
|
vec_checker.ranges = ranges
|
|
InterpreterShim(_graph, submodules).run(V.get_ops_handler())
|
|
self.assertFalse(vec_checker.simd_vec)
|
|
|
|
@unittest.skipIf(
|
|
not codecache.valid_vec_isa_list(), "Does not support vectorization"
|
|
)
|
|
@patch("torch.cuda.is_available", lambda: False)
|
|
def test_maxpool2d_cpu_only(self):
|
|
for dtype in vec_dtypes:
|
|
input = torch.randn(10, 32, 20, 20, dtype=dtype).to(
|
|
memory_format=torch.channels_last
|
|
)
|
|
maxpool = torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
|
|
|
def func(x):
|
|
return maxpool(x)
|
|
|
|
with patch.object(config.cpp, "simdlen", None):
|
|
torch._dynamo.reset()
|
|
metrics.reset()
|
|
graph = torch.compile(func, backend="inductor")
|
|
graph(input)
|
|
assert same(graph(input), func(input), equal_nan=True)
|
|
assert metrics.generated_cpp_vec_kernel_count == 1
|
|
|
|
@unittest.skipIf(
|
|
not codecache.valid_vec_isa_list(), "Does not support vectorization"
|
|
)
|
|
@patch("torch.cuda.is_available", lambda: False)
|
|
def test_maxpool2d_with_pre_loop_collapse_cpu_only(self):
|
|
x1 = torch.randn(2, 3, 20, 20).to(memory_format=torch.channels_last)
|
|
x2 = torch.randn(2, 3, 20, 20).to(memory_format=torch.channels_last)
|
|
maxpool = torch.nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True)
|
|
|
|
def func(x1, x2):
|
|
y = x1 + x2
|
|
return maxpool(y)
|
|
|
|
with patch.object(config.cpp, "simdlen", None):
|
|
torch._dynamo.reset()
|
|
metrics.reset()
|
|
graph = torch.compile(func, backend="inductor")
|
|
graph(x1, x2)
|
|
assert same(graph(x1, x2), func(x1, x2), equal_nan=True)
|
|
assert metrics.generated_cpp_vec_kernel_count == 2
|
|
|
|
@unittest.skipIf(
|
|
not codecache.valid_vec_isa_list(), "Does not support vectorization"
|
|
)
|
|
@patch("torch.cuda.is_available", lambda: False)
|
|
def test_sign_cpu_only(self):
|
|
def fn(x):
|
|
return (torch.sign(x),)
|
|
|
|
for dtype in vec_dtypes:
|
|
x = torch.randn((2, 9), dtype=dtype)
|
|
x[0, 0] = torch.nan
|
|
x[1, -1] = torch.nan
|
|
|
|
with config.patch({"cpp.simdlen": None}):
|
|
torch._dynamo.reset()
|
|
metrics.reset()
|
|
traced = make_fx(fn)(x)
|
|
compiled = compile_fx_inner(traced, [x])
|
|
assert same(fn(x)[0], compiled([x])[0], equal_nan=True)
|
|
assert metrics.generated_cpp_vec_kernel_count == 1
|
|
|
|
@unittest.skipIf(
|
|
not codecache.valid_vec_isa_list(), "Does not support vectorization"
|
|
)
|
|
@patch("torch.cuda.is_available", lambda: False)
|
|
def test_reduction_cpu_only(self):
|
|
def fn(x):
|
|
return (torch.argmax(x, -1),)
|
|
|
|
for dtype in vec_dtypes:
|
|
x = torch.randn((10, 10), dtype=dtype)
|
|
|
|
with config.patch({"cpp.simdlen": None}):
|
|
torch._dynamo.reset()
|
|
metrics.reset()
|
|
traced = make_fx(fn)(x)
|
|
compiled = compile_fx_inner(traced, [x])
|
|
assert same(fn(x)[0], compiled([x])[0], equal_nan=True)
|
|
assert metrics.generated_cpp_vec_kernel_count == 0
|
|
|
|
# Currently, we enabled AVX2 and AVX512 for vectorization. If the platform is not
|
|
# supported, the vectorization will not work and skip this test case. For ARM or
|
|
# other platforms support, we just need to add the ISA info to the supported_vector_isa
|
|
# and include proper aten vectorization head file.
|
|
@unittest.skipIf(
|
|
not codecache.valid_vec_isa_list(), "Does not support vectorization"
|
|
)
|
|
@patch("torch.cuda.is_available", lambda: False)
|
|
def test_vec_kernel_cpu_only(self):
|
|
def fn(x1, x2):
|
|
# Current, there are some limitations as follows.
|
|
# rsqrt:
|
|
# assert [both a fallback and a decomp for same kernel: aten.rsqrt.default]
|
|
# round:
|
|
# couldn't find symbolic meta function/decomposition
|
|
# fmod/logical_and/logic_or:
|
|
# vec kernel has not support to_type
|
|
x = torch.abs(x1)
|
|
x = torch.sin(x)
|
|
x = torch.neg(x)
|
|
x = torch.square(x)
|
|
x = torch.sigmoid(x)
|
|
x = torch.relu(x)
|
|
x = torch.cos(x)
|
|
x = torch.exp(x)
|
|
x = torch.sqrt(x)
|
|
x = torch.add(x, x1)
|
|
x = torch.sub(x, x2)
|
|
x = torch.mul(x, x1)
|
|
x = torch.div(x, x1)
|
|
x = torch.pow(x, 10)
|
|
x = torch.log(x)
|
|
x = torch.floor(x)
|
|
x = torch.ceil(x)
|
|
x = torch.trunc(x)
|
|
x = torch.lgamma(x)
|
|
x = torch.fmod(x, x2)
|
|
x = torch.sign(x)
|
|
res = x + x2
|
|
return (res,)
|
|
|
|
for dtype in vec_dtypes:
|
|
torch.manual_seed(0)
|
|
x1 = torch.randn((5, 20), dtype=dtype)
|
|
x2 = torch.randn((5, 20), dtype=dtype)
|
|
|
|
tol = 1e-2 if dtype == torch.bfloat16 else 1e-4
|
|
with config.patch({"cpp.simdlen": 1}):
|
|
torch._dynamo.reset()
|
|
metrics.reset()
|
|
traced = make_fx(fn)(x1, x2)
|
|
compiled = compile_fx_inner(traced, [x1, x2])
|
|
assert same(
|
|
fn(x1, x2)[0], compiled([x1, x2])[0], equal_nan=True, tol=tol
|
|
)
|
|
assert metrics.generated_cpp_vec_kernel_count == 0
|
|
|
|
with config.patch({"cpp.simdlen": None}):
|
|
torch._dynamo.reset()
|
|
metrics.reset()
|
|
traced = make_fx(fn)(x1, x2)
|
|
compiled = compile_fx_inner(traced, [x1, x2])
|
|
assert same(fn(x1, x2)[0], compiled([x1, x2])[0], equal_nan=True)
|
|
assert metrics.generated_cpp_vec_kernel_count == 1
|
|
|
|
with config.patch({"cpp.simdlen": None}):
|
|
torch._dynamo.reset()
|
|
metrics.reset()
|
|
x1 = torch.randn(10, 20).permute(1, 0)
|
|
x2 = torch.randn((20, 10))
|
|
traced = make_fx(fn)(x1, x2)
|
|
compiled = compile_fx_inner(traced, [x1, x2])
|
|
assert same(fn(x1, x2)[0], compiled([x1, x2])[0], equal_nan=True)
|
|
assert metrics.generated_cpp_vec_kernel_count == 2
|
|
|
|
torch._dynamo.reset()
|
|
metrics.reset()
|
|
x1 = torch.randn((10, 7))
|
|
x2 = torch.randn((10, 7))
|
|
traced = make_fx(fn)(x1, x2)
|
|
compiled = compile_fx_inner(traced, ([x1, x2]))
|
|
assert same(fn(x1, x2)[0], compiled([x1, x2])[0], equal_nan=True)
|
|
assert metrics.generated_cpp_vec_kernel_count == 1
|
|
|
|
@unittest.skipIf(
|
|
sys.platform != "linux", "cpp kernel profile only support linux now"
|
|
)
|
|
@patch("torch.cuda.is_available", lambda: False)
|
|
@config.patch({"cpp.enable_kernel_profile": True})
|
|
def test_cpp_kernel_profile(self):
|
|
from torch.profiler import profile
|
|
|
|
@torch._dynamo.optimize("inductor", nopython=True)
|
|
def fn(a, b):
|
|
return a + b
|
|
|
|
a = torch.rand((100,))
|
|
b = torch.rand((100,))
|
|
with profile() as prof:
|
|
fn(a, b)
|
|
|
|
kernel_profile_events = []
|
|
for e in prof.profiler.function_events:
|
|
if "kernel_cpp_0" in e.name:
|
|
kernel_profile_events.append(e.name)
|
|
assert len(kernel_profile_events) > 0
|
|
|
|
@unittest.skipIf(
|
|
not codecache.valid_vec_isa_list(), "Does not support vectorization"
|
|
)
|
|
def test_channel_shuffle_cl_output(self):
|
|
"""code and shape extracted from shufflenet_v2_x1_0"""
|
|
|
|
def channel_shuffle(x, groups):
|
|
batchsize, num_channels, height, width = x.size()
|
|
channels_per_group = num_channels // groups
|
|
x = x.view(batchsize, groups, channels_per_group, height, width)
|
|
x = torch.transpose(x, 1, 2).contiguous()
|
|
x = x.view(batchsize, -1, height, width)
|
|
return x.contiguous(memory_format=torch.channels_last)
|
|
|
|
for simdlen in (None, 256, 1):
|
|
with config.patch({"cpp.simdlen": simdlen}):
|
|
torch._dynamo.reset()
|
|
metrics.reset()
|
|
x = torch.randn(64, 58, 28, 28)
|
|
opt_fn = torch._dynamo.optimize("inductor")(channel_shuffle)
|
|
self.assertTrue(same(channel_shuffle(x, 2), opt_fn(x, 2)))
|
|
if simdlen != 1:
|
|
assert metrics.generated_cpp_vec_kernel_count == 2
|
|
|
|
@slow()
|
|
@unittest.skipIf(
|
|
not codecache.valid_vec_isa_list(), "Does not support vectorization"
|
|
)
|
|
def test_transpose_with_norm(self):
|
|
"""a sub-module from TIMM gmlp_s16_224"""
|
|
|
|
class Model(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(
|
|
in_features=256, out_features=1536, bias=True
|
|
)
|
|
self.act = torch.nn.GELU()
|
|
self.norm = torch.nn.LayerNorm(768)
|
|
self.proj = torch.nn.Linear(196, 196)
|
|
self.fc = torch.nn.Linear(in_features=768, out_features=256, bias=True)
|
|
|
|
def forward(self, x):
|
|
x = self.linear(x)
|
|
x = self.act(x)
|
|
u, v = x.chunk(2, dim=-1)
|
|
v = self.norm(v)
|
|
v = self.proj(v.transpose(-1, -2))
|
|
y = u * v.transpose(-1, -2)
|
|
return self.fc(y)
|
|
|
|
x = torch.randn(128, 196, 256)
|
|
for simdlen in (None, 256, 1):
|
|
with config.patch({"cpp.simdlen": simdlen}):
|
|
for eval_mode in [True, False]:
|
|
torch._dynamo.reset()
|
|
metrics.reset()
|
|
m = Model().eval() if eval_mode else Model()
|
|
opt_fn = torch._dynamo.optimize("inductor")(m)
|
|
same(m(x), opt_fn(x))
|
|
if simdlen != 1:
|
|
assert metrics.generated_cpp_vec_kernel_count == 6
|
|
|
|
@unittest.skipIf(
|
|
not codecache.valid_vec_isa_list(), "Does not support vectorization"
|
|
)
|
|
def test_transpose_copy(self):
|
|
def fn(a):
|
|
return a.t().contiguous()
|
|
|
|
for simdlen in (None, 256, 1):
|
|
with config.patch({"cpp.simdlen": simdlen}):
|
|
for dtype in (torch.float, torch.bfloat16):
|
|
for shape in (
|
|
(7, 7),
|
|
(8, 8),
|
|
(9, 9),
|
|
(16, 16),
|
|
(17, 17),
|
|
(32, 32),
|
|
(33, 33),
|
|
):
|
|
torch._dynamo.reset()
|
|
metrics.reset()
|
|
x = torch.randn(shape, dtype=dtype)
|
|
opt_fn = torch._dynamo.optimize("inductor")(fn)
|
|
self.assertTrue(same(fn(x), opt_fn(x)))
|
|
if simdlen != 1:
|
|
assert metrics.generated_cpp_vec_kernel_count == 2
|
|
|
|
def test_transpose_non_contiguous(self):
|
|
def fn(a):
|
|
# From part of timm HaloAttn:
|
|
# (https://github.com/rwightman/pytorch-image-models/blob/main/timm/layers/halo_attn.py#L97).
|
|
# Fixed https://github.com/pytorch/pytorch/issues/94269 accuracy issue.
|
|
as_strided = torch.ops.aten.as_strided.default(
|
|
a, [1, 384, 2, 20, 12], [153600, 1, 61440, 384, 7680]
|
|
)
|
|
as_strided_1 = torch.ops.aten.as_strided.default(
|
|
as_strided,
|
|
[1, 384, 2, 2, 12, 12],
|
|
[153600, 1, 61440, 3072, 7680, 384],
|
|
)
|
|
clone_1 = torch.ops.aten.clone.default(
|
|
as_strided_1, memory_format=torch.contiguous_format
|
|
)
|
|
_unsafe_view_1 = torch.ops.aten._unsafe_view.default(
|
|
clone_1, [8, 48, 4, 144]
|
|
)
|
|
permute_2 = torch.ops.aten.permute.default(_unsafe_view_1, [0, 2, 3, 1])
|
|
split_with_sizes = torch.ops.aten.split_with_sizes.default(
|
|
permute_2, [16, 32], -1
|
|
)
|
|
getitem = split_with_sizes[0]
|
|
getitem_1 = split_with_sizes[1]
|
|
permute_3 = torch.ops.aten.permute.default(getitem, [0, 1, 3, 2])
|
|
expand_1 = torch.ops.aten.expand.default(permute_3, [8, 4, 16, 144])
|
|
clone_3 = torch.ops.aten.clone.default(
|
|
expand_1, memory_format=torch.contiguous_format
|
|
)
|
|
return clone_3
|
|
|
|
metrics.reset()
|
|
x = torch.randn(1, 384, 20, 20).to(memory_format=torch.channels_last)
|
|
opt_fn = torch._dynamo.optimize("inductor")(fn)
|
|
same(fn(x), opt_fn(x))
|
|
assert metrics.generated_cpp_vec_kernel_count == 0
|
|
|
|
def test_invalid_index_of_empty_tensor(self):
|
|
def fn(a):
|
|
b = a[[0]]
|
|
return b
|
|
|
|
a = torch.tensor([])
|
|
with self.assertRaises(RuntimeError):
|
|
torch.compile(fn)(a)
|
|
|
|
def test_ir_node_str(self):
|
|
@torch.compile
|
|
def fn(x: torch.Tensor) -> torch.Tensor:
|
|
return x.sin(), torch.nn.Softmax(dim=1)(x.cos())
|
|
|
|
def run_node_alt(*args, **kwargs):
|
|
rv = run_node(*args, **kwargs)
|
|
strings.append(str(rv))
|
|
return rv
|
|
|
|
strings = []
|
|
run_node = GraphLowering.run_node
|
|
with patch.object(GraphLowering, "run_node", run_node_alt):
|
|
fn(torch.randn([8, 128]))
|
|
self.assertGreater(len(strings), 3)
|
|
|
|
def test_vertical_sum_cpu_only(self):
|
|
def fn1(a):
|
|
return a.sum(dim=0)
|
|
|
|
def fn2(a):
|
|
return a.sum(dim=1)
|
|
|
|
metrics.reset()
|
|
x = torch.randn(100, 100)
|
|
opt_fn1 = torch._dynamo.optimize("inductor")(fn1)
|
|
self.assertTrue(same(fn1(x), opt_fn1(x)))
|
|
assert metrics.generated_cpp_vec_kernel_count == 1
|
|
|
|
metrics.reset()
|
|
x = torch.randn(100, 100, 100)
|
|
opt_fn2 = torch._dynamo.optimize("inductor")(fn2)
|
|
self.assertTrue(same(fn2(x), opt_fn2(x)))
|
|
assert metrics.generated_cpp_vec_kernel_count == 1
|
|
|
|
def test_transpose_vertical_sum_cpu_only(self):
|
|
def fn(a, b):
|
|
c = a * b
|
|
return c.sum(dim=1)
|
|
|
|
metrics.reset()
|
|
x = torch.randn(100, 50, 50)
|
|
y = torch.randn(100, 50, 50).transpose(1, 2)
|
|
opt_fn = torch._dynamo.optimize("inductor")(fn)
|
|
self.assertTrue(same(fn(x, y), opt_fn(x, y)))
|
|
assert metrics.generated_cpp_vec_kernel_count == 2
|
|
|
|
def test_transpose_sum2d_cpu_only(self):
|
|
def fn(a, b):
|
|
c = a * b
|
|
return c.sum()
|
|
|
|
metrics.reset()
|
|
x = torch.randn(50, 50)
|
|
y = torch.randn(50, 50).transpose(0, 1)
|
|
opt_fn = torch._dynamo.optimize("inductor")(fn)
|
|
self.assertTrue(same(fn(x, y), opt_fn(x, y)))
|
|
assert metrics.generated_cpp_vec_kernel_count == 2
|
|
|
|
|
|
if __name__ == "__main__":
|
|
from torch._dynamo.test_case import run_tests
|
|
from torch.testing._internal.inductor_utils import HAS_CPU
|
|
|
|
if HAS_CPU and not IS_MACOS:
|
|
run_tests(needs="filelock")
|