[MPS][BE] Move common binary ops macros to indexing.h (#149263)

And binary op invocation logic to OperationUtils.mm

This is a no-op change, additional sanity checks/logic improvements will be added as followups
Pull Request resolved: https://github.com/pytorch/pytorch/pull/149263
Approved by: https://github.com/dcci
ghstack dependencies: #149262
This commit is contained in:
Nikita Shulga 2025-03-15 14:58:39 -07:00 committed by PyTorch MergeBot
parent 21c2edfec8
commit f80bee4934
5 changed files with 109 additions and 104 deletions

View File

@ -133,6 +133,10 @@ class MetalShaderLibrary {
TensorIteratorBase& iter,
const std::string& name,
std::optional<int64_t> extra = std::nullopt);
void exec_binary_kernel(
TensorIteratorBase& iter,
const std::string& name,
const bool supports_dense = true);
protected:
virtual MTLLibrary_t getLibrary();

View File

@ -1010,6 +1010,49 @@ void MetalShaderLibrary::exec_unary_kernel(TensorIteratorBase& iter,
}
}
void MetalShaderLibrary::exec_binary_kernel(TensorIteratorBase& iter,
const std::string& name,
const bool supports_dense) {
TORCH_CHECK(iter.common_dtype() != at::kDouble, "float64 is not supported on MPS");
Tensor input = iter.input(0);
Tensor other = iter.input(1);
Tensor out = iter.output();
id<MTLDevice> device = MPSDevice::getInstance()->device();
MPSStream* mpsStream = getCurrentMPSStream();
const uint32_t nDim = iter.ndim();
constexpr uint32_t nOffsets = 3;
const uint32_t numThreads = iter.numel();
dispatch_sync_with_rethrow(mpsStream->queue(), ^() {
@autoreleasepool {
auto computeEncoder = mpsStream->commandEncoder();
if (supports_dense && iter.is_contiguous()) {
const auto kernel_name = fmt::format("{}_dense_{}", name, scalarToMetalTypeString(input));
auto binaryPSO = getPipelineStateForFunc(kernel_name);
[computeEncoder setComputePipelineState:binaryPSO];
mtl_setArgs(computeEncoder, input, other, out);
mtl_dispatch1DJob(computeEncoder, binaryPSO, numThreads);
return;
}
const auto kernel = fmt::format("{}_{}", name, scalarToMetalTypeString(input));
auto kernelDataOffsets = generateKernelDataOffsets(computeEncoder, iter);
auto binaryPSO = getPipelineStateForFunc(kernel);
// this function call is a no-op if MPS Profiler is not enabled
getMPSProfiler().beginProfileKernel(binaryPSO, kernel, {input, other});
[computeEncoder setComputePipelineState:binaryPSO];
mtl_setArgs(computeEncoder, input, other, out);
[computeEncoder setBuffer:kernelDataOffsets offset:0 atIndex:3];
mtl_dispatch1DJob(computeEncoder, binaryPSO, numThreads);
getMPSProfiler().endProfileKernel(binaryPSO);
}
});
}
MetalShaderLibrary& MetalShaderLibrary::getBundledLibrary() {
static BundledShaderLibary l;
return l;

View File

@ -1,8 +1,8 @@
#include <c10/metal/indexing.h>
#include <c10/metal/special_math.h>
#include <c10/metal/utils.h>
#include <metal_stdlib>
using namespace metal;
using namespace c10::metal;
struct fmax_functor {
template <typename T>
@ -92,54 +92,6 @@ struct polar_functor {
}
};
template <typename T, typename F>
kernel void binary_indexing(
constant void* input_ [[buffer(0)]],
constant void* other_ [[buffer(1)]],
device void* out_ [[buffer(2)]],
constant uint3* offsets [[buffer(3)]],
uint tid [[thread_position_in_grid]]) {
auto out =
(device result_of<F, T, T>*)((device uint8_t*)out_ + offsets[tid].x);
auto input = (constant T*)((constant uint8_t*)input_ + offsets[tid].y);
auto other = (constant T*)((constant uint8_t*)other_ + offsets[tid].z);
F f;
*out = f(*input, *other);
}
template <typename T, typename F>
kernel void binary_dense(
constant T* input [[buffer(0)]],
constant T* other [[buffer(1)]],
device result_of<F, T, T>* out [[buffer(2)]],
uint tid [[thread_position_in_grid]]) {
F f;
out[tid] = f(input[tid], other[tid]);
}
#define REGISTER_BINARY_INDEXING_OP(NAME, DTYPE) \
template [[host_name(#NAME "_" #DTYPE)]] kernel void \
binary_indexing<DTYPE, NAME##_functor>( \
constant void* input_, \
constant void* other_, \
device void* out_, \
constant uint3* offsets, \
uint tid); \
template [[host_name(#NAME "_dense_" #DTYPE)]] kernel void \
binary_dense<DTYPE, NAME##_functor>( \
constant DTYPE * input_, \
constant DTYPE * other_, \
device result_of<NAME##_functor, DTYPE, DTYPE> * out_, \
uint tid)
#define REGISTER_BINARY_OP(NAME, DTYPE) \
template [[host_name(#NAME "_" #DTYPE)]] kernel void NAME<DTYPE>( \
constant void* input_, \
constant void* other_, \
device void* out_, \
constant uint3* offsets, \
uint tid)
REGISTER_BINARY_INDEXING_OP(copysign, long);
REGISTER_BINARY_INDEXING_OP(copysign, int);
REGISTER_BINARY_INDEXING_OP(copysign, float);
@ -186,9 +138,7 @@ kernel void complex_mul(
out[1] = input[0] * other[1] + input[1] * other[0];
}
REGISTER_BINARY_OP(complex_mul, float);
REGISTER_BINARY_OP(complex_mul, half);
// Constructs complex tensor from real and imaginary planes
template <typename T>
kernel void complex_kernel(
constant void* real_ [[buffer(0)]],
@ -203,5 +153,15 @@ kernel void complex_kernel(
out[1] = imag[0];
}
#define REGISTER_BINARY_OP(NAME, DTYPE) \
template [[host_name(#NAME "_" #DTYPE)]] kernel void NAME<DTYPE>( \
constant void* input_, \
constant void* other_, \
device void* out_, \
constant uint3* offsets, \
uint tid)
REGISTER_BINARY_OP(complex_mul, float);
REGISTER_BINARY_OP(complex_mul, half);
REGISTER_BINARY_OP(complex_kernel, float);
REGISTER_BINARY_OP(complex_kernel, half);

View File

@ -23,54 +23,13 @@
#endif
namespace at::native {
namespace mps {
#ifndef PYTORCH_JIT_COMPILE_SHADERS
static auto& lib = MetalShaderLibrary::getBundledLibrary();
static auto& lib = mps::MetalShaderLibrary::getBundledLibrary();
#else
#include <ATen/native/mps/BinaryKernel_metallib.h>
#endif
static void binary_mps_impl(TensorIteratorBase& iter, const std::string func_name, bool supports_dense = true) {
TORCH_CHECK(iter.common_dtype() != at::kDouble, "float64 is not supported on MPS");
Tensor input = iter.input(0);
Tensor other = iter.input(1);
Tensor out = iter.output();
id<MTLDevice> device = MPSDevice::getInstance()->device();
MPSStream* mpsStream = getCurrentMPSStream();
const uint32_t nDim = iter.ndim();
constexpr uint32_t nOffsets = 3;
const uint32_t numThreads = iter.numel();
dispatch_sync_with_rethrow(mpsStream->queue(), ^() {
@autoreleasepool {
auto computeEncoder = mpsStream->commandEncoder();
if (supports_dense && iter.is_contiguous()) {
const auto kernel_name = fmt::format("{}_dense_{}", func_name, scalarToMetalTypeString(input));
auto binaryPSO = lib.getPipelineStateForFunc(kernel_name);
[computeEncoder setComputePipelineState:binaryPSO];
mtl_setArgs(computeEncoder, input, other, out);
mtl_dispatch1DJob(computeEncoder, binaryPSO, numThreads);
return;
}
const std::string kernel = func_name + "_" + scalarToMetalTypeString(input);
auto kernelDataOffsets = generateKernelDataOffsets(computeEncoder, iter);
id<MTLComputePipelineState> binaryPSO = lib.getPipelineStateForFunc(kernel);
// this function call is a no-op if MPS Profiler is not enabled
getMPSProfiler().beginProfileKernel(binaryPSO, kernel, {input, other});
[computeEncoder setComputePipelineState:binaryPSO];
mtl_setArgs(computeEncoder, input, other, out);
[computeEncoder setBuffer:kernelDataOffsets offset:0 atIndex:3];
mtl_dispatch1DJob(computeEncoder, binaryPSO, numThreads);
getMPSProfiler().endProfileKernel(binaryPSO);
}
});
}
namespace mps {
void complex_mul_out(const Tensor& input, const Tensor& other, const Tensor& output) {
TORCH_INTERNAL_ASSERT(c10::isComplexType(input.scalar_type()) || c10::isComplexType(other.scalar_type()));
@ -89,43 +48,43 @@ void complex_mul_out(const Tensor& input, const Tensor& other, const Tensor& out
auto iter =
TensorIteratorConfig().add_output(output_as_real).add_input(input_as_real).add_input(other_as_real).build();
mps::binary_mps_impl(iter, "complex_mul", false);
lib.exec_binary_kernel(iter, "complex_mul", /*supports_dense=*/false);
}
} // namespace mps
static void fmax_mps_kernel(TensorIteratorBase& iter) {
if (isFloatingType(iter.common_dtype())) {
mps::binary_mps_impl(iter, "fmax");
lib.exec_binary_kernel(iter, "fmax");
} else {
at::maximum_out(const_cast<Tensor&>(iter.output()), iter.input(0), iter.input(1));
}
}
static void fmin_mps_kernel(TensorIteratorBase& iter) {
if (isFloatingType(iter.common_dtype())) {
mps::binary_mps_impl(iter, "fmin");
lib.exec_binary_kernel(iter, "fmin");
} else {
at::minimum_out(const_cast<Tensor&>(iter.output()), iter.input(0), iter.input(1));
}
}
static void copysign_mps_kernel(TensorIteratorBase& iter) {
mps::binary_mps_impl(iter, "copysign");
lib.exec_binary_kernel(iter, "copysign");
}
static void nextafter_mps_kernel(TensorIteratorBase& iter) {
TORCH_CHECK_TYPE(isFloatingType(iter.common_dtype()), "nextafter_mps not implemented for non-floating types");
mps::binary_mps_impl(iter, "nextafter");
lib.exec_binary_kernel(iter, "nextafter");
}
static void zeta_mps_kernel(TensorIteratorBase& iter) {
TORCH_CHECK_TYPE(isFloatingType(iter.common_dtype()), "zeta_mps not implemented for non-floating types");
mps::binary_mps_impl(iter, "zeta");
lib.exec_binary_kernel(iter, "zeta");
}
static void xlog1py_mps_kernel(TensorIteratorBase& iter) {
TORCH_CHECK_TYPE(isFloatingType(iter.common_dtype()), "xlog1py_mps not implemented for non-floating types");
mps::binary_mps_impl(iter, "xlog1py");
lib.exec_binary_kernel(iter, "xlog1py");
}
REGISTER_DISPATCH(fmax_stub, &fmax_mps_kernel)
@ -147,7 +106,7 @@ Tensor& polar_out_mps(const Tensor& abs, const Tensor& angle, Tensor& output) {
auto output_as_real = at::view_as_real(output).select(output.dim(), 0);
auto iter = TensorIteratorConfig().add_output(output_as_real).add_input(abs).add_input(angle).build();
mps::binary_mps_impl(iter, "polar");
lib.exec_binary_kernel(iter, "polar");
return output;
}
@ -163,7 +122,7 @@ Tensor& complex_out_mps(const Tensor& real, const Tensor& imag, Tensor& output)
auto output_as_real = at::view_as_real(output).select(output.dim(), 0);
auto iter = TensorIteratorConfig().add_output(output_as_real).add_input(real).add_input(imag).build();
mps::binary_mps_impl(iter, "complex_kernel", false);
lib.exec_binary_kernel(iter, "complex_kernel", /*supports_dense=*/false);
return output;
}
} // namespace at::native

View File

@ -103,5 +103,44 @@ kernel void unary_strided(
} \
}
template <typename T, typename F>
kernel void binary_indexing(
constant void* input_ [[buffer(0)]],
constant void* other_ [[buffer(1)]],
device void* out_ [[buffer(2)]],
constant uint3* offsets [[buffer(3)]],
uint tid [[thread_position_in_grid]]) {
auto out =
(device result_of<F, T, T>*)((device uint8_t*)out_ + offsets[tid].x);
auto input = (constant T*)((constant uint8_t*)input_ + offsets[tid].y);
auto other = (constant T*)((constant uint8_t*)other_ + offsets[tid].z);
F f;
*out = f(*input, *other);
}
template <typename T, typename F>
kernel void binary_dense(
constant T* input [[buffer(0)]],
constant T* other [[buffer(1)]],
device result_of<F, T, T>* out [[buffer(2)]],
uint tid [[thread_position_in_grid]]) {
F f;
out[tid] = f(input[tid], other[tid]);
}
#define REGISTER_BINARY_INDEXING_OP(NAME, DTYPE) \
template [[host_name(#NAME "_" #DTYPE)]] kernel void ::c10::metal:: \
binary_indexing<DTYPE, NAME##_functor>( \
constant void* input_, \
constant void* other_, \
device void* out_, \
constant uint3* offsets, \
uint tid); \
template [[host_name(#NAME "_dense_" #DTYPE)]] kernel void ::c10::metal:: \
binary_dense<DTYPE, NAME##_functor>( \
constant DTYPE * input_, \
constant DTYPE * other_, \
device ::c10::metal::result_of<NAME##_functor, DTYPE, DTYPE> * out_, \
uint tid)
} // namespace metal
} // namespace c10