mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
See https://github.com/pytorch/pytorch/pull/129751#issue-2380881501. Most changes are auto-generated by linter. You can review these PRs via: ```bash git diff --ignore-all-space --ignore-blank-lines HEAD~1 ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/129771 Approved by: https://github.com/justinchuby, https://github.com/janeyx99
262 lines
8.7 KiB
Python
262 lines
8.7 KiB
Python
# mypy: allow-untyped-defs
|
|
from __future__ import annotations
|
|
|
|
import dataclasses
|
|
import functools
|
|
from typing import Any, TYPE_CHECKING
|
|
|
|
import onnxscript # type: ignore[import]
|
|
from onnxscript.function_libs.torch_lib import graph_building # type: ignore[import]
|
|
|
|
import torch
|
|
import torch.fx
|
|
from torch.onnx._internal import diagnostics
|
|
from torch.onnx._internal.diagnostics import infra
|
|
from torch.onnx._internal.diagnostics.infra import decorator, formatter
|
|
from torch.onnx._internal.fx import registration, type_utils as fx_type_utils
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
import logging
|
|
|
|
# NOTE: The following limits are for the number of items to display in diagnostics for
|
|
# a list, tuple or dict. The limit is picked such that common useful scenarios such as
|
|
# operator arguments are covered, while preventing excessive processing loads on considerably
|
|
# large containers such as the dictionary mapping from fx to onnx nodes.
|
|
_CONTAINER_ITEM_LIMIT: int = 10
|
|
|
|
# NOTE(bowbao): This is a shim over `torch.onnx._internal.diagnostics`, which is
|
|
# used in `torch.onnx`, and loaded with `torch`. Hence anything related to `onnxscript`
|
|
# cannot be put there.
|
|
|
|
# [NOTE: `dynamo_export` diagnostics logging]
|
|
# The 'dynamo_export' diagnostics leverages the PT2 artifact logger to handle the verbosity
|
|
# level of logs that are recorded in each SARIF log diagnostic. In addition to SARIF log,
|
|
# terminal logging is by default disabled. Terminal logging can be activated by setting
|
|
# the environment variable `TORCH_LOGS="onnx_diagnostics"`. When the environment variable
|
|
# is set, it also fixes logging level to `logging.DEBUG`, overriding the verbosity level
|
|
# specified in the diagnostic options.
|
|
# See `torch/_logging/__init__.py` for more on PT2 logging.
|
|
_ONNX_DIAGNOSTICS_ARTIFACT_LOGGER_NAME = "onnx_diagnostics"
|
|
diagnostic_logger = torch._logging.getArtifactLogger(
|
|
"torch.onnx", _ONNX_DIAGNOSTICS_ARTIFACT_LOGGER_NAME
|
|
)
|
|
|
|
|
|
def is_onnx_diagnostics_log_artifact_enabled() -> bool:
|
|
return torch._logging._internal.log_state.is_artifact_enabled(
|
|
_ONNX_DIAGNOSTICS_ARTIFACT_LOGGER_NAME
|
|
)
|
|
|
|
|
|
@functools.singledispatch
|
|
def _format_argument(obj: Any) -> str:
|
|
return formatter.format_argument(obj)
|
|
|
|
|
|
def format_argument(obj: Any) -> str:
|
|
formatter = _format_argument.dispatch(type(obj))
|
|
return formatter(obj)
|
|
|
|
|
|
# NOTE: EDITING BELOW? READ THIS FIRST!
|
|
#
|
|
# The below functions register the `format_argument` function for different types via
|
|
# `functools.singledispatch` registry. These are invoked by the diagnostics system
|
|
# when recording function arguments and return values as part of a diagnostic.
|
|
# Hence, code with heavy workload should be avoided. Things to avoid for example:
|
|
# `torch.fx.GraphModule.print_readable()`.
|
|
|
|
|
|
@_format_argument.register
|
|
def _torch_nn_module(obj: torch.nn.Module) -> str:
|
|
return f"torch.nn.Module({obj.__class__.__name__})"
|
|
|
|
|
|
@_format_argument.register
|
|
def _torch_fx_graph_module(obj: torch.fx.GraphModule) -> str:
|
|
return f"torch.fx.GraphModule({obj.__class__.__name__})"
|
|
|
|
|
|
@_format_argument.register
|
|
def _torch_fx_node(obj: torch.fx.Node) -> str:
|
|
node_string = f"fx.Node({obj.target})[{obj.op}]:"
|
|
if "val" not in obj.meta:
|
|
return node_string + "None"
|
|
return node_string + format_argument(obj.meta["val"])
|
|
|
|
|
|
@_format_argument.register
|
|
def _torch_fx_symbolic_bool(obj: torch.SymBool) -> str:
|
|
return f"SymBool({obj})"
|
|
|
|
|
|
@_format_argument.register
|
|
def _torch_fx_symbolic_int(obj: torch.SymInt) -> str:
|
|
return f"SymInt({obj})"
|
|
|
|
|
|
@_format_argument.register
|
|
def _torch_fx_symbolic_float(obj: torch.SymFloat) -> str:
|
|
return f"SymFloat({obj})"
|
|
|
|
|
|
@_format_argument.register
|
|
def _torch_tensor(obj: torch.Tensor) -> str:
|
|
return f"Tensor({fx_type_utils.from_torch_dtype_to_abbr(obj.dtype)}{_stringify_shape(obj.shape)})"
|
|
|
|
|
|
@_format_argument.register
|
|
def _int(obj: int) -> str:
|
|
return str(obj)
|
|
|
|
|
|
@_format_argument.register
|
|
def _float(obj: float) -> str:
|
|
return str(obj)
|
|
|
|
|
|
@_format_argument.register
|
|
def _bool(obj: bool) -> str:
|
|
return str(obj)
|
|
|
|
|
|
@_format_argument.register
|
|
def _str(obj: str) -> str:
|
|
return obj
|
|
|
|
|
|
@_format_argument.register
|
|
def _registration_onnx_function(obj: registration.ONNXFunction) -> str:
|
|
# TODO: Compact display of `param_schema`.
|
|
return f"registration.ONNXFunction({obj.op_full_name}, is_custom={obj.is_custom}, is_complex={obj.is_complex})"
|
|
|
|
|
|
@_format_argument.register
|
|
def _list(obj: list) -> str:
|
|
list_string = f"List[length={len(obj)}](\n"
|
|
if not obj:
|
|
return list_string + "None)"
|
|
for i, item in enumerate(obj):
|
|
if i >= _CONTAINER_ITEM_LIMIT:
|
|
# NOTE: Print only first _CONTAINER_ITEM_LIMIT items.
|
|
list_string += "...,\n"
|
|
break
|
|
list_string += f"{format_argument(item)},\n"
|
|
return list_string + ")"
|
|
|
|
|
|
@_format_argument.register
|
|
def _tuple(obj: tuple) -> str:
|
|
tuple_string = f"Tuple[length={len(obj)}](\n"
|
|
if not obj:
|
|
return tuple_string + "None)"
|
|
for i, item in enumerate(obj):
|
|
if i >= _CONTAINER_ITEM_LIMIT:
|
|
# NOTE: Print only first _CONTAINER_ITEM_LIMIT items.
|
|
tuple_string += "...,\n"
|
|
break
|
|
tuple_string += f"{format_argument(item)},\n"
|
|
return tuple_string + ")"
|
|
|
|
|
|
@_format_argument.register
|
|
def _dict(obj: dict) -> str:
|
|
dict_string = f"Dict[length={len(obj)}](\n"
|
|
if not obj:
|
|
return dict_string + "None)"
|
|
for i, (key, value) in enumerate(obj.items()):
|
|
if i >= _CONTAINER_ITEM_LIMIT:
|
|
# NOTE: Print only first _CONTAINER_ITEM_LIMIT items.
|
|
dict_string += "...\n"
|
|
break
|
|
dict_string += f"{key}: {format_argument(value)},\n"
|
|
return dict_string + ")"
|
|
|
|
|
|
@_format_argument.register
|
|
def _torch_nn_parameter(obj: torch.nn.Parameter) -> str:
|
|
return f"Parameter({format_argument(obj.data)})"
|
|
|
|
|
|
@_format_argument.register
|
|
def _onnxscript_torch_script_tensor(obj: graph_building.TorchScriptTensor) -> str:
|
|
return f"`TorchScriptTensor({fx_type_utils.from_torch_dtype_to_abbr(obj.dtype)}{_stringify_shape(obj.shape)})`" # type: ignore[arg-type] # noqa: B950
|
|
|
|
|
|
@_format_argument.register
|
|
def _onnxscript_onnx_function(obj: onnxscript.OnnxFunction) -> str:
|
|
return f"`OnnxFunction({obj.name})`"
|
|
|
|
|
|
@_format_argument.register
|
|
def _onnxscript_traced_onnx_function(obj: onnxscript.TracedOnnxFunction) -> str:
|
|
return f"`TracedOnnxFunction({obj.name})`"
|
|
|
|
|
|
# from torch/fx/graph.py to follow torch format
|
|
def _stringify_shape(shape: torch.Size | None) -> str:
|
|
if shape is None:
|
|
return ""
|
|
return f"[{', '.join(str(x) for x in shape)}]"
|
|
|
|
|
|
rules = diagnostics.rules
|
|
levels = diagnostics.levels
|
|
RuntimeErrorWithDiagnostic = infra.RuntimeErrorWithDiagnostic
|
|
LazyString = formatter.LazyString
|
|
DiagnosticOptions = infra.DiagnosticOptions
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class Diagnostic(infra.Diagnostic):
|
|
logger: logging.Logger = dataclasses.field(init=False, default=diagnostic_logger)
|
|
|
|
def log(self, level: int, message: str, *args, **kwargs) -> None:
|
|
if self.logger.isEnabledFor(level):
|
|
formatted_message = message % args
|
|
if is_onnx_diagnostics_log_artifact_enabled():
|
|
# Only log to terminal if artifact is enabled.
|
|
# See [NOTE: `dynamo_export` diagnostics logging] for details.
|
|
self.logger.log(level, formatted_message, **kwargs)
|
|
|
|
self.additional_messages.append(formatted_message)
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class DiagnosticContext(infra.DiagnosticContext[Diagnostic]):
|
|
logger: logging.Logger = dataclasses.field(init=False, default=diagnostic_logger)
|
|
_bound_diagnostic_type: type[Diagnostic] = dataclasses.field(
|
|
init=False, default=Diagnostic
|
|
)
|
|
|
|
def __enter__(self):
|
|
self._previous_log_level = self.logger.level
|
|
# Adjust the logger level based on `options.verbosity_level` and the environment
|
|
# variable `TORCH_LOGS`. See [NOTE: `dynamo_export` diagnostics logging] for details.
|
|
if not is_onnx_diagnostics_log_artifact_enabled():
|
|
return super().__enter__()
|
|
else:
|
|
return self
|
|
|
|
|
|
diagnose_call = functools.partial(
|
|
decorator.diagnose_call,
|
|
diagnostic_type=Diagnostic,
|
|
format_argument=format_argument,
|
|
)
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class UnsupportedFxNodeDiagnostic(Diagnostic):
|
|
unsupported_fx_node: torch.fx.Node | None = None
|
|
|
|
def __post_init__(self) -> None:
|
|
super().__post_init__()
|
|
# NOTE: This is a hack to make sure that the additional fields must be set and
|
|
# not None. Ideally they should not be set as optional. But this is a known
|
|
# limitation with `dataclasses`. Resolvable in Python 3.10 with `kw_only=True`.
|
|
# https://stackoverflow.com/questions/69711886/python-dataclasses-inheritance-and-default-values
|
|
if self.unsupported_fx_node is None:
|
|
raise ValueError("unsupported_fx_node must be specified.")
|