[dynamo][hooks] config to wrap the top frame in a wrapper (#149758)

This should be done by default but there are too many issues. This PR is a
workaround.

https://github.com/pytorch/pytorch/issues/117584

Pull Request resolved: https://github.com/pytorch/pytorch/pull/149758
Approved by: https://github.com/yf225
ghstack dependencies: #149712
This commit is contained in:
Animesh Jain 2025-03-21 12:34:03 -07:00 committed by PyTorch MergeBot
parent 621c801f78
commit 6bbe8dbd63
3 changed files with 32 additions and 3 deletions

View File

@ -859,6 +859,28 @@ class HooksTests(torch._dynamo.test_case.TestCase):
self.assertEqual(ref, res)
self.assertEqual(cnts.frame_count, 2)
@torch._dynamo.config.patch(wrap_top_frame=True)
def test_wrap_top_frame_with_hooks(self):
class ToyModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.net1 = torch.nn.Linear(18, 18, bias=False)
def forward(self, x):
return self.net1(x)
mod = ToyModel()
mod.register_forward_pre_hook(lambda mod, input: input[0] + 1)
cnts = torch._dynamo.testing.CompileCounter()
compiled_mod = torch.compile(mod, backend=cnts)
x = torch.rand(18, 18)
ref = mod(x)
res = compiled_mod(x)
self.assertEqual(ref, res)
self.assertEqual(cnts.frame_count, 1)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests

View File

@ -600,6 +600,10 @@ run_gc_after_compile = Config( # type: ignore[var-annotated]
env_name_default="TORCH_DYNAMO_RUN_GC_AFTER_COMPILE",
)
# Takes the function/module decorated with torch.compile and passes it through a
# wrapper. This ensures that nn.module hooks are also compiled in the same frame.
wrap_top_frame = False
# HACK: this is for testing custom ops profiling only
_custom_ops_profile: Optional[Any] = None

View File

@ -323,9 +323,12 @@ class OptimizedModule(torch.nn.Module):
if isinstance(self.dynamo_ctx, DisableContext):
# No need to check trace rules
self.forward = self.dynamo_ctx(self._orig_mod.__call__)
elif isinstance(self._orig_mod.forward, types.MethodType) and (
elif config.wrap_top_frame or (
isinstance(self._orig_mod.forward, types.MethodType)
and (
trace_rules.check(self._orig_mod.forward)
or getattr(self._orig_mod, "_is_fsdp_managed_module", False)
)
):
# This may be a torch.nn.* instance in trace_rules.py which
# won't trigger a frame evaluation workaround to add an extra