mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
cache loaded python modules (#149910)
I am splitting caching the loading of modules from the caching the codegen since its trivial and much easier. Module loading is 50% of the cost, and codegen is 50% of maybe_append choice on full graph model. which is 40% of total compile time. <img width="434" alt="Screenshot 2025-03-24 at 4 35 12 PM" src="https://github.com/user-attachments/assets/aa851c6a-bde9-43f8-b12d-e439504ef62c" /> running mm_loop benchmark, before this change: 67947323682 after this change: 25845073249 2.6X faster. it seems that the cache was there then got dropped. I added benchmark so it wont be dropped again by mistake. Pull Request resolved: https://github.com/pytorch/pytorch/pull/149910 Approved by: https://github.com/eellison, https://github.com/aorenste ghstack dependencies: #149932
This commit is contained in:
parent
48cff64a54
commit
128b32f363
|
|
@ -2837,6 +2837,11 @@ class PyCodeCache:
|
|||
# than once, but attach different attributes, i.e., due to different
|
||||
# constant values.
|
||||
modules: list[ModuleType] = []
|
||||
|
||||
# Modules loaded without extra attributes are stored here, those do not
|
||||
# need to be re-loaded.
|
||||
modules_no_attr: dict[str, ModuleType] = {}
|
||||
|
||||
linemaps: dict[str, list[tuple[Any, ...]]] = {}
|
||||
|
||||
@classmethod
|
||||
|
|
@ -2844,15 +2849,9 @@ class PyCodeCache:
|
|||
return write(source_code, "py", extra=extra)
|
||||
|
||||
@classmethod
|
||||
def load(
|
||||
cls,
|
||||
source_code: str,
|
||||
extra: str = "",
|
||||
linemap: Optional[list[tuple[int, str]]] = None,
|
||||
attrs: Optional[dict[str, Any]] = None,
|
||||
) -> ModuleType:
|
||||
def load(cls, source_code: str, extra: str = "") -> ModuleType:
|
||||
key, path = write(source_code, "py", extra=extra)
|
||||
return cls.load_by_key_path(key, path, linemap, attrs)
|
||||
return cls.load_by_key_path(key, path)
|
||||
|
||||
@classmethod
|
||||
def load_by_key_path(
|
||||
|
|
@ -2865,6 +2864,10 @@ class PyCodeCache:
|
|||
if linemap is None:
|
||||
linemap = []
|
||||
|
||||
# we only cache when attrs is None
|
||||
if attrs is None and path in cls.modules_no_attr:
|
||||
return cls.modules_no_attr[path]
|
||||
|
||||
in_toplevel = in_toplevel_process()
|
||||
mod = _reload_python_module(key, path, set_sys_modules=in_toplevel)
|
||||
|
||||
|
|
@ -2877,6 +2880,10 @@ class PyCodeCache:
|
|||
setattr(mod, k, v)
|
||||
|
||||
if in_toplevel:
|
||||
# we only cache when attrs is None
|
||||
if attrs is None:
|
||||
cls.modules_no_attr[path] = mod
|
||||
|
||||
cls.modules.append(mod)
|
||||
return mod
|
||||
|
||||
|
|
@ -2894,6 +2901,7 @@ class PyCodeCache:
|
|||
except FileNotFoundError:
|
||||
pass
|
||||
cls.modules.clear()
|
||||
cls.modules_no_attr.clear()
|
||||
|
||||
@classmethod
|
||||
@functools.lru_cache(None)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user