Revert "[dynamo][annotate] Remove the need of external ctx mgr of preserve_node_meta (#165188)"

This reverts commit f0325d0787.

Reverted https://github.com/pytorch/pytorch/pull/165188 on behalf of https://github.com/malfet due to Looks like it broke bunch of tests, see 2d4654d208/1 ([comment](https://github.com/pytorch/pytorch/pull/165188#issuecomment-3393674273))
This commit is contained in:
PyTorch MergeBot 2025-10-11 21:38:45 +00:00
parent 2d4654d208
commit a19123b37e
3 changed files with 25 additions and 8 deletions

View File

@ -18,6 +18,17 @@ def checkpoint_wrapper(fn):
class AnnotateTests(torch._dynamo.test_case.TestCase):
# TODO - should not need this because we should turn this on in Dynamo but
# for some reasons, test fail.
def setUp(self):
super().setUp()
self.cm = torch.fx.traceback.preserve_node_meta()
self.cm.__enter__()
def tearDown(self):
super().tearDown()
self.cm.__exit__(None, None, None)
def get_custom_metadata(self, gm):
def helper(gm):
custom_metadata = []

View File

@ -45,6 +45,17 @@ def aot_eager_regional_inductor():
@skipIfTorchDynamo("Not a suitable dynamo wrapped test")
class RegionalInductorTests(torch._inductor.test_case.TestCase):
# TODO - should not need this because we should turn this on in Dynamo but
# for some reasons, test fail.
def setUp(self):
super().setUp()
self.cm = torch.fx.traceback.preserve_node_meta()
self.cm.__enter__()
def tearDown(self):
super().tearDown()
self.cm.__exit__(None, None, None)
def test_simple(self):
def fn(x, y):
sin = torch.sin(x)

View File

@ -23,7 +23,6 @@ restoring state changes.
import inspect
import sys
import warnings
from contextlib import ExitStack
from typing import TYPE_CHECKING, Union
import torch._C
@ -1279,13 +1278,9 @@ class FxTracebackAnnotateVariable(ContextWrappingVariable):
)
def enter(self, tx, *args):
# Run the annotation ctx manager in eager. Also ensure that
# preserve_node_meta context manager is setup. This is important to pass
# on the metadata to the create_proxy nodes.
stack = ExitStack()
stack.enter_context(torch.fx.traceback.annotate(self.target_values))
stack.enter_context(torch.fx.traceback.preserve_node_meta())
self.set_cleanup_hook(tx, lambda: stack.close())
cm = torch.fx.traceback.annotate(self.target_values)
cm.__enter__()
self.set_cleanup_hook(tx, lambda: cm.__exit__(None, None, None))
return variables.ConstantVariable.create(None)
def module_name(self):