change multinomial to use async asserts instead of a synchronization (#134818)

Fixes https://github.com/pytorch/pytorch/issues/134442

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134818
Approved by: https://github.com/ezyang
ghstack dependencies: #134813
This commit is contained in:
chilli 2024-09-02 19:17:26 -07:00 committed by PyTorch MergeBot
parent db193d1e29
commit 6fce1faa10
2 changed files with 26 additions and 28 deletions

View File

@ -23,6 +23,7 @@
#include <ATen/ops/_sample_dirichlet_native.h>
#include <ATen/ops/_standard_gamma_grad_native.h>
#include <ATen/ops/_standard_gamma_native.h>
#include <ATen/ops/_assert_async.h>
#include <ATen/ops/argmax.h>
#include <ATen/ops/bernoulli_native.h>
#include <ATen/ops/binomial_native.h>
@ -585,19 +586,15 @@ Tensor& multinomial_out(const Tensor& self,
// https://github.com/pytorch/pytorch/issues/11931#issuecomment-625882503
if (!with_replacement || n_sample == 1) {
// Sanity checks on `self`.
auto is_valid = ((self.max() < INFINITY) & (self.min() >= 0)).item();
TORCH_CHECK(
is_valid.to<bool>(),
"probability tensor contains either `inf`, `nan` or element < 0");
bool zero_prob_condition = false;
auto is_valid = ((self.max() < INFINITY) & (self.min() >= 0));
at::_assert_async(is_valid, "probability tensor contains either `inf`, `nan` or element < 0");
at::Tensor zero_prob_condition;
if (self.dim() == 1){
zero_prob_condition = (self.sum() == 0).item().to<bool>();
zero_prob_condition = (self.sum() == 0);
} else {
zero_prob_condition = (self.sum(1) == 0).sum().item().to<bool>();
zero_prob_condition = (self.sum(1) == 0).any();
}
TORCH_CHECK(
!zero_prob_condition,
"invalid multinomial distribution (sum of probabilities <= 0)");
at::_assert_async(~zero_prob_condition, "invalid multinomial distribution (sum of probabilities <= 0)");
// The algorithm is from gumbel softmax.
// s = argmax( logp - log(-log(eps)) ) where eps ~ U(0, 1)

View File

@ -2849,28 +2849,29 @@ def error_inputs_multinomial(op_info, device, **kwargs):
rep_arg = (False, True) if torch.device(device).type == 'cpu' else (False,)
for rep in rep_arg:
kwargs = {'num_samples': 2, 'replacement': rep}
if torch.device(device).type == 'cpu':
for rep in rep_arg:
kwargs = {'num_samples': 2, 'replacement': rep}
for shape in inputs:
# error case when input tensor contains `inf`, `nan` or negative element
yield ErrorInput(SampleInput(torch.tensor(shape), kwargs=kwargs),
error_regex=err_msg1 if rep is False else err_msg2)
for shape in inputs:
# error case when input tensor contains `inf`, `nan` or negative element
yield ErrorInput(SampleInput(torch.tensor(shape), kwargs=kwargs),
error_regex=err_msg1 if rep is False else err_msg2)
# error case for the invalid multinomial distribution (sum of probabilities <= 0), 1-D input
x = torch.zeros(3, device=device)
yield ErrorInput(SampleInput(x, kwargs=kwargs),
error_regex=err_msg2)
# error case for the invalid multinomial distribution (sum of probabilities <= 0), 1-D input
x = torch.zeros(3, device=device)
yield ErrorInput(SampleInput(x, kwargs=kwargs),
error_regex=err_msg2)
# error case for the invalid multinomial distribution (sum of probabilities <= 0), 2-D input
x = torch.zeros(3, 3, device=device)
yield ErrorInput(SampleInput(x, kwargs=kwargs),
error_regex=err_msg2)
# error case for the invalid multinomial distribution (sum of probabilities <= 0), 2-D input
x = torch.zeros(3, 3, device=device)
yield ErrorInput(SampleInput(x, kwargs=kwargs),
error_regex=err_msg2)
# error case for the invalid multinomial distribution
x[1, :] = 1
yield ErrorInput(SampleInput(x, kwargs=kwargs),
error_regex=err_msg2)
# error case for the invalid multinomial distribution
x[1, :] = 1
yield ErrorInput(SampleInput(x, kwargs=kwargs),
error_regex=err_msg2)
def error_inputs_gradient(op_info, device, **kwargs):
for dtype in [torch.long, torch.float32, torch.complex64]: