mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/11985 Differential Revision: D10202374 Pulled By: SsnL fbshipit-source-id: 1ab1a07185f78a104f9b05930a87ef5a32f431e4
662 lines
23 KiB
Python
662 lines
23 KiB
Python
r"""Importing this file must **not** initialize CUDA context. test_distributed
|
|
relies on this assumption to properly run. This means that when this is imported
|
|
no CUDA calls shall be made, including torch.cuda.device_count(), etc.
|
|
|
|
common_cuda.py can freely initialize CUDA context when imported.
|
|
"""
|
|
|
|
import sys
|
|
import os
|
|
import platform
|
|
import re
|
|
import gc
|
|
import types
|
|
import inspect
|
|
import argparse
|
|
import unittest
|
|
import warnings
|
|
import random
|
|
import contextlib
|
|
import socket
|
|
from collections import OrderedDict
|
|
from functools import wraps
|
|
from itertools import product
|
|
from copy import deepcopy
|
|
from numbers import Number
|
|
|
|
import __main__
|
|
import errno
|
|
|
|
import torch
|
|
import torch.cuda
|
|
from torch._utils_internal import get_writable_path
|
|
from torch._six import string_classes, inf
|
|
import torch.backends.cudnn
|
|
import torch.backends.mkl
|
|
|
|
|
|
torch.set_default_tensor_type('torch.DoubleTensor')
|
|
torch.backends.cudnn.disable_global_flags()
|
|
|
|
|
|
parser = argparse.ArgumentParser(add_help=False)
|
|
parser.add_argument('--seed', type=int, default=1234)
|
|
parser.add_argument('--accept', action='store_true')
|
|
args, remaining = parser.parse_known_args()
|
|
SEED = args.seed
|
|
ACCEPT = args.accept
|
|
UNITTEST_ARGS = [sys.argv[0]] + remaining
|
|
torch.manual_seed(SEED)
|
|
|
|
|
|
def run_tests(argv=UNITTEST_ARGS):
|
|
unittest.main(argv=argv)
|
|
|
|
PY3 = sys.version_info > (3, 0)
|
|
PY34 = sys.version_info >= (3, 4)
|
|
|
|
IS_WINDOWS = sys.platform == "win32"
|
|
IS_PPC = platform.machine() == "ppc64le"
|
|
|
|
|
|
def _check_module_exists(name):
|
|
r"""Returns if a top-level module with :attr:`name` exists *without**
|
|
importing it. This is generally safer than try-catch block around a
|
|
`import X`. It avoids third party libraries breaking assumptions of some of
|
|
our tests, e.g., setting multiprocessing start method when imported
|
|
(see librosa/#747, torchvision/#544).
|
|
"""
|
|
if not PY3: # Python 2
|
|
import imp
|
|
try:
|
|
imp.find_module(name)
|
|
return True
|
|
except ImportError:
|
|
return False
|
|
elif PY34: # Python [3, 3.4)
|
|
import importlib
|
|
loader = importlib.find_loader(name)
|
|
return loader is not None
|
|
else: # Python >= 3.4
|
|
import importlib
|
|
spec = importlib.util.find_spec(name)
|
|
return spec is not None
|
|
|
|
TEST_NUMPY = _check_module_exists('numpy')
|
|
TEST_SCIPY = _check_module_exists('scipy')
|
|
TEST_MKL = torch.backends.mkl.is_available()
|
|
|
|
# On Py2, importing librosa 0.6.1 triggers a TypeError (if using newest joblib)
|
|
# see librosa/librosa#729.
|
|
# TODO: allow Py2 when librosa 0.6.2 releases
|
|
TEST_LIBROSA = _check_module_exists('librosa') and PY3
|
|
|
|
# Python 2.7 doesn't have spawn
|
|
NO_MULTIPROCESSING_SPAWN = os.environ.get('NO_MULTIPROCESSING_SPAWN', '0') == '1' or sys.version_info[0] == 2
|
|
TEST_WITH_ASAN = os.getenv('PYTORCH_TEST_WITH_ASAN', '0') == '1'
|
|
TEST_WITH_UBSAN = os.getenv('PYTORCH_TEST_WITH_UBSAN', '0') == '1'
|
|
TEST_WITH_ROCM = os.getenv('PYTORCH_TEST_WITH_ROCM', '0') == '1'
|
|
|
|
if TEST_NUMPY:
|
|
import numpy
|
|
|
|
|
|
def skipIfRocm(fn):
|
|
@wraps(fn)
|
|
def wrapper(*args, **kwargs):
|
|
if TEST_WITH_ROCM:
|
|
raise unittest.SkipTest("test doesn't currently work on the ROCm stack")
|
|
else:
|
|
fn(*args, **kwargs)
|
|
return wrapper
|
|
|
|
|
|
def skipIfNoLapack(fn):
|
|
@wraps(fn)
|
|
def wrapper(*args, **kwargs):
|
|
if not torch._C.has_lapack:
|
|
raise unittest.SkipTest('PyTorch compiled without Lapack')
|
|
else:
|
|
fn(*args, **kwargs)
|
|
return wrapper
|
|
|
|
|
|
def skipCUDAMemoryLeakCheckIf(condition):
|
|
def dec(fn):
|
|
if getattr(fn, '_do_cuda_memory_leak_check', True): # if current True
|
|
fn._do_cuda_memory_leak_check = not condition
|
|
return fn
|
|
return dec
|
|
|
|
|
|
def suppress_warnings(fn):
|
|
@wraps(fn)
|
|
def wrapper(*args, **kwargs):
|
|
with warnings.catch_warnings():
|
|
warnings.simplefilter("ignore")
|
|
fn(*args, **kwargs)
|
|
return wrapper
|
|
|
|
|
|
def get_cpu_type(type_name):
|
|
module, name = type_name.rsplit('.', 1)
|
|
assert module == 'torch.cuda'
|
|
return getattr(torch, name)
|
|
|
|
|
|
def get_gpu_type(type_name):
|
|
if isinstance(type_name, type):
|
|
type_name = '{}.{}'.format(type_name.__module__, type_name.__name__)
|
|
module, name = type_name.rsplit('.', 1)
|
|
assert module == 'torch'
|
|
return getattr(torch.cuda, name)
|
|
|
|
|
|
def to_gpu(obj, type_map={}):
|
|
if isinstance(obj, torch.Tensor):
|
|
assert obj.is_leaf
|
|
t = type_map.get(obj.type(), get_gpu_type(obj.type()))
|
|
with torch.no_grad():
|
|
res = obj.clone().type(t)
|
|
res.requires_grad = obj.requires_grad
|
|
return res
|
|
elif torch.is_storage(obj):
|
|
return obj.new().resize_(obj.size()).copy_(obj)
|
|
elif isinstance(obj, list):
|
|
return [to_gpu(o, type_map) for o in obj]
|
|
elif isinstance(obj, tuple):
|
|
return tuple(to_gpu(o, type_map) for o in obj)
|
|
else:
|
|
return deepcopy(obj)
|
|
|
|
|
|
def get_function_arglist(func):
|
|
return inspect.getargspec(func).args
|
|
|
|
|
|
def set_rng_seed(seed):
|
|
torch.manual_seed(seed)
|
|
random.seed(seed)
|
|
if TEST_NUMPY:
|
|
numpy.random.seed(seed)
|
|
|
|
|
|
@contextlib.contextmanager
|
|
def freeze_rng_state():
|
|
rng_state = torch.get_rng_state()
|
|
if torch.cuda.is_available():
|
|
cuda_rng_state = torch.cuda.get_rng_state()
|
|
yield
|
|
if torch.cuda.is_available():
|
|
torch.cuda.set_rng_state(cuda_rng_state)
|
|
torch.set_rng_state(rng_state)
|
|
|
|
|
|
def iter_indices(tensor):
|
|
if tensor.dim() == 0:
|
|
return range(0)
|
|
if tensor.dim() == 1:
|
|
return range(tensor.size(0))
|
|
return product(*(range(s) for s in tensor.size()))
|
|
|
|
|
|
def is_iterable(obj):
|
|
try:
|
|
iter(obj)
|
|
return True
|
|
except TypeError:
|
|
return False
|
|
|
|
|
|
class CudaMemoryLeakCheck():
|
|
def __init__(self, testcase, name=None):
|
|
self.name = testcase.id() if name is None else name
|
|
self.testcase = testcase
|
|
|
|
# initialize context & RNG to prevent false positive detections
|
|
# when the test is the first to initialize those
|
|
from common_cuda import initialize_cuda_context_rng
|
|
initialize_cuda_context_rng()
|
|
|
|
@staticmethod
|
|
def get_cuda_memory_usage():
|
|
# we don't need CUDA synchronize because the statistics are not tracked at
|
|
# actual freeing, but at when marking the block as free.
|
|
num_devices = torch.cuda.device_count()
|
|
gc.collect()
|
|
return tuple(torch.cuda.memory_allocated(i) for i in range(num_devices))
|
|
|
|
def __enter__(self):
|
|
self.befores = self.get_cuda_memory_usage()
|
|
|
|
def __exit__(self, exec_type, exec_value, traceback):
|
|
# Don't check for leaks if an exception was thrown
|
|
if exec_type is not None:
|
|
return
|
|
afters = self.get_cuda_memory_usage()
|
|
for i, (before, after) in enumerate(zip(self.befores, afters)):
|
|
self.testcase.assertEqual(
|
|
before, after, '{} leaked {} bytes CUDA memory on device {}'.format(
|
|
self.name, after - before, i))
|
|
|
|
|
|
class TestCase(unittest.TestCase):
|
|
precision = 1e-5
|
|
maxDiff = None
|
|
_do_cuda_memory_leak_check = False
|
|
|
|
def __init__(self, method_name='runTest'):
|
|
super(TestCase, self).__init__(method_name)
|
|
# Wraps the tested method if we should do CUDA memory check.
|
|
test_method = getattr(self, method_name)
|
|
self._do_cuda_memory_leak_check &= getattr(test_method, '_do_cuda_memory_leak_check', True)
|
|
# FIXME: figure out the flaky -1024 anti-leaks on windows. See #8044
|
|
if self._do_cuda_memory_leak_check and not IS_WINDOWS:
|
|
# the import below may initialize CUDA context, so we do it only if
|
|
# self._do_cuda_memory_leak_check is True.
|
|
from common_cuda import TEST_CUDA
|
|
fullname = self.id().lower() # class_name.method_name
|
|
if TEST_CUDA and ('gpu' in fullname or 'cuda' in fullname):
|
|
setattr(self, method_name, self.wrap_with_cuda_memory_check(test_method))
|
|
|
|
def assertLeaksNoCudaTensors(self, name=None):
|
|
name = self.id() if name is None else name
|
|
return CudaMemoryLeakCheck(self, name)
|
|
|
|
def wrap_with_cuda_memory_check(self, method):
|
|
# Assumes that `method` is the tested function in `self`.
|
|
# NOTE: Python Exceptions (e.g., unittest.Skip) keeps objects in scope
|
|
# alive, so this cannot be done in setUp and tearDown because
|
|
# tearDown is run unconditionally no matter whether the test
|
|
# passes or not. For the same reason, we can't wrap the `method`
|
|
# call in try-finally and always do the check.
|
|
@wraps(method)
|
|
def wrapper(self, *args, **kwargs):
|
|
with self.assertLeaksNoCudaTensors():
|
|
method(*args, **kwargs)
|
|
return types.MethodType(wrapper, self)
|
|
|
|
def setUp(self):
|
|
set_rng_seed(SEED)
|
|
|
|
def assertTensorsSlowEqual(self, x, y, prec=None, message=''):
|
|
max_err = 0
|
|
self.assertEqual(x.size(), y.size())
|
|
for index in iter_indices(x):
|
|
max_err = max(max_err, abs(x[index] - y[index]))
|
|
self.assertLessEqual(max_err, prec, message)
|
|
|
|
def safeToDense(self, t):
|
|
r = self.safeCoalesce(t)
|
|
return r.to_dense()
|
|
|
|
def safeCoalesce(self, t):
|
|
tc = t.coalesce()
|
|
self.assertEqual(tc.to_dense(), t.to_dense())
|
|
self.assertTrue(tc.is_coalesced())
|
|
|
|
# Our code below doesn't work when nnz is 0, because
|
|
# then it's a 0D tensor, not a 2D tensor.
|
|
if t._nnz() == 0:
|
|
self.assertEqual(t._indices(), tc._indices())
|
|
self.assertEqual(t._values(), tc._values())
|
|
return tc
|
|
|
|
value_map = {}
|
|
for idx, val in zip(t._indices().t(), t._values()):
|
|
idx_tup = tuple(idx.tolist())
|
|
if idx_tup in value_map:
|
|
value_map[idx_tup] += val
|
|
else:
|
|
value_map[idx_tup] = val.clone() if isinstance(val, torch.Tensor) else val
|
|
|
|
new_indices = sorted(list(value_map.keys()))
|
|
new_values = [value_map[idx] for idx in new_indices]
|
|
if t._values().ndimension() < 2:
|
|
new_values = t._values().new(new_values)
|
|
else:
|
|
new_values = torch.stack(new_values)
|
|
|
|
new_indices = t._indices().new(new_indices).t()
|
|
tg = t.new(new_indices, new_values, t.size())
|
|
|
|
self.assertEqual(tc._indices(), tg._indices())
|
|
self.assertEqual(tc._values(), tg._values())
|
|
|
|
if t.is_coalesced():
|
|
self.assertEqual(tc._indices(), t._indices())
|
|
self.assertEqual(tc._values(), t._values())
|
|
|
|
return tg
|
|
|
|
def assertEqual(self, x, y, prec=None, message='', allow_inf=False):
|
|
if isinstance(prec, str) and message == '':
|
|
message = prec
|
|
prec = None
|
|
if prec is None:
|
|
prec = self.precision
|
|
|
|
if isinstance(x, torch.Tensor) and isinstance(y, Number):
|
|
self.assertEqual(x.item(), y, prec, message, allow_inf)
|
|
elif isinstance(y, torch.Tensor) and isinstance(x, Number):
|
|
self.assertEqual(x, y.item(), prec, message, allow_inf)
|
|
elif isinstance(x, torch.Tensor) and isinstance(y, torch.Tensor):
|
|
def assertTensorsEqual(a, b):
|
|
super(TestCase, self).assertEqual(a.size(), b.size(), message)
|
|
if a.numel() > 0:
|
|
b = b.type_as(a)
|
|
b = b.cuda(device=a.get_device()) if a.is_cuda else b.cpu()
|
|
# check that NaNs are in the same locations
|
|
nan_mask = a != a
|
|
self.assertTrue(torch.equal(nan_mask, b != b), message)
|
|
diff = a - b
|
|
diff[nan_mask] = 0
|
|
# inf check if allow_inf=True
|
|
if allow_inf:
|
|
inf_mask = (a == float("inf")) | (a == float("-inf"))
|
|
self.assertTrue(torch.equal(inf_mask,
|
|
(b == float("inf")) | (b == float("-inf"))),
|
|
message)
|
|
diff[inf_mask] = 0
|
|
# TODO: implement abs on CharTensor
|
|
if diff.is_signed() and 'CharTensor' not in diff.type():
|
|
diff = diff.abs()
|
|
max_err = diff.max()
|
|
self.assertLessEqual(max_err, prec, message)
|
|
super(TestCase, self).assertEqual(x.is_sparse, y.is_sparse, message)
|
|
if x.is_sparse:
|
|
x = self.safeCoalesce(x)
|
|
y = self.safeCoalesce(y)
|
|
assertTensorsEqual(x._indices(), y._indices())
|
|
assertTensorsEqual(x._values(), y._values())
|
|
else:
|
|
assertTensorsEqual(x, y)
|
|
elif isinstance(x, string_classes) and isinstance(y, string_classes):
|
|
super(TestCase, self).assertEqual(x, y, message)
|
|
elif type(x) == set and type(y) == set:
|
|
super(TestCase, self).assertEqual(x, y, message)
|
|
elif isinstance(x, dict) and isinstance(y, dict):
|
|
if isinstance(x, OrderedDict) and isinstance(y, OrderedDict):
|
|
self.assertEqual(x.items(), y.items())
|
|
else:
|
|
self.assertEqual(set(x.keys()), set(y.keys()))
|
|
key_list = list(x.keys())
|
|
self.assertEqual([x[k] for k in key_list], [y[k] for k in key_list])
|
|
elif is_iterable(x) and is_iterable(y):
|
|
super(TestCase, self).assertEqual(len(x), len(y), message)
|
|
for x_, y_ in zip(x, y):
|
|
self.assertEqual(x_, y_, prec, message)
|
|
elif isinstance(x, bool) and isinstance(y, bool):
|
|
super(TestCase, self).assertEqual(x, y, message)
|
|
elif isinstance(x, Number) and isinstance(y, Number):
|
|
if abs(x) == inf or abs(y) == inf:
|
|
if allow_inf:
|
|
super(TestCase, self).assertEqual(x, y, message)
|
|
else:
|
|
self.fail("Expected finite numeric values - x={}, y={}".format(x, y))
|
|
return
|
|
super(TestCase, self).assertLessEqual(abs(x - y), prec, message)
|
|
else:
|
|
super(TestCase, self).assertEqual(x, y, message)
|
|
|
|
def assertAlmostEqual(self, x, y, places=None, msg=None, delta=None, allow_inf=None):
|
|
prec = delta
|
|
if places:
|
|
prec = 10**(-places)
|
|
self.assertEqual(x, y, prec, msg, allow_inf)
|
|
|
|
def assertNotEqual(self, x, y, prec=None, message=''):
|
|
if isinstance(prec, str) and message == '':
|
|
message = prec
|
|
prec = None
|
|
if prec is None:
|
|
prec = self.precision
|
|
|
|
if isinstance(x, torch.Tensor) and isinstance(y, torch.Tensor):
|
|
if x.size() != y.size():
|
|
super(TestCase, self).assertNotEqual(x.size(), y.size())
|
|
self.assertGreater(x.numel(), 0)
|
|
y = y.type_as(x)
|
|
y = y.cuda(device=x.get_device()) if x.is_cuda else y.cpu()
|
|
nan_mask = x != x
|
|
if torch.equal(nan_mask, y != y):
|
|
diff = x - y
|
|
if diff.is_signed():
|
|
diff = diff.abs()
|
|
diff[nan_mask] = 0
|
|
max_err = diff.max()
|
|
self.assertGreaterEqual(max_err, prec, message)
|
|
elif type(x) == str and type(y) == str:
|
|
super(TestCase, self).assertNotEqual(x, y)
|
|
elif is_iterable(x) and is_iterable(y):
|
|
super(TestCase, self).assertNotEqual(x, y)
|
|
else:
|
|
try:
|
|
self.assertGreaterEqual(abs(x - y), prec, message)
|
|
return
|
|
except (TypeError, AssertionError):
|
|
pass
|
|
super(TestCase, self).assertNotEqual(x, y, message)
|
|
|
|
def assertObjectIn(self, obj, iterable):
|
|
for elem in iterable:
|
|
if id(obj) == id(elem):
|
|
return
|
|
raise AssertionError("object not found in iterable")
|
|
|
|
# TODO: Support context manager interface
|
|
# NB: The kwargs forwarding to callable robs the 'subname' parameter.
|
|
# If you need it, manually apply your callable in a lambda instead.
|
|
def assertExpectedRaises(self, exc_type, callable, *args, **kwargs):
|
|
subname = None
|
|
if 'subname' in kwargs:
|
|
subname = kwargs['subname']
|
|
del kwargs['subname']
|
|
try:
|
|
callable(*args, **kwargs)
|
|
except exc_type as e:
|
|
self.assertExpected(str(e), subname)
|
|
return
|
|
# Don't put this in the try block; the AssertionError will catch it
|
|
self.fail(msg="Did not raise when expected to")
|
|
|
|
def assertWarns(self, callable, msg=''):
|
|
r"""
|
|
Test if :attr:`callable` raises a warning.
|
|
"""
|
|
with warnings.catch_warnings(record=True) as ws:
|
|
warnings.simplefilter("always") # allow any warning to be raised
|
|
callable()
|
|
self.assertTrue(len(ws) > 0, msg)
|
|
|
|
def assertWarnsRegex(self, callable, regex, msg=''):
|
|
r"""
|
|
Test if :attr:`callable` raises any warning with message that contains
|
|
the regex pattern :attr:`regex`.
|
|
"""
|
|
with warnings.catch_warnings(record=True) as ws:
|
|
warnings.simplefilter("always") # allow any warning to be raised
|
|
callable()
|
|
self.assertTrue(len(ws) > 0, msg)
|
|
found = any(re.search(regex, str(w.message)) is not None for w in ws)
|
|
self.assertTrue(found, msg)
|
|
|
|
def assertExpected(self, s, subname=None):
|
|
r"""
|
|
Test that a string matches the recorded contents of a file
|
|
derived from the name of this test and subname. This file
|
|
is placed in the 'expect' directory in the same directory
|
|
as the test script. You can automatically update the recorded test
|
|
output using --accept.
|
|
|
|
If you call this multiple times in a single function, you must
|
|
give a unique subname each time.
|
|
"""
|
|
if not (isinstance(s, str) or (sys.version_info[0] == 2 and isinstance(s, unicode))):
|
|
raise TypeError("assertExpected is strings only")
|
|
|
|
def remove_prefix(text, prefix):
|
|
if text.startswith(prefix):
|
|
return text[len(prefix):]
|
|
return text
|
|
# NB: we take __file__ from the module that defined the test
|
|
# class, so we place the expect directory where the test script
|
|
# lives, NOT where test/common.py lives. This doesn't matter in
|
|
# PyTorch where all test scripts are in the same directory as
|
|
# test/common.py, but it matters in onnx-pytorch
|
|
module_id = self.__class__.__module__
|
|
munged_id = remove_prefix(self.id(), module_id + ".")
|
|
test_file = os.path.realpath(sys.modules[module_id].__file__)
|
|
expected_file = os.path.join(os.path.dirname(test_file),
|
|
"expect",
|
|
munged_id)
|
|
|
|
subname_output = ""
|
|
if subname:
|
|
expected_file += "-" + subname
|
|
subname_output = " ({})".format(subname)
|
|
expected_file += ".expect"
|
|
expected = None
|
|
|
|
def accept_output(update_type):
|
|
print("Accepting {} for {}{}:\n\n{}".format(update_type, munged_id, subname_output, s))
|
|
with open(expected_file, 'w') as f:
|
|
f.write(s)
|
|
|
|
try:
|
|
with open(expected_file) as f:
|
|
expected = f.read()
|
|
except IOError as e:
|
|
if e.errno != errno.ENOENT:
|
|
raise
|
|
elif ACCEPT:
|
|
return accept_output("output")
|
|
else:
|
|
raise RuntimeError(
|
|
("I got this output for {}{}:\n\n{}\n\n"
|
|
"No expect file exists; to accept the current output, run:\n"
|
|
"python {} {} --accept").format(munged_id, subname_output, s, __main__.__file__, munged_id))
|
|
|
|
# a hack for JIT tests
|
|
if IS_WINDOWS:
|
|
expected = re.sub(r'CppOp\[(.+?)\]', 'CppOp[]', expected)
|
|
s = re.sub(r'CppOp\[(.+?)\]', 'CppOp[]', s)
|
|
|
|
if ACCEPT:
|
|
if expected != s:
|
|
return accept_output("updated output")
|
|
else:
|
|
if hasattr(self, "assertMultiLineEqual"):
|
|
# Python 2.7 only
|
|
# NB: Python considers lhs "old" and rhs "new".
|
|
self.assertMultiLineEqual(expected, s)
|
|
else:
|
|
self.assertEqual(s, expected)
|
|
|
|
if sys.version_info < (3, 2):
|
|
# assertRegexpMatches renamed to assertRegex in 3.2
|
|
assertRegex = unittest.TestCase.assertRegexpMatches
|
|
# assertRaisesRegexp renamed to assertRaisesRegex in 3.2
|
|
assertRaisesRegex = unittest.TestCase.assertRaisesRegexp
|
|
|
|
|
|
def download_file(url, binary=True):
|
|
if sys.version_info < (3,):
|
|
from urlparse import urlsplit
|
|
import urllib2
|
|
request = urllib2
|
|
error = urllib2
|
|
else:
|
|
from urllib.parse import urlsplit
|
|
from urllib import request, error
|
|
|
|
filename = os.path.basename(urlsplit(url)[2])
|
|
data_dir = get_writable_path(os.path.join(os.path.dirname(__file__), 'data'))
|
|
path = os.path.join(data_dir, filename)
|
|
|
|
if os.path.exists(path):
|
|
return path
|
|
try:
|
|
data = request.urlopen(url, timeout=15).read()
|
|
with open(path, 'wb' if binary else 'w') as f:
|
|
f.write(data)
|
|
return path
|
|
except error.URLError:
|
|
msg = "could not download test file '{}'".format(url)
|
|
warnings.warn(msg, RuntimeWarning)
|
|
raise unittest.SkipTest(msg)
|
|
|
|
|
|
def find_free_port():
|
|
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
|
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
|
sock.bind(('localhost', 0))
|
|
sockname = sock.getsockname()
|
|
sock.close()
|
|
return sockname[1]
|
|
|
|
|
|
# Methods for matrix generation
|
|
# Used in test_autograd.py and test_torch.py
|
|
def prod_single_zero(dim_size):
|
|
result = torch.randn(dim_size, dim_size)
|
|
result[0, 1] = 0
|
|
return result
|
|
|
|
|
|
def random_square_matrix_of_rank(l, rank):
|
|
assert rank <= l
|
|
A = torch.randn(l, l)
|
|
u, s, v = A.svd()
|
|
for i in range(l):
|
|
if i >= rank:
|
|
s[i] = 0
|
|
elif s[i] == 0:
|
|
s[i] = 1
|
|
return u.mm(torch.diag(s)).mm(v.transpose(0, 1))
|
|
|
|
|
|
def random_symmetric_matrix(l):
|
|
A = torch.randn(l, l)
|
|
for i in range(l):
|
|
for j in range(i):
|
|
A[i, j] = A[j, i]
|
|
return A
|
|
|
|
|
|
def random_symmetric_psd_matrix(l):
|
|
A = torch.randn(l, l)
|
|
return A.mm(A.transpose(0, 1))
|
|
|
|
|
|
def random_symmetric_pd_matrix(l, eps=1e-5):
|
|
A = torch.randn(l, l)
|
|
return A.mm(A.transpose(0, 1)) + torch.eye(l) * eps
|
|
|
|
|
|
def make_nonzero_det(A, sign=None, min_singular_value=0.1):
|
|
u, s, v = A.svd()
|
|
s[s < min_singular_value] = min_singular_value
|
|
A = u.mm(torch.diag(s)).mm(v.t())
|
|
det = A.det().item()
|
|
if sign is not None:
|
|
if (det < 0) ^ (sign < 0):
|
|
A[0, :].neg_()
|
|
return A
|
|
|
|
|
|
def random_fullrank_matrix_distinct_singular_value(l, *batches):
|
|
if len(batches) == 0:
|
|
A = torch.randn(l, l)
|
|
u, _, v = A.svd()
|
|
s = torch.arange(1., l + 1).mul_(1.0 / (l + 1))
|
|
return u.mm(torch.diag(s)).mm(v.t())
|
|
else:
|
|
all_matrices = []
|
|
for _ in range(0, torch.prod(torch.as_tensor(batches)).item()):
|
|
A = torch.randn(l, l)
|
|
u, _, v = A.svd()
|
|
s = torch.arange(1., l + 1).mul_(1.0 / (l + 1))
|
|
all_matrices.append(u.mm(torch.diag(s)).mm(v.t()))
|
|
return torch.stack(all_matrices).reshape(*(batches + (l, l)))
|