diff --git a/aten/src/ATen/SparseCsrTensorImpl.cpp b/aten/src/ATen/SparseCsrTensorImpl.cpp index 472a24276f3..8661a9bd94e 100644 --- a/aten/src/ATen/SparseCsrTensorImpl.cpp +++ b/aten/src/ATen/SparseCsrTensorImpl.cpp @@ -101,20 +101,16 @@ void SparseCsrTensorImpl::resize_(int64_t nnz, IntArrayRef size) { refresh_numel(); } -void SparseCsrTensorImpl::resize_and_clear_(int64_t sparse_dim, IntArrayRef size) { +void SparseCsrTensorImpl::resize_and_clear_(int64_t sparse_dim, int64_t dense_dim, IntArrayRef size) { TORCH_CHECK( !has_symbolic_sizes_strides_, "resize_and_clear_ called on tensor with symbolic shape"); - TORCH_CHECK(sparse_dim >= 2, "resize_and_clear_ sparse dimensionality must be at least 2, got ", sparse_dim); - TORCH_CHECK(static_cast(size.size()) >= sparse_dim, "resize_and_clear_ size length must be at least sparse dimensionality (=", - sparse_dim, "), got ", size.size()); - auto batch_dim = sparse_dim - 2; + TORCH_CHECK(sparse_dim == 2, "resize_and_clear_ sparse dimensionality must be 2, got ", sparse_dim); + TORCH_CHECK(static_cast(size.size()) >= sparse_dim + dense_dim, "resize_and_clear_ size length must be at least sparse dimensionality (=", + sparse_dim, ") plus dense dimensionality (=", dense_dim, "), got ", size.size()); + auto batch_dim = size.size() - sparse_dim - dense_dim; auto batchsize = size.slice(0, batch_dim); - auto densesize = size.slice(batch_dim + 2, size.size() - batch_dim - 2); - - auto values_size = DimVector(batchsize); - values_size.push_back(0); // nse - values_size.append(densesize.begin(), densesize.end()); + auto densesize = size.slice(batch_dim + sparse_dim, dense_dim); auto col_indices_size = DimVector(batchsize); col_indices_size.push_back(0); // nse @@ -123,14 +119,26 @@ void SparseCsrTensorImpl::resize_and_clear_(int64_t sparse_dim, IntArrayRef size [&] () -> int64_t { return size[batch_dim]; }, [&] () -> int64_t { return size[batch_dim + 1]; } ); + auto values_size = DimVector(batchsize); + values_size.push_back(0); // nse + // WARNING: in the case of block tensors, the block size is defined + // by the existing values shape. + int64_t block_factor = 1; AT_DISPATCH_PLAIN_SPARSE_COMPRESSED_LAYOUTS(layout_, "resize_and_clear_", [] () {}, [&] () { auto blocksize = this->values_.sizes().slice(this->batch_dim() + 1, 2); values_size.append(blocksize.begin(), blocksize.end()); - n_compressed_indices /= blocksize[(the_layout == kSparseBsr ? 0 : 1)]; + block_factor = blocksize[(the_layout == kSparseBsr ? 0 : 1)]; + }); + TORCH_CHECK(n_compressed_indices % block_factor == 0, + "The size of the compressed dimension (=", n_compressed_indices, + ") must be divisible with the corresponding block size (=", block_factor,")"); + n_compressed_indices /= block_factor; + values_size.append(densesize.begin(), densesize.end()); + auto crow_indices_size = DimVector(batchsize); crow_indices_size.push_back(n_compressed_indices + 1); diff --git a/aten/src/ATen/SparseCsrTensorImpl.h b/aten/src/ATen/SparseCsrTensorImpl.h index f2d90162ca9..de54c1d41db 100644 --- a/aten/src/ATen/SparseCsrTensorImpl.h +++ b/aten/src/ATen/SparseCsrTensorImpl.h @@ -37,7 +37,10 @@ struct TORCH_API SparseCsrTensorImpl : public TensorImpl { const caffe2::TypeMeta); void resize_(int64_t nnz, IntArrayRef size); - void resize_and_clear_(int64_t sparse_dim, IntArrayRef size); + void resize_and_clear_( + int64_t sparse_dim, + int64_t dense_dim, + IntArrayRef size); void resize_as_sparse_compressed_tensor_(const Tensor& src); void set_member_tensors( const Tensor& crow_indices, diff --git a/aten/src/ATen/native/sparse/SparseCsrTensorMath.cpp b/aten/src/ATen/native/sparse/SparseCsrTensorMath.cpp index 995a8b5193e..f7de8f75b07 100644 --- a/aten/src/ATen/native/sparse/SparseCsrTensorMath.cpp +++ b/aten/src/ATen/native/sparse/SparseCsrTensorMath.cpp @@ -408,7 +408,7 @@ Tensor& zero_sparse_csr_(Tensor& self) { `result = csr.clone(); result.values.zero_();` */ AT_DISPATCH_ALL_SPARSE_COMPRESSED_LAYOUTS(self.layout(), "zero_sparse_csr_", [](){}); - get_sparse_csr_impl(self)->resize_and_clear_(self.sparse_dim(), self.sizes()); + get_sparse_csr_impl(self)->resize_and_clear_(self.sparse_dim(), self.dense_dim(), self.sizes()); return self; } diff --git a/test/test_sparse.py b/test/test_sparse.py index 7382633f669..a490de4e184 100644 --- a/test/test_sparse.py +++ b/test/test_sparse.py @@ -21,7 +21,7 @@ from torch.testing._internal.common_device_type import \ (instantiate_device_type_tests, ops, dtypes, dtypesIfCUDA, onlyCPU, onlyCUDA, precisionOverride, deviceCountAtLeast, OpDTypes, onlyNativeDeviceTypes) from torch.testing._internal.common_methods_invocations import \ - (reduction_ops, sparse_unary_ufuncs, sparse_masked_reduction_ops, binary_ufuncs) + (op_db, reduction_ops, sparse_unary_ufuncs, sparse_masked_reduction_ops, binary_ufuncs) from torch.testing._internal.common_dtype import ( all_types, all_types_and_complex, all_types_and_complex_and, floating_and_complex_types, floating_and_complex_types_and, integral_types, floating_types_and, @@ -39,6 +39,8 @@ reduction_ops_with_sparse_support = [op for op in reduction_ops if 'masked.' not binary_ufuncs_with_sparse_support = [op for op in binary_ufuncs if _op_supports_any_sparse(op)] +like_fns_with_sparse_support = [op for op in op_db if _op_supports_any_sparse(op) and '_like' in op.name] + if TEST_SCIPY: import scipy.sparse @@ -4858,6 +4860,62 @@ class TestSparseAny(TestCase): run_test(m, n, k, device, dtype) + @onlyNativeDeviceTypes + @suppress_warnings + @ops(like_fns_with_sparse_support) + @all_sparse_layouts('layout', include_strided=False) + def test_like_fns(self, layout, device, dtype, op): + + for sample in op.sample_inputs_sparse(layout, device, dtype): + t_inp, t_args, t_kwargs = sample.input, sample.args, sample.kwargs + batch_dim = t_inp.dim() - t_inp.dense_dim() - t_inp.sparse_dim() + if t_inp.layout in {torch.sparse_bsr, torch.sparse_bsc}: + expected_blocksize = t_inp.values().shape[batch_dim + 1:batch_dim + 3] + else: + expected_blocksize = None + expected_dtype = t_kwargs.get('dtype', dtype) + expected_device = torch.device(t_kwargs.get('device', device)) + expected_layout = t_kwargs.get('layout', layout) + + result = op.op(t_inp, *t_args, **t_kwargs) + + self.assertEqual(result.dtype, expected_dtype) + self.assertEqual(result.device.type, expected_device.type) + self.assertEqual(result.layout, expected_layout) + + if result.layout in {torch.sparse_bsr, torch.sparse_bsc}: + result_batch_dim = result.dim() - result.dense_dim() - result.sparse_dim() + blocksize = result.values().shape[result_batch_dim + 1:result_batch_dim + 3] + self.assertEqual(blocksize, expected_blocksize) + + # Check op(inp).shape == inp.shape + self.assertEqual(result.shape, t_inp.shape) + + if expected_layout is torch.strided: + self.assertEqual(result.sparse_dim(), 0) + # Check op(inp, layout=torch.strided).dense_dim() == inp.dim() + self.assertEqual(result.dense_dim(), t_inp.dim()) + elif expected_layout is torch.sparse_coo: + # Check op(inp, layout=torch.sparse_coo).sparse_dim() == batch_dim + inp.sparse_dim() + self.assertEqual(result.sparse_dim(), batch_dim + t_inp.sparse_dim()) + # Check op(inp, layout=torch.sparse_coo).dense_dim() == inp.dense_dim() + self.assertEqual(result.dense_dim(), t_inp.dense_dim()) + + torch._validate_sparse_coo_tensor_args(result._indices(), result._values(), result.shape) + else: + # Check op(inp).sparse_dim() == inp.sparse_dim() + self.assertEqual(result.sparse_dim(), t_inp.sparse_dim()) + # Check op(inp).dense_dim() == inp.dense_dim() + self.assertEqual(result.dense_dim(), t_inp.dense_dim()) + + if result.layout in {torch.sparse_csr, torch.sparse_bsr}: + compressed_indices, plain_indices = result.crow_indices(), result.col_indices() + else: + compressed_indices, plain_indices = result.ccol_indices(), result.row_indices() + + torch._validate_sparse_compressed_tensor_args(compressed_indices, plain_indices, result.values(), + result.shape, result.layout) + # e.g., TestSparseUnaryUfuncsCPU and TestSparseUnaryUfuncsCUDA instantiate_device_type_tests(TestSparseUnaryUfuncs, globals(), except_for='meta') diff --git a/test/test_sparse_csr.py b/test/test_sparse_csr.py index a32957ffc14..f43a71129ef 100644 --- a/test/test_sparse_csr.py +++ b/test/test_sparse_csr.py @@ -515,6 +515,14 @@ class TestSparseCompressed(TestCase): if len(samples) == 0: raise ValueError("Expected at least one 2 or higher D tensor in samples.") + # Re-define atol and rtol for operations that result values + # are random (and hence, non-comparable) be we still want to + # check the shape, dtype, etc attributes of the results: + atol = rtol = None + if op.name == 'randn_like': + atol = 1e300 + rtol = 1 + for sample, sparse_sample in samples: expected = op(sample.input, *sample.args, **sample.kwargs) assert torch.is_tensor(expected) @@ -524,7 +532,7 @@ class TestSparseCompressed(TestCase): if require_mask and sample.kwargs.get('mask') is not None: output_mask = torch.masked._output_mask(op.op, sample.input, *sample.args, **sample.kwargs) expected.masked_fill_(~output_mask, 0) - self.assertEqual(strided_output, expected) + self.assertEqual(strided_output, expected, atol=atol, rtol=rtol) @skipMeta @all_sparse_compressed_layouts() diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index ecb7ecb3d61..db07177e60c 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -134,6 +134,8 @@ from torch.testing._internal.opinfo.definitions._masked import ( sample_inputs_softmax_variant, ) from torch.testing._internal.opinfo.definitions.sparse import ( + error_inputs_sparse_like_fns, + sample_inputs_sparse_like_fns, error_inputs_sparse_mul, sample_inputs_sparse_mul, error_inputs_sparse_reduction_sum, @@ -15690,6 +15692,12 @@ op_db: List[OpInfo] = [ supports_out=False, sample_inputs_func=sample_inputs_like_fns, supports_autograd=False, + error_inputs_sparse_func=error_inputs_sparse_like_fns, + sample_inputs_sparse_coo_func=partial(sample_inputs_sparse_like_fns, layout=torch.sparse_coo), + sample_inputs_sparse_csr_func=partial(sample_inputs_sparse_like_fns, layout=torch.sparse_csr), + sample_inputs_sparse_csc_func=partial(sample_inputs_sparse_like_fns, layout=torch.sparse_csc), + sample_inputs_sparse_bsr_func=partial(sample_inputs_sparse_like_fns, layout=torch.sparse_bsr), + sample_inputs_sparse_bsc_func=partial(sample_inputs_sparse_like_fns, layout=torch.sparse_bsc), skips=( )), OpInfo('ones_like', @@ -15732,7 +15740,12 @@ op_db: List[OpInfo] = [ supports_out=False, sample_inputs_func=sample_inputs_like_fns, supports_autograd=False, - supports_sparse_csr=True, + error_inputs_sparse_func=error_inputs_sparse_like_fns, + sample_inputs_sparse_coo_func=partial(sample_inputs_sparse_like_fns, layout=torch.sparse_coo), + sample_inputs_sparse_csr_func=partial(sample_inputs_sparse_like_fns, layout=torch.sparse_csr), + sample_inputs_sparse_csc_func=partial(sample_inputs_sparse_like_fns, layout=torch.sparse_csc), + sample_inputs_sparse_bsr_func=partial(sample_inputs_sparse_like_fns, layout=torch.sparse_bsr), + sample_inputs_sparse_bsc_func=partial(sample_inputs_sparse_like_fns, layout=torch.sparse_bsc), skips=( DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), # AssertionError: JIT Test does not execute any logic diff --git a/torch/testing/_internal/opinfo/core.py b/torch/testing/_internal/opinfo/core.py index 28746910e7c..8f86cbd06af 100644 --- a/torch/testing/_internal/opinfo/core.py +++ b/torch/testing/_internal/opinfo/core.py @@ -973,18 +973,28 @@ class OpInfo: # corresponding layout support implies the layout support: if self.supports_sparse is None: self.supports_sparse = self.sample_inputs_sparse_coo_func is not None + if self.sample_inputs_sparse_coo_func is None: + self.sample_inputs_sparse_coo_func = self._sample_inputs_unspecified if self.supports_sparse_csr is None: self.supports_sparse_csr = self.sample_inputs_sparse_csr_func is not None + if self.sample_inputs_sparse_csr_func is None: + self.sample_inputs_sparse_csr_func = self._sample_inputs_unspecified if self.supports_sparse_csc is None: self.supports_sparse_csc = self.sample_inputs_sparse_csc_func is not None + if self.sample_inputs_sparse_csc_func is None: + self.sample_inputs_sparse_csc_func = self._sample_inputs_unspecified if self.supports_sparse_bsr is None: self.supports_sparse_bsr = self.sample_inputs_sparse_bsr_func is not None + if self.sample_inputs_sparse_bsr_func is None: + self.sample_inputs_sparse_bsr_func = self._sample_inputs_unspecified if self.supports_sparse_bsc is None: self.supports_sparse_bsc = self.sample_inputs_sparse_bsc_func is not None + if self.sample_inputs_sparse_bsc_func is None: + self.sample_inputs_sparse_bsc_func = self._sample_inputs_unspecified # We run the sampling functions without tracking the gradiends of the creation of inputs self.sample_inputs_func = torch.no_grad()(self.sample_inputs_func) @@ -1228,6 +1238,21 @@ class OpInfo: sample_inputs_mth(device, dtype, requires_grad=requires_grad, **kwargs), ) + def _sample_inputs_unspecified(self, *args, **kwargs): + """Raises an NotImplemented exception in a OpInfo instance creation + that specifies supports_sparse(|_csr|_csc|_bsr|_bsc)=True + without specifying the corresponding sample function as + sample_inputs_sparse_(coo|csr|csc|bsr|bsc)_func. + + To avoid this, either define the corresponding sample function, + or re-map unsupported samples to error inputs in an appropiate + + opinfo/definitions/sparse.py:_validate_sample_input_sparse_ + + function. + """ + raise NotImplementedError("no sample function specified") + def sample_inputs_sparse_coo(self, device, dtype, requires_grad=False, **kwargs): """Returns an iterable of SampleInputs that contain inputs with sparse coo layout. diff --git a/torch/testing/_internal/opinfo/definitions/sparse.py b/torch/testing/_internal/opinfo/definitions/sparse.py index fdb177f96c0..c4ea9593b9c 100644 --- a/torch/testing/_internal/opinfo/definitions/sparse.py +++ b/torch/testing/_internal/opinfo/definitions/sparse.py @@ -598,6 +598,7 @@ def _validate_sample_input_elementwise_binary_sparse_mul(sample): layout is torch.sparse_csr and dtype is torch.complex32 and t_inp.numel() > 0 + and t_inp._nnz() > 0 and t_args[0].numel() > 0 and t_args[0].ndim > 0 ): @@ -619,6 +620,7 @@ def _validate_sample_input_elementwise_binary_sparse_mul(sample): elif ( layout in {torch.sparse_coo, torch.sparse_csr} and dtype is torch.bool + and t_inp._nnz() > 0 and t_args[0].ndim > 0 and t_inp.is_cpu and t_inp.numel() > 0 @@ -649,6 +651,7 @@ def _validate_sample_input_elementwise_binary_sparse_mul(sample): elif ( layout is torch.sparse_csr and t_inp.dense_dim() > 0 + and t_inp._nnz() > 0 and t_inp.is_cpu and dtype is torch.float16 and t_args[0].ndim > 0 @@ -758,6 +761,150 @@ def error_inputs_sparse_mul(op_info, device, layout, **kwargs): ) +def _sample_inputs_sparse_like_fns( + op_info, device, dtype, requires_grad, layout, **kwargs +): + from torch.testing._internal.common_utils import TestCase + + for tensor in TestCase().generate_simple_inputs( + layout, + device=device, + dtype=dtype, + enable_batch=True, + enable_hybrid=True, + enable_zero_sized=True, + enable_non_contiguous_indices=False, + enable_non_contiguous_values=False, + ): + yield SampleInput(tensor, args=(), kwargs={}) + yield SampleInput( + tensor, args=(), kwargs=dict(device=device, dtype=dtype, layout=layout) + ) + + if dtype is not torch.float64: + yield SampleInput(tensor, args=(), kwargs=dict(dtype=torch.float64)) + + if torch.cuda.is_available(): + other_device = "cuda" if tensor.device.type == "cpu" else "cpu" + yield SampleInput(tensor, args=(), kwargs=dict(device=other_device)) + + if layout is torch.sparse_csr: + other_layout = torch.sparse_csc + elif layout is torch.sparse_csc: + other_layout = torch.sparse_csr + elif layout is torch.sparse_bsr: + other_layout = torch.sparse_bsc + elif layout is torch.sparse_bsc: + other_layout = torch.sparse_bsr + else: + other_layout = torch.strided + yield SampleInput(tensor, args=(), kwargs=dict(layout=other_layout)) + + if layout is not torch.sparse_coo: + yield SampleInput(tensor, args=(), kwargs=dict(layout=torch.sparse_coo)) + + +def _validate_sample_input_sparse_like_fns(op_info, sample, check_validate=False): + if sample.input.layout in { + torch.sparse_csr, + torch.sparse_csc, + torch.sparse_bsr, + torch.sparse_bsc, + }: + if sample.kwargs.get("device", sample.input.device) != sample.input.device: + return ErrorInput( + sample, + error_regex=( + "device of (ccol|crow)_indices \\(=(cpu|cuda.*)\\) must" + " match device of values \\(=(cuda.*|cpu)\\)" + ), + ) + if sample.kwargs.get("layout", sample.input.layout) != sample.input.layout: + return ErrorInput( + sample, + error_regex=( + "empty_like with different sparse layout is not supported" + " \\(self is Sparse(Csc|Csr|Bsc|Bsr) but you requested Sparse(Csr|Csc|Bsr|Bsc)\\)" + ), + ) + if sample.input.layout is torch.sparse_coo: + return ErrorInput( + sample, + error_regex=( + "Could not run 'aten::normal_' with arguments from the 'Sparse(CPU|CUDA)' backend." + ), + ) + if check_validate: + _check_validate(op_info, sample) + return sample + + +def _maybe_failing_sample_inputs_sparse_like_fns( + op_info, device, dtype, requires_grad, layout, **kwargs +): + if torch.cuda.is_available() and layout is not torch.sparse_coo: + other_device = "cuda" if torch.device(device).type == "cpu" else "cpu" + if layout is torch.sparse_csr: + other_layout = torch.sparse_csc + elif layout is torch.sparse_csc: + other_layout = torch.sparse_csr + elif layout is torch.sparse_bsr: + other_layout = torch.sparse_bsc + elif layout is torch.sparse_bsc: + other_layout = torch.sparse_bsr + else: + other_layout = torch.strided + + blocksize = (1, 1) if layout in {torch.sparse_bsr, torch.sparse_bsc} else None + + yield SampleInput( + torch.tensor([[0, 1], [2, 3]], dtype=dtype, device=device).to_sparse( + layout=layout, blocksize=blocksize + ), + kwargs=dict(device=other_device), + ) + + yield SampleInput( + torch.tensor([[0, 1], [2, 3]], dtype=dtype, device=device).to_sparse( + layout=layout, blocksize=blocksize + ), + kwargs=dict(layout=other_layout), + ) + + +def sample_inputs_sparse_like_fns( + op_info, device, dtype, requires_grad, layout, **kwargs +): + """Sample inputs for like-functions on sparse tensors.""" + yield from _sample_inputs_sparse( + _sample_inputs_sparse_like_fns, + _maybe_failing_sample_inputs_sparse_like_fns, + _validate_sample_input_sparse_like_fns, + op_info, + device, + dtype, + requires_grad, + layout, + **kwargs, + ) + + +def error_inputs_sparse_like_fns(op_info, device, layout, **kwargs): + """Error inputs for like-functions on sparse tensors.""" + dtype = torch.float64 + requires_grad = False + yield from _error_inputs_sparse( + _maybe_failing_sample_inputs_sparse_like_fns, + _validate_sample_input_sparse_like_fns, + op_info, + device, + dtype, + requires_grad, + layout, + **kwargs, + ) + + def _validate_sample_input_sparse_default(op_info, sample, check_validate=False): if op_info.name == "to_sparse": if (