mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
In dynamo/inductor, sometimes it helps to gather metrics/statistics for each model in different levels like model level, graph level, kernel level or pair of fusion nodes level. This kind of thing will be very easy to do with Scuba, but we only have scuba in fbcode. This PR build metric tables to solve part of the problem. Q: why not log to stdout/err direclty A: sometimes we need more structured data. E.g., it would be helpful to gather all the stats in a CSV and then do post-processing (like calculating a geomean etc.). Also metric table will tag each row with the model name which is helpful. Q: what's the difference with speedup_indcutor.csv A: speedup_indcutor.csv is a special case that gather statistics on model level: i.e., we have one row for each model. But recording statistics on finer grain level like graph etc. is also helpful. Example use cases: - As a followup on the bechmark fusion PR, I want to gather all the 'slow' fusion and analyze them. With the metric table, I can easily log slow fusion for each model into a csv file. Here is the log gathered for huggingface: https://gist.github.com/shunting314/964e73cc98368b301414ec7b7ad4c702 . - To help understand the effect of 'loop ordering after fusion' PR, it would be helpful to gather stats like how many fusions happens for each graph. Previously we log the metric to stderr directly. But logging these metrics in a structural way is useful. - gather number of registers, register spills, shared memory usage for each kernel in each model with runnable kernel code logged. Pull Request resolved: https://github.com/pytorch/pytorch/pull/109245 Approved by: https://github.com/jansel, https://github.com/mlazos
188 lines
4.8 KiB
Python
188 lines
4.8 KiB
Python
from __future__ import annotations
|
|
|
|
import csv
|
|
import os
|
|
from dataclasses import dataclass
|
|
from functools import lru_cache
|
|
|
|
from typing import List, Set, Tuple, TYPE_CHECKING, Union
|
|
|
|
from torch._inductor import config
|
|
from torch._inductor.utils import get_benchmark_name
|
|
|
|
# Prevent circular import
|
|
if TYPE_CHECKING:
|
|
from torch._inductor.scheduler import (
|
|
BaseSchedulerNode,
|
|
ExternKernelSchedulerNode,
|
|
NopKernelSchedulerNode,
|
|
SchedulerNode,
|
|
)
|
|
|
|
# counter for tracking how many kernels have been generated
|
|
generated_kernel_count = 0
|
|
generated_cpp_vec_kernel_count = 0
|
|
num_bytes_accessed = 0
|
|
nodes_num_elem: List[
|
|
Tuple[
|
|
Union[NopKernelSchedulerNode, SchedulerNode, ExternKernelSchedulerNode],
|
|
int,
|
|
]
|
|
] = []
|
|
node_runtimes: List[Tuple[BaseSchedulerNode, float]] = []
|
|
|
|
# counters for tracking fusions
|
|
ir_nodes_pre_fusion = 0
|
|
|
|
# counters for tracking to_dtype inserted
|
|
cpp_to_dtype_count = 0
|
|
|
|
# counters for tracking cpp_wrapper disabled
|
|
disable_cpp_wrapper = 0
|
|
|
|
|
|
# reset all counters
|
|
def reset():
|
|
global generated_kernel_count
|
|
global generated_cpp_vec_kernel_count
|
|
global num_bytes_accessed, nodes_num_elem
|
|
global ir_nodes_pre_fusion
|
|
global cpp_to_dtype_count
|
|
global disable_cpp_wrapper
|
|
|
|
generated_kernel_count = 0
|
|
generated_cpp_vec_kernel_count = 0
|
|
num_bytes_accessed = 0
|
|
nodes_num_elem.clear()
|
|
node_runtimes.clear()
|
|
ir_nodes_pre_fusion = 0
|
|
cpp_to_dtype_count = 0
|
|
disable_cpp_wrapper = 0
|
|
|
|
|
|
REGISTERED_METRIC_TABLES = {}
|
|
|
|
|
|
@dataclass
|
|
class MetricTable:
|
|
table_name: str
|
|
column_names: List[str]
|
|
|
|
num_rows_added: int = 0
|
|
|
|
def add_row(self, row_fn):
|
|
if self.table_name not in enabled_metric_tables():
|
|
return
|
|
|
|
row_dict = row_fn()
|
|
assert len(self.column_names) == len(
|
|
row_dict
|
|
), f"{len(self.column_names)} v.s. {len(row_dict)}"
|
|
assert set(self.column_names) == set(
|
|
row_dict.keys()
|
|
), f"{set(self.column_names)} v.s. {set(row_dict.keys())}"
|
|
|
|
row = [
|
|
get_benchmark_name(),
|
|
]
|
|
row += [row_dict[column_name] for column_name in self.column_names]
|
|
self._write_row(row)
|
|
|
|
def output_filename(self):
|
|
return f"metric_table_{self.table_name}.csv"
|
|
|
|
def write_header(self):
|
|
filename = self.output_filename()
|
|
with open(filename, "w") as fd:
|
|
writer = csv.writer(fd, lineterminator="\n")
|
|
writer.writerow(["model_name"] + self.column_names)
|
|
|
|
def _write_row(self, row):
|
|
filename = self.output_filename()
|
|
if self.num_rows_added == 0 and not os.path.exists(filename):
|
|
self.write_header()
|
|
|
|
self.num_rows_added += 1
|
|
|
|
for idx, orig_val in enumerate(row):
|
|
if isinstance(orig_val, float):
|
|
new_val = f"{orig_val:.6f}"
|
|
elif orig_val is None:
|
|
new_val = ""
|
|
else:
|
|
new_val = orig_val
|
|
row[idx] = new_val
|
|
|
|
with open(filename, "a") as fd:
|
|
writer = csv.writer(fd, lineterminator="\n")
|
|
writer.writerow(row)
|
|
|
|
@staticmethod
|
|
def register_table(name, column_names):
|
|
table = MetricTable(name, column_names)
|
|
REGISTERED_METRIC_TABLES[name] = table
|
|
|
|
|
|
MetricTable.register_table(
|
|
"slow_fusion",
|
|
[
|
|
"kernel1_path",
|
|
"kernel1_latency",
|
|
"kernel2_path",
|
|
"kernel2_latency",
|
|
"fused_kernel_path",
|
|
"fused_kernel_latency",
|
|
"slow_down_ratio",
|
|
],
|
|
)
|
|
|
|
# track the fusion statistics for each graph
|
|
MetricTable.register_table(
|
|
"graph_stats",
|
|
[
|
|
"graph_id",
|
|
"num_nodes_before_fusion",
|
|
"num_nodes_after_fusion",
|
|
],
|
|
)
|
|
|
|
|
|
def purge_old_log_files():
|
|
"""
|
|
Purge the old log file at the beginning when the benchmark script runs.
|
|
Should do it in the parent process rather than the child processes running
|
|
each individual model.
|
|
"""
|
|
for name, table in REGISTERED_METRIC_TABLES.items():
|
|
if name in enabled_metric_tables():
|
|
filename = table.output_filename()
|
|
if os.path.exists(filename):
|
|
os.unlink(filename)
|
|
|
|
table.write_header()
|
|
|
|
|
|
@lru_cache
|
|
def enabled_metric_tables() -> Set[str]:
|
|
config_str = config.enabled_metric_tables
|
|
|
|
enabled = set()
|
|
for name in config_str.split(","):
|
|
name = name.strip()
|
|
if not name:
|
|
continue
|
|
assert (
|
|
name in REGISTERED_METRIC_TABLES
|
|
), f"Metric table name {name} is not registered"
|
|
enabled.add(name)
|
|
return enabled
|
|
|
|
|
|
def is_metric_table_enabled(name):
|
|
return name in enabled_metric_tables()
|
|
|
|
|
|
def get_metric_table(name):
|
|
assert name in REGISTERED_METRIC_TABLES, f"Metric table {name} is not defined"
|
|
return REGISTERED_METRIC_TABLES[name]
|