Change unsafe_marked_cacheable_functions to a dictionary, so that you can specify a static cache key (#152486)

Fixes https://github.com/pytorch/pytorch/issues/152434

Pull Request resolved: https://github.com/pytorch/pytorch/pull/152486
Approved by: https://github.com/oulgen
This commit is contained in:
James Wu 2025-05-09 15:29:36 +00:00 committed by PyTorch MergeBot
parent 694748dd9d
commit 74d0300804
2 changed files with 46 additions and 3 deletions

View File

@ -375,7 +375,9 @@ class AOTAutogradCacheTests(InductorTestCase):
"Allow in graph produces an unserializable cache artifact"
)
with inductor_config.patch("unsafe_marked_cacheable_functions", [fn_name]):
with inductor_config.patch(
"unsafe_marked_cacheable_functions", {fn_name: "key1"}
):
fn(*args)
self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 1)
@ -390,6 +392,36 @@ class AOTAutogradCacheTests(InductorTestCase):
self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 1)
self.assertEqual(counters["aot_autograd"]["autograd_cache_bypass"], 1)
self._clear_dynamo_and_codecache()
with inductor_config.patch(
"unsafe_marked_cacheable_functions", {fn_name: "key2"}
):
fn(*args)
self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 2)
self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 1)
self.assertEqual(counters["aot_autograd"]["autograd_cache_bypass"], 1)
self._clear_dynamo_and_codecache()
fn(*args)
self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 2)
self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 2)
self.assertEqual(counters["aot_autograd"]["autograd_cache_bypass"], 1)
# On second try with same key, it should hit once more
with inductor_config.patch(
"unsafe_marked_cacheable_functions", {fn_name: "key1"}
):
self._clear_dynamo_and_codecache()
fn(*args)
self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 2)
self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 3)
self.assertEqual(counters["aot_autograd"]["autograd_cache_bypass"], 1)
@inductor_config.patch("fx_graph_remote_cache", False)
@inductor_config.patch("fx_graph_cache", False)
@functorch_config.patch({"enable_autograd_cache": True})

View File

@ -145,8 +145,19 @@ force_disable_caches: bool = Config(
# Unsafe way to skip dynamic shape guards to get faster cache load
unsafe_skip_cache_dynamic_shape_guards: bool = False
# Unsafe way to mark function as cacheable
unsafe_marked_cacheable_functions: list[str] = []
# Unsafe way to mark non torch functions as safe to cache
# dictionary is from function name -> cache key
# Any function name in the dictionary will be allowed to be cacheable
# by AOTAutogradCache and FxGraphCache.
# changing the cache key value will change the resulting
# FXGraphCache key.
# Example usage:
# torch._inductor.config.unsafe_marked_cacheable_functions = {
# 'torch.ops.my_function' : torch.__version__
# }
# The above example causes the custom op torch.ops.my_function to be cacheable,
# and for cache keys to be keyed by the current torch version
unsafe_marked_cacheable_functions: dict[str, str] = {}
# sleep in inductor for testing
sleep_sec_TESTING_ONLY: Optional[int] = None