mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[dynamo] Guard serialization for NOT_PRESENT_IN_GENERIC_DICT (#151343)
Adding guard serialization for type NOT_PRESENT_IN_GENERIC_DICT Differential Revision: [D73057304](https://our.internmc.facebook.com/intern/diff/D73057304/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/151343 Approved by: https://github.com/jansel, https://github.com/anijain2305 ghstack dependencies: #151318
This commit is contained in:
parent
a34c28e0d2
commit
558f45190e
|
|
@ -13,7 +13,7 @@ import torch.onnx.operators
|
|||
import torch.utils.cpp_extension
|
||||
from torch._dynamo.bytecode_transformation import transform_code_object
|
||||
from torch._dynamo.guards import CheckFunctionManager, CompileId
|
||||
from torch._dynamo.symbolic_convert import InstructionTranslator
|
||||
from torch._dynamo.symbolic_convert import InstructionTranslator, SpeculationLog
|
||||
from torch._dynamo.utils import dynamo_timed, get_metrics_context
|
||||
from torch._guards import compile_context, CompileContext, tracing
|
||||
|
||||
|
|
@ -35,8 +35,8 @@ class TestGuardSerialization(torch._inductor.test_case.TestCase):
|
|||
return
|
||||
|
||||
self._frame_state = _FrameState(
|
||||
f_locals=frame.f_locals,
|
||||
f_globals=frame.f_globals,
|
||||
f_locals=dict(frame.f_locals),
|
||||
f_globals=dict(frame.f_globals),
|
||||
f_code=frame.f_code,
|
||||
f_builtins=frame.f_builtins,
|
||||
)
|
||||
|
|
@ -44,6 +44,8 @@ class TestGuardSerialization(torch._inductor.test_case.TestCase):
|
|||
def _test_serialization(self, guard_type, fn, *args, **kwargs):
|
||||
self._frame_state = None
|
||||
sys.settrace(self._tracefunc)
|
||||
if isinstance(fn, torch.nn.Module):
|
||||
fn = fn.forward
|
||||
try:
|
||||
fn(*args, **kwargs)
|
||||
finally:
|
||||
|
|
@ -52,7 +54,9 @@ class TestGuardSerialization(torch._inductor.test_case.TestCase):
|
|||
assert self._frame_state is not None
|
||||
|
||||
def guard_filter_fn(guards):
|
||||
return [g.guard_type == guard_type for g in guards]
|
||||
ret = [g.guard_type == guard_type for g in guards]
|
||||
self.assertTrue(any(ret))
|
||||
return ret
|
||||
|
||||
ref_gm = None
|
||||
loaded_gm = None
|
||||
|
|
@ -73,7 +77,7 @@ class TestGuardSerialization(torch._inductor.test_case.TestCase):
|
|||
self._frame_state.f_locals,
|
||||
self._frame_state.f_globals,
|
||||
self._frame_state.f_builtins,
|
||||
(), # TODO closure
|
||||
fn.__closure__ or (),
|
||||
[], # TODO tf_mode_stack,
|
||||
code_options,
|
||||
lambda gm, *args, **kwargs: gm.forward,
|
||||
|
|
@ -81,7 +85,7 @@ class TestGuardSerialization(torch._inductor.test_case.TestCase):
|
|||
export=False,
|
||||
export_constraints=None,
|
||||
frame_state=None,
|
||||
speculation_log=None,
|
||||
speculation_log=SpeculationLog(),
|
||||
exn_vt_stack=None,
|
||||
distributed_state=None,
|
||||
)
|
||||
|
|
@ -140,6 +144,24 @@ class TestGuardSerialization(torch._inductor.test_case.TestCase):
|
|||
)
|
||||
self._test_check_fn(ref, loaded, {"x": None}, False)
|
||||
|
||||
def test_not_present_in_generic_dict(self):
|
||||
class Module(torch.nn.Module):
|
||||
def forward(self, x: torch.Tensor):
|
||||
return x + 1
|
||||
|
||||
m = Module()
|
||||
|
||||
def fn(x):
|
||||
return m(x)
|
||||
|
||||
ref, loaded = self._test_serialization(
|
||||
"NOT_PRESENT_IN_GENERIC_DICT", fn, torch.ones(2, dtype=torch.float32)
|
||||
)
|
||||
self._test_check_fn(ref, loaded, {"m": m}, True)
|
||||
|
||||
m.forward = types.MethodType(lambda x: x + 2, m)
|
||||
self._test_check_fn(ref, loaded, {"m": m}, False)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._dynamo.test_case import run_tests
|
||||
|
|
|
|||
|
|
@ -2679,6 +2679,20 @@ class CheckFunctionManager:
|
|||
assert isinstance(name, str)
|
||||
used_global_vars.add(name)
|
||||
|
||||
def normalize_create_fn(x):
|
||||
if isinstance(x, functools.partial):
|
||||
|
||||
def _ref(x):
|
||||
if isinstance(x, (TensorWeakRef, weakref.ref)):
|
||||
return x()
|
||||
return x
|
||||
|
||||
new_args = tuple(_ref(a) for a in x.args)
|
||||
new_keywords = {k: _ref(v) for k, v in x.keywords.items()}
|
||||
return functools.partial(x.func, *new_args, **new_keywords)
|
||||
|
||||
return x
|
||||
|
||||
output_graph_guards_state = dataclasses.replace(
|
||||
output_graph_guards_state,
|
||||
global_scope={
|
||||
|
|
@ -2692,7 +2706,7 @@ class CheckFunctionManager:
|
|||
guard,
|
||||
obj_weakref=None,
|
||||
guarded_class_weakref=None,
|
||||
create_fn=guard.inner_create_fn(),
|
||||
create_fn=normalize_create_fn(guard.create_fn),
|
||||
)
|
||||
for guard in sorted_guards
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user