mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[Inductor][NCU] Add kernel name filtering, and allow custom metrics (#150872)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/150872 Approved by: https://github.com/FindHao Co-authored-by: Yueming Hao <yhao@meta.com>
This commit is contained in:
parent
b117a6c47b
commit
82cb202de7
|
|
@ -1,3 +1,4 @@
|
|||
import argparse
|
||||
import dataclasses
|
||||
import datetime
|
||||
import tempfile
|
||||
|
|
@ -316,12 +317,17 @@ def perf_profile(
|
|||
|
||||
|
||||
def ncu_analyzer(
|
||||
benchmark_name: str, benchmark_compiled_module_fn: BenchmarkCallableType
|
||||
benchmark_name: str,
|
||||
benchmark_compiled_module_fn: BenchmarkCallableType,
|
||||
args: argparse.Namespace,
|
||||
) -> None:
|
||||
import inspect
|
||||
import os
|
||||
import subprocess
|
||||
|
||||
kernel_regex = args.ncu_kernel_regex
|
||||
metrics = args.ncu_metrics
|
||||
|
||||
module_file = inspect.getfile(benchmark_compiled_module_fn)
|
||||
module_dir = os.path.dirname(module_file)
|
||||
module_name = os.path.splitext(os.path.basename(module_file))[0]
|
||||
|
|
@ -345,17 +351,28 @@ def ncu_analyzer(
|
|||
"function",
|
||||
"--print-units",
|
||||
"base",
|
||||
"--set",
|
||||
"full",
|
||||
"--import-source",
|
||||
"yes",
|
||||
"--force-overwrite",
|
||||
"--export",
|
||||
ncu_output,
|
||||
]
|
||||
|
||||
if kernel_regex:
|
||||
ncu_cmd.extend(["--kernel-name", f"regex:{kernel_regex}"])
|
||||
|
||||
if metrics:
|
||||
ncu_cmd.extend(["--metrics", metrics])
|
||||
else:
|
||||
ncu_cmd.extend(["--set", "full"])
|
||||
|
||||
ncu_cmd.extend(
|
||||
[
|
||||
"python",
|
||||
"-c",
|
||||
python_cmd,
|
||||
]
|
||||
)
|
||||
|
||||
try:
|
||||
subprocess.run(ncu_cmd, check=True)
|
||||
|
|
@ -421,6 +438,25 @@ def compiled_module_main(
|
|||
action="store_true",
|
||||
help="Whether to run ncu analysis",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ncu-kernel-regex",
|
||||
type=str,
|
||||
default=None,
|
||||
help=(
|
||||
"Filter kernels profiled by NCU using a regex (e.g., '^triton_.*'). "
|
||||
"Maps to '--kernel-name regex:<regex>'. "
|
||||
"If None, NCU will profile all kernels."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ncu-metrics",
|
||||
type=str,
|
||||
default=None,
|
||||
help=(
|
||||
"Comma-separated list of NCU metrics to collect (e.g., 'dram__bytes.sum.per_second'). "
|
||||
"If None, NCU will use '--set full'."
|
||||
),
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.benchmark_kernels:
|
||||
|
|
@ -449,4 +485,8 @@ def compiled_module_main(
|
|||
benchmark_compiled_module_fn,
|
||||
)
|
||||
if args.ncu:
|
||||
ncu_analyzer(benchmark_name, benchmark_compiled_module_fn)
|
||||
ncu_analyzer(
|
||||
benchmark_name,
|
||||
benchmark_compiled_module_fn,
|
||||
args=args,
|
||||
)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user