mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[EZ][MPS] Improve distribution error checking (#166425)
Essentially not allow ops on self-overlapping outputs, by adding
`at::assert_no_internal_overlap(self);` check that already used in CPU
and CUDA builds, see
895795f07c/aten/src/ATen/native/DistributionTemplates.h (L366)
This fixes `test_error_inputs_bernoulli_mps`
Should be landed ahead of https://github.com/pytorch/pytorch/pull/165267
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166425
Approved by: https://github.com/Skylion007, https://github.com/seemethere
This commit is contained in:
parent
687c15c0b3
commit
1abfa5f70b
|
|
@ -57,6 +57,7 @@ Tensor& random_mps_impl(Tensor& self,
|
|||
if (self.numel() == 0) {
|
||||
return self;
|
||||
}
|
||||
at::assert_no_internal_overlap(self);
|
||||
// MPS random is broken for 5D+ tensors, see https://github.com/pytorch/pytorch/issues/147624
|
||||
const auto need_reshape = self.ndimension() > 4;
|
||||
auto mps_gen = get_generator_or_default<MPSGeneratorImpl>(gen, at::mps::detail::getDefaultMPSGenerator());
|
||||
|
|
|
|||
|
|
@ -812,7 +812,6 @@ if torch.backends.mps.is_available():
|
|||
"__rmod__",
|
||||
"__rsub__",
|
||||
"__rpow__",
|
||||
"bernoulli",
|
||||
"clamp_max",
|
||||
"clamp_min",
|
||||
"masked_scatter",
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user