mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Change unsafe_marked_cacheable_functions to a dictionary, so that you can specify a static cache key (#152486)
Fixes https://github.com/pytorch/pytorch/issues/152434 Pull Request resolved: https://github.com/pytorch/pytorch/pull/152486 Approved by: https://github.com/oulgen
This commit is contained in:
parent
694748dd9d
commit
74d0300804
|
|
@ -375,7 +375,9 @@ class AOTAutogradCacheTests(InductorTestCase):
|
||||||
"Allow in graph produces an unserializable cache artifact"
|
"Allow in graph produces an unserializable cache artifact"
|
||||||
)
|
)
|
||||||
|
|
||||||
with inductor_config.patch("unsafe_marked_cacheable_functions", [fn_name]):
|
with inductor_config.patch(
|
||||||
|
"unsafe_marked_cacheable_functions", {fn_name: "key1"}
|
||||||
|
):
|
||||||
fn(*args)
|
fn(*args)
|
||||||
|
|
||||||
self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 1)
|
self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 1)
|
||||||
|
|
@ -390,6 +392,36 @@ class AOTAutogradCacheTests(InductorTestCase):
|
||||||
self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 1)
|
self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 1)
|
||||||
self.assertEqual(counters["aot_autograd"]["autograd_cache_bypass"], 1)
|
self.assertEqual(counters["aot_autograd"]["autograd_cache_bypass"], 1)
|
||||||
|
|
||||||
|
self._clear_dynamo_and_codecache()
|
||||||
|
with inductor_config.patch(
|
||||||
|
"unsafe_marked_cacheable_functions", {fn_name: "key2"}
|
||||||
|
):
|
||||||
|
fn(*args)
|
||||||
|
|
||||||
|
self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 2)
|
||||||
|
self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 1)
|
||||||
|
self.assertEqual(counters["aot_autograd"]["autograd_cache_bypass"], 1)
|
||||||
|
|
||||||
|
self._clear_dynamo_and_codecache()
|
||||||
|
|
||||||
|
fn(*args)
|
||||||
|
|
||||||
|
self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 2)
|
||||||
|
self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 2)
|
||||||
|
self.assertEqual(counters["aot_autograd"]["autograd_cache_bypass"], 1)
|
||||||
|
|
||||||
|
# On second try with same key, it should hit once more
|
||||||
|
with inductor_config.patch(
|
||||||
|
"unsafe_marked_cacheable_functions", {fn_name: "key1"}
|
||||||
|
):
|
||||||
|
self._clear_dynamo_and_codecache()
|
||||||
|
|
||||||
|
fn(*args)
|
||||||
|
|
||||||
|
self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 2)
|
||||||
|
self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 3)
|
||||||
|
self.assertEqual(counters["aot_autograd"]["autograd_cache_bypass"], 1)
|
||||||
|
|
||||||
@inductor_config.patch("fx_graph_remote_cache", False)
|
@inductor_config.patch("fx_graph_remote_cache", False)
|
||||||
@inductor_config.patch("fx_graph_cache", False)
|
@inductor_config.patch("fx_graph_cache", False)
|
||||||
@functorch_config.patch({"enable_autograd_cache": True})
|
@functorch_config.patch({"enable_autograd_cache": True})
|
||||||
|
|
|
||||||
|
|
@ -145,8 +145,19 @@ force_disable_caches: bool = Config(
|
||||||
# Unsafe way to skip dynamic shape guards to get faster cache load
|
# Unsafe way to skip dynamic shape guards to get faster cache load
|
||||||
unsafe_skip_cache_dynamic_shape_guards: bool = False
|
unsafe_skip_cache_dynamic_shape_guards: bool = False
|
||||||
|
|
||||||
# Unsafe way to mark function as cacheable
|
# Unsafe way to mark non torch functions as safe to cache
|
||||||
unsafe_marked_cacheable_functions: list[str] = []
|
# dictionary is from function name -> cache key
|
||||||
|
# Any function name in the dictionary will be allowed to be cacheable
|
||||||
|
# by AOTAutogradCache and FxGraphCache.
|
||||||
|
# changing the cache key value will change the resulting
|
||||||
|
# FXGraphCache key.
|
||||||
|
# Example usage:
|
||||||
|
# torch._inductor.config.unsafe_marked_cacheable_functions = {
|
||||||
|
# 'torch.ops.my_function' : torch.__version__
|
||||||
|
# }
|
||||||
|
# The above example causes the custom op torch.ops.my_function to be cacheable,
|
||||||
|
# and for cache keys to be keyed by the current torch version
|
||||||
|
unsafe_marked_cacheable_functions: dict[str, str] = {}
|
||||||
|
|
||||||
# sleep in inductor for testing
|
# sleep in inductor for testing
|
||||||
sleep_sec_TESTING_ONLY: Optional[int] = None
|
sleep_sec_TESTING_ONLY: Optional[int] = None
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user