pytorch/torch/_inductor/metrics.py
Simon Fan aca3d1433c Estimate Scheduler node runtimes (#106426)
Working as starter task with @Chillee

This PR adds a method under BaseSchedulerNode to estimate the node's runtime in seconds.

We use a heuristic based approach, first by considering whether the operation is memory bandwidth bounded or compute bounded:
- memory bandwidth bounded: we compute the number of bytes that are read/written to
- compute bounded: we compute the FLOPS required by the operation

One use case could be to be used as a cost model for scheduling: https://github.com/pytorch/pytorch/pull/100762

```
(pytorch-3.10) [14:08:02] ~/local/pytorch (xmfan/estimate_snode_runtime) > python3 test/inductor/test_perf.py -k EstimateSnodeRuntimeTests
[(ExternKernelSchedulerNode(name='buf0'), 400)]
[(ExternKernelSchedulerNode(name='buf0'), 2.35057908433887e-27)]
.[(ExternKernelSchedulerNode(name='buf0'), 3000), (SchedulerNode(name='buf1'), 3000)]
[(ExternKernelSchedulerNode(name='buf0'), 2.35057908433887e-26), (SchedulerNode(name='buf1'), 7.187055238190188e-09)]
.[(ExternKernelSchedulerNode(name='buf0'), 3000)]
[(ExternKernelSchedulerNode(name='buf0'), 2.35057908433887e-26)]
.[(ExternKernelSchedulerNode(name='buf0'), 34600)]
[(ExternKernelSchedulerNode(name='buf0'), 3.22687496698039e-24)]
.[(ExternKernelSchedulerNode(name='buf0'), 396)]
[(ExternKernelSchedulerNode(name='buf0'), 1.88046326747109e-27)]
.[(ExternKernelSchedulerNode(name='buf0'), 396)]
[(ExternKernelSchedulerNode(name='buf0'), 1.88046326747109e-27)]
.[(ExternKernelSchedulerNode(name='buf0'), 7776176)]
[(ExternKernelSchedulerNode(name='buf0'), 4.63240241413653e-21)]
.[(FusedSchedulerNode(nodes=buf0_buf1), 210)]
[(FusedSchedulerNode(nodes=buf0_buf1), 5.030938666733132e-10)]
.[(ExternKernelSchedulerNode(name='buf0'), 300)]
[(ExternKernelSchedulerNode(name='buf0'), 2.35057908433887e-27)]
.[(SchedulerNode(name='buf0'), 20)]
[(SchedulerNode(name='buf0'), 4.7913701587934585e-11)]
.
----------------------------------------------------------------------
Ran 10 tests in 14.311s
OK
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/106426
Approved by: https://github.com/Chillee
2023-08-17 17:23:30 +00:00

53 lines
1.3 KiB
Python

from __future__ import annotations
from typing import List, Tuple, TYPE_CHECKING, Union
# Prevent circular import
if TYPE_CHECKING:
from torch._inductor.scheduler import (
BaseSchedulerNode,
ExternKernelSchedulerNode,
NopKernelSchedulerNode,
SchedulerNode,
)
# counter for tracking how many kernels have been generated
generated_kernel_count = 0
generated_cpp_vec_kernel_count = 0
num_bytes_accessed = 0
nodes_num_elem: List[
Tuple[
Union[NopKernelSchedulerNode, SchedulerNode, ExternKernelSchedulerNode],
int,
]
] = []
node_runtimes: List[Tuple[BaseSchedulerNode, float]] = []
# counters for tracking fusions
ir_nodes_pre_fusion = 0
# counters for tracking to_dtype inserted
cpp_to_dtype_count = 0
# counters for tracking cpp_wrapper disabled
disable_cpp_wrapper = 0
# reset all counters
def reset():
global generated_kernel_count
global generated_cpp_vec_kernel_count
global num_bytes_accessed, nodes_num_elem
global ir_nodes_pre_fusion
global cpp_to_dtype_count
global disable_cpp_wrapper
generated_kernel_count = 0
generated_cpp_vec_kernel_count = 0
num_bytes_accessed = 0
nodes_num_elem.clear()
node_runtimes.clear()
ir_nodes_pre_fusion = 0
cpp_to_dtype_count = 0
disable_cpp_wrapper = 0