mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
make torch.testing asserts importable (#54769)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/54769 Follow-up to #53820. This - makes the `asserts.py` module private as per suggestion from rgommers in https://github.com/pytorch/pytorch/pull/53820#issuecomment-802661387. With this the functions should only be accessible through `torch.testing`, giving us the option the change the underlying structure later. - moves the code from `torch/testing/__init__.py` to `torch/testing/_core.py` (happy to accept other name suggestions). Otherwise we can't import the new `_asserts.py` in `torch/testing/__init__.py` due to circular imports. Test Plan: Imported from OSS Reviewed By: mrshenli Differential Revision: D27438451 Pulled By: mruberry fbshipit-source-id: c7292b4d5709185b42b4aac8016648562688040e
This commit is contained in:
parent
ffe301846b
commit
f4967d68f5
|
|
@ -24,6 +24,8 @@ from torch._jit_internal import _qualified_name, is_scripting, get_callable_argu
|
|||
from torch.autograd import function
|
||||
from torch.nn import Module
|
||||
|
||||
from torch.testing._core import _get_default_tolerance
|
||||
|
||||
_flatten = torch._C._jit_flatten
|
||||
_unflatten = torch._C._jit_unflatten
|
||||
|
||||
|
|
@ -487,7 +489,7 @@ def _check_trace(
|
|||
orig.double(),
|
||||
ref.double(),
|
||||
rtol=check_tolerance,
|
||||
atol=torch.testing._get_default_tolerance(orig, ref)[1],
|
||||
atol=_get_default_tolerance(orig, ref)[1],
|
||||
)
|
||||
except AssertionError as e:
|
||||
maybe_warn_nondeterministic()
|
||||
|
|
|
|||
|
|
@ -1,394 +1,3 @@
|
|||
"""
|
||||
The testing package contains testing-specific utilities.
|
||||
"""
|
||||
|
||||
import torch
|
||||
import random
|
||||
import math
|
||||
import cmath
|
||||
from typing import cast, List, Optional, Tuple, Union
|
||||
from .check_kernel_launches import check_cuda_kernel_launches, check_code_for_cuda_kernel_launches
|
||||
import operator
|
||||
|
||||
FileCheck = torch._C.FileCheck
|
||||
|
||||
__all__ = [
|
||||
'assert_allclose', 'make_non_contiguous', 'rand_like', 'randn_like'
|
||||
]
|
||||
|
||||
rand_like = torch.rand_like
|
||||
randn_like = torch.randn_like
|
||||
|
||||
# Helper function that returns True when the dtype is an integral dtype,
|
||||
# False otherwise.
|
||||
# TODO: implement numpy-like issubdtype
|
||||
def is_integral(dtype: torch.dtype) -> bool:
|
||||
# Skip complex/quantized types
|
||||
dtypes = [x for x in get_all_dtypes() if x not in get_all_complex_dtypes()]
|
||||
return dtype in dtypes and not dtype.is_floating_point
|
||||
|
||||
def is_quantized(dtype: torch.dtype) -> bool:
|
||||
return dtype in (torch.quint8, torch.qint8, torch.qint32, torch.quint4x2)
|
||||
|
||||
# Helper function that maps a flattened index back into the given shape
|
||||
# TODO: consider adding torch.unravel_index
|
||||
def _unravel_index(flat_index, shape):
|
||||
flat_index = operator.index(flat_index)
|
||||
res = []
|
||||
|
||||
# Short-circuits on zero dim tensors
|
||||
if shape == torch.Size([]):
|
||||
return 0
|
||||
|
||||
for size in shape[::-1]:
|
||||
res.append(flat_index % size)
|
||||
flat_index = flat_index // size
|
||||
|
||||
if len(res) == 1:
|
||||
return res[0]
|
||||
|
||||
return tuple(res[::-1])
|
||||
# (bool, msg) tuple, where msg is None if and only if bool is True.
|
||||
_compare_return_type = Tuple[bool, Optional[str]]
|
||||
|
||||
# Compares two tensors with the same size on the same device and with the same
|
||||
# dtype for equality.
|
||||
# Returns a tuple (bool, msg). The bool value returned is True when the tensors
|
||||
# are "equal" and False otherwise.
|
||||
# The msg value is a debug string, and is None if the tensors are "equal."
|
||||
# NOTE: Test Framework Tensor 'Equality'
|
||||
# Two tensors are "equal" if they are "close", in the sense of torch.allclose.
|
||||
# The only exceptions are complex tensors and bool tensors.
|
||||
#
|
||||
# Complex tensors are "equal" if both the
|
||||
# real and complex parts (separately) are close. This is divergent from
|
||||
# torch.allclose's behavior, which compares the absolute values of the
|
||||
# complex numbers instead.
|
||||
#
|
||||
# Using torch.allclose would be a less strict
|
||||
# comparison that would allow large complex values with
|
||||
# significant real or imaginary differences to be considered "equal,"
|
||||
# and would make setting rtol and atol for complex tensors distinct from
|
||||
# other tensor types.
|
||||
#
|
||||
# Bool tensors are equal only if they are identical, regardless of
|
||||
# the rtol and atol values.
|
||||
#
|
||||
# The `equal_nan` can be True or False, which maps to the True or False
|
||||
# in `torch.allclose`. `equal_nan` can also be "relaxed", which means
|
||||
# the complex will be compared in the relaxed mode:
|
||||
# 2 + nan j == 3 + nan j ---> False when equal_nan=True
|
||||
# True when equal_nan="relaxed"
|
||||
def _compare_tensors_internal(a: torch.Tensor, b: torch.Tensor, *, rtol, atol, equal_nan: Union[str, bool]) -> _compare_return_type:
|
||||
assert equal_nan in {True, False, "relaxed"}
|
||||
debug_msg : Optional[str]
|
||||
# Integer (including bool) comparisons are identity comparisons
|
||||
# when rtol is zero and atol is less than one
|
||||
if (
|
||||
(is_integral(a.dtype) and rtol == 0 and atol < 1)
|
||||
or a.dtype is torch.bool
|
||||
or is_quantized(a.dtype)
|
||||
):
|
||||
if (a == b).all().item():
|
||||
return (True, None)
|
||||
|
||||
# Gathers debug info for failed integer comparison
|
||||
# NOTE: converts to long to correctly represent differences
|
||||
# (especially between uint8 tensors)
|
||||
identity_mask = a != b
|
||||
a_flat = a.to(torch.long).flatten()
|
||||
b_flat = b.to(torch.long).flatten()
|
||||
count_non_identical = torch.sum(identity_mask, dtype=torch.long)
|
||||
diff = torch.abs(a_flat - b_flat)
|
||||
greatest_diff_index = torch.argmax(diff)
|
||||
debug_msg = ("Found {0} different element(s) (out of {1}), with the greatest "
|
||||
"difference of {2} ({3} vs. {4}) occuring at index "
|
||||
"{5}.".format(count_non_identical.item(),
|
||||
a.numel(),
|
||||
diff[greatest_diff_index],
|
||||
a_flat[greatest_diff_index],
|
||||
b_flat[greatest_diff_index],
|
||||
_unravel_index(greatest_diff_index, a.shape)))
|
||||
return (False, debug_msg)
|
||||
|
||||
# Compares complex tensors' real and imaginary parts separately.
|
||||
# (see NOTE Test Framework Tensor "Equality")
|
||||
if a.is_complex():
|
||||
if equal_nan == "relaxed":
|
||||
a = a.clone()
|
||||
b = b.clone()
|
||||
a.real[a.imag.isnan()] = math.nan
|
||||
a.imag[a.real.isnan()] = math.nan
|
||||
b.real[b.imag.isnan()] = math.nan
|
||||
b.imag[b.real.isnan()] = math.nan
|
||||
|
||||
real_result, debug_msg = _compare_tensors_internal(a.real, b.real,
|
||||
rtol=rtol, atol=atol,
|
||||
equal_nan=equal_nan)
|
||||
|
||||
if not real_result:
|
||||
debug_msg = "Real parts failed to compare as equal! " + cast(str, debug_msg)
|
||||
return (real_result, debug_msg)
|
||||
|
||||
imag_result, debug_msg = _compare_tensors_internal(a.imag, b.imag,
|
||||
rtol=rtol, atol=atol,
|
||||
equal_nan=equal_nan)
|
||||
|
||||
if not imag_result:
|
||||
debug_msg = "Imaginary parts failed to compare as equal! " + cast(str, debug_msg)
|
||||
return (imag_result, debug_msg)
|
||||
|
||||
return (True, None)
|
||||
|
||||
# All other comparisons use torch.allclose directly
|
||||
if torch.allclose(a, b, rtol=rtol, atol=atol, equal_nan=(equal_nan in {"relaxed", True})):
|
||||
return (True, None)
|
||||
|
||||
# Gathers debug info for failed float tensor comparison
|
||||
# NOTE: converts to float64 to best represent differences
|
||||
a_flat = a.to(torch.float64).flatten()
|
||||
b_flat = b.to(torch.float64).flatten()
|
||||
diff = torch.abs(a_flat - b_flat)
|
||||
|
||||
# Masks close values
|
||||
# NOTE: this avoids (inf - inf) oddities when computing the difference
|
||||
close = torch.isclose(a_flat, b_flat, rtol, atol, (equal_nan in {"relaxed", True}))
|
||||
diff[close] = 0
|
||||
nans = torch.isnan(diff)
|
||||
num_nans = nans.sum()
|
||||
|
||||
outside_range = (diff > (atol + rtol * torch.abs(b_flat))) | (diff == math.inf)
|
||||
count_outside_range = torch.sum(outside_range, dtype=torch.long)
|
||||
greatest_diff_index = torch.argmax(diff)
|
||||
debug_msg = ("With rtol={0} and atol={1}, found {2} element(s) (out of {3}) whose "
|
||||
"difference(s) exceeded the margin of error (including {4} nan comparisons). "
|
||||
"The greatest difference was {5} ({6} vs. {7}), which "
|
||||
"occurred at index {8}.".format(rtol, atol,
|
||||
count_outside_range + num_nans,
|
||||
a.numel(),
|
||||
num_nans,
|
||||
diff[greatest_diff_index],
|
||||
a_flat[greatest_diff_index],
|
||||
b_flat[greatest_diff_index],
|
||||
_unravel_index(greatest_diff_index, a.shape)))
|
||||
return (False, debug_msg)
|
||||
|
||||
# Checks if two scalars are equal(-ish), returning (True, None)
|
||||
# when they are and (False, debug_msg) when they are not.
|
||||
def _compare_scalars_internal(a, b, *, rtol: float, atol: float, equal_nan: Union[str, bool]) -> _compare_return_type:
|
||||
def _helper(a, b, s) -> _compare_return_type:
|
||||
# Short-circuits on identity
|
||||
if a == b or ((equal_nan in {"relaxed", True}) and a != a and b != b):
|
||||
return (True, None)
|
||||
|
||||
# Special-case for NaN comparisions when equal_nan=False
|
||||
if not (equal_nan in {"relaxed", True}) and (a != a or b != b):
|
||||
msg = ("Found {0} and {1} while comparing" + s + "and either one "
|
||||
"is nan and the other isn't, or both are nan and "
|
||||
"equal_nan is False").format(a, b)
|
||||
return (False, msg)
|
||||
|
||||
diff = abs(a - b)
|
||||
allowed_diff = atol + rtol * abs(b)
|
||||
result = diff <= allowed_diff
|
||||
|
||||
# Special-case for infinity comparisons
|
||||
# NOTE: if b is inf then allowed_diff will be inf when rtol is not 0
|
||||
if ((math.isinf(a) or math.isinf(b)) and a != b):
|
||||
result = False
|
||||
|
||||
msg = None
|
||||
if not result:
|
||||
msg = ("Comparing" + s + "{0} and {1} gives a "
|
||||
"difference of {2}, but the allowed difference "
|
||||
"with rtol={3} and atol={4} is "
|
||||
"only {5}!").format(a, b, diff,
|
||||
rtol, atol, allowed_diff)
|
||||
|
||||
return result, msg
|
||||
|
||||
if isinstance(a, complex) or isinstance(b, complex):
|
||||
a = complex(a)
|
||||
b = complex(b)
|
||||
|
||||
if equal_nan == "relaxed":
|
||||
if cmath.isnan(a) and cmath.isnan(b):
|
||||
return (True, None)
|
||||
|
||||
result, msg = _helper(a.real, b.real, " the real part ")
|
||||
|
||||
if not result:
|
||||
return (False, msg)
|
||||
|
||||
return _helper(a.imag, b.imag, " the imaginary part ")
|
||||
|
||||
return _helper(a, b, " ")
|
||||
|
||||
def assert_allclose(actual, expected, rtol=None, atol=None, equal_nan=True, msg='') -> None:
|
||||
if not isinstance(actual, torch.Tensor):
|
||||
actual = torch.tensor(actual)
|
||||
if not isinstance(expected, torch.Tensor):
|
||||
expected = torch.tensor(expected, dtype=actual.dtype)
|
||||
if expected.shape != actual.shape:
|
||||
raise AssertionError("expected tensor shape {0} doesn't match with actual tensor "
|
||||
"shape {1}!".format(expected.shape, actual.shape))
|
||||
if rtol is None or atol is None:
|
||||
if rtol is not None or atol is not None:
|
||||
raise ValueError("rtol and atol must both be specified or both be unspecified")
|
||||
rtol, atol = _get_default_tolerance(actual, expected)
|
||||
|
||||
result, debug_msg = _compare_tensors_internal(actual, expected,
|
||||
rtol=rtol, atol=atol,
|
||||
equal_nan=equal_nan)
|
||||
|
||||
if result:
|
||||
return
|
||||
|
||||
if msg is None or msg == '':
|
||||
msg = debug_msg
|
||||
|
||||
raise AssertionError(msg)
|
||||
|
||||
def make_non_contiguous(tensor: torch.Tensor) -> torch.Tensor:
|
||||
if tensor.numel() <= 1: # can't make non-contiguous
|
||||
return tensor.clone()
|
||||
osize = list(tensor.size())
|
||||
|
||||
# randomly inflate a few dimensions in osize
|
||||
for _ in range(2):
|
||||
dim = random.randint(0, len(osize) - 1)
|
||||
add = random.randint(4, 15)
|
||||
osize[dim] = osize[dim] + add
|
||||
|
||||
# narrow doesn't make a non-contiguous tensor if we only narrow the 0-th dimension,
|
||||
# (which will always happen with a 1-dimensional tensor), so let's make a new
|
||||
# right-most dimension and cut it off
|
||||
|
||||
input = tensor.new(torch.Size(osize + [random.randint(2, 3)]))
|
||||
input = input.select(len(input.size()) - 1, random.randint(0, 1))
|
||||
# now extract the input of correct size from 'input'
|
||||
for i in range(len(osize)):
|
||||
if input.size(i) != tensor.size(i):
|
||||
bounds = random.randint(1, input.size(i) - tensor.size(i))
|
||||
input = input.narrow(i, bounds, tensor.size(i))
|
||||
|
||||
input.copy_(tensor)
|
||||
|
||||
# Use .data here to hide the view relation between input and other temporary Tensors
|
||||
return input.data
|
||||
|
||||
|
||||
# Functions and classes for describing the dtypes a function supports
|
||||
# NOTE: these helpers should correspond to PyTorch's C++ dispatch macros
|
||||
|
||||
# Verifies each given dtype is a torch.dtype
|
||||
def _validate_dtypes(*dtypes):
|
||||
for dtype in dtypes:
|
||||
assert isinstance(dtype, torch.dtype)
|
||||
return dtypes
|
||||
|
||||
# class for tuples corresponding to a PyTorch dispatch macro
|
||||
class _dispatch_dtypes(tuple):
|
||||
def __add__(self, other):
|
||||
assert isinstance(other, tuple)
|
||||
return _dispatch_dtypes(tuple.__add__(self, other))
|
||||
|
||||
_floating_types = _dispatch_dtypes((torch.float32, torch.float64))
|
||||
def floating_types():
|
||||
return _floating_types
|
||||
|
||||
_floating_types_and_half = _floating_types + (torch.half,)
|
||||
def floating_types_and_half():
|
||||
return _floating_types_and_half
|
||||
|
||||
def floating_types_and(*dtypes):
|
||||
return _floating_types + _validate_dtypes(*dtypes)
|
||||
|
||||
_floating_and_complex_types = _floating_types + (torch.cfloat, torch.cdouble)
|
||||
def floating_and_complex_types():
|
||||
return _floating_and_complex_types
|
||||
|
||||
def floating_and_complex_types_and(*dtypes):
|
||||
return _floating_and_complex_types + _validate_dtypes(*dtypes)
|
||||
|
||||
_integral_types = _dispatch_dtypes((torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64))
|
||||
def integral_types():
|
||||
return _integral_types
|
||||
|
||||
def integral_types_and(*dtypes):
|
||||
return _integral_types + _validate_dtypes(*dtypes)
|
||||
|
||||
_all_types = _floating_types + _integral_types
|
||||
def all_types():
|
||||
return _all_types
|
||||
|
||||
def all_types_and(*dtypes):
|
||||
return _all_types + _validate_dtypes(*dtypes)
|
||||
|
||||
_complex_types = _dispatch_dtypes((torch.cfloat, torch.cdouble))
|
||||
def complex_types():
|
||||
return _complex_types
|
||||
|
||||
_all_types_and_complex = _all_types + _complex_types
|
||||
def all_types_and_complex():
|
||||
return _all_types_and_complex
|
||||
|
||||
def all_types_and_complex_and(*dtypes):
|
||||
return _all_types_and_complex + _validate_dtypes(*dtypes)
|
||||
|
||||
_all_types_and_half = _all_types + (torch.half,)
|
||||
def all_types_and_half():
|
||||
return _all_types_and_half
|
||||
|
||||
def get_all_dtypes(include_half=True,
|
||||
include_bfloat16=True,
|
||||
include_bool=True,
|
||||
include_complex=True,
|
||||
include_complex32=False
|
||||
) -> List[torch.dtype]:
|
||||
dtypes = get_all_int_dtypes() + get_all_fp_dtypes(include_half=include_half, include_bfloat16=include_bfloat16)
|
||||
if include_bool:
|
||||
dtypes.append(torch.bool)
|
||||
if include_complex:
|
||||
dtypes += get_all_complex_dtypes(include_complex32)
|
||||
return dtypes
|
||||
|
||||
def get_all_math_dtypes(device) -> List[torch.dtype]:
|
||||
return get_all_int_dtypes() + get_all_fp_dtypes(include_half=device.startswith('cuda'),
|
||||
include_bfloat16=False) + get_all_complex_dtypes()
|
||||
|
||||
def get_all_complex_dtypes(include_complex32=False) -> List[torch.dtype]:
|
||||
return [torch.complex32, torch.complex64, torch.complex128] if include_complex32 else [torch.complex64, torch.complex128]
|
||||
|
||||
|
||||
def get_all_int_dtypes() -> List[torch.dtype]:
|
||||
return [torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64]
|
||||
|
||||
|
||||
def get_all_fp_dtypes(include_half=True, include_bfloat16=True) -> List[torch.dtype]:
|
||||
dtypes = [torch.float32, torch.float64]
|
||||
if include_half:
|
||||
dtypes.append(torch.float16)
|
||||
if include_bfloat16:
|
||||
dtypes.append(torch.bfloat16)
|
||||
return dtypes
|
||||
|
||||
|
||||
def get_all_device_types() -> List[str]:
|
||||
return ['cpu'] if not torch.cuda.is_available() else ['cpu', 'cuda']
|
||||
|
||||
# 'dtype': (rtol, atol)
|
||||
_default_tolerances = {
|
||||
'float64': (1e-5, 1e-8), # NumPy default
|
||||
'float32': (1e-4, 1e-5), # This may need to be changed
|
||||
'float16': (1e-3, 1e-3), # This may need to be changed
|
||||
}
|
||||
|
||||
|
||||
def _get_default_tolerance(a, b=None) -> Tuple[float, float]:
|
||||
if b is None:
|
||||
dtype = str(a.dtype).split('.')[-1] # e.g. "float32"
|
||||
return _default_tolerances.get(dtype, (0, 0))
|
||||
a_tol = _get_default_tolerance(a)
|
||||
b_tol = _get_default_tolerance(b)
|
||||
return (max(a_tol[0], b_tol[0]), max(a_tol[1], b_tol[1]))
|
||||
from ._core import *
|
||||
from ._asserts import *
|
||||
from ._check_kernel_launches import *
|
||||
|
|
|
|||
|
|
@ -3,8 +3,7 @@ from collections import namedtuple
|
|||
from typing import Any, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch.testing import _unravel_index
|
||||
from torch.testing._internal.common_utils import get_comparison_dtype as _get_comparison_dtype, TestCase as _TestCase
|
||||
from ._core import _unravel_index
|
||||
|
||||
__all__ = ["assert_tensors_equal", "assert_tensors_allclose"]
|
||||
|
||||
|
|
@ -24,9 +23,24 @@ except KeyError:
|
|||
pass
|
||||
|
||||
|
||||
# This is copy-pasted from torch.testing._internal.common_utils.TestCase.dtype_precisions. With this we avoid a
|
||||
# dependency on torch.testing._internal at import. See
|
||||
# https://github.com/pytorch/pytorch/pull/54769#issuecomment-813174256 for details.
|
||||
# {dtype: (rtol, atol)}
|
||||
_DTYPE_PRECISIONS = {
|
||||
torch.float16: (0.001, 1e-5),
|
||||
torch.bfloat16: (0.016, 1e-5),
|
||||
torch.float32: (1.3e-6, 1e-5),
|
||||
torch.float64: (1e-7, 1e-7),
|
||||
torch.complex32: (0.001, 1e-5),
|
||||
torch.complex64: (1.3e-6, 1e-5),
|
||||
torch.complex128: (1e-7, 1e-7),
|
||||
}
|
||||
|
||||
|
||||
def _get_default_rtol_and_atol(a: torch.Tensor, b: torch.Tensor) -> Tuple[float, float]:
|
||||
dtype = a.dtype if a.dtype == b.dtype else _get_comparison_dtype(a, b)
|
||||
return _TestCase.dtype_precisions.get(dtype, (0.0, 0.0))
|
||||
dtype = a.dtype if a.dtype == b.dtype else torch.promote_types(a.dtype, b.dtype)
|
||||
return _DTYPE_PRECISIONS.get(dtype, (0.0, 0.0))
|
||||
|
||||
|
||||
def _assert_are_tensors(a: Any, b: Any) -> None:
|
||||
|
|
@ -138,7 +152,7 @@ def _equalize_attributes(a: torch.Tensor, b: torch.Tensor) -> Tuple[torch.Tensor
|
|||
b = b.cpu()
|
||||
|
||||
if a.dtype != b.dtype:
|
||||
dtype = _get_comparison_dtype(a, b)
|
||||
dtype = torch.promote_types(a.dtype, b.dtype)
|
||||
a = a.to(dtype)
|
||||
b = b.to(dtype)
|
||||
|
||||
|
|
@ -2,6 +2,11 @@ import os
|
|||
import re
|
||||
import sys
|
||||
|
||||
__all__ = [
|
||||
"check_code_for_cuda_kernel_launches",
|
||||
"check_cuda_kernel_launches",
|
||||
]
|
||||
|
||||
|
||||
# Regular expression identifies a kernel launch indicator by
|
||||
# finding something approximating the pattern ">>>(arguments);"
|
||||
414
torch/testing/_core.py
Normal file
414
torch/testing/_core.py
Normal file
|
|
@ -0,0 +1,414 @@
|
|||
"""
|
||||
The testing package contains testing-specific utilities.
|
||||
"""
|
||||
|
||||
import torch
|
||||
import random
|
||||
import math
|
||||
import cmath
|
||||
from typing import cast, List, Optional, Tuple, Union
|
||||
import operator
|
||||
|
||||
FileCheck = torch._C.FileCheck
|
||||
|
||||
__all__ = [
|
||||
"FileCheck",
|
||||
"all_types",
|
||||
"all_types_and",
|
||||
"all_types_and_complex",
|
||||
"all_types_and_complex_and",
|
||||
"assert_allclose",
|
||||
"complex_types",
|
||||
"floating_and_complex_types",
|
||||
"floating_and_complex_types_and",
|
||||
"floating_types",
|
||||
"floating_types_and",
|
||||
"get_all_complex_dtypes",
|
||||
"get_all_dtypes",
|
||||
"get_all_device_types",
|
||||
"get_all_fp_dtypes",
|
||||
"get_all_int_dtypes",
|
||||
"get_all_math_dtypes",
|
||||
"integral_types",
|
||||
"integral_types_and",
|
||||
"make_non_contiguous",
|
||||
"rand_like",
|
||||
"randn_like",
|
||||
]
|
||||
|
||||
rand_like = torch.rand_like
|
||||
randn_like = torch.randn_like
|
||||
|
||||
# Helper function that returns True when the dtype is an integral dtype,
|
||||
# False otherwise.
|
||||
# TODO: implement numpy-like issubdtype
|
||||
def is_integral(dtype: torch.dtype) -> bool:
|
||||
# Skip complex/quantized types
|
||||
dtypes = [x for x in get_all_dtypes() if x not in get_all_complex_dtypes()]
|
||||
return dtype in dtypes and not dtype.is_floating_point
|
||||
|
||||
def is_quantized(dtype: torch.dtype) -> bool:
|
||||
return dtype in (torch.quint8, torch.qint8, torch.qint32, torch.quint4x2)
|
||||
|
||||
# Helper function that maps a flattened index back into the given shape
|
||||
# TODO: consider adding torch.unravel_index
|
||||
def _unravel_index(flat_index, shape):
|
||||
flat_index = operator.index(flat_index)
|
||||
res = []
|
||||
|
||||
# Short-circuits on zero dim tensors
|
||||
if shape == torch.Size([]):
|
||||
return 0
|
||||
|
||||
for size in shape[::-1]:
|
||||
res.append(flat_index % size)
|
||||
flat_index = flat_index // size
|
||||
|
||||
if len(res) == 1:
|
||||
return res[0]
|
||||
|
||||
return tuple(res[::-1])
|
||||
# (bool, msg) tuple, where msg is None if and only if bool is True.
|
||||
_compare_return_type = Tuple[bool, Optional[str]]
|
||||
|
||||
# Compares two tensors with the same size on the same device and with the same
|
||||
# dtype for equality.
|
||||
# Returns a tuple (bool, msg). The bool value returned is True when the tensors
|
||||
# are "equal" and False otherwise.
|
||||
# The msg value is a debug string, and is None if the tensors are "equal."
|
||||
# NOTE: Test Framework Tensor 'Equality'
|
||||
# Two tensors are "equal" if they are "close", in the sense of torch.allclose.
|
||||
# The only exceptions are complex tensors and bool tensors.
|
||||
#
|
||||
# Complex tensors are "equal" if both the
|
||||
# real and complex parts (separately) are close. This is divergent from
|
||||
# torch.allclose's behavior, which compares the absolute values of the
|
||||
# complex numbers instead.
|
||||
#
|
||||
# Using torch.allclose would be a less strict
|
||||
# comparison that would allow large complex values with
|
||||
# significant real or imaginary differences to be considered "equal,"
|
||||
# and would make setting rtol and atol for complex tensors distinct from
|
||||
# other tensor types.
|
||||
#
|
||||
# Bool tensors are equal only if they are identical, regardless of
|
||||
# the rtol and atol values.
|
||||
#
|
||||
# The `equal_nan` can be True or False, which maps to the True or False
|
||||
# in `torch.allclose`. `equal_nan` can also be "relaxed", which means
|
||||
# the complex will be compared in the relaxed mode:
|
||||
# 2 + nan j == 3 + nan j ---> False when equal_nan=True
|
||||
# True when equal_nan="relaxed"
|
||||
def _compare_tensors_internal(a: torch.Tensor, b: torch.Tensor, *, rtol, atol, equal_nan: Union[str, bool]) -> _compare_return_type:
|
||||
assert equal_nan in {True, False, "relaxed"}
|
||||
debug_msg : Optional[str]
|
||||
# Integer (including bool) comparisons are identity comparisons
|
||||
# when rtol is zero and atol is less than one
|
||||
if (
|
||||
(is_integral(a.dtype) and rtol == 0 and atol < 1)
|
||||
or a.dtype is torch.bool
|
||||
or is_quantized(a.dtype)
|
||||
):
|
||||
if (a == b).all().item():
|
||||
return (True, None)
|
||||
|
||||
# Gathers debug info for failed integer comparison
|
||||
# NOTE: converts to long to correctly represent differences
|
||||
# (especially between uint8 tensors)
|
||||
identity_mask = a != b
|
||||
a_flat = a.to(torch.long).flatten()
|
||||
b_flat = b.to(torch.long).flatten()
|
||||
count_non_identical = torch.sum(identity_mask, dtype=torch.long)
|
||||
diff = torch.abs(a_flat - b_flat)
|
||||
greatest_diff_index = torch.argmax(diff)
|
||||
debug_msg = ("Found {0} different element(s) (out of {1}), with the greatest "
|
||||
"difference of {2} ({3} vs. {4}) occuring at index "
|
||||
"{5}.".format(count_non_identical.item(),
|
||||
a.numel(),
|
||||
diff[greatest_diff_index],
|
||||
a_flat[greatest_diff_index],
|
||||
b_flat[greatest_diff_index],
|
||||
_unravel_index(greatest_diff_index, a.shape)))
|
||||
return (False, debug_msg)
|
||||
|
||||
# Compares complex tensors' real and imaginary parts separately.
|
||||
# (see NOTE Test Framework Tensor "Equality")
|
||||
if a.is_complex():
|
||||
if equal_nan == "relaxed":
|
||||
a = a.clone()
|
||||
b = b.clone()
|
||||
a.real[a.imag.isnan()] = math.nan
|
||||
a.imag[a.real.isnan()] = math.nan
|
||||
b.real[b.imag.isnan()] = math.nan
|
||||
b.imag[b.real.isnan()] = math.nan
|
||||
|
||||
real_result, debug_msg = _compare_tensors_internal(a.real, b.real,
|
||||
rtol=rtol, atol=atol,
|
||||
equal_nan=equal_nan)
|
||||
|
||||
if not real_result:
|
||||
debug_msg = "Real parts failed to compare as equal! " + cast(str, debug_msg)
|
||||
return (real_result, debug_msg)
|
||||
|
||||
imag_result, debug_msg = _compare_tensors_internal(a.imag, b.imag,
|
||||
rtol=rtol, atol=atol,
|
||||
equal_nan=equal_nan)
|
||||
|
||||
if not imag_result:
|
||||
debug_msg = "Imaginary parts failed to compare as equal! " + cast(str, debug_msg)
|
||||
return (imag_result, debug_msg)
|
||||
|
||||
return (True, None)
|
||||
|
||||
# All other comparisons use torch.allclose directly
|
||||
if torch.allclose(a, b, rtol=rtol, atol=atol, equal_nan=(equal_nan in {"relaxed", True})):
|
||||
return (True, None)
|
||||
|
||||
# Gathers debug info for failed float tensor comparison
|
||||
# NOTE: converts to float64 to best represent differences
|
||||
a_flat = a.to(torch.float64).flatten()
|
||||
b_flat = b.to(torch.float64).flatten()
|
||||
diff = torch.abs(a_flat - b_flat)
|
||||
|
||||
# Masks close values
|
||||
# NOTE: this avoids (inf - inf) oddities when computing the difference
|
||||
close = torch.isclose(a_flat, b_flat, rtol, atol, (equal_nan in {"relaxed", True}))
|
||||
diff[close] = 0
|
||||
nans = torch.isnan(diff)
|
||||
num_nans = nans.sum()
|
||||
|
||||
outside_range = (diff > (atol + rtol * torch.abs(b_flat))) | (diff == math.inf)
|
||||
count_outside_range = torch.sum(outside_range, dtype=torch.long)
|
||||
greatest_diff_index = torch.argmax(diff)
|
||||
debug_msg = ("With rtol={0} and atol={1}, found {2} element(s) (out of {3}) whose "
|
||||
"difference(s) exceeded the margin of error (including {4} nan comparisons). "
|
||||
"The greatest difference was {5} ({6} vs. {7}), which "
|
||||
"occurred at index {8}.".format(rtol, atol,
|
||||
count_outside_range + num_nans,
|
||||
a.numel(),
|
||||
num_nans,
|
||||
diff[greatest_diff_index],
|
||||
a_flat[greatest_diff_index],
|
||||
b_flat[greatest_diff_index],
|
||||
_unravel_index(greatest_diff_index, a.shape)))
|
||||
return (False, debug_msg)
|
||||
|
||||
# Checks if two scalars are equal(-ish), returning (True, None)
|
||||
# when they are and (False, debug_msg) when they are not.
|
||||
def _compare_scalars_internal(a, b, *, rtol: float, atol: float, equal_nan: Union[str, bool]) -> _compare_return_type:
|
||||
def _helper(a, b, s) -> _compare_return_type:
|
||||
# Short-circuits on identity
|
||||
if a == b or ((equal_nan in {"relaxed", True}) and a != a and b != b):
|
||||
return (True, None)
|
||||
|
||||
# Special-case for NaN comparisions when equal_nan=False
|
||||
if not (equal_nan in {"relaxed", True}) and (a != a or b != b):
|
||||
msg = ("Found {0} and {1} while comparing" + s + "and either one "
|
||||
"is nan and the other isn't, or both are nan and "
|
||||
"equal_nan is False").format(a, b)
|
||||
return (False, msg)
|
||||
|
||||
diff = abs(a - b)
|
||||
allowed_diff = atol + rtol * abs(b)
|
||||
result = diff <= allowed_diff
|
||||
|
||||
# Special-case for infinity comparisons
|
||||
# NOTE: if b is inf then allowed_diff will be inf when rtol is not 0
|
||||
if ((math.isinf(a) or math.isinf(b)) and a != b):
|
||||
result = False
|
||||
|
||||
msg = None
|
||||
if not result:
|
||||
msg = ("Comparing" + s + "{0} and {1} gives a "
|
||||
"difference of {2}, but the allowed difference "
|
||||
"with rtol={3} and atol={4} is "
|
||||
"only {5}!").format(a, b, diff,
|
||||
rtol, atol, allowed_diff)
|
||||
|
||||
return result, msg
|
||||
|
||||
if isinstance(a, complex) or isinstance(b, complex):
|
||||
a = complex(a)
|
||||
b = complex(b)
|
||||
|
||||
if equal_nan == "relaxed":
|
||||
if cmath.isnan(a) and cmath.isnan(b):
|
||||
return (True, None)
|
||||
|
||||
result, msg = _helper(a.real, b.real, " the real part ")
|
||||
|
||||
if not result:
|
||||
return (False, msg)
|
||||
|
||||
return _helper(a.imag, b.imag, " the imaginary part ")
|
||||
|
||||
return _helper(a, b, " ")
|
||||
|
||||
def assert_allclose(actual, expected, rtol=None, atol=None, equal_nan=True, msg='') -> None:
|
||||
if not isinstance(actual, torch.Tensor):
|
||||
actual = torch.tensor(actual)
|
||||
if not isinstance(expected, torch.Tensor):
|
||||
expected = torch.tensor(expected, dtype=actual.dtype)
|
||||
if expected.shape != actual.shape:
|
||||
raise AssertionError("expected tensor shape {0} doesn't match with actual tensor "
|
||||
"shape {1}!".format(expected.shape, actual.shape))
|
||||
if rtol is None or atol is None:
|
||||
if rtol is not None or atol is not None:
|
||||
raise ValueError("rtol and atol must both be specified or both be unspecified")
|
||||
rtol, atol = _get_default_tolerance(actual, expected)
|
||||
|
||||
result, debug_msg = _compare_tensors_internal(actual, expected,
|
||||
rtol=rtol, atol=atol,
|
||||
equal_nan=equal_nan)
|
||||
|
||||
if result:
|
||||
return
|
||||
|
||||
if msg is None or msg == '':
|
||||
msg = debug_msg
|
||||
|
||||
raise AssertionError(msg)
|
||||
|
||||
def make_non_contiguous(tensor: torch.Tensor) -> torch.Tensor:
|
||||
if tensor.numel() <= 1: # can't make non-contiguous
|
||||
return tensor.clone()
|
||||
osize = list(tensor.size())
|
||||
|
||||
# randomly inflate a few dimensions in osize
|
||||
for _ in range(2):
|
||||
dim = random.randint(0, len(osize) - 1)
|
||||
add = random.randint(4, 15)
|
||||
osize[dim] = osize[dim] + add
|
||||
|
||||
# narrow doesn't make a non-contiguous tensor if we only narrow the 0-th dimension,
|
||||
# (which will always happen with a 1-dimensional tensor), so let's make a new
|
||||
# right-most dimension and cut it off
|
||||
|
||||
input = tensor.new(torch.Size(osize + [random.randint(2, 3)]))
|
||||
input = input.select(len(input.size()) - 1, random.randint(0, 1))
|
||||
# now extract the input of correct size from 'input'
|
||||
for i in range(len(osize)):
|
||||
if input.size(i) != tensor.size(i):
|
||||
bounds = random.randint(1, input.size(i) - tensor.size(i))
|
||||
input = input.narrow(i, bounds, tensor.size(i))
|
||||
|
||||
input.copy_(tensor)
|
||||
|
||||
# Use .data here to hide the view relation between input and other temporary Tensors
|
||||
return input.data
|
||||
|
||||
|
||||
# Functions and classes for describing the dtypes a function supports
|
||||
# NOTE: these helpers should correspond to PyTorch's C++ dispatch macros
|
||||
|
||||
# Verifies each given dtype is a torch.dtype
|
||||
def _validate_dtypes(*dtypes):
|
||||
for dtype in dtypes:
|
||||
assert isinstance(dtype, torch.dtype)
|
||||
return dtypes
|
||||
|
||||
# class for tuples corresponding to a PyTorch dispatch macro
|
||||
class _dispatch_dtypes(tuple):
|
||||
def __add__(self, other):
|
||||
assert isinstance(other, tuple)
|
||||
return _dispatch_dtypes(tuple.__add__(self, other))
|
||||
|
||||
_floating_types = _dispatch_dtypes((torch.float32, torch.float64))
|
||||
def floating_types():
|
||||
return _floating_types
|
||||
|
||||
_floating_types_and_half = _floating_types + (torch.half,)
|
||||
def floating_types_and_half():
|
||||
return _floating_types_and_half
|
||||
|
||||
def floating_types_and(*dtypes):
|
||||
return _floating_types + _validate_dtypes(*dtypes)
|
||||
|
||||
_floating_and_complex_types = _floating_types + (torch.cfloat, torch.cdouble)
|
||||
def floating_and_complex_types():
|
||||
return _floating_and_complex_types
|
||||
|
||||
def floating_and_complex_types_and(*dtypes):
|
||||
return _floating_and_complex_types + _validate_dtypes(*dtypes)
|
||||
|
||||
_integral_types = _dispatch_dtypes((torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64))
|
||||
def integral_types():
|
||||
return _integral_types
|
||||
|
||||
def integral_types_and(*dtypes):
|
||||
return _integral_types + _validate_dtypes(*dtypes)
|
||||
|
||||
_all_types = _floating_types + _integral_types
|
||||
def all_types():
|
||||
return _all_types
|
||||
|
||||
def all_types_and(*dtypes):
|
||||
return _all_types + _validate_dtypes(*dtypes)
|
||||
|
||||
_complex_types = (torch.cfloat, torch.cdouble)
|
||||
def complex_types():
|
||||
return _complex_types
|
||||
|
||||
_all_types_and_complex = _all_types + _complex_types
|
||||
def all_types_and_complex():
|
||||
return _all_types_and_complex
|
||||
|
||||
def all_types_and_complex_and(*dtypes):
|
||||
return _all_types_and_complex + _validate_dtypes(*dtypes)
|
||||
|
||||
_all_types_and_half = _all_types + (torch.half,)
|
||||
def all_types_and_half():
|
||||
return _all_types_and_half
|
||||
|
||||
def get_all_dtypes(include_half=True,
|
||||
include_bfloat16=True,
|
||||
include_bool=True,
|
||||
include_complex=True,
|
||||
include_complex32=False
|
||||
) -> List[torch.dtype]:
|
||||
dtypes = get_all_int_dtypes() + get_all_fp_dtypes(include_half=include_half, include_bfloat16=include_bfloat16)
|
||||
if include_bool:
|
||||
dtypes.append(torch.bool)
|
||||
if include_complex:
|
||||
dtypes += get_all_complex_dtypes(include_complex32)
|
||||
return dtypes
|
||||
|
||||
def get_all_math_dtypes(device) -> List[torch.dtype]:
|
||||
return get_all_int_dtypes() + get_all_fp_dtypes(include_half=device.startswith('cuda'),
|
||||
include_bfloat16=False) + get_all_complex_dtypes()
|
||||
|
||||
def get_all_complex_dtypes(include_complex32=False) -> List[torch.dtype]:
|
||||
return [torch.complex32, torch.complex64, torch.complex128] if include_complex32 else [torch.complex64, torch.complex128]
|
||||
|
||||
|
||||
def get_all_int_dtypes() -> List[torch.dtype]:
|
||||
return [torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64]
|
||||
|
||||
|
||||
def get_all_fp_dtypes(include_half=True, include_bfloat16=True) -> List[torch.dtype]:
|
||||
dtypes = [torch.float32, torch.float64]
|
||||
if include_half:
|
||||
dtypes.append(torch.float16)
|
||||
if include_bfloat16:
|
||||
dtypes.append(torch.bfloat16)
|
||||
return dtypes
|
||||
|
||||
|
||||
def get_all_device_types() -> List[str]:
|
||||
return ['cpu'] if not torch.cuda.is_available() else ['cpu', 'cuda']
|
||||
|
||||
# 'dtype': (rtol, atol)
|
||||
_default_tolerances = {
|
||||
'float64': (1e-5, 1e-8), # NumPy default
|
||||
'float32': (1e-4, 1e-5), # This may need to be changed
|
||||
'float16': (1e-3, 1e-3), # This may need to be changed
|
||||
}
|
||||
|
||||
|
||||
def _get_default_tolerance(a, b=None) -> Tuple[float, float]:
|
||||
if b is None:
|
||||
dtype = str(a.dtype).split('.')[-1] # e.g. "float32"
|
||||
return _default_tolerances.get(dtype, (0, 0))
|
||||
a_tol = _get_default_tolerance(a)
|
||||
b_tol = _get_default_tolerance(b)
|
||||
return (max(a_tol[0], b_tol[0]), max(a_tol[1], b_tol[1]))
|
||||
|
|
@ -18,6 +18,7 @@ from functools import wraps
|
|||
|
||||
import torch
|
||||
import torch.distributed as c10d
|
||||
import torch.cuda.nccl
|
||||
|
||||
from functools import partial, reduce
|
||||
from torch.testing._internal.common_utils import TestCase, TEST_WITH_ROCM, FILE_SCHEMA
|
||||
|
|
|
|||
|
|
@ -14,10 +14,11 @@ import collections.abc
|
|||
from typing import List, Sequence, Tuple, Dict, Any, Union
|
||||
|
||||
from torch.testing import \
|
||||
(make_non_contiguous, _dispatch_dtypes, floating_types, floating_types_and,
|
||||
(make_non_contiguous, floating_types, floating_types_and,
|
||||
floating_and_complex_types, floating_and_complex_types_and,
|
||||
all_types_and_complex_and, all_types_and, all_types_and_complex,
|
||||
integral_types_and)
|
||||
from .._core import _dispatch_dtypes
|
||||
from torch.testing._internal.common_device_type import \
|
||||
(skipIf, skipMeta, skipCUDAIfNoMagma, skipCUDAIfNoMagmaAndNoCusolver, skipCUDAIfNoCusolver,
|
||||
skipCPUIfNoLapack, skipCPUIfNoMkl,
|
||||
|
|
|
|||
|
|
@ -42,10 +42,10 @@ from typing import cast, Any, Dict, Iterable, Iterator, Optional
|
|||
|
||||
import numpy as np
|
||||
|
||||
from torch.testing import floating_types_and, integral_types, complex_types
|
||||
from torch.testing._internal import expecttest
|
||||
from torch.testing import \
|
||||
(_compare_tensors_internal, _compare_scalars_internal, _compare_return_type,
|
||||
floating_types_and, integral_types, complex_types)
|
||||
from .._core import \
|
||||
(_compare_tensors_internal, _compare_scalars_internal, _compare_return_type)
|
||||
|
||||
import torch
|
||||
import torch.cuda
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user