mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
First step at adding exceptions (#12789)
Summary:
This is a first step towards adding exceptions. We need minimal support in order to begin converting the torch library to weak script mode (which is the main goal here).
Some limitations (that are documented in the tests & compiler):
1. Cannot assign exceptions to variables
2. Any name after raise is being treated as a valid Exception
3. No control flow analysis yet. Below a will be undefined:
if True:
a = 1
else:
raise Exception("Hi")
return a
Pull Request resolved: https://github.com/pytorch/pytorch/pull/12789
Differential Revision: D12848936
Pulled By: eellison
fbshipit-source-id: 1f60ceef2381040486123ec797e97d65b074862d
This commit is contained in:
parent
c7027a511f
commit
59f8e8ada7
|
|
@ -72,6 +72,7 @@ namespace c10 {
|
||||||
_(prim, StoreWorld) \
|
_(prim, StoreWorld) \
|
||||||
_(prim, DummyWorld) \
|
_(prim, DummyWorld) \
|
||||||
_(prim, fork) \
|
_(prim, fork) \
|
||||||
|
_(prim, RaiseException) \
|
||||||
_(aten, append) \
|
_(aten, append) \
|
||||||
_(aten, __not__) \
|
_(aten, __not__) \
|
||||||
FORALL_ATEN_BASE_SYMBOLS(_) \
|
FORALL_ATEN_BASE_SYMBOLS(_) \
|
||||||
|
|
|
||||||
4
test/expect/TestScript.test_python_frontend_py2.expect
Normal file
4
test/expect/TestScript.test_python_frontend_py2.expect
Normal file
|
|
@ -0,0 +1,4 @@
|
||||||
|
(def
|
||||||
|
(ident fn)
|
||||||
|
(decl (list) (option))
|
||||||
|
(list (raise (option))))
|
||||||
10
test/expect/TestScript.test_python_frontend_py3.expect
Normal file
10
test/expect/TestScript.test_python_frontend_py3.expect
Normal file
|
|
@ -0,0 +1,10 @@
|
||||||
|
(def
|
||||||
|
(ident fn)
|
||||||
|
(decl (list) (option))
|
||||||
|
(list
|
||||||
|
(raise
|
||||||
|
(option
|
||||||
|
(apply
|
||||||
|
(variable (ident Exception))
|
||||||
|
(list (string_literal hello))
|
||||||
|
(list))))))
|
||||||
|
|
@ -3860,6 +3860,20 @@ a")
|
||||||
ast = torch.jit.frontend.get_jit_ast(fn, is_method=False)
|
ast = torch.jit.frontend.get_jit_ast(fn, is_method=False)
|
||||||
self.assertExpected(str(ast))
|
self.assertExpected(str(ast))
|
||||||
|
|
||||||
|
@unittest.skipIf(not PY2, "Requires python 2")
|
||||||
|
def test_python_frontend_py2(self):
|
||||||
|
def fn():
|
||||||
|
raise Exception("hello")
|
||||||
|
ast = torch.jit.frontend.get_jit_ast(fn, is_method=False)
|
||||||
|
self.assertExpected(str(ast))
|
||||||
|
|
||||||
|
@unittest.skipIf(PY2, "Requires python 3")
|
||||||
|
def test_python_frontend_py3(self):
|
||||||
|
def fn():
|
||||||
|
raise Exception("hello")
|
||||||
|
ast = torch.jit.frontend.get_jit_ast(fn, is_method=False)
|
||||||
|
self.assertExpected(str(ast))
|
||||||
|
|
||||||
def _make_scalar_vars(self, arr, dtype):
|
def _make_scalar_vars(self, arr, dtype):
|
||||||
return [torch.tensor(val, dtype=dtype) for val in arr]
|
return [torch.tensor(val, dtype=dtype) for val in arr]
|
||||||
|
|
||||||
|
|
@ -4676,7 +4690,7 @@ a")
|
||||||
tester = self
|
tester = self
|
||||||
|
|
||||||
class Foo(torch.jit.ScriptModule):
|
class Foo(torch.jit.ScriptModule):
|
||||||
__constants__ = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j']
|
__constants__ = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i']
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(Foo, self).__init__(False)
|
super(Foo, self).__init__(False)
|
||||||
|
|
@ -4695,14 +4709,8 @@ a")
|
||||||
self.h = type(1)
|
self.h = type(1)
|
||||||
with tester.assertRaisesRegex(TypeError, "not a valid constant"):
|
with tester.assertRaisesRegex(TypeError, "not a valid constant"):
|
||||||
self.i = (3, 4, {})
|
self.i = (3, 4, {})
|
||||||
self.j = (6, (1, 2, 3), 8)
|
|
||||||
|
|
||||||
@torch.jit.script_method
|
|
||||||
def forward(self, x):
|
|
||||||
return x + self.a + self.b + self.f[0] + self.j[1][2]
|
|
||||||
|
|
||||||
f = Foo()
|
f = Foo()
|
||||||
self.assertEqual(f(torch.ones(1)), torch.ones(1) + 1 + 1.2 + 3 + 3)
|
|
||||||
|
|
||||||
def test_script_module_for(self):
|
def test_script_module_for(self):
|
||||||
class M(torch.jit.ScriptModule):
|
class M(torch.jit.ScriptModule):
|
||||||
|
|
@ -5078,17 +5086,6 @@ a")
|
||||||
v = torch.randn(1, device='cuda')
|
v = torch.randn(1, device='cuda')
|
||||||
self.assertEqual(foo(v), 0)
|
self.assertEqual(foo(v), 0)
|
||||||
|
|
||||||
def test_script_storage_offset(self):
|
|
||||||
@torch.jit.script
|
|
||||||
def foo(a):
|
|
||||||
return a.storage_offset()
|
|
||||||
|
|
||||||
v = torch.randn(5)
|
|
||||||
self.assertEqual(foo(v), 0)
|
|
||||||
|
|
||||||
v.set_(v.storage(), 3, [1], [1])
|
|
||||||
self.assertEquals(foo(v), 3)
|
|
||||||
|
|
||||||
def test_script_chunk(self):
|
def test_script_chunk(self):
|
||||||
@torch.jit.script
|
@torch.jit.script
|
||||||
def foo(a):
|
def foo(a):
|
||||||
|
|
@ -7535,6 +7532,66 @@ a")
|
||||||
ge_graph = jit_trace.__getattr__('forward').graph_for(x, y, c)
|
ge_graph = jit_trace.__getattr__('forward').graph_for(x, y, c)
|
||||||
self.assertExpectedGraph(ge_graph, 'jit')
|
self.assertExpectedGraph(ge_graph, 'jit')
|
||||||
|
|
||||||
|
def test_exceptions(self):
|
||||||
|
cu = torch.jit.CompilationUnit('''
|
||||||
|
def foo(cond):
|
||||||
|
if bool(cond):
|
||||||
|
raise ValueError(3)
|
||||||
|
return 1
|
||||||
|
''')
|
||||||
|
|
||||||
|
cu.foo(torch.tensor(0))
|
||||||
|
with self.assertRaisesRegex(torch.jit._Exception, "Exception"):
|
||||||
|
cu.foo(torch.tensor(1))
|
||||||
|
|
||||||
|
@torch.jit.script
|
||||||
|
def foo(cond):
|
||||||
|
a = 3
|
||||||
|
if bool(cond):
|
||||||
|
raise ArbitraryError(a, "hi")
|
||||||
|
if False:
|
||||||
|
raise ArbitraryError
|
||||||
|
return a
|
||||||
|
|
||||||
|
foo(torch.tensor(0))
|
||||||
|
# we don't currently validate the name of the exception
|
||||||
|
with self.assertRaisesRegex(torch.jit._Exception, "Exception"):
|
||||||
|
foo(torch.tensor(1))
|
||||||
|
|
||||||
|
@torch.jit.script
|
||||||
|
def foo():
|
||||||
|
a = Exception()
|
||||||
|
raise a
|
||||||
|
|
||||||
|
# a gets DCEd because the expression following raise is ignored
|
||||||
|
with self.assertRaisesRegex(torch.jit._Exception, "failed in interpreter"):
|
||||||
|
foo()
|
||||||
|
|
||||||
|
@torch.jit.script
|
||||||
|
def foo_except_used():
|
||||||
|
a = Exception()
|
||||||
|
print(a)
|
||||||
|
raise a
|
||||||
|
|
||||||
|
# a not DCEd
|
||||||
|
with self.assertRaisesRegex(RuntimeError, "expected value of type Tensor"):
|
||||||
|
foo_except_used()
|
||||||
|
|
||||||
|
# We don't validate the expr following raise
|
||||||
|
@torch.jit.script
|
||||||
|
def foo():
|
||||||
|
raise 3 + 4
|
||||||
|
|
||||||
|
# no control flow analysis yet
|
||||||
|
with self.assertRaisesRegex(RuntimeError, "undefined value a"):
|
||||||
|
@torch.jit.script
|
||||||
|
def foo():
|
||||||
|
if True:
|
||||||
|
a = 1
|
||||||
|
else:
|
||||||
|
raise Exception("Hi")
|
||||||
|
return a
|
||||||
|
|
||||||
def test_weak_script_function(self):
|
def test_weak_script_function(self):
|
||||||
outer_var = 10
|
outer_var = 10
|
||||||
outer_var2 = 11
|
outer_var2 = 11
|
||||||
|
|
@ -9309,6 +9366,7 @@ def add_nn_module_test(module_name, constructor_args, call_args, skipTestIf=()):
|
||||||
call_args_str = ', '.join(actuals)
|
call_args_str = ', '.join(actuals)
|
||||||
call = "self.submodule({})".format(call_args_str)
|
call = "self.submodule({})".format(call_args_str)
|
||||||
script = script_method_template.format(method_args, call)
|
script = script_method_template.format(method_args, call)
|
||||||
|
print(script)
|
||||||
|
|
||||||
# Create module to use the script method
|
# Create module to use the script method
|
||||||
class TheModule(torch.jit.ScriptModule):
|
class TheModule(torch.jit.ScriptModule):
|
||||||
|
|
|
||||||
|
|
@ -36,6 +36,7 @@
|
||||||
#include "torch/csrc/jit/function_schema.h"
|
#include "torch/csrc/jit/function_schema.h"
|
||||||
#include "torch/csrc/jit/operator.h"
|
#include "torch/csrc/jit/operator.h"
|
||||||
#include "torch/csrc/jit/fusers/interface.h"
|
#include "torch/csrc/jit/fusers/interface.h"
|
||||||
|
#include "torch/csrc/jit/script/jit_exception.h"
|
||||||
|
|
||||||
#include "caffe2/serialize/inline_container.h"
|
#include "caffe2/serialize/inline_container.h"
|
||||||
|
|
||||||
|
|
@ -84,6 +85,8 @@ std::string runJITCPPTests();
|
||||||
void initJITBindings(PyObject *module) {
|
void initJITBindings(PyObject *module) {
|
||||||
auto m = py::handle(module).cast<py::module>();
|
auto m = py::handle(module).cast<py::module>();
|
||||||
|
|
||||||
|
py::register_exception<JITException>(m, "JITException");
|
||||||
|
|
||||||
py::class_<python::IODescriptor>(m, "IODescriptor");
|
py::class_<python::IODescriptor>(m, "IODescriptor");
|
||||||
|
|
||||||
m.def("_jit_init", loadPythonClasses)
|
m.def("_jit_init", loadPythonClasses)
|
||||||
|
|
|
||||||
|
|
@ -12,6 +12,7 @@
|
||||||
#include "torch/csrc/jit/constants.h"
|
#include "torch/csrc/jit/constants.h"
|
||||||
#include "torch/csrc/jit/operator.h"
|
#include "torch/csrc/jit/operator.h"
|
||||||
#include "torch/csrc/variable_tensor_functions.h"
|
#include "torch/csrc/variable_tensor_functions.h"
|
||||||
|
#include "torch/csrc/jit/script/jit_exception.h"
|
||||||
|
|
||||||
#include <exception>
|
#include <exception>
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
|
|
@ -664,10 +665,15 @@ struct InterpreterStateImpl {
|
||||||
}
|
}
|
||||||
pc = new_pc;
|
pc = new_pc;
|
||||||
} catch(std::exception & e) {
|
} catch(std::exception & e) {
|
||||||
if(!instructions[pc].debug_location)
|
if (!instructions[pc].debug_location) {
|
||||||
throw; // rethrow original exception
|
throw;
|
||||||
// throw a new exception with enhanced debugging information
|
}
|
||||||
instructions[pc].debug_location->wrapAndRethrowException(e, "operation failed in interpreter");
|
auto msg = instructions[pc].debug_location->wrapException(e, "operation failed in interpreter");
|
||||||
|
if (dynamic_cast<JITException *>(&e)) {
|
||||||
|
throw JITException(msg);
|
||||||
|
} else {
|
||||||
|
throw std::runtime_error(msg);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
current_pc = pc;
|
current_pc = pc;
|
||||||
|
|
|
||||||
|
|
@ -16,17 +16,10 @@ std::unordered_set<Symbol> skip_list = {
|
||||||
prim::If,
|
prim::If,
|
||||||
prim::Loop, //TODO: handle Loop
|
prim::Loop, //TODO: handle Loop
|
||||||
prim::Print,
|
prim::Print,
|
||||||
|
prim::RaiseException,
|
||||||
prim::PythonOp, //may have side effects
|
prim::PythonOp, //may have side effects
|
||||||
prim::LoadWorld,
|
prim::LoadWorld,
|
||||||
prim::StoreWorld,
|
prim::StoreWorld,
|
||||||
//all the rand functions from native_functions.yaml
|
|
||||||
aten::rand,
|
|
||||||
aten::rand_like,
|
|
||||||
aten::randint,
|
|
||||||
aten::randint_like,
|
|
||||||
aten::randn,
|
|
||||||
aten::randn_like,
|
|
||||||
aten::randperm,
|
|
||||||
prim::Constant,
|
prim::Constant,
|
||||||
prim::Undefined,
|
prim::Undefined,
|
||||||
prim::NoneGenerator,
|
prim::NoneGenerator,
|
||||||
|
|
@ -129,7 +122,8 @@ void ConstantPropagation(Node* n, bool recurse) {
|
||||||
std::all_of(n->inputs().begin(), n->inputs().end(), [&](Value* v) {
|
std::all_of(n->inputs().begin(), n->inputs().end(), [&](Value* v) {
|
||||||
return v->node()->kind() == prim::Constant;
|
return v->node()->kind() == prim::Constant;
|
||||||
});
|
});
|
||||||
bool supported_node = !n->kind().is_onnx() && skip_list.count(n->kind()) == 0;
|
bool supported_node = !n->kind().is_onnx() && skip_list.count(n->kind()) == 0
|
||||||
|
&& !n->isNondeterministic();
|
||||||
auto run_blocks = [&]() {
|
auto run_blocks = [&]() {
|
||||||
if (recurse) {
|
if (recurse) {
|
||||||
for (Block * block : n->blocks()) {
|
for (Block * block : n->blocks()) {
|
||||||
|
|
|
||||||
|
|
@ -25,6 +25,7 @@ bool hasSideEffects(Node * node, bool_memo_type& memo) {
|
||||||
return it->second;
|
return it->second;
|
||||||
bool has_side_effects =
|
bool has_side_effects =
|
||||||
node->kind() == prim::Print || node->kind() == prim::StoreWorld ||
|
node->kind() == prim::Print || node->kind() == prim::StoreWorld ||
|
||||||
|
node->kind() == prim::RaiseException ||
|
||||||
std::any_of(node->blocks().begin(), node->blocks().end(), [&](Block* b) {
|
std::any_of(node->blocks().begin(), node->blocks().end(), [&](Block* b) {
|
||||||
return std::any_of(b->nodes().begin(), b->nodes().end(), [&](Node* n) {
|
return std::any_of(b->nodes().begin(), b->nodes().end(), [&](Node* n) {
|
||||||
return hasSideEffects(n, memo);
|
return hasSideEffects(n, memo);
|
||||||
|
|
|
||||||
|
|
@ -348,6 +348,7 @@ void PropagateShapeOnNode(Node * node, bool insert_expands) {
|
||||||
}
|
}
|
||||||
case prim::PythonOp:
|
case prim::PythonOp:
|
||||||
case prim::Print:
|
case prim::Print:
|
||||||
|
case prim::RaiseException:
|
||||||
case prim::Undefined: {
|
case prim::Undefined: {
|
||||||
setUnshapedType(node);
|
setUnshapedType(node);
|
||||||
return;
|
return;
|
||||||
|
|
|
||||||
|
|
@ -8,6 +8,7 @@
|
||||||
#include "torch/csrc/jit/ir.h"
|
#include "torch/csrc/jit/ir.h"
|
||||||
#include "torch/csrc/jit/operator.h"
|
#include "torch/csrc/jit/operator.h"
|
||||||
#include "torch/csrc/jit/custom_operator.h"
|
#include "torch/csrc/jit/custom_operator.h"
|
||||||
|
#include "torch/csrc/jit/script/jit_exception.h"
|
||||||
|
|
||||||
#include "torch/csrc/variable_tensor_functions.h"
|
#include "torch/csrc/variable_tensor_functions.h"
|
||||||
|
|
||||||
|
|
@ -255,6 +256,15 @@ RegisterOperators reg({
|
||||||
return 0;
|
return 0;
|
||||||
};
|
};
|
||||||
}),
|
}),
|
||||||
|
Operator(
|
||||||
|
prim::RaiseException,
|
||||||
|
[](Node* node) -> Operation {
|
||||||
|
return [](Stack& stack) {
|
||||||
|
throw JITException(pop(stack).toStringRef());
|
||||||
|
return 0;
|
||||||
|
};
|
||||||
|
}),
|
||||||
|
|
||||||
// Load x, y
|
// Load x, y
|
||||||
// loads values from registers onto the stack, the actual callback does
|
// loads values from registers onto the stack, the actual callback does
|
||||||
// nothing since the stack manipulation is already encoded in inst.inputs
|
// nothing since the stack manipulation is already encoded in inst.inputs
|
||||||
|
|
|
||||||
|
|
@ -948,6 +948,9 @@ private:
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
|
case TK_RAISE:
|
||||||
|
emitRaise(Raise(stmt));
|
||||||
|
break;
|
||||||
case TK_RETURN:
|
case TK_RETURN:
|
||||||
throw ErrorReport(stmt) << "return statements can appear only at the end "
|
throw ErrorReport(stmt) << "return statements can appear only at the end "
|
||||||
<< "of the function body";
|
<< "of the function body";
|
||||||
|
|
@ -1294,6 +1297,27 @@ private:
|
||||||
emitLoopCommon(stmt.range(), {}, {cond}, stmt.body(), {});
|
emitLoopCommon(stmt.range(), {}, {cond}, stmt.body(), {});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// Currently we do not support assigning exceptions to variables,
|
||||||
|
// a = Exception("hi")
|
||||||
|
// raise a
|
||||||
|
//
|
||||||
|
// We ignore the expression following raise
|
||||||
|
//
|
||||||
|
// NYI: add exception logic to control-flow nodes
|
||||||
|
// if True:
|
||||||
|
// a = 1
|
||||||
|
// else
|
||||||
|
// raise Exception("Hi")
|
||||||
|
// print(a)
|
||||||
|
void emitRaise(const Raise& stmt) {
|
||||||
|
const std::string exception = "Exception";
|
||||||
|
auto string_input = insertConstant(*graph, exception, stmt.range());
|
||||||
|
graph->insertNode(graph->create(prim::RaiseException, {string_input}, 0)
|
||||||
|
->setSourceLocation(std::make_shared<SourceRange>(stmt.range())));
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
// Validate that the `lhs` Expr's in an assignment statement are valid. That
|
// Validate that the `lhs` Expr's in an assignment statement are valid. That
|
||||||
// is:
|
// is:
|
||||||
//
|
//
|
||||||
|
|
|
||||||
16
torch/csrc/jit/script/jit_exception.h
Normal file
16
torch/csrc/jit/script/jit_exception.h
Normal file
|
|
@ -0,0 +1,16 @@
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <stdexcept>
|
||||||
|
|
||||||
|
namespace torch {
|
||||||
|
namespace jit {
|
||||||
|
|
||||||
|
struct JITException
|
||||||
|
: public std::runtime_error {
|
||||||
|
JITException() = default;
|
||||||
|
explicit JITException(const std::string& msg)
|
||||||
|
: std::runtime_error(msg) {}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace jit
|
||||||
|
} // namespace torch
|
||||||
|
|
@ -84,7 +84,9 @@ namespace script {
|
||||||
_(TK_ARROW, "arrow", "->") \
|
_(TK_ARROW, "arrow", "->") \
|
||||||
_(TK_DECL, "decl", "") \
|
_(TK_DECL, "decl", "") \
|
||||||
_(TK_SLICE_EXPR, "slice expr", "") \
|
_(TK_SLICE_EXPR, "slice expr", "") \
|
||||||
_(TK_TYPE_COMMENT, "type comment", "# type:")
|
_(TK_TYPE_COMMENT, "type comment", "# type:") \
|
||||||
|
_(TK_RAISE, "raise", "raise")
|
||||||
|
|
||||||
|
|
||||||
static const char* valid_single_char_tokens = "+-*/%@()[]:,={}><.?!";
|
static const char* valid_single_char_tokens = "+-*/%@()[]:,={}><.?!";
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -382,6 +382,12 @@ struct Parser {
|
||||||
auto values = parseList(TK_NOTHING, ',', TK_NEWLINE, &Parser::parseExp);
|
auto values = parseList(TK_NOTHING, ',', TK_NEWLINE, &Parser::parseExp);
|
||||||
return Return::create(range, values);
|
return Return::create(range, values);
|
||||||
}
|
}
|
||||||
|
case TK_RAISE: {
|
||||||
|
auto range = L.next().range;
|
||||||
|
auto expr = parseExp();
|
||||||
|
L.expect(TK_NEWLINE);
|
||||||
|
return Raise::create(range, expr);
|
||||||
|
}
|
||||||
default: {
|
default: {
|
||||||
List<Expr> exprs = parseList(TK_NOTHING, ',', TK_NOTHING, &Parser::parseExp);
|
List<Expr> exprs = parseList(TK_NOTHING, ',', TK_NOTHING, &Parser::parseExp);
|
||||||
if (L.cur().kind != TK_NEWLINE) {
|
if (L.cur().kind != TK_NEWLINE) {
|
||||||
|
|
|
||||||
|
|
@ -131,6 +131,10 @@ void initTreeViewBindings(PyObject *module) {
|
||||||
.def(py::init([](const SourceRange& range, std::vector<Expr> values) {
|
.def(py::init([](const SourceRange& range, std::vector<Expr> values) {
|
||||||
return Return::create(range, wrap_list(range, std::move(values)));
|
return Return::create(range, wrap_list(range, std::move(values)));
|
||||||
}));
|
}));
|
||||||
|
py::class_<Raise, Stmt>(m, "Raise")
|
||||||
|
.def(py::init([](const SourceRange& range, Expr *expr) {
|
||||||
|
return Raise::create(range, wrap_maybe(range, expr));
|
||||||
|
}));
|
||||||
py::class_<If, Stmt>(m, "If")
|
py::class_<If, Stmt>(m, "If")
|
||||||
.def(py::init([](const SourceRange& range, const Expr& cond, std::vector<Stmt> true_branch, std::vector<Stmt> false_branch) {
|
.def(py::init([](const SourceRange& range, const Expr& cond, std::vector<Stmt> true_branch, std::vector<Stmt> false_branch) {
|
||||||
return If::create(range, cond,
|
return If::create(range, cond,
|
||||||
|
|
|
||||||
|
|
@ -30,6 +30,7 @@ namespace script {
|
||||||
// | Assign(List<Expr> lhs, AssignType maybe_reduce, Expr rhs) TK_ASSIGN
|
// | Assign(List<Expr> lhs, AssignType maybe_reduce, Expr rhs) TK_ASSIGN
|
||||||
// | Return(List<Expr> values) TK_RETURN
|
// | Return(List<Expr> values) TK_RETURN
|
||||||
// | ExprStmt(List<Expr> expr) TK_EXPR_STMT
|
// | ExprStmt(List<Expr> expr) TK_EXPR_STMT
|
||||||
|
// | Raise(Expr expr) TK_RAISE
|
||||||
//
|
//
|
||||||
// Expr = TernaryIf(Expr cond, Expr true_expr, Expr false_expr) TK_IF_EXPR
|
// Expr = TernaryIf(Expr cond, Expr true_expr, Expr false_expr) TK_IF_EXPR
|
||||||
// | BinOp(Expr lhs, Expr rhs)
|
// | BinOp(Expr lhs, Expr rhs)
|
||||||
|
|
@ -208,6 +209,7 @@ struct Stmt : public TreeView {
|
||||||
case TK_ASSIGN:
|
case TK_ASSIGN:
|
||||||
case TK_RETURN:
|
case TK_RETURN:
|
||||||
case TK_EXPR_STMT:
|
case TK_EXPR_STMT:
|
||||||
|
case TK_RAISE:
|
||||||
return;
|
return;
|
||||||
default:
|
default:
|
||||||
throw ErrorReport(tree) << kindToString(tree->kind()) << " is not a valid Stmt";
|
throw ErrorReport(tree) << kindToString(tree->kind()) << " is not a valid Stmt";
|
||||||
|
|
@ -464,6 +466,19 @@ struct Return : public Stmt {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct Raise : public Stmt {
|
||||||
|
explicit Raise(const TreeRef& tree) : Stmt(tree) {
|
||||||
|
tree_->match(TK_RAISE);
|
||||||
|
}
|
||||||
|
Maybe<Expr> expr() const {
|
||||||
|
return Maybe<Expr>(subtree(0));
|
||||||
|
}
|
||||||
|
static Raise create(const SourceRange& range, const Maybe<Expr>& expr) {
|
||||||
|
return Raise(Compound::create(TK_RAISE, range, {expr}));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
struct ExprStmt : public Stmt {
|
struct ExprStmt : public Stmt {
|
||||||
explicit ExprStmt(const TreeRef& tree) : Stmt(tree) {
|
explicit ExprStmt(const TreeRef& tree) : Stmt(tree) {
|
||||||
tree_->match(TK_EXPR_STMT);
|
tree_->match(TK_EXPR_STMT);
|
||||||
|
|
|
||||||
|
|
@ -14,15 +14,20 @@ namespace torch { namespace jit {
|
||||||
struct SourceLocation {
|
struct SourceLocation {
|
||||||
virtual ~SourceLocation() = default;
|
virtual ~SourceLocation() = default;
|
||||||
virtual void highlight(std::ostream & out) const = 0;
|
virtual void highlight(std::ostream & out) const = 0;
|
||||||
void wrapAndRethrowException(const std::exception & e, const std::string & additional = "") {
|
|
||||||
|
std::string wrapException(const std::exception & e, const std::string & additional = "") {
|
||||||
std::stringstream msg;
|
std::stringstream msg;
|
||||||
msg << "\n" << e.what() << ":\n";
|
msg << "\n" << e.what() << ":\n";
|
||||||
if(!additional.empty()) {
|
if(!additional.empty()) {
|
||||||
msg << additional << ":\n";
|
msg << additional << ":\n";
|
||||||
}
|
}
|
||||||
highlight(msg);
|
highlight(msg);
|
||||||
throw std::runtime_error(msg.str());
|
return msg.str();
|
||||||
}
|
}
|
||||||
|
void wrapAndRethrowException(const std::exception & e, const std::string & additional = "") {
|
||||||
|
throw std::runtime_error(wrapException(e, additional));
|
||||||
|
}
|
||||||
|
|
||||||
};
|
};
|
||||||
|
|
||||||
inline std::ostream& operator<<(std::ostream& out, const SourceLocation& sl) {
|
inline std::ostream& operator<<(std::ostream& out, const SourceLocation& sl) {
|
||||||
|
|
|
||||||
|
|
@ -1374,6 +1374,9 @@ _register_builtin(len, 'aten::len')
|
||||||
|
|
||||||
_register_builtin(_wait, 'aten::wait')
|
_register_builtin(_wait, 'aten::wait')
|
||||||
|
|
||||||
|
# torch.jit._Exception
|
||||||
|
_Exception = torch._C.JITException
|
||||||
|
|
||||||
|
|
||||||
class _disable_tracing(object):
|
class _disable_tracing(object):
|
||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
|
|
|
||||||
|
|
@ -266,6 +266,18 @@ class StmtBuilder(Builder):
|
||||||
values = (stmt.value,) if not isinstance(stmt.value, ast.Tuple) else stmt.value.elts
|
values = (stmt.value,) if not isinstance(stmt.value, ast.Tuple) else stmt.value.elts
|
||||||
return Return(r, [build_expr(ctx, val) for val in values if val is not None])
|
return Return(r, [build_expr(ctx, val) for val in values if val is not None])
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def build_Raise(ctx, stmt):
|
||||||
|
r = ctx.make_range(stmt.lineno, stmt.col_offset, stmt.col_offset + len("raise"))
|
||||||
|
if PY2:
|
||||||
|
if stmt.tback:
|
||||||
|
raise NotSupportedError(r, "tracebacks with exceptions is not supported")
|
||||||
|
# TODO use stmt.type once instantiating exceptions is supported
|
||||||
|
expr = build_expr(ctx, stmt.inst) if stmt.inst else None
|
||||||
|
else:
|
||||||
|
expr = build_expr(ctx, stmt.exc)
|
||||||
|
return Raise(r, expr)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def build_AugAssign(ctx, stmt):
|
def build_AugAssign(ctx, stmt):
|
||||||
lhs = [StmtBuilder.get_assign_lhs_expr(ctx, stmt.target)]
|
lhs = [StmtBuilder.get_assign_lhs_expr(ctx, stmt.target)]
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user