[MPS] Error checking/bf16 support for torch.normal (#136863)

Before that attempt to run something like
```
% python -c "import torch;dev,dt='mps',torch.int; print(torch.normal(mean=torch.arange(1., 11., device=dev, dtype=dt), std=torch.arange(10, 0, -1, device=dev, dtype=dt)))"
```
Resulted in hard error
```
(mpsFileLoc): /AppleInternal/Library/BuildRoots/e0873e53-5185-11ef-9a51-9ab6d782fe32/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShadersGraph/mpsgraph/MetalPerformanceShadersGraph/Core/Files/MPSGraphUtilities.mm:233:0: error: 'mps.multiply' op requires the same element type for all operands and results
(mpsFileLoc): /AppleInternal/Library/BuildRoots/e0873e53-5185-11ef-9a51-9ab6d782fe32/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShadersGraph/mpsgraph/MetalPerformanceShadersGraph/Core/Files/MPSGraphUtilities.mm:233:0: note: see current operation: %5 = "mps.multiply"(%2, %arg1) : (tensor<10xf32>, tensor<10xsi32>) -> tensor<*xf32>
(mpsFileLoc): /AppleInternal/Library/BuildRoots/e0873e53-5185-11ef-9a51-9ab6d782fe32/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShadersGraph/mpsgraph/MetalPerformanceShadersGraph/Core/Files/MPSGraphUtilities.mm:233:0: error: 'mps.multiply' op requires the same element type for all operands and results
(mpsFileLoc): /AppleInternal/Library/BuildRoots/e0873e53-5185-11ef-9a51-9ab6d782fe32/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShadersGraph/mpsgraph/MetalPerformanceShadersGraph/Core/Files/MPSGraphUtilities.mm:233:0: note: see current operation: %5 = "mps.multiply"(%2, %arg1) : (tensor<10xf32>, tensor<10xsi32>) -> tensor<*xf32>
/AppleInternal/Library/BuildRoots/e0873e53-5185-11ef-9a51-9ab6d782fe32/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShadersGraph/mpsgraph/MetalPerformanceShadersGraph/Core/Files/MPSGraphExecutable.mm:953: failed assertion `original module failed verification'
```
After the change, it raises a nice type error
Pull Request resolved: https://github.com/pytorch/pytorch/pull/136863
Approved by: https://github.com/Skylion007
ghstack dependencies: #136754, #136755, #136821, #136822
This commit is contained in:
Nikita Shulga 2024-09-27 13:09:35 -07:00 committed by PyTorch MergeBot
parent f7ab0e9989
commit 283bda01aa
4 changed files with 38 additions and 14 deletions

View File

@ -454,6 +454,10 @@ inline bool supportedFloatingOrComplexType(const Tensor& t) {
return supportedFloatingOrComplexType(t.scalar_type()); return supportedFloatingOrComplexType(t.scalar_type());
} }
inline void checkSupportsBFloat16() {
TORCH_CHECK_TYPE(is_macos_13_or_newer(MacOSVersion::MACOS_VER_14_0_PLUS),
"MPS bfloat16 type is supported on MacOS 14.0 or newer.");
}
inline bool needsGather(const Tensor& t) { inline bool needsGather(const Tensor& t) {
static const bool is_macOS_15_0_or_newer = is_macos_13_or_newer(MacOSVersion::MACOS_VER_15_0_PLUS); static const bool is_macOS_15_0_or_newer = is_macos_13_or_newer(MacOSVersion::MACOS_VER_15_0_PLUS);

View File

@ -55,11 +55,6 @@ static inline void checkSupportsComplex() {
TORCH_CHECK_TYPE(supportsComplex(), "MPS complex types are only supported on MacOS 14.0 or newer."); TORCH_CHECK_TYPE(supportsComplex(), "MPS complex types are only supported on MacOS 14.0 or newer.");
} }
static inline void checkSupportsBFloat16() {
TORCH_CHECK_TYPE(is_macos_13_or_newer(MacOSVersion::MACOS_VER_14_0_PLUS),
"MPS bfloat16 type is supported on MacOS 14.0 or newer.");
}
MPSDataType getMPSDataType(ScalarType scalar_type) { MPSDataType getMPSDataType(ScalarType scalar_type) {
switch (scalar_type) { switch (scalar_type) {
case ScalarType::Float: case ScalarType::Float:

View File

@ -68,13 +68,28 @@ Tensor& random_mps_impl(Tensor& self,
newCachedGraph->stateTensor = newCachedGraph->stateTensor =
mpsGraphRankedPlaceHolder(mpsGraph, MPSDataTypeInt32, @[ @(at::mps::detail::PHILOX_STATE_N) ]); mpsGraphRankedPlaceHolder(mpsGraph, MPSDataTypeInt32, @[ @(at::mps::detail::PHILOX_STATE_N) ]);
// FP16, FP32 and Int32 are the only data types supported for distributions on MPS backend. // BF16, FP16, FP32 and Int32 are the only data types supported for distributions on MPS backend.
const MPSDataType inputDataType = [&] { const MPSDataType inputDataType = [&] {
// only for random_mps, we pass interval range of type int64_t // only for random_mps, we pass interval range of type int64_t
if constexpr (std::is_same_v<scalar_t, int64_t>) { if constexpr (std::is_same_v<scalar_t, int64_t>) {
return MPSDataTypeInt32; return MPSDataTypeInt32;
} }
return (self.scalar_type() == ScalarType::Half) ? MPSDataTypeFloat16 : MPSDataTypeFloat32; // for bernoully always use float32
if constexpr (std::is_same_v<scalar_t, bool>) {
return MPSDataTypeFloat32;
}
switch (self.scalar_type()) {
case kHalf:
return MPSDataTypeFloat16;
case kFloat:
return MPSDataTypeFloat32;
case kBFloat16: {
checkSupportsBFloat16();
return MPSDataTypeBFloat16;
}
default:
TORCH_CHECK_TYPE(false, "Unsupported type ", self.scalar_type(), " for operation ", op_name);
}
}(); }();
const MPSDataType outputDataType = std::is_same_v<scalar_t, bool> ? MPSDataTypeBool : inputDataType; const MPSDataType outputDataType = std::is_same_v<scalar_t, bool> ? MPSDataTypeBool : inputDataType;

View File

@ -7873,27 +7873,27 @@ class TestMPS(TestCaseMPS):
# Test normal # Test normal
def test_normal(self): def test_normal(self):
def helper(shape, mean=0.0, std=1.0): def helper(shape, mean=0.0, std=1.0, dtype=torch.float):
mps_out = torch.normal(mean, std, shape, device='mps') mps_out = torch.normal(mean, std, shape, device='mps', dtype=dtype)
mean_array = np.ones(shape) mean_array = np.ones(shape)
mean_array *= mean mean_array *= mean
cpu_mean_tensor = torch.tensor(mean_array, device='cpu', dtype=torch.float, requires_grad=False) cpu_mean_tensor = torch.tensor(mean_array, device='cpu', dtype=dtype, requires_grad=False)
mean_tensor = cpu_mean_tensor.detach().clone().to('mps') mean_tensor = cpu_mean_tensor.detach().clone().to('mps')
std_array = np.ones(shape) std_array = np.ones(shape)
std_array *= std std_array *= std
cpu_std_tensor = torch.tensor(std_array, device='cpu', dtype=torch.float, requires_grad=False) cpu_std_tensor = torch.tensor(std_array, device='cpu', dtype=dtype, requires_grad=False)
std_tensor = cpu_std_tensor.detach().clone().to('mps') std_tensor = cpu_std_tensor.detach().clone().to('mps')
# test out # test out
mps_out = torch.zeros(shape, device='mps') mps_out = torch.zeros(shape, device='mps', dtype=dtype)
torch.normal(mean_tensor, std, out=mps_out) torch.normal(mean_tensor, std, out=mps_out)
mps_out = torch.zeros(shape, device='mps') mps_out = torch.zeros(shape, device='mps', dtype=dtype)
torch.normal(mean, std_tensor, out=mps_out) torch.normal(mean, std_tensor, out=mps_out)
mps_out = torch.zeros(shape, device='mps') mps_out = torch.zeros(shape, device='mps', dtype=dtype)
torch.normal(mean_tensor, std_tensor, out=mps_out) torch.normal(mean_tensor, std_tensor, out=mps_out)
# test without out # test without out
@ -7910,6 +7910,16 @@ class TestMPS(TestCaseMPS):
helper((2, 3, 4, 5, 6)) helper((2, 3, 4, 5, 6))
helper((100, 100), 2.5, 1.2) helper((100, 100), 2.5, 1.2)
# Test invalid inputs
with self.assertRaises(TypeError):
helper((10, 10), 10, 11, dtype=torch.int32)
if product_version >= 14.0:
helper((10, 10), 2.5, 1.2, dtype=torch.bfloat16)
else:
with self.assertRaises(TypeError):
helper((10, 10), 2.5, 1.2, dtype=torch.bfloat16)
def test_bernoulli(self): def test_bernoulli(self):
shape = (10, 10) shape = (10, 10)
all_ones = torch.ones(shape, device='mps') all_ones = torch.ones(shape, device='mps')