[EZ][MPS] Improve distribution error checking (#166425)

Essentially not allow ops on self-overlapping outputs, by adding
`at::assert_no_internal_overlap(self);` check that already used in CPU
and CUDA builds, see
895795f07c/aten/src/ATen/native/DistributionTemplates.h (L366)

This fixes `test_error_inputs_bernoulli_mps`

Should be landed ahead of https://github.com/pytorch/pytorch/pull/165267
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166425
Approved by: https://github.com/Skylion007, https://github.com/seemethere
This commit is contained in:
Nikita Shulga 2025-10-28 10:44:31 -07:00 committed by PyTorch MergeBot
parent 687c15c0b3
commit 1abfa5f70b
2 changed files with 1 additions and 1 deletions

View File

@ -57,6 +57,7 @@ Tensor& random_mps_impl(Tensor& self,
if (self.numel() == 0) {
return self;
}
at::assert_no_internal_overlap(self);
// MPS random is broken for 5D+ tensors, see https://github.com/pytorch/pytorch/issues/147624
const auto need_reshape = self.ndimension() > 4;
auto mps_gen = get_generator_or_default<MPSGeneratorImpl>(gen, at::mps::detail::getDefaultMPSGenerator());

View File

@ -812,7 +812,6 @@ if torch.backends.mps.is_available():
"__rmod__",
"__rsub__",
"__rpow__",
"bernoulli",
"clamp_max",
"clamp_min",
"masked_scatter",