mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[dict] Implement dict.__eq__ and dict.__ne__ (#154942)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/154942 Approved by: https://github.com/zou3519 ghstack dependencies: #154003, #154793, #154794
This commit is contained in:
parent
ba8d19ec02
commit
ba71eb496b
|
|
@ -1086,6 +1086,56 @@ class DictGuardTests(LoggingTestCase):
|
||||||
record = self.getRecord(records, "d")
|
record = self.getRecord(records, "d")
|
||||||
self.assertIn(
|
self.assertIn(
|
||||||
"""d[3] == 4""",
|
"""d[3] == 4""",
|
||||||
|
munge_exc(record),
|
||||||
|
)
|
||||||
|
|
||||||
|
@make_logging_test(recompiles=True)
|
||||||
|
def test_cmp_eq(self, records):
|
||||||
|
@torch.compile(backend="eager", fullgraph=True)
|
||||||
|
def fn(x, d1, d2):
|
||||||
|
if d1 == d2:
|
||||||
|
return x.sin()
|
||||||
|
return x.cos()
|
||||||
|
|
||||||
|
x = torch.tensor(1.0)
|
||||||
|
d1 = self.thetype({1: 2, 3: 4})
|
||||||
|
d2 = self.thetype({1: 2, 5: 6})
|
||||||
|
y = fn(x, d1, d2)
|
||||||
|
# sanity check
|
||||||
|
self.assertEqual(len(records), 0)
|
||||||
|
self.assertEqual(y, x.cos())
|
||||||
|
|
||||||
|
y = fn(x, d1, d1)
|
||||||
|
self.assertEqual(len(records), 1)
|
||||||
|
self.assertEqual(y, x.sin())
|
||||||
|
record = self.getRecord(records, "d2")
|
||||||
|
self.assertIn(
|
||||||
|
"""list(dict.keys(d2))""",
|
||||||
|
munge_exc(record.getMessage()),
|
||||||
|
)
|
||||||
|
|
||||||
|
@make_logging_test(recompiles=True)
|
||||||
|
def test_cmp_ne(self, records):
|
||||||
|
@torch.compile(backend="eager", fullgraph=True)
|
||||||
|
def fn(x, d1, d2):
|
||||||
|
if d1 == d2:
|
||||||
|
return x.sin()
|
||||||
|
return x.cos()
|
||||||
|
|
||||||
|
x = torch.tensor(1.0)
|
||||||
|
d1 = self.thetype({1: 2, 3: 4})
|
||||||
|
d2 = self.thetype({1: 2, 5: 6})
|
||||||
|
y = fn(x, d1, d2)
|
||||||
|
# sanity check
|
||||||
|
self.assertEqual(len(records), 0)
|
||||||
|
self.assertEqual(y, x.cos())
|
||||||
|
|
||||||
|
y = fn(x, d1, d1)
|
||||||
|
self.assertEqual(len(records), 1)
|
||||||
|
self.assertEqual(y, x.sin())
|
||||||
|
record = self.getRecord(records, "d2")
|
||||||
|
self.assertIn(
|
||||||
|
"""list(dict.keys(d2))""",
|
||||||
munge_exc(record.getMessage()),
|
munge_exc(record.getMessage()),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -1363,14 +1413,6 @@ class DictMethodsTests(torch._dynamo.test_case.TestCase):
|
||||||
class DictSubclassMethodsTests(DictMethodsTests):
|
class DictSubclassMethodsTests(DictMethodsTests):
|
||||||
thetype = SimpleDict
|
thetype = SimpleDict
|
||||||
|
|
||||||
@unittest.expectedFailure
|
|
||||||
def test_cmp_eq(self):
|
|
||||||
return super().test_cmp_eq()
|
|
||||||
|
|
||||||
@unittest.expectedFailure
|
|
||||||
def test_cmp_ne(self):
|
|
||||||
return super().test_cmp_ne()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
from torch._dynamo.test_case import run_tests
|
from torch._dynamo.test_case import run_tests
|
||||||
|
|
|
||||||
|
|
@ -9,6 +9,7 @@ Python polyfills for common builtins.
|
||||||
# mypy: allow-untyped-defs
|
# mypy: allow-untyped-defs
|
||||||
|
|
||||||
import types
|
import types
|
||||||
|
from collections import OrderedDict
|
||||||
from collections.abc import Hashable, Iterable, MutableMapping, Sequence
|
from collections.abc import Hashable, Iterable, MutableMapping, Sequence
|
||||||
from itertools import repeat as _repeat
|
from itertools import repeat as _repeat
|
||||||
from typing import Any, Callable, TYPE_CHECKING
|
from typing import Any, Callable, TYPE_CHECKING
|
||||||
|
|
@ -101,6 +102,20 @@ def list_cmp(op: Callable[[Any, Any], bool], left: Sequence[Any], right: Sequenc
|
||||||
return op(len(left), len(right))
|
return op(len(left), len(right))
|
||||||
|
|
||||||
|
|
||||||
|
def dict___eq__(d, other):
|
||||||
|
if (len(d) != len(other)) or (d.keys() != other.keys()):
|
||||||
|
return False
|
||||||
|
|
||||||
|
if all(isinstance(a, OrderedDict) for a in (d, other)):
|
||||||
|
return list(d.items()) == list(other.items())
|
||||||
|
|
||||||
|
for k, v in d.items():
|
||||||
|
if v != other[k]:
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
def set_symmetric_difference(set1, set2):
|
def set_symmetric_difference(set1, set2):
|
||||||
symmetric_difference_set = set()
|
symmetric_difference_set = set()
|
||||||
for x in set1:
|
for x in set1:
|
||||||
|
|
|
||||||
|
|
@ -606,6 +606,19 @@ class ConstDictVariable(VariableTracker):
|
||||||
self.items.pop(key)
|
self.items.pop(key)
|
||||||
self.items[key] = val
|
self.items[key] = val
|
||||||
return ConstantVariable.create(None)
|
return ConstantVariable.create(None)
|
||||||
|
elif name == "__eq__" and istype(
|
||||||
|
self, ConstDictVariable
|
||||||
|
): # don't let Set use this function
|
||||||
|
if len(args) != 1:
|
||||||
|
raise_args_mismatch(tx, name)
|
||||||
|
|
||||||
|
return variables.UserFunctionVariable(polyfills.dict___eq__).call_function(
|
||||||
|
tx, [self, args[0]], {}
|
||||||
|
)
|
||||||
|
elif name == "__ne__":
|
||||||
|
return ConstantVariable.create(
|
||||||
|
not self.call_method(tx, "__eq__", args, kwargs).value
|
||||||
|
)
|
||||||
elif name == "__or__":
|
elif name == "__or__":
|
||||||
assert len(args) == 1
|
assert len(args) == 1
|
||||||
# Dicts can only be unioned with other dicts or subclasses of dicts.
|
# Dicts can only be unioned with other dicts or subclasses of dicts.
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user