mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
[annotation] add logging for debugging annotation (#165797)
Add logging for debugging annotation bugs. Log will show with `TORCH_LOGS="+annotation" ` Pull Request resolved: https://github.com/pytorch/pytorch/pull/165797 Approved by: https://github.com/ezyang, https://github.com/Skylion007, https://github.com/SherlockNoMad
This commit is contained in:
parent
4f7f43253d
commit
efc277cac7
|
|
@ -982,6 +982,7 @@ exclusions = {
|
|||
"graph_region_expansion",
|
||||
"hierarchical_compile",
|
||||
"compute_dependencies",
|
||||
"annotation",
|
||||
}
|
||||
for name in torch._logging._internal.log_registry.artifact_names:
|
||||
if name not in exclusions:
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ Contains various utils for AOTAutograd, including those for handling collections
|
|||
"""
|
||||
|
||||
import dataclasses
|
||||
import logging
|
||||
import operator
|
||||
import warnings
|
||||
from collections.abc import Callable
|
||||
|
|
@ -40,6 +41,7 @@ KNOWN_TYPES = [
|
|||
original_zip = zip
|
||||
|
||||
aot_graphs_effects_log = getArtifactLogger(__name__, "aot_graphs_effects")
|
||||
annotation_log = getArtifactLogger(__name__, "annotation")
|
||||
|
||||
|
||||
def strict_zip(*iterables, strict=True, **kwargs):
|
||||
|
|
@ -443,6 +445,10 @@ def _copy_metadata_to_bw_nodes_in_subgraph(
|
|||
) -> None:
|
||||
"""Copy metadata from forward nodes to backward nodes in a single subgraph."""
|
||||
for node in fx_g.graph.nodes:
|
||||
annotation_log.debug("node: %s", node.name)
|
||||
seq_nr = node.meta.get("seq_nr")
|
||||
annotation_log.debug("seq_nr: %s", seq_nr)
|
||||
|
||||
if not _is_backward_node_with_seq_nr(node):
|
||||
continue
|
||||
|
||||
|
|
@ -478,6 +484,10 @@ def copy_fwd_metadata_to_bw_nodes(fx_g: torch.fx.GraphModule) -> None:
|
|||
if isinstance(submod, torch.fx.GraphModule):
|
||||
_collect_fwd_nodes_from_subgraph(submod, fwd_seq_nr_to_node)
|
||||
|
||||
if annotation_log.isEnabledFor(logging.DEBUG):
|
||||
for k, v in fwd_seq_nr_to_node.items():
|
||||
annotation_log.debug("forward:: key: %s, value: %s", k, v)
|
||||
|
||||
# Second pass: copy metadata to backward nodes in all subgraphs
|
||||
# using the global forward mapping
|
||||
for submod in fx_g.modules():
|
||||
|
|
|
|||
|
|
@ -246,4 +246,9 @@ register_artifact(
|
|||
"Logs debug info for hierarchical compilation",
|
||||
off_by_default=True,
|
||||
)
|
||||
register_artifact(
|
||||
"annotation",
|
||||
"Logs detailed steps of the creating annotation on graph nodes",
|
||||
off_by_default=True,
|
||||
)
|
||||
register_artifact("custom_format_test_artifact", "Testing only", log_format="")
|
||||
|
|
|
|||
|
|
@ -17,6 +17,7 @@ from typing import Any, Optional
|
|||
import torch
|
||||
import torch.fx.traceback as fx_traceback
|
||||
from torch._C import _fx_map_aggregate as map_aggregate, _fx_map_arg as map_arg
|
||||
from torch._logging import getArtifactLogger
|
||||
from torch.utils._traceback import CapturedTraceback
|
||||
|
||||
from ._compatibility import compatibility
|
||||
|
|
@ -40,6 +41,7 @@ __all__ = [
|
|||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
annotation_log = getArtifactLogger(__name__, "annotation")
|
||||
|
||||
|
||||
@compatibility(is_backward_compatible=False)
|
||||
|
|
@ -202,7 +204,9 @@ class TracerBase:
|
|||
# BWD pass we retrieve the sequence_nr stored on the current
|
||||
# executing autograd Node. See NOTE [ Sequence Number ].
|
||||
if current_meta.get("in_grad_fn", 0) > 0:
|
||||
annotation_log.debug("seq_nr from current_meta")
|
||||
new_seq_nr = current_meta["grad_fn_seq_nr"][-1]
|
||||
annotation_log.debug("Assigning new_seq_nr %s to %s", new_seq_nr, node.name)
|
||||
node.meta["seq_nr"] = new_seq_nr
|
||||
|
||||
elif self.module_stack:
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user