pytorch/test/cpp/jit/test_exception.cpp
Shunting Zhang 911d527b87 Make TorchScript Preserve Fully Qualified Class Name for Python Exceptions (#70339)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/70339

When a python program is translated to TorchScript, the python exception type is dropped. This makes users's life hard when they need to categorize errors based more than only exception message.

Here we make the change so when we raise a python exception, we record the fully qualified class name for the exception. Later on when the TorchScript is interpreted, a special exception CustomJITException is thrown. User can get the python class name from CustomJITException::getPythonClassName .

Note that, this diff does not customize the mapping from C++ exception to Python exception. It's left to the users to do whatever mapping they want.

Code under scripts/shunting are just my own experimental code. I can split them out if requested.
ghstack-source-id: 146221879

Test Plan: buck test mode/opt //caffe2/test:jit

Reviewed By: gmagogsfm

Differential Revision: D33282878

fbshipit-source-id: 910f67a764519f1053a48589d1a34df69001525d
2021-12-24 00:25:40 -08:00

160 lines
5.0 KiB
C++

/*
* We have a python unit test for exceptions in test/jit/test_exception.py .
* Add a CPP version here to verify that excepted exception types thrown from
* C++. This is hard to test in python code since C++ exceptions will be
* translated to python exceptions.
*/
#include <gtest/gtest.h>
#include <pybind11/embed.h>
#include <torch/csrc/jit/frontend/parser.h>
#include <torch/csrc/jit/frontend/resolver.h>
#include <torch/csrc/jit/runtime/jit_exception.h>
#include <torch/jit.h>
#include <iostream>
#include <stdexcept>
namespace torch {
namespace jit {
namespace py = pybind11;
TEST(TestException, TestAssertion) {
std::string pythonCode = R"PY(
def foo():
raise AssertionError("An assertion failed")
)PY";
auto cu_ptr = torch::jit::compile(pythonCode);
torch::jit::GraphFunction* gf =
(torch::jit::GraphFunction*)&cu_ptr->get_function("foo");
std::cerr << "Graph is\n" << *gf->graph() << std::endl;
bool is_jit_exception = false;
std::string message;
c10::optional<std::string> exception_class;
try {
cu_ptr->run_method("foo");
} catch (JITException& e) {
is_jit_exception = true;
message = e.what();
exception_class = e.getPythonClassName();
}
EXPECT_TRUE(is_jit_exception);
EXPECT_FALSE(exception_class);
EXPECT_TRUE(
message.find("RuntimeError: AssertionError: An assertion failed") !=
std::string::npos);
}
struct MyPythonExceptionValue : public torch::jit::SugaredValue {
explicit MyPythonExceptionValue(const py::object& exception_class) {
qualified_name_ =
(py::str(py::getattr(exception_class, "__module__", py::str(""))) +
py::str(".") +
py::str(py::getattr(exception_class, "__name__", py::str(""))))
.cast<std::string>();
}
std::string kind() const override {
return "My Python exception";
}
// Simplified from PythonExceptionValue::call
std::shared_ptr<torch::jit::SugaredValue> call(
const torch::jit::SourceRange& loc,
torch::jit::GraphFunction& caller,
at::ArrayRef<torch::jit::NamedValue> args,
at::ArrayRef<torch::jit::NamedValue> kwargs,
size_t n_binders) override {
TORCH_CHECK(args.size() == 1);
Value* error_message = args.at(0).value(*caller.graph());
Value* qualified_class_name =
insertConstant(*caller.graph(), qualified_name_, loc);
return std::make_shared<ExceptionMessageValue>(
error_message, qualified_class_name);
}
private:
std::string qualified_name_;
};
class SimpleResolver : public torch::jit::Resolver {
public:
explicit SimpleResolver() {}
std::shared_ptr<torch::jit::SugaredValue> resolveValue(
const std::string& name,
torch::jit::GraphFunction& m,
const torch::jit::SourceRange& loc) override {
// follows toSugaredValue (toSugaredValue is defined in caffe2:_C which is
// a python extension. We can not add that as a cpp_binary's dep)
if (name == "SimpleValueError") {
py::object obj = py::globals()["SimpleValueError"];
return std::make_shared<MyPythonExceptionValue>(obj);
}
TORCH_CHECK(false, "resolveValue: can not resolve '", name, "{}'");
}
torch::jit::TypePtr resolveType(
const std::string& name,
const torch::jit::SourceRange& loc) override {
return nullptr;
}
};
/*
* - The python source code parsing for TorchScript here is learned from
* torch::jit::compile.
* - The code only parses one Def. If there are multiple in the code, those
* except the first one are skipped.
*/
TEST(TestException, TestCustomException) {
py::scoped_interpreter guard{};
py::exec(R"PY(
class SimpleValueError(ValueError):
def __init__(self, message):
super(SimpleValueError, self).__init__(message)
)PY");
std::string pythonCode = R"PY(
def foo():
raise SimpleValueError("An assertion failed")
)PY";
torch::jit::Parser p(
std::make_shared<torch::jit::Source>(pythonCode, "<string>", 1));
auto def = torch::jit::Def(p.parseFunction(/*is_method=*/false));
std::cerr << "Def is:\n" << def << std::endl;
auto cu = std::make_shared<torch::jit::CompilationUnit>();
(void)cu->define(
c10::nullopt,
{},
{},
{def},
// class PythonResolver is defined in
// torch/csrc/jit/python/script_init.cpp. It's not in a header file so I
// can not use it. Create a SimpleResolver insteand
{std::make_shared<SimpleResolver>()},
nullptr);
torch::jit::GraphFunction* gf =
(torch::jit::GraphFunction*)&cu->get_function("foo");
std::cerr << "Graph is\n" << *gf->graph() << std::endl;
bool is_jit_exception = false;
c10::optional<std::string> exception_class;
std::string message;
try {
cu->run_method("foo");
} catch (JITException& e) {
is_jit_exception = true;
exception_class = e.getPythonClassName();
message = e.what();
}
EXPECT_TRUE(is_jit_exception);
EXPECT_EQ("__main__.SimpleValueError", *exception_class);
EXPECT_TRUE(
message.find("__main__.SimpleValueError: An assertion failed") !=
std::string::npos);
}
} // namespace jit
} // namespace torch