mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[precompile] Pickle and check closure variable properly. (#166351)
Summary: Previously we didn't correctly handle closure tuple when there's content in it. Adding additional code for serializing the tuple and merge it with guard manager local scope. Test Plan: pytest test/dynamo/test_aot_compile.py Reviewers: Subscribers: Tasks: Tags: Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/166351 Approved by: https://github.com/Lucaskabela
This commit is contained in:
parent
2a058bfecf
commit
56afad4eb3
|
|
@ -84,6 +84,21 @@ class RepeatInterleaveModule(torch.nn.Module):
|
|||
return y_repeat
|
||||
|
||||
|
||||
class MultiModalMixin(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
return super().forward(x)
|
||||
|
||||
|
||||
class TextModel(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
return x + 1
|
||||
|
||||
|
||||
class TestVLLMModel(MultiModalMixin, TextModel):
|
||||
def forward(self, x):
|
||||
return super().forward(x)
|
||||
|
||||
|
||||
@torch._dynamo.config.patch("enable_aot_compile", True)
|
||||
@instantiate_parametrized_tests
|
||||
class TestAOTCompile(torch._inductor.test_case.TestCase):
|
||||
|
|
@ -532,6 +547,41 @@ from user code:
|
|||
)
|
||||
self.assertEqual(compiled_foo(inputs), foo(inputs))
|
||||
|
||||
def test_aot_compile_with_closure_save_and_load(self):
|
||||
tmp = 2
|
||||
|
||||
def fn(x, y):
|
||||
return x + y + tmp
|
||||
|
||||
compiled_fn = torch.compile(fn, fullgraph=True).aot_compile(
|
||||
((torch.randn(3, 4), torch.randn(3, 4)), {})
|
||||
)
|
||||
inputs = (torch.randn(3, 4), torch.randn(3, 4))
|
||||
expected = fn(*inputs)
|
||||
actual = compiled_fn(*inputs)
|
||||
self.assertEqual(expected, actual)
|
||||
compiled_fn.save_compiled_function(self.path())
|
||||
with open(self.path(), "rb") as f:
|
||||
compiled_fn = torch.compiler.load_compiled_function(f)
|
||||
actual = compiled_fn(*inputs)
|
||||
self.assertEqual(expected, actual)
|
||||
|
||||
def test_aot_compile_with_super_call(self):
|
||||
fn = TestVLLMModel()
|
||||
compiled_fn = torch.compile(fn.forward, fullgraph=True).aot_compile(
|
||||
((torch.randn(3, 4),), {})
|
||||
)
|
||||
self.assertEqual(fn.forward.__code__.co_freevars, ("__class__",))
|
||||
inputs = (torch.randn(3, 4),)
|
||||
expected = fn(*inputs)
|
||||
actual = compiled_fn(fn, *inputs)
|
||||
self.assertEqual(expected, actual)
|
||||
compiled_fn.save_compiled_function(self.path())
|
||||
with open(self.path(), "rb") as f:
|
||||
compiled_fn = torch.compiler.load_compiled_function(f)
|
||||
actual = compiled_fn(fn, *inputs)
|
||||
self.assertEqual(expected, actual)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._dynamo.test_case import run_tests
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
import dataclasses
|
||||
import importlib
|
||||
import inspect
|
||||
import io
|
||||
import logging
|
||||
import pickle
|
||||
import types
|
||||
|
|
@ -58,13 +59,40 @@ class CompileArtifacts:
|
|||
current_system.check_compatibility(self.system_info, self.device_type)
|
||||
|
||||
|
||||
class AOTCompilePickler(pickle.Pickler):
|
||||
@classmethod
|
||||
def _unpickle_cell(cls, val: Any) -> Any:
|
||||
def _() -> Any:
|
||||
return val
|
||||
|
||||
assert _.__closure__ is not None
|
||||
return _.__closure__[0]
|
||||
|
||||
# pyrefly: ignore [bad-override]
|
||||
def reducer_override(self, obj: Any) -> Any:
|
||||
if isinstance(obj, type((lambda x: lambda: x)(0).__closure__[0])): # type: ignore[index] # noqa: PLC3002
|
||||
return type(self)._unpickle_cell, (obj.cell_contents,)
|
||||
return NotImplemented
|
||||
|
||||
|
||||
@dataclass
|
||||
class AOTCompiledFunction:
|
||||
_artifacts: CompileArtifacts
|
||||
_guard_check_enabled: bool = True
|
||||
|
||||
def guard_check(self, *args: Any, **kwargs: Any) -> bool:
|
||||
f_locals = bind_locals(self._artifacts.signature, *args, **kwargs)
|
||||
f_locals: dict[str, Any] = {}
|
||||
if self._artifacts.closure:
|
||||
assert self._artifacts.bytecode.co_freevars and len(
|
||||
self._artifacts.closure
|
||||
) == len(self._artifacts.bytecode.co_freevars)
|
||||
f_locals = {
|
||||
name: cell.cell_contents
|
||||
for name, cell in zip(
|
||||
self._artifacts.bytecode.co_freevars, self._artifacts.closure
|
||||
)
|
||||
}
|
||||
f_locals.update(bind_locals(self._artifacts.signature, *args, **kwargs))
|
||||
assert self._artifacts.guard_manager is not None
|
||||
return self._artifacts.guard_manager.check(f_locals)
|
||||
|
||||
|
|
@ -122,7 +150,10 @@ class AOTCompiledFunction:
|
|||
type(compiled_fn).serialize_compile_artifacts(compiled_fn),
|
||||
)
|
||||
state["original_code"] = SerializedCode.from_code_object(state["original_code"])
|
||||
return pickle.dumps(state)
|
||||
buf = io.BytesIO()
|
||||
pickler = AOTCompilePickler(buf)
|
||||
pickler.dump(state)
|
||||
return buf.getvalue()
|
||||
|
||||
@classmethod
|
||||
def deserialize(cls, data: bytes) -> "AOTCompiledFunction":
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user