Add & use Token::text_view() (which returns a string_view unlike text()) (#151804)

Sadly, I can't just fix text() because that might cause lifetime issues in somebody's code.

Differential Revision: [D73376715](https://our.internmc.facebook.com/intern/diff/D73376715/)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/151804
Approved by: https://github.com/zou3519, https://github.com/cyyever, https://github.com/Skylion007, https://github.com/malfet
ghstack dependencies: #151801, #151802, #151803
This commit is contained in:
Scott Wolchok 2025-04-23 17:22:23 -07:00 committed by PyTorch MergeBot
parent 0559741d7f
commit 89a85d0954
3 changed files with 31 additions and 22 deletions

View File

@ -3,6 +3,7 @@
#include <ATen/core/Reduction.h>
#include <ATen/core/jit_type.h>
#include <ATen/core/type_factory.h>
#include <fmt/format.h>
#include <torch/csrc/jit/frontend/lexer.h>
#include <torch/csrc/jit/frontend/parse_string_literal.h>
#include <torch/csrc/jit/frontend/schema_type_parser.h>
@ -108,7 +109,7 @@ struct SchemaParser {
std::string name = L.expect(TK_IDENT).text();
if (L.nextIf(':')) {
L.expect(':');
name = name + "::" + L.expect(TK_IDENT).text();
name = fmt::format("{}::{}", name, L.expect(TK_IDENT).text_view());
}
std::string overload_name = "";
if (L.nextIf('.')) {
@ -240,28 +241,31 @@ struct SchemaParser {
}
case TK_IDENT: {
auto tok = L.next();
auto text = tok.text();
auto text_view = tok.text_view();
// NB: float/complex/long are here for BC purposes. Other dtypes
// are handled via str2dtype.
// Please don't add more cases to this if-else block.
if ("float" == text) {
if ("float" == text_view) {
return static_cast<int64_t>(at::kFloat);
} else if ("complex" == text) {
} else if ("complex" == text_view) {
return static_cast<int64_t>(at::kComplexFloat);
} else if ("long" == text) {
} else if ("long" == text_view) {
return static_cast<int64_t>(at::kLong);
} else if ("strided" == text) {
} else if ("strided" == text_view) {
return static_cast<int64_t>(at::kStrided);
} else if ("Mean" == text) {
} else if ("Mean" == text_view) {
return static_cast<int64_t>(at::Reduction::Mean);
} else if ("contiguous_format" == text) {
} else if ("contiguous_format" == text_view) {
return static_cast<int64_t>(c10::MemoryFormat::Contiguous);
} else if (
isPossiblyOptionalScalarType(real_type) &&
} else {
auto text = tok.text();
if (isPossiblyOptionalScalarType(real_type) &&
str2dtype.count(text) > 0) {
return static_cast<int64_t>(str2dtype.at(text));
} else {
throw(ErrorReport(L.cur().range) << "invalid numeric default value");
throw(
ErrorReport(L.cur().range) << "invalid numeric default value");
}
}
}
default:

View File

@ -397,9 +397,14 @@ struct Token {
int kind;
SourceRange range;
Token(int kind, SourceRange range) : kind(kind), range(std::move(range)) {}
std::string text() {
std::string text() const {
return std::string(range.token_text());
}
std::string_view text_view() const {
return range.token_text();
}
std::string kindString() const {
return kindToString(kind);
}

View File

@ -308,7 +308,7 @@ TypePtr SchemaTypeParser::parseRefinedTensor() {
return;
}
bool shape_symbol = false;
if (L.cur().kind == TK_IDENT && L.cur().text() == "SS") {
if (L.cur().kind == TK_IDENT && L.cur().text_view() == "SS") {
L.next();
L.expect('(');
L.expect('-');
@ -377,7 +377,7 @@ SchemaTypeParser::parseFakeAndRealType() {
});
fake_value = real_value =
c10::TypeFactory::create<TupleType>(std::move(types));
} else if (L.cur().kind == TK_IDENT && L.cur().text() == "Future") {
} else if (L.cur().kind == TK_IDENT && L.cur().text_view() == "Future") {
L.next(); // Future
L.expect('(');
auto p = parseType();
@ -385,7 +385,7 @@ SchemaTypeParser::parseFakeAndRealType() {
auto subalias = std::move(p.second);
L.expect(')');
fake_value = real_value = c10::TypeFactory::create<FutureType>(subtype);
} else if (L.cur().kind == TK_IDENT && L.cur().text() == "Await") {
} else if (L.cur().kind == TK_IDENT && L.cur().text_view() == "Await") {
L.next(); // Await
L.expect('(');
auto p = parseType();
@ -393,7 +393,7 @@ SchemaTypeParser::parseFakeAndRealType() {
auto subalias = std::move(p.second);
L.expect(')');
fake_value = real_value = c10::TypeFactory::create<AwaitType>(subtype);
} else if (L.cur().kind == TK_IDENT && L.cur().text() == "RRef") {
} else if (L.cur().kind == TK_IDENT && L.cur().text_view() == "RRef") {
L.next(); // RRef
L.expect('(');
auto p = parseType();
@ -401,11 +401,11 @@ SchemaTypeParser::parseFakeAndRealType() {
auto subalias = std::move(p.second);
L.expect(')');
fake_value = real_value = c10::TypeFactory::create<RRefType>(subtype);
} else if (L.cur().kind == TK_IDENT && L.cur().text() == "Tensor") {
} else if (L.cur().kind == TK_IDENT && L.cur().text_view() == "Tensor") {
L.next();
fake_value = real_value = c10::TypeFactory::get<TensorType>();
alias_info = parseAliasAnnotation();
} else if (L.cur().kind == TK_IDENT && L.cur().text() == "Dict") {
} else if (L.cur().kind == TK_IDENT && L.cur().text_view() == "Dict") {
L.next();
L.expect('(');
auto key_type = parseType().first;
@ -415,7 +415,7 @@ SchemaTypeParser::parseFakeAndRealType() {
alias_info = parseAliasAnnotation();
fake_value = real_value =
c10::TypeFactory::create<DictType>(key_type, value_type);
} else if (L.cur().kind == TK_IDENT && L.cur().text() == "Union") {
} else if (L.cur().kind == TK_IDENT && L.cur().text_view() == "Union") {
L.next();
L.expect('(');
std::vector<TypePtr> types;
@ -433,7 +433,7 @@ SchemaTypeParser::parseFakeAndRealType() {
parseTensorDType(L.cur().text())) {
fake_value = real_value = parseRefinedTensor();
alias_info = parseAliasAnnotation();
} else if (L.cur().kind == TK_IDENT && L.cur().text() == "__torch__") {
} else if (L.cur().kind == TK_IDENT && L.cur().text_view() == "__torch__") {
L.next();
L.expect('.');
auto torch_tok = L.expect(TK_IDENT);