mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Working as starter task with @Chillee This PR adds a method under BaseSchedulerNode to estimate the node's runtime in seconds. We use a heuristic based approach, first by considering whether the operation is memory bandwidth bounded or compute bounded: - memory bandwidth bounded: we compute the number of bytes that are read/written to - compute bounded: we compute the FLOPS required by the operation One use case could be to be used as a cost model for scheduling: https://github.com/pytorch/pytorch/pull/100762 ``` (pytorch-3.10) [14:08:02] ~/local/pytorch (xmfan/estimate_snode_runtime) > python3 test/inductor/test_perf.py -k EstimateSnodeRuntimeTests [(ExternKernelSchedulerNode(name='buf0'), 400)] [(ExternKernelSchedulerNode(name='buf0'), 2.35057908433887e-27)] .[(ExternKernelSchedulerNode(name='buf0'), 3000), (SchedulerNode(name='buf1'), 3000)] [(ExternKernelSchedulerNode(name='buf0'), 2.35057908433887e-26), (SchedulerNode(name='buf1'), 7.187055238190188e-09)] .[(ExternKernelSchedulerNode(name='buf0'), 3000)] [(ExternKernelSchedulerNode(name='buf0'), 2.35057908433887e-26)] .[(ExternKernelSchedulerNode(name='buf0'), 34600)] [(ExternKernelSchedulerNode(name='buf0'), 3.22687496698039e-24)] .[(ExternKernelSchedulerNode(name='buf0'), 396)] [(ExternKernelSchedulerNode(name='buf0'), 1.88046326747109e-27)] .[(ExternKernelSchedulerNode(name='buf0'), 396)] [(ExternKernelSchedulerNode(name='buf0'), 1.88046326747109e-27)] .[(ExternKernelSchedulerNode(name='buf0'), 7776176)] [(ExternKernelSchedulerNode(name='buf0'), 4.63240241413653e-21)] .[(FusedSchedulerNode(nodes=buf0_buf1), 210)] [(FusedSchedulerNode(nodes=buf0_buf1), 5.030938666733132e-10)] .[(ExternKernelSchedulerNode(name='buf0'), 300)] [(ExternKernelSchedulerNode(name='buf0'), 2.35057908433887e-27)] .[(SchedulerNode(name='buf0'), 20)] [(SchedulerNode(name='buf0'), 4.7913701587934585e-11)] . ---------------------------------------------------------------------- Ran 10 tests in 14.311s OK ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/106426 Approved by: https://github.com/Chillee
53 lines
1.3 KiB
Python
53 lines
1.3 KiB
Python
from __future__ import annotations
|
|
|
|
from typing import List, Tuple, TYPE_CHECKING, Union
|
|
|
|
# Prevent circular import
|
|
if TYPE_CHECKING:
|
|
from torch._inductor.scheduler import (
|
|
BaseSchedulerNode,
|
|
ExternKernelSchedulerNode,
|
|
NopKernelSchedulerNode,
|
|
SchedulerNode,
|
|
)
|
|
|
|
# counter for tracking how many kernels have been generated
|
|
generated_kernel_count = 0
|
|
generated_cpp_vec_kernel_count = 0
|
|
num_bytes_accessed = 0
|
|
nodes_num_elem: List[
|
|
Tuple[
|
|
Union[NopKernelSchedulerNode, SchedulerNode, ExternKernelSchedulerNode],
|
|
int,
|
|
]
|
|
] = []
|
|
node_runtimes: List[Tuple[BaseSchedulerNode, float]] = []
|
|
|
|
# counters for tracking fusions
|
|
ir_nodes_pre_fusion = 0
|
|
|
|
# counters for tracking to_dtype inserted
|
|
cpp_to_dtype_count = 0
|
|
|
|
# counters for tracking cpp_wrapper disabled
|
|
disable_cpp_wrapper = 0
|
|
|
|
|
|
# reset all counters
|
|
def reset():
|
|
global generated_kernel_count
|
|
global generated_cpp_vec_kernel_count
|
|
global num_bytes_accessed, nodes_num_elem
|
|
global ir_nodes_pre_fusion
|
|
global cpp_to_dtype_count
|
|
global disable_cpp_wrapper
|
|
|
|
generated_kernel_count = 0
|
|
generated_cpp_vec_kernel_count = 0
|
|
num_bytes_accessed = 0
|
|
nodes_num_elem.clear()
|
|
node_runtimes.clear()
|
|
ir_nodes_pre_fusion = 0
|
|
cpp_to_dtype_count = 0
|
|
disable_cpp_wrapper = 0
|