mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/68728 This removes the ability for `assert_close` to `.coalesce()` the tensors internally. Additionally, we now also check `.sparse_dim()`. Sparse team: please make sure that is the behavior you want for all sparse COO comparisons in the future. #67796 will temporarily keep BC by always coalescing, but in the future `TestCase.assertEqual` will no longer do that. cc nikitaved pearu cpuhrsch IvanYashchuk Test Plan: Imported from OSS Reviewed By: ngimel Differential Revision: D33542996 Pulled By: mruberry fbshipit-source-id: a8d2322c6ee1ca424e3efb14ab21787328cf28fc
104 lines
3.4 KiB
Python
104 lines
3.4 KiB
Python
"""This module exists since the `torch.testing` exposed a lot of stuff that shouldn't have been public. Although this
|
|
was never documented anywhere, some other internal FB projects as well as downstream OSS projects might use this. Thus,
|
|
we don't internalize without warning, but still go through a deprecation cycle.
|
|
"""
|
|
|
|
import functools
|
|
import warnings
|
|
from typing import Any, Callable, Dict, Optional, Tuple, Union
|
|
|
|
import torch
|
|
|
|
from . import _legacy
|
|
|
|
|
|
__all__ = [
|
|
"rand",
|
|
"randn",
|
|
"assert_allclose",
|
|
"get_all_device_types",
|
|
]
|
|
|
|
|
|
def warn_deprecated(instructions: Union[str, Callable[[str, Tuple[Any, ...], Dict[str, Any], Any], str]]) -> Callable:
|
|
def outer_wrapper(fn: Callable) -> Callable:
|
|
name = fn.__name__
|
|
head = f"torch.testing.{name}() is deprecated and will be removed in a future release. "
|
|
|
|
@functools.wraps(fn)
|
|
def inner_wrapper(*args: Any, **kwargs: Any) -> Any:
|
|
return_value = fn(*args, **kwargs)
|
|
tail = instructions(name, args, kwargs, return_value) if callable(instructions) else instructions
|
|
msg = (head + tail).strip()
|
|
warnings.warn(msg, FutureWarning)
|
|
return return_value
|
|
|
|
return inner_wrapper
|
|
|
|
return outer_wrapper
|
|
|
|
|
|
rand = warn_deprecated("Use torch.rand instead.")(torch.rand)
|
|
randn = warn_deprecated("Use torch.randn instead.")(torch.randn)
|
|
|
|
|
|
_DTYPE_PRECISIONS = {
|
|
torch.float16: (1e-3, 1e-3),
|
|
torch.float32: (1e-4, 1e-5),
|
|
torch.float64: (1e-5, 1e-8),
|
|
}
|
|
|
|
|
|
def _get_default_rtol_and_atol(actual: torch.Tensor, expected: torch.Tensor) -> Tuple[float, float]:
|
|
actual_rtol, actual_atol = _DTYPE_PRECISIONS.get(actual.dtype, (0.0, 0.0))
|
|
expected_rtol, expected_atol = _DTYPE_PRECISIONS.get(expected.dtype, (0.0, 0.0))
|
|
return max(actual_rtol, expected_rtol), max(actual_atol, expected_atol)
|
|
|
|
|
|
# TODO: include the deprecation as soon as torch.testing.assert_close is stable
|
|
# @warn_deprecated(
|
|
# "Use torch.testing.assert_close instead. "
|
|
# "For detailed upgrade instructions see https://github.com/pytorch/pytorch/issues/61844."
|
|
# )
|
|
def assert_allclose(
|
|
actual: Any,
|
|
expected: Any,
|
|
rtol: Optional[float] = None,
|
|
atol: Optional[float] = None,
|
|
equal_nan: bool = True,
|
|
msg: str = "",
|
|
) -> None:
|
|
if not isinstance(actual, torch.Tensor):
|
|
actual = torch.tensor(actual)
|
|
if not isinstance(expected, torch.Tensor):
|
|
expected = torch.tensor(expected, dtype=actual.dtype)
|
|
|
|
if rtol is None and atol is None:
|
|
rtol, atol = _get_default_rtol_and_atol(actual, expected)
|
|
|
|
torch.testing.assert_close(
|
|
actual,
|
|
expected,
|
|
rtol=rtol,
|
|
atol=atol,
|
|
equal_nan=equal_nan,
|
|
check_device=True,
|
|
check_dtype=False,
|
|
check_stride=False,
|
|
msg=msg or None,
|
|
)
|
|
|
|
|
|
# We iterate over all dtype getters and expose them here with an added deprecation warning
|
|
for name in _legacy.__all_dtype_getters__:
|
|
fn = getattr(_legacy, name)
|
|
instructions = (
|
|
lambda name, args, kwargs, return_value: f"This call to {name}(...) can be replaced with {return_value}."
|
|
)
|
|
globals()[name] = warn_deprecated(instructions)(fn)
|
|
__all__.append(name)
|
|
|
|
|
|
instructions = lambda name, args, kwargs, return_value: f"This call can be replaced with {return_value}." # noqa: E731
|
|
get_all_device_types = warn_deprecated(instructions)(_legacy.get_all_device_types)
|