[Set] Support sets in VariableBuilder (#153150)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153150
Approved by: https://github.com/zou3519
This commit is contained in:
Guilherme Leobas 2025-07-03 14:35:34 -03:00 committed by PyTorch MergeBot
parent 6c42afe196
commit e7167dbacf
31 changed files with 480 additions and 29 deletions

View File

@ -9321,31 +9321,6 @@ def ___make_guard_fn():
self.assertEqual(counter.frame_count, 1)
self.assertEqual(result, eager_result)
def test_input_set_graph_break(self):
def foo(x):
return x.pop() * x.pop()
x = torch.randn(10, 10)
y = torch.randn(10, 10)
counter = CompileCounter()
inp = {x, x, x, x, y, y}
foo = torch.compile(foo, backend=counter, fullgraph=True)
# There's a lot of stuff about sets that cannot work without a good deal of exertion on our part.
# Specifically, getting a set as input won't ever work with how GetItemSource works (Can't arbitrary access set contents)
# and so the guard story for the objects passed into input just isn't there atm.
with self.assertRaisesRegex(
torch._dynamo.exc.Unsupported,
"Unsupported method call",
):
foo(inp)
foo = torch.compile(foo, backend=counter, fullgraph=False)
foo(inp)
self.assertEqual(counter.frame_count, 1)
def test_reconstruct_set_across_graph_break(self):
def foo(x, y):
setty = set()

233
test/dynamo/test_sets.py Normal file
View File

@ -0,0 +1,233 @@
# Owner(s): ["module: dynamo"]
# TODO: move set tests from test_functions.py/test_misc.py to this file
import math
import unittest
import torch
import torch._dynamo.test_case
from torch._dynamo.exc import Unsupported
from torch._dynamo.testing import CompileCounter
from torch.testing._internal.common_utils import munge_exc
from torch.testing._internal.logging_utils import LoggingTestCase, make_logging_test
class TestSetGuards(LoggingTestCase):
def test_set_with_function(self):
s = {
torch._C._set_grad_enabled,
"hello",
torch.amp._exit_autocast,
}
cnts = CompileCounter()
@torch.compile(backend=cnts, fullgraph=True)
def fn(x, s):
if torch.amp._exit_autocast in s:
return x.sin()
return x.cos()
x = torch.randn(2)
y = fn(x, s)
self.assertEqual(y, x.sin())
self.assertEqual(cnts.frame_count, 1)
s.remove(torch.amp._exit_autocast)
s.add(torch._C._set_fwd_grad_enabled)
y = fn(x, s)
self.assertEqual(y, x.cos())
self.assertEqual(cnts.frame_count, 2)
@make_logging_test(recompiles=True)
def test_in_guard(self, records):
s = {
"Dynamo",
"Inductor",
"PyTorch",
torch.sin,
}
cnts = CompileCounter()
@torch.compile(backend=cnts, fullgraph=True)
def fn(x, s):
if "PyTorch" in s:
return x.sin()
return x.cos()
x = torch.randn(2)
y = fn(x, s)
self.assertEqual(y, x.sin())
self.assertEqual(cnts.frame_count, 1)
s.remove("PyTorch")
s.add("Cuda")
y = fn(x, s)
self.assertEqual(y, x.cos())
self.assertEqual(cnts.frame_count, 2)
self.assertGreater(len(records), 0)
record = self.getRecord(records, "set.__contains__")
self.assertIn(
"""set.__contains__(s, 'PyTorch')""",
munge_exc(record.getMessage()),
)
def test_set_with_tensors(self):
s = {
torch.ones(1),
torch.tensor([1.0]),
torch.zeros(1),
}
cnts = CompileCounter()
@torch.compile(backend=cnts, fullgraph=True)
def fn(x, s):
z = torch.zeros(1)
for i in s:
z += i
return x + z
x = torch.tensor([1.0])
self.assertExpectedInlineMunged(
Unsupported,
lambda: fn(x, s),
"""\
Attempted to wrap a set with tensors
Explanation: Dynamo cannot trace sets of tensors. To get a stable ordering, Dynamo needs to convert the set into a list and the order might not be stable if the set contains tensors.
Hint: Use a dictionary where the keys are tensors.
Hint: It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues.
Developer debug context: Python set containing torch.Tensor elements
from user code:
File "test_sets.py", line N, in fn
for i in s:""", # noqa: B950
)
def test_set_multiple_types(self):
s = {
"PyTorch",
3.3,
1j,
math.nan,
}
cnts = CompileCounter()
@torch.compile(backend=cnts, fullgraph=True)
def fn(x, s):
if "PyTorch" in s:
return x.sin()
return x.cos()
x = torch.tensor(1.0)
y = fn(x, s)
self.assertEqual(y, x.sin())
self.assertEqual(cnts.frame_count, 1)
s.remove("PyTorch")
y = fn(x, s)
self.assertEqual(y, x.cos())
self.assertEqual(cnts.frame_count, 2)
def test_set_recompile_on_key_pop(self):
s = {
torch._C._set_grad_enabled,
torch.amp._enter_autocast,
torch.amp._exit_autocast,
}
cnts = CompileCounter()
def fn(x, s):
if torch.amp._exit_autocast in s:
return x.sin()
return x.cos()
x = torch.randn(4)
opt_fn = torch.compile(fn, backend=cnts, fullgraph=True)
res = opt_fn(x, s)
opt_fn(x, s)
self.assertEqual(res, fn(x, s))
# No recompilation
self.assertEqual(cnts.frame_count, 1)
# Pop a value
s.remove(torch.amp._exit_autocast)
res = opt_fn(x, s)
# Check recompilation
self.assertEqual(cnts.frame_count, 2)
self.assertEqual(res, fn(x, s))
def test_set_recompile_on_key_change(self):
s = {
torch._C._set_grad_enabled,
torch.amp._enter_autocast,
torch.amp._exit_autocast,
}
cnts = CompileCounter()
def fn(x, s):
if torch.amp._exit_autocast in s:
return x.sin()
return x.cos()
x = torch.randn(4)
opt_fn = torch.compile(fn, backend=cnts, fullgraph=True)
res = opt_fn(x, s)
opt_fn(x, s)
self.assertEqual(res, fn(x, s))
# No recompilation
self.assertEqual(cnts.frame_count, 1)
# Pop a value
s.remove(torch.amp._exit_autocast)
# Add a different value
s.add(torch._C._set_autograd_fallback_mode)
res = opt_fn(x, s)
# Check recompilation
self.assertEqual(cnts.frame_count, 2)
self.assertEqual(res, fn(x, s))
@unittest.skip("random failures on Python 3.9")
def test_set_guard_on_keys_change(self):
# This test guarantee that we're not triggering any of the dict guards
# on sets
s = {
torch._C._set_grad_enabled,
torch.amp._enter_autocast,
torch.amp._exit_autocast,
}
cnts = CompileCounter()
def fn(x, s):
for e in s:
x = x * len(str(e))
return x
opt_fn = torch.compile(fn, backend=cnts, fullgraph=True)
opt_fn(torch.randn(4), s)
opt_fn(torch.randn(4), s)
# No recompilation
self.assertEqual(cnts.frame_count, 1)
# pop and add the same item
s.remove(torch.amp._exit_autocast)
# It is not guaranteed that _exit_autocast will be in a specific order
s.add(torch.amp._exit_autocast)
x = torch.randn(4)
res = opt_fn(x, s)
# Check Dynamo don't recompile
self.assertEqual(cnts.frame_count, 1)
self.assertEqual(res, fn(x, s))
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
run_tests()

View File

@ -2188,5 +2188,16 @@
"Remove the `@contextlib.contextmanager` decorator."
]
}
],
"GB0222": [
{
"Gb_type": "Attempted to wrap a set with tensors",
"Context": "Python set containing torch.Tensor elements",
"Explanation": "Dynamo cannot trace sets of tensors. To get a stable ordering, Dynamo needs to convert the set into a list and the order might not be stable if the set contains tensors.",
"Hints": [
"Use a dictionary where the keys are tensors.",
"It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues."
]
}
]
}

View File

@ -119,6 +119,7 @@ from .source import (
ListGetItemSource,
LocalSource,
NNModuleSource,
NonSerializableSetGetItemSource,
NumpyTensorSource,
OptimizerSource,
ScriptObjectQualifiedNameSource,
@ -945,6 +946,11 @@ class GuardBuilder(GuardBuilderBase):
# Fix this if condition
if isinstance(example_value, dict_keys):
guard_manager_enum = GuardManagerType.DICT_GUARD_MANAGER
elif isinstance(example_value, (set, frozenset)):
# we don't need to guard on key order for set/frozenset
# but the if above will be true for these types as set is
# implemented using a dict in Dynamo
guard_manager_enum = GuardManagerType.GUARD_MANAGER
else:
assert isinstance(example_value, dict)
guard_manager_enum = GuardManagerType.DICT_GUARD_MANAGER
@ -1289,6 +1295,14 @@ class GuardBuilder(GuardBuilderBase):
example_value=example_value,
guard_manager_enum=guard_manager_enum,
)
elif istype(source, NonSerializableSetGetItemSource):
assert base_guard_manager
out = base_guard_manager.set_getitem_manager(
index=source.index,
source=source_name,
example_value=example_value,
guard_manager_enum=guard_manager_enum,
)
elif istype(source, WeakRefCallSource):
assert base_guard_manager # to make mypy happy
out = base_guard_manager.weakref_call_manager(
@ -1507,6 +1521,19 @@ class GuardBuilder(GuardBuilderBase):
not invert, key, get_verbose_code_parts(code, guard)
)
def SET_CONTAINS(self, guard: Guard, key: Any, invert: bool):
set_ref = self.arg_ref(guard)
item = key
contains = not invert # install_dict_contains_guard inverts "contains"
code = f"set.__contains__({set_ref}, {item!r})"
self._set_guard_export_info(guard, [code])
self.get_guard_manager(guard).add_set_contains_guard(
contains, item, get_verbose_code_parts(code, guard)
)
def BOOL_MATCH(self, guard: Guard):
# checks val == True or val == False
ref = self.arg_ref(guard)

View File

@ -587,6 +587,34 @@ class ConstDictKeySource(ChainedSource):
return True
@dataclasses.dataclass(frozen=True)
class NonSerializableSetGetItemSource(ChainedSource):
index: int
def __post_init__(self):
from .variables import ConstantVariable
assert ConstantVariable.is_literal(self.index)
def guard_source(self):
return self.base.guard_source()
def reconstruct(self, codegen: "PyCodegen"):
codegen.add_push_null(
lambda: codegen.load_import_from(utils.__name__, "set_getitem")
)
codegen(self.base)
codegen.append_output(codegen.create_load_const(self.index))
codegen.extend_output(create_call_function(2, False))
def name(self):
# set ordering might not be stable
return f"list({self.base.name()})[{self.index!r}]"
def is_dict_key(self):
return False
# Used to access an item from the dictionary
@dataclasses.dataclass(frozen=True)
class DictGetItemSource(ChainedSource):

View File

@ -2572,6 +2572,11 @@ def dict_keys_getitem(d, n):
return next(itertools.islice(dict_class.keys(d), n, n + 1))
def set_getitem(s, n):
# Set ordering might not be stable
return list(s)[n]
def enum_repr(value, local):
# enum class can override __str__ method. Use __class__ and name attribute
# to extract the class name and key name.

View File

@ -110,6 +110,7 @@ from ..source import (
is_from_unspecialized_nn_module_source,
ListGetItemSource,
LocalSource,
NonSerializableSetGetItemSource,
NumpyTensorSource,
OptimizerSource,
RandomValueSource,
@ -772,6 +773,38 @@ class VariableBuilder:
var = TorchFunctionModeVariable(value, source=self.source)
self.tx.output.side_effects.track_object_existing(value, var)
return var
elif istype(value, set):
if any(isinstance(x, torch.Tensor) for x in value):
unimplemented_v2(
gb_type="Attempted to wrap a set with tensors",
context="Python set containing torch.Tensor elements",
explanation=(
"Dynamo cannot trace sets of tensors. To get a stable ordering, "
"Dynamo needs to convert the set into a list and the order might not be "
"stable if the set contains tensors."
),
hints=[
"Use a dictionary where the keys are tensors.",
*graph_break_hints.SUPPORTABLE,
],
)
self.install_guards(GuardBuilder.TYPE_MATCH)
self.install_guards(GuardBuilder.SEQUENCE_LENGTH)
# The list gives a ordering for the set items. The ordering is based
# on the Python hash and it is not related to object ordering inside
# the set object. The order being incorrect at runtime will lead to
# a recompilation.
L = list(value)
items = [
LazyVariableTracker.create(
v, source=NonSerializableSetGetItemSource(self.source, i)
)
for i, v in enumerate(L)
]
result = SetVariable(items, source=self.source)
return self.tx.output.side_effects.track_object_existing(value, result)
elif istype(value, frozenset) and all(
(
# For DBR quantization, we could get a frozenset of torch funcs.

View File

@ -129,6 +129,8 @@ def is_hashable(x):
class ConstDictVariable(VariableTracker):
CONTAINS_GUARD = GuardBuilder.DICT_CONTAINS
_nonvar_fields = {
"user_cls",
*VariableTracker._nonvar_fields,
@ -399,7 +401,7 @@ class ConstDictVariable(VariableTracker):
install_guard(
self.make_guard(
functools.partial(
GuardBuilder.DICT_CONTAINS,
type(self).CONTAINS_GUARD,
key=args[0].value,
invert=not contains,
)
@ -761,6 +763,8 @@ class DefaultDictVariable(ConstDictVariable):
class SetVariable(ConstDictVariable):
"""We model a sets as dictionary with None values"""
CONTAINS_GUARD = GuardBuilder.SET_CONTAINS
def __init__(
self,
items: list[VariableTracker],
@ -913,8 +917,7 @@ class SetVariable(ConstDictVariable):
pass
def install_dict_contains_guard(self, tx, args):
# Already EQUALS_MATCH guarded
pass
super().install_dict_contains_guard(tx, args)
class FrozensetVariable(SetVariable):
@ -991,6 +994,14 @@ class DictKeySetVariable(SetVariable):
+ "])"
)
def install_dict_keys_match_guard(self):
# Already EQUALS_MATCH guarded
pass
def install_dict_contains_guard(self, tx, args):
# Already EQUALS_MATCH guarded
pass
@property
def set_items(self):
return self.items

View File

@ -6,6 +6,7 @@
#include <ATen/EmptyTensor.h>
#include <ATen/SparseCsrTensorUtils.h>
#include <c10/util/flat_hash_map.h>
#include <fmt/format.h>
#include <torch/csrc/autograd/grad_mode.h>
#include <torch/csrc/autograd/utils/wrap_outputs.h>
#include <torch/csrc/dynamo/guards.h>
@ -2007,6 +2008,33 @@ class DICT_CONTAINS : public LeafGuard {
py::object _key;
};
// Check that set contains an item.
class SET_CONTAINS : public LeafGuard {
public:
SET_CONTAINS(
RootGuardManager* root_guard_manager,
bool contains,
py::object item,
py::object verbose_code_parts)
: LeafGuard(root_guard_manager, std::move(verbose_code_parts)),
_contains(contains ? 1 : 0),
_item(std::move(item)) {}
bool check_nopybind(PyObject* value) override { // borrowed ref
int result = (PySet_Check(value) || PyFrozenSet_Check(value)) &&
PySet_Contains(value, _item.ptr());
if (result == -1) {
PyErr_Clear();
return false;
}
return result == _contains;
}
private:
int _contains;
py::object _item;
};
/**
* Relational guards compare more than one value. We implement Relational
* guards by capturing some state in the guard object. For example for tensor
@ -4213,6 +4241,82 @@ class ListGetItemGuardAccessor : public GuardAccessor {
Py_ssize_t _index{-1};
};
/**
* Represents set[index] accessor by converting the set into a list.
*/
class SetGetItemGuardAccessor : public GuardAccessor {
public:
SetGetItemGuardAccessor(
RootGuardManager* root,
const py::object& index,
std::string source,
py::handle example_value,
py::handle guard_manager_enum)
: GuardAccessor(
root,
index,
std::move(source),
example_value,
guard_manager_enum),
_index(py::cast<Py_ssize_t>(index)) {}
// NB: Intentional duplication between check_nopybind and
// check_verbose_nopybind.
bool check_nopybind(PyObject* obj, bool matches_dict_tag = false)
override { // borrowed ref
PyObject* lst = PySequence_List(obj);
PyObject* x = PyList_GetItem(lst, _index); // borrowed ref
Py_XDECREF(lst);
if (x == nullptr) {
PyErr_Clear();
return false;
}
bool result = _guard_manager->check_nopybind(x);
return result;
}
GuardDebugInfo check_verbose_nopybind(
PyObject* obj) override { // borrowed ref
PyObject* lst = PySequence_List(obj);
PyObject* x = PyList_GetItem(lst, _index); // borrowed ref
Py_XDECREF(lst);
if (x == nullptr) {
PyErr_Clear();
return GuardDebugInfo(false, 0);
}
GuardDebugInfo result = _guard_manager->check_verbose_nopybind(x);
return result;
}
std::string repr() const override {
return fmt::format("SetGetItemGuardAccessor(index={})", _index);
}
public: // cloning functions
SetGetItemGuardAccessor(
GuardManager* guard_manager,
SetGetItemGuardAccessor* from)
: GuardAccessor(guard_manager, from) {
from->clone_visitor(this);
}
GuardAccessor* clone(
RootGuardManager* cloned_root,
const py::function& clone_filter_fn) override {
return clone_common<SetGetItemGuardAccessor>(cloned_root, clone_filter_fn);
}
void clone_visitor(SetGetItemGuardAccessor* to) {
to->_index = _index;
}
private:
Py_ssize_t _index{-1};
};
/**
* Represents tuple[index] accessor. It is faster than generic
* GetItemGuardAccessor.
@ -5607,6 +5711,10 @@ PyObject* torch_c_dynamo_guards_init() {
py_m, "DICT_CONTAINS")
.def(py::init<RootGuardManager*, bool, py::object, py::list>())
.def("__call__", &DICT_CONTAINS::check);
py::class_<SET_CONTAINS, LeafGuard, std::shared_ptr<SET_CONTAINS>>(
py_m, "SET_CONTAINS")
.def(py::init<RootGuardManager*, bool, py::object, py::list>())
.def("__call__", &SET_CONTAINS::check);
py::class_<DYNAMIC_INDICES, LeafGuard, std::shared_ptr<DYNAMIC_INDICES>>(
py_m, "DYNAMIC_INDICES")
.def(py::init<RootGuardManager*, py::set, py::list>())
@ -5969,6 +6077,18 @@ PyObject* torch_c_dynamo_guards_init() {
std::move(key),
std::move(verbose_code_parts)));
})
.def(
"add_set_contains_guard",
[](GuardManager& self,
bool contains,
py::object item,
py::object verbose_code_parts) -> void {
self.add_leaf_guard(std::make_shared<SET_CONTAINS>(
self.get_root(),
contains,
std::move(item),
std::move(verbose_code_parts)));
})
.def(
"add_dynamic_indices_guard",
[](GuardManager& self,
@ -6226,6 +6346,14 @@ PyObject* torch_c_dynamo_guards_init() {
py::arg("example_value"),
py::arg("guard_manager_enum"),
py::return_value_policy::reference)
.def(
"set_getitem_manager",
&GuardManager::get_child_manager<SetGetItemGuardAccessor>,
py::arg("index"),
py::arg("source"),
py::arg("example_value"),
py::arg("guard_manager_enum"),
py::return_value_policy::reference)
// return by reference because GuardManager has the ownership of accessors
// and guard managers
.def(