mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[MPS] Fix metal ops with different dtypes (#149974)
By implementing `_cast_` flavors of both dense and strided ops. Add regression tests that tests `fmax`/`fmin` for mixed dtypes. Been dreaded to write this PR for a while, as it end up to be pretty bulky: - Adds 1C10_METAL_ALL_TYPES_FUNCTOR` and `c10:🤘:ScalarType` to `c10/metal/common.h` and test that its values always match `c10::ScalarType` - Add `c10:🤘:cast_to` to `c10/metal/utils.h` which could be used to cast any scalar metal dtype to any other one, including complex values - Implement `val_at_offs<T>(constant void *, long offs, ScalarType dtype)` that is used to dynamically cast types - Add `binary_strided_cast` and `binary_dense_cast` that are invoked for output dtype and cast both inputs to that output before performing the op Benchmark collected on M2Pro that runs fmax for 1 mln element tensors (Times are in microseconds.) | | dense-dense | transp-transp | dense-transp | transp-dense | dense-scalar | dense-bcast | |-------------------------|---------------|----------------|----------------|----------------|---------------|--------------- | | fmax (torch.float16, torch.float16) | 160.9 | 159.9 | 270.5 | 270.9 | 236.6 | 293.0 | fmax (torch.float32, torch.float32) | 176.9 | 171.0 | 273.7 | 293.5 | 242.6 | 294.2 | fmax (torch.float32, torch.float16) | 171.4 | 170.9 | 283.6 | 303.0 | 253.7 | 302.3 | add (torch.float16, torch.float16) | 218.0 | 223.6 | 221.0 | 222.0 | 214.9 | 218.3 | add (torch.float32, torch.float32) | 227.4 | 233.9 | 228.8 | 231.9 | 218.9 | 221.4 | add (torch.float32, torch.float16) | 226.1 | 227.5 | 227.5 | 226.9 | 177.0 | 190.8 TODOS: - Include input and output dtype in non-cast kernel name - Make TensorFactory.h use `C10_METAL_ALL_TYPES_FUNCTOR` - Extend mixed_dytpes testing via OpInfo Fixes https://github.com/pytorch/pytorch/issues/149951 Pull Request resolved: https://github.com/pytorch/pytorch/pull/149974 Approved by: https://github.com/manuelcandales
This commit is contained in:
parent
aa575cab71
commit
de68ddc68e
|
|
@ -1,6 +1,7 @@
|
||||||
// Copyright © 2022 Apple Inc.
|
// Copyright © 2022 Apple Inc.
|
||||||
#include <ATen/core/TensorBase.h>
|
#include <ATen/core/TensorBase.h>
|
||||||
#include <ATen/native/mps/MetalShaderLibrary.h>
|
#include <ATen/native/mps/MetalShaderLibrary.h>
|
||||||
|
#include <c10/metal/common.h>
|
||||||
#include <functional>
|
#include <functional>
|
||||||
#include <stdexcept>
|
#include <stdexcept>
|
||||||
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
|
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
|
||||||
|
|
@ -1023,8 +1024,14 @@ void MetalShaderLibrary::exec_binary_kernel(TensorIteratorBase& iter, const std:
|
||||||
const uint32_t nDim = iter.ndim();
|
const uint32_t nDim = iter.ndim();
|
||||||
constexpr uint32_t nOffsets = 3;
|
constexpr uint32_t nOffsets = 3;
|
||||||
const uint32_t numThreads = iter.numel();
|
const uint32_t numThreads = iter.numel();
|
||||||
|
const auto cast_needed = input.scalar_type() != other.scalar_type();
|
||||||
const auto suffix = iter.is_contiguous() ? "dense" : "strided";
|
const auto suffix = iter.is_contiguous() ? "dense" : "strided";
|
||||||
const auto kernel_name = fmt::format("{}_{}_{}", name, suffix, scalarToMetalTypeString(input));
|
// TODO: Implicitly pass both input and output types to non-cast kernels
|
||||||
|
const auto kernel_name = fmt::format("{}_{}{}_{}",
|
||||||
|
name,
|
||||||
|
suffix,
|
||||||
|
cast_needed ? "_cast" : "",
|
||||||
|
cast_needed ? scalarToMetalTypeString(out) : scalarToMetalTypeString(input));
|
||||||
dispatch_sync_with_rethrow(mpsStream->queue(), ^() {
|
dispatch_sync_with_rethrow(mpsStream->queue(), ^() {
|
||||||
@autoreleasepool {
|
@autoreleasepool {
|
||||||
auto computeEncoder = mpsStream->commandEncoder();
|
auto computeEncoder = mpsStream->commandEncoder();
|
||||||
|
|
@ -1036,10 +1043,19 @@ void MetalShaderLibrary::exec_binary_kernel(TensorIteratorBase& iter, const std:
|
||||||
// i.e. it's true for both row-first and column-first tensors
|
// i.e. it's true for both row-first and column-first tensors
|
||||||
if (iter.is_contiguous()) {
|
if (iter.is_contiguous()) {
|
||||||
mtl_setArgs(computeEncoder, out, input, other);
|
mtl_setArgs(computeEncoder, out, input, other);
|
||||||
|
if (cast_needed) {
|
||||||
|
std::array<int, 4> size_and_types = {static_cast<int>(c10::elementSize(input.scalar_type())),
|
||||||
|
static_cast<int>(c10::elementSize(other.scalar_type())),
|
||||||
|
static_cast<int>(input.scalar_type()),
|
||||||
|
static_cast<int>(other.scalar_type())};
|
||||||
|
mtl_setBytes(computeEncoder, size_and_types, 3);
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
// Please note that shapes and strides of the iterator might be
|
// Please note that shapes and strides of the iterator might be
|
||||||
// different than that of its operands, for example binary op
|
// different than that of its operands, for example binary op
|
||||||
// between 4x4 tensor and scalar will result in 1D 16 element iterator
|
// between 4x4 tensor and scalar will result in 1D 16 element iterator
|
||||||
|
std::array<int, 3> ndim_and_types = {
|
||||||
|
iter.ndim(), static_cast<int>(input.scalar_type()), static_cast<int>(other.scalar_type())};
|
||||||
mtl_setArgs(computeEncoder,
|
mtl_setArgs(computeEncoder,
|
||||||
out,
|
out,
|
||||||
input,
|
input,
|
||||||
|
|
@ -1048,7 +1064,7 @@ void MetalShaderLibrary::exec_binary_kernel(TensorIteratorBase& iter, const std:
|
||||||
iter.strides(0),
|
iter.strides(0),
|
||||||
iter.strides(1),
|
iter.strides(1),
|
||||||
iter.strides(2),
|
iter.strides(2),
|
||||||
iter.ndim());
|
ndim_and_types);
|
||||||
}
|
}
|
||||||
mtl_dispatch1DJob(computeEncoder, binaryPSO, numThreads);
|
mtl_dispatch1DJob(computeEncoder, binaryPSO, numThreads);
|
||||||
getMPSProfiler().endProfileKernel(binaryPSO);
|
getMPSProfiler().endProfileKernel(binaryPSO);
|
||||||
|
|
@ -1132,3 +1148,8 @@ uint64_t MetalKernelFunction::getStaticThreadGroupMemoryLength() const {
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace at::native::mps
|
} // namespace at::native::mps
|
||||||
|
|
||||||
|
// Check that c10::metal::ScalarType is strict subset (with matching values) of c10::ScalarType
|
||||||
|
#define DTYPE_CHECKER(_n, _v) \
|
||||||
|
static_assert(static_cast<int>(::c10::ScalarType::_n) == static_cast<int>(::c10::metal::ScalarType::_n));
|
||||||
|
C10_METAL_ALL_TYPES_FUNCTOR(DTYPE_CHECKER)
|
||||||
|
|
|
||||||
|
|
@ -7,8 +7,38 @@
|
||||||
#define C10_METAL_CONSTEXPR constexpr
|
#define C10_METAL_CONSTEXPR constexpr
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
#if !defined(__METAL__) || __METAL_VERSION__ >= 310
|
||||||
|
#define C10_METAL_ALL_TYPES_FUNCTOR(_) \
|
||||||
|
_(Byte, 0) \
|
||||||
|
_(Char, 1) \
|
||||||
|
_(Short, 2) \
|
||||||
|
_(Int, 3) \
|
||||||
|
_(Long, 4) \
|
||||||
|
_(Half, 5) \
|
||||||
|
_(Float, 6) \
|
||||||
|
_(Bool, 11) \
|
||||||
|
_(BFloat16, 15)
|
||||||
|
#else
|
||||||
|
#define C10_METAL_ALL_TYPES_FUNCTOR(_) \
|
||||||
|
_(Byte, 0) \
|
||||||
|
_(Char, 1) \
|
||||||
|
_(Short, 2) \
|
||||||
|
_(Int, 3) \
|
||||||
|
_(Long, 4) \
|
||||||
|
_(Half, 5) \
|
||||||
|
_(Float, 6) \
|
||||||
|
_(Bool, 11)
|
||||||
|
#endif
|
||||||
|
|
||||||
namespace c10 {
|
namespace c10 {
|
||||||
namespace metal {
|
namespace metal {
|
||||||
C10_METAL_CONSTEXPR unsigned max_ndim = 16;
|
C10_METAL_CONSTEXPR unsigned max_ndim = 16;
|
||||||
|
|
||||||
|
enum class ScalarType {
|
||||||
|
#define _DEFINE_ENUM_VAL_(_v, _n) _v = _n,
|
||||||
|
C10_METAL_ALL_TYPES_FUNCTOR(_DEFINE_ENUM_VAL_)
|
||||||
|
#undef _DEFINE_ENUM_VAL_
|
||||||
|
};
|
||||||
|
|
||||||
} // namespace metal
|
} // namespace metal
|
||||||
} // namespace c10
|
} // namespace c10
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,4 @@
|
||||||
|
// Metal indexing primitives
|
||||||
#pragma once
|
#pragma once
|
||||||
#include <c10/metal/common.h>
|
#include <c10/metal/common.h>
|
||||||
#include <c10/metal/utils.h>
|
#include <c10/metal/utils.h>
|
||||||
|
|
@ -104,10 +105,39 @@ kernel void unary_strided(
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
inline constant T& ref_at_offs(constant void* ptr, long offs) {
|
inline T val_at_offs(constant void* ptr, long offs) {
|
||||||
return *reinterpret_cast<constant T*>(
|
return *reinterpret_cast<constant T*>(
|
||||||
static_cast<constant char*>(ptr) + offs);
|
static_cast<constant char*>(ptr) + offs);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Value at offset with dynamic cast from provided type
|
||||||
|
template <typename T>
|
||||||
|
inline T val_at_offs(constant void* ptr, long offs, ScalarType type) {
|
||||||
|
switch (type) {
|
||||||
|
case ScalarType::Bool:
|
||||||
|
return val_at_offs<bool>(ptr, offs);
|
||||||
|
case ScalarType::Byte:
|
||||||
|
return val_at_offs<uchar>(ptr, offs);
|
||||||
|
case ScalarType::Char:
|
||||||
|
return val_at_offs<char>(ptr, offs);
|
||||||
|
case ScalarType::Short:
|
||||||
|
return val_at_offs<short>(ptr, offs);
|
||||||
|
case ScalarType::Int:
|
||||||
|
return val_at_offs<int>(ptr, offs);
|
||||||
|
case ScalarType::Long:
|
||||||
|
return val_at_offs<long>(ptr, offs);
|
||||||
|
// Floats
|
||||||
|
case ScalarType::Float:
|
||||||
|
return static_cast<T>(val_at_offs<float>(ptr, offs));
|
||||||
|
case ScalarType::Half:
|
||||||
|
return static_cast<T>(val_at_offs<half>(ptr, offs));
|
||||||
|
#if __METAL_VERSION__ >= 310
|
||||||
|
case ScalarType::BFloat16:
|
||||||
|
return cast_to<T>(val_at_offs<bfloat>(ptr, offs));
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
inline device T& ref_at_offs(device void* ptr, long offs) {
|
inline device T& ref_at_offs(device void* ptr, long offs) {
|
||||||
return *reinterpret_cast<device T*>(static_cast<device char*>(ptr) + offs);
|
return *reinterpret_cast<device T*>(static_cast<device char*>(ptr) + offs);
|
||||||
|
|
@ -122,16 +152,40 @@ kernel void binary_strided(
|
||||||
constant long* output_strides [[buffer(4)]],
|
constant long* output_strides [[buffer(4)]],
|
||||||
constant long* input_strides [[buffer(5)]],
|
constant long* input_strides [[buffer(5)]],
|
||||||
constant long* other_strides [[buffer(6)]],
|
constant long* other_strides [[buffer(6)]],
|
||||||
constant uint& ndim [[buffer(7)]],
|
constant uint3& ndim [[buffer(7)]],
|
||||||
uint index [[thread_position_in_grid]]) {
|
uint index [[thread_position_in_grid]]) {
|
||||||
F f;
|
F f;
|
||||||
int pos[max_ndim];
|
int pos[max_ndim];
|
||||||
pos_from_thread_index(int(index), pos, sizes, ndim);
|
pos_from_thread_index(int(index), pos, sizes, ndim.x);
|
||||||
const auto input_offs = offset_from_coord(pos, input_strides, ndim);
|
const auto input_offs = offset_from_coord(pos, input_strides, ndim.x);
|
||||||
const auto other_offs = offset_from_coord(pos, other_strides, ndim);
|
const auto other_offs = offset_from_coord(pos, other_strides, ndim.x);
|
||||||
const auto output_offs = offset_from_coord(pos, output_strides, ndim);
|
const auto output_offs = offset_from_coord(pos, output_strides, ndim.x);
|
||||||
const auto a = ref_at_offs<T>(input, input_offs);
|
const auto a = val_at_offs<T>(input, input_offs);
|
||||||
const auto b = ref_at_offs<T>(other, other_offs);
|
const auto b = val_at_offs<T>(other, other_offs);
|
||||||
|
ref_at_offs<result_of<F, T, T>>(output, output_offs) = f(a, b);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename F>
|
||||||
|
kernel void binary_strided_cast(
|
||||||
|
device void* output [[buffer(0)]],
|
||||||
|
constant void* input [[buffer(1)]],
|
||||||
|
constant void* other [[buffer(2)]],
|
||||||
|
constant long* sizes [[buffer(3)]],
|
||||||
|
constant long* output_strides [[buffer(4)]],
|
||||||
|
constant long* input_strides [[buffer(5)]],
|
||||||
|
constant long* other_strides [[buffer(6)]],
|
||||||
|
constant uint3& ndim_types [[buffer(7)]],
|
||||||
|
uint index [[thread_position_in_grid]]) {
|
||||||
|
F f;
|
||||||
|
int pos[max_ndim];
|
||||||
|
pos_from_thread_index(int(index), pos, sizes, ndim_types.x);
|
||||||
|
const auto input_offs = offset_from_coord(pos, input_strides, ndim_types.x);
|
||||||
|
const auto other_offs = offset_from_coord(pos, other_strides, ndim_types.x);
|
||||||
|
const auto output_offs = offset_from_coord(pos, output_strides, ndim_types.x);
|
||||||
|
const auto a =
|
||||||
|
val_at_offs<T>(input, input_offs, static_cast<ScalarType>(ndim_types.y));
|
||||||
|
const auto b =
|
||||||
|
val_at_offs<T>(other, other_offs, static_cast<ScalarType>(ndim_types.z));
|
||||||
ref_at_offs<result_of<F, T, T>>(output, output_offs) = f(a, b);
|
ref_at_offs<result_of<F, T, T>>(output, output_offs) = f(a, b);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -145,6 +199,21 @@ kernel void binary_dense(
|
||||||
out[tid] = f(input[tid], other[tid]);
|
out[tid] = f(input[tid], other[tid]);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename T, typename F>
|
||||||
|
kernel void binary_dense_cast(
|
||||||
|
device result_of<F, T, T>* out [[buffer(0)]],
|
||||||
|
constant void* input [[buffer(1)]],
|
||||||
|
constant void* other [[buffer(2)]],
|
||||||
|
constant uint4& sizes_types [[buffer(3)]],
|
||||||
|
uint tid [[thread_position_in_grid]]) {
|
||||||
|
F f;
|
||||||
|
const auto a = val_at_offs<T>(
|
||||||
|
input, tid * sizes_types.x, static_cast<ScalarType>(sizes_types.z));
|
||||||
|
const auto b = val_at_offs<T>(
|
||||||
|
other, tid * sizes_types.y, static_cast<ScalarType>(sizes_types.w));
|
||||||
|
out[tid] = f(a, b);
|
||||||
|
}
|
||||||
|
|
||||||
#define REGISTER_BINARY_INDEXING_OP(NAME, DTYPE) \
|
#define REGISTER_BINARY_INDEXING_OP(NAME, DTYPE) \
|
||||||
template [[host_name(#NAME "_strided_" #DTYPE)]] kernel void ::c10::metal:: \
|
template [[host_name(#NAME "_strided_" #DTYPE)]] kernel void ::c10::metal:: \
|
||||||
binary_strided<DTYPE, NAME##_functor>( \
|
binary_strided<DTYPE, NAME##_functor>( \
|
||||||
|
|
@ -155,13 +224,31 @@ kernel void binary_dense(
|
||||||
constant long* output_strides, \
|
constant long* output_strides, \
|
||||||
constant long* input_strides, \
|
constant long* input_strides, \
|
||||||
constant long* other_strides, \
|
constant long* other_strides, \
|
||||||
constant uint& ndim, \
|
constant uint3& ndim, \
|
||||||
|
uint tid); \
|
||||||
|
template [[host_name(#NAME "_strided_cast_" #DTYPE)]] kernel void ::c10:: \
|
||||||
|
metal::binary_strided_cast<DTYPE, NAME##_functor>( \
|
||||||
|
device void* out, \
|
||||||
|
constant void* input, \
|
||||||
|
constant void* other, \
|
||||||
|
constant long* sizes, \
|
||||||
|
constant long* output_strides, \
|
||||||
|
constant long* input_strides, \
|
||||||
|
constant long* other_strides, \
|
||||||
|
constant uint3& ndim_types, \
|
||||||
uint tid); \
|
uint tid); \
|
||||||
template [[host_name(#NAME "_dense_" #DTYPE)]] kernel void ::c10::metal:: \
|
template [[host_name(#NAME "_dense_" #DTYPE)]] kernel void ::c10::metal:: \
|
||||||
binary_dense<DTYPE, NAME##_functor>( \
|
binary_dense<DTYPE, NAME##_functor>( \
|
||||||
device ::c10::metal::result_of<NAME##_functor, DTYPE, DTYPE> * out_, \
|
device ::c10::metal::result_of<NAME##_functor, DTYPE, DTYPE> * out_, \
|
||||||
constant DTYPE * input_, \
|
constant DTYPE * input_, \
|
||||||
constant DTYPE * other_, \
|
constant DTYPE * other_, \
|
||||||
|
uint tid); \
|
||||||
|
template [[host_name(#NAME "_dense_cast_" #DTYPE)]] kernel void ::c10:: \
|
||||||
|
metal::binary_dense_cast<DTYPE, NAME##_functor>( \
|
||||||
|
device ::c10::metal::result_of<NAME##_functor, DTYPE, DTYPE> * out_, \
|
||||||
|
constant void* input, \
|
||||||
|
constant void* other, \
|
||||||
|
constant uint4& sizes_types, \
|
||||||
uint tid)
|
uint tid)
|
||||||
} // namespace metal
|
} // namespace metal
|
||||||
} // namespace c10
|
} // namespace c10
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,6 @@
|
||||||
// Metal helper functions
|
// Metal helper functions
|
||||||
#pragma once
|
#pragma once
|
||||||
|
#include <c10/metal/common.h>
|
||||||
#include <metal_stdlib>
|
#include <metal_stdlib>
|
||||||
|
|
||||||
namespace c10 {
|
namespace c10 {
|
||||||
|
|
@ -145,5 +146,21 @@ template <typename T>
|
||||||
constexpr constant bool is_scalar_integral_v =
|
constexpr constant bool is_scalar_integral_v =
|
||||||
::metal::is_integral_v<T> && ::metal::is_scalar_v<T>;
|
::metal::is_integral_v<T> && ::metal::is_scalar_v<T>;
|
||||||
|
|
||||||
|
template <typename T, typename U>
|
||||||
|
inline ::metal::enable_if_t<::metal::is_same_v<U, T>, T> cast_to(const U from) {
|
||||||
|
return from;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename U>
|
||||||
|
inline ::metal::enable_if_t<is_complex_v<T>, T> cast_to(const U from) {
|
||||||
|
return T(float(from), 0.0);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename U>
|
||||||
|
inline ::metal::enable_if_t<!::metal::is_same_v<U, T> && !is_complex_v<T>, T>
|
||||||
|
cast_to(const U from) {
|
||||||
|
return static_cast<T>(from);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace metal
|
} // namespace metal
|
||||||
} // namespace c10
|
} // namespace c10
|
||||||
|
|
|
||||||
|
|
@ -12704,6 +12704,22 @@ class TestConsistency(TestCaseMPS):
|
||||||
rtol = 1.5e-3
|
rtol = 1.5e-3
|
||||||
self.assertEqual(cpu_grad_inputs, mps_grad_inputs, atol=atol, rtol=rtol)
|
self.assertEqual(cpu_grad_inputs, mps_grad_inputs, atol=atol, rtol=rtol)
|
||||||
|
|
||||||
|
def test_fmax_mixed_dtypes(self, device):
|
||||||
|
# Regression tesing for https://github.com/pytorch/pytorch/issues/149951
|
||||||
|
# fmax and fmin are implemented as binary metal shaders and they were implemented
|
||||||
|
# with the assumption that both args have the same dtype
|
||||||
|
x = torch.rand((3, 3), device=device, dtype=torch.float32)
|
||||||
|
x_int = torch.randint(-10, 10, (3, 3), device=device, dtype=torch.int8)
|
||||||
|
y = torch.rand((3, 3), device=device, dtype=torch.float16)
|
||||||
|
for op in [torch.fmax, torch.fmin]:
|
||||||
|
self.assertEqual(op(x, y), op(x.to("mps"), y.to("mps")).cpu())
|
||||||
|
self.assertEqual(op(x_int, y), op(x_int.to("mps"), y.to("mps")).cpu())
|
||||||
|
# Stride
|
||||||
|
self.assertEqual(op(x.t(), y), op(x.to("mps").t(), y.to("mps")).cpu())
|
||||||
|
# Broadcast
|
||||||
|
self.assertEqual(op(x, y[0]), op(x.to("mps"), y.to("mps")[0]).cpu())
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class TestErrorInputs(TestCase):
|
class TestErrorInputs(TestCase):
|
||||||
_ignore_not_implemented_error = True
|
_ignore_not_implemented_error = True
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user