cache loaded python modules (#149910)

I am splitting caching the loading of modules from the caching the codegen since its trivial and much easier.
Module loading is 50% of the cost, and codegen is 50%  of maybe_append choice on full graph model. which is 40% of total compile time.

<img width="434" alt="Screenshot 2025-03-24 at 4 35 12 PM" src="https://github.com/user-attachments/assets/aa851c6a-bde9-43f8-b12d-e439504ef62c" />

running mm_loop benchmark,
before this change:
67947323682

after this change:
25845073249

2.6X faster.

it seems that the cache was there then got dropped. I added benchmark so it wont be dropped again by mistake.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/149910
Approved by: https://github.com/eellison, https://github.com/aorenste
ghstack dependencies: #149932
This commit is contained in:
Laith Sakka 2025-03-25 21:56:53 -07:00 committed by PyTorch MergeBot
parent 48cff64a54
commit 128b32f363

View File

@ -2837,6 +2837,11 @@ class PyCodeCache:
# than once, but attach different attributes, i.e., due to different
# constant values.
modules: list[ModuleType] = []
# Modules loaded without extra attributes are stored here, those do not
# need to be re-loaded.
modules_no_attr: dict[str, ModuleType] = {}
linemaps: dict[str, list[tuple[Any, ...]]] = {}
@classmethod
@ -2844,15 +2849,9 @@ class PyCodeCache:
return write(source_code, "py", extra=extra)
@classmethod
def load(
cls,
source_code: str,
extra: str = "",
linemap: Optional[list[tuple[int, str]]] = None,
attrs: Optional[dict[str, Any]] = None,
) -> ModuleType:
def load(cls, source_code: str, extra: str = "") -> ModuleType:
key, path = write(source_code, "py", extra=extra)
return cls.load_by_key_path(key, path, linemap, attrs)
return cls.load_by_key_path(key, path)
@classmethod
def load_by_key_path(
@ -2865,6 +2864,10 @@ class PyCodeCache:
if linemap is None:
linemap = []
# we only cache when attrs is None
if attrs is None and path in cls.modules_no_attr:
return cls.modules_no_attr[path]
in_toplevel = in_toplevel_process()
mod = _reload_python_module(key, path, set_sys_modules=in_toplevel)
@ -2877,6 +2880,10 @@ class PyCodeCache:
setattr(mod, k, v)
if in_toplevel:
# we only cache when attrs is None
if attrs is None:
cls.modules_no_attr[path] = mod
cls.modules.append(mod)
return mod
@ -2894,6 +2901,7 @@ class PyCodeCache:
except FileNotFoundError:
pass
cls.modules.clear()
cls.modules_no_attr.clear()
@classmethod
@functools.lru_cache(None)