pytorch/torch/utils/_functools.py
PyTorch MergeBot bdf7cb9d9c Revert "[torch/utils][Code Clean] Clean asserts in torch/utils/*.py (#165410)"
This reverts commit e20c9bf288.

Reverted https://github.com/pytorch/pytorch/pull/165410 on behalf of https://github.com/clee2000 due to sorry I'm going to revert this since I want to try to back out some other things that are conflicting with this, there is nothing wrong with this PR, rebasing and resolving the merge conflicts should be enough, sorry for the churn ([comment](https://github.com/pytorch/pytorch/pull/165410#issuecomment-3427532373))
2025-10-21 16:27:54 +00:00

48 lines
1.5 KiB
Python

import functools
from collections.abc import Callable
from typing import Concatenate, TypeVar
from typing_extensions import ParamSpec
_P = ParamSpec("_P")
_T = TypeVar("_T")
_C = TypeVar("_C")
# Sentinel used to indicate that cache lookup failed.
_cache_sentinel = object()
def cache_method(
f: Callable[Concatenate[_C, _P], _T],
) -> Callable[Concatenate[_C, _P], _T]:
"""
Like `@functools.cache` but for methods.
`@functools.cache` (and similarly `@functools.lru_cache`) shouldn't be used
on methods because it caches `self`, keeping it alive
forever. `@cache_method` ignores `self` so won't keep `self` alive (assuming
no cycles with `self` in the parameters).
Footgun warning: This decorator completely ignores self's properties so only
use it when you know that self is frozen or won't change in a meaningful
way (such as the wrapped function being pure).
"""
cache_name = "_cache_method_" + f.__name__
@functools.wraps(f)
def wrap(self: _C, *args: _P.args, **kwargs: _P.kwargs) -> _T:
assert not kwargs
if not (cache := getattr(self, cache_name, None)):
cache = {}
setattr(self, cache_name, cache)
# pyrefly: ignore # unbound-name
cached_value = cache.get(args, _cache_sentinel)
if cached_value is not _cache_sentinel:
return cached_value
value = f(self, *args, **kwargs)
# pyrefly: ignore # unbound-name
cache[args] = value
return value
return wrap