[precompile] Prune local scope variables for guard serialization. (#154431)

Summary: Prune unused local objects from serialized local scope if they are not used in guard reconstruction. This is helpful when a user program takes things like local callable functions or the function call is recursive.

Test Plan:
test/dynamo/test_guard_serialization.py -k test_function_locals

Before pruning locals:
```
state = GuardsState(output_graph=OutputGraphGuardsState(local_scope={'x': tensor([ 0.0461,  0.4024, -1.0115]), 'g': <function ...aints=None, _guards=<torch._guards.GuardsSet object at 0x7fbccc7e9fc0>, _aotautograd_guards=[]), shape_code_parts=None)

    def pickle_guards_state(state: GuardsState) -> bytes:
        buf = io.BytesIO()
        pickler = GuardsStatePickler(buf)
        try:
            pickler.dump(state)
        except AttributeError as e:
>           raise torch._dynamo.exc.PackageError(str(e)) from e
E           torch._dynamo.exc.PackageError: Can't pickle local object 'TestGuardSerialization.test_function_locals.<locals>.foo'
```
After the diff
```
Tests finished: Pass 1. Fail 0. Fatal 0. Skip 0. Build failure 0
```

Differential Revision: D75452123

Pull Request resolved: https://github.com/pytorch/pytorch/pull/154431
Approved by: https://github.com/jansel
This commit is contained in:
Zhengxu Chen 2025-05-28 16:03:02 +00:00 committed by PyTorch MergeBot
parent 9db7bcb3fe
commit 5bf74753f6
3 changed files with 37 additions and 12 deletions

View File

@ -235,6 +235,15 @@ pytree.register_constant(CustomConstantType)
class TestGuardSerialization(torch._inductor.test_case.TestCase):
def test_function_locals(self):
def foo(x):
return x + 1
def fn(x, g):
return g(x) + 1
self._test_serialization("TENSOR_MATCH", fn, torch.randn(3), foo)
def _tracefunc(self, frame, event, arg):
if event != "call":
return

View File

@ -60,6 +60,7 @@ from torch._C._dynamo.guards import (
)
from torch._dynamo.source import (
get_global_source_name,
get_local_source_name,
IndexedSource,
is_from_flatten_script_object_source,
is_from_local_source,
@ -2856,17 +2857,23 @@ class CheckFunctionManager:
self.guards_state: Optional[bytes] = None
if self.guards_serialization_mode == "save":
output_graph_guards_state = self.output_graph.dump_guards_state()
# Only serialize the global variables that are actually used in guards.
used_global_vars = set()
for guard in sorted_guards:
if name := get_global_source_name(guard.originating_source):
assert isinstance(name, str)
used_global_vars.add(name)
for source in self.output_graph.guard_on_key_order:
used_local_vars = set()
def prune_variable(source):
if name := get_global_source_name(source):
assert isinstance(name, str)
used_global_vars.add(name)
elif name := get_local_source_name(source):
assert isinstance(name, str)
used_local_vars.add(name)
output_graph_guards_state = self.output_graph.dump_guards_state()
# Only serialize the global variables that are actually used in guards.
for guard in sorted_guards:
prune_variable(guard.originating_source)
for source in self.output_graph.guard_on_key_order:
prune_variable(source)
def normalize_create_fn(x):
if isinstance(x, functools.partial):
@ -2884,6 +2891,11 @@ class CheckFunctionManager:
output_graph_guards_state = dataclasses.replace(
output_graph_guards_state,
local_scope={
k: v
for k, v in output_graph_guards_state.local_scope.items()
if k in used_local_vars
},
global_scope={
k: v
for k, v in output_graph_guards_state.global_scope.items()

View File

@ -859,14 +859,18 @@ class BackwardStateSource(Source):
return GuardSource.BACKWARD_STATE
def is_from_local_source(source: Source, *, only_allow_input=False):
def get_local_source_name(source: Source, *, only_allow_input=False) -> Optional[str]:
if isinstance(source, ChainedSource):
return is_from_local_source(source.base, only_allow_input=only_allow_input)
return get_local_source_name(source.base, only_allow_input=only_allow_input)
if not isinstance(source, LocalSource):
return False
return None
if only_allow_input and not source.is_input:
return False
return True
return None
return source.local_name
def is_from_local_source(source: Source, *, only_allow_input=False):
return get_local_source_name(source, only_allow_input=only_allow_input) is not None
def is_from_global_source(source: Source) -> bool: