mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[MPS] Fix .item() for multi-dim scalar (#107913)
By refactoring `_local_scalar_dense_mps` to use `_empty_like` to allocate CPU tensor. Also, print a more reasonable error message when dst dim is less than src in mps_copy_ This fixes regression introduced by https://github.com/pytorch/pytorch/pull/105617 and adds regression test. <!-- copilot:poem --> ### <samp>🤖 Generated by Copilot at abd06e6</samp> > _Sing, O Muse, of the valiant deeds of the PyTorch developers_ > _Who strive to improve the performance and usability of tensors_ > _And who, with skill and wisdom, fixed a bug in the MPS backend_ > _That caused confusion and dismay to many a user of `item()`_ Fixes https://github.com/pytorch/pytorch/issues/107867 Pull Request resolved: https://github.com/pytorch/pytorch/pull/107913 Approved by: https://github.com/albanD
This commit is contained in:
parent
5b6ba4110b
commit
bae409388c
|
|
@ -293,7 +293,8 @@ at::Tensor& mps_copy_(at::Tensor& dst, const at::Tensor& src, bool non_blocking)
|
|||
dst.resize_as_(src);
|
||||
}
|
||||
|
||||
TORCH_CHECK(dst.dim() >= src.dim());
|
||||
TORCH_CHECK(
|
||||
dst.dim() >= src.dim(), "Destination ", dst.sym_sizes(), " doesn't match the broadcast shape ", src.sym_sizes());
|
||||
if (dst.dim() > src.dim()) {
|
||||
needs_broadcasting = true;
|
||||
} else {
|
||||
|
|
|
|||
|
|
@ -16,15 +16,14 @@ namespace at::native {
|
|||
Scalar _local_scalar_dense_mps(const Tensor& self) {
|
||||
Scalar r;
|
||||
|
||||
auto output = at::empty_like(self, TensorOptions(kCPU));
|
||||
mps::mps_copy_(output, self, false);
|
||||
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(at::ScalarType::Half,
|
||||
at::ScalarType::Bool,
|
||||
at::ScalarType::BFloat16,
|
||||
self.scalar_type(),
|
||||
"_local_scalar_dense_mps",
|
||||
[&] {
|
||||
Tensor output = at::empty({1}, TensorOptions(at::CPU(self.scalar_type())));
|
||||
|
||||
mps::mps_copy_(output, self, false);
|
||||
scalar_t value = *output.data_ptr<scalar_t>();
|
||||
r = Scalar(value);
|
||||
});
|
||||
|
|
|
|||
|
|
@ -3565,6 +3565,8 @@ class TestMPS(TestCaseMPS):
|
|||
helper((1, 5), (4, 0, 5), src_dtype, dst_dtype)
|
||||
helper((3, 1, 0), (3, 5, 0), src_dtype, dst_dtype)
|
||||
helper((0, 1, 0), (0, 5, 0), src_dtype, dst_dtype)
|
||||
# Regression test for https://github.com/pytorch/pytorch/issues/107867
|
||||
self.assertEqual(torch.tensor([[1]], device='mps').item(), 1.0)
|
||||
|
||||
# See https://github.com/pytorch/pytorch/pull/84742
|
||||
# and https://github.com/pytorch/pytorch/pull/78319
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user