mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Updates mypy to 1.11.1 to improve type inference Pull Request resolved: https://github.com/pytorch/pytorch/pull/133816 Approved by: https://github.com/ezyang
199 lines
5.6 KiB
Python
199 lines
5.6 KiB
Python
from __future__ import annotations
|
|
|
|
import json
|
|
import os
|
|
import typing
|
|
from abc import abstractmethod
|
|
from typing import Any, Callable, Dict, Generic, List, Optional, Type, TypeVar, Union
|
|
from typing_extensions import override, TypeAlias
|
|
|
|
from torch._inductor import config
|
|
|
|
|
|
try:
|
|
import redis
|
|
except ImportError:
|
|
redis = None # type: ignore[assignment]
|
|
|
|
|
|
if config.is_fbcode():
|
|
from rfe.scubadata.scubadata_py3 import ( # type: ignore[import-not-found]
|
|
Sample as Sample_,
|
|
)
|
|
|
|
Sample: TypeAlias = Sample_
|
|
else:
|
|
Sample: TypeAlias = Type[object] # type: ignore[misc,no-redef]
|
|
|
|
|
|
_T = TypeVar("_T")
|
|
_U = TypeVar("_U")
|
|
|
|
|
|
class RemoteCacheBackend(Generic[_T]):
|
|
"""
|
|
A backend implementation for accessing a remote/distributed cache. Only
|
|
works with bytes in/out. For structured data use a RemoteCache.
|
|
"""
|
|
|
|
@abstractmethod
|
|
def get(self, key: str) -> Optional[_T]:
|
|
pass
|
|
|
|
@abstractmethod
|
|
def put(self, key: str, data: _T) -> None:
|
|
pass
|
|
|
|
|
|
# Serde that encodes from _T to _U and decodes from _U to _T.
|
|
class RemoteCacheSerde(Generic[_T, _U]):
|
|
@abstractmethod
|
|
def encode(self, data: _T) -> _U:
|
|
pass
|
|
|
|
@abstractmethod
|
|
def decode(self, data: _U) -> _T:
|
|
pass
|
|
|
|
|
|
JsonDataTy = Optional[
|
|
Union[int, float, str, bool, Dict[str, "JsonDataTy"], List["JsonDataTy"]]
|
|
]
|
|
|
|
|
|
class RemoteCacheJsonSerde(RemoteCacheSerde[JsonDataTy, bytes]):
|
|
def encode(self, data: JsonDataTy) -> bytes:
|
|
return bytes(json.dumps(data), "ascii")
|
|
|
|
def decode(self, data: bytes) -> JsonDataTy:
|
|
return json.loads(data)
|
|
|
|
|
|
class RemoteCachePassthroughSerde(RemoteCacheSerde[_T, _T]):
|
|
def encode(self, data: _T) -> _T:
|
|
return data
|
|
|
|
def decode(self, data: _T) -> _T:
|
|
return data
|
|
|
|
|
|
class RemoteCache(Generic[_T]):
|
|
backend_override_cls: Optional[Callable[[], RemoteCacheBackend[Any]]] = None
|
|
|
|
def __init__(
|
|
self, backend: RemoteCacheBackend[_U], serde: RemoteCacheSerde[_T, _U]
|
|
) -> None:
|
|
# Support for testing.
|
|
if (override_cls := self.__class__.backend_override_cls) is not None:
|
|
self.backend = override_cls()
|
|
else:
|
|
self.backend = backend
|
|
self.serde = serde
|
|
|
|
def get(self, key: str) -> Optional[_T]:
|
|
sample = self._create_sample()
|
|
result = self._get(key, sample)
|
|
self._log_sample(sample)
|
|
return result
|
|
|
|
def put(self, key: str, value: _T) -> None:
|
|
sample = self._create_sample()
|
|
self._put(key, value, sample)
|
|
self._log_sample(sample)
|
|
|
|
def _decode(self, data: _U, sample: Optional[Sample]) -> _T: # type: ignore[override]
|
|
return self.serde.decode(data) # type: ignore[arg-type]
|
|
|
|
def _encode(self, value: _T, sample: Optional[Sample]) -> Any: # returns _U
|
|
return self.serde.encode(value)
|
|
|
|
def _get(self, key: str, sample: Optional[Sample]) -> Optional[_T]:
|
|
if data := self.backend.get(key):
|
|
return self._decode(data, sample)
|
|
return None
|
|
|
|
def _put(self, key: str, value: _T, sample: Optional[Sample]) -> None:
|
|
data = self._encode(value, sample)
|
|
self.backend.put(key, data)
|
|
|
|
def _create_sample(self) -> Optional[Sample]:
|
|
return None
|
|
|
|
def _log_sample(self, sample: Optional[Sample]) -> None:
|
|
pass
|
|
|
|
|
|
class RedisRemoteCacheBackend(RemoteCacheBackend[bytes]):
|
|
"""
|
|
A Redis implementation of a remote/distributed cache.
|
|
"""
|
|
|
|
_key_fmt: str
|
|
_redis: Optional[redis.Redis] = None
|
|
|
|
def __init__(self, cache_id: str) -> None:
|
|
if not redis:
|
|
# We had trouble importing redis - just skip init.
|
|
return
|
|
|
|
self._key_fmt = f"pt2:{cache_id}:{{key}}"
|
|
self._redis = redis.Redis(
|
|
host=os.environ.get("TORCHINDUCTOR_REDIS_HOST", "localhost"),
|
|
port=int(os.environ.get("TORCHINDUCTOR_REDIS_PORT", 6379)),
|
|
)
|
|
|
|
def __get_key(self, key: str) -> str:
|
|
return self._key_fmt.format(key=key)
|
|
|
|
@override
|
|
def get(self, key: str) -> Optional[bytes]:
|
|
if not self._redis:
|
|
# Either redis wasn't found or we already had some trouble...
|
|
return None
|
|
|
|
try:
|
|
value = self._redis.get(self.__get_key(key))
|
|
except redis.exceptions.ConnectionError:
|
|
# Redis is lazy and doesn't actually attempt to connect until the
|
|
# first use. Mark is as unavailable now.
|
|
self._redis = None
|
|
return None
|
|
|
|
# In theory redis.get() can return an Awaitable as well...
|
|
assert value is None or isinstance(value, bytes)
|
|
return value
|
|
|
|
@override
|
|
def put(self, key: str, data: bytes) -> None:
|
|
if not self._redis:
|
|
# Either redis wasn't found or we already had some trouble...
|
|
return
|
|
|
|
try:
|
|
self._redis.set(self.__get_key(key), data)
|
|
except redis.exceptions.ConnectionError:
|
|
# Redis is lazy and doesn't actually attempt to connect until the
|
|
# first use. Mark is as unavailable now.
|
|
self._redis = None
|
|
|
|
|
|
class RedisRemoteCache(RemoteCache[JsonDataTy]):
|
|
def __init__(self, key: str) -> None:
|
|
# Special test handling: If we're just going to override the backend
|
|
# anyway don't require redis
|
|
if self.__class__.backend_override_cls:
|
|
# This is totally bogus but it works for now...
|
|
backend = typing.cast(RemoteCacheBackend[bytes], None)
|
|
else:
|
|
backend = RedisRemoteCacheBackend(key)
|
|
serde = RemoteCacheJsonSerde()
|
|
super().__init__(backend, serde)
|
|
|
|
|
|
class RemoteAutotuneCache(RedisRemoteCache):
|
|
pass
|
|
|
|
|
|
class RemoteFxGraphCache(RedisRemoteCache):
|
|
pass
|