mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[BE] Move sinc kernels to the same OP family (#148399)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/148399 Approved by: https://github.com/dcci ghstack dependencies: #148398
This commit is contained in:
parent
7fcbaff206
commit
67937be673
|
|
@ -92,6 +92,22 @@ kernel void tanh_complex_kernel(
|
|||
vec2type_t<T0>(tanh_x, tan_y), vec2type_t<T0>(T0(1), tanh_x * tan_y));
|
||||
}
|
||||
|
||||
template <typename T0, typename T1>
|
||||
kernel void sinc_kernel(
|
||||
device T0* output [[buffer(0)]],
|
||||
constant T1* input [[buffer(1)]],
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
output[index] = T0(sinc(static_cast<float>(input[index])));
|
||||
}
|
||||
|
||||
template <typename T0>
|
||||
kernel void sinc_complex_kernel(
|
||||
device vec2type_t<T0>* output [[buffer(0)]],
|
||||
constant vec2type_t<T0>* input [[buffer(1)]],
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
output[index] = vec2type_t<T0>(sinc(float2(input[index])));
|
||||
}
|
||||
|
||||
#define INSTANTIATE_UNARY_KERNELS2(DTYPE0, DTYPE1) \
|
||||
template [[host_name("erfinv_" #DTYPE0 "_" #DTYPE1)]] kernel void \
|
||||
erfinv_kernel( \
|
||||
|
|
@ -107,6 +123,10 @@ kernel void tanh_complex_kernel(
|
|||
constant DTYPE1* input [[buffer(1)]], \
|
||||
uint id [[thread_position_in_grid]]); \
|
||||
template [[host_name("tanh_" #DTYPE0 "_" #DTYPE1)]] kernel void tanh_kernel( \
|
||||
device DTYPE0* output [[buffer(0)]], \
|
||||
constant DTYPE1* input [[buffer(1)]], \
|
||||
uint id [[thread_position_in_grid]]); \
|
||||
template [[host_name("sinc_" #DTYPE0 "_" #DTYPE1)]] kernel void sinc_kernel( \
|
||||
device DTYPE0* output [[buffer(0)]], \
|
||||
constant DTYPE1* input [[buffer(1)]], \
|
||||
uint id [[thread_position_in_grid]]);
|
||||
|
|
@ -136,6 +156,11 @@ INSTANTIATE_UNARY_KERNELS2(float, long);
|
|||
uint did [[thread_position_in_grid]]); \
|
||||
template [[host_name("sqrt_complex_" #DTYPE "_" #DTYPE)]] kernel void \
|
||||
sqrt_complex_kernel<DTYPE>( \
|
||||
device vec2type_t<DTYPE> * output [[buffer(0)]], \
|
||||
constant vec2type_t<DTYPE> * input [[buffer(1)]], \
|
||||
uint did [[thread_position_in_grid]]); \
|
||||
template [[host_name("sinc_complex_" #DTYPE "_" #DTYPE)]] kernel void \
|
||||
sinc_complex_kernel<DTYPE>( \
|
||||
device vec2type_t<DTYPE> * output [[buffer(0)]], \
|
||||
constant vec2type_t<DTYPE> * input [[buffer(1)]], \
|
||||
uint did [[thread_position_in_grid]]);
|
||||
|
|
@ -143,49 +168,6 @@ INSTANTIATE_UNARY_KERNELS2(float, long);
|
|||
INSTANTIATE_UNARY_KERNELS_VEC2(half);
|
||||
INSTANTIATE_UNARY_KERNELS_VEC2(float);
|
||||
|
||||
template <typename T0, typename T1>
|
||||
kernel void sinc_kernel(
|
||||
device T0* output [[buffer(0)]],
|
||||
constant T1* input [[buffer(1)]],
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
output[index] = T0(sinc(static_cast<float>(input[index])));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
kernel void sinc_complex(
|
||||
device T* output [[buffer(0)]],
|
||||
constant T* input [[buffer(1)]],
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
output[index] = T(sinc(float2(input[index])));
|
||||
}
|
||||
|
||||
#define INSTANTIATE_SINC_KERNEL(DTYPE0, DTYPE1) \
|
||||
template [[host_name("sinc_" #DTYPE0 "_" #DTYPE1)]] kernel void sinc_kernel( \
|
||||
device DTYPE0* output [[buffer(0)]], \
|
||||
constant DTYPE1* input [[buffer(1)]], \
|
||||
uint id [[thread_position_in_grid]])
|
||||
|
||||
#define INSTANTIATE_SINC_COMPLEX_KERNEL(DTYPE) \
|
||||
template [[host_name("sinc_complex_" #DTYPE "_" #DTYPE)]] kernel void \
|
||||
sinc_complex( \
|
||||
device DTYPE##2 * output [[buffer(0)]], \
|
||||
constant DTYPE##2 * input [[buffer(1)]], \
|
||||
uint id [[thread_position_in_grid]])
|
||||
|
||||
#if __METAL_VERSION__ >= 310
|
||||
INSTANTIATE_SINC_KERNEL(bfloat, bfloat);
|
||||
#endif
|
||||
INSTANTIATE_SINC_KERNEL(half, half);
|
||||
INSTANTIATE_SINC_KERNEL(float, float);
|
||||
INSTANTIATE_SINC_KERNEL(float, long);
|
||||
INSTANTIATE_SINC_KERNEL(float, int);
|
||||
INSTANTIATE_SINC_KERNEL(float, short);
|
||||
INSTANTIATE_SINC_KERNEL(float, char);
|
||||
INSTANTIATE_SINC_KERNEL(float, uchar);
|
||||
INSTANTIATE_SINC_KERNEL(float, bool);
|
||||
INSTANTIATE_SINC_COMPLEX_KERNEL(half);
|
||||
INSTANTIATE_SINC_COMPLEX_KERNEL(float);
|
||||
|
||||
template <typename T>
|
||||
kernel void round_decimals_kernel(
|
||||
device T* output [[buffer(0)]],
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user