[BE] Move sinc kernels to the same OP family (#148399)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/148399
Approved by: https://github.com/dcci
ghstack dependencies: #148398
This commit is contained in:
Nikita Shulga 2025-03-04 07:46:35 -08:00 committed by PyTorch MergeBot
parent 7fcbaff206
commit 67937be673

View File

@ -92,6 +92,22 @@ kernel void tanh_complex_kernel(
vec2type_t<T0>(tanh_x, tan_y), vec2type_t<T0>(T0(1), tanh_x * tan_y));
}
template <typename T0, typename T1>
kernel void sinc_kernel(
device T0* output [[buffer(0)]],
constant T1* input [[buffer(1)]],
uint index [[thread_position_in_grid]]) {
output[index] = T0(sinc(static_cast<float>(input[index])));
}
template <typename T0>
kernel void sinc_complex_kernel(
device vec2type_t<T0>* output [[buffer(0)]],
constant vec2type_t<T0>* input [[buffer(1)]],
uint index [[thread_position_in_grid]]) {
output[index] = vec2type_t<T0>(sinc(float2(input[index])));
}
#define INSTANTIATE_UNARY_KERNELS2(DTYPE0, DTYPE1) \
template [[host_name("erfinv_" #DTYPE0 "_" #DTYPE1)]] kernel void \
erfinv_kernel( \
@ -107,6 +123,10 @@ kernel void tanh_complex_kernel(
constant DTYPE1* input [[buffer(1)]], \
uint id [[thread_position_in_grid]]); \
template [[host_name("tanh_" #DTYPE0 "_" #DTYPE1)]] kernel void tanh_kernel( \
device DTYPE0* output [[buffer(0)]], \
constant DTYPE1* input [[buffer(1)]], \
uint id [[thread_position_in_grid]]); \
template [[host_name("sinc_" #DTYPE0 "_" #DTYPE1)]] kernel void sinc_kernel( \
device DTYPE0* output [[buffer(0)]], \
constant DTYPE1* input [[buffer(1)]], \
uint id [[thread_position_in_grid]]);
@ -136,6 +156,11 @@ INSTANTIATE_UNARY_KERNELS2(float, long);
uint did [[thread_position_in_grid]]); \
template [[host_name("sqrt_complex_" #DTYPE "_" #DTYPE)]] kernel void \
sqrt_complex_kernel<DTYPE>( \
device vec2type_t<DTYPE> * output [[buffer(0)]], \
constant vec2type_t<DTYPE> * input [[buffer(1)]], \
uint did [[thread_position_in_grid]]); \
template [[host_name("sinc_complex_" #DTYPE "_" #DTYPE)]] kernel void \
sinc_complex_kernel<DTYPE>( \
device vec2type_t<DTYPE> * output [[buffer(0)]], \
constant vec2type_t<DTYPE> * input [[buffer(1)]], \
uint did [[thread_position_in_grid]]);
@ -143,49 +168,6 @@ INSTANTIATE_UNARY_KERNELS2(float, long);
INSTANTIATE_UNARY_KERNELS_VEC2(half);
INSTANTIATE_UNARY_KERNELS_VEC2(float);
template <typename T0, typename T1>
kernel void sinc_kernel(
device T0* output [[buffer(0)]],
constant T1* input [[buffer(1)]],
uint index [[thread_position_in_grid]]) {
output[index] = T0(sinc(static_cast<float>(input[index])));
}
template <typename T>
kernel void sinc_complex(
device T* output [[buffer(0)]],
constant T* input [[buffer(1)]],
uint index [[thread_position_in_grid]]) {
output[index] = T(sinc(float2(input[index])));
}
#define INSTANTIATE_SINC_KERNEL(DTYPE0, DTYPE1) \
template [[host_name("sinc_" #DTYPE0 "_" #DTYPE1)]] kernel void sinc_kernel( \
device DTYPE0* output [[buffer(0)]], \
constant DTYPE1* input [[buffer(1)]], \
uint id [[thread_position_in_grid]])
#define INSTANTIATE_SINC_COMPLEX_KERNEL(DTYPE) \
template [[host_name("sinc_complex_" #DTYPE "_" #DTYPE)]] kernel void \
sinc_complex( \
device DTYPE##2 * output [[buffer(0)]], \
constant DTYPE##2 * input [[buffer(1)]], \
uint id [[thread_position_in_grid]])
#if __METAL_VERSION__ >= 310
INSTANTIATE_SINC_KERNEL(bfloat, bfloat);
#endif
INSTANTIATE_SINC_KERNEL(half, half);
INSTANTIATE_SINC_KERNEL(float, float);
INSTANTIATE_SINC_KERNEL(float, long);
INSTANTIATE_SINC_KERNEL(float, int);
INSTANTIATE_SINC_KERNEL(float, short);
INSTANTIATE_SINC_KERNEL(float, char);
INSTANTIATE_SINC_KERNEL(float, uchar);
INSTANTIATE_SINC_KERNEL(float, bool);
INSTANTIATE_SINC_COMPLEX_KERNEL(half);
INSTANTIATE_SINC_COMPLEX_KERNEL(float);
template <typename T>
kernel void round_decimals_kernel(
device T* output [[buffer(0)]],