mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary:
Support parallel reduction for GroupNorm by optimizing the parallelization heuristics: When the range of the first inner loop is much larger than the range of all outer loops, change the starting depth of parallelization to the first inner loop.
I tested the Inductor benchmark with this PR on CPU. One torchbench model(pytorch_CycleGAN_and_pix2pix) achieved ~45% performance improvement, and two diffusion models(Stable Diffusion and Latent Consistency Model(LCM)) achieved ~2% performance improvement.
Example:
```
import torch
import torch.nn as nn
class GN(nn.Module):
def __init__(self, num_groups, num_channels):
super(GN, self).__init__()
self.gn = nn.GroupNorm(num_groups, num_channels)
def forward(self, x):
return self.gn(x)
x = torch.randn(2, 64, 168, 168).to(memory_format=torch.channels_last)
m = GN(2, 64).eval()
compiled_m = torch.compile(m)
with torch.no_grad():
out = compiled_m(x)
```
Generated code:
- Before:
```
cpp_fused_native_group_norm_0 = async_compile.cpp_pybinding(['const float*', 'const float*', 'const float*', 'float*', 'float*', 'float*', 'float*', 'float*'], '''
#include "/tmp/torchinductor_jiayisun/pi/cpicxudqmdsjh5cm4klbtbrvy2cxwr7whxl3md2zzdjdf3orvfdf.h"
extern "C" void kernel(const float* in_ptr0,
const float* in_ptr1,
const float* in_ptr2,
float* out_ptr0,
float* out_ptr1,
float* out_ptr2,
float* out_ptr3,
float* out_ptr4)
{
#pragma omp parallel num_threads(56)
{
int tid = omp_get_thread_num();
{
#pragma omp for collapse(2)
for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(2L); x0+=static_cast<int64_t>(1L))
{
for(int64_t x1=static_cast<int64_t>(0L); x1<static_cast<int64_t>(2L); x1+=static_cast<int64_t>(1L))
{
{
Welford<float> tmp_acc0 = Welford<float>();
Welford<at::vec::Vectorized<float>> tmp_acc0_vec = Welford<at::vec::Vectorized<float>>();
Welford<at::vec::Vectorized<float>> masked_tmp_acc0_vec = Welford<at::vec::Vectorized<float>>();
static WeightRecp<at::vec::Vectorized<float>> wrecps0(static_cast<int64_t>(56448L));
for(int64_t x2=static_cast<int64_t>(0L); x2<static_cast<int64_t>(28224L); x2+=static_cast<int64_t>(1L))
{
for(int64_t x3=static_cast<int64_t>(0L); x3<static_cast<int64_t>(32L); x3+=static_cast<int64_t>(16L))
{
{
if(C10_LIKELY(x3 >= static_cast<int64_t>(0) && x3 < static_cast<int64_t>(32L)))
{
auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<int64_t>(x3 + 32L*x1 + 64L*x2 + 1806336L*x0), static_cast<int64_t>(16));
tmp_acc0_vec = welford_combine(tmp_acc0_vec, tmp0, &wrecps0);
}
}
}
}
tmp_acc0 = welford_combine(tmp_acc0, welford_vec_reduce_all(masked_tmp_acc0_vec));
tmp_acc0 = welford_combine(tmp_acc0, welford_vec_reduce_all(tmp_acc0_vec));
out_ptr0[static_cast<int64_t>(x1 + 2L*x0)] = static_cast<float>(tmp_acc0.mean);
out_ptr1[static_cast<int64_t>(x1 + 2L*x0)] = static_cast<float>(tmp_acc0.m2);
}
}
}
}
#pragma omp single
{
{
#pragma GCC ivdep
for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(2L); x0+=static_cast<int64_t>(1L))
{
#pragma GCC ivdep
for(int64_t x1=static_cast<int64_t>(0L); x1<static_cast<int64_t>(2L); x1+=static_cast<int64_t>(1L))
{
for(int64_t x2=static_cast<int64_t>(0L); x2<static_cast<int64_t>(32L); x2+=static_cast<int64_t>(16L))
{
{
if(C10_LIKELY(x2 >= static_cast<int64_t>(0) && x2 < static_cast<int64_t>(32L)))
{
auto tmp0 = out_ptr1[static_cast<int64_t>(x1 + 2L*x0)];
auto tmp6 = at::vec::Vectorized<float>::loadu(in_ptr1 + static_cast<int64_t>(x2 + 32L*x1), static_cast<int64_t>(16));
auto tmp9 = out_ptr0[static_cast<int64_t>(x1 + 2L*x0)];
auto tmp13 = at::vec::Vectorized<float>::loadu(in_ptr2 + static_cast<int64_t>(x2 + 32L*x1), static_cast<int64_t>(16));
auto tmp1 = static_cast<float>(903168.0);
auto tmp2 = tmp0 / tmp1;
auto tmp3 = static_cast<float>(1e-05);
auto tmp4 = decltype(tmp2)(tmp2 + tmp3);
auto tmp5 = 1 / std::sqrt(tmp4);
auto tmp7 = at::vec::Vectorized<float>(tmp5);
auto tmp8 = tmp7 * tmp6;
auto tmp10 = decltype(tmp9)(-tmp9);
auto tmp11 = at::vec::Vectorized<float>(tmp10);
auto tmp12 = tmp11 * tmp8;
auto tmp14 = tmp12 + tmp13;
tmp8.store(out_ptr2 + static_cast<int64_t>(x2 + 32L*x1 + 64L*x0));
tmp14.store(out_ptr3 + static_cast<int64_t>(x2 + 32L*x1 + 64L*x0));
}
}
}
}
}
}
}
{
#pragma omp for collapse(2)
for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(2L); x0+=static_cast<int64_t>(1L))
{
for(int64_t x1=static_cast<int64_t>(0L); x1<static_cast<int64_t>(28224L); x1+=static_cast<int64_t>(1L))
{
for(int64_t x2=static_cast<int64_t>(0L); x2<static_cast<int64_t>(64L); x2+=static_cast<int64_t>(16L))
{
{
if(C10_LIKELY(x2 >= static_cast<int64_t>(0) && x2 < static_cast<int64_t>(64L)))
{
auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<int64_t>(x2 + 64L*x1 + 1806336L*x0), static_cast<int64_t>(16));
auto tmp1 = at::vec::Vectorized<float>::loadu(out_ptr2 + static_cast<int64_t>(x2 + 64L*x0), static_cast<int64_t>(16));
auto tmp3 = at::vec::Vectorized<float>::loadu(out_ptr3 + static_cast<int64_t>(x2 + 64L*x0), static_cast<int64_t>(16));
auto tmp2 = tmp0 * tmp1;
auto tmp4 = tmp2 + tmp3;
tmp4.store(out_ptr4 + static_cast<int64_t>(x2 + 64L*x1 + 1806336L*x0));
}
}
}
}
}
}
}
}
''')
```
- After:
```
cpp_fused_native_group_norm_0 = async_compile.cpp_pybinding(['const float*', 'const float*', 'const float*', 'float*', 'float*', 'float*', 'float*', 'float*'], '''
#include "/tmp/torchinductor_jiayisun/pi/cpicxudqmdsjh5cm4klbtbrvy2cxwr7whxl3md2zzdjdf3orvfdf.h"
extern "C" void kernel(const float* in_ptr0,
const float* in_ptr1,
const float* in_ptr2,
float* out_ptr0,
float* out_ptr1,
float* out_ptr2,
float* out_ptr3,
float* out_ptr4)
{
{
#pragma GCC ivdep
for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(2L); x0+=static_cast<int64_t>(1L))
{
#pragma GCC ivdep
for(int64_t x1=static_cast<int64_t>(0L); x1<static_cast<int64_t>(2L); x1+=static_cast<int64_t>(1L))
{
{
Welford<float> tmp_acc0 = Welford<float>();
Welford<at::vec::Vectorized<float>> tmp_acc0_vec = Welford<at::vec::Vectorized<float>>();
Welford<at::vec::Vectorized<float>> masked_tmp_acc0_vec = Welford<at::vec::Vectorized<float>>();
Welford<at::vec::Vectorized<float>> tmp_acc0_vec_arr[56];
for (int i = 0; i < 56; i++)
{
tmp_acc0_vec_arr[i] = Welford<at::vec::Vectorized<float>>();
}
Welford<float> tmp_acc0_arr[56];
for (int i = 0; i < 56; i++)
{
tmp_acc0_arr[i] = Welford<float>();
}
Welford<at::vec::Vectorized<float>> masked_tmp_acc0_vec_arr[56];
for (int i = 0; i < 56; i++)
{
masked_tmp_acc0_vec_arr[i] = Welford<at::vec::Vectorized<float>>();
}
#pragma omp parallel num_threads(56)
{
int tid = omp_get_thread_num();
static WeightRecp<at::vec::Vectorized<float>> wrecps0(static_cast<int64_t>(1008L));
Welford<at::vec::Vectorized<float>> tmp_acc0_vec_local = Welford<at::vec::Vectorized<float>>();
Welford<float> tmp_acc0_local = Welford<float>();
Welford<at::vec::Vectorized<float>> masked_tmp_acc0_vec_local = Welford<at::vec::Vectorized<float>>();
#pragma omp for
for(int64_t x2=static_cast<int64_t>(0L); x2<static_cast<int64_t>(28224L); x2+=static_cast<int64_t>(1L))
{
for(int64_t x3=static_cast<int64_t>(0L); x3<static_cast<int64_t>(32L); x3+=static_cast<int64_t>(16L))
{
{
if(C10_LIKELY(x3 >= static_cast<int64_t>(0) && x3 < static_cast<int64_t>(32L)))
{
auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<int64_t>(x3 + 32L*x1 + 64L*x2 + 1806336L*x0), static_cast<int64_t>(16));
tmp_acc0_vec_local = welford_combine(tmp_acc0_vec_local, tmp0, &wrecps0);
}
}
}
}
tmp_acc0_vec_arr[tid] = tmp_acc0_vec_local;
tmp_acc0_arr[tid] = tmp_acc0_local;
masked_tmp_acc0_vec_arr[tid] = masked_tmp_acc0_vec_local;
}
for (int tid = 0; tid < 56; tid++)
{
tmp_acc0_vec = welford_combine(tmp_acc0_vec, tmp_acc0_vec_arr[tid]);
}
for (int tid = 0; tid < 56; tid++)
{
tmp_acc0 = welford_combine(tmp_acc0, tmp_acc0_arr[tid]);
}
for (int tid = 0; tid < 56; tid++)
{
masked_tmp_acc0_vec = welford_combine(masked_tmp_acc0_vec, masked_tmp_acc0_vec_arr[tid]);
}
tmp_acc0 = welford_combine(tmp_acc0, welford_vec_reduce_all(masked_tmp_acc0_vec));
tmp_acc0 = welford_combine(tmp_acc0, welford_vec_reduce_all(tmp_acc0_vec));
out_ptr0[static_cast<int64_t>(x1 + 2L*x0)] = static_cast<float>(tmp_acc0.mean);
out_ptr1[static_cast<int64_t>(x1 + 2L*x0)] = static_cast<float>(tmp_acc0.m2);
}
}
}
}
{
#pragma GCC ivdep
for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(2L); x0+=static_cast<int64_t>(1L))
{
#pragma GCC ivdep
for(int64_t x1=static_cast<int64_t>(0L); x1<static_cast<int64_t>(2L); x1+=static_cast<int64_t>(1L))
{
for(int64_t x2=static_cast<int64_t>(0L); x2<static_cast<int64_t>(32L); x2+=static_cast<int64_t>(16L))
{
{
if(C10_LIKELY(x2 >= static_cast<int64_t>(0) && x2 < static_cast<int64_t>(32L)))
{
auto tmp0 = out_ptr1[static_cast<int64_t>(x1 + 2L*x0)];
auto tmp6 = at::vec::Vectorized<float>::loadu(in_ptr1 + static_cast<int64_t>(x2 + 32L*x1), static_cast<int64_t>(16));
auto tmp9 = out_ptr0[static_cast<int64_t>(x1 + 2L*x0)];
auto tmp13 = at::vec::Vectorized<float>::loadu(in_ptr2 + static_cast<int64_t>(x2 + 32L*x1), static_cast<int64_t>(16));
auto tmp1 = static_cast<float>(903168.0);
auto tmp2 = tmp0 / tmp1;
auto tmp3 = static_cast<float>(1e-05);
auto tmp4 = decltype(tmp2)(tmp2 + tmp3);
auto tmp5 = 1 / std::sqrt(tmp4);
auto tmp7 = at::vec::Vectorized<float>(tmp5);
auto tmp8 = tmp7 * tmp6;
auto tmp10 = decltype(tmp9)(-tmp9);
auto tmp11 = at::vec::Vectorized<float>(tmp10);
auto tmp12 = tmp11 * tmp8;
auto tmp14 = tmp12 + tmp13;
tmp8.store(out_ptr2 + static_cast<int64_t>(x2 + 32L*x1 + 64L*x0));
tmp14.store(out_ptr3 + static_cast<int64_t>(x2 + 32L*x1 + 64L*x0));
}
}
}
}
}
}
#pragma omp parallel num_threads(56)
{
int tid = omp_get_thread_num();
{
#pragma omp for collapse(2)
for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(2L); x0+=static_cast<int64_t>(1L))
{
for(int64_t x1=static_cast<int64_t>(0L); x1<static_cast<int64_t>(28224L); x1+=static_cast<int64_t>(1L))
{
for(int64_t x2=static_cast<int64_t>(0L); x2<static_cast<int64_t>(64L); x2+=static_cast<int64_t>(16L))
{
{
if(C10_LIKELY(x2 >= static_cast<int64_t>(0) && x2 < static_cast<int64_t>(64L)))
{
auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<int64_t>(x2 + 64L*x1 + 1806336L*x0), static_cast<int64_t>(16));
auto tmp1 = at::vec::Vectorized<float>::loadu(out_ptr2 + static_cast<int64_t>(x2 + 64L*x0), static_cast<int64_t>(16));
auto tmp3 = at::vec::Vectorized<float>::loadu(out_ptr3 + static_cast<int64_t>(x2 + 64L*x0), static_cast<int64_t>(16));
auto tmp2 = tmp0 * tmp1;
auto tmp4 = tmp2 + tmp3;
tmp4.store(out_ptr4 + static_cast<int64_t>(x2 + 64L*x1 + 1806336L*x0));
}
}
}
}
}
}
}
}
''')
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144020
Approved by: https://github.com/leslie-fang-intel, https://github.com/jansel, https://github.com/jgong5
455 lines
14 KiB
Python
455 lines
14 KiB
Python
from __future__ import annotations
|
|
|
|
import csv
|
|
import dataclasses
|
|
import inspect
|
|
import os
|
|
import re
|
|
from dataclasses import dataclass
|
|
from functools import lru_cache
|
|
from typing import Callable, cast, Optional, TYPE_CHECKING, Union
|
|
|
|
from torch._inductor import config
|
|
from torch._inductor.utils import get_benchmark_name
|
|
from torch.utils._ordered_set import OrderedSet
|
|
|
|
|
|
# Prevent circular import
|
|
if TYPE_CHECKING:
|
|
from torch._inductor.scheduler import BaseSchedulerNode
|
|
|
|
# counter for tracking how many kernels have been generated
|
|
generated_kernel_count = 0
|
|
generated_cpp_vec_kernel_count = 0
|
|
num_bytes_accessed = 0
|
|
nodes_num_elem: list[
|
|
tuple[
|
|
BaseSchedulerNode,
|
|
int,
|
|
]
|
|
] = []
|
|
node_runtimes: list[tuple[BaseSchedulerNode, float]] = []
|
|
|
|
# counters for tracking fusions
|
|
ir_nodes_pre_fusion = 0
|
|
|
|
# counters for tracking to_dtype inserted
|
|
cpp_to_dtype_count = 0
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class CppOuterLoopFusedCount:
|
|
inner_kernel_number: int
|
|
local_buffer_number: int = 0
|
|
|
|
|
|
# The length counts the number of outer loop fusions.
|
|
cpp_outer_loop_fused_inner_counts: list[CppOuterLoopFusedCount] = []
|
|
|
|
num_comprehensive_padding = 0
|
|
num_matches_for_scatter_upon_const_tensor = 0
|
|
|
|
num_loop_reordering = 0
|
|
|
|
# counter for parallel reduction.
|
|
parallel_reduction_count = 0
|
|
|
|
|
|
# reset all counters
|
|
def reset() -> None:
|
|
global generated_kernel_count
|
|
global generated_cpp_vec_kernel_count
|
|
global num_bytes_accessed, nodes_num_elem
|
|
global ir_nodes_pre_fusion
|
|
global cpp_to_dtype_count
|
|
global cpp_outer_loop_fused_inner_counts
|
|
global num_comprehensive_padding
|
|
global num_matches_for_scatter_upon_const_tensor
|
|
global num_loop_reordering
|
|
global parallel_reduction_count
|
|
|
|
generated_kernel_count = 0
|
|
generated_cpp_vec_kernel_count = 0
|
|
num_bytes_accessed = 0
|
|
nodes_num_elem.clear()
|
|
node_runtimes.clear()
|
|
ir_nodes_pre_fusion = 0
|
|
cpp_to_dtype_count = 0
|
|
cpp_outer_loop_fused_inner_counts.clear()
|
|
num_comprehensive_padding = 0
|
|
num_matches_for_scatter_upon_const_tensor = 0
|
|
num_loop_reordering = 0
|
|
parallel_reduction_count = 0
|
|
|
|
|
|
@dataclass
|
|
class CachedMetricsDeltas:
|
|
"""
|
|
The subset of metrics we want update across cache hits, e.g., the
|
|
FxGraphCache.
|
|
"""
|
|
|
|
generated_kernel_count: int
|
|
generated_cpp_vec_kernel_count: int
|
|
ir_nodes_pre_fusion: int
|
|
cpp_to_dtype_count: int
|
|
num_bytes_accessed: int
|
|
num_matches_for_scatter_upon_const_tensor: int
|
|
|
|
|
|
def get_metric_fields() -> list[str]:
|
|
return [field.name for field in dataclasses.fields(CachedMetricsDeltas)]
|
|
|
|
|
|
class CachedMetricsHelper:
|
|
"""
|
|
A helper class to help calculate and apply counter deltas for those
|
|
metrics we want to save with cache entries (e.g., FxGraphCache) and
|
|
apply on a cache hit.
|
|
"""
|
|
|
|
def __init__(self) -> None:
|
|
self.cached_metrics = {}
|
|
for metric in get_metric_fields():
|
|
self.cached_metrics[metric] = globals()[metric]
|
|
|
|
def get_deltas(self) -> CachedMetricsDeltas:
|
|
delta_metrics = {}
|
|
for metric in get_metric_fields():
|
|
delta_metrics[metric] = globals()[metric] - self.cached_metrics[metric]
|
|
|
|
return CachedMetricsDeltas(**delta_metrics)
|
|
|
|
@staticmethod
|
|
def apply_deltas(delta: CachedMetricsDeltas) -> None:
|
|
for metric in get_metric_fields():
|
|
globals()[metric] += getattr(delta, metric)
|
|
|
|
|
|
REGISTERED_METRIC_TABLES: dict[str, MetricTable] = {}
|
|
|
|
|
|
@dataclass
|
|
class MetricTable:
|
|
table_name: str
|
|
column_names: list[str]
|
|
|
|
num_rows_added: int = 0
|
|
|
|
def add_row(
|
|
self, row_fn: Callable[[], dict[str, Optional[Union[str, float]]]]
|
|
) -> None:
|
|
if self.table_name not in enabled_metric_tables():
|
|
return
|
|
|
|
row_dict = row_fn()
|
|
assert len(self.column_names) == len(row_dict), (
|
|
f"{len(self.column_names)} v.s. {len(row_dict)}"
|
|
)
|
|
assert OrderedSet(self.column_names) == OrderedSet(row_dict.keys()), (
|
|
f"{OrderedSet(self.column_names)} v.s. {OrderedSet(row_dict.keys())}"
|
|
)
|
|
|
|
bn = get_benchmark_name()
|
|
# assert bn is not None
|
|
row = [bn] + [row_dict[column_name] for column_name in self.column_names]
|
|
assert all(isinstance(i, str) for i in row)
|
|
self._write_row(cast(list[str], row))
|
|
|
|
def output_filename(self) -> str:
|
|
return f"metric_table_{self.table_name}.csv"
|
|
|
|
def write_header(self) -> None:
|
|
filename = self.output_filename()
|
|
with open(filename, "w") as fd:
|
|
writer = csv.writer(fd, lineterminator="\n")
|
|
writer.writerow(["model_name"] + self.column_names)
|
|
|
|
def _write_row(self, row: list[str]) -> None:
|
|
filename = self.output_filename()
|
|
if self.num_rows_added == 0 and not os.path.exists(filename):
|
|
self.write_header()
|
|
|
|
self.num_rows_added += 1
|
|
|
|
for idx, orig_val in enumerate(row):
|
|
if isinstance(orig_val, float):
|
|
new_val = f"{orig_val:.6f}"
|
|
elif orig_val is None:
|
|
new_val = ""
|
|
else:
|
|
new_val = orig_val
|
|
row[idx] = new_val
|
|
|
|
with open(filename, "a") as fd:
|
|
writer = csv.writer(fd, lineterminator="\n")
|
|
writer.writerow(row)
|
|
|
|
@staticmethod
|
|
def register_table(name: str, column_names: list[str]) -> None:
|
|
table = MetricTable(name, column_names)
|
|
REGISTERED_METRIC_TABLES[name] = table
|
|
|
|
|
|
MetricTable.register_table(
|
|
"slow_fusion",
|
|
[
|
|
"kernel1_path",
|
|
"kernel1_latency",
|
|
"kernel2_path",
|
|
"kernel2_latency",
|
|
"fused_kernel_path",
|
|
"fused_kernel_latency",
|
|
"slow_down_ratio",
|
|
],
|
|
)
|
|
|
|
# track the fusion statistics for each graph
|
|
MetricTable.register_table(
|
|
"graph_stats",
|
|
[
|
|
"graph_id",
|
|
"num_nodes_before_fusion",
|
|
"num_nodes_after_fusion",
|
|
],
|
|
)
|
|
|
|
# track the perf difference between persistent reduction and non-persistent
|
|
# reductions
|
|
MetricTable.register_table(
|
|
"persistent_red_perf",
|
|
[
|
|
"kernel0_path",
|
|
"kernel1_path",
|
|
"kernel2_path",
|
|
"kernel3_path",
|
|
"kernel0_latency",
|
|
"kernel1_latency",
|
|
"kernel2_latency",
|
|
"kernel3_latency",
|
|
"size_hints",
|
|
"reduction_hint",
|
|
],
|
|
)
|
|
|
|
# Log the fusion failures due to indexing mismatch
|
|
MetricTable.register_table(
|
|
"fusion_failure_due_to_indexing_mismatch",
|
|
[
|
|
"pre_grad_graph_id",
|
|
"post_grad_graph_id",
|
|
"node1_name",
|
|
"node2_name",
|
|
"node1_debug_str",
|
|
"node2_debug_str",
|
|
"common_buffer_names",
|
|
"failure_reason",
|
|
],
|
|
)
|
|
|
|
# Log metadata for pointwise/reduction kernels. E.g., model name, kernel path, numel, rnumel, reduction hint
|
|
MetricTable.register_table(
|
|
"kernel_metadata",
|
|
[
|
|
"kernel_name",
|
|
"kernel_path",
|
|
"kernel_category", # pointwise/reduction/foreach etc.
|
|
"size_hints",
|
|
"reduction_hint",
|
|
"line_of_code",
|
|
"num_load",
|
|
"num_store",
|
|
"num_for_loop",
|
|
"num_atomic_add",
|
|
"num_args",
|
|
# xyz numel can be different to size_hints since size_hints are rounded
|
|
# up to the nearest power of 2.
|
|
# Inductor kernel will burn in the xyz numel in kernel code for static
|
|
# shape kernels.
|
|
# Logging them will be helpful to find unaligned shape for reduction
|
|
"xnumel",
|
|
"ynumel",
|
|
"rnumel",
|
|
"kernel_args_num_gb",
|
|
],
|
|
)
|
|
|
|
|
|
def _parse_kernel_fn_code(kernel_module_code: str) -> str:
|
|
"""
|
|
The kernel_module_code is the python module that contains kernel function code.
|
|
kernel function is the proper triton kernel function annotated with
|
|
@triton.jit
|
|
"""
|
|
from .codecache import PyCodeCache
|
|
from .wrapper_benchmark import get_triton_kernel
|
|
|
|
mod = PyCodeCache.load(kernel_module_code)
|
|
kernel = get_triton_kernel(mod)
|
|
# kernel is a CachingAutotune; kernel.fn is the JITFunction;
|
|
# kernel.fn.fn is the function being decorate by triton.jit
|
|
return inspect.getsource(kernel.fn.fn)
|
|
|
|
|
|
def _parse_kernel_line_of_code(proper_kernel_fn_code: str) -> int:
|
|
"""
|
|
Return the line of code for the kernel excluding the decorators.
|
|
"""
|
|
return len(proper_kernel_fn_code.splitlines())
|
|
|
|
|
|
def _parse_size_hints(kernel_module_code: str, kernel_category: str) -> Optional[str]:
|
|
if kernel_category == "foreach":
|
|
# foreach kernel does not have size_hints
|
|
return None
|
|
m = re.search(r"size_hints=(\[[0-9, ]*\]),", kernel_module_code)
|
|
assert m, "size_hints missing!"
|
|
return m.group(1)
|
|
|
|
|
|
def _parse_reduction_hint(
|
|
kernel_category: str, kernel_module_code: str
|
|
) -> Optional[str]:
|
|
if kernel_category not in ("reduction", "persistent_reduction"):
|
|
return None
|
|
m = re.search(r"reduction_hint=ReductionHint\.(\w*),", kernel_module_code)
|
|
assert m, "reduction_hint not found in kernel source code!"
|
|
return m.group(1)
|
|
|
|
|
|
def _count_pattern(proper_kernel_fn_code: str, pattern: str) -> int:
|
|
return proper_kernel_fn_code.count(pattern)
|
|
|
|
|
|
def _count_args(proper_kernel_fn_code: str) -> int:
|
|
def_line = proper_kernel_fn_code.splitlines()[0]
|
|
assert def_line.startswith("def ")
|
|
start_idx = def_line.index("(")
|
|
end_idx = def_line.index("):")
|
|
decl_csv = def_line[start_idx + 1 : end_idx]
|
|
comps = decl_csv.split(",")
|
|
return len(comps)
|
|
|
|
|
|
def _parse_proper_kernel_fn_code(kernel_fn_code: str) -> str:
|
|
"""
|
|
Skip decorators.
|
|
"""
|
|
start_pos = kernel_fn_code.index("def ")
|
|
return kernel_fn_code[start_pos:]
|
|
|
|
|
|
def _parse_numel(proper_kernel_fn_code: str, numel_arg_name: str) -> Optional[int]:
|
|
m = re.search(f"{numel_arg_name} = ([\\d]+)", proper_kernel_fn_code)
|
|
if m:
|
|
return int(m.group(1))
|
|
else:
|
|
return None
|
|
|
|
|
|
def _parse_kernel_args_num_gb(
|
|
kernel_fn_code: str, kernel_category: str
|
|
) -> Optional[float]:
|
|
"""
|
|
inductor meta looks like:
|
|
inductor_meta={... 'mutated_arg_names': [], 'no_x_dim': False, 'kernel_num_gb': 2.0},
|
|
"""
|
|
m = re.search(r".kernel_num_gb.:\s*([0-9.]+)", kernel_fn_code)
|
|
if m:
|
|
return float(m.group(1))
|
|
else:
|
|
"""
|
|
There are a few cases that kernel_num_gdb field can be missing:
|
|
1. the field will be missing if config.benchmark_kernel and
|
|
config.profile_bandwidth are false
|
|
2. even if config.benchmark_kernel or config.profile_bandwidth is true.
|
|
foreach kernel does not have kernel_num_gb field in the metadata
|
|
"""
|
|
return None
|
|
|
|
|
|
def log_kernel_metadata(
|
|
kernel_name: str, kernel_path: str, kernel_module_code: str
|
|
) -> None:
|
|
"""
|
|
An utility to log kernel metadata. We may parse metadata from kernel source code here.
|
|
|
|
It's fine to parse the generated kernel code here since the logging is
|
|
disabled by default. It would hurt compilation time.
|
|
"""
|
|
from .wrapper_benchmark import get_kernel_category_by_source_code
|
|
|
|
kernel_category = get_kernel_category_by_source_code(kernel_module_code)
|
|
reduction_hint = _parse_reduction_hint(kernel_category, kernel_module_code)
|
|
size_hints = _parse_size_hints(kernel_module_code, kernel_category)
|
|
kernel_fn_code = _parse_kernel_fn_code(kernel_module_code)
|
|
|
|
proper_kernel_fn_code = _parse_proper_kernel_fn_code(kernel_fn_code)
|
|
|
|
# the line of code excluding the decortors
|
|
kernel_line_of_code = _parse_kernel_line_of_code(proper_kernel_fn_code)
|
|
|
|
get_metric_table("kernel_metadata").add_row(
|
|
lambda: {
|
|
"kernel_name": kernel_name,
|
|
"kernel_path": kernel_path,
|
|
"kernel_category": kernel_category,
|
|
"size_hints": size_hints,
|
|
"reduction_hint": reduction_hint,
|
|
"line_of_code": kernel_line_of_code,
|
|
"num_load": _count_pattern(proper_kernel_fn_code, "tl.load"),
|
|
"num_store": _count_pattern(proper_kernel_fn_code, "tl.store"),
|
|
"num_for_loop": _count_pattern(proper_kernel_fn_code, "for "),
|
|
"num_atomic_add": _count_pattern(proper_kernel_fn_code, "tl.atomic_add"),
|
|
"num_args": _count_args(proper_kernel_fn_code),
|
|
"xnumel": _parse_numel(proper_kernel_fn_code, "xnumel"),
|
|
"ynumel": _parse_numel(proper_kernel_fn_code, "ynumel"),
|
|
"rnumel": _parse_numel(proper_kernel_fn_code, "rnumel"),
|
|
"kernel_args_num_gb": _parse_kernel_args_num_gb(
|
|
kernel_fn_code, kernel_category
|
|
),
|
|
}
|
|
)
|
|
|
|
|
|
def purge_old_log_files() -> None:
|
|
"""
|
|
Purge the old log file at the beginning when the benchmark script runs.
|
|
Should do it in the parent process rather than the child processes running
|
|
each individual model.
|
|
"""
|
|
for name, table in REGISTERED_METRIC_TABLES.items():
|
|
if name in enabled_metric_tables():
|
|
filename = table.output_filename()
|
|
if os.path.exists(filename):
|
|
os.unlink(filename)
|
|
|
|
table.write_header()
|
|
|
|
|
|
def enabled_metric_tables() -> OrderedSet[str]:
|
|
return enabled_metric_tables_impl(config.enabled_metric_tables)
|
|
|
|
|
|
@lru_cache
|
|
def enabled_metric_tables_impl(config_str: str) -> OrderedSet[str]:
|
|
enabled = OrderedSet[str]()
|
|
for name in config_str.split(","):
|
|
name = name.strip()
|
|
if not name:
|
|
continue
|
|
assert name in REGISTERED_METRIC_TABLES, (
|
|
f"Metric table name {name} is not registered"
|
|
)
|
|
enabled.add(name)
|
|
return enabled
|
|
|
|
|
|
def is_metric_table_enabled(name: str) -> bool:
|
|
return name in enabled_metric_tables()
|
|
|
|
|
|
def get_metric_table(name: str) -> MetricTable:
|
|
assert name in REGISTERED_METRIC_TABLES, f"Metric table {name} is not defined"
|
|
return REGISTERED_METRIC_TABLES[name]
|