From 59f8e8ada7234879c4eefdf710bce603b51f2be0 Mon Sep 17 00:00:00 2001 From: Elias Ellison Date: Tue, 30 Oct 2018 20:20:26 -0700 Subject: [PATCH] First step at adding exceptions (#12789) Summary: This is a first step towards adding exceptions. We need minimal support in order to begin converting the torch library to weak script mode (which is the main goal here). Some limitations (that are documented in the tests & compiler): 1. Cannot assign exceptions to variables 2. Any name after raise is being treated as a valid Exception 3. No control flow analysis yet. Below a will be undefined: if True: a = 1 else: raise Exception("Hi") return a Pull Request resolved: https://github.com/pytorch/pytorch/pull/12789 Differential Revision: D12848936 Pulled By: eellison fbshipit-source-id: 1f60ceef2381040486123ec797e97d65b074862d --- aten/src/ATen/core/interned_strings.h | 1 + ...TestScript.test_python_frontend_py2.expect | 4 + ...TestScript.test_python_frontend_py3.expect | 10 ++ test/test_jit.py | 94 +++++++++++++++---- torch/csrc/jit/init.cpp | 3 + torch/csrc/jit/interpreter.cpp | 14 ++- .../csrc/jit/passes/constant_propagation.cpp | 12 +-- .../csrc/jit/passes/dead_code_elimination.cpp | 1 + torch/csrc/jit/passes/shape_analysis.cpp | 1 + torch/csrc/jit/register_prim_ops.cpp | 10 ++ torch/csrc/jit/script/compiler.cpp | 24 +++++ torch/csrc/jit/script/jit_exception.h | 16 ++++ torch/csrc/jit/script/lexer.h | 4 +- torch/csrc/jit/script/parser.h | 6 ++ torch/csrc/jit/script/python_tree_views.cpp | 4 + torch/csrc/jit/script/tree_views.h | 15 +++ torch/csrc/jit/source_location.h | 9 +- torch/jit/__init__.py | 3 + torch/jit/frontend.py | 12 +++ 19 files changed, 209 insertions(+), 34 deletions(-) create mode 100644 test/expect/TestScript.test_python_frontend_py2.expect create mode 100644 test/expect/TestScript.test_python_frontend_py3.expect create mode 100644 torch/csrc/jit/script/jit_exception.h diff --git a/aten/src/ATen/core/interned_strings.h b/aten/src/ATen/core/interned_strings.h index b6681b724e4..e8488217ce3 100644 --- a/aten/src/ATen/core/interned_strings.h +++ b/aten/src/ATen/core/interned_strings.h @@ -72,6 +72,7 @@ namespace c10 { _(prim, StoreWorld) \ _(prim, DummyWorld) \ _(prim, fork) \ + _(prim, RaiseException) \ _(aten, append) \ _(aten, __not__) \ FORALL_ATEN_BASE_SYMBOLS(_) \ diff --git a/test/expect/TestScript.test_python_frontend_py2.expect b/test/expect/TestScript.test_python_frontend_py2.expect new file mode 100644 index 00000000000..88722628570 --- /dev/null +++ b/test/expect/TestScript.test_python_frontend_py2.expect @@ -0,0 +1,4 @@ +(def + (ident fn) + (decl (list) (option)) + (list (raise (option)))) diff --git a/test/expect/TestScript.test_python_frontend_py3.expect b/test/expect/TestScript.test_python_frontend_py3.expect new file mode 100644 index 00000000000..315e8123077 --- /dev/null +++ b/test/expect/TestScript.test_python_frontend_py3.expect @@ -0,0 +1,10 @@ +(def + (ident fn) + (decl (list) (option)) + (list + (raise + (option + (apply + (variable (ident Exception)) + (list (string_literal hello)) + (list)))))) diff --git a/test/test_jit.py b/test/test_jit.py index 5d6007f7431..ef4db919e06 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -3860,6 +3860,20 @@ a") ast = torch.jit.frontend.get_jit_ast(fn, is_method=False) self.assertExpected(str(ast)) + @unittest.skipIf(not PY2, "Requires python 2") + def test_python_frontend_py2(self): + def fn(): + raise Exception("hello") + ast = torch.jit.frontend.get_jit_ast(fn, is_method=False) + self.assertExpected(str(ast)) + + @unittest.skipIf(PY2, "Requires python 3") + def test_python_frontend_py3(self): + def fn(): + raise Exception("hello") + ast = torch.jit.frontend.get_jit_ast(fn, is_method=False) + self.assertExpected(str(ast)) + def _make_scalar_vars(self, arr, dtype): return [torch.tensor(val, dtype=dtype) for val in arr] @@ -4676,7 +4690,7 @@ a") tester = self class Foo(torch.jit.ScriptModule): - __constants__ = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j'] + __constants__ = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i'] def __init__(self): super(Foo, self).__init__(False) @@ -4695,14 +4709,8 @@ a") self.h = type(1) with tester.assertRaisesRegex(TypeError, "not a valid constant"): self.i = (3, 4, {}) - self.j = (6, (1, 2, 3), 8) - - @torch.jit.script_method - def forward(self, x): - return x + self.a + self.b + self.f[0] + self.j[1][2] f = Foo() - self.assertEqual(f(torch.ones(1)), torch.ones(1) + 1 + 1.2 + 3 + 3) def test_script_module_for(self): class M(torch.jit.ScriptModule): @@ -5078,17 +5086,6 @@ a") v = torch.randn(1, device='cuda') self.assertEqual(foo(v), 0) - def test_script_storage_offset(self): - @torch.jit.script - def foo(a): - return a.storage_offset() - - v = torch.randn(5) - self.assertEqual(foo(v), 0) - - v.set_(v.storage(), 3, [1], [1]) - self.assertEquals(foo(v), 3) - def test_script_chunk(self): @torch.jit.script def foo(a): @@ -7535,6 +7532,66 @@ a") ge_graph = jit_trace.__getattr__('forward').graph_for(x, y, c) self.assertExpectedGraph(ge_graph, 'jit') + def test_exceptions(self): + cu = torch.jit.CompilationUnit(''' + def foo(cond): + if bool(cond): + raise ValueError(3) + return 1 + ''') + + cu.foo(torch.tensor(0)) + with self.assertRaisesRegex(torch.jit._Exception, "Exception"): + cu.foo(torch.tensor(1)) + + @torch.jit.script + def foo(cond): + a = 3 + if bool(cond): + raise ArbitraryError(a, "hi") + if False: + raise ArbitraryError + return a + + foo(torch.tensor(0)) + # we don't currently validate the name of the exception + with self.assertRaisesRegex(torch.jit._Exception, "Exception"): + foo(torch.tensor(1)) + + @torch.jit.script + def foo(): + a = Exception() + raise a + + # a gets DCEd because the expression following raise is ignored + with self.assertRaisesRegex(torch.jit._Exception, "failed in interpreter"): + foo() + + @torch.jit.script + def foo_except_used(): + a = Exception() + print(a) + raise a + + # a not DCEd + with self.assertRaisesRegex(RuntimeError, "expected value of type Tensor"): + foo_except_used() + + # We don't validate the expr following raise + @torch.jit.script + def foo(): + raise 3 + 4 + + # no control flow analysis yet + with self.assertRaisesRegex(RuntimeError, "undefined value a"): + @torch.jit.script + def foo(): + if True: + a = 1 + else: + raise Exception("Hi") + return a + def test_weak_script_function(self): outer_var = 10 outer_var2 = 11 @@ -9309,6 +9366,7 @@ def add_nn_module_test(module_name, constructor_args, call_args, skipTestIf=()): call_args_str = ', '.join(actuals) call = "self.submodule({})".format(call_args_str) script = script_method_template.format(method_args, call) + print(script) # Create module to use the script method class TheModule(torch.jit.ScriptModule): diff --git a/torch/csrc/jit/init.cpp b/torch/csrc/jit/init.cpp index 279cba10e29..7abd6cdc6d3 100644 --- a/torch/csrc/jit/init.cpp +++ b/torch/csrc/jit/init.cpp @@ -36,6 +36,7 @@ #include "torch/csrc/jit/function_schema.h" #include "torch/csrc/jit/operator.h" #include "torch/csrc/jit/fusers/interface.h" +#include "torch/csrc/jit/script/jit_exception.h" #include "caffe2/serialize/inline_container.h" @@ -84,6 +85,8 @@ std::string runJITCPPTests(); void initJITBindings(PyObject *module) { auto m = py::handle(module).cast(); + py::register_exception(m, "JITException"); + py::class_(m, "IODescriptor"); m.def("_jit_init", loadPythonClasses) diff --git a/torch/csrc/jit/interpreter.cpp b/torch/csrc/jit/interpreter.cpp index 887f660d7c1..58ed481fdbe 100644 --- a/torch/csrc/jit/interpreter.cpp +++ b/torch/csrc/jit/interpreter.cpp @@ -12,6 +12,7 @@ #include "torch/csrc/jit/constants.h" #include "torch/csrc/jit/operator.h" #include "torch/csrc/variable_tensor_functions.h" +#include "torch/csrc/jit/script/jit_exception.h" #include #include @@ -664,10 +665,15 @@ struct InterpreterStateImpl { } pc = new_pc; } catch(std::exception & e) { - if(!instructions[pc].debug_location) - throw; // rethrow original exception - // throw a new exception with enhanced debugging information - instructions[pc].debug_location->wrapAndRethrowException(e, "operation failed in interpreter"); + if (!instructions[pc].debug_location) { + throw; + } + auto msg = instructions[pc].debug_location->wrapException(e, "operation failed in interpreter"); + if (dynamic_cast(&e)) { + throw JITException(msg); + } else { + throw std::runtime_error(msg); + } } } current_pc = pc; diff --git a/torch/csrc/jit/passes/constant_propagation.cpp b/torch/csrc/jit/passes/constant_propagation.cpp index c7f8451d0f8..d5c9f5b19d3 100644 --- a/torch/csrc/jit/passes/constant_propagation.cpp +++ b/torch/csrc/jit/passes/constant_propagation.cpp @@ -16,17 +16,10 @@ std::unordered_set skip_list = { prim::If, prim::Loop, //TODO: handle Loop prim::Print, + prim::RaiseException, prim::PythonOp, //may have side effects prim::LoadWorld, prim::StoreWorld, - //all the rand functions from native_functions.yaml - aten::rand, - aten::rand_like, - aten::randint, - aten::randint_like, - aten::randn, - aten::randn_like, - aten::randperm, prim::Constant, prim::Undefined, prim::NoneGenerator, @@ -129,7 +122,8 @@ void ConstantPropagation(Node* n, bool recurse) { std::all_of(n->inputs().begin(), n->inputs().end(), [&](Value* v) { return v->node()->kind() == prim::Constant; }); - bool supported_node = !n->kind().is_onnx() && skip_list.count(n->kind()) == 0; + bool supported_node = !n->kind().is_onnx() && skip_list.count(n->kind()) == 0 + && !n->isNondeterministic(); auto run_blocks = [&]() { if (recurse) { for (Block * block : n->blocks()) { diff --git a/torch/csrc/jit/passes/dead_code_elimination.cpp b/torch/csrc/jit/passes/dead_code_elimination.cpp index ebd79235f57..0b47f725e51 100644 --- a/torch/csrc/jit/passes/dead_code_elimination.cpp +++ b/torch/csrc/jit/passes/dead_code_elimination.cpp @@ -25,6 +25,7 @@ bool hasSideEffects(Node * node, bool_memo_type& memo) { return it->second; bool has_side_effects = node->kind() == prim::Print || node->kind() == prim::StoreWorld || + node->kind() == prim::RaiseException || std::any_of(node->blocks().begin(), node->blocks().end(), [&](Block* b) { return std::any_of(b->nodes().begin(), b->nodes().end(), [&](Node* n) { return hasSideEffects(n, memo); diff --git a/torch/csrc/jit/passes/shape_analysis.cpp b/torch/csrc/jit/passes/shape_analysis.cpp index 956b1a9db4a..4935bf0f52b 100644 --- a/torch/csrc/jit/passes/shape_analysis.cpp +++ b/torch/csrc/jit/passes/shape_analysis.cpp @@ -348,6 +348,7 @@ void PropagateShapeOnNode(Node * node, bool insert_expands) { } case prim::PythonOp: case prim::Print: + case prim::RaiseException: case prim::Undefined: { setUnshapedType(node); return; diff --git a/torch/csrc/jit/register_prim_ops.cpp b/torch/csrc/jit/register_prim_ops.cpp index 101b141c09b..386109f59e9 100644 --- a/torch/csrc/jit/register_prim_ops.cpp +++ b/torch/csrc/jit/register_prim_ops.cpp @@ -8,6 +8,7 @@ #include "torch/csrc/jit/ir.h" #include "torch/csrc/jit/operator.h" #include "torch/csrc/jit/custom_operator.h" +#include "torch/csrc/jit/script/jit_exception.h" #include "torch/csrc/variable_tensor_functions.h" @@ -255,6 +256,15 @@ RegisterOperators reg({ return 0; }; }), + Operator( + prim::RaiseException, + [](Node* node) -> Operation { + return [](Stack& stack) { + throw JITException(pop(stack).toStringRef()); + return 0; + }; + }), + // Load x, y // loads values from registers onto the stack, the actual callback does // nothing since the stack manipulation is already encoded in inst.inputs diff --git a/torch/csrc/jit/script/compiler.cpp b/torch/csrc/jit/script/compiler.cpp index cd2eb736e98..c35e89e4cbe 100644 --- a/torch/csrc/jit/script/compiler.cpp +++ b/torch/csrc/jit/script/compiler.cpp @@ -948,6 +948,9 @@ private: } } break; + case TK_RAISE: + emitRaise(Raise(stmt)); + break; case TK_RETURN: throw ErrorReport(stmt) << "return statements can appear only at the end " << "of the function body"; @@ -1294,6 +1297,27 @@ private: emitLoopCommon(stmt.range(), {}, {cond}, stmt.body(), {}); } + + // Currently we do not support assigning exceptions to variables, + // a = Exception("hi") + // raise a + // + // We ignore the expression following raise + // + // NYI: add exception logic to control-flow nodes + // if True: + // a = 1 + // else + // raise Exception("Hi") + // print(a) + void emitRaise(const Raise& stmt) { + const std::string exception = "Exception"; + auto string_input = insertConstant(*graph, exception, stmt.range()); + graph->insertNode(graph->create(prim::RaiseException, {string_input}, 0) + ->setSourceLocation(std::make_shared(stmt.range()))); + } + + // Validate that the `lhs` Expr's in an assignment statement are valid. That // is: // diff --git a/torch/csrc/jit/script/jit_exception.h b/torch/csrc/jit/script/jit_exception.h new file mode 100644 index 00000000000..6de8ad52193 --- /dev/null +++ b/torch/csrc/jit/script/jit_exception.h @@ -0,0 +1,16 @@ +#pragma once + +#include + +namespace torch { +namespace jit { + +struct JITException + : public std::runtime_error { + JITException() = default; + explicit JITException(const std::string& msg) + : std::runtime_error(msg) {} +}; + +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/script/lexer.h b/torch/csrc/jit/script/lexer.h index cd920c42902..c252e1c668e 100644 --- a/torch/csrc/jit/script/lexer.h +++ b/torch/csrc/jit/script/lexer.h @@ -84,7 +84,9 @@ namespace script { _(TK_ARROW, "arrow", "->") \ _(TK_DECL, "decl", "") \ _(TK_SLICE_EXPR, "slice expr", "") \ - _(TK_TYPE_COMMENT, "type comment", "# type:") + _(TK_TYPE_COMMENT, "type comment", "# type:") \ + _(TK_RAISE, "raise", "raise") + static const char* valid_single_char_tokens = "+-*/%@()[]:,={}><.?!"; diff --git a/torch/csrc/jit/script/parser.h b/torch/csrc/jit/script/parser.h index 64f7f9c8db9..7603c62091e 100644 --- a/torch/csrc/jit/script/parser.h +++ b/torch/csrc/jit/script/parser.h @@ -382,6 +382,12 @@ struct Parser { auto values = parseList(TK_NOTHING, ',', TK_NEWLINE, &Parser::parseExp); return Return::create(range, values); } + case TK_RAISE: { + auto range = L.next().range; + auto expr = parseExp(); + L.expect(TK_NEWLINE); + return Raise::create(range, expr); + } default: { List exprs = parseList(TK_NOTHING, ',', TK_NOTHING, &Parser::parseExp); if (L.cur().kind != TK_NEWLINE) { diff --git a/torch/csrc/jit/script/python_tree_views.cpp b/torch/csrc/jit/script/python_tree_views.cpp index be67a262a3f..d1d0430f4e3 100644 --- a/torch/csrc/jit/script/python_tree_views.cpp +++ b/torch/csrc/jit/script/python_tree_views.cpp @@ -131,6 +131,10 @@ void initTreeViewBindings(PyObject *module) { .def(py::init([](const SourceRange& range, std::vector values) { return Return::create(range, wrap_list(range, std::move(values))); })); + py::class_(m, "Raise") + .def(py::init([](const SourceRange& range, Expr *expr) { + return Raise::create(range, wrap_maybe(range, expr)); + })); py::class_(m, "If") .def(py::init([](const SourceRange& range, const Expr& cond, std::vector true_branch, std::vector false_branch) { return If::create(range, cond, diff --git a/torch/csrc/jit/script/tree_views.h b/torch/csrc/jit/script/tree_views.h index 162c33e6838..d706eb64292 100644 --- a/torch/csrc/jit/script/tree_views.h +++ b/torch/csrc/jit/script/tree_views.h @@ -30,6 +30,7 @@ namespace script { // | Assign(List lhs, AssignType maybe_reduce, Expr rhs) TK_ASSIGN // | Return(List values) TK_RETURN // | ExprStmt(List expr) TK_EXPR_STMT +// | Raise(Expr expr) TK_RAISE // // Expr = TernaryIf(Expr cond, Expr true_expr, Expr false_expr) TK_IF_EXPR // | BinOp(Expr lhs, Expr rhs) @@ -208,6 +209,7 @@ struct Stmt : public TreeView { case TK_ASSIGN: case TK_RETURN: case TK_EXPR_STMT: + case TK_RAISE: return; default: throw ErrorReport(tree) << kindToString(tree->kind()) << " is not a valid Stmt"; @@ -464,6 +466,19 @@ struct Return : public Stmt { } }; +struct Raise : public Stmt { + explicit Raise(const TreeRef& tree) : Stmt(tree) { + tree_->match(TK_RAISE); + } + Maybe expr() const { + return Maybe(subtree(0)); + } + static Raise create(const SourceRange& range, const Maybe& expr) { + return Raise(Compound::create(TK_RAISE, range, {expr})); + } +}; + + struct ExprStmt : public Stmt { explicit ExprStmt(const TreeRef& tree) : Stmt(tree) { tree_->match(TK_EXPR_STMT); diff --git a/torch/csrc/jit/source_location.h b/torch/csrc/jit/source_location.h index 6b86391a8f1..acbd8398a72 100644 --- a/torch/csrc/jit/source_location.h +++ b/torch/csrc/jit/source_location.h @@ -14,15 +14,20 @@ namespace torch { namespace jit { struct SourceLocation { virtual ~SourceLocation() = default; virtual void highlight(std::ostream & out) const = 0; - void wrapAndRethrowException(const std::exception & e, const std::string & additional = "") { + + std::string wrapException(const std::exception & e, const std::string & additional = "") { std::stringstream msg; msg << "\n" << e.what() << ":\n"; if(!additional.empty()) { msg << additional << ":\n"; } highlight(msg); - throw std::runtime_error(msg.str()); + return msg.str(); } + void wrapAndRethrowException(const std::exception & e, const std::string & additional = "") { + throw std::runtime_error(wrapException(e, additional)); + } + }; inline std::ostream& operator<<(std::ostream& out, const SourceLocation& sl) { diff --git a/torch/jit/__init__.py b/torch/jit/__init__.py index 4646e11f346..c36afc52350 100644 --- a/torch/jit/__init__.py +++ b/torch/jit/__init__.py @@ -1374,6 +1374,9 @@ _register_builtin(len, 'aten::len') _register_builtin(_wait, 'aten::wait') +# torch.jit._Exception +_Exception = torch._C.JITException + class _disable_tracing(object): def __enter__(self): diff --git a/torch/jit/frontend.py b/torch/jit/frontend.py index 06c404a898e..b6300d8f511 100644 --- a/torch/jit/frontend.py +++ b/torch/jit/frontend.py @@ -266,6 +266,18 @@ class StmtBuilder(Builder): values = (stmt.value,) if not isinstance(stmt.value, ast.Tuple) else stmt.value.elts return Return(r, [build_expr(ctx, val) for val in values if val is not None]) + @staticmethod + def build_Raise(ctx, stmt): + r = ctx.make_range(stmt.lineno, stmt.col_offset, stmt.col_offset + len("raise")) + if PY2: + if stmt.tback: + raise NotSupportedError(r, "tracebacks with exceptions is not supported") + # TODO use stmt.type once instantiating exceptions is supported + expr = build_expr(ctx, stmt.inst) if stmt.inst else None + else: + expr = build_expr(ctx, stmt.exc) + return Raise(r, expr) + @staticmethod def build_AugAssign(ctx, stmt): lhs = [StmtBuilder.get_assign_lhs_expr(ctx, stmt.target)]