mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
I noticed that a lot of bugs are being suppressed by torchdynamo's default error suppression, and worse yet, there's no way to unsuppress them. After discussion with voz and soumith, we decided that we will unify error suppression into a single option (suppress_errors) and default suppression to False. If your model used to work and no longer works, try TORCHDYNAMO_SUPPRESS_ERRORS=1 to bring back the old suppression behavior. Signed-off-by: Edward Z. Yang <ezyang@fb.com> cc @jansel @lezcano @fdrocha @mlazos @soumith @voznesenskym @yanboliang Pull Request resolved: https://github.com/pytorch/pytorch/pull/87440 Approved by: https://github.com/voznesenskym, https://github.com/albanD
69 lines
1.5 KiB
Python
69 lines
1.5 KiB
Python
import contextlib
|
|
import importlib
|
|
import sys
|
|
from unittest.mock import patch
|
|
|
|
import torch
|
|
import torch.testing
|
|
from torch.testing._internal.common_utils import (
|
|
IS_WINDOWS,
|
|
TEST_WITH_CROSSREF,
|
|
TEST_WITH_ROCM,
|
|
TEST_WITH_TORCHDYNAMO,
|
|
TestCase as TorchTestCase,
|
|
)
|
|
|
|
from . import config, reset, utils
|
|
|
|
|
|
def run_tests(needs=()):
|
|
from torch.testing._internal.common_utils import run_tests
|
|
|
|
if (
|
|
TEST_WITH_TORCHDYNAMO
|
|
or IS_WINDOWS
|
|
or TEST_WITH_CROSSREF
|
|
or TEST_WITH_ROCM
|
|
or sys.version_info >= (3, 11)
|
|
):
|
|
return # skip testing
|
|
|
|
if isinstance(needs, str):
|
|
needs = (needs,)
|
|
for need in needs:
|
|
if need == "cuda" and not torch.cuda.is_available():
|
|
return
|
|
else:
|
|
try:
|
|
importlib.import_module(need)
|
|
except ImportError:
|
|
return
|
|
run_tests()
|
|
|
|
|
|
class TestCase(TorchTestCase):
|
|
@classmethod
|
|
def tearDownClass(cls):
|
|
cls._exit_stack.close()
|
|
super().tearDownClass()
|
|
|
|
@classmethod
|
|
def setUpClass(cls):
|
|
super().setUpClass()
|
|
cls._exit_stack = contextlib.ExitStack()
|
|
cls._exit_stack.enter_context(
|
|
patch.object(config, "raise_on_ctx_manager_usage", True)
|
|
)
|
|
|
|
def setUp(self):
|
|
super().setUp()
|
|
reset()
|
|
utils.counters.clear()
|
|
|
|
def tearDown(self):
|
|
for k, v in utils.counters.items():
|
|
print(k, v.most_common())
|
|
reset()
|
|
utils.counters.clear()
|
|
super().tearDown()
|