mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[MPS][BE] Move common binary ops macros to indexing.h (#149263)
And binary op invocation logic to OperationUtils.mm This is a no-op change, additional sanity checks/logic improvements will be added as followups Pull Request resolved: https://github.com/pytorch/pytorch/pull/149263 Approved by: https://github.com/dcci ghstack dependencies: #149262
This commit is contained in:
parent
21c2edfec8
commit
f80bee4934
|
|
@ -133,6 +133,10 @@ class MetalShaderLibrary {
|
|||
TensorIteratorBase& iter,
|
||||
const std::string& name,
|
||||
std::optional<int64_t> extra = std::nullopt);
|
||||
void exec_binary_kernel(
|
||||
TensorIteratorBase& iter,
|
||||
const std::string& name,
|
||||
const bool supports_dense = true);
|
||||
|
||||
protected:
|
||||
virtual MTLLibrary_t getLibrary();
|
||||
|
|
|
|||
|
|
@ -1010,6 +1010,49 @@ void MetalShaderLibrary::exec_unary_kernel(TensorIteratorBase& iter,
|
|||
}
|
||||
}
|
||||
|
||||
void MetalShaderLibrary::exec_binary_kernel(TensorIteratorBase& iter,
|
||||
const std::string& name,
|
||||
const bool supports_dense) {
|
||||
TORCH_CHECK(iter.common_dtype() != at::kDouble, "float64 is not supported on MPS");
|
||||
|
||||
Tensor input = iter.input(0);
|
||||
Tensor other = iter.input(1);
|
||||
Tensor out = iter.output();
|
||||
|
||||
id<MTLDevice> device = MPSDevice::getInstance()->device();
|
||||
MPSStream* mpsStream = getCurrentMPSStream();
|
||||
const uint32_t nDim = iter.ndim();
|
||||
constexpr uint32_t nOffsets = 3;
|
||||
const uint32_t numThreads = iter.numel();
|
||||
dispatch_sync_with_rethrow(mpsStream->queue(), ^() {
|
||||
@autoreleasepool {
|
||||
auto computeEncoder = mpsStream->commandEncoder();
|
||||
if (supports_dense && iter.is_contiguous()) {
|
||||
const auto kernel_name = fmt::format("{}_dense_{}", name, scalarToMetalTypeString(input));
|
||||
auto binaryPSO = getPipelineStateForFunc(kernel_name);
|
||||
[computeEncoder setComputePipelineState:binaryPSO];
|
||||
mtl_setArgs(computeEncoder, input, other, out);
|
||||
mtl_dispatch1DJob(computeEncoder, binaryPSO, numThreads);
|
||||
return;
|
||||
}
|
||||
const auto kernel = fmt::format("{}_{}", name, scalarToMetalTypeString(input));
|
||||
auto kernelDataOffsets = generateKernelDataOffsets(computeEncoder, iter);
|
||||
|
||||
auto binaryPSO = getPipelineStateForFunc(kernel);
|
||||
|
||||
// this function call is a no-op if MPS Profiler is not enabled
|
||||
getMPSProfiler().beginProfileKernel(binaryPSO, kernel, {input, other});
|
||||
|
||||
[computeEncoder setComputePipelineState:binaryPSO];
|
||||
mtl_setArgs(computeEncoder, input, other, out);
|
||||
[computeEncoder setBuffer:kernelDataOffsets offset:0 atIndex:3];
|
||||
mtl_dispatch1DJob(computeEncoder, binaryPSO, numThreads);
|
||||
|
||||
getMPSProfiler().endProfileKernel(binaryPSO);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
MetalShaderLibrary& MetalShaderLibrary::getBundledLibrary() {
|
||||
static BundledShaderLibary l;
|
||||
return l;
|
||||
|
|
|
|||
|
|
@ -1,8 +1,8 @@
|
|||
#include <c10/metal/indexing.h>
|
||||
#include <c10/metal/special_math.h>
|
||||
#include <c10/metal/utils.h>
|
||||
#include <metal_stdlib>
|
||||
using namespace metal;
|
||||
using namespace c10::metal;
|
||||
|
||||
struct fmax_functor {
|
||||
template <typename T>
|
||||
|
|
@ -92,54 +92,6 @@ struct polar_functor {
|
|||
}
|
||||
};
|
||||
|
||||
template <typename T, typename F>
|
||||
kernel void binary_indexing(
|
||||
constant void* input_ [[buffer(0)]],
|
||||
constant void* other_ [[buffer(1)]],
|
||||
device void* out_ [[buffer(2)]],
|
||||
constant uint3* offsets [[buffer(3)]],
|
||||
uint tid [[thread_position_in_grid]]) {
|
||||
auto out =
|
||||
(device result_of<F, T, T>*)((device uint8_t*)out_ + offsets[tid].x);
|
||||
auto input = (constant T*)((constant uint8_t*)input_ + offsets[tid].y);
|
||||
auto other = (constant T*)((constant uint8_t*)other_ + offsets[tid].z);
|
||||
F f;
|
||||
*out = f(*input, *other);
|
||||
}
|
||||
|
||||
template <typename T, typename F>
|
||||
kernel void binary_dense(
|
||||
constant T* input [[buffer(0)]],
|
||||
constant T* other [[buffer(1)]],
|
||||
device result_of<F, T, T>* out [[buffer(2)]],
|
||||
uint tid [[thread_position_in_grid]]) {
|
||||
F f;
|
||||
out[tid] = f(input[tid], other[tid]);
|
||||
}
|
||||
|
||||
#define REGISTER_BINARY_INDEXING_OP(NAME, DTYPE) \
|
||||
template [[host_name(#NAME "_" #DTYPE)]] kernel void \
|
||||
binary_indexing<DTYPE, NAME##_functor>( \
|
||||
constant void* input_, \
|
||||
constant void* other_, \
|
||||
device void* out_, \
|
||||
constant uint3* offsets, \
|
||||
uint tid); \
|
||||
template [[host_name(#NAME "_dense_" #DTYPE)]] kernel void \
|
||||
binary_dense<DTYPE, NAME##_functor>( \
|
||||
constant DTYPE * input_, \
|
||||
constant DTYPE * other_, \
|
||||
device result_of<NAME##_functor, DTYPE, DTYPE> * out_, \
|
||||
uint tid)
|
||||
|
||||
#define REGISTER_BINARY_OP(NAME, DTYPE) \
|
||||
template [[host_name(#NAME "_" #DTYPE)]] kernel void NAME<DTYPE>( \
|
||||
constant void* input_, \
|
||||
constant void* other_, \
|
||||
device void* out_, \
|
||||
constant uint3* offsets, \
|
||||
uint tid)
|
||||
|
||||
REGISTER_BINARY_INDEXING_OP(copysign, long);
|
||||
REGISTER_BINARY_INDEXING_OP(copysign, int);
|
||||
REGISTER_BINARY_INDEXING_OP(copysign, float);
|
||||
|
|
@ -186,9 +138,7 @@ kernel void complex_mul(
|
|||
out[1] = input[0] * other[1] + input[1] * other[0];
|
||||
}
|
||||
|
||||
REGISTER_BINARY_OP(complex_mul, float);
|
||||
REGISTER_BINARY_OP(complex_mul, half);
|
||||
|
||||
// Constructs complex tensor from real and imaginary planes
|
||||
template <typename T>
|
||||
kernel void complex_kernel(
|
||||
constant void* real_ [[buffer(0)]],
|
||||
|
|
@ -203,5 +153,15 @@ kernel void complex_kernel(
|
|||
out[1] = imag[0];
|
||||
}
|
||||
|
||||
#define REGISTER_BINARY_OP(NAME, DTYPE) \
|
||||
template [[host_name(#NAME "_" #DTYPE)]] kernel void NAME<DTYPE>( \
|
||||
constant void* input_, \
|
||||
constant void* other_, \
|
||||
device void* out_, \
|
||||
constant uint3* offsets, \
|
||||
uint tid)
|
||||
|
||||
REGISTER_BINARY_OP(complex_mul, float);
|
||||
REGISTER_BINARY_OP(complex_mul, half);
|
||||
REGISTER_BINARY_OP(complex_kernel, float);
|
||||
REGISTER_BINARY_OP(complex_kernel, half);
|
||||
|
|
|
|||
|
|
@ -23,54 +23,13 @@
|
|||
#endif
|
||||
|
||||
namespace at::native {
|
||||
namespace mps {
|
||||
|
||||
#ifndef PYTORCH_JIT_COMPILE_SHADERS
|
||||
static auto& lib = MetalShaderLibrary::getBundledLibrary();
|
||||
static auto& lib = mps::MetalShaderLibrary::getBundledLibrary();
|
||||
#else
|
||||
#include <ATen/native/mps/BinaryKernel_metallib.h>
|
||||
#endif
|
||||
|
||||
static void binary_mps_impl(TensorIteratorBase& iter, const std::string func_name, bool supports_dense = true) {
|
||||
TORCH_CHECK(iter.common_dtype() != at::kDouble, "float64 is not supported on MPS");
|
||||
|
||||
Tensor input = iter.input(0);
|
||||
Tensor other = iter.input(1);
|
||||
Tensor out = iter.output();
|
||||
|
||||
id<MTLDevice> device = MPSDevice::getInstance()->device();
|
||||
MPSStream* mpsStream = getCurrentMPSStream();
|
||||
const uint32_t nDim = iter.ndim();
|
||||
constexpr uint32_t nOffsets = 3;
|
||||
const uint32_t numThreads = iter.numel();
|
||||
dispatch_sync_with_rethrow(mpsStream->queue(), ^() {
|
||||
@autoreleasepool {
|
||||
auto computeEncoder = mpsStream->commandEncoder();
|
||||
if (supports_dense && iter.is_contiguous()) {
|
||||
const auto kernel_name = fmt::format("{}_dense_{}", func_name, scalarToMetalTypeString(input));
|
||||
auto binaryPSO = lib.getPipelineStateForFunc(kernel_name);
|
||||
[computeEncoder setComputePipelineState:binaryPSO];
|
||||
mtl_setArgs(computeEncoder, input, other, out);
|
||||
mtl_dispatch1DJob(computeEncoder, binaryPSO, numThreads);
|
||||
return;
|
||||
}
|
||||
const std::string kernel = func_name + "_" + scalarToMetalTypeString(input);
|
||||
auto kernelDataOffsets = generateKernelDataOffsets(computeEncoder, iter);
|
||||
|
||||
id<MTLComputePipelineState> binaryPSO = lib.getPipelineStateForFunc(kernel);
|
||||
|
||||
// this function call is a no-op if MPS Profiler is not enabled
|
||||
getMPSProfiler().beginProfileKernel(binaryPSO, kernel, {input, other});
|
||||
|
||||
[computeEncoder setComputePipelineState:binaryPSO];
|
||||
mtl_setArgs(computeEncoder, input, other, out);
|
||||
[computeEncoder setBuffer:kernelDataOffsets offset:0 atIndex:3];
|
||||
mtl_dispatch1DJob(computeEncoder, binaryPSO, numThreads);
|
||||
|
||||
getMPSProfiler().endProfileKernel(binaryPSO);
|
||||
}
|
||||
});
|
||||
}
|
||||
namespace mps {
|
||||
|
||||
void complex_mul_out(const Tensor& input, const Tensor& other, const Tensor& output) {
|
||||
TORCH_INTERNAL_ASSERT(c10::isComplexType(input.scalar_type()) || c10::isComplexType(other.scalar_type()));
|
||||
|
|
@ -89,43 +48,43 @@ void complex_mul_out(const Tensor& input, const Tensor& other, const Tensor& out
|
|||
auto iter =
|
||||
TensorIteratorConfig().add_output(output_as_real).add_input(input_as_real).add_input(other_as_real).build();
|
||||
|
||||
mps::binary_mps_impl(iter, "complex_mul", false);
|
||||
lib.exec_binary_kernel(iter, "complex_mul", /*supports_dense=*/false);
|
||||
}
|
||||
|
||||
} // namespace mps
|
||||
|
||||
static void fmax_mps_kernel(TensorIteratorBase& iter) {
|
||||
if (isFloatingType(iter.common_dtype())) {
|
||||
mps::binary_mps_impl(iter, "fmax");
|
||||
lib.exec_binary_kernel(iter, "fmax");
|
||||
} else {
|
||||
at::maximum_out(const_cast<Tensor&>(iter.output()), iter.input(0), iter.input(1));
|
||||
}
|
||||
}
|
||||
static void fmin_mps_kernel(TensorIteratorBase& iter) {
|
||||
if (isFloatingType(iter.common_dtype())) {
|
||||
mps::binary_mps_impl(iter, "fmin");
|
||||
lib.exec_binary_kernel(iter, "fmin");
|
||||
} else {
|
||||
at::minimum_out(const_cast<Tensor&>(iter.output()), iter.input(0), iter.input(1));
|
||||
}
|
||||
}
|
||||
|
||||
static void copysign_mps_kernel(TensorIteratorBase& iter) {
|
||||
mps::binary_mps_impl(iter, "copysign");
|
||||
lib.exec_binary_kernel(iter, "copysign");
|
||||
}
|
||||
|
||||
static void nextafter_mps_kernel(TensorIteratorBase& iter) {
|
||||
TORCH_CHECK_TYPE(isFloatingType(iter.common_dtype()), "nextafter_mps not implemented for non-floating types");
|
||||
mps::binary_mps_impl(iter, "nextafter");
|
||||
lib.exec_binary_kernel(iter, "nextafter");
|
||||
}
|
||||
|
||||
static void zeta_mps_kernel(TensorIteratorBase& iter) {
|
||||
TORCH_CHECK_TYPE(isFloatingType(iter.common_dtype()), "zeta_mps not implemented for non-floating types");
|
||||
mps::binary_mps_impl(iter, "zeta");
|
||||
lib.exec_binary_kernel(iter, "zeta");
|
||||
}
|
||||
|
||||
static void xlog1py_mps_kernel(TensorIteratorBase& iter) {
|
||||
TORCH_CHECK_TYPE(isFloatingType(iter.common_dtype()), "xlog1py_mps not implemented for non-floating types");
|
||||
mps::binary_mps_impl(iter, "xlog1py");
|
||||
lib.exec_binary_kernel(iter, "xlog1py");
|
||||
}
|
||||
|
||||
REGISTER_DISPATCH(fmax_stub, &fmax_mps_kernel)
|
||||
|
|
@ -147,7 +106,7 @@ Tensor& polar_out_mps(const Tensor& abs, const Tensor& angle, Tensor& output) {
|
|||
auto output_as_real = at::view_as_real(output).select(output.dim(), 0);
|
||||
auto iter = TensorIteratorConfig().add_output(output_as_real).add_input(abs).add_input(angle).build();
|
||||
|
||||
mps::binary_mps_impl(iter, "polar");
|
||||
lib.exec_binary_kernel(iter, "polar");
|
||||
return output;
|
||||
}
|
||||
|
||||
|
|
@ -163,7 +122,7 @@ Tensor& complex_out_mps(const Tensor& real, const Tensor& imag, Tensor& output)
|
|||
auto output_as_real = at::view_as_real(output).select(output.dim(), 0);
|
||||
auto iter = TensorIteratorConfig().add_output(output_as_real).add_input(real).add_input(imag).build();
|
||||
|
||||
mps::binary_mps_impl(iter, "complex_kernel", false);
|
||||
lib.exec_binary_kernel(iter, "complex_kernel", /*supports_dense=*/false);
|
||||
return output;
|
||||
}
|
||||
} // namespace at::native
|
||||
|
|
|
|||
|
|
@ -103,5 +103,44 @@ kernel void unary_strided(
|
|||
} \
|
||||
}
|
||||
|
||||
template <typename T, typename F>
|
||||
kernel void binary_indexing(
|
||||
constant void* input_ [[buffer(0)]],
|
||||
constant void* other_ [[buffer(1)]],
|
||||
device void* out_ [[buffer(2)]],
|
||||
constant uint3* offsets [[buffer(3)]],
|
||||
uint tid [[thread_position_in_grid]]) {
|
||||
auto out =
|
||||
(device result_of<F, T, T>*)((device uint8_t*)out_ + offsets[tid].x);
|
||||
auto input = (constant T*)((constant uint8_t*)input_ + offsets[tid].y);
|
||||
auto other = (constant T*)((constant uint8_t*)other_ + offsets[tid].z);
|
||||
F f;
|
||||
*out = f(*input, *other);
|
||||
}
|
||||
|
||||
template <typename T, typename F>
|
||||
kernel void binary_dense(
|
||||
constant T* input [[buffer(0)]],
|
||||
constant T* other [[buffer(1)]],
|
||||
device result_of<F, T, T>* out [[buffer(2)]],
|
||||
uint tid [[thread_position_in_grid]]) {
|
||||
F f;
|
||||
out[tid] = f(input[tid], other[tid]);
|
||||
}
|
||||
|
||||
#define REGISTER_BINARY_INDEXING_OP(NAME, DTYPE) \
|
||||
template [[host_name(#NAME "_" #DTYPE)]] kernel void ::c10::metal:: \
|
||||
binary_indexing<DTYPE, NAME##_functor>( \
|
||||
constant void* input_, \
|
||||
constant void* other_, \
|
||||
device void* out_, \
|
||||
constant uint3* offsets, \
|
||||
uint tid); \
|
||||
template [[host_name(#NAME "_dense_" #DTYPE)]] kernel void ::c10::metal:: \
|
||||
binary_dense<DTYPE, NAME##_functor>( \
|
||||
constant DTYPE * input_, \
|
||||
constant DTYPE * other_, \
|
||||
device ::c10::metal::result_of<NAME##_functor, DTYPE, DTYPE> * out_, \
|
||||
uint tid)
|
||||
} // namespace metal
|
||||
} // namespace c10
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user