mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[MPS][BE] Remove exec_binary_alpha_kernel (#152485)
Which was almost a complete copy-n-paste from exec_binary_kernel anyway Just add `Scalar` as an optional argument and figure out kernel name during the invocation rather than in executor Pull Request resolved: https://github.com/pytorch/pytorch/pull/152485 Approved by: https://github.com/Skylion007 ghstack dependencies: #152443, #152466, #152479, #152504
This commit is contained in:
parent
c90e23eb73
commit
cf894b3f1f
|
|
@ -13,6 +13,7 @@ typedef void* MTLComputePipelineState_t;
|
|||
typedef void* MTLComputeCommandEncoder_t;
|
||||
#endif
|
||||
|
||||
#include <c10/core/Scalar.h>
|
||||
#include <c10/util/OptionalArrayRef.h>
|
||||
#include <functional>
|
||||
#include <optional>
|
||||
|
|
@ -29,9 +30,6 @@ struct TensorIteratorBase;
|
|||
|
||||
namespace at::native::mps {
|
||||
|
||||
// Forward declaration of MPSScalar - for exec_binary_alpha_kernel()
|
||||
struct MPSScalar;
|
||||
|
||||
namespace detail {
|
||||
template <typename T>
|
||||
class has_size_type {
|
||||
|
|
@ -140,11 +138,10 @@ class MetalShaderLibrary {
|
|||
TensorIteratorBase& iter,
|
||||
const std::string& name,
|
||||
std::optional<int64_t> extra = std::nullopt);
|
||||
void exec_binary_kernel(TensorIteratorBase& iter, const std::string& name);
|
||||
void exec_binary_alpha_kernel(
|
||||
void exec_binary_kernel(
|
||||
TensorIteratorBase& iter,
|
||||
const std::string& name,
|
||||
const MPSScalar& alpha);
|
||||
const std::optional<c10::Scalar> alpha = std::nullopt);
|
||||
|
||||
protected:
|
||||
virtual MTLLibrary_t getLibrary();
|
||||
|
|
|
|||
|
|
@ -1019,7 +1019,9 @@ void MetalShaderLibrary::exec_unary_kernel(TensorIteratorBase& iter,
|
|||
}
|
||||
}
|
||||
|
||||
void MetalShaderLibrary::exec_binary_kernel(TensorIteratorBase& iter, const std::string& name) {
|
||||
void MetalShaderLibrary::exec_binary_kernel(TensorIteratorBase& iter,
|
||||
const std::string& name,
|
||||
std::optional<c10::Scalar> alpha) {
|
||||
TORCH_CHECK(iter.common_dtype() != at::kDouble, "float64 is not supported on MPS");
|
||||
TORCH_CHECK(iter.can_use_32bit_indexing(), "Can't be indexed using 32-bit iterator");
|
||||
|
||||
|
|
@ -1049,12 +1051,15 @@ void MetalShaderLibrary::exec_binary_kernel(TensorIteratorBase& iter, const std:
|
|||
// i.e. it's true for both row-first and column-first tensors
|
||||
if (iter.is_contiguous()) {
|
||||
mtl_setArgs(computeEncoder, out, input, other);
|
||||
if (alpha) {
|
||||
mtl_setBytes(computeEncoder, getMPSScalar(*alpha, iter.common_dtype()), 3);
|
||||
}
|
||||
if (cast_needed) {
|
||||
std::array<int, 4> size_and_types = {static_cast<int>(c10::elementSize(input.scalar_type())),
|
||||
static_cast<int>(c10::elementSize(other.scalar_type())),
|
||||
static_cast<int>(input.scalar_type()),
|
||||
static_cast<int>(other.scalar_type())};
|
||||
mtl_setBytes(computeEncoder, size_and_types, 3);
|
||||
mtl_setBytes(computeEncoder, size_and_types, alpha ? 4 : 3);
|
||||
}
|
||||
} else {
|
||||
// Please note that shapes and strides of the iterator might be
|
||||
|
|
@ -1062,77 +1067,28 @@ void MetalShaderLibrary::exec_binary_kernel(TensorIteratorBase& iter, const std:
|
|||
// between 4x4 tensor and scalar will result in 1D 16 element iterator
|
||||
std::array<int, 3> ndim_and_types = {
|
||||
iter.ndim(), static_cast<int>(input.scalar_type()), static_cast<int>(other.scalar_type())};
|
||||
mtl_setArgs(computeEncoder,
|
||||
out,
|
||||
input,
|
||||
other,
|
||||
iter.shape(),
|
||||
iter.strides(0),
|
||||
iter.strides(1),
|
||||
iter.strides(2),
|
||||
ndim_and_types);
|
||||
}
|
||||
mtl_dispatch1DJob(computeEncoder, binaryPSO, numThreads);
|
||||
getMPSProfiler().endProfileKernel(binaryPSO);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
void MetalShaderLibrary::exec_binary_alpha_kernel(TensorIteratorBase& iter,
|
||||
const std::string& name,
|
||||
const MPSScalar& alpha) {
|
||||
TORCH_CHECK(iter.common_dtype() != at::kDouble, "float64 is not supported on MPS");
|
||||
TORCH_CHECK(iter.can_use_32bit_indexing(), "Can't be indexed using 32-bit iterator");
|
||||
|
||||
Tensor input = iter.input(0);
|
||||
Tensor other = iter.input(1);
|
||||
Tensor out = iter.output();
|
||||
|
||||
id<MTLDevice> device = MPSDevice::getInstance()->device();
|
||||
MPSStream* mpsStream = getCurrentMPSStream();
|
||||
const uint32_t nDim = iter.ndim();
|
||||
constexpr uint32_t nOffsets = 3;
|
||||
const uint32_t numThreads = iter.numel();
|
||||
const auto cast_needed = input.scalar_type() != other.scalar_type();
|
||||
const auto suffix = iter.is_contiguous() ? "dense" : "strided";
|
||||
// TODO: Implicitly pass both input and output types to non-cast kernels
|
||||
const auto kernel_name = cast_needed
|
||||
? fmt::format("{}_alpha_{}_cast_{}", name, suffix, scalarToMetalTypeString(out))
|
||||
: fmt::format("{}_alpha_{}_{}_{}", name, suffix, scalarToMetalTypeString(out), scalarToMetalTypeString(input));
|
||||
dispatch_sync_with_rethrow(mpsStream->queue(), ^() {
|
||||
@autoreleasepool {
|
||||
auto computeEncoder = mpsStream->commandEncoder();
|
||||
auto binaryPSO = getPipelineStateForFunc(kernel_name);
|
||||
// this function call is a no-op if MPS Profiler is not enabled
|
||||
getMPSProfiler().beginProfileKernel(binaryPSO, kernel_name, {input, other});
|
||||
[computeEncoder setComputePipelineState:binaryPSO];
|
||||
// Iterator is contiguous if all of its elements are dense in storage,
|
||||
// i.e. it's true for both row-first and column-first tensors
|
||||
if (iter.is_contiguous()) {
|
||||
mtl_setArgs(computeEncoder, out, input, other, alpha);
|
||||
if (cast_needed) {
|
||||
std::array<int, 4> size_and_types = {static_cast<int>(c10::elementSize(input.scalar_type())),
|
||||
static_cast<int>(c10::elementSize(other.scalar_type())),
|
||||
static_cast<int>(input.scalar_type()),
|
||||
static_cast<int>(other.scalar_type())};
|
||||
mtl_setBytes(computeEncoder, size_and_types, 4);
|
||||
if (alpha) {
|
||||
mtl_setArgs(computeEncoder,
|
||||
out,
|
||||
input,
|
||||
other,
|
||||
getMPSScalar(*alpha, iter.common_dtype()),
|
||||
iter.shape(),
|
||||
iter.strides(0),
|
||||
iter.strides(1),
|
||||
iter.strides(2),
|
||||
ndim_and_types);
|
||||
} else {
|
||||
mtl_setArgs(computeEncoder,
|
||||
out,
|
||||
input,
|
||||
other,
|
||||
iter.shape(),
|
||||
iter.strides(0),
|
||||
iter.strides(1),
|
||||
iter.strides(2),
|
||||
ndim_and_types);
|
||||
}
|
||||
} else {
|
||||
// Please note that shapes and strides of the iterator might be
|
||||
// different than that of its operands, for example binary op
|
||||
// between 4x4 tensor and scalar will result in 1D 16 element iterator
|
||||
std::array<int, 3> ndim_and_types = {
|
||||
iter.ndim(), static_cast<int>(input.scalar_type()), static_cast<int>(other.scalar_type())};
|
||||
mtl_setArgs(computeEncoder,
|
||||
out,
|
||||
input,
|
||||
other,
|
||||
alpha,
|
||||
iter.shape(),
|
||||
iter.strides(0),
|
||||
iter.strides(1),
|
||||
iter.strides(2),
|
||||
ndim_and_types);
|
||||
}
|
||||
mtl_dispatch1DJob(computeEncoder, binaryPSO, numThreads);
|
||||
getMPSProfiler().endProfileKernel(binaryPSO);
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ void binary_op_kernel(
|
|||
const Tensor& input,
|
||||
const Tensor& other,
|
||||
const Tensor& output,
|
||||
const std::optional<MPSScalar>& alpha = std::nullopt);
|
||||
const std::optional<Scalar> alpha = std::nullopt);
|
||||
void complex_mul_out(
|
||||
const Tensor& input,
|
||||
const Tensor& other,
|
||||
|
|
|
|||
|
|
@ -54,7 +54,7 @@ void binary_op_kernel(const std::string func_name,
|
|||
const Tensor& input,
|
||||
const Tensor& other,
|
||||
const Tensor& output,
|
||||
const std::optional<MPSScalar>& alpha) {
|
||||
const std::optional<Scalar> alpha) {
|
||||
auto new_size = at::infer_size(input.sizes(), other.sizes());
|
||||
if (!output.sizes().equals(new_size)) {
|
||||
output.resize_(new_size);
|
||||
|
|
@ -67,11 +67,7 @@ void binary_op_kernel(const std::string func_name,
|
|||
auto iter =
|
||||
TensorIteratorConfig().add_output(output).add_input(input).add_input(other).check_all_same_dtype(false).build();
|
||||
|
||||
if (alpha) {
|
||||
lib.exec_binary_alpha_kernel(iter, func_name, *alpha);
|
||||
} else {
|
||||
lib.exec_binary_kernel(iter, func_name);
|
||||
}
|
||||
lib.exec_binary_kernel(iter, func_name, alpha);
|
||||
}
|
||||
|
||||
} // namespace mps
|
||||
|
|
|
|||
|
|
@ -278,7 +278,7 @@ static void add_sub_lerp_template(const Tensor& self,
|
|||
if (self.is_mps() && other.is_mps() && (output.scalar_type() == commonDtype) && (self_complex == other_complex)) {
|
||||
if (alpha_has_value) {
|
||||
at::native::alpha_check(commonDtype, alpha);
|
||||
mps::binary_op_kernel(op_name, self, other, output, getMPSScalar(alpha, commonDtype));
|
||||
mps::binary_op_kernel(op_name + "_alpha", self, other, output, alpha);
|
||||
} else {
|
||||
mps::binary_op_kernel(op_name, self, other, output);
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user