mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/146140 Approved by: https://github.com/albanD
654 lines
24 KiB
Python
654 lines
24 KiB
Python
# Owner(s): ["module: onnx"]
|
|
from __future__ import annotations
|
|
|
|
import contextlib
|
|
import dataclasses
|
|
import io
|
|
import logging
|
|
import typing
|
|
from typing import Protocol
|
|
|
|
import torch
|
|
from torch.onnx import errors
|
|
from torch.onnx._internal import diagnostics
|
|
from torch.onnx._internal.diagnostics import infra
|
|
from torch.onnx._internal.diagnostics.infra import formatter, sarif
|
|
from torch.onnx._internal.fx import diagnostics as fx_diagnostics
|
|
from torch.testing._internal import common_utils, logging_utils
|
|
|
|
|
|
if typing.TYPE_CHECKING:
|
|
import unittest
|
|
|
|
|
|
class _SarifLogBuilder(Protocol):
|
|
def sarif_log(self) -> sarif.SarifLog: ...
|
|
|
|
|
|
def _assert_has_diagnostics(
|
|
sarif_log_builder: _SarifLogBuilder,
|
|
rule_level_pairs: set[tuple[infra.Rule, infra.Level]],
|
|
):
|
|
sarif_log = sarif_log_builder.sarif_log()
|
|
unseen_pairs = {(rule.id, level.name.lower()) for rule, level in rule_level_pairs}
|
|
actual_results = []
|
|
for run in sarif_log.runs:
|
|
if run.results is None:
|
|
continue
|
|
for result in run.results:
|
|
id_level_pair = (result.rule_id, result.level)
|
|
unseen_pairs.discard(id_level_pair)
|
|
actual_results.append(id_level_pair)
|
|
|
|
if unseen_pairs:
|
|
raise AssertionError(
|
|
f"Expected diagnostic results of rule id and level pair {unseen_pairs} not found. "
|
|
f"Actual diagnostic results: {actual_results}"
|
|
)
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class _RuleCollectionForTest(infra.RuleCollection):
|
|
rule_without_message_args: infra.Rule = dataclasses.field(
|
|
default=infra.Rule(
|
|
"1",
|
|
"rule-without-message-args",
|
|
message_default_template="rule message",
|
|
)
|
|
)
|
|
|
|
|
|
@contextlib.contextmanager
|
|
def assert_all_diagnostics(
|
|
test_suite: unittest.TestCase,
|
|
sarif_log_builder: _SarifLogBuilder,
|
|
rule_level_pairs: set[tuple[infra.Rule, infra.Level]],
|
|
):
|
|
"""Context manager to assert that all diagnostics are emitted.
|
|
|
|
Usage:
|
|
with assert_all_diagnostics(
|
|
self,
|
|
diagnostics.engine,
|
|
{(rule, infra.Level.Error)},
|
|
):
|
|
torch.onnx.export(...)
|
|
|
|
Args:
|
|
test_suite: The test suite instance.
|
|
sarif_log_builder: The SARIF log builder.
|
|
rule_level_pairs: A set of rule and level pairs to assert.
|
|
|
|
Returns:
|
|
A context manager.
|
|
|
|
Raises:
|
|
AssertionError: If not all diagnostics are emitted.
|
|
"""
|
|
|
|
try:
|
|
yield
|
|
except errors.OnnxExporterError:
|
|
test_suite.assertIn(infra.Level.ERROR, {level for _, level in rule_level_pairs})
|
|
finally:
|
|
_assert_has_diagnostics(sarif_log_builder, rule_level_pairs)
|
|
|
|
|
|
def assert_diagnostic(
|
|
test_suite: unittest.TestCase,
|
|
sarif_log_builder: _SarifLogBuilder,
|
|
rule: infra.Rule,
|
|
level: infra.Level,
|
|
):
|
|
"""Context manager to assert that a diagnostic is emitted.
|
|
|
|
Usage:
|
|
with assert_diagnostic(
|
|
self,
|
|
diagnostics.engine,
|
|
rule,
|
|
infra.Level.Error,
|
|
):
|
|
torch.onnx.export(...)
|
|
|
|
Args:
|
|
test_suite: The test suite instance.
|
|
sarif_log_builder: The SARIF log builder.
|
|
rule: The rule to assert.
|
|
level: The level to assert.
|
|
|
|
Returns:
|
|
A context manager.
|
|
|
|
Raises:
|
|
AssertionError: If the diagnostic is not emitted.
|
|
"""
|
|
|
|
return assert_all_diagnostics(test_suite, sarif_log_builder, {(rule, level)})
|
|
|
|
|
|
class TestDynamoOnnxDiagnostics(common_utils.TestCase):
|
|
"""Test cases for diagnostics emitted by the Dynamo ONNX export code."""
|
|
|
|
def setUp(self):
|
|
self.diagnostic_context = fx_diagnostics.DiagnosticContext("dynamo_export", "")
|
|
self.rules = _RuleCollectionForTest()
|
|
return super().setUp()
|
|
|
|
def test_log_is_recorded_in_sarif_additional_messages_according_to_diagnostic_options_verbosity_level(
|
|
self,
|
|
):
|
|
logging_levels = [
|
|
logging.DEBUG,
|
|
logging.INFO,
|
|
logging.WARNING,
|
|
logging.ERROR,
|
|
]
|
|
for verbosity_level in logging_levels:
|
|
self.diagnostic_context.options.verbosity_level = verbosity_level
|
|
with self.diagnostic_context:
|
|
diagnostic = fx_diagnostics.Diagnostic(
|
|
self.rules.rule_without_message_args, infra.Level.NONE
|
|
)
|
|
additional_messages_count = len(diagnostic.additional_messages)
|
|
for log_level in logging_levels:
|
|
diagnostic.log(level=log_level, message="log message")
|
|
if log_level >= verbosity_level:
|
|
self.assertGreater(
|
|
len(diagnostic.additional_messages),
|
|
additional_messages_count,
|
|
f"Additional message should be recorded when log level is {log_level} "
|
|
f"and verbosity level is {verbosity_level}",
|
|
)
|
|
else:
|
|
self.assertEqual(
|
|
len(diagnostic.additional_messages),
|
|
additional_messages_count,
|
|
f"Additional message should not be recorded when log level is "
|
|
f"{log_level} and verbosity level is {verbosity_level}",
|
|
)
|
|
|
|
def test_torch_logs_environment_variable_precedes_diagnostic_options_verbosity_level(
|
|
self,
|
|
):
|
|
self.diagnostic_context.options.verbosity_level = logging.ERROR
|
|
with logging_utils.log_settings("onnx_diagnostics"), self.diagnostic_context:
|
|
diagnostic = fx_diagnostics.Diagnostic(
|
|
self.rules.rule_without_message_args, infra.Level.NONE
|
|
)
|
|
additional_messages_count = len(diagnostic.additional_messages)
|
|
diagnostic.debug("message")
|
|
self.assertGreater(
|
|
len(diagnostic.additional_messages), additional_messages_count
|
|
)
|
|
|
|
def test_log_is_not_emitted_to_terminal_when_log_artifact_is_not_enabled(self):
|
|
self.diagnostic_context.options.verbosity_level = logging.INFO
|
|
with self.diagnostic_context:
|
|
diagnostic = fx_diagnostics.Diagnostic(
|
|
self.rules.rule_without_message_args, infra.Level.NONE
|
|
)
|
|
|
|
with self.assertLogs(
|
|
diagnostic.logger, level=logging.INFO
|
|
) as assert_log_context:
|
|
diagnostic.info("message")
|
|
# NOTE: self.assertNoLogs only exist >= Python 3.10
|
|
# Add this dummy log such that we can pass self.assertLogs, and inspect
|
|
# assert_log_context.records to check if the log we don't want is not emitted.
|
|
diagnostic.logger.log(logging.ERROR, "dummy message")
|
|
|
|
self.assertEqual(len(assert_log_context.records), 1)
|
|
|
|
def test_log_is_emitted_to_terminal_when_log_artifact_is_enabled(self):
|
|
self.diagnostic_context.options.verbosity_level = logging.INFO
|
|
|
|
with logging_utils.log_settings("onnx_diagnostics"), self.diagnostic_context:
|
|
diagnostic = fx_diagnostics.Diagnostic(
|
|
self.rules.rule_without_message_args, infra.Level.NONE
|
|
)
|
|
|
|
with self.assertLogs(diagnostic.logger, level=logging.INFO):
|
|
diagnostic.info("message")
|
|
|
|
def test_diagnostic_log_emit_correctly_formatted_string(self):
|
|
verbosity_level = logging.INFO
|
|
self.diagnostic_context.options.verbosity_level = verbosity_level
|
|
with self.diagnostic_context:
|
|
diagnostic = fx_diagnostics.Diagnostic(
|
|
self.rules.rule_without_message_args, infra.Level.NOTE
|
|
)
|
|
diagnostic.log(
|
|
logging.INFO,
|
|
"%s",
|
|
formatter.LazyString(lambda x, y: f"{x} {y}", "hello", "world"),
|
|
)
|
|
self.assertIn("hello world", diagnostic.additional_messages)
|
|
|
|
def test_log_diagnostic_to_diagnostic_context_raises_when_diagnostic_type_is_wrong(
|
|
self,
|
|
):
|
|
with self.diagnostic_context:
|
|
# Dynamo onnx exporter diagnostic context expects fx_diagnostics.Diagnostic
|
|
# instead of base infra.Diagnostic.
|
|
diagnostic = infra.Diagnostic(
|
|
self.rules.rule_without_message_args, infra.Level.NOTE
|
|
)
|
|
with self.assertRaises(TypeError):
|
|
self.diagnostic_context.log(diagnostic)
|
|
|
|
|
|
class TestTorchScriptOnnxDiagnostics(common_utils.TestCase):
|
|
"""Test cases for diagnostics emitted by the TorchScript ONNX export code."""
|
|
|
|
def setUp(self):
|
|
engine = diagnostics.engine
|
|
engine.clear()
|
|
self._sample_rule = diagnostics.rules.missing_custom_symbolic_function
|
|
super().setUp()
|
|
|
|
def _trigger_node_missing_onnx_shape_inference_warning_diagnostic_from_cpp(
|
|
self,
|
|
) -> diagnostics.TorchScriptOnnxExportDiagnostic:
|
|
class CustomAdd(torch.autograd.Function):
|
|
@staticmethod
|
|
def forward(ctx, x, y):
|
|
return x + y
|
|
|
|
@staticmethod
|
|
def symbolic(g, x, y):
|
|
return g.op("custom::CustomAdd", x, y)
|
|
|
|
class M(torch.nn.Module):
|
|
def forward(self, x):
|
|
return CustomAdd.apply(x, x)
|
|
|
|
# trigger warning for missing shape inference.
|
|
rule = diagnostics.rules.node_missing_onnx_shape_inference
|
|
torch.onnx.export(M(), torch.randn(3, 4), io.BytesIO())
|
|
|
|
context = diagnostics.engine.contexts[-1]
|
|
for diagnostic in context.diagnostics:
|
|
if (
|
|
diagnostic.rule == rule
|
|
and diagnostic.level == diagnostics.levels.WARNING
|
|
):
|
|
return typing.cast(
|
|
diagnostics.TorchScriptOnnxExportDiagnostic, diagnostic
|
|
)
|
|
raise AssertionError("No diagnostic found.")
|
|
|
|
def test_assert_diagnostic_raises_when_diagnostic_not_found(self):
|
|
with self.assertRaises(AssertionError):
|
|
with assert_diagnostic(
|
|
self,
|
|
diagnostics.engine,
|
|
diagnostics.rules.node_missing_onnx_shape_inference,
|
|
diagnostics.levels.WARNING,
|
|
):
|
|
pass
|
|
|
|
def test_cpp_diagnose_emits_warning(self):
|
|
with assert_diagnostic(
|
|
self,
|
|
diagnostics.engine,
|
|
diagnostics.rules.node_missing_onnx_shape_inference,
|
|
diagnostics.levels.WARNING,
|
|
):
|
|
# trigger warning for missing shape inference.
|
|
self._trigger_node_missing_onnx_shape_inference_warning_diagnostic_from_cpp()
|
|
|
|
def test_py_diagnose_emits_error(self):
|
|
class M(torch.nn.Module):
|
|
def forward(self, x):
|
|
return torch.diagonal(x)
|
|
|
|
with assert_diagnostic(
|
|
self,
|
|
diagnostics.engine,
|
|
diagnostics.rules.operator_supported_in_newer_opset_version,
|
|
diagnostics.levels.ERROR,
|
|
):
|
|
# trigger error for operator unsupported until newer opset version.
|
|
torch.onnx.export(
|
|
M(),
|
|
torch.randn(3, 4),
|
|
io.BytesIO(),
|
|
opset_version=9,
|
|
)
|
|
|
|
def test_diagnostics_engine_records_diagnosis_reported_outside_of_export(
|
|
self,
|
|
):
|
|
sample_level = diagnostics.levels.ERROR
|
|
with assert_diagnostic(
|
|
self,
|
|
diagnostics.engine,
|
|
self._sample_rule,
|
|
sample_level,
|
|
):
|
|
diagnostic = infra.Diagnostic(self._sample_rule, sample_level)
|
|
diagnostics.export_context().log(diagnostic)
|
|
|
|
def test_diagnostics_records_python_call_stack(self):
|
|
diagnostic = diagnostics.TorchScriptOnnxExportDiagnostic(self._sample_rule, diagnostics.levels.NOTE) # fmt: skip
|
|
# Do not break the above line, otherwise it will not work with Python-3.8+
|
|
stack = diagnostic.python_call_stack
|
|
assert stack is not None # for mypy
|
|
self.assertGreater(len(stack.frames), 0)
|
|
frame = stack.frames[0]
|
|
assert frame.location.snippet is not None # for mypy
|
|
self.assertIn("self._sample_rule", frame.location.snippet)
|
|
assert frame.location.uri is not None # for mypy
|
|
self.assertIn("test_diagnostics.py", frame.location.uri)
|
|
|
|
def test_diagnostics_records_cpp_call_stack(self):
|
|
diagnostic = self._trigger_node_missing_onnx_shape_inference_warning_diagnostic_from_cpp()
|
|
stack = diagnostic.cpp_call_stack
|
|
assert stack is not None # for mypy
|
|
self.assertGreater(len(stack.frames), 0)
|
|
frame_messages = [frame.location.message for frame in stack.frames]
|
|
# node missing onnx shape inference warning only comes from ToONNX (_jit_pass_onnx)
|
|
# after node-level shape type inference and processed symbolic_fn output type
|
|
self.assertTrue(
|
|
any(
|
|
isinstance(message, str) and "torch::jit::NodeToONNX" in message
|
|
for message in frame_messages
|
|
)
|
|
)
|
|
|
|
|
|
@common_utils.instantiate_parametrized_tests
|
|
class TestDiagnosticsInfra(common_utils.TestCase):
|
|
"""Test cases for diagnostics infra."""
|
|
|
|
def setUp(self):
|
|
self.rules = _RuleCollectionForTest()
|
|
with contextlib.ExitStack() as stack:
|
|
self.context: infra.DiagnosticContext[infra.Diagnostic] = (
|
|
stack.enter_context(infra.DiagnosticContext("test", "1.0.0"))
|
|
)
|
|
self.addCleanup(stack.pop_all().close)
|
|
return super().setUp()
|
|
|
|
def test_diagnostics_engine_records_diagnosis_with_custom_rules(self):
|
|
custom_rules = infra.RuleCollection.custom_collection_from_list(
|
|
"CustomRuleCollection",
|
|
[
|
|
infra.Rule(
|
|
"1",
|
|
"custom-rule",
|
|
message_default_template="custom rule message",
|
|
),
|
|
infra.Rule(
|
|
"2",
|
|
"custom-rule-2",
|
|
message_default_template="custom rule message 2",
|
|
),
|
|
],
|
|
)
|
|
|
|
with assert_all_diagnostics(
|
|
self,
|
|
self.context,
|
|
{
|
|
(custom_rules.custom_rule, infra.Level.WARNING), # type: ignore[attr-defined]
|
|
(custom_rules.custom_rule_2, infra.Level.ERROR), # type: ignore[attr-defined]
|
|
},
|
|
):
|
|
diagnostic1 = infra.Diagnostic(
|
|
custom_rules.custom_rule, # type: ignore[attr-defined]
|
|
infra.Level.WARNING,
|
|
)
|
|
self.context.log(diagnostic1)
|
|
|
|
diagnostic2 = infra.Diagnostic(
|
|
custom_rules.custom_rule_2, # type: ignore[attr-defined]
|
|
infra.Level.ERROR,
|
|
)
|
|
self.context.log(diagnostic2)
|
|
|
|
def test_diagnostic_log_is_not_emitted_when_level_less_than_diagnostic_options_verbosity_level(
|
|
self,
|
|
):
|
|
verbosity_level = logging.INFO
|
|
self.context.options.verbosity_level = verbosity_level
|
|
with self.context:
|
|
diagnostic = infra.Diagnostic(
|
|
self.rules.rule_without_message_args, infra.Level.NOTE
|
|
)
|
|
|
|
with self.assertLogs(
|
|
diagnostic.logger, level=verbosity_level
|
|
) as assert_log_context:
|
|
diagnostic.log(logging.DEBUG, "debug message")
|
|
# NOTE: self.assertNoLogs only exist >= Python 3.10
|
|
# Add this dummy log such that we can pass self.assertLogs, and inspect
|
|
# assert_log_context.records to check if the log level is correct.
|
|
diagnostic.log(logging.INFO, "info message")
|
|
|
|
for record in assert_log_context.records:
|
|
self.assertGreaterEqual(record.levelno, logging.INFO)
|
|
self.assertFalse(
|
|
any(
|
|
message.find("debug message") >= 0
|
|
for message in diagnostic.additional_messages
|
|
)
|
|
)
|
|
|
|
def test_diagnostic_log_is_emitted_when_level_not_less_than_diagnostic_options_verbosity_level(
|
|
self,
|
|
):
|
|
verbosity_level = logging.INFO
|
|
self.context.options.verbosity_level = verbosity_level
|
|
with self.context:
|
|
diagnostic = infra.Diagnostic(
|
|
self.rules.rule_without_message_args, infra.Level.NOTE
|
|
)
|
|
|
|
level_message_pairs = [
|
|
(logging.INFO, "info message"),
|
|
(logging.WARNING, "warning message"),
|
|
(logging.ERROR, "error message"),
|
|
]
|
|
|
|
for level, message in level_message_pairs:
|
|
with self.assertLogs(diagnostic.logger, level=verbosity_level):
|
|
diagnostic.log(level, message)
|
|
|
|
self.assertTrue(
|
|
any(
|
|
message.find(message) >= 0
|
|
for message in diagnostic.additional_messages
|
|
)
|
|
)
|
|
|
|
@common_utils.parametrize(
|
|
"log_api, log_level",
|
|
[
|
|
("debug", logging.DEBUG),
|
|
("info", logging.INFO),
|
|
("warning", logging.WARNING),
|
|
("error", logging.ERROR),
|
|
],
|
|
)
|
|
def test_diagnostic_log_is_emitted_according_to_api_level_and_diagnostic_options_verbosity_level(
|
|
self, log_api: str, log_level: int
|
|
):
|
|
verbosity_level = logging.INFO
|
|
self.context.options.verbosity_level = verbosity_level
|
|
with self.context:
|
|
diagnostic = infra.Diagnostic(
|
|
self.rules.rule_without_message_args, infra.Level.NOTE
|
|
)
|
|
|
|
message = "log message"
|
|
with self.assertLogs(
|
|
diagnostic.logger, level=verbosity_level
|
|
) as assert_log_context:
|
|
getattr(diagnostic, log_api)(message)
|
|
# NOTE: self.assertNoLogs only exist >= Python 3.10
|
|
# Add this dummy log such that we can pass self.assertLogs, and inspect
|
|
# assert_log_context.records to check if the log level is correct.
|
|
diagnostic.log(logging.ERROR, "dummy message")
|
|
|
|
for record in assert_log_context.records:
|
|
self.assertGreaterEqual(record.levelno, logging.INFO)
|
|
|
|
if log_level >= verbosity_level:
|
|
self.assertIn(message, diagnostic.additional_messages)
|
|
else:
|
|
self.assertNotIn(message, diagnostic.additional_messages)
|
|
|
|
def test_diagnostic_log_lazy_string_is_not_evaluated_when_level_less_than_diagnostic_options_verbosity_level(
|
|
self,
|
|
):
|
|
verbosity_level = logging.INFO
|
|
self.context.options.verbosity_level = verbosity_level
|
|
with self.context:
|
|
diagnostic = infra.Diagnostic(
|
|
self.rules.rule_without_message_args, infra.Level.NOTE
|
|
)
|
|
|
|
reference_val = 0
|
|
|
|
def expensive_formatting_function() -> str:
|
|
# Modify the reference_val to reflect this function is evaluated
|
|
nonlocal reference_val
|
|
reference_val += 1
|
|
return f"expensive formatting {reference_val}"
|
|
|
|
# `expensive_formatting_function` should NOT be evaluated.
|
|
diagnostic.debug("%s", formatter.LazyString(expensive_formatting_function))
|
|
self.assertEqual(
|
|
reference_val,
|
|
0,
|
|
"expensive_formatting_function should not be evaluated after being wrapped under LazyString",
|
|
)
|
|
|
|
def test_diagnostic_log_lazy_string_is_evaluated_once_when_level_not_less_than_diagnostic_options_verbosity_level(
|
|
self,
|
|
):
|
|
verbosity_level = logging.INFO
|
|
self.context.options.verbosity_level = verbosity_level
|
|
with self.context:
|
|
diagnostic = infra.Diagnostic(
|
|
self.rules.rule_without_message_args, infra.Level.NOTE
|
|
)
|
|
|
|
reference_val = 0
|
|
|
|
def expensive_formatting_function() -> str:
|
|
# Modify the reference_val to reflect this function is evaluated
|
|
nonlocal reference_val
|
|
reference_val += 1
|
|
return f"expensive formatting {reference_val}"
|
|
|
|
# `expensive_formatting_function` should NOT be evaluated.
|
|
diagnostic.info("%s", formatter.LazyString(expensive_formatting_function))
|
|
self.assertEqual(
|
|
reference_val,
|
|
1,
|
|
"expensive_formatting_function should only be evaluated once after being wrapped under LazyString",
|
|
)
|
|
|
|
def test_diagnostic_log_emit_correctly_formatted_string(self):
|
|
verbosity_level = logging.INFO
|
|
self.context.options.verbosity_level = verbosity_level
|
|
with self.context:
|
|
diagnostic = infra.Diagnostic(
|
|
self.rules.rule_without_message_args, infra.Level.NOTE
|
|
)
|
|
diagnostic.log(
|
|
logging.INFO,
|
|
"%s",
|
|
formatter.LazyString(lambda x, y: f"{x} {y}", "hello", "world"),
|
|
)
|
|
self.assertIn("hello world", diagnostic.additional_messages)
|
|
|
|
def test_diagnostic_nested_log_section_emits_messages_with_correct_section_title_indentation(
|
|
self,
|
|
):
|
|
verbosity_level = logging.INFO
|
|
self.context.options.verbosity_level = verbosity_level
|
|
with self.context:
|
|
diagnostic = infra.Diagnostic(
|
|
self.rules.rule_without_message_args, infra.Level.NOTE
|
|
)
|
|
|
|
with diagnostic.log_section(logging.INFO, "My Section"):
|
|
diagnostic.log(logging.INFO, "My Message")
|
|
with diagnostic.log_section(logging.INFO, "My Subsection"):
|
|
diagnostic.log(logging.INFO, "My Submessage")
|
|
|
|
with diagnostic.log_section(logging.INFO, "My Section 2"):
|
|
diagnostic.log(logging.INFO, "My Message 2")
|
|
|
|
self.assertIn("## My Section", diagnostic.additional_messages)
|
|
self.assertIn("### My Subsection", diagnostic.additional_messages)
|
|
self.assertIn("## My Section 2", diagnostic.additional_messages)
|
|
|
|
def test_diagnostic_log_source_exception_emits_exception_traceback_and_error_message(
|
|
self,
|
|
):
|
|
verbosity_level = logging.INFO
|
|
self.context.options.verbosity_level = verbosity_level
|
|
with self.context:
|
|
try:
|
|
raise ValueError("original exception")
|
|
except ValueError as e:
|
|
diagnostic = infra.Diagnostic(
|
|
self.rules.rule_without_message_args, infra.Level.NOTE
|
|
)
|
|
diagnostic.log_source_exception(logging.ERROR, e)
|
|
|
|
diagnostic_message = "\n".join(diagnostic.additional_messages)
|
|
|
|
self.assertIn("ValueError: original exception", diagnostic_message)
|
|
self.assertIn("Traceback (most recent call last):", diagnostic_message)
|
|
|
|
def test_log_diagnostic_to_diagnostic_context_raises_when_diagnostic_type_is_wrong(
|
|
self,
|
|
):
|
|
with self.context:
|
|
with self.assertRaises(TypeError):
|
|
# The method expects 'Diagnostic' or its subclasses as arguments.
|
|
# Passing any other type will trigger a TypeError.
|
|
self.context.log("This is a str message.")
|
|
|
|
def test_diagnostic_context_raises_if_diagnostic_is_error(self):
|
|
with self.assertRaises(infra.RuntimeErrorWithDiagnostic):
|
|
self.context.log_and_raise_if_error(
|
|
infra.Diagnostic(
|
|
self.rules.rule_without_message_args, infra.Level.ERROR
|
|
)
|
|
)
|
|
|
|
def test_diagnostic_context_raises_original_exception_from_diagnostic_created_from_it(
|
|
self,
|
|
):
|
|
with self.assertRaises(ValueError):
|
|
try:
|
|
raise ValueError("original exception")
|
|
except ValueError as e:
|
|
diagnostic = infra.Diagnostic(
|
|
self.rules.rule_without_message_args, infra.Level.ERROR
|
|
)
|
|
diagnostic.log_source_exception(logging.ERROR, e)
|
|
self.context.log_and_raise_if_error(diagnostic)
|
|
|
|
def test_diagnostic_context_raises_if_diagnostic_is_warning_and_warnings_as_errors_is_true(
|
|
self,
|
|
):
|
|
with self.assertRaises(infra.RuntimeErrorWithDiagnostic):
|
|
self.context.options.warnings_as_errors = True
|
|
self.context.log_and_raise_if_error(
|
|
infra.Diagnostic(
|
|
self.rules.rule_without_message_args, infra.Level.WARNING
|
|
)
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
common_utils.run_tests()
|