pytorch/torch/_dynamo/exc.py
Shangdi Yu 4a6cf0a93e Fix dynamo stack trace (#165930)
Fixes #165911

- Add message to Attribute error so we see `  Developer debug context: raised exception AttributeError(["'Linear' object has no attribute 'w'"])` instead of just `Developer debug context: raised exception AttributeError([])`
- Add stack trace in `ObservedException` so we display the inner most error stack trace back to user code

Output:

```
/data/users/shangdiy/pytorch/torch/__init__.py:2641: UserWarning: You are calling torch.compile inside torch.export region. To capture an useful graph, we will implicitly switch to torch.compile(backend=eager)
  warnings.warn(
Traceback (most recent call last):
  File "/data/users/shangdiy/pytorch/torch/_dynamo/variables/user_defined.py", line 1385, in var_getattr
    subobj = self._getattr_static(name)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/shangdiy/pytorch/torch/_dynamo/variables/user_defined.py", line 1256, in _getattr_static
    subobj = type(self.value).__getattribute__(self.value, name)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AttributeError: 'Linear' object has no attribute 'w'

During handling of the above exception, another exception occurred:

torch._dynamo.exc.ObservedAttributeError: 'Linear' object has no attribute 'w'

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/data/users/shangdiy/pytorch/test.py", line 34, in <module>
    mod = torch._dynamo.functional_export._dynamo_graph_capture_for_export(Model())(x)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/shangdiy/pytorch/torch/_dynamo/functional_export.py", line 481, in inner
    out = fullgraph_capture(
          ^^^^^^^^^^^^^^^^^^
  File "/data/users/shangdiy/pytorch/torch/_dynamo/convert_frame.py", line 1053, in fullgraph_capture
    return _fullgraph_capture_frame(
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/shangdiy/pytorch/torch/_dynamo/convert_frame.py", line 1115, in _fullgraph_capture_frame
    raise e.with_traceback(None) from e.__cause__  # User compiler error
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
torch._dynamo.exc.Unsupported: Observed exception
  Explanation: Dynamo found no exception handler at the top-level compiled function when encountering an exception. Exception will propagate outside the compiled region.
  Hint: Dynamo has detected that tracing the code will result in an error when running in eager. Please double check that your code doesn't contain a similar error when actually running eager/uncompiled.
  Hint: It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues.

  Developer debug context: raised exception AttributeError(["'Linear' object has no attribute 'w'"])

 For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0088.html

from user code:
   File "/data/users/shangdiy/pytorch/torch/_dynamo/functional_export.py", line 171, in forward
    res = self._export_root(*args, **kwargs)
  File "/data/users/shangdiy/pytorch/test.py", line 31, in forward
    weight = self.linear.w

Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"

```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165930
Approved by: https://github.com/anijain2305
2025-10-21 01:32:23 +00:00

831 lines
26 KiB
Python

from __future__ import annotations
"""Exception handling and error reporting for TorchDynamo.
This module provides a comprehensive set of exception classes and utilities for error
handling in TorchDynamo. It includes:
Base Exceptions:
- TorchDynamoException: Base class for all TorchDynamo-specific exceptions
- Various specialized subclasses for different error scenarios
User Error Handling:
- UserError: Exceptions for user-facing errors in TorchDynamo usage
- UserErrorType: Enumeration of different categories of user errors
- Formatted error messages with debugging information
Observed Exceptions:
- Classes for handling exceptions observed during tracing
- Special handling for StopIteration, LookupError, etc.
- Exception state management during compilation
Error Formatting:
- Stack trace filtering and formatting
- Error message augmentation
- Debugging utilities for error reporting
"""
import json
import logging
import os
import re
import textwrap
import typing
from enum import auto, Enum
from functools import lru_cache
from pathlib import Path
from traceback import extract_stack, format_exc, format_list, StackSummary
from typing import Any, NoReturn, Optional, TYPE_CHECKING
import torch._guards
from torch._utils_internal import get_file_path_2
from . import config
from .utils import counters
if TYPE_CHECKING:
import types
from torch._guards import CompileId
from .output_graph import DynamoTracerOutput
from .symbolic_convert import InstructionTranslatorBase
from .types import DynamoFrameType
def exportdb_error_message(case_name: str) -> str:
return (
"For more information about this error, see: "
+ "https://pytorch.org/docs/main/generated/exportdb/index.html#"
+ case_name.replace("_", "-")
)
log = logging.getLogger(__name__)
graph_breaks_log = torch._logging.getArtifactLogger(__name__, "graph_breaks")
class TorchDynamoException(RuntimeError):
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self._torch_dynamo_tracer_output: Optional[DynamoTracerOutput] = None
class InternalTorchDynamoError(TorchDynamoException):
pass
class ResumePrologueTracingError(TorchDynamoException):
pass
class RestartAnalysis(TorchDynamoException):
restart_reason: Optional[str]
def __init__(self, *args: Any, restart_reason: Optional[str] = None) -> None:
self.restart_reason = restart_reason
super().__init__(*args)
class SpeculationRestartAnalysis(RestartAnalysis):
pass
class UnspecializeRestartAnalysis(RestartAnalysis):
pass
class CompileCollectiveRestartAnalysis(RestartAnalysis):
pass
class TensorifyScalarRestartAnalysis(RestartAnalysis):
pass
class SkipFrame(TorchDynamoException):
pass
class TorchRuntimeError(TorchDynamoException):
pass
class InvalidBackend(TorchDynamoException):
def __init__(self, name: str) -> None:
super().__init__(
f"Invalid backend: {name!r}, see `torch._dynamo.list_backends()` for available backends."
)
class ResetRequired(TorchDynamoException):
def __init__(self) -> None:
super().__init__(
textwrap.dedent(
"""
Must call `torch._dynamo.reset()` before changing backends. Detected two calls to
`torch.compile()` with a different backend compiler arguments.
"""
)
)
class ShortenTraceback(TorchDynamoException):
def __init__(
self, *args: Any, first_useful_frame: Optional[types.FrameType], **kwargs: Any
) -> None:
super().__init__(*args, **kwargs)
self.first_useful_frame = first_useful_frame
def remove_dynamo_frames(self) -> typing.Self:
tb = self.__traceback__
if self.first_useful_frame is None or tb is None or config.verbose:
return self
while tb.tb_frame is not self.first_useful_frame:
tb = tb.tb_next
assert tb is not None, "internal error, please report a bug"
return self.with_traceback(tb)
class BackendCompilerFailed(ShortenTraceback):
def __init__(
self,
backend_fn: Any,
inner_exception: Exception,
first_useful_frame: Optional[types.FrameType],
) -> None:
self.backend_name = getattr(backend_fn, "__name__", "?")
self.inner_exception = inner_exception
msg = f"backend={self.backend_name!r} raised:\n{type(inner_exception).__name__}: {inner_exception}"
super().__init__(msg, first_useful_frame=first_useful_frame)
class Unsupported(TorchDynamoException):
def __init__(
self,
msg: str,
*,
case_name: Optional[str] = None,
real_stack: None | StackSummary = None,
) -> None:
super().__init__(msg)
if not real_stack:
real_stack = torch._guards.TracingContext.extract_stack()
self.real_stack = real_stack
self.msg = msg
self.category: Optional[str] = None
self.add_to_stats()
self.case_name: Optional[str] = case_name
def remove_from_stats(self) -> None:
assert self.category is not None
counters[self.category][self.msg] -= 1
if counters[self.category][self.msg] <= 0:
del counters[self.category][self.msg]
def add_to_stats(self, category: str = "unimplemented") -> None:
self.category = category
counters[category][self.msg] += 1
class UnknownPropertiesDuringBackwardTrace(Unsupported):
pass
class RecompileError(TorchDynamoException):
pass
class ArgsMismatchError(Unsupported):
def __init__(self, msg: str) -> None:
super().__init__(msg)
class AttributeMutationError(Unsupported):
def __init__(self, msg: str) -> None:
super().__init__(msg)
class InfiniteGeneratorError(Unsupported):
# Raised when the number of yielded values is greater than MAX_ITERATOR_LIMIT
def __init__(self, msg: str) -> None:
super().__init__(msg)
class SideEffectsError(Unsupported):
def __init__(self, msg: str) -> None:
super().__init__(msg)
class CondOpArgsMismatchError(ArgsMismatchError):
"""
Internal error from cond() due to arguments mismatch.
"""
def __init__(self, msg: str) -> None:
super().__init__(msg)
class UserErrorType(Enum):
DYNAMIC_CONTROL_FLOW = auto()
ANTI_PATTERN = auto()
STANDARD_LIBRARY = auto()
CONSTRAINT_VIOLATION = auto()
DYNAMIC_DIM = auto()
INVALID_INPUT = auto()
INVALID_OUTPUT = auto()
UNSUPPORTED_ALIASED_MUTATED_DYNAMIC_INPUTS = auto()
class UserError(Unsupported):
def __init__(
self, error_type: UserErrorType, msg: str, case_name: Optional[str] = None
) -> None:
"""
Type of errors that would be valid in Eager, but not supported in TorchDynamo.
The error message should tell user about next actions.
error_type: Type of user error
msg: Actionable error message
case_name: (Optional) Unique name (snake case) for the usage example in exportdb.
"""
if case_name is not None:
assert isinstance(case_name, str)
if msg.endswith("."):
msg += " "
else:
msg += "\n"
msg += exportdb_error_message(case_name)
super().__init__(msg)
self.error_type = error_type
self.message = msg
class SkipCodeRecursiveException(TorchDynamoException):
pass
class RecompileLimitExceeded(Unsupported):
pass
# debug exception thrown when tracing torch._dynamo.step_unsupported()
class StepUnsupported(TorchDynamoException):
pass
class UnsafeScriptObjectError(TorchDynamoException):
pass
class UncapturedHigherOrderOpError(TorchDynamoException):
def __init__(self, msg: str, real_stack: Optional[StackSummary] = None) -> None:
super().__init__(msg)
self.msg = msg
self.real_stack = (
real_stack
if real_stack is not None
else torch._guards.TracingContext.extract_stack()
)
class IncorrectUsage(Exception):
pass
# TODO: I'm a little uncertain about what error classification we should have
# for this. This is potentially a user error, but regressions in
# specialization in PyTorch proper could also trigger this problem
class FailOnRecompileLimitHit(Exception):
pass
class PackageError(TorchDynamoException):
pass
class ObservedException(TorchDynamoException):
# An exception observed during the tracing. This exception is used by Dynamo to handle exceptions.
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self.real_stack: StackSummary = torch._guards.TracingContext.extract_stack()
class ObservedUserStopIteration(ObservedException):
# An UserStopIteration exception observed during the Dynamo tracing (e.g Dynamo tracing __next__)
value: Optional[Any]
# Reference `StopIteration_init` in CPython
# https://github.com/python/cpython/blob/3.11/Objects/exceptions.c#L568-L584
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__("unhandled `raise StopIteration`")
if len(args) > 0:
self.value = args[0]
else:
self.value = None
class ObservedLookupError(ObservedException):
# A LookupError exception to be raised from inside Dynamo tracing. This can happen on __getitem__
pass
class ObservedIndexError(ObservedLookupError):
# An IndexError exception to be raised from inside Dynamo tracing. This can happen on list __getitem__
pass
class ObservedKeyError(ObservedLookupError):
# A KeyError exception to be raised from inside Dynamo tracing. This can happen on dict __getitem__
pass
class ObservedGeneratorExit(ObservedException):
pass
class ObservedAttributeError(ObservedException):
# An AttributeError exception to be raised from inside Dynamo tracing. This can happen on user defined object __getattr__
pass
class ObservedRuntimeError(ObservedException):
# A RuntimeError exception to be raised from inside Dynamo tracing. This can happen on generator.throw(..) method
pass
class ObservedNotImplementedError(ObservedException):
pass
class ObservedTypeError(ObservedException):
# A TypeError exception to be raised from inside Dynamo tracing. This can happen on generator.send(..) method
pass
observed_exception_map = {
StopIteration: ObservedUserStopIteration,
LookupError: ObservedLookupError,
IndexError: ObservedIndexError,
GeneratorExit: ObservedGeneratorExit,
KeyError: ObservedKeyError,
AttributeError: ObservedAttributeError,
RuntimeError: ObservedRuntimeError,
NotImplementedError: ObservedNotImplementedError,
TypeError: ObservedTypeError,
}
def get_dynamo_observed_exception(exc_type: type[Exception]) -> type[ObservedException]:
if exc_type not in observed_exception_map:
name = getattr(exc_type, "__name__", str(exc_type))
observed_exception_map[exc_type] = type( # type: ignore[assignment]
f"Observed{name}Error", (ObservedException,), {}
)
# pyrefly: ignore # index-error
return observed_exception_map[exc_type]
def raise_observed_exception(
exc_type: type[Exception],
tx: InstructionTranslatorBase,
*,
args: Optional[list[Any]] = None,
kwargs: Optional[dict[str, Any]] = None,
msg: Optional[str] = None,
) -> NoReturn:
from .variables import BuiltinVariable
# CPython here raises an exception. Since there is no python code, we have to manually setup the exception
# stack and raise the exception.
# If a message is provided but no args, use the message as the first argument
if msg is not None and (args is None or len(args) == 0):
args = [msg]
exception_vt = BuiltinVariable(exc_type).call_function(tx, args or [], kwargs or {}) # type: ignore[arg-type]
tx.exn_vt_stack.set_current_exception(exception_vt) # type: ignore[arg-type]
raised_exc = get_dynamo_observed_exception(exc_type)
# Store the original exception arguments for better error messages
if args:
raise raised_exc(*args)
raise raised_exc
def handle_observed_exception(tx: Any) -> None:
# This is essentially exception handling code, equivalent of this pseudo code
#
# try:
# ... somebody raising StopIteration
# except StopIteration
# pass
#
# If this was going through the python code, we would have called exception_handler method, but FOR_ITER
# handles the exception completely in CPython. For example for 3.11, the resulting bytecode is
#
#
# 6 46 LOAD_GLOBAL 2 (StopIteration)
# 58 RAISE_VARARGS 1
# >> 60 PUSH_EXC_INFO
# 7 62 LOAD_GLOBAL 2 (StopIteration)
# 74 CHECK_EXC_MATCH
# 76 POP_JUMP_FORWARD_IF_FALSE 3 (to 84)
# 78 POP_TOP
# 8 80 POP_EXCEPT
#
# Fortunately this translates to a simple pop from the exn_vt_stack
tx.exn_vt_stack.clear_current_exception()
# These exceptions are ok to fallback to eager/graph_break.
exceptions_allowed_to_be_fallback = (
torch._subclasses.fake_tensor.DataDependentOutputException,
torch._subclasses.fake_tensor.DynamicOutputShapeException,
torch._subclasses.fake_tensor.UnsupportedOperatorException,
torch._subclasses.fake_tensor.UnsupportedFakeTensorException,
torch._subclasses.fake_tensor.UnsupportedMutationAliasingException,
)
def unimplemented_with_warning(
e: Exception, code: types.CodeType, msg: str
) -> NoReturn:
# This function calls unimplemented internally and eventually graph breaks
# or falls to eager. unimplemented itself does not print any user warnings,
# i.e., its very silent. This helper function is intended when an error is
# encountered in the torch.compile stack which is worth showing as warning
# to the user. For example, if AOT Autograd backend fails with a fake tensor
# exception, its ok to fallback to eager but not silently. Here, we can use
# this function to log the message and the stack trace.
graph_break_msg = format_error_msg_verbose(e, code)
torch._logging.trace_structured(
"artifact",
metadata_fn=lambda: {
"name": "dynamo_graph_break_reason",
"encoding": "string",
},
payload_fn=lambda: graph_break_msg,
)
graph_breaks_log.debug("%s", graph_break_msg)
log.warning(msg)
unimplemented(msg, from_exc=e)
_NOTHING = object()
def unimplemented(
msg: str, *, from_exc: Any = _NOTHING, case_name: Optional[str] = None
) -> NoReturn:
assert msg != os.environ.get("BREAK", False)
if from_exc is not _NOTHING:
raise Unsupported(msg, case_name=case_name) from from_exc
raise Unsupported(msg, case_name=case_name)
def unimplemented_v2_with_warning(
e: Exception,
code: types.CodeType,
gb_type: str,
context: str,
explanation: str,
hints: list[str],
) -> NoReturn:
# This function calls unimplemented internally and eventually graph breaks
# or falls to eager. unimplemented itself does not print any user warnings,
# i.e., its very silent. This helper function is intended when an error is
# encountered in the torch.compile stack which is worth showing as warning
# to the user. For example, if AOT Autograd backend fails with a fake tensor
# exception, its ok to fallback to eager but not silently. Here, we can use
# this function to log the message and the stack trace.
graph_break_msg = format_error_msg_verbose(e, code)
torch._logging.trace_structured(
"artifact",
metadata_fn=lambda: {
"name": "dynamo_graph_break_reason",
"encoding": "string",
},
payload_fn=lambda: graph_break_msg,
)
graph_breaks_log.debug("%s", graph_break_msg)
unimplemented_v2(gb_type, context, explanation, hints, from_exc=e, log_warning=True)
def format_graph_break_message(
gb_type: str,
context: str,
explanation: str,
hints: list[str],
) -> str:
explanation = textwrap.indent(explanation, " ").lstrip()
hints_str = "\n".join(
" Hint: " + textwrap.indent(hint, " ").lstrip() for hint in hints
)
context = textwrap.indent(context, " ").lstrip()
msg = f"""\
{gb_type}
Explanation: {explanation}
{hints_str}
Developer debug context: {context}
"""
return msg
@lru_cache(maxsize=1)
def _load_gb_type_to_gb_id_map() -> dict[str, Any]:
"""
Loads the gb_type to gb_id map from the graph break registry from JSON file with caching.
Includes historical gb_type (mapping behavior of duplicate gb_types with different gb_ids is undefined).
"""
try:
script_dir = Path(__file__).resolve().parent
registry_path = get_file_path_2(
"", str(script_dir), "graph_break_registry.json"
)
with open(registry_path) as f:
registry = json.load(f)
except Exception:
log.exception("Error accessing the registry file")
registry = {}
mapping = {}
for k, v in registry.items():
for entry in v:
mapping[entry["Gb_type"]] = k
return mapping
def get_gbid_documentation_link(gb_type: str) -> Optional[str]:
"""
Retrieves the GBID documentation link for a given graph break type.
Args:
gb_type: The graph break type to look up.
Returns:
A string containing the documentation URL if found, otherwise None.
"""
GRAPH_BREAK_SITE_URL = (
"https://meta-pytorch.github.io/compile-graph-break-site/gb/" # @lint-ignore
)
gb_type_to_gb_id_map = _load_gb_type_to_gb_id_map()
if gb_type in gb_type_to_gb_id_map:
return (
f"{GRAPH_BREAK_SITE_URL}gb{gb_type_to_gb_id_map[gb_type].lstrip('GB')}.html"
)
return None
# TODO replace old unimplemented later
def unimplemented_v2(
gb_type: str,
context: str,
explanation: str,
hints: list[str],
*,
from_exc: Any = _NOTHING,
log_warning: bool = False,
) -> NoReturn:
"""
Called within dynamo to cause a graph break.
Args:
gb_type: Context-free graph break type. It should be a short string without any
information specific to the tracing context (i.e. no dynamically-generated strings)
context: Developer context for the graph break. It can contain tracing context/dynamic strings.
explanation: User-facing context-dependent explanation for the graph break. Can be dynamic.
hints: List of user-facing hints for the graph break.
"""
msg = format_graph_break_message(gb_type, context, explanation, hints)
documentation_link = get_gbid_documentation_link(gb_type)
if documentation_link:
msg += f"\n For more details about this graph break, please visit: {documentation_link}"
if log_warning:
log.warning(msg)
if from_exc is not _NOTHING:
past_real_stack = None
if hasattr(from_exc, "real_stack"):
past_real_stack = from_exc.real_stack
raise Unsupported(msg, real_stack=past_real_stack) from from_exc
raise Unsupported(msg)
# KeyError has special handling for its args
# see https://github.com/python/cpython/blob/3.11/Objects/exceptions.c#L2534 for details
class KeyErrorMsg:
def __init__(self, value: Any) -> None:
self.value = value
def __str__(self) -> str:
return str(self.value)
def __repr__(self) -> str:
return self.__str__()
def augment_exc_message(exc: Exception, msg: str = "\n", export: bool = False) -> None:
import traceback
exc.innermost_user_frame_summary = None # type: ignore[attr-defined]
real_stack = get_real_stack(exc)
if real_stack is not None and len(real_stack) > 0:
exc.innermost_user_frame_summary = real_stack[-1] # type: ignore[attr-defined]
msg += f"\nfrom user code:\n {''.join(traceback.format_list(real_stack))}"
if config.replay_record_enabled and hasattr(exc, "record_filename"):
msg += (
f"\nLast frame execution written to {exc.record_filename}. To run only this frame while debugging, run\
torch._dynamo.replay('{exc.record_filename}').\n"
)
if not config.verbose and hasattr(exc, "real_stack"):
msg += (
"\nSet TORCHDYNAMO_VERBOSE=1 for the internal stack trace "
"(please do this especially if you're reporting a bug to PyTorch). "
'For even more developer context, set TORCH_LOGS="+dynamo"\n'
)
if hasattr(exc, "inner_exception") and hasattr(
exc.inner_exception, "minifier_path"
):
if hasattr(exc.inner_exception, "buck_command"):
msg += (
f"\nMinifier script written to {exc.inner_exception.minifier_path}. Run "
f"this buck command to find the smallest traced graph "
f"which reproduces this error: {exc.inner_exception.buck_command}\n"
)
else:
msg += (
f"\nMinifier script written to {exc.inner_exception.minifier_path}. Run "
"this script to find the smallest traced graph which reproduces this error.\n"
)
old_msg = "" if len(exc.args) == 0 else str(exc.args[0])
if isinstance(exc, KeyError):
exc.args = (KeyErrorMsg(old_msg + msg),) + exc.args[1:]
else:
new_msg = old_msg + msg
exc.args = (new_msg,) + exc.args[1:]
def get_exc_message(
e: Exception, compile_id: CompileId
) -> tuple[Optional[str], Optional[int]]:
filename = None
lineno = None
if e.innermost_user_frame_summary is not None: # type: ignore[attr-defined]
filename = e.innermost_user_frame_summary.filename # type: ignore[attr-defined]
lineno = e.innermost_user_frame_summary.lineno # type: ignore[attr-defined]
e.compile_id = compile_id # type: ignore[attr-defined]
return filename, lineno
def get_stack_above_dynamo() -> StackSummary:
return filter_stack(extract_stack())
def get_real_stack(
exc: Exception, frame: Optional[DynamoFrameType] = None
) -> Optional[StackSummary]:
real_stack = getattr(exc, "real_stack", None)
if real_stack is None:
return None
# NB: it's possible for real_stack to be []; we still attempt to
# report a stack anyway because the stack_above_dynamo may still
# be useful for debugging
if frame is not None:
# NB: frame is PyInterpreterFrame on Python 3.11 and later,
# not a TRUE frame object. You can't actually feed it
# to traceback because it doesn't have enough information.
# To solve this problem, we technically should just materialize
# the frame, the same way _PyFrame_GetFrameObject would do
# (but we cannot actually do this, because this populates
# frame_obj field, which default eval frame doesn't like).
#
# Fortunately, in this case, we can hack it: there's no need
# to actually use the truly top frame, we can just extract
# from where we are right now and rely on filter_stack to
# get rid of all the dynamo frames. For ease of testing
# we apply this behavior to ALL Python versions
stack_above_dynamo = get_stack_above_dynamo()
else:
stack_above_dynamo = StackSummary()
return StackSummary.from_list(stack_above_dynamo + real_stack)
# filter out all frames after entering dynamo
def filter_stack(stack: StackSummary) -> StackSummary:
user_stack = StackSummary()
for frame in stack:
if frame.filename is None:
continue
if "convert_frame" in frame.filename:
break
if "eval_frame" in frame.filename or (
frame.line and "torch._dynamo.optimize(" in frame.line
):
continue
user_stack.append(frame)
return user_stack
def remove_resume_prefix(name: str) -> Optional[str]:
from .resume_execution import TORCH_DYNAMO_RESUME_IN_PREFIX
match = re.match(f"{TORCH_DYNAMO_RESUME_IN_PREFIX}_(\\w+)_at_\\d+", name)
if match:
return match.group(1)
return None
def collapse_resume_frames(stack: StackSummary) -> StackSummary:
"""
When we graph break, we create a resume function and make a regular Python call
to it, which gets intercepted by Dynamo. This behavior is normally shown in the
traceback, which can be confusing to a user. So we can filter out resume frames
for better traceback clarity.
Example:
File "..." line 3, in f
<line 3>
File "..." line 5, in torch_dynamo_resume_in_f_at_80
<line 5>
File "..." line 10, in torch_dynamo_resume_in_f_at_120
<line 10>
becomes
File "..." line 10, in f
<line 10>
"""
new_stack = StackSummary()
for frame in stack:
if frame.filename is None:
continue
name = remove_resume_prefix(frame.name)
if new_stack and name and new_stack[-1].name == name:
new_stack[-1] = frame
frame.name = name
else:
new_stack.append(frame)
return new_stack
def format_error_msg_verbose(
exc: Exception,
code: types.CodeType,
record_filename: Optional[str] = None,
frame: Optional[DynamoFrameType] = None,
) -> str:
msg = (
f"WON'T CONVERT {code.co_name} {code.co_filename} line {code.co_firstlineno}\n"
)
msg += "=" * 10 + " TorchDynamo Stack Trace " + "=" * 10 + "\n"
msg += format_exc()
real_stack = get_real_stack(exc, frame)
if real_stack is not None:
msg += (
"\n"
+ "=" * 10
+ " The above exception occurred while processing the following code "
+ "=" * 10
+ "\n\n"
)
msg += "".join(format_list(real_stack))
msg += "\n"
msg += "=" * 10
return msg
def format_error_msg(
exc: Exception,
code: types.CodeType,
record_filename: Optional[str] = None,
frame: Optional[DynamoFrameType] = None,
) -> str:
if config.verbose:
return format_error_msg_verbose(exc, code, record_filename, frame)
return f"WON'T CONVERT {code.co_name} {code.co_filename}\
line {code.co_firstlineno} \ndue to: \n{format_exc()}"