[fx graph cache] Support freezing with FX graph caching (#136505)

Summary: The main changes to support freezing are:
1) When pickling constant tensors as part of the cache key calculation: If freezing has not been applied, then keep the existing behavior (pickle the metadata and values). If freezing has been applied, then pickle the values if the constant will be inlined; otherwise, consider only the metadata.
2) If freezing has been applied, modify what we store in the cache: Instead of storing the constant attributes in the cache entry, store the _names_ of the constants, and then grab those constants from the GraphModule when we need attache the attributes to a newly-loaded Python module. Since the cache lookup path loads the Python module, this bullet means we need to thread through a GraphModule argument in several places.
3) Since this feature means that we may need to reload the same Python module path more than once (but attach different constant attributes), I changed PyCodeCache.load_by_key_path to not store an in-memory map of path to module (since there may be more than one). I don't _think_ this will have any affect on performance, however.. It's unclear why we were using an in-memory cache here anyway, since this function should only be called once for each module needed to be loaded.
4) Several tests were removing on-disk PyCodeCache artifacts by iterating over the modules. I made this more straightforward by implementing a cache_clear method that removes the on-disk artifacts. Arguably, this should have been the implementation all along.

Differential Revision: [D63542170](https://our.internmc.facebook.com/intern/diff/D63542170)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/136505
Approved by: https://github.com/eellison
This commit is contained in:
Sam Larsen 2024-10-31 15:19:32 -07:00 committed by PyTorch MergeBot
parent 7d644f025f
commit d8b606ecb5
8 changed files with 225 additions and 64 deletions

View File

@ -1,6 +1,5 @@
# Owner(s): ["module: dynamo"]
import os
import unittest
from unittest.mock import patch
@ -53,8 +52,6 @@ class AOTAutogradCacheTests(InductorTestCase):
Clear unrelated caches, like dynamo and PyCodeCache
"""
torch._dynamo.reset()
for m in torch._inductor.codecache.PyCodeCache.cache.values():
os.remove(m.__file__)
torch._inductor.codecache.PyCodeCache.cache_clear()
@inductor_config.patch("fx_graph_remote_cache", False)

View File

@ -160,11 +160,12 @@ class TestFxGraphCache(TestCase):
# A second call should hit. (First reset so in-memory guards
# don't prevent compilation).
for m in torch._inductor.codecache.PyCodeCache.cache.values():
os.remove(m.__file__)
# Clean triton kernels
shutil.rmtree(os.path.join(cache_dir(), "triton"), ignore_errors=True)
self.reset()
# Clean PyCodeCache and triton kernels
PyCodeCache.cache_clear()
shutil.rmtree(os.path.join(cache_dir(), "triton"), ignore_errors=True)
self.assertEqual(fn(a, b), compiled_fn(a, b))
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1)
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 1)
@ -381,6 +382,24 @@ class TestFxGraphCache(TestCase):
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 2)
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0)
# Now pretend the constants are frozen params.
counters.clear()
self.reset()
with mock.patch(
"torch._inductor.codecache.has_frozen_params", return_value=True
):
# A call to fn1 should miss in the cache since we do not consider
# the constant values.
self.assertEqual(fn1(a), compiled_fn1(a))
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1)
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0)
# A call to fn2 should hit for the same reason.
self.assertEqual(fn2(a), compiled_fn2(a))
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1)
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 1)
@requires_cuda
@config.patch({"fx_graph_cache": True})
@config.patch({"fx_graph_remote_cache": False})
@ -417,8 +436,6 @@ class TestFxGraphCache(TestCase):
# A second call should hit. (First reset so in-memory guards
# don't prevent compilation).
for m in torch._inductor.codecache.PyCodeCache.cache.values():
os.remove(m.__file__)
self.reset()
self.assertEqual(fn(a, b, c), compiled_fn(a, b, c), atol=atol, rtol=rtol)
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1)
@ -426,8 +443,6 @@ class TestFxGraphCache(TestCase):
self.assertEqual(counters["inductor"]["fxgraph_lookup_write_file"], 1)
# A third call with different score_mod should have a cache miss
for m in torch._inductor.codecache.PyCodeCache.cache.values():
os.remove(m.__file__)
self.reset()
self.assertEqual(fn2(a, b, c), compiled_fn2(a, b, c), atol=atol, rtol=rtol)
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 2)
@ -647,14 +662,80 @@ class TestFxGraphCache(TestCase):
self.assertNotEqual(a, b)
@config.patch({"fx_graph_cache": True})
@config.patch({"fx_graph_remote_cache": False})
@config.patch({"freezing": True})
@parametrize("device", (GPU_TYPE, "cpu"))
def test_freezing(self, device):
if device == GPU_TYPE and not HAS_GPU:
raise unittest.SkipTest(f"requires {GPU_TYPE}")
class MM(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.param = torch.nn.Parameter(torch.rand(8, 8))
def forward(self, x):
return x @ self.param
dtype = torch.float16
# Populate a cache entry.
mod1 = MM().to(device=device, dtype=dtype)
with torch.no_grad():
x = torch.rand(8, 8).to(device=device, dtype=dtype)
out0 = mod1(x)
out1 = torch.compile(mod1)(x)
self.assertEqual(out0, out1)
self.assertEqual(counters["inductor"]["fxgraph_cache_bypass"], 0)
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1)
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0)
counters.clear()
self.reset()
# Same nn.Module, but with different parameters should cache hit.
mod2 = MM().to(device=device, dtype=dtype)
self.assertNotEqual(mod1.param, mod2.param)
with torch.no_grad():
x = torch.rand(8, 8).to(device=device, dtype=dtype)
out0 = mod2(x)
out1 = torch.compile(mod2)(x)
self.assertEqual(out0, out1)
self.assertEqual(counters["inductor"]["fxgraph_cache_bypass"], 0)
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 0)
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 1)
class TestFxGraphCacheHashing(TestCase):
def test_tensor_constants(self):
"""
Test the hashing of tensor constants.
"""
data = FxGraphCachePickler().dumps(torch.tensor(list(range(9))))
small = torch.tensor(list(range(8)))
large = torch.tensor(list(range(32)))
self.assertTrue(GraphLowering.can_inline_constant(small))
self.assertFalse(GraphLowering.can_inline_constant(large))
# By default, we hash the metadata and values independent of the size.
pickler = FxGraphCachePickler()
data = pickler.dumps(small)
self.assertIsInstance(pickle.loads(data), TensorMetadataAndValues)
data = pickler.dumps(large)
self.assertIsInstance(pickle.loads(data), TensorMetadataAndValues)
# If include_non_inlined=False, we only hash the values of small tensors.
pickler = FxGraphCachePickler(False)
data = pickler.dumps(small)
self.assertIsInstance(pickle.loads(data), TensorMetadataAndValues)
data = pickler.dumps(large)
self.assertIsInstance(pickle.loads(data), TensorMetadata)
def test_hash_fake_tensors(self):
"""
@ -1133,13 +1214,13 @@ class TestUtils(TestCase):
b = torch.rand(10)
with fresh_inductor_cache():
self.assertEqual(len(PyCodeCache.cache.keys()), 0)
self.assertEqual(len(PyCodeCache.modules), 0)
res1 = torch.compile(fn)(a, b)
cache_dir1 = cache_dir()
torch._dynamo.reset()
with fresh_inductor_cache():
self.assertEqual(len(PyCodeCache.cache.keys()), 0)
self.assertEqual(len(PyCodeCache.modules), 0)
res2 = torch.compile(fn)(a, b)
cache_dir2 = cache_dir()

View File

@ -40,11 +40,11 @@ class TestKernelBenchmark(TestCase):
def setUp(self):
super().setUp()
PyCodeCache.cache.clear()
PyCodeCache.cache_clear()
def get_compiled_module(self):
compiled_module = None
for v in PyCodeCache.cache.values():
for v in PyCodeCache.modules:
if hasattr(v, "benchmark_compiled_module"):
self.assertTrue(
compiled_module is None, "Found multiple compiled modules"

View File

@ -14,7 +14,7 @@ from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU
class TestTritonWrapper(TestCase):
def get_compiled_module(self):
compiled_module = None
for v in PyCodeCache.cache.values():
for v in PyCodeCache.modules:
if hasattr(v, "benchmark_compiled_module"):
self.assertTrue(
compiled_module is None, "Found multiple compiled modules"

View File

@ -526,10 +526,18 @@ class FxGraphCachePickler(pickle.Pickler):
data that allow us to compute a stable, but safe hash.
"""
def __init__(self) -> None:
def __init__(self, include_non_inlined: bool = True) -> None:
"""
Create an FX graph pickler. If include_non_inlined=True, then pickling will
include the _values_ for all Tensors. (Note that any tensors are constants
attached as attributes to the GraphModule). Otherwise, pickling will include
only the metadata for these tensors.
"""
self._stream = io.BytesIO()
super().__init__(self._stream)
self.include_non_inlined = include_non_inlined
self.dispatch_table = copyreg.dispatch_table.copy()
self.dispatch_table.update(
{
@ -558,35 +566,38 @@ class FxGraphCachePickler(pickle.Pickler):
def _reduce_tensor(
self,
t: Tensor,
) -> Tuple[Callable[[T], T], Tuple[TensorMetadataAndValues]]:
) -> Tuple[Callable[[T], T], Tuple[Union[TensorMetadata, TensorMetadataAndValues]]]:
"""
Custom reducer to pickle Tensors. If we see tensors, we know they're constants
stored as attributes on the GraphModule. Include the values in the key
calculation. Small tensors will be inlined, so we can't serve the same cache
entry for different values anyway. Large constants are treated as parameters, so
we could conceivably reuse a cache entry. To do that, however, PyCodeCache would
need more complexity to create a new module from its cache, but with the right
constants attached as attributes.
stored as attributes on the GraphModule.
"""
from .graph import GraphLowering
if t.is_mkldnn:
# TODO: These tensors don't currently pickle, so we can't cache a compiled
# graph containing them. Just fail now. If mkldnn tensors get pickling
# support, we can remove this.
raise BypassFxGraphCache("mkldnn tensors unpickleable")
# Very large tensors could be expensive to copy to cpu and hash. Let's at least
# report if we find slowness.
start = time()
values = t.tolist()
elapsed = time() - start
if elapsed > 1.0:
warnings.warn(
f"FX graph cache handling of a large constant took {elapsed:.1}s. "
"Please file an issue."
)
# If this is an inlined constant or include_non_inlined=True, then we include
# the metadata and the values.
metadata = extract_tensor_metadata_for_cache_key(t)
return (_ident, (TensorMetadataAndValues(metadata, values),))
if GraphLowering.can_inline_constant(t) or self.include_non_inlined:
# Very large tensors will be expensive to copy to cpu and hash. Let's at
# least report any slowness.
start = time()
values = t.tolist()
elapsed = time() - start
if elapsed > 1.0:
warnings.warn(
f"FX graph cache copying of a large constant took {elapsed:.1}s. "
"Please file an issue."
)
return (_ident, (TensorMetadataAndValues(metadata, values),))
# Otherwise, we just include the metadata.
return (_ident, (metadata,))
def _reduce_symint(self, s: SymInt) -> Tuple[Callable[[T], T], Tuple[str]]:
"""
@ -804,6 +815,10 @@ class FxGraphHashDetails:
return custom_pass.uuid()
def has_frozen_params(gm: torch.fx.GraphModule) -> bool:
return getattr(gm, "_has_frozen_params", False)
def compiled_fx_graph_hash(
gm: torch.fx.GraphModule,
example_inputs: Sequence[InputType],
@ -813,8 +828,13 @@ def compiled_fx_graph_hash(
"""
Generate a unique hash of the FX graph for caching.
"""
# To support caching when the graph has frozen params, we ignore the tensor values
# of non-inlined constants since they won't be included in the cache entry. Without
# freezing, we want to include the values of any constant attribute.
include_non_inlined = not has_frozen_params(gm)
details = FxGraphHashDetails(gm, example_inputs, fx_kwargs, inputs_to_check)
pickler = FxGraphCachePickler()
pickler = FxGraphCachePickler(include_non_inlined)
# The prefix distinguishes among the other kinds of objects we
# cache in this module.
key = "f" + pickler.get_hash(details)
@ -828,6 +848,7 @@ def cudagraph_post_compile(
example_inputs: Sequence[InputType],
compiled_graph: CompiledFxGraph,
cudagraphs: BoxedBool,
gm: Optional[torch.fx.GraphModule],
) -> None:
"""
Checks for any reasons not to run cudagraphs and then
@ -873,7 +894,7 @@ def cudagraph_post_compile(
stack_traces=stack_traces,
is_backward=is_backward,
is_inference=is_inference,
constants=tuple(compiled_graph.constants.values()),
constants=tuple(compiled_graph.get_constants(gm).values()),
placeholders=placeholders,
mutated_input_idxs=tuple(compiled_graph.mutated_input_idxs),
)
@ -1030,6 +1051,7 @@ class FxGraphCache:
example_inputs: Sequence[InputType],
local: bool,
remote_cache: Optional[RemoteCache[JsonDataTy]],
gm: Optional[torch.fx.GraphModule],
) -> Tuple[Optional[CompiledFxGraph], Dict[str, Any]]:
"""
Lookup a compiled graph in the cache by key. On a hit, return the
@ -1135,7 +1157,7 @@ class FxGraphCache:
graph.cache_key,
artifact_path,
graph.cache_linemap,
graph.constants,
graph.get_constants(gm),
).call
except OSError:
# Not expected, but in case the PyCodeCache entry is removed from
@ -1177,6 +1199,7 @@ class FxGraphCache:
compiled_graph: CompiledFxGraph,
example_inputs: Sequence[InputType],
cudagraphs: BoxedBool,
gm: Optional[torch.fx.GraphModule] = None,
) -> CompiledFxGraph:
"""
Run a set of post processing steps after loading from the cache. These involve:
@ -1206,6 +1229,7 @@ class FxGraphCache:
example_inputs,
compiled_graph,
cudagraphs,
gm,
)
inputs_to_check = compiled_graph.inputs_to_check
# cudagraphs could have been disabled from the earlier conditions
@ -1294,9 +1318,15 @@ class FxGraphCache:
raise BypassFxGraphCache("Unsupported post grad custom pass")
# Freezing can embed constants that wouldn't be static across runs.
if config.freezing or config.aot_inductor.use_runtime_constant_folding:
if has_frozen_params(gm) and not torch._utils_internal.justknobs_check(
"pytorch/inductor:allow_freezing_with_caching"
):
raise BypassFxGraphCache("Skipping graph with frozen constants")
if config.aot_inductor.use_runtime_constant_folding:
raise BypassFxGraphCache(
"Freezing may introduce constants that aren't static across runs"
"Runtime constant folding can introduce constants that aren't "
"static across runs"
)
from torch._inductor.bisect_helper import BisectionManager
@ -1386,6 +1416,7 @@ class FxGraphCache:
local: bool,
remote_cache: Optional[RemoteCache[JsonDataTy]],
is_backward: bool,
gm: Optional[torch.fx.GraphModule] = None,
) -> Tuple[Optional[CompiledFxGraph], Dict[str, Any]]:
"""
Lookup the graph with the given key, and return results and metadata.
@ -1393,7 +1424,7 @@ class FxGraphCache:
differently from FXGraphCache.
"""
compiled_graph, cache_info = FxGraphCache._lookup_graph(
key, example_inputs, local, remote_cache
key, example_inputs, local, remote_cache, gm
)
cache_info = {
**cache_info,
@ -1453,6 +1484,7 @@ class FxGraphCache:
local,
remote_cache,
is_backward=fx_kwargs.get("is_backward", False),
gm=gm,
)
# CACHE BYPASS: Compile the graph, don't save it to the cache
@ -1528,7 +1560,7 @@ class FxGraphCache:
)
# Use the passed in cudagraphs so that we mutate the BoxedBool correctly
FxGraphCache.post_compile(
compiled_graph, example_inputs, fx_kwargs["cudagraphs"] # type: ignore[arg-type]
compiled_graph, example_inputs, fx_kwargs["cudagraphs"], gm # type: ignore[arg-type]
)
return compiled_graph
@ -1561,7 +1593,15 @@ class CompiledFxGraph:
device_idxs: Set[int]
mutated_inputs: Set[str]
mutated_input_idxs: Set[int]
constants: Dict[str, torch.Tensor]
# We populate exactly one of the next two fields. In the common case, we store the
# constant attirbutes in the cache entry and re-attach them to the module created in
# PyCodeCache.load_by_key_path. In the case that the graph has frozen parameters,
# however, we save the mapping from attribute names in the GraphLowering to the
# original name of the attribute in the GraphModule. When we create the module from
# the cache entry, we then look up the constants from the current GraphModule. This
# scheme allows us to support caching with freezing.
allocated_constant_name: Optional[Dict[str, str]]
constants: Optional[Dict[str, torch.Tensor]]
torchbind_constants: Dict[str, torch._C.ScriptObject]
output_strides: Optional[List[Optional[Tuple[_StrideExprStr, ...]]]]
disabled_cudagraphs_reason: Optional[str]
@ -1588,6 +1628,7 @@ class CompiledFxGraph:
self,
current_callable: Optional[Callable[..., Any]],
graph: GraphLowering,
gm: torch.fx.GraphModule,
output_strides: List[Optional[Tuple[_StrideExprStr, ...]]],
disabled_cudagraphs_reason: Optional[str],
metrics_deltas: metrics.CachedMetricsDeltas,
@ -1604,7 +1645,12 @@ class CompiledFxGraph:
self.device_idxs = set(graph.device_idxs)
self.mutated_inputs = set(graph.mutated_inputs)
self.mutated_input_idxs = set(graph.mutated_input_idxs)
self.constants = graph.constants
if has_frozen_params(gm):
self.allocated_constant_name = graph.allocated_constant_name
self.constants = None
else:
self.allocated_constant_name = None
self.constants = graph.constants
self.torchbind_constants = graph.torchbind_constants
self.output_strides = output_strides
self.disabled_cudagraphs_reason = disabled_cudagraphs_reason
@ -1623,6 +1669,26 @@ class CompiledFxGraph:
finally:
AutotuneCacheBundler.end_compile()
def get_constants(
self, gm: Optional[torch.fx.GraphModule]
) -> Dict[str, torch.Tensor]:
"""
Get the constant attributes.
"""
# Normal case: The constants are stored in the entry.
if self.constants is not None:
return self.constants
# Freezing case: Look up the constants from attributes on the GraphModule using
# the allocated_constant_name map.
assert gm is not None
assert self.allocated_constant_name is not None
constants = {
name: getattr(gm, orig_name)
for name, orig_name in self.allocated_constant_name.items()
}
return constants
def run_command_and_check(cmd_: str) -> None:
cmd = shlex.split(cmd_)
@ -2976,9 +3042,13 @@ def touch(filename: str): # type: ignore[no-untyped-def]
@clear_on_fresh_inductor_cache
class PyCodeCache:
# Track the loaded modules so we can remove the on-disk artifacts when
# clearing the cache. Note also that we may load the same path more
# than once, but attach different attributes, i.e., due to different
# constant values.
modules: List[ModuleType] = []
cache: Dict[str, ModuleType] = {}
linemaps: Dict[str, List[Tuple[Any, ...]]] = {}
cache_clear = staticmethod(cache.clear)
@classmethod
def write(cls, source_code: str, extra: str = "") -> Tuple[str, str]:
@ -3005,24 +3075,33 @@ class PyCodeCache:
) -> ModuleType:
if linemap is None:
linemap = []
if key not in cls.cache:
mod = _reload_python_module(key, path)
# another thread might set this first
cls.cache.setdefault(key, mod)
# unzip into separate lines/nodes lists
cls.linemaps[path] = list(zip(*linemap))
mod = _reload_python_module(key, path)
if attrs is not None:
for k, v in attrs.items():
setattr(mod, k, v)
# unzip into separate lines/nodes lists
cls.linemaps[path] = list(zip(*linemap))
if not (linemap or attrs):
mod._reload_in_subproc = functools.partial( # type: ignore[attr-defined]
_reload_python_module_in_subproc, key, path
)
if attrs is not None:
for k, v in attrs.items():
setattr(mod, k, v)
return cls.cache[key]
if not (linemap or attrs):
mod._reload_in_subproc = functools.partial( # type: ignore[attr-defined]
_reload_python_module_in_subproc, key, path
)
cls.modules.append(mod)
return mod
@classmethod
def cache_clear(cls) -> None:
for mod in cls.modules:
try:
assert mod.__file__
os.remove(mod.__file__)
except FileNotFoundError:
pass
cls.modules.clear()
@classmethod
@functools.lru_cache(None)

View File

@ -760,7 +760,7 @@ def _compile_fx_inner(
# to return the string directly.
return compiled_graph
compiled_graph = FxGraphCache.post_compile(
compiled_graph, example_inputs, cudagraphs
compiled_graph, example_inputs, cudagraphs, gm
)
log.debug("FX codegen and compilation took %.3fs", time.time() - start)
@ -1015,6 +1015,7 @@ def fx_codegen_and_compile(
compiled_graph = CompiledFxGraph(
compiled_fn,
graph,
gm,
output_strides,
V.graph.disable_cudagraphs_reason,
metrics_helper.get_deltas(),
@ -1289,6 +1290,8 @@ def fw_compiler_freezing(
aot_example_inputs, # type: ignore[arg-type]
)
setattr(opt_model, "_has_frozen_params", True) # noqa: B010
aot_example_inputs = [aot_example_inputs[ind] for ind in preserved_arg_indices]
num_fixed = len(preserved_arg_indices) - num_example_inputs

View File

@ -77,7 +77,8 @@ def benchmark_all_kernels(benchmark_name, benchmark_all_configs):
from torch._inductor.codecache import PyCodeCache
nfound = 0
for kernel_key, kernel_mod in PyCodeCache.cache.items():
for kernel_mod in PyCodeCache.modules:
kernel_key = kernel_mod.key
if not hasattr(kernel_mod, "get_args") or not hasattr(kernel_mod, "call"):
continue

View File

@ -890,7 +890,7 @@ class ExecutionTraceObserver(_ITraceObserver):
kernel_files = [
v.__file__
for v in PyCodeCache.cache.values()
for v in PyCodeCache.modules
if getattr(v, "__file__", None) is not None
]
work_dir, file_name = os.path.split(self._output_file_path)