[MPS][BE] Remove exec_binary_alpha_kernel (#152485)

Which was almost a complete copy-n-paste from exec_binary_kernel anyway
Just add `Scalar` as an optional argument and figure out kernel name during the invocation rather than in executor
Pull Request resolved: https://github.com/pytorch/pytorch/pull/152485
Approved by: https://github.com/Skylion007
ghstack dependencies: #152443, #152466, #152479, #152504
This commit is contained in:
Nikita Shulga 2025-04-29 22:52:57 -07:00 committed by PyTorch MergeBot
parent c90e23eb73
commit cf894b3f1f
5 changed files with 35 additions and 86 deletions

View File

@ -13,6 +13,7 @@ typedef void* MTLComputePipelineState_t;
typedef void* MTLComputeCommandEncoder_t;
#endif
#include <c10/core/Scalar.h>
#include <c10/util/OptionalArrayRef.h>
#include <functional>
#include <optional>
@ -29,9 +30,6 @@ struct TensorIteratorBase;
namespace at::native::mps {
// Forward declaration of MPSScalar - for exec_binary_alpha_kernel()
struct MPSScalar;
namespace detail {
template <typename T>
class has_size_type {
@ -140,11 +138,10 @@ class MetalShaderLibrary {
TensorIteratorBase& iter,
const std::string& name,
std::optional<int64_t> extra = std::nullopt);
void exec_binary_kernel(TensorIteratorBase& iter, const std::string& name);
void exec_binary_alpha_kernel(
void exec_binary_kernel(
TensorIteratorBase& iter,
const std::string& name,
const MPSScalar& alpha);
const std::optional<c10::Scalar> alpha = std::nullopt);
protected:
virtual MTLLibrary_t getLibrary();

View File

@ -1019,7 +1019,9 @@ void MetalShaderLibrary::exec_unary_kernel(TensorIteratorBase& iter,
}
}
void MetalShaderLibrary::exec_binary_kernel(TensorIteratorBase& iter, const std::string& name) {
void MetalShaderLibrary::exec_binary_kernel(TensorIteratorBase& iter,
const std::string& name,
std::optional<c10::Scalar> alpha) {
TORCH_CHECK(iter.common_dtype() != at::kDouble, "float64 is not supported on MPS");
TORCH_CHECK(iter.can_use_32bit_indexing(), "Can't be indexed using 32-bit iterator");
@ -1049,12 +1051,15 @@ void MetalShaderLibrary::exec_binary_kernel(TensorIteratorBase& iter, const std:
// i.e. it's true for both row-first and column-first tensors
if (iter.is_contiguous()) {
mtl_setArgs(computeEncoder, out, input, other);
if (alpha) {
mtl_setBytes(computeEncoder, getMPSScalar(*alpha, iter.common_dtype()), 3);
}
if (cast_needed) {
std::array<int, 4> size_and_types = {static_cast<int>(c10::elementSize(input.scalar_type())),
static_cast<int>(c10::elementSize(other.scalar_type())),
static_cast<int>(input.scalar_type()),
static_cast<int>(other.scalar_type())};
mtl_setBytes(computeEncoder, size_and_types, 3);
mtl_setBytes(computeEncoder, size_and_types, alpha ? 4 : 3);
}
} else {
// Please note that shapes and strides of the iterator might be
@ -1062,77 +1067,28 @@ void MetalShaderLibrary::exec_binary_kernel(TensorIteratorBase& iter, const std:
// between 4x4 tensor and scalar will result in 1D 16 element iterator
std::array<int, 3> ndim_and_types = {
iter.ndim(), static_cast<int>(input.scalar_type()), static_cast<int>(other.scalar_type())};
mtl_setArgs(computeEncoder,
out,
input,
other,
iter.shape(),
iter.strides(0),
iter.strides(1),
iter.strides(2),
ndim_and_types);
}
mtl_dispatch1DJob(computeEncoder, binaryPSO, numThreads);
getMPSProfiler().endProfileKernel(binaryPSO);
}
});
}
void MetalShaderLibrary::exec_binary_alpha_kernel(TensorIteratorBase& iter,
const std::string& name,
const MPSScalar& alpha) {
TORCH_CHECK(iter.common_dtype() != at::kDouble, "float64 is not supported on MPS");
TORCH_CHECK(iter.can_use_32bit_indexing(), "Can't be indexed using 32-bit iterator");
Tensor input = iter.input(0);
Tensor other = iter.input(1);
Tensor out = iter.output();
id<MTLDevice> device = MPSDevice::getInstance()->device();
MPSStream* mpsStream = getCurrentMPSStream();
const uint32_t nDim = iter.ndim();
constexpr uint32_t nOffsets = 3;
const uint32_t numThreads = iter.numel();
const auto cast_needed = input.scalar_type() != other.scalar_type();
const auto suffix = iter.is_contiguous() ? "dense" : "strided";
// TODO: Implicitly pass both input and output types to non-cast kernels
const auto kernel_name = cast_needed
? fmt::format("{}_alpha_{}_cast_{}", name, suffix, scalarToMetalTypeString(out))
: fmt::format("{}_alpha_{}_{}_{}", name, suffix, scalarToMetalTypeString(out), scalarToMetalTypeString(input));
dispatch_sync_with_rethrow(mpsStream->queue(), ^() {
@autoreleasepool {
auto computeEncoder = mpsStream->commandEncoder();
auto binaryPSO = getPipelineStateForFunc(kernel_name);
// this function call is a no-op if MPS Profiler is not enabled
getMPSProfiler().beginProfileKernel(binaryPSO, kernel_name, {input, other});
[computeEncoder setComputePipelineState:binaryPSO];
// Iterator is contiguous if all of its elements are dense in storage,
// i.e. it's true for both row-first and column-first tensors
if (iter.is_contiguous()) {
mtl_setArgs(computeEncoder, out, input, other, alpha);
if (cast_needed) {
std::array<int, 4> size_and_types = {static_cast<int>(c10::elementSize(input.scalar_type())),
static_cast<int>(c10::elementSize(other.scalar_type())),
static_cast<int>(input.scalar_type()),
static_cast<int>(other.scalar_type())};
mtl_setBytes(computeEncoder, size_and_types, 4);
if (alpha) {
mtl_setArgs(computeEncoder,
out,
input,
other,
getMPSScalar(*alpha, iter.common_dtype()),
iter.shape(),
iter.strides(0),
iter.strides(1),
iter.strides(2),
ndim_and_types);
} else {
mtl_setArgs(computeEncoder,
out,
input,
other,
iter.shape(),
iter.strides(0),
iter.strides(1),
iter.strides(2),
ndim_and_types);
}
} else {
// Please note that shapes and strides of the iterator might be
// different than that of its operands, for example binary op
// between 4x4 tensor and scalar will result in 1D 16 element iterator
std::array<int, 3> ndim_and_types = {
iter.ndim(), static_cast<int>(input.scalar_type()), static_cast<int>(other.scalar_type())};
mtl_setArgs(computeEncoder,
out,
input,
other,
alpha,
iter.shape(),
iter.strides(0),
iter.strides(1),
iter.strides(2),
ndim_and_types);
}
mtl_dispatch1DJob(computeEncoder, binaryPSO, numThreads);
getMPSProfiler().endProfileKernel(binaryPSO);

View File

@ -6,7 +6,7 @@ void binary_op_kernel(
const Tensor& input,
const Tensor& other,
const Tensor& output,
const std::optional<MPSScalar>& alpha = std::nullopt);
const std::optional<Scalar> alpha = std::nullopt);
void complex_mul_out(
const Tensor& input,
const Tensor& other,

View File

@ -54,7 +54,7 @@ void binary_op_kernel(const std::string func_name,
const Tensor& input,
const Tensor& other,
const Tensor& output,
const std::optional<MPSScalar>& alpha) {
const std::optional<Scalar> alpha) {
auto new_size = at::infer_size(input.sizes(), other.sizes());
if (!output.sizes().equals(new_size)) {
output.resize_(new_size);
@ -67,11 +67,7 @@ void binary_op_kernel(const std::string func_name,
auto iter =
TensorIteratorConfig().add_output(output).add_input(input).add_input(other).check_all_same_dtype(false).build();
if (alpha) {
lib.exec_binary_alpha_kernel(iter, func_name, *alpha);
} else {
lib.exec_binary_kernel(iter, func_name);
}
lib.exec_binary_kernel(iter, func_name, alpha);
}
} // namespace mps

View File

@ -278,7 +278,7 @@ static void add_sub_lerp_template(const Tensor& self,
if (self.is_mps() && other.is_mps() && (output.scalar_type() == commonDtype) && (self_complex == other_complex)) {
if (alpha_has_value) {
at::native::alpha_check(commonDtype, alpha);
mps::binary_op_kernel(op_name, self, other, output, getMPSScalar(alpha, commonDtype));
mps::binary_op_kernel(op_name + "_alpha", self, other, output, alpha);
} else {
mps::binary_op_kernel(op_name, self, other, output);
}