pytorch/torch/_dynamo/variables/dicts.py

839 lines
28 KiB
Python

import collections
import dataclasses
import functools
import inspect
import sys
from typing import Any, Dict, List, Optional
import torch
import torch.fx
from .. import variables
from ..bytecode_transformation import create_call_function, create_instruction
from ..eval_frame import skip_code
from ..exc import unimplemented
from ..guards import GuardBuilder, install_guard, make_dupe_guard
from ..source import AttrSource, GetItemSource, GlobalWeakRefSource
from ..utils import global_key_name, istensor, iter_contains
from .base import MutableLocal, VariableTracker
from .constant import ConstantVariable
from .tensor import TensorVariable
class ConstDictVariable(VariableTracker):
def __init__(self, items, user_cls, **kwargs):
super().__init__(**kwargs)
# All the keys are constants
assert not any(isinstance(x, VariableTracker) for x in items)
self.items = items
self.user_cls = user_cls
def as_proxy(self):
return {k: v.as_proxy() for k, v in self.items.items()}
def as_python_constant(self):
return {k: v.as_python_constant() for k, v in self.items.items()}
def python_type(self):
return self.user_cls
def reconstruct(self, codegen):
# instructions to load collections.OrderedDict if necessary
if self.user_cls is collections.OrderedDict:
codegen.extend_output(
[
codegen.create_load_python_module(collections, True),
codegen.create_load_attr("OrderedDict"),
]
)
# instructions to build the dict keys and values
for key in self.items.keys():
if istensor(key):
codegen.append_output(
codegen.create_load_global(global_key_name(key), True, add=True)
)
codegen.extend_output(create_call_function(0, False))
else:
codegen.append_output(codegen.create_load_const(key))
codegen(self.items[key])
# BUILD_MAP and calling collections.OrderedDict if necessary
if self.user_cls is collections.OrderedDict:
return [
create_instruction("BUILD_MAP", arg=len(self.items)),
*create_call_function(1, False),
]
# BUILD_MAP only if user_cls is dict
else:
return [create_instruction("BUILD_MAP", arg=len(self.items))]
def getitem_const(self, arg: VariableTracker):
return self.items[ConstDictVariable.get_key(arg)]
def call_method(
self,
tx,
name,
args: "List[VariableTracker]",
kwargs: "Dict[str, VariableTracker]",
) -> "VariableTracker":
from . import (
ConstantVariable,
ListIteratorVariable,
ListVariable,
TupleVariable,
)
val = self.items
if name == "__getitem__":
assert len(args) == 1
return self.getitem_const(args[0])
elif name == "items":
assert not (args or kwargs)
return TupleVariable(
[
TupleVariable(
items=[
ConstDictVariable._key_to_var(
tx,
k,
),
v,
],
)
for k, v in val.items()
],
)
elif name == "keys":
assert not (args or kwargs)
return SetVariable(
items=[
ConstDictVariable._key_to_var(
tx,
k,
)
for k in val.keys()
],
mutable_local=MutableLocal(),
)
elif name == "values":
assert not (args or kwargs)
return TupleVariable(list(val.values()))
elif name == "copy":
assert not (args or kwargs)
return self.modifed(self.items.copy(), mutable_local=MutableLocal())
elif name == "__len__":
assert not (args or kwargs)
return ConstantVariable.create(len(self.items))
elif (
name == "__setitem__"
and args
and ConstDictVariable.is_valid_key(args[0])
and self.mutable_local
):
assert not kwargs and len(args) == 2
k = ConstDictVariable.get_key(args[0])
if istensor(k):
tx.store_global_weakref(global_key_name(k), k)
newval = dict(val)
newval[k] = args[1]
return tx.replace_all(
self,
self.modifed(newval),
)
elif (
name in ("pop", "get")
and len(args) == 2
and not kwargs
and ConstDictVariable.is_valid_key(args[0])
and ConstDictVariable.get_key(args[0]) not in self.items
):
# missing item, return the default value
return args[1]
elif (
name == "get"
and len(args) == 1
and not kwargs
and ConstDictVariable.is_valid_key(args[0])
and ConstDictVariable.get_key(args[0]) not in self.items
):
return ConstantVariable(None)
elif (
name == "pop"
and args
and ConstDictVariable.is_valid_key(args[0])
and self.mutable_local
):
newval = dict(val)
result = newval.pop(ConstDictVariable.get_key(args[0]))
tx.replace_all(self, self.modifed(newval))
return result
elif (
name == "update"
and len(args) == 1
and isinstance(args[0], ConstDictVariable)
and self.mutable_local
):
newval = dict(val)
newval.update(args[0].items)
newval.update(kwargs) # all keys in kwargs are valid (`str`s)
result = self.modifed(newval)
return tx.replace_all(self, result)
elif (
name == "update"
and len(args) == 1
and isinstance(
args[0],
(
ListVariable,
TupleVariable,
ListIteratorVariable,
),
)
and self.mutable_local
):
newval = dict(val)
for x in args[0].unpack_var_sequence(tx):
k, v = x.unpack_var_sequence(tx)
assert ConstDictVariable.is_valid_key(k)
newval[ConstDictVariable.get_key(k)] = v
newval.update(kwargs) # all keys in kwargs are valid (`str`s)
result = self.modifed(newval)
return tx.replace_all(self, result)
elif (
name in ("get", "__getattr__")
and args
and ConstDictVariable.is_valid_key(args[0])
and ConstDictVariable.get_key(args[0]) in self.items
):
return self.items[ConstDictVariable.get_key(args[0])]
elif (
name == "__contains__" and args and ConstDictVariable.is_valid_key(args[0])
):
return ConstantVariable.create(
ConstDictVariable.get_key(args[0]) in self.items
)
else:
return super().call_method(tx, name, args, kwargs)
def modifed(self, items, **options):
"""a copy of self with different items"""
return self.clone(items=items, **options)
def unpack_var_sequence(self, tx):
val = self.items
result = [ConstDictVariable._key_to_var(tx, k) for k in val.keys()]
return result
@classmethod
def get_key(cls, arg: VariableTracker):
if isinstance(arg, TensorVariable) and arg.specialized_value is not None:
return arg.specialized_value
else:
return arg.as_python_constant()
@classmethod
def is_valid_key(cls, key):
return (
key.is_python_constant()
or (isinstance(key, TensorVariable) and key.specialized_value is not None)
or (isinstance(key, ConstantVariable) and key.python_type() is torch.dtype)
)
@classmethod
def _key_to_var(cls, tx, key, **options):
from .builder import VariableBuilder
if istensor(key):
return VariableBuilder(tx, GlobalWeakRefSource(global_key_name(key)))(key)
else:
assert ConstantVariable.is_literal(key)
return ConstantVariable.create(key, **options)
class DefaultDictVariable(ConstDictVariable):
def __init__(self, items, user_cls, default_factory=None, **kwargs):
super().__init__(items, user_cls, **kwargs)
assert user_cls is collections.defaultdict
self.default_factory = default_factory
def is_python_constant(self):
# Return false for unsupported defaults. This ensures that a bad handler
# path is not taken in BuiltinVariable for getitem.
if self.default_factory not in [list, tuple, dict] and not self.items:
return False
return super().is_python_constant()
@staticmethod
def is_supported_arg(arg):
if isinstance(arg, variables.BuiltinVariable):
return arg.fn in [list, tuple, dict]
else:
return isinstance(arg, variables.functions.BaseUserFunctionVariable)
def call_method(
self,
tx,
name,
args: "List[VariableTracker]",
kwargs: "Dict[str, VariableTracker]",
) -> "VariableTracker":
if name == "__getitem__":
k = ConstDictVariable.get_key(args[0])
if k in self.items:
return self.getitem_const(args[0])
else:
if self.default_factory is None:
raise KeyError(f"{k}")
else:
if istensor(k):
tx.store_global_weakref(global_key_name(k), k)
new_val = dict(self.items)
default_var = self.default_factory.call_function(tx, [], {})
new_val[k] = default_var
tx.replace_all(self, self.modifed(new_val))
return default_var
else:
return super().call_method(tx, name, args, kwargs)
class SetVariable(VariableTracker):
@dataclasses.dataclass
class SetElement:
vt: VariableTracker
underlying_value: Any
def __hash__(self) -> int:
return hash(self.underlying_value)
def __eq__(self, other: object) -> bool:
if not isinstance(other, SetVariable.SetElement):
return False
if isinstance(self.vt, variables.TensorVariable):
return self.underlying_value is other.underlying_value
else:
return self.underlying_value == other.underlying_value
def __init__(
self,
items: List[VariableTracker],
**kwargs,
):
super().__init__(**kwargs)
# Note - Set is still backed by a list, because we want set behavior over the contents,
assert isinstance(items, list)
assert all(isinstance(x, VariableTracker) for x in items)
self.items = []
self._add(items)
def as_proxy(self):
return [x.as_proxy() for x in self.items]
def python_type(self):
return set
def reconstruct(self, codegen):
codegen.load_import_from("builtins", "set")
codegen.foreach(self.items)
return [
create_instruction("BUILD_SET", arg=len(self.items))
] + create_call_function(1, True)
# Note - this is only used for producing a set
def _as_set_element(self, vt):
from .base import VariableTracker
from .misc import MethodWrapperVariable
from .tensor import TensorVariable
assert isinstance(vt, VariableTracker)
if isinstance(vt, TensorVariable):
fake_tensor = vt.as_proxy().node.meta.get("example_value")
if fake_tensor is None:
unimplemented(
"Cannot check Tensor object identity without its fake value"
)
return SetVariable.SetElement(vt, fake_tensor)
if isinstance(vt, ConstantVariable):
return SetVariable.SetElement(vt, vt.value)
if isinstance(vt, MethodWrapperVariable):
return SetVariable.SetElement(vt, vt.as_python_constant())
unimplemented(f"Sets with {type(vt)} NYI")
@property
def _underlying_items(self):
underlying_items = set()
for current_item in self.items:
assert (
current_item not in underlying_items
), "Items modeling set invariant violated"
underlying_items.add(self._as_set_element(current_item))
return underlying_items
def _add(self, item):
underlying_items = self._underlying_items
if isinstance(item, (list, set)):
items_to_add = item
else:
items_to_add = [item]
for item_to_add in items_to_add:
set_element = self._as_set_element(item_to_add)
if set_element not in underlying_items:
underlying_items.add(set_element)
self.items.append(set_element.vt)
else:
for e in underlying_items:
if hash(set_element) == hash(e):
alias_guard = make_dupe_guard(
e.vt.source, set_element.vt.source
)
if alias_guard:
install_guard(e.vt.source.make_guard(alias_guard))
return self.items
def call_method(
self,
tx,
name,
args: List[VariableTracker],
kwargs: Dict[str, VariableTracker],
) -> "VariableTracker":
# Somewhat duplicative of CommonListMethodsVariable - but better than to violate substitution
# principles and end up with things like direct item access attempts on a set, or
# getitem sources.
if name == "add" and args and self.mutable_local:
assert not kwargs
item = args[0]
result = SetVariable(
self._add(item),
mutable_local=self.mutable_local,
)
tx.replace_all(self, result)
return ConstantVariable.create(None)
elif name == "pop" and self.mutable_local:
assert not kwargs
assert not args
items = list(self.items)
result = items.pop()
tx.replace_all(
self,
SetVariable(items),
)
return result
elif name == "__len__":
return ConstantVariable.create(len(self.items))
elif name == "__contains__":
assert len(args) == 1
assert not kwargs
return iter_contains(self.items, args[0], tx, check_tensor_identity=True)
else:
return super().call_method(tx, name, args, kwargs)
def getitem_const(self, arg: VariableTracker):
raise RuntimeError("Illegal to getitem on a set")
def as_python_constant(self):
return self.python_type()([x.as_python_constant() for x in self.items])
def unpack_var_sequence(self, tx):
return list(self.items)
def _is_matching_transformers_cls(cls) -> bool:
mod = sys.modules.get("transformers.file_utils")
return mod is not None and issubclass(cls, mod.ModelOutput)
def _is_matching_diffusers_cls(cls) -> bool:
mod = sys.modules.get("diffusers.utils")
return mod is not None and issubclass(cls, mod.BaseOutput)
class DataClassVariable(ConstDictVariable):
"""
This is a bit of a hack to deal with
transformers.file_utils.ModelOutput() from huggingface.
ModelOutput causes trouble because it a a mix of a dataclass and a
OrderedDict and it calls super() methods implemented in C.
"""
# ModelOutput() excludes None, though generic datclasses don't
include_none = False
@staticmethod
@functools.lru_cache(None)
def _patch_once():
try:
from transformers.file_utils import ModelOutput
for obj in ModelOutput.__dict__.values():
if callable(obj):
skip_code(obj.__code__)
except ImportError:
pass
try:
from diffusers.utils import BaseOutput
for obj in BaseOutput.__dict__.values():
if callable(obj):
skip_code(obj.__code__)
except ImportError:
pass
@staticmethod
def is_matching_cls(cls):
return _is_matching_transformers_cls(cls) or _is_matching_diffusers_cls(cls)
@classmethod
def is_matching_object(cls, obj):
return cls.is_matching_cls(type(obj))
@classmethod
def create(cls, user_cls, args, kwargs, options):
DataClassVariable._patch_once()
skip_code(user_cls.__init__.__code__)
keys = [f.name for f in dataclasses.fields(user_cls)]
bound = inspect.signature(user_cls).bind(*args, **kwargs)
bound.apply_defaults()
assert set(bound.arguments.keys()) == set(keys)
items = {}
for key in keys:
val = bound.arguments[key]
if isinstance(val, VariableTracker):
items[key] = val
else:
if cls.include_none:
assert variables.ConstantVariable.is_literal(val)
items[key] = variables.ConstantVariable.create(val)
else:
assert val is None, f"unexpected {val}"
if len(items) == 1 and not isinstance(items[keys[0]], variables.TensorVariable):
unimplemented("DataClassVariable iterator constructor")
# TODO(jansel): implement unpacking logic in ModelOutput.__post_init__
return cls(items, user_cls, **options)
@classmethod
def wrap(cls, builder, obj):
user_cls = type(obj)
keys = [f.name for f in dataclasses.fields(user_cls)]
excluded = []
items = {}
for key in keys:
# __init__ function of a dataclass might not have yet defined the key
if hasattr(obj, key):
val = getattr(obj, key)
var = builder.__class__(
tx=builder.tx, source=AttrSource(builder.source, key)
)(val)
if val is not None or cls.include_none:
items[key] = var
else:
excluded.append(var)
return cls(items, user_cls)
def __init__(self, items, user_cls, **options):
super().__init__(items, user_cls, **options)
assert self.is_matching_cls(user_cls)
def as_proxy(self):
raise NotImplementedError()
def reconstruct(self, codegen):
codegen.extend_output([codegen._create_load_const(self.user_cls)])
keys = tuple(self.items.keys())
for key in keys:
codegen(self.items[key])
return codegen.create_call_function_kw(len(keys), keys, True)
def call_method(
self,
tx,
name,
args: "List[VariableTracker]",
kwargs: "Dict[str, VariableTracker]",
) -> "VariableTracker":
if name == "__getitem__":
assert not kwargs and len(args) == 1
index = args[0].as_python_constant()
if isinstance(index, str):
return self.items[index]
else:
return self.call_method(tx, "to_tuple", [], {}).call_method(
tx, "__getitem__", args, kwargs
)
elif name == "to_tuple":
assert not (args or kwargs)
return variables.TupleVariable(list(self.items.values()))
elif name == "__setattr__":
name = "__setitem__"
return super().call_method(tx, name, args, kwargs)
def var_getattr(self, tx, name: str) -> "VariableTracker":
if name in self.items:
return self.call_method(
tx, "__getitem__", [variables.ConstantVariable.create(name)], {}
)
elif not self.include_none:
defaults = {f.name: f.default for f in dataclasses.fields(self.user_cls)}
if name in defaults:
assert variables.ConstantVariable.is_literal(defaults[name])
return variables.ConstantVariable.create(defaults[name])
super().var_getattr(tx, name)
class CustomizedDictVariable(ConstDictVariable):
@staticmethod
def is_matching_cls(cls):
# True if using default OrderedDict.__init__ and did not implement __post_init__
if (
issubclass(cls, collections.OrderedDict)
and cls.__init__ is collections.OrderedDict.__init__
and not hasattr(cls, "__post_init__")
):
return True
# hack for HF usecase:
# assume dataclass annotation for ModelOutput subclass
# assume self.create is AA to ModelOutput.__post_init__
return _is_matching_transformers_cls(cls) or _is_matching_diffusers_cls(cls)
@classmethod
def is_matching_object(cls, obj):
return cls.is_matching_cls(type(obj))
# called from user_defined.py
# when is_matching_cls(cls) is true
@classmethod
def create(cls, user_cls, args, kwargs, options):
# avoid tracing when returning ModelOutput from forward func
for attr_name in ("__init__", "__post_init__", "__setattr__", "__setitem__"):
if hasattr(user_cls, attr_name):
fn = getattr(user_cls, attr_name)
assert callable(fn), f"expect callable attr {attr_name}"
if hasattr(fn, "__code__"):
skip_code(fn.__code__)
if not args and not kwargs:
# CustomDict() init with empty arguments
raw_items = {}
elif dataclasses.is_dataclass(user_cls):
# @dataclass CustomDict(a=1, b=2)
bound = inspect.signature(user_cls).bind(*args, **kwargs)
bound.apply_defaults()
raw_items = bound.arguments
elif not args:
# CustomDict(a=1, b=2) in the general (non-dataclass) case.
raw_items = dict(kwargs)
elif len(args) == 1 and isinstance(args[0], ConstDictVariable) and not kwargs:
# CustomDict({'a': 1, 'b': 2})
raw_items = args[0].items
else:
unimplemented("custom dict init with args/kwargs unimplemented")
items = {}
for key in raw_items.keys():
val = raw_items[key]
if isinstance(val, VariableTracker):
items[key] = val
elif variables.ConstantVariable.is_literal(val):
items[key] = variables.ConstantVariable.create(val)
else:
unimplemented("expect VariableTracker or ConstantVariable.is_literal")
return cls(items, user_cls, **options)
# called from builder.py
@classmethod
def wrap(cls, builder, obj):
raise NotImplementedError()
def __init__(self, items, user_cls, **options):
super().__init__(items, user_cls, **options)
assert self.is_matching_cls(user_cls)
def as_proxy(self):
raise NotImplementedError()
# 'RETURN_VALUE triggered compile'
# called from torch/_dynamo/codegen.py
def reconstruct(self, codegen):
codegen.extend_output([codegen._create_load_const(self.user_cls)])
keys = tuple(self.items.keys())
for key in keys:
codegen(self.items[key])
return codegen.create_call_function_kw(len(keys), keys, True)
def call_method(
self,
tx,
name,
args: "List[VariableTracker]",
kwargs: "Dict[str, VariableTracker]",
) -> "VariableTracker":
fn = getattr(self.user_cls, name)
source = None if self.source is None else AttrSource(self.source, name)
if hasattr(fn, "__objclass__") and fn.__objclass__ in (
dict,
collections.OrderedDict,
):
# for python dict method without overridden
return super().call_method(tx, name, args, kwargs)
elif name in ("__getitem__", "to_tuple", "__setitem__", "__setattr__"):
# for user overridden method
return tx.inline_user_function_return(
variables.UserFunctionVariable(fn, source=source),
[self] + list(args),
kwargs,
)
unimplemented("custom dict: call_method unimplemented name=%s", name)
def var_getattr(self, tx, name: str) -> "VariableTracker":
if name in self.items:
return self.call_method(
tx, "__getitem__", [variables.ConstantVariable.create(name)], {}
)
super().var_getattr(tx, name)
@functools.lru_cache(None)
def _install_PretrainedConfig_patch():
import transformers
# We need to monkeypatch transformers here, sadly.
# TODO(voz): Upstream to transformers lib
def _dynamo_overriden_transformers_eq(self, other):
if not hasattr(other, "__dict__"):
return False
return self.__dict__ == other.__dict__
transformers.configuration_utils.PretrainedConfig.__eq__ = (
_dynamo_overriden_transformers_eq
)
class HFPretrainedConfigVariable(VariableTracker):
"""
Hack for HuggingFace PretrainedConfig
"""
@staticmethod
def is_matching_cls(cls):
mod = sys.modules.get("transformers.configuration_utils")
is_match = mod is not None and issubclass(cls, mod.PretrainedConfig)
# Lazily install monkeypatch the first time we see it in dynamo
if is_match:
_install_PretrainedConfig_patch()
return is_match
@classmethod
def is_matching_object(cls, obj):
return cls.is_matching_cls(type(obj))
def __init__(self, obj, **kwargs):
super().__init__(**kwargs)
self.obj = obj
assert self.is_matching_cls(type(obj))
def var_getattr(self, tx, name: str) -> "VariableTracker":
from . import ConstantVariable
return ConstantVariable.create(getattr(self.obj, name))
def call_hasattr(self, tx, name: str) -> "VariableTracker":
return variables.ConstantVariable.create(hasattr(self.obj, name))
class PythonSysModulesVariable(VariableTracker):
"""Special case for sys.modules.
Without this we will guard on the exact set of modules imported in the
lifetime of the python program.
"""
def python_type(self):
return dict
@staticmethod
def reconstruct(self, codegen):
codegen.extend_output(
[
codegen.create_load_python_module(sys, True),
codegen.create_load_attr("modules"),
]
)
def call_method(
self, tx, name, args: List[VariableTracker], kwargs: Dict[str, VariableTracker]
):
from .builder import VariableBuilder
if name == "__getitem__":
return self.call_getitem(tx, *args, **kwargs)
elif name == "get":
return self.call_get(tx, *args, **kwargs)
elif name == "__contains__":
return self.call_contains(tx, *args, **kwargs)
# Fallback to dict implementation
real_dict = VariableBuilder(tx, self.source)(sys.modules)
return real_dict.call_method(tx, name, args, kwargs)
def _contains_helper(self, tx, key: VariableTracker):
k = ConstDictVariable.get_key(key)
has_key = k in sys.modules
install_guard(
self.make_guard(
functools.partial(GuardBuilder.DICT_CONTAINS, key=k, invert=not has_key)
)
)
return k, has_key
def call_contains(self, tx, key: VariableTracker):
k, has_key = self._contains_helper(tx, key)
return ConstantVariable.create(value=has_key)
def call_get(
self, tx, key: VariableTracker, default: Optional[VariableTracker] = None
):
from .builder import VariableBuilder
k, has_key = self._contains_helper(tx, key)
if has_key:
return VariableBuilder(
tx,
GetItemSource(self.source, k),
)(sys.modules[k])
if default is not None:
return default
return ConstantVariable.create(value=None)
def call_getitem(self, tx, key: VariableTracker):
from .builder import VariableBuilder
k, has_key = self._contains_helper(tx, key)
return VariableBuilder(
tx,
GetItemSource(self.source, k),
)(sys.modules[k])