Increase tolerance for test_adadelta (#69919)

Summary:
Fixes https://github.com/pytorch/pytorch/issues/69698

Pull Request resolved: https://github.com/pytorch/pytorch/pull/69919

Reviewed By: cpuhrsch

Differential Revision: D33286427

Pulled By: jbschlosser

fbshipit-source-id: a2ca90683c14b6669f9b1804881ac675ba925fc5
This commit is contained in:
Rishi Puri 2022-01-05 15:00:39 -08:00 committed by Facebook GitHub Bot
parent ce409d8f50
commit f9e1a1c97f

View File

@ -20,7 +20,7 @@ from torch.optim.lr_scheduler import LambdaLR, MultiplicativeLR, SequentialLR, S
from torch.optim.swa_utils import AveragedModel, SWALR, update_bn
from torch.testing._internal.common_utils import TestCase, run_tests, TEST_WITH_UBSAN, load_tests, \
skipIfRocm
from torch.testing._internal.common_device_type import toleranceOverride, tol
# load_tests from common_utils is used to automatically filter tests for
# sharding on sandcastle. This line silences flake warnings
load_tests = load_tests
@ -600,7 +600,9 @@ class TestOptim(TestCase):
optim.SparseAdam([{"params": [torch.zeros(3, layout=torch.sparse_coo)]}])
# ROCm precision is too low to pass this test
# Tolerance Override Handles https://github.com/pytorch/pytorch/issues/69698
@skipIfRocm
@toleranceOverride({torch.float32: tol(4e-3, 0)})
def test_adadelta(self):
for optimizer in [optim.Adadelta, optim_mt.Adadelta]:
self._test_basic_cases(
@ -622,6 +624,8 @@ class TestOptim(TestCase):
with self.assertRaisesRegex(ValueError, "Invalid rho value: 1.1"):
optimizer(None, lr=1e-2, rho=1.1)
# Tolerance Override Handles https://github.com/pytorch/pytorch/issues/69698
@toleranceOverride({torch.float32: tol(2e-2, 0)})
def test_adadelta_complex(self):
for optimizer in [optim.Adadelta]:
self._test_complex_optimizer(