pytorch/test/profiler/test_execution_trace.py
Sheng Fu f0a5a0d298 OSS: Capture triton kernel in ET (#124775)
This DIFF is to capture triton kernels in execution trace

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124775
Approved by: https://github.com/briancoutinho, https://github.com/aaronenyeshi
2024-04-27 18:01:18 +00:00

389 lines
14 KiB
Python

# Owner(s): ["oncall: profiler"]
# if tqdm is not shutdown properly, it will leave the monitor thread alive.
# This causes an issue in the multithreading test because we check all events
# in that test with their tids. The events that correspond to these lingering
# threads all have TID of (uint64_t)(-1) which is invalid.
# The work around is turnning off monitoring thread when tqdm is loaded.
# Since these are unit tests, it is safe to turn off monitor thread.
try:
import tqdm
tqdm.tqdm.monitor_interval = 0
except ImportError:
pass
import json
import sys
import tempfile
import unittest
from typing import Any, Dict, List
import torch
import torch.nn as nn
from torch import _dynamo as torchdynamo
from torch.autograd import (
_record_function_with_args_enter,
_record_function_with_args_exit,
)
from torch.profiler import (
ExecutionTraceObserver,
kineto_available,
profile,
record_function,
supported_activities,
)
from torch.testing._internal.common_cuda import TEST_CUDA
from torch.testing._internal.common_utils import (
IS_WINDOWS,
run_tests,
skipIfTorchDynamo,
TestCase,
)
from torch.utils._triton import has_triton
Json = Dict[str, Any]
class TestExecutionTrace(TestCase):
def payload(self, use_cuda=False):
u = torch.randn(3, 4, 5, requires_grad=True)
with record_function("## TEST 1 ##", "1, 2, 3"):
inf_val = float("inf")
neg_inf_val = float("-inf")
nan_val = float("nan")
rf_handle = _record_function_with_args_enter(
"## TEST 2 ##",
1,
False,
2.5,
[u, u],
(u, u),
"hello",
u,
inf_val,
neg_inf_val,
nan_val,
)
x = torch.randn(10, 10, requires_grad=True)
if use_cuda:
x = x.cuda()
y = torch.randn(10, 10, requires_grad=True)
if use_cuda:
y = y.cuda()
z = x + y + x * y + x * y
z.backward(z)
gelu = nn.GELU()
m = torch.randn(2)
_ = gelu(m)
if use_cuda:
z = z.cpu()
_record_function_with_args_exit(rf_handle)
def get_execution_trace_root(self, output_file_name) -> Json:
nodes = []
with open(output_file_name) as f:
et_graph = json.load(f)
assert "nodes" in et_graph
nodes = et_graph["nodes"]
return nodes
def get_execution_trace_rf_ids(self, nodes: List[Json]) -> List[int]:
"""Returns a sorted list of rf_id (record function ids) in execution trace"""
def get_rf_id(node):
attrs = node["attrs"]
for a in attrs:
if a["name"] == "rf_id":
return a["value"]
return None
rf_ids_ = (
get_rf_id(n)
for n in nodes
if n["name"] != "[pytorch|profiler|execution_trace|process]"
and n["name"] != "[pytorch|profiler|execution_trace|thread]"
)
return sorted(rf_id for rf_id in rf_ids_ if rf_id is not None)
def get_kineto_rf_ids(self, events: List[Json]) -> List[int]:
"""Returns a sorted list of Record function IDs for CPU operators and user annotations"""
ops_and_annotations = (
e for e in events if e.get("cat", "") in ["cpu_op", "user_annotation"]
)
return sorted(
e.get("args", {}).get("Record function id", -1) for e in ops_and_annotations
)
@unittest.skipIf(not kineto_available(), "Kineto is required")
def test_execution_trace_with_kineto(self):
trace_called_num = 0
def trace_handler(p):
nonlocal trace_called_num
trace_called_num += 1
use_cuda = torch.profiler.ProfilerActivity.CUDA in supported_activities()
# Create a temp file to save execution trace and kineto data.
fp = tempfile.NamedTemporaryFile("w+t", suffix=".et.json", delete=False)
fp.close()
kt = tempfile.NamedTemporaryFile(
mode="w+t", suffix=".kineto.json", delete=False
)
kt.close()
with profile(
activities=supported_activities(),
schedule=torch.profiler.schedule(
skip_first=3, wait=1, warmup=1, active=2, repeat=1
),
on_trace_ready=trace_handler,
execution_trace_observer=(
ExecutionTraceObserver().register_callback(fp.name)
),
) as p:
for idx in range(10):
with record_function(f"## LOOP {idx} ##"):
self.payload(use_cuda=use_cuda)
p.step()
self.assertEqual(fp.name, p.execution_trace_observer.get_output_file_path())
# Uncomment for debugging
# print("Output kineto = ", kt.name)
# print("Output ET = ", fp.name)
p.export_chrome_trace(kt.name)
self.assertEqual(trace_called_num, 1)
nodes = self.get_execution_trace_root(fp.name)
loop_count = 0
found_root_node = False
for n in nodes:
assert "name" in n
if "[pytorch|profiler|execution_trace|process]" in n["name"]:
found_root_node = True
if n["name"].startswith("## LOOP "):
loop_count += 1
self.assertTrue(found_root_node)
# Since profiler trace is active for 2 iterations
self.assertEqual(loop_count, 2)
# Compare the collected Execution Trace and Kineto Trace
# in terms of record func ID (rf_id) and External IDs
# both of these should match for the same trace window.
with open(kt.name) as f:
kineto = json.load(f)
events = kineto["traceEvents"]
# Look up rf_ids in both Execution and Kineto trace as two lists.
rf_ids_et = self.get_execution_trace_rf_ids(nodes)
rf_ids_kineto = self.get_kineto_rf_ids(events)
self.assertCountEqual(rf_ids_et, rf_ids_kineto)
self.assertListEqual(
rf_ids_et,
rf_ids_kineto,
msg=f"ET and kineto rf_id should exactly match\n"
f" rf_ids_et = {rf_ids_et}\n"
f" rf_ids_kineto = {rf_ids_kineto}\n",
)
def test_execution_trace_alone(self):
use_cuda = torch.profiler.ProfilerActivity.CUDA in supported_activities()
# Create a temp file to save execution trace data.
fp = tempfile.NamedTemporaryFile("w+t", suffix=".et.json", delete=False)
fp.close()
expected_loop_events = 0
et = ExecutionTraceObserver().register_callback(fp.name)
et.start()
for idx in range(5):
expected_loop_events += 1
with record_function(f"## LOOP {idx} ##"):
self.payload(use_cuda=use_cuda)
et.stop()
assert fp.name == et.get_output_file_path()
et.unregister_callback()
nodes = self.get_execution_trace_root(fp.name)
loop_count = 0
# Expected tensor object tuple size, in th form of:
# [tensor_id, storage_id, offset, numel, itemsize, device_str]
tensor_tuple_size = 6
found_root_node = False
for n in nodes:
assert "name" in n
if "[pytorch|profiler|execution_trace|process]" in n["name"]:
found_root_node = True
if n["name"].startswith("## LOOP "):
loop_count += 1
# Check if tensor tuple representation size is correct.
if n["name"] == "## TEST 2 ##":
assert len(n["inputs"]["values"][3][0]) == tensor_tuple_size
assert found_root_node
assert loop_count == expected_loop_events
@unittest.skipIf(IS_WINDOWS, "torch.compile does not support WINDOWS")
@unittest.skipIf(
sys.version_info >= (3, 12), "torch.compile is not supported on python 3.12+"
)
@unittest.skipIf(not TEST_CUDA or not has_triton(), "need CUDA and triton to run")
def test_execution_trace_with_pt2(self):
@torchdynamo.optimize("inductor")
def fn(a, b, c):
x = torch.nn.functional.linear(a, b)
x = x + c
return x.cos()
a, b, c = (torch.randn(4, 4, requires_grad=True).to("cuda") for _ in range(3))
inputs = [a, b, c]
with torch._inductor.config.patch(compile_threads=1):
fn(*inputs)
# Create a temp file to save execution trace data.
fp = tempfile.NamedTemporaryFile("w+t", suffix="_et.json", delete=False)
fp.close()
with profile(
activities=torch.profiler.supported_activities(),
record_shapes=True,
schedule=torch.profiler.schedule(
skip_first=3, wait=1, warmup=1, active=2, repeat=1
),
execution_trace_observer=(
ExecutionTraceObserver().register_callback(fp.name)
),
) as p:
for idx in range(10):
with record_function(f"## LOOP {idx} ##"):
fn(*inputs)
p.step()
nodes = self.get_execution_trace_root(fp.name)
found_captured_triton_kernel_node = False
for n in nodes:
assert "name" in n
if "triton_" in n["name"]:
for attr in n["attrs"]:
if attr["name"] == "kernel_file" and attr["value"] != "":
found_captured_triton_kernel_node = True
assert len(n["inputs"]["values"]) > 0
assert len(n["outputs"]["values"]) == 0
assert found_captured_triton_kernel_node
def test_execution_trace_start_stop(self):
use_cuda = torch.profiler.ProfilerActivity.CUDA in supported_activities()
# Create a temp file to save execution trace data.
fp = tempfile.NamedTemporaryFile("w+t", suffix=".et.json", delete=False)
fp.close()
expected_loop_events = 0
et = ExecutionTraceObserver().register_callback(fp.name)
for idx in range(10):
if idx == 3:
et.start()
elif idx == 5:
et.stop()
elif idx == 8:
et.start()
elif idx == 9:
et.stop()
if et._execution_trace_running:
expected_loop_events += 1
with record_function(f"## LOOP {idx} ##"):
self.payload(use_cuda=use_cuda)
assert fp.name == et.get_output_file_path()
et.unregister_callback()
nodes = self.get_execution_trace_root(fp.name)
loop_count = 0
found_root_node = False
for n in nodes:
assert "name" in n
if "[pytorch|profiler|execution_trace|process]" in n["name"]:
found_root_node = True
if n["name"].startswith("## LOOP "):
loop_count += 1
assert found_root_node
assert loop_count == expected_loop_events
def test_execution_trace_repeat_in_loop(self):
use_cuda = torch.profiler.ProfilerActivity.CUDA in supported_activities()
iter_list = {3, 4, 6, 8}
expected_loop_events = len(iter_list)
output_files = []
for idx in range(10):
if idx in iter_list:
# Create a temp file to save execution trace data.
fp = tempfile.NamedTemporaryFile("w+t", suffix=".et.json", delete=False)
fp.close()
output_files.append(fp.name)
et = ExecutionTraceObserver().register_callback(fp.name)
et.start()
with record_function(f"## LOOP {idx} ##"):
self.payload(use_cuda=use_cuda)
if idx in iter_list:
et.stop()
et.unregister_callback()
event_count = 0
for et_file in output_files:
nodes = self.get_execution_trace_root(et_file)
found_root_node = False
for n in nodes:
assert "name" in n
if "[pytorch|profiler|execution_trace|process]" in n["name"]:
assert n["id"] == 1
found_root_node = True
if n["name"].startswith("## LOOP "):
event_count += 1
assert found_root_node
assert event_count == expected_loop_events
def test_execution_trace_no_capture(self):
fp = tempfile.NamedTemporaryFile("w+t", suffix=".et.json", delete=False)
fp.close()
et = ExecutionTraceObserver().register_callback(fp.name)
assert fp.name == et.get_output_file_path()
et.unregister_callback()
nodes = self.get_execution_trace_root(fp.name)
for n in nodes:
assert "name" in n
if "[pytorch|profiler|execution_trace|process]" in n["name"]:
found_root_node = True
assert found_root_node
@skipIfTorchDynamo("https://github.com/pytorch/pytorch/issues/124500")
def test_execution_trace_nested_tensor(self):
fp = tempfile.NamedTemporaryFile("w+t", suffix=".et.json", delete=False)
fp.close()
observer = ExecutionTraceObserver().register_callback(fp.name)
def fn(nt):
return nt.sin().cos()
with torch.profiler.profile(execution_trace_observer=observer) as prof:
for i in range(3):
values = torch.rand((8 + i, 4 + i))
offsets = torch.tensor([0, 2, 4, 6, 8 + i])
nt = torch.nested.nested_tensor_from_jagged(values, offsets)
fn(nt)
nodes = self.get_execution_trace_root(fp.name)
found_cos = False
for n in nodes:
assert "name" in n
if "cos" in n["name"]:
found_cos = True
assert found_cos
if __name__ == "__main__":
run_tests()