mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
We define specializations for pybind11 defined templates (in particular, PYBIND11_DECLARE_HOLDER_TYPE) and consequently it is important that these specializations *always* be #include'd when making use of pybind11 templates whose behavior depends on these specializations, otherwise we can cause an ODR violation. The easiest way to ensure that all the specializations are always loaded is to designate a header (in this case, torch/csrc/util/pybind.h) that ensures the specializations are defined, and then add a lint to ensure this header is included whenever pybind11 headers are included. The existing grep linter didn't have enough knobs to do this conveniently, so I added some features. I'm open to suggestions for how to structure the features better. The main changes: - Added an --allowlist-pattern flag, which turns off the grep lint if some other line exists. This is used to stop the grep lint from complaining about pybind11 includes if the util include already exists. - Added --match-first-only flag, which lets grep only match against the first matching line. This is because, even if there are multiple includes that are problematic, I only need to fix one of them. We don't /really/ need this, but when I was running lintrunner -a to fixup the preexisting codebase it was annoying without this, as the lintrunner overall driver fails if there are multiple edits on the same file. I excluded any files that didn't otherwise have a dependency on torch/ATen, this was mostly caffe2 and the valgrind wrapper compat bindings. Note the grep replacement is kind of crappy, but clang-tidy lint cleaned it up in most cases. See also https://github.com/pybind/pybind11/issues/4099 Signed-off-by: Edward Z. Yang <ezyang@fb.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/82552 Approved by: https://github.com/albanD
161 lines
5.0 KiB
C++
161 lines
5.0 KiB
C++
/*
|
|
* We have a python unit test for exceptions in test/jit/test_exception.py .
|
|
* Add a CPP version here to verify that excepted exception types thrown from
|
|
* C++. This is hard to test in python code since C++ exceptions will be
|
|
* translated to python exceptions.
|
|
*/
|
|
#include <gtest/gtest.h>
|
|
#include <pybind11/embed.h>
|
|
#include <torch/csrc/jit/frontend/parser.h>
|
|
#include <torch/csrc/jit/frontend/resolver.h>
|
|
#include <torch/csrc/jit/runtime/jit_exception.h>
|
|
#include <torch/csrc/utils/pybind.h>
|
|
#include <torch/jit.h>
|
|
#include <iostream>
|
|
#include <stdexcept>
|
|
|
|
namespace torch {
|
|
namespace jit {
|
|
|
|
namespace py = pybind11;
|
|
|
|
TEST(TestException, TestAssertion) {
|
|
std::string pythonCode = R"PY(
|
|
def foo():
|
|
raise AssertionError("An assertion failed")
|
|
)PY";
|
|
auto cu_ptr = torch::jit::compile(pythonCode);
|
|
torch::jit::GraphFunction* gf =
|
|
(torch::jit::GraphFunction*)&cu_ptr->get_function("foo");
|
|
std::cerr << "Graph is\n" << *gf->graph() << std::endl;
|
|
|
|
bool is_jit_exception = false;
|
|
std::string message;
|
|
c10::optional<std::string> exception_class;
|
|
try {
|
|
cu_ptr->run_method("foo");
|
|
} catch (JITException& e) {
|
|
is_jit_exception = true;
|
|
message = e.what();
|
|
exception_class = e.getPythonClassName();
|
|
}
|
|
EXPECT_TRUE(is_jit_exception);
|
|
EXPECT_FALSE(exception_class);
|
|
EXPECT_TRUE(
|
|
message.find("RuntimeError: AssertionError: An assertion failed") !=
|
|
std::string::npos);
|
|
}
|
|
|
|
struct MyPythonExceptionValue : public torch::jit::SugaredValue {
|
|
explicit MyPythonExceptionValue(const py::object& exception_class) {
|
|
qualified_name_ =
|
|
(py::str(py::getattr(exception_class, "__module__", py::str(""))) +
|
|
py::str(".") +
|
|
py::str(py::getattr(exception_class, "__name__", py::str(""))))
|
|
.cast<std::string>();
|
|
}
|
|
|
|
std::string kind() const override {
|
|
return "My Python exception";
|
|
}
|
|
|
|
// Simplified from PythonExceptionValue::call
|
|
std::shared_ptr<torch::jit::SugaredValue> call(
|
|
const torch::jit::SourceRange& loc,
|
|
torch::jit::GraphFunction& caller,
|
|
at::ArrayRef<torch::jit::NamedValue> args,
|
|
at::ArrayRef<torch::jit::NamedValue> kwargs,
|
|
size_t n_binders) override {
|
|
TORCH_CHECK(args.size() == 1);
|
|
Value* error_message = args.at(0).value(*caller.graph());
|
|
Value* qualified_class_name =
|
|
insertConstant(*caller.graph(), qualified_name_, loc);
|
|
return std::make_shared<ExceptionMessageValue>(
|
|
error_message, qualified_class_name);
|
|
}
|
|
|
|
private:
|
|
std::string qualified_name_;
|
|
};
|
|
|
|
class SimpleResolver : public torch::jit::Resolver {
|
|
public:
|
|
explicit SimpleResolver() {}
|
|
|
|
std::shared_ptr<torch::jit::SugaredValue> resolveValue(
|
|
const std::string& name,
|
|
torch::jit::GraphFunction& m,
|
|
const torch::jit::SourceRange& loc) override {
|
|
// follows toSugaredValue (toSugaredValue is defined in caffe2:_C which is
|
|
// a python extension. We can not add that as a cpp_binary's dep)
|
|
if (name == "SimpleValueError") {
|
|
py::object obj = py::globals()["SimpleValueError"];
|
|
return std::make_shared<MyPythonExceptionValue>(obj);
|
|
}
|
|
TORCH_CHECK(false, "resolveValue: can not resolve '", name, "{}'");
|
|
}
|
|
|
|
torch::jit::TypePtr resolveType(
|
|
const std::string& name,
|
|
const torch::jit::SourceRange& loc) override {
|
|
return nullptr;
|
|
}
|
|
};
|
|
|
|
/*
|
|
* - The python source code parsing for TorchScript here is learned from
|
|
* torch::jit::compile.
|
|
* - The code only parses one Def. If there are multiple in the code, those
|
|
* except the first one are skipped.
|
|
*/
|
|
TEST(TestException, TestCustomException) {
|
|
py::scoped_interpreter guard{};
|
|
py::exec(R"PY(
|
|
class SimpleValueError(ValueError):
|
|
def __init__(self, message):
|
|
super(SimpleValueError, self).__init__(message)
|
|
)PY");
|
|
|
|
std::string pythonCode = R"PY(
|
|
def foo():
|
|
raise SimpleValueError("An assertion failed")
|
|
)PY";
|
|
|
|
torch::jit::Parser p(
|
|
std::make_shared<torch::jit::Source>(pythonCode, "<string>", 1));
|
|
auto def = torch::jit::Def(p.parseFunction(/*is_method=*/false));
|
|
std::cerr << "Def is:\n" << def << std::endl;
|
|
auto cu = std::make_shared<torch::jit::CompilationUnit>();
|
|
(void)cu->define(
|
|
c10::nullopt,
|
|
{},
|
|
{},
|
|
{def},
|
|
// class PythonResolver is defined in
|
|
// torch/csrc/jit/python/script_init.cpp. It's not in a header file so I
|
|
// can not use it. Create a SimpleResolver insteand
|
|
{std::make_shared<SimpleResolver>()},
|
|
nullptr);
|
|
torch::jit::GraphFunction* gf =
|
|
(torch::jit::GraphFunction*)&cu->get_function("foo");
|
|
std::cerr << "Graph is\n" << *gf->graph() << std::endl;
|
|
bool is_jit_exception = false;
|
|
c10::optional<std::string> exception_class;
|
|
std::string message;
|
|
try {
|
|
cu->run_method("foo");
|
|
} catch (JITException& e) {
|
|
is_jit_exception = true;
|
|
exception_class = e.getPythonClassName();
|
|
message = e.what();
|
|
}
|
|
EXPECT_TRUE(is_jit_exception);
|
|
EXPECT_EQ("__main__.SimpleValueError", *exception_class);
|
|
EXPECT_TRUE(
|
|
message.find("__main__.SimpleValueError: An assertion failed") !=
|
|
std::string::npos);
|
|
}
|
|
|
|
} // namespace jit
|
|
} // namespace torch
|