mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Enable FP8 cast for this issue https://github.com/pytorch/pytorch/issues/117119. Pull Request resolved: https://github.com/pytorch/pytorch/pull/117737 Approved by: https://github.com/jgong5, https://github.com/jansel
2808 lines
98 KiB
Python
2808 lines
98 KiB
Python
# Owner(s): ["oncall: cpu inductor"]
|
|
import contextlib
|
|
import copy
|
|
import itertools
|
|
import math
|
|
import platform
|
|
import sys
|
|
import unittest
|
|
from typing import Callable
|
|
from unittest.mock import patch
|
|
|
|
import numpy as np
|
|
import sympy
|
|
import torch
|
|
from torch._C import FileCheck
|
|
from torch._dynamo.testing import rand_strided
|
|
from torch._dynamo.utils import same
|
|
from torch._inductor import codecache, config, metrics
|
|
from torch._inductor.codegen.common import OptimizationContext
|
|
from torch._inductor.codegen.cpp import (
|
|
CppOverrides,
|
|
CppVecKernelChecker,
|
|
CppVecOverrides,
|
|
)
|
|
from torch._inductor.compile_fx import (
|
|
compile_fx,
|
|
compile_fx_inner,
|
|
complex_memory_overlap,
|
|
)
|
|
from torch._inductor.graph import GraphLowering
|
|
from torch._inductor.ir import InterpreterShim
|
|
from torch._inductor.utils import timed
|
|
from torch._inductor.virtualized import V
|
|
from torch.fx.experimental.proxy_tensor import make_fx
|
|
from torch.nn import functional as F
|
|
from torch.testing._internal.common_utils import (
|
|
instantiate_parametrized_tests,
|
|
IS_MACOS,
|
|
parametrize,
|
|
slowTest,
|
|
)
|
|
from torch.utils._python_dispatch import TorchDispatchMode
|
|
|
|
try:
|
|
try:
|
|
from . import test_torchinductor
|
|
except ImportError:
|
|
import test_torchinductor
|
|
except unittest.SkipTest:
|
|
if __name__ == "__main__":
|
|
sys.exit(0)
|
|
raise
|
|
|
|
|
|
vec_dtypes = test_torchinductor.vec_dtypes
|
|
_lowp_fp_dtypes = (
|
|
torch.bfloat16,
|
|
torch.float16,
|
|
)
|
|
run_and_get_cpp_code = test_torchinductor.run_and_get_cpp_code
|
|
TestCase = test_torchinductor.TestCase
|
|
aten = torch.ops.aten
|
|
check_model = test_torchinductor.check_model
|
|
|
|
|
|
class LstmModule(torch.nn.Module):
|
|
def __init__(
|
|
self,
|
|
input_size,
|
|
hidden_size,
|
|
num_layers,
|
|
bias=True,
|
|
bidirectional=False,
|
|
batch_first=False,
|
|
):
|
|
super().__init__()
|
|
self.lstm = torch.nn.LSTM(
|
|
input_size=input_size,
|
|
hidden_size=hidden_size,
|
|
num_layers=num_layers,
|
|
bias=bias,
|
|
bidirectional=bidirectional,
|
|
batch_first=batch_first,
|
|
)
|
|
|
|
def forward(self, x, h=None):
|
|
x, h = self.lstm(x, h)
|
|
return x, h
|
|
|
|
|
|
@instantiate_parametrized_tests
|
|
class CPUReproTests(TestCase):
|
|
common = check_model
|
|
|
|
def test_conv_stride_constraints(self):
|
|
for fmt in [torch.contiguous_format, torch.channels_last]:
|
|
# TorchDispatch doesn't work in our cuda invocation for some reason
|
|
m = torch.nn.Conv2d(5, 6, [3, 3])
|
|
|
|
def fn(inp, weight):
|
|
return (
|
|
F.conv2d(
|
|
inp, weight, None, m.stride, m.padding, m.dilation, m.groups
|
|
),
|
|
)
|
|
|
|
inp = torch.randn([2, 5, 16, 16])
|
|
inps = [inp, m.weight.to(memory_format=fmt)]
|
|
fn_fx = make_fx(fn)(*inps)
|
|
fn_compiled = compile_fx_inner(fn_fx, inps)
|
|
test_self = self
|
|
conv_seen = False
|
|
|
|
class RecordFunctions(TorchDispatchMode):
|
|
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
|
|
kwargs = kwargs if kwargs else {}
|
|
if func == torch.ops.aten.convolution.default:
|
|
# For CPU and mkldnn enable, we always using channles last
|
|
nonlocal fmt
|
|
if (
|
|
torch.backends.mkldnn.enabled
|
|
and torch.backends.mkldnn.is_available()
|
|
):
|
|
fmt = torch.channels_last
|
|
test_self.assertTrue(args[0].is_contiguous(memory_format=fmt))
|
|
test_self.assertTrue(args[1].is_contiguous(memory_format=fmt))
|
|
nonlocal conv_seen
|
|
conv_seen = True
|
|
|
|
return func(*args, **kwargs)
|
|
|
|
with RecordFunctions():
|
|
out = fn_compiled(inps)
|
|
|
|
self.assertTrue(conv_seen)
|
|
|
|
@patch("torch.cuda.is_available", lambda: False)
|
|
def test_conv2d_bn_mixed_dtype(self):
|
|
class Model(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.conv = torch.nn.Conv2d(
|
|
3,
|
|
16,
|
|
kernel_size=3,
|
|
stride=1,
|
|
padding=1,
|
|
bias=False,
|
|
dtype=torch.bfloat16,
|
|
)
|
|
self.bn = torch.nn.BatchNorm2d(
|
|
16, eps=0.001, momentum=0.1, affine=True, track_running_stats=True
|
|
)
|
|
|
|
def forward(self, x):
|
|
x = self.conv(x)
|
|
x = self.bn(x)
|
|
return x
|
|
|
|
v = torch.randn(1, 3, 64, 64, dtype=torch.bfloat16)
|
|
mod = Model().eval()
|
|
with torch.no_grad():
|
|
self.common(
|
|
mod,
|
|
(v,),
|
|
)
|
|
|
|
@unittest.skipIf(not torch.backends.mkldnn.is_available(), "MKLDNN is not enabled")
|
|
@patch("torch.cuda.is_available", lambda: False)
|
|
def test_conv2d_packed(self):
|
|
options = itertools.product([[3, 56, 56]], [True, False], [0, (0,)])
|
|
for x_shape, mode_train, padding in options:
|
|
mod = torch.nn.Sequential(
|
|
torch.nn.Conv2d(3, 64, 3, 3, padding=padding)
|
|
).train(mode=mode_train)
|
|
v = torch.randn(x_shape, dtype=torch.float32)
|
|
|
|
with torch.no_grad():
|
|
self.common(
|
|
mod,
|
|
(v,),
|
|
)
|
|
|
|
@patch("torch.cuda.is_available", lambda: False)
|
|
def test_conv2d_autocast(self):
|
|
v = torch.randn(1, 3, 28, 18, dtype=torch.float32)
|
|
mod = torch.nn.Sequential(torch.nn.Conv2d(3, 64, 3, 3)).eval()
|
|
with torch.no_grad(), torch.cpu.amp.autocast():
|
|
self.common(
|
|
mod,
|
|
(v,),
|
|
)
|
|
|
|
@unittest.skipIf(not torch.backends.mkldnn.is_available(), "MKLDNN is not enabled")
|
|
@patch("torch.cuda.is_available", lambda: False)
|
|
def test_unsupported_conv_transpose(self):
|
|
class Model(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.conv_transpose = torch.nn.ConvTranspose2d(
|
|
3, 6, 3, stride=1, padding=1, output_padding=1
|
|
)
|
|
|
|
def forward(self, input_tensor):
|
|
x = self.conv_transpose(input_tensor)
|
|
output = torch.tanh(x)
|
|
return output
|
|
|
|
input = torch.randn(1, 3, 28, 28)
|
|
m = Model().eval()
|
|
|
|
with torch.no_grad():
|
|
compiled_m = torch.compile(m)
|
|
with self.assertRaisesRegex(
|
|
RuntimeError,
|
|
"output padding must be smaller than either stride or dilation",
|
|
):
|
|
compiled_m(input)
|
|
|
|
@unittest.skipIf(not torch.backends.mkldnn.is_available(), "MKLDNN is not enabled")
|
|
@patch("torch.cuda.is_available", lambda: False)
|
|
def test_conv_used_from_multiple_places(self):
|
|
class M(torch.nn.Module):
|
|
def __init__(self, conv_in_channel, conv_out_channel) -> None:
|
|
super().__init__()
|
|
self.conv = torch.nn.Conv2d(conv_in_channel, conv_out_channel, (3, 3))
|
|
|
|
def forward(self, x):
|
|
res = self.conv(x)
|
|
res = F.relu(res)
|
|
res = self.conv(res)
|
|
return res
|
|
|
|
with torch.no_grad():
|
|
mod = M(3, 3).eval()
|
|
x = torch.randn(1, 3, 224, 224)
|
|
self.common(
|
|
mod,
|
|
(x,),
|
|
)
|
|
|
|
@unittest.skipIf(not torch.backends.mkldnn.is_available(), "MKLDNN is not enabled")
|
|
@patch("torch.cuda.is_available", lambda: False)
|
|
def test_linear_used_from_multiple_places(self):
|
|
class M(torch.nn.Module):
|
|
def __init__(self, in_channel, out_channel) -> None:
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(in_channel, out_channel)
|
|
|
|
def forward(self, x):
|
|
res = self.linear(x)
|
|
res = F.relu(res)
|
|
res = self.linear(res)
|
|
return res
|
|
|
|
dtypes = []
|
|
if torch.ops.mkldnn._is_mkldnn_bf16_supported():
|
|
dtypes.append(torch.bfloat16)
|
|
if torch.ops.mkldnn._is_mkldnn_fp16_supported():
|
|
dtypes.append(torch.float16)
|
|
for dtype in dtypes:
|
|
with torch.no_grad():
|
|
m = M(224, 224).to(dtype).eval()
|
|
m_opt = torch.compile(m)
|
|
x = torch.randn(224, 224, dtype=dtype)
|
|
m_opt(x)
|
|
self.assertEqual(m(x), m_opt(x))
|
|
|
|
@config.patch(implicit_fallbacks=True)
|
|
def test_multihead_attention_cpu(self):
|
|
def fn(
|
|
q,
|
|
k,
|
|
v,
|
|
embed_dim,
|
|
num_heads,
|
|
qkv_weight,
|
|
qkv_bias,
|
|
proj_weight,
|
|
proj_bias,
|
|
mask,
|
|
need_weights,
|
|
):
|
|
return torch._native_multi_head_attention(
|
|
q,
|
|
k,
|
|
v,
|
|
embed_dim,
|
|
num_heads,
|
|
qkv_weight,
|
|
qkv_bias,
|
|
proj_weight,
|
|
proj_bias,
|
|
mask,
|
|
need_weights,
|
|
)
|
|
|
|
B = 1
|
|
T = 3
|
|
embed_dim = 6
|
|
num_heads = 2
|
|
q = torch.randn([B, T, embed_dim])
|
|
k = torch.randn([B, T, embed_dim])
|
|
v = torch.randn([B, T, embed_dim])
|
|
qkv_weight = torch.randn([3 * embed_dim, embed_dim])
|
|
qkv_bias = torch.randn([3 * embed_dim])
|
|
proj_weight = torch.randn([3 * embed_dim, embed_dim])
|
|
proj_bias = torch.randn([3 * embed_dim])
|
|
mask = None
|
|
need_weights = False
|
|
|
|
inps = [
|
|
q,
|
|
k,
|
|
v,
|
|
embed_dim,
|
|
num_heads,
|
|
qkv_weight,
|
|
qkv_bias,
|
|
proj_weight,
|
|
proj_bias,
|
|
mask,
|
|
need_weights,
|
|
]
|
|
self.common(fn, inps)
|
|
|
|
@unittest.skipIf(not torch.backends.mkldnn.is_available(), "MKLDNN is not enabled")
|
|
@patch("torch.cuda.is_available", lambda: False)
|
|
def test_linear_packed(self):
|
|
dtypes = []
|
|
if torch.ops.mkldnn._is_mkldnn_bf16_supported():
|
|
dtypes.append(torch.bfloat16)
|
|
if torch.ops.mkldnn._is_mkldnn_fp16_supported():
|
|
dtypes.append(torch.float16)
|
|
options = itertools.product(
|
|
[[2, 3, 10], [2, 10], [10], [2, 0]], [3, 0], [True, False], dtypes
|
|
)
|
|
for input_shape, out_dim, bias, dtype in options:
|
|
mod = torch.nn.Sequential(
|
|
torch.nn.Linear(input_shape[-1], out_dim, bias=bias)
|
|
).eval()
|
|
|
|
v = torch.randn(input_shape)
|
|
with torch.no_grad():
|
|
self.common(
|
|
mod.to(dtype),
|
|
(v.to(dtype),),
|
|
)
|
|
|
|
@unittest.skipIf(not torch.backends.mkldnn.is_available(), "MKLDNN is not enabled")
|
|
@patch("torch.cuda.is_available", lambda: False)
|
|
def test_conv_transpose2d_packed_cpu(self):
|
|
options = itertools.product([[1, 3, 28, 28], [3, 28, 28]], [0, (0,)])
|
|
for x_shape, padding in options:
|
|
mod = torch.nn.Sequential(
|
|
torch.nn.ConvTranspose2d(3, 64, 3, 3, padding=padding)
|
|
).eval()
|
|
v = torch.randn(x_shape, dtype=torch.float32)
|
|
with torch.no_grad():
|
|
self.common(
|
|
mod,
|
|
(v,),
|
|
)
|
|
|
|
@unittest.skipIf(not torch._C._has_mkldnn, "MKLDNN is not enabled")
|
|
@patch("torch.cuda.is_available", lambda: False)
|
|
@torch._dynamo.config.patch(dynamic_shapes=True)
|
|
@torch._dynamo.config.patch(assume_static_by_default=False)
|
|
@torch._dynamo.config.patch(allow_rnn=True)
|
|
@config.patch(freezing=True)
|
|
def _test_lstm_packed(self, params_dict, change_input_sizes=False):
|
|
from torch._dynamo.utils import counters
|
|
|
|
for (
|
|
unbatched,
|
|
input_size,
|
|
hidden_size,
|
|
num_layers,
|
|
bidirectional,
|
|
bias,
|
|
empty_state,
|
|
batch_first,
|
|
batch_size,
|
|
seq_len,
|
|
) in itertools.product(*list(params_dict.values())):
|
|
dtypes = [torch.float]
|
|
if torch.ops.mkldnn._is_mkldnn_bf16_supported():
|
|
dtypes.append(torch.bfloat16)
|
|
if torch.ops.mkldnn._is_mkldnn_fp16_supported():
|
|
dtypes.append(torch.float16)
|
|
for dtype in dtypes:
|
|
counters.clear()
|
|
num_directions = 2 if bidirectional else 1
|
|
|
|
seq_len_var = seq_len + 3
|
|
if unbatched:
|
|
v = torch.randn(seq_len, input_size)
|
|
v_var = torch.randn(seq_len_var, input_size)
|
|
h = torch.randn(num_layers * num_directions, hidden_size)
|
|
c = torch.randn(num_layers * num_directions, hidden_size)
|
|
else:
|
|
if batch_first:
|
|
v = torch.randn(batch_size, seq_len, input_size)
|
|
v_var = torch.randn(batch_size, seq_len_var, input_size)
|
|
else:
|
|
v = torch.randn(seq_len, batch_size, input_size)
|
|
v_var = torch.randn(seq_len_var, batch_size, input_size)
|
|
h = torch.randn(
|
|
num_layers * num_directions, batch_size, hidden_size
|
|
)
|
|
c = torch.randn(
|
|
num_layers * num_directions, batch_size, hidden_size
|
|
)
|
|
|
|
mod = LstmModule(
|
|
input_size,
|
|
hidden_size,
|
|
num_layers,
|
|
bias,
|
|
bidirectional,
|
|
batch_first,
|
|
).eval()
|
|
maybe_autocast = (
|
|
torch.cpu.amp.autocast()
|
|
if dtype == torch.bfloat16
|
|
else contextlib.nullcontext()
|
|
)
|
|
|
|
with torch.no_grad(), maybe_autocast:
|
|
inps = [v]
|
|
if not empty_state:
|
|
inps.append((h, c))
|
|
|
|
fn_opt = torch._dynamo.optimize("inductor")(mod)
|
|
_, code = run_and_get_cpp_code(fn_opt, *inps)
|
|
|
|
# Check that _flat_weights are not functional_tensor, otherwise
|
|
# deepcopy will fail during recompilation.
|
|
fn_opt_copy = copy.deepcopy(fn_opt)
|
|
_flat_weights = fn_opt_copy.lstm._flat_weights
|
|
for _flat_weight in _flat_weights:
|
|
self.assertFalse(torch._is_functional_tensor(_flat_weight))
|
|
|
|
self.assertTrue("aten.mkldnn_rnn_layer" in code)
|
|
self.assertEqual(fn_opt(*inps), mod(*inps))
|
|
self.assertEqual(
|
|
counters["inductor"]["pattern_matcher_count"],
|
|
num_layers * num_directions
|
|
+ 2, # num of mkldnn_rnn_layer call + 2 view call on the concatenated hy, cy.
|
|
)
|
|
|
|
# Change input sizes
|
|
if change_input_sizes:
|
|
inps_var = [v_var]
|
|
self.assertEqual(fn_opt(*inps_var), mod(*inps_var))
|
|
|
|
@slowTest
|
|
def test_lstm_packed(self):
|
|
params_dict = {
|
|
"unbatched": [True, False],
|
|
"input_size": [1, 2],
|
|
"hidden_size": [2],
|
|
"num_layers": [1, 2],
|
|
"bidirectional": [False, True],
|
|
"bias": [False, True],
|
|
"empty_state": [False, True],
|
|
"batch_first": [True, False],
|
|
"batch_size": [1, 2],
|
|
"seq_len": [1, 2],
|
|
}
|
|
self._test_lstm_packed(params_dict)
|
|
|
|
def test_lstm_packed_change_input_sizes_cpu(self):
|
|
params_dict = {
|
|
"unbatched": [False],
|
|
"input_size": [2],
|
|
"hidden_size": [5],
|
|
"num_layers": [3],
|
|
"bidirectional": [True],
|
|
"bias": [True],
|
|
"empty_state": [False],
|
|
"batch_first": [False],
|
|
"batch_size": [2],
|
|
"seq_len": [3],
|
|
}
|
|
self._test_lstm_packed(params_dict, change_input_sizes=True)
|
|
|
|
@torch._dynamo.config.patch(dynamic_shapes=True)
|
|
@torch._dynamo.config.patch(assume_static_by_default=False)
|
|
@torch._dynamo.config.patch(allow_rnn=True)
|
|
def test_pack_padded_sequence_lstm(self):
|
|
embedding_dim = 12
|
|
hidden_dim = 10
|
|
batch_size = 24
|
|
num_layers = 1
|
|
bidirectional = True
|
|
num_direc = 2
|
|
max_lens = 96
|
|
|
|
sent = torch.randn(batch_size, max_lens, embedding_dim)
|
|
hid_0 = torch.rand(num_layers * num_direc, batch_size, hidden_dim)
|
|
hid_1 = torch.randn(num_layers * num_direc, batch_size, hidden_dim)
|
|
|
|
sent_lens = torch.Tensor(
|
|
[1, 2, 3, 4, 5, 1, 3, 2, 96, 5, 3, 1, 1, 2, 1, 2, 3, 6, 1, 2, 4, 6, 2, 1]
|
|
)
|
|
|
|
assert sent_lens.shape[0] == batch_size
|
|
assert sent_lens.max().item() == max_lens
|
|
|
|
hidden_0 = hid_0.clone().requires_grad_(False)
|
|
hidden_1 = hid_1.clone().requires_grad_(False)
|
|
embeds = torch.nn.utils.rnn.pack_padded_sequence(
|
|
sent, sent_lens, batch_first=True, enforce_sorted=False
|
|
)
|
|
|
|
mod = LstmModule(
|
|
embedding_dim,
|
|
hidden_dim,
|
|
num_layers=num_layers,
|
|
bias=True,
|
|
bidirectional=bidirectional,
|
|
batch_first=True,
|
|
).eval()
|
|
|
|
with torch.no_grad():
|
|
inps = [embeds, (hidden_0, hidden_1)]
|
|
fn_opt = torch._dynamo.optimize("inductor")(mod)
|
|
_, code = run_and_get_cpp_code(fn_opt, *inps)
|
|
# This case is unsupported
|
|
self.assertFalse("torch.ops.mkldnn._lstm" in code)
|
|
self.assertEqual(fn_opt(*inps), mod(*inps))
|
|
|
|
@patch("torch.cuda.is_available", lambda: False)
|
|
def test_conv_transpose2d_has_output_size_input(self):
|
|
# https://github.com/pytorch/pytorch/issues/100344.
|
|
class M(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.conv_transpose = torch.nn.ConvTranspose2d(
|
|
in_channels=3, out_channels=1, kernel_size=3, stride=1, padding=1
|
|
)
|
|
|
|
def forward(self, x):
|
|
return self.conv_transpose(x, output_size=(10, 10))
|
|
|
|
mod = M().eval()
|
|
v = torch.randn(1, 3, 10, 10, dtype=torch.float32)
|
|
with torch.no_grad():
|
|
self.common(
|
|
mod,
|
|
(v,),
|
|
)
|
|
|
|
def test_pad_with_nan_value(self):
|
|
# https://github.com/pytorch/pytorch/issues/100988.
|
|
class Model(torch.nn.Module):
|
|
def forward(self, x):
|
|
x = F.pad(x, (1, 1, 1, 1), value=float("nan"))
|
|
return x
|
|
|
|
mod = Model().eval()
|
|
v = torch.randn(1, 3, 10, 10, dtype=torch.float32)
|
|
with torch.no_grad():
|
|
self.common(
|
|
mod,
|
|
(v,),
|
|
)
|
|
|
|
def test_masked_fill_with_inf_or_nan_value(self):
|
|
def fn(value, mask):
|
|
y1 = torch.masked_fill(value, mask, float("inf"))
|
|
y2 = torch.masked_fill(value, mask, float("-inf"))
|
|
y3 = torch.masked_fill(value, mask, float("nan"))
|
|
return y1, y2, y3
|
|
|
|
value = torch.randn((2, 17))
|
|
mask = torch.randint(0, 1, size=(2, 17), dtype=torch.uint8).to(torch.bool)
|
|
with torch.no_grad():
|
|
self.common(
|
|
fn,
|
|
(value, mask),
|
|
)
|
|
|
|
def test_relu_with_inf_value(self):
|
|
# https://github.com/pytorch/pytorch/issues/117544.
|
|
|
|
def fn(out):
|
|
out = torch.sinh(input=out)
|
|
out = torch.relu(input=out)
|
|
return out
|
|
|
|
x = torch.Tensor([-572373.5000, 755109.1250, 330995.5625])
|
|
with torch.no_grad():
|
|
self.common(
|
|
fn,
|
|
(x,),
|
|
)
|
|
|
|
@config.patch(implicit_fallbacks=True)
|
|
def test_repeat_interleave(self):
|
|
def fn(y):
|
|
return torch.repeat_interleave(y, 2, output_size=8)
|
|
|
|
a = torch.tensor([[1, 2], [3, 4]])
|
|
self.common(
|
|
fn,
|
|
(a,),
|
|
)
|
|
|
|
def test_inplace_squeeze_needed(self):
|
|
mod = torch.nn.Sequential(
|
|
torch.nn.Linear(10, 10),
|
|
torch.nn.LayerNorm(10),
|
|
torch.nn.ReLU(),
|
|
).eval()
|
|
|
|
def fn(x):
|
|
return mod(x)
|
|
|
|
v = torch.randn(10)
|
|
# TODO: OMP parallel reduction order is not deterministic.
|
|
# Hence, the accurarcy might vary up and down. For short term,
|
|
# we increase the tolerance and will fix it later by using
|
|
# aten parallel.
|
|
self.common(fn, (v,), atol=5e-1, rtol=5e-1)
|
|
|
|
def test_cat_mul(self):
|
|
# https://github.com/pytorch/pytorch/issues/93365
|
|
def fn(p0, p1):
|
|
y1 = torch.cat([p0, p1], dim=0)
|
|
y2 = torch.mul(y1, y1)
|
|
return y1, y2
|
|
|
|
p0 = torch.randn(3, 4)
|
|
p1 = torch.randn(3, 4)
|
|
self.common(fn, (p0, p1))
|
|
|
|
def test_pow_cos(self):
|
|
# https://github.com/pytorch/pytorch/issues/98149
|
|
def fn(x):
|
|
t = x.pow(5)
|
|
return torch.cos(t)
|
|
|
|
x = torch.tensor([4], dtype=torch.uint8)
|
|
self.common(fn, (x,))
|
|
|
|
def test_reduce_with_masked(self):
|
|
# https://github.com/pytorch/pytorch/issues/96484
|
|
def fn(a, b):
|
|
a = torch.nn.functional.pad(a, (0, -1))
|
|
c = a + b
|
|
return c.min(0).values
|
|
|
|
a = torch.randn([2])
|
|
b = torch.randn([2])
|
|
self.common(fn, (a, b))
|
|
|
|
def test_scalar_sign_with_min(self):
|
|
# https://github.com/pytorch/pytorch/issues/101340
|
|
def fn(a):
|
|
t1 = torch.tanh(a)
|
|
t2 = torch.sign(t1)
|
|
return torch.min(t1, t2)
|
|
|
|
a = torch.randn(1, 3)
|
|
self.common(fn, (a,))
|
|
|
|
def test_index_propagation_issue_102065(self):
|
|
def fn(x):
|
|
x = torch.arange(x.numel())
|
|
return (x.unsqueeze(0) - x.unsqueeze(1)) ** 2
|
|
|
|
self.common(
|
|
fn,
|
|
(torch.randn(8),),
|
|
)
|
|
|
|
def test_ModularIndexing_range_issue_103133(self):
|
|
def fn(q, k):
|
|
einsum = torch.einsum("bcxd,bcyd->bcxy", (q, k))
|
|
constant_pad_nd = torch.ops.aten.constant_pad_nd.default(
|
|
einsum, [0, 0, 0, 1], 0.0
|
|
)
|
|
view = torch.ops.aten.view.default(constant_pad_nd, [12, 1, 512, 513])
|
|
y = view.new_zeros((12, 2, 256, 513))
|
|
y[:, :-1, :, 256:] = view[:, :, :256, :257]
|
|
return y
|
|
|
|
self.common(
|
|
fn,
|
|
(
|
|
torch.empty_strided((12, 1, 512, 64), (64, 196608, 768, 1)),
|
|
torch.empty_strided((12, 1, 512, 64), (64, 196608, 768, 1)),
|
|
),
|
|
)
|
|
|
|
@patch("torch.cuda.is_available", lambda: False)
|
|
def test_max_reduction_lowp_fp(self):
|
|
def fn(x):
|
|
return torch.ops.aten.max(x, 1, keepdim=True)[0].float()
|
|
|
|
for dtype in _lowp_fp_dtypes:
|
|
self.common(
|
|
fn,
|
|
(torch.randn(1, 32, 4, 4).to(dtype),),
|
|
)
|
|
|
|
@patch("torch.cuda.is_available", lambda: False)
|
|
def test_vec_transpose_lowp_fp(self):
|
|
for dtype in _lowp_fp_dtypes:
|
|
|
|
def fn(x):
|
|
return x.to(memory_format=torch.channels_last).to(dtype)
|
|
|
|
self.common(
|
|
fn,
|
|
(torch.randn(2, 3, 4, 4),),
|
|
)
|
|
|
|
def test_load_inf_bf16(self):
|
|
def fn1(x):
|
|
return torch.where(x > 0, x, math.inf)
|
|
|
|
def fn2(x):
|
|
return torch.where(x > 0, x, -math.inf)
|
|
|
|
for fn in [fn1, fn2]:
|
|
self.common(
|
|
fn,
|
|
(torch.randn(1, 3, 16, 16),),
|
|
)
|
|
|
|
@patch("torch.cuda.is_available", lambda: False)
|
|
def test_fp32_load_with_to_lowp_fp(self):
|
|
# From llama model.
|
|
class Model(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.cache_k = torch.zeros(8, 4, 2, 2)
|
|
|
|
def forward(self, x, xk):
|
|
bsz, seqlen, _ = x.shape
|
|
self.cache_k = self.cache_k.to(x)
|
|
self.cache_k[:bsz, 1 : 1 + seqlen] = xk
|
|
return self.cache_k
|
|
|
|
for dtype in _lowp_fp_dtypes:
|
|
ref_model = Model().eval()
|
|
opt_model = torch.compile()(Model().eval())
|
|
x = torch.randn(4, 2, 2).to(dtype)
|
|
xk = torch.randn(4, 2, 2, 2).to(dtype)
|
|
self.assertEqual(opt_model(x, xk), ref_model(x, xk))
|
|
|
|
@unittest.skipIf(
|
|
not codecache.valid_vec_isa_list(), "Does not support vectorization"
|
|
)
|
|
@patch("torch.cuda.is_available", lambda: False)
|
|
def test_sigmoid_with_reduction(self):
|
|
def fn(x):
|
|
x = torch.ops.aten.sigmoid.default(x)
|
|
return torch.ops.aten.mean.dim(x, [-1, -2], True)
|
|
|
|
x = torch.randn((1, 8, 8, 8))
|
|
with config.patch({"cpp.simdlen": None}):
|
|
torch._dynamo.reset()
|
|
metrics.reset()
|
|
self.common(fn, (x,))
|
|
|
|
def test_slice_scatter_default_end_value(self):
|
|
# From HF AllenaiLongformerBase.
|
|
def fn(query, key, window_overlap):
|
|
batch_size, seq_len, num_heads, head_dim = query.size()
|
|
assert (
|
|
seq_len % (window_overlap * 2) == 0
|
|
), f"Sequence length should be multiple of {window_overlap * 2}. Given {seq_len}"
|
|
|
|
chunks_count = torch.div(seq_len, window_overlap, rounding_mode="trunc") - 1
|
|
diagonal_chunked_attention_scores = key
|
|
diagonal_attention_scores = diagonal_chunked_attention_scores.new_zeros(
|
|
(
|
|
batch_size * num_heads,
|
|
chunks_count + 1,
|
|
window_overlap,
|
|
window_overlap * 2 + 1,
|
|
)
|
|
)
|
|
diagonal_attention_scores[
|
|
:, :3, :, window_overlap:
|
|
] = diagonal_chunked_attention_scores[
|
|
:, :, :window_overlap, : window_overlap + 1
|
|
]
|
|
return diagonal_attention_scores
|
|
|
|
self.common(
|
|
fn,
|
|
(
|
|
torch.randn(1, 1024, 12, 64),
|
|
torch.randn(12, 3, 512, 513),
|
|
256,
|
|
),
|
|
)
|
|
|
|
@unittest.skipIf(
|
|
not codecache.valid_vec_isa_list(), "Does not support vectorization"
|
|
)
|
|
@patch("torch.cuda.is_available", lambda: False)
|
|
def test_to_uint8_rounding_method(self):
|
|
def fn(x):
|
|
return x.to(torch.uint8)
|
|
|
|
numerical_testsuit = [4.4, 4.5, 4.6, 5.5]
|
|
for numerical_number in numerical_testsuit:
|
|
x = torch.ones(17) * numerical_number
|
|
with config.patch({"cpp.simdlen": None}):
|
|
torch._dynamo.reset()
|
|
metrics.reset()
|
|
self.common(fn, (x,))
|
|
assert metrics.generated_cpp_vec_kernel_count == 1
|
|
|
|
@unittest.skipIf(
|
|
not codecache.valid_vec_isa_list(), "Does not support vectorization"
|
|
)
|
|
@patch("torch.cuda.is_available", lambda: False)
|
|
def test_decomposed_dequant_relu_quant(self):
|
|
def fn(x, scale, zero_point, use_dequant, use_quant):
|
|
# For quantized_decomposed.dequantize_per_tensor
|
|
# Refer to torch/ao/quantization/fx/_decomposed.py
|
|
if use_dequant:
|
|
x = (x.to(torch.float32) - zero_point) * scale
|
|
|
|
x = torch.relu(x)
|
|
|
|
# For quantized_decomposed.quantize_per_tensor
|
|
# Refer to torch/ao/quantization/fx/_decomposed.py
|
|
if use_quant:
|
|
inv_scale = 1.0 / scale
|
|
x = torch.clamp(torch.round(x * inv_scale) + zero_point, 0, 255).to(
|
|
torch.uint8
|
|
)
|
|
return x
|
|
|
|
use_dequant_list = [False, True]
|
|
use_quant_list = [False, True]
|
|
for use_dequant, use_quant in itertools.product(
|
|
use_dequant_list, use_quant_list
|
|
):
|
|
x = torch.clamp(
|
|
torch.randn((1, 7, 7, 9), dtype=torch.float32) * 100, 0, 255
|
|
)
|
|
if use_dequant:
|
|
x = x.to(torch.uint8)
|
|
zero_point = 100
|
|
scale = 0.01
|
|
with config.patch({"cpp.simdlen": None}):
|
|
torch._dynamo.reset()
|
|
metrics.reset()
|
|
self.common(fn, (x, scale, zero_point, use_dequant, use_quant))
|
|
assert metrics.generated_cpp_vec_kernel_count == 1
|
|
|
|
@unittest.skipIf(
|
|
not codecache.valid_vec_isa_list(), "Does not support vectorization"
|
|
)
|
|
@patch("torch.cuda.is_available", lambda: False)
|
|
def test_dequant_quant_lowering(self):
|
|
def fn(x, scale, zero_point, use_dequant, use_quant):
|
|
if use_dequant:
|
|
x = torch.ops.quantized_decomposed.dequantize_per_tensor(
|
|
x, scale, zero_point, 0, 255, torch.uint8
|
|
)
|
|
|
|
x = torch.relu(x)
|
|
|
|
if use_quant:
|
|
x = torch.ops.quantized_decomposed.quantize_per_tensor(
|
|
x, scale, zero_point, 0, 255, torch.uint8
|
|
)
|
|
return x
|
|
|
|
use_dequant_list = [False, True]
|
|
use_quant_list = [False, True]
|
|
use_tensor_overload_list = [False, True]
|
|
for use_dequant, use_quant, use_tensor_overload in itertools.product(
|
|
use_dequant_list, use_quant_list, use_tensor_overload_list
|
|
):
|
|
x = torch.clamp(
|
|
torch.randn((1, 7, 7, 9), dtype=torch.float32) * 100, 0, 255
|
|
)
|
|
if use_dequant:
|
|
x = x.to(torch.uint8)
|
|
zero_point = 100
|
|
scale = 0.01
|
|
if use_tensor_overload:
|
|
zero_point = torch.tensor(zero_point, dtype=torch.int64)
|
|
scale = torch.tensor(scale)
|
|
with config.patch({"cpp.simdlen": None}):
|
|
torch._dynamo.reset()
|
|
metrics.reset()
|
|
self.common(fn, (x, scale, zero_point, use_dequant, use_quant))
|
|
assert metrics.generated_cpp_vec_kernel_count == 1
|
|
|
|
@unittest.skipIf(
|
|
not codecache.valid_vec_isa_list(), "Does not support vectorization"
|
|
)
|
|
@patch("torch.cuda.is_available", lambda: False)
|
|
def test_dequant_maxpool2d_lowering(self):
|
|
def fn(x, scale, zero_point):
|
|
x = torch.ops.quantized_decomposed.dequantize_per_tensor(
|
|
x, scale, zero_point, 0, 255, torch.uint8
|
|
)
|
|
max_pool2d_with_indices_default = (
|
|
torch.ops.aten.max_pool2d_with_indices.default(
|
|
x, [2, 2], [2, 2], [1, 1]
|
|
)[0]
|
|
)
|
|
return max_pool2d_with_indices_default
|
|
|
|
use_tensor_overload_list = [False, True]
|
|
for use_tensor_overload in use_tensor_overload_list:
|
|
x = (
|
|
torch.clamp(
|
|
torch.randn((3, 16, 8, 8), dtype=torch.float32) * 100, 0, 255
|
|
)
|
|
.to(torch.uint8)
|
|
.contiguous(memory_format=torch.channels_last)
|
|
)
|
|
zero_point = 100
|
|
scale = 0.01
|
|
if use_tensor_overload:
|
|
zero_point = torch.tensor(zero_point, dtype=torch.int64)
|
|
scale = torch.tensor(scale)
|
|
with config.patch({"cpp.simdlen": None}):
|
|
torch._dynamo.reset()
|
|
metrics.reset()
|
|
self.common(fn, (x, scale, zero_point))
|
|
assert metrics.generated_cpp_vec_kernel_count == 1
|
|
|
|
@unittest.skipIf(
|
|
not codecache.valid_vec_isa_list(), "Does not support vectorization"
|
|
)
|
|
@patch("torch.cuda.is_available", lambda: False)
|
|
def test_tile2d_load_decomposed_dequant_add_relu_quant(self):
|
|
def fn(
|
|
x,
|
|
scale,
|
|
zero_point,
|
|
x2,
|
|
scale2,
|
|
zero_point2,
|
|
output_scale,
|
|
output_zero_point,
|
|
use_dequant,
|
|
use_dequant2,
|
|
use_quant,
|
|
):
|
|
if use_dequant:
|
|
x = torch.ops.quantized_decomposed.dequantize_per_tensor(
|
|
x, scale, zero_point, 0, 255, torch.uint8
|
|
)
|
|
if use_dequant2:
|
|
x2 = torch.ops.quantized_decomposed.dequantize_per_tensor(
|
|
x2, scale2, zero_point2, 0, 255, torch.uint8
|
|
)
|
|
temp = x + x2
|
|
y = torch.relu(temp)
|
|
|
|
if use_quant:
|
|
y = torch.ops.quantized_decomposed.quantize_per_tensor(
|
|
y, output_scale, output_zero_point, 0, 255, torch.uint8
|
|
)
|
|
return y.contiguous()
|
|
|
|
use_dequant_list = [False, True]
|
|
use_dequant_list2 = [False, True]
|
|
use_quant_list = [False, True]
|
|
for use_dequant, use_dequant2, use_quant in itertools.product(
|
|
use_dequant_list, use_dequant_list2, use_quant_list
|
|
):
|
|
x = torch.clamp(
|
|
torch.randn((1, 1024, 14, 14), dtype=torch.float32) * 100, 0, 255
|
|
).contiguous(memory_format=torch.channels_last)
|
|
x2 = torch.clamp(
|
|
torch.randn((1, 1024, 14, 14), dtype=torch.float32) * 100, 0, 255
|
|
).contiguous(memory_format=torch.channels_last)
|
|
if use_dequant:
|
|
x = x.to(torch.uint8).contiguous(memory_format=torch.channels_last)
|
|
if use_dequant2:
|
|
x2 = x2.to(torch.uint8).contiguous(memory_format=torch.channels_last)
|
|
zero_point = 1
|
|
scale = 0.01
|
|
zero_point2 = 2
|
|
scale2 = 0.02
|
|
output_zero_point = 3
|
|
output_scale = 0.03
|
|
with config.patch({"cpp.simdlen": None}):
|
|
torch._dynamo.reset()
|
|
metrics.reset()
|
|
self.common(
|
|
fn,
|
|
(
|
|
x,
|
|
scale,
|
|
zero_point,
|
|
x2,
|
|
scale2,
|
|
zero_point2,
|
|
output_scale,
|
|
output_zero_point,
|
|
use_dequant,
|
|
use_dequant2,
|
|
use_quant,
|
|
),
|
|
)
|
|
assert metrics.generated_cpp_vec_kernel_count == 2
|
|
|
|
@unittest.skipIf(
|
|
not codecache.valid_vec_isa_list(), "Does not support vectorization"
|
|
)
|
|
@patch("torch.cuda.is_available", lambda: False)
|
|
def test_non_contiguous_load_buf_quant(self):
|
|
def fn(
|
|
x1,
|
|
x2,
|
|
groups,
|
|
):
|
|
x = torch.cat((x1, x2), dim=1)
|
|
batchsize, num_channels, height, width = x.size()
|
|
channels_per_group = num_channels // groups
|
|
x = torch.ops.quantized_decomposed.dequantize_per_tensor(
|
|
x, 1.0, 0, 0, 255, torch.uint8
|
|
)
|
|
x = x.view(batchsize, groups, channels_per_group, height, width)
|
|
x = torch.ops.quantized_decomposed.quantize_per_tensor(
|
|
x, 1.0, 0, 0, 255, torch.uint8
|
|
)
|
|
x = torch.ops.quantized_decomposed.dequantize_per_tensor(
|
|
x, 1.0, 0, 0, 255, torch.uint8
|
|
)
|
|
x = torch.transpose(x, 1, 2).contiguous()
|
|
x = x.view(batchsize, num_channels, height, width)
|
|
return x
|
|
|
|
x = torch.randint(0, 8, (1, 116, 28, 28), dtype=torch.uint8).contiguous(
|
|
memory_format=torch.channels_last
|
|
)
|
|
x2 = torch.randint(0, 8, (1, 116, 28, 28), dtype=torch.uint8).contiguous(
|
|
memory_format=torch.channels_last
|
|
)
|
|
|
|
with config.patch({"cpp.simdlen": None}):
|
|
torch._dynamo.reset()
|
|
metrics.reset()
|
|
self.common(
|
|
fn,
|
|
(
|
|
x,
|
|
x2,
|
|
2,
|
|
),
|
|
)
|
|
assert metrics.generated_cpp_vec_kernel_count == 2
|
|
|
|
@unittest.skipIf(
|
|
not codecache.valid_vec_isa_list(), "Does not support vectorization"
|
|
)
|
|
@patch("torch.cuda.is_available", lambda: False)
|
|
def test_tile2d_store_channel_shuffle_cl_quant_output(self):
|
|
def channel_shuffle(x, groups, output_scale, output_zero_point):
|
|
batchsize, num_channels, height, width = x.size()
|
|
channels_per_group = num_channels // groups
|
|
x = x.view(batchsize, groups, channels_per_group, height, width)
|
|
x = torch.transpose(x, 1, 2).contiguous()
|
|
x = x.view(batchsize, -1, height, width)
|
|
x = torch.ops.quantized_decomposed.quantize_per_tensor(
|
|
x, output_scale, output_zero_point, 0, 255, torch.uint8
|
|
)
|
|
return x.contiguous(memory_format=torch.channels_last)
|
|
|
|
with config.patch({"cpp.simdlen": None}):
|
|
torch._dynamo.reset()
|
|
metrics.reset()
|
|
x = torch.randn(64, 58, 28, 28)
|
|
output_zero_point = 3
|
|
output_scale = 0.03
|
|
self.common(channel_shuffle, (x, 2, output_scale, output_zero_point))
|
|
assert metrics.generated_cpp_vec_kernel_count == 2
|
|
|
|
@unittest.skipIf(
|
|
not codecache.valid_vec_isa_list(), "Does not support vectorization"
|
|
)
|
|
@patch("torch.cuda.is_available", lambda: False)
|
|
def test_dequant_relu_quant_dequant_relu_quant_lowering(self):
|
|
def fn(x, scale, zero_point, scale2, zero_point2, scale3, zero_point3):
|
|
x = torch.ops.quantized_decomposed.dequantize_per_tensor(
|
|
x, scale, zero_point, 0, 255, torch.uint8
|
|
)
|
|
x = torch.relu(x)
|
|
x = torch.ops.quantized_decomposed.quantize_per_tensor(
|
|
x, scale2, zero_point2, 0, 255, torch.uint8
|
|
)
|
|
x = torch.ops.quantized_decomposed.dequantize_per_tensor(
|
|
x, scale2, zero_point2, 0, 255, torch.uint8
|
|
)
|
|
x = torch.relu(x)
|
|
x = torch.ops.quantized_decomposed.quantize_per_tensor(
|
|
x, scale3, zero_point3, 0, 255, torch.uint8
|
|
)
|
|
return x
|
|
|
|
for use_tensor_overload in [True, False]:
|
|
x = torch.clamp(
|
|
torch.randn((1, 7, 7, 9), dtype=torch.float32) * 100, 0, 255
|
|
).to(torch.uint8)
|
|
zero_point_list = [100, 101, 102]
|
|
scale_list = [0.01, 0.02, 0.03]
|
|
if use_tensor_overload:
|
|
for i in range(len(zero_point_list)):
|
|
zero_point_list[i] = torch.tensor(
|
|
zero_point_list[i], dtype=torch.int64
|
|
)
|
|
scale_list[i] = torch.tensor(scale_list[i])
|
|
zero_point, zero_point2, zero_point3 = zero_point_list
|
|
scale, scale2, scale3 = scale_list
|
|
with config.patch({"cpp.simdlen": None}):
|
|
torch._dynamo.reset()
|
|
metrics.reset()
|
|
self.common(
|
|
fn,
|
|
(x, scale, zero_point, scale2, zero_point2, scale3, zero_point3),
|
|
rtol=1e-2,
|
|
atol=1e-2,
|
|
)
|
|
assert metrics.generated_cpp_vec_kernel_count == 1
|
|
|
|
def test_inplace_add_alpha(self):
|
|
def fn(x, y):
|
|
aten.add_.Tensor(x, y, alpha=0.55)
|
|
return (x,)
|
|
|
|
x1 = torch.zeros(10)
|
|
x2 = torch.zeros(10)
|
|
x3 = torch.zeros(10)
|
|
y = torch.randn(10)
|
|
fn_fx = make_fx(fn)(x1, y)
|
|
fn_compiled = compile_fx_inner(fn_fx, [x1, y])
|
|
fn(x2, y)
|
|
fn_compiled([x3, y])
|
|
assert same(x2, x3)
|
|
|
|
def test_int_div(self):
|
|
def fn(x, y):
|
|
s3 = x.size(1)
|
|
a = torch.zeros((1 + s3) // 2)
|
|
a += y
|
|
return a, s3
|
|
|
|
p0 = torch.randint(5, (1, 8))
|
|
p1 = torch.randn(1)
|
|
self.common(fn, (p0, p1))
|
|
|
|
def test_no_op_squeeze(self):
|
|
@torch._dynamo.optimize("inductor")
|
|
def forward(arg0_1):
|
|
return torch.ops.aten.squeeze.dim(arg0_1, 1)
|
|
|
|
x = torch.randn((10, 20))
|
|
self.common(forward, (x,))
|
|
|
|
def test_parallel_num_threads(self):
|
|
@torch._dynamo.optimize("inductor")
|
|
def fn(x1, x2):
|
|
return x1 + x2
|
|
|
|
@contextlib.contextmanager
|
|
def set_num_threads(num_threads):
|
|
orig_num_threads = torch.get_num_threads()
|
|
torch.set_num_threads(num_threads)
|
|
yield
|
|
torch.set_num_threads(orig_num_threads)
|
|
|
|
x1 = torch.randn((10, 20))
|
|
x2 = torch.randn((10, 20))
|
|
with set_num_threads(1):
|
|
assert same(x1 + x2, fn(x1, x2))
|
|
with set_num_threads(4):
|
|
assert same(x1 + x2, fn(x1, x2))
|
|
|
|
@patch("torch.cuda.is_available", lambda: False)
|
|
def test_timed_cpu_only(self):
|
|
timed(lambda: torch.randn(10), ())
|
|
|
|
def test_complex_memory_overlap(self):
|
|
dense = torch.zeros(64, 32)
|
|
self.assertFalse(complex_memory_overlap(dense))
|
|
self.assertFalse(complex_memory_overlap(dense.t()))
|
|
|
|
strided = dense.split(4, dim=1)
|
|
self.assertFalse(complex_memory_overlap(strided[0]))
|
|
self.assertFalse(complex_memory_overlap(strided[0].t()))
|
|
|
|
unsqueezed = dense.unsqueeze(1)
|
|
self.assertFalse(complex_memory_overlap(unsqueezed))
|
|
self.assertFalse(complex_memory_overlap(unsqueezed.permute(1, 2, 0)))
|
|
|
|
gathered = dense.index_select(0, torch.IntTensor([1, 0, 1]))
|
|
self.assertFalse(complex_memory_overlap(gathered))
|
|
self.assertFalse(complex_memory_overlap(gathered.t()))
|
|
|
|
@unittest.skipIf(
|
|
not codecache.valid_vec_isa_list(), "Does not support vectorization"
|
|
)
|
|
def test_vec_dynamic_shapes(self):
|
|
def fn(x):
|
|
return torch.softmax(x, -1)
|
|
|
|
value = torch.randn((2, 10))
|
|
with config.patch({"cpp.simdlen": None}):
|
|
torch._dynamo.reset()
|
|
metrics.reset()
|
|
self.common(fn, (value,))
|
|
|
|
@unittest.skipIf(
|
|
platform.machine() != "x86_64" or not codecache.valid_vec_isa_list(),
|
|
"Does not support vectorization or not x86_64 machine",
|
|
)
|
|
@patch("torch.cuda.is_available", lambda: False)
|
|
def test_auto_simd(self):
|
|
vec_avx512 = codecache.supported_vec_isa_list[0]
|
|
vec_avx2 = codecache.supported_vec_isa_list[1]
|
|
self.assertTrue(vec_avx512.bit_width() == 512)
|
|
self.assertTrue(vec_avx2.bit_width() == 256)
|
|
self.assertTrue(vec_avx512.nelements() == 16)
|
|
self.assertTrue(vec_avx2.nelements() == 8)
|
|
self.assertTrue(vec_avx512.nelements(torch.bfloat16) == 32)
|
|
self.assertTrue(vec_avx2.nelements(torch.bfloat16) == 16)
|
|
|
|
with config.patch({"cpp.simdlen": None}):
|
|
isa = codecache.pick_vec_isa()
|
|
if vec_avx512 in codecache.valid_vec_isa_list():
|
|
self.assertTrue(isa == vec_avx512)
|
|
else:
|
|
self.assertTrue(isa == vec_avx2)
|
|
|
|
with config.patch({"cpp.simdlen": 0}):
|
|
isa = codecache.pick_vec_isa()
|
|
self.assertFalse(isa)
|
|
|
|
with config.patch({"cpp.simdlen": 1}):
|
|
isa = codecache.pick_vec_isa()
|
|
self.assertFalse(isa)
|
|
|
|
with config.patch({"cpp.simdlen": 257}):
|
|
isa = codecache.pick_vec_isa()
|
|
self.assertFalse(isa)
|
|
|
|
with config.patch({"cpp.simdlen": 513}):
|
|
isa_list = codecache.valid_vec_isa_list()
|
|
if vec_avx512 in isa_list:
|
|
self.assertFalse(isa)
|
|
|
|
with config.patch({"cpp.simdlen": 512}):
|
|
isa_list = codecache.valid_vec_isa_list()
|
|
if vec_avx512 in isa_list:
|
|
isa = codecache.pick_vec_isa()
|
|
self.assertTrue(isa == vec_avx512)
|
|
|
|
with config.patch({"cpp.simdlen": 256}):
|
|
isa_list = codecache.valid_vec_isa_list()
|
|
if vec_avx2 in isa_list:
|
|
isa = codecache.pick_vec_isa()
|
|
self.assertTrue(isa == vec_avx2)
|
|
|
|
@unittest.skipIf(
|
|
not codecache.valid_vec_isa_list(), "Does not support vectorization"
|
|
)
|
|
@patch("torch.cuda.is_available", lambda: False)
|
|
def test_masked_fill_softmax(self):
|
|
def fn(value, mask):
|
|
mask = mask.to(torch.bool)
|
|
x = torch.masked_fill(value, mask, -33.0)
|
|
return torch.softmax(x, -1)
|
|
|
|
for dtype in vec_dtypes:
|
|
value = torch.randn((2, 17), dtype=dtype)
|
|
mask = torch.randint(0, 1, size=(2, 17), dtype=torch.uint8)
|
|
with config.patch({"cpp.simdlen": None}):
|
|
for cpp_wrapper_flag in [True, False]:
|
|
with config.patch({"cpp_wrapper": cpp_wrapper_flag}):
|
|
torch._dynamo.reset()
|
|
metrics.reset()
|
|
self.common(fn, (value, mask))
|
|
assert metrics.generated_cpp_vec_kernel_count >= 1
|
|
|
|
def test_load_same_bool_tensor_twice(self):
|
|
@torch._dynamo.optimize("inductor")
|
|
def fn(a, b):
|
|
x = torch.masked_fill(a, b, -33.0)
|
|
y = torch.masked_fill(a, b, -33.0)
|
|
return x, y
|
|
|
|
value = torch.randn((2, 17))
|
|
mask = torch.randint(0, 1, size=(2, 17), dtype=torch.uint8).to(torch.bool)
|
|
fn(value, mask)
|
|
|
|
def test_cpu_vec_cosim(self):
|
|
cpp_vec_op_list = []
|
|
cpp_op_list = []
|
|
|
|
for k, v in CppVecOverrides.__dict__.items():
|
|
if isinstance(v, staticmethod):
|
|
cpp_vec_op_list.append(k)
|
|
for k, v in CppOverrides.__dict__.items():
|
|
if isinstance(v, staticmethod):
|
|
cpp_op_list.append(k)
|
|
|
|
diff = [
|
|
"constant",
|
|
"index_expr",
|
|
"signbit",
|
|
"isinf",
|
|
"mod",
|
|
"masked",
|
|
"randn",
|
|
"isnan",
|
|
"rand",
|
|
"randint64",
|
|
"logical_and",
|
|
"logical_not",
|
|
"logical_or",
|
|
"logical_xor",
|
|
"bitwise_and",
|
|
"bitwise_left_shift",
|
|
"bitwise_not",
|
|
"bitwise_right_shift",
|
|
"bitwise_or",
|
|
"bitwise_xor",
|
|
"to_dtype_bitcast",
|
|
]
|
|
union = {*cpp_vec_op_list, *diff}
|
|
self.assertTrue(
|
|
set(cpp_op_list).issubset(union), f"unexpected: {set(cpp_op_list) - union}"
|
|
)
|
|
|
|
def test_atomic_add_lowp_fp(self):
|
|
def fn(test_args):
|
|
res = torch.gather(**test_args)
|
|
return res
|
|
|
|
for dtype in _lowp_fp_dtypes:
|
|
input_tensor_for_ref = torch.tensor(
|
|
[[3.0, -5.0]], dtype=dtype, requires_grad=True
|
|
)
|
|
input_tensor_for_opt = torch.tensor(
|
|
[[3.0, -5.0]], dtype=dtype, requires_grad=True
|
|
)
|
|
|
|
test_args_for_ref = {
|
|
"input": input_tensor_for_ref,
|
|
"dim": 1,
|
|
"index": torch.tensor([[1]]),
|
|
}
|
|
test_args_for_opt = {
|
|
"input": input_tensor_for_opt,
|
|
"dim": 1,
|
|
"index": torch.tensor([[1]]),
|
|
}
|
|
|
|
opt_fn = torch.compile(fn)
|
|
|
|
ref_fwd = fn(test_args_for_ref)
|
|
res_fwd = opt_fn(test_args_for_opt)
|
|
self.assertEqual(res_fwd, ref_fwd)
|
|
|
|
torch.manual_seed(1)
|
|
bwd_tensor_for_ref = torch.randn(ref_fwd.shape, dtype=dtype)
|
|
torch.manual_seed(1)
|
|
bwd_tensor_for_opt = torch.randn(res_fwd.shape, dtype=dtype)
|
|
self.assertEqual(bwd_tensor_for_ref, bwd_tensor_for_opt)
|
|
|
|
ref_fwd.backward(bwd_tensor_for_ref)
|
|
res_fwd.backward(bwd_tensor_for_opt)
|
|
|
|
ref_grad = test_args_for_ref["input"].grad
|
|
res_grad = test_args_for_opt["input"].grad
|
|
self.assertEqual(ref_grad, res_grad)
|
|
|
|
@patch("torch.cuda.is_available", lambda: False)
|
|
def test_scatter_using_atomic_add(self):
|
|
def fn(a, dim, index, b):
|
|
return aten.scatter(a, dim, index, b, reduce="add")
|
|
|
|
inps = (
|
|
torch.randn(5, 29, 13),
|
|
2,
|
|
torch.tensor([[[3, 5, 7, 9]]]),
|
|
torch.randn(1, 1, 10),
|
|
)
|
|
|
|
fn_opt = torch.compile()(fn)
|
|
with config.patch({"cpp.fallback_scatter_reduce_sum": False}):
|
|
_, code = run_and_get_cpp_code(fn_opt, *inps)
|
|
FileCheck().check("atomic_add").run(code)
|
|
|
|
self.assertEqual(
|
|
fn(*inps),
|
|
fn_opt(*inps),
|
|
)
|
|
|
|
@unittest.skipIf(
|
|
not codecache.valid_vec_isa_list(), "Does not support vectorization"
|
|
)
|
|
@patch("torch.cuda.is_available", lambda: False)
|
|
def test_new_vec_op_cpu_only(self):
|
|
def fn(x):
|
|
return torch.log1p(torch.expm1(torch.erf(x)))
|
|
|
|
for dtype in vec_dtypes:
|
|
torch.manual_seed(0)
|
|
x = torch.randn((2, 9), dtype=dtype)
|
|
x[0, 0] = torch.nan
|
|
x[1, -1] = torch.nan
|
|
|
|
tol = 1e-2 if dtype == torch.bfloat16 else 1e-4
|
|
|
|
with config.patch({"cpp.simdlen": None}):
|
|
for cpp_wrapper_flag in [True, False]:
|
|
with config.patch({"cpp_wrapper": cpp_wrapper_flag}):
|
|
torch._dynamo.reset()
|
|
metrics.reset()
|
|
self.common(fn, (x,))
|
|
assert metrics.generated_cpp_vec_kernel_count == 1
|
|
|
|
@unittest.skipIf(
|
|
not codecache.valid_vec_isa_list(), "Does not support vectorization"
|
|
)
|
|
@patch("torch.cuda.is_available", lambda: False)
|
|
def test_vec_cpu_only_for_all_available_isa(self):
|
|
def fn(x):
|
|
return torch.sin(torch.cos(torch.erf(x)))
|
|
|
|
x = torch.randn((2, 9))
|
|
x[0, 0] = torch.nan
|
|
x[1, -1] = torch.nan
|
|
|
|
bit_widths = [isa._bit_width for isa in codecache.valid_vec_isa_list()] + [None]
|
|
for item in bit_widths:
|
|
with config.patch({"cpp.simdlen": item}):
|
|
torch._dynamo.reset()
|
|
metrics.reset()
|
|
self.common(fn, (x,))
|
|
assert metrics.generated_cpp_vec_kernel_count == 1
|
|
|
|
@slowTest
|
|
@unittest.skipIf(
|
|
not codecache.valid_vec_isa_list(), "Does not support vectorization"
|
|
)
|
|
@patch("torch.cuda.is_available", lambda: False)
|
|
def test__adaptive_avg_pool2d(self):
|
|
def wrap_fn(oh, ow):
|
|
def fn(x):
|
|
return torch._adaptive_avg_pool2d(x, (oh, ow))
|
|
|
|
return fn
|
|
|
|
bit_widths = [isa._bit_width for isa in codecache.valid_vec_isa_list()]
|
|
ih = [16, 65]
|
|
iw = ih
|
|
oh = ih
|
|
ow = ih
|
|
for _ih, _iw, _oh, _ow, _simd_len, dtype in itertools.product(
|
|
ih, iw, oh, ow, bit_widths, vec_dtypes
|
|
):
|
|
x = torch.randn(2, 3, _ih, _iw, dtype=dtype).to(
|
|
memory_format=torch.channels_last
|
|
)
|
|
_fn = wrap_fn(_oh, _ow)
|
|
with config.patch({"cpp.simdlen": _simd_len}):
|
|
torch._dynamo.reset()
|
|
metrics.reset()
|
|
self.common(_fn, (x,))
|
|
assert metrics.generated_cpp_vec_kernel_count == 1
|
|
|
|
@unittest.skipIf(
|
|
not codecache.valid_vec_isa_list(), "Does not support vectorization"
|
|
)
|
|
@patch("torch.cuda.is_available", lambda: False)
|
|
def test_vec_logical(self):
|
|
def wrap_fn1(op: Callable):
|
|
def fn(x: torch.Tensor):
|
|
return torch.where(op(x), 1.0, 0.0)
|
|
|
|
return fn
|
|
|
|
def wrap_fn2(op: Callable):
|
|
def fn(x: torch.Tensor, y: torch.Tensor):
|
|
return torch.where(op(x, y), 1.0, 0.0)
|
|
|
|
return fn
|
|
|
|
for dtype in vec_dtypes:
|
|
x = torch.randn(64, dtype=dtype)
|
|
y = torch.randn(64, dtype=dtype)
|
|
logical_fns = [
|
|
torch.logical_and,
|
|
torch.logical_not,
|
|
torch.logical_or,
|
|
torch.logical_xor,
|
|
]
|
|
for logical_fn in logical_fns:
|
|
torch._dynamo.reset()
|
|
metrics.reset()
|
|
if logical_fn == torch.logical_not:
|
|
_fn = wrap_fn1(logical_fn)
|
|
_args = (x,)
|
|
else:
|
|
_fn = wrap_fn2(logical_fn)
|
|
_args = (x, y)
|
|
self.common(_fn, _args)
|
|
assert metrics.generated_cpp_vec_kernel_count == 1
|
|
|
|
@unittest.skipIf(
|
|
not codecache.valid_vec_isa_list(), "Does not support vectorization"
|
|
)
|
|
@patch("torch.cuda.is_available", lambda: False)
|
|
def test_vec_compare_op_cpu_only(self):
|
|
def fn(x):
|
|
y1 = torch.eq(x, 1.0)
|
|
x = torch.where(y1, x, -x)
|
|
y2 = torch.ne(x, 0.0)
|
|
x = torch.where(y2, x, -x)
|
|
y3 = torch.lt(x, 5.0)
|
|
x = torch.where(y3, x, x - 1.0)
|
|
y4 = torch.gt(x, -2.0)
|
|
x = torch.where(y4, x, x + 1.0)
|
|
y5 = torch.le(x, 8.0)
|
|
x = torch.where(y5, x, x - 1.0)
|
|
y6 = torch.ge(x, -3.0)
|
|
x = torch.where(y6, x, x + 1.0)
|
|
y7 = x == 1.0
|
|
x = torch.where(y7, x, -x)
|
|
y8 = x != 0.0
|
|
x = torch.where(y8, x, -x)
|
|
y9 = x < 5.0
|
|
x = torch.where(y9, x, x - 1.0)
|
|
y10 = x > -2.0
|
|
x = torch.where(y10, x, x + 1.0)
|
|
y11 = x <= 8.0
|
|
x = torch.where(y11, x, x - 1.0)
|
|
y12 = x >= -3.0
|
|
x = torch.where(y12, x, x + 1.0)
|
|
return x
|
|
|
|
for dtype in vec_dtypes:
|
|
x = torch.randn((2, 9), dtype=dtype)
|
|
|
|
with config.patch({"cpp.simdlen": None}):
|
|
torch._dynamo.reset()
|
|
metrics.reset()
|
|
self.common(fn, (x,))
|
|
assert metrics.generated_cpp_vec_kernel_count == 1
|
|
assert (
|
|
metrics.generated_kernel_count
|
|
- metrics.generated_cpp_vec_kernel_count
|
|
) == 0
|
|
|
|
def test_skip_cpp_codegen(self):
|
|
with config.patch({"disable_cpp_codegen": True}):
|
|
inps = torch.ones([20]), torch.rand([20])
|
|
|
|
def f(x, y):
|
|
return x + y + torch.tensor(1)
|
|
|
|
f_opt = torch.compile()(f)
|
|
|
|
_, code = run_and_get_cpp_code(f_opt, inps[0], inps[1])
|
|
FileCheck().check_not("void kernel").run(code)
|
|
|
|
self.assertEqual(
|
|
f(*inps),
|
|
f_opt(*inps),
|
|
)
|
|
|
|
# constant needs to be propagated on fallback
|
|
def f(x):
|
|
return x[torch.tensor(1) :] * 2
|
|
|
|
f_opt = torch.compile()(f)
|
|
_, code = run_and_get_cpp_code(f_opt, inps[0])
|
|
FileCheck().check_not("void kernel").run(code)
|
|
self.assertEqual(f_opt(inps[0]), f(inps[0]))
|
|
|
|
class Model(torch.nn.Module):
|
|
def __init__(
|
|
self,
|
|
):
|
|
super().__init__()
|
|
|
|
def forward(self, v1: torch.Tensor):
|
|
vx = v1.min(dim=1).values
|
|
v2 = torch.randn_like(vx)
|
|
return v2
|
|
|
|
model = Model()
|
|
x = torch.rand(10, 3, 0)
|
|
model_f = torch.compile()(model)
|
|
|
|
self.assertEqual(model(x), model_f(x))
|
|
|
|
def test_redundant_to_node_elimination_lowp_fp(self):
|
|
def fn(x, y):
|
|
res = x + y
|
|
res = torch.mean(res)
|
|
return res
|
|
|
|
for dtype in _lowp_fp_dtypes:
|
|
x = torch.randn((2, 9), dtype=dtype)
|
|
y = torch.randn((2, 9), dtype=dtype)
|
|
|
|
for torch_compile_debug in [True, False]:
|
|
with config.patch(
|
|
{"trace.enabled": torch_compile_debug, "cpp.simdlen": None}
|
|
):
|
|
torch._dynamo.reset()
|
|
metrics.reset()
|
|
self.common(fn, (x, y))
|
|
if codecache.valid_vec_isa_list():
|
|
assert metrics.generated_cpp_vec_kernel_count == 1
|
|
|
|
def test_do_not_insert_to_dtype_for_memory_copy_only_kernel(self):
|
|
def fn(x):
|
|
res = x.clone()
|
|
return res
|
|
|
|
x = torch.randn((100, 100), dtype=torch.bfloat16)
|
|
|
|
torch._dynamo.reset()
|
|
metrics.reset()
|
|
self.common(fn, (x,))
|
|
assert metrics.cpp_to_dtype_count == 0
|
|
if codecache.valid_vec_isa_list():
|
|
assert metrics.generated_cpp_vec_kernel_count == 1
|
|
|
|
def test_insert_to_dtype_count(self):
|
|
def fn(x):
|
|
res = x.relu()
|
|
return res
|
|
|
|
x = torch.randn((100, 100), dtype=torch.bfloat16)
|
|
|
|
torch._dynamo.reset()
|
|
metrics.reset()
|
|
self.common(fn, (x,))
|
|
assert metrics.cpp_to_dtype_count == 2
|
|
if codecache.valid_vec_isa_list():
|
|
assert metrics.generated_cpp_vec_kernel_count == 1
|
|
|
|
def test_memory_copy_with_fusion(self):
|
|
def fn(x):
|
|
res = x.relu()
|
|
x.copy_(res)
|
|
return (res,)
|
|
|
|
x = torch.randn((100, 100), dtype=torch.bfloat16)
|
|
|
|
torch._dynamo.reset()
|
|
metrics.reset()
|
|
self.common(fn, (x,))
|
|
assert metrics.cpp_to_dtype_count == 2
|
|
if codecache.valid_vec_isa_list():
|
|
assert metrics.generated_cpp_vec_kernel_count == 1
|
|
|
|
@unittest.skipIf(
|
|
not codecache.valid_vec_isa_list(), "Does not support vectorization"
|
|
)
|
|
@patch("torch.cuda.is_available", lambda: False)
|
|
def test_cpp_vec_constant_checker(self):
|
|
_graph: torch.fx.Graph = torch.fx.Graph()
|
|
a: torch.fx.Node = _graph.create_node("placeholder", "ops")
|
|
iv: torch.fx.Node = _graph.create_node("placeholder", "iv")
|
|
fv: torch.fx.Node = _graph.create_node("placeholder", "fv")
|
|
b: torch.fx.Node = _graph.create_node(
|
|
"call_method",
|
|
"constant",
|
|
args=(
|
|
a,
|
|
iv,
|
|
torch.int64,
|
|
),
|
|
)
|
|
c: torch.fx.Node = _graph.create_node(
|
|
"call_method",
|
|
"constant",
|
|
args=(
|
|
a,
|
|
fv,
|
|
torch.double,
|
|
),
|
|
)
|
|
d: torch.fx.Node = _graph.create_node(
|
|
"call_method",
|
|
"ge",
|
|
args=(
|
|
a,
|
|
b,
|
|
b,
|
|
),
|
|
)
|
|
_graph.output((d, c))
|
|
|
|
def get_index():
|
|
return ""
|
|
|
|
submodules = {"get_index": get_index}
|
|
|
|
graph_lowering = GraphLowering(
|
|
torch.fx.GraphModule(submodules, _graph),
|
|
shape_env=None,
|
|
num_static_inputs=0,
|
|
)
|
|
|
|
def set_opt_dtype(graph):
|
|
for node in graph.nodes:
|
|
if node.target == "constant":
|
|
if OptimizationContext.key in node.meta:
|
|
opt_ctx = node.meta[OptimizationContext.key]
|
|
else:
|
|
opt_ctx = OptimizationContext()
|
|
opt_ctx.dtype = node.args[-1]
|
|
node.meta[OptimizationContext.key] = opt_ctx
|
|
|
|
with patch.object(graph_lowering, "wrapper_code", ""), V.set_graph_handler(
|
|
graph_lowering
|
|
):
|
|
# The moset inner loop variable is used in the index_expr
|
|
tiling_factor = codecache.pick_vec_isa().nelements(dtype=torch.float)
|
|
with CppVecKernelChecker(
|
|
args=None, num_threads=1, tiling_factor=tiling_factor
|
|
) as vec_checker:
|
|
i32_iinfo = np.iinfo(np.int32)
|
|
f32_iinfo = np.finfo(np.float32)
|
|
set_opt_dtype(_graph)
|
|
InterpreterShim(_graph, submodules).run(
|
|
V.get_ops_handler(), i32_iinfo.max, f32_iinfo.max
|
|
)
|
|
self.assertTrue(vec_checker.simd_vec)
|
|
|
|
vec_checker.simd_vec = True
|
|
set_opt_dtype(_graph)
|
|
InterpreterShim(_graph, submodules).run(
|
|
V.get_ops_handler(), i32_iinfo.min, f32_iinfo.min
|
|
)
|
|
self.assertTrue(vec_checker.simd_vec)
|
|
|
|
vec_checker.simd_vec = True
|
|
set_opt_dtype(_graph)
|
|
InterpreterShim(_graph, submodules).run(
|
|
V.get_ops_handler(), i32_iinfo.min, np.inf
|
|
)
|
|
self.assertTrue(vec_checker.simd_vec)
|
|
|
|
vec_checker.simd_vec = True
|
|
set_opt_dtype(_graph)
|
|
InterpreterShim(_graph, submodules).run(
|
|
V.get_ops_handler(), i32_iinfo.min, -np.inf
|
|
)
|
|
self.assertTrue(vec_checker.simd_vec)
|
|
|
|
vec_checker.simd_vec = True
|
|
set_opt_dtype(_graph)
|
|
InterpreterShim(_graph, submodules).run(
|
|
V.get_ops_handler(), i32_iinfo.min - 1, f32_iinfo.min
|
|
)
|
|
self.assertFalse(vec_checker.simd_vec)
|
|
|
|
vec_checker.simd_vec = True
|
|
set_opt_dtype(_graph)
|
|
InterpreterShim(_graph, submodules).run(
|
|
V.get_ops_handler(), i32_iinfo.max + 1, f32_iinfo.max
|
|
)
|
|
self.assertFalse(vec_checker.simd_vec)
|
|
|
|
vec_checker.simd_vec = True
|
|
set_opt_dtype(_graph)
|
|
InterpreterShim(_graph, submodules).run(
|
|
V.get_ops_handler(), i32_iinfo.min, f32_iinfo.min * (1 + 1e-5)
|
|
)
|
|
self.assertFalse(vec_checker.simd_vec)
|
|
|
|
vec_checker.simd_vec = True
|
|
set_opt_dtype(_graph)
|
|
InterpreterShim(_graph, submodules).run(
|
|
V.get_ops_handler(), i32_iinfo.max, f32_iinfo.max * (1 + 1e-5)
|
|
)
|
|
self.assertFalse(vec_checker.simd_vec)
|
|
|
|
@unittest.skipIf(
|
|
not codecache.valid_vec_isa_list(), "Does not support vectorization"
|
|
)
|
|
@patch("torch.cuda.is_available", lambda: False)
|
|
def test_cpp_vec_index_expr_checker(self):
|
|
_graph: torch.fx.Graph = torch.fx.Graph()
|
|
a: torch.fx.Node = _graph.create_node("placeholder", "ops")
|
|
b: torch.fx.Node = _graph.create_node("call_module", "get_index", args=())
|
|
c: torch.fx.Node = _graph.create_node(
|
|
"call_method",
|
|
"index_expr",
|
|
args=(
|
|
a,
|
|
b,
|
|
torch.int64,
|
|
),
|
|
)
|
|
d: torch.fx.Node = _graph.create_node(
|
|
"call_method",
|
|
"ge",
|
|
args=(
|
|
a,
|
|
c,
|
|
c,
|
|
),
|
|
)
|
|
_graph.output(d)
|
|
|
|
def get_index():
|
|
return ""
|
|
|
|
submodules = {"get_index": get_index}
|
|
graph_lowering = GraphLowering(
|
|
torch.fx.GraphModule(submodules, _graph),
|
|
shape_env=None,
|
|
num_static_inputs=0,
|
|
)
|
|
with patch.object(graph_lowering, "wrapper_code", ""), V.set_graph_handler(
|
|
graph_lowering
|
|
):
|
|
itervars = [sympy.Symbol("i"), sympy.Symbol("j"), sympy.Symbol("k")]
|
|
|
|
tiling_factor = codecache.pick_vec_isa().nelements(dtype=torch.float)
|
|
# The most inner loop variable is used in the index_expr
|
|
with CppVecKernelChecker(
|
|
args=None, num_threads=1, tiling_factor=tiling_factor
|
|
) as vec_checker:
|
|
|
|
def get_index():
|
|
return -itervars[0] ** 2 + 2 * itervars[0] + itervars[1]
|
|
|
|
ranges = [0, 100, 200]
|
|
vec_checker.itervars = itervars[:2]
|
|
vec_checker.ranges = ranges[:2]
|
|
submodules = {"get_index": get_index}
|
|
InterpreterShim(_graph, submodules).run(V.get_ops_handler())
|
|
self.assertTrue(vec_checker.simd_vec)
|
|
|
|
# Most inner loop variable irrevalant
|
|
with CppVecKernelChecker(
|
|
args=None, num_threads=1, tiling_factor=tiling_factor
|
|
) as vec_checker:
|
|
|
|
def get_index():
|
|
return -itervars[0] ** 2 + 2 * itervars[0] + itervars[1]
|
|
|
|
ranges = [0, 100, 200]
|
|
vec_checker.itervars = itervars
|
|
vec_checker.ranges = ranges
|
|
submodules = {"get_index": get_index}
|
|
InterpreterShim(_graph, submodules).run(V.get_ops_handler())
|
|
self.assertTrue(vec_checker.simd_vec)
|
|
|
|
i32_iinfo = np.iinfo(np.int32)
|
|
_max_value = i32_iinfo.max + 1
|
|
ranges = [_max_value, _max_value, _max_value]
|
|
# Most inner loop variable irrevalant but max value is greater than
|
|
# the max value of INT32
|
|
with CppVecKernelChecker(
|
|
args=None, num_threads=1, tiling_factor=tiling_factor
|
|
) as vec_checker:
|
|
|
|
def get_index():
|
|
return itervars[0]
|
|
|
|
submodules = {"get_index": get_index}
|
|
vec_checker.itervars = itervars
|
|
vec_checker.ranges = ranges
|
|
InterpreterShim(_graph, submodules).run(V.get_ops_handler())
|
|
self.assertFalse(vec_checker.simd_vec)
|
|
|
|
# Most inner loop variable irrevalant but min value is greater than
|
|
# the min value of INT32
|
|
with CppVecKernelChecker(
|
|
args=None, num_threads=1, tiling_factor=tiling_factor
|
|
) as vec_checker:
|
|
|
|
def get_index():
|
|
return -itervars[0] - 2
|
|
|
|
submodules = {"get_index": get_index}
|
|
vec_checker.itervars = itervars
|
|
vec_checker.ranges = ranges
|
|
InterpreterShim(_graph, submodules).run(V.get_ops_handler())
|
|
self.assertFalse(vec_checker.simd_vec)
|
|
|
|
@unittest.skipIf(
|
|
not codecache.valid_vec_isa_list(), "Does not support vectorization"
|
|
)
|
|
@patch("torch.cuda.is_available", lambda: False)
|
|
def test_maxpool2d_cpu_only(self):
|
|
for dtype in vec_dtypes:
|
|
input = torch.randn(26, 32, 112, 112, dtype=dtype).to(
|
|
memory_format=torch.channels_last
|
|
)
|
|
maxpool = torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
|
|
|
def func(x):
|
|
return maxpool(x)
|
|
|
|
with patch.object(config.cpp, "simdlen", None):
|
|
torch._dynamo.reset()
|
|
metrics.reset()
|
|
self.common(func, (input,))
|
|
assert metrics.generated_cpp_vec_kernel_count == 1
|
|
|
|
@unittest.skipIf(
|
|
not codecache.valid_vec_isa_list(), "Does not support vectorization"
|
|
)
|
|
@patch("torch.cuda.is_available", lambda: False)
|
|
def test_maxpool2d_with_pre_loop_collapse_cpu_only(self):
|
|
x1 = torch.randn(2, 3, 20, 20).to(memory_format=torch.channels_last)
|
|
x2 = torch.randn(2, 3, 20, 20).to(memory_format=torch.channels_last)
|
|
maxpool = torch.nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True)
|
|
|
|
def func(x1, x2):
|
|
y = x1 + x2
|
|
return maxpool(y)
|
|
|
|
with patch.object(config.cpp, "simdlen", None):
|
|
torch._dynamo.reset()
|
|
metrics.reset()
|
|
self.common(func, (x1, x2))
|
|
assert metrics.generated_cpp_vec_kernel_count == 2
|
|
|
|
@unittest.skipIf(
|
|
not codecache.valid_vec_isa_list(), "Does not support vectorization"
|
|
)
|
|
@patch("torch.cuda.is_available", lambda: False)
|
|
def test_sign_cpu_only(self):
|
|
def fn(x):
|
|
return torch.sign(x)
|
|
|
|
for dtype in vec_dtypes:
|
|
x = torch.randn((2, 9), dtype=dtype)
|
|
x[0, 0] = torch.nan
|
|
x[1, -1] = torch.nan
|
|
|
|
with config.patch({"cpp.simdlen": None}):
|
|
torch._dynamo.reset()
|
|
metrics.reset()
|
|
self.common(fn, (x,))
|
|
assert metrics.generated_cpp_vec_kernel_count == 1
|
|
|
|
@unittest.skipIf(
|
|
not codecache.valid_vec_isa_list(), "Does not support vectorization"
|
|
)
|
|
@patch("torch.cuda.is_available", lambda: False)
|
|
def test_reduction_cpu_only(self):
|
|
def fn(x):
|
|
return torch.argmax(x, -1)
|
|
|
|
for dtype in vec_dtypes:
|
|
x = torch.randn((10, 10), dtype=dtype)
|
|
|
|
with config.patch({"cpp.simdlen": None}):
|
|
torch._dynamo.reset()
|
|
metrics.reset()
|
|
self.common(fn, (x,))
|
|
assert metrics.generated_cpp_vec_kernel_count == 0
|
|
|
|
# Currently, we enabled AVX2 and AVX512 for vectorization. If the platform is not
|
|
# supported, the vectorization will not work and skip this test case. For ARM or
|
|
# other platforms support, we just need to add the ISA info to the supported_vector_isa
|
|
# and include proper aten vectorization head file.
|
|
@unittest.skipIf(
|
|
not codecache.valid_vec_isa_list(), "Does not support vectorization"
|
|
)
|
|
@patch("torch.cuda.is_available", lambda: False)
|
|
def test_vec_kernel_cpu_only(self):
|
|
def fn(x1, x2):
|
|
# Current, there are some limitations as follows.
|
|
# rsqrt:
|
|
# assert [both a fallback and a decomp for same kernel: aten.rsqrt.default]
|
|
# round:
|
|
# couldn't find symbolic meta function/decomposition
|
|
# fmod/logical_and/logic_or:
|
|
# vec kernel has not support to_type
|
|
x = torch.abs(x1)
|
|
x = torch.sin(x)
|
|
x = torch.neg(x)
|
|
x = torch.square(x)
|
|
x = torch.sigmoid(x)
|
|
x = torch.relu(x)
|
|
x = torch.cos(x)
|
|
x = torch.exp(x)
|
|
x = torch.sqrt(x)
|
|
x = torch.add(x, x1)
|
|
x = torch.sub(x, x2)
|
|
x = torch.mul(x, x1)
|
|
x = torch.div(x, x1)
|
|
x = torch.pow(x, 10)
|
|
x = torch.log(x)
|
|
x = torch.floor(x)
|
|
x = torch.ceil(x)
|
|
x = torch.trunc(x)
|
|
x = torch.lgamma(x)
|
|
x = torch.fmod(x, x2)
|
|
x = torch.sign(x)
|
|
res = x + x2
|
|
return res
|
|
|
|
for dtype in vec_dtypes:
|
|
torch.manual_seed(0)
|
|
x1 = torch.randn((5, 20), dtype=dtype)
|
|
x2 = torch.randn((5, 20), dtype=dtype)
|
|
|
|
tol = 1e-2 if dtype == torch.bfloat16 else 1e-4
|
|
with config.patch({"cpp.simdlen": 1}):
|
|
torch._dynamo.reset()
|
|
metrics.reset()
|
|
self.common(fn, (x1, x2))
|
|
assert metrics.generated_cpp_vec_kernel_count == 0
|
|
|
|
with config.patch({"cpp.simdlen": None}):
|
|
torch._dynamo.reset()
|
|
metrics.reset()
|
|
self.common(fn, (x1, x2))
|
|
assert metrics.generated_cpp_vec_kernel_count == 1
|
|
|
|
with config.patch({"cpp.simdlen": None}):
|
|
torch._dynamo.reset()
|
|
metrics.reset()
|
|
x1 = torch.randn(10, 20).permute(1, 0)
|
|
x2 = torch.randn((20, 10))
|
|
self.common(fn, (x1, x2))
|
|
assert metrics.generated_cpp_vec_kernel_count == 2
|
|
|
|
torch._dynamo.reset()
|
|
metrics.reset()
|
|
x1 = torch.randn((10, 7))
|
|
x2 = torch.randn((10, 7))
|
|
self.common(fn, (x1, x2))
|
|
assert metrics.generated_cpp_vec_kernel_count == 1
|
|
|
|
@unittest.skipIf(
|
|
sys.platform != "linux", "cpp kernel profile only support linux now"
|
|
)
|
|
@patch("torch.cuda.is_available", lambda: False)
|
|
@config.patch({"cpp.enable_kernel_profile": True})
|
|
@config.patch({"cpp.descriptive_names": "original_aten"})
|
|
def test_cpp_kernel_profile(self):
|
|
from torch.profiler import profile
|
|
|
|
@torch._dynamo.optimize("inductor", nopython=True)
|
|
def fn(a, b):
|
|
return a + b
|
|
|
|
a = torch.rand((100,))
|
|
b = torch.rand((100,))
|
|
with profile() as prof:
|
|
fn(a, b)
|
|
|
|
kernel_profile_events = []
|
|
for e in prof.profiler.function_events:
|
|
if "cpp_fused_add_0" in e.name:
|
|
kernel_profile_events.append(e.name)
|
|
assert len(kernel_profile_events) > 0
|
|
|
|
@unittest.skipIf(
|
|
not codecache.valid_vec_isa_list(), "Does not support vectorization"
|
|
)
|
|
def test_channel_shuffle_cl_output(self):
|
|
"""code and shape extracted from shufflenet_v2_x1_0"""
|
|
|
|
def channel_shuffle(x, groups):
|
|
batchsize, num_channels, height, width = x.size()
|
|
channels_per_group = num_channels // groups
|
|
x = x.view(batchsize, groups, channels_per_group, height, width)
|
|
x = torch.transpose(x, 1, 2).contiguous()
|
|
x = x.view(batchsize, -1, height, width)
|
|
return x.contiguous(memory_format=torch.channels_last)
|
|
|
|
for simdlen in (None, 256, 1):
|
|
with config.patch({"cpp.simdlen": simdlen}):
|
|
torch._dynamo.reset()
|
|
metrics.reset()
|
|
x = torch.randn(64, 58, 28, 28)
|
|
self.common(channel_shuffle, (x, 2))
|
|
if simdlen != 1:
|
|
assert metrics.generated_cpp_vec_kernel_count == 2
|
|
|
|
@slowTest
|
|
@unittest.skipIf(
|
|
not codecache.valid_vec_isa_list(), "Does not support vectorization"
|
|
)
|
|
def test_transpose_with_norm(self):
|
|
"""a sub-module from TIMM gmlp_s16_224"""
|
|
|
|
class Model(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(
|
|
in_features=256, out_features=1536, bias=True
|
|
)
|
|
self.act = torch.nn.GELU()
|
|
self.norm = torch.nn.LayerNorm(768)
|
|
self.proj = torch.nn.Linear(196, 196)
|
|
self.fc = torch.nn.Linear(in_features=768, out_features=256, bias=True)
|
|
|
|
def forward(self, x):
|
|
x = self.linear(x)
|
|
x = self.act(x)
|
|
u, v = x.chunk(2, dim=-1)
|
|
v = self.norm(v)
|
|
v = self.proj(v.transpose(-1, -2))
|
|
y = u * v.transpose(-1, -2)
|
|
return self.fc(y)
|
|
|
|
x = torch.randn(128, 196, 256)
|
|
for simdlen in (None, 256, 1):
|
|
with config.patch({"cpp.simdlen": simdlen}):
|
|
for eval_mode in [True, False]:
|
|
torch._dynamo.reset()
|
|
metrics.reset()
|
|
m = Model().eval() if eval_mode else Model()
|
|
self.common(m, (x,))
|
|
if simdlen != 1:
|
|
assert metrics.generated_cpp_vec_kernel_count == 8
|
|
|
|
@unittest.skipIf(
|
|
not codecache.valid_vec_isa_list(), "Does not support vectorization"
|
|
)
|
|
def test_transpose_copy(self):
|
|
def fn(a):
|
|
return a.t().contiguous()
|
|
|
|
for simdlen in (None, 256, 1):
|
|
with config.patch({"cpp.simdlen": simdlen}):
|
|
for dtype in (torch.float, torch.bfloat16):
|
|
for shape in (
|
|
(7, 7),
|
|
(8, 8),
|
|
(9, 9),
|
|
(16, 16),
|
|
(17, 17),
|
|
(32, 32),
|
|
(33, 33),
|
|
):
|
|
torch._dynamo.reset()
|
|
metrics.reset()
|
|
x = torch.randn(shape, dtype=dtype)
|
|
self.common(fn, (x,))
|
|
if simdlen != 1:
|
|
assert metrics.generated_cpp_vec_kernel_count == 2
|
|
|
|
def test_horizontal_fusion(self):
|
|
def fn(a, b, c, idx):
|
|
_a = torch.index_select(a, dim=0, index=idx)
|
|
_b = torch.index_select(b, dim=0, index=idx)
|
|
_c = torch.index_select(c, dim=0, index=idx)
|
|
return _a, _b, _c
|
|
|
|
with config.patch({"cpp.max_horizontal_fusion_size": 0}):
|
|
metrics.reset()
|
|
torch._dynamo.reset()
|
|
a = torch.randn(size=(4, 16), dtype=torch.bfloat16)
|
|
b = torch.randn(size=(4, 16), dtype=torch.bfloat16)
|
|
c = torch.randn(size=(4, 16), dtype=torch.bfloat16)
|
|
idx = torch.zeros(size=[4], dtype=torch.int64)
|
|
opt_fn = torch._dynamo.optimize("inductor")(fn)
|
|
opt_fn(a, b, c, idx)
|
|
self.assertEqual(metrics.generated_kernel_count, 3)
|
|
self.assertTrue(same(fn(a, b, c, idx), opt_fn(a, b, c, idx)))
|
|
|
|
with config.patch({"cpp.max_horizontal_fusion_size": 1}):
|
|
metrics.reset()
|
|
torch._dynamo.reset()
|
|
a = torch.randn(size=(4, 32), dtype=torch.bfloat16)
|
|
b = torch.randn(size=(4, 32), dtype=torch.bfloat16)
|
|
c = torch.randn(size=(4, 32), dtype=torch.bfloat16)
|
|
idx = torch.zeros(size=[4], dtype=torch.int64)
|
|
opt_fn = torch._dynamo.optimize("inductor")(fn)
|
|
opt_fn(a, b, c, idx)
|
|
self.assertEqual(metrics.generated_kernel_count, 3)
|
|
self.assertTrue(same(fn(a, b, c, idx), opt_fn(a, b, c, idx)))
|
|
|
|
with config.patch({"cpp.max_horizontal_fusion_size": 2}):
|
|
metrics.reset()
|
|
torch._dynamo.reset()
|
|
a = torch.randn(size=(4, 64), dtype=torch.bfloat16)
|
|
b = torch.randn(size=(4, 64), dtype=torch.bfloat16)
|
|
c = torch.randn(size=(4, 64), dtype=torch.bfloat16)
|
|
idx = torch.zeros(size=[4], dtype=torch.int64)
|
|
opt_fn = torch._dynamo.optimize("inductor")(fn)
|
|
opt_fn(a, b, c, idx)
|
|
print(metrics.generated_kernel_count)
|
|
self.assertEqual(metrics.generated_kernel_count, 2)
|
|
self.assertTrue(same(fn(a, b, c, idx), opt_fn(a, b, c, idx)))
|
|
|
|
with config.patch({"cpp.max_horizontal_fusion_size": 3}):
|
|
metrics.reset()
|
|
torch._dynamo.reset()
|
|
a = torch.randn(size=(4, 128), dtype=torch.bfloat16)
|
|
b = torch.randn(size=(4, 128), dtype=torch.bfloat16)
|
|
c = torch.randn(size=(4, 128), dtype=torch.bfloat16)
|
|
idx = torch.zeros(size=[4], dtype=torch.int64)
|
|
opt_fn = torch._dynamo.optimize("inductor")(fn)
|
|
opt_fn(a, b, c, idx)
|
|
self.assertEqual(metrics.generated_kernel_count, 1)
|
|
self.assertTrue(same(fn(a, b, c, idx), opt_fn(a, b, c, idx)))
|
|
|
|
def test_lowp_fp_neg_abs(self):
|
|
def fn(x):
|
|
return x.neg().abs()
|
|
|
|
for dtype in _lowp_fp_dtypes:
|
|
metrics.reset()
|
|
x = torch.randn(100, 100).to(dtype)
|
|
opt_fn = torch._dynamo.optimize("inductor")(fn)
|
|
self.assertTrue(same(fn(x), opt_fn(x)))
|
|
assert metrics.cpp_to_dtype_count == 0
|
|
assert metrics.generated_cpp_vec_kernel_count == 1
|
|
|
|
def test_transpose_non_contiguous(self):
|
|
def fn(a):
|
|
# From part of timm HaloAttn:
|
|
# (https://github.com/rwightman/pytorch-image-models/blob/main/timm/layers/halo_attn.py#L97).
|
|
# Fixed https://github.com/pytorch/pytorch/issues/94269 accuracy issue.
|
|
as_strided = torch.ops.aten.as_strided.default(
|
|
a, [1, 384, 2, 20, 12], [153600, 1, 61440, 384, 7680]
|
|
)
|
|
as_strided_1 = torch.ops.aten.as_strided.default(
|
|
as_strided,
|
|
[1, 384, 2, 2, 12, 12],
|
|
[153600, 1, 61440, 3072, 7680, 384],
|
|
)
|
|
clone_1 = torch.ops.aten.clone.default(
|
|
as_strided_1, memory_format=torch.contiguous_format
|
|
)
|
|
_unsafe_view_1 = torch.ops.aten._unsafe_view.default(
|
|
clone_1, [8, 48, 4, 144]
|
|
)
|
|
permute_2 = torch.ops.aten.permute.default(_unsafe_view_1, [0, 2, 3, 1])
|
|
split_with_sizes = torch.ops.aten.split_with_sizes.default(
|
|
permute_2, [16, 32], -1
|
|
)
|
|
getitem = split_with_sizes[0]
|
|
getitem_1 = split_with_sizes[1]
|
|
permute_3 = torch.ops.aten.permute.default(getitem, [0, 1, 3, 2])
|
|
expand_1 = torch.ops.aten.expand.default(permute_3, [8, 4, 16, 144])
|
|
clone_3 = torch.ops.aten.clone.default(
|
|
expand_1, memory_format=torch.contiguous_format
|
|
)
|
|
return clone_3
|
|
|
|
metrics.reset()
|
|
x = torch.randn(1, 384, 20, 20).to(memory_format=torch.channels_last)
|
|
self.common(fn, (x,))
|
|
assert metrics.generated_cpp_vec_kernel_count == 1
|
|
|
|
def test_non_contiguous_index_with_constant_stride(self):
|
|
def fn(x):
|
|
x1 = x[:, :, :, ::2]
|
|
x2 = x[:, :, :, 1::2]
|
|
x = torch.stack((-x2, x1), dim=-1)
|
|
return x.flatten(-2)
|
|
|
|
metrics.reset()
|
|
x = torch.randn(1, 32, 16, 68)
|
|
opt_fn = torch._dynamo.optimize("inductor")(fn)
|
|
_, code = run_and_get_cpp_code(opt_fn, x)
|
|
self.assertTrue(same(fn(x), opt_fn(x)))
|
|
# def and use
|
|
FileCheck().check_count("cpp_fused", 2, exactly=True).run(code)
|
|
|
|
def test_invalid_index_of_empty_tensor(self):
|
|
def fn(a):
|
|
b = a[[0]]
|
|
return b
|
|
|
|
a = torch.tensor([])
|
|
with self.assertRaises(RuntimeError):
|
|
torch.compile(fn)(a)
|
|
|
|
def test_ir_node_str(self):
|
|
@torch.compile
|
|
def fn(x: torch.Tensor) -> torch.Tensor:
|
|
return x.sin(), torch.nn.Softmax(dim=1)(x.cos())
|
|
|
|
def run_node_alt(*args, **kwargs):
|
|
rv = run_node(*args, **kwargs)
|
|
strings.append(str(rv))
|
|
return rv
|
|
|
|
strings = []
|
|
run_node = GraphLowering.run_node
|
|
with patch.object(GraphLowering, "run_node", run_node_alt):
|
|
fn(torch.randn([8, 128]))
|
|
self.assertGreater(len(strings), 3)
|
|
|
|
def test_vertical_sum_cpu_only(self):
|
|
def fn1(a):
|
|
return a.sum(dim=0)
|
|
|
|
def fn2(a):
|
|
return a.sum(dim=1)
|
|
|
|
metrics.reset()
|
|
x = torch.randn(100, 100)
|
|
self.common(fn1, (x,))
|
|
assert metrics.generated_cpp_vec_kernel_count == 1
|
|
|
|
metrics.reset()
|
|
x = torch.randn(100, 100, 100)
|
|
self.common(fn2, (x,))
|
|
assert metrics.generated_cpp_vec_kernel_count == 1
|
|
|
|
def test_transpose_vertical_sum_cpu_only(self):
|
|
def fn(a, b):
|
|
c = a * b
|
|
return c.sum(dim=1)
|
|
|
|
metrics.reset()
|
|
x = torch.randn(100, 50, 50)
|
|
y = torch.randn(100, 50, 50).transpose(1, 2)
|
|
self.common(fn, (x, y))
|
|
assert metrics.generated_cpp_vec_kernel_count == 2
|
|
|
|
def test_transpose_sum2d_cpu_only(self):
|
|
def fn(a, b):
|
|
c = a * b
|
|
return c.sum()
|
|
|
|
metrics.reset()
|
|
x = torch.randn(50, 50)
|
|
y = torch.randn(50, 50).transpose(0, 1)
|
|
self.common(fn, (x, y))
|
|
assert metrics.generated_cpp_vec_kernel_count == 2
|
|
|
|
def test_transpose_sum_outer(self):
|
|
# https://github.com/pytorch/pytorch/issues/98573
|
|
def fn(a):
|
|
return a.transpose(2, 3).sum(dim=1).contiguous()
|
|
|
|
metrics.reset()
|
|
x = torch.randn(10, 50, 50, 50)
|
|
self.common(fn, (x,))
|
|
assert metrics.generated_cpp_vec_kernel_count == 1
|
|
|
|
def test_to_dtype_bool_float(self):
|
|
# https://github.com/pytorch/pytorch/issues/100800
|
|
def f(a):
|
|
return torch.where(
|
|
torch.ones_like(a).to(torch.bool),
|
|
torch.zeros_like(a),
|
|
torch.ones_like(a) * 2,
|
|
)
|
|
|
|
self.common(f, (torch.ones(16),))
|
|
|
|
def test_to_dtype_float_bool(self):
|
|
# https://github.com/pytorch/pytorch/issues/100466
|
|
def f(a):
|
|
a = a * torch.tensor(a >= 0, dtype=torch.float32)
|
|
return a
|
|
|
|
x = torch.rand(16)
|
|
self.common(f, (x,))
|
|
|
|
def test_constant_store(self):
|
|
# https://github.com/pytorch/pytorch/issues/104515
|
|
def f(a):
|
|
a[0, [3, 3]] = -float("inf")
|
|
return a
|
|
|
|
x = torch.rand(4, 5)
|
|
self.common(f, (x,))
|
|
|
|
def test_to_channels_last_lowp_fp(self):
|
|
def f(a):
|
|
return a.to(memory_format=torch.channels_last)
|
|
|
|
for dtype in _lowp_fp_dtypes:
|
|
x = torch.rand(2, 3, 14, 14).to(dtype)
|
|
self.common(f, (x,))
|
|
|
|
def test_broadcast_mul_lowp_fp(self):
|
|
def f(a, b):
|
|
return a * b
|
|
|
|
for dtype in _lowp_fp_dtypes:
|
|
a = torch.randn(2, 16, 16).to(dtype)
|
|
b = torch.randn(2, 1, 1).to(dtype)
|
|
self.common(f, (a, b))
|
|
|
|
def test_linear_buffer_reuse(self):
|
|
class M(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.linear1 = torch.nn.Linear(16, 16)
|
|
self.tanh = torch.nn.Tanh()
|
|
self.linear2 = torch.nn.Linear(16, 16)
|
|
|
|
def forward(self, x):
|
|
x = self.linear1(x)
|
|
x = self.tanh(x)
|
|
x = self.linear2(x)
|
|
return x
|
|
|
|
mod = M().eval()
|
|
v = torch.randn(1, 16)
|
|
|
|
with torch.no_grad():
|
|
|
|
def compile_fx_wrapper(model_, example_inputs_):
|
|
return compile_fx(model_, example_inputs_)
|
|
|
|
def run(*ex, **kwargs):
|
|
return mod(*ex, **kwargs)
|
|
|
|
run = torch._dynamo.optimize(compile_fx_wrapper)(run)
|
|
_, code = run_and_get_cpp_code(run, v)
|
|
self.assertFalse("= as_strided(" in code)
|
|
self.assertEqual(run(*v), mod(*v))
|
|
|
|
def test_invalid_dropout_args(self):
|
|
class MyModel(torch.nn.Module):
|
|
def forward(self, x):
|
|
x = x * 2
|
|
x = torch.nn.functional.dropout(x, p=0.5)
|
|
x = torch.relu(x)
|
|
return x
|
|
|
|
example_inputs = torch.tensor([[1, 2, 3], [4, 5, 6]])
|
|
|
|
func = MyModel()
|
|
jit_func = torch.compile(func)
|
|
self.assertRaises(RuntimeError, lambda: func(example_inputs))
|
|
self.assertRaises(RuntimeError, lambda: jit_func(example_inputs))
|
|
|
|
@config.patch(inplace_buffers=True)
|
|
def test_in_out_buffer(self):
|
|
def fn(x, y):
|
|
z = torch.matmul(x, y.transpose(-1, -2)) / 8.0
|
|
return z
|
|
|
|
inps = [torch.randn(1, 2, 8, 4), torch.randn(1, 2, 8, 4)]
|
|
fn_opt = torch._dynamo.optimize("inductor")(fn)
|
|
_, code = run_and_get_cpp_code(fn_opt, *inps)
|
|
self.assertTrue("in_out_ptr" in code)
|
|
self.assertEqual(fn_opt(*inps), fn(*inps))
|
|
|
|
def test_eliminate_meaningless_copy(self):
|
|
def fn(x1, x2):
|
|
permute = torch.ops.aten.permute.default(x2, [0, 2, 1, 3])
|
|
clone = torch.ops.aten.clone.default(
|
|
permute, memory_format=torch.contiguous_format
|
|
)
|
|
view = torch.ops.aten.view.default(clone, [1024, -1, 32])
|
|
bmm = torch.ops.aten.bmm.default(view, x1)
|
|
permute = torch.ops.aten.permute.default(view, [0, 2, 1])
|
|
return (bmm, permute)
|
|
|
|
metrics.reset()
|
|
self.common(
|
|
fn,
|
|
[
|
|
rand_strided(
|
|
(1024, 32, 128), (4096, 1, 32), device="cpu", dtype=torch.float32
|
|
),
|
|
rand_strided(
|
|
(64, 128, 16, 32),
|
|
(65536, 512, 32, 1),
|
|
device="cpu",
|
|
dtype=torch.float32,
|
|
),
|
|
],
|
|
)
|
|
self.assertEqual(metrics.generated_kernel_count, 1)
|
|
|
|
def test_attention_size_mismatch(self):
|
|
class Attention(torch.nn.Module):
|
|
def __init__(self, hidden_size, num_heads):
|
|
super().__init__()
|
|
self.hidden_size = hidden_size
|
|
self.num_heads = num_heads
|
|
self.head_size = hidden_size // num_heads
|
|
self.query = torch.nn.Linear(hidden_size, hidden_size)
|
|
self.key = torch.nn.Linear(hidden_size, hidden_size)
|
|
self.value = torch.nn.Linear(hidden_size, hidden_size)
|
|
self.inv_scale = torch.nn.Parameter(
|
|
torch.Tensor([1 / self.head_size**0.5]), requires_grad=False
|
|
)
|
|
|
|
def forward(self, x):
|
|
query = self.query(x)
|
|
key = self.key(x)
|
|
value = self.value(x)
|
|
(batch_size, seq_len, hidden_size) = query.size()
|
|
query = query.view(
|
|
batch_size, seq_len, self.num_heads, self.head_size
|
|
).permute(0, 2, 1, 3)
|
|
key = key.view(
|
|
batch_size, seq_len, self.num_heads, self.head_size
|
|
).permute(0, 2, 3, 1)
|
|
value = value.view(
|
|
batch_size, seq_len, self.num_heads, self.head_size
|
|
).permute(0, 2, 1, 3)
|
|
attention_weights = (
|
|
torch.matmul(query, key).div(self.inv_scale).softmax(dim=-1)
|
|
)
|
|
output = torch.matmul(attention_weights, value)
|
|
return output
|
|
|
|
torch.manual_seed(123)
|
|
hidden_size = 16
|
|
num_heads = 1
|
|
seq_len = 4
|
|
batch_size = 1
|
|
x = torch.randn(batch_size, seq_len, hidden_size)
|
|
|
|
func = Attention(hidden_size, num_heads).to("cpu")
|
|
|
|
with torch.no_grad():
|
|
res1 = func(x)
|
|
jit_func = torch.compile(func)
|
|
res2 = jit_func(x)
|
|
self.assertEqual(res1, res2)
|
|
|
|
def test_scalar_mul_bfloat16(self):
|
|
def f(x):
|
|
return torch.ops.aten.mul.Tensor(x, 1.7015043497085571)
|
|
|
|
metrics.reset()
|
|
x = torch.randn(4, 5, dtype=torch.bfloat16)
|
|
self.common(f, (x,))
|
|
assert metrics.generated_cpp_vec_kernel_count == 1
|
|
|
|
def test_bf16_zeros(self):
|
|
def fn():
|
|
x = torch.zeros(1, 1, 32, dtype=torch.bfloat16)
|
|
return x
|
|
|
|
self.common(fn, ())
|
|
|
|
def test_select_tiliing_with_index_expr(self):
|
|
def fn(x, y):
|
|
x = torch.ops.aten.view.default(x, [8, 8, 8, 3136])
|
|
x = torch.ops.aten.permute.default(x, [0, 1, 3, 2])
|
|
y = torch.ops.aten.mul.Tensor(y, x)
|
|
return torch.ops.aten.constant_pad_nd.default(y, [0, 0, 1, 0, 0, 0], 0.0)
|
|
|
|
x = torch.randn(8, 64, 56, 56)
|
|
y = torch.randn(8, 8, 3136, 8)
|
|
self.common(fn, (x, y))
|
|
|
|
@unittest.skipIf(not torch.backends.mkldnn.is_available(), "MKLDNN is not enabled")
|
|
@patch("torch.cuda.is_available", lambda: False)
|
|
@config.patch(freezing=True)
|
|
def test_linear_with_no_default_contiguous_input(self):
|
|
dtypes = [
|
|
torch.float32,
|
|
]
|
|
if torch.ops.mkldnn._is_mkldnn_bf16_supported():
|
|
dtypes.append(torch.bfloat16)
|
|
if torch.ops.mkldnn._is_mkldnn_fp16_supported():
|
|
dtypes.append(torch.float16)
|
|
mod = torch.nn.Sequential(torch.nn.Linear(16, 16)).eval()
|
|
temp = torch.randn(1, 16, 1, 1)
|
|
v = torch.as_strided(temp, [1, 16], [0, 1], 0)
|
|
self.assertTrue(v.is_contiguous())
|
|
for dtype in dtypes:
|
|
with torch.no_grad():
|
|
self.common(
|
|
mod.to(dtype),
|
|
(v.to(dtype),),
|
|
)
|
|
|
|
@patch("torch.cuda.is_available", lambda: False)
|
|
@config.patch(freezing=True)
|
|
def test_linear_with_reshape(self):
|
|
class M(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(16, 16, bias=False)
|
|
|
|
def forward(self, x):
|
|
x = self.linear(x)
|
|
return x.view(4, 4, 4)
|
|
|
|
mod = M().eval()
|
|
v = torch.randn(4, 16)
|
|
with torch.no_grad():
|
|
torch._dynamo.reset()
|
|
metrics.reset()
|
|
self.common(
|
|
mod,
|
|
(v,),
|
|
)
|
|
assert metrics.generated_kernel_count == 0
|
|
|
|
@config.patch(implicit_fallbacks=True)
|
|
def test_aten_normal_dtype(self):
|
|
for dtype in [torch.float64, torch.float16, None]:
|
|
|
|
def fn():
|
|
return torch.normal(2, 3, (10, 10), dtype=dtype, device="cpu")
|
|
|
|
self.assertEqual(
|
|
torch.compile(fn, backend="aot_eager_decomp_partition")().dtype,
|
|
dtype if dtype else torch.float32,
|
|
)
|
|
self.assertEqual(
|
|
torch.compile(fn, backend="inductor")().dtype,
|
|
dtype if dtype else torch.float32,
|
|
)
|
|
|
|
def test_group_norm_vec(self):
|
|
class M(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.group_norm = torch.nn.GroupNorm(32, 32)
|
|
|
|
def forward(self, x):
|
|
return self.group_norm(x)
|
|
|
|
metrics.reset()
|
|
mod = M().eval()
|
|
x = torch.randn(2, 32, 32, 32)
|
|
with torch.no_grad():
|
|
self.common(mod, (x,))
|
|
# 2 generated kernels (one for var_mean, the other for result)
|
|
assert metrics.generated_cpp_vec_kernel_count == 2
|
|
|
|
def test_int_div_vec(self):
|
|
def fn(x, y, mode):
|
|
return torch.div(x, y, rounding_mode=mode)
|
|
|
|
x = torch.randint(1, 100, (32, 32))
|
|
y = torch.randint(1, 100, (32, 32))
|
|
for mode in [None, "trunc", "floor"]:
|
|
with torch.no_grad():
|
|
metrics.reset()
|
|
self.common(fn, (x, y, mode))
|
|
# TODO: support vectorization for int div
|
|
assert metrics.generated_cpp_vec_kernel_count == 0
|
|
|
|
def test_uint8_add(self):
|
|
# https://github.com/pytorch/pytorch/issues/113016
|
|
def fn(x, y):
|
|
return torch.add(x, y).neg().to(torch.int32)
|
|
|
|
x = torch.randint(0, 255, (3, 3), dtype=torch.uint8)
|
|
y = torch.randint(0, 255, (3, 3), dtype=torch.uint8)
|
|
self.common(fn, (x, y))
|
|
|
|
def test_uint8_sub(self):
|
|
# https://github.com/pytorch/pytorch/issues/113016
|
|
def fn(x, y):
|
|
return torch.sub(x, y).neg().to(torch.int32)
|
|
|
|
x = torch.randint(0, 255, (3, 3), dtype=torch.uint8)
|
|
y = torch.randint(0, 255, (3, 3), dtype=torch.uint8)
|
|
self.common(fn, (x, y))
|
|
|
|
def test_non_contiguous_reduction_store(self):
|
|
# https://github.com/pytorch/pytorch/issues/113018
|
|
class M(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.conv = torch.nn.Conv2d(39, 1, kernel_size=(1, 17), stride=(2, 2))
|
|
|
|
def forward(self, x):
|
|
return self.conv(x.max(3).values)
|
|
|
|
m = M()
|
|
x = torch.randn(1, 39, 1, 18, 17)
|
|
self.common(m, (x,))
|
|
|
|
def test_embedding_vec(self):
|
|
class M(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.emb = torch.nn.Embedding(64, 128)
|
|
|
|
def forward(self, idx, x):
|
|
return self.emb(idx) + x
|
|
|
|
idx = torch.randint(0, 64, (4, 32))
|
|
x = torch.randn(4, 32, 128)
|
|
m = M().eval()
|
|
with torch.no_grad():
|
|
metrics.reset()
|
|
self.common(m, (idx, x))
|
|
assert metrics.generated_cpp_vec_kernel_count == 1
|
|
|
|
def test_embedding_vec_bf16(self):
|
|
class M(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.emb = torch.nn.Embedding(64, 128)
|
|
|
|
def forward(self, idx, x):
|
|
return self.emb(idx)
|
|
|
|
idx = torch.randint(0, 64, (4, 32))
|
|
x = torch.randn(4, 32, 128).to(torch.bfloat16)
|
|
m = M().eval()
|
|
with torch.no_grad():
|
|
metrics.reset()
|
|
self.common(m, (idx, x))
|
|
assert metrics.generated_cpp_vec_kernel_count == 1
|
|
|
|
# we are doing direct load/store, make sure we do not generate
|
|
# redundant type casts
|
|
m_opt = torch.compile(m)
|
|
_, code = run_and_get_cpp_code(m_opt, idx, x)
|
|
self.assertTrue("Vectorized" in code)
|
|
self.assertTrue("cvt_lowp_fp_to_fp32" not in code)
|
|
self.assertTrue("cvt_fp32_to_lowp_fp" not in code)
|
|
|
|
def test_concat_inner_vec(self):
|
|
def fn(x, y):
|
|
return F.relu(torch.cat([x, y], dim=1))
|
|
|
|
x = torch.randn(32, 35)
|
|
y = torch.randn(32, 120)
|
|
metrics.reset()
|
|
self.common(fn, (x, y))
|
|
assert metrics.generated_cpp_vec_kernel_count == 3
|
|
|
|
def test_expr_vec_non_contiguous(self):
|
|
def fn(x):
|
|
# the pattern from sebotnet33ts_256
|
|
y = torch.nn.functional.pad(x, (0, 31)).reshape(-1, 33, 63)
|
|
y = y[:, :32, 31:].reshape(4, 32, 1, 32, 32).expand(-1, -1, 32, -1, -1)
|
|
y = y.permute(0, 3, 1, 4, 2).clone(memory_format=torch.contiguous_format)
|
|
y = y.view(4, 1024, 1024)
|
|
return y.softmax(dim=-1)
|
|
|
|
x = torch.randn(128, 2048)
|
|
opt_fn = torch.compile(fn)
|
|
metrics.reset()
|
|
_, code = run_and_get_cpp_code(opt_fn, x)
|
|
self.assertTrue(same(fn(x), opt_fn(x)))
|
|
# 4 kernels for max, exp, sum and div
|
|
assert metrics.generated_cpp_vec_kernel_count == 4
|
|
FileCheck().check_count(
|
|
"Vectorized<int>::loadu(tmpbuf.data())", 0, exactly=True
|
|
).run(code)
|
|
|
|
def test_vec_contiguous_ModularIndexing(self):
|
|
# https://github.com/pytorch/pytorch/issues/114488
|
|
class M(torch.nn.Module):
|
|
def __init__(self, dim):
|
|
super().__init__()
|
|
self.norm = torch.nn.LayerNorm(dim * 4)
|
|
|
|
def forward(self, x):
|
|
# the pattern from swin_base_patch4_window7_224
|
|
B, H, W, C = x.shape
|
|
x = (
|
|
x.reshape(B, H // 2, 2, W // 2, 2, C)
|
|
.permute(0, 1, 3, 4, 2, 5)
|
|
.flatten(3)
|
|
)
|
|
x = self.norm(x)
|
|
return x
|
|
|
|
x = torch.randn(1, 56, 56, 128)
|
|
m = M(128)
|
|
opt_m = torch.compile(m)
|
|
with torch.no_grad():
|
|
metrics.reset()
|
|
_, code = run_and_get_cpp_code(opt_m, x)
|
|
self.assertTrue(same(m(x), opt_m(x)))
|
|
# Two kernels: one for reduction, one pointwises
|
|
assert metrics.generated_cpp_vec_kernel_count == 2
|
|
FileCheck().check_count(
|
|
"Vectorized<float>::loadu(tmpbuf.data())", 0, exactly=True
|
|
).run(code)
|
|
|
|
@parametrize("dtype", (torch.float16, torch.bfloat16, torch.float))
|
|
@parametrize("shape", ("15,3,13", "4,2048,4096"))
|
|
def test_fp8_cast(self, dtype: torch.dtype, shape: str):
|
|
def fp8_cast(x):
|
|
y0 = x.to(dtype=torch.float8_e4m3fn).to(dtype)
|
|
y1 = x.to(dtype=torch.float8_e5m2).to(dtype)
|
|
return y0, y1
|
|
|
|
shape = [int(dim) for dim in shape.split(",")]
|
|
x = torch.rand(*shape, device="cpu", dtype=dtype)
|
|
self.common(fp8_cast, (x,))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
from torch._dynamo.test_case import run_tests
|
|
from torch.testing._internal.inductor_utils import HAS_CPU
|
|
|
|
if HAS_CPU and not IS_MACOS:
|
|
run_tests(needs="filelock")
|