diff --git a/aten/src/ATen/native/mps/operations/Distributions.mm b/aten/src/ATen/native/mps/operations/Distributions.mm index 85c22c59caf..ef64233f6fe 100644 --- a/aten/src/ATen/native/mps/operations/Distributions.mm +++ b/aten/src/ATen/native/mps/operations/Distributions.mm @@ -57,6 +57,7 @@ Tensor& random_mps_impl(Tensor& self, if (self.numel() == 0) { return self; } + at::assert_no_internal_overlap(self); // MPS random is broken for 5D+ tensors, see https://github.com/pytorch/pytorch/issues/147624 const auto need_reshape = self.ndimension() > 4; auto mps_gen = get_generator_or_default(gen, at::mps::detail::getDefaultMPSGenerator()); diff --git a/torch/testing/_internal/common_mps.py b/torch/testing/_internal/common_mps.py index ce64f1d9cdd..b89487f400c 100644 --- a/torch/testing/_internal/common_mps.py +++ b/torch/testing/_internal/common_mps.py @@ -812,7 +812,6 @@ if torch.backends.mps.is_available(): "__rmod__", "__rsub__", "__rpow__", - "bernoulli", "clamp_max", "clamp_min", "masked_scatter",