mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
we should have a config-based way to skip flaky tests (#30978)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/30978 This particular approach queries our issue tracker for test titles that match the following format: ``` DISABLED test_async_grad_guard_with_grad (jit.test_async.TestAsync) ``` And then skips the python test for them. There is 1 second timeout so if the internet flakes we still run the test suite, without disabling any tests. This is intended as a quick fix, similar to ninja unland, to get to a green master. Long term test disables should go into the code. Test Plan: Imported from OSS Pulled By: zdevito Differential Revision: D18890532 fbshipit-source-id: fe9447e59a6d5c9ad345f7c3ff15d63b6d2a09e2
This commit is contained in:
parent
d2067569e7
commit
dab5f72543
|
|
@ -27,6 +27,11 @@ from itertools import product
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from numbers import Number
|
from numbers import Number
|
||||||
import tempfile
|
import tempfile
|
||||||
|
import json
|
||||||
|
if sys.version_info[0] == 2:
|
||||||
|
from urllib2 import urlopen # noqa f811
|
||||||
|
else:
|
||||||
|
from urllib.request import urlopen
|
||||||
|
|
||||||
import __main__
|
import __main__
|
||||||
import errno
|
import errno
|
||||||
|
|
@ -580,6 +585,34 @@ try:
|
||||||
except ImportError:
|
except ImportError:
|
||||||
print('Fail to import hypothesis in common_utils, tests are not derandomized')
|
print('Fail to import hypothesis in common_utils, tests are not derandomized')
|
||||||
|
|
||||||
|
disabled_test_from_issues = None
|
||||||
|
def check_disabled(test_name):
|
||||||
|
global disabled_test_from_issues
|
||||||
|
if disabled_test_from_issues is None:
|
||||||
|
disabled_test_from_issues = {}
|
||||||
|
|
||||||
|
def read_and_process():
|
||||||
|
url = 'https://raw.githubusercontent.com/zdevito/pytorch_disabled_tests/master/result.json'
|
||||||
|
contents = urlopen(url, timeout=1).read().decode('utf-8')
|
||||||
|
the_response = json.loads(contents)
|
||||||
|
for item in the_response['items']:
|
||||||
|
title = item['title']
|
||||||
|
key = 'DISABLED '
|
||||||
|
if title.startswith(key):
|
||||||
|
test_name = title[len(key):].strip()
|
||||||
|
disabled_test_from_issues[test_name] = item['html_url']
|
||||||
|
|
||||||
|
if not IS_SANDCASTLE and os.getenv("PYTORCH_RUN_DISABLED_TESTS", "0") != "1":
|
||||||
|
try:
|
||||||
|
read_and_process()
|
||||||
|
except Exception:
|
||||||
|
print("Couldn't download test skip set, leaving all tests enabled...")
|
||||||
|
|
||||||
|
if test_name in disabled_test_from_issues:
|
||||||
|
raise unittest.SkipTest(
|
||||||
|
"Test is disabled because an issue exists disabling it: {}".format(disabled_test_from_issues[test_name]) +
|
||||||
|
" To enable set the environment variable PYTORCH_RUN_DISABLED_TESTS=1")
|
||||||
|
|
||||||
class TestCase(expecttest.TestCase):
|
class TestCase(expecttest.TestCase):
|
||||||
precision = 1e-5
|
precision = 1e-5
|
||||||
maxDiff = None
|
maxDiff = None
|
||||||
|
|
@ -634,10 +667,14 @@ class TestCase(expecttest.TestCase):
|
||||||
def wrap_with_cuda_memory_check(self, method):
|
def wrap_with_cuda_memory_check(self, method):
|
||||||
return self.wrap_method_with_cuda_policy(method, self.assertLeaksNoCudaTensors)
|
return self.wrap_method_with_cuda_policy(method, self.assertLeaksNoCudaTensors)
|
||||||
|
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
|
|
||||||
|
|
||||||
if TEST_SKIP_FAST:
|
if TEST_SKIP_FAST:
|
||||||
if not getattr(self, self._testMethodName).__dict__.get('slow_test', False):
|
if not getattr(self, self._testMethodName).__dict__.get('slow_test', False):
|
||||||
raise unittest.SkipTest("test is fast; we disabled it with PYTORCH_TEST_SKIP_FAST")
|
raise unittest.SkipTest("test is fast; we disabled it with PYTORCH_TEST_SKIP_FAST")
|
||||||
|
check_disabled(str(self))
|
||||||
|
|
||||||
set_rng_seed(SEED)
|
set_rng_seed(SEED)
|
||||||
|
|
||||||
|
|
|
||||||
9
tools/update_disabled_tests.sh
Executable file
9
tools/update_disabled_tests.sh
Executable file
|
|
@ -0,0 +1,9 @@
|
||||||
|
#!/bin/bash
|
||||||
|
EXTRACTED_REPO=https://$USERNAME:$API_KEY@github.com/zdevito/pytorch_disabled_tests.git
|
||||||
|
git clone $EXTRACTED_REPO
|
||||||
|
cd pytorch_disabled_tests
|
||||||
|
curl 'https://api.github.com/search/issues?q=is%3Aissue+is%3Aopen+label%3A%22topic%3A+flaky-tests%22+repo:pytorch/pytorch+in%3Atitle+DISABLED' \
|
||||||
|
| sed 's/"score": [0-9\.]*/"score": 0.0/g' > result.json
|
||||||
|
# score changes every request, so we strip it out to avoid creating a commit every time we query.
|
||||||
|
git commit -a -m 'update'
|
||||||
|
git push
|
||||||
Loading…
Reference in New Issue
Block a user