Add/fix typing annotations to some functions (#39075)

Summary:
Add missing typing imports to some jit tests
Add typing annotations to `torch.testing._compare_scalars_internal` and `torch.testing._internal.assertTrue`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/39075

Differential Revision: D21882468

Pulled By: malfet

fbshipit-source-id: dd9858eb8e11a38411544cc64daf36fced807d76
This commit is contained in:
Nikita Shulga 2020-06-04 13:37:44 -07:00 committed by Facebook GitHub Bot
parent da2f8c9f1f
commit 8811e4d00d
8 changed files with 120 additions and 77 deletions

View File

@ -19,6 +19,7 @@ files =
caffe2,
aten/src/ATen/function_wrapper.py,
test/test_complex.py,
test/test_torch.py,
test/test_type_hints.py,
test/test_type_info.py
@ -43,6 +44,9 @@ ignore_missing_imports = True
# positives as well.
#
[mypy-test_torch]
check_untyped_defs = False
[mypy-torch.distributed.*]
ignore_errors = True
@ -484,3 +488,6 @@ ignore_missing_imports = True
[mypy-skimage.*]
ignore_missing_imports = True
[mypy-librosa.*]
ignore_missing_imports = True

View File

@ -10,7 +10,8 @@ pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
from torch.testing._internal.jit_utils import JitTestCase, _inline_everything
from torch.testing._internal.common_utils import TemporaryFileName
from typing import List
from typing import List, Tuple
from torch import Tensor
class TestAsync(JitTestCase):
def test_async_python(self):

View File

@ -1,12 +1,13 @@
import os
import sys
import inspect
from typing import List, Dict
from typing import Dict, List, Optional, Tuple
from textwrap import dedent
from collections import OrderedDict
import torch
from torch.testing import FileCheck
from torch import Tensor
# Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))

View File

@ -34,6 +34,7 @@ from torch.testing._internal.common_device_type import instantiate_device_type_t
skipCPUIfNoLapack, skipCPUIfNoMkl, skipCUDAIfNoMagma, skipCUDAIfRocm, skipCUDAIfNotRocm, onlyCUDA, onlyCPU, \
dtypes, dtypesIfCUDA, dtypesIfCPU, deviceCountAtLeast, skipCUDAIf, precisionOverride, \
PYTORCH_CUDA_MEMCHECK, largeCUDATensorTest, largeTensorTest, onlyOnCPUAndCUDA
from typing import Dict, List, Tuple, Union
import torch.backends.quantized
import torch.testing._internal.data
@ -60,7 +61,8 @@ class AbstractTestCases:
# This is intentionally prefixed by an underscore. Otherwise pytest will try to
# run its methods as test cases.
class _TestTorchMixin(TestCase):
def _make_tensors(self, shape, val_range=(-100, 100), use_floating=True, use_integral=True, use_complex=False):
def _make_tensors(self, shape, val_range=(-100, 100), use_floating=True, use_integral=True,
use_complex=False) -> Dict[str, List[torch.Tensor]]:
float_types = [torch.double,
torch.float]
int_types = [torch.int64,
@ -70,7 +72,7 @@ class AbstractTestCases:
complex_types = [torch.complex64,
torch.complex128]
def make_contiguous(shape, dtype):
def make_contiguous(shape, dtype) -> torch.Tensor:
if dtype in float_types:
val = torch.randn(shape, dtype=dtype)
val = val * ((val_range[1] - val_range[0]) / (math.pi * 2.0))
@ -81,7 +83,7 @@ class AbstractTestCases:
result.apply_(lambda x: random.randint(val_range[0], val_range[1]))
return result
def make_non_contiguous(shape, dtype):
def make_non_contiguous(shape, dtype) -> torch.Tensor:
contig = make_contiguous(shape, dtype)
non_contig = torch.empty(shape + (2, 2), dtype=dtype)[..., 0]
non_contig = non_contig.select(-1, -1)
@ -89,7 +91,7 @@ class AbstractTestCases:
self.assertFalse(non_contig.is_contiguous())
return non_contig
def make_contiguous_slice(size, dtype):
def make_contiguous_slice(size, dtype) -> torch.Tensor:
contig = make_contiguous((1, size), dtype)
non_contig = contig[:1, 1:size - 1]
self.assertTrue(non_contig.is_contiguous())
@ -102,7 +104,7 @@ class AbstractTestCases:
types += int_types
if use_complex:
types += complex_types
tensors = {"cont": [], "noncont": [], "slice": []}
tensors: Dict[str, List[torch.Tensor]] = {"cont": [], "noncont": [], "slice": []}
for dtype in types:
tensors["cont"].append(make_contiguous(shape, dtype))
tensors["noncont"].append(make_non_contiguous(shape, dtype))
@ -124,7 +126,7 @@ class AbstractTestCases:
self.assertEqual(x.int().dtype, torch.int32)
self.assertEqual(x.bfloat16().dtype, torch.bfloat16)
def test_doc_template(self):
def test_doc_template(self) -> None:
from torch._torch_docs import __file__ as doc_file
from torch._torch_docs import multi_dim_common, single_dim_common, factory_common_args, factory_like_common_args
@ -206,7 +208,7 @@ class AbstractTestCases:
# TODO: add torch.* tests when we have proper namespacing on ATen functions
# test_namespace(torch)
def test_linear_algebra_scalar_raises(self):
def test_linear_algebra_scalar_raises(self) -> None:
m = torch.randn(5, 5)
v = torch.randn(5)
s = torch.tensor(7)
@ -492,8 +494,8 @@ class AbstractTestCases:
[1, 1, 1]]))
@slowTest
def test_mv(self):
def _test_mv(m1, v1):
def test_mv(self) -> None:
def _test_mv(m1: torch.Tensor, v1: torch.Tensor) -> None:
res1 = torch.mv(m1, v1)
res2 = res1.clone().zero_()
for i, j in iter_indices(m1):
@ -4055,7 +4057,7 @@ tensor([[[1., 1., 1., ..., 1., 1., 1.],
self.assertEqual(x.__repr__(), str(x))
self.assertExpectedInline(str(x), '''tensor([1.0000e+00, 1.0100e+03])''')
def test_sizeof(self):
def test_sizeof(self) -> None:
sizeof_empty = torch.randn(0).storage().__sizeof__()
sizeof_10 = torch.randn(10).storage().__sizeof__()
sizeof_100 = torch.randn(100).storage().__sizeof__()
@ -4068,7 +4070,7 @@ tensor([[[1., 1., 1., ..., 1., 1., 1.],
self.assertEqual((sizeof_100 - sizeof_empty) // (sizeof_10 - sizeof_empty), 10)
self.assertEqual((sizeof_100 - sizeof_empty) % (sizeof_10 - sizeof_empty), 0)
def test_unsqueeze(self):
def test_unsqueeze(self) -> None:
x = torch.randn(2, 3, 4)
y = x.unsqueeze(1)
self.assertEqual(y, x.view(2, 1, 3, 4))
@ -4082,7 +4084,7 @@ tensor([[[1., 1., 1., ..., 1., 1., 1.],
y = x.clone().unsqueeze_(2)
self.assertEqual(y, x.contiguous().view(2, 4, 1))
def test_iter(self):
def test_iter(self) -> None:
x = torch.randn(5, 5)
for i, sub in enumerate(x):
self.assertEqual(sub, x[i])
@ -4090,7 +4092,7 @@ tensor([[[1., 1., 1., ..., 1., 1., 1.],
x = torch.Tensor()
self.assertEqual(list(x), [])
def test_accreal_type(self):
def test_accreal_type(self) -> None:
x = torch.ones(2, 3, 4)
self.assertIsInstance(x.double().sum().item(), float)
self.assertIsInstance(x.float().sum().item(), float)
@ -4100,7 +4102,7 @@ tensor([[[1., 1., 1., ..., 1., 1., 1.],
self.assertIsInstance(x.char().sum().item(), int)
self.assertIsInstance(x.byte().sum().item(), int)
def test_assertEqual(self):
def test_assertEqual(self) -> None:
x = torch.FloatTensor([0])
self.assertEqual(x, 0)
xv = torch.autograd.Variable(x)
@ -4114,10 +4116,10 @@ tensor([[[1., 1., 1., ..., 1., 1., 1.],
self.assertRaises(AssertionError,
lambda: self.assertEqual(x, xv, rtol=4))
self.assertRaisesRegex(TypeError, "takes 3 positional arguments",
lambda: self.assertEqual(x, xv, 1.0, ""))
self.assertRaisesRegex(TypeError, "takes from 3 to 4 positional arguments",
lambda: self.assertEqual(x, xv, "", 1.0)) # type: ignore
def test_new(self):
def test_new(self) -> None:
x = torch.autograd.Variable(torch.Tensor())
y = torch.autograd.Variable(torch.randn(4, 4))
z = torch.autograd.Variable(torch.IntTensor([1, 2, 3]))
@ -4142,7 +4144,7 @@ tensor([[[1., 1., 1., ..., 1., 1., 1.],
# TypeError would be better
self.assertRaises(RuntimeError, lambda: x.new(z.storage()))
def test_empty_like(self):
def test_empty_like(self) -> None:
x = torch.autograd.Variable(torch.Tensor())
y = torch.autograd.Variable(torch.randn(4, 4))
z = torch.autograd.Variable(torch.IntTensor([1, 2, 3]))
@ -4166,7 +4168,7 @@ tensor([[[1., 1., 1., ..., 1., 1., 1.],
self.assertEqual(pinned.data_ptr(), pinned.pin_memory().data_ptr())
@unittest.skipIf(not TEST_NUMPY, "Numpy not found")
def test_numpy_unresizable(self):
def test_numpy_unresizable(self) -> None:
x = np.zeros((2, 2))
y = torch.from_numpy(x)
with self.assertRaises(ValueError):
@ -4180,7 +4182,7 @@ tensor([[[1., 1., 1., ..., 1., 1., 1.],
w.resize((10, 10))
@unittest.skipIf(not TEST_NUMPY, "Numpy not found")
def test_to_numpy(self):
def test_to_numpy(self) -> None:
def get_castable_tensor(shape, dtype):
if dtype.is_floating_point:
dtype_info = torch.finfo(dtype)
@ -4284,7 +4286,7 @@ tensor([[[1., 1., 1., ..., 1., 1., 1.],
self.assertTrue(x[0][1] == 3)
@unittest.skipIf(not TEST_NUMPY, "Numpy not found")
def test_to_numpy_bool(self):
def test_to_numpy_bool(self) -> None:
x = torch.tensor([True, False], dtype=torch.bool)
self.assertEqual(x.dtype, torch.bool)
@ -4301,7 +4303,7 @@ tensor([[[1., 1., 1., ..., 1., 1., 1.],
self.assertEqual(x[0], y[0])
@unittest.skipIf(not TEST_NUMPY, "Numpy not found")
def test_from_numpy(self):
def test_from_numpy(self) -> None:
dtypes = [
np.double,
np.float,
@ -4374,7 +4376,7 @@ tensor([[[1., 1., 1., ..., 1., 1., 1.],
self.assertRaises(ValueError, lambda: torch.from_numpy(x))
@unittest.skipIf(not TEST_NUMPY, "Numpy not found")
def test_ctor_with_numpy_scalar_ctor(self):
def test_ctor_with_numpy_scalar_ctor(self) -> None:
dtypes = [
np.double,
np.float,
@ -4476,7 +4478,7 @@ tensor([[[1., 1., 1., ..., 1., 1., 1.],
self.assertEqual(geq2_x[i], geq2_array[i])
@unittest.skipIf(not TEST_NUMPY, "Numpy not found")
def test_multiplication_numpy_scalar(self):
def test_multiplication_numpy_scalar(self) -> None:
for np_dtype in [np.float32, np.float64, np.int32, np.int64, np.int16, np.uint8]:
for t_dtype in [torch.float, torch.double]:
np_sc = np_dtype(2.0)
@ -15088,7 +15090,7 @@ class TestTorchDeviceType(TestCase):
@onlyCPU
@dtypes(torch.float, torch.double)
def test_hardshrink_edge_cases(self, device, dtype):
def test_hardshrink_edge_cases(self, device, dtype) -> None:
def h(values, l_expected):
for l, expected in l_expected.items():
values_tensor = torch.tensor([float(v) for v in values],
@ -15113,7 +15115,7 @@ class TestTorchDeviceType(TestCase):
@slowTest
@unittest.skipIf(not TEST_NUMPY, 'Numpy not found')
@dtypes(torch.double)
def test_einsum(self, device, dtype):
def test_einsum(self, device: torch.device, dtype: torch.dtype) -> None:
# test cases taken from https://gist.github.com/rockt/15ee013889d65342088e9260a377dc8f
x = torch.randn(5, dtype=dtype, device=device)
y = torch.randn(7, dtype=dtype, device=device)
@ -15129,7 +15131,9 @@ class TestTorchDeviceType(TestCase):
l = torch.randn(5, 10, dtype=dtype, device=device)
r = torch.randn(5, 20, dtype=dtype, device=device)
w = torch.randn(30, 10, 20, dtype=dtype, device=device)
test_list = [
test_list: List[Union[Tuple[str, torch.Tensor],
Tuple[str, torch.Tensor, torch.Tensor],
Tuple[str, torch.Tensor, torch.Tensor, torch.Tensor]]] = [
# -- Vector
("i->", x), # sum
("i,i->", x, x), # dot
@ -15184,8 +15188,8 @@ class TestTorchDeviceType(TestCase):
@onlyCPU
@dtypes(torch.bool, torch.double)
def test_sum_all(self, device, dtype):
def check_sum_all(tensor):
def test_sum_all(self, device, dtype) -> None:
def check_sum_all(tensor: torch.Tensor) -> None:
pylist = tensor.reshape(-1).tolist()
self.assertEqual(tensor.sum(), sum(pylist))
@ -15296,7 +15300,7 @@ class TestTorchDeviceType(TestCase):
@onlyCPU
@dtypes(torch.double)
def test_sum_out(self, device, dtype):
def test_sum_out(self, device, dtype: torch.dtype) -> None:
x = torch.rand(100, 100, dtype=dtype, device=device)
res1 = torch.sum(x, 1)
res2 = torch.tensor((), dtype=dtype, device=device)
@ -17696,8 +17700,8 @@ class TestDevicePrecision(TestCase):
torch.sum(x, dim=[0], dtype=torch.float32, out=y)
@deviceCountAtLeast(1)
def test_advancedindex_mixed_cpu_devices(self, devices):
def test(x, ia, ib):
def test_advancedindex_mixed_cpu_devices(self, devices) -> None:
def test(x: torch.Tensor, ia: torch.Tensor, ib: torch.Tensor) -> None:
# test getitem
self.assertEqual(x[:, ia, None, ib, 0].cpu(),
x.cpu()[:, ia.cpu(), None, ib.cpu(), 0])
@ -17746,7 +17750,7 @@ class TestDevicePrecision(TestCase):
ib = ib.to(other_device)
test(x, ia, ib)
def test_copy_broadcast(self, device):
def test_copy_broadcast(self, device) -> None:
x = torch.randn(10, 5)
y = torch.randn(5, device=device)
x.copy_(y)
@ -17793,7 +17797,7 @@ class TestDevicePrecision(TestCase):
output = torch.zeros_like(x)
self.assertEqual(output, expected)
def test_ones_like(self, device):
def test_ones_like(self, device) -> None:
expected = torch.ones(100, 100, device=device)
res1 = torch.ones_like(expected)
@ -17869,7 +17873,7 @@ class TestViewOps(TestCase):
self.assertEqual(a[5:].real, a.real[5:])
self.assertEqual(a[5:].imag, a.imag[5:])
def test_diagonal_view(self, device):
def test_diagonal_view(self, device) -> None:
t = torch.ones((5, 5), device=device)
v = torch.diagonal(t)
self.assertTrue(self.is_view_of(t, v))
@ -17884,7 +17888,7 @@ class TestViewOps(TestCase):
v[0, 0] = 0
self.assertEqual(t[0, 0, 1], v[0, 0])
def test_select_view(self, device):
def test_select_view(self, device) -> None:
t = torch.ones((5, 5), device=device)
v = t.select(0, 2)
self.assertTrue(self.is_view_of(t, v))
@ -17892,7 +17896,7 @@ class TestViewOps(TestCase):
v[0] = 0
self.assertEqual(t[2, 0], v[0])
def test_unbind_view(self, device):
def test_unbind_view(self, device) -> None:
t = torch.zeros((5, 5), device=device)
tup = torch.unbind(t)
@ -17902,7 +17906,7 @@ class TestViewOps(TestCase):
v[0] = idx + 1
self.assertEqual(t[idx, 0], v[0])
def test_expand_view(self, device):
def test_expand_view(self, device) -> None:
t = torch.ones((5, 1), device=device)
v = t.expand(5, 5)
self.assertTrue(self.is_view_of(t, v))
@ -17927,7 +17931,7 @@ class TestViewOps(TestCase):
v[0, 0] = 0
self.assertEqual(t[0, 2], v[0, 0])
def test_permute_view(self, device):
def test_permute_view(self, device) -> None:
t = torch.ones((5, 5), device=device)
v = t.permute(1, 0)
self.assertTrue(self.is_view_of(t, v))
@ -18154,7 +18158,7 @@ _signed_types_no_half = [
torch.int8, torch.short, torch.int, torch.long
]
_cpu_types = []
_cpu_types: List[torch.dtype] = []
_unsigned_types = [torch.uint8]
@ -18185,7 +18189,7 @@ def _convert_t(dtype, device):
# Requesting a half CPU tensor returns a float CPU tensor with
# values representable by a half.
# Initialization uses randint for non-float types and randn for float types.
def _make_tensor(shape, dtype, device, fill_ones=False):
def _make_tensor(shape, dtype, device, fill_ones=False) -> torch.Tensor:
# Returns a tensor filled with ones
if fill_ones:
return torch.ones(*shape, dtype=_convert_t(dtype, device), device=device)
@ -18206,7 +18210,7 @@ def _make_tensor(shape, dtype, device, fill_ones=False):
# Default: returns a tensor with random float values
return torch.randn(shape, dtype=dtype, device=device).to(dtype=dtype)
def _small_0d(dtype, device):
def _small_0d(dtype, device) -> torch.Tensor:
return _make_tensor((1,), dtype, device).squeeze()
def _small_2d(dtype, device, has_zeros=True, fill_ones=False, oneish=False):
@ -18635,8 +18639,8 @@ def generate_test_function(cls,
float_precision,
dtype_list,
dtype_cpu_list,
decorators):
def fn(self, device, dtype):
decorators) -> None:
def fn(self, device, dtype) -> None:
# Generates the CPU inputs
# Note: CPU tensors are never torch.half
cpu_tensor = tensor_ctor(dtype, 'cpu')
@ -18683,7 +18687,7 @@ def generate_test_function(cls,
setattr(cls, test_name, fn)
# Instantiates variants of tensor_op_tests and adds them to the given class.
def generate_tensor_op_tests(cls):
def generate_tensor_op_tests(cls) -> None:
def caller(cls,
op_str,
@ -18853,7 +18857,7 @@ def generate_torch_test_functions(cls, testmeta, inplace):
torch.testing.assert_allclose(actual, expected, rtol=testmeta.rtol, atol=testmeta.atol)
def fn_non_contig(self, device, dtype):
def fn_non_contig(self, device, dtype) -> None:
shapes = [(5, 7), (1024,)]
for shape in shapes:
contig = _make_tensor(shape, dtype=dtype, device=device)

View File

@ -43,7 +43,8 @@ read gen_pyi for the gory details.
needed_modules = set()
FACTORY_PARAMS = "dtype: Optional[_dtype]=None, device: Union[_device, str, None]=None, requires_grad: _bool=False"
DEVICE_PARAM = "device: Union[_device, str, None]=None"
FACTORY_PARAMS = f"dtype: Optional[_dtype]=None, {DEVICE_PARAM}, requires_grad: _bool=False"
# this could be more precise w.r.t list contents etc. How to do Ellipsis?
INDICES = "indices: Union[None, _int, slice, Tensor, List, Tuple]"
@ -534,6 +535,18 @@ def gen_pyi(declarations_path, out):
'new_ones': ['def new_ones(self, size: {}, {}) -> Tensor: ...'.
format(type_to_python('IntArrayRef'), FACTORY_PARAMS)],
'new_tensor': ["def new_tensor(self, data: Any, {}) -> Tensor: ...".format(FACTORY_PARAMS)],
# new and __init__ have the same signatures differ only in return type
# Adapted from legacy_tensor_ctor and legacy_tensor_new
'new': ['def new(self, *args: Any, {}) ->Tensor: ...'.format(DEVICE_PARAM),
'def new(self, storage: Storage) -> Tensor: ...',
'def new(self, other: Tensor) -> Tensor: ...',
'def new(self, size: {}, *, {}) -> Tensor: ...'.format(type_to_python('IntArrayRef'), DEVICE_PARAM),
],
'__init__': ['def __init__(self, *args: Any, {}) -> None: ...'.format(DEVICE_PARAM),
'def __init__(self, storage: Storage) -> None: ...',
'def __init__(self, other: Tensor) -> None: ...',
'def __init__(self, size: {}, *, {}) -> None: ...'.format(type_to_python('IntArrayRef'), DEVICE_PARAM),
],
# clamp has no default values in the Declarations
'clamp': ["def clamp(self, min: _float=-inf, max: _float=inf,"
" *, out: Optional[Tensor]=None) -> Tensor: ..."],
@ -633,7 +646,7 @@ def gen_pyi(declarations_path, out):
legacy_class_hints = []
for c in ('DoubleTensor', 'FloatTensor', 'LongTensor', 'IntTensor',
'ShortTensor', 'CharTensor', 'ByteTensor', 'BoolTensor'):
'ShortTensor', 'HalfTensor', 'CharTensor', 'ByteTensor', 'BoolTensor'):
legacy_class_hints.append('class {}(Tensor): ...'.format(c))
# Generate type signatures for dtype classes

View File

@ -38,7 +38,10 @@ class Size(Tuple[_int, ...]):
# Defined in torch/csrc/Dtype.cpp
class dtype:
# TODO: is_floating_point, is_complex, is_Signed, __reduce__
# TODO: __reduce__
is_floating_point: _bool
is_complex: _bool
is_signed: _bool
...
# Defined in torch/csrc/TypeInfo.cpp
@ -206,4 +209,5 @@ class _TensorBase(object):
layout: _layout
real: Tensor
imag: Tensor
_version: _bool
${tensor_method_hints}

View File

@ -5,6 +5,7 @@ The testing package contains testing-specific utilities.
import torch
import random
import math
from typing import cast, List, Optional, Tuple, Union
FileCheck = torch._C.FileCheck
@ -18,7 +19,7 @@ randn_like = torch.randn_like
# Helper function that returns True when the dtype is an integral dtype,
# False otherwise.
# TODO: implement numpy-like issubdtype
def is_integral(dtype):
def is_integral(dtype: torch.dtype) -> bool:
# Skip complex/quantized types
dtypes = [x for x in get_all_dtypes() if x not in get_all_complex_dtypes()]
return dtype in dtypes and not dtype.is_floating_point
@ -40,6 +41,8 @@ def _unravel_index(flat_index, shape):
return res[0]
return tuple(res[::-1])
# (bool, msg) tuple, where msg is None if and only if bool is True.
_compare_return_type = Tuple[bool, Optional[str]]
# Compares two tensors with the same size on the same device and with the same
# dtype for equality.
@ -63,7 +66,8 @@ def _unravel_index(flat_index, shape):
#
# Bool tensors are equal only if they are identical, regardless of
# the rtol and atol values.
def _compare_tensors_internal(a, b, *, rtol, atol, equal_nan):
def _compare_tensors_internal(a: torch.Tensor, b: torch.Tensor, *, rtol, atol, equal_nan: bool) -> _compare_return_type:
debug_msg : Optional[str]
# Integer (including bool) comparisons are identity comparisons
# when rtol is zero and atol is less than one
if (is_integral(a.dtype) and rtol == 0 and atol < 1) or a.dtype is torch.bool:
@ -99,7 +103,7 @@ def _compare_tensors_internal(a, b, *, rtol, atol, equal_nan):
equal_nan=equal_nan)
if not real_result:
debug_msg = "Real parts failed to compare as equal! " + debug_msg
debug_msg = "Real parts failed to compare as equal! " + cast(str, debug_msg)
return (real_result, debug_msg)
a_imag = a.imag
@ -109,7 +113,7 @@ def _compare_tensors_internal(a, b, *, rtol, atol, equal_nan):
equal_nan=equal_nan)
if not imag_result:
debug_msg = "Imaginary parts failed to compare as equal! " + debug_msg
debug_msg = "Imaginary parts failed to compare as equal! " + cast(str, debug_msg)
return (imag_result, debug_msg)
return (True, None)
@ -149,8 +153,8 @@ def _compare_tensors_internal(a, b, *, rtol, atol, equal_nan):
# Checks if two scalars are equal(-ish), returning (True, None)
# when they are and (False, debug_msg) when they are not.
def _compare_scalars_internal(a, b, *, rtol, atol, equal_nan):
def _helper(a, b, s):
def _compare_scalars_internal(a, b, *, rtol: float, atol: float, equal_nan: bool) -> _compare_return_type:
def _helper(a, b, s) -> _compare_return_type:
# Short-circuits on identity
if a == b or (equal_nan and a != a and b != b):
return (True, None)
@ -194,7 +198,7 @@ def _compare_scalars_internal(a, b, *, rtol, atol, equal_nan):
return _helper(a, b, " ")
def assert_allclose(actual, expected, rtol=None, atol=None, equal_nan=True, msg=''):
def assert_allclose(actual, expected, rtol=None, atol=None, equal_nan=True, msg='') -> None:
if not isinstance(actual, torch.Tensor):
actual = torch.tensor(actual)
if not isinstance(expected, torch.Tensor):
@ -218,7 +222,7 @@ def assert_allclose(actual, expected, rtol=None, atol=None, equal_nan=True, msg=
raise AssertionError(msg)
def make_non_contiguous(tensor):
def make_non_contiguous(tensor: torch.Tensor) -> torch.Tensor:
if tensor.numel() <= 1: # can't make non-contiguous
return tensor.clone()
osize = list(tensor.size())
@ -247,7 +251,7 @@ def make_non_contiguous(tensor):
return input.data
def get_all_dtypes(include_half=True, include_bfloat16=True, include_bool=True, include_complex=True):
def get_all_dtypes(include_half=True, include_bfloat16=True, include_bool=True, include_complex=True) -> List[torch.dtype]:
dtypes = get_all_int_dtypes() + get_all_fp_dtypes(include_half=include_half, include_bfloat16=include_bfloat16)
if include_bool:
dtypes.append(torch.bool)
@ -256,20 +260,20 @@ def get_all_dtypes(include_half=True, include_bfloat16=True, include_bool=True,
return dtypes
def get_all_math_dtypes(device):
def get_all_math_dtypes(device) -> List[torch.dtype]:
return get_all_int_dtypes() + get_all_fp_dtypes(include_half=device.startswith('cuda'),
include_bfloat16=False) + get_all_complex_dtypes()
def get_all_complex_dtypes():
def get_all_complex_dtypes() -> List[torch.dtype]:
return [torch.complex64, torch.complex128]
def get_all_int_dtypes():
def get_all_int_dtypes() -> List[torch.dtype]:
return [torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64]
def get_all_fp_dtypes(include_half=True, include_bfloat16=True):
def get_all_fp_dtypes(include_half=True, include_bfloat16=True) -> List[torch.dtype]:
dtypes = [torch.float32, torch.float64]
if include_half:
dtypes.append(torch.float16)
@ -278,7 +282,7 @@ def get_all_fp_dtypes(include_half=True, include_bfloat16=True):
return dtypes
def get_all_device_types():
def get_all_device_types() -> List[str]:
return ['cpu'] if not torch.cuda.is_available() else ['cpu', 'cuda']
# 'dtype': (rtol, atol)
@ -289,7 +293,7 @@ _default_tolerances = {
}
def _get_default_tolerance(a, b=None):
def _get_default_tolerance(a, b=None) -> Tuple[float, float]:
if b is None:
dtype = str(a.dtype).split('.')[-1] # e.g. "float32"
return _default_tolerances.get(dtype, (0, 0))

View File

@ -33,9 +33,10 @@ import json
from urllib.request import urlopen
import __main__
import errno
from typing import cast, Any, Iterable, Optional
from torch.testing._internal import expecttest
from torch.testing import _compare_tensors_internal, _compare_scalars_internal
from torch.testing import _compare_tensors_internal, _compare_scalars_internal, _compare_return_type
import torch
import torch.cuda
@ -947,8 +948,8 @@ class TestCase(expecttest.TestCase):
# NOTE: this function checks the tensors' devices, sizes, and dtypes
# and acquires the appropriate device, dtype, rtol and atol to compare
# them with. It then calls _compare_tensors_internal.
def _compareTensors(self, a, b, *, rtol=None, atol=None, equal_nan=True,
exact_dtype=True, exact_device=False):
def _compareTensors(self, a, b, *, rtol: Optional[float] = None, atol=None, equal_nan=True,
exact_dtype=True, exact_device=False) -> _compare_return_type:
assert (atol is None) == (rtol is None)
if not isinstance(a, torch.Tensor):
return (False, "argument a, {0}, to _compareTensors is not a tensor!".format(a))
@ -993,7 +994,8 @@ class TestCase(expecttest.TestCase):
# when they are and (False, debug_msg) when they are not.
# NOTE: this function just acquires rtol and atol
# before calling _compare_scalars_internal.
def _compareScalars(self, a, b, *, rtol=None, atol=None, equal_nan=True):
def _compareScalars(self, a, b, *,
rtol: Optional[float] = None, atol: Optional[float] = None, equal_nan=True) -> _compare_return_type:
# Acquires rtol and atol
assert (atol is None) == (rtol is None)
if rtol is None:
@ -1005,17 +1007,18 @@ class TestCase(expecttest.TestCase):
rtol, atol = 0, 0
atol = max(atol, self.precision)
return _compare_scalars_internal(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan)
return _compare_scalars_internal(a, b, rtol=cast(float, rtol), atol=cast(float, atol), equal_nan=equal_nan)
def assertEqualIgnoreType(self, *args, **kwargs):
def assertEqualIgnoreType(self, *args, **kwargs) -> None:
# If you are seeing this function used, that means test is written wrongly
# and deserves detailed investigation
return self.assertEqual(*args, exact_dtype=False, **kwargs)
# Compares x and y
# TODO: default exact_device to True
def assertEqual(self, x, y, *, atol=None, rtol=None, equal_nan=True,
exact_dtype=True, exact_device=False, msg=None):
def assertEqual(self, x, y, msg: Optional[str] = None, *,
atol: Optional[float] = None, rtol: Optional[float] = None,
equal_nan=True, exact_dtype=True, exact_device=False) -> None:
assert (atol is None) == (rtol is None), "If one of atol or rtol is specified the other must be, too"
# Tensor x Number and Number x Tensor comparisons
@ -1045,6 +1048,7 @@ class TestCase(expecttest.TestCase):
exact_device=exact_device)
if not indices_result and msg is None:
assert debug_msg is not None
msg = "Sparse tensor indices failed to compare as equal! " + debug_msg
self.assertTrue(indices_result, msg=msg)
@ -1054,6 +1058,7 @@ class TestCase(expecttest.TestCase):
exact_device=exact_device)
if not values_result and msg is None:
assert debug_msg is not None
msg = "Sparse tensor values failed to compare as equal! " + debug_msg
self.assertTrue(values_result, msg=msg)
elif x.is_quantized and y.is_quantized:
@ -1086,6 +1091,7 @@ class TestCase(expecttest.TestCase):
exact_device=exact_device)
if not result and msg is None:
assert debug_msg is not None
msg = "Quantized representations failed to compare as equal! " + debug_msg
self.assertTrue(result, msg=msg)
else:
@ -1094,6 +1100,7 @@ class TestCase(expecttest.TestCase):
exact_device=exact_device)
if not result and msg is None:
assert debug_msg is not None
msg = "Tensors failed to compare as equal! " + debug_msg
self.assertTrue(result, msg=msg)
elif isinstance(x, string_classes) and isinstance(y, string_classes):
@ -1127,6 +1134,7 @@ class TestCase(expecttest.TestCase):
result, debug_msg = self._compareScalars(x, y, rtol=rtol, atol=atol,
equal_nan=equal_nan)
if not result and msg is None:
assert debug_msg is not None
msg = "Scalars failed to compare as equal! " + debug_msg
self.assertTrue(result, msg=msg)
else:
@ -1139,17 +1147,18 @@ class TestCase(expecttest.TestCase):
rtol = None if prec is None else 0
self.assertEqual(x, y, msg=msg, atol=prec, rtol=rtol)
def assertNotEqual(self, x, y, *, msg=None, atol=None, rtol=None):
def assertNotEqual(self, x, y, msg: Optional[str] = None, *,
atol: Optional[float] = None, rtol: Optional[float] = None) -> None:
with self.assertRaises(AssertionError, msg=msg):
self.assertEqual(x, y, atol=atol, rtol=rtol)
def assertEqualTypeString(self, x, y):
def assertEqualTypeString(self, x, y) -> None:
# This API is used simulate deprecated x.type() == y.type()
self.assertEqual(x.device, y.device)
self.assertEqual(x.dtype, y.dtype)
self.assertEqual(x.is_sparse, y.is_sparse)
def assertObjectIn(self, obj, iterable):
def assertObjectIn(self, obj: Any, iterable: Iterable[Any]) -> None:
for elem in iterable:
if id(obj) == id(elem):
return