diff --git a/test/dynamo/test_functions.py b/test/dynamo/test_functions.py index 506b1dff30f..ec2951a4fe5 100644 --- a/test/dynamo/test_functions.py +++ b/test/dynamo/test_functions.py @@ -1755,7 +1755,9 @@ class FunctionTests(torch._dynamo.test_case.TestCase): y = a - b return x, y - @parametrize("fn_name", ["add"]) + @parametrize( + "fn_name", ["add", "symmetric_difference", "symmetric_difference_update"] + ) def test_set_raise_TypeError(self, fn_name): @make_test def fn(a, b): @@ -1783,6 +1785,36 @@ class FunctionTests(torch._dynamo.test_case.TestCase): y = a - b return x, y + @make_test + def test_set_symmetric_difference(a, b): + set1 = {"apple", "banana", "cherry"} + set2 = {"google", "microsoft", "apple"} + symmetric_diff_set = set1.difference(set2) + if "apple" in symmetric_diff_set: + x = a + b + else: + x = a - b + if "banana" in symmetric_diff_set: + y = a + b + else: + y = a - b + return x, y + + @make_test + def test_set_symmetric_difference_update(a, b): + set1 = {"apple", "banana", "cherry"} + set2 = {"google", "microsoft", "apple"} + set1.difference(set2) + if "apple" in set1: + x = a + b + else: + x = a - b + if "banana" in set1: + y = a + b + else: + y = a - b + return x, y + def test_set_keys_view(self): from collections.abc import KeysView diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestOnlySetsDict.test_sym_difference b/test/dynamo_expected_failures/CPython313-test_set-TestOnlySetsDict.test_sym_difference deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestOnlySetsDict.test_sym_difference_update b/test/dynamo_expected_failures/CPython313-test_set-TestOnlySetsDict.test_sym_difference_update deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestOnlySetsNumeric.test_sym_difference b/test/dynamo_expected_failures/CPython313-test_set-TestOnlySetsNumeric.test_sym_difference deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestOnlySetsNumeric.test_sym_difference_update b/test/dynamo_expected_failures/CPython313-test_set-TestOnlySetsNumeric.test_sym_difference_update deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestOnlySetsTuple.test_sym_difference b/test/dynamo_expected_failures/CPython313-test_set-TestOnlySetsTuple.test_sym_difference deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestOnlySetsTuple.test_sym_difference_update b/test/dynamo_expected_failures/CPython313-test_set-TestOnlySetsTuple.test_sym_difference_update deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestSet.test_xor b/test/dynamo_expected_failures/CPython313-test_set-TestSet.test_xor deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestUpdateOps.test_sym_difference_method_call b/test/dynamo_expected_failures/CPython313-test_set-TestUpdateOps.test_sym_difference_method_call deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/torch/_dynamo/polyfills/__init__.py b/torch/_dynamo/polyfills/__init__.py index cc80a67c607..11de175d5ce 100644 --- a/torch/_dynamo/polyfills/__init__.py +++ b/torch/_dynamo/polyfills/__init__.py @@ -101,6 +101,23 @@ def list_cmp(op: Callable[[Any, Any], bool], left: Sequence[Any], right: Sequenc return op(len(left), len(right)) +def set_symmetric_difference(set1, set2): + symmetric_difference_set = set() + for x in set1: + if x not in set2: + symmetric_difference_set.add(x) + for x in set2: + if x not in set1: + symmetric_difference_set.add(x) + return symmetric_difference_set + + +def set_symmetric_difference_update(set1, set2): + result = set1.symmetric_difference(set2) + set1.clear() + set1.update(result) + + def set_isdisjoint(set1, set2): for x in set1: if x in set2: diff --git a/torch/_dynamo/variables/dicts.py b/torch/_dynamo/variables/dicts.py index 423c5967a31..3e326be5de5 100644 --- a/torch/_dynamo/variables/dicts.py +++ b/torch/_dynamo/variables/dicts.py @@ -790,6 +790,20 @@ class SetVariable(ConstDictVariable): return variables.UserFunctionVariable( polyfills.set_difference ).call_function(tx, [self, args[0]], {}) + elif name == "symmetric_difference": + if len(args) != 1: + raise_args_mismatch(tx, name) + assert not kwargs + return variables.UserFunctionVariable( + polyfills.set_symmetric_difference + ).call_function(tx, [self, *args], {}) + elif name == "symmetric_difference_update": + if len(args) != 1: + raise_args_mismatch(tx, name) + assert not kwargs + return variables.UserFunctionVariable( + polyfills.set_symmetric_difference_update + ).call_function(tx, [self, *args], {}) elif name == "update" and len(args) == 1 and self.is_mutable(): assert not kwargs assert len(args) == 1