mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Fixes #159399 "Modified torch.testing._internal.inductor_utils and test/inductor" Pull Request resolved: https://github.com/pytorch/pytorch/pull/159883 Approved by: https://github.com/janeyx99
103 lines
3.1 KiB
Python
103 lines
3.1 KiB
Python
# Owner(s): ["module: inductor"]
|
|
|
|
import torch
|
|
import torch.utils._pytree as pytree
|
|
from torch._inductor.pattern_matcher import (
|
|
CallFunctionVarArgs,
|
|
PatternMatcherPass,
|
|
register_graph_pattern,
|
|
)
|
|
from torch._inductor.test_case import run_tests, TestCase as InductorTestCase
|
|
from torch.testing._internal.common_utils import (
|
|
instantiate_parametrized_tests,
|
|
IS_LINUX,
|
|
parametrize,
|
|
)
|
|
from torch.testing._internal.inductor_utils import HAS_CUDA_AND_TRITON
|
|
|
|
|
|
class TestNeedsExactStrides(InductorTestCase):
|
|
@parametrize("dtype", [torch.float, torch.float8_e8m0fnu])
|
|
def test_custom_op(self, dtype):
|
|
device = "cuda" # float8_e8m0fnu errors on "cpu"
|
|
x = torch.ones(4, 4, 2, 2, device=device, dtype=torch.float8_e8m0fnu)
|
|
other = torch.ones(4, 4, 2, 2, device=device, dtype=torch.float8_e8m0fnu)
|
|
|
|
class _CustomPass(PatternMatcherPass):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def __call__(self, g: torch.fx.Graph):
|
|
self.apply(g)
|
|
|
|
g = _CustomPass()
|
|
called = False
|
|
|
|
@register_graph_pattern(
|
|
CallFunctionVarArgs(torch.ops.aten.permute),
|
|
pass_dict=g,
|
|
)
|
|
def _(match, *args, **kwargs):
|
|
flat_args, spec = pytree.tree_flatten((args, kwargs))
|
|
|
|
def decomp(*flat_args):
|
|
args, kwargs = pytree.tree_unflatten(flat_args, spec)
|
|
return torch.ops.mylib.force_channels_last(
|
|
torch.ops.aten.permute(*args, **kwargs)
|
|
)
|
|
|
|
nonlocal called
|
|
called = True
|
|
match.replace_by_example(decomp, flat_args)
|
|
|
|
from torch._inductor import config
|
|
|
|
class TestPassed(RuntimeError):
|
|
pass
|
|
|
|
with torch.library._scoped_library("mylib", "FRAGMENT") as lib:
|
|
lib.define(
|
|
"force_channels_last(Tensor x) -> Tensor",
|
|
tags=[torch._C.Tag.flexible_layout],
|
|
)
|
|
|
|
def impl2(x):
|
|
return x.clone(memory_format=torch.channels_last)
|
|
|
|
lib.impl("force_channels_last", impl2, "CompositeExplicitAutograd")
|
|
|
|
lib.define(
|
|
"add_op(Tensor x, Tensor y) -> Tensor",
|
|
)
|
|
|
|
def impl(x, y):
|
|
assert x.transpose(2, 3).is_contiguous()
|
|
assert y.is_contiguous()
|
|
return x.float() + y.float()
|
|
|
|
def meta(x, y):
|
|
return x.float() + y.float()
|
|
|
|
lib.impl("add_op", impl, "CompositeExplicitAutograd")
|
|
lib.impl("add_op", meta, "Meta")
|
|
|
|
def f(x, other):
|
|
return torch.ops.mylib.add_op.default(x.transpose(2, 3), other)
|
|
|
|
with config.patch(
|
|
post_grad_custom_post_pass=g,
|
|
):
|
|
try:
|
|
f_compile = torch.compile(f, fullgraph=True)
|
|
f_compile(x, other)
|
|
except TestPassed:
|
|
pass
|
|
assert called
|
|
|
|
|
|
instantiate_parametrized_tests(TestNeedsExactStrides)
|
|
|
|
if __name__ == "__main__":
|
|
if IS_LINUX and HAS_CUDA_AND_TRITON:
|
|
run_tests(needs="filelock")
|