mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Rename inductor cache (#156128)
Requested by Simon on a different PR Pull Request resolved: https://github.com/pytorch/pytorch/pull/156128 Approved by: https://github.com/xmfan
This commit is contained in:
parent
45382b284d
commit
a2a75be0f8
|
|
@ -8,7 +8,7 @@ import sys
|
||||||
import tempfile
|
import tempfile
|
||||||
from typing import Callable
|
from typing import Callable
|
||||||
|
|
||||||
from torch._inductor.utils import fresh_inductor_cache
|
from torch._inductor.utils import fresh_cache
|
||||||
|
|
||||||
|
|
||||||
logger: logging.Logger = logging.getLogger(__name__)
|
logger: logging.Logger = logging.getLogger(__name__)
|
||||||
|
|
@ -62,7 +62,7 @@ def _run_torchbench_from_args(
|
||||||
warm_compile_time: list[float] = []
|
warm_compile_time: list[float] = []
|
||||||
|
|
||||||
for _ in range(cmd_args.repeat):
|
for _ in range(cmd_args.repeat):
|
||||||
with fresh_inductor_cache():
|
with fresh_cache():
|
||||||
env = os.environ.copy()
|
env = os.environ.copy()
|
||||||
with tempfile.NamedTemporaryFile(suffix=".csv") as file:
|
with tempfile.NamedTemporaryFile(suffix=".csv") as file:
|
||||||
args.append("--output=" + file.name)
|
args.append("--output=" + file.name)
|
||||||
|
|
|
||||||
|
|
@ -51,7 +51,7 @@ from torch._logging.scribe import open_source_signpost
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from torch._dynamo.utils import clone_inputs, graph_break_reasons
|
from torch._dynamo.utils import clone_inputs, graph_break_reasons
|
||||||
from torch._inductor.utils import fresh_inductor_cache
|
from torch._inductor.utils import fresh_cache
|
||||||
except ImportError:
|
except ImportError:
|
||||||
from _dynamo.utils import clone_inputs, graph_break_reasons
|
from _dynamo.utils import clone_inputs, graph_break_reasons
|
||||||
|
|
||||||
|
|
@ -3416,7 +3416,7 @@ def maybe_fresh_cache(args):
|
||||||
if not cache_dir_assigned and (
|
if not cache_dir_assigned and (
|
||||||
args.cold_start_latency or args.warm_start_latency or args.ci
|
args.cold_start_latency or args.warm_start_latency or args.ci
|
||||||
):
|
):
|
||||||
return fresh_inductor_cache()
|
return fresh_cache()
|
||||||
else:
|
else:
|
||||||
return contextlib.nullcontext()
|
return contextlib.nullcontext()
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,7 @@ import timeit
|
||||||
|
|
||||||
import torch.fx
|
import torch.fx
|
||||||
from torch._dynamo.utils import counters
|
from torch._dynamo.utils import counters
|
||||||
from torch._inductor.utils import clear_inductor_caches, fresh_inductor_cache
|
from torch._inductor.utils import clear_caches, fresh_cache
|
||||||
|
|
||||||
|
|
||||||
N = 10000
|
N = 10000
|
||||||
|
|
@ -20,7 +20,7 @@ def main():
|
||||||
torch._inductor.config.fx_graph_cache = True
|
torch._inductor.config.fx_graph_cache = True
|
||||||
torch._inductor.config.fx_graph_remote_cache = False
|
torch._inductor.config.fx_graph_remote_cache = False
|
||||||
|
|
||||||
with fresh_inductor_cache():
|
with fresh_cache():
|
||||||
a = torch.randn(4).cuda()
|
a = torch.randn(4).cuda()
|
||||||
compiled_fn = torch.compile(huge_graph, backend="inductor")
|
compiled_fn = torch.compile(huge_graph, backend="inductor")
|
||||||
|
|
||||||
|
|
@ -30,7 +30,7 @@ def main():
|
||||||
|
|
||||||
def setup():
|
def setup():
|
||||||
torch._dynamo.reset()
|
torch._dynamo.reset()
|
||||||
clear_inductor_caches()
|
clear_caches()
|
||||||
for m in torch._inductor.codecache.PyCodeCache.cache.values():
|
for m in torch._inductor.codecache.PyCodeCache.cache.values():
|
||||||
os.remove(m.__file__)
|
os.remove(m.__file__)
|
||||||
counters.clear()
|
counters.clear()
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,7 @@ import sys
|
||||||
from benchmark_base import BenchmarkBase
|
from benchmark_base import BenchmarkBase
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch._inductor.utils import fresh_inductor_cache
|
from torch._inductor.utils import fresh_cache
|
||||||
|
|
||||||
|
|
||||||
class Benchmark(BenchmarkBase):
|
class Benchmark(BenchmarkBase):
|
||||||
|
|
@ -50,7 +50,7 @@ class Benchmark(BenchmarkBase):
|
||||||
result = result.sin()
|
result = result.sin()
|
||||||
return result
|
return result
|
||||||
|
|
||||||
with fresh_inductor_cache():
|
with fresh_cache():
|
||||||
f(self.a, self.b)
|
f(self.a, self.b)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,7 @@ from benchmark_base import BenchmarkBase
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch._inductor.utils import fresh_inductor_cache
|
from torch._inductor.utils import fresh_cache
|
||||||
|
|
||||||
|
|
||||||
class ListOfLinears(nn.Module):
|
class ListOfLinears(nn.Module):
|
||||||
|
|
@ -55,7 +55,7 @@ class Benchmark(BenchmarkBase):
|
||||||
|
|
||||||
def _work(self):
|
def _work(self):
|
||||||
with (
|
with (
|
||||||
fresh_inductor_cache(),
|
fresh_cache(),
|
||||||
torch._inductor.config.patch(force_shape_pad=self._force_shape_pad),
|
torch._inductor.config.patch(force_shape_pad=self._force_shape_pad),
|
||||||
):
|
):
|
||||||
opt_m = torch.compile(backend=self.backend(), dynamic=self.is_dynamic())(
|
opt_m = torch.compile(backend=self.backend(), dynamic=self.is_dynamic())(
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,7 @@ from benchmark_base import BenchmarkBase
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch._inductor.utils import fresh_inductor_cache
|
from torch._inductor.utils import fresh_cache
|
||||||
|
|
||||||
|
|
||||||
# Create a chain of artificial nesting
|
# Create a chain of artificial nesting
|
||||||
|
|
@ -94,7 +94,7 @@ class Benchmark(BenchmarkBase):
|
||||||
# enable_cpp_symbolic_shape_guards has impact on this benchmark
|
# enable_cpp_symbolic_shape_guards has impact on this benchmark
|
||||||
# Keep using False value for consistency.
|
# Keep using False value for consistency.
|
||||||
with (
|
with (
|
||||||
fresh_inductor_cache(),
|
fresh_cache(),
|
||||||
):
|
):
|
||||||
opt_m = torch.compile(backend=self.backend(), dynamic=self.is_dynamic())(
|
opt_m = torch.compile(backend=self.backend(), dynamic=self.is_dynamic())(
|
||||||
self.m.cuda() if self._is_gpu else self.m
|
self.m.cuda() if self._is_gpu else self.m
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,7 @@ import sys
|
||||||
from benchmark_base import BenchmarkBase
|
from benchmark_base import BenchmarkBase
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch._inductor.utils import fresh_inductor_cache
|
from torch._inductor.utils import fresh_cache
|
||||||
|
|
||||||
|
|
||||||
class Benchmark(BenchmarkBase):
|
class Benchmark(BenchmarkBase):
|
||||||
|
|
@ -31,7 +31,7 @@ class Benchmark(BenchmarkBase):
|
||||||
def f(x, y):
|
def f(x, y):
|
||||||
return x + y
|
return x + y
|
||||||
|
|
||||||
with fresh_inductor_cache():
|
with fresh_cache():
|
||||||
for i in range(8):
|
for i in range(8):
|
||||||
f(torch.arange(3), i * 2.5)
|
f(torch.arange(3), i * 2.5)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,7 @@ import sys
|
||||||
from benchmark_base import BenchmarkBase
|
from benchmark_base import BenchmarkBase
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch._inductor.utils import fresh_inductor_cache
|
from torch._inductor.utils import fresh_cache
|
||||||
|
|
||||||
|
|
||||||
class Benchmark(BenchmarkBase):
|
class Benchmark(BenchmarkBase):
|
||||||
|
|
@ -45,7 +45,7 @@ class Benchmark(BenchmarkBase):
|
||||||
z = torch.mm(z, b)
|
z = torch.mm(z, b)
|
||||||
return z
|
return z
|
||||||
|
|
||||||
with fresh_inductor_cache(), torch._inductor.config.patch(max_autotune=True):
|
with fresh_cache(), torch._inductor.config.patch(max_autotune=True):
|
||||||
f(self.a, self.b)
|
f(self.a, self.b)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,7 @@ from benchmark_base import BenchmarkBase
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch._inductor.utils import fresh_inductor_cache
|
from torch._inductor.utils import fresh_cache
|
||||||
|
|
||||||
|
|
||||||
class NestedModule(nn.Module):
|
class NestedModule(nn.Module):
|
||||||
|
|
@ -67,7 +67,7 @@ class Benchmark(BenchmarkBase):
|
||||||
# enable_cpp_symbolic_shape_guards has impact on this benchmark
|
# enable_cpp_symbolic_shape_guards has impact on this benchmark
|
||||||
# Keep using False value for consistency.
|
# Keep using False value for consistency.
|
||||||
with (
|
with (
|
||||||
fresh_inductor_cache(),
|
fresh_cache(),
|
||||||
):
|
):
|
||||||
opt_m = torch.compile(backend=self.backend(), dynamic=self.is_dynamic())(
|
opt_m = torch.compile(backend=self.backend(), dynamic=self.is_dynamic())(
|
||||||
self.m.cuda() if self._is_gpu else self.m
|
self.m.cuda() if self._is_gpu else self.m
|
||||||
|
|
|
||||||
|
|
@ -246,7 +246,7 @@ def run_single_experiment_group(
|
||||||
|
|
||||||
for config in group_config.experiments:
|
for config in group_config.experiments:
|
||||||
torch._dynamo.reset()
|
torch._dynamo.reset()
|
||||||
torch._inductor.utils.clear_inductor_caches()
|
torch._inductor.utils.clear_caches()
|
||||||
compiled_op = torch.compile(
|
compiled_op = torch.compile(
|
||||||
op,
|
op,
|
||||||
options=config.to_options(),
|
options=config.to_options(),
|
||||||
|
|
|
||||||
|
|
@ -13,7 +13,7 @@ from torch._inductor.fx_passes.micro_pipeline_tp import (
|
||||||
micro_pipeline_tp_pass,
|
micro_pipeline_tp_pass,
|
||||||
)
|
)
|
||||||
from torch._inductor.fx_passes.post_grad import remove_noop_ops, view_to_reshape
|
from torch._inductor.fx_passes.post_grad import remove_noop_ops, view_to_reshape
|
||||||
from torch._inductor.utils import fresh_inductor_cache, run_and_get_triton_code
|
from torch._inductor.utils import fresh_cache, run_and_get_triton_code
|
||||||
from torch.distributed._functional_collectives import (
|
from torch.distributed._functional_collectives import (
|
||||||
all_gather_tensor,
|
all_gather_tensor,
|
||||||
reduce_scatter_tensor,
|
reduce_scatter_tensor,
|
||||||
|
|
@ -81,7 +81,7 @@ class MicroPipelineTPTest(TestCase):
|
||||||
dist.destroy_process_group()
|
dist.destroy_process_group()
|
||||||
|
|
||||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||||
@fresh_inductor_cache()
|
@fresh_cache()
|
||||||
def test_find_all_gather_patterns(self):
|
def test_find_all_gather_patterns(self):
|
||||||
group = dist.group.WORLD
|
group = dist.group.WORLD
|
||||||
|
|
||||||
|
|
@ -134,7 +134,7 @@ class MicroPipelineTPTest(TestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||||
@fresh_inductor_cache()
|
@fresh_cache()
|
||||||
def test_find_reduce_scatter_patterns(self):
|
def test_find_reduce_scatter_patterns(self):
|
||||||
group = dist.group.WORLD
|
group = dist.group.WORLD
|
||||||
|
|
||||||
|
|
@ -173,7 +173,7 @@ class MicroPipelineTPTest(TestCase):
|
||||||
self.assertEqual(reduce_scatters[1].scatter_dim, 1)
|
self.assertEqual(reduce_scatters[1].scatter_dim, 1)
|
||||||
|
|
||||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||||
@fresh_inductor_cache()
|
@fresh_cache()
|
||||||
def test_get_unexposed_collectives(self):
|
def test_get_unexposed_collectives(self):
|
||||||
group = dist.group.WORLD
|
group = dist.group.WORLD
|
||||||
|
|
||||||
|
|
@ -201,7 +201,7 @@ class MicroPipelineTPTest(TestCase):
|
||||||
@parametrize("A_dims", [2, 3])
|
@parametrize("A_dims", [2, 3])
|
||||||
@parametrize("gather_dim", [0, 1, 2])
|
@parametrize("gather_dim", [0, 1, 2])
|
||||||
@parametrize("return_A", [True, False])
|
@parametrize("return_A", [True, False])
|
||||||
@fresh_inductor_cache()
|
@fresh_cache()
|
||||||
def test_fuse_all_gather_matmul(self, A_dims, gather_dim, return_A):
|
def test_fuse_all_gather_matmul(self, A_dims, gather_dim, return_A):
|
||||||
if gather_dim >= A_dims:
|
if gather_dim >= A_dims:
|
||||||
return
|
return
|
||||||
|
|
@ -248,7 +248,7 @@ class MicroPipelineTPTest(TestCase):
|
||||||
@parametrize("A_dims", [2, 3])
|
@parametrize("A_dims", [2, 3])
|
||||||
@parametrize("gather_dim", [0, 1, 2])
|
@parametrize("gather_dim", [0, 1, 2])
|
||||||
@parametrize("return_A", [True, False])
|
@parametrize("return_A", [True, False])
|
||||||
@fresh_inductor_cache()
|
@fresh_cache()
|
||||||
def test_fuse_all_gather_scaled_matmul(self, A_dims, gather_dim, return_A):
|
def test_fuse_all_gather_scaled_matmul(self, A_dims, gather_dim, return_A):
|
||||||
if gather_dim >= A_dims:
|
if gather_dim >= A_dims:
|
||||||
return
|
return
|
||||||
|
|
@ -321,7 +321,7 @@ class MicroPipelineTPTest(TestCase):
|
||||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||||
@parametrize("A_dims", [2, 3])
|
@parametrize("A_dims", [2, 3])
|
||||||
@parametrize("scatter_dim", [0, 1, 2])
|
@parametrize("scatter_dim", [0, 1, 2])
|
||||||
@fresh_inductor_cache()
|
@fresh_cache()
|
||||||
def test_fuse_matmul_reduce_scatter(self, A_dims, scatter_dim):
|
def test_fuse_matmul_reduce_scatter(self, A_dims, scatter_dim):
|
||||||
if scatter_dim >= A_dims:
|
if scatter_dim >= A_dims:
|
||||||
return
|
return
|
||||||
|
|
@ -350,7 +350,7 @@ class MicroPipelineTPTest(TestCase):
|
||||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||||
@parametrize("A_dims", [2, 3])
|
@parametrize("A_dims", [2, 3])
|
||||||
@parametrize("scatter_dim", [0, 1, 2])
|
@parametrize("scatter_dim", [0, 1, 2])
|
||||||
@fresh_inductor_cache()
|
@fresh_cache()
|
||||||
def test_fuse_scaled_matmul_reduce_scatter(self, A_dims, scatter_dim):
|
def test_fuse_scaled_matmul_reduce_scatter(self, A_dims, scatter_dim):
|
||||||
if scatter_dim >= A_dims:
|
if scatter_dim >= A_dims:
|
||||||
return
|
return
|
||||||
|
|
@ -403,7 +403,7 @@ class MicroPipelineTPTest(TestCase):
|
||||||
@runOnRocmArch(MI300_ARCH)
|
@runOnRocmArch(MI300_ARCH)
|
||||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||||
@parametrize("scatter_dim", [0, 1, 2])
|
@parametrize("scatter_dim", [0, 1, 2])
|
||||||
@fresh_inductor_cache()
|
@fresh_cache()
|
||||||
def test_fuse_scaled_matmul_reduce_scatter_rowwise_scales_reshape_mm_reshape(
|
def test_fuse_scaled_matmul_reduce_scatter_rowwise_scales_reshape_mm_reshape(
|
||||||
self, scatter_dim
|
self, scatter_dim
|
||||||
):
|
):
|
||||||
|
|
@ -465,7 +465,7 @@ class MicroPipelineTPTest(TestCase):
|
||||||
|
|
||||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||||
@parametrize("shard_dim", [0, 1])
|
@parametrize("shard_dim", [0, 1])
|
||||||
@fresh_inductor_cache()
|
@fresh_cache()
|
||||||
def test_dtensor_seq_par(self, shard_dim: int):
|
def test_dtensor_seq_par(self, shard_dim: int):
|
||||||
model: torch.nn.Module = MLPModule(device="cuda", bias=False)
|
model: torch.nn.Module = MLPModule(device="cuda", bias=False)
|
||||||
device_mesh = DeviceMesh(
|
device_mesh = DeviceMesh(
|
||||||
|
|
|
||||||
|
|
@ -9,7 +9,7 @@ import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
import torch.distributed._functional_collectives as funcol
|
import torch.distributed._functional_collectives as funcol
|
||||||
from torch._C import FileCheck
|
from torch._C import FileCheck
|
||||||
from torch._inductor.utils import fresh_inductor_cache, run_and_get_triton_code
|
from torch._inductor.utils import fresh_cache, run_and_get_triton_code
|
||||||
from torch.distributed._functional_collectives import (
|
from torch.distributed._functional_collectives import (
|
||||||
all_gather_into_tensor_coalesced,
|
all_gather_into_tensor_coalesced,
|
||||||
all_gather_tensor,
|
all_gather_tensor,
|
||||||
|
|
@ -464,7 +464,7 @@ class TestWithNCCL(MultiProcessTestCase):
|
||||||
|
|
||||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||||
@skip_if_lt_x_gpu(2)
|
@skip_if_lt_x_gpu(2)
|
||||||
@fresh_inductor_cache()
|
@fresh_cache()
|
||||||
def test_threading(self):
|
def test_threading(self):
|
||||||
self._init_process_group()
|
self._init_process_group()
|
||||||
device = torch.device(f"cuda:{self.rank}")
|
device = torch.device(f"cuda:{self.rank}")
|
||||||
|
|
@ -510,7 +510,7 @@ class TestWithNCCL(MultiProcessTestCase):
|
||||||
"_scaled_mm currently only supports sm>=90",
|
"_scaled_mm currently only supports sm>=90",
|
||||||
)
|
)
|
||||||
@skip_if_lt_x_gpu(2)
|
@skip_if_lt_x_gpu(2)
|
||||||
@fresh_inductor_cache()
|
@fresh_cache()
|
||||||
def test_fixed_striding(self):
|
def test_fixed_striding(self):
|
||||||
self._init_process_group()
|
self._init_process_group()
|
||||||
|
|
||||||
|
|
@ -736,7 +736,7 @@ class CompileTest(TestCase):
|
||||||
dist.destroy_process_group()
|
dist.destroy_process_group()
|
||||||
|
|
||||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||||
@fresh_inductor_cache()
|
@fresh_cache()
|
||||||
def test_inductor_all_reduce_single(self):
|
def test_inductor_all_reduce_single(self):
|
||||||
def func(arg: torch.Tensor) -> torch.Tensor:
|
def func(arg: torch.Tensor) -> torch.Tensor:
|
||||||
buf0 = arg + 42
|
buf0 = arg + 42
|
||||||
|
|
@ -773,7 +773,7 @@ class CompileTest(TestCase):
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||||
@fresh_inductor_cache()
|
@fresh_cache()
|
||||||
def test_inductor_all_reduce_coalesced(self):
|
def test_inductor_all_reduce_coalesced(self):
|
||||||
def func(args: list[torch.Tensor]) -> torch.Tensor:
|
def func(args: list[torch.Tensor]) -> torch.Tensor:
|
||||||
bufs = [arg + 42 for arg in args]
|
bufs = [arg + 42 for arg in args]
|
||||||
|
|
@ -819,7 +819,7 @@ class CompileTest(TestCase):
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||||
@fresh_inductor_cache()
|
@fresh_cache()
|
||||||
def test_inductor_inplace_op_on_view(self):
|
def test_inductor_inplace_op_on_view(self):
|
||||||
def func(arg: torch.Tensor) -> torch.Tensor:
|
def func(arg: torch.Tensor) -> torch.Tensor:
|
||||||
buf0 = (arg + 10)[:2]
|
buf0 = (arg + 10)[:2]
|
||||||
|
|
@ -843,7 +843,7 @@ class CompileTest(TestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||||
@fresh_inductor_cache()
|
@fresh_cache()
|
||||||
def test_inductor_all_reduce_non_contig_input(self):
|
def test_inductor_all_reduce_non_contig_input(self):
|
||||||
def func(arg: torch.Tensor) -> torch.Tensor:
|
def func(arg: torch.Tensor) -> torch.Tensor:
|
||||||
ar0 = funcol.all_reduce(arg, "avg", "0")
|
ar0 = funcol.all_reduce(arg, "avg", "0")
|
||||||
|
|
@ -869,7 +869,7 @@ class CompileTest(TestCase):
|
||||||
assert "torch.ops._c10d_functional.wait_tensor.default" in code
|
assert "torch.ops._c10d_functional.wait_tensor.default" in code
|
||||||
|
|
||||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||||
@fresh_inductor_cache()
|
@fresh_cache()
|
||||||
def test_inductor_reuse_buffer_after_inplace_collective(self):
|
def test_inductor_reuse_buffer_after_inplace_collective(self):
|
||||||
def func(arg: torch.Tensor) -> torch.Tensor:
|
def func(arg: torch.Tensor) -> torch.Tensor:
|
||||||
# Expect allocation
|
# Expect allocation
|
||||||
|
|
@ -904,7 +904,7 @@ class CompileTest(TestCase):
|
||||||
assert "= torch.ops._c10d_functional.wait_tensor.default" not in code
|
assert "= torch.ops._c10d_functional.wait_tensor.default" not in code
|
||||||
|
|
||||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||||
@fresh_inductor_cache()
|
@fresh_cache()
|
||||||
def test_inductor_all_gather_into_tensor_single(self):
|
def test_inductor_all_gather_into_tensor_single(self):
|
||||||
def func(arg: torch.Tensor) -> torch.Tensor:
|
def func(arg: torch.Tensor) -> torch.Tensor:
|
||||||
ag0 = funcol.all_gather_tensor(arg, 0, "0")
|
ag0 = funcol.all_gather_tensor(arg, 0, "0")
|
||||||
|
|
@ -931,7 +931,7 @@ class CompileTest(TestCase):
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||||
@fresh_inductor_cache()
|
@fresh_cache()
|
||||||
def test_inductor_all_gather_into_tensor_coalesced(self):
|
def test_inductor_all_gather_into_tensor_coalesced(self):
|
||||||
def func(args: list[torch.Tensor]) -> torch.Tensor:
|
def func(args: list[torch.Tensor]) -> torch.Tensor:
|
||||||
ag0 = funcol.all_gather_into_tensor_coalesced(args, "0")
|
ag0 = funcol.all_gather_into_tensor_coalesced(args, "0")
|
||||||
|
|
@ -965,7 +965,7 @@ class CompileTest(TestCase):
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
@unittest.skipIf(not HAS_GPU, "This is a GPU test!")
|
@unittest.skipIf(not HAS_GPU, "This is a GPU test!")
|
||||||
@fresh_inductor_cache()
|
@fresh_cache()
|
||||||
def test_wait_tensor(self):
|
def test_wait_tensor(self):
|
||||||
def func(arg: torch.Tensor) -> torch.Tensor:
|
def func(arg: torch.Tensor) -> torch.Tensor:
|
||||||
t = torch.ops._c10d_functional.all_reduce(arg, "avg", "0")
|
t = torch.ops._c10d_functional.all_reduce(arg, "avg", "0")
|
||||||
|
|
@ -987,7 +987,7 @@ class CompileTest(TestCase):
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||||
@fresh_inductor_cache()
|
@fresh_cache()
|
||||||
def test_inductor_reduce_scatter_tensor_single(self):
|
def test_inductor_reduce_scatter_tensor_single(self):
|
||||||
def func(arg: torch.Tensor) -> torch.Tensor:
|
def func(arg: torch.Tensor) -> torch.Tensor:
|
||||||
rs0 = funcol.reduce_scatter_tensor(arg, "avg", 0, "0")
|
rs0 = funcol.reduce_scatter_tensor(arg, "avg", 0, "0")
|
||||||
|
|
@ -1013,7 +1013,7 @@ class CompileTest(TestCase):
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||||
@fresh_inductor_cache()
|
@fresh_cache()
|
||||||
def test_inductor_reduce_scatter_tensor_coalesced(self):
|
def test_inductor_reduce_scatter_tensor_coalesced(self):
|
||||||
def func(args: list[torch.Tensor]) -> torch.Tensor:
|
def func(args: list[torch.Tensor]) -> torch.Tensor:
|
||||||
rs0 = funcol.reduce_scatter_tensor_coalesced(
|
rs0 = funcol.reduce_scatter_tensor_coalesced(
|
||||||
|
|
@ -1049,7 +1049,7 @@ class CompileTest(TestCase):
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||||
@fresh_inductor_cache()
|
@fresh_cache()
|
||||||
def test_inductor_all_to_all_single(self):
|
def test_inductor_all_to_all_single(self):
|
||||||
def _tolist_with_constrain_as_size(tensor):
|
def _tolist_with_constrain_as_size(tensor):
|
||||||
lst = tensor.tolist()
|
lst = tensor.tolist()
|
||||||
|
|
@ -1097,7 +1097,7 @@ class CompileTest(TestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||||
@fresh_inductor_cache()
|
@fresh_cache()
|
||||||
def test_inductor_broadcast(self):
|
def test_inductor_broadcast(self):
|
||||||
def func(arg: torch.Tensor) -> torch.Tensor:
|
def func(arg: torch.Tensor) -> torch.Tensor:
|
||||||
buf0 = arg + 42
|
buf0 = arg + 42
|
||||||
|
|
@ -1134,7 +1134,7 @@ class CompileTest(TestCase):
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||||
@fresh_inductor_cache()
|
@fresh_cache()
|
||||||
def test_ranks_and_tag(self):
|
def test_ranks_and_tag(self):
|
||||||
def func(arg: torch.Tensor) -> torch.Tensor:
|
def func(arg: torch.Tensor) -> torch.Tensor:
|
||||||
buf0 = arg + 42
|
buf0 = arg + 42
|
||||||
|
|
|
||||||
|
|
@ -1218,11 +1218,9 @@ class TestMultiProc(DynamoDistributedMultiProcTestCase):
|
||||||
@patch.object(torch._inductor.config, "sleep_sec_TESTING_ONLY", 10)
|
@patch.object(torch._inductor.config, "sleep_sec_TESTING_ONLY", 10)
|
||||||
def test_asymmetric_compilation_with_fx_cache(self):
|
def test_asymmetric_compilation_with_fx_cache(self):
|
||||||
from torch._dynamo.utils import counters
|
from torch._dynamo.utils import counters
|
||||||
from torch._inductor.utils import fresh_inductor_cache
|
from torch._inductor.utils import fresh_cache
|
||||||
|
|
||||||
with fresh_inductor_cache(), _dynamo_dist_per_rank_init(
|
with fresh_cache(), _dynamo_dist_per_rank_init(self.rank, self.world_size):
|
||||||
self.rank, self.world_size
|
|
||||||
):
|
|
||||||
torch._dynamo.utils.clear_compilation_metrics()
|
torch._dynamo.utils.clear_compilation_metrics()
|
||||||
|
|
||||||
device = f"cuda:{self.rank}"
|
device = f"cuda:{self.rank}"
|
||||||
|
|
@ -1252,7 +1250,7 @@ class TestMultiProc(DynamoDistributedMultiProcTestCase):
|
||||||
torch._dynamo.reset()
|
torch._dynamo.reset()
|
||||||
|
|
||||||
if self.rank == 0:
|
if self.rank == 0:
|
||||||
with fresh_inductor_cache():
|
with fresh_cache():
|
||||||
f(x)
|
f(x)
|
||||||
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 2)
|
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 2)
|
||||||
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0)
|
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0)
|
||||||
|
|
|
||||||
|
|
@ -10,7 +10,7 @@ import torch.distributed as dist
|
||||||
import torch.distributed._symmetric_memory as symm_mem
|
import torch.distributed._symmetric_memory as symm_mem
|
||||||
from torch._C._autograd import DeviceType
|
from torch._C._autograd import DeviceType
|
||||||
from torch._C._distributed_c10d import _SymmetricMemory
|
from torch._C._distributed_c10d import _SymmetricMemory
|
||||||
from torch._inductor.utils import fresh_inductor_cache, run_and_get_triton_code
|
from torch._inductor.utils import fresh_cache, run_and_get_triton_code
|
||||||
from torch.distributed._functional_collectives import all_gather_tensor
|
from torch.distributed._functional_collectives import all_gather_tensor
|
||||||
from torch.distributed._symmetric_memory import (
|
from torch.distributed._symmetric_memory import (
|
||||||
_fused_all_gather_matmul_fallback,
|
_fused_all_gather_matmul_fallback,
|
||||||
|
|
@ -1020,7 +1020,7 @@ class LoweringTest(MultiProcContinousTest):
|
||||||
@skip("Fails with 'one_shot_all_reduce' not found in AOT graph, TODO: fix")
|
@skip("Fails with 'one_shot_all_reduce' not found in AOT graph, TODO: fix")
|
||||||
@skipIfRocm # requires registered-buffer support
|
@skipIfRocm # requires registered-buffer support
|
||||||
@skip_if_lt_x_gpu(2)
|
@skip_if_lt_x_gpu(2)
|
||||||
@fresh_inductor_cache()
|
@fresh_cache()
|
||||||
def test_lowering_one_shot_all_reduce(self):
|
def test_lowering_one_shot_all_reduce(self):
|
||||||
self._init_process()
|
self._init_process()
|
||||||
arg = torch.rand(4, 4, device=self.device)
|
arg = torch.rand(4, 4, device=self.device)
|
||||||
|
|
|
||||||
|
|
@ -25,7 +25,7 @@ from torch._inductor import config as inductor_config
|
||||||
from torch._inductor.runtime.runtime_utils import cache_dir
|
from torch._inductor.runtime.runtime_utils import cache_dir
|
||||||
from torch._inductor.runtime.triton_compat import tl, triton
|
from torch._inductor.runtime.triton_compat import tl, triton
|
||||||
from torch._inductor.test_case import TestCase as InductorTestCase
|
from torch._inductor.test_case import TestCase as InductorTestCase
|
||||||
from torch._inductor.utils import fresh_inductor_cache
|
from torch._inductor.utils import fresh_cache
|
||||||
from torch._subclasses import FakeTensorMode
|
from torch._subclasses import FakeTensorMode
|
||||||
from torch.compiler._cache import CacheArtifactManager
|
from torch.compiler._cache import CacheArtifactManager
|
||||||
from torch.fx.experimental.symbolic_shapes import ShapeEnv
|
from torch.fx.experimental.symbolic_shapes import ShapeEnv
|
||||||
|
|
@ -165,7 +165,7 @@ class AOTAutogradCacheTests(InductorTestCase):
|
||||||
b = torch.rand(100, 100, dtype=dtype, device=device, requires_grad=True)
|
b = torch.rand(100, 100, dtype=dtype, device=device, requires_grad=True)
|
||||||
|
|
||||||
# Record artifacts
|
# Record artifacts
|
||||||
with fresh_inductor_cache():
|
with fresh_cache():
|
||||||
compiled_fn = torch.compile(fn, dynamic=dynamic)
|
compiled_fn = torch.compile(fn, dynamic=dynamic)
|
||||||
|
|
||||||
# A first call should miss in the cache.
|
# A first call should miss in the cache.
|
||||||
|
|
@ -201,7 +201,7 @@ class AOTAutogradCacheTests(InductorTestCase):
|
||||||
shutil.rmtree(os.path.join(cache_dir(), "triton"), ignore_errors=True)
|
shutil.rmtree(os.path.join(cache_dir(), "triton"), ignore_errors=True)
|
||||||
|
|
||||||
# We did not load anything so dont hit yet
|
# We did not load anything so dont hit yet
|
||||||
with fresh_inductor_cache():
|
with fresh_cache():
|
||||||
eager_result = fn(a, b)
|
eager_result = fn(a, b)
|
||||||
compiled_result = compiled_fn(a, b)
|
compiled_result = compiled_fn(a, b)
|
||||||
self.assertEqual(eager_result, compiled_result)
|
self.assertEqual(eager_result, compiled_result)
|
||||||
|
|
@ -221,7 +221,7 @@ class AOTAutogradCacheTests(InductorTestCase):
|
||||||
shutil.rmtree(os.path.join(cache_dir(), "triton"), ignore_errors=True)
|
shutil.rmtree(os.path.join(cache_dir(), "triton"), ignore_errors=True)
|
||||||
|
|
||||||
# Hot load and hit
|
# Hot load and hit
|
||||||
with fresh_inductor_cache():
|
with fresh_cache():
|
||||||
cache_info = torch.compiler.load_cache_artifacts(artifact_bytes)
|
cache_info = torch.compiler.load_cache_artifacts(artifact_bytes)
|
||||||
|
|
||||||
self.assertEqual(len(cache_info.inductor_artifacts), 2)
|
self.assertEqual(len(cache_info.inductor_artifacts), 2)
|
||||||
|
|
|
||||||
|
|
@ -798,7 +798,7 @@ TRACE FX call mul from test_logging.py:N in fn (LoggingTests.test_trace_call_pre
|
||||||
@requires_cuda
|
@requires_cuda
|
||||||
@unittest.skipIf(not SM90OrLater, "requires H100+ GPU")
|
@unittest.skipIf(not SM90OrLater, "requires H100+ GPU")
|
||||||
def test_autotuning(self, records):
|
def test_autotuning(self, records):
|
||||||
with torch._inductor.utils.fresh_inductor_cache():
|
with torch._inductor.utils.fresh_cache():
|
||||||
|
|
||||||
def f(a, b):
|
def f(a, b):
|
||||||
return torch.mm(a, b)
|
return torch.mm(a, b)
|
||||||
|
|
|
||||||
|
|
@ -54,7 +54,7 @@ from torch._dynamo.testing import (
|
||||||
)
|
)
|
||||||
from torch._dynamo.utils import call_size, counters, ifdynstaticdefault
|
from torch._dynamo.utils import call_size, counters, ifdynstaticdefault
|
||||||
from torch._dynamo.variables import builder
|
from torch._dynamo.variables import builder
|
||||||
from torch._inductor.utils import fresh_inductor_cache, run_and_get_code
|
from torch._inductor.utils import fresh_cache, run_and_get_code
|
||||||
from torch.ao.quantization import MinMaxObserver
|
from torch.ao.quantization import MinMaxObserver
|
||||||
from torch.ao.quantization.fake_quantize import FakeQuantize
|
from torch.ao.quantization.fake_quantize import FakeQuantize
|
||||||
from torch.ao.quantization.qconfig import QConfig
|
from torch.ao.quantization.qconfig import QConfig
|
||||||
|
|
@ -8087,7 +8087,7 @@ utils_device.CURRENT_DEVICE == None""".split(
|
||||||
|
|
||||||
m1 = Model(50)
|
m1 = Model(50)
|
||||||
m2 = Model(60)
|
m2 = Model(60)
|
||||||
with fresh_inductor_cache():
|
with fresh_cache():
|
||||||
m1(torch.rand(1, 2, 3))
|
m1(torch.rand(1, 2, 3))
|
||||||
m2(torch.rand(1, 2, 3))
|
m2(torch.rand(1, 2, 3))
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -12,7 +12,7 @@ import torch._inductor.mock_cache as mock_cache
|
||||||
import torch.compiler.config
|
import torch.compiler.config
|
||||||
import torch.nested
|
import torch.nested
|
||||||
from torch._dynamo.testing import CompileCounter
|
from torch._dynamo.testing import CompileCounter
|
||||||
from torch._inductor.utils import clear_inductor_caches, fresh_inductor_cache
|
from torch._inductor.utils import clear_caches, fresh_cache
|
||||||
|
|
||||||
|
|
||||||
class PgoTest(torch._dynamo.test_case.TestCase):
|
class PgoTest(torch._dynamo.test_case.TestCase):
|
||||||
|
|
@ -24,7 +24,7 @@ class PgoTest(torch._dynamo.test_case.TestCase):
|
||||||
torch._dynamo.config.patch(automatic_dynamic_local_pgo=True)
|
torch._dynamo.config.patch(automatic_dynamic_local_pgo=True)
|
||||||
)
|
)
|
||||||
if os.environ.get("INDUCTOR_TEST_DISABLE_FRESH_CACHE") != "1":
|
if os.environ.get("INDUCTOR_TEST_DISABLE_FRESH_CACHE") != "1":
|
||||||
self._test_stack.enter_context(fresh_inductor_cache())
|
self._test_stack.enter_context(fresh_cache())
|
||||||
mock_cache.PatchCaches.setUp()
|
mock_cache.PatchCaches.setUp()
|
||||||
|
|
||||||
def tearDown(self):
|
def tearDown(self):
|
||||||
|
|
@ -35,7 +35,7 @@ class PgoTest(torch._dynamo.test_case.TestCase):
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
torch._dynamo.reset()
|
torch._dynamo.reset()
|
||||||
clear_inductor_caches()
|
clear_caches()
|
||||||
|
|
||||||
def test_basic(self):
|
def test_basic(self):
|
||||||
cnts = CompileCounter()
|
cnts = CompileCounter()
|
||||||
|
|
@ -244,7 +244,7 @@ class PgoTest(torch._dynamo.test_case.TestCase):
|
||||||
self.assertEqual(cnts.frame_count, 2)
|
self.assertEqual(cnts.frame_count, 2)
|
||||||
|
|
||||||
torch._dynamo.reset()
|
torch._dynamo.reset()
|
||||||
clear_inductor_caches()
|
clear_caches()
|
||||||
cnts.clear()
|
cnts.clear()
|
||||||
|
|
||||||
with torch.compiler.config.patch(job_id="foo"):
|
with torch.compiler.config.patch(job_id="foo"):
|
||||||
|
|
|
||||||
|
|
@ -43,7 +43,7 @@ import torch.utils._pytree as pytree
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch._dynamo.debug_utils import same_two_models
|
from torch._dynamo.debug_utils import same_two_models
|
||||||
from torch._dynamo.testing import CompileCounter, rand_strided, same, skipIfPy312
|
from torch._dynamo.testing import CompileCounter, rand_strided, same, skipIfPy312
|
||||||
from torch._inductor.utils import fresh_inductor_cache
|
from torch._inductor.utils import fresh_cache
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
from torch.testing._internal.common_cuda import (
|
from torch.testing._internal.common_cuda import (
|
||||||
PLATFORM_SUPPORTS_FLASH_ATTENTION,
|
PLATFORM_SUPPORTS_FLASH_ATTENTION,
|
||||||
|
|
@ -5527,7 +5527,7 @@ def forward(self, s77 : torch.SymInt, s27 : torch.SymInt, L_x_ : torch.Tensor):
|
||||||
y = torch.randn(100, 10)
|
y = torch.randn(100, 10)
|
||||||
return torch.mm(x, y).sum()
|
return torch.mm(x, y).sum()
|
||||||
|
|
||||||
with fresh_inductor_cache():
|
with fresh_cache():
|
||||||
torch.compile(fn)()
|
torch.compile(fn)()
|
||||||
|
|
||||||
torch.compile(fn2)()
|
torch.compile(fn2)()
|
||||||
|
|
|
||||||
|
|
@ -18,7 +18,7 @@ import torch
|
||||||
from torch._inductor.codecache import get_kernel_bin_format
|
from torch._inductor.codecache import get_kernel_bin_format
|
||||||
from torch._inductor.package import AOTICompiledModel, load_package, package_aoti
|
from torch._inductor.package import AOTICompiledModel, load_package, package_aoti
|
||||||
from torch._inductor.test_case import TestCase
|
from torch._inductor.test_case import TestCase
|
||||||
from torch._inductor.utils import fresh_inductor_cache
|
from torch._inductor.utils import fresh_cache
|
||||||
from torch.export import Dim
|
from torch.export import Dim
|
||||||
from torch.testing._internal.common_utils import (
|
from torch.testing._internal.common_utils import (
|
||||||
IS_FBCODE,
|
IS_FBCODE,
|
||||||
|
|
@ -157,7 +157,7 @@ class TestAOTInductorPackage(TestCase):
|
||||||
torch.manual_seed(0)
|
torch.manual_seed(0)
|
||||||
with tempfile.NamedTemporaryFile(suffix=".pt2") as f:
|
with tempfile.NamedTemporaryFile(suffix=".pt2") as f:
|
||||||
ep = torch.export.export(model, example_inputs, strict=True)
|
ep = torch.export.export(model, example_inputs, strict=True)
|
||||||
with fresh_inductor_cache():
|
with fresh_cache():
|
||||||
# cubin files are removed when exiting this context
|
# cubin files are removed when exiting this context
|
||||||
package_path = torch._inductor.aoti_compile_and_package(
|
package_path = torch._inductor.aoti_compile_and_package(
|
||||||
ep,
|
ep,
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,7 @@ import torch
|
||||||
from torch._inductor import config
|
from torch._inductor import config
|
||||||
from torch._inductor.async_compile import AsyncCompile, shutdown_compile_workers
|
from torch._inductor.async_compile import AsyncCompile, shutdown_compile_workers
|
||||||
from torch._inductor.test_case import run_tests, TestCase
|
from torch._inductor.test_case import run_tests, TestCase
|
||||||
from torch._inductor.utils import fresh_inductor_cache
|
from torch._inductor.utils import fresh_cache
|
||||||
from torch.testing._internal.common_utils import (
|
from torch.testing._internal.common_utils import (
|
||||||
instantiate_parametrized_tests,
|
instantiate_parametrized_tests,
|
||||||
parametrize,
|
parametrize,
|
||||||
|
|
@ -32,7 +32,7 @@ class TestAsyncCompile(TestCase):
|
||||||
pool = AsyncCompile.process_pool()
|
pool = AsyncCompile.process_pool()
|
||||||
pool.ready_future.result(timeout=120)
|
pool.ready_future.result(timeout=120)
|
||||||
|
|
||||||
with fresh_inductor_cache():
|
with fresh_cache():
|
||||||
compiled_fn = torch.compile(fn)
|
compiled_fn = torch.compile(fn)
|
||||||
self.assertEqual(fn(x, y), compiled_fn(x, y))
|
self.assertEqual(fn(x, y), compiled_fn(x, y))
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -7,7 +7,7 @@ import torch
|
||||||
from torch._inductor.codegen.triton import TritonScheduling
|
from torch._inductor.codegen.triton import TritonScheduling
|
||||||
from torch._inductor.test_case import TestCase as InductorTestCase
|
from torch._inductor.test_case import TestCase as InductorTestCase
|
||||||
from torch._inductor.test_operators import realize
|
from torch._inductor.test_operators import realize
|
||||||
from torch._inductor.utils import fresh_inductor_cache, is_big_gpu, run_and_get_code
|
from torch._inductor.utils import fresh_cache, is_big_gpu, run_and_get_code
|
||||||
from torch.testing import FileCheck
|
from torch.testing import FileCheck
|
||||||
from torch.testing._internal.common_utils import slowTest
|
from torch.testing._internal.common_utils import slowTest
|
||||||
from torch.testing._internal.inductor_utils import (
|
from torch.testing._internal.inductor_utils import (
|
||||||
|
|
@ -283,7 +283,7 @@ if HAS_CUDA:
|
||||||
self.assertEqual(res, res2, atol=1e-4, rtol=1.1)
|
self.assertEqual(res, res2, atol=1e-4, rtol=1.1)
|
||||||
return code, code2
|
return code, code2
|
||||||
|
|
||||||
@fresh_inductor_cache()
|
@fresh_cache()
|
||||||
@config.patch(max_autotune_gemm_backends="TRITON")
|
@config.patch(max_autotune_gemm_backends="TRITON")
|
||||||
def test_equivalent_template_code(self):
|
def test_equivalent_template_code(self):
|
||||||
code, code2 = self._equivalent_output_code_impl(256)
|
code, code2 = self._equivalent_output_code_impl(256)
|
||||||
|
|
@ -298,7 +298,7 @@ if HAS_CUDA:
|
||||||
out_code[0]
|
out_code[0]
|
||||||
)
|
)
|
||||||
|
|
||||||
@fresh_inductor_cache()
|
@fresh_cache()
|
||||||
@config.patch(max_autotune_gemm_backends="ATEN")
|
@config.patch(max_autotune_gemm_backends="ATEN")
|
||||||
def test_equivalent_extern_code(self):
|
def test_equivalent_extern_code(self):
|
||||||
torch._dynamo.reset()
|
torch._dynamo.reset()
|
||||||
|
|
|
||||||
|
|
@ -38,7 +38,7 @@ from torch._inductor.graph import GraphLowering
|
||||||
from torch._inductor.mock_cache import global_stats, PatchCaches, Stats
|
from torch._inductor.mock_cache import global_stats, PatchCaches, Stats
|
||||||
from torch._inductor.runtime.runtime_utils import cache_dir
|
from torch._inductor.runtime.runtime_utils import cache_dir
|
||||||
from torch._inductor.test_case import run_tests, TestCase
|
from torch._inductor.test_case import run_tests, TestCase
|
||||||
from torch._inductor.utils import clear_inductor_caches, fresh_inductor_cache
|
from torch._inductor.utils import clear_caches, fresh_cache
|
||||||
from torch._library import capture_triton
|
from torch._library import capture_triton
|
||||||
from torch.compiler._cache import (
|
from torch.compiler._cache import (
|
||||||
CacheArtifact,
|
CacheArtifact,
|
||||||
|
|
@ -146,7 +146,7 @@ class TestFxGraphCache(TestCase):
|
||||||
AOTAutogradCache.clear()
|
AOTAutogradCache.clear()
|
||||||
PyCodeCache.cache_clear(purge=True)
|
PyCodeCache.cache_clear(purge=True)
|
||||||
torch._dynamo.reset()
|
torch._dynamo.reset()
|
||||||
clear_inductor_caches()
|
clear_caches()
|
||||||
|
|
||||||
@requires_triton()
|
@requires_triton()
|
||||||
@config.patch({"fx_graph_cache": True})
|
@config.patch({"fx_graph_cache": True})
|
||||||
|
|
@ -379,16 +379,14 @@ class TestFxGraphCache(TestCase):
|
||||||
), patch.dict(os.environ), PatchCaches():
|
), patch.dict(os.environ), PatchCaches():
|
||||||
os.environ.pop("TRITON_CACHE_MANAGER", None)
|
os.environ.pop("TRITON_CACHE_MANAGER", None)
|
||||||
for _ in range(4):
|
for _ in range(4):
|
||||||
with fresh_inductor_cache():
|
with fresh_cache():
|
||||||
compiled_fn = torch.compile(fn, dynamic=dynamic)
|
compiled_fn = torch.compile(fn, dynamic=dynamic)
|
||||||
self.assertEqual(fn(a, b), compiled_fn(a, b))
|
self.assertEqual(fn(a, b), compiled_fn(a, b))
|
||||||
reset()
|
reset()
|
||||||
|
|
||||||
self.assertEqual(global_stats.fx_graph, Stats(1, 3, 1))
|
self.assertEqual(global_stats.fx_graph, Stats(1, 3, 1))
|
||||||
|
|
||||||
with torch.compiler.config.patch(
|
with torch.compiler.config.patch({"cache_key_tag": "test"}), fresh_cache():
|
||||||
{"cache_key_tag": "test"}
|
|
||||||
), fresh_inductor_cache():
|
|
||||||
compiled_fn = torch.compile(fn, dynamic=dynamic)
|
compiled_fn = torch.compile(fn, dynamic=dynamic)
|
||||||
self.assertEqual(fn(a, b), compiled_fn(a, b))
|
self.assertEqual(fn(a, b), compiled_fn(a, b))
|
||||||
|
|
||||||
|
|
@ -426,7 +424,7 @@ class TestFxGraphCache(TestCase):
|
||||||
b = torch.rand(100, 100, dtype=dtype, device=device)
|
b = torch.rand(100, 100, dtype=dtype, device=device)
|
||||||
|
|
||||||
# Record artifacts
|
# Record artifacts
|
||||||
with fresh_inductor_cache():
|
with fresh_cache():
|
||||||
compiled_fn = torch.compile(fn, dynamic=dynamic)
|
compiled_fn = torch.compile(fn, dynamic=dynamic)
|
||||||
|
|
||||||
# A first call should miss in the cache.
|
# A first call should miss in the cache.
|
||||||
|
|
@ -456,7 +454,7 @@ class TestFxGraphCache(TestCase):
|
||||||
shutil.rmtree(os.path.join(cache_dir(), "triton"), ignore_errors=True)
|
shutil.rmtree(os.path.join(cache_dir(), "triton"), ignore_errors=True)
|
||||||
|
|
||||||
# We did not load anything so dont hit yet
|
# We did not load anything so dont hit yet
|
||||||
with fresh_inductor_cache():
|
with fresh_cache():
|
||||||
eager_result = fn(a, b)
|
eager_result = fn(a, b)
|
||||||
compiled_result = compiled_fn(a, b)
|
compiled_result = compiled_fn(a, b)
|
||||||
self.assertEqual(eager_result, compiled_result)
|
self.assertEqual(eager_result, compiled_result)
|
||||||
|
|
@ -470,7 +468,7 @@ class TestFxGraphCache(TestCase):
|
||||||
shutil.rmtree(os.path.join(cache_dir(), "triton"), ignore_errors=True)
|
shutil.rmtree(os.path.join(cache_dir(), "triton"), ignore_errors=True)
|
||||||
|
|
||||||
# Hot load and hit
|
# Hot load and hit
|
||||||
with fresh_inductor_cache():
|
with fresh_cache():
|
||||||
cache_info = torch.compiler.load_cache_artifacts(artifact_bytes)
|
cache_info = torch.compiler.load_cache_artifacts(artifact_bytes)
|
||||||
|
|
||||||
self.assertEqual(len(cache_info.inductor_artifacts), 1)
|
self.assertEqual(len(cache_info.inductor_artifacts), 1)
|
||||||
|
|
@ -503,7 +501,7 @@ class TestFxGraphCache(TestCase):
|
||||||
a2 = torch.randn(4, 8)
|
a2 = torch.randn(4, 8)
|
||||||
b2 = torch.randn(8, 4)
|
b2 = torch.randn(8, 4)
|
||||||
|
|
||||||
with fresh_inductor_cache():
|
with fresh_cache():
|
||||||
eager_result = fn(a, b)
|
eager_result = fn(a, b)
|
||||||
compiled_result = compiled_fn(a, b)
|
compiled_result = compiled_fn(a, b)
|
||||||
self.assertEqual(eager_result, compiled_result)
|
self.assertEqual(eager_result, compiled_result)
|
||||||
|
|
@ -519,7 +517,7 @@ class TestFxGraphCache(TestCase):
|
||||||
|
|
||||||
self.reset()
|
self.reset()
|
||||||
|
|
||||||
with fresh_inductor_cache():
|
with fresh_cache():
|
||||||
torch.compiler.load_cache_artifacts(artifact_bytes)
|
torch.compiler.load_cache_artifacts(artifact_bytes)
|
||||||
eager_result = fn(a, b)
|
eager_result = fn(a, b)
|
||||||
compiled_result = compiled_fn(a, b)
|
compiled_result = compiled_fn(a, b)
|
||||||
|
|
@ -531,7 +529,7 @@ class TestFxGraphCache(TestCase):
|
||||||
|
|
||||||
self.reset()
|
self.reset()
|
||||||
|
|
||||||
with fresh_inductor_cache():
|
with fresh_cache():
|
||||||
eager_result = fn(a2, b2)
|
eager_result = fn(a2, b2)
|
||||||
compiled_result = compiled_fn(a2, b2)
|
compiled_result = compiled_fn(a2, b2)
|
||||||
self.assertEqual(eager_result, compiled_result)
|
self.assertEqual(eager_result, compiled_result)
|
||||||
|
|
@ -555,7 +553,7 @@ class TestFxGraphCache(TestCase):
|
||||||
return x * 2
|
return x * 2
|
||||||
|
|
||||||
# Record artifacts
|
# Record artifacts
|
||||||
with torch.compiler.config.patch(job_id=self.id()), fresh_inductor_cache():
|
with torch.compiler.config.patch(job_id=self.id()), fresh_cache():
|
||||||
f(torch.randn(2, 3))
|
f(torch.randn(2, 3))
|
||||||
f(torch.randn(2, 4))
|
f(torch.randn(2, 4))
|
||||||
self.assertEqual(backend.frame_count, 2)
|
self.assertEqual(backend.frame_count, 2)
|
||||||
|
|
@ -582,7 +580,7 @@ class TestFxGraphCache(TestCase):
|
||||||
shutil.rmtree(os.path.join(cache_dir(), "triton"), ignore_errors=True)
|
shutil.rmtree(os.path.join(cache_dir(), "triton"), ignore_errors=True)
|
||||||
|
|
||||||
# Hot load and hit
|
# Hot load and hit
|
||||||
with torch.compiler.config.patch({"job_id": self.id()}), fresh_inductor_cache():
|
with torch.compiler.config.patch({"job_id": self.id()}), fresh_cache():
|
||||||
cache_info = torch.compiler.load_cache_artifacts(artifact_bytes)
|
cache_info = torch.compiler.load_cache_artifacts(artifact_bytes)
|
||||||
|
|
||||||
self.assertEqual(len(cache_info.inductor_artifacts), 2)
|
self.assertEqual(len(cache_info.inductor_artifacts), 2)
|
||||||
|
|
@ -617,7 +615,7 @@ class TestFxGraphCache(TestCase):
|
||||||
with mock.patch(
|
with mock.patch(
|
||||||
"torch._utils_internal.get_mast_job_name_version", return_value=("foo", 5)
|
"torch._utils_internal.get_mast_job_name_version", return_value=("foo", 5)
|
||||||
):
|
):
|
||||||
with fresh_inductor_cache():
|
with fresh_cache():
|
||||||
f(torch.randn(2, 3))
|
f(torch.randn(2, 3))
|
||||||
f(torch.randn(2, 4))
|
f(torch.randn(2, 4))
|
||||||
self.assertEqual(backend.frame_count, 2)
|
self.assertEqual(backend.frame_count, 2)
|
||||||
|
|
@ -639,7 +637,7 @@ class TestFxGraphCache(TestCase):
|
||||||
# Hot load and hit
|
# Hot load and hit
|
||||||
with mock.patch(
|
with mock.patch(
|
||||||
"torch._utils_internal.get_mast_job_name_version", return_value=("bar", 10)
|
"torch._utils_internal.get_mast_job_name_version", return_value=("bar", 10)
|
||||||
), fresh_inductor_cache():
|
), fresh_cache():
|
||||||
cache_info = torch.compiler.load_cache_artifacts(artifact_bytes)
|
cache_info = torch.compiler.load_cache_artifacts(artifact_bytes)
|
||||||
|
|
||||||
self.assertEqual(len(cache_info.pgo_artifacts), 2)
|
self.assertEqual(len(cache_info.pgo_artifacts), 2)
|
||||||
|
|
@ -1582,7 +1580,7 @@ class TestStandaloneCompile(TestCase):
|
||||||
AOTAutogradCache.clear()
|
AOTAutogradCache.clear()
|
||||||
PyCodeCache.cache_clear(purge=True)
|
PyCodeCache.cache_clear(purge=True)
|
||||||
torch._dynamo.reset()
|
torch._dynamo.reset()
|
||||||
clear_inductor_caches()
|
clear_caches()
|
||||||
|
|
||||||
def capture(self, fn, dynamic=None):
|
def capture(self, fn, dynamic=None):
|
||||||
def inner(*args):
|
def inner(*args):
|
||||||
|
|
@ -1638,7 +1636,7 @@ class TestStandaloneCompile(TestCase):
|
||||||
if format == "unpacked"
|
if format == "unpacked"
|
||||||
else os.path.join(temp_dir, "compiled_artifact.bin")
|
else os.path.join(temp_dir, "compiled_artifact.bin")
|
||||||
)
|
)
|
||||||
with fresh_inductor_cache():
|
with fresh_cache():
|
||||||
gm, args, kwargs = self.capture(f)(x)
|
gm, args, kwargs = self.capture(f)(x)
|
||||||
assert not kwargs
|
assert not kwargs
|
||||||
|
|
||||||
|
|
@ -1647,7 +1645,7 @@ class TestStandaloneCompile(TestCase):
|
||||||
|
|
||||||
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0)
|
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0)
|
||||||
|
|
||||||
with fresh_inductor_cache():
|
with fresh_cache():
|
||||||
loaded = torch._inductor.CompiledArtifact.load(path=path, format=format)
|
loaded = torch._inductor.CompiledArtifact.load(path=path, format=format)
|
||||||
if dynamic:
|
if dynamic:
|
||||||
concrete_args = [
|
concrete_args = [
|
||||||
|
|
@ -1679,7 +1677,7 @@ class TestStandaloneCompile(TestCase):
|
||||||
def backend(gm, args, **kwargs):
|
def backend(gm, args, **kwargs):
|
||||||
return torch._inductor.standalone_compile(gm, args)
|
return torch._inductor.standalone_compile(gm, args)
|
||||||
|
|
||||||
with fresh_inductor_cache():
|
with fresh_cache():
|
||||||
compiled_out = torch.compile(f, fullgraph=True, backend=backend)(x)
|
compiled_out = torch.compile(f, fullgraph=True, backend=backend)(x)
|
||||||
self.assertEqual(eager_out, compiled_out)
|
self.assertEqual(eager_out, compiled_out)
|
||||||
|
|
||||||
|
|
@ -1698,7 +1696,7 @@ class TestStandaloneCompile(TestCase):
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as temp_dir:
|
with tempfile.TemporaryDirectory() as temp_dir:
|
||||||
path = os.path.join(temp_dir, "new_dir")
|
path = os.path.join(temp_dir, "new_dir")
|
||||||
with fresh_inductor_cache():
|
with fresh_cache():
|
||||||
gm, args, kwargs = self.capture(f)(x)
|
gm, args, kwargs = self.capture(f)(x)
|
||||||
assert not kwargs
|
assert not kwargs
|
||||||
|
|
||||||
|
|
@ -1707,7 +1705,7 @@ class TestStandaloneCompile(TestCase):
|
||||||
|
|
||||||
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0)
|
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0)
|
||||||
|
|
||||||
with fresh_inductor_cache():
|
with fresh_cache():
|
||||||
loaded = torch._inductor.CompiledArtifact.load(
|
loaded = torch._inductor.CompiledArtifact.load(
|
||||||
path=path, format="unpacked"
|
path=path, format="unpacked"
|
||||||
)
|
)
|
||||||
|
|
@ -1731,7 +1729,7 @@ class TestStandaloneCompile(TestCase):
|
||||||
eager_out = f(x)
|
eager_out = f(x)
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as temp_dir:
|
with tempfile.TemporaryDirectory() as temp_dir:
|
||||||
with fresh_inductor_cache():
|
with fresh_cache():
|
||||||
gm, args, kwargs = self.capture(f)(x)
|
gm, args, kwargs = self.capture(f)(x)
|
||||||
assert not kwargs
|
assert not kwargs
|
||||||
|
|
||||||
|
|
@ -1743,7 +1741,7 @@ class TestStandaloneCompile(TestCase):
|
||||||
|
|
||||||
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0)
|
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0)
|
||||||
|
|
||||||
with fresh_inductor_cache():
|
with fresh_cache():
|
||||||
# Now modify the output file and expect to see the changes
|
# Now modify the output file and expect to see the changes
|
||||||
for subdir in os.listdir(temp_dir):
|
for subdir in os.listdir(temp_dir):
|
||||||
if subdir in ["aotautograd", "fxgraph"]:
|
if subdir in ["aotautograd", "fxgraph"]:
|
||||||
|
|
@ -1791,16 +1789,16 @@ class TestStandaloneCompile(TestCase):
|
||||||
with tempfile.TemporaryDirectory() as temp_dir:
|
with tempfile.TemporaryDirectory() as temp_dir:
|
||||||
path = os.path.join(temp_dir, "compiled_artifact.bin")
|
path = os.path.join(temp_dir, "compiled_artifact.bin")
|
||||||
|
|
||||||
with fresh_inductor_cache():
|
with fresh_cache():
|
||||||
compiled_artifact = torch._inductor.standalone_compile(gm, args)
|
compiled_artifact = torch._inductor.standalone_compile(gm, args)
|
||||||
compiled_artifact.save(path=path)
|
compiled_artifact.save(path=path)
|
||||||
|
|
||||||
script = f"""
|
script = f"""
|
||||||
import torch
|
import torch
|
||||||
from torch._inductor.utils import fresh_inductor_cache
|
from torch._inductor.utils import fresh_cache
|
||||||
|
|
||||||
arg = torch.ones(4, 1)
|
arg = torch.ones(4, 1)
|
||||||
with fresh_inductor_cache():
|
with fresh_cache():
|
||||||
loaded = torch._inductor.CompiledArtifact.load(path="{path}")
|
loaded = torch._inductor.CompiledArtifact.load(path="{path}")
|
||||||
compiled_result = loaded(arg)[0]
|
compiled_result = loaded(arg)[0]
|
||||||
|
|
||||||
|
|
@ -1832,7 +1830,7 @@ if not torch.allclose(eager_result, compiled_result, atol=0.1, rtol=0.01):
|
||||||
|
|
||||||
x = torch.ones(3)
|
x = torch.ones(3)
|
||||||
torch._dynamo.mark_dynamic(x, 0)
|
torch._dynamo.mark_dynamic(x, 0)
|
||||||
with fresh_inductor_cache():
|
with fresh_cache():
|
||||||
# captured graph is lambda s0, x: x * s0
|
# captured graph is lambda s0, x: x * s0
|
||||||
gm, args, kwargs = self.capture(f)(x)
|
gm, args, kwargs = self.capture(f)(x)
|
||||||
assert not kwargs
|
assert not kwargs
|
||||||
|
|
@ -1854,7 +1852,7 @@ if not torch.allclose(eager_result, compiled_result, atol=0.1, rtol=0.01):
|
||||||
|
|
||||||
x = torch.ones(3)
|
x = torch.ones(3)
|
||||||
torch._dynamo.mark_dynamic(x, 0)
|
torch._dynamo.mark_dynamic(x, 0)
|
||||||
with fresh_inductor_cache():
|
with fresh_cache():
|
||||||
# captured graph is lambda s0, x: x * s0
|
# captured graph is lambda s0, x: x * s0
|
||||||
gm, args, kwargs = self.capture(f)(x)
|
gm, args, kwargs = self.capture(f)(x)
|
||||||
assert not kwargs
|
assert not kwargs
|
||||||
|
|
@ -1890,7 +1888,7 @@ if not torch.allclose(eager_result, compiled_result, atol=0.1, rtol=0.01):
|
||||||
return x.shape[0] * x
|
return x.shape[0] * x
|
||||||
|
|
||||||
static_x = torch.randn(3)
|
static_x = torch.randn(3)
|
||||||
with fresh_inductor_cache():
|
with fresh_cache():
|
||||||
# static_gm is lambda x: x * 3
|
# static_gm is lambda x: x * 3
|
||||||
static_gm, args, kwargs = self.capture(f, dynamic=False)(static_x)
|
static_gm, args, kwargs = self.capture(f, dynamic=False)(static_x)
|
||||||
assert not kwargs
|
assert not kwargs
|
||||||
|
|
@ -2440,7 +2438,7 @@ class TestAutotuneCache(TestCase):
|
||||||
def reset(self):
|
def reset(self):
|
||||||
PyCodeCache.cache_clear(purge=True)
|
PyCodeCache.cache_clear(purge=True)
|
||||||
torch._dynamo.reset()
|
torch._dynamo.reset()
|
||||||
clear_inductor_caches()
|
clear_caches()
|
||||||
|
|
||||||
@unittest.skipIf(not HAS_CUDA, "Requires CUDA")
|
@unittest.skipIf(not HAS_CUDA, "Requires CUDA")
|
||||||
@unittest.skipIf(not SM80OrLater, "Requires SM80+")
|
@unittest.skipIf(not SM80OrLater, "Requires SM80+")
|
||||||
|
|
@ -2750,20 +2748,20 @@ class TestRemoteAOTAutogradCache(TestCase):
|
||||||
|
|
||||||
class TestUtils(TestCase):
|
class TestUtils(TestCase):
|
||||||
@config.patch({"fx_graph_remote_cache": False})
|
@config.patch({"fx_graph_remote_cache": False})
|
||||||
def test_fresh_inductor_cache(self):
|
def test_fresh_cache(self):
|
||||||
def fn(x, y):
|
def fn(x, y):
|
||||||
return x + y
|
return x + y
|
||||||
|
|
||||||
a = torch.rand(10)
|
a = torch.rand(10)
|
||||||
b = torch.rand(10)
|
b = torch.rand(10)
|
||||||
|
|
||||||
with fresh_inductor_cache():
|
with fresh_cache():
|
||||||
self.assertEqual(len(PyCodeCache.modules), 0)
|
self.assertEqual(len(PyCodeCache.modules), 0)
|
||||||
res1 = torch.compile(fn)(a, b)
|
res1 = torch.compile(fn)(a, b)
|
||||||
cache_dir1 = cache_dir()
|
cache_dir1 = cache_dir()
|
||||||
|
|
||||||
torch._dynamo.reset()
|
torch._dynamo.reset()
|
||||||
with fresh_inductor_cache():
|
with fresh_cache():
|
||||||
self.assertEqual(len(PyCodeCache.modules), 0)
|
self.assertEqual(len(PyCodeCache.modules), 0)
|
||||||
res2 = torch.compile(fn)(a, b)
|
res2 = torch.compile(fn)(a, b)
|
||||||
cache_dir2 = cache_dir()
|
cache_dir2 = cache_dir()
|
||||||
|
|
|
||||||
|
|
@ -919,9 +919,9 @@ class CompiledOptimizerTests(TestCase):
|
||||||
import torch._dynamo
|
import torch._dynamo
|
||||||
import torch._inductor
|
import torch._inductor
|
||||||
from torch._dynamo.debug_utils import aot_graph_input_parser
|
from torch._dynamo.debug_utils import aot_graph_input_parser
|
||||||
from torch._inductor.utils import fresh_inductor_cache
|
from torch._inductor.utils import fresh_cache
|
||||||
|
|
||||||
with fresh_inductor_cache():
|
with fresh_cache():
|
||||||
kwargs = aot_graph_input_parser(forward)
|
kwargs = aot_graph_input_parser(forward)
|
||||||
torch.compile(forward)(**kwargs)
|
torch.compile(forward)(**kwargs)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1814,9 +1814,9 @@ class CudaReproTests(TestCase):
|
||||||
|
|
||||||
m = ToyModel().to(device="cuda:0")
|
m = ToyModel().to(device="cuda:0")
|
||||||
input_tensor = torch.randn(32, 3, 64, 64).to(device="cuda:0")
|
input_tensor = torch.randn(32, 3, 64, 64).to(device="cuda:0")
|
||||||
from torch._inductor.utils import fresh_inductor_cache
|
from torch._inductor.utils import fresh_cache
|
||||||
|
|
||||||
with fresh_inductor_cache():
|
with fresh_cache():
|
||||||
cm = torch.compile(m, mode="max-autotune")
|
cm = torch.compile(m, mode="max-autotune")
|
||||||
out = cm(input_tensor)
|
out = cm(input_tensor)
|
||||||
out2 = m(input_tensor)
|
out2 = m(input_tensor)
|
||||||
|
|
|
||||||
|
|
@ -10,7 +10,7 @@ from torch._inductor.codecache import CUDACodeCache
|
||||||
from torch._inductor.codegen.cuda.cuda_env import nvcc_exist
|
from torch._inductor.codegen.cuda.cuda_env import nvcc_exist
|
||||||
from torch._inductor.exc import CUDACompileError
|
from torch._inductor.exc import CUDACompileError
|
||||||
from torch._inductor.test_case import TestCase as InductorTestCase
|
from torch._inductor.test_case import TestCase as InductorTestCase
|
||||||
from torch._inductor.utils import fresh_inductor_cache
|
from torch._inductor.utils import fresh_cache
|
||||||
|
|
||||||
|
|
||||||
_SOURCE_CODE = r"""
|
_SOURCE_CODE = r"""
|
||||||
|
|
@ -40,7 +40,7 @@ int saxpy(int n, float a, float *x, float *y) {
|
||||||
@unittest.skipIf(config.is_fbcode(), "fbcode requires different CUDA_HOME setup")
|
@unittest.skipIf(config.is_fbcode(), "fbcode requires different CUDA_HOME setup")
|
||||||
class TestCUDACodeCache(InductorTestCase):
|
class TestCUDACodeCache(InductorTestCase):
|
||||||
def test_cuda_load(self):
|
def test_cuda_load(self):
|
||||||
with fresh_inductor_cache():
|
with fresh_cache():
|
||||||
# Test both .o and .so compilation.
|
# Test both .o and .so compilation.
|
||||||
(
|
(
|
||||||
object_file_path,
|
object_file_path,
|
||||||
|
|
@ -67,13 +67,13 @@ class TestCUDACodeCache(InductorTestCase):
|
||||||
torch.testing.assert_close(y, expected_y)
|
torch.testing.assert_close(y, expected_y)
|
||||||
|
|
||||||
def test_compilation_error(self):
|
def test_compilation_error(self):
|
||||||
with fresh_inductor_cache():
|
with fresh_cache():
|
||||||
error_source_code = _SOURCE_CODE.replace("saxpy_device", "saxpy_wrong", 1)
|
error_source_code = _SOURCE_CODE.replace("saxpy_device", "saxpy_wrong", 1)
|
||||||
with self.assertRaises(CUDACompileError):
|
with self.assertRaises(CUDACompileError):
|
||||||
CUDACodeCache.compile(error_source_code, "o")
|
CUDACodeCache.compile(error_source_code, "o")
|
||||||
|
|
||||||
def test_async_compile(self):
|
def test_async_compile(self):
|
||||||
with fresh_inductor_cache():
|
with fresh_cache():
|
||||||
async_compile = AsyncCompile()
|
async_compile = AsyncCompile()
|
||||||
compiled_res = async_compile.cuda(_SOURCE_CODE, "so")
|
compiled_res = async_compile.cuda(_SOURCE_CODE, "so")
|
||||||
async_compile.wait(globals())
|
async_compile.wait(globals())
|
||||||
|
|
|
||||||
|
|
@ -13,7 +13,7 @@ from pathlib import Path
|
||||||
from typing import Callable, Optional
|
from typing import Callable, Optional
|
||||||
|
|
||||||
from torch._inductor.codegen.cuda.serialization import get_cutlass_operation_serializer
|
from torch._inductor.codegen.cuda.serialization import get_cutlass_operation_serializer
|
||||||
from torch._inductor.utils import clear_inductor_caches
|
from torch._inductor.utils import clear_caches
|
||||||
from torch.export import Dim
|
from torch.export import Dim
|
||||||
from torch.testing._internal.logging_utils import log_settings
|
from torch.testing._internal.logging_utils import log_settings
|
||||||
|
|
||||||
|
|
@ -38,7 +38,7 @@ from torch._inductor.exc import InductorError
|
||||||
from torch._inductor.ir import FixedLayout
|
from torch._inductor.ir import FixedLayout
|
||||||
from torch._inductor.select_algorithm import NoValidChoicesError
|
from torch._inductor.select_algorithm import NoValidChoicesError
|
||||||
from torch._inductor.test_case import run_tests, TestCase
|
from torch._inductor.test_case import run_tests, TestCase
|
||||||
from torch._inductor.utils import fresh_inductor_cache
|
from torch._inductor.utils import fresh_cache
|
||||||
from torch.sparse import SparseSemiStructuredTensor, to_sparse_semi_structured
|
from torch.sparse import SparseSemiStructuredTensor, to_sparse_semi_structured
|
||||||
from torch.testing import FileCheck
|
from torch.testing import FileCheck
|
||||||
from torch.testing._internal.common_cuda import (
|
from torch.testing._internal.common_cuda import (
|
||||||
|
|
@ -173,7 +173,7 @@ class TestCutlassBackend(TestCase):
|
||||||
|
|
||||||
def tearDown(self):
|
def tearDown(self):
|
||||||
super().tearDown()
|
super().tearDown()
|
||||||
clear_inductor_caches()
|
clear_caches()
|
||||||
|
|
||||||
def run_evt_test(self, model, op, shape, num_fusions=1):
|
def run_evt_test(self, model, op, shape, num_fusions=1):
|
||||||
M, N = shape
|
M, N = shape
|
||||||
|
|
@ -618,7 +618,7 @@ class TestCutlassBackend(TestCase):
|
||||||
]
|
]
|
||||||
for x_shape in x_shapes:
|
for x_shape in x_shapes:
|
||||||
torch._dynamo.reset()
|
torch._dynamo.reset()
|
||||||
clear_inductor_caches()
|
clear_caches()
|
||||||
|
|
||||||
inputs = [
|
inputs = [
|
||||||
(
|
(
|
||||||
|
|
@ -1065,7 +1065,7 @@ class TestCutlassBackend(TestCase):
|
||||||
def select_no_algorithm(*args, **kwargs):
|
def select_no_algorithm(*args, **kwargs):
|
||||||
raise NoValidChoicesError
|
raise NoValidChoicesError
|
||||||
|
|
||||||
with fresh_inductor_cache():
|
with fresh_cache():
|
||||||
with config.patch(
|
with config.patch(
|
||||||
{
|
{
|
||||||
"max_autotune": True,
|
"max_autotune": True,
|
||||||
|
|
@ -1113,7 +1113,7 @@ class TestCutlassBackend(TestCase):
|
||||||
def select_no_algorithm(*args, **kwargs):
|
def select_no_algorithm(*args, **kwargs):
|
||||||
raise NoValidChoicesError
|
raise NoValidChoicesError
|
||||||
|
|
||||||
with fresh_inductor_cache():
|
with fresh_cache():
|
||||||
with config.patch(
|
with config.patch(
|
||||||
{
|
{
|
||||||
"max_autotune": True,
|
"max_autotune": True,
|
||||||
|
|
@ -1187,7 +1187,7 @@ class TestCutlassBackend(TestCase):
|
||||||
raise NoValidChoicesError
|
raise NoValidChoicesError
|
||||||
|
|
||||||
def run_test(use_fast_accum):
|
def run_test(use_fast_accum):
|
||||||
with fresh_inductor_cache():
|
with fresh_cache():
|
||||||
with config.patch(
|
with config.patch(
|
||||||
{
|
{
|
||||||
"max_autotune": True,
|
"max_autotune": True,
|
||||||
|
|
@ -1266,7 +1266,7 @@ class TestCutlassBackend(TestCase):
|
||||||
def select_no_algorithm(*args, **kwargs):
|
def select_no_algorithm(*args, **kwargs):
|
||||||
raise NoValidChoicesError
|
raise NoValidChoicesError
|
||||||
|
|
||||||
with fresh_inductor_cache(), config.patch(
|
with fresh_cache(), config.patch(
|
||||||
{
|
{
|
||||||
"max_autotune": True,
|
"max_autotune": True,
|
||||||
"max_autotune_gemm_backends": "CUTLASS",
|
"max_autotune_gemm_backends": "CUTLASS",
|
||||||
|
|
@ -1324,7 +1324,7 @@ class TestCutlassBackend(TestCase):
|
||||||
def select_no_algorithm(*args, **kwargs):
|
def select_no_algorithm(*args, **kwargs):
|
||||||
raise NoValidChoicesError
|
raise NoValidChoicesError
|
||||||
|
|
||||||
with fresh_inductor_cache(), config.patch(
|
with fresh_cache(), config.patch(
|
||||||
{
|
{
|
||||||
"max_autotune": True,
|
"max_autotune": True,
|
||||||
"max_autotune_gemm_backends": "CUTLASS",
|
"max_autotune_gemm_backends": "CUTLASS",
|
||||||
|
|
|
||||||
|
|
@ -10,7 +10,7 @@ from pathlib import Path
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch._inductor import config, test_operators
|
from torch._inductor import config, test_operators
|
||||||
from torch._inductor.utils import fresh_inductor_cache
|
from torch._inductor.utils import fresh_cache
|
||||||
from torch.testing._internal.common_utils import skipIfWindows
|
from torch.testing._internal.common_utils import skipIfWindows
|
||||||
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU
|
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU
|
||||||
from torch.testing._internal.logging_utils import multiple_logs_to_string
|
from torch.testing._internal.logging_utils import multiple_logs_to_string
|
||||||
|
|
@ -44,7 +44,7 @@ class TestDebugTrace(test_torchinductor.TestCase):
|
||||||
"torch._inductor.debug", "ir_pre_fusion", "ir_post_fusion"
|
"torch._inductor.debug", "ir_pre_fusion", "ir_post_fusion"
|
||||||
)
|
)
|
||||||
|
|
||||||
# TODO(aakhundov): make this work with fresh_inductor_cache
|
# TODO(aakhundov): make this work with fresh_cache
|
||||||
# instead of force_disable_caches. currently, with the latter
|
# instead of force_disable_caches. currently, with the latter
|
||||||
# enabled, we get `inductor [('fxgraph_cache_hit', 1)]` in
|
# enabled, we get `inductor [('fxgraph_cache_hit', 1)]` in
|
||||||
# the counters: so the cache is actually hit and the test fails.
|
# the counters: so the cache is actually hit and the test fails.
|
||||||
|
|
@ -263,7 +263,7 @@ op2.node.kernel = extern_kernels.mm""",
|
||||||
# no failure
|
# no failure
|
||||||
with self.assertLogs(
|
with self.assertLogs(
|
||||||
logging.getLogger("torch._inductor.debug"), level=logging.WARNING
|
logging.getLogger("torch._inductor.debug"), level=logging.WARNING
|
||||||
), fresh_inductor_cache():
|
), fresh_cache():
|
||||||
m = ToyModel().to(device=GPU_TYPE)
|
m = ToyModel().to(device=GPU_TYPE)
|
||||||
m = torch.compile(m, mode="max-autotune")
|
m = torch.compile(m, mode="max-autotune")
|
||||||
input_tensor = torch.randn(100).to(device=GPU_TYPE)
|
input_tensor = torch.randn(100).to(device=GPU_TYPE)
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,7 @@ import torch._inductor.metrics as metrics
|
||||||
import torch.utils.flop_counter
|
import torch.utils.flop_counter
|
||||||
from torch._dynamo.utils import counters
|
from torch._dynamo.utils import counters
|
||||||
from torch._inductor.ir import FixedLayout
|
from torch._inductor.ir import FixedLayout
|
||||||
from torch._inductor.utils import fresh_inductor_cache
|
from torch._inductor.utils import fresh_cache
|
||||||
from torch.testing._internal.common_cuda import SM70OrLater
|
from torch.testing._internal.common_cuda import SM70OrLater
|
||||||
from torch.testing._internal.common_device_type import (
|
from torch.testing._internal.common_device_type import (
|
||||||
dtypes,
|
dtypes,
|
||||||
|
|
@ -77,7 +77,7 @@ class TestScheduler(TestCase):
|
||||||
for op, example_inputs, kwargs in tc:
|
for op, example_inputs, kwargs in tc:
|
||||||
comp = torch.compile(op)
|
comp = torch.compile(op)
|
||||||
torch._dynamo.reset()
|
torch._dynamo.reset()
|
||||||
with fresh_inductor_cache():
|
with fresh_cache():
|
||||||
comp(*example_inputs, **kwargs)
|
comp(*example_inputs, **kwargs)
|
||||||
self.assertEqual(metrics.num_bytes_accessed, 0)
|
self.assertEqual(metrics.num_bytes_accessed, 0)
|
||||||
self.assertEqual(any(m[1] for m in metrics.node_runtimes), False)
|
self.assertEqual(any(m[1] for m in metrics.node_runtimes), False)
|
||||||
|
|
@ -108,7 +108,7 @@ class TestScheduler(TestCase):
|
||||||
|
|
||||||
comp = torch.compile(op)
|
comp = torch.compile(op)
|
||||||
torch._dynamo.reset()
|
torch._dynamo.reset()
|
||||||
with fresh_inductor_cache():
|
with fresh_cache():
|
||||||
comp(*example_inputs, **kwargs)
|
comp(*example_inputs, **kwargs)
|
||||||
self.assertEqual(enba, metrics.num_bytes_accessed)
|
self.assertEqual(enba, metrics.num_bytes_accessed)
|
||||||
nonzero_node_runtimes = sum(1 for x in metrics.node_runtimes if x[1] != 0)
|
nonzero_node_runtimes = sum(1 for x in metrics.node_runtimes if x[1] != 0)
|
||||||
|
|
@ -152,7 +152,7 @@ class TestScheduler(TestCase):
|
||||||
comp = torch.compile(op, options=options)
|
comp = torch.compile(op, options=options)
|
||||||
# next two lines are required, otherwise the flops will be cached from pervious runs of this function.
|
# next two lines are required, otherwise the flops will be cached from pervious runs of this function.
|
||||||
torch._dynamo.reset()
|
torch._dynamo.reset()
|
||||||
with fresh_inductor_cache():
|
with fresh_cache():
|
||||||
# actually run to set the counters
|
# actually run to set the counters
|
||||||
comp(*example_inputs, **kwargs)
|
comp(*example_inputs, **kwargs)
|
||||||
with FlopCounterMode() as mode:
|
with FlopCounterMode() as mode:
|
||||||
|
|
|
||||||
|
|
@ -13,7 +13,7 @@ from torch._dynamo.testing import rand_strided
|
||||||
from torch._inductor import config
|
from torch._inductor import config
|
||||||
from torch._inductor.codecache import PyCodeCache
|
from torch._inductor.codecache import PyCodeCache
|
||||||
from torch._inductor.test_case import run_tests, TestCase
|
from torch._inductor.test_case import run_tests, TestCase
|
||||||
from torch._inductor.utils import fresh_inductor_cache
|
from torch._inductor.utils import fresh_cache
|
||||||
from torch.testing import FileCheck
|
from torch.testing import FileCheck
|
||||||
from torch.testing._internal.common_cuda import xfailIfSM89
|
from torch.testing._internal.common_cuda import xfailIfSM89
|
||||||
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU, IS_BIG_GPU
|
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU, IS_BIG_GPU
|
||||||
|
|
@ -152,7 +152,7 @@ class TestKernelBenchmark(TestCase):
|
||||||
@unittest.skipIf(
|
@unittest.skipIf(
|
||||||
not IS_BIG_GPU, "Skipping triton backend only since not big GPU (not enough SM)"
|
not IS_BIG_GPU, "Skipping triton backend only since not big GPU (not enough SM)"
|
||||||
)
|
)
|
||||||
@fresh_inductor_cache()
|
@fresh_cache()
|
||||||
def test_matmul_triton_kernel_benchmark(self):
|
def test_matmul_triton_kernel_benchmark(self):
|
||||||
M = 12544
|
M = 12544
|
||||||
N = 256
|
N = 256
|
||||||
|
|
@ -170,7 +170,7 @@ class TestKernelBenchmark(TestCase):
|
||||||
@config.patch(
|
@config.patch(
|
||||||
max_autotune=True, max_autotune_gemm_backends="TRITON", shape_padding=False
|
max_autotune=True, max_autotune_gemm_backends="TRITON", shape_padding=False
|
||||||
)
|
)
|
||||||
@fresh_inductor_cache()
|
@fresh_cache()
|
||||||
def test_mm_triton_kernel_benchmark(self):
|
def test_mm_triton_kernel_benchmark(self):
|
||||||
M = 2048
|
M = 2048
|
||||||
N = 2432
|
N = 2432
|
||||||
|
|
|
||||||
|
|
@ -49,7 +49,7 @@ from torch.utils._triton import has_triton_tma_device
|
||||||
aten = torch.ops.aten
|
aten = torch.ops.aten
|
||||||
from torch._inductor.mock_cache import global_stats, PatchCaches, Stats
|
from torch._inductor.mock_cache import global_stats, PatchCaches, Stats
|
||||||
from torch._inductor.test_case import run_tests, TestCase
|
from torch._inductor.test_case import run_tests, TestCase
|
||||||
from torch._inductor.utils import fresh_inductor_cache, run_and_get_code
|
from torch._inductor.utils import fresh_cache, run_and_get_code
|
||||||
from torch._inductor.virtualized import V
|
from torch._inductor.virtualized import V
|
||||||
from torch.fx.experimental.proxy_tensor import make_fx
|
from torch.fx.experimental.proxy_tensor import make_fx
|
||||||
from torch.testing import FileCheck
|
from torch.testing import FileCheck
|
||||||
|
|
@ -325,7 +325,7 @@ class TestMaxAutotune(TestCase):
|
||||||
|
|
||||||
torch.testing.assert_close(c_actual, c_expected, atol=1e-2, rtol=1e-2)
|
torch.testing.assert_close(c_actual, c_expected, atol=1e-2, rtol=1e-2)
|
||||||
|
|
||||||
@fresh_inductor_cache()
|
@fresh_cache()
|
||||||
@unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support sm carveout")
|
@unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support sm carveout")
|
||||||
@unittest.skipIf(IS_WINDOWS, "Windows doesn't support persistent TMA")
|
@unittest.skipIf(IS_WINDOWS, "Windows doesn't support persistent TMA")
|
||||||
@unittest.skipIf(
|
@unittest.skipIf(
|
||||||
|
|
@ -470,7 +470,7 @@ class TestMaxAutotune(TestCase):
|
||||||
FileCheck().check_not("extern_kernels.convolution").run(code[0])
|
FileCheck().check_not("extern_kernels.convolution").run(code[0])
|
||||||
self.assertEqual(conv1x1(input_tensor), out, atol=1e-2, rtol=0)
|
self.assertEqual(conv1x1(input_tensor), out, atol=1e-2, rtol=0)
|
||||||
|
|
||||||
@fresh_inductor_cache()
|
@fresh_cache()
|
||||||
@config.patch(max_autotune=True, max_fusion_size=2)
|
@config.patch(max_autotune=True, max_fusion_size=2)
|
||||||
def test_jit_fusion_matches_aot_fusion(self):
|
def test_jit_fusion_matches_aot_fusion(self):
|
||||||
# In this example, AOTInductor's JIT-compile will fuse(buf1, buf2) due
|
# In this example, AOTInductor's JIT-compile will fuse(buf1, buf2) due
|
||||||
|
|
@ -563,7 +563,7 @@ class TestMaxAutotune(TestCase):
|
||||||
def f(x, y):
|
def f(x, y):
|
||||||
return x @ y
|
return x @ y
|
||||||
|
|
||||||
with fresh_inductor_cache():
|
with fresh_cache():
|
||||||
act = torch.compile(f)(x, y)
|
act = torch.compile(f)(x, y)
|
||||||
ref = f(x, y)
|
ref = f(x, y)
|
||||||
self.assertTrue(torch.allclose(act, ref, atol=4 * 1e-3, rtol=4 * 1e-3))
|
self.assertTrue(torch.allclose(act, ref, atol=4 * 1e-3, rtol=4 * 1e-3))
|
||||||
|
|
@ -1146,7 +1146,7 @@ class TestMaxAutotune(TestCase):
|
||||||
self.assertEqual(generate_and_load_args - 1, make_key_args)
|
self.assertEqual(generate_and_load_args - 1, make_key_args)
|
||||||
self.assertEqual(generate_and_load_args, 16)
|
self.assertEqual(generate_and_load_args, 16)
|
||||||
|
|
||||||
@fresh_inductor_cache()
|
@fresh_cache()
|
||||||
@config.patch(
|
@config.patch(
|
||||||
{
|
{
|
||||||
"max_autotune": True,
|
"max_autotune": True,
|
||||||
|
|
@ -1213,7 +1213,7 @@ class TestMaxAutotune(TestCase):
|
||||||
b = torch.rand(22, 30, device=GPU_TYPE)
|
b = torch.rand(22, 30, device=GPU_TYPE)
|
||||||
|
|
||||||
# Valid cache hit.
|
# Valid cache hit.
|
||||||
with fresh_inductor_cache():
|
with fresh_cache():
|
||||||
reset_counters()
|
reset_counters()
|
||||||
compile_results = torch.compile(func_test1, dynamic=False)(a, b, a, b)
|
compile_results = torch.compile(func_test1, dynamic=False)(a, b, a, b)
|
||||||
eager_results = func_test1(a, b, a, b)
|
eager_results = func_test1(a, b, a, b)
|
||||||
|
|
@ -1246,7 +1246,7 @@ class TestMaxAutotune(TestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
# Test symbolic shapes with different symbols. Will cache miss due to different symbols in inputs.
|
# Test symbolic shapes with different symbols. Will cache miss due to different symbols in inputs.
|
||||||
with fresh_inductor_cache():
|
with fresh_cache():
|
||||||
a = torch.rand(10, 22, device=GPU_TYPE)
|
a = torch.rand(10, 22, device=GPU_TYPE)
|
||||||
b = torch.rand(22, 30, device=GPU_TYPE)
|
b = torch.rand(22, 30, device=GPU_TYPE)
|
||||||
|
|
||||||
|
|
@ -1297,7 +1297,7 @@ class TestMaxAutotune(TestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
# Test duck typing.
|
# Test duck typing.
|
||||||
with fresh_inductor_cache():
|
with fresh_cache():
|
||||||
reset_counters()
|
reset_counters()
|
||||||
|
|
||||||
compile_results = torch.compile(func_test1, dynamic=True)(a, b, a, b)
|
compile_results = torch.compile(func_test1, dynamic=True)(a, b, a, b)
|
||||||
|
|
@ -1313,7 +1313,7 @@ class TestMaxAutotune(TestCase):
|
||||||
x = torch.matmul(x, x)
|
x = torch.matmul(x, x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
with fresh_inductor_cache():
|
with fresh_cache():
|
||||||
reset_counters()
|
reset_counters()
|
||||||
input = torch.rand(10, 10, device=GPU_TYPE)
|
input = torch.rand(10, 10, device=GPU_TYPE)
|
||||||
|
|
||||||
|
|
@ -1324,7 +1324,7 @@ class TestMaxAutotune(TestCase):
|
||||||
self.assertEqual(hits(), 36)
|
self.assertEqual(hits(), 36)
|
||||||
self.assertEqual(misses(), 4)
|
self.assertEqual(misses(), 4)
|
||||||
|
|
||||||
with fresh_inductor_cache():
|
with fresh_cache():
|
||||||
reset_counters()
|
reset_counters()
|
||||||
input = torch.rand(10, 10, device=GPU_TYPE)
|
input = torch.rand(10, 10, device=GPU_TYPE)
|
||||||
|
|
||||||
|
|
@ -1343,7 +1343,7 @@ class TestMaxAutotune(TestCase):
|
||||||
b = torch.matmul(torch.cat([x, z], 1), torch.cat([y, m, l], 0))
|
b = torch.matmul(torch.cat([x, z], 1), torch.cat([y, m, l], 0))
|
||||||
return a, b
|
return a, b
|
||||||
|
|
||||||
with fresh_inductor_cache():
|
with fresh_cache():
|
||||||
a = torch.rand(10, 22, device=GPU_TYPE)
|
a = torch.rand(10, 22, device=GPU_TYPE)
|
||||||
b = torch.rand(22, 30, device=GPU_TYPE)
|
b = torch.rand(22, 30, device=GPU_TYPE)
|
||||||
c = torch.rand(10, 11, device=GPU_TYPE)
|
c = torch.rand(10, 11, device=GPU_TYPE)
|
||||||
|
|
@ -1384,7 +1384,7 @@ class TestMaxAutotune(TestCase):
|
||||||
]
|
]
|
||||||
|
|
||||||
# Valid cache hit.
|
# Valid cache hit.
|
||||||
with fresh_inductor_cache():
|
with fresh_cache():
|
||||||
torch._dynamo.utils.counters.clear()
|
torch._dynamo.utils.counters.clear()
|
||||||
compile_results = torch.compile(func_test1, dynamic=False)(a, b, a, b)
|
compile_results = torch.compile(func_test1, dynamic=False)(a, b, a, b)
|
||||||
eager_results = func_test1(a, b, a, b)
|
eager_results = func_test1(a, b, a, b)
|
||||||
|
|
@ -1424,7 +1424,7 @@ class TestMaxAutotune(TestCase):
|
||||||
]
|
]
|
||||||
|
|
||||||
# Valid cache hit.
|
# Valid cache hit.
|
||||||
with fresh_inductor_cache():
|
with fresh_cache():
|
||||||
torch._dynamo.utils.counters.clear()
|
torch._dynamo.utils.counters.clear()
|
||||||
compile_results = torch.compile(func_test1, dynamic=False)(a, b, a, b)
|
compile_results = torch.compile(func_test1, dynamic=False)(a, b, a, b)
|
||||||
eager_results = func_test1(a, b, a, b)
|
eager_results = func_test1(a, b, a, b)
|
||||||
|
|
@ -1543,7 +1543,7 @@ class TestMaxAutotunePrecompile(TestCase):
|
||||||
fn_c = torch.compile(mode="max-autotune-no-cudagraphs")(fn)
|
fn_c = torch.compile(mode="max-autotune-no-cudagraphs")(fn)
|
||||||
self.assertEqual(counters["inductor"]["select_algorithm_precompile"], 0)
|
self.assertEqual(counters["inductor"]["select_algorithm_precompile"], 0)
|
||||||
|
|
||||||
@fresh_inductor_cache()
|
@fresh_cache()
|
||||||
@config.patch(search_autotune_cache=True)
|
@config.patch(search_autotune_cache=True)
|
||||||
def test_search_autotune_cache(self):
|
def test_search_autotune_cache(self):
|
||||||
def fn(a, b, c):
|
def fn(a, b, c):
|
||||||
|
|
@ -1811,12 +1811,12 @@ class TestMaxAutotuneRemoteCache(TestCase):
|
||||||
os.environ.pop("TRITON_CACHE_MANAGER", None)
|
os.environ.pop("TRITON_CACHE_MANAGER", None)
|
||||||
with config.patch({"max_autotune": True}):
|
with config.patch({"max_autotune": True}):
|
||||||
for _ in range(4):
|
for _ in range(4):
|
||||||
with fresh_inductor_cache():
|
with fresh_cache():
|
||||||
torch.compile(mm, dynamic=dynamic)(a, b)
|
torch.compile(mm, dynamic=dynamic)(a, b)
|
||||||
reset()
|
reset()
|
||||||
with torch.compiler.config.patch(
|
with torch.compiler.config.patch(
|
||||||
{"cache_key_tag": "test"}
|
{"cache_key_tag": "test"}
|
||||||
), fresh_inductor_cache():
|
), fresh_cache():
|
||||||
torch.compile(mm, dynamic=dynamic)(a, b)
|
torch.compile(mm, dynamic=dynamic)(a, b)
|
||||||
reset()
|
reset()
|
||||||
|
|
||||||
|
|
@ -1825,12 +1825,10 @@ class TestMaxAutotuneRemoteCache(TestCase):
|
||||||
|
|
||||||
global_stats.reset()
|
global_stats.reset()
|
||||||
for _ in range(4):
|
for _ in range(4):
|
||||||
with fresh_inductor_cache():
|
with fresh_cache():
|
||||||
torch.compile(f, dynamic=dynamic)(x, y)
|
torch.compile(f, dynamic=dynamic)(x, y)
|
||||||
reset()
|
reset()
|
||||||
with torch.compiler.config.patch(
|
with torch.compiler.config.patch({"cache_key_tag": "test"}), fresh_cache():
|
||||||
{"cache_key_tag": "test"}
|
|
||||||
), fresh_inductor_cache():
|
|
||||||
torch.compile(mm, dynamic=dynamic)(a, b)
|
torch.compile(mm, dynamic=dynamic)(a, b)
|
||||||
reset()
|
reset()
|
||||||
global_stats.report()
|
global_stats.report()
|
||||||
|
|
|
||||||
|
|
@ -13,7 +13,7 @@ from torch._inductor.fx_passes.pad_mm import (
|
||||||
should_pad_mm_bf16,
|
should_pad_mm_bf16,
|
||||||
)
|
)
|
||||||
from torch._inductor.test_case import run_tests, TestCase
|
from torch._inductor.test_case import run_tests, TestCase
|
||||||
from torch._inductor.utils import fresh_inductor_cache, is_big_gpu, run_and_get_code
|
from torch._inductor.utils import fresh_cache, is_big_gpu, run_and_get_code
|
||||||
from torch.testing import FileCheck
|
from torch.testing import FileCheck
|
||||||
from torch.testing._internal.common_utils import skipIfRocm
|
from torch.testing._internal.common_utils import skipIfRocm
|
||||||
from torch.testing._internal.inductor_utils import HAS_CUDA
|
from torch.testing._internal.inductor_utils import HAS_CUDA
|
||||||
|
|
@ -362,7 +362,7 @@ class PadMMTest(TestCase):
|
||||||
self.assertEqual(out, inps[0] @ inps[1])
|
self.assertEqual(out, inps[0] @ inps[1])
|
||||||
|
|
||||||
@inductor_config.patch(force_shape_pad=True)
|
@inductor_config.patch(force_shape_pad=True)
|
||||||
@fresh_inductor_cache()
|
@fresh_cache()
|
||||||
def test_pad_addmm_2d_bias(self):
|
def test_pad_addmm_2d_bias(self):
|
||||||
@torch.compile()
|
@torch.compile()
|
||||||
def foo(input, x, y):
|
def foo(input, x, y):
|
||||||
|
|
@ -419,7 +419,7 @@ class PadMMTest(TestCase):
|
||||||
res2, bmm_expected_result
|
res2, bmm_expected_result
|
||||||
), "BMM results are not identical"
|
), "BMM results are not identical"
|
||||||
|
|
||||||
@fresh_inductor_cache()
|
@fresh_cache()
|
||||||
def test_exclude_padding(self):
|
def test_exclude_padding(self):
|
||||||
@torch.compile()
|
@torch.compile()
|
||||||
def mm(a, b):
|
def mm(a, b):
|
||||||
|
|
@ -448,7 +448,7 @@ class PadMMTest(TestCase):
|
||||||
repr(local_cache)
|
repr(local_cache)
|
||||||
)
|
)
|
||||||
|
|
||||||
@fresh_inductor_cache()
|
@fresh_cache()
|
||||||
@inductor_config.patch(max_pointwise_cat_inputs=2)
|
@inductor_config.patch(max_pointwise_cat_inputs=2)
|
||||||
def test_exclude_cat_padding(self):
|
def test_exclude_cat_padding(self):
|
||||||
@torch.compile()
|
@torch.compile()
|
||||||
|
|
@ -475,7 +475,7 @@ class PadMMTest(TestCase):
|
||||||
"No perf regression on H100+ with BF16",
|
"No perf regression on H100+ with BF16",
|
||||||
)
|
)
|
||||||
@skipIfRocm
|
@skipIfRocm
|
||||||
@fresh_inductor_cache()
|
@fresh_cache()
|
||||||
@inductor_config.patch(
|
@inductor_config.patch(
|
||||||
post_grad_fusion_options={"pad_aten_mm_pass": {"k_threshold_to_pad": 8388608}}
|
post_grad_fusion_options={"pad_aten_mm_pass": {"k_threshold_to_pad": 8388608}}
|
||||||
)
|
)
|
||||||
|
|
@ -508,7 +508,7 @@ class PadMMTest(TestCase):
|
||||||
|
|
||||||
assert torch.allclose(res2, mm_expected_result), "MM results are not identical"
|
assert torch.allclose(res2, mm_expected_result), "MM results are not identical"
|
||||||
|
|
||||||
@fresh_inductor_cache()
|
@fresh_cache()
|
||||||
@inductor_config.patch(
|
@inductor_config.patch(
|
||||||
{
|
{
|
||||||
"triton.unique_kernel_names": "original_aten",
|
"triton.unique_kernel_names": "original_aten",
|
||||||
|
|
|
||||||
|
|
@ -16,7 +16,7 @@ import torch.nn.functional as F
|
||||||
from torch import sym_int, SymBool, SymFloat, SymInt
|
from torch import sym_int, SymBool, SymFloat, SymInt
|
||||||
from torch._C import _disabled_torch_function_impl
|
from torch._C import _disabled_torch_function_impl
|
||||||
from torch._dynamo.testing import CompileCounterWithBackend
|
from torch._dynamo.testing import CompileCounterWithBackend
|
||||||
from torch._inductor.utils import fresh_inductor_cache
|
from torch._inductor.utils import fresh_cache
|
||||||
from torch.fx.experimental import sym_node
|
from torch.fx.experimental import sym_node
|
||||||
from torch.fx.experimental.proxy_tensor import make_fx
|
from torch.fx.experimental.proxy_tensor import make_fx
|
||||||
from torch.fx.experimental.sym_node import method_to_operator, SymNode, to_node
|
from torch.fx.experimental.sym_node import method_to_operator, SymNode, to_node
|
||||||
|
|
@ -3150,7 +3150,7 @@ class TestUnbacked(TestCase):
|
||||||
|
|
||||||
|
|
||||||
class TestUbackedOps(TestCase):
|
class TestUbackedOps(TestCase):
|
||||||
@fresh_inductor_cache()
|
@fresh_cache()
|
||||||
@skipIfTorchDynamo("not allowed to trace mark_unbacked")
|
@skipIfTorchDynamo("not allowed to trace mark_unbacked")
|
||||||
@torch._dynamo.config.patch("capture_scalar_outputs", True)
|
@torch._dynamo.config.patch("capture_scalar_outputs", True)
|
||||||
def test_unbacked_reshape1(self):
|
def test_unbacked_reshape1(self):
|
||||||
|
|
|
||||||
|
|
@ -42,7 +42,7 @@ from torch._inductor.runtime.compile_tasks import (
|
||||||
_set_triton_ptxas_path,
|
_set_triton_ptxas_path,
|
||||||
_worker_compile_triton,
|
_worker_compile_triton,
|
||||||
)
|
)
|
||||||
from torch._inductor.utils import clear_on_fresh_inductor_cache
|
from torch._inductor.utils import clear_on_fresh_cache
|
||||||
from torch._inductor.virtualized import V
|
from torch._inductor.virtualized import V
|
||||||
from torch.hub import _Faketqdm, tqdm
|
from torch.hub import _Faketqdm, tqdm
|
||||||
from torch.utils._ordered_set import OrderedSet
|
from torch.utils._ordered_set import OrderedSet
|
||||||
|
|
@ -162,7 +162,7 @@ def get_compile_threads() -> int:
|
||||||
return config.compile_threads
|
return config.compile_threads
|
||||||
|
|
||||||
|
|
||||||
@clear_on_fresh_inductor_cache
|
@clear_on_fresh_cache
|
||||||
class CompiledTritonKernels:
|
class CompiledTritonKernels:
|
||||||
"""
|
"""
|
||||||
In memory cache for storing compiled triton kernels.
|
In memory cache for storing compiled triton kernels.
|
||||||
|
|
|
||||||
|
|
@ -86,7 +86,7 @@ from torch._inductor.runtime.compile_tasks import _reload_python_module
|
||||||
from torch._inductor.runtime.runtime_utils import cache_dir, default_cache_dir
|
from torch._inductor.runtime.runtime_utils import cache_dir, default_cache_dir
|
||||||
from torch._inductor.utils import (
|
from torch._inductor.utils import (
|
||||||
ALIGN_BYTES,
|
ALIGN_BYTES,
|
||||||
clear_on_fresh_inductor_cache,
|
clear_on_fresh_cache,
|
||||||
is_linux,
|
is_linux,
|
||||||
is_windows,
|
is_windows,
|
||||||
)
|
)
|
||||||
|
|
@ -236,7 +236,7 @@ class CacheBase:
|
||||||
return system
|
return system
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@clear_on_fresh_inductor_cache
|
@clear_on_fresh_cache
|
||||||
@functools.cache
|
@functools.cache
|
||||||
def get_local_cache_path() -> Path:
|
def get_local_cache_path() -> Path:
|
||||||
return Path(os.path.join(cache_dir(), "cache", CacheBase.get_system()["hash"]))
|
return Path(os.path.join(cache_dir(), "cache", CacheBase.get_system()["hash"]))
|
||||||
|
|
@ -1595,7 +1595,7 @@ def split_aot_inductor_output_path(path: str) -> tuple[str, str]:
|
||||||
return path, ""
|
return path, ""
|
||||||
|
|
||||||
|
|
||||||
@clear_on_fresh_inductor_cache
|
@clear_on_fresh_cache
|
||||||
class CudaKernelParamCache:
|
class CudaKernelParamCache:
|
||||||
cache: dict[str, dict[str, Any]] = {}
|
cache: dict[str, dict[str, Any]] = {}
|
||||||
cache_clear = staticmethod(cache.clear)
|
cache_clear = staticmethod(cache.clear)
|
||||||
|
|
@ -2291,7 +2291,7 @@ def custom_op_wrapper(op: str, *args: Any) -> Union[list[c_void_p], c_void_p, No
|
||||||
|
|
||||||
# Precompiled headers are persistent past program runtime, but associated with one
|
# Precompiled headers are persistent past program runtime, but associated with one
|
||||||
# specific compiler version and set of flags. We explicitly use default_cache_dir here
|
# specific compiler version and set of flags. We explicitly use default_cache_dir here
|
||||||
# because these headers need to be global, rather than ignored by fresh_inductor_cache.
|
# because these headers need to be global, rather than ignored by fresh_cache.
|
||||||
_HEADER_DIR = os.path.join(default_cache_dir(), "precompiled_headers")
|
_HEADER_DIR = os.path.join(default_cache_dir(), "precompiled_headers")
|
||||||
_HEADER_LOCK_DIR = os.path.join(_HEADER_DIR, "locks")
|
_HEADER_LOCK_DIR = os.path.join(_HEADER_DIR, "locks")
|
||||||
|
|
||||||
|
|
@ -2378,7 +2378,7 @@ def _get_cpp_wrapper_header(device: str, aot_mode: bool = False) -> str:
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@clear_on_fresh_inductor_cache
|
@clear_on_fresh_cache
|
||||||
class CppCodeCache:
|
class CppCodeCache:
|
||||||
"""Compiles and caches C++ libraries. Users of this class supply the source code to
|
"""Compiles and caches C++ libraries. Users of this class supply the source code to
|
||||||
be compiled, while compilation flags are set by CppBuilder."""
|
be compiled, while compilation flags are set by CppBuilder."""
|
||||||
|
|
@ -2587,7 +2587,7 @@ def _worker_compile_cpp(
|
||||||
|
|
||||||
|
|
||||||
# Customized Python binding for cpp kernels
|
# Customized Python binding for cpp kernels
|
||||||
@clear_on_fresh_inductor_cache
|
@clear_on_fresh_cache
|
||||||
class CppPythonBindingsCodeCache(CppCodeCache):
|
class CppPythonBindingsCodeCache(CppCodeCache):
|
||||||
cache: dict[str, Callable[[], Union[CDLL, ModuleType]]] = {}
|
cache: dict[str, Callable[[], Union[CDLL, ModuleType]]] = {}
|
||||||
cache_clear = staticmethod(cache.clear)
|
cache_clear = staticmethod(cache.clear)
|
||||||
|
|
@ -2768,7 +2768,7 @@ class CppPythonBindingsCodeCache(CppCodeCache):
|
||||||
return cls.load_pybinding_async(*args, **kwargs)()
|
return cls.load_pybinding_async(*args, **kwargs)()
|
||||||
|
|
||||||
|
|
||||||
@clear_on_fresh_inductor_cache
|
@clear_on_fresh_cache
|
||||||
class CppWrapperCodeCache(CppPythonBindingsCodeCache):
|
class CppWrapperCodeCache(CppPythonBindingsCodeCache):
|
||||||
cache: dict[str, Callable[[], Union[CDLL, ModuleType]]] = {}
|
cache: dict[str, Callable[[], Union[CDLL, ModuleType]]] = {}
|
||||||
cache_clear = staticmethod(cache.clear)
|
cache_clear = staticmethod(cache.clear)
|
||||||
|
|
@ -2837,7 +2837,7 @@ class CppWrapperCodeCache(CppPythonBindingsCodeCache):
|
||||||
return _get_cpp_wrapper_header(device)
|
return _get_cpp_wrapper_header(device)
|
||||||
|
|
||||||
|
|
||||||
@clear_on_fresh_inductor_cache
|
@clear_on_fresh_cache
|
||||||
class HalideCodeCache(CppPythonBindingsCodeCache):
|
class HalideCodeCache(CppPythonBindingsCodeCache):
|
||||||
cache: dict[str, Callable[[], Union[ModuleType, CDLL]]] = {}
|
cache: dict[str, Callable[[], Union[ModuleType, CDLL]]] = {}
|
||||||
cache_clear = staticmethod(cache.clear)
|
cache_clear = staticmethod(cache.clear)
|
||||||
|
|
@ -3140,10 +3140,10 @@ class HalideCodeCache(CppPythonBindingsCodeCache):
|
||||||
target = "host-cuda" if device_type == "cuda" else "host"
|
target = "host-cuda" if device_type == "cuda" else "host"
|
||||||
if cls._standalone_runtime_path:
|
if cls._standalone_runtime_path:
|
||||||
assert not os.path.exists(cls._standalone_runtime_path)
|
assert not os.path.exists(cls._standalone_runtime_path)
|
||||||
# We hit this case in unittests when we run with fresh_inductor_cache()
|
# We hit this case in unittests when we run with fresh_cache()
|
||||||
# Generating a fresh runtime over and over causes errors because we initialize
|
# Generating a fresh runtime over and over causes errors because we initialize
|
||||||
# cuda hundreds of times in the same process and run out of file descriptors.
|
# cuda hundreds of times in the same process and run out of file descriptors.
|
||||||
# Workaround by jail breaking the current fresh_inductor_cache().
|
# Workaround by jail breaking the current fresh_cache().
|
||||||
base = default_cache_dir()
|
base = default_cache_dir()
|
||||||
else:
|
else:
|
||||||
base = cache_dir()
|
base = cache_dir()
|
||||||
|
|
@ -3239,7 +3239,7 @@ def touch(filename: str) -> None:
|
||||||
open(filename, "a").close()
|
open(filename, "a").close()
|
||||||
|
|
||||||
|
|
||||||
@clear_on_fresh_inductor_cache
|
@clear_on_fresh_cache
|
||||||
class PyCodeCache:
|
class PyCodeCache:
|
||||||
# Track the loaded modules so we can remove the on-disk artifacts when
|
# Track the loaded modules so we can remove the on-disk artifacts when
|
||||||
# clearing the cache. Note also that we may load the same path more
|
# clearing the cache. Note also that we may load the same path more
|
||||||
|
|
@ -3625,7 +3625,7 @@ class DLLWrapper:
|
||||||
self.close()
|
self.close()
|
||||||
|
|
||||||
|
|
||||||
@clear_on_fresh_inductor_cache
|
@clear_on_fresh_cache
|
||||||
class CUDACodeCache:
|
class CUDACodeCache:
|
||||||
"""
|
"""
|
||||||
A cache for managing the compilation and loading of CUDA source code specifically for CUTLASS.
|
A cache for managing the compilation and loading of CUDA source code specifically for CUTLASS.
|
||||||
|
|
@ -3792,7 +3792,7 @@ class CUDACodeCache:
|
||||||
fh.write(error_json)
|
fh.write(error_json)
|
||||||
|
|
||||||
|
|
||||||
@clear_on_fresh_inductor_cache
|
@clear_on_fresh_cache
|
||||||
class ROCmCodeCache:
|
class ROCmCodeCache:
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
class CacheEntry:
|
class CacheEntry:
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,7 @@ import shutil
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch._inductor.utils import clear_on_fresh_inductor_cache
|
from torch._inductor.utils import clear_on_fresh_cache
|
||||||
|
|
||||||
from ... import config
|
from ... import config
|
||||||
|
|
||||||
|
|
@ -12,7 +12,7 @@ from ... import config
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@clear_on_fresh_inductor_cache
|
@clear_on_fresh_cache
|
||||||
@functools.lru_cache(1)
|
@functools.lru_cache(1)
|
||||||
def get_cuda_arch() -> Optional[str]:
|
def get_cuda_arch() -> Optional[str]:
|
||||||
try:
|
try:
|
||||||
|
|
@ -27,7 +27,7 @@ def get_cuda_arch() -> Optional[str]:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
@clear_on_fresh_inductor_cache
|
@clear_on_fresh_cache
|
||||||
@functools.lru_cache(1)
|
@functools.lru_cache(1)
|
||||||
def get_cuda_version() -> Optional[str]:
|
def get_cuda_version() -> Optional[str]:
|
||||||
try:
|
try:
|
||||||
|
|
|
||||||
|
|
@ -12,7 +12,7 @@ from torch._inductor.codecache import cutlass_key
|
||||||
from torch._inductor.codegen.cuda.cuda_env import get_cuda_arch, get_cuda_version
|
from torch._inductor.codegen.cuda.cuda_env import get_cuda_arch, get_cuda_version
|
||||||
from torch._inductor.codegen.cuda.serialization import get_cutlass_operation_serializer
|
from torch._inductor.codegen.cuda.serialization import get_cutlass_operation_serializer
|
||||||
from torch._inductor.runtime.cache_dir_utils import cache_dir
|
from torch._inductor.runtime.cache_dir_utils import cache_dir
|
||||||
from torch._inductor.utils import clear_on_fresh_inductor_cache
|
from torch._inductor.utils import clear_on_fresh_cache
|
||||||
|
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
@ -47,7 +47,7 @@ def _generate_config_filename(request_key: str) -> str:
|
||||||
return f"{CONFIG_PREFIX}_{request_key}.json"
|
return f"{CONFIG_PREFIX}_{request_key}.json"
|
||||||
|
|
||||||
|
|
||||||
@clear_on_fresh_inductor_cache
|
@clear_on_fresh_cache
|
||||||
@functools.cache
|
@functools.cache
|
||||||
def maybe_fetch_ops() -> Optional[list[Any]]:
|
def maybe_fetch_ops() -> Optional[list[Any]]:
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
|
|
@ -13,7 +13,7 @@ from typing import Any, Optional
|
||||||
import sympy
|
import sympy
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch._inductor.utils import clear_on_fresh_inductor_cache
|
from torch._inductor.utils import clear_on_fresh_cache
|
||||||
|
|
||||||
from ... import config
|
from ... import config
|
||||||
from ...ir import Layout
|
from ...ir import Layout
|
||||||
|
|
@ -250,7 +250,7 @@ class CUTLASSArgs:
|
||||||
self.architectures = _normalize_cuda_arch(self.architectures)
|
self.architectures = _normalize_cuda_arch(self.architectures)
|
||||||
|
|
||||||
|
|
||||||
@clear_on_fresh_inductor_cache
|
@clear_on_fresh_cache
|
||||||
@functools.cache
|
@functools.cache
|
||||||
def _gen_ops_cached(arch, version) -> dict[Any, Any]:
|
def _gen_ops_cached(arch, version) -> dict[Any, Any]:
|
||||||
# Note: Cache needs to be specific for cuda architecture and version
|
# Note: Cache needs to be specific for cuda architecture and version
|
||||||
|
|
|
||||||
|
|
@ -13,7 +13,7 @@ import torch.utils._pytree as pytree
|
||||||
from torch._inductor.codegen.cuda.cutlass_cache import maybe_fetch_ops
|
from torch._inductor.codegen.cuda.cutlass_cache import maybe_fetch_ops
|
||||||
from torch._inductor.scheduler import BaseSchedulerNode
|
from torch._inductor.scheduler import BaseSchedulerNode
|
||||||
from torch._inductor.select_algorithm import create_inputs_key
|
from torch._inductor.select_algorithm import create_inputs_key
|
||||||
from torch._inductor.utils import clear_on_fresh_inductor_cache
|
from torch._inductor.utils import clear_on_fresh_cache
|
||||||
|
|
||||||
from ... import ir
|
from ... import ir
|
||||||
from ...config import cuda as inductor_cuda_config
|
from ...config import cuda as inductor_cuda_config
|
||||||
|
|
@ -405,7 +405,7 @@ int main(int argc, char** argv) {
|
||||||
""" # noqa: B950
|
""" # noqa: B950
|
||||||
|
|
||||||
|
|
||||||
@clear_on_fresh_inductor_cache
|
@clear_on_fresh_cache
|
||||||
class CUTLASSGemmTemplate(CUTLASSTemplate, ABC):
|
class CUTLASSGemmTemplate(CUTLASSTemplate, ABC):
|
||||||
"""
|
"""
|
||||||
CUTLASS GEMM Template, which is used to generate CUTLASS GEMM kernels
|
CUTLASS GEMM Template, which is used to generate CUTLASS GEMM kernels
|
||||||
|
|
|
||||||
|
|
@ -75,7 +75,7 @@ from torch._inductor.runtime.cache_dir_utils import cache_dir
|
||||||
from torch._inductor.utils import (
|
from torch._inductor.utils import (
|
||||||
BoxedBool,
|
BoxedBool,
|
||||||
count_tangents,
|
count_tangents,
|
||||||
fresh_inductor_cache,
|
fresh_cache,
|
||||||
get_all_devices,
|
get_all_devices,
|
||||||
InputType,
|
InputType,
|
||||||
is_gpu,
|
is_gpu,
|
||||||
|
|
@ -675,7 +675,7 @@ def with_fresh_cache_if_config() -> Generator[None, None, None]:
|
||||||
# Don't delete the cache dir because it has to survive beyond the
|
# Don't delete the cache dir because it has to survive beyond the
|
||||||
# compile_fx call. Let's put the temp dirs under the default cache
|
# compile_fx call. Let's put the temp dirs under the default cache
|
||||||
# dir so they're easier to locate.
|
# dir so they're easier to locate.
|
||||||
with fresh_inductor_cache(dir=cache_dir(), delete=False):
|
with fresh_cache(dir=cache_dir(), delete=False):
|
||||||
yield
|
yield
|
||||||
else:
|
else:
|
||||||
yield
|
yield
|
||||||
|
|
|
||||||
|
|
@ -13,7 +13,7 @@ from torch._inductor.compile_worker.subproc_pool import (
|
||||||
SubprocKind,
|
SubprocKind,
|
||||||
SubprocPool,
|
SubprocPool,
|
||||||
)
|
)
|
||||||
from torch._inductor.utils import clear_inductor_caches
|
from torch._inductor.utils import clear_caches
|
||||||
|
|
||||||
from .compile_fx_ext import (
|
from .compile_fx_ext import (
|
||||||
_OutOfProcessFxCompile,
|
_OutOfProcessFxCompile,
|
||||||
|
|
@ -77,14 +77,14 @@ class _SubprocessFxCompile(_OutOfProcessFxCompile):
|
||||||
# tmpdir still exists and fails to compile.
|
# tmpdir still exists and fails to compile.
|
||||||
#
|
#
|
||||||
# TODO: We probably should be using a separate tmpdir in the worker
|
# TODO: We probably should be using a separate tmpdir in the worker
|
||||||
# anyway... but we should probably still respect clear_inductor_caches()
|
# anyway... but we should probably still respect clear_caches()
|
||||||
# in the parent... maybe?
|
# in the parent... maybe?
|
||||||
#
|
#
|
||||||
# TODO: We could be less aggressive by keeping a clock which gets
|
# TODO: We could be less aggressive by keeping a clock which gets
|
||||||
# incremented when we clear the cache, send the clock to the worker and
|
# incremented when we clear the cache, send the clock to the worker and
|
||||||
# only clear caches if the clock changed since last time.
|
# only clear caches if the clock changed since last time.
|
||||||
#
|
#
|
||||||
clear_inductor_caches()
|
clear_caches()
|
||||||
torch._inductor.metrics.reset()
|
torch._inductor.metrics.reset()
|
||||||
|
|
||||||
# TODO: turn off config.fx_graph_async_compile
|
# TODO: turn off config.fx_graph_async_compile
|
||||||
|
|
|
||||||
|
|
@ -39,15 +39,15 @@ def triton_cache_dir(device: int) -> str:
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def temporary_cache_dir(directory: str) -> Generator[None, None, None]:
|
def temporary_cache_dir(directory: str) -> Generator[None, None, None]:
|
||||||
from torch._inductor.utils import clear_inductor_caches
|
from torch._inductor.utils import clear_caches
|
||||||
|
|
||||||
original = os.environ.get("TORCHINDUCTOR_CACHE_DIR")
|
original = os.environ.get("TORCHINDUCTOR_CACHE_DIR")
|
||||||
os.environ["TORCHINDUCTOR_CACHE_DIR"] = directory
|
os.environ["TORCHINDUCTOR_CACHE_DIR"] = directory
|
||||||
try:
|
try:
|
||||||
clear_inductor_caches()
|
clear_caches()
|
||||||
yield
|
yield
|
||||||
finally:
|
finally:
|
||||||
clear_inductor_caches()
|
clear_caches()
|
||||||
if original is None:
|
if original is None:
|
||||||
del os.environ["TORCHINDUCTOR_CACHE_DIR"]
|
del os.environ["TORCHINDUCTOR_CACHE_DIR"]
|
||||||
else:
|
else:
|
||||||
|
|
|
||||||
|
|
@ -28,7 +28,7 @@ import torch._inductor.async_compile # noqa: F401 required to warm up AsyncComp
|
||||||
from torch._dynamo.device_interface import get_interface_for_device
|
from torch._dynamo.device_interface import get_interface_for_device
|
||||||
from torch._dynamo.testing import rand_strided
|
from torch._dynamo.testing import rand_strided
|
||||||
from torch._dynamo.utils import counters, dynamo_timed, identity, preserve_rng_state
|
from torch._dynamo.utils import counters, dynamo_timed, identity, preserve_rng_state
|
||||||
from torch._inductor.utils import clear_on_fresh_inductor_cache
|
from torch._inductor.utils import clear_on_fresh_cache
|
||||||
from torch.utils._filelock import FileLock
|
from torch.utils._filelock import FileLock
|
||||||
from torch.utils._ordered_set import OrderedSet
|
from torch.utils._ordered_set import OrderedSet
|
||||||
|
|
||||||
|
|
@ -1291,7 +1291,7 @@ class TritonTemplate(KernelTemplate):
|
||||||
self.debug = debug
|
self.debug = debug
|
||||||
self._cache_codegen_enabled_for_template = cache_codegen_enabled_for_template
|
self._cache_codegen_enabled_for_template = cache_codegen_enabled_for_template
|
||||||
self._generated_code_cache: GeneratedCodeCache = GeneratedCodeCache()
|
self._generated_code_cache: GeneratedCodeCache = GeneratedCodeCache()
|
||||||
clear_on_fresh_inductor_cache(self._generated_code_cache)
|
clear_on_fresh_cache(self._generated_code_cache)
|
||||||
# When prologue_loads_all_inputs is true, prologue_supported_inputs is populated during def_kernel
|
# When prologue_loads_all_inputs is true, prologue_supported_inputs is populated during def_kernel
|
||||||
# by adding all inputs.
|
# by adding all inputs.
|
||||||
self.prologue_loads_all_inputs = prologue_loads_all_inputs
|
self.prologue_loads_all_inputs = prologue_loads_all_inputs
|
||||||
|
|
@ -2123,7 +2123,7 @@ class AlgorithmSelectorCache(PersistentCache):
|
||||||
# list of callbacks that are called after benchmarking
|
# list of callbacks that are called after benchmarking
|
||||||
self.feedback_saver_fns: list[FeedbackFunction] = []
|
self.feedback_saver_fns: list[FeedbackFunction] = []
|
||||||
|
|
||||||
clear_on_fresh_inductor_cache(self)
|
clear_on_fresh_cache(self)
|
||||||
|
|
||||||
def cache_clear(self) -> None:
|
def cache_clear(self) -> None:
|
||||||
self.precompile_cache.clear()
|
self.precompile_cache.clear()
|
||||||
|
|
|
||||||
|
|
@ -8,7 +8,7 @@ from torch._dynamo.test_case import (
|
||||||
)
|
)
|
||||||
from torch._functorch import config as functorch_config
|
from torch._functorch import config as functorch_config
|
||||||
from torch._inductor import config
|
from torch._inductor import config
|
||||||
from torch._inductor.utils import fresh_inductor_cache
|
from torch._inductor.utils import fresh_cache
|
||||||
|
|
||||||
|
|
||||||
def run_tests(needs: Union[str, tuple[str, ...]] = ()) -> None:
|
def run_tests(needs: Union[str, tuple[str, ...]] = ()) -> None:
|
||||||
|
|
@ -41,7 +41,7 @@ class TestCase(DynamoTestCase):
|
||||||
os.environ.get("INDUCTOR_TEST_DISABLE_FRESH_CACHE") != "1"
|
os.environ.get("INDUCTOR_TEST_DISABLE_FRESH_CACHE") != "1"
|
||||||
and os.environ.get("TORCH_COMPILE_DEBUG") != "1"
|
and os.environ.get("TORCH_COMPILE_DEBUG") != "1"
|
||||||
):
|
):
|
||||||
self._inductor_test_stack.enter_context(fresh_inductor_cache())
|
self._inductor_test_stack.enter_context(fresh_cache())
|
||||||
|
|
||||||
def tearDown(self) -> None:
|
def tearDown(self) -> None:
|
||||||
super().tearDown()
|
super().tearDown()
|
||||||
|
|
|
||||||
|
|
@ -1015,29 +1015,6 @@ def get_all_devices(gm: torch.fx.GraphModule) -> OrderedSet[torch.device]:
|
||||||
return input_devices | out_devices
|
return input_devices | out_devices
|
||||||
|
|
||||||
|
|
||||||
_registered_caches: list[Any] = []
|
|
||||||
|
|
||||||
|
|
||||||
def clear_on_fresh_inductor_cache(obj: Any) -> Any:
|
|
||||||
"""
|
|
||||||
Use this decorator to register any caches that should be cache_clear'd
|
|
||||||
with fresh_inductor_cache().
|
|
||||||
"""
|
|
||||||
if not hasattr(obj, "cache_clear") or not callable(obj.cache_clear):
|
|
||||||
raise AttributeError(f"{obj} does not have a cache_clear method")
|
|
||||||
|
|
||||||
_registered_caches.append(obj)
|
|
||||||
return obj
|
|
||||||
|
|
||||||
|
|
||||||
def clear_inductor_caches() -> None:
|
|
||||||
"""
|
|
||||||
Clear all registered caches.
|
|
||||||
"""
|
|
||||||
for obj in _registered_caches:
|
|
||||||
obj.cache_clear()
|
|
||||||
|
|
||||||
|
|
||||||
import gc
|
import gc
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -1070,19 +1047,42 @@ def unload_xpu_triton_pyds() -> None:
|
||||||
gc.collect()
|
gc.collect()
|
||||||
|
|
||||||
|
|
||||||
|
_registered_caches: list[Any] = []
|
||||||
|
|
||||||
|
|
||||||
|
def clear_on_fresh_cache(obj: Any) -> Any:
|
||||||
|
"""
|
||||||
|
Use this decorator to register any caches that should be cache_clear'd
|
||||||
|
with fresh_cache().
|
||||||
|
"""
|
||||||
|
if not hasattr(obj, "cache_clear") or not callable(obj.cache_clear):
|
||||||
|
raise AttributeError(f"{obj} does not have a cache_clear method")
|
||||||
|
|
||||||
|
_registered_caches.append(obj)
|
||||||
|
return obj
|
||||||
|
|
||||||
|
|
||||||
|
def clear_caches() -> None:
|
||||||
|
"""
|
||||||
|
Clear all registered caches.
|
||||||
|
"""
|
||||||
|
for obj in _registered_caches:
|
||||||
|
obj.cache_clear()
|
||||||
|
|
||||||
|
|
||||||
@contextlib.contextmanager
|
@contextlib.contextmanager
|
||||||
def fresh_inductor_cache(
|
def fresh_cache(
|
||||||
cache_entries: Optional[dict[str, Any]] = None,
|
cache_entries: Optional[dict[str, Any]] = None,
|
||||||
dir: Optional[str] = None,
|
dir: Optional[str] = None,
|
||||||
delete: bool = True,
|
delete: bool = True,
|
||||||
) -> Iterator[None]:
|
) -> Iterator[None]:
|
||||||
"""
|
"""
|
||||||
Contextmanager that provides a clean tmp cachedir for inductor.
|
Contextmanager that provides a clean tmp cachedir for pt2 caches.
|
||||||
|
|
||||||
Optionally, pass a dict as 'cache_entries' to get a list of filenames and sizes
|
Optionally, pass a dict as 'cache_entries' to get a list of filenames and sizes
|
||||||
generated with this cache instance.
|
generated with this cache instance.
|
||||||
"""
|
"""
|
||||||
clear_inductor_caches()
|
clear_caches()
|
||||||
|
|
||||||
inductor_cache_dir = tempfile.mkdtemp(dir=dir)
|
inductor_cache_dir = tempfile.mkdtemp(dir=dir)
|
||||||
try:
|
try:
|
||||||
|
|
@ -1123,7 +1123,13 @@ def fresh_inductor_cache(
|
||||||
log.warning("on error, temporary cache dir kept at %s", inductor_cache_dir)
|
log.warning("on error, temporary cache dir kept at %s", inductor_cache_dir)
|
||||||
raise
|
raise
|
||||||
finally:
|
finally:
|
||||||
clear_inductor_caches()
|
clear_caches()
|
||||||
|
|
||||||
|
|
||||||
|
# Deprecated functions -- only keeping them for BC reasons
|
||||||
|
clear_on_fresh_inductor_cache = clear_on_fresh_cache
|
||||||
|
clear_inductor_caches = clear_caches
|
||||||
|
fresh_inductor_cache = fresh_cache
|
||||||
|
|
||||||
|
|
||||||
def argsort(seq: Sequence[Any]) -> list[int]:
|
def argsort(seq: Sequence[Any]) -> list[int]:
|
||||||
|
|
|
||||||
|
|
@ -15,7 +15,7 @@ from benchmark_utils import ( # type: ignore[import-not-found]
|
||||||
)
|
)
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch._inductor.utils import fresh_inductor_cache
|
from torch._inductor.utils import fresh_cache
|
||||||
|
|
||||||
|
|
||||||
class BenchmarkRunnerMixedMM(BenchmarkRunner): # type: ignore[misc, no-any-unimported]
|
class BenchmarkRunnerMixedMM(BenchmarkRunner): # type: ignore[misc, no-any-unimported]
|
||||||
|
|
@ -59,7 +59,7 @@ class BenchmarkRunnerMixedMM(BenchmarkRunner): # type: ignore[misc, no-any-unim
|
||||||
)
|
)
|
||||||
b = b.to(dtype=dtype_right)
|
b = b.to(dtype=dtype_right)
|
||||||
|
|
||||||
with fresh_inductor_cache():
|
with fresh_cache():
|
||||||
|
|
||||||
def mixed_mm(A, B):
|
def mixed_mm(A, B):
|
||||||
return torch.mm(A, B.to(A.dtype))
|
return torch.mm(A, B.to(A.dtype))
|
||||||
|
|
|
||||||
|
|
@ -16,7 +16,7 @@ from benchmark_utils import ( # type: ignore[import-not-found]
|
||||||
)
|
)
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch._inductor.utils import fresh_inductor_cache
|
from torch._inductor.utils import fresh_cache
|
||||||
|
|
||||||
|
|
||||||
class BenchmarkRunnerMM(BenchmarkRunner): # type: ignore[misc, no-any-unimported]
|
class BenchmarkRunnerMM(BenchmarkRunner): # type: ignore[misc, no-any-unimported]
|
||||||
|
|
@ -57,7 +57,7 @@ class BenchmarkRunnerMM(BenchmarkRunner): # type: ignore[misc, no-any-unimporte
|
||||||
dtype_right=dtype,
|
dtype_right=dtype,
|
||||||
)
|
)
|
||||||
|
|
||||||
with fresh_inductor_cache():
|
with fresh_cache():
|
||||||
|
|
||||||
def mixed_mm(A: Any, B: Any) -> Any:
|
def mixed_mm(A: Any, B: Any) -> Any:
|
||||||
return torch.mm(A, B)
|
return torch.mm(A, B)
|
||||||
|
|
|
||||||
|
|
@ -18,7 +18,7 @@ import torch
|
||||||
from torch._inductor.fx_passes.pad_mm import ( # type: ignore[import-not-found]
|
from torch._inductor.fx_passes.pad_mm import ( # type: ignore[import-not-found]
|
||||||
get_alignment_size_dtype,
|
get_alignment_size_dtype,
|
||||||
)
|
)
|
||||||
from torch._inductor.utils import fresh_inductor_cache
|
from torch._inductor.utils import fresh_cache
|
||||||
|
|
||||||
|
|
||||||
class BenchmarkRunnerPadMM(BenchmarkRunner): # type: ignore[misc, no-any-unimported]
|
class BenchmarkRunnerPadMM(BenchmarkRunner): # type: ignore[misc, no-any-unimported]
|
||||||
|
|
@ -74,7 +74,7 @@ class BenchmarkRunnerPadMM(BenchmarkRunner): # type: ignore[misc, no-any-unimpo
|
||||||
print(f"transpose_left={transpose_left} transpose_right={transpose_right}")
|
print(f"transpose_left={transpose_left} transpose_right={transpose_right}")
|
||||||
print(f"prepadded_left={prepadded_left} prepadded_right={prepadded_right}")
|
print(f"prepadded_left={prepadded_left} prepadded_right={prepadded_right}")
|
||||||
|
|
||||||
with fresh_inductor_cache():
|
with fresh_cache():
|
||||||
|
|
||||||
def mm(a: Any, b: Any) -> Any:
|
def mm(a: Any, b: Any) -> Any:
|
||||||
return torch.mm(a, b)
|
return torch.mm(a, b)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user