mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[dynamo] Guard serialization for TYPE_MATCH (#152325)
Adding guard serialization for TYPE_MATCH Differential Revision: [D73780438](https://our.internmc.facebook.com/intern/diff/D73780438/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/152325 Approved by: https://github.com/jansel
This commit is contained in:
parent
a04f4622e1
commit
df663b9e72
|
|
@ -26,6 +26,11 @@ class _FrameState:
|
||||||
f_builtins: dict
|
f_builtins: dict
|
||||||
|
|
||||||
|
|
||||||
|
class GlobalModule(torch.nn.Module):
|
||||||
|
def forward(self, x):
|
||||||
|
return x + 1
|
||||||
|
|
||||||
|
|
||||||
class TestGuardSerialization(torch._inductor.test_case.TestCase):
|
class TestGuardSerialization(torch._inductor.test_case.TestCase):
|
||||||
def _tracefunc(self, frame, event, arg):
|
def _tracefunc(self, frame, event, arg):
|
||||||
if event != "call":
|
if event != "call":
|
||||||
|
|
@ -184,6 +189,27 @@ class TestGuardSerialization(torch._inductor.test_case.TestCase):
|
||||||
delattr(m, "a")
|
delattr(m, "a")
|
||||||
self._test_check_fn(ref, loaded, {"m": m}, False)
|
self._test_check_fn(ref, loaded, {"m": m}, False)
|
||||||
|
|
||||||
|
def test_type_match(self):
|
||||||
|
class LocalModule(torch.nn.Module):
|
||||||
|
def forward(self, x: torch.Tensor):
|
||||||
|
return x + 1
|
||||||
|
|
||||||
|
m = LocalModule()
|
||||||
|
|
||||||
|
def fn(m, x):
|
||||||
|
return m(x)
|
||||||
|
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
TypeError, "Please define the class at global scope"
|
||||||
|
):
|
||||||
|
self._test_serialization("TYPE_MATCH", fn, m, torch.randn(3))
|
||||||
|
|
||||||
|
m = GlobalModule()
|
||||||
|
ref, loaded = self._test_serialization("TYPE_MATCH", fn, m, torch.randn(3))
|
||||||
|
self._test_check_fn(ref, loaded, {"m": m}, True)
|
||||||
|
self._test_check_fn(ref, loaded, {"m": GlobalModule()}, True)
|
||||||
|
self._test_check_fn(ref, loaded, {"m": torch.nn.Module()}, False)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
from torch._dynamo.test_case import run_tests
|
from torch._dynamo.test_case import run_tests
|
||||||
|
|
|
||||||
|
|
@ -39,7 +39,7 @@ import weakref
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from inspect import currentframe
|
from inspect import currentframe
|
||||||
from typing import Any, Callable, Optional, TYPE_CHECKING, Union
|
from typing import Any, Callable, NoReturn, Optional, TYPE_CHECKING, Union
|
||||||
from weakref import ReferenceType
|
from weakref import ReferenceType
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
@ -524,6 +524,14 @@ def get_key_index_source(source, index):
|
||||||
return f"list(dict.keys({source}))[{index}]"
|
return f"list(dict.keys({source}))[{index}]"
|
||||||
|
|
||||||
|
|
||||||
|
def raise_local_type_error(obj: Any) -> NoReturn:
|
||||||
|
raise TypeError(
|
||||||
|
f"Type {type(obj)} for object {obj} cannot be saved "
|
||||||
|
+ "into torch.compile() package since it's defined in local scope. "
|
||||||
|
+ "Please define the class at global scope (top level of a module)."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass(frozen=True)
|
@dataclasses.dataclass(frozen=True)
|
||||||
class NNModuleAttrAccessorInfo:
|
class NNModuleAttrAccessorInfo:
|
||||||
# Represents where is the attr name is present in the nn module attribute
|
# Represents where is the attr name is present in the nn module attribute
|
||||||
|
|
@ -622,6 +630,7 @@ class GuardBuilder(GuardBuilderBase):
|
||||||
global_scope: dict[str, object],
|
global_scope: dict[str, object],
|
||||||
guard_manager: GuardManagerWrapper,
|
guard_manager: GuardManagerWrapper,
|
||||||
check_fn_manager: CheckFunctionManager,
|
check_fn_manager: CheckFunctionManager,
|
||||||
|
serialization_mode: Optional[str] = None,
|
||||||
):
|
):
|
||||||
self.f_code = f_code
|
self.f_code = f_code
|
||||||
self.id_ref = id_ref
|
self.id_ref = id_ref
|
||||||
|
|
@ -674,6 +683,7 @@ class GuardBuilder(GuardBuilderBase):
|
||||||
str, torch._C._dynamo.guards.GuardManager
|
str, torch._C._dynamo.guards.GuardManager
|
||||||
] = {}
|
] = {}
|
||||||
self._cached_duplicate_input_guards: set[tuple[str, str]] = set()
|
self._cached_duplicate_input_guards: set[tuple[str, str]] = set()
|
||||||
|
self.serialization_mode = serialization_mode
|
||||||
|
|
||||||
def guard_on_dict_keys_and_ignore_order(self, example_value, guard):
|
def guard_on_dict_keys_and_ignore_order(self, example_value, guard):
|
||||||
dict_mgr = self.get_guard_manager(guard)
|
dict_mgr = self.get_guard_manager(guard)
|
||||||
|
|
@ -1436,7 +1446,12 @@ class GuardBuilder(GuardBuilderBase):
|
||||||
|
|
||||||
def TYPE_MATCH(self, guard: Guard) -> None:
|
def TYPE_MATCH(self, guard: Guard) -> None:
|
||||||
# ___check_type_id is same as `id(type(x)) == y`
|
# ___check_type_id is same as `id(type(x)) == y`
|
||||||
t = type(self.get(guard.name))
|
value = self.get(guard.name)
|
||||||
|
t = type(value)
|
||||||
|
if self.serialization_mode == "save":
|
||||||
|
if t.__qualname__ != t.__name__:
|
||||||
|
raise_local_type_error(value)
|
||||||
|
|
||||||
obj_id = self.id_ref(t, f"type({guard.name})")
|
obj_id = self.id_ref(t, f"type({guard.name})")
|
||||||
code = f"___check_type_id({self.arg_ref(guard)}, {obj_id})"
|
code = f"___check_type_id({self.arg_ref(guard)}, {obj_id})"
|
||||||
self._set_guard_export_info(guard, [code])
|
self._set_guard_export_info(guard, [code])
|
||||||
|
|
@ -2495,6 +2510,10 @@ class GuardsStatePickler(pickle.Pickler):
|
||||||
torch._C.DispatchKeySet.from_raw_repr(dispatch_keys_raw),
|
torch._C.DispatchKeySet.from_raw_repr(dispatch_keys_raw),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _unpickle_python_module(cls, alias: str):
|
||||||
|
return importlib.import_module(alias)
|
||||||
|
|
||||||
def reducer_override(self, obj):
|
def reducer_override(self, obj):
|
||||||
if isinstance(obj, torch.Tensor) and obj.device.type != "meta":
|
if isinstance(obj, torch.Tensor) and obj.device.type != "meta":
|
||||||
return type(self)._unpickle_tensor, (
|
return type(self)._unpickle_tensor, (
|
||||||
|
|
@ -2505,9 +2524,14 @@ class GuardsStatePickler(pickle.Pickler):
|
||||||
)
|
)
|
||||||
|
|
||||||
elif isinstance(obj, torch.nn.Module):
|
elif isinstance(obj, torch.nn.Module):
|
||||||
|
if type(obj).__qualname__ == type(obj).__name__:
|
||||||
|
return NotImplemented
|
||||||
if obj.__class__.__getstate__ == torch.nn.Module.__getstate__:
|
if obj.__class__.__getstate__ == torch.nn.Module.__getstate__:
|
||||||
return type(self)._unpickle_module, (obj.__getstate__(),)
|
return type(self)._unpickle_module, (obj.__getstate__(),)
|
||||||
|
|
||||||
|
elif inspect.ismodule(obj):
|
||||||
|
return type(self)._unpickle_python_module, (obj.__name__,)
|
||||||
|
|
||||||
if type(obj).__qualname__ != type(obj).__name__:
|
if type(obj).__qualname__ != type(obj).__name__:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"Type {type(obj)} for object {obj} cannot be saved "
|
f"Type {type(obj)} for object {obj} cannot be saved "
|
||||||
|
|
@ -2562,7 +2586,11 @@ class CheckFunctionManager:
|
||||||
|
|
||||||
sorted_guards = sorted(guards or (), key=Guard.sort_key)
|
sorted_guards = sorted(guards or (), key=Guard.sort_key)
|
||||||
builder, guard_manager = self.build_guards(
|
builder, guard_manager = self.build_guards(
|
||||||
sorted_guards, existing_diff_guard_sources, f_code, output_graph
|
sorted_guards,
|
||||||
|
existing_diff_guard_sources,
|
||||||
|
f_code,
|
||||||
|
output_graph,
|
||||||
|
None if guard_filter_fn else self.guards_serialization_mode,
|
||||||
)
|
)
|
||||||
|
|
||||||
if guard_filter_fn:
|
if guard_filter_fn:
|
||||||
|
|
@ -2602,7 +2630,11 @@ class CheckFunctionManager:
|
||||||
]
|
]
|
||||||
# Redo the guards because filtering relies on the results from the last guard builder.
|
# Redo the guards because filtering relies on the results from the last guard builder.
|
||||||
builder, guard_manager = self.build_guards(
|
builder, guard_manager = self.build_guards(
|
||||||
sorted_guards, existing_diff_guard_sources, f_code, output_graph
|
sorted_guards,
|
||||||
|
existing_diff_guard_sources,
|
||||||
|
f_code,
|
||||||
|
output_graph,
|
||||||
|
self.guards_serialization_mode,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.guard_manager = guard_manager
|
self.guard_manager = guard_manager
|
||||||
|
|
@ -2729,6 +2761,7 @@ class CheckFunctionManager:
|
||||||
existing_diff_guard_sources,
|
existing_diff_guard_sources,
|
||||||
f_code,
|
f_code,
|
||||||
output_graph,
|
output_graph,
|
||||||
|
serialization_mode=None,
|
||||||
):
|
):
|
||||||
guard_manager = GuardManagerWrapper()
|
guard_manager = GuardManagerWrapper()
|
||||||
guard_manager.diff_guard_sources = existing_diff_guard_sources
|
guard_manager.diff_guard_sources = existing_diff_guard_sources
|
||||||
|
|
@ -2754,6 +2787,7 @@ class CheckFunctionManager:
|
||||||
output_graph.global_scope,
|
output_graph.global_scope,
|
||||||
guard_manager,
|
guard_manager,
|
||||||
self,
|
self,
|
||||||
|
serialization_mode,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Break retain cycle. See test_release_scope_memory
|
# Break retain cycle. See test_release_scope_memory
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user