mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
it without creating cyclic dependencies Pull Request resolved: https://github.com/pytorch/pytorch/pull/109832 Approved by: https://github.com/zou3519
127 lines
4.2 KiB
Python
127 lines
4.2 KiB
Python
# Owner(s): ["module: inductor"]
|
|
import json
|
|
import unittest
|
|
|
|
import torch
|
|
import torch._dynamo.test_case
|
|
import torch._inductor.utils
|
|
|
|
from torch._inductor import config
|
|
from torch.profiler import ProfilerActivity
|
|
|
|
from torch.testing._internal.common_utils import TemporaryFileName, TEST_WITH_ROCM
|
|
|
|
from torch.utils._triton import has_triton
|
|
|
|
HAS_TRITON = has_triton()
|
|
|
|
|
|
class DynamoProfilerTests(torch._dynamo.test_case.TestCase):
|
|
@unittest.skipIf(not HAS_TRITON, "requires cuda & triton")
|
|
def test_inductor_profiling_triton_launch(self):
|
|
# Verify that we get some sort of CPU-side indication of triton kernel launches
|
|
# in the profile traces. Currently, those appear as `cuLaunchKernel`. If this
|
|
# detail changes, the test can be updated or removed.
|
|
@torch.compile
|
|
def fn(x, y):
|
|
return (x + y).sin().cos()
|
|
|
|
x, y = (torch.rand((4, 4), device="cuda") for _ in range(2))
|
|
|
|
with torch.profiler.profile() as prof:
|
|
fn(x, y)
|
|
|
|
with TemporaryFileName(mode="w+") as fname:
|
|
prof.export_chrome_trace(fname)
|
|
with open(fname) as f:
|
|
trace_json = json.load(f)
|
|
|
|
self.assertTrue("traceEvents" in trace_json)
|
|
events = trace_json["traceEvents"]
|
|
|
|
def nameMatchesLaunchKernel(event_name):
|
|
return "cuLaunchKernel" in event_name
|
|
|
|
self.assertTrue(
|
|
any(
|
|
("name" in event and "cuLaunchKernel" == event["name"])
|
|
for event in events
|
|
)
|
|
)
|
|
|
|
def _test_profiling_kernel_names(self, fn, args, kernel_name_str: str):
|
|
"""
|
|
We expect a record_function event to be added on the CPU side, surrounding
|
|
the launch of each triton kernel.
|
|
"""
|
|
fn_opt = torch.compile(fn)
|
|
|
|
for _ in range(2):
|
|
fn_opt(*args)
|
|
|
|
with torch.profiler.profile(activities=[ProfilerActivity.CPU]) as prof:
|
|
fn_opt(*args)
|
|
|
|
# The name of the kernel is expected to match the name of the kernel in debug
|
|
# files etc. The name could change in the future, but it seems reasonable that
|
|
# the name should always contain "triton" and "kernel_name_str" - e.g. if the
|
|
# kernel contains a sin op, it should probably contain "str" in the name.
|
|
# If this changes in the future, feel free to change the assertion here.
|
|
# Debugging tips: you can add prof.export_chrome_trace("test.json") inline in
|
|
# this test, and then view test.json in chrome://tracing to see the trace.
|
|
self.assertTrue(
|
|
any(
|
|
(
|
|
hasattr(event, "name")
|
|
and kernel_name_str in event.name
|
|
and "triton" in event.name
|
|
)
|
|
for event in prof.events()
|
|
)
|
|
)
|
|
|
|
@unittest.skipIf(not HAS_TRITON, "requires cuda & triton")
|
|
def test_inductor_profiling_kernel_names_pointwise(self):
|
|
def fn(x, y):
|
|
return (x + y).sin().cos()
|
|
|
|
args = [torch.rand((4, 4), device="cuda") for _ in range(2)]
|
|
|
|
self._test_profiling_kernel_names(fn, args, "sin")
|
|
|
|
@unittest.skipIf(not HAS_TRITON, "requires cuda & triton")
|
|
def test_inductor_profiling_kernel_names_template(self):
|
|
with config.patch(
|
|
{"max_autotune": True, "max_autotune_gemm_backends": "TRITON"}
|
|
):
|
|
|
|
def fn(x, y):
|
|
return x @ y
|
|
|
|
args = [torch.rand((4, 4), device="cuda") for _ in range(2)]
|
|
|
|
self._test_profiling_kernel_names(fn, args, "mm")
|
|
|
|
@unittest.skipIf(not HAS_TRITON, "requires cuda & triton")
|
|
def test_inductor_profiling_kernel_names_foreach(self):
|
|
with config.patch(
|
|
{"max_autotune": True, "max_autotune_gemm_backends": "TRITON"}
|
|
):
|
|
|
|
def fn(x, y):
|
|
return torch._foreach_add(x, y)
|
|
|
|
x = [torch.rand((4, 4), device="cuda") for _ in range(3)]
|
|
y = [torch.rand((4, 4), device="cuda") for _ in range(3)]
|
|
|
|
args = (x, y)
|
|
|
|
self._test_profiling_kernel_names(fn, args, "_for_")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
from torch._dynamo.test_case import run_tests
|
|
|
|
if not TEST_WITH_ROCM:
|
|
run_tests()
|