mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
This improves tracing performance of optimizer tracing significantly (2x). In essence this just removes the recursion from propagate because it is not necessary. ListVariables and ConstDictVariables already contain the guards from the items contained in them. Adds two other optimizations for special cases of `recursively_contains` helps with https://github.com/pytorch/torchdynamo/issues/1803 Pull Request resolved: https://github.com/pytorch/pytorch/pull/89827 Approved by: https://github.com/anijain2305, https://github.com/jansel
440 lines
15 KiB
Python
440 lines
15 KiB
Python
import collections
|
|
import dataclasses
|
|
import functools
|
|
import inspect
|
|
from typing import Dict, List
|
|
|
|
from .. import variables
|
|
from ..bytecode_transformation import create_instruction
|
|
from ..eval_frame import skip_code
|
|
from ..exc import unimplemented
|
|
from ..source import AttrSource, GlobalWeakRefSource
|
|
from ..utils import global_key_name, istensor
|
|
from .base import MutableLocal, VariableTracker
|
|
from .constant import ConstantVariable
|
|
from .tensor import TensorVariable
|
|
|
|
|
|
class ConstDictVariable(VariableTracker):
|
|
def __init__(self, items, user_cls, recursively_contains=None, **kwargs):
|
|
super(ConstDictVariable, self).__init__(
|
|
recursively_contains=recursively_contains, **kwargs
|
|
)
|
|
|
|
self.guards.update(VariableTracker.propagate(items.values())["guards"])
|
|
self.items = items
|
|
self.user_cls = user_cls
|
|
|
|
def as_proxy(self):
|
|
return {k: v.as_proxy() for k, v in self.items.items()}
|
|
|
|
def python_type(self):
|
|
return self.user_cls
|
|
|
|
def reconstruct(self, codegen):
|
|
for key, value in self.items.items():
|
|
if istensor(key):
|
|
codegen.extend_output(
|
|
[
|
|
codegen.create_load_global(global_key_name(key), add=True),
|
|
create_instruction("CALL_FUNCTION", 0),
|
|
]
|
|
)
|
|
else:
|
|
codegen.append_output(codegen.create_load_const(key))
|
|
codegen(self.items[key])
|
|
|
|
return [create_instruction("BUILD_MAP", len(self.items))]
|
|
|
|
def getitem_const(self, arg: VariableTracker):
|
|
return self.items[ConstDictVariable.get_key(arg)].add_options(self, arg)
|
|
|
|
def call_method(
|
|
self,
|
|
tx,
|
|
name,
|
|
args: "List[VariableTracker]",
|
|
kwargs: "Dict[str, VariableTracker]",
|
|
) -> "VariableTracker":
|
|
from . import ConstantVariable, TupleVariable
|
|
|
|
options = VariableTracker.propagate(self, args, kwargs.values())
|
|
val = self.items
|
|
|
|
if name == "__getitem__":
|
|
return self.getitem_const(args[0])
|
|
|
|
elif name == "items":
|
|
assert not (args or kwargs)
|
|
return TupleVariable(
|
|
[
|
|
TupleVariable(
|
|
[
|
|
ConstDictVariable._key_to_var(
|
|
tx,
|
|
k,
|
|
**options,
|
|
),
|
|
v,
|
|
],
|
|
**options,
|
|
)
|
|
for k, v in val.items()
|
|
],
|
|
**options,
|
|
)
|
|
elif name == "keys":
|
|
assert not (args or kwargs)
|
|
return TupleVariable(
|
|
[
|
|
ConstDictVariable._key_to_var(
|
|
tx,
|
|
k,
|
|
**options,
|
|
)
|
|
for k in val.keys()
|
|
],
|
|
**options,
|
|
)
|
|
|
|
elif name == "values":
|
|
assert not (args or kwargs)
|
|
return TupleVariable(list(val.values()), **options)
|
|
elif name == "__len__":
|
|
assert not (args or kwargs)
|
|
return ConstantVariable(len(self.items), **options)
|
|
elif (
|
|
name == "__setitem__"
|
|
and args
|
|
and ConstDictVariable.is_valid_key(args[0])
|
|
and self.mutable_local
|
|
):
|
|
assert not kwargs and len(args) == 2
|
|
k = ConstDictVariable.get_key(args[0])
|
|
|
|
if istensor(k):
|
|
tx.store_dict_key(global_key_name(k), k)
|
|
newval = collections.OrderedDict(val)
|
|
newval[k] = args[1]
|
|
|
|
new_rec_contains = self.recursively_contains.union(
|
|
args[1].recursively_contains
|
|
)
|
|
if args[1].mutable_local is not None:
|
|
new_rec_contains.add(args[1].mutable_local)
|
|
|
|
return tx.replace_all(
|
|
self,
|
|
self.modifed(newval, new_rec_contains, **options),
|
|
)
|
|
elif (
|
|
name in ("pop", "get")
|
|
and args
|
|
and ConstDictVariable.is_valid_key(args[0])
|
|
and ConstDictVariable.get_key(args[0]) not in self.items
|
|
and len(args) == 2
|
|
):
|
|
# missing item, return the default value
|
|
return args[1].add_options(options)
|
|
elif (
|
|
name == "pop"
|
|
and args
|
|
and ConstDictVariable.is_valid_key(args[0])
|
|
and self.mutable_local
|
|
):
|
|
newval = collections.OrderedDict(val)
|
|
result = newval.pop(ConstDictVariable.get_key(args[0]))
|
|
tx.replace_all(self, self.modifed(newval, None, **options))
|
|
return result.add_options(options)
|
|
elif (
|
|
name == "update"
|
|
and args
|
|
and isinstance(args[0], ConstDictVariable)
|
|
and self.mutable_local
|
|
):
|
|
newval = collections.OrderedDict(val)
|
|
newval.update(args[0].items)
|
|
new_rec_contains = self.recursively_contains.union(
|
|
args[0].recursively_contains
|
|
)
|
|
result = self.modifed(
|
|
newval, recursively_contains=new_rec_contains, **options
|
|
)
|
|
return tx.replace_all(self, result)
|
|
elif (
|
|
name in ("get", "__getattr__")
|
|
and args
|
|
and ConstDictVariable.is_valid_key(args[0])
|
|
and ConstDictVariable.get_key(args[0]) in self.items
|
|
):
|
|
result = self.items[ConstDictVariable.get_key(args[0])]
|
|
return result.add_options(options)
|
|
elif (
|
|
name == "__contains__" and args and ConstDictVariable.is_valid_key(args[0])
|
|
):
|
|
return ConstantVariable(
|
|
ConstDictVariable.get_key(args[0]) in self.items, **options
|
|
)
|
|
else:
|
|
return super().call_method(tx, name, args, kwargs)
|
|
|
|
def modifed(self, items, recursively_contains, **options):
|
|
"""a copy of self with different items"""
|
|
return self.clone(
|
|
items=items, recursively_contains=recursively_contains, **options
|
|
)
|
|
|
|
def unpack_var_sequence(self, tx):
|
|
options = VariableTracker.propagate([self])
|
|
val = self.items
|
|
result = [ConstDictVariable._key_to_var(tx, k, **options) for k in val.keys()]
|
|
return result
|
|
|
|
@classmethod
|
|
def get_key(cls, arg: VariableTracker):
|
|
if isinstance(arg, TensorVariable) and arg.specialized_value is not None:
|
|
return arg.specialized_value
|
|
else:
|
|
return arg.as_python_constant()
|
|
|
|
@classmethod
|
|
def is_valid_key(cls, key):
|
|
return (
|
|
key.is_python_constant()
|
|
or isinstance(key, TensorVariable)
|
|
and key.specialized_value is not None
|
|
)
|
|
|
|
@classmethod
|
|
def _key_to_var(cls, tx, key, **options):
|
|
from .builder import VariableBuilder
|
|
|
|
if istensor(key):
|
|
return VariableBuilder(tx, GlobalWeakRefSource(global_key_name(key)))(key)
|
|
else:
|
|
assert ConstantVariable.is_literal(key)
|
|
return ConstantVariable(key, **options)
|
|
|
|
|
|
class DefaultDictVariable(ConstDictVariable):
|
|
def __init__(self, items, user_cls, default_factory=None, **kwargs):
|
|
super(DefaultDictVariable, self).__init__(items, user_cls, **kwargs)
|
|
assert user_cls is collections.defaultdict
|
|
self.default_factory = default_factory
|
|
|
|
def call_method(
|
|
self,
|
|
tx,
|
|
name,
|
|
args: "List[VariableTracker]",
|
|
kwargs: "Dict[str, VariableTracker]",
|
|
) -> "VariableTracker":
|
|
from . import ListVariable, TupleVariable
|
|
|
|
options = VariableTracker.propagate(self, args, kwargs.values())
|
|
|
|
if name == "__getitem__":
|
|
k = ConstDictVariable.get_key(args[0])
|
|
|
|
if k in self.items:
|
|
return self.getitem_const(args[0])
|
|
else:
|
|
if self.default_factory is None:
|
|
raise KeyError(f"{k}")
|
|
else:
|
|
if istensor(k):
|
|
tx.store_dict_key(global_key_name(k), k)
|
|
new_val = collections.OrderedDict(self.items)
|
|
if self.default_factory is list:
|
|
default_var = ListVariable([], mutable_local=MutableLocal())
|
|
elif self.default_factory is tuple:
|
|
default_var = TupleVariable([], mutable_local=MutableLocal())
|
|
elif self.default_factory is dict:
|
|
default_var = ConstDictVariable(
|
|
{}, dict, mutable_local=MutableLocal()
|
|
)
|
|
else:
|
|
unimplemented(
|
|
f"defaultdict with default_factory = {self.default_factory}"
|
|
)
|
|
new_val[k] = default_var
|
|
new_rec_contains = self.recursively_contains.union(
|
|
default_var.recursively_contains
|
|
)
|
|
if default_var.mutable_local is not None:
|
|
new_rec_contains.add(default_var.mutable_local)
|
|
tx.replace_all(
|
|
self, self.modifed(new_val, new_rec_contains, **options)
|
|
)
|
|
return default_var
|
|
else:
|
|
return super().call_method(tx, name, args, kwargs)
|
|
|
|
|
|
class DataClassVariable(ConstDictVariable):
|
|
"""
|
|
This is a bit of a hack to deal with
|
|
transformers.file_utils.ModelOutput() from huggingface.
|
|
|
|
ModelOutput causes trouble because it a a mix of a dataclass and a
|
|
OrderedDict and it calls super() methods implemented in C.
|
|
"""
|
|
|
|
# ModelOutput() excludes None, though generic datclasses don't
|
|
include_none = False
|
|
|
|
@staticmethod
|
|
@functools.lru_cache(None)
|
|
def _patch_once():
|
|
from transformers.file_utils import ModelOutput
|
|
|
|
for obj in ModelOutput.__dict__.values():
|
|
if callable(obj):
|
|
skip_code(obj.__code__)
|
|
|
|
@staticmethod
|
|
def is_matching_cls(cls):
|
|
try:
|
|
from transformers.file_utils import ModelOutput
|
|
|
|
return issubclass(cls, ModelOutput)
|
|
except ImportError:
|
|
return False
|
|
|
|
@classmethod
|
|
def is_matching_object(cls, obj):
|
|
return cls.is_matching_cls(type(obj))
|
|
|
|
@classmethod
|
|
def create(cls, user_cls, args, kwargs, options):
|
|
DataClassVariable._patch_once()
|
|
|
|
skip_code(user_cls.__init__.__code__)
|
|
keys = [f.name for f in dataclasses.fields(user_cls)]
|
|
bound = inspect.signature(user_cls).bind(*args, **kwargs)
|
|
bound.apply_defaults()
|
|
assert set(bound.arguments.keys()) == set(keys)
|
|
items = collections.OrderedDict()
|
|
for key in keys:
|
|
val = bound.arguments[key]
|
|
if isinstance(val, VariableTracker):
|
|
items[key] = val
|
|
else:
|
|
if cls.include_none:
|
|
assert variables.ConstantVariable.is_literal(val)
|
|
items[key] = variables.ConstantVariable(val)
|
|
else:
|
|
assert val is None, f"unexpected {val}"
|
|
|
|
if len(items) == 1 and not isinstance(items[keys[0]], variables.TensorVariable):
|
|
unimplemented("DataClassVariable iterator constructor")
|
|
# TODO(jansel): implement unpacking logic in ModelOutput.__post_init__
|
|
|
|
return cls(items, user_cls, **options)
|
|
|
|
@classmethod
|
|
def wrap(cls, builder, obj):
|
|
user_cls = type(obj)
|
|
keys = [f.name for f in dataclasses.fields(user_cls)]
|
|
|
|
excluded = []
|
|
items = collections.OrderedDict()
|
|
for key in keys:
|
|
# __init__ function of a dataclass might not have yet defined the key
|
|
if hasattr(obj, key):
|
|
val = getattr(obj, key)
|
|
var = builder.__class__(
|
|
tx=builder.tx, source=AttrSource(builder.source, key)
|
|
)(val)
|
|
if val is not None or cls.include_none:
|
|
items[key] = var
|
|
else:
|
|
excluded.append(var)
|
|
return cls(
|
|
items, user_cls, **VariableTracker.propagate(excluded, items.values())
|
|
)
|
|
|
|
def __init__(self, items, user_cls, **options):
|
|
super(DataClassVariable, self).__init__(items, user_cls, **options)
|
|
assert self.is_matching_cls(user_cls)
|
|
|
|
def as_proxy(self):
|
|
raise NotImplementedError()
|
|
|
|
def reconstruct(self, codegen):
|
|
codegen.extend_output([codegen._create_load_const(self.user_cls)])
|
|
keys = tuple(self.items.keys())
|
|
for key in keys:
|
|
codegen(self.items[key])
|
|
return [
|
|
codegen.create_load_const(keys),
|
|
create_instruction("CALL_FUNCTION_KW", len(keys)),
|
|
]
|
|
|
|
def call_method(
|
|
self,
|
|
tx,
|
|
name,
|
|
args: "List[VariableTracker]",
|
|
kwargs: "Dict[str, VariableTracker]",
|
|
) -> "VariableTracker":
|
|
options = VariableTracker.propagate(self, args, kwargs.values())
|
|
if name == "__getitem__":
|
|
assert not kwargs and len(args) == 1
|
|
index = args[0].as_python_constant()
|
|
if isinstance(index, str):
|
|
return self.items[index].add_options(options)
|
|
else:
|
|
return (
|
|
self.call_method(tx, "to_tuple", [], {})
|
|
.call_method(tx, "__getitem__", args, kwargs)
|
|
.add_options(options)
|
|
)
|
|
elif name == "to_tuple":
|
|
assert not (args or kwargs)
|
|
return variables.TupleVariable(list(self.items.values()), **options)
|
|
elif name == "__setattr__":
|
|
name = "__setitem__"
|
|
return super(DataClassVariable, self).call_method(tx, name, args, kwargs)
|
|
|
|
def var_getattr(self, tx, name: str) -> "VariableTracker":
|
|
if name in self.items:
|
|
return self.call_method(
|
|
tx, "__getitem__", [variables.ConstantVariable(name)], {}
|
|
)
|
|
elif not self.include_none:
|
|
defaults = {f.name: f.default for f in dataclasses.fields(self.user_cls)}
|
|
if name in defaults:
|
|
assert variables.ConstantVariable.is_literal(defaults[name])
|
|
return variables.ConstantVariable(defaults[name]).add_options(self)
|
|
super(DataClassVariable, self).var_getattr(tx, name)
|
|
|
|
|
|
class HFPretrainedConfigVariable(VariableTracker):
|
|
"""
|
|
Hack for HuggingFace PretrainedConfig
|
|
"""
|
|
|
|
@staticmethod
|
|
def is_matching_cls(cls):
|
|
try:
|
|
from transformers.configuration_utils import PretrainedConfig
|
|
|
|
return issubclass(cls, PretrainedConfig)
|
|
except ImportError:
|
|
return False
|
|
|
|
@classmethod
|
|
def is_matching_object(cls, obj):
|
|
return cls.is_matching_cls(type(obj))
|
|
|
|
def __init__(self, obj, **kwargs):
|
|
super(HFPretrainedConfigVariable, self).__init__(**kwargs)
|
|
self.obj = obj
|
|
assert self.is_matching_cls(type(obj))
|
|
|
|
def var_getattr(self, tx, name: str) -> "VariableTracker":
|
|
from . import ConstantVariable
|
|
|
|
return ConstantVariable(getattr(self.obj, name))
|