[dynamo] Guard serialization for NOT_PRESENT_IN_GENERIC_DICT (#151343)

Adding guard serialization for type NOT_PRESENT_IN_GENERIC_DICT

Differential Revision: [D73057304](https://our.internmc.facebook.com/intern/diff/D73057304/)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/151343
Approved by: https://github.com/jansel, https://github.com/anijain2305
ghstack dependencies: #151318
This commit is contained in:
zhxchen17 2025-04-24 14:09:19 -07:00 committed by PyTorch MergeBot
parent a34c28e0d2
commit 558f45190e
2 changed files with 43 additions and 7 deletions

View File

@ -13,7 +13,7 @@ import torch.onnx.operators
import torch.utils.cpp_extension
from torch._dynamo.bytecode_transformation import transform_code_object
from torch._dynamo.guards import CheckFunctionManager, CompileId
from torch._dynamo.symbolic_convert import InstructionTranslator
from torch._dynamo.symbolic_convert import InstructionTranslator, SpeculationLog
from torch._dynamo.utils import dynamo_timed, get_metrics_context
from torch._guards import compile_context, CompileContext, tracing
@ -35,8 +35,8 @@ class TestGuardSerialization(torch._inductor.test_case.TestCase):
return
self._frame_state = _FrameState(
f_locals=frame.f_locals,
f_globals=frame.f_globals,
f_locals=dict(frame.f_locals),
f_globals=dict(frame.f_globals),
f_code=frame.f_code,
f_builtins=frame.f_builtins,
)
@ -44,6 +44,8 @@ class TestGuardSerialization(torch._inductor.test_case.TestCase):
def _test_serialization(self, guard_type, fn, *args, **kwargs):
self._frame_state = None
sys.settrace(self._tracefunc)
if isinstance(fn, torch.nn.Module):
fn = fn.forward
try:
fn(*args, **kwargs)
finally:
@ -52,7 +54,9 @@ class TestGuardSerialization(torch._inductor.test_case.TestCase):
assert self._frame_state is not None
def guard_filter_fn(guards):
return [g.guard_type == guard_type for g in guards]
ret = [g.guard_type == guard_type for g in guards]
self.assertTrue(any(ret))
return ret
ref_gm = None
loaded_gm = None
@ -73,7 +77,7 @@ class TestGuardSerialization(torch._inductor.test_case.TestCase):
self._frame_state.f_locals,
self._frame_state.f_globals,
self._frame_state.f_builtins,
(), # TODO closure
fn.__closure__ or (),
[], # TODO tf_mode_stack,
code_options,
lambda gm, *args, **kwargs: gm.forward,
@ -81,7 +85,7 @@ class TestGuardSerialization(torch._inductor.test_case.TestCase):
export=False,
export_constraints=None,
frame_state=None,
speculation_log=None,
speculation_log=SpeculationLog(),
exn_vt_stack=None,
distributed_state=None,
)
@ -140,6 +144,24 @@ class TestGuardSerialization(torch._inductor.test_case.TestCase):
)
self._test_check_fn(ref, loaded, {"x": None}, False)
def test_not_present_in_generic_dict(self):
class Module(torch.nn.Module):
def forward(self, x: torch.Tensor):
return x + 1
m = Module()
def fn(x):
return m(x)
ref, loaded = self._test_serialization(
"NOT_PRESENT_IN_GENERIC_DICT", fn, torch.ones(2, dtype=torch.float32)
)
self._test_check_fn(ref, loaded, {"m": m}, True)
m.forward = types.MethodType(lambda x: x + 2, m)
self._test_check_fn(ref, loaded, {"m": m}, False)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests

View File

@ -2679,6 +2679,20 @@ class CheckFunctionManager:
assert isinstance(name, str)
used_global_vars.add(name)
def normalize_create_fn(x):
if isinstance(x, functools.partial):
def _ref(x):
if isinstance(x, (TensorWeakRef, weakref.ref)):
return x()
return x
new_args = tuple(_ref(a) for a in x.args)
new_keywords = {k: _ref(v) for k, v in x.keywords.items()}
return functools.partial(x.func, *new_args, **new_keywords)
return x
output_graph_guards_state = dataclasses.replace(
output_graph_guards_state,
global_scope={
@ -2692,7 +2706,7 @@ class CheckFunctionManager:
guard,
obj_weakref=None,
guarded_class_weakref=None,
create_fn=guard.inner_create_fn(),
create_fn=normalize_create_fn(guard.create_fn),
)
for guard in sorted_guards
}