[MPS] Fix metal ops with different dtypes (#149974)

By implementing `_cast_` flavors of both dense and strided ops. Add regression tests that tests `fmax`/`fmin` for mixed dtypes.

Been dreaded to write this PR for a while, as it end up to be pretty bulky:
 - Adds 1C10_METAL_ALL_TYPES_FUNCTOR` and `c10:🤘:ScalarType` to `c10/metal/common.h` and test that its values always match `c10::ScalarType`
 - Add `c10:🤘:cast_to` to `c10/metal/utils.h` which could be used to cast any scalar metal dtype to any other one, including complex values
 - Implement `val_at_offs<T>(constant void *, long offs, ScalarType dtype)` that is used to dynamically cast types
 - Add `binary_strided_cast` and `binary_dense_cast` that are invoked for output dtype and cast both inputs to that output before performing the op

Benchmark collected on M2Pro that runs fmax for 1 mln element tensors (Times are in microseconds.)

|                                           |  dense-dense  |  transp-transp  |  dense-transp  |  transp-dense  |  dense-scalar  |  dense-bcast |
|-------------------------|---------------|----------------|----------------|----------------|---------------|--------------- |
|      fmax (torch.float16, torch.float16)  |     160.9     |      159.9      |     270.5      |     270.9      |     236.6      |     293.0
|      fmax (torch.float32, torch.float32)  |     176.9     |      171.0      |     273.7      |     293.5      |     242.6      |     294.2
|      fmax (torch.float32, torch.float16)  |     171.4     |      170.9      |     283.6      |     303.0      |     253.7      |     302.3
|      add (torch.float16, torch.float16)   |     218.0     |      223.6      |     221.0      |     222.0      |     214.9      |     218.3
|      add (torch.float32, torch.float32)   |     227.4     |      233.9      |     228.8      |     231.9      |     218.9      |     221.4
|      add (torch.float32, torch.float16)   |     226.1     |      227.5      |     227.5      |     226.9      |     177.0      |     190.8

TODOS:
 - Include input and output dtype in non-cast kernel name
 - Make TensorFactory.h use `C10_METAL_ALL_TYPES_FUNCTOR`
- Extend mixed_dytpes testing via OpInfo

Fixes https://github.com/pytorch/pytorch/issues/149951
Pull Request resolved: https://github.com/pytorch/pytorch/pull/149974
Approved by: https://github.com/manuelcandales
This commit is contained in:
Nikita Shulga 2025-03-25 23:18:35 -07:00 committed by PyTorch MergeBot
parent aa575cab71
commit de68ddc68e
5 changed files with 182 additions and 11 deletions

View File

@ -1,6 +1,7 @@
// Copyright © 2022 Apple Inc.
#include <ATen/core/TensorBase.h>
#include <ATen/native/mps/MetalShaderLibrary.h>
#include <c10/metal/common.h>
#include <functional>
#include <stdexcept>
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
@ -1023,8 +1024,14 @@ void MetalShaderLibrary::exec_binary_kernel(TensorIteratorBase& iter, const std:
const uint32_t nDim = iter.ndim();
constexpr uint32_t nOffsets = 3;
const uint32_t numThreads = iter.numel();
const auto cast_needed = input.scalar_type() != other.scalar_type();
const auto suffix = iter.is_contiguous() ? "dense" : "strided";
const auto kernel_name = fmt::format("{}_{}_{}", name, suffix, scalarToMetalTypeString(input));
// TODO: Implicitly pass both input and output types to non-cast kernels
const auto kernel_name = fmt::format("{}_{}{}_{}",
name,
suffix,
cast_needed ? "_cast" : "",
cast_needed ? scalarToMetalTypeString(out) : scalarToMetalTypeString(input));
dispatch_sync_with_rethrow(mpsStream->queue(), ^() {
@autoreleasepool {
auto computeEncoder = mpsStream->commandEncoder();
@ -1036,10 +1043,19 @@ void MetalShaderLibrary::exec_binary_kernel(TensorIteratorBase& iter, const std:
// i.e. it's true for both row-first and column-first tensors
if (iter.is_contiguous()) {
mtl_setArgs(computeEncoder, out, input, other);
if (cast_needed) {
std::array<int, 4> size_and_types = {static_cast<int>(c10::elementSize(input.scalar_type())),
static_cast<int>(c10::elementSize(other.scalar_type())),
static_cast<int>(input.scalar_type()),
static_cast<int>(other.scalar_type())};
mtl_setBytes(computeEncoder, size_and_types, 3);
}
} else {
// Please note that shapes and strides of the iterator might be
// different than that of its operands, for example binary op
// between 4x4 tensor and scalar will result in 1D 16 element iterator
std::array<int, 3> ndim_and_types = {
iter.ndim(), static_cast<int>(input.scalar_type()), static_cast<int>(other.scalar_type())};
mtl_setArgs(computeEncoder,
out,
input,
@ -1048,7 +1064,7 @@ void MetalShaderLibrary::exec_binary_kernel(TensorIteratorBase& iter, const std:
iter.strides(0),
iter.strides(1),
iter.strides(2),
iter.ndim());
ndim_and_types);
}
mtl_dispatch1DJob(computeEncoder, binaryPSO, numThreads);
getMPSProfiler().endProfileKernel(binaryPSO);
@ -1132,3 +1148,8 @@ uint64_t MetalKernelFunction::getStaticThreadGroupMemoryLength() const {
}
} // namespace at::native::mps
// Check that c10::metal::ScalarType is strict subset (with matching values) of c10::ScalarType
#define DTYPE_CHECKER(_n, _v) \
static_assert(static_cast<int>(::c10::ScalarType::_n) == static_cast<int>(::c10::metal::ScalarType::_n));
C10_METAL_ALL_TYPES_FUNCTOR(DTYPE_CHECKER)

View File

@ -7,8 +7,38 @@
#define C10_METAL_CONSTEXPR constexpr
#endif
#if !defined(__METAL__) || __METAL_VERSION__ >= 310
#define C10_METAL_ALL_TYPES_FUNCTOR(_) \
_(Byte, 0) \
_(Char, 1) \
_(Short, 2) \
_(Int, 3) \
_(Long, 4) \
_(Half, 5) \
_(Float, 6) \
_(Bool, 11) \
_(BFloat16, 15)
#else
#define C10_METAL_ALL_TYPES_FUNCTOR(_) \
_(Byte, 0) \
_(Char, 1) \
_(Short, 2) \
_(Int, 3) \
_(Long, 4) \
_(Half, 5) \
_(Float, 6) \
_(Bool, 11)
#endif
namespace c10 {
namespace metal {
C10_METAL_CONSTEXPR unsigned max_ndim = 16;
enum class ScalarType {
#define _DEFINE_ENUM_VAL_(_v, _n) _v = _n,
C10_METAL_ALL_TYPES_FUNCTOR(_DEFINE_ENUM_VAL_)
#undef _DEFINE_ENUM_VAL_
};
} // namespace metal
} // namespace c10

View File

@ -1,3 +1,4 @@
// Metal indexing primitives
#pragma once
#include <c10/metal/common.h>
#include <c10/metal/utils.h>
@ -104,10 +105,39 @@ kernel void unary_strided(
}
template <typename T>
inline constant T& ref_at_offs(constant void* ptr, long offs) {
inline T val_at_offs(constant void* ptr, long offs) {
return *reinterpret_cast<constant T*>(
static_cast<constant char*>(ptr) + offs);
}
// Value at offset with dynamic cast from provided type
template <typename T>
inline T val_at_offs(constant void* ptr, long offs, ScalarType type) {
switch (type) {
case ScalarType::Bool:
return val_at_offs<bool>(ptr, offs);
case ScalarType::Byte:
return val_at_offs<uchar>(ptr, offs);
case ScalarType::Char:
return val_at_offs<char>(ptr, offs);
case ScalarType::Short:
return val_at_offs<short>(ptr, offs);
case ScalarType::Int:
return val_at_offs<int>(ptr, offs);
case ScalarType::Long:
return val_at_offs<long>(ptr, offs);
// Floats
case ScalarType::Float:
return static_cast<T>(val_at_offs<float>(ptr, offs));
case ScalarType::Half:
return static_cast<T>(val_at_offs<half>(ptr, offs));
#if __METAL_VERSION__ >= 310
case ScalarType::BFloat16:
return cast_to<T>(val_at_offs<bfloat>(ptr, offs));
#endif
}
}
template <typename T>
inline device T& ref_at_offs(device void* ptr, long offs) {
return *reinterpret_cast<device T*>(static_cast<device char*>(ptr) + offs);
@ -122,16 +152,40 @@ kernel void binary_strided(
constant long* output_strides [[buffer(4)]],
constant long* input_strides [[buffer(5)]],
constant long* other_strides [[buffer(6)]],
constant uint& ndim [[buffer(7)]],
constant uint3& ndim [[buffer(7)]],
uint index [[thread_position_in_grid]]) {
F f;
int pos[max_ndim];
pos_from_thread_index(int(index), pos, sizes, ndim);
const auto input_offs = offset_from_coord(pos, input_strides, ndim);
const auto other_offs = offset_from_coord(pos, other_strides, ndim);
const auto output_offs = offset_from_coord(pos, output_strides, ndim);
const auto a = ref_at_offs<T>(input, input_offs);
const auto b = ref_at_offs<T>(other, other_offs);
pos_from_thread_index(int(index), pos, sizes, ndim.x);
const auto input_offs = offset_from_coord(pos, input_strides, ndim.x);
const auto other_offs = offset_from_coord(pos, other_strides, ndim.x);
const auto output_offs = offset_from_coord(pos, output_strides, ndim.x);
const auto a = val_at_offs<T>(input, input_offs);
const auto b = val_at_offs<T>(other, other_offs);
ref_at_offs<result_of<F, T, T>>(output, output_offs) = f(a, b);
}
template <typename T, typename F>
kernel void binary_strided_cast(
device void* output [[buffer(0)]],
constant void* input [[buffer(1)]],
constant void* other [[buffer(2)]],
constant long* sizes [[buffer(3)]],
constant long* output_strides [[buffer(4)]],
constant long* input_strides [[buffer(5)]],
constant long* other_strides [[buffer(6)]],
constant uint3& ndim_types [[buffer(7)]],
uint index [[thread_position_in_grid]]) {
F f;
int pos[max_ndim];
pos_from_thread_index(int(index), pos, sizes, ndim_types.x);
const auto input_offs = offset_from_coord(pos, input_strides, ndim_types.x);
const auto other_offs = offset_from_coord(pos, other_strides, ndim_types.x);
const auto output_offs = offset_from_coord(pos, output_strides, ndim_types.x);
const auto a =
val_at_offs<T>(input, input_offs, static_cast<ScalarType>(ndim_types.y));
const auto b =
val_at_offs<T>(other, other_offs, static_cast<ScalarType>(ndim_types.z));
ref_at_offs<result_of<F, T, T>>(output, output_offs) = f(a, b);
}
@ -145,6 +199,21 @@ kernel void binary_dense(
out[tid] = f(input[tid], other[tid]);
}
template <typename T, typename F>
kernel void binary_dense_cast(
device result_of<F, T, T>* out [[buffer(0)]],
constant void* input [[buffer(1)]],
constant void* other [[buffer(2)]],
constant uint4& sizes_types [[buffer(3)]],
uint tid [[thread_position_in_grid]]) {
F f;
const auto a = val_at_offs<T>(
input, tid * sizes_types.x, static_cast<ScalarType>(sizes_types.z));
const auto b = val_at_offs<T>(
other, tid * sizes_types.y, static_cast<ScalarType>(sizes_types.w));
out[tid] = f(a, b);
}
#define REGISTER_BINARY_INDEXING_OP(NAME, DTYPE) \
template [[host_name(#NAME "_strided_" #DTYPE)]] kernel void ::c10::metal:: \
binary_strided<DTYPE, NAME##_functor>( \
@ -155,13 +224,31 @@ kernel void binary_dense(
constant long* output_strides, \
constant long* input_strides, \
constant long* other_strides, \
constant uint& ndim, \
constant uint3& ndim, \
uint tid); \
template [[host_name(#NAME "_strided_cast_" #DTYPE)]] kernel void ::c10:: \
metal::binary_strided_cast<DTYPE, NAME##_functor>( \
device void* out, \
constant void* input, \
constant void* other, \
constant long* sizes, \
constant long* output_strides, \
constant long* input_strides, \
constant long* other_strides, \
constant uint3& ndim_types, \
uint tid); \
template [[host_name(#NAME "_dense_" #DTYPE)]] kernel void ::c10::metal:: \
binary_dense<DTYPE, NAME##_functor>( \
device ::c10::metal::result_of<NAME##_functor, DTYPE, DTYPE> * out_, \
constant DTYPE * input_, \
constant DTYPE * other_, \
uint tid); \
template [[host_name(#NAME "_dense_cast_" #DTYPE)]] kernel void ::c10:: \
metal::binary_dense_cast<DTYPE, NAME##_functor>( \
device ::c10::metal::result_of<NAME##_functor, DTYPE, DTYPE> * out_, \
constant void* input, \
constant void* other, \
constant uint4& sizes_types, \
uint tid)
} // namespace metal
} // namespace c10

View File

@ -1,5 +1,6 @@
// Metal helper functions
#pragma once
#include <c10/metal/common.h>
#include <metal_stdlib>
namespace c10 {
@ -145,5 +146,21 @@ template <typename T>
constexpr constant bool is_scalar_integral_v =
::metal::is_integral_v<T> && ::metal::is_scalar_v<T>;
template <typename T, typename U>
inline ::metal::enable_if_t<::metal::is_same_v<U, T>, T> cast_to(const U from) {
return from;
}
template <typename T, typename U>
inline ::metal::enable_if_t<is_complex_v<T>, T> cast_to(const U from) {
return T(float(from), 0.0);
}
template <typename T, typename U>
inline ::metal::enable_if_t<!::metal::is_same_v<U, T> && !is_complex_v<T>, T>
cast_to(const U from) {
return static_cast<T>(from);
}
} // namespace metal
} // namespace c10

View File

@ -12704,6 +12704,22 @@ class TestConsistency(TestCaseMPS):
rtol = 1.5e-3
self.assertEqual(cpu_grad_inputs, mps_grad_inputs, atol=atol, rtol=rtol)
def test_fmax_mixed_dtypes(self, device):
# Regression tesing for https://github.com/pytorch/pytorch/issues/149951
# fmax and fmin are implemented as binary metal shaders and they were implemented
# with the assumption that both args have the same dtype
x = torch.rand((3, 3), device=device, dtype=torch.float32)
x_int = torch.randint(-10, 10, (3, 3), device=device, dtype=torch.int8)
y = torch.rand((3, 3), device=device, dtype=torch.float16)
for op in [torch.fmax, torch.fmin]:
self.assertEqual(op(x, y), op(x.to("mps"), y.to("mps")).cpu())
self.assertEqual(op(x_int, y), op(x_int.to("mps"), y.to("mps")).cpu())
# Stride
self.assertEqual(op(x.t(), y), op(x.to("mps").t(), y.to("mps")).cpu())
# Broadcast
self.assertEqual(op(x, y[0]), op(x.to("mps"), y.to("mps")[0]).cpu())
class TestErrorInputs(TestCase):
_ignore_not_implemented_error = True