mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[fx graph cache] Support freezing with FX graph caching (#136505)
Summary: The main changes to support freezing are: 1) When pickling constant tensors as part of the cache key calculation: If freezing has not been applied, then keep the existing behavior (pickle the metadata and values). If freezing has been applied, then pickle the values if the constant will be inlined; otherwise, consider only the metadata. 2) If freezing has been applied, modify what we store in the cache: Instead of storing the constant attributes in the cache entry, store the _names_ of the constants, and then grab those constants from the GraphModule when we need attache the attributes to a newly-loaded Python module. Since the cache lookup path loads the Python module, this bullet means we need to thread through a GraphModule argument in several places. 3) Since this feature means that we may need to reload the same Python module path more than once (but attach different constant attributes), I changed PyCodeCache.load_by_key_path to not store an in-memory map of path to module (since there may be more than one). I don't _think_ this will have any affect on performance, however.. It's unclear why we were using an in-memory cache here anyway, since this function should only be called once for each module needed to be loaded. 4) Several tests were removing on-disk PyCodeCache artifacts by iterating over the modules. I made this more straightforward by implementing a cache_clear method that removes the on-disk artifacts. Arguably, this should have been the implementation all along. Differential Revision: [D63542170](https://our.internmc.facebook.com/intern/diff/D63542170) Pull Request resolved: https://github.com/pytorch/pytorch/pull/136505 Approved by: https://github.com/eellison
This commit is contained in:
parent
7d644f025f
commit
d8b606ecb5
|
|
@ -1,6 +1,5 @@
|
||||||
# Owner(s): ["module: dynamo"]
|
# Owner(s): ["module: dynamo"]
|
||||||
|
|
||||||
import os
|
|
||||||
import unittest
|
import unittest
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
|
@ -53,8 +52,6 @@ class AOTAutogradCacheTests(InductorTestCase):
|
||||||
Clear unrelated caches, like dynamo and PyCodeCache
|
Clear unrelated caches, like dynamo and PyCodeCache
|
||||||
"""
|
"""
|
||||||
torch._dynamo.reset()
|
torch._dynamo.reset()
|
||||||
for m in torch._inductor.codecache.PyCodeCache.cache.values():
|
|
||||||
os.remove(m.__file__)
|
|
||||||
torch._inductor.codecache.PyCodeCache.cache_clear()
|
torch._inductor.codecache.PyCodeCache.cache_clear()
|
||||||
|
|
||||||
@inductor_config.patch("fx_graph_remote_cache", False)
|
@inductor_config.patch("fx_graph_remote_cache", False)
|
||||||
|
|
|
||||||
|
|
@ -160,11 +160,12 @@ class TestFxGraphCache(TestCase):
|
||||||
|
|
||||||
# A second call should hit. (First reset so in-memory guards
|
# A second call should hit. (First reset so in-memory guards
|
||||||
# don't prevent compilation).
|
# don't prevent compilation).
|
||||||
for m in torch._inductor.codecache.PyCodeCache.cache.values():
|
|
||||||
os.remove(m.__file__)
|
|
||||||
# Clean triton kernels
|
|
||||||
shutil.rmtree(os.path.join(cache_dir(), "triton"), ignore_errors=True)
|
|
||||||
self.reset()
|
self.reset()
|
||||||
|
|
||||||
|
# Clean PyCodeCache and triton kernels
|
||||||
|
PyCodeCache.cache_clear()
|
||||||
|
shutil.rmtree(os.path.join(cache_dir(), "triton"), ignore_errors=True)
|
||||||
|
|
||||||
self.assertEqual(fn(a, b), compiled_fn(a, b))
|
self.assertEqual(fn(a, b), compiled_fn(a, b))
|
||||||
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1)
|
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1)
|
||||||
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 1)
|
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 1)
|
||||||
|
|
@ -381,6 +382,24 @@ class TestFxGraphCache(TestCase):
|
||||||
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 2)
|
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 2)
|
||||||
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0)
|
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0)
|
||||||
|
|
||||||
|
# Now pretend the constants are frozen params.
|
||||||
|
counters.clear()
|
||||||
|
self.reset()
|
||||||
|
|
||||||
|
with mock.patch(
|
||||||
|
"torch._inductor.codecache.has_frozen_params", return_value=True
|
||||||
|
):
|
||||||
|
# A call to fn1 should miss in the cache since we do not consider
|
||||||
|
# the constant values.
|
||||||
|
self.assertEqual(fn1(a), compiled_fn1(a))
|
||||||
|
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1)
|
||||||
|
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0)
|
||||||
|
|
||||||
|
# A call to fn2 should hit for the same reason.
|
||||||
|
self.assertEqual(fn2(a), compiled_fn2(a))
|
||||||
|
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1)
|
||||||
|
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 1)
|
||||||
|
|
||||||
@requires_cuda
|
@requires_cuda
|
||||||
@config.patch({"fx_graph_cache": True})
|
@config.patch({"fx_graph_cache": True})
|
||||||
@config.patch({"fx_graph_remote_cache": False})
|
@config.patch({"fx_graph_remote_cache": False})
|
||||||
|
|
@ -417,8 +436,6 @@ class TestFxGraphCache(TestCase):
|
||||||
|
|
||||||
# A second call should hit. (First reset so in-memory guards
|
# A second call should hit. (First reset so in-memory guards
|
||||||
# don't prevent compilation).
|
# don't prevent compilation).
|
||||||
for m in torch._inductor.codecache.PyCodeCache.cache.values():
|
|
||||||
os.remove(m.__file__)
|
|
||||||
self.reset()
|
self.reset()
|
||||||
self.assertEqual(fn(a, b, c), compiled_fn(a, b, c), atol=atol, rtol=rtol)
|
self.assertEqual(fn(a, b, c), compiled_fn(a, b, c), atol=atol, rtol=rtol)
|
||||||
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1)
|
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1)
|
||||||
|
|
@ -426,8 +443,6 @@ class TestFxGraphCache(TestCase):
|
||||||
self.assertEqual(counters["inductor"]["fxgraph_lookup_write_file"], 1)
|
self.assertEqual(counters["inductor"]["fxgraph_lookup_write_file"], 1)
|
||||||
|
|
||||||
# A third call with different score_mod should have a cache miss
|
# A third call with different score_mod should have a cache miss
|
||||||
for m in torch._inductor.codecache.PyCodeCache.cache.values():
|
|
||||||
os.remove(m.__file__)
|
|
||||||
self.reset()
|
self.reset()
|
||||||
self.assertEqual(fn2(a, b, c), compiled_fn2(a, b, c), atol=atol, rtol=rtol)
|
self.assertEqual(fn2(a, b, c), compiled_fn2(a, b, c), atol=atol, rtol=rtol)
|
||||||
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 2)
|
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 2)
|
||||||
|
|
@ -647,14 +662,80 @@ class TestFxGraphCache(TestCase):
|
||||||
|
|
||||||
self.assertNotEqual(a, b)
|
self.assertNotEqual(a, b)
|
||||||
|
|
||||||
|
@config.patch({"fx_graph_cache": True})
|
||||||
|
@config.patch({"fx_graph_remote_cache": False})
|
||||||
|
@config.patch({"freezing": True})
|
||||||
|
@parametrize("device", (GPU_TYPE, "cpu"))
|
||||||
|
def test_freezing(self, device):
|
||||||
|
if device == GPU_TYPE and not HAS_GPU:
|
||||||
|
raise unittest.SkipTest(f"requires {GPU_TYPE}")
|
||||||
|
|
||||||
|
class MM(torch.nn.Module):
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.param = torch.nn.Parameter(torch.rand(8, 8))
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return x @ self.param
|
||||||
|
|
||||||
|
dtype = torch.float16
|
||||||
|
|
||||||
|
# Populate a cache entry.
|
||||||
|
mod1 = MM().to(device=device, dtype=dtype)
|
||||||
|
with torch.no_grad():
|
||||||
|
x = torch.rand(8, 8).to(device=device, dtype=dtype)
|
||||||
|
out0 = mod1(x)
|
||||||
|
out1 = torch.compile(mod1)(x)
|
||||||
|
self.assertEqual(out0, out1)
|
||||||
|
|
||||||
|
self.assertEqual(counters["inductor"]["fxgraph_cache_bypass"], 0)
|
||||||
|
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1)
|
||||||
|
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0)
|
||||||
|
|
||||||
|
counters.clear()
|
||||||
|
self.reset()
|
||||||
|
|
||||||
|
# Same nn.Module, but with different parameters should cache hit.
|
||||||
|
mod2 = MM().to(device=device, dtype=dtype)
|
||||||
|
self.assertNotEqual(mod1.param, mod2.param)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
x = torch.rand(8, 8).to(device=device, dtype=dtype)
|
||||||
|
out0 = mod2(x)
|
||||||
|
out1 = torch.compile(mod2)(x)
|
||||||
|
self.assertEqual(out0, out1)
|
||||||
|
|
||||||
|
self.assertEqual(counters["inductor"]["fxgraph_cache_bypass"], 0)
|
||||||
|
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 0)
|
||||||
|
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 1)
|
||||||
|
|
||||||
|
|
||||||
class TestFxGraphCacheHashing(TestCase):
|
class TestFxGraphCacheHashing(TestCase):
|
||||||
def test_tensor_constants(self):
|
def test_tensor_constants(self):
|
||||||
"""
|
"""
|
||||||
Test the hashing of tensor constants.
|
Test the hashing of tensor constants.
|
||||||
"""
|
"""
|
||||||
data = FxGraphCachePickler().dumps(torch.tensor(list(range(9))))
|
small = torch.tensor(list(range(8)))
|
||||||
|
large = torch.tensor(list(range(32)))
|
||||||
|
|
||||||
|
self.assertTrue(GraphLowering.can_inline_constant(small))
|
||||||
|
self.assertFalse(GraphLowering.can_inline_constant(large))
|
||||||
|
|
||||||
|
# By default, we hash the metadata and values independent of the size.
|
||||||
|
pickler = FxGraphCachePickler()
|
||||||
|
|
||||||
|
data = pickler.dumps(small)
|
||||||
self.assertIsInstance(pickle.loads(data), TensorMetadataAndValues)
|
self.assertIsInstance(pickle.loads(data), TensorMetadataAndValues)
|
||||||
|
data = pickler.dumps(large)
|
||||||
|
self.assertIsInstance(pickle.loads(data), TensorMetadataAndValues)
|
||||||
|
|
||||||
|
# If include_non_inlined=False, we only hash the values of small tensors.
|
||||||
|
pickler = FxGraphCachePickler(False)
|
||||||
|
|
||||||
|
data = pickler.dumps(small)
|
||||||
|
self.assertIsInstance(pickle.loads(data), TensorMetadataAndValues)
|
||||||
|
data = pickler.dumps(large)
|
||||||
|
self.assertIsInstance(pickle.loads(data), TensorMetadata)
|
||||||
|
|
||||||
def test_hash_fake_tensors(self):
|
def test_hash_fake_tensors(self):
|
||||||
"""
|
"""
|
||||||
|
|
@ -1133,13 +1214,13 @@ class TestUtils(TestCase):
|
||||||
b = torch.rand(10)
|
b = torch.rand(10)
|
||||||
|
|
||||||
with fresh_inductor_cache():
|
with fresh_inductor_cache():
|
||||||
self.assertEqual(len(PyCodeCache.cache.keys()), 0)
|
self.assertEqual(len(PyCodeCache.modules), 0)
|
||||||
res1 = torch.compile(fn)(a, b)
|
res1 = torch.compile(fn)(a, b)
|
||||||
cache_dir1 = cache_dir()
|
cache_dir1 = cache_dir()
|
||||||
|
|
||||||
torch._dynamo.reset()
|
torch._dynamo.reset()
|
||||||
with fresh_inductor_cache():
|
with fresh_inductor_cache():
|
||||||
self.assertEqual(len(PyCodeCache.cache.keys()), 0)
|
self.assertEqual(len(PyCodeCache.modules), 0)
|
||||||
res2 = torch.compile(fn)(a, b)
|
res2 = torch.compile(fn)(a, b)
|
||||||
cache_dir2 = cache_dir()
|
cache_dir2 = cache_dir()
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -40,11 +40,11 @@ class TestKernelBenchmark(TestCase):
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
super().setUp()
|
super().setUp()
|
||||||
PyCodeCache.cache.clear()
|
PyCodeCache.cache_clear()
|
||||||
|
|
||||||
def get_compiled_module(self):
|
def get_compiled_module(self):
|
||||||
compiled_module = None
|
compiled_module = None
|
||||||
for v in PyCodeCache.cache.values():
|
for v in PyCodeCache.modules:
|
||||||
if hasattr(v, "benchmark_compiled_module"):
|
if hasattr(v, "benchmark_compiled_module"):
|
||||||
self.assertTrue(
|
self.assertTrue(
|
||||||
compiled_module is None, "Found multiple compiled modules"
|
compiled_module is None, "Found multiple compiled modules"
|
||||||
|
|
|
||||||
|
|
@ -14,7 +14,7 @@ from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU
|
||||||
class TestTritonWrapper(TestCase):
|
class TestTritonWrapper(TestCase):
|
||||||
def get_compiled_module(self):
|
def get_compiled_module(self):
|
||||||
compiled_module = None
|
compiled_module = None
|
||||||
for v in PyCodeCache.cache.values():
|
for v in PyCodeCache.modules:
|
||||||
if hasattr(v, "benchmark_compiled_module"):
|
if hasattr(v, "benchmark_compiled_module"):
|
||||||
self.assertTrue(
|
self.assertTrue(
|
||||||
compiled_module is None, "Found multiple compiled modules"
|
compiled_module is None, "Found multiple compiled modules"
|
||||||
|
|
|
||||||
|
|
@ -526,10 +526,18 @@ class FxGraphCachePickler(pickle.Pickler):
|
||||||
data that allow us to compute a stable, but safe hash.
|
data that allow us to compute a stable, but safe hash.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self, include_non_inlined: bool = True) -> None:
|
||||||
|
"""
|
||||||
|
Create an FX graph pickler. If include_non_inlined=True, then pickling will
|
||||||
|
include the _values_ for all Tensors. (Note that any tensors are constants
|
||||||
|
attached as attributes to the GraphModule). Otherwise, pickling will include
|
||||||
|
only the metadata for these tensors.
|
||||||
|
"""
|
||||||
self._stream = io.BytesIO()
|
self._stream = io.BytesIO()
|
||||||
super().__init__(self._stream)
|
super().__init__(self._stream)
|
||||||
|
|
||||||
|
self.include_non_inlined = include_non_inlined
|
||||||
|
|
||||||
self.dispatch_table = copyreg.dispatch_table.copy()
|
self.dispatch_table = copyreg.dispatch_table.copy()
|
||||||
self.dispatch_table.update(
|
self.dispatch_table.update(
|
||||||
{
|
{
|
||||||
|
|
@ -558,35 +566,38 @@ class FxGraphCachePickler(pickle.Pickler):
|
||||||
def _reduce_tensor(
|
def _reduce_tensor(
|
||||||
self,
|
self,
|
||||||
t: Tensor,
|
t: Tensor,
|
||||||
) -> Tuple[Callable[[T], T], Tuple[TensorMetadataAndValues]]:
|
) -> Tuple[Callable[[T], T], Tuple[Union[TensorMetadata, TensorMetadataAndValues]]]:
|
||||||
"""
|
"""
|
||||||
Custom reducer to pickle Tensors. If we see tensors, we know they're constants
|
Custom reducer to pickle Tensors. If we see tensors, we know they're constants
|
||||||
stored as attributes on the GraphModule. Include the values in the key
|
stored as attributes on the GraphModule.
|
||||||
calculation. Small tensors will be inlined, so we can't serve the same cache
|
|
||||||
entry for different values anyway. Large constants are treated as parameters, so
|
|
||||||
we could conceivably reuse a cache entry. To do that, however, PyCodeCache would
|
|
||||||
need more complexity to create a new module from its cache, but with the right
|
|
||||||
constants attached as attributes.
|
|
||||||
"""
|
"""
|
||||||
|
from .graph import GraphLowering
|
||||||
|
|
||||||
if t.is_mkldnn:
|
if t.is_mkldnn:
|
||||||
# TODO: These tensors don't currently pickle, so we can't cache a compiled
|
# TODO: These tensors don't currently pickle, so we can't cache a compiled
|
||||||
# graph containing them. Just fail now. If mkldnn tensors get pickling
|
# graph containing them. Just fail now. If mkldnn tensors get pickling
|
||||||
# support, we can remove this.
|
# support, we can remove this.
|
||||||
raise BypassFxGraphCache("mkldnn tensors unpickleable")
|
raise BypassFxGraphCache("mkldnn tensors unpickleable")
|
||||||
|
|
||||||
# Very large tensors could be expensive to copy to cpu and hash. Let's at least
|
# If this is an inlined constant or include_non_inlined=True, then we include
|
||||||
# report if we find slowness.
|
# the metadata and the values.
|
||||||
start = time()
|
|
||||||
values = t.tolist()
|
|
||||||
elapsed = time() - start
|
|
||||||
if elapsed > 1.0:
|
|
||||||
warnings.warn(
|
|
||||||
f"FX graph cache handling of a large constant took {elapsed:.1}s. "
|
|
||||||
"Please file an issue."
|
|
||||||
)
|
|
||||||
|
|
||||||
metadata = extract_tensor_metadata_for_cache_key(t)
|
metadata = extract_tensor_metadata_for_cache_key(t)
|
||||||
return (_ident, (TensorMetadataAndValues(metadata, values),))
|
if GraphLowering.can_inline_constant(t) or self.include_non_inlined:
|
||||||
|
# Very large tensors will be expensive to copy to cpu and hash. Let's at
|
||||||
|
# least report any slowness.
|
||||||
|
start = time()
|
||||||
|
values = t.tolist()
|
||||||
|
elapsed = time() - start
|
||||||
|
if elapsed > 1.0:
|
||||||
|
warnings.warn(
|
||||||
|
f"FX graph cache copying of a large constant took {elapsed:.1}s. "
|
||||||
|
"Please file an issue."
|
||||||
|
)
|
||||||
|
|
||||||
|
return (_ident, (TensorMetadataAndValues(metadata, values),))
|
||||||
|
|
||||||
|
# Otherwise, we just include the metadata.
|
||||||
|
return (_ident, (metadata,))
|
||||||
|
|
||||||
def _reduce_symint(self, s: SymInt) -> Tuple[Callable[[T], T], Tuple[str]]:
|
def _reduce_symint(self, s: SymInt) -> Tuple[Callable[[T], T], Tuple[str]]:
|
||||||
"""
|
"""
|
||||||
|
|
@ -804,6 +815,10 @@ class FxGraphHashDetails:
|
||||||
return custom_pass.uuid()
|
return custom_pass.uuid()
|
||||||
|
|
||||||
|
|
||||||
|
def has_frozen_params(gm: torch.fx.GraphModule) -> bool:
|
||||||
|
return getattr(gm, "_has_frozen_params", False)
|
||||||
|
|
||||||
|
|
||||||
def compiled_fx_graph_hash(
|
def compiled_fx_graph_hash(
|
||||||
gm: torch.fx.GraphModule,
|
gm: torch.fx.GraphModule,
|
||||||
example_inputs: Sequence[InputType],
|
example_inputs: Sequence[InputType],
|
||||||
|
|
@ -813,8 +828,13 @@ def compiled_fx_graph_hash(
|
||||||
"""
|
"""
|
||||||
Generate a unique hash of the FX graph for caching.
|
Generate a unique hash of the FX graph for caching.
|
||||||
"""
|
"""
|
||||||
|
# To support caching when the graph has frozen params, we ignore the tensor values
|
||||||
|
# of non-inlined constants since they won't be included in the cache entry. Without
|
||||||
|
# freezing, we want to include the values of any constant attribute.
|
||||||
|
include_non_inlined = not has_frozen_params(gm)
|
||||||
|
|
||||||
details = FxGraphHashDetails(gm, example_inputs, fx_kwargs, inputs_to_check)
|
details = FxGraphHashDetails(gm, example_inputs, fx_kwargs, inputs_to_check)
|
||||||
pickler = FxGraphCachePickler()
|
pickler = FxGraphCachePickler(include_non_inlined)
|
||||||
# The prefix distinguishes among the other kinds of objects we
|
# The prefix distinguishes among the other kinds of objects we
|
||||||
# cache in this module.
|
# cache in this module.
|
||||||
key = "f" + pickler.get_hash(details)
|
key = "f" + pickler.get_hash(details)
|
||||||
|
|
@ -828,6 +848,7 @@ def cudagraph_post_compile(
|
||||||
example_inputs: Sequence[InputType],
|
example_inputs: Sequence[InputType],
|
||||||
compiled_graph: CompiledFxGraph,
|
compiled_graph: CompiledFxGraph,
|
||||||
cudagraphs: BoxedBool,
|
cudagraphs: BoxedBool,
|
||||||
|
gm: Optional[torch.fx.GraphModule],
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Checks for any reasons not to run cudagraphs and then
|
Checks for any reasons not to run cudagraphs and then
|
||||||
|
|
@ -873,7 +894,7 @@ def cudagraph_post_compile(
|
||||||
stack_traces=stack_traces,
|
stack_traces=stack_traces,
|
||||||
is_backward=is_backward,
|
is_backward=is_backward,
|
||||||
is_inference=is_inference,
|
is_inference=is_inference,
|
||||||
constants=tuple(compiled_graph.constants.values()),
|
constants=tuple(compiled_graph.get_constants(gm).values()),
|
||||||
placeholders=placeholders,
|
placeholders=placeholders,
|
||||||
mutated_input_idxs=tuple(compiled_graph.mutated_input_idxs),
|
mutated_input_idxs=tuple(compiled_graph.mutated_input_idxs),
|
||||||
)
|
)
|
||||||
|
|
@ -1030,6 +1051,7 @@ class FxGraphCache:
|
||||||
example_inputs: Sequence[InputType],
|
example_inputs: Sequence[InputType],
|
||||||
local: bool,
|
local: bool,
|
||||||
remote_cache: Optional[RemoteCache[JsonDataTy]],
|
remote_cache: Optional[RemoteCache[JsonDataTy]],
|
||||||
|
gm: Optional[torch.fx.GraphModule],
|
||||||
) -> Tuple[Optional[CompiledFxGraph], Dict[str, Any]]:
|
) -> Tuple[Optional[CompiledFxGraph], Dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
Lookup a compiled graph in the cache by key. On a hit, return the
|
Lookup a compiled graph in the cache by key. On a hit, return the
|
||||||
|
|
@ -1135,7 +1157,7 @@ class FxGraphCache:
|
||||||
graph.cache_key,
|
graph.cache_key,
|
||||||
artifact_path,
|
artifact_path,
|
||||||
graph.cache_linemap,
|
graph.cache_linemap,
|
||||||
graph.constants,
|
graph.get_constants(gm),
|
||||||
).call
|
).call
|
||||||
except OSError:
|
except OSError:
|
||||||
# Not expected, but in case the PyCodeCache entry is removed from
|
# Not expected, but in case the PyCodeCache entry is removed from
|
||||||
|
|
@ -1177,6 +1199,7 @@ class FxGraphCache:
|
||||||
compiled_graph: CompiledFxGraph,
|
compiled_graph: CompiledFxGraph,
|
||||||
example_inputs: Sequence[InputType],
|
example_inputs: Sequence[InputType],
|
||||||
cudagraphs: BoxedBool,
|
cudagraphs: BoxedBool,
|
||||||
|
gm: Optional[torch.fx.GraphModule] = None,
|
||||||
) -> CompiledFxGraph:
|
) -> CompiledFxGraph:
|
||||||
"""
|
"""
|
||||||
Run a set of post processing steps after loading from the cache. These involve:
|
Run a set of post processing steps after loading from the cache. These involve:
|
||||||
|
|
@ -1206,6 +1229,7 @@ class FxGraphCache:
|
||||||
example_inputs,
|
example_inputs,
|
||||||
compiled_graph,
|
compiled_graph,
|
||||||
cudagraphs,
|
cudagraphs,
|
||||||
|
gm,
|
||||||
)
|
)
|
||||||
inputs_to_check = compiled_graph.inputs_to_check
|
inputs_to_check = compiled_graph.inputs_to_check
|
||||||
# cudagraphs could have been disabled from the earlier conditions
|
# cudagraphs could have been disabled from the earlier conditions
|
||||||
|
|
@ -1294,9 +1318,15 @@ class FxGraphCache:
|
||||||
raise BypassFxGraphCache("Unsupported post grad custom pass")
|
raise BypassFxGraphCache("Unsupported post grad custom pass")
|
||||||
|
|
||||||
# Freezing can embed constants that wouldn't be static across runs.
|
# Freezing can embed constants that wouldn't be static across runs.
|
||||||
if config.freezing or config.aot_inductor.use_runtime_constant_folding:
|
if has_frozen_params(gm) and not torch._utils_internal.justknobs_check(
|
||||||
|
"pytorch/inductor:allow_freezing_with_caching"
|
||||||
|
):
|
||||||
|
raise BypassFxGraphCache("Skipping graph with frozen constants")
|
||||||
|
|
||||||
|
if config.aot_inductor.use_runtime_constant_folding:
|
||||||
raise BypassFxGraphCache(
|
raise BypassFxGraphCache(
|
||||||
"Freezing may introduce constants that aren't static across runs"
|
"Runtime constant folding can introduce constants that aren't "
|
||||||
|
"static across runs"
|
||||||
)
|
)
|
||||||
|
|
||||||
from torch._inductor.bisect_helper import BisectionManager
|
from torch._inductor.bisect_helper import BisectionManager
|
||||||
|
|
@ -1386,6 +1416,7 @@ class FxGraphCache:
|
||||||
local: bool,
|
local: bool,
|
||||||
remote_cache: Optional[RemoteCache[JsonDataTy]],
|
remote_cache: Optional[RemoteCache[JsonDataTy]],
|
||||||
is_backward: bool,
|
is_backward: bool,
|
||||||
|
gm: Optional[torch.fx.GraphModule] = None,
|
||||||
) -> Tuple[Optional[CompiledFxGraph], Dict[str, Any]]:
|
) -> Tuple[Optional[CompiledFxGraph], Dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
Lookup the graph with the given key, and return results and metadata.
|
Lookup the graph with the given key, and return results and metadata.
|
||||||
|
|
@ -1393,7 +1424,7 @@ class FxGraphCache:
|
||||||
differently from FXGraphCache.
|
differently from FXGraphCache.
|
||||||
"""
|
"""
|
||||||
compiled_graph, cache_info = FxGraphCache._lookup_graph(
|
compiled_graph, cache_info = FxGraphCache._lookup_graph(
|
||||||
key, example_inputs, local, remote_cache
|
key, example_inputs, local, remote_cache, gm
|
||||||
)
|
)
|
||||||
cache_info = {
|
cache_info = {
|
||||||
**cache_info,
|
**cache_info,
|
||||||
|
|
@ -1453,6 +1484,7 @@ class FxGraphCache:
|
||||||
local,
|
local,
|
||||||
remote_cache,
|
remote_cache,
|
||||||
is_backward=fx_kwargs.get("is_backward", False),
|
is_backward=fx_kwargs.get("is_backward", False),
|
||||||
|
gm=gm,
|
||||||
)
|
)
|
||||||
|
|
||||||
# CACHE BYPASS: Compile the graph, don't save it to the cache
|
# CACHE BYPASS: Compile the graph, don't save it to the cache
|
||||||
|
|
@ -1528,7 +1560,7 @@ class FxGraphCache:
|
||||||
)
|
)
|
||||||
# Use the passed in cudagraphs so that we mutate the BoxedBool correctly
|
# Use the passed in cudagraphs so that we mutate the BoxedBool correctly
|
||||||
FxGraphCache.post_compile(
|
FxGraphCache.post_compile(
|
||||||
compiled_graph, example_inputs, fx_kwargs["cudagraphs"] # type: ignore[arg-type]
|
compiled_graph, example_inputs, fx_kwargs["cudagraphs"], gm # type: ignore[arg-type]
|
||||||
)
|
)
|
||||||
return compiled_graph
|
return compiled_graph
|
||||||
|
|
||||||
|
|
@ -1561,7 +1593,15 @@ class CompiledFxGraph:
|
||||||
device_idxs: Set[int]
|
device_idxs: Set[int]
|
||||||
mutated_inputs: Set[str]
|
mutated_inputs: Set[str]
|
||||||
mutated_input_idxs: Set[int]
|
mutated_input_idxs: Set[int]
|
||||||
constants: Dict[str, torch.Tensor]
|
# We populate exactly one of the next two fields. In the common case, we store the
|
||||||
|
# constant attirbutes in the cache entry and re-attach them to the module created in
|
||||||
|
# PyCodeCache.load_by_key_path. In the case that the graph has frozen parameters,
|
||||||
|
# however, we save the mapping from attribute names in the GraphLowering to the
|
||||||
|
# original name of the attribute in the GraphModule. When we create the module from
|
||||||
|
# the cache entry, we then look up the constants from the current GraphModule. This
|
||||||
|
# scheme allows us to support caching with freezing.
|
||||||
|
allocated_constant_name: Optional[Dict[str, str]]
|
||||||
|
constants: Optional[Dict[str, torch.Tensor]]
|
||||||
torchbind_constants: Dict[str, torch._C.ScriptObject]
|
torchbind_constants: Dict[str, torch._C.ScriptObject]
|
||||||
output_strides: Optional[List[Optional[Tuple[_StrideExprStr, ...]]]]
|
output_strides: Optional[List[Optional[Tuple[_StrideExprStr, ...]]]]
|
||||||
disabled_cudagraphs_reason: Optional[str]
|
disabled_cudagraphs_reason: Optional[str]
|
||||||
|
|
@ -1588,6 +1628,7 @@ class CompiledFxGraph:
|
||||||
self,
|
self,
|
||||||
current_callable: Optional[Callable[..., Any]],
|
current_callable: Optional[Callable[..., Any]],
|
||||||
graph: GraphLowering,
|
graph: GraphLowering,
|
||||||
|
gm: torch.fx.GraphModule,
|
||||||
output_strides: List[Optional[Tuple[_StrideExprStr, ...]]],
|
output_strides: List[Optional[Tuple[_StrideExprStr, ...]]],
|
||||||
disabled_cudagraphs_reason: Optional[str],
|
disabled_cudagraphs_reason: Optional[str],
|
||||||
metrics_deltas: metrics.CachedMetricsDeltas,
|
metrics_deltas: metrics.CachedMetricsDeltas,
|
||||||
|
|
@ -1604,7 +1645,12 @@ class CompiledFxGraph:
|
||||||
self.device_idxs = set(graph.device_idxs)
|
self.device_idxs = set(graph.device_idxs)
|
||||||
self.mutated_inputs = set(graph.mutated_inputs)
|
self.mutated_inputs = set(graph.mutated_inputs)
|
||||||
self.mutated_input_idxs = set(graph.mutated_input_idxs)
|
self.mutated_input_idxs = set(graph.mutated_input_idxs)
|
||||||
self.constants = graph.constants
|
if has_frozen_params(gm):
|
||||||
|
self.allocated_constant_name = graph.allocated_constant_name
|
||||||
|
self.constants = None
|
||||||
|
else:
|
||||||
|
self.allocated_constant_name = None
|
||||||
|
self.constants = graph.constants
|
||||||
self.torchbind_constants = graph.torchbind_constants
|
self.torchbind_constants = graph.torchbind_constants
|
||||||
self.output_strides = output_strides
|
self.output_strides = output_strides
|
||||||
self.disabled_cudagraphs_reason = disabled_cudagraphs_reason
|
self.disabled_cudagraphs_reason = disabled_cudagraphs_reason
|
||||||
|
|
@ -1623,6 +1669,26 @@ class CompiledFxGraph:
|
||||||
finally:
|
finally:
|
||||||
AutotuneCacheBundler.end_compile()
|
AutotuneCacheBundler.end_compile()
|
||||||
|
|
||||||
|
def get_constants(
|
||||||
|
self, gm: Optional[torch.fx.GraphModule]
|
||||||
|
) -> Dict[str, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Get the constant attributes.
|
||||||
|
"""
|
||||||
|
# Normal case: The constants are stored in the entry.
|
||||||
|
if self.constants is not None:
|
||||||
|
return self.constants
|
||||||
|
|
||||||
|
# Freezing case: Look up the constants from attributes on the GraphModule using
|
||||||
|
# the allocated_constant_name map.
|
||||||
|
assert gm is not None
|
||||||
|
assert self.allocated_constant_name is not None
|
||||||
|
constants = {
|
||||||
|
name: getattr(gm, orig_name)
|
||||||
|
for name, orig_name in self.allocated_constant_name.items()
|
||||||
|
}
|
||||||
|
return constants
|
||||||
|
|
||||||
|
|
||||||
def run_command_and_check(cmd_: str) -> None:
|
def run_command_and_check(cmd_: str) -> None:
|
||||||
cmd = shlex.split(cmd_)
|
cmd = shlex.split(cmd_)
|
||||||
|
|
@ -2976,9 +3042,13 @@ def touch(filename: str): # type: ignore[no-untyped-def]
|
||||||
|
|
||||||
@clear_on_fresh_inductor_cache
|
@clear_on_fresh_inductor_cache
|
||||||
class PyCodeCache:
|
class PyCodeCache:
|
||||||
|
# Track the loaded modules so we can remove the on-disk artifacts when
|
||||||
|
# clearing the cache. Note also that we may load the same path more
|
||||||
|
# than once, but attach different attributes, i.e., due to different
|
||||||
|
# constant values.
|
||||||
|
modules: List[ModuleType] = []
|
||||||
cache: Dict[str, ModuleType] = {}
|
cache: Dict[str, ModuleType] = {}
|
||||||
linemaps: Dict[str, List[Tuple[Any, ...]]] = {}
|
linemaps: Dict[str, List[Tuple[Any, ...]]] = {}
|
||||||
cache_clear = staticmethod(cache.clear)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def write(cls, source_code: str, extra: str = "") -> Tuple[str, str]:
|
def write(cls, source_code: str, extra: str = "") -> Tuple[str, str]:
|
||||||
|
|
@ -3005,24 +3075,33 @@ class PyCodeCache:
|
||||||
) -> ModuleType:
|
) -> ModuleType:
|
||||||
if linemap is None:
|
if linemap is None:
|
||||||
linemap = []
|
linemap = []
|
||||||
if key not in cls.cache:
|
|
||||||
mod = _reload_python_module(key, path)
|
|
||||||
|
|
||||||
# another thread might set this first
|
mod = _reload_python_module(key, path)
|
||||||
cls.cache.setdefault(key, mod)
|
|
||||||
# unzip into separate lines/nodes lists
|
|
||||||
cls.linemaps[path] = list(zip(*linemap))
|
|
||||||
|
|
||||||
if attrs is not None:
|
# unzip into separate lines/nodes lists
|
||||||
for k, v in attrs.items():
|
cls.linemaps[path] = list(zip(*linemap))
|
||||||
setattr(mod, k, v)
|
|
||||||
|
|
||||||
if not (linemap or attrs):
|
if attrs is not None:
|
||||||
mod._reload_in_subproc = functools.partial( # type: ignore[attr-defined]
|
for k, v in attrs.items():
|
||||||
_reload_python_module_in_subproc, key, path
|
setattr(mod, k, v)
|
||||||
)
|
|
||||||
|
|
||||||
return cls.cache[key]
|
if not (linemap or attrs):
|
||||||
|
mod._reload_in_subproc = functools.partial( # type: ignore[attr-defined]
|
||||||
|
_reload_python_module_in_subproc, key, path
|
||||||
|
)
|
||||||
|
|
||||||
|
cls.modules.append(mod)
|
||||||
|
return mod
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def cache_clear(cls) -> None:
|
||||||
|
for mod in cls.modules:
|
||||||
|
try:
|
||||||
|
assert mod.__file__
|
||||||
|
os.remove(mod.__file__)
|
||||||
|
except FileNotFoundError:
|
||||||
|
pass
|
||||||
|
cls.modules.clear()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@functools.lru_cache(None)
|
@functools.lru_cache(None)
|
||||||
|
|
|
||||||
|
|
@ -760,7 +760,7 @@ def _compile_fx_inner(
|
||||||
# to return the string directly.
|
# to return the string directly.
|
||||||
return compiled_graph
|
return compiled_graph
|
||||||
compiled_graph = FxGraphCache.post_compile(
|
compiled_graph = FxGraphCache.post_compile(
|
||||||
compiled_graph, example_inputs, cudagraphs
|
compiled_graph, example_inputs, cudagraphs, gm
|
||||||
)
|
)
|
||||||
|
|
||||||
log.debug("FX codegen and compilation took %.3fs", time.time() - start)
|
log.debug("FX codegen and compilation took %.3fs", time.time() - start)
|
||||||
|
|
@ -1015,6 +1015,7 @@ def fx_codegen_and_compile(
|
||||||
compiled_graph = CompiledFxGraph(
|
compiled_graph = CompiledFxGraph(
|
||||||
compiled_fn,
|
compiled_fn,
|
||||||
graph,
|
graph,
|
||||||
|
gm,
|
||||||
output_strides,
|
output_strides,
|
||||||
V.graph.disable_cudagraphs_reason,
|
V.graph.disable_cudagraphs_reason,
|
||||||
metrics_helper.get_deltas(),
|
metrics_helper.get_deltas(),
|
||||||
|
|
@ -1289,6 +1290,8 @@ def fw_compiler_freezing(
|
||||||
aot_example_inputs, # type: ignore[arg-type]
|
aot_example_inputs, # type: ignore[arg-type]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
setattr(opt_model, "_has_frozen_params", True) # noqa: B010
|
||||||
|
|
||||||
aot_example_inputs = [aot_example_inputs[ind] for ind in preserved_arg_indices]
|
aot_example_inputs = [aot_example_inputs[ind] for ind in preserved_arg_indices]
|
||||||
num_fixed = len(preserved_arg_indices) - num_example_inputs
|
num_fixed = len(preserved_arg_indices) - num_example_inputs
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -77,7 +77,8 @@ def benchmark_all_kernels(benchmark_name, benchmark_all_configs):
|
||||||
from torch._inductor.codecache import PyCodeCache
|
from torch._inductor.codecache import PyCodeCache
|
||||||
|
|
||||||
nfound = 0
|
nfound = 0
|
||||||
for kernel_key, kernel_mod in PyCodeCache.cache.items():
|
for kernel_mod in PyCodeCache.modules:
|
||||||
|
kernel_key = kernel_mod.key
|
||||||
if not hasattr(kernel_mod, "get_args") or not hasattr(kernel_mod, "call"):
|
if not hasattr(kernel_mod, "get_args") or not hasattr(kernel_mod, "call"):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -890,7 +890,7 @@ class ExecutionTraceObserver(_ITraceObserver):
|
||||||
|
|
||||||
kernel_files = [
|
kernel_files = [
|
||||||
v.__file__
|
v.__file__
|
||||||
for v in PyCodeCache.cache.values()
|
for v in PyCodeCache.modules
|
||||||
if getattr(v, "__file__", None) is not None
|
if getattr(v, "__file__", None) is not None
|
||||||
]
|
]
|
||||||
work_dir, file_name = os.path.split(self._output_file_path)
|
work_dir, file_name = os.path.split(self._output_file_path)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user