Add check-sparse-tensor-invariants flag to Context - 2nd try. (#92094)

This PR is a copy of https://github.com/pytorch/pytorch/pull/90849 that merge was reverted.

The PR adds "check sparse tensor invariants" flag to Context that when enabled will trigger sparse tensor data invariants checks in unsafe methods of constructing sparse COO/CSR/CSC/BSR/BSC tensors. The feature includes the following changes to UI:

`torch.sparse.check_sparse_tensor_invariants` class provides different ways to enable/disable the invariant checking.

`torch.sparse_coo/csr/csc/bsr/bsc/compressed_tensor` functions have a new optional argument `check_invariants` to enable/disable the invariant checks explicitly. When the `check_invariants` argument is specified, the global state of the feature is temporarily overridden.

The PR fixes https://github.com/pytorch/pytorch/issues/90833

Pull Request resolved: https://github.com/pytorch/pytorch/pull/92094
Approved by: https://github.com/cpuhrsch
This commit is contained in:
Pearu Peterson 2023-01-12 21:00:28 +02:00 committed by PyTorch MergeBot
parent a111dd9014
commit b3e4f5029b
18 changed files with 493 additions and 96 deletions

View File

@ -367,6 +367,14 @@ bool Context::isXNNPACKAvailable() {
#endif
}
void Context::setCheckSparseTensorInvariants(bool e) {
enable_sparse_tensor_invariant_checks = e;
}
bool Context::checkSparseTensorInvariants() const {
return enable_sparse_tensor_invariant_checks;
}
bool Context::releaseWeightsWhenPrepacking() const {
return release_original_weights;
}

View File

@ -250,6 +250,8 @@ class TORCH_API Context {
void setQEngine(at::QEngine e);
static const std::vector<at::QEngine>& supportedQEngines();
static bool isXNNPACKAvailable();
void setCheckSparseTensorInvariants(bool e);
bool checkSparseTensorInvariants() const;
// This method is used to release the original weight after pre-packing.
// It should be called once before loading/running the model.
// NB: By default it is set to true for mobile builds.
@ -305,6 +307,7 @@ class TORCH_API Context {
#endif
bool display_vmap_fallback_warnings_ = false;
c10::optional<at::QEngine> quantized_engine = c10::nullopt;
bool enable_sparse_tensor_invariant_checks = false;
Allocator* prev_allocator_ptr_{nullptr};
};

View File

@ -356,6 +356,9 @@ Tensor _sparse_compressed_tensor_unsafe(const Tensor& compressed_indices,
}
Layout layout_ = layout.value();
AT_DISPATCH_ALL_SPARSE_COMPRESSED_LAYOUTS(layout_, "sparse_compressed_tensor_unsafe", [&]{});
if (at::globalContext().checkSparseTensorInvariants()) {
_validate_sparse_compressed_tensor_args_worker(compressed_indices, plain_indices, values, size, layout_);
}
TensorOptions options = TensorOptions().dtype(dtype).layout(layout_).device(device).pinned_memory(pin_memory);
SparseCsrTensor self = new_compressed_tensor(options);
get_sparse_csr_impl(self)->set_member_tensors(compressed_indices, plain_indices, values, size);
@ -373,6 +376,9 @@ Tensor _sparse_compressed_tensor_unsafe_template(const Tensor& compressed_indice
c10::optional<bool> pin_memory) {
Layout layout_ = layout.value_or(required_layout);
TORCH_CHECK(layout_ == required_layout, "sparse compressed layout must be ",required_layout, " but got ", layout_);
if (at::globalContext().checkSparseTensorInvariants()) {
_validate_sparse_compressed_tensor_args_worker(compressed_indices, plain_indices, values, size, layout_);
}
TensorOptions options = TensorOptions().dtype(dtype).layout(layout_).device(device).pinned_memory(pin_memory);
SparseCsrTensor self = new_compressed_tensor(options);
get_sparse_csr_impl(self)->set_member_tensors(compressed_indices, plain_indices, values, size);
@ -474,8 +480,6 @@ Tensor sparse_compressed_tensor(
// See [Note: hacky wrapper removal for TensorOptions]
TensorOptions options = TensorOptions().dtype(dtype).layout(layout_).device(device).pinned_memory(pin_memory);
_validate_sparse_compressed_tensor_args_worker(compressed_indices, plain_indices, values, size, layout_);
return at::native::_sparse_compressed_tensor_unsafe(
compressed_indices,
plain_indices,
@ -507,8 +511,6 @@ Tensor sparse_compressed_tensor(
// See [Note: hacky wrapper removal for TensorOptions]
TensorOptions options = TensorOptions().dtype(dtype).layout(layout_).device(device).pinned_memory(pin_memory);
_validate_sparse_compressed_tensor_args_worker(compressed_indices, plain_indices, values, size, layout_);
return at::native::_sparse_compressed_tensor_unsafe(
compressed_indices,
plain_indices,

View File

@ -398,8 +398,6 @@ Tensor sparse_coo_tensor(const Tensor& indices, const Tensor& values, IntArrayRe
!options.has_layout() || options.layout() == kSparse,
"expected sparse layout, but got layout ",
options.layout());
at::native::_validate_sparse_coo_tensor_args(indices, values, size);
return at::native::_sparse_coo_tensor_unsafe(
indices,
values,
@ -415,20 +413,10 @@ Tensor _sparse_coo_tensor_unsafe(const Tensor& indices, const Tensor& values_, a
c10::optional<Layout> layout,
c10::optional<Device> device,
c10::optional<bool> pin_memory) {
return at::native::_sparse_coo_tensor_unsafe_symint(indices, values_, c10::fromIntArrayRefSlow(size), dtype, layout, device, pin_memory);
Tensor values = expand_values_if_needed(values_);
auto sparse_dim = indices.size(0);
auto dense_dim = values.dim() - 1;
return at::_sparse_coo_tensor_with_dims_and_tensors(
sparse_dim,
dense_dim,
size,
indices,
values,
values.options().layout(kSparse));
if (at::globalContext().checkSparseTensorInvariants()) {
at::native::_validate_sparse_coo_tensor_args(indices, values_, size);
}
return at::native::_sparse_coo_tensor_unsafe_symint(indices, values_, c10::fromIntArrayRefSlow(size), dtype, layout, device, pin_memory);
}
// NOTE: _sparse_coo_tensor_unsafe() differs from sparse_coo_tensor()

View File

@ -321,6 +321,15 @@ invariants:
Dense dimensions always follow sparse dimensions, that is, mixing
of dense and sparse dimensions is not supported.
.. note::
To be sure that a constructed sparse tensor has consistent indices,
values, and size, the invariant checks can be enabled per tensor
creation via ``check_invariants=True`` keyword argument, or
globally using :class:`torch.sparse.check_sparse_tensor_invariants`
context manager instance. By default, the sparse tensor invariants
checks are disabled.
.. _sparse-uncoalesced-coo-docs:
Uncoalesced sparse COO tensors
@ -530,6 +539,13 @@ __ https://en.wikipedia.org/wiki/Sparse_matrix#Compressed_sparse_row_(CSR,_CRS_o
where ``plain_dim_size`` is the number of plain dimensions
(orthogonal to compressed dimensions, e.g. columns or rows).
To be sure that a constructed sparse tensor has consistent indices,
values, and size, the invariant checks can be enabled per tensor
creation via ``check_invariants=True`` keyword argument, or
globally using :class:`torch.sparse.check_sparse_tensor_invariants`
context manager instance. By default, the sparse tensor invariants
checks are disabled.
.. note::
The generalization of sparse compressed layouts to N-dimensional
@ -646,9 +662,9 @@ argument is optional and will be deduced from the ``crow_indices`` and
>>> csr = torch.sparse_csr_tensor(crow_indices, col_indices, values, dtype=torch.float64)
>>> csr
tensor(crow_indices=tensor([0, 2, 4]),
col_indices=tensor([0, 1, 0, 1]),
values=tensor([1., 2., 3., 4.]), size=(2, 2), nnz=4,
dtype=torch.float64)
col_indices=tensor([0, 1, 0, 1]),
values=tensor([1., 2., 3., 4.]), size=(2, 2), nnz=4,
dtype=torch.float64)
>>> csr.to_dense()
tensor([[1., 2.],
[3., 4.]], dtype=torch.float64)
@ -1160,6 +1176,14 @@ The following :mod:`torch` functions support sparse tensors:
:func:`~torch.zeros`
:func:`~torch.zeros_like`
To manage checking sparse tensor invariants, see:
.. autosummary::
:toctree: generated
:nosignatures:
sparse.check_sparse_tensor_invariants
Unary functions
---------------

View File

@ -48,6 +48,10 @@ Creation Ops
tensor
sparse_coo_tensor
sparse_csr_tensor
sparse_csc_tensor
sparse_bsr_tensor
sparse_bsc_tensor
asarray
as_tensor
as_strided

View File

@ -4072,6 +4072,113 @@ class TestSparseMeta(TestCase):
class TestSparseAny(TestCase):
@onlyCPU
@all_sparse_layouts('layout', include_strided=False)
@torch.sparse.check_sparse_tensor_invariants(enable=False)
def test_check_sparse_tensor_invariants(self, layout):
if layout is torch.sparse_coo:
def create_invalid_tensor(check_invariants=None):
shape = (2, 2)
invalid_indices = torch.tensor([[0], [3]]) # column index is out of range
values = torch.tensor([1])
if check_invariants is None:
return torch.sparse_coo_tensor(invalid_indices, values, shape)
else:
return torch.sparse_coo_tensor(invalid_indices, values, shape, check_invariants=check_invariants)
expected_exception_message = 'size is inconsistent with indices: for dim 1, size is 2 but found index 3'
elif layout in {torch.sparse_csr, torch.sparse_csc, torch.sparse_bsr, torch.sparse_bsc}:
def create_invalid_tensor(check_invariants=None):
shape = (2, 2)
compressed_indices = torch.tensor([0, 0, 1])
invalid_plain_indices = torch.tensor([3]) # index is out of range
if layout in {torch.sparse_bsr, torch.sparse_bsc}:
values = torch.tensor([[[1]]])
else:
values = torch.tensor([1])
if check_invariants is None:
return torch.sparse_compressed_tensor(compressed_indices, invalid_plain_indices, values, shape, layout=layout)
else:
return torch.sparse_compressed_tensor(compressed_indices, invalid_plain_indices, values, shape, layout=layout,
check_invariants=check_invariants)
if layout in {torch.sparse_csr, torch.sparse_bsr}:
expected_exception_message = r'`0 <= col_indices < ncols` is not satisfied.'
else:
expected_exception_message = r'`0 <= row_indices < nrows` is not satisfied.'
else:
raise NotImplementedError(layout)
# First, consider the case where invariant checks are disabled
# "globally" (read: within the context of this test method
# caller) as defined by check_sparse_tensor_invariants(False)
# decorator:
self.assertFalse(torch.sparse.check_sparse_tensor_invariants.is_enabled())
# Enable the invariant checks in a local context:
with torch.sparse.check_sparse_tensor_invariants():
self.assertTrue(torch.sparse.check_sparse_tensor_invariants.is_enabled())
# Leaving the local context must restore the "global" state of
# the invariant check feature:
self.assertFalse(torch.sparse.check_sparse_tensor_invariants.is_enabled())
# Since invariant checks are disabled by default, we can
# create an invalid sparse tensor without raising an
# exception:
r = create_invalid_tensor()
self.assertEqual(r.layout, layout)
# Or, when disabling the invariants check explicitly:
r = create_invalid_tensor(check_invariants=False)
self.assertEqual(r.layout, layout)
# Enabling invariant check via constructor's optional argument
# will raise an exception when sparse tensor invariants are
# violated:
with self.assertRaisesRegex(RuntimeError, expected_exception_message):
create_invalid_tensor(check_invariants=True)
# Check that the global invariant check flag has been restored
# after raising the exception above:
self.assertFalse(torch.sparse.check_sparse_tensor_invariants.is_enabled())
# Next, consider the case where invariant checks are enabled
# within a local context:
with torch.sparse.check_sparse_tensor_invariants():
self.assertTrue(torch.sparse.check_sparse_tensor_invariants.is_enabled())
# Since invariant checks are now enabled by default, an
# attempt to create an invalid sparse tensor will lead to
# an exception:
with self.assertRaisesRegex(RuntimeError, expected_exception_message):
create_invalid_tensor()
# Similarly, when enabling the invariant checks
# explicitly, invalid sparse tensor construction will lead
# to an exception:
with self.assertRaisesRegex(RuntimeError, expected_exception_message):
create_invalid_tensor(check_invariants=True)
# However, invariants check can be disabled via
# constructor's optional argument so that the invalid
# tensor is succesfully constructed:
r = create_invalid_tensor(check_invariants=False)
self.assertEqual(r.layout, layout)
# Check that the invariant check flag has been restored
# when leaving the constructor:
self.assertTrue(torch.sparse.check_sparse_tensor_invariants.is_enabled())
# Double-check restoring the global state when leaving the
# local context:
self.assertFalse(torch.sparse.check_sparse_tensor_invariants.is_enabled())
def test_generate_simple_inputs(self):
layouts = [torch.strided, torch.sparse_coo, torch.sparse_csr, torch.sparse_csc, torch.sparse_bsr, torch.sparse_bsc]

View File

@ -1363,9 +1363,17 @@ class TestSparseCSR(TestCase):
@onlyCUDA
@unittest.skipIf(not (CUDA11OrLater or TEST_WITH_ROCM), "Only CUDA 11+ is supported")
# hmm, the test passes ok on CUDA when Rocm is not available:
@skipCUDAIfRocmVersionLessThan((5, 2))
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
def test_baddbmm(self, device, dtype):
# TODO: disable the invariant checks within torch.baddbmm that
# constructs unconventional csr tensors leading to
# RuntimeError: tensor dimensionality must be sum of batch,
# base, and dense dimensionalities (=0 + 2 + 0) but got 3
# when invariant checking is enabled. When done, undecorate run_test.
@torch.sparse.check_sparse_tensor_invariants(enable=False)
def run_test(c, a, a_batched, b, op_b=False, op_out=False, *, dtype=None, device=None):
alpha = complex(random.random(), random.random()) if dtype.is_complex else random.random()
beta = complex(random.random(), random.random()) if dtype.is_complex else random.random()
@ -1388,8 +1396,8 @@ class TestSparseCSR(TestCase):
a = self.genSparseCSRTensor((m, k), nnz, dtype=dtype, device=device, index_dtype=index_dtype)
# a_batched is a regular CSR tensor but with a batch dimension in the shape
a_batched = torch._sparse_csr_tensor_unsafe(
a.crow_indices(), a.col_indices(), a.values(), (batch_size, m, k))
a_batched = torch.sparse_csr_tensor(
a.crow_indices(), a.col_indices(), a.values(), (batch_size, m, k), check_invariants=False)
b = make_tensor((batch_size, k, n), dtype=dtype, device=device, noncontiguous=noncontiguous)
c = make_tensor((batch_size, m, n), dtype=dtype, device=device, noncontiguous=noncontiguous)
@ -1420,9 +1428,13 @@ class TestSparseCSR(TestCase):
nnz = random.randint(0, m * k)
a = self.genSparseCSRTensor((m, k), nnz, dtype=dtype, device=device, index_dtype=index_dtype)
# a_batched is a regular CSR tensor but with a batch dimension in the shape
a_batched = torch._sparse_csr_tensor_unsafe(
a.crow_indices(), a.col_indices(), a.values(), (batch_size, m, k))
# a_batched is a regular CSR tensor but with a batch
# dimension in the shape. It is unorthodox in PyTorch
# to represent a batch sparse tensor in this way,
# hence checking the tensor invariants is locally
# turned off.
a_batched = torch.sparse_csr_tensor(
a.crow_indices(), a.col_indices(), a.values(), (batch_size, m, k), check_invariants=False)
b = make_tensor((batch_size, k, n), dtype=dtype, device=device, noncontiguous=noncontiguous)
for op_b, op_out in itertools.product([True, False], repeat=2):
@ -1549,8 +1561,8 @@ class TestSparseCSR(TestCase):
a = self.genSparseCSRTensor((m, k), nnz, dtype=dtype, device=device, index_dtype=index_dtype)
a_data = make_tensor((nnz, block_size, block_size), dtype=dtype, device=device)
a_data = a_data.mT if noncontiguous else a_data
a = torch._sparse_bsr_tensor_unsafe(a.crow_indices(), a.col_indices(),
a_data, (m * block_size, k * block_size))
a = torch.sparse_bsr_tensor(a.crow_indices(), a.col_indices(),
a_data, (m * block_size, k * block_size), check_invariants=False)
b = make_tensor((k * block_size, n * block_size), dtype=dtype, device=device, noncontiguous=noncontiguous)
c = make_tensor((m * block_size, n * block_size), dtype=dtype, device=device, noncontiguous=noncontiguous)
for op_b, op_out in itertools.product([True, False], repeat=2):
@ -1585,8 +1597,8 @@ class TestSparseCSR(TestCase):
a = self.genSparseCSRTensor((m, k), nnz, dtype=dtype, device=device, index_dtype=index_dtype)
a_data = make_tensor((nnz, block_size, block_size), dtype=dtype, device=device)
a_data = a_data.mT if noncontiguous else a_data # Test column-major blocks
a = torch._sparse_bsr_tensor_unsafe(a.crow_indices(), a.col_indices(),
a_data, (m * block_size, k * block_size))
a = torch.sparse_bsr_tensor(a.crow_indices(), a.col_indices(),
a_data, (m * block_size, k * block_size), check_invariants=False)
b = make_tensor((k * block_size,), dtype=dtype, device=device, noncontiguous=noncontiguous)
c = make_tensor((m * block_size,), dtype=dtype, device=device, noncontiguous=noncontiguous)
self.run_test_block_addmm_addmv(torch.addmv, c, a, b, dtype=dtype, device=device)
@ -1658,8 +1670,8 @@ class TestSparseCSR(TestCase):
a = self.genSparseCSRTensor((m, m), nnz, dtype=dtype, device=device, index_dtype=index_dtype)
a_data = make_tensor((nnz, block_size, block_size), dtype=dtype, device=device)
a_data = a_data.mT if noncontiguous else a_data # Test column-major blocks
a = torch._sparse_bsr_tensor_unsafe(a.crow_indices(), a.col_indices(),
a_data, (m * block_size, m * block_size))
a = torch.sparse_bsr_tensor(a.crow_indices(), a.col_indices(),
a_data, (m * block_size, m * block_size), check_invariants=False)
b = make_tensor((m * block_size, k), dtype=dtype, device=device, noncontiguous=noncontiguous)
for (upper, unitriangular, transpose, op_out) in itertools.product([True, False], repeat=4):

View File

@ -364,7 +364,8 @@ def gen_pyi(
f"{n2}_indices: Union[Tensor, List],"
" values: Union[Tensor, List], size: Optional[_size]=None,"
" *, dtype: Optional[_dtype]=None,"
" device: Union[_device, str, None]=None, requires_grad:_bool=False) -> Tensor: ..."
" device: Union[_device, str, None]=None, requires_grad:_bool=False,"
" check_invariants:_bool=None) -> Tensor: ..."
],
f"_sparse_{n}_tensor_unsafe": [
f"def _sparse_{n}_tensor_unsafe({n1}_indices: Union[Tensor, List],"
@ -411,7 +412,8 @@ def gen_pyi(
"sparse_coo_tensor": [
"def sparse_coo_tensor(indices: Tensor, values: Union[Tensor,List],"
" size: Optional[_size]=None, *, dtype: Optional[_dtype]=None,"
" device: Union[_device, str, None]=None, requires_grad:_bool=False) -> Tensor: ..."
" device: Union[_device, str, None]=None, requires_grad:_bool=False,"
" check_invariants:_bool=None) -> Tensor: ..."
],
"_sparse_coo_tensor_unsafe": [
"def _sparse_coo_tensor_unsafe(indices: Tensor, values: Tensor, size: List[int],"
@ -423,7 +425,8 @@ def gen_pyi(
"plain_indices: Union[Tensor, List],"
" values: Union[Tensor, List], size: Optional[_size]=None,"
" *, dtype: Optional[_dtype]=None, layout: Optional[_layout] = None,"
" device: Union[_device, str, None]=None, requires_grad:_bool=False) -> Tensor: ..."
" device: Union[_device, str, None]=None, requires_grad:_bool=False,"
" check_invariants:_bool=None) -> Tensor: ..."
],
"_sparse_compressed_tensor_unsafe": [
"def _sparse_compressed_tensor_unsafe(comp_indices: Union[Tensor, List],"

View File

@ -878,6 +878,8 @@ def _get_qengine() -> _int: ... # THPModule_qEngine
def _set_qengine(qegine: _int) -> None: ... # THPModule_setQEngine
def _supported_qengines() -> List[_int]: ... # THPModule_supportedQEngines
def _is_xnnpack_enabled() -> _bool: ... # THPModule_isEnabledXNNPACK
def _check_sparse_tensor_invariants() -> _bool: ... # THPModule_checkSparseTensorInvariants
def _set_check_sparse_tensor_invariants(arg: _bool) -> None: ... # THPModule_setCheckSparseTensorInvariants
def _set_default_mobile_cpu_allocator() -> None: ... # THPModule_setDefaultMobileCPUAllocator
def _unset_default_mobile_cpu_allocator() -> None: ... # THPModule_unsetDefaultMobileCPUAllocator
def _is_torch_function_enabled() -> _bool: ... # THPModule_isEnabledTorchFunction

View File

@ -50,8 +50,7 @@ __all__ = [
'set_deterministic_debug_mode', 'get_deterministic_debug_mode',
'set_float32_matmul_precision', 'get_float32_matmul_precision',
'set_warn_always', 'is_warn_always_enabled', 'SymInt', 'SymFloat',
'sym_int', 'sym_float', 'compile', 'vmap'
]
'sym_int', 'sym_float', 'compile', 'vmap']
################################################################################
# Load the extension module

View File

@ -106,6 +106,9 @@ factory_common_args = merge_dicts(
the pinned memory. Works only for CPU tensors. Default: ``False``.
memory_format (:class:`torch.memory_format`, optional): the desired memory format of
returned Tensor. Default: ``torch.contiguous_format``.
check_invariants (bool, optional): If sparse tensor invariants are checked.
Default: as returned by :func:`torch.sparse.check_sparse_tensor_invariants.is_enabled`,
initially False.
"""
),
{
@ -10161,7 +10164,7 @@ Example::
add_docstr(
torch.sparse_compressed_tensor,
r"""sparse_compressed_tensor(compressed_indices, plain_indices, values, size=None, """
r"""*, dtype=None, layout=None, device=None, requires_grad=False) -> Tensor
r"""*, dtype=None, layout=None, device=None, requires_grad=False, check_invariants=None) -> Tensor
Constructs a :ref:`sparse tensor in Compressed Sparse format - CSR,
CSC, BSR, or BSC - <sparse-compressed-docs>` with specified values at
@ -10213,6 +10216,7 @@ Keyword args:
the CPU for CPU tensor types and the current CUDA device for
CUDA tensor types.
{requires_grad}
{check_invariants}
Example::
>>> compressed_indices = [0, 2, 4]
@ -10232,8 +10236,8 @@ Example::
add_docstr(
torch.sparse_csr_tensor,
r"""
sparse_csr_tensor(crow_indices, col_indices, values, size=None, *, dtype=None, device=None, requires_grad=False) -> Tensor
r"""sparse_csr_tensor(crow_indices, col_indices, values, size=None, """
r"""*, dtype=None, device=None, requires_grad=False, check_invariants=None) -> Tensor
Constructs a :ref:`sparse tensor in CSR (Compressed Sparse Row) <sparse-csr-docs>` with specified
values at the given :attr:`crow_indices` and :attr:`col_indices`. Sparse matrix multiplication operations
@ -10273,6 +10277,7 @@ Keyword args:
the CPU for CPU tensor types and the current CUDA device for
CUDA tensor types.
{requires_grad}
{check_invariants}
Example::
>>> crow_indices = [0, 2, 4]
@ -10292,8 +10297,8 @@ Example::
add_docstr(
torch.sparse_csc_tensor,
r"""
sparse_csc_tensor(ccol_indices, row_indices, values, size=None, *, dtype=None, device=None, requires_grad=False) -> Tensor
r"""sparse_csc_tensor(ccol_indices, row_indices, values, size=None, """
r"""*, dtype=None, device=None, requires_grad=False, check_invariants=None) -> Tensor
Constructs a :ref:`sparse tensor in CSC (Compressed Sparse Column)
<sparse-csc-docs>` with specified values at the given
@ -10335,6 +10340,7 @@ Keyword args:
the CPU for CPU tensor types and the current CUDA device for
CUDA tensor types.
{requires_grad}
{check_invariants}
Example::
>>> ccol_indices = [0, 2, 4]
@ -10354,8 +10360,8 @@ Example::
add_docstr(
torch.sparse_bsr_tensor,
r"""
sparse_bsr_tensor(crow_indices, col_indices, values, size=None, *, dtype=None, device=None, requires_grad=False) -> Tensor
r"""sparse_bsr_tensor(crow_indices, col_indices, values, size=None, """
r"""*, dtype=None, device=None, requires_grad=False, check_invariants=None) -> Tensor
Constructs a :ref:`sparse tensor in BSR (Block Compressed Sparse Row))
<sparse-bsr-docs>` with specified 2-dimensional blocks at the given
@ -10399,6 +10405,7 @@ Keyword args:
the CPU for CPU tensor types and the current CUDA device for
CUDA tensor types.
{requires_grad}
{check_invariants}
Example::
>>> crow_indices = [0, 1, 2]
@ -10421,8 +10428,8 @@ Example::
add_docstr(
torch.sparse_bsc_tensor,
r"""
sparse_bsc_tensor(ccol_indices, row_indices, values, size=None, *, dtype=None, device=None, requires_grad=False) -> Tensor
r"""sparse_bsc_tensor(ccol_indices, row_indices, values, size=None, """
r"""*, dtype=None, device=None, requires_grad=False, check_invariants=None) -> Tensor
Constructs a :ref:`sparse tensor in BSC (Block Compressed Sparse
Column)) <sparse-bsc-docs>` with specified 2-dimensional blocks at the
@ -10465,6 +10472,7 @@ Keyword args:
the CPU for CPU tensor types and the current CUDA device for
CUDA tensor types.
{requires_grad}
{check_invariants}
Example::
>>> ccol_indices = [0, 1, 2]
@ -10488,7 +10496,7 @@ Example::
add_docstr(
torch.sparse_coo_tensor,
r"""
sparse_coo_tensor(indices, values, size=None, *, dtype=None, device=None, requires_grad=False) -> Tensor
sparse_coo_tensor(indices, values, size=None, *, dtype=None, device=None, requires_grad=False, check_invariants=None) -> Tensor
Constructs a :ref:`sparse tensor in COO(rdinate) format
<sparse-coo-docs>` with specified values at the given
@ -10520,7 +10528,7 @@ Keyword args:
(see :func:`torch.set_default_tensor_type`). :attr:`device` will be the CPU
for CPU tensor types and the current CUDA device for CUDA tensor types.
{requires_grad}
{check_invariants}
Example::

View File

@ -237,7 +237,7 @@ def _rebuild_sparse_tensor(layout, data):
"""
if layout == torch.sparse_coo:
indices, values, size = data
result = torch._sparse_coo_tensor_unsafe(indices, values, size)
result = torch.sparse_coo_tensor(indices, values, size, check_invariants=False)
_sparse_tensors_to_validate.append(result)
return result
@ -248,8 +248,13 @@ def _rebuild_sparse_tensor(layout, data):
torch.sparse_bsc,
}:
compressed_indices, plain_indices, values, size = data
result = torch._sparse_compressed_tensor_unsafe(
compressed_indices, plain_indices, values, size, layout=layout
result = torch.sparse_compressed_tensor(
compressed_indices,
plain_indices,
values,
size,
layout=layout,
check_invariants=False,
)
_sparse_tensors_to_validate.append(result)
return result

View File

@ -831,6 +831,27 @@ PyObject* THPModule_isEnabledXNNPACK(PyObject* _unused, PyObject* noargs) {
Py_RETURN_FALSE;
}
PyObject* THPModule_setCheckSparseTensorInvariants(
PyObject* _unused,
PyObject* arg) {
THPUtils_assert(
PyBool_Check(arg),
"set_check_sparse_tensor_invariants expects a bool, "
"but got %s",
THPUtils_typename(arg));
at::globalContext().setCheckSparseTensorInvariants(arg == Py_True);
Py_RETURN_NONE;
}
PyObject* THPModule_checkSparseTensorInvariants(
PyObject* _unused,
PyObject* noargs) {
if (at::globalContext().checkSparseTensorInvariants())
Py_RETURN_TRUE;
else
Py_RETURN_FALSE;
}
PyObject* THPModule_willEngineExecuteNode(PyObject* _unused, PyObject* arg) {
HANDLE_TH_ERRORS
bool isTHPFunction = THPFunction_Check(arg);
@ -1122,6 +1143,14 @@ static PyMethodDef TorchMethods[] = {
{"_set_qengine", THPModule_setQEngine, METH_O, nullptr},
{"_supported_qengines", THPModule_supportedQEngines, METH_NOARGS, nullptr},
{"_is_xnnpack_enabled", THPModule_isEnabledXNNPACK, METH_NOARGS, nullptr},
{"_set_check_sparse_tensor_invariants",
THPModule_setCheckSparseTensorInvariants,
METH_O,
nullptr},
{"_check_sparse_tensor_invariants",
THPModule_checkSparseTensorInvariants,
METH_NOARGS,
nullptr},
{"_will_engine_execute_node",
THPModule_willEngineExecuteNode,
METH_O,

View File

@ -197,29 +197,29 @@ static PyObject* THPVariable_nonzero(
THPVARIABLE_SPARSE_COMPRESSED_CTOR(
sparse_compressed_tensor,
9,
({"sparse_compressed_tensor(PyObject* compressed_indices, PyObject* plain_indices, PyObject* values, IntArrayRef size, *, ScalarType dtype=None, Layout? layout=None, Device? device=None, bool pin_memory=False, bool requires_grad=False)",
"sparse_compressed_tensor(PyObject* compressed_indices, PyObject* plain_indices, PyObject* values, *, ScalarType dtype=None, Layout? layout=None, Device? device=None, bool pin_memory=False, bool requires_grad=False)"}))
10,
({"sparse_compressed_tensor(PyObject* compressed_indices, PyObject* plain_indices, PyObject* values, IntArrayRef size, *, ScalarType dtype=None, Layout? layout=None, Device? device=None, bool pin_memory=False, bool requires_grad=False, bool check_invariants=None)",
"sparse_compressed_tensor(PyObject* compressed_indices, PyObject* plain_indices, PyObject* values, *, ScalarType dtype=None, Layout? layout=None, Device? device=None, bool pin_memory=False, bool requires_grad=False, bool check_invariants=None)"}))
THPVARIABLE_SPARSE_COMPRESSED_CTOR(
sparse_csr_tensor,
9,
({"sparse_csr_tensor(PyObject* crow_indices, PyObject* col_indices, PyObject* values, IntArrayRef size, *, ScalarType dtype=None, Layout? layout=None, Device? device=None, bool pin_memory=False, bool requires_grad=False)",
"sparse_csr_tensor(PyObject* crow_indices, PyObject* col_indices, PyObject* values, *, ScalarType dtype=None, Layout? layout=None, Device? device=None, bool pin_memory=False, bool requires_grad=False)"}))
10,
({"sparse_csr_tensor(PyObject* crow_indices, PyObject* col_indices, PyObject* values, IntArrayRef size, *, ScalarType dtype=None, Layout? layout=None, Device? device=None, bool pin_memory=False, bool requires_grad=False, bool check_invariants=None)",
"sparse_csr_tensor(PyObject* crow_indices, PyObject* col_indices, PyObject* values, *, ScalarType dtype=None, Layout? layout=None, Device? device=None, bool pin_memory=False, bool requires_grad=False, bool check_invariants=None)"}))
THPVARIABLE_SPARSE_COMPRESSED_CTOR(
sparse_csc_tensor,
9,
({"sparse_csc_tensor(PyObject* ccol_indices, PyObject* row_indices, PyObject* values, IntArrayRef size, *, ScalarType dtype=None, Layout? layout=None, Device? device=None, bool pin_memory=False, bool requires_grad=False)",
"sparse_csc_tensor(PyObject* ccol_indices, PyObject* row_indices, PyObject* values, *, ScalarType dtype=None, Layout? layout=None, Device? device=None, bool pin_memory=False, bool requires_grad=False)"}))
10,
({"sparse_csc_tensor(PyObject* ccol_indices, PyObject* row_indices, PyObject* values, IntArrayRef size, *, ScalarType dtype=None, Layout? layout=None, Device? device=None, bool pin_memory=False, bool requires_grad=False, bool check_invariants=None)",
"sparse_csc_tensor(PyObject* ccol_indices, PyObject* row_indices, PyObject* values, *, ScalarType dtype=None, Layout? layout=None, Device? device=None, bool pin_memory=False, bool requires_grad=False, bool check_invariants=None)"}))
THPVARIABLE_SPARSE_COMPRESSED_CTOR(
sparse_bsr_tensor,
9,
({"sparse_bsr_tensor(PyObject* crow_indices, PyObject* col_indices, PyObject* values, IntArrayRef size, *, ScalarType dtype=None, Layout? layout=None, Device? device=None, bool pin_memory=False, bool requires_grad=False)",
"sparse_bsr_tensor(PyObject* crow_indices, PyObject* col_indices, PyObject* values, *, ScalarType dtype=None, Layout? layout=None, Device? device=None, bool pin_memory=False, bool requires_grad=False)"}))
10,
({"sparse_bsr_tensor(PyObject* crow_indices, PyObject* col_indices, PyObject* values, IntArrayRef size, *, ScalarType dtype=None, Layout? layout=None, Device? device=None, bool pin_memory=False, bool requires_grad=False, bool check_invariants=None)",
"sparse_bsr_tensor(PyObject* crow_indices, PyObject* col_indices, PyObject* values, *, ScalarType dtype=None, Layout? layout=None, Device? device=None, bool pin_memory=False, bool requires_grad=False, bool check_invariants=None)"}))
THPVARIABLE_SPARSE_COMPRESSED_CTOR(
sparse_bsc_tensor,
9,
({"sparse_bsc_tensor(PyObject* ccol_indices, PyObject* row_indices, PyObject* values, IntArrayRef size, *, ScalarType dtype=None, Layout? layout=None, Device? device=None, bool pin_memory=False, bool requires_grad=False)",
"sparse_bsc_tensor(PyObject* ccol_indices, PyObject* row_indices, PyObject* values, *, ScalarType dtype=None, Layout? layout=None, Device? device=None, bool pin_memory=False, bool requires_grad=False)"}))
10,
({"sparse_bsc_tensor(PyObject* ccol_indices, PyObject* row_indices, PyObject* values, IntArrayRef size, *, ScalarType dtype=None, Layout? layout=None, Device? device=None, bool pin_memory=False, bool requires_grad=False, bool check_invariants=None)",
"sparse_bsc_tensor(PyObject* ccol_indices, PyObject* row_indices, PyObject* values, *, ScalarType dtype=None, Layout? layout=None, Device? device=None, bool pin_memory=False, bool requires_grad=False, bool check_invariants=None)"}))
THPVARIABLE_SPARSE_COMPRESSED_CTOR(
_sparse_compressed_tensor_unsafe,
@ -248,12 +248,12 @@ static PyObject* THPVariable_sparse_coo_tensor(
PyObject* kwargs) {
HANDLE_TH_ERRORS
static PythonArgParser parser({
"sparse_coo_tensor(PyObject* indices, PyObject* values, *, ScalarType dtype=None, Device? device=None, bool requires_grad=False)",
"sparse_coo_tensor(PyObject* indices, PyObject* values, IntArrayRef size, *, ScalarType dtype=None, Device? device=None, bool requires_grad=False)",
"sparse_coo_tensor(IntArrayRef size, *, ScalarType dtype=None, Device? device=None, bool requires_grad=False)",
"sparse_coo_tensor(PyObject* indices, PyObject* values, *, ScalarType dtype=None, Device? device=None, bool requires_grad=False, bool check_invariants=None)",
"sparse_coo_tensor(PyObject* indices, PyObject* values, IntArrayRef size, *, ScalarType dtype=None, Device? device=None, bool requires_grad=False, bool check_invariants=None)",
"sparse_coo_tensor(IntArrayRef size, *, ScalarType dtype=None, Device? device=None, bool requires_grad=False, bool check_invariants=None)",
});
ParsedArgs<6> parsed_args;
ParsedArgs<7> parsed_args;
auto r = parser.parse(args, kwargs, parsed_args);
if (r.has_torch_function()) {
return handle_torch_function(

View File

@ -784,6 +784,19 @@ Tensor indexing_tensor_from_data(
}
}
class CheckSparseTensorInvariantsContext {
public:
CheckSparseTensorInvariantsContext() {
state = at::globalContext().checkSparseTensorInvariants();
}
~CheckSparseTensorInvariantsContext() {
at::globalContext().setCheckSparseTensorInvariants(state);
}
private:
bool state;
};
Tensor sparse_compressed_tensor_ctor_worker(
std::string name,
c10::DispatchKey dispatch_key,
@ -802,6 +815,7 @@ Tensor sparse_compressed_tensor_ctor_worker(
ARG_DEVICE,
ARG_PIN_MEMORY,
ARG_REQUIRES_GRAD,
ARG_CHECK_INVARIANTS,
ARGS_COUNT
};
enum {
@ -811,6 +825,7 @@ Tensor sparse_compressed_tensor_ctor_worker(
ARG_DEVICE1,
ARG_PIN_MEMORY1,
ARG_REQUIRES_GRAD1,
ARG_CHECK_INVARIANTS1,
ARGS_COUNT1
};
@ -840,6 +855,10 @@ Tensor sparse_compressed_tensor_ctor_worker(
at::ScalarType plain_indices_scalar_type = plain_indices_dtype_attr
? reinterpret_cast<THPDtype*>(plain_indices_dtype_attr.get())->scalar_type
: kInt;
CheckSparseTensorInvariantsContext
restores_check_sparse_tensor_invariants_global_state{};
bool default_check_invariants =
at::globalContext().checkSparseTensorInvariants();
if (r.idx == 0) {
bool type_inference = r.isNone(ARG_TYPE);
@ -848,6 +867,10 @@ Tensor sparse_compressed_tensor_ctor_worker(
const auto inferred_scalar_type =
r.scalartypeWithDefault(ARG_TYPE, scalar_type);
at::OptionalDeviceGuard device_guard(r.deviceOptional(ARG_DEVICE));
// the global state of invariants check flag will be restored via
// CheckSparseTensorInvariantsContext destructor
at::globalContext().setCheckSparseTensorInvariants(
r.toBoolWithDefault(ARG_CHECK_INVARIANTS, default_check_invariants));
Tensor values = internal_new_from_data(
inferred_options,
@ -900,6 +923,10 @@ Tensor sparse_compressed_tensor_ctor_worker(
const auto inferred_scalar_type =
r.scalartypeWithDefault(ARG_TYPE1, scalar_type);
at::OptionalDeviceGuard device_guard(r.deviceOptional(ARG_DEVICE1));
// the global state of invariants check flag will be restored via
// CheckSparseTensorInvariantsContext destructor
at::globalContext().setCheckSparseTensorInvariants(
r.toBoolWithDefault(ARG_CHECK_INVARIANTS1, default_check_invariants));
Tensor values = internal_new_from_data(
inferred_options,
@ -1170,17 +1197,54 @@ Tensor sparse_coo_tensor_ctor(
PythonArgs& r) {
TORCH_INTERNAL_ASSERT(!isSparse(dispatchKeyToBackend(dispatch_key)));
TORCH_INTERNAL_ASSERT(!isSparseCsr(dispatchKeyToBackend(dispatch_key)));
enum {
ARG_INDICES = 0,
ARG_VALUES,
ARG_TYPE,
ARG_DEVICE,
ARG_REQUIRES_GRAD,
ARG_CHECK_INVARIANTS,
ARGS_COUNT
};
enum {
ARG_INDICES1 = 0,
ARG_VALUES1,
ARG_SIZE1,
ARG_TYPE1,
ARG_DEVICE1,
ARG_REQUIRES_GRAD1,
ARG_CHECK_INVARIANTS1,
ARGS_COUNT1
};
enum {
ARG_SIZE2 = 0,
ARG_TYPE2,
ARG_DEVICE2,
ARG_REQUIRES_GRAD2,
ARG_CHECK_INVARIANTS2,
ARGS_COUNT2
};
CheckSparseTensorInvariantsContext
restores_check_sparse_tensor_invariants_global_state{};
bool default_check_invariants =
at::globalContext().checkSparseTensorInvariants();
if (r.idx == 0) {
bool type_inference = r.isNone(2);
const auto inferred_options = typeIdWithDefault(r, 3, dispatch_key);
const auto inferred_scalar_type = r.scalartypeWithDefault(2, scalar_type);
at::OptionalDeviceGuard device_guard(r.deviceOptional(3));
bool type_inference = r.isNone(ARG_TYPE);
const auto inferred_options =
typeIdWithDefault(r, ARG_DEVICE, dispatch_key);
const auto inferred_scalar_type =
r.scalartypeWithDefault(ARG_TYPE, scalar_type);
at::OptionalDeviceGuard device_guard(r.deviceOptional(ARG_DEVICE));
at::globalContext().setCheckSparseTensorInvariants(
r.toBoolWithDefault(ARG_CHECK_INVARIANTS, default_check_invariants));
// if no dtype provided, infer type based on value type.
Tensor values = internal_new_from_data(
inferred_options,
inferred_scalar_type,
r.deviceOptional(3),
r.pyobject(1),
r.deviceOptional(ARG_DEVICE),
r.pyobject(ARG_VALUES),
/*copy_variables=*/false,
/*copy_numpy=*/true,
/*type_inference=*/type_inference);
@ -1188,24 +1252,29 @@ Tensor sparse_coo_tensor_ctor(
Tensor indices = internal_new_from_data(
values.options(),
kLong,
r.deviceOptional(3),
r.pyobject(0),
r.deviceOptional(ARG_DEVICE),
r.pyobject(ARG_INDICES),
/*copy_variables=*/false,
/*copy_numpy=*/true,
/*type_inference=*/false);
return at::sparse_coo_tensor(
indices, values, values.options().layout(at::kSparse))
.set_requires_grad(r.toBool(4));
.set_requires_grad(r.toBool(ARG_REQUIRES_GRAD));
} else if (r.idx == 1) {
bool type_inference = r.isNone(3);
const auto inferred_options = typeIdWithDefault(r, 4, dispatch_key);
const auto inferred_scalar_type = r.scalartypeWithDefault(3, scalar_type);
at::OptionalDeviceGuard device_guard(r.deviceOptional(4));
bool type_inference = r.isNone(ARG_TYPE1);
const auto inferred_options =
typeIdWithDefault(r, ARG_DEVICE1, dispatch_key);
const auto inferred_scalar_type =
r.scalartypeWithDefault(ARG_TYPE1, scalar_type);
at::OptionalDeviceGuard device_guard(r.deviceOptional(ARG_DEVICE1));
at::globalContext().setCheckSparseTensorInvariants(
r.toBoolWithDefault(ARG_CHECK_INVARIANTS1, default_check_invariants));
Tensor values = internal_new_from_data(
inferred_options,
inferred_scalar_type,
r.deviceOptional(4),
r.pyobject(1),
r.deviceOptional(ARG_DEVICE1),
r.pyobject(ARG_VALUES1),
/*copy_variables=*/false,
/*copy_numpy=*/true,
/*type_inference=*/type_inference);
@ -1213,25 +1282,30 @@ Tensor sparse_coo_tensor_ctor(
Tensor indices = internal_new_from_data(
values.options(),
kLong,
r.deviceOptional(4),
r.pyobject(0),
r.deviceOptional(ARG_DEVICE1),
r.pyobject(ARG_INDICES1),
/*copy_variables=*/false,
/*copy_numpy=*/true,
/*type_inference=*/false);
return at::sparse_coo_tensor(
indices,
values,
r.intlist(2),
r.intlist(ARG_SIZE1),
values.options().layout(at::kSparse))
.set_requires_grad(r.toBool(5));
.set_requires_grad(r.toBool(ARG_REQUIRES_GRAD1));
} else if (r.idx == 2) {
const auto inferred_options = typeIdWithDefault(r, 2, dispatch_key);
const auto inferred_scalar_type = r.scalartypeWithDefault(1, scalar_type);
at::OptionalDeviceGuard device_guard(r.deviceOptional(2));
const auto inferred_options =
typeIdWithDefault(r, ARG_DEVICE2, dispatch_key);
const auto inferred_scalar_type =
r.scalartypeWithDefault(ARG_TYPE2, scalar_type);
at::OptionalDeviceGuard device_guard(r.deviceOptional(ARG_DEVICE2));
at::globalContext().setCheckSparseTensorInvariants(
r.toBoolWithDefault(ARG_CHECK_INVARIANTS2, default_check_invariants));
return at::sparse_coo_tensor(
r.intlist(0),
r.intlist(ARG_SIZE2),
inferred_options.dtype(inferred_scalar_type).layout(at::kSparse))
.set_requires_grad(r.toBool(3));
.set_requires_grad(r.toBool(ARG_REQUIRES_GRAD2));
}
throw std::runtime_error("sparse_coo_tensor(): invalid arguments");
}

View File

@ -18,6 +18,7 @@ else:
__all__ = [
'addmm',
'check_sparse_tensor_invariants',
'mm',
'sum',
'softmax',
@ -356,3 +357,108 @@ Specifying a positive offset::
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0]])
""")
class check_sparse_tensor_invariants(object):
"""A tool to control checking sparse tensor invariants.
The following options exists to manage sparsr tensor invariants
checking in sparse tensor construction:
1. Using a context manager:
.. code:: python
with torch.sparse.check_sparse_tensor_invariants():
run_my_model()
2. Using a procedural approach:
.. code:: python
prev_checks_enabled = torch.sparse.check_sparse_tensor_invariants.is_enabled()
torch.sparse.check_sparse_tensor_invariants.enable()
run_my_model()
if not prev_checks_enabled:
torch.sparse.check_sparse_tensor_invariants.disable()
3. Using function decoration:
.. code:: python
@torch.sparse.check_sparse_tensor_invariants()
def run_my_model():
...
run_my_model()
4. Using ``check_invariants`` keyword argument in sparse tensor constructor call.
For example:
>>> torch.sparse_csr_tensor([0, 1, 3], [0, 1], [1, 2], check_invariants=True)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
RuntimeError: `crow_indices[..., -1] == nnz` is not satisfied.
"""
@staticmethod
def is_enabled():
r"""Returns True if the sparse tensor invariants checking is enabled.
.. note::
Use :func:`torch.sparse.check_sparse_tensor_invariants.enable` or
:func:`torch.sparse.check_sparse_tensor_invariants.disable` to
manage the state of the sparse tensor invariants checks.
"""
return torch._C._check_sparse_tensor_invariants()
@staticmethod
def enable():
r"""Enable sparse tensor invariants checking in sparse tensor constructors.
.. note::
By default, the sparse tensor invariants checks are disabled. Use
:func:`torch.sparse.check_sparse_tensor_invariants.is_enabled` to
retrieve the current state of sparse tensor invariants checking.
.. note::
The sparse tensor invariants check flag is effective to all sparse
tensor constructors, both in Python and ATen.
The flag can be locally overridden by the ``check_invariants``
optional argument of the sparse tensor constructor functions.
"""
torch._C._set_check_sparse_tensor_invariants(True)
@staticmethod
def disable():
r"""Disable sparse tensor invariants checking in sparse tensor constructors.
See :func:`torch.sparse.check_sparse_tensor_invariants.enable` for more information.
"""
torch._C._set_check_sparse_tensor_invariants(False)
# context manager support
def __init__(self, enable=True):
self.state = enable
self.saved_state = self.is_enabled()
def __enter__(self):
torch._C._set_check_sparse_tensor_invariants(self.state)
def __exit__(self, type, value, traceback):
torch._C._set_check_sparse_tensor_invariants(self.saved_state)
# decorator support
def __call__(self, mth):
def test_mth(*args, **kwargs):
with type(self)(self.state):
return mth(*args, **kwargs)
return test_mth

View File

@ -2220,6 +2220,29 @@ class TestCase(expecttest.TestCase):
check_if_enable(self)
set_rng_seed(SEED)
# Save global check sparse tensor invariants state that can be
# restored from tearDown:
self._check_invariants = torch.sparse.check_sparse_tensor_invariants.is_enabled()
# Enable invariant checks for all sparse tensors constructions
# including the unsafe ones. If this is not desired for some
# test case, use check_invariants=False optional argument to
# sparse tensor constructors or
# @torch.sparse.check_sparse_tensor_invariants(False)
# decorator to disable the invariant checks.
torch.sparse.check_sparse_tensor_invariants.enable()
def tearDown(self):
# There exists test cases that override TestCase.setUp
# definition, so we cannot assume that _check_invariants
# attribute is defined in general.
if hasattr(self, '_check_invariants'):
# Restore the global check sparse tensor invariants state
if self._check_invariants:
torch.sparse.check_sparse_tensor_invariants.enable()
else:
torch.sparse.check_sparse_tensor_invariants.disable()
@staticmethod
def _make_crow_indices(n_rows, n_cols, nnz,
*, device, dtype, random=True):