Lower cache mocking to test more pytorch code (#133579)

Summary: Previously we were mocking out FbRemoteFxGraphCacheBackend which meant that we were missing testing a whole bunch of the cache code. Cache at a lower level (CacheClient, LocalAutotuneCacheBackend, ManifoldClient, Redis) so we cover a larger amount of the caching code.

Test Plan: unit tests

Reviewed By: oulgen

Differential Revision: D60937966

Pull Request resolved: https://github.com/pytorch/pytorch/pull/133579
Approved by: https://github.com/oulgen
This commit is contained in:
Aaron Orenstein 2024-08-19 16:32:36 +00:00 committed by PyTorch MergeBot
parent 32ed4a3beb
commit 68fcd54226
5 changed files with 388 additions and 123 deletions

View File

@ -17,6 +17,7 @@ jinja2
fsspec
lintrunner
ninja
redis
# setuptools was removed from default python install
setuptools ; python_version >= "3.12"
packaging

290
test/inductor/mock_cache.py Normal file
View File

@ -0,0 +1,290 @@
# Owner(s): ["module: inductor"]
import contextlib
import dataclasses
import sys
import threading
import unittest.mock
from types import TracebackType
from typing import Callable, Generator, Optional, Tuple, Type, Union
from typing_extensions import override, Self
import torch
from torch._inductor import config
from torch._inductor.remote_cache import RemoteCacheBackend
# The cache state is thread-local so if we're running multiple tests at once
# they won't cross contaminate. However - it needs to be "global" because we
# allow code to create new cache clients which refer to the same cache (because
# it's a remote cache).
class _MockCacheState(threading.local):
def __init__(self, name: str):
self.reset()
self._name = name
self._cache = {}
self._clients = {} # Used for Manifold
def reset(self):
self.num_init = 0
self.num_put = 0
self.num_get_hit = 0
self.num_get_miss = 0
def report(self):
print(
"".join(
[
f"{self._name} cache: ",
f"init: {self.num_init}, ",
f"puts: {self.num_put}, ",
f"misses: {self.num_get_miss}, ",
f"hits: {self.num_get_hit}, ",
]
),
file=sys.stderr,
)
class _MockLocalAutotuneCacheBackend(RemoteCacheBackend):
_state = _MockCacheState("Local")
def __init__(self):
state = self._state
state.num_init += 1
@override
def get(self, key: str) -> Optional[bytes]:
assert isinstance(key, str)
state = self._state
if key in state._cache:
state.num_get_hit += 1
return state._cache[key]
else:
state.num_get_miss += 1
@override
def put(self, key: str, data: bytes) -> None:
assert isinstance(key, str)
assert isinstance(data, bytes)
state = self._state
state.num_put += 1
state._cache[key] = data
class _MockRedisRemoteCache:
_state = _MockCacheState("Redis")
def __init__(self, *args, **kwargs):
state = self._state
state.num_init += 1
def get(self, key: Union[bytes, str]) -> Optional[Union[bytes, str, int, float]]:
assert isinstance(key, (bytes, str))
state = self._state
if key in state._cache:
state.num_get_hit += 1
else:
state.num_get_miss += 1
return state._cache.get(key)
def set(self, key: Union[bytes, str], data: Union[bytes, str, int, float]) -> None:
assert isinstance(key, (bytes, str))
assert isinstance(data, (bytes, str, int, float)), type(data)
state = self._state
# According to https://redis-py.readthedocs.io/en/stable/commands.html#redis.commands.core.CoreCommands.set
# redis accepts Union[bytes, memoryview, str, int, float]
state.num_put += 1
state._cache[key] = data
@dataclasses.dataclass
class CacheDecl:
qname: str
cls: Type[object]
f: Optional[Callable[..., object]] = None
def patch(self) -> contextlib.AbstractContextManager:
return unittest.mock.patch(self.qname, self.f or self.cls)
_CACHES = (
CacheDecl(
"torch._inductor.runtime.triton_heuristics.LocalAutotuneCache",
_MockLocalAutotuneCacheBackend,
),
CacheDecl("redis.Redis", _MockRedisRemoteCache),
)
# List of configs for each cache
_CACHE_CONFIG_EN = (
"fx_graph_cache",
"fx_graph_remote_cache",
"autotune_local_cache",
"autotune_remote_cache",
# "bundled_autotune_cache",
)
def _has_redis():
import importlib
return importlib.util.find_spec("redis") is not None
class PatchCaches(contextlib.AbstractContextManager):
num_init = 0
num_put = 0
num_get_miss = 0
num_get_hit = 0
_savedCacheState = {}
@staticmethod
def get_caches() -> Tuple[CacheDecl, ...]:
if config.is_fbcode():
from .fb.mock_cache import FB_CACHES
return _CACHES + FB_CACHES
else:
return _CACHES
def __init__(self):
self._contexts = []
for decl in self.get_caches():
self._contexts.append(decl.patch())
@classmethod
def reset(cls):
"""
Reset the patched cache states as well as the PatchCaches
aggregation.
"""
cls.num_init = 0
cls.num_put = 0
cls.num_get_miss = 0
cls.num_get_hit = 0
for decl in cls.get_caches():
decl.cls._state.reset()
@classmethod
def update(cls):
"""
Update PatchCaches' state with the values from all the patched caches.
"""
cls.num_init = sum(decl.cls._state.num_init for decl in cls.get_caches())
cls.num_put = sum(decl.cls._state.num_put for decl in cls.get_caches())
cls.num_get_miss = sum(
decl.cls._state.num_get_miss for decl in cls.get_caches()
)
cls.num_get_hit = sum(decl.cls._state.num_get_hit for decl in cls.get_caches())
@classmethod
def setUp(cls):
# If we don't have redis available then fake it since we'll be mocking it anyway.
if not _has_redis():
class FakeRedisModule:
class Redis:
pass
sys.modules["redis"] = FakeRedisModule()
# If this test is using PatchCaches then disable all the caches by
# default, letting the tests turn them on explicitly. This is because
# tests using PatchCaches will often want to check stats explicitly.
cls._savedCacheState = {}
for name in _CACHE_CONFIG_EN:
if hasattr(config, name):
cls._savedCacheState[name] = getattr(config, name)
setattr(config, name, False)
for decl in cls.get_caches():
if hasattr(decl.cls, "setUp"):
decl.cls.setUp()
@classmethod
def tearDown(cls):
for decl in cls.get_caches()[::-1]:
if hasattr(decl.cls, "tearDown"):
decl.cls.tearDown()
# Restore cache defaults
for name in _CACHE_CONFIG_EN:
delattr(config, name)
if name in cls._savedCacheState:
setattr(config, name, cls._savedCacheState[name])
@classmethod
def report(cls):
"""
Report cache state for all patched caches.
"""
for decl in cls.get_caches():
decl.cls._state.report()
print(
"".join(
[
"All caches: ",
f"init: {cls.num_init}, ",
f"puts: {cls.num_put}, ",
f"misses: {cls.num_get_miss}, ",
f"hits: {cls.num_get_hit}",
]
),
file=sys.stderr,
)
def __enter__(self) -> Self:
"""
Start mocking the patched caches.
"""
self.reset()
for ctx in self._contexts:
ctx.__enter__()
return self
def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_value: Optional[BaseException],
traceback: Optional[TracebackType],
) -> None:
"""
Stop mocking the patched caches.
"""
for ctx in self._contexts[::-1]:
ctx.__exit__(exc_type, exc_value, traceback)
self.update()
@contextlib.contextmanager
def patch_fbcode(state: bool) -> Generator[None, None, None]:
if hasattr(torch.version, "git_version"):
# Currently non-fbcode
if state:
old = torch.version.git_version
delattr(torch.version, "git_version")
try:
yield
finally:
torch.version.git_version = old
else:
yield
else:
# Currently fbcode
if state:
yield
else:
torch.version.git_version = "12345+"
try:
yield
finally:
delattr(torch.version, "git_version")

View File

@ -1,7 +1,5 @@
# Owner(s): ["module: inductor"]
import base64
import functools
import json
import os
import pickle
import unittest
@ -42,10 +40,16 @@ from torch.testing._internal.inductor_utils import (
from torch.utils._triton import has_triton
try:
from .mock_cache import PatchCaches
except ImportError:
from mock_cache import PatchCaches # @manual
HAS_TRITON = has_triton()
if HAS_TRITON:
import triton
import triton # @manual
from torch.testing._internal.triton_utils import add_kernel
@ -106,6 +110,11 @@ class TestFxGraphCache(TestCase):
def setUp(self):
super().setUp()
counters.clear()
PatchCaches.setUp()
def tearDown(self):
super().tearDown()
PatchCaches.tearDown()
def reset(self):
torch._dynamo.reset()
@ -168,56 +177,23 @@ class TestFxGraphCache(TestCase):
a = torch.rand(25, dtype=dtype, device=device)
b = torch.rand(5, 5, dtype=dtype, device=device)
cache = {}
num_get = 0
num_put = 0
class MyCache:
def __init__(self, key, is_autotune=False):
pass
def get(self, filename):
nonlocal cache
nonlocal num_get
if filename not in cache:
return None
ret = json.loads(cache[filename])
num_get += 1
if config.is_fbcode():
return base64.b64decode(ret["data"]) if ret is not None else ret
else:
return base64.b64decode(ret) if ret is not None else ret
def put(self, filename, data):
nonlocal cache
nonlocal num_put
if config.is_fbcode():
data["data"] = base64.b64encode(data["data"]).decode("ascii")
else:
data = base64.b64encode(data).decode("ascii")
cache[filename] = json.dumps(data)
num_put += 1
cache_module = (
"torch._inductor.fb.remote_cache.FbRemoteFxGraphCacheBackend"
if config.is_fbcode()
else "torch._inductor.remote_cache.RedisRemoteCacheBackend"
)
with config.patch(
{
"fx_graph_cache": False,
"fx_graph_remote_cache": True,
}
), patch.dict(os.environ), patch(cache_module, MyCache, create=True):
), patch.dict(os.environ), PatchCaches():
os.environ.pop("TRITON_CACHE_MANAGER", None)
for _ in range(4):
with fresh_inductor_cache():
compiled_fn = torch.compile(fn, dynamic=dynamic)
self.assertEqual(fn(a, b), compiled_fn(a, b))
reset()
self.assertEqual(num_get, 3)
self.assertEqual(num_put, 1)
PatchCaches.report()
self.assertEqual(PatchCaches.num_get_hit, 3)
self.assertEqual(PatchCaches.num_get_miss, 1)
self.assertEqual(PatchCaches.num_put, 1)
@requires_triton()
@config.patch({"fx_graph_cache": True})

View File

@ -1,5 +1,4 @@
# Owner(s): ["module: inductor"]
import json
import os
import unittest
from typing import Callable, List, Optional
@ -35,6 +34,12 @@ from torch.testing._internal.common_utils import (
from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA
try:
from .mock_cache import PatchCaches
except ImportError:
from mock_cache import PatchCaches # @manual
torch.set_float32_matmul_precision("high")
if HAS_CUDA:
torch.cuda.memory._set_allocator_settings("expandable_segments:False")
@ -219,80 +224,6 @@ class TestMaxAutotune(TestCase):
with config.patch({"max_autotune": True}):
torch.compile(mm, dynamic=dynamic)(a, b)
@skipIfRocm
@parametrize("dynamic", (False, True))
def test_max_autotune_remote_caching(self, dynamic: bool):
from unittest.mock import patch
def mm(a, b):
a = torch.sin(a)
return a @ b
a = torch.randn(100, 10).cuda()
b = torch.randn(10, 100).cuda()
class Model(torch.nn.Module):
def forward(self, x, y):
return x + y
def f(x, y):
return Model()(x, y)
x = torch.randn(100, 100).cuda()
y = torch.randn(100, 100).cuda()
cache = {}
num_get = 0
num_put = 0
class MyCache:
def __init__(self, key, is_autotune=False):
pass
def get(self, filename):
nonlocal cache
nonlocal num_get
if filename not in cache:
return None
ret = json.loads(cache[filename])
num_get += 1
return ret
def put(self, filename, data):
nonlocal cache
nonlocal num_put
cache[filename] = json.dumps(data)
num_put += 1
cache_module = (
"torch._inductor.fb.remote_cache.FbRemoteAutotuneCacheBackend"
if config.is_fbcode()
else "torch._inductor.remote_cache.RedisRemoteCacheBackend"
)
with config.patch(
{
"autotune_local_cache": False,
"autotune_remote_cache": True,
}
), patch.dict(os.environ), patch(cache_module, MyCache, create=True):
os.environ.pop("TRITON_CACHE_MANAGER", None)
with config.patch({"max_autotune": True}):
for _ in range(4):
with fresh_inductor_cache():
torch.compile(mm, dynamic=dynamic)(a, b)
reset()
self.assertEqual(num_get, 3)
self.assertEqual(num_put, 1)
num_get = 0
num_put = 0
for _ in range(4):
with fresh_inductor_cache():
torch.compile(f, dynamic=dynamic)(x, y)
reset()
self.assertEqual(num_get, 3)
self.assertEqual(num_put, 1)
@skipIfRocm
def test_precompilation_threads(self):
import threading
@ -777,6 +708,72 @@ class TestMaxAutotune(TestCase):
self.assertIn("NoValidChoicesError", str(context.exception))
@instantiate_parametrized_tests
class TestMaxAutotuneRemoteCache(TestCase):
def setUp(self):
super().setUp()
PatchCaches.setUp()
def tearDown(self):
super().tearDown()
PatchCaches.tearDown()
@skipIfRocm
@parametrize("dynamic", (False, True))
def test_max_autotune_remote_caching(self, dynamic: bool):
from unittest.mock import patch
if not config.is_fbcode():
self.skipTest("Redis for autotune is currently broken")
def mm(a, b):
a = torch.sin(a)
return a @ b
a = torch.randn(100, 10).cuda()
b = torch.randn(10, 100).cuda()
class Model(torch.nn.Module):
def forward(self, x, y):
return x + y
def f(x, y):
return Model()(x, y)
x = torch.randn(100, 100).cuda()
y = torch.randn(100, 100).cuda()
with config.patch(
{
"autotune_local_cache": False,
"autotune_remote_cache": True,
}
), patch.dict(os.environ), PatchCaches():
os.environ.pop("TRITON_CACHE_MANAGER", None)
with config.patch({"max_autotune": True}):
for _ in range(4):
with fresh_inductor_cache():
torch.compile(mm, dynamic=dynamic)(a, b)
reset()
PatchCaches.update()
PatchCaches.report()
self.assertEqual(PatchCaches.num_get_hit, 3)
self.assertEqual(PatchCaches.num_get_miss, 1)
self.assertEqual(PatchCaches.num_put, 1)
PatchCaches.reset()
for _ in range(4):
with fresh_inductor_cache():
torch.compile(f, dynamic=dynamic)(x, y)
reset()
PatchCaches.update()
PatchCaches.report()
self.assertEqual(PatchCaches.num_get_hit, 3)
self.assertEqual(PatchCaches.num_get_miss, 1)
self.assertEqual(PatchCaches.num_put, 1)
class TestBenchmarkRequest(BenchmarkRequest):
def __init__(
self, value: float, multi_device: bool, parent_visible_devices: Optional[str]

View File

@ -68,6 +68,8 @@ T = TypeVar("T")
if TYPE_CHECKING:
from collections.abc import KeysView
from .remote_cache import RemoteCacheBackend
"""
codecache.py, cpp_builder.py and cpu_vec_isa.py import rule:
@ -1173,7 +1175,7 @@ class FxGraphCache:
compiled_graph: CompiledFxGraph,
example_inputs: List[torch.Tensor],
local: bool,
remote_cache: None,
remote_cache: Optional[RemoteCacheBackend],
) -> None:
"""
Store a serialized CompiledFxGraph on disk.
@ -1220,17 +1222,16 @@ class FxGraphCache:
write_atomic(path, content, make_dirs=True)
if remote_cache:
time_taken_ms = int((disk_compiled_graph._time_taken_ns or 0) // 1e6)
cache_data = (
{
"data": content,
"time_taken_ms": int(
disk_compiled_graph._time_taken_ns // 1e6
), # Convert from NS to MS
"time_taken_ms": time_taken_ms,
}
if config.is_fbcode()
else content
)
remote_cache.put(key, cache_data)
remote_cache.put(key, cache_data) # type: ignore[arg-type]
except Exception:
log.warning("fx graph unable to write to cache", exc_info=True)
counters["inductor"]["fxgraph_cache_write_error"] += 1
@ -1291,7 +1292,7 @@ class FxGraphCache:
cache_info["key"] = key
cache_info["components"] = debug_lines
remote_cache = None
remote_cache: Optional[RemoteCacheBackend] = None
if remote:
cache_id = "fx-graph-v1"
try: