pytorch/torch/_inductor/remote_cache.py
Aaron Gokaslan 31715be72a [BE]: Update mypy to 1.11.2 (#133816)
Updates mypy to 1.11.1 to improve type inference

Pull Request resolved: https://github.com/pytorch/pytorch/pull/133816
Approved by: https://github.com/ezyang
2024-09-16 19:44:11 +00:00

199 lines
5.6 KiB
Python

from __future__ import annotations
import json
import os
import typing
from abc import abstractmethod
from typing import Any, Callable, Dict, Generic, List, Optional, Type, TypeVar, Union
from typing_extensions import override, TypeAlias
from torch._inductor import config
try:
import redis
except ImportError:
redis = None # type: ignore[assignment]
if config.is_fbcode():
from rfe.scubadata.scubadata_py3 import ( # type: ignore[import-not-found]
Sample as Sample_,
)
Sample: TypeAlias = Sample_
else:
Sample: TypeAlias = Type[object] # type: ignore[misc,no-redef]
_T = TypeVar("_T")
_U = TypeVar("_U")
class RemoteCacheBackend(Generic[_T]):
"""
A backend implementation for accessing a remote/distributed cache. Only
works with bytes in/out. For structured data use a RemoteCache.
"""
@abstractmethod
def get(self, key: str) -> Optional[_T]:
pass
@abstractmethod
def put(self, key: str, data: _T) -> None:
pass
# Serde that encodes from _T to _U and decodes from _U to _T.
class RemoteCacheSerde(Generic[_T, _U]):
@abstractmethod
def encode(self, data: _T) -> _U:
pass
@abstractmethod
def decode(self, data: _U) -> _T:
pass
JsonDataTy = Optional[
Union[int, float, str, bool, Dict[str, "JsonDataTy"], List["JsonDataTy"]]
]
class RemoteCacheJsonSerde(RemoteCacheSerde[JsonDataTy, bytes]):
def encode(self, data: JsonDataTy) -> bytes:
return bytes(json.dumps(data), "ascii")
def decode(self, data: bytes) -> JsonDataTy:
return json.loads(data)
class RemoteCachePassthroughSerde(RemoteCacheSerde[_T, _T]):
def encode(self, data: _T) -> _T:
return data
def decode(self, data: _T) -> _T:
return data
class RemoteCache(Generic[_T]):
backend_override_cls: Optional[Callable[[], RemoteCacheBackend[Any]]] = None
def __init__(
self, backend: RemoteCacheBackend[_U], serde: RemoteCacheSerde[_T, _U]
) -> None:
# Support for testing.
if (override_cls := self.__class__.backend_override_cls) is not None:
self.backend = override_cls()
else:
self.backend = backend
self.serde = serde
def get(self, key: str) -> Optional[_T]:
sample = self._create_sample()
result = self._get(key, sample)
self._log_sample(sample)
return result
def put(self, key: str, value: _T) -> None:
sample = self._create_sample()
self._put(key, value, sample)
self._log_sample(sample)
def _decode(self, data: _U, sample: Optional[Sample]) -> _T: # type: ignore[override]
return self.serde.decode(data) # type: ignore[arg-type]
def _encode(self, value: _T, sample: Optional[Sample]) -> Any: # returns _U
return self.serde.encode(value)
def _get(self, key: str, sample: Optional[Sample]) -> Optional[_T]:
if data := self.backend.get(key):
return self._decode(data, sample)
return None
def _put(self, key: str, value: _T, sample: Optional[Sample]) -> None:
data = self._encode(value, sample)
self.backend.put(key, data)
def _create_sample(self) -> Optional[Sample]:
return None
def _log_sample(self, sample: Optional[Sample]) -> None:
pass
class RedisRemoteCacheBackend(RemoteCacheBackend[bytes]):
"""
A Redis implementation of a remote/distributed cache.
"""
_key_fmt: str
_redis: Optional[redis.Redis] = None
def __init__(self, cache_id: str) -> None:
if not redis:
# We had trouble importing redis - just skip init.
return
self._key_fmt = f"pt2:{cache_id}:{{key}}"
self._redis = redis.Redis(
host=os.environ.get("TORCHINDUCTOR_REDIS_HOST", "localhost"),
port=int(os.environ.get("TORCHINDUCTOR_REDIS_PORT", 6379)),
)
def __get_key(self, key: str) -> str:
return self._key_fmt.format(key=key)
@override
def get(self, key: str) -> Optional[bytes]:
if not self._redis:
# Either redis wasn't found or we already had some trouble...
return None
try:
value = self._redis.get(self.__get_key(key))
except redis.exceptions.ConnectionError:
# Redis is lazy and doesn't actually attempt to connect until the
# first use. Mark is as unavailable now.
self._redis = None
return None
# In theory redis.get() can return an Awaitable as well...
assert value is None or isinstance(value, bytes)
return value
@override
def put(self, key: str, data: bytes) -> None:
if not self._redis:
# Either redis wasn't found or we already had some trouble...
return
try:
self._redis.set(self.__get_key(key), data)
except redis.exceptions.ConnectionError:
# Redis is lazy and doesn't actually attempt to connect until the
# first use. Mark is as unavailable now.
self._redis = None
class RedisRemoteCache(RemoteCache[JsonDataTy]):
def __init__(self, key: str) -> None:
# Special test handling: If we're just going to override the backend
# anyway don't require redis
if self.__class__.backend_override_cls:
# This is totally bogus but it works for now...
backend = typing.cast(RemoteCacheBackend[bytes], None)
else:
backend = RedisRemoteCacheBackend(key)
serde = RemoteCacheJsonSerde()
super().__init__(backend, serde)
class RemoteAutotuneCache(RedisRemoteCache):
pass
class RemoteFxGraphCache(RedisRemoteCache):
pass