mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[precompile] Pickle and check closure variable properly. (#166351)
Summary: Previously we didn't correctly handle closure tuple when there's content in it. Adding additional code for serializing the tuple and merge it with guard manager local scope. Test Plan: pytest test/dynamo/test_aot_compile.py Reviewers: Subscribers: Tasks: Tags: Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/166351 Approved by: https://github.com/Lucaskabela
This commit is contained in:
parent
2a058bfecf
commit
56afad4eb3
|
|
@ -84,6 +84,21 @@ class RepeatInterleaveModule(torch.nn.Module):
|
||||||
return y_repeat
|
return y_repeat
|
||||||
|
|
||||||
|
|
||||||
|
class MultiModalMixin(torch.nn.Module):
|
||||||
|
def forward(self, x):
|
||||||
|
return super().forward(x)
|
||||||
|
|
||||||
|
|
||||||
|
class TextModel(torch.nn.Module):
|
||||||
|
def forward(self, x):
|
||||||
|
return x + 1
|
||||||
|
|
||||||
|
|
||||||
|
class TestVLLMModel(MultiModalMixin, TextModel):
|
||||||
|
def forward(self, x):
|
||||||
|
return super().forward(x)
|
||||||
|
|
||||||
|
|
||||||
@torch._dynamo.config.patch("enable_aot_compile", True)
|
@torch._dynamo.config.patch("enable_aot_compile", True)
|
||||||
@instantiate_parametrized_tests
|
@instantiate_parametrized_tests
|
||||||
class TestAOTCompile(torch._inductor.test_case.TestCase):
|
class TestAOTCompile(torch._inductor.test_case.TestCase):
|
||||||
|
|
@ -532,6 +547,41 @@ from user code:
|
||||||
)
|
)
|
||||||
self.assertEqual(compiled_foo(inputs), foo(inputs))
|
self.assertEqual(compiled_foo(inputs), foo(inputs))
|
||||||
|
|
||||||
|
def test_aot_compile_with_closure_save_and_load(self):
|
||||||
|
tmp = 2
|
||||||
|
|
||||||
|
def fn(x, y):
|
||||||
|
return x + y + tmp
|
||||||
|
|
||||||
|
compiled_fn = torch.compile(fn, fullgraph=True).aot_compile(
|
||||||
|
((torch.randn(3, 4), torch.randn(3, 4)), {})
|
||||||
|
)
|
||||||
|
inputs = (torch.randn(3, 4), torch.randn(3, 4))
|
||||||
|
expected = fn(*inputs)
|
||||||
|
actual = compiled_fn(*inputs)
|
||||||
|
self.assertEqual(expected, actual)
|
||||||
|
compiled_fn.save_compiled_function(self.path())
|
||||||
|
with open(self.path(), "rb") as f:
|
||||||
|
compiled_fn = torch.compiler.load_compiled_function(f)
|
||||||
|
actual = compiled_fn(*inputs)
|
||||||
|
self.assertEqual(expected, actual)
|
||||||
|
|
||||||
|
def test_aot_compile_with_super_call(self):
|
||||||
|
fn = TestVLLMModel()
|
||||||
|
compiled_fn = torch.compile(fn.forward, fullgraph=True).aot_compile(
|
||||||
|
((torch.randn(3, 4),), {})
|
||||||
|
)
|
||||||
|
self.assertEqual(fn.forward.__code__.co_freevars, ("__class__",))
|
||||||
|
inputs = (torch.randn(3, 4),)
|
||||||
|
expected = fn(*inputs)
|
||||||
|
actual = compiled_fn(fn, *inputs)
|
||||||
|
self.assertEqual(expected, actual)
|
||||||
|
compiled_fn.save_compiled_function(self.path())
|
||||||
|
with open(self.path(), "rb") as f:
|
||||||
|
compiled_fn = torch.compiler.load_compiled_function(f)
|
||||||
|
actual = compiled_fn(fn, *inputs)
|
||||||
|
self.assertEqual(expected, actual)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
from torch._dynamo.test_case import run_tests
|
from torch._dynamo.test_case import run_tests
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
import dataclasses
|
import dataclasses
|
||||||
import importlib
|
import importlib
|
||||||
import inspect
|
import inspect
|
||||||
|
import io
|
||||||
import logging
|
import logging
|
||||||
import pickle
|
import pickle
|
||||||
import types
|
import types
|
||||||
|
|
@ -58,13 +59,40 @@ class CompileArtifacts:
|
||||||
current_system.check_compatibility(self.system_info, self.device_type)
|
current_system.check_compatibility(self.system_info, self.device_type)
|
||||||
|
|
||||||
|
|
||||||
|
class AOTCompilePickler(pickle.Pickler):
|
||||||
|
@classmethod
|
||||||
|
def _unpickle_cell(cls, val: Any) -> Any:
|
||||||
|
def _() -> Any:
|
||||||
|
return val
|
||||||
|
|
||||||
|
assert _.__closure__ is not None
|
||||||
|
return _.__closure__[0]
|
||||||
|
|
||||||
|
# pyrefly: ignore [bad-override]
|
||||||
|
def reducer_override(self, obj: Any) -> Any:
|
||||||
|
if isinstance(obj, type((lambda x: lambda: x)(0).__closure__[0])): # type: ignore[index] # noqa: PLC3002
|
||||||
|
return type(self)._unpickle_cell, (obj.cell_contents,)
|
||||||
|
return NotImplemented
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class AOTCompiledFunction:
|
class AOTCompiledFunction:
|
||||||
_artifacts: CompileArtifacts
|
_artifacts: CompileArtifacts
|
||||||
_guard_check_enabled: bool = True
|
_guard_check_enabled: bool = True
|
||||||
|
|
||||||
def guard_check(self, *args: Any, **kwargs: Any) -> bool:
|
def guard_check(self, *args: Any, **kwargs: Any) -> bool:
|
||||||
f_locals = bind_locals(self._artifacts.signature, *args, **kwargs)
|
f_locals: dict[str, Any] = {}
|
||||||
|
if self._artifacts.closure:
|
||||||
|
assert self._artifacts.bytecode.co_freevars and len(
|
||||||
|
self._artifacts.closure
|
||||||
|
) == len(self._artifacts.bytecode.co_freevars)
|
||||||
|
f_locals = {
|
||||||
|
name: cell.cell_contents
|
||||||
|
for name, cell in zip(
|
||||||
|
self._artifacts.bytecode.co_freevars, self._artifacts.closure
|
||||||
|
)
|
||||||
|
}
|
||||||
|
f_locals.update(bind_locals(self._artifacts.signature, *args, **kwargs))
|
||||||
assert self._artifacts.guard_manager is not None
|
assert self._artifacts.guard_manager is not None
|
||||||
return self._artifacts.guard_manager.check(f_locals)
|
return self._artifacts.guard_manager.check(f_locals)
|
||||||
|
|
||||||
|
|
@ -122,7 +150,10 @@ class AOTCompiledFunction:
|
||||||
type(compiled_fn).serialize_compile_artifacts(compiled_fn),
|
type(compiled_fn).serialize_compile_artifacts(compiled_fn),
|
||||||
)
|
)
|
||||||
state["original_code"] = SerializedCode.from_code_object(state["original_code"])
|
state["original_code"] = SerializedCode.from_code_object(state["original_code"])
|
||||||
return pickle.dumps(state)
|
buf = io.BytesIO()
|
||||||
|
pickler = AOTCompilePickler(buf)
|
||||||
|
pickler.dump(state)
|
||||||
|
return buf.getvalue()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def deserialize(cls, data: bytes) -> "AOTCompiledFunction":
|
def deserialize(cls, data: bytes) -> "AOTCompiledFunction":
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user