diff --git a/test/dynamo/test_logging.py b/test/dynamo/test_logging.py index 5f6704b8a39..dd7e93de14c 100644 --- a/test/dynamo/test_logging.py +++ b/test/dynamo/test_logging.py @@ -982,6 +982,7 @@ exclusions = { "graph_region_expansion", "hierarchical_compile", "compute_dependencies", + "annotation", } for name in torch._logging._internal.log_registry.artifact_names: if name not in exclusions: diff --git a/torch/_functorch/_aot_autograd/utils.py b/torch/_functorch/_aot_autograd/utils.py index 83091d867d4..eae75e06a42 100644 --- a/torch/_functorch/_aot_autograd/utils.py +++ b/torch/_functorch/_aot_autograd/utils.py @@ -4,6 +4,7 @@ Contains various utils for AOTAutograd, including those for handling collections """ import dataclasses +import logging import operator import warnings from collections.abc import Callable @@ -40,6 +41,7 @@ KNOWN_TYPES = [ original_zip = zip aot_graphs_effects_log = getArtifactLogger(__name__, "aot_graphs_effects") +annotation_log = getArtifactLogger(__name__, "annotation") def strict_zip(*iterables, strict=True, **kwargs): @@ -443,6 +445,10 @@ def _copy_metadata_to_bw_nodes_in_subgraph( ) -> None: """Copy metadata from forward nodes to backward nodes in a single subgraph.""" for node in fx_g.graph.nodes: + annotation_log.debug("node: %s", node.name) + seq_nr = node.meta.get("seq_nr") + annotation_log.debug("seq_nr: %s", seq_nr) + if not _is_backward_node_with_seq_nr(node): continue @@ -478,6 +484,10 @@ def copy_fwd_metadata_to_bw_nodes(fx_g: torch.fx.GraphModule) -> None: if isinstance(submod, torch.fx.GraphModule): _collect_fwd_nodes_from_subgraph(submod, fwd_seq_nr_to_node) + if annotation_log.isEnabledFor(logging.DEBUG): + for k, v in fwd_seq_nr_to_node.items(): + annotation_log.debug("forward:: key: %s, value: %s", k, v) + # Second pass: copy metadata to backward nodes in all subgraphs # using the global forward mapping for submod in fx_g.modules(): diff --git a/torch/_logging/_registrations.py b/torch/_logging/_registrations.py index 3c6f092ed4d..162ad53a63c 100644 --- a/torch/_logging/_registrations.py +++ b/torch/_logging/_registrations.py @@ -246,4 +246,9 @@ register_artifact( "Logs debug info for hierarchical compilation", off_by_default=True, ) +register_artifact( + "annotation", + "Logs detailed steps of the creating annotation on graph nodes", + off_by_default=True, +) register_artifact("custom_format_test_artifact", "Testing only", log_format="") diff --git a/torch/fx/proxy.py b/torch/fx/proxy.py index 8979dcbabaf..9b359d78d6d 100644 --- a/torch/fx/proxy.py +++ b/torch/fx/proxy.py @@ -17,6 +17,7 @@ from typing import Any, Optional import torch import torch.fx.traceback as fx_traceback from torch._C import _fx_map_aggregate as map_aggregate, _fx_map_arg as map_arg +from torch._logging import getArtifactLogger from torch.utils._traceback import CapturedTraceback from ._compatibility import compatibility @@ -40,6 +41,7 @@ __all__ = [ log = logging.getLogger(__name__) +annotation_log = getArtifactLogger(__name__, "annotation") @compatibility(is_backward_compatible=False) @@ -202,7 +204,9 @@ class TracerBase: # BWD pass we retrieve the sequence_nr stored on the current # executing autograd Node. See NOTE [ Sequence Number ]. if current_meta.get("in_grad_fn", 0) > 0: + annotation_log.debug("seq_nr from current_meta") new_seq_nr = current_meta["grad_fn_seq_nr"][-1] + annotation_log.debug("Assigning new_seq_nr %s to %s", new_seq_nr, node.name) node.meta["seq_nr"] = new_seq_nr elif self.module_stack: