pytorch/test/dynamo/test_dicts.py
2025-10-31 10:20:48 +00:00

1756 lines
51 KiB
Python

# Owner(s): ["module: dynamo"]
# ruff: noqa: TRY002
import enum
import itertools
import operator
import types
import unittest
import weakref
from collections import defaultdict, namedtuple, OrderedDict, UserDict
from typing import Any
import torch
import torch._dynamo.test_case
import torch._dynamo.testing
import torch._functorch.config
import torch.nn
import torch.utils.checkpoint
from torch._dynamo.testing import same
from torch._dynamo.utils import dict_items
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
make_dynamo_test,
munge_exc,
parametrize,
)
from torch.testing._internal.logging_utils import LoggingTestCase, make_logging_test
class SimpleDict(dict):
pass
class DummyUserDict(UserDict):
pass
class DictTests(torch._dynamo.test_case.TestCase):
def test_dict_subclass_instantiation(self):
def fn(x):
sd = SimpleDict(x=5)
return sd["x"] * x
x = torch.randn(4)
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
self.assertEqual(fn(x), opt_fn(x))
def test_dict_subclass_local_mutation(self):
def fn(x):
sd = SimpleDict(x=5)
z = sd["x"] * x
sd["x"] = 10
return z * sd["x"]
x = torch.randn(4)
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
self.assertEqual(fn(x), opt_fn(x))
def test_dict_contains_enum(self):
class TensorDim(str, enum.Enum):
DDP = "ddp"
FSDP = "fsdp"
CP = "cp"
TP = "tp"
class Foo(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
val = x.sin()
if TensorDim.DDP in {"ddp"}:
val += x.cos()
if "ddp" in {TensorDim.DDP}:
val += x.cos()
return val
inp = torch.randn(4, 4)
mod = Foo()
opt_f = torch.compile(mod)
self.assertEqual(mod(inp), opt_f(inp))
def test_dict_subclass_local_with_non_dict_method(self):
# Checks that add_1 method is inlined
class MethodDict(dict):
def add_1(self, x):
return x + 1
def fn(x):
sd = MethodDict(x=5)
z = sd["x"] * x
sd["x"] = 10
return sd.add_1(z * sd["x"])
x = torch.randn(4)
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
self.assertEqual(fn(x), opt_fn(x))
def test_dict_contains(self):
sd = dict()
sd[2] = 5
sd[4] = 10
def fn(x):
if 1 in sd:
x = x * 2
else:
x = x * 3
return x
x = torch.randn(4)
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
self.assertEqual(fn(x), opt_fn(x))
# Ensure a recompilation
sd[1] = 15
self.assertEqual(fn(x), opt_fn(x))
# Ensure not recompilation because the traced program remains same here.
sd[2] = 10
with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True):
self.assertEqual(fn(x), opt_fn(x))
def test_dict_subclass_methods_fallback_readonly(self):
sd = SimpleDict()
sd[2] = 5
sd[4] = 10
# check that regular attr accesses work well
sd.attr = 4
def fn(x):
for value in sd.values():
x = x * value
for key in sd.keys():
x = x * key
for k, v in sd.items():
x = x * k
x = x * v
# for k in sd:
# x = x * k
if 1 in sd:
x = x * 2
else:
x = x * 3
x = x * sd.get(2, 0)
x = x * sd.get(3, 4)
x = len(sd) * x
x = x * sd.attr
return x
x = torch.randn(4)
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
self.assertEqual(fn(x), opt_fn(x))
# Ensure a recompilation
sd[6] = 15
self.assertEqual(fn(x), opt_fn(x))
def test_dict_subclass_instantiation_return(self):
def fn(x):
sd = SimpleDict(x=5 * x)
sd["y"] = 10
return sd
x = torch.randn(4)
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
ref = fn(x)
res = opt_fn(x)
self.assertEqual(type(ref), type(res))
self.assertEqual(ref["x"], res["x"])
self.assertEqual(ref["y"], res["y"])
def test_dict_subclass_methods_fallback_mutation(self):
def fn(sd, x):
for value in sd.values():
x = x * value
sd[6] = 14
for key in sd.keys():
x = x * key
for k, v in sd.items():
x = x * k
x = x * v
# for k in sd:
# x = x * k
if 1 in sd:
x = x * 2
else:
x = x * 3
x = x * sd.get(2, 0)
x = x * sd.get(3, 4)
x = len(sd) * x
return x
x = torch.randn(4)
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
sd1 = SimpleDict()
sd1[2] = 5
sd1[4] = 10
sd2 = SimpleDict()
sd2[2] = 5
sd2[4] = 10
self.assertTrue(sd1 == sd2)
self.assertEqual(fn(sd1, x), opt_fn(sd2, x))
self.assertTrue(sd1 == sd2)
def test_dict_subclass_setitem(self):
class SetItemDict(dict):
def __setitem__(self, key, value):
super().__setitem__(key, value + 1)
def fn(x):
sd = SetItemDict(x=5 * x)
sd["y"] = 10
return sd
x = torch.randn(4)
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
ref = fn(x)
res = opt_fn(x)
self.assertEqual(type(ref), type(res))
self.assertEqual(ref["x"], res["x"])
self.assertEqual(ref["y"], res["y"])
def test_custom_iter_dict(self):
class ReversedDict(dict):
def __iter__(self):
return reversed(list(self.keys()))
d = {
"foo": 1,
"bar": 2,
}
d = ReversedDict(d)
@torch.compile(backend="eager")
def fn(x, d):
# Forces side effects attribute reapplication logic
d.sample = 1
d["baz"] = 4
return x * d["foo"] * d["bar"]
fn(torch.randn(4), d)
# This is intentional because the dict is mutated, so we will have a recompilation.
fn(torch.randn(4), d)
with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True):
fn(torch.randn(4), d)
def test_custom_keys_iter_dict(self):
class ReversedDict(dict):
def keys(self):
return ["bar", "foo"]
d = {
"foo": 1,
"bar": 2,
}
d = ReversedDict(d)
@torch.compile(backend="eager")
def fn(x, d):
return x * d["foo"] * d["bar"]
fn(torch.randn(4), d)
with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True):
fn(torch.randn(4), d)
def test_dict_guard_on_keys_order(self):
d = {
2: 4,
3: 5,
}
cnts = torch._dynamo.testing.CompileCounter()
def fn(x, d):
for key, value in d.items():
x = x * key + value
return x
opt_fn = torch.compile(fn, backend=cnts)
opt_fn(torch.randn(4), d)
opt_fn(torch.randn(4), d)
# No recompilation
self.assertEqual(cnts.frame_count, 1)
# move 2 to the end
d[2] = d.pop(2)
x = torch.randn(4)
res = opt_fn(x, d)
# Check recompilation
self.assertEqual(cnts.frame_count, 2)
self.assertEqual(res, fn(x, d))
def test_dict_guard_on_keys_order2(self):
d = {
2: 4,
3: 5,
}
cnts = torch._dynamo.testing.CompileCounter()
def fn(x, d):
for key in d:
value = d[key]
x = x * key + value
return x
opt_fn = torch.compile(fn, backend=cnts)
opt_fn(torch.randn(4), d)
opt_fn(torch.randn(4), d)
# No recompilation
self.assertEqual(cnts.frame_count, 1)
# move 2 to the end
d[2] = d.pop(2)
x = torch.randn(4)
res = opt_fn(x, d)
# Check recompilation
self.assertEqual(cnts.frame_count, 2)
self.assertEqual(res, fn(x, d))
def test_ordered_dict_reordered_keys(self):
d = OrderedDict()
d[2] = 4
d[3] = 5
d.move_to_end(2)
cnts = torch._dynamo.testing.CompileCounter()
def fn(x, d):
y = 0
for idx, value in enumerate(d.values()):
if idx == 0:
y += torch.sin(x * value)
else:
y += torch.cos(x * value)
return y
opt_fn = torch.compile(fn, backend=cnts)
x = torch.randn(4)
self.assertEqual(opt_fn(x, d), fn(x, d))
def test_ordered_dict_subclass_reordered_keys(self):
class ODSubclass(OrderedDict):
def keys(self):
return super().keys()
d = ODSubclass()
d[2] = 4
d[3] = 5
d.move_to_end(2)
cnts = torch._dynamo.testing.CompileCounter()
def fn(x, d):
y = 0
for idx, value in enumerate(d.values()):
if idx == 0:
y += torch.sin(x * value)
else:
y += torch.cos(x * value)
return y
opt_fn = torch.compile(fn, backend=cnts)
x = torch.randn(4)
self.assertEqual(opt_fn(x, d), fn(x, d))
def test_lazy_key_guarding(self):
d = {"a": 2, "b": 3, "c": 5}
def fn(x):
return x * d["a"]
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
x = torch.randn(4)
ref = fn(x)
res = opt_fn(x)
self.assertEqual(ref, res)
# Since key c was not used, it should not lead to a recompilation
d.pop("c")
d["d"] = 10
with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True):
ref = fn(x)
res = opt_fn(x)
self.assertEqual(ref, res)
def test_lazy_key_non_const_guarding(self):
d = {
list: 2,
dict: 3,
OrderedDict: 5,
namedtuple: 7,
}
def fn(x):
return x * d[list]
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
x = torch.randn(4)
ref = fn(x)
res = opt_fn(x)
self.assertEqual(ref, res)
# Since key c was not used, it should not lead to a recompilation
d.pop(dict)
d[defaultdict] = 10
with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True):
ref = fn(x)
res = opt_fn(x)
self.assertEqual(ref, res)
def test_dict_mutation_side_effect(self):
def fn(d):
d["c"] = d["a"] + d.pop("b")
return d
args1 = {"a": torch.randn(10), "b": torch.randn(10)}
args2 = dict(args1)
assert fn(args1) is args1
cnts = torch._dynamo.testing.CompileCounter()
opt_fn = torch.compile(fn, backend=cnts)
self.assertIs(opt_fn(args2), args2)
self.assertTrue(same(args1, args2))
self.assertEqual(cnts.frame_count, 1)
self.assertEqual(cnts.op_count, 1)
def test_dict_copy_alias(self):
@torch.compile(backend="eager", fullgraph=True)
def run(x, d0):
d1 = d0.copy()
d1[0] = 1
return x + 1, d1
d0 = {}
res, d1 = run(torch.zeros(1), d0)
self.assertTrue(same(res, torch.ones(1)))
self.assertEqual(d0, {})
self.assertEqual(d1, {0: 1})
def test_dict_subclass_get_method(self):
class dotdict(dict):
"""dot.notation access to dictionary attributes"""
__getattr__ = dict.get
__setattr__ = dict.__setitem__
__delattr__ = dict.__delitem__
config = dotdict({"a": 1, "b": 2})
def fn(x):
x2 = x * 2 # noqa: F841
x3 = x * config.get("a", 3)
return x3
x = torch.randn(2)
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
self.assertEqual(fn(x), opt_fn(x))
def test_dict_order_keys(self):
def fn(d):
c = 0
for v in d.values():
c += v
return c
args1 = {}
args1["a"] = torch.rand(10)
args1["b"] = torch.rand(10)
cnts = torch._dynamo.testing.CompileCounter()
opt_fn = torch.compile(fn, backend=cnts)
self.assertEqual(fn(args1), opt_fn(args1))
self.assertEqual(cnts.frame_count, 1)
self.assertEqual(cnts.op_count, 2)
# A different order of keys recompiles
args2 = {}
args2["b"] = args1["b"]
args2["a"] = args1["a"]
self.assertEqual(fn(args2), opt_fn(args2))
self.assertEqual(cnts.frame_count, 2)
# Extra calls don't recompile
self.assertEqual(cnts.frame_count, 2)
def test_dict_namedtuple(self):
def fn(d):
if namedtuple in d:
return d[3] * 2
else:
return d[3] * 3
args1 = {namedtuple: None, 3: torch.randn(3)}
cnts = torch._dynamo.testing.CompileCounter()
opt_fn = torch.compile(fn, backend=cnts)
self.assertEqual(fn(args1), opt_fn(args1))
self.assertEqual(cnts.frame_count, 1)
# Test a failing namedtuple guard
args2 = {2: None, 3: torch.randn(3)}
self.assertEqual(fn(args2), opt_fn(args2))
self.assertEqual(cnts.frame_count, 2)
def test_dict_order_keys_tensors(self):
def fn(d, x):
return d[x] + 3
args1 = {}
x = torch.randn(10)
y = torch.randn(10)
z = torch.randn(10)
args1[x] = y
args1[3] = z
cnts = torch._dynamo.testing.CompileCounter()
opt_fn = torch.compile(fn, backend=cnts)
self.assertEqual(fn(args1, x), opt_fn(args1, x))
self.assertEqual(cnts.frame_count, 1)
# Calling again doesn't recompile (same id and key order)
opt_fn(args1, x)
self.assertEqual(cnts.frame_count, 1)
args2 = {}
args2[3] = z
args2[x] = y
# Different order recompiles
self.assertEqual(fn(args2, x), opt_fn(args2, x))
self.assertEqual(cnts.frame_count, 2)
def test_dict_order_keys_modules(self):
def fn(d, x):
return d[x](torch.ones(2, 2))
args1 = {}
x = torch.nn.Linear(2, 2)
y = torch.nn.Linear(2, 2)
z = torch.nn.Linear(2, 2)
args1[x] = y
args1[3] = z
cnts = torch._dynamo.testing.CompileCounter()
opt_fn = torch.compile(fn, backend=cnts)
self.assertEqual(fn(args1, x), opt_fn(args1, x))
self.assertEqual(cnts.frame_count, 1)
# Calling again doesn't recompile (same id and key order)
opt_fn(args1, x)
self.assertEqual(cnts.frame_count, 1)
args2 = {}
args2[3] = z
args2[x] = y
# Different order recompiles
self.assertEqual(fn(args2, x), opt_fn(args2, x))
self.assertEqual(cnts.frame_count, 2)
def test_contains_dunder_dict(self):
class UserDefined:
def __init__(self) -> None:
self.a = 3
self.b = 5
def run(self, x):
if "a" in self.__dict__:
x = x * self.a
if "b" in self.__dict__:
x = x * self.b
self.c = 7
if "c" in self.__dict__:
x = x * self.c
return x * self.__dict__.get("a") * self.__dict__.get("z", 2)
obj = UserDefined()
def fn(x):
return obj.run(x)
x = torch.randn(4)
ref = fn(x)
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
res = opt_fn(x)
self.assertEqual(ref, res)
def test_contains_module_dunder_dict(self):
class MyModule(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.foo = 1
self.bar = 2
self.baz = 3
def forward(self, x):
if "foo" in self.__dict__:
return x * self.bar
return x * self.baz
mod = MyModule()
x = torch.randn(10)
opt_mod = torch.compile(mod, backend="eager", fullgraph=True)
self.assertEqual(mod(x), opt_mod(x))
def test_update_dunder_dict(self):
class UserDefined:
def run(self, x):
self.__dict__["a"] = 10
return x * self.a + self.__dict__["a"]
obj1 = UserDefined()
obj2 = UserDefined()
def fn(x, obj):
return obj.run(x)
x = torch.randn(4)
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
ref = fn(x, obj1)
res = opt_fn(x, obj2)
self.assertEqual(ref, res)
# Make sure only `a` is updated.
self.assertEqual(obj1.__dict__, obj2.__dict__)
def test_update_module_dunder_dict(self):
class MyModule(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, x):
self.__dict__["a"] = 10
return x * self.a + self.__dict__["a"]
mod = MyModule()
x = torch.randn(10)
opt_mod = torch.compile(mod, backend="eager", fullgraph=True)
self.assertEqual(mod(x), opt_mod(x))
def test_dict_reconstruct_keeps_original_order(self):
def fn():
modules = OrderedDict([("act", torch.nn.ReLU())])
module_dict = torch.nn.ModuleDict(modules)
next_modules = {"fc4": torch.nn.Linear(5, 6), "act3": torch.nn.Sigmoid()}
modules.update(next_modules.items())
module_dict.update(next_modules)
return modules, module_dict
cnts = torch._dynamo.testing.CompileCounter()
opt_fn = torch.compile(fn, backend=cnts)
modules, module_dict = opt_fn()
self.assertEqual(len(module_dict), len(modules))
for k1, m2 in zip(modules, module_dict.children()):
self.assertTrue(modules[k1] is m2)
def test_dict_subclass_initialization_in_graph(self):
for super_class in (
OrderedDict,
dict,
):
class CustomDict(super_class):
def __new__(cls, *args, **kwargs):
return super().__new__(cls, *args, **kwargs)
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def fn(x):
c = CustomDict()
c["key"] = x
assert "key" in c
return c["key"] + 1
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
x = torch.rand(4)
self.assertEqual(fn(x), opt_fn(x))
def test_dict_list_values(self):
def inner_fn(args):
return [x[1].shape for x in args]
@torch.compile(backend="eager")
def fn(tensors):
return inner_fn(zip(itertools.count(), tensors["args"]))
fn({"args": [torch.ones(5, 5), torch.ones(5, 6), torch.ones(5, 7)]})
fn({"args": [torch.ones(5, 5)]})
def test_dict_iter(self):
class MyMod(torch.nn.Module):
def forward(self, x):
z = {"my": 1, "const": 2, "dict": 3, "variable": 4}
tot = 0
for key in z:
tot += z[key]
return tot
x = torch.tensor([0])
model = MyMod()
opt_model = torch.compile(model, backend="eager", fullgraph=True)
y = opt_model(x)
self.assertEqual(y, 10)
def test_dict_subclass_contains(self):
# pattern from huggingface
class ClassInstantier(OrderedDict):
pass
@torch.compile(fullgraph=True, backend="eager")
def f(x, d):
if "key1" in d:
x = x + 2
if "key2" in d:
x = x + 4
x = x + 8
return x
result = f(torch.ones(8), ClassInstantier({"key1": torch.ones(8)}))
self.assertTrue(same(result, torch.full([8], 11.0)))
result = f(torch.ones(8), ClassInstantier({"key2": torch.ones(8)}))
self.assertTrue(same(result, torch.full([8], 13.0)))
def test_dict_tag_guard(self):
class Foo:
def __init__(self) -> None:
self.scalar = 10
def fn(d, x):
return d["a"] * d["b"] * d["c"].scalar * x
foo = Foo()
d = {"a": 2, "b": 3, "c": foo}
opt_fn = torch.compile(fn, backend="eager")
inp = torch.randn(3, 3)
self.assertEqual(fn(d, inp), opt_fn(d, inp))
d["a"] = 4
self.assertEqual(fn(d, inp), opt_fn(d, inp))
# Check that recompilation happens
foo.scalar = 12
self.assertEqual(fn(d, inp), opt_fn(d, inp))
def test_empty_dict_recompilation(self):
def fn(d, x):
if d:
return torch.cos(x)
return torch.sin(x)
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
x = torch.randn(4)
self.assertEqual(fn({}, x), opt_fn({}, x))
self.assertEqual(fn({"a": 1}, x), opt_fn({"a": 1}, x))
def test_udf_dict_reconstruction(self):
class MyDict(dict):
pass
def fn(x, klass):
x = x * 2
sc_dict = dict.__new__(klass)
sc_dict["x"] = x
if isinstance(sc_dict, MyDict):
sc_dict.attr = 3
return sc_dict
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
x = torch.randn(4)
ref = fn(x, MyDict)
res = opt_fn(x, MyDict)
self.assertEqual(ref, res)
self.assertTrue(isinstance(res, MyDict))
self.assertEqual(ref.attr, res.attr)
ref = fn(x, dict)
res = opt_fn(x, dict)
self.assertEqual(ref, res)
self.assertTrue(isinstance(res, dict))
def test_weakref_dict(self):
states = weakref.WeakKeyDictionary()
mod1 = torch.nn.Module()
mod2 = torch.nn.Module()
states[mod1] = 2
states[mod2] = 3
def fn(x):
if mod1 in states:
x = torch.sin(x)
if mod2 in states:
x = torch.cos(x)
return x
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
x = torch.randn(4)
self.assertEqual(fn(x), opt_fn(x))
def test_construct_user_dict_and_return(self):
def fn(x):
return DummyUserDict({"a": x + 1})
x = torch.randn(4)
res = fn(x)
self.assertEqual(res["a"], x + 1)
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
self.assertEqual(res["a"], opt_fn(x)["a"])
def test_fn_id(self):
def fn(x, f):
d = {id(f): 3}
return x * d[id(f)]
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
x = torch.randn(4)
def nothing():
pass
f = nothing
self.assertEqual(fn(x, f), opt_fn(x, f))
def test_mapping_proxy_for_local(self):
def fn(x):
d = {"a": 2, "b": 3, "c": 5 * x}
mp = types.MappingProxyType(d)
y = torch.sin(x * mp["a"])
for v in mp.values():
y += torch.cos(x * v)
return mp
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
x = torch.randn(4)
ref = fn(x)
res = opt_fn(x)
self.assertEqual(ref, res)
self.assertTrue(type(res) is types.MappingProxyType)
def test_mapping_proxy_for_nonlocal(self):
d = {"a": 2, "b": 3, "c": 5}
def fn(x):
mp = types.MappingProxyType(d)
y = torch.sin(x * mp["a"])
for v in mp.values():
y += torch.cos(x * v)
d["d"] = 4
return mp
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
x = torch.randn(4)
ref = fn(x)
res = opt_fn(x)
self.assertEqual(ref, res)
self.assertTrue(type(res) is types.MappingProxyType)
# check update to d is reflected in res
d["e"] = 5
self.assertEqual(d["e"], res["e"])
def test_mapping_proxy_existing(self):
d = {"a": 2, "b": 3, "c": 5}
def fn(x, mp):
y = torch.sin(x * mp["a"])
for v in mp.values():
y += torch.cos(x * v)
if isinstance(mp, types.MappingProxyType):
y *= 2
return y
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
x = torch.randn(4)
mp = types.MappingProxyType(d)
ref = fn(x, mp)
res = opt_fn(x, mp)
self.assertEqual(ref, res)
d["a"] = 3
ref = fn(x, mp)
res = opt_fn(x, mp)
self.assertEqual(ref, res)
d.pop("b")
ref = fn(x, mp)
res = opt_fn(x, mp)
self.assertEqual(ref, res)
def test_dict_construction_from_mapping_proxy(self):
d = {"a": 2, "b": 3, "c": 5}
def fn(x, mp):
d = dict(mp)
y = torch.sin(x * d["a"])
return y
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
x = torch.randn(4)
mp = types.MappingProxyType(d)
ref = fn(x, mp)
res = opt_fn(x, mp)
self.assertEqual(ref, res)
def test_mapping_proxy_existing_mutation(self):
d = {"a": 2, "b": 3, "c": 5}
mp = types.MappingProxyType(d)
def fn(x):
d["d"] = 4
y = torch.sin(x * mp["d"])
return y
opt_fn = torch.compile(fn, backend="eager")
x = torch.randn(4)
ref = torch.sin(x * 4)
res = opt_fn(x)
self.assertEqual(ref, res)
self.assertEqual(d.keys(), mp.keys())
def test_mapping_proxy_existing_local_mutation(self):
d = {"a": 2, "b": 3, "c": 5}
mp = types.MappingProxyType(d)
def fn(x):
# Dynamo should not cause a graph break here because it knows that
# the existing proxy can't point to this new dict
other_dict = {}
other_dict["d"] = 4
y = torch.sin(x * mp["c"])
return y
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
x = torch.randn(4)
ref = torch.sin(x * mp["c"])
res = opt_fn(x)
self.assertEqual(ref, res)
self.assertEqual(d.keys(), mp.keys())
def test_move_to_end(self):
def fn(x):
d = OrderedDict({"a": torch.cos(x), "b": 3, "c": 5})
d.move_to_end("a")
return d
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
x = torch.randn(4)
self.assertEqual(["b", "c", "a"], list(opt_fn(x).keys()))
self.assertEqual(fn(x), opt_fn(x))
def test_mapping_proxy_ban_muation_on_dict_realization(self):
def fn(x):
class Foo:
b = 4
d = dict(Foo.__dict__)
y = torch.sin(x) * d["b"]
# This should cause a graph break, because otherwise the
# Foo.__dict__ will not be updated.
Foo.bar = 3
return Foo, y * Foo.__dict__["bar"]
opt_fn = torch.compile(fn, backend="eager")
x = torch.randn(4)
foo1, ref = fn(x)
foo2, res = opt_fn(x)
self.assertEqual(ref, res)
self.assertEqual(foo1.bar, foo2.bar)
def test_overridden_get_item(self):
class MyDict(dict):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.calls = 0
def __getitem__(self, key):
self.calls += 1
return super().__getitem__(key) + 1
def fn(x, d):
d["d"] = 4
return x * d["a"] + d["b"] + d["c"] + d["d"]
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
x = torch.randn(4)
d1 = MyDict({"a": 2, "b": 3, "c": 5})
ref = fn(x, d1)
d2 = MyDict({"a": 2, "b": 3, "c": 5})
res = opt_fn(x, d2)
self.assertEqual(ref, res)
self.assertEqual(d1.calls, d2.calls)
def test_items_type(self):
def fn():
d = dict({"a": 1, "b": "2", "c": torch.tensor(3)}) # noqa: C418
return d.items()
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
ref = fn()
res = opt_fn()
self.assertEqual(ref, res)
self.assertEqual(type(res), dict_items)
def test_builtin_or_with_invalid_types(self):
args = (
1, # int
1.0, # float
"a", # str
(1, 2), # tuple
[1, 2], # list
)
@torch.compile(backend="eager", fullgraph=True)
def fn(b: Any):
a = {"one": torch.ones(1)}
return a | b
from torch._dynamo.exc import Unsupported
for arg in args:
with self.assertRaisesRegex(Unsupported, "Observed exception"):
_ = fn(arg)
def test_builtin_or_with_diff_keys(self):
def f():
a = {"one": torch.ones(1)}
b = {"two": torch.ones(2)}
return a, b, a | b, b | a, a.__or__(b), b.__or__(a)
opt_f = torch.compile(f, backend="eager", fullgraph=True)
self.assertEqual(f(), opt_f())
def test_builtin_or_with_same_keys(self):
def f():
a = {"one": torch.ones(1), "two": torch.ones(2)}
b = {"one": torch.ones(1), "three": torch.ones(3)}
return a, b, a | b, b | a, a.__or__(b), b.__or__(a)
opt_f = torch.compile(f, backend="eager", fullgraph=True)
self.assertEqual(f(), opt_f())
def test_builtin_ior_(self):
def f():
a = {"one": torch.ones(1)}
b = {"two": torch.ones(2)}
a |= b
return a, b
opt_f = torch.compile(f, backend="eager", fullgraph=True)
self.assertEqual(f(), opt_f())
def test_newly_constructed_default_dict(self):
def f(x):
d = defaultdict(list)
d[0] = [
42,
]
return x + 1, d
x = torch.ones(2)
ref = f(x)
res = torch.compile(f, backend="eager", fullgraph=True)(x)
self.assertEqual(ref, res)
@unittest.expectedFailure
def test_newly_constructed_default_dict_with_dict(self):
def f(x):
d = defaultdict(dict, {2: {"a": 1}})
d[0] = {"b": 2}
return x + 1, d
x = torch.ones(2)
ref = f(x)
res = torch.compile(f, backend="eager", fullgraph=True)(x)
self.assertEqual(ref, res)
def test_iter_default_dict(self):
def f(x):
d = defaultdict(list)
d[0] = 42
for k in d:
d[k] += 1
return x + 1, d
x = torch.ones(2)
ref = f(x)
res = torch.compile(f, backend="eager", fullgraph=True)(x)
self.assertEqual(ref, res)
@parametrize("op", ["or_", "and_", "xor", "sub"])
def test_dict_keys_binop(self, op):
op = getattr(operator, op)
def f():
a = {"one": torch.ones(1), "two": torch.ones(2)}
b = {"one": torch.ones(1), "three": torch.ones(3)}
return op(a.keys(), b.keys()), op(b.keys(), a.keys())
opt_f = torch.compile(f, backend="eager", fullgraph=True)
self.assertEqual(f(), opt_f())
@parametrize("op", ["ior", "iand", "ixor", "isub"])
def test_dict_keys_inplace_binop(self, op):
op = getattr(operator, op)
def f():
a = {"one": torch.ones(1), "two": torch.ones(2)}.keys()
b = {"one": torch.ones(1), "three": torch.ones(3)}.keys()
c = {"one": torch.ones(1), "two": torch.ones(2)}.keys()
a = op(a, b)
b = op(b, c)
return a, b
opt_f = torch.compile(f, backend="eager", fullgraph=True)
self.assertEqual(f(), opt_f())
instantiate_parametrized_tests(DictTests)
class DictGuardTests(LoggingTestCase):
thetype = dict
@make_logging_test(recompiles=True)
def test_popitem(self, records):
d = self.thetype()
d[1] = 2
d[3] = 4
@torch.compile(backend="eager", fullgraph=True)
def fn(x):
k, v = d.popitem()
if k == 3 and v == 4:
return x.sin()
return x.cos()
x = torch.tensor(1.0)
y = fn(x)
# sanity check
self.assertEqual(len(records), 0)
self.assertEqual(y, x.sin())
d[3] = 5
y = fn(x)
self.assertEqual(len(records), 1)
self.assertEqual(y, x.cos())
record = self.getRecord(records, "d")
self.assertIn(
"""d[3] == 4""",
munge_exc(record),
)
@make_logging_test(recompiles=True)
def test_cmp_eq(self, records):
@torch.compile(backend="eager", fullgraph=True)
def fn(x, d1, d2):
if d1 == d2:
return x.sin()
return x.cos()
x = torch.tensor(1.0)
d1 = self.thetype({1: 2, 3: 4})
d2 = self.thetype({1: 2, 5: 6})
y = fn(x, d1, d2)
# sanity check
self.assertEqual(len(records), 0)
self.assertEqual(y, x.cos())
y = fn(x, d1, d1)
self.assertEqual(len(records), 1)
self.assertEqual(y, x.sin())
record = self.getRecord(records, "d2")
self.assertIn(
"""list(dict.keys(d2))""",
munge_exc(record.getMessage()),
)
@make_logging_test(recompiles=True)
def test_cmp_ne(self, records):
@torch.compile(backend="eager", fullgraph=True)
def fn(x, d1, d2):
if d1 == d2:
return x.sin()
return x.cos()
x = torch.tensor(1.0)
d1 = self.thetype({1: 2, 3: 4})
d2 = self.thetype({1: 2, 5: 6})
y = fn(x, d1, d2)
# sanity check
self.assertEqual(len(records), 0)
self.assertEqual(y, x.cos())
y = fn(x, d1, d1)
self.assertEqual(len(records), 1)
self.assertEqual(y, x.sin())
record = self.getRecord(records, "d2")
self.assertIn(
"""list(dict.keys(d2))""",
munge_exc(record.getMessage()),
)
@make_logging_test(recompiles=True)
def test_cmp_or(self, records):
@torch.compile(backend="eager", fullgraph=True)
def fn(x, d1, d2):
d = d1 | d2
if d.get(5, False):
return x.sin()
return x.cos()
x = torch.tensor(1.0)
d1 = self.thetype({1: 2, 3: 4})
d2 = self.thetype({1: 2, 5: 6})
y = fn(x, d1, d2)
# sanity check
self.assertEqual(len(records), 0)
self.assertEqual(y, x.sin())
y = fn(x, d1, d1)
self.assertEqual(len(records), 1)
self.assertEqual(y, x.cos())
record = self.getRecord(records, "d2")
self.assertIn(
"""KeyError on d2[5]""",
munge_exc(record.getMessage()),
)
@make_logging_test(recompiles=True)
def test_cmp_ior(self, records):
@torch.compile(backend="eager", fullgraph=True)
def fn(x, d1, d2):
d2 |= d1
if d2.get(3, False):
return x.sin()
return x.cos()
x = torch.tensor(1.0)
d1 = self.thetype({1: 2, 3: 4})
d2 = self.thetype({1: 2, 5: 6})
d3, d4 = d2.copy(), d2.copy()
y = fn(x, d1, d2)
# sanity check
self.assertEqual(len(records), 0)
self.assertEqual(y, x.sin())
y = fn(x, d3, d4)
self.assertEqual(len(records), 1)
self.assertEqual(y, x.cos())
record = self.getRecord(records, "d1")
self.assertIn(
"""KeyError on d1[3]""",
munge_exc(record.getMessage()),
)
class DictMethodsTests(torch._dynamo.test_case.TestCase):
thetype = dict
# Methods:
# + clear
# + copy
# + fromkeys
# + get
# + items
# + keys
# + pop
# + popitem
# + setdefault
# + update
# + values
# BinOps:
# ==, !=, |
def setUp(self):
torch._dynamo.config.enable_trace_unittest = True
super().setUp()
def tearDown(self):
torch._dynamo.config.enable_trace_unittest = False
return super().tearDown()
def assertEqual(self, x, y):
self.assertTrue(x == y, f"Expected {x} to be equal to {y}")
def assertNotEqual(self, x, y):
self.assertFalse(x == y, f"Expected {x} to not be equal to {y}")
@make_dynamo_test
def test_cmp_eq(self):
d1 = self.thetype({"a": 1, "b": 2})
d2 = self.thetype({"a": 1, "b": 2})
d3 = self.thetype({"a": 1, "b": 3})
self.assertEqual(d1, d2)
self.assertNotEqual(d1, d3)
# Test the == operator
self.assertEqual(d1 == d2, True)
self.assertEqual(d1 == d3, False)
# Test the __eq__ method
self.assertEqual(d1.__eq__(d2), True)
self.assertEqual(d1.__eq__(d3), False)
# Test Dict.__eq__
self.assertEqual(dict.__eq__(d1, d2), True)
self.assertEqual(self.thetype.__eq__(d1, d3), False)
@make_dynamo_test
def test_cmp_ne(self):
d1 = self.thetype({"a": 1, "b": 2})
d2 = self.thetype({"a": 1, "b": 2})
d3 = self.thetype({"a": 1, "b": 3})
self.assertNotEqual(d1, d3)
self.assertEqual(d1, d2)
# Test the != operator
self.assertEqual(d1 != d3, True)
self.assertEqual(d1 != d2, False)
# Test the __ne__ method
self.assertEqual(d1.__ne__(d3), True)
self.assertEqual(d1.__ne__(d2), False)
# Test Dict.__ne__
self.assertEqual(dict.__ne__(d1, d3), True)
self.assertEqual(self.thetype.__ne__(d1, d2), False)
@make_dynamo_test
def test_binop_or(self):
d1 = self.thetype({"a": 1, "b": 2})
d2 = self.thetype({"b": 3, "c": 4})
# Test the | operator
self.assertEqual(d1 | d2, {"a": 1, "b": 3, "c": 4})
self.assertEqual(d2 | d1, {"a": 1, "b": 2, "c": 4})
# Test the __or__ method
self.assertEqual(d1.__or__(d2), {"a": 1, "b": 3, "c": 4})
self.assertEqual(d2.__or__(d1), {"a": 1, "b": 2, "c": 4})
# Test Dict.__or__
self.assertEqual(dict.__or__(d1, d2), {"a": 1, "b": 3, "c": 4})
self.assertEqual(self.thetype.__or__(d2, d1), {"a": 1, "b": 2, "c": 4})
# Test with non-dict types
self.assertRaises(TypeError, lambda: d1 | 1)
@make_dynamo_test
def test_binop_ior(self):
d1 = self.thetype({"a": 1, "b": 2})
d2 = self.thetype({"b": 3, "c": 4})
# Test the |= operator
d3, d4 = d1.copy(), d2.copy()
d3 |= d2
d4 |= d1
self.assertEqual(d3, {"a": 1, "b": 3, "c": 4})
self.assertEqual(d4, {"a": 1, "b": 2, "c": 4})
# Test with an iterable
d3, d4 = d1.copy(), d2.copy()
# Test the __ior__ method
d3, d4 = d1.copy(), d2.copy()
d3.__ior__(d2)
d4.__ior__(d1)
self.assertEqual(d3, {"a": 1, "b": 3, "c": 4})
self.assertEqual(d4, {"a": 1, "b": 2, "c": 4})
# Test Dict.__or__
d3, d4 = d1.copy(), d2.copy()
self.assertEqual(dict.__ior__(d3, d2), {"a": 1, "b": 3, "c": 4})
self.assertEqual(self.thetype.__ior__(d4, d1), {"a": 1, "b": 2, "c": 4})
# Test return value
d3, d4 = d1.copy(), d2.copy()
self.assertEqual(d3.__ior__(d2), {"a": 1, "b": 3, "c": 4})
self.assertEqual(dict.__ior__(d4, d1), {"a": 1, "b": 2, "c": 4})
# Test with non-dict types
self.assertRaises(TypeError, lambda: dict.__ior__(d1, 1))
@make_dynamo_test
def test_binop_ior_iterable(self):
d1 = self.thetype({"a": 1, "b": 2})
d2 = self.thetype({"b": 3, "c": 4})
d3, d4 = d1.copy(), d2.copy()
def fn(d):
yield from d.items()
self.assertEqual(d3.__ior__(d2.items()), {"a": 1, "b": 3, "c": 4})
self.assertEqual(d4.__ior__(fn(d1)), {"a": 1, "b": 2, "c": 4})
@make_dynamo_test
def test_clear(self):
d = self.thetype({"a": 1, "b": 2})
d.clear()
self.assertEqual(d, {})
# Test that clear returns None
d = self.thetype({"a": 1, "b": 2})
self.assertIsNone(d.clear())
# Test Dict.clear
d = self.thetype({"a": 1, "b": 2})
dict.clear(d)
self.assertEqual(d, {})
d = self.thetype({"a": 1, "b": 2})
self.thetype.clear(d)
self.assertEqual(d, {})
# Test invalid usage
self.assertRaises(TypeError, d.clear, 1)
@make_dynamo_test
def test_copy(self):
d = self.thetype({"a": 1, "b": 2})
d2 = d.copy()
self.assertEqual(d, d2)
# Test that copy returns a new instance
self.assertIsNot(d, d2)
# Test Dict.copy
self.assertEqual(dict.copy(d), d2)
self.assertEqual(self.thetype.copy(d), d2)
# Test invalid usage
self.assertRaises(TypeError, d.copy, 1)
@unittest.expectedFailure
@make_dynamo_test
def test_fromkeys(self):
d = self.thetype.fromkeys(["a", "b"], 1)
self.assertEqual(d, {"a": 1, "b": 1})
p = self.thetype.fromkeys(["a", "b"], None)
self.assertEqual(p, {"a": None, "b": None})
# Test Dict.fromkeys
d2 = self.thetype.fromkeys(["c", "d"], 2)
self.assertEqual(d2, {"c": 2, "d": 2})
# Test invalid usage
self.assertRaises(TypeError, self.thetype.fromkeys)
self.assertRaises(TypeError, self.thetype.fromkeys, 1, 2)
@make_dynamo_test
def test_get(self):
d = self.thetype({"a": 1, "b": 2})
self.assertEqual(d.get("a"), 1)
self.assertEqual(d.get("c", 3), 3)
self.assertIsNone(d.get("c"))
# Test Dict.get
self.assertEqual(dict.get(d, "b"), 2)
self.assertEqual(self.thetype.get(d, "b"), 2)
# Test invalid usage
self.assertRaises(TypeError, d.get)
@make_dynamo_test
def test_items(self):
d = self.thetype({"a": 1, "b": 2})
items = d.items()
self.assertEqual(set(items), {("a", 1), ("b", 2)})
# Test Dict.items
self.assertEqual(set(dict.items(d)), {("a", 1), ("b", 2)})
self.assertEqual(set(self.thetype.items(d)), {("a", 1), ("b", 2)})
# Test invalid usage
self.assertRaises(TypeError, d.items, 1)
@make_dynamo_test
def test_keys(self):
d = self.thetype({"a": 1, "b": 2})
keys = d.keys()
self.assertEqual(set(keys), {"a", "b"})
# Test Dict.keys
self.assertEqual(set(dict.keys(d)), {"a", "b"})
self.assertEqual(set(self.thetype.keys(d)), {"a", "b"})
# Test invalid usage
self.assertRaises(TypeError, d.keys, 1)
@make_dynamo_test
def test_pop(self):
d = self.thetype({"a": 1, "b": 2})
self.assertEqual(d.pop("a"), 1)
self.assertEqual(d, {"b": 2})
self.assertIsNone(d.pop("c", None))
# Test Dict.pop
d = self.thetype({"a": 1, "b": 2})
self.assertEqual(dict.pop(d, "b"), 2)
self.assertEqual(self.thetype.pop(d, "a"), 1)
# Test invalid usage
self.assertRaises(KeyError, d.pop, "c")
self.assertRaises(TypeError, d.pop)
@make_dynamo_test
def test_popitem(self):
d = self.thetype({"a": 1})
key, value = d.popitem()
self.assertEqual(key, "a")
self.assertEqual(value, 1)
self.assertEqual(len(d), 0)
# check LIFO
d = self.thetype()
d["a"] = 1
d["b"] = 2
self.assertEqual(d.popitem(), ("b", 2))
# Test Dict.popitem
d = self.thetype({"a": 1})
key, value = dict.popitem(d)
self.assertEqual(key, "a")
self.assertEqual(value, 1)
d = self.thetype({"a": 1})
key, value = self.thetype.popitem(d)
self.assertEqual(key, "a")
self.assertEqual(value, 1)
# Test invalid usage
if self.thetype is not OrderedDict:
# OrderedDict accepts a keyword arg
self.assertRaises(TypeError, d.popitem, 1)
@make_dynamo_test
def test_setdefault(self):
d = self.thetype({"a": 1, "b": 2})
self.assertEqual(d.setdefault("a", 3), 1)
self.assertEqual(d.setdefault("c", 3), 3)
self.assertIsNone(d.setdefault("d"), None)
self.assertEqual(d, {"a": 1, "b": 2, "c": 3, "d": None})
# Test Dict.setdefault
self.assertEqual(dict.setdefault(d, "f", 5), 5)
self.assertEqual(self.thetype.setdefault(d, "e", 5), 5)
# Test invalid usage
self.assertRaises(TypeError, d.setdefault)
self.assertRaises(TypeError, d.setdefault, [[]])
@make_dynamo_test
def test_update(self):
d = self.thetype({"a": 1, "b": 2})
d.update({"b": 3, "c": 4})
self.assertEqual(d, {"a": 1, "b": 3, "c": 4})
# Test with another dict
d2 = self.thetype({"d": 5})
d.update(d2)
self.assertEqual(d, {"a": 1, "b": 3, "c": 4, "d": 5})
# Test Dict.update
d3 = self.thetype({"e": 6})
dict.update(d, d3)
self.assertEqual(d, {"a": 1, "b": 3, "c": 4, "d": 5, "e": 6})
d4 = self.thetype({"f": 7})
self.thetype.update(d, d4)
self.assertEqual(d, {"a": 1, "b": 3, "c": 4, "d": 5, "e": 6, "f": 7})
# Test with keyword arguments
d.update(f=7, g=8)
self.assertEqual(d, {"a": 1, "b": 3, "c": 4, "d": 5, "e": 6, "f": 7, "g": 8})
# Test Dict.update with keyword arguments
self.thetype.update(d, h=9, i=10)
self.assertEqual(
d, {"a": 1, "b": 3, "c": 4, "d": 5, "e": 6, "f": 7, "g": 8, "h": 9, "i": 10}
)
# Test invalid usage
self.assertRaises(TypeError, d.update, 1)
@make_dynamo_test
def test_values(self):
d = self.thetype({"a": 1, "b": 2})
values = d.values()
self.assertEqual(set(values), {1, 2})
# Test Dict.values
self.assertEqual(set(dict.values(d)), {1, 2})
self.assertEqual(set(self.thetype.values(d)), {1, 2})
# Test invalid usage
self.assertRaises(TypeError, d.values, 1)
@make_dynamo_test
def test_type(self):
d = self.thetype({"a": 1, "b": 2})
self.assertIsInstance(d, self.thetype)
self.assertIs(type(d), self.thetype)
@make_dynamo_test
def test_dict_type_comparison(self):
types = (dict, OrderedDict, defaultdict)
self.assertEqual(self.thetype, self.thetype)
self.assertTrue(self.thetype is self.thetype)
for other in types:
if self.thetype == other:
continue
self.assertNotEqual(self.thetype, other)
self.assertTrue(self.thetype is not other, f"{self.thetype=}, {other=}")
@make_dynamo_test
def test_dict___iter__(self):
d = self.thetype({1: 2})
it = d.__iter__()
self.assertEqual(next(it), 1)
class DictSubclassMethodsTests(DictMethodsTests):
thetype = SimpleDict
class OrderedDictMethodsTests(DictMethodsTests):
thetype = OrderedDict
# Methods:
# - popitem - Inherited from DictMethodsTest
# + move_to_end
@make_dynamo_test
def test_move_to_end(self):
d = self.thetype.fromkeys("abcde")
self.assertEqual("".join(d), "abcde")
d.move_to_end("b")
self.assertEqual("".join(d), "acdeb")
# Test OrderedDict.move_to_end
self.thetype.move_to_end(d, "a")
self.assertEqual("".join(d), "cdeba")
# Test last=False
self.thetype.move_to_end(d, "a", last=False)
self.assertEqual("".join(d), "acdeb")
# Test KeyError
self.assertRaises(KeyError, d.move_to_end, "f")
def test_cmp_eq_order(self):
a = self.thetype.fromkeys("abc")
b = self.thetype.fromkeys("bca")
self.assertFalse(a == b)
@make_dynamo_test
def test_binop_or_return_type(self):
d1 = self.thetype({"a": 1, "b": 2})
d2 = self.thetype({"b": 3, "c": 4})
# Test return type
self.assertIs(type(d1 | d2), OrderedDict)
self.assertIs(type(dict(d1) | d2), OrderedDict)
self.assertIs(type(d1 | dict(d2)), OrderedDict)
@make_dynamo_test
def test_binop_ior_return_type(self):
d1 = self.thetype({"a": 1, "b": 2})
d2 = self.thetype({"b": 3, "c": 4})
# Test return type
d3, d4 = d1.copy(), d2.copy()
self.assertIs(type(d3.__ior__(d2)), OrderedDict)
self.assertIs(type(dict.__ior__(d4, d2)), OrderedDict)
self.assertIs(type(self.thetype.__ior__(d4, d2)), OrderedDict)
d3, d4 = d1.copy(), d2.copy()
self.assertIs(type(dict.__ior__(d3, dict(d2))), OrderedDict)
self.assertIs(type(dict.__ior__(dict(d3), d2)), dict)
self.assertIs(type(dict(d4).__ior__(d2)), dict)
@make_dynamo_test
def test_popitem_kwarg(self):
d = self.thetype.fromkeys("abcdf")
self.assertEqual(d.popitem(last=True), ("f", None))
self.assertEqual(list(d), list("abcd"))
self.assertEqual(d.popitem(last=False), ("a", None))
self.assertEqual(list(d), list("bcd"))
self.assertEqual(d.popitem(False), ("b", None))
self.assertEqual(list(d), list("cd"))
self.assertEqual(d.popitem(True), ("d", None))
self.assertEqual(list(d), list("c"))
class OrderedDictSubclassOverload(torch._dynamo.test_case.TestCase):
def setUp(self):
torch._dynamo.config.enable_trace_unittest = True
super().setUp()
def tearDown(self):
torch._dynamo.config.enable_trace_unittest = False
return super().tearDown()
def assertEqual(self, x, y):
self.assertTrue(x == y, f"Expected {x} to be equal to {y}")
def assertNotEqual(self, x, y):
self.assertFalse(x == y, f"Expected {x} to not be equal to {y}")
class OrderedDictSubclass(OrderedDict):
def get(self, key, default=None, /):
return default
def move_to_end(self, key, last=True, /):
# change the behavior to something else
self.pop(key)
thetype = OrderedDictSubclass
@make_dynamo_test
def test_move_to_end(self):
p = self.thetype({"a": 1, "b": 2, "c": 3})
p.move_to_end("a")
self.assertEqual(list(p.keys()), list("bc"))
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
run_tests()