mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
This improves tracing performance of optimizer tracing significantly (2x). In essence this just removes the recursion from propagate because it is not necessary. ListVariables and ConstDictVariables already contain the guards from the items contained in them. Adds two other optimizations for special cases of `recursively_contains` helps with https://github.com/pytorch/torchdynamo/issues/1803 Pull Request resolved: https://github.com/pytorch/pytorch/pull/89827 Approved by: https://github.com/anijain2305, https://github.com/jansel
297 lines
9.0 KiB
Python
297 lines
9.0 KiB
Python
import collections
|
|
from typing import Any, Callable, Dict, List, Optional, Set
|
|
|
|
from .. import variables
|
|
from ..exc import unimplemented
|
|
from ..source import AttrSource, Source
|
|
from ..utils import dict_values, identity, istype, odict_values
|
|
|
|
|
|
class MutableLocal:
|
|
"""
|
|
Marker used to indicate this (list, iter, etc) was constructed in
|
|
local scope and can be mutated safely in analysis without leaking
|
|
state.
|
|
"""
|
|
|
|
def __hash__(self):
|
|
return id(self)
|
|
|
|
def __eq__(self, other):
|
|
return self is other
|
|
|
|
|
|
# metaclass to call post_init
|
|
class HasPostInit(type):
|
|
def __call__(cls, *args, **kwargs):
|
|
obj = type.__call__(cls, *args, **kwargs)
|
|
obj.__post_init__(*args, **kwargs)
|
|
return obj
|
|
|
|
|
|
class VariableTracker(object, metaclass=HasPostInit):
|
|
"""
|
|
Base class for tracked locals and stack values
|
|
|
|
VariableTracker instances are immutable and should be copied in
|
|
order to change them.
|
|
"""
|
|
|
|
# fields to leave unmodified in apply()
|
|
_nonvar_fields = ["value"]
|
|
|
|
@staticmethod
|
|
def propagate(*vars: List[List["VariableTracker"]]):
|
|
"""Combine the guards from many VariableTracker into **kwargs for a new instance"""
|
|
guards = set()
|
|
|
|
def visit(var):
|
|
if type(var) in (list, tuple, dict_values, odict_values):
|
|
for i in var:
|
|
visit(i)
|
|
else:
|
|
assert isinstance(var, VariableTracker), typestr(var)
|
|
guards.update(var.guards)
|
|
|
|
visit(vars)
|
|
return {
|
|
"guards": guards,
|
|
}
|
|
|
|
def clone(self, **kwargs):
|
|
"""Shallow copy with some (optional) changes"""
|
|
args = dict(self.__dict__)
|
|
args.update(kwargs)
|
|
return self.__class__(**args)
|
|
|
|
@classmethod
|
|
def copy(cls, value):
|
|
"""Deeper (but not full) copy, leaving FX and user objects alone"""
|
|
return cls.apply(identity, value)
|
|
|
|
@classmethod
|
|
def apply(
|
|
cls,
|
|
fn: Callable[["VariableTracker"], "VariableTracker"],
|
|
value,
|
|
cache=None,
|
|
skip_fn=lambda _: False, # Whether we should skip applying to this var
|
|
):
|
|
"""
|
|
Walk this object and call fn on all the VariableTracker
|
|
instances to produce a new VariableTracker with the results.
|
|
"""
|
|
if cache is None:
|
|
cache = dict()
|
|
|
|
idx = id(value)
|
|
if idx in cache:
|
|
return cache[idx][0]
|
|
|
|
if isinstance(value, VariableTracker):
|
|
if not skip_fn(value):
|
|
updated_dict = dict(value.__dict__)
|
|
for key in updated_dict.keys():
|
|
if key not in value._nonvar_fields:
|
|
updated_dict[key] = cls.apply(
|
|
fn, updated_dict[key], cache, skip_fn
|
|
)
|
|
result = fn(value.clone(**updated_dict))
|
|
else:
|
|
result = fn(value)
|
|
|
|
elif istype(value, list):
|
|
result = [cls.apply(fn, v, cache, skip_fn) for v in value]
|
|
elif istype(value, tuple):
|
|
result = tuple(cls.apply(fn, v, cache, skip_fn) for v in value)
|
|
elif istype(value, collections.OrderedDict):
|
|
result = collections.OrderedDict(
|
|
cls.apply(fn, v, cache, skip_fn) for v in value.items()
|
|
)
|
|
elif istype(value, dict):
|
|
result = {
|
|
k: cls.apply(fn, v, cache, skip_fn) for k, v in list(value.items())
|
|
}
|
|
else:
|
|
result = value
|
|
|
|
# save `value` to keep it alive and ensure id() isn't reused
|
|
cache[idx] = (result, value)
|
|
return result
|
|
|
|
def add_guard(self, guard):
|
|
return self.clone(guards=set.union(self.guards, {guard}))
|
|
|
|
def add_guards(self, guards):
|
|
if guards is None:
|
|
return self
|
|
assert isinstance(guards, set)
|
|
return self.clone(guards=set.union(self.guards, guards))
|
|
|
|
def add_options(self, options, *more):
|
|
if more:
|
|
return self.add_options(options).add_options(*more)
|
|
if isinstance(options, VariableTracker):
|
|
return self.add_guards(options.guards)
|
|
assert isinstance(options, dict)
|
|
return self.add_guards(options.get("guards", set()))
|
|
|
|
def __str__(self):
|
|
return f"{self.__class__.__name__}()"
|
|
|
|
def __repr__(self):
|
|
return str(self)
|
|
|
|
def python_type(self):
|
|
raise NotImplementedError(f"{self} has no type")
|
|
|
|
def as_python_constant(self):
|
|
"""For constants"""
|
|
raise NotImplementedError(f"{self} is not a constant")
|
|
|
|
def is_python_constant(self):
|
|
try:
|
|
self.as_python_constant()
|
|
return True
|
|
except NotImplementedError:
|
|
return False
|
|
|
|
def as_specialized(self, tx):
|
|
"""
|
|
For specialized variables, return itself,
|
|
For unspecialized variables, convert to constant variable and return.
|
|
"""
|
|
return self
|
|
|
|
def can_make_guard(self):
|
|
try:
|
|
self.make_guard(None)
|
|
return True
|
|
except NotImplementedError:
|
|
return False
|
|
|
|
def make_guard(self, fn):
|
|
if self.source:
|
|
return self.source.make_guard(fn)
|
|
raise NotImplementedError()
|
|
|
|
def replace_guards(self, guards, *fns):
|
|
name = self.source.name()
|
|
new_guards = {g for g in (guards or []) if g.name != name}
|
|
new_guards.update(self.source.make_guard(fn) for fn in fns)
|
|
return new_guards
|
|
|
|
def const_getattr(self, tx, name: str) -> Any:
|
|
"""getattr(self, name) returning a python constant"""
|
|
raise NotImplementedError()
|
|
|
|
def var_getattr(self, tx, name: str) -> "VariableTracker":
|
|
"""getattr(self, name) returning a new variable"""
|
|
options = VariableTracker.propagate(self)
|
|
value = self.const_getattr(tx, name)
|
|
if not variables.ConstantVariable.is_literal(value):
|
|
raise NotImplementedError()
|
|
if self.source:
|
|
options["source"] = AttrSource(self.source, name)
|
|
return variables.ConstantVariable(value, **options)
|
|
|
|
def is_proxy(self):
|
|
try:
|
|
self.as_proxy()
|
|
return True
|
|
except NotImplementedError:
|
|
return False
|
|
|
|
def as_proxy(self):
|
|
raise NotImplementedError(str(self))
|
|
|
|
def reconstruct(self, codegen):
|
|
raise NotImplementedError()
|
|
|
|
def unpack_var_sequence(self, tx):
|
|
raise NotImplementedError()
|
|
|
|
def has_unpack_var_sequence(self, tx):
|
|
try:
|
|
self.unpack_var_sequence(tx)
|
|
return True
|
|
except NotImplementedError:
|
|
return False
|
|
|
|
def num_parameters(self):
|
|
unimplemented(f"num_parameters: {self}")
|
|
|
|
def call_hasattr(self, tx, name: str) -> "VariableTracker":
|
|
unimplemented(f"hasattr: {self}")
|
|
|
|
def call_function(
|
|
self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
|
|
) -> "VariableTracker":
|
|
unimplemented(f"call_function {self} {args} {kwargs}")
|
|
|
|
def call_method(
|
|
self,
|
|
tx,
|
|
name,
|
|
args: "List[VariableTracker]",
|
|
kwargs: "Dict[str, VariableTracker]",
|
|
) -> "VariableTracker":
|
|
if name == "__len__" and self.has_unpack_var_sequence(tx):
|
|
assert not (args or kwargs)
|
|
return variables.ConstantVariable(
|
|
len(self.unpack_var_sequence(tx)), **VariableTracker.propagate(self)
|
|
)
|
|
elif (
|
|
name == "__getattr__"
|
|
and len(args) == 1
|
|
and args[0].is_python_constant()
|
|
and not kwargs
|
|
):
|
|
return self.var_getattr(tx, args[0].as_python_constant()).add_options(
|
|
self, args[0]
|
|
)
|
|
raise unimplemented(f"call_method {self} {name} {args} {kwargs}")
|
|
|
|
def __init__(
|
|
self,
|
|
guards: Optional[Set] = None,
|
|
source: Source = None,
|
|
mutable_local: MutableLocal = None,
|
|
recursively_contains: Optional[Set] = None,
|
|
):
|
|
super(VariableTracker, self).__init__()
|
|
self.guards = guards or set()
|
|
self.source = source
|
|
self.mutable_local = mutable_local
|
|
self.recursively_contains = (
|
|
recursively_contains # provides hint to replace_all when replacing vars
|
|
)
|
|
|
|
def __post_init__(self, *args, **kwargs):
|
|
if self.recursively_contains is None:
|
|
self.recursively_contains = set()
|
|
|
|
def aggregate_mutables(var):
|
|
self.recursively_contains.update(var.recursively_contains)
|
|
if var.mutable_local is not None:
|
|
self.recursively_contains.add(var.mutable_local)
|
|
|
|
return var
|
|
|
|
VariableTracker.apply(
|
|
aggregate_mutables, self, skip_fn=lambda var: var is not self
|
|
)
|
|
|
|
assert None not in self.recursively_contains
|
|
|
|
|
|
def typestr(*objs):
|
|
if len(objs) == 1:
|
|
(obj,) = objs
|
|
if isinstance(obj, VariableTracker):
|
|
return str(obj)
|
|
else:
|
|
return type(obj).__name__
|
|
else:
|
|
return " ".join(map(typestr, objs))
|