mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Increase tolerance for test_adadelta (#69919)
Summary: Fixes https://github.com/pytorch/pytorch/issues/69698 Pull Request resolved: https://github.com/pytorch/pytorch/pull/69919 Reviewed By: cpuhrsch Differential Revision: D33286427 Pulled By: jbschlosser fbshipit-source-id: a2ca90683c14b6669f9b1804881ac675ba925fc5
This commit is contained in:
parent
ce409d8f50
commit
f9e1a1c97f
|
|
@ -20,7 +20,7 @@ from torch.optim.lr_scheduler import LambdaLR, MultiplicativeLR, SequentialLR, S
|
|||
from torch.optim.swa_utils import AveragedModel, SWALR, update_bn
|
||||
from torch.testing._internal.common_utils import TestCase, run_tests, TEST_WITH_UBSAN, load_tests, \
|
||||
skipIfRocm
|
||||
|
||||
from torch.testing._internal.common_device_type import toleranceOverride, tol
|
||||
# load_tests from common_utils is used to automatically filter tests for
|
||||
# sharding on sandcastle. This line silences flake warnings
|
||||
load_tests = load_tests
|
||||
|
|
@ -600,7 +600,9 @@ class TestOptim(TestCase):
|
|||
optim.SparseAdam([{"params": [torch.zeros(3, layout=torch.sparse_coo)]}])
|
||||
|
||||
# ROCm precision is too low to pass this test
|
||||
# Tolerance Override Handles https://github.com/pytorch/pytorch/issues/69698
|
||||
@skipIfRocm
|
||||
@toleranceOverride({torch.float32: tol(4e-3, 0)})
|
||||
def test_adadelta(self):
|
||||
for optimizer in [optim.Adadelta, optim_mt.Adadelta]:
|
||||
self._test_basic_cases(
|
||||
|
|
@ -622,6 +624,8 @@ class TestOptim(TestCase):
|
|||
with self.assertRaisesRegex(ValueError, "Invalid rho value: 1.1"):
|
||||
optimizer(None, lr=1e-2, rho=1.1)
|
||||
|
||||
# Tolerance Override Handles https://github.com/pytorch/pytorch/issues/69698
|
||||
@toleranceOverride({torch.float32: tol(2e-2, 0)})
|
||||
def test_adadelta_complex(self):
|
||||
for optimizer in [optim.Adadelta]:
|
||||
self._test_complex_optimizer(
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user