mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Revert D34455360: Multisect successfully blamed D34455360 for test failures
Summary: This diff is reverting D34455360 (61d6c43864) D34455360 (61d6c43864) is making the following tests to fail and this revert diff is either the revert of the blame diff or the revert of the stack of diffs that need to be reverted to revert the blame diff Tests affected: - https://www.internalfb.com/intern/test/562950004334605/ Multisect link: https://www.internalfb.com/intern/testinfra/multisect/756170 Test Plan: NA Reviewed By: zhxchen17 Differential Revision: D34596156 fbshipit-source-id: a465bca0094db3caf6130c80f1ed49eea981359b (cherry picked from commit ef5e5578c64ce9827570757fb016aafa9c782c6a)
This commit is contained in:
parent
2e039b653e
commit
0723639b60
|
|
@ -1,32 +0,0 @@
|
|||
#include <gtest/gtest.h>
|
||||
#include <torch/csrc/jit/frontend/source_range.h>
|
||||
|
||||
using namespace ::testing;
|
||||
using namespace ::torch::jit;
|
||||
|
||||
TEST(SourceRangeTest, test_find) {
|
||||
std::vector<std::shared_ptr<std::string>> strings;
|
||||
strings.push_back(std::make_shared<std::string>("hello world"));
|
||||
strings.push_back(std::make_shared<std::string>("nihaoma"));
|
||||
|
||||
std::vector<c10::string_view> pieces{*strings[0], *strings[1]};
|
||||
|
||||
StringCordView view(pieces, strings);
|
||||
|
||||
auto x = view.find("rldni", 0);
|
||||
EXPECT_EQ(x, 8);
|
||||
}
|
||||
|
||||
TEST(SourceRangeTest, test_substr) {
|
||||
std::vector<std::shared_ptr<std::string>> strings;
|
||||
strings.push_back(std::make_shared<std::string>("hello world"));
|
||||
strings.push_back(std::make_shared<std::string>("nihaoma"));
|
||||
|
||||
std::vector<c10::string_view> pieces{*strings[0], *strings[1]};
|
||||
|
||||
StringCordView view(pieces, strings);
|
||||
|
||||
auto x = view.substr(4, 10).str();
|
||||
EXPECT_EQ(x, view.str().substr(4, 10));
|
||||
EXPECT_EQ(view.substr(0, view.size()).str(), view.str());
|
||||
}
|
||||
|
|
@ -143,38 +143,6 @@ TEST(BackendTest, TestCompiler) {
|
|||
AT_ASSERT(mres.toTensor().equal(ref.toTensor()));
|
||||
}
|
||||
|
||||
TEST(BackendTest, TestCompilerWithStringTable) {
|
||||
setShouldUseFormatWithStringTable(true);
|
||||
Module m("m");
|
||||
m.define(R"(
|
||||
def forward(self, x, h):
|
||||
return x + h
|
||||
)");
|
||||
|
||||
std::vector<IValue> inputs;
|
||||
inputs.emplace_back(2.0 * torch::ones({}));
|
||||
inputs.emplace_back(1.0 * torch::ones({}));
|
||||
auto ref = m.forward(inputs);
|
||||
|
||||
c10::Dict<IValue, IValue> compile_spec(StringType::get(), AnyType::get());
|
||||
c10::Dict<IValue, IValue> fake_dict(StringType::get(), AnyType::get());
|
||||
fake_dict.insert("", "");
|
||||
compile_spec.insert("forward", fake_dict);
|
||||
auto any_dict_ty = DictType::create(StringType::get(), AnyType::get());
|
||||
// lowered module
|
||||
auto lm = torch::jit::detail::codegen_backend_module(
|
||||
"backend_with_compiler_demo", m, compile_spec, any_dict_ty);
|
||||
auto res = lm.forward(inputs);
|
||||
AT_ASSERT(res.toTensor().equal(ref.toTensor()));
|
||||
|
||||
std::stringstream ss;
|
||||
lm._save_for_mobile(ss);
|
||||
auto mlm = _load_for_mobile(ss);
|
||||
auto mres = mlm.forward(inputs);
|
||||
setShouldUseFormatWithStringTable(false);
|
||||
AT_ASSERT(mres.toTensor().equal(ref.toTensor()));
|
||||
}
|
||||
|
||||
TEST(BackendTest, TestComposite) {
|
||||
c10::Dict<IValue, IValue> compile_spec(StringType::get(), AnyType::get());
|
||||
c10::Dict<IValue, IValue> fake_dict(StringType::get(), AnyType::get());
|
||||
|
|
@ -415,56 +383,6 @@ Traceback of TorchScript (most recent call last):
|
|||
ASSERT_THROWS_WITH_MESSAGE(mlm.forward(inputs), error_pattern);
|
||||
}
|
||||
|
||||
TEST(BackendTestDebugInfo, TestCompilerWithStringTable) {
|
||||
setShouldUseFormatWithStringTable(true);
|
||||
Module m("m");
|
||||
m.define(R"(
|
||||
def forward(self, x, h):
|
||||
return x + h
|
||||
)");
|
||||
|
||||
std::vector<IValue> inputs;
|
||||
inputs.emplace_back(torch::rand({2, 4}));
|
||||
inputs.emplace_back(torch::rand({13, 9}));
|
||||
|
||||
c10::Dict<IValue, IValue> compile_spec(StringType::get(), AnyType::get());
|
||||
c10::Dict<IValue, IValue> fake_dict(StringType::get(), AnyType::get());
|
||||
fake_dict.insert("", "");
|
||||
compile_spec.insert("forward", fake_dict);
|
||||
auto any_dict_ty = DictType::create(StringType::get(), AnyType::get());
|
||||
// lowered module
|
||||
auto lm = torch::jit::detail::codegen_backend_module(
|
||||
"backend_with_compiler_demo", m, compile_spec, any_dict_ty);
|
||||
|
||||
std::stringstream ss;
|
||||
lm._save_for_mobile(ss, ExtraFilesMap(), true);
|
||||
auto mlm = _load_for_mobile(ss);
|
||||
std::string error_pattern = R"(
|
||||
Module hierarchy:top(m)::<unknown>.__loweredModule__(m)::forward.aten::add
|
||||
Traceback of TorchScript (most recent call last):
|
||||
File "<string>", line 3, in <unknown>
|
||||
|
||||
def forward(self, x: Tensor, h: Tensor):
|
||||
return self.__loweredModule__.forward(x, h)
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
|
||||
|
||||
File "<string>", line 5, in forward
|
||||
typed_inputs: List[Any] = [x, h, ]
|
||||
if self.__backend.is_available() :
|
||||
_0, = self.__backend.execute(self.__handles["forward"], typed_inputs)
|
||||
~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
|
||||
assert isinstance(_0, Tensor)
|
||||
return _0
|
||||
File "<string>", line 3, in <unknown>
|
||||
|
||||
def forward(self, x, h):
|
||||
return x + h
|
||||
~~~~~ <--- HERE
|
||||
)";
|
||||
setShouldUseFormatWithStringTable(false);
|
||||
ASSERT_THROWS_WITH_MESSAGE(mlm.forward(inputs), error_pattern);
|
||||
}
|
||||
|
||||
TEST(BackendTestDebugInfo, TestExceptionStackForCompilerWithModuleHierarchy) {
|
||||
Module a("A");
|
||||
a.define(R"(
|
||||
|
|
|
|||
|
|
@ -38,18 +38,16 @@ static inline void trim(std::string& s) {
|
|||
}
|
||||
} // namespace
|
||||
|
||||
#define ASSERT_THROWS_WITH_MESSAGE(statement, substring) \
|
||||
try { \
|
||||
(void)statement; \
|
||||
FAIL(); \
|
||||
} catch (const std::exception& e) { \
|
||||
std::string substring_s(substring); \
|
||||
trim(substring_s); \
|
||||
auto exception_string = std::string(e.what()); \
|
||||
trim(exception_string); \
|
||||
ASSERT_NE(exception_string.find(substring_s), std::string::npos) \
|
||||
<< " Error was: \n" \
|
||||
<< exception_string; \
|
||||
#define ASSERT_THROWS_WITH_MESSAGE(statement, substring) \
|
||||
try { \
|
||||
(void)statement; \
|
||||
FAIL(); \
|
||||
} catch (const std::exception& e) { \
|
||||
std::string substring_s(substring); \
|
||||
trim(substring_s); \
|
||||
auto exception_string = std::string(e.what()); \
|
||||
trim(exception_string); \
|
||||
ASSERT_NE(exception_string.find(substring_s), std::string::npos); \
|
||||
}
|
||||
|
||||
namespace torch {
|
||||
|
|
|
|||
|
|
@ -2,7 +2,6 @@
|
|||
|
||||
#include <ATen/core/Reduction.h>
|
||||
#include <ATen/core/type_factory.h>
|
||||
#include <c10/util/Optional.h>
|
||||
#include <c10/util/string_utils.h>
|
||||
#include <torch/csrc/jit/frontend/lexer.h>
|
||||
#include <torch/csrc/jit/frontend/parse_string_literal.h>
|
||||
|
|
@ -28,13 +27,8 @@ namespace jit {
|
|||
|
||||
namespace {
|
||||
struct SchemaParser {
|
||||
explicit SchemaParser(const std::string& str)
|
||||
: L(std::make_shared<Source>(
|
||||
c10::string_view(str),
|
||||
c10::nullopt,
|
||||
0,
|
||||
nullptr,
|
||||
Source::DONT_COPY)),
|
||||
SchemaParser(const std::string& str)
|
||||
: L(std::make_shared<SourceView>(c10::string_view(str))),
|
||||
type_parser(L, /*parse_complete_tensor_types*/ false) {}
|
||||
|
||||
either<OperatorName, FunctionSchema> parseDeclaration() {
|
||||
|
|
|
|||
|
|
@ -190,7 +190,7 @@ struct TORCH_API SharedParserData {
|
|||
// find the longest match of str.substring(pos) against a token, return true
|
||||
// if successful filling in kind, start,and len
|
||||
bool match(
|
||||
StringCordView str,
|
||||
c10::string_view str,
|
||||
size_t pos,
|
||||
bool continuation, // are we inside a scope where newlines don't count
|
||||
// (e.g. inside parens)
|
||||
|
|
@ -241,12 +241,12 @@ struct TORCH_API SharedParserData {
|
|||
// invariant: the next token is not whitespace or newline
|
||||
*start = pos;
|
||||
// check for a valid number
|
||||
if (isNumber(str.piece(0), pos, len)) {
|
||||
if (isNumber(str, pos, len)) {
|
||||
*kind = TK_NUMBER;
|
||||
return true;
|
||||
}
|
||||
// check for string
|
||||
if (isString(str.piece(0), pos, len)) {
|
||||
if (isString(str, pos, len)) {
|
||||
*kind = TK_STRINGLITERAL;
|
||||
return true;
|
||||
}
|
||||
|
|
@ -369,7 +369,7 @@ struct TORCH_API SharedParserData {
|
|||
return isspace(n) && n != '\n';
|
||||
}
|
||||
// Make an exception ignoring comments for type annotation comments
|
||||
bool isTypeComment(StringCordView str, size_t pos) {
|
||||
bool isTypeComment(c10::string_view str, size_t pos) {
|
||||
const std::string type_string = "# type:";
|
||||
if (str.size() < pos + type_string.length()) {
|
||||
return false;
|
||||
|
|
@ -388,7 +388,7 @@ struct Token {
|
|||
SourceRange range;
|
||||
Token(int kind, SourceRange range) : kind(kind), range(std::move(range)) {}
|
||||
std::string text() {
|
||||
return range.text().str();
|
||||
return range.text();
|
||||
}
|
||||
std::string kindString() const {
|
||||
return kindToString(kind);
|
||||
|
|
@ -396,7 +396,7 @@ struct Token {
|
|||
};
|
||||
|
||||
struct Lexer {
|
||||
explicit Lexer(std::shared_ptr<Source> source)
|
||||
explicit Lexer(std::shared_ptr<SourceView> source)
|
||||
: source(std::move(source)),
|
||||
pos(0),
|
||||
nesting(0),
|
||||
|
|
@ -519,19 +519,25 @@ struct Lexer {
|
|||
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
||||
size_t length;
|
||||
AT_ASSERT(source);
|
||||
auto src = source->text_str();
|
||||
if (!shared.match(
|
||||
src, pos, nesting > 0, whitespace_token, &kind, &start, &length)) {
|
||||
source->text(),
|
||||
pos,
|
||||
nesting > 0,
|
||||
whitespace_token,
|
||||
&kind,
|
||||
&start,
|
||||
&length)) {
|
||||
expected(
|
||||
"a valid token",
|
||||
Token(source->char_at(start), SourceRange(source, start, start + 1)));
|
||||
Token(
|
||||
(source->text())[start], SourceRange(source, start, start + 1)));
|
||||
}
|
||||
auto t = Token(kind, SourceRange(source, start, start + length));
|
||||
pos = start + length;
|
||||
return t;
|
||||
}
|
||||
|
||||
std::shared_ptr<Source> source;
|
||||
std::shared_ptr<SourceView> source;
|
||||
size_t pos;
|
||||
size_t nesting; // depth of ( [ { nesting...
|
||||
std::vector<int> indent_stack; // stack of indentation level of blocks
|
||||
|
|
|
|||
|
|
@ -46,7 +46,7 @@ Decl mergeTypesFromTypeComment(
|
|||
}
|
||||
|
||||
struct ParserImpl {
|
||||
explicit ParserImpl(const std::shared_ptr<Source>& source)
|
||||
explicit ParserImpl(const std::shared_ptr<SourceView>& source)
|
||||
: L(source), shared(sharedParserData()) {}
|
||||
|
||||
Ident parseIdent() {
|
||||
|
|
@ -801,7 +801,7 @@ struct ParserImpl {
|
|||
SharedParserData& shared;
|
||||
};
|
||||
|
||||
Parser::Parser(const std::shared_ptr<Source>& src)
|
||||
Parser::Parser(const std::shared_ptr<SourceView>& src)
|
||||
: pImpl(new ParserImpl(src)) {}
|
||||
|
||||
Parser::~Parser() = default;
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ TORCH_API Decl mergeTypesFromTypeComment(
|
|||
bool is_method);
|
||||
|
||||
struct TORCH_API Parser {
|
||||
explicit Parser(const std::shared_ptr<Source>& src);
|
||||
explicit Parser(const std::shared_ptr<SourceView>& src);
|
||||
TreeRef parseFunction(bool is_method);
|
||||
TreeRef parseClass();
|
||||
Decl parseTypeComment();
|
||||
|
|
|
|||
|
|
@ -227,7 +227,7 @@ TypePtr ScriptTypeParser::parseTypeFromExpr(const Expr& expr) const {
|
|||
// expression and base type names.
|
||||
if (resolver_) {
|
||||
if (auto typePtr =
|
||||
resolver_->resolveType(expr.range().text().str(), expr.range())) {
|
||||
resolver_->resolveType(expr.range().text(), expr.range())) {
|
||||
return typePtr;
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -4,140 +4,13 @@
|
|||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
// A stringlike class backed by a vector of string_view
|
||||
// the string represented are logically the concatenation of the string_views
|
||||
// This has advantage of not needing continues memory.
|
||||
StringCordView::StringCordView() {
|
||||
accumulated_sizes_.push_back(0);
|
||||
}
|
||||
|
||||
StringCordView::StringCordView(
|
||||
std::vector<c10::string_view> inputs,
|
||||
std::vector<std::shared_ptr<std::string>> ownerships)
|
||||
: pieces_(std::move(inputs)), owned_strings_(std::move(ownerships)) {
|
||||
accumulated_sizes_.push_back(0);
|
||||
size_t running_sum = 0;
|
||||
for (auto& s : pieces_) {
|
||||
if (s.size() > 0) {
|
||||
running_sum += s.size();
|
||||
accumulated_sizes_.push_back(running_sum);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
size_t StringCordView::find(const std::string& tok, size_t start) const {
|
||||
if (tok.size() == 0) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
if ((size() - start) < tok.size()) {
|
||||
return std::string::npos;
|
||||
}
|
||||
|
||||
Iterator begin = iter_for_pos(start);
|
||||
Iterator end_iter = end();
|
||||
size_t offset = start;
|
||||
for (; begin != end_iter; ++begin, ++offset) {
|
||||
if (*begin == tok[0]) {
|
||||
auto mis = std::mismatch(begin, end_iter, tok.begin(), tok.end());
|
||||
if (mis.second == tok.end()) {
|
||||
// no mismatch, and second string (tok) is exhausted.
|
||||
return offset;
|
||||
}
|
||||
if (mis.first == end_iter) {
|
||||
// this str is exhausted but tok is not
|
||||
return std::string::npos;
|
||||
}
|
||||
}
|
||||
}
|
||||
return std::string::npos;
|
||||
}
|
||||
|
||||
StringCordView StringCordView::substr(size_t start, size_t size) const {
|
||||
std::vector<c10::string_view> pieces;
|
||||
std::vector<std::shared_ptr<std::string>> ownerships;
|
||||
if (start >= this->size()) {
|
||||
// out of bounds
|
||||
return StringCordView();
|
||||
}
|
||||
if (start + size >= this->size()) {
|
||||
size = this->size() - start;
|
||||
}
|
||||
Iterator begin = iter_for_pos(start);
|
||||
Iterator end = iter_for_pos(start + size);
|
||||
|
||||
if (begin.line_ == end.line_) {
|
||||
// same line
|
||||
pieces.push_back(pieces_[begin.line_].substr(begin.pos_, size));
|
||||
} else {
|
||||
pieces.push_back(pieces_[begin.line_].substr(begin.pos_));
|
||||
|
||||
size_t last_line = pieces_.size();
|
||||
if (end != this->end() && end.line_ < last_line) {
|
||||
// end is within the string
|
||||
last_line = end.line_;
|
||||
}
|
||||
for (size_t i = begin.line_ + 1; i < last_line; i++) {
|
||||
pieces.push_back(pieces_[i]);
|
||||
}
|
||||
if (end != this->end()) {
|
||||
pieces.push_back(pieces_[end.line_].substr(0, end.pos_));
|
||||
}
|
||||
}
|
||||
|
||||
// share ownership
|
||||
std::copy(
|
||||
owned_strings_.begin(),
|
||||
owned_strings_.end(),
|
||||
std::back_inserter(ownerships));
|
||||
|
||||
return StringCordView(std::move(pieces), std::move(ownerships));
|
||||
}
|
||||
|
||||
bool StringCordView::operator==(const std::string& rhs) {
|
||||
if (size() != rhs.size()) {
|
||||
return false;
|
||||
}
|
||||
auto res = std::mismatch(begin(), end(), rhs.begin(), rhs.end());
|
||||
// both need to exhaust
|
||||
return res.first == end() && res.second == rhs.end();
|
||||
}
|
||||
|
||||
bool StringCordView::operator==(const StringCordView& rhs) {
|
||||
if (size() != rhs.size()) {
|
||||
return false;
|
||||
}
|
||||
auto res = std::mismatch(begin(), end(), rhs.begin(), rhs.end());
|
||||
// both need to exhaust
|
||||
return res.first == end() && res.second == rhs.end();
|
||||
}
|
||||
|
||||
StringCordView::Iterator StringCordView::iter_for_pos(size_t pos) const {
|
||||
if (pos == 0) {
|
||||
return begin();
|
||||
}
|
||||
if (pos >= size()) {
|
||||
return end();
|
||||
}
|
||||
auto upper = std::upper_bound(
|
||||
accumulated_sizes_.begin(), accumulated_sizes_.end(), pos);
|
||||
if (upper == accumulated_sizes_.end()) {
|
||||
return end();
|
||||
}
|
||||
size_t line = upper - accumulated_sizes_.begin() - 1;
|
||||
assert(accumulated_sizes_[line] <= pos);
|
||||
assert(accumulated_sizes_[line + 1] > pos);
|
||||
return Iterator(this, line, pos - accumulated_sizes_[line], size() - pos);
|
||||
}
|
||||
|
||||
size_t SourceRangeHasher::operator()(const torch::jit::SourceRange& key) const {
|
||||
return (
|
||||
std::hash<uintptr_t>()(reinterpret_cast<uintptr_t>(key.source().get())) ^
|
||||
std::hash<size_t>()(key.start()) ^ std::hash<size_t>()(key.end()));
|
||||
}
|
||||
|
||||
c10::optional<SourceRange> Source::findSourceRangeThatGenerated(
|
||||
c10::optional<SourceRange> SourceView::findSourceRangeThatGenerated(
|
||||
const SourceRange& range) {
|
||||
if (!gen_ranges_) {
|
||||
return c10::nullopt;
|
||||
|
|
@ -200,7 +73,7 @@ C10_EXPORT void SourceRange::print_with_context(
|
|||
return;
|
||||
}
|
||||
|
||||
auto str = source_view_->text_str().str();
|
||||
c10::string_view str = source_view_->text();
|
||||
if (size() == str.size()) {
|
||||
// this is just the entire file, not a subset, so print it out.
|
||||
// primarily used to print out python stack traces
|
||||
|
|
@ -268,7 +141,7 @@ C10_EXPORT void SourceRange::print_with_context(
|
|||
line_end = start();
|
||||
while (line_start < range_end) {
|
||||
// move line_end to end of line
|
||||
while (line_end < str.size() && str[line_end] != '\n') {
|
||||
while (str[line_end] != '\n' && line_end < str.size()) {
|
||||
++line_end;
|
||||
}
|
||||
// print line of code
|
||||
|
|
|
|||
|
|
@ -4,172 +4,43 @@
|
|||
|
||||
#include <algorithm>
|
||||
#include <iostream>
|
||||
#include <iterator>
|
||||
#include <memory>
|
||||
#include <numeric>
|
||||
#include <unordered_map>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
class SourceRangeUnpickler;
|
||||
struct SourceRange;
|
||||
|
||||
// A stringlike class backed by a vector of string_view
|
||||
// the string represented are logically the concatenation of the string_views
|
||||
// This has advantage of not needing continues memory.
|
||||
struct TORCH_API StringCordView {
|
||||
StringCordView();
|
||||
StringCordView(
|
||||
std::vector<c10::string_view> inputs,
|
||||
std::vector<std::shared_ptr<std::string>> ownerships);
|
||||
|
||||
size_t size() const {
|
||||
return accumulated_sizes_.back();
|
||||
}
|
||||
|
||||
size_t find(const std::string& tok, size_t start) const;
|
||||
StringCordView substr(size_t start, size_t size) const;
|
||||
|
||||
char at(size_t index) const {
|
||||
return *iter_for_pos(index);
|
||||
}
|
||||
char operator[](size_t index) const {
|
||||
return at(index);
|
||||
}
|
||||
|
||||
std::string str() const {
|
||||
std::stringstream ss;
|
||||
for (auto s : pieces_) {
|
||||
ss << std::string(s);
|
||||
}
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
bool operator==(const std::string& rhs);
|
||||
|
||||
bool operator==(const StringCordView& rhs);
|
||||
|
||||
c10::string_view piece(size_t index) const {
|
||||
return pieces_[index];
|
||||
}
|
||||
|
||||
struct Iterator {
|
||||
Iterator(
|
||||
const StringCordView* str,
|
||||
size_t start_line,
|
||||
size_t start_pos,
|
||||
size_t size)
|
||||
: line_(start_line), pos_(start_pos), str_(str), size_(size) {}
|
||||
explicit Iterator(const StringCordView* str)
|
||||
: Iterator(str, 0, 0, str->size()) {}
|
||||
Iterator(const Iterator&) = default;
|
||||
Iterator(Iterator&&) = default;
|
||||
Iterator& operator=(const Iterator&) = default;
|
||||
Iterator& operator=(Iterator&&) = default;
|
||||
|
||||
Iterator operator++() {
|
||||
if (size_ == 0) {
|
||||
return *this;
|
||||
}
|
||||
if ((pos_ + 1) < str_->pieces_[line_].size()) {
|
||||
pos_++;
|
||||
} else {
|
||||
line_++;
|
||||
pos_ = 0;
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
Iterator operator++(int) {
|
||||
Iterator prev(*this);
|
||||
++(*this);
|
||||
return prev;
|
||||
}
|
||||
|
||||
bool operator==(const Iterator& rhs) const {
|
||||
if (!has_next() && !rhs.has_next()) {
|
||||
return true;
|
||||
}
|
||||
return (str_ == rhs.str_) && (line_ == rhs.line_) && (pos_ == rhs.pos_);
|
||||
}
|
||||
bool operator!=(const Iterator& rhs) {
|
||||
return !((*this) == rhs);
|
||||
}
|
||||
bool has_next() const {
|
||||
return size_ > 0 && (line_ < str_->pieces_.size());
|
||||
}
|
||||
|
||||
char operator*() const {
|
||||
TORCH_INTERNAL_ASSERT(line_ < str_->pieces_.size());
|
||||
TORCH_INTERNAL_ASSERT(pos_ < str_->pieces_[line_].size());
|
||||
return str_->pieces_[line_].at(pos_);
|
||||
}
|
||||
|
||||
private:
|
||||
size_t line_;
|
||||
size_t pos_;
|
||||
const StringCordView* str_;
|
||||
size_t size_;
|
||||
friend struct StringCordView;
|
||||
};
|
||||
|
||||
Iterator begin() const {
|
||||
return Iterator(this, 0, 0, size());
|
||||
}
|
||||
Iterator end() const {
|
||||
return Iterator(this, pieces_.size(), 0, 0);
|
||||
}
|
||||
|
||||
private:
|
||||
Iterator iter_for_pos(size_t pos) const;
|
||||
|
||||
std::vector<c10::string_view> pieces_;
|
||||
std::vector<size_t> accumulated_sizes_;
|
||||
std::vector<std::shared_ptr<std::string>> owned_strings_;
|
||||
};
|
||||
|
||||
// Source represents a code segment. It keeps track of:
|
||||
// SourceView represents a code segment. It keeps track of:
|
||||
// - text_view : the view into text of the code segment
|
||||
// - filename (optional) : if present, represents the name of the file from
|
||||
// which the code segment originated.
|
||||
// - starting_line_no : represents the line in the original file where the
|
||||
// code segment started.
|
||||
struct TORCH_API Source {
|
||||
// Whether or not Source should copy the string passed in the constructor.
|
||||
enum CopiesString { COPIES_STRING, DONT_COPY };
|
||||
|
||||
explicit Source(
|
||||
struct SourceView {
|
||||
explicit SourceView(
|
||||
c10::string_view text_view,
|
||||
c10::optional<std::string> filename = c10::nullopt,
|
||||
size_t starting_line_no = 0,
|
||||
std::shared_ptr<SourceRangeUnpickler> gen_ranges = nullptr,
|
||||
CopiesString copies_str = COPIES_STRING)
|
||||
: filename_(std::move(filename)),
|
||||
starting_line_no_(starting_line_no),
|
||||
std::shared_ptr<SourceRangeUnpickler> gen_ranges = nullptr)
|
||||
: text_view_(text_view),
|
||||
filename_(c10::nullopt),
|
||||
starting_line_no_(0),
|
||||
gen_ranges_(std::move(gen_ranges)) {
|
||||
if (copies_str == COPIES_STRING) {
|
||||
std::shared_ptr<std::string> allocated_str =
|
||||
std::make_shared<std::string>(text_view.data(), text_view.size());
|
||||
text_view_ = StringCordView({*allocated_str}, {allocated_str});
|
||||
} else {
|
||||
text_view_ = StringCordView({text_view}, {});
|
||||
}
|
||||
|
||||
calc_line_start_offsets();
|
||||
}
|
||||
|
||||
explicit Source(
|
||||
StringCordView str,
|
||||
c10::optional<std::string> filename = c10::nullopt,
|
||||
size_t starting_line_no = 0,
|
||||
SourceView(
|
||||
c10::string_view text_view,
|
||||
c10::optional<std::string> filename,
|
||||
size_t starting_line_no,
|
||||
std::shared_ptr<SourceRangeUnpickler> gen_ranges = nullptr)
|
||||
: text_view_(str),
|
||||
: text_view_(text_view),
|
||||
filename_(std::move(filename)),
|
||||
starting_line_no_(starting_line_no),
|
||||
gen_ranges_(std::move(gen_ranges)) {
|
||||
calc_line_start_offsets();
|
||||
}
|
||||
|
||||
// Given a line number (within source_), return the byte offset of the
|
||||
// beginning of that line.
|
||||
size_t offset_for_line(size_t line) const {
|
||||
|
|
@ -183,9 +54,11 @@ struct TORCH_API Source {
|
|||
|
||||
// Calculate the line (within the code segment) on which `offset` resides.
|
||||
size_t lineno_for_offset(size_t offset) const {
|
||||
auto iter = std::upper_bound(
|
||||
line_starting_offsets_.begin(), line_starting_offsets_.end(), offset);
|
||||
return iter - line_starting_offsets_.begin() - 1;
|
||||
return std::upper_bound(
|
||||
line_starting_offsets_.begin(),
|
||||
line_starting_offsets_.end(),
|
||||
offset) -
|
||||
line_starting_offsets_.begin() - 1;
|
||||
}
|
||||
|
||||
// Calculate the line (within the original source file, if present) on which
|
||||
|
|
@ -198,27 +71,11 @@ struct TORCH_API Source {
|
|||
}
|
||||
}
|
||||
|
||||
StringCordView get_line(size_t lineno) const {
|
||||
auto start = offset_for_line(lineno);
|
||||
auto size = (lineno + 1) < num_lines() ? offset_for_line(lineno + 1) - start
|
||||
: text_view_.size() - start;
|
||||
return text_view_.substr(start, size);
|
||||
}
|
||||
|
||||
// Note: this makes a copy
|
||||
StringCordView text_str() const {
|
||||
const c10::string_view text() const {
|
||||
return text_view_;
|
||||
}
|
||||
|
||||
char char_at(size_t index) const {
|
||||
return text_view_.at(index);
|
||||
}
|
||||
|
||||
size_t size() const {
|
||||
return text_view_.size();
|
||||
}
|
||||
|
||||
c10::optional<std::string>& filename() {
|
||||
const c10::optional<std::string>& filename() const {
|
||||
return filename_;
|
||||
}
|
||||
|
||||
|
|
@ -229,20 +86,18 @@ struct TORCH_API Source {
|
|||
c10::optional<SourceRange> findSourceRangeThatGenerated(
|
||||
const SourceRange& range);
|
||||
|
||||
~Source() = default;
|
||||
protected:
|
||||
c10::string_view text_view_;
|
||||
|
||||
private:
|
||||
void calc_line_start_offsets() {
|
||||
line_starting_offsets_.clear();
|
||||
line_starting_offsets_.push_back(0);
|
||||
size_t pos = 0;
|
||||
while ((pos = text_view_.find("\n", pos)) != std::string::npos) {
|
||||
while ((pos = text().find('\n', pos)) != std::string::npos) {
|
||||
line_starting_offsets_.push_back(++pos);
|
||||
}
|
||||
}
|
||||
|
||||
StringCordView text_view_;
|
||||
|
||||
c10::optional<std::string> filename_;
|
||||
// If filename_ is not present, starting_line_no_ is don't care
|
||||
size_t starting_line_no_;
|
||||
|
|
@ -253,15 +108,67 @@ struct TORCH_API Source {
|
|||
std::shared_ptr<SourceRangeUnpickler> gen_ranges_;
|
||||
};
|
||||
|
||||
// Source represents a code segment like SourceView, but the former owns a copy
|
||||
// of source text while the latter doesn't.
|
||||
struct Source : public SourceView {
|
||||
explicit Source(
|
||||
std::string text,
|
||||
std::shared_ptr<SourceRangeUnpickler> gen_ranges = nullptr)
|
||||
: SourceView(text, gen_ranges), text_(std::move(text)) {
|
||||
text_view_ = text_;
|
||||
}
|
||||
|
||||
explicit Source(
|
||||
c10::string_view text_view,
|
||||
std::shared_ptr<SourceRangeUnpickler> gen_ranges = nullptr)
|
||||
: SourceView(text_view, gen_ranges),
|
||||
text_(text_view.begin(), text_view.end()) {
|
||||
text_view_ = text_;
|
||||
}
|
||||
|
||||
explicit Source(
|
||||
std::string text,
|
||||
c10::optional<std::string> filename,
|
||||
size_t starting_line_no,
|
||||
std::shared_ptr<SourceRangeUnpickler> gen_ranges = nullptr)
|
||||
: SourceView(text, filename, starting_line_no, gen_ranges),
|
||||
text_(std::move(text)) {
|
||||
text_view_ = text_;
|
||||
}
|
||||
|
||||
explicit Source(
|
||||
c10::string_view text_view,
|
||||
c10::optional<std::string> filename,
|
||||
size_t starting_line_no,
|
||||
std::shared_ptr<SourceRangeUnpickler> gen_ranges = nullptr)
|
||||
: SourceView(text_view, filename, starting_line_no, gen_ranges),
|
||||
text_(text_view.begin(), text_view.end()) {
|
||||
text_view_ = text_;
|
||||
}
|
||||
|
||||
// Constructor that deepcopies and owns source text referenced in
|
||||
// `source_view`.
|
||||
explicit Source(const SourceView& source_view) : SourceView(source_view) {
|
||||
text_ = std::string(text_view_.begin(), text_view_.end());
|
||||
text_view_ = text_;
|
||||
}
|
||||
|
||||
std::string text_;
|
||||
};
|
||||
|
||||
// A SourceRange is a reference to subset of a Source, specified by `start` and
|
||||
// `end` byte offsets into the source text.
|
||||
struct TORCH_API SourceRange {
|
||||
SourceRange(std::shared_ptr<Source> source_view_, size_t start_, size_t end_)
|
||||
SourceRange(
|
||||
std::shared_ptr<SourceView> source_view_,
|
||||
size_t start_,
|
||||
size_t end_)
|
||||
: source_view_(std::move(source_view_)), start_(start_), end_(end_) {}
|
||||
SourceRange() : source_view_(nullptr), start_(0), end_(0) {}
|
||||
|
||||
const StringCordView text() const {
|
||||
return source_view_->text_str().substr(start(), end() - start());
|
||||
const std::string text() const {
|
||||
auto text_view = source_view_->text().substr(start(), end() - start());
|
||||
return std::string(text_view.begin(), text_view.end());
|
||||
}
|
||||
size_t size() const {
|
||||
return end() - start();
|
||||
|
|
@ -276,7 +183,7 @@ struct TORCH_API SourceRange {
|
|||
bool highlight,
|
||||
const std::string& funcname) const;
|
||||
|
||||
const std::shared_ptr<Source>& source() const {
|
||||
const std::shared_ptr<SourceView>& source() const {
|
||||
return source_view_;
|
||||
}
|
||||
size_t start() const {
|
||||
|
|
@ -322,7 +229,7 @@ struct TORCH_API SourceRange {
|
|||
}
|
||||
|
||||
protected:
|
||||
std::shared_ptr<Source> source_view_;
|
||||
std::shared_ptr<SourceView> source_view_;
|
||||
|
||||
private:
|
||||
size_t start_;
|
||||
|
|
@ -330,16 +237,13 @@ struct TORCH_API SourceRange {
|
|||
};
|
||||
|
||||
// OwnedSourceRange is just like a SourceRange except that it owns a `Source`
|
||||
// instead of `Source`. Thus OwnedSourceRange owns a copy of source text.
|
||||
// instead of `SourceView`. Thus OwnedSourceRange owns a copy of source text.
|
||||
struct OwnedSourceRange : public SourceRange {
|
||||
explicit OwnedSourceRange(const SourceRange& source_range)
|
||||
OwnedSourceRange(const SourceRange& source_range)
|
||||
: SourceRange(source_range) {
|
||||
const auto& source = source_range.source();
|
||||
if (source) {
|
||||
source_view_ = std::make_shared<Source>(
|
||||
source->text_str().str(),
|
||||
source->filename(),
|
||||
source->starting_line_no());
|
||||
source_view_ = std::make_shared<Source>(*source);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
|
@ -377,14 +281,3 @@ using SourceRangeTagMap =
|
|||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
||||
namespace std {
|
||||
template <>
|
||||
struct iterator_traits<torch::jit::StringCordView::Iterator> {
|
||||
using value_type = char;
|
||||
using difference_type = ptrdiff_t;
|
||||
using pointer = char*;
|
||||
using reference = char&;
|
||||
using iterator_category = std::forward_iterator_tag;
|
||||
};
|
||||
} // namespace std
|
||||
|
|
|
|||
|
|
@ -21,26 +21,26 @@ namespace jit {
|
|||
*/
|
||||
class TORCH_API SourceRef : public CustomClassHolder {
|
||||
public:
|
||||
explicit SourceRef(std::shared_ptr<Source> source_view)
|
||||
explicit SourceRef(std::shared_ptr<SourceView> source_view)
|
||||
: source_view_(std::move(source_view)) {}
|
||||
bool operator==(const SourceRef& other) const {
|
||||
return source_view_ == other.source_view_;
|
||||
}
|
||||
bool operator<(const Source& other) const {
|
||||
bool operator<(const SourceView& other) const {
|
||||
return source_view_.get() < &other;
|
||||
}
|
||||
friend bool operator<(const Source& other, const SourceRef& self) {
|
||||
friend bool operator<(const SourceView& other, const SourceRef& self) {
|
||||
return &other < self.source_view_.get();
|
||||
}
|
||||
bool operator<(const SourceRef& other) const {
|
||||
return *this < *other.source_view_.get();
|
||||
}
|
||||
const Source* operator->() const {
|
||||
const SourceView* operator->() const {
|
||||
return source_view_.get();
|
||||
}
|
||||
|
||||
private:
|
||||
std::shared_ptr<Source> source_view_;
|
||||
std::shared_ptr<SourceView> source_view_;
|
||||
};
|
||||
|
||||
} // namespace jit
|
||||
|
|
|
|||
|
|
@ -122,26 +122,17 @@ MobileDebugTable::MobileDebugTable(
|
|||
at::DataPtr debug_data;
|
||||
size_t debug_size{0};
|
||||
std::tie(debug_data, debug_size) = reader->getRecord(record_name);
|
||||
auto ivalueTuple = jit::unpickle(
|
||||
reinterpret_cast<const char*>(debug_data.get()),
|
||||
debug_size,
|
||||
nullptr,
|
||||
{},
|
||||
c10::parseType);
|
||||
const auto& ivalues = ivalueTuple.toTuple()->elements();
|
||||
IValue lines;
|
||||
std::unique_ptr<SourceRangeDeserializer> deserializer;
|
||||
if (ivalues.size() == 3 && ivalues[0].isString() &&
|
||||
kFormatWithStringTable == ivalues[0].toStringRef()) {
|
||||
// new format
|
||||
deserializer = std::make_unique<SourceRangeDeserializer>(ivalues[1]);
|
||||
lines = ivalues[2];
|
||||
} else {
|
||||
deserializer = std::make_unique<SourceRangeDeserializer>();
|
||||
lines = ivalueTuple;
|
||||
}
|
||||
|
||||
for (auto& val : lines.toTuple()->elements()) {
|
||||
auto ivalues =
|
||||
std::move(*jit::unpickle(
|
||||
reinterpret_cast<const char*>(debug_data.get()),
|
||||
debug_size,
|
||||
nullptr,
|
||||
{},
|
||||
c10::parseType)
|
||||
.toTuple())
|
||||
.elements();
|
||||
SourceRangeDeserializer deserializer;
|
||||
for (auto& val : ivalues) {
|
||||
auto tup_elems = std::move(*std::move(val).toTuple()).elements();
|
||||
// For BC we decode only tuples with 3 elements
|
||||
// assuming it contains
|
||||
|
|
@ -149,7 +140,7 @@ MobileDebugTable::MobileDebugTable(
|
|||
if (tup_elems.size() == 3) {
|
||||
int64_t debug_handle = tup_elems[kSourceRangeTagIndex].toInt();
|
||||
auto source_range =
|
||||
deserializer->deserialize(tup_elems[kSourceRangeIndex]);
|
||||
deserializer.deserialize(tup_elems[kSourceRangeIndex]);
|
||||
source_range_map.emplace(debug_handle, std::move(source_range));
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -3,7 +3,6 @@
|
|||
#include <caffe2/serialize/inline_container.h>
|
||||
#include <torch/csrc/jit/api/compilation_unit.h>
|
||||
#include <torch/csrc/jit/ir/scope.h>
|
||||
#include <torch/csrc/jit/serialization/source_range_serialization.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ bool IndexingPatternFinder::IsSameSource(const Node* n, const Node* m) {
|
|||
const auto source_n = n->sourceRange().source();
|
||||
const auto source_m = m->sourceRange().source();
|
||||
return (
|
||||
(source_n->text_str() == source_m->text_str()) &&
|
||||
(source_n->text() == source_m->text()) &&
|
||||
(source_n->starting_line_no() == source_m->starting_line_no()));
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -104,8 +104,9 @@ void initTreeViewBindings(PyObject* module) {
|
|||
return SourceRange(self.source_, start, end);
|
||||
})
|
||||
.def_property_readonly("source", [](const SourceRangeFactory& self) {
|
||||
auto text_view = self.source_->text_str().str();
|
||||
return text_view;
|
||||
auto text_view = self.source_->text();
|
||||
std::string text(text_view.begin(), text_view.end());
|
||||
return text;
|
||||
});
|
||||
|
||||
py::class_<TreeView>(m, "TreeView")
|
||||
|
|
|
|||
|
|
@ -2191,10 +2191,6 @@ void initJitScriptBindings(PyObject* module) {
|
|||
m.def(
|
||||
"_run_emit_module_hook", [](const Module& m) { didFinishEmitModule(m); });
|
||||
|
||||
m.def(
|
||||
"_set_should_use_format_with_string_table",
|
||||
setShouldUseFormatWithStringTable);
|
||||
|
||||
// NOLINTNEXTLINE(bugprone-unused-raii)
|
||||
py::class_<logging::LoggerBase, std::shared_ptr<logging::LoggerBase>>(
|
||||
m, "LoggerBase");
|
||||
|
|
|
|||
|
|
@ -61,7 +61,7 @@ auto initBindings() {
|
|||
return static_cast<int64_t>((*self)->starting_line_no());
|
||||
})
|
||||
.def("text", [](const c10::intrusive_ptr<SourceRef>& self) {
|
||||
return (*self)->text_str().str();
|
||||
return (*self)->text();
|
||||
});
|
||||
|
||||
torch::class_<InstructionStats>("profiling", "InstructionStats")
|
||||
|
|
|
|||
|
|
@ -22,7 +22,7 @@ std::string qualifierToArchivePath(
|
|||
return export_prefix + path + "." + kExportSuffix;
|
||||
}
|
||||
|
||||
std::shared_ptr<Source> findSourceInArchiveFromQualifier(
|
||||
std::shared_ptr<SourceView> findSourceInArchiveFromQualifier(
|
||||
caffe2::serialize::PyTorchStreamReader& reader,
|
||||
const std::string& export_prefix,
|
||||
const std::string& qualifier) {
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ class PyTorchStreamReader;
|
|||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
struct Source;
|
||||
struct SourceView;
|
||||
|
||||
// Convert a class type's qualifier name to the corresponding path the source
|
||||
// file it should be written to.
|
||||
|
|
@ -23,7 +23,7 @@ std::string qualifierToArchivePath(
|
|||
const std::string& qualifier,
|
||||
const std::string& export_prefix);
|
||||
|
||||
std::shared_ptr<Source> findSourceInArchiveFromQualifier(
|
||||
std::shared_ptr<SourceView> findSourceInArchiveFromQualifier(
|
||||
caffe2::serialize::PyTorchStreamReader& reader,
|
||||
const std::string& export_prefix,
|
||||
const std::string& qualifier);
|
||||
|
|
|
|||
|
|
@ -159,7 +159,7 @@ void SourceImporterImpl::parseSourceIfNeeded(const std::string& qualifier) {
|
|||
return;
|
||||
}
|
||||
loaded_sources_.insert(qualifier);
|
||||
std::shared_ptr<Source> src = source_loader_(qualifier);
|
||||
std::shared_ptr<SourceView> src = source_loader_(qualifier);
|
||||
|
||||
// The importer, when looking for classes/functions doesn't know if 'foo'
|
||||
// contains definitions or if it is a prefix of 'foo.bar', we only figure it
|
||||
|
|
|
|||
|
|
@ -20,7 +20,8 @@
|
|||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
using SourceLoader = std::function<std::shared_ptr<Source>(const std::string&)>;
|
||||
using SourceLoader =
|
||||
std::function<std::shared_ptr<SourceView>(const std::string&)>;
|
||||
|
||||
struct SourceImporterImpl : public Resolver,
|
||||
std::enable_shared_from_this<SourceImporterImpl> {
|
||||
|
|
|
|||
|
|
@ -1,97 +1,54 @@
|
|||
#include <torch/csrc/jit/serialization/source_range_serialization.h>
|
||||
#include <torch/csrc/jit/serialization/source_range_serialization_impl.h>
|
||||
|
||||
#include <c10/util/Exception.h>
|
||||
#include <c10/util/Flags.h>
|
||||
#include <torch/csrc/jit/mobile/type_parser.h>
|
||||
#include <torch/csrc/jit/serialization/pickle.h>
|
||||
#include <algorithm>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
// "Whether to emit compact debug_pkl when saving a model to .pt file."
|
||||
// "Compact file is smaller but cannot be loaded by old torch binaries."
|
||||
// TODO(qihan) remove when all binaries are using string table.
|
||||
thread_local bool should_use_format_with_string_table_ = false;
|
||||
|
||||
class SourceRangeSerializer {
|
||||
public:
|
||||
// Serialize SourceRange as Tuple[SourceType, int, int]
|
||||
// where SourceType = Tuple[int, int, int, List[int]],
|
||||
// The first 2 ints are positions into the vector returned by textSaved
|
||||
// after all the Ranges are processed. textSaved() returns a vector of str
|
||||
// where SourceType = Tuple[str, Optional[str], int, List[int]],
|
||||
// the serialized form of Source
|
||||
c10::IValue serialize(const SourceRange& sr);
|
||||
|
||||
const std::vector<c10::IValue>& texts_saved() {
|
||||
return texts_;
|
||||
}
|
||||
|
||||
SourceRangeSerializer() {
|
||||
texts_.emplace_back("");
|
||||
text_to_idx_[texts_.back().toStringRef()] = 0;
|
||||
}
|
||||
|
||||
private:
|
||||
// Serialize Source as Tuple[str, Optional[str], int, List[int]]
|
||||
// This caches serialized sources, since many SourceRanges can
|
||||
// refer to the same one.
|
||||
c10::IValue serialize_source(const std::shared_ptr<Source>& s);
|
||||
std::unordered_map<std::shared_ptr<Source>, c10::IValue> serialized_sources;
|
||||
c10::IValue serialize_source(const std::shared_ptr<SourceView>& s);
|
||||
|
||||
int64_t store_text_and_get_index(const std::string& text_view);
|
||||
|
||||
std::vector<c10::IValue> texts_;
|
||||
std::unordered_map<c10::string_view, int64_t> text_to_idx_;
|
||||
std::unordered_map<std::shared_ptr<SourceView>, c10::IValue>
|
||||
serialized_sources;
|
||||
};
|
||||
|
||||
SourceRange SourceRangeDeserializer::deserialize(const c10::IValue& iv) {
|
||||
const auto& tup_elems = iv.toTupleRef().elements();
|
||||
TORCH_INTERNAL_ASSERT(tup_elems.size() == 3);
|
||||
std::shared_ptr<Source> source_ = deserialize_source(tup_elems[0]);
|
||||
std::shared_ptr<SourceView> source_ = deserialize_source(tup_elems[0]);
|
||||
int64_t start_ = tup_elems[1].toInt();
|
||||
int64_t end_ = tup_elems[2].toInt();
|
||||
return SourceRange(source_, start_, end_);
|
||||
}
|
||||
|
||||
std::shared_ptr<Source> SourceRangeDeserializer::deserialize_source(
|
||||
std::shared_ptr<SourceView> SourceRangeDeserializer::deserialize_source(
|
||||
const c10::IValue& iv) {
|
||||
auto tup = iv.toTuple();
|
||||
auto it = cached_sources.find(tup);
|
||||
if (it != cached_sources.end()) {
|
||||
return it->second;
|
||||
}
|
||||
std::shared_ptr<Source> source;
|
||||
|
||||
const auto& tup_elems = tup->elements();
|
||||
TORCH_INTERNAL_ASSERT(tup_elems.size() == 3);
|
||||
if (!text_table_.empty()) {
|
||||
const auto& textIndex = tup_elems[0].toIntList();
|
||||
int64_t fnameIndex = tup_elems[1].toInt();
|
||||
int64_t starting_line_no_ = tup_elems[2].toInt();
|
||||
c10::optional<std::string> filename = c10::nullopt;
|
||||
std::string text_ = tup_elems[0].toString()->string();
|
||||
c10::optional<std::string> filename_ = tup_elems[1].toOptional<std::string>();
|
||||
int64_t starting_line_no_ = tup_elems[2].toInt();
|
||||
|
||||
filename = *text_table_[fnameIndex];
|
||||
|
||||
std::vector<c10::string_view> pieces;
|
||||
std::vector<std::shared_ptr<std::string>> strs;
|
||||
|
||||
for (int64_t i : textIndex) {
|
||||
pieces.emplace_back(*text_table_[i]);
|
||||
strs.emplace_back(text_table_[i]);
|
||||
}
|
||||
|
||||
StringCordView str_cord(std::move(pieces), std::move(strs));
|
||||
|
||||
source = std::make_shared<Source>(str_cord, filename, starting_line_no_);
|
||||
} else {
|
||||
std::string text_ = tup_elems[0].toString()->string();
|
||||
c10::optional<std::string> filename_ =
|
||||
tup_elems[1].toOptional<std::string>();
|
||||
int64_t starting_line_no_ = tup_elems[2].toInt();
|
||||
source = std::make_shared<Source>(
|
||||
std::move(text_), std::move(filename_), starting_line_no_);
|
||||
}
|
||||
auto source = std::make_shared<Source>(
|
||||
std::move(text_), std::move(filename_), starting_line_no_);
|
||||
cached_sources[tup] = source;
|
||||
return source;
|
||||
}
|
||||
|
|
@ -101,50 +58,17 @@ c10::IValue SourceRangeSerializer::serialize(const SourceRange& sr) {
|
|||
serialize_source(sr.source()), (int64_t)sr.start(), (int64_t)sr.end());
|
||||
}
|
||||
|
||||
int64_t SourceRangeSerializer::store_text_and_get_index(
|
||||
const std::string& text_view) {
|
||||
auto text_iter = text_to_idx_.find(text_view);
|
||||
if (text_iter == text_to_idx_.end()) {
|
||||
int64_t text_pos = static_cast<int64_t>(texts_.size());
|
||||
texts_.emplace_back(text_view);
|
||||
text_to_idx_[texts_.back().toStringView()] = text_pos;
|
||||
return text_pos;
|
||||
} else {
|
||||
return text_iter->second;
|
||||
}
|
||||
}
|
||||
|
||||
c10::IValue SourceRangeSerializer::serialize_source(
|
||||
const std::shared_ptr<Source>& s) {
|
||||
const std::shared_ptr<SourceView>& s) {
|
||||
if (serialized_sources.count(s)) {
|
||||
return serialized_sources.at(s);
|
||||
}
|
||||
c10::intrusive_ptr<c10::ivalue::Tuple> serialized;
|
||||
c10::List<int64_t> lines;
|
||||
if (should_use_format_with_string_table_) {
|
||||
if (s == nullptr) {
|
||||
serialized = c10::ivalue::Tuple::create({lines, 0, 0});
|
||||
} else {
|
||||
for (size_t lineno = 0; lineno < s->num_lines(); lineno++) {
|
||||
std::string line_content = s->get_line(lineno).str();
|
||||
int64_t text_pos = store_text_and_get_index(line_content);
|
||||
lines.push_back(text_pos);
|
||||
}
|
||||
|
||||
int64_t fname_pos = 0;
|
||||
if (s->filename().has_value()) {
|
||||
fname_pos = store_text_and_get_index(*s->filename());
|
||||
}
|
||||
serialized = c10::ivalue::Tuple::create(
|
||||
{lines, fname_pos, (int64_t)s->starting_line_no()});
|
||||
}
|
||||
if (s == nullptr) {
|
||||
serialized = c10::ivalue::Tuple::create({"", "", 0});
|
||||
} else {
|
||||
if (s == nullptr) {
|
||||
serialized = c10::ivalue::Tuple::create({"", "", 0});
|
||||
} else {
|
||||
serialized = c10::ivalue::Tuple::create(
|
||||
{s->text_str().str(), s->filename(), (int64_t)s->starting_line_no()});
|
||||
}
|
||||
serialized = c10::ivalue::Tuple::create(
|
||||
{s->text(), s->filename(), (int64_t)s->starting_line_no()});
|
||||
}
|
||||
serialized_sources[s] = serialized;
|
||||
return serialized;
|
||||
|
|
@ -162,24 +86,14 @@ std::vector<char> SourceRangePickler::pickle(
|
|||
if (it != source_range_tags.end()) {
|
||||
source_range_tag = it->second;
|
||||
}
|
||||
|
||||
ivalues.emplace_back(c10::ivalue::Tuple::create(
|
||||
{(int64_t)range.bytes,
|
||||
srs->serialize(range.range),
|
||||
static_cast<int64_t>(source_range_tag)}));
|
||||
}
|
||||
|
||||
std::vector<at::Tensor> table;
|
||||
auto textTable = c10::ivalue::Tuple::create(srs->texts_saved());
|
||||
auto ivalue = c10::ivalue::Tuple::create(std::move(ivalues));
|
||||
std::vector<char> result;
|
||||
if (should_use_format_with_string_table_) {
|
||||
result = jit::pickle(
|
||||
c10::ivalue::Tuple::create({kFormatWithStringTable, textTable, ivalue}),
|
||||
&table);
|
||||
} else {
|
||||
result = jit::pickle(ivalue, &table);
|
||||
}
|
||||
auto result = jit::pickle(ivalue, &table);
|
||||
TORCH_CHECK(table.size() == 0, "Expected 0 tensors to be written");
|
||||
return result;
|
||||
}
|
||||
|
|
@ -189,7 +103,7 @@ ConcreteSourceRangeUnpickler::ConcreteSourceRangeUnpickler(
|
|||
size_t size)
|
||||
: data(std::move(data)),
|
||||
size(size),
|
||||
deserializer(nullptr),
|
||||
deserializer(new SourceRangeDeserializer()),
|
||||
unpickled_records(nullptr) {}
|
||||
|
||||
void ConcreteSourceRangeUnpickler::unpickle() {
|
||||
|
|
@ -205,19 +119,10 @@ void ConcreteSourceRangeUnpickler::unpickle() {
|
|||
{},
|
||||
c10::parseType)
|
||||
.toTuple();
|
||||
|
||||
const auto& ivalues = ivaluesTuple->elements();
|
||||
|
||||
unpickled_records = std::make_shared<SourceRangeRecords>();
|
||||
IValue lines;
|
||||
if (ivalues[0].isString() &&
|
||||
kFormatWithStringTable == ivalues[0].toStringRef()) {
|
||||
deserializer.reset(new SourceRangeDeserializer(ivalues[1]));
|
||||
lines = ivalues[2];
|
||||
} else {
|
||||
deserializer.reset(new SourceRangeDeserializer());
|
||||
lines = ivaluesTuple;
|
||||
}
|
||||
for (auto& val : lines.toTuple()->elements()) {
|
||||
for (auto& val : ivalues) {
|
||||
const auto& tup_elems = val.toTupleRef().elements();
|
||||
int64_t offset = tup_elems[kByteOffsetIndex].toInt();
|
||||
auto source_range = deserializer->deserialize(tup_elems[kSourceRangeIndex]);
|
||||
|
|
@ -247,10 +152,5 @@ c10::optional<SourceRange> ConcreteSourceRangeUnpickler::
|
|||
return c10::nullopt;
|
||||
}
|
||||
|
||||
TORCH_API void setShouldUseFormatWithStringTable(
|
||||
bool should_use_format_with_string_table) {
|
||||
should_use_format_with_string_table_ = should_use_format_with_string_table;
|
||||
}
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
|
|
|||
|
|
@ -20,7 +20,6 @@ class SourceRangeSerializer;
|
|||
static constexpr size_t kByteOffsetIndex = 0;
|
||||
static constexpr size_t kSourceRangeIndex = 1;
|
||||
static constexpr size_t kSourceRangeTagIndex = 2;
|
||||
constexpr c10::string_view kFormatWithStringTable = "FORMAT_WITH_STRING_TABLE";
|
||||
|
||||
class SourceRangePickler {
|
||||
public:
|
||||
|
|
@ -36,21 +35,14 @@ class SourceRangePickler {
|
|||
|
||||
class SourceRangeDeserializer {
|
||||
public:
|
||||
SourceRangeDeserializer() = default;
|
||||
explicit SourceRangeDeserializer(c10::IValue text_table) {
|
||||
for (const auto& x : text_table.toTuple()->elements()) {
|
||||
text_table_.emplace_back(std::make_shared<std::string>(x.toStringRef()));
|
||||
}
|
||||
}
|
||||
SourceRange deserialize(const c10::IValue& iv);
|
||||
|
||||
private:
|
||||
std::shared_ptr<Source> deserialize_source(const c10::IValue& iv);
|
||||
std::shared_ptr<SourceView> deserialize_source(const c10::IValue& iv);
|
||||
std::unordered_map<
|
||||
c10::intrusive_ptr<c10::ivalue::Tuple>,
|
||||
std::shared_ptr<Source>>
|
||||
std::shared_ptr<SourceView>>
|
||||
cached_sources;
|
||||
std::vector<std::shared_ptr<std::string>> text_table_;
|
||||
};
|
||||
|
||||
class SourceRangeUnpickler {
|
||||
|
|
@ -61,8 +53,5 @@ class SourceRangeUnpickler {
|
|||
virtual ~SourceRangeUnpickler() = default;
|
||||
};
|
||||
|
||||
TORCH_API void setShouldUseFormatWithStringTable(
|
||||
bool should_use_format_with_string_table);
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
|
|
|||
|
|
@ -94,7 +94,7 @@ size_t assertFind(
|
|||
const SourceRange& search_range,
|
||||
const std::string& sub,
|
||||
const std::function<void(std::ostream& out)>& extra_msg = nullptr) {
|
||||
auto pos = search_range.source()->text_str().find(sub, search_range.start());
|
||||
auto pos = search_range.source()->text().find(sub, search_range.start());
|
||||
if (pos == std::string::npos || (pos + sub.size()) > search_range.end()) {
|
||||
auto found_range =
|
||||
SourceRange(search_range.source(), search_range.start(), sub.size());
|
||||
|
|
@ -122,18 +122,19 @@ size_t assertFind(
|
|||
}
|
||||
|
||||
size_t assertFind(
|
||||
const std::shared_ptr<Source>& source,
|
||||
const std::shared_ptr<SourceView>& source,
|
||||
const std::string& sub,
|
||||
size_t start,
|
||||
const Check& check) {
|
||||
return assertFind(SourceRange(source, start, source->size()), sub, check);
|
||||
return assertFind(
|
||||
SourceRange(source, start, source->text().size()), sub, check);
|
||||
}
|
||||
|
||||
void assertNotFind(
|
||||
const SourceRange& search_range,
|
||||
const std::string& sub,
|
||||
const Check& check) {
|
||||
auto pos = search_range.source()->text_str().find(sub, search_range.start());
|
||||
auto pos = search_range.source()->text().find(sub, search_range.start());
|
||||
if (pos != std::string::npos && (pos + sub.size()) <= search_range.end()) {
|
||||
auto found_range =
|
||||
SourceRange(search_range.source(), pos, sub.size() + pos);
|
||||
|
|
@ -201,7 +202,9 @@ struct FileCheckImpl {
|
|||
friend std::ostream& operator<<(std::ostream& out, const FileCheckImpl& fc);
|
||||
|
||||
private:
|
||||
bool parseSingleCheck(const std::shared_ptr<Source>& source, size_t* start) {
|
||||
bool parseSingleCheck(
|
||||
const std::shared_ptr<SourceView>& source,
|
||||
size_t* start) {
|
||||
const static std::vector<std::pair<CheckType, std::string>> check_pairs = {
|
||||
{CHECK, ": "},
|
||||
{CHECK_NEXT, "-NEXT: "},
|
||||
|
|
@ -214,35 +217,31 @@ struct FileCheckImpl {
|
|||
|
||||
for (const auto& check_pair : check_pairs) {
|
||||
const std::string& check_suffix = check_pair.second;
|
||||
auto suffix_pos = source->text_str().find(check_suffix, *start);
|
||||
auto suffix_pos = source->text().find(check_suffix, *start);
|
||||
if (suffix_pos != *start) {
|
||||
continue;
|
||||
}
|
||||
size_t end_check_string = suffix_pos + check_suffix.size();
|
||||
CheckType type = check_pair.first;
|
||||
c10::optional<size_t> count = c10::nullopt;
|
||||
auto end_line = source->text_str().find("\n", end_check_string);
|
||||
auto end_line = source->text().find('\n', end_check_string);
|
||||
bool exactly = false;
|
||||
if (type == CHECK_COUNT) {
|
||||
const std::string exact = "EXACTLY-";
|
||||
if (source->text_str().find(exact, end_check_string) ==
|
||||
end_check_string) {
|
||||
if (source->text().find(exact, end_check_string) == end_check_string) {
|
||||
exactly = true;
|
||||
end_check_string += exact.size();
|
||||
}
|
||||
size_t end =
|
||||
assertFind(SourceRange(source, end_check_string, end_line), ":");
|
||||
auto count_view = source->text_str()
|
||||
.substr(end_check_string, end - end_check_string)
|
||||
.str();
|
||||
auto count_view =
|
||||
source->text().substr(end_check_string, end - end_check_string);
|
||||
count = c10::stoll(std::string(count_view.begin(), count_view.end()));
|
||||
end_check_string = end + 2; // add ':' and the space
|
||||
}
|
||||
auto check = Check(
|
||||
type,
|
||||
source->text_str()
|
||||
.substr(end_check_string, end_line - end_check_string)
|
||||
.str(),
|
||||
source->text().substr(end_check_string, end_line - end_check_string),
|
||||
count);
|
||||
addCheck(check);
|
||||
if (exactly) {
|
||||
|
|
@ -254,30 +253,32 @@ struct FileCheckImpl {
|
|||
return false;
|
||||
}
|
||||
|
||||
size_t findNextStart(const std::shared_ptr<Source>& source, size_t prev_end) {
|
||||
size_t start = source->text_str().find("#", prev_end);
|
||||
size_t findNextStart(
|
||||
const std::shared_ptr<SourceView>& source,
|
||||
size_t prev_end) {
|
||||
size_t start = source->text().find('#', prev_end);
|
||||
if (start == std::string::npos) {
|
||||
return start;
|
||||
}
|
||||
start += 1;
|
||||
static constexpr size_t max_whitespace = 6;
|
||||
size_t i = 0;
|
||||
while (start + i < source->size() && i < max_whitespace) {
|
||||
auto c = source->char_at(start + i);
|
||||
while (start + i < source->text().size() && i < max_whitespace) {
|
||||
auto c = source->text().at(start + i);
|
||||
if (c != ' ' && c != '\t') {
|
||||
break;
|
||||
}
|
||||
i++;
|
||||
}
|
||||
static const std::string check = "CHECK";
|
||||
if (source->text_str().substr(start + i, check.size()) == check) {
|
||||
if (source->text().substr(start + i, check.size()) == check) {
|
||||
return start + i + check.size();
|
||||
} else {
|
||||
return findNextStart(source, start + i + 1);
|
||||
}
|
||||
}
|
||||
|
||||
void parseStrings(const std::shared_ptr<Source>& source) {
|
||||
void parseStrings(const std::shared_ptr<SourceView>& source) {
|
||||
size_t start = 0;
|
||||
start = findNextStart(source, 0);
|
||||
while (start != std::string::npos) {
|
||||
|
|
@ -296,7 +297,7 @@ struct FileCheckImpl {
|
|||
|
||||
void doCheckNot(
|
||||
const std::vector<Check>& nots,
|
||||
const std::shared_ptr<Source>& source,
|
||||
const std::shared_ptr<SourceView>& source,
|
||||
const SourceRange& prev,
|
||||
const SourceRange& next) {
|
||||
auto start = prev.end(); // inclusive
|
||||
|
|
@ -313,7 +314,7 @@ struct FileCheckImpl {
|
|||
// Checks that source token is highlighted, does not advance search range.
|
||||
void doCheckSourceHighlighted(
|
||||
const Check& check,
|
||||
const std::shared_ptr<Source>& source,
|
||||
const std::shared_ptr<SourceView>& source,
|
||||
size_t start_offset) {
|
||||
auto construct_error_and_throw = [&](size_t error_start_pos) {
|
||||
SourceRange error_range(
|
||||
|
|
@ -329,8 +330,8 @@ struct FileCheckImpl {
|
|||
size_t search_start_offset = start_offset;
|
||||
bool found_token_at_least_once = false;
|
||||
size_t pos = search_start_offset;
|
||||
while (pos < source->size()) {
|
||||
pos = source->text_str().find(check.search_str_, search_start_offset);
|
||||
while (pos < source->text().size()) {
|
||||
pos = source->text().find(check.search_str_, search_start_offset);
|
||||
if (pos == std::string::npos) {
|
||||
break;
|
||||
}
|
||||
|
|
@ -348,16 +349,17 @@ struct FileCheckImpl {
|
|||
auto highlight_start_offset =
|
||||
source->offset_for_line(highlight_lineno) + col;
|
||||
auto highlight_end_offset = std::min(
|
||||
highlight_start_offset + check.search_str_.size(), source->size());
|
||||
highlight_start_offset + check.search_str_.size(),
|
||||
source->text().size());
|
||||
|
||||
if (highlight_end_offset >= source->size()) {
|
||||
if (highlight_end_offset >= source->text().size()) {
|
||||
construct_error_and_throw(pos);
|
||||
}
|
||||
|
||||
bool found_highlight = true;
|
||||
for (const auto posi :
|
||||
c10::irange(highlight_start_offset, highlight_end_offset)) {
|
||||
if (source->char_at(posi) != '~') {
|
||||
if (source->text()[posi] != '~') {
|
||||
found_highlight = false;
|
||||
}
|
||||
}
|
||||
|
|
@ -388,7 +390,7 @@ struct FileCheckImpl {
|
|||
|
||||
SourceRange matchDagGroup(
|
||||
const std::vector<Check>& group,
|
||||
const std::shared_ptr<Source>& source,
|
||||
const std::shared_ptr<SourceView>& source,
|
||||
const SourceRange& prev) {
|
||||
size_t group_beg = std::string::npos;
|
||||
size_t group_end = 0;
|
||||
|
|
@ -406,7 +408,7 @@ struct FileCheckImpl {
|
|||
|
||||
SourceRange matchGroup(
|
||||
const std::vector<Check>& group,
|
||||
const std::shared_ptr<Source>& source,
|
||||
const std::shared_ptr<SourceView>& source,
|
||||
const SourceRange& prev) {
|
||||
AT_ASSERT(group.size() != 0);
|
||||
CheckType type = group[0].type_;
|
||||
|
|
@ -465,7 +467,7 @@ struct FileCheckImpl {
|
|||
return SourceRange(source, start_range, end_range);
|
||||
}
|
||||
|
||||
void doChecks(const std::shared_ptr<Source>& source) {
|
||||
void doChecks(const std::shared_ptr<SourceView>& source) {
|
||||
SourceRange prev(source, 0, 0);
|
||||
for (size_t i = 0; i < groups.size(); i++) {
|
||||
const auto& curr_group = groups[i];
|
||||
|
|
@ -482,7 +484,7 @@ struct FileCheckImpl {
|
|||
++i; // already checked the group after
|
||||
} else {
|
||||
SourceRange end_of_file(
|
||||
source, source->size() + 1, source->size() + 1);
|
||||
source, source->text().size() + 1, source->text().size() + 1);
|
||||
doCheckNot(curr_group, source, prev, end_of_file);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -258,22 +258,7 @@ def get_model_info(
|
|||
# Parse debug info and add begin/end markers if not present
|
||||
# to ensure that we cover the entire source code.
|
||||
debug_info_t = pickle.loads(raw_debug)
|
||||
text_table = None
|
||||
|
||||
if (len(debug_info_t) == 3 and
|
||||
isinstance(debug_info_t[0], str) and
|
||||
debug_info_t[0] == 'FORMAT_WITH_STRING_TABLE'):
|
||||
_, text_table, content = debug_info_t
|
||||
|
||||
def parse_new_format(line):
|
||||
# (0, (('', '', 0), 0, 0))
|
||||
num, ((text_indexes, fname_idx, offset), start, end), tag = line
|
||||
text = ''.join(text_table[x] for x in text_indexes) # type: ignore[index]
|
||||
fname = text_table[fname_idx] # type: ignore[index]
|
||||
return num, ((text, fname, offset), start, end), tag
|
||||
|
||||
debug_info_t = map(parse_new_format, content)
|
||||
|
||||
assert isinstance(debug_info_t, tuple)
|
||||
debug_info = list(debug_info_t)
|
||||
if not debug_info:
|
||||
debug_info.append((0, (('', '', 0), 0, 0)))
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user