pytorch/torch/_dynamo/variables/dicts.py
Michael Lazos 342d78d1a2 Cache guards once per variable tracker, rather than re-propagating them repeatedly (#89827)
This improves tracing performance of optimizer tracing significantly (2x). In essence this just removes the recursion from propagate because it is not necessary. ListVariables and ConstDictVariables already contain the guards from the items contained in them.

Adds two other optimizations for special cases of `recursively_contains`

helps with https://github.com/pytorch/torchdynamo/issues/1803

Pull Request resolved: https://github.com/pytorch/pytorch/pull/89827
Approved by: https://github.com/anijain2305, https://github.com/jansel
2022-12-02 01:45:05 +00:00

440 lines
15 KiB
Python

import collections
import dataclasses
import functools
import inspect
from typing import Dict, List
from .. import variables
from ..bytecode_transformation import create_instruction
from ..eval_frame import skip_code
from ..exc import unimplemented
from ..source import AttrSource, GlobalWeakRefSource
from ..utils import global_key_name, istensor
from .base import MutableLocal, VariableTracker
from .constant import ConstantVariable
from .tensor import TensorVariable
class ConstDictVariable(VariableTracker):
def __init__(self, items, user_cls, recursively_contains=None, **kwargs):
super(ConstDictVariable, self).__init__(
recursively_contains=recursively_contains, **kwargs
)
self.guards.update(VariableTracker.propagate(items.values())["guards"])
self.items = items
self.user_cls = user_cls
def as_proxy(self):
return {k: v.as_proxy() for k, v in self.items.items()}
def python_type(self):
return self.user_cls
def reconstruct(self, codegen):
for key, value in self.items.items():
if istensor(key):
codegen.extend_output(
[
codegen.create_load_global(global_key_name(key), add=True),
create_instruction("CALL_FUNCTION", 0),
]
)
else:
codegen.append_output(codegen.create_load_const(key))
codegen(self.items[key])
return [create_instruction("BUILD_MAP", len(self.items))]
def getitem_const(self, arg: VariableTracker):
return self.items[ConstDictVariable.get_key(arg)].add_options(self, arg)
def call_method(
self,
tx,
name,
args: "List[VariableTracker]",
kwargs: "Dict[str, VariableTracker]",
) -> "VariableTracker":
from . import ConstantVariable, TupleVariable
options = VariableTracker.propagate(self, args, kwargs.values())
val = self.items
if name == "__getitem__":
return self.getitem_const(args[0])
elif name == "items":
assert not (args or kwargs)
return TupleVariable(
[
TupleVariable(
[
ConstDictVariable._key_to_var(
tx,
k,
**options,
),
v,
],
**options,
)
for k, v in val.items()
],
**options,
)
elif name == "keys":
assert not (args or kwargs)
return TupleVariable(
[
ConstDictVariable._key_to_var(
tx,
k,
**options,
)
for k in val.keys()
],
**options,
)
elif name == "values":
assert not (args or kwargs)
return TupleVariable(list(val.values()), **options)
elif name == "__len__":
assert not (args or kwargs)
return ConstantVariable(len(self.items), **options)
elif (
name == "__setitem__"
and args
and ConstDictVariable.is_valid_key(args[0])
and self.mutable_local
):
assert not kwargs and len(args) == 2
k = ConstDictVariable.get_key(args[0])
if istensor(k):
tx.store_dict_key(global_key_name(k), k)
newval = collections.OrderedDict(val)
newval[k] = args[1]
new_rec_contains = self.recursively_contains.union(
args[1].recursively_contains
)
if args[1].mutable_local is not None:
new_rec_contains.add(args[1].mutable_local)
return tx.replace_all(
self,
self.modifed(newval, new_rec_contains, **options),
)
elif (
name in ("pop", "get")
and args
and ConstDictVariable.is_valid_key(args[0])
and ConstDictVariable.get_key(args[0]) not in self.items
and len(args) == 2
):
# missing item, return the default value
return args[1].add_options(options)
elif (
name == "pop"
and args
and ConstDictVariable.is_valid_key(args[0])
and self.mutable_local
):
newval = collections.OrderedDict(val)
result = newval.pop(ConstDictVariable.get_key(args[0]))
tx.replace_all(self, self.modifed(newval, None, **options))
return result.add_options(options)
elif (
name == "update"
and args
and isinstance(args[0], ConstDictVariable)
and self.mutable_local
):
newval = collections.OrderedDict(val)
newval.update(args[0].items)
new_rec_contains = self.recursively_contains.union(
args[0].recursively_contains
)
result = self.modifed(
newval, recursively_contains=new_rec_contains, **options
)
return tx.replace_all(self, result)
elif (
name in ("get", "__getattr__")
and args
and ConstDictVariable.is_valid_key(args[0])
and ConstDictVariable.get_key(args[0]) in self.items
):
result = self.items[ConstDictVariable.get_key(args[0])]
return result.add_options(options)
elif (
name == "__contains__" and args and ConstDictVariable.is_valid_key(args[0])
):
return ConstantVariable(
ConstDictVariable.get_key(args[0]) in self.items, **options
)
else:
return super().call_method(tx, name, args, kwargs)
def modifed(self, items, recursively_contains, **options):
"""a copy of self with different items"""
return self.clone(
items=items, recursively_contains=recursively_contains, **options
)
def unpack_var_sequence(self, tx):
options = VariableTracker.propagate([self])
val = self.items
result = [ConstDictVariable._key_to_var(tx, k, **options) for k in val.keys()]
return result
@classmethod
def get_key(cls, arg: VariableTracker):
if isinstance(arg, TensorVariable) and arg.specialized_value is not None:
return arg.specialized_value
else:
return arg.as_python_constant()
@classmethod
def is_valid_key(cls, key):
return (
key.is_python_constant()
or isinstance(key, TensorVariable)
and key.specialized_value is not None
)
@classmethod
def _key_to_var(cls, tx, key, **options):
from .builder import VariableBuilder
if istensor(key):
return VariableBuilder(tx, GlobalWeakRefSource(global_key_name(key)))(key)
else:
assert ConstantVariable.is_literal(key)
return ConstantVariable(key, **options)
class DefaultDictVariable(ConstDictVariable):
def __init__(self, items, user_cls, default_factory=None, **kwargs):
super(DefaultDictVariable, self).__init__(items, user_cls, **kwargs)
assert user_cls is collections.defaultdict
self.default_factory = default_factory
def call_method(
self,
tx,
name,
args: "List[VariableTracker]",
kwargs: "Dict[str, VariableTracker]",
) -> "VariableTracker":
from . import ListVariable, TupleVariable
options = VariableTracker.propagate(self, args, kwargs.values())
if name == "__getitem__":
k = ConstDictVariable.get_key(args[0])
if k in self.items:
return self.getitem_const(args[0])
else:
if self.default_factory is None:
raise KeyError(f"{k}")
else:
if istensor(k):
tx.store_dict_key(global_key_name(k), k)
new_val = collections.OrderedDict(self.items)
if self.default_factory is list:
default_var = ListVariable([], mutable_local=MutableLocal())
elif self.default_factory is tuple:
default_var = TupleVariable([], mutable_local=MutableLocal())
elif self.default_factory is dict:
default_var = ConstDictVariable(
{}, dict, mutable_local=MutableLocal()
)
else:
unimplemented(
f"defaultdict with default_factory = {self.default_factory}"
)
new_val[k] = default_var
new_rec_contains = self.recursively_contains.union(
default_var.recursively_contains
)
if default_var.mutable_local is not None:
new_rec_contains.add(default_var.mutable_local)
tx.replace_all(
self, self.modifed(new_val, new_rec_contains, **options)
)
return default_var
else:
return super().call_method(tx, name, args, kwargs)
class DataClassVariable(ConstDictVariable):
"""
This is a bit of a hack to deal with
transformers.file_utils.ModelOutput() from huggingface.
ModelOutput causes trouble because it a a mix of a dataclass and a
OrderedDict and it calls super() methods implemented in C.
"""
# ModelOutput() excludes None, though generic datclasses don't
include_none = False
@staticmethod
@functools.lru_cache(None)
def _patch_once():
from transformers.file_utils import ModelOutput
for obj in ModelOutput.__dict__.values():
if callable(obj):
skip_code(obj.__code__)
@staticmethod
def is_matching_cls(cls):
try:
from transformers.file_utils import ModelOutput
return issubclass(cls, ModelOutput)
except ImportError:
return False
@classmethod
def is_matching_object(cls, obj):
return cls.is_matching_cls(type(obj))
@classmethod
def create(cls, user_cls, args, kwargs, options):
DataClassVariable._patch_once()
skip_code(user_cls.__init__.__code__)
keys = [f.name for f in dataclasses.fields(user_cls)]
bound = inspect.signature(user_cls).bind(*args, **kwargs)
bound.apply_defaults()
assert set(bound.arguments.keys()) == set(keys)
items = collections.OrderedDict()
for key in keys:
val = bound.arguments[key]
if isinstance(val, VariableTracker):
items[key] = val
else:
if cls.include_none:
assert variables.ConstantVariable.is_literal(val)
items[key] = variables.ConstantVariable(val)
else:
assert val is None, f"unexpected {val}"
if len(items) == 1 and not isinstance(items[keys[0]], variables.TensorVariable):
unimplemented("DataClassVariable iterator constructor")
# TODO(jansel): implement unpacking logic in ModelOutput.__post_init__
return cls(items, user_cls, **options)
@classmethod
def wrap(cls, builder, obj):
user_cls = type(obj)
keys = [f.name for f in dataclasses.fields(user_cls)]
excluded = []
items = collections.OrderedDict()
for key in keys:
# __init__ function of a dataclass might not have yet defined the key
if hasattr(obj, key):
val = getattr(obj, key)
var = builder.__class__(
tx=builder.tx, source=AttrSource(builder.source, key)
)(val)
if val is not None or cls.include_none:
items[key] = var
else:
excluded.append(var)
return cls(
items, user_cls, **VariableTracker.propagate(excluded, items.values())
)
def __init__(self, items, user_cls, **options):
super(DataClassVariable, self).__init__(items, user_cls, **options)
assert self.is_matching_cls(user_cls)
def as_proxy(self):
raise NotImplementedError()
def reconstruct(self, codegen):
codegen.extend_output([codegen._create_load_const(self.user_cls)])
keys = tuple(self.items.keys())
for key in keys:
codegen(self.items[key])
return [
codegen.create_load_const(keys),
create_instruction("CALL_FUNCTION_KW", len(keys)),
]
def call_method(
self,
tx,
name,
args: "List[VariableTracker]",
kwargs: "Dict[str, VariableTracker]",
) -> "VariableTracker":
options = VariableTracker.propagate(self, args, kwargs.values())
if name == "__getitem__":
assert not kwargs and len(args) == 1
index = args[0].as_python_constant()
if isinstance(index, str):
return self.items[index].add_options(options)
else:
return (
self.call_method(tx, "to_tuple", [], {})
.call_method(tx, "__getitem__", args, kwargs)
.add_options(options)
)
elif name == "to_tuple":
assert not (args or kwargs)
return variables.TupleVariable(list(self.items.values()), **options)
elif name == "__setattr__":
name = "__setitem__"
return super(DataClassVariable, self).call_method(tx, name, args, kwargs)
def var_getattr(self, tx, name: str) -> "VariableTracker":
if name in self.items:
return self.call_method(
tx, "__getitem__", [variables.ConstantVariable(name)], {}
)
elif not self.include_none:
defaults = {f.name: f.default for f in dataclasses.fields(self.user_cls)}
if name in defaults:
assert variables.ConstantVariable.is_literal(defaults[name])
return variables.ConstantVariable(defaults[name]).add_options(self)
super(DataClassVariable, self).var_getattr(tx, name)
class HFPretrainedConfigVariable(VariableTracker):
"""
Hack for HuggingFace PretrainedConfig
"""
@staticmethod
def is_matching_cls(cls):
try:
from transformers.configuration_utils import PretrainedConfig
return issubclass(cls, PretrainedConfig)
except ImportError:
return False
@classmethod
def is_matching_object(cls, obj):
return cls.is_matching_cls(type(obj))
def __init__(self, obj, **kwargs):
super(HFPretrainedConfigVariable, self).__init__(**kwargs)
self.obj = obj
assert self.is_matching_cls(type(obj))
def var_getattr(self, tx, name: str) -> "VariableTracker":
from . import ConstantVariable
return ConstantVariable(getattr(self.obj, name))