mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-08 07:39:33 +01:00
See https://github.com/pytorch/pytorch/pull/129751#issue-2380881501. Most changes are auto-generated by linter. You can review these PRs via: ```bash git diff --ignore-all-space --ignore-blank-lines HEAD~1 ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/129763 Approved by: https://github.com/jansel
88 lines
2.6 KiB
Python
88 lines
2.6 KiB
Python
# Owner(s): ["module: inductor"]
|
|
|
|
import sys
|
|
import unittest
|
|
|
|
import torch
|
|
from torch.testing._internal.common_utils import IS_LINUX, skipIfXpu
|
|
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU
|
|
|
|
|
|
try:
|
|
import triton # noqa: F401
|
|
except ImportError:
|
|
if __name__ == "__main__":
|
|
sys.exit(0)
|
|
raise unittest.SkipTest("requires triton") # noqa: B904
|
|
|
|
from torch._inductor import config
|
|
from torch._inductor.runtime.hints import TRITON_MAX_BLOCK
|
|
from torch._inductor.runtime.triton_heuristics import triton_config
|
|
from torch._inductor.test_case import run_tests, TestCase
|
|
|
|
|
|
class TestTritonHeuristics(TestCase):
|
|
device_type = GPU_TYPE
|
|
|
|
def test_triton_config(self):
|
|
"""
|
|
Make sure block size does not exceed the maximum defined in inductor config.
|
|
"""
|
|
cfg = triton_config([2048, 2], 64, 64)
|
|
for label in "XYZ":
|
|
key = f"{label}BLOCK"
|
|
if key not in cfg.kwargs:
|
|
continue
|
|
self.assertTrue(cfg.kwargs[key] <= TRITON_MAX_BLOCK[label])
|
|
|
|
def _test_artificial_zgrid(self):
|
|
def forward(primals_1, primals_2, primals_5):
|
|
view = torch.ops.aten.reshape.default(primals_5, [-1, 2, 4])
|
|
primals_5 = None
|
|
permute = torch.ops.aten.permute.default(view, [0, 2, 1])
|
|
clone = torch.ops.aten.clone.default(
|
|
permute, memory_format=torch.contiguous_format
|
|
)
|
|
permute = None
|
|
view_1 = torch.ops.aten.reshape.default(clone, [-1, 4])
|
|
clone = None
|
|
permute_1 = torch.ops.aten.permute.default(primals_1, [1, 0])
|
|
primals_1 = None
|
|
addmm = torch.ops.aten.addmm.default(primals_2, view_1, permute_1)
|
|
primals_2 = None
|
|
return addmm
|
|
|
|
s0 = 16777472
|
|
s1 = 8
|
|
|
|
args = [
|
|
torch.rand([2, 4], device=GPU_TYPE),
|
|
torch.rand([2], device=GPU_TYPE),
|
|
torch.rand([s0, s1], device=GPU_TYPE),
|
|
]
|
|
torch._dynamo.mark_dynamic(args[-1], 0)
|
|
foo_c = torch.compile(forward)
|
|
|
|
self.assertEqual(forward(*args), foo_c(*args))
|
|
|
|
args = [
|
|
torch.rand([2, 4], device=GPU_TYPE),
|
|
torch.rand([2], device=GPU_TYPE),
|
|
torch.rand([s0, s1], device=GPU_TYPE),
|
|
]
|
|
self.assertEqual(forward(*args), foo_c(*args))
|
|
|
|
@skipIfXpu
|
|
def test_artificial_zgrid(self):
|
|
self._test_artificial_zgrid()
|
|
|
|
@skipIfXpu
|
|
@config.patch("cpp_wrapper", True)
|
|
def test_artificial_grid_cpp_wrapper(self):
|
|
self._test_artificial_zgrid()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
if IS_LINUX and HAS_GPU:
|
|
run_tests()
|