mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
change multinomial to use async asserts instead of a synchronization (#134818)
Fixes https://github.com/pytorch/pytorch/issues/134442 Pull Request resolved: https://github.com/pytorch/pytorch/pull/134818 Approved by: https://github.com/ezyang ghstack dependencies: #134813
This commit is contained in:
parent
db193d1e29
commit
6fce1faa10
|
|
@ -23,6 +23,7 @@
|
|||
#include <ATen/ops/_sample_dirichlet_native.h>
|
||||
#include <ATen/ops/_standard_gamma_grad_native.h>
|
||||
#include <ATen/ops/_standard_gamma_native.h>
|
||||
#include <ATen/ops/_assert_async.h>
|
||||
#include <ATen/ops/argmax.h>
|
||||
#include <ATen/ops/bernoulli_native.h>
|
||||
#include <ATen/ops/binomial_native.h>
|
||||
|
|
@ -585,19 +586,15 @@ Tensor& multinomial_out(const Tensor& self,
|
|||
// https://github.com/pytorch/pytorch/issues/11931#issuecomment-625882503
|
||||
if (!with_replacement || n_sample == 1) {
|
||||
// Sanity checks on `self`.
|
||||
auto is_valid = ((self.max() < INFINITY) & (self.min() >= 0)).item();
|
||||
TORCH_CHECK(
|
||||
is_valid.to<bool>(),
|
||||
"probability tensor contains either `inf`, `nan` or element < 0");
|
||||
bool zero_prob_condition = false;
|
||||
auto is_valid = ((self.max() < INFINITY) & (self.min() >= 0));
|
||||
at::_assert_async(is_valid, "probability tensor contains either `inf`, `nan` or element < 0");
|
||||
at::Tensor zero_prob_condition;
|
||||
if (self.dim() == 1){
|
||||
zero_prob_condition = (self.sum() == 0).item().to<bool>();
|
||||
zero_prob_condition = (self.sum() == 0);
|
||||
} else {
|
||||
zero_prob_condition = (self.sum(1) == 0).sum().item().to<bool>();
|
||||
zero_prob_condition = (self.sum(1) == 0).any();
|
||||
}
|
||||
TORCH_CHECK(
|
||||
!zero_prob_condition,
|
||||
"invalid multinomial distribution (sum of probabilities <= 0)");
|
||||
at::_assert_async(~zero_prob_condition, "invalid multinomial distribution (sum of probabilities <= 0)");
|
||||
|
||||
// The algorithm is from gumbel softmax.
|
||||
// s = argmax( logp - log(-log(eps)) ) where eps ~ U(0, 1)
|
||||
|
|
|
|||
|
|
@ -2849,28 +2849,29 @@ def error_inputs_multinomial(op_info, device, **kwargs):
|
|||
|
||||
rep_arg = (False, True) if torch.device(device).type == 'cpu' else (False,)
|
||||
|
||||
for rep in rep_arg:
|
||||
kwargs = {'num_samples': 2, 'replacement': rep}
|
||||
if torch.device(device).type == 'cpu':
|
||||
for rep in rep_arg:
|
||||
kwargs = {'num_samples': 2, 'replacement': rep}
|
||||
|
||||
for shape in inputs:
|
||||
# error case when input tensor contains `inf`, `nan` or negative element
|
||||
yield ErrorInput(SampleInput(torch.tensor(shape), kwargs=kwargs),
|
||||
error_regex=err_msg1 if rep is False else err_msg2)
|
||||
for shape in inputs:
|
||||
# error case when input tensor contains `inf`, `nan` or negative element
|
||||
yield ErrorInput(SampleInput(torch.tensor(shape), kwargs=kwargs),
|
||||
error_regex=err_msg1 if rep is False else err_msg2)
|
||||
|
||||
# error case for the invalid multinomial distribution (sum of probabilities <= 0), 1-D input
|
||||
x = torch.zeros(3, device=device)
|
||||
yield ErrorInput(SampleInput(x, kwargs=kwargs),
|
||||
error_regex=err_msg2)
|
||||
# error case for the invalid multinomial distribution (sum of probabilities <= 0), 1-D input
|
||||
x = torch.zeros(3, device=device)
|
||||
yield ErrorInput(SampleInput(x, kwargs=kwargs),
|
||||
error_regex=err_msg2)
|
||||
|
||||
# error case for the invalid multinomial distribution (sum of probabilities <= 0), 2-D input
|
||||
x = torch.zeros(3, 3, device=device)
|
||||
yield ErrorInput(SampleInput(x, kwargs=kwargs),
|
||||
error_regex=err_msg2)
|
||||
# error case for the invalid multinomial distribution (sum of probabilities <= 0), 2-D input
|
||||
x = torch.zeros(3, 3, device=device)
|
||||
yield ErrorInput(SampleInput(x, kwargs=kwargs),
|
||||
error_regex=err_msg2)
|
||||
|
||||
# error case for the invalid multinomial distribution
|
||||
x[1, :] = 1
|
||||
yield ErrorInput(SampleInput(x, kwargs=kwargs),
|
||||
error_regex=err_msg2)
|
||||
# error case for the invalid multinomial distribution
|
||||
x[1, :] = 1
|
||||
yield ErrorInput(SampleInput(x, kwargs=kwargs),
|
||||
error_regex=err_msg2)
|
||||
|
||||
def error_inputs_gradient(op_info, device, **kwargs):
|
||||
for dtype in [torch.long, torch.float32, torch.complex64]:
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user