mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Add & use Token::text_view() (which returns a string_view unlike text()) (#151804)
Sadly, I can't just fix text() because that might cause lifetime issues in somebody's code. Differential Revision: [D73376715](https://our.internmc.facebook.com/intern/diff/D73376715/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/151804 Approved by: https://github.com/zou3519, https://github.com/cyyever, https://github.com/Skylion007, https://github.com/malfet ghstack dependencies: #151801, #151802, #151803
This commit is contained in:
parent
0559741d7f
commit
89a85d0954
|
|
@ -3,6 +3,7 @@
|
|||
#include <ATen/core/Reduction.h>
|
||||
#include <ATen/core/jit_type.h>
|
||||
#include <ATen/core/type_factory.h>
|
||||
#include <fmt/format.h>
|
||||
#include <torch/csrc/jit/frontend/lexer.h>
|
||||
#include <torch/csrc/jit/frontend/parse_string_literal.h>
|
||||
#include <torch/csrc/jit/frontend/schema_type_parser.h>
|
||||
|
|
@ -108,7 +109,7 @@ struct SchemaParser {
|
|||
std::string name = L.expect(TK_IDENT).text();
|
||||
if (L.nextIf(':')) {
|
||||
L.expect(':');
|
||||
name = name + "::" + L.expect(TK_IDENT).text();
|
||||
name = fmt::format("{}::{}", name, L.expect(TK_IDENT).text_view());
|
||||
}
|
||||
std::string overload_name = "";
|
||||
if (L.nextIf('.')) {
|
||||
|
|
@ -240,28 +241,31 @@ struct SchemaParser {
|
|||
}
|
||||
case TK_IDENT: {
|
||||
auto tok = L.next();
|
||||
auto text = tok.text();
|
||||
auto text_view = tok.text_view();
|
||||
// NB: float/complex/long are here for BC purposes. Other dtypes
|
||||
// are handled via str2dtype.
|
||||
// Please don't add more cases to this if-else block.
|
||||
if ("float" == text) {
|
||||
if ("float" == text_view) {
|
||||
return static_cast<int64_t>(at::kFloat);
|
||||
} else if ("complex" == text) {
|
||||
} else if ("complex" == text_view) {
|
||||
return static_cast<int64_t>(at::kComplexFloat);
|
||||
} else if ("long" == text) {
|
||||
} else if ("long" == text_view) {
|
||||
return static_cast<int64_t>(at::kLong);
|
||||
} else if ("strided" == text) {
|
||||
} else if ("strided" == text_view) {
|
||||
return static_cast<int64_t>(at::kStrided);
|
||||
} else if ("Mean" == text) {
|
||||
} else if ("Mean" == text_view) {
|
||||
return static_cast<int64_t>(at::Reduction::Mean);
|
||||
} else if ("contiguous_format" == text) {
|
||||
} else if ("contiguous_format" == text_view) {
|
||||
return static_cast<int64_t>(c10::MemoryFormat::Contiguous);
|
||||
} else if (
|
||||
isPossiblyOptionalScalarType(real_type) &&
|
||||
str2dtype.count(text) > 0) {
|
||||
return static_cast<int64_t>(str2dtype.at(text));
|
||||
} else {
|
||||
throw(ErrorReport(L.cur().range) << "invalid numeric default value");
|
||||
auto text = tok.text();
|
||||
if (isPossiblyOptionalScalarType(real_type) &&
|
||||
str2dtype.count(text) > 0) {
|
||||
return static_cast<int64_t>(str2dtype.at(text));
|
||||
} else {
|
||||
throw(
|
||||
ErrorReport(L.cur().range) << "invalid numeric default value");
|
||||
}
|
||||
}
|
||||
}
|
||||
default:
|
||||
|
|
|
|||
|
|
@ -397,9 +397,14 @@ struct Token {
|
|||
int kind;
|
||||
SourceRange range;
|
||||
Token(int kind, SourceRange range) : kind(kind), range(std::move(range)) {}
|
||||
std::string text() {
|
||||
std::string text() const {
|
||||
return std::string(range.token_text());
|
||||
}
|
||||
|
||||
std::string_view text_view() const {
|
||||
return range.token_text();
|
||||
}
|
||||
|
||||
std::string kindString() const {
|
||||
return kindToString(kind);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -308,7 +308,7 @@ TypePtr SchemaTypeParser::parseRefinedTensor() {
|
|||
return;
|
||||
}
|
||||
bool shape_symbol = false;
|
||||
if (L.cur().kind == TK_IDENT && L.cur().text() == "SS") {
|
||||
if (L.cur().kind == TK_IDENT && L.cur().text_view() == "SS") {
|
||||
L.next();
|
||||
L.expect('(');
|
||||
L.expect('-');
|
||||
|
|
@ -377,7 +377,7 @@ SchemaTypeParser::parseFakeAndRealType() {
|
|||
});
|
||||
fake_value = real_value =
|
||||
c10::TypeFactory::create<TupleType>(std::move(types));
|
||||
} else if (L.cur().kind == TK_IDENT && L.cur().text() == "Future") {
|
||||
} else if (L.cur().kind == TK_IDENT && L.cur().text_view() == "Future") {
|
||||
L.next(); // Future
|
||||
L.expect('(');
|
||||
auto p = parseType();
|
||||
|
|
@ -385,7 +385,7 @@ SchemaTypeParser::parseFakeAndRealType() {
|
|||
auto subalias = std::move(p.second);
|
||||
L.expect(')');
|
||||
fake_value = real_value = c10::TypeFactory::create<FutureType>(subtype);
|
||||
} else if (L.cur().kind == TK_IDENT && L.cur().text() == "Await") {
|
||||
} else if (L.cur().kind == TK_IDENT && L.cur().text_view() == "Await") {
|
||||
L.next(); // Await
|
||||
L.expect('(');
|
||||
auto p = parseType();
|
||||
|
|
@ -393,7 +393,7 @@ SchemaTypeParser::parseFakeAndRealType() {
|
|||
auto subalias = std::move(p.second);
|
||||
L.expect(')');
|
||||
fake_value = real_value = c10::TypeFactory::create<AwaitType>(subtype);
|
||||
} else if (L.cur().kind == TK_IDENT && L.cur().text() == "RRef") {
|
||||
} else if (L.cur().kind == TK_IDENT && L.cur().text_view() == "RRef") {
|
||||
L.next(); // RRef
|
||||
L.expect('(');
|
||||
auto p = parseType();
|
||||
|
|
@ -401,11 +401,11 @@ SchemaTypeParser::parseFakeAndRealType() {
|
|||
auto subalias = std::move(p.second);
|
||||
L.expect(')');
|
||||
fake_value = real_value = c10::TypeFactory::create<RRefType>(subtype);
|
||||
} else if (L.cur().kind == TK_IDENT && L.cur().text() == "Tensor") {
|
||||
} else if (L.cur().kind == TK_IDENT && L.cur().text_view() == "Tensor") {
|
||||
L.next();
|
||||
fake_value = real_value = c10::TypeFactory::get<TensorType>();
|
||||
alias_info = parseAliasAnnotation();
|
||||
} else if (L.cur().kind == TK_IDENT && L.cur().text() == "Dict") {
|
||||
} else if (L.cur().kind == TK_IDENT && L.cur().text_view() == "Dict") {
|
||||
L.next();
|
||||
L.expect('(');
|
||||
auto key_type = parseType().first;
|
||||
|
|
@ -415,7 +415,7 @@ SchemaTypeParser::parseFakeAndRealType() {
|
|||
alias_info = parseAliasAnnotation();
|
||||
fake_value = real_value =
|
||||
c10::TypeFactory::create<DictType>(key_type, value_type);
|
||||
} else if (L.cur().kind == TK_IDENT && L.cur().text() == "Union") {
|
||||
} else if (L.cur().kind == TK_IDENT && L.cur().text_view() == "Union") {
|
||||
L.next();
|
||||
L.expect('(');
|
||||
std::vector<TypePtr> types;
|
||||
|
|
@ -433,7 +433,7 @@ SchemaTypeParser::parseFakeAndRealType() {
|
|||
parseTensorDType(L.cur().text())) {
|
||||
fake_value = real_value = parseRefinedTensor();
|
||||
alias_info = parseAliasAnnotation();
|
||||
} else if (L.cur().kind == TK_IDENT && L.cur().text() == "__torch__") {
|
||||
} else if (L.cur().kind == TK_IDENT && L.cur().text_view() == "__torch__") {
|
||||
L.next();
|
||||
L.expect('.');
|
||||
auto torch_tok = L.expect(TK_IDENT);
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user