[fx graph cache] Support freezing with FX graph caching (#136505)

Summary: The main changes to support freezing are:
1) When pickling constant tensors as part of the cache key calculation: If freezing has not been applied, then keep the existing behavior (pickle the metadata and values). If freezing has been applied, then pickle the values if the constant will be inlined; otherwise, consider only the metadata.
2) If freezing has been applied, modify what we store in the cache: Instead of storing the constant attributes in the cache entry, store the _names_ of the constants, and then grab those constants from the GraphModule when we need attache the attributes to a newly-loaded Python module. Since the cache lookup path loads the Python module, this bullet means we need to thread through a GraphModule argument in several places.
3) Since this feature means that we may need to reload the same Python module path more than once (but attach different constant attributes), I changed PyCodeCache.load_by_key_path to not store an in-memory map of path to module (since there may be more than one). I don't _think_ this will have any affect on performance, however.. It's unclear why we were using an in-memory cache here anyway, since this function should only be called once for each module needed to be loaded.
4) Several tests were removing on-disk PyCodeCache artifacts by iterating over the modules. I made this more straightforward by implementing a cache_clear method that removes the on-disk artifacts. Arguably, this should have been the implementation all along.

Differential Revision: [D63542170](https://our.internmc.facebook.com/intern/diff/D63542170)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/136505
Approved by: https://github.com/eellison
This commit is contained in:
Sam Larsen 2024-10-31 15:19:32 -07:00 committed by PyTorch MergeBot
parent 7d644f025f
commit d8b606ecb5
8 changed files with 225 additions and 64 deletions

View File

@ -1,6 +1,5 @@
# Owner(s): ["module: dynamo"] # Owner(s): ["module: dynamo"]
import os
import unittest import unittest
from unittest.mock import patch from unittest.mock import patch
@ -53,8 +52,6 @@ class AOTAutogradCacheTests(InductorTestCase):
Clear unrelated caches, like dynamo and PyCodeCache Clear unrelated caches, like dynamo and PyCodeCache
""" """
torch._dynamo.reset() torch._dynamo.reset()
for m in torch._inductor.codecache.PyCodeCache.cache.values():
os.remove(m.__file__)
torch._inductor.codecache.PyCodeCache.cache_clear() torch._inductor.codecache.PyCodeCache.cache_clear()
@inductor_config.patch("fx_graph_remote_cache", False) @inductor_config.patch("fx_graph_remote_cache", False)

View File

@ -160,11 +160,12 @@ class TestFxGraphCache(TestCase):
# A second call should hit. (First reset so in-memory guards # A second call should hit. (First reset so in-memory guards
# don't prevent compilation). # don't prevent compilation).
for m in torch._inductor.codecache.PyCodeCache.cache.values():
os.remove(m.__file__)
# Clean triton kernels
shutil.rmtree(os.path.join(cache_dir(), "triton"), ignore_errors=True)
self.reset() self.reset()
# Clean PyCodeCache and triton kernels
PyCodeCache.cache_clear()
shutil.rmtree(os.path.join(cache_dir(), "triton"), ignore_errors=True)
self.assertEqual(fn(a, b), compiled_fn(a, b)) self.assertEqual(fn(a, b), compiled_fn(a, b))
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1) self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1)
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 1) self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 1)
@ -381,6 +382,24 @@ class TestFxGraphCache(TestCase):
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 2) self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 2)
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0) self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0)
# Now pretend the constants are frozen params.
counters.clear()
self.reset()
with mock.patch(
"torch._inductor.codecache.has_frozen_params", return_value=True
):
# A call to fn1 should miss in the cache since we do not consider
# the constant values.
self.assertEqual(fn1(a), compiled_fn1(a))
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1)
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0)
# A call to fn2 should hit for the same reason.
self.assertEqual(fn2(a), compiled_fn2(a))
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1)
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 1)
@requires_cuda @requires_cuda
@config.patch({"fx_graph_cache": True}) @config.patch({"fx_graph_cache": True})
@config.patch({"fx_graph_remote_cache": False}) @config.patch({"fx_graph_remote_cache": False})
@ -417,8 +436,6 @@ class TestFxGraphCache(TestCase):
# A second call should hit. (First reset so in-memory guards # A second call should hit. (First reset so in-memory guards
# don't prevent compilation). # don't prevent compilation).
for m in torch._inductor.codecache.PyCodeCache.cache.values():
os.remove(m.__file__)
self.reset() self.reset()
self.assertEqual(fn(a, b, c), compiled_fn(a, b, c), atol=atol, rtol=rtol) self.assertEqual(fn(a, b, c), compiled_fn(a, b, c), atol=atol, rtol=rtol)
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1) self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1)
@ -426,8 +443,6 @@ class TestFxGraphCache(TestCase):
self.assertEqual(counters["inductor"]["fxgraph_lookup_write_file"], 1) self.assertEqual(counters["inductor"]["fxgraph_lookup_write_file"], 1)
# A third call with different score_mod should have a cache miss # A third call with different score_mod should have a cache miss
for m in torch._inductor.codecache.PyCodeCache.cache.values():
os.remove(m.__file__)
self.reset() self.reset()
self.assertEqual(fn2(a, b, c), compiled_fn2(a, b, c), atol=atol, rtol=rtol) self.assertEqual(fn2(a, b, c), compiled_fn2(a, b, c), atol=atol, rtol=rtol)
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 2) self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 2)
@ -647,14 +662,80 @@ class TestFxGraphCache(TestCase):
self.assertNotEqual(a, b) self.assertNotEqual(a, b)
@config.patch({"fx_graph_cache": True})
@config.patch({"fx_graph_remote_cache": False})
@config.patch({"freezing": True})
@parametrize("device", (GPU_TYPE, "cpu"))
def test_freezing(self, device):
if device == GPU_TYPE and not HAS_GPU:
raise unittest.SkipTest(f"requires {GPU_TYPE}")
class MM(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.param = torch.nn.Parameter(torch.rand(8, 8))
def forward(self, x):
return x @ self.param
dtype = torch.float16
# Populate a cache entry.
mod1 = MM().to(device=device, dtype=dtype)
with torch.no_grad():
x = torch.rand(8, 8).to(device=device, dtype=dtype)
out0 = mod1(x)
out1 = torch.compile(mod1)(x)
self.assertEqual(out0, out1)
self.assertEqual(counters["inductor"]["fxgraph_cache_bypass"], 0)
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1)
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0)
counters.clear()
self.reset()
# Same nn.Module, but with different parameters should cache hit.
mod2 = MM().to(device=device, dtype=dtype)
self.assertNotEqual(mod1.param, mod2.param)
with torch.no_grad():
x = torch.rand(8, 8).to(device=device, dtype=dtype)
out0 = mod2(x)
out1 = torch.compile(mod2)(x)
self.assertEqual(out0, out1)
self.assertEqual(counters["inductor"]["fxgraph_cache_bypass"], 0)
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 0)
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 1)
class TestFxGraphCacheHashing(TestCase): class TestFxGraphCacheHashing(TestCase):
def test_tensor_constants(self): def test_tensor_constants(self):
""" """
Test the hashing of tensor constants. Test the hashing of tensor constants.
""" """
data = FxGraphCachePickler().dumps(torch.tensor(list(range(9)))) small = torch.tensor(list(range(8)))
large = torch.tensor(list(range(32)))
self.assertTrue(GraphLowering.can_inline_constant(small))
self.assertFalse(GraphLowering.can_inline_constant(large))
# By default, we hash the metadata and values independent of the size.
pickler = FxGraphCachePickler()
data = pickler.dumps(small)
self.assertIsInstance(pickle.loads(data), TensorMetadataAndValues) self.assertIsInstance(pickle.loads(data), TensorMetadataAndValues)
data = pickler.dumps(large)
self.assertIsInstance(pickle.loads(data), TensorMetadataAndValues)
# If include_non_inlined=False, we only hash the values of small tensors.
pickler = FxGraphCachePickler(False)
data = pickler.dumps(small)
self.assertIsInstance(pickle.loads(data), TensorMetadataAndValues)
data = pickler.dumps(large)
self.assertIsInstance(pickle.loads(data), TensorMetadata)
def test_hash_fake_tensors(self): def test_hash_fake_tensors(self):
""" """
@ -1133,13 +1214,13 @@ class TestUtils(TestCase):
b = torch.rand(10) b = torch.rand(10)
with fresh_inductor_cache(): with fresh_inductor_cache():
self.assertEqual(len(PyCodeCache.cache.keys()), 0) self.assertEqual(len(PyCodeCache.modules), 0)
res1 = torch.compile(fn)(a, b) res1 = torch.compile(fn)(a, b)
cache_dir1 = cache_dir() cache_dir1 = cache_dir()
torch._dynamo.reset() torch._dynamo.reset()
with fresh_inductor_cache(): with fresh_inductor_cache():
self.assertEqual(len(PyCodeCache.cache.keys()), 0) self.assertEqual(len(PyCodeCache.modules), 0)
res2 = torch.compile(fn)(a, b) res2 = torch.compile(fn)(a, b)
cache_dir2 = cache_dir() cache_dir2 = cache_dir()

View File

@ -40,11 +40,11 @@ class TestKernelBenchmark(TestCase):
def setUp(self): def setUp(self):
super().setUp() super().setUp()
PyCodeCache.cache.clear() PyCodeCache.cache_clear()
def get_compiled_module(self): def get_compiled_module(self):
compiled_module = None compiled_module = None
for v in PyCodeCache.cache.values(): for v in PyCodeCache.modules:
if hasattr(v, "benchmark_compiled_module"): if hasattr(v, "benchmark_compiled_module"):
self.assertTrue( self.assertTrue(
compiled_module is None, "Found multiple compiled modules" compiled_module is None, "Found multiple compiled modules"

View File

@ -14,7 +14,7 @@ from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU
class TestTritonWrapper(TestCase): class TestTritonWrapper(TestCase):
def get_compiled_module(self): def get_compiled_module(self):
compiled_module = None compiled_module = None
for v in PyCodeCache.cache.values(): for v in PyCodeCache.modules:
if hasattr(v, "benchmark_compiled_module"): if hasattr(v, "benchmark_compiled_module"):
self.assertTrue( self.assertTrue(
compiled_module is None, "Found multiple compiled modules" compiled_module is None, "Found multiple compiled modules"

View File

@ -526,10 +526,18 @@ class FxGraphCachePickler(pickle.Pickler):
data that allow us to compute a stable, but safe hash. data that allow us to compute a stable, but safe hash.
""" """
def __init__(self) -> None: def __init__(self, include_non_inlined: bool = True) -> None:
"""
Create an FX graph pickler. If include_non_inlined=True, then pickling will
include the _values_ for all Tensors. (Note that any tensors are constants
attached as attributes to the GraphModule). Otherwise, pickling will include
only the metadata for these tensors.
"""
self._stream = io.BytesIO() self._stream = io.BytesIO()
super().__init__(self._stream) super().__init__(self._stream)
self.include_non_inlined = include_non_inlined
self.dispatch_table = copyreg.dispatch_table.copy() self.dispatch_table = copyreg.dispatch_table.copy()
self.dispatch_table.update( self.dispatch_table.update(
{ {
@ -558,35 +566,38 @@ class FxGraphCachePickler(pickle.Pickler):
def _reduce_tensor( def _reduce_tensor(
self, self,
t: Tensor, t: Tensor,
) -> Tuple[Callable[[T], T], Tuple[TensorMetadataAndValues]]: ) -> Tuple[Callable[[T], T], Tuple[Union[TensorMetadata, TensorMetadataAndValues]]]:
""" """
Custom reducer to pickle Tensors. If we see tensors, we know they're constants Custom reducer to pickle Tensors. If we see tensors, we know they're constants
stored as attributes on the GraphModule. Include the values in the key stored as attributes on the GraphModule.
calculation. Small tensors will be inlined, so we can't serve the same cache
entry for different values anyway. Large constants are treated as parameters, so
we could conceivably reuse a cache entry. To do that, however, PyCodeCache would
need more complexity to create a new module from its cache, but with the right
constants attached as attributes.
""" """
from .graph import GraphLowering
if t.is_mkldnn: if t.is_mkldnn:
# TODO: These tensors don't currently pickle, so we can't cache a compiled # TODO: These tensors don't currently pickle, so we can't cache a compiled
# graph containing them. Just fail now. If mkldnn tensors get pickling # graph containing them. Just fail now. If mkldnn tensors get pickling
# support, we can remove this. # support, we can remove this.
raise BypassFxGraphCache("mkldnn tensors unpickleable") raise BypassFxGraphCache("mkldnn tensors unpickleable")
# Very large tensors could be expensive to copy to cpu and hash. Let's at least # If this is an inlined constant or include_non_inlined=True, then we include
# report if we find slowness. # the metadata and the values.
start = time()
values = t.tolist()
elapsed = time() - start
if elapsed > 1.0:
warnings.warn(
f"FX graph cache handling of a large constant took {elapsed:.1}s. "
"Please file an issue."
)
metadata = extract_tensor_metadata_for_cache_key(t) metadata = extract_tensor_metadata_for_cache_key(t)
return (_ident, (TensorMetadataAndValues(metadata, values),)) if GraphLowering.can_inline_constant(t) or self.include_non_inlined:
# Very large tensors will be expensive to copy to cpu and hash. Let's at
# least report any slowness.
start = time()
values = t.tolist()
elapsed = time() - start
if elapsed > 1.0:
warnings.warn(
f"FX graph cache copying of a large constant took {elapsed:.1}s. "
"Please file an issue."
)
return (_ident, (TensorMetadataAndValues(metadata, values),))
# Otherwise, we just include the metadata.
return (_ident, (metadata,))
def _reduce_symint(self, s: SymInt) -> Tuple[Callable[[T], T], Tuple[str]]: def _reduce_symint(self, s: SymInt) -> Tuple[Callable[[T], T], Tuple[str]]:
""" """
@ -804,6 +815,10 @@ class FxGraphHashDetails:
return custom_pass.uuid() return custom_pass.uuid()
def has_frozen_params(gm: torch.fx.GraphModule) -> bool:
return getattr(gm, "_has_frozen_params", False)
def compiled_fx_graph_hash( def compiled_fx_graph_hash(
gm: torch.fx.GraphModule, gm: torch.fx.GraphModule,
example_inputs: Sequence[InputType], example_inputs: Sequence[InputType],
@ -813,8 +828,13 @@ def compiled_fx_graph_hash(
""" """
Generate a unique hash of the FX graph for caching. Generate a unique hash of the FX graph for caching.
""" """
# To support caching when the graph has frozen params, we ignore the tensor values
# of non-inlined constants since they won't be included in the cache entry. Without
# freezing, we want to include the values of any constant attribute.
include_non_inlined = not has_frozen_params(gm)
details = FxGraphHashDetails(gm, example_inputs, fx_kwargs, inputs_to_check) details = FxGraphHashDetails(gm, example_inputs, fx_kwargs, inputs_to_check)
pickler = FxGraphCachePickler() pickler = FxGraphCachePickler(include_non_inlined)
# The prefix distinguishes among the other kinds of objects we # The prefix distinguishes among the other kinds of objects we
# cache in this module. # cache in this module.
key = "f" + pickler.get_hash(details) key = "f" + pickler.get_hash(details)
@ -828,6 +848,7 @@ def cudagraph_post_compile(
example_inputs: Sequence[InputType], example_inputs: Sequence[InputType],
compiled_graph: CompiledFxGraph, compiled_graph: CompiledFxGraph,
cudagraphs: BoxedBool, cudagraphs: BoxedBool,
gm: Optional[torch.fx.GraphModule],
) -> None: ) -> None:
""" """
Checks for any reasons not to run cudagraphs and then Checks for any reasons not to run cudagraphs and then
@ -873,7 +894,7 @@ def cudagraph_post_compile(
stack_traces=stack_traces, stack_traces=stack_traces,
is_backward=is_backward, is_backward=is_backward,
is_inference=is_inference, is_inference=is_inference,
constants=tuple(compiled_graph.constants.values()), constants=tuple(compiled_graph.get_constants(gm).values()),
placeholders=placeholders, placeholders=placeholders,
mutated_input_idxs=tuple(compiled_graph.mutated_input_idxs), mutated_input_idxs=tuple(compiled_graph.mutated_input_idxs),
) )
@ -1030,6 +1051,7 @@ class FxGraphCache:
example_inputs: Sequence[InputType], example_inputs: Sequence[InputType],
local: bool, local: bool,
remote_cache: Optional[RemoteCache[JsonDataTy]], remote_cache: Optional[RemoteCache[JsonDataTy]],
gm: Optional[torch.fx.GraphModule],
) -> Tuple[Optional[CompiledFxGraph], Dict[str, Any]]: ) -> Tuple[Optional[CompiledFxGraph], Dict[str, Any]]:
""" """
Lookup a compiled graph in the cache by key. On a hit, return the Lookup a compiled graph in the cache by key. On a hit, return the
@ -1135,7 +1157,7 @@ class FxGraphCache:
graph.cache_key, graph.cache_key,
artifact_path, artifact_path,
graph.cache_linemap, graph.cache_linemap,
graph.constants, graph.get_constants(gm),
).call ).call
except OSError: except OSError:
# Not expected, but in case the PyCodeCache entry is removed from # Not expected, but in case the PyCodeCache entry is removed from
@ -1177,6 +1199,7 @@ class FxGraphCache:
compiled_graph: CompiledFxGraph, compiled_graph: CompiledFxGraph,
example_inputs: Sequence[InputType], example_inputs: Sequence[InputType],
cudagraphs: BoxedBool, cudagraphs: BoxedBool,
gm: Optional[torch.fx.GraphModule] = None,
) -> CompiledFxGraph: ) -> CompiledFxGraph:
""" """
Run a set of post processing steps after loading from the cache. These involve: Run a set of post processing steps after loading from the cache. These involve:
@ -1206,6 +1229,7 @@ class FxGraphCache:
example_inputs, example_inputs,
compiled_graph, compiled_graph,
cudagraphs, cudagraphs,
gm,
) )
inputs_to_check = compiled_graph.inputs_to_check inputs_to_check = compiled_graph.inputs_to_check
# cudagraphs could have been disabled from the earlier conditions # cudagraphs could have been disabled from the earlier conditions
@ -1294,9 +1318,15 @@ class FxGraphCache:
raise BypassFxGraphCache("Unsupported post grad custom pass") raise BypassFxGraphCache("Unsupported post grad custom pass")
# Freezing can embed constants that wouldn't be static across runs. # Freezing can embed constants that wouldn't be static across runs.
if config.freezing or config.aot_inductor.use_runtime_constant_folding: if has_frozen_params(gm) and not torch._utils_internal.justknobs_check(
"pytorch/inductor:allow_freezing_with_caching"
):
raise BypassFxGraphCache("Skipping graph with frozen constants")
if config.aot_inductor.use_runtime_constant_folding:
raise BypassFxGraphCache( raise BypassFxGraphCache(
"Freezing may introduce constants that aren't static across runs" "Runtime constant folding can introduce constants that aren't "
"static across runs"
) )
from torch._inductor.bisect_helper import BisectionManager from torch._inductor.bisect_helper import BisectionManager
@ -1386,6 +1416,7 @@ class FxGraphCache:
local: bool, local: bool,
remote_cache: Optional[RemoteCache[JsonDataTy]], remote_cache: Optional[RemoteCache[JsonDataTy]],
is_backward: bool, is_backward: bool,
gm: Optional[torch.fx.GraphModule] = None,
) -> Tuple[Optional[CompiledFxGraph], Dict[str, Any]]: ) -> Tuple[Optional[CompiledFxGraph], Dict[str, Any]]:
""" """
Lookup the graph with the given key, and return results and metadata. Lookup the graph with the given key, and return results and metadata.
@ -1393,7 +1424,7 @@ class FxGraphCache:
differently from FXGraphCache. differently from FXGraphCache.
""" """
compiled_graph, cache_info = FxGraphCache._lookup_graph( compiled_graph, cache_info = FxGraphCache._lookup_graph(
key, example_inputs, local, remote_cache key, example_inputs, local, remote_cache, gm
) )
cache_info = { cache_info = {
**cache_info, **cache_info,
@ -1453,6 +1484,7 @@ class FxGraphCache:
local, local,
remote_cache, remote_cache,
is_backward=fx_kwargs.get("is_backward", False), is_backward=fx_kwargs.get("is_backward", False),
gm=gm,
) )
# CACHE BYPASS: Compile the graph, don't save it to the cache # CACHE BYPASS: Compile the graph, don't save it to the cache
@ -1528,7 +1560,7 @@ class FxGraphCache:
) )
# Use the passed in cudagraphs so that we mutate the BoxedBool correctly # Use the passed in cudagraphs so that we mutate the BoxedBool correctly
FxGraphCache.post_compile( FxGraphCache.post_compile(
compiled_graph, example_inputs, fx_kwargs["cudagraphs"] # type: ignore[arg-type] compiled_graph, example_inputs, fx_kwargs["cudagraphs"], gm # type: ignore[arg-type]
) )
return compiled_graph return compiled_graph
@ -1561,7 +1593,15 @@ class CompiledFxGraph:
device_idxs: Set[int] device_idxs: Set[int]
mutated_inputs: Set[str] mutated_inputs: Set[str]
mutated_input_idxs: Set[int] mutated_input_idxs: Set[int]
constants: Dict[str, torch.Tensor] # We populate exactly one of the next two fields. In the common case, we store the
# constant attirbutes in the cache entry and re-attach them to the module created in
# PyCodeCache.load_by_key_path. In the case that the graph has frozen parameters,
# however, we save the mapping from attribute names in the GraphLowering to the
# original name of the attribute in the GraphModule. When we create the module from
# the cache entry, we then look up the constants from the current GraphModule. This
# scheme allows us to support caching with freezing.
allocated_constant_name: Optional[Dict[str, str]]
constants: Optional[Dict[str, torch.Tensor]]
torchbind_constants: Dict[str, torch._C.ScriptObject] torchbind_constants: Dict[str, torch._C.ScriptObject]
output_strides: Optional[List[Optional[Tuple[_StrideExprStr, ...]]]] output_strides: Optional[List[Optional[Tuple[_StrideExprStr, ...]]]]
disabled_cudagraphs_reason: Optional[str] disabled_cudagraphs_reason: Optional[str]
@ -1588,6 +1628,7 @@ class CompiledFxGraph:
self, self,
current_callable: Optional[Callable[..., Any]], current_callable: Optional[Callable[..., Any]],
graph: GraphLowering, graph: GraphLowering,
gm: torch.fx.GraphModule,
output_strides: List[Optional[Tuple[_StrideExprStr, ...]]], output_strides: List[Optional[Tuple[_StrideExprStr, ...]]],
disabled_cudagraphs_reason: Optional[str], disabled_cudagraphs_reason: Optional[str],
metrics_deltas: metrics.CachedMetricsDeltas, metrics_deltas: metrics.CachedMetricsDeltas,
@ -1604,7 +1645,12 @@ class CompiledFxGraph:
self.device_idxs = set(graph.device_idxs) self.device_idxs = set(graph.device_idxs)
self.mutated_inputs = set(graph.mutated_inputs) self.mutated_inputs = set(graph.mutated_inputs)
self.mutated_input_idxs = set(graph.mutated_input_idxs) self.mutated_input_idxs = set(graph.mutated_input_idxs)
self.constants = graph.constants if has_frozen_params(gm):
self.allocated_constant_name = graph.allocated_constant_name
self.constants = None
else:
self.allocated_constant_name = None
self.constants = graph.constants
self.torchbind_constants = graph.torchbind_constants self.torchbind_constants = graph.torchbind_constants
self.output_strides = output_strides self.output_strides = output_strides
self.disabled_cudagraphs_reason = disabled_cudagraphs_reason self.disabled_cudagraphs_reason = disabled_cudagraphs_reason
@ -1623,6 +1669,26 @@ class CompiledFxGraph:
finally: finally:
AutotuneCacheBundler.end_compile() AutotuneCacheBundler.end_compile()
def get_constants(
self, gm: Optional[torch.fx.GraphModule]
) -> Dict[str, torch.Tensor]:
"""
Get the constant attributes.
"""
# Normal case: The constants are stored in the entry.
if self.constants is not None:
return self.constants
# Freezing case: Look up the constants from attributes on the GraphModule using
# the allocated_constant_name map.
assert gm is not None
assert self.allocated_constant_name is not None
constants = {
name: getattr(gm, orig_name)
for name, orig_name in self.allocated_constant_name.items()
}
return constants
def run_command_and_check(cmd_: str) -> None: def run_command_and_check(cmd_: str) -> None:
cmd = shlex.split(cmd_) cmd = shlex.split(cmd_)
@ -2976,9 +3042,13 @@ def touch(filename: str): # type: ignore[no-untyped-def]
@clear_on_fresh_inductor_cache @clear_on_fresh_inductor_cache
class PyCodeCache: class PyCodeCache:
# Track the loaded modules so we can remove the on-disk artifacts when
# clearing the cache. Note also that we may load the same path more
# than once, but attach different attributes, i.e., due to different
# constant values.
modules: List[ModuleType] = []
cache: Dict[str, ModuleType] = {} cache: Dict[str, ModuleType] = {}
linemaps: Dict[str, List[Tuple[Any, ...]]] = {} linemaps: Dict[str, List[Tuple[Any, ...]]] = {}
cache_clear = staticmethod(cache.clear)
@classmethod @classmethod
def write(cls, source_code: str, extra: str = "") -> Tuple[str, str]: def write(cls, source_code: str, extra: str = "") -> Tuple[str, str]:
@ -3005,24 +3075,33 @@ class PyCodeCache:
) -> ModuleType: ) -> ModuleType:
if linemap is None: if linemap is None:
linemap = [] linemap = []
if key not in cls.cache:
mod = _reload_python_module(key, path)
# another thread might set this first mod = _reload_python_module(key, path)
cls.cache.setdefault(key, mod)
# unzip into separate lines/nodes lists
cls.linemaps[path] = list(zip(*linemap))
if attrs is not None: # unzip into separate lines/nodes lists
for k, v in attrs.items(): cls.linemaps[path] = list(zip(*linemap))
setattr(mod, k, v)
if not (linemap or attrs): if attrs is not None:
mod._reload_in_subproc = functools.partial( # type: ignore[attr-defined] for k, v in attrs.items():
_reload_python_module_in_subproc, key, path setattr(mod, k, v)
)
return cls.cache[key] if not (linemap or attrs):
mod._reload_in_subproc = functools.partial( # type: ignore[attr-defined]
_reload_python_module_in_subproc, key, path
)
cls.modules.append(mod)
return mod
@classmethod
def cache_clear(cls) -> None:
for mod in cls.modules:
try:
assert mod.__file__
os.remove(mod.__file__)
except FileNotFoundError:
pass
cls.modules.clear()
@classmethod @classmethod
@functools.lru_cache(None) @functools.lru_cache(None)

View File

@ -760,7 +760,7 @@ def _compile_fx_inner(
# to return the string directly. # to return the string directly.
return compiled_graph return compiled_graph
compiled_graph = FxGraphCache.post_compile( compiled_graph = FxGraphCache.post_compile(
compiled_graph, example_inputs, cudagraphs compiled_graph, example_inputs, cudagraphs, gm
) )
log.debug("FX codegen and compilation took %.3fs", time.time() - start) log.debug("FX codegen and compilation took %.3fs", time.time() - start)
@ -1015,6 +1015,7 @@ def fx_codegen_and_compile(
compiled_graph = CompiledFxGraph( compiled_graph = CompiledFxGraph(
compiled_fn, compiled_fn,
graph, graph,
gm,
output_strides, output_strides,
V.graph.disable_cudagraphs_reason, V.graph.disable_cudagraphs_reason,
metrics_helper.get_deltas(), metrics_helper.get_deltas(),
@ -1289,6 +1290,8 @@ def fw_compiler_freezing(
aot_example_inputs, # type: ignore[arg-type] aot_example_inputs, # type: ignore[arg-type]
) )
setattr(opt_model, "_has_frozen_params", True) # noqa: B010
aot_example_inputs = [aot_example_inputs[ind] for ind in preserved_arg_indices] aot_example_inputs = [aot_example_inputs[ind] for ind in preserved_arg_indices]
num_fixed = len(preserved_arg_indices) - num_example_inputs num_fixed = len(preserved_arg_indices) - num_example_inputs

View File

@ -77,7 +77,8 @@ def benchmark_all_kernels(benchmark_name, benchmark_all_configs):
from torch._inductor.codecache import PyCodeCache from torch._inductor.codecache import PyCodeCache
nfound = 0 nfound = 0
for kernel_key, kernel_mod in PyCodeCache.cache.items(): for kernel_mod in PyCodeCache.modules:
kernel_key = kernel_mod.key
if not hasattr(kernel_mod, "get_args") or not hasattr(kernel_mod, "call"): if not hasattr(kernel_mod, "get_args") or not hasattr(kernel_mod, "call"):
continue continue

View File

@ -890,7 +890,7 @@ class ExecutionTraceObserver(_ITraceObserver):
kernel_files = [ kernel_files = [
v.__file__ v.__file__
for v in PyCodeCache.cache.values() for v in PyCodeCache.modules
if getattr(v, "__file__", None) is not None if getattr(v, "__file__", None) is not None
] ]
work_dir, file_name = os.path.split(self._output_file_path) work_dir, file_name = os.path.split(self._output_file_path)