mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[precompile] Prune local scope variables for guard serialization. (#154431)
Summary: Prune unused local objects from serialized local scope if they are not used in guard reconstruction. This is helpful when a user program takes things like local callable functions or the function call is recursive.
Test Plan:
test/dynamo/test_guard_serialization.py -k test_function_locals
Before pruning locals:
```
state = GuardsState(output_graph=OutputGraphGuardsState(local_scope={'x': tensor([ 0.0461, 0.4024, -1.0115]), 'g': <function ...aints=None, _guards=<torch._guards.GuardsSet object at 0x7fbccc7e9fc0>, _aotautograd_guards=[]), shape_code_parts=None)
def pickle_guards_state(state: GuardsState) -> bytes:
buf = io.BytesIO()
pickler = GuardsStatePickler(buf)
try:
pickler.dump(state)
except AttributeError as e:
> raise torch._dynamo.exc.PackageError(str(e)) from e
E torch._dynamo.exc.PackageError: Can't pickle local object 'TestGuardSerialization.test_function_locals.<locals>.foo'
```
After the diff
```
Tests finished: Pass 1. Fail 0. Fatal 0. Skip 0. Build failure 0
```
Differential Revision: D75452123
Pull Request resolved: https://github.com/pytorch/pytorch/pull/154431
Approved by: https://github.com/jansel
This commit is contained in:
parent
9db7bcb3fe
commit
5bf74753f6
|
|
@ -235,6 +235,15 @@ pytree.register_constant(CustomConstantType)
|
|||
|
||||
|
||||
class TestGuardSerialization(torch._inductor.test_case.TestCase):
|
||||
def test_function_locals(self):
|
||||
def foo(x):
|
||||
return x + 1
|
||||
|
||||
def fn(x, g):
|
||||
return g(x) + 1
|
||||
|
||||
self._test_serialization("TENSOR_MATCH", fn, torch.randn(3), foo)
|
||||
|
||||
def _tracefunc(self, frame, event, arg):
|
||||
if event != "call":
|
||||
return
|
||||
|
|
|
|||
|
|
@ -60,6 +60,7 @@ from torch._C._dynamo.guards import (
|
|||
)
|
||||
from torch._dynamo.source import (
|
||||
get_global_source_name,
|
||||
get_local_source_name,
|
||||
IndexedSource,
|
||||
is_from_flatten_script_object_source,
|
||||
is_from_local_source,
|
||||
|
|
@ -2856,17 +2857,23 @@ class CheckFunctionManager:
|
|||
|
||||
self.guards_state: Optional[bytes] = None
|
||||
if self.guards_serialization_mode == "save":
|
||||
output_graph_guards_state = self.output_graph.dump_guards_state()
|
||||
# Only serialize the global variables that are actually used in guards.
|
||||
used_global_vars = set()
|
||||
for guard in sorted_guards:
|
||||
if name := get_global_source_name(guard.originating_source):
|
||||
assert isinstance(name, str)
|
||||
used_global_vars.add(name)
|
||||
for source in self.output_graph.guard_on_key_order:
|
||||
used_local_vars = set()
|
||||
|
||||
def prune_variable(source):
|
||||
if name := get_global_source_name(source):
|
||||
assert isinstance(name, str)
|
||||
used_global_vars.add(name)
|
||||
elif name := get_local_source_name(source):
|
||||
assert isinstance(name, str)
|
||||
used_local_vars.add(name)
|
||||
|
||||
output_graph_guards_state = self.output_graph.dump_guards_state()
|
||||
# Only serialize the global variables that are actually used in guards.
|
||||
for guard in sorted_guards:
|
||||
prune_variable(guard.originating_source)
|
||||
for source in self.output_graph.guard_on_key_order:
|
||||
prune_variable(source)
|
||||
|
||||
def normalize_create_fn(x):
|
||||
if isinstance(x, functools.partial):
|
||||
|
|
@ -2884,6 +2891,11 @@ class CheckFunctionManager:
|
|||
|
||||
output_graph_guards_state = dataclasses.replace(
|
||||
output_graph_guards_state,
|
||||
local_scope={
|
||||
k: v
|
||||
for k, v in output_graph_guards_state.local_scope.items()
|
||||
if k in used_local_vars
|
||||
},
|
||||
global_scope={
|
||||
k: v
|
||||
for k, v in output_graph_guards_state.global_scope.items()
|
||||
|
|
|
|||
|
|
@ -859,14 +859,18 @@ class BackwardStateSource(Source):
|
|||
return GuardSource.BACKWARD_STATE
|
||||
|
||||
|
||||
def is_from_local_source(source: Source, *, only_allow_input=False):
|
||||
def get_local_source_name(source: Source, *, only_allow_input=False) -> Optional[str]:
|
||||
if isinstance(source, ChainedSource):
|
||||
return is_from_local_source(source.base, only_allow_input=only_allow_input)
|
||||
return get_local_source_name(source.base, only_allow_input=only_allow_input)
|
||||
if not isinstance(source, LocalSource):
|
||||
return False
|
||||
return None
|
||||
if only_allow_input and not source.is_input:
|
||||
return False
|
||||
return True
|
||||
return None
|
||||
return source.local_name
|
||||
|
||||
|
||||
def is_from_local_source(source: Source, *, only_allow_input=False):
|
||||
return get_local_source_name(source, only_allow_input=only_allow_input) is not None
|
||||
|
||||
|
||||
def is_from_global_source(source: Source) -> bool:
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user