mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
As suggested in https://github.com/pytorch/pytorch/issues/105230 Implements small fix for torch/_inductor/metrics.py I ran into a circular import, which I handled using if TYPE_CHECKING (https://docs.python.org/3/library/typing.html#constant). There are then two options for describing the types, either use their class names as strings or use from future import annotations ``` If from __future__ import annotations is used, annotations are not evaluated at function definition time. Instead, they are stored as strings in __annotations__. This makes it unnecessary to use quotes around the annotation (see [PEP 563](https://peps.python.org/pep-0563/)). ``` I'm open to suggestions if it does not meet your coding guidelines Pull Request resolved: https://github.com/pytorch/pytorch/pull/105793 Approved by: https://github.com/Skylion007
50 lines
1.2 KiB
Python
50 lines
1.2 KiB
Python
from __future__ import annotations
|
|
|
|
from typing import List, Tuple, TYPE_CHECKING, Union
|
|
|
|
# Prevent circular import
|
|
if TYPE_CHECKING:
|
|
from torch._inductor.scheduler import (
|
|
ExternKernelSchedulerNode,
|
|
NopKernelSchedulerNode,
|
|
SchedulerNode,
|
|
)
|
|
|
|
# counter for tracking how many kernels have been generated
|
|
generated_kernel_count = 0
|
|
generated_cpp_vec_kernel_count = 0
|
|
num_bytes_accessed = 0
|
|
nodes_num_elem: List[
|
|
Tuple[
|
|
Union[NopKernelSchedulerNode, SchedulerNode, ExternKernelSchedulerNode],
|
|
int,
|
|
]
|
|
] = []
|
|
|
|
# counters for tracking fusions
|
|
ir_nodes_pre_fusion = 0
|
|
|
|
# counters for tracking to_dtype inserted
|
|
cpp_to_dtype_count = 0
|
|
|
|
# counters for tracking cpp_wrapper disabled
|
|
disable_cpp_wrapper = 0
|
|
|
|
|
|
# reset all counters
|
|
def reset():
|
|
global generated_kernel_count
|
|
global generated_cpp_vec_kernel_count
|
|
global num_bytes_accessed, nodes_num_elem
|
|
global ir_nodes_pre_fusion
|
|
global cpp_to_dtype_count
|
|
global disable_cpp_wrapper
|
|
|
|
generated_kernel_count = 0
|
|
generated_cpp_vec_kernel_count = 0
|
|
num_bytes_accessed = 0
|
|
nodes_num_elem.clear()
|
|
ir_nodes_pre_fusion = 0
|
|
cpp_to_dtype_count = 0
|
|
disable_cpp_wrapper = 0
|