[dynamo][guards] Revert introduction of different types of lambda_guards (#163385)

With
https://fb.workplace.com/groups/260102303573409/permalink/787294574187510/
issue, it might be a better idea to just speedup _realize_dict and keep
the changes very local. So reverting this PR as well, to return to clean
slate.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163385
Approved by: https://github.com/jansel
This commit is contained in:
Animesh Jain 2025-09-26 18:25:05 -07:00 committed by PyTorch MergeBot
parent 8f6dbc0ba8
commit 991e3d0d16
3 changed files with 5 additions and 133 deletions

View File

@ -224,12 +224,6 @@ class GuardManager:
def add_lambda_guard(
self, user_lambda: Callable[..., Any], verbose_code_parts: list[str]
) -> None: ...
def add_lambda_guard_no_args(
self, user_lambda: Callable[..., Any], verbose_code_parts: list[str]
) -> None: ...
def add_lambda_guard_no_framelocals(
self, user_lambda: Callable[..., Any], verbose_code_parts: list[str]
) -> None: ...
def add_id_match_guard(
self, id_val: int, verbose_code_parts: list[str]
) -> None: ...

View File

@ -2061,10 +2061,10 @@ class GuardBuilder(GuardBuilderBase):
# TODO(anijain2305) - Consider this moving this guard to C++
compare_fn = torch._functorch.pyfunctorch.compare_functorch_state
def fn() -> bool:
def fn(x: Any) -> bool:
return compare_fn(states)
self.guard_manager.root.add_lambda_guard_no_args(
self.guard_manager.root.add_lambda_guard(
fn, get_verbose_code_parts(code, guard)
)
@ -2090,10 +2090,10 @@ class GuardBuilder(GuardBuilderBase):
]
self._set_guard_export_info(guard, code)
def fn() -> bool:
def fn(x: Any) -> bool:
return guard_hooks_ids == hooks_ids_fn(get_hooks())
self.guard_manager.root.add_lambda_guard_no_args(
self.guard_manager.root.add_lambda_guard(
fn, get_verbose_code_parts(code, guard)
)
@ -2114,7 +2114,7 @@ class GuardBuilder(GuardBuilderBase):
return x.__tensor_flatten__()[1] == original_metadata
global_name = f"___check_metadata_{id(metadata_checker)}_c{CompileContext.current_compile_id()}"
self.get_guard_manager(guard).add_lambda_guard_no_framelocals(
self.get_guard_manager(guard).add_lambda_guard(
metadata_checker, get_verbose_code_parts(global_name, guard)
)

View File

@ -1700,95 +1700,6 @@ class LAMBDA_GUARD : public LeafGuard {
py::function _guard_check_fn;
};
/*
Similar to LAMBDA_GUARD but where lambda does not take any arguments. This
ensures that we don't need to construct a dictionary from framelocals even if
the guard is at the root. These guards are for root guards like GlobalState.
*/
class LAMBDA_GUARD_NO_ARGS : public LeafGuard {
public:
LAMBDA_GUARD_NO_ARGS(
RootGuardManager* root_guard_manager,
py::object guard_check_fn,
py::object verbose_code_parts)
: LeafGuard(root_guard_manager, std::move(verbose_code_parts)) {
if (py::isinstance<py::function>(guard_check_fn)) {
_guard_check_fn = py::cast<py::function>(std::move(guard_check_fn));
} else {
throw py::type_error("LAMBDA_GUARD_NO_ARGS expects (callable, str)");
}
}
bool _check() {
PyObject* x = PyObject_CallNoArgs(_guard_check_fn.ptr()); // new ref
if (x == nullptr) {
// An exception is caught in the lambda function.
PyErr_Clear();
return false;
}
bool result = PyObject_IsTrue(x);
Py_DECREF(x);
return result;
}
bool check_nopybind(PyObject* value) override { // borrowed ref
return _check();
}
GuardDebugInfo check_verbose_nopybind(PyObject* value) override {
PyObject* x = PyObject_CallNoArgs(_guard_check_fn.ptr()); // new ref
if (x == nullptr) {
// An exception is caught in the lambda function.
std::string exc_message = get_exception_message();
PyErr_Clear();
return GuardDebugInfo(false, exc_message, 0);
}
bool result = PyObject_IsTrue(x);
Py_DECREF(x);
if (result) {
return GuardDebugInfo(true, 0);
}
return GuardDebugInfo(false, verbose_code_parts(), 0);
}
// Ensure that framelocals dict is not constructed.
bool check_nopybind(FrameLocalsMapping* map) override {
return _check();
}
private:
// The user provided lambda function for check_fn.
py::function _guard_check_fn;
};
/*
Similar to LAMBDA_GUARD but disallows running on a FrameLocalsMapping input.
These guards are at trunk or leaf, and not at the root.
*/
class LAMBDA_GUARD_NO_FRAMELOCALS : public LAMBDA_GUARD {
public:
LAMBDA_GUARD_NO_FRAMELOCALS(
RootGuardManager* root_guard_manager,
py::object guard_check_fn,
py::object verbose_code_parts)
: LAMBDA_GUARD(root_guard_manager, guard_check_fn, verbose_code_parts) {}
bool check_nopybind(PyObject* value) override { // borrowed ref
return LAMBDA_GUARD::check_nopybind(value);
}
GuardDebugInfo check_verbose_nopybind(PyObject* value) override {
return LAMBDA_GUARD::check_verbose_nopybind(value);
}
bool check_nopybind(FrameLocalsMapping* map) override {
TORCH_CHECK(
false,
"FramelocalsMapping input to LAMBDA_GUARD_NO_FRAMELOCALS,"
"use LAMBDA_GUARD instead");
}
};
class TYPE_MATCH : public LeafGuard {
public:
// type_id = id(type(obj))
@ -6768,19 +6679,6 @@ PyObject* torch_c_dynamo_guards_init() {
py_m, "LAMBDA_GUARD")
.def(py::init<RootGuardManager*, py::function, py::list>())
.def("__call__", &LAMBDA_GUARD::check);
py::class_<
LAMBDA_GUARD_NO_ARGS,
LeafGuard,
std::shared_ptr<LAMBDA_GUARD_NO_ARGS>>(py_m, "LAMBDA_GUARD_NO_ARGS")
.def(py::init<RootGuardManager*, py::function, py::list>())
.def("__call__", &LAMBDA_GUARD_NO_ARGS::check);
py::class_<
LAMBDA_GUARD_NO_FRAMELOCALS,
LeafGuard,
std::shared_ptr<LAMBDA_GUARD_NO_FRAMELOCALS>>(
py_m, "LAMBDA_GUARD_NO_FRAMELOCALS")
.def(py::init<RootGuardManager*, py::function, py::list>())
.def("__call__", &LAMBDA_GUARD_NO_FRAMELOCALS::check);
py::class_<TYPE_MATCH, LeafGuard, std::shared_ptr<TYPE_MATCH>>(
py_m, "TYPE_MATCH")
.def(py::init<RootGuardManager*, py::object, py::list>())
@ -7095,26 +6993,6 @@ PyObject* torch_c_dynamo_guards_init() {
std::move(lambda),
std::move(verbose_code_parts)));
})
.def(
"add_lambda_guard_no_args",
[](GuardManager& self,
py::object lambda,
py::object verbose_code_parts) -> void {
self.add_leaf_guard(std::make_shared<LAMBDA_GUARD_NO_ARGS>(
self.get_root(),
std::move(lambda),
std::move(verbose_code_parts)));
})
.def(
"add_lambda_guard_no_framelocals",
[](GuardManager& self,
py::object lambda,
py::object verbose_code_parts) -> void {
self.add_leaf_guard(std::make_shared<LAMBDA_GUARD_NO_FRAMELOCALS>(
self.get_root(),
std::move(lambda),
std::move(verbose_code_parts)));
})
.def(
"add_type_match_guard",
[](GuardManager& self,