[dynamo] Guard serialization for TYPE_MATCH (#152325)

Adding guard serialization for TYPE_MATCH

Differential Revision: [D73780438](https://our.internmc.facebook.com/intern/diff/D73780438/)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/152325
Approved by: https://github.com/jansel
This commit is contained in:
zhxchen17 2025-04-29 07:56:13 -07:00 committed by PyTorch MergeBot
parent a04f4622e1
commit df663b9e72
2 changed files with 64 additions and 4 deletions

View File

@ -26,6 +26,11 @@ class _FrameState:
f_builtins: dict f_builtins: dict
class GlobalModule(torch.nn.Module):
def forward(self, x):
return x + 1
class TestGuardSerialization(torch._inductor.test_case.TestCase): class TestGuardSerialization(torch._inductor.test_case.TestCase):
def _tracefunc(self, frame, event, arg): def _tracefunc(self, frame, event, arg):
if event != "call": if event != "call":
@ -184,6 +189,27 @@ class TestGuardSerialization(torch._inductor.test_case.TestCase):
delattr(m, "a") delattr(m, "a")
self._test_check_fn(ref, loaded, {"m": m}, False) self._test_check_fn(ref, loaded, {"m": m}, False)
def test_type_match(self):
class LocalModule(torch.nn.Module):
def forward(self, x: torch.Tensor):
return x + 1
m = LocalModule()
def fn(m, x):
return m(x)
with self.assertRaisesRegex(
TypeError, "Please define the class at global scope"
):
self._test_serialization("TYPE_MATCH", fn, m, torch.randn(3))
m = GlobalModule()
ref, loaded = self._test_serialization("TYPE_MATCH", fn, m, torch.randn(3))
self._test_check_fn(ref, loaded, {"m": m}, True)
self._test_check_fn(ref, loaded, {"m": GlobalModule()}, True)
self._test_check_fn(ref, loaded, {"m": torch.nn.Module()}, False)
if __name__ == "__main__": if __name__ == "__main__":
from torch._dynamo.test_case import run_tests from torch._dynamo.test_case import run_tests

View File

@ -39,7 +39,7 @@ import weakref
from contextlib import contextmanager from contextlib import contextmanager
from copy import deepcopy from copy import deepcopy
from inspect import currentframe from inspect import currentframe
from typing import Any, Callable, Optional, TYPE_CHECKING, Union from typing import Any, Callable, NoReturn, Optional, TYPE_CHECKING, Union
from weakref import ReferenceType from weakref import ReferenceType
import torch import torch
@ -524,6 +524,14 @@ def get_key_index_source(source, index):
return f"list(dict.keys({source}))[{index}]" return f"list(dict.keys({source}))[{index}]"
def raise_local_type_error(obj: Any) -> NoReturn:
raise TypeError(
f"Type {type(obj)} for object {obj} cannot be saved "
+ "into torch.compile() package since it's defined in local scope. "
+ "Please define the class at global scope (top level of a module)."
)
@dataclasses.dataclass(frozen=True) @dataclasses.dataclass(frozen=True)
class NNModuleAttrAccessorInfo: class NNModuleAttrAccessorInfo:
# Represents where is the attr name is present in the nn module attribute # Represents where is the attr name is present in the nn module attribute
@ -622,6 +630,7 @@ class GuardBuilder(GuardBuilderBase):
global_scope: dict[str, object], global_scope: dict[str, object],
guard_manager: GuardManagerWrapper, guard_manager: GuardManagerWrapper,
check_fn_manager: CheckFunctionManager, check_fn_manager: CheckFunctionManager,
serialization_mode: Optional[str] = None,
): ):
self.f_code = f_code self.f_code = f_code
self.id_ref = id_ref self.id_ref = id_ref
@ -674,6 +683,7 @@ class GuardBuilder(GuardBuilderBase):
str, torch._C._dynamo.guards.GuardManager str, torch._C._dynamo.guards.GuardManager
] = {} ] = {}
self._cached_duplicate_input_guards: set[tuple[str, str]] = set() self._cached_duplicate_input_guards: set[tuple[str, str]] = set()
self.serialization_mode = serialization_mode
def guard_on_dict_keys_and_ignore_order(self, example_value, guard): def guard_on_dict_keys_and_ignore_order(self, example_value, guard):
dict_mgr = self.get_guard_manager(guard) dict_mgr = self.get_guard_manager(guard)
@ -1436,7 +1446,12 @@ class GuardBuilder(GuardBuilderBase):
def TYPE_MATCH(self, guard: Guard) -> None: def TYPE_MATCH(self, guard: Guard) -> None:
# ___check_type_id is same as `id(type(x)) == y` # ___check_type_id is same as `id(type(x)) == y`
t = type(self.get(guard.name)) value = self.get(guard.name)
t = type(value)
if self.serialization_mode == "save":
if t.__qualname__ != t.__name__:
raise_local_type_error(value)
obj_id = self.id_ref(t, f"type({guard.name})") obj_id = self.id_ref(t, f"type({guard.name})")
code = f"___check_type_id({self.arg_ref(guard)}, {obj_id})" code = f"___check_type_id({self.arg_ref(guard)}, {obj_id})"
self._set_guard_export_info(guard, [code]) self._set_guard_export_info(guard, [code])
@ -2495,6 +2510,10 @@ class GuardsStatePickler(pickle.Pickler):
torch._C.DispatchKeySet.from_raw_repr(dispatch_keys_raw), torch._C.DispatchKeySet.from_raw_repr(dispatch_keys_raw),
) )
@classmethod
def _unpickle_python_module(cls, alias: str):
return importlib.import_module(alias)
def reducer_override(self, obj): def reducer_override(self, obj):
if isinstance(obj, torch.Tensor) and obj.device.type != "meta": if isinstance(obj, torch.Tensor) and obj.device.type != "meta":
return type(self)._unpickle_tensor, ( return type(self)._unpickle_tensor, (
@ -2505,9 +2524,14 @@ class GuardsStatePickler(pickle.Pickler):
) )
elif isinstance(obj, torch.nn.Module): elif isinstance(obj, torch.nn.Module):
if type(obj).__qualname__ == type(obj).__name__:
return NotImplemented
if obj.__class__.__getstate__ == torch.nn.Module.__getstate__: if obj.__class__.__getstate__ == torch.nn.Module.__getstate__:
return type(self)._unpickle_module, (obj.__getstate__(),) return type(self)._unpickle_module, (obj.__getstate__(),)
elif inspect.ismodule(obj):
return type(self)._unpickle_python_module, (obj.__name__,)
if type(obj).__qualname__ != type(obj).__name__: if type(obj).__qualname__ != type(obj).__name__:
raise RuntimeError( raise RuntimeError(
f"Type {type(obj)} for object {obj} cannot be saved " f"Type {type(obj)} for object {obj} cannot be saved "
@ -2562,7 +2586,11 @@ class CheckFunctionManager:
sorted_guards = sorted(guards or (), key=Guard.sort_key) sorted_guards = sorted(guards or (), key=Guard.sort_key)
builder, guard_manager = self.build_guards( builder, guard_manager = self.build_guards(
sorted_guards, existing_diff_guard_sources, f_code, output_graph sorted_guards,
existing_diff_guard_sources,
f_code,
output_graph,
None if guard_filter_fn else self.guards_serialization_mode,
) )
if guard_filter_fn: if guard_filter_fn:
@ -2602,7 +2630,11 @@ class CheckFunctionManager:
] ]
# Redo the guards because filtering relies on the results from the last guard builder. # Redo the guards because filtering relies on the results from the last guard builder.
builder, guard_manager = self.build_guards( builder, guard_manager = self.build_guards(
sorted_guards, existing_diff_guard_sources, f_code, output_graph sorted_guards,
existing_diff_guard_sources,
f_code,
output_graph,
self.guards_serialization_mode,
) )
self.guard_manager = guard_manager self.guard_manager = guard_manager
@ -2729,6 +2761,7 @@ class CheckFunctionManager:
existing_diff_guard_sources, existing_diff_guard_sources,
f_code, f_code,
output_graph, output_graph,
serialization_mode=None,
): ):
guard_manager = GuardManagerWrapper() guard_manager = GuardManagerWrapper()
guard_manager.diff_guard_sources = existing_diff_guard_sources guard_manager.diff_guard_sources = existing_diff_guard_sources
@ -2754,6 +2787,7 @@ class CheckFunctionManager:
output_graph.global_scope, output_graph.global_scope,
guard_manager, guard_manager,
self, self,
serialization_mode,
) )
# Break retain cycle. See test_release_scope_memory # Break retain cycle. See test_release_scope_memory