mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/70339 When a python program is translated to TorchScript, the python exception type is dropped. This makes users's life hard when they need to categorize errors based more than only exception message. Here we make the change so when we raise a python exception, we record the fully qualified class name for the exception. Later on when the TorchScript is interpreted, a special exception CustomJITException is thrown. User can get the python class name from CustomJITException::getPythonClassName . Note that, this diff does not customize the mapping from C++ exception to Python exception. It's left to the users to do whatever mapping they want. Code under scripts/shunting are just my own experimental code. I can split them out if requested. ghstack-source-id: 146221879 Test Plan: buck test mode/opt //caffe2/test:jit Reviewed By: gmagogsfm Differential Revision: D33282878 fbshipit-source-id: 910f67a764519f1053a48589d1a34df69001525d
160 lines
5.0 KiB
C++
160 lines
5.0 KiB
C++
/*
|
|
* We have a python unit test for exceptions in test/jit/test_exception.py .
|
|
* Add a CPP version here to verify that excepted exception types thrown from
|
|
* C++. This is hard to test in python code since C++ exceptions will be
|
|
* translated to python exceptions.
|
|
*/
|
|
#include <gtest/gtest.h>
|
|
#include <pybind11/embed.h>
|
|
#include <torch/csrc/jit/frontend/parser.h>
|
|
#include <torch/csrc/jit/frontend/resolver.h>
|
|
#include <torch/csrc/jit/runtime/jit_exception.h>
|
|
#include <torch/jit.h>
|
|
#include <iostream>
|
|
#include <stdexcept>
|
|
|
|
namespace torch {
|
|
namespace jit {
|
|
|
|
namespace py = pybind11;
|
|
|
|
TEST(TestException, TestAssertion) {
|
|
std::string pythonCode = R"PY(
|
|
def foo():
|
|
raise AssertionError("An assertion failed")
|
|
)PY";
|
|
auto cu_ptr = torch::jit::compile(pythonCode);
|
|
torch::jit::GraphFunction* gf =
|
|
(torch::jit::GraphFunction*)&cu_ptr->get_function("foo");
|
|
std::cerr << "Graph is\n" << *gf->graph() << std::endl;
|
|
|
|
bool is_jit_exception = false;
|
|
std::string message;
|
|
c10::optional<std::string> exception_class;
|
|
try {
|
|
cu_ptr->run_method("foo");
|
|
} catch (JITException& e) {
|
|
is_jit_exception = true;
|
|
message = e.what();
|
|
exception_class = e.getPythonClassName();
|
|
}
|
|
EXPECT_TRUE(is_jit_exception);
|
|
EXPECT_FALSE(exception_class);
|
|
EXPECT_TRUE(
|
|
message.find("RuntimeError: AssertionError: An assertion failed") !=
|
|
std::string::npos);
|
|
}
|
|
|
|
struct MyPythonExceptionValue : public torch::jit::SugaredValue {
|
|
explicit MyPythonExceptionValue(const py::object& exception_class) {
|
|
qualified_name_ =
|
|
(py::str(py::getattr(exception_class, "__module__", py::str(""))) +
|
|
py::str(".") +
|
|
py::str(py::getattr(exception_class, "__name__", py::str(""))))
|
|
.cast<std::string>();
|
|
}
|
|
|
|
std::string kind() const override {
|
|
return "My Python exception";
|
|
}
|
|
|
|
// Simplified from PythonExceptionValue::call
|
|
std::shared_ptr<torch::jit::SugaredValue> call(
|
|
const torch::jit::SourceRange& loc,
|
|
torch::jit::GraphFunction& caller,
|
|
at::ArrayRef<torch::jit::NamedValue> args,
|
|
at::ArrayRef<torch::jit::NamedValue> kwargs,
|
|
size_t n_binders) override {
|
|
TORCH_CHECK(args.size() == 1);
|
|
Value* error_message = args.at(0).value(*caller.graph());
|
|
Value* qualified_class_name =
|
|
insertConstant(*caller.graph(), qualified_name_, loc);
|
|
return std::make_shared<ExceptionMessageValue>(
|
|
error_message, qualified_class_name);
|
|
}
|
|
|
|
private:
|
|
std::string qualified_name_;
|
|
};
|
|
|
|
class SimpleResolver : public torch::jit::Resolver {
|
|
public:
|
|
explicit SimpleResolver() {}
|
|
|
|
std::shared_ptr<torch::jit::SugaredValue> resolveValue(
|
|
const std::string& name,
|
|
torch::jit::GraphFunction& m,
|
|
const torch::jit::SourceRange& loc) override {
|
|
// follows toSugaredValue (toSugaredValue is defined in caffe2:_C which is
|
|
// a python extension. We can not add that as a cpp_binary's dep)
|
|
if (name == "SimpleValueError") {
|
|
py::object obj = py::globals()["SimpleValueError"];
|
|
return std::make_shared<MyPythonExceptionValue>(obj);
|
|
}
|
|
TORCH_CHECK(false, "resolveValue: can not resolve '", name, "{}'");
|
|
}
|
|
|
|
torch::jit::TypePtr resolveType(
|
|
const std::string& name,
|
|
const torch::jit::SourceRange& loc) override {
|
|
return nullptr;
|
|
}
|
|
};
|
|
|
|
/*
|
|
* - The python source code parsing for TorchScript here is learned from
|
|
* torch::jit::compile.
|
|
* - The code only parses one Def. If there are multiple in the code, those
|
|
* except the first one are skipped.
|
|
*/
|
|
TEST(TestException, TestCustomException) {
|
|
py::scoped_interpreter guard{};
|
|
py::exec(R"PY(
|
|
class SimpleValueError(ValueError):
|
|
def __init__(self, message):
|
|
super(SimpleValueError, self).__init__(message)
|
|
)PY");
|
|
|
|
std::string pythonCode = R"PY(
|
|
def foo():
|
|
raise SimpleValueError("An assertion failed")
|
|
)PY";
|
|
|
|
torch::jit::Parser p(
|
|
std::make_shared<torch::jit::Source>(pythonCode, "<string>", 1));
|
|
auto def = torch::jit::Def(p.parseFunction(/*is_method=*/false));
|
|
std::cerr << "Def is:\n" << def << std::endl;
|
|
auto cu = std::make_shared<torch::jit::CompilationUnit>();
|
|
(void)cu->define(
|
|
c10::nullopt,
|
|
{},
|
|
{},
|
|
{def},
|
|
// class PythonResolver is defined in
|
|
// torch/csrc/jit/python/script_init.cpp. It's not in a header file so I
|
|
// can not use it. Create a SimpleResolver insteand
|
|
{std::make_shared<SimpleResolver>()},
|
|
nullptr);
|
|
torch::jit::GraphFunction* gf =
|
|
(torch::jit::GraphFunction*)&cu->get_function("foo");
|
|
std::cerr << "Graph is\n" << *gf->graph() << std::endl;
|
|
bool is_jit_exception = false;
|
|
c10::optional<std::string> exception_class;
|
|
std::string message;
|
|
try {
|
|
cu->run_method("foo");
|
|
} catch (JITException& e) {
|
|
is_jit_exception = true;
|
|
exception_class = e.getPythonClassName();
|
|
message = e.what();
|
|
}
|
|
EXPECT_TRUE(is_jit_exception);
|
|
EXPECT_EQ("__main__.SimpleValueError", *exception_class);
|
|
EXPECT_TRUE(
|
|
message.find("__main__.SimpleValueError: An assertion failed") !=
|
|
std::string::npos);
|
|
}
|
|
|
|
} // namespace jit
|
|
} // namespace torch
|