""" The testing package contains testing-specific utilities. """ import torch import random import math from typing import cast, List, Optional, Tuple, Union from .check_kernel_launches import check_cuda_kernel_launches, check_code_for_cuda_kernel_launches FileCheck = torch._C.FileCheck __all__ = [ 'assert_allclose', 'make_non_contiguous', 'rand_like', 'randn_like' ] rand_like = torch.rand_like randn_like = torch.randn_like # Helper function that returns True when the dtype is an integral dtype, # False otherwise. # TODO: implement numpy-like issubdtype def is_integral(dtype: torch.dtype) -> bool: # Skip complex/quantized types dtypes = [x for x in get_all_dtypes() if x not in get_all_complex_dtypes()] return dtype in dtypes and not dtype.is_floating_point def is_quantized(dtype: torch.dtype) -> bool: return dtype in (torch.quint8, torch.qint8, torch.qint32, torch.quint4x2) # Helper function that maps a flattened index back into the given shape # TODO: consider adding torch.unravel_index def _unravel_index(flat_index, shape): res = [] # Short-circuits on zero dim tensors if shape == torch.Size([]): return 0 for size in shape[::-1]: res.append(int(flat_index % size)) flat_index = int(flat_index // size) if len(res) == 1: return res[0] return tuple(res[::-1]) # (bool, msg) tuple, where msg is None if and only if bool is True. _compare_return_type = Tuple[bool, Optional[str]] # Compares two tensors with the same size on the same device and with the same # dtype for equality. # Returns a tuple (bool, msg). The bool value returned is True when the tensors # are "equal" and False otherwise. # The msg value is a debug string, and is None if the tensors are "equal." # NOTE: Test Framework Tensor 'Equality' # Two tensors are "equal" if they are "close", in the sense of torch.allclose. # The only exceptions are complex tensors and bool tensors. # # Complex tensors are "equal" if both the # real and complex parts (separately) are close. This is divergent from # torch.allclose's behavior, which compares the absolute values of the # complex numbers instead. # # Using torch.allclose would be a less strict # comparison that would allow large complex values with # significant real or imaginary differences to be considered "equal," # and would make setting rtol and atol for complex tensors distinct from # other tensor types. # # Bool tensors are equal only if they are identical, regardless of # the rtol and atol values. def _compare_tensors_internal(a: torch.Tensor, b: torch.Tensor, *, rtol, atol, equal_nan: bool) -> _compare_return_type: debug_msg : Optional[str] # Integer (including bool) comparisons are identity comparisons # when rtol is zero and atol is less than one if ( (is_integral(a.dtype) and rtol == 0 and atol < 1) or a.dtype is torch.bool or is_quantized(a.dtype) ): if (a == b).all().item(): return (True, None) # Gathers debug info for failed integer comparison # NOTE: converts to long to correctly represent differences # (especially between uint8 tensors) identity_mask = a != b a_flat = a.to(torch.long).flatten() b_flat = b.to(torch.long).flatten() count_non_identical = torch.sum(identity_mask, dtype=torch.long) diff = torch.abs(a_flat - b_flat) greatest_diff_index = torch.argmax(diff) debug_msg = ("Found {0} different element(s) (out of {1}), with the greatest " "difference of {2} ({3} vs. {4}) occuring at index " "{5}.".format(count_non_identical.item(), a.numel(), diff[greatest_diff_index], a_flat[greatest_diff_index], b_flat[greatest_diff_index], _unravel_index(greatest_diff_index, a.shape))) return (False, debug_msg) # Compares complex tensors' real and imaginary parts separately. # (see NOTE Test Framework Tensor "Equality") if a.is_complex(): a_real = a.real b_real = b.real real_result, debug_msg = _compare_tensors_internal(a_real, b_real, rtol=rtol, atol=atol, equal_nan=equal_nan) if not real_result: debug_msg = "Real parts failed to compare as equal! " + cast(str, debug_msg) return (real_result, debug_msg) a_imag = a.imag b_imag = b.imag imag_result, debug_msg = _compare_tensors_internal(a_imag, b_imag, rtol=rtol, atol=atol, equal_nan=equal_nan) if not imag_result: debug_msg = "Imaginary parts failed to compare as equal! " + cast(str, debug_msg) return (imag_result, debug_msg) return (True, None) # All other comparisons use torch.allclose directly if torch.allclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan): return (True, None) # Gathers debug info for failed float tensor comparison # NOTE: converts to float64 to best represent differences a_flat = a.to(torch.float64).flatten() b_flat = b.to(torch.float64).flatten() diff = torch.abs(a_flat - b_flat) # Masks close values # NOTE: this avoids (inf - inf) oddities when computing the difference close = torch.isclose(a_flat, b_flat, rtol, atol, equal_nan) diff[close] = 0 nans = torch.isnan(diff) num_nans = nans.sum() outside_range = (diff > (atol + rtol * torch.abs(b_flat))) | (diff == math.inf) count_outside_range = torch.sum(outside_range, dtype=torch.long) greatest_diff_index = torch.argmax(diff) debug_msg = ("With rtol={0} and atol={1}, found {2} element(s) (out of {3}) whose " "difference(s) exceeded the margin of error (including {4} nan comparisons). " "The greatest difference was {5} ({6} vs. {7}), which " "occurred at index {8}.".format(rtol, atol, count_outside_range + num_nans, a.numel(), num_nans, diff[greatest_diff_index], a_flat[greatest_diff_index], b_flat[greatest_diff_index], _unravel_index(greatest_diff_index, a.shape))) return (False, debug_msg) # Checks if two scalars are equal(-ish), returning (True, None) # when they are and (False, debug_msg) when they are not. def _compare_scalars_internal(a, b, *, rtol: float, atol: float, equal_nan: bool) -> _compare_return_type: def _helper(a, b, s) -> _compare_return_type: # Short-circuits on identity if a == b or (equal_nan and a != a and b != b): return (True, None) # Special-case for NaN comparisions when equal_nan=False if not equal_nan and (a != a or b != b): msg = ("Found {0} and {1} while comparing" + s + "and either one " "is nan and the other isn't, or both are nan and " "equal_nan is False").format(a, b) return (False, msg) diff = abs(a - b) allowed_diff = atol + rtol * abs(b) result = diff <= allowed_diff # Special-case for infinity comparisons # NOTE: if b is inf then allowed_diff will be inf when rtol is not 0 if ((math.isinf(a) or math.isinf(b)) and a != b): result = False msg = None if not result: msg = ("Comparing" + s + "{0} and {1} gives a " "difference of {2}, but the allowed difference " "with rtol={3} and atol={4} is " "only {5}!").format(a, b, diff, rtol, atol, allowed_diff) return result, msg if isinstance(a, complex) or isinstance(b, complex): a = complex(a) b = complex(b) result, msg = _helper(a.real, b.real, " the real part ") if not result: return (False, msg) return _helper(a.imag, b.imag, " the imaginary part ") return _helper(a, b, " ") def assert_allclose(actual, expected, rtol=None, atol=None, equal_nan=True, msg='') -> None: if not isinstance(actual, torch.Tensor): actual = torch.tensor(actual) if not isinstance(expected, torch.Tensor): expected = torch.tensor(expected, dtype=actual.dtype) if expected.shape != actual.shape: raise AssertionError("expected tensor shape {0} doesn't match with actual tensor " "shape {1}!".format(expected.shape, actual.shape)) if rtol is None or atol is None: if rtol is not None or atol is not None: raise ValueError("rtol and atol must both be specified or both be unspecified") rtol, atol = _get_default_tolerance(actual, expected) result, debug_msg = _compare_tensors_internal(actual, expected, rtol=rtol, atol=atol, equal_nan=equal_nan) if result: return if msg is None or msg == '': msg = debug_msg raise AssertionError(msg) def make_non_contiguous(tensor: torch.Tensor) -> torch.Tensor: if tensor.numel() <= 1: # can't make non-contiguous return tensor.clone() osize = list(tensor.size()) # randomly inflate a few dimensions in osize for _ in range(2): dim = random.randint(0, len(osize) - 1) add = random.randint(4, 15) osize[dim] = osize[dim] + add # narrow doesn't make a non-contiguous tensor if we only narrow the 0-th dimension, # (which will always happen with a 1-dimensional tensor), so let's make a new # right-most dimension and cut it off input = tensor.new(torch.Size(osize + [random.randint(2, 3)])) input = input.select(len(input.size()) - 1, random.randint(0, 1)) # now extract the input of correct size from 'input' for i in range(len(osize)): if input.size(i) != tensor.size(i): bounds = random.randint(1, input.size(i) - tensor.size(i)) input = input.narrow(i, bounds, tensor.size(i)) input.copy_(tensor) # Use .data here to hide the view relation between input and other temporary Tensors return input.data # Functions and classes for describing the dtypes a function supports # NOTE: these helpers should correspond to PyTorch's C++ dispatch macros # Verifies each given dtype is a torch.dtype def _validate_dtypes(*dtypes): for dtype in dtypes: assert isinstance(dtype, torch.dtype) return dtypes # class for tuples corresponding to a PyTorch dispatch macro class _dispatch_dtypes(tuple): def __add__(self, other): assert isinstance(other, tuple) return _dispatch_dtypes(tuple.__add__(self, other)) _floating_types = _dispatch_dtypes((torch.float32, torch.float64)) def floating_types(): return _floating_types _floating_types_and_half = _floating_types + (torch.half,) def floating_types_and_half(): return _floating_types_and_half def floating_types_and(*dtypes): return _floating_types + _validate_dtypes(*dtypes) _floating_and_complex_types = _floating_types + (torch.cfloat, torch.cdouble) def floating_and_complex_types(): return _floating_and_complex_types def floating_and_complex_types_and(*dtypes): return _floating_and_complex_types + _validate_dtypes(*dtypes) _integral_types = _dispatch_dtypes((torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64)) def integral_types(): return _integral_types def integral_types_and(*dtypes): return _integral_types + _validate_dtypes(*dtypes) _all_types = _floating_types + _integral_types def all_types(): return _all_types def all_types_and(*dtypes): return _all_types + _validate_dtypes(*dtypes) _complex_types = (torch.cfloat, torch.cdouble) def complex_types(): return _complex_types _all_types_and_complex = _all_types + _complex_types def all_types_and_complex(): return _all_types_and_complex def all_types_and_complex_and(*dtypes): return _all_types_and_complex + _validate_dtypes(*dtypes) _all_types_and_half = _all_types + (torch.half,) def all_types_and_half(): return _all_types_and_half def get_all_dtypes(include_half=True, include_bfloat16=True, include_bool=True, include_complex=True, include_complex32=False ) -> List[torch.dtype]: dtypes = get_all_int_dtypes() + get_all_fp_dtypes(include_half=include_half, include_bfloat16=include_bfloat16) if include_bool: dtypes.append(torch.bool) if include_complex: dtypes += get_all_complex_dtypes(include_complex32) return dtypes def get_all_math_dtypes(device) -> List[torch.dtype]: return get_all_int_dtypes() + get_all_fp_dtypes(include_half=device.startswith('cuda'), include_bfloat16=False) + get_all_complex_dtypes() def get_all_complex_dtypes(include_complex32=False) -> List[torch.dtype]: return [torch.complex32, torch.complex64, torch.complex128] if include_complex32 else [torch.complex64, torch.complex128] def get_all_int_dtypes() -> List[torch.dtype]: return [torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64] def get_all_fp_dtypes(include_half=True, include_bfloat16=True) -> List[torch.dtype]: dtypes = [torch.float32, torch.float64] if include_half: dtypes.append(torch.float16) if include_bfloat16: dtypes.append(torch.bfloat16) return dtypes def get_all_device_types() -> List[str]: return ['cpu'] if not torch.cuda.is_available() else ['cpu', 'cuda'] # 'dtype': (rtol, atol) _default_tolerances = { 'float64': (1e-5, 1e-8), # NumPy default 'float32': (1e-4, 1e-5), # This may need to be changed 'float16': (1e-3, 1e-3), # This may need to be changed } def _get_default_tolerance(a, b=None) -> Tuple[float, float]: if b is None: dtype = str(a.dtype).split('.')[-1] # e.g. "float32" return _default_tolerances.get(dtype, (0, 0)) a_tol = _get_default_tolerance(a) b_tol = _get_default_tolerance(b) return (max(a_tol[0], b_tol[0]), max(a_tol[1], b_tol[1]))