diff --git a/test/test_testing.py b/test/test_testing.py index 272b5033c32..a3849b7f7f2 100644 --- a/test/test_testing.py +++ b/test/test_testing.py @@ -1204,37 +1204,15 @@ class TestAssertCloseSparseCOO(TestCase): for fn in assert_close_with_inputs(actual, expected): fn() - def test_mismatching_is_coalesced(self): - indices = ( - (0, 1), - (1, 0), - ) - values = (1, 2) - actual = torch.sparse_coo_tensor(indices, values, size=(2, 2)) - expected = actual.clone().coalesce() + def test_mismatching_sparse_dims(self): + t = torch.randn(2, 3, 4) + actual = t.to_sparse() + expected = t.to_sparse(2) for fn in assert_close_with_inputs(actual, expected): - with self.assertRaisesRegex(AssertionError, "is_coalesced"): + with self.assertRaisesRegex(AssertionError, re.escape("number of sparse dimensions in sparse COO tensors")): fn() - def test_mismatching_is_coalesced_no_check(self): - actual_indices = ( - (0, 1), - (1, 0), - ) - actual_values = (1, 2) - actual = torch.sparse_coo_tensor(actual_indices, actual_values, size=(2, 2)).coalesce() - - expected_indices = ( - (0, 1, 1,), - (1, 0, 0,), - ) - expected_values = (1, 1, 1) - expected = torch.sparse_coo_tensor(expected_indices, expected_values, size=(2, 2)) - - for fn in assert_close_with_inputs(actual, expected): - fn(check_is_coalesced=False) - def test_mismatching_nnz(self): actual_indices = ( (0, 1), diff --git a/torch/testing/_comparison.py b/torch/testing/_comparison.py index 06a1fe2b743..1da2a23fda8 100644 --- a/torch/testing/_comparison.py +++ b/torch/testing/_comparison.py @@ -50,6 +50,14 @@ _DTYPE_PRECISIONS = { torch.complex64: (1.3e-6, 1e-5), torch.complex128: (1e-7, 1e-7), } +# The default tolerances of torch.float32 are used for quantized dtypes, because quantized tensors are compared in +# their dequantized and floating point representation. For more details see `TensorLikePair._compare_quantized_values` +_DTYPE_PRECISIONS.update( + { + dtype: _DTYPE_PRECISIONS[torch.float32] + for dtype in (torch.quint8, torch.quint2x4, torch.quint4x2, torch.qint8, torch.qint32) + } +) def default_tolerances(*inputs: Union[torch.Tensor, torch.dtype]) -> Tuple[float, float]: @@ -622,13 +630,12 @@ class TensorLikePair(Pair): - the :attr:`~torch.Tensor.shape`, - whether both inputs are quantized or not, - - and if they are the quantization scheme. + - and if they use the same quantization scheme. Checks for - :attr:`~torch.Tensor.layout`, - :meth:`~torch.Tensor.stride`, - - :meth:`~torch.Tensor.is_coalesced`, - :attr:`~torch.Tensor.device`, and - :attr:`~torch.Tensor.dtype` @@ -652,15 +659,8 @@ class TensorLikePair(Pair): if actual.layout != expected.layout: if self.check_layout: raise_mismatch_error("layout", actual.layout, expected.layout) - else: - if actual.layout == torch.strided and self.check_stride and actual.stride() != expected.stride(): - raise_mismatch_error("stride()", actual.stride(), expected.stride()) - elif ( - actual.layout == torch.sparse_coo - and self.check_is_coalesced - and actual.is_coalesced() != expected.is_coalesced() - ): - raise_mismatch_error("is_coalesced()", actual.is_coalesced(), expected.is_coalesced()) + elif actual.layout == torch.strided and self.check_stride and actual.stride() != expected.stride(): + raise_mismatch_error("stride()", actual.stride(), expected.stride()) if self.check_device and actual.device != expected.device: raise_mismatch_error("device", actual.device, expected.device) @@ -677,7 +677,6 @@ class TensorLikePair(Pair): - ... not of the same ``dtype``, they are promoted to a common ``dtype`` (according to :func:`torch.promote_types`). - ... not of the same ``layout``, they are converted to strided tensors. - - ... both sparse COO tensors but only one is coalesced, the other one is coalesced. Args: actual (Tensor): Actual tensor. @@ -699,9 +698,6 @@ class TensorLikePair(Pair): # These checks are needed, since Tensor.to_dense() fails on tensors that are already strided actual = actual.to_dense() if actual.layout != torch.strided else actual expected = expected.to_dense() if expected.layout != torch.strided else expected - elif actual.is_sparse and actual.is_coalesced() != expected.is_coalesced(): - actual = actual.coalesce() - expected = expected.coalesce() return actual, expected @@ -735,10 +731,20 @@ class TensorLikePair(Pair): ) -> None: """Compares sparse COO tensors by comparing + - the number of sparse dimensions, - the number of non-zero elements (nnz) for equality, - the indices for equality, and - the values for closeness. """ + if actual.sparse_dim() != expected.sparse_dim(): + raise self._make_error_meta( + AssertionError, + ( + f"The number of sparse dimensions in sparse COO tensors does not match: " + f"{actual.sparse_dim()} != {expected.sparse_dim()}" + ), + ) + if actual._nnz() != expected._nnz(): raise self._make_error_meta( AssertionError, @@ -1031,7 +1037,6 @@ def assert_close( check_dtype: bool = True, check_layout: bool = True, check_stride: bool = False, - check_is_coalesced: bool = True, msg: Optional[str] = None, ): r"""Asserts that ``actual`` and ``expected`` are close. @@ -1050,8 +1055,6 @@ def assert_close( If ``actual`` and ``expected`` are sparse (either having COO or CSR layout), their strided members are checked individually. Indices, namely ``indices`` for COO or ``crow_indices`` and ``col_indices`` for CSR layout, are always checked for equality whereas the values are checked for closeness according to the definition above. - Sparse COO tensors are only considered close if both are either coalesced or uncoalesced (if - ``check_is_coalesced`` is ``True``). If ``actual`` and ``expected`` are quantized, they are considered close if they have the same :meth:`~torch.Tensor.qscheme` and the result of :meth:`~torch.Tensor.dequantize` is close according to the @@ -1089,9 +1092,6 @@ def assert_close( check is disabled, tensors with different ``layout``'s are converted to strided tensors before being compared. check_stride (bool): If ``True`` and corresponding tensors are strided, asserts that they have the same stride. - check_is_coalesced (bool): If ``True`` (default) and corresponding tensors are sparse COO, checks that both - ``actual`` and ``expected`` are either coalesced or uncoalesced. If this check is disabled, tensors are - :meth:`~torch.Tensor.coalesce`'ed before being compared. msg (Optional[str]): Optional error message to use in case a failure occurs during the comparison. Raises: @@ -1112,8 +1112,6 @@ def assert_close( :attr:`~torch.Tensor.device`. AssertionError: If ``check_dtype`` is ``True``, but corresponding tensors do not have the same ``dtype``. AssertionError: If ``check_stride`` is ``True``, but corresponding strided tensors do not have the same stride. - AssertionError: If ``check_is_coalesced`` is ``True``, but corresponding sparse COO tensors are not both - either coalesced or uncoalesced. AssertionError: If the values of corresponding tensors are not close according to the definition above. The following table displays the default ``rtol`` and ``atol`` for different ``dtype``'s. In case of mismatching @@ -1136,6 +1134,16 @@ def assert_close( +---------------------------+------------+----------+ | :attr:`~torch.complex128` | ``1e-7`` | ``1e-7`` | +---------------------------+------------+----------+ + | :attr:`~torch.quint8` | ``1.3e-6`` | ``1e-5`` | + +---------------------------+------------+----------+ + | :attr:`~torch.quint2x4` | ``1.3e-6`` | ``1e-5`` | + +---------------------------+------------+----------+ + | :attr:`~torch.quint4x2` | ``1.3e-6`` | ``1e-5`` | + +---------------------------+------------+----------+ + | :attr:`~torch.qint8` | ``1.3e-6`` | ``1e-5`` | + +---------------------------+------------+----------+ + | :attr:`~torch.qint32` | ``1.3e-6`` | ``1e-5`` | + +---------------------------+------------+----------+ | other | ``0.0`` | ``0.0`` | +---------------------------+------------+----------+ @@ -1255,6 +1263,5 @@ def assert_close( check_dtype=check_dtype, check_layout=check_layout, check_stride=check_stride, - check_is_coalesced=check_is_coalesced, msg=msg, ) diff --git a/torch/testing/_deprecated.py b/torch/testing/_deprecated.py index d58d890cf1d..67826b3a628 100644 --- a/torch/testing/_deprecated.py +++ b/torch/testing/_deprecated.py @@ -85,7 +85,6 @@ def assert_allclose( check_device=True, check_dtype=False, check_stride=False, - check_is_coalesced=False, msg=msg or None, )