mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[EZ][MPS] Improve distribution error checking (#166425)
Essentially not allow ops on self-overlapping outputs, by adding
`at::assert_no_internal_overlap(self);` check that already used in CPU
and CUDA builds, see
895795f07c/aten/src/ATen/native/DistributionTemplates.h (L366)
This fixes `test_error_inputs_bernoulli_mps`
Should be landed ahead of https://github.com/pytorch/pytorch/pull/165267
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166425
Approved by: https://github.com/Skylion007, https://github.com/seemethere
This commit is contained in:
parent
687c15c0b3
commit
1abfa5f70b
|
|
@ -57,6 +57,7 @@ Tensor& random_mps_impl(Tensor& self,
|
||||||
if (self.numel() == 0) {
|
if (self.numel() == 0) {
|
||||||
return self;
|
return self;
|
||||||
}
|
}
|
||||||
|
at::assert_no_internal_overlap(self);
|
||||||
// MPS random is broken for 5D+ tensors, see https://github.com/pytorch/pytorch/issues/147624
|
// MPS random is broken for 5D+ tensors, see https://github.com/pytorch/pytorch/issues/147624
|
||||||
const auto need_reshape = self.ndimension() > 4;
|
const auto need_reshape = self.ndimension() > 4;
|
||||||
auto mps_gen = get_generator_or_default<MPSGeneratorImpl>(gen, at::mps::detail::getDefaultMPSGenerator());
|
auto mps_gen = get_generator_or_default<MPSGeneratorImpl>(gen, at::mps::detail::getDefaultMPSGenerator());
|
||||||
|
|
|
||||||
|
|
@ -812,7 +812,6 @@ if torch.backends.mps.is_available():
|
||||||
"__rmod__",
|
"__rmod__",
|
||||||
"__rsub__",
|
"__rsub__",
|
||||||
"__rpow__",
|
"__rpow__",
|
||||||
"bernoulli",
|
|
||||||
"clamp_max",
|
"clamp_max",
|
||||||
"clamp_min",
|
"clamp_min",
|
||||||
"masked_scatter",
|
"masked_scatter",
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user