Revert "Extend CSR constructor to support batched indices and values"

This reverts commit c074a53002.

Reverted https://github.com/pytorch/pytorch/pull/74542 on behalf of https://github.com/malfet
This commit is contained in:
PyTorch MergeBot 2022-03-30 19:54:26 +00:00
parent 2e4152b118
commit cc23725e89
9 changed files with 144 additions and 279 deletions

View File

@ -60,22 +60,17 @@ SparseCsrTensorImpl::SparseCsrTensorImpl(
}
void SparseCsrTensorImpl::resize_(int64_t nnz, IntArrayRef size) {
auto rows = size[size.size() - 2];
auto cols = size[size.size() - 1];
auto rows = size[0];
auto cols = size[1];
auto old_crow_indices_size = crow_indices_.size(-1);
auto new_crow_indices_size = DimVector(size.slice(0, size.size() - 2));
new_crow_indices_size.push_back(rows + 1);
crow_indices_.resize_(new_crow_indices_size);
crow_indices_.resize_({rows + 1});
if (rows + 1 >= old_crow_indices_size) {
crow_indices_.narrow(-1, old_crow_indices_size, rows + 1 - old_crow_indices_size).fill_(nnz);
} else {
crow_indices_.narrow(-1, rows, 1).fill_(std::min<int64_t>(nnz, rows*cols));
}
auto col_indices_values_size = DimVector(size.slice(0, size.size() - 2));
col_indices_values_size.push_back(std::min<int64_t>(nnz, rows*cols));
col_indices_.resize_(col_indices_values_size);
values_.resize_(col_indices_values_size);
col_indices_.resize_({std::min<int64_t>(nnz, rows*cols)});
values_.resize_({std::min<int64_t>(nnz, rows*cols)});
sizes_and_strides_.set_sizes(size);
}

View File

@ -43,7 +43,7 @@ struct TORCH_API SparseCsrTensorImpl : public TensorImpl {
const Tensor& crow_indices() const { return crow_indices_; }
const Tensor& col_indices() const { return col_indices_; }
const Tensor& values() const { return values_; }
int nnz() { return col_indices_.size(-1); }
int nnz() { return values_.size(0); }
/**
* Return a TensorImpl that is a shallow-copy of this TensorImpl.

View File

@ -101,7 +101,7 @@ class MklSparseCsrDescriptor
sparse_matrix_t raw_descriptor;
// Assuming that the last two dimensions are block elements of the matrix
if (values.dim() == 3 && crow_indices.dim() == 1 && col_indices.dim() == 1) {
if (values.dim() == 3) {
TORCH_CHECK(
values.size(-1) == values.size(-2),
"MKL Sparse doesn't support matrices with non-square blocks.");

View File

@ -9,7 +9,6 @@
#include <ATen/SparseCsrTensorImpl.h>
#include <ATen/SparseCsrTensorUtils.h>
#include <ATen/SparseTensorImpl.h>
#include <ATen/native/LinearAlgebraUtils.h>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
@ -57,51 +56,29 @@ void _validate_sparse_csr_tensor_args(const Tensor& crow_indices, const Tensor&
// Shape and Strides invariants
TORCH_CHECK(
size.size() >= 2,
"size of a batched CSR tensor must have length >= 2, but got: ",
size.size() == 2,
"size of a CSR tensor must be of length 2, but got: ",
size.size());
TORCH_CHECK(
crow_indices.dim() >= 1,
"crow_indices must have dim >= 1 but got crow_indices.dim() = ",
crow_indices.dim() == 1,
"crow_indices must have dim=1 but got crow_indices.dim()=",
crow_indices.dim());
TORCH_CHECK(
col_indices.dim() >= 1,
"col_indices must have dim >= 1 but got col_indices.dim() = ",
col_indices.dim() == 1,
"col_indices must have dim=1 but got col_indices.dim()=",
col_indices.dim());
TORCH_CHECK(
values.dim() >= 1,
"values must have dim >= 1 but got values.dim() = ",
values.dim() == 1,
"values must have dim=1 but got values.dim()=",
values.dim());
// Note, this check also enforces `crow_indices.numel() >= 1`
TORCH_CHECK(
crow_indices.dim() == col_indices.dim(),
"Number of dimensions of crow_indices and col_indices must be the same.");
TORCH_CHECK(
crow_indices.dim() == values.dim(),
"Number of dimensions of indices and values must be the same.");
TORCH_CHECK(
crow_indices.dim() == size.size() - 1,
"Number of dimensions of indices must be one less than the number of dimensions of the provided size.");
// All batch sizes must be the same
auto batch_size = size.slice(0, size.size() - 2);
auto crow_indices_batch_size = crow_indices.sizes().slice(0, crow_indices.dim() - 1);
auto col_indices_batch_size = col_indices.sizes().slice(0, col_indices.dim() - 1);
auto values_batch_size = values.sizes().slice(0, values.dim() - 1);
TORCH_CHECK(
batch_size == crow_indices_batch_size &&
batch_size == col_indices_batch_size &&
batch_size == values_batch_size,
"All batch dimensions of the provided size, indices, and values must be the same.");
// Note, this check also enforces `crow_indices.size(-1) >= 1`
TORCH_CHECK(
crow_indices.size(-1) == (size[size.size() - 2] + 1),
"crow_indices.size(-1) must be equal to size[-2] + 1 (that is ", size[size.size() - 2] + 1, "), but got: ",
crow_indices.size(-1));
crow_indices.numel() == (size[0] + 1),
"crow_indices.numel() must be size(0) + 1, but got: ",
crow_indices.numel());
TORCH_CHECK(
col_indices.numel() == values.numel(),
"col_indices and values must have the same number of elements, but got col_indices.numel(): ",
"col_indices and values must have equal sizes, but got col_indices.numel(): ",
col_indices.numel(),
", values.numel(): ",
values.numel());
@ -109,28 +86,22 @@ void _validate_sparse_csr_tensor_args(const Tensor& crow_indices, const Tensor&
// Indices invariants
AT_DISPATCH_INDEX_TYPES(crow_indices.scalar_type(), "csr_construct_check", [&] {
Tensor crow_indices_cpu = crow_indices.to(kCPU);
auto crow_indices_data_ptr = crow_indices_cpu.data_ptr<index_t>();
auto batch_stride = crow_indices_cpu.dim() >= 2 ? crow_indices_cpu.stride(-2) : 0;
for (const auto batch_id : c10::irange(batchCount(crow_indices_cpu))) {
TORCH_CHECK(
crow_indices_data_ptr[batch_id*batch_stride] == 0,
"(Batch element ", batch_id, ") ",
": 0th value of crow_indices must be 0, but it is ", crow_indices_data_ptr[batch_id*batch_stride]);
TORCH_CHECK(
crow_indices_data_ptr[batch_id*batch_stride + crow_indices.size(-1) - 1] == col_indices.size(-1),
"(Batch element ", batch_id, ") ",
"last value of crow_indices should be equal to the length of col_indices.");
auto crow_indices_accessor = crow_indices_cpu.accessor<index_t, 1>();
TORCH_CHECK(
crow_indices_accessor[0] == 0, "0th value of crow_indices must be 0.");
for (int i = 1; i <= size[size.size() - 2]; i++) {
TORCH_CHECK(
crow_indices_data_ptr[batch_id*batch_stride + i - 1] <= crow_indices_data_ptr[batch_id*batch_stride + i],
"(Batch element ", batch_id, ") ",
"at position i = ", i, ", the condition crow_indices[i - 1] <= crow_indices[i] fails");
}
TORCH_CHECK(
crow_indices_accessor[crow_indices.numel() - 1] == col_indices.numel(),
"last value of crow_indices should be equal to the length of col_indices.");
for (int i = 1; i <= size[0]; i++) {
TORCH_CHECK(
crow_indices_accessor[i - 1] <= crow_indices_accessor[i],
"at position i = ", i, ", this condition crow_indices[i - 1] <= crow_indices[i] fails");
}
if (col_indices.numel() > 0) {
TORCH_CHECK(0 <= col_indices.min().item<index_t>(), "col_indices.min() should be greater or equal to zero");
TORCH_CHECK(size[size.size() - 1] > col_indices.max().item<index_t>(), "size[-1] should be greater than col_indices.max()");
TORCH_CHECK(size[1] > col_indices.max().item<index_t>(), "size(1) should be greater than col_indices.max()");
}
});
@ -242,10 +213,13 @@ Tensor sparse_csr_tensor(
c10::optional<bool> pin_memory) {
// See [Note: hacky wrapper removal for TensorOptions]
TensorOptions options = TensorOptions().dtype(dtype).layout(layout).device(device).pinned_memory(pin_memory);
// std::array<int64_t, 2> size = {0, 0};
auto size = DimVector(IntArrayRef(col_indices.sizes().data(), col_indices.dim() - 1));
size.push_back(crow_indices.size(-1) - 1);
size.push_back(col_indices.max().item<int64_t>() + 1);
std::array<int64_t, 2> size = {0, 0};
if (col_indices.numel() > 0) {
AT_DISPATCH_INDEX_TYPES(col_indices.scalar_type(), "csr_construct_check", [&] {
size[0] = crow_indices.numel() - 1;
size[1] = col_indices.max().item<index_t>() + 1;
});
}
at::native::_validate_sparse_csr_tensor_args(crow_indices, col_indices, values, size);
@ -269,21 +243,16 @@ Tensor empty_sparse_csr(
c10::optional<MemoryFormat> optional_memory_format) {
check_size_nonnegative(size);
TORCH_CHECK(size.size() >= 2, "torch.empty: Only batched sparse CSR matrices are supported, but got size ", size);
TORCH_CHECK(size.size() == 2, "torch.empty: Only 2D sparse CSR tensors are supported.");
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(layout == Layout::SparseCsr);
auto rows = size[size.size() - 2];
auto rows = size[0];
int64_t nnz = 0;
auto crow_indices_size = DimVector(size.slice(0, size.size() - 2));
crow_indices_size.push_back(rows + 1);
auto col_indices_values_size = DimVector(size.slice(0, size.size() - 2));
col_indices_values_size.push_back(nnz);
TensorOptions options = TensorOptions().dtype(ScalarType::Long).layout(Layout::Strided).device(device).pinned_memory(pin_memory);
auto crow_indices = at::empty(crow_indices_size, options);
auto col_indices = at::empty(col_indices_values_size, options);
auto values = at::empty(col_indices_values_size, options.dtype(dtype));
auto crow_indices = at::empty({rows + 1}, options);
auto col_indices = at::empty({nnz}, options);
auto values = at::empty({nnz}, options.dtype(dtype));
return at::native::_sparse_csr_tensor_unsafe(
crow_indices,
@ -301,13 +270,13 @@ const Tensor& resize_sparse_csr_(
IntArrayRef size,
c10::optional<MemoryFormat> optional_memory_format) {
check_size_nonnegative(size);
TORCH_CHECK(size.size() >= 2, "torch.resize_: Only batched sparse CSR matrices are supported, but got size ", size);
TORCH_CHECK(size.size() == 2, "torch.resize_: Only 2D sparse CSR tensors are supported.");
TORCH_CHECK(
self.size(-1) <= size[size.size() - 1],
self.size(1) <= size[1],
"torch.resize_: Resizing columns of sparse CSR tensors to a smaller value is not supported. ",
"The original number of columns is ",
self.size(-1),
" while the requested new number of columns is ", size[size.size() - 1], ".");
self.size(1),
" while the requested new number of columns is ", size[1], ".");
get_sparse_csr_impl(self)->resize_(self._nnz(), size);
return self;
}

View File

@ -638,10 +638,13 @@ void add_out_dense_sparse_csr_cpu(
" in add operation");
auto src_values = src.values();
auto src_crow_indices = src.crow_indices();
auto src_col_indices = src.col_indices();
resize_output(out, dense.sizes());
Tensor resultBuffer = out;
Tensor valuesBuffer = src_values.to(commonDtype);
if (out.scalar_type() != commonDtype) {
resultBuffer = dense.to(commonDtype);
@ -649,15 +652,6 @@ void add_out_dense_sparse_csr_cpu(
resultBuffer.copy_(dense);
}
if (src._nnz() == 0) {
return;
}
auto valuesBuffer = src_values.to(commonDtype).view({-1, src_values.size(-1)});
resultBuffer = resultBuffer.view({-1, out.size(-2), out.size(-1)});
auto src_crow_indices = src.crow_indices().view({-1, src.crow_indices().size(-1)});
auto src_col_indices = src.col_indices().view({-1, src.col_indices().size(-1)});
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(
kHalf,
kBool,
@ -677,26 +671,27 @@ void add_out_dense_sparse_csr_cpu(
&alpha,
&src_crow_indices,
&src_col_indices]() {
auto batch_count = resultBuffer.dim() > 2 ? resultBuffer.size(-3) : 1;
auto values_accessor = valuesBuffer.accessor<scalar_t, 2>();
auto values_accessor = valuesBuffer.accessor<scalar_t, 1>();
scalar_t* out_ptr = resultBuffer.data_ptr<scalar_t>();
scalar_t cast_value = alpha.to<scalar_t>();
auto crow_indices_accessor =
src_crow_indices.accessor<index_t, 2>();
src_crow_indices.accessor<index_t, 1>();
auto col_indices_accessor =
src_col_indices.accessor<index_t, 2>();
auto out_strides = resultBuffer.strides();
src_col_indices.accessor<index_t, 1>();
auto out_strides0 = resultBuffer.strides()[0];
auto out_strides1 = resultBuffer.strides()[1];
for (const auto batch_idx : c10::irange(batch_count)) {
for (const auto irow : c10::irange(src_crow_indices.size(-1) - 1)) {
index_t start_index = crow_indices_accessor[batch_idx][irow];
index_t end_index = crow_indices_accessor[batch_idx][irow + 1];
for (const auto i : c10::irange(start_index, end_index)) {
auto icol = col_indices_accessor[batch_idx][i];
auto index = batch_idx * out_strides[0] + irow * out_strides[1] + icol * out_strides[2];
out_ptr[index] += cast_value * values_accessor[batch_idx][i];
}
for (index_t irow = 0; irow < src_crow_indices.size(0) - 1;
++irow) {
index_t start_index = crow_indices_accessor[irow];
index_t end_index = crow_indices_accessor[irow + 1];
for (index_t i = start_index; i < end_index; ++i) {
auto icol = col_indices_accessor[i];
auto index = resultBuffer.storage_offset() +
irow * out_strides0 + icol * out_strides1;
out_ptr[index] += cast_value * values_accessor[i];
}
}
});

View File

@ -978,18 +978,6 @@ void add_out_sparse_csr(
auto B_col_indices_ptr = B_col_indices.data_ptr<int>();
auto C_col_indices_ptr = C_col_indices.data_ptr<int>();
// Windows compilers don't support nested macros
// so we need this lambda outside of the AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES
auto fix_nnz = [&C_crow_indices, &m](int nnz) -> int {
// For some reason POINTER_MODE_HOST is not working here
// Let's extract manually the nnz from the C_crow_indices
#if AT_ROCM_ENABLED()
return std::max({nnz, C_crow_indices.narrow(-1, m, 1).item<int>()});
#else
return nnz;
#endif
};
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(
C.scalar_type(), "add_out_sparse_csr_cuda_impl", [&] {
auto beta_ = beta.to<scalar_t>();
@ -1050,8 +1038,6 @@ void add_out_sparse_csr(
&nnzC,
work_data.get());
nnzC = fix_nnz(nnzC);
// Resize result using nnz information from cusparse
col_indices_and_values_resize_(C, nnzC);
C_col_indices = C.col_indices();

View File

@ -159,26 +159,18 @@ Tensor& add_out_dense_sparse_csr_cuda(
" in add operation");
Tensor src_values = src.values();
Tensor src_crow_indices = src.crow_indices();
Tensor src_col_indices = src.col_indices();
resize_output(output, dense.sizes());
Tensor resultBuffer = output;
Tensor valuesBuffer = src_values.to(commonDtype);
if (output.scalar_type() != commonDtype) {
resultBuffer = dense.to(commonDtype);
} else if (!is_same_tensor(output, dense)) {
resultBuffer.copy_(dense);
}
if (src._nnz() == 0) {
return output;
}
auto valuesBuffer = src_values.to(commonDtype).view({-1, src_values.size(-1)});
resultBuffer = resultBuffer.view({-1, output.size(-2), output.size(-1)});
auto src_crow_indices = src.crow_indices().view({-1, src.crow_indices().size(-1)});
auto src_col_indices = src.col_indices().view({-1, src.col_indices().size(-1)});
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(
kHalf, kBool, kBFloat16,
commonDtype,
@ -188,7 +180,6 @@ Tensor& add_out_dense_sparse_csr_cuda(
src_crow_indices.scalar_type(),
"csr_add_out_crow_indices",
[&valuesBuffer, &resultBuffer, &alpha, &src_crow_indices, &src_col_indices]() {
auto batch_count = resultBuffer.dim() > 2 ? resultBuffer.size(-3) : 1;
scalar_t* values_accessor = valuesBuffer.data_ptr<scalar_t>();
scalar_t* out_ptr = resultBuffer.data_ptr<scalar_t>();
scalar_t cast_value = alpha.to<scalar_t>();
@ -198,11 +189,8 @@ Tensor& add_out_dense_sparse_csr_cuda(
int64_t out_storage_offset = resultBuffer.storage_offset();
auto out_strides = resultBuffer.strides();
auto out_strides0 = out_strides[0];
auto out_strides1 = out_strides[1];
auto crow_stride0 = src_crow_indices.stride(0);
auto col_stride0 = src_col_indices.stride(0);
auto val_stride0 = valuesBuffer.stride(0);
int64_t out_strides0 = out_strides[0];
int64_t out_strides1 = out_strides[1];
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
at::cuda::ThrustAllocator allocator;
@ -212,29 +200,24 @@ Tensor& add_out_dense_sparse_csr_cuda(
thrust::for_each(
policy,
thrust::make_counting_iterator(int64_t(0)),
thrust::make_counting_iterator(int64_t(src_crow_indices.size(-1) - 1)),
thrust::make_counting_iterator(int64_t(src_crow_indices.size(0) - 1)),
[values_accessor,
crow_indices_accessor,
col_indices_accessor,
out_ptr,
cast_value,
out_storage_offset,
out_strides0,
out_strides1,
crow_stride0,
col_stride0,
val_stride0,
batch_count
cast_value,
out_strides1
]__device__(int64_t irow) {
for (index_t batch_idx = 0; batch_idx < batch_count; batch_idx++) {
index_t start_index = crow_indices_accessor[batch_idx*crow_stride0 + irow];
index_t end_index = crow_indices_accessor[batch_idx*crow_stride0 + irow + 1];
index_t start_index = crow_indices_accessor[irow];
index_t end_index = crow_indices_accessor[irow + 1];
for (index_t i = start_index; i < end_index; ++i) {
auto icol = col_indices_accessor[batch_idx*col_stride0 + i];
auto index = batch_idx * out_strides0 + irow * out_strides1 + icol;
out_ptr[index] += cast_value * values_accessor[batch_idx*val_stride0 + i];
auto icol = col_indices_accessor[i];
auto index = out_storage_offset + irow * out_strides0 + icol * out_strides1;
out_ptr[index] += cast_value * values_accessor[i];
}
}
});
});
});

View File

@ -165,33 +165,6 @@ class TestSparseCSR(TestCase):
self.assertEqual(torch.tensor(col_indices, dtype=index_dtype), sparse.col_indices())
self.assertEqual(torch.tensor(values, dtype=dtype), sparse.values())
@dtypes(*get_all_dtypes())
def test_sparse_csr_batch_constructor(self, device, dtype):
batch_shape = (2, 3)
crow_indices = torch.tensor([0, 2, 4], device=device).repeat(6, 1).reshape(*batch_shape, -1)
col_indices = torch.tensor([0, 1, 0, 1], device=device).repeat(6, 1).reshape(*batch_shape, -1)
values = torch.tensor([1, 2, 3, 4], device=device, dtype=dtype).repeat(6, 1).reshape(*batch_shape, -1)
for index_dtype in [torch.int32, torch.int64]:
sparse = torch.sparse_csr_tensor(crow_indices.to(index_dtype),
col_indices.to(index_dtype),
values,
size=(*batch_shape, 2, 10),
dtype=dtype,
device=device)
self.assertEqual((*batch_shape, 2, 10), sparse.shape)
self.assertEqual(crow_indices.to(index_dtype), sparse.crow_indices())
self.assertEqual(col_indices.to(index_dtype), sparse.col_indices())
self.assertEqual(values, sparse.values())
@dtypes(*get_all_dtypes())
def test_sparse_csr_batch_constructor_shape_inference(self, device, dtype):
batch_shape = (2, 3)
crow_indices = torch.tensor([0, 2, 4], device=device).repeat(6, 1).reshape(*batch_shape, -1)
col_indices = torch.tensor([0, 1, 0, 1], device=device).repeat(6, 1).reshape(*batch_shape, -1)
values = torch.tensor([1, 2, 3, 4], device=device, dtype=dtype).repeat(6, 1).reshape(*batch_shape, -1)
sparse = torch.sparse_csr_tensor(crow_indices, col_indices, values, dtype=dtype, device=device)
self.assertEqual((*batch_shape, crow_indices.shape[-1] - 1, col_indices.max() + 1), sparse.shape)
@dtypes(*get_all_dtypes())
def test_sparse_csr_constructor_from_lists(self, device, dtype):
# without size
@ -225,17 +198,15 @@ class TestSparseCSR(TestCase):
@dtypes(*get_all_dtypes())
def test_empty(self, device, dtype):
ns = [5, 2, 0]
batch_shapes = [(), (2,), (2, 3)]
for m, n, b in itertools.product(ns, ns, batch_shapes):
shape = (*b, m, n)
for shape in itertools.product(ns, ns):
result = torch.empty(shape, dtype=dtype, device=device, layout=torch.sparse_csr)
self.assertEqual(result.shape, shape)
self.assertEqual(result.dtype, dtype)
self.assertEqual(result.device, torch.device(device))
self.assertEqual(result.layout, torch.sparse_csr)
self.assertEqual(result.crow_indices().shape, (*b, shape[-2] + 1,))
self.assertEqual(result.col_indices().shape, (*b, 0,))
self.assertEqual(result.values().shape, (*b, 0,))
self.assertEqual(result.crow_indices().shape, (shape[0] + 1,))
self.assertEqual(result.col_indices().shape, (0,))
self.assertEqual(result.values().shape, (0,))
self.assertEqual(result._nnz(), 0)
self.assertEqual(result.crow_indices().device, torch.device(device))
self.assertEqual(result.col_indices().device, torch.device(device))
@ -247,22 +218,23 @@ class TestSparseCSR(TestCase):
@skipMeta
@dtypes(*get_all_dtypes())
def test_empty_errors(self, device, dtype):
with self.assertRaisesRegex(RuntimeError, "torch.empty: Only batched sparse CSR matrices are supported, but got size"):
with self.assertRaisesRegex(RuntimeError, "torch.empty: Only 2D sparse CSR tensors are supported."):
torch.empty((5,), dtype=dtype, device=device, layout=torch.sparse_csr)
with self.assertRaisesRegex(RuntimeError, "torch.empty: Only 2D sparse CSR tensors are supported."):
torch.empty((2, 3, 4), dtype=dtype, device=device, layout=torch.sparse_csr)
@skipMeta
@dtypes(*get_all_dtypes())
def test_clone(self, device, dtype):
from operator import mul
from functools import reduce
for batch_shape in ((), (2,), (2, 3)):
prod = reduce(mul, batch_shape, 1)
crow_indices = torch.tensor([0, 2, 4], device=device).repeat(prod, 1).reshape(*batch_shape, -1)
col_indices = torch.tensor([0, 1, 0, 1], device=device).repeat(prod, 1).reshape(*batch_shape, -1)
values = torch.tensor([1, 2, 3, 4], device=device, dtype=dtype).repeat(prod, 1).reshape(*batch_shape, -1)
sparse = torch.sparse_csr_tensor(crow_indices, col_indices, values, dtype=dtype, device=device)
cloned_sparse = sparse.clone()
self.assertEqual(sparse, cloned_sparse)
x = torch.sparse_csr_tensor([0, 2, 4],
[0, 1, 0, 1],
[1, 2, 3, 4],
dtype=dtype,
device=device)
y = x.clone()
self.assertEqual(x, y)
@skipMeta
@dtypes(*get_all_dtypes())
@ -277,10 +249,9 @@ class TestSparseCSR(TestCase):
self.assertEqual(a, b)
ns = [5, 2, 0]
batch_shapes = [(), (2,), (2, 3)]
for (m, n, b), index_dtype in zip(itertools.product(ns, ns, batch_shapes), [torch.int32, torch.int64]):
run_test((*b, m, n), 0, index_dtype)
run_test((*b, m, n), m * n, index_dtype)
for shape, index_dtype in zip(itertools.product(ns, ns), [torch.int32, torch.int64]):
run_test(shape, 0, index_dtype)
run_test(shape, shape[0] * shape[1], index_dtype)
@skipMeta
@dtypes(*get_all_dtypes())
@ -304,31 +275,25 @@ class TestSparseCSR(TestCase):
@skipMeta
@dtypes(*get_all_dtypes())
def test_resize(self, device, dtype):
batch_shapes = [(), (2,), (2, 3)]
for index_dtype, b in zip([torch.int32, torch.int64], batch_shapes):
shape = (*b, 2, 3)
for index_dtype in [torch.int32, torch.int64]:
shape = (2, 3)
nnz = 6
a = self.genSparseCSRTensor(shape, nnz, dtype=dtype, device=device, index_dtype=index_dtype)
new_shape = (*b, 4, 5)
new_shape = (4, 5)
a.resize_(new_shape)
self.assertEqual(a.shape, new_shape)
# resize to larger shape doesn't add specified elements
self.assertEqual(a._nnz(), nnz)
new_shape = (*b, 1, 5)
new_shape = (1, 5)
a.resize_(new_shape)
self.assertEqual(a.shape, new_shape)
# resize to smaller shape trims specified elements
self.assertEqual(a._nnz(), 5)
# trim batched dimensions
a.resize_(new_shape[-2], new_shape[-1])
self.assertEqual(a.shape, (new_shape[-2], new_shape[-1]))
self.assertEqual(a._nnz(), 5)
@skipMeta
@dtypes(*get_all_dtypes())
def test_resize_errors(self, device, dtype):
@ -337,7 +302,7 @@ class TestSparseCSR(TestCase):
nnz = 6
a = self.genSparseCSRTensor(shape, nnz, dtype=dtype, device=device, index_dtype=index_dtype)
with self.assertRaisesRegex(RuntimeError, "torch.resize_: Only batched sparse CSR matrices are supported"):
with self.assertRaisesRegex(RuntimeError, "torch.resize_: Only 2D sparse CSR tensors are supported."):
new_shape = (4,)
a.resize_(new_shape)
@ -382,62 +347,49 @@ class TestSparseCSR(TestCase):
torch.tensor([1, 2, 3, 4]))
def test_factory_shape_invariants_check(self, device):
crow_indices = torch.tensor([0, 2, 4], device=device)
col_indices = torch.tensor([0, 1, 0, 1], device=device)
values = torch.tensor([1, 2, 3, 4], device=device)
crow_indices = [0, 2, 4]
col_indices = [0, 1, 0, 1]
values = [1, 2, 3, 4]
size = (2, 10)
torch.sparse_csr_tensor(crow_indices, col_indices, values, size, device=device)
torch.sparse_csr_tensor(torch.tensor(crow_indices), torch.tensor(col_indices), torch.tensor(values), size,
device=device)
with self.assertRaisesRegex(RuntimeError, r"size of a batched CSR tensor must have length >= 2, but got: 1"):
torch.sparse_csr_tensor(crow_indices, col_indices, values,
size=(2,),
with self.assertRaisesRegex(RuntimeError, r"size of a CSR tensor must be of length 2, but got: 3"):
torch.sparse_csr_tensor(torch.tensor(crow_indices), torch.tensor(col_indices), torch.tensor(values),
size=(2, 10, 2),
device=device)
with self.assertRaisesRegex(RuntimeError, r"crow_indices must have dim >= 1 but got crow_indices\.dim\(\)\ = 0"):
torch.sparse_csr_tensor(torch.zeros((), device=device, dtype=torch.int64),
col_indices,
values,
with self.assertRaisesRegex(RuntimeError, r"crow_indices must have dim\=1 but got crow_indices\.dim\(\)\=2"):
torch.sparse_csr_tensor(torch.tensor(crow_indices).repeat(2, 1),
torch.tensor(col_indices),
torch.tensor(values),
size,
device=device)
with self.assertRaisesRegex(RuntimeError, r"col_indices must have dim >= 1 but got col_indices\.dim\(\)\ = 0"):
torch.sparse_csr_tensor(crow_indices,
torch.zeros((), device=device, dtype=torch.int64),
values,
with self.assertRaisesRegex(RuntimeError, r"col_indices must have dim\=1 but got col_indices\.dim\(\)\=2"):
torch.sparse_csr_tensor(torch.tensor(crow_indices),
torch.tensor(col_indices).repeat(2, 1),
torch.tensor(values),
size,
device=device)
with self.assertRaisesRegex(RuntimeError, r"values must have dim >= 1 but got values\.dim\(\)\ = 0"):
torch.sparse_csr_tensor(crow_indices,
col_indices,
torch.zeros((), device=device, dtype=torch.int64),
with self.assertRaisesRegex(RuntimeError, r"values must have dim\=1 but got values\.dim\(\)\=2"):
torch.sparse_csr_tensor(torch.tensor(crow_indices),
torch.tensor(col_indices),
torch.tensor(values).repeat(2, 1),
size,
device=device)
with self.assertRaisesRegex(RuntimeError,
r"crow_indices\.size\(-1\) must be equal to size\[-2\] \+ 1 \(that is 2\), but got: 3"):
torch.sparse_csr_tensor(crow_indices, col_indices, values, (1, 1),
r"crow_indices\.numel\(\) must be size\(0\) \+ 1, but got: 3"):
torch.sparse_csr_tensor(torch.tensor(crow_indices), torch.tensor(col_indices), torch.tensor(values), (1, 1),
device=device)
with self.assertRaisesRegex(RuntimeError,
r"Number of dimensions of crow_indices and col_indices must be the same"):
torch.sparse_csr_tensor(crow_indices, col_indices.repeat(2, 1), values, size,
device=device)
with self.assertRaisesRegex(RuntimeError,
r"Number of dimensions of indices and values must be the same"):
torch.sparse_csr_tensor(crow_indices, col_indices, values.repeat(2, 1), size,
device=device)
with self.assertRaisesRegex(RuntimeError,
r"Number of dimensions of indices must be one less"):
torch.sparse_csr_tensor(crow_indices.repeat(2, 1), col_indices.repeat(2, 1), values.repeat(2, 1), size,
device=device)
with self.assertRaisesRegex(RuntimeError,
r"All batch dimensions of the provided size, indices, and values must be the same"):
torch.sparse_csr_tensor(crow_indices.repeat(2, 1), col_indices.repeat(3, 1), values.repeat(4, 1), (2, 2, 10),
r"col_indices and values must have equal sizes, " +
r"but got col_indices\.numel\(\): 3, values\.numel\(\): 4"):
torch.sparse_csr_tensor(torch.tensor(crow_indices), torch.tensor([0, 1, 0]), torch.tensor(values), size,
device=device)
def test_factory_indices_invariants_check(self, device):
@ -456,7 +408,7 @@ class TestSparseCSR(TestCase):
with self.assertRaisesRegex(RuntimeError,
r"at position i \= 2," +
r" the condition crow_indices\[i - 1\] <\= crow_indices\[i\] fails"):
r" this condition crow_indices\[i - 1\] <\= crow_indices\[i\] fails"):
torch.sparse_csr_tensor(torch.tensor([0, 5, 4]), torch.tensor(col_indices), torch.tensor(values), size,
device=device)
@ -464,7 +416,7 @@ class TestSparseCSR(TestCase):
torch.sparse_csr_tensor(torch.tensor(crow_indices), torch.tensor([0, -1, 0, 1]), torch.tensor(values), size,
device=device)
with self.assertRaisesRegex(RuntimeError, r"size\[-1\] should be greater than col_indices\.max\(\)"):
with self.assertRaisesRegex(RuntimeError, r"size\(1\) should be greater than col_indices\.max\(\)"):
torch.sparse_csr_tensor(torch.tensor(crow_indices), torch.tensor([0, 11, 0, 1]), torch.tensor(values), size,
device=device)
@ -569,12 +521,12 @@ class TestSparseCSR(TestCase):
sparse = dense.to_sparse_csr()
self.assertEqual(sparse.to_dense(), dense)
batch_shape = (2, 3)
crow_indices = torch.tensor([0, 3, 5], device=device).repeat(6, 1).reshape(*batch_shape, -1)
col_indices = torch.tensor([0, 1, 2, 0, 1], device=device).repeat(6, 1).reshape(*batch_shape, -1)
values = torch.tensor([1, 2, 1, 3, 4], device=device, dtype=dtype).repeat(6, 1).reshape(*batch_shape, -1)
csr = torch.sparse_csr_tensor(crow_indices, col_indices, values, dtype=dtype, device=device)
dense = torch.tensor([[1, 2, 1], [3, 4, 0]], dtype=dtype, device=device).repeat(6, 1).reshape(csr.shape)
crow_indices = torch.tensor([0, 3, 5])
col_indices = torch.tensor([0, 1, 2, 0, 1])
values = torch.tensor([1, 2, 1, 3, 4], dtype=dtype)
csr = torch.sparse_csr_tensor(crow_indices, col_indices,
values, dtype=dtype, device=device)
dense = torch.tensor([[1, 2, 1], [3, 4, 0]], dtype=dtype, device=device)
self.assertEqual(csr.to_dense(), dense)
@skipCPUIfNoMklSparse
@ -1147,9 +1099,6 @@ class TestSparseCSR(TestCase):
@dtypes(torch.float, torch.double)
def test_add(self, device, dtype):
def _test_spadd_shape(nnz, shape):
# sparse.to_dense() uses torch.add internally so if torch.add is wrong,
# the dense tensor will be wrong but this test would still pass
# there's a separate test that checks for the correctness of the .to_dense() call
x = self.genSparseCSRTensor(shape, nnz, dtype=dtype, device=device, index_dtype=torch.int32)
y = torch.randn(*shape, dtype=dtype, device=device)
r = random.random()
@ -1171,12 +1120,10 @@ class TestSparseCSR(TestCase):
self.assertEqual(res, expected)
ns = [2, 5]
batch_shapes = [(), (2,), (2, 3)]
for b, m, n in itertools.product(batch_shapes, ns, ns):
_test_spadd_shape(0, (*b, m, n))
_test_spadd_shape(m * n // 2, (*b, m, n))
_test_spadd_shape(m * n, (*b, m, n))
_test_spadd_shape(10, [100, 100])
_test_spadd_shape(0, [100, 100])
_test_spadd_shape(10, [100, 1])
_test_spadd_shape(10, [1, 100])
@dtypes(torch.float, torch.double)
def test_mul(self, device, dtype):

View File

@ -1999,11 +1999,9 @@ class TestCase(expecttest.TestCase):
return crow_indices.to(device=device)
def genSparseCSRTensor(self, size, nnz, *, device, dtype, index_dtype):
from operator import mul
from functools import reduce
sparse_dim = 2
assert all(size[d] > 0 for d in range(len(size))) or nnz == 0, 'invalid arguments'
assert len(size) >= sparse_dim
assert all(size[d] > 0 for d in range(sparse_dim)) or nnz == 0, 'invalid arguments'
assert len(size) == sparse_dim
def random_sparse_csr(n_rows, n_cols, nnz):
crow_indices = self._make_crow_indices(n_rows, n_cols, nnz, device=device, dtype=index_dtype)
@ -2017,15 +2015,7 @@ class TestCase(expecttest.TestCase):
values = make_tensor([nnz], device=device, dtype=dtype, low=low, high=high)
return values, crow_indices, col_indices
batch_shape = size[:-2]
n_batch = reduce(mul, batch_shape, 1)
sparse_tensors = [random_sparse_csr(size[-2], size[-1], nnz) for _ in range(n_batch)]
sparse_tensors_it = map(list, zip(*sparse_tensors))
values = torch.stack(next(sparse_tensors_it)).reshape(*batch_shape, -1)
crow_indices = torch.stack(next(sparse_tensors_it)).reshape(*batch_shape, -1)
col_indices = torch.stack(next(sparse_tensors_it)).reshape(*batch_shape, -1)
values, crow_indices, col_indices = random_sparse_csr(size[0], size[1], nnz)
return torch.sparse_csr_tensor(crow_indices,
col_indices,
values, size=size, dtype=dtype, device=device)