mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: * Deletes all weak script decorators / associated data structures / methods * In order to keep supporting the standard library in script, this enables recursive script on any function defined in `torch.nn` * Most changes in `torch/nn` are the result of `ag -Q "weak" torch/nn/ -l | xargs sed -i '/weak/d'`, only `rnn.py` needed manual editing to use the `ignore` and `export` to continue supporting the overloaded `forward` methods * `Sequential`/`ModuleList` no longer need to be added to constants since they are compiled on demand This should also fix https://github.com/pytorch/pytorch/issues/22212 Pull Request resolved: https://github.com/pytorch/pytorch/pull/22212 Differential Revision: D15988346 Pulled By: driazati fbshipit-source-id: af223e3ad0580be895377312949997a70e988e4f
701 lines
21 KiB
C++
701 lines
21 KiB
C++
#include <torch/csrc/jit/script/parser.h>
|
|
#include <c10/util/Optional.h>
|
|
#include <torch/csrc/jit/script/lexer.h>
|
|
#include <torch/csrc/jit/script/parse_string_literal.h>
|
|
#include <torch/csrc/jit/script/tree.h>
|
|
#include <torch/csrc/jit/script/tree_views.h>
|
|
|
|
namespace torch {
|
|
namespace jit {
|
|
namespace script {
|
|
|
|
Decl mergeTypesFromTypeComment(
|
|
const Decl& decl,
|
|
const Decl& type_annotation_decl,
|
|
bool is_method) {
|
|
auto expected_num_annotations = decl.params().size();
|
|
if (is_method) {
|
|
// `self` argument
|
|
expected_num_annotations -= 1;
|
|
}
|
|
if (expected_num_annotations != type_annotation_decl.params().size()) {
|
|
throw ErrorReport(type_annotation_decl.range())
|
|
<< "Number of type annotations ("
|
|
<< type_annotation_decl.params().size()
|
|
<< ") did not match the number of "
|
|
<< (is_method ? "method" : "function")
|
|
<< " parameters (" << expected_num_annotations << ")";
|
|
}
|
|
auto old = decl.params();
|
|
auto _new = type_annotation_decl.params();
|
|
// Merge signature idents and ranges with annotation types
|
|
|
|
std::vector<Param> new_params;
|
|
size_t i = is_method ? 1 : 0;
|
|
size_t j = 0;
|
|
if (is_method) {
|
|
new_params.push_back(old[0]);
|
|
}
|
|
for (; i < decl.params().size(); ++i, ++j) {
|
|
new_params.emplace_back(old[i].withType(_new[j].type()));
|
|
}
|
|
return Decl::create(
|
|
decl.range(),
|
|
List<Param>::create(decl.range(), new_params),
|
|
type_annotation_decl.return_type());
|
|
}
|
|
|
|
struct ParserImpl {
|
|
explicit ParserImpl(const std::shared_ptr<Source>& source)
|
|
: L(source), shared(sharedParserData()) {}
|
|
|
|
Ident parseIdent() {
|
|
auto t = L.expect(TK_IDENT);
|
|
// whenever we parse something that has a TreeView type we always
|
|
// use its create method so that the accessors and the constructor
|
|
// of the Compound tree are in the same place.
|
|
return Ident::create(t.range, t.text());
|
|
}
|
|
TreeRef createApply(const Expr& expr) {
|
|
TreeList attributes;
|
|
auto range = L.cur().range;
|
|
TreeList inputs;
|
|
parseArguments(inputs, attributes);
|
|
return Apply::create(
|
|
range,
|
|
expr,
|
|
List<Expr>(makeList(range, std::move(inputs))),
|
|
List<Attribute>(makeList(range, std::move(attributes))));
|
|
}
|
|
|
|
static bool followsTuple(int kind) {
|
|
switch (kind) {
|
|
case TK_PLUS_EQ:
|
|
case TK_MINUS_EQ:
|
|
case TK_TIMES_EQ:
|
|
case TK_DIV_EQ:
|
|
case TK_NEWLINE:
|
|
case '=':
|
|
case ')':
|
|
return true;
|
|
default:
|
|
return false;
|
|
}
|
|
}
|
|
|
|
// exp | expr, | expr, expr, ...
|
|
Expr parseExpOrExpTuple() {
|
|
auto prefix = parseExp();
|
|
if (L.cur().kind == ',') {
|
|
std::vector<Expr> exprs = {prefix};
|
|
while (L.nextIf(',')) {
|
|
if (followsTuple(L.cur().kind))
|
|
break;
|
|
exprs.push_back(parseExp());
|
|
}
|
|
auto list = List<Expr>::create(prefix.range(), exprs);
|
|
prefix = TupleLiteral::create(list.range(), list);
|
|
}
|
|
return prefix;
|
|
}
|
|
// things like a 1.0 or a(4) that are not unary/binary expressions
|
|
// and have higher precedence than all of them
|
|
TreeRef parseBaseExp() {
|
|
TreeRef prefix;
|
|
switch (L.cur().kind) {
|
|
case TK_NUMBER: {
|
|
prefix = parseConst();
|
|
} break;
|
|
case TK_TRUE:
|
|
case TK_FALSE:
|
|
case TK_NONE: {
|
|
auto k = L.cur().kind;
|
|
auto r = L.cur().range;
|
|
prefix = create_compound(k, r, {});
|
|
L.next();
|
|
} break;
|
|
case '(': {
|
|
L.next();
|
|
if (L.nextIf(')')) {
|
|
/// here we have the empty tuple case
|
|
std::vector<Expr> vecExpr;
|
|
List<Expr> listExpr = List<Expr>::create(L.cur().range, vecExpr);
|
|
prefix = TupleLiteral::create(L.cur().range, listExpr);
|
|
break;
|
|
}
|
|
prefix = parseExpOrExpTuple();
|
|
L.expect(')');
|
|
} break;
|
|
case '[': {
|
|
auto list = parseList('[', ',', ']', &ParserImpl::parseExp);
|
|
|
|
if (list.size() == 1 && (*list.begin()).kind() == TK_LIST_COMP) {
|
|
prefix = *list.begin();
|
|
} else {
|
|
for (auto se : list) {
|
|
if (se.kind() == TK_LIST_COMP) {
|
|
throw ErrorReport(list.range())
|
|
<< " expected a single list comprehension within '[' , ']'";
|
|
}
|
|
}
|
|
prefix = ListLiteral::create(list.range(), List<Expr>(list));
|
|
}
|
|
|
|
} break;
|
|
case '{': {
|
|
L.next();
|
|
std::vector<Expr> keys;
|
|
std::vector<Expr> values;
|
|
auto range = L.cur().range;
|
|
if (L.cur().kind != '}') {
|
|
do {
|
|
keys.push_back(parseExp());
|
|
L.expect(':');
|
|
values.push_back(parseExp());
|
|
} while (L.nextIf(','));
|
|
}
|
|
L.expect('}');
|
|
prefix = DictLiteral::create(
|
|
range,
|
|
List<Expr>::create(range, keys),
|
|
List<Expr>::create(range, values));
|
|
} break;
|
|
case TK_STRINGLITERAL: {
|
|
prefix = parseConcatenatedStringLiterals();
|
|
} break;
|
|
case TK_DOTS: {
|
|
prefix = Dots::create(L.cur().range);
|
|
L.next();
|
|
} break;
|
|
default: {
|
|
Ident name = parseIdent();
|
|
prefix = Var::create(name.range(), name);
|
|
} break;
|
|
}
|
|
while (true) {
|
|
if (L.nextIf('.')) {
|
|
const auto name = parseIdent();
|
|
prefix = Select::create(name.range(), Expr(prefix), Ident(name));
|
|
} else if (L.cur().kind == '(') {
|
|
prefix = createApply(Expr(prefix));
|
|
} else if (L.cur().kind == '[') {
|
|
prefix = parseSubscript(prefix);
|
|
} else {
|
|
break;
|
|
}
|
|
}
|
|
return prefix;
|
|
}
|
|
TreeRef parseAssignmentOp() {
|
|
auto r = L.cur().range;
|
|
switch (L.cur().kind) {
|
|
case TK_PLUS_EQ:
|
|
case TK_MINUS_EQ:
|
|
case TK_TIMES_EQ:
|
|
case TK_DIV_EQ: {
|
|
int modifier = L.next().text()[0];
|
|
return create_compound(modifier, r, {});
|
|
} break;
|
|
default: {
|
|
L.expect('=');
|
|
return create_compound('=', r, {}); // no reduction
|
|
} break;
|
|
}
|
|
}
|
|
TreeRef parseTrinary(
|
|
TreeRef true_branch,
|
|
const SourceRange& range,
|
|
int binary_prec) {
|
|
auto cond = parseExp();
|
|
L.expect(TK_ELSE);
|
|
auto false_branch = parseExp(binary_prec);
|
|
return create_compound(TK_IF_EXPR, range, {cond, std::move(true_branch), false_branch});
|
|
}
|
|
// parse the longest expression whose binary operators have
|
|
// precedence strictly greater than 'precedence'
|
|
// precedence == 0 will parse _all_ expressions
|
|
// this is the core loop of 'top-down precedence parsing'
|
|
Expr parseExp() {
|
|
return parseExp(0);
|
|
}
|
|
Expr parseExp(int precedence) {
|
|
TreeRef prefix;
|
|
int unary_prec;
|
|
if (shared.isUnary(L.cur().kind, &unary_prec)) {
|
|
auto kind = L.cur().kind;
|
|
auto pos = L.cur().range;
|
|
L.next();
|
|
auto unary_kind =
|
|
kind == '*' ? TK_STARRED : kind == '-' ? TK_UNARY_MINUS : kind;
|
|
auto subexp = parseExp(unary_prec);
|
|
// fold '-' into constant numbers, so that attributes can accept
|
|
// things like -1
|
|
if (unary_kind == TK_UNARY_MINUS && subexp.kind() == TK_CONST) {
|
|
prefix = Const::create(subexp.range(), "-" + Const(subexp).text());
|
|
} else {
|
|
prefix = create_compound(unary_kind, pos, {subexp});
|
|
}
|
|
} else {
|
|
prefix = parseBaseExp();
|
|
}
|
|
int binary_prec;
|
|
while (shared.isBinary(L.cur().kind, &binary_prec)) {
|
|
if (binary_prec <= precedence) // not allowed to parse something which is
|
|
// not greater than 'precedence'
|
|
break;
|
|
|
|
int kind = L.cur().kind;
|
|
auto pos = L.cur().range;
|
|
L.next();
|
|
if (shared.isRightAssociative(kind))
|
|
binary_prec--;
|
|
|
|
// special case for trinary operator
|
|
if (kind == TK_IF) {
|
|
prefix = parseTrinary(prefix, pos, binary_prec);
|
|
continue;
|
|
}
|
|
|
|
if (kind == TK_FOR) {
|
|
// TK_FOR targets should only parse exprs prec greater than 4, which only
|
|
// includes subset of Exprs that suppose to be on the LHS according to the
|
|
// python grammer https://docs.python.org/3/reference/grammar.html
|
|
auto target = parseLHSExp();
|
|
L.expect(TK_IN);
|
|
auto iter = parseExp();
|
|
prefix = ListComp::create(pos, Expr(prefix), target, iter);
|
|
continue;
|
|
}
|
|
|
|
prefix = create_compound(kind, pos, {prefix, parseExp(binary_prec)});
|
|
}
|
|
return Expr(prefix);
|
|
}
|
|
|
|
void parseSequence(
|
|
int begin,
|
|
int sep,
|
|
int end,
|
|
const std::function<void()>& parse) {
|
|
if (begin != TK_NOTHING) {
|
|
L.expect(begin);
|
|
}
|
|
while (end != L.cur().kind) {
|
|
parse();
|
|
if (!L.nextIf(sep)) {
|
|
if (end != TK_NOTHING) {
|
|
L.expect(end);
|
|
}
|
|
return;
|
|
}
|
|
}
|
|
L.expect(end);
|
|
}
|
|
template <typename T>
|
|
List<T> parseList(int begin, int sep, int end, T (ParserImpl::*parse)()) {
|
|
auto r = L.cur().range;
|
|
std::vector<T> elements;
|
|
parseSequence(
|
|
begin, sep, end, [&] { elements.emplace_back((this->*parse)()); });
|
|
return List<T>::create(r, elements);
|
|
}
|
|
|
|
Const parseConst() {
|
|
auto range = L.cur().range;
|
|
auto t = L.expect(TK_NUMBER);
|
|
return Const::create(t.range, t.text());
|
|
}
|
|
|
|
StringLiteral parseConcatenatedStringLiterals() {
|
|
auto range = L.cur().range;
|
|
std::stringstream ss;
|
|
while (L.cur().kind == TK_STRINGLITERAL) {
|
|
auto literal_range = L.cur().range;
|
|
ss << parseStringLiteral(literal_range, L.next().text());
|
|
}
|
|
return StringLiteral::create(range, ss.str());
|
|
}
|
|
|
|
Expr parseAttributeValue() {
|
|
return parseExp();
|
|
}
|
|
|
|
void parseArguments(TreeList& inputs, TreeList& attributes) {
|
|
parseSequence('(', ',', ')', [&] {
|
|
if (L.cur().kind == TK_IDENT && L.lookahead().kind == '=') {
|
|
auto ident = parseIdent();
|
|
L.expect('=');
|
|
auto v = parseAttributeValue();
|
|
attributes.push_back(Attribute::create(ident.range(), Ident(ident), v));
|
|
} else {
|
|
inputs.push_back(parseExp());
|
|
}
|
|
});
|
|
}
|
|
|
|
// parse LHS acceptable exprs, which only includes subset of Exprs that prec is
|
|
// greater than 4 according to the python grammer
|
|
Expr parseLHSExp() {
|
|
return parseExp(4);
|
|
}
|
|
|
|
// Parse expr's of the form [a:], [:b], [a:b], [:] and all variations with "::"
|
|
Expr parseSubscriptExp() {
|
|
TreeRef first, second, third;
|
|
auto range = L.cur().range;
|
|
if (L.cur().kind != ':') {
|
|
first = parseExp();
|
|
}
|
|
if (L.nextIf(':')) {
|
|
if (L.cur().kind != ',' && L.cur().kind != ']' && L.cur().kind != ':') {
|
|
second = parseExp();
|
|
}
|
|
if (L.nextIf(':')) {
|
|
if (L.cur().kind != ',' && L.cur().kind != ']') {
|
|
third = parseExp();
|
|
}
|
|
}
|
|
auto maybe_first = first ? Maybe<Expr>::create(range, Expr(first))
|
|
: Maybe<Expr>::create(range);
|
|
auto maybe_second = second ? Maybe<Expr>::create(range, Expr(second))
|
|
: Maybe<Expr>::create(range);
|
|
auto maybe_third = third ? Maybe<Expr>::create(range, Expr(third))
|
|
: Maybe<Expr>::create(range);
|
|
return SliceExpr::create(range, maybe_first, maybe_second, maybe_third);
|
|
} else {
|
|
return Expr(first);
|
|
}
|
|
}
|
|
|
|
TreeRef parseSubscript(const TreeRef& value) {
|
|
const auto range = L.cur().range;
|
|
|
|
auto subscript_exprs =
|
|
parseList('[', ',', ']', &ParserImpl::parseSubscriptExp);
|
|
|
|
return Subscript::create(range, Expr(value), subscript_exprs);
|
|
}
|
|
|
|
Maybe<Expr> maybeParseTypeAnnotation() {
|
|
if (L.nextIf(':')) {
|
|
// NB: parseExp must not be called inline, since argument evaluation order
|
|
// changes when L.cur().range is mutated with respect to the parseExp()
|
|
// call.
|
|
auto expr = parseExp();
|
|
return Maybe<Expr>::create(expr.range(), expr);
|
|
} else {
|
|
return Maybe<Expr>::create(L.cur().range);
|
|
}
|
|
}
|
|
|
|
TreeRef parseFormalParam(bool kwarg_only) {
|
|
auto ident = parseIdent();
|
|
TreeRef type = maybeParseTypeAnnotation();
|
|
TreeRef def;
|
|
if (L.nextIf('=')) {
|
|
// NB: parseExp must not be called inline, since argument evaluation order
|
|
// changes when L.cur().range is mutated with respect to the parseExp()
|
|
// call.
|
|
auto expr = parseExp();
|
|
def = Maybe<Expr>::create(expr.range(), expr);
|
|
} else {
|
|
def = Maybe<Expr>::create(L.cur().range);
|
|
}
|
|
return Param::create(
|
|
type->range(), Ident(ident), Maybe<Expr>(type), Maybe<Expr>(def), kwarg_only);
|
|
}
|
|
|
|
Param parseBareTypeAnnotation() {
|
|
auto type = parseExp();
|
|
return Param::create(
|
|
type.range(),
|
|
Ident::create(type.range(), ""),
|
|
Maybe<Expr>::create(type.range(), type),
|
|
Maybe<Expr>::create(type.range()),
|
|
/*kwarg_only=*/false);
|
|
}
|
|
|
|
Decl parseTypeComment() {
|
|
auto range = L.cur().range;
|
|
L.expect(TK_TYPE_COMMENT);
|
|
auto param_types =
|
|
parseList('(', ',', ')', &ParserImpl::parseBareTypeAnnotation);
|
|
TreeRef return_type;
|
|
if (L.nextIf(TK_ARROW)) {
|
|
auto return_type_range = L.cur().range;
|
|
return_type = Maybe<Expr>::create(return_type_range, parseExp());
|
|
} else {
|
|
return_type = Maybe<Expr>::create(L.cur().range);
|
|
}
|
|
return Decl::create(range, param_types, Maybe<Expr>(return_type));
|
|
}
|
|
|
|
// 'first' has already been parsed since expressions can exist
|
|
// alone on a line:
|
|
// first[,other,lhs] = rhs
|
|
TreeRef parseAssign(const Expr& lhs) {
|
|
auto type = maybeParseTypeAnnotation();
|
|
auto op = parseAssignmentOp();
|
|
auto rhs = parseExpOrExpTuple();
|
|
L.expect(TK_NEWLINE);
|
|
if (op->kind() == '=') {
|
|
return Assign::create(lhs.range(), lhs, Expr(rhs), type);
|
|
} else {
|
|
// this is an augmented assignment
|
|
if (lhs.kind() == TK_TUPLE_LITERAL) {
|
|
throw ErrorReport(lhs.range())
|
|
<< " augmented assignment can only have one LHS expression";
|
|
}
|
|
return AugAssign::create(lhs.range(), lhs, AugAssignKind(op), Expr(rhs));
|
|
}
|
|
}
|
|
|
|
TreeRef parseStmt() {
|
|
switch (L.cur().kind) {
|
|
case TK_IF:
|
|
return parseIf();
|
|
case TK_WHILE:
|
|
return parseWhile();
|
|
case TK_FOR:
|
|
return parseFor();
|
|
case TK_GLOBAL: {
|
|
auto range = L.next().range;
|
|
auto idents =
|
|
parseList(TK_NOTHING, ',', TK_NOTHING, &ParserImpl::parseIdent);
|
|
L.expect(TK_NEWLINE);
|
|
return Global::create(range, idents);
|
|
}
|
|
case TK_RETURN: {
|
|
auto range = L.next().range;
|
|
Expr value = L.cur().kind != TK_NEWLINE ? parseExpOrExpTuple()
|
|
: Expr(create_compound(TK_NONE, range, {}));
|
|
L.expect(TK_NEWLINE);
|
|
return Return::create(range, value);
|
|
}
|
|
case TK_RAISE: {
|
|
auto range = L.next().range;
|
|
auto expr = parseExp();
|
|
L.expect(TK_NEWLINE);
|
|
return Raise::create(range, expr);
|
|
}
|
|
case TK_ASSERT: {
|
|
auto range = L.next().range;
|
|
auto cond = parseExp();
|
|
Maybe<Expr> maybe_first = Maybe<Expr>::create(range);
|
|
if (L.nextIf(',')) {
|
|
auto msg = parseExp();
|
|
maybe_first = Maybe<Expr>::create(range, Expr(msg));
|
|
}
|
|
L.expect(TK_NEWLINE);
|
|
return Assert::create(range, cond, maybe_first);
|
|
}
|
|
case TK_PASS: {
|
|
auto range = L.next().range;
|
|
L.expect(TK_NEWLINE);
|
|
return Pass::create(range);
|
|
}
|
|
case TK_DEF: {
|
|
return parseFunction(/*is_method=*/false);
|
|
}
|
|
default: {
|
|
auto lhs = parseExpOrExpTuple();
|
|
if (L.cur().kind != TK_NEWLINE) {
|
|
return parseAssign(lhs);
|
|
} else {
|
|
L.expect(TK_NEWLINE);
|
|
return ExprStmt::create(lhs.range(), lhs);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
TreeRef parseIf(bool expect_if = true) {
|
|
auto r = L.cur().range;
|
|
if (expect_if)
|
|
L.expect(TK_IF);
|
|
auto cond = parseExp();
|
|
L.expect(':');
|
|
auto true_branch = parseStatements();
|
|
auto false_branch = makeList(L.cur().range, {});
|
|
if (L.nextIf(TK_ELSE)) {
|
|
L.expect(':');
|
|
false_branch = parseStatements();
|
|
} else if (L.nextIf(TK_ELIF)) {
|
|
// NB: this needs to be a separate statement, since the call to parseIf
|
|
// mutates the lexer state, and thus causes a heap-use-after-free in
|
|
// compilers which evaluate argument expressions LTR
|
|
auto range = L.cur().range;
|
|
false_branch = makeList(range, {parseIf(false)});
|
|
}
|
|
return If::create(
|
|
r, Expr(cond), List<Stmt>(true_branch), List<Stmt>(false_branch));
|
|
}
|
|
TreeRef parseWhile() {
|
|
auto r = L.cur().range;
|
|
L.expect(TK_WHILE);
|
|
auto cond = parseExp();
|
|
L.expect(':');
|
|
auto body = parseStatements();
|
|
return While::create(r, Expr(cond), List<Stmt>(body));
|
|
}
|
|
|
|
TreeRef parseFor() {
|
|
auto r = L.cur().range;
|
|
L.expect(TK_FOR);
|
|
auto targets = parseList(TK_NOTHING, ',', TK_IN, &ParserImpl::parseLHSExp);
|
|
auto itrs = parseList(TK_NOTHING, ',', ':', &ParserImpl::parseExp);
|
|
auto body = parseStatements();
|
|
return For::create(r, targets, itrs, body);
|
|
}
|
|
|
|
TreeRef parseStatements(bool expect_indent = true) {
|
|
auto r = L.cur().range;
|
|
if (expect_indent) {
|
|
L.expect(TK_INDENT);
|
|
}
|
|
TreeList stmts;
|
|
do {
|
|
stmts.push_back(parseStmt());
|
|
} while (!L.nextIf(TK_DEDENT));
|
|
return create_compound(TK_LIST, r, std::move(stmts));
|
|
}
|
|
|
|
Maybe<Expr> parseReturnAnnotation() {
|
|
if (L.nextIf(TK_ARROW)) {
|
|
// Exactly one expression for return type annotation
|
|
auto return_type_range = L.cur().range;
|
|
return Maybe<Expr>::create(return_type_range, parseExp());
|
|
} else {
|
|
return Maybe<Expr>::create(L.cur().range);
|
|
}
|
|
}
|
|
|
|
List<Param> parseFormalParams() {
|
|
auto r = L.cur().range;
|
|
std::vector<Param> params;
|
|
bool kwarg_only = false;
|
|
parseSequence('(', ',', ')', [&] {
|
|
if (!kwarg_only && L.nextIf('*')) {
|
|
kwarg_only = true;
|
|
} else {
|
|
params.emplace_back(parseFormalParam(kwarg_only));
|
|
}
|
|
});
|
|
return List<Param>::create(r, params);
|
|
}
|
|
Decl parseDecl() {
|
|
// Parse return type annotation
|
|
List<Param> paramlist = parseFormalParams();
|
|
TreeRef return_type;
|
|
Maybe<Expr> return_annotation = parseReturnAnnotation();
|
|
L.expect(':');
|
|
return Decl::create(
|
|
paramlist.range(), List<Param>(paramlist), return_annotation);
|
|
}
|
|
|
|
TreeRef parseNamedTuple(const Ident& name) {
|
|
const auto& range = name.range();
|
|
L.expect(')');
|
|
L.expect(':');
|
|
L.expect(TK_INDENT);
|
|
std::vector<Ident> fields;
|
|
std::vector<Maybe<Expr>> type_exprs;
|
|
while (L.cur().kind != TK_DEDENT) {
|
|
fields.push_back(parseIdent());
|
|
type_exprs.push_back(maybeParseTypeAnnotation());
|
|
L.expect(TK_NEWLINE);
|
|
}
|
|
L.expect(TK_DEDENT);
|
|
return NamedTupleDef::create(
|
|
range,
|
|
name,
|
|
List<Ident>::create(range, fields),
|
|
List<Maybe<Expr>>::create(range, type_exprs));
|
|
}
|
|
|
|
TreeRef parseClassLike() {
|
|
L.expect(TK_CLASS_DEF);
|
|
const auto name = parseIdent();
|
|
if (L.nextIf('(')) {
|
|
// Only support inheriting from NamedTuple right now.
|
|
if (L.cur().kind == TK_IDENT && L.cur().text() == "NamedTuple") {
|
|
L.next();
|
|
return parseNamedTuple(name);
|
|
} else {
|
|
L.reportError("Inheritance is not supported for TorchScript classes");
|
|
}
|
|
}
|
|
L.expect(':');
|
|
|
|
L.expect(TK_INDENT);
|
|
std::vector<Def> methods;
|
|
while (L.cur().kind != TK_DEDENT) {
|
|
methods.push_back(Def(parseFunction(/*is_method=*/true)));
|
|
}
|
|
L.expect(TK_DEDENT);
|
|
|
|
return ClassDef::create(
|
|
name.range(), name, List<Def>::create(name.range(), methods));
|
|
}
|
|
|
|
TreeRef parseFunction(bool is_method) {
|
|
L.expect(TK_DEF);
|
|
auto name = parseIdent();
|
|
auto decl = parseDecl();
|
|
|
|
// Handle type annotations specified in a type comment as the first line of
|
|
// the function.
|
|
L.expect(TK_INDENT);
|
|
if (L.cur().kind == TK_TYPE_COMMENT) {
|
|
auto type_annotation_decl = Decl(parseTypeComment());
|
|
L.expect(TK_NEWLINE);
|
|
decl = mergeTypesFromTypeComment(decl, type_annotation_decl, is_method);
|
|
}
|
|
|
|
auto stmts_list = parseStatements(false);
|
|
return Def::create(
|
|
name.range(), Ident(name), Decl(decl), List<Stmt>(stmts_list));
|
|
}
|
|
Lexer& lexer() {
|
|
return L;
|
|
}
|
|
|
|
private:
|
|
// short helpers to create nodes
|
|
TreeRef create_compound(
|
|
int kind,
|
|
const SourceRange& range,
|
|
TreeList&& trees) {
|
|
return Compound::create(kind, range, std::move(trees));
|
|
}
|
|
TreeRef makeList(const SourceRange& range, TreeList&& trees) {
|
|
return create_compound(TK_LIST, range, std::move(trees));
|
|
}
|
|
Lexer L;
|
|
SharedParserData& shared;
|
|
};
|
|
|
|
Parser::Parser(const std::shared_ptr<Source>& src)
|
|
: pImpl(new ParserImpl(src)) {}
|
|
|
|
Parser::~Parser() = default;
|
|
|
|
TreeRef Parser::parseFunction(bool is_method) {
|
|
return pImpl->parseFunction(is_method);
|
|
}
|
|
TreeRef Parser::parseClassLike() {
|
|
return pImpl->parseClassLike();
|
|
}
|
|
Lexer& Parser::lexer() {
|
|
return pImpl->lexer();
|
|
}
|
|
Decl Parser::parseTypeComment() {
|
|
return pImpl->parseTypeComment();
|
|
}
|
|
Expr Parser::parseExp() {
|
|
return pImpl->parseExp();
|
|
}
|
|
|
|
} // namespace script
|
|
} // namespace jit
|
|
} // namespace torch
|