mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Use both absolute and relative tolerance in testing (#34258)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/34258 This PR allows both atol and rtol to be specified, uses defaults based on the prior analysis (spreadsheet attached to https://github.com/pytorch/pytorch/pull/32538), but retains the absolute tolerance behavior in cases where precision was previously specified explicitly. Test Plan: Imported from OSS Differential Revision: D21110255 Pulled By: nairbv fbshipit-source-id: 57b3a004c7d5ac1be80ee765f03668b1b13f4a7e
This commit is contained in:
parent
3aec9f7924
commit
54ed6fd3ee
|
|
@ -156,13 +156,13 @@ class IntrinsicQATModuleTest(TestCase):
|
||||||
running_var_actual = qat_op.running_var
|
running_var_actual = qat_op.running_var
|
||||||
num_batches_tracked_actual = qat_op.num_batches_tracked
|
num_batches_tracked_actual = qat_op.num_batches_tracked
|
||||||
precision = 1e-10
|
precision = 1e-10
|
||||||
self.assertEqual(input_grad_ref, input_grad_actual, prec=precision)
|
self.assertEqual(input_grad_ref, input_grad_actual, atol=precision)
|
||||||
self.assertEqual(weight_grad_ref, weight_grad_actual, prec=precision)
|
self.assertEqual(weight_grad_ref, weight_grad_actual, atol=precision)
|
||||||
self.assertEqual(gamma_grad_ref, gamma_grad_actual, prec=precision)
|
self.assertEqual(gamma_grad_ref, gamma_grad_actual, atol=precision)
|
||||||
self.assertEqual(beta_grad_ref, beta_grad_actual, prec=precision)
|
self.assertEqual(beta_grad_ref, beta_grad_actual, atol=precision)
|
||||||
self.assertEqual(num_batches_tracked_ref, num_batches_tracked_actual, prec=precision)
|
self.assertEqual(num_batches_tracked_ref, num_batches_tracked_actual, atol=precision)
|
||||||
self.assertEqual(running_mean_ref, running_mean_actual, prec=precision)
|
self.assertEqual(running_mean_ref, running_mean_actual, atol=precision)
|
||||||
self.assertEqual(running_var_ref, running_var_actual, prec=precision)
|
self.assertEqual(running_var_ref, running_var_actual, atol=precision)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
|
||||||
|
|
@ -922,7 +922,7 @@ class TestQuantizedOps(TestCase):
|
||||||
qX_ref = torch.quantize_per_tensor(X_ref, scale=qX_hat.q_scale(), zero_point=qX_hat.q_zero_point(),
|
qX_ref = torch.quantize_per_tensor(X_ref, scale=qX_hat.q_scale(), zero_point=qX_hat.q_zero_point(),
|
||||||
dtype=torch_type)
|
dtype=torch_type)
|
||||||
|
|
||||||
self.assertEqual(qX_ref.int_repr().to(torch.double), qX_hat.int_repr().to(torch.double), prec=1.0,
|
self.assertEqual(qX_ref.int_repr().to(torch.double), qX_hat.int_repr().to(torch.double), atol=1.0,
|
||||||
message=error_message.format(name, qX_hat.int_repr(), qX_ref.int_repr()))
|
message=error_message.format(name, qX_hat.int_repr(), qX_ref.int_repr()))
|
||||||
self.assertEqual(scale, qX_hat.q_scale(),
|
self.assertEqual(scale, qX_hat.q_scale(),
|
||||||
message=error_message.format(name + '.scale', scale, qX_hat.q_scale()))
|
message=error_message.format(name + '.scale', scale, qX_hat.q_scale()))
|
||||||
|
|
@ -984,7 +984,7 @@ class TestQuantizedOps(TestCase):
|
||||||
qX_ref = torch.quantize_per_tensor(X_ref, scale=X_hat.q_scale(), zero_point=X_hat.q_zero_point(),
|
qX_ref = torch.quantize_per_tensor(X_ref, scale=X_hat.q_scale(), zero_point=X_hat.q_zero_point(),
|
||||||
dtype=torch_type)
|
dtype=torch_type)
|
||||||
|
|
||||||
self.assertEqual(qX_ref.int_repr().to(torch.double), X_hat.int_repr().to(torch.double), prec=1.0,
|
self.assertEqual(qX_ref.int_repr().to(torch.double), X_hat.int_repr().to(torch.double), atol=1.0,
|
||||||
message=error_message.format(name, X_hat.int_repr(), qX_ref.int_repr()))
|
message=error_message.format(name, X_hat.int_repr(), qX_ref.int_repr()))
|
||||||
self.assertEqual(scale, X_hat.q_scale(),
|
self.assertEqual(scale, X_hat.q_scale(),
|
||||||
message=error_message.format(name + '.scale', scale, X_hat.q_scale()))
|
message=error_message.format(name + '.scale', scale, X_hat.q_scale()))
|
||||||
|
|
@ -1036,7 +1036,7 @@ class TestQuantizedOps(TestCase):
|
||||||
count_include_pad=count_include_pad, divisor_override=divisor_override)
|
count_include_pad=count_include_pad, divisor_override=divisor_override)
|
||||||
qX_ref = torch.quantize_per_tensor(X_ref, scale=qX_hat.q_scale(), zero_point=qX_hat.q_zero_point(),
|
qX_ref = torch.quantize_per_tensor(X_ref, scale=qX_hat.q_scale(), zero_point=qX_hat.q_zero_point(),
|
||||||
dtype=torch_type)
|
dtype=torch_type)
|
||||||
self.assertEqual(qX_ref.int_repr().to(torch.double), qX_hat.int_repr().to(torch.double), prec=1.0,
|
self.assertEqual(qX_ref.int_repr().to(torch.double), qX_hat.int_repr().to(torch.double), atol=1.0,
|
||||||
message=error_message.format(name, qX_hat.int_repr(), qX_ref.int_repr()))
|
message=error_message.format(name, qX_hat.int_repr(), qX_ref.int_repr()))
|
||||||
self.assertEqual(scale, qX_hat.q_scale(),
|
self.assertEqual(scale, qX_hat.q_scale(),
|
||||||
message=error_message.format(name + '.scale', scale, qX_hat.q_scale()))
|
message=error_message.format(name + '.scale', scale, qX_hat.q_scale()))
|
||||||
|
|
@ -1100,7 +1100,7 @@ class TestQuantizedOps(TestCase):
|
||||||
qX_ref = torch.quantize_per_tensor(X_ref, scale=X_hat.q_scale(), zero_point=X_hat.q_zero_point(),
|
qX_ref = torch.quantize_per_tensor(X_ref, scale=X_hat.q_scale(), zero_point=X_hat.q_zero_point(),
|
||||||
dtype=torch_type)
|
dtype=torch_type)
|
||||||
|
|
||||||
self.assertEqual(qX_ref.int_repr().to(torch.double), X_hat.int_repr().to(torch.double), prec=1.0,
|
self.assertEqual(qX_ref.int_repr().to(torch.double), X_hat.int_repr().to(torch.double), atol=1.0,
|
||||||
message=error_message.format(name, X_hat.int_repr(), qX_ref.int_repr()))
|
message=error_message.format(name, X_hat.int_repr(), qX_ref.int_repr()))
|
||||||
self.assertEqual(scale, X_hat.q_scale(),
|
self.assertEqual(scale, X_hat.q_scale(),
|
||||||
message=error_message.format(name + '.scale', scale, X_hat.q_scale()))
|
message=error_message.format(name + '.scale', scale, X_hat.q_scale()))
|
||||||
|
|
@ -1141,7 +1141,7 @@ class TestQuantizedOps(TestCase):
|
||||||
|
|
||||||
for name, op in ops_under_test.items():
|
for name, op in ops_under_test.items():
|
||||||
qX_hat = op(qX, output_size=output_size)
|
qX_hat = op(qX, output_size=output_size)
|
||||||
self.assertEqual(X_ref, qX_hat.int_repr(), prec=1.0,
|
self.assertEqual(X_ref, qX_hat.int_repr(), atol=1.0,
|
||||||
message=error_message.format(name, X_ref, qX_hat))
|
message=error_message.format(name, X_ref, qX_hat))
|
||||||
self.assertEqual(scale, qX_hat.q_scale(),
|
self.assertEqual(scale, qX_hat.q_scale(),
|
||||||
message=error_message.format(name + '.scale', scale, qX_hat.q_scale()))
|
message=error_message.format(name + '.scale', scale, qX_hat.q_scale()))
|
||||||
|
|
@ -1187,7 +1187,7 @@ class TestQuantizedOps(TestCase):
|
||||||
for name, op in ops_under_test.items():
|
for name, op in ops_under_test.items():
|
||||||
X_hat = op(qX, output_size=output_size)
|
X_hat = op(qX, output_size=output_size)
|
||||||
self.assertTrue(X_hat.stride() != sorted(X_hat.stride()))
|
self.assertTrue(X_hat.stride() != sorted(X_hat.stride()))
|
||||||
self.assertEqual(X_ref, X_hat.int_repr(), prec=1.0,
|
self.assertEqual(X_ref, X_hat.int_repr(), atol=1.0,
|
||||||
message="{} results are off".format(name))
|
message="{} results are off".format(name))
|
||||||
self.assertEqual(scale, X_hat.q_scale(),
|
self.assertEqual(scale, X_hat.q_scale(),
|
||||||
message=error_message.format(name + '.scale', scale, X_hat.q_scale()))
|
message=error_message.format(name + '.scale', scale, X_hat.q_scale()))
|
||||||
|
|
@ -1349,7 +1349,7 @@ class TestQuantizedOps(TestCase):
|
||||||
for name, op in ops_under_test.items():
|
for name, op in ops_under_test.items():
|
||||||
qX_hat = op(qX, size=size, scale_factor=scale_factor,
|
qX_hat = op(qX, size=size, scale_factor=scale_factor,
|
||||||
mode=mode, align_corners=align_corners)
|
mode=mode, align_corners=align_corners)
|
||||||
self.assertEqual(X_ref, qX_hat.int_repr(), prec=1.0,
|
self.assertEqual(X_ref, qX_hat.int_repr(), atol=1.0,
|
||||||
message="{} results are off".format(name, qX_hat.int_repr(), X_ref))
|
message="{} results are off".format(name, qX_hat.int_repr(), X_ref))
|
||||||
self.assertEqual(scale, qX_hat.q_scale(),
|
self.assertEqual(scale, qX_hat.q_scale(),
|
||||||
message=error_message.format(name + '.scale', scale, qX_hat.q_scale()))
|
message=error_message.format(name + '.scale', scale, qX_hat.q_scale()))
|
||||||
|
|
@ -1402,7 +1402,7 @@ class TestQuantizedOps(TestCase):
|
||||||
for name, op in ops_under_test.items():
|
for name, op in ops_under_test.items():
|
||||||
qX_hat = op(qX, size=size, scale_factor=scale_factor,
|
qX_hat = op(qX, size=size, scale_factor=scale_factor,
|
||||||
mode=mode, align_corners=align_corners)
|
mode=mode, align_corners=align_corners)
|
||||||
self.assertEqual(X_ref, qX_hat.int_repr(), prec=1.0,
|
self.assertEqual(X_ref, qX_hat.int_repr(), atol=1.0,
|
||||||
message="{} results are off".format(name, qX_hat.int_repr(), X_ref))
|
message="{} results are off".format(name, qX_hat.int_repr(), X_ref))
|
||||||
self.assertEqual(scale, qX_hat.q_scale(),
|
self.assertEqual(scale, qX_hat.q_scale(),
|
||||||
message=error_message.format(name + '.scale', scale, qX_hat.q_scale()))
|
message=error_message.format(name + '.scale', scale, qX_hat.q_scale()))
|
||||||
|
|
|
||||||
|
|
@ -399,7 +399,7 @@ class ModuleAPITest(QuantizationTestCase):
|
||||||
|
|
||||||
qlinear.set_weight_bias(W_q, B)
|
qlinear.set_weight_bias(W_q, B)
|
||||||
# Simple round-trip test to ensure weight()/set_weight() API
|
# Simple round-trip test to ensure weight()/set_weight() API
|
||||||
self.assertEqual(qlinear.weight(), W_q)
|
self.assertEqual(qlinear.weight(), W_q, atol=1e-5)
|
||||||
W_pack = qlinear._packed_params._packed_params
|
W_pack = qlinear._packed_params._packed_params
|
||||||
|
|
||||||
qlinear.scale = float(scale)
|
qlinear.scale = float(scale)
|
||||||
|
|
|
||||||
|
|
@ -552,7 +552,7 @@ class TestAutograd(TestCase):
|
||||||
z.sum().backward()
|
z.sum().backward()
|
||||||
|
|
||||||
self.assertEqual(counter[0], 1, 'bw_hook not called')
|
self.assertEqual(counter[0], 1, 'bw_hook not called')
|
||||||
self.assertEqual(x.grad, torch.ones(5, 5) * 2)
|
self.assertEqual(x.grad, torch.ones(5, 5) * 2, atol=1e-5)
|
||||||
|
|
||||||
def test_hook_none(self):
|
def test_hook_none(self):
|
||||||
# WARNING: this is a test for autograd internals.
|
# WARNING: this is a test for autograd internals.
|
||||||
|
|
@ -5577,7 +5577,7 @@ class TestAutogradDeviceType(TestCase):
|
||||||
input_lengths, target_lengths, reduction='none')
|
input_lengths, target_lengths, reduction='none')
|
||||||
self.assertTrue("Cudnn" in str(loss_cudnn.grad_fn))
|
self.assertTrue("Cudnn" in str(loss_cudnn.grad_fn))
|
||||||
grad_cudnn, = torch.autograd.grad(loss_cudnn, log_probs, grad_out)
|
grad_cudnn, = torch.autograd.grad(loss_cudnn, log_probs, grad_out)
|
||||||
self.assertEqual(grad_cudnn, grad_native, prec=1e-4)
|
self.assertEqual(grad_cudnn, grad_native, atol=1e-4)
|
||||||
|
|
||||||
@skipCUDAIfRocm
|
@skipCUDAIfRocm
|
||||||
def test_leaky_relu_inplace_with_neg_slope(self, device):
|
def test_leaky_relu_inplace_with_neg_slope(self, device):
|
||||||
|
|
|
||||||
|
|
@ -434,7 +434,7 @@ EXAMPLES = [
|
||||||
Example(MixtureSameFamily, [
|
Example(MixtureSameFamily, [
|
||||||
{
|
{
|
||||||
'mixture_distribution': Categorical(torch.rand(5, requires_grad=True)),
|
'mixture_distribution': Categorical(torch.rand(5, requires_grad=True)),
|
||||||
'component_distribution': Normal(torch.randn(5, requires_grad=True),
|
'component_distribution': Normal(torch.randn(5, requires_grad=True),
|
||||||
torch.rand(5, requires_grad=True)),
|
torch.rand(5, requires_grad=True)),
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
|
@ -442,7 +442,7 @@ EXAMPLES = [
|
||||||
'component_distribution': MultivariateNormal(
|
'component_distribution': MultivariateNormal(
|
||||||
loc=torch.randn(5, 2, requires_grad=True),
|
loc=torch.randn(5, 2, requires_grad=True),
|
||||||
covariance_matrix=torch.tensor([[2.0, 0.3], [0.3, 0.25]], requires_grad=True)),
|
covariance_matrix=torch.tensor([[2.0, 0.3], [0.3, 0.25]], requires_grad=True)),
|
||||||
},
|
},
|
||||||
]),
|
]),
|
||||||
Example(VonMises, [
|
Example(VonMises, [
|
||||||
{
|
{
|
||||||
|
|
@ -913,9 +913,9 @@ class TestDistributions(TestCase):
|
||||||
self.assertRaises(NotImplementedError, Bernoulli(r).rsample)
|
self.assertRaises(NotImplementedError, Bernoulli(r).rsample)
|
||||||
|
|
||||||
# check entropy computation
|
# check entropy computation
|
||||||
self.assertEqual(Bernoulli(p).entropy(), torch.tensor([0.6108, 0.5004, 0.6730]), prec=1e-4)
|
self.assertEqual(Bernoulli(p).entropy(), torch.tensor([0.6108, 0.5004, 0.6730]), atol=1e-4)
|
||||||
self.assertEqual(Bernoulli(torch.tensor([0.0])).entropy(), torch.tensor([0.0]))
|
self.assertEqual(Bernoulli(torch.tensor([0.0])).entropy(), torch.tensor([0.0]))
|
||||||
self.assertEqual(Bernoulli(s).entropy(), torch.tensor(0.6108), prec=1e-4)
|
self.assertEqual(Bernoulli(s).entropy(), torch.tensor(0.6108), atol=1e-4)
|
||||||
|
|
||||||
def test_bernoulli_enumerate_support(self):
|
def test_bernoulli_enumerate_support(self):
|
||||||
examples = [
|
examples = [
|
||||||
|
|
@ -962,8 +962,8 @@ class TestDistributions(TestCase):
|
||||||
self._check_log_prob(Geometric(logits=p.log() - (-p).log1p()), ref_log_prob)
|
self._check_log_prob(Geometric(logits=p.log() - (-p).log1p()), ref_log_prob)
|
||||||
|
|
||||||
# check entropy computation
|
# check entropy computation
|
||||||
self.assertEqual(Geometric(p).entropy(), scipy.stats.geom(p.detach().numpy(), loc=-1).entropy(), prec=1e-3)
|
self.assertEqual(Geometric(p).entropy(), scipy.stats.geom(p.detach().numpy(), loc=-1).entropy(), atol=1e-3)
|
||||||
self.assertEqual(float(Geometric(s).entropy()), scipy.stats.geom(s, loc=-1).entropy().item(), prec=1e-3)
|
self.assertEqual(float(Geometric(s).entropy()), scipy.stats.geom(s, loc=-1).entropy().item(), atol=1e-3)
|
||||||
|
|
||||||
@unittest.skipIf(not TEST_NUMPY, "NumPy not found")
|
@unittest.skipIf(not TEST_NUMPY, "NumPy not found")
|
||||||
def test_geometric_sample(self):
|
def test_geometric_sample(self):
|
||||||
|
|
@ -1047,8 +1047,8 @@ class TestDistributions(TestCase):
|
||||||
bin1 = Binomial(total_count, torch.tensor(0.5))
|
bin1 = Binomial(total_count, torch.tensor(0.5))
|
||||||
samples = bin1.sample(torch.Size((100000,)))
|
samples = bin1.sample(torch.Size((100000,)))
|
||||||
self.assertTrue((samples <= total_count.type_as(samples)).all())
|
self.assertTrue((samples <= total_count.type_as(samples)).all())
|
||||||
self.assertEqual(samples.mean(dim=0), bin1.mean, prec=0.02)
|
self.assertEqual(samples.mean(dim=0), bin1.mean, atol=0.02)
|
||||||
self.assertEqual(samples.var(dim=0), bin1.variance, prec=0.02)
|
self.assertEqual(samples.var(dim=0), bin1.variance, atol=0.02)
|
||||||
|
|
||||||
def test_negative_binomial(self):
|
def test_negative_binomial(self):
|
||||||
p = torch.arange(0.05, 1, 0.1).requires_grad_()
|
p = torch.arange(0.05, 1, 0.1).requires_grad_()
|
||||||
|
|
@ -1165,7 +1165,7 @@ class TestDistributions(TestCase):
|
||||||
self._check_log_prob(Categorical(logits=p.log()), ref_log_prob)
|
self._check_log_prob(Categorical(logits=p.log()), ref_log_prob)
|
||||||
|
|
||||||
# check entropy computation
|
# check entropy computation
|
||||||
self.assertEqual(Categorical(p).entropy(), torch.tensor([1.0114, 1.0297]), prec=1e-4)
|
self.assertEqual(Categorical(p).entropy(), torch.tensor([1.0114, 1.0297]), atol=1e-4)
|
||||||
self.assertEqual(Categorical(s).entropy(), torch.tensor([0.0, 0.0]))
|
self.assertEqual(Categorical(s).entropy(), torch.tensor([0.0, 0.0]))
|
||||||
|
|
||||||
def test_categorical_enumerate_support(self):
|
def test_categorical_enumerate_support(self):
|
||||||
|
|
@ -1467,7 +1467,7 @@ class TestDistributions(TestCase):
|
||||||
set_rng_seed(1)
|
set_rng_seed(1)
|
||||||
self.assertEqual(HalfNormal(std_delta).sample(sample_shape=(1, 2)),
|
self.assertEqual(HalfNormal(std_delta).sample(sample_shape=(1, 2)),
|
||||||
torch.tensor([[[0.0, 0.0], [0.0, 0.0]]]),
|
torch.tensor([[[0.0, 0.0], [0.0, 0.0]]]),
|
||||||
prec=1e-4)
|
atol=1e-4)
|
||||||
|
|
||||||
self._gradcheck_log_prob(HalfNormal, (std,))
|
self._gradcheck_log_prob(HalfNormal, (std,))
|
||||||
self._gradcheck_log_prob(HalfNormal, (1.0,))
|
self._gradcheck_log_prob(HalfNormal, (1.0,))
|
||||||
|
|
@ -1514,7 +1514,7 @@ class TestDistributions(TestCase):
|
||||||
set_rng_seed(1)
|
set_rng_seed(1)
|
||||||
self.assertEqual(LogNormal(mean_delta, std_delta).sample(sample_shape=(1, 2)),
|
self.assertEqual(LogNormal(mean_delta, std_delta).sample(sample_shape=(1, 2)),
|
||||||
torch.tensor([[[math.exp(1), 1.0], [math.exp(1), 1.0]]]),
|
torch.tensor([[[math.exp(1), 1.0], [math.exp(1), 1.0]]]),
|
||||||
prec=1e-4)
|
atol=1e-4)
|
||||||
|
|
||||||
self._gradcheck_log_prob(LogNormal, (mean, std))
|
self._gradcheck_log_prob(LogNormal, (mean, std))
|
||||||
self._gradcheck_log_prob(LogNormal, (mean, 1.0))
|
self._gradcheck_log_prob(LogNormal, (mean, 1.0))
|
||||||
|
|
@ -1566,7 +1566,7 @@ class TestDistributions(TestCase):
|
||||||
torch.tensor([math.exp(1) / (1. + 1. + math.exp(1)),
|
torch.tensor([math.exp(1) / (1. + 1. + math.exp(1)),
|
||||||
1. / (1. + 1. + math.exp(1)),
|
1. / (1. + 1. + math.exp(1)),
|
||||||
1. / (1. + 1. + math.exp(1))]),
|
1. / (1. + 1. + math.exp(1))]),
|
||||||
prec=1e-4)
|
atol=1e-4)
|
||||||
|
|
||||||
self._gradcheck_log_prob(LogisticNormal, (mean, std))
|
self._gradcheck_log_prob(LogisticNormal, (mean, std))
|
||||||
self._gradcheck_log_prob(LogisticNormal, (mean, 1.0))
|
self._gradcheck_log_prob(LogisticNormal, (mean, 1.0))
|
||||||
|
|
@ -1617,24 +1617,24 @@ class TestDistributions(TestCase):
|
||||||
Categorical(torch.rand(5)),
|
Categorical(torch.rand(5)),
|
||||||
Normal(torch.randn(5), torch.rand(5)))
|
Normal(torch.randn(5), torch.rand(5)))
|
||||||
normal_case_1d_batch = MixtureSameFamily(
|
normal_case_1d_batch = MixtureSameFamily(
|
||||||
Categorical(torch.rand(3, 5)),
|
Categorical(torch.rand(3, 5)),
|
||||||
Normal(torch.randn(3, 5), torch.rand(3, 5)))
|
Normal(torch.randn(3, 5), torch.rand(3, 5)))
|
||||||
normal_case_1d_multi_batch = MixtureSameFamily(
|
normal_case_1d_multi_batch = MixtureSameFamily(
|
||||||
Categorical(torch.rand(4, 3, 5)),
|
Categorical(torch.rand(4, 3, 5)),
|
||||||
Normal(torch.randn(4, 3, 5), torch.rand(4, 3, 5)))
|
Normal(torch.randn(4, 3, 5), torch.rand(4, 3, 5)))
|
||||||
normal_case_2d = MixtureSameFamily(
|
normal_case_2d = MixtureSameFamily(
|
||||||
Categorical(torch.rand(5)),
|
Categorical(torch.rand(5)),
|
||||||
Independent(Normal(torch.randn(5, 2), torch.rand(5, 2)), 1))
|
Independent(Normal(torch.randn(5, 2), torch.rand(5, 2)), 1))
|
||||||
normal_case_2d_batch = MixtureSameFamily(
|
normal_case_2d_batch = MixtureSameFamily(
|
||||||
Categorical(torch.rand(3, 5)),
|
Categorical(torch.rand(3, 5)),
|
||||||
Independent(Normal(torch.randn(3, 5, 2), torch.rand(3, 5, 2)), 1))
|
Independent(Normal(torch.randn(3, 5, 2), torch.rand(3, 5, 2)), 1))
|
||||||
normal_case_2d_multi_batch = MixtureSameFamily(
|
normal_case_2d_multi_batch = MixtureSameFamily(
|
||||||
Categorical(torch.rand(4, 3, 5)),
|
Categorical(torch.rand(4, 3, 5)),
|
||||||
Independent(Normal(torch.randn(4, 3, 5, 2), torch.rand(4, 3, 5, 2)), 1))
|
Independent(Normal(torch.randn(4, 3, 5, 2), torch.rand(4, 3, 5, 2)), 1))
|
||||||
|
|
||||||
self.assertEqual(normal_case_1d.sample().size(), ())
|
self.assertEqual(normal_case_1d.sample().size(), ())
|
||||||
self.assertEqual(normal_case_1d.sample((2,)).size(), (2,))
|
self.assertEqual(normal_case_1d.sample((2,)).size(), (2,))
|
||||||
self.assertEqual(normal_case_1d.sample((2, 7)).size(), (2, 7))
|
self.assertEqual(normal_case_1d.sample((2, 7)).size(), (2, 7))
|
||||||
self.assertEqual(normal_case_1d_batch.sample().size(), (3,))
|
self.assertEqual(normal_case_1d_batch.sample().size(), (3,))
|
||||||
self.assertEqual(normal_case_1d_batch.sample((2,)).size(), (2, 3))
|
self.assertEqual(normal_case_1d_batch.sample((2,)).size(), (2, 3))
|
||||||
self.assertEqual(normal_case_1d_batch.sample((2, 7)).size(), (2, 7, 3))
|
self.assertEqual(normal_case_1d_batch.sample((2, 7)).size(), (2, 7, 3))
|
||||||
|
|
@ -1644,7 +1644,7 @@ class TestDistributions(TestCase):
|
||||||
|
|
||||||
self.assertEqual(normal_case_2d.sample().size(), (2,))
|
self.assertEqual(normal_case_2d.sample().size(), (2,))
|
||||||
self.assertEqual(normal_case_2d.sample((2,)).size(), (2, 2))
|
self.assertEqual(normal_case_2d.sample((2,)).size(), (2, 2))
|
||||||
self.assertEqual(normal_case_2d.sample((2, 7)).size(), (2, 7, 2))
|
self.assertEqual(normal_case_2d.sample((2, 7)).size(), (2, 7, 2))
|
||||||
self.assertEqual(normal_case_2d_batch.sample().size(), (3, 2))
|
self.assertEqual(normal_case_2d_batch.sample().size(), (3, 2))
|
||||||
self.assertEqual(normal_case_2d_batch.sample((2,)).size(), (2, 3, 2))
|
self.assertEqual(normal_case_2d_batch.sample((2,)).size(), (2, 3, 2))
|
||||||
self.assertEqual(normal_case_2d_batch.sample((2, 7)).size(), (2, 7, 3, 2))
|
self.assertEqual(normal_case_2d_batch.sample((2, 7)).size(), (2, 7, 3, 2))
|
||||||
|
|
@ -1668,7 +1668,7 @@ class TestDistributions(TestCase):
|
||||||
self.assertAlmostEqual(log_prob, expected, places=3)
|
self.assertAlmostEqual(log_prob, expected, places=3)
|
||||||
|
|
||||||
self._check_log_prob(
|
self._check_log_prob(
|
||||||
MixtureSameFamily(Categorical(probs=probs),
|
MixtureSameFamily(Categorical(probs=probs),
|
||||||
Normal(loc, scale)), ref_log_prob)
|
Normal(loc, scale)), ref_log_prob)
|
||||||
|
|
||||||
@unittest.skipIf(not TEST_NUMPY, "Numpy not found")
|
@unittest.skipIf(not TEST_NUMPY, "Numpy not found")
|
||||||
|
|
@ -1695,7 +1695,7 @@ class TestDistributions(TestCase):
|
||||||
self._check_sampler_sampler(
|
self._check_sampler_sampler(
|
||||||
MixtureSameFamily(Categorical(probs=probs), Normal(loc, scale)),
|
MixtureSameFamily(Categorical(probs=probs), Normal(loc, scale)),
|
||||||
ScipyMixtureNormal(probs.numpy(), loc.numpy(), scale.numpy()),
|
ScipyMixtureNormal(probs.numpy(), loc.numpy(), scale.numpy()),
|
||||||
'''MixtureSameFamily(Categorical(probs={}),
|
'''MixtureSameFamily(Categorical(probs={}),
|
||||||
Normal(loc={}, scale={}))'''.format(probs, loc, scale))
|
Normal(loc={}, scale={}))'''.format(probs, loc, scale))
|
||||||
|
|
||||||
def test_normal(self):
|
def test_normal(self):
|
||||||
|
|
@ -1716,7 +1716,7 @@ class TestDistributions(TestCase):
|
||||||
set_rng_seed(1)
|
set_rng_seed(1)
|
||||||
self.assertEqual(Normal(loc_delta, scale_delta).sample(sample_shape=(1, 2)),
|
self.assertEqual(Normal(loc_delta, scale_delta).sample(sample_shape=(1, 2)),
|
||||||
torch.tensor([[[1.0, 0.0], [1.0, 0.0]]]),
|
torch.tensor([[[1.0, 0.0], [1.0, 0.0]]]),
|
||||||
prec=1e-4)
|
atol=1e-4)
|
||||||
|
|
||||||
self._gradcheck_log_prob(Normal, (loc, scale))
|
self._gradcheck_log_prob(Normal, (loc, scale))
|
||||||
self._gradcheck_log_prob(Normal, (loc, 1.0))
|
self._gradcheck_log_prob(Normal, (loc, 1.0))
|
||||||
|
|
@ -1865,9 +1865,9 @@ class TestDistributions(TestCase):
|
||||||
d = LowRankMultivariateNormal(mean, cov_factor, cov_diag)
|
d = LowRankMultivariateNormal(mean, cov_factor, cov_diag)
|
||||||
samples = d.rsample((100000,))
|
samples = d.rsample((100000,))
|
||||||
empirical_mean = samples.mean(0)
|
empirical_mean = samples.mean(0)
|
||||||
self.assertEqual(d.mean, empirical_mean, prec=0.01)
|
self.assertEqual(d.mean, empirical_mean, atol=0.01)
|
||||||
empirical_var = samples.var(0)
|
empirical_var = samples.var(0)
|
||||||
self.assertEqual(d.variance, empirical_var, prec=0.02)
|
self.assertEqual(d.variance, empirical_var, atol=0.02)
|
||||||
|
|
||||||
def test_multivariate_normal_shape(self):
|
def test_multivariate_normal_shape(self):
|
||||||
mean = torch.randn(5, 3, requires_grad=True)
|
mean = torch.randn(5, 3, requires_grad=True)
|
||||||
|
|
@ -1983,7 +1983,7 @@ class TestDistributions(TestCase):
|
||||||
multivariate=True)
|
multivariate=True)
|
||||||
self._check_sampler_sampler(MultivariateNormal(mean, precision_matrix=prec),
|
self._check_sampler_sampler(MultivariateNormal(mean, precision_matrix=prec),
|
||||||
scipy.stats.multivariate_normal(mean.detach().numpy(), cov.detach().numpy()),
|
scipy.stats.multivariate_normal(mean.detach().numpy(), cov.detach().numpy()),
|
||||||
'MultivariateNormal(loc={}, prec={})'.format(mean, prec),
|
'MultivariateNormal(loc={}, atol={})'.format(mean, prec),
|
||||||
multivariate=True)
|
multivariate=True)
|
||||||
self._check_sampler_sampler(MultivariateNormal(mean, scale_tril=scale_tril),
|
self._check_sampler_sampler(MultivariateNormal(mean, scale_tril=scale_tril),
|
||||||
scipy.stats.multivariate_normal(mean.detach().numpy(), cov.detach().numpy()),
|
scipy.stats.multivariate_normal(mean.detach().numpy(), cov.detach().numpy()),
|
||||||
|
|
@ -2005,9 +2005,9 @@ class TestDistributions(TestCase):
|
||||||
d = MultivariateNormal(mean, scale_tril=scale_tril)
|
d = MultivariateNormal(mean, scale_tril=scale_tril)
|
||||||
samples = d.rsample((100000,))
|
samples = d.rsample((100000,))
|
||||||
empirical_mean = samples.mean(0)
|
empirical_mean = samples.mean(0)
|
||||||
self.assertEqual(d.mean, empirical_mean, prec=0.01)
|
self.assertEqual(d.mean, empirical_mean, atol=0.01)
|
||||||
empirical_var = samples.var(0)
|
empirical_var = samples.var(0)
|
||||||
self.assertEqual(d.variance, empirical_var, prec=0.05)
|
self.assertEqual(d.variance, empirical_var, atol=0.05)
|
||||||
|
|
||||||
def test_exponential(self):
|
def test_exponential(self):
|
||||||
rate = torch.randn(5, 5).abs().requires_grad_()
|
rate = torch.randn(5, 5).abs().requires_grad_()
|
||||||
|
|
@ -2062,7 +2062,7 @@ class TestDistributions(TestCase):
|
||||||
set_rng_seed(0)
|
set_rng_seed(0)
|
||||||
self.assertEqual(Laplace(loc_delta, scale_delta).sample(sample_shape=(1, 2)),
|
self.assertEqual(Laplace(loc_delta, scale_delta).sample(sample_shape=(1, 2)),
|
||||||
torch.tensor([[[1.0, 0.0], [1.0, 0.0]]]),
|
torch.tensor([[[1.0, 0.0], [1.0, 0.0]]]),
|
||||||
prec=1e-4)
|
atol=1e-4)
|
||||||
|
|
||||||
self._gradcheck_log_prob(Laplace, (loc, scale))
|
self._gradcheck_log_prob(Laplace, (loc, scale))
|
||||||
self._gradcheck_log_prob(Laplace, (loc, 1.0))
|
self._gradcheck_log_prob(Laplace, (loc, 1.0))
|
||||||
|
|
@ -2438,11 +2438,11 @@ class TestDistributions(TestCase):
|
||||||
self._check_log_prob(ContinuousBernoulli(logits=p.log() - (-p).log1p()), ref_log_prob)
|
self._check_log_prob(ContinuousBernoulli(logits=p.log() - (-p).log1p()), ref_log_prob)
|
||||||
|
|
||||||
# check entropy computation
|
# check entropy computation
|
||||||
self.assertEqual(ContinuousBernoulli(p).entropy(), torch.tensor([-0.02938, -0.07641, -0.00682]), prec=1e-4)
|
self.assertEqual(ContinuousBernoulli(p).entropy(), torch.tensor([-0.02938, -0.07641, -0.00682]), atol=1e-4)
|
||||||
# entropy below corresponds to the clamped value of prob when using float 64
|
# entropy below corresponds to the clamped value of prob when using float 64
|
||||||
# the value for float32 should be -1.76898
|
# the value for float32 should be -1.76898
|
||||||
self.assertEqual(ContinuousBernoulli(torch.tensor([0.0])).entropy(), torch.tensor([-2.58473]))
|
self.assertEqual(ContinuousBernoulli(torch.tensor([0.0])).entropy(), torch.tensor([-2.58473]), atol=1e-5)
|
||||||
self.assertEqual(ContinuousBernoulli(s).entropy(), torch.tensor(-0.02938), prec=1e-4)
|
self.assertEqual(ContinuousBernoulli(s).entropy(), torch.tensor(-0.02938), atol=1e-4)
|
||||||
|
|
||||||
def test_continuous_bernoulli_3d(self):
|
def test_continuous_bernoulli_3d(self):
|
||||||
p = torch.full((2, 3, 5), 0.5).requires_grad_()
|
p = torch.full((2, 3, 5), 0.5).requires_grad_()
|
||||||
|
|
@ -2857,7 +2857,7 @@ class TestRsample(TestCase):
|
||||||
num = 1.0 - 2.0 * alpha - 4.0 * alpha**2
|
num = 1.0 - 2.0 * alpha - 4.0 * alpha**2
|
||||||
den = (1.0 + alpha)**2 * (1.0 + 2.0 * alpha)**3
|
den = (1.0 + alpha)**2 * (1.0 + 2.0 * alpha)**3
|
||||||
expected_grad = num / den
|
expected_grad = num / den
|
||||||
self.assertEqual(actual_grad, expected_grad, 0.002, '\n'.join([
|
self.assertEqual(actual_grad, expected_grad, atol=0.002, message='\n'.join([
|
||||||
"alpha = alpha_c + %.2g" % shift,
|
"alpha = alpha_c + %.2g" % shift,
|
||||||
"expected_grad: %.5g" % expected_grad,
|
"expected_grad: %.5g" % expected_grad,
|
||||||
"actual_grad: %.5g" % actual_grad,
|
"actual_grad: %.5g" % actual_grad,
|
||||||
|
|
@ -3690,7 +3690,7 @@ class TestKL(TestCase):
|
||||||
expected = -dist.log_prob(x).mean(0)
|
expected = -dist.log_prob(x).mean(0)
|
||||||
ignore = (expected == inf) | (expected == -inf)
|
ignore = (expected == inf) | (expected == -inf)
|
||||||
expected[ignore] = actual[ignore]
|
expected[ignore] = actual[ignore]
|
||||||
self.assertEqual(actual, expected, prec=0.2, message='\n'.join([
|
self.assertEqual(actual, expected, atol=0.2, message='\n'.join([
|
||||||
'{} example {}/{}, incorrect .entropy().'.format(Dist.__name__, i + 1, len(params)),
|
'{} example {}/{}, incorrect .entropy().'.format(Dist.__name__, i + 1, len(params)),
|
||||||
'Expected (monte carlo) {}'.format(expected),
|
'Expected (monte carlo) {}'.format(expected),
|
||||||
'Actual (analytic) {}'.format(actual),
|
'Actual (analytic) {}'.format(actual),
|
||||||
|
|
@ -3763,7 +3763,7 @@ class TestNumericalStability(TestCase):
|
||||||
probs=None,
|
probs=None,
|
||||||
logits=None,
|
logits=None,
|
||||||
expected_gradient=None,
|
expected_gradient=None,
|
||||||
prec=1e-5):
|
atol=1e-5):
|
||||||
if probs is not None:
|
if probs is not None:
|
||||||
p = probs.detach().requires_grad_()
|
p = probs.detach().requires_grad_()
|
||||||
dist = dist_class(p)
|
dist = dist_class(p)
|
||||||
|
|
@ -3774,13 +3774,13 @@ class TestNumericalStability(TestCase):
|
||||||
log_pdf.sum().backward()
|
log_pdf.sum().backward()
|
||||||
self.assertEqual(log_pdf,
|
self.assertEqual(log_pdf,
|
||||||
expected_value,
|
expected_value,
|
||||||
prec=prec,
|
atol=atol,
|
||||||
message='Incorrect value for tensor type: {}. Expected = {}, Actual = {}'
|
message='Incorrect value for tensor type: {}. Expected = {}, Actual = {}'
|
||||||
.format(type(x), expected_value, log_pdf))
|
.format(type(x), expected_value, log_pdf))
|
||||||
if expected_gradient is not None:
|
if expected_gradient is not None:
|
||||||
self.assertEqual(p.grad,
|
self.assertEqual(p.grad,
|
||||||
expected_gradient,
|
expected_gradient,
|
||||||
prec=prec,
|
atol=atol,
|
||||||
message='Incorrect gradient for tensor type: {}. Expected = {}, Actual = {}'
|
message='Incorrect gradient for tensor type: {}. Expected = {}, Actual = {}'
|
||||||
.format(type(x), expected_gradient, p.grad))
|
.format(type(x), expected_gradient, p.grad))
|
||||||
|
|
||||||
|
|
@ -3813,14 +3813,14 @@ class TestNumericalStability(TestCase):
|
||||||
x=tensor_type([0]),
|
x=tensor_type([0]),
|
||||||
expected_value=tensor_type([math.log(1e-4)]),
|
expected_value=tensor_type([math.log(1e-4)]),
|
||||||
expected_gradient=tensor_type([-10000]),
|
expected_gradient=tensor_type([-10000]),
|
||||||
prec=2)
|
atol=2)
|
||||||
|
|
||||||
self._test_pdf_score(dist_class=Bernoulli,
|
self._test_pdf_score(dist_class=Bernoulli,
|
||||||
logits=tensor_type([math.log(9999)]),
|
logits=tensor_type([math.log(9999)]),
|
||||||
x=tensor_type([0]),
|
x=tensor_type([0]),
|
||||||
expected_value=tensor_type([math.log(1e-4)]),
|
expected_value=tensor_type([math.log(1e-4)]),
|
||||||
expected_gradient=tensor_type([-1]),
|
expected_gradient=tensor_type([-1]),
|
||||||
prec=1e-3)
|
atol=1e-3)
|
||||||
|
|
||||||
def test_bernoulli_with_logits_underflow(self):
|
def test_bernoulli_with_logits_underflow(self):
|
||||||
for tensor_type, lim in ([(torch.FloatTensor, -1e38),
|
for tensor_type, lim in ([(torch.FloatTensor, -1e38),
|
||||||
|
|
@ -3928,21 +3928,21 @@ class TestNumericalStability(TestCase):
|
||||||
x=tensor_type([1]),
|
x=tensor_type([1]),
|
||||||
expected_value=tensor_type([expec_val(1, probs=1e-4)]),
|
expected_value=tensor_type([expec_val(1, probs=1e-4)]),
|
||||||
expected_gradient=tensor_type(tensor_type([expec_grad(1, probs=1e-4)])),
|
expected_gradient=tensor_type(tensor_type([expec_grad(1, probs=1e-4)])),
|
||||||
prec=1e-3)
|
atol=1e-3)
|
||||||
|
|
||||||
self._test_pdf_score(dist_class=ContinuousBernoulli,
|
self._test_pdf_score(dist_class=ContinuousBernoulli,
|
||||||
probs=tensor_type([1 - 1e-4]),
|
probs=tensor_type([1 - 1e-4]),
|
||||||
x=tensor_type([0.1]),
|
x=tensor_type([0.1]),
|
||||||
expected_value=tensor_type([expec_val(0.1, probs=1 - 1e-4)]),
|
expected_value=tensor_type([expec_val(0.1, probs=1 - 1e-4)]),
|
||||||
expected_gradient=tensor_type([expec_grad(0.1, probs=1 - 1e-4)]),
|
expected_gradient=tensor_type([expec_grad(0.1, probs=1 - 1e-4)]),
|
||||||
prec=2)
|
atol=2)
|
||||||
|
|
||||||
self._test_pdf_score(dist_class=ContinuousBernoulli,
|
self._test_pdf_score(dist_class=ContinuousBernoulli,
|
||||||
logits=tensor_type([math.log(9999)]),
|
logits=tensor_type([math.log(9999)]),
|
||||||
x=tensor_type([0]),
|
x=tensor_type([0]),
|
||||||
expected_value=tensor_type([expec_val(0, logits=math.log(9999))]),
|
expected_value=tensor_type([expec_val(0, logits=math.log(9999))]),
|
||||||
expected_gradient=tensor_type([expec_grad(0, logits=math.log(9999))]),
|
expected_gradient=tensor_type([expec_grad(0, logits=math.log(9999))]),
|
||||||
prec=1e-3)
|
atol=1e-3)
|
||||||
|
|
||||||
self._test_pdf_score(dist_class=ContinuousBernoulli,
|
self._test_pdf_score(dist_class=ContinuousBernoulli,
|
||||||
logits=tensor_type([0.001]),
|
logits=tensor_type([0.001]),
|
||||||
|
|
|
||||||
|
|
@ -7208,7 +7208,7 @@ a")
|
||||||
continue
|
continue
|
||||||
msg = ("Failed on {func_name} with inputs {a} {b}. Python: {res_python}, Script: {res_script}"
|
msg = ("Failed on {func_name} with inputs {a} {b}. Python: {res_python}, Script: {res_script}"
|
||||||
.format(func_name=func_name, a=a, b=b, res_python=res_python, res_script=res_script))
|
.format(func_name=func_name, a=a, b=b, res_python=res_python, res_script=res_script))
|
||||||
self.assertEqual(res_python, res_script, message=msg, prec=(1e-4) * max(abs(res_python), res_script))
|
self.assertEqual(res_python, res_script, message=msg, atol=(1e-4) * max(abs(res_python), res_script))
|
||||||
|
|
||||||
unary_float_ops = ["log", "log1p", "log10", "exp", "sqrt", "gamma", "lgamma", "erf",
|
unary_float_ops = ["log", "log1p", "log10", "exp", "sqrt", "gamma", "lgamma", "erf",
|
||||||
"erfc", "expm1", "fabs", "acos", "asin", "atan", "cos", "sin", "tan",
|
"erfc", "expm1", "fabs", "acos", "asin", "atan", "cos", "sin", "tan",
|
||||||
|
|
|
||||||
|
|
@ -488,7 +488,7 @@ class TestFuser(JitTestCase):
|
||||||
with torch.jit.optimized_execution(False):
|
with torch.jit.optimized_execution(False):
|
||||||
out_noopt = model_noopt(x, y)
|
out_noopt = model_noopt(x, y)
|
||||||
rep_noopt = str(model_noopt.graph_for(x, y))
|
rep_noopt = str(model_noopt.graph_for(x, y))
|
||||||
self.assertEqual(out, out_noopt, prec=3e-5)
|
self.assertEqual(out, out_noopt, atol=3e-5)
|
||||||
|
|
||||||
# Check that normalization op has really been decomposed
|
# Check that normalization op has really been decomposed
|
||||||
for node_in_graph in in_opt_graph:
|
for node_in_graph in in_opt_graph:
|
||||||
|
|
|
||||||
|
|
@ -3777,12 +3777,12 @@ class TestNN(NNTestCase):
|
||||||
conv2.weight.data.copy_(conv1.weight.data)
|
conv2.weight.data.copy_(conv1.weight.data)
|
||||||
out1 = conv1(inputs)
|
out1 = conv1(inputs)
|
||||||
out2 = conv2(inputs)
|
out2 = conv2(inputs)
|
||||||
self.assertEqual(out1, out2, prec=0.0)
|
self.assertEqual(out1, out2, atol=0.0)
|
||||||
y = torch.randn(out1.size(), device="cuda", dtype=dtype)
|
y = torch.randn(out1.size(), device="cuda", dtype=dtype)
|
||||||
out1.backward(y)
|
out1.backward(y)
|
||||||
out2.backward(y)
|
out2.backward(y)
|
||||||
self.assertEqual(conv1.bias.grad.data, conv2.bias.grad.data, prec=0.0)
|
self.assertEqual(conv1.bias.grad.data, conv2.bias.grad.data, atol=0.0)
|
||||||
self.assertEqual(conv1.weight.grad.data, conv2.weight.grad.data, prec=0.0)
|
self.assertEqual(conv1.weight.grad.data, conv2.weight.grad.data, atol=0.0)
|
||||||
|
|
||||||
def test_Conv2d_missing_argument(self):
|
def test_Conv2d_missing_argument(self):
|
||||||
c = nn.Conv2d(3, 3, 3)
|
c = nn.Conv2d(3, 3, 3)
|
||||||
|
|
@ -3988,18 +3988,18 @@ class TestNN(NNTestCase):
|
||||||
output2.backward(grad_output[:, offset:].contiguous())
|
output2.backward(grad_output[:, offset:].contiguous())
|
||||||
|
|
||||||
self.assertEqual(output, torch.cat([output1, output2], 1),
|
self.assertEqual(output, torch.cat([output1, output2], 1),
|
||||||
prec=dtype2prec_DONTUSE[dtype])
|
atol=dtype2prec_DONTUSE[dtype])
|
||||||
self.assertEqual(i.grad.data,
|
self.assertEqual(i.grad.data,
|
||||||
torch.cat([i1.grad.data, i2.grad.data], 1),
|
torch.cat([i1.grad.data, i2.grad.data], 1),
|
||||||
prec=dtype2prec_DONTUSE[dtype])
|
atol=dtype2prec_DONTUSE[dtype])
|
||||||
self.assertEqual(m.bias.grad.data,
|
self.assertEqual(m.bias.grad.data,
|
||||||
torch.cat([m1.bias.grad.data,
|
torch.cat([m1.bias.grad.data,
|
||||||
m2.bias.grad.data], 0),
|
m2.bias.grad.data], 0),
|
||||||
prec=dtype2prec_DONTUSE[dtype])
|
atol=dtype2prec_DONTUSE[dtype])
|
||||||
self.assertEqual(m.weight.grad.data,
|
self.assertEqual(m.weight.grad.data,
|
||||||
torch.cat([m1.weight.grad.data,
|
torch.cat([m1.weight.grad.data,
|
||||||
m2.weight.grad.data], 0),
|
m2.weight.grad.data], 0),
|
||||||
prec=dtype2prec_DONTUSE[dtype])
|
atol=dtype2prec_DONTUSE[dtype])
|
||||||
|
|
||||||
def test_MaxUnpool2d_output_size(self):
|
def test_MaxUnpool2d_output_size(self):
|
||||||
m = nn.MaxPool2d(3, stride=2, return_indices=True)
|
m = nn.MaxPool2d(3, stride=2, return_indices=True)
|
||||||
|
|
@ -4981,7 +4981,7 @@ class TestNN(NNTestCase):
|
||||||
def check_rnn_grads(rnn1, rnn2):
|
def check_rnn_grads(rnn1, rnn2):
|
||||||
for x_layer, y_layer in zip(rnn1.all_weights, rnn2.all_weights):
|
for x_layer, y_layer in zip(rnn1.all_weights, rnn2.all_weights):
|
||||||
for x, y in zip(x_layer, y_layer):
|
for x, y in zip(x_layer, y_layer):
|
||||||
self.assertEqual(x.grad, y.grad, prec=5e-5)
|
self.assertEqual(x.grad, y.grad, atol=5e-5)
|
||||||
|
|
||||||
input_size = 10
|
input_size = 10
|
||||||
hidden_size = 6
|
hidden_size = 6
|
||||||
|
|
@ -5396,12 +5396,12 @@ class TestNN(NNTestCase):
|
||||||
self.assertEqual(list(outputs_cpu.keys()), list(outputs_gpu.keys()))
|
self.assertEqual(list(outputs_cpu.keys()), list(outputs_gpu.keys()))
|
||||||
for key in outputs_cpu.keys():
|
for key in outputs_cpu.keys():
|
||||||
if key != 'weights':
|
if key != 'weights':
|
||||||
self.assertEqual(outputs_cpu[key], outputs_gpu[key], prec=5e-5, message=key)
|
self.assertEqual(outputs_cpu[key], outputs_gpu[key], atol=5e-5, message=key)
|
||||||
|
|
||||||
# check grad weights separately, as nested dict
|
# check grad weights separately, as nested dict
|
||||||
for cpu_layer_weight, gpu_layer_weight in zip(outputs_cpu['weights'], outputs_gpu['weights']):
|
for cpu_layer_weight, gpu_layer_weight in zip(outputs_cpu['weights'], outputs_gpu['weights']):
|
||||||
for (cpu_weight, gpu_weight) in zip(cpu_layer_weight, gpu_layer_weight):
|
for (cpu_weight, gpu_weight) in zip(cpu_layer_weight, gpu_layer_weight):
|
||||||
self.assertEqual(cpu_weight.grad.data, gpu_weight.grad.data, prec=5e-5)
|
self.assertEqual(cpu_weight.grad.data, gpu_weight.grad.data, atol=5e-5)
|
||||||
|
|
||||||
for module in (nn.RNN, nn.LSTM, nn.GRU):
|
for module in (nn.RNN, nn.LSTM, nn.GRU):
|
||||||
for bias, bidirectional, batch_first, contig, variable_len, lens_as_tensor \
|
for bias, bidirectional, batch_first, contig, variable_len, lens_as_tensor \
|
||||||
|
|
@ -6479,7 +6479,7 @@ class TestNN(NNTestCase):
|
||||||
|
|
||||||
out_cuda.backward(gradients.cuda())
|
out_cuda.backward(gradients.cuda())
|
||||||
self.assertEqual(input_cpu.grad, input_cuda.grad)
|
self.assertEqual(input_cpu.grad, input_cuda.grad)
|
||||||
self.assertEqual(grid_cpu.grad, grid_cuda.grad, prec=5e-5)
|
self.assertEqual(grid_cpu.grad, grid_cuda.grad, atol=5e-5)
|
||||||
|
|
||||||
# check that zero-dimensional input strides don't error out
|
# check that zero-dimensional input strides don't error out
|
||||||
base_input = torch.randn(N, C, 1, IW)
|
base_input = torch.randn(N, C, 1, IW)
|
||||||
|
|
@ -6618,8 +6618,8 @@ class TestNN(NNTestCase):
|
||||||
raise AssertionError("missing groundtruth test for interpolation mode '{}'".format(mode))
|
raise AssertionError("missing groundtruth test for interpolation mode '{}'".format(mode))
|
||||||
output = F.grid_sample(input, grid, mode=mode, padding_mode=padding_mode,
|
output = F.grid_sample(input, grid, mode=mode, padding_mode=padding_mode,
|
||||||
align_corners=align_corners)
|
align_corners=align_corners)
|
||||||
self.assertEqual(output, groundtruth,
|
self.assertEqual(output, groundtruth, atol=1e-5,
|
||||||
"groundtruth comparison failed for mode={}, "
|
message="groundtruth comparison failed for mode={}, "
|
||||||
"padding_mode={}".format(mode, padding_mode))
|
"padding_mode={}".format(mode, padding_mode))
|
||||||
|
|
||||||
# explicit check for gradient edge cases
|
# explicit check for gradient edge cases
|
||||||
|
|
@ -6707,7 +6707,7 @@ class TestNN(NNTestCase):
|
||||||
|
|
||||||
out_cuda.backward(gradients.cuda())
|
out_cuda.backward(gradients.cuda())
|
||||||
self.assertEqual(input_cpu.grad, input_cuda.grad)
|
self.assertEqual(input_cpu.grad, input_cuda.grad)
|
||||||
self.assertEqual(grid_cpu.grad, grid_cuda.grad, prec=5e-5)
|
self.assertEqual(grid_cpu.grad, grid_cuda.grad, atol=5e-5)
|
||||||
|
|
||||||
# check that zero-dimensional input strides don't error out
|
# check that zero-dimensional input strides don't error out
|
||||||
base_input = torch.randn(N, C, 1, IH, IW)
|
base_input = torch.randn(N, C, 1, IH, IW)
|
||||||
|
|
@ -7216,7 +7216,7 @@ class TestNN(NNTestCase):
|
||||||
[6.10547, 6.43750, 6.98438, 7.31641]]]])
|
[6.10547, 6.43750, 6.98438, 7.31641]]]])
|
||||||
out_t = F.interpolate(in_t, scale_factor=2, mode='bicubic', align_corners=False)
|
out_t = F.interpolate(in_t, scale_factor=2, mode='bicubic', align_corners=False)
|
||||||
torch.set_printoptions(precision=5)
|
torch.set_printoptions(precision=5)
|
||||||
self.assertEqual(out_t, expected_out_t)
|
self.assertEqual(out_t, expected_out_t, atol=1e-5)
|
||||||
|
|
||||||
device_list = ['cpu']
|
device_list = ['cpu']
|
||||||
if TEST_CUDA:
|
if TEST_CUDA:
|
||||||
|
|
@ -7230,7 +7230,7 @@ class TestNN(NNTestCase):
|
||||||
in_t = torch.ones(2, 2, 2, 2).to(device)
|
in_t = torch.ones(2, 2, 2, 2).to(device)
|
||||||
out_t = F.interpolate(in_t, scale_factor=scale_factor, **kwargs)
|
out_t = F.interpolate(in_t, scale_factor=scale_factor, **kwargs)
|
||||||
out_size = int(math.floor(in_t.shape[-1] * scale_factor))
|
out_size = int(math.floor(in_t.shape[-1] * scale_factor))
|
||||||
self.assertEqual(torch.ones(2, 2, out_size, out_size), out_t.data)
|
self.assertEqual(torch.ones(2, 2, out_size, out_size), out_t.data, atol=1e-5)
|
||||||
|
|
||||||
input = torch.randn(2, 2, 2, 2, requires_grad=True)
|
input = torch.randn(2, 2, 2, 2, requires_grad=True)
|
||||||
gradcheck(lambda x: F.interpolate(x, out_size, **kwargs), [input])
|
gradcheck(lambda x: F.interpolate(x, out_size, **kwargs), [input])
|
||||||
|
|
@ -7262,7 +7262,7 @@ class TestNN(NNTestCase):
|
||||||
[5.92212, 6.16094, 6.62870, 7.04680]]]])
|
[5.92212, 6.16094, 6.62870, 7.04680]]]])
|
||||||
out_t = F.interpolate(in_t, scale_factor=2.3, mode='bicubic', align_corners=False, recompute_scale_factor=False)
|
out_t = F.interpolate(in_t, scale_factor=2.3, mode='bicubic', align_corners=False, recompute_scale_factor=False)
|
||||||
torch.set_printoptions(precision=5)
|
torch.set_printoptions(precision=5)
|
||||||
self.assertEqual(out_t, expected_out_t)
|
self.assertEqual(out_t, expected_out_t, atol=1e-5)
|
||||||
|
|
||||||
device_list = ['cpu']
|
device_list = ['cpu']
|
||||||
if TEST_CUDA:
|
if TEST_CUDA:
|
||||||
|
|
@ -7276,7 +7276,7 @@ class TestNN(NNTestCase):
|
||||||
in_t = torch.ones(2, 2, 2, 2).to(device)
|
in_t = torch.ones(2, 2, 2, 2).to(device)
|
||||||
out_t = F.interpolate(in_t, scale_factor=scale_factor, **kwargs)
|
out_t = F.interpolate(in_t, scale_factor=scale_factor, **kwargs)
|
||||||
out_size = int(math.floor(in_t.shape[-1] * scale_factor))
|
out_size = int(math.floor(in_t.shape[-1] * scale_factor))
|
||||||
self.assertEqual(torch.ones(2, 2, out_size, out_size), out_t.data)
|
self.assertEqual(torch.ones(2, 2, out_size, out_size), out_t.data, atol=1e-5)
|
||||||
|
|
||||||
input = torch.randn(2, 2, 2, 2, requires_grad=True)
|
input = torch.randn(2, 2, 2, 2, requires_grad=True)
|
||||||
gradcheck(lambda x: F.interpolate(x, out_size, **kwargs), [input])
|
gradcheck(lambda x: F.interpolate(x, out_size, **kwargs), [input])
|
||||||
|
|
@ -7800,12 +7800,12 @@ class TestNN(NNTestCase):
|
||||||
outf = F.log_softmax(inputf, dim=-1)
|
outf = F.log_softmax(inputf, dim=-1)
|
||||||
out = F.log_softmax(input, dim=-1)
|
out = F.log_softmax(input, dim=-1)
|
||||||
self.assertEqual(out.dtype, dtype)
|
self.assertEqual(out.dtype, dtype)
|
||||||
self.assertEqual(out, outf, prec=0.1)
|
self.assertEqual(out, outf, atol=0.1)
|
||||||
|
|
||||||
out.sum().backward()
|
out.sum().backward()
|
||||||
outf.sum().backward()
|
outf.sum().backward()
|
||||||
self.assertEqual(input.grad.dtype, dtype)
|
self.assertEqual(input.grad.dtype, dtype)
|
||||||
self.assertEqual(input.grad, inputf.grad.to(dtype), prec=0.1)
|
self.assertEqual(input.grad, inputf.grad.to(dtype), atol=0.1)
|
||||||
|
|
||||||
def test_adaptive_log_softmax(self):
|
def test_adaptive_log_softmax(self):
|
||||||
# args validation
|
# args validation
|
||||||
|
|
@ -7909,12 +7909,12 @@ class TestNN(NNTestCase):
|
||||||
outf = loss_cpu(inputf, target)
|
outf = loss_cpu(inputf, target)
|
||||||
out = loss_cpu(input, target)
|
out = loss_cpu(input, target)
|
||||||
self.assertEqual(out.dtype, dtype)
|
self.assertEqual(out.dtype, dtype)
|
||||||
self.assertEqual(out, outf, prec=1e-1)
|
self.assertEqual(out, outf, atol=1e-1)
|
||||||
|
|
||||||
outf.backward()
|
outf.backward()
|
||||||
out.backward()
|
out.backward()
|
||||||
self.assertEqual(input.grad.dtype, dtype)
|
self.assertEqual(input.grad.dtype, dtype)
|
||||||
self.assertEqual(input.grad, inputf.grad, prec=1e-1)
|
self.assertEqual(input.grad, inputf.grad, atol=1e-1)
|
||||||
|
|
||||||
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
|
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
|
||||||
def test_convert_sync_batchnorm(self):
|
def test_convert_sync_batchnorm(self):
|
||||||
|
|
@ -8362,10 +8362,10 @@ class TestNNInit(TestCase):
|
||||||
flattened_tensor = input_tensor.view(rows, cols)
|
flattened_tensor = input_tensor.view(rows, cols)
|
||||||
if rows > cols:
|
if rows > cols:
|
||||||
self.assertEqual(torch.mm(flattened_tensor.t(), flattened_tensor),
|
self.assertEqual(torch.mm(flattened_tensor.t(), flattened_tensor),
|
||||||
torch.eye(cols) * gain ** 2, prec=1e-6)
|
torch.eye(cols) * gain ** 2, atol=1e-6)
|
||||||
else:
|
else:
|
||||||
self.assertEqual(torch.mm(flattened_tensor, flattened_tensor.t()),
|
self.assertEqual(torch.mm(flattened_tensor, flattened_tensor.t()),
|
||||||
torch.eye(rows) * gain ** 2, prec=1e-6)
|
torch.eye(rows) * gain ** 2, atol=1e-6)
|
||||||
|
|
||||||
def test_deprecation(self):
|
def test_deprecation(self):
|
||||||
x = torch.randn(3, 3)
|
x = torch.randn(3, 3)
|
||||||
|
|
@ -9527,7 +9527,7 @@ class TestNNDeviceType(NNTestCase):
|
||||||
# Shape unchanged
|
# Shape unchanged
|
||||||
self.assertTrue(y_draw.shape == logits.shape)
|
self.assertTrue(y_draw.shape == logits.shape)
|
||||||
# One choice per draw
|
# One choice per draw
|
||||||
self.assertEqual(y_draw.sum(), count_expected, prec=torch.finfo(y_draw.dtype).eps)
|
self.assertEqual(y_draw.sum(), count_expected, atol=torch.finfo(y_draw.dtype).eps)
|
||||||
|
|
||||||
def _test_gumbel_softmax_straight_through(self, device, dtype):
|
def _test_gumbel_softmax_straight_through(self, device, dtype):
|
||||||
num_draws = 100
|
num_draws = 100
|
||||||
|
|
@ -9545,7 +9545,7 @@ class TestNNDeviceType(NNTestCase):
|
||||||
# All values positive
|
# All values positive
|
||||||
self.assertGreaterEqual(y_draw.min(), 0)
|
self.assertGreaterEqual(y_draw.min(), 0)
|
||||||
# Each experiment should result in 1 draw.
|
# Each experiment should result in 1 draw.
|
||||||
self.assertEqual(counts.sum(), num_draws, prec=torch.finfo(counts.dtype).eps)
|
self.assertEqual(counts.sum(), num_draws, atol=torch.finfo(counts.dtype).eps)
|
||||||
|
|
||||||
# check results is asymptotically as expected.
|
# check results is asymptotically as expected.
|
||||||
expected = probs * num_draws
|
expected = probs * num_draws
|
||||||
|
|
@ -9769,12 +9769,12 @@ class TestNNDeviceType(NNTestCase):
|
||||||
out = F.softmax(input, dim=-1, dtype=torch.float)
|
out = F.softmax(input, dim=-1, dtype=torch.float)
|
||||||
outf = F.softmax(inputf, dim=-1)
|
outf = F.softmax(inputf, dim=-1)
|
||||||
# should be bitwise equal
|
# should be bitwise equal
|
||||||
self.assertEqual(out, outf, prec=0)
|
self.assertEqual(out, outf, atol=0)
|
||||||
gO = torch.empty_like(outf).uniform_()
|
gO = torch.empty_like(outf).uniform_()
|
||||||
out.backward(gO)
|
out.backward(gO)
|
||||||
outf.backward(gO)
|
outf.backward(gO)
|
||||||
# should be bitwise equal
|
# should be bitwise equal
|
||||||
self.assertEqual(input.grad, inputf.grad.to(dtype), prec=0)
|
self.assertEqual(input.grad, inputf.grad.to(dtype), atol=0)
|
||||||
|
|
||||||
@onlyCUDA
|
@onlyCUDA
|
||||||
def test_pool3d_size_one_feature_dim(self, device):
|
def test_pool3d_size_one_feature_dim(self, device):
|
||||||
|
|
@ -9904,7 +9904,7 @@ class TestNNDeviceType(NNTestCase):
|
||||||
expected = self._embedding_bag_reference_impl(
|
expected = self._embedding_bag_reference_impl(
|
||||||
input, reference_weights, offsets, mode, ref_per_sample_weights)
|
input, reference_weights, offsets, mode, ref_per_sample_weights)
|
||||||
result = es(input, offsets, per_sample_weights)
|
result = es(input, offsets, per_sample_weights)
|
||||||
self.assertEqual(result, expected, prec=dtype2prec_DONTUSE[dtype])
|
self.assertEqual(result, expected, atol=dtype2prec_DONTUSE[dtype])
|
||||||
|
|
||||||
grad = torch.randn_like(expected)
|
grad = torch.randn_like(expected)
|
||||||
result.backward(grad)
|
result.backward(grad)
|
||||||
|
|
@ -9913,7 +9913,7 @@ class TestNNDeviceType(NNTestCase):
|
||||||
dtype2prec_DONTUSE[dtype])
|
dtype2prec_DONTUSE[dtype])
|
||||||
if trainable_scale:
|
if trainable_scale:
|
||||||
self.assertEqual(per_sample_weights.grad, ref_per_sample_weights.grad,
|
self.assertEqual(per_sample_weights.grad, ref_per_sample_weights.grad,
|
||||||
prec=dtype2prec_DONTUSE[dtype])
|
atol=dtype2prec_DONTUSE[dtype])
|
||||||
|
|
||||||
if device == 'cuda':
|
if device == 'cuda':
|
||||||
dtypes = (torch.float, torch.double, torch.half)
|
dtypes = (torch.float, torch.double, torch.half)
|
||||||
|
|
@ -9944,7 +9944,7 @@ class TestNNDeviceType(NNTestCase):
|
||||||
expected = self._embedding_bag_reference_impl(
|
expected = self._embedding_bag_reference_impl(
|
||||||
input, reference_weights, offsets, mode, ref_per_sample_weights, include_last_offset)
|
input, reference_weights, offsets, mode, ref_per_sample_weights, include_last_offset)
|
||||||
result = es(input, offsets, per_sample_weights)
|
result = es(input, offsets, per_sample_weights)
|
||||||
self.assertEqual(result, expected, prec=dtype2prec_DONTUSE[dtype])
|
self.assertEqual(result, expected, atol=dtype2prec_DONTUSE[dtype])
|
||||||
|
|
||||||
grad = torch.randn_like(expected)
|
grad = torch.randn_like(expected)
|
||||||
result.backward(grad)
|
result.backward(grad)
|
||||||
|
|
@ -9953,7 +9953,7 @@ class TestNNDeviceType(NNTestCase):
|
||||||
dtype2prec_DONTUSE[dtype])
|
dtype2prec_DONTUSE[dtype])
|
||||||
if trainable_scale:
|
if trainable_scale:
|
||||||
self.assertEqual(per_sample_weights.grad, ref_per_sample_weights.grad,
|
self.assertEqual(per_sample_weights.grad, ref_per_sample_weights.grad,
|
||||||
prec=dtype2prec_DONTUSE[dtype])
|
atol=dtype2prec_DONTUSE[dtype])
|
||||||
|
|
||||||
if device == 'cuda':
|
if device == 'cuda':
|
||||||
dtypes = (torch.float, torch.double, torch.half)
|
dtypes = (torch.float, torch.double, torch.half)
|
||||||
|
|
@ -10240,13 +10240,13 @@ class TestNNDeviceType(NNTestCase):
|
||||||
self.assertEqual(output, torch.cat([output1, output2], 1))
|
self.assertEqual(output, torch.cat([output1, output2], 1))
|
||||||
self.assertEqual(i.grad.data,
|
self.assertEqual(i.grad.data,
|
||||||
torch.cat([i1.grad.data, i2.grad.data], 1),
|
torch.cat([i1.grad.data, i2.grad.data], 1),
|
||||||
prec=dtype2prec_DONTUSE[dtype])
|
atol=dtype2prec_DONTUSE[dtype])
|
||||||
self.assertEqual(m.bias.grad.data,
|
self.assertEqual(m.bias.grad.data,
|
||||||
torch.cat([m1.bias.grad.data, m2.bias.grad.data], 0),
|
torch.cat([m1.bias.grad.data, m2.bias.grad.data], 0),
|
||||||
prec=dtype2prec_DONTUSE[dtype])
|
atol=dtype2prec_DONTUSE[dtype])
|
||||||
self.assertEqual(m.weight.grad.data,
|
self.assertEqual(m.weight.grad.data,
|
||||||
torch.cat([m1.weight.grad.data, m2.weight.grad.data], 0),
|
torch.cat([m1.weight.grad.data, m2.weight.grad.data], 0),
|
||||||
prec=dtype2prec_DONTUSE[dtype])
|
atol=dtype2prec_DONTUSE[dtype])
|
||||||
|
|
||||||
def _test_batchnorm_grad(self, device, dtype=torch.double):
|
def _test_batchnorm_grad(self, device, dtype=torch.double):
|
||||||
bs, n_feat, size_feat = 4, 5, 6
|
bs, n_feat, size_feat = 4, 5, 6
|
||||||
|
|
@ -10692,8 +10692,8 @@ class TestNNDeviceType(NNTestCase):
|
||||||
out2 = op_bfp16(input2)
|
out2 = op_bfp16(input2)
|
||||||
out2.backward(grad_input2)
|
out2.backward(grad_input2)
|
||||||
|
|
||||||
self.assertEqual(out1, out2, prec=prec)
|
self.assertEqual(out1, out2, atol=prec)
|
||||||
self.assertEqual(input1.grad.data, input2.grad.data, prec=prec)
|
self.assertEqual(input1.grad.data, input2.grad.data, atol=prec)
|
||||||
|
|
||||||
@onlyCUDA
|
@onlyCUDA
|
||||||
@skipCUDAIfNotRocm
|
@skipCUDAIfNotRocm
|
||||||
|
|
|
||||||
|
|
@ -230,7 +230,7 @@ class _TestTorchMixin(object):
|
||||||
input = torch.empty(10).uniform_(d, 10)
|
input = torch.empty(10).uniform_(d, 10)
|
||||||
res_torch = torch.mvlgamma(input, d)
|
res_torch = torch.mvlgamma(input, d)
|
||||||
res_scipy = multigammaln(input.numpy(), d)
|
res_scipy = multigammaln(input.numpy(), d)
|
||||||
self.assertEqual(res_torch.numpy(), res_scipy)
|
self.assertEqual(res_torch.numpy(), res_scipy, atol=1e-5)
|
||||||
|
|
||||||
def test_mvlgamma_argcheck(self):
|
def test_mvlgamma_argcheck(self):
|
||||||
def run_test(d):
|
def run_test(d):
|
||||||
|
|
@ -503,7 +503,7 @@ class _TestTorchMixin(object):
|
||||||
for i, j in iter_indices(m1):
|
for i, j in iter_indices(m1):
|
||||||
res2[i] += m1[i][j] * v1[j]
|
res2[i] += m1[i][j] * v1[j]
|
||||||
|
|
||||||
self.assertEqual(res1, res2)
|
self.assertEqual(res1, res2, atol=1e-5)
|
||||||
|
|
||||||
_test_mv(torch.randn(100, 100, dtype=torch.float32), torch.randn(100, dtype=torch.float32))
|
_test_mv(torch.randn(100, 100, dtype=torch.float32), torch.randn(100, dtype=torch.float32))
|
||||||
_test_mv(torch.randn(100, 100, dtype=torch.float64), torch.randn(100, dtype=torch.float64))
|
_test_mv(torch.randn(100, 100, dtype=torch.float64), torch.randn(100, dtype=torch.float64))
|
||||||
|
|
@ -1565,7 +1565,7 @@ class _TestTorchMixin(object):
|
||||||
|
|
||||||
def test_not_equal(self):
|
def test_not_equal(self):
|
||||||
ones = torch.ones(10, dtype=torch.int)
|
ones = torch.ones(10, dtype=torch.int)
|
||||||
self.assertRaisesRegex(AssertionError, "0 not greater than or equal to",
|
self.assertRaisesRegex(AssertionError, "0 not greater than",
|
||||||
lambda: self.assertNotEqual(ones, ones))
|
lambda: self.assertNotEqual(ones, ones))
|
||||||
|
|
||||||
def assertIsOrdered(self, order, x, mxx, ixx, task):
|
def assertIsOrdered(self, order, x, mxx, ixx, task):
|
||||||
|
|
@ -1955,10 +1955,10 @@ class _TestTorchMixin(object):
|
||||||
for normalized in (True, False):
|
for normalized in (True, False):
|
||||||
res = x.fft(signal_ndim, normalized=normalized)
|
res = x.fft(signal_ndim, normalized=normalized)
|
||||||
rec = res.ifft(signal_ndim, normalized=normalized)
|
rec = res.ifft(signal_ndim, normalized=normalized)
|
||||||
self.assertEqual(x, rec, 1e-8, 'fft and ifft')
|
self.assertEqual(x, rec, atol=1e-8, message='fft and ifft')
|
||||||
res = x.ifft(signal_ndim, normalized=normalized)
|
res = x.ifft(signal_ndim, normalized=normalized)
|
||||||
rec = res.fft(signal_ndim, normalized=normalized)
|
rec = res.fft(signal_ndim, normalized=normalized)
|
||||||
self.assertEqual(x, rec, 1e-8, 'ifft and fft')
|
self.assertEqual(x, rec, atol=1e-8, message='ifft and fft')
|
||||||
|
|
||||||
def _test_real(sizes, signal_ndim, prepro_fn=lambda x: x):
|
def _test_real(sizes, signal_ndim, prepro_fn=lambda x: x):
|
||||||
x = prepro_fn(torch.randn(*sizes, dtype=dtype, device=device))
|
x = prepro_fn(torch.randn(*sizes, dtype=dtype, device=device))
|
||||||
|
|
@ -1991,11 +1991,11 @@ class _TestTorchMixin(object):
|
||||||
test_input_signal_sizes = [signal_sizes]
|
test_input_signal_sizes = [signal_sizes]
|
||||||
rec = res.irfft(signal_ndim, normalized=normalized,
|
rec = res.irfft(signal_ndim, normalized=normalized,
|
||||||
onesided=onesided, signal_sizes=signal_sizes)
|
onesided=onesided, signal_sizes=signal_sizes)
|
||||||
self.assertEqual(x, rec, 1e-8, 'rfft and irfft')
|
self.assertEqual(x, rec, atol=1e-8, message='rfft and irfft')
|
||||||
if not onesided: # check that we can use C2C ifft
|
if not onesided: # check that we can use C2C ifft
|
||||||
rec = res.ifft(signal_ndim, normalized=normalized)
|
rec = res.ifft(signal_ndim, normalized=normalized)
|
||||||
self.assertEqual(x, rec.select(-1, 0), 1e-8, 'twosided rfft and ifft real')
|
self.assertEqual(x, rec.select(-1, 0), atol=1e-8, message='twosided rfft and ifft real')
|
||||||
self.assertEqual(rec.select(-1, 1).abs().mean(), 0, 1e-8, 'twosided rfft and ifft imaginary')
|
self.assertEqual(rec.select(-1, 1).abs().mean(), 0, atol=1e-8, message='twosided rfft and ifft imaginary')
|
||||||
|
|
||||||
# contiguous case
|
# contiguous case
|
||||||
_test_real((100,), 1)
|
_test_real((100,), 1)
|
||||||
|
|
@ -2057,10 +2057,10 @@ class _TestTorchMixin(object):
|
||||||
imvx2 = torch.xcorr2(x, ki, 'V')
|
imvx2 = torch.xcorr2(x, ki, 'V')
|
||||||
imfx = torch.xcorr2(x, ki, 'F')
|
imfx = torch.xcorr2(x, ki, 'F')
|
||||||
|
|
||||||
self.assertEqual(imvc, imvc2, 0, 'torch.conv2')
|
self.assertEqual(imvc, imvc2, atol=0, message='torch.conv2')
|
||||||
self.assertEqual(imvc, imvx, 0, 'torch.conv2')
|
self.assertEqual(imvc, imvx, atol=0, message='torch.conv2')
|
||||||
self.assertEqual(imvc, imvx2, 0, 'torch.conv2')
|
self.assertEqual(imvc, imvx2, atol=0, message='torch.conv2')
|
||||||
self.assertEqual(imfc, imfx, 0, 'torch.conv2')
|
self.assertEqual(imfc, imfx, atol=0, message='torch.conv2')
|
||||||
self.assertLessEqual(math.abs(x.dot(x) - torch.xcorr2(x, x)[0][0]), 1e-10, 'torch.conv2')
|
self.assertLessEqual(math.abs(x.dot(x) - torch.xcorr2(x, x)[0][0]), 1e-10, 'torch.conv2')
|
||||||
|
|
||||||
xx = torch.Tensor(2, x.size(1), x.size(2))
|
xx = torch.Tensor(2, x.size(1), x.size(2))
|
||||||
|
|
@ -2074,11 +2074,11 @@ class _TestTorchMixin(object):
|
||||||
immvc2 = torch.conv2(xx, kk, 'V')
|
immvc2 = torch.conv2(xx, kk, 'V')
|
||||||
immfc = torch.conv2(xx, kk, 'F')
|
immfc = torch.conv2(xx, kk, 'F')
|
||||||
|
|
||||||
self.assertEqual(immvc[0], immvc[1], 0, 'torch.conv2')
|
self.assertEqual(immvc[0], immvc[1], atol=0, message='torch.conv2')
|
||||||
self.assertEqual(immvc[0], imvc, 0, 'torch.conv2')
|
self.assertEqual(immvc[0], imvc, atol=0, message='torch.conv2')
|
||||||
self.assertEqual(immvc2[0], imvc2, 0, 'torch.conv2')
|
self.assertEqual(immvc2[0], imvc2, atol=0, message='torch.conv2')
|
||||||
self.assertEqual(immfc[0], immfc[1], 0, 'torch.conv2')
|
self.assertEqual(immfc[0], immfc[1], atol=0, message='torch.conv2')
|
||||||
self.assertEqual(immfc[0], imfc, 0, 'torch.conv2')
|
self.assertEqual(immfc[0], imfc, atol=0, message='torch.conv2')
|
||||||
|
|
||||||
@unittest.skip("Not implemented yet")
|
@unittest.skip("Not implemented yet")
|
||||||
def test_conv3(self):
|
def test_conv3(self):
|
||||||
|
|
@ -2101,10 +2101,10 @@ class _TestTorchMixin(object):
|
||||||
imvx2 = torch.xcorr3(x, ki, 'V')
|
imvx2 = torch.xcorr3(x, ki, 'V')
|
||||||
imfx = torch.xcorr3(x, ki, 'F')
|
imfx = torch.xcorr3(x, ki, 'F')
|
||||||
|
|
||||||
self.assertEqual(imvc, imvc2, 0, 'torch.conv3')
|
self.assertEqual(imvc, imvc2, atol=0, message='torch.conv3')
|
||||||
self.assertEqual(imvc, imvx, 0, 'torch.conv3')
|
self.assertEqual(imvc, imvx, atol=0, message='torch.conv3')
|
||||||
self.assertEqual(imvc, imvx2, 0, 'torch.conv3')
|
self.assertEqual(imvc, imvx2, atol=0, message='torch.conv3')
|
||||||
self.assertEqual(imfc, imfx, 0, 'torch.conv3')
|
self.assertEqual(imfc, imfx, atol=0, message='torch.conv3')
|
||||||
self.assertLessEqual(math.abs(x.dot(x) - torch.xcorr3(x, x)[0][0][0]), 4e-10, 'torch.conv3')
|
self.assertLessEqual(math.abs(x.dot(x) - torch.xcorr3(x, x)[0][0][0]), 4e-10, 'torch.conv3')
|
||||||
|
|
||||||
xx = torch.Tensor(2, x.size(1), x.size(2), x.size(3))
|
xx = torch.Tensor(2, x.size(1), x.size(2), x.size(3))
|
||||||
|
|
@ -2118,11 +2118,11 @@ class _TestTorchMixin(object):
|
||||||
immvc2 = torch.conv3(xx, kk, 'V')
|
immvc2 = torch.conv3(xx, kk, 'V')
|
||||||
immfc = torch.conv3(xx, kk, 'F')
|
immfc = torch.conv3(xx, kk, 'F')
|
||||||
|
|
||||||
self.assertEqual(immvc[0], immvc[1], 0, 'torch.conv3')
|
self.assertEqual(immvc[0], immvc[1], atol=0, message='torch.conv3')
|
||||||
self.assertEqual(immvc[0], imvc, 0, 'torch.conv3')
|
self.assertEqual(immvc[0], imvc, atol=0, message='torch.conv3')
|
||||||
self.assertEqual(immvc2[0], imvc2, 0, 'torch.conv3')
|
self.assertEqual(immvc2[0], imvc2, atol=0, message='torch.conv3')
|
||||||
self.assertEqual(immfc[0], immfc[1], 0, 'torch.conv3')
|
self.assertEqual(immfc[0], immfc[1], atol=0, message='torch.conv3')
|
||||||
self.assertEqual(immfc[0], imfc, 0, 'torch.conv3')
|
self.assertEqual(immfc[0], imfc, atol=0, message='torch.conv3')
|
||||||
|
|
||||||
@unittest.skip("Not implemented yet")
|
@unittest.skip("Not implemented yet")
|
||||||
def _test_conv_corr_eq(self, fn, fn_2_to_3):
|
def _test_conv_corr_eq(self, fn, fn_2_to_3):
|
||||||
|
|
@ -2202,7 +2202,7 @@ class _TestTorchMixin(object):
|
||||||
# Dramatically alter the internal state of the main generator
|
# Dramatically alter the internal state of the main generator
|
||||||
_ = torch.rand(100000)
|
_ = torch.rand(100000)
|
||||||
forked_value = torch.rand(1000, generator=gen)
|
forked_value = torch.rand(1000, generator=gen)
|
||||||
self.assertEqual(target_value, forked_value, 0, "RNG has not forked correctly.")
|
self.assertEqual(target_value, forked_value, atol=0, message="RNG has not forked correctly.")
|
||||||
|
|
||||||
def test_RNG_after_pickle(self):
|
def test_RNG_after_pickle(self):
|
||||||
torch.random.manual_seed(100)
|
torch.random.manual_seed(100)
|
||||||
|
|
@ -2226,10 +2226,10 @@ class _TestTorchMixin(object):
|
||||||
repeat_midstream = torch.randn(odd_number)
|
repeat_midstream = torch.randn(odd_number)
|
||||||
torch.manual_seed(123)
|
torch.manual_seed(123)
|
||||||
reseeded = torch.randn(odd_number)
|
reseeded = torch.randn(odd_number)
|
||||||
self.assertEqual(midstream, repeat_midstream, 0,
|
self.assertEqual(midstream, repeat_midstream, atol=0,
|
||||||
'get_rng_state/set_rng_state not generating same sequence of normally distributed numbers')
|
message='get_rng_state/set_rng_state not generating same sequence of normally distributed numbers')
|
||||||
self.assertEqual(seeded, reseeded, 0,
|
self.assertEqual(seeded, reseeded, atol=0,
|
||||||
'repeated calls to manual_seed not generating same sequence of normally distributed numbers')
|
message='repeated calls to manual_seed not generating same sequence of normally distributed numbers')
|
||||||
|
|
||||||
def test_manual_seed(self):
|
def test_manual_seed(self):
|
||||||
rng_state = torch.get_rng_state()
|
rng_state = torch.get_rng_state()
|
||||||
|
|
@ -2703,6 +2703,7 @@ class _TestTorchMixin(object):
|
||||||
for wi in w:
|
for wi in w:
|
||||||
self.assertEqual(str(wi.message)[0:52], str(warn))
|
self.assertEqual(str(wi.message)[0:52], str(warn))
|
||||||
|
|
||||||
|
|
||||||
def test_unbiased(self):
|
def test_unbiased(self):
|
||||||
tensor = torch.randn(100)
|
tensor = torch.randn(100)
|
||||||
self.assertEqual(tensor.var(0), tensor.var(0, unbiased=True))
|
self.assertEqual(tensor.var(0), tensor.var(0, unbiased=True))
|
||||||
|
|
@ -3527,7 +3528,7 @@ class _TestTorchMixin(object):
|
||||||
expected_1d = [0.16478512, 0.43221009, 0.84261382, 0.99750268, 0.27460563,
|
expected_1d = [0.16478512, 0.43221009, 0.84261382, 0.99750268, 0.27460563,
|
||||||
0.01084163, 0.73373985, 0.65039611, 0.12329865, 0.35587373]
|
0.01084163, 0.73373985, 0.65039611, 0.12329865, 0.35587373]
|
||||||
actual_1d = engine_1d.draw(10)
|
actual_1d = engine_1d.draw(10)
|
||||||
self.assertEqual(actual_1d.flatten(), torch.tensor(expected_1d))
|
self.assertEqual(actual_1d.flatten(), torch.tensor(expected_1d), atol=1e-5)
|
||||||
self.assertEqual(actual_1d.size(), torch.Size([10, 1]))
|
self.assertEqual(actual_1d.size(), torch.Size([10, 1]))
|
||||||
# make sure random seed if chosen if none is provided
|
# make sure random seed if chosen if none is provided
|
||||||
engine_1d_a = torch.quasirandom.SobolEngine(1, scramble=True)
|
engine_1d_a = torch.quasirandom.SobolEngine(1, scramble=True)
|
||||||
|
|
@ -4037,6 +4038,12 @@ tensor([[[1., 1., 1., ..., 1., 1., 1.],
|
||||||
self.assertEqual(x, xv)
|
self.assertEqual(x, xv)
|
||||||
self.assertEqual(xv, x)
|
self.assertEqual(xv, x)
|
||||||
|
|
||||||
|
self.assertRaisesRegex(AssertionError, "don't combine",
|
||||||
|
lambda: self.assertEqual(x, xv, 1.0, atol=4))
|
||||||
|
|
||||||
|
self.assertRaisesRegex(TypeError, "takes from 3 to 4 positional arguments",
|
||||||
|
lambda: self.assertEqual(x, xv, 1.0, ""))
|
||||||
|
|
||||||
def test_new(self):
|
def test_new(self):
|
||||||
x = torch.autograd.Variable(torch.Tensor())
|
x = torch.autograd.Variable(torch.Tensor())
|
||||||
y = torch.autograd.Variable(torch.randn(4, 4))
|
y = torch.autograd.Variable(torch.randn(4, 4))
|
||||||
|
|
@ -4699,19 +4706,19 @@ tensor([[[1., 1., 1., ..., 1., 1., 1.],
|
||||||
float_tensor = torch.FloatTensor([1.0, tiny_float])
|
float_tensor = torch.FloatTensor([1.0, tiny_float])
|
||||||
double_tensor = torch.DoubleTensor([1.0, tiny_float, tiny_double])
|
double_tensor = torch.DoubleTensor([1.0, tiny_float, tiny_double])
|
||||||
|
|
||||||
self.assertEqual(float_tensor[0], 1.0, prec=0.0)
|
self.assertEqual(float_tensor[0], 1.0, atol=0.0)
|
||||||
self.assertEqual(float_tensor[1], tiny_float, prec=tiny_float / 16)
|
self.assertEqual(float_tensor[1], tiny_float, atol=tiny_float / 16)
|
||||||
self.assertEqual(double_tensor[0], 1.0, prec=0.0)
|
self.assertEqual(double_tensor[0], 1.0, atol=0.0)
|
||||||
self.assertEqual(double_tensor[1], tiny_float, prec=0.0)
|
self.assertEqual(double_tensor[1], tiny_float, atol=0.0)
|
||||||
self.assertEqual(double_tensor[2], tiny_double, prec=0.0)
|
self.assertEqual(double_tensor[2], tiny_double, atol=0.0)
|
||||||
|
|
||||||
torch.set_flush_denormal(True)
|
torch.set_flush_denormal(True)
|
||||||
self.assertEqual(float_tensor[0], 1.0, prec=0.0)
|
self.assertEqual(float_tensor[0], 1.0, atol=0.0)
|
||||||
self.assertEqual(float_tensor[1], 0.0, prec=0.0) # tiny_float to zero
|
self.assertEqual(float_tensor[1], 0.0, atol=0.0) # tiny_float to zero
|
||||||
self.assertEqual(double_tensor[0], 1.0, prec=0.0)
|
self.assertEqual(double_tensor[0], 1.0, atol=0.0)
|
||||||
# tiny_float is not converted to zero in double type
|
# tiny_float is not converted to zero in double type
|
||||||
self.assertEqual(double_tensor[1], tiny_float, prec=0.0)
|
self.assertEqual(double_tensor[1], tiny_float, atol=0.0)
|
||||||
self.assertEqual(double_tensor[2], 0.0, prec=0.0) # tiny_double to zero
|
self.assertEqual(double_tensor[2], 0.0, atol=0.0) # tiny_double to zero
|
||||||
torch.set_flush_denormal(False)
|
torch.set_flush_denormal(False)
|
||||||
|
|
||||||
def test_show_config(self):
|
def test_show_config(self):
|
||||||
|
|
@ -5907,16 +5914,17 @@ class TestTorchDeviceType(TestCase):
|
||||||
# no batches: 2-D tensors
|
# no batches: 2-D tensors
|
||||||
matrix = random_fullrank_matrix_distinct_singular_value(5).to(device)
|
matrix = random_fullrank_matrix_distinct_singular_value(5).to(device)
|
||||||
matrix_inverse = torch.inverse(matrix)
|
matrix_inverse = torch.inverse(matrix)
|
||||||
|
|
||||||
identity = torch.eye(5, dtype=torch.float64, device=device)
|
identity = torch.eye(5, dtype=torch.float64, device=device)
|
||||||
self.assertEqual(identity, torch.mm(matrix, matrix_inverse), 1e-8, 'inverse value')
|
self.assertEqual(identity, torch.mm(matrix, matrix_inverse), atol=1e-8, message='inverse value')
|
||||||
self.assertEqual(identity, torch.mm(matrix_inverse, matrix), 1e-8, 'inverse value')
|
self.assertEqual(identity, torch.mm(matrix_inverse, matrix), atol=1e-8, message='inverse value')
|
||||||
|
|
||||||
matrix_inverse_out = torch.empty(5, 5, dtype=torch.float64, device=device)
|
matrix_inverse_out = torch.empty(5, 5, dtype=torch.float64, device=device)
|
||||||
torch.inverse(matrix, out=matrix_inverse_out)
|
torch.inverse(matrix, out=matrix_inverse_out)
|
||||||
self.assertEqual(matrix_inverse_out, matrix_inverse, 0, 'inverse value in-place')
|
self.assertEqual(matrix_inverse_out, matrix_inverse, atol=0, message='inverse value in-place')
|
||||||
# second call, now that matrix_inverse_out is transposed
|
# second call, now that matrix_inverse_out is transposed
|
||||||
torch.inverse(matrix, out=matrix_inverse_out)
|
torch.inverse(matrix, out=matrix_inverse_out)
|
||||||
self.assertEqual(matrix_inverse_out, matrix_inverse, 0, 'inverse value in-place')
|
self.assertEqual(matrix_inverse_out, matrix_inverse, atol=0, message='inverse value in-place')
|
||||||
|
|
||||||
# one batch
|
# one batch
|
||||||
matrix = random_fullrank_matrix_distinct_singular_value(5, 1).to(device)
|
matrix = random_fullrank_matrix_distinct_singular_value(5, 1).to(device)
|
||||||
|
|
@ -6550,10 +6558,10 @@ class TestTorchDeviceType(TestCase):
|
||||||
# Testing against definition for pseudo-inverses
|
# Testing against definition for pseudo-inverses
|
||||||
MPI = torch.pinverse(M)
|
MPI = torch.pinverse(M)
|
||||||
if M.numel() > 0:
|
if M.numel() > 0:
|
||||||
self.assertEqual(M, M.matmul(MPI).matmul(M), 1e-8, 'pseudo-inverse condition 1')
|
self.assertEqual(M, M.matmul(MPI).matmul(M), atol=1e-8, message='pseudo-inverse condition 1')
|
||||||
self.assertEqual(MPI, MPI.matmul(M).matmul(MPI), 1e-8, 'pseudo-inverse condition 2')
|
self.assertEqual(MPI, MPI.matmul(M).matmul(MPI), atol=1e-8, message='pseudo-inverse condition 2')
|
||||||
self.assertEqual(M.matmul(MPI), (M.matmul(MPI)).transpose(-2, -1), 1e-8, 'pseudo-inverse condition 3')
|
self.assertEqual(M.matmul(MPI), (M.matmul(MPI)).transpose(-2, -1), atol=1e-8, message='pseudo-inverse condition 3')
|
||||||
self.assertEqual(MPI.matmul(M), (MPI.matmul(M)).transpose(-2, -1), 1e-8, 'pseudo-inverse condition 4')
|
self.assertEqual(MPI.matmul(M), (MPI.matmul(M)).transpose(-2, -1), atol=1e-8, message='pseudo-inverse condition 4')
|
||||||
else:
|
else:
|
||||||
self.assertEqual(M.shape, MPI.shape[:-2] + (MPI.shape[-1], MPI.shape[-2]))
|
self.assertEqual(M.shape, MPI.shape[:-2] + (MPI.shape[-1], MPI.shape[-2]))
|
||||||
for sizes in [(5, 5), (3, 5, 5), (3, 7, 5, 5), # square matrices
|
for sizes in [(5, 5), (3, 5, 5), (3, 7, 5, 5), # square matrices
|
||||||
|
|
@ -6569,7 +6577,7 @@ class TestTorchDeviceType(TestCase):
|
||||||
batchdims = sizes[:-2]
|
batchdims = sizes[:-2]
|
||||||
M = fullrank(matsize, *batchdims, dtype=dtype, device=device)
|
M = fullrank(matsize, *batchdims, dtype=dtype, device=device)
|
||||||
self.assertEqual(torch.eye(matsize, dtype=dtype, device=device).expand(sizes), M.pinverse().matmul(M),
|
self.assertEqual(torch.eye(matsize, dtype=dtype, device=device).expand(sizes), M.pinverse().matmul(M),
|
||||||
1e-7, 'pseudo-inverse for invertible matrix')
|
atol=1e-7, message='pseudo-inverse for invertible matrix')
|
||||||
|
|
||||||
@skipCUDAIfNoMagma
|
@skipCUDAIfNoMagma
|
||||||
@skipCPUIfNoLapack
|
@skipCPUIfNoLapack
|
||||||
|
|
@ -6704,12 +6712,13 @@ class TestTorchDeviceType(TestCase):
|
||||||
sdet, logabsdet = M.slogdet()
|
sdet, logabsdet = M.slogdet()
|
||||||
|
|
||||||
# Test det
|
# Test det
|
||||||
self.assertEqual(det, target_sdet * target_logabsdet.exp(), 1e-7, '{} (det)'.format(desc))
|
self.assertEqual(det, target_sdet * target_logabsdet.exp(), atol=1e-7, message='{} (det)'.format(desc))
|
||||||
|
|
||||||
# Test slogdet
|
# Test slogdet
|
||||||
# Compare the overall value rather than individual parts because of
|
# Compare the overall value rather than individual parts because of
|
||||||
# precision issues when det is near zero.
|
# precision issues when det is near zero.
|
||||||
self.assertEqual(sdet * logabsdet.exp(), target_sdet * target_logabsdet.exp(), 1e-7, '{} (slogdet)'.format(desc))
|
self.assertEqual(sdet * logabsdet.exp(), target_sdet * target_logabsdet.exp(),
|
||||||
|
atol=1e-7, message='{} (slogdet)'.format(desc))
|
||||||
|
|
||||||
# Test logdet
|
# Test logdet
|
||||||
# Compare logdet against our own pytorch slogdet because they should
|
# Compare logdet against our own pytorch slogdet because they should
|
||||||
|
|
@ -6719,7 +6728,8 @@ class TestTorchDeviceType(TestCase):
|
||||||
if sdet.item() < 0:
|
if sdet.item() < 0:
|
||||||
self.assertTrue(logdet.item() != logdet.item(), '{} (logdet negative case)'.format(desc))
|
self.assertTrue(logdet.item() != logdet.item(), '{} (logdet negative case)'.format(desc))
|
||||||
else:
|
else:
|
||||||
self.assertEqual(logdet.exp(), target_logabsdet.exp(), 1e-7, '{} (logdet non-negative case)'.format(desc))
|
self.assertEqual(logdet.exp(), target_logabsdet.exp(),
|
||||||
|
atol=1e-7, message='{} (logdet non-negative case)'.format(desc))
|
||||||
|
|
||||||
eye = torch.eye(5, dtype=dtype, device=device)
|
eye = torch.eye(5, dtype=dtype, device=device)
|
||||||
test_single_det(eye, (torch.ones((), dtype=dtype, device=device), torch.zeros((), dtype=dtype, device=device)), 'identity')
|
test_single_det(eye, (torch.ones((), dtype=dtype, device=device), torch.zeros((), dtype=dtype, device=device)), 'identity')
|
||||||
|
|
@ -7154,12 +7164,12 @@ class TestTorchDeviceType(TestCase):
|
||||||
# test Upper Triangular
|
# test Upper Triangular
|
||||||
U = torch.cholesky(A, True)
|
U = torch.cholesky(A, True)
|
||||||
B = torch.mm(U.t(), U)
|
B = torch.mm(U.t(), U)
|
||||||
self.assertEqual(A, B, 1e-14, 'cholesky (upper) did not allow rebuilding the original matrix')
|
self.assertEqual(A, B, atol=1e-14, message='cholesky (upper) did not allow rebuilding the original matrix')
|
||||||
|
|
||||||
# test Lower Triangular
|
# test Lower Triangular
|
||||||
L = torch.cholesky(A, False)
|
L = torch.cholesky(A, False)
|
||||||
B = torch.mm(L, L.t())
|
B = torch.mm(L, L.t())
|
||||||
self.assertEqual(A, B, 1e-14, 'cholesky (lower) did not allow rebuilding the original matrix')
|
self.assertEqual(A, B, atol=1e-14, message='cholesky (lower) did not allow rebuilding the original matrix')
|
||||||
|
|
||||||
def test_view(self, device):
|
def test_view(self, device):
|
||||||
tensor = torch.rand(15, device=device)
|
tensor = torch.rand(15, device=device)
|
||||||
|
|
@ -8689,7 +8699,7 @@ class TestTorchDeviceType(TestCase):
|
||||||
|
|
||||||
if eigenvectors:
|
if eigenvectors:
|
||||||
x_recon = torch.matmul(torch.matmul(outv, torch.diag_embed(oute)), outv.transpose(-2, -1))
|
x_recon = torch.matmul(torch.matmul(outv, torch.diag_embed(oute)), outv.transpose(-2, -1))
|
||||||
self.assertEqual(x, x_recon, 1e-8, 'Incorrect reconstruction using V @ diag(e) @ V.T')
|
self.assertEqual(x, x_recon, atol=1e-8, message='Incorrect reconstruction using V @ diag(e) @ V.T')
|
||||||
else:
|
else:
|
||||||
eigvals, _ = torch.symeig(x, eigenvectors=True, upper=upper)
|
eigvals, _ = torch.symeig(x, eigenvectors=True, upper=upper)
|
||||||
self.assertEqual(eigvals, oute, 'Eigenvalues mismatch')
|
self.assertEqual(eigvals, oute, 'Eigenvalues mismatch')
|
||||||
|
|
@ -8708,7 +8718,7 @@ class TestTorchDeviceType(TestCase):
|
||||||
rese, resv = torch.symeig(x, eigenvectors=eigenvectors, upper=upper)
|
rese, resv = torch.symeig(x, eigenvectors=eigenvectors, upper=upper)
|
||||||
if eigenvectors:
|
if eigenvectors:
|
||||||
x_recon = torch.matmul(torch.matmul(resv, torch.diag_embed(rese)), resv.transpose(-2, -1))
|
x_recon = torch.matmul(torch.matmul(resv, torch.diag_embed(rese)), resv.transpose(-2, -1))
|
||||||
self.assertEqual(x, x_recon, 1e-8, 'Incorrect reconstruction using V @ diag(e) @ V.T')
|
self.assertEqual(x, x_recon, atol=1e-8, message='Incorrect reconstruction using V @ diag(e) @ V.T')
|
||||||
else:
|
else:
|
||||||
eigvals, _ = torch.symeig(x, eigenvectors=True, upper=upper)
|
eigvals, _ = torch.symeig(x, eigenvectors=True, upper=upper)
|
||||||
self.assertEqual(eigvals, rese, 'Eigenvalues mismatch')
|
self.assertEqual(eigvals, rese, 'Eigenvalues mismatch')
|
||||||
|
|
@ -8732,12 +8742,12 @@ class TestTorchDeviceType(TestCase):
|
||||||
if compute_uv:
|
if compute_uv:
|
||||||
if some:
|
if some:
|
||||||
x_recon = torch.matmul(outu, torch.matmul(outs.diag_embed(), outv.transpose(-2, -1)))
|
x_recon = torch.matmul(outu, torch.matmul(outs.diag_embed(), outv.transpose(-2, -1)))
|
||||||
self.assertEqual(x, x_recon, 1e-8, 'Incorrect reconstruction using U @ diag(S) @ V.T')
|
self.assertEqual(x, x_recon, atol=1e-8, message='Incorrect reconstruction using U @ diag(S) @ V.T')
|
||||||
else:
|
else:
|
||||||
narrow_u = outu[..., :min(*dims[-2:])]
|
narrow_u = outu[..., :min(*dims[-2:])]
|
||||||
narrow_v = outv[..., :min(*dims[-2:])]
|
narrow_v = outv[..., :min(*dims[-2:])]
|
||||||
x_recon = torch.matmul(narrow_u, torch.matmul(outs.diag_embed(), narrow_v.transpose(-2, -1)))
|
x_recon = torch.matmul(narrow_u, torch.matmul(outs.diag_embed(), narrow_v.transpose(-2, -1)))
|
||||||
self.assertEqual(x, x_recon, 1e-8, 'Incorrect reconstruction using U @ diag(S) @ V.T')
|
self.assertEqual(x, x_recon, atol=1e-8, message='Incorrect reconstruction using U @ diag(S) @ V.T')
|
||||||
else:
|
else:
|
||||||
_, singvals, _ = torch.svd(x, compute_uv=True)
|
_, singvals, _ = torch.svd(x, compute_uv=True)
|
||||||
self.assertEqual(singvals, outs, 'Singular values mismatch')
|
self.assertEqual(singvals, outs, 'Singular values mismatch')
|
||||||
|
|
@ -8759,12 +8769,12 @@ class TestTorchDeviceType(TestCase):
|
||||||
if compute_uv:
|
if compute_uv:
|
||||||
if some:
|
if some:
|
||||||
x_recon = torch.matmul(resu, torch.matmul(ress.diag_embed(), resv.transpose(-2, -1)))
|
x_recon = torch.matmul(resu, torch.matmul(ress.diag_embed(), resv.transpose(-2, -1)))
|
||||||
self.assertEqual(x, x_recon, 1e-8, 'Incorrect reconstruction using U @ diag(S) @ V.T')
|
self.assertEqual(x, x_recon, atol=1e-8, message='Incorrect reconstruction using U @ diag(S) @ V.T')
|
||||||
else:
|
else:
|
||||||
narrow_u = resu[..., :min(*dims[-2:])]
|
narrow_u = resu[..., :min(*dims[-2:])]
|
||||||
narrow_v = resv[..., :min(*dims[-2:])]
|
narrow_v = resv[..., :min(*dims[-2:])]
|
||||||
x_recon = torch.matmul(narrow_u, torch.matmul(ress.diag_embed(), narrow_v.transpose(-2, -1)))
|
x_recon = torch.matmul(narrow_u, torch.matmul(ress.diag_embed(), narrow_v.transpose(-2, -1)))
|
||||||
self.assertEqual(x, x_recon, 1e-8, 'Incorrect reconstruction using U @ diag(S) @ V.T')
|
self.assertEqual(x, x_recon, atol=1e-8, message='Incorrect reconstruction using U @ diag(S) @ V.T')
|
||||||
else:
|
else:
|
||||||
_, singvals, _ = torch.svd(x, compute_uv=True)
|
_, singvals, _ = torch.svd(x, compute_uv=True)
|
||||||
self.assertEqual(singvals, ress, 'Singular values mismatch')
|
self.assertEqual(singvals, ress, 'Singular values mismatch')
|
||||||
|
|
@ -8982,7 +8992,7 @@ class TestTorchDeviceType(TestCase):
|
||||||
for p in [0, 1, 2, 3, 4, inf, -inf]:
|
for p in [0, 1, 2, 3, 4, inf, -inf]:
|
||||||
res = x.norm(p).item()
|
res = x.norm(p).item()
|
||||||
expected = np.linalg.norm(xn, p)
|
expected = np.linalg.norm(xn, p)
|
||||||
self.assertEqual(res, expected, "full reduction failed for {}-norm".format(p))
|
self.assertEqual(res, expected, atol=1e-5, message="full reduction failed for {}-norm".format(p))
|
||||||
|
|
||||||
# one dimension
|
# one dimension
|
||||||
x = torch.randn(25, 25, device=device)
|
x = torch.randn(25, 25, device=device)
|
||||||
|
|
@ -12167,7 +12177,7 @@ class TestTorchDeviceType(TestCase):
|
||||||
# NB: librosa defaults to np.complex64 output, no matter what
|
# NB: librosa defaults to np.complex64 output, no matter what
|
||||||
# the input dtype
|
# the input dtype
|
||||||
ref_result = librosa_stft(x, n_fft, hop_length, win_length, window, center)
|
ref_result = librosa_stft(x, n_fft, hop_length, win_length, window, center)
|
||||||
self.assertEqual(result, ref_result, 7e-6, 'stft comparison against librosa', exact_dtype=False)
|
self.assertEqual(result, ref_result, atol=7e-6, message='stft comparison against librosa', exact_dtype=False)
|
||||||
else:
|
else:
|
||||||
self.assertRaises(expected_error,
|
self.assertRaises(expected_error,
|
||||||
lambda: x.stft(n_fft, hop_length, win_length, window, center=center))
|
lambda: x.stft(n_fft, hop_length, win_length, window, center=center))
|
||||||
|
|
@ -13931,12 +13941,12 @@ class TestTorchDeviceType(TestCase):
|
||||||
v = torch.zeros(4, 4, dtype=dtype, device=device)
|
v = torch.zeros(4, 4, dtype=dtype, device=device)
|
||||||
torch.eig(X, True, out=(e, v))
|
torch.eig(X, True, out=(e, v))
|
||||||
Xhat = torch.mm(torch.mm(v, torch.diag(e.select(1, 0))), v.t())
|
Xhat = torch.mm(torch.mm(v, torch.diag(e.select(1, 0))), v.t())
|
||||||
self.assertEqual(X, Xhat, 1e-8, 'VeV\' wrong')
|
self.assertEqual(X, Xhat, atol=1e-8, message='VeV\' wrong')
|
||||||
self.assertFalse(v.is_contiguous(), 'V is contiguous')
|
self.assertFalse(v.is_contiguous(), 'V is contiguous')
|
||||||
|
|
||||||
torch.eig(X, True, out=(e, v))
|
torch.eig(X, True, out=(e, v))
|
||||||
Xhat = torch.mm(v, torch.mm(e.select(1, 0).diag(), v.t()))
|
Xhat = torch.mm(v, torch.mm(e.select(1, 0).diag(), v.t()))
|
||||||
self.assertEqual(X, Xhat, 1e-8, 'VeV\' wrong')
|
self.assertEqual(X, Xhat, atol=1e-8, message='VeV\' wrong')
|
||||||
self.assertFalse(v.is_contiguous(), 'V is contiguous')
|
self.assertFalse(v.is_contiguous(), 'V is contiguous')
|
||||||
|
|
||||||
# test non-contiguous
|
# test non-contiguous
|
||||||
|
|
@ -13948,7 +13958,7 @@ class TestTorchDeviceType(TestCase):
|
||||||
self.assertFalse(e.is_contiguous(), 'E is contiguous')
|
self.assertFalse(e.is_contiguous(), 'E is contiguous')
|
||||||
torch.eig(X, True, out=(e, v))
|
torch.eig(X, True, out=(e, v))
|
||||||
Xhat = torch.mm(torch.mm(v, torch.diag(e.select(1, 0))), v.t())
|
Xhat = torch.mm(torch.mm(v, torch.diag(e.select(1, 0))), v.t())
|
||||||
self.assertEqual(X, Xhat, 1e-8, 'VeV\' wrong')
|
self.assertEqual(X, Xhat, atol=1e-8, message='VeV\' wrong')
|
||||||
|
|
||||||
@skipCUDAIfNoMagma
|
@skipCUDAIfNoMagma
|
||||||
@skipCPUIfNoLapack
|
@skipCPUIfNoLapack
|
||||||
|
|
@ -13988,7 +13998,7 @@ class TestTorchDeviceType(TestCase):
|
||||||
self.assertEqual(qform(B, X[:, :k]), I)
|
self.assertEqual(qform(B, X[:, :k]), I)
|
||||||
|
|
||||||
# Check block equation
|
# Check block equation
|
||||||
self.assertEqual(qform(A, X[:, :k]) / E[:k], I, prec=0.2)
|
self.assertEqual(qform(A, X[:, :k]) / E[:k], I, atol=0.2)
|
||||||
|
|
||||||
orig_lobpcg = lobpcg
|
orig_lobpcg = lobpcg
|
||||||
|
|
||||||
|
|
@ -14021,7 +14031,7 @@ class TestTorchDeviceType(TestCase):
|
||||||
E, V = lobpcg(A, k=k, n=n, largest=False)
|
E, V = lobpcg(A, k=k, n=n, largest=False)
|
||||||
self.assertEqual(E.shape, batches + (k,))
|
self.assertEqual(E.shape, batches + (k,))
|
||||||
self.assertEqual(V.shape, batches + (m, k))
|
self.assertEqual(V.shape, batches + (m, k))
|
||||||
self.assertEqual(matmul(A, V), mm(V, E.diag_embed()), prec=prec)
|
self.assertEqual(matmul(A, V), mm(V, E.diag_embed()), atol=prec)
|
||||||
e = torch.symeig(A)[0]
|
e = torch.symeig(A)[0]
|
||||||
e_smallest = e[..., :k]
|
e_smallest = e[..., :k]
|
||||||
self.assertEqual(E, e_smallest)
|
self.assertEqual(E, e_smallest)
|
||||||
|
|
@ -14029,17 +14039,17 @@ class TestTorchDeviceType(TestCase):
|
||||||
# classical eigenvalue problem, largest eigenvalues
|
# classical eigenvalue problem, largest eigenvalues
|
||||||
E, V = lobpcg(A, k=k, n=n, largest=True)
|
E, V = lobpcg(A, k=k, n=n, largest=True)
|
||||||
e_largest, _ = torch.sort(e[..., -k:], descending=True)
|
e_largest, _ = torch.sort(e[..., -k:], descending=True)
|
||||||
self.assertEqual(E, e_largest, prec=prec)
|
self.assertEqual(E, e_largest, atol=prec)
|
||||||
self.assertEqual(matmul(A, V), mm(V, E.diag_embed()), prec=prec)
|
self.assertEqual(matmul(A, V), mm(V, E.diag_embed()), atol=prec)
|
||||||
|
|
||||||
# generalized eigenvalue problem, smallest eigenvalues
|
# generalized eigenvalue problem, smallest eigenvalues
|
||||||
E, V = lobpcg(A, B=B, k=k, n=n, largest=False)
|
E, V = lobpcg(A, B=B, k=k, n=n, largest=False)
|
||||||
self.assertEqual(matmul(A, V), mm(matmul(B, V), E.diag_embed()), prec=prec)
|
self.assertEqual(matmul(A, V), mm(matmul(B, V), E.diag_embed()), atol=prec)
|
||||||
|
|
||||||
# generalized eigenvalue problem, largest eigenvalues
|
# generalized eigenvalue problem, largest eigenvalues
|
||||||
E, V = lobpcg(A, B=B, k=k, n=n, largest=True)
|
E, V = lobpcg(A, B=B, k=k, n=n, largest=True)
|
||||||
self.assertEqual(matmul(A, V) / E.max(), mm(matmul(B, V), (E / E.max()).diag_embed()),
|
self.assertEqual(matmul(A, V) / E.max(), mm(matmul(B, V), (E / E.max()).diag_embed()),
|
||||||
prec=prec)
|
atol=prec)
|
||||||
|
|
||||||
# check sparse input
|
# check sparse input
|
||||||
for m, n, k, density in [
|
for m, n, k, density in [
|
||||||
|
|
@ -14061,21 +14071,21 @@ class TestTorchDeviceType(TestCase):
|
||||||
# classical eigenvalue problem, smallest eigenvalues
|
# classical eigenvalue problem, smallest eigenvalues
|
||||||
E, V = lobpcg(A, k=k, n=n, largest=False)
|
E, V = lobpcg(A, k=k, n=n, largest=False)
|
||||||
self.assertEqual(E, e_smallest)
|
self.assertEqual(E, e_smallest)
|
||||||
self.assertEqual(matmul(A, V), mm(V, E.diag_embed()), prec=prec)
|
self.assertEqual(matmul(A, V), mm(V, E.diag_embed()), atol=prec)
|
||||||
|
|
||||||
# classical eigenvalue problem, largest eigenvalues
|
# classical eigenvalue problem, largest eigenvalues
|
||||||
E, V = lobpcg(A, k=k, n=n, largest=True)
|
E, V = lobpcg(A, k=k, n=n, largest=True)
|
||||||
self.assertEqual(matmul(A, V), mm(V, E.diag_embed()), prec=prec)
|
self.assertEqual(matmul(A, V), mm(V, E.diag_embed()), atol=prec)
|
||||||
self.assertEqual(E, e_largest)
|
self.assertEqual(E, e_largest)
|
||||||
|
|
||||||
# generalized eigenvalue problem, smallest eigenvalues
|
# generalized eigenvalue problem, smallest eigenvalues
|
||||||
E, V = lobpcg(A, B=B, k=k, n=n, largest=False)
|
E, V = lobpcg(A, B=B, k=k, n=n, largest=False)
|
||||||
self.assertEqual(matmul(A, V), matmul(B, mm(V, E.diag_embed())), prec=prec)
|
self.assertEqual(matmul(A, V), matmul(B, mm(V, E.diag_embed())), atol=prec)
|
||||||
|
|
||||||
# generalized eigenvalue problem, largest eigenvalues
|
# generalized eigenvalue problem, largest eigenvalues
|
||||||
E, V = lobpcg(A, B=B, k=k, n=n, largest=True)
|
E, V = lobpcg(A, B=B, k=k, n=n, largest=True)
|
||||||
self.assertEqual(matmul(A, V) / E.max(), mm(matmul(B, V), (E / E.max()).diag_embed()),
|
self.assertEqual(matmul(A, V) / E.max(), mm(matmul(B, V), (E / E.max()).diag_embed()),
|
||||||
prec=prec)
|
atol=prec)
|
||||||
|
|
||||||
@skipCPUIfNoLapack
|
@skipCPUIfNoLapack
|
||||||
@onlyCPU
|
@onlyCPU
|
||||||
|
|
@ -14323,7 +14333,7 @@ scipy_lobpcg | {:10.2e} | {:10.2e} | {:6} | N/A
|
||||||
for j in range(100):
|
for j in range(100):
|
||||||
res2[i] += m[i, j] * v[j]
|
res2[i] += m[i, j] * v[j]
|
||||||
|
|
||||||
self.assertEqual(res1, res2)
|
self.assertEqual(res1, res2, atol=self.precision)
|
||||||
|
|
||||||
# Test 0-strided
|
# Test 0-strided
|
||||||
t = torch.randn(1, device=device).to(dtype).expand(10)
|
t = torch.randn(1, device=device).to(dtype).expand(10)
|
||||||
|
|
@ -14336,7 +14346,7 @@ scipy_lobpcg | {:10.2e} | {:10.2e} | {:6} | N/A
|
||||||
for j in range(100):
|
for j in range(100):
|
||||||
res2[i] += m[i, j] * v[j]
|
res2[i] += m[i, j] * v[j]
|
||||||
|
|
||||||
self.assertEqual(res1, res2)
|
self.assertEqual(res1, res2, atol=self.precision)
|
||||||
|
|
||||||
@slowTest
|
@slowTest
|
||||||
@onlyCPU
|
@onlyCPU
|
||||||
|
|
@ -14928,7 +14938,7 @@ scipy_lobpcg | {:10.2e} | {:10.2e} | {:6} | N/A
|
||||||
self.assertEqual(torch.baddbmm(1, res2, 0, b1, b2), res2)
|
self.assertEqual(torch.baddbmm(1, res2, 0, b1, b2), res2)
|
||||||
|
|
||||||
res4 = torch.baddbmm(res2, b1, b2, beta=1, alpha=.5)
|
res4 = torch.baddbmm(res2, b1, b2, beta=1, alpha=.5)
|
||||||
self.assertEqual(res4, res * 3, prec=2e-5)
|
self.assertEqual(res4, res * 3, atol=2e-5)
|
||||||
|
|
||||||
res5 = torch.baddbmm(res2, b1, b2, beta=0, alpha=1)
|
res5 = torch.baddbmm(res2, b1, b2, beta=0, alpha=1)
|
||||||
self.assertEqual(res5, res)
|
self.assertEqual(res5, res)
|
||||||
|
|
@ -15727,7 +15737,7 @@ class TestDevicePrecision(TestCase):
|
||||||
def _test_linspace(self, device, dtype, steps):
|
def _test_linspace(self, device, dtype, steps):
|
||||||
a = torch.linspace(0, 10, steps=steps, dtype=dtype, device=device)
|
a = torch.linspace(0, 10, steps=steps, dtype=dtype, device=device)
|
||||||
b = torch.linspace(0, 10, steps=steps)
|
b = torch.linspace(0, 10, steps=steps)
|
||||||
self.assertEqual(a, b, exact_dtype=False)
|
self.assertEqual(a, b, atol=self.precision, exact_dtype=False)
|
||||||
|
|
||||||
# See NOTE [Linspace+Logspace precision override]
|
# See NOTE [Linspace+Logspace precision override]
|
||||||
@precisionOverride({torch.half: 0.0039 + LINSPACE_LOGSPACE_EXTRA_EPS})
|
@precisionOverride({torch.half: 0.0039 + LINSPACE_LOGSPACE_EXTRA_EPS})
|
||||||
|
|
@ -15745,12 +15755,12 @@ class TestDevicePrecision(TestCase):
|
||||||
def _test_logspace(self, device, dtype, steps):
|
def _test_logspace(self, device, dtype, steps):
|
||||||
a = torch.logspace(1, 1.1, steps=steps, dtype=dtype, device=device)
|
a = torch.logspace(1, 1.1, steps=steps, dtype=dtype, device=device)
|
||||||
b = torch.logspace(1, 1.1, steps=steps)
|
b = torch.logspace(1, 1.1, steps=steps)
|
||||||
self.assertEqual(a, b, exact_dtype=False)
|
self.assertEqual(a, b, atol=self.precision, exact_dtype=False)
|
||||||
|
|
||||||
def _test_logspace_base2(self, device, dtype, steps):
|
def _test_logspace_base2(self, device, dtype, steps):
|
||||||
a = torch.logspace(1, 1.1, steps=steps, base=2, dtype=dtype, device=device)
|
a = torch.logspace(1, 1.1, steps=steps, base=2, dtype=dtype, device=device)
|
||||||
b = torch.logspace(1, 1.1, steps=steps, base=2)
|
b = torch.logspace(1, 1.1, steps=steps, base=2)
|
||||||
self.assertEqual(a, b, exact_dtype=False)
|
self.assertEqual(a, b, atol=self.precision, exact_dtype=False)
|
||||||
|
|
||||||
# See NOTE [Linspace+Logspace precision override]
|
# See NOTE [Linspace+Logspace precision override]
|
||||||
@precisionOverride({torch.half: 0.0157 + LINSPACE_LOGSPACE_EXTRA_EPS})
|
@precisionOverride({torch.half: 0.0157 + LINSPACE_LOGSPACE_EXTRA_EPS})
|
||||||
|
|
@ -15870,7 +15880,7 @@ class TestDevicePrecision(TestCase):
|
||||||
index = index.to(device=device)
|
index = index.to(device=device)
|
||||||
out_gpu = inp_tensor.index_add(0, index, t)
|
out_gpu = inp_tensor.index_add(0, index, t)
|
||||||
|
|
||||||
self.assertEqual(out_cpu, out_gpu, prec=1e-2)
|
self.assertEqual(out_cpu, out_gpu, atol=1e-2)
|
||||||
|
|
||||||
@skipCUDAIfRocm
|
@skipCUDAIfRocm
|
||||||
@dtypes(torch.double)
|
@dtypes(torch.double)
|
||||||
|
|
@ -16949,9 +16959,9 @@ def generate_test_function(cls,
|
||||||
# Compares CPU and device inputs and outputs
|
# Compares CPU and device inputs and outputs
|
||||||
precision = dtype2precision.get(dtype, float_precision)
|
precision = dtype2precision.get(dtype, float_precision)
|
||||||
|
|
||||||
self.assertEqual(cpu_tensor, device_tensor, prec=precision, exact_dtype=False, allow_inf=True)
|
self.assertEqual(cpu_tensor, device_tensor, atol=precision, exact_dtype=False, allow_inf=True)
|
||||||
self.assertEqual(cpu_args, device_args, prec=precision, exact_dtype=False, allow_inf=True)
|
self.assertEqual(cpu_args, device_args, atol=precision, exact_dtype=False, allow_inf=True)
|
||||||
self.assertEqual(cpu_result, device_result, prec=precision, exact_dtype=False, allow_inf=True)
|
self.assertEqual(cpu_result, device_result, atol=precision, exact_dtype=False, allow_inf=True)
|
||||||
|
|
||||||
test_name = "test_" + op_str + subtest_str
|
test_name = "test_" + op_str + subtest_str
|
||||||
assert not hasattr(cls, test_name), "{0} already in TestDevicePrecision".format(test_name)
|
assert not hasattr(cls, test_name), "{0} already in TestDevicePrecision".format(test_name)
|
||||||
|
|
@ -17286,7 +17296,7 @@ class TestTensorDeviceOps(TestCase):
|
||||||
# then the corresponding column of the V has to be changed.
|
# then the corresponding column of the V has to be changed.
|
||||||
# Thus here we only compare result[..., :m].abs() from CPU and device.
|
# Thus here we only compare result[..., :m].abs() from CPU and device.
|
||||||
for x, y in zip(cpu_result, device_result):
|
for x, y in zip(cpu_result, device_result):
|
||||||
self.assertEqual(x[..., :m].abs(), y[..., :m].abs(), prec=1e-5)
|
self.assertEqual(x[..., :m].abs(), y[..., :m].abs(), atol=1e-5)
|
||||||
|
|
||||||
@skipCUDAIfNoMagma
|
@skipCUDAIfNoMagma
|
||||||
@dtypes(*_float_types_no_half)
|
@dtypes(*_float_types_no_half)
|
||||||
|
|
|
||||||
|
|
@ -592,7 +592,7 @@ class TestTypePromotion(TestCase):
|
||||||
if op_name != 'div':
|
if op_name != 'div':
|
||||||
sparse = op(s1, s2)
|
sparse = op(s1, s2)
|
||||||
self.assertEqual(sparse.dtype, e.dtype)
|
self.assertEqual(sparse.dtype, e.dtype)
|
||||||
self.assertEqual(e, sparse.to_dense(), prec=precision, message=err)
|
self.assertEqual(e, sparse.to_dense(), atol=precision, message=err)
|
||||||
else:
|
else:
|
||||||
# sparse division only supports division by a scalar
|
# sparse division only supports division by a scalar
|
||||||
self.assertRaises(RuntimeError, lambda: op(s1, s2).to_dense())
|
self.assertRaises(RuntimeError, lambda: op(s1, s2).to_dense())
|
||||||
|
|
@ -602,7 +602,7 @@ class TestTypePromotion(TestCase):
|
||||||
if inplace:
|
if inplace:
|
||||||
e, d1, s1, d2, s2 = [x.clone() for x in test_tensors]
|
e, d1, s1, d2, s2 = [x.clone() for x in test_tensors]
|
||||||
dense_sparse = op(d1, s2)
|
dense_sparse = op(d1, s2)
|
||||||
self.assertEqual(e, dense_sparse, prec=precision, message=err)
|
self.assertEqual(e, dense_sparse, atol=precision, message=err)
|
||||||
else:
|
else:
|
||||||
# sparse division only supports division by a scalar
|
# sparse division only supports division by a scalar
|
||||||
# mul: Didn't find kernel to dispatch to for operator 'aten::_nnz'
|
# mul: Didn't find kernel to dispatch to for operator 'aten::_nnz'
|
||||||
|
|
@ -623,7 +623,7 @@ class TestTypePromotion(TestCase):
|
||||||
sparse = op(s1, scalar)
|
sparse = op(s1, scalar)
|
||||||
dense_scalar = op(d1, scalar)
|
dense_scalar = op(d1, scalar)
|
||||||
self.assertEqual(sparse.dtype, dense_scalar.dtype)
|
self.assertEqual(sparse.dtype, dense_scalar.dtype)
|
||||||
self.assertEqual(dense_scalar, sparse.to_dense(), prec=precision, message=err)
|
self.assertEqual(dense_scalar, sparse.to_dense(), atol=precision, message=err)
|
||||||
else:
|
else:
|
||||||
# add(sparse, dense) is not supported. Use add(dense, sparse) instead.
|
# add(sparse, dense) is not supported. Use add(dense, sparse) instead.
|
||||||
# "mul_cpu" / "div_cpu" not implemented for 'Half'
|
# "mul_cpu" / "div_cpu" not implemented for 'Half'
|
||||||
|
|
|
||||||
|
|
@ -362,11 +362,11 @@ class TestBottleneck(TestCase):
|
||||||
def _check_run_args(self):
|
def _check_run_args(self):
|
||||||
# Check that this fails due to missing args
|
# Check that this fails due to missing args
|
||||||
rc, out, err = self._run_bottleneck('bottleneck_test/test_args.py')
|
rc, out, err = self._run_bottleneck('bottleneck_test/test_args.py')
|
||||||
self.assertEqual(rc, 2, None, self._fail_msg('Missing args should error', out + err))
|
self.assertEqual(rc, 2, atol=0, message=self._fail_msg('Missing args should error', out + err))
|
||||||
|
|
||||||
# This should succeed
|
# This should succeed
|
||||||
rc, out, err = self._run_bottleneck('bottleneck_test/test_args.py', '--foo foo --bar bar')
|
rc, out, err = self._run_bottleneck('bottleneck_test/test_args.py', '--foo foo --bar bar')
|
||||||
self.assertEqual(rc, 0, None, self._fail_msg('Should pass args to script', out + err))
|
self.assertEqual(rc, 0, atol=0, message=self._fail_msg('Should pass args to script', out + err))
|
||||||
|
|
||||||
def _fail_msg(self, msg, output):
|
def _fail_msg(self, msg, output):
|
||||||
return '{}, output was:\n{}'.format(msg, output)
|
return '{}, output was:\n{}'.format(msg, output)
|
||||||
|
|
|
||||||
|
|
@ -15,7 +15,7 @@ rand_like = torch.rand_like
|
||||||
randn_like = torch.randn_like
|
randn_like = torch.randn_like
|
||||||
|
|
||||||
|
|
||||||
def assert_allclose(actual, expected, rtol=None, atol=None, equal_nan=True):
|
def assert_allclose(actual, expected, rtol=None, atol=None, equal_nan=True, msg=''):
|
||||||
if not isinstance(actual, torch.Tensor):
|
if not isinstance(actual, torch.Tensor):
|
||||||
actual = torch.tensor(actual)
|
actual = torch.tensor(actual)
|
||||||
if not isinstance(expected, torch.Tensor):
|
if not isinstance(expected, torch.Tensor):
|
||||||
|
|
@ -50,13 +50,14 @@ def assert_allclose(actual, expected, rtol=None, atol=None, equal_nan=True):
|
||||||
|
|
||||||
# Count number of offenders
|
# Count number of offenders
|
||||||
count = (~close).long().sum()
|
count = (~close).long().sum()
|
||||||
|
if msg == '' or msg is None:
|
||||||
|
msg = ('Not within tolerance rtol={} atol={} at input{} ({} vs. {}) and {}'
|
||||||
|
' other locations ({:2.2f}%)')
|
||||||
|
msg = msg.format(
|
||||||
|
rtol, atol, list(index), actual[index].item(), expected[index].item(),
|
||||||
|
count - 1, 100 * count / actual.numel())
|
||||||
|
|
||||||
msg = ('Not within tolerance rtol={} atol={} at input{} ({} vs. {}) and {}'
|
raise AssertionError(msg)
|
||||||
' other locations ({:2.2f}%)')
|
|
||||||
|
|
||||||
raise AssertionError(msg.format(
|
|
||||||
rtol, atol, list(index), actual[index].item(), expected[index].item(),
|
|
||||||
count - 1, 100. * count / actual.numel()))
|
|
||||||
|
|
||||||
def make_non_contiguous(tensor):
|
def make_non_contiguous(tensor):
|
||||||
if tensor.numel() <= 1: # can't make non-contiguous
|
if tensor.numel() <= 1: # can't make non-contiguous
|
||||||
|
|
|
||||||
|
|
@ -39,6 +39,8 @@ import __main__
|
||||||
import errno
|
import errno
|
||||||
|
|
||||||
from torch.testing._internal import expecttest
|
from torch.testing._internal import expecttest
|
||||||
|
from torch.testing import get_all_dtypes
|
||||||
|
from torch.testing import get_all_complex_dtypes
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.cuda
|
import torch.cuda
|
||||||
|
|
@ -561,7 +563,7 @@ class CudaMemoryLeakCheck():
|
||||||
for i, (before, after) in enumerate(zip(self.befores, afters)):
|
for i, (before, after) in enumerate(zip(self.befores, afters)):
|
||||||
if not TEST_WITH_ROCM:
|
if not TEST_WITH_ROCM:
|
||||||
self.testcase.assertEqual(
|
self.testcase.assertEqual(
|
||||||
before, after, '{} leaked {} bytes CUDA memory on device {}'.format(
|
before, after, message='{} leaked {} bytes CUDA memory on device {}'.format(
|
||||||
self.name, after - before, i))
|
self.name, after - before, i))
|
||||||
else:
|
else:
|
||||||
# TODO: Investigate ROCm memory leaking.
|
# TODO: Investigate ROCm memory leaking.
|
||||||
|
|
@ -795,27 +797,84 @@ class TestCase(expecttest.TestCase):
|
||||||
|
|
||||||
return tg
|
return tg
|
||||||
|
|
||||||
def assertEqual(self, x, y, prec=None, message='', allow_inf=False, exact_dtype=None):
|
# Some analysis of tolerance by logging tests from test_torch.py can be found
|
||||||
|
# in https://github.com/pytorch/pytorch/pull/32538.
|
||||||
|
# dtype name : (rtol, atol)
|
||||||
|
dtype_precisions = {
|
||||||
|
'float16': (0.001, 1e-5),
|
||||||
|
'bfloat16': (0.016, 1e-5),
|
||||||
|
'float32': (1.3e-6, 1e-5),
|
||||||
|
'float64': (1e-7, 1e-7),
|
||||||
|
'complex32': (0.001, 1e-5),
|
||||||
|
'complex64': (1.3e-6, 1e-5),
|
||||||
|
'complex128': (1e-7, 1e-7),
|
||||||
|
}
|
||||||
|
|
||||||
|
# todo: implement numpy-like issubdtype
|
||||||
|
def is_integral(self, dtype):
|
||||||
|
# Skip complex/quantized types
|
||||||
|
dtypes = [x for x in get_all_dtypes() if x not in get_all_complex_dtypes()]
|
||||||
|
return dtype in dtypes and not dtype.is_floating_point
|
||||||
|
|
||||||
|
# accepts tensors, dtypes, or np.ndarrays
|
||||||
|
def get_default_tolerance(self, a, b=None):
|
||||||
|
if b is None:
|
||||||
|
dtype = torch.float
|
||||||
|
if isinstance(a, torch.Tensor):
|
||||||
|
dtype = a.dtype
|
||||||
|
elif isinstance(a, torch.dtype):
|
||||||
|
dtype = a
|
||||||
|
elif TEST_NUMPY and isinstance(a, numpy.ndarray):
|
||||||
|
# Some tests call assertEqual with numpy Unicode.
|
||||||
|
if numpy.issubdtype(a.dtype, numpy.dtype('U')):
|
||||||
|
dtype = torch.float
|
||||||
|
else:
|
||||||
|
dtype = torch.from_numpy(a).dtype
|
||||||
|
if self.is_integral(dtype):
|
||||||
|
return (0, 0)
|
||||||
|
dtype = str(dtype).split('.')[-1]
|
||||||
|
return self.dtype_precisions.get(dtype, (self.precision, self.precision))
|
||||||
|
|
||||||
|
a_tol = self.get_default_tolerance(a)
|
||||||
|
b_tol = self.get_default_tolerance(b)
|
||||||
|
return (max(a_tol[0], b_tol[0]), max(a_tol[1], b_tol[1]))
|
||||||
|
|
||||||
|
def assertEqual(self, x, y, message='', **kwargs):
|
||||||
|
self.assertIsNone(kwargs.get('prec', None), 'prec is no longer supported. Use atol or rtol.')
|
||||||
|
rtol = kwargs.get('rtol', None)
|
||||||
|
atol = kwargs.get('atol', None)
|
||||||
|
allow_inf = kwargs.get('allow_inf', False)
|
||||||
|
exact_dtype = kwargs.get('exact_dtype', None)
|
||||||
|
# we allow setting an absolute tolerance as a positional arg for BC with legacy testing behavior.
|
||||||
|
if isinstance(message, Number):
|
||||||
|
self.assertIsNone(atol, "don't combine positional prec and atol")
|
||||||
|
self.assertIsNone(rtol, "don't combine positionial prec and rtol")
|
||||||
|
atol = message
|
||||||
|
message = ''
|
||||||
|
rtol = 0
|
||||||
|
elif atol is None and rtol is None:
|
||||||
|
# if both are None, use defaults per-dtype
|
||||||
|
(rtol, atol) = self.get_default_tolerance(x, y)
|
||||||
|
else:
|
||||||
|
if rtol is None:
|
||||||
|
rtol = 0
|
||||||
|
if atol is None:
|
||||||
|
atol = 0
|
||||||
|
|
||||||
if exact_dtype is None:
|
if exact_dtype is None:
|
||||||
exact_dtype = self.exact_dtype
|
exact_dtype = self.exact_dtype
|
||||||
|
|
||||||
if isinstance(prec, str) and message == '':
|
|
||||||
message = prec
|
|
||||||
prec = None
|
|
||||||
if prec is None:
|
|
||||||
prec = self.precision
|
|
||||||
|
|
||||||
if isinstance(x, torch.Tensor) and isinstance(y, Number):
|
if isinstance(x, torch.Tensor) and isinstance(y, Number):
|
||||||
self.assertEqual(x.item(), y, prec=prec, message=message,
|
self.assertEqual(x.item(), y, atol=atol, rtol=rtol, message=message,
|
||||||
allow_inf=allow_inf, exact_dtype=exact_dtype)
|
allow_inf=allow_inf, exact_dtype=exact_dtype)
|
||||||
elif isinstance(y, torch.Tensor) and isinstance(x, Number):
|
elif isinstance(y, torch.Tensor) and isinstance(x, Number):
|
||||||
self.assertEqual(x, y.item(), prec=prec, message=message,
|
self.assertEqual(x, y.item(), atol=atol, rtol=rtol, message=message,
|
||||||
allow_inf=allow_inf, exact_dtype=exact_dtype)
|
allow_inf=allow_inf, exact_dtype=exact_dtype)
|
||||||
elif isinstance(x, torch.Tensor) and isinstance(y, numpy.bool_):
|
elif isinstance(x, torch.Tensor) and isinstance(y, numpy.bool_):
|
||||||
self.assertEqual(x.item(), y, prec=prec, message=message,
|
self.assertEqual(x.item(), y, atol=atol, rtol=rtol, message=message,
|
||||||
allow_inf=allow_inf, exact_dtype=exact_dtype)
|
allow_inf=allow_inf, exact_dtype=exact_dtype)
|
||||||
elif isinstance(y, torch.Tensor) and isinstance(x, numpy.bool_):
|
elif isinstance(y, torch.Tensor) and isinstance(x, numpy.bool_):
|
||||||
self.assertEqual(x, y.item(), prec=prec, message=message,
|
self.assertEqual(x, y.item(), atol=atol, rtol=rtol, message=message,
|
||||||
allow_inf=allow_inf, exact_dtype=exact_dtype)
|
allow_inf=allow_inf, exact_dtype=exact_dtype)
|
||||||
elif isinstance(x, torch.Tensor) and isinstance(y, torch.Tensor):
|
elif isinstance(x, torch.Tensor) and isinstance(y, torch.Tensor):
|
||||||
def assertTensorsEqual(a, b):
|
def assertTensorsEqual(a, b):
|
||||||
|
|
@ -840,24 +899,23 @@ class TestCase(expecttest.TestCase):
|
||||||
a = a.to(torch.int)
|
a = a.to(torch.int)
|
||||||
b = b.to(torch.int)
|
b = b.to(torch.int)
|
||||||
|
|
||||||
diff = a - b
|
if a.is_complex():
|
||||||
if a.dtype.is_complex or a.dtype.is_floating_point:
|
# todo: assert_allclose should handle complex types directly.
|
||||||
# check that NaNs are in the same locations
|
float_dtype = torch.float if a.dtype == torch.complex64 else torch.double
|
||||||
nan_mask = torch.isnan(a)
|
self.assertEqual(a.copy_real().to(float_dtype), b.copy_real().to(float_dtype),
|
||||||
self.assertTrue(torch.equal(nan_mask, torch.isnan(b)), message)
|
atol=atol, rtol=rtol, message=message)
|
||||||
diff[nan_mask] = 0
|
self.assertEqual(a.copy_imag().to(float_dtype), b.copy_imag().to(float_dtype),
|
||||||
# inf check if allow_inf=True
|
atol=atol, rtol=rtol, message=message)
|
||||||
if allow_inf:
|
elif a.is_floating_point():
|
||||||
inf_mask = torch.isinf(a)
|
torch.testing.assert_allclose(a, b, atol=atol, rtol=rtol, equal_nan=True, msg=message)
|
||||||
inf_sign = inf_mask.sign()
|
else:
|
||||||
self.assertTrue(torch.equal(inf_sign, torch.isinf(b).sign()), message)
|
diff = a - b
|
||||||
diff[inf_mask] = 0
|
# TODO: implement abs on CharTensor (int8)
|
||||||
# TODO: implement abs on CharTensor (int8)
|
if diff.is_signed() and diff.dtype != torch.int8:
|
||||||
# TODO: modify abs to return float/double for ComplexFloat/ComplexDouble
|
diff = diff.abs()
|
||||||
if diff.is_signed() and diff.dtype != torch.int8:
|
max_err = diff.max()
|
||||||
diff = diff.abs()
|
self.assertLessEqual(max_err, atol, message)
|
||||||
max_err = diff.max()
|
|
||||||
self.assertLessEqual(max_err, prec, message)
|
|
||||||
super(TestCase, self).assertEqual(x.is_sparse, y.is_sparse, message)
|
super(TestCase, self).assertEqual(x.is_sparse, y.is_sparse, message)
|
||||||
super(TestCase, self).assertEqual(x.is_quantized, y.is_quantized, message)
|
super(TestCase, self).assertEqual(x.is_quantized, y.is_quantized, message)
|
||||||
if x.is_sparse:
|
if x.is_sparse:
|
||||||
|
|
@ -866,30 +924,26 @@ class TestCase(expecttest.TestCase):
|
||||||
assertTensorsEqual(x._indices(), y._indices())
|
assertTensorsEqual(x._indices(), y._indices())
|
||||||
assertTensorsEqual(x._values(), y._values())
|
assertTensorsEqual(x._values(), y._values())
|
||||||
elif x.is_quantized and y.is_quantized:
|
elif x.is_quantized and y.is_quantized:
|
||||||
self.assertEqual(x.qscheme(), y.qscheme(), prec=prec,
|
self.assertEqual(x.qscheme(), y.qscheme(), atol=atol, rtol=rtol,
|
||||||
message=message, allow_inf=allow_inf,
|
message=message, allow_inf=allow_inf, exact_dtype=exact_dtype)
|
||||||
exact_dtype=exact_dtype)
|
|
||||||
if x.qscheme() == torch.per_tensor_affine:
|
if x.qscheme() == torch.per_tensor_affine:
|
||||||
self.assertEqual(x.q_scale(), y.q_scale(), prec=prec,
|
self.assertEqual(x.q_scale(), y.q_scale(), atol=atol, rtol=rtol,
|
||||||
message=message, allow_inf=allow_inf,
|
message=message, allow_inf=allow_inf, exact_dtype=exact_dtype)
|
||||||
exact_dtype=exact_dtype)
|
|
||||||
self.assertEqual(x.q_zero_point(), y.q_zero_point(),
|
self.assertEqual(x.q_zero_point(), y.q_zero_point(),
|
||||||
prec=prec, message=message,
|
atol=atol, rtol=rtol, message=message,
|
||||||
allow_inf=allow_inf, exact_dtype=exact_dtype)
|
allow_inf=allow_inf, exact_dtype=exact_dtype)
|
||||||
elif x.qscheme() == torch.per_channel_affine:
|
elif x.qscheme() == torch.per_channel_affine:
|
||||||
self.assertEqual(x.q_per_channel_scales(), y.q_per_channel_scales(), prec=prec,
|
self.assertEqual(x.q_per_channel_scales(), y.q_per_channel_scales(), atol=atol, rtol=rtol,
|
||||||
message=message, allow_inf=allow_inf,
|
message=message, allow_inf=allow_inf, exact_dtype=exact_dtype)
|
||||||
exact_dtype=exact_dtype)
|
|
||||||
self.assertEqual(x.q_per_channel_zero_points(), y.q_per_channel_zero_points(),
|
self.assertEqual(x.q_per_channel_zero_points(), y.q_per_channel_zero_points(),
|
||||||
prec=prec, message=message,
|
atol=atol, rtol=rtol, message=message,
|
||||||
allow_inf=allow_inf, exact_dtype=exact_dtype)
|
allow_inf=allow_inf, exact_dtype=exact_dtype)
|
||||||
self.assertEqual(x.q_per_channel_axis(), y.q_per_channel_axis(),
|
self.assertEqual(x.q_per_channel_axis(), y.q_per_channel_axis(),
|
||||||
prec=prec, message=message)
|
atol=atol, rtol=rtol, message=message)
|
||||||
self.assertEqual(x.dtype, y.dtype)
|
self.assertEqual(x.dtype, y.dtype)
|
||||||
self.assertEqual(x.int_repr().to(torch.int32),
|
self.assertEqual(x.int_repr().to(torch.int32),
|
||||||
y.int_repr().to(torch.int32), prec=prec,
|
y.int_repr().to(torch.int32), atol=atol, rtol=rtol,
|
||||||
message=message, allow_inf=allow_inf,
|
message=message, allow_inf=allow_inf, exact_dtype=exact_dtype)
|
||||||
exact_dtype=exact_dtype)
|
|
||||||
else:
|
else:
|
||||||
assertTensorsEqual(x, y)
|
assertTensorsEqual(x, y)
|
||||||
elif isinstance(x, string_classes) and isinstance(y, string_classes):
|
elif isinstance(x, string_classes) and isinstance(y, string_classes):
|
||||||
|
|
@ -898,22 +952,20 @@ class TestCase(expecttest.TestCase):
|
||||||
super(TestCase, self).assertEqual(x, y, message)
|
super(TestCase, self).assertEqual(x, y, message)
|
||||||
elif isinstance(x, dict) and isinstance(y, dict):
|
elif isinstance(x, dict) and isinstance(y, dict):
|
||||||
if isinstance(x, OrderedDict) and isinstance(y, OrderedDict):
|
if isinstance(x, OrderedDict) and isinstance(y, OrderedDict):
|
||||||
self.assertEqual(x.items(), y.items(), prec=prec,
|
self.assertEqual(x.items(), y.items(), atol=atol, rtol=rtol,
|
||||||
message=message, allow_inf=allow_inf,
|
message=message, allow_inf=allow_inf, exact_dtype=exact_dtype)
|
||||||
exact_dtype=exact_dtype)
|
|
||||||
else:
|
else:
|
||||||
self.assertEqual(set(x.keys()), set(y.keys()), prec=prec,
|
self.assertEqual(set(x.keys()), set(y.keys()), atol=atol, rtol=rtol,
|
||||||
message=message, allow_inf=allow_inf,
|
message=message, allow_inf=allow_inf, exact_dtype=exact_dtype)
|
||||||
exact_dtype=exact_dtype)
|
|
||||||
key_list = list(x.keys())
|
key_list = list(x.keys())
|
||||||
self.assertEqual([x[k] for k in key_list],
|
self.assertEqual([x[k] for k in key_list],
|
||||||
[y[k] for k in key_list],
|
[y[k] for k in key_list],
|
||||||
prec=prec, message=message,
|
atol=atol, rtol=rtol, message=message,
|
||||||
allow_inf=allow_inf, exact_dtype=exact_dtype)
|
allow_inf=allow_inf, exact_dtype=exact_dtype)
|
||||||
elif is_iterable(x) and is_iterable(y):
|
elif is_iterable(x) and is_iterable(y):
|
||||||
super(TestCase, self).assertEqual(len(x), len(y), message)
|
super(TestCase, self).assertEqual(len(x), len(y), message)
|
||||||
for x_, y_ in zip(x, y):
|
for x_, y_ in zip(x, y):
|
||||||
self.assertEqual(x_, y_, prec=prec, message=message,
|
self.assertEqual(x_, y_, atol=atol, rtol=rtol, message=message,
|
||||||
allow_inf=allow_inf, exact_dtype=exact_dtype)
|
allow_inf=allow_inf, exact_dtype=exact_dtype)
|
||||||
elif isinstance(x, bool) and isinstance(y, bool):
|
elif isinstance(x, bool) and isinstance(y, bool):
|
||||||
super(TestCase, self).assertEqual(x, y, message)
|
super(TestCase, self).assertEqual(x, y, message)
|
||||||
|
|
@ -924,22 +976,21 @@ class TestCase(expecttest.TestCase):
|
||||||
else:
|
else:
|
||||||
self.fail("Expected finite numeric values - x={}, y={}".format(x, y))
|
self.fail("Expected finite numeric values - x={}, y={}".format(x, y))
|
||||||
return
|
return
|
||||||
super(TestCase, self).assertLessEqual(abs(x - y), prec, message)
|
super(TestCase, self).assertLessEqual(abs(x - y), atol, message)
|
||||||
else:
|
else:
|
||||||
super(TestCase, self).assertEqual(x, y, message)
|
super(TestCase, self).assertEqual(x, y, message)
|
||||||
|
|
||||||
def assertAlmostEqual(self, x, y, places=None, msg=None, delta=None, allow_inf=None):
|
def assertAlmostEqual(self, x, y, places=None, msg='', delta=None, allow_inf=None):
|
||||||
prec = delta
|
prec = delta
|
||||||
if places:
|
if places:
|
||||||
prec = 10**(-places)
|
prec = 10**(-places)
|
||||||
self.assertEqual(x, y, prec, msg, allow_inf)
|
self.assertEqual(x, y, msg, atol=prec, allow_inf=allow_inf)
|
||||||
|
|
||||||
def assertNotEqual(self, x, y, prec=None, message=''):
|
def assertNotEqual(self, x, y, message='', atol=None):
|
||||||
if isinstance(prec, str) and message == '':
|
if not isinstance(message, str):
|
||||||
message = prec
|
raise Error("fix this test, message should be a string")
|
||||||
prec = None
|
if atol is None:
|
||||||
if prec is None:
|
(_, atol) = self.get_default_tolerance(x, y)
|
||||||
prec = self.precision
|
|
||||||
|
|
||||||
if isinstance(x, torch.Tensor) and isinstance(y, torch.Tensor):
|
if isinstance(x, torch.Tensor) and isinstance(y, torch.Tensor):
|
||||||
if x.size() != y.size():
|
if x.size() != y.size():
|
||||||
|
|
@ -956,14 +1007,14 @@ class TestCase(expecttest.TestCase):
|
||||||
# Use `item()` to work around:
|
# Use `item()` to work around:
|
||||||
# https://github.com/pytorch/pytorch/issues/22301
|
# https://github.com/pytorch/pytorch/issues/22301
|
||||||
max_err = diff.max().item()
|
max_err = diff.max().item()
|
||||||
self.assertGreaterEqual(max_err, prec, message)
|
self.assertGreater(max_err, atol, message)
|
||||||
elif type(x) == str and type(y) == str:
|
elif type(x) == str and type(y) == str:
|
||||||
super(TestCase, self).assertNotEqual(x, y)
|
super(TestCase, self).assertNotEqual(x, y)
|
||||||
elif is_iterable(x) and is_iterable(y):
|
elif is_iterable(x) and is_iterable(y):
|
||||||
super(TestCase, self).assertNotEqual(x, y)
|
super(TestCase, self).assertNotEqual(x, y)
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
self.assertGreaterEqual(abs(x - y), prec, message)
|
self.assertGreater(abs(x - y), atol, message)
|
||||||
return
|
return
|
||||||
except (TypeError, AssertionError):
|
except (TypeError, AssertionError):
|
||||||
pass
|
pass
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user