pytorch/torch/csrc/jit/jit_log.cpp
Dmytro Dzhulgakov 2f4c31ce3a [jit] Speed up saving in case of many classes (#44589)
Summary:
There's an annoying O(N^2) in module export logic that makes saving some of the models (if they have many classes) take eternity.

I'm not super familiar with this code to properly untangle the deps and make it a pure hash lookup. So I just added a side lookup table for raw pointers. It's still quadratic, but it's O(num_classes^2) instead of O(num_classes * num_references) which already gives huge savings.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/44589

Test Plan:
Tested with one of the offending models - just loading a saving a Torchscript file:

```
Before:
load 1.9239683151245117
save 165.74712467193604

After:
load 1.9409027099609375
save 1.4711427688598633
```

Reviewed By: suo

Differential Revision: D23675278

Pulled By: dzhulgakov

fbshipit-source-id: 8f3fa7730941085ea20d9255b49a149ac1bf64fe
2020-09-15 15:10:45 -07:00

134 lines
3.7 KiB
C++

#include <cstdlib>
#include <iomanip>
#include <sstream>
#include <vector>
#include <ATen/core/function.h>
#include <c10/util/Exception.h>
#include <c10/util/StringUtil.h>
#include <torch/csrc/jit/api/function_impl.h>
#include <torch/csrc/jit/frontend/error_report.h>
#include <torch/csrc/jit/ir/ir.h>
#include <torch/csrc/jit/jit_log.h>
#include <torch/csrc/jit/serialization/python_print.h>
namespace torch {
namespace jit {
// gets a string representation of a node header
// (e.g. outputs, a node kind and outputs)
std::string getHeader(const Node* node) {
std::stringstream ss;
node->print(ss, 0, {}, false, false, false, false);
return ss.str();
}
static std::unordered_map<std::string, size_t> parseJITLogOption(
const char* option) {
std::stringstream in_ss;
in_ss << "function:";
if (option) {
in_ss << option;
}
std::unordered_map<std::string, size_t> files_to_levels;
std::string line;
while (std::getline(in_ss, line, ':')) {
if (line.size() == 0) {
continue;
}
auto index_at = line.find_last_of('>');
auto begin_index = index_at == std::string::npos ? 0 : index_at + 1;
size_t logging_level = index_at == std::string::npos ? 0 : index_at + 1;
auto end_index = line.find_last_of('.') == std::string::npos
? line.size()
: line.find_last_of('.');
auto filename = line.substr(begin_index, end_index - begin_index);
files_to_levels.insert({filename, logging_level});
}
return files_to_levels;
}
bool is_enabled(const char* cfname, JitLoggingLevels level) {
static const char* c_log_level = std::getenv("PYTORCH_JIT_LOG_LEVEL");
static const std::unordered_map<std::string, size_t> files_to_levels =
parseJITLogOption(c_log_level);
std::string fname{cfname};
fname = c10::detail::StripBasename(fname);
auto end_index = fname.find_last_of('.') == std::string::npos
? fname.size()
: fname.find_last_of('.');
auto fname_no_ext = fname.substr(0, end_index);
auto it = files_to_levels.find(fname_no_ext);
if (it == files_to_levels.end()) {
return false;
}
return level <= static_cast<JitLoggingLevels>(it->second);
}
// Unfortunately, in `GraphExecutor` where `log_function` is invoked
// we won't have access to an original function, so we have to construct
// a dummy function to give to PythonPrint
std::string log_function(const std::shared_ptr<torch::jit::Graph>& graph) {
torch::jit::GraphFunction func("source_dump", graph, nullptr);
std::vector<at::IValue> constants;
PrintDepsTable deps;
PythonPrint pp(constants, deps);
pp.printFunction(func);
return pp.str();
}
std::string jit_log_prefix(
const std::string& prefix,
const std::string& in_str) {
std::stringstream in_ss(in_str);
std::stringstream out_ss;
std::string line;
while (std::getline(in_ss, line)) {
out_ss << prefix << line << std::endl;
}
return out_ss.str();
}
std::string jit_log_prefix(
JitLoggingLevels level,
const char* fn,
int l,
const std::string& in_str) {
std::stringstream prefix_ss;
prefix_ss << "[";
prefix_ss << level << " ";
prefix_ss << c10::detail::StripBasename(std::string(fn)) << ":";
prefix_ss << std::setfill('0') << std::setw(3) << l;
prefix_ss << "] ";
return jit_log_prefix(prefix_ss.str(), in_str);
}
std::ostream& operator<<(std::ostream& out, JitLoggingLevels level) {
switch (level) {
case JitLoggingLevels::GRAPH_DUMP:
out << "DUMP";
break;
case JitLoggingLevels::GRAPH_UPDATE:
out << "UPDATE";
break;
case JitLoggingLevels::GRAPH_DEBUG:
out << "DEBUG";
break;
default:
TORCH_INTERNAL_ASSERT(false, "Invalid level");
}
return out;
}
} // namespace jit
} // namespace torch