_convert_coo_to_csr CPP and CUDA functionality (#61838)

Summary:
Fixes https://github.com/pytorch/pytorch/issues/57381 and improves https://github.com/pytorch/pytorch/pull/61340 via dedicated `coo_to_csr` functionalities.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/61838

Reviewed By: ezyang

Differential Revision: D30132736

Pulled By: cpuhrsch

fbshipit-source-id: a1fd074c0d70366a524d219a620b94f8bed71d7c
This commit is contained in:
rusty1s 2021-08-11 11:35:53 -07:00 committed by Facebook GitHub Bot
parent b8e6144e0a
commit 82123758ba
5 changed files with 128 additions and 2 deletions

View File

@ -8174,6 +8174,15 @@
CPU: searchsorted_cpu
CUDA: searchsorted_cuda
- func: _convert_indices_from_coo_to_csr(Tensor self, int size, *, bool out_int32=False) -> Tensor
structured_delegate: _convert_indices_from_coo_to_csr.out
- func: _convert_indices_from_coo_to_csr.out(Tensor self, int size, *, bool out_int32=False, Tensor(a!) out) -> Tensor(a!)
structured: True
dispatch:
CPU: _convert_indices_from_coo_to_csr_structured_cpu
CUDA: _convert_indices_from_coo_to_csr_structured_cuda
## NN wrappers
- func: mse_loss.out(Tensor self, Tensor target, int reduction=Mean, *, Tensor(a!) out) -> Tensor(a!)

View File

@ -15,6 +15,52 @@
#include <algorithm>
namespace at {
namespace meta {
TORCH_META_FUNC(_convert_indices_from_coo_to_csr) (
const Tensor& self, const int64_t size, const bool out_int32
) {
TORCH_CHECK(self.dim() <= 1, "Input is supposed to be a vector");
ScalarType scalar_type = out_int32 ? ScalarType::Int : ScalarType::Long;
c10::TensorOptions options = TensorOptions().device(self.options().device()).dtype(scalar_type);
set_output(size + 1, options);
}
} // namespace meta
namespace {
constexpr int64_t GRAIN_SIZE = at::internal::GRAIN_SIZE;
template <typename input_t, typename output_t>
void convert_indices_from_coo_to_csr_cpu(const Tensor& result, const Tensor& input, const int64_t size) {
int64_t numel = input.numel();
const input_t* data_in = input.data_ptr<input_t>();
output_t* data_out = result.data_ptr<output_t>();
if (numel == 0) {
result.zero_();
return;
}
for (int64_t i = 0; i <= data_in[0]; i++)
data_out[i] = static_cast<output_t>(0);
at::parallel_for(0, numel - 1, GRAIN_SIZE, [&](int64_t start, int64_t end) {
input_t curr_value = data_in[start], next_value;
for (int64_t i = start; i < end; i++) {
next_value = data_in[i + 1];
for (; curr_value < next_value; curr_value++)
data_out[curr_value + 1] = static_cast<output_t>(i + 1);
}
});
for (int64_t i = data_in[numel - 1] + 1; i < size + 1; i++)
data_out[i] = static_cast<output_t>(numel);
}
} // end anonymous namespace
namespace native {
using namespace at::sparse_csr;
@ -322,5 +368,19 @@ Tensor& add_out_sparse_csr_cpu(
return out;
}
TORCH_IMPL_FUNC(_convert_indices_from_coo_to_csr_structured_cpu) (
const Tensor& input, const int64_t size, const bool out_int32, const Tensor& result
) {
if (out_int32) {
AT_DISPATCH_INTEGRAL_TYPES(input.scalar_type(), "convert_indices_from_coo_to_csr_cpu", [&] {
convert_indices_from_coo_to_csr_cpu<scalar_t, int>(result, input, size);
});
} else {
AT_DISPATCH_INTEGRAL_TYPES(input.scalar_type(), "convert_indices_from_coo_to_csr_cpu", [&] {
convert_indices_from_coo_to_csr_cpu<scalar_t, int64_t>(result, input, size);
});
}
}
} // namespace native
} // namespace at

View File

@ -30,6 +30,44 @@
namespace at {
namespace native {
namespace {
template <typename input_t, typename output_t>
__global__ void convert_indices_from_coo_to_csr_cuda_kernel(output_t* data_out, const input_t* data_in, const int64_t size, const int64_t numel) {
int64_t tid = blockDim.x * blockIdx.x + threadIdx.x;
if (tid == 0) {
for (int64_t i = 0; i <= data_in[0]; i++)
data_out[i] = static_cast<output_t>(0);
} else if (tid < numel) {
for (int64_t i = data_in[tid - 1]; i < data_in[tid]; i++)
data_out[i + 1] = static_cast<output_t>(tid);
} else if (tid == numel) {
for (int64_t i = data_in[numel - 1] + 1; i < size + 1; i++)
data_out[i] = static_cast<output_t>(numel);
}
}
template <typename input_t, typename output_t>
void convert_indices_from_coo_to_csr_cuda(const Tensor& result, const Tensor& input, const int64_t size) {
int64_t numel = input.numel();
const input_t* data_in = input.data_ptr<input_t>();
output_t* data_out = result.data_ptr<output_t>();
if (numel == 0) {
result.zero_();
return;
}
// Run (numel + 1) threads...
int64_t THREADS = at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock;
int64_t BLOCKS = (numel + THREADS) / THREADS;
at::cuda::CUDAStream stream = at::cuda::getCurrentCUDAStream();
convert_indices_from_coo_to_csr_cuda_kernel<<<BLOCKS, THREADS, 0, stream>>>(data_out, data_in, size, numel);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
} // namespace
using namespace at::sparse_csr;
// certain utiliy functions are usable from sparse COO.
using namespace at::sparse;
@ -226,5 +264,19 @@ Tensor& add_out_sparse_csr_cuda(
return out;
}
TORCH_IMPL_FUNC(_convert_indices_from_coo_to_csr_structured_cuda) (
const Tensor& input, const int64_t size, const bool out_int32, const Tensor& result
) {
if (out_int32) {
AT_DISPATCH_INTEGRAL_TYPES(input.scalar_type(), "convert_indices_from_coo_to_csr_cuda", [&] {
convert_indices_from_coo_to_csr_cuda<scalar_t, int>(result, input, size);
});
} else {
AT_DISPATCH_INTEGRAL_TYPES(input.scalar_type(), "convert_indices_from_coo_to_csr_cuda", [&] {
convert_indices_from_coo_to_csr_cuda<scalar_t, int64_t>(result, input, size);
});
}
}
} // namespace native
} // namespace at

View File

@ -318,6 +318,11 @@ class TestSparseCSR(TestCase):
@coalescedonoff
@dtypes(torch.double)
def test_coo_to_csr_convert(self, device, dtype, coalesced):
with self.assertRaisesRegex(RuntimeError, "Input is supposed to be a vector"):
torch._convert_indices_from_coo_to_csr(
torch.randint(100, (5, 5), device=device),
size=100)
size = (5, 5)
sparse_dim = 2
nnz = 10

View File

@ -993,8 +993,8 @@ class Tensor(torch._C._TensorBase):
coalesced_self = self.coalesce()
row_indices = coalesced_self.indices()[0]
device = coalesced_self.values().device
arange = torch.arange(self.shape[0] + 1, dtype=row_indices.dtype, device=device)
crow_indices = torch.bucketize(arange, row_indices, out_int32=row_indices.dtype == torch.int32)
crow_indices = torch._convert_indices_from_coo_to_csr(
row_indices, self.shape[0], out_int32=row_indices.dtype == torch.int32)
return torch.sparse_csr_tensor(crow_indices,
coalesced_self.indices()[1].contiguous(),
coalesced_self.values(),