diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 6c6a77aafdc..804900bccc6 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -8174,6 +8174,15 @@ CPU: searchsorted_cpu CUDA: searchsorted_cuda +- func: _convert_indices_from_coo_to_csr(Tensor self, int size, *, bool out_int32=False) -> Tensor + structured_delegate: _convert_indices_from_coo_to_csr.out + +- func: _convert_indices_from_coo_to_csr.out(Tensor self, int size, *, bool out_int32=False, Tensor(a!) out) -> Tensor(a!) + structured: True + dispatch: + CPU: _convert_indices_from_coo_to_csr_structured_cpu + CUDA: _convert_indices_from_coo_to_csr_structured_cuda + ## NN wrappers - func: mse_loss.out(Tensor self, Tensor target, int reduction=Mean, *, Tensor(a!) out) -> Tensor(a!) diff --git a/aten/src/ATen/native/sparse/SparseCsrTensorMath.cpp b/aten/src/ATen/native/sparse/SparseCsrTensorMath.cpp index cba3f5c67a3..2d98eb0a7c3 100644 --- a/aten/src/ATen/native/sparse/SparseCsrTensorMath.cpp +++ b/aten/src/ATen/native/sparse/SparseCsrTensorMath.cpp @@ -15,6 +15,52 @@ #include namespace at { +namespace meta { + +TORCH_META_FUNC(_convert_indices_from_coo_to_csr) ( + const Tensor& self, const int64_t size, const bool out_int32 +) { + TORCH_CHECK(self.dim() <= 1, "Input is supposed to be a vector"); + ScalarType scalar_type = out_int32 ? ScalarType::Int : ScalarType::Long; + c10::TensorOptions options = TensorOptions().device(self.options().device()).dtype(scalar_type); + set_output(size + 1, options); +} + +} // namespace meta + +namespace { + +constexpr int64_t GRAIN_SIZE = at::internal::GRAIN_SIZE; + +template +void convert_indices_from_coo_to_csr_cpu(const Tensor& result, const Tensor& input, const int64_t size) { + int64_t numel = input.numel(); + const input_t* data_in = input.data_ptr(); + output_t* data_out = result.data_ptr(); + + if (numel == 0) { + result.zero_(); + return; + } + + for (int64_t i = 0; i <= data_in[0]; i++) + data_out[i] = static_cast(0); + + at::parallel_for(0, numel - 1, GRAIN_SIZE, [&](int64_t start, int64_t end) { + input_t curr_value = data_in[start], next_value; + for (int64_t i = start; i < end; i++) { + next_value = data_in[i + 1]; + for (; curr_value < next_value; curr_value++) + data_out[curr_value + 1] = static_cast(i + 1); + } + }); + + for (int64_t i = data_in[numel - 1] + 1; i < size + 1; i++) + data_out[i] = static_cast(numel); +} + +} // end anonymous namespace + namespace native { using namespace at::sparse_csr; @@ -322,5 +368,19 @@ Tensor& add_out_sparse_csr_cpu( return out; } +TORCH_IMPL_FUNC(_convert_indices_from_coo_to_csr_structured_cpu) ( + const Tensor& input, const int64_t size, const bool out_int32, const Tensor& result +) { + if (out_int32) { + AT_DISPATCH_INTEGRAL_TYPES(input.scalar_type(), "convert_indices_from_coo_to_csr_cpu", [&] { + convert_indices_from_coo_to_csr_cpu(result, input, size); + }); + } else { + AT_DISPATCH_INTEGRAL_TYPES(input.scalar_type(), "convert_indices_from_coo_to_csr_cpu", [&] { + convert_indices_from_coo_to_csr_cpu(result, input, size); + }); + } +} + } // namespace native } // namespace at diff --git a/aten/src/ATen/native/sparse/cuda/SparseCsrTensorMath.cu b/aten/src/ATen/native/sparse/cuda/SparseCsrTensorMath.cu index 0a45dcd0cc7..ea765e076fb 100644 --- a/aten/src/ATen/native/sparse/cuda/SparseCsrTensorMath.cu +++ b/aten/src/ATen/native/sparse/cuda/SparseCsrTensorMath.cu @@ -30,6 +30,44 @@ namespace at { namespace native { +namespace { + +template +__global__ void convert_indices_from_coo_to_csr_cuda_kernel(output_t* data_out, const input_t* data_in, const int64_t size, const int64_t numel) { + int64_t tid = blockDim.x * blockIdx.x + threadIdx.x; + if (tid == 0) { + for (int64_t i = 0; i <= data_in[0]; i++) + data_out[i] = static_cast(0); + } else if (tid < numel) { + for (int64_t i = data_in[tid - 1]; i < data_in[tid]; i++) + data_out[i + 1] = static_cast(tid); + } else if (tid == numel) { + for (int64_t i = data_in[numel - 1] + 1; i < size + 1; i++) + data_out[i] = static_cast(numel); + } +} + +template +void convert_indices_from_coo_to_csr_cuda(const Tensor& result, const Tensor& input, const int64_t size) { + int64_t numel = input.numel(); + const input_t* data_in = input.data_ptr(); + output_t* data_out = result.data_ptr(); + + if (numel == 0) { + result.zero_(); + return; + } + + // Run (numel + 1) threads... + int64_t THREADS = at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock; + int64_t BLOCKS = (numel + THREADS) / THREADS; + at::cuda::CUDAStream stream = at::cuda::getCurrentCUDAStream(); + convert_indices_from_coo_to_csr_cuda_kernel<<>>(data_out, data_in, size, numel); + C10_CUDA_KERNEL_LAUNCH_CHECK(); +} + +} // namespace + using namespace at::sparse_csr; // certain utiliy functions are usable from sparse COO. using namespace at::sparse; @@ -226,5 +264,19 @@ Tensor& add_out_sparse_csr_cuda( return out; } +TORCH_IMPL_FUNC(_convert_indices_from_coo_to_csr_structured_cuda) ( + const Tensor& input, const int64_t size, const bool out_int32, const Tensor& result +) { + if (out_int32) { + AT_DISPATCH_INTEGRAL_TYPES(input.scalar_type(), "convert_indices_from_coo_to_csr_cuda", [&] { + convert_indices_from_coo_to_csr_cuda(result, input, size); + }); + } else { + AT_DISPATCH_INTEGRAL_TYPES(input.scalar_type(), "convert_indices_from_coo_to_csr_cuda", [&] { + convert_indices_from_coo_to_csr_cuda(result, input, size); + }); + } +} + } // namespace native } // namespace at diff --git a/test/test_sparse_csr.py b/test/test_sparse_csr.py index e337efceca1..b9f48855e46 100644 --- a/test/test_sparse_csr.py +++ b/test/test_sparse_csr.py @@ -318,6 +318,11 @@ class TestSparseCSR(TestCase): @coalescedonoff @dtypes(torch.double) def test_coo_to_csr_convert(self, device, dtype, coalesced): + with self.assertRaisesRegex(RuntimeError, "Input is supposed to be a vector"): + torch._convert_indices_from_coo_to_csr( + torch.randint(100, (5, 5), device=device), + size=100) + size = (5, 5) sparse_dim = 2 nnz = 10 diff --git a/torch/_tensor.py b/torch/_tensor.py index ae2ede46969..24811da6c1a 100644 --- a/torch/_tensor.py +++ b/torch/_tensor.py @@ -993,8 +993,8 @@ class Tensor(torch._C._TensorBase): coalesced_self = self.coalesce() row_indices = coalesced_self.indices()[0] device = coalesced_self.values().device - arange = torch.arange(self.shape[0] + 1, dtype=row_indices.dtype, device=device) - crow_indices = torch.bucketize(arange, row_indices, out_int32=row_indices.dtype == torch.int32) + crow_indices = torch._convert_indices_from_coo_to_csr( + row_indices, self.shape[0], out_int32=row_indices.dtype == torch.int32) return torch.sparse_csr_tensor(crow_indices, coalesced_self.indices()[1].contiguous(), coalesced_self.values(),