mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[dict] Implement dict.__eq__ and dict.__ne__ (#154942)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/154942 Approved by: https://github.com/zou3519 ghstack dependencies: #154003, #154793, #154794
This commit is contained in:
parent
ba8d19ec02
commit
ba71eb496b
|
|
@ -1086,6 +1086,56 @@ class DictGuardTests(LoggingTestCase):
|
|||
record = self.getRecord(records, "d")
|
||||
self.assertIn(
|
||||
"""d[3] == 4""",
|
||||
munge_exc(record),
|
||||
)
|
||||
|
||||
@make_logging_test(recompiles=True)
|
||||
def test_cmp_eq(self, records):
|
||||
@torch.compile(backend="eager", fullgraph=True)
|
||||
def fn(x, d1, d2):
|
||||
if d1 == d2:
|
||||
return x.sin()
|
||||
return x.cos()
|
||||
|
||||
x = torch.tensor(1.0)
|
||||
d1 = self.thetype({1: 2, 3: 4})
|
||||
d2 = self.thetype({1: 2, 5: 6})
|
||||
y = fn(x, d1, d2)
|
||||
# sanity check
|
||||
self.assertEqual(len(records), 0)
|
||||
self.assertEqual(y, x.cos())
|
||||
|
||||
y = fn(x, d1, d1)
|
||||
self.assertEqual(len(records), 1)
|
||||
self.assertEqual(y, x.sin())
|
||||
record = self.getRecord(records, "d2")
|
||||
self.assertIn(
|
||||
"""list(dict.keys(d2))""",
|
||||
munge_exc(record.getMessage()),
|
||||
)
|
||||
|
||||
@make_logging_test(recompiles=True)
|
||||
def test_cmp_ne(self, records):
|
||||
@torch.compile(backend="eager", fullgraph=True)
|
||||
def fn(x, d1, d2):
|
||||
if d1 == d2:
|
||||
return x.sin()
|
||||
return x.cos()
|
||||
|
||||
x = torch.tensor(1.0)
|
||||
d1 = self.thetype({1: 2, 3: 4})
|
||||
d2 = self.thetype({1: 2, 5: 6})
|
||||
y = fn(x, d1, d2)
|
||||
# sanity check
|
||||
self.assertEqual(len(records), 0)
|
||||
self.assertEqual(y, x.cos())
|
||||
|
||||
y = fn(x, d1, d1)
|
||||
self.assertEqual(len(records), 1)
|
||||
self.assertEqual(y, x.sin())
|
||||
record = self.getRecord(records, "d2")
|
||||
self.assertIn(
|
||||
"""list(dict.keys(d2))""",
|
||||
munge_exc(record.getMessage()),
|
||||
)
|
||||
|
||||
|
|
@ -1363,14 +1413,6 @@ class DictMethodsTests(torch._dynamo.test_case.TestCase):
|
|||
class DictSubclassMethodsTests(DictMethodsTests):
|
||||
thetype = SimpleDict
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_cmp_eq(self):
|
||||
return super().test_cmp_eq()
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_cmp_ne(self):
|
||||
return super().test_cmp_ne()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._dynamo.test_case import run_tests
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ Python polyfills for common builtins.
|
|||
# mypy: allow-untyped-defs
|
||||
|
||||
import types
|
||||
from collections import OrderedDict
|
||||
from collections.abc import Hashable, Iterable, MutableMapping, Sequence
|
||||
from itertools import repeat as _repeat
|
||||
from typing import Any, Callable, TYPE_CHECKING
|
||||
|
|
@ -101,6 +102,20 @@ def list_cmp(op: Callable[[Any, Any], bool], left: Sequence[Any], right: Sequenc
|
|||
return op(len(left), len(right))
|
||||
|
||||
|
||||
def dict___eq__(d, other):
|
||||
if (len(d) != len(other)) or (d.keys() != other.keys()):
|
||||
return False
|
||||
|
||||
if all(isinstance(a, OrderedDict) for a in (d, other)):
|
||||
return list(d.items()) == list(other.items())
|
||||
|
||||
for k, v in d.items():
|
||||
if v != other[k]:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def set_symmetric_difference(set1, set2):
|
||||
symmetric_difference_set = set()
|
||||
for x in set1:
|
||||
|
|
|
|||
|
|
@ -606,6 +606,19 @@ class ConstDictVariable(VariableTracker):
|
|||
self.items.pop(key)
|
||||
self.items[key] = val
|
||||
return ConstantVariable.create(None)
|
||||
elif name == "__eq__" and istype(
|
||||
self, ConstDictVariable
|
||||
): # don't let Set use this function
|
||||
if len(args) != 1:
|
||||
raise_args_mismatch(tx, name)
|
||||
|
||||
return variables.UserFunctionVariable(polyfills.dict___eq__).call_function(
|
||||
tx, [self, args[0]], {}
|
||||
)
|
||||
elif name == "__ne__":
|
||||
return ConstantVariable.create(
|
||||
not self.call_method(tx, "__eq__", args, kwargs).value
|
||||
)
|
||||
elif name == "__or__":
|
||||
assert len(args) == 1
|
||||
# Dicts can only be unioned with other dicts or subclasses of dicts.
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user