diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py index f9a798fc66e..d73be018b66 100644 --- a/test/dynamo/test_misc.py +++ b/test/dynamo/test_misc.py @@ -9321,31 +9321,6 @@ def ___make_guard_fn(): self.assertEqual(counter.frame_count, 1) self.assertEqual(result, eager_result) - def test_input_set_graph_break(self): - def foo(x): - return x.pop() * x.pop() - - x = torch.randn(10, 10) - y = torch.randn(10, 10) - - counter = CompileCounter() - - inp = {x, x, x, x, y, y} - foo = torch.compile(foo, backend=counter, fullgraph=True) - - # There's a lot of stuff about sets that cannot work without a good deal of exertion on our part. - # Specifically, getting a set as input won't ever work with how GetItemSource works (Can't arbitrary access set contents) - # and so the guard story for the objects passed into input just isn't there atm. - with self.assertRaisesRegex( - torch._dynamo.exc.Unsupported, - "Unsupported method call", - ): - foo(inp) - - foo = torch.compile(foo, backend=counter, fullgraph=False) - foo(inp) - self.assertEqual(counter.frame_count, 1) - def test_reconstruct_set_across_graph_break(self): def foo(x, y): setty = set() diff --git a/test/dynamo/test_sets.py b/test/dynamo/test_sets.py new file mode 100644 index 00000000000..3287de613fe --- /dev/null +++ b/test/dynamo/test_sets.py @@ -0,0 +1,233 @@ +# Owner(s): ["module: dynamo"] + +# TODO: move set tests from test_functions.py/test_misc.py to this file + +import math +import unittest + +import torch +import torch._dynamo.test_case +from torch._dynamo.exc import Unsupported +from torch._dynamo.testing import CompileCounter +from torch.testing._internal.common_utils import munge_exc +from torch.testing._internal.logging_utils import LoggingTestCase, make_logging_test + + +class TestSetGuards(LoggingTestCase): + def test_set_with_function(self): + s = { + torch._C._set_grad_enabled, + "hello", + torch.amp._exit_autocast, + } + cnts = CompileCounter() + + @torch.compile(backend=cnts, fullgraph=True) + def fn(x, s): + if torch.amp._exit_autocast in s: + return x.sin() + return x.cos() + + x = torch.randn(2) + y = fn(x, s) + self.assertEqual(y, x.sin()) + self.assertEqual(cnts.frame_count, 1) + + s.remove(torch.amp._exit_autocast) + s.add(torch._C._set_fwd_grad_enabled) + y = fn(x, s) + self.assertEqual(y, x.cos()) + self.assertEqual(cnts.frame_count, 2) + + @make_logging_test(recompiles=True) + def test_in_guard(self, records): + s = { + "Dynamo", + "Inductor", + "PyTorch", + torch.sin, + } + cnts = CompileCounter() + + @torch.compile(backend=cnts, fullgraph=True) + def fn(x, s): + if "PyTorch" in s: + return x.sin() + return x.cos() + + x = torch.randn(2) + y = fn(x, s) + self.assertEqual(y, x.sin()) + self.assertEqual(cnts.frame_count, 1) + + s.remove("PyTorch") + s.add("Cuda") + y = fn(x, s) + self.assertEqual(y, x.cos()) + self.assertEqual(cnts.frame_count, 2) + self.assertGreater(len(records), 0) + record = self.getRecord(records, "set.__contains__") + self.assertIn( + """set.__contains__(s, 'PyTorch')""", + munge_exc(record.getMessage()), + ) + + def test_set_with_tensors(self): + s = { + torch.ones(1), + torch.tensor([1.0]), + torch.zeros(1), + } + cnts = CompileCounter() + + @torch.compile(backend=cnts, fullgraph=True) + def fn(x, s): + z = torch.zeros(1) + for i in s: + z += i + return x + z + + x = torch.tensor([1.0]) + self.assertExpectedInlineMunged( + Unsupported, + lambda: fn(x, s), + """\ +Attempted to wrap a set with tensors + Explanation: Dynamo cannot trace sets of tensors. To get a stable ordering, Dynamo needs to convert the set into a list and the order might not be stable if the set contains tensors. + Hint: Use a dictionary where the keys are tensors. + Hint: It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues. + + Developer debug context: Python set containing torch.Tensor elements + + +from user code: + File "test_sets.py", line N, in fn + for i in s:""", # noqa: B950 + ) + + def test_set_multiple_types(self): + s = { + "PyTorch", + 3.3, + 1j, + math.nan, + } + cnts = CompileCounter() + + @torch.compile(backend=cnts, fullgraph=True) + def fn(x, s): + if "PyTorch" in s: + return x.sin() + return x.cos() + + x = torch.tensor(1.0) + y = fn(x, s) + self.assertEqual(y, x.sin()) + self.assertEqual(cnts.frame_count, 1) + + s.remove("PyTorch") + y = fn(x, s) + self.assertEqual(y, x.cos()) + self.assertEqual(cnts.frame_count, 2) + + def test_set_recompile_on_key_pop(self): + s = { + torch._C._set_grad_enabled, + torch.amp._enter_autocast, + torch.amp._exit_autocast, + } + + cnts = CompileCounter() + + def fn(x, s): + if torch.amp._exit_autocast in s: + return x.sin() + return x.cos() + + x = torch.randn(4) + opt_fn = torch.compile(fn, backend=cnts, fullgraph=True) + res = opt_fn(x, s) + opt_fn(x, s) + self.assertEqual(res, fn(x, s)) + # No recompilation + self.assertEqual(cnts.frame_count, 1) + + # Pop a value + s.remove(torch.amp._exit_autocast) + + res = opt_fn(x, s) + # Check recompilation + self.assertEqual(cnts.frame_count, 2) + self.assertEqual(res, fn(x, s)) + + def test_set_recompile_on_key_change(self): + s = { + torch._C._set_grad_enabled, + torch.amp._enter_autocast, + torch.amp._exit_autocast, + } + + cnts = CompileCounter() + + def fn(x, s): + if torch.amp._exit_autocast in s: + return x.sin() + return x.cos() + + x = torch.randn(4) + opt_fn = torch.compile(fn, backend=cnts, fullgraph=True) + res = opt_fn(x, s) + opt_fn(x, s) + self.assertEqual(res, fn(x, s)) + # No recompilation + self.assertEqual(cnts.frame_count, 1) + + # Pop a value + s.remove(torch.amp._exit_autocast) + # Add a different value + s.add(torch._C._set_autograd_fallback_mode) + + res = opt_fn(x, s) + # Check recompilation + self.assertEqual(cnts.frame_count, 2) + self.assertEqual(res, fn(x, s)) + + @unittest.skip("random failures on Python 3.9") + def test_set_guard_on_keys_change(self): + # This test guarantee that we're not triggering any of the dict guards + # on sets + s = { + torch._C._set_grad_enabled, + torch.amp._enter_autocast, + torch.amp._exit_autocast, + } + + cnts = CompileCounter() + + def fn(x, s): + for e in s: + x = x * len(str(e)) + return x + + opt_fn = torch.compile(fn, backend=cnts, fullgraph=True) + opt_fn(torch.randn(4), s) + opt_fn(torch.randn(4), s) + # No recompilation + self.assertEqual(cnts.frame_count, 1) + + # pop and add the same item + s.remove(torch.amp._exit_autocast) + # It is not guaranteed that _exit_autocast will be in a specific order + s.add(torch.amp._exit_autocast) + + x = torch.randn(4) + res = opt_fn(x, s) + # Check Dynamo don't recompile + self.assertEqual(cnts.frame_count, 1) + self.assertEqual(res, fn(x, s)) + + +if __name__ == "__main__": + from torch._dynamo.test_case import run_tests + + run_tests() diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestBasicOpsTriple.test_empty_difference b/test/dynamo_expected_failures/CPython313-test_set-TestBasicOpsTriple.test_empty_difference deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestBasicOpsTriple.test_empty_difference_rev b/test/dynamo_expected_failures/CPython313-test_set-TestBasicOpsTriple.test_empty_difference_rev deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestBasicOpsTriple.test_empty_intersection b/test/dynamo_expected_failures/CPython313-test_set-TestBasicOpsTriple.test_empty_intersection deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestBasicOpsTriple.test_empty_isdisjoint b/test/dynamo_expected_failures/CPython313-test_set-TestBasicOpsTriple.test_empty_isdisjoint deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestBasicOpsTriple.test_empty_symmetric_difference b/test/dynamo_expected_failures/CPython313-test_set-TestBasicOpsTriple.test_empty_symmetric_difference deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestBasicOpsTriple.test_empty_union b/test/dynamo_expected_failures/CPython313-test_set-TestBasicOpsTriple.test_empty_union deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestBasicOpsTriple.test_equivalent_equality b/test/dynamo_expected_failures/CPython313-test_set-TestBasicOpsTriple.test_equivalent_equality deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestBasicOpsTriple.test_intersection_empty b/test/dynamo_expected_failures/CPython313-test_set-TestBasicOpsTriple.test_intersection_empty deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestBasicOpsTriple.test_length b/test/dynamo_expected_failures/CPython313-test_set-TestBasicOpsTriple.test_length deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestBasicOpsTriple.test_self_difference b/test/dynamo_expected_failures/CPython313-test_set-TestBasicOpsTriple.test_self_difference deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestBasicOpsTriple.test_self_equality b/test/dynamo_expected_failures/CPython313-test_set-TestBasicOpsTriple.test_self_equality deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestBasicOpsTriple.test_self_intersection b/test/dynamo_expected_failures/CPython313-test_set-TestBasicOpsTriple.test_self_intersection deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestBasicOpsTriple.test_self_isdisjoint b/test/dynamo_expected_failures/CPython313-test_set-TestBasicOpsTriple.test_self_isdisjoint deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestBasicOpsTriple.test_self_symmetric_difference b/test/dynamo_expected_failures/CPython313-test_set-TestBasicOpsTriple.test_self_symmetric_difference deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestBasicOpsTriple.test_self_union b/test/dynamo_expected_failures/CPython313-test_set-TestBasicOpsTriple.test_self_union deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestBasicOpsTriple.test_union_empty b/test/dynamo_expected_failures/CPython313-test_set-TestBasicOpsTriple.test_union_empty deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestFrozenSet.test_and b/test/dynamo_expected_failures/CPython313-test_set-TestFrozenSet.test_and deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestFrozenSet.test_equality b/test/dynamo_expected_failures/CPython313-test_set-TestFrozenSet.test_equality deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestFrozenSet.test_sub b/test/dynamo_expected_failures/CPython313-test_set-TestFrozenSet.test_sub deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestFrozenSet.test_xor b/test/dynamo_expected_failures/CPython313-test_set-TestFrozenSet.test_xor deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_sys-SysModuleTest.test_module_names b/test/dynamo_expected_failures/CPython313-test_sys-SysModuleTest.test_module_names deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestBasicOpsTriple.test_copy b/test/dynamo_expected_failures/TestTorch.test_print similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_set-TestBasicOpsTriple.test_copy rename to test/dynamo_expected_failures/TestTorch.test_print diff --git a/torch/_dynamo/graph_break_registry.json b/torch/_dynamo/graph_break_registry.json index c5c3117e117..153b7132bf1 100644 --- a/torch/_dynamo/graph_break_registry.json +++ b/torch/_dynamo/graph_break_registry.json @@ -2188,5 +2188,16 @@ "Remove the `@contextlib.contextmanager` decorator." ] } + ], + "GB0222": [ + { + "Gb_type": "Attempted to wrap a set with tensors", + "Context": "Python set containing torch.Tensor elements", + "Explanation": "Dynamo cannot trace sets of tensors. To get a stable ordering, Dynamo needs to convert the set into a list and the order might not be stable if the set contains tensors.", + "Hints": [ + "Use a dictionary where the keys are tensors.", + "It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues." + ] + } ] -} +} \ No newline at end of file diff --git a/torch/_dynamo/guards.py b/torch/_dynamo/guards.py index 4d60b3cbb05..fb064df4375 100644 --- a/torch/_dynamo/guards.py +++ b/torch/_dynamo/guards.py @@ -119,6 +119,7 @@ from .source import ( ListGetItemSource, LocalSource, NNModuleSource, + NonSerializableSetGetItemSource, NumpyTensorSource, OptimizerSource, ScriptObjectQualifiedNameSource, @@ -945,6 +946,11 @@ class GuardBuilder(GuardBuilderBase): # Fix this if condition if isinstance(example_value, dict_keys): guard_manager_enum = GuardManagerType.DICT_GUARD_MANAGER + elif isinstance(example_value, (set, frozenset)): + # we don't need to guard on key order for set/frozenset + # but the if above will be true for these types as set is + # implemented using a dict in Dynamo + guard_manager_enum = GuardManagerType.GUARD_MANAGER else: assert isinstance(example_value, dict) guard_manager_enum = GuardManagerType.DICT_GUARD_MANAGER @@ -1289,6 +1295,14 @@ class GuardBuilder(GuardBuilderBase): example_value=example_value, guard_manager_enum=guard_manager_enum, ) + elif istype(source, NonSerializableSetGetItemSource): + assert base_guard_manager + out = base_guard_manager.set_getitem_manager( + index=source.index, + source=source_name, + example_value=example_value, + guard_manager_enum=guard_manager_enum, + ) elif istype(source, WeakRefCallSource): assert base_guard_manager # to make mypy happy out = base_guard_manager.weakref_call_manager( @@ -1507,6 +1521,19 @@ class GuardBuilder(GuardBuilderBase): not invert, key, get_verbose_code_parts(code, guard) ) + def SET_CONTAINS(self, guard: Guard, key: Any, invert: bool): + set_ref = self.arg_ref(guard) + item = key + contains = not invert # install_dict_contains_guard inverts "contains" + + code = f"set.__contains__({set_ref}, {item!r})" + + self._set_guard_export_info(guard, [code]) + + self.get_guard_manager(guard).add_set_contains_guard( + contains, item, get_verbose_code_parts(code, guard) + ) + def BOOL_MATCH(self, guard: Guard): # checks val == True or val == False ref = self.arg_ref(guard) diff --git a/torch/_dynamo/source.py b/torch/_dynamo/source.py index 2ae169f099f..4c18c1f47ce 100644 --- a/torch/_dynamo/source.py +++ b/torch/_dynamo/source.py @@ -587,6 +587,34 @@ class ConstDictKeySource(ChainedSource): return True +@dataclasses.dataclass(frozen=True) +class NonSerializableSetGetItemSource(ChainedSource): + index: int + + def __post_init__(self): + from .variables import ConstantVariable + + assert ConstantVariable.is_literal(self.index) + + def guard_source(self): + return self.base.guard_source() + + def reconstruct(self, codegen: "PyCodegen"): + codegen.add_push_null( + lambda: codegen.load_import_from(utils.__name__, "set_getitem") + ) + codegen(self.base) + codegen.append_output(codegen.create_load_const(self.index)) + codegen.extend_output(create_call_function(2, False)) + + def name(self): + # set ordering might not be stable + return f"list({self.base.name()})[{self.index!r}]" + + def is_dict_key(self): + return False + + # Used to access an item from the dictionary @dataclasses.dataclass(frozen=True) class DictGetItemSource(ChainedSource): diff --git a/torch/_dynamo/utils.py b/torch/_dynamo/utils.py index e11888b9dc6..9a1f79293b3 100644 --- a/torch/_dynamo/utils.py +++ b/torch/_dynamo/utils.py @@ -2572,6 +2572,11 @@ def dict_keys_getitem(d, n): return next(itertools.islice(dict_class.keys(d), n, n + 1)) +def set_getitem(s, n): + # Set ordering might not be stable + return list(s)[n] + + def enum_repr(value, local): # enum class can override __str__ method. Use __class__ and name attribute # to extract the class name and key name. diff --git a/torch/_dynamo/variables/builder.py b/torch/_dynamo/variables/builder.py index adfa3451487..1853e909420 100644 --- a/torch/_dynamo/variables/builder.py +++ b/torch/_dynamo/variables/builder.py @@ -110,6 +110,7 @@ from ..source import ( is_from_unspecialized_nn_module_source, ListGetItemSource, LocalSource, + NonSerializableSetGetItemSource, NumpyTensorSource, OptimizerSource, RandomValueSource, @@ -772,6 +773,38 @@ class VariableBuilder: var = TorchFunctionModeVariable(value, source=self.source) self.tx.output.side_effects.track_object_existing(value, var) return var + elif istype(value, set): + if any(isinstance(x, torch.Tensor) for x in value): + unimplemented_v2( + gb_type="Attempted to wrap a set with tensors", + context="Python set containing torch.Tensor elements", + explanation=( + "Dynamo cannot trace sets of tensors. To get a stable ordering, " + "Dynamo needs to convert the set into a list and the order might not be " + "stable if the set contains tensors." + ), + hints=[ + "Use a dictionary where the keys are tensors.", + *graph_break_hints.SUPPORTABLE, + ], + ) + + self.install_guards(GuardBuilder.TYPE_MATCH) + self.install_guards(GuardBuilder.SEQUENCE_LENGTH) + + # The list gives a ordering for the set items. The ordering is based + # on the Python hash and it is not related to object ordering inside + # the set object. The order being incorrect at runtime will lead to + # a recompilation. + L = list(value) + items = [ + LazyVariableTracker.create( + v, source=NonSerializableSetGetItemSource(self.source, i) + ) + for i, v in enumerate(L) + ] + result = SetVariable(items, source=self.source) + return self.tx.output.side_effects.track_object_existing(value, result) elif istype(value, frozenset) and all( ( # For DBR quantization, we could get a frozenset of torch funcs. diff --git a/torch/_dynamo/variables/dicts.py b/torch/_dynamo/variables/dicts.py index a5eea11e65f..f815ca58e99 100644 --- a/torch/_dynamo/variables/dicts.py +++ b/torch/_dynamo/variables/dicts.py @@ -129,6 +129,8 @@ def is_hashable(x): class ConstDictVariable(VariableTracker): + CONTAINS_GUARD = GuardBuilder.DICT_CONTAINS + _nonvar_fields = { "user_cls", *VariableTracker._nonvar_fields, @@ -399,7 +401,7 @@ class ConstDictVariable(VariableTracker): install_guard( self.make_guard( functools.partial( - GuardBuilder.DICT_CONTAINS, + type(self).CONTAINS_GUARD, key=args[0].value, invert=not contains, ) @@ -761,6 +763,8 @@ class DefaultDictVariable(ConstDictVariable): class SetVariable(ConstDictVariable): """We model a sets as dictionary with None values""" + CONTAINS_GUARD = GuardBuilder.SET_CONTAINS + def __init__( self, items: list[VariableTracker], @@ -913,8 +917,7 @@ class SetVariable(ConstDictVariable): pass def install_dict_contains_guard(self, tx, args): - # Already EQUALS_MATCH guarded - pass + super().install_dict_contains_guard(tx, args) class FrozensetVariable(SetVariable): @@ -991,6 +994,14 @@ class DictKeySetVariable(SetVariable): + "])" ) + def install_dict_keys_match_guard(self): + # Already EQUALS_MATCH guarded + pass + + def install_dict_contains_guard(self, tx, args): + # Already EQUALS_MATCH guarded + pass + @property def set_items(self): return self.items diff --git a/torch/csrc/dynamo/guards.cpp b/torch/csrc/dynamo/guards.cpp index 0362649140b..83fb0adbe6c 100644 --- a/torch/csrc/dynamo/guards.cpp +++ b/torch/csrc/dynamo/guards.cpp @@ -6,6 +6,7 @@ #include #include #include +#include #include #include #include @@ -2007,6 +2008,33 @@ class DICT_CONTAINS : public LeafGuard { py::object _key; }; +// Check that set contains an item. +class SET_CONTAINS : public LeafGuard { + public: + SET_CONTAINS( + RootGuardManager* root_guard_manager, + bool contains, + py::object item, + py::object verbose_code_parts) + : LeafGuard(root_guard_manager, std::move(verbose_code_parts)), + _contains(contains ? 1 : 0), + _item(std::move(item)) {} + + bool check_nopybind(PyObject* value) override { // borrowed ref + int result = (PySet_Check(value) || PyFrozenSet_Check(value)) && + PySet_Contains(value, _item.ptr()); + if (result == -1) { + PyErr_Clear(); + return false; + } + return result == _contains; + } + + private: + int _contains; + py::object _item; +}; + /** * Relational guards compare more than one value. We implement Relational * guards by capturing some state in the guard object. For example for tensor @@ -4213,6 +4241,82 @@ class ListGetItemGuardAccessor : public GuardAccessor { Py_ssize_t _index{-1}; }; +/** + * Represents set[index] accessor by converting the set into a list. + */ +class SetGetItemGuardAccessor : public GuardAccessor { + public: + SetGetItemGuardAccessor( + RootGuardManager* root, + const py::object& index, + std::string source, + py::handle example_value, + py::handle guard_manager_enum) + : GuardAccessor( + root, + index, + std::move(source), + example_value, + guard_manager_enum), + _index(py::cast(index)) {} + + // NB: Intentional duplication between check_nopybind and + // check_verbose_nopybind. + bool check_nopybind(PyObject* obj, bool matches_dict_tag = false) + override { // borrowed ref + + PyObject* lst = PySequence_List(obj); + PyObject* x = PyList_GetItem(lst, _index); // borrowed ref + Py_XDECREF(lst); + if (x == nullptr) { + PyErr_Clear(); + return false; + } + bool result = _guard_manager->check_nopybind(x); + return result; + } + + GuardDebugInfo check_verbose_nopybind( + PyObject* obj) override { // borrowed ref + + PyObject* lst = PySequence_List(obj); + PyObject* x = PyList_GetItem(lst, _index); // borrowed ref + Py_XDECREF(lst); + + if (x == nullptr) { + PyErr_Clear(); + return GuardDebugInfo(false, 0); + } + GuardDebugInfo result = _guard_manager->check_verbose_nopybind(x); + return result; + } + + std::string repr() const override { + return fmt::format("SetGetItemGuardAccessor(index={})", _index); + } + + public: // cloning functions + SetGetItemGuardAccessor( + GuardManager* guard_manager, + SetGetItemGuardAccessor* from) + : GuardAccessor(guard_manager, from) { + from->clone_visitor(this); + } + + GuardAccessor* clone( + RootGuardManager* cloned_root, + const py::function& clone_filter_fn) override { + return clone_common(cloned_root, clone_filter_fn); + } + + void clone_visitor(SetGetItemGuardAccessor* to) { + to->_index = _index; + } + + private: + Py_ssize_t _index{-1}; +}; + /** * Represents tuple[index] accessor. It is faster than generic * GetItemGuardAccessor. @@ -5607,6 +5711,10 @@ PyObject* torch_c_dynamo_guards_init() { py_m, "DICT_CONTAINS") .def(py::init()) .def("__call__", &DICT_CONTAINS::check); + py::class_>( + py_m, "SET_CONTAINS") + .def(py::init()) + .def("__call__", &SET_CONTAINS::check); py::class_>( py_m, "DYNAMIC_INDICES") .def(py::init()) @@ -5969,6 +6077,18 @@ PyObject* torch_c_dynamo_guards_init() { std::move(key), std::move(verbose_code_parts))); }) + .def( + "add_set_contains_guard", + [](GuardManager& self, + bool contains, + py::object item, + py::object verbose_code_parts) -> void { + self.add_leaf_guard(std::make_shared( + self.get_root(), + contains, + std::move(item), + std::move(verbose_code_parts))); + }) .def( "add_dynamic_indices_guard", [](GuardManager& self, @@ -6226,6 +6346,14 @@ PyObject* torch_c_dynamo_guards_init() { py::arg("example_value"), py::arg("guard_manager_enum"), py::return_value_policy::reference) + .def( + "set_getitem_manager", + &GuardManager::get_child_manager, + py::arg("index"), + py::arg("source"), + py::arg("example_value"), + py::arg("guard_manager_enum"), + py::return_value_policy::reference) // return by reference because GuardManager has the ownership of accessors // and guard managers .def(