[dynamo] Support reading attributes from pybind objects (#134630)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134630
Approved by: https://github.com/jansel
This commit is contained in:
Animesh Jain 2024-08-28 22:04:48 -07:00 committed by PyTorch MergeBot
parent 92e38a476f
commit 594162f7ab
7 changed files with 78 additions and 6 deletions

View File

@ -835,6 +835,7 @@ libtorch_python_core_sources = [
"torch/csrc/dynamo/extra_state.cpp",
"torch/csrc/dynamo/framelocals_mapping.cpp",
"torch/csrc/dynamo/guards.cpp",
"torch/csrc/dynamo/utils.cpp",
"torch/csrc/dynamo/init.cpp",
"torch/csrc/functorch/init.cpp",
"torch/csrc/fx/node.cpp",

View File

@ -3602,6 +3602,22 @@ class DefaultsTests(torch._dynamo.test_case.TestCase):
dynamo_class_name = dynamo_default_str[1].split(" object at")[0]
self.assertEqual(eager_class_name, dynamo_class_name)
def test_pybind_object(self):
def fn(x, pybind_obj):
if pybind_obj.result:
return torch.cos(x)
return torch.sin(x)
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
pybind_obj = torch._C._dynamo.guards.GuardDebugInfo(True, ["a==1"], 0)
x = torch.randn(4)
self.assertEqual(opt_fn(x, pybind_obj), fn(x, pybind_obj))
pybind_obj = torch._C._dynamo.guards.GuardDebugInfo(False, ["a==1"], 1)
x = torch.randn(4)
self.assertEqual(opt_fn(x, pybind_obj), fn(x, pybind_obj))
instantiate_parametrized_tests(FunctionTests)

View File

@ -899,6 +899,18 @@ class UserDefinedObjectVariable(UserDefinedVariable):
def _check_for_getattr(self):
return get_custom_getattr(self.value)
def _is_c_defined_property(self, subobj):
if not isinstance(subobj, property):
return False
# pybind def_readwrite is implemented via PyCFunction. At the python level, it is visible as a property whose
# fget is an instancemethod wrapper - https://docs.python.org/3/c-api/method.html#c.PyInstanceMethod_Check
# If we have a PyCFunction, we make an assumption that there is no side effect.
return isinstance(
subobj.fget, types.BuiltinFunctionType
) or torch._C._dynamo.utils.is_instancemethod(subobj.fget)
def _getattr_static(self, name):
subobj = inspect.getattr_static(self.value, name, NO_SUCH_SUBOBJ)
import _collections
@ -913,12 +925,7 @@ class UserDefinedObjectVariable(UserDefinedVariable):
or (
inspect.ismemberdescriptor(subobj) and name in self.value.__slots__
) # handle memberdecriptor and slots
or (
isinstance(subobj, property)
and isinstance(
subobj.fget, types.BuiltinFunctionType
) # property with C-defined fget
)
or self._is_c_defined_property(subobj)
):
# Call __getattribute__, we have already checked that this is not overridden and side-effect free. We don't
# want to call getattr because it can be user-overridden.

View File

@ -1,4 +1,5 @@
#include <torch/csrc/dynamo/init.h>
#include <torch/csrc/dynamo/utils.h>
#include <pybind11/stl_bind.h>
#include <torch/csrc/Exceptions.h>
@ -44,6 +45,11 @@ void initDynamoBindings(PyObject* torch) {
throw python_error();
}
PyObject* utils = torch_c_dynamo_utils_init();
if (utils == nullptr || PyModule_AddObject(dynamo, "utils", utils) != 0) {
throw python_error();
}
PyObject* guards = torch_c_dynamo_guards_init();
if (guards == nullptr || PyModule_AddObject(dynamo, "guards", guards) != 0) {
throw python_error();

View File

@ -0,0 +1,33 @@
#include <torch/csrc/dynamo/utils.h>
namespace torch::dynamo {
static std::array<PyMethodDef, 1> _methods = {{
{nullptr,
nullptr,
0,
nullptr} // Sentinel value indicating the end of the array
}};
bool is_instancemethod(py::object obj) {
return PyInstanceMethod_Check(obj.ptr());
}
static struct PyModuleDef _module = {
PyModuleDef_HEAD_INIT,
"torch._C._dynamo.utils",
"Module containing C utils",
-1,
_methods.data()};
PyObject* torch_c_dynamo_utils_init() {
auto m = PyModule_Create(&_module);
if (m == nullptr)
return nullptr;
auto py_m = py::handle(m).cast<py::module>();
py_m.def("is_instancemethod", is_instancemethod);
return m;
}
} // namespace torch::dynamo

View File

@ -1,5 +1,10 @@
#pragma once
#include <torch/csrc/python_headers.h>
// C2039 MSVC
#include <pybind11/complex.h>
#include <torch/csrc/utils/pybind.h>
#include <Python.h>
// The visibility attribute is to avoid a warning about storing a field in the
// struct that has a different visibility (from pybind) than the struct.
#ifdef _WIN32
@ -7,3 +12,7 @@
#else
#define VISIBILITY_HIDDEN __attribute__((visibility("hidden")))
#endif
namespace torch::dynamo {
PyObject* torch_c_dynamo_utils_init();
} // namespace torch::dynamo