mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
This is the step 3 to add cutlass as an alternative inductor backend. Full tests can be found from the last PR in the stack. Feature request: https://github.com/pytorch/pytorch/issues/106991. Pull Request resolved: https://github.com/pytorch/pytorch/pull/107901 Approved by: https://github.com/jansel, https://github.com/aakhundov, https://github.com/kadeng ghstack dependencies: #107802, #107847
37 lines
915 B
Python
37 lines
915 B
Python
# Owner(s): ["module: inductor"]
|
|
|
|
import functools
|
|
import logging
|
|
|
|
import torch
|
|
|
|
from torch._dynamo.test_case import run_tests, TestCase
|
|
|
|
from torch._inductor.utils import do_bench, do_bench_using_profiling
|
|
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
|
|
class TestBench(TestCase):
|
|
@classmethod
|
|
def setUpClass(cls):
|
|
super().setUpClass()
|
|
x = torch.rand(1024, 10).cuda().half()
|
|
w = torch.rand(512, 10).cuda().half()
|
|
cls._bench_fn = functools.partial(torch.nn.functional.linear, x, w)
|
|
|
|
def test_do_bench(self):
|
|
res = do_bench(self._bench_fn)
|
|
log.warning("do_bench result: %s", res)
|
|
self.assertGreater(res, 0)
|
|
|
|
def test_do_bench_using_profiling(self):
|
|
res = do_bench_using_profiling(self._bench_fn)
|
|
log.warning("do_bench_using_profiling result: %s", res)
|
|
self.assertGreater(res, 0)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
run_tests("cuda")
|