mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[EZ][BE] Reuse result_of from c10/metal/utils.h (#149262)
No need for one more implementation Pull Request resolved: https://github.com/pytorch/pytorch/pull/149262 Approved by: https://github.com/dcci
This commit is contained in:
parent
acf42b0048
commit
3e2c4086ad
|
|
@ -2,6 +2,7 @@
|
|||
#include <c10/metal/utils.h>
|
||||
#include <metal_stdlib>
|
||||
using namespace metal;
|
||||
using namespace c10::metal;
|
||||
|
||||
struct fmax_functor {
|
||||
template <typename T>
|
||||
|
|
@ -91,12 +92,6 @@ struct polar_functor {
|
|||
}
|
||||
};
|
||||
|
||||
// Future BinaryTensorIterator
|
||||
template <typename T, typename F>
|
||||
using result_of = decltype(::metal::declval<F>()(
|
||||
::metal::declval<T>(),
|
||||
::metal::declval<T>()));
|
||||
|
||||
template <typename T, typename F>
|
||||
kernel void binary_indexing(
|
||||
constant void* input_ [[buffer(0)]],
|
||||
|
|
@ -104,7 +99,8 @@ kernel void binary_indexing(
|
|||
device void* out_ [[buffer(2)]],
|
||||
constant uint3* offsets [[buffer(3)]],
|
||||
uint tid [[thread_position_in_grid]]) {
|
||||
auto out = (device result_of<T, F>*)((device uint8_t*)out_ + offsets[tid].x);
|
||||
auto out =
|
||||
(device result_of<F, T, T>*)((device uint8_t*)out_ + offsets[tid].x);
|
||||
auto input = (constant T*)((constant uint8_t*)input_ + offsets[tid].y);
|
||||
auto other = (constant T*)((constant uint8_t*)other_ + offsets[tid].z);
|
||||
F f;
|
||||
|
|
@ -115,7 +111,7 @@ template <typename T, typename F>
|
|||
kernel void binary_dense(
|
||||
constant T* input [[buffer(0)]],
|
||||
constant T* other [[buffer(1)]],
|
||||
device result_of<T, F>* out [[buffer(2)]],
|
||||
device result_of<F, T, T>* out [[buffer(2)]],
|
||||
uint tid [[thread_position_in_grid]]) {
|
||||
F f;
|
||||
out[tid] = f(input[tid], other[tid]);
|
||||
|
|
@ -133,7 +129,7 @@ kernel void binary_dense(
|
|||
binary_dense<DTYPE, NAME##_functor>( \
|
||||
constant DTYPE * input_, \
|
||||
constant DTYPE * other_, \
|
||||
device result_of<DTYPE, NAME##_functor> * out_, \
|
||||
device result_of<NAME##_functor, DTYPE, DTYPE> * out_, \
|
||||
uint tid)
|
||||
|
||||
#define REGISTER_BINARY_OP(NAME, DTYPE) \
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user