mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: We are working toward full model compilation, where when compilation error happens, we just fall back to eager mode rather than error out. But at the same time, we should fix these issues if they are bugs. We will: * 1/ log warnings in OSS; * 2/ log warnings and write them into Scuba in fbcode; to prevent us from ignoring these issues. Test Plan: Manual test Differential Revision: D47506314 Pull Request resolved: https://github.com/pytorch/pytorch/pull/105307 Approved by: https://github.com/jansel
159 lines
5.0 KiB
Python
159 lines
5.0 KiB
Python
import torch._dynamo.test_case
|
|
import unittest.mock
|
|
import os
|
|
import contextlib
|
|
import torch._logging
|
|
import torch._logging._internal
|
|
import logging
|
|
|
|
@contextlib.contextmanager
|
|
def preserve_log_state():
|
|
prev_state = torch._logging._internal._get_log_state()
|
|
torch._logging._internal._set_log_state(torch._logging._internal.LogState())
|
|
try:
|
|
yield
|
|
finally:
|
|
torch._logging._internal._set_log_state(prev_state)
|
|
torch._logging._internal._init_logs()
|
|
|
|
def log_settings(settings):
|
|
exit_stack = contextlib.ExitStack()
|
|
settings_patch = unittest.mock.patch.dict(os.environ, {"TORCH_LOGS": settings})
|
|
exit_stack.enter_context(preserve_log_state())
|
|
exit_stack.enter_context(settings_patch)
|
|
torch._logging._internal._init_logs()
|
|
return exit_stack
|
|
|
|
def log_api(**kwargs):
|
|
exit_stack = contextlib.ExitStack()
|
|
exit_stack.enter_context(preserve_log_state())
|
|
torch._logging.set_logs(**kwargs)
|
|
return exit_stack
|
|
|
|
|
|
def kwargs_to_settings(**kwargs):
|
|
INT_TO_VERBOSITY = {10: "+", 20: "", 40: "-"}
|
|
|
|
settings = []
|
|
|
|
def append_setting(name, level):
|
|
if isinstance(name, str) and isinstance(level, int) and level in INT_TO_VERBOSITY:
|
|
settings.append(INT_TO_VERBOSITY[level] + name)
|
|
return
|
|
else:
|
|
raise ValueError("Invalid value for setting")
|
|
|
|
for name, val in kwargs.items():
|
|
if isinstance(val, bool):
|
|
settings.append(name)
|
|
elif isinstance(val, int):
|
|
append_setting(name, val)
|
|
elif isinstance(val, dict) and name == "modules":
|
|
for module_qname, level in val.items():
|
|
append_setting(module_qname, level)
|
|
else:
|
|
raise ValueError("Invalid value for setting")
|
|
|
|
return ",".join(settings)
|
|
|
|
|
|
# Note on testing strategy:
|
|
# This class does two things:
|
|
# 1. Runs two versions of a test:
|
|
# 1a. patches the env var log settings to some specific value
|
|
# 1b. calls torch._logging.set_logs(..)
|
|
# 2. patches the emit method of each setup handler to gather records
|
|
# that are emitted to each console stream
|
|
# 3. passes a ref to the gathered records to each test case for checking
|
|
#
|
|
# The goal of this testing in general is to ensure that given some settings env var
|
|
# that the logs are setup correctly and capturing the correct records.
|
|
def make_logging_test(**kwargs):
|
|
def wrapper(fn):
|
|
def test_fn(self):
|
|
|
|
torch._dynamo.reset()
|
|
records = []
|
|
# run with env var
|
|
if len(kwargs) == 0:
|
|
with self._handler_watcher(records):
|
|
fn(self, records)
|
|
else:
|
|
with log_settings(kwargs_to_settings(**kwargs)), self._handler_watcher(records):
|
|
fn(self, records)
|
|
|
|
# run with API
|
|
torch._dynamo.reset()
|
|
records.clear()
|
|
with log_api(**kwargs), self._handler_watcher(records):
|
|
fn(self, records)
|
|
|
|
|
|
return test_fn
|
|
|
|
return wrapper
|
|
|
|
def make_settings_test(settings):
|
|
def wrapper(fn):
|
|
def test_fn(self):
|
|
torch._dynamo.reset()
|
|
records = []
|
|
# run with env var
|
|
with log_settings(settings), self._handler_watcher(records):
|
|
fn(self, records)
|
|
|
|
return test_fn
|
|
|
|
return wrapper
|
|
|
|
class LoggingTestCase(torch._dynamo.test_case.TestCase):
|
|
@classmethod
|
|
def setUpClass(cls):
|
|
super().setUpClass()
|
|
cls._exit_stack.enter_context(
|
|
unittest.mock.patch.dict(os.environ, {"___LOG_TESTING": ""})
|
|
)
|
|
cls._exit_stack.enter_context(
|
|
unittest.mock.patch("torch._dynamo.config.suppress_errors", True)
|
|
)
|
|
|
|
@classmethod
|
|
def tearDownClass(cls):
|
|
cls._exit_stack.close()
|
|
torch._logging._internal.log_state.clear()
|
|
torch._logging._init_logs()
|
|
|
|
# This patches the emit method of each handler to gather records
|
|
# as they are emitted
|
|
def _handler_watcher(self, record_list):
|
|
exit_stack = contextlib.ExitStack()
|
|
|
|
def emit_post_hook(record):
|
|
nonlocal record_list
|
|
record_list.append(record)
|
|
|
|
# registered logs are the only ones with handlers, so patch those
|
|
for log_qname in torch._logging._internal.log_registry.get_log_qnames():
|
|
logger = logging.getLogger(log_qname)
|
|
num_handlers = len(logger.handlers)
|
|
self.assertLessEqual(
|
|
num_handlers,
|
|
2,
|
|
"All pt2 loggers should only have at most two handlers (debug artifacts and messages above debug level).",
|
|
)
|
|
|
|
self.assertGreater(num_handlers, 0, "All pt2 loggers should have more than zero handlers")
|
|
|
|
for handler in logger.handlers:
|
|
old_emit = handler.emit
|
|
|
|
def new_emit(record):
|
|
old_emit(record)
|
|
emit_post_hook(record)
|
|
|
|
exit_stack.enter_context(
|
|
unittest.mock.patch.object(handler, "emit", new_emit)
|
|
)
|
|
|
|
return exit_stack
|