pytorch/torch/csrc/jit/jit_log.cpp
Zachary DeVito 58ed8ca9e1 clean up exported source format (#28129)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/28129

The previous PR in the stack removed the need to order classes/functions
or have correct import statements. This resolved circular depedency issues
that can arise when class constructors like ModuleList put new instances
of themselves in a common namespace.

This PR changes our export format to no longer produce this information.
By doing so we can make the logic signficantly simpler, since we just
keep track of an individual PythonPrint object per file.

Notes:
* PythonPrint was changed to manage its own stream/list of ranges. It
was doing this anyway internally, this just makes the API more clear.
* Since we are changing the serialization format, I also removed op_version_set.
It is now replaced with the VERSION number that written in the zip archive.
This further simplifies the code emission process.
* A test of op_version_set was removed since there is no longer any behavior
to test.

Test Plan: Imported from OSS

Differential Revision: D17961610

Pulled By: zdevito

fbshipit-source-id: ada362c4ca34d05393a1a7e799c94785ab9d9825
2019-10-16 22:47:24 -07:00

135 lines
3.7 KiB
C++

#include <cstdlib>
#include <iomanip>
#include <sstream>
#include <vector>
#include <c10/util/Exception.h>
#include <c10/util/StringUtil.h>
#include <torch/csrc/jit/function.h>
#include <torch/csrc/jit/ir.h>
#include <torch/csrc/jit/jit_log.h>
#include <torch/csrc/jit/passes/python_print.h>
#include <torch/csrc/jit/script/error_report.h>
namespace torch {
namespace jit {
// gets a string representation of a node header
// (e.g. outputs, a node kind and outputs)
std::string getHeader(const Node *node) {
std::stringstream ss;
node->print(ss, 0, {}, false, false, false, false);
return ss.str();
}
static std::unordered_map<std::string, size_t>
parseJITLogOption(const char *option) {
std::stringstream in_ss;
in_ss << "function:";
if (option) {
in_ss << option;
}
std::unordered_map<std::string, size_t> files_to_levels;
std::string line;
while (std::getline(in_ss, line, ':')) {
if (line.size() == 0) {
continue;
}
auto index_at = line.find_last_of('>');
auto begin_index = index_at == std::string::npos ? 0 : index_at + 1;
size_t logging_level = index_at == std::string::npos ? 1 : index_at + 2;
auto end_index = line.find_last_of('.') == std::string::npos
? line.size()
: line.find_last_of('.');
auto filename = line.substr(begin_index, end_index - begin_index);
files_to_levels.insert({filename, logging_level});
}
return files_to_levels;
}
bool is_enabled(const char *cfname, JitLoggingLevels level) {
static const char* c_log_level = std::getenv("PYTORCH_JIT_LOG_LEVEL");
static const std::unordered_map<std::string, size_t> files_to_levels =
parseJITLogOption(c_log_level);
std::string fname{cfname};
fname = c10::detail::StripBasename(fname);
auto end_index = fname.find_last_of('.') == std::string::npos
? fname.size()
: fname.find_last_of('.');
auto fname_no_ext = fname.substr(0, end_index);
auto it = files_to_levels.find(fname_no_ext);
if (it == files_to_levels.end()) {
return false;
}
return level <= static_cast<JitLoggingLevels>(it->second);
}
// Unfortunately, in `GraphExecutor` where `log_function` is invoked
// we won't have access to an original function, so we have to construct
// a dummy function to give to PythonPrint
std::string log_function(const std::shared_ptr<torch::jit::Graph> &graph) {
torch::jit::Function func("source_dump", graph, nullptr);
std::vector<at::Tensor> tensors;
std::vector<c10::NamedTypePtr> deps;
PythonPrint pp(tensors, deps, false);
pp.printFunction(func);
return pp.str();
}
std::string jit_log_prefix(
const std::string& prefix,
const std::string& in_str) {
std::stringstream in_ss(in_str);
std::stringstream out_ss;
std::string line;
while (std::getline(in_ss, line)) {
out_ss << prefix << line << std::endl;
}
return out_ss.str();
}
std::string jit_log_prefix(
JitLoggingLevels level,
const char* fn,
int l,
const std::string& in_str) {
std::stringstream prefix_ss;
prefix_ss << "[";
prefix_ss << level << " ";
prefix_ss << c10::detail::StripBasename(std::string(fn)) << ":";
prefix_ss << std::setfill('0') << std::setw(3) << l;
prefix_ss << "] ";
return jit_log_prefix(prefix_ss.str(), in_str);
}
std::ostream& operator<<(std::ostream& out, JitLoggingLevels level) {
switch (level) {
case JitLoggingLevels::GRAPH_DUMP:
out << "DUMP";
break;
case JitLoggingLevels::GRAPH_UPDATE:
out << "UPDATE";
break;
case JitLoggingLevels::GRAPH_DEBUG:
out << "DEBUG";
break;
default:
TORCH_INTERNAL_ASSERT(false, "Invalid level");
}
return out;
}
} // namespace jit
} // namespace torch