mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/44587 Currently it's skewed by one. The following test demonstrates it: ``` $ cat test.py import torch def foo(a,b): return a*a*b torch._C._jit_set_profiling_executor(True) torch._C._jit_set_profiling_mode(True) torch._C._jit_override_can_fuse_on_cpu(True) torch._C._jit_set_texpr_fuser_enabled(True) f = torch.jit.script(foo) for _ in range(10): f(torch.rand(10), torch.rand(10)) $ cat test_logging_levels.sh PYTORCH_JIT_LOG_LEVEL="tensorexpr_fuser" python test.py 2>&1 | grep DUMP >& /dev/null && echo OK || echo FAIL PYTORCH_JIT_LOG_LEVEL="tensorexpr_fuser" python test.py 2>&1 | grep UPDATE >& /dev/null && echo FAIL || echo OK PYTORCH_JIT_LOG_LEVEL="tensorexpr_fuser" python test.py 2>&1 | grep DEBUG >& /dev/null && echo FAIL || echo OK PYTORCH_JIT_LOG_LEVEL=">tensorexpr_fuser" python test.py 2>&1 | grep DUMP >& /dev/null && echo OK || echo FAIL PYTORCH_JIT_LOG_LEVEL=">tensorexpr_fuser" python test.py 2>&1 | grep UPDATE >& /dev/null && echo OK || echo FAIL PYTORCH_JIT_LOG_LEVEL=">tensorexpr_fuser" python test.py 2>&1 | grep DEBUG >& /dev/null && echo FAIL || echo OK PYTORCH_JIT_LOG_LEVEL=">>tensorexpr_fuser" python test.py 2>&1 | grep DUMP >& /dev/null && echo OK || echo FAIL PYTORCH_JIT_LOG_LEVEL=">>tensorexpr_fuser" python test.py 2>&1 | grep UPDATE >& /dev/null && echo OK || echo FAIL PYTORCH_JIT_LOG_LEVEL=">>tensorexpr_fuser" python test.py 2>&1 | grep DEBUG >& /dev/null && echo OK || echo FAIL ``` Before this change: ``` OK FAIL OK OK OK FAIL OK OK OK ``` With this change everthing passes. Differential Revision: D23666813 Test Plan: Imported from OSS Reviewed By: bertmaher Pulled By: ZolotukhinM fbshipit-source-id: 4adaa5a3d06deadf54eae014a0d76588cdc5e20a
134 lines
3.7 KiB
C++
134 lines
3.7 KiB
C++
|
|
#include <cstdlib>
|
|
#include <iomanip>
|
|
#include <sstream>
|
|
#include <vector>
|
|
|
|
#include <ATen/core/function.h>
|
|
#include <c10/util/Exception.h>
|
|
#include <c10/util/StringUtil.h>
|
|
#include <torch/csrc/jit/api/function_impl.h>
|
|
#include <torch/csrc/jit/frontend/error_report.h>
|
|
#include <torch/csrc/jit/ir/ir.h>
|
|
#include <torch/csrc/jit/jit_log.h>
|
|
#include <torch/csrc/jit/serialization/python_print.h>
|
|
|
|
namespace torch {
|
|
namespace jit {
|
|
|
|
// gets a string representation of a node header
|
|
// (e.g. outputs, a node kind and outputs)
|
|
std::string getHeader(const Node* node) {
|
|
std::stringstream ss;
|
|
node->print(ss, 0, {}, false, false, false, false);
|
|
return ss.str();
|
|
}
|
|
|
|
static std::unordered_map<std::string, size_t> parseJITLogOption(
|
|
const char* option) {
|
|
std::stringstream in_ss;
|
|
in_ss << "function:";
|
|
if (option) {
|
|
in_ss << option;
|
|
}
|
|
|
|
std::unordered_map<std::string, size_t> files_to_levels;
|
|
std::string line;
|
|
while (std::getline(in_ss, line, ':')) {
|
|
if (line.size() == 0) {
|
|
continue;
|
|
}
|
|
|
|
auto index_at = line.find_last_of('>');
|
|
auto begin_index = index_at == std::string::npos ? 0 : index_at + 1;
|
|
size_t logging_level = index_at == std::string::npos ? 0 : index_at + 1;
|
|
auto end_index = line.find_last_of('.') == std::string::npos
|
|
? line.size()
|
|
: line.find_last_of('.');
|
|
auto filename = line.substr(begin_index, end_index - begin_index);
|
|
files_to_levels.insert({filename, logging_level});
|
|
}
|
|
|
|
return files_to_levels;
|
|
}
|
|
|
|
bool is_enabled(const char* cfname, JitLoggingLevels level) {
|
|
static const char* c_log_level = std::getenv("PYTORCH_JIT_LOG_LEVEL");
|
|
static const std::unordered_map<std::string, size_t> files_to_levels =
|
|
parseJITLogOption(c_log_level);
|
|
std::string fname{cfname};
|
|
fname = c10::detail::StripBasename(fname);
|
|
auto end_index = fname.find_last_of('.') == std::string::npos
|
|
? fname.size()
|
|
: fname.find_last_of('.');
|
|
auto fname_no_ext = fname.substr(0, end_index);
|
|
|
|
auto it = files_to_levels.find(fname_no_ext);
|
|
if (it == files_to_levels.end()) {
|
|
return false;
|
|
}
|
|
|
|
return level <= static_cast<JitLoggingLevels>(it->second);
|
|
}
|
|
|
|
// Unfortunately, in `GraphExecutor` where `log_function` is invoked
|
|
// we won't have access to an original function, so we have to construct
|
|
// a dummy function to give to PythonPrint
|
|
std::string log_function(const std::shared_ptr<torch::jit::Graph>& graph) {
|
|
torch::jit::GraphFunction func("source_dump", graph, nullptr);
|
|
std::vector<at::IValue> constants;
|
|
std::vector<c10::NamedTypePtr> deps;
|
|
PythonPrint pp(constants, deps);
|
|
pp.printFunction(func);
|
|
return pp.str();
|
|
}
|
|
|
|
std::string jit_log_prefix(
|
|
const std::string& prefix,
|
|
const std::string& in_str) {
|
|
std::stringstream in_ss(in_str);
|
|
std::stringstream out_ss;
|
|
std::string line;
|
|
while (std::getline(in_ss, line)) {
|
|
out_ss << prefix << line << std::endl;
|
|
}
|
|
|
|
return out_ss.str();
|
|
}
|
|
|
|
std::string jit_log_prefix(
|
|
JitLoggingLevels level,
|
|
const char* fn,
|
|
int l,
|
|
const std::string& in_str) {
|
|
std::stringstream prefix_ss;
|
|
prefix_ss << "[";
|
|
prefix_ss << level << " ";
|
|
prefix_ss << c10::detail::StripBasename(std::string(fn)) << ":";
|
|
prefix_ss << std::setfill('0') << std::setw(3) << l;
|
|
prefix_ss << "] ";
|
|
|
|
return jit_log_prefix(prefix_ss.str(), in_str);
|
|
}
|
|
|
|
std::ostream& operator<<(std::ostream& out, JitLoggingLevels level) {
|
|
switch (level) {
|
|
case JitLoggingLevels::GRAPH_DUMP:
|
|
out << "DUMP";
|
|
break;
|
|
case JitLoggingLevels::GRAPH_UPDATE:
|
|
out << "UPDATE";
|
|
break;
|
|
case JitLoggingLevels::GRAPH_DEBUG:
|
|
out << "DEBUG";
|
|
break;
|
|
default:
|
|
TORCH_INTERNAL_ASSERT(false, "Invalid level");
|
|
}
|
|
|
|
return out;
|
|
}
|
|
|
|
} // namespace jit
|
|
} // namespace torch
|