[Inductor][NCU] Add kernel name filtering, and allow custom metrics (#150872)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150872
Approved by: https://github.com/FindHao

Co-authored-by: Yueming Hao <yhao@meta.com>
This commit is contained in:
Will Feng 2025-05-04 00:41:06 -07:00 committed by PyTorch MergeBot
parent b117a6c47b
commit 82cb202de7

View File

@ -1,3 +1,4 @@
import argparse
import dataclasses
import datetime
import tempfile
@ -316,12 +317,17 @@ def perf_profile(
def ncu_analyzer(
benchmark_name: str, benchmark_compiled_module_fn: BenchmarkCallableType
benchmark_name: str,
benchmark_compiled_module_fn: BenchmarkCallableType,
args: argparse.Namespace,
) -> None:
import inspect
import os
import subprocess
kernel_regex = args.ncu_kernel_regex
metrics = args.ncu_metrics
module_file = inspect.getfile(benchmark_compiled_module_fn)
module_dir = os.path.dirname(module_file)
module_name = os.path.splitext(os.path.basename(module_file))[0]
@ -345,18 +351,29 @@ def ncu_analyzer(
"function",
"--print-units",
"base",
"--set",
"full",
"--import-source",
"yes",
"--force-overwrite",
"--export",
ncu_output,
"python",
"-c",
python_cmd,
]
if kernel_regex:
ncu_cmd.extend(["--kernel-name", f"regex:{kernel_regex}"])
if metrics:
ncu_cmd.extend(["--metrics", metrics])
else:
ncu_cmd.extend(["--set", "full"])
ncu_cmd.extend(
[
"python",
"-c",
python_cmd,
]
)
try:
subprocess.run(ncu_cmd, check=True)
print(f"\nNCU profiling results for benchmark {benchmark_name}:")
@ -421,6 +438,25 @@ def compiled_module_main(
action="store_true",
help="Whether to run ncu analysis",
)
parser.add_argument(
"--ncu-kernel-regex",
type=str,
default=None,
help=(
"Filter kernels profiled by NCU using a regex (e.g., '^triton_.*'). "
"Maps to '--kernel-name regex:<regex>'. "
"If None, NCU will profile all kernels."
),
)
parser.add_argument(
"--ncu-metrics",
type=str,
default=None,
help=(
"Comma-separated list of NCU metrics to collect (e.g., 'dram__bytes.sum.per_second'). "
"If None, NCU will use '--set full'."
),
)
args = parser.parse_args()
if args.benchmark_kernels:
@ -449,4 +485,8 @@ def compiled_module_main(
benchmark_compiled_module_fn,
)
if args.ncu:
ncu_analyzer(benchmark_name, benchmark_compiled_module_fn)
ncu_analyzer(
benchmark_name,
benchmark_compiled_module_fn,
args=args,
)