First step at adding exceptions (#12789)

Summary:
This is a first step towards adding exceptions. We need minimal support in order to begin converting the torch library to weak script mode (which is the main goal here).

Some limitations (that are documented in the tests & compiler):
1. Cannot assign exceptions to variables
2. Any name after raise is being treated as a valid Exception
3. No control flow analysis yet. Below a will be undefined:

if True:
     a = 1
else:
     raise Exception("Hi")
return a
Pull Request resolved: https://github.com/pytorch/pytorch/pull/12789

Differential Revision: D12848936

Pulled By: eellison

fbshipit-source-id: 1f60ceef2381040486123ec797e97d65b074862d
This commit is contained in:
Elias Ellison 2018-10-30 20:20:26 -07:00 committed by Facebook Github Bot
parent c7027a511f
commit 59f8e8ada7
19 changed files with 209 additions and 34 deletions

View File

@ -72,6 +72,7 @@ namespace c10 {
_(prim, StoreWorld) \ _(prim, StoreWorld) \
_(prim, DummyWorld) \ _(prim, DummyWorld) \
_(prim, fork) \ _(prim, fork) \
_(prim, RaiseException) \
_(aten, append) \ _(aten, append) \
_(aten, __not__) \ _(aten, __not__) \
FORALL_ATEN_BASE_SYMBOLS(_) \ FORALL_ATEN_BASE_SYMBOLS(_) \

View File

@ -0,0 +1,4 @@
(def
(ident fn)
(decl (list) (option))
(list (raise (option))))

View File

@ -0,0 +1,10 @@
(def
(ident fn)
(decl (list) (option))
(list
(raise
(option
(apply
(variable (ident Exception))
(list (string_literal hello))
(list))))))

View File

@ -3860,6 +3860,20 @@ a")
ast = torch.jit.frontend.get_jit_ast(fn, is_method=False) ast = torch.jit.frontend.get_jit_ast(fn, is_method=False)
self.assertExpected(str(ast)) self.assertExpected(str(ast))
@unittest.skipIf(not PY2, "Requires python 2")
def test_python_frontend_py2(self):
def fn():
raise Exception("hello")
ast = torch.jit.frontend.get_jit_ast(fn, is_method=False)
self.assertExpected(str(ast))
@unittest.skipIf(PY2, "Requires python 3")
def test_python_frontend_py3(self):
def fn():
raise Exception("hello")
ast = torch.jit.frontend.get_jit_ast(fn, is_method=False)
self.assertExpected(str(ast))
def _make_scalar_vars(self, arr, dtype): def _make_scalar_vars(self, arr, dtype):
return [torch.tensor(val, dtype=dtype) for val in arr] return [torch.tensor(val, dtype=dtype) for val in arr]
@ -4676,7 +4690,7 @@ a")
tester = self tester = self
class Foo(torch.jit.ScriptModule): class Foo(torch.jit.ScriptModule):
__constants__ = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j'] __constants__ = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i']
def __init__(self): def __init__(self):
super(Foo, self).__init__(False) super(Foo, self).__init__(False)
@ -4695,14 +4709,8 @@ a")
self.h = type(1) self.h = type(1)
with tester.assertRaisesRegex(TypeError, "not a valid constant"): with tester.assertRaisesRegex(TypeError, "not a valid constant"):
self.i = (3, 4, {}) self.i = (3, 4, {})
self.j = (6, (1, 2, 3), 8)
@torch.jit.script_method
def forward(self, x):
return x + self.a + self.b + self.f[0] + self.j[1][2]
f = Foo() f = Foo()
self.assertEqual(f(torch.ones(1)), torch.ones(1) + 1 + 1.2 + 3 + 3)
def test_script_module_for(self): def test_script_module_for(self):
class M(torch.jit.ScriptModule): class M(torch.jit.ScriptModule):
@ -5078,17 +5086,6 @@ a")
v = torch.randn(1, device='cuda') v = torch.randn(1, device='cuda')
self.assertEqual(foo(v), 0) self.assertEqual(foo(v), 0)
def test_script_storage_offset(self):
@torch.jit.script
def foo(a):
return a.storage_offset()
v = torch.randn(5)
self.assertEqual(foo(v), 0)
v.set_(v.storage(), 3, [1], [1])
self.assertEquals(foo(v), 3)
def test_script_chunk(self): def test_script_chunk(self):
@torch.jit.script @torch.jit.script
def foo(a): def foo(a):
@ -7535,6 +7532,66 @@ a")
ge_graph = jit_trace.__getattr__('forward').graph_for(x, y, c) ge_graph = jit_trace.__getattr__('forward').graph_for(x, y, c)
self.assertExpectedGraph(ge_graph, 'jit') self.assertExpectedGraph(ge_graph, 'jit')
def test_exceptions(self):
cu = torch.jit.CompilationUnit('''
def foo(cond):
if bool(cond):
raise ValueError(3)
return 1
''')
cu.foo(torch.tensor(0))
with self.assertRaisesRegex(torch.jit._Exception, "Exception"):
cu.foo(torch.tensor(1))
@torch.jit.script
def foo(cond):
a = 3
if bool(cond):
raise ArbitraryError(a, "hi")
if False:
raise ArbitraryError
return a
foo(torch.tensor(0))
# we don't currently validate the name of the exception
with self.assertRaisesRegex(torch.jit._Exception, "Exception"):
foo(torch.tensor(1))
@torch.jit.script
def foo():
a = Exception()
raise a
# a gets DCEd because the expression following raise is ignored
with self.assertRaisesRegex(torch.jit._Exception, "failed in interpreter"):
foo()
@torch.jit.script
def foo_except_used():
a = Exception()
print(a)
raise a
# a not DCEd
with self.assertRaisesRegex(RuntimeError, "expected value of type Tensor"):
foo_except_used()
# We don't validate the expr following raise
@torch.jit.script
def foo():
raise 3 + 4
# no control flow analysis yet
with self.assertRaisesRegex(RuntimeError, "undefined value a"):
@torch.jit.script
def foo():
if True:
a = 1
else:
raise Exception("Hi")
return a
def test_weak_script_function(self): def test_weak_script_function(self):
outer_var = 10 outer_var = 10
outer_var2 = 11 outer_var2 = 11
@ -9309,6 +9366,7 @@ def add_nn_module_test(module_name, constructor_args, call_args, skipTestIf=()):
call_args_str = ', '.join(actuals) call_args_str = ', '.join(actuals)
call = "self.submodule({})".format(call_args_str) call = "self.submodule({})".format(call_args_str)
script = script_method_template.format(method_args, call) script = script_method_template.format(method_args, call)
print(script)
# Create module to use the script method # Create module to use the script method
class TheModule(torch.jit.ScriptModule): class TheModule(torch.jit.ScriptModule):

View File

@ -36,6 +36,7 @@
#include "torch/csrc/jit/function_schema.h" #include "torch/csrc/jit/function_schema.h"
#include "torch/csrc/jit/operator.h" #include "torch/csrc/jit/operator.h"
#include "torch/csrc/jit/fusers/interface.h" #include "torch/csrc/jit/fusers/interface.h"
#include "torch/csrc/jit/script/jit_exception.h"
#include "caffe2/serialize/inline_container.h" #include "caffe2/serialize/inline_container.h"
@ -84,6 +85,8 @@ std::string runJITCPPTests();
void initJITBindings(PyObject *module) { void initJITBindings(PyObject *module) {
auto m = py::handle(module).cast<py::module>(); auto m = py::handle(module).cast<py::module>();
py::register_exception<JITException>(m, "JITException");
py::class_<python::IODescriptor>(m, "IODescriptor"); py::class_<python::IODescriptor>(m, "IODescriptor");
m.def("_jit_init", loadPythonClasses) m.def("_jit_init", loadPythonClasses)

View File

@ -12,6 +12,7 @@
#include "torch/csrc/jit/constants.h" #include "torch/csrc/jit/constants.h"
#include "torch/csrc/jit/operator.h" #include "torch/csrc/jit/operator.h"
#include "torch/csrc/variable_tensor_functions.h" #include "torch/csrc/variable_tensor_functions.h"
#include "torch/csrc/jit/script/jit_exception.h"
#include <exception> #include <exception>
#include <iostream> #include <iostream>
@ -664,10 +665,15 @@ struct InterpreterStateImpl {
} }
pc = new_pc; pc = new_pc;
} catch(std::exception & e) { } catch(std::exception & e) {
if(!instructions[pc].debug_location) if (!instructions[pc].debug_location) {
throw; // rethrow original exception throw;
// throw a new exception with enhanced debugging information }
instructions[pc].debug_location->wrapAndRethrowException(e, "operation failed in interpreter"); auto msg = instructions[pc].debug_location->wrapException(e, "operation failed in interpreter");
if (dynamic_cast<JITException *>(&e)) {
throw JITException(msg);
} else {
throw std::runtime_error(msg);
}
} }
} }
current_pc = pc; current_pc = pc;

View File

@ -16,17 +16,10 @@ std::unordered_set<Symbol> skip_list = {
prim::If, prim::If,
prim::Loop, //TODO: handle Loop prim::Loop, //TODO: handle Loop
prim::Print, prim::Print,
prim::RaiseException,
prim::PythonOp, //may have side effects prim::PythonOp, //may have side effects
prim::LoadWorld, prim::LoadWorld,
prim::StoreWorld, prim::StoreWorld,
//all the rand functions from native_functions.yaml
aten::rand,
aten::rand_like,
aten::randint,
aten::randint_like,
aten::randn,
aten::randn_like,
aten::randperm,
prim::Constant, prim::Constant,
prim::Undefined, prim::Undefined,
prim::NoneGenerator, prim::NoneGenerator,
@ -129,7 +122,8 @@ void ConstantPropagation(Node* n, bool recurse) {
std::all_of(n->inputs().begin(), n->inputs().end(), [&](Value* v) { std::all_of(n->inputs().begin(), n->inputs().end(), [&](Value* v) {
return v->node()->kind() == prim::Constant; return v->node()->kind() == prim::Constant;
}); });
bool supported_node = !n->kind().is_onnx() && skip_list.count(n->kind()) == 0; bool supported_node = !n->kind().is_onnx() && skip_list.count(n->kind()) == 0
&& !n->isNondeterministic();
auto run_blocks = [&]() { auto run_blocks = [&]() {
if (recurse) { if (recurse) {
for (Block * block : n->blocks()) { for (Block * block : n->blocks()) {

View File

@ -25,6 +25,7 @@ bool hasSideEffects(Node * node, bool_memo_type& memo) {
return it->second; return it->second;
bool has_side_effects = bool has_side_effects =
node->kind() == prim::Print || node->kind() == prim::StoreWorld || node->kind() == prim::Print || node->kind() == prim::StoreWorld ||
node->kind() == prim::RaiseException ||
std::any_of(node->blocks().begin(), node->blocks().end(), [&](Block* b) { std::any_of(node->blocks().begin(), node->blocks().end(), [&](Block* b) {
return std::any_of(b->nodes().begin(), b->nodes().end(), [&](Node* n) { return std::any_of(b->nodes().begin(), b->nodes().end(), [&](Node* n) {
return hasSideEffects(n, memo); return hasSideEffects(n, memo);

View File

@ -348,6 +348,7 @@ void PropagateShapeOnNode(Node * node, bool insert_expands) {
} }
case prim::PythonOp: case prim::PythonOp:
case prim::Print: case prim::Print:
case prim::RaiseException:
case prim::Undefined: { case prim::Undefined: {
setUnshapedType(node); setUnshapedType(node);
return; return;

View File

@ -8,6 +8,7 @@
#include "torch/csrc/jit/ir.h" #include "torch/csrc/jit/ir.h"
#include "torch/csrc/jit/operator.h" #include "torch/csrc/jit/operator.h"
#include "torch/csrc/jit/custom_operator.h" #include "torch/csrc/jit/custom_operator.h"
#include "torch/csrc/jit/script/jit_exception.h"
#include "torch/csrc/variable_tensor_functions.h" #include "torch/csrc/variable_tensor_functions.h"
@ -255,6 +256,15 @@ RegisterOperators reg({
return 0; return 0;
}; };
}), }),
Operator(
prim::RaiseException,
[](Node* node) -> Operation {
return [](Stack& stack) {
throw JITException(pop(stack).toStringRef());
return 0;
};
}),
// Load x, y // Load x, y
// loads values from registers onto the stack, the actual callback does // loads values from registers onto the stack, the actual callback does
// nothing since the stack manipulation is already encoded in inst.inputs // nothing since the stack manipulation is already encoded in inst.inputs

View File

@ -948,6 +948,9 @@ private:
} }
} }
break; break;
case TK_RAISE:
emitRaise(Raise(stmt));
break;
case TK_RETURN: case TK_RETURN:
throw ErrorReport(stmt) << "return statements can appear only at the end " throw ErrorReport(stmt) << "return statements can appear only at the end "
<< "of the function body"; << "of the function body";
@ -1294,6 +1297,27 @@ private:
emitLoopCommon(stmt.range(), {}, {cond}, stmt.body(), {}); emitLoopCommon(stmt.range(), {}, {cond}, stmt.body(), {});
} }
// Currently we do not support assigning exceptions to variables,
// a = Exception("hi")
// raise a
//
// We ignore the expression following raise
//
// NYI: add exception logic to control-flow nodes
// if True:
// a = 1
// else
// raise Exception("Hi")
// print(a)
void emitRaise(const Raise& stmt) {
const std::string exception = "Exception";
auto string_input = insertConstant(*graph, exception, stmt.range());
graph->insertNode(graph->create(prim::RaiseException, {string_input}, 0)
->setSourceLocation(std::make_shared<SourceRange>(stmt.range())));
}
// Validate that the `lhs` Expr's in an assignment statement are valid. That // Validate that the `lhs` Expr's in an assignment statement are valid. That
// is: // is:
// //

View File

@ -0,0 +1,16 @@
#pragma once
#include <stdexcept>
namespace torch {
namespace jit {
struct JITException
: public std::runtime_error {
JITException() = default;
explicit JITException(const std::string& msg)
: std::runtime_error(msg) {}
};
} // namespace jit
} // namespace torch

View File

@ -84,7 +84,9 @@ namespace script {
_(TK_ARROW, "arrow", "->") \ _(TK_ARROW, "arrow", "->") \
_(TK_DECL, "decl", "") \ _(TK_DECL, "decl", "") \
_(TK_SLICE_EXPR, "slice expr", "") \ _(TK_SLICE_EXPR, "slice expr", "") \
_(TK_TYPE_COMMENT, "type comment", "# type:") _(TK_TYPE_COMMENT, "type comment", "# type:") \
_(TK_RAISE, "raise", "raise")
static const char* valid_single_char_tokens = "+-*/%@()[]:,={}><.?!"; static const char* valid_single_char_tokens = "+-*/%@()[]:,={}><.?!";

View File

@ -382,6 +382,12 @@ struct Parser {
auto values = parseList(TK_NOTHING, ',', TK_NEWLINE, &Parser::parseExp); auto values = parseList(TK_NOTHING, ',', TK_NEWLINE, &Parser::parseExp);
return Return::create(range, values); return Return::create(range, values);
} }
case TK_RAISE: {
auto range = L.next().range;
auto expr = parseExp();
L.expect(TK_NEWLINE);
return Raise::create(range, expr);
}
default: { default: {
List<Expr> exprs = parseList(TK_NOTHING, ',', TK_NOTHING, &Parser::parseExp); List<Expr> exprs = parseList(TK_NOTHING, ',', TK_NOTHING, &Parser::parseExp);
if (L.cur().kind != TK_NEWLINE) { if (L.cur().kind != TK_NEWLINE) {

View File

@ -131,6 +131,10 @@ void initTreeViewBindings(PyObject *module) {
.def(py::init([](const SourceRange& range, std::vector<Expr> values) { .def(py::init([](const SourceRange& range, std::vector<Expr> values) {
return Return::create(range, wrap_list(range, std::move(values))); return Return::create(range, wrap_list(range, std::move(values)));
})); }));
py::class_<Raise, Stmt>(m, "Raise")
.def(py::init([](const SourceRange& range, Expr *expr) {
return Raise::create(range, wrap_maybe(range, expr));
}));
py::class_<If, Stmt>(m, "If") py::class_<If, Stmt>(m, "If")
.def(py::init([](const SourceRange& range, const Expr& cond, std::vector<Stmt> true_branch, std::vector<Stmt> false_branch) { .def(py::init([](const SourceRange& range, const Expr& cond, std::vector<Stmt> true_branch, std::vector<Stmt> false_branch) {
return If::create(range, cond, return If::create(range, cond,

View File

@ -30,6 +30,7 @@ namespace script {
// | Assign(List<Expr> lhs, AssignType maybe_reduce, Expr rhs) TK_ASSIGN // | Assign(List<Expr> lhs, AssignType maybe_reduce, Expr rhs) TK_ASSIGN
// | Return(List<Expr> values) TK_RETURN // | Return(List<Expr> values) TK_RETURN
// | ExprStmt(List<Expr> expr) TK_EXPR_STMT // | ExprStmt(List<Expr> expr) TK_EXPR_STMT
// | Raise(Expr expr) TK_RAISE
// //
// Expr = TernaryIf(Expr cond, Expr true_expr, Expr false_expr) TK_IF_EXPR // Expr = TernaryIf(Expr cond, Expr true_expr, Expr false_expr) TK_IF_EXPR
// | BinOp(Expr lhs, Expr rhs) // | BinOp(Expr lhs, Expr rhs)
@ -208,6 +209,7 @@ struct Stmt : public TreeView {
case TK_ASSIGN: case TK_ASSIGN:
case TK_RETURN: case TK_RETURN:
case TK_EXPR_STMT: case TK_EXPR_STMT:
case TK_RAISE:
return; return;
default: default:
throw ErrorReport(tree) << kindToString(tree->kind()) << " is not a valid Stmt"; throw ErrorReport(tree) << kindToString(tree->kind()) << " is not a valid Stmt";
@ -464,6 +466,19 @@ struct Return : public Stmt {
} }
}; };
struct Raise : public Stmt {
explicit Raise(const TreeRef& tree) : Stmt(tree) {
tree_->match(TK_RAISE);
}
Maybe<Expr> expr() const {
return Maybe<Expr>(subtree(0));
}
static Raise create(const SourceRange& range, const Maybe<Expr>& expr) {
return Raise(Compound::create(TK_RAISE, range, {expr}));
}
};
struct ExprStmt : public Stmt { struct ExprStmt : public Stmt {
explicit ExprStmt(const TreeRef& tree) : Stmt(tree) { explicit ExprStmt(const TreeRef& tree) : Stmt(tree) {
tree_->match(TK_EXPR_STMT); tree_->match(TK_EXPR_STMT);

View File

@ -14,15 +14,20 @@ namespace torch { namespace jit {
struct SourceLocation { struct SourceLocation {
virtual ~SourceLocation() = default; virtual ~SourceLocation() = default;
virtual void highlight(std::ostream & out) const = 0; virtual void highlight(std::ostream & out) const = 0;
void wrapAndRethrowException(const std::exception & e, const std::string & additional = "") {
std::string wrapException(const std::exception & e, const std::string & additional = "") {
std::stringstream msg; std::stringstream msg;
msg << "\n" << e.what() << ":\n"; msg << "\n" << e.what() << ":\n";
if(!additional.empty()) { if(!additional.empty()) {
msg << additional << ":\n"; msg << additional << ":\n";
} }
highlight(msg); highlight(msg);
throw std::runtime_error(msg.str()); return msg.str();
} }
void wrapAndRethrowException(const std::exception & e, const std::string & additional = "") {
throw std::runtime_error(wrapException(e, additional));
}
}; };
inline std::ostream& operator<<(std::ostream& out, const SourceLocation& sl) { inline std::ostream& operator<<(std::ostream& out, const SourceLocation& sl) {

View File

@ -1374,6 +1374,9 @@ _register_builtin(len, 'aten::len')
_register_builtin(_wait, 'aten::wait') _register_builtin(_wait, 'aten::wait')
# torch.jit._Exception
_Exception = torch._C.JITException
class _disable_tracing(object): class _disable_tracing(object):
def __enter__(self): def __enter__(self):

View File

@ -266,6 +266,18 @@ class StmtBuilder(Builder):
values = (stmt.value,) if not isinstance(stmt.value, ast.Tuple) else stmt.value.elts values = (stmt.value,) if not isinstance(stmt.value, ast.Tuple) else stmt.value.elts
return Return(r, [build_expr(ctx, val) for val in values if val is not None]) return Return(r, [build_expr(ctx, val) for val in values if val is not None])
@staticmethod
def build_Raise(ctx, stmt):
r = ctx.make_range(stmt.lineno, stmt.col_offset, stmt.col_offset + len("raise"))
if PY2:
if stmt.tback:
raise NotSupportedError(r, "tracebacks with exceptions is not supported")
# TODO use stmt.type once instantiating exceptions is supported
expr = build_expr(ctx, stmt.inst) if stmt.inst else None
else:
expr = build_expr(ctx, stmt.exc)
return Raise(r, expr)
@staticmethod @staticmethod
def build_AugAssign(ctx, stmt): def build_AugAssign(ctx, stmt):
lhs = [StmtBuilder.get_assign_lhs_expr(ctx, stmt.target)] lhs = [StmtBuilder.get_assign_lhs_expr(ctx, stmt.target)]