mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
use more elements per thread for narrow dtypes (#139449)
Fix perf issue for narrow type by accessing more elements per thread Pull Request resolved: https://github.com/pytorch/pytorch/pull/139449 Approved by: https://github.com/Chillee, https://github.com/eqy
This commit is contained in:
parent
7621fc5dad
commit
05c3330893
|
|
@ -52,13 +52,49 @@
|
||||||
|
|
||||||
namespace at::native {
|
namespace at::native {
|
||||||
|
|
||||||
|
|
||||||
|
template <typename args_t, size_t... Is>
|
||||||
|
constexpr auto sum_of_sizes(args_t args, std::index_sequence<Is...>) {
|
||||||
|
if constexpr (sizeof...(Is) == 0) {
|
||||||
|
return 0;
|
||||||
|
} else {
|
||||||
|
return (sizeof(std::tuple_element_t<Is, args_t>) + ...);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <int io_sizes>
|
||||||
|
constexpr auto elems_per_thread(){
|
||||||
|
if constexpr (io_sizes == 1) {
|
||||||
|
return 16;
|
||||||
|
} else if constexpr (io_sizes < 4) {
|
||||||
|
return 8;
|
||||||
|
} else {
|
||||||
|
return 4;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <int io_sizes>
|
||||||
|
constexpr auto io_block_work_size() {
|
||||||
|
return num_threads() * elems_per_thread<io_sizes>();
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename func_t>
|
||||||
|
constexpr auto calc_io_size(){
|
||||||
|
using traits = function_traits<func_t>;
|
||||||
|
using args_t = typename traits::ArgsTuple;
|
||||||
|
constexpr auto input_size = at::native::sum_of_sizes(args_t{}, std::make_index_sequence<std::tuple_size_v<args_t>>{});
|
||||||
|
constexpr auto output_size = sizeof(typename traits::result_type);
|
||||||
|
return input_size + output_size;
|
||||||
|
}
|
||||||
|
|
||||||
template <int vec_size, typename func_t, typename array_t>
|
template <int vec_size, typename func_t, typename array_t>
|
||||||
C10_LAUNCH_BOUNDS_1(num_threads())
|
C10_LAUNCH_BOUNDS_1(num_threads())
|
||||||
__global__ void vectorized_elementwise_kernel(int N, func_t f, array_t data) {
|
__global__ void vectorized_elementwise_kernel(int N, func_t f, array_t data) {
|
||||||
using traits = function_traits<func_t>;
|
using traits = function_traits<func_t>;
|
||||||
int remaining = N - block_work_size() * blockIdx.x;
|
constexpr auto io_size = calc_io_size<func_t>();
|
||||||
|
int remaining = N - io_block_work_size<io_size>() * blockIdx.x;
|
||||||
|
|
||||||
if (remaining < block_work_size()) { // if this block handles the reminder,
|
if (remaining < io_block_work_size<io_size>()) { // if this block handles the reminder,
|
||||||
// just do a naive unrolled loop
|
// just do a naive unrolled loop
|
||||||
auto input_calc = TrivialOffsetCalculator<traits::arity>();
|
auto input_calc = TrivialOffsetCalculator<traits::arity>();
|
||||||
auto output_calc = TrivialOffsetCalculator<1>();
|
auto output_calc = TrivialOffsetCalculator<1>();
|
||||||
|
|
@ -69,19 +105,21 @@ __global__ void vectorized_elementwise_kernel(int N, func_t f, array_t data) {
|
||||||
decltype(input_calc),
|
decltype(input_calc),
|
||||||
decltype(output_calc),
|
decltype(output_calc),
|
||||||
memory::LoadWithoutCast,
|
memory::LoadWithoutCast,
|
||||||
memory::StoreWithoutCast>(
|
memory::StoreWithoutCast,
|
||||||
|
elems_per_thread<io_size>()>(
|
||||||
data, remaining, input_calc, output_calc, loader, storer);
|
data, remaining, input_calc, output_calc, loader, storer);
|
||||||
elementwise_kernel_helper(f, policy);
|
elementwise_kernel_helper(f, policy);
|
||||||
} else { // if this block has a full `block_work_size` data to handle, use
|
} else { // if this block has a full `block_work_size` data to handle, use
|
||||||
// vectorized memory access
|
// vectorized memory access
|
||||||
elementwise_kernel_helper(
|
elementwise_kernel_helper(
|
||||||
f, memory::policies::vectorized<vec_size, array_t>(data));
|
f, memory::policies::vectorized<vec_size, array_t, elems_per_thread<io_size>()>(data));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <
|
template <
|
||||||
typename func_t,
|
typename func_t,
|
||||||
typename array_t,
|
typename array_t,
|
||||||
|
int elems_per_thread,
|
||||||
typename inp_calc_t,
|
typename inp_calc_t,
|
||||||
typename out_calc_t,
|
typename out_calc_t,
|
||||||
typename loader_t,
|
typename loader_t,
|
||||||
|
|
@ -95,9 +133,9 @@ __global__ void unrolled_elementwise_kernel(
|
||||||
out_calc_t oc,
|
out_calc_t oc,
|
||||||
loader_t l,
|
loader_t l,
|
||||||
storer_t s) {
|
storer_t s) {
|
||||||
int remaining = N - block_work_size() * blockIdx.x;
|
int remaining = N - elems_per_thread * num_threads() * blockIdx.x;
|
||||||
auto policy = memory::policies::
|
auto policy = memory::policies::
|
||||||
unroll<array_t, inp_calc_t, out_calc_t, loader_t, storer_t>(
|
unroll<array_t, inp_calc_t, out_calc_t, loader_t, storer_t, elems_per_thread>(
|
||||||
data, remaining, ic, oc, l, s);
|
data, remaining, ic, oc, l, s);
|
||||||
elementwise_kernel_helper(f, policy);
|
elementwise_kernel_helper(f, policy);
|
||||||
}
|
}
|
||||||
|
|
@ -110,7 +148,8 @@ static inline void launch_vectorized_kernel(
|
||||||
array_t data) {
|
array_t data) {
|
||||||
TORCH_INTERNAL_ASSERT(N > 0 && N <= std::numeric_limits<int32_t>::max());
|
TORCH_INTERNAL_ASSERT(N > 0 && N <= std::numeric_limits<int32_t>::max());
|
||||||
using traits = function_traits<func_t>;
|
using traits = function_traits<func_t>;
|
||||||
int64_t grid = (N + block_work_size() - 1) / block_work_size();
|
constexpr auto io_size = calc_io_size<func_t>();
|
||||||
|
int64_t grid = (N + io_block_work_size<io_size>() - 1) / io_block_work_size<io_size>();
|
||||||
auto stream = at::cuda::getCurrentCUDAStream();
|
auto stream = at::cuda::getCurrentCUDAStream();
|
||||||
int vec_size = memory::can_vectorize_up_to<func_t>(data);
|
int vec_size = memory::can_vectorize_up_to<func_t>(data);
|
||||||
|
|
||||||
|
|
@ -130,7 +169,7 @@ static inline void launch_vectorized_kernel(
|
||||||
auto output_calc = TrivialOffsetCalculator<1>();
|
auto output_calc = TrivialOffsetCalculator<1>();
|
||||||
auto loader = memory::LoadWithoutCast();
|
auto loader = memory::LoadWithoutCast();
|
||||||
auto storer = memory::StoreWithoutCast();
|
auto storer = memory::StoreWithoutCast();
|
||||||
unrolled_elementwise_kernel<func_t, array_t>
|
unrolled_elementwise_kernel<func_t, array_t, elems_per_thread<io_size>()>
|
||||||
<<<grid, num_threads(), 0, stream>>>(
|
<<<grid, num_threads(), 0, stream>>>(
|
||||||
N, f, data, input_calc, output_calc, loader, storer);
|
N, f, data, input_calc, output_calc, loader, storer);
|
||||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||||
|
|
@ -159,7 +198,7 @@ static inline void launch_unrolled_kernel(
|
||||||
TORCH_INTERNAL_ASSERT(N > 0 && N <= std::numeric_limits<int32_t>::max());
|
TORCH_INTERNAL_ASSERT(N > 0 && N <= std::numeric_limits<int32_t>::max());
|
||||||
int64_t grid = (N + block_work_size() - 1) / block_work_size();
|
int64_t grid = (N + block_work_size() - 1) / block_work_size();
|
||||||
auto stream = at::cuda::getCurrentCUDAStream();
|
auto stream = at::cuda::getCurrentCUDAStream();
|
||||||
unrolled_elementwise_kernel<func_t, array_t>
|
unrolled_elementwise_kernel<func_t, array_t, thread_work_size()>
|
||||||
<<<grid, num_threads(), 0, stream>>>(N, f, data, ic, oc, l, s);
|
<<<grid, num_threads(), 0, stream>>>(N, f, data, ic, oc, l, s);
|
||||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -46,18 +46,19 @@ __device__ inline void elementwise_kernel_helper(func_t f, policy_t policy) {
|
||||||
using traits = function_traits<func_t>;
|
using traits = function_traits<func_t>;
|
||||||
using return_t = typename traits::result_type;
|
using return_t = typename traits::result_type;
|
||||||
using args_t = typename traits::ArgsTuple;
|
using args_t = typename traits::ArgsTuple;
|
||||||
|
constexpr int elems_per_thread = policy_t::tws;
|
||||||
|
|
||||||
int idx = blockIdx.x;
|
int idx = blockIdx.x;
|
||||||
|
|
||||||
return_t results[thread_work_size()];
|
return_t results[elems_per_thread];
|
||||||
args_t args[thread_work_size()];
|
args_t args[elems_per_thread];
|
||||||
|
|
||||||
// load
|
// load
|
||||||
policy.load(args, idx);
|
policy.load(args, idx);
|
||||||
|
|
||||||
// compute
|
// compute
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < thread_work_size(); i++) {
|
for (int i = 0; i < elems_per_thread; i++) {
|
||||||
if (policy.check_inbounds(i)) {
|
if (policy.check_inbounds(i)) {
|
||||||
results[i] = c10::guts::apply(f, args[i]);
|
results[i] = c10::guts::apply(f, args[i]);
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -57,11 +57,11 @@ struct static_unroll<func, end, end> {
|
||||||
template<int arg_index>
|
template<int arg_index>
|
||||||
struct vectorized_load_helper {
|
struct vectorized_load_helper {
|
||||||
template <typename args_t, typename policy_t>
|
template <typename args_t, typename policy_t>
|
||||||
static __device__ void apply(policy_t &self, args_t *args, int idx) {
|
static __device__ void apply(policy_t &self, args_t *args, int idx, int block_work_size) {
|
||||||
using arg_t = std::tuple_element_t<arg_index, args_t>;
|
using arg_t = std::tuple_element_t<arg_index, args_t>;
|
||||||
// `data` hold the data_ptr for tensors [output, input0, input1, ...], so we
|
// `data` hold the data_ptr for tensors [output, input0, input1, ...], so we
|
||||||
// need a +1 offset to get the input
|
// need a +1 offset to get the input
|
||||||
auto ptr = reinterpret_cast<arg_t *>(self.data[arg_index + 1]) + block_work_size() * idx;
|
auto ptr = reinterpret_cast<arg_t *>(self.data[arg_index + 1]) + block_work_size * idx;
|
||||||
auto args_accessor = [&args] __device__ (int thread_unroll_idx) -> arg_t & { return std::get<arg_index>(args[thread_unroll_idx]); };
|
auto args_accessor = [&args] __device__ (int thread_unroll_idx) -> arg_t & { return std::get<arg_index>(args[thread_unroll_idx]); };
|
||||||
self.load_single_arg(args_accessor, ptr);
|
self.load_single_arg(args_accessor, ptr);
|
||||||
}
|
}
|
||||||
|
|
@ -181,9 +181,7 @@ __device__ aligned_vector<bool, vec_size> load_vector(const bool *base_ptr, uint
|
||||||
|
|
||||||
namespace policies {
|
namespace policies {
|
||||||
|
|
||||||
// Assumption:
|
template<typename data_t, typename inp_calc_t, typename out_calc_t, typename loader_t, typename storer_t, int elems_per_thread, int num_outputs=1>
|
||||||
// all tensors are contiguous, that is: stride == sizeof(type) for all tensors
|
|
||||||
template<typename data_t, typename inp_calc_t, typename out_calc_t, typename loader_t, typename storer_t, int num_outputs = 1>
|
|
||||||
struct unroll {
|
struct unroll {
|
||||||
|
|
||||||
data_t data;
|
data_t data;
|
||||||
|
|
@ -192,6 +190,7 @@ struct unroll {
|
||||||
out_calc_t output_offset_calculator;
|
out_calc_t output_offset_calculator;
|
||||||
loader_t loader;
|
loader_t loader;
|
||||||
storer_t storer;
|
storer_t storer;
|
||||||
|
static constexpr int tws = elems_per_thread;
|
||||||
|
|
||||||
__device__ unroll(data_t data, int remaining, inp_calc_t ic, out_calc_t oc, loader_t l, storer_t s):
|
__device__ unroll(data_t data, int remaining, inp_calc_t ic, out_calc_t oc, loader_t l, storer_t s):
|
||||||
data(data), remaining(remaining), input_offset_calculator(ic), output_offset_calculator(oc), loader(l), storer(s) {}
|
data(data), remaining(remaining), input_offset_calculator(ic), output_offset_calculator(oc), loader(l), storer(s) {}
|
||||||
|
|
@ -205,11 +204,11 @@ struct unroll {
|
||||||
constexpr int arity = std::tuple_size_v<args_t>;
|
constexpr int arity = std::tuple_size_v<args_t>;
|
||||||
int thread_idx = threadIdx.x;
|
int thread_idx = threadIdx.x;
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < thread_work_size(); i++) {
|
for (int i = 0; i < elems_per_thread; i++) {
|
||||||
if (thread_idx >= remaining) {
|
if (thread_idx >= remaining) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
int linear_idx = thread_idx + block_work_size() * idx;
|
int linear_idx = thread_idx + elems_per_thread * num_threads() * idx;
|
||||||
auto offset = input_offset_calculator.get(linear_idx);
|
auto offset = input_offset_calculator.get(linear_idx);
|
||||||
detail::static_unroll<detail::unroll_load_helper, arity>::with_args(*this, args, offset, loader, i, num_outputs);
|
detail::static_unroll<detail::unroll_load_helper, arity>::with_args(*this, args, offset, loader, i, num_outputs);
|
||||||
thread_idx += num_threads();
|
thread_idx += num_threads();
|
||||||
|
|
@ -220,11 +219,11 @@ struct unroll {
|
||||||
__device__ inline void store(scalar_t *from, int idx) {
|
__device__ inline void store(scalar_t *from, int idx) {
|
||||||
int thread_idx = threadIdx.x;
|
int thread_idx = threadIdx.x;
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < thread_work_size(); i++) {
|
for (int i = 0; i < elems_per_thread; i++) {
|
||||||
if (thread_idx >= remaining) {
|
if (thread_idx >= remaining) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
int linear_idx = thread_idx + block_work_size() * idx;
|
int linear_idx = thread_idx + elems_per_thread * num_threads() * idx;
|
||||||
int offset = output_offset_calculator.get(linear_idx)[0];
|
int offset = output_offset_calculator.get(linear_idx)[0];
|
||||||
storer.store(from[i], data[0], offset);
|
storer.store(from[i], data[0], offset);
|
||||||
thread_idx += num_threads();
|
thread_idx += num_threads();
|
||||||
|
|
@ -237,11 +236,12 @@ struct unroll {
|
||||||
// Note:
|
// Note:
|
||||||
// Functions in vectorized policy does not do boundary check. It assumes the whole block
|
// Functions in vectorized policy does not do boundary check. It assumes the whole block
|
||||||
// has its job to do. So the reminders should be handled by the caller manually.
|
// has its job to do. So the reminders should be handled by the caller manually.
|
||||||
template <int vec_size, typename data_t> // vec_size: number of scalars, can be 1, 2, or 4.
|
template <int vec_size, typename data_t, int elems_per_thread> // vec_size: number of scalars, can be 1, 2, or 4.
|
||||||
struct vectorized {
|
struct vectorized {
|
||||||
|
|
||||||
static_assert(thread_work_size() % vec_size == 0, "The workload per thread must be a multiple of vec_size");
|
static_assert(elems_per_thread % vec_size == 0, "The workload per thread must be a multiple of vec_size");
|
||||||
static constexpr int loop_size = thread_work_size() / vec_size;
|
static constexpr int loop_size = elems_per_thread / vec_size;
|
||||||
|
static constexpr int tws = elems_per_thread;
|
||||||
|
|
||||||
data_t data;
|
data_t data;
|
||||||
|
|
||||||
|
|
@ -268,13 +268,13 @@ struct vectorized {
|
||||||
template<typename args_t>
|
template<typename args_t>
|
||||||
__device__ inline void load(args_t *args, int idx) {
|
__device__ inline void load(args_t *args, int idx) {
|
||||||
constexpr int arity = std::tuple_size_v<args_t>;
|
constexpr int arity = std::tuple_size_v<args_t>;
|
||||||
detail::static_unroll<detail::vectorized_load_helper, arity>::with_args(*this, args, idx);
|
detail::static_unroll<detail::vectorized_load_helper, arity>::with_args(*this, args, idx, elems_per_thread * num_threads());
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename scalar_t>
|
template<typename scalar_t>
|
||||||
__device__ inline void store(scalar_t *from, int idx) {
|
__device__ inline void store(scalar_t *from, int idx) {
|
||||||
using vec_t = aligned_vector<scalar_t, vec_size>;
|
using vec_t = aligned_vector<scalar_t, vec_size>;
|
||||||
scalar_t *to = reinterpret_cast<scalar_t *>(data[0]) + block_work_size() * idx;
|
scalar_t *to = reinterpret_cast<scalar_t *>(data[0]) + elems_per_thread * num_threads() * idx;
|
||||||
vec_t *to_ = reinterpret_cast<vec_t *>(to);
|
vec_t *to_ = reinterpret_cast<vec_t *>(to);
|
||||||
int thread_idx = threadIdx.x;
|
int thread_idx = threadIdx.x;
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
|
|
@ -299,6 +299,7 @@ struct multi_outputs_unroll {
|
||||||
out_calc_t output_offset_calculator;
|
out_calc_t output_offset_calculator;
|
||||||
LoadWithoutCast loader;
|
LoadWithoutCast loader;
|
||||||
StoreWithoutCast storer;
|
StoreWithoutCast storer;
|
||||||
|
static constexpr int tws = thread_work_size();
|
||||||
|
|
||||||
__device__ multi_outputs_unroll(data_t data, int remaining, inp_calc_t ic, out_calc_t oc):
|
__device__ multi_outputs_unroll(data_t data, int remaining, inp_calc_t ic, out_calc_t oc):
|
||||||
data(data), remaining(remaining), input_offset_calculator(ic), output_offset_calculator(oc) {}
|
data(data), remaining(remaining), input_offset_calculator(ic), output_offset_calculator(oc) {}
|
||||||
|
|
|
||||||
|
|
@ -82,7 +82,7 @@ __global__ void vectorized_copy(scalar_t *dst, scalar_t *src) {
|
||||||
data[0] = reinterpret_cast<char *>(dst);
|
data[0] = reinterpret_cast<char *>(dst);
|
||||||
data[1] = reinterpret_cast<char *>(src);
|
data[1] = reinterpret_cast<char *>(src);
|
||||||
int idx = blockIdx.x;
|
int idx = blockIdx.x;
|
||||||
using vectorized = policies::vectorized<vec_size, array_t>;
|
using vectorized = policies::vectorized<vec_size, array_t, thread_work_size()>;
|
||||||
auto policy = vectorized(data);
|
auto policy = vectorized(data);
|
||||||
scalar_t buf[thread_work_size()];
|
scalar_t buf[thread_work_size()];
|
||||||
#if !defined(USE_ROCM)
|
#if !defined(USE_ROCM)
|
||||||
|
|
|
||||||
|
|
@ -1045,7 +1045,6 @@ class TestReductions(TestCase):
|
||||||
a[:, (shape[1] - 1) // 2:] = True
|
a[:, (shape[1] - 1) // 2:] = True
|
||||||
values, indices = a.mode(-1)
|
values, indices = a.mode(-1)
|
||||||
self.assertEqual(values, torch.ones(shape[0], dtype=torch.bool))
|
self.assertEqual(values, torch.ones(shape[0], dtype=torch.bool))
|
||||||
print(indices)
|
|
||||||
indexed = a.gather(1, indices.unsqueeze(1)).squeeze(1)
|
indexed = a.gather(1, indices.unsqueeze(1)).squeeze(1)
|
||||||
self.assertEqual(values, indexed)
|
self.assertEqual(values, indexed)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,57 +1,57 @@
|
||||||
# Owner(s): ["module: tests"]
|
# Owner(s): ["module: tests"]
|
||||||
|
|
||||||
import torch
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
import math
|
import math
|
||||||
from numbers import Number
|
|
||||||
import random
|
import random
|
||||||
import unittest
|
import unittest
|
||||||
|
from numbers import Number
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
from torch import inf, nan
|
from torch import inf, nan
|
||||||
from torch.testing._internal.common_utils import (
|
|
||||||
TestCase,
|
|
||||||
run_tests,
|
|
||||||
torch_to_numpy_dtype_dict,
|
|
||||||
numpy_to_torch_dtype_dict,
|
|
||||||
suppress_warnings,
|
|
||||||
TEST_SCIPY,
|
|
||||||
slowTest,
|
|
||||||
skipIfNoSciPy,
|
|
||||||
IS_WINDOWS,
|
|
||||||
gradcheck,
|
|
||||||
is_iterable_of_tensors,
|
|
||||||
xfailIfTorchDynamo,
|
|
||||||
)
|
|
||||||
from torch.testing._internal.common_methods_invocations import (
|
|
||||||
unary_ufuncs,
|
|
||||||
generate_elementwise_unary_tensors,
|
|
||||||
generate_elementwise_unary_small_value_tensors,
|
|
||||||
generate_elementwise_unary_large_value_tensors,
|
|
||||||
generate_elementwise_unary_extremal_value_tensors,
|
|
||||||
)
|
|
||||||
from torch.testing._internal.common_device_type import (
|
|
||||||
instantiate_device_type_tests,
|
|
||||||
ops,
|
|
||||||
dtypes,
|
|
||||||
onlyCPU,
|
|
||||||
onlyNativeDeviceTypes,
|
|
||||||
onlyCUDA,
|
|
||||||
dtypesIfCUDA,
|
|
||||||
precisionOverride,
|
|
||||||
dtypesIfCPU,
|
|
||||||
)
|
|
||||||
from torch.utils import _pytree as pytree
|
|
||||||
|
|
||||||
from torch.testing import make_tensor
|
from torch.testing import make_tensor
|
||||||
|
from torch.testing._internal.common_device_type import (
|
||||||
|
dtypes,
|
||||||
|
dtypesIfCPU,
|
||||||
|
dtypesIfCUDA,
|
||||||
|
instantiate_device_type_tests,
|
||||||
|
onlyCPU,
|
||||||
|
onlyCUDA,
|
||||||
|
onlyNativeDeviceTypes,
|
||||||
|
ops,
|
||||||
|
precisionOverride,
|
||||||
|
)
|
||||||
from torch.testing._internal.common_dtype import (
|
from torch.testing._internal.common_dtype import (
|
||||||
floating_types_and,
|
|
||||||
all_types_and_complex_and,
|
all_types_and_complex_and,
|
||||||
integral_types_and,
|
|
||||||
get_all_math_dtypes,
|
|
||||||
complex_types,
|
complex_types,
|
||||||
floating_and_complex_types_and,
|
floating_and_complex_types_and,
|
||||||
|
floating_types_and,
|
||||||
|
get_all_math_dtypes,
|
||||||
|
integral_types_and,
|
||||||
)
|
)
|
||||||
|
from torch.testing._internal.common_methods_invocations import (
|
||||||
|
generate_elementwise_unary_extremal_value_tensors,
|
||||||
|
generate_elementwise_unary_large_value_tensors,
|
||||||
|
generate_elementwise_unary_small_value_tensors,
|
||||||
|
generate_elementwise_unary_tensors,
|
||||||
|
unary_ufuncs,
|
||||||
|
)
|
||||||
|
from torch.testing._internal.common_utils import (
|
||||||
|
gradcheck,
|
||||||
|
is_iterable_of_tensors,
|
||||||
|
IS_WINDOWS,
|
||||||
|
numpy_to_torch_dtype_dict,
|
||||||
|
run_tests,
|
||||||
|
skipIfNoSciPy,
|
||||||
|
slowTest,
|
||||||
|
suppress_warnings,
|
||||||
|
TEST_SCIPY,
|
||||||
|
TestCase,
|
||||||
|
torch_to_numpy_dtype_dict,
|
||||||
|
xfailIfTorchDynamo,
|
||||||
|
)
|
||||||
|
from torch.utils import _pytree as pytree
|
||||||
|
|
||||||
if TEST_SCIPY:
|
if TEST_SCIPY:
|
||||||
import scipy
|
import scipy
|
||||||
|
|
@ -172,7 +172,7 @@ class TestUnaryUfuncs(TestCase):
|
||||||
torch.from_numpy(expected).to(actual.dtype),
|
torch.from_numpy(expected).to(actual.dtype),
|
||||||
msg,
|
msg,
|
||||||
exact_device=False,
|
exact_device=False,
|
||||||
**kwargs
|
**kwargs,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.assertEqual(actual, expected, msg, exact_device=False, **kwargs)
|
self.assertEqual(actual, expected, msg, exact_device=False, **kwargs)
|
||||||
|
|
@ -444,7 +444,9 @@ class TestUnaryUfuncs(TestCase):
|
||||||
|
|
||||||
all_outs = [op(slice, **torch_kwargs) for slice in input]
|
all_outs = [op(slice, **torch_kwargs) for slice in input]
|
||||||
if is_iterable_of_tensors(actual):
|
if is_iterable_of_tensors(actual):
|
||||||
expected = [torch.stack([out[i] for out in all_outs]) for i in range(len(actual))]
|
expected = [
|
||||||
|
torch.stack([out[i] for out in all_outs]) for i in range(len(actual))
|
||||||
|
]
|
||||||
else:
|
else:
|
||||||
expected = torch.stack(all_outs)
|
expected = torch.stack(all_outs)
|
||||||
|
|
||||||
|
|
@ -510,8 +512,10 @@ class TestUnaryUfuncs(TestCase):
|
||||||
shapes = [[1, 3, 6, 6], [1, 3, 6, 128], [1, 3, 256, 256]]
|
shapes = [[1, 3, 6, 6], [1, 3, 6, 128], [1, 3, 256, 256]]
|
||||||
for shape in shapes:
|
for shape in shapes:
|
||||||
x = torch.randn(shape, device=device)
|
x = torch.randn(shape, device=device)
|
||||||
extremals = [float('nan'), float('inf'), -float('inf')]
|
extremals = [float("nan"), float("inf"), -float("inf")]
|
||||||
for id1, id2, extremal in zip(torch.randint(0, 2, (3,)), torch.randint(0, 5, (3,)), extremals):
|
for id1, id2, extremal in zip(
|
||||||
|
torch.randint(0, 2, (3,)), torch.randint(0, 5, (3,)), extremals
|
||||||
|
):
|
||||||
x[0, id1, id2, :] = extremal
|
x[0, id1, id2, :] = extremal
|
||||||
test_dtype(func(), x, torch.bfloat16)
|
test_dtype(func(), x, torch.bfloat16)
|
||||||
|
|
||||||
|
|
@ -520,9 +524,13 @@ class TestUnaryUfuncs(TestCase):
|
||||||
value_dtype = torch.tensor([], dtype=dtype).real.dtype
|
value_dtype = torch.tensor([], dtype=dtype).real.dtype
|
||||||
|
|
||||||
def gen_tensor(a):
|
def gen_tensor(a):
|
||||||
return torch.view_as_complex(torch.tensor(a, dtype=value_dtype, device=device))
|
return torch.view_as_complex(
|
||||||
|
torch.tensor(a, dtype=value_dtype, device=device)
|
||||||
|
)
|
||||||
|
|
||||||
for extremal, kwarg_name in zip(['nan', 'inf', '-inf'], ['nan', 'posinf', 'neginf']):
|
for extremal, kwarg_name in zip(
|
||||||
|
["nan", "inf", "-inf"], ["nan", "posinf", "neginf"]
|
||||||
|
):
|
||||||
a = gen_tensor([123, float(extremal)])
|
a = gen_tensor([123, float(extremal)])
|
||||||
res = torch.nan_to_num(a, **{kwarg_name: 12})
|
res = torch.nan_to_num(a, **{kwarg_name: 12})
|
||||||
res_check = gen_tensor([123, 12])
|
res_check = gen_tensor([123, 12])
|
||||||
|
|
@ -1078,17 +1086,21 @@ class TestUnaryUfuncs(TestCase):
|
||||||
(1e-19 + 1e-18j, 4.99999984132761269448e-20 + 5.00000022906852482872e-19j),
|
(1e-19 + 1e-18j, 4.99999984132761269448e-20 + 5.00000022906852482872e-19j),
|
||||||
(-1.0 + 2.0j, -0.78546208143234252930 + -0.44626939296722412109j),
|
(-1.0 + 2.0j, -0.78546208143234252930 + -0.44626939296722412109j),
|
||||||
(0.0 + 0.5j, -0.06383547931909561157 + 0.25000000000000000000j),
|
(0.0 + 0.5j, -0.06383547931909561157 + 0.25000000000000000000j),
|
||||||
(2.0j, -1.55740761756896972656 + 0.99999988079071044922j)
|
(2.0j, -1.55740761756896972656 + 0.99999988079071044922j),
|
||||||
]
|
]
|
||||||
|
|
||||||
for inp, out in inouts:
|
for inp, out in inouts:
|
||||||
res = torch.nn.functional.silu(torch.tensor(inp, dtype=dtype, device=device))
|
res = torch.nn.functional.silu(
|
||||||
|
torch.tensor(inp, dtype=dtype, device=device)
|
||||||
|
)
|
||||||
self.assertFalse(torch.any(torch.isnan(res)))
|
self.assertFalse(torch.any(torch.isnan(res)))
|
||||||
self.assertEqual(res.real, out.real, atol=atol, rtol=rtol)
|
self.assertEqual(res.real, out.real, atol=atol, rtol=rtol)
|
||||||
self.assertEqual(res.imag, out.imag, atol=atol, rtol=rtol)
|
self.assertEqual(res.imag, out.imag, atol=atol, rtol=rtol)
|
||||||
|
|
||||||
for inp, out in inouts:
|
for inp, out in inouts:
|
||||||
res = torch.nn.functional.silu(torch.tensor(inp, dtype=dtype, device=device), inplace=True)
|
res = torch.nn.functional.silu(
|
||||||
|
torch.tensor(inp, dtype=dtype, device=device), inplace=True
|
||||||
|
)
|
||||||
self.assertFalse(torch.any(torch.isnan(res)))
|
self.assertFalse(torch.any(torch.isnan(res)))
|
||||||
self.assertEqual(res.real, out.real, atol=atol, rtol=rtol)
|
self.assertEqual(res.real, out.real, atol=atol, rtol=rtol)
|
||||||
self.assertEqual(res.imag, out.imag, atol=atol, rtol=rtol)
|
self.assertEqual(res.imag, out.imag, atol=atol, rtol=rtol)
|
||||||
|
|
@ -1487,6 +1499,19 @@ class TestUnaryUfuncs(TestCase):
|
||||||
for num in abs_zeros:
|
for num in abs_zeros:
|
||||||
self.assertGreater(math.copysign(1.0, num), 0.0)
|
self.assertGreater(math.copysign(1.0, num), 0.0)
|
||||||
|
|
||||||
|
@onlyCUDA
|
||||||
|
@dtypes(torch.bool, torch.int8)
|
||||||
|
def test_narrow_dtypes(self, device, dtype):
|
||||||
|
x_int = torch.randint(2, (8 * 1024,), device=device, dtype=torch.int)
|
||||||
|
x = x_int.to(dtype)
|
||||||
|
# check normal conversion
|
||||||
|
self.assertEqual(x_int, x.int())
|
||||||
|
x.fill_(0)
|
||||||
|
self.assertEqual(x.sum(), 0)
|
||||||
|
# test unaligned tensor with non-round number of elements
|
||||||
|
x[1:4000].fill_(1)
|
||||||
|
self.assertEqual(x.sum(), 3999)
|
||||||
|
|
||||||
@dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16))
|
@dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16))
|
||||||
def test_isposinf_isneginf_non_boolean_output(self, device, dtype):
|
def test_isposinf_isneginf_non_boolean_output(self, device, dtype):
|
||||||
# test non-boolean tensors as the `out=` parameters
|
# test non-boolean tensors as the `out=` parameters
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user