mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
Rename inductor cache (#156128)
Requested by Simon on a different PR Pull Request resolved: https://github.com/pytorch/pytorch/pull/156128 Approved by: https://github.com/xmfan
This commit is contained in:
parent
45382b284d
commit
a2a75be0f8
|
|
@ -8,7 +8,7 @@ import sys
|
|||
import tempfile
|
||||
from typing import Callable
|
||||
|
||||
from torch._inductor.utils import fresh_inductor_cache
|
||||
from torch._inductor.utils import fresh_cache
|
||||
|
||||
|
||||
logger: logging.Logger = logging.getLogger(__name__)
|
||||
|
|
@ -62,7 +62,7 @@ def _run_torchbench_from_args(
|
|||
warm_compile_time: list[float] = []
|
||||
|
||||
for _ in range(cmd_args.repeat):
|
||||
with fresh_inductor_cache():
|
||||
with fresh_cache():
|
||||
env = os.environ.copy()
|
||||
with tempfile.NamedTemporaryFile(suffix=".csv") as file:
|
||||
args.append("--output=" + file.name)
|
||||
|
|
|
|||
|
|
@ -51,7 +51,7 @@ from torch._logging.scribe import open_source_signpost
|
|||
|
||||
try:
|
||||
from torch._dynamo.utils import clone_inputs, graph_break_reasons
|
||||
from torch._inductor.utils import fresh_inductor_cache
|
||||
from torch._inductor.utils import fresh_cache
|
||||
except ImportError:
|
||||
from _dynamo.utils import clone_inputs, graph_break_reasons
|
||||
|
||||
|
|
@ -3416,7 +3416,7 @@ def maybe_fresh_cache(args):
|
|||
if not cache_dir_assigned and (
|
||||
args.cold_start_latency or args.warm_start_latency or args.ci
|
||||
):
|
||||
return fresh_inductor_cache()
|
||||
return fresh_cache()
|
||||
else:
|
||||
return contextlib.nullcontext()
|
||||
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ import timeit
|
|||
|
||||
import torch.fx
|
||||
from torch._dynamo.utils import counters
|
||||
from torch._inductor.utils import clear_inductor_caches, fresh_inductor_cache
|
||||
from torch._inductor.utils import clear_caches, fresh_cache
|
||||
|
||||
|
||||
N = 10000
|
||||
|
|
@ -20,7 +20,7 @@ def main():
|
|||
torch._inductor.config.fx_graph_cache = True
|
||||
torch._inductor.config.fx_graph_remote_cache = False
|
||||
|
||||
with fresh_inductor_cache():
|
||||
with fresh_cache():
|
||||
a = torch.randn(4).cuda()
|
||||
compiled_fn = torch.compile(huge_graph, backend="inductor")
|
||||
|
||||
|
|
@ -30,7 +30,7 @@ def main():
|
|||
|
||||
def setup():
|
||||
torch._dynamo.reset()
|
||||
clear_inductor_caches()
|
||||
clear_caches()
|
||||
for m in torch._inductor.codecache.PyCodeCache.cache.values():
|
||||
os.remove(m.__file__)
|
||||
counters.clear()
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ import sys
|
|||
from benchmark_base import BenchmarkBase
|
||||
|
||||
import torch
|
||||
from torch._inductor.utils import fresh_inductor_cache
|
||||
from torch._inductor.utils import fresh_cache
|
||||
|
||||
|
||||
class Benchmark(BenchmarkBase):
|
||||
|
|
@ -50,7 +50,7 @@ class Benchmark(BenchmarkBase):
|
|||
result = result.sin()
|
||||
return result
|
||||
|
||||
with fresh_inductor_cache():
|
||||
with fresh_cache():
|
||||
f(self.a, self.b)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ from benchmark_base import BenchmarkBase
|
|||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch._inductor.utils import fresh_inductor_cache
|
||||
from torch._inductor.utils import fresh_cache
|
||||
|
||||
|
||||
class ListOfLinears(nn.Module):
|
||||
|
|
@ -55,7 +55,7 @@ class Benchmark(BenchmarkBase):
|
|||
|
||||
def _work(self):
|
||||
with (
|
||||
fresh_inductor_cache(),
|
||||
fresh_cache(),
|
||||
torch._inductor.config.patch(force_shape_pad=self._force_shape_pad),
|
||||
):
|
||||
opt_m = torch.compile(backend=self.backend(), dynamic=self.is_dynamic())(
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ from benchmark_base import BenchmarkBase
|
|||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch._inductor.utils import fresh_inductor_cache
|
||||
from torch._inductor.utils import fresh_cache
|
||||
|
||||
|
||||
# Create a chain of artificial nesting
|
||||
|
|
@ -94,7 +94,7 @@ class Benchmark(BenchmarkBase):
|
|||
# enable_cpp_symbolic_shape_guards has impact on this benchmark
|
||||
# Keep using False value for consistency.
|
||||
with (
|
||||
fresh_inductor_cache(),
|
||||
fresh_cache(),
|
||||
):
|
||||
opt_m = torch.compile(backend=self.backend(), dynamic=self.is_dynamic())(
|
||||
self.m.cuda() if self._is_gpu else self.m
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ import sys
|
|||
from benchmark_base import BenchmarkBase
|
||||
|
||||
import torch
|
||||
from torch._inductor.utils import fresh_inductor_cache
|
||||
from torch._inductor.utils import fresh_cache
|
||||
|
||||
|
||||
class Benchmark(BenchmarkBase):
|
||||
|
|
@ -31,7 +31,7 @@ class Benchmark(BenchmarkBase):
|
|||
def f(x, y):
|
||||
return x + y
|
||||
|
||||
with fresh_inductor_cache():
|
||||
with fresh_cache():
|
||||
for i in range(8):
|
||||
f(torch.arange(3), i * 2.5)
|
||||
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ import sys
|
|||
from benchmark_base import BenchmarkBase
|
||||
|
||||
import torch
|
||||
from torch._inductor.utils import fresh_inductor_cache
|
||||
from torch._inductor.utils import fresh_cache
|
||||
|
||||
|
||||
class Benchmark(BenchmarkBase):
|
||||
|
|
@ -45,7 +45,7 @@ class Benchmark(BenchmarkBase):
|
|||
z = torch.mm(z, b)
|
||||
return z
|
||||
|
||||
with fresh_inductor_cache(), torch._inductor.config.patch(max_autotune=True):
|
||||
with fresh_cache(), torch._inductor.config.patch(max_autotune=True):
|
||||
f(self.a, self.b)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ from benchmark_base import BenchmarkBase
|
|||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch._inductor.utils import fresh_inductor_cache
|
||||
from torch._inductor.utils import fresh_cache
|
||||
|
||||
|
||||
class NestedModule(nn.Module):
|
||||
|
|
@ -67,7 +67,7 @@ class Benchmark(BenchmarkBase):
|
|||
# enable_cpp_symbolic_shape_guards has impact on this benchmark
|
||||
# Keep using False value for consistency.
|
||||
with (
|
||||
fresh_inductor_cache(),
|
||||
fresh_cache(),
|
||||
):
|
||||
opt_m = torch.compile(backend=self.backend(), dynamic=self.is_dynamic())(
|
||||
self.m.cuda() if self._is_gpu else self.m
|
||||
|
|
|
|||
|
|
@ -246,7 +246,7 @@ def run_single_experiment_group(
|
|||
|
||||
for config in group_config.experiments:
|
||||
torch._dynamo.reset()
|
||||
torch._inductor.utils.clear_inductor_caches()
|
||||
torch._inductor.utils.clear_caches()
|
||||
compiled_op = torch.compile(
|
||||
op,
|
||||
options=config.to_options(),
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ from torch._inductor.fx_passes.micro_pipeline_tp import (
|
|||
micro_pipeline_tp_pass,
|
||||
)
|
||||
from torch._inductor.fx_passes.post_grad import remove_noop_ops, view_to_reshape
|
||||
from torch._inductor.utils import fresh_inductor_cache, run_and_get_triton_code
|
||||
from torch._inductor.utils import fresh_cache, run_and_get_triton_code
|
||||
from torch.distributed._functional_collectives import (
|
||||
all_gather_tensor,
|
||||
reduce_scatter_tensor,
|
||||
|
|
@ -81,7 +81,7 @@ class MicroPipelineTPTest(TestCase):
|
|||
dist.destroy_process_group()
|
||||
|
||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||
@fresh_inductor_cache()
|
||||
@fresh_cache()
|
||||
def test_find_all_gather_patterns(self):
|
||||
group = dist.group.WORLD
|
||||
|
||||
|
|
@ -134,7 +134,7 @@ class MicroPipelineTPTest(TestCase):
|
|||
)
|
||||
|
||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||
@fresh_inductor_cache()
|
||||
@fresh_cache()
|
||||
def test_find_reduce_scatter_patterns(self):
|
||||
group = dist.group.WORLD
|
||||
|
||||
|
|
@ -173,7 +173,7 @@ class MicroPipelineTPTest(TestCase):
|
|||
self.assertEqual(reduce_scatters[1].scatter_dim, 1)
|
||||
|
||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||
@fresh_inductor_cache()
|
||||
@fresh_cache()
|
||||
def test_get_unexposed_collectives(self):
|
||||
group = dist.group.WORLD
|
||||
|
||||
|
|
@ -201,7 +201,7 @@ class MicroPipelineTPTest(TestCase):
|
|||
@parametrize("A_dims", [2, 3])
|
||||
@parametrize("gather_dim", [0, 1, 2])
|
||||
@parametrize("return_A", [True, False])
|
||||
@fresh_inductor_cache()
|
||||
@fresh_cache()
|
||||
def test_fuse_all_gather_matmul(self, A_dims, gather_dim, return_A):
|
||||
if gather_dim >= A_dims:
|
||||
return
|
||||
|
|
@ -248,7 +248,7 @@ class MicroPipelineTPTest(TestCase):
|
|||
@parametrize("A_dims", [2, 3])
|
||||
@parametrize("gather_dim", [0, 1, 2])
|
||||
@parametrize("return_A", [True, False])
|
||||
@fresh_inductor_cache()
|
||||
@fresh_cache()
|
||||
def test_fuse_all_gather_scaled_matmul(self, A_dims, gather_dim, return_A):
|
||||
if gather_dim >= A_dims:
|
||||
return
|
||||
|
|
@ -321,7 +321,7 @@ class MicroPipelineTPTest(TestCase):
|
|||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||
@parametrize("A_dims", [2, 3])
|
||||
@parametrize("scatter_dim", [0, 1, 2])
|
||||
@fresh_inductor_cache()
|
||||
@fresh_cache()
|
||||
def test_fuse_matmul_reduce_scatter(self, A_dims, scatter_dim):
|
||||
if scatter_dim >= A_dims:
|
||||
return
|
||||
|
|
@ -350,7 +350,7 @@ class MicroPipelineTPTest(TestCase):
|
|||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||
@parametrize("A_dims", [2, 3])
|
||||
@parametrize("scatter_dim", [0, 1, 2])
|
||||
@fresh_inductor_cache()
|
||||
@fresh_cache()
|
||||
def test_fuse_scaled_matmul_reduce_scatter(self, A_dims, scatter_dim):
|
||||
if scatter_dim >= A_dims:
|
||||
return
|
||||
|
|
@ -403,7 +403,7 @@ class MicroPipelineTPTest(TestCase):
|
|||
@runOnRocmArch(MI300_ARCH)
|
||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||
@parametrize("scatter_dim", [0, 1, 2])
|
||||
@fresh_inductor_cache()
|
||||
@fresh_cache()
|
||||
def test_fuse_scaled_matmul_reduce_scatter_rowwise_scales_reshape_mm_reshape(
|
||||
self, scatter_dim
|
||||
):
|
||||
|
|
@ -465,7 +465,7 @@ class MicroPipelineTPTest(TestCase):
|
|||
|
||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||
@parametrize("shard_dim", [0, 1])
|
||||
@fresh_inductor_cache()
|
||||
@fresh_cache()
|
||||
def test_dtensor_seq_par(self, shard_dim: int):
|
||||
model: torch.nn.Module = MLPModule(device="cuda", bias=False)
|
||||
device_mesh = DeviceMesh(
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ import torch
|
|||
import torch.distributed as dist
|
||||
import torch.distributed._functional_collectives as funcol
|
||||
from torch._C import FileCheck
|
||||
from torch._inductor.utils import fresh_inductor_cache, run_and_get_triton_code
|
||||
from torch._inductor.utils import fresh_cache, run_and_get_triton_code
|
||||
from torch.distributed._functional_collectives import (
|
||||
all_gather_into_tensor_coalesced,
|
||||
all_gather_tensor,
|
||||
|
|
@ -464,7 +464,7 @@ class TestWithNCCL(MultiProcessTestCase):
|
|||
|
||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@fresh_inductor_cache()
|
||||
@fresh_cache()
|
||||
def test_threading(self):
|
||||
self._init_process_group()
|
||||
device = torch.device(f"cuda:{self.rank}")
|
||||
|
|
@ -510,7 +510,7 @@ class TestWithNCCL(MultiProcessTestCase):
|
|||
"_scaled_mm currently only supports sm>=90",
|
||||
)
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@fresh_inductor_cache()
|
||||
@fresh_cache()
|
||||
def test_fixed_striding(self):
|
||||
self._init_process_group()
|
||||
|
||||
|
|
@ -736,7 +736,7 @@ class CompileTest(TestCase):
|
|||
dist.destroy_process_group()
|
||||
|
||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||
@fresh_inductor_cache()
|
||||
@fresh_cache()
|
||||
def test_inductor_all_reduce_single(self):
|
||||
def func(arg: torch.Tensor) -> torch.Tensor:
|
||||
buf0 = arg + 42
|
||||
|
|
@ -773,7 +773,7 @@ class CompileTest(TestCase):
|
|||
torch.cuda.synchronize()
|
||||
|
||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||
@fresh_inductor_cache()
|
||||
@fresh_cache()
|
||||
def test_inductor_all_reduce_coalesced(self):
|
||||
def func(args: list[torch.Tensor]) -> torch.Tensor:
|
||||
bufs = [arg + 42 for arg in args]
|
||||
|
|
@ -819,7 +819,7 @@ class CompileTest(TestCase):
|
|||
torch.cuda.synchronize()
|
||||
|
||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||
@fresh_inductor_cache()
|
||||
@fresh_cache()
|
||||
def test_inductor_inplace_op_on_view(self):
|
||||
def func(arg: torch.Tensor) -> torch.Tensor:
|
||||
buf0 = (arg + 10)[:2]
|
||||
|
|
@ -843,7 +843,7 @@ class CompileTest(TestCase):
|
|||
)
|
||||
|
||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||
@fresh_inductor_cache()
|
||||
@fresh_cache()
|
||||
def test_inductor_all_reduce_non_contig_input(self):
|
||||
def func(arg: torch.Tensor) -> torch.Tensor:
|
||||
ar0 = funcol.all_reduce(arg, "avg", "0")
|
||||
|
|
@ -869,7 +869,7 @@ class CompileTest(TestCase):
|
|||
assert "torch.ops._c10d_functional.wait_tensor.default" in code
|
||||
|
||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||
@fresh_inductor_cache()
|
||||
@fresh_cache()
|
||||
def test_inductor_reuse_buffer_after_inplace_collective(self):
|
||||
def func(arg: torch.Tensor) -> torch.Tensor:
|
||||
# Expect allocation
|
||||
|
|
@ -904,7 +904,7 @@ class CompileTest(TestCase):
|
|||
assert "= torch.ops._c10d_functional.wait_tensor.default" not in code
|
||||
|
||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||
@fresh_inductor_cache()
|
||||
@fresh_cache()
|
||||
def test_inductor_all_gather_into_tensor_single(self):
|
||||
def func(arg: torch.Tensor) -> torch.Tensor:
|
||||
ag0 = funcol.all_gather_tensor(arg, 0, "0")
|
||||
|
|
@ -931,7 +931,7 @@ class CompileTest(TestCase):
|
|||
torch.cuda.synchronize()
|
||||
|
||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||
@fresh_inductor_cache()
|
||||
@fresh_cache()
|
||||
def test_inductor_all_gather_into_tensor_coalesced(self):
|
||||
def func(args: list[torch.Tensor]) -> torch.Tensor:
|
||||
ag0 = funcol.all_gather_into_tensor_coalesced(args, "0")
|
||||
|
|
@ -965,7 +965,7 @@ class CompileTest(TestCase):
|
|||
torch.cuda.synchronize()
|
||||
|
||||
@unittest.skipIf(not HAS_GPU, "This is a GPU test!")
|
||||
@fresh_inductor_cache()
|
||||
@fresh_cache()
|
||||
def test_wait_tensor(self):
|
||||
def func(arg: torch.Tensor) -> torch.Tensor:
|
||||
t = torch.ops._c10d_functional.all_reduce(arg, "avg", "0")
|
||||
|
|
@ -987,7 +987,7 @@ class CompileTest(TestCase):
|
|||
torch.cuda.synchronize()
|
||||
|
||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||
@fresh_inductor_cache()
|
||||
@fresh_cache()
|
||||
def test_inductor_reduce_scatter_tensor_single(self):
|
||||
def func(arg: torch.Tensor) -> torch.Tensor:
|
||||
rs0 = funcol.reduce_scatter_tensor(arg, "avg", 0, "0")
|
||||
|
|
@ -1013,7 +1013,7 @@ class CompileTest(TestCase):
|
|||
torch.cuda.synchronize()
|
||||
|
||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||
@fresh_inductor_cache()
|
||||
@fresh_cache()
|
||||
def test_inductor_reduce_scatter_tensor_coalesced(self):
|
||||
def func(args: list[torch.Tensor]) -> torch.Tensor:
|
||||
rs0 = funcol.reduce_scatter_tensor_coalesced(
|
||||
|
|
@ -1049,7 +1049,7 @@ class CompileTest(TestCase):
|
|||
torch.cuda.synchronize()
|
||||
|
||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||
@fresh_inductor_cache()
|
||||
@fresh_cache()
|
||||
def test_inductor_all_to_all_single(self):
|
||||
def _tolist_with_constrain_as_size(tensor):
|
||||
lst = tensor.tolist()
|
||||
|
|
@ -1097,7 +1097,7 @@ class CompileTest(TestCase):
|
|||
)
|
||||
|
||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||
@fresh_inductor_cache()
|
||||
@fresh_cache()
|
||||
def test_inductor_broadcast(self):
|
||||
def func(arg: torch.Tensor) -> torch.Tensor:
|
||||
buf0 = arg + 42
|
||||
|
|
@ -1134,7 +1134,7 @@ class CompileTest(TestCase):
|
|||
torch.cuda.synchronize()
|
||||
|
||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||
@fresh_inductor_cache()
|
||||
@fresh_cache()
|
||||
def test_ranks_and_tag(self):
|
||||
def func(arg: torch.Tensor) -> torch.Tensor:
|
||||
buf0 = arg + 42
|
||||
|
|
|
|||
|
|
@ -1218,11 +1218,9 @@ class TestMultiProc(DynamoDistributedMultiProcTestCase):
|
|||
@patch.object(torch._inductor.config, "sleep_sec_TESTING_ONLY", 10)
|
||||
def test_asymmetric_compilation_with_fx_cache(self):
|
||||
from torch._dynamo.utils import counters
|
||||
from torch._inductor.utils import fresh_inductor_cache
|
||||
from torch._inductor.utils import fresh_cache
|
||||
|
||||
with fresh_inductor_cache(), _dynamo_dist_per_rank_init(
|
||||
self.rank, self.world_size
|
||||
):
|
||||
with fresh_cache(), _dynamo_dist_per_rank_init(self.rank, self.world_size):
|
||||
torch._dynamo.utils.clear_compilation_metrics()
|
||||
|
||||
device = f"cuda:{self.rank}"
|
||||
|
|
@ -1252,7 +1250,7 @@ class TestMultiProc(DynamoDistributedMultiProcTestCase):
|
|||
torch._dynamo.reset()
|
||||
|
||||
if self.rank == 0:
|
||||
with fresh_inductor_cache():
|
||||
with fresh_cache():
|
||||
f(x)
|
||||
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 2)
|
||||
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0)
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ import torch.distributed as dist
|
|||
import torch.distributed._symmetric_memory as symm_mem
|
||||
from torch._C._autograd import DeviceType
|
||||
from torch._C._distributed_c10d import _SymmetricMemory
|
||||
from torch._inductor.utils import fresh_inductor_cache, run_and_get_triton_code
|
||||
from torch._inductor.utils import fresh_cache, run_and_get_triton_code
|
||||
from torch.distributed._functional_collectives import all_gather_tensor
|
||||
from torch.distributed._symmetric_memory import (
|
||||
_fused_all_gather_matmul_fallback,
|
||||
|
|
@ -1020,7 +1020,7 @@ class LoweringTest(MultiProcContinousTest):
|
|||
@skip("Fails with 'one_shot_all_reduce' not found in AOT graph, TODO: fix")
|
||||
@skipIfRocm # requires registered-buffer support
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@fresh_inductor_cache()
|
||||
@fresh_cache()
|
||||
def test_lowering_one_shot_all_reduce(self):
|
||||
self._init_process()
|
||||
arg = torch.rand(4, 4, device=self.device)
|
||||
|
|
|
|||
|
|
@ -25,7 +25,7 @@ from torch._inductor import config as inductor_config
|
|||
from torch._inductor.runtime.runtime_utils import cache_dir
|
||||
from torch._inductor.runtime.triton_compat import tl, triton
|
||||
from torch._inductor.test_case import TestCase as InductorTestCase
|
||||
from torch._inductor.utils import fresh_inductor_cache
|
||||
from torch._inductor.utils import fresh_cache
|
||||
from torch._subclasses import FakeTensorMode
|
||||
from torch.compiler._cache import CacheArtifactManager
|
||||
from torch.fx.experimental.symbolic_shapes import ShapeEnv
|
||||
|
|
@ -165,7 +165,7 @@ class AOTAutogradCacheTests(InductorTestCase):
|
|||
b = torch.rand(100, 100, dtype=dtype, device=device, requires_grad=True)
|
||||
|
||||
# Record artifacts
|
||||
with fresh_inductor_cache():
|
||||
with fresh_cache():
|
||||
compiled_fn = torch.compile(fn, dynamic=dynamic)
|
||||
|
||||
# A first call should miss in the cache.
|
||||
|
|
@ -201,7 +201,7 @@ class AOTAutogradCacheTests(InductorTestCase):
|
|||
shutil.rmtree(os.path.join(cache_dir(), "triton"), ignore_errors=True)
|
||||
|
||||
# We did not load anything so dont hit yet
|
||||
with fresh_inductor_cache():
|
||||
with fresh_cache():
|
||||
eager_result = fn(a, b)
|
||||
compiled_result = compiled_fn(a, b)
|
||||
self.assertEqual(eager_result, compiled_result)
|
||||
|
|
@ -221,7 +221,7 @@ class AOTAutogradCacheTests(InductorTestCase):
|
|||
shutil.rmtree(os.path.join(cache_dir(), "triton"), ignore_errors=True)
|
||||
|
||||
# Hot load and hit
|
||||
with fresh_inductor_cache():
|
||||
with fresh_cache():
|
||||
cache_info = torch.compiler.load_cache_artifacts(artifact_bytes)
|
||||
|
||||
self.assertEqual(len(cache_info.inductor_artifacts), 2)
|
||||
|
|
|
|||
|
|
@ -798,7 +798,7 @@ TRACE FX call mul from test_logging.py:N in fn (LoggingTests.test_trace_call_pre
|
|||
@requires_cuda
|
||||
@unittest.skipIf(not SM90OrLater, "requires H100+ GPU")
|
||||
def test_autotuning(self, records):
|
||||
with torch._inductor.utils.fresh_inductor_cache():
|
||||
with torch._inductor.utils.fresh_cache():
|
||||
|
||||
def f(a, b):
|
||||
return torch.mm(a, b)
|
||||
|
|
|
|||
|
|
@ -54,7 +54,7 @@ from torch._dynamo.testing import (
|
|||
)
|
||||
from torch._dynamo.utils import call_size, counters, ifdynstaticdefault
|
||||
from torch._dynamo.variables import builder
|
||||
from torch._inductor.utils import fresh_inductor_cache, run_and_get_code
|
||||
from torch._inductor.utils import fresh_cache, run_and_get_code
|
||||
from torch.ao.quantization import MinMaxObserver
|
||||
from torch.ao.quantization.fake_quantize import FakeQuantize
|
||||
from torch.ao.quantization.qconfig import QConfig
|
||||
|
|
@ -8087,7 +8087,7 @@ utils_device.CURRENT_DEVICE == None""".split(
|
|||
|
||||
m1 = Model(50)
|
||||
m2 = Model(60)
|
||||
with fresh_inductor_cache():
|
||||
with fresh_cache():
|
||||
m1(torch.rand(1, 2, 3))
|
||||
m2(torch.rand(1, 2, 3))
|
||||
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ import torch._inductor.mock_cache as mock_cache
|
|||
import torch.compiler.config
|
||||
import torch.nested
|
||||
from torch._dynamo.testing import CompileCounter
|
||||
from torch._inductor.utils import clear_inductor_caches, fresh_inductor_cache
|
||||
from torch._inductor.utils import clear_caches, fresh_cache
|
||||
|
||||
|
||||
class PgoTest(torch._dynamo.test_case.TestCase):
|
||||
|
|
@ -24,7 +24,7 @@ class PgoTest(torch._dynamo.test_case.TestCase):
|
|||
torch._dynamo.config.patch(automatic_dynamic_local_pgo=True)
|
||||
)
|
||||
if os.environ.get("INDUCTOR_TEST_DISABLE_FRESH_CACHE") != "1":
|
||||
self._test_stack.enter_context(fresh_inductor_cache())
|
||||
self._test_stack.enter_context(fresh_cache())
|
||||
mock_cache.PatchCaches.setUp()
|
||||
|
||||
def tearDown(self):
|
||||
|
|
@ -35,7 +35,7 @@ class PgoTest(torch._dynamo.test_case.TestCase):
|
|||
|
||||
def reset(self):
|
||||
torch._dynamo.reset()
|
||||
clear_inductor_caches()
|
||||
clear_caches()
|
||||
|
||||
def test_basic(self):
|
||||
cnts = CompileCounter()
|
||||
|
|
@ -244,7 +244,7 @@ class PgoTest(torch._dynamo.test_case.TestCase):
|
|||
self.assertEqual(cnts.frame_count, 2)
|
||||
|
||||
torch._dynamo.reset()
|
||||
clear_inductor_caches()
|
||||
clear_caches()
|
||||
cnts.clear()
|
||||
|
||||
with torch.compiler.config.patch(job_id="foo"):
|
||||
|
|
|
|||
|
|
@ -43,7 +43,7 @@ import torch.utils._pytree as pytree
|
|||
from torch import nn
|
||||
from torch._dynamo.debug_utils import same_two_models
|
||||
from torch._dynamo.testing import CompileCounter, rand_strided, same, skipIfPy312
|
||||
from torch._inductor.utils import fresh_inductor_cache
|
||||
from torch._inductor.utils import fresh_cache
|
||||
from torch.nn import functional as F
|
||||
from torch.testing._internal.common_cuda import (
|
||||
PLATFORM_SUPPORTS_FLASH_ATTENTION,
|
||||
|
|
@ -5527,7 +5527,7 @@ def forward(self, s77 : torch.SymInt, s27 : torch.SymInt, L_x_ : torch.Tensor):
|
|||
y = torch.randn(100, 10)
|
||||
return torch.mm(x, y).sum()
|
||||
|
||||
with fresh_inductor_cache():
|
||||
with fresh_cache():
|
||||
torch.compile(fn)()
|
||||
|
||||
torch.compile(fn2)()
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ import torch
|
|||
from torch._inductor.codecache import get_kernel_bin_format
|
||||
from torch._inductor.package import AOTICompiledModel, load_package, package_aoti
|
||||
from torch._inductor.test_case import TestCase
|
||||
from torch._inductor.utils import fresh_inductor_cache
|
||||
from torch._inductor.utils import fresh_cache
|
||||
from torch.export import Dim
|
||||
from torch.testing._internal.common_utils import (
|
||||
IS_FBCODE,
|
||||
|
|
@ -157,7 +157,7 @@ class TestAOTInductorPackage(TestCase):
|
|||
torch.manual_seed(0)
|
||||
with tempfile.NamedTemporaryFile(suffix=".pt2") as f:
|
||||
ep = torch.export.export(model, example_inputs, strict=True)
|
||||
with fresh_inductor_cache():
|
||||
with fresh_cache():
|
||||
# cubin files are removed when exiting this context
|
||||
package_path = torch._inductor.aoti_compile_and_package(
|
||||
ep,
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ import torch
|
|||
from torch._inductor import config
|
||||
from torch._inductor.async_compile import AsyncCompile, shutdown_compile_workers
|
||||
from torch._inductor.test_case import run_tests, TestCase
|
||||
from torch._inductor.utils import fresh_inductor_cache
|
||||
from torch._inductor.utils import fresh_cache
|
||||
from torch.testing._internal.common_utils import (
|
||||
instantiate_parametrized_tests,
|
||||
parametrize,
|
||||
|
|
@ -32,7 +32,7 @@ class TestAsyncCompile(TestCase):
|
|||
pool = AsyncCompile.process_pool()
|
||||
pool.ready_future.result(timeout=120)
|
||||
|
||||
with fresh_inductor_cache():
|
||||
with fresh_cache():
|
||||
compiled_fn = torch.compile(fn)
|
||||
self.assertEqual(fn(x, y), compiled_fn(x, y))
|
||||
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ import torch
|
|||
from torch._inductor.codegen.triton import TritonScheduling
|
||||
from torch._inductor.test_case import TestCase as InductorTestCase
|
||||
from torch._inductor.test_operators import realize
|
||||
from torch._inductor.utils import fresh_inductor_cache, is_big_gpu, run_and_get_code
|
||||
from torch._inductor.utils import fresh_cache, is_big_gpu, run_and_get_code
|
||||
from torch.testing import FileCheck
|
||||
from torch.testing._internal.common_utils import slowTest
|
||||
from torch.testing._internal.inductor_utils import (
|
||||
|
|
@ -283,7 +283,7 @@ if HAS_CUDA:
|
|||
self.assertEqual(res, res2, atol=1e-4, rtol=1.1)
|
||||
return code, code2
|
||||
|
||||
@fresh_inductor_cache()
|
||||
@fresh_cache()
|
||||
@config.patch(max_autotune_gemm_backends="TRITON")
|
||||
def test_equivalent_template_code(self):
|
||||
code, code2 = self._equivalent_output_code_impl(256)
|
||||
|
|
@ -298,7 +298,7 @@ if HAS_CUDA:
|
|||
out_code[0]
|
||||
)
|
||||
|
||||
@fresh_inductor_cache()
|
||||
@fresh_cache()
|
||||
@config.patch(max_autotune_gemm_backends="ATEN")
|
||||
def test_equivalent_extern_code(self):
|
||||
torch._dynamo.reset()
|
||||
|
|
|
|||
|
|
@ -38,7 +38,7 @@ from torch._inductor.graph import GraphLowering
|
|||
from torch._inductor.mock_cache import global_stats, PatchCaches, Stats
|
||||
from torch._inductor.runtime.runtime_utils import cache_dir
|
||||
from torch._inductor.test_case import run_tests, TestCase
|
||||
from torch._inductor.utils import clear_inductor_caches, fresh_inductor_cache
|
||||
from torch._inductor.utils import clear_caches, fresh_cache
|
||||
from torch._library import capture_triton
|
||||
from torch.compiler._cache import (
|
||||
CacheArtifact,
|
||||
|
|
@ -146,7 +146,7 @@ class TestFxGraphCache(TestCase):
|
|||
AOTAutogradCache.clear()
|
||||
PyCodeCache.cache_clear(purge=True)
|
||||
torch._dynamo.reset()
|
||||
clear_inductor_caches()
|
||||
clear_caches()
|
||||
|
||||
@requires_triton()
|
||||
@config.patch({"fx_graph_cache": True})
|
||||
|
|
@ -379,16 +379,14 @@ class TestFxGraphCache(TestCase):
|
|||
), patch.dict(os.environ), PatchCaches():
|
||||
os.environ.pop("TRITON_CACHE_MANAGER", None)
|
||||
for _ in range(4):
|
||||
with fresh_inductor_cache():
|
||||
with fresh_cache():
|
||||
compiled_fn = torch.compile(fn, dynamic=dynamic)
|
||||
self.assertEqual(fn(a, b), compiled_fn(a, b))
|
||||
reset()
|
||||
|
||||
self.assertEqual(global_stats.fx_graph, Stats(1, 3, 1))
|
||||
|
||||
with torch.compiler.config.patch(
|
||||
{"cache_key_tag": "test"}
|
||||
), fresh_inductor_cache():
|
||||
with torch.compiler.config.patch({"cache_key_tag": "test"}), fresh_cache():
|
||||
compiled_fn = torch.compile(fn, dynamic=dynamic)
|
||||
self.assertEqual(fn(a, b), compiled_fn(a, b))
|
||||
|
||||
|
|
@ -426,7 +424,7 @@ class TestFxGraphCache(TestCase):
|
|||
b = torch.rand(100, 100, dtype=dtype, device=device)
|
||||
|
||||
# Record artifacts
|
||||
with fresh_inductor_cache():
|
||||
with fresh_cache():
|
||||
compiled_fn = torch.compile(fn, dynamic=dynamic)
|
||||
|
||||
# A first call should miss in the cache.
|
||||
|
|
@ -456,7 +454,7 @@ class TestFxGraphCache(TestCase):
|
|||
shutil.rmtree(os.path.join(cache_dir(), "triton"), ignore_errors=True)
|
||||
|
||||
# We did not load anything so dont hit yet
|
||||
with fresh_inductor_cache():
|
||||
with fresh_cache():
|
||||
eager_result = fn(a, b)
|
||||
compiled_result = compiled_fn(a, b)
|
||||
self.assertEqual(eager_result, compiled_result)
|
||||
|
|
@ -470,7 +468,7 @@ class TestFxGraphCache(TestCase):
|
|||
shutil.rmtree(os.path.join(cache_dir(), "triton"), ignore_errors=True)
|
||||
|
||||
# Hot load and hit
|
||||
with fresh_inductor_cache():
|
||||
with fresh_cache():
|
||||
cache_info = torch.compiler.load_cache_artifacts(artifact_bytes)
|
||||
|
||||
self.assertEqual(len(cache_info.inductor_artifacts), 1)
|
||||
|
|
@ -503,7 +501,7 @@ class TestFxGraphCache(TestCase):
|
|||
a2 = torch.randn(4, 8)
|
||||
b2 = torch.randn(8, 4)
|
||||
|
||||
with fresh_inductor_cache():
|
||||
with fresh_cache():
|
||||
eager_result = fn(a, b)
|
||||
compiled_result = compiled_fn(a, b)
|
||||
self.assertEqual(eager_result, compiled_result)
|
||||
|
|
@ -519,7 +517,7 @@ class TestFxGraphCache(TestCase):
|
|||
|
||||
self.reset()
|
||||
|
||||
with fresh_inductor_cache():
|
||||
with fresh_cache():
|
||||
torch.compiler.load_cache_artifacts(artifact_bytes)
|
||||
eager_result = fn(a, b)
|
||||
compiled_result = compiled_fn(a, b)
|
||||
|
|
@ -531,7 +529,7 @@ class TestFxGraphCache(TestCase):
|
|||
|
||||
self.reset()
|
||||
|
||||
with fresh_inductor_cache():
|
||||
with fresh_cache():
|
||||
eager_result = fn(a2, b2)
|
||||
compiled_result = compiled_fn(a2, b2)
|
||||
self.assertEqual(eager_result, compiled_result)
|
||||
|
|
@ -555,7 +553,7 @@ class TestFxGraphCache(TestCase):
|
|||
return x * 2
|
||||
|
||||
# Record artifacts
|
||||
with torch.compiler.config.patch(job_id=self.id()), fresh_inductor_cache():
|
||||
with torch.compiler.config.patch(job_id=self.id()), fresh_cache():
|
||||
f(torch.randn(2, 3))
|
||||
f(torch.randn(2, 4))
|
||||
self.assertEqual(backend.frame_count, 2)
|
||||
|
|
@ -582,7 +580,7 @@ class TestFxGraphCache(TestCase):
|
|||
shutil.rmtree(os.path.join(cache_dir(), "triton"), ignore_errors=True)
|
||||
|
||||
# Hot load and hit
|
||||
with torch.compiler.config.patch({"job_id": self.id()}), fresh_inductor_cache():
|
||||
with torch.compiler.config.patch({"job_id": self.id()}), fresh_cache():
|
||||
cache_info = torch.compiler.load_cache_artifacts(artifact_bytes)
|
||||
|
||||
self.assertEqual(len(cache_info.inductor_artifacts), 2)
|
||||
|
|
@ -617,7 +615,7 @@ class TestFxGraphCache(TestCase):
|
|||
with mock.patch(
|
||||
"torch._utils_internal.get_mast_job_name_version", return_value=("foo", 5)
|
||||
):
|
||||
with fresh_inductor_cache():
|
||||
with fresh_cache():
|
||||
f(torch.randn(2, 3))
|
||||
f(torch.randn(2, 4))
|
||||
self.assertEqual(backend.frame_count, 2)
|
||||
|
|
@ -639,7 +637,7 @@ class TestFxGraphCache(TestCase):
|
|||
# Hot load and hit
|
||||
with mock.patch(
|
||||
"torch._utils_internal.get_mast_job_name_version", return_value=("bar", 10)
|
||||
), fresh_inductor_cache():
|
||||
), fresh_cache():
|
||||
cache_info = torch.compiler.load_cache_artifacts(artifact_bytes)
|
||||
|
||||
self.assertEqual(len(cache_info.pgo_artifacts), 2)
|
||||
|
|
@ -1582,7 +1580,7 @@ class TestStandaloneCompile(TestCase):
|
|||
AOTAutogradCache.clear()
|
||||
PyCodeCache.cache_clear(purge=True)
|
||||
torch._dynamo.reset()
|
||||
clear_inductor_caches()
|
||||
clear_caches()
|
||||
|
||||
def capture(self, fn, dynamic=None):
|
||||
def inner(*args):
|
||||
|
|
@ -1638,7 +1636,7 @@ class TestStandaloneCompile(TestCase):
|
|||
if format == "unpacked"
|
||||
else os.path.join(temp_dir, "compiled_artifact.bin")
|
||||
)
|
||||
with fresh_inductor_cache():
|
||||
with fresh_cache():
|
||||
gm, args, kwargs = self.capture(f)(x)
|
||||
assert not kwargs
|
||||
|
||||
|
|
@ -1647,7 +1645,7 @@ class TestStandaloneCompile(TestCase):
|
|||
|
||||
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0)
|
||||
|
||||
with fresh_inductor_cache():
|
||||
with fresh_cache():
|
||||
loaded = torch._inductor.CompiledArtifact.load(path=path, format=format)
|
||||
if dynamic:
|
||||
concrete_args = [
|
||||
|
|
@ -1679,7 +1677,7 @@ class TestStandaloneCompile(TestCase):
|
|||
def backend(gm, args, **kwargs):
|
||||
return torch._inductor.standalone_compile(gm, args)
|
||||
|
||||
with fresh_inductor_cache():
|
||||
with fresh_cache():
|
||||
compiled_out = torch.compile(f, fullgraph=True, backend=backend)(x)
|
||||
self.assertEqual(eager_out, compiled_out)
|
||||
|
||||
|
|
@ -1698,7 +1696,7 @@ class TestStandaloneCompile(TestCase):
|
|||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
path = os.path.join(temp_dir, "new_dir")
|
||||
with fresh_inductor_cache():
|
||||
with fresh_cache():
|
||||
gm, args, kwargs = self.capture(f)(x)
|
||||
assert not kwargs
|
||||
|
||||
|
|
@ -1707,7 +1705,7 @@ class TestStandaloneCompile(TestCase):
|
|||
|
||||
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0)
|
||||
|
||||
with fresh_inductor_cache():
|
||||
with fresh_cache():
|
||||
loaded = torch._inductor.CompiledArtifact.load(
|
||||
path=path, format="unpacked"
|
||||
)
|
||||
|
|
@ -1731,7 +1729,7 @@ class TestStandaloneCompile(TestCase):
|
|||
eager_out = f(x)
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
with fresh_inductor_cache():
|
||||
with fresh_cache():
|
||||
gm, args, kwargs = self.capture(f)(x)
|
||||
assert not kwargs
|
||||
|
||||
|
|
@ -1743,7 +1741,7 @@ class TestStandaloneCompile(TestCase):
|
|||
|
||||
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0)
|
||||
|
||||
with fresh_inductor_cache():
|
||||
with fresh_cache():
|
||||
# Now modify the output file and expect to see the changes
|
||||
for subdir in os.listdir(temp_dir):
|
||||
if subdir in ["aotautograd", "fxgraph"]:
|
||||
|
|
@ -1791,16 +1789,16 @@ class TestStandaloneCompile(TestCase):
|
|||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
path = os.path.join(temp_dir, "compiled_artifact.bin")
|
||||
|
||||
with fresh_inductor_cache():
|
||||
with fresh_cache():
|
||||
compiled_artifact = torch._inductor.standalone_compile(gm, args)
|
||||
compiled_artifact.save(path=path)
|
||||
|
||||
script = f"""
|
||||
import torch
|
||||
from torch._inductor.utils import fresh_inductor_cache
|
||||
from torch._inductor.utils import fresh_cache
|
||||
|
||||
arg = torch.ones(4, 1)
|
||||
with fresh_inductor_cache():
|
||||
with fresh_cache():
|
||||
loaded = torch._inductor.CompiledArtifact.load(path="{path}")
|
||||
compiled_result = loaded(arg)[0]
|
||||
|
||||
|
|
@ -1832,7 +1830,7 @@ if not torch.allclose(eager_result, compiled_result, atol=0.1, rtol=0.01):
|
|||
|
||||
x = torch.ones(3)
|
||||
torch._dynamo.mark_dynamic(x, 0)
|
||||
with fresh_inductor_cache():
|
||||
with fresh_cache():
|
||||
# captured graph is lambda s0, x: x * s0
|
||||
gm, args, kwargs = self.capture(f)(x)
|
||||
assert not kwargs
|
||||
|
|
@ -1854,7 +1852,7 @@ if not torch.allclose(eager_result, compiled_result, atol=0.1, rtol=0.01):
|
|||
|
||||
x = torch.ones(3)
|
||||
torch._dynamo.mark_dynamic(x, 0)
|
||||
with fresh_inductor_cache():
|
||||
with fresh_cache():
|
||||
# captured graph is lambda s0, x: x * s0
|
||||
gm, args, kwargs = self.capture(f)(x)
|
||||
assert not kwargs
|
||||
|
|
@ -1890,7 +1888,7 @@ if not torch.allclose(eager_result, compiled_result, atol=0.1, rtol=0.01):
|
|||
return x.shape[0] * x
|
||||
|
||||
static_x = torch.randn(3)
|
||||
with fresh_inductor_cache():
|
||||
with fresh_cache():
|
||||
# static_gm is lambda x: x * 3
|
||||
static_gm, args, kwargs = self.capture(f, dynamic=False)(static_x)
|
||||
assert not kwargs
|
||||
|
|
@ -2440,7 +2438,7 @@ class TestAutotuneCache(TestCase):
|
|||
def reset(self):
|
||||
PyCodeCache.cache_clear(purge=True)
|
||||
torch._dynamo.reset()
|
||||
clear_inductor_caches()
|
||||
clear_caches()
|
||||
|
||||
@unittest.skipIf(not HAS_CUDA, "Requires CUDA")
|
||||
@unittest.skipIf(not SM80OrLater, "Requires SM80+")
|
||||
|
|
@ -2750,20 +2748,20 @@ class TestRemoteAOTAutogradCache(TestCase):
|
|||
|
||||
class TestUtils(TestCase):
|
||||
@config.patch({"fx_graph_remote_cache": False})
|
||||
def test_fresh_inductor_cache(self):
|
||||
def test_fresh_cache(self):
|
||||
def fn(x, y):
|
||||
return x + y
|
||||
|
||||
a = torch.rand(10)
|
||||
b = torch.rand(10)
|
||||
|
||||
with fresh_inductor_cache():
|
||||
with fresh_cache():
|
||||
self.assertEqual(len(PyCodeCache.modules), 0)
|
||||
res1 = torch.compile(fn)(a, b)
|
||||
cache_dir1 = cache_dir()
|
||||
|
||||
torch._dynamo.reset()
|
||||
with fresh_inductor_cache():
|
||||
with fresh_cache():
|
||||
self.assertEqual(len(PyCodeCache.modules), 0)
|
||||
res2 = torch.compile(fn)(a, b)
|
||||
cache_dir2 = cache_dir()
|
||||
|
|
|
|||
|
|
@ -919,9 +919,9 @@ class CompiledOptimizerTests(TestCase):
|
|||
import torch._dynamo
|
||||
import torch._inductor
|
||||
from torch._dynamo.debug_utils import aot_graph_input_parser
|
||||
from torch._inductor.utils import fresh_inductor_cache
|
||||
from torch._inductor.utils import fresh_cache
|
||||
|
||||
with fresh_inductor_cache():
|
||||
with fresh_cache():
|
||||
kwargs = aot_graph_input_parser(forward)
|
||||
torch.compile(forward)(**kwargs)
|
||||
|
||||
|
|
|
|||
|
|
@ -1814,9 +1814,9 @@ class CudaReproTests(TestCase):
|
|||
|
||||
m = ToyModel().to(device="cuda:0")
|
||||
input_tensor = torch.randn(32, 3, 64, 64).to(device="cuda:0")
|
||||
from torch._inductor.utils import fresh_inductor_cache
|
||||
from torch._inductor.utils import fresh_cache
|
||||
|
||||
with fresh_inductor_cache():
|
||||
with fresh_cache():
|
||||
cm = torch.compile(m, mode="max-autotune")
|
||||
out = cm(input_tensor)
|
||||
out2 = m(input_tensor)
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ from torch._inductor.codecache import CUDACodeCache
|
|||
from torch._inductor.codegen.cuda.cuda_env import nvcc_exist
|
||||
from torch._inductor.exc import CUDACompileError
|
||||
from torch._inductor.test_case import TestCase as InductorTestCase
|
||||
from torch._inductor.utils import fresh_inductor_cache
|
||||
from torch._inductor.utils import fresh_cache
|
||||
|
||||
|
||||
_SOURCE_CODE = r"""
|
||||
|
|
@ -40,7 +40,7 @@ int saxpy(int n, float a, float *x, float *y) {
|
|||
@unittest.skipIf(config.is_fbcode(), "fbcode requires different CUDA_HOME setup")
|
||||
class TestCUDACodeCache(InductorTestCase):
|
||||
def test_cuda_load(self):
|
||||
with fresh_inductor_cache():
|
||||
with fresh_cache():
|
||||
# Test both .o and .so compilation.
|
||||
(
|
||||
object_file_path,
|
||||
|
|
@ -67,13 +67,13 @@ class TestCUDACodeCache(InductorTestCase):
|
|||
torch.testing.assert_close(y, expected_y)
|
||||
|
||||
def test_compilation_error(self):
|
||||
with fresh_inductor_cache():
|
||||
with fresh_cache():
|
||||
error_source_code = _SOURCE_CODE.replace("saxpy_device", "saxpy_wrong", 1)
|
||||
with self.assertRaises(CUDACompileError):
|
||||
CUDACodeCache.compile(error_source_code, "o")
|
||||
|
||||
def test_async_compile(self):
|
||||
with fresh_inductor_cache():
|
||||
with fresh_cache():
|
||||
async_compile = AsyncCompile()
|
||||
compiled_res = async_compile.cuda(_SOURCE_CODE, "so")
|
||||
async_compile.wait(globals())
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ from pathlib import Path
|
|||
from typing import Callable, Optional
|
||||
|
||||
from torch._inductor.codegen.cuda.serialization import get_cutlass_operation_serializer
|
||||
from torch._inductor.utils import clear_inductor_caches
|
||||
from torch._inductor.utils import clear_caches
|
||||
from torch.export import Dim
|
||||
from torch.testing._internal.logging_utils import log_settings
|
||||
|
||||
|
|
@ -38,7 +38,7 @@ from torch._inductor.exc import InductorError
|
|||
from torch._inductor.ir import FixedLayout
|
||||
from torch._inductor.select_algorithm import NoValidChoicesError
|
||||
from torch._inductor.test_case import run_tests, TestCase
|
||||
from torch._inductor.utils import fresh_inductor_cache
|
||||
from torch._inductor.utils import fresh_cache
|
||||
from torch.sparse import SparseSemiStructuredTensor, to_sparse_semi_structured
|
||||
from torch.testing import FileCheck
|
||||
from torch.testing._internal.common_cuda import (
|
||||
|
|
@ -173,7 +173,7 @@ class TestCutlassBackend(TestCase):
|
|||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
clear_inductor_caches()
|
||||
clear_caches()
|
||||
|
||||
def run_evt_test(self, model, op, shape, num_fusions=1):
|
||||
M, N = shape
|
||||
|
|
@ -618,7 +618,7 @@ class TestCutlassBackend(TestCase):
|
|||
]
|
||||
for x_shape in x_shapes:
|
||||
torch._dynamo.reset()
|
||||
clear_inductor_caches()
|
||||
clear_caches()
|
||||
|
||||
inputs = [
|
||||
(
|
||||
|
|
@ -1065,7 +1065,7 @@ class TestCutlassBackend(TestCase):
|
|||
def select_no_algorithm(*args, **kwargs):
|
||||
raise NoValidChoicesError
|
||||
|
||||
with fresh_inductor_cache():
|
||||
with fresh_cache():
|
||||
with config.patch(
|
||||
{
|
||||
"max_autotune": True,
|
||||
|
|
@ -1113,7 +1113,7 @@ class TestCutlassBackend(TestCase):
|
|||
def select_no_algorithm(*args, **kwargs):
|
||||
raise NoValidChoicesError
|
||||
|
||||
with fresh_inductor_cache():
|
||||
with fresh_cache():
|
||||
with config.patch(
|
||||
{
|
||||
"max_autotune": True,
|
||||
|
|
@ -1187,7 +1187,7 @@ class TestCutlassBackend(TestCase):
|
|||
raise NoValidChoicesError
|
||||
|
||||
def run_test(use_fast_accum):
|
||||
with fresh_inductor_cache():
|
||||
with fresh_cache():
|
||||
with config.patch(
|
||||
{
|
||||
"max_autotune": True,
|
||||
|
|
@ -1266,7 +1266,7 @@ class TestCutlassBackend(TestCase):
|
|||
def select_no_algorithm(*args, **kwargs):
|
||||
raise NoValidChoicesError
|
||||
|
||||
with fresh_inductor_cache(), config.patch(
|
||||
with fresh_cache(), config.patch(
|
||||
{
|
||||
"max_autotune": True,
|
||||
"max_autotune_gemm_backends": "CUTLASS",
|
||||
|
|
@ -1324,7 +1324,7 @@ class TestCutlassBackend(TestCase):
|
|||
def select_no_algorithm(*args, **kwargs):
|
||||
raise NoValidChoicesError
|
||||
|
||||
with fresh_inductor_cache(), config.patch(
|
||||
with fresh_cache(), config.patch(
|
||||
{
|
||||
"max_autotune": True,
|
||||
"max_autotune_gemm_backends": "CUTLASS",
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ from pathlib import Path
|
|||
|
||||
import torch
|
||||
from torch._inductor import config, test_operators
|
||||
from torch._inductor.utils import fresh_inductor_cache
|
||||
from torch._inductor.utils import fresh_cache
|
||||
from torch.testing._internal.common_utils import skipIfWindows
|
||||
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU
|
||||
from torch.testing._internal.logging_utils import multiple_logs_to_string
|
||||
|
|
@ -44,7 +44,7 @@ class TestDebugTrace(test_torchinductor.TestCase):
|
|||
"torch._inductor.debug", "ir_pre_fusion", "ir_post_fusion"
|
||||
)
|
||||
|
||||
# TODO(aakhundov): make this work with fresh_inductor_cache
|
||||
# TODO(aakhundov): make this work with fresh_cache
|
||||
# instead of force_disable_caches. currently, with the latter
|
||||
# enabled, we get `inductor [('fxgraph_cache_hit', 1)]` in
|
||||
# the counters: so the cache is actually hit and the test fails.
|
||||
|
|
@ -263,7 +263,7 @@ op2.node.kernel = extern_kernels.mm""",
|
|||
# no failure
|
||||
with self.assertLogs(
|
||||
logging.getLogger("torch._inductor.debug"), level=logging.WARNING
|
||||
), fresh_inductor_cache():
|
||||
), fresh_cache():
|
||||
m = ToyModel().to(device=GPU_TYPE)
|
||||
m = torch.compile(m, mode="max-autotune")
|
||||
input_tensor = torch.randn(100).to(device=GPU_TYPE)
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ import torch._inductor.metrics as metrics
|
|||
import torch.utils.flop_counter
|
||||
from torch._dynamo.utils import counters
|
||||
from torch._inductor.ir import FixedLayout
|
||||
from torch._inductor.utils import fresh_inductor_cache
|
||||
from torch._inductor.utils import fresh_cache
|
||||
from torch.testing._internal.common_cuda import SM70OrLater
|
||||
from torch.testing._internal.common_device_type import (
|
||||
dtypes,
|
||||
|
|
@ -77,7 +77,7 @@ class TestScheduler(TestCase):
|
|||
for op, example_inputs, kwargs in tc:
|
||||
comp = torch.compile(op)
|
||||
torch._dynamo.reset()
|
||||
with fresh_inductor_cache():
|
||||
with fresh_cache():
|
||||
comp(*example_inputs, **kwargs)
|
||||
self.assertEqual(metrics.num_bytes_accessed, 0)
|
||||
self.assertEqual(any(m[1] for m in metrics.node_runtimes), False)
|
||||
|
|
@ -108,7 +108,7 @@ class TestScheduler(TestCase):
|
|||
|
||||
comp = torch.compile(op)
|
||||
torch._dynamo.reset()
|
||||
with fresh_inductor_cache():
|
||||
with fresh_cache():
|
||||
comp(*example_inputs, **kwargs)
|
||||
self.assertEqual(enba, metrics.num_bytes_accessed)
|
||||
nonzero_node_runtimes = sum(1 for x in metrics.node_runtimes if x[1] != 0)
|
||||
|
|
@ -152,7 +152,7 @@ class TestScheduler(TestCase):
|
|||
comp = torch.compile(op, options=options)
|
||||
# next two lines are required, otherwise the flops will be cached from pervious runs of this function.
|
||||
torch._dynamo.reset()
|
||||
with fresh_inductor_cache():
|
||||
with fresh_cache():
|
||||
# actually run to set the counters
|
||||
comp(*example_inputs, **kwargs)
|
||||
with FlopCounterMode() as mode:
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ from torch._dynamo.testing import rand_strided
|
|||
from torch._inductor import config
|
||||
from torch._inductor.codecache import PyCodeCache
|
||||
from torch._inductor.test_case import run_tests, TestCase
|
||||
from torch._inductor.utils import fresh_inductor_cache
|
||||
from torch._inductor.utils import fresh_cache
|
||||
from torch.testing import FileCheck
|
||||
from torch.testing._internal.common_cuda import xfailIfSM89
|
||||
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU, IS_BIG_GPU
|
||||
|
|
@ -152,7 +152,7 @@ class TestKernelBenchmark(TestCase):
|
|||
@unittest.skipIf(
|
||||
not IS_BIG_GPU, "Skipping triton backend only since not big GPU (not enough SM)"
|
||||
)
|
||||
@fresh_inductor_cache()
|
||||
@fresh_cache()
|
||||
def test_matmul_triton_kernel_benchmark(self):
|
||||
M = 12544
|
||||
N = 256
|
||||
|
|
@ -170,7 +170,7 @@ class TestKernelBenchmark(TestCase):
|
|||
@config.patch(
|
||||
max_autotune=True, max_autotune_gemm_backends="TRITON", shape_padding=False
|
||||
)
|
||||
@fresh_inductor_cache()
|
||||
@fresh_cache()
|
||||
def test_mm_triton_kernel_benchmark(self):
|
||||
M = 2048
|
||||
N = 2432
|
||||
|
|
|
|||
|
|
@ -49,7 +49,7 @@ from torch.utils._triton import has_triton_tma_device
|
|||
aten = torch.ops.aten
|
||||
from torch._inductor.mock_cache import global_stats, PatchCaches, Stats
|
||||
from torch._inductor.test_case import run_tests, TestCase
|
||||
from torch._inductor.utils import fresh_inductor_cache, run_and_get_code
|
||||
from torch._inductor.utils import fresh_cache, run_and_get_code
|
||||
from torch._inductor.virtualized import V
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch.testing import FileCheck
|
||||
|
|
@ -325,7 +325,7 @@ class TestMaxAutotune(TestCase):
|
|||
|
||||
torch.testing.assert_close(c_actual, c_expected, atol=1e-2, rtol=1e-2)
|
||||
|
||||
@fresh_inductor_cache()
|
||||
@fresh_cache()
|
||||
@unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support sm carveout")
|
||||
@unittest.skipIf(IS_WINDOWS, "Windows doesn't support persistent TMA")
|
||||
@unittest.skipIf(
|
||||
|
|
@ -470,7 +470,7 @@ class TestMaxAutotune(TestCase):
|
|||
FileCheck().check_not("extern_kernels.convolution").run(code[0])
|
||||
self.assertEqual(conv1x1(input_tensor), out, atol=1e-2, rtol=0)
|
||||
|
||||
@fresh_inductor_cache()
|
||||
@fresh_cache()
|
||||
@config.patch(max_autotune=True, max_fusion_size=2)
|
||||
def test_jit_fusion_matches_aot_fusion(self):
|
||||
# In this example, AOTInductor's JIT-compile will fuse(buf1, buf2) due
|
||||
|
|
@ -563,7 +563,7 @@ class TestMaxAutotune(TestCase):
|
|||
def f(x, y):
|
||||
return x @ y
|
||||
|
||||
with fresh_inductor_cache():
|
||||
with fresh_cache():
|
||||
act = torch.compile(f)(x, y)
|
||||
ref = f(x, y)
|
||||
self.assertTrue(torch.allclose(act, ref, atol=4 * 1e-3, rtol=4 * 1e-3))
|
||||
|
|
@ -1146,7 +1146,7 @@ class TestMaxAutotune(TestCase):
|
|||
self.assertEqual(generate_and_load_args - 1, make_key_args)
|
||||
self.assertEqual(generate_and_load_args, 16)
|
||||
|
||||
@fresh_inductor_cache()
|
||||
@fresh_cache()
|
||||
@config.patch(
|
||||
{
|
||||
"max_autotune": True,
|
||||
|
|
@ -1213,7 +1213,7 @@ class TestMaxAutotune(TestCase):
|
|||
b = torch.rand(22, 30, device=GPU_TYPE)
|
||||
|
||||
# Valid cache hit.
|
||||
with fresh_inductor_cache():
|
||||
with fresh_cache():
|
||||
reset_counters()
|
||||
compile_results = torch.compile(func_test1, dynamic=False)(a, b, a, b)
|
||||
eager_results = func_test1(a, b, a, b)
|
||||
|
|
@ -1246,7 +1246,7 @@ class TestMaxAutotune(TestCase):
|
|||
)
|
||||
|
||||
# Test symbolic shapes with different symbols. Will cache miss due to different symbols in inputs.
|
||||
with fresh_inductor_cache():
|
||||
with fresh_cache():
|
||||
a = torch.rand(10, 22, device=GPU_TYPE)
|
||||
b = torch.rand(22, 30, device=GPU_TYPE)
|
||||
|
||||
|
|
@ -1297,7 +1297,7 @@ class TestMaxAutotune(TestCase):
|
|||
)
|
||||
|
||||
# Test duck typing.
|
||||
with fresh_inductor_cache():
|
||||
with fresh_cache():
|
||||
reset_counters()
|
||||
|
||||
compile_results = torch.compile(func_test1, dynamic=True)(a, b, a, b)
|
||||
|
|
@ -1313,7 +1313,7 @@ class TestMaxAutotune(TestCase):
|
|||
x = torch.matmul(x, x)
|
||||
return x
|
||||
|
||||
with fresh_inductor_cache():
|
||||
with fresh_cache():
|
||||
reset_counters()
|
||||
input = torch.rand(10, 10, device=GPU_TYPE)
|
||||
|
||||
|
|
@ -1324,7 +1324,7 @@ class TestMaxAutotune(TestCase):
|
|||
self.assertEqual(hits(), 36)
|
||||
self.assertEqual(misses(), 4)
|
||||
|
||||
with fresh_inductor_cache():
|
||||
with fresh_cache():
|
||||
reset_counters()
|
||||
input = torch.rand(10, 10, device=GPU_TYPE)
|
||||
|
||||
|
|
@ -1343,7 +1343,7 @@ class TestMaxAutotune(TestCase):
|
|||
b = torch.matmul(torch.cat([x, z], 1), torch.cat([y, m, l], 0))
|
||||
return a, b
|
||||
|
||||
with fresh_inductor_cache():
|
||||
with fresh_cache():
|
||||
a = torch.rand(10, 22, device=GPU_TYPE)
|
||||
b = torch.rand(22, 30, device=GPU_TYPE)
|
||||
c = torch.rand(10, 11, device=GPU_TYPE)
|
||||
|
|
@ -1384,7 +1384,7 @@ class TestMaxAutotune(TestCase):
|
|||
]
|
||||
|
||||
# Valid cache hit.
|
||||
with fresh_inductor_cache():
|
||||
with fresh_cache():
|
||||
torch._dynamo.utils.counters.clear()
|
||||
compile_results = torch.compile(func_test1, dynamic=False)(a, b, a, b)
|
||||
eager_results = func_test1(a, b, a, b)
|
||||
|
|
@ -1424,7 +1424,7 @@ class TestMaxAutotune(TestCase):
|
|||
]
|
||||
|
||||
# Valid cache hit.
|
||||
with fresh_inductor_cache():
|
||||
with fresh_cache():
|
||||
torch._dynamo.utils.counters.clear()
|
||||
compile_results = torch.compile(func_test1, dynamic=False)(a, b, a, b)
|
||||
eager_results = func_test1(a, b, a, b)
|
||||
|
|
@ -1543,7 +1543,7 @@ class TestMaxAutotunePrecompile(TestCase):
|
|||
fn_c = torch.compile(mode="max-autotune-no-cudagraphs")(fn)
|
||||
self.assertEqual(counters["inductor"]["select_algorithm_precompile"], 0)
|
||||
|
||||
@fresh_inductor_cache()
|
||||
@fresh_cache()
|
||||
@config.patch(search_autotune_cache=True)
|
||||
def test_search_autotune_cache(self):
|
||||
def fn(a, b, c):
|
||||
|
|
@ -1811,12 +1811,12 @@ class TestMaxAutotuneRemoteCache(TestCase):
|
|||
os.environ.pop("TRITON_CACHE_MANAGER", None)
|
||||
with config.patch({"max_autotune": True}):
|
||||
for _ in range(4):
|
||||
with fresh_inductor_cache():
|
||||
with fresh_cache():
|
||||
torch.compile(mm, dynamic=dynamic)(a, b)
|
||||
reset()
|
||||
with torch.compiler.config.patch(
|
||||
{"cache_key_tag": "test"}
|
||||
), fresh_inductor_cache():
|
||||
), fresh_cache():
|
||||
torch.compile(mm, dynamic=dynamic)(a, b)
|
||||
reset()
|
||||
|
||||
|
|
@ -1825,12 +1825,10 @@ class TestMaxAutotuneRemoteCache(TestCase):
|
|||
|
||||
global_stats.reset()
|
||||
for _ in range(4):
|
||||
with fresh_inductor_cache():
|
||||
with fresh_cache():
|
||||
torch.compile(f, dynamic=dynamic)(x, y)
|
||||
reset()
|
||||
with torch.compiler.config.patch(
|
||||
{"cache_key_tag": "test"}
|
||||
), fresh_inductor_cache():
|
||||
with torch.compiler.config.patch({"cache_key_tag": "test"}), fresh_cache():
|
||||
torch.compile(mm, dynamic=dynamic)(a, b)
|
||||
reset()
|
||||
global_stats.report()
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ from torch._inductor.fx_passes.pad_mm import (
|
|||
should_pad_mm_bf16,
|
||||
)
|
||||
from torch._inductor.test_case import run_tests, TestCase
|
||||
from torch._inductor.utils import fresh_inductor_cache, is_big_gpu, run_and_get_code
|
||||
from torch._inductor.utils import fresh_cache, is_big_gpu, run_and_get_code
|
||||
from torch.testing import FileCheck
|
||||
from torch.testing._internal.common_utils import skipIfRocm
|
||||
from torch.testing._internal.inductor_utils import HAS_CUDA
|
||||
|
|
@ -362,7 +362,7 @@ class PadMMTest(TestCase):
|
|||
self.assertEqual(out, inps[0] @ inps[1])
|
||||
|
||||
@inductor_config.patch(force_shape_pad=True)
|
||||
@fresh_inductor_cache()
|
||||
@fresh_cache()
|
||||
def test_pad_addmm_2d_bias(self):
|
||||
@torch.compile()
|
||||
def foo(input, x, y):
|
||||
|
|
@ -419,7 +419,7 @@ class PadMMTest(TestCase):
|
|||
res2, bmm_expected_result
|
||||
), "BMM results are not identical"
|
||||
|
||||
@fresh_inductor_cache()
|
||||
@fresh_cache()
|
||||
def test_exclude_padding(self):
|
||||
@torch.compile()
|
||||
def mm(a, b):
|
||||
|
|
@ -448,7 +448,7 @@ class PadMMTest(TestCase):
|
|||
repr(local_cache)
|
||||
)
|
||||
|
||||
@fresh_inductor_cache()
|
||||
@fresh_cache()
|
||||
@inductor_config.patch(max_pointwise_cat_inputs=2)
|
||||
def test_exclude_cat_padding(self):
|
||||
@torch.compile()
|
||||
|
|
@ -475,7 +475,7 @@ class PadMMTest(TestCase):
|
|||
"No perf regression on H100+ with BF16",
|
||||
)
|
||||
@skipIfRocm
|
||||
@fresh_inductor_cache()
|
||||
@fresh_cache()
|
||||
@inductor_config.patch(
|
||||
post_grad_fusion_options={"pad_aten_mm_pass": {"k_threshold_to_pad": 8388608}}
|
||||
)
|
||||
|
|
@ -508,7 +508,7 @@ class PadMMTest(TestCase):
|
|||
|
||||
assert torch.allclose(res2, mm_expected_result), "MM results are not identical"
|
||||
|
||||
@fresh_inductor_cache()
|
||||
@fresh_cache()
|
||||
@inductor_config.patch(
|
||||
{
|
||||
"triton.unique_kernel_names": "original_aten",
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ import torch.nn.functional as F
|
|||
from torch import sym_int, SymBool, SymFloat, SymInt
|
||||
from torch._C import _disabled_torch_function_impl
|
||||
from torch._dynamo.testing import CompileCounterWithBackend
|
||||
from torch._inductor.utils import fresh_inductor_cache
|
||||
from torch._inductor.utils import fresh_cache
|
||||
from torch.fx.experimental import sym_node
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch.fx.experimental.sym_node import method_to_operator, SymNode, to_node
|
||||
|
|
@ -3150,7 +3150,7 @@ class TestUnbacked(TestCase):
|
|||
|
||||
|
||||
class TestUbackedOps(TestCase):
|
||||
@fresh_inductor_cache()
|
||||
@fresh_cache()
|
||||
@skipIfTorchDynamo("not allowed to trace mark_unbacked")
|
||||
@torch._dynamo.config.patch("capture_scalar_outputs", True)
|
||||
def test_unbacked_reshape1(self):
|
||||
|
|
|
|||
|
|
@ -42,7 +42,7 @@ from torch._inductor.runtime.compile_tasks import (
|
|||
_set_triton_ptxas_path,
|
||||
_worker_compile_triton,
|
||||
)
|
||||
from torch._inductor.utils import clear_on_fresh_inductor_cache
|
||||
from torch._inductor.utils import clear_on_fresh_cache
|
||||
from torch._inductor.virtualized import V
|
||||
from torch.hub import _Faketqdm, tqdm
|
||||
from torch.utils._ordered_set import OrderedSet
|
||||
|
|
@ -162,7 +162,7 @@ def get_compile_threads() -> int:
|
|||
return config.compile_threads
|
||||
|
||||
|
||||
@clear_on_fresh_inductor_cache
|
||||
@clear_on_fresh_cache
|
||||
class CompiledTritonKernels:
|
||||
"""
|
||||
In memory cache for storing compiled triton kernels.
|
||||
|
|
|
|||
|
|
@ -86,7 +86,7 @@ from torch._inductor.runtime.compile_tasks import _reload_python_module
|
|||
from torch._inductor.runtime.runtime_utils import cache_dir, default_cache_dir
|
||||
from torch._inductor.utils import (
|
||||
ALIGN_BYTES,
|
||||
clear_on_fresh_inductor_cache,
|
||||
clear_on_fresh_cache,
|
||||
is_linux,
|
||||
is_windows,
|
||||
)
|
||||
|
|
@ -236,7 +236,7 @@ class CacheBase:
|
|||
return system
|
||||
|
||||
@staticmethod
|
||||
@clear_on_fresh_inductor_cache
|
||||
@clear_on_fresh_cache
|
||||
@functools.cache
|
||||
def get_local_cache_path() -> Path:
|
||||
return Path(os.path.join(cache_dir(), "cache", CacheBase.get_system()["hash"]))
|
||||
|
|
@ -1595,7 +1595,7 @@ def split_aot_inductor_output_path(path: str) -> tuple[str, str]:
|
|||
return path, ""
|
||||
|
||||
|
||||
@clear_on_fresh_inductor_cache
|
||||
@clear_on_fresh_cache
|
||||
class CudaKernelParamCache:
|
||||
cache: dict[str, dict[str, Any]] = {}
|
||||
cache_clear = staticmethod(cache.clear)
|
||||
|
|
@ -2291,7 +2291,7 @@ def custom_op_wrapper(op: str, *args: Any) -> Union[list[c_void_p], c_void_p, No
|
|||
|
||||
# Precompiled headers are persistent past program runtime, but associated with one
|
||||
# specific compiler version and set of flags. We explicitly use default_cache_dir here
|
||||
# because these headers need to be global, rather than ignored by fresh_inductor_cache.
|
||||
# because these headers need to be global, rather than ignored by fresh_cache.
|
||||
_HEADER_DIR = os.path.join(default_cache_dir(), "precompiled_headers")
|
||||
_HEADER_LOCK_DIR = os.path.join(_HEADER_DIR, "locks")
|
||||
|
||||
|
|
@ -2378,7 +2378,7 @@ def _get_cpp_wrapper_header(device: str, aot_mode: bool = False) -> str:
|
|||
)
|
||||
|
||||
|
||||
@clear_on_fresh_inductor_cache
|
||||
@clear_on_fresh_cache
|
||||
class CppCodeCache:
|
||||
"""Compiles and caches C++ libraries. Users of this class supply the source code to
|
||||
be compiled, while compilation flags are set by CppBuilder."""
|
||||
|
|
@ -2587,7 +2587,7 @@ def _worker_compile_cpp(
|
|||
|
||||
|
||||
# Customized Python binding for cpp kernels
|
||||
@clear_on_fresh_inductor_cache
|
||||
@clear_on_fresh_cache
|
||||
class CppPythonBindingsCodeCache(CppCodeCache):
|
||||
cache: dict[str, Callable[[], Union[CDLL, ModuleType]]] = {}
|
||||
cache_clear = staticmethod(cache.clear)
|
||||
|
|
@ -2768,7 +2768,7 @@ class CppPythonBindingsCodeCache(CppCodeCache):
|
|||
return cls.load_pybinding_async(*args, **kwargs)()
|
||||
|
||||
|
||||
@clear_on_fresh_inductor_cache
|
||||
@clear_on_fresh_cache
|
||||
class CppWrapperCodeCache(CppPythonBindingsCodeCache):
|
||||
cache: dict[str, Callable[[], Union[CDLL, ModuleType]]] = {}
|
||||
cache_clear = staticmethod(cache.clear)
|
||||
|
|
@ -2837,7 +2837,7 @@ class CppWrapperCodeCache(CppPythonBindingsCodeCache):
|
|||
return _get_cpp_wrapper_header(device)
|
||||
|
||||
|
||||
@clear_on_fresh_inductor_cache
|
||||
@clear_on_fresh_cache
|
||||
class HalideCodeCache(CppPythonBindingsCodeCache):
|
||||
cache: dict[str, Callable[[], Union[ModuleType, CDLL]]] = {}
|
||||
cache_clear = staticmethod(cache.clear)
|
||||
|
|
@ -3140,10 +3140,10 @@ class HalideCodeCache(CppPythonBindingsCodeCache):
|
|||
target = "host-cuda" if device_type == "cuda" else "host"
|
||||
if cls._standalone_runtime_path:
|
||||
assert not os.path.exists(cls._standalone_runtime_path)
|
||||
# We hit this case in unittests when we run with fresh_inductor_cache()
|
||||
# We hit this case in unittests when we run with fresh_cache()
|
||||
# Generating a fresh runtime over and over causes errors because we initialize
|
||||
# cuda hundreds of times in the same process and run out of file descriptors.
|
||||
# Workaround by jail breaking the current fresh_inductor_cache().
|
||||
# Workaround by jail breaking the current fresh_cache().
|
||||
base = default_cache_dir()
|
||||
else:
|
||||
base = cache_dir()
|
||||
|
|
@ -3239,7 +3239,7 @@ def touch(filename: str) -> None:
|
|||
open(filename, "a").close()
|
||||
|
||||
|
||||
@clear_on_fresh_inductor_cache
|
||||
@clear_on_fresh_cache
|
||||
class PyCodeCache:
|
||||
# Track the loaded modules so we can remove the on-disk artifacts when
|
||||
# clearing the cache. Note also that we may load the same path more
|
||||
|
|
@ -3625,7 +3625,7 @@ class DLLWrapper:
|
|||
self.close()
|
||||
|
||||
|
||||
@clear_on_fresh_inductor_cache
|
||||
@clear_on_fresh_cache
|
||||
class CUDACodeCache:
|
||||
"""
|
||||
A cache for managing the compilation and loading of CUDA source code specifically for CUTLASS.
|
||||
|
|
@ -3792,7 +3792,7 @@ class CUDACodeCache:
|
|||
fh.write(error_json)
|
||||
|
||||
|
||||
@clear_on_fresh_inductor_cache
|
||||
@clear_on_fresh_cache
|
||||
class ROCmCodeCache:
|
||||
@dataclasses.dataclass
|
||||
class CacheEntry:
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ import shutil
|
|||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch._inductor.utils import clear_on_fresh_inductor_cache
|
||||
from torch._inductor.utils import clear_on_fresh_cache
|
||||
|
||||
from ... import config
|
||||
|
||||
|
|
@ -12,7 +12,7 @@ from ... import config
|
|||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@clear_on_fresh_inductor_cache
|
||||
@clear_on_fresh_cache
|
||||
@functools.lru_cache(1)
|
||||
def get_cuda_arch() -> Optional[str]:
|
||||
try:
|
||||
|
|
@ -27,7 +27,7 @@ def get_cuda_arch() -> Optional[str]:
|
|||
return None
|
||||
|
||||
|
||||
@clear_on_fresh_inductor_cache
|
||||
@clear_on_fresh_cache
|
||||
@functools.lru_cache(1)
|
||||
def get_cuda_version() -> Optional[str]:
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ from torch._inductor.codecache import cutlass_key
|
|||
from torch._inductor.codegen.cuda.cuda_env import get_cuda_arch, get_cuda_version
|
||||
from torch._inductor.codegen.cuda.serialization import get_cutlass_operation_serializer
|
||||
from torch._inductor.runtime.cache_dir_utils import cache_dir
|
||||
from torch._inductor.utils import clear_on_fresh_inductor_cache
|
||||
from torch._inductor.utils import clear_on_fresh_cache
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
|
@ -47,7 +47,7 @@ def _generate_config_filename(request_key: str) -> str:
|
|||
return f"{CONFIG_PREFIX}_{request_key}.json"
|
||||
|
||||
|
||||
@clear_on_fresh_inductor_cache
|
||||
@clear_on_fresh_cache
|
||||
@functools.cache
|
||||
def maybe_fetch_ops() -> Optional[list[Any]]:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ from typing import Any, Optional
|
|||
import sympy
|
||||
|
||||
import torch
|
||||
from torch._inductor.utils import clear_on_fresh_inductor_cache
|
||||
from torch._inductor.utils import clear_on_fresh_cache
|
||||
|
||||
from ... import config
|
||||
from ...ir import Layout
|
||||
|
|
@ -250,7 +250,7 @@ class CUTLASSArgs:
|
|||
self.architectures = _normalize_cuda_arch(self.architectures)
|
||||
|
||||
|
||||
@clear_on_fresh_inductor_cache
|
||||
@clear_on_fresh_cache
|
||||
@functools.cache
|
||||
def _gen_ops_cached(arch, version) -> dict[Any, Any]:
|
||||
# Note: Cache needs to be specific for cuda architecture and version
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ import torch.utils._pytree as pytree
|
|||
from torch._inductor.codegen.cuda.cutlass_cache import maybe_fetch_ops
|
||||
from torch._inductor.scheduler import BaseSchedulerNode
|
||||
from torch._inductor.select_algorithm import create_inputs_key
|
||||
from torch._inductor.utils import clear_on_fresh_inductor_cache
|
||||
from torch._inductor.utils import clear_on_fresh_cache
|
||||
|
||||
from ... import ir
|
||||
from ...config import cuda as inductor_cuda_config
|
||||
|
|
@ -405,7 +405,7 @@ int main(int argc, char** argv) {
|
|||
""" # noqa: B950
|
||||
|
||||
|
||||
@clear_on_fresh_inductor_cache
|
||||
@clear_on_fresh_cache
|
||||
class CUTLASSGemmTemplate(CUTLASSTemplate, ABC):
|
||||
"""
|
||||
CUTLASS GEMM Template, which is used to generate CUTLASS GEMM kernels
|
||||
|
|
|
|||
|
|
@ -75,7 +75,7 @@ from torch._inductor.runtime.cache_dir_utils import cache_dir
|
|||
from torch._inductor.utils import (
|
||||
BoxedBool,
|
||||
count_tangents,
|
||||
fresh_inductor_cache,
|
||||
fresh_cache,
|
||||
get_all_devices,
|
||||
InputType,
|
||||
is_gpu,
|
||||
|
|
@ -675,7 +675,7 @@ def with_fresh_cache_if_config() -> Generator[None, None, None]:
|
|||
# Don't delete the cache dir because it has to survive beyond the
|
||||
# compile_fx call. Let's put the temp dirs under the default cache
|
||||
# dir so they're easier to locate.
|
||||
with fresh_inductor_cache(dir=cache_dir(), delete=False):
|
||||
with fresh_cache(dir=cache_dir(), delete=False):
|
||||
yield
|
||||
else:
|
||||
yield
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ from torch._inductor.compile_worker.subproc_pool import (
|
|||
SubprocKind,
|
||||
SubprocPool,
|
||||
)
|
||||
from torch._inductor.utils import clear_inductor_caches
|
||||
from torch._inductor.utils import clear_caches
|
||||
|
||||
from .compile_fx_ext import (
|
||||
_OutOfProcessFxCompile,
|
||||
|
|
@ -77,14 +77,14 @@ class _SubprocessFxCompile(_OutOfProcessFxCompile):
|
|||
# tmpdir still exists and fails to compile.
|
||||
#
|
||||
# TODO: We probably should be using a separate tmpdir in the worker
|
||||
# anyway... but we should probably still respect clear_inductor_caches()
|
||||
# anyway... but we should probably still respect clear_caches()
|
||||
# in the parent... maybe?
|
||||
#
|
||||
# TODO: We could be less aggressive by keeping a clock which gets
|
||||
# incremented when we clear the cache, send the clock to the worker and
|
||||
# only clear caches if the clock changed since last time.
|
||||
#
|
||||
clear_inductor_caches()
|
||||
clear_caches()
|
||||
torch._inductor.metrics.reset()
|
||||
|
||||
# TODO: turn off config.fx_graph_async_compile
|
||||
|
|
|
|||
|
|
@ -39,15 +39,15 @@ def triton_cache_dir(device: int) -> str:
|
|||
|
||||
@contextmanager
|
||||
def temporary_cache_dir(directory: str) -> Generator[None, None, None]:
|
||||
from torch._inductor.utils import clear_inductor_caches
|
||||
from torch._inductor.utils import clear_caches
|
||||
|
||||
original = os.environ.get("TORCHINDUCTOR_CACHE_DIR")
|
||||
os.environ["TORCHINDUCTOR_CACHE_DIR"] = directory
|
||||
try:
|
||||
clear_inductor_caches()
|
||||
clear_caches()
|
||||
yield
|
||||
finally:
|
||||
clear_inductor_caches()
|
||||
clear_caches()
|
||||
if original is None:
|
||||
del os.environ["TORCHINDUCTOR_CACHE_DIR"]
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -28,7 +28,7 @@ import torch._inductor.async_compile # noqa: F401 required to warm up AsyncComp
|
|||
from torch._dynamo.device_interface import get_interface_for_device
|
||||
from torch._dynamo.testing import rand_strided
|
||||
from torch._dynamo.utils import counters, dynamo_timed, identity, preserve_rng_state
|
||||
from torch._inductor.utils import clear_on_fresh_inductor_cache
|
||||
from torch._inductor.utils import clear_on_fresh_cache
|
||||
from torch.utils._filelock import FileLock
|
||||
from torch.utils._ordered_set import OrderedSet
|
||||
|
||||
|
|
@ -1291,7 +1291,7 @@ class TritonTemplate(KernelTemplate):
|
|||
self.debug = debug
|
||||
self._cache_codegen_enabled_for_template = cache_codegen_enabled_for_template
|
||||
self._generated_code_cache: GeneratedCodeCache = GeneratedCodeCache()
|
||||
clear_on_fresh_inductor_cache(self._generated_code_cache)
|
||||
clear_on_fresh_cache(self._generated_code_cache)
|
||||
# When prologue_loads_all_inputs is true, prologue_supported_inputs is populated during def_kernel
|
||||
# by adding all inputs.
|
||||
self.prologue_loads_all_inputs = prologue_loads_all_inputs
|
||||
|
|
@ -2123,7 +2123,7 @@ class AlgorithmSelectorCache(PersistentCache):
|
|||
# list of callbacks that are called after benchmarking
|
||||
self.feedback_saver_fns: list[FeedbackFunction] = []
|
||||
|
||||
clear_on_fresh_inductor_cache(self)
|
||||
clear_on_fresh_cache(self)
|
||||
|
||||
def cache_clear(self) -> None:
|
||||
self.precompile_cache.clear()
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ from torch._dynamo.test_case import (
|
|||
)
|
||||
from torch._functorch import config as functorch_config
|
||||
from torch._inductor import config
|
||||
from torch._inductor.utils import fresh_inductor_cache
|
||||
from torch._inductor.utils import fresh_cache
|
||||
|
||||
|
||||
def run_tests(needs: Union[str, tuple[str, ...]] = ()) -> None:
|
||||
|
|
@ -41,7 +41,7 @@ class TestCase(DynamoTestCase):
|
|||
os.environ.get("INDUCTOR_TEST_DISABLE_FRESH_CACHE") != "1"
|
||||
and os.environ.get("TORCH_COMPILE_DEBUG") != "1"
|
||||
):
|
||||
self._inductor_test_stack.enter_context(fresh_inductor_cache())
|
||||
self._inductor_test_stack.enter_context(fresh_cache())
|
||||
|
||||
def tearDown(self) -> None:
|
||||
super().tearDown()
|
||||
|
|
|
|||
|
|
@ -1015,29 +1015,6 @@ def get_all_devices(gm: torch.fx.GraphModule) -> OrderedSet[torch.device]:
|
|||
return input_devices | out_devices
|
||||
|
||||
|
||||
_registered_caches: list[Any] = []
|
||||
|
||||
|
||||
def clear_on_fresh_inductor_cache(obj: Any) -> Any:
|
||||
"""
|
||||
Use this decorator to register any caches that should be cache_clear'd
|
||||
with fresh_inductor_cache().
|
||||
"""
|
||||
if not hasattr(obj, "cache_clear") or not callable(obj.cache_clear):
|
||||
raise AttributeError(f"{obj} does not have a cache_clear method")
|
||||
|
||||
_registered_caches.append(obj)
|
||||
return obj
|
||||
|
||||
|
||||
def clear_inductor_caches() -> None:
|
||||
"""
|
||||
Clear all registered caches.
|
||||
"""
|
||||
for obj in _registered_caches:
|
||||
obj.cache_clear()
|
||||
|
||||
|
||||
import gc
|
||||
|
||||
|
||||
|
|
@ -1070,19 +1047,42 @@ def unload_xpu_triton_pyds() -> None:
|
|||
gc.collect()
|
||||
|
||||
|
||||
_registered_caches: list[Any] = []
|
||||
|
||||
|
||||
def clear_on_fresh_cache(obj: Any) -> Any:
|
||||
"""
|
||||
Use this decorator to register any caches that should be cache_clear'd
|
||||
with fresh_cache().
|
||||
"""
|
||||
if not hasattr(obj, "cache_clear") or not callable(obj.cache_clear):
|
||||
raise AttributeError(f"{obj} does not have a cache_clear method")
|
||||
|
||||
_registered_caches.append(obj)
|
||||
return obj
|
||||
|
||||
|
||||
def clear_caches() -> None:
|
||||
"""
|
||||
Clear all registered caches.
|
||||
"""
|
||||
for obj in _registered_caches:
|
||||
obj.cache_clear()
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def fresh_inductor_cache(
|
||||
def fresh_cache(
|
||||
cache_entries: Optional[dict[str, Any]] = None,
|
||||
dir: Optional[str] = None,
|
||||
delete: bool = True,
|
||||
) -> Iterator[None]:
|
||||
"""
|
||||
Contextmanager that provides a clean tmp cachedir for inductor.
|
||||
Contextmanager that provides a clean tmp cachedir for pt2 caches.
|
||||
|
||||
Optionally, pass a dict as 'cache_entries' to get a list of filenames and sizes
|
||||
generated with this cache instance.
|
||||
"""
|
||||
clear_inductor_caches()
|
||||
clear_caches()
|
||||
|
||||
inductor_cache_dir = tempfile.mkdtemp(dir=dir)
|
||||
try:
|
||||
|
|
@ -1123,7 +1123,13 @@ def fresh_inductor_cache(
|
|||
log.warning("on error, temporary cache dir kept at %s", inductor_cache_dir)
|
||||
raise
|
||||
finally:
|
||||
clear_inductor_caches()
|
||||
clear_caches()
|
||||
|
||||
|
||||
# Deprecated functions -- only keeping them for BC reasons
|
||||
clear_on_fresh_inductor_cache = clear_on_fresh_cache
|
||||
clear_inductor_caches = clear_caches
|
||||
fresh_inductor_cache = fresh_cache
|
||||
|
||||
|
||||
def argsort(seq: Sequence[Any]) -> list[int]:
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ from benchmark_utils import ( # type: ignore[import-not-found]
|
|||
)
|
||||
|
||||
import torch
|
||||
from torch._inductor.utils import fresh_inductor_cache
|
||||
from torch._inductor.utils import fresh_cache
|
||||
|
||||
|
||||
class BenchmarkRunnerMixedMM(BenchmarkRunner): # type: ignore[misc, no-any-unimported]
|
||||
|
|
@ -59,7 +59,7 @@ class BenchmarkRunnerMixedMM(BenchmarkRunner): # type: ignore[misc, no-any-unim
|
|||
)
|
||||
b = b.to(dtype=dtype_right)
|
||||
|
||||
with fresh_inductor_cache():
|
||||
with fresh_cache():
|
||||
|
||||
def mixed_mm(A, B):
|
||||
return torch.mm(A, B.to(A.dtype))
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ from benchmark_utils import ( # type: ignore[import-not-found]
|
|||
)
|
||||
|
||||
import torch
|
||||
from torch._inductor.utils import fresh_inductor_cache
|
||||
from torch._inductor.utils import fresh_cache
|
||||
|
||||
|
||||
class BenchmarkRunnerMM(BenchmarkRunner): # type: ignore[misc, no-any-unimported]
|
||||
|
|
@ -57,7 +57,7 @@ class BenchmarkRunnerMM(BenchmarkRunner): # type: ignore[misc, no-any-unimporte
|
|||
dtype_right=dtype,
|
||||
)
|
||||
|
||||
with fresh_inductor_cache():
|
||||
with fresh_cache():
|
||||
|
||||
def mixed_mm(A: Any, B: Any) -> Any:
|
||||
return torch.mm(A, B)
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ import torch
|
|||
from torch._inductor.fx_passes.pad_mm import ( # type: ignore[import-not-found]
|
||||
get_alignment_size_dtype,
|
||||
)
|
||||
from torch._inductor.utils import fresh_inductor_cache
|
||||
from torch._inductor.utils import fresh_cache
|
||||
|
||||
|
||||
class BenchmarkRunnerPadMM(BenchmarkRunner): # type: ignore[misc, no-any-unimported]
|
||||
|
|
@ -74,7 +74,7 @@ class BenchmarkRunnerPadMM(BenchmarkRunner): # type: ignore[misc, no-any-unimpo
|
|||
print(f"transpose_left={transpose_left} transpose_right={transpose_right}")
|
||||
print(f"prepadded_left={prepadded_left} prepadded_right={prepadded_right}")
|
||||
|
||||
with fresh_inductor_cache():
|
||||
with fresh_cache():
|
||||
|
||||
def mm(a: Any, b: Any) -> Any:
|
||||
return torch.mm(a, b)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user