mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
This PR removes unused loop variables in tests. Pull Request resolved: https://github.com/pytorch/pytorch/pull/166690 Approved by: https://github.com/justinchuby, https://github.com/mlazos
1756 lines
51 KiB
Python
1756 lines
51 KiB
Python
# Owner(s): ["module: dynamo"]
|
|
|
|
# ruff: noqa: TRY002
|
|
|
|
import enum
|
|
import itertools
|
|
import operator
|
|
import types
|
|
import unittest
|
|
import weakref
|
|
from collections import defaultdict, namedtuple, OrderedDict, UserDict
|
|
from typing import Any
|
|
|
|
import torch
|
|
import torch._dynamo.test_case
|
|
import torch._dynamo.testing
|
|
import torch._functorch.config
|
|
import torch.nn
|
|
import torch.utils.checkpoint
|
|
from torch._dynamo.testing import same
|
|
from torch._dynamo.utils import dict_items
|
|
from torch.testing._internal.common_utils import (
|
|
instantiate_parametrized_tests,
|
|
make_dynamo_test,
|
|
munge_exc,
|
|
parametrize,
|
|
)
|
|
from torch.testing._internal.logging_utils import LoggingTestCase, make_logging_test
|
|
|
|
|
|
class SimpleDict(dict):
|
|
pass
|
|
|
|
|
|
class DummyUserDict(UserDict):
|
|
pass
|
|
|
|
|
|
class DictTests(torch._dynamo.test_case.TestCase):
|
|
def test_dict_subclass_instantiation(self):
|
|
def fn(x):
|
|
sd = SimpleDict(x=5)
|
|
return sd["x"] * x
|
|
|
|
x = torch.randn(4)
|
|
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
|
|
self.assertEqual(fn(x), opt_fn(x))
|
|
|
|
def test_dict_subclass_local_mutation(self):
|
|
def fn(x):
|
|
sd = SimpleDict(x=5)
|
|
z = sd["x"] * x
|
|
sd["x"] = 10
|
|
return z * sd["x"]
|
|
|
|
x = torch.randn(4)
|
|
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
|
|
self.assertEqual(fn(x), opt_fn(x))
|
|
|
|
def test_dict_contains_enum(self):
|
|
class TensorDim(str, enum.Enum):
|
|
DDP = "ddp"
|
|
FSDP = "fsdp"
|
|
CP = "cp"
|
|
TP = "tp"
|
|
|
|
class Foo(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
|
|
def forward(self, x):
|
|
val = x.sin()
|
|
if TensorDim.DDP in {"ddp"}:
|
|
val += x.cos()
|
|
if "ddp" in {TensorDim.DDP}:
|
|
val += x.cos()
|
|
return val
|
|
|
|
inp = torch.randn(4, 4)
|
|
mod = Foo()
|
|
opt_f = torch.compile(mod)
|
|
self.assertEqual(mod(inp), opt_f(inp))
|
|
|
|
def test_dict_subclass_local_with_non_dict_method(self):
|
|
# Checks that add_1 method is inlined
|
|
class MethodDict(dict):
|
|
def add_1(self, x):
|
|
return x + 1
|
|
|
|
def fn(x):
|
|
sd = MethodDict(x=5)
|
|
z = sd["x"] * x
|
|
sd["x"] = 10
|
|
return sd.add_1(z * sd["x"])
|
|
|
|
x = torch.randn(4)
|
|
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
|
|
self.assertEqual(fn(x), opt_fn(x))
|
|
|
|
def test_dict_contains(self):
|
|
sd = dict()
|
|
sd[2] = 5
|
|
sd[4] = 10
|
|
|
|
def fn(x):
|
|
if 1 in sd:
|
|
x = x * 2
|
|
else:
|
|
x = x * 3
|
|
return x
|
|
|
|
x = torch.randn(4)
|
|
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
|
|
self.assertEqual(fn(x), opt_fn(x))
|
|
|
|
# Ensure a recompilation
|
|
sd[1] = 15
|
|
self.assertEqual(fn(x), opt_fn(x))
|
|
|
|
# Ensure not recompilation because the traced program remains same here.
|
|
sd[2] = 10
|
|
with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True):
|
|
self.assertEqual(fn(x), opt_fn(x))
|
|
|
|
def test_dict_subclass_methods_fallback_readonly(self):
|
|
sd = SimpleDict()
|
|
sd[2] = 5
|
|
sd[4] = 10
|
|
# check that regular attr accesses work well
|
|
sd.attr = 4
|
|
|
|
def fn(x):
|
|
for value in sd.values():
|
|
x = x * value
|
|
for key in sd.keys():
|
|
x = x * key
|
|
for k, v in sd.items():
|
|
x = x * k
|
|
x = x * v
|
|
# for k in sd:
|
|
# x = x * k
|
|
|
|
if 1 in sd:
|
|
x = x * 2
|
|
else:
|
|
x = x * 3
|
|
|
|
x = x * sd.get(2, 0)
|
|
x = x * sd.get(3, 4)
|
|
x = len(sd) * x
|
|
x = x * sd.attr
|
|
return x
|
|
|
|
x = torch.randn(4)
|
|
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
|
|
self.assertEqual(fn(x), opt_fn(x))
|
|
|
|
# Ensure a recompilation
|
|
sd[6] = 15
|
|
self.assertEqual(fn(x), opt_fn(x))
|
|
|
|
def test_dict_subclass_instantiation_return(self):
|
|
def fn(x):
|
|
sd = SimpleDict(x=5 * x)
|
|
sd["y"] = 10
|
|
return sd
|
|
|
|
x = torch.randn(4)
|
|
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
|
|
ref = fn(x)
|
|
res = opt_fn(x)
|
|
self.assertEqual(type(ref), type(res))
|
|
self.assertEqual(ref["x"], res["x"])
|
|
self.assertEqual(ref["y"], res["y"])
|
|
|
|
def test_dict_subclass_methods_fallback_mutation(self):
|
|
def fn(sd, x):
|
|
for value in sd.values():
|
|
x = x * value
|
|
sd[6] = 14
|
|
for key in sd.keys():
|
|
x = x * key
|
|
for k, v in sd.items():
|
|
x = x * k
|
|
x = x * v
|
|
# for k in sd:
|
|
# x = x * k
|
|
|
|
if 1 in sd:
|
|
x = x * 2
|
|
else:
|
|
x = x * 3
|
|
|
|
x = x * sd.get(2, 0)
|
|
x = x * sd.get(3, 4)
|
|
x = len(sd) * x
|
|
return x
|
|
|
|
x = torch.randn(4)
|
|
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
|
|
|
|
sd1 = SimpleDict()
|
|
sd1[2] = 5
|
|
sd1[4] = 10
|
|
|
|
sd2 = SimpleDict()
|
|
sd2[2] = 5
|
|
sd2[4] = 10
|
|
self.assertTrue(sd1 == sd2)
|
|
|
|
self.assertEqual(fn(sd1, x), opt_fn(sd2, x))
|
|
self.assertTrue(sd1 == sd2)
|
|
|
|
def test_dict_subclass_setitem(self):
|
|
class SetItemDict(dict):
|
|
def __setitem__(self, key, value):
|
|
super().__setitem__(key, value + 1)
|
|
|
|
def fn(x):
|
|
sd = SetItemDict(x=5 * x)
|
|
sd["y"] = 10
|
|
return sd
|
|
|
|
x = torch.randn(4)
|
|
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
|
|
ref = fn(x)
|
|
res = opt_fn(x)
|
|
self.assertEqual(type(ref), type(res))
|
|
self.assertEqual(ref["x"], res["x"])
|
|
self.assertEqual(ref["y"], res["y"])
|
|
|
|
def test_custom_iter_dict(self):
|
|
class ReversedDict(dict):
|
|
def __iter__(self):
|
|
return reversed(list(self.keys()))
|
|
|
|
d = {
|
|
"foo": 1,
|
|
"bar": 2,
|
|
}
|
|
|
|
d = ReversedDict(d)
|
|
|
|
@torch.compile(backend="eager")
|
|
def fn(x, d):
|
|
# Forces side effects attribute reapplication logic
|
|
d.sample = 1
|
|
d["baz"] = 4
|
|
return x * d["foo"] * d["bar"]
|
|
|
|
fn(torch.randn(4), d)
|
|
# This is intentional because the dict is mutated, so we will have a recompilation.
|
|
fn(torch.randn(4), d)
|
|
with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True):
|
|
fn(torch.randn(4), d)
|
|
|
|
def test_custom_keys_iter_dict(self):
|
|
class ReversedDict(dict):
|
|
def keys(self):
|
|
return ["bar", "foo"]
|
|
|
|
d = {
|
|
"foo": 1,
|
|
"bar": 2,
|
|
}
|
|
|
|
d = ReversedDict(d)
|
|
|
|
@torch.compile(backend="eager")
|
|
def fn(x, d):
|
|
return x * d["foo"] * d["bar"]
|
|
|
|
fn(torch.randn(4), d)
|
|
with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True):
|
|
fn(torch.randn(4), d)
|
|
|
|
def test_dict_guard_on_keys_order(self):
|
|
d = {
|
|
2: 4,
|
|
3: 5,
|
|
}
|
|
|
|
cnts = torch._dynamo.testing.CompileCounter()
|
|
|
|
def fn(x, d):
|
|
for key, value in d.items():
|
|
x = x * key + value
|
|
return x
|
|
|
|
opt_fn = torch.compile(fn, backend=cnts)
|
|
opt_fn(torch.randn(4), d)
|
|
opt_fn(torch.randn(4), d)
|
|
# No recompilation
|
|
self.assertEqual(cnts.frame_count, 1)
|
|
|
|
# move 2 to the end
|
|
d[2] = d.pop(2)
|
|
|
|
x = torch.randn(4)
|
|
res = opt_fn(x, d)
|
|
# Check recompilation
|
|
self.assertEqual(cnts.frame_count, 2)
|
|
self.assertEqual(res, fn(x, d))
|
|
|
|
def test_dict_guard_on_keys_order2(self):
|
|
d = {
|
|
2: 4,
|
|
3: 5,
|
|
}
|
|
|
|
cnts = torch._dynamo.testing.CompileCounter()
|
|
|
|
def fn(x, d):
|
|
for key in d:
|
|
value = d[key]
|
|
x = x * key + value
|
|
return x
|
|
|
|
opt_fn = torch.compile(fn, backend=cnts)
|
|
opt_fn(torch.randn(4), d)
|
|
opt_fn(torch.randn(4), d)
|
|
# No recompilation
|
|
self.assertEqual(cnts.frame_count, 1)
|
|
|
|
# move 2 to the end
|
|
d[2] = d.pop(2)
|
|
|
|
x = torch.randn(4)
|
|
res = opt_fn(x, d)
|
|
# Check recompilation
|
|
self.assertEqual(cnts.frame_count, 2)
|
|
self.assertEqual(res, fn(x, d))
|
|
|
|
def test_ordered_dict_reordered_keys(self):
|
|
d = OrderedDict()
|
|
d[2] = 4
|
|
d[3] = 5
|
|
d.move_to_end(2)
|
|
|
|
cnts = torch._dynamo.testing.CompileCounter()
|
|
|
|
def fn(x, d):
|
|
y = 0
|
|
for idx, value in enumerate(d.values()):
|
|
if idx == 0:
|
|
y += torch.sin(x * value)
|
|
else:
|
|
y += torch.cos(x * value)
|
|
return y
|
|
|
|
opt_fn = torch.compile(fn, backend=cnts)
|
|
x = torch.randn(4)
|
|
self.assertEqual(opt_fn(x, d), fn(x, d))
|
|
|
|
def test_ordered_dict_subclass_reordered_keys(self):
|
|
class ODSubclass(OrderedDict):
|
|
def keys(self):
|
|
return super().keys()
|
|
|
|
d = ODSubclass()
|
|
d[2] = 4
|
|
d[3] = 5
|
|
d.move_to_end(2)
|
|
|
|
cnts = torch._dynamo.testing.CompileCounter()
|
|
|
|
def fn(x, d):
|
|
y = 0
|
|
for idx, value in enumerate(d.values()):
|
|
if idx == 0:
|
|
y += torch.sin(x * value)
|
|
else:
|
|
y += torch.cos(x * value)
|
|
return y
|
|
|
|
opt_fn = torch.compile(fn, backend=cnts)
|
|
x = torch.randn(4)
|
|
self.assertEqual(opt_fn(x, d), fn(x, d))
|
|
|
|
def test_lazy_key_guarding(self):
|
|
d = {"a": 2, "b": 3, "c": 5}
|
|
|
|
def fn(x):
|
|
return x * d["a"]
|
|
|
|
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
|
|
|
|
x = torch.randn(4)
|
|
ref = fn(x)
|
|
res = opt_fn(x)
|
|
self.assertEqual(ref, res)
|
|
|
|
# Since key c was not used, it should not lead to a recompilation
|
|
d.pop("c")
|
|
d["d"] = 10
|
|
|
|
with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True):
|
|
ref = fn(x)
|
|
res = opt_fn(x)
|
|
self.assertEqual(ref, res)
|
|
|
|
def test_lazy_key_non_const_guarding(self):
|
|
d = {
|
|
list: 2,
|
|
dict: 3,
|
|
OrderedDict: 5,
|
|
namedtuple: 7,
|
|
}
|
|
|
|
def fn(x):
|
|
return x * d[list]
|
|
|
|
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
|
|
|
|
x = torch.randn(4)
|
|
ref = fn(x)
|
|
res = opt_fn(x)
|
|
self.assertEqual(ref, res)
|
|
|
|
# Since key c was not used, it should not lead to a recompilation
|
|
d.pop(dict)
|
|
d[defaultdict] = 10
|
|
|
|
with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True):
|
|
ref = fn(x)
|
|
res = opt_fn(x)
|
|
self.assertEqual(ref, res)
|
|
|
|
def test_dict_mutation_side_effect(self):
|
|
def fn(d):
|
|
d["c"] = d["a"] + d.pop("b")
|
|
return d
|
|
|
|
args1 = {"a": torch.randn(10), "b": torch.randn(10)}
|
|
args2 = dict(args1)
|
|
assert fn(args1) is args1
|
|
cnts = torch._dynamo.testing.CompileCounter()
|
|
opt_fn = torch.compile(fn, backend=cnts)
|
|
self.assertIs(opt_fn(args2), args2)
|
|
self.assertTrue(same(args1, args2))
|
|
self.assertEqual(cnts.frame_count, 1)
|
|
self.assertEqual(cnts.op_count, 1)
|
|
|
|
def test_dict_copy_alias(self):
|
|
@torch.compile(backend="eager", fullgraph=True)
|
|
def run(x, d0):
|
|
d1 = d0.copy()
|
|
d1[0] = 1
|
|
return x + 1, d1
|
|
|
|
d0 = {}
|
|
res, d1 = run(torch.zeros(1), d0)
|
|
self.assertTrue(same(res, torch.ones(1)))
|
|
self.assertEqual(d0, {})
|
|
self.assertEqual(d1, {0: 1})
|
|
|
|
def test_dict_subclass_get_method(self):
|
|
class dotdict(dict):
|
|
"""dot.notation access to dictionary attributes"""
|
|
|
|
__getattr__ = dict.get
|
|
__setattr__ = dict.__setitem__
|
|
__delattr__ = dict.__delitem__
|
|
|
|
config = dotdict({"a": 1, "b": 2})
|
|
|
|
def fn(x):
|
|
x2 = x * 2 # noqa: F841
|
|
x3 = x * config.get("a", 3)
|
|
return x3
|
|
|
|
x = torch.randn(2)
|
|
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
|
|
self.assertEqual(fn(x), opt_fn(x))
|
|
|
|
def test_dict_order_keys(self):
|
|
def fn(d):
|
|
c = 0
|
|
for v in d.values():
|
|
c += v
|
|
return c
|
|
|
|
args1 = {}
|
|
args1["a"] = torch.rand(10)
|
|
args1["b"] = torch.rand(10)
|
|
cnts = torch._dynamo.testing.CompileCounter()
|
|
opt_fn = torch.compile(fn, backend=cnts)
|
|
self.assertEqual(fn(args1), opt_fn(args1))
|
|
self.assertEqual(cnts.frame_count, 1)
|
|
self.assertEqual(cnts.op_count, 2)
|
|
|
|
# A different order of keys recompiles
|
|
args2 = {}
|
|
args2["b"] = args1["b"]
|
|
args2["a"] = args1["a"]
|
|
self.assertEqual(fn(args2), opt_fn(args2))
|
|
self.assertEqual(cnts.frame_count, 2)
|
|
# Extra calls don't recompile
|
|
self.assertEqual(cnts.frame_count, 2)
|
|
|
|
def test_dict_namedtuple(self):
|
|
def fn(d):
|
|
if namedtuple in d:
|
|
return d[3] * 2
|
|
else:
|
|
return d[3] * 3
|
|
|
|
args1 = {namedtuple: None, 3: torch.randn(3)}
|
|
cnts = torch._dynamo.testing.CompileCounter()
|
|
opt_fn = torch.compile(fn, backend=cnts)
|
|
self.assertEqual(fn(args1), opt_fn(args1))
|
|
self.assertEqual(cnts.frame_count, 1)
|
|
# Test a failing namedtuple guard
|
|
args2 = {2: None, 3: torch.randn(3)}
|
|
self.assertEqual(fn(args2), opt_fn(args2))
|
|
self.assertEqual(cnts.frame_count, 2)
|
|
|
|
def test_dict_order_keys_tensors(self):
|
|
def fn(d, x):
|
|
return d[x] + 3
|
|
|
|
args1 = {}
|
|
x = torch.randn(10)
|
|
y = torch.randn(10)
|
|
z = torch.randn(10)
|
|
args1[x] = y
|
|
args1[3] = z
|
|
|
|
cnts = torch._dynamo.testing.CompileCounter()
|
|
opt_fn = torch.compile(fn, backend=cnts)
|
|
self.assertEqual(fn(args1, x), opt_fn(args1, x))
|
|
self.assertEqual(cnts.frame_count, 1)
|
|
|
|
# Calling again doesn't recompile (same id and key order)
|
|
opt_fn(args1, x)
|
|
self.assertEqual(cnts.frame_count, 1)
|
|
args2 = {}
|
|
args2[3] = z
|
|
args2[x] = y
|
|
|
|
# Different order recompiles
|
|
self.assertEqual(fn(args2, x), opt_fn(args2, x))
|
|
self.assertEqual(cnts.frame_count, 2)
|
|
|
|
def test_dict_order_keys_modules(self):
|
|
def fn(d, x):
|
|
return d[x](torch.ones(2, 2))
|
|
|
|
args1 = {}
|
|
x = torch.nn.Linear(2, 2)
|
|
y = torch.nn.Linear(2, 2)
|
|
z = torch.nn.Linear(2, 2)
|
|
args1[x] = y
|
|
args1[3] = z
|
|
|
|
cnts = torch._dynamo.testing.CompileCounter()
|
|
opt_fn = torch.compile(fn, backend=cnts)
|
|
self.assertEqual(fn(args1, x), opt_fn(args1, x))
|
|
self.assertEqual(cnts.frame_count, 1)
|
|
|
|
# Calling again doesn't recompile (same id and key order)
|
|
opt_fn(args1, x)
|
|
self.assertEqual(cnts.frame_count, 1)
|
|
args2 = {}
|
|
args2[3] = z
|
|
args2[x] = y
|
|
|
|
# Different order recompiles
|
|
self.assertEqual(fn(args2, x), opt_fn(args2, x))
|
|
self.assertEqual(cnts.frame_count, 2)
|
|
|
|
def test_contains_dunder_dict(self):
|
|
class UserDefined:
|
|
def __init__(self) -> None:
|
|
self.a = 3
|
|
self.b = 5
|
|
|
|
def run(self, x):
|
|
if "a" in self.__dict__:
|
|
x = x * self.a
|
|
if "b" in self.__dict__:
|
|
x = x * self.b
|
|
self.c = 7
|
|
if "c" in self.__dict__:
|
|
x = x * self.c
|
|
return x * self.__dict__.get("a") * self.__dict__.get("z", 2)
|
|
|
|
obj = UserDefined()
|
|
|
|
def fn(x):
|
|
return obj.run(x)
|
|
|
|
x = torch.randn(4)
|
|
ref = fn(x)
|
|
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
|
|
res = opt_fn(x)
|
|
self.assertEqual(ref, res)
|
|
|
|
def test_contains_module_dunder_dict(self):
|
|
class MyModule(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.foo = 1
|
|
self.bar = 2
|
|
self.baz = 3
|
|
|
|
def forward(self, x):
|
|
if "foo" in self.__dict__:
|
|
return x * self.bar
|
|
return x * self.baz
|
|
|
|
mod = MyModule()
|
|
x = torch.randn(10)
|
|
opt_mod = torch.compile(mod, backend="eager", fullgraph=True)
|
|
self.assertEqual(mod(x), opt_mod(x))
|
|
|
|
def test_update_dunder_dict(self):
|
|
class UserDefined:
|
|
def run(self, x):
|
|
self.__dict__["a"] = 10
|
|
return x * self.a + self.__dict__["a"]
|
|
|
|
obj1 = UserDefined()
|
|
obj2 = UserDefined()
|
|
|
|
def fn(x, obj):
|
|
return obj.run(x)
|
|
|
|
x = torch.randn(4)
|
|
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
|
|
ref = fn(x, obj1)
|
|
res = opt_fn(x, obj2)
|
|
self.assertEqual(ref, res)
|
|
# Make sure only `a` is updated.
|
|
self.assertEqual(obj1.__dict__, obj2.__dict__)
|
|
|
|
def test_update_module_dunder_dict(self):
|
|
class MyModule(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def forward(self, x):
|
|
self.__dict__["a"] = 10
|
|
return x * self.a + self.__dict__["a"]
|
|
|
|
mod = MyModule()
|
|
x = torch.randn(10)
|
|
opt_mod = torch.compile(mod, backend="eager", fullgraph=True)
|
|
self.assertEqual(mod(x), opt_mod(x))
|
|
|
|
def test_dict_reconstruct_keeps_original_order(self):
|
|
def fn():
|
|
modules = OrderedDict([("act", torch.nn.ReLU())])
|
|
module_dict = torch.nn.ModuleDict(modules)
|
|
|
|
next_modules = {"fc4": torch.nn.Linear(5, 6), "act3": torch.nn.Sigmoid()}
|
|
modules.update(next_modules.items())
|
|
module_dict.update(next_modules)
|
|
return modules, module_dict
|
|
|
|
cnts = torch._dynamo.testing.CompileCounter()
|
|
opt_fn = torch.compile(fn, backend=cnts)
|
|
modules, module_dict = opt_fn()
|
|
|
|
self.assertEqual(len(module_dict), len(modules))
|
|
for k1, m2 in zip(modules, module_dict.children()):
|
|
self.assertTrue(modules[k1] is m2)
|
|
|
|
def test_dict_subclass_initialization_in_graph(self):
|
|
for super_class in (
|
|
OrderedDict,
|
|
dict,
|
|
):
|
|
|
|
class CustomDict(super_class):
|
|
def __new__(cls, *args, **kwargs):
|
|
return super().__new__(cls, *args, **kwargs)
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
|
|
def fn(x):
|
|
c = CustomDict()
|
|
c["key"] = x
|
|
assert "key" in c
|
|
return c["key"] + 1
|
|
|
|
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
|
|
|
|
x = torch.rand(4)
|
|
self.assertEqual(fn(x), opt_fn(x))
|
|
|
|
def test_dict_list_values(self):
|
|
def inner_fn(args):
|
|
return [x[1].shape for x in args]
|
|
|
|
@torch.compile(backend="eager")
|
|
def fn(tensors):
|
|
return inner_fn(zip(itertools.count(), tensors["args"]))
|
|
|
|
fn({"args": [torch.ones(5, 5), torch.ones(5, 6), torch.ones(5, 7)]})
|
|
fn({"args": [torch.ones(5, 5)]})
|
|
|
|
def test_dict_iter(self):
|
|
class MyMod(torch.nn.Module):
|
|
def forward(self, x):
|
|
z = {"my": 1, "const": 2, "dict": 3, "variable": 4}
|
|
tot = 0
|
|
for key in z:
|
|
tot += z[key]
|
|
|
|
return tot
|
|
|
|
x = torch.tensor([0])
|
|
model = MyMod()
|
|
opt_model = torch.compile(model, backend="eager", fullgraph=True)
|
|
y = opt_model(x)
|
|
|
|
self.assertEqual(y, 10)
|
|
|
|
def test_dict_subclass_contains(self):
|
|
# pattern from huggingface
|
|
class ClassInstantier(OrderedDict):
|
|
pass
|
|
|
|
@torch.compile(fullgraph=True, backend="eager")
|
|
def f(x, d):
|
|
if "key1" in d:
|
|
x = x + 2
|
|
if "key2" in d:
|
|
x = x + 4
|
|
x = x + 8
|
|
return x
|
|
|
|
result = f(torch.ones(8), ClassInstantier({"key1": torch.ones(8)}))
|
|
self.assertTrue(same(result, torch.full([8], 11.0)))
|
|
|
|
result = f(torch.ones(8), ClassInstantier({"key2": torch.ones(8)}))
|
|
self.assertTrue(same(result, torch.full([8], 13.0)))
|
|
|
|
def test_dict_tag_guard(self):
|
|
class Foo:
|
|
def __init__(self) -> None:
|
|
self.scalar = 10
|
|
|
|
def fn(d, x):
|
|
return d["a"] * d["b"] * d["c"].scalar * x
|
|
|
|
foo = Foo()
|
|
|
|
d = {"a": 2, "b": 3, "c": foo}
|
|
|
|
opt_fn = torch.compile(fn, backend="eager")
|
|
inp = torch.randn(3, 3)
|
|
self.assertEqual(fn(d, inp), opt_fn(d, inp))
|
|
|
|
d["a"] = 4
|
|
self.assertEqual(fn(d, inp), opt_fn(d, inp))
|
|
|
|
# Check that recompilation happens
|
|
foo.scalar = 12
|
|
self.assertEqual(fn(d, inp), opt_fn(d, inp))
|
|
|
|
def test_empty_dict_recompilation(self):
|
|
def fn(d, x):
|
|
if d:
|
|
return torch.cos(x)
|
|
return torch.sin(x)
|
|
|
|
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
|
|
x = torch.randn(4)
|
|
self.assertEqual(fn({}, x), opt_fn({}, x))
|
|
self.assertEqual(fn({"a": 1}, x), opt_fn({"a": 1}, x))
|
|
|
|
def test_udf_dict_reconstruction(self):
|
|
class MyDict(dict):
|
|
pass
|
|
|
|
def fn(x, klass):
|
|
x = x * 2
|
|
sc_dict = dict.__new__(klass)
|
|
sc_dict["x"] = x
|
|
if isinstance(sc_dict, MyDict):
|
|
sc_dict.attr = 3
|
|
return sc_dict
|
|
|
|
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
|
|
x = torch.randn(4)
|
|
ref = fn(x, MyDict)
|
|
res = opt_fn(x, MyDict)
|
|
self.assertEqual(ref, res)
|
|
self.assertTrue(isinstance(res, MyDict))
|
|
self.assertEqual(ref.attr, res.attr)
|
|
|
|
ref = fn(x, dict)
|
|
res = opt_fn(x, dict)
|
|
self.assertEqual(ref, res)
|
|
self.assertTrue(isinstance(res, dict))
|
|
|
|
def test_weakref_dict(self):
|
|
states = weakref.WeakKeyDictionary()
|
|
|
|
mod1 = torch.nn.Module()
|
|
mod2 = torch.nn.Module()
|
|
|
|
states[mod1] = 2
|
|
states[mod2] = 3
|
|
|
|
def fn(x):
|
|
if mod1 in states:
|
|
x = torch.sin(x)
|
|
if mod2 in states:
|
|
x = torch.cos(x)
|
|
return x
|
|
|
|
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
|
|
x = torch.randn(4)
|
|
self.assertEqual(fn(x), opt_fn(x))
|
|
|
|
def test_construct_user_dict_and_return(self):
|
|
def fn(x):
|
|
return DummyUserDict({"a": x + 1})
|
|
|
|
x = torch.randn(4)
|
|
res = fn(x)
|
|
self.assertEqual(res["a"], x + 1)
|
|
|
|
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
|
|
self.assertEqual(res["a"], opt_fn(x)["a"])
|
|
|
|
def test_fn_id(self):
|
|
def fn(x, f):
|
|
d = {id(f): 3}
|
|
return x * d[id(f)]
|
|
|
|
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
|
|
x = torch.randn(4)
|
|
|
|
def nothing():
|
|
pass
|
|
|
|
f = nothing
|
|
self.assertEqual(fn(x, f), opt_fn(x, f))
|
|
|
|
def test_mapping_proxy_for_local(self):
|
|
def fn(x):
|
|
d = {"a": 2, "b": 3, "c": 5 * x}
|
|
mp = types.MappingProxyType(d)
|
|
y = torch.sin(x * mp["a"])
|
|
for v in mp.values():
|
|
y += torch.cos(x * v)
|
|
return mp
|
|
|
|
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
|
|
x = torch.randn(4)
|
|
ref = fn(x)
|
|
res = opt_fn(x)
|
|
self.assertEqual(ref, res)
|
|
self.assertTrue(type(res) is types.MappingProxyType)
|
|
|
|
def test_mapping_proxy_for_nonlocal(self):
|
|
d = {"a": 2, "b": 3, "c": 5}
|
|
|
|
def fn(x):
|
|
mp = types.MappingProxyType(d)
|
|
y = torch.sin(x * mp["a"])
|
|
for v in mp.values():
|
|
y += torch.cos(x * v)
|
|
d["d"] = 4
|
|
return mp
|
|
|
|
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
|
|
x = torch.randn(4)
|
|
ref = fn(x)
|
|
res = opt_fn(x)
|
|
self.assertEqual(ref, res)
|
|
self.assertTrue(type(res) is types.MappingProxyType)
|
|
|
|
# check update to d is reflected in res
|
|
d["e"] = 5
|
|
self.assertEqual(d["e"], res["e"])
|
|
|
|
def test_mapping_proxy_existing(self):
|
|
d = {"a": 2, "b": 3, "c": 5}
|
|
|
|
def fn(x, mp):
|
|
y = torch.sin(x * mp["a"])
|
|
for v in mp.values():
|
|
y += torch.cos(x * v)
|
|
if isinstance(mp, types.MappingProxyType):
|
|
y *= 2
|
|
return y
|
|
|
|
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
|
|
x = torch.randn(4)
|
|
mp = types.MappingProxyType(d)
|
|
ref = fn(x, mp)
|
|
res = opt_fn(x, mp)
|
|
self.assertEqual(ref, res)
|
|
|
|
d["a"] = 3
|
|
ref = fn(x, mp)
|
|
res = opt_fn(x, mp)
|
|
self.assertEqual(ref, res)
|
|
|
|
d.pop("b")
|
|
ref = fn(x, mp)
|
|
res = opt_fn(x, mp)
|
|
self.assertEqual(ref, res)
|
|
|
|
def test_dict_construction_from_mapping_proxy(self):
|
|
d = {"a": 2, "b": 3, "c": 5}
|
|
|
|
def fn(x, mp):
|
|
d = dict(mp)
|
|
y = torch.sin(x * d["a"])
|
|
return y
|
|
|
|
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
|
|
x = torch.randn(4)
|
|
mp = types.MappingProxyType(d)
|
|
ref = fn(x, mp)
|
|
res = opt_fn(x, mp)
|
|
self.assertEqual(ref, res)
|
|
|
|
def test_mapping_proxy_existing_mutation(self):
|
|
d = {"a": 2, "b": 3, "c": 5}
|
|
|
|
mp = types.MappingProxyType(d)
|
|
|
|
def fn(x):
|
|
d["d"] = 4
|
|
y = torch.sin(x * mp["d"])
|
|
return y
|
|
|
|
opt_fn = torch.compile(fn, backend="eager")
|
|
x = torch.randn(4)
|
|
ref = torch.sin(x * 4)
|
|
res = opt_fn(x)
|
|
self.assertEqual(ref, res)
|
|
self.assertEqual(d.keys(), mp.keys())
|
|
|
|
def test_mapping_proxy_existing_local_mutation(self):
|
|
d = {"a": 2, "b": 3, "c": 5}
|
|
|
|
mp = types.MappingProxyType(d)
|
|
|
|
def fn(x):
|
|
# Dynamo should not cause a graph break here because it knows that
|
|
# the existing proxy can't point to this new dict
|
|
other_dict = {}
|
|
other_dict["d"] = 4
|
|
y = torch.sin(x * mp["c"])
|
|
return y
|
|
|
|
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
|
|
x = torch.randn(4)
|
|
ref = torch.sin(x * mp["c"])
|
|
res = opt_fn(x)
|
|
self.assertEqual(ref, res)
|
|
self.assertEqual(d.keys(), mp.keys())
|
|
|
|
def test_move_to_end(self):
|
|
def fn(x):
|
|
d = OrderedDict({"a": torch.cos(x), "b": 3, "c": 5})
|
|
d.move_to_end("a")
|
|
return d
|
|
|
|
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
|
|
x = torch.randn(4)
|
|
self.assertEqual(["b", "c", "a"], list(opt_fn(x).keys()))
|
|
self.assertEqual(fn(x), opt_fn(x))
|
|
|
|
def test_mapping_proxy_ban_muation_on_dict_realization(self):
|
|
def fn(x):
|
|
class Foo:
|
|
b = 4
|
|
|
|
d = dict(Foo.__dict__)
|
|
y = torch.sin(x) * d["b"]
|
|
# This should cause a graph break, because otherwise the
|
|
# Foo.__dict__ will not be updated.
|
|
Foo.bar = 3
|
|
return Foo, y * Foo.__dict__["bar"]
|
|
|
|
opt_fn = torch.compile(fn, backend="eager")
|
|
x = torch.randn(4)
|
|
foo1, ref = fn(x)
|
|
foo2, res = opt_fn(x)
|
|
self.assertEqual(ref, res)
|
|
self.assertEqual(foo1.bar, foo2.bar)
|
|
|
|
def test_overridden_get_item(self):
|
|
class MyDict(dict):
|
|
def __init__(self, *args, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
self.calls = 0
|
|
|
|
def __getitem__(self, key):
|
|
self.calls += 1
|
|
return super().__getitem__(key) + 1
|
|
|
|
def fn(x, d):
|
|
d["d"] = 4
|
|
return x * d["a"] + d["b"] + d["c"] + d["d"]
|
|
|
|
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
|
|
x = torch.randn(4)
|
|
d1 = MyDict({"a": 2, "b": 3, "c": 5})
|
|
ref = fn(x, d1)
|
|
|
|
d2 = MyDict({"a": 2, "b": 3, "c": 5})
|
|
res = opt_fn(x, d2)
|
|
self.assertEqual(ref, res)
|
|
self.assertEqual(d1.calls, d2.calls)
|
|
|
|
def test_items_type(self):
|
|
def fn():
|
|
d = dict({"a": 1, "b": "2", "c": torch.tensor(3)}) # noqa: C418
|
|
return d.items()
|
|
|
|
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
|
|
ref = fn()
|
|
res = opt_fn()
|
|
self.assertEqual(ref, res)
|
|
self.assertEqual(type(res), dict_items)
|
|
|
|
def test_builtin_or_with_invalid_types(self):
|
|
args = (
|
|
1, # int
|
|
1.0, # float
|
|
"a", # str
|
|
(1, 2), # tuple
|
|
[1, 2], # list
|
|
)
|
|
|
|
@torch.compile(backend="eager", fullgraph=True)
|
|
def fn(b: Any):
|
|
a = {"one": torch.ones(1)}
|
|
return a | b
|
|
|
|
from torch._dynamo.exc import Unsupported
|
|
|
|
for arg in args:
|
|
with self.assertRaisesRegex(Unsupported, "Observed exception"):
|
|
_ = fn(arg)
|
|
|
|
def test_builtin_or_with_diff_keys(self):
|
|
def f():
|
|
a = {"one": torch.ones(1)}
|
|
b = {"two": torch.ones(2)}
|
|
return a, b, a | b, b | a, a.__or__(b), b.__or__(a)
|
|
|
|
opt_f = torch.compile(f, backend="eager", fullgraph=True)
|
|
self.assertEqual(f(), opt_f())
|
|
|
|
def test_builtin_or_with_same_keys(self):
|
|
def f():
|
|
a = {"one": torch.ones(1), "two": torch.ones(2)}
|
|
b = {"one": torch.ones(1), "three": torch.ones(3)}
|
|
return a, b, a | b, b | a, a.__or__(b), b.__or__(a)
|
|
|
|
opt_f = torch.compile(f, backend="eager", fullgraph=True)
|
|
self.assertEqual(f(), opt_f())
|
|
|
|
def test_builtin_ior_(self):
|
|
def f():
|
|
a = {"one": torch.ones(1)}
|
|
b = {"two": torch.ones(2)}
|
|
a |= b
|
|
return a, b
|
|
|
|
opt_f = torch.compile(f, backend="eager", fullgraph=True)
|
|
self.assertEqual(f(), opt_f())
|
|
|
|
def test_newly_constructed_default_dict(self):
|
|
def f(x):
|
|
d = defaultdict(list)
|
|
d[0] = [
|
|
42,
|
|
]
|
|
return x + 1, d
|
|
|
|
x = torch.ones(2)
|
|
ref = f(x)
|
|
res = torch.compile(f, backend="eager", fullgraph=True)(x)
|
|
|
|
self.assertEqual(ref, res)
|
|
|
|
@unittest.expectedFailure
|
|
def test_newly_constructed_default_dict_with_dict(self):
|
|
def f(x):
|
|
d = defaultdict(dict, {2: {"a": 1}})
|
|
d[0] = {"b": 2}
|
|
return x + 1, d
|
|
|
|
x = torch.ones(2)
|
|
ref = f(x)
|
|
res = torch.compile(f, backend="eager", fullgraph=True)(x)
|
|
|
|
self.assertEqual(ref, res)
|
|
|
|
def test_iter_default_dict(self):
|
|
def f(x):
|
|
d = defaultdict(list)
|
|
d[0] = 42
|
|
for k in d:
|
|
d[k] += 1
|
|
return x + 1, d
|
|
|
|
x = torch.ones(2)
|
|
ref = f(x)
|
|
res = torch.compile(f, backend="eager", fullgraph=True)(x)
|
|
|
|
self.assertEqual(ref, res)
|
|
|
|
@parametrize("op", ["or_", "and_", "xor", "sub"])
|
|
def test_dict_keys_binop(self, op):
|
|
op = getattr(operator, op)
|
|
|
|
def f():
|
|
a = {"one": torch.ones(1), "two": torch.ones(2)}
|
|
b = {"one": torch.ones(1), "three": torch.ones(3)}
|
|
return op(a.keys(), b.keys()), op(b.keys(), a.keys())
|
|
|
|
opt_f = torch.compile(f, backend="eager", fullgraph=True)
|
|
self.assertEqual(f(), opt_f())
|
|
|
|
@parametrize("op", ["ior", "iand", "ixor", "isub"])
|
|
def test_dict_keys_inplace_binop(self, op):
|
|
op = getattr(operator, op)
|
|
|
|
def f():
|
|
a = {"one": torch.ones(1), "two": torch.ones(2)}.keys()
|
|
b = {"one": torch.ones(1), "three": torch.ones(3)}.keys()
|
|
c = {"one": torch.ones(1), "two": torch.ones(2)}.keys()
|
|
a = op(a, b)
|
|
b = op(b, c)
|
|
return a, b
|
|
|
|
opt_f = torch.compile(f, backend="eager", fullgraph=True)
|
|
self.assertEqual(f(), opt_f())
|
|
|
|
|
|
instantiate_parametrized_tests(DictTests)
|
|
|
|
|
|
class DictGuardTests(LoggingTestCase):
|
|
thetype = dict
|
|
|
|
@make_logging_test(recompiles=True)
|
|
def test_popitem(self, records):
|
|
d = self.thetype()
|
|
d[1] = 2
|
|
d[3] = 4
|
|
|
|
@torch.compile(backend="eager", fullgraph=True)
|
|
def fn(x):
|
|
k, v = d.popitem()
|
|
if k == 3 and v == 4:
|
|
return x.sin()
|
|
return x.cos()
|
|
|
|
x = torch.tensor(1.0)
|
|
y = fn(x)
|
|
# sanity check
|
|
self.assertEqual(len(records), 0)
|
|
self.assertEqual(y, x.sin())
|
|
|
|
d[3] = 5
|
|
y = fn(x)
|
|
self.assertEqual(len(records), 1)
|
|
self.assertEqual(y, x.cos())
|
|
record = self.getRecord(records, "d")
|
|
self.assertIn(
|
|
"""d[3] == 4""",
|
|
munge_exc(record),
|
|
)
|
|
|
|
@make_logging_test(recompiles=True)
|
|
def test_cmp_eq(self, records):
|
|
@torch.compile(backend="eager", fullgraph=True)
|
|
def fn(x, d1, d2):
|
|
if d1 == d2:
|
|
return x.sin()
|
|
return x.cos()
|
|
|
|
x = torch.tensor(1.0)
|
|
d1 = self.thetype({1: 2, 3: 4})
|
|
d2 = self.thetype({1: 2, 5: 6})
|
|
y = fn(x, d1, d2)
|
|
# sanity check
|
|
self.assertEqual(len(records), 0)
|
|
self.assertEqual(y, x.cos())
|
|
|
|
y = fn(x, d1, d1)
|
|
self.assertEqual(len(records), 1)
|
|
self.assertEqual(y, x.sin())
|
|
record = self.getRecord(records, "d2")
|
|
self.assertIn(
|
|
"""list(dict.keys(d2))""",
|
|
munge_exc(record.getMessage()),
|
|
)
|
|
|
|
@make_logging_test(recompiles=True)
|
|
def test_cmp_ne(self, records):
|
|
@torch.compile(backend="eager", fullgraph=True)
|
|
def fn(x, d1, d2):
|
|
if d1 == d2:
|
|
return x.sin()
|
|
return x.cos()
|
|
|
|
x = torch.tensor(1.0)
|
|
d1 = self.thetype({1: 2, 3: 4})
|
|
d2 = self.thetype({1: 2, 5: 6})
|
|
y = fn(x, d1, d2)
|
|
# sanity check
|
|
self.assertEqual(len(records), 0)
|
|
self.assertEqual(y, x.cos())
|
|
|
|
y = fn(x, d1, d1)
|
|
self.assertEqual(len(records), 1)
|
|
self.assertEqual(y, x.sin())
|
|
record = self.getRecord(records, "d2")
|
|
self.assertIn(
|
|
"""list(dict.keys(d2))""",
|
|
munge_exc(record.getMessage()),
|
|
)
|
|
|
|
@make_logging_test(recompiles=True)
|
|
def test_cmp_or(self, records):
|
|
@torch.compile(backend="eager", fullgraph=True)
|
|
def fn(x, d1, d2):
|
|
d = d1 | d2
|
|
if d.get(5, False):
|
|
return x.sin()
|
|
return x.cos()
|
|
|
|
x = torch.tensor(1.0)
|
|
d1 = self.thetype({1: 2, 3: 4})
|
|
d2 = self.thetype({1: 2, 5: 6})
|
|
y = fn(x, d1, d2)
|
|
# sanity check
|
|
self.assertEqual(len(records), 0)
|
|
self.assertEqual(y, x.sin())
|
|
|
|
y = fn(x, d1, d1)
|
|
self.assertEqual(len(records), 1)
|
|
self.assertEqual(y, x.cos())
|
|
record = self.getRecord(records, "d2")
|
|
self.assertIn(
|
|
"""KeyError on d2[5]""",
|
|
munge_exc(record.getMessage()),
|
|
)
|
|
|
|
@make_logging_test(recompiles=True)
|
|
def test_cmp_ior(self, records):
|
|
@torch.compile(backend="eager", fullgraph=True)
|
|
def fn(x, d1, d2):
|
|
d2 |= d1
|
|
if d2.get(3, False):
|
|
return x.sin()
|
|
return x.cos()
|
|
|
|
x = torch.tensor(1.0)
|
|
d1 = self.thetype({1: 2, 3: 4})
|
|
d2 = self.thetype({1: 2, 5: 6})
|
|
d3, d4 = d2.copy(), d2.copy()
|
|
y = fn(x, d1, d2)
|
|
# sanity check
|
|
self.assertEqual(len(records), 0)
|
|
self.assertEqual(y, x.sin())
|
|
|
|
y = fn(x, d3, d4)
|
|
self.assertEqual(len(records), 1)
|
|
self.assertEqual(y, x.cos())
|
|
record = self.getRecord(records, "d1")
|
|
self.assertIn(
|
|
"""KeyError on d1[3]""",
|
|
munge_exc(record.getMessage()),
|
|
)
|
|
|
|
|
|
class DictMethodsTests(torch._dynamo.test_case.TestCase):
|
|
thetype = dict
|
|
|
|
# Methods:
|
|
# + clear
|
|
# + copy
|
|
# + fromkeys
|
|
# + get
|
|
# + items
|
|
# + keys
|
|
# + pop
|
|
# + popitem
|
|
# + setdefault
|
|
# + update
|
|
# + values
|
|
# BinOps:
|
|
# ==, !=, |
|
|
|
|
def setUp(self):
|
|
torch._dynamo.config.enable_trace_unittest = True
|
|
super().setUp()
|
|
|
|
def tearDown(self):
|
|
torch._dynamo.config.enable_trace_unittest = False
|
|
return super().tearDown()
|
|
|
|
def assertEqual(self, x, y):
|
|
self.assertTrue(x == y, f"Expected {x} to be equal to {y}")
|
|
|
|
def assertNotEqual(self, x, y):
|
|
self.assertFalse(x == y, f"Expected {x} to not be equal to {y}")
|
|
|
|
@make_dynamo_test
|
|
def test_cmp_eq(self):
|
|
d1 = self.thetype({"a": 1, "b": 2})
|
|
d2 = self.thetype({"a": 1, "b": 2})
|
|
d3 = self.thetype({"a": 1, "b": 3})
|
|
self.assertEqual(d1, d2)
|
|
self.assertNotEqual(d1, d3)
|
|
|
|
# Test the == operator
|
|
self.assertEqual(d1 == d2, True)
|
|
self.assertEqual(d1 == d3, False)
|
|
|
|
# Test the __eq__ method
|
|
self.assertEqual(d1.__eq__(d2), True)
|
|
self.assertEqual(d1.__eq__(d3), False)
|
|
|
|
# Test Dict.__eq__
|
|
self.assertEqual(dict.__eq__(d1, d2), True)
|
|
self.assertEqual(self.thetype.__eq__(d1, d3), False)
|
|
|
|
@make_dynamo_test
|
|
def test_cmp_ne(self):
|
|
d1 = self.thetype({"a": 1, "b": 2})
|
|
d2 = self.thetype({"a": 1, "b": 2})
|
|
d3 = self.thetype({"a": 1, "b": 3})
|
|
self.assertNotEqual(d1, d3)
|
|
self.assertEqual(d1, d2)
|
|
|
|
# Test the != operator
|
|
self.assertEqual(d1 != d3, True)
|
|
self.assertEqual(d1 != d2, False)
|
|
|
|
# Test the __ne__ method
|
|
self.assertEqual(d1.__ne__(d3), True)
|
|
self.assertEqual(d1.__ne__(d2), False)
|
|
|
|
# Test Dict.__ne__
|
|
self.assertEqual(dict.__ne__(d1, d3), True)
|
|
self.assertEqual(self.thetype.__ne__(d1, d2), False)
|
|
|
|
@make_dynamo_test
|
|
def test_binop_or(self):
|
|
d1 = self.thetype({"a": 1, "b": 2})
|
|
d2 = self.thetype({"b": 3, "c": 4})
|
|
|
|
# Test the | operator
|
|
self.assertEqual(d1 | d2, {"a": 1, "b": 3, "c": 4})
|
|
self.assertEqual(d2 | d1, {"a": 1, "b": 2, "c": 4})
|
|
|
|
# Test the __or__ method
|
|
self.assertEqual(d1.__or__(d2), {"a": 1, "b": 3, "c": 4})
|
|
self.assertEqual(d2.__or__(d1), {"a": 1, "b": 2, "c": 4})
|
|
|
|
# Test Dict.__or__
|
|
self.assertEqual(dict.__or__(d1, d2), {"a": 1, "b": 3, "c": 4})
|
|
self.assertEqual(self.thetype.__or__(d2, d1), {"a": 1, "b": 2, "c": 4})
|
|
|
|
# Test with non-dict types
|
|
self.assertRaises(TypeError, lambda: d1 | 1)
|
|
|
|
@make_dynamo_test
|
|
def test_binop_ior(self):
|
|
d1 = self.thetype({"a": 1, "b": 2})
|
|
d2 = self.thetype({"b": 3, "c": 4})
|
|
|
|
# Test the |= operator
|
|
d3, d4 = d1.copy(), d2.copy()
|
|
d3 |= d2
|
|
d4 |= d1
|
|
self.assertEqual(d3, {"a": 1, "b": 3, "c": 4})
|
|
self.assertEqual(d4, {"a": 1, "b": 2, "c": 4})
|
|
|
|
# Test with an iterable
|
|
d3, d4 = d1.copy(), d2.copy()
|
|
|
|
# Test the __ior__ method
|
|
d3, d4 = d1.copy(), d2.copy()
|
|
d3.__ior__(d2)
|
|
d4.__ior__(d1)
|
|
self.assertEqual(d3, {"a": 1, "b": 3, "c": 4})
|
|
self.assertEqual(d4, {"a": 1, "b": 2, "c": 4})
|
|
|
|
# Test Dict.__or__
|
|
d3, d4 = d1.copy(), d2.copy()
|
|
self.assertEqual(dict.__ior__(d3, d2), {"a": 1, "b": 3, "c": 4})
|
|
self.assertEqual(self.thetype.__ior__(d4, d1), {"a": 1, "b": 2, "c": 4})
|
|
|
|
# Test return value
|
|
d3, d4 = d1.copy(), d2.copy()
|
|
self.assertEqual(d3.__ior__(d2), {"a": 1, "b": 3, "c": 4})
|
|
self.assertEqual(dict.__ior__(d4, d1), {"a": 1, "b": 2, "c": 4})
|
|
|
|
# Test with non-dict types
|
|
self.assertRaises(TypeError, lambda: dict.__ior__(d1, 1))
|
|
|
|
@make_dynamo_test
|
|
def test_binop_ior_iterable(self):
|
|
d1 = self.thetype({"a": 1, "b": 2})
|
|
d2 = self.thetype({"b": 3, "c": 4})
|
|
d3, d4 = d1.copy(), d2.copy()
|
|
|
|
def fn(d):
|
|
yield from d.items()
|
|
|
|
self.assertEqual(d3.__ior__(d2.items()), {"a": 1, "b": 3, "c": 4})
|
|
self.assertEqual(d4.__ior__(fn(d1)), {"a": 1, "b": 2, "c": 4})
|
|
|
|
@make_dynamo_test
|
|
def test_clear(self):
|
|
d = self.thetype({"a": 1, "b": 2})
|
|
d.clear()
|
|
self.assertEqual(d, {})
|
|
|
|
# Test that clear returns None
|
|
d = self.thetype({"a": 1, "b": 2})
|
|
self.assertIsNone(d.clear())
|
|
|
|
# Test Dict.clear
|
|
d = self.thetype({"a": 1, "b": 2})
|
|
dict.clear(d)
|
|
self.assertEqual(d, {})
|
|
|
|
d = self.thetype({"a": 1, "b": 2})
|
|
self.thetype.clear(d)
|
|
self.assertEqual(d, {})
|
|
|
|
# Test invalid usage
|
|
self.assertRaises(TypeError, d.clear, 1)
|
|
|
|
@make_dynamo_test
|
|
def test_copy(self):
|
|
d = self.thetype({"a": 1, "b": 2})
|
|
d2 = d.copy()
|
|
self.assertEqual(d, d2)
|
|
|
|
# Test that copy returns a new instance
|
|
self.assertIsNot(d, d2)
|
|
|
|
# Test Dict.copy
|
|
self.assertEqual(dict.copy(d), d2)
|
|
self.assertEqual(self.thetype.copy(d), d2)
|
|
|
|
# Test invalid usage
|
|
self.assertRaises(TypeError, d.copy, 1)
|
|
|
|
@unittest.expectedFailure
|
|
@make_dynamo_test
|
|
def test_fromkeys(self):
|
|
d = self.thetype.fromkeys(["a", "b"], 1)
|
|
self.assertEqual(d, {"a": 1, "b": 1})
|
|
p = self.thetype.fromkeys(["a", "b"], None)
|
|
self.assertEqual(p, {"a": None, "b": None})
|
|
|
|
# Test Dict.fromkeys
|
|
d2 = self.thetype.fromkeys(["c", "d"], 2)
|
|
self.assertEqual(d2, {"c": 2, "d": 2})
|
|
|
|
# Test invalid usage
|
|
self.assertRaises(TypeError, self.thetype.fromkeys)
|
|
self.assertRaises(TypeError, self.thetype.fromkeys, 1, 2)
|
|
|
|
@make_dynamo_test
|
|
def test_get(self):
|
|
d = self.thetype({"a": 1, "b": 2})
|
|
self.assertEqual(d.get("a"), 1)
|
|
self.assertEqual(d.get("c", 3), 3)
|
|
self.assertIsNone(d.get("c"))
|
|
|
|
# Test Dict.get
|
|
self.assertEqual(dict.get(d, "b"), 2)
|
|
self.assertEqual(self.thetype.get(d, "b"), 2)
|
|
|
|
# Test invalid usage
|
|
self.assertRaises(TypeError, d.get)
|
|
|
|
@make_dynamo_test
|
|
def test_items(self):
|
|
d = self.thetype({"a": 1, "b": 2})
|
|
items = d.items()
|
|
self.assertEqual(set(items), {("a", 1), ("b", 2)})
|
|
|
|
# Test Dict.items
|
|
self.assertEqual(set(dict.items(d)), {("a", 1), ("b", 2)})
|
|
self.assertEqual(set(self.thetype.items(d)), {("a", 1), ("b", 2)})
|
|
|
|
# Test invalid usage
|
|
self.assertRaises(TypeError, d.items, 1)
|
|
|
|
@make_dynamo_test
|
|
def test_keys(self):
|
|
d = self.thetype({"a": 1, "b": 2})
|
|
keys = d.keys()
|
|
self.assertEqual(set(keys), {"a", "b"})
|
|
|
|
# Test Dict.keys
|
|
self.assertEqual(set(dict.keys(d)), {"a", "b"})
|
|
self.assertEqual(set(self.thetype.keys(d)), {"a", "b"})
|
|
|
|
# Test invalid usage
|
|
self.assertRaises(TypeError, d.keys, 1)
|
|
|
|
@make_dynamo_test
|
|
def test_pop(self):
|
|
d = self.thetype({"a": 1, "b": 2})
|
|
self.assertEqual(d.pop("a"), 1)
|
|
self.assertEqual(d, {"b": 2})
|
|
self.assertIsNone(d.pop("c", None))
|
|
|
|
# Test Dict.pop
|
|
d = self.thetype({"a": 1, "b": 2})
|
|
self.assertEqual(dict.pop(d, "b"), 2)
|
|
self.assertEqual(self.thetype.pop(d, "a"), 1)
|
|
|
|
# Test invalid usage
|
|
self.assertRaises(KeyError, d.pop, "c")
|
|
self.assertRaises(TypeError, d.pop)
|
|
|
|
@make_dynamo_test
|
|
def test_popitem(self):
|
|
d = self.thetype({"a": 1})
|
|
key, value = d.popitem()
|
|
self.assertEqual(key, "a")
|
|
self.assertEqual(value, 1)
|
|
self.assertEqual(len(d), 0)
|
|
# check LIFO
|
|
d = self.thetype()
|
|
d["a"] = 1
|
|
d["b"] = 2
|
|
self.assertEqual(d.popitem(), ("b", 2))
|
|
|
|
# Test Dict.popitem
|
|
d = self.thetype({"a": 1})
|
|
key, value = dict.popitem(d)
|
|
self.assertEqual(key, "a")
|
|
self.assertEqual(value, 1)
|
|
|
|
d = self.thetype({"a": 1})
|
|
key, value = self.thetype.popitem(d)
|
|
self.assertEqual(key, "a")
|
|
self.assertEqual(value, 1)
|
|
|
|
# Test invalid usage
|
|
if self.thetype is not OrderedDict:
|
|
# OrderedDict accepts a keyword arg
|
|
self.assertRaises(TypeError, d.popitem, 1)
|
|
|
|
@make_dynamo_test
|
|
def test_setdefault(self):
|
|
d = self.thetype({"a": 1, "b": 2})
|
|
self.assertEqual(d.setdefault("a", 3), 1)
|
|
self.assertEqual(d.setdefault("c", 3), 3)
|
|
self.assertIsNone(d.setdefault("d"), None)
|
|
self.assertEqual(d, {"a": 1, "b": 2, "c": 3, "d": None})
|
|
|
|
# Test Dict.setdefault
|
|
self.assertEqual(dict.setdefault(d, "f", 5), 5)
|
|
self.assertEqual(self.thetype.setdefault(d, "e", 5), 5)
|
|
|
|
# Test invalid usage
|
|
self.assertRaises(TypeError, d.setdefault)
|
|
self.assertRaises(TypeError, d.setdefault, [[]])
|
|
|
|
@make_dynamo_test
|
|
def test_update(self):
|
|
d = self.thetype({"a": 1, "b": 2})
|
|
d.update({"b": 3, "c": 4})
|
|
self.assertEqual(d, {"a": 1, "b": 3, "c": 4})
|
|
|
|
# Test with another dict
|
|
d2 = self.thetype({"d": 5})
|
|
d.update(d2)
|
|
self.assertEqual(d, {"a": 1, "b": 3, "c": 4, "d": 5})
|
|
|
|
# Test Dict.update
|
|
d3 = self.thetype({"e": 6})
|
|
dict.update(d, d3)
|
|
self.assertEqual(d, {"a": 1, "b": 3, "c": 4, "d": 5, "e": 6})
|
|
d4 = self.thetype({"f": 7})
|
|
self.thetype.update(d, d4)
|
|
self.assertEqual(d, {"a": 1, "b": 3, "c": 4, "d": 5, "e": 6, "f": 7})
|
|
|
|
# Test with keyword arguments
|
|
d.update(f=7, g=8)
|
|
self.assertEqual(d, {"a": 1, "b": 3, "c": 4, "d": 5, "e": 6, "f": 7, "g": 8})
|
|
|
|
# Test Dict.update with keyword arguments
|
|
self.thetype.update(d, h=9, i=10)
|
|
self.assertEqual(
|
|
d, {"a": 1, "b": 3, "c": 4, "d": 5, "e": 6, "f": 7, "g": 8, "h": 9, "i": 10}
|
|
)
|
|
|
|
# Test invalid usage
|
|
self.assertRaises(TypeError, d.update, 1)
|
|
|
|
@make_dynamo_test
|
|
def test_values(self):
|
|
d = self.thetype({"a": 1, "b": 2})
|
|
values = d.values()
|
|
self.assertEqual(set(values), {1, 2})
|
|
|
|
# Test Dict.values
|
|
self.assertEqual(set(dict.values(d)), {1, 2})
|
|
self.assertEqual(set(self.thetype.values(d)), {1, 2})
|
|
|
|
# Test invalid usage
|
|
self.assertRaises(TypeError, d.values, 1)
|
|
|
|
@make_dynamo_test
|
|
def test_type(self):
|
|
d = self.thetype({"a": 1, "b": 2})
|
|
self.assertIsInstance(d, self.thetype)
|
|
self.assertIs(type(d), self.thetype)
|
|
|
|
@make_dynamo_test
|
|
def test_dict_type_comparison(self):
|
|
types = (dict, OrderedDict, defaultdict)
|
|
self.assertEqual(self.thetype, self.thetype)
|
|
self.assertTrue(self.thetype is self.thetype)
|
|
for other in types:
|
|
if self.thetype == other:
|
|
continue
|
|
self.assertNotEqual(self.thetype, other)
|
|
self.assertTrue(self.thetype is not other, f"{self.thetype=}, {other=}")
|
|
|
|
@make_dynamo_test
|
|
def test_dict___iter__(self):
|
|
d = self.thetype({1: 2})
|
|
it = d.__iter__()
|
|
self.assertEqual(next(it), 1)
|
|
|
|
|
|
class DictSubclassMethodsTests(DictMethodsTests):
|
|
thetype = SimpleDict
|
|
|
|
|
|
class OrderedDictMethodsTests(DictMethodsTests):
|
|
thetype = OrderedDict
|
|
|
|
# Methods:
|
|
# - popitem - Inherited from DictMethodsTest
|
|
# + move_to_end
|
|
|
|
@make_dynamo_test
|
|
def test_move_to_end(self):
|
|
d = self.thetype.fromkeys("abcde")
|
|
self.assertEqual("".join(d), "abcde")
|
|
d.move_to_end("b")
|
|
self.assertEqual("".join(d), "acdeb")
|
|
|
|
# Test OrderedDict.move_to_end
|
|
self.thetype.move_to_end(d, "a")
|
|
self.assertEqual("".join(d), "cdeba")
|
|
|
|
# Test last=False
|
|
self.thetype.move_to_end(d, "a", last=False)
|
|
self.assertEqual("".join(d), "acdeb")
|
|
|
|
# Test KeyError
|
|
self.assertRaises(KeyError, d.move_to_end, "f")
|
|
|
|
def test_cmp_eq_order(self):
|
|
a = self.thetype.fromkeys("abc")
|
|
b = self.thetype.fromkeys("bca")
|
|
self.assertFalse(a == b)
|
|
|
|
@make_dynamo_test
|
|
def test_binop_or_return_type(self):
|
|
d1 = self.thetype({"a": 1, "b": 2})
|
|
d2 = self.thetype({"b": 3, "c": 4})
|
|
|
|
# Test return type
|
|
self.assertIs(type(d1 | d2), OrderedDict)
|
|
self.assertIs(type(dict(d1) | d2), OrderedDict)
|
|
self.assertIs(type(d1 | dict(d2)), OrderedDict)
|
|
|
|
@make_dynamo_test
|
|
def test_binop_ior_return_type(self):
|
|
d1 = self.thetype({"a": 1, "b": 2})
|
|
d2 = self.thetype({"b": 3, "c": 4})
|
|
|
|
# Test return type
|
|
d3, d4 = d1.copy(), d2.copy()
|
|
self.assertIs(type(d3.__ior__(d2)), OrderedDict)
|
|
self.assertIs(type(dict.__ior__(d4, d2)), OrderedDict)
|
|
self.assertIs(type(self.thetype.__ior__(d4, d2)), OrderedDict)
|
|
|
|
d3, d4 = d1.copy(), d2.copy()
|
|
self.assertIs(type(dict.__ior__(d3, dict(d2))), OrderedDict)
|
|
self.assertIs(type(dict.__ior__(dict(d3), d2)), dict)
|
|
self.assertIs(type(dict(d4).__ior__(d2)), dict)
|
|
|
|
@make_dynamo_test
|
|
def test_popitem_kwarg(self):
|
|
d = self.thetype.fromkeys("abcdf")
|
|
self.assertEqual(d.popitem(last=True), ("f", None))
|
|
self.assertEqual(list(d), list("abcd"))
|
|
self.assertEqual(d.popitem(last=False), ("a", None))
|
|
self.assertEqual(list(d), list("bcd"))
|
|
self.assertEqual(d.popitem(False), ("b", None))
|
|
self.assertEqual(list(d), list("cd"))
|
|
self.assertEqual(d.popitem(True), ("d", None))
|
|
self.assertEqual(list(d), list("c"))
|
|
|
|
|
|
class OrderedDictSubclassOverload(torch._dynamo.test_case.TestCase):
|
|
def setUp(self):
|
|
torch._dynamo.config.enable_trace_unittest = True
|
|
super().setUp()
|
|
|
|
def tearDown(self):
|
|
torch._dynamo.config.enable_trace_unittest = False
|
|
return super().tearDown()
|
|
|
|
def assertEqual(self, x, y):
|
|
self.assertTrue(x == y, f"Expected {x} to be equal to {y}")
|
|
|
|
def assertNotEqual(self, x, y):
|
|
self.assertFalse(x == y, f"Expected {x} to not be equal to {y}")
|
|
|
|
class OrderedDictSubclass(OrderedDict):
|
|
def get(self, key, default=None, /):
|
|
return default
|
|
|
|
def move_to_end(self, key, last=True, /):
|
|
# change the behavior to something else
|
|
self.pop(key)
|
|
|
|
thetype = OrderedDictSubclass
|
|
|
|
@make_dynamo_test
|
|
def test_move_to_end(self):
|
|
p = self.thetype({"a": 1, "b": 2, "c": 3})
|
|
p.move_to_end("a")
|
|
self.assertEqual(list(p.keys()), list("bc"))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
from torch._dynamo.test_case import run_tests
|
|
|
|
run_tests()
|