mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Replace torch.allClose with self.assertEqual (#39424)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/39424 Reviewed By: Krovatkin Differential Revision: D21854870 Pulled By: ailzhang fbshipit-source-id: eb68f1775596e4c963169033444d6d6f4f818d4f
This commit is contained in:
parent
5d2cfb3d4c
commit
46447045ea
|
|
@ -10510,8 +10510,8 @@ class TestTorchDeviceType(TestCase):
|
|||
alias_table, prob_table = torch._multinomial_alias_setup(probs)
|
||||
alias_samples = torch._multinomial_alias_draw(prob_table, alias_table, MAX_SAMPLES)
|
||||
alias_dist = torch.unique(alias_samples, return_counts=True)[1].to(dtype=probs.dtype) / MAX_SAMPLES
|
||||
self.assertTrue(torch.allclose(alias_dist, probs, rtol=0.02, atol=0.0),
|
||||
"Actual: {}\nExpected: {}".format(alias_dist, probs))
|
||||
self.assertEqual(alias_dist, probs, rtol=0.02, atol=0.0,
|
||||
msg="Actual: {}\nExpected: {}".format(alias_dist, probs))
|
||||
|
||||
for probs in [torch.tensor([0.2501, 0.25, 0.2499, 0.25], device=device),
|
||||
torch.tensor([0.8, 0.199, 0.001], device=device),
|
||||
|
|
@ -11176,11 +11176,11 @@ class TestTorchDeviceType(TestCase):
|
|||
for cm in ['use_mm_for_euclid_dist', 'donot_use_mm_for_euclid_dist']:
|
||||
actual = torch.cdist(x, y, p=2, compute_mode=cm)
|
||||
expected = self._brute_cdist(x, y, p=2)
|
||||
self.assertTrue(torch.allclose(expected, actual, rtol=0, atol=0.02))
|
||||
self.assertEqual(expected, actual, rtol=0, atol=0.02)
|
||||
else:
|
||||
actual = torch.cdist(x, y, p=p)
|
||||
expected = self._brute_cdist(x, y, p=p)
|
||||
self.assertTrue(torch.allclose(expected, actual))
|
||||
self.assertEqual(expected, actual)
|
||||
|
||||
def test_cdist_norm_batch(self, device):
|
||||
for r1 in [3, 4, 5, 6]:
|
||||
|
|
@ -11193,11 +11193,11 @@ class TestTorchDeviceType(TestCase):
|
|||
for cm in ['use_mm_for_euclid_dist', 'donot_use_mm_for_euclid_dist']:
|
||||
actual = torch.cdist(x, y, p=2, compute_mode=cm)
|
||||
expected = self._brute_cdist(x, y, p=2)
|
||||
self.assertTrue(torch.allclose(expected, actual, rtol=0, atol=0.02))
|
||||
self.assertEqual(expected, actual, rtol=0, atol=0.02)
|
||||
else:
|
||||
actual = torch.cdist(x, y, p=p)
|
||||
expected = self._brute_cdist(x, y, p=p)
|
||||
self.assertTrue(torch.allclose(expected, actual))
|
||||
self.assertEqual(expected, actual)
|
||||
|
||||
def test_cdist_large(self, device):
|
||||
for cm in ['use_mm_for_euclid_dist_if_necessary', 'use_mm_for_euclid_dist', 'donot_use_mm_for_euclid_dist']:
|
||||
|
|
@ -11205,7 +11205,7 @@ class TestTorchDeviceType(TestCase):
|
|||
y = torch.randn(1000, 10, device=device)
|
||||
actual = torch.cdist(x, y, p=2, compute_mode=cm)
|
||||
expected = self._brute_cdist(x, y, p=2)
|
||||
self.assertTrue(torch.allclose(expected, actual))
|
||||
self.assertEqual(expected, actual)
|
||||
|
||||
@slowTest
|
||||
def test_cdist_large_batch(self, device):
|
||||
|
|
@ -11214,7 +11214,7 @@ class TestTorchDeviceType(TestCase):
|
|||
y = torch.randn(4, 3, 1000, 10, device=device)
|
||||
actual = torch.cdist(x, y, p=2, compute_mode=cm)
|
||||
expected = self._brute_cdist(x, y, p=2)
|
||||
self.assertTrue(torch.allclose(expected, actual))
|
||||
self.assertEqual(expected, actual)
|
||||
|
||||
def test_cdist_non_contiguous(self, device):
|
||||
for cm in ['use_mm_for_euclid_dist', 'donot_use_mm_for_euclid_dist']:
|
||||
|
|
@ -11224,7 +11224,7 @@ class TestTorchDeviceType(TestCase):
|
|||
expected = self._brute_cdist(x, y, p=2)
|
||||
self.assertFalse(x.is_contiguous())
|
||||
self.assertFalse(y.is_contiguous())
|
||||
self.assertTrue(torch.allclose(expected, actual))
|
||||
self.assertEqual(expected, actual)
|
||||
|
||||
x = torch.randn(7, 5, device=device)
|
||||
y = torch.randn(5, 3, device=device).t()
|
||||
|
|
@ -11232,7 +11232,7 @@ class TestTorchDeviceType(TestCase):
|
|||
expected = self._brute_cdist(x, y, p=2)
|
||||
self.assertTrue(x.is_contiguous())
|
||||
self.assertFalse(y.is_contiguous())
|
||||
self.assertTrue(torch.allclose(expected, actual))
|
||||
self.assertEqual(expected, actual)
|
||||
|
||||
x = torch.randn(5, 7, device=device).t()
|
||||
y = torch.randn(3, 5, device=device)
|
||||
|
|
@ -11240,7 +11240,7 @@ class TestTorchDeviceType(TestCase):
|
|||
expected = self._brute_cdist(x, y, p=2)
|
||||
self.assertFalse(x.is_contiguous())
|
||||
self.assertTrue(y.is_contiguous())
|
||||
self.assertTrue(torch.allclose(expected, actual))
|
||||
self.assertEqual(expected, actual)
|
||||
|
||||
def test_cdist_non_contiguous_batch(self, device):
|
||||
for cm in ['use_mm_for_euclid_dist', 'donot_use_mm_for_euclid_dist']:
|
||||
|
|
@ -11250,7 +11250,7 @@ class TestTorchDeviceType(TestCase):
|
|||
expected = self._brute_cdist(x, y, p=2)
|
||||
self.assertFalse(x.is_contiguous())
|
||||
self.assertFalse(y.is_contiguous())
|
||||
self.assertTrue(torch.allclose(expected, actual))
|
||||
self.assertEqual(expected, actual)
|
||||
|
||||
x = torch.randn(7, 2, 7, 5, device=device)
|
||||
y = torch.randn(7, 2, 5, 3, device=device).transpose(-1, -2)
|
||||
|
|
@ -11258,7 +11258,7 @@ class TestTorchDeviceType(TestCase):
|
|||
expected = self._brute_cdist(x, y, p=2)
|
||||
self.assertTrue(x.is_contiguous())
|
||||
self.assertFalse(y.is_contiguous())
|
||||
self.assertTrue(torch.allclose(expected, actual))
|
||||
self.assertEqual(expected, actual)
|
||||
|
||||
x = torch.randn(4, 5, 7, device=device).transpose(-1, -2)
|
||||
y = torch.randn(4, 3, 5, device=device)
|
||||
|
|
@ -11266,7 +11266,7 @@ class TestTorchDeviceType(TestCase):
|
|||
expected = self._brute_cdist(x, y, p=2)
|
||||
self.assertFalse(x.is_contiguous())
|
||||
self.assertTrue(y.is_contiguous())
|
||||
self.assertTrue(torch.allclose(expected, actual))
|
||||
self.assertEqual(expected, actual)
|
||||
|
||||
def test_multinomial_constraints(self, device):
|
||||
x = torch.empty(1, 2, 3, dtype=torch.double, device=device)
|
||||
|
|
@ -11617,7 +11617,7 @@ class TestTorchDeviceType(TestCase):
|
|||
expected = logcumsumexp(a, axis)
|
||||
self.assertEqual(a.dtype, actual.dtype)
|
||||
self.assertEqual(expected.shape, actual.shape)
|
||||
self.assertTrue(torch.allclose(expected, actual))
|
||||
self.assertEqual(expected, actual)
|
||||
|
||||
# Check that out is actually inplace
|
||||
b = torch.randn(5, 2, device=device)
|
||||
|
|
@ -11626,7 +11626,7 @@ class TestTorchDeviceType(TestCase):
|
|||
expected = logcumsumexp(b, axis)
|
||||
torch.logcumsumexp(b, axis=axis, out=inplace_out)
|
||||
|
||||
self.assertTrue(torch.allclose(inplace_out, expected))
|
||||
self.assertEqual(inplace_out, expected)
|
||||
|
||||
# Check input and inplace_output type mismatch
|
||||
b = torch.randn(5, 2, device=device, dtype=torch.float64)
|
||||
|
|
@ -12822,12 +12822,12 @@ class TestTorchDeviceType(TestCase):
|
|||
actual = torch.pdist(x, p=p)
|
||||
expected = self._brute_pdist(y, p=p)
|
||||
self.assertEqual(expected.shape, actual.shape)
|
||||
self.assertTrue(torch.allclose(expected, actual))
|
||||
self.assertEqual(expected, actual)
|
||||
if grad_check and expected.size() != torch.Size([0]):
|
||||
g0 = torch.rand_like(actual)
|
||||
actual.backward(g0)
|
||||
expected.backward(g0)
|
||||
self.assertTrue(torch.allclose(x.grad, y.grad))
|
||||
self.assertEqual(x.grad, y.grad)
|
||||
|
||||
@slowTest
|
||||
def test_pdist_norm_forward(self, device):
|
||||
|
|
@ -12858,7 +12858,7 @@ class TestTorchDeviceType(TestCase):
|
|||
x = torch.randn(50000, 1, dtype=torch.float32)
|
||||
expected_cpu = torch.pdist(x, p=2)
|
||||
actual_gpu = torch.pdist(x.to(device), p=2)
|
||||
self.assertTrue(torch.allclose(expected_cpu, actual_gpu.cpu()))
|
||||
self.assertEqual(expected_cpu, actual_gpu.cpu())
|
||||
|
||||
def test_atan2(self, device):
|
||||
def _test_atan2_with_size(size, device):
|
||||
|
|
@ -12869,7 +12869,7 @@ class TestTorchDeviceType(TestCase):
|
|||
y = b.view(-1)
|
||||
expected = torch.tensor([math.atan2(x[i].item(), y[i].item()) for i in range(x.numel())],
|
||||
device=device, dtype=torch.double)
|
||||
self.assertTrue(torch.allclose(expected, actual.view(-1), rtol=0, atol=0.02))
|
||||
self.assertEqual(expected, actual.view(-1), rtol=0, atol=0.02)
|
||||
|
||||
_test_atan2_with_size((2, 2), device)
|
||||
_test_atan2_with_size((3, 3), device)
|
||||
|
|
@ -12881,7 +12881,7 @@ class TestTorchDeviceType(TestCase):
|
|||
x_tensor = torch.tensor([x], dtype=dtype, device=device)
|
||||
y_tensor = torch.tensor([y], dtype=dtype, device=device)
|
||||
actual = torch.atan2(y_tensor, x_tensor)
|
||||
self.assertTrue(torch.allclose(expected_tensor, actual, rtol=0, atol=0.02))
|
||||
self.assertEqual(expected_tensor, actual, rtol=0, atol=0.02)
|
||||
|
||||
for dtype in [torch.float, torch.double]:
|
||||
_test_atan2(0, 0, 0, device, dtype)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user