mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[dynamo] Support multiple inheritance for custom dict construction (#142416)
This patch applies a local and practical workaround for custom dict construction when multiple inheritance is involved. Handling multiple inheritance in general could be a lot more involved, so I created #142414 to track that. Fixes #141118. Pull Request resolved: https://github.com/pytorch/pytorch/pull/142416 Approved by: https://github.com/jansel
This commit is contained in:
parent
b5d8d2444a
commit
b4f4c75e19
|
|
@ -3015,6 +3015,45 @@ utils_device.CURRENT_DEVICE == None""".split(
|
|||
self.assertEqual(fn(args2, x), opt_fn(args2, x))
|
||||
self.assertEqual(cnts.frame_count, 2)
|
||||
|
||||
def test_mutable_mapping_multiple_inheritance(self):
|
||||
class MyWeirdDict(collections.abc.MutableMapping, torch.nn.Module):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__()
|
||||
self._items = kwargs
|
||||
|
||||
def keys(self):
|
||||
return self._items.keys()
|
||||
|
||||
def __getitem__(self, item):
|
||||
return self._items[item]
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
self._items[key] = value
|
||||
|
||||
def __delitem__(self, item):
|
||||
del self._items[item]
|
||||
|
||||
def __len__(self):
|
||||
return len(self._items)
|
||||
|
||||
def __iter__(self):
|
||||
yield from self._items
|
||||
|
||||
def __hash__(self):
|
||||
return hash(id(self))
|
||||
|
||||
def items(self):
|
||||
for k, v in self._items.items():
|
||||
yield (k, v)
|
||||
|
||||
@torch.compile(fullgraph=True)
|
||||
def to_weird_dict(td):
|
||||
return MyWeirdDict(**td)
|
||||
|
||||
d = MyWeirdDict(a=1, b=2, c=3)
|
||||
res = to_weird_dict(d)
|
||||
self.assertEqual(tuple(d.items()), tuple(res.items()))
|
||||
|
||||
def test_dunder_new_function_inlining(self):
|
||||
# https://github.com/pytorch/pytorch/issues/107460
|
||||
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ import math
|
|||
import operator
|
||||
import types
|
||||
from collections import defaultdict, OrderedDict
|
||||
from collections.abc import KeysView
|
||||
from collections.abc import KeysView, MutableMapping
|
||||
from typing import Dict, List, TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
|
|
@ -1394,18 +1394,28 @@ class BuiltinVariable(VariableTracker):
|
|||
return ConstDictVariable(
|
||||
items, user_cls, mutation_type=ValueMutationNew()
|
||||
)
|
||||
elif isinstance(arg, variables.MutableMappingVariable):
|
||||
# This is applicable for user defined objects which seem like dict, but are not really dicts. For
|
||||
# example, TensorDict derives from MutableMapping. For such cases, we can directly inline the .items
|
||||
# method and create a new dict.
|
||||
elif hasattr(arg, "value") and isinstance(arg.value, MutableMapping):
|
||||
# This handles all other `MutableMapping` instances; for
|
||||
# example, TensorDict which derives from MutableMapping.
|
||||
#
|
||||
# TODO(#142414) `hasattr(arg, 'value')` is a local workaround
|
||||
# for lack of generall multiple inheritance in Dynamo. We can't
|
||||
# use `isinstance(arg, MutableMappingVariable)` here because
|
||||
# `arg` could be, e.g., a `UnspecializedNNModuleVariable` when
|
||||
# `arg.value` has multiple inheritace.
|
||||
if does_not_override_dict_iter_methods(type(arg.value)):
|
||||
# These are implemeted in C, so we will have to manually construct the items
|
||||
|
||||
# In this case, `arg.value.items()` uses the default impls,
|
||||
# which are implemented in C and cannot be traced, so we
|
||||
# will have to manually construct the items. This is safe
|
||||
# because we know they are side-effect free.
|
||||
#
|
||||
# Mutation tracked by Dynamo isn't reflected in `arg.value`,
|
||||
# so we can't handle such cases by just calling
|
||||
# `arg.value.items()`
|
||||
if tx.output.side_effects.has_pending_mutation(arg):
|
||||
unimplemented(
|
||||
f"{user_cls.__name__}.items(): {args} {kwargs} - object is mutated"
|
||||
)
|
||||
|
||||
new_dict = dict(arg.value.items())
|
||||
return VariableTracker.build(tx, new_dict)
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -89,7 +89,7 @@ class ConstDictVariable(VariableTracker):
|
|||
Hashable = ConstDictVariable._HashableTracker
|
||||
x = tuple(Hashable(e).underlying_value for e in self.vt.items)
|
||||
elif isinstance(self.vt, variables.NNModuleVariable):
|
||||
return self.vt.module
|
||||
return self.vt.value
|
||||
elif isinstance(self.vt, variables.UnspecializedNNModuleVariable):
|
||||
return self.vt.value
|
||||
elif isinstance(self.vt, variables.UserFunctionVariable):
|
||||
|
|
@ -277,14 +277,7 @@ class ConstDictVariable(VariableTracker):
|
|||
args: "List[VariableTracker]",
|
||||
kwargs: "Dict[str, VariableTracker]",
|
||||
) -> "VariableTracker":
|
||||
from . import (
|
||||
BuiltinVariable,
|
||||
ConstantVariable,
|
||||
ListIteratorVariable,
|
||||
ListVariable,
|
||||
TupleVariable,
|
||||
UserDefinedObjectVariable,
|
||||
)
|
||||
from . import BuiltinVariable, ConstantVariable, TupleVariable
|
||||
|
||||
Hashable = ConstDictVariable._HashableTracker
|
||||
|
||||
|
|
@ -344,31 +337,23 @@ class ConstDictVariable(VariableTracker):
|
|||
self.items.clear()
|
||||
return ConstantVariable.create(None)
|
||||
elif name == "update" and self.is_mutable():
|
||||
is_args_supported = len(args) == 1 and isinstance(
|
||||
args[0],
|
||||
(
|
||||
ConstDictVariable,
|
||||
ListVariable,
|
||||
TupleVariable,
|
||||
ListIteratorVariable,
|
||||
variables.IteratorVariable,
|
||||
UserDefinedObjectVariable,
|
||||
),
|
||||
)
|
||||
|
||||
is_kwargs_supported = len(kwargs) > 0 and len(args) == 0
|
||||
|
||||
if is_args_supported or is_kwargs_supported:
|
||||
# In general, this call looks like `a.update(b, x=1, y=2, ...)`.
|
||||
# Either `b` or the kwargs is omittable, but not both.
|
||||
has_arg = len(args) == 1
|
||||
has_kwargs = len(kwargs) > 0
|
||||
if has_arg or has_kwargs:
|
||||
tx.output.side_effects.mutation(self)
|
||||
if len(args) == 1:
|
||||
if has_arg:
|
||||
if isinstance(args[0], ConstDictVariable):
|
||||
dict_vt = args[0]
|
||||
else:
|
||||
dict_vt = BuiltinVariable.call_custom_dict(tx, dict, args[0])
|
||||
self.items.update(dict_vt.items)
|
||||
# Wrap strings
|
||||
if has_kwargs:
|
||||
# Handle kwargs
|
||||
kwargs = {
|
||||
Hashable(ConstantVariable.create(k)): v for k, v in kwargs.items()
|
||||
Hashable(ConstantVariable.create(k)): v
|
||||
for k, v in kwargs.items()
|
||||
}
|
||||
self.items.update(kwargs)
|
||||
return ConstantVariable.create(None)
|
||||
|
|
|
|||
|
|
@ -131,18 +131,18 @@ class NNModuleVariable(VariableTracker):
|
|||
_nonvar_fields = {
|
||||
"module_type",
|
||||
"module_key",
|
||||
"module",
|
||||
"value",
|
||||
"nn_module_stack_source",
|
||||
*VariableTracker._nonvar_fields,
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self, module_type: type, module_key: str, module: torch.nn.Module, **kwargs
|
||||
self, module_type: type, module_key: str, value: torch.nn.Module, **kwargs
|
||||
) -> None:
|
||||
super().__init__(**kwargs)
|
||||
self.module_type = module_type
|
||||
self.module_key = module_key
|
||||
self.module = module
|
||||
self.value = value
|
||||
assert self.source
|
||||
self.nn_module_stack_source = self.source
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user