[mm sampling] extract more triton information (#153099)

Summary:
# Why

capture more triton config information that was not being captured

# What

capture and extract

- group_m
- allow_tf32
- acc_type
- matrix_instr_nonkdim
- waves_per_eu
- kpack

to achieve this, add

- matrix_instr_nonkdim
- waves_per_eu
- kpack

to the info_dict of the TritonTemplateCaller

Test Plan:
with D74342290

```
buck2 run -c fbcode.rocm_arch=mi300 -m rocm621 mode/opt-amd-gpu  fbcode//deeplearning/aot_inductor/benchmark/sampling:test_gemm_autotune_benchmark_AMD_block_0 2>&1 | tee /tmp/tmp.52Igj8lthj/15.txt
```

(edited for clarity and brevity)

```
AutotuneMetrics03LogEntry(
    backend='Triton',
    exectime_ms=0.007449999917298555,
    perf_model_name='scripts.vandrei.pytorch_experiments.matmul_estimator_lib.estimate_matmul_time_new',
    perf_model_exectime_ms=0.009558684365573179,
    config_triton_block_m=16,
    config_triton_block_n=256,
    config_triton_block_k=128,
    config_triton_num_stages=2,
    config_triton_num_warps=8,
    config_triton_group_m=16,
    config_triton_allow_tf32='False',
    config_triton_acc_type='tl.float32',
    config_triton_matrix_instr_nonkdim=16,
    config_triton_waves_per_eu=1,
    config_triton_kpack=2,
    x_batch_dim=0,
    x_row_dim=8,
    x_col_dim=96,
    x_batch_stride=0,
    x_row_stride=96,
    x_col_stride=1,
    x_dtype='torch.float16',
    x_dtype_size=16,
    w_batch_dim=0,
    w_row_dim=96,
    w_col_dim=512,
    w_batch_stride=0,
    w_row_stride=512,
    w_col_stride=1,
    w_dtype='torch.float16',
    w_dtype_size=16,
    vendor='AMD',
    model='gfx942:sramecc+:xnack-',
    major=9,
    minor=4,
    sms=304,
    l2_cache=4194304,
    warp_size=64,
    regs_per_sm=65536,
    max_threads_per_sm=2048,
    total_mem=206141652992,
    hip_version='6.2.41134',
    triton_upstream_hash='3889f3f3b97b817741e308c173409927b7c4536f',
    environment='experiment-xzy-default',
    session_id='8a7001bd-652c-440c-bc56-4cb1e25146ea',
    [...]
)
```

Reviewed By: exclamaforte

Differential Revision: D74342286

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153099
Approved by: https://github.com/exclamaforte, https://github.com/eellison
This commit is contained in:
Ruben Rodriguez Buchillon 2025-05-08 07:24:28 +00:00 committed by PyTorch MergeBot
parent 3c87529d23
commit f9df09da08

View File

@ -1395,6 +1395,9 @@ class TritonTemplate(KernelTemplate):
"GROUP_M": kwargs.get("GROUP_M", -1),
"allow_tf32": str(kwargs.get("ALLOW_TF32", None)),
"acc_type": str(kwargs.get("ACC_TYPE", None)),
"matrix_instr_nonkdim": kwargs.get("matrix_instr_nonkdim", 0),
"waves_per_eu": kwargs.get("waves_per_eu", 0),
"kpack": kwargs.get("kpack", 2),
},
mutated_inputs=mutated_inputs,
workspace_arg=workspace_arg,