mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Lower cache mocking to test more pytorch code (#133579)
Summary: Previously we were mocking out FbRemoteFxGraphCacheBackend which meant that we were missing testing a whole bunch of the cache code. Cache at a lower level (CacheClient, LocalAutotuneCacheBackend, ManifoldClient, Redis) so we cover a larger amount of the caching code. Test Plan: unit tests Reviewed By: oulgen Differential Revision: D60937966 Pull Request resolved: https://github.com/pytorch/pytorch/pull/133579 Approved by: https://github.com/oulgen
This commit is contained in:
parent
32ed4a3beb
commit
68fcd54226
|
|
@ -17,6 +17,7 @@ jinja2
|
|||
fsspec
|
||||
lintrunner
|
||||
ninja
|
||||
redis
|
||||
# setuptools was removed from default python install
|
||||
setuptools ; python_version >= "3.12"
|
||||
packaging
|
||||
|
|
|
|||
290
test/inductor/mock_cache.py
Normal file
290
test/inductor/mock_cache.py
Normal file
|
|
@ -0,0 +1,290 @@
|
|||
# Owner(s): ["module: inductor"]
|
||||
import contextlib
|
||||
import dataclasses
|
||||
import sys
|
||||
import threading
|
||||
import unittest.mock
|
||||
from types import TracebackType
|
||||
from typing import Callable, Generator, Optional, Tuple, Type, Union
|
||||
from typing_extensions import override, Self
|
||||
|
||||
import torch
|
||||
from torch._inductor import config
|
||||
from torch._inductor.remote_cache import RemoteCacheBackend
|
||||
|
||||
|
||||
# The cache state is thread-local so if we're running multiple tests at once
|
||||
# they won't cross contaminate. However - it needs to be "global" because we
|
||||
# allow code to create new cache clients which refer to the same cache (because
|
||||
# it's a remote cache).
|
||||
class _MockCacheState(threading.local):
|
||||
def __init__(self, name: str):
|
||||
self.reset()
|
||||
self._name = name
|
||||
self._cache = {}
|
||||
self._clients = {} # Used for Manifold
|
||||
|
||||
def reset(self):
|
||||
self.num_init = 0
|
||||
self.num_put = 0
|
||||
self.num_get_hit = 0
|
||||
self.num_get_miss = 0
|
||||
|
||||
def report(self):
|
||||
print(
|
||||
"".join(
|
||||
[
|
||||
f"{self._name} cache: ",
|
||||
f"init: {self.num_init}, ",
|
||||
f"puts: {self.num_put}, ",
|
||||
f"misses: {self.num_get_miss}, ",
|
||||
f"hits: {self.num_get_hit}, ",
|
||||
]
|
||||
),
|
||||
file=sys.stderr,
|
||||
)
|
||||
|
||||
|
||||
class _MockLocalAutotuneCacheBackend(RemoteCacheBackend):
|
||||
_state = _MockCacheState("Local")
|
||||
|
||||
def __init__(self):
|
||||
state = self._state
|
||||
state.num_init += 1
|
||||
|
||||
@override
|
||||
def get(self, key: str) -> Optional[bytes]:
|
||||
assert isinstance(key, str)
|
||||
|
||||
state = self._state
|
||||
if key in state._cache:
|
||||
state.num_get_hit += 1
|
||||
return state._cache[key]
|
||||
else:
|
||||
state.num_get_miss += 1
|
||||
|
||||
@override
|
||||
def put(self, key: str, data: bytes) -> None:
|
||||
assert isinstance(key, str)
|
||||
assert isinstance(data, bytes)
|
||||
|
||||
state = self._state
|
||||
state.num_put += 1
|
||||
state._cache[key] = data
|
||||
|
||||
|
||||
class _MockRedisRemoteCache:
|
||||
_state = _MockCacheState("Redis")
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
state = self._state
|
||||
state.num_init += 1
|
||||
|
||||
def get(self, key: Union[bytes, str]) -> Optional[Union[bytes, str, int, float]]:
|
||||
assert isinstance(key, (bytes, str))
|
||||
|
||||
state = self._state
|
||||
|
||||
if key in state._cache:
|
||||
state.num_get_hit += 1
|
||||
else:
|
||||
state.num_get_miss += 1
|
||||
return state._cache.get(key)
|
||||
|
||||
def set(self, key: Union[bytes, str], data: Union[bytes, str, int, float]) -> None:
|
||||
assert isinstance(key, (bytes, str))
|
||||
assert isinstance(data, (bytes, str, int, float)), type(data)
|
||||
|
||||
state = self._state
|
||||
|
||||
# According to https://redis-py.readthedocs.io/en/stable/commands.html#redis.commands.core.CoreCommands.set
|
||||
# redis accepts Union[bytes, memoryview, str, int, float]
|
||||
state.num_put += 1
|
||||
state._cache[key] = data
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class CacheDecl:
|
||||
qname: str
|
||||
cls: Type[object]
|
||||
f: Optional[Callable[..., object]] = None
|
||||
|
||||
def patch(self) -> contextlib.AbstractContextManager:
|
||||
return unittest.mock.patch(self.qname, self.f or self.cls)
|
||||
|
||||
|
||||
_CACHES = (
|
||||
CacheDecl(
|
||||
"torch._inductor.runtime.triton_heuristics.LocalAutotuneCache",
|
||||
_MockLocalAutotuneCacheBackend,
|
||||
),
|
||||
CacheDecl("redis.Redis", _MockRedisRemoteCache),
|
||||
)
|
||||
|
||||
# List of configs for each cache
|
||||
_CACHE_CONFIG_EN = (
|
||||
"fx_graph_cache",
|
||||
"fx_graph_remote_cache",
|
||||
"autotune_local_cache",
|
||||
"autotune_remote_cache",
|
||||
# "bundled_autotune_cache",
|
||||
)
|
||||
|
||||
|
||||
def _has_redis():
|
||||
import importlib
|
||||
|
||||
return importlib.util.find_spec("redis") is not None
|
||||
|
||||
|
||||
class PatchCaches(contextlib.AbstractContextManager):
|
||||
num_init = 0
|
||||
num_put = 0
|
||||
num_get_miss = 0
|
||||
num_get_hit = 0
|
||||
_savedCacheState = {}
|
||||
|
||||
@staticmethod
|
||||
def get_caches() -> Tuple[CacheDecl, ...]:
|
||||
if config.is_fbcode():
|
||||
from .fb.mock_cache import FB_CACHES
|
||||
|
||||
return _CACHES + FB_CACHES
|
||||
else:
|
||||
return _CACHES
|
||||
|
||||
def __init__(self):
|
||||
self._contexts = []
|
||||
for decl in self.get_caches():
|
||||
self._contexts.append(decl.patch())
|
||||
|
||||
@classmethod
|
||||
def reset(cls):
|
||||
"""
|
||||
Reset the patched cache states as well as the PatchCaches
|
||||
aggregation.
|
||||
"""
|
||||
cls.num_init = 0
|
||||
cls.num_put = 0
|
||||
cls.num_get_miss = 0
|
||||
cls.num_get_hit = 0
|
||||
|
||||
for decl in cls.get_caches():
|
||||
decl.cls._state.reset()
|
||||
|
||||
@classmethod
|
||||
def update(cls):
|
||||
"""
|
||||
Update PatchCaches' state with the values from all the patched caches.
|
||||
"""
|
||||
cls.num_init = sum(decl.cls._state.num_init for decl in cls.get_caches())
|
||||
cls.num_put = sum(decl.cls._state.num_put for decl in cls.get_caches())
|
||||
cls.num_get_miss = sum(
|
||||
decl.cls._state.num_get_miss for decl in cls.get_caches()
|
||||
)
|
||||
cls.num_get_hit = sum(decl.cls._state.num_get_hit for decl in cls.get_caches())
|
||||
|
||||
@classmethod
|
||||
def setUp(cls):
|
||||
# If we don't have redis available then fake it since we'll be mocking it anyway.
|
||||
if not _has_redis():
|
||||
|
||||
class FakeRedisModule:
|
||||
class Redis:
|
||||
pass
|
||||
|
||||
sys.modules["redis"] = FakeRedisModule()
|
||||
|
||||
# If this test is using PatchCaches then disable all the caches by
|
||||
# default, letting the tests turn them on explicitly. This is because
|
||||
# tests using PatchCaches will often want to check stats explicitly.
|
||||
cls._savedCacheState = {}
|
||||
for name in _CACHE_CONFIG_EN:
|
||||
if hasattr(config, name):
|
||||
cls._savedCacheState[name] = getattr(config, name)
|
||||
setattr(config, name, False)
|
||||
|
||||
for decl in cls.get_caches():
|
||||
if hasattr(decl.cls, "setUp"):
|
||||
decl.cls.setUp()
|
||||
|
||||
@classmethod
|
||||
def tearDown(cls):
|
||||
for decl in cls.get_caches()[::-1]:
|
||||
if hasattr(decl.cls, "tearDown"):
|
||||
decl.cls.tearDown()
|
||||
|
||||
# Restore cache defaults
|
||||
for name in _CACHE_CONFIG_EN:
|
||||
delattr(config, name)
|
||||
if name in cls._savedCacheState:
|
||||
setattr(config, name, cls._savedCacheState[name])
|
||||
|
||||
@classmethod
|
||||
def report(cls):
|
||||
"""
|
||||
Report cache state for all patched caches.
|
||||
"""
|
||||
for decl in cls.get_caches():
|
||||
decl.cls._state.report()
|
||||
print(
|
||||
"".join(
|
||||
[
|
||||
"All caches: ",
|
||||
f"init: {cls.num_init}, ",
|
||||
f"puts: {cls.num_put}, ",
|
||||
f"misses: {cls.num_get_miss}, ",
|
||||
f"hits: {cls.num_get_hit}",
|
||||
]
|
||||
),
|
||||
file=sys.stderr,
|
||||
)
|
||||
|
||||
def __enter__(self) -> Self:
|
||||
"""
|
||||
Start mocking the patched caches.
|
||||
"""
|
||||
self.reset()
|
||||
|
||||
for ctx in self._contexts:
|
||||
ctx.__enter__()
|
||||
return self
|
||||
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: Optional[Type[BaseException]],
|
||||
exc_value: Optional[BaseException],
|
||||
traceback: Optional[TracebackType],
|
||||
) -> None:
|
||||
"""
|
||||
Stop mocking the patched caches.
|
||||
"""
|
||||
for ctx in self._contexts[::-1]:
|
||||
ctx.__exit__(exc_type, exc_value, traceback)
|
||||
|
||||
self.update()
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def patch_fbcode(state: bool) -> Generator[None, None, None]:
|
||||
if hasattr(torch.version, "git_version"):
|
||||
# Currently non-fbcode
|
||||
if state:
|
||||
old = torch.version.git_version
|
||||
delattr(torch.version, "git_version")
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
torch.version.git_version = old
|
||||
else:
|
||||
yield
|
||||
else:
|
||||
# Currently fbcode
|
||||
if state:
|
||||
yield
|
||||
else:
|
||||
torch.version.git_version = "12345+"
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
delattr(torch.version, "git_version")
|
||||
|
|
@ -1,7 +1,5 @@
|
|||
# Owner(s): ["module: inductor"]
|
||||
import base64
|
||||
import functools
|
||||
import json
|
||||
import os
|
||||
import pickle
|
||||
import unittest
|
||||
|
|
@ -42,10 +40,16 @@ from torch.testing._internal.inductor_utils import (
|
|||
from torch.utils._triton import has_triton
|
||||
|
||||
|
||||
try:
|
||||
from .mock_cache import PatchCaches
|
||||
except ImportError:
|
||||
from mock_cache import PatchCaches # @manual
|
||||
|
||||
|
||||
HAS_TRITON = has_triton()
|
||||
|
||||
if HAS_TRITON:
|
||||
import triton
|
||||
import triton # @manual
|
||||
|
||||
from torch.testing._internal.triton_utils import add_kernel
|
||||
|
||||
|
|
@ -106,6 +110,11 @@ class TestFxGraphCache(TestCase):
|
|||
def setUp(self):
|
||||
super().setUp()
|
||||
counters.clear()
|
||||
PatchCaches.setUp()
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
PatchCaches.tearDown()
|
||||
|
||||
def reset(self):
|
||||
torch._dynamo.reset()
|
||||
|
|
@ -168,56 +177,23 @@ class TestFxGraphCache(TestCase):
|
|||
a = torch.rand(25, dtype=dtype, device=device)
|
||||
b = torch.rand(5, 5, dtype=dtype, device=device)
|
||||
|
||||
cache = {}
|
||||
num_get = 0
|
||||
num_put = 0
|
||||
|
||||
class MyCache:
|
||||
def __init__(self, key, is_autotune=False):
|
||||
pass
|
||||
|
||||
def get(self, filename):
|
||||
nonlocal cache
|
||||
nonlocal num_get
|
||||
if filename not in cache:
|
||||
return None
|
||||
ret = json.loads(cache[filename])
|
||||
num_get += 1
|
||||
if config.is_fbcode():
|
||||
return base64.b64decode(ret["data"]) if ret is not None else ret
|
||||
else:
|
||||
return base64.b64decode(ret) if ret is not None else ret
|
||||
|
||||
def put(self, filename, data):
|
||||
nonlocal cache
|
||||
nonlocal num_put
|
||||
if config.is_fbcode():
|
||||
data["data"] = base64.b64encode(data["data"]).decode("ascii")
|
||||
else:
|
||||
data = base64.b64encode(data).decode("ascii")
|
||||
cache[filename] = json.dumps(data)
|
||||
num_put += 1
|
||||
|
||||
cache_module = (
|
||||
"torch._inductor.fb.remote_cache.FbRemoteFxGraphCacheBackend"
|
||||
if config.is_fbcode()
|
||||
else "torch._inductor.remote_cache.RedisRemoteCacheBackend"
|
||||
)
|
||||
|
||||
with config.patch(
|
||||
{
|
||||
"fx_graph_cache": False,
|
||||
"fx_graph_remote_cache": True,
|
||||
}
|
||||
), patch.dict(os.environ), patch(cache_module, MyCache, create=True):
|
||||
), patch.dict(os.environ), PatchCaches():
|
||||
os.environ.pop("TRITON_CACHE_MANAGER", None)
|
||||
for _ in range(4):
|
||||
with fresh_inductor_cache():
|
||||
compiled_fn = torch.compile(fn, dynamic=dynamic)
|
||||
self.assertEqual(fn(a, b), compiled_fn(a, b))
|
||||
reset()
|
||||
self.assertEqual(num_get, 3)
|
||||
self.assertEqual(num_put, 1)
|
||||
|
||||
PatchCaches.report()
|
||||
self.assertEqual(PatchCaches.num_get_hit, 3)
|
||||
self.assertEqual(PatchCaches.num_get_miss, 1)
|
||||
self.assertEqual(PatchCaches.num_put, 1)
|
||||
|
||||
@requires_triton()
|
||||
@config.patch({"fx_graph_cache": True})
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
# Owner(s): ["module: inductor"]
|
||||
import json
|
||||
import os
|
||||
import unittest
|
||||
from typing import Callable, List, Optional
|
||||
|
|
@ -35,6 +34,12 @@ from torch.testing._internal.common_utils import (
|
|||
from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA
|
||||
|
||||
|
||||
try:
|
||||
from .mock_cache import PatchCaches
|
||||
except ImportError:
|
||||
from mock_cache import PatchCaches # @manual
|
||||
|
||||
|
||||
torch.set_float32_matmul_precision("high")
|
||||
if HAS_CUDA:
|
||||
torch.cuda.memory._set_allocator_settings("expandable_segments:False")
|
||||
|
|
@ -219,80 +224,6 @@ class TestMaxAutotune(TestCase):
|
|||
with config.patch({"max_autotune": True}):
|
||||
torch.compile(mm, dynamic=dynamic)(a, b)
|
||||
|
||||
@skipIfRocm
|
||||
@parametrize("dynamic", (False, True))
|
||||
def test_max_autotune_remote_caching(self, dynamic: bool):
|
||||
from unittest.mock import patch
|
||||
|
||||
def mm(a, b):
|
||||
a = torch.sin(a)
|
||||
return a @ b
|
||||
|
||||
a = torch.randn(100, 10).cuda()
|
||||
b = torch.randn(10, 100).cuda()
|
||||
|
||||
class Model(torch.nn.Module):
|
||||
def forward(self, x, y):
|
||||
return x + y
|
||||
|
||||
def f(x, y):
|
||||
return Model()(x, y)
|
||||
|
||||
x = torch.randn(100, 100).cuda()
|
||||
y = torch.randn(100, 100).cuda()
|
||||
|
||||
cache = {}
|
||||
num_get = 0
|
||||
num_put = 0
|
||||
|
||||
class MyCache:
|
||||
def __init__(self, key, is_autotune=False):
|
||||
pass
|
||||
|
||||
def get(self, filename):
|
||||
nonlocal cache
|
||||
nonlocal num_get
|
||||
if filename not in cache:
|
||||
return None
|
||||
ret = json.loads(cache[filename])
|
||||
num_get += 1
|
||||
return ret
|
||||
|
||||
def put(self, filename, data):
|
||||
nonlocal cache
|
||||
nonlocal num_put
|
||||
cache[filename] = json.dumps(data)
|
||||
num_put += 1
|
||||
|
||||
cache_module = (
|
||||
"torch._inductor.fb.remote_cache.FbRemoteAutotuneCacheBackend"
|
||||
if config.is_fbcode()
|
||||
else "torch._inductor.remote_cache.RedisRemoteCacheBackend"
|
||||
)
|
||||
|
||||
with config.patch(
|
||||
{
|
||||
"autotune_local_cache": False,
|
||||
"autotune_remote_cache": True,
|
||||
}
|
||||
), patch.dict(os.environ), patch(cache_module, MyCache, create=True):
|
||||
os.environ.pop("TRITON_CACHE_MANAGER", None)
|
||||
with config.patch({"max_autotune": True}):
|
||||
for _ in range(4):
|
||||
with fresh_inductor_cache():
|
||||
torch.compile(mm, dynamic=dynamic)(a, b)
|
||||
reset()
|
||||
self.assertEqual(num_get, 3)
|
||||
self.assertEqual(num_put, 1)
|
||||
num_get = 0
|
||||
num_put = 0
|
||||
for _ in range(4):
|
||||
with fresh_inductor_cache():
|
||||
torch.compile(f, dynamic=dynamic)(x, y)
|
||||
reset()
|
||||
self.assertEqual(num_get, 3)
|
||||
self.assertEqual(num_put, 1)
|
||||
|
||||
@skipIfRocm
|
||||
def test_precompilation_threads(self):
|
||||
import threading
|
||||
|
|
@ -777,6 +708,72 @@ class TestMaxAutotune(TestCase):
|
|||
self.assertIn("NoValidChoicesError", str(context.exception))
|
||||
|
||||
|
||||
@instantiate_parametrized_tests
|
||||
class TestMaxAutotuneRemoteCache(TestCase):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
PatchCaches.setUp()
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
PatchCaches.tearDown()
|
||||
|
||||
@skipIfRocm
|
||||
@parametrize("dynamic", (False, True))
|
||||
def test_max_autotune_remote_caching(self, dynamic: bool):
|
||||
from unittest.mock import patch
|
||||
|
||||
if not config.is_fbcode():
|
||||
self.skipTest("Redis for autotune is currently broken")
|
||||
|
||||
def mm(a, b):
|
||||
a = torch.sin(a)
|
||||
return a @ b
|
||||
|
||||
a = torch.randn(100, 10).cuda()
|
||||
b = torch.randn(10, 100).cuda()
|
||||
|
||||
class Model(torch.nn.Module):
|
||||
def forward(self, x, y):
|
||||
return x + y
|
||||
|
||||
def f(x, y):
|
||||
return Model()(x, y)
|
||||
|
||||
x = torch.randn(100, 100).cuda()
|
||||
y = torch.randn(100, 100).cuda()
|
||||
|
||||
with config.patch(
|
||||
{
|
||||
"autotune_local_cache": False,
|
||||
"autotune_remote_cache": True,
|
||||
}
|
||||
), patch.dict(os.environ), PatchCaches():
|
||||
os.environ.pop("TRITON_CACHE_MANAGER", None)
|
||||
with config.patch({"max_autotune": True}):
|
||||
for _ in range(4):
|
||||
with fresh_inductor_cache():
|
||||
torch.compile(mm, dynamic=dynamic)(a, b)
|
||||
reset()
|
||||
|
||||
PatchCaches.update()
|
||||
PatchCaches.report()
|
||||
self.assertEqual(PatchCaches.num_get_hit, 3)
|
||||
self.assertEqual(PatchCaches.num_get_miss, 1)
|
||||
self.assertEqual(PatchCaches.num_put, 1)
|
||||
|
||||
PatchCaches.reset()
|
||||
for _ in range(4):
|
||||
with fresh_inductor_cache():
|
||||
torch.compile(f, dynamic=dynamic)(x, y)
|
||||
reset()
|
||||
PatchCaches.update()
|
||||
PatchCaches.report()
|
||||
self.assertEqual(PatchCaches.num_get_hit, 3)
|
||||
self.assertEqual(PatchCaches.num_get_miss, 1)
|
||||
self.assertEqual(PatchCaches.num_put, 1)
|
||||
|
||||
|
||||
class TestBenchmarkRequest(BenchmarkRequest):
|
||||
def __init__(
|
||||
self, value: float, multi_device: bool, parent_visible_devices: Optional[str]
|
||||
|
|
|
|||
|
|
@ -68,6 +68,8 @@ T = TypeVar("T")
|
|||
if TYPE_CHECKING:
|
||||
from collections.abc import KeysView
|
||||
|
||||
from .remote_cache import RemoteCacheBackend
|
||||
|
||||
|
||||
"""
|
||||
codecache.py, cpp_builder.py and cpu_vec_isa.py import rule:
|
||||
|
|
@ -1173,7 +1175,7 @@ class FxGraphCache:
|
|||
compiled_graph: CompiledFxGraph,
|
||||
example_inputs: List[torch.Tensor],
|
||||
local: bool,
|
||||
remote_cache: None,
|
||||
remote_cache: Optional[RemoteCacheBackend],
|
||||
) -> None:
|
||||
"""
|
||||
Store a serialized CompiledFxGraph on disk.
|
||||
|
|
@ -1220,17 +1222,16 @@ class FxGraphCache:
|
|||
write_atomic(path, content, make_dirs=True)
|
||||
|
||||
if remote_cache:
|
||||
time_taken_ms = int((disk_compiled_graph._time_taken_ns or 0) // 1e6)
|
||||
cache_data = (
|
||||
{
|
||||
"data": content,
|
||||
"time_taken_ms": int(
|
||||
disk_compiled_graph._time_taken_ns // 1e6
|
||||
), # Convert from NS to MS
|
||||
"time_taken_ms": time_taken_ms,
|
||||
}
|
||||
if config.is_fbcode()
|
||||
else content
|
||||
)
|
||||
remote_cache.put(key, cache_data)
|
||||
remote_cache.put(key, cache_data) # type: ignore[arg-type]
|
||||
except Exception:
|
||||
log.warning("fx graph unable to write to cache", exc_info=True)
|
||||
counters["inductor"]["fxgraph_cache_write_error"] += 1
|
||||
|
|
@ -1291,7 +1292,7 @@ class FxGraphCache:
|
|||
cache_info["key"] = key
|
||||
cache_info["components"] = debug_lines
|
||||
|
||||
remote_cache = None
|
||||
remote_cache: Optional[RemoteCacheBackend] = None
|
||||
if remote:
|
||||
cache_id = "fx-graph-v1"
|
||||
try:
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user