diff --git a/aten/src/ATen/native/Distributions.cpp b/aten/src/ATen/native/Distributions.cpp index 0de9632ddb1..bd1d974e9be 100644 --- a/aten/src/ATen/native/Distributions.cpp +++ b/aten/src/ATen/native/Distributions.cpp @@ -23,6 +23,7 @@ #include #include #include +#include #include #include #include @@ -585,19 +586,15 @@ Tensor& multinomial_out(const Tensor& self, // https://github.com/pytorch/pytorch/issues/11931#issuecomment-625882503 if (!with_replacement || n_sample == 1) { // Sanity checks on `self`. - auto is_valid = ((self.max() < INFINITY) & (self.min() >= 0)).item(); - TORCH_CHECK( - is_valid.to(), - "probability tensor contains either `inf`, `nan` or element < 0"); - bool zero_prob_condition = false; + auto is_valid = ((self.max() < INFINITY) & (self.min() >= 0)); + at::_assert_async(is_valid, "probability tensor contains either `inf`, `nan` or element < 0"); + at::Tensor zero_prob_condition; if (self.dim() == 1){ - zero_prob_condition = (self.sum() == 0).item().to(); + zero_prob_condition = (self.sum() == 0); } else { - zero_prob_condition = (self.sum(1) == 0).sum().item().to(); + zero_prob_condition = (self.sum(1) == 0).any(); } - TORCH_CHECK( - !zero_prob_condition, - "invalid multinomial distribution (sum of probabilities <= 0)"); + at::_assert_async(~zero_prob_condition, "invalid multinomial distribution (sum of probabilities <= 0)"); // The algorithm is from gumbel softmax. // s = argmax( logp - log(-log(eps)) ) where eps ~ U(0, 1) diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index a64cb62cede..2120f4f019e 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -2849,28 +2849,29 @@ def error_inputs_multinomial(op_info, device, **kwargs): rep_arg = (False, True) if torch.device(device).type == 'cpu' else (False,) - for rep in rep_arg: - kwargs = {'num_samples': 2, 'replacement': rep} + if torch.device(device).type == 'cpu': + for rep in rep_arg: + kwargs = {'num_samples': 2, 'replacement': rep} - for shape in inputs: - # error case when input tensor contains `inf`, `nan` or negative element - yield ErrorInput(SampleInput(torch.tensor(shape), kwargs=kwargs), - error_regex=err_msg1 if rep is False else err_msg2) + for shape in inputs: + # error case when input tensor contains `inf`, `nan` or negative element + yield ErrorInput(SampleInput(torch.tensor(shape), kwargs=kwargs), + error_regex=err_msg1 if rep is False else err_msg2) - # error case for the invalid multinomial distribution (sum of probabilities <= 0), 1-D input - x = torch.zeros(3, device=device) - yield ErrorInput(SampleInput(x, kwargs=kwargs), - error_regex=err_msg2) + # error case for the invalid multinomial distribution (sum of probabilities <= 0), 1-D input + x = torch.zeros(3, device=device) + yield ErrorInput(SampleInput(x, kwargs=kwargs), + error_regex=err_msg2) - # error case for the invalid multinomial distribution (sum of probabilities <= 0), 2-D input - x = torch.zeros(3, 3, device=device) - yield ErrorInput(SampleInput(x, kwargs=kwargs), - error_regex=err_msg2) + # error case for the invalid multinomial distribution (sum of probabilities <= 0), 2-D input + x = torch.zeros(3, 3, device=device) + yield ErrorInput(SampleInput(x, kwargs=kwargs), + error_regex=err_msg2) - # error case for the invalid multinomial distribution - x[1, :] = 1 - yield ErrorInput(SampleInput(x, kwargs=kwargs), - error_regex=err_msg2) + # error case for the invalid multinomial distribution + x[1, :] = 1 + yield ErrorInput(SampleInput(x, kwargs=kwargs), + error_regex=err_msg2) def error_inputs_gradient(op_info, device, **kwargs): for dtype in [torch.long, torch.float32, torch.complex64]: