diff --git a/aten/src/ATen/native/mps/OperationUtils.h b/aten/src/ATen/native/mps/OperationUtils.h index 7016292e6ef..1e71d9d8819 100644 --- a/aten/src/ATen/native/mps/OperationUtils.h +++ b/aten/src/ATen/native/mps/OperationUtils.h @@ -454,6 +454,10 @@ inline bool supportedFloatingOrComplexType(const Tensor& t) { return supportedFloatingOrComplexType(t.scalar_type()); } +inline void checkSupportsBFloat16() { + TORCH_CHECK_TYPE(is_macos_13_or_newer(MacOSVersion::MACOS_VER_14_0_PLUS), + "MPS bfloat16 type is supported on MacOS 14.0 or newer."); +} inline bool needsGather(const Tensor& t) { static const bool is_macOS_15_0_or_newer = is_macos_13_or_newer(MacOSVersion::MACOS_VER_15_0_PLUS); diff --git a/aten/src/ATen/native/mps/OperationUtils.mm b/aten/src/ATen/native/mps/OperationUtils.mm index 08e9ce18ff7..90879a026ed 100644 --- a/aten/src/ATen/native/mps/OperationUtils.mm +++ b/aten/src/ATen/native/mps/OperationUtils.mm @@ -55,11 +55,6 @@ static inline void checkSupportsComplex() { TORCH_CHECK_TYPE(supportsComplex(), "MPS complex types are only supported on MacOS 14.0 or newer."); } -static inline void checkSupportsBFloat16() { - TORCH_CHECK_TYPE(is_macos_13_or_newer(MacOSVersion::MACOS_VER_14_0_PLUS), - "MPS bfloat16 type is supported on MacOS 14.0 or newer."); -} - MPSDataType getMPSDataType(ScalarType scalar_type) { switch (scalar_type) { case ScalarType::Float: diff --git a/aten/src/ATen/native/mps/operations/Distributions.mm b/aten/src/ATen/native/mps/operations/Distributions.mm index 889a9a87d20..a91a5c3d717 100644 --- a/aten/src/ATen/native/mps/operations/Distributions.mm +++ b/aten/src/ATen/native/mps/operations/Distributions.mm @@ -68,13 +68,28 @@ Tensor& random_mps_impl(Tensor& self, newCachedGraph->stateTensor = mpsGraphRankedPlaceHolder(mpsGraph, MPSDataTypeInt32, @[ @(at::mps::detail::PHILOX_STATE_N) ]); - // FP16, FP32 and Int32 are the only data types supported for distributions on MPS backend. + // BF16, FP16, FP32 and Int32 are the only data types supported for distributions on MPS backend. const MPSDataType inputDataType = [&] { // only for random_mps, we pass interval range of type int64_t if constexpr (std::is_same_v) { return MPSDataTypeInt32; } - return (self.scalar_type() == ScalarType::Half) ? MPSDataTypeFloat16 : MPSDataTypeFloat32; + // for bernoully always use float32 + if constexpr (std::is_same_v) { + return MPSDataTypeFloat32; + } + switch (self.scalar_type()) { + case kHalf: + return MPSDataTypeFloat16; + case kFloat: + return MPSDataTypeFloat32; + case kBFloat16: { + checkSupportsBFloat16(); + return MPSDataTypeBFloat16; + } + default: + TORCH_CHECK_TYPE(false, "Unsupported type ", self.scalar_type(), " for operation ", op_name); + } }(); const MPSDataType outputDataType = std::is_same_v ? MPSDataTypeBool : inputDataType; diff --git a/test/test_mps.py b/test/test_mps.py index c1f8835002b..8693049dcc7 100644 --- a/test/test_mps.py +++ b/test/test_mps.py @@ -7873,27 +7873,27 @@ class TestMPS(TestCaseMPS): # Test normal def test_normal(self): - def helper(shape, mean=0.0, std=1.0): - mps_out = torch.normal(mean, std, shape, device='mps') + def helper(shape, mean=0.0, std=1.0, dtype=torch.float): + mps_out = torch.normal(mean, std, shape, device='mps', dtype=dtype) mean_array = np.ones(shape) mean_array *= mean - cpu_mean_tensor = torch.tensor(mean_array, device='cpu', dtype=torch.float, requires_grad=False) + cpu_mean_tensor = torch.tensor(mean_array, device='cpu', dtype=dtype, requires_grad=False) mean_tensor = cpu_mean_tensor.detach().clone().to('mps') std_array = np.ones(shape) std_array *= std - cpu_std_tensor = torch.tensor(std_array, device='cpu', dtype=torch.float, requires_grad=False) + cpu_std_tensor = torch.tensor(std_array, device='cpu', dtype=dtype, requires_grad=False) std_tensor = cpu_std_tensor.detach().clone().to('mps') # test out - mps_out = torch.zeros(shape, device='mps') + mps_out = torch.zeros(shape, device='mps', dtype=dtype) torch.normal(mean_tensor, std, out=mps_out) - mps_out = torch.zeros(shape, device='mps') + mps_out = torch.zeros(shape, device='mps', dtype=dtype) torch.normal(mean, std_tensor, out=mps_out) - mps_out = torch.zeros(shape, device='mps') + mps_out = torch.zeros(shape, device='mps', dtype=dtype) torch.normal(mean_tensor, std_tensor, out=mps_out) # test without out @@ -7910,6 +7910,16 @@ class TestMPS(TestCaseMPS): helper((2, 3, 4, 5, 6)) helper((100, 100), 2.5, 1.2) + # Test invalid inputs + with self.assertRaises(TypeError): + helper((10, 10), 10, 11, dtype=torch.int32) + + if product_version >= 14.0: + helper((10, 10), 2.5, 1.2, dtype=torch.bfloat16) + else: + with self.assertRaises(TypeError): + helper((10, 10), 2.5, 1.2, dtype=torch.bfloat16) + def test_bernoulli(self): shape = (10, 10) all_ones = torch.ones(shape, device='mps')