diff --git a/aten/src/ATen/native/TensorFactories.cpp b/aten/src/ATen/native/TensorFactories.cpp index f9f558a3dee..e6fe1ffa2f1 100644 --- a/aten/src/ATen/native/TensorFactories.cpp +++ b/aten/src/ATen/native/TensorFactories.cpp @@ -263,18 +263,15 @@ THGenerator* get_generator(at::Generator* gen) { Tensor randperm(const Type& dtype, int64_t n, Generator* generator) { Tensor result = dtype.tensor(n); - return at::native::randperm_out(result, n, generator); + return at::randperm_out(result, n, generator); } -Tensor& randperm_out(Tensor& result, int64_t n, Generator* generator) { +Tensor& randperm_out_cpu(Tensor& result, int64_t n, Generator* generator) { if (n < 0) { std::ostringstream oss; oss << "n must be non-negative, got " << n; throw std::runtime_error(oss.str()); } - if (result.type().backend() != at::kCPU) { - throw std::runtime_error("randperm is only implemented for CPU"); - } result.resize_({n}); auto gen = get_generator(generator); diff --git a/aten/src/ATen/native/cuda/TensorFactories.cu b/aten/src/ATen/native/cuda/TensorFactories.cu index f6db9bdc681..4f2013c2f19 100644 --- a/aten/src/ATen/native/cuda/TensorFactories.cu +++ b/aten/src/ATen/native/cuda/TensorFactories.cu @@ -1,4 +1,14 @@ +#include "ATen/ATen.h" #include "ATen/NativeFunctions.h" +#include "ATen/cuda/CUDATypeConversion.cuh" + +#include +#include +#include +#include +#include +#include + #include #include @@ -26,4 +36,57 @@ Tensor& eye_out_cuda(Tensor& result, int64_t n, int64_t m) { return result; } +Tensor& randperm_out_cuda(Tensor& result, int64_t n, Generator* generator) { + if (n < 0) { + std::ostringstream oss; + oss << "n must be non-negative, got " << n; + throw std::runtime_error(oss.str()); + } + + if (n > 0) { + AT_DISPATCH_ALL_TYPES_AND_HALF( + result.type(), "randperm_out_cuda", [&] { + AT_CHECK(Scalar(n).to(), + "n is too large for result tensor type: '", result.type().toString(), "'"); + } + ); + } + + result.resize_({n}); + + if (result.type().scalarType() == at::ScalarType::Half) { + auto result_float = CUDA(kFloat).tensor({n}); + result.copy_(randperm_out_cuda(result_float, n, generator)); + } else { + if (n < 30000) { // For small inputs, we offload it to CPU instead. + auto result_cpu = result.type().toBackend(kCPU).tensor({n}); + randperm_out(result_cpu, n, generator); + result.copy_(result_cpu); + } else { + // Generate random values for the keys array + AT_DISPATCH_ALL_TYPES( + result.type(), "randperm_out_cuda", [&] { + using cuda_scalar_t = cuda::into_type; + + auto keys = result.type().tensor(result.sizes()).random_(generator); + + auto result_data = thrust::device_ptr(result.data()); + auto keys_data = thrust::device_ptr(keys.data()); + + auto state = globalContext().getTHCState(); + THCThrustAllocator thrustAlloc(state); + auto policy = thrust::cuda::par(thrustAlloc).on(THCState_getCurrentStream(state)); + + thrust::sequence(policy, result_data, result_data + n); + + // Use the sorted order of keys to rearrange the result array + thrust::sort_by_key(policy, keys_data, keys_data + n, result_data); + } + ); + } + } + + return result; +} + }} // namespace at::native diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 82974534c5b..4cd9d347a07 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -800,6 +800,9 @@ - func: randperm_out(Tensor result, int64_t n, *, Generator* generator=nullptr) -> Tensor variants: function + dispatch: + CPU: randperm_out_cpu + CUDA: randperm_out_cuda - func: range(Type dtype, Scalar start, Scalar end, Scalar step=1) -> Tensor variants: function diff --git a/test/test_cuda.py b/test/test_cuda.py index 087d1d0e8bb..42706b7044b 100644 --- a/test/test_cuda.py +++ b/test/test_cuda.py @@ -1653,6 +1653,41 @@ class TestCuda(TestCase): torch.cuda.nvtx.mark("bar") torch.cuda.nvtx.range_pop() + def test_randperm_cuda(self): + cuda = torch.device('cuda:0') + + # For small inputs, randperm is offloaded to CPU instead + with torch.random.fork_rng(devices=[0]): + res1 = torch.randperm(100, device=cuda) + res2 = torch.cuda.LongTensor() + torch.randperm(100, out=res2, device=cuda) + self.assertEqual(res1, res2, 0) + + with torch.random.fork_rng(devices=[0]): + res1 = torch.randperm(100000, device=cuda) + res2 = torch.cuda.LongTensor() + torch.randperm(100000, out=res2, device=cuda) + self.assertEqual(res1, res2, 0) + + with torch.random.fork_rng(devices=[0]): + res1 = torch.randperm(100, dtype=torch.half, device=cuda) + res2 = torch.cuda.HalfTensor() + torch.randperm(100, out=res2, device=cuda) + self.assertEqual(res1, res2, 0) + + with torch.random.fork_rng(devices=[0]): + res1 = torch.randperm(50000, dtype=torch.half, device=cuda) + res2 = torch.cuda.HalfTensor() + torch.randperm(50000, out=res2, device=cuda) + self.assertEqual(res1, res2, 0) + + # randperm of 0 elements is an empty tensor + res1 = torch.randperm(0, device=cuda) + res2 = torch.cuda.LongTensor(5) + torch.randperm(0, out=res2, device=cuda) + self.assertEqual(res1.numel(), 0) + self.assertEqual(res2.numel(), 0) + def test_random_neg_values(self): TestTorch._test_random_neg_values(self, use_cuda=True)