pytorch/torch/utils/_stats.py
Kasperi Apell a7915c56f6 Propagate callable parameter types using ParamSpec (#142306) (#143797)
The codebase has a few locations where callable parameter type information is lost when the unpackings *args and **kwargs are typed as Any. Refactor these instances to retain type information using typing_extensions.ParamSpec.

Also, in these functions, enforce return type with TypeVar.

Addresses #142306

Pull Request resolved: https://github.com/pytorch/pytorch/pull/143797
Approved by: https://github.com/Skylion007

Co-authored-by: Aaron Gokaslan <aaronGokaslan@gmail.com>
Co-authored-by: Xuehai Pan <XuehaiPan@outlook.com>
2024-12-29 23:03:14 +00:00

28 lines
989 B
Python

# NOTE! PLEASE KEEP THIS FILE *FREE* OF TORCH DEPS! IT SHOULD BE IMPORTABLE ANYWHERE.
# IF YOU FEEL AN OVERWHELMING URGE TO ADD A TORCH DEP, MAKE A TRAMPOLINE FILE A LA torch._dynamo.utils
# AND SCRUB AWAY TORCH NOTIONS THERE.
import collections
import functools
from typing import Callable, OrderedDict, TypeVar
from typing_extensions import ParamSpec
simple_call_counter: OrderedDict[str, int] = collections.OrderedDict()
_P = ParamSpec("_P")
_R = TypeVar("_R")
def count_label(label: str) -> None:
prev = simple_call_counter.setdefault(label, 0)
simple_call_counter[label] = prev + 1
def count(fn: Callable[_P, _R]) -> Callable[_P, _R]:
@functools.wraps(fn)
def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:
if fn.__qualname__ not in simple_call_counter:
simple_call_counter[fn.__qualname__] = 0
simple_call_counter[fn.__qualname__] = simple_call_counter[fn.__qualname__] + 1
return fn(*args, **kwargs)
return wrapper