pytorch/torch/fx/traceback.py
soulitzer 3cc5c42a23 Fix aot sequence_nr to reset bwd flag (#107210)
The way the aot autograd sequence_nr tracking works is that we run the aot export logic, the dynamo captured forward graph is run under an fx.Interpreter, which iterates through the nodes of the forward graph while setting the `current_metadata`.
Since during backward what is run doesn't correspond to any node during forward, we fallback to the global `current_metadata`. And since this global metadata is ends up being shared between runs, that leads to weirdness if we forget to reset things, e.g., depending whether this is the first test run, the printed results will be different.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/107210
Approved by: https://github.com/bdhirsh
2023-08-24 16:58:12 +00:00

101 lines
3.2 KiB
Python

import traceback
from contextlib import contextmanager
from typing import List, Any, Dict
from ._compatibility import compatibility
__all__ = ['preserve_node_meta', 'has_preserved_node_meta',
'set_stack_trace', 'set_grad_fn_seq_nr', 'reset_grad_fn_seq_nr',
'format_stack', 'set_current_meta', 'get_current_meta']
current_meta: Dict[str, Any] = {}
should_preserve_node_meta = False
@compatibility(is_backward_compatible=False)
@contextmanager
def preserve_node_meta():
global should_preserve_node_meta
saved_should_preserve_node_meta = should_preserve_node_meta
try:
should_preserve_node_meta = True
yield
finally:
should_preserve_node_meta = saved_should_preserve_node_meta
@compatibility(is_backward_compatible=False)
def set_stack_trace(stack : List[str]):
global current_meta
if should_preserve_node_meta and stack:
current_meta["stack_trace"] = "".join(stack)
@compatibility(is_backward_compatible=False)
def set_grad_fn_seq_nr(seq_nr):
global current_meta
if should_preserve_node_meta:
# The seq_nr is captured by eager mode in the grad_fn during forward
current_meta["prev_grad_fn_seq_nr"] = current_meta.get("grad_fn_seq_nr", None)
current_meta["prev_in_grad_fn"] = current_meta.get("in_grad_fn", None)
current_meta["grad_fn_seq_nr"] = seq_nr
current_meta["in_grad_fn"] = True
@compatibility(is_backward_compatible=False)
def reset_grad_fn_seq_nr():
# NB: reset state properly, this would be helpful towards supporting
# reentrant autograd if we actually wanted to do that.
global current_meta
if should_preserve_node_meta:
if current_meta["prev_grad_fn_seq_nr"] is None:
assert current_meta["prev_in_grad_fn"] is None
del current_meta["grad_fn_seq_nr"]
del current_meta["in_grad_fn"]
current_meta["grad_fn_seq_nr"] = current_meta["prev_grad_fn_seq_nr"]
current_meta["in_grad_fn"] = current_meta["prev_in_grad_fn"]
@compatibility(is_backward_compatible=False)
def format_stack() -> List[str]:
if should_preserve_node_meta:
return [current_meta.get("stack_trace", "")]
else:
# fallback to traceback.format_stack()
return traceback.format_list(traceback.extract_stack()[:-1])
@compatibility(is_backward_compatible=False)
def has_preserved_node_meta() -> bool:
return should_preserve_node_meta
@compatibility(is_backward_compatible=False)
@contextmanager
def set_current_meta(node):
global current_meta
if should_preserve_node_meta and node.meta:
saved_meta = current_meta
try:
current_meta = node.meta.copy()
# Append (node.name, node.target) onto "from_node" for provenance tracking
if "from_node" not in current_meta:
current_meta["from_node"] = [(node.name, node.target)]
elif current_meta["from_node"][-1][0] != node.name:
current_meta["from_node"].append((node.name, node.target))
yield
finally:
current_meta = saved_meta
else:
yield
@compatibility(is_backward_compatible=False)
def get_current_meta() -> Dict[str, Any]:
return current_meta