""" Python polyfills for operator """ from __future__ import annotations import operator from typing import Any, overload, TYPE_CHECKING, TypeVar from typing_extensions import TypeVarTuple, Unpack from ..decorators import substitute_in_graph if TYPE_CHECKING: from collections.abc import Callable, Iterable # Most unary and binary operators are handled by BuiltinVariable (e.g., `pos`, `add`) __all__ = ["attrgetter", "itemgetter", "methodcaller", "countOf"] _T = TypeVar("_T") _T1 = TypeVar("_T1") _T2 = TypeVar("_T2") _Ts = TypeVarTuple("_Ts") _U = TypeVar("_U") _U1 = TypeVar("_U1") _U2 = TypeVar("_U2") _Us = TypeVarTuple("_Us") @overload # pyrefly: ignore [inconsistent-overload] def attrgetter(attr: str, /) -> Callable[[Any], _U]: ... @overload # pyrefly: ignore [inconsistent-overload] def attrgetter( attr1: str, attr2: str, /, *attrs: str ) -> Callable[[Any], tuple[_U1, _U2, Unpack[_Us]]]: ... # Reference: https://docs.python.org/3/library/operator.html#operator.attrgetter @substitute_in_graph(operator.attrgetter, is_embedded_type=True) # type: ignore[arg-type,misc] def attrgetter(*attrs: str) -> Callable[[Any], Any | tuple[Any, ...]]: if len(attrs) == 0: raise TypeError("attrgetter expected 1 argument, got 0") if any(not isinstance(attr, str) for attr in attrs): raise TypeError("attribute name must be a string") def resolve_attr(obj: Any, attr: str) -> Any: for name in attr.split("."): obj = getattr(obj, name) return obj if len(attrs) == 1: attr = attrs[0] def getter(obj: Any) -> Any: return resolve_attr(obj, attr) else: def getter(obj: Any) -> tuple[Any, ...]: # type: ignore[misc] return tuple(resolve_attr(obj, attr) for attr in attrs) return getter @overload # pyrefly: ignore [inconsistent-overload] def itemgetter(item: _T, /) -> Callable[[Any], _U]: ... @overload # pyrefly: ignore [inconsistent-overload] def itemgetter( item1: _T1, item2: _T2, /, *items: Unpack[_Ts] ) -> Callable[[Any], tuple[_U1, _U2, Unpack[_Us]]]: ... # Reference: https://docs.python.org/3/library/operator.html#operator.itemgetter @substitute_in_graph(operator.itemgetter, is_embedded_type=True) # type: ignore[arg-type,misc] def itemgetter(*items: Any) -> Callable[[Any], Any | tuple[Any, ...]]: if len(items) == 0: raise TypeError("itemgetter expected 1 argument, got 0") if len(items) == 1: item = items[0] def getter(obj: Any) -> Any: return obj[item] else: def getter(obj: Any) -> tuple[Any, ...]: # type: ignore[misc] return tuple(obj[item] for item in items) return getter # Reference: https://docs.python.org/3/library/operator.html#operator.methodcaller @substitute_in_graph(operator.methodcaller, is_embedded_type=True) # type: ignore[arg-type] def methodcaller(name: str, /, *args: Any, **kwargs: Any) -> Callable[[Any], Any]: if not isinstance(name, str): raise TypeError("method name must be a string") def caller(obj: Any) -> Any: return getattr(obj, name)(*args, **kwargs) return caller # Reference: https://docs.python.org/3/library/operator.html#operator.countOf @substitute_in_graph(operator.countOf, can_constant_fold_through=True) # type: ignore[arg-type,misc] def countOf(a: Iterable[_T], b: _T, /) -> int: return sum(it is b or it == b for it in a)