[precompile] Filter out ID_MATCH family of guards with caching_precompile. (#158368)

Summary: For case like caching_precompile, we almost always want to drop ID_MATCH-type guards since they will block serialization. This diff add this behavior when this global flag is toggled on so that ID_MATCH guards are excluded from compilation and serialization.

Test Plan:
test_dynamo -- -k test_id_match_with_config

Rollback Plan:

Differential Revision: D78363609

Pull Request resolved: https://github.com/pytorch/pytorch/pull/158368
Approved by: https://github.com/jamesjwu
This commit is contained in:
Zhengxu Chen 2025-07-18 14:47:11 +00:00 committed by PyTorch MergeBot
parent e882c761dd
commit 036eb1f65d
2 changed files with 48 additions and 3 deletions

View File

@ -878,6 +878,23 @@ class TestGuardSerialization(torch._inductor.test_case.TestCase):
):
self._test_serialization("ID_MATCH", fn, torch.randn(3))
@torch._dynamo.config.patch(caching_precompile=True)
def test_id_match_with_config(self):
def fn(x):
return x + id(x)
ref, loaded = self._test_serialization("ID_MATCH", fn, torch.randn(3))
self._test_check_fn(ref, loaded, {"x": torch.randn(3)}, True)
def fn(x):
# usage of this context manager installs a FUNCTION_MATCH guard
with torch.no_grad():
y = x * 2
return y
ref, loaded = self._test_serialization("FUNCTION_MATCH", fn, torch.randn(3))
self._test_check_fn(ref, loaded, {"x": torch.randn(3)}, True)
def test_dispatch_key_set_match(self):
def fn(x, dks):
if dks.has("CPU"):

View File

@ -1591,7 +1591,7 @@ class GuardBuilder(GuardBuilderBase):
val = self.get(guard.name)
id_val = self.id_ref(val, guard.name)
code = f"___check_obj_id({ref}, {id_val})"
self._set_guard_export_info(guard, [code])
self._set_guard_export_info(guard, [code], provided_func_name="ID_MATCH")
self.get_guard_manager(guard).add_id_match_guard(
id_val, get_verbose_code_parts(code, guard)
@ -2473,7 +2473,9 @@ class GuardBuilder(GuardBuilderBase):
self._set_guard_export_info(guard, code)
# A util that in the case of export, adds data onto guards
def _set_guard_export_info(self, guard, code_list, provided_guarded_object=None):
def _set_guard_export_info(
self, guard, code_list, provided_guarded_object=None, provided_func_name=None
):
# WARNING: It is important that cur_frame/caller do NOT stay in
# the current frame, because they will keep things live longer
# than they should. See TestMisc.test_release_module_memory
@ -2482,7 +2484,7 @@ class GuardBuilder(GuardBuilderBase):
caller = cur_frame.f_back
del cur_frame
assert caller is not None
func_name = caller.f_code.co_name
func_name = provided_func_name or caller.f_code.co_name
del caller
# We use func_name for export, so might as well get a nice defensive check out of it
assert func_name in self.__class__.__dict__, (
@ -2842,6 +2844,32 @@ class CheckFunctionManager:
if not justknobs_check("pytorch/compiler:guard_nn_modules"):
log.warning("guard_nn_modules is turned off using justknobs killswitch")
# TODO Be more explicit about the behavior for the users.
if (
torch._dynamo.config.caching_precompile
and self.guards_serialization_mode != "load"
):
_guard_filter_fn = guard_filter_fn or (lambda gs: [True for g in gs])
def guard_filter_fn(guards):
ret = []
for keep, g in zip(_guard_filter_fn(guards), guards):
if not keep:
ret.append(False)
elif (
g.guard_type in ("ID_MATCH", "CLOSURE_MATCH", "WEAKREF_ALIVE")
or "ID_MATCH" in g.derived_guard_types
):
log.warning(
"%s guard on %s is dropped with caching_precompile=True.",
g.guard_type,
g.orig_guard.name,
)
ret.append(False)
else:
ret.append(True)
return ret
sorted_guards = sorted(guards or (), key=Guard.sort_key)
builder, guard_manager = self.build_guards(
sorted_guards,