Move THCTensor_(uniform) to ATen (#20292)

Summary:
As a first step for this plan: https://github.com/pytorch/pytorch/issues/19508#issuecomment-485178192, this PR moves `THCTensor_(uniform)` to ATen. Major changes are:
- `uniform_` cuda kernel now utilizes a philox generator.
- the kernel also utilizes TensorIterator
- the kernel uses a grid-stride loop to achieve peak effective bandwidth

- Since the engine has changed from `curandStateMTGP32` to `curandStatePhilox4_32_10`, the randoms generated now will be different.
- Here is the diff showing codegen changes: https://gist.github.com/syed-ahmed/4af9ae0d42b6c7dbaa13b9dd0d1dd1e8 (BC breaking change if any)

- Philox4_32_10 is known to pass the standard TestU01 Big Crush test (https://www.thesalmons.org/john/random123/papers/random123sc11.pdf) and hence the quality of random numbers generated isn't an issue when compared to the previously used `curandStateMTGP32`.
- I have added a test case in `aten/src/ATen/test/cuda_distributions_test.cu` which verifies that philox offset is incremented properly

The benchmark was done on a DGX station with 4 V100s.
I modified the script from jcjohnson 's [multinomial benchmark](https://github.com/jcjohnson/pytorch-multinomial-benchmark) to produce this notebook which shows that there is a general speedup with this PR and a regression hasn't been introduced: https://gist.github.com/syed-ahmed/9d26d4e96308aed274d0f2c7be5218ef

To reproduce the notebook:
- Run https://gist.github.com/syed-ahmed/4208c22c541f1d30ad6a9b1efc1d728f in a container with the current pytorch top of tree with the command: `python uniform_benchmark.py --stats_json before.json`
- Apply this diff to the current pytorch top of tree and run the same script in a container with the command: `python uniform_benchmark.py --stats_json after.json`
- Run the notebook attached above with the `after.json` and `before.json` in the same directory

The effected bandwidth was calculated using the script (thanks to ngimel ): https://gist.github.com/syed-ahmed/f8b7384d642f4bce484228b508b4bc68
Following are the numbers before and after.
```
uniform, size, elements 65536 forward 5.168914794921875e-06 bandwidth (GB/s) 50.71548098597786
uniform, size, elements 131072 forward 5.056858062744141e-06 bandwidth (GB/s) 103.67860705101367
uniform, size, elements 262144 forward 7.164478302001953e-06 bandwidth (GB/s) 146.357621001797
uniform, size, elements 524288 forward 1.1217594146728515e-05 bandwidth (GB/s) 186.9520302275877
uniform, size, elements 1048576 forward 1.923084259033203e-05 bandwidth (GB/s) 218.10297600317384
uniform, size, elements 2097152 forward 3.640890121459961e-05 bandwidth (GB/s) 230.39992200138826
uniform, size, elements 4194304 forward 6.778717041015625e-05 bandwidth (GB/s) 247.49839679819922
uniform, size, elements 8388608 forward 0.00012810707092285157 bandwidth (GB/s) 261.92490202361347
uniform, size, elements 16777216 forward 0.00025241613388061524 bandwidth (GB/s) 265.86598474620627
uniform, size, elements 33554432 forward 0.000497891902923584 bandwidth (GB/s) 269.5720239913193
```
```
uniform, size, elements 65536 forward 5.550384521484375e-06 bandwidth (GB/s) 47.22988091821306
uniform, size, elements 131072 forward 5.581378936767578e-06 bandwidth (GB/s) 93.93520954942333
uniform, size, elements 262144 forward 6.165504455566406e-06 bandwidth (GB/s) 170.071404141686
uniform, size, elements 524288 forward 6.3276290893554685e-06 bandwidth (GB/s) 331.4277702414469
uniform, size, elements 1048576 forward 8.509159088134765e-06 bandwidth (GB/s) 492.91639239047356
uniform, size, elements 2097152 forward 1.2989044189453124e-05 bandwidth (GB/s) 645.8218077979443
uniform, size, elements 4194304 forward 2.347707748413086e-05 bandwidth (GB/s) 714.6211452997259
uniform, size, elements 8388608 forward 4.4286251068115234e-05 bandwidth (GB/s) 757.6715389250498
uniform, size, elements 16777216 forward 8.672237396240235e-05 bandwidth (GB/s) 773.8356427961071
uniform, size, elements 33554432 forward 0.00016920566558837892 bandwidth (GB/s) 793.2224227438523
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/20292

Differential Revision: D15277761

Pulled By: ezyang

fbshipit-source-id: 8bfe31a01eeed77f0ed6e7ec4d2dda4c6472ecaa
This commit is contained in:
Syed Tousif Ahmed 2019-05-13 09:35:30 -07:00 committed by Facebook Github Bot
parent 5f7ef09f57
commit 67414714e5
12 changed files with 359 additions and 47 deletions

View File

@ -2625,7 +2625,6 @@
- floating_point
backends:
- CPU
- CUDA
cname: uniform
variants: function
return: self

View File

@ -249,7 +249,7 @@ Tensor & random_(Tensor& self, Generator * generator) {
return at::legacy::th::_th_random_(self, generator);
}
Tensor & uniform_(Tensor& self, double from, double to, Generator * generator) {
Tensor & uniform_cpu_(Tensor& self, double from, double to, Generator * generator) {
return at::legacy::th::_th_uniform_(self, from, to, generator);
}

View File

@ -486,6 +486,14 @@ std::unique_ptr<TensorIterator> TensorIterator::unary_op(Tensor& out, const Tens
return builder.build();
}
std::unique_ptr<TensorIterator> TensorIterator::nullary_op(Tensor& out) {
auto builder = TensorIterator::Builder();
builder.add_output(out);
// FIXME: workaround for bug: https://github.com/pytorch/pytorch/issues/20342
builder.iter_->resize_outputs_ = false;
return builder.build();
}
std::unique_ptr<TensorIterator> TensorIterator::reduce_op(Tensor& out, const Tensor& a) {
AT_ASSERT(out.defined());
auto builder = TensorIterator::Builder();

View File

@ -146,6 +146,7 @@ struct CAFFE2_API TensorIterator {
static std::unique_ptr<TensorIterator> binary_op(Tensor& out, const Tensor& a, const Tensor& b);
static std::unique_ptr<TensorIterator> unary_op(Tensor& out, const Tensor& a);
static std::unique_ptr<TensorIterator> nullary_op(Tensor& out);
static std::unique_ptr<TensorIterator> reduce_op(Tensor& out, const Tensor& a);
int ndim() const { return shape_.size(); }

View File

@ -11,6 +11,8 @@
#include <functional>
#include <ATen/native/Distributions.h>
#include <ATen/native/cuda/Loops.cuh>
#include <ATen/native/TensorIterator.h>
#include <THC/THCGeneral.h>
#include <THC/THCTensorRandom.h>
@ -23,17 +25,160 @@
#include <utility>
#include <type_traits>
/**
* Note [Register spilling in curand call for CUDA < 10]
* ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
* For CUDA < 10, curandStatePhilox4_32_10_t engine achieves poor performance (60% SOL bandwidth)
* when called to generate one random number at a time. This is because the line
* unsigned ret = (&state->output.x)[state->STATE++];
* in
* QUALIFIERS unsigned int curand(curandStatePhilox4_32_10_t *state)
* in curand_kernel.h dynamically indexes into state.output, preventing the compiler from ever
* storing state.output in registers.
*
* CUDA 10 fixed this problem. However, for backwards compatibility, in the following kernels
* we are using curand distributions that utilize curand4 call. curand4 call doesn't have the
* register spilling problem.
*/
THCGenerator* THCRandom_getGenerator(THCState* state);
namespace {
// increment should be at least the number of curand() random numbers used in
// each thread.
// Increment should be at least the number of curand() random numbers used in
// each thread. It is the user's responsibility to make sure that the increment for philox is never
// smaller than the number of curand() calls. Increment value > the number of curand() calls
// won't harm but anything less would mean that you would be reusing random values from
// previous calls.
// e.g. In many kernels below, we use distributions that utilize curand4 call in the kernel.
// Hence, increment value should be at least 4 for those kernels.
std::pair<uint64_t, uint64_t> next_philox_seed(at::Generator* gen, uint64_t increment) {
auto gen_ = THCRandom_getGenerator(at::globalContext().getTHCState());
uint64_t offset = gen_->state.philox_seed_offset.fetch_add(increment);
return std::make_pair(gen_->state.initial_seed, offset);
}
// launch bounds used for kernels utilizing TensorIterator
const uint32_t block_size_bound = 256;
const uint32_t grid_size_bound = 4;
// number of randoms given by distributions like curand_uniform4, curand_uniform2_double
// used in calculating philox offset.
const uint32_t curand4_engine_calls = 4;
// utility function that calculates proper philox_offset
// for distributions utilizing TensorIterator. For distributions using
// TensorIterator, we are using a grid-stride loop with each
// thread yielding one element per thread. For the edge of the grid-stride
// loop, if the tensor size is large, the unroll loop will kick in and the float4
// from curand4 will start getting utilized (for common tensor sizes, we end up
// using rand.x from each thread). Hence, the philox_offset is
// (number of elements per thread * number of engine calls), which makes
// sure that philox offset increment is not less than the number of randoms used
// in each thread.
std::tuple<uint64_t, dim3, dim3> calc_execution_policy(int64_t total_elements) {
const uint64_t numel = static_cast<uint64_t>(total_elements);
const uint32_t block_size = block_size_bound;
const uint32_t unroll = curand4_engine_calls;
dim3 dim_block(block_size);
dim3 grid((numel + block_size - 1) / block_size);
uint32_t blocks_per_sm = at::cuda::getCurrentDeviceProperties()->maxThreadsPerMultiProcessor / block_size;
grid.x = std::min(
static_cast<uint32_t>(at::cuda::getCurrentDeviceProperties()->multiProcessorCount) * blocks_per_sm,
grid.x);
//number of times random will be generated per thread, to offset philox counter in thc random state
uint64_t counter_offset = ((numel - 1) / (block_size * grid.x * unroll) + 1)
* curand4_engine_calls;
return std::make_tuple(counter_offset, grid, dim_block);
}
// grid stride loop kernel for distributions
template<typename accscalar_t, int unroll_factor, typename dist_t, typename transform_t>
C10_LAUNCH_BOUNDS_2(block_size_bound, grid_size_bound)
__global__ void distribution_elementwise_grid_stride_kernel(int numel,
std::pair<uint64_t, uint64_t> seeds,
const dist_t dist_func,
const transform_t transform_func) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
curandStatePhilox4_32_10_t state;
curand_init(
seeds.first,
idx,
seeds.second,
&state);
int rounded_size = ((numel - 1)/(blockDim.x * gridDim.x * unroll_factor)+1) *
blockDim.x * gridDim.x * unroll_factor;
for(int linear_index = idx; linear_index < rounded_size; linear_index += blockDim.x * gridDim.x * unroll_factor) {
auto rand = dist_func(&state);
#pragma unroll
for (int ii = 0; ii < unroll_factor; ii++) {
int li = linear_index + blockDim.x * gridDim.x * ii;
if (li < numel) {
transform_func(li, static_cast<accscalar_t>((&rand.x)[ii]));
}
}
__syncthreads();
}
}
template<typename scalar_t,
typename accscalar_t,
int unroll_factor,
typename dist_t,
typename transform_t>
void distribution_nullary_kernel(at::TensorIterator& iter,
at::Generator* gen,
const dist_t& dist_func,
const transform_t transform_func) {
static_assert(unroll_factor >= 1, "unroll_factor must be >= 1.");
int64_t numel = iter.numel();
if (numel == 0) {
return;
}
auto execution_policy = calc_execution_policy(numel);
auto counter_offset = std::get<0>(execution_policy);
auto grid = std::get<1>(execution_policy);
auto block = std::get<2>(execution_policy);
auto seeds = next_philox_seed(gen, counter_offset);
if (!iter.can_use_32bit_indexing()) {
for (auto& sub_iter : iter.with_32bit_indexing()) {
distribution_nullary_kernel<scalar_t, accscalar_t, unroll_factor>(sub_iter,
gen, dist_func, transform_func);
}
return;
}
char* out_data = (char*)iter.data_ptr(0);
auto stream = at::cuda::getCurrentCUDAStream();
if (iter.is_trivial_1d()) {
auto strides = iter.get_inner_strides();
int stride0 = strides[0];
distribution_elementwise_grid_stride_kernel<accscalar_t, unroll_factor><<<grid, block, 0, stream>>>(
numel,
seeds,
dist_func,
[=]__device__(int idx, accscalar_t rand) {
scalar_t* out = (scalar_t*)&out_data[stride0 * idx];
*out = transform_func(rand);
}
);
} else {
auto offset_calc = at::native::make_offset_calculator<1>(iter);
distribution_elementwise_grid_stride_kernel<accscalar_t, unroll_factor><<<grid, block, 0, stream>>>(
numel,
seeds,
dist_func,
[=]__device__(int idx, accscalar_t rand) {
auto offsets = offset_calc.get(idx);
scalar_t* out = (scalar_t*)&out_data[offsets[0]];
*out = transform_func(rand);
}
);
}
AT_CUDA_CHECK(cudaGetLastError());
}
template <typename scalar_t>
void poisson_cuda_kernel(
at::Tensor& ret,
@ -117,6 +262,7 @@ void bernoulli_tensor_cuda_kernel(
blockIdx.x * blockDim.x + threadIdx.x,
seeds.second,
&state);
// See Note [Register spilling in curand call for CUDA < 10]
float4 rand = curand_uniform4(&state);
switch (n) {
case 4: {
@ -159,6 +305,7 @@ void bernoulli_scalar_cuda_kernel(
blockIdx.x * blockDim.x + threadIdx.x,
seeds.second,
&state);
// See Note [Register spilling in curand call for CUDA < 10]
float4 rand = curand_uniform4(&state);
switch (n) {
case 4: {
@ -256,5 +403,49 @@ Tensor& bernoulli_scalar_cuda_(Tensor &self, double p, Generator* gen) {
return self;
}
void uniform_kernel_cuda(TensorIterator& iter, double from_, double to_, Generator* gen) {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(iter.dtype(), "uniform_cuda", [&] {
auto from = static_cast<scalar_t>(from_);
auto to = static_cast<scalar_t>(to_);
AT_CHECK(from <= to,
"uniform_ expects to return a [from, to) range, but found from=", from,
" > to=", to);
AT_CHECK((to - from) <= std::numeric_limits<scalar_t>::max(),
"uniform_ expects to-from <= std::numeric_limits<", toString(iter.dtype()),
">::max(), but found to=", to, " and from=", from,
" which result in to-from to exceed the limit");
using accscalar_t = at::acc_type<scalar_t, true>;
auto range = static_cast<accscalar_t>(to-from);
from = static_cast<accscalar_t>(from);
// define lambda to reverse bounds, multiply 'range' and add 'from_'
auto uniform_func = [range, from] __device__ (accscalar_t rand) {
// reverse the bounds of curand4 from (0, 1] to [0, 1)
// Note that this method is from legacy THCTensorRandom and is likely to give
// you more 0-s, since, the probability of gettings 1-s is higher than 0-s and
// by reversing the bounds, we are flipping the probabilities of 1-s and 0-s.
auto reverse_bound_rand = rand == static_cast<accscalar_t>(1.0) ? static_cast<accscalar_t>(0.0) : rand;
return static_cast<scalar_t>(reverse_bound_rand * range + from);
};
if (std::is_same<scalar_t, double>::value) {
distribution_nullary_kernel<scalar_t, accscalar_t, curand4_engine_calls/2>(iter,
gen,
[] __device__ (curandStatePhilox4_32_10_t* state) { return curand_uniform2_double(state); },
uniform_func);
} else {
distribution_nullary_kernel<scalar_t, accscalar_t, curand4_engine_calls>(iter,
gen,
[] __device__ (curandStatePhilox4_32_10_t* state) { return curand_uniform4(state); },
uniform_func);
}
});
}
Tensor& uniform_cuda_(Tensor& self, double from, double to, Generator* gen) {
auto iter = TensorIterator::nullary_op(self);
uniform_kernel_cuda(*iter, from, to, gen);
return self;
}
}} // namespace at::native

View File

@ -2956,6 +2956,9 @@
- func: uniform_(Tensor(a!) self, float from=0, float to=1, *, Generator? generator=None) -> Tensor(a!)
variants: method
dispatch:
CPU: uniform_cpu_
CUDA: uniform_cuda_
- func: normal_(Tensor(a!) self, float mean=0, float std=1, *, Generator? generator=None) -> Tensor(a!)
variants: method

View File

@ -31,6 +31,7 @@ list(APPEND ATen_CUDA_TEST_SRCS
${CMAKE_CURRENT_SOURCE_DIR}/cuda_apply_test.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cuda_stream_test.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cuda_half_test.cu
${CMAKE_CURRENT_SOURCE_DIR}/cuda_distributions_test.cu
${CMAKE_CURRENT_SOURCE_DIR}/cuda_optional_test.cu
${CMAKE_CURRENT_SOURCE_DIR}/cuda_packedtensoraccessor_test.cu
${CMAKE_CURRENT_SOURCE_DIR}/cuda_tensor_interop_test.cpp)

View File

@ -0,0 +1,143 @@
#include <gtest/gtest.h>
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <curand.h>
#include <curand_kernel.h>
#include <curand_philox4x32_x.h>
__global__ void expected_randoms(float* x, uint64_t counter_offset) {
for(int i=0; i < 4; i++) {
curandStatePhilox4_32_10_t state;
curand_init(
123,
i,
counter_offset,
&state);
auto ret = curand_uniform4(&state);
x[i] = ret.x;
}
}
TEST(DistributionsTest, TestPhiloxIncrementSmallTensor) {
// Test Description:
// In Distributions.cu we mentioned that philox increment
// should be at least the number of curand() random numbers used in
// each thread. In this test, we make sure that uniform_ correctly
// increments philox and doesn't reuse randoms from previous calls
// for a small tensor size of 4.
// - We check that by first getting 4 randoms from uniform_.
// Once we get these 4 randoms, that would mean that philox counter for
// thread 0, 1, 2 and 3, was incremented by 4 (check calc_execution_policy
// function for details).
// - Now get 4 randoms with offset=4 for thread {0,1,2,3} from expected_randoms
// kernel above.
// - Now get 4 more randoms from uniform_ (note thread {0,1,2,3} for this call would
// start from a philox_offset value of 4)
// - the 4 randoms from expected_randoms and the 4 randoms from the previous call
// of uniform_ should match, signifying that the philox offset was
// incremented properly and no randoms are being reused from previous calls
// if cuda not available, return
if (!at::cuda::is_available()) return;
// manual seed to 123
at::manual_seed(123);
// get 4 randoms from uniform_(), philox offset is now incremented to 4 by this call
at::empty({4}, at::TensorOptions(at::kCUDA)).uniform_();
// allocate 4 float on host memory
float *x;
cudaMallocManaged(&x, 4*sizeof(float));
// launch kernel to get expected randoms
expected_randoms<<<1, 1>>>(x, 4);
// Wait for GPU to finish before accessing on host
cudaDeviceSynchronize();
// get 4 new float from uniform_()
auto self = at::empty({4}, at::TensorOptions(at::kCUDA));
self.uniform_();
// check randoms from expected_randoms kernel are equal to the randoms from the second
// call of uniform_()
for (int i = 0; i < 4; i++) {
ASSERT_EQ(self[i].item().to<float>(), x[i]);
}
// Free memory
cudaFree(x);
}
TEST(DistributionsTest, TestPhiloxIncrementBigTensor) {
// Test Description:
// In Distributions.cu we mentioned that philox increment
// should be at least the number of curand() random numbers used in
// each thread. In this test, we make sure that uniform_ correctly
// increments philox and doesn't reuse randoms from previous calls
// for a big size tensor.
// - First of all, we come up with what the size of the big tensor
// should be for this test. Our goal is to show that when the uniform_
// kernel runs at full occupancy (i.e. when the number of elements is
// greater the number of threads launched), it hits the unroll loop in
// the uniform_ kernel.
// - Hence, we set the size of the tensor in this test to be 8 times the
// maximum number of threads we can launch. This means that, each thread will
// be yielding 8 elements, and as a result, curand_uniform4 will be called twice
// and all the 8 elements in a thread will consume all the float4 from the
// two calls of curand_unfiorm4 as a result of the unroll loop. Therefore,
// after this call to the unform_, counter_offset for the next call to uniform_
// will start from 8. This is what we test next.
// - Now get 4 randoms with offset=8 for thread {0,1,2,3} from expected_randoms
// kernel above.
// - Now get 4 more randoms from uniform_ (note thread {0,1,2,3} for this call would
// start from a philox_offset value of 8)
// - the 4 randoms from expected_randoms kernel and the 4 randoms from the previous call
// of uniform_ should match, signifying that the philox offset was
// incremented properly and no randoms are being reused from previous calls
// if cuda not available, return
if (!at::cuda::is_available()) return;
// manual seed to 123
at::manual_seed(123);
// calculate maximum number of threads that can be launched
// and set the numel to be 8 times that
const int block_size = 256;
dim3 dim_block(block_size);
uint32_t blocks_per_sm = at::cuda::getCurrentDeviceProperties()->maxThreadsPerMultiProcessor / block_size;
dim3 grid(static_cast<uint32_t>(at::cuda::getCurrentDeviceProperties()->multiProcessorCount) * blocks_per_sm);
auto numel = block_size * grid.x * 8;
// get numel randoms from uniform_(), philox offset is now incremented to 8 by this call
at::empty({numel}, at::TensorOptions(at::kCUDA)).uniform_();
// allocate 4 float on host memory
float *x;
cudaMallocManaged(&x, 4*sizeof(float));
// launch kernel to get expected randoms
expected_randoms<<<1, 1>>>(x, 8);
// Wait for GPU to finish before accessing on host
cudaDeviceSynchronize();
// get 4 new float from uniform_()
auto self = at::empty({4}, at::TensorOptions(at::kCUDA));
self.uniform_();
// check randoms from expected_randoms kernel are equal to the randoms from the second
// call of uniform_()
for (int i = 0; i < 4; i++) {
ASSERT_EQ(self[i].item().to<float>(), x[i]);
}
// Free memory
cudaFree(x);
}

View File

@ -101,24 +101,6 @@ THC_API __host__ void THCRandom_setRNGState(THCState* state, THByteTensor *rng_s
}
}
// Goes from (0, 1] to [0, 1). Note 1-x is not sufficient since for some floats
// eps near 0, 1-eps will round to 1.
template <typename T>
__device__ inline T reverse_bounds(T value) {
if (THCNumerics<T>::eq(value, ScalarConvert<int, T>::to(1))) {
return ScalarConvert<int, T>::to(0);
}
return value;
}
__device__ inline at::Half half_uniform_scale_and_shift(float x, double a, double b) {
at::Half width = ScalarConvert<double, at::Half>::to(b - a);
at::Half start = ScalarConvert<double, at::Half>::to(a);
at::Half scaled = THCNumerics<at::Half>::mul(reverse_bounds(ScalarConvert<float, at::Half>::to(x)), width);
return THCNumerics<at::Half>::add(scaled, start);
}
#define GENERATE_KERNEL1(NAME, T, ARG1, CURAND_T, CURAND_FUNC, TRANSFORM) \
__global__ void NAME(curandStateMtgp32 *state, int size, T *result, ARG1) \
{ \
@ -147,11 +129,6 @@ __global__ void NAME(curandStateMtgp32 *state, int size, T *result, ARG1, ARG2)
} \
}
// NOTE: curand_uniform is (0, 1] and we want [a, b)
GENERATE_KERNEL2(generate_uniform, float, float a, float b, float, curand_uniform, reverse_bounds(x) * (b-a) + a)
GENERATE_KERNEL2(generate_uniform, float, double a, double b, float, curand_uniform, reverse_bounds(x) * (b-a) + a)
GENERATE_KERNEL2(generate_uniform, double, double a, double b, double, curand_uniform_double, reverse_bounds(x) * (b-a) + a)
GENERATE_KERNEL2(generate_normal, float, double mean, double stdv, float, curand_normal, (x * stdv) + mean)
GENERATE_KERNEL2(generate_normal, double, double mean, double stdv, double, curand_normal_double, (x * stdv) + mean)
@ -161,7 +138,6 @@ GENERATE_KERNEL1(generate_exponential, double, double lambda, double, curand_uni
GENERATE_KERNEL2(generate_cauchy, float, double median, double sigma, float, curand_uniform, (float)(median + sigma * tan(M_PI*(x-0.5))))
GENERATE_KERNEL2(generate_cauchy, double, double median, double sigma, double, curand_uniform_double, (double)(median + sigma * tan(M_PI*(x-0.5))))
GENERATE_KERNEL2(generate_uniform, at::Half, double a, double b, float, curand_uniform, (half_uniform_scale_and_shift(x, a, b)))
GENERATE_KERNEL2(generate_normal, at::Half, double mean, double stdv, float, curand_normal, (ScalarConvert<float, at::Half>::to((x * stdv) + mean)))
GENERATE_KERNEL1(generate_exponential, at::Half, double lambda, float, curand_uniform, (ScalarConvert<float, at::Half>::to((float)(-1. / lambda * log(x)))))
GENERATE_KERNEL2(generate_cauchy, at::Half, double median, double sigma, float, curand_uniform, (ScalarConvert<float, at::Half>::to((float)(median + sigma * tan(M_PI*(x-0.5))))))

View File

@ -8,21 +8,6 @@
#if defined(THC_REAL_IS_FLOAT) || defined(THC_REAL_IS_DOUBLE) || defined(THC_REAL_IS_HALF)
void THCTensor_(uniform)(THCState* state, THCTensor *self_, double a, double b)
{
THCAssertSameGPU(THCTensor_(checkGPU)(state, 1, self_));
ptrdiff_t size = THCTensor_(nElement)(state, self_);
if (size == 0) return;
THCGenerator* gen = THCRandom_getGenerator(state);
THCTensor *self = THCTensor_(newContiguous)(state, self_);
scalar_t *data = THCTensor_(data)(state, self);
generate_uniform<<<NUM_BLOCKS, BLOCK_SIZE, 0, THCState_getCurrentStream(state)>>>(
gen->state.gen_states, size, data, a, b);
THCTensor_(freeCopyTo)(state, self, self_);
};
void THCTensor_(normal)(THCState* state, THCTensor *self_, double mean, double stdv)
{
THCAssertSameGPU(THCTensor_(checkGPU)(state, 1, self_));
@ -191,7 +176,8 @@ void THCTensor_(multinomial)(struct THCState *state,
// Uniform random samples in a separate kernel launch, into
// temporarily allocated memory. The device RNG is thread-limited
THCTensor *sampled = THCTensor_(newWithSize2d)(state, numDist, n_sample);
THCTensor_(uniform)(state, sampled, 0.0, 1.0);
auto out = THTensor_wrap(sampled);
at::native::uniform_cuda_(out, 0.0, 1.0);
dim3 block(numCategories < maxThreads ? numCategories : maxThreads);
dim3 grid(numDist < numSM * 4 ? numDist : numSM * 4);
@ -380,8 +366,10 @@ void THCTensor_(multinomialAliasDraw)(THCState *state, THCudaLongTensor *self, T
THCTensor *uniform = THCTensor_(newWithSize1d)(state, n_sample);
THCTensor *bernoulli = THCTensor_(newWithSize1d)(state, n_sample);
THCTensor_(uniform)(state, uniform, 0, K);
THCTensor_(uniform)(state, bernoulli, 0, 1);
auto out_uniform = THTensor_wrap(uniform);
auto out_bernoulli = THTensor_wrap(bernoulli);
at::native::uniform_cuda_(out_uniform, 0, K);
at::native::uniform_cuda_(out_bernoulli, 0, 1);
multinomialAliasDrawKernel
<<<THCCeilDiv((int)n_sample+BLOCK_SIZE-1, BLOCK_SIZE), BLOCK_SIZE, 0, THCState_getCurrentStream(state)>>>(

View File

@ -4,7 +4,6 @@
#if defined(THC_REAL_IS_FLOAT) || defined(THC_REAL_IS_DOUBLE) || defined(THC_REAL_IS_HALF)
THC_API void THCTensor_(uniform)(struct THCState *state, THCTensor *self, double a, double b);
THC_API void THCTensor_(normal)(struct THCState *state, THCTensor *self, double mean, double stdv);
THC_API void THCTensor_(normal_means)(struct THCState *state, THCTensor *self, THCTensor *means, double stddev);
THC_API void THCTensor_(normal_stddevs)(struct THCState *state, THCTensor *self, double mean, THCTensor *stddevs);

View File

@ -34,6 +34,9 @@ fi
if [[ -x ./cuda_half_test ]]; then
./cuda_half_test
fi
if [[ -x ./cuda_distributions_test ]]; then
./cuda_distributions_test
fi
if [[ -x ./cuda_optional_test ]]; then
./cuda_optional_test
fi