mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[dynamo][guards] Revert introduction of different types of lambda_guards (#163385)
With https://fb.workplace.com/groups/260102303573409/permalink/787294574187510/ issue, it might be a better idea to just speedup _realize_dict and keep the changes very local. So reverting this PR as well, to return to clean slate. Pull Request resolved: https://github.com/pytorch/pytorch/pull/163385 Approved by: https://github.com/jansel
This commit is contained in:
parent
8f6dbc0ba8
commit
991e3d0d16
|
|
@ -224,12 +224,6 @@ class GuardManager:
|
|||
def add_lambda_guard(
|
||||
self, user_lambda: Callable[..., Any], verbose_code_parts: list[str]
|
||||
) -> None: ...
|
||||
def add_lambda_guard_no_args(
|
||||
self, user_lambda: Callable[..., Any], verbose_code_parts: list[str]
|
||||
) -> None: ...
|
||||
def add_lambda_guard_no_framelocals(
|
||||
self, user_lambda: Callable[..., Any], verbose_code_parts: list[str]
|
||||
) -> None: ...
|
||||
def add_id_match_guard(
|
||||
self, id_val: int, verbose_code_parts: list[str]
|
||||
) -> None: ...
|
||||
|
|
|
|||
|
|
@ -2061,10 +2061,10 @@ class GuardBuilder(GuardBuilderBase):
|
|||
# TODO(anijain2305) - Consider this moving this guard to C++
|
||||
compare_fn = torch._functorch.pyfunctorch.compare_functorch_state
|
||||
|
||||
def fn() -> bool:
|
||||
def fn(x: Any) -> bool:
|
||||
return compare_fn(states)
|
||||
|
||||
self.guard_manager.root.add_lambda_guard_no_args(
|
||||
self.guard_manager.root.add_lambda_guard(
|
||||
fn, get_verbose_code_parts(code, guard)
|
||||
)
|
||||
|
||||
|
|
@ -2090,10 +2090,10 @@ class GuardBuilder(GuardBuilderBase):
|
|||
]
|
||||
self._set_guard_export_info(guard, code)
|
||||
|
||||
def fn() -> bool:
|
||||
def fn(x: Any) -> bool:
|
||||
return guard_hooks_ids == hooks_ids_fn(get_hooks())
|
||||
|
||||
self.guard_manager.root.add_lambda_guard_no_args(
|
||||
self.guard_manager.root.add_lambda_guard(
|
||||
fn, get_verbose_code_parts(code, guard)
|
||||
)
|
||||
|
||||
|
|
@ -2114,7 +2114,7 @@ class GuardBuilder(GuardBuilderBase):
|
|||
return x.__tensor_flatten__()[1] == original_metadata
|
||||
|
||||
global_name = f"___check_metadata_{id(metadata_checker)}_c{CompileContext.current_compile_id()}"
|
||||
self.get_guard_manager(guard).add_lambda_guard_no_framelocals(
|
||||
self.get_guard_manager(guard).add_lambda_guard(
|
||||
metadata_checker, get_verbose_code_parts(global_name, guard)
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -1700,95 +1700,6 @@ class LAMBDA_GUARD : public LeafGuard {
|
|||
py::function _guard_check_fn;
|
||||
};
|
||||
|
||||
/*
|
||||
Similar to LAMBDA_GUARD but where lambda does not take any arguments. This
|
||||
ensures that we don't need to construct a dictionary from framelocals even if
|
||||
the guard is at the root. These guards are for root guards like GlobalState.
|
||||
*/
|
||||
class LAMBDA_GUARD_NO_ARGS : public LeafGuard {
|
||||
public:
|
||||
LAMBDA_GUARD_NO_ARGS(
|
||||
RootGuardManager* root_guard_manager,
|
||||
py::object guard_check_fn,
|
||||
py::object verbose_code_parts)
|
||||
: LeafGuard(root_guard_manager, std::move(verbose_code_parts)) {
|
||||
if (py::isinstance<py::function>(guard_check_fn)) {
|
||||
_guard_check_fn = py::cast<py::function>(std::move(guard_check_fn));
|
||||
} else {
|
||||
throw py::type_error("LAMBDA_GUARD_NO_ARGS expects (callable, str)");
|
||||
}
|
||||
}
|
||||
|
||||
bool _check() {
|
||||
PyObject* x = PyObject_CallNoArgs(_guard_check_fn.ptr()); // new ref
|
||||
if (x == nullptr) {
|
||||
// An exception is caught in the lambda function.
|
||||
PyErr_Clear();
|
||||
return false;
|
||||
}
|
||||
bool result = PyObject_IsTrue(x);
|
||||
Py_DECREF(x);
|
||||
return result;
|
||||
}
|
||||
|
||||
bool check_nopybind(PyObject* value) override { // borrowed ref
|
||||
return _check();
|
||||
}
|
||||
|
||||
GuardDebugInfo check_verbose_nopybind(PyObject* value) override {
|
||||
PyObject* x = PyObject_CallNoArgs(_guard_check_fn.ptr()); // new ref
|
||||
if (x == nullptr) {
|
||||
// An exception is caught in the lambda function.
|
||||
std::string exc_message = get_exception_message();
|
||||
PyErr_Clear();
|
||||
return GuardDebugInfo(false, exc_message, 0);
|
||||
}
|
||||
bool result = PyObject_IsTrue(x);
|
||||
Py_DECREF(x);
|
||||
if (result) {
|
||||
return GuardDebugInfo(true, 0);
|
||||
}
|
||||
return GuardDebugInfo(false, verbose_code_parts(), 0);
|
||||
}
|
||||
|
||||
// Ensure that framelocals dict is not constructed.
|
||||
bool check_nopybind(FrameLocalsMapping* map) override {
|
||||
return _check();
|
||||
}
|
||||
|
||||
private:
|
||||
// The user provided lambda function for check_fn.
|
||||
py::function _guard_check_fn;
|
||||
};
|
||||
|
||||
/*
|
||||
Similar to LAMBDA_GUARD but disallows running on a FrameLocalsMapping input.
|
||||
These guards are at trunk or leaf, and not at the root.
|
||||
*/
|
||||
class LAMBDA_GUARD_NO_FRAMELOCALS : public LAMBDA_GUARD {
|
||||
public:
|
||||
LAMBDA_GUARD_NO_FRAMELOCALS(
|
||||
RootGuardManager* root_guard_manager,
|
||||
py::object guard_check_fn,
|
||||
py::object verbose_code_parts)
|
||||
: LAMBDA_GUARD(root_guard_manager, guard_check_fn, verbose_code_parts) {}
|
||||
|
||||
bool check_nopybind(PyObject* value) override { // borrowed ref
|
||||
return LAMBDA_GUARD::check_nopybind(value);
|
||||
}
|
||||
|
||||
GuardDebugInfo check_verbose_nopybind(PyObject* value) override {
|
||||
return LAMBDA_GUARD::check_verbose_nopybind(value);
|
||||
}
|
||||
|
||||
bool check_nopybind(FrameLocalsMapping* map) override {
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"FramelocalsMapping input to LAMBDA_GUARD_NO_FRAMELOCALS,"
|
||||
"use LAMBDA_GUARD instead");
|
||||
}
|
||||
};
|
||||
|
||||
class TYPE_MATCH : public LeafGuard {
|
||||
public:
|
||||
// type_id = id(type(obj))
|
||||
|
|
@ -6768,19 +6679,6 @@ PyObject* torch_c_dynamo_guards_init() {
|
|||
py_m, "LAMBDA_GUARD")
|
||||
.def(py::init<RootGuardManager*, py::function, py::list>())
|
||||
.def("__call__", &LAMBDA_GUARD::check);
|
||||
py::class_<
|
||||
LAMBDA_GUARD_NO_ARGS,
|
||||
LeafGuard,
|
||||
std::shared_ptr<LAMBDA_GUARD_NO_ARGS>>(py_m, "LAMBDA_GUARD_NO_ARGS")
|
||||
.def(py::init<RootGuardManager*, py::function, py::list>())
|
||||
.def("__call__", &LAMBDA_GUARD_NO_ARGS::check);
|
||||
py::class_<
|
||||
LAMBDA_GUARD_NO_FRAMELOCALS,
|
||||
LeafGuard,
|
||||
std::shared_ptr<LAMBDA_GUARD_NO_FRAMELOCALS>>(
|
||||
py_m, "LAMBDA_GUARD_NO_FRAMELOCALS")
|
||||
.def(py::init<RootGuardManager*, py::function, py::list>())
|
||||
.def("__call__", &LAMBDA_GUARD_NO_FRAMELOCALS::check);
|
||||
py::class_<TYPE_MATCH, LeafGuard, std::shared_ptr<TYPE_MATCH>>(
|
||||
py_m, "TYPE_MATCH")
|
||||
.def(py::init<RootGuardManager*, py::object, py::list>())
|
||||
|
|
@ -7095,26 +6993,6 @@ PyObject* torch_c_dynamo_guards_init() {
|
|||
std::move(lambda),
|
||||
std::move(verbose_code_parts)));
|
||||
})
|
||||
.def(
|
||||
"add_lambda_guard_no_args",
|
||||
[](GuardManager& self,
|
||||
py::object lambda,
|
||||
py::object verbose_code_parts) -> void {
|
||||
self.add_leaf_guard(std::make_shared<LAMBDA_GUARD_NO_ARGS>(
|
||||
self.get_root(),
|
||||
std::move(lambda),
|
||||
std::move(verbose_code_parts)));
|
||||
})
|
||||
.def(
|
||||
"add_lambda_guard_no_framelocals",
|
||||
[](GuardManager& self,
|
||||
py::object lambda,
|
||||
py::object verbose_code_parts) -> void {
|
||||
self.add_leaf_guard(std::make_shared<LAMBDA_GUARD_NO_FRAMELOCALS>(
|
||||
self.get_root(),
|
||||
std::move(lambda),
|
||||
std::move(verbose_code_parts)));
|
||||
})
|
||||
.def(
|
||||
"add_type_match_guard",
|
||||
[](GuardManager& self,
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user