[precompile] Pickle and check closure variable properly. (#166351)

Summary:

Previously we didn't correctly handle closure tuple when there's content in it. Adding additional code for serializing the tuple and merge it with guard manager local scope.

Test Plan:

pytest test/dynamo/test_aot_compile.py

Reviewers:

Subscribers:

Tasks:

Tags:

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166351
Approved by: https://github.com/Lucaskabela
This commit is contained in:
zhxchen17 2025-10-29 00:28:18 +00:00 committed by PyTorch MergeBot
parent 2a058bfecf
commit 56afad4eb3
2 changed files with 83 additions and 2 deletions

View File

@ -84,6 +84,21 @@ class RepeatInterleaveModule(torch.nn.Module):
return y_repeat return y_repeat
class MultiModalMixin(torch.nn.Module):
def forward(self, x):
return super().forward(x)
class TextModel(torch.nn.Module):
def forward(self, x):
return x + 1
class TestVLLMModel(MultiModalMixin, TextModel):
def forward(self, x):
return super().forward(x)
@torch._dynamo.config.patch("enable_aot_compile", True) @torch._dynamo.config.patch("enable_aot_compile", True)
@instantiate_parametrized_tests @instantiate_parametrized_tests
class TestAOTCompile(torch._inductor.test_case.TestCase): class TestAOTCompile(torch._inductor.test_case.TestCase):
@ -532,6 +547,41 @@ from user code:
) )
self.assertEqual(compiled_foo(inputs), foo(inputs)) self.assertEqual(compiled_foo(inputs), foo(inputs))
def test_aot_compile_with_closure_save_and_load(self):
tmp = 2
def fn(x, y):
return x + y + tmp
compiled_fn = torch.compile(fn, fullgraph=True).aot_compile(
((torch.randn(3, 4), torch.randn(3, 4)), {})
)
inputs = (torch.randn(3, 4), torch.randn(3, 4))
expected = fn(*inputs)
actual = compiled_fn(*inputs)
self.assertEqual(expected, actual)
compiled_fn.save_compiled_function(self.path())
with open(self.path(), "rb") as f:
compiled_fn = torch.compiler.load_compiled_function(f)
actual = compiled_fn(*inputs)
self.assertEqual(expected, actual)
def test_aot_compile_with_super_call(self):
fn = TestVLLMModel()
compiled_fn = torch.compile(fn.forward, fullgraph=True).aot_compile(
((torch.randn(3, 4),), {})
)
self.assertEqual(fn.forward.__code__.co_freevars, ("__class__",))
inputs = (torch.randn(3, 4),)
expected = fn(*inputs)
actual = compiled_fn(fn, *inputs)
self.assertEqual(expected, actual)
compiled_fn.save_compiled_function(self.path())
with open(self.path(), "rb") as f:
compiled_fn = torch.compiler.load_compiled_function(f)
actual = compiled_fn(fn, *inputs)
self.assertEqual(expected, actual)
if __name__ == "__main__": if __name__ == "__main__":
from torch._dynamo.test_case import run_tests from torch._dynamo.test_case import run_tests

View File

@ -1,6 +1,7 @@
import dataclasses import dataclasses
import importlib import importlib
import inspect import inspect
import io
import logging import logging
import pickle import pickle
import types import types
@ -58,13 +59,40 @@ class CompileArtifacts:
current_system.check_compatibility(self.system_info, self.device_type) current_system.check_compatibility(self.system_info, self.device_type)
class AOTCompilePickler(pickle.Pickler):
@classmethod
def _unpickle_cell(cls, val: Any) -> Any:
def _() -> Any:
return val
assert _.__closure__ is not None
return _.__closure__[0]
# pyrefly: ignore [bad-override]
def reducer_override(self, obj: Any) -> Any:
if isinstance(obj, type((lambda x: lambda: x)(0).__closure__[0])): # type: ignore[index] # noqa: PLC3002
return type(self)._unpickle_cell, (obj.cell_contents,)
return NotImplemented
@dataclass @dataclass
class AOTCompiledFunction: class AOTCompiledFunction:
_artifacts: CompileArtifacts _artifacts: CompileArtifacts
_guard_check_enabled: bool = True _guard_check_enabled: bool = True
def guard_check(self, *args: Any, **kwargs: Any) -> bool: def guard_check(self, *args: Any, **kwargs: Any) -> bool:
f_locals = bind_locals(self._artifacts.signature, *args, **kwargs) f_locals: dict[str, Any] = {}
if self._artifacts.closure:
assert self._artifacts.bytecode.co_freevars and len(
self._artifacts.closure
) == len(self._artifacts.bytecode.co_freevars)
f_locals = {
name: cell.cell_contents
for name, cell in zip(
self._artifacts.bytecode.co_freevars, self._artifacts.closure
)
}
f_locals.update(bind_locals(self._artifacts.signature, *args, **kwargs))
assert self._artifacts.guard_manager is not None assert self._artifacts.guard_manager is not None
return self._artifacts.guard_manager.check(f_locals) return self._artifacts.guard_manager.check(f_locals)
@ -122,7 +150,10 @@ class AOTCompiledFunction:
type(compiled_fn).serialize_compile_artifacts(compiled_fn), type(compiled_fn).serialize_compile_artifacts(compiled_fn),
) )
state["original_code"] = SerializedCode.from_code_object(state["original_code"]) state["original_code"] = SerializedCode.from_code_object(state["original_code"])
return pickle.dumps(state) buf = io.BytesIO()
pickler = AOTCompilePickler(buf)
pickler.dump(state)
return buf.getvalue()
@classmethod @classmethod
def deserialize(cls, data: bytes) -> "AOTCompiledFunction": def deserialize(cls, data: bytes) -> "AOTCompiledFunction":