mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
change sparse COO comparison strategy in assert_close (#68728)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/68728 This removes the ability for `assert_close` to `.coalesce()` the tensors internally. Additionally, we now also check `.sparse_dim()`. Sparse team: please make sure that is the behavior you want for all sparse COO comparisons in the future. #67796 will temporarily keep BC by always coalescing, but in the future `TestCase.assertEqual` will no longer do that. cc nikitaved pearu cpuhrsch IvanYashchuk Test Plan: Imported from OSS Reviewed By: ngimel Differential Revision: D33542996 Pulled By: mruberry fbshipit-source-id: a8d2322c6ee1ca424e3efb14ab21787328cf28fc
This commit is contained in:
parent
8d05174def
commit
802dd2b725
|
|
@ -1204,37 +1204,15 @@ class TestAssertCloseSparseCOO(TestCase):
|
||||||
for fn in assert_close_with_inputs(actual, expected):
|
for fn in assert_close_with_inputs(actual, expected):
|
||||||
fn()
|
fn()
|
||||||
|
|
||||||
def test_mismatching_is_coalesced(self):
|
def test_mismatching_sparse_dims(self):
|
||||||
indices = (
|
t = torch.randn(2, 3, 4)
|
||||||
(0, 1),
|
actual = t.to_sparse()
|
||||||
(1, 0),
|
expected = t.to_sparse(2)
|
||||||
)
|
|
||||||
values = (1, 2)
|
|
||||||
actual = torch.sparse_coo_tensor(indices, values, size=(2, 2))
|
|
||||||
expected = actual.clone().coalesce()
|
|
||||||
|
|
||||||
for fn in assert_close_with_inputs(actual, expected):
|
for fn in assert_close_with_inputs(actual, expected):
|
||||||
with self.assertRaisesRegex(AssertionError, "is_coalesced"):
|
with self.assertRaisesRegex(AssertionError, re.escape("number of sparse dimensions in sparse COO tensors")):
|
||||||
fn()
|
fn()
|
||||||
|
|
||||||
def test_mismatching_is_coalesced_no_check(self):
|
|
||||||
actual_indices = (
|
|
||||||
(0, 1),
|
|
||||||
(1, 0),
|
|
||||||
)
|
|
||||||
actual_values = (1, 2)
|
|
||||||
actual = torch.sparse_coo_tensor(actual_indices, actual_values, size=(2, 2)).coalesce()
|
|
||||||
|
|
||||||
expected_indices = (
|
|
||||||
(0, 1, 1,),
|
|
||||||
(1, 0, 0,),
|
|
||||||
)
|
|
||||||
expected_values = (1, 1, 1)
|
|
||||||
expected = torch.sparse_coo_tensor(expected_indices, expected_values, size=(2, 2))
|
|
||||||
|
|
||||||
for fn in assert_close_with_inputs(actual, expected):
|
|
||||||
fn(check_is_coalesced=False)
|
|
||||||
|
|
||||||
def test_mismatching_nnz(self):
|
def test_mismatching_nnz(self):
|
||||||
actual_indices = (
|
actual_indices = (
|
||||||
(0, 1),
|
(0, 1),
|
||||||
|
|
|
||||||
|
|
@ -50,6 +50,14 @@ _DTYPE_PRECISIONS = {
|
||||||
torch.complex64: (1.3e-6, 1e-5),
|
torch.complex64: (1.3e-6, 1e-5),
|
||||||
torch.complex128: (1e-7, 1e-7),
|
torch.complex128: (1e-7, 1e-7),
|
||||||
}
|
}
|
||||||
|
# The default tolerances of torch.float32 are used for quantized dtypes, because quantized tensors are compared in
|
||||||
|
# their dequantized and floating point representation. For more details see `TensorLikePair._compare_quantized_values`
|
||||||
|
_DTYPE_PRECISIONS.update(
|
||||||
|
{
|
||||||
|
dtype: _DTYPE_PRECISIONS[torch.float32]
|
||||||
|
for dtype in (torch.quint8, torch.quint2x4, torch.quint4x2, torch.qint8, torch.qint32)
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def default_tolerances(*inputs: Union[torch.Tensor, torch.dtype]) -> Tuple[float, float]:
|
def default_tolerances(*inputs: Union[torch.Tensor, torch.dtype]) -> Tuple[float, float]:
|
||||||
|
|
@ -622,13 +630,12 @@ class TensorLikePair(Pair):
|
||||||
|
|
||||||
- the :attr:`~torch.Tensor.shape`,
|
- the :attr:`~torch.Tensor.shape`,
|
||||||
- whether both inputs are quantized or not,
|
- whether both inputs are quantized or not,
|
||||||
- and if they are the quantization scheme.
|
- and if they use the same quantization scheme.
|
||||||
|
|
||||||
Checks for
|
Checks for
|
||||||
|
|
||||||
- :attr:`~torch.Tensor.layout`,
|
- :attr:`~torch.Tensor.layout`,
|
||||||
- :meth:`~torch.Tensor.stride`,
|
- :meth:`~torch.Tensor.stride`,
|
||||||
- :meth:`~torch.Tensor.is_coalesced`,
|
|
||||||
- :attr:`~torch.Tensor.device`, and
|
- :attr:`~torch.Tensor.device`, and
|
||||||
- :attr:`~torch.Tensor.dtype`
|
- :attr:`~torch.Tensor.dtype`
|
||||||
|
|
||||||
|
|
@ -652,15 +659,8 @@ class TensorLikePair(Pair):
|
||||||
if actual.layout != expected.layout:
|
if actual.layout != expected.layout:
|
||||||
if self.check_layout:
|
if self.check_layout:
|
||||||
raise_mismatch_error("layout", actual.layout, expected.layout)
|
raise_mismatch_error("layout", actual.layout, expected.layout)
|
||||||
else:
|
elif actual.layout == torch.strided and self.check_stride and actual.stride() != expected.stride():
|
||||||
if actual.layout == torch.strided and self.check_stride and actual.stride() != expected.stride():
|
raise_mismatch_error("stride()", actual.stride(), expected.stride())
|
||||||
raise_mismatch_error("stride()", actual.stride(), expected.stride())
|
|
||||||
elif (
|
|
||||||
actual.layout == torch.sparse_coo
|
|
||||||
and self.check_is_coalesced
|
|
||||||
and actual.is_coalesced() != expected.is_coalesced()
|
|
||||||
):
|
|
||||||
raise_mismatch_error("is_coalesced()", actual.is_coalesced(), expected.is_coalesced())
|
|
||||||
|
|
||||||
if self.check_device and actual.device != expected.device:
|
if self.check_device and actual.device != expected.device:
|
||||||
raise_mismatch_error("device", actual.device, expected.device)
|
raise_mismatch_error("device", actual.device, expected.device)
|
||||||
|
|
@ -677,7 +677,6 @@ class TensorLikePair(Pair):
|
||||||
- ... not of the same ``dtype``, they are promoted to a common ``dtype`` (according to
|
- ... not of the same ``dtype``, they are promoted to a common ``dtype`` (according to
|
||||||
:func:`torch.promote_types`).
|
:func:`torch.promote_types`).
|
||||||
- ... not of the same ``layout``, they are converted to strided tensors.
|
- ... not of the same ``layout``, they are converted to strided tensors.
|
||||||
- ... both sparse COO tensors but only one is coalesced, the other one is coalesced.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
actual (Tensor): Actual tensor.
|
actual (Tensor): Actual tensor.
|
||||||
|
|
@ -699,9 +698,6 @@ class TensorLikePair(Pair):
|
||||||
# These checks are needed, since Tensor.to_dense() fails on tensors that are already strided
|
# These checks are needed, since Tensor.to_dense() fails on tensors that are already strided
|
||||||
actual = actual.to_dense() if actual.layout != torch.strided else actual
|
actual = actual.to_dense() if actual.layout != torch.strided else actual
|
||||||
expected = expected.to_dense() if expected.layout != torch.strided else expected
|
expected = expected.to_dense() if expected.layout != torch.strided else expected
|
||||||
elif actual.is_sparse and actual.is_coalesced() != expected.is_coalesced():
|
|
||||||
actual = actual.coalesce()
|
|
||||||
expected = expected.coalesce()
|
|
||||||
|
|
||||||
return actual, expected
|
return actual, expected
|
||||||
|
|
||||||
|
|
@ -735,10 +731,20 @@ class TensorLikePair(Pair):
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Compares sparse COO tensors by comparing
|
"""Compares sparse COO tensors by comparing
|
||||||
|
|
||||||
|
- the number of sparse dimensions,
|
||||||
- the number of non-zero elements (nnz) for equality,
|
- the number of non-zero elements (nnz) for equality,
|
||||||
- the indices for equality, and
|
- the indices for equality, and
|
||||||
- the values for closeness.
|
- the values for closeness.
|
||||||
"""
|
"""
|
||||||
|
if actual.sparse_dim() != expected.sparse_dim():
|
||||||
|
raise self._make_error_meta(
|
||||||
|
AssertionError,
|
||||||
|
(
|
||||||
|
f"The number of sparse dimensions in sparse COO tensors does not match: "
|
||||||
|
f"{actual.sparse_dim()} != {expected.sparse_dim()}"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
if actual._nnz() != expected._nnz():
|
if actual._nnz() != expected._nnz():
|
||||||
raise self._make_error_meta(
|
raise self._make_error_meta(
|
||||||
AssertionError,
|
AssertionError,
|
||||||
|
|
@ -1031,7 +1037,6 @@ def assert_close(
|
||||||
check_dtype: bool = True,
|
check_dtype: bool = True,
|
||||||
check_layout: bool = True,
|
check_layout: bool = True,
|
||||||
check_stride: bool = False,
|
check_stride: bool = False,
|
||||||
check_is_coalesced: bool = True,
|
|
||||||
msg: Optional[str] = None,
|
msg: Optional[str] = None,
|
||||||
):
|
):
|
||||||
r"""Asserts that ``actual`` and ``expected`` are close.
|
r"""Asserts that ``actual`` and ``expected`` are close.
|
||||||
|
|
@ -1050,8 +1055,6 @@ def assert_close(
|
||||||
If ``actual`` and ``expected`` are sparse (either having COO or CSR layout), their strided members are
|
If ``actual`` and ``expected`` are sparse (either having COO or CSR layout), their strided members are
|
||||||
checked individually. Indices, namely ``indices`` for COO or ``crow_indices`` and ``col_indices`` for CSR layout,
|
checked individually. Indices, namely ``indices`` for COO or ``crow_indices`` and ``col_indices`` for CSR layout,
|
||||||
are always checked for equality whereas the values are checked for closeness according to the definition above.
|
are always checked for equality whereas the values are checked for closeness according to the definition above.
|
||||||
Sparse COO tensors are only considered close if both are either coalesced or uncoalesced (if
|
|
||||||
``check_is_coalesced`` is ``True``).
|
|
||||||
|
|
||||||
If ``actual`` and ``expected`` are quantized, they are considered close if they have the same
|
If ``actual`` and ``expected`` are quantized, they are considered close if they have the same
|
||||||
:meth:`~torch.Tensor.qscheme` and the result of :meth:`~torch.Tensor.dequantize` is close according to the
|
:meth:`~torch.Tensor.qscheme` and the result of :meth:`~torch.Tensor.dequantize` is close according to the
|
||||||
|
|
@ -1089,9 +1092,6 @@ def assert_close(
|
||||||
check is disabled, tensors with different ``layout``'s are converted to strided tensors before being
|
check is disabled, tensors with different ``layout``'s are converted to strided tensors before being
|
||||||
compared.
|
compared.
|
||||||
check_stride (bool): If ``True`` and corresponding tensors are strided, asserts that they have the same stride.
|
check_stride (bool): If ``True`` and corresponding tensors are strided, asserts that they have the same stride.
|
||||||
check_is_coalesced (bool): If ``True`` (default) and corresponding tensors are sparse COO, checks that both
|
|
||||||
``actual`` and ``expected`` are either coalesced or uncoalesced. If this check is disabled, tensors are
|
|
||||||
:meth:`~torch.Tensor.coalesce`'ed before being compared.
|
|
||||||
msg (Optional[str]): Optional error message to use in case a failure occurs during the comparison.
|
msg (Optional[str]): Optional error message to use in case a failure occurs during the comparison.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
|
|
@ -1112,8 +1112,6 @@ def assert_close(
|
||||||
:attr:`~torch.Tensor.device`.
|
:attr:`~torch.Tensor.device`.
|
||||||
AssertionError: If ``check_dtype`` is ``True``, but corresponding tensors do not have the same ``dtype``.
|
AssertionError: If ``check_dtype`` is ``True``, but corresponding tensors do not have the same ``dtype``.
|
||||||
AssertionError: If ``check_stride`` is ``True``, but corresponding strided tensors do not have the same stride.
|
AssertionError: If ``check_stride`` is ``True``, but corresponding strided tensors do not have the same stride.
|
||||||
AssertionError: If ``check_is_coalesced`` is ``True``, but corresponding sparse COO tensors are not both
|
|
||||||
either coalesced or uncoalesced.
|
|
||||||
AssertionError: If the values of corresponding tensors are not close according to the definition above.
|
AssertionError: If the values of corresponding tensors are not close according to the definition above.
|
||||||
|
|
||||||
The following table displays the default ``rtol`` and ``atol`` for different ``dtype``'s. In case of mismatching
|
The following table displays the default ``rtol`` and ``atol`` for different ``dtype``'s. In case of mismatching
|
||||||
|
|
@ -1136,6 +1134,16 @@ def assert_close(
|
||||||
+---------------------------+------------+----------+
|
+---------------------------+------------+----------+
|
||||||
| :attr:`~torch.complex128` | ``1e-7`` | ``1e-7`` |
|
| :attr:`~torch.complex128` | ``1e-7`` | ``1e-7`` |
|
||||||
+---------------------------+------------+----------+
|
+---------------------------+------------+----------+
|
||||||
|
| :attr:`~torch.quint8` | ``1.3e-6`` | ``1e-5`` |
|
||||||
|
+---------------------------+------------+----------+
|
||||||
|
| :attr:`~torch.quint2x4` | ``1.3e-6`` | ``1e-5`` |
|
||||||
|
+---------------------------+------------+----------+
|
||||||
|
| :attr:`~torch.quint4x2` | ``1.3e-6`` | ``1e-5`` |
|
||||||
|
+---------------------------+------------+----------+
|
||||||
|
| :attr:`~torch.qint8` | ``1.3e-6`` | ``1e-5`` |
|
||||||
|
+---------------------------+------------+----------+
|
||||||
|
| :attr:`~torch.qint32` | ``1.3e-6`` | ``1e-5`` |
|
||||||
|
+---------------------------+------------+----------+
|
||||||
| other | ``0.0`` | ``0.0`` |
|
| other | ``0.0`` | ``0.0`` |
|
||||||
+---------------------------+------------+----------+
|
+---------------------------+------------+----------+
|
||||||
|
|
||||||
|
|
@ -1255,6 +1263,5 @@ def assert_close(
|
||||||
check_dtype=check_dtype,
|
check_dtype=check_dtype,
|
||||||
check_layout=check_layout,
|
check_layout=check_layout,
|
||||||
check_stride=check_stride,
|
check_stride=check_stride,
|
||||||
check_is_coalesced=check_is_coalesced,
|
|
||||||
msg=msg,
|
msg=msg,
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -85,7 +85,6 @@ def assert_allclose(
|
||||||
check_device=True,
|
check_device=True,
|
||||||
check_dtype=False,
|
check_dtype=False,
|
||||||
check_stride=False,
|
check_stride=False,
|
||||||
check_is_coalesced=False,
|
|
||||||
msg=msg or None,
|
msg=msg or None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user