[dynamo] Support multiple inheritance for custom dict construction (#142416)

This patch applies a local and practical workaround for custom dict
construction when multiple inheritance is involved.

Handling multiple inheritance in general could be a lot more involved,
so I created #142414 to track that.

Fixes #141118.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/142416
Approved by: https://github.com/jansel
This commit is contained in:
Ryan Guo 2024-12-12 12:05:55 -08:00 committed by PyTorch MergeBot
parent b5d8d2444a
commit b4f4c75e19
4 changed files with 75 additions and 41 deletions

View File

@ -3015,6 +3015,45 @@ utils_device.CURRENT_DEVICE == None""".split(
self.assertEqual(fn(args2, x), opt_fn(args2, x))
self.assertEqual(cnts.frame_count, 2)
def test_mutable_mapping_multiple_inheritance(self):
class MyWeirdDict(collections.abc.MutableMapping, torch.nn.Module):
def __init__(self, **kwargs):
super().__init__()
self._items = kwargs
def keys(self):
return self._items.keys()
def __getitem__(self, item):
return self._items[item]
def __setitem__(self, key, value):
self._items[key] = value
def __delitem__(self, item):
del self._items[item]
def __len__(self):
return len(self._items)
def __iter__(self):
yield from self._items
def __hash__(self):
return hash(id(self))
def items(self):
for k, v in self._items.items():
yield (k, v)
@torch.compile(fullgraph=True)
def to_weird_dict(td):
return MyWeirdDict(**td)
d = MyWeirdDict(a=1, b=2, c=3)
res = to_weird_dict(d)
self.assertEqual(tuple(d.items()), tuple(res.items()))
def test_dunder_new_function_inlining(self):
# https://github.com/pytorch/pytorch/issues/107460

View File

@ -9,7 +9,7 @@ import math
import operator
import types
from collections import defaultdict, OrderedDict
from collections.abc import KeysView
from collections.abc import KeysView, MutableMapping
from typing import Dict, List, TYPE_CHECKING
import torch
@ -1394,18 +1394,28 @@ class BuiltinVariable(VariableTracker):
return ConstDictVariable(
items, user_cls, mutation_type=ValueMutationNew()
)
elif isinstance(arg, variables.MutableMappingVariable):
# This is applicable for user defined objects which seem like dict, but are not really dicts. For
# example, TensorDict derives from MutableMapping. For such cases, we can directly inline the .items
# method and create a new dict.
elif hasattr(arg, "value") and isinstance(arg.value, MutableMapping):
# This handles all other `MutableMapping` instances; for
# example, TensorDict which derives from MutableMapping.
#
# TODO(#142414) `hasattr(arg, 'value')` is a local workaround
# for lack of generall multiple inheritance in Dynamo. We can't
# use `isinstance(arg, MutableMappingVariable)` here because
# `arg` could be, e.g., a `UnspecializedNNModuleVariable` when
# `arg.value` has multiple inheritace.
if does_not_override_dict_iter_methods(type(arg.value)):
# These are implemeted in C, so we will have to manually construct the items
# In this case, `arg.value.items()` uses the default impls,
# which are implemented in C and cannot be traced, so we
# will have to manually construct the items. This is safe
# because we know they are side-effect free.
#
# Mutation tracked by Dynamo isn't reflected in `arg.value`,
# so we can't handle such cases by just calling
# `arg.value.items()`
if tx.output.side_effects.has_pending_mutation(arg):
unimplemented(
f"{user_cls.__name__}.items(): {args} {kwargs} - object is mutated"
)
new_dict = dict(arg.value.items())
return VariableTracker.build(tx, new_dict)
else:

View File

@ -89,7 +89,7 @@ class ConstDictVariable(VariableTracker):
Hashable = ConstDictVariable._HashableTracker
x = tuple(Hashable(e).underlying_value for e in self.vt.items)
elif isinstance(self.vt, variables.NNModuleVariable):
return self.vt.module
return self.vt.value
elif isinstance(self.vt, variables.UnspecializedNNModuleVariable):
return self.vt.value
elif isinstance(self.vt, variables.UserFunctionVariable):
@ -277,14 +277,7 @@ class ConstDictVariable(VariableTracker):
args: "List[VariableTracker]",
kwargs: "Dict[str, VariableTracker]",
) -> "VariableTracker":
from . import (
BuiltinVariable,
ConstantVariable,
ListIteratorVariable,
ListVariable,
TupleVariable,
UserDefinedObjectVariable,
)
from . import BuiltinVariable, ConstantVariable, TupleVariable
Hashable = ConstDictVariable._HashableTracker
@ -344,31 +337,23 @@ class ConstDictVariable(VariableTracker):
self.items.clear()
return ConstantVariable.create(None)
elif name == "update" and self.is_mutable():
is_args_supported = len(args) == 1 and isinstance(
args[0],
(
ConstDictVariable,
ListVariable,
TupleVariable,
ListIteratorVariable,
variables.IteratorVariable,
UserDefinedObjectVariable,
),
)
is_kwargs_supported = len(kwargs) > 0 and len(args) == 0
if is_args_supported or is_kwargs_supported:
# In general, this call looks like `a.update(b, x=1, y=2, ...)`.
# Either `b` or the kwargs is omittable, but not both.
has_arg = len(args) == 1
has_kwargs = len(kwargs) > 0
if has_arg or has_kwargs:
tx.output.side_effects.mutation(self)
if len(args) == 1:
if has_arg:
if isinstance(args[0], ConstDictVariable):
dict_vt = args[0]
else:
dict_vt = BuiltinVariable.call_custom_dict(tx, dict, args[0])
self.items.update(dict_vt.items)
# Wrap strings
if has_kwargs:
# Handle kwargs
kwargs = {
Hashable(ConstantVariable.create(k)): v for k, v in kwargs.items()
Hashable(ConstantVariable.create(k)): v
for k, v in kwargs.items()
}
self.items.update(kwargs)
return ConstantVariable.create(None)

View File

@ -131,18 +131,18 @@ class NNModuleVariable(VariableTracker):
_nonvar_fields = {
"module_type",
"module_key",
"module",
"value",
"nn_module_stack_source",
*VariableTracker._nonvar_fields,
}
def __init__(
self, module_type: type, module_key: str, module: torch.nn.Module, **kwargs
self, module_type: type, module_key: str, value: torch.nn.Module, **kwargs
) -> None:
super().__init__(**kwargs)
self.module_type = module_type
self.module_key = module_key
self.module = module
self.value = value
assert self.source
self.nn_module_stack_source = self.source