Revert "inductor codecache: include private inductor configs in cache key (#153672)"

This reverts commit 2c1cb38d95.

Reverted https://github.com/pytorch/pytorch/pull/153672 on behalf of https://github.com/malfet due to Looks like it regressed pr_time_benchmarks, see ba3f91af97/1 ([comment](https://github.com/pytorch/pytorch/pull/153672#issuecomment-2922759739))
This commit is contained in:
PyTorch MergeBot 2025-05-30 15:54:12 +00:00
parent 4b1f047a33
commit 31f95b5d2e
5 changed files with 4 additions and 135 deletions

View File

@ -56,7 +56,6 @@ from torch.testing._internal.inductor_utils import (
requires_gpu,
requires_triton,
)
from torch.testing._internal.logging_utils import multiple_logs_to_string
from torch.testing._internal.triton_utils import requires_cuda
@ -2121,90 +2120,6 @@ class TestFxGraphCacheHashing(TestCase):
pickler.dumps(details3),
)
def test_hash_private_config_changes(self):
"""
Test that private config settings affect hashes.
"""
with config.patch({"_micro_pipeline_tp": False}):
details1 = FxGraphHashDetails(None, [], {}, [])
details2 = FxGraphHashDetails(None, [], {}, [])
with config.patch({"_micro_pipeline_tp": True}):
details3 = FxGraphHashDetails(None, [], {}, [])
gm = torch.fx.GraphModule({}, torch.fx.Graph())
pickler = FxGraphCachePickler(gm)
self.assertEqual(
pickler.dumps(details1),
pickler.dumps(details2),
)
self.assertNotEqual(
pickler.dumps(details1),
pickler.dumps(details3),
)
def test_non_serializable_custom_passes_causes_cache_miss(self):
class Mod(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.param = torch.nn.Parameter(torch.rand(4, 4))
def forward(self, x):
return x @ self.param
mod1 = Mod()
mod_compiled = torch.compile(mod1)
with torch.no_grad():
x = torch.rand(4, 4)
# miss
mod_compiled(x)
self.assertEqual(counters["inductor"]["fxgraph_cache_bypass"], 0)
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1)
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0)
# hit
torch._dynamo.reset()
mod_compiled(x)
self.assertEqual(counters["inductor"]["fxgraph_cache_bypass"], 0)
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1)
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 1)
torch._dynamo.reset()
counters.clear()
# hit
mod_compiled(x)
self.assertEqual(counters["inductor"]["fxgraph_cache_bypass"], 0)
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 0)
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 1)
with config.patch({"_fuse_ddp_communication_passes": ["new_pass_foo_bar"]}):
# miss (private config changed)
torch._dynamo.reset()
mod_compiled(x)
self.assertEqual(counters["inductor"]["fxgraph_cache_bypass"], 0)
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1)
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 1)
torch._dynamo.reset()
counters.clear()
(codecache_stream,), ctx = multiple_logs_to_string(
"torch._inductor.codecache", "codecache"
)
with ctx(), config.patch(
{"_fuse_ddp_communication_passes": [lambda *args: None]}
):
# bypass (custom pass is not serializable)
mod_compiled(x)
self.assertEqual(counters["inductor"]["fxgraph_cache_bypass"], 1)
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 0)
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0)
counters.clear()
# assert that our bypass is explicit
codecache_logs = codecache_stream.getvalue().strip()
self.assertTrue(
"Bypassing FX Graph Cache because 'Unsupported _fuse_ddp_communication_pass'"
in codecache_logs
)
def test_hash_custom_passes(self):
"""
Test CustomGraphPass usage.

View File

@ -881,7 +881,7 @@ class FxGraphHashDetails:
# Also hash on various system info (including the triton compiler version).
self.torch_version = torch_key()
self.system_info = CacheBase.get_system()
self.inductor_config = config.save_config_portable(ignore_private_configs=False)
self.inductor_config = config.save_config_portable()
# Custom post grad passes should provide an ID to hash.
self.post_grad_custom_pre_pass = self._get_custom_pass_detail(
config.post_grad_custom_pre_pass
@ -889,36 +889,6 @@ class FxGraphHashDetails:
self.post_grad_custom_post_pass = self._get_custom_pass_detail(
config.post_grad_custom_post_pass
)
self._pre_fusion_custom_pass = self._get_custom_pass_detail_unsafe(
config._pre_fusion_custom_pass
)
self._fuse_ddp_communication_passes = self._get_custom_pass_detail_unsafe(
config._fuse_ddp_communication_passes
)
# This is mainly added to handle these two inductor configs, which are (unfortunately)
# sometimes cache safe:
# - _pre_fusion_custom_pass
# - _fuse_ddp_communication_passes
# Their types can be found in `torch/_inductor/config.py`, but:
# - if they are string names, we can cache them safely (one is by default)
# - if any of them are set to custom callables, we will need to cache miss
# Future work is for someone to find any places where these functions are used
# and force them to be of type CustomGraphPass, so we can guarantee serialization.
def _get_custom_pass_detail_unsafe(self, custom_pass: Any) -> Optional[Any]:
if not custom_pass:
return None
if isinstance(custom_pass, list):
return [self._get_custom_pass_detail_unsafe(x) for x in custom_pass]
if isinstance(custom_pass, str):
return custom_pass
if isinstance(custom_pass, CustomGraphPass):
return custom_pass.uuid()
if callable(custom_pass):
# Returning None is safe here because we raise an explicit bypass error
# later if we detect these passes are set to callables
return None
raise AssertionError(f"unknown config type: {str(type(custom_pass))}")
def _get_custom_pass_detail(
self, custom_pass: CustomGraphPassType
@ -1396,14 +1366,6 @@ class FxGraphCache(GuardedCache[CompiledFxGraph]):
for p in (config.post_grad_custom_pre_pass, config.post_grad_custom_post_pass):
if p and (not isinstance(p, CustomGraphPass) or not p.uuid()):
raise BypassFxGraphCache("Unsupported post grad custom pass")
# We should find any users of _pre_fusion_custom_pass and _fuse_ddp_communication_passes
# and ensure they are not passing us raw callables
if config._pre_fusion_custom_pass is not None:
if not isinstance(config._pre_fusion_custom_pass, CustomGraphPass):
raise BypassFxGraphCache("Unsupported _pre_fusion_custom_pass")
for p in config._fuse_ddp_communication_passes:
if callable(p) and not isinstance(p, CustomGraphPass):
raise BypassFxGraphCache("Unsupported _fuse_ddp_communication_pass")
# Freezing can embed constants that wouldn't be static across runs.
if has_frozen_params(gm) and not torch._utils_internal.justknobs_check(

View File

@ -1635,8 +1635,6 @@ _save_config_ignore: list[str] = [
"aot_inductor.dump_aoti_minifier",
"post_grad_custom_pre_pass",
"post_grad_custom_post_pass",
"_fuse_ddp_communication_passes",
"_pre_fusion_custom_pass",
]
_cache_config_ignore_prefix: list[str] = [
@ -1650,8 +1648,6 @@ _cache_config_ignore_prefix: list[str] = [
# see CustomGraphPass; these are handled specially
"post_grad_custom_post_pass",
"post_grad_custom_pre_pass",
"_fuse_ddp_communication_passes",
"_pre_fusion_custom_pass",
# tests assume that changes here don't invalidate cache
"always_complex_memory_overlap_TESTING_ONLY",
]

View File

@ -508,13 +508,9 @@ class ConfigModule(ModuleType):
protocol=2,
)
def save_config_portable(
self, *, ignore_private_configs: bool = True
) -> dict[str, Any]:
def save_config_portable(self) -> dict[str, Any]:
"""Convert config to portable format"""
prefixes = []
if ignore_private_configs:
prefixes.append("_")
prefixes = ["_"]
prefixes.extend(getattr(self, "_cache_config_ignore_prefix", []))
return self._get_dict(ignored_prefixes=prefixes)

View File

@ -24,7 +24,7 @@ Note that the import should happen before the call to install_config_module(), o
assert TYPE_CHECKING, "Do not use at runtime"
def save_config() -> bytes: ...
def save_config_portable(*, ignore_private_configs: bool = True) -> dict[str, Any]: ...
def save_config_portable() -> dict[str, Any]: ...
def codegen_config() -> str: ...
def get_hash() -> bytes: ...
def to_dict() -> dict[str, Any]: ...