use more elements per thread for narrow dtypes (#139449)

Fix perf issue for narrow type by accessing more elements per thread

Pull Request resolved: https://github.com/pytorch/pytorch/pull/139449
Approved by: https://github.com/Chillee, https://github.com/eqy
This commit is contained in:
Natalia Gimelshein 2024-11-14 22:50:13 +00:00 committed by PyTorch MergeBot
parent 7621fc5dad
commit 05c3330893
6 changed files with 142 additions and 77 deletions

View File

@ -52,13 +52,49 @@
namespace at::native {
template <typename args_t, size_t... Is>
constexpr auto sum_of_sizes(args_t args, std::index_sequence<Is...>) {
if constexpr (sizeof...(Is) == 0) {
return 0;
} else {
return (sizeof(std::tuple_element_t<Is, args_t>) + ...);
}
}
template <int io_sizes>
constexpr auto elems_per_thread(){
if constexpr (io_sizes == 1) {
return 16;
} else if constexpr (io_sizes < 4) {
return 8;
} else {
return 4;
}
}
template <int io_sizes>
constexpr auto io_block_work_size() {
return num_threads() * elems_per_thread<io_sizes>();
}
template <typename func_t>
constexpr auto calc_io_size(){
using traits = function_traits<func_t>;
using args_t = typename traits::ArgsTuple;
constexpr auto input_size = at::native::sum_of_sizes(args_t{}, std::make_index_sequence<std::tuple_size_v<args_t>>{});
constexpr auto output_size = sizeof(typename traits::result_type);
return input_size + output_size;
}
template <int vec_size, typename func_t, typename array_t>
C10_LAUNCH_BOUNDS_1(num_threads())
__global__ void vectorized_elementwise_kernel(int N, func_t f, array_t data) {
using traits = function_traits<func_t>;
int remaining = N - block_work_size() * blockIdx.x;
constexpr auto io_size = calc_io_size<func_t>();
int remaining = N - io_block_work_size<io_size>() * blockIdx.x;
if (remaining < block_work_size()) { // if this block handles the reminder,
if (remaining < io_block_work_size<io_size>()) { // if this block handles the reminder,
// just do a naive unrolled loop
auto input_calc = TrivialOffsetCalculator<traits::arity>();
auto output_calc = TrivialOffsetCalculator<1>();
@ -69,19 +105,21 @@ __global__ void vectorized_elementwise_kernel(int N, func_t f, array_t data) {
decltype(input_calc),
decltype(output_calc),
memory::LoadWithoutCast,
memory::StoreWithoutCast>(
memory::StoreWithoutCast,
elems_per_thread<io_size>()>(
data, remaining, input_calc, output_calc, loader, storer);
elementwise_kernel_helper(f, policy);
} else { // if this block has a full `block_work_size` data to handle, use
// vectorized memory access
elementwise_kernel_helper(
f, memory::policies::vectorized<vec_size, array_t>(data));
f, memory::policies::vectorized<vec_size, array_t, elems_per_thread<io_size>()>(data));
}
}
template <
typename func_t,
typename array_t,
int elems_per_thread,
typename inp_calc_t,
typename out_calc_t,
typename loader_t,
@ -95,9 +133,9 @@ __global__ void unrolled_elementwise_kernel(
out_calc_t oc,
loader_t l,
storer_t s) {
int remaining = N - block_work_size() * blockIdx.x;
int remaining = N - elems_per_thread * num_threads() * blockIdx.x;
auto policy = memory::policies::
unroll<array_t, inp_calc_t, out_calc_t, loader_t, storer_t>(
unroll<array_t, inp_calc_t, out_calc_t, loader_t, storer_t, elems_per_thread>(
data, remaining, ic, oc, l, s);
elementwise_kernel_helper(f, policy);
}
@ -110,7 +148,8 @@ static inline void launch_vectorized_kernel(
array_t data) {
TORCH_INTERNAL_ASSERT(N > 0 && N <= std::numeric_limits<int32_t>::max());
using traits = function_traits<func_t>;
int64_t grid = (N + block_work_size() - 1) / block_work_size();
constexpr auto io_size = calc_io_size<func_t>();
int64_t grid = (N + io_block_work_size<io_size>() - 1) / io_block_work_size<io_size>();
auto stream = at::cuda::getCurrentCUDAStream();
int vec_size = memory::can_vectorize_up_to<func_t>(data);
@ -130,7 +169,7 @@ static inline void launch_vectorized_kernel(
auto output_calc = TrivialOffsetCalculator<1>();
auto loader = memory::LoadWithoutCast();
auto storer = memory::StoreWithoutCast();
unrolled_elementwise_kernel<func_t, array_t>
unrolled_elementwise_kernel<func_t, array_t, elems_per_thread<io_size>()>
<<<grid, num_threads(), 0, stream>>>(
N, f, data, input_calc, output_calc, loader, storer);
C10_CUDA_KERNEL_LAUNCH_CHECK();
@ -159,7 +198,7 @@ static inline void launch_unrolled_kernel(
TORCH_INTERNAL_ASSERT(N > 0 && N <= std::numeric_limits<int32_t>::max());
int64_t grid = (N + block_work_size() - 1) / block_work_size();
auto stream = at::cuda::getCurrentCUDAStream();
unrolled_elementwise_kernel<func_t, array_t>
unrolled_elementwise_kernel<func_t, array_t, thread_work_size()>
<<<grid, num_threads(), 0, stream>>>(N, f, data, ic, oc, l, s);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}

View File

@ -46,18 +46,19 @@ __device__ inline void elementwise_kernel_helper(func_t f, policy_t policy) {
using traits = function_traits<func_t>;
using return_t = typename traits::result_type;
using args_t = typename traits::ArgsTuple;
constexpr int elems_per_thread = policy_t::tws;
int idx = blockIdx.x;
return_t results[thread_work_size()];
args_t args[thread_work_size()];
return_t results[elems_per_thread];
args_t args[elems_per_thread];
// load
policy.load(args, idx);
// compute
#pragma unroll
for (int i = 0; i < thread_work_size(); i++) {
for (int i = 0; i < elems_per_thread; i++) {
if (policy.check_inbounds(i)) {
results[i] = c10::guts::apply(f, args[i]);
}

View File

@ -57,11 +57,11 @@ struct static_unroll<func, end, end> {
template<int arg_index>
struct vectorized_load_helper {
template <typename args_t, typename policy_t>
static __device__ void apply(policy_t &self, args_t *args, int idx) {
static __device__ void apply(policy_t &self, args_t *args, int idx, int block_work_size) {
using arg_t = std::tuple_element_t<arg_index, args_t>;
// `data` hold the data_ptr for tensors [output, input0, input1, ...], so we
// need a +1 offset to get the input
auto ptr = reinterpret_cast<arg_t *>(self.data[arg_index + 1]) + block_work_size() * idx;
auto ptr = reinterpret_cast<arg_t *>(self.data[arg_index + 1]) + block_work_size * idx;
auto args_accessor = [&args] __device__ (int thread_unroll_idx) -> arg_t & { return std::get<arg_index>(args[thread_unroll_idx]); };
self.load_single_arg(args_accessor, ptr);
}
@ -181,9 +181,7 @@ __device__ aligned_vector<bool, vec_size> load_vector(const bool *base_ptr, uint
namespace policies {
// Assumption:
// all tensors are contiguous, that is: stride == sizeof(type) for all tensors
template<typename data_t, typename inp_calc_t, typename out_calc_t, typename loader_t, typename storer_t, int num_outputs = 1>
template<typename data_t, typename inp_calc_t, typename out_calc_t, typename loader_t, typename storer_t, int elems_per_thread, int num_outputs=1>
struct unroll {
data_t data;
@ -192,6 +190,7 @@ struct unroll {
out_calc_t output_offset_calculator;
loader_t loader;
storer_t storer;
static constexpr int tws = elems_per_thread;
__device__ unroll(data_t data, int remaining, inp_calc_t ic, out_calc_t oc, loader_t l, storer_t s):
data(data), remaining(remaining), input_offset_calculator(ic), output_offset_calculator(oc), loader(l), storer(s) {}
@ -205,11 +204,11 @@ struct unroll {
constexpr int arity = std::tuple_size_v<args_t>;
int thread_idx = threadIdx.x;
#pragma unroll
for (int i = 0; i < thread_work_size(); i++) {
for (int i = 0; i < elems_per_thread; i++) {
if (thread_idx >= remaining) {
return;
}
int linear_idx = thread_idx + block_work_size() * idx;
int linear_idx = thread_idx + elems_per_thread * num_threads() * idx;
auto offset = input_offset_calculator.get(linear_idx);
detail::static_unroll<detail::unroll_load_helper, arity>::with_args(*this, args, offset, loader, i, num_outputs);
thread_idx += num_threads();
@ -220,11 +219,11 @@ struct unroll {
__device__ inline void store(scalar_t *from, int idx) {
int thread_idx = threadIdx.x;
#pragma unroll
for (int i = 0; i < thread_work_size(); i++) {
for (int i = 0; i < elems_per_thread; i++) {
if (thread_idx >= remaining) {
return;
}
int linear_idx = thread_idx + block_work_size() * idx;
int linear_idx = thread_idx + elems_per_thread * num_threads() * idx;
int offset = output_offset_calculator.get(linear_idx)[0];
storer.store(from[i], data[0], offset);
thread_idx += num_threads();
@ -237,11 +236,12 @@ struct unroll {
// Note:
// Functions in vectorized policy does not do boundary check. It assumes the whole block
// has its job to do. So the reminders should be handled by the caller manually.
template <int vec_size, typename data_t> // vec_size: number of scalars, can be 1, 2, or 4.
template <int vec_size, typename data_t, int elems_per_thread> // vec_size: number of scalars, can be 1, 2, or 4.
struct vectorized {
static_assert(thread_work_size() % vec_size == 0, "The workload per thread must be a multiple of vec_size");
static constexpr int loop_size = thread_work_size() / vec_size;
static_assert(elems_per_thread % vec_size == 0, "The workload per thread must be a multiple of vec_size");
static constexpr int loop_size = elems_per_thread / vec_size;
static constexpr int tws = elems_per_thread;
data_t data;
@ -268,13 +268,13 @@ struct vectorized {
template<typename args_t>
__device__ inline void load(args_t *args, int idx) {
constexpr int arity = std::tuple_size_v<args_t>;
detail::static_unroll<detail::vectorized_load_helper, arity>::with_args(*this, args, idx);
detail::static_unroll<detail::vectorized_load_helper, arity>::with_args(*this, args, idx, elems_per_thread * num_threads());
}
template<typename scalar_t>
__device__ inline void store(scalar_t *from, int idx) {
using vec_t = aligned_vector<scalar_t, vec_size>;
scalar_t *to = reinterpret_cast<scalar_t *>(data[0]) + block_work_size() * idx;
scalar_t *to = reinterpret_cast<scalar_t *>(data[0]) + elems_per_thread * num_threads() * idx;
vec_t *to_ = reinterpret_cast<vec_t *>(to);
int thread_idx = threadIdx.x;
#pragma unroll
@ -299,6 +299,7 @@ struct multi_outputs_unroll {
out_calc_t output_offset_calculator;
LoadWithoutCast loader;
StoreWithoutCast storer;
static constexpr int tws = thread_work_size();
__device__ multi_outputs_unroll(data_t data, int remaining, inp_calc_t ic, out_calc_t oc):
data(data), remaining(remaining), input_offset_calculator(ic), output_offset_calculator(oc) {}

View File

@ -82,7 +82,7 @@ __global__ void vectorized_copy(scalar_t *dst, scalar_t *src) {
data[0] = reinterpret_cast<char *>(dst);
data[1] = reinterpret_cast<char *>(src);
int idx = blockIdx.x;
using vectorized = policies::vectorized<vec_size, array_t>;
using vectorized = policies::vectorized<vec_size, array_t, thread_work_size()>;
auto policy = vectorized(data);
scalar_t buf[thread_work_size()];
#if !defined(USE_ROCM)

View File

@ -1045,7 +1045,6 @@ class TestReductions(TestCase):
a[:, (shape[1] - 1) // 2:] = True
values, indices = a.mode(-1)
self.assertEqual(values, torch.ones(shape[0], dtype=torch.bool))
print(indices)
indexed = a.gather(1, indices.unsqueeze(1)).squeeze(1)
self.assertEqual(values, indexed)

View File

@ -1,57 +1,57 @@
# Owner(s): ["module: tests"]
import torch
import numpy as np
import math
from numbers import Number
import random
import unittest
from numbers import Number
import numpy as np
import torch
from torch import inf, nan
from torch.testing._internal.common_utils import (
TestCase,
run_tests,
torch_to_numpy_dtype_dict,
numpy_to_torch_dtype_dict,
suppress_warnings,
TEST_SCIPY,
slowTest,
skipIfNoSciPy,
IS_WINDOWS,
gradcheck,
is_iterable_of_tensors,
xfailIfTorchDynamo,
)
from torch.testing._internal.common_methods_invocations import (
unary_ufuncs,
generate_elementwise_unary_tensors,
generate_elementwise_unary_small_value_tensors,
generate_elementwise_unary_large_value_tensors,
generate_elementwise_unary_extremal_value_tensors,
)
from torch.testing._internal.common_device_type import (
instantiate_device_type_tests,
ops,
dtypes,
onlyCPU,
onlyNativeDeviceTypes,
onlyCUDA,
dtypesIfCUDA,
precisionOverride,
dtypesIfCPU,
)
from torch.utils import _pytree as pytree
from torch.testing import make_tensor
from torch.testing._internal.common_device_type import (
dtypes,
dtypesIfCPU,
dtypesIfCUDA,
instantiate_device_type_tests,
onlyCPU,
onlyCUDA,
onlyNativeDeviceTypes,
ops,
precisionOverride,
)
from torch.testing._internal.common_dtype import (
floating_types_and,
all_types_and_complex_and,
integral_types_and,
get_all_math_dtypes,
complex_types,
floating_and_complex_types_and,
floating_types_and,
get_all_math_dtypes,
integral_types_and,
)
from torch.testing._internal.common_methods_invocations import (
generate_elementwise_unary_extremal_value_tensors,
generate_elementwise_unary_large_value_tensors,
generate_elementwise_unary_small_value_tensors,
generate_elementwise_unary_tensors,
unary_ufuncs,
)
from torch.testing._internal.common_utils import (
gradcheck,
is_iterable_of_tensors,
IS_WINDOWS,
numpy_to_torch_dtype_dict,
run_tests,
skipIfNoSciPy,
slowTest,
suppress_warnings,
TEST_SCIPY,
TestCase,
torch_to_numpy_dtype_dict,
xfailIfTorchDynamo,
)
from torch.utils import _pytree as pytree
if TEST_SCIPY:
import scipy
@ -172,7 +172,7 @@ class TestUnaryUfuncs(TestCase):
torch.from_numpy(expected).to(actual.dtype),
msg,
exact_device=False,
**kwargs
**kwargs,
)
else:
self.assertEqual(actual, expected, msg, exact_device=False, **kwargs)
@ -444,7 +444,9 @@ class TestUnaryUfuncs(TestCase):
all_outs = [op(slice, **torch_kwargs) for slice in input]
if is_iterable_of_tensors(actual):
expected = [torch.stack([out[i] for out in all_outs]) for i in range(len(actual))]
expected = [
torch.stack([out[i] for out in all_outs]) for i in range(len(actual))
]
else:
expected = torch.stack(all_outs)
@ -510,8 +512,10 @@ class TestUnaryUfuncs(TestCase):
shapes = [[1, 3, 6, 6], [1, 3, 6, 128], [1, 3, 256, 256]]
for shape in shapes:
x = torch.randn(shape, device=device)
extremals = [float('nan'), float('inf'), -float('inf')]
for id1, id2, extremal in zip(torch.randint(0, 2, (3,)), torch.randint(0, 5, (3,)), extremals):
extremals = [float("nan"), float("inf"), -float("inf")]
for id1, id2, extremal in zip(
torch.randint(0, 2, (3,)), torch.randint(0, 5, (3,)), extremals
):
x[0, id1, id2, :] = extremal
test_dtype(func(), x, torch.bfloat16)
@ -520,9 +524,13 @@ class TestUnaryUfuncs(TestCase):
value_dtype = torch.tensor([], dtype=dtype).real.dtype
def gen_tensor(a):
return torch.view_as_complex(torch.tensor(a, dtype=value_dtype, device=device))
return torch.view_as_complex(
torch.tensor(a, dtype=value_dtype, device=device)
)
for extremal, kwarg_name in zip(['nan', 'inf', '-inf'], ['nan', 'posinf', 'neginf']):
for extremal, kwarg_name in zip(
["nan", "inf", "-inf"], ["nan", "posinf", "neginf"]
):
a = gen_tensor([123, float(extremal)])
res = torch.nan_to_num(a, **{kwarg_name: 12})
res_check = gen_tensor([123, 12])
@ -1078,17 +1086,21 @@ class TestUnaryUfuncs(TestCase):
(1e-19 + 1e-18j, 4.99999984132761269448e-20 + 5.00000022906852482872e-19j),
(-1.0 + 2.0j, -0.78546208143234252930 + -0.44626939296722412109j),
(0.0 + 0.5j, -0.06383547931909561157 + 0.25000000000000000000j),
(2.0j, -1.55740761756896972656 + 0.99999988079071044922j)
(2.0j, -1.55740761756896972656 + 0.99999988079071044922j),
]
for inp, out in inouts:
res = torch.nn.functional.silu(torch.tensor(inp, dtype=dtype, device=device))
res = torch.nn.functional.silu(
torch.tensor(inp, dtype=dtype, device=device)
)
self.assertFalse(torch.any(torch.isnan(res)))
self.assertEqual(res.real, out.real, atol=atol, rtol=rtol)
self.assertEqual(res.imag, out.imag, atol=atol, rtol=rtol)
for inp, out in inouts:
res = torch.nn.functional.silu(torch.tensor(inp, dtype=dtype, device=device), inplace=True)
res = torch.nn.functional.silu(
torch.tensor(inp, dtype=dtype, device=device), inplace=True
)
self.assertFalse(torch.any(torch.isnan(res)))
self.assertEqual(res.real, out.real, atol=atol, rtol=rtol)
self.assertEqual(res.imag, out.imag, atol=atol, rtol=rtol)
@ -1487,6 +1499,19 @@ class TestUnaryUfuncs(TestCase):
for num in abs_zeros:
self.assertGreater(math.copysign(1.0, num), 0.0)
@onlyCUDA
@dtypes(torch.bool, torch.int8)
def test_narrow_dtypes(self, device, dtype):
x_int = torch.randint(2, (8 * 1024,), device=device, dtype=torch.int)
x = x_int.to(dtype)
# check normal conversion
self.assertEqual(x_int, x.int())
x.fill_(0)
self.assertEqual(x.sum(), 0)
# test unaligned tensor with non-round number of elements
x[1:4000].fill_(1)
self.assertEqual(x.sum(), 3999)
@dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16))
def test_isposinf_isneginf_non_boolean_output(self, device, dtype):
# test non-boolean tensors as the `out=` parameters