mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[dynamo] Support reading attributes from pybind objects (#134630)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134630 Approved by: https://github.com/jansel
This commit is contained in:
parent
92e38a476f
commit
594162f7ab
|
|
@ -835,6 +835,7 @@ libtorch_python_core_sources = [
|
|||
"torch/csrc/dynamo/extra_state.cpp",
|
||||
"torch/csrc/dynamo/framelocals_mapping.cpp",
|
||||
"torch/csrc/dynamo/guards.cpp",
|
||||
"torch/csrc/dynamo/utils.cpp",
|
||||
"torch/csrc/dynamo/init.cpp",
|
||||
"torch/csrc/functorch/init.cpp",
|
||||
"torch/csrc/fx/node.cpp",
|
||||
|
|
|
|||
|
|
@ -3602,6 +3602,22 @@ class DefaultsTests(torch._dynamo.test_case.TestCase):
|
|||
dynamo_class_name = dynamo_default_str[1].split(" object at")[0]
|
||||
self.assertEqual(eager_class_name, dynamo_class_name)
|
||||
|
||||
def test_pybind_object(self):
|
||||
def fn(x, pybind_obj):
|
||||
if pybind_obj.result:
|
||||
return torch.cos(x)
|
||||
return torch.sin(x)
|
||||
|
||||
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
|
||||
|
||||
pybind_obj = torch._C._dynamo.guards.GuardDebugInfo(True, ["a==1"], 0)
|
||||
x = torch.randn(4)
|
||||
self.assertEqual(opt_fn(x, pybind_obj), fn(x, pybind_obj))
|
||||
|
||||
pybind_obj = torch._C._dynamo.guards.GuardDebugInfo(False, ["a==1"], 1)
|
||||
x = torch.randn(4)
|
||||
self.assertEqual(opt_fn(x, pybind_obj), fn(x, pybind_obj))
|
||||
|
||||
|
||||
instantiate_parametrized_tests(FunctionTests)
|
||||
|
||||
|
|
|
|||
|
|
@ -899,6 +899,18 @@ class UserDefinedObjectVariable(UserDefinedVariable):
|
|||
def _check_for_getattr(self):
|
||||
return get_custom_getattr(self.value)
|
||||
|
||||
def _is_c_defined_property(self, subobj):
|
||||
if not isinstance(subobj, property):
|
||||
return False
|
||||
|
||||
# pybind def_readwrite is implemented via PyCFunction. At the python level, it is visible as a property whose
|
||||
# fget is an instancemethod wrapper - https://docs.python.org/3/c-api/method.html#c.PyInstanceMethod_Check
|
||||
|
||||
# If we have a PyCFunction, we make an assumption that there is no side effect.
|
||||
return isinstance(
|
||||
subobj.fget, types.BuiltinFunctionType
|
||||
) or torch._C._dynamo.utils.is_instancemethod(subobj.fget)
|
||||
|
||||
def _getattr_static(self, name):
|
||||
subobj = inspect.getattr_static(self.value, name, NO_SUCH_SUBOBJ)
|
||||
import _collections
|
||||
|
|
@ -913,12 +925,7 @@ class UserDefinedObjectVariable(UserDefinedVariable):
|
|||
or (
|
||||
inspect.ismemberdescriptor(subobj) and name in self.value.__slots__
|
||||
) # handle memberdecriptor and slots
|
||||
or (
|
||||
isinstance(subobj, property)
|
||||
and isinstance(
|
||||
subobj.fget, types.BuiltinFunctionType
|
||||
) # property with C-defined fget
|
||||
)
|
||||
or self._is_c_defined_property(subobj)
|
||||
):
|
||||
# Call __getattribute__, we have already checked that this is not overridden and side-effect free. We don't
|
||||
# want to call getattr because it can be user-overridden.
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
#include <torch/csrc/dynamo/init.h>
|
||||
#include <torch/csrc/dynamo/utils.h>
|
||||
|
||||
#include <pybind11/stl_bind.h>
|
||||
#include <torch/csrc/Exceptions.h>
|
||||
|
|
@ -44,6 +45,11 @@ void initDynamoBindings(PyObject* torch) {
|
|||
throw python_error();
|
||||
}
|
||||
|
||||
PyObject* utils = torch_c_dynamo_utils_init();
|
||||
if (utils == nullptr || PyModule_AddObject(dynamo, "utils", utils) != 0) {
|
||||
throw python_error();
|
||||
}
|
||||
|
||||
PyObject* guards = torch_c_dynamo_guards_init();
|
||||
if (guards == nullptr || PyModule_AddObject(dynamo, "guards", guards) != 0) {
|
||||
throw python_error();
|
||||
|
|
|
|||
33
torch/csrc/dynamo/utils.cpp
Normal file
33
torch/csrc/dynamo/utils.cpp
Normal file
|
|
@ -0,0 +1,33 @@
|
|||
#include <torch/csrc/dynamo/utils.h>
|
||||
|
||||
namespace torch::dynamo {
|
||||
|
||||
static std::array<PyMethodDef, 1> _methods = {{
|
||||
{nullptr,
|
||||
nullptr,
|
||||
0,
|
||||
nullptr} // Sentinel value indicating the end of the array
|
||||
}};
|
||||
|
||||
bool is_instancemethod(py::object obj) {
|
||||
return PyInstanceMethod_Check(obj.ptr());
|
||||
}
|
||||
|
||||
static struct PyModuleDef _module = {
|
||||
PyModuleDef_HEAD_INIT,
|
||||
"torch._C._dynamo.utils",
|
||||
"Module containing C utils",
|
||||
-1,
|
||||
_methods.data()};
|
||||
|
||||
PyObject* torch_c_dynamo_utils_init() {
|
||||
auto m = PyModule_Create(&_module);
|
||||
if (m == nullptr)
|
||||
return nullptr;
|
||||
|
||||
auto py_m = py::handle(m).cast<py::module>();
|
||||
py_m.def("is_instancemethod", is_instancemethod);
|
||||
return m;
|
||||
}
|
||||
|
||||
} // namespace torch::dynamo
|
||||
|
|
@ -1,5 +1,10 @@
|
|||
#pragma once
|
||||
#include <torch/csrc/python_headers.h>
|
||||
// C2039 MSVC
|
||||
#include <pybind11/complex.h>
|
||||
#include <torch/csrc/utils/pybind.h>
|
||||
|
||||
#include <Python.h>
|
||||
// The visibility attribute is to avoid a warning about storing a field in the
|
||||
// struct that has a different visibility (from pybind) than the struct.
|
||||
#ifdef _WIN32
|
||||
|
|
@ -7,3 +12,7 @@
|
|||
#else
|
||||
#define VISIBILITY_HIDDEN __attribute__((visibility("hidden")))
|
||||
#endif
|
||||
|
||||
namespace torch::dynamo {
|
||||
PyObject* torch_c_dynamo_utils_init();
|
||||
} // namespace torch::dynamo
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user