mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Revert "inductor codecache: include private inductor configs in cache key (#153672)"
This reverts commit2c1cb38d95. Reverted https://github.com/pytorch/pytorch/pull/153672 on behalf of https://github.com/malfet due to Looks like it regressed pr_time_benchmarks, seeba3f91af97/1([comment](https://github.com/pytorch/pytorch/pull/153672#issuecomment-2922759739))
This commit is contained in:
parent
4b1f047a33
commit
31f95b5d2e
|
|
@ -56,7 +56,6 @@ from torch.testing._internal.inductor_utils import (
|
|||
requires_gpu,
|
||||
requires_triton,
|
||||
)
|
||||
from torch.testing._internal.logging_utils import multiple_logs_to_string
|
||||
from torch.testing._internal.triton_utils import requires_cuda
|
||||
|
||||
|
||||
|
|
@ -2121,90 +2120,6 @@ class TestFxGraphCacheHashing(TestCase):
|
|||
pickler.dumps(details3),
|
||||
)
|
||||
|
||||
def test_hash_private_config_changes(self):
|
||||
"""
|
||||
Test that private config settings affect hashes.
|
||||
"""
|
||||
with config.patch({"_micro_pipeline_tp": False}):
|
||||
details1 = FxGraphHashDetails(None, [], {}, [])
|
||||
details2 = FxGraphHashDetails(None, [], {}, [])
|
||||
|
||||
with config.patch({"_micro_pipeline_tp": True}):
|
||||
details3 = FxGraphHashDetails(None, [], {}, [])
|
||||
|
||||
gm = torch.fx.GraphModule({}, torch.fx.Graph())
|
||||
pickler = FxGraphCachePickler(gm)
|
||||
|
||||
self.assertEqual(
|
||||
pickler.dumps(details1),
|
||||
pickler.dumps(details2),
|
||||
)
|
||||
self.assertNotEqual(
|
||||
pickler.dumps(details1),
|
||||
pickler.dumps(details3),
|
||||
)
|
||||
|
||||
def test_non_serializable_custom_passes_causes_cache_miss(self):
|
||||
class Mod(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.param = torch.nn.Parameter(torch.rand(4, 4))
|
||||
|
||||
def forward(self, x):
|
||||
return x @ self.param
|
||||
|
||||
mod1 = Mod()
|
||||
mod_compiled = torch.compile(mod1)
|
||||
with torch.no_grad():
|
||||
x = torch.rand(4, 4)
|
||||
# miss
|
||||
mod_compiled(x)
|
||||
self.assertEqual(counters["inductor"]["fxgraph_cache_bypass"], 0)
|
||||
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1)
|
||||
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0)
|
||||
# hit
|
||||
torch._dynamo.reset()
|
||||
mod_compiled(x)
|
||||
self.assertEqual(counters["inductor"]["fxgraph_cache_bypass"], 0)
|
||||
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1)
|
||||
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 1)
|
||||
torch._dynamo.reset()
|
||||
counters.clear()
|
||||
|
||||
# hit
|
||||
mod_compiled(x)
|
||||
self.assertEqual(counters["inductor"]["fxgraph_cache_bypass"], 0)
|
||||
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 0)
|
||||
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 1)
|
||||
with config.patch({"_fuse_ddp_communication_passes": ["new_pass_foo_bar"]}):
|
||||
# miss (private config changed)
|
||||
torch._dynamo.reset()
|
||||
mod_compiled(x)
|
||||
self.assertEqual(counters["inductor"]["fxgraph_cache_bypass"], 0)
|
||||
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1)
|
||||
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 1)
|
||||
torch._dynamo.reset()
|
||||
counters.clear()
|
||||
|
||||
(codecache_stream,), ctx = multiple_logs_to_string(
|
||||
"torch._inductor.codecache", "codecache"
|
||||
)
|
||||
with ctx(), config.patch(
|
||||
{"_fuse_ddp_communication_passes": [lambda *args: None]}
|
||||
):
|
||||
# bypass (custom pass is not serializable)
|
||||
mod_compiled(x)
|
||||
self.assertEqual(counters["inductor"]["fxgraph_cache_bypass"], 1)
|
||||
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 0)
|
||||
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0)
|
||||
counters.clear()
|
||||
# assert that our bypass is explicit
|
||||
codecache_logs = codecache_stream.getvalue().strip()
|
||||
self.assertTrue(
|
||||
"Bypassing FX Graph Cache because 'Unsupported _fuse_ddp_communication_pass'"
|
||||
in codecache_logs
|
||||
)
|
||||
|
||||
def test_hash_custom_passes(self):
|
||||
"""
|
||||
Test CustomGraphPass usage.
|
||||
|
|
|
|||
|
|
@ -881,7 +881,7 @@ class FxGraphHashDetails:
|
|||
# Also hash on various system info (including the triton compiler version).
|
||||
self.torch_version = torch_key()
|
||||
self.system_info = CacheBase.get_system()
|
||||
self.inductor_config = config.save_config_portable(ignore_private_configs=False)
|
||||
self.inductor_config = config.save_config_portable()
|
||||
# Custom post grad passes should provide an ID to hash.
|
||||
self.post_grad_custom_pre_pass = self._get_custom_pass_detail(
|
||||
config.post_grad_custom_pre_pass
|
||||
|
|
@ -889,36 +889,6 @@ class FxGraphHashDetails:
|
|||
self.post_grad_custom_post_pass = self._get_custom_pass_detail(
|
||||
config.post_grad_custom_post_pass
|
||||
)
|
||||
self._pre_fusion_custom_pass = self._get_custom_pass_detail_unsafe(
|
||||
config._pre_fusion_custom_pass
|
||||
)
|
||||
self._fuse_ddp_communication_passes = self._get_custom_pass_detail_unsafe(
|
||||
config._fuse_ddp_communication_passes
|
||||
)
|
||||
|
||||
# This is mainly added to handle these two inductor configs, which are (unfortunately)
|
||||
# sometimes cache safe:
|
||||
# - _pre_fusion_custom_pass
|
||||
# - _fuse_ddp_communication_passes
|
||||
# Their types can be found in `torch/_inductor/config.py`, but:
|
||||
# - if they are string names, we can cache them safely (one is by default)
|
||||
# - if any of them are set to custom callables, we will need to cache miss
|
||||
# Future work is for someone to find any places where these functions are used
|
||||
# and force them to be of type CustomGraphPass, so we can guarantee serialization.
|
||||
def _get_custom_pass_detail_unsafe(self, custom_pass: Any) -> Optional[Any]:
|
||||
if not custom_pass:
|
||||
return None
|
||||
if isinstance(custom_pass, list):
|
||||
return [self._get_custom_pass_detail_unsafe(x) for x in custom_pass]
|
||||
if isinstance(custom_pass, str):
|
||||
return custom_pass
|
||||
if isinstance(custom_pass, CustomGraphPass):
|
||||
return custom_pass.uuid()
|
||||
if callable(custom_pass):
|
||||
# Returning None is safe here because we raise an explicit bypass error
|
||||
# later if we detect these passes are set to callables
|
||||
return None
|
||||
raise AssertionError(f"unknown config type: {str(type(custom_pass))}")
|
||||
|
||||
def _get_custom_pass_detail(
|
||||
self, custom_pass: CustomGraphPassType
|
||||
|
|
@ -1396,14 +1366,6 @@ class FxGraphCache(GuardedCache[CompiledFxGraph]):
|
|||
for p in (config.post_grad_custom_pre_pass, config.post_grad_custom_post_pass):
|
||||
if p and (not isinstance(p, CustomGraphPass) or not p.uuid()):
|
||||
raise BypassFxGraphCache("Unsupported post grad custom pass")
|
||||
# We should find any users of _pre_fusion_custom_pass and _fuse_ddp_communication_passes
|
||||
# and ensure they are not passing us raw callables
|
||||
if config._pre_fusion_custom_pass is not None:
|
||||
if not isinstance(config._pre_fusion_custom_pass, CustomGraphPass):
|
||||
raise BypassFxGraphCache("Unsupported _pre_fusion_custom_pass")
|
||||
for p in config._fuse_ddp_communication_passes:
|
||||
if callable(p) and not isinstance(p, CustomGraphPass):
|
||||
raise BypassFxGraphCache("Unsupported _fuse_ddp_communication_pass")
|
||||
|
||||
# Freezing can embed constants that wouldn't be static across runs.
|
||||
if has_frozen_params(gm) and not torch._utils_internal.justknobs_check(
|
||||
|
|
|
|||
|
|
@ -1635,8 +1635,6 @@ _save_config_ignore: list[str] = [
|
|||
"aot_inductor.dump_aoti_minifier",
|
||||
"post_grad_custom_pre_pass",
|
||||
"post_grad_custom_post_pass",
|
||||
"_fuse_ddp_communication_passes",
|
||||
"_pre_fusion_custom_pass",
|
||||
]
|
||||
|
||||
_cache_config_ignore_prefix: list[str] = [
|
||||
|
|
@ -1650,8 +1648,6 @@ _cache_config_ignore_prefix: list[str] = [
|
|||
# see CustomGraphPass; these are handled specially
|
||||
"post_grad_custom_post_pass",
|
||||
"post_grad_custom_pre_pass",
|
||||
"_fuse_ddp_communication_passes",
|
||||
"_pre_fusion_custom_pass",
|
||||
# tests assume that changes here don't invalidate cache
|
||||
"always_complex_memory_overlap_TESTING_ONLY",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -508,13 +508,9 @@ class ConfigModule(ModuleType):
|
|||
protocol=2,
|
||||
)
|
||||
|
||||
def save_config_portable(
|
||||
self, *, ignore_private_configs: bool = True
|
||||
) -> dict[str, Any]:
|
||||
def save_config_portable(self) -> dict[str, Any]:
|
||||
"""Convert config to portable format"""
|
||||
prefixes = []
|
||||
if ignore_private_configs:
|
||||
prefixes.append("_")
|
||||
prefixes = ["_"]
|
||||
prefixes.extend(getattr(self, "_cache_config_ignore_prefix", []))
|
||||
return self._get_dict(ignored_prefixes=prefixes)
|
||||
|
||||
|
|
|
|||
|
|
@ -24,7 +24,7 @@ Note that the import should happen before the call to install_config_module(), o
|
|||
assert TYPE_CHECKING, "Do not use at runtime"
|
||||
|
||||
def save_config() -> bytes: ...
|
||||
def save_config_portable(*, ignore_private_configs: bool = True) -> dict[str, Any]: ...
|
||||
def save_config_portable() -> dict[str, Any]: ...
|
||||
def codegen_config() -> str: ...
|
||||
def get_hash() -> bytes: ...
|
||||
def to_dict() -> dict[str, Any]: ...
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user