Rename inductor cache (#156128)

Requested by Simon on a different PR

Pull Request resolved: https://github.com/pytorch/pytorch/pull/156128
Approved by: https://github.com/xmfan
This commit is contained in:
Oguz Ulgen 2025-06-16 15:28:16 -07:00 committed by PyTorch MergeBot
parent 45382b284d
commit a2a75be0f8
48 changed files with 232 additions and 232 deletions

View File

@ -8,7 +8,7 @@ import sys
import tempfile import tempfile
from typing import Callable from typing import Callable
from torch._inductor.utils import fresh_inductor_cache from torch._inductor.utils import fresh_cache
logger: logging.Logger = logging.getLogger(__name__) logger: logging.Logger = logging.getLogger(__name__)
@ -62,7 +62,7 @@ def _run_torchbench_from_args(
warm_compile_time: list[float] = [] warm_compile_time: list[float] = []
for _ in range(cmd_args.repeat): for _ in range(cmd_args.repeat):
with fresh_inductor_cache(): with fresh_cache():
env = os.environ.copy() env = os.environ.copy()
with tempfile.NamedTemporaryFile(suffix=".csv") as file: with tempfile.NamedTemporaryFile(suffix=".csv") as file:
args.append("--output=" + file.name) args.append("--output=" + file.name)

View File

@ -51,7 +51,7 @@ from torch._logging.scribe import open_source_signpost
try: try:
from torch._dynamo.utils import clone_inputs, graph_break_reasons from torch._dynamo.utils import clone_inputs, graph_break_reasons
from torch._inductor.utils import fresh_inductor_cache from torch._inductor.utils import fresh_cache
except ImportError: except ImportError:
from _dynamo.utils import clone_inputs, graph_break_reasons from _dynamo.utils import clone_inputs, graph_break_reasons
@ -3416,7 +3416,7 @@ def maybe_fresh_cache(args):
if not cache_dir_assigned and ( if not cache_dir_assigned and (
args.cold_start_latency or args.warm_start_latency or args.ci args.cold_start_latency or args.warm_start_latency or args.ci
): ):
return fresh_inductor_cache() return fresh_cache()
else: else:
return contextlib.nullcontext() return contextlib.nullcontext()

View File

@ -3,7 +3,7 @@ import timeit
import torch.fx import torch.fx
from torch._dynamo.utils import counters from torch._dynamo.utils import counters
from torch._inductor.utils import clear_inductor_caches, fresh_inductor_cache from torch._inductor.utils import clear_caches, fresh_cache
N = 10000 N = 10000
@ -20,7 +20,7 @@ def main():
torch._inductor.config.fx_graph_cache = True torch._inductor.config.fx_graph_cache = True
torch._inductor.config.fx_graph_remote_cache = False torch._inductor.config.fx_graph_remote_cache = False
with fresh_inductor_cache(): with fresh_cache():
a = torch.randn(4).cuda() a = torch.randn(4).cuda()
compiled_fn = torch.compile(huge_graph, backend="inductor") compiled_fn = torch.compile(huge_graph, backend="inductor")
@ -30,7 +30,7 @@ def main():
def setup(): def setup():
torch._dynamo.reset() torch._dynamo.reset()
clear_inductor_caches() clear_caches()
for m in torch._inductor.codecache.PyCodeCache.cache.values(): for m in torch._inductor.codecache.PyCodeCache.cache.values():
os.remove(m.__file__) os.remove(m.__file__)
counters.clear() counters.clear()

View File

@ -3,7 +3,7 @@ import sys
from benchmark_base import BenchmarkBase from benchmark_base import BenchmarkBase
import torch import torch
from torch._inductor.utils import fresh_inductor_cache from torch._inductor.utils import fresh_cache
class Benchmark(BenchmarkBase): class Benchmark(BenchmarkBase):
@ -50,7 +50,7 @@ class Benchmark(BenchmarkBase):
result = result.sin() result = result.sin()
return result return result
with fresh_inductor_cache(): with fresh_cache():
f(self.a, self.b) f(self.a, self.b)

View File

@ -4,7 +4,7 @@ from benchmark_base import BenchmarkBase
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch._inductor.utils import fresh_inductor_cache from torch._inductor.utils import fresh_cache
class ListOfLinears(nn.Module): class ListOfLinears(nn.Module):
@ -55,7 +55,7 @@ class Benchmark(BenchmarkBase):
def _work(self): def _work(self):
with ( with (
fresh_inductor_cache(), fresh_cache(),
torch._inductor.config.patch(force_shape_pad=self._force_shape_pad), torch._inductor.config.patch(force_shape_pad=self._force_shape_pad),
): ):
opt_m = torch.compile(backend=self.backend(), dynamic=self.is_dynamic())( opt_m = torch.compile(backend=self.backend(), dynamic=self.is_dynamic())(

View File

@ -4,7 +4,7 @@ from benchmark_base import BenchmarkBase
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch._inductor.utils import fresh_inductor_cache from torch._inductor.utils import fresh_cache
# Create a chain of artificial nesting # Create a chain of artificial nesting
@ -94,7 +94,7 @@ class Benchmark(BenchmarkBase):
# enable_cpp_symbolic_shape_guards has impact on this benchmark # enable_cpp_symbolic_shape_guards has impact on this benchmark
# Keep using False value for consistency. # Keep using False value for consistency.
with ( with (
fresh_inductor_cache(), fresh_cache(),
): ):
opt_m = torch.compile(backend=self.backend(), dynamic=self.is_dynamic())( opt_m = torch.compile(backend=self.backend(), dynamic=self.is_dynamic())(
self.m.cuda() if self._is_gpu else self.m self.m.cuda() if self._is_gpu else self.m

View File

@ -3,7 +3,7 @@ import sys
from benchmark_base import BenchmarkBase from benchmark_base import BenchmarkBase
import torch import torch
from torch._inductor.utils import fresh_inductor_cache from torch._inductor.utils import fresh_cache
class Benchmark(BenchmarkBase): class Benchmark(BenchmarkBase):
@ -31,7 +31,7 @@ class Benchmark(BenchmarkBase):
def f(x, y): def f(x, y):
return x + y return x + y
with fresh_inductor_cache(): with fresh_cache():
for i in range(8): for i in range(8):
f(torch.arange(3), i * 2.5) f(torch.arange(3), i * 2.5)

View File

@ -3,7 +3,7 @@ import sys
from benchmark_base import BenchmarkBase from benchmark_base import BenchmarkBase
import torch import torch
from torch._inductor.utils import fresh_inductor_cache from torch._inductor.utils import fresh_cache
class Benchmark(BenchmarkBase): class Benchmark(BenchmarkBase):
@ -45,7 +45,7 @@ class Benchmark(BenchmarkBase):
z = torch.mm(z, b) z = torch.mm(z, b)
return z return z
with fresh_inductor_cache(), torch._inductor.config.patch(max_autotune=True): with fresh_cache(), torch._inductor.config.patch(max_autotune=True):
f(self.a, self.b) f(self.a, self.b)

View File

@ -4,7 +4,7 @@ from benchmark_base import BenchmarkBase
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch._inductor.utils import fresh_inductor_cache from torch._inductor.utils import fresh_cache
class NestedModule(nn.Module): class NestedModule(nn.Module):
@ -67,7 +67,7 @@ class Benchmark(BenchmarkBase):
# enable_cpp_symbolic_shape_guards has impact on this benchmark # enable_cpp_symbolic_shape_guards has impact on this benchmark
# Keep using False value for consistency. # Keep using False value for consistency.
with ( with (
fresh_inductor_cache(), fresh_cache(),
): ):
opt_m = torch.compile(backend=self.backend(), dynamic=self.is_dynamic())( opt_m = torch.compile(backend=self.backend(), dynamic=self.is_dynamic())(
self.m.cuda() if self._is_gpu else self.m self.m.cuda() if self._is_gpu else self.m

View File

@ -246,7 +246,7 @@ def run_single_experiment_group(
for config in group_config.experiments: for config in group_config.experiments:
torch._dynamo.reset() torch._dynamo.reset()
torch._inductor.utils.clear_inductor_caches() torch._inductor.utils.clear_caches()
compiled_op = torch.compile( compiled_op = torch.compile(
op, op,
options=config.to_options(), options=config.to_options(),

View File

@ -13,7 +13,7 @@ from torch._inductor.fx_passes.micro_pipeline_tp import (
micro_pipeline_tp_pass, micro_pipeline_tp_pass,
) )
from torch._inductor.fx_passes.post_grad import remove_noop_ops, view_to_reshape from torch._inductor.fx_passes.post_grad import remove_noop_ops, view_to_reshape
from torch._inductor.utils import fresh_inductor_cache, run_and_get_triton_code from torch._inductor.utils import fresh_cache, run_and_get_triton_code
from torch.distributed._functional_collectives import ( from torch.distributed._functional_collectives import (
all_gather_tensor, all_gather_tensor,
reduce_scatter_tensor, reduce_scatter_tensor,
@ -81,7 +81,7 @@ class MicroPipelineTPTest(TestCase):
dist.destroy_process_group() dist.destroy_process_group()
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@fresh_inductor_cache() @fresh_cache()
def test_find_all_gather_patterns(self): def test_find_all_gather_patterns(self):
group = dist.group.WORLD group = dist.group.WORLD
@ -134,7 +134,7 @@ class MicroPipelineTPTest(TestCase):
) )
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@fresh_inductor_cache() @fresh_cache()
def test_find_reduce_scatter_patterns(self): def test_find_reduce_scatter_patterns(self):
group = dist.group.WORLD group = dist.group.WORLD
@ -173,7 +173,7 @@ class MicroPipelineTPTest(TestCase):
self.assertEqual(reduce_scatters[1].scatter_dim, 1) self.assertEqual(reduce_scatters[1].scatter_dim, 1)
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@fresh_inductor_cache() @fresh_cache()
def test_get_unexposed_collectives(self): def test_get_unexposed_collectives(self):
group = dist.group.WORLD group = dist.group.WORLD
@ -201,7 +201,7 @@ class MicroPipelineTPTest(TestCase):
@parametrize("A_dims", [2, 3]) @parametrize("A_dims", [2, 3])
@parametrize("gather_dim", [0, 1, 2]) @parametrize("gather_dim", [0, 1, 2])
@parametrize("return_A", [True, False]) @parametrize("return_A", [True, False])
@fresh_inductor_cache() @fresh_cache()
def test_fuse_all_gather_matmul(self, A_dims, gather_dim, return_A): def test_fuse_all_gather_matmul(self, A_dims, gather_dim, return_A):
if gather_dim >= A_dims: if gather_dim >= A_dims:
return return
@ -248,7 +248,7 @@ class MicroPipelineTPTest(TestCase):
@parametrize("A_dims", [2, 3]) @parametrize("A_dims", [2, 3])
@parametrize("gather_dim", [0, 1, 2]) @parametrize("gather_dim", [0, 1, 2])
@parametrize("return_A", [True, False]) @parametrize("return_A", [True, False])
@fresh_inductor_cache() @fresh_cache()
def test_fuse_all_gather_scaled_matmul(self, A_dims, gather_dim, return_A): def test_fuse_all_gather_scaled_matmul(self, A_dims, gather_dim, return_A):
if gather_dim >= A_dims: if gather_dim >= A_dims:
return return
@ -321,7 +321,7 @@ class MicroPipelineTPTest(TestCase):
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@parametrize("A_dims", [2, 3]) @parametrize("A_dims", [2, 3])
@parametrize("scatter_dim", [0, 1, 2]) @parametrize("scatter_dim", [0, 1, 2])
@fresh_inductor_cache() @fresh_cache()
def test_fuse_matmul_reduce_scatter(self, A_dims, scatter_dim): def test_fuse_matmul_reduce_scatter(self, A_dims, scatter_dim):
if scatter_dim >= A_dims: if scatter_dim >= A_dims:
return return
@ -350,7 +350,7 @@ class MicroPipelineTPTest(TestCase):
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@parametrize("A_dims", [2, 3]) @parametrize("A_dims", [2, 3])
@parametrize("scatter_dim", [0, 1, 2]) @parametrize("scatter_dim", [0, 1, 2])
@fresh_inductor_cache() @fresh_cache()
def test_fuse_scaled_matmul_reduce_scatter(self, A_dims, scatter_dim): def test_fuse_scaled_matmul_reduce_scatter(self, A_dims, scatter_dim):
if scatter_dim >= A_dims: if scatter_dim >= A_dims:
return return
@ -403,7 +403,7 @@ class MicroPipelineTPTest(TestCase):
@runOnRocmArch(MI300_ARCH) @runOnRocmArch(MI300_ARCH)
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@parametrize("scatter_dim", [0, 1, 2]) @parametrize("scatter_dim", [0, 1, 2])
@fresh_inductor_cache() @fresh_cache()
def test_fuse_scaled_matmul_reduce_scatter_rowwise_scales_reshape_mm_reshape( def test_fuse_scaled_matmul_reduce_scatter_rowwise_scales_reshape_mm_reshape(
self, scatter_dim self, scatter_dim
): ):
@ -465,7 +465,7 @@ class MicroPipelineTPTest(TestCase):
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@parametrize("shard_dim", [0, 1]) @parametrize("shard_dim", [0, 1])
@fresh_inductor_cache() @fresh_cache()
def test_dtensor_seq_par(self, shard_dim: int): def test_dtensor_seq_par(self, shard_dim: int):
model: torch.nn.Module = MLPModule(device="cuda", bias=False) model: torch.nn.Module = MLPModule(device="cuda", bias=False)
device_mesh = DeviceMesh( device_mesh = DeviceMesh(

View File

@ -9,7 +9,7 @@ import torch
import torch.distributed as dist import torch.distributed as dist
import torch.distributed._functional_collectives as funcol import torch.distributed._functional_collectives as funcol
from torch._C import FileCheck from torch._C import FileCheck
from torch._inductor.utils import fresh_inductor_cache, run_and_get_triton_code from torch._inductor.utils import fresh_cache, run_and_get_triton_code
from torch.distributed._functional_collectives import ( from torch.distributed._functional_collectives import (
all_gather_into_tensor_coalesced, all_gather_into_tensor_coalesced,
all_gather_tensor, all_gather_tensor,
@ -464,7 +464,7 @@ class TestWithNCCL(MultiProcessTestCase):
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@skip_if_lt_x_gpu(2) @skip_if_lt_x_gpu(2)
@fresh_inductor_cache() @fresh_cache()
def test_threading(self): def test_threading(self):
self._init_process_group() self._init_process_group()
device = torch.device(f"cuda:{self.rank}") device = torch.device(f"cuda:{self.rank}")
@ -510,7 +510,7 @@ class TestWithNCCL(MultiProcessTestCase):
"_scaled_mm currently only supports sm>=90", "_scaled_mm currently only supports sm>=90",
) )
@skip_if_lt_x_gpu(2) @skip_if_lt_x_gpu(2)
@fresh_inductor_cache() @fresh_cache()
def test_fixed_striding(self): def test_fixed_striding(self):
self._init_process_group() self._init_process_group()
@ -736,7 +736,7 @@ class CompileTest(TestCase):
dist.destroy_process_group() dist.destroy_process_group()
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@fresh_inductor_cache() @fresh_cache()
def test_inductor_all_reduce_single(self): def test_inductor_all_reduce_single(self):
def func(arg: torch.Tensor) -> torch.Tensor: def func(arg: torch.Tensor) -> torch.Tensor:
buf0 = arg + 42 buf0 = arg + 42
@ -773,7 +773,7 @@ class CompileTest(TestCase):
torch.cuda.synchronize() torch.cuda.synchronize()
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@fresh_inductor_cache() @fresh_cache()
def test_inductor_all_reduce_coalesced(self): def test_inductor_all_reduce_coalesced(self):
def func(args: list[torch.Tensor]) -> torch.Tensor: def func(args: list[torch.Tensor]) -> torch.Tensor:
bufs = [arg + 42 for arg in args] bufs = [arg + 42 for arg in args]
@ -819,7 +819,7 @@ class CompileTest(TestCase):
torch.cuda.synchronize() torch.cuda.synchronize()
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@fresh_inductor_cache() @fresh_cache()
def test_inductor_inplace_op_on_view(self): def test_inductor_inplace_op_on_view(self):
def func(arg: torch.Tensor) -> torch.Tensor: def func(arg: torch.Tensor) -> torch.Tensor:
buf0 = (arg + 10)[:2] buf0 = (arg + 10)[:2]
@ -843,7 +843,7 @@ class CompileTest(TestCase):
) )
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@fresh_inductor_cache() @fresh_cache()
def test_inductor_all_reduce_non_contig_input(self): def test_inductor_all_reduce_non_contig_input(self):
def func(arg: torch.Tensor) -> torch.Tensor: def func(arg: torch.Tensor) -> torch.Tensor:
ar0 = funcol.all_reduce(arg, "avg", "0") ar0 = funcol.all_reduce(arg, "avg", "0")
@ -869,7 +869,7 @@ class CompileTest(TestCase):
assert "torch.ops._c10d_functional.wait_tensor.default" in code assert "torch.ops._c10d_functional.wait_tensor.default" in code
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@fresh_inductor_cache() @fresh_cache()
def test_inductor_reuse_buffer_after_inplace_collective(self): def test_inductor_reuse_buffer_after_inplace_collective(self):
def func(arg: torch.Tensor) -> torch.Tensor: def func(arg: torch.Tensor) -> torch.Tensor:
# Expect allocation # Expect allocation
@ -904,7 +904,7 @@ class CompileTest(TestCase):
assert "= torch.ops._c10d_functional.wait_tensor.default" not in code assert "= torch.ops._c10d_functional.wait_tensor.default" not in code
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@fresh_inductor_cache() @fresh_cache()
def test_inductor_all_gather_into_tensor_single(self): def test_inductor_all_gather_into_tensor_single(self):
def func(arg: torch.Tensor) -> torch.Tensor: def func(arg: torch.Tensor) -> torch.Tensor:
ag0 = funcol.all_gather_tensor(arg, 0, "0") ag0 = funcol.all_gather_tensor(arg, 0, "0")
@ -931,7 +931,7 @@ class CompileTest(TestCase):
torch.cuda.synchronize() torch.cuda.synchronize()
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@fresh_inductor_cache() @fresh_cache()
def test_inductor_all_gather_into_tensor_coalesced(self): def test_inductor_all_gather_into_tensor_coalesced(self):
def func(args: list[torch.Tensor]) -> torch.Tensor: def func(args: list[torch.Tensor]) -> torch.Tensor:
ag0 = funcol.all_gather_into_tensor_coalesced(args, "0") ag0 = funcol.all_gather_into_tensor_coalesced(args, "0")
@ -965,7 +965,7 @@ class CompileTest(TestCase):
torch.cuda.synchronize() torch.cuda.synchronize()
@unittest.skipIf(not HAS_GPU, "This is a GPU test!") @unittest.skipIf(not HAS_GPU, "This is a GPU test!")
@fresh_inductor_cache() @fresh_cache()
def test_wait_tensor(self): def test_wait_tensor(self):
def func(arg: torch.Tensor) -> torch.Tensor: def func(arg: torch.Tensor) -> torch.Tensor:
t = torch.ops._c10d_functional.all_reduce(arg, "avg", "0") t = torch.ops._c10d_functional.all_reduce(arg, "avg", "0")
@ -987,7 +987,7 @@ class CompileTest(TestCase):
torch.cuda.synchronize() torch.cuda.synchronize()
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@fresh_inductor_cache() @fresh_cache()
def test_inductor_reduce_scatter_tensor_single(self): def test_inductor_reduce_scatter_tensor_single(self):
def func(arg: torch.Tensor) -> torch.Tensor: def func(arg: torch.Tensor) -> torch.Tensor:
rs0 = funcol.reduce_scatter_tensor(arg, "avg", 0, "0") rs0 = funcol.reduce_scatter_tensor(arg, "avg", 0, "0")
@ -1013,7 +1013,7 @@ class CompileTest(TestCase):
torch.cuda.synchronize() torch.cuda.synchronize()
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@fresh_inductor_cache() @fresh_cache()
def test_inductor_reduce_scatter_tensor_coalesced(self): def test_inductor_reduce_scatter_tensor_coalesced(self):
def func(args: list[torch.Tensor]) -> torch.Tensor: def func(args: list[torch.Tensor]) -> torch.Tensor:
rs0 = funcol.reduce_scatter_tensor_coalesced( rs0 = funcol.reduce_scatter_tensor_coalesced(
@ -1049,7 +1049,7 @@ class CompileTest(TestCase):
torch.cuda.synchronize() torch.cuda.synchronize()
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@fresh_inductor_cache() @fresh_cache()
def test_inductor_all_to_all_single(self): def test_inductor_all_to_all_single(self):
def _tolist_with_constrain_as_size(tensor): def _tolist_with_constrain_as_size(tensor):
lst = tensor.tolist() lst = tensor.tolist()
@ -1097,7 +1097,7 @@ class CompileTest(TestCase):
) )
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@fresh_inductor_cache() @fresh_cache()
def test_inductor_broadcast(self): def test_inductor_broadcast(self):
def func(arg: torch.Tensor) -> torch.Tensor: def func(arg: torch.Tensor) -> torch.Tensor:
buf0 = arg + 42 buf0 = arg + 42
@ -1134,7 +1134,7 @@ class CompileTest(TestCase):
torch.cuda.synchronize() torch.cuda.synchronize()
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@fresh_inductor_cache() @fresh_cache()
def test_ranks_and_tag(self): def test_ranks_and_tag(self):
def func(arg: torch.Tensor) -> torch.Tensor: def func(arg: torch.Tensor) -> torch.Tensor:
buf0 = arg + 42 buf0 = arg + 42

View File

@ -1218,11 +1218,9 @@ class TestMultiProc(DynamoDistributedMultiProcTestCase):
@patch.object(torch._inductor.config, "sleep_sec_TESTING_ONLY", 10) @patch.object(torch._inductor.config, "sleep_sec_TESTING_ONLY", 10)
def test_asymmetric_compilation_with_fx_cache(self): def test_asymmetric_compilation_with_fx_cache(self):
from torch._dynamo.utils import counters from torch._dynamo.utils import counters
from torch._inductor.utils import fresh_inductor_cache from torch._inductor.utils import fresh_cache
with fresh_inductor_cache(), _dynamo_dist_per_rank_init( with fresh_cache(), _dynamo_dist_per_rank_init(self.rank, self.world_size):
self.rank, self.world_size
):
torch._dynamo.utils.clear_compilation_metrics() torch._dynamo.utils.clear_compilation_metrics()
device = f"cuda:{self.rank}" device = f"cuda:{self.rank}"
@ -1252,7 +1250,7 @@ class TestMultiProc(DynamoDistributedMultiProcTestCase):
torch._dynamo.reset() torch._dynamo.reset()
if self.rank == 0: if self.rank == 0:
with fresh_inductor_cache(): with fresh_cache():
f(x) f(x)
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 2) self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 2)
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0) self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0)

View File

@ -10,7 +10,7 @@ import torch.distributed as dist
import torch.distributed._symmetric_memory as symm_mem import torch.distributed._symmetric_memory as symm_mem
from torch._C._autograd import DeviceType from torch._C._autograd import DeviceType
from torch._C._distributed_c10d import _SymmetricMemory from torch._C._distributed_c10d import _SymmetricMemory
from torch._inductor.utils import fresh_inductor_cache, run_and_get_triton_code from torch._inductor.utils import fresh_cache, run_and_get_triton_code
from torch.distributed._functional_collectives import all_gather_tensor from torch.distributed._functional_collectives import all_gather_tensor
from torch.distributed._symmetric_memory import ( from torch.distributed._symmetric_memory import (
_fused_all_gather_matmul_fallback, _fused_all_gather_matmul_fallback,
@ -1020,7 +1020,7 @@ class LoweringTest(MultiProcContinousTest):
@skip("Fails with 'one_shot_all_reduce' not found in AOT graph, TODO: fix") @skip("Fails with 'one_shot_all_reduce' not found in AOT graph, TODO: fix")
@skipIfRocm # requires registered-buffer support @skipIfRocm # requires registered-buffer support
@skip_if_lt_x_gpu(2) @skip_if_lt_x_gpu(2)
@fresh_inductor_cache() @fresh_cache()
def test_lowering_one_shot_all_reduce(self): def test_lowering_one_shot_all_reduce(self):
self._init_process() self._init_process()
arg = torch.rand(4, 4, device=self.device) arg = torch.rand(4, 4, device=self.device)

View File

@ -25,7 +25,7 @@ from torch._inductor import config as inductor_config
from torch._inductor.runtime.runtime_utils import cache_dir from torch._inductor.runtime.runtime_utils import cache_dir
from torch._inductor.runtime.triton_compat import tl, triton from torch._inductor.runtime.triton_compat import tl, triton
from torch._inductor.test_case import TestCase as InductorTestCase from torch._inductor.test_case import TestCase as InductorTestCase
from torch._inductor.utils import fresh_inductor_cache from torch._inductor.utils import fresh_cache
from torch._subclasses import FakeTensorMode from torch._subclasses import FakeTensorMode
from torch.compiler._cache import CacheArtifactManager from torch.compiler._cache import CacheArtifactManager
from torch.fx.experimental.symbolic_shapes import ShapeEnv from torch.fx.experimental.symbolic_shapes import ShapeEnv
@ -165,7 +165,7 @@ class AOTAutogradCacheTests(InductorTestCase):
b = torch.rand(100, 100, dtype=dtype, device=device, requires_grad=True) b = torch.rand(100, 100, dtype=dtype, device=device, requires_grad=True)
# Record artifacts # Record artifacts
with fresh_inductor_cache(): with fresh_cache():
compiled_fn = torch.compile(fn, dynamic=dynamic) compiled_fn = torch.compile(fn, dynamic=dynamic)
# A first call should miss in the cache. # A first call should miss in the cache.
@ -201,7 +201,7 @@ class AOTAutogradCacheTests(InductorTestCase):
shutil.rmtree(os.path.join(cache_dir(), "triton"), ignore_errors=True) shutil.rmtree(os.path.join(cache_dir(), "triton"), ignore_errors=True)
# We did not load anything so dont hit yet # We did not load anything so dont hit yet
with fresh_inductor_cache(): with fresh_cache():
eager_result = fn(a, b) eager_result = fn(a, b)
compiled_result = compiled_fn(a, b) compiled_result = compiled_fn(a, b)
self.assertEqual(eager_result, compiled_result) self.assertEqual(eager_result, compiled_result)
@ -221,7 +221,7 @@ class AOTAutogradCacheTests(InductorTestCase):
shutil.rmtree(os.path.join(cache_dir(), "triton"), ignore_errors=True) shutil.rmtree(os.path.join(cache_dir(), "triton"), ignore_errors=True)
# Hot load and hit # Hot load and hit
with fresh_inductor_cache(): with fresh_cache():
cache_info = torch.compiler.load_cache_artifacts(artifact_bytes) cache_info = torch.compiler.load_cache_artifacts(artifact_bytes)
self.assertEqual(len(cache_info.inductor_artifacts), 2) self.assertEqual(len(cache_info.inductor_artifacts), 2)

View File

@ -798,7 +798,7 @@ TRACE FX call mul from test_logging.py:N in fn (LoggingTests.test_trace_call_pre
@requires_cuda @requires_cuda
@unittest.skipIf(not SM90OrLater, "requires H100+ GPU") @unittest.skipIf(not SM90OrLater, "requires H100+ GPU")
def test_autotuning(self, records): def test_autotuning(self, records):
with torch._inductor.utils.fresh_inductor_cache(): with torch._inductor.utils.fresh_cache():
def f(a, b): def f(a, b):
return torch.mm(a, b) return torch.mm(a, b)

View File

@ -54,7 +54,7 @@ from torch._dynamo.testing import (
) )
from torch._dynamo.utils import call_size, counters, ifdynstaticdefault from torch._dynamo.utils import call_size, counters, ifdynstaticdefault
from torch._dynamo.variables import builder from torch._dynamo.variables import builder
from torch._inductor.utils import fresh_inductor_cache, run_and_get_code from torch._inductor.utils import fresh_cache, run_and_get_code
from torch.ao.quantization import MinMaxObserver from torch.ao.quantization import MinMaxObserver
from torch.ao.quantization.fake_quantize import FakeQuantize from torch.ao.quantization.fake_quantize import FakeQuantize
from torch.ao.quantization.qconfig import QConfig from torch.ao.quantization.qconfig import QConfig
@ -8087,7 +8087,7 @@ utils_device.CURRENT_DEVICE == None""".split(
m1 = Model(50) m1 = Model(50)
m2 = Model(60) m2 = Model(60)
with fresh_inductor_cache(): with fresh_cache():
m1(torch.rand(1, 2, 3)) m1(torch.rand(1, 2, 3))
m2(torch.rand(1, 2, 3)) m2(torch.rand(1, 2, 3))

View File

@ -12,7 +12,7 @@ import torch._inductor.mock_cache as mock_cache
import torch.compiler.config import torch.compiler.config
import torch.nested import torch.nested
from torch._dynamo.testing import CompileCounter from torch._dynamo.testing import CompileCounter
from torch._inductor.utils import clear_inductor_caches, fresh_inductor_cache from torch._inductor.utils import clear_caches, fresh_cache
class PgoTest(torch._dynamo.test_case.TestCase): class PgoTest(torch._dynamo.test_case.TestCase):
@ -24,7 +24,7 @@ class PgoTest(torch._dynamo.test_case.TestCase):
torch._dynamo.config.patch(automatic_dynamic_local_pgo=True) torch._dynamo.config.patch(automatic_dynamic_local_pgo=True)
) )
if os.environ.get("INDUCTOR_TEST_DISABLE_FRESH_CACHE") != "1": if os.environ.get("INDUCTOR_TEST_DISABLE_FRESH_CACHE") != "1":
self._test_stack.enter_context(fresh_inductor_cache()) self._test_stack.enter_context(fresh_cache())
mock_cache.PatchCaches.setUp() mock_cache.PatchCaches.setUp()
def tearDown(self): def tearDown(self):
@ -35,7 +35,7 @@ class PgoTest(torch._dynamo.test_case.TestCase):
def reset(self): def reset(self):
torch._dynamo.reset() torch._dynamo.reset()
clear_inductor_caches() clear_caches()
def test_basic(self): def test_basic(self):
cnts = CompileCounter() cnts = CompileCounter()
@ -244,7 +244,7 @@ class PgoTest(torch._dynamo.test_case.TestCase):
self.assertEqual(cnts.frame_count, 2) self.assertEqual(cnts.frame_count, 2)
torch._dynamo.reset() torch._dynamo.reset()
clear_inductor_caches() clear_caches()
cnts.clear() cnts.clear()
with torch.compiler.config.patch(job_id="foo"): with torch.compiler.config.patch(job_id="foo"):

View File

@ -43,7 +43,7 @@ import torch.utils._pytree as pytree
from torch import nn from torch import nn
from torch._dynamo.debug_utils import same_two_models from torch._dynamo.debug_utils import same_two_models
from torch._dynamo.testing import CompileCounter, rand_strided, same, skipIfPy312 from torch._dynamo.testing import CompileCounter, rand_strided, same, skipIfPy312
from torch._inductor.utils import fresh_inductor_cache from torch._inductor.utils import fresh_cache
from torch.nn import functional as F from torch.nn import functional as F
from torch.testing._internal.common_cuda import ( from torch.testing._internal.common_cuda import (
PLATFORM_SUPPORTS_FLASH_ATTENTION, PLATFORM_SUPPORTS_FLASH_ATTENTION,
@ -5527,7 +5527,7 @@ def forward(self, s77 : torch.SymInt, s27 : torch.SymInt, L_x_ : torch.Tensor):
y = torch.randn(100, 10) y = torch.randn(100, 10)
return torch.mm(x, y).sum() return torch.mm(x, y).sum()
with fresh_inductor_cache(): with fresh_cache():
torch.compile(fn)() torch.compile(fn)()
torch.compile(fn2)() torch.compile(fn2)()

View File

@ -18,7 +18,7 @@ import torch
from torch._inductor.codecache import get_kernel_bin_format from torch._inductor.codecache import get_kernel_bin_format
from torch._inductor.package import AOTICompiledModel, load_package, package_aoti from torch._inductor.package import AOTICompiledModel, load_package, package_aoti
from torch._inductor.test_case import TestCase from torch._inductor.test_case import TestCase
from torch._inductor.utils import fresh_inductor_cache from torch._inductor.utils import fresh_cache
from torch.export import Dim from torch.export import Dim
from torch.testing._internal.common_utils import ( from torch.testing._internal.common_utils import (
IS_FBCODE, IS_FBCODE,
@ -157,7 +157,7 @@ class TestAOTInductorPackage(TestCase):
torch.manual_seed(0) torch.manual_seed(0)
with tempfile.NamedTemporaryFile(suffix=".pt2") as f: with tempfile.NamedTemporaryFile(suffix=".pt2") as f:
ep = torch.export.export(model, example_inputs, strict=True) ep = torch.export.export(model, example_inputs, strict=True)
with fresh_inductor_cache(): with fresh_cache():
# cubin files are removed when exiting this context # cubin files are removed when exiting this context
package_path = torch._inductor.aoti_compile_and_package( package_path = torch._inductor.aoti_compile_and_package(
ep, ep,

View File

@ -3,7 +3,7 @@ import torch
from torch._inductor import config from torch._inductor import config
from torch._inductor.async_compile import AsyncCompile, shutdown_compile_workers from torch._inductor.async_compile import AsyncCompile, shutdown_compile_workers
from torch._inductor.test_case import run_tests, TestCase from torch._inductor.test_case import run_tests, TestCase
from torch._inductor.utils import fresh_inductor_cache from torch._inductor.utils import fresh_cache
from torch.testing._internal.common_utils import ( from torch.testing._internal.common_utils import (
instantiate_parametrized_tests, instantiate_parametrized_tests,
parametrize, parametrize,
@ -32,7 +32,7 @@ class TestAsyncCompile(TestCase):
pool = AsyncCompile.process_pool() pool = AsyncCompile.process_pool()
pool.ready_future.result(timeout=120) pool.ready_future.result(timeout=120)
with fresh_inductor_cache(): with fresh_cache():
compiled_fn = torch.compile(fn) compiled_fn = torch.compile(fn)
self.assertEqual(fn(x, y), compiled_fn(x, y)) self.assertEqual(fn(x, y), compiled_fn(x, y))

View File

@ -7,7 +7,7 @@ import torch
from torch._inductor.codegen.triton import TritonScheduling from torch._inductor.codegen.triton import TritonScheduling
from torch._inductor.test_case import TestCase as InductorTestCase from torch._inductor.test_case import TestCase as InductorTestCase
from torch._inductor.test_operators import realize from torch._inductor.test_operators import realize
from torch._inductor.utils import fresh_inductor_cache, is_big_gpu, run_and_get_code from torch._inductor.utils import fresh_cache, is_big_gpu, run_and_get_code
from torch.testing import FileCheck from torch.testing import FileCheck
from torch.testing._internal.common_utils import slowTest from torch.testing._internal.common_utils import slowTest
from torch.testing._internal.inductor_utils import ( from torch.testing._internal.inductor_utils import (
@ -283,7 +283,7 @@ if HAS_CUDA:
self.assertEqual(res, res2, atol=1e-4, rtol=1.1) self.assertEqual(res, res2, atol=1e-4, rtol=1.1)
return code, code2 return code, code2
@fresh_inductor_cache() @fresh_cache()
@config.patch(max_autotune_gemm_backends="TRITON") @config.patch(max_autotune_gemm_backends="TRITON")
def test_equivalent_template_code(self): def test_equivalent_template_code(self):
code, code2 = self._equivalent_output_code_impl(256) code, code2 = self._equivalent_output_code_impl(256)
@ -298,7 +298,7 @@ if HAS_CUDA:
out_code[0] out_code[0]
) )
@fresh_inductor_cache() @fresh_cache()
@config.patch(max_autotune_gemm_backends="ATEN") @config.patch(max_autotune_gemm_backends="ATEN")
def test_equivalent_extern_code(self): def test_equivalent_extern_code(self):
torch._dynamo.reset() torch._dynamo.reset()

View File

@ -38,7 +38,7 @@ from torch._inductor.graph import GraphLowering
from torch._inductor.mock_cache import global_stats, PatchCaches, Stats from torch._inductor.mock_cache import global_stats, PatchCaches, Stats
from torch._inductor.runtime.runtime_utils import cache_dir from torch._inductor.runtime.runtime_utils import cache_dir
from torch._inductor.test_case import run_tests, TestCase from torch._inductor.test_case import run_tests, TestCase
from torch._inductor.utils import clear_inductor_caches, fresh_inductor_cache from torch._inductor.utils import clear_caches, fresh_cache
from torch._library import capture_triton from torch._library import capture_triton
from torch.compiler._cache import ( from torch.compiler._cache import (
CacheArtifact, CacheArtifact,
@ -146,7 +146,7 @@ class TestFxGraphCache(TestCase):
AOTAutogradCache.clear() AOTAutogradCache.clear()
PyCodeCache.cache_clear(purge=True) PyCodeCache.cache_clear(purge=True)
torch._dynamo.reset() torch._dynamo.reset()
clear_inductor_caches() clear_caches()
@requires_triton() @requires_triton()
@config.patch({"fx_graph_cache": True}) @config.patch({"fx_graph_cache": True})
@ -379,16 +379,14 @@ class TestFxGraphCache(TestCase):
), patch.dict(os.environ), PatchCaches(): ), patch.dict(os.environ), PatchCaches():
os.environ.pop("TRITON_CACHE_MANAGER", None) os.environ.pop("TRITON_CACHE_MANAGER", None)
for _ in range(4): for _ in range(4):
with fresh_inductor_cache(): with fresh_cache():
compiled_fn = torch.compile(fn, dynamic=dynamic) compiled_fn = torch.compile(fn, dynamic=dynamic)
self.assertEqual(fn(a, b), compiled_fn(a, b)) self.assertEqual(fn(a, b), compiled_fn(a, b))
reset() reset()
self.assertEqual(global_stats.fx_graph, Stats(1, 3, 1)) self.assertEqual(global_stats.fx_graph, Stats(1, 3, 1))
with torch.compiler.config.patch( with torch.compiler.config.patch({"cache_key_tag": "test"}), fresh_cache():
{"cache_key_tag": "test"}
), fresh_inductor_cache():
compiled_fn = torch.compile(fn, dynamic=dynamic) compiled_fn = torch.compile(fn, dynamic=dynamic)
self.assertEqual(fn(a, b), compiled_fn(a, b)) self.assertEqual(fn(a, b), compiled_fn(a, b))
@ -426,7 +424,7 @@ class TestFxGraphCache(TestCase):
b = torch.rand(100, 100, dtype=dtype, device=device) b = torch.rand(100, 100, dtype=dtype, device=device)
# Record artifacts # Record artifacts
with fresh_inductor_cache(): with fresh_cache():
compiled_fn = torch.compile(fn, dynamic=dynamic) compiled_fn = torch.compile(fn, dynamic=dynamic)
# A first call should miss in the cache. # A first call should miss in the cache.
@ -456,7 +454,7 @@ class TestFxGraphCache(TestCase):
shutil.rmtree(os.path.join(cache_dir(), "triton"), ignore_errors=True) shutil.rmtree(os.path.join(cache_dir(), "triton"), ignore_errors=True)
# We did not load anything so dont hit yet # We did not load anything so dont hit yet
with fresh_inductor_cache(): with fresh_cache():
eager_result = fn(a, b) eager_result = fn(a, b)
compiled_result = compiled_fn(a, b) compiled_result = compiled_fn(a, b)
self.assertEqual(eager_result, compiled_result) self.assertEqual(eager_result, compiled_result)
@ -470,7 +468,7 @@ class TestFxGraphCache(TestCase):
shutil.rmtree(os.path.join(cache_dir(), "triton"), ignore_errors=True) shutil.rmtree(os.path.join(cache_dir(), "triton"), ignore_errors=True)
# Hot load and hit # Hot load and hit
with fresh_inductor_cache(): with fresh_cache():
cache_info = torch.compiler.load_cache_artifacts(artifact_bytes) cache_info = torch.compiler.load_cache_artifacts(artifact_bytes)
self.assertEqual(len(cache_info.inductor_artifacts), 1) self.assertEqual(len(cache_info.inductor_artifacts), 1)
@ -503,7 +501,7 @@ class TestFxGraphCache(TestCase):
a2 = torch.randn(4, 8) a2 = torch.randn(4, 8)
b2 = torch.randn(8, 4) b2 = torch.randn(8, 4)
with fresh_inductor_cache(): with fresh_cache():
eager_result = fn(a, b) eager_result = fn(a, b)
compiled_result = compiled_fn(a, b) compiled_result = compiled_fn(a, b)
self.assertEqual(eager_result, compiled_result) self.assertEqual(eager_result, compiled_result)
@ -519,7 +517,7 @@ class TestFxGraphCache(TestCase):
self.reset() self.reset()
with fresh_inductor_cache(): with fresh_cache():
torch.compiler.load_cache_artifacts(artifact_bytes) torch.compiler.load_cache_artifacts(artifact_bytes)
eager_result = fn(a, b) eager_result = fn(a, b)
compiled_result = compiled_fn(a, b) compiled_result = compiled_fn(a, b)
@ -531,7 +529,7 @@ class TestFxGraphCache(TestCase):
self.reset() self.reset()
with fresh_inductor_cache(): with fresh_cache():
eager_result = fn(a2, b2) eager_result = fn(a2, b2)
compiled_result = compiled_fn(a2, b2) compiled_result = compiled_fn(a2, b2)
self.assertEqual(eager_result, compiled_result) self.assertEqual(eager_result, compiled_result)
@ -555,7 +553,7 @@ class TestFxGraphCache(TestCase):
return x * 2 return x * 2
# Record artifacts # Record artifacts
with torch.compiler.config.patch(job_id=self.id()), fresh_inductor_cache(): with torch.compiler.config.patch(job_id=self.id()), fresh_cache():
f(torch.randn(2, 3)) f(torch.randn(2, 3))
f(torch.randn(2, 4)) f(torch.randn(2, 4))
self.assertEqual(backend.frame_count, 2) self.assertEqual(backend.frame_count, 2)
@ -582,7 +580,7 @@ class TestFxGraphCache(TestCase):
shutil.rmtree(os.path.join(cache_dir(), "triton"), ignore_errors=True) shutil.rmtree(os.path.join(cache_dir(), "triton"), ignore_errors=True)
# Hot load and hit # Hot load and hit
with torch.compiler.config.patch({"job_id": self.id()}), fresh_inductor_cache(): with torch.compiler.config.patch({"job_id": self.id()}), fresh_cache():
cache_info = torch.compiler.load_cache_artifacts(artifact_bytes) cache_info = torch.compiler.load_cache_artifacts(artifact_bytes)
self.assertEqual(len(cache_info.inductor_artifacts), 2) self.assertEqual(len(cache_info.inductor_artifacts), 2)
@ -617,7 +615,7 @@ class TestFxGraphCache(TestCase):
with mock.patch( with mock.patch(
"torch._utils_internal.get_mast_job_name_version", return_value=("foo", 5) "torch._utils_internal.get_mast_job_name_version", return_value=("foo", 5)
): ):
with fresh_inductor_cache(): with fresh_cache():
f(torch.randn(2, 3)) f(torch.randn(2, 3))
f(torch.randn(2, 4)) f(torch.randn(2, 4))
self.assertEqual(backend.frame_count, 2) self.assertEqual(backend.frame_count, 2)
@ -639,7 +637,7 @@ class TestFxGraphCache(TestCase):
# Hot load and hit # Hot load and hit
with mock.patch( with mock.patch(
"torch._utils_internal.get_mast_job_name_version", return_value=("bar", 10) "torch._utils_internal.get_mast_job_name_version", return_value=("bar", 10)
), fresh_inductor_cache(): ), fresh_cache():
cache_info = torch.compiler.load_cache_artifacts(artifact_bytes) cache_info = torch.compiler.load_cache_artifacts(artifact_bytes)
self.assertEqual(len(cache_info.pgo_artifacts), 2) self.assertEqual(len(cache_info.pgo_artifacts), 2)
@ -1582,7 +1580,7 @@ class TestStandaloneCompile(TestCase):
AOTAutogradCache.clear() AOTAutogradCache.clear()
PyCodeCache.cache_clear(purge=True) PyCodeCache.cache_clear(purge=True)
torch._dynamo.reset() torch._dynamo.reset()
clear_inductor_caches() clear_caches()
def capture(self, fn, dynamic=None): def capture(self, fn, dynamic=None):
def inner(*args): def inner(*args):
@ -1638,7 +1636,7 @@ class TestStandaloneCompile(TestCase):
if format == "unpacked" if format == "unpacked"
else os.path.join(temp_dir, "compiled_artifact.bin") else os.path.join(temp_dir, "compiled_artifact.bin")
) )
with fresh_inductor_cache(): with fresh_cache():
gm, args, kwargs = self.capture(f)(x) gm, args, kwargs = self.capture(f)(x)
assert not kwargs assert not kwargs
@ -1647,7 +1645,7 @@ class TestStandaloneCompile(TestCase):
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0) self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0)
with fresh_inductor_cache(): with fresh_cache():
loaded = torch._inductor.CompiledArtifact.load(path=path, format=format) loaded = torch._inductor.CompiledArtifact.load(path=path, format=format)
if dynamic: if dynamic:
concrete_args = [ concrete_args = [
@ -1679,7 +1677,7 @@ class TestStandaloneCompile(TestCase):
def backend(gm, args, **kwargs): def backend(gm, args, **kwargs):
return torch._inductor.standalone_compile(gm, args) return torch._inductor.standalone_compile(gm, args)
with fresh_inductor_cache(): with fresh_cache():
compiled_out = torch.compile(f, fullgraph=True, backend=backend)(x) compiled_out = torch.compile(f, fullgraph=True, backend=backend)(x)
self.assertEqual(eager_out, compiled_out) self.assertEqual(eager_out, compiled_out)
@ -1698,7 +1696,7 @@ class TestStandaloneCompile(TestCase):
with tempfile.TemporaryDirectory() as temp_dir: with tempfile.TemporaryDirectory() as temp_dir:
path = os.path.join(temp_dir, "new_dir") path = os.path.join(temp_dir, "new_dir")
with fresh_inductor_cache(): with fresh_cache():
gm, args, kwargs = self.capture(f)(x) gm, args, kwargs = self.capture(f)(x)
assert not kwargs assert not kwargs
@ -1707,7 +1705,7 @@ class TestStandaloneCompile(TestCase):
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0) self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0)
with fresh_inductor_cache(): with fresh_cache():
loaded = torch._inductor.CompiledArtifact.load( loaded = torch._inductor.CompiledArtifact.load(
path=path, format="unpacked" path=path, format="unpacked"
) )
@ -1731,7 +1729,7 @@ class TestStandaloneCompile(TestCase):
eager_out = f(x) eager_out = f(x)
with tempfile.TemporaryDirectory() as temp_dir: with tempfile.TemporaryDirectory() as temp_dir:
with fresh_inductor_cache(): with fresh_cache():
gm, args, kwargs = self.capture(f)(x) gm, args, kwargs = self.capture(f)(x)
assert not kwargs assert not kwargs
@ -1743,7 +1741,7 @@ class TestStandaloneCompile(TestCase):
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0) self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0)
with fresh_inductor_cache(): with fresh_cache():
# Now modify the output file and expect to see the changes # Now modify the output file and expect to see the changes
for subdir in os.listdir(temp_dir): for subdir in os.listdir(temp_dir):
if subdir in ["aotautograd", "fxgraph"]: if subdir in ["aotautograd", "fxgraph"]:
@ -1791,16 +1789,16 @@ class TestStandaloneCompile(TestCase):
with tempfile.TemporaryDirectory() as temp_dir: with tempfile.TemporaryDirectory() as temp_dir:
path = os.path.join(temp_dir, "compiled_artifact.bin") path = os.path.join(temp_dir, "compiled_artifact.bin")
with fresh_inductor_cache(): with fresh_cache():
compiled_artifact = torch._inductor.standalone_compile(gm, args) compiled_artifact = torch._inductor.standalone_compile(gm, args)
compiled_artifact.save(path=path) compiled_artifact.save(path=path)
script = f""" script = f"""
import torch import torch
from torch._inductor.utils import fresh_inductor_cache from torch._inductor.utils import fresh_cache
arg = torch.ones(4, 1) arg = torch.ones(4, 1)
with fresh_inductor_cache(): with fresh_cache():
loaded = torch._inductor.CompiledArtifact.load(path="{path}") loaded = torch._inductor.CompiledArtifact.load(path="{path}")
compiled_result = loaded(arg)[0] compiled_result = loaded(arg)[0]
@ -1832,7 +1830,7 @@ if not torch.allclose(eager_result, compiled_result, atol=0.1, rtol=0.01):
x = torch.ones(3) x = torch.ones(3)
torch._dynamo.mark_dynamic(x, 0) torch._dynamo.mark_dynamic(x, 0)
with fresh_inductor_cache(): with fresh_cache():
# captured graph is lambda s0, x: x * s0 # captured graph is lambda s0, x: x * s0
gm, args, kwargs = self.capture(f)(x) gm, args, kwargs = self.capture(f)(x)
assert not kwargs assert not kwargs
@ -1854,7 +1852,7 @@ if not torch.allclose(eager_result, compiled_result, atol=0.1, rtol=0.01):
x = torch.ones(3) x = torch.ones(3)
torch._dynamo.mark_dynamic(x, 0) torch._dynamo.mark_dynamic(x, 0)
with fresh_inductor_cache(): with fresh_cache():
# captured graph is lambda s0, x: x * s0 # captured graph is lambda s0, x: x * s0
gm, args, kwargs = self.capture(f)(x) gm, args, kwargs = self.capture(f)(x)
assert not kwargs assert not kwargs
@ -1890,7 +1888,7 @@ if not torch.allclose(eager_result, compiled_result, atol=0.1, rtol=0.01):
return x.shape[0] * x return x.shape[0] * x
static_x = torch.randn(3) static_x = torch.randn(3)
with fresh_inductor_cache(): with fresh_cache():
# static_gm is lambda x: x * 3 # static_gm is lambda x: x * 3
static_gm, args, kwargs = self.capture(f, dynamic=False)(static_x) static_gm, args, kwargs = self.capture(f, dynamic=False)(static_x)
assert not kwargs assert not kwargs
@ -2440,7 +2438,7 @@ class TestAutotuneCache(TestCase):
def reset(self): def reset(self):
PyCodeCache.cache_clear(purge=True) PyCodeCache.cache_clear(purge=True)
torch._dynamo.reset() torch._dynamo.reset()
clear_inductor_caches() clear_caches()
@unittest.skipIf(not HAS_CUDA, "Requires CUDA") @unittest.skipIf(not HAS_CUDA, "Requires CUDA")
@unittest.skipIf(not SM80OrLater, "Requires SM80+") @unittest.skipIf(not SM80OrLater, "Requires SM80+")
@ -2750,20 +2748,20 @@ class TestRemoteAOTAutogradCache(TestCase):
class TestUtils(TestCase): class TestUtils(TestCase):
@config.patch({"fx_graph_remote_cache": False}) @config.patch({"fx_graph_remote_cache": False})
def test_fresh_inductor_cache(self): def test_fresh_cache(self):
def fn(x, y): def fn(x, y):
return x + y return x + y
a = torch.rand(10) a = torch.rand(10)
b = torch.rand(10) b = torch.rand(10)
with fresh_inductor_cache(): with fresh_cache():
self.assertEqual(len(PyCodeCache.modules), 0) self.assertEqual(len(PyCodeCache.modules), 0)
res1 = torch.compile(fn)(a, b) res1 = torch.compile(fn)(a, b)
cache_dir1 = cache_dir() cache_dir1 = cache_dir()
torch._dynamo.reset() torch._dynamo.reset()
with fresh_inductor_cache(): with fresh_cache():
self.assertEqual(len(PyCodeCache.modules), 0) self.assertEqual(len(PyCodeCache.modules), 0)
res2 = torch.compile(fn)(a, b) res2 = torch.compile(fn)(a, b)
cache_dir2 = cache_dir() cache_dir2 = cache_dir()

View File

@ -919,9 +919,9 @@ class CompiledOptimizerTests(TestCase):
import torch._dynamo import torch._dynamo
import torch._inductor import torch._inductor
from torch._dynamo.debug_utils import aot_graph_input_parser from torch._dynamo.debug_utils import aot_graph_input_parser
from torch._inductor.utils import fresh_inductor_cache from torch._inductor.utils import fresh_cache
with fresh_inductor_cache(): with fresh_cache():
kwargs = aot_graph_input_parser(forward) kwargs = aot_graph_input_parser(forward)
torch.compile(forward)(**kwargs) torch.compile(forward)(**kwargs)

View File

@ -1814,9 +1814,9 @@ class CudaReproTests(TestCase):
m = ToyModel().to(device="cuda:0") m = ToyModel().to(device="cuda:0")
input_tensor = torch.randn(32, 3, 64, 64).to(device="cuda:0") input_tensor = torch.randn(32, 3, 64, 64).to(device="cuda:0")
from torch._inductor.utils import fresh_inductor_cache from torch._inductor.utils import fresh_cache
with fresh_inductor_cache(): with fresh_cache():
cm = torch.compile(m, mode="max-autotune") cm = torch.compile(m, mode="max-autotune")
out = cm(input_tensor) out = cm(input_tensor)
out2 = m(input_tensor) out2 = m(input_tensor)

View File

@ -10,7 +10,7 @@ from torch._inductor.codecache import CUDACodeCache
from torch._inductor.codegen.cuda.cuda_env import nvcc_exist from torch._inductor.codegen.cuda.cuda_env import nvcc_exist
from torch._inductor.exc import CUDACompileError from torch._inductor.exc import CUDACompileError
from torch._inductor.test_case import TestCase as InductorTestCase from torch._inductor.test_case import TestCase as InductorTestCase
from torch._inductor.utils import fresh_inductor_cache from torch._inductor.utils import fresh_cache
_SOURCE_CODE = r""" _SOURCE_CODE = r"""
@ -40,7 +40,7 @@ int saxpy(int n, float a, float *x, float *y) {
@unittest.skipIf(config.is_fbcode(), "fbcode requires different CUDA_HOME setup") @unittest.skipIf(config.is_fbcode(), "fbcode requires different CUDA_HOME setup")
class TestCUDACodeCache(InductorTestCase): class TestCUDACodeCache(InductorTestCase):
def test_cuda_load(self): def test_cuda_load(self):
with fresh_inductor_cache(): with fresh_cache():
# Test both .o and .so compilation. # Test both .o and .so compilation.
( (
object_file_path, object_file_path,
@ -67,13 +67,13 @@ class TestCUDACodeCache(InductorTestCase):
torch.testing.assert_close(y, expected_y) torch.testing.assert_close(y, expected_y)
def test_compilation_error(self): def test_compilation_error(self):
with fresh_inductor_cache(): with fresh_cache():
error_source_code = _SOURCE_CODE.replace("saxpy_device", "saxpy_wrong", 1) error_source_code = _SOURCE_CODE.replace("saxpy_device", "saxpy_wrong", 1)
with self.assertRaises(CUDACompileError): with self.assertRaises(CUDACompileError):
CUDACodeCache.compile(error_source_code, "o") CUDACodeCache.compile(error_source_code, "o")
def test_async_compile(self): def test_async_compile(self):
with fresh_inductor_cache(): with fresh_cache():
async_compile = AsyncCompile() async_compile = AsyncCompile()
compiled_res = async_compile.cuda(_SOURCE_CODE, "so") compiled_res = async_compile.cuda(_SOURCE_CODE, "so")
async_compile.wait(globals()) async_compile.wait(globals())

View File

@ -13,7 +13,7 @@ from pathlib import Path
from typing import Callable, Optional from typing import Callable, Optional
from torch._inductor.codegen.cuda.serialization import get_cutlass_operation_serializer from torch._inductor.codegen.cuda.serialization import get_cutlass_operation_serializer
from torch._inductor.utils import clear_inductor_caches from torch._inductor.utils import clear_caches
from torch.export import Dim from torch.export import Dim
from torch.testing._internal.logging_utils import log_settings from torch.testing._internal.logging_utils import log_settings
@ -38,7 +38,7 @@ from torch._inductor.exc import InductorError
from torch._inductor.ir import FixedLayout from torch._inductor.ir import FixedLayout
from torch._inductor.select_algorithm import NoValidChoicesError from torch._inductor.select_algorithm import NoValidChoicesError
from torch._inductor.test_case import run_tests, TestCase from torch._inductor.test_case import run_tests, TestCase
from torch._inductor.utils import fresh_inductor_cache from torch._inductor.utils import fresh_cache
from torch.sparse import SparseSemiStructuredTensor, to_sparse_semi_structured from torch.sparse import SparseSemiStructuredTensor, to_sparse_semi_structured
from torch.testing import FileCheck from torch.testing import FileCheck
from torch.testing._internal.common_cuda import ( from torch.testing._internal.common_cuda import (
@ -173,7 +173,7 @@ class TestCutlassBackend(TestCase):
def tearDown(self): def tearDown(self):
super().tearDown() super().tearDown()
clear_inductor_caches() clear_caches()
def run_evt_test(self, model, op, shape, num_fusions=1): def run_evt_test(self, model, op, shape, num_fusions=1):
M, N = shape M, N = shape
@ -618,7 +618,7 @@ class TestCutlassBackend(TestCase):
] ]
for x_shape in x_shapes: for x_shape in x_shapes:
torch._dynamo.reset() torch._dynamo.reset()
clear_inductor_caches() clear_caches()
inputs = [ inputs = [
( (
@ -1065,7 +1065,7 @@ class TestCutlassBackend(TestCase):
def select_no_algorithm(*args, **kwargs): def select_no_algorithm(*args, **kwargs):
raise NoValidChoicesError raise NoValidChoicesError
with fresh_inductor_cache(): with fresh_cache():
with config.patch( with config.patch(
{ {
"max_autotune": True, "max_autotune": True,
@ -1113,7 +1113,7 @@ class TestCutlassBackend(TestCase):
def select_no_algorithm(*args, **kwargs): def select_no_algorithm(*args, **kwargs):
raise NoValidChoicesError raise NoValidChoicesError
with fresh_inductor_cache(): with fresh_cache():
with config.patch( with config.patch(
{ {
"max_autotune": True, "max_autotune": True,
@ -1187,7 +1187,7 @@ class TestCutlassBackend(TestCase):
raise NoValidChoicesError raise NoValidChoicesError
def run_test(use_fast_accum): def run_test(use_fast_accum):
with fresh_inductor_cache(): with fresh_cache():
with config.patch( with config.patch(
{ {
"max_autotune": True, "max_autotune": True,
@ -1266,7 +1266,7 @@ class TestCutlassBackend(TestCase):
def select_no_algorithm(*args, **kwargs): def select_no_algorithm(*args, **kwargs):
raise NoValidChoicesError raise NoValidChoicesError
with fresh_inductor_cache(), config.patch( with fresh_cache(), config.patch(
{ {
"max_autotune": True, "max_autotune": True,
"max_autotune_gemm_backends": "CUTLASS", "max_autotune_gemm_backends": "CUTLASS",
@ -1324,7 +1324,7 @@ class TestCutlassBackend(TestCase):
def select_no_algorithm(*args, **kwargs): def select_no_algorithm(*args, **kwargs):
raise NoValidChoicesError raise NoValidChoicesError
with fresh_inductor_cache(), config.patch( with fresh_cache(), config.patch(
{ {
"max_autotune": True, "max_autotune": True,
"max_autotune_gemm_backends": "CUTLASS", "max_autotune_gemm_backends": "CUTLASS",

View File

@ -10,7 +10,7 @@ from pathlib import Path
import torch import torch
from torch._inductor import config, test_operators from torch._inductor import config, test_operators
from torch._inductor.utils import fresh_inductor_cache from torch._inductor.utils import fresh_cache
from torch.testing._internal.common_utils import skipIfWindows from torch.testing._internal.common_utils import skipIfWindows
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU
from torch.testing._internal.logging_utils import multiple_logs_to_string from torch.testing._internal.logging_utils import multiple_logs_to_string
@ -44,7 +44,7 @@ class TestDebugTrace(test_torchinductor.TestCase):
"torch._inductor.debug", "ir_pre_fusion", "ir_post_fusion" "torch._inductor.debug", "ir_pre_fusion", "ir_post_fusion"
) )
# TODO(aakhundov): make this work with fresh_inductor_cache # TODO(aakhundov): make this work with fresh_cache
# instead of force_disable_caches. currently, with the latter # instead of force_disable_caches. currently, with the latter
# enabled, we get `inductor [('fxgraph_cache_hit', 1)]` in # enabled, we get `inductor [('fxgraph_cache_hit', 1)]` in
# the counters: so the cache is actually hit and the test fails. # the counters: so the cache is actually hit and the test fails.
@ -263,7 +263,7 @@ op2.node.kernel = extern_kernels.mm""",
# no failure # no failure
with self.assertLogs( with self.assertLogs(
logging.getLogger("torch._inductor.debug"), level=logging.WARNING logging.getLogger("torch._inductor.debug"), level=logging.WARNING
), fresh_inductor_cache(): ), fresh_cache():
m = ToyModel().to(device=GPU_TYPE) m = ToyModel().to(device=GPU_TYPE)
m = torch.compile(m, mode="max-autotune") m = torch.compile(m, mode="max-autotune")
input_tensor = torch.randn(100).to(device=GPU_TYPE) input_tensor = torch.randn(100).to(device=GPU_TYPE)

View File

@ -5,7 +5,7 @@ import torch._inductor.metrics as metrics
import torch.utils.flop_counter import torch.utils.flop_counter
from torch._dynamo.utils import counters from torch._dynamo.utils import counters
from torch._inductor.ir import FixedLayout from torch._inductor.ir import FixedLayout
from torch._inductor.utils import fresh_inductor_cache from torch._inductor.utils import fresh_cache
from torch.testing._internal.common_cuda import SM70OrLater from torch.testing._internal.common_cuda import SM70OrLater
from torch.testing._internal.common_device_type import ( from torch.testing._internal.common_device_type import (
dtypes, dtypes,
@ -77,7 +77,7 @@ class TestScheduler(TestCase):
for op, example_inputs, kwargs in tc: for op, example_inputs, kwargs in tc:
comp = torch.compile(op) comp = torch.compile(op)
torch._dynamo.reset() torch._dynamo.reset()
with fresh_inductor_cache(): with fresh_cache():
comp(*example_inputs, **kwargs) comp(*example_inputs, **kwargs)
self.assertEqual(metrics.num_bytes_accessed, 0) self.assertEqual(metrics.num_bytes_accessed, 0)
self.assertEqual(any(m[1] for m in metrics.node_runtimes), False) self.assertEqual(any(m[1] for m in metrics.node_runtimes), False)
@ -108,7 +108,7 @@ class TestScheduler(TestCase):
comp = torch.compile(op) comp = torch.compile(op)
torch._dynamo.reset() torch._dynamo.reset()
with fresh_inductor_cache(): with fresh_cache():
comp(*example_inputs, **kwargs) comp(*example_inputs, **kwargs)
self.assertEqual(enba, metrics.num_bytes_accessed) self.assertEqual(enba, metrics.num_bytes_accessed)
nonzero_node_runtimes = sum(1 for x in metrics.node_runtimes if x[1] != 0) nonzero_node_runtimes = sum(1 for x in metrics.node_runtimes if x[1] != 0)
@ -152,7 +152,7 @@ class TestScheduler(TestCase):
comp = torch.compile(op, options=options) comp = torch.compile(op, options=options)
# next two lines are required, otherwise the flops will be cached from pervious runs of this function. # next two lines are required, otherwise the flops will be cached from pervious runs of this function.
torch._dynamo.reset() torch._dynamo.reset()
with fresh_inductor_cache(): with fresh_cache():
# actually run to set the counters # actually run to set the counters
comp(*example_inputs, **kwargs) comp(*example_inputs, **kwargs)
with FlopCounterMode() as mode: with FlopCounterMode() as mode:

View File

@ -13,7 +13,7 @@ from torch._dynamo.testing import rand_strided
from torch._inductor import config from torch._inductor import config
from torch._inductor.codecache import PyCodeCache from torch._inductor.codecache import PyCodeCache
from torch._inductor.test_case import run_tests, TestCase from torch._inductor.test_case import run_tests, TestCase
from torch._inductor.utils import fresh_inductor_cache from torch._inductor.utils import fresh_cache
from torch.testing import FileCheck from torch.testing import FileCheck
from torch.testing._internal.common_cuda import xfailIfSM89 from torch.testing._internal.common_cuda import xfailIfSM89
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU, IS_BIG_GPU from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU, IS_BIG_GPU
@ -152,7 +152,7 @@ class TestKernelBenchmark(TestCase):
@unittest.skipIf( @unittest.skipIf(
not IS_BIG_GPU, "Skipping triton backend only since not big GPU (not enough SM)" not IS_BIG_GPU, "Skipping triton backend only since not big GPU (not enough SM)"
) )
@fresh_inductor_cache() @fresh_cache()
def test_matmul_triton_kernel_benchmark(self): def test_matmul_triton_kernel_benchmark(self):
M = 12544 M = 12544
N = 256 N = 256
@ -170,7 +170,7 @@ class TestKernelBenchmark(TestCase):
@config.patch( @config.patch(
max_autotune=True, max_autotune_gemm_backends="TRITON", shape_padding=False max_autotune=True, max_autotune_gemm_backends="TRITON", shape_padding=False
) )
@fresh_inductor_cache() @fresh_cache()
def test_mm_triton_kernel_benchmark(self): def test_mm_triton_kernel_benchmark(self):
M = 2048 M = 2048
N = 2432 N = 2432

View File

@ -49,7 +49,7 @@ from torch.utils._triton import has_triton_tma_device
aten = torch.ops.aten aten = torch.ops.aten
from torch._inductor.mock_cache import global_stats, PatchCaches, Stats from torch._inductor.mock_cache import global_stats, PatchCaches, Stats
from torch._inductor.test_case import run_tests, TestCase from torch._inductor.test_case import run_tests, TestCase
from torch._inductor.utils import fresh_inductor_cache, run_and_get_code from torch._inductor.utils import fresh_cache, run_and_get_code
from torch._inductor.virtualized import V from torch._inductor.virtualized import V
from torch.fx.experimental.proxy_tensor import make_fx from torch.fx.experimental.proxy_tensor import make_fx
from torch.testing import FileCheck from torch.testing import FileCheck
@ -325,7 +325,7 @@ class TestMaxAutotune(TestCase):
torch.testing.assert_close(c_actual, c_expected, atol=1e-2, rtol=1e-2) torch.testing.assert_close(c_actual, c_expected, atol=1e-2, rtol=1e-2)
@fresh_inductor_cache() @fresh_cache()
@unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support sm carveout") @unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support sm carveout")
@unittest.skipIf(IS_WINDOWS, "Windows doesn't support persistent TMA") @unittest.skipIf(IS_WINDOWS, "Windows doesn't support persistent TMA")
@unittest.skipIf( @unittest.skipIf(
@ -470,7 +470,7 @@ class TestMaxAutotune(TestCase):
FileCheck().check_not("extern_kernels.convolution").run(code[0]) FileCheck().check_not("extern_kernels.convolution").run(code[0])
self.assertEqual(conv1x1(input_tensor), out, atol=1e-2, rtol=0) self.assertEqual(conv1x1(input_tensor), out, atol=1e-2, rtol=0)
@fresh_inductor_cache() @fresh_cache()
@config.patch(max_autotune=True, max_fusion_size=2) @config.patch(max_autotune=True, max_fusion_size=2)
def test_jit_fusion_matches_aot_fusion(self): def test_jit_fusion_matches_aot_fusion(self):
# In this example, AOTInductor's JIT-compile will fuse(buf1, buf2) due # In this example, AOTInductor's JIT-compile will fuse(buf1, buf2) due
@ -563,7 +563,7 @@ class TestMaxAutotune(TestCase):
def f(x, y): def f(x, y):
return x @ y return x @ y
with fresh_inductor_cache(): with fresh_cache():
act = torch.compile(f)(x, y) act = torch.compile(f)(x, y)
ref = f(x, y) ref = f(x, y)
self.assertTrue(torch.allclose(act, ref, atol=4 * 1e-3, rtol=4 * 1e-3)) self.assertTrue(torch.allclose(act, ref, atol=4 * 1e-3, rtol=4 * 1e-3))
@ -1146,7 +1146,7 @@ class TestMaxAutotune(TestCase):
self.assertEqual(generate_and_load_args - 1, make_key_args) self.assertEqual(generate_and_load_args - 1, make_key_args)
self.assertEqual(generate_and_load_args, 16) self.assertEqual(generate_and_load_args, 16)
@fresh_inductor_cache() @fresh_cache()
@config.patch( @config.patch(
{ {
"max_autotune": True, "max_autotune": True,
@ -1213,7 +1213,7 @@ class TestMaxAutotune(TestCase):
b = torch.rand(22, 30, device=GPU_TYPE) b = torch.rand(22, 30, device=GPU_TYPE)
# Valid cache hit. # Valid cache hit.
with fresh_inductor_cache(): with fresh_cache():
reset_counters() reset_counters()
compile_results = torch.compile(func_test1, dynamic=False)(a, b, a, b) compile_results = torch.compile(func_test1, dynamic=False)(a, b, a, b)
eager_results = func_test1(a, b, a, b) eager_results = func_test1(a, b, a, b)
@ -1246,7 +1246,7 @@ class TestMaxAutotune(TestCase):
) )
# Test symbolic shapes with different symbols. Will cache miss due to different symbols in inputs. # Test symbolic shapes with different symbols. Will cache miss due to different symbols in inputs.
with fresh_inductor_cache(): with fresh_cache():
a = torch.rand(10, 22, device=GPU_TYPE) a = torch.rand(10, 22, device=GPU_TYPE)
b = torch.rand(22, 30, device=GPU_TYPE) b = torch.rand(22, 30, device=GPU_TYPE)
@ -1297,7 +1297,7 @@ class TestMaxAutotune(TestCase):
) )
# Test duck typing. # Test duck typing.
with fresh_inductor_cache(): with fresh_cache():
reset_counters() reset_counters()
compile_results = torch.compile(func_test1, dynamic=True)(a, b, a, b) compile_results = torch.compile(func_test1, dynamic=True)(a, b, a, b)
@ -1313,7 +1313,7 @@ class TestMaxAutotune(TestCase):
x = torch.matmul(x, x) x = torch.matmul(x, x)
return x return x
with fresh_inductor_cache(): with fresh_cache():
reset_counters() reset_counters()
input = torch.rand(10, 10, device=GPU_TYPE) input = torch.rand(10, 10, device=GPU_TYPE)
@ -1324,7 +1324,7 @@ class TestMaxAutotune(TestCase):
self.assertEqual(hits(), 36) self.assertEqual(hits(), 36)
self.assertEqual(misses(), 4) self.assertEqual(misses(), 4)
with fresh_inductor_cache(): with fresh_cache():
reset_counters() reset_counters()
input = torch.rand(10, 10, device=GPU_TYPE) input = torch.rand(10, 10, device=GPU_TYPE)
@ -1343,7 +1343,7 @@ class TestMaxAutotune(TestCase):
b = torch.matmul(torch.cat([x, z], 1), torch.cat([y, m, l], 0)) b = torch.matmul(torch.cat([x, z], 1), torch.cat([y, m, l], 0))
return a, b return a, b
with fresh_inductor_cache(): with fresh_cache():
a = torch.rand(10, 22, device=GPU_TYPE) a = torch.rand(10, 22, device=GPU_TYPE)
b = torch.rand(22, 30, device=GPU_TYPE) b = torch.rand(22, 30, device=GPU_TYPE)
c = torch.rand(10, 11, device=GPU_TYPE) c = torch.rand(10, 11, device=GPU_TYPE)
@ -1384,7 +1384,7 @@ class TestMaxAutotune(TestCase):
] ]
# Valid cache hit. # Valid cache hit.
with fresh_inductor_cache(): with fresh_cache():
torch._dynamo.utils.counters.clear() torch._dynamo.utils.counters.clear()
compile_results = torch.compile(func_test1, dynamic=False)(a, b, a, b) compile_results = torch.compile(func_test1, dynamic=False)(a, b, a, b)
eager_results = func_test1(a, b, a, b) eager_results = func_test1(a, b, a, b)
@ -1424,7 +1424,7 @@ class TestMaxAutotune(TestCase):
] ]
# Valid cache hit. # Valid cache hit.
with fresh_inductor_cache(): with fresh_cache():
torch._dynamo.utils.counters.clear() torch._dynamo.utils.counters.clear()
compile_results = torch.compile(func_test1, dynamic=False)(a, b, a, b) compile_results = torch.compile(func_test1, dynamic=False)(a, b, a, b)
eager_results = func_test1(a, b, a, b) eager_results = func_test1(a, b, a, b)
@ -1543,7 +1543,7 @@ class TestMaxAutotunePrecompile(TestCase):
fn_c = torch.compile(mode="max-autotune-no-cudagraphs")(fn) fn_c = torch.compile(mode="max-autotune-no-cudagraphs")(fn)
self.assertEqual(counters["inductor"]["select_algorithm_precompile"], 0) self.assertEqual(counters["inductor"]["select_algorithm_precompile"], 0)
@fresh_inductor_cache() @fresh_cache()
@config.patch(search_autotune_cache=True) @config.patch(search_autotune_cache=True)
def test_search_autotune_cache(self): def test_search_autotune_cache(self):
def fn(a, b, c): def fn(a, b, c):
@ -1811,12 +1811,12 @@ class TestMaxAutotuneRemoteCache(TestCase):
os.environ.pop("TRITON_CACHE_MANAGER", None) os.environ.pop("TRITON_CACHE_MANAGER", None)
with config.patch({"max_autotune": True}): with config.patch({"max_autotune": True}):
for _ in range(4): for _ in range(4):
with fresh_inductor_cache(): with fresh_cache():
torch.compile(mm, dynamic=dynamic)(a, b) torch.compile(mm, dynamic=dynamic)(a, b)
reset() reset()
with torch.compiler.config.patch( with torch.compiler.config.patch(
{"cache_key_tag": "test"} {"cache_key_tag": "test"}
), fresh_inductor_cache(): ), fresh_cache():
torch.compile(mm, dynamic=dynamic)(a, b) torch.compile(mm, dynamic=dynamic)(a, b)
reset() reset()
@ -1825,12 +1825,10 @@ class TestMaxAutotuneRemoteCache(TestCase):
global_stats.reset() global_stats.reset()
for _ in range(4): for _ in range(4):
with fresh_inductor_cache(): with fresh_cache():
torch.compile(f, dynamic=dynamic)(x, y) torch.compile(f, dynamic=dynamic)(x, y)
reset() reset()
with torch.compiler.config.patch( with torch.compiler.config.patch({"cache_key_tag": "test"}), fresh_cache():
{"cache_key_tag": "test"}
), fresh_inductor_cache():
torch.compile(mm, dynamic=dynamic)(a, b) torch.compile(mm, dynamic=dynamic)(a, b)
reset() reset()
global_stats.report() global_stats.report()

View File

@ -13,7 +13,7 @@ from torch._inductor.fx_passes.pad_mm import (
should_pad_mm_bf16, should_pad_mm_bf16,
) )
from torch._inductor.test_case import run_tests, TestCase from torch._inductor.test_case import run_tests, TestCase
from torch._inductor.utils import fresh_inductor_cache, is_big_gpu, run_and_get_code from torch._inductor.utils import fresh_cache, is_big_gpu, run_and_get_code
from torch.testing import FileCheck from torch.testing import FileCheck
from torch.testing._internal.common_utils import skipIfRocm from torch.testing._internal.common_utils import skipIfRocm
from torch.testing._internal.inductor_utils import HAS_CUDA from torch.testing._internal.inductor_utils import HAS_CUDA
@ -362,7 +362,7 @@ class PadMMTest(TestCase):
self.assertEqual(out, inps[0] @ inps[1]) self.assertEqual(out, inps[0] @ inps[1])
@inductor_config.patch(force_shape_pad=True) @inductor_config.patch(force_shape_pad=True)
@fresh_inductor_cache() @fresh_cache()
def test_pad_addmm_2d_bias(self): def test_pad_addmm_2d_bias(self):
@torch.compile() @torch.compile()
def foo(input, x, y): def foo(input, x, y):
@ -419,7 +419,7 @@ class PadMMTest(TestCase):
res2, bmm_expected_result res2, bmm_expected_result
), "BMM results are not identical" ), "BMM results are not identical"
@fresh_inductor_cache() @fresh_cache()
def test_exclude_padding(self): def test_exclude_padding(self):
@torch.compile() @torch.compile()
def mm(a, b): def mm(a, b):
@ -448,7 +448,7 @@ class PadMMTest(TestCase):
repr(local_cache) repr(local_cache)
) )
@fresh_inductor_cache() @fresh_cache()
@inductor_config.patch(max_pointwise_cat_inputs=2) @inductor_config.patch(max_pointwise_cat_inputs=2)
def test_exclude_cat_padding(self): def test_exclude_cat_padding(self):
@torch.compile() @torch.compile()
@ -475,7 +475,7 @@ class PadMMTest(TestCase):
"No perf regression on H100+ with BF16", "No perf regression on H100+ with BF16",
) )
@skipIfRocm @skipIfRocm
@fresh_inductor_cache() @fresh_cache()
@inductor_config.patch( @inductor_config.patch(
post_grad_fusion_options={"pad_aten_mm_pass": {"k_threshold_to_pad": 8388608}} post_grad_fusion_options={"pad_aten_mm_pass": {"k_threshold_to_pad": 8388608}}
) )
@ -508,7 +508,7 @@ class PadMMTest(TestCase):
assert torch.allclose(res2, mm_expected_result), "MM results are not identical" assert torch.allclose(res2, mm_expected_result), "MM results are not identical"
@fresh_inductor_cache() @fresh_cache()
@inductor_config.patch( @inductor_config.patch(
{ {
"triton.unique_kernel_names": "original_aten", "triton.unique_kernel_names": "original_aten",

View File

@ -16,7 +16,7 @@ import torch.nn.functional as F
from torch import sym_int, SymBool, SymFloat, SymInt from torch import sym_int, SymBool, SymFloat, SymInt
from torch._C import _disabled_torch_function_impl from torch._C import _disabled_torch_function_impl
from torch._dynamo.testing import CompileCounterWithBackend from torch._dynamo.testing import CompileCounterWithBackend
from torch._inductor.utils import fresh_inductor_cache from torch._inductor.utils import fresh_cache
from torch.fx.experimental import sym_node from torch.fx.experimental import sym_node
from torch.fx.experimental.proxy_tensor import make_fx from torch.fx.experimental.proxy_tensor import make_fx
from torch.fx.experimental.sym_node import method_to_operator, SymNode, to_node from torch.fx.experimental.sym_node import method_to_operator, SymNode, to_node
@ -3150,7 +3150,7 @@ class TestUnbacked(TestCase):
class TestUbackedOps(TestCase): class TestUbackedOps(TestCase):
@fresh_inductor_cache() @fresh_cache()
@skipIfTorchDynamo("not allowed to trace mark_unbacked") @skipIfTorchDynamo("not allowed to trace mark_unbacked")
@torch._dynamo.config.patch("capture_scalar_outputs", True) @torch._dynamo.config.patch("capture_scalar_outputs", True)
def test_unbacked_reshape1(self): def test_unbacked_reshape1(self):

View File

@ -42,7 +42,7 @@ from torch._inductor.runtime.compile_tasks import (
_set_triton_ptxas_path, _set_triton_ptxas_path,
_worker_compile_triton, _worker_compile_triton,
) )
from torch._inductor.utils import clear_on_fresh_inductor_cache from torch._inductor.utils import clear_on_fresh_cache
from torch._inductor.virtualized import V from torch._inductor.virtualized import V
from torch.hub import _Faketqdm, tqdm from torch.hub import _Faketqdm, tqdm
from torch.utils._ordered_set import OrderedSet from torch.utils._ordered_set import OrderedSet
@ -162,7 +162,7 @@ def get_compile_threads() -> int:
return config.compile_threads return config.compile_threads
@clear_on_fresh_inductor_cache @clear_on_fresh_cache
class CompiledTritonKernels: class CompiledTritonKernels:
""" """
In memory cache for storing compiled triton kernels. In memory cache for storing compiled triton kernels.

View File

@ -86,7 +86,7 @@ from torch._inductor.runtime.compile_tasks import _reload_python_module
from torch._inductor.runtime.runtime_utils import cache_dir, default_cache_dir from torch._inductor.runtime.runtime_utils import cache_dir, default_cache_dir
from torch._inductor.utils import ( from torch._inductor.utils import (
ALIGN_BYTES, ALIGN_BYTES,
clear_on_fresh_inductor_cache, clear_on_fresh_cache,
is_linux, is_linux,
is_windows, is_windows,
) )
@ -236,7 +236,7 @@ class CacheBase:
return system return system
@staticmethod @staticmethod
@clear_on_fresh_inductor_cache @clear_on_fresh_cache
@functools.cache @functools.cache
def get_local_cache_path() -> Path: def get_local_cache_path() -> Path:
return Path(os.path.join(cache_dir(), "cache", CacheBase.get_system()["hash"])) return Path(os.path.join(cache_dir(), "cache", CacheBase.get_system()["hash"]))
@ -1595,7 +1595,7 @@ def split_aot_inductor_output_path(path: str) -> tuple[str, str]:
return path, "" return path, ""
@clear_on_fresh_inductor_cache @clear_on_fresh_cache
class CudaKernelParamCache: class CudaKernelParamCache:
cache: dict[str, dict[str, Any]] = {} cache: dict[str, dict[str, Any]] = {}
cache_clear = staticmethod(cache.clear) cache_clear = staticmethod(cache.clear)
@ -2291,7 +2291,7 @@ def custom_op_wrapper(op: str, *args: Any) -> Union[list[c_void_p], c_void_p, No
# Precompiled headers are persistent past program runtime, but associated with one # Precompiled headers are persistent past program runtime, but associated with one
# specific compiler version and set of flags. We explicitly use default_cache_dir here # specific compiler version and set of flags. We explicitly use default_cache_dir here
# because these headers need to be global, rather than ignored by fresh_inductor_cache. # because these headers need to be global, rather than ignored by fresh_cache.
_HEADER_DIR = os.path.join(default_cache_dir(), "precompiled_headers") _HEADER_DIR = os.path.join(default_cache_dir(), "precompiled_headers")
_HEADER_LOCK_DIR = os.path.join(_HEADER_DIR, "locks") _HEADER_LOCK_DIR = os.path.join(_HEADER_DIR, "locks")
@ -2378,7 +2378,7 @@ def _get_cpp_wrapper_header(device: str, aot_mode: bool = False) -> str:
) )
@clear_on_fresh_inductor_cache @clear_on_fresh_cache
class CppCodeCache: class CppCodeCache:
"""Compiles and caches C++ libraries. Users of this class supply the source code to """Compiles and caches C++ libraries. Users of this class supply the source code to
be compiled, while compilation flags are set by CppBuilder.""" be compiled, while compilation flags are set by CppBuilder."""
@ -2587,7 +2587,7 @@ def _worker_compile_cpp(
# Customized Python binding for cpp kernels # Customized Python binding for cpp kernels
@clear_on_fresh_inductor_cache @clear_on_fresh_cache
class CppPythonBindingsCodeCache(CppCodeCache): class CppPythonBindingsCodeCache(CppCodeCache):
cache: dict[str, Callable[[], Union[CDLL, ModuleType]]] = {} cache: dict[str, Callable[[], Union[CDLL, ModuleType]]] = {}
cache_clear = staticmethod(cache.clear) cache_clear = staticmethod(cache.clear)
@ -2768,7 +2768,7 @@ class CppPythonBindingsCodeCache(CppCodeCache):
return cls.load_pybinding_async(*args, **kwargs)() return cls.load_pybinding_async(*args, **kwargs)()
@clear_on_fresh_inductor_cache @clear_on_fresh_cache
class CppWrapperCodeCache(CppPythonBindingsCodeCache): class CppWrapperCodeCache(CppPythonBindingsCodeCache):
cache: dict[str, Callable[[], Union[CDLL, ModuleType]]] = {} cache: dict[str, Callable[[], Union[CDLL, ModuleType]]] = {}
cache_clear = staticmethod(cache.clear) cache_clear = staticmethod(cache.clear)
@ -2837,7 +2837,7 @@ class CppWrapperCodeCache(CppPythonBindingsCodeCache):
return _get_cpp_wrapper_header(device) return _get_cpp_wrapper_header(device)
@clear_on_fresh_inductor_cache @clear_on_fresh_cache
class HalideCodeCache(CppPythonBindingsCodeCache): class HalideCodeCache(CppPythonBindingsCodeCache):
cache: dict[str, Callable[[], Union[ModuleType, CDLL]]] = {} cache: dict[str, Callable[[], Union[ModuleType, CDLL]]] = {}
cache_clear = staticmethod(cache.clear) cache_clear = staticmethod(cache.clear)
@ -3140,10 +3140,10 @@ class HalideCodeCache(CppPythonBindingsCodeCache):
target = "host-cuda" if device_type == "cuda" else "host" target = "host-cuda" if device_type == "cuda" else "host"
if cls._standalone_runtime_path: if cls._standalone_runtime_path:
assert not os.path.exists(cls._standalone_runtime_path) assert not os.path.exists(cls._standalone_runtime_path)
# We hit this case in unittests when we run with fresh_inductor_cache() # We hit this case in unittests when we run with fresh_cache()
# Generating a fresh runtime over and over causes errors because we initialize # Generating a fresh runtime over and over causes errors because we initialize
# cuda hundreds of times in the same process and run out of file descriptors. # cuda hundreds of times in the same process and run out of file descriptors.
# Workaround by jail breaking the current fresh_inductor_cache(). # Workaround by jail breaking the current fresh_cache().
base = default_cache_dir() base = default_cache_dir()
else: else:
base = cache_dir() base = cache_dir()
@ -3239,7 +3239,7 @@ def touch(filename: str) -> None:
open(filename, "a").close() open(filename, "a").close()
@clear_on_fresh_inductor_cache @clear_on_fresh_cache
class PyCodeCache: class PyCodeCache:
# Track the loaded modules so we can remove the on-disk artifacts when # Track the loaded modules so we can remove the on-disk artifacts when
# clearing the cache. Note also that we may load the same path more # clearing the cache. Note also that we may load the same path more
@ -3625,7 +3625,7 @@ class DLLWrapper:
self.close() self.close()
@clear_on_fresh_inductor_cache @clear_on_fresh_cache
class CUDACodeCache: class CUDACodeCache:
""" """
A cache for managing the compilation and loading of CUDA source code specifically for CUTLASS. A cache for managing the compilation and loading of CUDA source code specifically for CUTLASS.
@ -3792,7 +3792,7 @@ class CUDACodeCache:
fh.write(error_json) fh.write(error_json)
@clear_on_fresh_inductor_cache @clear_on_fresh_cache
class ROCmCodeCache: class ROCmCodeCache:
@dataclasses.dataclass @dataclasses.dataclass
class CacheEntry: class CacheEntry:

View File

@ -4,7 +4,7 @@ import shutil
from typing import Optional from typing import Optional
import torch import torch
from torch._inductor.utils import clear_on_fresh_inductor_cache from torch._inductor.utils import clear_on_fresh_cache
from ... import config from ... import config
@ -12,7 +12,7 @@ from ... import config
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@clear_on_fresh_inductor_cache @clear_on_fresh_cache
@functools.lru_cache(1) @functools.lru_cache(1)
def get_cuda_arch() -> Optional[str]: def get_cuda_arch() -> Optional[str]:
try: try:
@ -27,7 +27,7 @@ def get_cuda_arch() -> Optional[str]:
return None return None
@clear_on_fresh_inductor_cache @clear_on_fresh_cache
@functools.lru_cache(1) @functools.lru_cache(1)
def get_cuda_version() -> Optional[str]: def get_cuda_version() -> Optional[str]:
try: try:

View File

@ -12,7 +12,7 @@ from torch._inductor.codecache import cutlass_key
from torch._inductor.codegen.cuda.cuda_env import get_cuda_arch, get_cuda_version from torch._inductor.codegen.cuda.cuda_env import get_cuda_arch, get_cuda_version
from torch._inductor.codegen.cuda.serialization import get_cutlass_operation_serializer from torch._inductor.codegen.cuda.serialization import get_cutlass_operation_serializer
from torch._inductor.runtime.cache_dir_utils import cache_dir from torch._inductor.runtime.cache_dir_utils import cache_dir
from torch._inductor.utils import clear_on_fresh_inductor_cache from torch._inductor.utils import clear_on_fresh_cache
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -47,7 +47,7 @@ def _generate_config_filename(request_key: str) -> str:
return f"{CONFIG_PREFIX}_{request_key}.json" return f"{CONFIG_PREFIX}_{request_key}.json"
@clear_on_fresh_inductor_cache @clear_on_fresh_cache
@functools.cache @functools.cache
def maybe_fetch_ops() -> Optional[list[Any]]: def maybe_fetch_ops() -> Optional[list[Any]]:
""" """

View File

@ -13,7 +13,7 @@ from typing import Any, Optional
import sympy import sympy
import torch import torch
from torch._inductor.utils import clear_on_fresh_inductor_cache from torch._inductor.utils import clear_on_fresh_cache
from ... import config from ... import config
from ...ir import Layout from ...ir import Layout
@ -250,7 +250,7 @@ class CUTLASSArgs:
self.architectures = _normalize_cuda_arch(self.architectures) self.architectures = _normalize_cuda_arch(self.architectures)
@clear_on_fresh_inductor_cache @clear_on_fresh_cache
@functools.cache @functools.cache
def _gen_ops_cached(arch, version) -> dict[Any, Any]: def _gen_ops_cached(arch, version) -> dict[Any, Any]:
# Note: Cache needs to be specific for cuda architecture and version # Note: Cache needs to be specific for cuda architecture and version

View File

@ -13,7 +13,7 @@ import torch.utils._pytree as pytree
from torch._inductor.codegen.cuda.cutlass_cache import maybe_fetch_ops from torch._inductor.codegen.cuda.cutlass_cache import maybe_fetch_ops
from torch._inductor.scheduler import BaseSchedulerNode from torch._inductor.scheduler import BaseSchedulerNode
from torch._inductor.select_algorithm import create_inputs_key from torch._inductor.select_algorithm import create_inputs_key
from torch._inductor.utils import clear_on_fresh_inductor_cache from torch._inductor.utils import clear_on_fresh_cache
from ... import ir from ... import ir
from ...config import cuda as inductor_cuda_config from ...config import cuda as inductor_cuda_config
@ -405,7 +405,7 @@ int main(int argc, char** argv) {
""" # noqa: B950 """ # noqa: B950
@clear_on_fresh_inductor_cache @clear_on_fresh_cache
class CUTLASSGemmTemplate(CUTLASSTemplate, ABC): class CUTLASSGemmTemplate(CUTLASSTemplate, ABC):
""" """
CUTLASS GEMM Template, which is used to generate CUTLASS GEMM kernels CUTLASS GEMM Template, which is used to generate CUTLASS GEMM kernels

View File

@ -75,7 +75,7 @@ from torch._inductor.runtime.cache_dir_utils import cache_dir
from torch._inductor.utils import ( from torch._inductor.utils import (
BoxedBool, BoxedBool,
count_tangents, count_tangents,
fresh_inductor_cache, fresh_cache,
get_all_devices, get_all_devices,
InputType, InputType,
is_gpu, is_gpu,
@ -675,7 +675,7 @@ def with_fresh_cache_if_config() -> Generator[None, None, None]:
# Don't delete the cache dir because it has to survive beyond the # Don't delete the cache dir because it has to survive beyond the
# compile_fx call. Let's put the temp dirs under the default cache # compile_fx call. Let's put the temp dirs under the default cache
# dir so they're easier to locate. # dir so they're easier to locate.
with fresh_inductor_cache(dir=cache_dir(), delete=False): with fresh_cache(dir=cache_dir(), delete=False):
yield yield
else: else:
yield yield

View File

@ -13,7 +13,7 @@ from torch._inductor.compile_worker.subproc_pool import (
SubprocKind, SubprocKind,
SubprocPool, SubprocPool,
) )
from torch._inductor.utils import clear_inductor_caches from torch._inductor.utils import clear_caches
from .compile_fx_ext import ( from .compile_fx_ext import (
_OutOfProcessFxCompile, _OutOfProcessFxCompile,
@ -77,14 +77,14 @@ class _SubprocessFxCompile(_OutOfProcessFxCompile):
# tmpdir still exists and fails to compile. # tmpdir still exists and fails to compile.
# #
# TODO: We probably should be using a separate tmpdir in the worker # TODO: We probably should be using a separate tmpdir in the worker
# anyway... but we should probably still respect clear_inductor_caches() # anyway... but we should probably still respect clear_caches()
# in the parent... maybe? # in the parent... maybe?
# #
# TODO: We could be less aggressive by keeping a clock which gets # TODO: We could be less aggressive by keeping a clock which gets
# incremented when we clear the cache, send the clock to the worker and # incremented when we clear the cache, send the clock to the worker and
# only clear caches if the clock changed since last time. # only clear caches if the clock changed since last time.
# #
clear_inductor_caches() clear_caches()
torch._inductor.metrics.reset() torch._inductor.metrics.reset()
# TODO: turn off config.fx_graph_async_compile # TODO: turn off config.fx_graph_async_compile

View File

@ -39,15 +39,15 @@ def triton_cache_dir(device: int) -> str:
@contextmanager @contextmanager
def temporary_cache_dir(directory: str) -> Generator[None, None, None]: def temporary_cache_dir(directory: str) -> Generator[None, None, None]:
from torch._inductor.utils import clear_inductor_caches from torch._inductor.utils import clear_caches
original = os.environ.get("TORCHINDUCTOR_CACHE_DIR") original = os.environ.get("TORCHINDUCTOR_CACHE_DIR")
os.environ["TORCHINDUCTOR_CACHE_DIR"] = directory os.environ["TORCHINDUCTOR_CACHE_DIR"] = directory
try: try:
clear_inductor_caches() clear_caches()
yield yield
finally: finally:
clear_inductor_caches() clear_caches()
if original is None: if original is None:
del os.environ["TORCHINDUCTOR_CACHE_DIR"] del os.environ["TORCHINDUCTOR_CACHE_DIR"]
else: else:

View File

@ -28,7 +28,7 @@ import torch._inductor.async_compile # noqa: F401 required to warm up AsyncComp
from torch._dynamo.device_interface import get_interface_for_device from torch._dynamo.device_interface import get_interface_for_device
from torch._dynamo.testing import rand_strided from torch._dynamo.testing import rand_strided
from torch._dynamo.utils import counters, dynamo_timed, identity, preserve_rng_state from torch._dynamo.utils import counters, dynamo_timed, identity, preserve_rng_state
from torch._inductor.utils import clear_on_fresh_inductor_cache from torch._inductor.utils import clear_on_fresh_cache
from torch.utils._filelock import FileLock from torch.utils._filelock import FileLock
from torch.utils._ordered_set import OrderedSet from torch.utils._ordered_set import OrderedSet
@ -1291,7 +1291,7 @@ class TritonTemplate(KernelTemplate):
self.debug = debug self.debug = debug
self._cache_codegen_enabled_for_template = cache_codegen_enabled_for_template self._cache_codegen_enabled_for_template = cache_codegen_enabled_for_template
self._generated_code_cache: GeneratedCodeCache = GeneratedCodeCache() self._generated_code_cache: GeneratedCodeCache = GeneratedCodeCache()
clear_on_fresh_inductor_cache(self._generated_code_cache) clear_on_fresh_cache(self._generated_code_cache)
# When prologue_loads_all_inputs is true, prologue_supported_inputs is populated during def_kernel # When prologue_loads_all_inputs is true, prologue_supported_inputs is populated during def_kernel
# by adding all inputs. # by adding all inputs.
self.prologue_loads_all_inputs = prologue_loads_all_inputs self.prologue_loads_all_inputs = prologue_loads_all_inputs
@ -2123,7 +2123,7 @@ class AlgorithmSelectorCache(PersistentCache):
# list of callbacks that are called after benchmarking # list of callbacks that are called after benchmarking
self.feedback_saver_fns: list[FeedbackFunction] = [] self.feedback_saver_fns: list[FeedbackFunction] = []
clear_on_fresh_inductor_cache(self) clear_on_fresh_cache(self)
def cache_clear(self) -> None: def cache_clear(self) -> None:
self.precompile_cache.clear() self.precompile_cache.clear()

View File

@ -8,7 +8,7 @@ from torch._dynamo.test_case import (
) )
from torch._functorch import config as functorch_config from torch._functorch import config as functorch_config
from torch._inductor import config from torch._inductor import config
from torch._inductor.utils import fresh_inductor_cache from torch._inductor.utils import fresh_cache
def run_tests(needs: Union[str, tuple[str, ...]] = ()) -> None: def run_tests(needs: Union[str, tuple[str, ...]] = ()) -> None:
@ -41,7 +41,7 @@ class TestCase(DynamoTestCase):
os.environ.get("INDUCTOR_TEST_DISABLE_FRESH_CACHE") != "1" os.environ.get("INDUCTOR_TEST_DISABLE_FRESH_CACHE") != "1"
and os.environ.get("TORCH_COMPILE_DEBUG") != "1" and os.environ.get("TORCH_COMPILE_DEBUG") != "1"
): ):
self._inductor_test_stack.enter_context(fresh_inductor_cache()) self._inductor_test_stack.enter_context(fresh_cache())
def tearDown(self) -> None: def tearDown(self) -> None:
super().tearDown() super().tearDown()

View File

@ -1015,29 +1015,6 @@ def get_all_devices(gm: torch.fx.GraphModule) -> OrderedSet[torch.device]:
return input_devices | out_devices return input_devices | out_devices
_registered_caches: list[Any] = []
def clear_on_fresh_inductor_cache(obj: Any) -> Any:
"""
Use this decorator to register any caches that should be cache_clear'd
with fresh_inductor_cache().
"""
if not hasattr(obj, "cache_clear") or not callable(obj.cache_clear):
raise AttributeError(f"{obj} does not have a cache_clear method")
_registered_caches.append(obj)
return obj
def clear_inductor_caches() -> None:
"""
Clear all registered caches.
"""
for obj in _registered_caches:
obj.cache_clear()
import gc import gc
@ -1070,19 +1047,42 @@ def unload_xpu_triton_pyds() -> None:
gc.collect() gc.collect()
_registered_caches: list[Any] = []
def clear_on_fresh_cache(obj: Any) -> Any:
"""
Use this decorator to register any caches that should be cache_clear'd
with fresh_cache().
"""
if not hasattr(obj, "cache_clear") or not callable(obj.cache_clear):
raise AttributeError(f"{obj} does not have a cache_clear method")
_registered_caches.append(obj)
return obj
def clear_caches() -> None:
"""
Clear all registered caches.
"""
for obj in _registered_caches:
obj.cache_clear()
@contextlib.contextmanager @contextlib.contextmanager
def fresh_inductor_cache( def fresh_cache(
cache_entries: Optional[dict[str, Any]] = None, cache_entries: Optional[dict[str, Any]] = None,
dir: Optional[str] = None, dir: Optional[str] = None,
delete: bool = True, delete: bool = True,
) -> Iterator[None]: ) -> Iterator[None]:
""" """
Contextmanager that provides a clean tmp cachedir for inductor. Contextmanager that provides a clean tmp cachedir for pt2 caches.
Optionally, pass a dict as 'cache_entries' to get a list of filenames and sizes Optionally, pass a dict as 'cache_entries' to get a list of filenames and sizes
generated with this cache instance. generated with this cache instance.
""" """
clear_inductor_caches() clear_caches()
inductor_cache_dir = tempfile.mkdtemp(dir=dir) inductor_cache_dir = tempfile.mkdtemp(dir=dir)
try: try:
@ -1123,7 +1123,13 @@ def fresh_inductor_cache(
log.warning("on error, temporary cache dir kept at %s", inductor_cache_dir) log.warning("on error, temporary cache dir kept at %s", inductor_cache_dir)
raise raise
finally: finally:
clear_inductor_caches() clear_caches()
# Deprecated functions -- only keeping them for BC reasons
clear_on_fresh_inductor_cache = clear_on_fresh_cache
clear_inductor_caches = clear_caches
fresh_inductor_cache = fresh_cache
def argsort(seq: Sequence[Any]) -> list[int]: def argsort(seq: Sequence[Any]) -> list[int]:

View File

@ -15,7 +15,7 @@ from benchmark_utils import ( # type: ignore[import-not-found]
) )
import torch import torch
from torch._inductor.utils import fresh_inductor_cache from torch._inductor.utils import fresh_cache
class BenchmarkRunnerMixedMM(BenchmarkRunner): # type: ignore[misc, no-any-unimported] class BenchmarkRunnerMixedMM(BenchmarkRunner): # type: ignore[misc, no-any-unimported]
@ -59,7 +59,7 @@ class BenchmarkRunnerMixedMM(BenchmarkRunner): # type: ignore[misc, no-any-unim
) )
b = b.to(dtype=dtype_right) b = b.to(dtype=dtype_right)
with fresh_inductor_cache(): with fresh_cache():
def mixed_mm(A, B): def mixed_mm(A, B):
return torch.mm(A, B.to(A.dtype)) return torch.mm(A, B.to(A.dtype))

View File

@ -16,7 +16,7 @@ from benchmark_utils import ( # type: ignore[import-not-found]
) )
import torch import torch
from torch._inductor.utils import fresh_inductor_cache from torch._inductor.utils import fresh_cache
class BenchmarkRunnerMM(BenchmarkRunner): # type: ignore[misc, no-any-unimported] class BenchmarkRunnerMM(BenchmarkRunner): # type: ignore[misc, no-any-unimported]
@ -57,7 +57,7 @@ class BenchmarkRunnerMM(BenchmarkRunner): # type: ignore[misc, no-any-unimporte
dtype_right=dtype, dtype_right=dtype,
) )
with fresh_inductor_cache(): with fresh_cache():
def mixed_mm(A: Any, B: Any) -> Any: def mixed_mm(A: Any, B: Any) -> Any:
return torch.mm(A, B) return torch.mm(A, B)

View File

@ -18,7 +18,7 @@ import torch
from torch._inductor.fx_passes.pad_mm import ( # type: ignore[import-not-found] from torch._inductor.fx_passes.pad_mm import ( # type: ignore[import-not-found]
get_alignment_size_dtype, get_alignment_size_dtype,
) )
from torch._inductor.utils import fresh_inductor_cache from torch._inductor.utils import fresh_cache
class BenchmarkRunnerPadMM(BenchmarkRunner): # type: ignore[misc, no-any-unimported] class BenchmarkRunnerPadMM(BenchmarkRunner): # type: ignore[misc, no-any-unimported]
@ -74,7 +74,7 @@ class BenchmarkRunnerPadMM(BenchmarkRunner): # type: ignore[misc, no-any-unimpo
print(f"transpose_left={transpose_left} transpose_right={transpose_right}") print(f"transpose_left={transpose_left} transpose_right={transpose_right}")
print(f"prepadded_left={prepadded_left} prepadded_right={prepadded_right}") print(f"prepadded_left={prepadded_left} prepadded_right={prepadded_right}")
with fresh_inductor_cache(): with fresh_cache():
def mm(a: Any, b: Any) -> Any: def mm(a: Any, b: Any) -> Any:
return torch.mm(a, b) return torch.mm(a, b)