diff --git a/test/inductor/test_triton_kernels.py b/test/inductor/test_triton_kernels.py index b49ffefedf9..dd5bbce4aa7 100644 --- a/test/inductor/test_triton_kernels.py +++ b/test/inductor/test_triton_kernels.py @@ -1316,6 +1316,23 @@ def forward(self, x_1, output_1): x = torch.randn(4, device="cuda") f(x, x) + @requires_cuda + @skipIfRocm + @common_utils.parametrize("backend", ["eager", "aot_eager", "inductor"]) + def test_triton_kernel_num_ctas(self, backend): + @triton.jit + def kernel(X): + return + + @torch.compile(backend=backend) + def f(x): + kernel[(1,)](x, num_ctas=1) + kernel.run(x, num_ctas=1, grid=(1,), warmup=False) + return x + + x = torch.randn(4, device="cuda") + f(x) + @requires_cuda @skipIfRocm @common_utils.parametrize("backend", ["eager", "aot_eager", "inductor"]) diff --git a/torch/_dynamo/trace_rules.py b/torch/_dynamo/trace_rules.py index b43f4477371..3e6b50f4c50 100644 --- a/torch/_dynamo/trace_rules.py +++ b/torch/_dynamo/trace_rules.py @@ -3267,6 +3267,7 @@ SKIP_DIRS = [ "", _config_module.__file__, + "triton/backends", ] SKIP_DIRS.extend(filter(None, (_module_dir(m) for m in BUILTIN_SKIPLIST))) @@ -3307,7 +3308,7 @@ FORCE_SKIP_FILES = {f"{_module_dir(torch)}optim/lr_scheduler.py"} def _recompile_re(): global SKIP_DIRS_RE - SKIP_DIRS_RE = re.compile(f"^({'|'.join(map(re.escape, SKIP_DIRS))})") + SKIP_DIRS_RE = re.compile(rf"^[^\s<]*({'|'.join(map(re.escape, SKIP_DIRS))})") def add(import_name: str):