mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[fx graph cache] Support freezing with FX graph caching (#136505)
Summary: The main changes to support freezing are: 1) When pickling constant tensors as part of the cache key calculation: If freezing has not been applied, then keep the existing behavior (pickle the metadata and values). If freezing has been applied, then pickle the values if the constant will be inlined; otherwise, consider only the metadata. 2) If freezing has been applied, modify what we store in the cache: Instead of storing the constant attributes in the cache entry, store the _names_ of the constants, and then grab those constants from the GraphModule when we need attache the attributes to a newly-loaded Python module. Since the cache lookup path loads the Python module, this bullet means we need to thread through a GraphModule argument in several places. 3) Since this feature means that we may need to reload the same Python module path more than once (but attach different constant attributes), I changed PyCodeCache.load_by_key_path to not store an in-memory map of path to module (since there may be more than one). I don't _think_ this will have any affect on performance, however.. It's unclear why we were using an in-memory cache here anyway, since this function should only be called once for each module needed to be loaded. 4) Several tests were removing on-disk PyCodeCache artifacts by iterating over the modules. I made this more straightforward by implementing a cache_clear method that removes the on-disk artifacts. Arguably, this should have been the implementation all along. Differential Revision: [D63542170](https://our.internmc.facebook.com/intern/diff/D63542170) Pull Request resolved: https://github.com/pytorch/pytorch/pull/136505 Approved by: https://github.com/eellison
This commit is contained in:
parent
7d644f025f
commit
d8b606ecb5
|
|
@ -1,6 +1,5 @@
|
|||
# Owner(s): ["module: dynamo"]
|
||||
|
||||
import os
|
||||
import unittest
|
||||
from unittest.mock import patch
|
||||
|
||||
|
|
@ -53,8 +52,6 @@ class AOTAutogradCacheTests(InductorTestCase):
|
|||
Clear unrelated caches, like dynamo and PyCodeCache
|
||||
"""
|
||||
torch._dynamo.reset()
|
||||
for m in torch._inductor.codecache.PyCodeCache.cache.values():
|
||||
os.remove(m.__file__)
|
||||
torch._inductor.codecache.PyCodeCache.cache_clear()
|
||||
|
||||
@inductor_config.patch("fx_graph_remote_cache", False)
|
||||
|
|
|
|||
|
|
@ -160,11 +160,12 @@ class TestFxGraphCache(TestCase):
|
|||
|
||||
# A second call should hit. (First reset so in-memory guards
|
||||
# don't prevent compilation).
|
||||
for m in torch._inductor.codecache.PyCodeCache.cache.values():
|
||||
os.remove(m.__file__)
|
||||
# Clean triton kernels
|
||||
shutil.rmtree(os.path.join(cache_dir(), "triton"), ignore_errors=True)
|
||||
self.reset()
|
||||
|
||||
# Clean PyCodeCache and triton kernels
|
||||
PyCodeCache.cache_clear()
|
||||
shutil.rmtree(os.path.join(cache_dir(), "triton"), ignore_errors=True)
|
||||
|
||||
self.assertEqual(fn(a, b), compiled_fn(a, b))
|
||||
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1)
|
||||
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 1)
|
||||
|
|
@ -381,6 +382,24 @@ class TestFxGraphCache(TestCase):
|
|||
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 2)
|
||||
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0)
|
||||
|
||||
# Now pretend the constants are frozen params.
|
||||
counters.clear()
|
||||
self.reset()
|
||||
|
||||
with mock.patch(
|
||||
"torch._inductor.codecache.has_frozen_params", return_value=True
|
||||
):
|
||||
# A call to fn1 should miss in the cache since we do not consider
|
||||
# the constant values.
|
||||
self.assertEqual(fn1(a), compiled_fn1(a))
|
||||
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1)
|
||||
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0)
|
||||
|
||||
# A call to fn2 should hit for the same reason.
|
||||
self.assertEqual(fn2(a), compiled_fn2(a))
|
||||
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1)
|
||||
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 1)
|
||||
|
||||
@requires_cuda
|
||||
@config.patch({"fx_graph_cache": True})
|
||||
@config.patch({"fx_graph_remote_cache": False})
|
||||
|
|
@ -417,8 +436,6 @@ class TestFxGraphCache(TestCase):
|
|||
|
||||
# A second call should hit. (First reset so in-memory guards
|
||||
# don't prevent compilation).
|
||||
for m in torch._inductor.codecache.PyCodeCache.cache.values():
|
||||
os.remove(m.__file__)
|
||||
self.reset()
|
||||
self.assertEqual(fn(a, b, c), compiled_fn(a, b, c), atol=atol, rtol=rtol)
|
||||
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1)
|
||||
|
|
@ -426,8 +443,6 @@ class TestFxGraphCache(TestCase):
|
|||
self.assertEqual(counters["inductor"]["fxgraph_lookup_write_file"], 1)
|
||||
|
||||
# A third call with different score_mod should have a cache miss
|
||||
for m in torch._inductor.codecache.PyCodeCache.cache.values():
|
||||
os.remove(m.__file__)
|
||||
self.reset()
|
||||
self.assertEqual(fn2(a, b, c), compiled_fn2(a, b, c), atol=atol, rtol=rtol)
|
||||
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 2)
|
||||
|
|
@ -647,14 +662,80 @@ class TestFxGraphCache(TestCase):
|
|||
|
||||
self.assertNotEqual(a, b)
|
||||
|
||||
@config.patch({"fx_graph_cache": True})
|
||||
@config.patch({"fx_graph_remote_cache": False})
|
||||
@config.patch({"freezing": True})
|
||||
@parametrize("device", (GPU_TYPE, "cpu"))
|
||||
def test_freezing(self, device):
|
||||
if device == GPU_TYPE and not HAS_GPU:
|
||||
raise unittest.SkipTest(f"requires {GPU_TYPE}")
|
||||
|
||||
class MM(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.param = torch.nn.Parameter(torch.rand(8, 8))
|
||||
|
||||
def forward(self, x):
|
||||
return x @ self.param
|
||||
|
||||
dtype = torch.float16
|
||||
|
||||
# Populate a cache entry.
|
||||
mod1 = MM().to(device=device, dtype=dtype)
|
||||
with torch.no_grad():
|
||||
x = torch.rand(8, 8).to(device=device, dtype=dtype)
|
||||
out0 = mod1(x)
|
||||
out1 = torch.compile(mod1)(x)
|
||||
self.assertEqual(out0, out1)
|
||||
|
||||
self.assertEqual(counters["inductor"]["fxgraph_cache_bypass"], 0)
|
||||
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1)
|
||||
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0)
|
||||
|
||||
counters.clear()
|
||||
self.reset()
|
||||
|
||||
# Same nn.Module, but with different parameters should cache hit.
|
||||
mod2 = MM().to(device=device, dtype=dtype)
|
||||
self.assertNotEqual(mod1.param, mod2.param)
|
||||
|
||||
with torch.no_grad():
|
||||
x = torch.rand(8, 8).to(device=device, dtype=dtype)
|
||||
out0 = mod2(x)
|
||||
out1 = torch.compile(mod2)(x)
|
||||
self.assertEqual(out0, out1)
|
||||
|
||||
self.assertEqual(counters["inductor"]["fxgraph_cache_bypass"], 0)
|
||||
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 0)
|
||||
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 1)
|
||||
|
||||
|
||||
class TestFxGraphCacheHashing(TestCase):
|
||||
def test_tensor_constants(self):
|
||||
"""
|
||||
Test the hashing of tensor constants.
|
||||
"""
|
||||
data = FxGraphCachePickler().dumps(torch.tensor(list(range(9))))
|
||||
small = torch.tensor(list(range(8)))
|
||||
large = torch.tensor(list(range(32)))
|
||||
|
||||
self.assertTrue(GraphLowering.can_inline_constant(small))
|
||||
self.assertFalse(GraphLowering.can_inline_constant(large))
|
||||
|
||||
# By default, we hash the metadata and values independent of the size.
|
||||
pickler = FxGraphCachePickler()
|
||||
|
||||
data = pickler.dumps(small)
|
||||
self.assertIsInstance(pickle.loads(data), TensorMetadataAndValues)
|
||||
data = pickler.dumps(large)
|
||||
self.assertIsInstance(pickle.loads(data), TensorMetadataAndValues)
|
||||
|
||||
# If include_non_inlined=False, we only hash the values of small tensors.
|
||||
pickler = FxGraphCachePickler(False)
|
||||
|
||||
data = pickler.dumps(small)
|
||||
self.assertIsInstance(pickle.loads(data), TensorMetadataAndValues)
|
||||
data = pickler.dumps(large)
|
||||
self.assertIsInstance(pickle.loads(data), TensorMetadata)
|
||||
|
||||
def test_hash_fake_tensors(self):
|
||||
"""
|
||||
|
|
@ -1133,13 +1214,13 @@ class TestUtils(TestCase):
|
|||
b = torch.rand(10)
|
||||
|
||||
with fresh_inductor_cache():
|
||||
self.assertEqual(len(PyCodeCache.cache.keys()), 0)
|
||||
self.assertEqual(len(PyCodeCache.modules), 0)
|
||||
res1 = torch.compile(fn)(a, b)
|
||||
cache_dir1 = cache_dir()
|
||||
|
||||
torch._dynamo.reset()
|
||||
with fresh_inductor_cache():
|
||||
self.assertEqual(len(PyCodeCache.cache.keys()), 0)
|
||||
self.assertEqual(len(PyCodeCache.modules), 0)
|
||||
res2 = torch.compile(fn)(a, b)
|
||||
cache_dir2 = cache_dir()
|
||||
|
||||
|
|
|
|||
|
|
@ -40,11 +40,11 @@ class TestKernelBenchmark(TestCase):
|
|||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
PyCodeCache.cache.clear()
|
||||
PyCodeCache.cache_clear()
|
||||
|
||||
def get_compiled_module(self):
|
||||
compiled_module = None
|
||||
for v in PyCodeCache.cache.values():
|
||||
for v in PyCodeCache.modules:
|
||||
if hasattr(v, "benchmark_compiled_module"):
|
||||
self.assertTrue(
|
||||
compiled_module is None, "Found multiple compiled modules"
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU
|
|||
class TestTritonWrapper(TestCase):
|
||||
def get_compiled_module(self):
|
||||
compiled_module = None
|
||||
for v in PyCodeCache.cache.values():
|
||||
for v in PyCodeCache.modules:
|
||||
if hasattr(v, "benchmark_compiled_module"):
|
||||
self.assertTrue(
|
||||
compiled_module is None, "Found multiple compiled modules"
|
||||
|
|
|
|||
|
|
@ -526,10 +526,18 @@ class FxGraphCachePickler(pickle.Pickler):
|
|||
data that allow us to compute a stable, but safe hash.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
def __init__(self, include_non_inlined: bool = True) -> None:
|
||||
"""
|
||||
Create an FX graph pickler. If include_non_inlined=True, then pickling will
|
||||
include the _values_ for all Tensors. (Note that any tensors are constants
|
||||
attached as attributes to the GraphModule). Otherwise, pickling will include
|
||||
only the metadata for these tensors.
|
||||
"""
|
||||
self._stream = io.BytesIO()
|
||||
super().__init__(self._stream)
|
||||
|
||||
self.include_non_inlined = include_non_inlined
|
||||
|
||||
self.dispatch_table = copyreg.dispatch_table.copy()
|
||||
self.dispatch_table.update(
|
||||
{
|
||||
|
|
@ -558,35 +566,38 @@ class FxGraphCachePickler(pickle.Pickler):
|
|||
def _reduce_tensor(
|
||||
self,
|
||||
t: Tensor,
|
||||
) -> Tuple[Callable[[T], T], Tuple[TensorMetadataAndValues]]:
|
||||
) -> Tuple[Callable[[T], T], Tuple[Union[TensorMetadata, TensorMetadataAndValues]]]:
|
||||
"""
|
||||
Custom reducer to pickle Tensors. If we see tensors, we know they're constants
|
||||
stored as attributes on the GraphModule. Include the values in the key
|
||||
calculation. Small tensors will be inlined, so we can't serve the same cache
|
||||
entry for different values anyway. Large constants are treated as parameters, so
|
||||
we could conceivably reuse a cache entry. To do that, however, PyCodeCache would
|
||||
need more complexity to create a new module from its cache, but with the right
|
||||
constants attached as attributes.
|
||||
stored as attributes on the GraphModule.
|
||||
"""
|
||||
from .graph import GraphLowering
|
||||
|
||||
if t.is_mkldnn:
|
||||
# TODO: These tensors don't currently pickle, so we can't cache a compiled
|
||||
# graph containing them. Just fail now. If mkldnn tensors get pickling
|
||||
# support, we can remove this.
|
||||
raise BypassFxGraphCache("mkldnn tensors unpickleable")
|
||||
|
||||
# Very large tensors could be expensive to copy to cpu and hash. Let's at least
|
||||
# report if we find slowness.
|
||||
start = time()
|
||||
values = t.tolist()
|
||||
elapsed = time() - start
|
||||
if elapsed > 1.0:
|
||||
warnings.warn(
|
||||
f"FX graph cache handling of a large constant took {elapsed:.1}s. "
|
||||
"Please file an issue."
|
||||
)
|
||||
|
||||
# If this is an inlined constant or include_non_inlined=True, then we include
|
||||
# the metadata and the values.
|
||||
metadata = extract_tensor_metadata_for_cache_key(t)
|
||||
return (_ident, (TensorMetadataAndValues(metadata, values),))
|
||||
if GraphLowering.can_inline_constant(t) or self.include_non_inlined:
|
||||
# Very large tensors will be expensive to copy to cpu and hash. Let's at
|
||||
# least report any slowness.
|
||||
start = time()
|
||||
values = t.tolist()
|
||||
elapsed = time() - start
|
||||
if elapsed > 1.0:
|
||||
warnings.warn(
|
||||
f"FX graph cache copying of a large constant took {elapsed:.1}s. "
|
||||
"Please file an issue."
|
||||
)
|
||||
|
||||
return (_ident, (TensorMetadataAndValues(metadata, values),))
|
||||
|
||||
# Otherwise, we just include the metadata.
|
||||
return (_ident, (metadata,))
|
||||
|
||||
def _reduce_symint(self, s: SymInt) -> Tuple[Callable[[T], T], Tuple[str]]:
|
||||
"""
|
||||
|
|
@ -804,6 +815,10 @@ class FxGraphHashDetails:
|
|||
return custom_pass.uuid()
|
||||
|
||||
|
||||
def has_frozen_params(gm: torch.fx.GraphModule) -> bool:
|
||||
return getattr(gm, "_has_frozen_params", False)
|
||||
|
||||
|
||||
def compiled_fx_graph_hash(
|
||||
gm: torch.fx.GraphModule,
|
||||
example_inputs: Sequence[InputType],
|
||||
|
|
@ -813,8 +828,13 @@ def compiled_fx_graph_hash(
|
|||
"""
|
||||
Generate a unique hash of the FX graph for caching.
|
||||
"""
|
||||
# To support caching when the graph has frozen params, we ignore the tensor values
|
||||
# of non-inlined constants since they won't be included in the cache entry. Without
|
||||
# freezing, we want to include the values of any constant attribute.
|
||||
include_non_inlined = not has_frozen_params(gm)
|
||||
|
||||
details = FxGraphHashDetails(gm, example_inputs, fx_kwargs, inputs_to_check)
|
||||
pickler = FxGraphCachePickler()
|
||||
pickler = FxGraphCachePickler(include_non_inlined)
|
||||
# The prefix distinguishes among the other kinds of objects we
|
||||
# cache in this module.
|
||||
key = "f" + pickler.get_hash(details)
|
||||
|
|
@ -828,6 +848,7 @@ def cudagraph_post_compile(
|
|||
example_inputs: Sequence[InputType],
|
||||
compiled_graph: CompiledFxGraph,
|
||||
cudagraphs: BoxedBool,
|
||||
gm: Optional[torch.fx.GraphModule],
|
||||
) -> None:
|
||||
"""
|
||||
Checks for any reasons not to run cudagraphs and then
|
||||
|
|
@ -873,7 +894,7 @@ def cudagraph_post_compile(
|
|||
stack_traces=stack_traces,
|
||||
is_backward=is_backward,
|
||||
is_inference=is_inference,
|
||||
constants=tuple(compiled_graph.constants.values()),
|
||||
constants=tuple(compiled_graph.get_constants(gm).values()),
|
||||
placeholders=placeholders,
|
||||
mutated_input_idxs=tuple(compiled_graph.mutated_input_idxs),
|
||||
)
|
||||
|
|
@ -1030,6 +1051,7 @@ class FxGraphCache:
|
|||
example_inputs: Sequence[InputType],
|
||||
local: bool,
|
||||
remote_cache: Optional[RemoteCache[JsonDataTy]],
|
||||
gm: Optional[torch.fx.GraphModule],
|
||||
) -> Tuple[Optional[CompiledFxGraph], Dict[str, Any]]:
|
||||
"""
|
||||
Lookup a compiled graph in the cache by key. On a hit, return the
|
||||
|
|
@ -1135,7 +1157,7 @@ class FxGraphCache:
|
|||
graph.cache_key,
|
||||
artifact_path,
|
||||
graph.cache_linemap,
|
||||
graph.constants,
|
||||
graph.get_constants(gm),
|
||||
).call
|
||||
except OSError:
|
||||
# Not expected, but in case the PyCodeCache entry is removed from
|
||||
|
|
@ -1177,6 +1199,7 @@ class FxGraphCache:
|
|||
compiled_graph: CompiledFxGraph,
|
||||
example_inputs: Sequence[InputType],
|
||||
cudagraphs: BoxedBool,
|
||||
gm: Optional[torch.fx.GraphModule] = None,
|
||||
) -> CompiledFxGraph:
|
||||
"""
|
||||
Run a set of post processing steps after loading from the cache. These involve:
|
||||
|
|
@ -1206,6 +1229,7 @@ class FxGraphCache:
|
|||
example_inputs,
|
||||
compiled_graph,
|
||||
cudagraphs,
|
||||
gm,
|
||||
)
|
||||
inputs_to_check = compiled_graph.inputs_to_check
|
||||
# cudagraphs could have been disabled from the earlier conditions
|
||||
|
|
@ -1294,9 +1318,15 @@ class FxGraphCache:
|
|||
raise BypassFxGraphCache("Unsupported post grad custom pass")
|
||||
|
||||
# Freezing can embed constants that wouldn't be static across runs.
|
||||
if config.freezing or config.aot_inductor.use_runtime_constant_folding:
|
||||
if has_frozen_params(gm) and not torch._utils_internal.justknobs_check(
|
||||
"pytorch/inductor:allow_freezing_with_caching"
|
||||
):
|
||||
raise BypassFxGraphCache("Skipping graph with frozen constants")
|
||||
|
||||
if config.aot_inductor.use_runtime_constant_folding:
|
||||
raise BypassFxGraphCache(
|
||||
"Freezing may introduce constants that aren't static across runs"
|
||||
"Runtime constant folding can introduce constants that aren't "
|
||||
"static across runs"
|
||||
)
|
||||
|
||||
from torch._inductor.bisect_helper import BisectionManager
|
||||
|
|
@ -1386,6 +1416,7 @@ class FxGraphCache:
|
|||
local: bool,
|
||||
remote_cache: Optional[RemoteCache[JsonDataTy]],
|
||||
is_backward: bool,
|
||||
gm: Optional[torch.fx.GraphModule] = None,
|
||||
) -> Tuple[Optional[CompiledFxGraph], Dict[str, Any]]:
|
||||
"""
|
||||
Lookup the graph with the given key, and return results and metadata.
|
||||
|
|
@ -1393,7 +1424,7 @@ class FxGraphCache:
|
|||
differently from FXGraphCache.
|
||||
"""
|
||||
compiled_graph, cache_info = FxGraphCache._lookup_graph(
|
||||
key, example_inputs, local, remote_cache
|
||||
key, example_inputs, local, remote_cache, gm
|
||||
)
|
||||
cache_info = {
|
||||
**cache_info,
|
||||
|
|
@ -1453,6 +1484,7 @@ class FxGraphCache:
|
|||
local,
|
||||
remote_cache,
|
||||
is_backward=fx_kwargs.get("is_backward", False),
|
||||
gm=gm,
|
||||
)
|
||||
|
||||
# CACHE BYPASS: Compile the graph, don't save it to the cache
|
||||
|
|
@ -1528,7 +1560,7 @@ class FxGraphCache:
|
|||
)
|
||||
# Use the passed in cudagraphs so that we mutate the BoxedBool correctly
|
||||
FxGraphCache.post_compile(
|
||||
compiled_graph, example_inputs, fx_kwargs["cudagraphs"] # type: ignore[arg-type]
|
||||
compiled_graph, example_inputs, fx_kwargs["cudagraphs"], gm # type: ignore[arg-type]
|
||||
)
|
||||
return compiled_graph
|
||||
|
||||
|
|
@ -1561,7 +1593,15 @@ class CompiledFxGraph:
|
|||
device_idxs: Set[int]
|
||||
mutated_inputs: Set[str]
|
||||
mutated_input_idxs: Set[int]
|
||||
constants: Dict[str, torch.Tensor]
|
||||
# We populate exactly one of the next two fields. In the common case, we store the
|
||||
# constant attirbutes in the cache entry and re-attach them to the module created in
|
||||
# PyCodeCache.load_by_key_path. In the case that the graph has frozen parameters,
|
||||
# however, we save the mapping from attribute names in the GraphLowering to the
|
||||
# original name of the attribute in the GraphModule. When we create the module from
|
||||
# the cache entry, we then look up the constants from the current GraphModule. This
|
||||
# scheme allows us to support caching with freezing.
|
||||
allocated_constant_name: Optional[Dict[str, str]]
|
||||
constants: Optional[Dict[str, torch.Tensor]]
|
||||
torchbind_constants: Dict[str, torch._C.ScriptObject]
|
||||
output_strides: Optional[List[Optional[Tuple[_StrideExprStr, ...]]]]
|
||||
disabled_cudagraphs_reason: Optional[str]
|
||||
|
|
@ -1588,6 +1628,7 @@ class CompiledFxGraph:
|
|||
self,
|
||||
current_callable: Optional[Callable[..., Any]],
|
||||
graph: GraphLowering,
|
||||
gm: torch.fx.GraphModule,
|
||||
output_strides: List[Optional[Tuple[_StrideExprStr, ...]]],
|
||||
disabled_cudagraphs_reason: Optional[str],
|
||||
metrics_deltas: metrics.CachedMetricsDeltas,
|
||||
|
|
@ -1604,7 +1645,12 @@ class CompiledFxGraph:
|
|||
self.device_idxs = set(graph.device_idxs)
|
||||
self.mutated_inputs = set(graph.mutated_inputs)
|
||||
self.mutated_input_idxs = set(graph.mutated_input_idxs)
|
||||
self.constants = graph.constants
|
||||
if has_frozen_params(gm):
|
||||
self.allocated_constant_name = graph.allocated_constant_name
|
||||
self.constants = None
|
||||
else:
|
||||
self.allocated_constant_name = None
|
||||
self.constants = graph.constants
|
||||
self.torchbind_constants = graph.torchbind_constants
|
||||
self.output_strides = output_strides
|
||||
self.disabled_cudagraphs_reason = disabled_cudagraphs_reason
|
||||
|
|
@ -1623,6 +1669,26 @@ class CompiledFxGraph:
|
|||
finally:
|
||||
AutotuneCacheBundler.end_compile()
|
||||
|
||||
def get_constants(
|
||||
self, gm: Optional[torch.fx.GraphModule]
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
"""
|
||||
Get the constant attributes.
|
||||
"""
|
||||
# Normal case: The constants are stored in the entry.
|
||||
if self.constants is not None:
|
||||
return self.constants
|
||||
|
||||
# Freezing case: Look up the constants from attributes on the GraphModule using
|
||||
# the allocated_constant_name map.
|
||||
assert gm is not None
|
||||
assert self.allocated_constant_name is not None
|
||||
constants = {
|
||||
name: getattr(gm, orig_name)
|
||||
for name, orig_name in self.allocated_constant_name.items()
|
||||
}
|
||||
return constants
|
||||
|
||||
|
||||
def run_command_and_check(cmd_: str) -> None:
|
||||
cmd = shlex.split(cmd_)
|
||||
|
|
@ -2976,9 +3042,13 @@ def touch(filename: str): # type: ignore[no-untyped-def]
|
|||
|
||||
@clear_on_fresh_inductor_cache
|
||||
class PyCodeCache:
|
||||
# Track the loaded modules so we can remove the on-disk artifacts when
|
||||
# clearing the cache. Note also that we may load the same path more
|
||||
# than once, but attach different attributes, i.e., due to different
|
||||
# constant values.
|
||||
modules: List[ModuleType] = []
|
||||
cache: Dict[str, ModuleType] = {}
|
||||
linemaps: Dict[str, List[Tuple[Any, ...]]] = {}
|
||||
cache_clear = staticmethod(cache.clear)
|
||||
|
||||
@classmethod
|
||||
def write(cls, source_code: str, extra: str = "") -> Tuple[str, str]:
|
||||
|
|
@ -3005,24 +3075,33 @@ class PyCodeCache:
|
|||
) -> ModuleType:
|
||||
if linemap is None:
|
||||
linemap = []
|
||||
if key not in cls.cache:
|
||||
mod = _reload_python_module(key, path)
|
||||
|
||||
# another thread might set this first
|
||||
cls.cache.setdefault(key, mod)
|
||||
# unzip into separate lines/nodes lists
|
||||
cls.linemaps[path] = list(zip(*linemap))
|
||||
mod = _reload_python_module(key, path)
|
||||
|
||||
if attrs is not None:
|
||||
for k, v in attrs.items():
|
||||
setattr(mod, k, v)
|
||||
# unzip into separate lines/nodes lists
|
||||
cls.linemaps[path] = list(zip(*linemap))
|
||||
|
||||
if not (linemap or attrs):
|
||||
mod._reload_in_subproc = functools.partial( # type: ignore[attr-defined]
|
||||
_reload_python_module_in_subproc, key, path
|
||||
)
|
||||
if attrs is not None:
|
||||
for k, v in attrs.items():
|
||||
setattr(mod, k, v)
|
||||
|
||||
return cls.cache[key]
|
||||
if not (linemap or attrs):
|
||||
mod._reload_in_subproc = functools.partial( # type: ignore[attr-defined]
|
||||
_reload_python_module_in_subproc, key, path
|
||||
)
|
||||
|
||||
cls.modules.append(mod)
|
||||
return mod
|
||||
|
||||
@classmethod
|
||||
def cache_clear(cls) -> None:
|
||||
for mod in cls.modules:
|
||||
try:
|
||||
assert mod.__file__
|
||||
os.remove(mod.__file__)
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
cls.modules.clear()
|
||||
|
||||
@classmethod
|
||||
@functools.lru_cache(None)
|
||||
|
|
|
|||
|
|
@ -760,7 +760,7 @@ def _compile_fx_inner(
|
|||
# to return the string directly.
|
||||
return compiled_graph
|
||||
compiled_graph = FxGraphCache.post_compile(
|
||||
compiled_graph, example_inputs, cudagraphs
|
||||
compiled_graph, example_inputs, cudagraphs, gm
|
||||
)
|
||||
|
||||
log.debug("FX codegen and compilation took %.3fs", time.time() - start)
|
||||
|
|
@ -1015,6 +1015,7 @@ def fx_codegen_and_compile(
|
|||
compiled_graph = CompiledFxGraph(
|
||||
compiled_fn,
|
||||
graph,
|
||||
gm,
|
||||
output_strides,
|
||||
V.graph.disable_cudagraphs_reason,
|
||||
metrics_helper.get_deltas(),
|
||||
|
|
@ -1289,6 +1290,8 @@ def fw_compiler_freezing(
|
|||
aot_example_inputs, # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
setattr(opt_model, "_has_frozen_params", True) # noqa: B010
|
||||
|
||||
aot_example_inputs = [aot_example_inputs[ind] for ind in preserved_arg_indices]
|
||||
num_fixed = len(preserved_arg_indices) - num_example_inputs
|
||||
|
||||
|
|
|
|||
|
|
@ -77,7 +77,8 @@ def benchmark_all_kernels(benchmark_name, benchmark_all_configs):
|
|||
from torch._inductor.codecache import PyCodeCache
|
||||
|
||||
nfound = 0
|
||||
for kernel_key, kernel_mod in PyCodeCache.cache.items():
|
||||
for kernel_mod in PyCodeCache.modules:
|
||||
kernel_key = kernel_mod.key
|
||||
if not hasattr(kernel_mod, "get_args") or not hasattr(kernel_mod, "call"):
|
||||
continue
|
||||
|
||||
|
|
|
|||
|
|
@ -890,7 +890,7 @@ class ExecutionTraceObserver(_ITraceObserver):
|
|||
|
||||
kernel_files = [
|
||||
v.__file__
|
||||
for v in PyCodeCache.cache.values()
|
||||
for v in PyCodeCache.modules
|
||||
if getattr(v, "__file__", None) is not None
|
||||
]
|
||||
work_dir, file_name = os.path.split(self._output_file_path)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user