pytorch/torch/testing/_internal/logging_utils.py
Yanbo Liang 4c73016ff2 [Dynamo] Enable torch._dynamo.config.suppress_errors by default (#105307)
Summary:
We are working toward full model compilation, where when compilation error happens, we just fall back to eager mode rather than error out.
But at the same time, we should fix these issues if they are bugs. We will:
* 1/ log warnings in OSS;
* 2/ log warnings and write them into Scuba in fbcode;

to prevent us from ignoring these issues.

Test Plan: Manual test

Differential Revision: D47506314

Pull Request resolved: https://github.com/pytorch/pytorch/pull/105307
Approved by: https://github.com/jansel
2023-07-21 19:17:46 +00:00

159 lines
5.0 KiB
Python

import torch._dynamo.test_case
import unittest.mock
import os
import contextlib
import torch._logging
import torch._logging._internal
import logging
@contextlib.contextmanager
def preserve_log_state():
prev_state = torch._logging._internal._get_log_state()
torch._logging._internal._set_log_state(torch._logging._internal.LogState())
try:
yield
finally:
torch._logging._internal._set_log_state(prev_state)
torch._logging._internal._init_logs()
def log_settings(settings):
exit_stack = contextlib.ExitStack()
settings_patch = unittest.mock.patch.dict(os.environ, {"TORCH_LOGS": settings})
exit_stack.enter_context(preserve_log_state())
exit_stack.enter_context(settings_patch)
torch._logging._internal._init_logs()
return exit_stack
def log_api(**kwargs):
exit_stack = contextlib.ExitStack()
exit_stack.enter_context(preserve_log_state())
torch._logging.set_logs(**kwargs)
return exit_stack
def kwargs_to_settings(**kwargs):
INT_TO_VERBOSITY = {10: "+", 20: "", 40: "-"}
settings = []
def append_setting(name, level):
if isinstance(name, str) and isinstance(level, int) and level in INT_TO_VERBOSITY:
settings.append(INT_TO_VERBOSITY[level] + name)
return
else:
raise ValueError("Invalid value for setting")
for name, val in kwargs.items():
if isinstance(val, bool):
settings.append(name)
elif isinstance(val, int):
append_setting(name, val)
elif isinstance(val, dict) and name == "modules":
for module_qname, level in val.items():
append_setting(module_qname, level)
else:
raise ValueError("Invalid value for setting")
return ",".join(settings)
# Note on testing strategy:
# This class does two things:
# 1. Runs two versions of a test:
# 1a. patches the env var log settings to some specific value
# 1b. calls torch._logging.set_logs(..)
# 2. patches the emit method of each setup handler to gather records
# that are emitted to each console stream
# 3. passes a ref to the gathered records to each test case for checking
#
# The goal of this testing in general is to ensure that given some settings env var
# that the logs are setup correctly and capturing the correct records.
def make_logging_test(**kwargs):
def wrapper(fn):
def test_fn(self):
torch._dynamo.reset()
records = []
# run with env var
if len(kwargs) == 0:
with self._handler_watcher(records):
fn(self, records)
else:
with log_settings(kwargs_to_settings(**kwargs)), self._handler_watcher(records):
fn(self, records)
# run with API
torch._dynamo.reset()
records.clear()
with log_api(**kwargs), self._handler_watcher(records):
fn(self, records)
return test_fn
return wrapper
def make_settings_test(settings):
def wrapper(fn):
def test_fn(self):
torch._dynamo.reset()
records = []
# run with env var
with log_settings(settings), self._handler_watcher(records):
fn(self, records)
return test_fn
return wrapper
class LoggingTestCase(torch._dynamo.test_case.TestCase):
@classmethod
def setUpClass(cls):
super().setUpClass()
cls._exit_stack.enter_context(
unittest.mock.patch.dict(os.environ, {"___LOG_TESTING": ""})
)
cls._exit_stack.enter_context(
unittest.mock.patch("torch._dynamo.config.suppress_errors", True)
)
@classmethod
def tearDownClass(cls):
cls._exit_stack.close()
torch._logging._internal.log_state.clear()
torch._logging._init_logs()
# This patches the emit method of each handler to gather records
# as they are emitted
def _handler_watcher(self, record_list):
exit_stack = contextlib.ExitStack()
def emit_post_hook(record):
nonlocal record_list
record_list.append(record)
# registered logs are the only ones with handlers, so patch those
for log_qname in torch._logging._internal.log_registry.get_log_qnames():
logger = logging.getLogger(log_qname)
num_handlers = len(logger.handlers)
self.assertLessEqual(
num_handlers,
2,
"All pt2 loggers should only have at most two handlers (debug artifacts and messages above debug level).",
)
self.assertGreater(num_handlers, 0, "All pt2 loggers should have more than zero handlers")
for handler in logger.handlers:
old_emit = handler.emit
def new_emit(record):
old_emit(record)
emit_post_hook(record)
exit_stack.enter_context(
unittest.mock.patch.object(handler, "emit", new_emit)
)
return exit_stack