Add mutex to THC random number generator (#6527)

* Add mutex to THC random number generator

* Add test for CUDA RNG multithread

* fix lint

* Rename gen_state to state and remove unnecessary mutex lock

* Remove RNG test from cpp_extensions

* Add CUDA RNG test to libtorch

* Build test_rng only if CUDA exists

* Move test to aten/src/ATen/test/

* Separate ATen build and test, and run ATen test in CI test phase

* Don't test ATen in ASAN build

* Fix bug in ATen scalar_test

* Fix bug in ATen native_test

* Add FIXME to some CUDA tests in scalar_tensor_test

* Valgrind doesn't work well with CUDA, seed the CPU and CUDA RNG separately instead
This commit is contained in:
Will Feng 2018-04-18 15:54:13 -04:00 committed by GitHub
parent c25f097225
commit e089849b4a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
25 changed files with 167 additions and 93 deletions

View File

@ -29,10 +29,13 @@ fi
python setup.py install python setup.py install
# Test ATen # Add the ATen test binaries so that they won't be git clean'ed away
git add -f aten/build/src/ATen/test
# Testing ATen install
if [[ "$BUILD_ENVIRONMENT" != *cuda* ]]; then if [[ "$BUILD_ENVIRONMENT" != *cuda* ]]; then
echo "Testing ATen" echo "Testing ATen install"
time tools/run_aten_tests.sh time tools/test_aten_install.sh
fi fi
# Test C FFI plugins # Test C FFI plugins

View File

@ -29,6 +29,14 @@ fi
time python test/run_test.py --verbose time python test/run_test.py --verbose
# Test ATen
if [[ "$BUILD_ENVIRONMENT" != *asan* ]]; then
echo "Testing ATen"
TORCH_LIB_PATH=$(python -c "import site; print(site.getsitepackages()[0])")/torch/lib
ln -s "$TORCH_LIB_PATH"/libATen.so aten/build/src/ATen/libATen.so
aten/tools/run_tests.sh aten/build
fi
rm -rf ninja rm -rf ninja
echo "Installing torchvision at branch master" echo "Installing torchvision at branch master"

View File

@ -10,6 +10,7 @@
#include <THC/THCGeneral.h> #include <THC/THCGeneral.h>
#include <THC/THCTensorRandom.h> #include <THC/THCTensorRandom.h>
#include <THC/THCGenerator.h>
#include <THC/THCApply.cuh> #include <THC/THCApply.cuh>
#include <THC/THCNumerics.cuh> #include <THC/THCNumerics.cuh>
@ -21,8 +22,8 @@ THCGenerator* THCRandom_getGenerator(THCState* state);
namespace { namespace {
std::pair<uint64_t, uint64_t> next_philox_seed(at::Generator* gen) { std::pair<uint64_t, uint64_t> next_philox_seed(at::Generator* gen) {
auto gen_ = THCRandom_getGenerator(at::globalContext().thc_state); auto gen_ = THCRandom_getGenerator(at::globalContext().thc_state);
uint64_t offset = THAtomicAddLong(&gen_->philox_seed_offset, 1); uint64_t offset = THAtomicAddLong(&gen_->state.philox_seed_offset, 1);
return std::make_pair(gen_->initial_seed, offset); return std::make_pair(gen_->state.initial_seed, offset);
} }
template <typename scalar_t> template <typename scalar_t>

View File

@ -45,6 +45,11 @@ if(NOT NO_CUDA)
target_link_libraries(integer_divider_test ATen) target_link_libraries(integer_divider_test ATen)
endif() endif()
if(NOT NO_CUDA)
cuda_add_executable(cuda_rng_test cuda_rng_test.cpp)
target_link_libraries(cuda_rng_test ATen)
endif()
if (CUDNN_FOUND) if (CUDNN_FOUND)
add_executable(cudnn_test cudnn_test.cpp) add_executable(cudnn_test cudnn_test.cpp)
target_link_libraries(cudnn_test ATen) target_link_libraries(cudnn_test ATen)

View File

@ -24,7 +24,8 @@ void trace() {
TEST_CASE( "atest", "[]" ) { TEST_CASE( "atest", "[]" ) {
manual_seed(123); manual_seed(123, at::Backend::CPU);
manual_seed(123, at::Backend::CUDA);
auto foo = rand(CPU(kFloat), {12,6}); auto foo = rand(CPU(kFloat), {12,6});
REQUIRE(foo.data<float>() == foo.toFloatData()); REQUIRE(foo.data<float>() == foo.toFloatData());

View File

@ -274,13 +274,13 @@ static void test(Type & type) {
} }
TEST_CASE( "basic tests CPU", "[cpu]" ) { TEST_CASE( "basic tests CPU", "[cpu]" ) {
manual_seed(123); manual_seed(123, at::Backend::CPU);
test(CPU(kFloat)); test(CPU(kFloat));
} }
TEST_CASE( "basic tests GPU", "[cuda]" ) { TEST_CASE( "basic tests GPU", "[cuda]" ) {
manual_seed(123); manual_seed(123, at::Backend::CUDA);
if(at::hasCUDA()) { if(at::hasCUDA()) {
test(CUDA(kFloat)); test(CUDA(kFloat));

View File

@ -8,7 +8,7 @@ using namespace at;
TEST_CASE( "broadcast", "[]" ) { TEST_CASE( "broadcast", "[]" ) {
manual_seed(123); manual_seed(123, at::Backend::CPU);
Type & T = CPU(kFloat); Type & T = CPU(kFloat);

View File

@ -0,0 +1,27 @@
#define CATCH_CONFIG_MAIN
#include "catch.hpp"
#include "ATen/ATen.h"
#include "cuda.h"
#include "cuda_runtime.h"
#include <thread>
void makeRandomNumber() {
cudaSetDevice(std::rand() % 2);
auto x = at::CUDA(at::kFloat).randn({1000});
}
void testCudaRNGMultithread() {
auto threads = std::vector<std::thread>();
for (auto i = 0; i < 1000; i++) {
threads.emplace_back(makeRandomNumber);
}
for (auto& t : threads) {
t.join();
}
};
TEST_CASE( "CUDA RNG test", "[cuda]" ) {
SECTION( "multithread" )
testCudaRNGMultithread();
}

View File

@ -10,7 +10,7 @@ using namespace at;
using namespace at::native; using namespace at::native;
TEST_CASE( "cudnn", "[cuda]" ) { TEST_CASE( "cudnn", "[cuda]" ) {
manual_seed(123); manual_seed(123, at::Backend::CUDA);
#if CUDNN_VERSION < 7000 #if CUDNN_VERSION < 7000
auto handle = getCudnnHandle(); auto handle = getCudnnHandle();

View File

@ -13,7 +13,7 @@ using namespace at;
TEST_CASE( "dlconvertor", "[cpu]" ) { TEST_CASE( "dlconvertor", "[cpu]" ) {
manual_seed(123); manual_seed(123, at::Backend::CPU);
INFO( "convert ATen to DLTensor" ); INFO( "convert ATen to DLTensor" );

View File

@ -163,9 +163,9 @@ void test(Type & T, Type & AccT) {
auto ct1 = randn(T, {3, 4}); auto ct1 = randn(T, {3, 4});
auto ct2 = randn(T, {3, 4}); auto ct2 = randn(T, {3, 4});
auto t1 = randn(T.toBackend(Backend::CPU), {3, 4}); auto t1 = randn(T.toBackend(Backend::CPU), {3, 4});
REQUIRE_THROWS_WITH(ct1._standard_gamma_grad(ct2), "not implemented"); REQUIRE_THROWS_WITH(ct1._standard_gamma_grad(ct2), Catch::Contains("not implemented"));
REQUIRE_THROWS_WITH(ct1._standard_gamma_grad(t1), "not implemented"); REQUIRE_THROWS_WITH(ct1._standard_gamma_grad(t1), Catch::Contains("not implemented"));
REQUIRE_THROWS_WITH(t1._standard_gamma_grad(ct2), "CUDA Backend"); REQUIRE_THROWS_WITH(t1._standard_gamma_grad(ct2), Catch::Contains("CUDA Backend"));
} }
} }
@ -189,13 +189,13 @@ void test(Type & T, Type & AccT) {
} }
TEST_CASE( "native test CPU", "[cpu]" ) { TEST_CASE( "native test CPU", "[cpu]" ) {
manual_seed(123); manual_seed(123, at::Backend::CPU);
test(CPU(kFloat), CPU(kDouble)); test(CPU(kFloat), CPU(kDouble));
} }
TEST_CASE( "native test CUDA", "[cuda]" ) { TEST_CASE( "native test CUDA", "[cuda]" ) {
manual_seed(123); manual_seed(123, at::Backend::CUDA);
if (at::hasCUDA()) { if (at::hasCUDA()) {
test(CUDA(kFloat), CUDA(kDouble)); test(CUDA(kFloat), CUDA(kDouble));

View File

@ -132,8 +132,10 @@ void test(Type &T) {
if (t.numel() != 0) { if (t.numel() != 0) {
REQUIRE(t.sum(0).dim() == std::max<int64_t>(t.dim() - 1, 0)); REQUIRE(t.sum(0).dim() == std::max<int64_t>(t.dim() - 1, 0));
} else { } else {
if (!T.is_cuda()) { // FIXME: out of range exception in CUDA
REQUIRE(t.sum(0).equal(T.tensor({0}))); REQUIRE(t.sum(0).equal(T.tensor({0})));
} }
}
// reduce (with dimension argument and with 2 return arguments) // reduce (with dimension argument and with 2 return arguments)
if (t.numel() != 0) { if (t.numel() != 0) {
@ -273,13 +275,13 @@ void test(Type &T) {
} }
TEST_CASE( "scalar tensor test CPU", "[cpu]" ) { TEST_CASE( "scalar tensor test CPU", "[cpu]" ) {
manual_seed(123); manual_seed(123, at::Backend::CPU);
test(CPU(kFloat)); test(CPU(kFloat));
} }
TEST_CASE( "scalar tensor test CUDA", "[cuda]" ) { TEST_CASE( "scalar tensor test CUDA", "[cuda]" ) {
manual_seed(123); manual_seed(123, at::Backend::CUDA);
if (at::hasCUDA()) { if (at::hasCUDA()) {
test(CUDA(kFloat)); test(CUDA(kFloat));

View File

@ -72,7 +72,8 @@ void test_overflow() {
TEST_CASE( "scalar test", "[]" ) { TEST_CASE( "scalar test", "[]" ) {
manual_seed(123); manual_seed(123, at::Backend::CPU);
manual_seed(123, at::Backend::CUDA);
Scalar what = 257; Scalar what = 257;
Scalar bar = 3.0; Scalar bar = 3.0;
@ -83,7 +84,7 @@ TEST_CASE( "scalar test", "[]" ) {
REQUIRE_NOTHROW(gen.seed()); REQUIRE_NOTHROW(gen.seed());
auto && C = at::globalContext(); auto && C = at::globalContext();
if(at::hasCUDA()) { if(at::hasCUDA()) {
auto & CUDAFloat = C.getType(Backend::CPU,ScalarType::Float); auto & CUDAFloat = C.getType(Backend::CUDA,ScalarType::Float);
auto t2 = zeros(CUDAFloat, {4,4}); auto t2 = zeros(CUDAFloat, {4,4});
cout << &t2 << "\n"; cout << &t2 << "\n";
cout << "AFTER GET TYPE " << &CUDAFloat << "\n"; cout << "AFTER GET TYPE " << &CUDAFloat << "\n";

View File

@ -24,7 +24,7 @@ void test(int given_num_threads) {
} }
int main() { int main() {
manual_seed(123); manual_seed(123, at::Backend::CPU);
test(-1); test(-1);
std::thread t1(test, -1); std::thread t1(test, -1);

View File

@ -13,7 +13,7 @@ using namespace at;
TEST_CASE( "parallel", "[cpu]" ) { TEST_CASE( "parallel", "[cpu]" ) {
manual_seed(123); manual_seed(123, at::Backend::CPU);
set_num_threads(1); set_num_threads(1);
Tensor a = rand(CPU(at::kFloat), {1,3}); Tensor a = rand(CPU(at::kFloat), {1,3});

View File

@ -2,10 +2,11 @@
#include "ATen/ATen.h" #include "ATen/ATen.h"
void manual_seed(uint64_t seed) { void manual_seed(uint64_t seed, at::Backend backend) {
if (backend == at::Backend::CPU) {
at::Generator & cpu_gen = at::globalContext().defaultGenerator(at::Backend::CPU); at::Generator & cpu_gen = at::globalContext().defaultGenerator(at::Backend::CPU);
cpu_gen.manualSeed(seed); cpu_gen.manualSeed(seed);
if (at::hasCUDA()) { } else if (backend == at::Backend::CUDA && at::hasCUDA()) {
at::Generator & cuda_gen = at::globalContext().defaultGenerator(at::Backend::CUDA); at::Generator & cuda_gen = at::globalContext().defaultGenerator(at::Backend::CUDA);
cuda_gen.manualSeed(seed); cuda_gen.manualSeed(seed);
} }

View File

@ -9,7 +9,7 @@
using namespace at; using namespace at;
TEST_CASE( "undefined tensor test", "[]" ) { TEST_CASE( "undefined tensor test", "[]" ) {
manual_seed(123); manual_seed(123, at::Backend::CPU);
// mainly test ops on undefined tensors don't segfault and give a reasonable errror message. // mainly test ops on undefined tensors don't segfault and give a reasonable errror message.
Tensor und; Tensor und;

View File

@ -7,7 +7,7 @@
using namespace at; using namespace at;
TEST_CASE( "wrapdim test", "[]" ) { TEST_CASE( "wrapdim test", "[]" ) {
manual_seed(123); manual_seed(123, at::Backend::CPU);
Type & T = CPU(kFloat); Type & T = CPU(kFloat);

View File

@ -0,0 +1,19 @@
#ifndef THC_GENERATOR_INC
#define THC_GENERATOR_INC
#include <mutex>
typedef struct THCGeneratorState {
struct curandStateMtgp32* gen_states;
struct mtgp32_kernel_params *kernel_params;
int initf;
uint64_t initial_seed;
int64_t philox_seed_offset;
} THCGeneratorState;
struct THCGenerator {
std::mutex mutex; /* mutex for using this generator */
THCGeneratorState state;
};
#endif

View File

@ -1,4 +1,5 @@
#include "THCTensorRandom.h" #include "THCTensorRandom.h"
#include "THCGenerator.h"
#include <random> #include <random>
#include <curand.h> #include <curand.h>
@ -11,15 +12,16 @@ void createGeneratorState(THCGenerator* gen, uint64_t seed);
/* Frees memory allocated during setup. */ /* Frees memory allocated during setup. */
void destroyGenerator(THCState *state, THCGenerator* gen) void destroyGenerator(THCState *state, THCGenerator* gen)
{ {
if (gen->gen_states) std::lock_guard<std::mutex> lock(gen->mutex);
if (gen->state.gen_states)
{ {
THCudaCheck(THCudaFree(state, gen->gen_states)); THCudaCheck(THCudaFree(state, gen->state.gen_states));
gen->gen_states = NULL; gen->state.gen_states = NULL;
} }
if (gen->kernel_params) if (gen->state.kernel_params)
{ {
THCudaCheck(THCudaFree(state, gen->kernel_params)); THCudaCheck(THCudaFree(state, gen->state.kernel_params));
gen->kernel_params = NULL; gen->state.kernel_params = NULL;
} }
} }
@ -39,11 +41,12 @@ void THCRandom_init(THCState* state, int devices, int current_device)
std::random_device rd; std::random_device rd;
for (int i = 0; i < rng_state->num_devices; ++i) for (int i = 0; i < rng_state->num_devices; ++i)
{ {
rng_state->gen[i].initf = 0; new (&rng_state->gen[i].mutex) std::mutex();
rng_state->gen[i].initial_seed = createSeed(rd); rng_state->gen[i].state.initf = 0;
rng_state->gen[i].philox_seed_offset = 0; rng_state->gen[i].state.initial_seed = createSeed(rd);
rng_state->gen[i].gen_states = NULL; rng_state->gen[i].state.philox_seed_offset = 0;
rng_state->gen[i].kernel_params = NULL; rng_state->gen[i].state.gen_states = NULL;
rng_state->gen[i].state.kernel_params = NULL;
} }
} }
@ -74,18 +77,20 @@ static THCGenerator* THCRandom_rawGenerator(THCState* state)
THCGenerator* THCRandom_getGenerator(THCState* state) THCGenerator* THCRandom_getGenerator(THCState* state)
{ {
THCGenerator* gen = THCRandom_rawGenerator(state); THCGenerator* gen = THCRandom_rawGenerator(state);
if (gen->initf == 0) std::lock_guard<std::mutex> lock(gen->mutex);
if (gen->state.initf == 0)
{ {
initializeGenerator(state, gen); initializeGenerator(state, gen);
createGeneratorState(gen, gen->initial_seed); createGeneratorState(gen, gen->state.initial_seed);
gen->initf = 1; gen->state.initf = 1;
} }
return gen; return gen;
} }
struct curandStateMtgp32* THCRandom_generatorStates(struct THCState* state) struct curandStateMtgp32* THCRandom_generatorStates(struct THCState* state)
{ {
return THCRandom_getGenerator(state)->gen_states; THCGenerator* gen = THCRandom_getGenerator(state);
return gen->state.gen_states;
} }
/* Random seed */ /* Random seed */
@ -109,8 +114,9 @@ uint64_t THCRandom_seedAll(THCState* state)
void THCRandom_manualSeed(THCState* state, uint64_t seed) void THCRandom_manualSeed(THCState* state, uint64_t seed)
{ {
THCGenerator* gen = THCRandom_rawGenerator(state); THCGenerator* gen = THCRandom_rawGenerator(state);
gen->initial_seed = seed; std::lock_guard<std::mutex> lock(gen->mutex);
if (gen->initf) { gen->state.initial_seed = seed;
if (gen->state.initf) {
createGeneratorState(gen, seed); createGeneratorState(gen, seed);
} }
} }
@ -130,5 +136,6 @@ void THCRandom_manualSeedAll(THCState* state, uint64_t seed)
/* Get the initial seed */ /* Get the initial seed */
uint64_t THCRandom_initialSeed(THCState* state) uint64_t THCRandom_initialSeed(THCState* state)
{ {
return THCRandom_getGenerator(state)->initial_seed; THCGenerator* gen = THCRandom_getGenerator(state);
return gen->state.initial_seed;
} }

View File

@ -5,6 +5,7 @@
#include "THCTensorMath.h" #include "THCTensorMath.h"
#include "THCReduceApplyUtils.cuh" #include "THCReduceApplyUtils.cuh"
#include "THCTensorRandom.cuh" #include "THCTensorRandom.cuh"
#include "THCGenerator.h"
#include <thrust/functional.h> #include <thrust/functional.h>
#include <curand.h> #include <curand.h>
@ -18,22 +19,22 @@
THCGenerator* THCRandom_getGenerator(THCState* state); THCGenerator* THCRandom_getGenerator(THCState* state);
/* Sets up generator. Allocates but does not create the generator states. */ /* Sets up generator. Allocates but does not create the generator states. Not thread-safe. */
__host__ void initializeGenerator(THCState *state, THCGenerator* gen) __host__ void initializeGenerator(THCState *state, THCGenerator* gen)
{ {
THCudaCheck(THCudaMalloc(state, (void**)&gen->gen_states, MAX_NUM_BLOCKS * sizeof(curandStateMtgp32))); THCudaCheck(THCudaMalloc(state, (void**)&gen->state.gen_states, MAX_NUM_BLOCKS * sizeof(curandStateMtgp32)));
THCudaCheck(THCudaMalloc(state, (void**)&gen->kernel_params, sizeof(mtgp32_kernel_params))); THCudaCheck(THCudaMalloc(state, (void**)&gen->state.kernel_params, sizeof(mtgp32_kernel_params)));
} }
/* Creates a new generator state given the seed. */ /* Creates a new generator state given the seed. Not thread-safe. */
__host__ void createGeneratorState(THCGenerator* gen, uint64_t seed) __host__ void createGeneratorState(THCGenerator* gen, uint64_t seed)
{ {
if (curandMakeMTGP32Constants(mtgp32dc_params_fast_11213, gen->kernel_params) != CURAND_STATUS_SUCCESS) if (curandMakeMTGP32Constants(mtgp32dc_params_fast_11213, gen->state.kernel_params) != CURAND_STATUS_SUCCESS)
{ {
THError("Creating MTGP constants failed."); THError("Creating MTGP constants failed.");
} }
if (curandMakeMTGP32KernelState(gen->gen_states, mtgp32dc_params_fast_11213, if (curandMakeMTGP32KernelState(gen->state.gen_states, mtgp32dc_params_fast_11213,
gen->kernel_params, MAX_NUM_BLOCKS, seed) != CURAND_STATUS_SUCCESS) gen->state.kernel_params, MAX_NUM_BLOCKS, seed) != CURAND_STATUS_SUCCESS)
{ {
THError("Creating MTGP kernel state failed."); THError("Creating MTGP kernel state failed.");
} }
@ -42,19 +43,20 @@ __host__ void createGeneratorState(THCGenerator* gen, uint64_t seed)
__host__ void THCRandom_getRNGState(THCState* state, THByteTensor *rng_state) __host__ void THCRandom_getRNGState(THCState* state, THByteTensor *rng_state)
{ {
THCGenerator* gen = THCRandom_getGenerator(state); THCGenerator* gen = THCRandom_getGenerator(state);
std::lock_guard<std::mutex> lock(gen->mutex);
// The RNG state comprises the MTPG32 states, the seed, and an offset used for Philox // The RNG state comprises the MTPG32 states, the seed, and an offset used for Philox
static const size_t states_size = MAX_NUM_BLOCKS * sizeof(curandStateMtgp32); static const size_t states_size = MAX_NUM_BLOCKS * sizeof(curandStateMtgp32);
static const size_t seed_size = sizeof(gen->initial_seed); static const size_t seed_size = sizeof(gen->state.initial_seed);
static const size_t offset_size = sizeof(gen->philox_seed_offset); static const size_t offset_size = sizeof(gen->state.philox_seed_offset);
static const size_t total_size = states_size + seed_size + offset_size; static const size_t total_size = states_size + seed_size + offset_size;
THByteTensor_resize1d(rng_state, total_size); THByteTensor_resize1d(rng_state, total_size);
THArgCheck(THByteTensor_nElement(rng_state) == total_size, 1, "RNG state is wrong size"); THArgCheck(THByteTensor_nElement(rng_state) == total_size, 1, "RNG state is wrong size");
THArgCheck(THByteTensor_isContiguous(rng_state), 1, "RNG state must be contiguous"); THArgCheck(THByteTensor_isContiguous(rng_state), 1, "RNG state must be contiguous");
THCudaCheck(cudaMemcpy(THByteTensor_data(rng_state), gen->gen_states, THCudaCheck(cudaMemcpy(THByteTensor_data(rng_state), gen->state.gen_states,
states_size, cudaMemcpyDeviceToHost)); states_size, cudaMemcpyDeviceToHost));
memcpy(THByteTensor_data(rng_state) + states_size, &gen->initial_seed, seed_size); memcpy(THByteTensor_data(rng_state) + states_size, &gen->state.initial_seed, seed_size);
memcpy(THByteTensor_data(rng_state) + states_size + seed_size, &gen->philox_seed_offset, offset_size); memcpy(THByteTensor_data(rng_state) + states_size + seed_size, &gen->state.philox_seed_offset, offset_size);
} }
__global__ void set_rngstate_kernel(curandStateMtgp32 *state, mtgp32_kernel_params *kernel) __global__ void set_rngstate_kernel(curandStateMtgp32 *state, mtgp32_kernel_params *kernel)
@ -65,10 +67,11 @@ __global__ void set_rngstate_kernel(curandStateMtgp32 *state, mtgp32_kernel_para
__host__ void THCRandom_setRNGState(THCState* state, THByteTensor *rng_state) __host__ void THCRandom_setRNGState(THCState* state, THByteTensor *rng_state)
{ {
THCGenerator* gen = THCRandom_getGenerator(state); THCGenerator* gen = THCRandom_getGenerator(state);
std::lock_guard<std::mutex> lock(gen->mutex);
static const size_t states_size = MAX_NUM_BLOCKS * sizeof(curandStateMtgp32); static const size_t states_size = MAX_NUM_BLOCKS * sizeof(curandStateMtgp32);
static const size_t seed_size = sizeof(gen->initial_seed); static const size_t seed_size = sizeof(gen->state.initial_seed);
static const size_t offset_size = sizeof(gen->philox_seed_offset); static const size_t offset_size = sizeof(gen->state.philox_seed_offset);
static const size_t total_size = states_size + seed_size + offset_size; static const size_t total_size = states_size + seed_size + offset_size;
bool no_philox_seed = false; bool no_philox_seed = false;
if (THByteTensor_nElement(rng_state) == total_size - offset_size) { if (THByteTensor_nElement(rng_state) == total_size - offset_size) {
@ -79,16 +82,16 @@ __host__ void THCRandom_setRNGState(THCState* state, THByteTensor *rng_state)
} }
THArgCheck(THByteTensor_isContiguous(rng_state), 1, "RNG state must be contiguous"); THArgCheck(THByteTensor_isContiguous(rng_state), 1, "RNG state must be contiguous");
THCudaCheck(cudaMemcpy(gen->gen_states, THByteTensor_data(rng_state), THCudaCheck(cudaMemcpy(gen->state.gen_states, THByteTensor_data(rng_state),
states_size, cudaMemcpyHostToDevice)); states_size, cudaMemcpyHostToDevice));
set_rngstate_kernel<<<1, MAX_NUM_BLOCKS, 0, THCState_getCurrentStream(state)>>>( set_rngstate_kernel<<<1, MAX_NUM_BLOCKS, 0, THCState_getCurrentStream(state)>>>(
gen->gen_states, gen->kernel_params); gen->state.gen_states, gen->state.kernel_params);
memcpy(&gen->initial_seed, THByteTensor_data(rng_state) + states_size, seed_size); memcpy(&gen->state.initial_seed, THByteTensor_data(rng_state) + states_size, seed_size);
if (!no_philox_seed) { if (!no_philox_seed) {
memcpy(&gen->philox_seed_offset, THByteTensor_data(rng_state) + states_size + seed_size, offset_size); memcpy(&gen->state.philox_seed_offset, THByteTensor_data(rng_state) + states_size + seed_size, offset_size);
} }
else { else {
gen->philox_seed_offset = 0; gen->state.philox_seed_offset = 0;
} }
} }

View File

@ -6,14 +6,7 @@
#include "generic/THCTensorRandom.h" #include "generic/THCTensorRandom.h"
#include "THCGenerateAllTypes.h" #include "THCGenerateAllTypes.h"
/* Generator */ typedef struct THCGenerator THCGenerator;
typedef struct _Generator {
struct curandStateMtgp32* gen_states;
struct mtgp32_kernel_params *kernel_params;
int initf;
uint64_t initial_seed;
int64_t philox_seed_offset;
} THCGenerator;
typedef struct THCRNGState { typedef struct THCRNGState {
/* One generator per GPU */ /* One generator per GPU */

View File

@ -16,7 +16,7 @@ THC_API void THCTensor_(uniform)(THCState* state, THCTensor *self_, double a, do
real *data = THCTensor_(data)(state, self); real *data = THCTensor_(data)(state, self);
generate_uniform<<<NUM_BLOCKS, BLOCK_SIZE, 0, THCState_getCurrentStream(state)>>>( generate_uniform<<<NUM_BLOCKS, BLOCK_SIZE, 0, THCState_getCurrentStream(state)>>>(
gen->gen_states, size, data, a, b); gen->state.gen_states, size, data, a, b);
THCTensor_(freeCopyTo)(state, self, self_); THCTensor_(freeCopyTo)(state, self, self_);
}; };
@ -31,7 +31,7 @@ THC_API void THCTensor_(normal)(THCState* state, THCTensor *self_, double mean,
real *data = THCTensor_(data)(state, self); real *data = THCTensor_(data)(state, self);
generate_normal<<<NUM_BLOCKS, BLOCK_SIZE, 0, THCState_getCurrentStream(state)>>>( generate_normal<<<NUM_BLOCKS, BLOCK_SIZE, 0, THCState_getCurrentStream(state)>>>(
gen->gen_states, size, data, mean, stdv); gen->state.gen_states, size, data, mean, stdv);
THCTensor_(freeCopyTo)(state, self, self_); THCTensor_(freeCopyTo)(state, self, self_);
}; };
@ -70,7 +70,7 @@ THC_API void THCTensor_(logNormal)(THCState* state, THCTensor *self_, double mea
real *data = THCTensor_(data)(state, self); real *data = THCTensor_(data)(state, self);
generateLogNormal<real><<<NUM_BLOCKS, BLOCK_SIZE, 0, THCState_getCurrentStream(state)>>>( generateLogNormal<real><<<NUM_BLOCKS, BLOCK_SIZE, 0, THCState_getCurrentStream(state)>>>(
gen->gen_states, size, data, mean, stdv); gen->state.gen_states, size, data, mean, stdv);
THCTensor_(freeCopyTo)(state, self, self_); THCTensor_(freeCopyTo)(state, self, self_);
}; };
@ -86,7 +86,7 @@ THC_API void THCTensor_(exponential)(THCState* state, THCTensor *self_, double l
real *data = THCTensor_(data)(state, self); real *data = THCTensor_(data)(state, self);
generate_exponential<<<NUM_BLOCKS, BLOCK_SIZE, 0, THCState_getCurrentStream(state)>>>( generate_exponential<<<NUM_BLOCKS, BLOCK_SIZE, 0, THCState_getCurrentStream(state)>>>(
gen->gen_states, size, data, lambda); gen->state.gen_states, size, data, lambda);
THCTensor_(freeCopyTo)(state, self, self_); THCTensor_(freeCopyTo)(state, self, self_);
}; };
@ -102,7 +102,7 @@ THC_API void THCTensor_(cauchy)(THCState* state, THCTensor *self_, double median
real *data = THCTensor_(data)(state, self); real *data = THCTensor_(data)(state, self);
generate_cauchy<<<NUM_BLOCKS, BLOCK_SIZE, 0, THCState_getCurrentStream(state)>>>( generate_cauchy<<<NUM_BLOCKS, BLOCK_SIZE, 0, THCState_getCurrentStream(state)>>>(
gen->gen_states, size, data, median, sigma); gen->state.gen_states, size, data, median, sigma);
THCTensor_(freeCopyTo)(state, self, self_); THCTensor_(freeCopyTo)(state, self, self_);
}; };
@ -242,7 +242,7 @@ THC_API void THCTensor_(multinomial)(struct THCState *state,
sampleMultinomialWithReplacement sampleMultinomialWithReplacement
<<<grid, block, 0, THCState_getCurrentStream(state)>>>( <<<grid, block, 0, THCState_getCurrentStream(state)>>>(
gen->gen_states, gen->state.gen_states,
n_sample, n_sample,
THCudaLongTensor_data(state, self), THCudaLongTensor_data(state, self),
numDist, numCategories, numDist, numCategories,
@ -275,7 +275,7 @@ THC_API void THCTensor_(multinomial)(struct THCState *state,
// recalculate our distribution // recalculate our distribution
sampleMultinomialWithoutReplacement sampleMultinomialWithoutReplacement
<<<grid, block, 0, THCState_getCurrentStream(state)>>>( <<<grid, block, 0, THCState_getCurrentStream(state)>>>(
gen->gen_states, gen->state.gen_states,
n_sample, n_sample,
sample, sample,
THCudaLongTensor_data(state, self), THCudaLongTensor_data(state, self),
@ -412,7 +412,7 @@ THC_API void THCTensor_(bernoulli)(THCState* state, THCTensor *self_, double p)
real *data = THCTensor_(data)(state, self); real *data = THCTensor_(data)(state, self);
generate_bernoulli<<<NUM_BLOCKS, BLOCK_SIZE, 0, THCState_getCurrentStream(state)>>>( generate_bernoulli<<<NUM_BLOCKS, BLOCK_SIZE, 0, THCState_getCurrentStream(state)>>>(
gen->gen_states, size, data, p); gen->state.gen_states, size, data, p);
THCTensor_(freeCopyTo)(state, self, self_); THCTensor_(freeCopyTo)(state, self, self_);
}; };
@ -443,7 +443,7 @@ THC_API void THCTensor_(NAME)(THCState* state, \
THArgCheck(size == prob_size, 3, "inconsistent tensor size"); \ THArgCheck(size == prob_size, 3, "inconsistent tensor size"); \
\ \
generate_bernoulli_tensor<<<NUM_BLOCKS, BLOCK_SIZE, 0, THCState_getCurrentStream(state)>>>( \ generate_bernoulli_tensor<<<NUM_BLOCKS, BLOCK_SIZE, 0, THCState_getCurrentStream(state)>>>( \
gen->gen_states, size, result_data, probs_data); \ gen->state.gen_states, size, result_data, probs_data); \
\ \
PROB_TYPE##_free(state, probs); \ PROB_TYPE##_free(state, probs); \
THCTensor_(freeCopyTo)(state, self, self_); \ THCTensor_(freeCopyTo)(state, self, self_); \
@ -483,7 +483,7 @@ THC_API void THCTensor_(geometric)(THCState* state, THCTensor *self_, double p)
real *data = THCTensor_(data)(state, self); real *data = THCTensor_(data)(state, self);
generate_geometric<<<NUM_BLOCKS, BLOCK_SIZE, 0, THCState_getCurrentStream(state)>>>( generate_geometric<<<NUM_BLOCKS, BLOCK_SIZE, 0, THCState_getCurrentStream(state)>>>(
gen->gen_states, size, data, p); gen->state.gen_states, size, data, p);
THCTensor_(freeCopyTo)(state, self, self_); THCTensor_(freeCopyTo)(state, self, self_);
}; };
@ -504,11 +504,11 @@ THC_API void THCTensor_(clampedRandom)(THCState* state, THCTensor *self_, int64_
#if defined(THC_REAL_IS_LONG) || defined(THC_REAL_IS_DOUBLE) || defined(THC_REAL_IS_FLOAT) #if defined(THC_REAL_IS_LONG) || defined(THC_REAL_IS_DOUBLE) || defined(THC_REAL_IS_FLOAT)
if (range > 1ULL << 32) { if (range > 1ULL << 32) {
generate_random_64<<<NUM_BLOCKS, BLOCK_SIZE, 0, THCState_getCurrentStream(state)>>>( generate_random_64<<<NUM_BLOCKS, BLOCK_SIZE, 0, THCState_getCurrentStream(state)>>>(
gen->gen_states, size, data, min_val, range); gen->state.gen_states, size, data, min_val, range);
} else { } else {
#endif #endif
generate_random<<<NUM_BLOCKS, BLOCK_SIZE, 0, THCState_getCurrentStream(state)>>>( generate_random<<<NUM_BLOCKS, BLOCK_SIZE, 0, THCState_getCurrentStream(state)>>>(
gen->gen_states, size, data, min_val, range); gen->state.gen_states, size, data, min_val, range);
#if defined(THC_REAL_IS_LONG) || defined(THC_REAL_IS_DOUBLE) || defined(THC_REAL_IS_FLOAT) #if defined(THC_REAL_IS_LONG) || defined(THC_REAL_IS_DOUBLE) || defined(THC_REAL_IS_FLOAT)
} }
#endif #endif
@ -534,19 +534,19 @@ THC_API void THCTensor_(random)(THCState* state, THCTensor *self_)
#if defined(THC_REAL_IS_HALF) #if defined(THC_REAL_IS_HALF)
generate_random<<<NUM_BLOCKS, BLOCK_SIZE, 0, THCState_getCurrentStream(state)>>>( generate_random<<<NUM_BLOCKS, BLOCK_SIZE, 0, THCState_getCurrentStream(state)>>>(
gen->gen_states, size, data, 0UL, (1UL << HLF_MANT_DIG) + 1); gen->state.gen_states, size, data, 0UL, (1UL << HLF_MANT_DIG) + 1);
#elif defined(THC_REAL_IS_FLOAT) #elif defined(THC_REAL_IS_FLOAT)
generate_random<<<NUM_BLOCKS, BLOCK_SIZE, 0, THCState_getCurrentStream(state)>>>( generate_random<<<NUM_BLOCKS, BLOCK_SIZE, 0, THCState_getCurrentStream(state)>>>(
gen->gen_states, size, data, 0UL, (1UL << FLT_MANT_DIG) + 1); gen->state.gen_states, size, data, 0UL, (1UL << FLT_MANT_DIG) + 1);
#elif defined(THC_REAL_IS_DOUBLE) #elif defined(THC_REAL_IS_DOUBLE)
generate_random_64<<<NUM_BLOCKS, BLOCK_SIZE, 0, THCState_getCurrentStream(state)>>>( generate_random_64<<<NUM_BLOCKS, BLOCK_SIZE, 0, THCState_getCurrentStream(state)>>>(
gen->gen_states, size, data, 0ULL, (1ULL << DBL_MANT_DIG) + 1); gen->state.gen_states, size, data, 0ULL, (1ULL << DBL_MANT_DIG) + 1);
#elif defined(THC_REAL_IS_LONG) #elif defined(THC_REAL_IS_LONG)
generate_random_64<<<NUM_BLOCKS, BLOCK_SIZE, 0, THCState_getCurrentStream(state)>>>( generate_random_64<<<NUM_BLOCKS, BLOCK_SIZE, 0, THCState_getCurrentStream(state)>>>(
gen->gen_states, size, data, 0ULL, static_cast<uint64_t>(std::numeric_limits<real>::max()) + 1); gen->state.gen_states, size, data, 0ULL, static_cast<uint64_t>(std::numeric_limits<real>::max()) + 1);
#else #else
generate_random<<<NUM_BLOCKS, BLOCK_SIZE, 0, THCState_getCurrentStream(state)>>>( generate_random<<<NUM_BLOCKS, BLOCK_SIZE, 0, THCState_getCurrentStream(state)>>>(
gen->gen_states, size, data, 0UL, static_cast<uint32_t>(std::numeric_limits<real>::max()) + 1); gen->state.gen_states, size, data, 0UL, static_cast<uint32_t>(std::numeric_limits<real>::max()) + 1);
#endif #endif
THCTensor_(freeCopyTo)(state, self, self_); THCTensor_(freeCopyTo)(state, self, self_);

View File

@ -16,6 +16,9 @@ $BUILD_ROOT/src/ATen/test/undefined_tensor_test
if [[ -x $BUILD_ROOT/src/ATen/test/cudnn_test ]]; then if [[ -x $BUILD_ROOT/src/ATen/test/cudnn_test ]]; then
$BUILD_ROOT/src/ATen/test/cudnn_test $BUILD_ROOT/src/ATen/test/cudnn_test
fi fi
if [[ -x $BUILD_ROOT/src/ATen/test/cuda_rng_test ]]; then
$BUILD_ROOT/src/ATen/test/cuda_rng_test
fi
if [ "$VALGRIND" == "ON" ] if [ "$VALGRIND" == "ON" ]
then then
valgrind --suppressions=`dirname $0`/valgrind.sup --error-exitcode=1 $BUILD_ROOT/src/ATen/test/basic "[cpu]" valgrind --suppressions=`dirname $0`/valgrind.sup --error-exitcode=1 $BUILD_ROOT/src/ATen/test/basic "[cpu]"

View File

@ -3,7 +3,7 @@ set -xe
mkdir aten_build aten_install mkdir aten_build aten_install
cd aten_build cd aten_build
cmake ../aten -DNO_CUDA=1 -DCMAKE_INSTALL_PREFIX=../aten_install cmake ../aten -DNO_CUDA=1 -DCMAKE_INSTALL_PREFIX=../aten_install
make -j32 install NUM_JOBS="$(getconf _NPROCESSORS_ONLN)"
../aten/tools/run_tests.sh . make -j"$NUM_JOBS" install
cd .. cd ..
aten/tools/test_install.sh $(pwd)/aten_install $(pwd)/aten aten/tools/test_install.sh $(pwd)/aten_install $(pwd)/aten