Make debug_pkl smaller by only emitting unique traces. (#72596)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/72596

debug_pkl file inside of pytorch's .pt file consists of a list of SourceRanges. Each SourceRange points to a Source which is a stack track, filename, and start, end numbers. Those are emitted in debug_pkl file as strings.

Since many SourceRange shares the same source, the string for trace can be deduped.

The newer format saves a set of unique traces in a tuple, then each SourceRange will save the offset of it's trace w.r.t. position in that tuple. (i.e. manually applying dictionary compression).

The above helps with smaller file size. On loading, if we copy each trace to Source as string the runtime memory would still blowup.
To mitigate this, we use SourceView directly instead of source which will take the reference of string inside of Deserializer and make that into string_view. This is safe because Deserializer is hold by Unpickler by shared_ptr, and Unpickler is also hold by shared_ptr by another Source object. That Source object will be alive during the model construction.

Test Plan:
unit test

Took original file (312271638_930.predictor.disagg.local); loaded with `torch.jit.load` save again with `torch.jit.save`. Unzip both, look at contents:
```
[qihan@devvm5585.vll0 ~]$ du archive -h
4.0K    archive/xl_model_weights
3.7M    archive/extra
8.0K    archive/code/__torch__/caffe2/torch/fb/model_transform/splitting
8.0K    archive/code/__torch__/caffe2/torch/fb/model_transform
8.0K    archive/code/__torch__/caffe2/torch/fb
8.0K    archive/code/__torch__/caffe2/torch
8.0K    archive/code/__torch__/caffe2
20M     archive/code/__torch__/torch/fx/graph_module
20M     archive/code/__torch__/torch/fx
8.0K    archive/code/__torch__/torch/classes
20M     archive/code/__torch__/torch
20M     archive/code/__torch__
20M     archive/code
2.7M    archive/constants
35M     archive
[qihan@devvm5585.vll0 ~]$ du resaved -h
4.0K    resaved/extra
8.0K    resaved/code/__torch__/caffe2/torch/fb/model_transform/splitting
8.0K    resaved/code/__torch__/caffe2/torch/fb/model_transform
8.0K    resaved/code/__torch__/caffe2/torch/fb
8.0K    resaved/code/__torch__/caffe2/torch
8.0K    resaved/code/__torch__/caffe2
1.3M    resaved/code/__torch__/torch/fx/graph_module
1.3M    resaved/code/__torch__/torch/fx
8.0K    resaved/code/__torch__/torch/classes
1.4M    resaved/code/__torch__/torch
1.4M    resaved/code/__torch__
1.4M    resaved/code
2.7M    resaved/constants
13M     resaved
[qihan@devvm5585.vll0 ~]$
```

Reviewed By: JasonHanwen

Differential Revision: D33994011

fbshipit-source-id: 8e6224c6e942e91c3403f686c8f0937d1002ed41
(cherry picked from commit a7014dd4029308c95007f362a57c31796d686647)
This commit is contained in:
Han Qi 2022-02-24 00:40:53 -08:00 committed by PyTorch MergeBot
parent 86deecd7be
commit 3d37f5b052
24 changed files with 576 additions and 203 deletions

View File

@ -0,0 +1,32 @@
#include <gtest/gtest.h>
#include <torch/csrc/jit/frontend/source_range.h>
using namespace ::testing;
using namespace ::torch::jit;
TEST(SourceRangeTest, test_find) {
std::vector<std::shared_ptr<std::string>> strings;
strings.push_back(std::make_shared<std::string>("hello world"));
strings.push_back(std::make_shared<std::string>("nihaoma"));
std::vector<c10::string_view> pieces{*strings[0], *strings[1]};
StringCordView view(pieces, strings);
auto x = view.find("rldni", 0);
EXPECT_EQ(x, 8);
}
TEST(SourceRangeTest, test_substr) {
std::vector<std::shared_ptr<std::string>> strings;
strings.push_back(std::make_shared<std::string>("hello world"));
strings.push_back(std::make_shared<std::string>("nihaoma"));
std::vector<c10::string_view> pieces{*strings[0], *strings[1]};
StringCordView view(pieces, strings);
auto x = view.substr(4, 10).str();
EXPECT_EQ(x, view.str().substr(4, 10));
EXPECT_EQ(view.substr(0, view.size()).str(), view.str());
}

View File

@ -47,7 +47,9 @@ static inline void trim(std::string& s) {
trim(substring_s); \
auto exception_string = std::string(e.what()); \
trim(exception_string); \
ASSERT_NE(exception_string.find(substring_s), std::string::npos); \
ASSERT_NE(exception_string.find(substring_s), std::string::npos) \
<< " Error was: \n" \
<< exception_string; \
}
namespace torch {

View File

@ -4453,7 +4453,8 @@ def foo(xyz):
return list(debug_files)
debug_files = debug_records_from_mod(ft3)
for debug_file in debug_files:
for dfile in debug_files:
_, table, debug_file = dfile
for i in range(len(debug_file) - 1):
offset, source_range_tag, source_range = debug_file[i]
offset2, source_range_tag2, source_range2 = debug_file[i + 1]

View File

@ -2,6 +2,7 @@
#include <ATen/core/Reduction.h>
#include <ATen/core/type_factory.h>
#include <c10/util/Optional.h>
#include <c10/util/string_utils.h>
#include <torch/csrc/jit/frontend/lexer.h>
#include <torch/csrc/jit/frontend/parse_string_literal.h>
@ -27,8 +28,13 @@ namespace jit {
namespace {
struct SchemaParser {
SchemaParser(const std::string& str)
: L(std::make_shared<SourceView>(c10::string_view(str))),
explicit SchemaParser(const std::string& str)
: L(std::make_shared<Source>(
c10::string_view(str),
c10::nullopt,
0,
nullptr,
Source::DONT_COPY)),
type_parser(L, /*parse_complete_tensor_types*/ false) {}
either<OperatorName, FunctionSchema> parseDeclaration() {

View File

@ -190,7 +190,7 @@ struct TORCH_API SharedParserData {
// find the longest match of str.substring(pos) against a token, return true
// if successful filling in kind, start,and len
bool match(
c10::string_view str,
StringCordView str,
size_t pos,
bool continuation, // are we inside a scope where newlines don't count
// (e.g. inside parens)
@ -241,12 +241,12 @@ struct TORCH_API SharedParserData {
// invariant: the next token is not whitespace or newline
*start = pos;
// check for a valid number
if (isNumber(str, pos, len)) {
if (isNumber(str.piece(0), pos, len)) {
*kind = TK_NUMBER;
return true;
}
// check for string
if (isString(str, pos, len)) {
if (isString(str.piece(0), pos, len)) {
*kind = TK_STRINGLITERAL;
return true;
}
@ -368,7 +368,7 @@ struct TORCH_API SharedParserData {
return isspace(n) && n != '\n';
}
// Make an exception ignoring comments for type annotation comments
bool isTypeComment(c10::string_view str, size_t pos) {
bool isTypeComment(StringCordView str, size_t pos) {
const std::string type_string = "# type:";
if (str.size() < pos + type_string.length()) {
return false;
@ -387,7 +387,7 @@ struct Token {
SourceRange range;
Token(int kind, SourceRange range) : kind(kind), range(std::move(range)) {}
std::string text() {
return range.text();
return range.text().str();
}
std::string kindString() const {
return kindToString(kind);
@ -395,7 +395,7 @@ struct Token {
};
struct Lexer {
explicit Lexer(std::shared_ptr<SourceView> source)
explicit Lexer(std::shared_ptr<Source> source)
: source(std::move(source)),
pos(0),
nesting(0),
@ -518,25 +518,19 @@ struct Lexer {
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
size_t length;
AT_ASSERT(source);
auto src = source->text_str();
if (!shared.match(
source->text(),
pos,
nesting > 0,
whitespace_token,
&kind,
&start,
&length)) {
src, pos, nesting > 0, whitespace_token, &kind, &start, &length)) {
expected(
"a valid token",
Token(
(source->text())[start], SourceRange(source, start, start + 1)));
Token(source->char_at(start), SourceRange(source, start, start + 1)));
}
auto t = Token(kind, SourceRange(source, start, start + length));
pos = start + length;
return t;
}
std::shared_ptr<SourceView> source;
std::shared_ptr<Source> source;
size_t pos;
size_t nesting; // depth of ( [ { nesting...
std::vector<int> indent_stack; // stack of indentation level of blocks

View File

@ -46,7 +46,7 @@ Decl mergeTypesFromTypeComment(
}
struct ParserImpl {
explicit ParserImpl(const std::shared_ptr<SourceView>& source)
explicit ParserImpl(const std::shared_ptr<Source>& source)
: L(source), shared(sharedParserData()) {}
Ident parseIdent() {
@ -801,7 +801,7 @@ struct ParserImpl {
SharedParserData& shared;
};
Parser::Parser(const std::shared_ptr<SourceView>& src)
Parser::Parser(const std::shared_ptr<Source>& src)
: pImpl(new ParserImpl(src)) {}
Parser::~Parser() = default;

View File

@ -17,7 +17,7 @@ TORCH_API Decl mergeTypesFromTypeComment(
bool is_method);
struct TORCH_API Parser {
explicit Parser(const std::shared_ptr<SourceView>& src);
explicit Parser(const std::shared_ptr<Source>& src);
TreeRef parseFunction(bool is_method);
TreeRef parseClass();
Decl parseTypeComment();

View File

@ -227,7 +227,7 @@ TypePtr ScriptTypeParser::parseTypeFromExpr(const Expr& expr) const {
// expression and base type names.
if (resolver_) {
if (auto typePtr =
resolver_->resolveType(expr.range().text(), expr.range())) {
resolver_->resolveType(expr.range().text().str(), expr.range())) {
return typePtr;
}
}

View File

@ -4,13 +4,140 @@
namespace torch {
namespace jit {
// A stringlike class backed by a vector of string_view
// the string represented are logically the concatenation of the string_views
// This has advantage of not needing continues memory.
StringCordView::StringCordView() {
accumulated_sizes_.push_back(0);
}
StringCordView::StringCordView(
std::vector<c10::string_view> inputs,
std::vector<std::shared_ptr<std::string>> ownerships)
: pieces_(std::move(inputs)), owned_strings_(std::move(ownerships)) {
accumulated_sizes_.push_back(0);
size_t running_sum = 0;
for (auto& s : pieces_) {
if (s.size() > 0) {
running_sum += s.size();
accumulated_sizes_.push_back(running_sum);
}
}
}
size_t StringCordView::find(const std::string& tok, size_t start) const {
if (tok.size() == 0) {
return 0;
}
if ((size() - start) < tok.size()) {
return std::string::npos;
}
Iterator begin = iter_for_pos(start);
Iterator end_iter = end();
size_t offset = start;
for (; begin != end_iter; ++begin, ++offset) {
if (*begin == tok[0]) {
auto mis = std::mismatch(begin, end_iter, tok.begin(), tok.end());
if (mis.second == tok.end()) {
// no mismatch, and second string (tok) is exhausted.
return offset;
}
if (mis.first == end_iter) {
// this str is exhausted but tok is not
return std::string::npos;
}
}
}
return std::string::npos;
}
StringCordView StringCordView::substr(size_t start, size_t size) const {
std::vector<c10::string_view> pieces;
std::vector<std::shared_ptr<std::string>> ownerships;
if (start >= this->size()) {
// out of bounds
return StringCordView();
}
if (start + size >= this->size()) {
size = this->size() - start;
}
Iterator begin = iter_for_pos(start);
Iterator end = iter_for_pos(start + size);
if (begin.line_ == end.line_) {
// same line
pieces.push_back(pieces_[begin.line_].substr(begin.pos_, size));
} else {
pieces.push_back(pieces_[begin.line_].substr(begin.pos_));
size_t last_line = pieces_.size();
if (end != this->end() && end.line_ < last_line) {
// end is within the string
last_line = end.line_;
}
for (size_t i = begin.line_ + 1; i < last_line; i++) {
pieces.push_back(pieces_[i]);
}
if (end != this->end()) {
pieces.push_back(pieces_[end.line_].substr(0, end.pos_));
}
}
// share ownership
std::copy(
owned_strings_.begin(),
owned_strings_.end(),
std::back_inserter(ownerships));
return StringCordView(std::move(pieces), std::move(ownerships));
}
bool StringCordView::operator==(const std::string& rhs) {
if (size() != rhs.size()) {
return false;
}
auto res = std::mismatch(begin(), end(), rhs.begin(), rhs.end());
// both need to exhaust
return res.first == end() && res.second == rhs.end();
}
bool StringCordView::operator==(const StringCordView& rhs) {
if (size() != rhs.size()) {
return false;
}
auto res = std::mismatch(begin(), end(), rhs.begin(), rhs.end());
// both need to exhaust
return res.first == end() && res.second == rhs.end();
}
StringCordView::Iterator StringCordView::iter_for_pos(size_t pos) const {
if (pos == 0) {
return begin();
}
if (pos >= size()) {
return end();
}
auto upper = std::upper_bound(
accumulated_sizes_.begin(), accumulated_sizes_.end(), pos);
if (upper == accumulated_sizes_.end()) {
return end();
}
size_t line = upper - accumulated_sizes_.begin() - 1;
assert(accumulated_sizes_[line] <= pos);
assert(accumulated_sizes_[line + 1] > pos);
return Iterator(this, line, pos - accumulated_sizes_[line], size() - pos);
}
size_t SourceRangeHasher::operator()(const torch::jit::SourceRange& key) const {
return (
std::hash<uintptr_t>()(reinterpret_cast<uintptr_t>(key.source().get())) ^
std::hash<size_t>()(key.start()) ^ std::hash<size_t>()(key.end()));
}
c10::optional<SourceRange> SourceView::findSourceRangeThatGenerated(
c10::optional<SourceRange> Source::findSourceRangeThatGenerated(
const SourceRange& range) {
if (!gen_ranges_) {
return c10::nullopt;
@ -73,7 +200,7 @@ C10_EXPORT void SourceRange::print_with_context(
return;
}
c10::string_view str = source_view_->text();
auto str = source_view_->text_str().str();
if (size() == str.size()) {
// this is just the entire file, not a subset, so print it out.
// primarily used to print out python stack traces
@ -141,7 +268,7 @@ C10_EXPORT void SourceRange::print_with_context(
line_end = start();
while (line_start < range_end) {
// move line_end to end of line
while (str[line_end] != '\n' && line_end < str.size()) {
while (line_end < str.size() && str[line_end] != '\n') {
++line_end;
}
// print line of code

View File

@ -4,43 +4,172 @@
#include <algorithm>
#include <iostream>
#include <iterator>
#include <memory>
#include <numeric>
#include <unordered_map>
namespace torch {
namespace jit {
class SourceRangeUnpickler;
struct SourceRange;
// SourceView represents a code segment. It keeps track of:
// A stringlike class backed by a vector of string_view
// the string represented are logically the concatenation of the string_views
// This has advantage of not needing continues memory.
struct TORCH_API StringCordView {
StringCordView();
StringCordView(
std::vector<c10::string_view> inputs,
std::vector<std::shared_ptr<std::string>> ownerships);
size_t size() const {
return accumulated_sizes_.back();
}
size_t find(const std::string& tok, size_t start) const;
StringCordView substr(size_t start, size_t size) const;
char at(size_t index) const {
return *iter_for_pos(index);
}
char operator[](size_t index) const {
return at(index);
}
std::string str() const {
std::stringstream ss;
for (auto s : pieces_) {
ss << std::string(s);
}
return ss.str();
}
bool operator==(const std::string& rhs);
bool operator==(const StringCordView& rhs);
c10::string_view piece(size_t index) const {
return pieces_[index];
}
struct Iterator {
Iterator(
const StringCordView* str,
size_t start_line,
size_t start_pos,
size_t size)
: line_(start_line), pos_(start_pos), str_(str), size_(size) {}
explicit Iterator(const StringCordView* str)
: Iterator(str, 0, 0, str->size()) {}
Iterator(const Iterator&) = default;
Iterator(Iterator&&) = default;
Iterator& operator=(const Iterator&) = default;
Iterator& operator=(Iterator&&) = default;
Iterator operator++() {
if (size_ == 0) {
return *this;
}
if ((pos_ + 1) < str_->pieces_[line_].size()) {
pos_++;
} else {
line_++;
pos_ = 0;
}
return *this;
}
Iterator operator++(int) {
Iterator prev(*this);
++(*this);
return prev;
}
bool operator==(const Iterator& rhs) const {
if (!has_next() && !rhs.has_next()) {
return true;
}
return (str_ == rhs.str_) && (line_ == rhs.line_) && (pos_ == rhs.pos_);
}
bool operator!=(const Iterator& rhs) {
return !((*this) == rhs);
}
bool has_next() const {
return size_ > 0 && (line_ < str_->pieces_.size());
}
char operator*() const {
TORCH_INTERNAL_ASSERT(line_ < str_->pieces_.size());
TORCH_INTERNAL_ASSERT(pos_ < str_->pieces_[line_].size());
return str_->pieces_[line_].at(pos_);
}
private:
size_t line_;
size_t pos_;
const StringCordView* str_;
size_t size_;
friend struct StringCordView;
};
Iterator begin() const {
return Iterator(this, 0, 0, size());
}
Iterator end() const {
return Iterator(this, pieces_.size(), 0, 0);
}
private:
Iterator iter_for_pos(size_t pos) const;
std::vector<c10::string_view> pieces_;
std::vector<size_t> accumulated_sizes_;
std::vector<std::shared_ptr<std::string>> owned_strings_;
};
// Source represents a code segment. It keeps track of:
// - text_view : the view into text of the code segment
// - filename (optional) : if present, represents the name of the file from
// which the code segment originated.
// - starting_line_no : represents the line in the original file where the
// code segment started.
struct SourceView {
explicit SourceView(
struct TORCH_API Source {
// Whether or not Source should copy the string passed in the constructor.
enum CopiesString { COPIES_STRING, DONT_COPY };
explicit Source(
c10::string_view text_view,
std::shared_ptr<SourceRangeUnpickler> gen_ranges = nullptr)
: text_view_(text_view),
filename_(c10::nullopt),
starting_line_no_(0),
c10::optional<std::string> filename = c10::nullopt,
size_t starting_line_no = 0,
std::shared_ptr<SourceRangeUnpickler> gen_ranges = nullptr,
CopiesString copies_str = COPIES_STRING)
: filename_(std::move(filename)),
starting_line_no_(starting_line_no),
gen_ranges_(std::move(gen_ranges)) {
if (copies_str == COPIES_STRING) {
std::shared_ptr<std::string> allocated_str =
std::make_shared<std::string>(text_view.data(), text_view.size());
text_view_ = StringCordView({*allocated_str}, {allocated_str});
} else {
text_view_ = StringCordView({text_view}, {});
}
calc_line_start_offsets();
}
SourceView(
c10::string_view text_view,
c10::optional<std::string> filename,
size_t starting_line_no,
Source(
StringCordView str,
c10::optional<std::string> filename = c10::nullopt,
size_t starting_line_no = 0,
std::shared_ptr<SourceRangeUnpickler> gen_ranges = nullptr)
: text_view_(text_view),
: text_view_(str),
filename_(std::move(filename)),
starting_line_no_(starting_line_no),
gen_ranges_(std::move(gen_ranges)) {
calc_line_start_offsets();
}
// Given a line number (within source_), return the byte offset of the
// beginning of that line.
size_t offset_for_line(size_t line) const {
@ -54,11 +183,9 @@ struct SourceView {
// Calculate the line (within the code segment) on which `offset` resides.
size_t lineno_for_offset(size_t offset) const {
return std::upper_bound(
line_starting_offsets_.begin(),
line_starting_offsets_.end(),
offset) -
line_starting_offsets_.begin() - 1;
auto iter = std::upper_bound(
line_starting_offsets_.begin(), line_starting_offsets_.end(), offset);
return iter - line_starting_offsets_.begin() - 1;
}
// Calculate the line (within the original source file, if present) on which
@ -71,11 +198,27 @@ struct SourceView {
}
}
const c10::string_view text() const {
StringCordView get_line(size_t lineno) const {
auto start = offset_for_line(lineno);
auto size = (lineno + 1) < num_lines() ? offset_for_line(lineno + 1) - start
: text_view_.size() - start;
return text_view_.substr(start, size);
}
// Note: this makes a copy
StringCordView text_str() const {
return text_view_;
}
const c10::optional<std::string>& filename() const {
char char_at(size_t index) const {
return text_view_.at(index);
}
size_t size() const {
return text_view_.size();
}
c10::optional<std::string>& filename() {
return filename_;
}
@ -86,18 +229,20 @@ struct SourceView {
c10::optional<SourceRange> findSourceRangeThatGenerated(
const SourceRange& range);
protected:
c10::string_view text_view_;
~Source() = default;
private:
void calc_line_start_offsets() {
line_starting_offsets_.clear();
line_starting_offsets_.push_back(0);
size_t pos = 0;
while ((pos = text().find('\n', pos)) != std::string::npos) {
while ((pos = text_view_.find("\n", pos)) != std::string::npos) {
line_starting_offsets_.push_back(++pos);
}
}
StringCordView text_view_;
c10::optional<std::string> filename_;
// If filename_ is not present, starting_line_no_ is don't care
size_t starting_line_no_;
@ -108,67 +253,15 @@ struct SourceView {
std::shared_ptr<SourceRangeUnpickler> gen_ranges_;
};
// Source represents a code segment like SourceView, but the former owns a copy
// of source text while the latter doesn't.
struct Source : public SourceView {
explicit Source(
std::string text,
std::shared_ptr<SourceRangeUnpickler> gen_ranges = nullptr)
: SourceView(text, gen_ranges), text_(std::move(text)) {
text_view_ = text_;
}
explicit Source(
c10::string_view text_view,
std::shared_ptr<SourceRangeUnpickler> gen_ranges = nullptr)
: SourceView(text_view, gen_ranges),
text_(text_view.begin(), text_view.end()) {
text_view_ = text_;
}
explicit Source(
std::string text,
c10::optional<std::string> filename,
size_t starting_line_no,
std::shared_ptr<SourceRangeUnpickler> gen_ranges = nullptr)
: SourceView(text, filename, starting_line_no, gen_ranges),
text_(std::move(text)) {
text_view_ = text_;
}
explicit Source(
c10::string_view text_view,
c10::optional<std::string> filename,
size_t starting_line_no,
std::shared_ptr<SourceRangeUnpickler> gen_ranges = nullptr)
: SourceView(text_view, filename, starting_line_no, gen_ranges),
text_(text_view.begin(), text_view.end()) {
text_view_ = text_;
}
// Constructor that deepcopies and owns source text referenced in
// `source_view`.
explicit Source(const SourceView& source_view) : SourceView(source_view) {
text_ = std::string(text_view_.begin(), text_view_.end());
text_view_ = text_;
}
std::string text_;
};
// A SourceRange is a reference to subset of a Source, specified by `start` and
// `end` byte offsets into the source text.
struct TORCH_API SourceRange {
SourceRange(
std::shared_ptr<SourceView> source_view_,
size_t start_,
size_t end_)
SourceRange(std::shared_ptr<Source> source_view_, size_t start_, size_t end_)
: source_view_(std::move(source_view_)), start_(start_), end_(end_) {}
SourceRange() : source_view_(nullptr), start_(0), end_(0) {}
const std::string text() const {
auto text_view = source_view_->text().substr(start(), end() - start());
return std::string(text_view.begin(), text_view.end());
const StringCordView text() const {
return source_view_->text_str().substr(start(), end() - start());
}
size_t size() const {
return end() - start();
@ -183,7 +276,7 @@ struct TORCH_API SourceRange {
bool highlight,
const std::string& funcname) const;
const std::shared_ptr<SourceView>& source() const {
const std::shared_ptr<Source>& source() const {
return source_view_;
}
size_t start() const {
@ -229,7 +322,7 @@ struct TORCH_API SourceRange {
}
protected:
std::shared_ptr<SourceView> source_view_;
std::shared_ptr<Source> source_view_;
private:
size_t start_;
@ -237,13 +330,16 @@ struct TORCH_API SourceRange {
};
// OwnedSourceRange is just like a SourceRange except that it owns a `Source`
// instead of `SourceView`. Thus OwnedSourceRange owns a copy of source text.
// instead of `Source`. Thus OwnedSourceRange owns a copy of source text.
struct OwnedSourceRange : public SourceRange {
OwnedSourceRange(const SourceRange& source_range)
explicit OwnedSourceRange(const SourceRange& source_range)
: SourceRange(source_range) {
const auto& source = source_range.source();
if (source) {
source_view_ = std::make_shared<Source>(*source);
source_view_ = std::make_shared<Source>(
source->text_str().str(),
source->filename(),
source->starting_line_no());
}
}
};
@ -281,3 +377,14 @@ using SourceRangeTagMap =
} // namespace jit
} // namespace torch
namespace std {
template <>
struct iterator_traits<torch::jit::StringCordView::Iterator> {
using value_type = char;
using difference_type = ptrdiff_t;
using pointer = char*;
using reference = char&;
using iterator_category = std::forward_iterator_tag;
};
} // namespace std

View File

@ -21,26 +21,26 @@ namespace jit {
*/
class TORCH_API SourceRef : public CustomClassHolder {
public:
explicit SourceRef(std::shared_ptr<SourceView> source_view)
explicit SourceRef(std::shared_ptr<Source> source_view)
: source_view_(std::move(source_view)) {}
bool operator==(const SourceRef& other) const {
return source_view_ == other.source_view_;
}
bool operator<(const SourceView& other) const {
bool operator<(const Source& other) const {
return source_view_.get() < &other;
}
friend bool operator<(const SourceView& other, const SourceRef& self) {
friend bool operator<(const Source& other, const SourceRef& self) {
return &other < self.source_view_.get();
}
bool operator<(const SourceRef& other) const {
return *this < *other.source_view_.get();
}
const SourceView* operator->() const {
const Source* operator->() const {
return source_view_.get();
}
private:
std::shared_ptr<SourceView> source_view_;
std::shared_ptr<Source> source_view_;
};
} // namespace jit

View File

@ -122,17 +122,26 @@ MobileDebugTable::MobileDebugTable(
at::DataPtr debug_data;
size_t debug_size{0};
std::tie(debug_data, debug_size) = reader->getRecord(record_name);
auto ivalues =
std::move(*jit::unpickle(
auto ivalueTuple = jit::unpickle(
reinterpret_cast<const char*>(debug_data.get()),
debug_size,
nullptr,
{},
c10::parseType)
.toTuple())
.elements();
SourceRangeDeserializer deserializer;
for (auto& val : ivalues) {
c10::parseType);
const auto& ivalues = ivalueTuple.toTuple()->elements();
IValue lines;
std::unique_ptr<SourceRangeDeserializer> deserializer;
if (ivalues.size() == 3 && ivalues[0].isString() &&
kFormatWithStringTable == ivalues[0].toStringRef()) {
// new format
deserializer = std::make_unique<SourceRangeDeserializer>(ivalues[1]);
lines = ivalues[2];
} else {
deserializer = std::make_unique<SourceRangeDeserializer>();
lines = ivalueTuple;
}
for (auto& val : lines.toTuple()->elements()) {
auto tup_elems = std::move(*std::move(val).toTuple()).elements();
// For BC we decode only tuples with 3 elements
// assuming it contains
@ -140,7 +149,7 @@ MobileDebugTable::MobileDebugTable(
if (tup_elems.size() == 3) {
int64_t debug_handle = tup_elems[kSourceRangeTagIndex].toInt();
auto source_range =
deserializer.deserialize(tup_elems[kSourceRangeIndex]);
deserializer->deserialize(tup_elems[kSourceRangeIndex]);
source_range_map.emplace(debug_handle, std::move(source_range));
}
}

View File

@ -3,6 +3,7 @@
#include <caffe2/serialize/inline_container.h>
#include <torch/csrc/jit/api/compilation_unit.h>
#include <torch/csrc/jit/ir/scope.h>
#include <torch/csrc/jit/serialization/source_range_serialization.h>
namespace torch {
namespace jit {

View File

@ -7,7 +7,7 @@ bool IndexingPatternFinder::IsSameSource(const Node* n, const Node* m) {
const auto source_n = n->sourceRange().source();
const auto source_m = m->sourceRange().source();
return (
(source_n->text() == source_m->text()) &&
(source_n->text_str() == source_m->text_str()) &&
(source_n->starting_line_no() == source_m->starting_line_no()));
}

View File

@ -104,9 +104,8 @@ void initTreeViewBindings(PyObject* module) {
return SourceRange(self.source_, start, end);
})
.def_property_readonly("source", [](const SourceRangeFactory& self) {
auto text_view = self.source_->text();
std::string text(text_view.begin(), text_view.end());
return text;
auto text_view = self.source_->text_str().str();
return text_view;
});
py::class_<TreeView>(m, "TreeView")

View File

@ -61,7 +61,7 @@ auto initBindings() {
return static_cast<int64_t>((*self)->starting_line_no());
})
.def("text", [](const c10::intrusive_ptr<SourceRef>& self) {
return (*self)->text();
return (*self)->text_str().str();
});
torch::class_<InstructionStats>("profiling", "InstructionStats")

View File

@ -22,7 +22,7 @@ std::string qualifierToArchivePath(
return export_prefix + path + "." + kExportSuffix;
}
std::shared_ptr<SourceView> findSourceInArchiveFromQualifier(
std::shared_ptr<Source> findSourceInArchiveFromQualifier(
caffe2::serialize::PyTorchStreamReader& reader,
const std::string& export_prefix,
const std::string& qualifier) {

View File

@ -12,7 +12,7 @@ class PyTorchStreamReader;
namespace torch {
namespace jit {
struct SourceView;
struct Source;
// Convert a class type's qualifier name to the corresponding path the source
// file it should be written to.
@ -23,7 +23,7 @@ std::string qualifierToArchivePath(
const std::string& qualifier,
const std::string& export_prefix);
std::shared_ptr<SourceView> findSourceInArchiveFromQualifier(
std::shared_ptr<Source> findSourceInArchiveFromQualifier(
caffe2::serialize::PyTorchStreamReader& reader,
const std::string& export_prefix,
const std::string& qualifier);

View File

@ -159,7 +159,7 @@ void SourceImporterImpl::parseSourceIfNeeded(const std::string& qualifier) {
return;
}
loaded_sources_.insert(qualifier);
std::shared_ptr<SourceView> src = source_loader_(qualifier);
std::shared_ptr<Source> src = source_loader_(qualifier);
// The importer, when looking for classes/functions doesn't know if 'foo'
// contains definitions or if it is a prefix of 'foo.bar', we only figure it

View File

@ -20,8 +20,7 @@
namespace torch {
namespace jit {
using SourceLoader =
std::function<std::shared_ptr<SourceView>(const std::string&)>;
using SourceLoader = std::function<std::shared_ptr<Source>(const std::string&)>;
struct SourceImporterImpl : public Resolver,
std::enable_shared_from_this<SourceImporterImpl> {

View File

@ -1,8 +1,10 @@
#include <torch/csrc/jit/serialization/source_range_serialization.h>
#include <torch/csrc/jit/serialization/source_range_serialization_impl.h>
#include <c10/util/Exception.h>
#include <torch/csrc/jit/mobile/type_parser.h>
#include <torch/csrc/jit/serialization/pickle.h>
#include <algorithm>
namespace torch {
namespace jit {
@ -10,45 +12,80 @@ namespace jit {
class SourceRangeSerializer {
public:
// Serialize SourceRange as Tuple[SourceType, int, int]
// where SourceType = Tuple[str, Optional[str], int, List[int]],
// where SourceType = Tuple[int, int, int, List[int]],
// The first 2 ints are positions into the vector returned by textSaved
// after all the Ranges are processed. textSaved() returns a vector of str
// the serialized form of Source
c10::IValue serialize(const SourceRange& sr);
const std::vector<c10::IValue>& texts_saved() {
return texts_;
}
SourceRangeSerializer() {
texts_.emplace_back("");
text_to_idx_[texts_.back().toStringRef()] = 0;
}
private:
// Serialize Source as Tuple[str, Optional[str], int, List[int]]
// This caches serialized sources, since many SourceRanges can
// refer to the same one.
c10::IValue serialize_source(const std::shared_ptr<SourceView>& s);
c10::IValue serialize_source(const std::shared_ptr<Source>& s);
std::unordered_map<std::shared_ptr<Source>, c10::IValue> serialized_sources;
std::unordered_map<std::shared_ptr<SourceView>, c10::IValue>
serialized_sources;
int64_t store_text_and_get_index(const std::string& text_view);
std::vector<c10::IValue> texts_;
std::unordered_map<c10::string_view, int64_t> text_to_idx_;
};
SourceRange SourceRangeDeserializer::deserialize(const c10::IValue& iv) {
const auto& tup_elems = iv.toTupleRef().elements();
TORCH_INTERNAL_ASSERT(tup_elems.size() == 3);
std::shared_ptr<SourceView> source_ = deserialize_source(tup_elems[0]);
std::shared_ptr<Source> source_ = deserialize_source(tup_elems[0]);
int64_t start_ = tup_elems[1].toInt();
int64_t end_ = tup_elems[2].toInt();
return SourceRange(source_, start_, end_);
}
std::shared_ptr<SourceView> SourceRangeDeserializer::deserialize_source(
std::shared_ptr<Source> SourceRangeDeserializer::deserialize_source(
const c10::IValue& iv) {
auto tup = iv.toTuple();
auto it = cached_sources.find(tup);
if (it != cached_sources.end()) {
return it->second;
}
std::shared_ptr<Source> source;
const auto& tup_elems = tup->elements();
TORCH_INTERNAL_ASSERT(tup_elems.size() == 3);
std::string text_ = tup_elems[0].toString()->string();
c10::optional<std::string> filename_ = tup_elems[1].toOptional<std::string>();
if (!text_table_.empty()) {
const auto& textIndex = tup_elems[0].toIntList();
int64_t fnameIndex = tup_elems[1].toInt();
int64_t starting_line_no_ = tup_elems[2].toInt();
c10::optional<std::string> filename = c10::nullopt;
auto source = std::make_shared<Source>(
filename = *text_table_[fnameIndex];
std::vector<c10::string_view> pieces;
std::vector<std::shared_ptr<std::string>> strs;
for (int64_t i : textIndex) {
pieces.emplace_back(*text_table_[i]);
strs.emplace_back(text_table_[i]);
}
StringCordView str_cord(std::move(pieces), std::move(strs));
source = std::make_shared<Source>(str_cord, filename, starting_line_no_);
} else {
std::string text_ = tup_elems[0].toString()->string();
c10::optional<std::string> filename_ =
tup_elems[1].toOptional<std::string>();
int64_t starting_line_no_ = tup_elems[2].toInt();
source = std::make_shared<Source>(
std::move(text_), std::move(filename_), starting_line_no_);
}
cached_sources[tup] = source;
return source;
}
@ -58,17 +95,41 @@ c10::IValue SourceRangeSerializer::serialize(const SourceRange& sr) {
serialize_source(sr.source()), (int64_t)sr.start(), (int64_t)sr.end());
}
int64_t SourceRangeSerializer::store_text_and_get_index(
const std::string& text_view) {
auto text_iter = text_to_idx_.find(text_view);
if (text_iter == text_to_idx_.end()) {
int64_t text_pos = static_cast<int64_t>(texts_.size());
texts_.emplace_back(text_view);
text_to_idx_[texts_.back().toStringView()] = text_pos;
return text_pos;
} else {
return text_iter->second;
}
}
c10::IValue SourceRangeSerializer::serialize_source(
const std::shared_ptr<SourceView>& s) {
const std::shared_ptr<Source>& s) {
if (serialized_sources.count(s)) {
return serialized_sources.at(s);
}
c10::intrusive_ptr<c10::ivalue::Tuple> serialized;
c10::List<int64_t> lines;
if (s == nullptr) {
serialized = c10::ivalue::Tuple::create({"", "", 0});
serialized = c10::ivalue::Tuple::create({lines, 0, 0});
} else {
for (size_t lineno = 0; lineno < s->num_lines(); lineno++) {
std::string line_content = s->get_line(lineno).str();
int64_t text_pos = store_text_and_get_index(line_content);
lines.push_back(text_pos);
}
int64_t fname_pos = 0;
if (s->filename().has_value()) {
fname_pos = store_text_and_get_index(*s->filename());
}
serialized = c10::ivalue::Tuple::create(
{s->text(), s->filename(), (int64_t)s->starting_line_no()});
{lines, fname_pos, (int64_t)s->starting_line_no()});
}
serialized_sources[s] = serialized;
return serialized;
@ -86,14 +147,19 @@ std::vector<char> SourceRangePickler::pickle(
if (it != source_range_tags.end()) {
source_range_tag = it->second;
}
ivalues.emplace_back(c10::ivalue::Tuple::create(
{(int64_t)range.bytes,
srs->serialize(range.range),
static_cast<int64_t>(source_range_tag)}));
}
std::vector<at::Tensor> table;
auto textTable = c10::ivalue::Tuple::create(srs->texts_saved());
auto ivalue = c10::ivalue::Tuple::create(std::move(ivalues));
auto result = jit::pickle(ivalue, &table);
auto result = jit::pickle(
c10::ivalue::Tuple::create({kFormatWithStringTable, textTable, ivalue}),
&table);
TORCH_CHECK(table.size() == 0, "Expected 0 tensors to be written");
return result;
}
@ -103,7 +169,7 @@ ConcreteSourceRangeUnpickler::ConcreteSourceRangeUnpickler(
size_t size)
: data(std::move(data)),
size(size),
deserializer(new SourceRangeDeserializer()),
deserializer(nullptr),
unpickled_records(nullptr) {}
void ConcreteSourceRangeUnpickler::unpickle() {
@ -119,10 +185,19 @@ void ConcreteSourceRangeUnpickler::unpickle() {
{},
c10::parseType)
.toTuple();
const auto& ivalues = ivaluesTuple->elements();
const auto& ivalues = ivaluesTuple->elements();
unpickled_records = std::make_shared<SourceRangeRecords>();
for (auto& val : ivalues) {
IValue lines;
if (ivalues[0].isString() &&
kFormatWithStringTable == ivalues[0].toStringRef()) {
deserializer.reset(new SourceRangeDeserializer(ivalues[1]));
lines = ivalues[2];
} else {
deserializer.reset(new SourceRangeDeserializer());
lines = ivaluesTuple;
}
for (auto& val : lines.toTuple()->elements()) {
const auto& tup_elems = val.toTupleRef().elements();
int64_t offset = tup_elems[kByteOffsetIndex].toInt();
auto source_range = deserializer->deserialize(tup_elems[kSourceRangeIndex]);

View File

@ -20,6 +20,7 @@ class SourceRangeSerializer;
static constexpr size_t kByteOffsetIndex = 0;
static constexpr size_t kSourceRangeIndex = 1;
static constexpr size_t kSourceRangeTagIndex = 2;
constexpr c10::string_view kFormatWithStringTable = "FORMAT_WITH_STRING_TABLE";
class SourceRangePickler {
public:
@ -35,14 +36,21 @@ class SourceRangePickler {
class SourceRangeDeserializer {
public:
SourceRangeDeserializer() = default;
explicit SourceRangeDeserializer(c10::IValue text_table) {
for (const auto& x : text_table.toTuple()->elements()) {
text_table_.emplace_back(std::make_shared<std::string>(x.toStringRef()));
}
}
SourceRange deserialize(const c10::IValue& iv);
private:
std::shared_ptr<SourceView> deserialize_source(const c10::IValue& iv);
std::shared_ptr<Source> deserialize_source(const c10::IValue& iv);
std::unordered_map<
c10::intrusive_ptr<c10::ivalue::Tuple>,
std::shared_ptr<SourceView>>
std::shared_ptr<Source>>
cached_sources;
std::vector<std::shared_ptr<std::string>> text_table_;
};
class SourceRangeUnpickler {

View File

@ -94,7 +94,7 @@ size_t assertFind(
const SourceRange& search_range,
const std::string& sub,
const std::function<void(std::ostream& out)>& extra_msg = nullptr) {
auto pos = search_range.source()->text().find(sub, search_range.start());
auto pos = search_range.source()->text_str().find(sub, search_range.start());
if (pos == std::string::npos || (pos + sub.size()) > search_range.end()) {
auto found_range =
SourceRange(search_range.source(), search_range.start(), sub.size());
@ -122,19 +122,18 @@ size_t assertFind(
}
size_t assertFind(
const std::shared_ptr<SourceView>& source,
const std::shared_ptr<Source>& source,
const std::string& sub,
size_t start,
const Check& check) {
return assertFind(
SourceRange(source, start, source->text().size()), sub, check);
return assertFind(SourceRange(source, start, source->size()), sub, check);
}
void assertNotFind(
const SourceRange& search_range,
const std::string& sub,
const Check& check) {
auto pos = search_range.source()->text().find(sub, search_range.start());
auto pos = search_range.source()->text_str().find(sub, search_range.start());
if (pos != std::string::npos && (pos + sub.size()) <= search_range.end()) {
auto found_range =
SourceRange(search_range.source(), pos, sub.size() + pos);
@ -202,9 +201,7 @@ struct FileCheckImpl {
friend std::ostream& operator<<(std::ostream& out, const FileCheckImpl& fc);
private:
bool parseSingleCheck(
const std::shared_ptr<SourceView>& source,
size_t* start) {
bool parseSingleCheck(const std::shared_ptr<Source>& source, size_t* start) {
const static std::vector<std::pair<CheckType, std::string>> check_pairs = {
{CHECK, ": "},
{CHECK_NEXT, "-NEXT: "},
@ -217,31 +214,35 @@ struct FileCheckImpl {
for (const auto& check_pair : check_pairs) {
const std::string& check_suffix = check_pair.second;
auto suffix_pos = source->text().find(check_suffix, *start);
auto suffix_pos = source->text_str().find(check_suffix, *start);
if (suffix_pos != *start) {
continue;
}
size_t end_check_string = suffix_pos + check_suffix.size();
CheckType type = check_pair.first;
c10::optional<size_t> count = c10::nullopt;
auto end_line = source->text().find('\n', end_check_string);
auto end_line = source->text_str().find("\n", end_check_string);
bool exactly = false;
if (type == CHECK_COUNT) {
const std::string exact = "EXACTLY-";
if (source->text().find(exact, end_check_string) == end_check_string) {
if (source->text_str().find(exact, end_check_string) ==
end_check_string) {
exactly = true;
end_check_string += exact.size();
}
size_t end =
assertFind(SourceRange(source, end_check_string, end_line), ":");
auto count_view =
source->text().substr(end_check_string, end - end_check_string);
auto count_view = source->text_str()
.substr(end_check_string, end - end_check_string)
.str();
count = c10::stoll(std::string(count_view.begin(), count_view.end()));
end_check_string = end + 2; // add ':' and the space
}
auto check = Check(
type,
source->text().substr(end_check_string, end_line - end_check_string),
source->text_str()
.substr(end_check_string, end_line - end_check_string)
.str(),
count);
addCheck(check);
if (exactly) {
@ -253,32 +254,30 @@ struct FileCheckImpl {
return false;
}
size_t findNextStart(
const std::shared_ptr<SourceView>& source,
size_t prev_end) {
size_t start = source->text().find('#', prev_end);
size_t findNextStart(const std::shared_ptr<Source>& source, size_t prev_end) {
size_t start = source->text_str().find("#", prev_end);
if (start == std::string::npos) {
return start;
}
start += 1;
static constexpr size_t max_whitespace = 6;
size_t i = 0;
while (start + i < source->text().size() && i < max_whitespace) {
auto c = source->text().at(start + i);
while (start + i < source->size() && i < max_whitespace) {
auto c = source->char_at(start + i);
if (c != ' ' && c != '\t') {
break;
}
i++;
}
static const std::string check = "CHECK";
if (source->text().substr(start + i, check.size()) == check) {
if (source->text_str().substr(start + i, check.size()) == check) {
return start + i + check.size();
} else {
return findNextStart(source, start + i + 1);
}
}
void parseStrings(const std::shared_ptr<SourceView>& source) {
void parseStrings(const std::shared_ptr<Source>& source) {
size_t start = 0;
start = findNextStart(source, 0);
while (start != std::string::npos) {
@ -297,7 +296,7 @@ struct FileCheckImpl {
void doCheckNot(
const std::vector<Check>& nots,
const std::shared_ptr<SourceView>& source,
const std::shared_ptr<Source>& source,
const SourceRange& prev,
const SourceRange& next) {
auto start = prev.end(); // inclusive
@ -314,7 +313,7 @@ struct FileCheckImpl {
// Checks that source token is highlighted, does not advance search range.
void doCheckSourceHighlighted(
const Check& check,
const std::shared_ptr<SourceView>& source,
const std::shared_ptr<Source>& source,
size_t start_offset) {
auto construct_error_and_throw = [&](size_t error_start_pos) {
SourceRange error_range(
@ -330,8 +329,8 @@ struct FileCheckImpl {
size_t search_start_offset = start_offset;
bool found_token_at_least_once = false;
size_t pos = search_start_offset;
while (pos < source->text().size()) {
pos = source->text().find(check.search_str_, search_start_offset);
while (pos < source->size()) {
pos = source->text_str().find(check.search_str_, search_start_offset);
if (pos == std::string::npos) {
break;
}
@ -349,17 +348,16 @@ struct FileCheckImpl {
auto highlight_start_offset =
source->offset_for_line(highlight_lineno) + col;
auto highlight_end_offset = std::min(
highlight_start_offset + check.search_str_.size(),
source->text().size());
highlight_start_offset + check.search_str_.size(), source->size());
if (highlight_end_offset >= source->text().size()) {
if (highlight_end_offset >= source->size()) {
construct_error_and_throw(pos);
}
bool found_highlight = true;
for (const auto posi :
c10::irange(highlight_start_offset, highlight_end_offset)) {
if (source->text()[posi] != '~') {
if (source->char_at(posi) != '~') {
found_highlight = false;
}
}
@ -390,7 +388,7 @@ struct FileCheckImpl {
SourceRange matchDagGroup(
const std::vector<Check>& group,
const std::shared_ptr<SourceView>& source,
const std::shared_ptr<Source>& source,
const SourceRange& prev) {
size_t group_beg = std::string::npos;
size_t group_end = 0;
@ -408,7 +406,7 @@ struct FileCheckImpl {
SourceRange matchGroup(
const std::vector<Check>& group,
const std::shared_ptr<SourceView>& source,
const std::shared_ptr<Source>& source,
const SourceRange& prev) {
AT_ASSERT(group.size() != 0);
CheckType type = group[0].type_;
@ -467,7 +465,7 @@ struct FileCheckImpl {
return SourceRange(source, start_range, end_range);
}
void doChecks(const std::shared_ptr<SourceView>& source) {
void doChecks(const std::shared_ptr<Source>& source) {
SourceRange prev(source, 0, 0);
for (size_t i = 0; i < groups.size(); i++) {
const auto& curr_group = groups[i];
@ -484,7 +482,7 @@ struct FileCheckImpl {
++i; // already checked the group after
} else {
SourceRange end_of_file(
source, source->text().size() + 1, source->text().size() + 1);
source, source->size() + 1, source->size() + 1);
doCheckNot(curr_group, source, prev, end_of_file);
}
}

View File

@ -258,7 +258,22 @@ def get_model_info(
# Parse debug info and add begin/end markers if not present
# to ensure that we cover the entire source code.
debug_info_t = pickle.loads(raw_debug)
assert isinstance(debug_info_t, tuple)
text_table = None
if (len(debug_info_t) == 3 and
isinstance(debug_info_t[0], str) and
debug_info_t[0] == 'FORMAT_WITH_STRING_TABLE'):
_, text_table, content = debug_info_t
def parse_new_format(line):
# (0, (('', '', 0), 0, 0))
num, ((text_indexes, fname_idx, offset), start, end), tag = line
text = ''.join(text_table[x] for x in text_indexes) # type: ignore
fname = text_table[fname_idx] # type: ignore
return num, ((text, fname, offset), start, end), tag
debug_info_t = map(parse_new_format, content)
debug_info = list(debug_info_t)
if not debug_info:
debug_info.append((0, (('', '', 0), 0, 0)))