mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[Set] Support sets in VariableBuilder (#153150)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/153150 Approved by: https://github.com/zou3519
This commit is contained in:
parent
6c42afe196
commit
e7167dbacf
|
|
@ -9321,31 +9321,6 @@ def ___make_guard_fn():
|
|||
self.assertEqual(counter.frame_count, 1)
|
||||
self.assertEqual(result, eager_result)
|
||||
|
||||
def test_input_set_graph_break(self):
|
||||
def foo(x):
|
||||
return x.pop() * x.pop()
|
||||
|
||||
x = torch.randn(10, 10)
|
||||
y = torch.randn(10, 10)
|
||||
|
||||
counter = CompileCounter()
|
||||
|
||||
inp = {x, x, x, x, y, y}
|
||||
foo = torch.compile(foo, backend=counter, fullgraph=True)
|
||||
|
||||
# There's a lot of stuff about sets that cannot work without a good deal of exertion on our part.
|
||||
# Specifically, getting a set as input won't ever work with how GetItemSource works (Can't arbitrary access set contents)
|
||||
# and so the guard story for the objects passed into input just isn't there atm.
|
||||
with self.assertRaisesRegex(
|
||||
torch._dynamo.exc.Unsupported,
|
||||
"Unsupported method call",
|
||||
):
|
||||
foo(inp)
|
||||
|
||||
foo = torch.compile(foo, backend=counter, fullgraph=False)
|
||||
foo(inp)
|
||||
self.assertEqual(counter.frame_count, 1)
|
||||
|
||||
def test_reconstruct_set_across_graph_break(self):
|
||||
def foo(x, y):
|
||||
setty = set()
|
||||
|
|
|
|||
233
test/dynamo/test_sets.py
Normal file
233
test/dynamo/test_sets.py
Normal file
|
|
@ -0,0 +1,233 @@
|
|||
# Owner(s): ["module: dynamo"]
|
||||
|
||||
# TODO: move set tests from test_functions.py/test_misc.py to this file
|
||||
|
||||
import math
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
import torch._dynamo.test_case
|
||||
from torch._dynamo.exc import Unsupported
|
||||
from torch._dynamo.testing import CompileCounter
|
||||
from torch.testing._internal.common_utils import munge_exc
|
||||
from torch.testing._internal.logging_utils import LoggingTestCase, make_logging_test
|
||||
|
||||
|
||||
class TestSetGuards(LoggingTestCase):
|
||||
def test_set_with_function(self):
|
||||
s = {
|
||||
torch._C._set_grad_enabled,
|
||||
"hello",
|
||||
torch.amp._exit_autocast,
|
||||
}
|
||||
cnts = CompileCounter()
|
||||
|
||||
@torch.compile(backend=cnts, fullgraph=True)
|
||||
def fn(x, s):
|
||||
if torch.amp._exit_autocast in s:
|
||||
return x.sin()
|
||||
return x.cos()
|
||||
|
||||
x = torch.randn(2)
|
||||
y = fn(x, s)
|
||||
self.assertEqual(y, x.sin())
|
||||
self.assertEqual(cnts.frame_count, 1)
|
||||
|
||||
s.remove(torch.amp._exit_autocast)
|
||||
s.add(torch._C._set_fwd_grad_enabled)
|
||||
y = fn(x, s)
|
||||
self.assertEqual(y, x.cos())
|
||||
self.assertEqual(cnts.frame_count, 2)
|
||||
|
||||
@make_logging_test(recompiles=True)
|
||||
def test_in_guard(self, records):
|
||||
s = {
|
||||
"Dynamo",
|
||||
"Inductor",
|
||||
"PyTorch",
|
||||
torch.sin,
|
||||
}
|
||||
cnts = CompileCounter()
|
||||
|
||||
@torch.compile(backend=cnts, fullgraph=True)
|
||||
def fn(x, s):
|
||||
if "PyTorch" in s:
|
||||
return x.sin()
|
||||
return x.cos()
|
||||
|
||||
x = torch.randn(2)
|
||||
y = fn(x, s)
|
||||
self.assertEqual(y, x.sin())
|
||||
self.assertEqual(cnts.frame_count, 1)
|
||||
|
||||
s.remove("PyTorch")
|
||||
s.add("Cuda")
|
||||
y = fn(x, s)
|
||||
self.assertEqual(y, x.cos())
|
||||
self.assertEqual(cnts.frame_count, 2)
|
||||
self.assertGreater(len(records), 0)
|
||||
record = self.getRecord(records, "set.__contains__")
|
||||
self.assertIn(
|
||||
"""set.__contains__(s, 'PyTorch')""",
|
||||
munge_exc(record.getMessage()),
|
||||
)
|
||||
|
||||
def test_set_with_tensors(self):
|
||||
s = {
|
||||
torch.ones(1),
|
||||
torch.tensor([1.0]),
|
||||
torch.zeros(1),
|
||||
}
|
||||
cnts = CompileCounter()
|
||||
|
||||
@torch.compile(backend=cnts, fullgraph=True)
|
||||
def fn(x, s):
|
||||
z = torch.zeros(1)
|
||||
for i in s:
|
||||
z += i
|
||||
return x + z
|
||||
|
||||
x = torch.tensor([1.0])
|
||||
self.assertExpectedInlineMunged(
|
||||
Unsupported,
|
||||
lambda: fn(x, s),
|
||||
"""\
|
||||
Attempted to wrap a set with tensors
|
||||
Explanation: Dynamo cannot trace sets of tensors. To get a stable ordering, Dynamo needs to convert the set into a list and the order might not be stable if the set contains tensors.
|
||||
Hint: Use a dictionary where the keys are tensors.
|
||||
Hint: It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues.
|
||||
|
||||
Developer debug context: Python set containing torch.Tensor elements
|
||||
|
||||
|
||||
from user code:
|
||||
File "test_sets.py", line N, in fn
|
||||
for i in s:""", # noqa: B950
|
||||
)
|
||||
|
||||
def test_set_multiple_types(self):
|
||||
s = {
|
||||
"PyTorch",
|
||||
3.3,
|
||||
1j,
|
||||
math.nan,
|
||||
}
|
||||
cnts = CompileCounter()
|
||||
|
||||
@torch.compile(backend=cnts, fullgraph=True)
|
||||
def fn(x, s):
|
||||
if "PyTorch" in s:
|
||||
return x.sin()
|
||||
return x.cos()
|
||||
|
||||
x = torch.tensor(1.0)
|
||||
y = fn(x, s)
|
||||
self.assertEqual(y, x.sin())
|
||||
self.assertEqual(cnts.frame_count, 1)
|
||||
|
||||
s.remove("PyTorch")
|
||||
y = fn(x, s)
|
||||
self.assertEqual(y, x.cos())
|
||||
self.assertEqual(cnts.frame_count, 2)
|
||||
|
||||
def test_set_recompile_on_key_pop(self):
|
||||
s = {
|
||||
torch._C._set_grad_enabled,
|
||||
torch.amp._enter_autocast,
|
||||
torch.amp._exit_autocast,
|
||||
}
|
||||
|
||||
cnts = CompileCounter()
|
||||
|
||||
def fn(x, s):
|
||||
if torch.amp._exit_autocast in s:
|
||||
return x.sin()
|
||||
return x.cos()
|
||||
|
||||
x = torch.randn(4)
|
||||
opt_fn = torch.compile(fn, backend=cnts, fullgraph=True)
|
||||
res = opt_fn(x, s)
|
||||
opt_fn(x, s)
|
||||
self.assertEqual(res, fn(x, s))
|
||||
# No recompilation
|
||||
self.assertEqual(cnts.frame_count, 1)
|
||||
|
||||
# Pop a value
|
||||
s.remove(torch.amp._exit_autocast)
|
||||
|
||||
res = opt_fn(x, s)
|
||||
# Check recompilation
|
||||
self.assertEqual(cnts.frame_count, 2)
|
||||
self.assertEqual(res, fn(x, s))
|
||||
|
||||
def test_set_recompile_on_key_change(self):
|
||||
s = {
|
||||
torch._C._set_grad_enabled,
|
||||
torch.amp._enter_autocast,
|
||||
torch.amp._exit_autocast,
|
||||
}
|
||||
|
||||
cnts = CompileCounter()
|
||||
|
||||
def fn(x, s):
|
||||
if torch.amp._exit_autocast in s:
|
||||
return x.sin()
|
||||
return x.cos()
|
||||
|
||||
x = torch.randn(4)
|
||||
opt_fn = torch.compile(fn, backend=cnts, fullgraph=True)
|
||||
res = opt_fn(x, s)
|
||||
opt_fn(x, s)
|
||||
self.assertEqual(res, fn(x, s))
|
||||
# No recompilation
|
||||
self.assertEqual(cnts.frame_count, 1)
|
||||
|
||||
# Pop a value
|
||||
s.remove(torch.amp._exit_autocast)
|
||||
# Add a different value
|
||||
s.add(torch._C._set_autograd_fallback_mode)
|
||||
|
||||
res = opt_fn(x, s)
|
||||
# Check recompilation
|
||||
self.assertEqual(cnts.frame_count, 2)
|
||||
self.assertEqual(res, fn(x, s))
|
||||
|
||||
@unittest.skip("random failures on Python 3.9")
|
||||
def test_set_guard_on_keys_change(self):
|
||||
# This test guarantee that we're not triggering any of the dict guards
|
||||
# on sets
|
||||
s = {
|
||||
torch._C._set_grad_enabled,
|
||||
torch.amp._enter_autocast,
|
||||
torch.amp._exit_autocast,
|
||||
}
|
||||
|
||||
cnts = CompileCounter()
|
||||
|
||||
def fn(x, s):
|
||||
for e in s:
|
||||
x = x * len(str(e))
|
||||
return x
|
||||
|
||||
opt_fn = torch.compile(fn, backend=cnts, fullgraph=True)
|
||||
opt_fn(torch.randn(4), s)
|
||||
opt_fn(torch.randn(4), s)
|
||||
# No recompilation
|
||||
self.assertEqual(cnts.frame_count, 1)
|
||||
|
||||
# pop and add the same item
|
||||
s.remove(torch.amp._exit_autocast)
|
||||
# It is not guaranteed that _exit_autocast will be in a specific order
|
||||
s.add(torch.amp._exit_autocast)
|
||||
|
||||
x = torch.randn(4)
|
||||
res = opt_fn(x, s)
|
||||
# Check Dynamo don't recompile
|
||||
self.assertEqual(cnts.frame_count, 1)
|
||||
self.assertEqual(res, fn(x, s))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._dynamo.test_case import run_tests
|
||||
|
||||
run_tests()
|
||||
|
|
@ -2188,5 +2188,16 @@
|
|||
"Remove the `@contextlib.contextmanager` decorator."
|
||||
]
|
||||
}
|
||||
],
|
||||
"GB0222": [
|
||||
{
|
||||
"Gb_type": "Attempted to wrap a set with tensors",
|
||||
"Context": "Python set containing torch.Tensor elements",
|
||||
"Explanation": "Dynamo cannot trace sets of tensors. To get a stable ordering, Dynamo needs to convert the set into a list and the order might not be stable if the set contains tensors.",
|
||||
"Hints": [
|
||||
"Use a dictionary where the keys are tensors.",
|
||||
"It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues."
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
|
@ -119,6 +119,7 @@ from .source import (
|
|||
ListGetItemSource,
|
||||
LocalSource,
|
||||
NNModuleSource,
|
||||
NonSerializableSetGetItemSource,
|
||||
NumpyTensorSource,
|
||||
OptimizerSource,
|
||||
ScriptObjectQualifiedNameSource,
|
||||
|
|
@ -945,6 +946,11 @@ class GuardBuilder(GuardBuilderBase):
|
|||
# Fix this if condition
|
||||
if isinstance(example_value, dict_keys):
|
||||
guard_manager_enum = GuardManagerType.DICT_GUARD_MANAGER
|
||||
elif isinstance(example_value, (set, frozenset)):
|
||||
# we don't need to guard on key order for set/frozenset
|
||||
# but the if above will be true for these types as set is
|
||||
# implemented using a dict in Dynamo
|
||||
guard_manager_enum = GuardManagerType.GUARD_MANAGER
|
||||
else:
|
||||
assert isinstance(example_value, dict)
|
||||
guard_manager_enum = GuardManagerType.DICT_GUARD_MANAGER
|
||||
|
|
@ -1289,6 +1295,14 @@ class GuardBuilder(GuardBuilderBase):
|
|||
example_value=example_value,
|
||||
guard_manager_enum=guard_manager_enum,
|
||||
)
|
||||
elif istype(source, NonSerializableSetGetItemSource):
|
||||
assert base_guard_manager
|
||||
out = base_guard_manager.set_getitem_manager(
|
||||
index=source.index,
|
||||
source=source_name,
|
||||
example_value=example_value,
|
||||
guard_manager_enum=guard_manager_enum,
|
||||
)
|
||||
elif istype(source, WeakRefCallSource):
|
||||
assert base_guard_manager # to make mypy happy
|
||||
out = base_guard_manager.weakref_call_manager(
|
||||
|
|
@ -1507,6 +1521,19 @@ class GuardBuilder(GuardBuilderBase):
|
|||
not invert, key, get_verbose_code_parts(code, guard)
|
||||
)
|
||||
|
||||
def SET_CONTAINS(self, guard: Guard, key: Any, invert: bool):
|
||||
set_ref = self.arg_ref(guard)
|
||||
item = key
|
||||
contains = not invert # install_dict_contains_guard inverts "contains"
|
||||
|
||||
code = f"set.__contains__({set_ref}, {item!r})"
|
||||
|
||||
self._set_guard_export_info(guard, [code])
|
||||
|
||||
self.get_guard_manager(guard).add_set_contains_guard(
|
||||
contains, item, get_verbose_code_parts(code, guard)
|
||||
)
|
||||
|
||||
def BOOL_MATCH(self, guard: Guard):
|
||||
# checks val == True or val == False
|
||||
ref = self.arg_ref(guard)
|
||||
|
|
|
|||
|
|
@ -587,6 +587,34 @@ class ConstDictKeySource(ChainedSource):
|
|||
return True
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class NonSerializableSetGetItemSource(ChainedSource):
|
||||
index: int
|
||||
|
||||
def __post_init__(self):
|
||||
from .variables import ConstantVariable
|
||||
|
||||
assert ConstantVariable.is_literal(self.index)
|
||||
|
||||
def guard_source(self):
|
||||
return self.base.guard_source()
|
||||
|
||||
def reconstruct(self, codegen: "PyCodegen"):
|
||||
codegen.add_push_null(
|
||||
lambda: codegen.load_import_from(utils.__name__, "set_getitem")
|
||||
)
|
||||
codegen(self.base)
|
||||
codegen.append_output(codegen.create_load_const(self.index))
|
||||
codegen.extend_output(create_call_function(2, False))
|
||||
|
||||
def name(self):
|
||||
# set ordering might not be stable
|
||||
return f"list({self.base.name()})[{self.index!r}]"
|
||||
|
||||
def is_dict_key(self):
|
||||
return False
|
||||
|
||||
|
||||
# Used to access an item from the dictionary
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class DictGetItemSource(ChainedSource):
|
||||
|
|
|
|||
|
|
@ -2572,6 +2572,11 @@ def dict_keys_getitem(d, n):
|
|||
return next(itertools.islice(dict_class.keys(d), n, n + 1))
|
||||
|
||||
|
||||
def set_getitem(s, n):
|
||||
# Set ordering might not be stable
|
||||
return list(s)[n]
|
||||
|
||||
|
||||
def enum_repr(value, local):
|
||||
# enum class can override __str__ method. Use __class__ and name attribute
|
||||
# to extract the class name and key name.
|
||||
|
|
|
|||
|
|
@ -110,6 +110,7 @@ from ..source import (
|
|||
is_from_unspecialized_nn_module_source,
|
||||
ListGetItemSource,
|
||||
LocalSource,
|
||||
NonSerializableSetGetItemSource,
|
||||
NumpyTensorSource,
|
||||
OptimizerSource,
|
||||
RandomValueSource,
|
||||
|
|
@ -772,6 +773,38 @@ class VariableBuilder:
|
|||
var = TorchFunctionModeVariable(value, source=self.source)
|
||||
self.tx.output.side_effects.track_object_existing(value, var)
|
||||
return var
|
||||
elif istype(value, set):
|
||||
if any(isinstance(x, torch.Tensor) for x in value):
|
||||
unimplemented_v2(
|
||||
gb_type="Attempted to wrap a set with tensors",
|
||||
context="Python set containing torch.Tensor elements",
|
||||
explanation=(
|
||||
"Dynamo cannot trace sets of tensors. To get a stable ordering, "
|
||||
"Dynamo needs to convert the set into a list and the order might not be "
|
||||
"stable if the set contains tensors."
|
||||
),
|
||||
hints=[
|
||||
"Use a dictionary where the keys are tensors.",
|
||||
*graph_break_hints.SUPPORTABLE,
|
||||
],
|
||||
)
|
||||
|
||||
self.install_guards(GuardBuilder.TYPE_MATCH)
|
||||
self.install_guards(GuardBuilder.SEQUENCE_LENGTH)
|
||||
|
||||
# The list gives a ordering for the set items. The ordering is based
|
||||
# on the Python hash and it is not related to object ordering inside
|
||||
# the set object. The order being incorrect at runtime will lead to
|
||||
# a recompilation.
|
||||
L = list(value)
|
||||
items = [
|
||||
LazyVariableTracker.create(
|
||||
v, source=NonSerializableSetGetItemSource(self.source, i)
|
||||
)
|
||||
for i, v in enumerate(L)
|
||||
]
|
||||
result = SetVariable(items, source=self.source)
|
||||
return self.tx.output.side_effects.track_object_existing(value, result)
|
||||
elif istype(value, frozenset) and all(
|
||||
(
|
||||
# For DBR quantization, we could get a frozenset of torch funcs.
|
||||
|
|
|
|||
|
|
@ -129,6 +129,8 @@ def is_hashable(x):
|
|||
|
||||
|
||||
class ConstDictVariable(VariableTracker):
|
||||
CONTAINS_GUARD = GuardBuilder.DICT_CONTAINS
|
||||
|
||||
_nonvar_fields = {
|
||||
"user_cls",
|
||||
*VariableTracker._nonvar_fields,
|
||||
|
|
@ -399,7 +401,7 @@ class ConstDictVariable(VariableTracker):
|
|||
install_guard(
|
||||
self.make_guard(
|
||||
functools.partial(
|
||||
GuardBuilder.DICT_CONTAINS,
|
||||
type(self).CONTAINS_GUARD,
|
||||
key=args[0].value,
|
||||
invert=not contains,
|
||||
)
|
||||
|
|
@ -761,6 +763,8 @@ class DefaultDictVariable(ConstDictVariable):
|
|||
class SetVariable(ConstDictVariable):
|
||||
"""We model a sets as dictionary with None values"""
|
||||
|
||||
CONTAINS_GUARD = GuardBuilder.SET_CONTAINS
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
items: list[VariableTracker],
|
||||
|
|
@ -913,8 +917,7 @@ class SetVariable(ConstDictVariable):
|
|||
pass
|
||||
|
||||
def install_dict_contains_guard(self, tx, args):
|
||||
# Already EQUALS_MATCH guarded
|
||||
pass
|
||||
super().install_dict_contains_guard(tx, args)
|
||||
|
||||
|
||||
class FrozensetVariable(SetVariable):
|
||||
|
|
@ -991,6 +994,14 @@ class DictKeySetVariable(SetVariable):
|
|||
+ "])"
|
||||
)
|
||||
|
||||
def install_dict_keys_match_guard(self):
|
||||
# Already EQUALS_MATCH guarded
|
||||
pass
|
||||
|
||||
def install_dict_contains_guard(self, tx, args):
|
||||
# Already EQUALS_MATCH guarded
|
||||
pass
|
||||
|
||||
@property
|
||||
def set_items(self):
|
||||
return self.items
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@
|
|||
#include <ATen/EmptyTensor.h>
|
||||
#include <ATen/SparseCsrTensorUtils.h>
|
||||
#include <c10/util/flat_hash_map.h>
|
||||
#include <fmt/format.h>
|
||||
#include <torch/csrc/autograd/grad_mode.h>
|
||||
#include <torch/csrc/autograd/utils/wrap_outputs.h>
|
||||
#include <torch/csrc/dynamo/guards.h>
|
||||
|
|
@ -2007,6 +2008,33 @@ class DICT_CONTAINS : public LeafGuard {
|
|||
py::object _key;
|
||||
};
|
||||
|
||||
// Check that set contains an item.
|
||||
class SET_CONTAINS : public LeafGuard {
|
||||
public:
|
||||
SET_CONTAINS(
|
||||
RootGuardManager* root_guard_manager,
|
||||
bool contains,
|
||||
py::object item,
|
||||
py::object verbose_code_parts)
|
||||
: LeafGuard(root_guard_manager, std::move(verbose_code_parts)),
|
||||
_contains(contains ? 1 : 0),
|
||||
_item(std::move(item)) {}
|
||||
|
||||
bool check_nopybind(PyObject* value) override { // borrowed ref
|
||||
int result = (PySet_Check(value) || PyFrozenSet_Check(value)) &&
|
||||
PySet_Contains(value, _item.ptr());
|
||||
if (result == -1) {
|
||||
PyErr_Clear();
|
||||
return false;
|
||||
}
|
||||
return result == _contains;
|
||||
}
|
||||
|
||||
private:
|
||||
int _contains;
|
||||
py::object _item;
|
||||
};
|
||||
|
||||
/**
|
||||
* Relational guards compare more than one value. We implement Relational
|
||||
* guards by capturing some state in the guard object. For example for tensor
|
||||
|
|
@ -4213,6 +4241,82 @@ class ListGetItemGuardAccessor : public GuardAccessor {
|
|||
Py_ssize_t _index{-1};
|
||||
};
|
||||
|
||||
/**
|
||||
* Represents set[index] accessor by converting the set into a list.
|
||||
*/
|
||||
class SetGetItemGuardAccessor : public GuardAccessor {
|
||||
public:
|
||||
SetGetItemGuardAccessor(
|
||||
RootGuardManager* root,
|
||||
const py::object& index,
|
||||
std::string source,
|
||||
py::handle example_value,
|
||||
py::handle guard_manager_enum)
|
||||
: GuardAccessor(
|
||||
root,
|
||||
index,
|
||||
std::move(source),
|
||||
example_value,
|
||||
guard_manager_enum),
|
||||
_index(py::cast<Py_ssize_t>(index)) {}
|
||||
|
||||
// NB: Intentional duplication between check_nopybind and
|
||||
// check_verbose_nopybind.
|
||||
bool check_nopybind(PyObject* obj, bool matches_dict_tag = false)
|
||||
override { // borrowed ref
|
||||
|
||||
PyObject* lst = PySequence_List(obj);
|
||||
PyObject* x = PyList_GetItem(lst, _index); // borrowed ref
|
||||
Py_XDECREF(lst);
|
||||
if (x == nullptr) {
|
||||
PyErr_Clear();
|
||||
return false;
|
||||
}
|
||||
bool result = _guard_manager->check_nopybind(x);
|
||||
return result;
|
||||
}
|
||||
|
||||
GuardDebugInfo check_verbose_nopybind(
|
||||
PyObject* obj) override { // borrowed ref
|
||||
|
||||
PyObject* lst = PySequence_List(obj);
|
||||
PyObject* x = PyList_GetItem(lst, _index); // borrowed ref
|
||||
Py_XDECREF(lst);
|
||||
|
||||
if (x == nullptr) {
|
||||
PyErr_Clear();
|
||||
return GuardDebugInfo(false, 0);
|
||||
}
|
||||
GuardDebugInfo result = _guard_manager->check_verbose_nopybind(x);
|
||||
return result;
|
||||
}
|
||||
|
||||
std::string repr() const override {
|
||||
return fmt::format("SetGetItemGuardAccessor(index={})", _index);
|
||||
}
|
||||
|
||||
public: // cloning functions
|
||||
SetGetItemGuardAccessor(
|
||||
GuardManager* guard_manager,
|
||||
SetGetItemGuardAccessor* from)
|
||||
: GuardAccessor(guard_manager, from) {
|
||||
from->clone_visitor(this);
|
||||
}
|
||||
|
||||
GuardAccessor* clone(
|
||||
RootGuardManager* cloned_root,
|
||||
const py::function& clone_filter_fn) override {
|
||||
return clone_common<SetGetItemGuardAccessor>(cloned_root, clone_filter_fn);
|
||||
}
|
||||
|
||||
void clone_visitor(SetGetItemGuardAccessor* to) {
|
||||
to->_index = _index;
|
||||
}
|
||||
|
||||
private:
|
||||
Py_ssize_t _index{-1};
|
||||
};
|
||||
|
||||
/**
|
||||
* Represents tuple[index] accessor. It is faster than generic
|
||||
* GetItemGuardAccessor.
|
||||
|
|
@ -5607,6 +5711,10 @@ PyObject* torch_c_dynamo_guards_init() {
|
|||
py_m, "DICT_CONTAINS")
|
||||
.def(py::init<RootGuardManager*, bool, py::object, py::list>())
|
||||
.def("__call__", &DICT_CONTAINS::check);
|
||||
py::class_<SET_CONTAINS, LeafGuard, std::shared_ptr<SET_CONTAINS>>(
|
||||
py_m, "SET_CONTAINS")
|
||||
.def(py::init<RootGuardManager*, bool, py::object, py::list>())
|
||||
.def("__call__", &SET_CONTAINS::check);
|
||||
py::class_<DYNAMIC_INDICES, LeafGuard, std::shared_ptr<DYNAMIC_INDICES>>(
|
||||
py_m, "DYNAMIC_INDICES")
|
||||
.def(py::init<RootGuardManager*, py::set, py::list>())
|
||||
|
|
@ -5969,6 +6077,18 @@ PyObject* torch_c_dynamo_guards_init() {
|
|||
std::move(key),
|
||||
std::move(verbose_code_parts)));
|
||||
})
|
||||
.def(
|
||||
"add_set_contains_guard",
|
||||
[](GuardManager& self,
|
||||
bool contains,
|
||||
py::object item,
|
||||
py::object verbose_code_parts) -> void {
|
||||
self.add_leaf_guard(std::make_shared<SET_CONTAINS>(
|
||||
self.get_root(),
|
||||
contains,
|
||||
std::move(item),
|
||||
std::move(verbose_code_parts)));
|
||||
})
|
||||
.def(
|
||||
"add_dynamic_indices_guard",
|
||||
[](GuardManager& self,
|
||||
|
|
@ -6226,6 +6346,14 @@ PyObject* torch_c_dynamo_guards_init() {
|
|||
py::arg("example_value"),
|
||||
py::arg("guard_manager_enum"),
|
||||
py::return_value_policy::reference)
|
||||
.def(
|
||||
"set_getitem_manager",
|
||||
&GuardManager::get_child_manager<SetGetItemGuardAccessor>,
|
||||
py::arg("index"),
|
||||
py::arg("source"),
|
||||
py::arg("example_value"),
|
||||
py::arg("guard_manager_enum"),
|
||||
py::return_value_policy::reference)
|
||||
// return by reference because GuardManager has the ownership of accessors
|
||||
// and guard managers
|
||||
.def(
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user