mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Revert "Still run TritonBundler with BundledAOTAutogradCache, save autotune results (#158048)"
This reverts commit 8e57cdb746.
Reverted https://github.com/pytorch/pytorch/pull/158048 on behalf of https://github.com/jeffdaily due to rocm failures due to unit test introduced in this PR, but no pre-merge signal available ([comment](https://github.com/pytorch/pytorch/pull/158048#issuecomment-3098746624))
This commit is contained in:
parent
b1a0c34dd3
commit
bc379aebe2
|
|
@ -15,7 +15,6 @@ import torch.utils.cpp_extension
|
|||
from torch._dynamo.package import CompilePackage, DiskDynamoStore, DynamoCache
|
||||
from torch._dynamo.precompile_context import PrecompileContext
|
||||
from torch._functorch import config as functorch_config
|
||||
from torch._inductor.mock_cache import global_stats, PatchCaches, Stats
|
||||
from torch._inductor.runtime.runtime_utils import cache_dir
|
||||
from torch.testing._internal.common_utils import (
|
||||
instantiate_parametrized_tests,
|
||||
|
|
@ -429,39 +428,6 @@ def add(x, y):
|
|||
self.assertEqual(expected, [result1, result2])
|
||||
self.assertEqual(torch._dynamo.convert_frame.FRAME_COUNTER, total_frames)
|
||||
|
||||
@parametrize("device", ("cuda", "xpu"))
|
||||
@torch._dynamo.config.patch(caching_precompile=True)
|
||||
def test_automatic_dynamo_autotune_cache(self, device):
|
||||
if device == "cuda" and not HAS_CUDA:
|
||||
raise unittest.SkipTest("Requires CUDA/Triton")
|
||||
if device == "xpu" and not HAS_XPU:
|
||||
raise unittest.SkipTest("Requires XPU/Triton")
|
||||
|
||||
def fn(x, y):
|
||||
return x.sin() + y
|
||||
|
||||
arg1 = torch.randn(3, 3, device=device)
|
||||
arg2 = torch.randn(3, 3, device=device)
|
||||
expected = fn(arg1, arg2).clone()
|
||||
|
||||
with PatchCaches():
|
||||
compiled_fn1 = torch.compile(fn, mode="max-autotune")
|
||||
result = compiled_fn1(arg1, arg2).clone()
|
||||
self.assertEqual(expected, result)
|
||||
self.assertEqual(global_stats.autotune_local, Stats(1, 0, 1))
|
||||
DynamoCache.clear()
|
||||
|
||||
total_frames = torch._dynamo.convert_frame.FRAME_COUNTER
|
||||
self._save_and_reload(
|
||||
expected_backends=1, expected_dynamo=1, expected_autotune=1
|
||||
)
|
||||
compiled_fn1 = torch.compile(fn, mode="max-autotune")
|
||||
with torch.compiler.set_stance("fail_on_recompile"):
|
||||
result1 = compiled_fn1(arg1, arg2).clone()
|
||||
self.assertEqual(expected, result1)
|
||||
self.assertEqual(torch._dynamo.convert_frame.FRAME_COUNTER, total_frames)
|
||||
self.assertEqual(global_stats.autotune_local, Stats(2, 1, 1))
|
||||
|
||||
@parametrize("device", ("cpu", "cuda", "xpu"))
|
||||
@torch._dynamo.config.patch(caching_precompile=True)
|
||||
def test_automatic_dynamo_recompiles(self, device):
|
||||
|
|
|
|||
|
|
@ -70,8 +70,7 @@ class PrecompileContext(CacheArtifactManager):
|
|||
|
||||
The following artifact types are supported by PrecompileContext:
|
||||
- BundledAOTAutogradCacheArtifact
|
||||
- DynamoCodeStateArtifact
|
||||
- AutotuneCacheArtifact (regular autotune results, same as Megacache)
|
||||
- CodeStateArtifact (from torch._dynamo.package once available)
|
||||
"""
|
||||
|
||||
# Protected by the compile_lock
|
||||
|
|
@ -150,12 +149,8 @@ class PrecompileContext(CacheArtifactManager):
|
|||
artifacts_by_key = {}
|
||||
cache_info = CacheInfo()
|
||||
for artifact in chain(*artifacts.values()):
|
||||
if artifact.type() == "autotune":
|
||||
# Populate autotune cache artifacts
|
||||
artifact.populate_cache()
|
||||
else:
|
||||
artifacts_by_key[artifact.key] = artifact
|
||||
cache_info.add(artifact)
|
||||
artifacts_by_key[artifact.key] = artifact
|
||||
|
||||
from torch._dynamo.package import _BackendId, DynamoCache
|
||||
|
||||
|
|
|
|||
|
|
@ -909,37 +909,10 @@ def _compile_fx_inner(
|
|||
else:
|
||||
log.debug("Failed to generate FX cache key")
|
||||
|
||||
if torch._functorch.config.bundled_autograd_cache:
|
||||
assert mb_compiled_graph is None
|
||||
assert cache_info is None
|
||||
# When using bundled autograd cache, we still want
|
||||
# to use the TritonBundler, but we don't want to save
|
||||
# the results here. The results will get saved directly
|
||||
# to AOTAutogradCache.
|
||||
TritonBundler.begin_compile()
|
||||
try:
|
||||
mb_compiled_graph = fx_codegen_and_compile(
|
||||
gm, example_inputs, inputs_to_check, **graph_kwargs
|
||||
)
|
||||
assert mb_compiled_graph is not None
|
||||
(
|
||||
triton_bundle,
|
||||
triton_bundler_meta,
|
||||
) = TritonBundler.collect()
|
||||
mb_compiled_graph.set_triton_bundle(triton_bundle)
|
||||
except (ShortenTraceback, SkipFrame):
|
||||
raise
|
||||
except Exception as e:
|
||||
raise InductorError(e, currentframe()).with_traceback(
|
||||
e.__traceback__
|
||||
) from None
|
||||
finally:
|
||||
TritonBundler.end_compile()
|
||||
|
||||
# CACHE BYPASS: Compile the graph, don't save it to the cache
|
||||
# (this can happen either because cache was disabled, or we
|
||||
# determined the input is uncacheable)
|
||||
elif cache_info is None or cache_info["cache_state"] == "bypass":
|
||||
if cache_info is None or cache_info["cache_state"] == "bypass":
|
||||
assert mb_compiled_graph is None
|
||||
log.debug(
|
||||
"FX cache bypass reason: %s",
|
||||
|
|
|
|||
|
|
@ -35,7 +35,6 @@ from typing import Any, Optional, TYPE_CHECKING
|
|||
from typing_extensions import override
|
||||
|
||||
import torch
|
||||
from torch._dynamo.precompile_context import PrecompileContext
|
||||
from torch._inductor.runtime.runtime_utils import cache_dir
|
||||
from torch.compiler._cache import (
|
||||
CacheArtifact,
|
||||
|
|
@ -126,7 +125,6 @@ class AutotuneCache:
|
|||
) -> Optional[AutotuneCache]:
|
||||
cache = AutotuneCache(configs_hash)
|
||||
key = AutotuneCache._prepare_key(filename)
|
||||
|
||||
cache._setup_local_cache(inductor_meta, os.path.dirname(filename), key)
|
||||
cache._setup_remote_autotune_cache(inductor_meta, key)
|
||||
if cache.local_cache or cache.remote_cache:
|
||||
|
|
@ -302,10 +300,6 @@ class AutotuneCache:
|
|||
CacheArtifactManager.record_artifact(
|
||||
AutotuneCacheArtifact.type(), autotune_artifact_key, data
|
||||
)
|
||||
if torch._dynamo.config.caching_precompile:
|
||||
PrecompileContext.record_artifact(
|
||||
AutotuneCacheArtifact.type(), autotune_artifact_key, data
|
||||
)
|
||||
|
||||
if log.isEnabledFor(logging.DEBUG):
|
||||
type_str = "coordesc" if found_by_coordesc else "heuristic"
|
||||
|
|
@ -631,10 +625,6 @@ class LocalAutotuneCache(RemoteCache[JsonDataTy]):
|
|||
CacheArtifactManager.record_artifact(
|
||||
AutotuneCacheArtifact.type(), autotune_artifact_key, result
|
||||
)
|
||||
if torch._dynamo.config.caching_precompile:
|
||||
PrecompileContext.record_artifact(
|
||||
AutotuneCacheArtifact.type(), autotune_artifact_key, result
|
||||
)
|
||||
return result
|
||||
|
||||
@override
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user