mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: Fixes https://github.com/pytorch/pytorch/issues/4661 - Add warnings in engine's `execute` function so it can be triggered through both cpp and python codepaths - Adds an RAII guard version of `c10::Warning::set_warnAlways` and replaces all prior usages of the set_warnAlways with the new one Pull Request resolved: https://github.com/pytorch/pytorch/pull/59412 Reviewed By: jbschlosser Differential Revision: D28969294 Pulled By: soulitzer fbshipit-source-id: b03369c926a3be18ce1cf363b39edd82a14245f0
220 lines
5.3 KiB
C++
220 lines
5.3 KiB
C++
#include <c10/util/Backtrace.h>
|
|
#include <c10/util/Exception.h>
|
|
#include <c10/util/Logging.h>
|
|
#include <c10/util/Type.h>
|
|
|
|
#include <iostream>
|
|
#include <numeric>
|
|
#include <sstream>
|
|
#include <string>
|
|
|
|
namespace c10 {
|
|
|
|
Error::Error(std::string msg, std::string backtrace, const void* caller)
|
|
: msg_(std::move(msg)), backtrace_(std::move(backtrace)), caller_(caller) {
|
|
refresh_what();
|
|
}
|
|
|
|
// PyTorch-style error message
|
|
// Error::Error(SourceLocation source_location, const std::string& msg)
|
|
// NB: This is defined in Logging.cpp for access to GetFetchStackTrace
|
|
|
|
// Caffe2-style error message
|
|
Error::Error(
|
|
const char* file,
|
|
const uint32_t line,
|
|
const char* condition,
|
|
const std::string& msg,
|
|
const std::string& backtrace,
|
|
const void* caller)
|
|
: Error(
|
|
str("[enforce fail at ",
|
|
detail::StripBasename(file),
|
|
":",
|
|
line,
|
|
"] ",
|
|
condition,
|
|
". ",
|
|
msg),
|
|
backtrace,
|
|
caller) {}
|
|
|
|
std::string Error::compute_what(bool include_backtrace) const {
|
|
std::ostringstream oss;
|
|
|
|
oss << msg_;
|
|
|
|
if (context_.size() == 1) {
|
|
// Fold error and context in one line
|
|
oss << " (" << context_[0] << ")";
|
|
} else {
|
|
for (const auto& c : context_) {
|
|
oss << "\n " << c;
|
|
}
|
|
}
|
|
|
|
if (include_backtrace) {
|
|
oss << "\n" << backtrace_;
|
|
}
|
|
|
|
return oss.str();
|
|
}
|
|
|
|
void Error::refresh_what() {
|
|
what_ = compute_what(/*include_backtrace*/ true);
|
|
what_without_backtrace_ = compute_what(/*include_backtrace*/ false);
|
|
}
|
|
|
|
void Error::add_context(std::string new_msg) {
|
|
context_.push_back(std::move(new_msg));
|
|
// TODO: Calling add_context O(n) times has O(n^2) cost. We can fix
|
|
// this perf problem by populating the fields lazily... if this ever
|
|
// actually is a problem.
|
|
// NB: If you do fix this, make sure you do it in a thread safe way!
|
|
// what() is almost certainly expected to be thread safe even when
|
|
// accessed across multiple threads
|
|
refresh_what();
|
|
}
|
|
|
|
namespace detail {
|
|
|
|
void torchCheckFail(
|
|
const char* func,
|
|
const char* file,
|
|
uint32_t line,
|
|
const std::string& msg) {
|
|
throw ::c10::Error({func, file, line}, msg);
|
|
}
|
|
|
|
void torchCheckFail(
|
|
const char* func,
|
|
const char* file,
|
|
uint32_t line,
|
|
const char* msg) {
|
|
throw ::c10::Error({func, file, line}, msg);
|
|
}
|
|
|
|
void torchInternalAssertFail(
|
|
const char* func,
|
|
const char* file,
|
|
uint32_t line,
|
|
const char* condMsg,
|
|
const char* userMsg) {
|
|
torchCheckFail(func, file, line, c10::str(condMsg, userMsg));
|
|
}
|
|
|
|
// This should never be called. It is provided in case of compilers
|
|
// that don't do any dead code stripping in debug builds.
|
|
void torchInternalAssertFail(
|
|
const char* func,
|
|
const char* file,
|
|
uint32_t line,
|
|
const char* condMsg,
|
|
const std::string& userMsg) {
|
|
torchCheckFail(func, file, line, c10::str(condMsg, userMsg));
|
|
}
|
|
|
|
} // namespace detail
|
|
|
|
namespace Warning {
|
|
|
|
namespace {
|
|
WarningHandler* getBaseHandler() {
|
|
static WarningHandler base_warning_handler_ = WarningHandler();
|
|
return &base_warning_handler_;
|
|
};
|
|
|
|
class ThreadWarningHandler {
|
|
public:
|
|
ThreadWarningHandler() = delete;
|
|
|
|
static WarningHandler* get_handler() {
|
|
if (!warning_handler_) {
|
|
warning_handler_ = getBaseHandler();
|
|
}
|
|
return warning_handler_;
|
|
}
|
|
|
|
static void set_handler(WarningHandler* handler) {
|
|
warning_handler_ = handler;
|
|
}
|
|
|
|
private:
|
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
|
static thread_local WarningHandler* warning_handler_;
|
|
};
|
|
|
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
|
thread_local WarningHandler* ThreadWarningHandler::warning_handler_ = nullptr;
|
|
|
|
} // namespace
|
|
|
|
void warn(
|
|
const SourceLocation& source_location,
|
|
const std::string& msg,
|
|
const bool verbatim) {
|
|
ThreadWarningHandler::get_handler()->process(source_location, msg, verbatim);
|
|
}
|
|
|
|
void warn(
|
|
SourceLocation source_location,
|
|
detail::CompileTimeEmptyString msg,
|
|
const bool verbatim) {
|
|
warn(source_location, "", verbatim);
|
|
}
|
|
|
|
void warn(
|
|
SourceLocation source_location,
|
|
const char* msg,
|
|
const bool verbatim) {
|
|
ThreadWarningHandler::get_handler()->process(source_location, msg, verbatim);
|
|
}
|
|
|
|
void set_warning_handler(WarningHandler* handler) noexcept(true) {
|
|
ThreadWarningHandler::set_handler(handler);
|
|
}
|
|
|
|
WarningHandler* get_warning_handler() noexcept(true) {
|
|
return ThreadWarningHandler::get_handler();
|
|
}
|
|
|
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
|
bool warn_always = false;
|
|
|
|
void set_warnAlways(bool setting) noexcept(true) {
|
|
warn_always = setting;
|
|
}
|
|
|
|
bool get_warnAlways() noexcept(true) {
|
|
return warn_always;
|
|
}
|
|
|
|
WarnAlways::WarnAlways(bool setting /*=true*/)
|
|
: prev_setting(get_warnAlways()) {
|
|
set_warnAlways(setting);
|
|
}
|
|
|
|
WarnAlways::~WarnAlways() {
|
|
set_warnAlways(prev_setting);
|
|
}
|
|
|
|
} // namespace Warning
|
|
|
|
void WarningHandler::process(
|
|
const SourceLocation& source_location,
|
|
const std::string& msg,
|
|
const bool /*verbatim*/) {
|
|
LOG_AT_FILE_LINE(WARNING, source_location.file, source_location.line)
|
|
<< "Warning: " << msg << " (function " << source_location.function << ")";
|
|
}
|
|
|
|
std::string GetExceptionString(const std::exception& e) {
|
|
#ifdef __GXX_RTTI
|
|
return demangle(typeid(e).name()) + ": " + e.what();
|
|
#else
|
|
return std::string("Exception (no RTTI available): ") + e.what();
|
|
#endif // __GXX_RTTI
|
|
}
|
|
|
|
} // namespace c10
|