#pragma once #include "lexer.h" #include "tree.h" #include "tree_views.h" namespace torch { namespace jit { namespace script { Decl mergeTypesFromTypeComment(Decl decl, Decl type_annotation_decl, bool is_method) { auto expected_num_annotations = decl.params().size(); if (is_method) { // `self` argument expected_num_annotations -= 1; } if (expected_num_annotations != type_annotation_decl.params().size()) { throw ErrorReport(type_annotation_decl.range()) << "Number of type annotations (" << type_annotation_decl.params().size() << ") did not match the number of " << "function parameters (" << expected_num_annotations << ")"; } auto old = decl.params(); auto _new = type_annotation_decl.params(); // Merge signature idents and ranges with annotation types std::vector new_params; size_t i = is_method ? 1 : 0; size_t j = 0; if (is_method) { new_params.push_back(old[0]); } for (; i < decl.params().size(); ++i, ++j) { new_params.push_back(Param::create(old[i].range(), old[i].ident(), _new[j].type())); } return Decl::create(decl.range(), List::create(decl.range(), new_params), type_annotation_decl.return_type()); } struct Parser { explicit Parser(const std::string& str) : L(str), shared(sharedParserData()) {} Ident parseIdent() { auto t = L.expect(TK_IDENT); // whenever we parse something that has a TreeView type we always // use its create method so that the accessors and the constructor // of the Compound tree are in the same place. return Ident::create(t.range, t.text()); } TreeRef createApply(Expr expr) { TreeList attributes; auto range = L.cur().range; TreeList inputs; parseOperatorArguments(inputs, attributes); return Apply::create( range, expr, List(makeList(range, std::move(inputs))), List(makeList(range, std::move(attributes)))); } // exp | expr, | expr, expr, ... TreeRef parseExpOrExpTuple(int end) { auto prefix = parseExp(); if(L.cur().kind == ',') { std::vector exprs = { prefix }; while(L.cur().kind != end) { L.expect(','); if (L.cur().kind == end) break; exprs.push_back(parseExp()); } auto list = List::create(prefix.range(), exprs); prefix = TupleLiteral::create(list.range(), list); } return prefix; } // things like a 1.0 or a(4) that are not unary/binary expressions // and have higher precedence than all of them TreeRef parseBaseExp() { TreeRef prefix; switch (L.cur().kind) { case TK_NUMBER: { prefix = parseConst(); } break; case TK_TRUE: case TK_FALSE: case TK_NONE: { auto k = L.cur().kind; auto r = L.cur().range; prefix = c(k, r, {}); L.next(); } break; case '(': { L.next(); if (L.nextIf(')')) { /// here we have the empty tuple case std::vector vecExpr; List listExpr = List::create(L.cur().range, vecExpr); prefix = TupleLiteral::create(L.cur().range, listExpr); break; } prefix = parseExpOrExpTuple(')'); L.expect(')'); } break; case '[': { auto list = parseList('[', ',', ']', &Parser::parseExp); prefix = ListLiteral::create(list.range(), List(list)); } break; case TK_STRINGLITERAL: { prefix = parseStringLiteral(); } break; default: { Ident name = parseIdent(); prefix = Var::create(name.range(), name); } break; } while (true) { if (L.nextIf('.')) { const auto name = parseIdent(); prefix = Select::create(name.range(), Expr(prefix), Ident(name)); } else if (L.cur().kind == '(') { prefix = createApply(Expr(prefix)); } else if (L.cur().kind == '[') { prefix = parseSubscript(prefix); } else { break; } } return prefix; } TreeRef parseOptionalReduction() { auto r = L.cur().range; switch (L.cur().kind) { case TK_PLUS_EQ: case TK_MINUS_EQ: case TK_TIMES_EQ: case TK_DIV_EQ: { int modifier = L.next().text()[0]; return c(modifier, r, {}); } break; default: { L.expect('='); return c('=', r, {}); // no reduction } break; } } TreeRef parseTrinary(TreeRef true_branch, const SourceRange& range, int binary_prec) { auto cond = parseExp(); L.expect(TK_ELSE); auto false_branch = parseExp(binary_prec); return c(TK_IF_EXPR, range, {cond, true_branch, false_branch}); } // parse the longest expression whose binary operators have // precedence strictly greater than 'precedence' // precedence == 0 will parse _all_ expressions // this is the core loop of 'top-down precedence parsing' Expr parseExp() { return parseExp(0); } Expr parseExp(int precedence) { TreeRef prefix = nullptr; int unary_prec; if (shared.isUnary(L.cur().kind, &unary_prec)) { auto kind = L.cur().kind; auto pos = L.cur().range; L.next(); auto unary_kind = kind == '*' ? TK_STARRED : kind == '-' ? TK_UNARY_MINUS : kind; auto subexp = parseExp(unary_prec); // fold '-' into constant numbers, so that attributes can accept // things like -1 if(unary_kind == TK_UNARY_MINUS && subexp.kind() == TK_CONST) { prefix = Const::create(subexp.range(), "-" + Const(subexp).text()); } else { prefix = c(unary_kind, pos, {subexp}); } } else { prefix = parseBaseExp(); } int binary_prec; while (shared.isBinary(L.cur().kind, &binary_prec)) { if (binary_prec <= precedence) // not allowed to parse something which is // not greater than 'precedence' break; int kind = L.cur().kind; auto pos = L.cur().range; L.next(); if (shared.isRightAssociative(kind)) binary_prec--; // special case for trinary operator if (kind == TK_IF) { prefix = parseTrinary(prefix, pos, binary_prec); continue; } prefix = c(kind, pos, {prefix, parseExp(binary_prec)}); } return Expr(prefix); } template List parseList(int begin, int sep, int end, T (Parser::*parse)()) { auto r = L.cur().range; if (begin != TK_NOTHING) L.expect(begin); std::vector elements; if (L.cur().kind != end) { do { elements.push_back((this->*parse)()); } while (L.nextIf(sep)); } if (end != TK_NOTHING) L.expect(end); return List::create(r, elements); } Const parseConst() { auto range = L.cur().range; auto t = L.expect(TK_NUMBER); return Const::create(t.range, t.text()); } bool isCharCount(char c, const std::string& str, size_t start, int len) { //count checks from [start, start + len) return start + len <= str.size() && std::count(str.begin() + start, str.begin() + start + len, c) == len; } std::string parseString(const SourceRange& range, const std::string &str) { int quote_len = isCharCount(str[0], str, 0, 3) ? 3 : 1; auto ret_str = str.substr(quote_len, str.size() - quote_len * 2); size_t pos = ret_str.find('\\'); while(pos != std::string::npos) { //invariant: pos has to escape a character because it is a valid string char c = ret_str[pos + 1]; switch (ret_str[pos + 1]) { case '\\': case '\'': case '\"': case '\n': break; case 'a': c = '\a'; break; case 'b': c = '\b'; break; case 'f': c = '\f'; break; case 'n': c = '\n'; break; case 'v': c = '\v'; break; default: throw ErrorReport(range) << " octal and hex escaped sequences are not supported"; } ret_str.replace(pos, /* num to erase */ 2, /* num copies */ 1, c); pos = ret_str.find('\\', pos + 1); } return ret_str; } StringLiteral parseStringLiteral() { auto range = L.cur().range; std::stringstream ss; while(L.cur().kind == TK_STRINGLITERAL) { auto literal_range = L.cur().range; ss << parseString(literal_range, L.next().text()); } return StringLiteral::create(range, ss.str()); } Expr parseAttributeValue() { return parseExp(); } void parseOperatorArguments(TreeList& inputs, TreeList& attributes) { L.expect('('); if (L.cur().kind != ')') { do { if (L.cur().kind == TK_IDENT && L.lookahead().kind == '=') { auto ident = parseIdent(); L.expect('='); auto v = parseAttributeValue(); attributes.push_back(Attribute::create(ident.range(), Ident(ident), v)); } else { inputs.push_back(parseExp()); } } while (L.nextIf(',')); } L.expect(')'); } // Parse expr's of the form [a:], [:b], [a:b], [:] Expr parseSubscriptExp() { TreeRef first, second; auto range = L.cur().range; if (L.cur().kind != ':') { first = parseExp(); } if (L.nextIf(':')) { if (L.cur().kind != ',' && L.cur().kind != ']') { second = parseExp(); } auto maybe_first = first ? Maybe::create(range, Expr(first)) : Maybe::create(range); auto maybe_second = second ? Maybe::create(range, Expr(second)) : Maybe::create(range); return SliceExpr::create(range, maybe_first, maybe_second); } else { return Expr(first); } } TreeRef parseSubscript(TreeRef value) { const auto range = L.cur().range; auto subscript_exprs = parseList('[', ',', ']', &Parser::parseSubscriptExp); return Subscript::create(range, Expr(value), subscript_exprs); } TreeRef parseParam() { auto ident = parseIdent(); TreeRef type; if (L.nextIf(':')) { type = parseExp(); } else { type = Var::create(L.cur().range, Ident::create(L.cur().range, "Tensor")); } return Param::create(type->range(), Ident(ident), Expr(type)); } Param parseBareTypeAnnotation() { auto type = parseExp(); return Param::create(type.range(), Ident::create(type.range(), ""), type); } TreeRef parseTypeComment(bool parse_full_line=false) { auto range = L.cur().range; if (parse_full_line) { L.expect(TK_TYPE_COMMENT); } auto param_types = parseList('(', ',', ')', &Parser::parseBareTypeAnnotation); TreeRef return_type; if (L.nextIf(TK_ARROW)) { auto return_type_range = L.cur().range; return_type = Maybe::create(return_type_range, parseExp()); } else { return_type = Maybe::create(L.cur().range); } if (!parse_full_line) L.expect(TK_NEWLINE); return Decl::create(range, param_types, Maybe(return_type)); } // 'first' has already been parsed since expressions can exist // alone on a line: // first[,other,lhs] = rhs Assign parseAssign(List list) { auto red = parseOptionalReduction(); auto rhs = parseExpOrExpTuple(TK_NEWLINE); L.expect(TK_NEWLINE); return Assign::create(list.range(), list, AssignKind(red), Expr(rhs)); } TreeRef parseStmt() { switch (L.cur().kind) { case TK_IF: return parseIf(); case TK_WHILE: return parseWhile(); case TK_FOR: return parseFor(); case TK_GLOBAL: { auto range = L.next().range; auto idents = parseList(TK_NOTHING, ',', TK_NOTHING, &Parser::parseIdent); L.expect(TK_NEWLINE); return Global::create(range, idents); } case TK_RETURN: { auto range = L.next().range; // XXX: TK_NEWLINE makes it accept an empty list auto values = parseList(TK_NOTHING, ',', TK_NEWLINE, &Parser::parseExp); return Return::create(range, values); } case TK_RAISE: { auto range = L.next().range; auto expr = parseExp(); L.expect(TK_NEWLINE); return Raise::create(range, expr); } default: { List exprs = parseList(TK_NOTHING, ',', TK_NOTHING, &Parser::parseExp); if (L.cur().kind != TK_NEWLINE) { return parseAssign(exprs); } else { L.expect(TK_NEWLINE); return ExprStmt::create(exprs[0].range(), exprs); } } } } TreeRef parseOptionalIdentList() { TreeRef list = nullptr; if (L.cur().kind == '(') { list = parseList('(', ',', ')', &Parser::parseIdent); } else { list = c(TK_LIST, L.cur().range, {}); } return list; } TreeRef parseIf(bool expect_if=true) { auto r = L.cur().range; if (expect_if) L.expect(TK_IF); auto cond = parseExp(); L.expect(':'); auto true_branch = parseStatements(); auto false_branch = makeList(L.cur().range, {}); if (L.nextIf(TK_ELSE)) { L.expect(':'); false_branch = parseStatements(); } else if (L.nextIf(TK_ELIF)) { // NB: this needs to be a separate statement, since the call to parseIf // mutates the lexer state, and thus causes a heap-use-after-free in // compilers which evaluate argument expressions LTR auto range = L.cur().range; false_branch = makeList(range, {parseIf(false)}); } return If::create(r, Expr(cond), List(true_branch), List(false_branch)); } TreeRef parseWhile() { auto r = L.cur().range; L.expect(TK_WHILE); auto cond = parseExp(); L.expect(':'); auto body = parseStatements(); return While::create(r, Expr(cond), List(body)); } TreeRef parseFor() { auto r = L.cur().range; L.expect(TK_FOR); auto targets = parseList(TK_NOTHING, ',', TK_NOTHING, &Parser::parseExp); L.expect(TK_IN); auto itrs = parseList(TK_NOTHING, ',', TK_NOTHING, &Parser::parseExp); L.expect(':'); auto body = parseStatements(); return For::create(r, targets, itrs, body); } TreeRef parseStatements(bool expect_indent=true) { auto r = L.cur().range; if (expect_indent) L.expect(TK_INDENT); TreeList stmts; for (size_t i=0; ; ++i) { auto stmt = parseStmt(); stmts.push_back(stmt); if (L.nextIf(TK_DEDENT)) break; } return c(TK_LIST, r, std::move(stmts)); } Decl parseDecl() { auto paramlist = parseList('(', ',', ')', &Parser::parseParam); // Parse return type annotation TreeRef return_type; if (L.nextIf(TK_ARROW)) { // Exactly one expression for return type annotation auto return_type_range = L.cur().range; return_type = Maybe::create(return_type_range, parseExp()); } else { // Default to returning single tensor. TODO: better sentinel value? return_type = Maybe::create(L.cur().range); } L.expect(':'); return Decl::create(paramlist.range(), List(paramlist), Maybe(return_type)); } TreeRef parseFunction(bool is_method) { L.expect(TK_DEF); auto name = parseIdent(); auto decl = parseDecl(); // Handle type annotations specified in a type comment as the first line of // the function. L.expect(TK_INDENT); if (L.nextIf(TK_TYPE_COMMENT)) { auto type_annotation_decl = Decl(parseTypeComment()); decl = mergeTypesFromTypeComment(decl, type_annotation_decl, is_method); } auto stmts_list = parseStatements(false); return Def::create(name.range(), Ident(name), Decl(decl), List(stmts_list)); } Lexer& lexer() { return L; } private: // short helpers to create nodes TreeRef c(int kind, const SourceRange& range, TreeList&& trees) { return Compound::create(kind, range, std::move(trees)); } TreeRef makeList(const SourceRange& range, TreeList&& trees) { return c(TK_LIST, range, std::move(trees)); } Lexer L; SharedParserData& shared; }; } // namespace script } // namespace jit } // namespace torch