mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[dynamo][hooks] config to wrap the top frame in a wrapper (#149758)
This should be done by default but there are too many issues. This PR is a workaround. https://github.com/pytorch/pytorch/issues/117584 Pull Request resolved: https://github.com/pytorch/pytorch/pull/149758 Approved by: https://github.com/yf225 ghstack dependencies: #149712
This commit is contained in:
parent
621c801f78
commit
6bbe8dbd63
|
|
@ -859,6 +859,28 @@ class HooksTests(torch._dynamo.test_case.TestCase):
|
|||
self.assertEqual(ref, res)
|
||||
self.assertEqual(cnts.frame_count, 2)
|
||||
|
||||
@torch._dynamo.config.patch(wrap_top_frame=True)
|
||||
def test_wrap_top_frame_with_hooks(self):
|
||||
class ToyModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.net1 = torch.nn.Linear(18, 18, bias=False)
|
||||
|
||||
def forward(self, x):
|
||||
return self.net1(x)
|
||||
|
||||
mod = ToyModel()
|
||||
mod.register_forward_pre_hook(lambda mod, input: input[0] + 1)
|
||||
|
||||
cnts = torch._dynamo.testing.CompileCounter()
|
||||
compiled_mod = torch.compile(mod, backend=cnts)
|
||||
|
||||
x = torch.rand(18, 18)
|
||||
ref = mod(x)
|
||||
res = compiled_mod(x)
|
||||
self.assertEqual(ref, res)
|
||||
self.assertEqual(cnts.frame_count, 1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._dynamo.test_case import run_tests
|
||||
|
|
|
|||
|
|
@ -600,6 +600,10 @@ run_gc_after_compile = Config( # type: ignore[var-annotated]
|
|||
env_name_default="TORCH_DYNAMO_RUN_GC_AFTER_COMPILE",
|
||||
)
|
||||
|
||||
# Takes the function/module decorated with torch.compile and passes it through a
|
||||
# wrapper. This ensures that nn.module hooks are also compiled in the same frame.
|
||||
wrap_top_frame = False
|
||||
|
||||
# HACK: this is for testing custom ops profiling only
|
||||
_custom_ops_profile: Optional[Any] = None
|
||||
|
||||
|
|
|
|||
|
|
@ -323,9 +323,12 @@ class OptimizedModule(torch.nn.Module):
|
|||
if isinstance(self.dynamo_ctx, DisableContext):
|
||||
# No need to check trace rules
|
||||
self.forward = self.dynamo_ctx(self._orig_mod.__call__)
|
||||
elif isinstance(self._orig_mod.forward, types.MethodType) and (
|
||||
elif config.wrap_top_frame or (
|
||||
isinstance(self._orig_mod.forward, types.MethodType)
|
||||
and (
|
||||
trace_rules.check(self._orig_mod.forward)
|
||||
or getattr(self._orig_mod, "_is_fsdp_managed_module", False)
|
||||
)
|
||||
):
|
||||
# This may be a torch.nn.* instance in trace_rules.py which
|
||||
# won't trigger a frame evaluation workaround to add an extra
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user