Add spdiags sparse matrix initialization (#78439)

Similar to [scipy.sparse.spdiags](https://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.spdiags.html#scipy-sparse-spdiags)

Part of #70926

In other functions (ie (torch.diagonal)[https://pytorch.org/docs/stable/generated/torch.diagonal.html#torch.diagonal]) diagonals of a tensor are referenced using the offset and the two dimensions that the diagonal is taken with respect to.

Here the reference implementation from scipy is only considering matrix output, so even if we only support 2-d output at first. It may be useful to consider how the dimensions corresponding to each diagonal would be specified for higher dimensional output.

The proposed torch signature implies that all offsets refer to the diagonals with respect to the only two dimensions of the output:

```
torch.sparse.spdiags(Tensor diagonals, IntTensor offsets, int[] shape, Layout? layout=None) -> SparseTensor
```
 Above it is required that: `diagonals.ndimension() == 2`, `offsets.ndimensions() == 1`, `offsets.shape[0] == diagonals.shape[0]` and `len(shape) == 2`.

This would need to be altered for the case where `len(shape)` > 2. One options is:
```
torch.sparse.spdiags(Tensor[] diagonals, IntTensor[] offsets, IntTensor dims, int[] shape, Layout? layout=None) -> SparseTensor
```

Here `offsets` and `diagonals` becomes lists of tensors, and the `IntTensor dims` argument is introduced. This would require that `len(diagonals) == len(offsets) == dims.shape[0]`, `dims.ndimension() == 2` and `dims.shape[1] == 2` also the same restrictions as the 2d case above apply to the elements of `diagonals` and `offsets` pairwise (that is `diagonals[i].ndimension() == 2`, `offsets[i].ndimension() == 1` and `offsets[i].shape[0] == diagonals[i].shape[0]` for all i). This form of the signature would construct the sparse result by placing the values from `diagonals[i][j]` into the diagonal with offset `offset[i][j]` taken with respect to dimensions `dims[i]`. The specialization back to the original signature for the 2d case could be seen as allowing the single row of dims to default to `[0, 1]` when there is only one `diagonals`, `offsets` provided, and shape is `2-d`. This option allows the rows of an input element `diagonals[i]` to have a different length which may be appropriate as the max length of a diagonal along different dimension pairs will be different.

Another option is to specify the dimensions the diagonal is taken with respect to for each offset. This signature would look like:

```
torch.sparse.spdiags(Tensor diagonals, IntTensor offsets, IntTensor dims, int[] shape, Layout? layout=None) -> SparseTensor
```
Here, `diagonals` is still 2-D with dimension 0 matching the length of 1-D `offsets` and the tensor input `dims` is also 2-D with dimension 0 matching the length of 1-D `offsets` and the second dimension being fixed at `2` in this case the sparse result is constructed by placing the elements from `diagonals[i]` into the output diagonal `output.diagonal(offset[i], dim0=dims[i][0], dim1=dims[i][1])` (with some additional consideration that makes it more complicated than simply asigning to that view). The specialization from this back to the 2-D form could be seen as assuming `dims = [[0, 1], [0, 1]... len(offsets) times ]` when `len shape==2`.

In both proposed signatures for the N-D case the specialization back to the 2-D signature is a bit of a stretch for your typical default arguments logic, however I think the first is better choice as it offers more flexibility.

I think some discussion is required about:
- [x] Should the N-D output case be implemented from the outset
- [x] If not, should the future addition of the N-D output case be considered when designing the interface.
- [x] Other thoughts on the signature which includes the `dims` information for the N-D output case.

**Resolution**: Since no one has offered a request for N-D output support, I think is fine to restrict this to sparse matrix generation. Should a request for N-D support come later, an overload accepting the additional `dims` could be added.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/78439
Approved by: https://github.com/nikitaved, https://github.com/cpuhrsch, https://github.com/pearu
This commit is contained in:
Andrew M. James 2022-06-29 16:50:32 -05:00 committed by PyTorch MergeBot
parent 45ae244086
commit cfb2034b65
8 changed files with 378 additions and 1 deletions

View File

@ -0,0 +1,74 @@
#include <ATen/Dispatch.h>
#include <ATen/SparseTensorImpl.h>
#include <ATen/SparseTensorUtils.h>
#include <ATen/TensorIndexing.h>
#include <ATen/TensorIterator.h>
#include <ATen/core/ATen_fwd.h>
#include <ATen/core/Tensor.h>
#include <ATen/native/cpu/Loops.h>
#include <ATen/native/sparse/SparseFactories.h>
#include <c10/core/Scalar.h>
#include <c10/util/ArrayRef.h>
#include <c10/util/Exception.h>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#include <ATen/NativeFunctions.h>
#else
#include <ATen/ops/sparse_coo_tensor.h>
#endif
namespace at {
namespace native {
using namespace at::sparse;
namespace {
void _spdiags_kernel_cpu(
TensorIterator& iter,
const Tensor& diagonals,
Tensor& values,
Tensor& indices) {
auto* row_index_write_ptr = indices[0].data_ptr<int64_t>();
auto* col_index_write_ptr = indices[1].data_ptr<int64_t>();
const int64_t diagonals_read_stride = diagonals.stride(1);
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(
at::ScalarType::BFloat16,
at::ScalarType::Half,
at::ScalarType::Bool,
at::ScalarType::ComplexHalf,
diagonals.scalar_type(),
"spdiags_cpu",
[&] {
auto* values_write_ptr = values.data_ptr<scalar_t>();
cpu_kernel(
iter,
[&](int64_t diag_index,
int64_t diag_offset,
int64_t out_offset,
int64_t n_out) -> int64_t {
if (n_out > 0) {
auto* rows_start = row_index_write_ptr + out_offset;
auto* cols_start = col_index_write_ptr + out_offset;
auto* vals_start = values_write_ptr + out_offset;
const int64_t first_col = std::max<int64_t>(diag_offset, 0);
const int64_t first_row = first_col - diag_offset;
auto* data_read = diagonals[diag_index].data_ptr<scalar_t>() +
first_col * diagonals_read_stride;
for (int64_t i = 0; i < n_out; ++i) {
rows_start[i] = first_row + i;
cols_start[i] = first_col + i;
vals_start[i] = data_read[i * diagonals_read_stride];
}
}
// dummy return
return 0;
});
});
}
} // namespace
REGISTER_DISPATCH(spdiags_kernel_stub, &_spdiags_kernel_cpu)
} // namespace native
} // namespace at

View File

@ -5281,6 +5281,11 @@
SparseCPU: log_softmax_backward_sparse_cpu
SparseCUDA: log_softmax_backward_sparse_cuda
- func: _spdiags(Tensor diagonals, Tensor offsets, int[] shape, Layout? layout=None) -> Tensor
python_module: sparse
dispatch:
CPU: spdiags
- func: norm.ScalarOpt_dtype(Tensor self, Scalar? p, *, ScalarType dtype) -> Tensor
device_check: NoCheck # TensorIterator
variants: function, method

View File

@ -0,0 +1,95 @@
#include <ATen/Dispatch.h>
#include <ATen/native/sparse/SparseFactories.h>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#include <ATen/NativeFunctions.h>
#else
#include <ATen/ops/_unique.h>
#include <ATen/ops/arange.h>
#include <ATen/ops/empty.h>
#include <ATen/ops/sparse_coo_tensor.h>
#include <ATen/ops/where.h>
#endif
namespace at {
namespace native {
DEFINE_DISPATCH(spdiags_kernel_stub);
Tensor spdiags(
const Tensor& diagonals,
const Tensor& offsets,
IntArrayRef shape,
c10::optional<Layout> layout) {
auto diagonals_2d = diagonals.dim() == 1 ? diagonals.unsqueeze(0) : diagonals;
TORCH_CHECK(diagonals_2d.dim() == 2, "Diagonals must be vector or matrix");
TORCH_CHECK(shape.size() == 2, "Output shape must be 2d");
auto offsets_1d = offsets.dim() == 0 ? offsets.unsqueeze(0) : offsets;
TORCH_CHECK(offsets_1d.dim() == 1, "Offsets must be scalar or vector");
TORCH_CHECK(
diagonals_2d.size(0) == offsets_1d.size(0),
"Number of diagonals (",
diagonals_2d.size(0),
") does not match the number of offsets (",
offsets_1d.size(0),
")");
if (layout) {
TORCH_CHECK(
(*layout == Layout::Sparse) || (*layout == Layout::SparseCsc) ||
(*layout == Layout::SparseCsr),
"Only output layouts (Sparse, SparseCsc, SparseCsr) are supported, got ",
*layout);
}
TORCH_CHECK(
offsets_1d.scalar_type() == at::kLong,
"Offset Tensor must have dtype Long but got ",
offsets_1d.scalar_type());
TORCH_CHECK(
offsets_1d.numel() == std::get<0>(at::_unique(offsets_1d)).numel(),
"Offset tensor contains duplicate values");
auto nnz_per_diag = at::where(
offsets_1d.le(0),
offsets_1d.add(shape[0]).clamp_max_(diagonals_2d.size(1)),
offsets_1d.add(-std::min<int64_t>(shape[1], diagonals_2d.size(1))).neg());
auto nnz_per_diag_cumsum = nnz_per_diag.cumsum(-1);
const auto nnz = diagonals_2d.size(0) > 0
? nnz_per_diag_cumsum.select(-1, -1).item<int64_t>()
: int64_t{0};
// Offsets into nnz for each diagonal
auto result_mem_offsets = nnz_per_diag_cumsum.sub(nnz_per_diag);
// coo tensor guts
auto indices = at::empty({2, nnz}, offsets_1d.options());
auto values = at::empty({nnz}, diagonals_2d.options());
// We add this indexer to lookup the row of diagonals we are reading from at
// each iteration
const auto n_diag = offsets_1d.size(0);
Tensor diag_index = at::arange(n_diag, offsets_1d.options());
// cpu_kernel requires an output
auto dummy = at::empty({1}, offsets_1d.options()).resize_({0});
auto iter = TensorIteratorConfig()
.set_check_mem_overlap(false)
.add_output(dummy)
.add_input(diag_index)
.add_input(offsets_1d)
.add_input(result_mem_offsets)
.add_input(nnz_per_diag)
.build();
spdiags_kernel_stub(iter.device_type(), iter, diagonals_2d, values, indices);
auto result_coo = at::sparse_coo_tensor(indices, values, shape);
if (layout) {
if (*layout == Layout::SparseCsr) {
return result_coo.to_sparse_csr();
}
if (*layout == Layout::SparseCsc) {
return result_coo.to_sparse_csc();
}
}
return result_coo;
}
} // namespace native
} // namespace at

View File

@ -0,0 +1,15 @@
#pragma once
#include <ATen/TensorIterator.h>
#include <ATen/core/ATen_fwd.h>
#include <ATen/core/Tensor.h>
#include <ATen/native/DispatchStub.h>
namespace at {
namespace native {
using spdiags_kernel_fn_t =
void (*)(TensorIterator&, const Tensor&, Tensor&, Tensor&);
DECLARE_DISPATCH(spdiags_kernel_fn_t, spdiags_kernel_stub);
} // namespace native
} // namespace at

View File

@ -1155,6 +1155,7 @@ aten_native_source_codegen_list = [
"aten/src/ATen/native/cpu/scaled_modified_bessel_k0.cpp",
"aten/src/ATen/native/cpu/scaled_modified_bessel_k1.cpp",
"aten/src/ATen/native/cpu/spherical_bessel_j0.cpp",
"aten/src/ATen/native/cpu/SparseFactories.cpp",
"aten/src/ATen/native/quantized/cpu/kernels/QuantizedOpKernels.cpp",
]
@ -1357,6 +1358,7 @@ aten_native_source_non_codegen_list = [
"aten/src/ATen/native/sparse/SparseTensorMath.cpp",
"aten/src/ATen/native/sparse/SparseUnaryOps.cpp",
"aten/src/ATen/native/sparse/SparseCsrTensorMath.cpp",
"aten/src/ATen/native/sparse/SparseFactories.cpp",
"aten/src/ATen/native/transformers/attention.cpp",
"aten/src/ATen/native/transformers/transformer.cpp",
"aten/src/ATen/native/utils/Factory.cpp",

View File

@ -599,6 +599,7 @@ Torch functions specific to sparse Tensors
smm
sparse.softmax
sparse.log_softmax
sparse.spdiags
Other functions
+++++++++++++++

View File

@ -8,7 +8,7 @@ import random
import unittest
from torch.testing import make_tensor
from torch.testing._internal.common_utils import TestCase, run_tests, skipIfRocm, do_test_dtypes, \
do_test_empty_full, load_tests, TEST_NUMPY, IS_WINDOWS, gradcheck, coalescedonoff, \
do_test_empty_full, load_tests, TEST_NUMPY, TEST_SCIPY, IS_WINDOWS, gradcheck, coalescedonoff, \
DeterministicGuard, first_sample
from torch.testing._internal.common_cuda import TEST_CUDA, _get_torch_cuda_version
from numbers import Number
@ -26,6 +26,9 @@ from torch.testing._internal.common_dtype import (
floating_and_complex_types_and, integral_types, floating_types_and,
)
if TEST_SCIPY:
import scipy.sparse
# load_tests from torch.testing._internal.common_utils is used to automatically filter tests for
# sharding on sandcastle. This line silences flake warnings
load_tests = load_tests
@ -3558,6 +3561,94 @@ class TestSparse(TestCase):
test(4, 6, [7, 3, 1, 3, 1, 3], [7, 3, 1, 3, 2, 3])
test(4, 6, [7, 3, 1, 3, 2, 1], [7, 3, 1, 3, 2, 3])
@unittest.skipIf(not TEST_NUMPY, "NumPy is not availible")
@onlyCPU
@dtypes(*all_types_and_complex_and(torch.bool))
def test_sparse_spdiags(self, device, dtype):
make_diags = functools.partial(make_tensor, dtype=dtype, device=device)
make_offsets = functools.partial(torch.tensor, dtype=torch.long, device=device)
if TEST_SCIPY:
def reference(diags, offsets, shape):
return scipy.sparse.spdiags(diags, offsets, *shape).toarray()
else:
def reference(diags, offsets, shape):
result = torch.zeros(shape, dtype=dtype, device=device)
for i, off in enumerate(offsets):
res_view = result.diagonal(off)
data = diags[i]
if off > 0:
data = data[off:]
m = min(res_view.shape[0], data.shape[0])
res_view[:m] = data[:m]
return result
def check_valid(diags, offsets, shape, layout=None):
ref_out = reference(diags, offsets, shape)
out = torch.sparse.spdiags(diags, offsets, shape, layout=layout)
if layout is None:
ex_layout = torch.sparse_coo
else:
ex_layout = layout
out_dense = out.to_dense()
self.assertTrue(out.layout == ex_layout, f"Output layout {out.layout} expected {ex_layout}")
self.assertEqual(out_dense, ref_out, f"Result:\n{out_dense} does not match reference:\n{ref_out}")
def check_invalid(args, error):
with self.assertRaisesRegex(RuntimeError, error):
torch.sparse.spdiags(*args)
def valid_cases():
# some normal cases
yield (make_diags((1, 5)), make_offsets([0]), (5, 5))
yield (make_diags((3, 3)), make_offsets([-1, 0, 1]), (4, 4))
# noncontigous diags
yield (make_diags((5, 4), noncontiguous=True), make_offsets([-1, 1, 0, 2, -2]), (5, 5))
# noncontigous offsets
yield (make_diags((3, 4)), make_offsets([1, -1, 0, -2, 2])[::2], (5, 5))
# noncontigous diags + offsets
yield (make_diags((3, 4), noncontiguous=True), make_offsets([1, -1, 0, -2, 2])[::2], (5, 5))
# correct dimensionality, 2d, 2d , and shapes match, but the number of diagonals is zero
yield (make_diags((0, 3)), make_offsets([]), (3, 3))
# forward rotation of upper diagonals
yield (make_diags((3, 8)), make_offsets([1, 2, 3]), (4, 4))
# rotation exausts input space to read from
yield (make_diags((2, 3)), make_offsets([2, 1]), (3, 3))
# Simple cases repeated with special output format
yield (make_diags((1, 5)), make_offsets([0]), (5, 5), torch.sparse_csc)
yield (make_diags((3, 3)), make_offsets([-1, 0, 1]), (4, 4), torch.sparse_csr)
# vector diags
yield (make_diags((3, )), make_offsets([1]), (4, 4))
# Scalar offset
yield (make_diags((1, 3)), make_offsets(2), (4, 4))
# offsets out of range
yield (make_diags((1, 3)), make_offsets([3]), (3, 3))
yield (make_diags((1, 3)), make_offsets([-3]), (3, 3))
for case in valid_cases():
check_valid(*case)
def invalid_cases():
yield (make_diags((1, 3)), make_offsets([0]), (3, 2, 3)), "Output shape must be 2d"
yield (make_diags((2, 3)), make_offsets([[1, 2], [0, 3]]), (3, 3)), "Offsets must be scalar or vector"
yield (make_diags((3, 2, 3)), make_offsets([0, 1, 2]), (4, 4)), "Diagonals must be vector or matrix"
yield (make_diags((3, 3)), make_offsets([-1, 0]), (3, 3)),\
r"Number of diagonals \(\d\) does not match the number of offsets \(\d\)"
yield (make_diags((5,)), make_offsets([0, 1, 2, 3, 4]), (3, 3)),\
r"Number of diagonals \(\d\) does not match the number of offsets \(\d\)"
yield (make_diags((2, 2)), make_offsets([-1, 0]), (2, 3), torch.strided),\
r"Only output layouts \(\w+, \w+, \w+\) are supported, got \w+"
yield (make_diags((2, 5)), make_offsets([0, 0]), (5, 5)), "Offset tensor contains duplicate values"
yield (make_diags((1, 5)), make_offsets([0]).to(torch.int32), (5, 5)), r"Offset Tensor must have dtype Long but got \w+"
for case, error_regex in invalid_cases():
check_invalid(case, error_regex)
class TestSparseOneOff(TestCase):
@unittest.skipIf(not TEST_CUDA, 'CUDA not available')

View File

@ -262,3 +262,97 @@ Args:
performed. This is useful for preventing data type
overflows. Default: None
""")
spdiags = _add_docstr(
_sparse._spdiags,
r"""
sparse.spdiags(diagonals, offsets, shape, layout=None) -> Tensor
Creates a sparse 2D tensor by placing the values from rows of
:attr:`diagonals` along specified diagonals of the output
The :attr:`offsets` tensor controls which diagonals are set.
- If :attr:`offsets[i]` = 0, it is the main diagonal
- If :attr:`offsets[i]` < 0, it is below the main diagonal
- If :attr:`offsets[i]` > 0, it is above the main diagonal
The number of rows in :attr:`diagonals` must match the length of :attr:`offsets`,
and an offset may not be repeated.
Args:
diagonals (Tensor): Matrix storing diagonals row-wise
offsets (Tensor): The diagonals to be set, stored as a vector
shape (2-tuple of ints): The desired shape of the result
Keyword args:
layout (:class:`torch.layout`, optional): The desired layout of the
returned tensor. ``torch.sparse_coo``, ``torch.sparse_csc`` and ``torch.sparse_csr``
are supported. Default: ``torch.sparse_coo``
Examples:
Set the main and first two lower diagonals of a matrix::
>>> diags = torch.arange(9).reshape(3, 3)
>>> diags
tensor([[0, 1, 2],
[3, 4, 5],
[6, 7, 8]])
>>> s = torch.sparse.spdiags(diags, torch.tensor([0, -1, -2]), (3, 3))
>>> s
tensor(indices=tensor([[0, 1, 2, 1, 2, 2],
[0, 1, 2, 0, 1, 0]]),
values=tensor([0, 1, 2, 3, 4, 6]),
size=(3, 3), nnz=6, layout=torch.sparse_coo)
>>> s.to_dense()
tensor([[0, 0, 0],
[3, 1, 0],
[6, 4, 2]])
Change the output layout::
>>> diags = torch.arange(9).reshape(3, 3)
>>> diags
tensor([[0, 1, 2],[3, 4, 5], [6, 7, 8])
>>> s = torch.sparse.spdiags(diags, torch.tensor([0, -1, -2]), (3, 3), layout=torch.sparse_csr)
>>> s
tensor(crow_indices=tensor([0, 1, 3, 6]),
col_indices=tensor([0, 0, 1, 0, 1, 2]),
values=tensor([0, 3, 1, 6, 4, 2]), size=(3, 3), nnz=6,
layout=torch.sparse_csr)
>>> s.to_dense()
tensor([[0, 0, 0],
[3, 1, 0],
[6, 4, 2]])
Set partial diagonals of a large output::
>>> diags = torch.tensor([[1, 2], [3, 4]])
>>> offsets = torch.tensor([0, -1])
>>> torch.sparse.spdiags(diags, offsets, (5, 5)).to_dense()
tensor([[1, 0, 0, 0, 0],
[3, 2, 0, 0, 0],
[0, 4, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0]])
.. note::
When setting the values along a given diagonal the index into the diagonal
and the index into the row of :attr:`diagonals` is taken as the
column index in the output. This has the effect that when setting a diagonal
with a positive offset `k` the first value along that diagonal will be
the value in position `k` of the row of :attr:`diagonals`
Specifying a positive offset::
>>> diags = torch.tensor([[1, 2, 3], [1, 2, 3], [1, 2, 3]])
>>> torch.sparse.spdiags(diags, torch.tensor([0, 1, 2]), (5, 5)).to_dense()
tensor([[1, 2, 3, 0, 0],
[0, 2, 3, 0, 0],
[0, 0, 3, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0]])
""")