pytorch/torch/_inductor/metrics.py
Piotr Gryko c0b8b7b90c [inductor] Enable mypy checking in torch/_inductor/metrics.py (#105793)
As suggested in https://github.com/pytorch/pytorch/issues/105230

Implements small fix for torch/_inductor/metrics.py

I ran into a circular import, which I handled using if TYPE_CHECKING (https://docs.python.org/3/library/typing.html#constant).

There are then two options for describing the types, either use their class names as strings or use from future import annotations

```
If from __future__ import annotations is used, annotations are not evaluated at function definition time.
Instead, they are stored as strings in __annotations__.
This makes it unnecessary to use quotes around the annotation (see [PEP 563](https://peps.python.org/pep-0563/)).
```

I'm open to suggestions if it does not meet your coding guidelines

Pull Request resolved: https://github.com/pytorch/pytorch/pull/105793
Approved by: https://github.com/Skylion007
2023-08-03 22:43:57 +00:00

50 lines
1.2 KiB
Python

from __future__ import annotations
from typing import List, Tuple, TYPE_CHECKING, Union
# Prevent circular import
if TYPE_CHECKING:
from torch._inductor.scheduler import (
ExternKernelSchedulerNode,
NopKernelSchedulerNode,
SchedulerNode,
)
# counter for tracking how many kernels have been generated
generated_kernel_count = 0
generated_cpp_vec_kernel_count = 0
num_bytes_accessed = 0
nodes_num_elem: List[
Tuple[
Union[NopKernelSchedulerNode, SchedulerNode, ExternKernelSchedulerNode],
int,
]
] = []
# counters for tracking fusions
ir_nodes_pre_fusion = 0
# counters for tracking to_dtype inserted
cpp_to_dtype_count = 0
# counters for tracking cpp_wrapper disabled
disable_cpp_wrapper = 0
# reset all counters
def reset():
global generated_kernel_count
global generated_cpp_vec_kernel_count
global num_bytes_accessed, nodes_num_elem
global ir_nodes_pre_fusion
global cpp_to_dtype_count
global disable_cpp_wrapper
generated_kernel_count = 0
generated_cpp_vec_kernel_count = 0
num_bytes_accessed = 0
nodes_num_elem.clear()
ir_nodes_pre_fusion = 0
cpp_to_dtype_count = 0
disable_cpp_wrapper = 0