[annotation] add logging for debugging annotation (#165797)

Add logging for debugging annotation bugs. Log will show with `TORCH_LOGS="+annotation" `

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165797
Approved by: https://github.com/ezyang, https://github.com/Skylion007, https://github.com/SherlockNoMad
This commit is contained in:
Shangdi Yu 2025-10-20 21:27:38 +00:00 committed by PyTorch MergeBot
parent 4f7f43253d
commit efc277cac7
4 changed files with 20 additions and 0 deletions

View File

@ -982,6 +982,7 @@ exclusions = {
"graph_region_expansion", "graph_region_expansion",
"hierarchical_compile", "hierarchical_compile",
"compute_dependencies", "compute_dependencies",
"annotation",
} }
for name in torch._logging._internal.log_registry.artifact_names: for name in torch._logging._internal.log_registry.artifact_names:
if name not in exclusions: if name not in exclusions:

View File

@ -4,6 +4,7 @@ Contains various utils for AOTAutograd, including those for handling collections
""" """
import dataclasses import dataclasses
import logging
import operator import operator
import warnings import warnings
from collections.abc import Callable from collections.abc import Callable
@ -40,6 +41,7 @@ KNOWN_TYPES = [
original_zip = zip original_zip = zip
aot_graphs_effects_log = getArtifactLogger(__name__, "aot_graphs_effects") aot_graphs_effects_log = getArtifactLogger(__name__, "aot_graphs_effects")
annotation_log = getArtifactLogger(__name__, "annotation")
def strict_zip(*iterables, strict=True, **kwargs): def strict_zip(*iterables, strict=True, **kwargs):
@ -443,6 +445,10 @@ def _copy_metadata_to_bw_nodes_in_subgraph(
) -> None: ) -> None:
"""Copy metadata from forward nodes to backward nodes in a single subgraph.""" """Copy metadata from forward nodes to backward nodes in a single subgraph."""
for node in fx_g.graph.nodes: for node in fx_g.graph.nodes:
annotation_log.debug("node: %s", node.name)
seq_nr = node.meta.get("seq_nr")
annotation_log.debug("seq_nr: %s", seq_nr)
if not _is_backward_node_with_seq_nr(node): if not _is_backward_node_with_seq_nr(node):
continue continue
@ -478,6 +484,10 @@ def copy_fwd_metadata_to_bw_nodes(fx_g: torch.fx.GraphModule) -> None:
if isinstance(submod, torch.fx.GraphModule): if isinstance(submod, torch.fx.GraphModule):
_collect_fwd_nodes_from_subgraph(submod, fwd_seq_nr_to_node) _collect_fwd_nodes_from_subgraph(submod, fwd_seq_nr_to_node)
if annotation_log.isEnabledFor(logging.DEBUG):
for k, v in fwd_seq_nr_to_node.items():
annotation_log.debug("forward:: key: %s, value: %s", k, v)
# Second pass: copy metadata to backward nodes in all subgraphs # Second pass: copy metadata to backward nodes in all subgraphs
# using the global forward mapping # using the global forward mapping
for submod in fx_g.modules(): for submod in fx_g.modules():

View File

@ -246,4 +246,9 @@ register_artifact(
"Logs debug info for hierarchical compilation", "Logs debug info for hierarchical compilation",
off_by_default=True, off_by_default=True,
) )
register_artifact(
"annotation",
"Logs detailed steps of the creating annotation on graph nodes",
off_by_default=True,
)
register_artifact("custom_format_test_artifact", "Testing only", log_format="") register_artifact("custom_format_test_artifact", "Testing only", log_format="")

View File

@ -17,6 +17,7 @@ from typing import Any, Optional
import torch import torch
import torch.fx.traceback as fx_traceback import torch.fx.traceback as fx_traceback
from torch._C import _fx_map_aggregate as map_aggregate, _fx_map_arg as map_arg from torch._C import _fx_map_aggregate as map_aggregate, _fx_map_arg as map_arg
from torch._logging import getArtifactLogger
from torch.utils._traceback import CapturedTraceback from torch.utils._traceback import CapturedTraceback
from ._compatibility import compatibility from ._compatibility import compatibility
@ -40,6 +41,7 @@ __all__ = [
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
annotation_log = getArtifactLogger(__name__, "annotation")
@compatibility(is_backward_compatible=False) @compatibility(is_backward_compatible=False)
@ -202,7 +204,9 @@ class TracerBase:
# BWD pass we retrieve the sequence_nr stored on the current # BWD pass we retrieve the sequence_nr stored on the current
# executing autograd Node. See NOTE [ Sequence Number ]. # executing autograd Node. See NOTE [ Sequence Number ].
if current_meta.get("in_grad_fn", 0) > 0: if current_meta.get("in_grad_fn", 0) > 0:
annotation_log.debug("seq_nr from current_meta")
new_seq_nr = current_meta["grad_fn_seq_nr"][-1] new_seq_nr = current_meta["grad_fn_seq_nr"][-1]
annotation_log.debug("Assigning new_seq_nr %s to %s", new_seq_nr, node.name)
node.meta["seq_nr"] = new_seq_nr node.meta["seq_nr"] = new_seq_nr
elif self.module_stack: elif self.module_stack: