[dynamo][guards] Print relational guards only once (#150810)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150810
Approved by: https://github.com/anijain2305
This commit is contained in:
Isuru Fernando 2025-04-08 14:46:06 +00:00 committed by PyTorch MergeBot
parent 8b5e717601
commit a22d3e778e
3 changed files with 21 additions and 15 deletions

View File

@ -62,8 +62,9 @@ def munge_shape_guards(s: str) -> str:
if torch._dynamo.config.enable_cpp_symbolic_shape_guards:
# Since we can have multiple guard accessors for one guard, the shape guard
# printing will have duplicates. We remove duplicates whie preserving order.
lines = list(dict.fromkeys(lines))
# printing will have just SYMBOLIC_SHAPE_GUARD in one line for the second
# guard accessor and onwards. We remove those lines
lines = [line for line in lines if "__SHAPE_GUARD__:" in line]
return "\n".join(lines)

View File

@ -201,17 +201,17 @@ class GuardManagerWrapper:
self.id_matched_objs = {}
self.no_tensor_aliasing_sources = []
self.print_no_tensor_aliasing_guard = True
self.printed_relational_guards = set()
self.diff_guard_sources: OrderedSet[str] = OrderedSet()
@contextmanager
def _preserve_print_no_tensor_aliasing_flag(self):
self.print_no_tensor_aliasing_guard = True
def _preserve_printed_relational_guards(self):
self.printed_relational_guards = set()
try:
yield
finally:
self.print_no_tensor_aliasing_guard = True
self.printed_relational_guards = set()
def collect_diff_guard_sources(self):
# At the time of finalize, we have only marked guard managers with
@ -314,9 +314,9 @@ class GuardManagerWrapper:
def construct_manager_string(self, mgr, body):
with body.indent():
for guard in mgr.get_leaf_guards():
if isinstance(guard, torch._C._dynamo.guards.NO_TENSOR_ALIASING): # type: ignore[attr-defined]
if self.print_no_tensor_aliasing_guard:
self.print_no_tensor_aliasing_guard = False
if isinstance(guard, torch._C._dynamo.guards.RelationalGuard): # type: ignore[attr-defined]
if guard not in self.printed_relational_guards:
self.printed_relational_guards.add(guard)
body.writelines(self.get_guard_lines(guard))
else:
body.writelines(
@ -353,7 +353,7 @@ class GuardManagerWrapper:
else:
super().writeline("+- " + line)
with self._preserve_print_no_tensor_aliasing_flag():
with self._preserve_printed_relational_guards():
body = IndentedBufferWithPrefix()
body.tabwidth = 1
body.writeline("", skip_prefix=True)

View File

@ -5455,22 +5455,27 @@ PyObject* torch_c_dynamo_guards_init() {
py::list>())
.def("__call__", &TENSOR_MATCH::check);
// NOLINTNEXTLINE(bugprone-unused-raii)
py::class_<OBJECT_ALIASING, LeafGuard, std::shared_ptr<OBJECT_ALIASING>>(
py_m, "OBJECT_ALIASING");
py::class_<RelationalGuard, LeafGuard, std::shared_ptr<RelationalGuard>>(
py_m, "RelationalGuard");
// NOLINTNEXTLINE(bugprone-unused-raii)
py::class_<
OBJECT_ALIASING,
RelationalGuard,
std::shared_ptr<OBJECT_ALIASING>>(py_m, "OBJECT_ALIASING");
// NOLINTNEXTLINE(bugprone-unused-raii)
py::class_<
NO_TENSOR_ALIASING,
LeafGuard,
RelationalGuard,
std::shared_ptr<NO_TENSOR_ALIASING>>(py_m, "NO_TENSOR_ALIASING");
// NOLINTNEXTLINE(bugprone-unused-raii)
py::class_<
STORAGE_OVERLAPPING,
LeafGuard,
RelationalGuard,
std::shared_ptr<STORAGE_OVERLAPPING>>(py_m, "STORAGE_OVERLAPPING");
// NOLINTNEXTLINE(bugprone-unused-raii)
py::class_<
SYMBOLIC_SHAPE_GUARD,
LeafGuard,
RelationalGuard,
std::shared_ptr<SYMBOLIC_SHAPE_GUARD>>(py_m, "SYMBOLIC_SHAPE_GUARD");
// Guard Accessors - These are present so that we can iterate over the