[MPS] Fix .item() for multi-dim scalar (#107913)

By refactoring `_local_scalar_dense_mps` to use `_empty_like` to allocate CPU tensor.
Also, print a more reasonable error message when dst dim is less than src in mps_copy_

This fixes regression introduced by https://github.com/pytorch/pytorch/pull/105617 and adds regression test.

<!--
copilot:poem
-->
### <samp>🤖 Generated by Copilot at abd06e6</samp>

> _Sing, O Muse, of the valiant deeds of the PyTorch developers_
> _Who strive to improve the performance and usability of tensors_
> _And who, with skill and wisdom, fixed a bug in the MPS backend_
> _That caused confusion and dismay to many a user of `item()`_

Fixes https://github.com/pytorch/pytorch/issues/107867

Pull Request resolved: https://github.com/pytorch/pytorch/pull/107913
Approved by: https://github.com/albanD
This commit is contained in:
Nikita Shulga 2023-08-31 21:08:29 +00:00 committed by PyTorch MergeBot
parent 5b6ba4110b
commit bae409388c
3 changed files with 6 additions and 4 deletions

View File

@ -293,7 +293,8 @@ at::Tensor& mps_copy_(at::Tensor& dst, const at::Tensor& src, bool non_blocking)
dst.resize_as_(src);
}
TORCH_CHECK(dst.dim() >= src.dim());
TORCH_CHECK(
dst.dim() >= src.dim(), "Destination ", dst.sym_sizes(), " doesn't match the broadcast shape ", src.sym_sizes());
if (dst.dim() > src.dim()) {
needs_broadcasting = true;
} else {

View File

@ -16,15 +16,14 @@ namespace at::native {
Scalar _local_scalar_dense_mps(const Tensor& self) {
Scalar r;
auto output = at::empty_like(self, TensorOptions(kCPU));
mps::mps_copy_(output, self, false);
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(at::ScalarType::Half,
at::ScalarType::Bool,
at::ScalarType::BFloat16,
self.scalar_type(),
"_local_scalar_dense_mps",
[&] {
Tensor output = at::empty({1}, TensorOptions(at::CPU(self.scalar_type())));
mps::mps_copy_(output, self, false);
scalar_t value = *output.data_ptr<scalar_t>();
r = Scalar(value);
});

View File

@ -3565,6 +3565,8 @@ class TestMPS(TestCaseMPS):
helper((1, 5), (4, 0, 5), src_dtype, dst_dtype)
helper((3, 1, 0), (3, 5, 0), src_dtype, dst_dtype)
helper((0, 1, 0), (0, 5, 0), src_dtype, dst_dtype)
# Regression test for https://github.com/pytorch/pytorch/issues/107867
self.assertEqual(torch.tensor([[1]], device='mps').item(), 1.0)
# See https://github.com/pytorch/pytorch/pull/84742
# and https://github.com/pytorch/pytorch/pull/78319