# Owner(s): ["oncall: profiler"] # ruff: noqa: F841 import collections import gc import json import mmap import os import pickle import random import re import struct import subprocess import sys import tempfile import threading import time import unittest from dataclasses import dataclass, field from typing import Optional, TYPE_CHECKING from unittest.mock import patch import expecttest import torch import torch.nn as nn import torch.optim import torch.utils.data from torch._C._profiler import _ExperimentalConfig, _ExtraFields_PyCall from torch._inductor.utils import is_big_gpu from torch.autograd.profiler import KinetoStepTracker, profile as _profile from torch.autograd.profiler_legacy import profile as _profile_legacy from torch.profiler import ( _utils, DeviceType, kineto_available, profile, ProfilerAction, ProfilerActivity, record_function, supported_activities, ) from torch.profiler._pattern_matcher import ( Conv2dBiasFollowedByBatchNorm2dPattern, ExtraCUDACopyPattern, ForLoopIndexingPattern, FP32MatMulPattern, GradNotSetToNonePattern, MatMulDimInFP16Pattern, NamePattern, OptimizerSingleTensorPattern, Pattern, report_all_anti_patterns, SynchronizedDataLoaderPattern, ) from torch.testing._internal.common_cuda import TEST_MULTIGPU from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, IS_ARM64, IS_JETSON, IS_LINUX, IS_WINDOWS, parametrize, run_tests, serialTest, skipIfTorchDynamo, TemporaryDirectoryName, TemporaryFileName, TEST_CUDA, TEST_WITH_CROSSREF, TEST_WITH_ROCM, TEST_XPU, TestCase, ) if TYPE_CHECKING: from torch.autograd.profiler_util import FunctionEvent # if tqdm is not shutdown properly, it will leave the monitor thread alive. # This causes an issue in the multithreading test because we check all events # in that test with their tids. The events that correspond to these lingering # threads all have TID of (uint64_t)(-1) which is invalid. # The work around is turning off monitoring thread when tqdm is loaded. # Since these are unit tests, it is safe to turn off monitor thread. try: import tqdm tqdm.tqdm.monitor_interval = 0 except ImportError: pass try: import psutil HAS_PSUTIL = True except ModuleNotFoundError: HAS_PSUTIL = False psutil = None @unittest.skipIf(not HAS_PSUTIL, "Requires psutil to run") @unittest.skipIf(IS_WINDOWS, "Test is flaky on Windows") @unittest.skipIf(not torch.cuda.is_available(), "CUDA is required") class TestProfilerCUDA(TestCase): def test_mem_leak(self): """Checks that there's no memory leak when using profiler with CUDA""" t = torch.rand(1, 1).cuda() p = psutil.Process() last_rss = collections.deque(maxlen=5) for _ in range(10): with _profile(use_cuda=True): for _ in range(1024): t = torch.mm(t, t) gc.collect() torch.cuda.empty_cache() last_rss.append(p.memory_info().rss) # with CUDA events leaking the increase in memory was ~7 MB between # profiler invocations above is_increasing = all( last_rss[idx] > last_rss[idx - 1] for idx in range(1, len(last_rss)) ) max_diff = -1 for idx in range(1, len(last_rss)): max_diff = max(max_diff, last_rss[idx] - last_rss[idx - 1]) self.assertTrue( not (is_increasing and max_diff > 100 * 1024), msg=f"memory usage is increasing, {str(last_rss)}", ) def test_custom_module_input_op_ids(self): class MyFunc(torch.autograd.Function): @staticmethod def forward(ctx, x): ctx.save_for_backward(x) return x @staticmethod def backward(ctx, gO): (x,) = ctx.saved_tensors return x def custom_layer(input_ten): return MyFunc.apply(input_ten) # Only testing that emit_nvtx runs when # record_shapes option is enabled. with torch.autograd.profiler.emit_nvtx(record_shapes=True) as prof: x = torch.randn(10, 10, requires_grad=True) y = torch.randn(10, 10, requires_grad=True) z = x + y s = custom_layer(z) q = s.sum() q.backward() @unittest.skipIf(not torch.cuda.is_available(), "CUDA is required") def test_cudagraph_profiling_workaround(self): import subprocess # repro taken from #75504 # Launch in a separate process to catch hanging/illegal memory errors # and to make sure CUPTI isn't already initialized. p = subprocess.check_call( [ sys.executable, "-c", """ import os import torch from torch.profiler import ProfilerActivity, profile def add_one(in_: torch.Tensor): return in_ + 1 sample_arg = torch.zeros(10, device="cuda").requires_grad_(True) # add this before cuda graphs are created torch.profiler._utils._init_for_cuda_graphs() add_one_graphed = torch.cuda.graphs.make_graphed_callables(add_one, sample_args=(sample_arg,)) zeros = torch.zeros(10, device="cuda") out = add_one_graphed(zeros) assert out[0] == 1 with profile(activities=[ProfilerActivity.CPU]): add_one_graphed(zeros) with profile(activities=[ProfilerActivity.CUDA]): add_one_graphed(zeros) """, ], universal_newlines=True, timeout=60, ) # ^ this will throw an exception if the script fails. @unittest.skipIf(not torch.profiler.itt.is_available(), "ITT is required") class TestProfilerITT(TestCase): def test_custom_module_input_op_ids(self): class MyFunc(torch.autograd.Function): @staticmethod def forward(ctx, x): ctx.save_for_backward(x) return x @staticmethod def backward(ctx, gO): (x,) = ctx.saved_tensors return x def custom_layer(input_ten): return MyFunc.apply(input_ten) # Only testing that emit_itt runs when # record_shapes option is enabled. with torch.autograd.profiler.emit_itt(record_shapes=True) as prof: x = torch.randn(10, 10, requires_grad=True) y = torch.randn(10, 10, requires_grad=True) z = x + y s = custom_layer(z) q = s.sum() q.backward() @instantiate_parametrized_tests class TestProfiler(TestCase): @unittest.skipIf( TEST_WITH_CROSSREF, "crossref intercepts calls and changes the callsite." ) def test_source(self): """Checks that source code attribution works for eager, TS and autograd mode""" # avoid automatic inlining prev_opt = torch._C._get_graph_executor_optimize() torch._C._set_graph_executor_optimize(False) @torch.jit.script def ts_method_2(x, y): return torch.matmul(x, y) @torch.jit.script def ts_method_1(x, y, z): a = x + z w = ts_method_2(x, y) + a return w.sum() class DummyModule(nn.Module): def __init__(self) -> None: super().__init__() self.conv = torch.nn.Conv2d( 3, 2, kernel_size=1, stride=2, padding=3, bias=False ) def forward(self, x): return self.conv(x) mod = DummyModule() def call_module(x): return mod(x) with _profile( with_stack=True, use_kineto=kineto_available(), experimental_config=_ExperimentalConfig(verbose=True), ) as p: x = torch.randn(10, 10, requires_grad=True) y = torch.randn(10, 10, requires_grad=True) z = x + y w = ts_method_1(x, y, z) v = 2 * w v.backward() a = torch.randn(2, 3, 2, 2, requires_grad=True) b = call_module(a) c = b.sum() c.backward() for e in p.function_events: if "aten::add" in e.name or "AddBackward" in e.name: self.assertTrue(any("test_profiler" in entry for entry in e.stack)) self.assertTrue( any( ( "test_source" in entry or "ts_method_1" in entry or "ts_method_2" in entry ) for entry in e.stack ) ) if kineto_available(): with TemporaryFileName(mode="w+") as fname: p.export_chrome_trace(fname) with open(fname) as f: events = json.load(f)["traceEvents"] def extract(pattern: str): matches = [e for e in events if re.search(pattern, e["name"])] self.assertEqual( len(matches), 1, repr([e["name"] for e in matches]) ) return matches[0] module_event = extract(r"DummyModule_0") wrapper_event = extract(r"call_module") self.assertEqual( module_event["args"]["Python parent id"], wrapper_event["args"]["Python id"], ) torch._C._set_graph_executor_optimize(prev_opt) @parametrize( "name,thread_spec", { "basic": ((False, False),), "multiple_preexisting": ((False, False),) * 2, "open_in_scope": ((True, False),), "close_in_scope": ((False, True),), "complex": ( # Large number of background threads (False, False), (False, False), (False, False), (False, False), # some of which finish during profiling (False, True), (False, True), # And the profiled section is also multithreaded (True, False), (True, True), ), }.items(), name_fn=lambda name, thread_spec: name, ) @serialTest() @parametrize("work_in_main_thread", [True, False]) @skipIfTorchDynamo("profiler gets ignored if dynamo activated") def test_source_multithreaded(self, name, thread_spec, work_in_main_thread): """Test various threading configurations. `thread_spec` is a Tuple[Tuple[bool, bool], ...] where each pair is a thread. The first bool indicates if the thread should be started under the profiler context and the second is if it should be joined under the profiler context. """ timeout = 15 num_threads = len(thread_spec) + 1 # Main thread start_barrier = threading.Barrier(num_threads, timeout=timeout) end_barrier = threading.Barrier(num_threads, timeout=timeout) class Task(threading.Thread): def __init__(self) -> None: self._end_gate = threading.Event() super().__init__(daemon=True) self.start() self.finished = False def run(self): self._run(self._end_gate) def release(self): self._end_gate.set() @staticmethod def _run(end_gate=None): def known_preexisting_function(): start_barrier.wait() # Fixed point that we can use to test capture of functions # which are already running when profiling is enabled. known_preexisting_function() model = torch.nn.Sequential( torch.nn.Linear(10, 10), torch.nn.ReLU(), ) def invoked_during_run(): pass invoked_during_run() _ = model(torch.rand(4, 10)) end_barrier.wait() if end_gate is not None: end_gate.wait(timeout=timeout) threads = {} def add_threads(context: bool): for idx, (start_under_profiler, _) in enumerate(thread_spec): if start_under_profiler == context: assert idx not in threads threads[idx] = Task() def join_threads(context: bool): for idx, (_, end_under_profiler) in enumerate(thread_spec): if end_under_profiler == context: threads[idx].release() for idx, (_, end_under_profiler) in enumerate(thread_spec): t = threads[idx] if end_under_profiler == context: t.join(timeout=timeout) try: add_threads(False) with torch.profiler.profile(with_stack=True) as prof: # Threads added while the profiler are running will not be observed # since there is no way to hook into Python's thread start call to # register the observer. These are here purely to verify safety. add_threads(True) if work_in_main_thread: Task._run() else: start_barrier.wait() end_barrier.wait() join_threads(True) join_threads(False) finally: # It is very important that we clean up everything because the # Python tracer will detect ALL active threads. (Even orphans from # prior failed tests.) If we don't clean up properly we can # contaminate subsequent tests. start_barrier.abort() end_barrier.abort() for t in threads.values(): t.release() for t in threads.values(): t.join(timeout=timeout) for t in threads.values(): self.assertFalse(t.is_alive()) roots = prof.profiler.kineto_results.experimental_event_tree() nodes = [ node for node in _utils.traverse_dfs(roots) if isinstance(node.extra_fields, _ExtraFields_PyCall) ] tid_counts = collections.Counter([node.start_tid for node in nodes]) prior_threads = sum( not start_under_profiler for start_under_profiler, _ in thread_spec ) expected_threads = prior_threads + 1 self.assertEqual( len(tid_counts), expected_threads, f"{expected_threads}, {tid_counts}" ) self.assertEqual(len(nodes), sum(tid_counts.values())) # Profiler uses uint64_t max as a placeholder until TID can be determined. no_tid = 2**64 - 1 self.assertFalse(no_tid in tid_counts) worker_threads = prior_threads + (1 if work_in_main_thread else 0) observed_preexisting = [ node.start_tid for node in nodes if "known_preexisting_function" in node.name ] self.assertEqual(len(observed_preexisting), worker_threads) self.assertEqual(len(observed_preexisting), len(set(observed_preexisting))) observed_during_run = [ node.start_tid for node in nodes if "invoked_during_run" in node.name ] self.assertEqual(len(observed_during_run), worker_threads) self.assertEqual(len(observed_during_run), len(set(observed_during_run))) def payload(self, use_cuda=False): x = torch.randn(10, 10) if use_cuda: x = x.cuda() y = torch.randn(10, 10) if use_cuda: y = y.cuda() z = torch.mm(x, y) z = z + y if use_cuda: z = z.cpu() def _check_stats(self, profiler_stats): self.assertGreater(profiler_stats.profiling_window_duration_sec, 0) self.assertGreater(profiler_stats.number_of_events, 0) self.assertGreater(profiler_stats.profiler_prepare_call_duration_us, 0) self.assertGreater(profiler_stats.profiler_enable_call_duration_us, 0) self.assertGreater(profiler_stats.profiler_disable_call_duration_us, 0) self.assertGreater(profiler_stats.parse_kineto_call_duration_us, 0) self.assertGreater( profiler_stats.function_events_build_tree_call_duration_us, 0 ) @unittest.skipIf(not kineto_available(), "Kineto is required") def test_kineto(self): use_cuda = torch.profiler.ProfilerActivity.CUDA in supported_activities() with _profile(use_cuda=use_cuda, use_kineto=True): self.payload(use_cuda=use_cuda) # rerun to avoid initial start overhead with _profile(use_cuda=use_cuda, use_kineto=True) as p: self.payload(use_cuda=use_cuda) self.assertTrue("aten::mm" in str(p)) output = p.key_averages().table( sort_by="self_cuda_time_total" if use_cuda else "self_cpu_time_total", row_limit=-1, ) # print(output) found_gemm = False found_memcpy = False found_mm = False for e in p.function_events: if "aten::mm" in e.name: found_mm = True if "gemm" in e.name.lower() or "Cijk" in e.name: found_gemm = True if "memcpy" in e.name.lower() or "__amd_rocclr_copyBuffer" in e.name: found_memcpy = True if use_cuda: self.assertTrue(found_gemm) self.assertTrue(found_memcpy) else: self.assertTrue(found_mm) self._check_stats(p._stats) # p.export_chrome_trace("/tmp/test_trace.json") @unittest.skipIf(not kineto_available(), "Kineto is required") @unittest.skipIf(not TEST_MULTIGPU, "Multiple GPUs needed") @unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm") def test_kineto_multigpu(self): with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as prof: for gpu_id in [0, 1]: x = torch.randn(10, 10).cuda(gpu_id) y = torch.randn(10, 10).cuda(gpu_id) z = x.matmul(y) found_gemm_0 = False found_gemm_1 = False found_cuda = False for evt in prof.events(): if "gemm" in evt.name.lower() and evt.device_type == DeviceType.CUDA: if evt.device_index == 0: found_gemm_0 = True elif evt.device_index == 1: found_gemm_1 = True if "cuda" in evt.name.lower() and evt.device_type == DeviceType.CPU: found_cuda = True self.assertTrue(found_gemm_0) self.assertTrue(found_gemm_1) self.assertTrue(found_cuda) self._check_stats(prof._stats()) def test_memory_profiler(self): def run_profiler(tensor_creation_fn): # collecting allocs / deallocs with _profile( profile_memory=True, record_shapes=True, use_kineto=kineto_available(), ) as prof: x = None with record_function("test_user_scope_alloc"): x = tensor_creation_fn() with record_function("test_user_scope_dealloc"): del x return prof.key_averages(group_by_input_shape=True) def check_metrics(stats, metric, allocs=None, deallocs=None): stat_metrics = {} # print(stats) for stat in stats: stat_metrics[stat.key] = getattr(stat, metric) # print(stat_metrics) if allocs is not None: for alloc_fn in allocs: self.assertTrue(alloc_fn in stat_metrics) self.assertGreater( stat_metrics[alloc_fn], 0, f"alloc_fn = {alloc_fn}" ) if deallocs is not None: for dealloc_fn in deallocs: self.assertTrue(dealloc_fn in stat_metrics) self.assertLess( stat_metrics[dealloc_fn], 0, f"alloc_fn = {dealloc_fn}" ) def create_cpu_tensor(): return torch.rand(10, 10) def create_cuda_tensor(): return torch.rand(10, 10).cuda() def create_xpu_tensor(): return torch.rand(10, 10).xpu() def create_mkldnn_tensor(): return torch.rand(10, 10, dtype=torch.float32).to_mkldnn() stats = run_profiler(create_cpu_tensor) check_metrics( stats, "cpu_memory_usage", allocs=[ "aten::empty", "aten::rand", "test_user_scope_alloc", ], deallocs=[ "test_user_scope_dealloc", ], ) if kineto_available(): with TemporaryFileName(mode="w+") as fname: with profile(profile_memory=True) as prof: x = None with record_function("test_user_scope_alloc"): x = create_cpu_tensor() with record_function("test_user_scope_dealloc"): del x prof.export_chrome_trace(fname) with open(fname) as f: trace = json.load(f) assert "traceEvents" in trace events = trace["traceEvents"] found_memory_events = False for evt in events: assert "name" in evt if evt["name"] == "[memory]": found_memory_events = True assert "args" in evt assert "Addr" in evt["args"] assert "Device Type" in evt["args"] assert "Device Id" in evt["args"] assert "Bytes" in evt["args"] # Memory should be an instantaneous event. assert "dur" not in evt["args"] assert "cat" not in evt["args"] assert found_memory_events if torch.cuda.is_available(): create_cuda_tensor() stats = run_profiler(create_cuda_tensor) check_metrics( stats, "device_memory_usage", allocs=[ "test_user_scope_alloc", "aten::to", "aten::empty_strided", ], deallocs=[ "test_user_scope_dealloc", ], ) check_metrics( stats, "cpu_memory_usage", allocs=[ "aten::rand", "aten::empty", ], ) if torch.xpu.is_available(): create_xpu_tensor() stats = run_profiler(create_xpu_tensor) check_metrics( stats, "device_memory_usage", allocs=[ "test_user_scope_alloc", "aten::to", "aten::empty_strided", ], deallocs=[ "test_user_scope_dealloc", ], ) check_metrics( stats, "cpu_memory_usage", allocs=[ "aten::rand", "aten::empty", ], ) if torch.backends.mkldnn.is_available(): create_mkldnn_tensor() stats = run_profiler(create_mkldnn_tensor) check_metrics( stats, "cpu_memory_usage", allocs=[ "test_user_scope_alloc", "aten::rand", "aten::empty", "aten::to_mkldnn", ], deallocs=[ "test_user_scope_dealloc", ], ) # check top-level memory events with _profile(profile_memory=True, use_kineto=kineto_available()) as prof: x = torch.rand(10, 10) del x if torch.cuda.is_available(): y = torch.rand(10, 10).cuda() del y elif torch.xpu.is_available(): y = torch.rand(10, 10).to("xpu") del y gc.collect() stats = prof.key_averages(group_by_input_shape=True) check_metrics( stats, "cpu_memory_usage", allocs=["aten::rand", "aten::empty"], deallocs=["[memory]"], ) if torch.cuda.is_available(): check_metrics(stats, "device_memory_usage", deallocs=["[memory]"]) elif torch.xpu.is_available(): check_metrics(stats, "device_memory_usage", deallocs=["[memory]"]) @unittest.skipIf( IS_JETSON, "Jetson has a guard against OOM since host and gpu memory are shared" ) def test_oom_tracing(self): def run_profiler(tensor_creation_fn): with _profile(profile_memory=True, record_shapes=True) as prof: with self.assertRaisesRegex(RuntimeError, ".*[tT]ried to allocate.*"): x = tensor_creation_fn() return prof def create_cuda_tensor_oom(): device = torch.device("cuda:0") return torch.empty( 1024, 1024, 1024, 1024, dtype=torch.float32, device=device ) def check_trace(fname): prof.export_chrome_trace(fname) with open(fname) as f: trace = json.load(f) self.assertTrue("traceEvents" in trace) events = trace["traceEvents"] found_out_of_memory_events = False for evt in events: self.assertTrue("name" in evt) if evt["name"] == "[OutOfMemory]": found_out_of_memory_events = True self.assertTrue("args" in evt) self.assertTrue("Device Type" in evt["args"]) self.assertTrue("Device Id" in evt["args"]) self.assertTrue("Bytes" in evt["args"]) # Memory should be an instantaneous event. self.assertTrue("dur" not in evt["args"]) self.assertTrue("cat" not in evt["args"]) self.assertTrue(found_out_of_memory_events) if torch.cuda.is_available(): with TemporaryFileName(mode="w+") as fname: prof = run_profiler(create_cuda_tensor_oom) check_trace(fname) @unittest.skipIf(not kineto_available(), "Kineto is required") def test_module_hierarchy(self): class A(nn.Module): def my_new_method(self, x): return x * 3 def forward_impl_(self, x, y): return self.my_new_method(x) + y def forward(self, x, y): y = y - 2 return self.forward_impl_(x, y) class B(nn.Module): def forward(self, x): return x + 2 class C(nn.Module): def __init__(self) -> None: super().__init__() self.A0 = A() self.B0 = B() def call_b(self, x): return self.B0.forward(x) def forward(self, x, y): return self.A0.forward(x, y) + self.call_b(x) model = C() model = torch.jit.script(model) input_a = torch.rand(128, 128) input_b = torch.rand(128, 128) op_to_module_hierarchy = {} op_to_module_hierarchy["aten::sub"] = ["TOP(C)::forward.A0(A)::forward."] op_to_module_hierarchy["aten::mul"] = [ "TOP(C)::forward.A0(A)::forward.SELF(A)::forward_impl_.SELF(A)::my_new_method." ] op_to_module_hierarchy["aten::add"] = [ "TOP(C)::forward.A0(A)::forward.SELF(A)::forward_impl_.", "TOP(C)::forward.SELF(C)::call_b.B0(B)::forward.", "TOP(C)::forward.", ] with TemporaryFileName(mode="w+") as fname: with profile( activities=[torch.profiler.ProfilerActivity.CPU], with_modules=True, ) as prof: model(input_a, input_b) prof.export_chrome_trace(fname) with open(fname) as f: trace = json.load(f) assert "traceEvents" in trace events = trace["traceEvents"] found_memory_events = False for evt in events: assert "name" in evt if "args" in evt: op_name = evt["name"] if "Module Hierarchy" in evt["args"]: hierarchy = evt["args"]["Module Hierarchy"] if op_name in op_to_module_hierarchy: assert hierarchy in op_to_module_hierarchy[op_name] def test_high_level_trace(self): """Checks that python side high level events are recorded.""" class RepeatedDataset(torch.utils.data.Dataset): def __init__(self, N, D_in, D_out): self.N = N self.x = torch.randn(N, D_in) self.y = torch.randn(N, D_out) def __len__(self): return self.N def __getitem__(self, idx): return self.x, self.y class TwoLayerNet(torch.nn.Module): def __init__(self, D_in, H, D_out): super().__init__() self.linear1 = torch.nn.Linear(D_in, H) self.linear2 = torch.nn.Linear(H, D_out) def forward(self, x): h_relu = self.linear1(x).clamp(min=0) y_pred = self.linear2(h_relu) return y_pred class CustomSGD(torch.optim.SGD): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) def train(): for data in dataloader: x, y = data[0], data[1] y_pred = model(x) loss = criterion(y_pred, y) optimizer.zero_grad() loss.backward() optimizer.step() N, D_in, H, D_out = 8, 10, 5, 2 model = TwoLayerNet(D_in, H, D_out) criterion = torch.nn.MSELoss(reduction="sum") optimizer = torch.optim.SGD(model.parameters(), lr=1e-4) ds = RepeatedDataset(N, D_in, D_out) dataloader = torch.utils.data.DataLoader(ds, batch_size=1) try: train() except Exception: self.assertTrue(False, "Expected no exception without profiling.") # Create multiple instances, expect each func is hooked only one time. # Nested wrappers(repeated patching) will make following test fail. optimizer_duplicate = torch.optim.SGD(model.parameters(), lr=1e-4) dataloader_duplicate = torch.utils.data.DataLoader(ds, batch_size=1) def judge(expected_event_count, prof): actual_event_count = {} for e in prof.function_events: if "#" in e.name: key = e.name if key in expected_event_count.keys(): actual_event_count[key] = ( actual_event_count.setdefault(key, 0) + 1 ) for key, count in expected_event_count.items(): self.assertTrue( (key in actual_event_count.keys()) and (count == actual_event_count[key]) ) with _profile(use_kineto=kineto_available()) as prof: train() expected_event_count = { # "+1" because the final iteration will enter __next__ but skip the loop body. "enumerate(DataLoader)#_SingleProcessDataLoaderIter.__next__": (N + 1), "Optimizer.step#SGD.step": N, "Optimizer.zero_grad#SGD.zero_grad": N, } judge(expected_event_count, prof) # Test on pickle/unpickle. Expect to work in multi-processing. optimizer = pickle.loads(pickle.dumps(optimizer)) with _profile(use_kineto=kineto_available()) as prof: train() judge(expected_event_count, prof) # Test on customized optimizer. optimizer = CustomSGD(model.parameters(), lr=1e-4) with _profile(use_kineto=kineto_available()) as prof: train() expected_event_count = { "enumerate(DataLoader)#_SingleProcessDataLoaderIter.__next__": (N + 1), "Optimizer.step#CustomSGD.step": N, "Optimizer.zero_grad#CustomSGD.zero_grad": N, } judge(expected_event_count, prof) def test_flops(self): model = torch.nn.Sequential( nn.Conv2d(16, 33, 18), nn.ReLU(), nn.Linear(243, 243), nn.ReLU(), ) inputs = torch.randn(40, 16, 18, 260) nested_tensor = torch.nested.nested_tensor( [torch.randn((2, 5)), torch.randn((3, 5))], layout=torch.jagged ) with _profile( record_shapes=True, with_flops=True, use_kineto=kineto_available() ) as prof: model(inputs) # test that nested tensor won't cause exception during flop compute nested_tensor = nested_tensor + nested_tensor profiler_output = prof.key_averages(group_by_input_shape=True).table( sort_by="cpu_time_total", row_limit=10 ) self.assertRegex(profiler_output, "Total M?FLOPs") if not (kineto_available() and torch.cuda.is_available()): return with profile( activities=[ torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA, ], record_shapes=True, with_flops=True, ) as kineto_profiler: model(inputs) profiler_output = kineto_profiler.key_averages().table( sort_by="self_cuda_time_total", row_limit=-1 ) self.assertRegex(profiler_output, "Total M?FLOPs") def test_override_time_units(self): US_IN_SECOND = 1000.0 * 1000.0 US_IN_MS = 1000.0 model = torch.nn.Sequential( nn.Conv2d(16, 33, 18), nn.ReLU(), nn.Linear(243, 243), nn.ReLU(), ) inputs = torch.randn(40, 16, 18, 260) with _profile() as prof: model(inputs) profiler_output = prof.key_averages().table(time_unit="s") self.assertRegex(profiler_output, r".*(\.[0-9]{3}s).*") self.assertNotRegex(profiler_output, r".*(\.[0-9]{3}ms).*") self.assertNotRegex(profiler_output, r".*(\.[0-9]{3}us).*") for event in prof.key_averages(): cpu_time_str_s = f"{event.cpu_time / US_IN_SECOND:.3f}s" cpu_time_total_str_s = f"{event.cpu_time_total / US_IN_SECOND:.3f}s" self.assertTrue(cpu_time_str_s in profiler_output) self.assertTrue(cpu_time_total_str_s in profiler_output) profiler_output = prof.key_averages().table(time_unit="ms") self.assertNotRegex(profiler_output, r".*(\.[0-9]{3}s).*") self.assertRegex(profiler_output, r".*(\.[0-9]{3}ms).*") self.assertNotRegex(profiler_output, r".*(\.[0-9]{3}us).*") for event in prof.key_averages(): cpu_time_str_ms = f"{event.cpu_time / US_IN_MS:.3f}ms" cpu_time_total_str_ms = f"{event.cpu_time_total / US_IN_MS:.3f}ms" self.assertTrue(cpu_time_str_ms in profiler_output) self.assertTrue(cpu_time_total_str_ms in profiler_output) profiler_output = prof.key_averages().table(time_unit="us") self.assertNotRegex(profiler_output, r".*(\.[0-9]{3}s).*") self.assertNotRegex(profiler_output, r".*(\.[0-9]{3}ms).*") self.assertRegex(profiler_output, r".*(\.[0-9]{3}us).*") for event in prof.key_averages(): cpu_time_str_us = f"{event.cpu_time:.3f}us" cpu_time_total_str_us = f"{event.cpu_time_total:.3f}us" self.assertTrue(cpu_time_str_us in profiler_output) self.assertTrue(cpu_time_total_str_us in profiler_output) @patch.dict(os.environ, {"KINETO_USE_DAEMON": "1"}) @patch.dict(os.environ, {"KINETO_DAEMON_INIT_DELAY_S": "1"}) def test_kineto_profiler_api(self): called_num = [0] use_cuda = torch.profiler.ProfilerActivity.CUDA in supported_activities() with profile(activities=supported_activities()): self.payload(use_cuda=use_cuda) def trace_handler(p): output = p.key_averages().table( sort_by="self_cuda_time_total" if use_cuda else "self_cpu_time_total", row_limit=-1, ) # print(output) # p.export_chrome_trace("/tmp/test_trace_" + str(called_num[0]) + ".json") called_num[0] += 1 initial_step = KinetoStepTracker.current_step() with profile( activities=supported_activities(), schedule=torch.profiler.schedule(wait=1, warmup=1, active=2), on_trace_ready=trace_handler, ) as p: for _ in range(8): self.payload(use_cuda=use_cuda) p.step() self.assertEqual(called_num[0], 2) self.assertEqual(KinetoStepTracker.current_step(), initial_step + 8) # case without schedule with profile(activities=supported_activities()) as p: self.payload(use_cuda=use_cuda) self.payload(use_cuda=use_cuda) output = p.key_averages().table( sort_by="self_cuda_time_total" if use_cuda else "self_cpu_time_total", row_limit=-1, ) # print(output) test_schedule = torch.profiler.schedule( skip_first=3, wait=2, warmup=1, active=4, repeat=2 ) test_schedule_expected_outputs = [ # skip first 3 ProfilerAction.NONE, ProfilerAction.NONE, ProfilerAction.NONE, # ---- # repeat No. 1 begin # wait 2 ProfilerAction.NONE, ProfilerAction.NONE, # warmup 1 ProfilerAction.WARMUP, # active 2 begin ProfilerAction.RECORD, ProfilerAction.RECORD, ProfilerAction.RECORD, ProfilerAction.RECORD_AND_SAVE, # active 2 end # repeat No. 1 end # --- # repeat No. 2 begin # wait 2 ProfilerAction.NONE, ProfilerAction.NONE, # warmup 1 ProfilerAction.WARMUP, # active 2 begin ProfilerAction.RECORD, ProfilerAction.RECORD, ProfilerAction.RECORD, ProfilerAction.RECORD_AND_SAVE, # active 2 end # repeat No. 2 end ProfilerAction.NONE, ProfilerAction.NONE, ProfilerAction.NONE, ProfilerAction.NONE, ] for step in range(len(test_schedule_expected_outputs)): self.assertEqual(test_schedule(step), test_schedule_expected_outputs[step]) @patch.dict(os.environ, {"KINETO_USE_DAEMON": "1"}) @patch.dict(os.environ, {"KINETO_DAEMON_INIT_DELAY_S": "1"}) def test_kineto_profiler_multiple_steppers(self): niters = 8 use_cuda = torch.profiler.ProfilerActivity.CUDA in supported_activities() net = SimpleNet() opt = torch.optim.SGD(net.parameters(), lr=0.01, momentum=0.9) opt.zero_grad() inputs = torch.rand(10) with profile(activities=supported_activities()): self.payload(use_cuda=use_cuda) def optimizer_step(): """This simulates a step() hook in the optimizer""" KinetoStepTracker.increment_step("yet_another_step") initial_step = KinetoStepTracker.current_step() def run_batch(): out = net(inputs) loss = torch.nn.functional.cross_entropy(out, torch.rand(2)) loss.backward() opt.step() # Manually call the hook. TODO: Remove this once we add the # profiler step hooks in the Optimizer class that will get triggered above. # See https://github.com/pytorch/pytorch/issues/88446 optimizer_step() for _ in range(niters): run_batch() with profile( activities=supported_activities(), schedule=torch.profiler.schedule(wait=1, warmup=1, active=2), ) as p: for _ in range(niters): run_batch() p.step() self.assertEqual(KinetoStepTracker.current_step(), initial_step + 2 * niters) def test_export_stacks(self): with _profile( with_stack=True, use_kineto=kineto_available(), experimental_config=_ExperimentalConfig(verbose=True), ) as p: x = torch.randn(10, 10) y = torch.randn(10, 10) z = torch.mm(x, y) z = z + y with TemporaryFileName(mode="w+") as fname: p.export_stacks(fname) with open(fname) as f: lines = f.readlines() assert len(lines) > 0, "Empty stacks file" for line in lines: is_int = False try: assert int(line.split(" ")[-1]) > 0, "Invalid stacks record" is_int = True except ValueError: pass assert is_int, "Invalid stacks record" @unittest.skipIf(not kineto_available(), "Kineto is required") def test_tensorboard_trace_handler(self): use_cuda = torch.profiler.ProfilerActivity.CUDA in supported_activities() with _profile(use_cuda=use_cuda, use_kineto=True): self.payload(use_cuda=use_cuda) with TemporaryDirectoryName() as dname: with profile( activities=[torch.profiler.ProfilerActivity.CPU] + ([torch.profiler.ProfilerActivity.CUDA] if use_cuda else []), schedule=torch.profiler.schedule(wait=1, warmup=1, active=2, repeat=3), on_trace_ready=torch.profiler.tensorboard_trace_handler(dname), ) as p: for _ in range(18): self.payload(use_cuda=use_cuda) p.step() self.assertTrue(os.path.exists(dname)) file_num = 0 for file_name in os.listdir(dname): parts = file_name.split(".") self.assertTrue(len(parts) > 4) self.assertTrue( parts[-4].isdigit() and int(parts[-4]) > 0, "Wrong tracing file name pattern", ) self.assertEqual(parts[-3:], ["pt", "trace", "json"]) file_num += 1 self.assertEqual(file_num, 3) # test case for gzip file format with TemporaryDirectoryName() as dname: p = profile( activities=[torch.profiler.ProfilerActivity.CPU] + ([torch.profiler.ProfilerActivity.CUDA] if use_cuda else []), schedule=torch.profiler.schedule(wait=1, warmup=1, active=2, repeat=3), on_trace_ready=torch.profiler.tensorboard_trace_handler( dname, use_gzip=True ), ) p.start() for _ in range(18): self.payload(use_cuda=use_cuda) p.step() p.stop() self.assertTrue(os.path.exists(dname)) file_num = 0 for file_name in os.listdir(dname): parts = file_name.split(".") self.assertTrue(len(parts) > 4) self.assertTrue( parts[-5].isdigit() and int(parts[-5]) > 0, "Wrong tracing file name pattern", ) self.assertEqual(parts[-4:], ["pt", "trace", "json", "gz"]) file_num += 1 self.assertEqual(file_num, 3) @unittest.skipIf(not kineto_available(), "Kineto is required") def test_profiler_metadata(self): t1, t2 = torch.ones(1), torch.ones(1) with profile() as prof: torch.add(t1, t2) prof.add_metadata("test_key1", "test_value1") prof.add_metadata_json("test_key2", "[1,2,3]") with TemporaryFileName(mode="w+") as fname: prof.export_chrome_trace(fname) with open(fname) as f: trace = json.load(f) assert "test_key1" in trace assert trace["test_key1"] == "test_value1" assert "test_key2" in trace assert trace["test_key2"] == [1, 2, 3] def _test_profiler_tracing(self, use_kineto): with _profile(use_kineto=use_kineto) as prof: t1, t2 = torch.ones(1), torch.ones(1) torch.add(t1, t2) with TemporaryFileName(mode="w+") as fname: prof.export_chrome_trace(fname) # read the trace and expect valid json # if the JSON generated by export_chrome_trace is not valid, this will throw and fail the test. with open(fname) as f: json.load(f) # test empty trace with _profile(use_kineto=use_kineto) as prof: pass # saving an empty trace with TemporaryFileName(mode="w+") as fname: prof.export_chrome_trace(fname) if use_kineto: with open(fname) as f: contents = json.load(f) # Some builds may not have logger observer # so skip if not if "WARNING" in contents: found_empty_warning = False for warning in contents["WARNING"]: if "No Valid Trace Events" in warning: found_empty_warning = True self.assertTrue(found_empty_warning) # Same test but for cuda. use_cuda = torch.profiler.ProfilerActivity.CUDA in supported_activities() if not use_cuda: return device = torch.device("cuda:0") with _profile(use_cuda=True, use_kineto=use_kineto) as prof: t1, t2 = torch.ones(1, device=device), torch.ones(1, device=device) torch.add(t1, t2) with TemporaryFileName(mode="w+") as fname: prof.export_chrome_trace(fname) # Now validate the json with open(fname) as f: json.load(f) def test_profiler_tracing(self): self._test_profiler_tracing(False) if kineto_available(): self._test_profiler_tracing(True) def test_profiler_op_event_args(self): torch._C._profiler._set_record_concrete_inputs_enabled_val(True) with _profile(record_shapes=True) as prof: a = torch.ones((64, 32), dtype=torch.float32) c = torch.cat([a, a]).sin() with TemporaryFileName(mode="w+") as fname: prof.export_chrome_trace(fname) with open(fname) as f: j = json.load(f) op_events = [ e for e in j["traceEvents"] if e.get("cat", "") == "cpu_op" ] for e in op_events: args = e["args"] if e["name"] == "aten::ones": self.assertEqual( args["Input type"], ["ScalarList", "Scalar", "", "", "Scalar"], ) self.assertEqual( args["Concrete Inputs"], ["[64, 32]", "6", "", "", "False"] ) if e["name"] == "aten::cat": self.assertEqual(args["Input Dims"], [[[64, 32], [64, 32]], []]) self.assertEqual(args["Input type"], ["TensorList", "Scalar"]) # check that each op has record function id self.assertGreaterEqual( args.get("Record function id", -1), 0, f"Failed finding record funciont for op = {e}", ) def test_profiler_strides(self): torch._C._profiler._set_record_concrete_inputs_enabled_val(True) base_tensor = torch.randn(1024, dtype=torch.float32) a = base_tensor.as_strided((16, 16), (17, 1), 0) b = base_tensor.as_strided((16, 16), (25, 2), 272) with _profile(record_shapes=True) as prof: c = torch.add(a, b) with TemporaryFileName(mode="w+") as fname: prof.export_chrome_trace(fname) with open(fname) as f: j = json.load(f) op_events = [ e for e in j["traceEvents"] if e.get("cat", "") == "cpu_op" ] for e in op_events: args = e["args"] if e["name"] == "aten::add": self.assertEqual(args["Input Strides"], [[17, 1], [25, 2], []]) def test_profiler_fwd_bwd_link(self): with _profile(use_kineto=True) as prof: t1, t2 = ( torch.ones(1, requires_grad=True), torch.ones(1, requires_grad=True), ) z = torch.add(t1, t2) y = torch.ones(1) loss = torch.nn.functional.binary_cross_entropy_with_logits(z, y) loss.backward() with TemporaryFileName(mode="w+") as fname: prof.export_chrome_trace(fname) with open(fname) as f: j = json.load(f) events = j["traceEvents"] ts_to_name = {} flow_s_to_ts = {} flow_f_to_ts = {} for e in events: if e["ph"] == "X": ts_to_name[e["ts"]] = e["name"] if ( "cat" in e and "name" in e and e["cat"] == "fwdbwd" and e["name"] == "fwdbwd" ): if e["ph"] == "s": flow_s_to_ts[e["id"]] = e["ts"] elif e["ph"] == "f": flow_f_to_ts[e["id"]] = e["ts"] self.assertEqual(len(flow_s_to_ts), 2) self.assertEqual(len(flow_f_to_ts), 2) self.assertIn(1, flow_s_to_ts) self.assertIn(1, flow_f_to_ts) self.assertIn(2, flow_s_to_ts) self.assertIn(2, flow_f_to_ts) s_ts_1 = flow_s_to_ts[1] f_ts_1 = flow_f_to_ts[1] s_ts_2 = flow_s_to_ts[2] f_ts_2 = flow_f_to_ts[2] self.assertTrue( all( ts in ts_to_name.keys() for ts in [s_ts_1, f_ts_1, s_ts_2, f_ts_2] ) ) self.assertTrue( ts_to_name[s_ts_1] == "aten::binary_cross_entropy_with_logits" ) self.assertTrue(ts_to_name[s_ts_2] == "aten::add") def test_profiler_disable_fwd_bwd_link(self): try: torch._C._profiler._set_fwd_bwd_enabled_val(False) with _profile(use_kineto=True) as prof: t1, t2 = ( torch.ones(1, requires_grad=True), torch.ones(1, requires_grad=True), ) z = torch.add(t1, t2) y = torch.ones(1) loss = torch.nn.functional.binary_cross_entropy_with_logits(z, y) loss.backward() with TemporaryFileName(mode="w+") as fname: prof.export_chrome_trace(fname) with open(fname) as f: j = json.load(f) events = j["traceEvents"] for e in events: self.assertNotEqual(e.get("cat", None), "fwdbwd") finally: torch._C._profiler._set_fwd_bwd_enabled_val(True) @unittest.skipIf(not kineto_available(), "Kineto is required") @unittest.skipIf(not torch.cuda.is_available(), "CUDA is required") def test_profiler_cuda_sync_events(self): device = torch.device("cuda:0") t1, t2 = torch.ones(1, device=device), torch.ones(1, device=device) def workload() -> None: torch.add(t1, t2) torch.cuda.synchronize() torch.add(t1, t2) def trace_and_check(exp_config: Optional[_ExperimentalConfig]) -> None: with _profile( use_kineto=True, use_cuda=True, experimental_config=exp_config, ) as prof: workload() with TemporaryFileName(mode="w+") as fname: # fname = "/tmp/kineto_out.json" prof.export_chrome_trace(fname) with open(fname) as f: j = json.load(f) cats = {e.get("cat", None) for e in j["traceEvents"]} self.assertTrue( "cuda_sync" in cats, f"Expected to find cuda_sync event found = {cats}", ) print("Testing enable_cuda_sync_events in _ExperimentalConfig") trace_and_check(exp_config=_ExperimentalConfig(enable_cuda_sync_events=True)) print("Testing _profiler._set_cuda_sync_enabled_val()") try: torch._C._profiler._set_cuda_sync_enabled_val(True) trace_and_check(exp_config=None) finally: torch._C._profiler._set_cuda_sync_enabled_val(False) def test_profiler_type(self): profiler_type = torch._C._autograd._profiler_type ActiveProfilerType = torch._C._profiler.ActiveProfilerType self.assertEqual(profiler_type(), ActiveProfilerType.NONE) # Autograd profiler with _profile_legacy(): self.assertEqual(profiler_type(), ActiveProfilerType.LEGACY) # Kineto profiler with profile(): self.assertEqual(profiler_type(), ActiveProfilerType.KINETO) def test_profiler_correlation_id(self): """ We expect the correlation_id to be unique across multiple invokation of the profiler, So we will reuse id_uniqueness_set. """ id_uniqueness_set = set() model = torch.nn.Sequential( nn.Conv2d(16, 33, 18), nn.ReLU(), nn.Linear(243, 243), nn.ReLU(), ) inputs = torch.randn(40, 16, 18, 260) uint32_max = 2**32 - 1 for _ in range(5): with profile() as prof: model(inputs) for event in prof.profiler.kineto_results.events(): corr_id = event.correlation_id() if (corr_id) and event.device_type() == DeviceType.CPU: self.assertTrue(corr_id not in id_uniqueness_set) id_uniqueness_set.add(corr_id) self.assertTrue(corr_id < uint32_max) def test_nested_tensor_with_shapes(self): a = torch.randn(4, 4) b = torch.randn(4, 4) c = torch.randn(4, 4) inp = torch.nested.nested_tensor([a, b]) with torch.profiler.profile(record_shapes=True) as prof: torch.nn.functional.linear(inp, c, None) for e in prof.events(): if e.name in ("aten::mm", "aten::addmm"): # intentionally vague tests to protect against possible future changes # of mm to addmm or other impl, or changing internal order of args self.assertTrue(len(e.input_shapes) > 0) self.assertTrue(len(e.input_shapes[0]) > 0) @patch.dict(os.environ, {"KINETO_USE_DAEMON": "1"}) @patch.dict(os.environ, {"KINETO_DAEMON_INIT_DELAY_S": "1"}) @skipIfTorchDynamo("profiler gets ignored if dynamo activated") def test_kineto_profiler_with_environment_variable(self): script = """ import torch import torch.nn as nn from torch.profiler import supported_activities, profile from torch.autograd.profiler import KinetoStepTracker class SimpleNet(nn.Module): def __init__(self) -> None: super().__init__() self.fc1 = nn.Linear(10, 5) self.fc2 = nn.Linear(5, 2) def forward(self, x): return self.fc2(self.fc1(x)) def payload(use_cuda=False): x = torch.randn(10, 10) if use_cuda: x = x.cuda() y = torch.randn(10, 10) if use_cuda: y = y.cuda() z = torch.mm(x, y) z = z + y if use_cuda: z = z.cpu() niters = 8 use_cuda = torch.profiler.ProfilerActivity.CUDA in supported_activities() net = SimpleNet() opt = torch.optim.SGD(net.parameters(), lr=0.01) opt.zero_grad() inputs = torch.rand(10) with profile(activities=supported_activities()): payload(use_cuda=use_cuda) initial_step = KinetoStepTracker.current_step() def run_batch(): out = net(inputs) loss = torch.nn.functional.cross_entropy(out, torch.rand(2)) loss.backward() opt.step() for _ in range(niters): run_batch() with profile( activities=supported_activities(), schedule=torch.profiler.schedule( wait=1, warmup=1, active=2), ) as p: for _ in range(niters): run_batch() p.step() assert KinetoStepTracker.current_step() == initial_step + 2 * niters """ try: subprocess.check_output( [sys.executable, "-W", "always", "-c", script], cwd=os.path.dirname(os.path.realpath(__file__)), ) except subprocess.CalledProcessError as e: if e.returncode != 0: self.assertTrue( False, "Kineto is not working properly with the Dynolog environment variable", ) def test_concrete_inputs_profiling(self): x = torch.rand(2, 6) with profile(record_shapes=True) as p: y = x.as_strided([4, 3], [1, 4]) found = False for e in p.events(): if e.name in ("aten::as_strided"): found = True self.assertTrue(len(e.input_shapes) > 0) self.assertTrue(len(e.concrete_inputs) > 0) self.assertEqual([2, 6], e.input_shapes[0]) self.assertEqual([4, 3], e.concrete_inputs[1]) self.assertEqual([1, 4], e.concrete_inputs[2]) self.assertTrue(found, "Expected to find aten::as_strided but did not") def test_concrete_inputs_profiling_toggling(self): try: for before, after in [(True, False), (False, True)]: x = torch.rand(2, 6) torch._C._profiler._set_record_concrete_inputs_enabled_val(before) with profile(record_shapes=True) as p: y = x.as_strided([4, 3], [1, 4]) torch._C._profiler._set_record_concrete_inputs_enabled_val(after) found = False for e in p.events(): if e.name in ("aten::as_strided"): found = True self.assertTrue(len(e.input_shapes)) self.assertTrue(found, "Expected to find aten::as_strided but did not") finally: torch._C._profiler._set_record_concrete_inputs_enabled_val(True) def test_record_function_fast(self): x, y = (torch.rand((4, 4)) for _ in range(2)) with profile(record_shapes=True) as p: for _ in range(4): # Test first with no optional args with torch._C._profiler._RecordFunctionFast("add_test_fast_rf1"): x.add(y) self.assertGreaterEqual( len([e for e in p.events() if e.name == "add_test_fast_rf1"]), 4 ) for e in p.events(): if e.name == "add_test_fast_rf1": self.assertTrue(e.input_shapes == []) self.assertTrue(e.kwinputs == {}) with profile(record_shapes=True) as p: # add optional args cm = torch._C._profiler._RecordFunctionFast( "add_test_fast_rf2", [x, y], {"stream": 0, "grid": "lambda x : x + 1"} ) for _ in range(4): with cm: x.add(y) self.assertGreaterEqual( len([e for e in p.events() if e.name == "add_test_fast_rf2"]), 4 ) for e in p.events(): if e.name == "add_test_fast_rf2": self.assertTrue(e.input_shapes == [[4, 4], [4, 4]]) self.assertTrue(e.kwinputs == {"stream": 0, "grid": "lambda x : x + 1"}) with profile(record_shapes=True) as p: cm = torch._C._profiler._RecordFunctionFast( "add_test_fast_rf3", input_values=["hi"], keyword_values={"hi": "hello"} ) for _ in range(4): try: with cm: x.add(y) raise ValueError x.relu() except ValueError: pass self.assertGreaterEqual( len([e for e in p.events() if e.name == "add_test_fast_rf3"]), 4 ) self.assertFalse(any((e.name and "relu" in e.name) for e in p.events())) for e in p.events(): if e.name == "add_test_fast_rf3": self.assertTrue(e.input_shapes == [[]]) with profile() as p: for _ in range(4): with torch._C._profiler._RecordFunctionFast( "add_test_fast_rf4", [x, y] ): x.add(y) with torch._C._profiler._RecordFunctionFast("add_test_fast_rf5"): x.relu() self.assertGreaterEqual( len([e for e in p.events() if e.name == "add_test_fast_rf4"]), 4 ) for e in p.events(): if e.name == "add_test_fast_rf4": self.assertTrue(e.input_shapes == []) self.assertGreaterEqual( len([e for e in p.events() if e.name == "add_test_fast_rf5"]), 4 ) with profile(record_shapes=True) as p: # test optional args with tuple cm = torch._C._profiler._RecordFunctionFast( "add_test_fast_rf6", ( x, y, ), ) for _ in range(4): with cm: x.add(y) self.assertGreaterEqual( len([e for e in p.events() if e.name == "add_test_fast_rf6"]), 4 ) for e in p.events(): if e.name == "add_test_fast_rf6": self.assertTrue(e.input_shapes == [[4, 4], [4, 4]]) @skipIfTorchDynamo("profiler gets ignored if dynamo activated") def test_profiler_op_event_kwargs(self): x, y = (torch.rand((4, 4)) for _ in range(2)) with profile(record_shapes=True) as p: cm = torch._C._profiler._RecordFunctionFast( "add_test_kwinputs", [x, y], { "stream": 0, "grid": "lambda x : x + 1", "debug": 'debug"', "boolean": True, }, ) for _ in range(4): with cm: x.add(y) with TemporaryFileName(mode="w+") as fname: p.export_chrome_trace(fname) with open(fname) as f: j = json.load(f) op_events = [ e for e in j["traceEvents"] if e.get("name", "") == "add_test_kwinputs" ] self.assertTrue(len(op_events) > 0) for e in op_events: args = e["args"] self.assertTrue("stream" in args) self.assertTrue("grid" in args) self.assertTrue("boolean" in args) self.assertTrue(args["stream"] == 0) self.assertTrue(args["grid"] == "lambda x : x + 1") self.assertTrue(args["debug"] == "None") self.assertTrue(args["boolean"]) self.assertTrue(e["cat"] == "cpu_op") with profile(record_shapes=True) as p1: cm = torch._C._profiler._RecordFunctionFast( "add_test_kwinputs", [x, y], {"stream": "test", "grid": [1, 2], "scope": "user_scope"}, ) for _ in range(4): with cm: x.add(y) with TemporaryFileName(mode="w+") as fname1: p1.export_chrome_trace(fname1) with open(fname1) as f1: j = json.load(f1) op_events = [ e for e in j["traceEvents"] if e.get("name", "") == "add_test_kwinputs" ] self.assertTrue(len(op_events) > 0) for e in op_events: args = e["args"] self.assertTrue("stream" not in args) self.assertTrue("grid" not in args) self.assertTrue(e["cat"] == "user_annotation") @skipIfTorchDynamo("profiler gets ignored if dynamo activated") def test_profiler_op_event_kwargs_list_of_strings(self): x, y = (torch.rand((4, 4)) for _ in range(2)) with profile(record_shapes=True) as p: cm = torch._C._profiler._RecordFunctionFast( "add_test_kwinputs_string_list", [x, y], { "string_list": ["hello", "world", "test"], "int_param": 42, "string_param": "single_string", }, ) for _ in range(4): with cm: x.add(y) with TemporaryFileName(mode="w+") as fname: p.export_chrome_trace(fname) with open(fname) as f: j = json.load(f) op_events = [ e for e in j["traceEvents"] if e.get("name", "") == "add_test_kwinputs_string_list" ] self.assertTrue(len(op_events) > 0) for e in op_events: args = e["args"] self.assertTrue("string_list" in args) self.assertTrue("int_param" in args) self.assertTrue("string_param" in args) # Check that the list of strings is properly serialized # The list should be formatted as a JSON array by ivalueListToStr self.assertEqual(args["string_list"], ["hello", "world", "test"]) self.assertEqual(args["int_param"], 42) self.assertEqual(args["string_param"], "single_string") self.assertTrue(e["cat"] == "cpu_op") # Test mixed types that should be filtered out with profile(record_shapes=True) as p1: cm = torch._C._profiler._RecordFunctionFast( "add_test_kwinputs_string_list_filtered", [x, y], { "valid_string_list": ["valid1", "valid2"], "mixed_list": ["string", 123], # Should be filtered out "non_string_list": [1, 2, 3], # Should be filtered out "valid_int": 100, }, ) for _ in range(4): with cm: x.add(y) with TemporaryFileName(mode="w+") as fname1: p1.export_chrome_trace(fname1) with open(fname1) as f1: j = json.load(f1) op_events = [ e for e in j["traceEvents"] if e.get("name", "") == "add_test_kwinputs_string_list_filtered" ] self.assertTrue(len(op_events) > 0) for e in op_events: args = e["args"] # Only valid types should be present self.assertTrue("valid_string_list" in args) self.assertTrue("valid_int" in args) # Invalid lists should be filtered out self.assertTrue("mixed_list" not in args) self.assertTrue("non_string_list" not in args) # Check values self.assertEqual(args["valid_string_list"], ["valid1", "valid2"]) self.assertEqual(args["valid_int"], 100) self.assertTrue(e["cat"] == "cpu_op") def test_is_profiler_enabled(self): self.assertFalse(torch.autograd.profiler._is_profiler_enabled) with profile() as p: self.assertTrue(torch.autograd.profiler._is_profiler_enabled) self.assertFalse(torch.autograd.profiler._is_profiler_enabled) with torch.autograd.profiler.profile() as p: self.assertTrue(torch.autograd.profiler._is_profiler_enabled) self.assertFalse(torch.autograd.profiler._is_profiler_enabled) def test_guarded_record_function_fast(self): x, y = (torch.rand((4, 4)) for _ in range(2)) with profile() as p: cm = torch._C._profiler._RecordFunctionFast("guarded_rff") for _ in range(4): if torch.autograd.profiler._is_profiler_enabled: with cm: x.add(y) else: x.add(y) self.assertGreaterEqual( len([e for e in p.events() if e.name == "guarded_rff"]), 4 ) @unittest.skipIf(not torch.cuda.is_available(), "CUDA is required") def test_event_list(self): # AFAIK event list is part of legacy profiler and/or used when kineto is not available. # This test has basic sanity checks to test against obvious regressions. x, y = (torch.rand((4, 4), requires_grad=True, device="cuda") for _ in range(2)) with profile(with_stack=True) as p: z = (x @ y).relu().sum() z.backward() event_list = torch.autograd.profiler_util.EventList(p.events()) # event_list._build_tree() with TemporaryFileName(mode="w+") as fname: event_list.export_chrome_trace(fname) with open(fname) as f: json.load(f) event_list.table() def _check_all_gpu_present(self, gpu_dict, max_gpu_count): for i in range(max_gpu_count): self.assertEqual(gpu_dict["GPU " + str(i)], 1) # Do json sanity testing. Checks that all events are between profiler start and end # also checks to see that GPU values are present in trace if cuda is used def _validate_basic_json(self, traceEvents, cuda_available=False): MAX_GPU_COUNT = 8 PROFILER_IDX = -4 RECORD_END = -1 RECORD_START = -2 traceEventProfiler = traceEvents[PROFILER_IDX] self.assertTrue(traceEventProfiler["name"] == "PyTorch Profiler (0)") self.assertTrue(traceEvents[RECORD_END]["name"] == "Record Window End") self.assertTrue( traceEvents[RECORD_START]["name"] == "Iteration Start: PyTorch Profiler" ) # check that the profiler starts/ends within the record interval self.assertGreaterEqual( traceEventProfiler["ts"], traceEvents[RECORD_START]["ts"], "Profiler starts before record!", ) self.assertLessEqual( traceEventProfiler["ts"] + traceEventProfiler["dur"], traceEvents[RECORD_END]["ts"], "Profiler ends after record end!", ) gpu_dict = collections.defaultdict(int) for i, traceEvent in enumerate(traceEvents): if ( i == len(traceEvents) + RECORD_END or i == len(traceEvents) + RECORD_START ): continue # make sure all valid trace events are within the bounds of the profiler if "ts" in traceEvent: self.assertGreaterEqual( traceEvent["ts"], traceEventProfiler["ts"], "Trace event is out of bounds", ) # some python events seem to go a little past record end probably because # of some clock inaccuracies so just compare events ending to RECORD_END if "dur" in traceEvent: self.assertLessEqual( traceEvent["ts"] + traceEvent["dur"], traceEvents[RECORD_END]["ts"], "Trace event ends too late!", ) gpu_value = traceEvent.get("args", {}).get("labels", None) if gpu_value and "GPU" in gpu_value: gpu_dict[gpu_value] += 1 # Max PID offset is 5M, based from pytorch/kineto include header: # https://github.com/pytorch/kineto/blob/8681ff11e1fa54da39023076c5c43eddd87b7a8a/libkineto/include/output_base.h#L35 kExceedMaxPid = 5000000 self.assertTrue( traceEvents[i + 1]["args"]["sort_index"] == kExceedMaxPid + int(gpu_value.split()[1]) ) # TODO add checking gpu count if cpuOnly_ is true or not def _test_chrome_trace_basic_helper(self, with_cuda=False): if with_cuda: device = "cuda" else: device = "cpu" x, y = (torch.rand(4, 4).to(device) for _ in range(2)) with profile(with_stack=True) as p: torch.add(x, y) with TemporaryFileName(mode="w+") as fname: p.export_chrome_trace(fname) with open(fname) as f: report = json.load(f) self._validate_basic_json(report["traceEvents"], with_cuda) @unittest.skipIf(not kineto_available(), "Kineto is required") @skipIfTorchDynamo("profiler gets ignored if dynamo activated") def test_basic_chrome_trace(self): self._test_chrome_trace_basic_helper() if torch.cuda.is_available(): self._test_chrome_trace_basic_helper(with_cuda=True) @skipIfTorchDynamo("profiler gets ignored if dynamo activated") def test_profiler_time_scale(self): MARGIN_ERROR = 0.5 SEC_TO_US = 1000 * 1000 WAIT_TIME = 10 with profile() as p: with torch.profiler.record_function("test_span"): for _ in range(WAIT_TIME): torch.rand(4, 4) time.sleep(1) events = p.events() # make sure function events are scaled appropriately self.assertTrue(events[0].name == "test_span") test_span = events[0] self.assertGreaterEqual( test_span.cpu_time / SEC_TO_US, WAIT_TIME - MARGIN_ERROR, "event out of range", ) self.assertLessEqual( test_span.cpu_time / SEC_TO_US, WAIT_TIME + MARGIN_ERROR, "event out of range", ) # make sure tracing is scaled appropriately with TemporaryFileName(mode="w+") as fname: p.export_chrome_trace(fname) with open(fname) as f: report = json.load(f) events = report["traceEvents"] for event in events: if event["name"] == "test_span": self.assertGreaterEqual( event["dur"] / SEC_TO_US, WAIT_TIME - MARGIN_ERROR, "profiling out of range", ) self.assertLessEqual( event["dur"] / SEC_TO_US, WAIT_TIME + MARGIN_ERROR, "profiling out of range", ) def _schedule_helper(self, warmup, active, repeat, acc_events=True): with profile( schedule=torch.profiler.schedule( skip_first=0, wait=0, warmup=warmup, active=active, repeat=repeat, ), acc_events=acc_events, ) as prof: for _ in range(100): torch.add(1, 2) prof.step() # print(prof.key_averages()) for ev in prof.key_averages(): if ev.key == "aten::add": return ev.count return 0 @skipIfTorchDynamo("profiler gets ignored if dynamo activated") def test_schedule_function_count(self): self.assertEqual(self._schedule_helper(warmup=0, active=1, repeat=1), 1) self.assertEqual(self._schedule_helper(warmup=0, active=5, repeat=0), 100) self.assertEqual(self._schedule_helper(warmup=0, active=5, repeat=10), 50) self.assertEqual(self._schedule_helper(warmup=1, active=5, repeat=0), 83) self.assertEqual(self._schedule_helper(warmup=10, active=10, repeat=4), 40) self.assertEqual(self._schedule_helper(warmup=50, active=1, repeat=0), 1) self.assertEqual( self._schedule_helper(warmup=0, active=5, repeat=0, acc_events=False), 0 ) self.assertEqual( self._schedule_helper(warmup=10, active=10, repeat=4, acc_events=False), 10 ) def _step_helper_func(self, prof): time.sleep(0.1) torch.randn(1, 3, 224, 224) prof.step() def _partial_overlap(self, prof_step, step_helper_func): p_start = prof_step["ts"] p_end = prof_step["ts"] + prof_step["dur"] h_start = step_helper_func["ts"] h_end = step_helper_func["ts"] + step_helper_func["dur"] if p_start < h_start and p_end < h_end and p_end > h_start: return True if p_start > h_start and p_start < h_end and p_end > h_end: return True return False @skipIfTorchDynamo("profiler gets ignored if dynamo activated") def test_cpu_annotation_overlap(self): with torch.profiler.profile( activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True, with_stack=True, schedule=torch.profiler.schedule(wait=0, warmup=0, active=5, repeat=1), experimental_config=torch._C._profiler._ExperimentalConfig( adjust_profiler_step=True ), ) as prof: for i in range(5): self._step_helper_func(prof) with TemporaryFileName(mode="w+") as fname: prof.export_chrome_trace(fname) prof_steps = [] step_helper_funcs = [] with open(fname) as f: report = json.load(f) for event in report["traceEvents"]: if "ProfilerStep" in event["name"]: prof_steps.append(event) if "step_helper_func" in event["name"]: step_helper_funcs.append(event) self.assertEqual(len(prof_steps), 5) self.assertEqual(len(step_helper_funcs), 5) for i in range(len(step_helper_funcs)): for j in range(len(step_helper_funcs)): self.assertTrue( not self._partial_overlap(prof_steps[i], step_helper_funcs[j]) ) @skipIfTorchDynamo("profiler gets ignored if dynamo activated") def test_user_annotation(self): use_cuda = torch.profiler.ProfilerActivity.CUDA in supported_activities() with profile(activities=supported_activities()) as p: with torch.profiler.record_function("test_user_annotation"): self.payload(use_cuda=use_cuda) for evt in p.key_averages(): if evt.key == "test_user_annotation": self.assertTrue(evt.is_user_annotation) else: self.assertFalse(evt.is_user_annotation) @unittest.skipUnless(TEST_CUDA or TEST_XPU, "requires gpu") @skipIfTorchDynamo("profiler gets ignored if dynamo activated") def test_basic_profile(self): # test a really basic profile to make sure no erroneous aten ops are run x = torch.randn(4, device="cuda") with torch.profiler.profile(with_stack=True) as p: x *= 2 names = [e.name for e in p.events()] for name in names: if name.startswith("aten") and name != "aten::mul_": self.assertTrue(False, "Found unexpected event: " + name) self.assertTrue("aten::mul_" in names) @unittest.skipIf(not torch.cuda.is_available(), "CUDA is required") @skipIfTorchDynamo("profiler gets ignored if dynamo activated") def test_dynamic_toggle(self): acc = torch.accelerator.current_accelerator() self.assertIsNotNone(acc) device = acc.type gpu_activity = getattr(ProfilerActivity, device.upper(), None) self.assertIsNotNone(gpu_activity) activities = [ProfilerActivity.CPU, gpu_activity] with profile(activities=activities) as p: with torch.profiler.record_function("test_user_annotation"): x, y = (torch.rand(4, 4).to(device) for _ in range(2)) torch.add(x, y) self.assertTrue(any("aten" in e.name for e in p.events())) self.assertTrue(any(device in e.name for e in p.events())) self.assertTrue(any("kernel" in e.name.lower() for e in p.events())) with profile(activities=activities) as p1: p1.toggle_collection_dynamic(False, [gpu_activity]) with torch.profiler.record_function("test_user_annotation"): x, y = (torch.rand(4, 4).to(device) for _ in range(2)) torch.add(x, y) self.assertTrue(any("aten" in e.name for e in p1.events())) self.assertTrue(all(device not in e.name for e in p1.events())) self.assertTrue(all("kernel" not in e.name.lower() for e in p1.events())) with profile(activities=activities) as p2: p2.toggle_collection_dynamic(False, activities) with torch.profiler.record_function("test_user_annotation"): x, y = (torch.rand(4, 4).to(device) for _ in range(2)) torch.add(x, y) self.assertTrue(len(p2.events()) == 0) @skipIfTorchDynamo("profiler gets ignored if dynamo activated") def test_lazy_build_tree(self): with profile() as p: self.payload() stats = p._stats() # Test that the tree is not built self.assertEqual(stats.function_events_build_tree_call_duration_us, 0) self.assertEqual(stats.number_of_events, 0) # Test that the tree is built on demand p.events() self.assertGreater(stats.function_events_build_tree_call_duration_us, 0) self.assertGreater(stats.number_of_events, 0) @skipIfTorchDynamo("profiler gets ignored if dynamo activated") @unittest.skipIf( torch.cuda.is_available(), "CUDA complains about forking after init" ) @unittest.skipIf(IS_WINDOWS, "can't use os.fork() on Windows") def test_forked_process(self): # Induce a pid cache by running the profiler with payload def validate_forked_json(profiler): nonlocal cpu_op_found, parent_tid, child_pid with TemporaryFileName(mode="w+") as fname: profiler.export_chrome_trace(fname) with open(fname) as f: events = json.load(f)["traceEvents"] for event in events: if "cat" in event and event["cat"] == "cpu_op": self.assertEqual(event["pid"], child_pid) self.assertNotEqual(event["tid"], parent_tid) cpu_op_found = True cpu_op_found = False parent_tid = threading.current_thread().ident with profile() as p: self.payload() pid = os.fork() if pid == 0: child_pid = os.getpid() with profile() as p: self.payload() validate_forked_json(p) self.assertTrue(cpu_op_found) os._exit(0) else: os.waitpid(pid, 0) @skipIfTorchDynamo("profiler gets ignored if dynamo activated") def test_skip_first_wait(self): # Other tests test when skip_first_wait is false (default) so just test the true case test_schedule = torch.profiler.schedule( skip_first=3, wait=5, warmup=1, active=2, repeat=2, skip_first_wait=1 ) test_schedule_expected_outputs = [ # repeat No. 1 begin # skip first 3 ProfilerAction.NONE, ProfilerAction.NONE, ProfilerAction.NONE, # warmup 1 ProfilerAction.WARMUP, # active 1 begin ProfilerAction.RECORD, ProfilerAction.RECORD_AND_SAVE, # active 1 end # repeat No. 1 end # --- # repeat No. 2 begin # wait 5 ProfilerAction.NONE, ProfilerAction.NONE, ProfilerAction.NONE, ProfilerAction.NONE, ProfilerAction.NONE, # warmup 1 ProfilerAction.WARMUP, # active 2 begin ProfilerAction.RECORD, ProfilerAction.RECORD_AND_SAVE, # active 2 end # repeat No. 2 end ProfilerAction.NONE, ProfilerAction.NONE, ProfilerAction.NONE, ProfilerAction.NONE, ] for step in range(len(test_schedule_expected_outputs)): self.assertEqual(test_schedule(step), test_schedule_expected_outputs[step]) @skipIfTorchDynamo("profiler gets ignored if dynamo activated") @unittest.skipIf(not torch.cuda.is_available(), "CUDA is required") @unittest.skipIf(not kineto_available(), "Kineto is required") def test_disable_external_correlation(self): cuda_external_id_events = {"cuda_runtime", "gpu_memcpy", "kernel"} activities = [ProfilerActivity.CPU, ProfilerActivity.CUDA] def check_correlations(event, disable_external_correlation): if "cat" in event and event["cat"] in cuda_external_id_events: if disable_external_correlation: self.assertTrue("External id" not in event["args"]) elif event["name"] != "cudaDeviceSynchronize": self.assertTrue("External id" in event["args"]) self.assertTrue(event["args"]["External id"] > 0) def validate_json(prof, disable_external_correlation): with TemporaryFileName(mode="w+") as fname: prof.export_chrome_trace(fname) with open(fname) as f: events = json.load(f)["traceEvents"] seen_event_types = set() for event in events: check_correlations(event, disable_external_correlation) if "cat" in event: seen_event_types.add(event["cat"]) self.assertTrue(cuda_external_id_events.issubset(seen_event_types)) # Run with External Id for CUDA events on and off for disable_external_correlation in [False, True]: with profile( activities=activities, experimental_config=torch._C._profiler._ExperimentalConfig( disable_external_correlation=disable_external_correlation ), ) as prof: self.payload(use_cuda=True) validate_json(prof, disable_external_correlation) @skipIfTorchDynamo("profiler gets ignored if dynamo activated") @unittest.skipIf(not torch.cuda.is_available(), "CUDA is required") @unittest.skipIf(not kineto_available(), "Kineto is required") @unittest.skipIf( "RelWithAssert" in torch.__config__.show(), "failing in debug build, see https://github.com/pytorch/pytorch/pull/150059 for example", ) def test_profile_all_threads(self): profiling_started = threading.Event() profiling_ended = threading.Event() n_rep = 5 def prep_inputs(): return [torch.randn(1024, 1024, device="cuda") for _ in range(2)] def main_thread_fn(profile_all_threads, returned_events): x, y = prep_inputs() experimental_config = torch._C._profiler._ExperimentalConfig( profile_all_threads=profile_all_threads ) with torch.profiler.profile( experimental_config=experimental_config, record_shapes=True ) as p: profiling_started.set() for _ in range(n_rep): _ = x @ y profiling_ended.wait() returned_events.append(p.events()) def side_thread_fn(): x, y = prep_inputs() profiling_started.wait() for _ in range(n_rep): _ = x @ y profiling_ended.set() def main_with_thread_fn(profile_all_threads): x, y = prep_inputs() experimental_config = torch._C._profiler._ExperimentalConfig( profile_all_threads=profile_all_threads ) with torch.profiler.profile( experimental_config=experimental_config, record_shapes=True ) as p: side_thread = threading.Thread(target=side_thread_fn) side_thread.start() for _ in range(n_rep): _ = x @ y side_thread.join() return p.events() for profile_all_threads in (True, False): returned_events = [] main_thread = threading.Thread( target=main_thread_fn, args=(profile_all_threads, returned_events) ) side_thread = threading.Thread(target=side_thread_fn) main_thread.start() side_thread.start() main_thread.join() side_thread.join() def verify_events(events): mm_events = collections.defaultdict(int) for e in events: if e.name == "aten::mm": mm_events[e.thread] += 1 self.assertEqual(e.input_shapes, [[1024, 1024], [1024, 1024]]) self.assertEqual(len(mm_events), 1 + int(profile_all_threads)) for v in mm_events.values(): self.assertEqual(v, n_rep) verify_events(returned_events[0]) # test spawning thread from within the profiled region events = main_with_thread_fn(profile_all_threads) verify_events(events) @skipIfTorchDynamo("profiler gets ignored if dynamo activated") @unittest.skipIf(not kineto_available(), "Kineto is required") def test_python_gc_event(self): activities = [ProfilerActivity.CPU] def payload(): x = torch.randn(10, 10) y = torch.randn(10, 10) with record_function("pre_gc"): torch.mm(x, y) gc.collect() with record_function("post_gc"): torch.mm(x, y) def validate_json(prof, gc_collection_on): with TemporaryFileName(mode="w+") as fname: prof.export_chrome_trace(fname) with open(fname) as f: events = json.load(f)["traceEvents"] # Find required events if gc_collection_on: pre_gc = next( (e for e in events if e["name"] == "pre_gc"), None ) post_gc = next( (e for e in events if e["name"] == "post_gc"), None ) python_gc_events = [ e for e in events if e["name"] == "Python GC" ] # Assert all required events are present self.assertIsNotNone(pre_gc, "pre_gc event is missing") self.assertIsNotNone(post_gc, "post_gc event is missing") self.assertTrue( len(python_gc_events) > 0, "No Python GC events found" ) # Calculate boundaries pre_gc_end = pre_gc["ts"] + pre_gc.get("dur", 0) post_gc_start = post_gc["ts"] # Assert each Python GC event is correctly placed for python_gc in python_gc_events: python_gc_start = python_gc["ts"] python_gc_end = python_gc["ts"] + python_gc.get("dur", 0) self.assertTrue( python_gc_start > pre_gc_end and python_gc_end < post_gc_start, f"Python GC event at {python_gc_start} is not correctly placed.", ) else: python_gc_events = [ e for e in events if e["name"] == "Python GC" ] self.assertTrue( len(python_gc_events) == 0, "Python GC event found when flag off", ) for gc_flag in [True, False]: with profile( activities=activities, experimental_config=torch._C._profiler._ExperimentalConfig( record_python_gc_info=gc_flag ), with_stack=True, ) as prof: payload() validate_json(prof, gc_flag) class SimpleNet(nn.Module): def __init__(self) -> None: super().__init__() self.fc1 = nn.Linear(10, 5) self.fc2 = nn.Linear(5, 2) def forward(self, x): return self.fc2(self.fc1(x)) @dataclass(frozen=True) class MockKinetoEvent: _name: str _start_us: int _duration_us: int _linked_correlation_id: int _device_type: int @property def name(self) -> str: return self._name def start_ns(self) -> int: return self._start_us * 1000 def duration_ns(self) -> int: return self._duration_us * 1000 def linked_correlation_id(self) -> int: return self._linked_correlation_id def device_type(self) -> DeviceType: return DeviceType.CUDA if self._device_type == 1 else DeviceType.CPU @dataclass(frozen=True) class MockProfilerEvent: _name: str id: int start_time_ns: int duration_time_ns: int correlation_id: int = 0 children: list["MockProfilerEvent"] = field(default_factory=list) parent: Optional["MockProfilerEvent"] = None @property def end_time_ns(self): return self.start_time_ns + self.duration_time_ns @property def name(self) -> str: return self._name def __post__init__(self, parent, children): object.__setattr__(self, "parent", parent) object.__setattr__(self, "children", children) class MockNode: def __init__(self, name, children) -> None: self.name = name self.children = [MockNode(name, i) for name, i in children.items()] class TestExperimentalUtils(TestCase): def make_tree(self) -> list[MockNode]: tree = { "root_0": { "1": {"2": {}}, "3": { "4": {}, "5": {}, }, }, "root_1": { "6": {}, "7": {}, "8": { "9": {"10": {}}, }, }, } return [MockNode(name, i) for name, i in tree.items()] def test_dfs(self) -> None: self.assertEqual( " ".join(i.name for i in _utils.traverse_dfs(self.make_tree())), "root_0 1 2 3 4 5 root_1 6 7 8 9 10", ) def test_bfs(self) -> None: self.assertEqual( " ".join(i.name for i in _utils.traverse_bfs(self.make_tree())), "root_0 root_1 1 3 6 7 8 2 4 5 9 10", ) @staticmethod def generate_mock_profile(): cuda_events = [ MockKinetoEvent("cudaLaunchKernel", 400, 100, 1, 0), MockKinetoEvent("cudaLaunchKernel", 500, 100, 2, 0), MockKinetoEvent("cudaLaunchKernel", 600, 100, 3, 0), MockKinetoEvent("cudaLaunchKernel", 700, 100, 4, 0), MockKinetoEvent("cudaLaunchKernel", 800, 100, 5, 0), MockKinetoEvent("cudaLaunchKernel", 1500, 100, 6, 0), MockKinetoEvent("GPU", 900, 100, 1, 1), MockKinetoEvent("GPU", 1000, 100, 2, 1), MockKinetoEvent("GPU", 1100, 100, 3, 1), MockKinetoEvent("GPU", 1200, 100, 4, 1), MockKinetoEvent("GPU", 1300, 100, 5, 1), MockKinetoEvent("GPU", 1700, 100, 6, 1), ] cpu_events = [ MockProfilerEvent("CPU (Before cudaLaunchKernel)", 1, 0, 100000), MockProfilerEvent("CPU (Before cudaLaunchKernel)", 2, 100000, 100000), MockProfilerEvent("CPU (Before cudaLaunchKernel)", 3, 200000, 100000), MockProfilerEvent("CPU (Before cudaLaunchKernel)", 4, 300000, 100000), MockProfilerEvent("CPU (After cudaLaunchKernel)", 5, 400000, 100000), MockProfilerEvent("CPU (After cudaLaunchKernel)", 6, 500000, 100000), MockProfilerEvent("CPU (After cudaLaunchKernel)", 7, 600000, 100000), MockProfilerEvent("CPU (After cudaLaunchKernel)", 8, 700000, 100000), MockProfilerEvent("CPU (After GPU)", 9, 800000, 100000), MockProfilerEvent("CPU (After GPU)", 10, 900000, 100000), MockProfilerEvent("CPU (After GPU)", 11, 1100000, 100000), MockProfilerEvent("CPU (After GPU)", 12, 1200000, 500000), ] profiler = unittest.mock.Mock() profiler.kineto_results = unittest.mock.Mock() profiler.kineto_results.events = unittest.mock.Mock(return_value=cuda_events) profiler.kineto_results.experimental_event_tree = unittest.mock.Mock( return_value=cpu_events ) return profiler @staticmethod def load_mock_profile(): accept = expecttest.ACCEPT json_file_path = os.path.join( os.path.dirname(os.path.realpath(__file__)), "profiler_utils_mock_events.json", ) if accept and torch.cuda.is_available(): def garbage_code(x): for i in range(5): x[0, i] = i x = torch.ones((4096, 4096), device="cuda") x = x @ x with profile( activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True, with_stack=True, ) as prof: for _ in range(5): x = x @ x garbage_code(x) for _ in range(5): x = x @ x kineto_events = [ { "_name": e.name, "_start_ns": e.start_ns(), "_duration_ns": e.duration_ns(), "_linked_correlation_id": e.linked_correlation_id(), "_device_type": 1 if e.device_type() == DeviceType.CUDA else 0, } for e in prof.profiler.kineto_results.events() ] def EventTreeDFS(event_tree): from collections import deque stack = deque(event_tree) while stack: curr_event = stack.pop() yield curr_event for child_event in curr_event.children: stack.append(child_event) profiler_events = [ { "_name": e.name, "id": e.id, "start_time_ns": e.start_time_ns, "duration_time_ns": e.duration_time_ns, "correlation_id": e.correlation_id, "children": [child.id for child in e.children], "parent": e.parent.id if e.parent else None, } for e in EventTreeDFS( prof.profiler.kineto_results.experimental_event_tree() ) ] with open(json_file_path, "w") as f: json.dump([kineto_events, profiler_events], f) assert os.path.exists(json_file_path) with open(json_file_path) as f: kineto_events, profiler_events = json.load(f) cuda_events = [MockKinetoEvent(*event.values()) for event in kineto_events] cpu_events = [] id_map = {} for e in profiler_events: event = MockProfilerEvent(**e) id_map[event.id] = event cpu_events.append(event) for event in cpu_events: parent = None if event.parent is None else id_map[event.parent] children = [id_map[child] for child in event.children] event.__post__init__(parent, children) cpu_events = [event for event in cpu_events if event.parent is None] profiler = unittest.mock.Mock() profiler.kineto_results = unittest.mock.Mock() profiler.kineto_results.events = unittest.mock.Mock(return_value=cuda_events) profiler.kineto_results.experimental_event_tree = unittest.mock.Mock( return_value=cpu_events ) return profiler def test_utils_compute_self_time(self): with profile() as prof: t1, t2 = ( torch.ones(1, requires_grad=True), torch.ones(1, requires_grad=True), ) z = torch.add(t1, t2) y = torch.ones(1) loss = torch.nn.functional.binary_cross_entropy_with_logits(z, y) loss.backward() basic_eval = _utils.BasicEvaluation(prof.profiler) metrics = basic_eval.metrics self.assertTrue(len(metrics) > 0) for event_key, event_metrics in metrics.items(): self.assertEqual( event_metrics.self_time_ns, event_key.event.duration_time_ns - sum(child.duration_time_ns for child in event_key.event.children), ) def test_utils_intervals_overlap(self): event = _utils.EventKey(MockProfilerEvent("Event 1", 1, 5, 5)) intervals = [ _utils.Interval(0, 9), _utils.Interval(1, 2), _utils.Interval(2, 3), _utils.Interval(3, 4), _utils.Interval(4, 5), _utils.Interval(8, 12), ] print(event.intervals_overlap(intervals)) self.assertEqual(event.intervals_overlap(intervals), 5) def test_utils_compute_queue_depth(self): def format_queue_depth(queue_depth_list, events): res = "" for data, event in zip(queue_depth_list, events): res += f"{data.queue_depth} [{event.name}]\n" return res # We have to use Mock because time series data is too flaky to test profiler = self.generate_mock_profile() basic_evaluation = _utils.BasicEvaluation(profiler) self.assertExpectedInline( format_queue_depth( basic_evaluation.queue_depth_list, basic_evaluation.cuda_events ), """\ 1 [cudaLaunchKernel] 2 [cudaLaunchKernel] 3 [cudaLaunchKernel] 4 [cudaLaunchKernel] 5 [cudaLaunchKernel] 4 [GPU] 3 [GPU] 2 [GPU] 1 [GPU] 0 [GPU] 1 [cudaLaunchKernel] 0 [GPU] """, ) self.assertExpectedInline( format_queue_depth( [basic_evaluation.metrics[k] for k in basic_evaluation.event_keys], basic_evaluation.events, ), """\ 0 [CPU (Before cudaLaunchKernel)] 0 [CPU (Before cudaLaunchKernel)] 0 [CPU (Before cudaLaunchKernel)] 0 [CPU (Before cudaLaunchKernel)] 1 [CPU (After cudaLaunchKernel)] 2 [CPU (After cudaLaunchKernel)] 3 [CPU (After cudaLaunchKernel)] 4 [CPU (After cudaLaunchKernel)] 5 [CPU (After GPU)] 4 [CPU (After GPU)] 2 [CPU (After GPU)] 1 [CPU (After GPU)] """, ) def test_utils_compute_queue_depth_when_no_cuda_events(self): # For traces with only cpu events, we expect empty queue depth list x = torch.ones((1024, 1024)) with profile() as prof: for _ in range(5): x = x @ x basic_evaluation = _utils.BasicEvaluation(prof.profiler) self.assertFalse(basic_evaluation.compute_queue_depth()) def test_utils_compute_idle_time(self): profiler = self.generate_mock_profile() basic_evaluation = _utils.BasicEvaluation(profiler) expected_output = "\n".join( [ f"{basic_evaluation.metrics[event_key].idle_time_ns} [{event_key.event.name}]" for event_key in basic_evaluation.event_keys ] ) self.assertExpectedInline( expected_output, """\ 100000 [CPU (Before cudaLaunchKernel)] 100000 [CPU (Before cudaLaunchKernel)] 100000 [CPU (Before cudaLaunchKernel)] 100000 [CPU (Before cudaLaunchKernel)] 0 [CPU (After cudaLaunchKernel)] 0 [CPU (After cudaLaunchKernel)] 0 [CPU (After cudaLaunchKernel)] 0 [CPU (After cudaLaunchKernel)] 0 [CPU (After GPU)] 0 [CPU (After GPU)] 0 [CPU (After GPU)] 100000 [CPU (After GPU)]""", ) @unittest.skipIf(IS_JETSON, "JSON not behaving as expected on Jetson") def test_utils_get_optimizable_events(self): basic_evaluation = _utils.BasicEvaluation(self.load_mock_profile()) optimizable_events = basic_evaluation.get_optimizable_events( 2, print_enable=False ) expected_output = "\n".join( [f"{event_key.event.name}" for event_key in optimizable_events] ) self.assertExpectedInline( expected_output, """\ aten::copy_""", ) def test_profiler_name_pattern(self): x = torch.ones((4096, 4096)) with profile() as prof: for _ in range(5): x = x @ x x = x + x matched_events = NamePattern(prof, "aten::mm").matched_events() output = "\n".join([f"{event.name}" for event in matched_events]) self.assertExpectedInline( output, """\ aten::mm aten::mm aten::mm aten::mm aten::mm""", ) # TODO: Add logic for CUDA version of test @unittest.skipIf(torch.cuda.is_available(), "Test not working for CUDA") def test_profiler_pattern_match_helper(self): x = torch.ones((100, 100)) with profile() as prof: for _ in range(5): x = x @ x x = x + x event_tree = prof.profiler.kineto_results.experimental_event_tree() pattern = Pattern(prof) self.assertEqual([], pattern.siblings_of(event_tree[0])[0]) self.assertEqual(event_tree[1:], pattern.siblings_of(event_tree[0])[1]) child_nodes = event_tree[0].children self.assertEqual([], pattern.siblings_of(child_nodes[0])[0]) self.assertEqual(child_nodes[1:], pattern.siblings_of(child_nodes[0])[1]) self.assertEqual( event_tree[0], pattern.root_of(event_tree[0].children[0].children[0]) ) self.assertEqual(None, pattern.next_of(event_tree[-1])) self.assertEqual(event_tree[1], pattern.next_of(event_tree[0])) self.assertEqual(event_tree[0], pattern.prev_of(event_tree[1])) @unittest.skipIf( TEST_WITH_CROSSREF, "crossref intercepts calls and changes the callsite." ) @unittest.skipIf(not torch.cuda.is_available(), "CUDA is required") def test_profiler_extra_cuda_copy_pattern(self): cases = ( (0, lambda: torch.ones((100, 100), device="cuda")), (1, lambda: torch.ones((100, 100)).to("cuda")), (1, lambda: torch.zeros((100, 100)).to("cuda")), (1, lambda: torch.empty((100, 100)).fill_(5).to("cuda")), (1, lambda: torch.ones((100, 100)).cuda()), (1, lambda: torch.zeros((100, 100)).cuda()), (1, lambda: torch.empty((100, 100)).fill_(5).cuda()), (1, lambda: torch.rand((100, 100)).cuda()), (1, lambda: torch.randn((100, 100)).cuda()), (1, lambda: torch.full((100, 100), 10).cuda()), (0, lambda: torch.rand((100, 100)).to(dtype=torch.float16)), (0, lambda: torch.rand((100, 100)).half()), (0, lambda: torch.rand((100, 100), device="cuda").half()), ) num_matched = [] for _, fn in cases: with profile(with_stack=True, record_shapes=True) as prof: fn() pattern = ExtraCUDACopyPattern(prof) num_matched.append(len(pattern.matched_events())) self.assertEqual(num_matched, [i for i, _ in cases]) @unittest.skipIf( TEST_WITH_CROSSREF, "crossref intercepts calls and changes the callsite." ) def test_profiler_for_loop_indexing_pattern(self): x = torch.ones((100, 100)) def case1(): for i in range(100): x[i] = i def case2(): y = 0 for i in range(100): y += x[i] def case3(): y = 1 for i in range(100): y *= x[i] def case4(): y = x for _ in range(100): y = y @ x def case5(): for i in range(100): x[i, :] = torch.arange(100) + i cases = ((1, case1), (1, case2), (1, case3), (0, case4), (1, case5)) num_matched = [] for _, fn in cases: with profile(with_stack=True) as prof: fn() pattern = ForLoopIndexingPattern(prof) num_matched.append(len(pattern.matched_events())) self.assertEqual(num_matched, [i for i, _ in cases]) @unittest.skipIf(not torch.cuda.is_available(), "CUDA is required") def test_profiler_fp32_matmul_pattern(self): x = torch.ones((100, 100), device="cuda") with profile(with_stack=True) as prof: x = x @ x pattern = FP32MatMulPattern(prof) has_tf32 = 0 if pattern.skip else 1 num_matched = len(pattern.matched_events()) self.assertEqual(num_matched, has_tf32) @unittest.skipIf(not torch.cuda.is_available(), "CUDA is required") def test_profiler_extra_cuda_copy_pattern_benchmark(self): with profile(with_stack=True, record_shapes=True) as prof: x = torch.ones((100, 100)).to("cuda") x = torch.ones((50, 50)).to("cuda") pattern = ExtraCUDACopyPattern(prof) shapes_factor_map = pattern.benchmark(pattern.matched_events()) self.assertEqual(len(shapes_factor_map), 2) def test_profiler_optimizer_single_tensor_pattern(self): x = torch.ones((100, 100)) cases = ( (1, lambda: torch.optim.Adam(model.parameters())), (1, lambda: torch.optim.SGD(model.parameters(), lr=0.01)), (1, lambda: torch.optim.AdamW(model.parameters())), (0, lambda: torch.optim.Adam(model.parameters(), foreach=True)), (0, lambda: torch.optim.SGD(model.parameters(), lr=0.01, foreach=True)), (0, lambda: torch.optim.AdamW(model.parameters(), foreach=True)), ) num_matched = [] for _, fn in cases: with profile(with_stack=True) as prof: model = nn.Sequential( nn.Linear(100, 100), nn.ReLU(), nn.Linear(100, 10), ) optimizer = fn() optimizer.zero_grad() y_hat = model(x) loss = torch.nn.functional.cross_entropy( y_hat, torch.randint(0, 10, (100,)) ) loss.backward() optimizer.step() pattern = OptimizerSingleTensorPattern(prof) num_matched.append(len(pattern.matched_events())) self.assertEqual(num_matched, [i for i, _ in cases]) def test_profiler_synchronized_dataloader_pattern(self): dataset = torch.rand((100, 100)) sync_dataloader = torch.utils.data.DataLoader(dataset, batch_size=10) async_dataloader = torch.utils.data.DataLoader( dataset, batch_size=10, num_workers=4 ) with profile(with_stack=True) as prof: next(iter(sync_dataloader)) next(iter(async_dataloader)) pattern = SynchronizedDataLoaderPattern(prof) num_matched = len(pattern.matched_events()) self.assertEqual(num_matched, 1) @skipIfTorchDynamo( "pattern checks for aten::_zero op which might not be there with torch.compile'd graph" ) def test_profiler_grad_not_set_to_none_pattern(self): x = torch.ones((100, 100)) model = nn.Sequential( nn.Linear(100, 100), nn.ReLU(), nn.Linear(100, 10), ) optimizer = torch.optim.Adam(model.parameters()) cases = ( (0, lambda: optimizer.zero_grad()), (0, lambda: model.zero_grad()), (1, lambda: optimizer.zero_grad(set_to_none=False)), (1, lambda: model.zero_grad(set_to_none=False)), ) num_matched = [] for _, fn in cases: with profile(with_stack=True) as prof: y_hat = model(x) loss = torch.nn.functional.cross_entropy( y_hat, torch.randint(0, 10, (100,)) ) loss.backward() optimizer.step() fn() pattern = GradNotSetToNonePattern(prof) num_matched.append(len(pattern.matched_events())) self.assertEqual(num_matched, [i for i, _ in cases]) def test_profiler_conv2d_bias_followed_by_batchnorm2d_pattern(self): x = torch.randn((1, 3, 32, 32)) cases = ( (1, nn.Sequential(nn.Conv2d(3, 3, 3, 1, 1), nn.BatchNorm2d(3))), (0, nn.Sequential(nn.Conv2d(3, 3, 3, 1, 1, bias=False), nn.BatchNorm2d(3))), (0, nn.Sequential(nn.Conv2d(3, 3, 3, 1, 1))), ) num_matched = [] for _, model in cases: with profile(with_stack=True, record_shapes=True) as prof: model(x) pattern = Conv2dBiasFollowedByBatchNorm2dPattern(prof) num_matched.append(len(pattern.matched_events())) self.assertEqual(num_matched, [i for i, _ in cases]) @unittest.skipIf(not torch.cuda.is_available(), "CUDA is required") def test_profiler_matmul_dim_fp16_pattern(self): cases = ( (1, torch.randn((201, 201), device="cuda", dtype=torch.float16)), (1, torch.randn((3, 97, 97), device="cuda", dtype=torch.float16)), (0, torch.randn((200, 200), device="cuda", dtype=torch.float16)), (0, torch.randn((3, 200, 200), device="cuda", dtype=torch.float16)), ) num_matched = [] for _, x in cases: with profile(with_stack=True, record_shapes=True) as prof: x @ x pattern = MatMulDimInFP16Pattern(prof) num_matched.append(len(pattern.matched_events())) self.assertEqual(num_matched, [i for i, _ in cases]) @skipIfTorchDynamo("profiler gets ignored if dynamo activated") def test_profiler_pattern_matcher_json_report(self): x = torch.ones((100, 100)) model = nn.Sequential( nn.Linear(100, 100), nn.ReLU(), nn.Linear(100, 10), ) optimizer = torch.optim.Adam(model.parameters()) with profile(with_stack=True, record_shapes=True) as prof: y_hat = model(x) loss = torch.nn.functional.cross_entropy( y_hat, torch.randint(0, 10, (100,)) ) loss.backward() optimizer.step() optimizer.zero_grad() with tempfile.TemporaryDirectory() as tmpdir: report_all_anti_patterns(prof, json_report_dir=tmpdir, print_enable=False) with open(os.path.join(tmpdir, "torchtidy_report.json")) as f: report = json.load(f) # It is platform dependent whether the path will include "profiler/" keys = [k for k in report.keys() if k.endswith("test_profiler.py")] self.assertEqual(len(keys), 1, f"{keys}") entry = report[keys[0]] self.assertTrue(len(entry) > 0) expected_fields = sorted(["line_number", "name", "url", "message"]) for event in entry: actual_fields = sorted(event.keys()) self.assertEqual(expected_fields, actual_fields) @unittest.skipIf(IS_ARM64 or not IS_LINUX, "x86 linux only cpp unwinding") def test_fuzz_symbolize(self): # generate some random addresses in the text section and make sure the # symbolizers do not throw exceptions/crash def get_text_sections(): text_sections = [] seen = set() for filename in os.listdir("/proc/self/map_files"): library = os.readlink("/proc/self/map_files/" + filename) if ".so" not in library or library in seen: continue seen.add(library) with open(os.path.join("/proc/self/map_files", library), "rb") as f: mm = mmap.mmap(f.fileno(), 0, prot=mmap.PROT_READ) def unpack(fmt, offset): return struct.unpack( fmt, mm[offset : offset + struct.calcsize(fmt)] ) if mm[:4] != b"\x7fELF": continue (section_headers_start,) = unpack("Q", 40) (section_header_size,) = unpack("H", 58) (num_section_headers,) = unpack("H", 60) (shstrndx,) = unpack("H", 62) (shstrtab_offset,) = unpack( "Q", section_headers_start + shstrndx * section_header_size + 24 ) for i in range(num_section_headers): (section_name_offset,) = unpack( "I", section_headers_start + i * section_header_size ) name_start = shstrtab_offset + section_name_offset section_name = mm[name_start : name_start + 6] if section_name != b".text\0": continue (section_offset,) = unpack( "Q", section_headers_start + i * section_header_size + 24 ) (section_size,) = unpack( "Q", section_headers_start + i * section_header_size + 32 ) start = int(filename.split("-")[0], 16) + section_offset text_sections.append((start, section_size)) break mm.close() return text_sections r = random.Random() r.seed(1) text_sections = get_text_sections() addrs = [] for _ in range(200): s = r.randrange(0, len(text_sections)) start, size = text_sections[s] addr = r.randrange(start, start + size) addrs.append(addr) fast = torch._C._profiler.symbolize_addresses(addrs, "fast") dladdr = torch._C._profiler.symbolize_addresses(addrs, "dladdr") addr2line = torch._C._profiler.symbolize_addresses(addrs, "addr2line") self.assertEqual(len(fast), len(addrs)) self.assertEqual(len(addr2line), len(fast)) def test_profiler_overload_names(self): from torch.library import _scoped_library, fallthrough_kernel def validate_json(prof): print() with TemporaryFileName(mode="w+") as fname: prof.export_chrome_trace(fname) with open(fname) as f: events = json.load(f)["traceEvents"] self.assertTrue( any("aten::add.Tensor" in e["name"] for e in events) ) self.assertTrue(any("aten::add.out" in e["name"] for e in events)) with _scoped_library("aten", "IMPL") as my_lib: my_lib.impl("add.Tensor", fallthrough_kernel, "CPU") experimental_config = torch._C._profiler._ExperimentalConfig( capture_overload_names=True ) with profile( experimental_config=experimental_config, activities=[ProfilerActivity.CPU], ) as prof: torch.add(1, 5) # The following execution trace is expected # # Dispatch trace: # [call] op=[aten::add.Tensor], key=[AutogradCPU] # [redispatch] op=[aten::add.Tensor], key=[Undefined] # [call] op=[aten::empty.memory_format], key=[BackendSelect] # [redispatch] op=[aten::empty.memory_format], key=[CPU] # [call] op=[aten::add.out], key=[CPU] # # prof.table() # --------------- --------------- ------------ ------------ ------------ ------------ ------------ ------------ # Name Overload Name Self CPU % Self CPU CPU total % CPU total CPU time avg # of Calls # --------------- --------------- ------------ ------------ ------------ ------------ ------------ ------------ # aten::add Tensor 71.97% 130.887us 100.00% 181.873us 181.873us 1 # aten::empty memory_format 8.52% 15.489us 8.52% 15.489us 15.489us 1 # aten::add out 19.52% 35.497us 19.52% 35.497us 35.497us 1 # --------------- --------------- ------------ ------------ ------------ ------------ ------------ ------------ # aten::add.out and aten::empty.memory_format are children of aten::add.Tensor aten_add_parent: list[FunctionEvent] = [ event for event in prof.events() if len(event.cpu_children) == 2 ] assert len(aten_add_parent) == 1 aten_add_parent = aten_add_parent[0] assert aten_add_parent.overload_name == "Tensor" aten_add_out_event = [ c for c in aten_add_parent.cpu_children if c.overload_name == "out" ] assert len(aten_add_out_event) == 1 # Without group_by_overload_name, the overload name is ignored in the key averages key_averages = prof.key_averages() assert len(key_averages) == 2 assert "Overload Name" not in key_averages.table() key_averages = prof.key_averages(group_by_overload_name=True) assert len(key_averages) == 3 assert "Overload Name" in key_averages.table() validate_json(prof) def test_expose_kineto_event_metadata(self): def check_metadata(prof, op_name, metadata_key): with TemporaryFileName(mode="w+") as fname: prof.export_chrome_trace(fname) with open(fname) as f: events = json.load(f)["traceEvents"] found_op = False for e in events: if "name" in e and "args" in e and e["name"] == op_name: assert metadata_key in e["args"], ( f"Metadata for '{op_name}' in Chrome trace did not contain '{metadata_key}'." ) found_op = True assert found_op, f"Could not find op '{op_name}' in Chrome trace." found_op = False for event in prof.events(): if event.name == op_name: assert metadata_key in event.metadata_json, ( f"Metadata for '{op_name}' in FunctionEvent did not contain '{metadata_key}'." ) found_op = True assert found_op, f"Could not find op '{op_name}' in prof.events()." experimental_config = torch._C._profiler._ExperimentalConfig( expose_kineto_event_metadata=True ) with profile( experimental_config=experimental_config, activities=[ProfilerActivity.CPU], ) as prof: torch.add(1, 5) check_metadata(prof, op_name="aten::add", metadata_key="Ev Idx") @unittest.skipIf(not torch.cuda.is_available(), "requries CUDA") def test_profiler_debug_autotuner(self): """ This test makes sure that profiling events will be present when the kernel is run using the DebugAutotuner. """ if not is_big_gpu(): raise unittest.SkipTest("requires large gpu to max-autotune") in1 = torch.randn((256, 512), device="cuda", dtype=torch.float16) in2 = torch.randn((512, 768), device="cuda", dtype=torch.float16) def mm(): return torch.mm(in1, in2) pb_mm = torch.compile( mm, options={ "benchmark_kernel": True, "max_autotune": True, "max_autotune_gemm_backends": "TRITON", "profile_bandwidth": True, }, ) comp_mm = torch.compile( mm, options={ "benchmark_kernel": True, "max_autotune": True, "max_autotune_gemm_backends": "TRITON", }, ) with profile() as prof1: pb_mm() with profile() as prof2: comp_mm() def names(prof): return { ev.name for ev in prof.events() if "mm" in ev.name or "triton" in ev.name } n1 = names(prof1) n2 = names(prof2) self.assertEqual(n1, n2) if __name__ == "__main__": run_tests()