pytorch/torch/testing/_asserts.py
Philip Meier dbf3451c6e Add support for checking tensor containers in torch.testing (#55385)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/55385

This renames `assert_tensors_(equal|close)` to `_check_tensors_(equal|close)` and exposes two new functions: `assert_(equal|close)`. In addition to tensor pairs, the newly added functions also support the comparison of tensors in sequences or mappings. Otherwise their signature stays the same.

Test Plan: Imported from OSS

Reviewed By: albanD

Differential Revision: D27903805

Pulled By: mruberry

fbshipit-source-id: 719d19a1d26de8d14cb25846e3d22a6ac828c80a
2021-04-24 23:36:36 -07:00

662 lines
27 KiB
Python

import collections.abc
import functools
import sys
from collections import namedtuple
from typing import Any, Callable, Mapping, Optional, Sequence, Tuple, Type, TypeVar, Union, cast
import torch
from torch import Tensor
from ._core import _unravel_index
__all__ = ["assert_equal", "assert_close"]
# The UsageError should be raised in case the test function is not used correctly. With this the user is able to
# differentiate between a test failure (there is a bug in the tested code) and a test error (there is a bug in the
# test). If pytest is the test runner, we use the built-in UsageError instead our custom one.
try:
# The module 'pytest' will be imported if the 'pytest' runner is used. This will only give false-positives in case
# a previously imported module already directly or indirectly imported 'pytest', but the test is run by another
# runner such as 'unittest'.
# 'mypy' is not able to handle this within a type annotation
# (see https://mypy.readthedocs.io/en/latest/common_issues.html#variables-vs-type-aliases for details). In case
# 'UsageError' is used in an annotation, add a 'type: ignore[valid-type]' comment.
UsageError: Type[Exception] = sys.modules["pytest"].UsageError # type: ignore[attr-defined]
except (KeyError, AttributeError):
class UsageError(Exception): # type: ignore[no-redef]
pass
# This is copy-pasted from torch.testing._internal.common_utils.TestCase.dtype_precisions. With this we avoid a
# dependency on torch.testing._internal at import. See
# https://github.com/pytorch/pytorch/pull/54769#issuecomment-813174256 for details.
# {dtype: (rtol, atol)}
_DTYPE_PRECISIONS = {
torch.float16: (0.001, 1e-5),
torch.bfloat16: (0.016, 1e-5),
torch.float32: (1.3e-6, 1e-5),
torch.float64: (1e-7, 1e-7),
torch.complex32: (0.001, 1e-5),
torch.complex64: (1.3e-6, 1e-5),
torch.complex128: (1e-7, 1e-7),
}
def _get_default_rtol_and_atol(actual: Tensor, expected: Tensor) -> Tuple[float, float]:
dtype = actual.dtype if actual.dtype == expected.dtype else torch.promote_types(actual.dtype, expected.dtype)
return _DTYPE_PRECISIONS.get(dtype, (0.0, 0.0))
def _check_are_tensors(actual: Any, expected: Any) -> Optional[AssertionError]:
"""Checks if both inputs are tensors.
Args:
actual (Any): Actual input.
expected (Any): Actual input.
Returns:
(Optional[AssertionError]): If check did not pass.
"""
if not (isinstance(actual, Tensor) and isinstance(expected, Tensor)):
return AssertionError(f"Both inputs have to be tensors, but got {type(actual)} and {type(expected)} instead.")
return None
def _check_supported_tensors(
actual: Tensor,
expected: Tensor,
) -> Optional[UsageError]: # type: ignore[valid-type]
"""Checks if the tensors are supported by the current infrastructure.
All checks are temporary and will be relaxed in the future.
Args:
actual (Tensor): Actual tensor.
expected (Tensor): Expected tensor.
Returns:
(Optional[UsageError]): If check did not pass.
"""
if any(t.dtype in (torch.complex32, torch.complex64, torch.complex128) for t in (actual, expected)):
return UsageError("Comparison for complex tensors is not supported yet.")
if any(t.is_quantized for t in (actual, expected)):
return UsageError("Comparison for quantized tensors is not supported yet.")
if any(t.is_sparse for t in (actual, expected)):
return UsageError("Comparison for sparse tensors is not supported yet.")
return None
def _check_attributes_equal(
actual: Tensor,
expected: Tensor,
*,
check_device: bool = True,
check_dtype: bool = True,
check_stride: bool = True,
) -> Optional[AssertionError]:
"""Checks if the attributes of two tensors match.
Always checks the :attr:`~torch.Tensor.shape`. Checks for :attr:`~torch.Tensor.device`,
:attr:`~torch.Tensor.dtype`, and :meth:`~torch.Tensor.stride` are optional and can be disabled.
Args:
actual (Tensor): Actual tensor.
expected (Tensor): Expected tensor.
check_device (bool): If ``True`` (default), asserts that both :attr:`actual` and :attr:`expected` are on the
same :attr:`~torch.Tensor.device` memory.
check_dtype (bool): If ``True`` (default), asserts that both :attr:`actual` and :attr:`expected` have the same
:attr:`~torch.Tensor.dtype`.
check_stride (bool): If ``True`` (default), asserts that both :attr:`actual` and :attr:`expected` have the same
:meth:`~torch.Tensor.stride`.
Returns:
(Optional[AssertionError]): If checks did not pass.
"""
msg_fmtstr = "The values for attribute '{}' do not match: {} != {}."
if actual.shape != expected.shape:
return AssertionError(msg_fmtstr.format("shape", actual.shape, expected.shape))
if check_device and actual.device != expected.device:
return AssertionError(msg_fmtstr.format("device", actual.device, expected.device))
if check_dtype and actual.dtype != expected.dtype:
return AssertionError(msg_fmtstr.format("dtype", actual.dtype, expected.dtype))
if check_stride and actual.stride() != expected.stride():
return AssertionError(msg_fmtstr.format("stride()", actual.stride(), expected.stride()))
return None
def _equalize_attributes(actual: Tensor, expected: Tensor) -> Tuple[Tensor, Tensor]:
"""Equalizes some attributes of two tensors for value comparison.
If :attr:`actual` and :attr:`expected`
- are not onn the same memory :attr:`~torch.Tensor.device`, they are moved CPU memory, and
- do not have the same :attr:`~torch.Tensor.dtype`, they are copied to the :class:`~torch.dtype` returned by
:func:`torch.promote_types`.
Args:
actual (Tensor): Actual tensor.
expected (Tensor): Expected tensor.
Returns:
Tuple(Tensor, Tensor): Equalized tensors.
"""
if actual.device != expected.device:
actual = actual.cpu()
expected = expected.cpu()
if actual.dtype != expected.dtype:
dtype = torch.promote_types(actual.dtype, expected.dtype)
actual = actual.to(dtype)
expected = expected.to(dtype)
return actual, expected
_Trace = namedtuple(
"_Trace",
(
"total_elements",
"total_mismatches",
"mismatch_ratio",
"max_abs_diff",
"max_abs_diff_idx",
"max_rel_diff",
"max_rel_diff_idx",
),
)
def _trace_mismatches(actual: Tensor, expected: Tensor, mismatches: Tensor) -> _Trace:
"""Traces mismatches.
Args:
actual (Tensor): Actual tensor.
expected (Tensor): Expected tensor.
mismatches (Tensor): Boolean mask of the same shape as :attr:`actual` and :attr:`expected` that indicates
the location of mismatches.
Returns:
(NamedTuple): Mismatch diagnostics with the following fields:
- total_elements (int): Total number of values.
- total_mismatches (int): Total number of mismatches.
- mismatch_ratio (float): Quotient of total mismatches and total elements.
- max_abs_diff (Union[int, float]): Greatest absolute difference of :attr:`actual` and :attr:`expected`.
- max_abs_diff_idx (Union[int, Tuple[int, ...]]): Index of greatest absolute difference.
- max_rel_diff (Union[int, float]): Greatest relative difference of :attr:`actual` and :attr:`expected`.
- max_rel_diff_idx (Union[int, Tuple[int, ...]]): Index of greatest relative difference.
The returned type of ``max_abs_diff`` and ``max_rel_diff`` depends on the :attr:`~torch.Tensor.dtype` of
:attr:`actual` and :attr:`expected`.
"""
total_elements = mismatches.numel()
total_mismatches = torch.sum(mismatches).item()
mismatch_ratio = total_mismatches / total_elements
dtype = torch.float64 if actual.dtype.is_floating_point else torch.int64
a_flat = actual.flatten().to(dtype)
b_flat = expected.flatten().to(dtype)
abs_diff = torch.abs(a_flat - b_flat)
max_abs_diff, max_abs_diff_flat_idx = torch.max(abs_diff, 0)
rel_diff = abs_diff / torch.abs(b_flat)
max_rel_diff, max_rel_diff_flat_idx = torch.max(rel_diff, 0)
return _Trace(
total_elements=total_elements,
total_mismatches=total_mismatches,
mismatch_ratio=mismatch_ratio,
max_abs_diff=max_abs_diff.item(),
max_abs_diff_idx=_unravel_index(max_abs_diff_flat_idx.item(), mismatches.shape),
max_rel_diff=max_rel_diff.item(),
max_rel_diff_idx=_unravel_index(max_rel_diff_flat_idx.item(), mismatches.shape),
)
def _check_values_equal(actual: Tensor, expected: Tensor) -> Optional[AssertionError]:
"""Checks if the values of two tensors are bitwise equal.
Args:
actual (Tensor): Actual tensor.
expected (Tensor): Expected tensor.
Returns:
(Optional[AssertionError]): If check did not pass.
"""
mismatches = torch.ne(actual, expected)
if not torch.any(mismatches):
return None
trace = _trace_mismatches(actual, expected, mismatches)
return AssertionError(
f"Tensors are not equal!\n\n"
f"Mismatched elements: {trace.total_mismatches} / {trace.total_elements} ({trace.mismatch_ratio:.1%})\n"
f"Greatest absolute difference: {trace.max_abs_diff} at {trace.max_abs_diff_idx}\n"
f"Greatest relative difference: {trace.max_rel_diff} at {trace.max_rel_diff_idx}"
)
def _check_values_close(
actual: Tensor,
expected: Tensor,
*,
rtol,
atol,
) -> Optional[AssertionError]:
"""Checks if the values of two tensors are close up to a desired tolerance.
Args:
actual (Tensor): Actual tensor.
expected (Tensor): Expected tensor.
rtol (float): Relative tolerance.
atol (float): Absolute tolerance.
Returns:
(Optional[AssertionError]): If check did not pass.
"""
mismatches = ~torch.isclose(actual, expected, rtol=rtol, atol=atol)
if not torch.any(mismatches):
return None
trace = _trace_mismatches(actual, expected, mismatches)
return AssertionError(
f"Tensors are not close!\n\n"
f"Mismatched elements: {trace.total_mismatches} / {trace.total_elements} ({trace.mismatch_ratio:.1%})\n"
f"Greatest absolute difference: {trace.max_abs_diff} at {trace.max_abs_diff_idx} (up to {atol} allowed)\n"
f"Greatest relative difference: {trace.max_rel_diff} at {trace.max_rel_diff_idx} (up to {rtol} allowed)"
)
def _check_tensors_equal(
actual: Tensor,
expected: Tensor,
*,
check_device: bool = True,
check_dtype: bool = True,
check_stride: bool = True,
) -> Optional[Exception]:
"""Checks that the values of two tensors are bitwise equal.
Optionally, checks that some attributes of both tensors are equal.
For a description of the parameters see :func:`assert_equal`.
Returns:
Optional[Exception]: If checks did not pass.
"""
exc: Optional[Exception] = _check_are_tensors(actual, expected)
if exc:
return exc
exc = _check_supported_tensors(actual, expected)
if exc:
return exc
exc = _check_attributes_equal(
actual, expected, check_device=check_device, check_dtype=check_dtype, check_stride=check_stride
)
if exc:
return exc
actual, expected = _equalize_attributes(actual, expected)
exc = _check_values_equal(actual, expected)
if exc:
return exc
return None
def _check_tensors_close(
actual: Tensor,
expected: Tensor,
*,
rtol: Optional[float] = None,
atol: Optional[float] = None,
check_device: bool = True,
check_dtype: bool = True,
check_stride: bool = True,
) -> Optional[Exception]:
r"""Checks that the values of two tensors are close.
Closeness is defined by
.. math::
\lvert a - b \rvert \le \texttt{atol} + \texttt{rtol} \cdot \lvert b \rvert
If both tolerances, :attr:`rtol` and :attr:`rtol`, are ``0``, asserts that :attr:`actual` and :attr:`expected` are
bitwise equal.
Optionally, checks that some attributes of both tensors are equal.
For a description of the parameters see :func:`assert_equal`.
Returns:
Optional[Exception]: If checks did not pass.
"""
exc: Optional[Exception] = _check_are_tensors(actual, expected)
if exc:
return exc
exc = _check_supported_tensors(actual, expected)
if exc:
return exc
if (rtol is None) ^ (atol is None):
# We require both tolerance to be omitted or specified, because specifying only one might lead to surprising
# results. Imagine setting atol=0.0 and the tensors still match because rtol>0.0.
return UsageError(
f"Both 'rtol' and 'atol' must be omitted or specified, " f"but got rtol={rtol} and atol={atol} instead."
)
elif rtol is None:
rtol, atol = _get_default_rtol_and_atol(actual, expected)
exc = _check_attributes_equal(
actual, expected, check_device=check_device, check_dtype=check_dtype, check_stride=check_stride
)
if exc:
raise exc
actual, expected = _equalize_attributes(actual, expected)
if (rtol == 0.0) and (atol == 0.0):
exc = _check_values_equal(actual, expected)
else:
exc = _check_values_close(actual, expected, rtol=rtol, atol=atol)
if exc:
return exc
return None
def _check_by_type(
actual: Union[Tensor, Sequence[Tensor], Mapping[Any, Tensor]],
expected: Union[Tensor, Sequence[Tensor], Mapping[Any, Tensor]],
check_tensors: Callable[[Tensor, Tensor], Optional[Exception]],
) -> Optional[Exception]:
"""Delegates tensor checking based on the inputs types.
Currently supports pairs of
- :class:`Tensor`'s,
- :class:`~collections.abc.Sequence`'s of :class:`Tensor`'s, and
- :class:`~collections.abc.Mapping`'s of :class:`Tensor`'s.
Args:
actual (Union[Tensor, Sequence[Tensor], Mapping[Any, Tensor]]): Actual input.
expected (Union[Tensor, Sequence[Tensor], Mapping[Any, Tensor]]): Expected input.
check_tensors (Callable[[Tensor, Tensor], Optional[Exception]]): Callable used to check if two tensors match.
In case they mismatch should return an :class:`Exception` with an expressive error message.
Returns:
(Optional[Exception]): :class:`UsageError` if the inputs types are unsupported. Additionally, any exception
returned by :attr:`check_tensors`.
"""
# _check_are_tensors() returns nothing in case both inputs are tensors and an exception otherwise. Thus, the logic
# is inverted here.
are_tensors = not _check_are_tensors(actual, expected)
if are_tensors:
return check_tensors(cast(Tensor, actual), cast(Tensor, expected))
if isinstance(actual, collections.abc.Sequence) and isinstance(expected, collections.abc.Sequence):
return _check_sequence(actual, expected, check_tensors)
elif isinstance(actual, collections.abc.Mapping) and isinstance(expected, collections.abc.Mapping):
return _check_mapping(actual, expected, check_tensors)
return UsageError(
f"Both inputs have to be tensors, or sequences or mappings of tensors, "
f"but got {type(actual)} and {type(expected)} instead."
)
E = TypeVar("E", bound=Exception)
def _amend_error_message(exc: E, msg_fmtstr: str) -> E:
"""Amends an exception message.
Args:
exc (E): Exception.
msg_fmtstr: Format string for the amended message.
Returns:
(E): New exception with amended error message.
"""
return type(exc)(msg_fmtstr.format(str(exc)))
_SEQUENCE_MSG_FMTSTR = "The failure occurred at index {} of the sequences."
def _check_sequence(
actual: Sequence[Tensor], expected: Sequence[Tensor], check_tensors: Callable[[Tensor, Tensor], Optional[Exception]]
) -> Optional[Exception]:
"""Checks if the values of two sequences of tensors match.
Args:
actual (Sequence[Tensor]): Actual sequence of tensors.
expected (Sequence[Tensor]): Expected sequence of tensors.
check_tensors (Callable[[Tensor, Tensor], Optional[Exception]]): Callable used to check if the items of
:attr:`actual` and :attr:`expected` match. In case they mismatch should return an :class:`Exception` with
an expressive error message.
Returns:
Optional[Exception]: :class:`AssertionError` if the sequences do not have the same length. Additionally, any
exception returned by :attr:`check_tensors`. In this case, the error message is amended to include the
first offending index.
"""
actual_len = len(actual)
expected_len = len(expected)
if actual_len != expected_len:
return AssertionError(f"The length of the sequences mismatch: {actual_len} != {expected_len}")
for idx, (actual_t, expected_t) in enumerate(zip(actual, expected)):
exc = check_tensors(actual_t, expected_t)
if exc:
return _amend_error_message(exc, f"{{}}\n\n{_SEQUENCE_MSG_FMTSTR.format(idx)}")
return None
_MAPPING_MSG_FMTSTR = "The failure occurred for key '{}' of the mappings."
def _check_mapping(
actual: Mapping[Any, Tensor],
expected: Mapping[Any, Tensor],
check_tensors: Callable[[Tensor, Tensor], Optional[Exception]],
) -> Optional[Exception]:
"""Checks if the values of two mappings of tensors match.
Args:
actual (Mapping[Any, Tensor]): First mapping of tensors.
expected (Mapping[Any, Tensor]): Second mapping of tensors.
check_tensors (Callable[[Tensor, Tensor], Optional[Exception]]): Callable used to check if the values of
:attr:`actual` and :attr:`expected` match. In case they mismatch should return an :class:`Exception` with
an expressive error message.
Returns:
Optional[Exception]: :class:`AssertionError` if the sequences do not have the same set of keys. Additionally,
any exception returned by :attr:`check_tensors`. In this case, the error message is amended to include the
first offending key.
"""
actual_keys = set(actual.keys())
expected_keys = set(expected.keys())
if actual_keys != expected_keys:
missing_keys = expected_keys - actual_keys
additional_keys = actual_keys - expected_keys
return AssertionError(
f"The keys of the mappings do not match:\n\n"
f"Missing keys in the actual mapping: {sorted(missing_keys)}\n"
f"Additional keys in the actual mapping: {sorted(additional_keys)}\n"
)
for key in sorted(actual_keys):
actual_t = actual[key]
expected_t = expected[key]
exc = check_tensors(actual_t, expected_t)
if exc:
return _amend_error_message(exc, f"{{}}\n\n{_MAPPING_MSG_FMTSTR.format(key)}")
return None
def assert_equal(
actual: Union[Tensor, Sequence[Tensor], Mapping[Any, Tensor]],
expected: Union[Tensor, Sequence[Tensor], Mapping[Any, Tensor]],
*,
check_device: bool = True,
check_dtype: bool = True,
check_stride: bool = True,
) -> None:
"""Asserts that the values of tensors are bitwise equal.
Optionally, checks that some attributes of tensors are equal.
Also supports :class:`~collections.abc.Sequence`'s and :class:`~collections.abc.Mapping`'s of :class:`Tensor`'s.
Args:
actual (Union[Tensor, Sequence[Tensor], Mapping[Any, Tensor]]): Actual input.
expected (Union[Tensor, Sequence[Tensor], Mapping[Any, Tensor]]): Expected input.
check_device (bool): If ``True`` (default), asserts that tensors live in the same :attr:`~torch.Tensor.device`
memory. If this check is disabled **and** they do not live in the same memory :attr:`~torch.Tensor.device`,
they are moved CPU memory before their values are compared.
check_dtype (bool): If ``True`` (default), asserts that tensors have the same :attr:`~torch.Tensor.dtype`. If
this check is disabled they do not have the same :attr:`~torch.Tensor.dtype`, they are copied to the
:class:`~torch.dtype` returned by :func:`torch.promote_types` before their values are compared.
check_stride (bool): If ``True`` (default), asserts that the tensors have the same stride.
Raises:
UsageError: If the input pair has an unsupported type.
UsageError: If any tensor is complex, quantized, or sparse. This is a temporary restriction and
will be relaxed in the future.
AssertionError: If any corresponding tensors do not have the same :attr:`~torch.Tensor.shape`.
AssertionError: If :attr:`check_device`, but any corresponding tensors do not live in the same
:attr:`~torch.Tensor.device` memory.
AssertionError: If :attr:`check_dtype`, but any corresponding tensors do not have the same
:attr:`~torch.Tensor.dtype`.
AssertionError: If :attr:`check_stride`, but any corresponding tensors do not have the same stride.
AssertionError: If the values of any corresponding tensors are not bitwise equal.
AssertionError: If the inputs are :class:`~collections.abc.Sequence`'s, but their length does not match.
AssertionError: If the inputs are :class:`~collections.abc.Mapping`'s, but their set of keys mismatch.
.. seealso::
To assert that the values in tensors are close but are not required to be bitwise equal, use
:func:`assert_close` instead.
"""
check_tensors = functools.partial(
_check_tensors_equal,
check_device=check_device,
check_dtype=check_dtype,
check_stride=check_stride,
)
exc = _check_by_type(actual, expected, check_tensors)
if exc:
raise exc
def assert_close(
actual: Union[Tensor, Sequence[Tensor], Mapping[Any, Tensor]],
expected: Union[Tensor, Sequence[Tensor], Mapping[Any, Tensor]],
*,
rtol: Optional[float] = None,
atol: Optional[float] = None,
check_device: bool = True,
check_dtype: bool = True,
check_stride: bool = True,
) -> None:
r"""Asserts that the values of tensors are close.
Closeness is defined by
.. math::
\lvert a - b \rvert \le \texttt{atol} + \texttt{rtol} \cdot \lvert b \rvert
If both tolerances, :attr:`rtol` and :attr:`rtol`, are ``0``, asserts that :attr:`actual` and :attr:`expected` are
bitwise equal.
Optionally, checks that some attributes of tensors are equal.
Also supports :class:`~collections.abc.Sequence`'s and :class:`~collections.abc.Mapping`'s of :class:`Tensor`'s.
Args:
actual (Union[Tensor, Sequence[Tensor], Mapping[Any, Tensor]]): Actual input.
expected (Union[Tensor, Sequence[Tensor], Mapping[Any, Tensor]]): Expected input.
rtol (Optional[float]): Relative tolerance. If specified :attr:`atol` must also be specified. If omitted,
default values based on the :attr:`~torch.Tensor.dtype` are selected with the below table.
atol (Optional[float]): Absolute tolerance. If specified :attr:`rtol` must also be specified. If omitted,
default values based on the :attr:`~torch.Tensor.dtype` are selected with the below table.
check_device (bool): If ``True`` (default), asserts that both :attr:`actual` and :attr:`expected` are on the
same :attr:`~torch.Tensor.device` memory. If this check is disabled **and** :attr:`actual` and
:attr:`expected` are not on the same memory :attr:`~torch.Tensor.device`, they are moved CPU memory before
their values are compared.
check_dtype (bool): If ``True`` (default), asserts that both :attr:`actual` and :attr:`expected` have the same
:attr:`~torch.Tensor.dtype`. If this check is disabled **and** :attr:`actual` and :attr:`expected` do not
have the same :attr:`~torch.Tensor.dtype`, they are copied to the :class:`~torch.dtype` returned by
:func:`torch.promote_types` before their values are compared.
check_stride (bool): If ``True`` (default), asserts that both :attr:`actual` and :attr:`expected` have the same
stride.
Raises:
UsageError: If the input pair has an unsupported type.
UsageError: If :attr:`actual` or :attr:`expected` is complex, quantized, or sparse. This is a temporary
restriction and will be relaxed in the future.
AssertionError: If :attr:`actual` and :attr:`expected` do not have the same :attr:`~torch.Tensor.shape`.
AssertionError: If :attr:`check_device`, but :attr:`actual` and :attr:`expected` are not on the same
:attr:`~torch.Tensor.device` memory.
AssertionError: If :attr:`check_dtype`, but :attr:`actual` and :attr:`expected` do not have the same
:attr:`~torch.Tensor.dtype`.
AssertionError: If :attr:`check_stride`, but :attr:`actual` and :attr:`expected` do not have the same stride.
AssertionError: If the values of :attr:`actual` and :attr:`expected` are close up to a desired tolerance.
AssertionError: If the inputs are :class:`~collections.abc.Sequence`'s, but their length does not match.
AssertionError: If the inputs are :class:`~collections.abc.Mapping`'s, but their set of keys mismatch.
The following table displays the default ``rtol``'s and ``atol``'s. Note that the :class:`~torch.dtype` refers to
the promoted type in case :attr:`actual` and :attr:`expected` do not have the same :attr:`~torch.Tensor.dtype`.
+===========================+============+==========+
| :class:`~torch.dtype` | ``rtol`` | ``atol`` |
+===========================+============+==========+
| :attr:`~torch.float16` | ``1e-3`` | ``1e-5`` |
+---------------------------+------------+----------+
| :attr:`~torch.bfloat16` | ``1.6e-2`` | ``1e-5`` |
+---------------------------+------------+----------+
| :attr:`~torch.float32` | ``1.3e-6`` | ``1e-5`` |
+---------------------------+------------+----------+
| :attr:`~torch.float64` | ``1e-7`` | ``1e-7`` |
+---------------------------+------------+----------+
| :attr:`~torch.complex32` | ``1e-3`` | ``1e-5`` |
+---------------------------+------------+----------+
| :attr:`~torch.complex64` | ``1.3e-6`` | ``1e-5`` |
+---------------------------+------------+----------+
| :attr:`~torch.complex128` | ``1e-7`` | ``1e-7`` |
+---------------------------+------------+----------+
| other | ``0.0`` | ``0.0`` |
+---------------------------+------------+----------+
.. seealso::
To assert that the values in tensors are bitwise equal, use :func:`assert_equal` instead.
"""
check_tensors = functools.partial(
_check_tensors_close,
rtol=rtol,
atol=atol,
check_device=check_device,
check_dtype=check_dtype,
check_stride=check_stride,
)
exc = _check_by_type(actual, expected, check_tensors)
if exc:
raise exc