mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Dynamo was aggressively specializing on lazy VTs over `set_name_hint` in `STORE_FAST`, etc., and `isinstance` in `LOAD_FAST_CHECK`. This causes regional `torch.compile` from optimizing ComfyUI GGUF + LoRA to either (1). exceed the recompialtion limit of 8, which results in suboptimal performance, and (2). even if recompilation limit is increased, the compilation time gets unnecessarily high (180s v.s. 20s for Flux). This patch fixes the recompilation issue. Pull Request resolved: https://github.com/pytorch/pytorch/pull/156891 Approved by: https://github.com/williamwen42, https://github.com/mlazos
225 lines
7.3 KiB
Python
225 lines
7.3 KiB
Python
import collections
|
|
import functools
|
|
import inspect
|
|
from typing import Any, Callable, final, Optional, Union
|
|
from typing_extensions import Self
|
|
|
|
from ..utils import is_function_or_wrapper
|
|
from .base import VariableTracker
|
|
from .tensor import SymNodeVariable
|
|
|
|
|
|
class LazyCache:
|
|
"""Container to cache the real VariableTracker"""
|
|
|
|
def __init__(self, value: Any, source: Any) -> None:
|
|
if not isinstance(value, LazySymNodeFormatString):
|
|
assert source
|
|
self.value = value
|
|
self.source = source
|
|
self.name_hint: Optional[str] = None
|
|
self.vt: Optional[VariableTracker] = None
|
|
|
|
def realize(self) -> None:
|
|
assert self.vt is None
|
|
from ..symbolic_convert import InstructionTranslator
|
|
from . import builder
|
|
|
|
tx = InstructionTranslator.current_tx()
|
|
|
|
if isinstance(self.value, LazySymNodeFormatString):
|
|
self.vt = builder.SourcelessBuilder.create(tx, self.value)
|
|
else:
|
|
self.vt = builder.VariableBuilder(tx, self.source)(self.value)
|
|
|
|
if self.name_hint is not None:
|
|
self.vt.set_name_hint(self.name_hint)
|
|
|
|
del self.value
|
|
del self.source
|
|
del self.name_hint
|
|
|
|
|
|
@final
|
|
class LazyVariableTracker(VariableTracker):
|
|
"""
|
|
A structure that defers the creation of the actual VariableTracker
|
|
for a given underlying value until it is accessed.
|
|
|
|
The `realize` function invokes VariableTracker.build() to produce the real object.
|
|
Once a LazyVariableTracker has been realized, internal bookkeeping will
|
|
prevent double realization.
|
|
|
|
This object should be utilized for processing containers, or objects that
|
|
reference other objects where we may not want to take on creating all the
|
|
VariableTrackers right away.
|
|
"""
|
|
|
|
_nonvar_fields = {"_cache", *VariableTracker._nonvar_fields}
|
|
|
|
@staticmethod
|
|
def create(value: Any, source: Any, **options: Any) -> "LazyVariableTracker":
|
|
return LazyVariableTracker(LazyCache(value, source), source=source, **options)
|
|
|
|
def __init__(self, _cache: LazyCache, **kwargs: Any) -> None:
|
|
assert isinstance(_cache, LazyCache)
|
|
super().__init__(**kwargs)
|
|
self._cache = _cache
|
|
|
|
def realize(self) -> VariableTracker:
|
|
"""Force construction of the real VariableTracker"""
|
|
if self._cache.vt is None:
|
|
self._cache.realize()
|
|
assert self._cache.vt is not None
|
|
return self._cache.vt
|
|
|
|
def unwrap(self) -> Union[VariableTracker, Self]:
|
|
"""Return the real VariableTracker if it already exists"""
|
|
if self.is_realized():
|
|
assert self._cache.vt is not None
|
|
return self._cache.vt
|
|
return self
|
|
|
|
def is_realized(self) -> bool:
|
|
return self._cache.vt is not None
|
|
|
|
def clone(self, **kwargs: Any) -> VariableTracker:
|
|
assert kwargs.get("_cache", self._cache) is self._cache
|
|
if kwargs.get("source", self.source) is not self.source:
|
|
self.realize()
|
|
return VariableTracker.clone(self.unwrap(), **kwargs)
|
|
|
|
def peek_type(self) -> type[Any]:
|
|
assert not self.is_realized()
|
|
return type(self._cache.value)
|
|
|
|
def peek_value(self) -> Any:
|
|
assert not self.is_realized()
|
|
return self._cache.value
|
|
|
|
def set_name_hint(self, name: str) -> None:
|
|
if self.is_realized():
|
|
self._cache.vt.set_name_hint(name) # type: ignore[union-attr]
|
|
else:
|
|
self._cache.name_hint = name
|
|
|
|
def __str__(self) -> str:
|
|
if self.is_realized():
|
|
return repr(self.unwrap())
|
|
return super().__repr__()
|
|
|
|
def __getattr__(self, item: str) -> Any:
|
|
return getattr(self.realize(), item)
|
|
|
|
# most methods are auto-generated below, these are the ones we want to exclude
|
|
visit = VariableTracker.visit # type: ignore[assignment]
|
|
__repr__ = __str__
|
|
|
|
@classmethod
|
|
def realize_all(
|
|
cls,
|
|
value: Any,
|
|
cache: Optional[dict[int, tuple[Any, Any]]] = None,
|
|
) -> Any:
|
|
"""
|
|
Walk an object and realize all LazyVariableTrackers inside it.
|
|
"""
|
|
if cache is None:
|
|
cache = {}
|
|
|
|
idx = id(value)
|
|
if idx in cache:
|
|
return cache[idx][0]
|
|
|
|
value_cls = type(value)
|
|
if issubclass(value_cls, LazyVariableTracker):
|
|
result = cls.realize_all(value.realize(), cache)
|
|
elif issubclass(value_cls, VariableTracker):
|
|
# update value in-place
|
|
result = value
|
|
value_dict = value.__dict__
|
|
nonvars = value._nonvar_fields
|
|
for key in value_dict:
|
|
if key not in nonvars:
|
|
value_dict[key] = cls.realize_all(value_dict[key], cache)
|
|
elif value_cls is list:
|
|
result = [cls.realize_all(v, cache) for v in value]
|
|
elif value_cls is tuple:
|
|
result = tuple(cls.realize_all(v, cache) for v in value)
|
|
elif value_cls in (dict, collections.OrderedDict):
|
|
result = {k: cls.realize_all(v, cache) for k, v in list(value.items())}
|
|
else:
|
|
result = value
|
|
|
|
# save `value` to keep it alive and ensure id() isn't reused
|
|
cache[idx] = (result, value)
|
|
return result
|
|
|
|
def is_hashable(self) -> bool:
|
|
# Checks that the underlying value is hashable without realizing the VT.
|
|
# This is used by ConstDictVariable tracker to find if the key LazyVT
|
|
# can be hashed.
|
|
def _helper(value: Any) -> bool:
|
|
# TODO: Add support for more types
|
|
return (
|
|
inspect.isbuiltin(value)
|
|
or issubclass(type(value), type)
|
|
or is_function_or_wrapper(value)
|
|
)
|
|
|
|
assert not self.is_realized()
|
|
value = self._cache.value
|
|
if isinstance(value, tuple):
|
|
return all(_helper(v) for v in value)
|
|
return _helper(value)
|
|
|
|
def original_value(self) -> Any:
|
|
# Returns the value without realizing the VT.
|
|
assert not self.is_realized()
|
|
return self._cache.value
|
|
|
|
def original_source(self) -> Any:
|
|
# Returns the source without realizing the VT.
|
|
assert not self.is_realized()
|
|
return self._cache.source
|
|
|
|
|
|
class LazySymNodeFormatString:
|
|
def __init__(
|
|
self, sym_node_variable: SymNodeVariable, fmt_spec_var: VariableTracker
|
|
) -> None:
|
|
from .constant import ConstantVariable
|
|
|
|
self.sym_node_var = sym_node_variable
|
|
self.fmt_var = ConstantVariable.create(
|
|
"{:" + fmt_spec_var.as_python_constant() + "}"
|
|
)
|
|
|
|
def __repr__(self) -> str:
|
|
return str.format(
|
|
self.fmt_var.as_python_constant(),
|
|
str(self.sym_node_var.evaluate_expr()),
|
|
)
|
|
|
|
|
|
def _create_realize_and_forward(
|
|
name: str,
|
|
) -> Callable[[LazyVariableTracker, Any, Any], Any]:
|
|
@functools.wraps(getattr(VariableTracker, name))
|
|
def realize_and_forward(
|
|
self: LazyVariableTracker, *args: Any, **kwargs: Any
|
|
) -> Any:
|
|
return getattr(self.realize(), name)(*args, **kwargs)
|
|
|
|
return realize_and_forward
|
|
|
|
|
|
def _populate() -> None:
|
|
for name, value in VariableTracker.__dict__.items():
|
|
if name not in LazyVariableTracker.__dict__:
|
|
if callable(value):
|
|
setattr(LazyVariableTracker, name, _create_realize_and_forward(name))
|
|
|
|
|
|
_populate()
|