mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
First of all, by extending `c10:🤘:cast_to` to work correctly with complex dtypes, by introducing two more specializations: one that casts complex to scalar, and another that casts scalar to complex (as default metal typecast will turn `float x` into `float2(x, x)`) Add ComplexHalf and ComplexFloat enum values to `c10:🤘:ScalarTypes` and handle them in `val_at_offs(ptr, offs, type)` Pull Request resolved: https://github.com/pytorch/pytorch/pull/152504 Approved by: https://github.com/dcci ghstack dependencies: #152443, #152466, #152479
49 lines
1.4 KiB
C++
49 lines
1.4 KiB
C++
#pragma once
|
|
// Set of global constants that could be shareable between CPU and Metal code
|
|
|
|
#ifdef __METAL__
|
|
#define C10_METAL_CONSTEXPR constant constexpr
|
|
#else
|
|
#define C10_METAL_CONSTEXPR constexpr
|
|
#endif
|
|
|
|
#if !defined(__METAL__) || __METAL_VERSION__ >= 310
|
|
#define C10_METAL_ALL_TYPES_FUNCTOR(_) \
|
|
_(Byte, 0) \
|
|
_(Char, 1) \
|
|
_(Short, 2) \
|
|
_(Int, 3) \
|
|
_(Long, 4) \
|
|
_(Half, 5) \
|
|
_(Float, 6) \
|
|
_(ComplexHalf, 8) \
|
|
_(ComplexFloat, 9) \
|
|
_(Bool, 11) \
|
|
_(BFloat16, 15)
|
|
#else
|
|
#define C10_METAL_ALL_TYPES_FUNCTOR(_) \
|
|
_(Byte, 0) \
|
|
_(Char, 1) \
|
|
_(Short, 2) \
|
|
_(Int, 3) \
|
|
_(Long, 4) \
|
|
_(Half, 5) \
|
|
_(Float, 6) \
|
|
_(ComplexHalf, 8) \
|
|
_(ComplexFloat, 9) \
|
|
_(Bool, 11)
|
|
#endif
|
|
|
|
namespace c10 {
|
|
namespace metal {
|
|
C10_METAL_CONSTEXPR unsigned max_ndim = 16;
|
|
|
|
enum class ScalarType {
|
|
#define _DEFINE_ENUM_VAL_(_v, _n) _v = _n,
|
|
C10_METAL_ALL_TYPES_FUNCTOR(_DEFINE_ENUM_VAL_)
|
|
#undef _DEFINE_ENUM_VAL_
|
|
};
|
|
|
|
} // namespace metal
|
|
} // namespace c10
|