diff --git a/torch/csrc/jit/frontend/function_schema_parser.cpp b/torch/csrc/jit/frontend/function_schema_parser.cpp index d7e10a1177b..2ad43928419 100644 --- a/torch/csrc/jit/frontend/function_schema_parser.cpp +++ b/torch/csrc/jit/frontend/function_schema_parser.cpp @@ -3,6 +3,7 @@ #include #include #include +#include #include #include #include @@ -108,7 +109,7 @@ struct SchemaParser { std::string name = L.expect(TK_IDENT).text(); if (L.nextIf(':')) { L.expect(':'); - name = name + "::" + L.expect(TK_IDENT).text(); + name = fmt::format("{}::{}", name, L.expect(TK_IDENT).text_view()); } std::string overload_name = ""; if (L.nextIf('.')) { @@ -240,28 +241,31 @@ struct SchemaParser { } case TK_IDENT: { auto tok = L.next(); - auto text = tok.text(); + auto text_view = tok.text_view(); // NB: float/complex/long are here for BC purposes. Other dtypes // are handled via str2dtype. // Please don't add more cases to this if-else block. - if ("float" == text) { + if ("float" == text_view) { return static_cast(at::kFloat); - } else if ("complex" == text) { + } else if ("complex" == text_view) { return static_cast(at::kComplexFloat); - } else if ("long" == text) { + } else if ("long" == text_view) { return static_cast(at::kLong); - } else if ("strided" == text) { + } else if ("strided" == text_view) { return static_cast(at::kStrided); - } else if ("Mean" == text) { + } else if ("Mean" == text_view) { return static_cast(at::Reduction::Mean); - } else if ("contiguous_format" == text) { + } else if ("contiguous_format" == text_view) { return static_cast(c10::MemoryFormat::Contiguous); - } else if ( - isPossiblyOptionalScalarType(real_type) && - str2dtype.count(text) > 0) { - return static_cast(str2dtype.at(text)); } else { - throw(ErrorReport(L.cur().range) << "invalid numeric default value"); + auto text = tok.text(); + if (isPossiblyOptionalScalarType(real_type) && + str2dtype.count(text) > 0) { + return static_cast(str2dtype.at(text)); + } else { + throw( + ErrorReport(L.cur().range) << "invalid numeric default value"); + } } } default: diff --git a/torch/csrc/jit/frontend/lexer.h b/torch/csrc/jit/frontend/lexer.h index f36e421c822..0faf6ff24da 100644 --- a/torch/csrc/jit/frontend/lexer.h +++ b/torch/csrc/jit/frontend/lexer.h @@ -397,9 +397,14 @@ struct Token { int kind; SourceRange range; Token(int kind, SourceRange range) : kind(kind), range(std::move(range)) {} - std::string text() { + std::string text() const { return std::string(range.token_text()); } + + std::string_view text_view() const { + return range.token_text(); + } + std::string kindString() const { return kindToString(kind); } diff --git a/torch/csrc/jit/frontend/schema_type_parser.cpp b/torch/csrc/jit/frontend/schema_type_parser.cpp index 917f70a281c..a7bce8d9eeb 100644 --- a/torch/csrc/jit/frontend/schema_type_parser.cpp +++ b/torch/csrc/jit/frontend/schema_type_parser.cpp @@ -308,7 +308,7 @@ TypePtr SchemaTypeParser::parseRefinedTensor() { return; } bool shape_symbol = false; - if (L.cur().kind == TK_IDENT && L.cur().text() == "SS") { + if (L.cur().kind == TK_IDENT && L.cur().text_view() == "SS") { L.next(); L.expect('('); L.expect('-'); @@ -377,7 +377,7 @@ SchemaTypeParser::parseFakeAndRealType() { }); fake_value = real_value = c10::TypeFactory::create(std::move(types)); - } else if (L.cur().kind == TK_IDENT && L.cur().text() == "Future") { + } else if (L.cur().kind == TK_IDENT && L.cur().text_view() == "Future") { L.next(); // Future L.expect('('); auto p = parseType(); @@ -385,7 +385,7 @@ SchemaTypeParser::parseFakeAndRealType() { auto subalias = std::move(p.second); L.expect(')'); fake_value = real_value = c10::TypeFactory::create(subtype); - } else if (L.cur().kind == TK_IDENT && L.cur().text() == "Await") { + } else if (L.cur().kind == TK_IDENT && L.cur().text_view() == "Await") { L.next(); // Await L.expect('('); auto p = parseType(); @@ -393,7 +393,7 @@ SchemaTypeParser::parseFakeAndRealType() { auto subalias = std::move(p.second); L.expect(')'); fake_value = real_value = c10::TypeFactory::create(subtype); - } else if (L.cur().kind == TK_IDENT && L.cur().text() == "RRef") { + } else if (L.cur().kind == TK_IDENT && L.cur().text_view() == "RRef") { L.next(); // RRef L.expect('('); auto p = parseType(); @@ -401,11 +401,11 @@ SchemaTypeParser::parseFakeAndRealType() { auto subalias = std::move(p.second); L.expect(')'); fake_value = real_value = c10::TypeFactory::create(subtype); - } else if (L.cur().kind == TK_IDENT && L.cur().text() == "Tensor") { + } else if (L.cur().kind == TK_IDENT && L.cur().text_view() == "Tensor") { L.next(); fake_value = real_value = c10::TypeFactory::get(); alias_info = parseAliasAnnotation(); - } else if (L.cur().kind == TK_IDENT && L.cur().text() == "Dict") { + } else if (L.cur().kind == TK_IDENT && L.cur().text_view() == "Dict") { L.next(); L.expect('('); auto key_type = parseType().first; @@ -415,7 +415,7 @@ SchemaTypeParser::parseFakeAndRealType() { alias_info = parseAliasAnnotation(); fake_value = real_value = c10::TypeFactory::create(key_type, value_type); - } else if (L.cur().kind == TK_IDENT && L.cur().text() == "Union") { + } else if (L.cur().kind == TK_IDENT && L.cur().text_view() == "Union") { L.next(); L.expect('('); std::vector types; @@ -433,7 +433,7 @@ SchemaTypeParser::parseFakeAndRealType() { parseTensorDType(L.cur().text())) { fake_value = real_value = parseRefinedTensor(); alias_info = parseAliasAnnotation(); - } else if (L.cur().kind == TK_IDENT && L.cur().text() == "__torch__") { + } else if (L.cur().kind == TK_IDENT && L.cur().text_view() == "__torch__") { L.next(); L.expect('.'); auto torch_tok = L.expect(TK_IDENT);