mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[mm sampling] extract more triton information (#153099)
Summary:
# Why
capture more triton config information that was not being captured
# What
capture and extract
- group_m
- allow_tf32
- acc_type
- matrix_instr_nonkdim
- waves_per_eu
- kpack
to achieve this, add
- matrix_instr_nonkdim
- waves_per_eu
- kpack
to the info_dict of the TritonTemplateCaller
Test Plan:
with D74342290
```
buck2 run -c fbcode.rocm_arch=mi300 -m rocm621 mode/opt-amd-gpu fbcode//deeplearning/aot_inductor/benchmark/sampling:test_gemm_autotune_benchmark_AMD_block_0 2>&1 | tee /tmp/tmp.52Igj8lthj/15.txt
```
(edited for clarity and brevity)
```
AutotuneMetrics03LogEntry(
backend='Triton',
exectime_ms=0.007449999917298555,
perf_model_name='scripts.vandrei.pytorch_experiments.matmul_estimator_lib.estimate_matmul_time_new',
perf_model_exectime_ms=0.009558684365573179,
config_triton_block_m=16,
config_triton_block_n=256,
config_triton_block_k=128,
config_triton_num_stages=2,
config_triton_num_warps=8,
config_triton_group_m=16,
config_triton_allow_tf32='False',
config_triton_acc_type='tl.float32',
config_triton_matrix_instr_nonkdim=16,
config_triton_waves_per_eu=1,
config_triton_kpack=2,
x_batch_dim=0,
x_row_dim=8,
x_col_dim=96,
x_batch_stride=0,
x_row_stride=96,
x_col_stride=1,
x_dtype='torch.float16',
x_dtype_size=16,
w_batch_dim=0,
w_row_dim=96,
w_col_dim=512,
w_batch_stride=0,
w_row_stride=512,
w_col_stride=1,
w_dtype='torch.float16',
w_dtype_size=16,
vendor='AMD',
model='gfx942:sramecc+:xnack-',
major=9,
minor=4,
sms=304,
l2_cache=4194304,
warp_size=64,
regs_per_sm=65536,
max_threads_per_sm=2048,
total_mem=206141652992,
hip_version='6.2.41134',
triton_upstream_hash='3889f3f3b97b817741e308c173409927b7c4536f',
environment='experiment-xzy-default',
session_id='8a7001bd-652c-440c-bc56-4cb1e25146ea',
[...]
)
```
Reviewed By: exclamaforte
Differential Revision: D74342286
Pull Request resolved: https://github.com/pytorch/pytorch/pull/153099
Approved by: https://github.com/exclamaforte, https://github.com/eellison
This commit is contained in:
parent
3c87529d23
commit
f9df09da08
|
|
@ -1395,6 +1395,9 @@ class TritonTemplate(KernelTemplate):
|
|||
"GROUP_M": kwargs.get("GROUP_M", -1),
|
||||
"allow_tf32": str(kwargs.get("ALLOW_TF32", None)),
|
||||
"acc_type": str(kwargs.get("ACC_TYPE", None)),
|
||||
"matrix_instr_nonkdim": kwargs.get("matrix_instr_nonkdim", 0),
|
||||
"waves_per_eu": kwargs.get("waves_per_eu", 0),
|
||||
"kpack": kwargs.get("kpack", 2),
|
||||
},
|
||||
mutated_inputs=mutated_inputs,
|
||||
workspace_arg=workspace_arg,
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user