mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[MPS] Error checking/bf16 support for torch.normal (#136863)
Before that attempt to run something like ``` % python -c "import torch;dev,dt='mps',torch.int; print(torch.normal(mean=torch.arange(1., 11., device=dev, dtype=dt), std=torch.arange(10, 0, -1, device=dev, dtype=dt)))" ``` Resulted in hard error ``` (mpsFileLoc): /AppleInternal/Library/BuildRoots/e0873e53-5185-11ef-9a51-9ab6d782fe32/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShadersGraph/mpsgraph/MetalPerformanceShadersGraph/Core/Files/MPSGraphUtilities.mm:233:0: error: 'mps.multiply' op requires the same element type for all operands and results (mpsFileLoc): /AppleInternal/Library/BuildRoots/e0873e53-5185-11ef-9a51-9ab6d782fe32/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShadersGraph/mpsgraph/MetalPerformanceShadersGraph/Core/Files/MPSGraphUtilities.mm:233:0: note: see current operation: %5 = "mps.multiply"(%2, %arg1) : (tensor<10xf32>, tensor<10xsi32>) -> tensor<*xf32> (mpsFileLoc): /AppleInternal/Library/BuildRoots/e0873e53-5185-11ef-9a51-9ab6d782fe32/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShadersGraph/mpsgraph/MetalPerformanceShadersGraph/Core/Files/MPSGraphUtilities.mm:233:0: error: 'mps.multiply' op requires the same element type for all operands and results (mpsFileLoc): /AppleInternal/Library/BuildRoots/e0873e53-5185-11ef-9a51-9ab6d782fe32/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShadersGraph/mpsgraph/MetalPerformanceShadersGraph/Core/Files/MPSGraphUtilities.mm:233:0: note: see current operation: %5 = "mps.multiply"(%2, %arg1) : (tensor<10xf32>, tensor<10xsi32>) -> tensor<*xf32> /AppleInternal/Library/BuildRoots/e0873e53-5185-11ef-9a51-9ab6d782fe32/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShadersGraph/mpsgraph/MetalPerformanceShadersGraph/Core/Files/MPSGraphExecutable.mm:953: failed assertion `original module failed verification' ``` After the change, it raises a nice type error Pull Request resolved: https://github.com/pytorch/pytorch/pull/136863 Approved by: https://github.com/Skylion007 ghstack dependencies: #136754, #136755, #136821, #136822
This commit is contained in:
parent
f7ab0e9989
commit
283bda01aa
|
|
@ -454,6 +454,10 @@ inline bool supportedFloatingOrComplexType(const Tensor& t) {
|
|||
return supportedFloatingOrComplexType(t.scalar_type());
|
||||
}
|
||||
|
||||
inline void checkSupportsBFloat16() {
|
||||
TORCH_CHECK_TYPE(is_macos_13_or_newer(MacOSVersion::MACOS_VER_14_0_PLUS),
|
||||
"MPS bfloat16 type is supported on MacOS 14.0 or newer.");
|
||||
}
|
||||
|
||||
inline bool needsGather(const Tensor& t) {
|
||||
static const bool is_macOS_15_0_or_newer = is_macos_13_or_newer(MacOSVersion::MACOS_VER_15_0_PLUS);
|
||||
|
|
|
|||
|
|
@ -55,11 +55,6 @@ static inline void checkSupportsComplex() {
|
|||
TORCH_CHECK_TYPE(supportsComplex(), "MPS complex types are only supported on MacOS 14.0 or newer.");
|
||||
}
|
||||
|
||||
static inline void checkSupportsBFloat16() {
|
||||
TORCH_CHECK_TYPE(is_macos_13_or_newer(MacOSVersion::MACOS_VER_14_0_PLUS),
|
||||
"MPS bfloat16 type is supported on MacOS 14.0 or newer.");
|
||||
}
|
||||
|
||||
MPSDataType getMPSDataType(ScalarType scalar_type) {
|
||||
switch (scalar_type) {
|
||||
case ScalarType::Float:
|
||||
|
|
|
|||
|
|
@ -68,13 +68,28 @@ Tensor& random_mps_impl(Tensor& self,
|
|||
newCachedGraph->stateTensor =
|
||||
mpsGraphRankedPlaceHolder(mpsGraph, MPSDataTypeInt32, @[ @(at::mps::detail::PHILOX_STATE_N) ]);
|
||||
|
||||
// FP16, FP32 and Int32 are the only data types supported for distributions on MPS backend.
|
||||
// BF16, FP16, FP32 and Int32 are the only data types supported for distributions on MPS backend.
|
||||
const MPSDataType inputDataType = [&] {
|
||||
// only for random_mps, we pass interval range of type int64_t
|
||||
if constexpr (std::is_same_v<scalar_t, int64_t>) {
|
||||
return MPSDataTypeInt32;
|
||||
}
|
||||
return (self.scalar_type() == ScalarType::Half) ? MPSDataTypeFloat16 : MPSDataTypeFloat32;
|
||||
// for bernoully always use float32
|
||||
if constexpr (std::is_same_v<scalar_t, bool>) {
|
||||
return MPSDataTypeFloat32;
|
||||
}
|
||||
switch (self.scalar_type()) {
|
||||
case kHalf:
|
||||
return MPSDataTypeFloat16;
|
||||
case kFloat:
|
||||
return MPSDataTypeFloat32;
|
||||
case kBFloat16: {
|
||||
checkSupportsBFloat16();
|
||||
return MPSDataTypeBFloat16;
|
||||
}
|
||||
default:
|
||||
TORCH_CHECK_TYPE(false, "Unsupported type ", self.scalar_type(), " for operation ", op_name);
|
||||
}
|
||||
}();
|
||||
const MPSDataType outputDataType = std::is_same_v<scalar_t, bool> ? MPSDataTypeBool : inputDataType;
|
||||
|
||||
|
|
|
|||
|
|
@ -7873,27 +7873,27 @@ class TestMPS(TestCaseMPS):
|
|||
|
||||
# Test normal
|
||||
def test_normal(self):
|
||||
def helper(shape, mean=0.0, std=1.0):
|
||||
mps_out = torch.normal(mean, std, shape, device='mps')
|
||||
def helper(shape, mean=0.0, std=1.0, dtype=torch.float):
|
||||
mps_out = torch.normal(mean, std, shape, device='mps', dtype=dtype)
|
||||
|
||||
mean_array = np.ones(shape)
|
||||
mean_array *= mean
|
||||
cpu_mean_tensor = torch.tensor(mean_array, device='cpu', dtype=torch.float, requires_grad=False)
|
||||
cpu_mean_tensor = torch.tensor(mean_array, device='cpu', dtype=dtype, requires_grad=False)
|
||||
mean_tensor = cpu_mean_tensor.detach().clone().to('mps')
|
||||
|
||||
std_array = np.ones(shape)
|
||||
std_array *= std
|
||||
cpu_std_tensor = torch.tensor(std_array, device='cpu', dtype=torch.float, requires_grad=False)
|
||||
cpu_std_tensor = torch.tensor(std_array, device='cpu', dtype=dtype, requires_grad=False)
|
||||
std_tensor = cpu_std_tensor.detach().clone().to('mps')
|
||||
|
||||
# test out
|
||||
mps_out = torch.zeros(shape, device='mps')
|
||||
mps_out = torch.zeros(shape, device='mps', dtype=dtype)
|
||||
torch.normal(mean_tensor, std, out=mps_out)
|
||||
|
||||
mps_out = torch.zeros(shape, device='mps')
|
||||
mps_out = torch.zeros(shape, device='mps', dtype=dtype)
|
||||
torch.normal(mean, std_tensor, out=mps_out)
|
||||
|
||||
mps_out = torch.zeros(shape, device='mps')
|
||||
mps_out = torch.zeros(shape, device='mps', dtype=dtype)
|
||||
torch.normal(mean_tensor, std_tensor, out=mps_out)
|
||||
|
||||
# test without out
|
||||
|
|
@ -7910,6 +7910,16 @@ class TestMPS(TestCaseMPS):
|
|||
helper((2, 3, 4, 5, 6))
|
||||
helper((100, 100), 2.5, 1.2)
|
||||
|
||||
# Test invalid inputs
|
||||
with self.assertRaises(TypeError):
|
||||
helper((10, 10), 10, 11, dtype=torch.int32)
|
||||
|
||||
if product_version >= 14.0:
|
||||
helper((10, 10), 2.5, 1.2, dtype=torch.bfloat16)
|
||||
else:
|
||||
with self.assertRaises(TypeError):
|
||||
helper((10, 10), 2.5, 1.2, dtype=torch.bfloat16)
|
||||
|
||||
def test_bernoulli(self):
|
||||
shape = (10, 10)
|
||||
all_ones = torch.ones(shape, device='mps')
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user