mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
This reverts commit 5669e10d37.
Reverted https://github.com/pytorch/pytorch/pull/87974 on behalf of https://github.com/mehtanirav due to Internal breakages from method removal
50 lines
1.6 KiB
Python
50 lines
1.6 KiB
Python
"""
|
|
This file only exists since `torch.testing.assert_allclose` is deprecated, but used extensively throughout the tests in
|
|
this package. The replacement `torch.testing.assert_close` doesn't support one feature that is needed here: comparison
|
|
between numpy arrays and torch tensors. See https://github.com/pytorch/pytorch/issues/61844 for the reasoning why this
|
|
was removed.
|
|
"""
|
|
|
|
import torch
|
|
from typing import Tuple, Any, Optional
|
|
|
|
_DTYPE_PRECISIONS = {
|
|
torch.float16: (1e-3, 1e-3),
|
|
torch.float32: (1e-4, 1e-5),
|
|
torch.float64: (1e-5, 1e-8),
|
|
}
|
|
|
|
|
|
def _get_default_rtol_and_atol(actual: torch.Tensor, expected: torch.Tensor) -> Tuple[float, float]:
|
|
actual_rtol, actual_atol = _DTYPE_PRECISIONS.get(actual.dtype, (0.0, 0.0))
|
|
expected_rtol, expected_atol = _DTYPE_PRECISIONS.get(expected.dtype, (0.0, 0.0))
|
|
return max(actual_rtol, expected_rtol), max(actual_atol, expected_atol)
|
|
|
|
|
|
def assert_allclose(
|
|
actual: Any,
|
|
expected: Any,
|
|
rtol: Optional[float] = None,
|
|
atol: Optional[float] = None,
|
|
equal_nan: bool = True,
|
|
msg: str = "",
|
|
) -> None:
|
|
if not isinstance(actual, torch.Tensor):
|
|
actual = torch.tensor(actual)
|
|
if not isinstance(expected, torch.Tensor):
|
|
expected = torch.tensor(expected, dtype=actual.dtype)
|
|
|
|
if rtol is None and atol is None:
|
|
rtol, atol = _get_default_rtol_and_atol(actual, expected)
|
|
|
|
torch.testing.assert_close(
|
|
actual,
|
|
expected,
|
|
rtol=rtol,
|
|
atol=atol,
|
|
equal_nan=equal_nan,
|
|
check_device=True,
|
|
check_dtype=False,
|
|
check_stride=False,
|
|
msg=msg or None,
|
|
) |