mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary:
This is a first step towards adding exceptions. We need minimal support in order to begin converting the torch library to weak script mode (which is the main goal here).
Some limitations (that are documented in the tests & compiler):
1. Cannot assign exceptions to variables
2. Any name after raise is being treated as a valid Exception
3. No control flow analysis yet. Below a will be undefined:
if True:
a = 1
else:
raise Exception("Hi")
return a
Pull Request resolved: https://github.com/pytorch/pytorch/pull/12789
Differential Revision: D12848936
Pulled By: eellison
fbshipit-source-id: 1f60ceef2381040486123ec797e97d65b074862d
502 lines
16 KiB
C++
502 lines
16 KiB
C++
#pragma once
|
|
#include <algorithm>
|
|
#include <iostream>
|
|
#include <memory>
|
|
#include <sstream>
|
|
#include <string>
|
|
#include <unordered_map>
|
|
#include <vector>
|
|
#include "torch/csrc/jit/assertions.h"
|
|
#include "torch/csrc/jit/source_range.h"
|
|
#include <torch/csrc/utils/memory.h>
|
|
#include <clocale>
|
|
|
|
namespace torch {
|
|
namespace jit {
|
|
namespace script {
|
|
|
|
// single character tokens are just the character itself '+'
|
|
// multi-character tokens need an entry here
|
|
// if the third entry is not the empty string, it is used
|
|
// in the lexer to match this token.
|
|
|
|
// These kinds are also used in Tree.h as the kind of the AST node.
|
|
// Some kinds TK_APPLY, TK_LIST are only used in the AST and are not seen in the
|
|
// lexer.
|
|
|
|
#define TC_FORALL_TOKEN_KINDS(_) \
|
|
_(TK_EOF, "eof", "") \
|
|
_(TK_WHITESPACE, "whitespace", "") \
|
|
_(TK_NUMBER, "number", "") \
|
|
_(TK_NEWLINE, "newline", "") \
|
|
_(TK_INDENT, "indent", "") \
|
|
_(TK_DEDENT, "dedent", "") \
|
|
_(TK_DEF, "def", "def") \
|
|
_(TK_EQUIVALENT, "equivalent", "<=>") \
|
|
_(TK_IDENT, "ident", "") \
|
|
_(TK_STRING, "string", "") \
|
|
_(TK_STRINGLITERAL, "string_literal", "") \
|
|
_(TK_CONST, "const", "") \
|
|
_(TK_LIST, "list", "") \
|
|
_(TK_OPTION, "option", "") \
|
|
_(TK_APPLY, "apply", "") \
|
|
_(TK_COMPREHENSION, "comprehension", "") \
|
|
_(TK_RANGE_CONSTRAINT, "range_constraint", "") \
|
|
_(TK_PARAM, "param", "") \
|
|
_(TK_INFERRED, "inferred", "") \
|
|
_(TK_ACCESS, "access", "") \
|
|
_(TK_ASSIGN, "assign", "") \
|
|
_(TK_ATTRIBUTE, "attribute", "") \
|
|
_(TK_IF, "if", "if") \
|
|
_(TK_ELSE, "else", "else") \
|
|
_(TK_ELIF, "elif", "elif") \
|
|
_(TK_WHILE, "while", "while") \
|
|
_(TK_EXPR_STMT, "expression statement", "") \
|
|
_(TK_RETURN, "return", "return") \
|
|
_(TK_NE, "ne", "!=") \
|
|
_(TK_EQ, "eq", "==") \
|
|
_(TK_LE, "le", "<=") \
|
|
_(TK_GE, "ge", ">=") \
|
|
_(TK_IF_EXPR, "if", "") \
|
|
_(TK_TRUE, "True", "True") \
|
|
_(TK_FALSE, "False", "False") \
|
|
_(TK_NONE, "None", "None") \
|
|
_(TK_AND, "and", "and") \
|
|
_(TK_OR, "or", "or") \
|
|
_(TK_NOT, "not", "not") \
|
|
_(TK_CAST, "cast", "") \
|
|
_(TK_PLUS_EQ, "+=", "+=") \
|
|
_(TK_MINUS_EQ, "-=", "-=") \
|
|
_(TK_TIMES_EQ, "*=", "*=") \
|
|
_(TK_DIV_EQ, "/=", "/=") \
|
|
_(TK_GLOBAL, "global", "global") \
|
|
_(TK_BUILT_IN, "built-in", "") \
|
|
_(TK_SUBSCRIPT, "subscript", "") \
|
|
_(TK_VAR, "variable", "") \
|
|
_(TK_NOTHING, "nothing", "") \
|
|
_(TK_LIST_LITERAL, "list-literal", "") \
|
|
_(TK_TUPLE_LITERAL, "tuple-literal", "") \
|
|
_(TK_FOR, "for", "for") \
|
|
_(TK_IN, "in", "in") \
|
|
_(TK_STARRED, "starred", "") \
|
|
_(TK_UNARY_MINUS, "unary minus", "") \
|
|
_(TK_POW, "pow operator", "**") \
|
|
_(TK_ARROW, "arrow", "->") \
|
|
_(TK_DECL, "decl", "") \
|
|
_(TK_SLICE_EXPR, "slice expr", "") \
|
|
_(TK_TYPE_COMMENT, "type comment", "# type:") \
|
|
_(TK_RAISE, "raise", "raise")
|
|
|
|
|
|
static const char* valid_single_char_tokens = "+-*/%@()[]:,={}><.?!";
|
|
|
|
enum TokenKind {
|
|
// we use characters to represent themselves so skip all valid characters
|
|
// before
|
|
// assigning enum values to multi-char tokens.
|
|
TK_DUMMY_START = 256,
|
|
#define DEFINE_TOKEN(tok, _, _2) tok,
|
|
TC_FORALL_TOKEN_KINDS(DEFINE_TOKEN)
|
|
#undef DEFINE_TOKEN
|
|
};
|
|
|
|
std::string kindToString(int kind);
|
|
int stringToKind(std::string str);
|
|
|
|
// nested hash tables that indicate char-by-char what is a valid token.
|
|
struct TokenTrie;
|
|
using TokenTrieRef = std::unique_ptr<TokenTrie>;
|
|
struct TokenTrie {
|
|
TokenTrie() : kind(0) {}
|
|
void insert(const char* str, int tok) {
|
|
if (*str == '\0') {
|
|
JIT_ASSERT(kind == 0);
|
|
kind = tok;
|
|
return;
|
|
}
|
|
|
|
for (size_t i = 0, e = child_chars.size(); i < e; ++i) {
|
|
if (child_chars[i] == *str) {
|
|
child_tries[i]->insert(str + 1, tok);
|
|
return;
|
|
}
|
|
}
|
|
|
|
child_chars.emplace_back(*str);
|
|
child_tries.emplace_back(torch::make_unique<TokenTrie>());
|
|
child_tries.back()->insert(str + 1, tok);
|
|
}
|
|
int kind; // 0 == invalid token
|
|
|
|
std::vector<char> child_chars;
|
|
std::vector<TokenTrieRef> child_tries;
|
|
};
|
|
|
|
// stuff that is shared against all TC lexers/parsers and is initialized only
|
|
// once.
|
|
struct SharedParserData {
|
|
SharedParserData() : head(new TokenTrie()) {
|
|
std::stringstream ss;
|
|
for (const char* c = valid_single_char_tokens; *c; c++) {
|
|
std::string str(1, *c);
|
|
head->insert(str.c_str(), *c);
|
|
}
|
|
|
|
#define ADD_CASE(tok, _, tokstring) \
|
|
if (*(tokstring) != '\0') { \
|
|
head->insert((tokstring), (tok)); \
|
|
}
|
|
TC_FORALL_TOKEN_KINDS(ADD_CASE)
|
|
#undef ADD_CASE
|
|
}
|
|
#ifdef _WIN32
|
|
double strtod_c(const char * str, char** end) {
|
|
static _locale_t loc = _create_locale(LC_ALL, "C");
|
|
return _strtod_l(str, end, loc);
|
|
}
|
|
#else
|
|
double strtod_c(const char * str, char** end) {
|
|
static locale_t loc = newlocale(LC_ALL_MASK, "C", nullptr);
|
|
return strtod_l(str, end, loc);
|
|
}
|
|
#endif
|
|
// 1. skip whitespace
|
|
// 2. handle comment or newline
|
|
//
|
|
bool isNumber(const std::string& str, size_t start, size_t* len) {
|
|
char first = str[start];
|
|
// strtod allows numbers to start with + or - or nan or inf
|
|
// http://en.cppreference.com/w/cpp/string/byte/strtof
|
|
// but we want only the number part, otherwise 1+3 will turn into two
|
|
// adjacent numbers in the lexer
|
|
if (first == '-' || first == '+' || isalpha(first))
|
|
return false;
|
|
const char* startptr = str.c_str() + start;
|
|
char* endptr;
|
|
strtod_c(startptr, &endptr);
|
|
*len = endptr - startptr;
|
|
return *len > 0;
|
|
}
|
|
|
|
bool isCharCount(char c, const std::string& str, size_t start, int len) {
|
|
//count checks from [start, start + len)
|
|
return start + len <= str.size() && std::count(str.begin() + start, str.begin() + start + len, c) == len;
|
|
}
|
|
|
|
// python concatenates all adjacent strings "a" "b" == "ab"
|
|
// strings can be enclosed with 1 or 3 single or double quotes
|
|
// if enclosed with 3 quotes newlines are valid
|
|
// as elsewhere, backslash and new line should be ignored
|
|
bool isString(const std::string& str, size_t start, size_t* len) {
|
|
char quote = str[start];
|
|
if (quote != '\"' && quote != '\'')
|
|
return false;
|
|
int quote_len = isCharCount(quote, str, start, 3) ? 3 : 1;
|
|
|
|
//end is now set past the opening quotation marks
|
|
size_t end = start + quote_len;
|
|
while(end < str.size() && !isCharCount(quote, str, end, quote_len)) {
|
|
if (str[end] == '\n' && quote_len != 3) {
|
|
return false;
|
|
}
|
|
//handle escaped characters. advances past escaped quotation marks,
|
|
//escaped newlines and escaped backslashes
|
|
if (str[end] == '\\') {
|
|
end++;
|
|
}
|
|
end++;
|
|
}
|
|
//set length equal to the complete string including quotations
|
|
*len = end - start + quote_len;
|
|
//if end finished without going past the last character of the string than
|
|
//there is a match
|
|
return end < str.size();
|
|
}
|
|
|
|
bool isblank(int n) {
|
|
return isspace(n) && n != '\n';
|
|
}
|
|
// Make an exception ignoring comments for type annotation comments
|
|
bool isTypeComment(const std::string& str, size_t pos) {
|
|
const std::string type_string = "# type:";
|
|
if (str.size() < pos + type_string.length()) {
|
|
return false;
|
|
}
|
|
auto match_string = str.substr(pos, type_string.size());
|
|
return match_string == type_string;
|
|
}
|
|
// find the longest match of str.substring(pos) against a token, return true
|
|
// if successful
|
|
// filling in kind, start,and len
|
|
bool match(
|
|
const std::string& str,
|
|
size_t pos,
|
|
bool continuation, // are we inside a scope where newlines don't count
|
|
// (e.g. inside parens)
|
|
bool whitespace_token, // should we treat whitespace as a token
|
|
int* kind,
|
|
size_t* start,
|
|
size_t* len) {
|
|
*start = pos;
|
|
// skip whitespace
|
|
while (pos < str.size() && isblank(str[pos]))
|
|
pos++;
|
|
|
|
// special handling
|
|
if (pos < str.size()) {
|
|
if (str[pos] == '#' && !isTypeComment(str, pos)) {
|
|
// skip comments
|
|
while (pos < str.size() && str[pos] != '\n')
|
|
pos++;
|
|
// tail call, handle whitespace and more comments
|
|
return match(
|
|
str, pos, continuation, whitespace_token, kind, start, len);
|
|
}
|
|
if (str[pos] == '\\' && pos + 1 < str.size() && str[pos + 1] == '\n' &&
|
|
!whitespace_token) {
|
|
return match(str, pos + 2, continuation, false, kind, start, len);
|
|
}
|
|
if (str[pos] == '\n') {
|
|
return match(
|
|
str, pos + 1, continuation, !continuation, kind, start, len);
|
|
}
|
|
}
|
|
if (pos == str.size()) {
|
|
*kind = TK_EOF;
|
|
*start = pos;
|
|
*len = 0;
|
|
return true;
|
|
}
|
|
// invariant: the next token is not whitespace or newline
|
|
if (whitespace_token) {
|
|
*kind = TK_WHITESPACE;
|
|
*len = pos - *start;
|
|
return true;
|
|
}
|
|
*start = pos;
|
|
// check for a valid number
|
|
if (isNumber(str, pos, len)) {
|
|
*kind = TK_NUMBER;
|
|
return true;
|
|
}
|
|
// check for string
|
|
if (isString(str, pos, len)) {
|
|
*kind = TK_STRINGLITERAL;
|
|
return true;
|
|
}
|
|
|
|
// check for either an ident or a token
|
|
// ident tracks whether what we have scanned so far could be an identifier
|
|
// matched indicates if we have found any match.
|
|
bool matched = false;
|
|
bool ident = true;
|
|
TokenTrie* cur = head.get();
|
|
for (size_t i = 0; pos + i < str.size() && (ident || cur != nullptr); i++) {
|
|
ident = ident && validIdent(i, str[pos + i]);
|
|
if (ident) {
|
|
matched = true;
|
|
*len = i + 1;
|
|
*kind = TK_IDENT;
|
|
}
|
|
// check for token second, so that e.g. 'max' matches the token TK_MAX
|
|
// rather the
|
|
// identifier 'max'
|
|
if (cur) {
|
|
size_t child_offset = 0;
|
|
for (size_t e = cur->child_chars.size(); child_offset < e; ++child_offset) {
|
|
if (cur->child_chars[child_offset] == str[pos + i])
|
|
break;
|
|
}
|
|
|
|
cur = (child_offset == cur->child_chars.size())
|
|
? nullptr
|
|
: cur->child_tries[child_offset].get();
|
|
|
|
if (cur && cur->kind != 0) {
|
|
matched = true;
|
|
*len = i + 1;
|
|
*kind = cur->kind;
|
|
}
|
|
}
|
|
}
|
|
return matched;
|
|
}
|
|
bool isUnary(int kind, int* prec);
|
|
bool isBinary(int kind, int* prec);
|
|
bool isRightAssociative(int kind) {
|
|
switch (kind) {
|
|
case '?':
|
|
case TK_POW:
|
|
return true;
|
|
default:
|
|
return false;
|
|
}
|
|
}
|
|
|
|
private:
|
|
bool validIdent(size_t i, char n) {
|
|
return isalpha(n) || n == '_' || (i > 0 && isdigit(n));
|
|
}
|
|
TokenTrieRef head;
|
|
};
|
|
|
|
SharedParserData& sharedParserData();
|
|
|
|
struct Token {
|
|
int kind;
|
|
SourceRange range;
|
|
Token(int kind, const SourceRange& range) : kind(kind), range(range) {}
|
|
std::string text() {
|
|
return range.text();
|
|
}
|
|
std::string kindString() const {
|
|
return kindToString(kind);
|
|
}
|
|
};
|
|
|
|
struct Lexer {
|
|
explicit Lexer(const std::string& str)
|
|
: file(std::make_shared<std::string>(str)),
|
|
pos(0),
|
|
nesting(0),
|
|
indent_stack(),
|
|
next_tokens(),
|
|
shared(sharedParserData()) {
|
|
auto first_indent = lexRaw(true);
|
|
indent_stack.push_back(first_indent.range.size());
|
|
lex();
|
|
}
|
|
// Return the current token, and then move to the next one
|
|
Token next() {
|
|
if (next_tokens.size() == 0)
|
|
reportError("Lexer invariant violated: empty token queue");
|
|
Token r = next_tokens.front();
|
|
next_tokens.erase(next_tokens.begin());
|
|
if (next_tokens.size() == 0) {
|
|
lex();
|
|
}
|
|
return r;
|
|
}
|
|
// Skip the current token if it matches the given kind
|
|
bool nextIf(int kind) {
|
|
if (cur().kind != kind)
|
|
return false;
|
|
next();
|
|
return true;
|
|
}
|
|
|
|
[[noreturn]] void reportError(const std::string& what) {
|
|
reportError(what, cur());
|
|
}[[noreturn]] void reportError(const std::string& what, const Token& t) {
|
|
std::stringstream ss;
|
|
ss << what << ":\n";
|
|
t.range.highlight(ss);
|
|
throw std::runtime_error(ss.str());
|
|
}
|
|
[[noreturn]] void expected(const std::string& what, const Token& t) {
|
|
std::stringstream ss;
|
|
ss << "expected " << what << " but found '" << t.kindString()
|
|
<< "' here:\n";
|
|
t.range.highlight(ss);
|
|
throw std::runtime_error(ss.str());
|
|
}[[noreturn]] void expected(const std::string& what) {
|
|
expected(what, cur());
|
|
}
|
|
// Check that the current token has a given kind, return the current token,
|
|
// and advance to the next one.
|
|
Token expect(int kind) {
|
|
if (cur().kind != kind) {
|
|
expected(kindToString(kind));
|
|
}
|
|
return next();
|
|
}
|
|
Token& lookahead() {
|
|
if (next_tokens.size() < 2) {
|
|
lex();
|
|
}
|
|
return next_tokens[1];
|
|
}
|
|
Token& cur() {
|
|
return next_tokens.front();
|
|
}
|
|
|
|
private:
|
|
void lex() {
|
|
auto r = lexRaw();
|
|
switch (r.kind) {
|
|
case '(':
|
|
case '[':
|
|
case '{':
|
|
nesting++;
|
|
break;
|
|
case ')':
|
|
case ']':
|
|
case '}':
|
|
nesting--;
|
|
break;
|
|
case TK_WHITESPACE: {
|
|
int depth = r.range.size();
|
|
if (depth > indent_stack.back()) {
|
|
indent_stack.push_back(depth);
|
|
r.kind = TK_INDENT;
|
|
} else if (depth == indent_stack.back()) {
|
|
r.kind = TK_NEWLINE;
|
|
} else {
|
|
next_tokens.emplace_back(TK_NEWLINE, r.range);
|
|
while (indent_stack.back() != depth) {
|
|
indent_stack.pop_back();
|
|
next_tokens.emplace_back(TK_DEDENT, r.range);
|
|
if (indent_stack.size() == 0) {
|
|
reportError("invalid ident level", r);
|
|
}
|
|
}
|
|
return; // We've already queued the tokens
|
|
}
|
|
} break;
|
|
case TK_EOF:
|
|
if (indent_stack.size() > 1) {
|
|
next_tokens.emplace_back(TK_NEWLINE, r.range);
|
|
next_tokens.emplace_back(TK_DEDENT, r.range);
|
|
indent_stack.pop_back();
|
|
return;
|
|
}
|
|
break;
|
|
default:
|
|
break;
|
|
}
|
|
next_tokens.push_back(std::move(r));
|
|
}
|
|
Token lexRaw(bool whitespace_token = false) {
|
|
int kind;
|
|
size_t start;
|
|
size_t length;
|
|
JIT_ASSERT(file);
|
|
if (!shared.match(
|
|
*file,
|
|
pos,
|
|
nesting > 0,
|
|
whitespace_token,
|
|
&kind,
|
|
&start,
|
|
&length)) {
|
|
expected(
|
|
"a valid token",
|
|
Token((*file)[start], SourceRange(file, start, start + 1)));
|
|
}
|
|
auto t = Token(kind, SourceRange(file, start, start + length));
|
|
pos = start + length;
|
|
return t;
|
|
}
|
|
|
|
std::shared_ptr<std::string> file;
|
|
size_t pos;
|
|
size_t nesting; // depth of ( [ { nesting...
|
|
std::vector<int> indent_stack; // stack of identation level of blocks
|
|
// Invariant: this should always contain at least a single element
|
|
std::vector<Token> next_tokens;
|
|
SharedParserData& shared;
|
|
};
|
|
} // namespace script
|
|
} // namespace jit
|
|
} // namespace torch
|