mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Lower cache mocking to test more pytorch code (#133579)
Summary: Previously we were mocking out FbRemoteFxGraphCacheBackend which meant that we were missing testing a whole bunch of the cache code. Cache at a lower level (CacheClient, LocalAutotuneCacheBackend, ManifoldClient, Redis) so we cover a larger amount of the caching code. Test Plan: unit tests Reviewed By: oulgen Differential Revision: D60937966 Pull Request resolved: https://github.com/pytorch/pytorch/pull/133579 Approved by: https://github.com/oulgen
This commit is contained in:
parent
32ed4a3beb
commit
68fcd54226
|
|
@ -17,6 +17,7 @@ jinja2
|
||||||
fsspec
|
fsspec
|
||||||
lintrunner
|
lintrunner
|
||||||
ninja
|
ninja
|
||||||
|
redis
|
||||||
# setuptools was removed from default python install
|
# setuptools was removed from default python install
|
||||||
setuptools ; python_version >= "3.12"
|
setuptools ; python_version >= "3.12"
|
||||||
packaging
|
packaging
|
||||||
|
|
|
||||||
290
test/inductor/mock_cache.py
Normal file
290
test/inductor/mock_cache.py
Normal file
|
|
@ -0,0 +1,290 @@
|
||||||
|
# Owner(s): ["module: inductor"]
|
||||||
|
import contextlib
|
||||||
|
import dataclasses
|
||||||
|
import sys
|
||||||
|
import threading
|
||||||
|
import unittest.mock
|
||||||
|
from types import TracebackType
|
||||||
|
from typing import Callable, Generator, Optional, Tuple, Type, Union
|
||||||
|
from typing_extensions import override, Self
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch._inductor import config
|
||||||
|
from torch._inductor.remote_cache import RemoteCacheBackend
|
||||||
|
|
||||||
|
|
||||||
|
# The cache state is thread-local so if we're running multiple tests at once
|
||||||
|
# they won't cross contaminate. However - it needs to be "global" because we
|
||||||
|
# allow code to create new cache clients which refer to the same cache (because
|
||||||
|
# it's a remote cache).
|
||||||
|
class _MockCacheState(threading.local):
|
||||||
|
def __init__(self, name: str):
|
||||||
|
self.reset()
|
||||||
|
self._name = name
|
||||||
|
self._cache = {}
|
||||||
|
self._clients = {} # Used for Manifold
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self.num_init = 0
|
||||||
|
self.num_put = 0
|
||||||
|
self.num_get_hit = 0
|
||||||
|
self.num_get_miss = 0
|
||||||
|
|
||||||
|
def report(self):
|
||||||
|
print(
|
||||||
|
"".join(
|
||||||
|
[
|
||||||
|
f"{self._name} cache: ",
|
||||||
|
f"init: {self.num_init}, ",
|
||||||
|
f"puts: {self.num_put}, ",
|
||||||
|
f"misses: {self.num_get_miss}, ",
|
||||||
|
f"hits: {self.num_get_hit}, ",
|
||||||
|
]
|
||||||
|
),
|
||||||
|
file=sys.stderr,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class _MockLocalAutotuneCacheBackend(RemoteCacheBackend):
|
||||||
|
_state = _MockCacheState("Local")
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
state = self._state
|
||||||
|
state.num_init += 1
|
||||||
|
|
||||||
|
@override
|
||||||
|
def get(self, key: str) -> Optional[bytes]:
|
||||||
|
assert isinstance(key, str)
|
||||||
|
|
||||||
|
state = self._state
|
||||||
|
if key in state._cache:
|
||||||
|
state.num_get_hit += 1
|
||||||
|
return state._cache[key]
|
||||||
|
else:
|
||||||
|
state.num_get_miss += 1
|
||||||
|
|
||||||
|
@override
|
||||||
|
def put(self, key: str, data: bytes) -> None:
|
||||||
|
assert isinstance(key, str)
|
||||||
|
assert isinstance(data, bytes)
|
||||||
|
|
||||||
|
state = self._state
|
||||||
|
state.num_put += 1
|
||||||
|
state._cache[key] = data
|
||||||
|
|
||||||
|
|
||||||
|
class _MockRedisRemoteCache:
|
||||||
|
_state = _MockCacheState("Redis")
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
state = self._state
|
||||||
|
state.num_init += 1
|
||||||
|
|
||||||
|
def get(self, key: Union[bytes, str]) -> Optional[Union[bytes, str, int, float]]:
|
||||||
|
assert isinstance(key, (bytes, str))
|
||||||
|
|
||||||
|
state = self._state
|
||||||
|
|
||||||
|
if key in state._cache:
|
||||||
|
state.num_get_hit += 1
|
||||||
|
else:
|
||||||
|
state.num_get_miss += 1
|
||||||
|
return state._cache.get(key)
|
||||||
|
|
||||||
|
def set(self, key: Union[bytes, str], data: Union[bytes, str, int, float]) -> None:
|
||||||
|
assert isinstance(key, (bytes, str))
|
||||||
|
assert isinstance(data, (bytes, str, int, float)), type(data)
|
||||||
|
|
||||||
|
state = self._state
|
||||||
|
|
||||||
|
# According to https://redis-py.readthedocs.io/en/stable/commands.html#redis.commands.core.CoreCommands.set
|
||||||
|
# redis accepts Union[bytes, memoryview, str, int, float]
|
||||||
|
state.num_put += 1
|
||||||
|
state._cache[key] = data
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class CacheDecl:
|
||||||
|
qname: str
|
||||||
|
cls: Type[object]
|
||||||
|
f: Optional[Callable[..., object]] = None
|
||||||
|
|
||||||
|
def patch(self) -> contextlib.AbstractContextManager:
|
||||||
|
return unittest.mock.patch(self.qname, self.f or self.cls)
|
||||||
|
|
||||||
|
|
||||||
|
_CACHES = (
|
||||||
|
CacheDecl(
|
||||||
|
"torch._inductor.runtime.triton_heuristics.LocalAutotuneCache",
|
||||||
|
_MockLocalAutotuneCacheBackend,
|
||||||
|
),
|
||||||
|
CacheDecl("redis.Redis", _MockRedisRemoteCache),
|
||||||
|
)
|
||||||
|
|
||||||
|
# List of configs for each cache
|
||||||
|
_CACHE_CONFIG_EN = (
|
||||||
|
"fx_graph_cache",
|
||||||
|
"fx_graph_remote_cache",
|
||||||
|
"autotune_local_cache",
|
||||||
|
"autotune_remote_cache",
|
||||||
|
# "bundled_autotune_cache",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _has_redis():
|
||||||
|
import importlib
|
||||||
|
|
||||||
|
return importlib.util.find_spec("redis") is not None
|
||||||
|
|
||||||
|
|
||||||
|
class PatchCaches(contextlib.AbstractContextManager):
|
||||||
|
num_init = 0
|
||||||
|
num_put = 0
|
||||||
|
num_get_miss = 0
|
||||||
|
num_get_hit = 0
|
||||||
|
_savedCacheState = {}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_caches() -> Tuple[CacheDecl, ...]:
|
||||||
|
if config.is_fbcode():
|
||||||
|
from .fb.mock_cache import FB_CACHES
|
||||||
|
|
||||||
|
return _CACHES + FB_CACHES
|
||||||
|
else:
|
||||||
|
return _CACHES
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._contexts = []
|
||||||
|
for decl in self.get_caches():
|
||||||
|
self._contexts.append(decl.patch())
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def reset(cls):
|
||||||
|
"""
|
||||||
|
Reset the patched cache states as well as the PatchCaches
|
||||||
|
aggregation.
|
||||||
|
"""
|
||||||
|
cls.num_init = 0
|
||||||
|
cls.num_put = 0
|
||||||
|
cls.num_get_miss = 0
|
||||||
|
cls.num_get_hit = 0
|
||||||
|
|
||||||
|
for decl in cls.get_caches():
|
||||||
|
decl.cls._state.reset()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def update(cls):
|
||||||
|
"""
|
||||||
|
Update PatchCaches' state with the values from all the patched caches.
|
||||||
|
"""
|
||||||
|
cls.num_init = sum(decl.cls._state.num_init for decl in cls.get_caches())
|
||||||
|
cls.num_put = sum(decl.cls._state.num_put for decl in cls.get_caches())
|
||||||
|
cls.num_get_miss = sum(
|
||||||
|
decl.cls._state.num_get_miss for decl in cls.get_caches()
|
||||||
|
)
|
||||||
|
cls.num_get_hit = sum(decl.cls._state.num_get_hit for decl in cls.get_caches())
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def setUp(cls):
|
||||||
|
# If we don't have redis available then fake it since we'll be mocking it anyway.
|
||||||
|
if not _has_redis():
|
||||||
|
|
||||||
|
class FakeRedisModule:
|
||||||
|
class Redis:
|
||||||
|
pass
|
||||||
|
|
||||||
|
sys.modules["redis"] = FakeRedisModule()
|
||||||
|
|
||||||
|
# If this test is using PatchCaches then disable all the caches by
|
||||||
|
# default, letting the tests turn them on explicitly. This is because
|
||||||
|
# tests using PatchCaches will often want to check stats explicitly.
|
||||||
|
cls._savedCacheState = {}
|
||||||
|
for name in _CACHE_CONFIG_EN:
|
||||||
|
if hasattr(config, name):
|
||||||
|
cls._savedCacheState[name] = getattr(config, name)
|
||||||
|
setattr(config, name, False)
|
||||||
|
|
||||||
|
for decl in cls.get_caches():
|
||||||
|
if hasattr(decl.cls, "setUp"):
|
||||||
|
decl.cls.setUp()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def tearDown(cls):
|
||||||
|
for decl in cls.get_caches()[::-1]:
|
||||||
|
if hasattr(decl.cls, "tearDown"):
|
||||||
|
decl.cls.tearDown()
|
||||||
|
|
||||||
|
# Restore cache defaults
|
||||||
|
for name in _CACHE_CONFIG_EN:
|
||||||
|
delattr(config, name)
|
||||||
|
if name in cls._savedCacheState:
|
||||||
|
setattr(config, name, cls._savedCacheState[name])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def report(cls):
|
||||||
|
"""
|
||||||
|
Report cache state for all patched caches.
|
||||||
|
"""
|
||||||
|
for decl in cls.get_caches():
|
||||||
|
decl.cls._state.report()
|
||||||
|
print(
|
||||||
|
"".join(
|
||||||
|
[
|
||||||
|
"All caches: ",
|
||||||
|
f"init: {cls.num_init}, ",
|
||||||
|
f"puts: {cls.num_put}, ",
|
||||||
|
f"misses: {cls.num_get_miss}, ",
|
||||||
|
f"hits: {cls.num_get_hit}",
|
||||||
|
]
|
||||||
|
),
|
||||||
|
file=sys.stderr,
|
||||||
|
)
|
||||||
|
|
||||||
|
def __enter__(self) -> Self:
|
||||||
|
"""
|
||||||
|
Start mocking the patched caches.
|
||||||
|
"""
|
||||||
|
self.reset()
|
||||||
|
|
||||||
|
for ctx in self._contexts:
|
||||||
|
ctx.__enter__()
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(
|
||||||
|
self,
|
||||||
|
exc_type: Optional[Type[BaseException]],
|
||||||
|
exc_value: Optional[BaseException],
|
||||||
|
traceback: Optional[TracebackType],
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Stop mocking the patched caches.
|
||||||
|
"""
|
||||||
|
for ctx in self._contexts[::-1]:
|
||||||
|
ctx.__exit__(exc_type, exc_value, traceback)
|
||||||
|
|
||||||
|
self.update()
|
||||||
|
|
||||||
|
|
||||||
|
@contextlib.contextmanager
|
||||||
|
def patch_fbcode(state: bool) -> Generator[None, None, None]:
|
||||||
|
if hasattr(torch.version, "git_version"):
|
||||||
|
# Currently non-fbcode
|
||||||
|
if state:
|
||||||
|
old = torch.version.git_version
|
||||||
|
delattr(torch.version, "git_version")
|
||||||
|
try:
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
torch.version.git_version = old
|
||||||
|
else:
|
||||||
|
yield
|
||||||
|
else:
|
||||||
|
# Currently fbcode
|
||||||
|
if state:
|
||||||
|
yield
|
||||||
|
else:
|
||||||
|
torch.version.git_version = "12345+"
|
||||||
|
try:
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
delattr(torch.version, "git_version")
|
||||||
|
|
@ -1,7 +1,5 @@
|
||||||
# Owner(s): ["module: inductor"]
|
# Owner(s): ["module: inductor"]
|
||||||
import base64
|
|
||||||
import functools
|
import functools
|
||||||
import json
|
|
||||||
import os
|
import os
|
||||||
import pickle
|
import pickle
|
||||||
import unittest
|
import unittest
|
||||||
|
|
@ -42,10 +40,16 @@ from torch.testing._internal.inductor_utils import (
|
||||||
from torch.utils._triton import has_triton
|
from torch.utils._triton import has_triton
|
||||||
|
|
||||||
|
|
||||||
|
try:
|
||||||
|
from .mock_cache import PatchCaches
|
||||||
|
except ImportError:
|
||||||
|
from mock_cache import PatchCaches # @manual
|
||||||
|
|
||||||
|
|
||||||
HAS_TRITON = has_triton()
|
HAS_TRITON = has_triton()
|
||||||
|
|
||||||
if HAS_TRITON:
|
if HAS_TRITON:
|
||||||
import triton
|
import triton # @manual
|
||||||
|
|
||||||
from torch.testing._internal.triton_utils import add_kernel
|
from torch.testing._internal.triton_utils import add_kernel
|
||||||
|
|
||||||
|
|
@ -106,6 +110,11 @@ class TestFxGraphCache(TestCase):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
super().setUp()
|
super().setUp()
|
||||||
counters.clear()
|
counters.clear()
|
||||||
|
PatchCaches.setUp()
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
super().tearDown()
|
||||||
|
PatchCaches.tearDown()
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
torch._dynamo.reset()
|
torch._dynamo.reset()
|
||||||
|
|
@ -168,56 +177,23 @@ class TestFxGraphCache(TestCase):
|
||||||
a = torch.rand(25, dtype=dtype, device=device)
|
a = torch.rand(25, dtype=dtype, device=device)
|
||||||
b = torch.rand(5, 5, dtype=dtype, device=device)
|
b = torch.rand(5, 5, dtype=dtype, device=device)
|
||||||
|
|
||||||
cache = {}
|
|
||||||
num_get = 0
|
|
||||||
num_put = 0
|
|
||||||
|
|
||||||
class MyCache:
|
|
||||||
def __init__(self, key, is_autotune=False):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def get(self, filename):
|
|
||||||
nonlocal cache
|
|
||||||
nonlocal num_get
|
|
||||||
if filename not in cache:
|
|
||||||
return None
|
|
||||||
ret = json.loads(cache[filename])
|
|
||||||
num_get += 1
|
|
||||||
if config.is_fbcode():
|
|
||||||
return base64.b64decode(ret["data"]) if ret is not None else ret
|
|
||||||
else:
|
|
||||||
return base64.b64decode(ret) if ret is not None else ret
|
|
||||||
|
|
||||||
def put(self, filename, data):
|
|
||||||
nonlocal cache
|
|
||||||
nonlocal num_put
|
|
||||||
if config.is_fbcode():
|
|
||||||
data["data"] = base64.b64encode(data["data"]).decode("ascii")
|
|
||||||
else:
|
|
||||||
data = base64.b64encode(data).decode("ascii")
|
|
||||||
cache[filename] = json.dumps(data)
|
|
||||||
num_put += 1
|
|
||||||
|
|
||||||
cache_module = (
|
|
||||||
"torch._inductor.fb.remote_cache.FbRemoteFxGraphCacheBackend"
|
|
||||||
if config.is_fbcode()
|
|
||||||
else "torch._inductor.remote_cache.RedisRemoteCacheBackend"
|
|
||||||
)
|
|
||||||
|
|
||||||
with config.patch(
|
with config.patch(
|
||||||
{
|
{
|
||||||
"fx_graph_cache": False,
|
"fx_graph_cache": False,
|
||||||
"fx_graph_remote_cache": True,
|
"fx_graph_remote_cache": True,
|
||||||
}
|
}
|
||||||
), patch.dict(os.environ), patch(cache_module, MyCache, create=True):
|
), patch.dict(os.environ), PatchCaches():
|
||||||
os.environ.pop("TRITON_CACHE_MANAGER", None)
|
os.environ.pop("TRITON_CACHE_MANAGER", None)
|
||||||
for _ in range(4):
|
for _ in range(4):
|
||||||
with fresh_inductor_cache():
|
with fresh_inductor_cache():
|
||||||
compiled_fn = torch.compile(fn, dynamic=dynamic)
|
compiled_fn = torch.compile(fn, dynamic=dynamic)
|
||||||
self.assertEqual(fn(a, b), compiled_fn(a, b))
|
self.assertEqual(fn(a, b), compiled_fn(a, b))
|
||||||
reset()
|
reset()
|
||||||
self.assertEqual(num_get, 3)
|
|
||||||
self.assertEqual(num_put, 1)
|
PatchCaches.report()
|
||||||
|
self.assertEqual(PatchCaches.num_get_hit, 3)
|
||||||
|
self.assertEqual(PatchCaches.num_get_miss, 1)
|
||||||
|
self.assertEqual(PatchCaches.num_put, 1)
|
||||||
|
|
||||||
@requires_triton()
|
@requires_triton()
|
||||||
@config.patch({"fx_graph_cache": True})
|
@config.patch({"fx_graph_cache": True})
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,4 @@
|
||||||
# Owner(s): ["module: inductor"]
|
# Owner(s): ["module: inductor"]
|
||||||
import json
|
|
||||||
import os
|
import os
|
||||||
import unittest
|
import unittest
|
||||||
from typing import Callable, List, Optional
|
from typing import Callable, List, Optional
|
||||||
|
|
@ -35,6 +34,12 @@ from torch.testing._internal.common_utils import (
|
||||||
from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA
|
from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA
|
||||||
|
|
||||||
|
|
||||||
|
try:
|
||||||
|
from .mock_cache import PatchCaches
|
||||||
|
except ImportError:
|
||||||
|
from mock_cache import PatchCaches # @manual
|
||||||
|
|
||||||
|
|
||||||
torch.set_float32_matmul_precision("high")
|
torch.set_float32_matmul_precision("high")
|
||||||
if HAS_CUDA:
|
if HAS_CUDA:
|
||||||
torch.cuda.memory._set_allocator_settings("expandable_segments:False")
|
torch.cuda.memory._set_allocator_settings("expandable_segments:False")
|
||||||
|
|
@ -219,80 +224,6 @@ class TestMaxAutotune(TestCase):
|
||||||
with config.patch({"max_autotune": True}):
|
with config.patch({"max_autotune": True}):
|
||||||
torch.compile(mm, dynamic=dynamic)(a, b)
|
torch.compile(mm, dynamic=dynamic)(a, b)
|
||||||
|
|
||||||
@skipIfRocm
|
|
||||||
@parametrize("dynamic", (False, True))
|
|
||||||
def test_max_autotune_remote_caching(self, dynamic: bool):
|
|
||||||
from unittest.mock import patch
|
|
||||||
|
|
||||||
def mm(a, b):
|
|
||||||
a = torch.sin(a)
|
|
||||||
return a @ b
|
|
||||||
|
|
||||||
a = torch.randn(100, 10).cuda()
|
|
||||||
b = torch.randn(10, 100).cuda()
|
|
||||||
|
|
||||||
class Model(torch.nn.Module):
|
|
||||||
def forward(self, x, y):
|
|
||||||
return x + y
|
|
||||||
|
|
||||||
def f(x, y):
|
|
||||||
return Model()(x, y)
|
|
||||||
|
|
||||||
x = torch.randn(100, 100).cuda()
|
|
||||||
y = torch.randn(100, 100).cuda()
|
|
||||||
|
|
||||||
cache = {}
|
|
||||||
num_get = 0
|
|
||||||
num_put = 0
|
|
||||||
|
|
||||||
class MyCache:
|
|
||||||
def __init__(self, key, is_autotune=False):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def get(self, filename):
|
|
||||||
nonlocal cache
|
|
||||||
nonlocal num_get
|
|
||||||
if filename not in cache:
|
|
||||||
return None
|
|
||||||
ret = json.loads(cache[filename])
|
|
||||||
num_get += 1
|
|
||||||
return ret
|
|
||||||
|
|
||||||
def put(self, filename, data):
|
|
||||||
nonlocal cache
|
|
||||||
nonlocal num_put
|
|
||||||
cache[filename] = json.dumps(data)
|
|
||||||
num_put += 1
|
|
||||||
|
|
||||||
cache_module = (
|
|
||||||
"torch._inductor.fb.remote_cache.FbRemoteAutotuneCacheBackend"
|
|
||||||
if config.is_fbcode()
|
|
||||||
else "torch._inductor.remote_cache.RedisRemoteCacheBackend"
|
|
||||||
)
|
|
||||||
|
|
||||||
with config.patch(
|
|
||||||
{
|
|
||||||
"autotune_local_cache": False,
|
|
||||||
"autotune_remote_cache": True,
|
|
||||||
}
|
|
||||||
), patch.dict(os.environ), patch(cache_module, MyCache, create=True):
|
|
||||||
os.environ.pop("TRITON_CACHE_MANAGER", None)
|
|
||||||
with config.patch({"max_autotune": True}):
|
|
||||||
for _ in range(4):
|
|
||||||
with fresh_inductor_cache():
|
|
||||||
torch.compile(mm, dynamic=dynamic)(a, b)
|
|
||||||
reset()
|
|
||||||
self.assertEqual(num_get, 3)
|
|
||||||
self.assertEqual(num_put, 1)
|
|
||||||
num_get = 0
|
|
||||||
num_put = 0
|
|
||||||
for _ in range(4):
|
|
||||||
with fresh_inductor_cache():
|
|
||||||
torch.compile(f, dynamic=dynamic)(x, y)
|
|
||||||
reset()
|
|
||||||
self.assertEqual(num_get, 3)
|
|
||||||
self.assertEqual(num_put, 1)
|
|
||||||
|
|
||||||
@skipIfRocm
|
@skipIfRocm
|
||||||
def test_precompilation_threads(self):
|
def test_precompilation_threads(self):
|
||||||
import threading
|
import threading
|
||||||
|
|
@ -777,6 +708,72 @@ class TestMaxAutotune(TestCase):
|
||||||
self.assertIn("NoValidChoicesError", str(context.exception))
|
self.assertIn("NoValidChoicesError", str(context.exception))
|
||||||
|
|
||||||
|
|
||||||
|
@instantiate_parametrized_tests
|
||||||
|
class TestMaxAutotuneRemoteCache(TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
super().setUp()
|
||||||
|
PatchCaches.setUp()
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
super().tearDown()
|
||||||
|
PatchCaches.tearDown()
|
||||||
|
|
||||||
|
@skipIfRocm
|
||||||
|
@parametrize("dynamic", (False, True))
|
||||||
|
def test_max_autotune_remote_caching(self, dynamic: bool):
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
if not config.is_fbcode():
|
||||||
|
self.skipTest("Redis for autotune is currently broken")
|
||||||
|
|
||||||
|
def mm(a, b):
|
||||||
|
a = torch.sin(a)
|
||||||
|
return a @ b
|
||||||
|
|
||||||
|
a = torch.randn(100, 10).cuda()
|
||||||
|
b = torch.randn(10, 100).cuda()
|
||||||
|
|
||||||
|
class Model(torch.nn.Module):
|
||||||
|
def forward(self, x, y):
|
||||||
|
return x + y
|
||||||
|
|
||||||
|
def f(x, y):
|
||||||
|
return Model()(x, y)
|
||||||
|
|
||||||
|
x = torch.randn(100, 100).cuda()
|
||||||
|
y = torch.randn(100, 100).cuda()
|
||||||
|
|
||||||
|
with config.patch(
|
||||||
|
{
|
||||||
|
"autotune_local_cache": False,
|
||||||
|
"autotune_remote_cache": True,
|
||||||
|
}
|
||||||
|
), patch.dict(os.environ), PatchCaches():
|
||||||
|
os.environ.pop("TRITON_CACHE_MANAGER", None)
|
||||||
|
with config.patch({"max_autotune": True}):
|
||||||
|
for _ in range(4):
|
||||||
|
with fresh_inductor_cache():
|
||||||
|
torch.compile(mm, dynamic=dynamic)(a, b)
|
||||||
|
reset()
|
||||||
|
|
||||||
|
PatchCaches.update()
|
||||||
|
PatchCaches.report()
|
||||||
|
self.assertEqual(PatchCaches.num_get_hit, 3)
|
||||||
|
self.assertEqual(PatchCaches.num_get_miss, 1)
|
||||||
|
self.assertEqual(PatchCaches.num_put, 1)
|
||||||
|
|
||||||
|
PatchCaches.reset()
|
||||||
|
for _ in range(4):
|
||||||
|
with fresh_inductor_cache():
|
||||||
|
torch.compile(f, dynamic=dynamic)(x, y)
|
||||||
|
reset()
|
||||||
|
PatchCaches.update()
|
||||||
|
PatchCaches.report()
|
||||||
|
self.assertEqual(PatchCaches.num_get_hit, 3)
|
||||||
|
self.assertEqual(PatchCaches.num_get_miss, 1)
|
||||||
|
self.assertEqual(PatchCaches.num_put, 1)
|
||||||
|
|
||||||
|
|
||||||
class TestBenchmarkRequest(BenchmarkRequest):
|
class TestBenchmarkRequest(BenchmarkRequest):
|
||||||
def __init__(
|
def __init__(
|
||||||
self, value: float, multi_device: bool, parent_visible_devices: Optional[str]
|
self, value: float, multi_device: bool, parent_visible_devices: Optional[str]
|
||||||
|
|
|
||||||
|
|
@ -68,6 +68,8 @@ T = TypeVar("T")
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from collections.abc import KeysView
|
from collections.abc import KeysView
|
||||||
|
|
||||||
|
from .remote_cache import RemoteCacheBackend
|
||||||
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
codecache.py, cpp_builder.py and cpu_vec_isa.py import rule:
|
codecache.py, cpp_builder.py and cpu_vec_isa.py import rule:
|
||||||
|
|
@ -1173,7 +1175,7 @@ class FxGraphCache:
|
||||||
compiled_graph: CompiledFxGraph,
|
compiled_graph: CompiledFxGraph,
|
||||||
example_inputs: List[torch.Tensor],
|
example_inputs: List[torch.Tensor],
|
||||||
local: bool,
|
local: bool,
|
||||||
remote_cache: None,
|
remote_cache: Optional[RemoteCacheBackend],
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Store a serialized CompiledFxGraph on disk.
|
Store a serialized CompiledFxGraph on disk.
|
||||||
|
|
@ -1220,17 +1222,16 @@ class FxGraphCache:
|
||||||
write_atomic(path, content, make_dirs=True)
|
write_atomic(path, content, make_dirs=True)
|
||||||
|
|
||||||
if remote_cache:
|
if remote_cache:
|
||||||
|
time_taken_ms = int((disk_compiled_graph._time_taken_ns or 0) // 1e6)
|
||||||
cache_data = (
|
cache_data = (
|
||||||
{
|
{
|
||||||
"data": content,
|
"data": content,
|
||||||
"time_taken_ms": int(
|
"time_taken_ms": time_taken_ms,
|
||||||
disk_compiled_graph._time_taken_ns // 1e6
|
|
||||||
), # Convert from NS to MS
|
|
||||||
}
|
}
|
||||||
if config.is_fbcode()
|
if config.is_fbcode()
|
||||||
else content
|
else content
|
||||||
)
|
)
|
||||||
remote_cache.put(key, cache_data)
|
remote_cache.put(key, cache_data) # type: ignore[arg-type]
|
||||||
except Exception:
|
except Exception:
|
||||||
log.warning("fx graph unable to write to cache", exc_info=True)
|
log.warning("fx graph unable to write to cache", exc_info=True)
|
||||||
counters["inductor"]["fxgraph_cache_write_error"] += 1
|
counters["inductor"]["fxgraph_cache_write_error"] += 1
|
||||||
|
|
@ -1291,7 +1292,7 @@ class FxGraphCache:
|
||||||
cache_info["key"] = key
|
cache_info["key"] = key
|
||||||
cache_info["components"] = debug_lines
|
cache_info["components"] = debug_lines
|
||||||
|
|
||||||
remote_cache = None
|
remote_cache: Optional[RemoteCacheBackend] = None
|
||||||
if remote:
|
if remote:
|
||||||
cache_id = "fx-graph-v1"
|
cache_id = "fx-graph-v1"
|
||||||
try:
|
try:
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user