[dynamo] Guard serialization for TYPE_MATCH (#152325)

Adding guard serialization for TYPE_MATCH

Differential Revision: [D73780438](https://our.internmc.facebook.com/intern/diff/D73780438/)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/152325
Approved by: https://github.com/jansel
This commit is contained in:
zhxchen17 2025-04-29 07:56:13 -07:00 committed by PyTorch MergeBot
parent a04f4622e1
commit df663b9e72
2 changed files with 64 additions and 4 deletions

View File

@ -26,6 +26,11 @@ class _FrameState:
f_builtins: dict
class GlobalModule(torch.nn.Module):
def forward(self, x):
return x + 1
class TestGuardSerialization(torch._inductor.test_case.TestCase):
def _tracefunc(self, frame, event, arg):
if event != "call":
@ -184,6 +189,27 @@ class TestGuardSerialization(torch._inductor.test_case.TestCase):
delattr(m, "a")
self._test_check_fn(ref, loaded, {"m": m}, False)
def test_type_match(self):
class LocalModule(torch.nn.Module):
def forward(self, x: torch.Tensor):
return x + 1
m = LocalModule()
def fn(m, x):
return m(x)
with self.assertRaisesRegex(
TypeError, "Please define the class at global scope"
):
self._test_serialization("TYPE_MATCH", fn, m, torch.randn(3))
m = GlobalModule()
ref, loaded = self._test_serialization("TYPE_MATCH", fn, m, torch.randn(3))
self._test_check_fn(ref, loaded, {"m": m}, True)
self._test_check_fn(ref, loaded, {"m": GlobalModule()}, True)
self._test_check_fn(ref, loaded, {"m": torch.nn.Module()}, False)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests

View File

@ -39,7 +39,7 @@ import weakref
from contextlib import contextmanager
from copy import deepcopy
from inspect import currentframe
from typing import Any, Callable, Optional, TYPE_CHECKING, Union
from typing import Any, Callable, NoReturn, Optional, TYPE_CHECKING, Union
from weakref import ReferenceType
import torch
@ -524,6 +524,14 @@ def get_key_index_source(source, index):
return f"list(dict.keys({source}))[{index}]"
def raise_local_type_error(obj: Any) -> NoReturn:
raise TypeError(
f"Type {type(obj)} for object {obj} cannot be saved "
+ "into torch.compile() package since it's defined in local scope. "
+ "Please define the class at global scope (top level of a module)."
)
@dataclasses.dataclass(frozen=True)
class NNModuleAttrAccessorInfo:
# Represents where is the attr name is present in the nn module attribute
@ -622,6 +630,7 @@ class GuardBuilder(GuardBuilderBase):
global_scope: dict[str, object],
guard_manager: GuardManagerWrapper,
check_fn_manager: CheckFunctionManager,
serialization_mode: Optional[str] = None,
):
self.f_code = f_code
self.id_ref = id_ref
@ -674,6 +683,7 @@ class GuardBuilder(GuardBuilderBase):
str, torch._C._dynamo.guards.GuardManager
] = {}
self._cached_duplicate_input_guards: set[tuple[str, str]] = set()
self.serialization_mode = serialization_mode
def guard_on_dict_keys_and_ignore_order(self, example_value, guard):
dict_mgr = self.get_guard_manager(guard)
@ -1436,7 +1446,12 @@ class GuardBuilder(GuardBuilderBase):
def TYPE_MATCH(self, guard: Guard) -> None:
# ___check_type_id is same as `id(type(x)) == y`
t = type(self.get(guard.name))
value = self.get(guard.name)
t = type(value)
if self.serialization_mode == "save":
if t.__qualname__ != t.__name__:
raise_local_type_error(value)
obj_id = self.id_ref(t, f"type({guard.name})")
code = f"___check_type_id({self.arg_ref(guard)}, {obj_id})"
self._set_guard_export_info(guard, [code])
@ -2495,6 +2510,10 @@ class GuardsStatePickler(pickle.Pickler):
torch._C.DispatchKeySet.from_raw_repr(dispatch_keys_raw),
)
@classmethod
def _unpickle_python_module(cls, alias: str):
return importlib.import_module(alias)
def reducer_override(self, obj):
if isinstance(obj, torch.Tensor) and obj.device.type != "meta":
return type(self)._unpickle_tensor, (
@ -2505,9 +2524,14 @@ class GuardsStatePickler(pickle.Pickler):
)
elif isinstance(obj, torch.nn.Module):
if type(obj).__qualname__ == type(obj).__name__:
return NotImplemented
if obj.__class__.__getstate__ == torch.nn.Module.__getstate__:
return type(self)._unpickle_module, (obj.__getstate__(),)
elif inspect.ismodule(obj):
return type(self)._unpickle_python_module, (obj.__name__,)
if type(obj).__qualname__ != type(obj).__name__:
raise RuntimeError(
f"Type {type(obj)} for object {obj} cannot be saved "
@ -2562,7 +2586,11 @@ class CheckFunctionManager:
sorted_guards = sorted(guards or (), key=Guard.sort_key)
builder, guard_manager = self.build_guards(
sorted_guards, existing_diff_guard_sources, f_code, output_graph
sorted_guards,
existing_diff_guard_sources,
f_code,
output_graph,
None if guard_filter_fn else self.guards_serialization_mode,
)
if guard_filter_fn:
@ -2602,7 +2630,11 @@ class CheckFunctionManager:
]
# Redo the guards because filtering relies on the results from the last guard builder.
builder, guard_manager = self.build_guards(
sorted_guards, existing_diff_guard_sources, f_code, output_graph
sorted_guards,
existing_diff_guard_sources,
f_code,
output_graph,
self.guards_serialization_mode,
)
self.guard_manager = guard_manager
@ -2729,6 +2761,7 @@ class CheckFunctionManager:
existing_diff_guard_sources,
f_code,
output_graph,
serialization_mode=None,
):
guard_manager = GuardManagerWrapper()
guard_manager.diff_guard_sources = existing_diff_guard_sources
@ -2754,6 +2787,7 @@ class CheckFunctionManager:
output_graph.global_scope,
guard_manager,
self,
serialization_mode,
)
# Break retain cycle. See test_release_scope_memory