mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Kill the test_torch.py mixin and creates test_scatter_gather_ops (#71691)
Summary:
Per title.
Also annotates test_torch.py with additional cleanup tasks and adds empty sample inputs to elementwise unary and binary OpInfos.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/71691
Reviewed By: ngimel
Differential Revision: D33735126
Pulled By: mruberry
fbshipit-source-id: 8cc097a7581a8b620540c95b2a5889c1165ecf23
(cherry picked from commit 5c6a245a3f)
This commit is contained in:
parent
3a03af2f50
commit
e0d829a266
|
|
@ -20,9 +20,6 @@ import torch.cuda.comm as comm
|
|||
from torch.nn.parallel import scatter_gather
|
||||
from torch.utils.checkpoint import checkpoint_sequential
|
||||
from torch._six import inf, nan
|
||||
|
||||
from test_torch import AbstractTestCases
|
||||
|
||||
from torch.testing._internal.common_methods_invocations import tri_tests_args, tri_large_tests_args, \
|
||||
_compare_trilu_indices, _compare_large_trilu_indices
|
||||
from torch.testing._internal.common_utils import TestCase, freeze_rng_state, run_tests, \
|
||||
|
|
@ -59,18 +56,6 @@ if TEST_CUDA:
|
|||
TEST_BF16 = torch.cuda.is_bf16_supported()
|
||||
|
||||
|
||||
types = [
|
||||
torch.FloatTensor,
|
||||
torch.DoubleTensor,
|
||||
torch.LongTensor,
|
||||
torch.IntTensor,
|
||||
torch.ShortTensor,
|
||||
torch.CharTensor,
|
||||
torch.ByteTensor,
|
||||
torch.HalfTensor,
|
||||
]
|
||||
|
||||
|
||||
def make_sparse_tensor(t, n, *sizes):
|
||||
assert t.is_sparse
|
||||
tensor = t()
|
||||
|
|
@ -1540,34 +1525,6 @@ except RuntimeError as e:
|
|||
res_cpu = src.cpu()[idx.cpu()]
|
||||
self.assertEqual(res.cpu(), res_cpu)
|
||||
|
||||
def test_tensor_gather(self):
|
||||
AbstractTestCases._TestTorchMixin._test_gather(self, lambda t: t.cuda(), False)
|
||||
|
||||
def test_tensor_scatter(self):
|
||||
AbstractTestCases._TestTorchMixin._test_scatter_base(self, lambda t: t.cuda(), 'scatter_', test_bounds=False)
|
||||
|
||||
def test_tensor_scatterAdd(self):
|
||||
AbstractTestCases._TestTorchMixin._test_scatter_base(self, lambda t: t.cuda(), 'scatter_add_', test_bounds=False)
|
||||
|
||||
def test_scatter_add_mult_index_base(self):
|
||||
AbstractTestCases._TestTorchMixin._test_scatter_add_mult_index_base(self, lambda t: t.cuda())
|
||||
|
||||
def test_tensor_scatterFill(self):
|
||||
AbstractTestCases._TestTorchMixin._test_scatter_base(self, lambda t: t.cuda(),
|
||||
'scatter_', True, test_bounds=False)
|
||||
|
||||
def test_tensor_scatter_complex(self):
|
||||
AbstractTestCases._TestTorchMixin._test_scatter_base(self, lambda t: t.cuda(),
|
||||
'scatter_', test_bounds=False, test_complex=True)
|
||||
|
||||
def test_tensor_scatterAdd_complex(self):
|
||||
AbstractTestCases._TestTorchMixin._test_scatter_base(self, lambda t: t.cuda(),
|
||||
'scatter_add_', test_bounds=False, test_complex=True)
|
||||
|
||||
def test_tensor_scatterFill_complex(self):
|
||||
AbstractTestCases._TestTorchMixin._test_scatter_base(self, lambda t: t.cuda(),
|
||||
'scatter_', True, test_bounds=False, test_complex=True)
|
||||
|
||||
def test_min_max_inits(self):
|
||||
# Testing if THC_reduceAll received the correct index initialization.
|
||||
# This affects the result of THC_reduceAll operations at extreme values
|
||||
|
|
|
|||
168
test/test_scatter_gather_ops.py
Normal file
168
test/test_scatter_gather_ops.py
Normal file
|
|
@ -0,0 +1,168 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Owner(s): ["module: scatter & gather ops"]
|
||||
|
||||
import random
|
||||
|
||||
import torch
|
||||
|
||||
from torch.testing import make_tensor
|
||||
from torch.testing._internal.common_utils import \
|
||||
(run_tests, TestCase,)
|
||||
from torch.testing._internal.common_device_type import \
|
||||
(instantiate_device_type_tests, dtypes, dtypesIfCUDA,
|
||||
toleranceOverride, tol)
|
||||
|
||||
# Protects against includes accidentally setting the default dtype
|
||||
assert torch.get_default_dtype() is torch.float32
|
||||
|
||||
|
||||
# Note: test_scatter_gather_ops.py
|
||||
# This test file tests scatter and gather operations,
|
||||
# like torch.scatter and torch.gather.
|
||||
|
||||
class TestScatterGather(TestCase):
|
||||
# Fills an index tensor with valid indices
|
||||
def _fill_indices(self, idx, dim, dim_size, elems_per_row, m, n, o):
|
||||
for i in range(1 if dim == 0 else m):
|
||||
for j in range(1 if dim == 1 else n):
|
||||
for k in range(1 if dim == 2 else o):
|
||||
ii = [i, j, k]
|
||||
ii[dim] = slice(0, idx.size(dim) + 1)
|
||||
idx[tuple(ii)] = torch.randperm(dim_size)[0:elems_per_row]
|
||||
|
||||
@dtypes(torch.float32, torch.complex64)
|
||||
def test_gather(self, device, dtype):
|
||||
m, n, o = random.randint(10, 20), random.randint(10, 20), random.randint(10, 20)
|
||||
elems_per_row = random.randint(1, 10)
|
||||
dim = random.randrange(3)
|
||||
|
||||
src = make_tensor((m, n, o), device=device, dtype=dtype)
|
||||
idx_size = [m, n, o]
|
||||
idx_size[dim] = elems_per_row
|
||||
idx = make_tensor(idx_size, device=device, dtype=torch.long)
|
||||
self._fill_indices(idx, dim, src.size(dim), elems_per_row, m, n, o)
|
||||
|
||||
actual = torch.gather(src, dim, idx)
|
||||
expected = torch.zeros(idx_size, device=device, dtype=dtype)
|
||||
for i in range(idx_size[0]):
|
||||
for j in range(idx_size[1]):
|
||||
for k in range(idx_size[2]):
|
||||
ii = [i, j, k]
|
||||
ii[dim] = idx[i, j, k]
|
||||
expected[i, j, k] = src[tuple(ii)]
|
||||
self.assertEqual(actual, expected, atol=0, rtol=0)
|
||||
|
||||
# Guarded because torch.max isn't defined for complex types
|
||||
if not dtype.is_complex:
|
||||
src = make_tensor((3, 4, 5), device=device, dtype=dtype)
|
||||
expected, idx = src.max(2, True)
|
||||
actual = torch.gather(src, 2, idx)
|
||||
self.assertEqual(actual, expected, atol=0, rtol=0)
|
||||
|
||||
@dtypes(torch.bool)
|
||||
def test_gather_bool(self, device, dtype):
|
||||
src = torch.tensor(((False, True), (True, True)), device=device, dtype=dtype)
|
||||
idx = torch.tensor(((0, 0), (1, 0)), device=device, dtype=torch.long)
|
||||
actual = torch.gather(src, 1, idx)
|
||||
expected = torch.tensor(((False, False), (True, True)), device=device, dtype=dtype)
|
||||
self.assertEqual(actual, expected, atol=0, rtol=0)
|
||||
|
||||
def _test_scatter_base(self, fn, *, device, dtype, is_scalar, reduction):
|
||||
m, n, o = random.randint(10, 20), random.randint(10, 20), random.randint(10, 20)
|
||||
elems_per_row = random.randint(1, 10)
|
||||
dim = random.randrange(3)
|
||||
|
||||
idx_size = [m, n, o]
|
||||
idx_size[dim] = elems_per_row
|
||||
idx = torch.empty(tuple(idx_size), device=device, dtype=torch.long)
|
||||
self._fill_indices(idx, dim, ([m, n, o])[dim], elems_per_row, m, n, o)
|
||||
|
||||
if is_scalar:
|
||||
src = random.random()
|
||||
else:
|
||||
src_size = [random.randint(1, 5) + s for s in idx_size]
|
||||
src = make_tensor(tuple(src_size), device=device, dtype=dtype)
|
||||
|
||||
base = make_tensor((m, n, o), device=device, dtype=dtype)
|
||||
if reduction is not None:
|
||||
actual = fn(base.clone(), dim, idx, src, reduce=reduction)
|
||||
else:
|
||||
actual = fn(base.clone(), dim, idx, src)
|
||||
|
||||
expected = base.clone()
|
||||
for i in range(idx_size[0]):
|
||||
for j in range(idx_size[1]):
|
||||
for k in range(idx_size[2]):
|
||||
ii = [i, j, k]
|
||||
ii[dim] = idx[i, j, k]
|
||||
if fn is torch.Tensor.scatter_add_:
|
||||
expected[tuple(ii)] += src[i, j, k]
|
||||
else:
|
||||
# method may be 'scatter_' or 'scatter'
|
||||
# both might have a reduction argument
|
||||
value = src if is_scalar else src[i, j, k]
|
||||
|
||||
if reduction == "add":
|
||||
expected[tuple(ii)] += value
|
||||
elif reduction == "multiply":
|
||||
expected[tuple(ii)] *= value
|
||||
else:
|
||||
expected[tuple(ii)] = value
|
||||
|
||||
self.assertEqual(actual, expected, atol=0, rtol=0)
|
||||
|
||||
# Tests empty index
|
||||
dst = make_tensor((2, 2), device=device, dtype=dtype)
|
||||
idx = torch.tensor((), device=device, dtype=torch.long)
|
||||
src = make_tensor((2, 2), device=device, dtype=dtype)
|
||||
if reduction is not None:
|
||||
actual = fn(dst, 0, idx, src, reduce=reduction)
|
||||
else:
|
||||
actual = fn(dst, 0, idx, src)
|
||||
self.assertEqual(actual, dst, atol=0, rtol=0)
|
||||
|
||||
@dtypes(torch.float16, torch.float32, torch.complex64)
|
||||
def test_scatter_(self, device, dtype):
|
||||
self._test_scatter_base(torch.Tensor.scatter_, device=device, dtype=dtype,
|
||||
is_scalar=False, reduction=None)
|
||||
|
||||
@dtypes(torch.float16, torch.float32, torch.complex64)
|
||||
def test_scatter__scalar(self, device, dtype):
|
||||
self._test_scatter_base(torch.Tensor.scatter_, device=device, dtype=dtype,
|
||||
is_scalar=True, reduction=None)
|
||||
|
||||
# FIXME: RuntimeError: "cuda_scatter_gather_base_kernel_reduce_multiply" not implemented for 'ComplexFloat'
|
||||
@toleranceOverride({torch.float16: tol(atol=1e-2, rtol=0)})
|
||||
@dtypesIfCUDA(torch.float16, torch.float32)
|
||||
@dtypes(torch.float16, torch.float32, torch.complex64)
|
||||
def test_scatter__reductions(self, device, dtype):
|
||||
for reduction in ("add", "multiply"):
|
||||
self._test_scatter_base(torch.Tensor.scatter_, device=device, dtype=dtype,
|
||||
is_scalar=False, reduction=reduction)
|
||||
self._test_scatter_base(torch.Tensor.scatter_, device=device, dtype=dtype,
|
||||
is_scalar=True, reduction=reduction)
|
||||
|
||||
@dtypes(torch.float16, torch.float32, torch.complex64)
|
||||
def test_scatter_add_(self, device, dtype):
|
||||
self._test_scatter_base(torch.Tensor.scatter_add_, device=device, dtype=dtype,
|
||||
is_scalar=False, reduction=None)
|
||||
|
||||
@dtypes(torch.float32)
|
||||
def test_scatter_add_mult_index_base(self, device, dtype):
|
||||
m, n = 30, 40
|
||||
idx = torch.zeros(m, n, device=device, dtype=torch.long)
|
||||
src = torch.ones(m, n, device=device, dtype=dtype)
|
||||
res0 = torch.zeros(m, n, device=device, dtype=dtype).scatter_add_(0, idx, src)
|
||||
res1 = torch.zeros(m, n, device=device, dtype=dtype).scatter_add_(1, idx, src)
|
||||
|
||||
self.assertEqual(res0[0, :], m * torch.ones(n, device=device, dtype=dtype), atol=0, rtol=0)
|
||||
self.assertEqual(res1[:, 0], n * torch.ones(m, device=device, dtype=dtype), atol=0, rtol=0)
|
||||
|
||||
|
||||
# Generic Device Test Framework instantation, see
|
||||
# https://github.com/pytorch/pytorch/wiki/Running-and-writing-tests
|
||||
# for details.
|
||||
instantiate_device_type_tests(TestScatterGather, globals())
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_tests()
|
||||
|
|
@ -1208,6 +1208,16 @@ class TestTensorCreation(TestCase):
|
|||
d = torch.tensor((2, 3), device=device, dtype=torch.double)
|
||||
self.assertRaises(RuntimeError, lambda: torch.zeros((2, 3), device=device, dtype=torch.float32, out=d))
|
||||
|
||||
# FIXME: Create an OpInfo-based tensor creation method test that verifies this for all tensor
|
||||
# creation methods and verify all dtypes and layouts
|
||||
@dtypes(torch.bool, torch.uint8, torch.int16, torch.int64, torch.float16, torch.float32, torch.complex64)
|
||||
def test_zeros_dtype_layout_device_match(self, device, dtype):
|
||||
layout = torch.strided
|
||||
t = torch.zeros((2, 3), device=device, dtype=dtype, layout=layout)
|
||||
self.assertIs(dtype, t.dtype)
|
||||
self.assertIs(layout, t.layout)
|
||||
self.assertEqual(torch.device(device), t.device)
|
||||
|
||||
# TODO: update to work on CUDA, too
|
||||
@onlyCPU
|
||||
def test_trilu_indices(self, device):
|
||||
|
|
|
|||
5649
test/test_torch.py
5649
test/test_torch.py
File diff suppressed because it is too large
Load Diff
|
|
@ -1186,7 +1186,8 @@ def sample_inputs_unary(op_info, device, dtype, requires_grad, **kwargs):
|
|||
low=low, high=high,
|
||||
requires_grad=requires_grad))
|
||||
else:
|
||||
for shape in ((L,), ()):
|
||||
# Creates a 1D, empty, and scalar tensor
|
||||
for shape in ((L,), (1, 0, 3), ()):
|
||||
yield SampleInput(make_tensor(shape, device=device, dtype=dtype,
|
||||
low=low, high=high,
|
||||
requires_grad=requires_grad))
|
||||
|
|
@ -1956,6 +1957,7 @@ def sample_inputs_binary_pwise(
|
|||
((S, M, S), (S, M, S)),
|
||||
((M, 1, S), (M, S)),
|
||||
((M, 1, S), (1, M, S)),
|
||||
((0, 1, 3), (0, 10, 3))
|
||||
]
|
||||
|
||||
for shape_lhs, shape_rhs_or_scalar in shapes:
|
||||
|
|
@ -3016,6 +3018,101 @@ def sample_inputs_gather(op_info, device, dtype, requires_grad, **kwargs):
|
|||
args=(0, torch.tensor(0, dtype=torch.int64, device=device))),
|
||||
)
|
||||
|
||||
def _fill_indices(idx, dim, dim_size, elems_per_row, m, n, o):
|
||||
for i in range(1 if dim == 0 else m):
|
||||
for j in range(1 if dim == 1 else n):
|
||||
for k in range(1 if dim == 2 else o):
|
||||
ii = [i, j, k]
|
||||
ii[dim] = slice(0, idx.size(dim) + 1)
|
||||
idx[tuple(ii)] = torch.randperm(dim_size)[0:elems_per_row]
|
||||
|
||||
def error_inputs_gather(op_info, device, **kwargs):
|
||||
# src is [1, 2]
|
||||
# [3, 4]
|
||||
src = torch.tensor(((1, 2), (3, 4)), device=device, dtype=torch.float32)
|
||||
|
||||
# idx is [0, 0]
|
||||
# [1, 0]
|
||||
idx = torch.tensor(((0, 0), (1, 0)), device=device, dtype=torch.long)
|
||||
|
||||
# Index should be smaller than self except on dimesion 1
|
||||
bad_src = make_tensor((1, 1), device=device, dtype=torch.float32)
|
||||
yield ErrorInput(SampleInput(bad_src, args=(1, idx,)), error_type=RuntimeError,
|
||||
error_regex="Size does not match at dimension 0")
|
||||
|
||||
# Index must have long dtype
|
||||
bad_idx = idx.to(torch.int32)
|
||||
yield ErrorInput(SampleInput(src, args=(1, bad_idx)), error_type=RuntimeError,
|
||||
error_regex="Expected dtype int64 for index")
|
||||
|
||||
# TODO: FIXME
|
||||
# out.dtype must match src.dtype
|
||||
# Creates new src & idx since SampleInputs can't share tensors
|
||||
src = torch.tensor(((1, 2), (3, 4)), device=device, dtype=torch.float32)
|
||||
idx = torch.tensor(((0, 0), (1, 0)), device=device, dtype=torch.long)
|
||||
out = torch.empty((2, 2), device=device, dtype=torch.float64)
|
||||
yield ErrorInput(SampleInput(src, args=(1, idx), kwargs={'out': out}), error_type=RuntimeError,
|
||||
error_regex="Expected out tensor to have dtype")
|
||||
|
||||
# src and index tensors must have the same # of dimensions
|
||||
# idx too few dimensions
|
||||
src = torch.tensor(((1, 2), (3, 4)), device=device, dtype=torch.float32)
|
||||
idx = torch.tensor((0, 0), device=device, dtype=torch.long)
|
||||
yield ErrorInput(SampleInput(src, args=(1, idx)), error_type=RuntimeError,
|
||||
error_regex="Index tensor must have the same number of dimensions")
|
||||
|
||||
# src too few dimensions
|
||||
src = torch.tensor((1, 2), device=device, dtype=torch.float32)
|
||||
idx = torch.tensor(((0, 0), (1, 0)), device=device, dtype=torch.long)
|
||||
yield ErrorInput(SampleInput(src, args=(0, idx)), error_type=RuntimeError,
|
||||
error_regex="Index tensor must have the same number of dimensions")
|
||||
|
||||
# index out of bounds
|
||||
# NOTE: this ErrorInput is guarded because bounds checking does not occur on CUDA devices
|
||||
if torch.device(device).type == 'cpu':
|
||||
src = torch.tensor(((1, 2), (3, 4)), device=device, dtype=torch.float32)
|
||||
idx = torch.tensor(((0, 23), (1, 0)), device=device, dtype=torch.long)
|
||||
yield ErrorInput(SampleInput(src, args=(1, idx,)), error_type=RuntimeError,
|
||||
error_regex="index 23 is out of bounds for dimension")
|
||||
|
||||
# Error inputs for scatter
|
||||
def error_inputs_scatter_and_scatter_add(op_info, device, **kwargs):
|
||||
# Error when self.dtype != src.dtype (and src is not a scalar)
|
||||
src = make_tensor((2, 5), device=device, dtype=torch.float32)
|
||||
idx = torch.tensor(((0, 1), (1, 2)), device=device, dtype=torch.long)
|
||||
dst = torch.zeros((3, 5), device=device, dtype=torch.double)
|
||||
yield ErrorInput(SampleInput(dst, args=(0, idx, src)), error_type=RuntimeError,
|
||||
error_regex="Expected self.dtype to be equal to src.dtype")
|
||||
|
||||
# Index dtype must be long
|
||||
src = make_tensor((2, 5), device=device, dtype=torch.float32)
|
||||
idx = torch.tensor(((0, 1), (1, 2)), device=device, dtype=torch.int32)
|
||||
dst = torch.zeros((3, 5), device=device, dtype=torch.float32)
|
||||
yield ErrorInput(SampleInput(dst, args=(0, idx, src)), error_type=RuntimeError,
|
||||
error_regex="Expected dtype int64 for index")
|
||||
|
||||
# Index and destination must have the same number of dimensions
|
||||
src = make_tensor((2, 5), device=device, dtype=torch.float32)
|
||||
idx = torch.tensor(((0, 1), (1, 2)), device=device, dtype=torch.long)
|
||||
dst = torch.zeros((3, 5, 3), device=device, dtype=torch.float32)
|
||||
yield ErrorInput(SampleInput(dst, args=(0, idx, src)), error_type=RuntimeError,
|
||||
error_regex="Index tensor must have the same number of dimensions as self tensor")
|
||||
|
||||
# Index and src must have the same number of dimensions when src is not a scalar
|
||||
src = make_tensor((2, 5, 2), device=device, dtype=torch.float32)
|
||||
idx = torch.tensor(((34, 1), (1, 2)), device=device, dtype=torch.long)
|
||||
dst = torch.zeros((3, 5), device=device, dtype=torch.float32)
|
||||
yield ErrorInput(SampleInput(dst, args=(0, idx, src)), error_type=RuntimeError,
|
||||
error_regex="Index tensor must have the same number of dimensions as src tensor")
|
||||
|
||||
# Index out of bounds
|
||||
# NOTE: this ErrorInput is guarded because bounds checking does not occur on CUDA devices
|
||||
if torch.device(device).type == 'cpu':
|
||||
src = make_tensor((2, 5), device=device, dtype=torch.float32)
|
||||
idx = torch.tensor(((34, 1), (1, 2)), device=device, dtype=torch.long)
|
||||
dst = torch.zeros((3, 5), device=device, dtype=torch.float32)
|
||||
yield ErrorInput(SampleInput(dst, args=(0, idx, src)), error_type=RuntimeError,
|
||||
error_regex="index 34 is out of bounds for dimension 0 with size 3")
|
||||
|
||||
def sample_inputs_take_along_dim(op_info, device, dtype, requires_grad, **kwargs):
|
||||
return (SampleInput(make_tensor((S, S), device, dtype,
|
||||
|
|
@ -13101,6 +13198,7 @@ op_db: List[OpInfo] = [
|
|||
gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
|
||||
supports_forward_ad=True,
|
||||
supports_fwgrad_bwgrad=True,
|
||||
error_inputs_func=error_inputs_gather
|
||||
),
|
||||
OpInfo('index_fill',
|
||||
dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
|
||||
|
|
@ -13224,7 +13322,8 @@ op_db: List[OpInfo] = [
|
|||
dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
|
||||
supports_forward_ad=True,
|
||||
supports_fwgrad_bwgrad=True,
|
||||
sample_inputs_func=sample_inputs_scatter,),
|
||||
sample_inputs_func=sample_inputs_scatter,
|
||||
error_inputs_func=error_inputs_scatter_and_scatter_add),
|
||||
OpInfo('bfloat16',
|
||||
op=lambda x, *args, **kwargs: x.bfloat16(*args, **kwargs),
|
||||
dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
|
||||
|
|
@ -13656,6 +13755,7 @@ op_db: List[OpInfo] = [
|
|||
OpInfo('scatter_add',
|
||||
dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
|
||||
sample_inputs_func=sample_inputs_scatter_add,
|
||||
error_inputs_func=error_inputs_scatter_and_scatter_add,
|
||||
supports_out=False,
|
||||
supports_forward_ad=True,
|
||||
supports_fwgrad_bwgrad=True,
|
||||
|
|
@ -14347,6 +14447,7 @@ op_db: List[OpInfo] = [
|
|||
# If we pass `condition` first, none of the input which supports
|
||||
# autograd will be tested. Hence the following lambda.
|
||||
op=lambda self, condition, other: torch.where(condition, self, other),
|
||||
ref=lambda self, condition, other: np.where(condition, self, other),
|
||||
sample_inputs_func=sample_inputs_where,
|
||||
supports_out=False,
|
||||
supports_forward_ad=True,
|
||||
|
|
|
|||
|
|
@ -2484,6 +2484,8 @@ def retry(ExceptionToCheck, tries=3, delay=3, skip_after_retries=False):
|
|||
return deco_retry
|
||||
|
||||
|
||||
# FIXME: modernize these to be consistent with make_tensor
|
||||
# and review including them in torch.testing
|
||||
# Methods for matrix generation
|
||||
|
||||
def random_square_matrix_of_rank(l, rank, dtype=torch.double, device='cpu'):
|
||||
|
|
@ -2776,7 +2778,7 @@ def random_sparse_pd_matrix(matrix_size, density=0.01, **kwargs):
|
|||
indices_tensor = torch.tensor([icoords, jcoords])
|
||||
return torch.sparse_coo_tensor(indices_tensor, values, (matrix_size, matrix_size), dtype=dtype, device=device)
|
||||
|
||||
|
||||
# FIXME: remove this by updating test suites using it
|
||||
def do_test_dtypes(self, dtypes, layout, device):
|
||||
for dtype in dtypes:
|
||||
if dtype != torch.float16:
|
||||
|
|
@ -2785,7 +2787,7 @@ def do_test_dtypes(self, dtypes, layout, device):
|
|||
self.assertIs(layout, out.layout)
|
||||
self.assertEqual(device, out.device)
|
||||
|
||||
|
||||
# FIXME: remove this by updating test suites using it
|
||||
def do_test_empty_full(self, dtypes, layout, device):
|
||||
shape = torch.Size([2, 3])
|
||||
|
||||
|
|
@ -2839,43 +2841,8 @@ def do_test_empty_full(self, dtypes, layout, device):
|
|||
dtype=int64_dtype, layout=layout, device=device, requires_grad=False),
|
||||
int64_dtype, layout, device, fv + 5, False)
|
||||
|
||||
# this helper method is to recursively
|
||||
# clone the tensor-type input of operators tested by OpInfo
|
||||
def clone_input_helper(input):
|
||||
if isinstance(input, torch.Tensor):
|
||||
return torch.clone(input)
|
||||
|
||||
if isinstance(input, Sequence):
|
||||
return tuple(map(clone_input_helper, input))
|
||||
|
||||
return input
|
||||
|
||||
THESE_TAKE_WAY_TOO_LONG = {
|
||||
'test_Conv3d_groups',
|
||||
'test_conv_double_backward',
|
||||
'test_conv_double_backward_groups',
|
||||
'test_Conv3d_dilated',
|
||||
'test_Conv3d_stride_padding',
|
||||
'test_Conv3d_dilated_strided',
|
||||
'test_Conv3d',
|
||||
'test_Conv2d_dilated',
|
||||
'test_ConvTranspose3d_dilated',
|
||||
'test_ConvTranspose2d_dilated',
|
||||
'test_snli',
|
||||
'test_Conv2d',
|
||||
'test_Conv2d_padding',
|
||||
'test_ConvTranspose2d_no_bias',
|
||||
'test_ConvTranspose2d',
|
||||
'test_ConvTranspose3d',
|
||||
'test_Conv2d_no_bias',
|
||||
'test_matmul_4d_4d',
|
||||
'test_multinomial_invalid_probs',
|
||||
}
|
||||
|
||||
|
||||
# FIXME: improve load_tests() documentation here
|
||||
running_script_path = None
|
||||
|
||||
|
||||
def set_running_script_path():
|
||||
global running_script_path
|
||||
try:
|
||||
|
|
@ -2885,7 +2852,6 @@ def set_running_script_path():
|
|||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def check_test_defined_in_running_script(test_case):
|
||||
if running_script_path is None:
|
||||
return
|
||||
|
|
@ -2895,7 +2861,6 @@ def check_test_defined_in_running_script(test_case):
|
|||
"accidentally import a unittest.TestCase from another file?".format(
|
||||
test_case.id(), running_script_path, test_case_class_file)
|
||||
|
||||
|
||||
def load_tests(loader, tests, pattern):
|
||||
set_running_script_path()
|
||||
test_suite = unittest.TestSuite()
|
||||
|
|
@ -2906,6 +2871,7 @@ def load_tests(loader, tests, pattern):
|
|||
return test_suite
|
||||
|
||||
|
||||
# FIXME: document this and move it to test_serialization
|
||||
class BytesIOContext(io.BytesIO):
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
|
@ -2980,22 +2946,15 @@ def set_cwd(path: str) -> Iterator[None]:
|
|||
os.chdir(old_cwd)
|
||||
|
||||
|
||||
# Using @precisionOverride specific to your test is the recommended way
|
||||
# FIXME: delete this
|
||||
# Using @toleranceOverride specific to your test is the recommended way
|
||||
# of doing this. These are just some values that worked for test_nn.
|
||||
dtype2prec_DONTUSE = {torch.float: 1e-5,
|
||||
torch.double: 1e-5,
|
||||
torch.half: 1e-2,
|
||||
torch.bfloat16: 1e-1}
|
||||
|
||||
|
||||
def _wrap_warn_once(regex):
|
||||
def decorator(fn):
|
||||
def inner(self, *args, **kwargs):
|
||||
with self.assertWarnsOnceRegex(UserWarning, regex):
|
||||
fn(self, *args, **kwargs)
|
||||
return inner
|
||||
return decorator
|
||||
|
||||
# FIXME: move to test_sparse or sparse utils
|
||||
# This is a wrapper that wraps a test to run this test twice, one with
|
||||
# coalesced=True, another with coalesced=False for coalesced/uncoalesced sparse tensors.
|
||||
def coalescedonoff(f):
|
||||
|
|
@ -3200,6 +3159,8 @@ def get_cycles_per_ms() -> float:
|
|||
return mean(vals[2 : num - 2])
|
||||
|
||||
|
||||
# OpInfo utils
|
||||
|
||||
T = TypeVar('T')
|
||||
def first_sample(self: unittest.TestCase, samples: Iterable[T]) -> T:
|
||||
"""
|
||||
|
|
@ -3210,3 +3171,14 @@ def first_sample(self: unittest.TestCase, samples: Iterable[T]) -> T:
|
|||
return next(iter(samples))
|
||||
except StopIteration:
|
||||
raise unittest.SkipTest('Skipped! Need at least 1 sample input')
|
||||
|
||||
# this helper method is to recursively
|
||||
# clone the tensor-type input of operators tested by OpInfo
|
||||
def clone_input_helper(input):
|
||||
if isinstance(input, torch.Tensor):
|
||||
return torch.clone(input)
|
||||
|
||||
if isinstance(input, Sequence):
|
||||
return tuple(map(clone_input_helper, input))
|
||||
|
||||
return input
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user