mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Fix zeros_like for sparse tensors with batch dimensions. Add opinfo-based tests to like-functions. (#101215)
Fixes #101078 Pull Request resolved: https://github.com/pytorch/pytorch/pull/101215 Approved by: https://github.com/cpuhrsch
This commit is contained in:
parent
597e2a11a3
commit
cbe270d233
|
|
@ -101,20 +101,16 @@ void SparseCsrTensorImpl::resize_(int64_t nnz, IntArrayRef size) {
|
||||||
refresh_numel();
|
refresh_numel();
|
||||||
}
|
}
|
||||||
|
|
||||||
void SparseCsrTensorImpl::resize_and_clear_(int64_t sparse_dim, IntArrayRef size) {
|
void SparseCsrTensorImpl::resize_and_clear_(int64_t sparse_dim, int64_t dense_dim, IntArrayRef size) {
|
||||||
TORCH_CHECK(
|
TORCH_CHECK(
|
||||||
!has_symbolic_sizes_strides_,
|
!has_symbolic_sizes_strides_,
|
||||||
"resize_and_clear_ called on tensor with symbolic shape");
|
"resize_and_clear_ called on tensor with symbolic shape");
|
||||||
TORCH_CHECK(sparse_dim >= 2, "resize_and_clear_ sparse dimensionality must be at least 2, got ", sparse_dim);
|
TORCH_CHECK(sparse_dim == 2, "resize_and_clear_ sparse dimensionality must be 2, got ", sparse_dim);
|
||||||
TORCH_CHECK(static_cast<int64_t>(size.size()) >= sparse_dim, "resize_and_clear_ size length must be at least sparse dimensionality (=",
|
TORCH_CHECK(static_cast<int64_t>(size.size()) >= sparse_dim + dense_dim, "resize_and_clear_ size length must be at least sparse dimensionality (=",
|
||||||
sparse_dim, "), got ", size.size());
|
sparse_dim, ") plus dense dimensionality (=", dense_dim, "), got ", size.size());
|
||||||
auto batch_dim = sparse_dim - 2;
|
auto batch_dim = size.size() - sparse_dim - dense_dim;
|
||||||
auto batchsize = size.slice(0, batch_dim);
|
auto batchsize = size.slice(0, batch_dim);
|
||||||
auto densesize = size.slice(batch_dim + 2, size.size() - batch_dim - 2);
|
auto densesize = size.slice(batch_dim + sparse_dim, dense_dim);
|
||||||
|
|
||||||
auto values_size = DimVector(batchsize);
|
|
||||||
values_size.push_back(0); // nse
|
|
||||||
values_size.append(densesize.begin(), densesize.end());
|
|
||||||
|
|
||||||
auto col_indices_size = DimVector(batchsize);
|
auto col_indices_size = DimVector(batchsize);
|
||||||
col_indices_size.push_back(0); // nse
|
col_indices_size.push_back(0); // nse
|
||||||
|
|
@ -123,14 +119,26 @@ void SparseCsrTensorImpl::resize_and_clear_(int64_t sparse_dim, IntArrayRef size
|
||||||
[&] () -> int64_t { return size[batch_dim]; },
|
[&] () -> int64_t { return size[batch_dim]; },
|
||||||
[&] () -> int64_t { return size[batch_dim + 1]; }
|
[&] () -> int64_t { return size[batch_dim + 1]; }
|
||||||
);
|
);
|
||||||
|
auto values_size = DimVector(batchsize);
|
||||||
|
values_size.push_back(0); // nse
|
||||||
|
// WARNING: in the case of block tensors, the block size is defined
|
||||||
|
// by the existing values shape.
|
||||||
|
int64_t block_factor = 1;
|
||||||
AT_DISPATCH_PLAIN_SPARSE_COMPRESSED_LAYOUTS(layout_,
|
AT_DISPATCH_PLAIN_SPARSE_COMPRESSED_LAYOUTS(layout_,
|
||||||
"resize_and_clear_",
|
"resize_and_clear_",
|
||||||
[] () {},
|
[] () {},
|
||||||
[&] () {
|
[&] () {
|
||||||
auto blocksize = this->values_.sizes().slice(this->batch_dim() + 1, 2);
|
auto blocksize = this->values_.sizes().slice(this->batch_dim() + 1, 2);
|
||||||
values_size.append(blocksize.begin(), blocksize.end());
|
values_size.append(blocksize.begin(), blocksize.end());
|
||||||
n_compressed_indices /= blocksize[(the_layout == kSparseBsr ? 0 : 1)];
|
block_factor = blocksize[(the_layout == kSparseBsr ? 0 : 1)];
|
||||||
|
|
||||||
});
|
});
|
||||||
|
TORCH_CHECK(n_compressed_indices % block_factor == 0,
|
||||||
|
"The size of the compressed dimension (=", n_compressed_indices,
|
||||||
|
") must be divisible with the corresponding block size (=", block_factor,")");
|
||||||
|
n_compressed_indices /= block_factor;
|
||||||
|
values_size.append(densesize.begin(), densesize.end());
|
||||||
|
|
||||||
auto crow_indices_size = DimVector(batchsize);
|
auto crow_indices_size = DimVector(batchsize);
|
||||||
crow_indices_size.push_back(n_compressed_indices + 1);
|
crow_indices_size.push_back(n_compressed_indices + 1);
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -37,7 +37,10 @@ struct TORCH_API SparseCsrTensorImpl : public TensorImpl {
|
||||||
const caffe2::TypeMeta);
|
const caffe2::TypeMeta);
|
||||||
|
|
||||||
void resize_(int64_t nnz, IntArrayRef size);
|
void resize_(int64_t nnz, IntArrayRef size);
|
||||||
void resize_and_clear_(int64_t sparse_dim, IntArrayRef size);
|
void resize_and_clear_(
|
||||||
|
int64_t sparse_dim,
|
||||||
|
int64_t dense_dim,
|
||||||
|
IntArrayRef size);
|
||||||
void resize_as_sparse_compressed_tensor_(const Tensor& src);
|
void resize_as_sparse_compressed_tensor_(const Tensor& src);
|
||||||
void set_member_tensors(
|
void set_member_tensors(
|
||||||
const Tensor& crow_indices,
|
const Tensor& crow_indices,
|
||||||
|
|
|
||||||
|
|
@ -408,7 +408,7 @@ Tensor& zero_sparse_csr_(Tensor& self) {
|
||||||
`result = csr.clone(); result.values.zero_();`
|
`result = csr.clone(); result.values.zero_();`
|
||||||
*/
|
*/
|
||||||
AT_DISPATCH_ALL_SPARSE_COMPRESSED_LAYOUTS(self.layout(), "zero_sparse_csr_", [](){});
|
AT_DISPATCH_ALL_SPARSE_COMPRESSED_LAYOUTS(self.layout(), "zero_sparse_csr_", [](){});
|
||||||
get_sparse_csr_impl(self)->resize_and_clear_(self.sparse_dim(), self.sizes());
|
get_sparse_csr_impl(self)->resize_and_clear_(self.sparse_dim(), self.dense_dim(), self.sizes());
|
||||||
return self;
|
return self;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -21,7 +21,7 @@ from torch.testing._internal.common_device_type import \
|
||||||
(instantiate_device_type_tests, ops, dtypes, dtypesIfCUDA, onlyCPU, onlyCUDA, precisionOverride,
|
(instantiate_device_type_tests, ops, dtypes, dtypesIfCUDA, onlyCPU, onlyCUDA, precisionOverride,
|
||||||
deviceCountAtLeast, OpDTypes, onlyNativeDeviceTypes)
|
deviceCountAtLeast, OpDTypes, onlyNativeDeviceTypes)
|
||||||
from torch.testing._internal.common_methods_invocations import \
|
from torch.testing._internal.common_methods_invocations import \
|
||||||
(reduction_ops, sparse_unary_ufuncs, sparse_masked_reduction_ops, binary_ufuncs)
|
(op_db, reduction_ops, sparse_unary_ufuncs, sparse_masked_reduction_ops, binary_ufuncs)
|
||||||
from torch.testing._internal.common_dtype import (
|
from torch.testing._internal.common_dtype import (
|
||||||
all_types, all_types_and_complex, all_types_and_complex_and, floating_and_complex_types,
|
all_types, all_types_and_complex, all_types_and_complex_and, floating_and_complex_types,
|
||||||
floating_and_complex_types_and, integral_types, floating_types_and,
|
floating_and_complex_types_and, integral_types, floating_types_and,
|
||||||
|
|
@ -39,6 +39,8 @@ reduction_ops_with_sparse_support = [op for op in reduction_ops if 'masked.' not
|
||||||
|
|
||||||
binary_ufuncs_with_sparse_support = [op for op in binary_ufuncs if _op_supports_any_sparse(op)]
|
binary_ufuncs_with_sparse_support = [op for op in binary_ufuncs if _op_supports_any_sparse(op)]
|
||||||
|
|
||||||
|
like_fns_with_sparse_support = [op for op in op_db if _op_supports_any_sparse(op) and '_like' in op.name]
|
||||||
|
|
||||||
if TEST_SCIPY:
|
if TEST_SCIPY:
|
||||||
import scipy.sparse
|
import scipy.sparse
|
||||||
|
|
||||||
|
|
@ -4858,6 +4860,62 @@ class TestSparseAny(TestCase):
|
||||||
run_test(m, n, k, device, dtype)
|
run_test(m, n, k, device, dtype)
|
||||||
|
|
||||||
|
|
||||||
|
@onlyNativeDeviceTypes
|
||||||
|
@suppress_warnings
|
||||||
|
@ops(like_fns_with_sparse_support)
|
||||||
|
@all_sparse_layouts('layout', include_strided=False)
|
||||||
|
def test_like_fns(self, layout, device, dtype, op):
|
||||||
|
|
||||||
|
for sample in op.sample_inputs_sparse(layout, device, dtype):
|
||||||
|
t_inp, t_args, t_kwargs = sample.input, sample.args, sample.kwargs
|
||||||
|
batch_dim = t_inp.dim() - t_inp.dense_dim() - t_inp.sparse_dim()
|
||||||
|
if t_inp.layout in {torch.sparse_bsr, torch.sparse_bsc}:
|
||||||
|
expected_blocksize = t_inp.values().shape[batch_dim + 1:batch_dim + 3]
|
||||||
|
else:
|
||||||
|
expected_blocksize = None
|
||||||
|
expected_dtype = t_kwargs.get('dtype', dtype)
|
||||||
|
expected_device = torch.device(t_kwargs.get('device', device))
|
||||||
|
expected_layout = t_kwargs.get('layout', layout)
|
||||||
|
|
||||||
|
result = op.op(t_inp, *t_args, **t_kwargs)
|
||||||
|
|
||||||
|
self.assertEqual(result.dtype, expected_dtype)
|
||||||
|
self.assertEqual(result.device.type, expected_device.type)
|
||||||
|
self.assertEqual(result.layout, expected_layout)
|
||||||
|
|
||||||
|
if result.layout in {torch.sparse_bsr, torch.sparse_bsc}:
|
||||||
|
result_batch_dim = result.dim() - result.dense_dim() - result.sparse_dim()
|
||||||
|
blocksize = result.values().shape[result_batch_dim + 1:result_batch_dim + 3]
|
||||||
|
self.assertEqual(blocksize, expected_blocksize)
|
||||||
|
|
||||||
|
# Check op(inp).shape == inp.shape
|
||||||
|
self.assertEqual(result.shape, t_inp.shape)
|
||||||
|
|
||||||
|
if expected_layout is torch.strided:
|
||||||
|
self.assertEqual(result.sparse_dim(), 0)
|
||||||
|
# Check op(inp, layout=torch.strided).dense_dim() == inp.dim()
|
||||||
|
self.assertEqual(result.dense_dim(), t_inp.dim())
|
||||||
|
elif expected_layout is torch.sparse_coo:
|
||||||
|
# Check op(inp, layout=torch.sparse_coo).sparse_dim() == batch_dim + inp.sparse_dim()
|
||||||
|
self.assertEqual(result.sparse_dim(), batch_dim + t_inp.sparse_dim())
|
||||||
|
# Check op(inp, layout=torch.sparse_coo).dense_dim() == inp.dense_dim()
|
||||||
|
self.assertEqual(result.dense_dim(), t_inp.dense_dim())
|
||||||
|
|
||||||
|
torch._validate_sparse_coo_tensor_args(result._indices(), result._values(), result.shape)
|
||||||
|
else:
|
||||||
|
# Check op(inp).sparse_dim() == inp.sparse_dim()
|
||||||
|
self.assertEqual(result.sparse_dim(), t_inp.sparse_dim())
|
||||||
|
# Check op(inp).dense_dim() == inp.dense_dim()
|
||||||
|
self.assertEqual(result.dense_dim(), t_inp.dense_dim())
|
||||||
|
|
||||||
|
if result.layout in {torch.sparse_csr, torch.sparse_bsr}:
|
||||||
|
compressed_indices, plain_indices = result.crow_indices(), result.col_indices()
|
||||||
|
else:
|
||||||
|
compressed_indices, plain_indices = result.ccol_indices(), result.row_indices()
|
||||||
|
|
||||||
|
torch._validate_sparse_compressed_tensor_args(compressed_indices, plain_indices, result.values(),
|
||||||
|
result.shape, result.layout)
|
||||||
|
|
||||||
# e.g., TestSparseUnaryUfuncsCPU and TestSparseUnaryUfuncsCUDA
|
# e.g., TestSparseUnaryUfuncsCPU and TestSparseUnaryUfuncsCUDA
|
||||||
instantiate_device_type_tests(TestSparseUnaryUfuncs, globals(), except_for='meta')
|
instantiate_device_type_tests(TestSparseUnaryUfuncs, globals(), except_for='meta')
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -515,6 +515,14 @@ class TestSparseCompressed(TestCase):
|
||||||
if len(samples) == 0:
|
if len(samples) == 0:
|
||||||
raise ValueError("Expected at least one 2 or higher D tensor in samples.")
|
raise ValueError("Expected at least one 2 or higher D tensor in samples.")
|
||||||
|
|
||||||
|
# Re-define atol and rtol for operations that result values
|
||||||
|
# are random (and hence, non-comparable) be we still want to
|
||||||
|
# check the shape, dtype, etc attributes of the results:
|
||||||
|
atol = rtol = None
|
||||||
|
if op.name == 'randn_like':
|
||||||
|
atol = 1e300
|
||||||
|
rtol = 1
|
||||||
|
|
||||||
for sample, sparse_sample in samples:
|
for sample, sparse_sample in samples:
|
||||||
expected = op(sample.input, *sample.args, **sample.kwargs)
|
expected = op(sample.input, *sample.args, **sample.kwargs)
|
||||||
assert torch.is_tensor(expected)
|
assert torch.is_tensor(expected)
|
||||||
|
|
@ -524,7 +532,7 @@ class TestSparseCompressed(TestCase):
|
||||||
if require_mask and sample.kwargs.get('mask') is not None:
|
if require_mask and sample.kwargs.get('mask') is not None:
|
||||||
output_mask = torch.masked._output_mask(op.op, sample.input, *sample.args, **sample.kwargs)
|
output_mask = torch.masked._output_mask(op.op, sample.input, *sample.args, **sample.kwargs)
|
||||||
expected.masked_fill_(~output_mask, 0)
|
expected.masked_fill_(~output_mask, 0)
|
||||||
self.assertEqual(strided_output, expected)
|
self.assertEqual(strided_output, expected, atol=atol, rtol=rtol)
|
||||||
|
|
||||||
@skipMeta
|
@skipMeta
|
||||||
@all_sparse_compressed_layouts()
|
@all_sparse_compressed_layouts()
|
||||||
|
|
|
||||||
|
|
@ -134,6 +134,8 @@ from torch.testing._internal.opinfo.definitions._masked import (
|
||||||
sample_inputs_softmax_variant,
|
sample_inputs_softmax_variant,
|
||||||
)
|
)
|
||||||
from torch.testing._internal.opinfo.definitions.sparse import (
|
from torch.testing._internal.opinfo.definitions.sparse import (
|
||||||
|
error_inputs_sparse_like_fns,
|
||||||
|
sample_inputs_sparse_like_fns,
|
||||||
error_inputs_sparse_mul,
|
error_inputs_sparse_mul,
|
||||||
sample_inputs_sparse_mul,
|
sample_inputs_sparse_mul,
|
||||||
error_inputs_sparse_reduction_sum,
|
error_inputs_sparse_reduction_sum,
|
||||||
|
|
@ -15690,6 +15692,12 @@ op_db: List[OpInfo] = [
|
||||||
supports_out=False,
|
supports_out=False,
|
||||||
sample_inputs_func=sample_inputs_like_fns,
|
sample_inputs_func=sample_inputs_like_fns,
|
||||||
supports_autograd=False,
|
supports_autograd=False,
|
||||||
|
error_inputs_sparse_func=error_inputs_sparse_like_fns,
|
||||||
|
sample_inputs_sparse_coo_func=partial(sample_inputs_sparse_like_fns, layout=torch.sparse_coo),
|
||||||
|
sample_inputs_sparse_csr_func=partial(sample_inputs_sparse_like_fns, layout=torch.sparse_csr),
|
||||||
|
sample_inputs_sparse_csc_func=partial(sample_inputs_sparse_like_fns, layout=torch.sparse_csc),
|
||||||
|
sample_inputs_sparse_bsr_func=partial(sample_inputs_sparse_like_fns, layout=torch.sparse_bsr),
|
||||||
|
sample_inputs_sparse_bsc_func=partial(sample_inputs_sparse_like_fns, layout=torch.sparse_bsc),
|
||||||
skips=(
|
skips=(
|
||||||
)),
|
)),
|
||||||
OpInfo('ones_like',
|
OpInfo('ones_like',
|
||||||
|
|
@ -15732,7 +15740,12 @@ op_db: List[OpInfo] = [
|
||||||
supports_out=False,
|
supports_out=False,
|
||||||
sample_inputs_func=sample_inputs_like_fns,
|
sample_inputs_func=sample_inputs_like_fns,
|
||||||
supports_autograd=False,
|
supports_autograd=False,
|
||||||
supports_sparse_csr=True,
|
error_inputs_sparse_func=error_inputs_sparse_like_fns,
|
||||||
|
sample_inputs_sparse_coo_func=partial(sample_inputs_sparse_like_fns, layout=torch.sparse_coo),
|
||||||
|
sample_inputs_sparse_csr_func=partial(sample_inputs_sparse_like_fns, layout=torch.sparse_csr),
|
||||||
|
sample_inputs_sparse_csc_func=partial(sample_inputs_sparse_like_fns, layout=torch.sparse_csc),
|
||||||
|
sample_inputs_sparse_bsr_func=partial(sample_inputs_sparse_like_fns, layout=torch.sparse_bsr),
|
||||||
|
sample_inputs_sparse_bsc_func=partial(sample_inputs_sparse_like_fns, layout=torch.sparse_bsc),
|
||||||
skips=(
|
skips=(
|
||||||
DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"),
|
DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"),
|
||||||
# AssertionError: JIT Test does not execute any logic
|
# AssertionError: JIT Test does not execute any logic
|
||||||
|
|
|
||||||
|
|
@ -973,18 +973,28 @@ class OpInfo:
|
||||||
# corresponding layout support implies the layout support:
|
# corresponding layout support implies the layout support:
|
||||||
if self.supports_sparse is None:
|
if self.supports_sparse is None:
|
||||||
self.supports_sparse = self.sample_inputs_sparse_coo_func is not None
|
self.supports_sparse = self.sample_inputs_sparse_coo_func is not None
|
||||||
|
if self.sample_inputs_sparse_coo_func is None:
|
||||||
|
self.sample_inputs_sparse_coo_func = self._sample_inputs_unspecified
|
||||||
|
|
||||||
if self.supports_sparse_csr is None:
|
if self.supports_sparse_csr is None:
|
||||||
self.supports_sparse_csr = self.sample_inputs_sparse_csr_func is not None
|
self.supports_sparse_csr = self.sample_inputs_sparse_csr_func is not None
|
||||||
|
if self.sample_inputs_sparse_csr_func is None:
|
||||||
|
self.sample_inputs_sparse_csr_func = self._sample_inputs_unspecified
|
||||||
|
|
||||||
if self.supports_sparse_csc is None:
|
if self.supports_sparse_csc is None:
|
||||||
self.supports_sparse_csc = self.sample_inputs_sparse_csc_func is not None
|
self.supports_sparse_csc = self.sample_inputs_sparse_csc_func is not None
|
||||||
|
if self.sample_inputs_sparse_csc_func is None:
|
||||||
|
self.sample_inputs_sparse_csc_func = self._sample_inputs_unspecified
|
||||||
|
|
||||||
if self.supports_sparse_bsr is None:
|
if self.supports_sparse_bsr is None:
|
||||||
self.supports_sparse_bsr = self.sample_inputs_sparse_bsr_func is not None
|
self.supports_sparse_bsr = self.sample_inputs_sparse_bsr_func is not None
|
||||||
|
if self.sample_inputs_sparse_bsr_func is None:
|
||||||
|
self.sample_inputs_sparse_bsr_func = self._sample_inputs_unspecified
|
||||||
|
|
||||||
if self.supports_sparse_bsc is None:
|
if self.supports_sparse_bsc is None:
|
||||||
self.supports_sparse_bsc = self.sample_inputs_sparse_bsc_func is not None
|
self.supports_sparse_bsc = self.sample_inputs_sparse_bsc_func is not None
|
||||||
|
if self.sample_inputs_sparse_bsc_func is None:
|
||||||
|
self.sample_inputs_sparse_bsc_func = self._sample_inputs_unspecified
|
||||||
|
|
||||||
# We run the sampling functions without tracking the gradiends of the creation of inputs
|
# We run the sampling functions without tracking the gradiends of the creation of inputs
|
||||||
self.sample_inputs_func = torch.no_grad()(self.sample_inputs_func)
|
self.sample_inputs_func = torch.no_grad()(self.sample_inputs_func)
|
||||||
|
|
@ -1228,6 +1238,21 @@ class OpInfo:
|
||||||
sample_inputs_mth(device, dtype, requires_grad=requires_grad, **kwargs),
|
sample_inputs_mth(device, dtype, requires_grad=requires_grad, **kwargs),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _sample_inputs_unspecified(self, *args, **kwargs):
|
||||||
|
"""Raises an NotImplemented exception in a OpInfo instance creation
|
||||||
|
that specifies supports_sparse(|_csr|_csc|_bsr|_bsc)=True
|
||||||
|
without specifying the corresponding sample function as
|
||||||
|
sample_inputs_sparse_(coo|csr|csc|bsr|bsc)_func.
|
||||||
|
|
||||||
|
To avoid this, either define the corresponding sample function,
|
||||||
|
or re-map unsupported samples to error inputs in an appropiate
|
||||||
|
|
||||||
|
opinfo/definitions/sparse.py:_validate_sample_input_sparse_<op>
|
||||||
|
|
||||||
|
function.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError("no sample function specified")
|
||||||
|
|
||||||
def sample_inputs_sparse_coo(self, device, dtype, requires_grad=False, **kwargs):
|
def sample_inputs_sparse_coo(self, device, dtype, requires_grad=False, **kwargs):
|
||||||
"""Returns an iterable of SampleInputs that contain inputs with sparse
|
"""Returns an iterable of SampleInputs that contain inputs with sparse
|
||||||
coo layout.
|
coo layout.
|
||||||
|
|
|
||||||
|
|
@ -598,6 +598,7 @@ def _validate_sample_input_elementwise_binary_sparse_mul(sample):
|
||||||
layout is torch.sparse_csr
|
layout is torch.sparse_csr
|
||||||
and dtype is torch.complex32
|
and dtype is torch.complex32
|
||||||
and t_inp.numel() > 0
|
and t_inp.numel() > 0
|
||||||
|
and t_inp._nnz() > 0
|
||||||
and t_args[0].numel() > 0
|
and t_args[0].numel() > 0
|
||||||
and t_args[0].ndim > 0
|
and t_args[0].ndim > 0
|
||||||
):
|
):
|
||||||
|
|
@ -619,6 +620,7 @@ def _validate_sample_input_elementwise_binary_sparse_mul(sample):
|
||||||
elif (
|
elif (
|
||||||
layout in {torch.sparse_coo, torch.sparse_csr}
|
layout in {torch.sparse_coo, torch.sparse_csr}
|
||||||
and dtype is torch.bool
|
and dtype is torch.bool
|
||||||
|
and t_inp._nnz() > 0
|
||||||
and t_args[0].ndim > 0
|
and t_args[0].ndim > 0
|
||||||
and t_inp.is_cpu
|
and t_inp.is_cpu
|
||||||
and t_inp.numel() > 0
|
and t_inp.numel() > 0
|
||||||
|
|
@ -649,6 +651,7 @@ def _validate_sample_input_elementwise_binary_sparse_mul(sample):
|
||||||
elif (
|
elif (
|
||||||
layout is torch.sparse_csr
|
layout is torch.sparse_csr
|
||||||
and t_inp.dense_dim() > 0
|
and t_inp.dense_dim() > 0
|
||||||
|
and t_inp._nnz() > 0
|
||||||
and t_inp.is_cpu
|
and t_inp.is_cpu
|
||||||
and dtype is torch.float16
|
and dtype is torch.float16
|
||||||
and t_args[0].ndim > 0
|
and t_args[0].ndim > 0
|
||||||
|
|
@ -758,6 +761,150 @@ def error_inputs_sparse_mul(op_info, device, layout, **kwargs):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _sample_inputs_sparse_like_fns(
|
||||||
|
op_info, device, dtype, requires_grad, layout, **kwargs
|
||||||
|
):
|
||||||
|
from torch.testing._internal.common_utils import TestCase
|
||||||
|
|
||||||
|
for tensor in TestCase().generate_simple_inputs(
|
||||||
|
layout,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
enable_batch=True,
|
||||||
|
enable_hybrid=True,
|
||||||
|
enable_zero_sized=True,
|
||||||
|
enable_non_contiguous_indices=False,
|
||||||
|
enable_non_contiguous_values=False,
|
||||||
|
):
|
||||||
|
yield SampleInput(tensor, args=(), kwargs={})
|
||||||
|
yield SampleInput(
|
||||||
|
tensor, args=(), kwargs=dict(device=device, dtype=dtype, layout=layout)
|
||||||
|
)
|
||||||
|
|
||||||
|
if dtype is not torch.float64:
|
||||||
|
yield SampleInput(tensor, args=(), kwargs=dict(dtype=torch.float64))
|
||||||
|
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
other_device = "cuda" if tensor.device.type == "cpu" else "cpu"
|
||||||
|
yield SampleInput(tensor, args=(), kwargs=dict(device=other_device))
|
||||||
|
|
||||||
|
if layout is torch.sparse_csr:
|
||||||
|
other_layout = torch.sparse_csc
|
||||||
|
elif layout is torch.sparse_csc:
|
||||||
|
other_layout = torch.sparse_csr
|
||||||
|
elif layout is torch.sparse_bsr:
|
||||||
|
other_layout = torch.sparse_bsc
|
||||||
|
elif layout is torch.sparse_bsc:
|
||||||
|
other_layout = torch.sparse_bsr
|
||||||
|
else:
|
||||||
|
other_layout = torch.strided
|
||||||
|
yield SampleInput(tensor, args=(), kwargs=dict(layout=other_layout))
|
||||||
|
|
||||||
|
if layout is not torch.sparse_coo:
|
||||||
|
yield SampleInput(tensor, args=(), kwargs=dict(layout=torch.sparse_coo))
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_sample_input_sparse_like_fns(op_info, sample, check_validate=False):
|
||||||
|
if sample.input.layout in {
|
||||||
|
torch.sparse_csr,
|
||||||
|
torch.sparse_csc,
|
||||||
|
torch.sparse_bsr,
|
||||||
|
torch.sparse_bsc,
|
||||||
|
}:
|
||||||
|
if sample.kwargs.get("device", sample.input.device) != sample.input.device:
|
||||||
|
return ErrorInput(
|
||||||
|
sample,
|
||||||
|
error_regex=(
|
||||||
|
"device of (ccol|crow)_indices \\(=(cpu|cuda.*)\\) must"
|
||||||
|
" match device of values \\(=(cuda.*|cpu)\\)"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
if sample.kwargs.get("layout", sample.input.layout) != sample.input.layout:
|
||||||
|
return ErrorInput(
|
||||||
|
sample,
|
||||||
|
error_regex=(
|
||||||
|
"empty_like with different sparse layout is not supported"
|
||||||
|
" \\(self is Sparse(Csc|Csr|Bsc|Bsr) but you requested Sparse(Csr|Csc|Bsr|Bsc)\\)"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
if sample.input.layout is torch.sparse_coo:
|
||||||
|
return ErrorInput(
|
||||||
|
sample,
|
||||||
|
error_regex=(
|
||||||
|
"Could not run 'aten::normal_' with arguments from the 'Sparse(CPU|CUDA)' backend."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
if check_validate:
|
||||||
|
_check_validate(op_info, sample)
|
||||||
|
return sample
|
||||||
|
|
||||||
|
|
||||||
|
def _maybe_failing_sample_inputs_sparse_like_fns(
|
||||||
|
op_info, device, dtype, requires_grad, layout, **kwargs
|
||||||
|
):
|
||||||
|
if torch.cuda.is_available() and layout is not torch.sparse_coo:
|
||||||
|
other_device = "cuda" if torch.device(device).type == "cpu" else "cpu"
|
||||||
|
if layout is torch.sparse_csr:
|
||||||
|
other_layout = torch.sparse_csc
|
||||||
|
elif layout is torch.sparse_csc:
|
||||||
|
other_layout = torch.sparse_csr
|
||||||
|
elif layout is torch.sparse_bsr:
|
||||||
|
other_layout = torch.sparse_bsc
|
||||||
|
elif layout is torch.sparse_bsc:
|
||||||
|
other_layout = torch.sparse_bsr
|
||||||
|
else:
|
||||||
|
other_layout = torch.strided
|
||||||
|
|
||||||
|
blocksize = (1, 1) if layout in {torch.sparse_bsr, torch.sparse_bsc} else None
|
||||||
|
|
||||||
|
yield SampleInput(
|
||||||
|
torch.tensor([[0, 1], [2, 3]], dtype=dtype, device=device).to_sparse(
|
||||||
|
layout=layout, blocksize=blocksize
|
||||||
|
),
|
||||||
|
kwargs=dict(device=other_device),
|
||||||
|
)
|
||||||
|
|
||||||
|
yield SampleInput(
|
||||||
|
torch.tensor([[0, 1], [2, 3]], dtype=dtype, device=device).to_sparse(
|
||||||
|
layout=layout, blocksize=blocksize
|
||||||
|
),
|
||||||
|
kwargs=dict(layout=other_layout),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def sample_inputs_sparse_like_fns(
|
||||||
|
op_info, device, dtype, requires_grad, layout, **kwargs
|
||||||
|
):
|
||||||
|
"""Sample inputs for like-functions on sparse tensors."""
|
||||||
|
yield from _sample_inputs_sparse(
|
||||||
|
_sample_inputs_sparse_like_fns,
|
||||||
|
_maybe_failing_sample_inputs_sparse_like_fns,
|
||||||
|
_validate_sample_input_sparse_like_fns,
|
||||||
|
op_info,
|
||||||
|
device,
|
||||||
|
dtype,
|
||||||
|
requires_grad,
|
||||||
|
layout,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def error_inputs_sparse_like_fns(op_info, device, layout, **kwargs):
|
||||||
|
"""Error inputs for like-functions on sparse tensors."""
|
||||||
|
dtype = torch.float64
|
||||||
|
requires_grad = False
|
||||||
|
yield from _error_inputs_sparse(
|
||||||
|
_maybe_failing_sample_inputs_sparse_like_fns,
|
||||||
|
_validate_sample_input_sparse_like_fns,
|
||||||
|
op_info,
|
||||||
|
device,
|
||||||
|
dtype,
|
||||||
|
requires_grad,
|
||||||
|
layout,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _validate_sample_input_sparse_default(op_info, sample, check_validate=False):
|
def _validate_sample_input_sparse_default(op_info, sample, check_validate=False):
|
||||||
if op_info.name == "to_sparse":
|
if op_info.name == "to_sparse":
|
||||||
if (
|
if (
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user