From 74d03008045b888c6cb4099d2fe80ba4d2adbb7b Mon Sep 17 00:00:00 2001 From: James Wu Date: Fri, 9 May 2025 15:29:36 +0000 Subject: [PATCH] Change unsafe_marked_cacheable_functions to a dictionary, so that you can specify a static cache key (#152486) Fixes https://github.com/pytorch/pytorch/issues/152434 Pull Request resolved: https://github.com/pytorch/pytorch/pull/152486 Approved by: https://github.com/oulgen --- test/dynamo/test_aot_autograd_cache.py | 34 +++++++++++++++++++++++++- torch/_inductor/config.py | 15 ++++++++++-- 2 files changed, 46 insertions(+), 3 deletions(-) diff --git a/test/dynamo/test_aot_autograd_cache.py b/test/dynamo/test_aot_autograd_cache.py index b350347413c..0267a791f82 100644 --- a/test/dynamo/test_aot_autograd_cache.py +++ b/test/dynamo/test_aot_autograd_cache.py @@ -375,7 +375,9 @@ class AOTAutogradCacheTests(InductorTestCase): "Allow in graph produces an unserializable cache artifact" ) - with inductor_config.patch("unsafe_marked_cacheable_functions", [fn_name]): + with inductor_config.patch( + "unsafe_marked_cacheable_functions", {fn_name: "key1"} + ): fn(*args) self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 1) @@ -390,6 +392,36 @@ class AOTAutogradCacheTests(InductorTestCase): self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 1) self.assertEqual(counters["aot_autograd"]["autograd_cache_bypass"], 1) + self._clear_dynamo_and_codecache() + with inductor_config.patch( + "unsafe_marked_cacheable_functions", {fn_name: "key2"} + ): + fn(*args) + + self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 2) + self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 1) + self.assertEqual(counters["aot_autograd"]["autograd_cache_bypass"], 1) + + self._clear_dynamo_and_codecache() + + fn(*args) + + self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 2) + self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 2) + self.assertEqual(counters["aot_autograd"]["autograd_cache_bypass"], 1) + + # On second try with same key, it should hit once more + with inductor_config.patch( + "unsafe_marked_cacheable_functions", {fn_name: "key1"} + ): + self._clear_dynamo_and_codecache() + + fn(*args) + + self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 2) + self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 3) + self.assertEqual(counters["aot_autograd"]["autograd_cache_bypass"], 1) + @inductor_config.patch("fx_graph_remote_cache", False) @inductor_config.patch("fx_graph_cache", False) @functorch_config.patch({"enable_autograd_cache": True}) diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index c233ada0513..980cc28ba7b 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -145,8 +145,19 @@ force_disable_caches: bool = Config( # Unsafe way to skip dynamic shape guards to get faster cache load unsafe_skip_cache_dynamic_shape_guards: bool = False -# Unsafe way to mark function as cacheable -unsafe_marked_cacheable_functions: list[str] = [] +# Unsafe way to mark non torch functions as safe to cache +# dictionary is from function name -> cache key +# Any function name in the dictionary will be allowed to be cacheable +# by AOTAutogradCache and FxGraphCache. +# changing the cache key value will change the resulting +# FXGraphCache key. +# Example usage: +# torch._inductor.config.unsafe_marked_cacheable_functions = { +# 'torch.ops.my_function' : torch.__version__ +# } +# The above example causes the custom op torch.ops.my_function to be cacheable, +# and for cache keys to be keyed by the current torch version +unsafe_marked_cacheable_functions: dict[str, str] = {} # sleep in inductor for testing sleep_sec_TESTING_ONLY: Optional[int] = None