mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Add/fix typing annotations to some functions (#39075)
Summary: Add missing typing imports to some jit tests Add typing annotations to `torch.testing._compare_scalars_internal` and `torch.testing._internal.assertTrue` Pull Request resolved: https://github.com/pytorch/pytorch/pull/39075 Differential Revision: D21882468 Pulled By: malfet fbshipit-source-id: dd9858eb8e11a38411544cc64daf36fced807d76
This commit is contained in:
parent
da2f8c9f1f
commit
8811e4d00d
7
mypy.ini
7
mypy.ini
|
|
@ -19,6 +19,7 @@ files =
|
||||||
caffe2,
|
caffe2,
|
||||||
aten/src/ATen/function_wrapper.py,
|
aten/src/ATen/function_wrapper.py,
|
||||||
test/test_complex.py,
|
test/test_complex.py,
|
||||||
|
test/test_torch.py,
|
||||||
test/test_type_hints.py,
|
test/test_type_hints.py,
|
||||||
test/test_type_info.py
|
test/test_type_info.py
|
||||||
|
|
||||||
|
|
@ -43,6 +44,9 @@ ignore_missing_imports = True
|
||||||
# positives as well.
|
# positives as well.
|
||||||
#
|
#
|
||||||
|
|
||||||
|
[mypy-test_torch]
|
||||||
|
check_untyped_defs = False
|
||||||
|
|
||||||
[mypy-torch.distributed.*]
|
[mypy-torch.distributed.*]
|
||||||
ignore_errors = True
|
ignore_errors = True
|
||||||
|
|
||||||
|
|
@ -484,3 +488,6 @@ ignore_missing_imports = True
|
||||||
|
|
||||||
[mypy-skimage.*]
|
[mypy-skimage.*]
|
||||||
ignore_missing_imports = True
|
ignore_missing_imports = True
|
||||||
|
|
||||||
|
[mypy-librosa.*]
|
||||||
|
ignore_missing_imports = True
|
||||||
|
|
|
||||||
|
|
@ -10,7 +10,8 @@ pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||||
sys.path.append(pytorch_test_dir)
|
sys.path.append(pytorch_test_dir)
|
||||||
from torch.testing._internal.jit_utils import JitTestCase, _inline_everything
|
from torch.testing._internal.jit_utils import JitTestCase, _inline_everything
|
||||||
from torch.testing._internal.common_utils import TemporaryFileName
|
from torch.testing._internal.common_utils import TemporaryFileName
|
||||||
from typing import List
|
from typing import List, Tuple
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
class TestAsync(JitTestCase):
|
class TestAsync(JitTestCase):
|
||||||
def test_async_python(self):
|
def test_async_python(self):
|
||||||
|
|
|
||||||
|
|
@ -1,12 +1,13 @@
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import inspect
|
import inspect
|
||||||
from typing import List, Dict
|
from typing import Dict, List, Optional, Tuple
|
||||||
from textwrap import dedent
|
from textwrap import dedent
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.testing import FileCheck
|
from torch.testing import FileCheck
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
# Make the helper files in test/ importable
|
# Make the helper files in test/ importable
|
||||||
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||||
|
|
|
||||||
|
|
@ -34,6 +34,7 @@ from torch.testing._internal.common_device_type import instantiate_device_type_t
|
||||||
skipCPUIfNoLapack, skipCPUIfNoMkl, skipCUDAIfNoMagma, skipCUDAIfRocm, skipCUDAIfNotRocm, onlyCUDA, onlyCPU, \
|
skipCPUIfNoLapack, skipCPUIfNoMkl, skipCUDAIfNoMagma, skipCUDAIfRocm, skipCUDAIfNotRocm, onlyCUDA, onlyCPU, \
|
||||||
dtypes, dtypesIfCUDA, dtypesIfCPU, deviceCountAtLeast, skipCUDAIf, precisionOverride, \
|
dtypes, dtypesIfCUDA, dtypesIfCPU, deviceCountAtLeast, skipCUDAIf, precisionOverride, \
|
||||||
PYTORCH_CUDA_MEMCHECK, largeCUDATensorTest, largeTensorTest, onlyOnCPUAndCUDA
|
PYTORCH_CUDA_MEMCHECK, largeCUDATensorTest, largeTensorTest, onlyOnCPUAndCUDA
|
||||||
|
from typing import Dict, List, Tuple, Union
|
||||||
import torch.backends.quantized
|
import torch.backends.quantized
|
||||||
import torch.testing._internal.data
|
import torch.testing._internal.data
|
||||||
|
|
||||||
|
|
@ -60,7 +61,8 @@ class AbstractTestCases:
|
||||||
# This is intentionally prefixed by an underscore. Otherwise pytest will try to
|
# This is intentionally prefixed by an underscore. Otherwise pytest will try to
|
||||||
# run its methods as test cases.
|
# run its methods as test cases.
|
||||||
class _TestTorchMixin(TestCase):
|
class _TestTorchMixin(TestCase):
|
||||||
def _make_tensors(self, shape, val_range=(-100, 100), use_floating=True, use_integral=True, use_complex=False):
|
def _make_tensors(self, shape, val_range=(-100, 100), use_floating=True, use_integral=True,
|
||||||
|
use_complex=False) -> Dict[str, List[torch.Tensor]]:
|
||||||
float_types = [torch.double,
|
float_types = [torch.double,
|
||||||
torch.float]
|
torch.float]
|
||||||
int_types = [torch.int64,
|
int_types = [torch.int64,
|
||||||
|
|
@ -70,7 +72,7 @@ class AbstractTestCases:
|
||||||
complex_types = [torch.complex64,
|
complex_types = [torch.complex64,
|
||||||
torch.complex128]
|
torch.complex128]
|
||||||
|
|
||||||
def make_contiguous(shape, dtype):
|
def make_contiguous(shape, dtype) -> torch.Tensor:
|
||||||
if dtype in float_types:
|
if dtype in float_types:
|
||||||
val = torch.randn(shape, dtype=dtype)
|
val = torch.randn(shape, dtype=dtype)
|
||||||
val = val * ((val_range[1] - val_range[0]) / (math.pi * 2.0))
|
val = val * ((val_range[1] - val_range[0]) / (math.pi * 2.0))
|
||||||
|
|
@ -81,7 +83,7 @@ class AbstractTestCases:
|
||||||
result.apply_(lambda x: random.randint(val_range[0], val_range[1]))
|
result.apply_(lambda x: random.randint(val_range[0], val_range[1]))
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def make_non_contiguous(shape, dtype):
|
def make_non_contiguous(shape, dtype) -> torch.Tensor:
|
||||||
contig = make_contiguous(shape, dtype)
|
contig = make_contiguous(shape, dtype)
|
||||||
non_contig = torch.empty(shape + (2, 2), dtype=dtype)[..., 0]
|
non_contig = torch.empty(shape + (2, 2), dtype=dtype)[..., 0]
|
||||||
non_contig = non_contig.select(-1, -1)
|
non_contig = non_contig.select(-1, -1)
|
||||||
|
|
@ -89,7 +91,7 @@ class AbstractTestCases:
|
||||||
self.assertFalse(non_contig.is_contiguous())
|
self.assertFalse(non_contig.is_contiguous())
|
||||||
return non_contig
|
return non_contig
|
||||||
|
|
||||||
def make_contiguous_slice(size, dtype):
|
def make_contiguous_slice(size, dtype) -> torch.Tensor:
|
||||||
contig = make_contiguous((1, size), dtype)
|
contig = make_contiguous((1, size), dtype)
|
||||||
non_contig = contig[:1, 1:size - 1]
|
non_contig = contig[:1, 1:size - 1]
|
||||||
self.assertTrue(non_contig.is_contiguous())
|
self.assertTrue(non_contig.is_contiguous())
|
||||||
|
|
@ -102,7 +104,7 @@ class AbstractTestCases:
|
||||||
types += int_types
|
types += int_types
|
||||||
if use_complex:
|
if use_complex:
|
||||||
types += complex_types
|
types += complex_types
|
||||||
tensors = {"cont": [], "noncont": [], "slice": []}
|
tensors: Dict[str, List[torch.Tensor]] = {"cont": [], "noncont": [], "slice": []}
|
||||||
for dtype in types:
|
for dtype in types:
|
||||||
tensors["cont"].append(make_contiguous(shape, dtype))
|
tensors["cont"].append(make_contiguous(shape, dtype))
|
||||||
tensors["noncont"].append(make_non_contiguous(shape, dtype))
|
tensors["noncont"].append(make_non_contiguous(shape, dtype))
|
||||||
|
|
@ -124,7 +126,7 @@ class AbstractTestCases:
|
||||||
self.assertEqual(x.int().dtype, torch.int32)
|
self.assertEqual(x.int().dtype, torch.int32)
|
||||||
self.assertEqual(x.bfloat16().dtype, torch.bfloat16)
|
self.assertEqual(x.bfloat16().dtype, torch.bfloat16)
|
||||||
|
|
||||||
def test_doc_template(self):
|
def test_doc_template(self) -> None:
|
||||||
from torch._torch_docs import __file__ as doc_file
|
from torch._torch_docs import __file__ as doc_file
|
||||||
from torch._torch_docs import multi_dim_common, single_dim_common, factory_common_args, factory_like_common_args
|
from torch._torch_docs import multi_dim_common, single_dim_common, factory_common_args, factory_like_common_args
|
||||||
|
|
||||||
|
|
@ -206,7 +208,7 @@ class AbstractTestCases:
|
||||||
# TODO: add torch.* tests when we have proper namespacing on ATen functions
|
# TODO: add torch.* tests when we have proper namespacing on ATen functions
|
||||||
# test_namespace(torch)
|
# test_namespace(torch)
|
||||||
|
|
||||||
def test_linear_algebra_scalar_raises(self):
|
def test_linear_algebra_scalar_raises(self) -> None:
|
||||||
m = torch.randn(5, 5)
|
m = torch.randn(5, 5)
|
||||||
v = torch.randn(5)
|
v = torch.randn(5)
|
||||||
s = torch.tensor(7)
|
s = torch.tensor(7)
|
||||||
|
|
@ -492,8 +494,8 @@ class AbstractTestCases:
|
||||||
[1, 1, 1]]))
|
[1, 1, 1]]))
|
||||||
|
|
||||||
@slowTest
|
@slowTest
|
||||||
def test_mv(self):
|
def test_mv(self) -> None:
|
||||||
def _test_mv(m1, v1):
|
def _test_mv(m1: torch.Tensor, v1: torch.Tensor) -> None:
|
||||||
res1 = torch.mv(m1, v1)
|
res1 = torch.mv(m1, v1)
|
||||||
res2 = res1.clone().zero_()
|
res2 = res1.clone().zero_()
|
||||||
for i, j in iter_indices(m1):
|
for i, j in iter_indices(m1):
|
||||||
|
|
@ -4055,7 +4057,7 @@ tensor([[[1., 1., 1., ..., 1., 1., 1.],
|
||||||
self.assertEqual(x.__repr__(), str(x))
|
self.assertEqual(x.__repr__(), str(x))
|
||||||
self.assertExpectedInline(str(x), '''tensor([1.0000e+00, 1.0100e+03])''')
|
self.assertExpectedInline(str(x), '''tensor([1.0000e+00, 1.0100e+03])''')
|
||||||
|
|
||||||
def test_sizeof(self):
|
def test_sizeof(self) -> None:
|
||||||
sizeof_empty = torch.randn(0).storage().__sizeof__()
|
sizeof_empty = torch.randn(0).storage().__sizeof__()
|
||||||
sizeof_10 = torch.randn(10).storage().__sizeof__()
|
sizeof_10 = torch.randn(10).storage().__sizeof__()
|
||||||
sizeof_100 = torch.randn(100).storage().__sizeof__()
|
sizeof_100 = torch.randn(100).storage().__sizeof__()
|
||||||
|
|
@ -4068,7 +4070,7 @@ tensor([[[1., 1., 1., ..., 1., 1., 1.],
|
||||||
self.assertEqual((sizeof_100 - sizeof_empty) // (sizeof_10 - sizeof_empty), 10)
|
self.assertEqual((sizeof_100 - sizeof_empty) // (sizeof_10 - sizeof_empty), 10)
|
||||||
self.assertEqual((sizeof_100 - sizeof_empty) % (sizeof_10 - sizeof_empty), 0)
|
self.assertEqual((sizeof_100 - sizeof_empty) % (sizeof_10 - sizeof_empty), 0)
|
||||||
|
|
||||||
def test_unsqueeze(self):
|
def test_unsqueeze(self) -> None:
|
||||||
x = torch.randn(2, 3, 4)
|
x = torch.randn(2, 3, 4)
|
||||||
y = x.unsqueeze(1)
|
y = x.unsqueeze(1)
|
||||||
self.assertEqual(y, x.view(2, 1, 3, 4))
|
self.assertEqual(y, x.view(2, 1, 3, 4))
|
||||||
|
|
@ -4082,7 +4084,7 @@ tensor([[[1., 1., 1., ..., 1., 1., 1.],
|
||||||
y = x.clone().unsqueeze_(2)
|
y = x.clone().unsqueeze_(2)
|
||||||
self.assertEqual(y, x.contiguous().view(2, 4, 1))
|
self.assertEqual(y, x.contiguous().view(2, 4, 1))
|
||||||
|
|
||||||
def test_iter(self):
|
def test_iter(self) -> None:
|
||||||
x = torch.randn(5, 5)
|
x = torch.randn(5, 5)
|
||||||
for i, sub in enumerate(x):
|
for i, sub in enumerate(x):
|
||||||
self.assertEqual(sub, x[i])
|
self.assertEqual(sub, x[i])
|
||||||
|
|
@ -4090,7 +4092,7 @@ tensor([[[1., 1., 1., ..., 1., 1., 1.],
|
||||||
x = torch.Tensor()
|
x = torch.Tensor()
|
||||||
self.assertEqual(list(x), [])
|
self.assertEqual(list(x), [])
|
||||||
|
|
||||||
def test_accreal_type(self):
|
def test_accreal_type(self) -> None:
|
||||||
x = torch.ones(2, 3, 4)
|
x = torch.ones(2, 3, 4)
|
||||||
self.assertIsInstance(x.double().sum().item(), float)
|
self.assertIsInstance(x.double().sum().item(), float)
|
||||||
self.assertIsInstance(x.float().sum().item(), float)
|
self.assertIsInstance(x.float().sum().item(), float)
|
||||||
|
|
@ -4100,7 +4102,7 @@ tensor([[[1., 1., 1., ..., 1., 1., 1.],
|
||||||
self.assertIsInstance(x.char().sum().item(), int)
|
self.assertIsInstance(x.char().sum().item(), int)
|
||||||
self.assertIsInstance(x.byte().sum().item(), int)
|
self.assertIsInstance(x.byte().sum().item(), int)
|
||||||
|
|
||||||
def test_assertEqual(self):
|
def test_assertEqual(self) -> None:
|
||||||
x = torch.FloatTensor([0])
|
x = torch.FloatTensor([0])
|
||||||
self.assertEqual(x, 0)
|
self.assertEqual(x, 0)
|
||||||
xv = torch.autograd.Variable(x)
|
xv = torch.autograd.Variable(x)
|
||||||
|
|
@ -4114,10 +4116,10 @@ tensor([[[1., 1., 1., ..., 1., 1., 1.],
|
||||||
self.assertRaises(AssertionError,
|
self.assertRaises(AssertionError,
|
||||||
lambda: self.assertEqual(x, xv, rtol=4))
|
lambda: self.assertEqual(x, xv, rtol=4))
|
||||||
|
|
||||||
self.assertRaisesRegex(TypeError, "takes 3 positional arguments",
|
self.assertRaisesRegex(TypeError, "takes from 3 to 4 positional arguments",
|
||||||
lambda: self.assertEqual(x, xv, 1.0, ""))
|
lambda: self.assertEqual(x, xv, "", 1.0)) # type: ignore
|
||||||
|
|
||||||
def test_new(self):
|
def test_new(self) -> None:
|
||||||
x = torch.autograd.Variable(torch.Tensor())
|
x = torch.autograd.Variable(torch.Tensor())
|
||||||
y = torch.autograd.Variable(torch.randn(4, 4))
|
y = torch.autograd.Variable(torch.randn(4, 4))
|
||||||
z = torch.autograd.Variable(torch.IntTensor([1, 2, 3]))
|
z = torch.autograd.Variable(torch.IntTensor([1, 2, 3]))
|
||||||
|
|
@ -4142,7 +4144,7 @@ tensor([[[1., 1., 1., ..., 1., 1., 1.],
|
||||||
# TypeError would be better
|
# TypeError would be better
|
||||||
self.assertRaises(RuntimeError, lambda: x.new(z.storage()))
|
self.assertRaises(RuntimeError, lambda: x.new(z.storage()))
|
||||||
|
|
||||||
def test_empty_like(self):
|
def test_empty_like(self) -> None:
|
||||||
x = torch.autograd.Variable(torch.Tensor())
|
x = torch.autograd.Variable(torch.Tensor())
|
||||||
y = torch.autograd.Variable(torch.randn(4, 4))
|
y = torch.autograd.Variable(torch.randn(4, 4))
|
||||||
z = torch.autograd.Variable(torch.IntTensor([1, 2, 3]))
|
z = torch.autograd.Variable(torch.IntTensor([1, 2, 3]))
|
||||||
|
|
@ -4166,7 +4168,7 @@ tensor([[[1., 1., 1., ..., 1., 1., 1.],
|
||||||
self.assertEqual(pinned.data_ptr(), pinned.pin_memory().data_ptr())
|
self.assertEqual(pinned.data_ptr(), pinned.pin_memory().data_ptr())
|
||||||
|
|
||||||
@unittest.skipIf(not TEST_NUMPY, "Numpy not found")
|
@unittest.skipIf(not TEST_NUMPY, "Numpy not found")
|
||||||
def test_numpy_unresizable(self):
|
def test_numpy_unresizable(self) -> None:
|
||||||
x = np.zeros((2, 2))
|
x = np.zeros((2, 2))
|
||||||
y = torch.from_numpy(x)
|
y = torch.from_numpy(x)
|
||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
|
|
@ -4180,7 +4182,7 @@ tensor([[[1., 1., 1., ..., 1., 1., 1.],
|
||||||
w.resize((10, 10))
|
w.resize((10, 10))
|
||||||
|
|
||||||
@unittest.skipIf(not TEST_NUMPY, "Numpy not found")
|
@unittest.skipIf(not TEST_NUMPY, "Numpy not found")
|
||||||
def test_to_numpy(self):
|
def test_to_numpy(self) -> None:
|
||||||
def get_castable_tensor(shape, dtype):
|
def get_castable_tensor(shape, dtype):
|
||||||
if dtype.is_floating_point:
|
if dtype.is_floating_point:
|
||||||
dtype_info = torch.finfo(dtype)
|
dtype_info = torch.finfo(dtype)
|
||||||
|
|
@ -4284,7 +4286,7 @@ tensor([[[1., 1., 1., ..., 1., 1., 1.],
|
||||||
self.assertTrue(x[0][1] == 3)
|
self.assertTrue(x[0][1] == 3)
|
||||||
|
|
||||||
@unittest.skipIf(not TEST_NUMPY, "Numpy not found")
|
@unittest.skipIf(not TEST_NUMPY, "Numpy not found")
|
||||||
def test_to_numpy_bool(self):
|
def test_to_numpy_bool(self) -> None:
|
||||||
x = torch.tensor([True, False], dtype=torch.bool)
|
x = torch.tensor([True, False], dtype=torch.bool)
|
||||||
self.assertEqual(x.dtype, torch.bool)
|
self.assertEqual(x.dtype, torch.bool)
|
||||||
|
|
||||||
|
|
@ -4301,7 +4303,7 @@ tensor([[[1., 1., 1., ..., 1., 1., 1.],
|
||||||
self.assertEqual(x[0], y[0])
|
self.assertEqual(x[0], y[0])
|
||||||
|
|
||||||
@unittest.skipIf(not TEST_NUMPY, "Numpy not found")
|
@unittest.skipIf(not TEST_NUMPY, "Numpy not found")
|
||||||
def test_from_numpy(self):
|
def test_from_numpy(self) -> None:
|
||||||
dtypes = [
|
dtypes = [
|
||||||
np.double,
|
np.double,
|
||||||
np.float,
|
np.float,
|
||||||
|
|
@ -4374,7 +4376,7 @@ tensor([[[1., 1., 1., ..., 1., 1., 1.],
|
||||||
self.assertRaises(ValueError, lambda: torch.from_numpy(x))
|
self.assertRaises(ValueError, lambda: torch.from_numpy(x))
|
||||||
|
|
||||||
@unittest.skipIf(not TEST_NUMPY, "Numpy not found")
|
@unittest.skipIf(not TEST_NUMPY, "Numpy not found")
|
||||||
def test_ctor_with_numpy_scalar_ctor(self):
|
def test_ctor_with_numpy_scalar_ctor(self) -> None:
|
||||||
dtypes = [
|
dtypes = [
|
||||||
np.double,
|
np.double,
|
||||||
np.float,
|
np.float,
|
||||||
|
|
@ -4476,7 +4478,7 @@ tensor([[[1., 1., 1., ..., 1., 1., 1.],
|
||||||
self.assertEqual(geq2_x[i], geq2_array[i])
|
self.assertEqual(geq2_x[i], geq2_array[i])
|
||||||
|
|
||||||
@unittest.skipIf(not TEST_NUMPY, "Numpy not found")
|
@unittest.skipIf(not TEST_NUMPY, "Numpy not found")
|
||||||
def test_multiplication_numpy_scalar(self):
|
def test_multiplication_numpy_scalar(self) -> None:
|
||||||
for np_dtype in [np.float32, np.float64, np.int32, np.int64, np.int16, np.uint8]:
|
for np_dtype in [np.float32, np.float64, np.int32, np.int64, np.int16, np.uint8]:
|
||||||
for t_dtype in [torch.float, torch.double]:
|
for t_dtype in [torch.float, torch.double]:
|
||||||
np_sc = np_dtype(2.0)
|
np_sc = np_dtype(2.0)
|
||||||
|
|
@ -15088,7 +15090,7 @@ class TestTorchDeviceType(TestCase):
|
||||||
|
|
||||||
@onlyCPU
|
@onlyCPU
|
||||||
@dtypes(torch.float, torch.double)
|
@dtypes(torch.float, torch.double)
|
||||||
def test_hardshrink_edge_cases(self, device, dtype):
|
def test_hardshrink_edge_cases(self, device, dtype) -> None:
|
||||||
def h(values, l_expected):
|
def h(values, l_expected):
|
||||||
for l, expected in l_expected.items():
|
for l, expected in l_expected.items():
|
||||||
values_tensor = torch.tensor([float(v) for v in values],
|
values_tensor = torch.tensor([float(v) for v in values],
|
||||||
|
|
@ -15113,7 +15115,7 @@ class TestTorchDeviceType(TestCase):
|
||||||
@slowTest
|
@slowTest
|
||||||
@unittest.skipIf(not TEST_NUMPY, 'Numpy not found')
|
@unittest.skipIf(not TEST_NUMPY, 'Numpy not found')
|
||||||
@dtypes(torch.double)
|
@dtypes(torch.double)
|
||||||
def test_einsum(self, device, dtype):
|
def test_einsum(self, device: torch.device, dtype: torch.dtype) -> None:
|
||||||
# test cases taken from https://gist.github.com/rockt/15ee013889d65342088e9260a377dc8f
|
# test cases taken from https://gist.github.com/rockt/15ee013889d65342088e9260a377dc8f
|
||||||
x = torch.randn(5, dtype=dtype, device=device)
|
x = torch.randn(5, dtype=dtype, device=device)
|
||||||
y = torch.randn(7, dtype=dtype, device=device)
|
y = torch.randn(7, dtype=dtype, device=device)
|
||||||
|
|
@ -15129,7 +15131,9 @@ class TestTorchDeviceType(TestCase):
|
||||||
l = torch.randn(5, 10, dtype=dtype, device=device)
|
l = torch.randn(5, 10, dtype=dtype, device=device)
|
||||||
r = torch.randn(5, 20, dtype=dtype, device=device)
|
r = torch.randn(5, 20, dtype=dtype, device=device)
|
||||||
w = torch.randn(30, 10, 20, dtype=dtype, device=device)
|
w = torch.randn(30, 10, 20, dtype=dtype, device=device)
|
||||||
test_list = [
|
test_list: List[Union[Tuple[str, torch.Tensor],
|
||||||
|
Tuple[str, torch.Tensor, torch.Tensor],
|
||||||
|
Tuple[str, torch.Tensor, torch.Tensor, torch.Tensor]]] = [
|
||||||
# -- Vector
|
# -- Vector
|
||||||
("i->", x), # sum
|
("i->", x), # sum
|
||||||
("i,i->", x, x), # dot
|
("i,i->", x, x), # dot
|
||||||
|
|
@ -15184,8 +15188,8 @@ class TestTorchDeviceType(TestCase):
|
||||||
|
|
||||||
@onlyCPU
|
@onlyCPU
|
||||||
@dtypes(torch.bool, torch.double)
|
@dtypes(torch.bool, torch.double)
|
||||||
def test_sum_all(self, device, dtype):
|
def test_sum_all(self, device, dtype) -> None:
|
||||||
def check_sum_all(tensor):
|
def check_sum_all(tensor: torch.Tensor) -> None:
|
||||||
pylist = tensor.reshape(-1).tolist()
|
pylist = tensor.reshape(-1).tolist()
|
||||||
self.assertEqual(tensor.sum(), sum(pylist))
|
self.assertEqual(tensor.sum(), sum(pylist))
|
||||||
|
|
||||||
|
|
@ -15296,7 +15300,7 @@ class TestTorchDeviceType(TestCase):
|
||||||
|
|
||||||
@onlyCPU
|
@onlyCPU
|
||||||
@dtypes(torch.double)
|
@dtypes(torch.double)
|
||||||
def test_sum_out(self, device, dtype):
|
def test_sum_out(self, device, dtype: torch.dtype) -> None:
|
||||||
x = torch.rand(100, 100, dtype=dtype, device=device)
|
x = torch.rand(100, 100, dtype=dtype, device=device)
|
||||||
res1 = torch.sum(x, 1)
|
res1 = torch.sum(x, 1)
|
||||||
res2 = torch.tensor((), dtype=dtype, device=device)
|
res2 = torch.tensor((), dtype=dtype, device=device)
|
||||||
|
|
@ -17696,8 +17700,8 @@ class TestDevicePrecision(TestCase):
|
||||||
torch.sum(x, dim=[0], dtype=torch.float32, out=y)
|
torch.sum(x, dim=[0], dtype=torch.float32, out=y)
|
||||||
|
|
||||||
@deviceCountAtLeast(1)
|
@deviceCountAtLeast(1)
|
||||||
def test_advancedindex_mixed_cpu_devices(self, devices):
|
def test_advancedindex_mixed_cpu_devices(self, devices) -> None:
|
||||||
def test(x, ia, ib):
|
def test(x: torch.Tensor, ia: torch.Tensor, ib: torch.Tensor) -> None:
|
||||||
# test getitem
|
# test getitem
|
||||||
self.assertEqual(x[:, ia, None, ib, 0].cpu(),
|
self.assertEqual(x[:, ia, None, ib, 0].cpu(),
|
||||||
x.cpu()[:, ia.cpu(), None, ib.cpu(), 0])
|
x.cpu()[:, ia.cpu(), None, ib.cpu(), 0])
|
||||||
|
|
@ -17746,7 +17750,7 @@ class TestDevicePrecision(TestCase):
|
||||||
ib = ib.to(other_device)
|
ib = ib.to(other_device)
|
||||||
test(x, ia, ib)
|
test(x, ia, ib)
|
||||||
|
|
||||||
def test_copy_broadcast(self, device):
|
def test_copy_broadcast(self, device) -> None:
|
||||||
x = torch.randn(10, 5)
|
x = torch.randn(10, 5)
|
||||||
y = torch.randn(5, device=device)
|
y = torch.randn(5, device=device)
|
||||||
x.copy_(y)
|
x.copy_(y)
|
||||||
|
|
@ -17793,7 +17797,7 @@ class TestDevicePrecision(TestCase):
|
||||||
output = torch.zeros_like(x)
|
output = torch.zeros_like(x)
|
||||||
self.assertEqual(output, expected)
|
self.assertEqual(output, expected)
|
||||||
|
|
||||||
def test_ones_like(self, device):
|
def test_ones_like(self, device) -> None:
|
||||||
expected = torch.ones(100, 100, device=device)
|
expected = torch.ones(100, 100, device=device)
|
||||||
|
|
||||||
res1 = torch.ones_like(expected)
|
res1 = torch.ones_like(expected)
|
||||||
|
|
@ -17869,7 +17873,7 @@ class TestViewOps(TestCase):
|
||||||
self.assertEqual(a[5:].real, a.real[5:])
|
self.assertEqual(a[5:].real, a.real[5:])
|
||||||
self.assertEqual(a[5:].imag, a.imag[5:])
|
self.assertEqual(a[5:].imag, a.imag[5:])
|
||||||
|
|
||||||
def test_diagonal_view(self, device):
|
def test_diagonal_view(self, device) -> None:
|
||||||
t = torch.ones((5, 5), device=device)
|
t = torch.ones((5, 5), device=device)
|
||||||
v = torch.diagonal(t)
|
v = torch.diagonal(t)
|
||||||
self.assertTrue(self.is_view_of(t, v))
|
self.assertTrue(self.is_view_of(t, v))
|
||||||
|
|
@ -17884,7 +17888,7 @@ class TestViewOps(TestCase):
|
||||||
v[0, 0] = 0
|
v[0, 0] = 0
|
||||||
self.assertEqual(t[0, 0, 1], v[0, 0])
|
self.assertEqual(t[0, 0, 1], v[0, 0])
|
||||||
|
|
||||||
def test_select_view(self, device):
|
def test_select_view(self, device) -> None:
|
||||||
t = torch.ones((5, 5), device=device)
|
t = torch.ones((5, 5), device=device)
|
||||||
v = t.select(0, 2)
|
v = t.select(0, 2)
|
||||||
self.assertTrue(self.is_view_of(t, v))
|
self.assertTrue(self.is_view_of(t, v))
|
||||||
|
|
@ -17892,7 +17896,7 @@ class TestViewOps(TestCase):
|
||||||
v[0] = 0
|
v[0] = 0
|
||||||
self.assertEqual(t[2, 0], v[0])
|
self.assertEqual(t[2, 0], v[0])
|
||||||
|
|
||||||
def test_unbind_view(self, device):
|
def test_unbind_view(self, device) -> None:
|
||||||
t = torch.zeros((5, 5), device=device)
|
t = torch.zeros((5, 5), device=device)
|
||||||
tup = torch.unbind(t)
|
tup = torch.unbind(t)
|
||||||
|
|
||||||
|
|
@ -17902,7 +17906,7 @@ class TestViewOps(TestCase):
|
||||||
v[0] = idx + 1
|
v[0] = idx + 1
|
||||||
self.assertEqual(t[idx, 0], v[0])
|
self.assertEqual(t[idx, 0], v[0])
|
||||||
|
|
||||||
def test_expand_view(self, device):
|
def test_expand_view(self, device) -> None:
|
||||||
t = torch.ones((5, 1), device=device)
|
t = torch.ones((5, 1), device=device)
|
||||||
v = t.expand(5, 5)
|
v = t.expand(5, 5)
|
||||||
self.assertTrue(self.is_view_of(t, v))
|
self.assertTrue(self.is_view_of(t, v))
|
||||||
|
|
@ -17927,7 +17931,7 @@ class TestViewOps(TestCase):
|
||||||
v[0, 0] = 0
|
v[0, 0] = 0
|
||||||
self.assertEqual(t[0, 2], v[0, 0])
|
self.assertEqual(t[0, 2], v[0, 0])
|
||||||
|
|
||||||
def test_permute_view(self, device):
|
def test_permute_view(self, device) -> None:
|
||||||
t = torch.ones((5, 5), device=device)
|
t = torch.ones((5, 5), device=device)
|
||||||
v = t.permute(1, 0)
|
v = t.permute(1, 0)
|
||||||
self.assertTrue(self.is_view_of(t, v))
|
self.assertTrue(self.is_view_of(t, v))
|
||||||
|
|
@ -18154,7 +18158,7 @@ _signed_types_no_half = [
|
||||||
torch.int8, torch.short, torch.int, torch.long
|
torch.int8, torch.short, torch.int, torch.long
|
||||||
]
|
]
|
||||||
|
|
||||||
_cpu_types = []
|
_cpu_types: List[torch.dtype] = []
|
||||||
|
|
||||||
_unsigned_types = [torch.uint8]
|
_unsigned_types = [torch.uint8]
|
||||||
|
|
||||||
|
|
@ -18185,7 +18189,7 @@ def _convert_t(dtype, device):
|
||||||
# Requesting a half CPU tensor returns a float CPU tensor with
|
# Requesting a half CPU tensor returns a float CPU tensor with
|
||||||
# values representable by a half.
|
# values representable by a half.
|
||||||
# Initialization uses randint for non-float types and randn for float types.
|
# Initialization uses randint for non-float types and randn for float types.
|
||||||
def _make_tensor(shape, dtype, device, fill_ones=False):
|
def _make_tensor(shape, dtype, device, fill_ones=False) -> torch.Tensor:
|
||||||
# Returns a tensor filled with ones
|
# Returns a tensor filled with ones
|
||||||
if fill_ones:
|
if fill_ones:
|
||||||
return torch.ones(*shape, dtype=_convert_t(dtype, device), device=device)
|
return torch.ones(*shape, dtype=_convert_t(dtype, device), device=device)
|
||||||
|
|
@ -18206,7 +18210,7 @@ def _make_tensor(shape, dtype, device, fill_ones=False):
|
||||||
# Default: returns a tensor with random float values
|
# Default: returns a tensor with random float values
|
||||||
return torch.randn(shape, dtype=dtype, device=device).to(dtype=dtype)
|
return torch.randn(shape, dtype=dtype, device=device).to(dtype=dtype)
|
||||||
|
|
||||||
def _small_0d(dtype, device):
|
def _small_0d(dtype, device) -> torch.Tensor:
|
||||||
return _make_tensor((1,), dtype, device).squeeze()
|
return _make_tensor((1,), dtype, device).squeeze()
|
||||||
|
|
||||||
def _small_2d(dtype, device, has_zeros=True, fill_ones=False, oneish=False):
|
def _small_2d(dtype, device, has_zeros=True, fill_ones=False, oneish=False):
|
||||||
|
|
@ -18635,8 +18639,8 @@ def generate_test_function(cls,
|
||||||
float_precision,
|
float_precision,
|
||||||
dtype_list,
|
dtype_list,
|
||||||
dtype_cpu_list,
|
dtype_cpu_list,
|
||||||
decorators):
|
decorators) -> None:
|
||||||
def fn(self, device, dtype):
|
def fn(self, device, dtype) -> None:
|
||||||
# Generates the CPU inputs
|
# Generates the CPU inputs
|
||||||
# Note: CPU tensors are never torch.half
|
# Note: CPU tensors are never torch.half
|
||||||
cpu_tensor = tensor_ctor(dtype, 'cpu')
|
cpu_tensor = tensor_ctor(dtype, 'cpu')
|
||||||
|
|
@ -18683,7 +18687,7 @@ def generate_test_function(cls,
|
||||||
setattr(cls, test_name, fn)
|
setattr(cls, test_name, fn)
|
||||||
|
|
||||||
# Instantiates variants of tensor_op_tests and adds them to the given class.
|
# Instantiates variants of tensor_op_tests and adds them to the given class.
|
||||||
def generate_tensor_op_tests(cls):
|
def generate_tensor_op_tests(cls) -> None:
|
||||||
|
|
||||||
def caller(cls,
|
def caller(cls,
|
||||||
op_str,
|
op_str,
|
||||||
|
|
@ -18853,7 +18857,7 @@ def generate_torch_test_functions(cls, testmeta, inplace):
|
||||||
|
|
||||||
torch.testing.assert_allclose(actual, expected, rtol=testmeta.rtol, atol=testmeta.atol)
|
torch.testing.assert_allclose(actual, expected, rtol=testmeta.rtol, atol=testmeta.atol)
|
||||||
|
|
||||||
def fn_non_contig(self, device, dtype):
|
def fn_non_contig(self, device, dtype) -> None:
|
||||||
shapes = [(5, 7), (1024,)]
|
shapes = [(5, 7), (1024,)]
|
||||||
for shape in shapes:
|
for shape in shapes:
|
||||||
contig = _make_tensor(shape, dtype=dtype, device=device)
|
contig = _make_tensor(shape, dtype=dtype, device=device)
|
||||||
|
|
|
||||||
|
|
@ -43,7 +43,8 @@ read gen_pyi for the gory details.
|
||||||
|
|
||||||
needed_modules = set()
|
needed_modules = set()
|
||||||
|
|
||||||
FACTORY_PARAMS = "dtype: Optional[_dtype]=None, device: Union[_device, str, None]=None, requires_grad: _bool=False"
|
DEVICE_PARAM = "device: Union[_device, str, None]=None"
|
||||||
|
FACTORY_PARAMS = f"dtype: Optional[_dtype]=None, {DEVICE_PARAM}, requires_grad: _bool=False"
|
||||||
|
|
||||||
# this could be more precise w.r.t list contents etc. How to do Ellipsis?
|
# this could be more precise w.r.t list contents etc. How to do Ellipsis?
|
||||||
INDICES = "indices: Union[None, _int, slice, Tensor, List, Tuple]"
|
INDICES = "indices: Union[None, _int, slice, Tensor, List, Tuple]"
|
||||||
|
|
@ -534,6 +535,18 @@ def gen_pyi(declarations_path, out):
|
||||||
'new_ones': ['def new_ones(self, size: {}, {}) -> Tensor: ...'.
|
'new_ones': ['def new_ones(self, size: {}, {}) -> Tensor: ...'.
|
||||||
format(type_to_python('IntArrayRef'), FACTORY_PARAMS)],
|
format(type_to_python('IntArrayRef'), FACTORY_PARAMS)],
|
||||||
'new_tensor': ["def new_tensor(self, data: Any, {}) -> Tensor: ...".format(FACTORY_PARAMS)],
|
'new_tensor': ["def new_tensor(self, data: Any, {}) -> Tensor: ...".format(FACTORY_PARAMS)],
|
||||||
|
# new and __init__ have the same signatures differ only in return type
|
||||||
|
# Adapted from legacy_tensor_ctor and legacy_tensor_new
|
||||||
|
'new': ['def new(self, *args: Any, {}) ->Tensor: ...'.format(DEVICE_PARAM),
|
||||||
|
'def new(self, storage: Storage) -> Tensor: ...',
|
||||||
|
'def new(self, other: Tensor) -> Tensor: ...',
|
||||||
|
'def new(self, size: {}, *, {}) -> Tensor: ...'.format(type_to_python('IntArrayRef'), DEVICE_PARAM),
|
||||||
|
],
|
||||||
|
'__init__': ['def __init__(self, *args: Any, {}) -> None: ...'.format(DEVICE_PARAM),
|
||||||
|
'def __init__(self, storage: Storage) -> None: ...',
|
||||||
|
'def __init__(self, other: Tensor) -> None: ...',
|
||||||
|
'def __init__(self, size: {}, *, {}) -> None: ...'.format(type_to_python('IntArrayRef'), DEVICE_PARAM),
|
||||||
|
],
|
||||||
# clamp has no default values in the Declarations
|
# clamp has no default values in the Declarations
|
||||||
'clamp': ["def clamp(self, min: _float=-inf, max: _float=inf,"
|
'clamp': ["def clamp(self, min: _float=-inf, max: _float=inf,"
|
||||||
" *, out: Optional[Tensor]=None) -> Tensor: ..."],
|
" *, out: Optional[Tensor]=None) -> Tensor: ..."],
|
||||||
|
|
@ -633,7 +646,7 @@ def gen_pyi(declarations_path, out):
|
||||||
|
|
||||||
legacy_class_hints = []
|
legacy_class_hints = []
|
||||||
for c in ('DoubleTensor', 'FloatTensor', 'LongTensor', 'IntTensor',
|
for c in ('DoubleTensor', 'FloatTensor', 'LongTensor', 'IntTensor',
|
||||||
'ShortTensor', 'CharTensor', 'ByteTensor', 'BoolTensor'):
|
'ShortTensor', 'HalfTensor', 'CharTensor', 'ByteTensor', 'BoolTensor'):
|
||||||
legacy_class_hints.append('class {}(Tensor): ...'.format(c))
|
legacy_class_hints.append('class {}(Tensor): ...'.format(c))
|
||||||
|
|
||||||
# Generate type signatures for dtype classes
|
# Generate type signatures for dtype classes
|
||||||
|
|
|
||||||
|
|
@ -38,7 +38,10 @@ class Size(Tuple[_int, ...]):
|
||||||
|
|
||||||
# Defined in torch/csrc/Dtype.cpp
|
# Defined in torch/csrc/Dtype.cpp
|
||||||
class dtype:
|
class dtype:
|
||||||
# TODO: is_floating_point, is_complex, is_Signed, __reduce__
|
# TODO: __reduce__
|
||||||
|
is_floating_point: _bool
|
||||||
|
is_complex: _bool
|
||||||
|
is_signed: _bool
|
||||||
...
|
...
|
||||||
|
|
||||||
# Defined in torch/csrc/TypeInfo.cpp
|
# Defined in torch/csrc/TypeInfo.cpp
|
||||||
|
|
@ -206,4 +209,5 @@ class _TensorBase(object):
|
||||||
layout: _layout
|
layout: _layout
|
||||||
real: Tensor
|
real: Tensor
|
||||||
imag: Tensor
|
imag: Tensor
|
||||||
|
_version: _bool
|
||||||
${tensor_method_hints}
|
${tensor_method_hints}
|
||||||
|
|
|
||||||
|
|
@ -5,6 +5,7 @@ The testing package contains testing-specific utilities.
|
||||||
import torch
|
import torch
|
||||||
import random
|
import random
|
||||||
import math
|
import math
|
||||||
|
from typing import cast, List, Optional, Tuple, Union
|
||||||
|
|
||||||
FileCheck = torch._C.FileCheck
|
FileCheck = torch._C.FileCheck
|
||||||
|
|
||||||
|
|
@ -18,7 +19,7 @@ randn_like = torch.randn_like
|
||||||
# Helper function that returns True when the dtype is an integral dtype,
|
# Helper function that returns True when the dtype is an integral dtype,
|
||||||
# False otherwise.
|
# False otherwise.
|
||||||
# TODO: implement numpy-like issubdtype
|
# TODO: implement numpy-like issubdtype
|
||||||
def is_integral(dtype):
|
def is_integral(dtype: torch.dtype) -> bool:
|
||||||
# Skip complex/quantized types
|
# Skip complex/quantized types
|
||||||
dtypes = [x for x in get_all_dtypes() if x not in get_all_complex_dtypes()]
|
dtypes = [x for x in get_all_dtypes() if x not in get_all_complex_dtypes()]
|
||||||
return dtype in dtypes and not dtype.is_floating_point
|
return dtype in dtypes and not dtype.is_floating_point
|
||||||
|
|
@ -40,6 +41,8 @@ def _unravel_index(flat_index, shape):
|
||||||
return res[0]
|
return res[0]
|
||||||
|
|
||||||
return tuple(res[::-1])
|
return tuple(res[::-1])
|
||||||
|
# (bool, msg) tuple, where msg is None if and only if bool is True.
|
||||||
|
_compare_return_type = Tuple[bool, Optional[str]]
|
||||||
|
|
||||||
# Compares two tensors with the same size on the same device and with the same
|
# Compares two tensors with the same size on the same device and with the same
|
||||||
# dtype for equality.
|
# dtype for equality.
|
||||||
|
|
@ -63,7 +66,8 @@ def _unravel_index(flat_index, shape):
|
||||||
#
|
#
|
||||||
# Bool tensors are equal only if they are identical, regardless of
|
# Bool tensors are equal only if they are identical, regardless of
|
||||||
# the rtol and atol values.
|
# the rtol and atol values.
|
||||||
def _compare_tensors_internal(a, b, *, rtol, atol, equal_nan):
|
def _compare_tensors_internal(a: torch.Tensor, b: torch.Tensor, *, rtol, atol, equal_nan: bool) -> _compare_return_type:
|
||||||
|
debug_msg : Optional[str]
|
||||||
# Integer (including bool) comparisons are identity comparisons
|
# Integer (including bool) comparisons are identity comparisons
|
||||||
# when rtol is zero and atol is less than one
|
# when rtol is zero and atol is less than one
|
||||||
if (is_integral(a.dtype) and rtol == 0 and atol < 1) or a.dtype is torch.bool:
|
if (is_integral(a.dtype) and rtol == 0 and atol < 1) or a.dtype is torch.bool:
|
||||||
|
|
@ -99,7 +103,7 @@ def _compare_tensors_internal(a, b, *, rtol, atol, equal_nan):
|
||||||
equal_nan=equal_nan)
|
equal_nan=equal_nan)
|
||||||
|
|
||||||
if not real_result:
|
if not real_result:
|
||||||
debug_msg = "Real parts failed to compare as equal! " + debug_msg
|
debug_msg = "Real parts failed to compare as equal! " + cast(str, debug_msg)
|
||||||
return (real_result, debug_msg)
|
return (real_result, debug_msg)
|
||||||
|
|
||||||
a_imag = a.imag
|
a_imag = a.imag
|
||||||
|
|
@ -109,7 +113,7 @@ def _compare_tensors_internal(a, b, *, rtol, atol, equal_nan):
|
||||||
equal_nan=equal_nan)
|
equal_nan=equal_nan)
|
||||||
|
|
||||||
if not imag_result:
|
if not imag_result:
|
||||||
debug_msg = "Imaginary parts failed to compare as equal! " + debug_msg
|
debug_msg = "Imaginary parts failed to compare as equal! " + cast(str, debug_msg)
|
||||||
return (imag_result, debug_msg)
|
return (imag_result, debug_msg)
|
||||||
|
|
||||||
return (True, None)
|
return (True, None)
|
||||||
|
|
@ -149,8 +153,8 @@ def _compare_tensors_internal(a, b, *, rtol, atol, equal_nan):
|
||||||
|
|
||||||
# Checks if two scalars are equal(-ish), returning (True, None)
|
# Checks if two scalars are equal(-ish), returning (True, None)
|
||||||
# when they are and (False, debug_msg) when they are not.
|
# when they are and (False, debug_msg) when they are not.
|
||||||
def _compare_scalars_internal(a, b, *, rtol, atol, equal_nan):
|
def _compare_scalars_internal(a, b, *, rtol: float, atol: float, equal_nan: bool) -> _compare_return_type:
|
||||||
def _helper(a, b, s):
|
def _helper(a, b, s) -> _compare_return_type:
|
||||||
# Short-circuits on identity
|
# Short-circuits on identity
|
||||||
if a == b or (equal_nan and a != a and b != b):
|
if a == b or (equal_nan and a != a and b != b):
|
||||||
return (True, None)
|
return (True, None)
|
||||||
|
|
@ -194,7 +198,7 @@ def _compare_scalars_internal(a, b, *, rtol, atol, equal_nan):
|
||||||
|
|
||||||
return _helper(a, b, " ")
|
return _helper(a, b, " ")
|
||||||
|
|
||||||
def assert_allclose(actual, expected, rtol=None, atol=None, equal_nan=True, msg=''):
|
def assert_allclose(actual, expected, rtol=None, atol=None, equal_nan=True, msg='') -> None:
|
||||||
if not isinstance(actual, torch.Tensor):
|
if not isinstance(actual, torch.Tensor):
|
||||||
actual = torch.tensor(actual)
|
actual = torch.tensor(actual)
|
||||||
if not isinstance(expected, torch.Tensor):
|
if not isinstance(expected, torch.Tensor):
|
||||||
|
|
@ -218,7 +222,7 @@ def assert_allclose(actual, expected, rtol=None, atol=None, equal_nan=True, msg=
|
||||||
|
|
||||||
raise AssertionError(msg)
|
raise AssertionError(msg)
|
||||||
|
|
||||||
def make_non_contiguous(tensor):
|
def make_non_contiguous(tensor: torch.Tensor) -> torch.Tensor:
|
||||||
if tensor.numel() <= 1: # can't make non-contiguous
|
if tensor.numel() <= 1: # can't make non-contiguous
|
||||||
return tensor.clone()
|
return tensor.clone()
|
||||||
osize = list(tensor.size())
|
osize = list(tensor.size())
|
||||||
|
|
@ -247,7 +251,7 @@ def make_non_contiguous(tensor):
|
||||||
return input.data
|
return input.data
|
||||||
|
|
||||||
|
|
||||||
def get_all_dtypes(include_half=True, include_bfloat16=True, include_bool=True, include_complex=True):
|
def get_all_dtypes(include_half=True, include_bfloat16=True, include_bool=True, include_complex=True) -> List[torch.dtype]:
|
||||||
dtypes = get_all_int_dtypes() + get_all_fp_dtypes(include_half=include_half, include_bfloat16=include_bfloat16)
|
dtypes = get_all_int_dtypes() + get_all_fp_dtypes(include_half=include_half, include_bfloat16=include_bfloat16)
|
||||||
if include_bool:
|
if include_bool:
|
||||||
dtypes.append(torch.bool)
|
dtypes.append(torch.bool)
|
||||||
|
|
@ -256,20 +260,20 @@ def get_all_dtypes(include_half=True, include_bfloat16=True, include_bool=True,
|
||||||
return dtypes
|
return dtypes
|
||||||
|
|
||||||
|
|
||||||
def get_all_math_dtypes(device):
|
def get_all_math_dtypes(device) -> List[torch.dtype]:
|
||||||
return get_all_int_dtypes() + get_all_fp_dtypes(include_half=device.startswith('cuda'),
|
return get_all_int_dtypes() + get_all_fp_dtypes(include_half=device.startswith('cuda'),
|
||||||
include_bfloat16=False) + get_all_complex_dtypes()
|
include_bfloat16=False) + get_all_complex_dtypes()
|
||||||
|
|
||||||
|
|
||||||
def get_all_complex_dtypes():
|
def get_all_complex_dtypes() -> List[torch.dtype]:
|
||||||
return [torch.complex64, torch.complex128]
|
return [torch.complex64, torch.complex128]
|
||||||
|
|
||||||
|
|
||||||
def get_all_int_dtypes():
|
def get_all_int_dtypes() -> List[torch.dtype]:
|
||||||
return [torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64]
|
return [torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64]
|
||||||
|
|
||||||
|
|
||||||
def get_all_fp_dtypes(include_half=True, include_bfloat16=True):
|
def get_all_fp_dtypes(include_half=True, include_bfloat16=True) -> List[torch.dtype]:
|
||||||
dtypes = [torch.float32, torch.float64]
|
dtypes = [torch.float32, torch.float64]
|
||||||
if include_half:
|
if include_half:
|
||||||
dtypes.append(torch.float16)
|
dtypes.append(torch.float16)
|
||||||
|
|
@ -278,7 +282,7 @@ def get_all_fp_dtypes(include_half=True, include_bfloat16=True):
|
||||||
return dtypes
|
return dtypes
|
||||||
|
|
||||||
|
|
||||||
def get_all_device_types():
|
def get_all_device_types() -> List[str]:
|
||||||
return ['cpu'] if not torch.cuda.is_available() else ['cpu', 'cuda']
|
return ['cpu'] if not torch.cuda.is_available() else ['cpu', 'cuda']
|
||||||
|
|
||||||
# 'dtype': (rtol, atol)
|
# 'dtype': (rtol, atol)
|
||||||
|
|
@ -289,7 +293,7 @@ _default_tolerances = {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def _get_default_tolerance(a, b=None):
|
def _get_default_tolerance(a, b=None) -> Tuple[float, float]:
|
||||||
if b is None:
|
if b is None:
|
||||||
dtype = str(a.dtype).split('.')[-1] # e.g. "float32"
|
dtype = str(a.dtype).split('.')[-1] # e.g. "float32"
|
||||||
return _default_tolerances.get(dtype, (0, 0))
|
return _default_tolerances.get(dtype, (0, 0))
|
||||||
|
|
|
||||||
|
|
@ -33,9 +33,10 @@ import json
|
||||||
from urllib.request import urlopen
|
from urllib.request import urlopen
|
||||||
import __main__
|
import __main__
|
||||||
import errno
|
import errno
|
||||||
|
from typing import cast, Any, Iterable, Optional
|
||||||
|
|
||||||
from torch.testing._internal import expecttest
|
from torch.testing._internal import expecttest
|
||||||
from torch.testing import _compare_tensors_internal, _compare_scalars_internal
|
from torch.testing import _compare_tensors_internal, _compare_scalars_internal, _compare_return_type
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.cuda
|
import torch.cuda
|
||||||
|
|
@ -947,8 +948,8 @@ class TestCase(expecttest.TestCase):
|
||||||
# NOTE: this function checks the tensors' devices, sizes, and dtypes
|
# NOTE: this function checks the tensors' devices, sizes, and dtypes
|
||||||
# and acquires the appropriate device, dtype, rtol and atol to compare
|
# and acquires the appropriate device, dtype, rtol and atol to compare
|
||||||
# them with. It then calls _compare_tensors_internal.
|
# them with. It then calls _compare_tensors_internal.
|
||||||
def _compareTensors(self, a, b, *, rtol=None, atol=None, equal_nan=True,
|
def _compareTensors(self, a, b, *, rtol: Optional[float] = None, atol=None, equal_nan=True,
|
||||||
exact_dtype=True, exact_device=False):
|
exact_dtype=True, exact_device=False) -> _compare_return_type:
|
||||||
assert (atol is None) == (rtol is None)
|
assert (atol is None) == (rtol is None)
|
||||||
if not isinstance(a, torch.Tensor):
|
if not isinstance(a, torch.Tensor):
|
||||||
return (False, "argument a, {0}, to _compareTensors is not a tensor!".format(a))
|
return (False, "argument a, {0}, to _compareTensors is not a tensor!".format(a))
|
||||||
|
|
@ -993,7 +994,8 @@ class TestCase(expecttest.TestCase):
|
||||||
# when they are and (False, debug_msg) when they are not.
|
# when they are and (False, debug_msg) when they are not.
|
||||||
# NOTE: this function just acquires rtol and atol
|
# NOTE: this function just acquires rtol and atol
|
||||||
# before calling _compare_scalars_internal.
|
# before calling _compare_scalars_internal.
|
||||||
def _compareScalars(self, a, b, *, rtol=None, atol=None, equal_nan=True):
|
def _compareScalars(self, a, b, *,
|
||||||
|
rtol: Optional[float] = None, atol: Optional[float] = None, equal_nan=True) -> _compare_return_type:
|
||||||
# Acquires rtol and atol
|
# Acquires rtol and atol
|
||||||
assert (atol is None) == (rtol is None)
|
assert (atol is None) == (rtol is None)
|
||||||
if rtol is None:
|
if rtol is None:
|
||||||
|
|
@ -1005,17 +1007,18 @@ class TestCase(expecttest.TestCase):
|
||||||
rtol, atol = 0, 0
|
rtol, atol = 0, 0
|
||||||
atol = max(atol, self.precision)
|
atol = max(atol, self.precision)
|
||||||
|
|
||||||
return _compare_scalars_internal(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan)
|
return _compare_scalars_internal(a, b, rtol=cast(float, rtol), atol=cast(float, atol), equal_nan=equal_nan)
|
||||||
|
|
||||||
def assertEqualIgnoreType(self, *args, **kwargs):
|
def assertEqualIgnoreType(self, *args, **kwargs) -> None:
|
||||||
# If you are seeing this function used, that means test is written wrongly
|
# If you are seeing this function used, that means test is written wrongly
|
||||||
# and deserves detailed investigation
|
# and deserves detailed investigation
|
||||||
return self.assertEqual(*args, exact_dtype=False, **kwargs)
|
return self.assertEqual(*args, exact_dtype=False, **kwargs)
|
||||||
|
|
||||||
# Compares x and y
|
# Compares x and y
|
||||||
# TODO: default exact_device to True
|
# TODO: default exact_device to True
|
||||||
def assertEqual(self, x, y, *, atol=None, rtol=None, equal_nan=True,
|
def assertEqual(self, x, y, msg: Optional[str] = None, *,
|
||||||
exact_dtype=True, exact_device=False, msg=None):
|
atol: Optional[float] = None, rtol: Optional[float] = None,
|
||||||
|
equal_nan=True, exact_dtype=True, exact_device=False) -> None:
|
||||||
assert (atol is None) == (rtol is None), "If one of atol or rtol is specified the other must be, too"
|
assert (atol is None) == (rtol is None), "If one of atol or rtol is specified the other must be, too"
|
||||||
|
|
||||||
# Tensor x Number and Number x Tensor comparisons
|
# Tensor x Number and Number x Tensor comparisons
|
||||||
|
|
@ -1045,6 +1048,7 @@ class TestCase(expecttest.TestCase):
|
||||||
exact_device=exact_device)
|
exact_device=exact_device)
|
||||||
|
|
||||||
if not indices_result and msg is None:
|
if not indices_result and msg is None:
|
||||||
|
assert debug_msg is not None
|
||||||
msg = "Sparse tensor indices failed to compare as equal! " + debug_msg
|
msg = "Sparse tensor indices failed to compare as equal! " + debug_msg
|
||||||
self.assertTrue(indices_result, msg=msg)
|
self.assertTrue(indices_result, msg=msg)
|
||||||
|
|
||||||
|
|
@ -1054,6 +1058,7 @@ class TestCase(expecttest.TestCase):
|
||||||
exact_device=exact_device)
|
exact_device=exact_device)
|
||||||
|
|
||||||
if not values_result and msg is None:
|
if not values_result and msg is None:
|
||||||
|
assert debug_msg is not None
|
||||||
msg = "Sparse tensor values failed to compare as equal! " + debug_msg
|
msg = "Sparse tensor values failed to compare as equal! " + debug_msg
|
||||||
self.assertTrue(values_result, msg=msg)
|
self.assertTrue(values_result, msg=msg)
|
||||||
elif x.is_quantized and y.is_quantized:
|
elif x.is_quantized and y.is_quantized:
|
||||||
|
|
@ -1086,6 +1091,7 @@ class TestCase(expecttest.TestCase):
|
||||||
exact_device=exact_device)
|
exact_device=exact_device)
|
||||||
|
|
||||||
if not result and msg is None:
|
if not result and msg is None:
|
||||||
|
assert debug_msg is not None
|
||||||
msg = "Quantized representations failed to compare as equal! " + debug_msg
|
msg = "Quantized representations failed to compare as equal! " + debug_msg
|
||||||
self.assertTrue(result, msg=msg)
|
self.assertTrue(result, msg=msg)
|
||||||
else:
|
else:
|
||||||
|
|
@ -1094,6 +1100,7 @@ class TestCase(expecttest.TestCase):
|
||||||
exact_device=exact_device)
|
exact_device=exact_device)
|
||||||
|
|
||||||
if not result and msg is None:
|
if not result and msg is None:
|
||||||
|
assert debug_msg is not None
|
||||||
msg = "Tensors failed to compare as equal! " + debug_msg
|
msg = "Tensors failed to compare as equal! " + debug_msg
|
||||||
self.assertTrue(result, msg=msg)
|
self.assertTrue(result, msg=msg)
|
||||||
elif isinstance(x, string_classes) and isinstance(y, string_classes):
|
elif isinstance(x, string_classes) and isinstance(y, string_classes):
|
||||||
|
|
@ -1127,6 +1134,7 @@ class TestCase(expecttest.TestCase):
|
||||||
result, debug_msg = self._compareScalars(x, y, rtol=rtol, atol=atol,
|
result, debug_msg = self._compareScalars(x, y, rtol=rtol, atol=atol,
|
||||||
equal_nan=equal_nan)
|
equal_nan=equal_nan)
|
||||||
if not result and msg is None:
|
if not result and msg is None:
|
||||||
|
assert debug_msg is not None
|
||||||
msg = "Scalars failed to compare as equal! " + debug_msg
|
msg = "Scalars failed to compare as equal! " + debug_msg
|
||||||
self.assertTrue(result, msg=msg)
|
self.assertTrue(result, msg=msg)
|
||||||
else:
|
else:
|
||||||
|
|
@ -1139,17 +1147,18 @@ class TestCase(expecttest.TestCase):
|
||||||
rtol = None if prec is None else 0
|
rtol = None if prec is None else 0
|
||||||
self.assertEqual(x, y, msg=msg, atol=prec, rtol=rtol)
|
self.assertEqual(x, y, msg=msg, atol=prec, rtol=rtol)
|
||||||
|
|
||||||
def assertNotEqual(self, x, y, *, msg=None, atol=None, rtol=None):
|
def assertNotEqual(self, x, y, msg: Optional[str] = None, *,
|
||||||
|
atol: Optional[float] = None, rtol: Optional[float] = None) -> None:
|
||||||
with self.assertRaises(AssertionError, msg=msg):
|
with self.assertRaises(AssertionError, msg=msg):
|
||||||
self.assertEqual(x, y, atol=atol, rtol=rtol)
|
self.assertEqual(x, y, atol=atol, rtol=rtol)
|
||||||
|
|
||||||
def assertEqualTypeString(self, x, y):
|
def assertEqualTypeString(self, x, y) -> None:
|
||||||
# This API is used simulate deprecated x.type() == y.type()
|
# This API is used simulate deprecated x.type() == y.type()
|
||||||
self.assertEqual(x.device, y.device)
|
self.assertEqual(x.device, y.device)
|
||||||
self.assertEqual(x.dtype, y.dtype)
|
self.assertEqual(x.dtype, y.dtype)
|
||||||
self.assertEqual(x.is_sparse, y.is_sparse)
|
self.assertEqual(x.is_sparse, y.is_sparse)
|
||||||
|
|
||||||
def assertObjectIn(self, obj, iterable):
|
def assertObjectIn(self, obj: Any, iterable: Iterable[Any]) -> None:
|
||||||
for elem in iterable:
|
for elem in iterable:
|
||||||
if id(obj) == id(elem):
|
if id(obj) == id(elem):
|
||||||
return
|
return
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user