mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: There's an annoying O(N^2) in module export logic that makes saving some of the models (if they have many classes) take eternity. I'm not super familiar with this code to properly untangle the deps and make it a pure hash lookup. So I just added a side lookup table for raw pointers. It's still quadratic, but it's O(num_classes^2) instead of O(num_classes * num_references) which already gives huge savings. Pull Request resolved: https://github.com/pytorch/pytorch/pull/44589 Test Plan: Tested with one of the offending models - just loading a saving a Torchscript file: ``` Before: load 1.9239683151245117 save 165.74712467193604 After: load 1.9409027099609375 save 1.4711427688598633 ``` Reviewed By: suo Differential Revision: D23675278 Pulled By: dzhulgakov fbshipit-source-id: 8f3fa7730941085ea20d9255b49a149ac1bf64fe
134 lines
3.7 KiB
C++
134 lines
3.7 KiB
C++
|
|
#include <cstdlib>
|
|
#include <iomanip>
|
|
#include <sstream>
|
|
#include <vector>
|
|
|
|
#include <ATen/core/function.h>
|
|
#include <c10/util/Exception.h>
|
|
#include <c10/util/StringUtil.h>
|
|
#include <torch/csrc/jit/api/function_impl.h>
|
|
#include <torch/csrc/jit/frontend/error_report.h>
|
|
#include <torch/csrc/jit/ir/ir.h>
|
|
#include <torch/csrc/jit/jit_log.h>
|
|
#include <torch/csrc/jit/serialization/python_print.h>
|
|
|
|
namespace torch {
|
|
namespace jit {
|
|
|
|
// gets a string representation of a node header
|
|
// (e.g. outputs, a node kind and outputs)
|
|
std::string getHeader(const Node* node) {
|
|
std::stringstream ss;
|
|
node->print(ss, 0, {}, false, false, false, false);
|
|
return ss.str();
|
|
}
|
|
|
|
static std::unordered_map<std::string, size_t> parseJITLogOption(
|
|
const char* option) {
|
|
std::stringstream in_ss;
|
|
in_ss << "function:";
|
|
if (option) {
|
|
in_ss << option;
|
|
}
|
|
|
|
std::unordered_map<std::string, size_t> files_to_levels;
|
|
std::string line;
|
|
while (std::getline(in_ss, line, ':')) {
|
|
if (line.size() == 0) {
|
|
continue;
|
|
}
|
|
|
|
auto index_at = line.find_last_of('>');
|
|
auto begin_index = index_at == std::string::npos ? 0 : index_at + 1;
|
|
size_t logging_level = index_at == std::string::npos ? 0 : index_at + 1;
|
|
auto end_index = line.find_last_of('.') == std::string::npos
|
|
? line.size()
|
|
: line.find_last_of('.');
|
|
auto filename = line.substr(begin_index, end_index - begin_index);
|
|
files_to_levels.insert({filename, logging_level});
|
|
}
|
|
|
|
return files_to_levels;
|
|
}
|
|
|
|
bool is_enabled(const char* cfname, JitLoggingLevels level) {
|
|
static const char* c_log_level = std::getenv("PYTORCH_JIT_LOG_LEVEL");
|
|
static const std::unordered_map<std::string, size_t> files_to_levels =
|
|
parseJITLogOption(c_log_level);
|
|
std::string fname{cfname};
|
|
fname = c10::detail::StripBasename(fname);
|
|
auto end_index = fname.find_last_of('.') == std::string::npos
|
|
? fname.size()
|
|
: fname.find_last_of('.');
|
|
auto fname_no_ext = fname.substr(0, end_index);
|
|
|
|
auto it = files_to_levels.find(fname_no_ext);
|
|
if (it == files_to_levels.end()) {
|
|
return false;
|
|
}
|
|
|
|
return level <= static_cast<JitLoggingLevels>(it->second);
|
|
}
|
|
|
|
// Unfortunately, in `GraphExecutor` where `log_function` is invoked
|
|
// we won't have access to an original function, so we have to construct
|
|
// a dummy function to give to PythonPrint
|
|
std::string log_function(const std::shared_ptr<torch::jit::Graph>& graph) {
|
|
torch::jit::GraphFunction func("source_dump", graph, nullptr);
|
|
std::vector<at::IValue> constants;
|
|
PrintDepsTable deps;
|
|
PythonPrint pp(constants, deps);
|
|
pp.printFunction(func);
|
|
return pp.str();
|
|
}
|
|
|
|
std::string jit_log_prefix(
|
|
const std::string& prefix,
|
|
const std::string& in_str) {
|
|
std::stringstream in_ss(in_str);
|
|
std::stringstream out_ss;
|
|
std::string line;
|
|
while (std::getline(in_ss, line)) {
|
|
out_ss << prefix << line << std::endl;
|
|
}
|
|
|
|
return out_ss.str();
|
|
}
|
|
|
|
std::string jit_log_prefix(
|
|
JitLoggingLevels level,
|
|
const char* fn,
|
|
int l,
|
|
const std::string& in_str) {
|
|
std::stringstream prefix_ss;
|
|
prefix_ss << "[";
|
|
prefix_ss << level << " ";
|
|
prefix_ss << c10::detail::StripBasename(std::string(fn)) << ":";
|
|
prefix_ss << std::setfill('0') << std::setw(3) << l;
|
|
prefix_ss << "] ";
|
|
|
|
return jit_log_prefix(prefix_ss.str(), in_str);
|
|
}
|
|
|
|
std::ostream& operator<<(std::ostream& out, JitLoggingLevels level) {
|
|
switch (level) {
|
|
case JitLoggingLevels::GRAPH_DUMP:
|
|
out << "DUMP";
|
|
break;
|
|
case JitLoggingLevels::GRAPH_UPDATE:
|
|
out << "UPDATE";
|
|
break;
|
|
case JitLoggingLevels::GRAPH_DEBUG:
|
|
out << "DEBUG";
|
|
break;
|
|
default:
|
|
TORCH_INTERNAL_ASSERT(false, "Invalid level");
|
|
}
|
|
|
|
return out;
|
|
}
|
|
|
|
} // namespace jit
|
|
} // namespace torch
|