[NEEDS REVIEW] Add nan and inf probability check to multinomial (#7647)

* Add nan and inf probs check to multinomial

* fix bug

* Spawn CUDA test in subprocess

* Make sure invalid input won't pass the test case

* Try to fix error

* Test failure cases in Python 3 only

* Try to fix Windows error

* Move CUDA test to test_cuda.py

* fix issues

* fix module name error

* no need to check for CUDA existence in test_cuda

* Use PY3
This commit is contained in:
Will Feng 2018-06-06 22:49:12 -04:00 committed by Edward Z. Yang
parent 784c46ba1d
commit 89ea6acde2
5 changed files with 85 additions and 0 deletions

View File

@ -2,6 +2,8 @@
#define TH_GENERIC_FILE "generic/THTensorRandom.cpp"
#else
#include <cmath>
#ifdef _OPENMP
#include <omp.h>
#endif
@ -403,6 +405,10 @@ void THTensor_(multinomial)(THLongTensor *self, THGenerator *_generator, THTenso
THCleanup(THDoubleTensor_free(cum_dist); if (start_dim == 1) THTensor_(squeeze1d)(prob_dist, prob_dist, 0);),
2,
"invalid multinomial distribution (encountering probability entry < 0)");
THArgCheckWithCleanup((std::isfinite(val)),
THCleanup(THDoubleTensor_free(cum_dist); if (start_dim == 1) THTensor_(squeeze1d)(prob_dist, prob_dist, 0);),
2,
"invalid multinomial distribution (encountering probability entry = infinity or NaN)");
sum += val;
THDoubleStorage_set(
cum_dist->storage, \

View File

@ -48,6 +48,7 @@ struct THCNumerics<uint8_t> {
static inline __host__ __device__ uint8_t abs(uint8_t a) { return a; }
static inline __host__ __device__ uint8_t pow(uint8_t a, uint8_t b) { return powi<uint8_t>(a, b); }
static inline __host__ __device__ bool isnan(uint8_t a) { return false; }
static inline __host__ __device__ bool isinf(uint8_t a) { return false; }
};
template <>
@ -70,6 +71,7 @@ struct THCNumerics<int8_t> {
static inline __host__ __device__ int8_t abs(int8_t a) { return ::abs((int)a); }
static inline __host__ __device__ int8_t pow(int8_t a, int8_t b) { return powi<int8_t>(a, b); }
static inline __host__ __device__ bool isnan(int8_t a) { return false; }
static inline __host__ __device__ bool isinf(int8_t a) { return false; }
};
template <>
@ -92,6 +94,7 @@ struct THCNumerics<int16_t> {
static inline __host__ __device__ int16_t abs(int16_t a) { return ::abs((int)a); }
static inline __host__ __device__ int16_t pow(int16_t a, int16_t b) { return powi<int16_t>(a, b); }
static inline __host__ __device__ bool isnan(int16_t a) { return false; }
static inline __host__ __device__ bool isinf(int16_t a) { return false; }
};
template <>
@ -114,6 +117,7 @@ struct THCNumerics<int32_t> {
static inline __host__ __device__ int32_t abs(int32_t a) { return ::abs(a); }
static inline __host__ __device__ int32_t pow(int32_t a, int32_t b) { return powi<int32_t>(a, b); }
static inline __host__ __device__ bool isnan(int32_t a) { return false; }
static inline __host__ __device__ bool isinf(int32_t a) { return false; }
};
template <>
@ -142,6 +146,7 @@ struct THCNumerics<int64_t> {
static inline __host__ __device__ int64_t abs(int64_t a) { return labs(a); }
static inline __host__ __device__ int64_t pow(int64_t a, int64_t b) { return powi<int64_t>(a, b); }
static inline __host__ __device__ bool isnan(int64_t a) { return false; }
static inline __host__ __device__ bool isinf(int64_t a) { return false; }
};
#ifdef CUDA_HALF_TENSOR
@ -624,6 +629,19 @@ static inline __host__ __device__ half lgamma(half a) {
return ne(a, a);
}
static inline __host__ __device__ bool isinf(half a) {
#ifdef __CUDA_ARCH__
#ifdef CUDA_HALF_INSTRUCTIONS
return __hisinf(a) != 0;
#else
float fa = __half2float(a);
return ::isinf(fa);
#endif
#else // __CUDA_ARCH__
return ::isinf(THC_half2float(a));
#endif
}
};
#endif
@ -677,6 +695,7 @@ struct THCNumerics<float> {
static inline __host__ __device__ float pow (float a, float b) { return powf(a, b); }
static inline __host__ __device__ float atan2(float a, float b) { return atan2f(a, b); }
static inline __host__ __device__ bool isnan(float a) { return ::isnan(a); }
static inline __host__ __device__ bool isinf(float a) { return ::isinf(a); }
};
template <>
@ -729,6 +748,7 @@ struct THCNumerics<double> {
static inline __host__ __device__ double pow (double a, double b) { return ::pow(a, b); }
static inline __host__ __device__ double atan2(double a, double b) { return ::atan2(a, b); }
static inline __host__ __device__ bool isnan(double a) { return ::isnan(a); }
static inline __host__ __device__ bool isinf(double a) { return ::isinf(a); }
};
/// `half` has some type conversion issues associated with it, since it

View File

@ -183,6 +183,8 @@ sampleMultinomialOnce(int64_t* dest,
for (int cat = threadIdx.x; cat < categories; cat += blockDim.x) {
val = dist[curDist * stride_dist + cat * stride_categories];
assert(THCNumerics<T>::ge(val, zero));
assert(!THCNumerics<T>::isinf(val));
assert(!THCNumerics<T>::isnan(val));
sum = THCNumerics<AccT>::add(sum, ScalarConvert<T, AccT>::to(val));
}

View File

@ -3,11 +3,13 @@ import math
import tempfile
import re
import unittest
import sys
from itertools import repeat
import torch
import torch.cuda
import torch.cuda.comm as comm
from torch import multiprocessing as mp
from test_torch import TestTorch
from common import TestCase, get_gpu_type, to_gpu, freeze_rng_state, run_tests, PY3
@ -1399,6 +1401,34 @@ class TestCuda(TestCase):
sample = torch.multinomial(freqs, 1000, True)
self.assertNotEqual(freqs[sample].min(), 0)
def _spawn_method(self, method, arg):
try:
mp.set_start_method('spawn')
except RuntimeError:
pass
with mp.Pool(1) as pool:
self.assertTrue(pool.map(method, [arg]))
@staticmethod
def _test_multinomial_invalid_probs_cuda(probs):
try:
with torch.random.fork_rng(devices=[0]):
torch.multinomial(probs.to('cuda'), 1)
torch.cuda.synchronize()
return False # Should not be reached
except RuntimeError as e:
return 'device-side assert triggered' in str(e)
@unittest.skipIf(not PY3,
"spawn start method is not supported in Python 2, \
but we need it for creating another process with CUDA")
def test_multinomial_invalid_probs_cuda(self):
test_method = TestCuda._test_multinomial_invalid_probs_cuda
self._spawn_method(test_method, torch.Tensor([0, -1]))
self._spawn_method(test_method, torch.Tensor([0, float('inf')]))
self._spawn_method(test_method, torch.Tensor([0, float('-inf')]))
self._spawn_method(test_method, torch.Tensor([0, float('nan')]))
def test_broadcast(self):
TestTorch._test_broadcast(self, lambda t: t.cuda())

View File

@ -17,6 +17,7 @@ from torch.utils.dlpack import from_dlpack, to_dlpack
from torch._utils import _rebuild_tensor
from itertools import product, combinations
from functools import reduce
from torch import multiprocessing as mp
from common import TestCase, iter_indices, TEST_NUMPY, TEST_SCIPY, TEST_MKL, \
run_tests, download_file, skipIfNoLapack, suppress_warnings, IS_WINDOWS, PY3
@ -2368,6 +2369,32 @@ class TestTorch(TestCase):
def test_multinomial(self):
self._test_multinomial(self, torch.FloatTensor)
def _spawn_method(self, method, arg):
try:
mp.set_start_method('spawn')
except RuntimeError:
pass
with mp.Pool(1) as pool:
self.assertTrue(pool.map(method, [arg]))
@staticmethod
def _test_multinomial_invalid_probs(probs):
try:
torch.multinomial(probs.to('cpu'), 1)
return False # Should not be reached
except RuntimeError as e:
return 'invalid multinomial distribution' in str(e)
@unittest.skipIf(not PY3,
"spawn start method is not supported in Python 2, \
but we need it for for testing failure case for CPU RNG on Windows")
def test_multinomial_invalid_probs(self):
test_method = TestTorch._test_multinomial_invalid_probs
self._spawn_method(test_method, torch.Tensor([0, -1]))
self._spawn_method(test_method, torch.Tensor([0, float('inf')]))
self._spawn_method(test_method, torch.Tensor([0, float('-inf')]))
self._spawn_method(test_method, torch.Tensor([0, float('nan')]))
@suppress_warnings
def test_range(self):
res1 = torch.range(0, 1)