mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[dynamo] Guard serialization for TYPE_MATCH (#152325)
Adding guard serialization for TYPE_MATCH Differential Revision: [D73780438](https://our.internmc.facebook.com/intern/diff/D73780438/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/152325 Approved by: https://github.com/jansel
This commit is contained in:
parent
a04f4622e1
commit
df663b9e72
|
|
@ -26,6 +26,11 @@ class _FrameState:
|
|||
f_builtins: dict
|
||||
|
||||
|
||||
class GlobalModule(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
return x + 1
|
||||
|
||||
|
||||
class TestGuardSerialization(torch._inductor.test_case.TestCase):
|
||||
def _tracefunc(self, frame, event, arg):
|
||||
if event != "call":
|
||||
|
|
@ -184,6 +189,27 @@ class TestGuardSerialization(torch._inductor.test_case.TestCase):
|
|||
delattr(m, "a")
|
||||
self._test_check_fn(ref, loaded, {"m": m}, False)
|
||||
|
||||
def test_type_match(self):
|
||||
class LocalModule(torch.nn.Module):
|
||||
def forward(self, x: torch.Tensor):
|
||||
return x + 1
|
||||
|
||||
m = LocalModule()
|
||||
|
||||
def fn(m, x):
|
||||
return m(x)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
TypeError, "Please define the class at global scope"
|
||||
):
|
||||
self._test_serialization("TYPE_MATCH", fn, m, torch.randn(3))
|
||||
|
||||
m = GlobalModule()
|
||||
ref, loaded = self._test_serialization("TYPE_MATCH", fn, m, torch.randn(3))
|
||||
self._test_check_fn(ref, loaded, {"m": m}, True)
|
||||
self._test_check_fn(ref, loaded, {"m": GlobalModule()}, True)
|
||||
self._test_check_fn(ref, loaded, {"m": torch.nn.Module()}, False)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._dynamo.test_case import run_tests
|
||||
|
|
|
|||
|
|
@ -39,7 +39,7 @@ import weakref
|
|||
from contextlib import contextmanager
|
||||
from copy import deepcopy
|
||||
from inspect import currentframe
|
||||
from typing import Any, Callable, Optional, TYPE_CHECKING, Union
|
||||
from typing import Any, Callable, NoReturn, Optional, TYPE_CHECKING, Union
|
||||
from weakref import ReferenceType
|
||||
|
||||
import torch
|
||||
|
|
@ -524,6 +524,14 @@ def get_key_index_source(source, index):
|
|||
return f"list(dict.keys({source}))[{index}]"
|
||||
|
||||
|
||||
def raise_local_type_error(obj: Any) -> NoReturn:
|
||||
raise TypeError(
|
||||
f"Type {type(obj)} for object {obj} cannot be saved "
|
||||
+ "into torch.compile() package since it's defined in local scope. "
|
||||
+ "Please define the class at global scope (top level of a module)."
|
||||
)
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class NNModuleAttrAccessorInfo:
|
||||
# Represents where is the attr name is present in the nn module attribute
|
||||
|
|
@ -622,6 +630,7 @@ class GuardBuilder(GuardBuilderBase):
|
|||
global_scope: dict[str, object],
|
||||
guard_manager: GuardManagerWrapper,
|
||||
check_fn_manager: CheckFunctionManager,
|
||||
serialization_mode: Optional[str] = None,
|
||||
):
|
||||
self.f_code = f_code
|
||||
self.id_ref = id_ref
|
||||
|
|
@ -674,6 +683,7 @@ class GuardBuilder(GuardBuilderBase):
|
|||
str, torch._C._dynamo.guards.GuardManager
|
||||
] = {}
|
||||
self._cached_duplicate_input_guards: set[tuple[str, str]] = set()
|
||||
self.serialization_mode = serialization_mode
|
||||
|
||||
def guard_on_dict_keys_and_ignore_order(self, example_value, guard):
|
||||
dict_mgr = self.get_guard_manager(guard)
|
||||
|
|
@ -1436,7 +1446,12 @@ class GuardBuilder(GuardBuilderBase):
|
|||
|
||||
def TYPE_MATCH(self, guard: Guard) -> None:
|
||||
# ___check_type_id is same as `id(type(x)) == y`
|
||||
t = type(self.get(guard.name))
|
||||
value = self.get(guard.name)
|
||||
t = type(value)
|
||||
if self.serialization_mode == "save":
|
||||
if t.__qualname__ != t.__name__:
|
||||
raise_local_type_error(value)
|
||||
|
||||
obj_id = self.id_ref(t, f"type({guard.name})")
|
||||
code = f"___check_type_id({self.arg_ref(guard)}, {obj_id})"
|
||||
self._set_guard_export_info(guard, [code])
|
||||
|
|
@ -2495,6 +2510,10 @@ class GuardsStatePickler(pickle.Pickler):
|
|||
torch._C.DispatchKeySet.from_raw_repr(dispatch_keys_raw),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _unpickle_python_module(cls, alias: str):
|
||||
return importlib.import_module(alias)
|
||||
|
||||
def reducer_override(self, obj):
|
||||
if isinstance(obj, torch.Tensor) and obj.device.type != "meta":
|
||||
return type(self)._unpickle_tensor, (
|
||||
|
|
@ -2505,9 +2524,14 @@ class GuardsStatePickler(pickle.Pickler):
|
|||
)
|
||||
|
||||
elif isinstance(obj, torch.nn.Module):
|
||||
if type(obj).__qualname__ == type(obj).__name__:
|
||||
return NotImplemented
|
||||
if obj.__class__.__getstate__ == torch.nn.Module.__getstate__:
|
||||
return type(self)._unpickle_module, (obj.__getstate__(),)
|
||||
|
||||
elif inspect.ismodule(obj):
|
||||
return type(self)._unpickle_python_module, (obj.__name__,)
|
||||
|
||||
if type(obj).__qualname__ != type(obj).__name__:
|
||||
raise RuntimeError(
|
||||
f"Type {type(obj)} for object {obj} cannot be saved "
|
||||
|
|
@ -2562,7 +2586,11 @@ class CheckFunctionManager:
|
|||
|
||||
sorted_guards = sorted(guards or (), key=Guard.sort_key)
|
||||
builder, guard_manager = self.build_guards(
|
||||
sorted_guards, existing_diff_guard_sources, f_code, output_graph
|
||||
sorted_guards,
|
||||
existing_diff_guard_sources,
|
||||
f_code,
|
||||
output_graph,
|
||||
None if guard_filter_fn else self.guards_serialization_mode,
|
||||
)
|
||||
|
||||
if guard_filter_fn:
|
||||
|
|
@ -2602,7 +2630,11 @@ class CheckFunctionManager:
|
|||
]
|
||||
# Redo the guards because filtering relies on the results from the last guard builder.
|
||||
builder, guard_manager = self.build_guards(
|
||||
sorted_guards, existing_diff_guard_sources, f_code, output_graph
|
||||
sorted_guards,
|
||||
existing_diff_guard_sources,
|
||||
f_code,
|
||||
output_graph,
|
||||
self.guards_serialization_mode,
|
||||
)
|
||||
|
||||
self.guard_manager = guard_manager
|
||||
|
|
@ -2729,6 +2761,7 @@ class CheckFunctionManager:
|
|||
existing_diff_guard_sources,
|
||||
f_code,
|
||||
output_graph,
|
||||
serialization_mode=None,
|
||||
):
|
||||
guard_manager = GuardManagerWrapper()
|
||||
guard_manager.diff_guard_sources = existing_diff_guard_sources
|
||||
|
|
@ -2754,6 +2787,7 @@ class CheckFunctionManager:
|
|||
output_graph.global_scope,
|
||||
guard_manager,
|
||||
self,
|
||||
serialization_mode,
|
||||
)
|
||||
|
||||
# Break retain cycle. See test_release_scope_memory
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user