mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/73348 Test Plan: Imported from OSS Reviewed By: mrshenli Differential Revision: D34457727 Pulled By: mruberry fbshipit-source-id: 2cc812b643e0d1e753bead2751ee79b3f03fde20 (cherry picked from commit bcdaca1a019a679b8b274e2fb5f19bfd08874ce9)
142 lines
4.9 KiB
Python
142 lines
4.9 KiB
Python
"""This module exists since the `torch.testing` exposed a lot of stuff that shouldn't have been public. Although this
|
|
was never documented anywhere, some other internal FB projects as well as downstream OSS projects might use this. Thus,
|
|
we don't internalize without warning, but still go through a deprecation cycle.
|
|
"""
|
|
|
|
import functools
|
|
import random
|
|
import warnings
|
|
from typing import Any, Callable, Dict, Optional, Tuple, Union
|
|
|
|
import torch
|
|
|
|
from . import _legacy
|
|
|
|
|
|
__all__ = [
|
|
"rand",
|
|
"randn",
|
|
"assert_allclose",
|
|
"get_all_device_types",
|
|
"make_non_contiguous",
|
|
]
|
|
|
|
|
|
def warn_deprecated(instructions: Union[str, Callable[[str, Tuple[Any, ...], Dict[str, Any], Any], str]]) -> Callable:
|
|
def outer_wrapper(fn: Callable) -> Callable:
|
|
name = fn.__name__
|
|
head = f"torch.testing.{name}() is deprecated since 1.12 and will be removed in 1.14. "
|
|
|
|
@functools.wraps(fn)
|
|
def inner_wrapper(*args: Any, **kwargs: Any) -> Any:
|
|
return_value = fn(*args, **kwargs)
|
|
tail = instructions(name, args, kwargs, return_value) if callable(instructions) else instructions
|
|
msg = (head + tail).strip()
|
|
warnings.warn(msg, FutureWarning)
|
|
return return_value
|
|
|
|
return inner_wrapper
|
|
|
|
return outer_wrapper
|
|
|
|
|
|
rand = warn_deprecated("Use torch.rand() instead.")(torch.rand)
|
|
randn = warn_deprecated("Use torch.randn() instead.")(torch.randn)
|
|
|
|
|
|
_DTYPE_PRECISIONS = {
|
|
torch.float16: (1e-3, 1e-3),
|
|
torch.float32: (1e-4, 1e-5),
|
|
torch.float64: (1e-5, 1e-8),
|
|
}
|
|
|
|
|
|
def _get_default_rtol_and_atol(actual: torch.Tensor, expected: torch.Tensor) -> Tuple[float, float]:
|
|
actual_rtol, actual_atol = _DTYPE_PRECISIONS.get(actual.dtype, (0.0, 0.0))
|
|
expected_rtol, expected_atol = _DTYPE_PRECISIONS.get(expected.dtype, (0.0, 0.0))
|
|
return max(actual_rtol, expected_rtol), max(actual_atol, expected_atol)
|
|
|
|
|
|
@warn_deprecated(
|
|
"Use torch.testing.assert_close() instead. "
|
|
"For detailed upgrade instructions see https://github.com/pytorch/pytorch/issues/61844."
|
|
)
|
|
def assert_allclose(
|
|
actual: Any,
|
|
expected: Any,
|
|
rtol: Optional[float] = None,
|
|
atol: Optional[float] = None,
|
|
equal_nan: bool = True,
|
|
msg: str = "",
|
|
) -> None:
|
|
if not isinstance(actual, torch.Tensor):
|
|
actual = torch.tensor(actual)
|
|
if not isinstance(expected, torch.Tensor):
|
|
expected = torch.tensor(expected, dtype=actual.dtype)
|
|
|
|
if rtol is None and atol is None:
|
|
rtol, atol = _get_default_rtol_and_atol(actual, expected)
|
|
|
|
torch.testing.assert_close(
|
|
actual,
|
|
expected,
|
|
rtol=rtol,
|
|
atol=atol,
|
|
equal_nan=equal_nan,
|
|
check_device=True,
|
|
check_dtype=False,
|
|
check_stride=False,
|
|
msg=msg or None,
|
|
)
|
|
|
|
|
|
getter_instructions = (
|
|
lambda name, args, kwargs, return_value: f"This call can be replaced with {return_value}." # noqa: E731
|
|
)
|
|
|
|
# Deprecate and expose all dtype getters
|
|
for name in _legacy.__all_dtype_getters__:
|
|
fn = getattr(_legacy, name)
|
|
globals()[name] = warn_deprecated(getter_instructions)(fn)
|
|
__all__.append(name)
|
|
|
|
get_all_device_types = warn_deprecated(getter_instructions)(_legacy.get_all_device_types)
|
|
|
|
|
|
@warn_deprecated(
|
|
"Depending on the use case there a different replacement options:\n\n"
|
|
"- If you are using `make_non_contiguous` in combination with a creation function to create a noncontiguous tensor "
|
|
"with random values, use `torch.testing.make_tensor(..., noncontiguous=True)` instead.\n"
|
|
"- If you are using `make_non_contiguous` with a specific tensor, you can replace this call with "
|
|
"`torch.repeat_interleave(input, 2, dim=-1)[..., ::2]`.\n"
|
|
"- If you are using `make_non_contiguous` in the PyTorch test suite, use "
|
|
"`torch.testing._internal.common_utils.noncontiguous_like` instead."
|
|
)
|
|
def make_non_contiguous(tensor: torch.Tensor) -> torch.Tensor:
|
|
if tensor.numel() <= 1: # can't make non-contiguous
|
|
return tensor.clone()
|
|
osize = list(tensor.size())
|
|
|
|
# randomly inflate a few dimensions in osize
|
|
for _ in range(2):
|
|
dim = random.randint(0, len(osize) - 1)
|
|
add = random.randint(4, 15)
|
|
osize[dim] = osize[dim] + add
|
|
|
|
# narrow doesn't make a non-contiguous tensor if we only narrow the 0-th dimension,
|
|
# (which will always happen with a 1-dimensional tensor), so let's make a new
|
|
# right-most dimension and cut it off
|
|
|
|
input = tensor.new(torch.Size(osize + [random.randint(2, 3)]))
|
|
input = input.select(len(input.size()) - 1, random.randint(0, 1))
|
|
# now extract the input of correct size from 'input'
|
|
for i in range(len(osize)):
|
|
if input.size(i) != tensor.size(i):
|
|
bounds = random.randint(1, input.size(i) - tensor.size(i))
|
|
input = input.narrow(i, bounds, tensor.size(i))
|
|
|
|
input.copy_(tensor)
|
|
|
|
# Use .data here to hide the view relation between input and other temporary Tensors
|
|
return input.data
|