pytorch/test/inductor/test_needs_exact_strides.py
gaoyvfeng 50f23ff6f8 rename-HAS_CUDA-to-HAS_CUDA_AND_TRITON (#159883)
Fixes #159399
"Modified torch.testing._internal.inductor_utils and test/inductor"

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159883
Approved by: https://github.com/janeyx99
2025-08-08 15:44:52 +00:00

103 lines
3.1 KiB
Python

# Owner(s): ["module: inductor"]
import torch
import torch.utils._pytree as pytree
from torch._inductor.pattern_matcher import (
CallFunctionVarArgs,
PatternMatcherPass,
register_graph_pattern,
)
from torch._inductor.test_case import run_tests, TestCase as InductorTestCase
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
IS_LINUX,
parametrize,
)
from torch.testing._internal.inductor_utils import HAS_CUDA_AND_TRITON
class TestNeedsExactStrides(InductorTestCase):
@parametrize("dtype", [torch.float, torch.float8_e8m0fnu])
def test_custom_op(self, dtype):
device = "cuda" # float8_e8m0fnu errors on "cpu"
x = torch.ones(4, 4, 2, 2, device=device, dtype=torch.float8_e8m0fnu)
other = torch.ones(4, 4, 2, 2, device=device, dtype=torch.float8_e8m0fnu)
class _CustomPass(PatternMatcherPass):
def __init__(self) -> None:
super().__init__()
def __call__(self, g: torch.fx.Graph):
self.apply(g)
g = _CustomPass()
called = False
@register_graph_pattern(
CallFunctionVarArgs(torch.ops.aten.permute),
pass_dict=g,
)
def _(match, *args, **kwargs):
flat_args, spec = pytree.tree_flatten((args, kwargs))
def decomp(*flat_args):
args, kwargs = pytree.tree_unflatten(flat_args, spec)
return torch.ops.mylib.force_channels_last(
torch.ops.aten.permute(*args, **kwargs)
)
nonlocal called
called = True
match.replace_by_example(decomp, flat_args)
from torch._inductor import config
class TestPassed(RuntimeError):
pass
with torch.library._scoped_library("mylib", "FRAGMENT") as lib:
lib.define(
"force_channels_last(Tensor x) -> Tensor",
tags=[torch._C.Tag.flexible_layout],
)
def impl2(x):
return x.clone(memory_format=torch.channels_last)
lib.impl("force_channels_last", impl2, "CompositeExplicitAutograd")
lib.define(
"add_op(Tensor x, Tensor y) -> Tensor",
)
def impl(x, y):
assert x.transpose(2, 3).is_contiguous()
assert y.is_contiguous()
return x.float() + y.float()
def meta(x, y):
return x.float() + y.float()
lib.impl("add_op", impl, "CompositeExplicitAutograd")
lib.impl("add_op", meta, "Meta")
def f(x, other):
return torch.ops.mylib.add_op.default(x.transpose(2, 3), other)
with config.patch(
post_grad_custom_post_pass=g,
):
try:
f_compile = torch.compile(f, fullgraph=True)
f_compile(x, other)
except TestPassed:
pass
assert called
instantiate_parametrized_tests(TestNeedsExactStrides)
if __name__ == "__main__":
if IS_LINUX and HAS_CUDA_AND_TRITON:
run_tests(needs="filelock")