diff --git a/aten/src/ATen/core/interned_strings.h b/aten/src/ATen/core/interned_strings.h index b6681b724e4..e8488217ce3 100644 --- a/aten/src/ATen/core/interned_strings.h +++ b/aten/src/ATen/core/interned_strings.h @@ -72,6 +72,7 @@ namespace c10 { _(prim, StoreWorld) \ _(prim, DummyWorld) \ _(prim, fork) \ + _(prim, RaiseException) \ _(aten, append) \ _(aten, __not__) \ FORALL_ATEN_BASE_SYMBOLS(_) \ diff --git a/test/expect/TestScript.test_python_frontend_py2.expect b/test/expect/TestScript.test_python_frontend_py2.expect new file mode 100644 index 00000000000..88722628570 --- /dev/null +++ b/test/expect/TestScript.test_python_frontend_py2.expect @@ -0,0 +1,4 @@ +(def + (ident fn) + (decl (list) (option)) + (list (raise (option)))) diff --git a/test/expect/TestScript.test_python_frontend_py3.expect b/test/expect/TestScript.test_python_frontend_py3.expect new file mode 100644 index 00000000000..315e8123077 --- /dev/null +++ b/test/expect/TestScript.test_python_frontend_py3.expect @@ -0,0 +1,10 @@ +(def + (ident fn) + (decl (list) (option)) + (list + (raise + (option + (apply + (variable (ident Exception)) + (list (string_literal hello)) + (list)))))) diff --git a/test/test_jit.py b/test/test_jit.py index 5d6007f7431..ef4db919e06 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -3860,6 +3860,20 @@ a") ast = torch.jit.frontend.get_jit_ast(fn, is_method=False) self.assertExpected(str(ast)) + @unittest.skipIf(not PY2, "Requires python 2") + def test_python_frontend_py2(self): + def fn(): + raise Exception("hello") + ast = torch.jit.frontend.get_jit_ast(fn, is_method=False) + self.assertExpected(str(ast)) + + @unittest.skipIf(PY2, "Requires python 3") + def test_python_frontend_py3(self): + def fn(): + raise Exception("hello") + ast = torch.jit.frontend.get_jit_ast(fn, is_method=False) + self.assertExpected(str(ast)) + def _make_scalar_vars(self, arr, dtype): return [torch.tensor(val, dtype=dtype) for val in arr] @@ -4676,7 +4690,7 @@ a") tester = self class Foo(torch.jit.ScriptModule): - __constants__ = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j'] + __constants__ = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i'] def __init__(self): super(Foo, self).__init__(False) @@ -4695,14 +4709,8 @@ a") self.h = type(1) with tester.assertRaisesRegex(TypeError, "not a valid constant"): self.i = (3, 4, {}) - self.j = (6, (1, 2, 3), 8) - - @torch.jit.script_method - def forward(self, x): - return x + self.a + self.b + self.f[0] + self.j[1][2] f = Foo() - self.assertEqual(f(torch.ones(1)), torch.ones(1) + 1 + 1.2 + 3 + 3) def test_script_module_for(self): class M(torch.jit.ScriptModule): @@ -5078,17 +5086,6 @@ a") v = torch.randn(1, device='cuda') self.assertEqual(foo(v), 0) - def test_script_storage_offset(self): - @torch.jit.script - def foo(a): - return a.storage_offset() - - v = torch.randn(5) - self.assertEqual(foo(v), 0) - - v.set_(v.storage(), 3, [1], [1]) - self.assertEquals(foo(v), 3) - def test_script_chunk(self): @torch.jit.script def foo(a): @@ -7535,6 +7532,66 @@ a") ge_graph = jit_trace.__getattr__('forward').graph_for(x, y, c) self.assertExpectedGraph(ge_graph, 'jit') + def test_exceptions(self): + cu = torch.jit.CompilationUnit(''' + def foo(cond): + if bool(cond): + raise ValueError(3) + return 1 + ''') + + cu.foo(torch.tensor(0)) + with self.assertRaisesRegex(torch.jit._Exception, "Exception"): + cu.foo(torch.tensor(1)) + + @torch.jit.script + def foo(cond): + a = 3 + if bool(cond): + raise ArbitraryError(a, "hi") + if False: + raise ArbitraryError + return a + + foo(torch.tensor(0)) + # we don't currently validate the name of the exception + with self.assertRaisesRegex(torch.jit._Exception, "Exception"): + foo(torch.tensor(1)) + + @torch.jit.script + def foo(): + a = Exception() + raise a + + # a gets DCEd because the expression following raise is ignored + with self.assertRaisesRegex(torch.jit._Exception, "failed in interpreter"): + foo() + + @torch.jit.script + def foo_except_used(): + a = Exception() + print(a) + raise a + + # a not DCEd + with self.assertRaisesRegex(RuntimeError, "expected value of type Tensor"): + foo_except_used() + + # We don't validate the expr following raise + @torch.jit.script + def foo(): + raise 3 + 4 + + # no control flow analysis yet + with self.assertRaisesRegex(RuntimeError, "undefined value a"): + @torch.jit.script + def foo(): + if True: + a = 1 + else: + raise Exception("Hi") + return a + def test_weak_script_function(self): outer_var = 10 outer_var2 = 11 @@ -9309,6 +9366,7 @@ def add_nn_module_test(module_name, constructor_args, call_args, skipTestIf=()): call_args_str = ', '.join(actuals) call = "self.submodule({})".format(call_args_str) script = script_method_template.format(method_args, call) + print(script) # Create module to use the script method class TheModule(torch.jit.ScriptModule): diff --git a/torch/csrc/jit/init.cpp b/torch/csrc/jit/init.cpp index 279cba10e29..7abd6cdc6d3 100644 --- a/torch/csrc/jit/init.cpp +++ b/torch/csrc/jit/init.cpp @@ -36,6 +36,7 @@ #include "torch/csrc/jit/function_schema.h" #include "torch/csrc/jit/operator.h" #include "torch/csrc/jit/fusers/interface.h" +#include "torch/csrc/jit/script/jit_exception.h" #include "caffe2/serialize/inline_container.h" @@ -84,6 +85,8 @@ std::string runJITCPPTests(); void initJITBindings(PyObject *module) { auto m = py::handle(module).cast(); + py::register_exception(m, "JITException"); + py::class_(m, "IODescriptor"); m.def("_jit_init", loadPythonClasses) diff --git a/torch/csrc/jit/interpreter.cpp b/torch/csrc/jit/interpreter.cpp index 887f660d7c1..58ed481fdbe 100644 --- a/torch/csrc/jit/interpreter.cpp +++ b/torch/csrc/jit/interpreter.cpp @@ -12,6 +12,7 @@ #include "torch/csrc/jit/constants.h" #include "torch/csrc/jit/operator.h" #include "torch/csrc/variable_tensor_functions.h" +#include "torch/csrc/jit/script/jit_exception.h" #include #include @@ -664,10 +665,15 @@ struct InterpreterStateImpl { } pc = new_pc; } catch(std::exception & e) { - if(!instructions[pc].debug_location) - throw; // rethrow original exception - // throw a new exception with enhanced debugging information - instructions[pc].debug_location->wrapAndRethrowException(e, "operation failed in interpreter"); + if (!instructions[pc].debug_location) { + throw; + } + auto msg = instructions[pc].debug_location->wrapException(e, "operation failed in interpreter"); + if (dynamic_cast(&e)) { + throw JITException(msg); + } else { + throw std::runtime_error(msg); + } } } current_pc = pc; diff --git a/torch/csrc/jit/passes/constant_propagation.cpp b/torch/csrc/jit/passes/constant_propagation.cpp index c7f8451d0f8..d5c9f5b19d3 100644 --- a/torch/csrc/jit/passes/constant_propagation.cpp +++ b/torch/csrc/jit/passes/constant_propagation.cpp @@ -16,17 +16,10 @@ std::unordered_set skip_list = { prim::If, prim::Loop, //TODO: handle Loop prim::Print, + prim::RaiseException, prim::PythonOp, //may have side effects prim::LoadWorld, prim::StoreWorld, - //all the rand functions from native_functions.yaml - aten::rand, - aten::rand_like, - aten::randint, - aten::randint_like, - aten::randn, - aten::randn_like, - aten::randperm, prim::Constant, prim::Undefined, prim::NoneGenerator, @@ -129,7 +122,8 @@ void ConstantPropagation(Node* n, bool recurse) { std::all_of(n->inputs().begin(), n->inputs().end(), [&](Value* v) { return v->node()->kind() == prim::Constant; }); - bool supported_node = !n->kind().is_onnx() && skip_list.count(n->kind()) == 0; + bool supported_node = !n->kind().is_onnx() && skip_list.count(n->kind()) == 0 + && !n->isNondeterministic(); auto run_blocks = [&]() { if (recurse) { for (Block * block : n->blocks()) { diff --git a/torch/csrc/jit/passes/dead_code_elimination.cpp b/torch/csrc/jit/passes/dead_code_elimination.cpp index ebd79235f57..0b47f725e51 100644 --- a/torch/csrc/jit/passes/dead_code_elimination.cpp +++ b/torch/csrc/jit/passes/dead_code_elimination.cpp @@ -25,6 +25,7 @@ bool hasSideEffects(Node * node, bool_memo_type& memo) { return it->second; bool has_side_effects = node->kind() == prim::Print || node->kind() == prim::StoreWorld || + node->kind() == prim::RaiseException || std::any_of(node->blocks().begin(), node->blocks().end(), [&](Block* b) { return std::any_of(b->nodes().begin(), b->nodes().end(), [&](Node* n) { return hasSideEffects(n, memo); diff --git a/torch/csrc/jit/passes/shape_analysis.cpp b/torch/csrc/jit/passes/shape_analysis.cpp index 956b1a9db4a..4935bf0f52b 100644 --- a/torch/csrc/jit/passes/shape_analysis.cpp +++ b/torch/csrc/jit/passes/shape_analysis.cpp @@ -348,6 +348,7 @@ void PropagateShapeOnNode(Node * node, bool insert_expands) { } case prim::PythonOp: case prim::Print: + case prim::RaiseException: case prim::Undefined: { setUnshapedType(node); return; diff --git a/torch/csrc/jit/register_prim_ops.cpp b/torch/csrc/jit/register_prim_ops.cpp index 101b141c09b..386109f59e9 100644 --- a/torch/csrc/jit/register_prim_ops.cpp +++ b/torch/csrc/jit/register_prim_ops.cpp @@ -8,6 +8,7 @@ #include "torch/csrc/jit/ir.h" #include "torch/csrc/jit/operator.h" #include "torch/csrc/jit/custom_operator.h" +#include "torch/csrc/jit/script/jit_exception.h" #include "torch/csrc/variable_tensor_functions.h" @@ -255,6 +256,15 @@ RegisterOperators reg({ return 0; }; }), + Operator( + prim::RaiseException, + [](Node* node) -> Operation { + return [](Stack& stack) { + throw JITException(pop(stack).toStringRef()); + return 0; + }; + }), + // Load x, y // loads values from registers onto the stack, the actual callback does // nothing since the stack manipulation is already encoded in inst.inputs diff --git a/torch/csrc/jit/script/compiler.cpp b/torch/csrc/jit/script/compiler.cpp index cd2eb736e98..c35e89e4cbe 100644 --- a/torch/csrc/jit/script/compiler.cpp +++ b/torch/csrc/jit/script/compiler.cpp @@ -948,6 +948,9 @@ private: } } break; + case TK_RAISE: + emitRaise(Raise(stmt)); + break; case TK_RETURN: throw ErrorReport(stmt) << "return statements can appear only at the end " << "of the function body"; @@ -1294,6 +1297,27 @@ private: emitLoopCommon(stmt.range(), {}, {cond}, stmt.body(), {}); } + + // Currently we do not support assigning exceptions to variables, + // a = Exception("hi") + // raise a + // + // We ignore the expression following raise + // + // NYI: add exception logic to control-flow nodes + // if True: + // a = 1 + // else + // raise Exception("Hi") + // print(a) + void emitRaise(const Raise& stmt) { + const std::string exception = "Exception"; + auto string_input = insertConstant(*graph, exception, stmt.range()); + graph->insertNode(graph->create(prim::RaiseException, {string_input}, 0) + ->setSourceLocation(std::make_shared(stmt.range()))); + } + + // Validate that the `lhs` Expr's in an assignment statement are valid. That // is: // diff --git a/torch/csrc/jit/script/jit_exception.h b/torch/csrc/jit/script/jit_exception.h new file mode 100644 index 00000000000..6de8ad52193 --- /dev/null +++ b/torch/csrc/jit/script/jit_exception.h @@ -0,0 +1,16 @@ +#pragma once + +#include + +namespace torch { +namespace jit { + +struct JITException + : public std::runtime_error { + JITException() = default; + explicit JITException(const std::string& msg) + : std::runtime_error(msg) {} +}; + +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/script/lexer.h b/torch/csrc/jit/script/lexer.h index cd920c42902..c252e1c668e 100644 --- a/torch/csrc/jit/script/lexer.h +++ b/torch/csrc/jit/script/lexer.h @@ -84,7 +84,9 @@ namespace script { _(TK_ARROW, "arrow", "->") \ _(TK_DECL, "decl", "") \ _(TK_SLICE_EXPR, "slice expr", "") \ - _(TK_TYPE_COMMENT, "type comment", "# type:") + _(TK_TYPE_COMMENT, "type comment", "# type:") \ + _(TK_RAISE, "raise", "raise") + static const char* valid_single_char_tokens = "+-*/%@()[]:,={}><.?!"; diff --git a/torch/csrc/jit/script/parser.h b/torch/csrc/jit/script/parser.h index 64f7f9c8db9..7603c62091e 100644 --- a/torch/csrc/jit/script/parser.h +++ b/torch/csrc/jit/script/parser.h @@ -382,6 +382,12 @@ struct Parser { auto values = parseList(TK_NOTHING, ',', TK_NEWLINE, &Parser::parseExp); return Return::create(range, values); } + case TK_RAISE: { + auto range = L.next().range; + auto expr = parseExp(); + L.expect(TK_NEWLINE); + return Raise::create(range, expr); + } default: { List exprs = parseList(TK_NOTHING, ',', TK_NOTHING, &Parser::parseExp); if (L.cur().kind != TK_NEWLINE) { diff --git a/torch/csrc/jit/script/python_tree_views.cpp b/torch/csrc/jit/script/python_tree_views.cpp index be67a262a3f..d1d0430f4e3 100644 --- a/torch/csrc/jit/script/python_tree_views.cpp +++ b/torch/csrc/jit/script/python_tree_views.cpp @@ -131,6 +131,10 @@ void initTreeViewBindings(PyObject *module) { .def(py::init([](const SourceRange& range, std::vector values) { return Return::create(range, wrap_list(range, std::move(values))); })); + py::class_(m, "Raise") + .def(py::init([](const SourceRange& range, Expr *expr) { + return Raise::create(range, wrap_maybe(range, expr)); + })); py::class_(m, "If") .def(py::init([](const SourceRange& range, const Expr& cond, std::vector true_branch, std::vector false_branch) { return If::create(range, cond, diff --git a/torch/csrc/jit/script/tree_views.h b/torch/csrc/jit/script/tree_views.h index 162c33e6838..d706eb64292 100644 --- a/torch/csrc/jit/script/tree_views.h +++ b/torch/csrc/jit/script/tree_views.h @@ -30,6 +30,7 @@ namespace script { // | Assign(List lhs, AssignType maybe_reduce, Expr rhs) TK_ASSIGN // | Return(List values) TK_RETURN // | ExprStmt(List expr) TK_EXPR_STMT +// | Raise(Expr expr) TK_RAISE // // Expr = TernaryIf(Expr cond, Expr true_expr, Expr false_expr) TK_IF_EXPR // | BinOp(Expr lhs, Expr rhs) @@ -208,6 +209,7 @@ struct Stmt : public TreeView { case TK_ASSIGN: case TK_RETURN: case TK_EXPR_STMT: + case TK_RAISE: return; default: throw ErrorReport(tree) << kindToString(tree->kind()) << " is not a valid Stmt"; @@ -464,6 +466,19 @@ struct Return : public Stmt { } }; +struct Raise : public Stmt { + explicit Raise(const TreeRef& tree) : Stmt(tree) { + tree_->match(TK_RAISE); + } + Maybe expr() const { + return Maybe(subtree(0)); + } + static Raise create(const SourceRange& range, const Maybe& expr) { + return Raise(Compound::create(TK_RAISE, range, {expr})); + } +}; + + struct ExprStmt : public Stmt { explicit ExprStmt(const TreeRef& tree) : Stmt(tree) { tree_->match(TK_EXPR_STMT); diff --git a/torch/csrc/jit/source_location.h b/torch/csrc/jit/source_location.h index 6b86391a8f1..acbd8398a72 100644 --- a/torch/csrc/jit/source_location.h +++ b/torch/csrc/jit/source_location.h @@ -14,15 +14,20 @@ namespace torch { namespace jit { struct SourceLocation { virtual ~SourceLocation() = default; virtual void highlight(std::ostream & out) const = 0; - void wrapAndRethrowException(const std::exception & e, const std::string & additional = "") { + + std::string wrapException(const std::exception & e, const std::string & additional = "") { std::stringstream msg; msg << "\n" << e.what() << ":\n"; if(!additional.empty()) { msg << additional << ":\n"; } highlight(msg); - throw std::runtime_error(msg.str()); + return msg.str(); } + void wrapAndRethrowException(const std::exception & e, const std::string & additional = "") { + throw std::runtime_error(wrapException(e, additional)); + } + }; inline std::ostream& operator<<(std::ostream& out, const SourceLocation& sl) { diff --git a/torch/jit/__init__.py b/torch/jit/__init__.py index 4646e11f346..c36afc52350 100644 --- a/torch/jit/__init__.py +++ b/torch/jit/__init__.py @@ -1374,6 +1374,9 @@ _register_builtin(len, 'aten::len') _register_builtin(_wait, 'aten::wait') +# torch.jit._Exception +_Exception = torch._C.JITException + class _disable_tracing(object): def __enter__(self): diff --git a/torch/jit/frontend.py b/torch/jit/frontend.py index 06c404a898e..b6300d8f511 100644 --- a/torch/jit/frontend.py +++ b/torch/jit/frontend.py @@ -266,6 +266,18 @@ class StmtBuilder(Builder): values = (stmt.value,) if not isinstance(stmt.value, ast.Tuple) else stmt.value.elts return Return(r, [build_expr(ctx, val) for val in values if val is not None]) + @staticmethod + def build_Raise(ctx, stmt): + r = ctx.make_range(stmt.lineno, stmt.col_offset, stmt.col_offset + len("raise")) + if PY2: + if stmt.tback: + raise NotSupportedError(r, "tracebacks with exceptions is not supported") + # TODO use stmt.type once instantiating exceptions is supported + expr = build_expr(ctx, stmt.inst) if stmt.inst else None + else: + expr = build_expr(ctx, stmt.exc) + return Raise(r, expr) + @staticmethod def build_AugAssign(ctx, stmt): lhs = [StmtBuilder.get_assign_lhs_expr(ctx, stmt.target)]