diff --git a/test/cpp/jit/test_lexer.cpp b/test/cpp/jit/test_lexer.cpp index 465adbf6ecb..1a9ddf9d5f3 100644 --- a/test/cpp/jit/test_lexer.cpp +++ b/test/cpp/jit/test_lexer.cpp @@ -29,7 +29,7 @@ TEST(LexerTest, AllTokens) { TEST(LexerTest, SlightlyOffIsNot) { std::vector suffixes = {"", " ", "**"}; for (const auto& suffix : suffixes) { - std::vector extras = {"n", "no", "no3"}; + std::vector extras = {"n", "no", "no3", "note"}; for (const auto& extra : extras) { std::string s = "is " + extra + suffix; Lexer l(std::make_shared(s)); @@ -45,7 +45,7 @@ TEST(LexerTest, SlightlyOffIsNot) { TEST(LexerTest, SlightlyOffNotIn) { std::vector suffixes = {"", " ", "**"}; for (const auto& suffix : suffixes) { - std::vector extras = {"i", "i3"}; + std::vector extras = {"i", "i3", "inn"}; for (const auto& extra : extras) { std::string s = "not " + extra + suffix; Lexer l(std::make_shared(s)); @@ -57,32 +57,4 @@ TEST(LexerTest, SlightlyOffNotIn) { } } } - -TEST(LexerTest, IsNoteBug) { - // The code string `is note` is lexed as TK_ISNOT followed by a - // TK_IDENT that is an e. This is not how it works in Python, but - // presumably we need to maintain this behavior. - Lexer l(std::make_shared("is note")); - const auto is_not_tok = l.next(); - EXPECT_EQ(is_not_tok.kind, TK_ISNOT); - const auto e_tok = l.next(); - EXPECT_EQ(e_tok.kind, TK_IDENT); - EXPECT_EQ(e_tok.range.text(), "e"); - const auto eof_tok = l.next(); - EXPECT_EQ(eof_tok.kind, TK_EOF); -} - -TEST(LexerTest, NotInpBug) { - // Another manifestation of the above IsNoteBug; `not inp` is lexed - // as TK_NOT_IN followed by a TK_IDENT that is a p. Again, not how - // it works in Python. - Lexer l(std::make_shared("not inp")); - const auto not_in_tok = l.next(); - EXPECT_EQ(not_in_tok.kind, TK_NOTIN); - const auto p_tok = l.next(); - EXPECT_EQ(p_tok.kind, TK_IDENT); - EXPECT_EQ(p_tok.range.text(), "p"); - const auto eof_tok = l.next(); - EXPECT_EQ(eof_tok.kind, TK_EOF); -} } // namespace torch::jit diff --git a/torch/csrc/jit/frontend/lexer.h b/torch/csrc/jit/frontend/lexer.h index 0faf6ff24da..ecdb571b802 100644 --- a/torch/csrc/jit/frontend/lexer.h +++ b/torch/csrc/jit/frontend/lexer.h @@ -1,13 +1,16 @@ #pragma once #include #include +#include #include #include #include #include #include +#include #include #include +#include #include #include #include @@ -133,51 +136,10 @@ enum TokenKind { TORCH_API std::string kindToString(int kind); TORCH_API int stringToKind(const std::string& str); -// nested hash tables that indicate char-by-char what is a valid token. -struct TokenTrie; -using TokenTrieRef = std::unique_ptr; -struct TokenTrie { - TokenTrie() = default; - void insert(const char* str, int tok) { - if (*str == '\0') { - AT_ASSERT(kind == 0); - kind = tok; - return; - } - - for (size_t i = 0, e = child_chars.size(); i < e; ++i) { - if (child_chars[i] == *str) { - child_tries[i]->insert(str + 1, tok); - return; - } - } - - child_chars.emplace_back(*str); - child_tries.emplace_back(std::make_unique()); - child_tries.back()->insert(str + 1, tok); - } - int kind{0}; // 0 == invalid token - - std::vector child_chars; - std::vector child_tries; -}; - // stuff that is shared against all TC lexers/parsers and is initialized only // once. struct TORCH_API SharedParserData { - SharedParserData() : head(new TokenTrie()) { - for (const char* c = valid_single_char_tokens; *c; c++) { - std::string str(1, *c); - head->insert(str.c_str(), *c); - } - -#define ADD_CASE(tok, _, tokstring) \ - if (*(tokstring) != '\0') { \ - head->insert((tokstring), (tok)); \ - } - TC_FORALL_TOKEN_KINDS(ADD_CASE) -#undef ADD_CASE - } + SharedParserData() = default; bool match( StringCordView::Iterator pos, @@ -248,41 +210,213 @@ struct TORCH_API SharedParserData { return true; } - // check for either an ident or a token - // ident tracks whether what we have scanned so far could be an identifier - // matched indicates if we have found any match. - bool matched = false; - bool ident = true; - TokenTrie* cur = head.get(); - // for (size_t i = 0; pos + i < str.size() && (ident || cur != nullptr); - // i++) - for (size_t i = 0; pos.has_next() && (ident || cur != nullptr); - ++pos, ++i) { - ident = ident && validIdent(i, *pos); - if (ident) { - matched = true; - *end = pos.next_iter(); - *kind = TK_IDENT; - } - // check for token second, so that e.g. 'max' matches the token TK_MAX - // rather the - // identifier 'max' - if (cur) { - const auto begin_it = cur->child_chars.begin(); - const auto end_it = cur->child_chars.end(); - const auto ch_it = std::find(begin_it, end_it, *pos); - - cur = (ch_it == end_it) ? nullptr - : cur->child_tries[ch_it - begin_it].get(); - - if (cur && cur->kind != 0) { - matched = true; - *end = pos.next_iter(); - *kind = cur->kind; - } - } + if (isalpha(*pos) || *pos == '_') { + matchIdentOrKeyword(pos, kind, end); + return true; } - return matched; + + // Hand-coded DFA matching for tokens that cannot be confused with + // identifiers. We could use a lexer generator toolkit like Flex + // or re2c instead, but that would add another dependency, and I + // expect this component to change infrequently given that PyTorch + // 2.0 is years old already. Note that the tests in text_lexer.cpp + // should guarantee that we don't forget to update this when we + // update TC_FORALL_TOKEN_KINDS. + const auto next_pos = pos.next_iter(); + switch (*pos) { + case '+': { + if (pos.has_next() && *next_pos == '=') { + *end = next_pos.next_iter(); + *kind = TK_PLUS_EQ; + return true; + } + goto single_char_token; + } + case '-': + if (pos.has_next()) { + if (*next_pos == '=') { + *end = next_pos.next_iter(); + *kind = TK_MINUS_EQ; + return true; + } + if (*next_pos == '>') { + *end = next_pos.next_iter(); + *kind = TK_ARROW; + return true; + } + } + goto single_char_token; + case '*': + if (pos.has_next()) { + if (*next_pos == '*') { + if (next_pos.has_next() && *next_pos.next_iter() == '=') { + *end = next_pos.next_iter().next_iter(); + *kind = TK_POW_EQ; + return true; + } + *end = next_pos.next_iter(); + *kind = TK_POW; + return true; + } + if (*next_pos == '=') { + *end = next_pos.next_iter(); + *kind = TK_TIMES_EQ; + return true; + } + } + goto single_char_token; + case '/': + if (pos.has_next()) { + if (*next_pos == '/') { + *end = next_pos.next_iter(); + *kind = TK_FLOOR_DIV; + return true; + } + if (*next_pos == '=') { + *end = next_pos.next_iter(); + *kind = TK_DIV_EQ; + return true; + } + } + goto single_char_token; + case '%': + if (pos.has_next()) { + if (*next_pos == '=') { + *end = next_pos.next_iter(); + *kind = TK_MOD_EQ; + return true; + } + } + goto single_char_token; + case '=': + if (pos.has_next()) { + if (*next_pos == '=') { + *end = next_pos.next_iter(); + *kind = TK_EQ; + return true; + } + } + goto single_char_token; + case '>': + if (pos.has_next()) { + if (*next_pos == '=') { + *end = next_pos.next_iter(); + *kind = TK_GE; + return true; + } + if (*next_pos == '>') { + if (next_pos.has_next() && *next_pos.next_iter() == '=') { + *end = next_pos.next_iter().next_iter(); + *kind = TK_RSHIFT_EQ; + return true; + } + *end = next_pos.next_iter(); + *kind = TK_RSHIFT; + return true; + } + } + goto single_char_token; + case '<': + if (pos.has_next()) { + if (*next_pos == '=') { + if (next_pos.has_next() && *next_pos.next_iter() == '>') { + *end = next_pos.next_iter().next_iter(); + *kind = TK_EQUIVALENT; + return true; + } + *end = next_pos.next_iter(); + *kind = TK_LE; + return true; + } + if (*next_pos == '<') { + if (next_pos.has_next() && *next_pos.next_iter() == '=') { + *end = next_pos.next_iter().next_iter(); + *kind = TK_LSHIFT_EQ; + return true; + } + *end = next_pos.next_iter(); + *kind = TK_LSHIFT; + return true; + } + } + goto single_char_token; + case '.': + if (pos.has_next()) { + if (*next_pos == '.' && next_pos.has_next() && + *next_pos.next_iter() == '.') { + *end = next_pos.next_iter().next_iter(); + *kind = TK_DOTS; + return true; + } + } + goto single_char_token; + case '!': + if (pos.has_next()) { + if (*next_pos == '=') { + *end = next_pos.next_iter(); + *kind = TK_NE; + return true; + } + } + goto single_char_token; + case '&': + if (pos.has_next()) { + if (*next_pos == '=') { + *end = next_pos.next_iter(); + *kind = TK_BIT_AND_EQ; + return true; + } + } + goto single_char_token; + case '^': + if (pos.has_next()) { + if (*next_pos == '=') { + *end = next_pos.next_iter(); + *kind = TK_BIT_XOR_EQ; + return true; + } + } + goto single_char_token; + case '|': + if (pos.has_next()) { + if (*next_pos == '=') { + *end = next_pos.next_iter(); + *kind = TK_BIT_OR_EQ; + return true; + } + } + goto single_char_token; + case '#': + *end = pos + std::strlen("# type:"); + *kind = TK_TYPE_COMMENT; + return true; + case '@': + case '(': + case ')': + case '[': + case ']': + case ':': + case ',': + case '{': + case '}': + case '?': + case '~': + single_char_token: + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + std::strchr(valid_single_char_tokens, *pos) != nullptr, + "Did you forget to add the character `", + *pos, + "` to valid_single_char_tokens?"); + *end = next_pos; + *kind = *pos; + return true; + } + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + std::strchr(valid_single_char_tokens, *pos) == nullptr, + "Did you forget to add the character `", + *pos, + "` to the above switch statement?"); + return false; } bool isUnary(int kind, int* prec); @@ -299,8 +433,196 @@ struct TORCH_API SharedParserData { } private: - bool validIdent(size_t i, char n) { - return isalpha(n) || n == '_' || (i > 0 && isdigit(n)); + void matchIdentOrKeyword( + StringCordView::Iterator pos, + int* kind, + StringCordView::Iterator* end) const { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(pos.has_next()); + static constexpr char kIsNot[] = "is not"; + static constexpr char kNotIn[] = "not in"; + constexpr char kMaybeIsNot = 'i'; + constexpr char kMaybeNotIn = 'n'; + constexpr int kIsNotSpaceIndex = 2; + constexpr int kNotInSpaceIndex = 3; + auto start = pos; + char possible_special_token = *pos; + // The longest tokens are 8 chars. + std::array token_chars; + token_chars.fill('\0'); + token_chars[0] = possible_special_token; + ++pos; + size_t i; + auto valid_ident_char = [](const char ch) { + return isalpha(ch) || ch == '_' || isdigit(ch); + }; + for (i = 1; pos.has_next(); ++pos, ++i) { + auto ch = *pos; + if (possible_special_token == kMaybeIsNot) { + if (ch != kIsNot[i]) { + if (i >= kIsNotSpaceIndex + 1) { + // Kick out to the after-loop flow, which will correctly + // record that we found TK_IS. + break; + } + possible_special_token = '\0'; + } else if (ch == ' ') { + continue; + } + if (possible_special_token && i == sizeof(kIsNot) - 2 && + (!pos.has_next() || !valid_ident_char(*(pos + 1)))) { + *kind = TK_ISNOT; + *end = pos.next_iter(); + return; + } + } else if (possible_special_token == kMaybeNotIn) { + if (ch != kNotIn[i]) { + if (i >= kNotInSpaceIndex + 1) { + // Kick out to the after-loop flow, which will correctly + // record that we found TK_NOT. + break; + } + possible_special_token = '\0'; + } else if (ch == ' ') { + continue; + } + + if (possible_special_token && i == sizeof(kNotIn) - 2 && + (!pos.has_next() || !valid_ident_char(*(pos + 1)))) { + *kind = TK_NOTIN; + *end = pos.next_iter(); + return; + } + } + if (valid_ident_char(ch)) { + if (i < token_chars.size()) { + token_chars[i] = ch; + } + continue; + } + break; + } + + // These two possible_special_token checks have to be after the + // loop and not in the loop because we might see end-of-input + // (e.g., the entire input `not p`). + if (possible_special_token == kMaybeIsNot) { + if (i >= kIsNotSpaceIndex) { + *kind = TK_IS; + *end = start + kIsNotSpaceIndex; + return; + } + } else if (possible_special_token == kMaybeNotIn) { + if (i >= kNotInSpaceIndex) { + *kind = TK_NOT; + *end = start + kNotInSpaceIndex; + return; + } + } + + *end = pos; + *kind = identTokenKind(token_chars, i); + } + + template + static constexpr uint64_t stringToUint64(const char (&str)[N]) { + static_assert(N <= 9); + uint64_t result = 0; + for (auto i : c10::irange(N)) { + if (!str[i]) { + return result; + } +#if __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__ + result |= static_cast(str[i]) << (8 * i); +#elif __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ + result |= static_cast(str[i]) << (56 - 8 * i); +#else +#error "Unexpected or undefined value of __BYTE_ORDER__" +#endif + } + return result; + } + + static int identTokenKind( + const std::array& token_chars, + size_t token_length) { + if (token_length > token_chars.size()) { + return TK_IDENT; + } +#if __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__ + static_assert(stringToUint64("and") == 0x646e61); + static_assert(stringToUint64("Ellipsis") == 0x73697370696c6c45); +#else + static_assert(stringToUint64("and") == 0x616e64); + static_assert(stringToUint64("Ellipsis") == 0x456c6c6980836973); +#endif + + std::uint64_t token = 0; + std::memcpy(&token, token_chars.data(), token_chars.size()); + // FWIW, based on checking Godbolt this probably compiles down to + // binary or linear search over the integers representing our + // strings. I tried an alternate version that switched on the + // first character of the token, but it doesn't seem to matter for + // performance. + switch (token) { + case stringToUint64("Ellipsis"): + return TK_ELLIPSIS; + case stringToUint64("False"): + return TK_FALSE; + case stringToUint64("None"): + return TK_NONE; + case stringToUint64("NoneType"): + return TK_NONE_TYPE; + case stringToUint64("True"): + return TK_TRUE; + case stringToUint64("and"): + return TK_AND; + case stringToUint64("as"): + return TK_AS; + case stringToUint64("assert"): + return TK_ASSERT; + case stringToUint64("break"): + return TK_BREAK; + case stringToUint64("class"): + return TK_CLASS_DEF; + case stringToUint64("continue"): + return TK_CONTINUE; + case stringToUint64("def"): + return TK_DEF; + case stringToUint64("del"): + return TK_DELETE; + case stringToUint64("elif"): + return TK_ELIF; + case stringToUint64("else"): + return TK_ELSE; + case stringToUint64("for"): + return TK_FOR; + case stringToUint64("global"): + return TK_GLOBAL; + case stringToUint64("if"): + return TK_IF; + case stringToUint64("import"): + return TK_IMPORT; + case stringToUint64("in"): + return TK_IN; + case stringToUint64("is"): + return TK_IS; + case stringToUint64("not"): + return TK_NOT; + case stringToUint64("or"): + return TK_OR; + case stringToUint64("pass"): + return TK_PASS; + case stringToUint64("raise"): + return TK_RAISE; + case stringToUint64("return"): + return TK_RETURN; + case stringToUint64("while"): + return TK_WHILE; + case stringToUint64("with"): + return TK_WITH; + default: + return TK_IDENT; + } } // 1. skip whitespace @@ -387,8 +709,6 @@ struct TORCH_API SharedParserData { auto match_string = str.substr(pos, type_string.size()); return match_string == type_string; } - - TokenTrieRef head; }; TORCH_API SharedParserData& sharedParserData(); diff --git a/torch/csrc/jit/frontend/parser_constants.h b/torch/csrc/jit/frontend/parser_constants.h index cf51d10b098..3a902aba589 100644 --- a/torch/csrc/jit/frontend/parser_constants.h +++ b/torch/csrc/jit/frontend/parser_constants.h @@ -1,5 +1,6 @@ #pragma once namespace torch::jit { -static const char* valid_single_char_tokens = "+-*/%@()[]:,={}><.?!&^|~"; +[[maybe_unused]] static const char* valid_single_char_tokens = + "+-*/%@()[]:,={}><.?!&^|~"; } // namespace torch::jit