mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Change unsafe_marked_cacheable_functions to a dictionary, so that you can specify a static cache key (#152486)
Fixes https://github.com/pytorch/pytorch/issues/152434 Pull Request resolved: https://github.com/pytorch/pytorch/pull/152486 Approved by: https://github.com/oulgen
This commit is contained in:
parent
694748dd9d
commit
74d0300804
|
|
@ -375,7 +375,9 @@ class AOTAutogradCacheTests(InductorTestCase):
|
|||
"Allow in graph produces an unserializable cache artifact"
|
||||
)
|
||||
|
||||
with inductor_config.patch("unsafe_marked_cacheable_functions", [fn_name]):
|
||||
with inductor_config.patch(
|
||||
"unsafe_marked_cacheable_functions", {fn_name: "key1"}
|
||||
):
|
||||
fn(*args)
|
||||
|
||||
self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 1)
|
||||
|
|
@ -390,6 +392,36 @@ class AOTAutogradCacheTests(InductorTestCase):
|
|||
self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 1)
|
||||
self.assertEqual(counters["aot_autograd"]["autograd_cache_bypass"], 1)
|
||||
|
||||
self._clear_dynamo_and_codecache()
|
||||
with inductor_config.patch(
|
||||
"unsafe_marked_cacheable_functions", {fn_name: "key2"}
|
||||
):
|
||||
fn(*args)
|
||||
|
||||
self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 2)
|
||||
self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 1)
|
||||
self.assertEqual(counters["aot_autograd"]["autograd_cache_bypass"], 1)
|
||||
|
||||
self._clear_dynamo_and_codecache()
|
||||
|
||||
fn(*args)
|
||||
|
||||
self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 2)
|
||||
self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 2)
|
||||
self.assertEqual(counters["aot_autograd"]["autograd_cache_bypass"], 1)
|
||||
|
||||
# On second try with same key, it should hit once more
|
||||
with inductor_config.patch(
|
||||
"unsafe_marked_cacheable_functions", {fn_name: "key1"}
|
||||
):
|
||||
self._clear_dynamo_and_codecache()
|
||||
|
||||
fn(*args)
|
||||
|
||||
self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 2)
|
||||
self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 3)
|
||||
self.assertEqual(counters["aot_autograd"]["autograd_cache_bypass"], 1)
|
||||
|
||||
@inductor_config.patch("fx_graph_remote_cache", False)
|
||||
@inductor_config.patch("fx_graph_cache", False)
|
||||
@functorch_config.patch({"enable_autograd_cache": True})
|
||||
|
|
|
|||
|
|
@ -145,8 +145,19 @@ force_disable_caches: bool = Config(
|
|||
# Unsafe way to skip dynamic shape guards to get faster cache load
|
||||
unsafe_skip_cache_dynamic_shape_guards: bool = False
|
||||
|
||||
# Unsafe way to mark function as cacheable
|
||||
unsafe_marked_cacheable_functions: list[str] = []
|
||||
# Unsafe way to mark non torch functions as safe to cache
|
||||
# dictionary is from function name -> cache key
|
||||
# Any function name in the dictionary will be allowed to be cacheable
|
||||
# by AOTAutogradCache and FxGraphCache.
|
||||
# changing the cache key value will change the resulting
|
||||
# FXGraphCache key.
|
||||
# Example usage:
|
||||
# torch._inductor.config.unsafe_marked_cacheable_functions = {
|
||||
# 'torch.ops.my_function' : torch.__version__
|
||||
# }
|
||||
# The above example causes the custom op torch.ops.my_function to be cacheable,
|
||||
# and for cache keys to be keyed by the current torch version
|
||||
unsafe_marked_cacheable_functions: dict[str, str] = {}
|
||||
|
||||
# sleep in inductor for testing
|
||||
sleep_sec_TESTING_ONLY: Optional[int] = None
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user