mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[MPS] Fix index_copy for scalars (#161267)
By `squeezing the input` when copying into scalar tensor from a 1d one And enable `test_index_copy_scalars_mps` Fixes https://github.com/pytorch/pytorch/issues/160737 Pull Request resolved: https://github.com/pytorch/pytorch/pull/161267 Approved by: https://github.com/manuelcandales, https://github.com/Skylion007, https://github.com/dcci ghstack dependencies: #161206
This commit is contained in:
parent
4c36c8a994
commit
c8bb0e4720
|
|
@ -230,7 +230,7 @@ TORCH_IMPL_FUNC(index_copy_out_mps)(const Tensor& self,
|
||||||
index.numel());
|
index.numel());
|
||||||
int64_t idx = index.item<int64_t>();
|
int64_t idx = index.item<int64_t>();
|
||||||
TORCH_CHECK(idx == 0, "index_copy_(): the only valid index for a 0-dim tensor is 0, but got ", idx);
|
TORCH_CHECK(idx == 0, "index_copy_(): the only valid index for a 0-dim tensor is 0, but got ", idx);
|
||||||
result.copy_(source);
|
result.copy_(source.squeeze());
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -254,11 +254,12 @@ TORCH_IMPL_FUNC(index_copy_out_mps)(const Tensor& self,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
TORCH_CHECK(source.size(dim) == index.numel(),
|
const auto source_size_dim = source.dim() > 0 ? source.size(dim) : 1;
|
||||||
|
TORCH_CHECK(index.numel() == source_size_dim,
|
||||||
"index_copy_(): Number of indices (",
|
"index_copy_(): Number of indices (",
|
||||||
index.numel(),
|
index.numel(),
|
||||||
") should be equal to source.size(dim) (",
|
") should be equal to source.size(dim) (",
|
||||||
source.size(dim),
|
source_size_dim,
|
||||||
")");
|
")");
|
||||||
|
|
||||||
auto stream = getCurrentMPSStream();
|
auto stream = getCurrentMPSStream();
|
||||||
|
|
|
||||||
|
|
@ -1913,8 +1913,8 @@ class TestIndexing(TestCase):
|
||||||
# onlyNativeDeviceTypes due to an XLA error:
|
# onlyNativeDeviceTypes due to an XLA error:
|
||||||
# https://github.com/pytorch/pytorch/issues/53256
|
# https://github.com/pytorch/pytorch/issues/53256
|
||||||
@onlyNativeDeviceTypes
|
@onlyNativeDeviceTypes
|
||||||
@expectedFailureMPS # See https://github.com/pytorch/pytorch/issues/160737
|
|
||||||
@dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
|
@dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
|
||||||
|
@dtypesIfMPS(*all_mps_types_and(torch.bool, torch.cfloat))
|
||||||
def test_index_copy_scalars(self, device, dtype):
|
def test_index_copy_scalars(self, device, dtype):
|
||||||
# Create the 8 possible combinations of scalar sizes for target / index / source
|
# Create the 8 possible combinations of scalar sizes for target / index / source
|
||||||
scalars = (
|
scalars = (
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user