pytorch/torch/_dynamo/test_case.py
Edward Z. Yang 96691865b9 [dynamo] Unify raise_on_* config to suppress_errors and raise by default (#87440)
I noticed that a lot of bugs are being suppressed by torchdynamo's default
error suppression, and worse yet, there's no way to unsuppress them.  After
discussion with voz and soumith, we decided that we will unify error suppression
into a single option (suppress_errors) and default suppression to False.

If your model used to work and no longer works, try TORCHDYNAMO_SUPPRESS_ERRORS=1
to bring back the old suppression behavior.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>

cc @jansel @lezcano @fdrocha @mlazos @soumith @voznesenskym @yanboliang
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87440
Approved by: https://github.com/voznesenskym, https://github.com/albanD
2022-10-21 17:03:29 +00:00

69 lines
1.5 KiB
Python

import contextlib
import importlib
import sys
from unittest.mock import patch
import torch
import torch.testing
from torch.testing._internal.common_utils import (
IS_WINDOWS,
TEST_WITH_CROSSREF,
TEST_WITH_ROCM,
TEST_WITH_TORCHDYNAMO,
TestCase as TorchTestCase,
)
from . import config, reset, utils
def run_tests(needs=()):
from torch.testing._internal.common_utils import run_tests
if (
TEST_WITH_TORCHDYNAMO
or IS_WINDOWS
or TEST_WITH_CROSSREF
or TEST_WITH_ROCM
or sys.version_info >= (3, 11)
):
return # skip testing
if isinstance(needs, str):
needs = (needs,)
for need in needs:
if need == "cuda" and not torch.cuda.is_available():
return
else:
try:
importlib.import_module(need)
except ImportError:
return
run_tests()
class TestCase(TorchTestCase):
@classmethod
def tearDownClass(cls):
cls._exit_stack.close()
super().tearDownClass()
@classmethod
def setUpClass(cls):
super().setUpClass()
cls._exit_stack = contextlib.ExitStack()
cls._exit_stack.enter_context(
patch.object(config, "raise_on_ctx_manager_usage", True)
)
def setUp(self):
super().setUp()
reset()
utils.counters.clear()
def tearDown(self):
for k, v in utils.counters.items():
print(k, v.most_common())
reset()
utils.counters.clear()
super().tearDown()