[EZ][BE] Reuse result_of from c10/metal/utils.h (#149262)

No need for one more implementation
Pull Request resolved: https://github.com/pytorch/pytorch/pull/149262
Approved by: https://github.com/dcci
This commit is contained in:
Nikita Shulga 2025-03-15 14:58:38 -07:00 committed by PyTorch MergeBot
parent acf42b0048
commit 3e2c4086ad

View File

@ -2,6 +2,7 @@
#include <c10/metal/utils.h>
#include <metal_stdlib>
using namespace metal;
using namespace c10::metal;
struct fmax_functor {
template <typename T>
@ -91,12 +92,6 @@ struct polar_functor {
}
};
// Future BinaryTensorIterator
template <typename T, typename F>
using result_of = decltype(::metal::declval<F>()(
::metal::declval<T>(),
::metal::declval<T>()));
template <typename T, typename F>
kernel void binary_indexing(
constant void* input_ [[buffer(0)]],
@ -104,7 +99,8 @@ kernel void binary_indexing(
device void* out_ [[buffer(2)]],
constant uint3* offsets [[buffer(3)]],
uint tid [[thread_position_in_grid]]) {
auto out = (device result_of<T, F>*)((device uint8_t*)out_ + offsets[tid].x);
auto out =
(device result_of<F, T, T>*)((device uint8_t*)out_ + offsets[tid].x);
auto input = (constant T*)((constant uint8_t*)input_ + offsets[tid].y);
auto other = (constant T*)((constant uint8_t*)other_ + offsets[tid].z);
F f;
@ -115,7 +111,7 @@ template <typename T, typename F>
kernel void binary_dense(
constant T* input [[buffer(0)]],
constant T* other [[buffer(1)]],
device result_of<T, F>* out [[buffer(2)]],
device result_of<F, T, T>* out [[buffer(2)]],
uint tid [[thread_position_in_grid]]) {
F f;
out[tid] = f(input[tid], other[tid]);
@ -133,7 +129,7 @@ kernel void binary_dense(
binary_dense<DTYPE, NAME##_functor>( \
constant DTYPE * input_, \
constant DTYPE * other_, \
device result_of<DTYPE, NAME##_functor> * out_, \
device result_of<NAME##_functor, DTYPE, DTYPE> * out_, \
uint tid)
#define REGISTER_BINARY_OP(NAME, DTYPE) \