mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[precompile] Filter out ID_MATCH family of guards with caching_precompile. (#158368)
Summary: For case like caching_precompile, we almost always want to drop ID_MATCH-type guards since they will block serialization. This diff add this behavior when this global flag is toggled on so that ID_MATCH guards are excluded from compilation and serialization. Test Plan: test_dynamo -- -k test_id_match_with_config Rollback Plan: Differential Revision: D78363609 Pull Request resolved: https://github.com/pytorch/pytorch/pull/158368 Approved by: https://github.com/jamesjwu
This commit is contained in:
parent
e882c761dd
commit
036eb1f65d
|
|
@ -878,6 +878,23 @@ class TestGuardSerialization(torch._inductor.test_case.TestCase):
|
|||
):
|
||||
self._test_serialization("ID_MATCH", fn, torch.randn(3))
|
||||
|
||||
@torch._dynamo.config.patch(caching_precompile=True)
|
||||
def test_id_match_with_config(self):
|
||||
def fn(x):
|
||||
return x + id(x)
|
||||
|
||||
ref, loaded = self._test_serialization("ID_MATCH", fn, torch.randn(3))
|
||||
self._test_check_fn(ref, loaded, {"x": torch.randn(3)}, True)
|
||||
|
||||
def fn(x):
|
||||
# usage of this context manager installs a FUNCTION_MATCH guard
|
||||
with torch.no_grad():
|
||||
y = x * 2
|
||||
return y
|
||||
|
||||
ref, loaded = self._test_serialization("FUNCTION_MATCH", fn, torch.randn(3))
|
||||
self._test_check_fn(ref, loaded, {"x": torch.randn(3)}, True)
|
||||
|
||||
def test_dispatch_key_set_match(self):
|
||||
def fn(x, dks):
|
||||
if dks.has("CPU"):
|
||||
|
|
|
|||
|
|
@ -1591,7 +1591,7 @@ class GuardBuilder(GuardBuilderBase):
|
|||
val = self.get(guard.name)
|
||||
id_val = self.id_ref(val, guard.name)
|
||||
code = f"___check_obj_id({ref}, {id_val})"
|
||||
self._set_guard_export_info(guard, [code])
|
||||
self._set_guard_export_info(guard, [code], provided_func_name="ID_MATCH")
|
||||
|
||||
self.get_guard_manager(guard).add_id_match_guard(
|
||||
id_val, get_verbose_code_parts(code, guard)
|
||||
|
|
@ -2473,7 +2473,9 @@ class GuardBuilder(GuardBuilderBase):
|
|||
self._set_guard_export_info(guard, code)
|
||||
|
||||
# A util that in the case of export, adds data onto guards
|
||||
def _set_guard_export_info(self, guard, code_list, provided_guarded_object=None):
|
||||
def _set_guard_export_info(
|
||||
self, guard, code_list, provided_guarded_object=None, provided_func_name=None
|
||||
):
|
||||
# WARNING: It is important that cur_frame/caller do NOT stay in
|
||||
# the current frame, because they will keep things live longer
|
||||
# than they should. See TestMisc.test_release_module_memory
|
||||
|
|
@ -2482,7 +2484,7 @@ class GuardBuilder(GuardBuilderBase):
|
|||
caller = cur_frame.f_back
|
||||
del cur_frame
|
||||
assert caller is not None
|
||||
func_name = caller.f_code.co_name
|
||||
func_name = provided_func_name or caller.f_code.co_name
|
||||
del caller
|
||||
# We use func_name for export, so might as well get a nice defensive check out of it
|
||||
assert func_name in self.__class__.__dict__, (
|
||||
|
|
@ -2842,6 +2844,32 @@ class CheckFunctionManager:
|
|||
if not justknobs_check("pytorch/compiler:guard_nn_modules"):
|
||||
log.warning("guard_nn_modules is turned off using justknobs killswitch")
|
||||
|
||||
# TODO Be more explicit about the behavior for the users.
|
||||
if (
|
||||
torch._dynamo.config.caching_precompile
|
||||
and self.guards_serialization_mode != "load"
|
||||
):
|
||||
_guard_filter_fn = guard_filter_fn or (lambda gs: [True for g in gs])
|
||||
|
||||
def guard_filter_fn(guards):
|
||||
ret = []
|
||||
for keep, g in zip(_guard_filter_fn(guards), guards):
|
||||
if not keep:
|
||||
ret.append(False)
|
||||
elif (
|
||||
g.guard_type in ("ID_MATCH", "CLOSURE_MATCH", "WEAKREF_ALIVE")
|
||||
or "ID_MATCH" in g.derived_guard_types
|
||||
):
|
||||
log.warning(
|
||||
"%s guard on %s is dropped with caching_precompile=True.",
|
||||
g.guard_type,
|
||||
g.orig_guard.name,
|
||||
)
|
||||
ret.append(False)
|
||||
else:
|
||||
ret.append(True)
|
||||
return ret
|
||||
|
||||
sorted_guards = sorted(guards or (), key=Guard.sort_key)
|
||||
builder, guard_manager = self.build_guards(
|
||||
sorted_guards,
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user