mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: fix assert_allclose doesnt check shape should fix https://github.com/pytorch/pytorch/issues/47449. Pull Request resolved: https://github.com/pytorch/pytorch/pull/47580 Reviewed By: samestep Differential Revision: D24836399 Pulled By: walterddr fbshipit-source-id: 943f8c83864bc01e1a782048c234e9592d2f1a25
377 lines
14 KiB
Python
377 lines
14 KiB
Python
"""
|
|
The testing package contains testing-specific utilities.
|
|
"""
|
|
|
|
import torch
|
|
import random
|
|
import math
|
|
from typing import cast, List, Optional, Tuple, Union
|
|
from .check_kernel_launches import check_cuda_kernel_launches, check_code_for_cuda_kernel_launches
|
|
|
|
FileCheck = torch._C.FileCheck
|
|
|
|
__all__ = [
|
|
'assert_allclose', 'make_non_contiguous', 'rand_like', 'randn_like'
|
|
]
|
|
|
|
rand_like = torch.rand_like
|
|
randn_like = torch.randn_like
|
|
|
|
# Helper function that returns True when the dtype is an integral dtype,
|
|
# False otherwise.
|
|
# TODO: implement numpy-like issubdtype
|
|
def is_integral(dtype: torch.dtype) -> bool:
|
|
# Skip complex/quantized types
|
|
dtypes = [x for x in get_all_dtypes() if x not in get_all_complex_dtypes()]
|
|
return dtype in dtypes and not dtype.is_floating_point
|
|
|
|
def is_quantized(dtype: torch.dtype) -> bool:
|
|
return dtype in (torch.quint8, torch.qint8, torch.qint32, torch.quint4x2)
|
|
|
|
# Helper function that maps a flattened index back into the given shape
|
|
# TODO: consider adding torch.unravel_index
|
|
def _unravel_index(flat_index, shape):
|
|
res = []
|
|
|
|
# Short-circuits on zero dim tensors
|
|
if shape == torch.Size([]):
|
|
return 0
|
|
|
|
for size in shape[::-1]:
|
|
res.append(int(flat_index % size))
|
|
flat_index = int(flat_index // size)
|
|
|
|
if len(res) == 1:
|
|
return res[0]
|
|
|
|
return tuple(res[::-1])
|
|
# (bool, msg) tuple, where msg is None if and only if bool is True.
|
|
_compare_return_type = Tuple[bool, Optional[str]]
|
|
|
|
# Compares two tensors with the same size on the same device and with the same
|
|
# dtype for equality.
|
|
# Returns a tuple (bool, msg). The bool value returned is True when the tensors
|
|
# are "equal" and False otherwise.
|
|
# The msg value is a debug string, and is None if the tensors are "equal."
|
|
# NOTE: Test Framework Tensor 'Equality'
|
|
# Two tensors are "equal" if they are "close", in the sense of torch.allclose.
|
|
# The only exceptions are complex tensors and bool tensors.
|
|
#
|
|
# Complex tensors are "equal" if both the
|
|
# real and complex parts (separately) are close. This is divergent from
|
|
# torch.allclose's behavior, which compares the absolute values of the
|
|
# complex numbers instead.
|
|
#
|
|
# Using torch.allclose would be a less strict
|
|
# comparison that would allow large complex values with
|
|
# significant real or imaginary differences to be considered "equal,"
|
|
# and would make setting rtol and atol for complex tensors distinct from
|
|
# other tensor types.
|
|
#
|
|
# Bool tensors are equal only if they are identical, regardless of
|
|
# the rtol and atol values.
|
|
def _compare_tensors_internal(a: torch.Tensor, b: torch.Tensor, *, rtol, atol, equal_nan: bool) -> _compare_return_type:
|
|
debug_msg : Optional[str]
|
|
# Integer (including bool) comparisons are identity comparisons
|
|
# when rtol is zero and atol is less than one
|
|
if (
|
|
(is_integral(a.dtype) and rtol == 0 and atol < 1)
|
|
or a.dtype is torch.bool
|
|
or is_quantized(a.dtype)
|
|
):
|
|
if (a == b).all().item():
|
|
return (True, None)
|
|
|
|
# Gathers debug info for failed integer comparison
|
|
# NOTE: converts to long to correctly represent differences
|
|
# (especially between uint8 tensors)
|
|
identity_mask = a != b
|
|
a_flat = a.to(torch.long).flatten()
|
|
b_flat = b.to(torch.long).flatten()
|
|
count_non_identical = torch.sum(identity_mask, dtype=torch.long)
|
|
diff = torch.abs(a_flat - b_flat)
|
|
greatest_diff_index = torch.argmax(diff)
|
|
debug_msg = ("Found {0} different element(s) (out of {1}), with the greatest "
|
|
"difference of {2} ({3} vs. {4}) occuring at index "
|
|
"{5}.".format(count_non_identical.item(),
|
|
a.numel(),
|
|
diff[greatest_diff_index],
|
|
a_flat[greatest_diff_index],
|
|
b_flat[greatest_diff_index],
|
|
_unravel_index(greatest_diff_index, a.shape)))
|
|
return (False, debug_msg)
|
|
|
|
# Compares complex tensors' real and imaginary parts separately.
|
|
# (see NOTE Test Framework Tensor "Equality")
|
|
if a.is_complex():
|
|
a_real = a.real
|
|
b_real = b.real
|
|
real_result, debug_msg = _compare_tensors_internal(a_real, b_real,
|
|
rtol=rtol, atol=atol,
|
|
equal_nan=equal_nan)
|
|
|
|
if not real_result:
|
|
debug_msg = "Real parts failed to compare as equal! " + cast(str, debug_msg)
|
|
return (real_result, debug_msg)
|
|
|
|
a_imag = a.imag
|
|
b_imag = b.imag
|
|
imag_result, debug_msg = _compare_tensors_internal(a_imag, b_imag,
|
|
rtol=rtol, atol=atol,
|
|
equal_nan=equal_nan)
|
|
|
|
if not imag_result:
|
|
debug_msg = "Imaginary parts failed to compare as equal! " + cast(str, debug_msg)
|
|
return (imag_result, debug_msg)
|
|
|
|
return (True, None)
|
|
|
|
# All other comparisons use torch.allclose directly
|
|
if torch.allclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan):
|
|
return (True, None)
|
|
|
|
# Gathers debug info for failed float tensor comparison
|
|
# NOTE: converts to float64 to best represent differences
|
|
a_flat = a.to(torch.float64).flatten()
|
|
b_flat = b.to(torch.float64).flatten()
|
|
diff = torch.abs(a_flat - b_flat)
|
|
|
|
# Masks close values
|
|
# NOTE: this avoids (inf - inf) oddities when computing the difference
|
|
close = torch.isclose(a_flat, b_flat, rtol, atol, equal_nan)
|
|
diff[close] = 0
|
|
nans = torch.isnan(diff)
|
|
num_nans = nans.sum()
|
|
|
|
outside_range = (diff > (atol + rtol * torch.abs(b_flat))) | (diff == math.inf)
|
|
count_outside_range = torch.sum(outside_range, dtype=torch.long)
|
|
greatest_diff_index = torch.argmax(diff)
|
|
debug_msg = ("With rtol={0} and atol={1}, found {2} element(s) (out of {3}) whose "
|
|
"difference(s) exceeded the margin of error (including {4} nan comparisons). "
|
|
"The greatest difference was {5} ({6} vs. {7}), which "
|
|
"occurred at index {8}.".format(rtol, atol,
|
|
count_outside_range + num_nans,
|
|
a.numel(),
|
|
num_nans,
|
|
diff[greatest_diff_index],
|
|
a_flat[greatest_diff_index],
|
|
b_flat[greatest_diff_index],
|
|
_unravel_index(greatest_diff_index, a.shape)))
|
|
return (False, debug_msg)
|
|
|
|
# Checks if two scalars are equal(-ish), returning (True, None)
|
|
# when they are and (False, debug_msg) when they are not.
|
|
def _compare_scalars_internal(a, b, *, rtol: float, atol: float, equal_nan: bool) -> _compare_return_type:
|
|
def _helper(a, b, s) -> _compare_return_type:
|
|
# Short-circuits on identity
|
|
if a == b or (equal_nan and a != a and b != b):
|
|
return (True, None)
|
|
|
|
# Special-case for NaN comparisions when equal_nan=False
|
|
if not equal_nan and (a != a or b != b):
|
|
msg = ("Found {0} and {1} while comparing" + s + "and either one "
|
|
"is nan and the other isn't, or both are nan and "
|
|
"equal_nan is False").format(a, b)
|
|
return (False, msg)
|
|
|
|
diff = abs(a - b)
|
|
allowed_diff = atol + rtol * abs(b)
|
|
result = diff <= allowed_diff
|
|
|
|
# Special-case for infinity comparisons
|
|
# NOTE: if b is inf then allowed_diff will be inf when rtol is not 0
|
|
if ((math.isinf(a) or math.isinf(b)) and a != b):
|
|
result = False
|
|
|
|
msg = None
|
|
if not result:
|
|
msg = ("Comparing" + s + "{0} and {1} gives a "
|
|
"difference of {2}, but the allowed difference "
|
|
"with rtol={3} and atol={4} is "
|
|
"only {5}!").format(a, b, diff,
|
|
rtol, atol, allowed_diff)
|
|
|
|
return result, msg
|
|
|
|
if isinstance(a, complex) or isinstance(b, complex):
|
|
a = complex(a)
|
|
b = complex(b)
|
|
|
|
result, msg = _helper(a.real, b.real, " the real part ")
|
|
|
|
if not result:
|
|
return (False, msg)
|
|
|
|
return _helper(a.imag, b.imag, " the imaginary part ")
|
|
|
|
return _helper(a, b, " ")
|
|
|
|
def assert_allclose(actual, expected, rtol=None, atol=None, equal_nan=True, msg='') -> None:
|
|
if not isinstance(actual, torch.Tensor):
|
|
actual = torch.tensor(actual)
|
|
if not isinstance(expected, torch.Tensor):
|
|
expected = torch.tensor(expected, dtype=actual.dtype)
|
|
if expected.shape != actual.shape:
|
|
raise AssertionError("expected tensor shape {0} doesn't match with actual tensor "
|
|
"shape {1}!".format(expected.shape, actual.shape))
|
|
if rtol is None or atol is None:
|
|
if rtol is not None or atol is not None:
|
|
raise ValueError("rtol and atol must both be specified or both be unspecified")
|
|
rtol, atol = _get_default_tolerance(actual, expected)
|
|
|
|
result, debug_msg = _compare_tensors_internal(actual, expected,
|
|
rtol=rtol, atol=atol,
|
|
equal_nan=equal_nan)
|
|
|
|
if result:
|
|
return
|
|
|
|
if msg is None or msg == '':
|
|
msg = debug_msg
|
|
|
|
raise AssertionError(msg)
|
|
|
|
def make_non_contiguous(tensor: torch.Tensor) -> torch.Tensor:
|
|
if tensor.numel() <= 1: # can't make non-contiguous
|
|
return tensor.clone()
|
|
osize = list(tensor.size())
|
|
|
|
# randomly inflate a few dimensions in osize
|
|
for _ in range(2):
|
|
dim = random.randint(0, len(osize) - 1)
|
|
add = random.randint(4, 15)
|
|
osize[dim] = osize[dim] + add
|
|
|
|
# narrow doesn't make a non-contiguous tensor if we only narrow the 0-th dimension,
|
|
# (which will always happen with a 1-dimensional tensor), so let's make a new
|
|
# right-most dimension and cut it off
|
|
|
|
input = tensor.new(torch.Size(osize + [random.randint(2, 3)]))
|
|
input = input.select(len(input.size()) - 1, random.randint(0, 1))
|
|
# now extract the input of correct size from 'input'
|
|
for i in range(len(osize)):
|
|
if input.size(i) != tensor.size(i):
|
|
bounds = random.randint(1, input.size(i) - tensor.size(i))
|
|
input = input.narrow(i, bounds, tensor.size(i))
|
|
|
|
input.copy_(tensor)
|
|
|
|
# Use .data here to hide the view relation between input and other temporary Tensors
|
|
return input.data
|
|
|
|
|
|
# Functions and classes for describing the dtypes a function supports
|
|
# NOTE: these helpers should correspond to PyTorch's C++ dispatch macros
|
|
|
|
# Verifies each given dtype is a torch.dtype
|
|
def _validate_dtypes(*dtypes):
|
|
for dtype in dtypes:
|
|
assert isinstance(dtype, torch.dtype)
|
|
return dtypes
|
|
|
|
# class for tuples corresponding to a PyTorch dispatch macro
|
|
class _dispatch_dtypes(tuple):
|
|
def __add__(self, other):
|
|
assert isinstance(other, tuple)
|
|
return _dispatch_dtypes(tuple.__add__(self, other))
|
|
|
|
_floating_types = _dispatch_dtypes((torch.float32, torch.float64))
|
|
def floating_types():
|
|
return _floating_types
|
|
|
|
_floating_types_and_half = _floating_types + (torch.half,)
|
|
def floating_types_and_half():
|
|
return _floating_types_and_half
|
|
|
|
def floating_types_and(*dtypes):
|
|
return _floating_types + _validate_dtypes(*dtypes)
|
|
|
|
_floating_and_complex_types = _floating_types + (torch.cfloat, torch.cdouble)
|
|
def floating_and_complex_types():
|
|
return _floating_and_complex_types
|
|
|
|
def floating_and_complex_types_and(*dtypes):
|
|
return _floating_and_complex_types + _validate_dtypes(*dtypes)
|
|
|
|
_integral_types = _dispatch_dtypes((torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64))
|
|
def integral_types():
|
|
return _integral_types
|
|
|
|
def integral_types_and(*dtypes):
|
|
return _integral_types + _validate_dtypes(*dtypes)
|
|
|
|
_all_types = _floating_types + _integral_types
|
|
def all_types():
|
|
return _all_types
|
|
|
|
def all_types_and(*dtypes):
|
|
return _all_types + _validate_dtypes(*dtypes)
|
|
|
|
_complex_types = (torch.cfloat, torch.cdouble)
|
|
def complex_types():
|
|
return _complex_types
|
|
|
|
_all_types_and_complex = _all_types + _complex_types
|
|
def all_types_and_complex():
|
|
return _all_types_and_complex
|
|
|
|
def all_types_and_complex_and(*dtypes):
|
|
return _all_types_and_complex + _validate_dtypes(*dtypes)
|
|
|
|
_all_types_and_half = _all_types + (torch.half,)
|
|
def all_types_and_half():
|
|
return _all_types_and_half
|
|
|
|
def get_all_dtypes(include_half=True,
|
|
include_bfloat16=True,
|
|
include_bool=True,
|
|
include_complex=True,
|
|
include_complex32=False
|
|
) -> List[torch.dtype]:
|
|
dtypes = get_all_int_dtypes() + get_all_fp_dtypes(include_half=include_half, include_bfloat16=include_bfloat16)
|
|
if include_bool:
|
|
dtypes.append(torch.bool)
|
|
if include_complex:
|
|
dtypes += get_all_complex_dtypes(include_complex32)
|
|
return dtypes
|
|
|
|
def get_all_math_dtypes(device) -> List[torch.dtype]:
|
|
return get_all_int_dtypes() + get_all_fp_dtypes(include_half=device.startswith('cuda'),
|
|
include_bfloat16=False) + get_all_complex_dtypes()
|
|
|
|
def get_all_complex_dtypes(include_complex32=False) -> List[torch.dtype]:
|
|
return [torch.complex32, torch.complex64, torch.complex128] if include_complex32 else [torch.complex64, torch.complex128]
|
|
|
|
|
|
def get_all_int_dtypes() -> List[torch.dtype]:
|
|
return [torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64]
|
|
|
|
|
|
def get_all_fp_dtypes(include_half=True, include_bfloat16=True) -> List[torch.dtype]:
|
|
dtypes = [torch.float32, torch.float64]
|
|
if include_half:
|
|
dtypes.append(torch.float16)
|
|
if include_bfloat16:
|
|
dtypes.append(torch.bfloat16)
|
|
return dtypes
|
|
|
|
|
|
def get_all_device_types() -> List[str]:
|
|
return ['cpu'] if not torch.cuda.is_available() else ['cpu', 'cuda']
|
|
|
|
# 'dtype': (rtol, atol)
|
|
_default_tolerances = {
|
|
'float64': (1e-5, 1e-8), # NumPy default
|
|
'float32': (1e-4, 1e-5), # This may need to be changed
|
|
'float16': (1e-3, 1e-3), # This may need to be changed
|
|
}
|
|
|
|
|
|
def _get_default_tolerance(a, b=None) -> Tuple[float, float]:
|
|
if b is None:
|
|
dtype = str(a.dtype).split('.')[-1] # e.g. "float32"
|
|
return _default_tolerances.get(dtype, (0, 0))
|
|
a_tol = _get_default_tolerance(a)
|
|
b_tol = _get_default_tolerance(b)
|
|
return (max(a_tol[0], b_tol[0]), max(a_tol[1], b_tol[1]))
|