Do not trace into triton/backends (#126083)

Fixes #125807

Pull Request resolved: https://github.com/pytorch/pytorch/pull/126083
Approved by: https://github.com/yanboliang, https://github.com/jansel
This commit is contained in:
Oguz Ulgen 2024-05-22 12:51:07 -07:00 committed by PyTorch MergeBot
parent 558c4413ce
commit cc61d03ac9
2 changed files with 19 additions and 1 deletions

View File

@ -1316,6 +1316,23 @@ def forward(self, x_1, output_1):
x = torch.randn(4, device="cuda")
f(x, x)
@requires_cuda
@skipIfRocm
@common_utils.parametrize("backend", ["eager", "aot_eager", "inductor"])
def test_triton_kernel_num_ctas(self, backend):
@triton.jit
def kernel(X):
return
@torch.compile(backend=backend)
def f(x):
kernel[(1,)](x, num_ctas=1)
kernel.run(x, num_ctas=1, grid=(1,), warmup=False)
return x
x = torch.randn(4, device="cuda")
f(x)
@requires_cuda
@skipIfRocm
@common_utils.parametrize("backend", ["eager", "aot_eager", "inductor"])

View File

@ -3267,6 +3267,7 @@ SKIP_DIRS = [
"<frozen importlib",
"<__array_function__ internals>",
_config_module.__file__,
"triton/backends",
]
SKIP_DIRS.extend(filter(None, (_module_dir(m) for m in BUILTIN_SKIPLIST)))
@ -3307,7 +3308,7 @@ FORCE_SKIP_FILES = {f"{_module_dir(torch)}optim/lr_scheduler.py"}
def _recompile_re():
global SKIP_DIRS_RE
SKIP_DIRS_RE = re.compile(f"^({'|'.join(map(re.escape, SKIP_DIRS))})")
SKIP_DIRS_RE = re.compile(rf"^[^\s<]*({'|'.join(map(re.escape, SKIP_DIRS))})")
def add(import_name: str):