diff --git a/test/dynamo/test_aot_compile.py b/test/dynamo/test_aot_compile.py index 12efbdab94f..c822569f624 100644 --- a/test/dynamo/test_aot_compile.py +++ b/test/dynamo/test_aot_compile.py @@ -84,6 +84,21 @@ class RepeatInterleaveModule(torch.nn.Module): return y_repeat +class MultiModalMixin(torch.nn.Module): + def forward(self, x): + return super().forward(x) + + +class TextModel(torch.nn.Module): + def forward(self, x): + return x + 1 + + +class TestVLLMModel(MultiModalMixin, TextModel): + def forward(self, x): + return super().forward(x) + + @torch._dynamo.config.patch("enable_aot_compile", True) @instantiate_parametrized_tests class TestAOTCompile(torch._inductor.test_case.TestCase): @@ -532,6 +547,41 @@ from user code: ) self.assertEqual(compiled_foo(inputs), foo(inputs)) + def test_aot_compile_with_closure_save_and_load(self): + tmp = 2 + + def fn(x, y): + return x + y + tmp + + compiled_fn = torch.compile(fn, fullgraph=True).aot_compile( + ((torch.randn(3, 4), torch.randn(3, 4)), {}) + ) + inputs = (torch.randn(3, 4), torch.randn(3, 4)) + expected = fn(*inputs) + actual = compiled_fn(*inputs) + self.assertEqual(expected, actual) + compiled_fn.save_compiled_function(self.path()) + with open(self.path(), "rb") as f: + compiled_fn = torch.compiler.load_compiled_function(f) + actual = compiled_fn(*inputs) + self.assertEqual(expected, actual) + + def test_aot_compile_with_super_call(self): + fn = TestVLLMModel() + compiled_fn = torch.compile(fn.forward, fullgraph=True).aot_compile( + ((torch.randn(3, 4),), {}) + ) + self.assertEqual(fn.forward.__code__.co_freevars, ("__class__",)) + inputs = (torch.randn(3, 4),) + expected = fn(*inputs) + actual = compiled_fn(fn, *inputs) + self.assertEqual(expected, actual) + compiled_fn.save_compiled_function(self.path()) + with open(self.path(), "rb") as f: + compiled_fn = torch.compiler.load_compiled_function(f) + actual = compiled_fn(fn, *inputs) + self.assertEqual(expected, actual) + if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/torch/_dynamo/aot_compile.py b/torch/_dynamo/aot_compile.py index 26d012d4176..4b13d677f5a 100644 --- a/torch/_dynamo/aot_compile.py +++ b/torch/_dynamo/aot_compile.py @@ -1,6 +1,7 @@ import dataclasses import importlib import inspect +import io import logging import pickle import types @@ -58,13 +59,40 @@ class CompileArtifacts: current_system.check_compatibility(self.system_info, self.device_type) +class AOTCompilePickler(pickle.Pickler): + @classmethod + def _unpickle_cell(cls, val: Any) -> Any: + def _() -> Any: + return val + + assert _.__closure__ is not None + return _.__closure__[0] + + # pyrefly: ignore [bad-override] + def reducer_override(self, obj: Any) -> Any: + if isinstance(obj, type((lambda x: lambda: x)(0).__closure__[0])): # type: ignore[index] # noqa: PLC3002 + return type(self)._unpickle_cell, (obj.cell_contents,) + return NotImplemented + + @dataclass class AOTCompiledFunction: _artifacts: CompileArtifacts _guard_check_enabled: bool = True def guard_check(self, *args: Any, **kwargs: Any) -> bool: - f_locals = bind_locals(self._artifacts.signature, *args, **kwargs) + f_locals: dict[str, Any] = {} + if self._artifacts.closure: + assert self._artifacts.bytecode.co_freevars and len( + self._artifacts.closure + ) == len(self._artifacts.bytecode.co_freevars) + f_locals = { + name: cell.cell_contents + for name, cell in zip( + self._artifacts.bytecode.co_freevars, self._artifacts.closure + ) + } + f_locals.update(bind_locals(self._artifacts.signature, *args, **kwargs)) assert self._artifacts.guard_manager is not None return self._artifacts.guard_manager.check(f_locals) @@ -122,7 +150,10 @@ class AOTCompiledFunction: type(compiled_fn).serialize_compile_artifacts(compiled_fn), ) state["original_code"] = SerializedCode.from_code_object(state["original_code"]) - return pickle.dumps(state) + buf = io.BytesIO() + pickler = AOTCompilePickler(buf) + pickler.dump(state) + return buf.getvalue() @classmethod def deserialize(cls, data: bytes) -> "AOTCompiledFunction":