Implement randperm for CUDA (#7606)

* Implement randperm for CUDA

* Use Thrust to implement randperm

* clean up

* Fix test

* Offload small input scenario to CPU

* Fixed test

* Try to fix Windows error

* Fix Windows error and clean up

* Use fork_rng context manager

* Move test_randperm_cuda to test_cuda

* Add half tensor support

* Fix cuda::type error

* Fix CPU offloading

* Fix issues

* No need to check range for n == 0 case
This commit is contained in:
Will Feng 2018-06-06 14:30:58 -04:00 committed by GitHub
parent 9af3a80cff
commit edfcbfbe1f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 103 additions and 5 deletions

View File

@ -263,18 +263,15 @@ THGenerator* get_generator(at::Generator* gen) {
Tensor randperm(const Type& dtype, int64_t n, Generator* generator) {
Tensor result = dtype.tensor(n);
return at::native::randperm_out(result, n, generator);
return at::randperm_out(result, n, generator);
}
Tensor& randperm_out(Tensor& result, int64_t n, Generator* generator) {
Tensor& randperm_out_cpu(Tensor& result, int64_t n, Generator* generator) {
if (n < 0) {
std::ostringstream oss;
oss << "n must be non-negative, got " << n;
throw std::runtime_error(oss.str());
}
if (result.type().backend() != at::kCPU) {
throw std::runtime_error("randperm is only implemented for CPU");
}
result.resize_({n});
auto gen = get_generator(generator);

View File

@ -1,4 +1,14 @@
#include "ATen/ATen.h"
#include "ATen/NativeFunctions.h"
#include "ATen/cuda/CUDATypeConversion.cuh"
#include <THC/THCGeneral.h>
#include <THC/THCThrustAllocator.cuh>
#include <thrust/device_ptr.h>
#include <thrust/sort.h>
#include <thrust/execution_policy.h>
#include <thrust/sequence.h>
#include <algorithm>
#include <sstream>
@ -26,4 +36,57 @@ Tensor& eye_out_cuda(Tensor& result, int64_t n, int64_t m) {
return result;
}
Tensor& randperm_out_cuda(Tensor& result, int64_t n, Generator* generator) {
if (n < 0) {
std::ostringstream oss;
oss << "n must be non-negative, got " << n;
throw std::runtime_error(oss.str());
}
if (n > 0) {
AT_DISPATCH_ALL_TYPES_AND_HALF(
result.type(), "randperm_out_cuda", [&] {
AT_CHECK(Scalar(n).to<scalar_t>(),
"n is too large for result tensor type: '", result.type().toString(), "'");
}
);
}
result.resize_({n});
if (result.type().scalarType() == at::ScalarType::Half) {
auto result_float = CUDA(kFloat).tensor({n});
result.copy_(randperm_out_cuda(result_float, n, generator));
} else {
if (n < 30000) { // For small inputs, we offload it to CPU instead.
auto result_cpu = result.type().toBackend(kCPU).tensor({n});
randperm_out(result_cpu, n, generator);
result.copy_(result_cpu);
} else {
// Generate random values for the keys array
AT_DISPATCH_ALL_TYPES(
result.type(), "randperm_out_cuda", [&] {
using cuda_scalar_t = cuda::into_type<scalar_t>;
auto keys = result.type().tensor(result.sizes()).random_(generator);
auto result_data = thrust::device_ptr<cuda_scalar_t>(result.data<cuda_scalar_t>());
auto keys_data = thrust::device_ptr<cuda_scalar_t>(keys.data<cuda_scalar_t>());
auto state = globalContext().getTHCState();
THCThrustAllocator thrustAlloc(state);
auto policy = thrust::cuda::par(thrustAlloc).on(THCState_getCurrentStream(state));
thrust::sequence(policy, result_data, result_data + n);
// Use the sorted order of keys to rearrange the result array
thrust::sort_by_key(policy, keys_data, keys_data + n, result_data);
}
);
}
}
return result;
}
}} // namespace at::native

View File

@ -800,6 +800,9 @@
- func: randperm_out(Tensor result, int64_t n, *, Generator* generator=nullptr) -> Tensor
variants: function
dispatch:
CPU: randperm_out_cpu
CUDA: randperm_out_cuda
- func: range(Type dtype, Scalar start, Scalar end, Scalar step=1) -> Tensor
variants: function

View File

@ -1653,6 +1653,41 @@ class TestCuda(TestCase):
torch.cuda.nvtx.mark("bar")
torch.cuda.nvtx.range_pop()
def test_randperm_cuda(self):
cuda = torch.device('cuda:0')
# For small inputs, randperm is offloaded to CPU instead
with torch.random.fork_rng(devices=[0]):
res1 = torch.randperm(100, device=cuda)
res2 = torch.cuda.LongTensor()
torch.randperm(100, out=res2, device=cuda)
self.assertEqual(res1, res2, 0)
with torch.random.fork_rng(devices=[0]):
res1 = torch.randperm(100000, device=cuda)
res2 = torch.cuda.LongTensor()
torch.randperm(100000, out=res2, device=cuda)
self.assertEqual(res1, res2, 0)
with torch.random.fork_rng(devices=[0]):
res1 = torch.randperm(100, dtype=torch.half, device=cuda)
res2 = torch.cuda.HalfTensor()
torch.randperm(100, out=res2, device=cuda)
self.assertEqual(res1, res2, 0)
with torch.random.fork_rng(devices=[0]):
res1 = torch.randperm(50000, dtype=torch.half, device=cuda)
res2 = torch.cuda.HalfTensor()
torch.randperm(50000, out=res2, device=cuda)
self.assertEqual(res1, res2, 0)
# randperm of 0 elements is an empty tensor
res1 = torch.randperm(0, device=cuda)
res2 = torch.cuda.LongTensor(5)
torch.randperm(0, out=res2, device=cuda)
self.assertEqual(res1.numel(), 0)
self.assertEqual(res2.numel(), 0)
def test_random_neg_values(self):
TestTorch._test_random_neg_values(self, use_cuda=True)