mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Revert "Extend CSR constructor to support batched indices and values"
This reverts commit c074a53002.
Reverted https://github.com/pytorch/pytorch/pull/74542 on behalf of https://github.com/malfet
This commit is contained in:
parent
2e4152b118
commit
cc23725e89
|
|
@ -60,22 +60,17 @@ SparseCsrTensorImpl::SparseCsrTensorImpl(
|
|||
}
|
||||
|
||||
void SparseCsrTensorImpl::resize_(int64_t nnz, IntArrayRef size) {
|
||||
auto rows = size[size.size() - 2];
|
||||
auto cols = size[size.size() - 1];
|
||||
auto rows = size[0];
|
||||
auto cols = size[1];
|
||||
auto old_crow_indices_size = crow_indices_.size(-1);
|
||||
|
||||
auto new_crow_indices_size = DimVector(size.slice(0, size.size() - 2));
|
||||
new_crow_indices_size.push_back(rows + 1);
|
||||
crow_indices_.resize_(new_crow_indices_size);
|
||||
crow_indices_.resize_({rows + 1});
|
||||
if (rows + 1 >= old_crow_indices_size) {
|
||||
crow_indices_.narrow(-1, old_crow_indices_size, rows + 1 - old_crow_indices_size).fill_(nnz);
|
||||
} else {
|
||||
crow_indices_.narrow(-1, rows, 1).fill_(std::min<int64_t>(nnz, rows*cols));
|
||||
}
|
||||
auto col_indices_values_size = DimVector(size.slice(0, size.size() - 2));
|
||||
col_indices_values_size.push_back(std::min<int64_t>(nnz, rows*cols));
|
||||
col_indices_.resize_(col_indices_values_size);
|
||||
values_.resize_(col_indices_values_size);
|
||||
col_indices_.resize_({std::min<int64_t>(nnz, rows*cols)});
|
||||
values_.resize_({std::min<int64_t>(nnz, rows*cols)});
|
||||
sizes_and_strides_.set_sizes(size);
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -43,7 +43,7 @@ struct TORCH_API SparseCsrTensorImpl : public TensorImpl {
|
|||
const Tensor& crow_indices() const { return crow_indices_; }
|
||||
const Tensor& col_indices() const { return col_indices_; }
|
||||
const Tensor& values() const { return values_; }
|
||||
int nnz() { return col_indices_.size(-1); }
|
||||
int nnz() { return values_.size(0); }
|
||||
|
||||
/**
|
||||
* Return a TensorImpl that is a shallow-copy of this TensorImpl.
|
||||
|
|
|
|||
|
|
@ -101,7 +101,7 @@ class MklSparseCsrDescriptor
|
|||
sparse_matrix_t raw_descriptor;
|
||||
|
||||
// Assuming that the last two dimensions are block elements of the matrix
|
||||
if (values.dim() == 3 && crow_indices.dim() == 1 && col_indices.dim() == 1) {
|
||||
if (values.dim() == 3) {
|
||||
TORCH_CHECK(
|
||||
values.size(-1) == values.size(-2),
|
||||
"MKL Sparse doesn't support matrices with non-square blocks.");
|
||||
|
|
|
|||
|
|
@ -9,7 +9,6 @@
|
|||
#include <ATen/SparseCsrTensorImpl.h>
|
||||
#include <ATen/SparseCsrTensorUtils.h>
|
||||
#include <ATen/SparseTensorImpl.h>
|
||||
#include <ATen/native/LinearAlgebraUtils.h>
|
||||
|
||||
#ifndef AT_PER_OPERATOR_HEADERS
|
||||
#include <ATen/Functions.h>
|
||||
|
|
@ -57,51 +56,29 @@ void _validate_sparse_csr_tensor_args(const Tensor& crow_indices, const Tensor&
|
|||
|
||||
// Shape and Strides invariants
|
||||
TORCH_CHECK(
|
||||
size.size() >= 2,
|
||||
"size of a batched CSR tensor must have length >= 2, but got: ",
|
||||
size.size() == 2,
|
||||
"size of a CSR tensor must be of length 2, but got: ",
|
||||
size.size());
|
||||
TORCH_CHECK(
|
||||
crow_indices.dim() >= 1,
|
||||
"crow_indices must have dim >= 1 but got crow_indices.dim() = ",
|
||||
crow_indices.dim() == 1,
|
||||
"crow_indices must have dim=1 but got crow_indices.dim()=",
|
||||
crow_indices.dim());
|
||||
TORCH_CHECK(
|
||||
col_indices.dim() >= 1,
|
||||
"col_indices must have dim >= 1 but got col_indices.dim() = ",
|
||||
col_indices.dim() == 1,
|
||||
"col_indices must have dim=1 but got col_indices.dim()=",
|
||||
col_indices.dim());
|
||||
TORCH_CHECK(
|
||||
values.dim() >= 1,
|
||||
"values must have dim >= 1 but got values.dim() = ",
|
||||
values.dim() == 1,
|
||||
"values must have dim=1 but got values.dim()=",
|
||||
values.dim());
|
||||
|
||||
// Note, this check also enforces `crow_indices.numel() >= 1`
|
||||
TORCH_CHECK(
|
||||
crow_indices.dim() == col_indices.dim(),
|
||||
"Number of dimensions of crow_indices and col_indices must be the same.");
|
||||
TORCH_CHECK(
|
||||
crow_indices.dim() == values.dim(),
|
||||
"Number of dimensions of indices and values must be the same.");
|
||||
TORCH_CHECK(
|
||||
crow_indices.dim() == size.size() - 1,
|
||||
"Number of dimensions of indices must be one less than the number of dimensions of the provided size.");
|
||||
|
||||
// All batch sizes must be the same
|
||||
auto batch_size = size.slice(0, size.size() - 2);
|
||||
auto crow_indices_batch_size = crow_indices.sizes().slice(0, crow_indices.dim() - 1);
|
||||
auto col_indices_batch_size = col_indices.sizes().slice(0, col_indices.dim() - 1);
|
||||
auto values_batch_size = values.sizes().slice(0, values.dim() - 1);
|
||||
TORCH_CHECK(
|
||||
batch_size == crow_indices_batch_size &&
|
||||
batch_size == col_indices_batch_size &&
|
||||
batch_size == values_batch_size,
|
||||
"All batch dimensions of the provided size, indices, and values must be the same.");
|
||||
|
||||
// Note, this check also enforces `crow_indices.size(-1) >= 1`
|
||||
TORCH_CHECK(
|
||||
crow_indices.size(-1) == (size[size.size() - 2] + 1),
|
||||
"crow_indices.size(-1) must be equal to size[-2] + 1 (that is ", size[size.size() - 2] + 1, "), but got: ",
|
||||
crow_indices.size(-1));
|
||||
crow_indices.numel() == (size[0] + 1),
|
||||
"crow_indices.numel() must be size(0) + 1, but got: ",
|
||||
crow_indices.numel());
|
||||
TORCH_CHECK(
|
||||
col_indices.numel() == values.numel(),
|
||||
"col_indices and values must have the same number of elements, but got col_indices.numel(): ",
|
||||
"col_indices and values must have equal sizes, but got col_indices.numel(): ",
|
||||
col_indices.numel(),
|
||||
", values.numel(): ",
|
||||
values.numel());
|
||||
|
|
@ -109,28 +86,22 @@ void _validate_sparse_csr_tensor_args(const Tensor& crow_indices, const Tensor&
|
|||
// Indices invariants
|
||||
AT_DISPATCH_INDEX_TYPES(crow_indices.scalar_type(), "csr_construct_check", [&] {
|
||||
Tensor crow_indices_cpu = crow_indices.to(kCPU);
|
||||
auto crow_indices_data_ptr = crow_indices_cpu.data_ptr<index_t>();
|
||||
auto batch_stride = crow_indices_cpu.dim() >= 2 ? crow_indices_cpu.stride(-2) : 0;
|
||||
for (const auto batch_id : c10::irange(batchCount(crow_indices_cpu))) {
|
||||
TORCH_CHECK(
|
||||
crow_indices_data_ptr[batch_id*batch_stride] == 0,
|
||||
"(Batch element ", batch_id, ") ",
|
||||
": 0th value of crow_indices must be 0, but it is ", crow_indices_data_ptr[batch_id*batch_stride]);
|
||||
TORCH_CHECK(
|
||||
crow_indices_data_ptr[batch_id*batch_stride + crow_indices.size(-1) - 1] == col_indices.size(-1),
|
||||
"(Batch element ", batch_id, ") ",
|
||||
"last value of crow_indices should be equal to the length of col_indices.");
|
||||
auto crow_indices_accessor = crow_indices_cpu.accessor<index_t, 1>();
|
||||
TORCH_CHECK(
|
||||
crow_indices_accessor[0] == 0, "0th value of crow_indices must be 0.");
|
||||
|
||||
for (int i = 1; i <= size[size.size() - 2]; i++) {
|
||||
TORCH_CHECK(
|
||||
crow_indices_data_ptr[batch_id*batch_stride + i - 1] <= crow_indices_data_ptr[batch_id*batch_stride + i],
|
||||
"(Batch element ", batch_id, ") ",
|
||||
"at position i = ", i, ", the condition crow_indices[i - 1] <= crow_indices[i] fails");
|
||||
}
|
||||
TORCH_CHECK(
|
||||
crow_indices_accessor[crow_indices.numel() - 1] == col_indices.numel(),
|
||||
"last value of crow_indices should be equal to the length of col_indices.");
|
||||
|
||||
for (int i = 1; i <= size[0]; i++) {
|
||||
TORCH_CHECK(
|
||||
crow_indices_accessor[i - 1] <= crow_indices_accessor[i],
|
||||
"at position i = ", i, ", this condition crow_indices[i - 1] <= crow_indices[i] fails");
|
||||
}
|
||||
if (col_indices.numel() > 0) {
|
||||
TORCH_CHECK(0 <= col_indices.min().item<index_t>(), "col_indices.min() should be greater or equal to zero");
|
||||
TORCH_CHECK(size[size.size() - 1] > col_indices.max().item<index_t>(), "size[-1] should be greater than col_indices.max()");
|
||||
TORCH_CHECK(size[1] > col_indices.max().item<index_t>(), "size(1) should be greater than col_indices.max()");
|
||||
}
|
||||
});
|
||||
|
||||
|
|
@ -242,10 +213,13 @@ Tensor sparse_csr_tensor(
|
|||
c10::optional<bool> pin_memory) {
|
||||
// See [Note: hacky wrapper removal for TensorOptions]
|
||||
TensorOptions options = TensorOptions().dtype(dtype).layout(layout).device(device).pinned_memory(pin_memory);
|
||||
// std::array<int64_t, 2> size = {0, 0};
|
||||
auto size = DimVector(IntArrayRef(col_indices.sizes().data(), col_indices.dim() - 1));
|
||||
size.push_back(crow_indices.size(-1) - 1);
|
||||
size.push_back(col_indices.max().item<int64_t>() + 1);
|
||||
std::array<int64_t, 2> size = {0, 0};
|
||||
if (col_indices.numel() > 0) {
|
||||
AT_DISPATCH_INDEX_TYPES(col_indices.scalar_type(), "csr_construct_check", [&] {
|
||||
size[0] = crow_indices.numel() - 1;
|
||||
size[1] = col_indices.max().item<index_t>() + 1;
|
||||
});
|
||||
}
|
||||
|
||||
at::native::_validate_sparse_csr_tensor_args(crow_indices, col_indices, values, size);
|
||||
|
||||
|
|
@ -269,21 +243,16 @@ Tensor empty_sparse_csr(
|
|||
c10::optional<MemoryFormat> optional_memory_format) {
|
||||
check_size_nonnegative(size);
|
||||
|
||||
TORCH_CHECK(size.size() >= 2, "torch.empty: Only batched sparse CSR matrices are supported, but got size ", size);
|
||||
TORCH_CHECK(size.size() == 2, "torch.empty: Only 2D sparse CSR tensors are supported.");
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(layout == Layout::SparseCsr);
|
||||
|
||||
auto rows = size[size.size() - 2];
|
||||
auto rows = size[0];
|
||||
int64_t nnz = 0;
|
||||
|
||||
auto crow_indices_size = DimVector(size.slice(0, size.size() - 2));
|
||||
crow_indices_size.push_back(rows + 1);
|
||||
auto col_indices_values_size = DimVector(size.slice(0, size.size() - 2));
|
||||
col_indices_values_size.push_back(nnz);
|
||||
|
||||
TensorOptions options = TensorOptions().dtype(ScalarType::Long).layout(Layout::Strided).device(device).pinned_memory(pin_memory);
|
||||
auto crow_indices = at::empty(crow_indices_size, options);
|
||||
auto col_indices = at::empty(col_indices_values_size, options);
|
||||
auto values = at::empty(col_indices_values_size, options.dtype(dtype));
|
||||
auto crow_indices = at::empty({rows + 1}, options);
|
||||
auto col_indices = at::empty({nnz}, options);
|
||||
auto values = at::empty({nnz}, options.dtype(dtype));
|
||||
|
||||
return at::native::_sparse_csr_tensor_unsafe(
|
||||
crow_indices,
|
||||
|
|
@ -301,13 +270,13 @@ const Tensor& resize_sparse_csr_(
|
|||
IntArrayRef size,
|
||||
c10::optional<MemoryFormat> optional_memory_format) {
|
||||
check_size_nonnegative(size);
|
||||
TORCH_CHECK(size.size() >= 2, "torch.resize_: Only batched sparse CSR matrices are supported, but got size ", size);
|
||||
TORCH_CHECK(size.size() == 2, "torch.resize_: Only 2D sparse CSR tensors are supported.");
|
||||
TORCH_CHECK(
|
||||
self.size(-1) <= size[size.size() - 1],
|
||||
self.size(1) <= size[1],
|
||||
"torch.resize_: Resizing columns of sparse CSR tensors to a smaller value is not supported. ",
|
||||
"The original number of columns is ",
|
||||
self.size(-1),
|
||||
" while the requested new number of columns is ", size[size.size() - 1], ".");
|
||||
self.size(1),
|
||||
" while the requested new number of columns is ", size[1], ".");
|
||||
get_sparse_csr_impl(self)->resize_(self._nnz(), size);
|
||||
return self;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -638,10 +638,13 @@ void add_out_dense_sparse_csr_cpu(
|
|||
" in add operation");
|
||||
|
||||
auto src_values = src.values();
|
||||
auto src_crow_indices = src.crow_indices();
|
||||
auto src_col_indices = src.col_indices();
|
||||
|
||||
resize_output(out, dense.sizes());
|
||||
|
||||
Tensor resultBuffer = out;
|
||||
Tensor valuesBuffer = src_values.to(commonDtype);
|
||||
|
||||
if (out.scalar_type() != commonDtype) {
|
||||
resultBuffer = dense.to(commonDtype);
|
||||
|
|
@ -649,15 +652,6 @@ void add_out_dense_sparse_csr_cpu(
|
|||
resultBuffer.copy_(dense);
|
||||
}
|
||||
|
||||
if (src._nnz() == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto valuesBuffer = src_values.to(commonDtype).view({-1, src_values.size(-1)});
|
||||
resultBuffer = resultBuffer.view({-1, out.size(-2), out.size(-1)});
|
||||
auto src_crow_indices = src.crow_indices().view({-1, src.crow_indices().size(-1)});
|
||||
auto src_col_indices = src.col_indices().view({-1, src.col_indices().size(-1)});
|
||||
|
||||
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(
|
||||
kHalf,
|
||||
kBool,
|
||||
|
|
@ -677,26 +671,27 @@ void add_out_dense_sparse_csr_cpu(
|
|||
&alpha,
|
||||
&src_crow_indices,
|
||||
&src_col_indices]() {
|
||||
auto batch_count = resultBuffer.dim() > 2 ? resultBuffer.size(-3) : 1;
|
||||
auto values_accessor = valuesBuffer.accessor<scalar_t, 2>();
|
||||
auto values_accessor = valuesBuffer.accessor<scalar_t, 1>();
|
||||
scalar_t* out_ptr = resultBuffer.data_ptr<scalar_t>();
|
||||
scalar_t cast_value = alpha.to<scalar_t>();
|
||||
|
||||
auto crow_indices_accessor =
|
||||
src_crow_indices.accessor<index_t, 2>();
|
||||
src_crow_indices.accessor<index_t, 1>();
|
||||
auto col_indices_accessor =
|
||||
src_col_indices.accessor<index_t, 2>();
|
||||
auto out_strides = resultBuffer.strides();
|
||||
src_col_indices.accessor<index_t, 1>();
|
||||
auto out_strides0 = resultBuffer.strides()[0];
|
||||
auto out_strides1 = resultBuffer.strides()[1];
|
||||
|
||||
for (const auto batch_idx : c10::irange(batch_count)) {
|
||||
for (const auto irow : c10::irange(src_crow_indices.size(-1) - 1)) {
|
||||
index_t start_index = crow_indices_accessor[batch_idx][irow];
|
||||
index_t end_index = crow_indices_accessor[batch_idx][irow + 1];
|
||||
for (const auto i : c10::irange(start_index, end_index)) {
|
||||
auto icol = col_indices_accessor[batch_idx][i];
|
||||
auto index = batch_idx * out_strides[0] + irow * out_strides[1] + icol * out_strides[2];
|
||||
out_ptr[index] += cast_value * values_accessor[batch_idx][i];
|
||||
}
|
||||
for (index_t irow = 0; irow < src_crow_indices.size(0) - 1;
|
||||
++irow) {
|
||||
index_t start_index = crow_indices_accessor[irow];
|
||||
index_t end_index = crow_indices_accessor[irow + 1];
|
||||
|
||||
for (index_t i = start_index; i < end_index; ++i) {
|
||||
auto icol = col_indices_accessor[i];
|
||||
auto index = resultBuffer.storage_offset() +
|
||||
irow * out_strides0 + icol * out_strides1;
|
||||
out_ptr[index] += cast_value * values_accessor[i];
|
||||
}
|
||||
}
|
||||
});
|
||||
|
|
|
|||
|
|
@ -978,18 +978,6 @@ void add_out_sparse_csr(
|
|||
auto B_col_indices_ptr = B_col_indices.data_ptr<int>();
|
||||
auto C_col_indices_ptr = C_col_indices.data_ptr<int>();
|
||||
|
||||
// Windows compilers don't support nested macros
|
||||
// so we need this lambda outside of the AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES
|
||||
auto fix_nnz = [&C_crow_indices, &m](int nnz) -> int {
|
||||
// For some reason POINTER_MODE_HOST is not working here
|
||||
// Let's extract manually the nnz from the C_crow_indices
|
||||
#if AT_ROCM_ENABLED()
|
||||
return std::max({nnz, C_crow_indices.narrow(-1, m, 1).item<int>()});
|
||||
#else
|
||||
return nnz;
|
||||
#endif
|
||||
};
|
||||
|
||||
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(
|
||||
C.scalar_type(), "add_out_sparse_csr_cuda_impl", [&] {
|
||||
auto beta_ = beta.to<scalar_t>();
|
||||
|
|
@ -1050,8 +1038,6 @@ void add_out_sparse_csr(
|
|||
&nnzC,
|
||||
work_data.get());
|
||||
|
||||
nnzC = fix_nnz(nnzC);
|
||||
|
||||
// Resize result using nnz information from cusparse
|
||||
col_indices_and_values_resize_(C, nnzC);
|
||||
C_col_indices = C.col_indices();
|
||||
|
|
|
|||
|
|
@ -159,26 +159,18 @@ Tensor& add_out_dense_sparse_csr_cuda(
|
|||
" in add operation");
|
||||
|
||||
Tensor src_values = src.values();
|
||||
Tensor src_crow_indices = src.crow_indices();
|
||||
Tensor src_col_indices = src.col_indices();
|
||||
|
||||
resize_output(output, dense.sizes());
|
||||
|
||||
Tensor resultBuffer = output;
|
||||
|
||||
Tensor valuesBuffer = src_values.to(commonDtype);
|
||||
if (output.scalar_type() != commonDtype) {
|
||||
resultBuffer = dense.to(commonDtype);
|
||||
} else if (!is_same_tensor(output, dense)) {
|
||||
resultBuffer.copy_(dense);
|
||||
}
|
||||
|
||||
if (src._nnz() == 0) {
|
||||
return output;
|
||||
}
|
||||
|
||||
auto valuesBuffer = src_values.to(commonDtype).view({-1, src_values.size(-1)});
|
||||
resultBuffer = resultBuffer.view({-1, output.size(-2), output.size(-1)});
|
||||
auto src_crow_indices = src.crow_indices().view({-1, src.crow_indices().size(-1)});
|
||||
auto src_col_indices = src.col_indices().view({-1, src.col_indices().size(-1)});
|
||||
|
||||
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(
|
||||
kHalf, kBool, kBFloat16,
|
||||
commonDtype,
|
||||
|
|
@ -188,7 +180,6 @@ Tensor& add_out_dense_sparse_csr_cuda(
|
|||
src_crow_indices.scalar_type(),
|
||||
"csr_add_out_crow_indices",
|
||||
[&valuesBuffer, &resultBuffer, &alpha, &src_crow_indices, &src_col_indices]() {
|
||||
auto batch_count = resultBuffer.dim() > 2 ? resultBuffer.size(-3) : 1;
|
||||
scalar_t* values_accessor = valuesBuffer.data_ptr<scalar_t>();
|
||||
scalar_t* out_ptr = resultBuffer.data_ptr<scalar_t>();
|
||||
scalar_t cast_value = alpha.to<scalar_t>();
|
||||
|
|
@ -198,11 +189,8 @@ Tensor& add_out_dense_sparse_csr_cuda(
|
|||
int64_t out_storage_offset = resultBuffer.storage_offset();
|
||||
|
||||
auto out_strides = resultBuffer.strides();
|
||||
auto out_strides0 = out_strides[0];
|
||||
auto out_strides1 = out_strides[1];
|
||||
auto crow_stride0 = src_crow_indices.stride(0);
|
||||
auto col_stride0 = src_col_indices.stride(0);
|
||||
auto val_stride0 = valuesBuffer.stride(0);
|
||||
int64_t out_strides0 = out_strides[0];
|
||||
int64_t out_strides1 = out_strides[1];
|
||||
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
at::cuda::ThrustAllocator allocator;
|
||||
|
|
@ -212,29 +200,24 @@ Tensor& add_out_dense_sparse_csr_cuda(
|
|||
thrust::for_each(
|
||||
policy,
|
||||
thrust::make_counting_iterator(int64_t(0)),
|
||||
thrust::make_counting_iterator(int64_t(src_crow_indices.size(-1) - 1)),
|
||||
thrust::make_counting_iterator(int64_t(src_crow_indices.size(0) - 1)),
|
||||
[values_accessor,
|
||||
crow_indices_accessor,
|
||||
col_indices_accessor,
|
||||
out_ptr,
|
||||
cast_value,
|
||||
out_storage_offset,
|
||||
out_strides0,
|
||||
out_strides1,
|
||||
crow_stride0,
|
||||
col_stride0,
|
||||
val_stride0,
|
||||
batch_count
|
||||
cast_value,
|
||||
out_strides1
|
||||
]__device__(int64_t irow) {
|
||||
for (index_t batch_idx = 0; batch_idx < batch_count; batch_idx++) {
|
||||
index_t start_index = crow_indices_accessor[batch_idx*crow_stride0 + irow];
|
||||
index_t end_index = crow_indices_accessor[batch_idx*crow_stride0 + irow + 1];
|
||||
index_t start_index = crow_indices_accessor[irow];
|
||||
index_t end_index = crow_indices_accessor[irow + 1];
|
||||
|
||||
for (index_t i = start_index; i < end_index; ++i) {
|
||||
auto icol = col_indices_accessor[batch_idx*col_stride0 + i];
|
||||
auto index = batch_idx * out_strides0 + irow * out_strides1 + icol;
|
||||
out_ptr[index] += cast_value * values_accessor[batch_idx*val_stride0 + i];
|
||||
auto icol = col_indices_accessor[i];
|
||||
auto index = out_storage_offset + irow * out_strides0 + icol * out_strides1;
|
||||
out_ptr[index] += cast_value * values_accessor[i];
|
||||
}
|
||||
}
|
||||
});
|
||||
});
|
||||
});
|
||||
|
|
|
|||
|
|
@ -165,33 +165,6 @@ class TestSparseCSR(TestCase):
|
|||
self.assertEqual(torch.tensor(col_indices, dtype=index_dtype), sparse.col_indices())
|
||||
self.assertEqual(torch.tensor(values, dtype=dtype), sparse.values())
|
||||
|
||||
@dtypes(*get_all_dtypes())
|
||||
def test_sparse_csr_batch_constructor(self, device, dtype):
|
||||
batch_shape = (2, 3)
|
||||
crow_indices = torch.tensor([0, 2, 4], device=device).repeat(6, 1).reshape(*batch_shape, -1)
|
||||
col_indices = torch.tensor([0, 1, 0, 1], device=device).repeat(6, 1).reshape(*batch_shape, -1)
|
||||
values = torch.tensor([1, 2, 3, 4], device=device, dtype=dtype).repeat(6, 1).reshape(*batch_shape, -1)
|
||||
for index_dtype in [torch.int32, torch.int64]:
|
||||
sparse = torch.sparse_csr_tensor(crow_indices.to(index_dtype),
|
||||
col_indices.to(index_dtype),
|
||||
values,
|
||||
size=(*batch_shape, 2, 10),
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
self.assertEqual((*batch_shape, 2, 10), sparse.shape)
|
||||
self.assertEqual(crow_indices.to(index_dtype), sparse.crow_indices())
|
||||
self.assertEqual(col_indices.to(index_dtype), sparse.col_indices())
|
||||
self.assertEqual(values, sparse.values())
|
||||
|
||||
@dtypes(*get_all_dtypes())
|
||||
def test_sparse_csr_batch_constructor_shape_inference(self, device, dtype):
|
||||
batch_shape = (2, 3)
|
||||
crow_indices = torch.tensor([0, 2, 4], device=device).repeat(6, 1).reshape(*batch_shape, -1)
|
||||
col_indices = torch.tensor([0, 1, 0, 1], device=device).repeat(6, 1).reshape(*batch_shape, -1)
|
||||
values = torch.tensor([1, 2, 3, 4], device=device, dtype=dtype).repeat(6, 1).reshape(*batch_shape, -1)
|
||||
sparse = torch.sparse_csr_tensor(crow_indices, col_indices, values, dtype=dtype, device=device)
|
||||
self.assertEqual((*batch_shape, crow_indices.shape[-1] - 1, col_indices.max() + 1), sparse.shape)
|
||||
|
||||
@dtypes(*get_all_dtypes())
|
||||
def test_sparse_csr_constructor_from_lists(self, device, dtype):
|
||||
# without size
|
||||
|
|
@ -225,17 +198,15 @@ class TestSparseCSR(TestCase):
|
|||
@dtypes(*get_all_dtypes())
|
||||
def test_empty(self, device, dtype):
|
||||
ns = [5, 2, 0]
|
||||
batch_shapes = [(), (2,), (2, 3)]
|
||||
for m, n, b in itertools.product(ns, ns, batch_shapes):
|
||||
shape = (*b, m, n)
|
||||
for shape in itertools.product(ns, ns):
|
||||
result = torch.empty(shape, dtype=dtype, device=device, layout=torch.sparse_csr)
|
||||
self.assertEqual(result.shape, shape)
|
||||
self.assertEqual(result.dtype, dtype)
|
||||
self.assertEqual(result.device, torch.device(device))
|
||||
self.assertEqual(result.layout, torch.sparse_csr)
|
||||
self.assertEqual(result.crow_indices().shape, (*b, shape[-2] + 1,))
|
||||
self.assertEqual(result.col_indices().shape, (*b, 0,))
|
||||
self.assertEqual(result.values().shape, (*b, 0,))
|
||||
self.assertEqual(result.crow_indices().shape, (shape[0] + 1,))
|
||||
self.assertEqual(result.col_indices().shape, (0,))
|
||||
self.assertEqual(result.values().shape, (0,))
|
||||
self.assertEqual(result._nnz(), 0)
|
||||
self.assertEqual(result.crow_indices().device, torch.device(device))
|
||||
self.assertEqual(result.col_indices().device, torch.device(device))
|
||||
|
|
@ -247,22 +218,23 @@ class TestSparseCSR(TestCase):
|
|||
@skipMeta
|
||||
@dtypes(*get_all_dtypes())
|
||||
def test_empty_errors(self, device, dtype):
|
||||
with self.assertRaisesRegex(RuntimeError, "torch.empty: Only batched sparse CSR matrices are supported, but got size"):
|
||||
with self.assertRaisesRegex(RuntimeError, "torch.empty: Only 2D sparse CSR tensors are supported."):
|
||||
torch.empty((5,), dtype=dtype, device=device, layout=torch.sparse_csr)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "torch.empty: Only 2D sparse CSR tensors are supported."):
|
||||
torch.empty((2, 3, 4), dtype=dtype, device=device, layout=torch.sparse_csr)
|
||||
|
||||
@skipMeta
|
||||
@dtypes(*get_all_dtypes())
|
||||
def test_clone(self, device, dtype):
|
||||
from operator import mul
|
||||
from functools import reduce
|
||||
for batch_shape in ((), (2,), (2, 3)):
|
||||
prod = reduce(mul, batch_shape, 1)
|
||||
crow_indices = torch.tensor([0, 2, 4], device=device).repeat(prod, 1).reshape(*batch_shape, -1)
|
||||
col_indices = torch.tensor([0, 1, 0, 1], device=device).repeat(prod, 1).reshape(*batch_shape, -1)
|
||||
values = torch.tensor([1, 2, 3, 4], device=device, dtype=dtype).repeat(prod, 1).reshape(*batch_shape, -1)
|
||||
sparse = torch.sparse_csr_tensor(crow_indices, col_indices, values, dtype=dtype, device=device)
|
||||
cloned_sparse = sparse.clone()
|
||||
self.assertEqual(sparse, cloned_sparse)
|
||||
x = torch.sparse_csr_tensor([0, 2, 4],
|
||||
[0, 1, 0, 1],
|
||||
[1, 2, 3, 4],
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
y = x.clone()
|
||||
|
||||
self.assertEqual(x, y)
|
||||
|
||||
@skipMeta
|
||||
@dtypes(*get_all_dtypes())
|
||||
|
|
@ -277,10 +249,9 @@ class TestSparseCSR(TestCase):
|
|||
self.assertEqual(a, b)
|
||||
|
||||
ns = [5, 2, 0]
|
||||
batch_shapes = [(), (2,), (2, 3)]
|
||||
for (m, n, b), index_dtype in zip(itertools.product(ns, ns, batch_shapes), [torch.int32, torch.int64]):
|
||||
run_test((*b, m, n), 0, index_dtype)
|
||||
run_test((*b, m, n), m * n, index_dtype)
|
||||
for shape, index_dtype in zip(itertools.product(ns, ns), [torch.int32, torch.int64]):
|
||||
run_test(shape, 0, index_dtype)
|
||||
run_test(shape, shape[0] * shape[1], index_dtype)
|
||||
|
||||
@skipMeta
|
||||
@dtypes(*get_all_dtypes())
|
||||
|
|
@ -304,31 +275,25 @@ class TestSparseCSR(TestCase):
|
|||
@skipMeta
|
||||
@dtypes(*get_all_dtypes())
|
||||
def test_resize(self, device, dtype):
|
||||
batch_shapes = [(), (2,), (2, 3)]
|
||||
for index_dtype, b in zip([torch.int32, torch.int64], batch_shapes):
|
||||
shape = (*b, 2, 3)
|
||||
for index_dtype in [torch.int32, torch.int64]:
|
||||
shape = (2, 3)
|
||||
nnz = 6
|
||||
a = self.genSparseCSRTensor(shape, nnz, dtype=dtype, device=device, index_dtype=index_dtype)
|
||||
|
||||
new_shape = (*b, 4, 5)
|
||||
new_shape = (4, 5)
|
||||
a.resize_(new_shape)
|
||||
|
||||
self.assertEqual(a.shape, new_shape)
|
||||
# resize to larger shape doesn't add specified elements
|
||||
self.assertEqual(a._nnz(), nnz)
|
||||
|
||||
new_shape = (*b, 1, 5)
|
||||
new_shape = (1, 5)
|
||||
a.resize_(new_shape)
|
||||
|
||||
self.assertEqual(a.shape, new_shape)
|
||||
# resize to smaller shape trims specified elements
|
||||
self.assertEqual(a._nnz(), 5)
|
||||
|
||||
# trim batched dimensions
|
||||
a.resize_(new_shape[-2], new_shape[-1])
|
||||
self.assertEqual(a.shape, (new_shape[-2], new_shape[-1]))
|
||||
self.assertEqual(a._nnz(), 5)
|
||||
|
||||
@skipMeta
|
||||
@dtypes(*get_all_dtypes())
|
||||
def test_resize_errors(self, device, dtype):
|
||||
|
|
@ -337,7 +302,7 @@ class TestSparseCSR(TestCase):
|
|||
nnz = 6
|
||||
a = self.genSparseCSRTensor(shape, nnz, dtype=dtype, device=device, index_dtype=index_dtype)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "torch.resize_: Only batched sparse CSR matrices are supported"):
|
||||
with self.assertRaisesRegex(RuntimeError, "torch.resize_: Only 2D sparse CSR tensors are supported."):
|
||||
new_shape = (4,)
|
||||
a.resize_(new_shape)
|
||||
|
||||
|
|
@ -382,62 +347,49 @@ class TestSparseCSR(TestCase):
|
|||
torch.tensor([1, 2, 3, 4]))
|
||||
|
||||
def test_factory_shape_invariants_check(self, device):
|
||||
crow_indices = torch.tensor([0, 2, 4], device=device)
|
||||
col_indices = torch.tensor([0, 1, 0, 1], device=device)
|
||||
values = torch.tensor([1, 2, 3, 4], device=device)
|
||||
crow_indices = [0, 2, 4]
|
||||
col_indices = [0, 1, 0, 1]
|
||||
values = [1, 2, 3, 4]
|
||||
size = (2, 10)
|
||||
torch.sparse_csr_tensor(crow_indices, col_indices, values, size, device=device)
|
||||
torch.sparse_csr_tensor(torch.tensor(crow_indices), torch.tensor(col_indices), torch.tensor(values), size,
|
||||
device=device)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, r"size of a batched CSR tensor must have length >= 2, but got: 1"):
|
||||
torch.sparse_csr_tensor(crow_indices, col_indices, values,
|
||||
size=(2,),
|
||||
with self.assertRaisesRegex(RuntimeError, r"size of a CSR tensor must be of length 2, but got: 3"):
|
||||
torch.sparse_csr_tensor(torch.tensor(crow_indices), torch.tensor(col_indices), torch.tensor(values),
|
||||
size=(2, 10, 2),
|
||||
device=device)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, r"crow_indices must have dim >= 1 but got crow_indices\.dim\(\)\ = 0"):
|
||||
torch.sparse_csr_tensor(torch.zeros((), device=device, dtype=torch.int64),
|
||||
col_indices,
|
||||
values,
|
||||
with self.assertRaisesRegex(RuntimeError, r"crow_indices must have dim\=1 but got crow_indices\.dim\(\)\=2"):
|
||||
torch.sparse_csr_tensor(torch.tensor(crow_indices).repeat(2, 1),
|
||||
torch.tensor(col_indices),
|
||||
torch.tensor(values),
|
||||
size,
|
||||
device=device)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, r"col_indices must have dim >= 1 but got col_indices\.dim\(\)\ = 0"):
|
||||
torch.sparse_csr_tensor(crow_indices,
|
||||
torch.zeros((), device=device, dtype=torch.int64),
|
||||
values,
|
||||
with self.assertRaisesRegex(RuntimeError, r"col_indices must have dim\=1 but got col_indices\.dim\(\)\=2"):
|
||||
torch.sparse_csr_tensor(torch.tensor(crow_indices),
|
||||
torch.tensor(col_indices).repeat(2, 1),
|
||||
torch.tensor(values),
|
||||
size,
|
||||
device=device)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, r"values must have dim >= 1 but got values\.dim\(\)\ = 0"):
|
||||
torch.sparse_csr_tensor(crow_indices,
|
||||
col_indices,
|
||||
torch.zeros((), device=device, dtype=torch.int64),
|
||||
with self.assertRaisesRegex(RuntimeError, r"values must have dim\=1 but got values\.dim\(\)\=2"):
|
||||
torch.sparse_csr_tensor(torch.tensor(crow_indices),
|
||||
torch.tensor(col_indices),
|
||||
torch.tensor(values).repeat(2, 1),
|
||||
size,
|
||||
device=device)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError,
|
||||
r"crow_indices\.size\(-1\) must be equal to size\[-2\] \+ 1 \(that is 2\), but got: 3"):
|
||||
torch.sparse_csr_tensor(crow_indices, col_indices, values, (1, 1),
|
||||
r"crow_indices\.numel\(\) must be size\(0\) \+ 1, but got: 3"):
|
||||
torch.sparse_csr_tensor(torch.tensor(crow_indices), torch.tensor(col_indices), torch.tensor(values), (1, 1),
|
||||
device=device)
|
||||
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError,
|
||||
r"Number of dimensions of crow_indices and col_indices must be the same"):
|
||||
torch.sparse_csr_tensor(crow_indices, col_indices.repeat(2, 1), values, size,
|
||||
device=device)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError,
|
||||
r"Number of dimensions of indices and values must be the same"):
|
||||
torch.sparse_csr_tensor(crow_indices, col_indices, values.repeat(2, 1), size,
|
||||
device=device)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError,
|
||||
r"Number of dimensions of indices must be one less"):
|
||||
torch.sparse_csr_tensor(crow_indices.repeat(2, 1), col_indices.repeat(2, 1), values.repeat(2, 1), size,
|
||||
device=device)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError,
|
||||
r"All batch dimensions of the provided size, indices, and values must be the same"):
|
||||
torch.sparse_csr_tensor(crow_indices.repeat(2, 1), col_indices.repeat(3, 1), values.repeat(4, 1), (2, 2, 10),
|
||||
r"col_indices and values must have equal sizes, " +
|
||||
r"but got col_indices\.numel\(\): 3, values\.numel\(\): 4"):
|
||||
torch.sparse_csr_tensor(torch.tensor(crow_indices), torch.tensor([0, 1, 0]), torch.tensor(values), size,
|
||||
device=device)
|
||||
|
||||
def test_factory_indices_invariants_check(self, device):
|
||||
|
|
@ -456,7 +408,7 @@ class TestSparseCSR(TestCase):
|
|||
|
||||
with self.assertRaisesRegex(RuntimeError,
|
||||
r"at position i \= 2," +
|
||||
r" the condition crow_indices\[i - 1\] <\= crow_indices\[i\] fails"):
|
||||
r" this condition crow_indices\[i - 1\] <\= crow_indices\[i\] fails"):
|
||||
torch.sparse_csr_tensor(torch.tensor([0, 5, 4]), torch.tensor(col_indices), torch.tensor(values), size,
|
||||
device=device)
|
||||
|
||||
|
|
@ -464,7 +416,7 @@ class TestSparseCSR(TestCase):
|
|||
torch.sparse_csr_tensor(torch.tensor(crow_indices), torch.tensor([0, -1, 0, 1]), torch.tensor(values), size,
|
||||
device=device)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, r"size\[-1\] should be greater than col_indices\.max\(\)"):
|
||||
with self.assertRaisesRegex(RuntimeError, r"size\(1\) should be greater than col_indices\.max\(\)"):
|
||||
torch.sparse_csr_tensor(torch.tensor(crow_indices), torch.tensor([0, 11, 0, 1]), torch.tensor(values), size,
|
||||
device=device)
|
||||
|
||||
|
|
@ -569,12 +521,12 @@ class TestSparseCSR(TestCase):
|
|||
sparse = dense.to_sparse_csr()
|
||||
self.assertEqual(sparse.to_dense(), dense)
|
||||
|
||||
batch_shape = (2, 3)
|
||||
crow_indices = torch.tensor([0, 3, 5], device=device).repeat(6, 1).reshape(*batch_shape, -1)
|
||||
col_indices = torch.tensor([0, 1, 2, 0, 1], device=device).repeat(6, 1).reshape(*batch_shape, -1)
|
||||
values = torch.tensor([1, 2, 1, 3, 4], device=device, dtype=dtype).repeat(6, 1).reshape(*batch_shape, -1)
|
||||
csr = torch.sparse_csr_tensor(crow_indices, col_indices, values, dtype=dtype, device=device)
|
||||
dense = torch.tensor([[1, 2, 1], [3, 4, 0]], dtype=dtype, device=device).repeat(6, 1).reshape(csr.shape)
|
||||
crow_indices = torch.tensor([0, 3, 5])
|
||||
col_indices = torch.tensor([0, 1, 2, 0, 1])
|
||||
values = torch.tensor([1, 2, 1, 3, 4], dtype=dtype)
|
||||
csr = torch.sparse_csr_tensor(crow_indices, col_indices,
|
||||
values, dtype=dtype, device=device)
|
||||
dense = torch.tensor([[1, 2, 1], [3, 4, 0]], dtype=dtype, device=device)
|
||||
self.assertEqual(csr.to_dense(), dense)
|
||||
|
||||
@skipCPUIfNoMklSparse
|
||||
|
|
@ -1147,9 +1099,6 @@ class TestSparseCSR(TestCase):
|
|||
@dtypes(torch.float, torch.double)
|
||||
def test_add(self, device, dtype):
|
||||
def _test_spadd_shape(nnz, shape):
|
||||
# sparse.to_dense() uses torch.add internally so if torch.add is wrong,
|
||||
# the dense tensor will be wrong but this test would still pass
|
||||
# there's a separate test that checks for the correctness of the .to_dense() call
|
||||
x = self.genSparseCSRTensor(shape, nnz, dtype=dtype, device=device, index_dtype=torch.int32)
|
||||
y = torch.randn(*shape, dtype=dtype, device=device)
|
||||
r = random.random()
|
||||
|
|
@ -1171,12 +1120,10 @@ class TestSparseCSR(TestCase):
|
|||
|
||||
self.assertEqual(res, expected)
|
||||
|
||||
ns = [2, 5]
|
||||
batch_shapes = [(), (2,), (2, 3)]
|
||||
for b, m, n in itertools.product(batch_shapes, ns, ns):
|
||||
_test_spadd_shape(0, (*b, m, n))
|
||||
_test_spadd_shape(m * n // 2, (*b, m, n))
|
||||
_test_spadd_shape(m * n, (*b, m, n))
|
||||
_test_spadd_shape(10, [100, 100])
|
||||
_test_spadd_shape(0, [100, 100])
|
||||
_test_spadd_shape(10, [100, 1])
|
||||
_test_spadd_shape(10, [1, 100])
|
||||
|
||||
@dtypes(torch.float, torch.double)
|
||||
def test_mul(self, device, dtype):
|
||||
|
|
|
|||
|
|
@ -1999,11 +1999,9 @@ class TestCase(expecttest.TestCase):
|
|||
return crow_indices.to(device=device)
|
||||
|
||||
def genSparseCSRTensor(self, size, nnz, *, device, dtype, index_dtype):
|
||||
from operator import mul
|
||||
from functools import reduce
|
||||
sparse_dim = 2
|
||||
assert all(size[d] > 0 for d in range(len(size))) or nnz == 0, 'invalid arguments'
|
||||
assert len(size) >= sparse_dim
|
||||
assert all(size[d] > 0 for d in range(sparse_dim)) or nnz == 0, 'invalid arguments'
|
||||
assert len(size) == sparse_dim
|
||||
|
||||
def random_sparse_csr(n_rows, n_cols, nnz):
|
||||
crow_indices = self._make_crow_indices(n_rows, n_cols, nnz, device=device, dtype=index_dtype)
|
||||
|
|
@ -2017,15 +2015,7 @@ class TestCase(expecttest.TestCase):
|
|||
values = make_tensor([nnz], device=device, dtype=dtype, low=low, high=high)
|
||||
return values, crow_indices, col_indices
|
||||
|
||||
batch_shape = size[:-2]
|
||||
n_batch = reduce(mul, batch_shape, 1)
|
||||
|
||||
sparse_tensors = [random_sparse_csr(size[-2], size[-1], nnz) for _ in range(n_batch)]
|
||||
sparse_tensors_it = map(list, zip(*sparse_tensors))
|
||||
values = torch.stack(next(sparse_tensors_it)).reshape(*batch_shape, -1)
|
||||
crow_indices = torch.stack(next(sparse_tensors_it)).reshape(*batch_shape, -1)
|
||||
col_indices = torch.stack(next(sparse_tensors_it)).reshape(*batch_shape, -1)
|
||||
|
||||
values, crow_indices, col_indices = random_sparse_csr(size[0], size[1], nnz)
|
||||
return torch.sparse_csr_tensor(crow_indices,
|
||||
col_indices,
|
||||
values, size=size, dtype=dtype, device=device)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user