mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Implement randperm for CUDA (#7606)
* Implement randperm for CUDA * Use Thrust to implement randperm * clean up * Fix test * Offload small input scenario to CPU * Fixed test * Try to fix Windows error * Fix Windows error and clean up * Use fork_rng context manager * Move test_randperm_cuda to test_cuda * Add half tensor support * Fix cuda::type error * Fix CPU offloading * Fix issues * No need to check range for n == 0 case
This commit is contained in:
parent
9af3a80cff
commit
edfcbfbe1f
|
|
@ -263,18 +263,15 @@ THGenerator* get_generator(at::Generator* gen) {
|
|||
|
||||
Tensor randperm(const Type& dtype, int64_t n, Generator* generator) {
|
||||
Tensor result = dtype.tensor(n);
|
||||
return at::native::randperm_out(result, n, generator);
|
||||
return at::randperm_out(result, n, generator);
|
||||
}
|
||||
|
||||
Tensor& randperm_out(Tensor& result, int64_t n, Generator* generator) {
|
||||
Tensor& randperm_out_cpu(Tensor& result, int64_t n, Generator* generator) {
|
||||
if (n < 0) {
|
||||
std::ostringstream oss;
|
||||
oss << "n must be non-negative, got " << n;
|
||||
throw std::runtime_error(oss.str());
|
||||
}
|
||||
if (result.type().backend() != at::kCPU) {
|
||||
throw std::runtime_error("randperm is only implemented for CPU");
|
||||
}
|
||||
|
||||
result.resize_({n});
|
||||
auto gen = get_generator(generator);
|
||||
|
|
|
|||
|
|
@ -1,4 +1,14 @@
|
|||
#include "ATen/ATen.h"
|
||||
#include "ATen/NativeFunctions.h"
|
||||
#include "ATen/cuda/CUDATypeConversion.cuh"
|
||||
|
||||
#include <THC/THCGeneral.h>
|
||||
#include <THC/THCThrustAllocator.cuh>
|
||||
#include <thrust/device_ptr.h>
|
||||
#include <thrust/sort.h>
|
||||
#include <thrust/execution_policy.h>
|
||||
#include <thrust/sequence.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <sstream>
|
||||
|
||||
|
|
@ -26,4 +36,57 @@ Tensor& eye_out_cuda(Tensor& result, int64_t n, int64_t m) {
|
|||
return result;
|
||||
}
|
||||
|
||||
Tensor& randperm_out_cuda(Tensor& result, int64_t n, Generator* generator) {
|
||||
if (n < 0) {
|
||||
std::ostringstream oss;
|
||||
oss << "n must be non-negative, got " << n;
|
||||
throw std::runtime_error(oss.str());
|
||||
}
|
||||
|
||||
if (n > 0) {
|
||||
AT_DISPATCH_ALL_TYPES_AND_HALF(
|
||||
result.type(), "randperm_out_cuda", [&] {
|
||||
AT_CHECK(Scalar(n).to<scalar_t>(),
|
||||
"n is too large for result tensor type: '", result.type().toString(), "'");
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
result.resize_({n});
|
||||
|
||||
if (result.type().scalarType() == at::ScalarType::Half) {
|
||||
auto result_float = CUDA(kFloat).tensor({n});
|
||||
result.copy_(randperm_out_cuda(result_float, n, generator));
|
||||
} else {
|
||||
if (n < 30000) { // For small inputs, we offload it to CPU instead.
|
||||
auto result_cpu = result.type().toBackend(kCPU).tensor({n});
|
||||
randperm_out(result_cpu, n, generator);
|
||||
result.copy_(result_cpu);
|
||||
} else {
|
||||
// Generate random values for the keys array
|
||||
AT_DISPATCH_ALL_TYPES(
|
||||
result.type(), "randperm_out_cuda", [&] {
|
||||
using cuda_scalar_t = cuda::into_type<scalar_t>;
|
||||
|
||||
auto keys = result.type().tensor(result.sizes()).random_(generator);
|
||||
|
||||
auto result_data = thrust::device_ptr<cuda_scalar_t>(result.data<cuda_scalar_t>());
|
||||
auto keys_data = thrust::device_ptr<cuda_scalar_t>(keys.data<cuda_scalar_t>());
|
||||
|
||||
auto state = globalContext().getTHCState();
|
||||
THCThrustAllocator thrustAlloc(state);
|
||||
auto policy = thrust::cuda::par(thrustAlloc).on(THCState_getCurrentStream(state));
|
||||
|
||||
thrust::sequence(policy, result_data, result_data + n);
|
||||
|
||||
// Use the sorted order of keys to rearrange the result array
|
||||
thrust::sort_by_key(policy, keys_data, keys_data + n, result_data);
|
||||
}
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
}} // namespace at::native
|
||||
|
|
|
|||
|
|
@ -800,6 +800,9 @@
|
|||
|
||||
- func: randperm_out(Tensor result, int64_t n, *, Generator* generator=nullptr) -> Tensor
|
||||
variants: function
|
||||
dispatch:
|
||||
CPU: randperm_out_cpu
|
||||
CUDA: randperm_out_cuda
|
||||
|
||||
- func: range(Type dtype, Scalar start, Scalar end, Scalar step=1) -> Tensor
|
||||
variants: function
|
||||
|
|
|
|||
|
|
@ -1653,6 +1653,41 @@ class TestCuda(TestCase):
|
|||
torch.cuda.nvtx.mark("bar")
|
||||
torch.cuda.nvtx.range_pop()
|
||||
|
||||
def test_randperm_cuda(self):
|
||||
cuda = torch.device('cuda:0')
|
||||
|
||||
# For small inputs, randperm is offloaded to CPU instead
|
||||
with torch.random.fork_rng(devices=[0]):
|
||||
res1 = torch.randperm(100, device=cuda)
|
||||
res2 = torch.cuda.LongTensor()
|
||||
torch.randperm(100, out=res2, device=cuda)
|
||||
self.assertEqual(res1, res2, 0)
|
||||
|
||||
with torch.random.fork_rng(devices=[0]):
|
||||
res1 = torch.randperm(100000, device=cuda)
|
||||
res2 = torch.cuda.LongTensor()
|
||||
torch.randperm(100000, out=res2, device=cuda)
|
||||
self.assertEqual(res1, res2, 0)
|
||||
|
||||
with torch.random.fork_rng(devices=[0]):
|
||||
res1 = torch.randperm(100, dtype=torch.half, device=cuda)
|
||||
res2 = torch.cuda.HalfTensor()
|
||||
torch.randperm(100, out=res2, device=cuda)
|
||||
self.assertEqual(res1, res2, 0)
|
||||
|
||||
with torch.random.fork_rng(devices=[0]):
|
||||
res1 = torch.randperm(50000, dtype=torch.half, device=cuda)
|
||||
res2 = torch.cuda.HalfTensor()
|
||||
torch.randperm(50000, out=res2, device=cuda)
|
||||
self.assertEqual(res1, res2, 0)
|
||||
|
||||
# randperm of 0 elements is an empty tensor
|
||||
res1 = torch.randperm(0, device=cuda)
|
||||
res2 = torch.cuda.LongTensor(5)
|
||||
torch.randperm(0, out=res2, device=cuda)
|
||||
self.assertEqual(res1.numel(), 0)
|
||||
self.assertEqual(res2.numel(), 0)
|
||||
|
||||
def test_random_neg_values(self):
|
||||
TestTorch._test_random_neg_values(self, use_cuda=True)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user